caobin commited on
Commit
3a8e995
·
verified ·
1 Parent(s): f41e6d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -48
app.py CHANGED
@@ -1,84 +1,59 @@
1
-
2
-
3
-
4
-
5
  import gradio as gr
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
  import torch
8
 
 
9
  MODEL_ID = "caobin/llm-caobin"
10
 
11
  # 加载 tokenizer 和模型
12
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
13
-
14
- # 根据是否有 GPU 自动设置 dtype
15
- dtype = torch.float16 if torch.cuda.is_available() else torch.float32
16
- device = "cuda" if torch.cuda.is_available() else "cpu"
17
-
18
  model = AutoModelForCausalLM.from_pretrained(
19
  MODEL_ID,
20
- torch_dtype=dtype,
21
  trust_remote_code=True
22
  )
23
- model.to(device)
24
- model.eval()
25
-
26
- MAX_HISTORY = 3 # 只保留最近几轮对话
27
 
 
28
  def chat_fn(message, history):
29
- """
30
- message: 用户最新输入
31
- history: [{"role": "user"/"assistant", "content": str}, ...]
32
- """
33
- # 只保留最近 MAX_HISTORY 轮
34
- recent_history = history[-MAX_HISTORY*2:] # user+assistant = 2 条消息一轮
35
-
36
- # 拼接 prompt
37
  full_prompt = ""
38
- for msg in recent_history:
39
- if msg["role"] == "user":
40
- full_prompt += f"<|user|>{msg['content']}"
41
- elif msg["role"] == "assistant":
42
- full_prompt += f"<|assistant|>{msg['content']}"
43
  full_prompt += f"<|user|>{message}<|assistant|>"
44
 
45
- inputs = tokenizer(full_prompt, return_tensors="pt").to(device)
 
46
 
47
- # 生成回复
48
  output_ids = model.generate(
49
  **inputs,
50
- max_new_tokens=512,
51
  temperature=0.7,
52
  top_p=0.9,
53
  do_sample=True,
54
- pad_token_id=tokenizer.eos_token_id
55
  )
56
 
57
- # decode 新生成部分
58
- generated_text = tokenizer.decode(output_ids[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
59
-
60
- return generated_text.strip()
61
-
62
- def respond(message, chat_history):
63
- # chat_history 是 Gradio 最新格式 [{"role":..., "content":...}, ...]
64
- response = chat_fn(message, chat_history)
65
- # 更新聊天历史
66
- new_history = chat_history + [
67
- {"role": "user", "content": message},
68
- {"role": "assistant", "content": response}
69
- ]
70
- return "", new_history
71
 
72
- # Gradio 界面
73
  with gr.Blocks(title="caobin LLM Chatbot") as demo:
74
  gr.Markdown("# 🤖 caobin's AI assistant")
75
-
76
- chatbot = gr.Chatbot([], height=450) # 初始化为空列表
77
  msg = gr.Textbox(label="输入你的问题")
78
 
 
 
 
 
 
79
  msg.submit(respond, [msg, chatbot], [msg, chatbot])
80
 
81
  demo.launch()
82
 
83
 
84
 
 
 
 
 
 
 
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
+ # 模型 ID
6
  MODEL_ID = "caobin/llm-caobin"
7
 
8
  # 加载 tokenizer 和模型
9
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
 
 
 
 
 
10
  model = AutoModelForCausalLM.from_pretrained(
11
  MODEL_ID,
12
+ device_map="auto", # CPU 上会自动映射到 CPU
13
  trust_remote_code=True
14
  )
 
 
 
 
15
 
16
+ # 聊天函数
17
  def chat_fn(message, history):
18
+ # 只保留最近 3 轮历史
19
+ history = history[-3:]
 
 
 
 
 
 
20
  full_prompt = ""
21
+ for user_msg, bot_msg in history:
22
+ full_prompt += f"<|user|>{user_msg}<|assistant|>{bot_msg}"
 
 
 
23
  full_prompt += f"<|user|>{message}<|assistant|>"
24
 
25
+ # tokenizer 转 tensor
26
+ inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
27
 
28
+ # 生成回答
29
  output_ids = model.generate(
30
  **inputs,
31
+ max_new_tokens=256,
32
  temperature=0.7,
33
  top_p=0.9,
34
  do_sample=True,
 
35
  )
36
 
37
+ output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
38
+ if "<|assistant|>" in output_text:
39
+ output_text = output_text.split("<|assistant|>")[-1]
40
+ return output_text.strip()
 
 
 
 
 
 
 
 
 
 
41
 
42
+ # Gradio UI
43
  with gr.Blocks(title="caobin LLM Chatbot") as demo:
44
  gr.Markdown("# 🤖 caobin's AI assistant")
45
+ chatbot = gr.Chatbot(height=450)
 
46
  msg = gr.Textbox(label="输入你的问题")
47
 
48
+ def respond(message, chat_history):
49
+ response = chat_fn(message, chat_history)
50
+ chat_history.append((message, response))
51
+ return "", chat_history
52
+
53
  msg.submit(respond, [msg, chatbot], [msg, chatbot])
54
 
55
  demo.launch()
56
 
57
 
58
 
59
+