Spaces:
Sleeping
Sleeping
| import os | |
| # --- Robust fix: HF/K8s may set OMP_NUM_THREADS like "7500m" (invalid for libgomp) --- | |
| _raw_omp = os.getenv("OMP_NUM_THREADS", "") | |
| if not _raw_omp.isdigit(): | |
| os.environ["OMP_NUM_THREADS"] = "1" | |
| import html | |
| import threading | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
| MODEL_ID = os.getenv("MODEL_ID", "Milkfish033/deepseek-r1-1.5b-merged") | |
| # 🔒 固定 system prompt(不在 UI 暴露) | |
| SYSTEM_PROMPT = "你是 Bello,一个友好的智能助手。请用清晰、简洁的中文回答用户问题。" | |
| theme = gr.themes.Soft() | |
| css = """ | |
| .gradio-container { background: #ffffff !important; } | |
| footer { display: none !important; } | |
| .page-wrap { | |
| max-width: 980px; | |
| margin: 0 auto; | |
| padding: 16px 12px 28px 12px; | |
| } | |
| .chat-card { | |
| border: 1px solid #e5e7eb; | |
| border-radius: 16px; | |
| background: #ffffff; | |
| box-shadow: 0 1px 10px rgba(0,0,0,0.06); | |
| padding: 14px; | |
| } | |
| .chat-window { | |
| border: 1px solid #e5e7eb; | |
| border-radius: 14px; | |
| background: #ffffff; | |
| padding: 14px; | |
| height: 520px; | |
| overflow-y: auto; | |
| } | |
| /* 气泡 */ | |
| .bubble-row { display: flex; margin: 10px 0; } | |
| .bubble-user { justify-content: flex-end; } | |
| .bubble-bot { justify-content: flex-start; } | |
| .bubble { | |
| max-width: 78%; | |
| padding: 10px 12px; | |
| border-radius: 16px; | |
| line-height: 1.45; | |
| white-space: pre-wrap; | |
| word-break: break-word; | |
| border: 1px solid transparent; | |
| } | |
| .bubble.user { | |
| background: #eef2ff; | |
| border-color: #e0e7ff; | |
| } | |
| .bubble.bot { | |
| background: #f8fafc; | |
| border-color: #eef2f7; | |
| } | |
| /* 输入区 */ | |
| .input-row { margin-top: 12px; display: flex; gap: 10px; align-items: center; } | |
| .input-row textarea { | |
| border: 1px solid #d1d5db !important; | |
| border-radius: 14px !important; | |
| background: #ffffff !important; | |
| } | |
| .input-row button { | |
| border-radius: 14px !important; | |
| padding: 10px 14px !important; | |
| } | |
| """ | |
| # ------------------------- | |
| # Load model once | |
| # ------------------------- | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| trust_remote_code=True, | |
| ) | |
| model.eval() | |
| def build_prompt(history_msgs, user_msg: str) -> str: | |
| """ | |
| history_msgs: [{"role":"user"/"assistant", "content": "..."} ...] | |
| """ | |
| messages = [{"role": "system", "content": SYSTEM_PROMPT}] | |
| messages.extend(history_msgs) | |
| messages.append({"role": "user", "content": user_msg}) | |
| if hasattr(tokenizer, "apply_chat_template"): | |
| return tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| # fallback | |
| prompt = f"System: {SYSTEM_PROMPT}\n" | |
| for m in history_msgs: | |
| if m["role"] == "user": | |
| prompt += f"User: {m['content']}\n" | |
| else: | |
| prompt += f"Assistant: {m['content']}\n" | |
| prompt += f"User: {user_msg}\nAssistant:" | |
| return prompt | |
| def stream_generate(prompt: str, max_new_tokens: int, temperature: float, top_p: float): | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| if torch.cuda.is_available(): | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| streamer = TextIteratorStreamer( | |
| tokenizer, | |
| skip_special_tokens=True, | |
| skip_prompt=True, # ✅ 不回显 prompt | |
| ) | |
| gen_kwargs = dict( | |
| **inputs, | |
| streamer=streamer, | |
| max_new_tokens=int(max_new_tokens), | |
| do_sample=(float(temperature) > 0), | |
| temperature=float(temperature), | |
| top_p=float(top_p), | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| t = threading.Thread(target=model.generate, kwargs=gen_kwargs) | |
| t.start() | |
| out = "" | |
| for piece in streamer: | |
| out += piece | |
| yield out.strip() | |
| def render_chat(history_msgs): | |
| """ | |
| 把 history 渲染为 HTML(稳定,不受 Chatbot 格式限制) | |
| """ | |
| rows = [] | |
| for m in history_msgs: | |
| role = m.get("role") | |
| content = html.escape(m.get("content", "")) | |
| if role == "user": | |
| rows.append( | |
| f'<div class="bubble-row bubble-user"><div class="bubble user">{content}</div></div>' | |
| ) | |
| else: | |
| rows.append( | |
| f'<div class="bubble-row bubble-bot"><div class="bubble bot">{content}</div></div>' | |
| ) | |
| if not rows: | |
| rows.append( | |
| '<div class="bubble-row bubble-bot"><div class="bubble bot">你好!我在这儿~有什么能帮到您?</div></div>' | |
| ) | |
| return f'<div class="chat-window">{"".join(rows)}</div>' | |
| def on_user_submit(user_text, history_msgs): | |
| history_msgs = history_msgs or [] | |
| user_text = (user_text or "").strip() | |
| if not user_text: | |
| return gr.update(value=""), history_msgs, render_chat(history_msgs) | |
| # 追加用户消息 | |
| history_msgs = history_msgs + [{"role": "user", "content": user_text}, {"role": "assistant", "content": ""}] | |
| return gr.update(value=""), history_msgs, render_chat(history_msgs) | |
| def on_bot_stream(history_msgs, max_tokens, temperature, top_p): | |
| history_msgs = history_msgs or [] | |
| if len(history_msgs) < 2: | |
| yield history_msgs, render_chat(history_msgs) | |
| return | |
| # 最后一条 user + assistant 占位 | |
| user_msg = history_msgs[-2]["content"] | |
| prior = history_msgs[:-2] | |
| prompt = build_prompt(prior, user_msg) | |
| gen = stream_generate(prompt, max_tokens, temperature, top_p) | |
| partial = "" | |
| for chunk in gen: | |
| partial = chunk | |
| history_msgs[-1]["content"] = partial | |
| yield history_msgs, render_chat(history_msgs) | |
| with gr.Blocks() as demo: | |
| with gr.Column(elem_classes=["page-wrap"]): | |
| gr.Markdown("# 我是 Bello,有什么能帮到您?") | |
| with gr.Column(elem_classes=["chat-card"]): | |
| history_state = gr.State([]) # [{"role","content"}...] | |
| chat_html = gr.HTML(render_chat([])) | |
| with gr.Row(elem_classes=["input-row"]): | |
| user_in = gr.Textbox( | |
| placeholder="请输入问题...", | |
| show_label=False, | |
| lines=2, | |
| scale=8, | |
| ) | |
| send = gr.Button("发送", scale=1) | |
| with gr.Row(): | |
| max_tokens = gr.Slider(1, 2048, value=512, step=1, label="Max new tokens") | |
| temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature") | |
| top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p") | |
| # Enter / Click | |
| user_in.submit(on_user_submit, [user_in, history_state], [user_in, history_state, chat_html], queue=False).then( | |
| on_bot_stream, [history_state, max_tokens, temperature, top_p], [history_state, chat_html] | |
| ) | |
| send.click(on_user_submit, [user_in, history_state], [user_in, history_state, chat_html], queue=False).then( | |
| on_bot_stream, [history_state, max_tokens, temperature, top_p], [history_state, chat_html] | |
| ) | |
| demo.queue(default_concurrency_limit=1) | |
| if __name__ == "__main__": | |
| demo.launch(ssr_mode=False, theme=theme, css=css) | |