import os # --- Robust fix: HF/K8s may set OMP_NUM_THREADS like "7500m" (invalid for libgomp) --- _raw_omp = os.getenv("OMP_NUM_THREADS", "") if not _raw_omp.isdigit(): os.environ["OMP_NUM_THREADS"] = "1" import html import threading import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer MODEL_ID = os.getenv("MODEL_ID", "Milkfish033/deepseek-r1-1.5b-merged") # 🔒 固定 system prompt(不在 UI 暴露) SYSTEM_PROMPT = "你是 Bello,一个友好的智能助手。请用清晰、简洁的中文回答用户问题。" theme = gr.themes.Soft() css = """ .gradio-container { background: #ffffff !important; } footer { display: none !important; } .page-wrap { max-width: 980px; margin: 0 auto; padding: 16px 12px 28px 12px; } .chat-card { border: 1px solid #e5e7eb; border-radius: 16px; background: #ffffff; box-shadow: 0 1px 10px rgba(0,0,0,0.06); padding: 14px; } .chat-window { border: 1px solid #e5e7eb; border-radius: 14px; background: #ffffff; padding: 14px; height: 520px; overflow-y: auto; } /* 气泡 */ .bubble-row { display: flex; margin: 10px 0; } .bubble-user { justify-content: flex-end; } .bubble-bot { justify-content: flex-start; } .bubble { max-width: 78%; padding: 10px 12px; border-radius: 16px; line-height: 1.45; white-space: pre-wrap; word-break: break-word; border: 1px solid transparent; } .bubble.user { background: #eef2ff; border-color: #e0e7ff; } .bubble.bot { background: #f8fafc; border-color: #eef2f7; } /* 输入区 */ .input-row { margin-top: 12px; display: flex; gap: 10px; align-items: center; } .input-row textarea { border: 1px solid #d1d5db !important; border-radius: 14px !important; background: #ffffff !important; } .input-row button { border-radius: 14px !important; padding: 10px 14px !important; } """ # ------------------------- # Load model once # ------------------------- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None, trust_remote_code=True, ) model.eval() def build_prompt(history_msgs, user_msg: str) -> str: """ history_msgs: [{"role":"user"/"assistant", "content": "..."} ...] """ messages = [{"role": "system", "content": SYSTEM_PROMPT}] messages.extend(history_msgs) messages.append({"role": "user", "content": user_msg}) if hasattr(tokenizer, "apply_chat_template"): return tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) # fallback prompt = f"System: {SYSTEM_PROMPT}\n" for m in history_msgs: if m["role"] == "user": prompt += f"User: {m['content']}\n" else: prompt += f"Assistant: {m['content']}\n" prompt += f"User: {user_msg}\nAssistant:" return prompt def stream_generate(prompt: str, max_new_tokens: int, temperature: float, top_p: float): inputs = tokenizer(prompt, return_tensors="pt") if torch.cuda.is_available(): inputs = {k: v.to(model.device) for k, v in inputs.items()} streamer = TextIteratorStreamer( tokenizer, skip_special_tokens=True, skip_prompt=True, # ✅ 不回显 prompt ) gen_kwargs = dict( **inputs, streamer=streamer, max_new_tokens=int(max_new_tokens), do_sample=(float(temperature) > 0), temperature=float(temperature), top_p=float(top_p), pad_token_id=tokenizer.eos_token_id, ) t = threading.Thread(target=model.generate, kwargs=gen_kwargs) t.start() out = "" for piece in streamer: out += piece yield out.strip() def render_chat(history_msgs): """ 把 history 渲染为 HTML(稳定,不受 Chatbot 格式限制) """ rows = [] for m in history_msgs: role = m.get("role") content = html.escape(m.get("content", "")) if role == "user": rows.append( f'
{content}
' ) else: rows.append( f'
{content}
' ) if not rows: rows.append( '
你好!我在这儿~有什么能帮到您?
' ) return f'
{"".join(rows)}
' def on_user_submit(user_text, history_msgs): history_msgs = history_msgs or [] user_text = (user_text or "").strip() if not user_text: return gr.update(value=""), history_msgs, render_chat(history_msgs) # 追加用户消息 history_msgs = history_msgs + [{"role": "user", "content": user_text}, {"role": "assistant", "content": ""}] return gr.update(value=""), history_msgs, render_chat(history_msgs) def on_bot_stream(history_msgs, max_tokens, temperature, top_p): history_msgs = history_msgs or [] if len(history_msgs) < 2: yield history_msgs, render_chat(history_msgs) return # 最后一条 user + assistant 占位 user_msg = history_msgs[-2]["content"] prior = history_msgs[:-2] prompt = build_prompt(prior, user_msg) gen = stream_generate(prompt, max_tokens, temperature, top_p) partial = "" for chunk in gen: partial = chunk history_msgs[-1]["content"] = partial yield history_msgs, render_chat(history_msgs) with gr.Blocks() as demo: with gr.Column(elem_classes=["page-wrap"]): gr.Markdown("# 我是 Bello,有什么能帮到您?") with gr.Column(elem_classes=["chat-card"]): history_state = gr.State([]) # [{"role","content"}...] chat_html = gr.HTML(render_chat([])) with gr.Row(elem_classes=["input-row"]): user_in = gr.Textbox( placeholder="请输入问题...", show_label=False, lines=2, scale=8, ) send = gr.Button("发送", scale=1) with gr.Row(): max_tokens = gr.Slider(1, 2048, value=512, step=1, label="Max new tokens") temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature") top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p") # Enter / Click user_in.submit(on_user_submit, [user_in, history_state], [user_in, history_state, chat_html], queue=False).then( on_bot_stream, [history_state, max_tokens, temperature, top_p], [history_state, chat_html] ) send.click(on_user_submit, [user_in, history_state], [user_in, history_state, chat_html], queue=False).then( on_bot_stream, [history_state, max_tokens, temperature, top_p], [history_state, chat_html] ) demo.queue(default_concurrency_limit=1) if __name__ == "__main__": demo.launch(ssr_mode=False, theme=theme, css=css)