caobin commited on
Commit
a90f54e
·
verified ·
1 Parent(s): 31dc697

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -26
app.py CHANGED
@@ -1,55 +1,70 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
 
4
 
5
- MODEL_ID = "caobin/llm-caobin"
6
-
7
- tokenizer = AutoTokenizer.from_pretrained(
8
- MODEL_ID,
9
- trust_remote_code=True
10
- )
11
 
 
 
12
  model = AutoModelForCausalLM.from_pretrained(
13
  MODEL_ID,
14
  torch_dtype=torch.float16,
15
- device_map="auto",
16
  trust_remote_code=True
17
  )
18
 
19
- def chat_fn(message, history):
 
 
 
 
 
 
 
 
20
  full_prompt = ""
21
- for user_msg, bot_msg in history:
22
  full_prompt += f"<|user|>{user_msg}<|assistant|>{bot_msg}"
23
- full_prompt += f"<|user|>{message}<|assistant|>"
24
 
25
- inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
26
 
27
- output_ids = model.generate(
 
 
28
  **inputs,
29
- max_new_tokens=512,
30
- temperature=0.7,
31
- top_p=0.9,
 
32
  do_sample=True,
33
  )
34
 
35
- output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
36
- if "<|assistant|>" in output_text:
37
- output_text = output_text.split("<|assistant|>")[-1]
38
 
39
- return output_text.strip()
 
 
 
 
40
 
 
 
 
 
 
 
 
 
41
 
42
- with gr.Blocks(title="caobin LLM chatbot") as demo:
 
43
  gr.Markdown("# 🤖 caobin 自定义 LLM 对话 Demo")
44
 
45
  chatbot = gr.Chatbot(height=450)
46
  msg = gr.Textbox(label="输入你的问题")
47
 
48
- def respond(message, chat_history):
49
- response = chat_fn(message, chat_history)
50
- chat_history.append((message, response))
51
- return "", chat_history
52
-
53
  msg.submit(respond, [msg, chatbot], [msg, chatbot])
54
 
55
  demo.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
3
  import torch
4
+ import threading
5
 
6
+ MODEL_ID = "caobin/llm-caobin"
 
 
 
 
 
7
 
8
+ # 加载 tokenizer 和模型
9
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
10
  model = AutoModelForCausalLM.from_pretrained(
11
  MODEL_ID,
12
  torch_dtype=torch.float16,
 
13
  trust_remote_code=True
14
  )
15
 
16
+ # 判断是否有 GPU
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ model.to(device)
19
+ model.eval()
20
+
21
+ # 边生成边输出的函数
22
+ def generate_stream(prompt, max_new_tokens=512, temperature=0.7, top_p=0.9, max_history=3, history=[]):
23
+ # 只保留最近 max_history 轮对话
24
+ recent_history = history[-max_history:]
25
  full_prompt = ""
26
+ for user_msg, bot_msg in recent_history:
27
  full_prompt += f"<|user|>{user_msg}<|assistant|>{bot_msg}"
28
+ full_prompt += f"<|user|>{prompt}<|assistant|>"
29
 
30
+ inputs = tokenizer(full_prompt, return_tensors="pt").to(device)
31
 
32
+ # 使用流式输出
33
+ streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
34
+ generate_kwargs = dict(
35
  **inputs,
36
+ streamer=streamer,
37
+ max_new_tokens=max_new_tokens,
38
+ temperature=temperature,
39
+ top_p=top_p,
40
  do_sample=True,
41
  )
42
 
43
+ thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
44
+ thread.start()
 
45
 
46
+ # 边生成边返回文本
47
+ output_text = ""
48
+ for new_text in streamer:
49
+ output_text += new_text
50
+ yield output_text.strip()
51
 
52
+ # Gradio 回调函数
53
+ def respond(message, chat_history):
54
+ # 返回一个生成器,用于流式更新
55
+ generator = generate_stream(message, history=chat_history)
56
+ bot_response = ""
57
+ for partial in generator:
58
+ bot_response = partial
59
+ yield "", chat_history + [(message, bot_response)]
60
 
61
+ # 创建 Gradio 界面
62
+ with gr.Blocks(title="caobin LLM Chatbot") as demo:
63
  gr.Markdown("# 🤖 caobin 自定义 LLM 对话 Demo")
64
 
65
  chatbot = gr.Chatbot(height=450)
66
  msg = gr.Textbox(label="输入你的问题")
67
 
 
 
 
 
 
68
  msg.submit(respond, [msg, chatbot], [msg, chatbot])
69
 
70
  demo.launch()