badanwang commited on
Commit
7aee45a
·
verified ·
1 Parent(s): b965102

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -34
app.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
@@ -5,46 +9,69 @@ from threading import Thread
5
  import logging
6
  import time
7
  import json
 
8
 
9
- # ===================================================================
10
- # 最终版 app.py (适配 Gradio 4.x+ 的推荐模式)
11
- # ===================================================================
12
 
13
- # 1. 配置详细的日志记录
14
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
15
  logger = logging.getLogger(__name__)
16
 
17
- MODEL_ID = "badanwang/teacher_basic_qwen3-0.6b"
 
18
 
19
  logger.info("===== Application Startup =====")
20
  logger.info(f"正在加载模型和分词器: {MODEL_ID}")
 
 
21
  try:
22
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
 
23
  model = AutoModelForCausalLM.from_pretrained(
24
  MODEL_ID,
25
- torch_dtype=torch.float32, # CPU 上使用 float32 以获得最佳稳定性和兼容性
26
- device_map="auto"
 
27
  )
28
  logger.info("模型和分词器加载成功!")
29
  except Exception as e:
30
  logger.error(f"加载模型时发生致命错误: {e}", exc_info=True)
 
31
  model, tokenizer = None, None
 
 
32
 
33
- # --- 2. 核心推理函数 (已根据 Gradio 新模式重构) ---
34
- def predict(messages: list):
 
 
35
  """
36
- 接收一个包含完整对话历史的 OpenAI 格式列表,返回模型的流式响应。
 
 
 
 
37
  """
38
  start_time = time.time()
39
  logger.info("\n--- [START] 进入 predict 函数 ---")
40
-
41
- # 使用 json.dumps 美化输出,方便阅读
42
- logger.info(f"[INPUT] 收到的 messages 列表:\n{json.dumps(messages, indent=2, ensure_ascii=False)}")
43
 
 
44
  if model is None or tokenizer is None:
45
- logger.warning("[HANDLER] 模型或分词器为 None,返回错误信息。")
46
- yield "错误:模型未能成功加载,请检查后台日志。"
47
- return
 
 
 
 
 
 
 
 
 
 
48
 
49
  try:
50
  logger.info("[HANDLER] 正在应用聊天模板...")
@@ -56,52 +83,78 @@ def predict(messages: list):
56
  ).to(model.device)
57
  logger.info(f"[HANDLER] 模板应用成功,输入 token 数量: {prompt_tokens.shape[-1]}")
58
 
 
59
  streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
60
-
 
61
  generation_kwargs = {
62
  "input_ids": prompt_tokens,
63
  "streamer": streamer,
64
  "max_new_tokens": 1024,
65
  "do_sample": True,
66
  "temperature": 0.7,
67
- "top_p": 0.9,
68
  }
69
 
70
- logger.info("[HANDLER] 准备启动生成线程...")
71
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
72
  thread.start()
73
  logger.info("[HANDLER] 生成线程已启动,开始从 streamer 中读取数据...")
74
 
 
75
  buffer = ""
76
  token_count = 0
77
  for new_text in streamer:
78
  token_count += 1
79
- logger.info(f"[STREAM] 正在生成第 {token_count} 个 token: {repr(new_text)}")
 
 
80
  buffer += new_text
81
  yield buffer
82
 
83
  logger.info(f"[HANDLER] Streamer 读取完毕,共生成 {token_count} 个 token。")
 
84
 
85
  except Exception as e:
86
  logger.error(f"[HANDLER] 在推理过程中发生错误: {e}", exc_info=True)
87
- yield "抱歉,处理您的请求时遇到了一个内部错误。"
 
88
 
89
- end_time = time.time()
90
- logger.info(f"--- [END] predict 函数结束,总耗时: {end_time - start_time:.2f} 秒 ---")
 
 
 
91
 
