File size: 8,562 Bytes
099a252
a3d2995
6d38798
 
a3d2995
099a252
 
 
 
 
 
6d38798
 
 
 
 
 
 
 
099a252
6d38798
099a252
 
 
6d38798
099a252
 
 
 
 
 
 
6d38798
099a252
 
6d38798
 
099a252
 
 
6d38798
099a252
6d38798
099a252
 
 
 
 
6d38798
 
 
 
 
 
 
 
 
099a252
6d38798
 
099a252
 
 
 
 
6d38798
099a252
6d38798
 
 
 
 
 
 
 
099a252
 
6d38798
099a252
6d38798
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
099a252
6d38798
 
 
 
099a252
6d38798
 
 
 
099a252
6d38798
 
099a252
 
 
 
 
6d38798
 
 
 
 
 
 
 
099a252
 
 
6d38798
 
099a252
6d38798
099a252
6d38798
099a252
 
6d38798
 
 
 
 
099a252
 
 
6d38798
099a252
 
6d38798
099a252
6d38798
099a252
6d38798
 
099a252
 
 
 
6d38798
099a252
6d38798
 
099a252
6d38798
 
099a252
6d38798
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c721945
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
#!/usr/bin/env python
"""
VerMind All-in-One App (FastAPI + Streamlit)
适配 Docker CPU 环境 (自动降级 float32)
"""
import os
import sys
import time
import argparse
import threading
from functools import lru_cache
from typing import List, Dict

# ---- 1. 路径适配: 确保 src 目录能被 import ----
# Docker 中 WORKDIR 是 /app,src 在 /app/src
current_dir = os.path.dirname(os.path.abspath(__file__))
src_dir = os.path.join(current_dir, "src")
if os.path.exists(src_dir) and src_dir not in sys.path:
    sys.path.append(src_dir)

# ---- 2. 清理代理 (避免 Docker 内网络请求被宿主机代理干扰) ----
for var in ["http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"]:
    os.environ.pop(var, None)

# ---- 依赖导入 ----
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

try:
    import httpx
except ImportError:
    httpx = None

# ========== 全局配置 ==========
HF_MODEL_ID = os.getenv("HF_MODEL_ID", "nev8rz/vermind")
DEFAULT_API_HOST = "0.0.0.0"
DEFAULT_API_PORT = 8000

# ========== 核心工具函数 ==========
def pick_device() -> str:
    """自动选择设备"""
    forced = os.getenv("DEVICE", "").strip().lower()
    if forced in {"cpu", "cuda", "mps"}:
        return forced
    return "cuda" if torch.cuda.is_available() else "cpu"

def pick_dtype(device: str):
    """
    根据设备选择精度。
    警告:在 Docker CPU 环境下,必须使用 float32,否则会报 "addmm_impl_cpu_" not implemented for 'Half'
    """
    s = os.getenv("DTYPE", "").strip().lower()
    if s in {"fp16", "float16"}: return torch.float16
    if s in {"bf16", "bfloat16"}: return torch.bfloat16
    if s in {"fp32", "float32"}: return torch.float32

    if device == "cpu":
        # CPU 默认强制 float32
        return torch.float32
    return torch.float16

@lru_cache(maxsize=1)
def load_tokenizer_and_model():
    """单例加载模型"""
    device = pick_device()
    dtype = pick_dtype(device)
    
    print(f"[Init] Model: {HF_MODEL_ID} | Device: {device} | Dtype: {dtype}")
    
    token = os.getenv("HF_TOKEN", None)
    
    try:
        tok = AutoTokenizer.from_pretrained(
            HF_MODEL_ID,
            trust_remote_code=True,
            token=token
        )
        
        # CPU 加载逻辑
        if device == "cuda":
            model = AutoModelForCausalLM.from_pretrained(
                HF_MODEL_ID,
                trust_remote_code=True,
                torch_dtype=dtype,
                device_map="auto",
                token=token
            )
        else:
            model = AutoModelForCausalLM.from_pretrained(
                HF_MODEL_ID,
                trust_remote_code=True,
                torch_dtype=dtype,
                token=token
            ).to(device) # CPU 显式 to(device) 更稳

        model.eval()
        return tok, model, device
    except Exception as e:
        print(f"[Error] Failed to load model: {e}")
        raise e

