import os # --- Robust fix: HF/K8s may set OMP_NUM_THREADS like "7500m" (invalid for libgomp) --- _raw_omp = os.getenv("OMP_NUM_THREADS", "") if not _raw_omp.isdigit(): os.environ["OMP_NUM_THREADS"] = "1" import html import threading import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer MODEL_ID = os.getenv("MODEL_ID", "Milkfish033/deepseek-r1-1.5b-merged") # 🔒 固定 system prompt(不在 UI 暴露) SYSTEM_PROMPT = "你是 Bello,一个友好的智能助手。请用清晰、简洁的中文回答用户问题。" theme = gr.themes.Soft() css = """ .gradio-container { background: #ffffff !important; } footer { display: none !important; } .page-wrap { max-width: 980px; margin: 0 auto; padding: 16px 12px 28px 12px; } .chat-card { border: 1px solid #e5e7eb; border-radius: 16px; background: #ffffff; box-shadow: 0 1px 10px rgba(0,0,0,0.06); padding: 14px; } .chat-window { border: 1px solid #e5e7eb; border-radius: 14px; background: #ffffff; padding: 14px; height: 520px; overflow-y: auto; } /* 气泡 */ .bubble-row { display: flex; margin: 10px 0; } .bubble-user { justify-content: flex-end; } .bubble-bot { justify-content: flex-start; } .bubble { max-width: 78%; padding: 10px 12px; border-radius: 16px; line-height: 1.45; white-space: pre-wrap; word-break: break-word; border: 1px solid transparent; } .bubble.user { background: #eef2ff; border-color: #e0e7ff; } .bubble.bot { background: #f8fafc; border-color: #eef2f7; } /* 输入区 */ .input-row { margin-top: 12px; display: flex; gap: 10px; align-items: center; } .input-row textarea { border: 1px solid #d1d5db !important; border-radius: 14px !important; background: #ffffff !important; } .input-row button { border-radius: 14px !important; padding: 10px 14px !important; } """ # ------------------------- # Load model once # ------------------------- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None, trust_remote_code=True, ) model.eval() def build_prompt(history_msgs, user_msg: str) -> str: """ history_msgs: [{"role":"user"/"assistant", "content": "..."} ...] """ messages = [{"role": "system", "content": SYSTEM_PROMPT}] messages.extend(history_msgs) messages.append({"role": "user", "content": user_msg}) if hasattr(tokenizer, "apply_chat_template"): return tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) # fallback prompt = f"System: {SYSTEM_PROMPT}\n" for m in history_msgs: if m["role"] == "user": prompt += f"User: {m['content']}\n" else: prompt += f"Assistant: {m['content']}\n" prompt += f"User: {user_msg}\nAssistant:" return prompt def stream_generate(prompt: str, max_new_tokens: int, temperature: float, top_p: float): inputs = tokenizer(prompt, return_tensors="pt") if torch.cuda.is_available(): inputs = {k: v.to(model.device) for k, v in inputs.items()} streamer = TextIteratorStreamer( tokenizer, skip_special_tokens=True, skip_prompt=True, # ✅ 不回显 prompt ) gen_kwargs = dict( **inputs, streamer=streamer, max_new_tokens=int(max_new_tokens), do_sample=(float(temperature) > 0), temperature=float(temperature), top_p=float(top_p), pad_token_id=tokenizer.eos_token_id, ) t = threading.Thread(target=model.generate, kwargs=gen_kwargs) t.start() out = "" for piece in streamer: out += piece yield out.strip() def render_chat(history_msgs): """ 把 history 渲染为 HTML(稳定,不受 Chatbot 格式限制) """ rows = [] for m in history_msgs: role = m.get("role") content = html.escape(m.get("content", "")) if role == "user": rows.append( f'