caobin commited on
Commit
a622589
·
verified ·
1 Parent(s): ee595ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -14
app.py CHANGED
@@ -1,44 +1,44 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
3
  import torch
4
 
5
  MODEL_ID = "caobin/llm-caobin"
6
 
7
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
 
 
 
 
8
  model = AutoModelForCausalLM.from_pretrained(
9
  MODEL_ID,
10
  torch_dtype=torch.float16,
11
- device_map="auto"
 
12
  )
13
 
14
  def chat_fn(message, history):
15
- input_text = ""
16
-
17
-
18
  for user_msg, bot_msg in history:
19
- input_text += f"<|user|>{user_msg}<|assistant|>{bot_msg}"
20
- input_text += f"<|user|>{message}<|assistant|>"
21
 
22
- inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
23
 
24
  output_ids = model.generate(
25
  **inputs,
26
  max_new_tokens=512,
27
- do_sample=True,
28
  temperature=0.7,
29
  top_p=0.9,
 
30
  )
31
 
32
  output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
33
-
34
- # 只取 assistant 最新的回答
35
  if "<|assistant|>" in output_text:
36
  output_text = output_text.split("<|assistant|>")[-1]
37
 
38
- return output_text
39
 
40
 
41
- # Gradio UI
42
  with gr.Blocks(title="caobin LLM chatbot") as demo:
43
  gr.Markdown("# 🤖 caobin 自定义 LLM 对话 Demo")
44
 
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
  MODEL_ID = "caobin/llm-caobin"
6
 
7
+ tokenizer = AutoTokenizer.from_pretrained(
8
+ MODEL_ID,
9
+ trust_remote_code=True
10
+ )
11
+
12
  model = AutoModelForCausalLM.from_pretrained(
13
  MODEL_ID,
14
  torch_dtype=torch.float16,
15
+ device_map="auto",
16
+ trust_remote_code=True
17
  )
18
 
19
  def chat_fn(message, history):
20
+ full_prompt = ""
 
 
21
  for user_msg, bot_msg in history:
22
+ full_prompt += f"<|user|>{user_msg}<|assistant|>{bot_msg}"
23
+ full_prompt += f"<|user|>{message}<|assistant|>"
24
 
25
+ inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
26
 
27
  output_ids = model.generate(
28
  **inputs,
29
  max_new_tokens=512,
 
30
  temperature=0.7,
31
  top_p=0.9,
32
+ do_sample=True,
33
  )
34
 
35
  output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
 
 
36
  if "<|assistant|>" in output_text:
37
  output_text = output_text.split("<|assistant|>")[-1]
38
 
39
+ return output_text.strip()
40
 
41
 
 
42
  with gr.Blocks(title="caobin LLM chatbot") as demo:
43
  gr.Markdown("# 🤖 caobin 自定义 LLM 对话 Demo")
44