badanwang commited on
Commit
9eb5e7a
·
verified ·
1 Parent(s): 237b0a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -65
app.py CHANGED
@@ -1,90 +1,165 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
4
  import os
5
 
6
- # --- 1. 配置与模型加载 ---
7
- # 假设运行环境的硬件资源是充足的。
8
  MODEL_ID = os.getenv("MODEL_ID", "badanwang/teacher_basic_qwen3-0.6b")
9
- print(f"INFO: 正在加载模型: {MODEL_ID}")
10
 
11
- # 使用 try-except 来捕获任何可能的加载错误 (例如网络问题、模型名称错误等)
12
  try:
13
- # 加载分词器和模型
14
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
15
- # device_map="auto" 会自动利用可用的硬件 (如 CPU 或 GPU)
16
  model = AutoModelForCausalLM.from_pretrained(
17
  MODEL_ID,
18
- torch_dtype="auto", # 自动选择最佳数据类型
19
  device_map="auto",
20
  trust_remote_code=True
21
  )
22
- print("INFO: 模型和分词器加载成功!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- # 将核心推理逻辑定义为一个函数
25
- # 只有在模型成功加载后,这个函数才会被有效定义
26
- def predict(prompt: str, history: list[list[str]]):
27
- """
28
- 接收用户输入和对话历史,返回更新后的完整对话历史。
29
- Gradio 会自动为这个函数创建 API 端点。
30
- """
31
- print(f"INFO: 收到API/UI请求: prompt='{prompt}'")
32
-
33
- # 1. 构建符合模型要求的消息列表
34
- messages = []
35
- for user_message, bot_message in history:
36
- messages.append({"role": "user", "content": user_message})
37
- messages.append({"role": "assistant", "content": bot_message})
38
- messages.append({"role": "user", "content": prompt})
39
-
40
- # 2. 应用聊天模板并进行分词
41
  input_ids = tokenizer.apply_chat_template(
42
  messages,
43
  add_generation_prompt=True,
44
  tokenize=True,
45
  return_tensors="pt"
46
  ).to(model.device)
 
 
47
 
48
- # 3. 生成回复
49
- # 使用简单的 .generate(),不加额外的采样参数以保持简洁
50
- outputs = model.generate(input_ids, max_new_tokens=1024)
51
-
52
- # 4. 解码生成的文本,跳过输入的token
53
- response_text = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- print(f"INFO: 生成回复: {response_text}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- # 5. 更新并返回对话历史
58
- history.append([prompt, response_text])
59
- return history
 
 
 
 
 
 
60
 
61
- except Exception as e:
62
- print(f"FATAL: 加载模型或分词器时发生致命错误: {e}")
63
- # 如果模型加载失败,则定义一个专门用于报错的函数
64
- # 这能确保Gradio界面依然可以启动,并向用户显示一个清晰的错误信息
65
- def predict(*args, **kwargs):
66
- raise gr.Error(f"模型未能加载,应用无法工作。请检查后台日志获取详细错误信息。错误: {e}")
67
-
68
- # --- 2. 创建并启动 Gradio 应用 ---
69
- # 使用 gr.Blocks 来自定义界面布局
70
- with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
71
- gr.Markdown(f"## 模型聊天机器人\n当前模型: `{MODEL_ID}`")
72
 
73
- # 定义聊天机器人组件和输入框
74
- chatbot = gr.Chatbot(label="对话历史", height=600)
75
- msg_input = gr.Textbox(label="在这里输入你的问题...", placeholder="例如:你好,你是谁?")
76
- clear_button = gr.Button("清除对话")
77
-
78
- # 设定组件的交互逻辑
79
- # 当用户在输入框中按回车时,调用 predict 函数
80
- msg_input.submit(predict, [msg_input, chatbot], chatbot)
81
- # 当用户点击“清除对话”按钮时,清空聊天机器人组件
82
- clear_button.click(lambda: [], None, chatbot)
83
-
84
- # --- 3. 启动应用并开放API ---
85
- print("INFO: 准备启动Gradio应用...")
86
-
87
- # .queue() 使应用能够处理多个排队的请求,并且在 4.29.0 版本中会自动开放API。
88
- # share=True 是解决CORS问题的关键。它会生成一个公开的、已配置好CORS的 .gradio.live 网址。
89
- # *** 已移除 'api_open=True' 参数以适配 gradio==4.29.0 ***
90
- demo.queue().launch(share=True)
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
+ from threading import Thread
5
  import os
6
 
7
+ # --- 配置 ---
 
8
  MODEL_ID = os.getenv("MODEL_ID", "badanwang/teacher_basic_qwen3-0.6b")
9
+ print(f"INFO: Application startup. Loading model: {MODEL_ID}")
10
 
11
+ # --- 1. 模型加载 (内置健壮的错误处理) ---
12
  try:
 
13
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
 
14
  model = AutoModelForCausalLM.from_pretrained(
15
  MODEL_ID,
16
+ torch_dtype="auto",
17
  device_map="auto",
18
  trust_remote_code=True
19
  )
20
+ print("INFO: Model and tokenizer loaded successfully!")
21
+ model_loaded = True
22
+ except Exception as e:
23
+ print(f"FATAL: Failed to load model or tokenizer: {e}")
24
+ model_loaded = False
25
+ model_load_error = e
26
+
27
+ # --- 2. 核心流式推理函数 ---
28
+ def stream_predict(prompt: str, history: list[list[str]]):
29
+ """
30
+ 一个生成器函数,用于流式生成对话。
31
+ 它会逐步 (yield) 返回完整的对话历史。
32
+ """
33
+ if not model_loaded:
34
+ # 如果模型加载失败,则立即抛出错误
35
+ raise gr.Error(f"Model is not loaded. Please check logs. Error: {model_load_error}")
36
+
37
+ print(f"INFO: Received prompt: '{prompt}'")
38
 
39
+ # 将历史记录和新提示转换为模型需要的格式
40
+ messages = []
41
+ for user_msg, assistant_msg in history:
42
+ messages.append({"role": "user", "content": user_msg})
43
+ messages.append({"role": "assistant", "content": assistant_msg})
44
+ messages.append({"role": "user", "content": prompt})
45
+
46
+ # 应用聊天模板
47
+ try:
 
 
 
 
 
 
 
 
48
  input_ids = tokenizer.apply_chat_template(
49
  messages,
50
  add_generation_prompt=True,
51
  tokenize=True,
52
  return_tensors="pt"
53
  ).to(model.device)
54
+ except Exception as e:
55
+ raise gr.Error(f"Error applying chat template: {e}")
56
 
57
+ # 初始化 streamer 和生成线程
58
+ streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
59
+
60
+ generation_kwargs = dict(
61
+ input_ids=input_ids,
62
+ streamer=streamer,
63
+ max_new_tokens=1024,
64
+ do_sample=True,
65
+ temperature=0.7,
66
+ top_p=0.9
67
+ )
68
+
69
+ # 在独立线程中运行生成,防止阻塞UI
70
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
71
+ thread.start()
72
+
73
+ # 流式输出
74
+ try:
75
+ # 初始化一个空的字符串来存放助手的回复
76
+ assistant_response = ""
77
+ # 每次从streamer中获取一个新的文本片段
78
+ for new_text in streamer:
79
+ if not new_text:
80
+ continue
81
+ assistant_response += new_text
82
+ # 将当前用户输入和不断增长的助手回复组合成新的对话历史
83
+ # 然后使用 yield 返回,Gradio会用它来更新UI
84
+ yield history + [[prompt, assistant_response]]
85
 
86
+ print("INFO: Streaming finished.")
87
+
88
+ except Exception as e:
89
+ print(f"ERROR: An error occurred during streaming: {e}")
90
+ raise gr.Error(f"An error occurred during generation: {e}")
91
+ finally:
92
+ # 确保线程结束
93
+ thread.join()
94
+
95
+
96
+ # --- 3. Gradio Blocks 界面布局 ---
97
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), css="footer {visibility: hidden}") as demo:
98
+ gr.Markdown(f"# 流式对话机器人\n### 模型: `{MODEL_ID}`")
99
+
100
+ # 使用 gr.State 来存储对话历史
101
+ # 这是实现多轮对话的关键
102
+ chatbot_state = gr.State([])
103
+
104
+ # Chatbot 组件用于显示对话
105
+ chatbot_ui = gr.Chatbot(label="对话窗口", height=600)
106
+
107
+ with gr.Row():
108
+ # Textbox 用于用户输入
109
+ prompt_input = gr.Textbox(
110
+ show_label=False,
111
+ placeholder="请在这里输入您的问题...",
112
+ scale=4,
113
+ )
114
+ # Button 用于提交
115
+ submit_button = gr.Button("发送", variant="primary", scale=1)
116
+
117
+ # 清除按钮
118
+ clear_button = gr.Button("清除对话历史")
119
+
120
+ # --- 4. 事件处理逻辑 ---
121
+
122
+ # 提交逻辑:
123
+ # 1. 点击"发送"按钮或在输入框按回车时触发
124
+ # 2. 调用 stream_predict 函数
125
+ # 3. 输入是用户输入框(prompt_input)和对话历史状态(chatbot_state)
126
+ # 4. 输出会实时更新聊天机器人界面(chatbot_ui)
127
+ # 5. 在函数开始前,将用户输入添加到聊天记录的末尾,并清空输入框
128
+ def on_submit(prompt, history):
129
+ # 将用户输入加入历史,形成 "用户: XXX" 的临时记录
130
+ return "", history + [[prompt, None]]
131
 
132
+ prompt_input.submit(
133
+ on_submit,
134
+ [prompt_input, chatbot_state],
135
+ [prompt_input, chatbot_ui]
136
+ ).then(
137
+ stream_predict,
138
+ [prompt_input, chatbot_state],
139
+ chatbot_ui
140
+ )
141
 
142
+ submit_button.click(
143
+ on_submit,
144
+ [prompt_input, chatbot_state],
145
+ [prompt_input, chatbot_ui]
146
+ ).then(
147
+ stream_predict,
148
+ [prompt_input, chatbot_state],
149
+ chatbot_ui
150
+ )
 
 
151
 
152
+ # 清除逻辑:
153
+ # 点击按钮时,清空状态和UI
154
+ def on_clear():
155
+ return []
156
+
157
+ clear_button.click(on_clear, [], chatbot_state)
158
+ clear_button.click(on_clear, [], chatbot_ui)
159
+
160
+
161
+ # --- 5. 启动应用 ---
162
+ print("INFO: Preparing to launch Gradio app...")
163
+ # .queue() 启用请求队列,对于流式应用是必需的
164
+ # 在Hugging Face Spaces上, 无需 share=True, Gradio会自动处理
165
+ demo.queue().launch()