badanwang commited on
Commit
b965102
·
verified ·
1 Parent(s): e906ca1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -54
app.py CHANGED
@@ -2,82 +2,106 @@ import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
  from threading import Thread
 
 
 
5
 
 
 
 
 
 
 
 
6
 
7
  MODEL_ID = "badanwang/teacher_basic_qwen3-0.6b"
8
 
9
- print("正在加载模型和分词器...")
 
10
  try:
11
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
12
-
13
- # --- 【关键修改 1】---
14
- # 移除 torch_dtype=torch.bfloat16,使用默认的 float32,这是在 CPU 上最稳妥的选择。
15
  model = AutoModelForCausalLM.from_pretrained(
16
  MODEL_ID,
 
17
  device_map="auto"
18
  )
19
- print("模型和分词器加载成功!")
20
  except Exception as e:
21
- print(f"加载模型时出错: {e}")
22
  model, tokenizer = None, None
23
 
24
- # --- 3. 定义核心推理函数 ---
25
- def predict(message, history):
26
- print("\n--- [DEBUG] 进入 predict 函数 ---") # 调试日志
27
- print(f"[DEBUG] 收到的 Message: {message}")
28
- print(f"[DEBUG] 收到的 History: {history}")
 
 
 
 
 
29
 
30
  if model is None or tokenizer is None:
31
- print("[DEBUG] 模型或分词器为 None,返回错误。")
32
- yield "错误:模型未能成功加载,请检查后台日志和模型ID。"
33
  return
34
 
35
- # ... (将 history 转换为 chat_history_for_model 的代码保持不变) ...
36
- chat_history_for_model = []
37
- for user_msg, assistant_msg in history:
38
- chat_history_for_model.append({"role": "user", "content": user_msg})
39
- chat_history_for_model.append({"role": "assistant", "content": assistant_msg})
40
- chat_history_for_model.append({"role": "user", "content": message})
41
-
42
- print("[DEBUG] 正在应用聊天模板...")
43
- prompt_tokens = tokenizer.apply_chat_template(
44
- chat_history_for_model,
45
- add_generation_prompt=True,
46
- tokenize=True,
47
- return_tensors="pt"
48
- ).to(model.device)
49
- print(f"[DEBUG] 模板应用成功,输入 token 数量: {prompt_tokens.shape[-1]}")
50
 
51
- streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
52
 
53
- generation_kwargs = {
54
- "input_ids": prompt_tokens,
55
- "streamer": streamer,
56
- "max_new_tokens": 1024,
57
- "do_sample": True,
58
- "temperature": 0.7,
59
- "top_p": 0.9,
60
- }
61
 
62
- print("[DEBUG] 准备启动生成线程...")
63
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
64
- thread.start()
65
- print("[DEBUG] 生成线程已启动,开始从 streamer 中读取数据...")
66
 
67
- buffer = ""
68
- token_count = 0
69
- for new_text in streamer:
70
- token_count += 1
71
- print(f"[DEBUG] 正在生成第 {token_count} 个 token: '{new_text}'") # 逐个 token 打印
72
- buffer += new_text
73
- yield buffer
 
 
 
 
 
 
74
 
75
- print("[DEBUG] Streamer 读取完毕,函数结束。")
 
76
 
77
- # ... (gr.Blocks demo.launch() 的代码保持不变) ...
78
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
79
- # ...
80
- chat_interface = gr.ChatInterface(fn=predict, #...
 
 
 
 
 
 
81
  )
82
- demo.queue()
83
- demo.launch()
 
 
 
 
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
  from threading import Thread
5
+ import logging
6
+ import time
7
+ import json
8
 
9
+ # ===================================================================
10
+ # 最终版 app.py (适配 Gradio 4.x+ 的推荐模式)
11
+ # ===================================================================
12
+
13
+ # 1. 配置详细的日志记录
14
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
15
+ logger = logging.getLogger(__name__)
16
 
17
  MODEL_ID = "badanwang/teacher_basic_qwen3-0.6b"
18
 
