caobin commited on
Commit
073fff5
·
verified ·
1 Parent(s): 611d2af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -14
app.py CHANGED
@@ -2,30 +2,55 @@ import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
 
 
 
5
  MODEL_ID = "caobin/llm-caobin"
6
 
7
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
8
  model = AutoModelForCausalLM.from_pretrained(
9
  MODEL_ID,
10
- device_map="auto", # CPU 上也可以
11
  trust_remote_code=True
12
  )
13
 
14
- def chat_fn(message, history):
15
- recent_history = history[-6:] # 只保留最近 3 轮
16
- full_prompt = ""
17
- for msg in recent_history:
 
 
 
 
 
18
  content = msg['content']
19
  if isinstance(content, list):
 
20
  content = " ".join([str(c) for c in content])
 
 
 
 
 
 
 
 
 
 
 
 
21
  if msg["role"] == "user":
22
- full_prompt += f"<|user|>{content}<|assistant|>"
23
  elif msg["role"] == "assistant":
24
- full_prompt += content
25
-
 
26
  full_prompt += f"<|user|>{message}<|assistant|>"
27
 
 
28
  inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
 
 
29
  output_ids = model.generate(
30
  **inputs,
31
  max_new_tokens=256,
@@ -33,12 +58,15 @@ def chat_fn(message, history):
33
  top_p=0.9,
34
  do_sample=True,
35
  )
 
36
  output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
37
  if "<|assistant|>" in output_text:
38
  output_text = output_text.split("<|assistant|>")[-1]
39
  return output_text.strip()
40
 
41
-
 
 
42
  with gr.Blocks(title="caobin LLM Chatbot") as demo:
43
  gr.Markdown("# 🤖 caobin's AI assistant")
44
  chatbot = gr.Chatbot(height=450)
@@ -53,9 +81,7 @@ with gr.Blocks(title="caobin LLM Chatbot") as demo:
53
 
54
  msg.submit(respond, [msg, chatbot], [msg, chatbot])
55
 
 
 
 
56
  demo.launch()
57
-
58
-
59
-
60
-
61
-
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
+ # -------------------------------
6
+ # 模型加载
7
+ # -------------------------------
8
  MODEL_ID = "caobin/llm-caobin"
9
 
10
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
11
  model = AutoModelForCausalLM.from_pretrained(
12
  MODEL_ID,
13
+ device_map="auto", # CPU 上自动映射到 CPU
14
  trust_remote_code=True
15
  )
16
 
17
+ # -------------------------------
18
+ # 工具函数:清理历史
19
+ # -------------------------------
20
+ def clean_history(history):
21
+ """
22
+ 将历史消息的 content 转为字符串,避免 list 导致空回答
23
+ """
24
+ cleaned = []
25
+ for msg in history:
26
  content = msg['content']
27
  if isinstance(content, list):
28
+ # list -> str
29
  content = " ".join([str(c) for c in content])
30
+ cleaned.append({"role": msg['role'], "content": content})
31
+ return cleaned
32
+
33
+ # -------------------------------
34
+ # 聊天函数
35
+ # -------------------------------
36
+ def chat_fn(message, history):
37
+ history = clean_history(history)
38
+ recent_history = history[-6:] # 保留最近 3 轮对话
39
+ full_prompt = ""
40
+
41
+ for msg in recent_history:
42
  if msg["role"] == "user":
43
+ full_prompt += f"<|user|>{msg['content']}<|assistant|>"
44
  elif msg["role"] == "assistant":
45
+ full_prompt += msg['content']
46
+
47
+ # 当前用户问题
48
  full_prompt += f"<|user|>{message}<|assistant|>"
49
 
50
+ # tokenizer -> tensor
51
  inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
52
+
53
+ # 生成回答
54
  output_ids = model.generate(
55
  **inputs,
56
  max_new_tokens=256,
 
58
  top_p=0.9,
59
  do_sample=True,
60
  )
61
+
62
  output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
63
  if "<|assistant|>" in output_text:
64
  output_text = output_text.split("<|assistant|>")[-1]
65
  return output_text.strip()
66
 
67
+ # -------------------------------
68
+ # Gradio UI
69
+ # -------------------------------
70
  with gr.Blocks(title="caobin LLM Chatbot") as demo:
71
  gr.Markdown("# 🤖 caobin's AI assistant")
72
  chatbot = gr.Chatbot(height=450)
 
81
 
82
  msg.submit(respond, [msg, chatbot], [msg, chatbot])
83
 
84
+ # -------------------------------
85
+ # 启动
86
+ # -------------------------------
87
  demo.launch()