vermind / src /streamlit_app.py
nev8r's picture
Update src/streamlit_app.py
c721945 verified
#!/usr/bin/env python
"""
VerMind All-in-One App (FastAPI + Streamlit)
适配 Docker CPU 环境 (自动降级 float32)
"""
import os
import sys
import time
import argparse
import threading
from functools import lru_cache
from typing import List, Dict
# ---- 1. 路径适配: 确保 src 目录能被 import ----
# Docker 中 WORKDIR 是 /app,src 在 /app/src
current_dir = os.path.dirname(os.path.abspath(__file__))
src_dir = os.path.join(current_dir, "src")
if os.path.exists(src_dir) and src_dir not in sys.path:
sys.path.append(src_dir)
# ---- 2. 清理代理 (避免 Docker 内网络请求被宿主机代理干扰) ----
for var in ["http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"]:
os.environ.pop(var, None)
# ---- 依赖导入 ----
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
try:
import httpx
except ImportError:
httpx = None
# ========== 全局配置 ==========
HF_MODEL_ID = os.getenv("HF_MODEL_ID", "nev8rz/vermind")
DEFAULT_API_HOST = "0.0.0.0"
DEFAULT_API_PORT = 8000
# ========== 核心工具函数 ==========
def pick_device() -> str:
"""自动选择设备"""
forced = os.getenv("DEVICE", "").strip().lower()
if forced in {"cpu", "cuda", "mps"}:
return forced
return "cuda" if torch.cuda.is_available() else "cpu"
def pick_dtype(device: str):
"""
根据设备选择精度。
警告:在 Docker CPU 环境下,必须使用 float32,否则会报 "addmm_impl_cpu_" not implemented for 'Half'
"""
s = os.getenv("DTYPE", "").strip().lower()
if s in {"fp16", "float16"}: return torch.float16
if s in {"bf16", "bfloat16"}: return torch.bfloat16
if s in {"fp32", "float32"}: return torch.float32
if device == "cpu":
# CPU 默认强制 float32
return torch.float32
return torch.float16
@lru_cache(maxsize=1)
def load_tokenizer_and_model():
"""单例加载模型"""
device = pick_device()
dtype = pick_dtype(device)
print(f"[Init] Model: {HF_MODEL_ID} | Device: {device} | Dtype: {dtype}")
token = os.getenv("HF_TOKEN", None)
try:
tok = AutoTokenizer.from_pretrained(
HF_MODEL_ID,
trust_remote_code=True,
token=token
)
# CPU 加载逻辑
if device == "cuda":
model = AutoModelForCausalLM.from_pretrained(
HF_MODEL_ID,
trust_remote_code=True,
torch_dtype=dtype,
device_map="auto",
token=token
)
else:
model = AutoModelForCausalLM.from_pretrained(
HF_MODEL_ID,
trust_remote_code=True,
torch_dtype=dtype,
token=token
).to(device) # CPU 显式 to(device) 更稳
model.eval()
return tok, model, device
except Exception as e:
print(f"[Error] Failed to load model: {e}")
raise e
def generate_reply(messages: List[Dict], max_new_tokens=256, temperature=0.7, top_p=0.9) -> str:
tok, model, device = load_tokenizer_and_model()
# 构建 Prompt (如果模型有 chat_template 则使用,否则 fallback)
try:
prompt = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
except:
# 简易拼接 fallback
prompt = ""
for m in messages:
prompt += f"{m['role']}: {m['content']}\n"
prompt += "assistant: "
inputs = tok(prompt, return_tensors="pt")
if device != "cuda":
inputs = {k: v.to(device) for k, v in inputs.items()}
gen_kwargs = {
"max_new_tokens": max_new_tokens,
"do_sample": temperature > 0,
"pad_token_id": tok.eos_token_id
}
if temperature > 0:
gen_kwargs["temperature"] = temperature
gen_kwargs["top_p"] = top_p
with torch.no_grad():
outputs = model.generate(**inputs, **gen_kwargs)
# 解码
input_len = inputs["input_ids"].shape[1]
return tok.decode(outputs[0][input_len:], skip_special_tokens=True).strip()
# ========== FastAPI 部分 ==========
app = FastAPI()
class ChatReq(BaseModel):
messages: List[dict]
max_tokens: int = 256
temperature: float = 0.7
top_p: float = 0.9
@app.get("/health")
async def health():
return {"status": "ok"}
@app.post("/v1/chat/completions")
async def chat(req: ChatReq):
try:
text = generate_reply(req.messages, req.max_tokens, req.temperature, req.top_p)
return {
"choices": [{"message": {"role": "assistant", "content": text}}],
"model": HF_MODEL_ID
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
def run_uvicorn_server(host: str, port: int):
import uvicorn
print(f"[API] Starting Uvicorn on {host}:{port}")
uvicorn.run(app, host=host, port=port, log_level="warning")
# ========== Streamlit UI 部分 ==========
def run_streamlit_ui(api_host: str, api_port: int):
import streamlit as st
# 容器内部 API 地址 (本地回环)
internal_api_url = f"http://127.0.0.1:{api_port}/v1/chat/completions"
health_url = f"http://127.0.0.1:{api_port}/health"
# --- 后台启动 API 线程 ---
if httpx:
api_up = False
try:
if httpx.get(health_url, timeout=0.1).status_code == 200:
api_up = True
except:
pass
# 如果 API 没通,且当前没有正在启动的线程,则启动
target_thread_name = f"uvicorn_{api_port}"
running_threads = [t.name for t in threading.enumerate()]
if not api_up and target_thread_name not in running_threads:
t = threading.Thread(
target=run_uvicorn_server,
args=("0.0.0.0", api_port), # 绑定 0.0.0.0 方便外部调试
daemon=True,
name=target_thread_name
)
t.start()
# 简单等待
time.sleep(2)
# --- UI 渲染 ---
st.set_page_config(page_title="VerMind Chat", layout="centered")
st.title("🤖 VerMind AI (Docker Edition)")
if "messages" not in st.session_state:
st.session_state.messages = []
# 历史消息
for msg in st.session_state.messages:
with st.chat_message(msg["role"]):
st.markdown(msg["content"])
# 输入框
if prompt := st.chat_input("Input your question here..."):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
with st.spinner("Model is thinking... (CPU may be slow)"):
try:
payload = {
"messages": st.session_state.messages,
"temperature": 0.7,
"max_tokens": 512
}
# 增加超时时间,CPU 推理较慢
resp = httpx.post(internal_api_url, json=payload, timeout=300.0)
if resp.status_code == 200:
reply = resp.json()["choices"][0]["message"]["content"]
st.markdown(reply)
st.session_state.messages.append({"role": "assistant", "content": reply})
else:
st.error(f"API Error: {resp.text}")
except Exception as e:
st.error(f"Connection Failed. Ensure API is running. Error: {e}")
# ========== 入口 ==========
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--api", action="store_true", help="Run FastAPI only")
parser.add_argument("--ui", action="store_true", help="Run Streamlit UI")
parser.add_argument("--host", default=DEFAULT_API_HOST)
parser.add_argument("--port", type=int, default=DEFAULT_API_PORT)
# 接收 Streamlit 传入的未知参数 (如 streamlit 自身的 flag)
args, unknown = parser.parse_known_args()
if args.api:
run_uvicorn_server(args.host, args.port)
elif args.ui:
# Streamlit 模式下,API 端口复用 --port 参数 (默认 8000)
run_streamlit_ui(api_host=args.host, api_port=args.port)
else:
# 默认回退 (兼容 Docker 直接运行 python app.py)
run_uvicorn_server(args.host, args.port)