Milkfish033 commited on
Commit
fd44fb7
·
verified ·
1 Parent(s): 2cffd9a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -44
app.py CHANGED
@@ -1,19 +1,21 @@
1
  import os
2
 
3
- # ---- Robust fix for OMP_NUM_THREADS (HF / K8s may set it to '7500m') ----
4
  _raw_omp = os.getenv("OMP_NUM_THREADS", "")
5
  if not _raw_omp.isdigit():
6
  os.environ["OMP_NUM_THREADS"] = "1"
7
 
8
-
9
  import threading
10
  import gradio as gr
11
  import torch
12
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
13
 
 
 
 
14
  MODEL_ID = os.getenv("MODEL_ID", "Milkfish033/deepseek-r1-1.5b-merged")
15
 
16
- # 🔒 固定 system prompt(UI 暴露)
17
  SYSTEM_PROMPT = "你是 Bello,一个友好的智能助手。请用清晰、简洁的中文回答用户问题。"
18
 
19
  theme = gr.themes.Soft()
@@ -35,7 +37,7 @@ footer { display: none !important; }
35
  padding: 12px;
36
  }
37
 
38
- /* 输入框边框 */
39
  .chat-card textarea,
40
  .chat-card input {
41
  border: 1px solid #d1d5db !important;
@@ -46,25 +48,38 @@ footer { display: none !important; }
46
  /* 发送按钮圆角 */
47
  .chat-card button { border-radius: 14px !important; }
48
 
49
- /* 气泡样式不同 gradio 版本 class 不,多写命中率) */
50
  .chat-card .message.user,
51
- .chat-card .bubble.user {
 
 
52
  background: #eef2ff !important;
53
  border: 1px solid #e0e7ff !important;
54
  border-radius: 16px !important;
55
  }
56
 
57
- .chat-card .message.bot,
58
  .chat-card .message.assistant,
 
 
59
  .chat-card .bubble.bot,
60
- .chat-card .bubble.assistant {
 
61
  background: #f8fafc !important;
62
  border: 1px solid #eef2f7 !important;
63
  border-radius: 16px !important;
64
  }
 
 
 
 
 
 
 
65
  """
66
 
67
- # ---- Load model once ----
 
 
68
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
69
  model = AutoModelForCausalLM.from_pretrained(
70
  MODEL_ID,
@@ -74,20 +89,60 @@ model = AutoModelForCausalLM.from_pretrained(
74
  )
75
  model.eval()
76
 
77
-
78
- def _build_prompt(history_pairs, user_msg: str) -> str:
 
 
79
  """
80
- 旧版 ChatInterface 的 history 是 [(user, bot), ...]
81
- 我们把它转成 messages,再用 chat_template 生成 prompt
 
 
82
  """
83
- messages = [{"role": "system", "content": SYSTEM_PROMPT}]
84
- for u, a in history_pairs:
85
- if u:
86
- messages.append({"role": "user", "content": u})
87
- if a:
88
- messages.append({"role": "assistant", "content": a})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  messages.append({"role": "user", "content": user_msg})
90
 
 
91
  if hasattr(tokenizer, "apply_chat_template"):
92
  try:
93
  return tokenizer.apply_chat_template(
@@ -98,20 +153,21 @@ def _build_prompt(history_pairs, user_msg: str) -> str:
98
  except Exception:
99
  pass
100
 
101
- # fallback
102
  prompt = f"System: {SYSTEM_PROMPT}\n"
103
- for u, a in history_pairs:
104
- prompt += f"User: {u}\nAssistant: {a}\n"
105
- prompt += f"User: {user_msg}\nAssistant:"
 
 
 
106
  return prompt
107
 
108
-
109
- def respond(message: str, history):
110
- """
111
- 兼容旧版 gradio.ChatInterface:fn(message, history) -> str 或 generator
112
- history: List[Tuple[str, str]]
113
- """
114
- prompt = _build_prompt(history, message)
115
 
116
  inputs = tokenizer(prompt, return_tensors="pt")
117
  if torch.cuda.is_available():
@@ -120,16 +176,16 @@ def respond(message: str, history):
120
  streamer = TextIteratorStreamer(
121
  tokenizer,
122
  skip_special_tokens=True,
123
- skip_prompt=True, # ✅ 不回显 prompt(解决 <|User|> 问题
124
  )
125
 
126
  gen_kwargs = dict(
127
  **inputs,
128
  streamer=streamer,
129
- max_new_tokens=512,
130
- do_sample=True,
131
- temperature=0.7,
132
- top_p=0.95,
133
  pad_token_id=tokenizer.eos_token_id,
134
  )
135
 
@@ -141,20 +197,64 @@ def respond(message: str, history):
141
  out += piece
142
  yield out.strip()
143
 
144
-
145
- with gr.Blocks(theme=theme, css=css) as demo:
 
 
146
  with gr.Column(elem_classes=["page-wrap"]):
147
  gr.Markdown("# 我是 Bello,有什么能帮到您?")
148
 
149
  with gr.Column(elem_classes=["chat-card"]):
150
- # ✅ 老版本支持 type="messages",不要传 type
151
- gr.ChatInterface(
152
- fn=respond,
153
- title="",
154
- description="",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  )
156
 
157
- demo.queue(default_concurrency_limit=4)
 
158
 
159
  if __name__ == "__main__":
160
- demo.launch(ssr_mode=False)
 
 
1
  import os
2
 
3
+ # --- Robust fix: HF/K8s may set OMP_NUM_THREADS like "7500m" (invalid for libgomp) ---
4
  _raw_omp = os.getenv("OMP_NUM_THREADS", "")
5
  if not _raw_omp.isdigit():
6
  os.environ["OMP_NUM_THREADS"] = "1"
7
 
 
8
  import threading
9
  import gradio as gr
10
  import torch
11
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
12
 
13
+ # -------------------------
14
+ # Config
15
+ # -------------------------
16
  MODEL_ID = os.getenv("MODEL_ID", "Milkfish033/deepseek-r1-1.5b-merged")
17
 
18
+ # 🔒 固定系统提示词不在 UI 暴露)
19
  SYSTEM_PROMPT = "你是 Bello,一个友好的智能助手。请用清晰、简洁的中文回答用户问题。"
20
 
21
  theme = gr.themes.Soft()
 
37
  padding: 12px;
38
  }
39
 
40
+ /* 输入框边框更明显 */
41
  .chat-card textarea,
42
  .chat-card input {
43
  border: 1px solid #d1d5db !important;
 
48
  /* 发送按钮圆角 */
49
  .chat-card button { border-radius: 14px !important; }
50
 
51
+ /* 气泡样式不同 gradio 版本 class ,多写 selector 命中 */
52
  .chat-card .message.user,
53
+ .chat-card .bubble.user,
54
+ .chat-card [data-testid="chatbot"] .message.user,
55
+ .chat-card [data-testid="chatbot"] .bubble.user {
56
  background: #eef2ff !important;
57
  border: 1px solid #e0e7ff !important;
58
  border-radius: 16px !important;
59
  }
60
 
 
61
  .chat-card .message.assistant,
62
+ .chat-card .message.bot,
63
+ .chat-card .bubble.assistant,
64
  .chat-card .bubble.bot,
65
+ .chat-card [data-testid="chatbot"] .message.assistant,
66
+ .chat-card [data-testid="chatbot"] .bubble.assistant {
67
  background: #f8fafc !important;
68
  border: 1px solid #eef2f7 !important;
69
  border-radius: 16px !important;
70
  }
71
+
72
+ /* 每条消息 spacing */
73
+ .chat-card .message,
74
+ .chat-card .bubble {
75
+ padding: 10px 12px !important;
76
+ margin: 8px 0 !important;
77
+ }
78
  """
79
 
80
+ # -------------------------
81
+ # Load model once
82
+ # -------------------------
83
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
84
  model = AutoModelForCausalLM.from_pretrained(
85
  MODEL_ID,
 
89
  )
90
  model.eval()
91
 
92
+ # -------------------------
93
+ # History adapter (CRITICAL FIX)
94
+ # -------------------------
95
+ def history_to_messages(history) -> list[dict]:
96
  """
97
+ 兼容 Gradio ChatInterface 不同版本/不同轮次可能传入的 history 结构:
98
+ A) 旧格式: [(user, bot), ...]
99
+ B) 新格式: [{"role":"user"/"assistant","content":"..."}, ...]
100
+ C) 其它异常结构:尽量容错,不抛异常
101
  """
102
+ msgs = [{"role": "system", "content": SYSTEM_PROMPT}]
103
+
104
+ if not history:
105
+ return msgs
106
+
107
+ # 1) 如果是 list,拿第一个元素判断结构
108
+ first = history[0]
109
+
110
+ # Case A: tuple/list pairs
111
+ if isinstance(first, (tuple, list)):
112
+ for item in history:
113
+ if not isinstance(item, (tuple, list)):
114
+ continue
115
+ # 有些情况下可能是 (user, bot, meta...) 长度>2
116
+ user = item[0] if len(item) > 0 else ""
117
+ bot = item[1] if len(item) > 1 else ""
118
+ if user:
119
+ msgs.append({"role": "user", "content": str(user)})
120
+ if bot:
121
+ msgs.append({"role": "assistant", "content": str(bot)})
122
+ return msgs
123
+
124
+ # Case B: dict messages
125
+ if isinstance(first, dict) and "role" in first:
126
+ for m in history:
127
+ if not isinstance(m, dict):
128
+ continue
129
+ role = m.get("role")
130
+ content = m.get("content", "")
131
+ if role in ("user", "assistant"):
132
+ msgs.append({"role": role, "content": str(content)})
133
+ return msgs
134
+
135
+ # Case C: unknown -> stringify
136
+ for item in history:
137
+ msgs.append({"role": "assistant", "content": str(item)})
138
+ return msgs
139
+
140
+
141
+ def build_prompt(history, user_msg: str) -> str:
142
+ messages = history_to_messages(history)
143
  messages.append({"role": "user", "content": user_msg})
144
 
145
+ # 优先使用模型自带 chat template(deepseek 带 jinja template)
146
  if hasattr(tokenizer, "apply_chat_template"):
147
  try:
148
  return tokenizer.apply_chat_template(
 
153
  except Exception:
154
  pass
155
 
156
+ # fallback:简单拼接
157
  prompt = f"System: {SYSTEM_PROMPT}\n"
158
+ for m in messages:
159
+ if m["role"] == "user":
160
+ prompt += f"User: {m['content']}\n"
161
+ elif m["role"] == "assistant":
162
+ prompt += f"Assistant: {m['content']}\n"
163
+ prompt += "Assistant:"
164
  return prompt
165
 
166
+ # -------------------------
167
+ # Generation (streaming)
168
+ # -------------------------
169
+ def respond(message, history, max_tokens=512, temperature=0.7, top_p=0.95):
170
+ prompt = build_prompt(history, message)
 
 
171
 
172
  inputs = tokenizer(prompt, return_tensors="pt")
173
  if torch.cuda.is_available():
 
176
  streamer = TextIteratorStreamer(
177
  tokenizer,
178
  skip_special_tokens=True,
179
+ skip_prompt=True, # ✅ 不回显 prompt(解决 <|User|>...
180
  )
181
 
182
  gen_kwargs = dict(
183
  **inputs,
184
  streamer=streamer,
185
+ max_new_tokens=int(max_tokens),
186
+ do_sample=(float(temperature) > 0),
187
+ temperature=float(temperature),
188
+ top_p=float(top_p),
189
  pad_token_id=tokenizer.eos_token_id,
190
  )
191
 
 
197
  out += piece
198
  yield out.strip()
199
 
200
+ # -------------------------
201
+ # UI
202
+ # -------------------------
203
+ with gr.Blocks() as demo:
204
  with gr.Column(elem_classes=["page-wrap"]):
205
  gr.Markdown("# 我是 Bello,有什么能帮到您?")
206
 
207
  with gr.Column(elem_classes=["chat-card"]):
208
+ # ✅ 不 type="messages"(避免旧版本报错)
209
+ # ✅ 用额外输入手工加 sliders(兼容性更稳)
210
+ chatbot = gr.Chatbot(height=520)
211
+ msg = gr.Textbox(placeholder="请输入问题...", show_label=False)
212
+ send = gr.Button("发送")
213
+
214
+ with gr.Row():
215
+ max_tokens = gr.Slider(1, 2048, value=512, step=1, label="Max new tokens")
216
+ temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature")
217
+ top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")
218
+
219
+ def _user_submit(user_message, chat_history):
220
+ # chat_history 是 [(user, bot), ...] 但有时会携带 meta,因此我们不强依赖结构
221
+ if chat_history is None:
222
+ chat_history = []
223
+ chat_history = list(chat_history)
224
+ chat_history.append((user_message, "")) # 先占位
225
+ return "", chat_history
226
+
227
+ def _bot_stream(chat_history, max_tokens, temperature, top_p):
228
+ # 取最后一条 user
229
+ if not chat_history:
230
+ return chat_history
231
+ last_user = chat_history[-1][0]
232
+
233
+ # history 给模型:去掉最后一条占位(只传已完成的对话)
234
+ prior = chat_history[:-1]
235
+
236
+ # 用我们的 respond()(它能吃 tuple 或 dict messages)
237
+ gen = respond(last_user, prior, max_tokens=max_tokens, temperature=temperature, top_p=top_p)
238
+
239
+ partial = ""
240
+ for chunk in gen:
241
+ partial = chunk
242
+ chat_history[-1] = (last_user, partial)
243
+ yield chat_history
244
+
245
+ # Enter 提交
246
+ msg.submit(_user_submit, [msg, chatbot], [msg, chatbot], queue=False).then(
247
+ _bot_stream, [chatbot, max_tokens, temperature, top_p], chatbot
248
+ )
249
+
250
+ # 点击按钮提交
251
+ send.click(_user_submit, [msg, chatbot], [msg, chatbot], queue=False).then(
252
+ _bot_stream, [chatbot, max_tokens, temperature, top_p], chatbot
253
  )
254
 
255
+ # 并发先设低,稳定第一;确认稳定后你再调大
256
+ demo.queue(default_concurrency_limit=1)
257
 
258
  if __name__ == "__main__":
259
+ # Gradio 6: theme/css 建议放到 launch()
260
+ demo.launch(ssr_mode=False, theme=theme, css=css)