caobin commited on
Commit
f41e6d1
·
verified ·
1 Parent(s): b5e4873

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -38
app.py CHANGED
@@ -1,70 +1,84 @@
 
 
 
 
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
3
  import torch
4
- import threading
5
 
6
  MODEL_ID = "caobin/llm-caobin"
7
 
8
  # 加载 tokenizer 和模型
9
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
 
 
 
 
 
10
  model = AutoModelForCausalLM.from_pretrained(
11
  MODEL_ID,
12
- torch_dtype=torch.float16,
13
  trust_remote_code=True
14
  )
15
-
16
- # 判断是否有 GPU
17
- device = "cuda" if torch.cuda.is_available() else "cpu"
18
  model.to(device)
19
  model.eval()
20
 
21
- # 边生成边输出的函数
22
- def generate_stream(prompt, max_new_tokens=512, temperature=0.7, top_p=0.9, max_history=3, history=[]):
23
- # 只保留最近 max_history 轮对话
24
- recent_history = history[-max_history:]
 
 
 
 
 
 
 
25
  full_prompt = ""
26
- for user_msg, bot_msg in recent_history:
27
- full_prompt += f"<|user|>{user_msg}<|assistant|>{bot_msg}"
28
- full_prompt += f"<|user|>{prompt}<|assistant|>"
 
 
 
29
 
30
  inputs = tokenizer(full_prompt, return_tensors="pt").to(device)
31
 
32
- # 使用流式输出
33
- streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
34
- generate_kwargs = dict(
35
  **inputs,
36
- streamer=streamer,
37
- max_new_tokens=max_new_tokens,
38
- temperature=temperature,
39
- top_p=top_p,
40
  do_sample=True,
 
41
  )
42
 
43
- thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
44
- thread.start()
45
 
46
- # 边生成边返回文本
47
- output_text = ""
48
- for new_text in streamer:
49
- output_text += new_text
50
- yield output_text.strip()
51
 
52
- # Gradio 回调函数
53
  def respond(message, chat_history):
54
- # 返回一个生成器,用于流式更新
55
- generator = generate_stream(message, history=chat_history)
56
- bot_response = ""
57
- for partial in generator:
58
- bot_response = partial
59
- yield "", chat_history + [(message, bot_response)]
60
-
61
- # 创建 Gradio 界面
 
 
62
  with gr.Blocks(title="caobin LLM Chatbot") as demo:
63
- gr.Markdown("# 🤖 caobins AI assistant")
64
 
65
- chatbot = gr.Chatbot(height=450)
66
  msg = gr.Textbox(label="输入你的问题")
67
 
68
  msg.submit(respond, [msg, chatbot], [msg, chatbot])
69
 
70
  demo.launch()
 
 
 
 
1
+
2
+
3
+
4
+
5
  import gradio as gr
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
  import torch
 
8
 
9
  MODEL_ID = "caobin/llm-caobin"
10
 
11
  # 加载 tokenizer 和模型
12
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
13
+
14
+ # 根据是否有 GPU 自动设置 dtype
15
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+
18
  model = AutoModelForCausalLM.from_pretrained(
19
  MODEL_ID,
20
+ torch_dtype=dtype,
21
  trust_remote_code=True
22
  )
 
 
 
23
  model.to(device)
24
  model.eval()
25
 
26
+ MAX_HISTORY = 3 # 只保留最近几轮对话
27
+
28
+ def chat_fn(message, history):
29
+ """
30
+ message: 用户最新输入
31
+ history: [{"role": "user"/"assistant", "content": str}, ...]
32
+ """
33
+ # 只保留最近 MAX_HISTORY 轮
34
+ recent_history = history[-MAX_HISTORY*2:] # user+assistant = 2 条消息一轮
35
+
36
+ # 拼接 prompt
37
  full_prompt = ""
38
+ for msg in recent_history:
39
+ if msg["role"] == "user":
40
+ full_prompt += f"<|user|>{msg['content']}"
41
+ elif msg["role"] == "assistant":
42
+ full_prompt += f"<|assistant|>{msg['content']}"
43
+ full_prompt += f"<|user|>{message}<|assistant|>"
44
 
45
  inputs = tokenizer(full_prompt, return_tensors="pt").to(device)
46
 
47
+ # 生成回复
48
+ output_ids = model.generate(
 
49
  **inputs,
50
+ max_new_tokens=512,
51
+ temperature=0.7,
52
+ top_p=0.9,
 
53
  do_sample=True,
54
+ pad_token_id=tokenizer.eos_token_id
55
  )
56
 
57
+ # decode 新生成部分
58
+ generated_text = tokenizer.decode(output_ids[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
59
 
60
+ return generated_text.strip()
 
 
 
 
61
 
 
62
  def respond(message, chat_history):
63
+ # chat_history 是 Gradio 最新格式 [{"role":..., "content":...}, ...]
64
+ response = chat_fn(message, chat_history)
65
+ # 更新聊天历史
66
+ new_history = chat_history + [
67
+ {"role": "user", "content": message},
68
+ {"role": "assistant", "content": response}
69
+ ]
70
+ return "", new_history
71
+
72
+ # Gradio 界面
73
  with gr.Blocks(title="caobin LLM Chatbot") as demo:
74
+ gr.Markdown("# 🤖 caobin's AI assistant")
75
 
76
+ chatbot = gr.Chatbot([], height=450) # 初始化为空列表
77
  msg = gr.Textbox(label="输入你的问题")
78
 
79
  msg.submit(respond, [msg, chatbot], [msg, chatbot])
80
 
81
  demo.launch()
82
+
83
+
84
+