emmanuelq2 commited on
Commit
47e116f
·
verified ·
1 Parent(s): 0d9b311

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +333 -54
app.py CHANGED
@@ -1,70 +1,349 @@
1
- import gradio as gr
2
- from huggingface_hub import InferenceClient
3
 
 
 
 
 
4
 
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
14
- """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
16
- """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
 
19
- messages = [{"role": "system", "content": system_message}]
 
20
 
21
- messages.extend(history)
22
 
23
- messages.append({"role": "user", "content": message})
24
 
25
- response = ""
26
 
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
33
- ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
38
 
39
- response += token
40
- yield response
41
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- type="messages",
49
- additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
54
- minimum=0.1,
55
- maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
59
- ),
60
- ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  )
62
 
63
- with gr.Blocks() as demo:
64
- with gr.Sidebar():
65
- gr.LoginButton()
66
- chatbot.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
 
 
 
68
 
69
- if __name__ == "__main__":
70
- demo.launch()
 
1
+ ## **Setting Up the Development Environment**
 
2
 
3
+ import torch
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
5
+ from datasets import load_dataset
6
+ import gradio as gr
7
 
8
+ import torch
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ # Check if GPU is available
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
+ print(f"Using device: {device}")
14
 
15
+ """## **Building a Baseline Chatbot**"""
16
 
17
+ from transformers import AutoModelForCausalLM, AutoTokenizer
18
 
19
+ # Load the pretrained DialoGPT model and tokenizer
20
+ MODEL_NAME= "microsoft/DialoGPT-medium"
21
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
22
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
 
 
 
 
 
 
 
23
 
24
+ # Baseline chatbot function
25
+ chat_history_ids = None
26
 
27
+ def chatbot_response(user_input, chat_history_ids=None):
28
+ new_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt")
29
+ # Add conversational history
30
+ # torch.cat() concatenates tensors along the last dimension (dim=-1).
31
+ # If this is the FIRST message (chat_history_ids is None), we just use new_input_ids.
32
+ bot_input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1) if chat_history_ids is not None else new_input_ids
33
+ # Generate a response
34
+ chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
35
+ # bot_input_ids.shape[-1] → length of the input tokens
36
+ # chat_history_ids[:, bot_input_ids.shape[-1]:] → slice off the input, keep only newly generated tokens
37
+ response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
38
+ return response
39
 
40
  """
41
+ ## **Launch Your First Chatbot Locally**"""
42
+
43
+ css = """
44
+ /* Container */
45
+ .container {
46
+ background-color: #fdf4f4;
47
+ border-radius: 15px;
48
+ box-shadow: 0 6px 20px rgba(0, 0, 0, 0.1);
49
+ padding: 25px;
50
+ font-family: 'Comic Sans MS', sans-serif;
51
+ }
52
+
53
+ /* Title */
54
+ h1 {
55
+ text-align: center;
56
+ font-size: 32px;
57
+ color: #ff7f7f;
58
+ font-weight: 600;
59
+ margin-bottom: 25px;
60
+ font-family: 'Pacifico', sans-serif;
61
+ }
62
+
63
+ /* Outer box */
64
+ .input_output_outerbox {
65
+ background-color: #f8d3d3; /* Light pink */
66
+ padding: 10px;
67
+ border-radius: 15px;
68
+ margin-bottom: 15px;
69
+ }
70
+
71
+ /* Input and Text area */
72
+ input[type="text"], textarea {
73
+ width: 100%;
74
+ padding: 18px 22px;
75
+ font-size: 18px;
76
+ border-radius: 25px;
77
+ border: 2px solid #ff6f61;
78
+ background-color: #fff9e6; /* Cream color */
79
+ color: brown;
80
+ font-weight: bold;
81
+ outline: none;
82
+ transition: border-color 0.3s ease;
83
+ }
84
+
85
+ /* Keep background and text color on focus */
86
+ input[type="text"]:focus, textarea:focus {
87
+ border-color: #ff1493;
88
+ background-color: #fff9e6 !important;
89
+ color: brown;
90
+ font-weight: bold;
91
+ box-shadow: none;
92
+ }
93
+
94
+ /* Output */
95
+ .output_text {
96
+ padding: 16px 22px;
97
+ background-color: #2e082e;
98
+ border-radius: 20px;
99
+ font-size: 18px;
100
+ color: brown;
101
+ font-weight: bold;
102
+ border: 1px solid #ff6f61;
103
+ word-wrap: break-word;
104
+ min-height: 60px;
105
+ }
106
+
107
+ /* Button */
108
+ button {
109
+ background-color: #ff6f61;
110
+ color: red;
111
+ padding: 16px 28px;
112
+ font-size: 20px;
113
+ font-weight: bold;
114
+ border-radius: 30px;
115
+ border: none;
116
+ cursor: pointer;
117
+ width: 100%;
118
+ transition: background-color 0.3s ease, transform 0.2s;
119
+ }
120
+
121
+ /* Button hover effect with animation */
122
+ button:hover {
123
+ background-color: #ff1493;
124
+ transform: scale(1.1);
125
+ }
126
+
127
+ /* Cute footer with smaller text */
128
+ footer {
129
+ text-align: center;
130
+ margin-top: 20px;
131
+ font-size: 16px;
132
+ color: #ff6f61;
133
+ }
134
+
135
  """
