badanwang commited on
Commit
43e2809
·
verified ·
1 Parent(s): d92cc67

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -23
app.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import gradio as gr
3
  from threading import Thread
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
@@ -6,34 +5,25 @@ import torch
6
  import os
7
 
8
  # --- 配置 ---
9
- # 我们不再需要API Token,因为模型在本地运行
10
  MODEL_ID = "badanwang/teacher_basic_qwen3-0.6b"
11
 
12
-
13
  print("开始加载模型和分词器...")
14
  try:
15
- # 确保使用 trust_remote_code=True,因为Qwen模型需要加载自定义代码
16
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
17
  model = AutoModelForCausalLM.from_pretrained(
18
  MODEL_ID,
19
- torch_dtype="auto", # 使用适合CPU的类型,如torch.float32
20
- device_map="auto", # 自动将模型加载到可用设备(这里是CPU)
21
  trust_remote_code=True
22
  )
23
  print("模型和分词器加载成功!")
24
  except Exception as e:
25
  print(f"模型加载失败: {e}")
26
- # 如果模型加载失败,应用将无法工作,这里可以抛出异常或退出
27
  raise gr.Error(f"关键错误:无法加载模型 {MODEL_ID}。错误信息: {e}")
28
 
29
-
30
  # --- 核心对话函数 ---
31
  def predict(message, history):
32
- """
33
- 主函数,使用加载到本地的模型进行流式对话。
34
- """
35
- # 1. 格式化对话历史
36
- # Qwen的模板要求一个特殊的列表格式
37
  messages = []
38
  for turn in history:
39
  user_msg, assistant_msg = turn
@@ -41,17 +31,14 @@ def predict(message, history):
41
  messages.append({"role": "assistant", "content": assistant_msg})
42
  messages.append({"role": "user", "content": message})
43
 
44
- # 使用分词器的 apply_chat_template 方法来正确格式化输入
45
  model_inputs = tokenizer.apply_chat_template(
46
  messages,
47
  add_generation_prompt=True,
48
  return_tensors="pt"
49
- ).to(model.device) # 确保输入张量和模型在同一设备上
50
 
51
- # 2. 设置流式输出
52
  streamer = TextIteratorStreamer(tokenizer, timeout=300.0, skip_prompt=True, skip_special_tokens=True)
53
 
54
- # 3. 在一个单独的线程中运行生成,以避免阻塞UI
55
  generation_kwargs = dict(
56
  inputs=model_inputs,
57
  streamer=streamer,
@@ -63,21 +50,19 @@ def predict(message, history):
63
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
64
  thread.start()
65
 
66
- # 4. 从streamer中yield每个新生成的token
67
  full_response = ""
68
  for new_text in streamer:
69
  full_response += new_text
70
  yield full_response
71
 
72
  # --- 创建并启动Gradio界面 ---
 
73
  demo = gr.ChatInterface(
74
  fn=predict,
75
- title="老师傅",
76
  description=f"直接在Space中运行 {MODEL_ID} 模型进行流式对话。CPU推理可能较慢,请耐心等待。",
77
- examples=[["你好"], ["请用python写一个快速排序算法"], ["给我讲个笑话吧"]],
78
- cache_examples=False,
79
  )
80
 
81
  if __name__ == "__main__":
82
- demo.launch(# 这是修改后的代码
83
- demo.launch(share=True))
 
 
1
  import gradio as gr
2
  from threading import Thread
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
5
  import os
6
 
7
  # --- 配置 ---
 
8
  MODEL_ID = "badanwang/teacher_basic_qwen3-0.6b"
9
 
10
+ # --- 加载模型和分词器 ---
11
  print("开始加载模型和分词器...")
12
  try:
 
13
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
14
  model = AutoModelForCausalLM.from_pretrained(
15
  MODEL_ID,
16
+ torch_dtype="auto",
17
+ device_map="auto",
18
  trust_remote_code=True
19
  )
20
  print("模型和分词器加载成功!")
21
  except Exception as e:
22
  print(f"模型加载失败: {e}")
 
23
  raise gr.Error(f"关键错误:无法加载模型 {MODEL_ID}。错误信息: {e}")
24
 
 
25
  # --- 核心对话函数 ---
26
  def predict(message, history):
 
 
 
 
 
27
  messages = []
28
  for turn in history:
29
  user_msg, assistant_msg = turn
 
31
  messages.append({"role": "assistant", "content": assistant_msg})
32
  messages.append({"role": "user", "content": message})
33
 
 
34
  model_inputs = tokenizer.apply_chat_template(
35
  messages,
36
  add_generation_prompt=True,
37
  return_tensors="pt"
38
+ ).to(model.device)
39
 
 
40
  streamer = TextIteratorStreamer(tokenizer, timeout=300.0, skip_prompt=True, skip_special_tokens=True)
41
 
 
42
  generation_kwargs = dict(
43
  inputs=model_inputs,
44
  streamer=streamer,
 
50
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
51
  thread.start()
52
 
 
53
  full_response = ""
54
  for new_text in streamer:
55
  full_response += new_text
56
  yield full_response
57
 
58
  # --- 创建并启动Gradio界面 ---
59
+ # 已移除 examples 和 cache_examples 参数来修复点击示例时报错的问题
60
  demo = gr.ChatInterface(
61
  fn=predict,
62
+ title="小Q老师 - 基础问答 (本地加载)",
63
  description=f"直接在Space中运行 {MODEL_ID} 模型进行流式对话。CPU推理可能较慢,请耐心等待。",
 
 
64
  )
65
 
66
  if __name__ == "__main__":
67
+ # 使用 share=True 来允许跨域 WebSocket 连接
68
+ demo.launch(share=True)