DDDDEvvvvv commited on
Commit
59ffec9
·
verified ·
1 Parent(s): 995c510

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -14
app.py CHANGED
@@ -2,28 +2,55 @@ import gradio as gr
2
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
  import torch
4
 
 
5
  model_name = "facebook/blenderbot-400M-distill"
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
8
 
9
- history = []
 
 
10
 
11
- def respond(message):
12
- global history
 
 
 
 
 
13
  history.append({"role": "user", "content": message})
14
- # Only use last 3 messages for context
15
- last_msgs = history[-3:]
16
- input_text = " ".join([m["content"] for m in last_msgs])
17
- inputs = tokenizer(input_text, return_tensors="pt")
18
- outputs = model.generate(**inputs, max_new_tokens=100)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
20
  history.append({"role": "assistant", "content": response_text})
21
- return history
22
 
23
  def reset_chat():
24
- global history
25
- history = []
26
- return history
27
 
28
  with gr.Blocks(css="""
29
  body {background-color: #000 !important; color: #fff !important;}
@@ -35,10 +62,24 @@ body {background-color: #000 !important; color: #fff !important;}
35
  .gr-button {background-color: #0ff !important; color: #000 !important; border-radius: 8px;}
36
  footer {display: none !important;}
37
  """) as demo:
 
 
 
38
  chatbot = gr.Chatbot(label="DevMegaBlack")
39
  msg = gr.Textbox(placeholder="Say something...")
40
  reset_btn = gr.Button("Reset Chat")
41
- msg.submit(respond, msg, chatbot)
42
- reset_btn.click(reset_chat, [], chatbot)
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  demo.launch()
 
2
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
  import torch
4
 
5
+ # Model setup
6
  model_name = "facebook/blenderbot-400M-distill"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
9
 
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ model.to(device)
12
+ model.eval()
13
 
14
+ # Optional speed boost on GPU
15
+ if device == "cuda":
16
+ model = model.half()
17
+
18
+ persona = "You are a helpful, concise, friendly assistant."
19
+
20
+ def respond(message, history):
21
  history.append({"role": "user", "content": message})
22
+
23
+ # Build context from last 3 turns
24
+ context = persona + "\n"
25
+ for msg in history[-6:]:
26
+ role = "User" if msg["role"] == "user" else "Bot"
27
+ context += f"{role}: {msg['content']}\n"
28
+ context += "Bot:"
29
+
30
+ inputs = tokenizer(
31
+ context,
32
+ return_tensors="pt",
33
+ truncation=True,
34
+ max_length=512
35
+ ).to(device)
36
+
37
+ with torch.no_grad():
38
+ outputs = model.generate(
39
+ **inputs,
40
+ max_new_tokens=120,
41
+ do_sample=True,
42
+ temperature=0.7,
43
+ top_p=0.9,
44
+ repetition_penalty=1.1
45
+ )
46
+
47
  response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
48
+
49
  history.append({"role": "assistant", "content": response_text})
50
+ return history, history
51
 
52
  def reset_chat():
53
+ return [], []
 
 
54
 
55
  with gr.Blocks(css="""
56
  body {background-color: #000 !important; color: #fff !important;}
 
62
  .gr-button {background-color: #0ff !important; color: #000 !important; border-radius: 8px;}
63
  footer {display: none !important;}
64
  """) as demo:
65
+
66
+ state = gr.State([])
67
+
68
  chatbot = gr.Chatbot(label="DevMegaBlack")
69
  msg = gr.Textbox(placeholder="Say something...")
70
  reset_btn = gr.Button("Reset Chat")
 
 
71
 
72
+ msg.submit(
73
+ respond,
74
+ [msg, state],
75
+ [chatbot, state]
76
+ ).then(lambda: "", None, msg)
77
+
78
+ reset_btn.click(
79
+ reset_chat,
80
+ [],
81
+ [chatbot, state]
82
+ )
83
+
84
+ demo.queue()
85
  demo.launch()