def generate_reply(messages: List[Dict], max_new_tokens=256, temperature=0.7, top_p=0.9) -> str:
    tok, model, device = load_tokenizer_and_model()
    
    # 构建 Prompt (如果模型有 chat_template 则使用,否则 fallback)
    try:
        prompt = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    except:
        # 简易拼接 fallback
        prompt = ""
        for m in messages:
            prompt += f"{m['role']}: {m['content']}\n"
        prompt += "assistant: "

    inputs = tok(prompt, return_tensors="pt")
    if device != "cuda":
        inputs = {k: v.to(device) for k, v in inputs.items()}

    gen_kwargs = {
        "max_new_tokens": max_new_tokens,
        "do_sample": temperature > 0,
        "pad_token_id": tok.eos_token_id
    }
    if temperature > 0:
        gen_kwargs["temperature"] = temperature
        gen_kwargs["top_p"] = top_p

    with torch.no_grad():
        outputs = model.generate(**inputs, **gen_kwargs)
    
    # 解码
    input_len = inputs["input_ids"].shape[1]
    return tok.decode(outputs[0][input_len:], skip_special_tokens=True).strip()

# ========== FastAPI 部分 ==========
app = FastAPI()

class ChatReq(BaseModel):
    messages: List[dict]
    max_tokens: int = 256
    temperature: float = 0.7
    top_p: float = 0.9

@app.get("/health")
async def health():
    return {"status": "ok"}

@app.post("/v1/chat/completions")
async def chat(req: ChatReq):
    try:
        text = generate_reply(req.messages, req.max_tokens, req.temperature, req.top_p)
        return {
            "choices": [{"message": {"role": "assistant", "content": text}}],
            "model": HF_MODEL_ID
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

def run_uvicorn_server(host: str, port: int):
    import uvicorn
    print(f"[API] Starting Uvicorn on {host}:{port}")
    uvicorn.run(app, host=host, port=port, log_level="warning")

# ========== Streamlit UI 部分 ==========
def run_streamlit_ui(api_host: str, api_port: int):
    import streamlit as st
    
    # 容器内部 API 地址 (本地回环)
    internal_api_url = f"http://127.0.0.1:{api_port}/v1/chat/completions"
    health_url = f"http://127.0.0.1:{api_port}/health"

    # --- 后台启动 API 线程 ---
    if httpx:
        api_up = False
        try:
            if httpx.get(health_url, timeout=0.1).status_code == 200:
                api_up = True
        except:
            pass
            
        # 如果 API 没通,且当前没有正在启动的线程,则启动
        target_thread_name = f"uvicorn_{api_port}"
        running_threads = [t.name for t in threading.enumerate()]
        
        if not api_up and target_thread_name not in running_threads:
            t = threading.Thread(
                target=run_uvicorn_server,
                args=("0.0.0.0", api_port), # 绑定 0.0.0.0 方便外部调试
                daemon=True,
                name=target_thread_name
            )
            t.start()
            # 简单等待
            time.sleep(2)

    # --- UI 渲染 ---
    st.set_page_config(page_title="VerMind Chat", layout="centered")
    st.title("🤖 VerMind AI (Docker Edition)")

    if "messages" not in st.session_state:
        st.session_state.messages = []

    # 历史消息
    for msg in st.session_state.messages:
        with st.chat_message(msg["role"]):
            st.markdown(msg["content"])

    # 输入框
    if prompt := st.chat_input("Input your question here..."):
        st.session_state.messages.append({"role": "user", "content": prompt})
        with st.chat_message("user"):
            st.markdown(prompt)

        with st.chat_message("assistant"):
            with st.spinner("Model is thinking... (CPU may be slow)"):
                try:
                    payload = {
                        "messages": st.session_state.messages,
                        "temperature": 0.7,
                        "max_tokens": 512
                    }
                    # 增加超时时间,CPU 推理较慢
                    resp = httpx.post(internal_api_url, json=payload, timeout=300.0)
                    if resp.status_code == 200:
                        reply = resp.json()["choices"][0]["message"]["content"]
                        st.markdown(reply)
                        st.session_state.messages.append({"role": "assistant", "content": reply})
                    else:
                        st.error(f"API Error: {resp.text}")
                except Exception as e:
                    st.error(f"Connection Failed. Ensure API is running. Error: {e}")

# ========== 入口 ==========
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--api", action="store_true", help="Run FastAPI only")
    parser.add_argument("--ui", action="store_true", help="Run Streamlit UI")
    parser.add_argument("--host", default=DEFAULT_API_HOST)
    parser.add_argument("--port", type=int, default=DEFAULT_API_PORT)
    
    # 接收 Streamlit 传入的未知参数 (如 streamlit 自身的 flag)
    args, unknown = parser.parse_known_args()

    if args.api:
        run_uvicorn_server(args.host, args.port)
    elif args.ui:
        # Streamlit 模式下,API 端口复用 --port 参数 (默认 8000)
        run_streamlit_ui(api_host=args.host, api_port=args.port)
    else:
        # 默认回退 (兼容 Docker 直接运行 python app.py)
        run_uvicorn_server(args.host, args.port)