92
- # --- 3. 创建Gradio界面 (已优化) ---
93
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
94
- gr.Markdown(f"# 你的自定义Qwen模型聊天机器人\n## 模型: {MODEL_ID}")
95
 
 
96
  chat_interface = gr.ChatInterface(
97
  fn=predict,
98
- title="聊天机器人",
99
- description="向你的微调Qwen模型提问吧!",
100
- examples=[["你好,你是谁?"], ["用Python写一个快速排序算法"]],
101
- type="messages" # <-- 【最关键的优化】告诉 Gradio 使用新的 OpenAI 格式
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  )
103
 
104
  # --- 4. 启动应用 ---
105
- logger.info("准备启动 Gradio 应用...")
106
- demo.queue().launch()
107
- logger.info("Gradio 应用已启动。")
 
 
 
 
1
+ # ===================================================================
2
+ # 优化版 app.py (为 Gradio 5.x 优化)
3
+ # ===================================================================
4
+
5
  import gradio as gr
6
  import torch
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
9
  import logging
10
  import time
11
  import json
12
+ import os
13
 
14
+ # --- 1. 配置与初始化 ---
 
 
15
 
16
+ # 日志记录配置
17
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
18
  logger = logging.getLogger(__name__)
19
 
20
+ # 从环境变量或默认值加载模型ID,增加灵活性
21
+ MODEL_ID = os.getenv("MODEL_ID", "badanwang/teacher_basic_qwen3-0.6b")
22
 
23
  logger.info("===== Application Startup =====")
24
  logger.info(f"正在加载模型和分词器: {MODEL_ID}")
25
+
26
+ # 异常处理以优雅地处理模型加载失败
27
  try:
28
+ # 推荐使用 trust_remote_code=True 以确保所有模型组件正确加载
29
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
30
  model = AutoModelForCausalLM.from_pretrained(
31
  MODEL_ID,
32
+ torch_dtype="auto", # 推荐使用 "auto" 以获得最佳性能和兼容性
33
+ device_map="auto",
34
+ trust_remote_code=True
35
  )
36
  logger.info("模型和分词器加载成功!")
37
  except Exception as e:
38
  logger.error(f"加载模型时发生致命错误: {e}", exc_info=True)
39
+ # 在无法加载模型时,创建一个占位符函数,以便Gradio界面仍能启动并显示错误
40
  model, tokenizer = None, None
41
+ def model_load_error_placeholder(*args, **kwargs):
42
+ raise gr.Error(f"致命错误:无法加载模型 '{MODEL_ID}'。请检查后台日志以获取详细信息。")
43
 
44
+ # --- 2. 核心推理函数 ---
45
+
46
+ # Gradio 5.x 的 ChatInterface `fn` 函数接收两个参数: message 和 history
47
+ def predict(message: str, history: list[list[str]]):
48
  """
49
+ 核心推理函数,接收用户输入和聊天历史,并以流式方式返回模型输出。
50
+
51
+ Args:
52
+ message (str): 用户的最新输入。
53
+ history (list[list[str]]): 聊天历史,格式为 [[user_msg_1, bot_msg_1], [user_msg_2, bot_msg_2], ...]。
54
  """
55
  start_time = time.time()
56
  logger.info("\n--- [START] 进入 predict 函数 ---")
57
+ logger.info(f"[INPUT] Message: {message}")
58
+ logger.info(f"[INPUT] History:\n{json.dumps(history, indent=2, ensure_ascii=False)}")
 
59
 
60
+ # 如果模型加载失败,使用占位符函数抛出错误
61
  if model is None or tokenizer is None:
62
+ model_load_error_placeholder()
63
+
64
+ # 将 Gradio 的 history 格式转换为 Hugging Face 模板所需的格式
65
+ # history 的格式: [[user, assistant], [user, assistant], ...]
66
+ # messages 的格式: [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
67
+ messages = []
68
+ for turn in history:
69
+ user_message, bot_message = turn
70
+ messages.append({"role": "user", "content": user_message})
71
+ messages.append({"role": "assistant", "content": bot_message})
72
+ messages.append({"role": "user", "content": message})
73
+
74
+ logger.info(f"[HANDLER] 转换后的 messages 列表:\n{json.dumps(messages, indent=2, ensure_ascii=False)}")
75
 
