caixiaoshun commited on
Commit
4ecd3c9
·
verified ·
1 Parent(s): 462b2cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -26
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  from typing import List, Dict, Optional, Tuple
 
5
 
6
  # ==========================================
7
  # Helper: dtype map & loader with simple cache
@@ -21,7 +22,9 @@ def _dtype_from_name(name: str):
21
 
22
 
23
  def load_model_and_tokenizer(repo_id: str, device_map: str = "cpu", dtype_name: str = "auto"):
24
-
 
 
25
  key = (repo_id, device_map, dtype_name)
26
  if key in _MODEL_CACHE:
27
  return _MODEL_CACHE[key]
@@ -48,7 +51,7 @@ def load_model_and_tokenizer(repo_id: str, device_map: str = "cpu", dtype_name:
48
 
49
 
50
  # ==========================================
51
- # Chat utilities
52
  # ==========================================
53
 
54
  def messages_to_pairs(messages: List[Dict[str, str]]) -> List[Tuple[str, str]]:
@@ -72,33 +75,66 @@ def messages_to_pairs(messages: List[Dict[str, str]]) -> List[Tuple[str, str]]:
72
  return pairs
73
 
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  def predict(user_text: str,
76
  messages_state: List[Dict[str, str]],
77
  repo_id: str, device_map: str, dtype_name: str,
78
- max_new_token: int, top_k: int):
79
-
 
 
 
 
 
80
  messages_state = messages_state or []
 
81
 
82
- # Append user message
 
 
 
83
  messages_state.append({"role": "user", "content": user_text or ""})
 
84
 
85
- # Build initial display and show immediately
86
  chat_display = messages_to_pairs(messages_state)
87
- yield chat_display, messages_state
88
-
89
- # Load model/tokenizer lazily
90
- try:
91
- tokenizer, model = load_model_and_tokenizer(repo_id, device_map=device_map, dtype_name=dtype_name)
92
- except Exception as e:
93
- err = f"[加载错误] {e}"
94
- # Show error as assistant reply
95
- chat_display[-1] = (chat_display[-1][0], err)
96
- messages_state.append({"role": "assistant", "content": err})
97
- yield chat_display, messages_state
98
- return
99
 
100
- # Inference
101
  try:
 
 
102
  try:
103
  output = model.chat(
104
  messages_state,
@@ -107,35 +143,53 @@ def predict(user_text: str,
107
  top_k=int(top_k),
108
  )
109
  except TypeError:
110
- # Fallback to minimal signature if custom signature not supported in current build
111
  output = model.chat(messages_state, tokenizer)
112
 
113
  partial = ""
114
  for ch in str(output):
115
  partial += ch
116
  chat_display[-1] = (chat_display[-1][0], partial)
117
- yield chat_display, messages_state
118
 
119
- # Finalize state
120
  messages_state.append({"role": "assistant", "content": str(output)})
121
- yield chat_display, messages_state
 
122
  except Exception as e:
123
  err = f"[推理错误] {e}"
 
124
  chat_display[-1] = (chat_display[-1][0], err)
125
  messages_state.append({"role": "assistant", "content": err})
126
- yield chat_display, messages_state
127
 
128
 
129
  def clear_chat():
130
  return [], [] # chatbot pairs, messages_state
131
 
132
 
 
 
 
 
 
 
 
 
 
 
 
133
  # ==========================================
134
  # Gradio UI
135
  # ==========================================
136
  with gr.Blocks(title="mini-moe Chat (Gradio)") as demo:
 
 
 
 
 
137
 
138
  messages_state = gr.State([]) # 保存 role/content 历史
 
 
139
 
140
  with gr.Row():
141
  with gr.Column(scale=2):
@@ -156,10 +210,15 @@ with gr.Blocks(title="mini-moe Chat (Gradio)") as demo:
156
  dtype_dd = gr.Dropdown(label="精度 (dtype/torch_dtype)", choices=["auto", "float32", "bfloat16", "float16"], value="auto")
157
  max_new_num = gr.Number(label="max_new_token", value=256, precision=0)
158
  top_k_num = gr.Number(label="top_k", value=5, precision=0)
 
 
 
159
 
160
  # Events: send / submit
161
- send_evt_inputs = [user_box, messages_state, repo_dd, device_dd, dtype_dd, max_new_num, top_k_num]
162
- send_evt_outputs = [chatbot, messages_state]
 
 
163
 
164
  send_btn.click(predict, inputs=send_evt_inputs, outputs=send_evt_outputs)
165
  user_box.submit(predict, inputs=send_evt_inputs, outputs=send_evt_outputs)
@@ -173,6 +232,14 @@ with gr.Blocks(title="mini-moe Chat (Gradio)") as demo:
173
  # Clear chat
174
  clear_btn.click(clear_chat, inputs=None, outputs=[chatbot, messages_state])
175
 
 
 
 
 
 
 
 
 
176
 
177
  if __name__ == "__main__":
178
  demo.queue().launch() # set share=True if you want a public link
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  from typing import List, Dict, Optional, Tuple
5
+ import time
6
 
7
  # ==========================================
8
  # Helper: dtype map & loader with simple cache
 
22
 
23
 
24
  def load_model_and_tokenizer(repo_id: str, device_map: str = "cpu", dtype_name: str = "auto"):
25
+ """Load & cache (tokenizer, model) keyed by (repo_id, device_map, dtype). No low_cpu_mem_usage.
26
+ Prefer `dtype=...`; on TypeError fallback to `torch_dtype=` or omit.
27
+ """
28
  key = (repo_id, device_map, dtype_name)
29
  if key in _MODEL_CACHE:
30
  return _MODEL_CACHE[key]
 
51
 
52
 
53
  # ==========================================
54
+ # Chat utilities & logging helpers
55
  # ==========================================
56
 
57
  def messages_to_pairs(messages: List[Dict[str, str]]) -> List[Tuple[str, str]]:
 
75
  return pairs
76
 
77
 
78
+ def _ts() -> str:
79
+ return time.strftime("%H:%M:%S")
80
+
81
+
82
+ def append_log(logs: str, msg: str) -> str:
83
+ line = f"[{_ts()}] {msg}\n"
84
+ return (logs + line) if logs else line
85
+
86
+
87
+ # ==========================================
88
+ # Model state helpers (reload only when repo_id changes)
89
+ # ==========================================
90
+
91
+ def ensure_model(model_state: Dict, repo_id: str, device_map: str, dtype_name: str, logs: str):
92
+ """Ensure a model is available in model_state.
93
+ Only (re)load when repo_id changes or model_state is empty.
94
+ device_map/dtype_name 变更不会触发重新加载(按你的要求)。
95
+ """
96
+ ms = model_state or {"repo_id": None, "tok": None, "model": None}
97
+ if ms.get("repo_id") != repo_id or ms.get("model") is None:
98
+ logs = append_log(logs, f"加载模型 {repo_id}(触发:repo 变更)…")
99
+ tok, mdl = load_model_and_tokenizer(repo_id, device_map=device_map, dtype_name=dtype_name)
100
+ ms = {"repo_id": repo_id, "tok": tok, "model": mdl}
101
+ logs = append_log(logs, "模型加载完成。")
102
+ else:
103
+ logs = append_log(logs, f"使用已加载模型 {repo_id}(缓存)")
104
+ return ms, ms["tok"], ms["model"], logs
105
+
106
+
107
+ # ==========================================
108
+ # Predict
109
+ # ==========================================
110
+
111
  def predict(user_text: str,
112
  messages_state: List[Dict[str, str]],
113
  repo_id: str, device_map: str, dtype_name: str,
114
+ max_new_token: int, top_k: int,
115
+ logs_state: str,
116
+ model_state: Dict):
117
+ """Generator for streaming output + live logs.
118
+ Only reload when repo_id changes.
119
+ Expects custom model.chat(conversations, tokenizer, max_new_token=..., top_k=...).
120
+ """
121
  messages_state = messages_state or []
122
+ logs_state = logs_state or ""
123
 
124
+ # 1) Ensure model based on repo_id only
125
+ model_state, tokenizer, model, logs_state = ensure_model(model_state, repo_id, device_map, dtype_name, logs_state)
126
+
127
+ # 2) Append user & paint
128
  messages_state.append({"role": "user", "content": user_text or ""})
129
+ logs_state = append_log(logs_state, f"收到输入:{(user_text or '').strip()[:60]}")
130
 
 
131
  chat_display = messages_to_pairs(messages_state)
132
+ yield chat_display, messages_state, logs_state, logs_state, model_state
 
 
 
 
 
 
 
 
 
 
 
133
 
134
+ # 3) Inference
135
  try:
136
+ logs_state = append_log(logs_state, f"开始推理:max_new_token={int(max_new_token)}, top_k={int(top_k)}")
137
+ yield chat_display, messages_state, logs_state, logs_state, model_state
138
  try:
139
  output = model.chat(
140
  messages_state,
 
143
  top_k=int(top_k),
144
  )
145
  except TypeError:
 
146
  output = model.chat(messages_state, tokenizer)
147
 
148
  partial = ""
149
  for ch in str(output):
150
  partial += ch
151
  chat_display[-1] = (chat_display[-1][0], partial)
152
+ yield chat_display, messages_state, logs_state, logs_state, model_state
153
 
 
154
  messages_state.append({"role": "assistant", "content": str(output)})
155
+ logs_state = append_log(logs_state, f"推理完成,输出长度 {len(str(output))} 字符。")
156
+ yield chat_display, messages_state, logs_state, logs_state, model_state
157
  except Exception as e:
158
  err = f"[推理错误] {e}"
159
+ logs_state = append_log(logs_state, err)
160
  chat_display[-1] = (chat_display[-1][0], err)
161
  messages_state.append({"role": "assistant", "content": err})
162
+ yield chat_display, messages_state, logs_state, logs_state, model_state
163
 
164
 
165
  def clear_chat():
166
  return [], [] # chatbot pairs, messages_state
167
 
168
 
169
+ def clear_logs_fn():
170
+ return "", "" # logs_box text, logs_state
171
+
172
+
173
+ def preload_on_repo_change(repo_id: str, device_map: str, dtype_name: str, logs_state: str, model_state: Dict):
174
+ """当仓库切换时,预加载模型并写日志。"""
175
+ logs_state = logs_state or ""
176
+ model_state, _, _, logs_state = ensure_model(model_state, repo_id, device_map, dtype_name, logs_state)
177
+ return logs_state, model_state
178
+
179
+
180
  # ==========================================
181
  # Gradio UI
182
  # ==========================================
183
  with gr.Blocks(title="mini-moe Chat (Gradio)") as demo:
184
+ gr.Markdown("""
185
+ # 🤖 mini-moe Chat UI (Gradio)
186
+ 仅在 **repo 变更** 时重新加载模型;设备/精度变更不会触发重新加载(按你的要求)。
187
+ 右侧含 **日志面板**,实时显示加载与推理步骤;**不使用 system prompt**。
188
+ """)
189
 
190
  messages_state = gr.State([]) # 保存 role/content 历史
191
+ logs_state = gr.State("") # 保存日志文本
192
+ model_state = gr.State({"repo_id": None, "tok": None, "model": None}) # 当前已加载模型
193
 
194
  with gr.Row():
195
  with gr.Column(scale=2):
 
210
  dtype_dd = gr.Dropdown(label="精度 (dtype/torch_dtype)", choices=["auto", "float32", "bfloat16", "float16"], value="auto")
211
  max_new_num = gr.Number(label="max_new_token", value=256, precision=0)
212
  top_k_num = gr.Number(label="top_k", value=5, precision=0)
213
+ with gr.Accordion("📜 日志 (展开查看)", open=False):
214
+ logs_box = gr.Textbox(label="运行日志", lines=12, interactive=False)
215
+ log_clear_btn = gr.Button("清空日志")
216
 
217
  # Events: send / submit
218
+ send_evt_inputs = [
219
+ user_box, messages_state, repo_dd, device_dd, dtype_dd, max_new_num, top_k_num, logs_state, model_state
220
+ ]
221
+ send_evt_outputs = [chatbot, messages_state, logs_box, logs_state, model_state]
222
 
223
  send_btn.click(predict, inputs=send_evt_inputs, outputs=send_evt_outputs)
224
  user_box.submit(predict, inputs=send_evt_inputs, outputs=send_evt_outputs)
 
232
  # Clear chat
233
  clear_btn.click(clear_chat, inputs=None, outputs=[chatbot, messages_state])
234
 
235
+ # Clear logs
236
+ log_clear_btn.click(clear_logs_fn, inputs=None, outputs=[logs_box, logs_state])
237
+
238
+ # Preload on repo change (only reload on repo change)
239
+ repo_dd.change(preload_on_repo_change,
240
+ inputs=[repo_dd, device_dd, dtype_dd, logs_state, model_state],
241
+ outputs=[logs_box, model_state])
242
+
243
 
244
  if __name__ == "__main__":
245
  demo.queue().launch() # set share=True if you want a public link