| | |
| | """ |
| | VerMind All-in-One App (FastAPI + Streamlit) |
| | 适配 Docker CPU 环境 (自动降级 float32) |
| | """ |
| | import os |
| | import sys |
| | import time |
| | import argparse |
| | import threading |
| | from functools import lru_cache |
| | from typing import List, Dict |
| |
|
| | |
| | |
| | current_dir = os.path.dirname(os.path.abspath(__file__)) |
| | src_dir = os.path.join(current_dir, "src") |
| | if os.path.exists(src_dir) and src_dir not in sys.path: |
| | sys.path.append(src_dir) |
| |
|
| | |
| | for var in ["http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"]: |
| | os.environ.pop(var, None) |
| |
|
| | |
| | from fastapi import FastAPI, HTTPException |
| | from pydantic import BaseModel |
| | import torch |
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| |
|
| | try: |
| | import httpx |
| | except ImportError: |
| | httpx = None |
| |
|
| | |
| | HF_MODEL_ID = os.getenv("HF_MODEL_ID", "nev8rz/vermind") |
| | DEFAULT_API_HOST = "0.0.0.0" |
| | DEFAULT_API_PORT = 8000 |
| |
|
| | |
| | def pick_device() -> str: |
| | """自动选择设备""" |
| | forced = os.getenv("DEVICE", "").strip().lower() |
| | if forced in {"cpu", "cuda", "mps"}: |
| | return forced |
| | return "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | def pick_dtype(device: str): |
| | """ |
| | 根据设备选择精度。 |
| | 警告:在 Docker CPU 环境下,必须使用 float32,否则会报 "addmm_impl_cpu_" not implemented for 'Half' |
| | """ |
| | s = os.getenv("DTYPE", "").strip().lower() |
| | if s in {"fp16", "float16"}: return torch.float16 |
| | if s in {"bf16", "bfloat16"}: return torch.bfloat16 |
| | if s in {"fp32", "float32"}: return torch.float32 |
| |
|
| | if device == "cpu": |
| | |
| | return torch.float32 |
| | return torch.float16 |
| |
|
| | @lru_cache(maxsize=1) |
| | def load_tokenizer_and_model(): |
| | """单例加载模型""" |
| | device = pick_device() |
| | dtype = pick_dtype(device) |
| | |
| | print(f"[Init] Model: {HF_MODEL_ID} | Device: {device} | Dtype: {dtype}") |
| | |
| | token = os.getenv("HF_TOKEN", None) |
| | |
| | try: |
| | tok = AutoTokenizer.from_pretrained( |
| | HF_MODEL_ID, |
| | trust_remote_code=True, |
| | token=token |
| | ) |
| | |
| | |
| | if device == "cuda": |
| | model = AutoModelForCausalLM.from_pretrained( |
| | HF_MODEL_ID, |
| | trust_remote_code=True, |
| | torch_dtype=dtype, |
| | device_map="auto", |
| | token=token |
| | ) |
| | else: |
| | model = AutoModelForCausalLM.from_pretrained( |
| | HF_MODEL_ID, |
| | trust_remote_code=True, |
| | torch_dtype=dtype, |
| | token=token |
| | ).to(device) |
| |
|
| | model.eval() |
| | return tok, model, device |
| | except Exception as e: |
| | print(f"[Error] Failed to load model: {e}") |
| | raise e |
| |
|
| | def generate_reply(messages: List[Dict], max_new_tokens=256, temperature=0.7, top_p=0.9) -> str: |
| | tok, model, device = load_tokenizer_and_model() |
| | |
| | |
| | try: |
| | prompt = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| | except: |
| | |
| | prompt = "" |
| | for m in messages: |
| | prompt += f"{m['role']}: {m['content']}\n" |
| | prompt += "assistant: " |
| |
|
| | inputs = tok(prompt, return_tensors="pt") |
| | if device != "cuda": |
| | inputs = {k: v.to(device) for k, v in inputs.items()} |
| |
|
| | gen_kwargs = { |
| | "max_new_tokens": max_new_tokens, |
| | "do_sample": temperature > 0, |
| | "pad_token_id": tok.eos_token_id |
| | } |
| | if temperature > 0: |
| | gen_kwargs["temperature"] = temperature |
| | gen_kwargs["top_p"] = top_p |
| |
|
| | with torch.no_grad(): |
| | outputs = model.generate(**inputs, **gen_kwargs) |
| | |
| | |
| | input_len = inputs["input_ids"].shape[1] |
| | return tok.decode(outputs[0][input_len:], skip_special_tokens=True).strip() |
| |
|
| | |
| | app = FastAPI() |
| |
|
| | class ChatReq(BaseModel): |
| | messages: List[dict] |
| | max_tokens: int = 256 |
| | temperature: float = 0.7 |
| | top_p: float = 0.9 |
| |
|
| | @app.get("/health") |
| | async def health(): |
| | return {"status": "ok"} |
| |
|
| | @app.post("/v1/chat/completions") |
| | async def chat(req: ChatReq): |
| | try: |
| | text = generate_reply(req.messages, req.max_tokens, req.temperature, req.top_p) |
| | return { |
| | "choices": [{"message": {"role": "assistant", "content": text}}], |
| | "model": HF_MODEL_ID |
| | } |
| | except Exception as e: |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| | def run_uvicorn_server(host: str, port: int): |
| | import uvicorn |
| | print(f"[API] Starting Uvicorn on {host}:{port}") |
| | uvicorn.run(app, host=host, port=port, log_level="warning") |
| |
|
| | |
| | def run_streamlit_ui(api_host: str, api_port: int): |
| | import streamlit as st |
| | |
| | |
| | internal_api_url = f"http://127.0.0.1:{api_port}/v1/chat/completions" |
| | health_url = f"http://127.0.0.1:{api_port}/health" |
| |
|
| | |
| | if httpx: |
| | api_up = False |
| | try: |
| | if httpx.get(health_url, timeout=0.1).status_code == 200: |
| | api_up = True |
| | except: |
| | pass |
| | |
| | |
| | target_thread_name = f"uvicorn_{api_port}" |
| | running_threads = [t.name for t in threading.enumerate()] |
| | |
| | if not api_up and target_thread_name not in running_threads: |
| | t = threading.Thread( |
| | target=run_uvicorn_server, |
| | args=("0.0.0.0", api_port), |
| | daemon=True, |
| | name=target_thread_name |
| | ) |
| | t.start() |
| | |
| | time.sleep(2) |
| |
|
| | |
| | st.set_page_config(page_title="VerMind Chat", layout="centered") |
| | st.title("🤖 VerMind AI (Docker Edition)") |
| |
|
| | if "messages" not in st.session_state: |
| | st.session_state.messages = [] |
| |
|
| | |
| | for msg in st.session_state.messages: |
| | with st.chat_message(msg["role"]): |
| | st.markdown(msg["content"]) |
| |
|
| | |
| | if prompt := st.chat_input("Input your question here..."): |
| | st.session_state.messages.append({"role": "user", "content": prompt}) |
| | with st.chat_message("user"): |
| | st.markdown(prompt) |
| |
|
| | with st.chat_message("assistant"): |
| | with st.spinner("Model is thinking... (CPU may be slow)"): |
| | try: |
| | payload = { |
| | "messages": st.session_state.messages, |
| | "temperature": 0.7, |
| | "max_tokens": 512 |
| | } |
| | |
| | resp = httpx.post(internal_api_url, json=payload, timeout=300.0) |
| | if resp.status_code == 200: |
| | reply = resp.json()["choices"][0]["message"]["content"] |
| | st.markdown(reply) |
| | st.session_state.messages.append({"role": "assistant", "content": reply}) |
| | else: |
| | st.error(f"API Error: {resp.text}") |
| | except Exception as e: |
| | st.error(f"Connection Failed. Ensure API is running. Error: {e}") |
| |
|
| | |
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--api", action="store_true", help="Run FastAPI only") |
| | parser.add_argument("--ui", action="store_true", help="Run Streamlit UI") |
| | parser.add_argument("--host", default=DEFAULT_API_HOST) |
| | parser.add_argument("--port", type=int, default=DEFAULT_API_PORT) |
| | |
| | |
| | args, unknown = parser.parse_known_args() |
| |
|
| | if args.api: |
| | run_uvicorn_server(args.host, args.port) |
| | elif args.ui: |
| | |
| | run_streamlit_ui(api_host=args.host, api_port=args.port) |
| | else: |
| | |
| | run_uvicorn_server(args.host, args.port) |