136
+
137
+ iface = gr.Interface(fn=chatbot_response,
138
+ theme="default",
139
+ inputs="text",
140
+ outputs="text",
141
+ title="Baseline Chatbot",
142
+ css=css)
143
+ iface.launch()
144
+
145
+ """## **Fine-Tuning the Chatbot for Better Conversations (Most effective upgrade)**"""
146
+
147
+ # Load the SAMSum dataset (robust alternative to DailyDialog)
148
+ # Using the full namespace 'knkarthick/samsum' to ensure access
149
+ dataset = load_dataset("knkarthick/samsum")
150
+
151
+ # Rename 'dialogue' to 'dialog' to match the expected variable name
152
+ dataset = dataset.rename_column("dialogue", "dialog")
153
+
154
+ # Split the dataset into training and validation sets
155
+ # SAMSum already has 'train' and 'validation' splits
156
+ train_data = dataset["train"].shuffle(seed=42).select(range(len(dataset["train"]) // 20))
157
+ valid_data = dataset["validation"].shuffle(seed=42).select(range(len(dataset["validation"]) // 20))
158
+
159
+
160
+
161
+ tokenizer.pad_token = tokenizer.eos_token
162
+
163
+ def tokenize_function(examples):
164
+ # Flatten multi-turn dialog structure
165
+ text_list = ["" .join(dialog) if isinstance(dialog, list) else dialog for dialog in examples ["dialog"] ]
166
+ # Tokenize each conversation
167
+ model_inputs = tokenizer(text_list, padding="max_length", truncation=True, max_length=128)
168
+
169
+ # Set labels = input_ids
170
+ model_inputs["labels"] = model_inputs["input_ids"].copy()
171
+
172
+ return model_inputs
173
+
174
+ # Tokenizing dataset
175
+ tokenized_train = train_data.map(tokenize_function, batched=True, remove_columns=["dialog"])
176
+ tokenized_valid = valid_data.map(tokenize_function, batched=True, remove_columns=["dialog"])
177
+
178
+ # Convert dataset format
179
+ tokenized_train.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
180
+ tokenized_valid.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
181
+
182
+ training_args = TrainingArguments(
183
+ output_dir="./fine_tuned_chatbot",
184
+ learning_rate=5e-5,
185
+ per_device_train_batch_size=2,
186
+ per_device_eval_batch_size=2,
187
+ num_train_epochs=3,
188
+ save_steps=500,
189
+ save_total_limit=2 # keeping only the two most recent points
190
+ )
191
+
192
+ trainer = Trainer(
193
+ model=model,
194
+ args=training_args,
195
+ train_dataset=tokenized_train,
196
+ eval_dataset=tokenized_valid
197
  )
198
 
199
+ import os
200
+ from transformers.integrations import WandbCallback
201
+
202
+ # Disable wandb logging environment variable
203
+ os.environ["WANDB_DISABLED"] = "true"
204
+
205
+ # Remove the WandbCallback that was added during Trainer initialization
206
+ # This is necessary because the Trainer was created before we disabled wandb
207
+ try:
208
+ trainer.remove_callback(WandbCallback)
209
+ except ValueError:
210
+ pass
211
+
212
+ # Train the model
213
+ trainer.train()
214
+
215
+ def chatbot_response(user_input):
216
+ input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt").to(model.device)
217
+ output_ids = model.generate(
218
+ input_ids,
219
+ max_new_tokens=30,
220
+ pad_token_id=tokenizer.eos_token_id,
221
+ do_sample=True,
222
+ top_k=50,
223
+ top_p=0.9,
224
+ temperature=0.7,
225
+ repetition_penalty=1.2
226
+ )
227
+ response = tokenizer.decode(output_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
228
+ return response
229
+
230
+
231
+
232
+ # Gradio UI
233
+
234
+ iface.launch()
235
+
236
+ """#### **TESTED QUERIES**
237
+
238
+ Ex: How is it going?
239
+
240
+ Ex: I am feeling a bit stressed today. Any advice?
241
+
242
+ Ex: Can you explain why people dream?
243
+
244
+ Ex: Purple elephants dance faster in the rain, right?
245
+
246
+ ## **Further Upgrading Chatbot Responses**
247
+
248
+ ### **Upgrade 1: RAG (Retrieval-Augmented Generation)**
249
+ """
250
+
251
+ # Small knowledge base
252
+ knowledge_base = {
253
+ "huggingface": "Hugging Face is a company specializing in Natural Language Processing technologies.",
254
+ "transformers": "Transformers are a type of deep learning model introduced in the paper 'Attention is All You Need'.",
255
+ "gradio": "Gradio is a Python library that allows you to rapidly create user interfaces for machine learning models."
256
+ }
257
+
258
+ def retrieve_relevant_info(query):
259
+ # Simple keyword matching
260
+ # instead using BM25 or Dense Passage Retrieval methods
261
+ for keyword, info in knowledge_base.items():
262
+ if keyword.lower() in query.lower():
263
+ return info
264
+ return ""
265
+
266
+ def chatbot_response(user_input):
267
+
268
+ retrieved_info = retrieve_relevant_info(user_input)
269
+ augmented_prompt = (retrieved_info + "\n" if retrieved_info else "") + "User: " + user_input + "\nBot:"
270
+
271
+ input_ids = tokenizer.encode(augmented_prompt, return_tensors="pt").to(model.device)
272
+
273
+ output_ids = model.generate(
274
+ input_ids,
275
+ max_new_tokens=50,
276
+ pad_token_id=tokenizer.eos_token_id,
277
+ do_sample=True,
278
+ top_p=0.85,
279
+ temperature=0.7,
280
+ top_k=50,
281
+ repetition_penalty=1.1
282
+ )
283
+
284
+ response = tokenizer.decode(output_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
285
+ return response.strip()
286
+
287
+ """### **Upgrade 2: Improving Response Coherence and Context Awareness**"""
288
+
289
+ conversation_history = []
290
+
291
+ def chatbot_response(user_input):
292
+ global conversation_history
293
+ conversation_history.append(f"User: {user_input}")
294
+ if len(conversation_history) > 6: # Limit to last 6 turns
295
+ conversation_history = conversation_history[-6:]
296
+
297
+ prompt = "\n".join(conversation_history) + "\nBot:"
298
+
299
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
300
+
301
+ output_ids = model.generate(
302
+ input_ids,
303
+ max_new_tokens=50,
304
+ pad_token_id=tokenizer.eos_token_id,
305
+ do_sample=True,
306
+ top_p=0.85,
307
+ temperature=0.7,
308
+ top_k=50,
309
+ repetition_penalty=1.1
310
+ )
311
+
312
+ response = tokenizer.decode(output_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True).strip()
313
+
314
+ conversation_history.append(f"Bot: {response}")
315
+ return response
316
+
317
+ """### **Upgrade 3: Handle Uncertain Responses with Fallback Mechanism**"""
318
+
319
+ conversation_history = []
320
+
321
+ def chatbot_response(user_input):
322
+ global conversation_history
323
+ conversation_history.append(f"User: {user_input}")
324
+ if len(conversation_history) > 6:
325
+ conversation_history = conversation_history[-6:]
326
+
327
+ prompt = "\n".join(conversation_history) + "\nBot:"
328
+
329
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
330
+
331
+ output_ids = model.generate(
332
+ input_ids,
333
+ max_new_tokens=50,
334
+ pad_token_id=tokenizer.eos_token_id,
335
+ do_sample=True,
336
+ top_p=0.9,
337
+ temperature=0.8,
338
+ top_k=50,
339
+ repetition_penalty=1.2
340
+ )
341
+
342
+ response = tokenizer.decode(output_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True).strip()
343
 
344
+ # Fallback if response is too short or vague
345
+ if not response or len(response.split()) <= 2:
346
+ response = "I'm not sure I understood that. Could you please rephrase?"
347
 
348
+ conversation_history.append(f"Bot: {response}")
349
+ return response