#!/usr/bin/env python """ VerMind All-in-One App (FastAPI + Streamlit) 适配 Docker CPU 环境 (自动降级 float32) """ import os import sys import time import argparse import threading from functools import lru_cache from typing import List, Dict # ---- 1. 路径适配: 确保 src 目录能被 import ---- # Docker 中 WORKDIR 是 /app,src 在 /app/src current_dir = os.path.dirname(os.path.abspath(__file__)) src_dir = os.path.join(current_dir, "src") if os.path.exists(src_dir) and src_dir not in sys.path: sys.path.append(src_dir) # ---- 2. 清理代理 (避免 Docker 内网络请求被宿主机代理干扰) ---- for var in ["http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"]: os.environ.pop(var, None) # ---- 依赖导入 ---- from fastapi import FastAPI, HTTPException from pydantic import BaseModel import torch from transformers import AutoTokenizer, AutoModelForCausalLM try: import httpx except ImportError: httpx = None # ========== 全局配置 ========== HF_MODEL_ID = os.getenv("HF_MODEL_ID", "nev8rz/vermind") DEFAULT_API_HOST = "0.0.0.0" DEFAULT_API_PORT = 8000 # ========== 核心工具函数 ========== def pick_device() -> str: """自动选择设备""" forced = os.getenv("DEVICE", "").strip().lower() if forced in {"cpu", "cuda", "mps"}: return forced return "cuda" if torch.cuda.is_available() else "cpu" def pick_dtype(device: str): """ 根据设备选择精度。 警告:在 Docker CPU 环境下,必须使用 float32,否则会报 "addmm_impl_cpu_" not implemented for 'Half' """ s = os.getenv("DTYPE", "").strip().lower() if s in {"fp16", "float16"}: return torch.float16 if s in {"bf16", "bfloat16"}: return torch.bfloat16 if s in {"fp32", "float32"}: return torch.float32 if device == "cpu": # CPU 默认强制 float32 return torch.float32 return torch.float16 @lru_cache(maxsize=1) def load_tokenizer_and_model(): """单例加载模型""" device = pick_device() dtype = pick_dtype(device) print(f"[Init] Model: {HF_MODEL_ID} | Device: {device} | Dtype: {dtype}") token = os.getenv("HF_TOKEN", None) try: tok = AutoTokenizer.from_pretrained( HF_MODEL_ID, trust_remote_code=True, token=token ) # CPU 加载逻辑 if device == "cuda": model = AutoModelForCausalLM.from_pretrained( HF_MODEL_ID, trust_remote_code=True, torch_dtype=dtype, device_map="auto", token=token ) else: model = AutoModelForCausalLM.from_pretrained( HF_MODEL_ID, trust_remote_code=True, torch_dtype=dtype, token=token ).to(device) # CPU 显式 to(device) 更稳 model.eval() return tok, model, device except Exception as e: print(f"[Error] Failed to load model: {e}") raise e def generate_reply(messages: List[Dict], max_new_tokens=256, temperature=0.7, top_p=0.9) -> str: tok, model, device = load_tokenizer_and_model() # 构建 Prompt (如果模型有 chat_template 则使用,否则 fallback) try: prompt = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) except: # 简易拼接 fallback prompt = "" for m in messages: prompt += f"{m['role']}: {m['content']}\n" prompt += "assistant: " inputs = tok(prompt, return_tensors="pt") if device != "cuda": inputs = {k: v.to(device) for k, v in inputs.items()} gen_kwargs = { "max_new_tokens": max_new_tokens, "do_sample": temperature > 0, "pad_token_id": tok.eos_token_id } if temperature > 0: gen_kwargs["temperature"] = temperature gen_kwargs["top_p"] = top_p with torch.no_grad(): outputs = model.generate(**inputs, **gen_kwargs) # 解码 input_len = inputs["input_ids"].shape[1] return tok.decode(outputs[0][input_len:], skip_special_tokens=True).strip() # ========== FastAPI 部分 ========== app = FastAPI() class ChatReq(BaseModel): messages: List[dict] max_tokens: int = 256 temperature: float = 0.7 top_p: float = 0.9 @app.get("/health") async def health(): return {"status": "ok"} @app.post("/v1/chat/completions") async def chat(req: ChatReq): try: text = generate_reply(req.messages, req.max_tokens, req.temperature, req.top_p) return { "choices": [{"message": {"role": "assistant", "content": text}}], "model": HF_MODEL_ID } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) def run_uvicorn_server(host: str, port: int): import uvicorn print(f"[API] Starting Uvicorn on {host}:{port}") uvicorn.run(app, host=host, port=port, log_level="warning") # ========== Streamlit UI 部分 ========== def run_streamlit_ui(api_host: str, api_port: int): import streamlit as st # 容器内部 API 地址 (本地回环) internal_api_url = f"http://127.0.0.1:{api_port}/v1/chat/completions" health_url = f"http://127.0.0.1:{api_port}/health" # --- 后台启动 API 线程 --- if httpx: api_up = False try: if httpx.get(health_url, timeout=0.1).status_code == 200: api_up = True except: pass # 如果 API 没通,且当前没有正在启动的线程,则启动 target_thread_name = f"uvicorn_{api_port}" running_threads = [t.name for t in threading.enumerate()] if not api_up and target_thread_name not in running_threads: t = threading.Thread( target=run_uvicorn_server, args=("0.0.0.0", api_port), # 绑定 0.0.0.0 方便外部调试 daemon=True, name=target_thread_name ) t.start() # 简单等待 time.sleep(2) # --- UI 渲染 --- st.set_page_config(page_title="VerMind Chat", layout="centered") st.title("🤖 VerMind AI (Docker Edition)") if "messages" not in st.session_state: st.session_state.messages = [] # 历史消息 for msg in st.session_state.messages: with st.chat_message(msg["role"]): st.markdown(msg["content"]) # 输入框 if prompt := st.chat_input("Input your question here..."): st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) with st.chat_message("assistant"): with st.spinner("Model is thinking... (CPU may be slow)"): try: payload = { "messages": st.session_state.messages, "temperature": 0.7, "max_tokens": 512 } # 增加超时时间,CPU 推理较慢 resp = httpx.post(internal_api_url, json=payload, timeout=300.0) if resp.status_code == 200: reply = resp.json()["choices"][0]["message"]["content"] st.markdown(reply) st.session_state.messages.append({"role": "assistant", "content": reply}) else: st.error(f"API Error: {resp.text}") except Exception as e: st.error(f"Connection Failed. Ensure API is running. Error: {e}") # ========== 入口 ========== if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--api", action="store_true", help="Run FastAPI only") parser.add_argument("--ui", action="store_true", help="Run Streamlit UI") parser.add_argument("--host", default=DEFAULT_API_HOST) parser.add_argument("--port", type=int, default=DEFAULT_API_PORT) # 接收 Streamlit 传入的未知参数 (如 streamlit 自身的 flag) args, unknown = parser.parse_known_args() if args.api: run_uvicorn_server(args.host, args.port) elif args.ui: # Streamlit 模式下,API 端口复用 --port 参数 (默认 8000) run_streamlit_ui(api_host=args.host, api_port=args.port) else: # 默认回退 (兼容 Docker 直接运行 python app.py) run_uvicorn_server(args.host, args.port)