Bello / app.py
Milkfish033's picture
Update app.py
d28ef8c verified
import os
# --- Robust fix: HF/K8s may set OMP_NUM_THREADS like "7500m" (invalid for libgomp) ---
_raw_omp = os.getenv("OMP_NUM_THREADS", "")
if not _raw_omp.isdigit():
os.environ["OMP_NUM_THREADS"] = "1"
import html
import threading
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
MODEL_ID = os.getenv("MODEL_ID", "Milkfish033/deepseek-r1-1.5b-merged")
# 🔒 固定 system prompt(不在 UI 暴露)
SYSTEM_PROMPT = "你是 Bello,一个友好的智能助手。请用清晰、简洁的中文回答用户问题。"
theme = gr.themes.Soft()
css = """
.gradio-container { background: #ffffff !important; }
footer { display: none !important; }
.page-wrap {
max-width: 980px;
margin: 0 auto;
padding: 16px 12px 28px 12px;
}
.chat-card {
border: 1px solid #e5e7eb;
border-radius: 16px;
background: #ffffff;
box-shadow: 0 1px 10px rgba(0,0,0,0.06);
padding: 14px;
}
.chat-window {
border: 1px solid #e5e7eb;
border-radius: 14px;
background: #ffffff;
padding: 14px;
height: 520px;
overflow-y: auto;
}
/* 气泡 */
.bubble-row { display: flex; margin: 10px 0; }
.bubble-user { justify-content: flex-end; }
.bubble-bot { justify-content: flex-start; }
.bubble {
max-width: 78%;
padding: 10px 12px;
border-radius: 16px;
line-height: 1.45;
white-space: pre-wrap;
word-break: break-word;
border: 1px solid transparent;
}
.bubble.user {
background: #eef2ff;
border-color: #e0e7ff;
}
.bubble.bot {
background: #f8fafc;
border-color: #eef2f7;
}
/* 输入区 */
.input-row { margin-top: 12px; display: flex; gap: 10px; align-items: center; }
.input-row textarea {
border: 1px solid #d1d5db !important;
border-radius: 14px !important;
background: #ffffff !important;
}
.input-row button {
border-radius: 14px !important;
padding: 10px 14px !important;
}
"""
# -------------------------
# Load model once
# -------------------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True,
)
model.eval()
def build_prompt(history_msgs, user_msg: str) -> str:
"""
history_msgs: [{"role":"user"/"assistant", "content": "..."} ...]
"""
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
messages.extend(history_msgs)
messages.append({"role": "user", "content": user_msg})
if hasattr(tokenizer, "apply_chat_template"):
return tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
# fallback
prompt = f"System: {SYSTEM_PROMPT}\n"
for m in history_msgs:
if m["role"] == "user":
prompt += f"User: {m['content']}\n"
else:
prompt += f"Assistant: {m['content']}\n"
prompt += f"User: {user_msg}\nAssistant:"
return prompt
def stream_generate(prompt: str, max_new_tokens: int, temperature: float, top_p: float):
inputs = tokenizer(prompt, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.to(model.device) for k, v in inputs.items()}
streamer = TextIteratorStreamer(
tokenizer,
skip_special_tokens=True,
skip_prompt=True, # ✅ 不回显 prompt
)
gen_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=int(max_new_tokens),
do_sample=(float(temperature) > 0),
temperature=float(temperature),
top_p=float(top_p),
pad_token_id=tokenizer.eos_token_id,
)
t = threading.Thread(target=model.generate, kwargs=gen_kwargs)
t.start()
out = ""
for piece in streamer:
out += piece
yield out.strip()
def render_chat(history_msgs):
"""
把 history 渲染为 HTML(稳定,不受 Chatbot 格式限制)
"""
rows = []
for m in history_msgs:
role = m.get("role")
content = html.escape(m.get("content", ""))
if role == "user":
rows.append(
f'<div class="bubble-row bubble-user"><div class="bubble user">{content}</div></div>'
)
else:
rows.append(
f'<div class="bubble-row bubble-bot"><div class="bubble bot">{content}</div></div>'
)
if not rows:
rows.append(
'<div class="bubble-row bubble-bot"><div class="bubble bot">你好!我在这儿~有什么能帮到您?</div></div>'
)
return f'<div class="chat-window">{"".join(rows)}</div>'
def on_user_submit(user_text, history_msgs):
history_msgs = history_msgs or []
user_text = (user_text or "").strip()
if not user_text:
return gr.update(value=""), history_msgs, render_chat(history_msgs)
# 追加用户消息
history_msgs = history_msgs + [{"role": "user", "content": user_text}, {"role": "assistant", "content": ""}]
return gr.update(value=""), history_msgs, render_chat(history_msgs)
def on_bot_stream(history_msgs, max_tokens, temperature, top_p):
history_msgs = history_msgs or []
if len(history_msgs) < 2:
yield history_msgs, render_chat(history_msgs)
return
# 最后一条 user + assistant 占位
user_msg = history_msgs[-2]["content"]
prior = history_msgs[:-2]
prompt = build_prompt(prior, user_msg)
gen = stream_generate(prompt, max_tokens, temperature, top_p)
partial = ""
for chunk in gen:
partial = chunk
history_msgs[-1]["content"] = partial
yield history_msgs, render_chat(history_msgs)
with gr.Blocks() as demo:
with gr.Column(elem_classes=["page-wrap"]):
gr.Markdown("# 我是 Bello,有什么能帮到您?")
with gr.Column(elem_classes=["chat-card"]):
history_state = gr.State([]) # [{"role","content"}...]
chat_html = gr.HTML(render_chat([]))
with gr.Row(elem_classes=["input-row"]):
user_in = gr.Textbox(
placeholder="请输入问题...",
show_label=False,
lines=2,
scale=8,
)
send = gr.Button("发送", scale=1)
with gr.Row():
max_tokens = gr.Slider(1, 2048, value=512, step=1, label="Max new tokens")
temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")
# Enter / Click
user_in.submit(on_user_submit, [user_in, history_state], [user_in, history_state, chat_html], queue=False).then(
on_bot_stream, [history_state, max_tokens, temperature, top_p], [history_state, chat_html]
)
send.click(on_user_submit, [user_in, history_state], [user_in, history_state, chat_html], queue=False).then(
on_bot_stream, [history_state, max_tokens, temperature, top_p], [history_state, chat_html]
)
demo.queue(default_concurrency_limit=1)
if __name__ == "__main__":
demo.launch(ssr_mode=False, theme=theme, css=css)