19
+ logger.info("===== Application Startup =====")
20
+ logger.info(f"正在加载模型和分词器: {MODEL_ID}")
21
  try:
22
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
 
 
 
23
  model = AutoModelForCausalLM.from_pretrained(
24
  MODEL_ID,
25
+ torch_dtype=torch.float32, # 在 CPU 上使用 float32 以获得最佳稳定性和兼容性
26
  device_map="auto"
27
  )
28
+ logger.info("模型和分词器加载成功!")
29
  except Exception as e:
30
+ logger.error(f"加载模型时发生致命错误: {e}", exc_info=True)
31
  model, tokenizer = None, None
32
 
33
+ # --- 2. 核心推理函数 (已根据 Gradio 新模式重构) ---
34
+ def predict(messages: list):
35
+ """
36
+ 接收一个包含完整对话历史的 OpenAI 格式列表,返回模型的流式响应。
37
+ """
38
+ start_time = time.time()
39
+ logger.info("\n--- [START] 进入 predict 函数 ---")
40
+
41
+ # 使用 json.dumps 美化输出,方便阅读
42
+ logger.info(f"[INPUT] 收到的 messages 列表:\n{json.dumps(messages, indent=2, ensure_ascii=False)}")
43
 
44
  if model is None or tokenizer is None:
45
+ logger.warning("[HANDLER] 模型或分词器为 None,返回错误信息。")
46
+ yield "错误:模型未能成功加载,请检查后台日志。"
47
  return
48
 
49
+ try:
50
+ logger.info("[HANDLER] 正在应用聊天模板...")
51
+ prompt_tokens = tokenizer.apply_chat_template(
52
+ messages,
53
+ add_generation_prompt=True,
54
+ tokenize=True,
55
+ return_tensors="pt"
56
+ ).to(model.device)
57
+ logger.info(f"[HANDLER] 模板应用成功,输入 token 数量: {prompt_tokens.shape[-1]}")
 
 
 
 
 
 
58
 
59
+ streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
60
 
61
+ generation_kwargs = {
62
+ "input_ids": prompt_tokens,
63
+ "streamer": streamer,
64
+ "max_new_tokens": 1024,
65
+ "do_sample": True,
66
+ "temperature": 0.7,
67
+ "top_p": 0.9,
68
+ }
69
 
70
+ logger.info("[HANDLER] 准备启动生成线程...")
71
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
72
+ thread.start()
73
+ logger.info("[HANDLER] 生成线程已启动,开始从 streamer 中读取数据...")
74
 
75
+ buffer = ""
76
+ token_count = 0
77
+ for new_text in streamer:
78
+ token_count += 1
79
+ logger.info(f"[STREAM] 正在生成第 {token_count} 个 token: {repr(new_text)}")
80
+ buffer += new_text
81
+ yield buffer
82
+
83
+ logger.info(f"[HANDLER] Streamer 读取完毕,共生成 {token_count} 个 token。")
84
+
85
+ except Exception as e:
86
+ logger.error(f"[HANDLER] 在推理过程中发生错误: {e}", exc_info=True)
87
+ yield "抱歉,处理您的请求时遇到了一个内部错误。"
88
 
89
+ end_time = time.time()
90
+ logger.info(f"--- [END] predict 函数结束,总耗时: {end_time - start_time:.2f} 秒 ---")
91
 
92
+ # --- 3. 创建Gradio界面 (已优化) ---
93
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
94
+ gr.Markdown(f"# 你的自定义Qwen模型聊天机器人\n## 模型: {MODEL_ID}")
95
+
96
+ chat_interface = gr.ChatInterface(
97
+ fn=predict,
98
+ title="聊天机器人",
99
+ description="向你的微调Qwen模型提问吧!",
100
+ examples=[["你好,你是谁?"], ["用Python写一个快速排序算法"]],
101
+ type="messages" # <-- 【最关键的优化】告诉 Gradio 使用新的 OpenAI 格式
102
  )
103
+
104
+ # --- 4. 启动应用 ---
105
+ logger.info("准备启动 Gradio 应用...")
106
+ demo.queue().launch()
107
+ logger.info("Gradio 应用已启动。")