Spaces:
Sleeping
Sleeping
File size: 7,304 Bytes
cd56330 2cffd9a fd44fb7 2cffd9a d28ef8c cd56330 d28ef8c 9b87900 cd56330 9b87900 cd56330 d28ef8c cd56330 d28ef8c cd56330 d28ef8c cd56330 d28ef8c cd56330 d28ef8c cd56330 fd44fb7 cd56330 d28ef8c 9b87900 d28ef8c 9b87900 ee70f87 d28ef8c cd56330 ee70f87 cd56330 ee70f87 cd56330 d28ef8c ee70f87 cd56330 d28ef8c cd56330 ee70f87 cd56330 ee70f87 fd44fb7 cd56330 9b87900 cd56330 9b87900 cd56330 9b87900 d28ef8c ee70f87 d28ef8c ee70f87 d28ef8c ee70f87 d28ef8c ee70f87 d28ef8c ee70f87 d28ef8c ee70f87 d28ef8c ee70f87 d28ef8c ee70f87 fd44fb7 cd56330 9b87900 cd56330 d28ef8c fd44fb7 d28ef8c fd44fb7 d28ef8c 9b87900 cd56330 fd44fb7 cd56330 fd44fb7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 |
import os
# --- Robust fix: HF/K8s may set OMP_NUM_THREADS like "7500m" (invalid for libgomp) ---
_raw_omp = os.getenv("OMP_NUM_THREADS", "")
if not _raw_omp.isdigit():
os.environ["OMP_NUM_THREADS"] = "1"
import html
import threading
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
MODEL_ID = os.getenv("MODEL_ID", "Milkfish033/deepseek-r1-1.5b-merged")
# 🔒 固定 system prompt(不在 UI 暴露)
SYSTEM_PROMPT = "你是 Bello,一个友好的智能助手。请用清晰、简洁的中文回答用户问题。"
theme = gr.themes.Soft()
css = """
.gradio-container { background: #ffffff !important; }
footer { display: none !important; }
.page-wrap {
max-width: 980px;
margin: 0 auto;
padding: 16px 12px 28px 12px;
}
.chat-card {
border: 1px solid #e5e7eb;
border-radius: 16px;
background: #ffffff;
box-shadow: 0 1px 10px rgba(0,0,0,0.06);
padding: 14px;
}
.chat-window {
border: 1px solid #e5e7eb;
border-radius: 14px;
background: #ffffff;
padding: 14px;
height: 520px;
overflow-y: auto;
}
/* 气泡 */
.bubble-row { display: flex; margin: 10px 0; }
.bubble-user { justify-content: flex-end; }
.bubble-bot { justify-content: flex-start; }
.bubble {
max-width: 78%;
padding: 10px 12px;
border-radius: 16px;
line-height: 1.45;
white-space: pre-wrap;
word-break: break-word;
border: 1px solid transparent;
}
.bubble.user {
background: #eef2ff;
border-color: #e0e7ff;
}
.bubble.bot {
background: #f8fafc;
border-color: #eef2f7;
}
/* 输入区 */
.input-row { margin-top: 12px; display: flex; gap: 10px; align-items: center; }
.input-row textarea {
border: 1px solid #d1d5db !important;
border-radius: 14px !important;
background: #ffffff !important;
}
.input-row button {
border-radius: 14px !important;
padding: 10px 14px !important;
}
"""
# -------------------------
# Load model once
# -------------------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True,
)
model.eval()
def build_prompt(history_msgs, user_msg: str) -> str:
"""
history_msgs: [{"role":"user"/"assistant", "content": "..."} ...]
"""
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
messages.extend(history_msgs)
messages.append({"role": "user", "content": user_msg})
if hasattr(tokenizer, "apply_chat_template"):
return tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
# fallback
prompt = f"System: {SYSTEM_PROMPT}\n"
for m in history_msgs:
if m["role"] == "user":
prompt += f"User: {m['content']}\n"
else:
prompt += f"Assistant: {m['content']}\n"
prompt += f"User: {user_msg}\nAssistant:"
return prompt
def stream_generate(prompt: str, max_new_tokens: int, temperature: float, top_p: float):
inputs = tokenizer(prompt, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.to(model.device) for k, v in inputs.items()}
streamer = TextIteratorStreamer(
tokenizer,
skip_special_tokens=True,
skip_prompt=True, # ✅ 不回显 prompt
)
gen_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=int(max_new_tokens),
do_sample=(float(temperature) > 0),
temperature=float(temperature),
top_p=float(top_p),
pad_token_id=tokenizer.eos_token_id,
)
t = threading.Thread(target=model.generate, kwargs=gen_kwargs)
t.start()
out = ""
for piece in streamer:
out += piece
yield out.strip()
def render_chat(history_msgs):
"""
把 history 渲染为 HTML(稳定,不受 Chatbot 格式限制)
"""
rows = []
for m in history_msgs:
role = m.get("role")
content = html.escape(m.get("content", ""))
if role == "user":
rows.append(
f'<div class="bubble-row bubble-user"><div class="bubble user">{content}</div></div>'
)
else:
rows.append(
f'<div class="bubble-row bubble-bot"><div class="bubble bot">{content}</div></div>'
)
if not rows:
rows.append(
'<div class="bubble-row bubble-bot"><div class="bubble bot">你好!我在这儿~有什么能帮到您?</div></div>'
)
return f'<div class="chat-window">{"".join(rows)}</div>'
def on_user_submit(user_text, history_msgs):
history_msgs = history_msgs or []
user_text = (user_text or "").strip()
if not user_text:
return gr.update(value=""), history_msgs, render_chat(history_msgs)
# 追加用户消息
history_msgs = history_msgs + [{"role": "user", "content": user_text}, {"role": "assistant", "content": ""}]
return gr.update(value=""), history_msgs, render_chat(history_msgs)
def on_bot_stream(history_msgs, max_tokens, temperature, top_p):
history_msgs = history_msgs or []
if len(history_msgs) < 2:
yield history_msgs, render_chat(history_msgs)
return
# 最后一条 user + assistant 占位
user_msg = history_msgs[-2]["content"]
prior = history_msgs[:-2]
prompt = build_prompt(prior, user_msg)
gen = stream_generate(prompt, max_tokens, temperature, top_p)
partial = ""
for chunk in gen:
partial = chunk
history_msgs[-1]["content"] = partial
yield history_msgs, render_chat(history_msgs)
with gr.Blocks() as demo:
with gr.Column(elem_classes=["page-wrap"]):
gr.Markdown("# 我是 Bello,有什么能帮到您?")
with gr.Column(elem_classes=["chat-card"]):
history_state = gr.State([]) # [{"role","content"}...]
chat_html = gr.HTML(render_chat([]))
with gr.Row(elem_classes=["input-row"]):
user_in = gr.Textbox(
placeholder="请输入问题...",
show_label=False,
lines=2,
scale=8,
)
send = gr.Button("发送", scale=1)
with gr.Row():
max_tokens = gr.Slider(1, 2048, value=512, step=1, label="Max new tokens")
temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")
# Enter / Click
user_in.submit(on_user_submit, [user_in, history_state], [user_in, history_state, chat_html], queue=False).then(
on_bot_stream, [history_state, max_tokens, temperature, top_p], [history_state, chat_html]
)
send.click(on_user_submit, [user_in, history_state], [user_in, history_state, chat_html], queue=False).then(
on_bot_stream, [history_state, max_tokens, temperature, top_p], [history_state, chat_html]
)
demo.queue(default_concurrency_limit=1)
if __name__ == "__main__":
demo.launch(ssr_mode=False, theme=theme, css=css)
|