caobin commited on
Commit
22dabde
·
verified ·
1 Parent(s): 4e7b745

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -40
app.py CHANGED
@@ -1,68 +1,87 @@
1
  import gradio as gr
 
2
  import torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
 
 
 
5
  MODEL_ID = "caobin/llm-caobin"
6
 
7
- tokenizer = AutoTokenizer.from_pretrained(
8
- MODEL_ID,
9
- trust_remote_code=True
10
- )
11
-
12
  model = AutoModelForCausalLM.from_pretrained(
13
  MODEL_ID,
14
- device_map="auto",
15
  trust_remote_code=True
16
  )
17
 
18
- model.eval()
19
-
20
- def chat_fn(message, history):
 
21
  """
22
- history: List[Tuple[user, assistant]]
23
  """
24
- history = history[-3:]
25
-
26
- prompt = ""
27
- for user, assistant in history:
28
- prompt += f"<|user|>{user}<|assistant|>{assistant}"
29
-
30
- prompt += f"<|user|>{message}<|assistant|>"
31
-
32
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
33
-
34
- with torch.no_grad():
35
- output_ids = model.generate(
36
- **inputs,
37
- max_new_tokens=128,
38
- temperature=0.5,
39
- top_p=0.7,
40
- do_sample=True,
41
- pad_token_id=tokenizer.eos_token_id,
42
- )
43
-
44
- output_text = tokenizer.decode(
45
- output_ids[0],
46
- skip_special_tokens=True
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  )
48
 
 
49
  if "<|assistant|>" in output_text:
50
  output_text = output_text.split("<|assistant|>")[-1]
51
-
52
  return output_text.strip()
53
 
54
-
 
 
55
  with gr.Blocks(title="caobin LLM Chatbot") as demo:
56
- gr.Markdown("# 🤖 caobin's AI Assistant")
57
-
58
  chatbot = gr.Chatbot(height=450)
59
  msg = gr.Textbox(label="输入你的问题")
60
 
61
  def respond(message, chat_history):
62
  response = chat_fn(message, chat_history)
63
- chat_history.append((message, response))
 
 
64
  return "", chat_history
65
 
66
  msg.submit(respond, [msg, chatbot], [msg, chatbot])
67
 
68
- demo.launch()
 
 
 
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
 
4
 
5
+ # -------------------------------
6
+ # 模型加载
7
+ # -------------------------------
8
  MODEL_ID = "caobin/llm-caobin"
9
 
10
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
 
 
 
 
11
  model = AutoModelForCausalLM.from_pretrained(
12
  MODEL_ID,
13
+ device_map="auto", # CPU 上自动映射到 CPU
14
  trust_remote_code=True
15
  )
16
 
17
+ # -------------------------------
18
+ # 工具函数:清理历史
19
+ # -------------------------------
20
+ def clean_history(history):
21
  """
22
+ 将历史消息的 content 转为字符串,避免 list 导致空回答
23
  """
24
+ cleaned = []
25
+ for msg in history:
26
+ content = msg['content']
27
+ if isinstance(content, list):
28
+ # list -> str
29
+ content = " ".join([str(c) for c in content])
30
+ cleaned.append({"role": msg['role'], "content": content})
31
+ return cleaned
32
+
33
+ # -------------------------------
34
+ # 聊天函数
35
+ # -------------------------------
36
+ def chat_fn(message, history):
37
+ history = clean_history(history)
38
+ recent_history = history[-6:] # 保留最近 3 轮对话
39
+ full_prompt = ""
40
+
41
+ for msg in recent_history:
42
+ if msg["role"] == "user":
43
+ full_prompt += f"<|user|>{msg['content']}<|assistant|>"
44
+ elif msg["role"] == "assistant":
45
+ full_prompt += msg['content']
46
+
47
+ # 当前用户问题
48
+ full_prompt += f"<|user|>{message}<|assistant|>"
49
+
50
+ # tokenizer -> tensor
51
+ inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
52
+
53
+ # 生成回答
54
+ output_ids = model.generate(
55
+ **inputs,
56
+ max_new_tokens=128,
57
+ temperature=0.5,
58
+ top_p=0.7,
59
+ do_sample=True,
60
  )
61
 
62
+ output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
63
  if "<|assistant|>" in output_text:
64
  output_text = output_text.split("<|assistant|>")[-1]
 
65
  return output_text.strip()
66
 
67
+ # -------------------------------
68
+ # Gradio UI
69
+ # -------------------------------
70
  with gr.Blocks(title="caobin LLM Chatbot") as demo:
71
+ gr.Markdown("# 🤖 caobin's AI assistant")
 
72
  chatbot = gr.Chatbot(height=450)
73
  msg = gr.Textbox(label="输入你的问题")
74
 
75
  def respond(message, chat_history):
76
  response = chat_fn(message, chat_history)
77
+ # 用字典格式添加消息
78
+ chat_history.append({"role": "user", "content": message})
79
+ chat_history.append({"role": "assistant", "content": response})
80
  return "", chat_history
81
 
82
  msg.submit(respond, [msg, chatbot], [msg, chatbot])
83
 
84
+ # -------------------------------
85
+ # 启动
86
+ # -------------------------------
87
+ demo.launch()