badanwang commited on
Commit
128f145
·
verified ·
1 Parent(s): e20c625

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -25
app.py CHANGED
@@ -4,35 +4,40 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import os
5
 
6
  # --- 1. 配置与模型加载 ---
 
7
  MODEL_ID = os.getenv("MODEL_ID", "badanwang/teacher_basic_qwen3-0.6b")
8
- print(f"正在加载模型: {MODEL_ID}")
9
 
10
- # 尝试加载模型,如果失败则在界面上显示错误
11
  try:
 
12
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
13
- # 使用 device_map="auto" accelerate 库自动处理设备分配
14
  model = AutoModelForCausalLM.from_pretrained(
15
  MODEL_ID,
16
- torch_dtype="auto",
17
  device_map="auto",
18
  trust_remote_code=True
19
  )
20
- print("模型和分词器加载成功!")
21
 
22
- # 定义核心推理函数
 
23
  def predict(prompt: str, history: list[list[str]]):
24
  """
25
- 接收输入和历史,返回更新后的历史。
26
- Gradio 会自动为此函数创建 API 端点。
27
  """
28
- print(f"收到请求: prompt='{prompt}'")
29
 
 
30
  messages = []
31
  for user_message, bot_message in history:
32
  messages.append({"role": "user", "content": user_message})
33
  messages.append({"role": "assistant", "content": bot_message})
34
  messages.append({"role": "user", "content": prompt})
35
 
 
36
  input_ids = tokenizer.apply_chat_template(
37
  messages,
38
  add_generation_prompt=True,
@@ -40,31 +45,46 @@ try:
40
  return_tensors="pt"
41
  ).to(model.device)
42
 
 
 
43
  outputs = model.generate(input_ids, max_new_tokens=1024)
 
 
44
  response_text = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
45
 
46
- print(f"生成回复: {response_text}")
47
 
 
48
  history.append([prompt, response_text])
49
  return history
50
 
51
  except Exception as e:
52
- print(f"加载模型时发生致命错误: {e}")
53
- # 如果模型加载失败,则定义一个报错函数
 
54
  def predict(*args, **kwargs):
55
- raise gr.Error(f"模型加载失败,请检查Space后台日志以确认是否为内存不足。错误详情: {e}")
 
 
 
 
 
 
 
 
 
 
56
 
57
- # --- 2. 创建并启动Gradio应用 ---
58
- with gr.Blocks(theme=gr.themes.Default()) as demo:
59
- gr.Markdown(f"## 简易模型聊天 ({MODEL_ID})")
60
- chatbot = gr.Chatbot(label="对话窗口", height=600)
61
- msg = gr.Textbox(label="输入你的问题")
62
- clear = gr.Button("清除对话")
63
 
64
- msg.submit(predict, [msg, chatbot], chatbot)
65
- clear.click(lambda: [], None, chatbot)
66
 
67
- # .queue() 允许处理排队请求
68
- # api_open=True 是关键,它会自动创建 /run/predict API 端点
69
- print("准备启动Gradio应用...")
70
- demo.queue().launch(api_open=True)
 
4
  import os
5
 
6
  # --- 1. 配置与模型加载 ---
7
+ # 假设运行环境的硬件资源是充足的。
8
  MODEL_ID = os.getenv("MODEL_ID", "badanwang/teacher_basic_qwen3-0.6b")
9
+ print(f"INFO: 正在加载模型: {MODEL_ID}")
10
 
11
+ # 使用 try-except 来捕获任何可能的加载错误 (例如网络问题、模型名称错误等)
12
  try:
13
+ # 加载分词器和模型
14
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
15
+ # device_map="auto" 会自动利用可用的硬件 (如 CPU 或 GPU)
16
  model = AutoModelForCausalLM.from_pretrained(
17
  MODEL_ID,
18
+ torch_dtype="auto", # 自动选择最佳数据类型
19
  device_map="auto",
20
  trust_remote_code=True
21
  )
22
+ print("INFO: 模型和分词器加载成功!")
23
 
24
+ # 将核心推理逻辑定义为一个函数
25
+ # 只有在模型成功加载后,这个函数才会被有效定义
26
  def predict(prompt: str, history: list[list[str]]):
27
  """
28
+ 接收用户输入和对话历史,返回更新后的完整对话历史。
29
+ Gradio 会自动为这个函数创建 API 端点。
30
  """
31
+ print(f"INFO: 收到API/UI请求: prompt='{prompt}'")
32
 
33
+ # 1. 构建符合模型要求的消息列表
34
  messages = []
35
  for user_message, bot_message in history:
36
  messages.append({"role": "user", "content": user_message})
37
  messages.append({"role": "assistant", "content": bot_message})
38
  messages.append({"role": "user", "content": prompt})
39
 
40
+ # 2. 应用聊天模板并进行分词
41
  input_ids = tokenizer.apply_chat_template(
42
  messages,
43
  add_generation_prompt=True,
 
45
  return_tensors="pt"
46
  ).to(model.device)
47
 
48
+ # 3. 生成回复
49
+ # 使用简单的 .generate(),不加额外的采样参数以保持简洁
50
  outputs = model.generate(input_ids, max_new_tokens=1024)
51
+
52
+ # 4. 解码生成的文本,跳过输入的token
53
  response_text = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
54
 
55
+ print(f"INFO: 生成回复: {response_text}")
56
 
57
+ # 5. 更新并返回对话历史
58
  history.append([prompt, response_text])
59
  return history
60
 
61
  except Exception as e:
62
+ print(f"FATAL: 加载模型或分词器时发生致命错误: {e}")
63
+ # 如果模型加载失败,则定义一个专门用于报错的函数
64
+ # 这能确保Gradio界面依然可以启动,并向用户显示一个清晰的错误信息
65
  def predict(*args, **kwargs):
66
+ raise gr.Error(f"模型未能加载,应用无法工作。请检查后台日志获取详细错误信息。错误: {e}")
67
+
68
+ # --- 2. 创建并启动 Gradio 应用 ---
69
+ # 使用 gr.Blocks 来自定义界面布局
70
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
71
+ gr.Markdown(f"## 模型聊天机器人\n当前模型: `{MODEL_ID}`")
72
+
73
+ # 定义聊天机器人组件和输入框
74
+ chatbot = gr.Chatbot(label="对话历史", height=600)
75
+ msg_input = gr.Textbox(label="在这里输入你的问题...", placeholder="例如:你好,你是谁?")
76
+ clear_button = gr.Button("清除对话")
77
 
78
+ # 设定组件的交互逻辑
79
+ # 当用户在输入框中按回车时,调用 predict 函数
80
+ msg_input.submit(predict, [msg_input, chatbot], chatbot)
81
+ # 当用户点击“清除对话”按钮时,清空聊天机器人组件
82
+ clear_button.click(lambda: [], None, chatbot)
 
83
 
84
+ # --- 3. 启动应用并开放API ---
85
+ print("INFO: 准备启动Gradio应用...")
86
 
87
+ # .queue() 使应用���够处理多个排队的请求
88
+ # share=True 是解决CORS问题的关键。它会生成一个公开的、已配置好CORS的 .gradio.live 网址。
89
+ # api_open=True 明确地开启所有公共函数的API访问功能。
90
+ demo.queue().launch(share=True, api_open=True)