DDDDEvvvvv commited on
Commit
a30a54f
·
verified ·
1 Parent(s): 59ffec9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -15
app.py CHANGED
@@ -2,29 +2,53 @@ import gradio as gr
2
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
  import torch
4
 
5
- # Model setup
 
6
  model_name = "facebook/blenderbot-400M-distill"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
11
  model.to(device)
12
  model.eval()
13
 
14
- # Optional speed boost on GPU
15
  if device == "cuda":
16
  model = model.half()
17
 
18
  persona = "You are a helpful, concise, friendly assistant."
19
 
 
 
20
  def respond(message, history):
 
21
  history.append({"role": "user", "content": message})
22
 
23
- # Build context from last 3 turns
 
 
 
 
 
24
  context = persona + "\n"
25
- for msg in history[-6:]:
26
- role = "User" if msg["role"] == "user" else "Bot"
27
- context += f"{role}: {msg['content']}\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  context += "Bot:"
29
 
30
  inputs = tokenizer(
@@ -34,32 +58,56 @@ def respond(message, history):
34
  max_length=512
35
  ).to(device)
36
 
 
 
37
  with torch.no_grad():
38
  outputs = model.generate(
39
  **inputs,
40
- max_new_tokens=120,
41
  do_sample=True,
42
- temperature=0.7,
43
- top_p=0.9,
44
- repetition_penalty=1.1
 
45
  )
46
 
47
  response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
48
 
49
- history.append({"role": "assistant", "content": response_text})
50
- return history, history
 
 
 
 
 
 
 
 
51
 
52
  def reset_chat():
53
  return [], []
54
 
 
 
55
  with gr.Blocks(css="""
56
  body {background-color: #000 !important; color: #fff !important;}
57
  .gr-chatbot {background-color: #111 !important; border-radius: 12px; height: 100% !important;}
58
  .gr-chatbot .message.user {border-color: #0ff; background-color: transparent !important;}
59
  .gr-chatbot .message.bot {border-color: #aaa; background-color: transparent !important;}
60
- .gr-textbox textarea {background-color: transparent !important; color: #fff !important; border: 1px solid #555 !important;}
61
- .gr-textbox textarea::selection {background-color: #0ff !important; color: #000 !important;}
62
- .gr-button {background-color: #0ff !important; color: #000 !important; border-radius: 8px;}
 
 
 
 
 
 
 
 
 
 
 
63
  footer {display: none !important;}
64
  """) as demo:
65
 
 
2
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
  import torch
4
 
5
+ # ------------------ MODEL SETUP ------------------
6
+
7
  model_name = "facebook/blenderbot-400M-distill"
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
10
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ print("Using device:", device)
13
+
14
  model.to(device)
15
  model.eval()
16
 
17
+ # Half precision ONLY on GPU
18
  if device == "cuda":
19
  model = model.half()
20
 
21
  persona = "You are a helpful, concise, friendly assistant."
22
 
23
+ # ------------------ CHAT FUNCTION ------------------
24
+
25
  def respond(message, history):
26
+ # Add user message
27
  history.append({"role": "user", "content": message})
28
 
29
+ # Add loading placeholder
30
+ history.append({"role": "assistant", "content": "⏳ Thinking..."})
31
+ yield history, history
32
+
33
+ # --------- BUILD CONTEXT (TURN-BASED MEMORY) ---------
34
+
35
  context = persona + "\n"
36
+
37
+ # Group messages into turns (user + bot)
38
+ turns = []
39
+ temp = []
40
+ for msg in history[:-1]: # exclude "Thinking..."
41
+ temp.append(msg)
42
+ if len(temp) == 2:
43
+ turns.append(temp)
44
+ temp = []
45
+
46
+ # Keep last 3 full turns
47
+ for turn in turns[-3:]:
48
+ for msg in turn:
49
+ role = "User" if msg["role"] == "user" else "Bot"
50
+ context += f"{role}: {msg['content']}\n"
51
+
52
  context += "Bot:"
53
 
54
  inputs = tokenizer(
 
58
  max_length=512
59
  ).to(device)
60
 
61
+ # ------------------ GENERATION ------------------
62
+
63
  with torch.no_grad():
64
  outputs = model.generate(
65
  **inputs,
66
+ max_new_tokens=80,
67
  do_sample=True,
68
+ temperature=0.65,
69
+ top_p=0.85,
70
+ repetition_penalty=1.1,
71
+ num_beams=1
72
  )
73
 
74
  response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
75
 
76
+ # Replace loading text
77
+ history[-1]["content"] = response_text
78
+
79
+ # Optional hard trim to prevent slowdown
80
+ if len(history) > 12:
81
+ history = history[-10:]
82
+
83
+ yield history, history
84
+
85
+ # ------------------ RESET ------------------
86
 
87
  def reset_chat():
88
  return [], []
89
 
90
+ # ------------------ UI ------------------
91
+
92
  with gr.Blocks(css="""
93
  body {background-color: #000 !important; color: #fff !important;}
94
  .gr-chatbot {background-color: #111 !important; border-radius: 12px; height: 100% !important;}
95
  .gr-chatbot .message.user {border-color: #0ff; background-color: transparent !important;}
96
  .gr-chatbot .message.bot {border-color: #aaa; background-color: transparent !important;}
97
+ .gr-textbox textarea {
98
+ background-color: transparent !important;
99
+ color: #fff !important;
100
+ border: 1px solid #555 !important;
101
+ }
102
+ .gr-textbox textarea::selection {
103
+ background-color: #0ff !important;
104
+ color: #000 !important;
105
+ }
106
+ .gr-button {
107
+ background-color: #0ff !important;
108
+ color: #000 !important;
109
+ border-radius: 8px;
110
+ }
111
  footer {display: none !important;}
112
  """) as demo:
113