File size: 8,562 Bytes
099a252 a3d2995 6d38798 a3d2995 099a252 6d38798 099a252 6d38798 099a252 6d38798 099a252 6d38798 099a252 6d38798 099a252 6d38798 099a252 6d38798 099a252 6d38798 099a252 6d38798 099a252 6d38798 099a252 6d38798 099a252 6d38798 099a252 6d38798 099a252 6d38798 099a252 6d38798 099a252 6d38798 099a252 6d38798 099a252 6d38798 099a252 6d38798 099a252 6d38798 099a252 6d38798 099a252 6d38798 099a252 6d38798 099a252 6d38798 099a252 6d38798 099a252 6d38798 099a252 6d38798 099a252 6d38798 099a252 6d38798 c721945 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 | #!/usr/bin/env python
"""
VerMind All-in-One App (FastAPI + Streamlit)
适配 Docker CPU 环境 (自动降级 float32)
"""
import os
import sys
import time
import argparse
import threading
from functools import lru_cache
from typing import List, Dict
# ---- 1. 路径适配: 确保 src 目录能被 import ----
# Docker 中 WORKDIR 是 /app,src 在 /app/src
current_dir = os.path.dirname(os.path.abspath(__file__))
src_dir = os.path.join(current_dir, "src")
if os.path.exists(src_dir) and src_dir not in sys.path:
sys.path.append(src_dir)
# ---- 2. 清理代理 (避免 Docker 内网络请求被宿主机代理干扰) ----
for var in ["http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"]:
os.environ.pop(var, None)
# ---- 依赖导入 ----
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
try:
import httpx
except ImportError:
httpx = None
# ========== 全局配置 ==========
HF_MODEL_ID = os.getenv("HF_MODEL_ID", "nev8rz/vermind")
DEFAULT_API_HOST = "0.0.0.0"
DEFAULT_API_PORT = 8000
# ========== 核心工具函数 ==========
def pick_device() -> str:
"""自动选择设备"""
forced = os.getenv("DEVICE", "").strip().lower()
if forced in {"cpu", "cuda", "mps"}:
return forced
return "cuda" if torch.cuda.is_available() else "cpu"
def pick_dtype(device: str):
"""
根据设备选择精度。
警告:在 Docker CPU 环境下,必须使用 float32,否则会报 "addmm_impl_cpu_" not implemented for 'Half'
"""
s = os.getenv("DTYPE", "").strip().lower()
if s in {"fp16", "float16"}: return torch.float16
if s in {"bf16", "bfloat16"}: return torch.bfloat16
if s in {"fp32", "float32"}: return torch.float32
if device == "cpu":
# CPU 默认强制 float32
return torch.float32
return torch.float16
@lru_cache(maxsize=1)
def load_tokenizer_and_model():
"""单例加载模型"""
device = pick_device()
dtype = pick_dtype(device)
print(f"[Init] Model: {HF_MODEL_ID} | Device: {device} | Dtype: {dtype}")
token = os.getenv("HF_TOKEN", None)
try:
tok = AutoTokenizer.from_pretrained(
HF_MODEL_ID,
trust_remote_code=True,
token=token
)
# CPU 加载逻辑
if device == "cuda":
model = AutoModelForCausalLM.from_pretrained(
HF_MODEL_ID,
trust_remote_code=True,
torch_dtype=dtype,
device_map="auto",
token=token
)
else:
model = AutoModelForCausalLM.from_pretrained(
HF_MODEL_ID,
trust_remote_code=True,
torch_dtype=dtype,
token=token
).to(device) # CPU 显式 to(device) 更稳
model.eval()
return tok, model, device
except Exception as e:
print(f"[Error] Failed to load model: {e}")
raise e
def generate_reply(messages: List[Dict], max_new_tokens=256, temperature=0.7, top_p=0.9) -> str:
tok, model, device = load_tokenizer_and_model()
# 构建 Prompt (如果模型有 chat_template 则使用,否则 fallback)
try:
prompt = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
except:
# 简易拼接 fallback
prompt = ""
for m in messages:
prompt += f"{m['role']}: {m['content']}\n"
prompt += "assistant: "
inputs = tok(prompt, return_tensors="pt")
if device != "cuda":
inputs = {k: v.to(device) for k, v in inputs.items()}
gen_kwargs = {
"max_new_tokens": max_new_tokens,
"do_sample": temperature > 0,
"pad_token_id": tok.eos_token_id
}
if temperature > 0:
gen_kwargs["temperature"] = temperature
gen_kwargs["top_p"] = top_p
with torch.no_grad():
outputs = model.generate(**inputs, **gen_kwargs)
# 解码
input_len = inputs["input_ids"].shape[1]
return tok.decode(outputs[0][input_len:], skip_special_tokens=True).strip()
# ========== FastAPI 部分 ==========
app = FastAPI()
class ChatReq(BaseModel):
messages: List[dict]
max_tokens: int = 256
temperature: float = 0.7
top_p: float = 0.9
@app.get("/health")
async def health():
return {"status": "ok"}
@app.post("/v1/chat/completions")
async def chat(req: ChatReq):
try:
text = generate_reply(req.messages, req.max_tokens, req.temperature, req.top_p)
return {
"choices": [{"message": {"role": "assistant", "content": text}}],
"model": HF_MODEL_ID
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
def run_uvicorn_server(host: str, port: int):
import uvicorn
print(f"[API] Starting Uvicorn on {host}:{port}")
uvicorn.run(app, host=host, port=port, log_level="warning")
# ========== Streamlit UI 部分 ==========
def run_streamlit_ui(api_host: str, api_port: int):
import streamlit as st
# 容器内部 API 地址 (本地回环)
internal_api_url = f"http://127.0.0.1:{api_port}/v1/chat/completions"
health_url = f"http://127.0.0.1:{api_port}/health"
# --- 后台启动 API 线程 ---
if httpx:
api_up = False
try:
if httpx.get(health_url, timeout=0.1).status_code == 200:
api_up = True
except:
pass
# 如果 API 没通,且当前没有正在启动的线程,则启动
target_thread_name = f"uvicorn_{api_port}"
running_threads = [t.name for t in threading.enumerate()]
if not api_up and target_thread_name not in running_threads:
t = threading.Thread(
target=run_uvicorn_server,
args=("0.0.0.0", api_port), # 绑定 0.0.0.0 方便外部调试
daemon=True,
name=target_thread_name
)
t.start()
# 简单等待
time.sleep(2)
# --- UI 渲染 ---
st.set_page_config(page_title="VerMind Chat", layout="centered")
st.title("🤖 VerMind AI (Docker Edition)")
if "messages" not in st.session_state:
st.session_state.messages = []
# 历史消息
for msg in st.session_state.messages:
with st.chat_message(msg["role"]):
st.markdown(msg["content"])
# 输入框
if prompt := st.chat_input("Input your question here..."):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
with st.spinner("Model is thinking... (CPU may be slow)"):
try:
payload = {
"messages": st.session_state.messages,
"temperature": 0.7,
"max_tokens": 512
}
# 增加超时时间,CPU 推理较慢
resp = httpx.post(internal_api_url, json=payload, timeout=300.0)
if resp.status_code == 200:
reply = resp.json()["choices"][0]["message"]["content"]
st.markdown(reply)
st.session_state.messages.append({"role": "assistant", "content": reply})
else:
st.error(f"API Error: {resp.text}")
except Exception as e:
st.error(f"Connection Failed. Ensure API is running. Error: {e}")
# ========== 入口 ==========
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--api", action="store_true", help="Run FastAPI only")
parser.add_argument("--ui", action="store_true", help="Run Streamlit UI")
parser.add_argument("--host", default=DEFAULT_API_HOST)
parser.add_argument("--port", type=int, default=DEFAULT_API_PORT)
# 接收 Streamlit 传入的未知参数 (如 streamlit 自身的 flag)
args, unknown = parser.parse_known_args()
if args.api:
run_uvicorn_server(args.host, args.port)
elif args.ui:
# Streamlit 模式下,API 端口复用 --port 参数 (默认 8000)
run_streamlit_ui(api_host=args.host, api_port=args.port)
else:
# 默认回退 (兼容 Docker 直接运行 python app.py)
run_uvicorn_server(args.host, args.port) |