tosei0000 commited on
Commit
1448c88
·
verified ·
1 Parent(s): fdd3f6b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -11
app.py CHANGED
@@ -1,31 +1,27 @@
1
- # app.py
2
-
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
  import gradio as gr
6
 
7
- model_name = "tosei0000/tosei" # 替换为你的模型路径或名称
8
 
9
- # 加载 tokenizer 和 model
10
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
11
  model = AutoModelForCausalLM.from_pretrained(
12
  model_name,
13
  torch_dtype=torch.bfloat16,
14
- device_map="auto", # 依赖 accelerate
15
  trust_remote_code=True
16
  )
17
 
18
- # 明确设置 pad_token_id
 
 
19
  tokenizer.pad_token_id = tokenizer.eos_token_id
20
  model.config.pad_token_id = tokenizer.eos_token_id
21
 
22
- # 多轮对话函数
23
  def chat(user_input, history):
24
- # 拼接历史
25
  prompt = "".join(
26
  f"User: {u}\nAssistant: {a}\n" for u, a in history
27
  ) + f"User: {user_input}\nAssistant:"
28
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
29
  output = model.generate(
30
  **inputs,
31
  max_new_tokens=256,
@@ -36,12 +32,10 @@ def chat(user_input, history):
36
  eos_token_id=tokenizer.eos_token_id
37
  )
38
  text = tokenizer.decode(output[0], skip_special_tokens=True)
39
- # 提取回复
40
  reply = text[len(prompt):].strip().split("\n")[0]
41
  history.append((user_input, reply))
42
  return history, history
43
 
44
- # Gradio 界面
45
  with gr.Blocks(title="Qwen2 Chatbot") as demo:
46
  gr.Markdown("## 🤖 Qwen2 聊天机器人")
47
  chatbot = gr.Chatbot()
@@ -56,6 +50,7 @@ if __name__ == "__main__":
56
  demo.launch()
57
 
58
 
 
59
  # from transformers import AutoTokenizer, AutoModelForCausalLM
60
  # import torch
61
 
 
 
 
1
  from transformers import AutoTokenizer, AutoModelForCausalLM
2
  import torch
3
  import gradio as gr
4
 
5
+ model_name = "tosei0000/chatbot"
6
 
 
7
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
8
  model = AutoModelForCausalLM.from_pretrained(
9
  model_name,
10
  torch_dtype=torch.bfloat16,
 
11
  trust_remote_code=True
12
  )
13
 
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ model = model.to(device)
16
+
17
  tokenizer.pad_token_id = tokenizer.eos_token_id
18
  model.config.pad_token_id = tokenizer.eos_token_id
19
 
 
20
  def chat(user_input, history):
 
21
  prompt = "".join(
22
  f"User: {u}\nAssistant: {a}\n" for u, a in history
23
  ) + f"User: {user_input}\nAssistant:"
24
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
25
  output = model.generate(
26
  **inputs,
27
  max_new_tokens=256,
 
32
  eos_token_id=tokenizer.eos_token_id
33
  )
34
  text = tokenizer.decode(output[0], skip_special_tokens=True)
 
35
  reply = text[len(prompt):].strip().split("\n")[0]
36
  history.append((user_input, reply))
37
  return history, history
38
 
 
39
  with gr.Blocks(title="Qwen2 Chatbot") as demo:
40
  gr.Markdown("## 🤖 Qwen2 聊天机器人")
41
  chatbot = gr.Chatbot()
 
50
  demo.launch()
51
 
52
 
53
+
54
  # from transformers import AutoTokenizer, AutoModelForCausalLM
55
  # import torch
56