76
  try:
77
  logger.info("[HANDLER] 正在应用聊天模板...")
 
83
  ).to(model.device)
84
  logger.info(f"[HANDLER] 模板应用成功,输入 token 数量: {prompt_tokens.shape[-1]}")
85
 
86
+ # 使用 TextIteratorStreamer 实现流式输出
87
  streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
88
+
89
+ # 定义生成参数
90
  generation_kwargs = {
91
  "input_ids": prompt_tokens,
92
  "streamer": streamer,
93
  "max_new_tokens": 1024,
94
  "do_sample": True,
95
  "temperature": 0.7,
96
+ "top_p": 0.9
97
  }
98
 
99
+ # 在单独的线程中运行模型生成,以避免阻塞UI
100
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
101
  thread.start()
102
  logger.info("[HANDLER] 生成线程已启动,开始从 streamer 中读取数据...")
103
 
104
+ # 从 streamer 中逐个 token 地 yield,实现实时流式效果
105
  buffer = ""
106
  token_count = 0
107
  for new_text in streamer:
108
  token_count += 1
109
+ if "�" in new_text: # 过滤掉解码不完全的特殊字符
110
+ continue
111
+ logger.debug(f"[STREAM] 正在生成第 {token_count} 个 token: {repr(new_text)}")
112
  buffer += new_text
113
  yield buffer
114
 
115
  logger.info(f"[HANDLER] Streamer 读取完毕,共生成 {token_count} 个 token。")
116
+ thread.join() # 确保线程执行完毕
117
 
118
  except Exception as e:
119
  logger.error(f"[HANDLER] 在推理过程中发生错误: {e}", exc_info=True)
120
+ # 使用 gr.Error 在界面上优雅地显示错误信息
121
+ raise gr.Error(f"抱歉,处理您的请求时遇到了一个内部错误: {e}")
122
 
123
+ finally:
124
+ end_time = time.time()
125
+ logger.info(f"--- [END] predict 函数结束,总耗时: {end_time - start_time:.2f} 秒 ---")
126
+
127
+ # --- 3. 创建并配置Gradio界面 ---
128
 
129
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), css="footer {visibility: hidden}") as demo:
130
+ gr.Markdown(f"# 你的自定义Qwen模型聊天机器人\n## 模型: `{MODEL_ID}`")
 
131
 
132
+ # gr.ChatInterface 是 Gradio 5.x 中构建聊天机器人的推荐方式
133
  chat_interface = gr.ChatInterface(
134
  fn=predict,
135
+ # Gradio 5.x 的 `fn` 自动接收 message 和 history,无需手动管理状态
136
+ chatbot=gr.Chatbot(
137
+ height=600,
138
+ show_copy_button=True,
139
+ avatar_images=(None, "https://s2.loli.net/2024/07/17/iPqD3uVgW9eBkbT.png") # (user, bot)
140
+ ),
141
+ title="Qwen 大模型聊天室",
142
+ description="向你的微调Qwen模型提问吧!这是一个流式输出的例子。",
143
+ examples=[
144
+ ["你好,你是谁?"],
145
+ ["用 Python 写一个快速排序算法。"],
146
+ ["解释一下什么是大型语言模型(LLM)。"]
147
+ ],
148
+ submit_btn="发送",
149
+ retry_btn="🔄 重试",
150
+ undo_btn="↩️ 撤销",
151
+ clear_btn="🗑️ 清除"
152
  )
153
 
154
  # --- 4. 启动应用 ---
155
+
156
+ if __name__ == "__main__":
157
+ logger.info("准备启动 Gradio 应用...")
158
+ # 使用 queue() 实现请求排队,concurrency_count 控制并发数
159
+ demo.queue(concurrency_count=2).launch(server_name="0.0.0.0", server_port=7860)
160
+ logger.info("Gradio 应用已启动。")