caobin commited on
Commit
968eb3d
·
verified ·
1 Parent(s): 346f904

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -57
app.py CHANGED
@@ -1,87 +1,89 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
 
4
 
5
- # -------------------------------
6
  # 模型加载
7
- # -------------------------------
8
  MODEL_ID = "caobin/llm-caobin"
9
 
10
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
 
 
 
 
11
  model = AutoModelForCausalLM.from_pretrained(
12
  MODEL_ID,
13
- device_map="auto", # CPU 上自动映射到 CPU
14
  trust_remote_code=True
15
  )
16
 
17
- # -------------------------------
18
- # 工具函数:清理历史
19
- # -------------------------------
20
- def clean_history(history):
 
 
21
  """
22
- 将历史消息的 content 转为字符串,避免 list 导致空回答
23
  """
24
- cleaned = []
25
- for msg in history:
26
- content = msg['content']
27
- if isinstance(content, list):
28
- # list -> str
29
- content = " ".join([str(c) for c in content])
30
- cleaned.append({"role": msg['role'], "content": content})
31
- return cleaned
32
-
33
- # -------------------------------
34
- # 聊天函数
35
- # -------------------------------
36
- def chat_fn(message, history):
37
- history = clean_history(history)
38
- recent_history = history[-6:] # 保留最近 3 轮对话
39
- full_prompt = ""
40
-
41
- for msg in recent_history:
42
- if msg["role"] == "user":
43
- full_prompt += f"<|user|>{msg['content']}<|assistant|>"
44
- elif msg["role"] == "assistant":
45
- full_prompt += msg['content']
46
-
47
- # 当前用户问题
48
- full_prompt += f"<|user|>{message}<|assistant|>"
49
-
50
- # tokenizer -> tensor
51
- inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
52
-
53
- # 生成回答
54
- output_ids = model.generate(
55
- **inputs,
56
- max_new_tokens=128,
57
- temperature=0.5,
58
- top_p=0.7,
59
- do_sample=True,
60
  )
61
 
62
- output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
63
  if "<|assistant|>" in output_text:
64
  output_text = output_text.split("<|assistant|>")[-1]
 
65
  return output_text.strip()
66
 
67
- # -------------------------------
 
68
  # Gradio UI
69
- # -------------------------------
70
  with gr.Blocks(title="caobin LLM Chatbot") as demo:
71
- gr.Markdown("# 🤖 caobin's AI assistant")
 
72
  chatbot = gr.Chatbot(height=450)
73
- msg = gr.Textbox(label="输入你的问题")
 
 
 
74
 
75
  def respond(message, chat_history):
76
  response = chat_fn(message, chat_history)
77
- # 用字典格式添加消息
78
- chat_history.append({"role": "user", "content": message})
79
- chat_history.append({"role": "assistant", "content": response})
80
  return "", chat_history
81
 
82
- msg.submit(respond, [msg, chatbot], [msg, chatbot])
 
 
 
 
83
 
84
- # -------------------------------
85
  # 启动
86
- # -------------------------------
87
  demo.launch()
 
1
  import gradio as gr
 
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
+ # ===============================
6
  # 模型加载
7
+ # ===============================
8
  MODEL_ID = "caobin/llm-caobin"
9
 
10
+ tokenizer = AutoTokenizer.from_pretrained(
11
+ MODEL_ID,
12
+ trust_remote_code=True
13
+ )
14
+
15
  model = AutoModelForCausalLM.from_pretrained(
16
  MODEL_ID,
17
+ device_map="auto",
18
  trust_remote_code=True
19
  )
20
 
21
+ model.eval()
22
+
23
+ # ===============================
24
+ # 核心聊天逻辑(使用 tuple history
25
+ # ===============================
26
+ def chat_fn(message, history):
27
  """
28
+ history: List[Tuple[user, assistant]]
29
  """
30
+ # 只保留最近 3 轮
31
+ history = history[-3:]
32
+
33
+ prompt = ""
34
+ for user, assistant in history:
35
+ prompt += f"<|user|>{user}<|assistant|>{assistant}"
36
+
37
+ prompt += f"<|user|>{message}<|assistant|>"
38
+
39
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
40
+
41
+ with torch.no_grad():
42
+ output_ids = model.generate(
43
+ **inputs,
44
+ max_new_tokens=128,
45
+ temperature=0.5,
46
+ top_p=0.7,
47
+ do_sample=True,
48
+ pad_token_id=tokenizer.eos_token_id,
49
+ )
50
+
51
+ output_text = tokenizer.decode(
52
+ output_ids[0],
53
+ skip_special_tokens=True
 
 
 
 
 
 
 
 
 
 
 
 
54
  )
55
 
56
+ # 只取 assistant 的新回答
57
  if "<|assistant|>" in output_text:
58
  output_text = output_text.split("<|assistant|>")[-1]
59
+
60
  return output_text.strip()
61
 
62
+
63
+ # ===============================
64
  # Gradio UI
65
+ # ===============================
66
  with gr.Blocks(title="caobin LLM Chatbot") as demo:
67
+ gr.Markdown("# 🤖 caobin's AI Assistant")
68
+
69
  chatbot = gr.Chatbot(height=450)
70
+ msg = gr.Textbox(
71
+ label="输入你的问题",
72
+ placeholder="请输入你的问题,支持多轮对话"
73
+ )
74
 
75
  def respond(message, chat_history):
76
  response = chat_fn(message, chat_history)
77
+ chat_history.append((message, response))
 
 
78
  return "", chat_history
79
 
80
+ msg.submit(
81
+ respond,
82
+ inputs=[msg, chatbot],
83
+ outputs=[msg, chatbot]
84
+ )
85
 
86
+ # ===============================
87
  # 启动
88
+ # ===============================
89
  demo.launch()