tosei0000 commited on
Commit
be5d4bb
·
verified ·
1 Parent(s): 9b5ea7d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -33
app.py CHANGED
@@ -1,33 +1,31 @@
 
 
1
  from transformers import AutoTokenizer, AutoModelForCausalLM
2
  import torch
3
  import gradio as gr
4
 
5
- # 模型名称(可以换成你自己的Qwen2模型)
6
- model_name = "tosei0000/chatbot"
7
 
8
  # 加载 tokenizer 和 model
9
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
10
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True)
11
-
12
- # 设置 pad_token_id(避免警告和生成错误)
 
 
 
 
 
13
  tokenizer.pad_token_id = tokenizer.eos_token_id
14
  model.config.pad_token_id = tokenizer.eos_token_id
15
 
16
- # 聊天历史存储
17
- chat_history = []
18
-
19
- # 多轮对话生成函数
20
  def chat(user_input, history):
21
- # 构造 prompt(把历史拼接起来)
22
- prompt = ""
23
- for i, (user_msg, bot_msg) in enumerate(history):
24
- prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n"
25
- prompt += f"User: {user_input}\nAssistant:"
26
-
27
- # 编码输入
28
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
29
-
30
- # 生成
31
  output = model.generate(
32
  **inputs,
33
  max_new_tokens=256,
@@ -37,30 +35,23 @@ def chat(user_input, history):
37
  pad_token_id=tokenizer.pad_token_id,
38
  eos_token_id=tokenizer.eos_token_id
39
  )
40
-
41
- # 解码
42
- decoded = tokenizer.decode(output[0], skip_special_tokens=True)
43
-
44
- # 提取模型最新回复部分(去掉前面的prompt)
45
- response = decoded[len(prompt):].strip().split("\n")[0]
46
-
47
- # 更新历史
48
- history.append((user_input, response))
49
  return history, history
50
 
51
- # 创建 Gradio 接口
52
- with gr.Blocks(title="Qwen2 聊天机器人") as demo:
53
- gr.Markdown("## 🤖 Qwen2 Chatbot")
54
  chatbot = gr.Chatbot()
55
  msg = gr.Textbox(label="输入你的问题")
56
  clear = gr.Button("清除对话")
57
-
58
- state = gr.State([]) # 存储历史
59
 
60
  msg.submit(chat, [msg, state], [chatbot, state])
61
  clear.click(lambda: ([], []), None, [chatbot, state])
62
 
63
- # 启动 Gradio
64
  if __name__ == "__main__":
65
  demo.launch()
66
 
 
1
+ # app.py
2
+
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
  import gradio as gr
6
 
7
+ model_name = "Qwen/Qwen2-1.5B" # 替换为你的模型路径或名称
 
8
 
9
  # 加载 tokenizer 和 model
10
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
11
+ model = AutoModelForCausalLM.from_pretrained(
12
+ model_name,
13
+ torch_dtype=torch.bfloat16,
14
+ device_map="auto", # 依赖 accelerate
15
+ trust_remote_code=True
16
+ )
17
+
18
+ # 明确设置 pad_token_id
19
  tokenizer.pad_token_id = tokenizer.eos_token_id
20
  model.config.pad_token_id = tokenizer.eos_token_id
21
 
22
+ # 多轮对话函数
 
 
 
23
  def chat(user_input, history):
24
+ # 拼接历史
25
+ prompt = "".join(
26
+ f"User: {u}\nAssistant: {a}\n" for u, a in history
27
+ ) + f"User: {user_input}\nAssistant:"
 
 
 
28
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
 
29
  output = model.generate(
30
  **inputs,
31
  max_new_tokens=256,
 
35
  pad_token_id=tokenizer.pad_token_id,
36
  eos_token_id=tokenizer.eos_token_id
37
  )
38
+ text = tokenizer.decode(output[0], skip_special_tokens=True)
39
+ # 提取回复
40
+ reply = text[len(prompt):].strip().split("\n")[0]
41
+ history.append((user_input, reply))
 
 
 
 
 
42
  return history, history
43
 
44
+ # Gradio 界面
45
+ with gr.Blocks(title="Qwen2 Chatbot") as demo:
46
+ gr.Markdown("## 🤖 Qwen2 聊天机器人")
47
  chatbot = gr.Chatbot()
48
  msg = gr.Textbox(label="输入你的问题")
49
  clear = gr.Button("清除对话")
50
+ state = gr.State([])
 
51
 
52
  msg.submit(chat, [msg, state], [chatbot, state])
53
  clear.click(lambda: ([], []), None, [chatbot, state])
54
 
 
55
  if __name__ == "__main__":
56
  demo.launch()
57