update styles in app.py

#1
by sohn12 - opened
Files changed (1) hide show
  1. app.py +855 -847
app.py CHANGED
@@ -1,848 +1,856 @@
1
- import gradio as gr
2
- import torch
3
- import os
4
- import json
5
- import re
6
- import random
7
- from transformers import (
8
- AutoTokenizer,
9
- AutoModelForSequenceClassification,
10
- AutoModelForCausalLM,
11
- pipeline,
12
- )
13
- import datetime
14
- import sys
15
-
16
- # Define emotion label mapping
17
- EMOTION_LABELS = [
18
- "admiration", "amusement", "anger", "annoyance", "approval", "caring", "confusion",
19
- "curiosity", "desire", "disappointment", "disapproval", "disgust", "embarrassment",
20
- "excitement", "fear", "gratitude", "grief", "joy", "love", "nervousness", "optimism",
21
- "pride", "realization", "relief", "remorse", "sadness", "surprise", "neutral"
22
- ]
23
-
24
- # Map similar emotions to our response categories
25
- EMOTION_MAPPING = {
26
- "admiration": "joy",
27
- "amusement": "joy",
28
- "anger": "anger",
29
- "annoyance": "anger",
30
- "approval": "joy",
31
- "caring": "joy",
32
- "confusion": "neutral",
33
- "curiosity": "neutral",
34
- "desire": "neutral",
35
- "disappointment": "sadness",
36
- "disapproval": "anger",
37
- "disgust": "disgust",
38
- "embarrassment": "sadness",
39
- "excitement": "joy",
40
- "fear": "fear",
41
- "gratitude": "joy",
42
- "grief": "sadness",
43
- "joy": "joy",
44
- "love": "joy",
45
- "nervousness": "fear",
46
- "optimism": "joy",
47
- "pride": "joy",
48
- "realization": "neutral",
49
- "relief": "joy",
50
- "remorse": "sadness",
51
- "sadness": "sadness",
52
- "surprise": "surprise",
53
- "neutral": "neutral"
54
- }
55
-
56
- class ChatbotContext:
57
- """Class to maintain conversation context and history"""
58
- def __init__(self):
59
- self.conversation_history = []
60
- self.detected_emotions = []
61
- self.user_feedback = []
62
- self.current_session_id = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
63
- # Track emotional progression for therapeutic conversation flow
64
- self.conversation_stage = "initial" # initial, middle, advanced
65
- self.emotion_trajectory = [] # track emotion changes over time
66
- self.consecutive_positive_count = 0
67
- self.consecutive_negative_count = 0
68
- # Add user name tracking
69
- self.user_name = None
70
- self.bot_name = "Mira" # Friendly, easy to remember name
71
- self.introduced = False
72
- self.waiting_for_name = False
73
-
74
- def add_message(self, role, text, emotions=None):
75
- """Add a message to the conversation history"""
76
- timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
77
- message = {
78
- "role": role,
79
- "text": text,
80
- "timestamp": timestamp
81
- }
82
- if emotions and role == "user":
83
- message["emotions"] = emotions
84
- self.detected_emotions.append(emotions)
85
- self._update_emotional_trajectory(emotions)
86
-
87
- self.conversation_history.append(message)
88
- return message
89
-
90
- def _update_emotional_trajectory(self, emotions):
91
- """Update the emotional trajectory based on newly detected emotions"""
92
- # Get the primary emotion
93
- primary_emotion = emotions[0]["emotion"] if emotions else "neutral"
94
-
95
- # Add to trajectory
96
- self.emotion_trajectory.append(primary_emotion)
97
-
98
- # Classify as positive, negative, or neutral
99
- positive_emotions = ["joy", "admiration", "amusement", "excitement",
100
- "optimism", "gratitude", "pride", "love", "relief"]
101
- negative_emotions = ["sadness", "anger", "fear", "disgust", "disappointment",
102
- "annoyance", "disapproval", "embarrassment", "grief",
103
- "remorse", "nervousness"]
104
-
105
- if primary_emotion in positive_emotions:
106
- self.consecutive_positive_count += 1
107
- self.consecutive_negative_count = 0
108
- elif primary_emotion in negative_emotions:
109
- self.consecutive_negative_count += 1
110
- self.consecutive_positive_count = 0
111
- else: # neutral or other
112
- # Don't reset counters for neutral emotions to maintain progress
113
- pass
114
-
115
- # Update conversation stage based on trajectory and message count
116
- msg_count = len(self.conversation_history) // 2 # Count actual exchanges (user/bot pairs)
117
- if msg_count <= 1: # First real exchange
118
- self.conversation_stage = "initial"
119
- elif msg_count <= 3: # First few exchanges
120
- self.conversation_stage = "middle"
121
- else: # More established conversation
122
- self.conversation_stage = "advanced"
123
-
124
- def get_emotional_state(self):
125
- """Get the current emotional state of the conversation"""
126
- if len(self.emotion_trajectory) < 2:
127
- return "unknown"
128
-
129
- # Get the last few emotions (with 'neutral' having less weight)
130
- recent_emotions = self.emotion_trajectory[-3:]
131
- positive_emotions = ["joy", "admiration", "amusement", "excitement",
132
- "optimism", "gratitude", "pride", "love", "relief"]
133
- negative_emotions = ["sadness", "anger", "fear", "disgust", "disappointment",
134
- "annoyance", "disapproval", "embarrassment", "grief",
135
- "remorse", "nervousness"]
136
-
137
- # Count positive and negative emotions
138
- pos_count = sum(1 for e in recent_emotions if e in positive_emotions)
139
- neg_count = sum(1 for e in recent_emotions if e in negative_emotions)
140
-
141
- if self.consecutive_positive_count >= 2:
142
- return "positive"
143
- elif self.consecutive_negative_count >= 2:
144
- return "negative"
145
- elif pos_count > neg_count:
146
- return "improving"
147
- elif neg_count > pos_count:
148
- return "declining"
149
- else:
150
- return "neutral"
151
-
152
- def add_feedback(self, rating, comments=None):
153
- """Add user feedback about the chatbot's response"""
154
- feedback = {
155
- "rating": rating,
156
- "comments": comments,
157
- "timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
158
- }
159
- self.user_feedback.append(feedback)
160
- return feedback
161
-
162
- def get_recent_messages(self, count=5):
163
- """Get the most recent messages from the conversation history"""
164
- return self.conversation_history[-count:] if len(self.conversation_history) >= count else self.conversation_history
165
-
166
- def save_conversation(self, filepath=None):
167
- """Save the conversation history to a JSON file"""
168
- if not filepath:
169
- os.makedirs("./conversations", exist_ok=True)
170
- filepath = f"./conversations/conversation_{self.current_session_id}.json"
171
-
172
- data = {
173
- "conversation_history": self.conversation_history,
174
- "user_feedback": self.user_feedback,
175
- "emotion_trajectory": self.emotion_trajectory,
176
- "session_id": self.current_session_id,
177
- "start_time": self.conversation_history[0]["timestamp"] if self.conversation_history else None,
178
- "end_time": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
179
- }
180
-
181
- with open(filepath, 'w') as f:
182
- json.dump(data, f, indent=2)
183
- print(f"Conversation saved to {filepath}")
184
- return filepath
185
-
186
- def clean_response_text(response, user_name):
187
- """Clean up the response text to make it more natural"""
188
- # Remove repeated name mentions
189
- if user_name:
190
- # Replace patterns like "Hey user_name," or "Hi user_name,"
191
- response = re.sub(r'^(Hey|Hi|Hello)\s+' + re.escape(user_name) + r',?\s+', '', response, flags=re.IGNORECASE)
192
-
193
- # Replace duplicate name mentions
194
- pattern = re.escape(user_name) + r',?\s+.*' + re.escape(user_name)
195
- if re.search(pattern, response, re.IGNORECASE):
196
- response = re.sub(r',?\s+' + re.escape(user_name) + r'([,.!?])', r'\1', response, flags=re.IGNORECASE)
197
-
198
- # Remove name at the end of sentences if it appears earlier
199
- if response.count(user_name) > 1:
200
- response = re.sub(r',\s+' + re.escape(user_name) + r'([.!?])(\s|$)', r'\1\2', response, flags=re.IGNORECASE)
201
-
202
- # Remove phrases that feel repetitive or formulaic
203
- phrases_to_remove = [
204
- r"let me know what you'd prefer,?\s+",
205
- r"i'm here to listen,?\s+",
206
- r"let me know if there's anything else,?\s+",
207
- r"i'm all ears,?\s+",
208
- r"i'm here for you,?\s+"
209
- ]
210
-
211
- for phrase in phrases_to_remove:
212
- response = re.sub(phrase, "", response, flags=re.IGNORECASE)
213
-
214
- # Fix multiple punctuation
215
- response = re.sub(r'([.!?])\s+\1', r'\1', response)
216
-
217
- # Fix missing space after punctuation
218
- response = re.sub(r'([.!?])([A-Za-z])', r'\1 \2', response)
219
-
220
- # Make sure first letter is capitalized
221
- if response and len(response) > 0:
222
- response = response[0].upper() + response[1:]
223
-
224
- return response.strip()
225
-
226
- class GradioEmotionChatbot:
227
- def __init__(self, emotion_model_id, response_model_id=None, confidence_threshold=0.3):
228
- self.emotion_model_id = emotion_model_id
229
- self.response_model_id = response_model_id or "mistralai/Mistral-7B-Instruct-v0.2"
230
- self.confidence_threshold = confidence_threshold
231
- self.context = ChatbotContext()
232
- self.initialize_models()
233
-
234
- def initialize_models(self):
235
- # Initialize emotion classification model
236
- print(f"Loading emotion classification model: {self.emotion_model_id}")
237
- try:
238
- self.emotion_model = AutoModelForSequenceClassification.from_pretrained(self.emotion_model_id)
239
- self.emotion_tokenizer = AutoTokenizer.from_pretrained(self.emotion_model_id)
240
-
241
- self.emotion_classifier = pipeline(
242
- "text-classification",
243
- model=self.emotion_model,
244
- tokenizer=self.emotion_tokenizer,
245
- top_k=None # Returns scores for all labels
246
- )
247
- print("Emotion classification model loaded successfully!")
248
- except Exception as e:
249
- print(f"Error loading emotion classification model: {e}")
250
- # Fallback to a dummy classifier for demo purposes
251
- self.emotion_classifier = lambda text: [[{"label": "neutral", "score": 1.0}]]
252
-
253
- # Initialize response generation model (or use fallback)
254
- print(f"Loading response generation model: {self.response_model_id}")
255
- try:
256
- self.response_model = AutoModelForCausalLM.from_pretrained(
257
- self.response_model_id,
258
- torch_dtype=torch.float16,
259
- device_map="auto"
260
- )
261
- self.response_tokenizer = AutoTokenizer.from_pretrained(self.response_model_id)
262
-
263
- self.response_generator = pipeline(
264
- "text-generation",
265
- model=self.response_model,
266
- tokenizer=self.response_tokenizer,
267
- do_sample=True,
268
- top_p=0.92,
269
- top_k=50,
270
- temperature=0.7,
271
- max_new_tokens=100
272
- )
273
- print("Response generation model loaded successfully!")
274
- except Exception as e:
275
- print(f"Using fallback response generation. Reason: {e}")
276
- self.response_generator = self.fallback_response_generator
277
-
278
- def fallback_response_generator(self, prompt, **kwargs):
279
- """Fallback response generator using templates"""
280
- # Try to extract emotion from the prompt
281
- emotion_match = re.search(r"emotion: (\w+)", prompt.lower())
282
- if emotion_match:
283
- emotion = emotion_match.group(1)
284
- else:
285
- emotion = "neutral"
286
-
287
- # Default user name
288
- user_name = "friend"
289
- name_match = re.search(r"Your friend \((.*?)\)", prompt.lower())
290
- if name_match:
291
- user_name = name_match.group(1)
292
-
293
- # Extract user message
294
- message_match = re.search(r"message: \"(.*?)\"", prompt)
295
- user_message = message_match.group(1) if message_match else ""
296
-
297
- # Generate response using fallback method
298
- response = self.natural_fallback_response(user_message, emotion, user_name)
299
-
300
- # Format as if coming from the pipeline
301
- return [{"generated_text": response}]
302
-
303
- def natural_fallback_response(self, user_message, primary_emotion, user_name):
304
- """Conversational fallback responses that sound like a supportive friend"""
305
- # Define emotion categories
306
- sad_emotions = ["sadness", "disappointment", "grief", "remorse"]
307
- fear_emotions = ["fear", "nervousness", "anxiety"]
308
- anger_emotions = ["anger", "annoyance", "disapproval", "disgust"]
309
- joy_emotions = ["joy", "admiration", "amusement", "excitement", "optimism",
310
- "gratitude", "pride", "love", "relief"]
311
-
312
- # Multi-stage response templates - more natural and varied
313
- if primary_emotion in joy_emotions:
314
- responses = [
315
- f"That's awesome, {user_name}! What made you feel that way?",
316
- f"I'm so glad to hear that! Tell me more about it?",
317
- f"That's great news! What else is going on with you lately?"
318
- ]
319
- elif primary_emotion in sad_emotions:
320
- responses = [
321
- f"I'm sorry to hear that, {user_name}. Want to talk about what happened?",
322
- f"That sounds rough. What's been going on?",
323
- f"Ugh, that's tough. How are you handling it?"
324
- ]
325
- elif primary_emotion in anger_emotions:
326
- responses = [
327
- f"That sounds really frustrating. What happened?",
328
- f"Oh no, that would upset me too. Want to vent about it?",
329
- f"I can see why you'd be upset about that. What are you thinking of doing?"
330
- ]
331
- elif primary_emotion in fear_emotions:
332
- responses = [
333
- f"That sounds scary, {user_name}. What's got you worried?",
334
- f"I can imagine that would be stressful. What's on your mind about it?",
335
- f"I get feeling anxious about that. What's the biggest concern for you?"
336
- ]
337
- else: # neutral emotions
338
- responses = [
339
- f"What's been on your mind lately, {user_name}?",
340
- f"How's everything else going with you?",
341
- f"Tell me more about what's going on in your life these days."
342
- ]
343
-
344
- return random.choice(responses)
345
-
346
- def classify_text(self, text):
347
- """Classify text and return emotion data"""
348
- try:
349
- results = self.emotion_classifier(text)
350
-
351
- # Sort emotions by score in descending order
352
- sorted_emotions = sorted(results[0], key=lambda x: x['score'], reverse=True)
353
-
354
- # Process emotions above threshold
355
- detected_emotions = []
356
- for emotion in sorted_emotions:
357
- # Map numerical label to emotion name
358
- try:
359
- label_id = int(emotion['label'].split('_')[-1]) if '_' in emotion['label'] else int(emotion['label'])
360
- if 0 <= label_id < len(EMOTION_LABELS):
361
- emotion_name = EMOTION_LABELS[label_id]
362
- else:
363
- emotion_name = emotion['label']
364
- except (ValueError, IndexError):
365
- emotion_name = emotion['label']
366
-
367
- score = emotion['score']
368
-
369
- if score >= self.confidence_threshold:
370
- detected_emotions.append({"emotion": emotion_name, "score": score})
371
-
372
- # If no emotions detected above threshold, add neutral
373
- if not detected_emotions:
374
- detected_emotions.append({"emotion": "neutral", "score": 1.0})
375
-
376
- return detected_emotions
377
- except Exception as e:
378
- print(f"Error during classification: {e}")
379
- # Return neutral as fallback
380
- return [{"emotion": "neutral", "score": 1.0}]
381
-
382
- def format_emotion_text(self, emotion_data):
383
- """Create a simple emotion text display"""
384
- if not emotion_data:
385
- return ""
386
-
387
- # Define emotion emojis
388
- emotion_emojis = {
389
- "joy": "😊", "admiration": "🤩", "amusement": "😄", "approval": "👍",
390
- "excitement": "🎉", "gratitude": "🙏", "love": "❤️", "optimism": "🌟",
391
- "pride": "🦚", "relief": "😌", "sadness": "😢", "disappointment": "😞",
392
- "grief": "💔", "remorse": "😔", "embarrassment": "😳", "anger": "😠",
393
- "annoyance": "😤", "disapproval": "👎", "disgust": "🤢", "fear": "😨",
394
- "nervousness": "😰", "surprise": "😲", "confusion": "😕", "curiosity": "🤔",
395
- "neutral": "😐", "realization": "💡", "desire": "✨"
396
- }
397
-
398
- # Format the primary emotion
399
- primary = emotion_data[0]["emotion"]
400
- emoji = emotion_emojis.get(primary, "😐")
401
- score = emotion_data[0]["score"]
402
-
403
- return f"Detected: {emoji} {primary.capitalize()} ({score:.2f})"
404
-
405
- def generate_response(self, user_message, emotion_data):
406
- """Generate a response based on the user's message and detected emotions"""
407
- # Get the primary emotion with context awareness
408
- primary_emotion = emotion_data[0]["emotion"] if emotion_data else "neutral"
409
-
410
- # Get recent conversation history for context
411
- recent_exchanges = self.context.get_recent_messages(6)
412
- conversation_history = ""
413
- for msg in recent_exchanges:
414
- role = "Friend" if msg["role"] == "user" else self.context.bot_name
415
- conversation_history += f"{role}: {msg['text']}\n"
416
-
417
- # Check if this is a greeting
418
- is_greeting = any(greeting in user_message.lower() for greeting in ["hi", "hello", "hey", "greetings"])
419
- is_question_about_bot = "how are you" in user_message.lower() or any(q in user_message.lower() for q in ["what can you do", "who are you", "what are you", "your purpose"])
420
-
421
- # Handle special cases
422
- if is_greeting:
423
- if len(self.context.conversation_history) <= 4: # First greeting exchange
424
- return f"Hi! I'm {self.context.bot_name}. It's nice to meet you. How are you feeling today?"
425
- else:
426
- return f"Hey! Good to chat with you again. What's been going on with you?"
427
-
428
- elif is_question_about_bot:
429
- return f"I'm doing well, thanks for asking! I'm {self.context.bot_name}, here as a friend to chat whenever you need someone to talk to. What's on your mind today?"
430
-
431
- # Create a more conversational prompt based on emotion
432
- system_instruction = f"""You are {self.context.bot_name}, having a natural conversation with your friend. You should respond in a casual, warm way like a supportive friend would - not like a therapist or clinical chatbot.
433
-
434
- Your friend seems to be feeling {primary_emotion}. In your response:
435
- 1. Be genuinely empathetic but natural - like how a real friend would respond
436
- 2. Keep your response short (1-3 sentences) and conversational
437
- 3. Don't use phrases like "I understand" or "I'm here for you" too much - vary your language
438
- 4. Use casual language, contractions (don't instead of do not), and occasional sentence fragments
439
- 5. Don't sound formulaic or overly positive - be authentic
440
- 6. Keep the same emotional tone throughout your response
441
- 7. Don't explain what you're doing or add meta-commentary
442
- 8. DON'T address them by name multiple times or at the end of sentences - it sounds unnatural
443
- 9. Don't end with "Let me know what you'd prefer" or similar phrases
444
-
445
- Recent conversation:
446
- {conversation_history}
447
-
448
- Your friend's message: "{user_message}"
449
- Current emotion: {primary_emotion}
450
-
451
- Respond naturally as a supportive friend (without using their name more than once if at all):"""
452
-
453
- try:
454
- # Generate the response
455
- generated = self.response_generator(
456
- system_instruction,
457
- max_new_tokens=100,
458
- do_sample=True,
459
- temperature=0.8,
460
- top_p=0.92,
461
- top_k=50,
462
- )
463
-
464
- # Extract the generated text
465
- if isinstance(generated, list):
466
- response_text = generated[0].get('generated_text', '')
467
- else:
468
- response_text = generated.get('generated_text', '')
469
-
470
- # Clean up the response - extract only the actual response without system prompt
471
- if "[/INST]" in response_text:
472
- parts = response_text.split("[/INST]")
473
- if len(parts) > 1:
474
- response_text = parts[1].strip()
475
-
476
- # If we're still getting the system instruction, try an alternative approach
477
- if "Your friend seems to be feeling" in response_text:
478
- # Try to extract just the bot's response using pattern matching
479
- match = re.search(r'Respond naturally as a supportive friend.*?:\s*(.*?)$', response_text, re.DOTALL)
480
- if match:
481
- response_text = match.group(1).strip()
482
- else:
483
- # If that fails, try another approach - take text after the last numbered instruction
484
- match = re.search(r'9\.\s+[^\n]+\s*(.*?)$', response_text, re.DOTALL)
485
- if match:
486
- response_text = match.group(1).strip()
487
- else:
488
- # Last resort: pick a fallback response based on emotion
489
- response_text = self.natural_fallback_response(user_message, primary_emotion, self.context.user_name or "friend")
490
-
491
- # Remove any model-specific markers
492
- response_text = response_text.replace("<s>", "").replace("</s>", "")
493
-
494
- # Remove any internal notes or debugging info that might appear
495
- if "Note:" in response_text:
496
- response_text = response_text.split("Note:")[0].strip()
497
-
498
- # Remove any metadata or system-like text
499
- response_text = response_text.replace("Assistant:", "").replace(f"{self.context.bot_name}:", "").strip()
500
-
501
- # Remove any quotation marks surrounding the response
502
- response_text = response_text.strip('"').strip()
503
-
504
- # Handle potential model halt mid-sentence
505
- if response_text.endswith((".", "!", "?")):
506
- pass # Response ends with proper punctuation
507
- else:
508
- # Try to find the last complete sentence
509
- last_period = max(response_text.rfind("."), response_text.rfind("!"), response_text.rfind("?"))
510
- if last_period > len(response_text) * 0.5: # If we've got at least half the response
511
- response_text = response_text[:last_period+1]
512
-
513
- # FINAL CHECK: If we still have parts of the system prompt, use fallback response
514
- if any(phrase in response_text for phrase in ["Your friend seems to be feeling", "Keep your response short", "Be genuinely empathetic"]):
515
- response_text = self.natural_fallback_response(user_message, primary_emotion, self.context.user_name or "friend")
516
-
517
- return clean_response_text(response_text.strip(), self.context.user_name)
518
-
519
- except Exception as e:
520
- print(f"Error generating response: {e}")
521
- return self.natural_fallback_response(user_message, primary_emotion, self.context.user_name or "friend")
522
-
523
- def process_message(self, user_message, chatbot_history):
524
- """Process a user message and return the chatbot response"""
525
- # Initialize context if first message
526
- if not self.context.conversation_history:
527
- initial_greeting = f"Hi! I'm {self.context.bot_name}, your friendly emotional support chatbot. Who am I talking to today?"
528
- self.context.add_message("bot", initial_greeting)
529
- self.context.waiting_for_name = True
530
- return [[None, initial_greeting]]
531
-
532
- # Handle name collection if this is the first user message
533
- if self.context.waiting_for_name and not self.context.introduced:
534
- common_greetings = ["hi", "hey", "hello", "greetings", "howdy", "hiya"]
535
- words = user_message.strip().split()
536
- potential_name = None
537
-
538
- if "i'm" in user_message.lower() or "im" in user_message.lower():
539
- parts = user_message.lower().replace("i'm", "im").split("im")
540
- if len(parts) > 1 and parts[1].strip():
541
- potential_name = parts[1].strip().split()[0].capitalize()
542
-
543
- elif "my name is" in user_message.lower():
544
- parts = user_message.lower().split("my name is")
545
- if len(parts) > 1 and parts[1].strip():
546
- potential_name = parts[1].strip().split()[0].capitalize()
547
-
548
- elif len(words) <= 3 and words[0].lower() not in common_greetings:
549
- potential_name = words[0].capitalize()
550
-
551
- if potential_name:
552
- potential_name = ''.join(c for c in potential_name if c.isalnum())
553
-
554
- if potential_name and len(potential_name) >= 2 and potential_name.lower() not in common_greetings:
555
- self.context.user_name = potential_name
556
- greeting_response = f"Nice to meet you, {self.context.user_name}! How are you feeling today?"
557
- else:
558
- self.context.user_name = "friend"
559
- greeting_response = "Nice to meet you! How are you feeling today?"
560
-
561
- self.context.introduced = True
562
- self.context.waiting_for_name = False
563
- self.context.add_message("user", user_message)
564
- self.context.add_message("bot", greeting_response)
565
-
566
- return chatbot_history + [[user_message, greeting_response]]
567
-
568
- # Regular message processing
569
- emotion_data = self.classify_text(user_message)
570
- self.context.add_message("user", user_message, emotion_data)
571
-
572
- # Generate the response
573
- bot_response = self.generate_response(user_message, emotion_data)
574
- self.context.add_message("bot", bot_response)
575
-
576
- # Create a simple emotion display text
577
- emotion_text = self.format_emotion_text(emotion_data)
578
-
579
- # Combine emotion text with bot response
580
- full_response = f"{emotion_text}\n\n{bot_response}" if emotion_text else bot_response
581
-
582
- # Return updated chat history in the expected tuple format
583
- return chatbot_history + [[user_message, full_response]]
584
-
585
- def reset_conversation(self):
586
- """Reset the conversation context"""
587
- self.context = ChatbotContext()
588
- return []
589
-
590
- # Create the Gradio interface
591
- def create_gradio_interface():
592
- # Initialize the chatbot with default models
593
- emotion_model_id = os.environ.get("EMOTION_MODEL_ID", "suku9/emotion-classifier")
594
- response_model_id = os.environ.get("RESPONSE_MODEL_ID", "mistralai/Mistral-7B-Instruct-v0.2")
595
-
596
- chatbot = GradioEmotionChatbot(emotion_model_id, response_model_id)
597
-
598
- # Create the Gradio interface with dark mode styling
599
- custom_css = """
600
- /* Dark mode styling */
601
- body {
602
- background-color: #1a1a1a !important;
603
- color: #e0e0e0 !important;
604
- }
605
-
606
- .gradio-container {
607
- max-width: 750px !important;
608
- margin: auto !important;
609
- font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif !important;
610
- border-radius: 12px !important;
611
- background: #2d2d2d !important;
612
- padding: 15px !important;
613
- }
614
-
615
- /* Chatbot header styling */
616
- .gradio-container h1, #header {
617
- color: #a29bfe !important;
618
- text-align: center !important;
619
- font-size: 1.8rem !important;
620
- margin-bottom: 5px !important;
621
- font-weight: 600 !important;
622
- }
623
-
624
- .gradio-container p, #subheader {
625
- text-align: center !important;
626
- color: #b0b0b0 !important;
627
- margin-bottom: 15px !important;
628
- font-size: 0.9rem !important;
629
- }
630
-
631
- /* Chatbot window styling */
632
- #chatbot {
633
- height: 380px !important;
634
- overflow: auto !important;
635
- border-radius: 10px !important;
636
- background-color: #1a1a1a !important;
637
- border: 1px solid #3d3d3d !important;
638
- padding: 10px !important;
639
- margin-bottom: 15px !important;
640
- }
641
-
642
- /* Force horizontal text orientation for ALL elements */
643
- * {
644
- writing-mode: horizontal-tb !important;
645
- text-orientation: mixed !important;
646
- direction: ltr !important;
647
- }
648
-
649
- /* Message styling */
650
- .message {
651
- border-radius: 12px !important;
652
- padding: 8px 12px !important;
653
- margin: 5px 0 !important;
654
- max-width: 85% !important;
655
- width: 250px !important;
656
- word-break: break-word !important;
657
- writing-mode: horizontal-tb !important;
658
- text-orientation: mixed !important;
659
- direction: ltr !important;
660
- }
661
-
662
- .user-message {
663
- background-color: #4a5568 !important;
664
- color: #e2e8f0 !important;
665
- writing-mode: horizontal-tb !important;
666
- text-orientation: mixed !important;
667
- }
668
-
669
- .bot-message {
670
- background-color: #553c9a !important;
671
- color: #f8f9fa !important;
672
- writing-mode: horizontal-tb !important;
673
- text-orientation: mixed !important;
674
- }
675
-
676
- /* User input styling - FIX FOR VERTICAL TEXT */
677
- #user-input, .gradio-container textarea, .gradio-container input[type="text"] {
678
- background-color: #2d2d2d !important;
679
- color: #e0e0e0 !important;
680
- border-radius: 20px !important;
681
- padding: 10px 15px !important;
682
- border: 1px solid #3d3d3d !important;
683
- margin-bottom: 10px !important;
684
- writing-mode: horizontal-tb !important;
685
- text-orientation: mixed !important;
686
- direction: ltr !important;
687
- width: 100% !important;
688
- min-height: 45px !important;
689
- height: auto !important;
690
- resize: none !important;
691
- }
692
-
693
- /* Force text orientation for any text inputs */
694
- .cm-editor, .cm-scroller, .cm-content, .cm-line {
695
- writing-mode: horizontal-tb !important;
696
- text-orientation: mixed !important;
697
- }
698
-
699
- /* Ensure row is horizontal */
700
- .gradio-row {
701
- flex-direction: row !important;
702
- }
703
-
704
- /* Fix for chat bubbles */
705
- .chat, .chat > div, .chat > div > div, .chat-msg, .chat-msg > div, .chat-msg-content {
706
- writing-mode: horizontal-tb !important;
707
- text-orientation: mixed !important;
708
- }
709
-
710
- /* Apply horizontal text to all text elements in chatbot */
711
- .prose, .prose p, .prose span, .text-input-with-enter {
712
- writing-mode: horizontal-tb !important;
713
- text-orientation: mixed !important;
714
- direction: ltr !important;
715
- }
716
-
717
- /* Target the specific user bubble on the right side */
718
- .gradio-chatbot > div > div {
719
- writing-mode: horizontal-tb !important;
720
- text-orientation: mixed !important;
721
- direction: ltr !important;
722
- }
723
-
724
- /* Target any text inside chatbot bubbles */
725
- .gradio-chatbot * {
726
- writing-mode: horizontal-tb !important;
727
- text-orientation: mixed !important;
728
- direction: ltr !important;
729
- }
730
-
731
- /* AVATAR AND USERNAME FIXES */
732
- .avatar, .avatar-container, .avatar-image, .user-avatar, .bot-avatar {
733
- writing-mode: horizontal-tb !important;
734
- text-orientation: mixed !important;
735
- direction: ltr !important;
736
- }
737
-
738
- /* Fix for specific containers that might be causing issues */
739
- [class*="message"], [class*="bubble"], [class*="avatar"], [class*="chat"] {
740
- writing-mode: horizontal-tb !important;
741
- text-orientation: mixed !important;
742
- direction: ltr !important;
743
- }
744
-
745
- /* Button styling */
746
- .send-btn, .clear-btn {
747
- background-color: #6c5ce7 !important;
748
- color: white !important;
749
- border: none !important;
750
- border-radius: 20px !important;
751
- padding: 8px 16px !important;
752
- font-weight: 600 !important;
753
- cursor: pointer !important;
754
- transition: all 0.3s ease !important;
755
- }
756
-
757
- .send-btn:hover, .clear-btn:hover {
758
- background-color: #5649c1 !important;
759
- transform: translateY(-1px) !important;
760
- }
761
-
762
- .clear-btn {
763
- background-color: #e74c3c !important;
764
- }
765
-
766
- .clear-btn:hover {
767
- background-color: #c0392b !important;
768
- }
769
-
770
- /* Hide footer */
771
- footer {
772
- display: none !important;
773
- }
774
-
775
- /* Fix scrollbar */
776
- ::-webkit-scrollbar {
777
- width: 6px;
778
- background-color: #1a1a1a;
779
- }
780
-
781
- ::-webkit-scrollbar-thumb {
782
- background-color: #4a4a4a;
783
- border-radius: 3px;
784
- }
785
- """
786
-
787
- with gr.Blocks(css=custom_css) as demo:
788
- gr.Markdown("# EmotionChat", elem_id="header")
789
- gr.Markdown("A supportive chatbot that understands how you feel", elem_id="subheader")
790
-
791
- # Chat interface with improved styling
792
- chatbot_interface = gr.Chatbot(
793
- elem_id="chatbot",
794
- show_label=False,
795
- height=380,
796
- avatar_images=["https://em-content.zobj.net/source/microsoft-teams/363/bust-in-silhouette_1f464.png",
797
- "https://em-content.zobj.net/source/microsoft-teams/363/robot_1f916.png"],
798
- )
799
-
800
- # Input and button row with better styling
801
- with gr.Row():
802
- user_input = gr.Textbox(
803
- placeholder="Type your message here...",
804
- show_label=False,
805
- container=False,
806
- scale=9,
807
- elem_id="user-input",
808
- lines=1,
809
- max_lines=1,
810
- rtl=False
811
- )
812
- submit_btn = gr.Button("Send", scale=1, elem_classes="send-btn")
813
-
814
- # New conversation button
815
- clear_btn = gr.Button("New Conversation", elem_classes="clear-btn")
816
-
817
- # Set up the event handlers
818
- submit_btn.click(
819
- chatbot.process_message,
820
- inputs=[user_input, chatbot_interface],
821
- outputs=[chatbot_interface],
822
- ).then(
823
- lambda: "", # Clear the input box after sending
824
- None,
825
- [user_input],
826
- )
827
-
828
- user_input.submit(
829
- chatbot.process_message,
830
- inputs=[user_input, chatbot_interface],
831
- outputs=[chatbot_interface],
832
- ).then(
833
- lambda: "", # Clear the input box after sending
834
- None,
835
- [user_input],
836
- )
837
-
838
- clear_btn.click(
839
- chatbot.reset_conversation,
840
- inputs=None,
841
- outputs=[chatbot_interface],
842
- )
843
-
844
- return demo
845
-
846
- if __name__ == "__main__":
847
- demo = create_gradio_interface()
 
 
 
 
 
 
 
 
848
  demo.launch(debug=True, share=True)
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ import json
5
+ import re
6
+ import random
7
+ from transformers import (
8
+ AutoTokenizer,
9
+ AutoModelForSequenceClassification,
10
+ AutoModelForCausalLM,
11
+ pipeline,
12
+ )
13
+ import datetime
14
+ import sys
15
+
16
+ # Define emotion label mapping
17
+ EMOTION_LABELS = [
18
+ "admiration", "amusement", "anger", "annoyance", "approval", "caring", "confusion",
19
+ "curiosity", "desire", "disappointment", "disapproval", "disgust", "embarrassment",
20
+ "excitement", "fear", "gratitude", "grief", "joy", "love", "nervousness", "optimism",
21
+ "pride", "realization", "relief", "remorse", "sadness", "surprise", "neutral"
22
+ ]
23
+
24
+ # Map similar emotions to our response categories
25
+ EMOTION_MAPPING = {
26
+ "admiration": "joy",
27
+ "amusement": "joy",
28
+ "anger": "anger",
29
+ "annoyance": "anger",
30
+ "approval": "joy",
31
+ "caring": "joy",
32
+ "confusion": "neutral",
33
+ "curiosity": "neutral",
34
+ "desire": "neutral",
35
+ "disappointment": "sadness",
36
+ "disapproval": "anger",
37
+ "disgust": "disgust",
38
+ "embarrassment": "sadness",
39
+ "excitement": "joy",
40
+ "fear": "fear",
41
+ "gratitude": "joy",
42
+ "grief": "sadness",
43
+ "joy": "joy",
44
+ "love": "joy",
45
+ "nervousness": "fear",
46
+ "optimism": "joy",
47
+ "pride": "joy",
48
+ "realization": "neutral",
49
+ "relief": "joy",
50
+ "remorse": "sadness",
51
+ "sadness": "sadness",
52
+ "surprise": "surprise",
53
+ "neutral": "neutral"
54
+ }
55
+
56
+ class ChatbotContext:
57
+ """Class to maintain conversation context and history"""
58
+ def __init__(self):
59
+ self.conversation_history = []
60
+ self.detected_emotions = []
61
+ self.user_feedback = []
62
+ self.current_session_id = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
63
+ # Track emotional progression for therapeutic conversation flow
64
+ self.conversation_stage = "initial" # initial, middle, advanced
65
+ self.emotion_trajectory = [] # track emotion changes over time
66
+ self.consecutive_positive_count = 0
67
+ self.consecutive_negative_count = 0
68
+ # Add user name tracking
69
+ self.user_name = None
70
+ self.bot_name = "Mira" # Friendly, easy to remember name
71
+ self.introduced = False
72
+ self.waiting_for_name = False
73
+
74
+ def add_message(self, role, text, emotions=None):
75
+ """Add a message to the conversation history"""
76
+ timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
77
+ message = {
78
+ "role": role,
79
+ "text": text,
80
+ "timestamp": timestamp
81
+ }
82
+ if emotions and role == "user":
83
+ message["emotions"] = emotions
84
+ self.detected_emotions.append(emotions)
85
+ self._update_emotional_trajectory(emotions)
86
+
87
+ self.conversation_history.append(message)
88
+ return message
89
+
90
+ def _update_emotional_trajectory(self, emotions):
91
+ """Update the emotional trajectory based on newly detected emotions"""
92
+ # Get the primary emotion
93
+ primary_emotion = emotions[0]["emotion"] if emotions else "neutral"
94
+
95
+ # Add to trajectory
96
+ self.emotion_trajectory.append(primary_emotion)
97
+
98
+ # Classify as positive, negative, or neutral
99
+ positive_emotions = ["joy", "admiration", "amusement", "excitement",
100
+ "optimism", "gratitude", "pride", "love", "relief"]
101
+ negative_emotions = ["sadness", "anger", "fear", "disgust", "disappointment",
102
+ "annoyance", "disapproval", "embarrassment", "grief",
103
+ "remorse", "nervousness"]
104
+
105
+ if primary_emotion in positive_emotions:
106
+ self.consecutive_positive_count += 1
107
+ self.consecutive_negative_count = 0
108
+ elif primary_emotion in negative_emotions:
109
+ self.consecutive_negative_count += 1
110
+ self.consecutive_positive_count = 0
111
+ else: # neutral or other
112
+ # Don't reset counters for neutral emotions to maintain progress
113
+ pass
114
+
115
+ # Update conversation stage based on trajectory and message count
116
+ msg_count = len(self.conversation_history) // 2 # Count actual exchanges (user/bot pairs)
117
+ if msg_count <= 1: # First real exchange
118
+ self.conversation_stage = "initial"
119
+ elif msg_count <= 3: # First few exchanges
120
+ self.conversation_stage = "middle"
121
+ else: # More established conversation
122
+ self.conversation_stage = "advanced"
123
+
124
+ def get_emotional_state(self):
125
+ """Get the current emotional state of the conversation"""
126
+ if len(self.emotion_trajectory) < 2:
127
+ return "unknown"
128
+
129
+ # Get the last few emotions (with 'neutral' having less weight)
130
+ recent_emotions = self.emotion_trajectory[-3:]
131
+ positive_emotions = ["joy", "admiration", "amusement", "excitement",
132
+ "optimism", "gratitude", "pride", "love", "relief"]
133
+ negative_emotions = ["sadness", "anger", "fear", "disgust", "disappointment",
134
+ "annoyance", "disapproval", "embarrassment", "grief",
135
+ "remorse", "nervousness"]
136
+
137
+ # Count positive and negative emotions
138
+ pos_count = sum(1 for e in recent_emotions if e in positive_emotions)
139
+ neg_count = sum(1 for e in recent_emotions if e in negative_emotions)
140
+
141
+ if self.consecutive_positive_count >= 2:
142
+ return "positive"
143
+ elif self.consecutive_negative_count >= 2:
144
+ return "negative"
145
+ elif pos_count > neg_count:
146
+ return "improving"
147
+ elif neg_count > pos_count:
148
+ return "declining"
149
+ else:
150
+ return "neutral"
151
+
152
+ def add_feedback(self, rating, comments=None):
153
+ """Add user feedback about the chatbot's response"""
154
+ feedback = {
155
+ "rating": rating,
156
+ "comments": comments,
157
+ "timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
158
+ }
159
+ self.user_feedback.append(feedback)
160
+ return feedback
161
+
162
+ def get_recent_messages(self, count=5):
163
+ """Get the most recent messages from the conversation history"""
164
+ return self.conversation_history[-count:] if len(self.conversation_history) >= count else self.conversation_history
165
+
166
+ def save_conversation(self, filepath=None):
167
+ """Save the conversation history to a JSON file"""
168
+ if not filepath:
169
+ os.makedirs("./conversations", exist_ok=True)
170
+ filepath = f"./conversations/conversation_{self.current_session_id}.json"
171
+
172
+ data = {
173
+ "conversation_history": self.conversation_history,
174
+ "user_feedback": self.user_feedback,
175
+ "emotion_trajectory": self.emotion_trajectory,
176
+ "session_id": self.current_session_id,
177
+ "start_time": self.conversation_history[0]["timestamp"] if self.conversation_history else None,
178
+ "end_time": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
179
+ }
180
+
181
+ with open(filepath, 'w') as f:
182
+ json.dump(data, f, indent=2)
183
+ print(f"Conversation saved to {filepath}")
184
+ return filepath
185
+
186
+ def clean_response_text(response, user_name):
187
+ """Clean up the response text to make it more natural"""
188
+ # Remove repeated name mentions
189
+ if user_name:
190
+ # Replace patterns like "Hey user_name," or "Hi user_name,"
191
+ response = re.sub(r'^(Hey|Hi|Hello)\s+' + re.escape(user_name) + r',?\s+', '', response, flags=re.IGNORECASE)
192
+
193
+ # Replace duplicate name mentions
194
+ pattern = re.escape(user_name) + r',?\s+.*' + re.escape(user_name)
195
+ if re.search(pattern, response, re.IGNORECASE):
196
+ response = re.sub(r',?\s+' + re.escape(user_name) + r'([,.!?])', r'\1', response, flags=re.IGNORECASE)
197
+
198
+ # Remove name at the end of sentences if it appears earlier
199
+ if response.count(user_name) > 1:
200
+ response = re.sub(r',\s+' + re.escape(user_name) + r'([.!?])(\s|$)', r'\1\2', response, flags=re.IGNORECASE)
201
+
202
+ # Remove phrases that feel repetitive or formulaic
203
+ phrases_to_remove = [
204
+ r"let me know what you'd prefer,?\s+",
205
+ r"i'm here to listen,?\s+",
206
+ r"let me know if there's anything else,?\s+",
207
+ r"i'm all ears,?\s+",
208
+ r"i'm here for you,?\s+"
209
+ ]
210
+
211
+ for phrase in phrases_to_remove:
212
+ response = re.sub(phrase, "", response, flags=re.IGNORECASE)
213
+
214
+ # Fix multiple punctuation
215
+ response = re.sub(r'([.!?])\s+\1', r'\1', response)
216
+
217
+ # Fix missing space after punctuation
218
+ response = re.sub(r'([.!?])([A-Za-z])', r'\1 \2', response)
219
+
220
+ # Make sure first letter is capitalized
221
+ if response and len(response) > 0:
222
+ response = response[0].upper() + response[1:]
223
+
224
+ return response.strip()
225
+
226
+ class GradioEmotionChatbot:
227
+ def __init__(self, emotion_model_id, response_model_id=None, confidence_threshold=0.3):
228
+ self.emotion_model_id = emotion_model_id
229
+ self.response_model_id = response_model_id or "mistralai/Mistral-7B-Instruct-v0.2"
230
+ self.confidence_threshold = confidence_threshold
231
+ self.context = ChatbotContext()
232
+ self.initialize_models()
233
+
234
+ def initialize_models(self):
235
+ # Initialize emotion classification model
236
+ print(f"Loading emotion classification model: {self.emotion_model_id}")
237
+ try:
238
+ self.emotion_model = AutoModelForSequenceClassification.from_pretrained(self.emotion_model_id)
239
+ self.emotion_tokenizer = AutoTokenizer.from_pretrained(self.emotion_model_id)
240
+
241
+ self.emotion_classifier = pipeline(
242
+ "text-classification",
243
+ model=self.emotion_model,
244
+ tokenizer=self.emotion_tokenizer,
245
+ top_k=None # Returns scores for all labels
246
+ )
247
+ print("Emotion classification model loaded successfully!")
248
+ except Exception as e:
249
+ print(f"Error loading emotion classification model: {e}")
250
+ # Fallback to a dummy classifier for demo purposes
251
+ self.emotion_classifier = lambda text: [[{"label": "neutral", "score": 1.0}]]
252
+
253
+ # Initialize response generation model (or use fallback)
254
+ print(f"Loading response generation model: {self.response_model_id}")
255
+ try:
256
+ self.response_model = AutoModelForCausalLM.from_pretrained(
257
+ self.response_model_id,
258
+ torch_dtype=torch.float16,
259
+ device_map="auto"
260
+ )
261
+ self.response_tokenizer = AutoTokenizer.from_pretrained(self.response_model_id)
262
+
263
+ self.response_generator = pipeline(
264
+ "text-generation",
265
+ model=self.response_model,
266
+ tokenizer=self.response_tokenizer,
267
+ do_sample=True,
268
+ top_p=0.92,
269
+ top_k=50,
270
+ temperature=0.7,
271
+ max_new_tokens=100
272
+ )
273
+ print("Response generation model loaded successfully!")
274
+ except Exception as e:
275
+ print(f"Using fallback response generation. Reason: {e}")
276
+ self.response_generator = self.fallback_response_generator
277
+
278
+ def fallback_response_generator(self, prompt, **kwargs):
279
+ """Fallback response generator using templates"""
280
+ # Try to extract emotion from the prompt
281
+ emotion_match = re.search(r"emotion: (\w+)", prompt.lower())
282
+ if emotion_match:
283
+ emotion = emotion_match.group(1)
284
+ else:
285
+ emotion = "neutral"
286
+
287
+ # Default user name
288
+ user_name = "friend"
289
+ name_match = re.search(r"Your friend \((.*?)\)", prompt.lower())
290
+ if name_match:
291
+ user_name = name_match.group(1)
292
+
293
+ # Extract user message
294
+ message_match = re.search(r"message: \"(.*?)\"", prompt)
295
+ user_message = message_match.group(1) if message_match else ""
296
+
297
+ # Generate response using fallback method
298
+ response = self.natural_fallback_response(user_message, emotion, user_name)
299
+
300
+ # Format as if coming from the pipeline
301
+ return [{"generated_text": response}]
302
+
303
+ def natural_fallback_response(self, user_message, primary_emotion, user_name):
304
+ """Conversational fallback responses that sound like a supportive friend"""
305
+ # Define emotion categories
306
+ sad_emotions = ["sadness", "disappointment", "grief", "remorse"]
307
+ fear_emotions = ["fear", "nervousness", "anxiety"]
308
+ anger_emotions = ["anger", "annoyance", "disapproval", "disgust"]
309
+ joy_emotions = ["joy", "admiration", "amusement", "excitement", "optimism",
310
+ "gratitude", "pride", "love", "relief"]
311
+
312
+ # Multi-stage response templates - more natural and varied
313
+ if primary_emotion in joy_emotions:
314
+ responses = [
315
+ f"That's awesome, {user_name}! What made you feel that way?",
316
+ f"I'm so glad to hear that! Tell me more about it?",
317
+ f"That's great news! What else is going on with you lately?"
318
+ ]
319
+ elif primary_emotion in sad_emotions:
320
+ responses = [
321
+ f"I'm sorry to hear that, {user_name}. Want to talk about what happened?",
322
+ f"That sounds rough. What's been going on?",
323
+ f"Ugh, that's tough. How are you handling it?"
324
+ ]
325
+ elif primary_emotion in anger_emotions:
326
+ responses = [
327
+ f"That sounds really frustrating. What happened?",
328
+ f"Oh no, that would upset me too. Want to vent about it?",
329
+ f"I can see why you'd be upset about that. What are you thinking of doing?"
330
+ ]
331
+ elif primary_emotion in fear_emotions:
332
+ responses = [
333
+ f"That sounds scary, {user_name}. What's got you worried?",
334
+ f"I can imagine that would be stressful. What's on your mind about it?",
335
+ f"I get feeling anxious about that. What's the biggest concern for you?"
336
+ ]
337
+ else: # neutral emotions
338
+ responses = [
339
+ f"What's been on your mind lately, {user_name}?",
340
+ f"How's everything else going with you?",
341
+ f"Tell me more about what's going on in your life these days."
342
+ ]
343
+
344
+ return random.choice(responses)
345
+
346
+ def classify_text(self, text):
347
+ """Classify text and return emotion data"""
348
+ try:
349
+ results = self.emotion_classifier(text)
350
+
351
+ # Sort emotions by score in descending order
352
+ sorted_emotions = sorted(results[0], key=lambda x: x['score'], reverse=True)
353
+
354
+ # Process emotions above threshold
355
+ detected_emotions = []
356
+ for emotion in sorted_emotions:
357
+ # Map numerical label to emotion name
358
+ try:
359
+ label_id = int(emotion['label'].split('_')[-1]) if '_' in emotion['label'] else int(emotion['label'])
360
+ if 0 <= label_id < len(EMOTION_LABELS):
361
+ emotion_name = EMOTION_LABELS[label_id]
362
+ else:
363
+ emotion_name = emotion['label']
364
+ except (ValueError, IndexError):
365
+ emotion_name = emotion['label']
366
+
367
+ score = emotion['score']
368
+
369
+ if score >= self.confidence_threshold:
370
+ detected_emotions.append({"emotion": emotion_name, "score": score})
371
+
372
+ # If no emotions detected above threshold, add neutral
373
+ if not detected_emotions:
374
+ detected_emotions.append({"emotion": "neutral", "score": 1.0})
375
+
376
+ return detected_emotions
377
+ except Exception as e:
378
+ print(f"Error during classification: {e}")
379
+ # Return neutral as fallback
380
+ return [{"emotion": "neutral", "score": 1.0}]
381
+
382
+ def format_emotion_text(self, emotion_data):
383
+ """Create a simple emotion text display"""
384
+ if not emotion_data:
385
+ return ""
386
+
387
+ # Define emotion emojis
388
+ emotion_emojis = {
389
+ "joy": "😊", "admiration": "🤩", "amusement": "😄", "approval": "👍",
390
+ "excitement": "🎉", "gratitude": "🙏", "love": "❤️", "optimism": "🌟",
391
+ "pride": "🦚", "relief": "😌", "sadness": "😢", "disappointment": "😞",
392
+ "grief": "💔", "remorse": "😔", "embarrassment": "😳", "anger": "😠",
393
+ "annoyance": "😤", "disapproval": "👎", "disgust": "🤢", "fear": "😨",
394
+ "nervousness": "😰", "surprise": "😲", "confusion": "😕", "curiosity": "🤔",
395
+ "neutral": "😐", "realization": "💡", "desire": "✨"
396
+ }
397
+
398
+ # Format the primary emotion
399
+ primary = emotion_data[0]["emotion"]
400
+ emoji = emotion_emojis.get(primary, "😐")
401
+ score = emotion_data[0]["score"]
402
+
403
+ return f"Detected: {emoji} {primary.capitalize()} ({score:.2f})"
404
+
405
+ def generate_response(self, user_message, emotion_data):
406
+ """Generate a response based on the user's message and detected emotions"""
407
+ # Get the primary emotion with context awareness
408
+ primary_emotion = emotion_data[0]["emotion"] if emotion_data else "neutral"
409
+
410
+ # Get recent conversation history for context
411
+ recent_exchanges = self.context.get_recent_messages(6)
412
+ conversation_history = ""
413
+ for msg in recent_exchanges:
414
+ role = "Friend" if msg["role"] == "user" else self.context.bot_name
415
+ conversation_history += f"{role}: {msg['text']}\n"
416
+
417
+ # Check if this is a greeting
418
+ is_greeting = any(greeting in user_message.lower() for greeting in ["hi", "hello", "hey", "greetings"])
419
+ is_question_about_bot = "how are you" in user_message.lower() or any(q in user_message.lower() for q in ["what can you do", "who are you", "what are you", "your purpose"])
420
+
421
+ # Handle special cases
422
+ if is_greeting:
423
+ if len(self.context.conversation_history) <= 4: # First greeting exchange
424
+ return f"Hi! I'm {self.context.bot_name}. It's nice to meet you. How are you feeling today?"
425
+ else:
426
+ return f"Hey! Good to chat with you again. What's been going on with you?"
427
+
428
+ elif is_question_about_bot:
429
+ return f"I'm doing well, thanks for asking! I'm {self.context.bot_name}, here as a friend to chat whenever you need someone to talk to. What's on your mind today?"
430
+
431
+ # Create a more conversational prompt based on emotion
432
+ system_instruction = f"""You are {self.context.bot_name}, having a natural conversation with your friend. You should respond in a casual, warm way like a supportive friend would - not like a therapist or clinical chatbot.
433
+
434
+ Your friend seems to be feeling {primary_emotion}. In your response:
435
+ 1. Be genuinely empathetic but natural - like how a real friend would respond
436
+ 2. Keep your response short (1-3 sentences) and conversational
437
+ 3. Don't use phrases like "I understand" or "I'm here for you" too much - vary your language
438
+ 4. Use casual language, contractions (don't instead of do not), and occasional sentence fragments
439
+ 5. Don't sound formulaic or overly positive - be authentic
440
+ 6. Keep the same emotional tone throughout your response
441
+ 7. Don't explain what you're doing or add meta-commentary
442
+ 8. DON'T address them by name multiple times or at the end of sentences - it sounds unnatural
443
+ 9. Don't end with "Let me know what you'd prefer" or similar phrases
444
+
445
+ Recent conversation:
446
+ {conversation_history}
447
+
448
+ Your friend's message: "{user_message}"
449
+ Current emotion: {primary_emotion}
450
+
451
+ Respond naturally as a supportive friend (without using their name more than once if at all):"""
452
+
453
+ try:
454
+ # Generate the response
455
+ generated = self.response_generator(
456
+ system_instruction,
457
+ max_new_tokens=100,
458
+ do_sample=True,
459
+ temperature=0.8,
460
+ top_p=0.92,
461
+ top_k=50,
462
+ )
463
+
464
+ # Extract the generated text
465
+ if isinstance(generated, list):
466
+ response_text = generated[0].get('generated_text', '')
467
+ else:
468
+ response_text = generated.get('generated_text', '')
469
+
470
+ # Clean up the response - extract only the actual response without system prompt
471
+ if "[/INST]" in response_text:
472
+ parts = response_text.split("[/INST]")
473
+ if len(parts) > 1:
474
+ response_text = parts[1].strip()
475
+
476
+ # If we're still getting the system instruction, try an alternative approach
477
+ if "Your friend seems to be feeling" in response_text:
478
+ # Try to extract just the bot's response using pattern matching
479
+ match = re.search(r'Respond naturally as a supportive friend.*?:\s*(.*?)$', response_text, re.DOTALL)
480
+ if match:
481
+ response_text = match.group(1).strip()
482
+ else:
483
+ # If that fails, try another approach - take text after the last numbered instruction
484
+ match = re.search(r'9\.\s+[^\n]+\s*(.*?)$', response_text, re.DOTALL)
485
+ if match:
486
+ response_text = match.group(1).strip()
487
+ else:
488
+ # Last resort: pick a fallback response based on emotion
489
+ response_text = self.natural_fallback_response(user_message, primary_emotion, self.context.user_name or "friend")
490
+
491
+ # Remove any model-specific markers
492
+ response_text = response_text.replace("<s>", "").replace("</s>", "")
493
+
494
+ # Remove any internal notes or debugging info that might appear
495
+ if "Note:" in response_text:
496
+ response_text = response_text.split("Note:")[0].strip()
497
+
498
+ # Remove any metadata or system-like text
499
+ response_text = response_text.replace("Assistant:", "").replace(f"{self.context.bot_name}:", "").strip()
500
+
501
+ # Remove any quotation marks surrounding the response
502
+ response_text = response_text.strip('"').strip()
503
+
504
+ # Handle potential model halt mid-sentence
505
+ if response_text.endswith((".", "!", "?")):
506
+ pass # Response ends with proper punctuation
507
+ else:
508
+ # Try to find the last complete sentence
509
+ last_period = max(response_text.rfind("."), response_text.rfind("!"), response_text.rfind("?"))
510
+ if last_period > len(response_text) * 0.5: # If we've got at least half the response
511
+ response_text = response_text[:last_period+1]
512
+
513
+ # FINAL CHECK: If we still have parts of the system prompt, use fallback response
514
+ if any(phrase in response_text for phrase in ["Your friend seems to be feeling", "Keep your response short", "Be genuinely empathetic"]):
515
+ response_text = self.natural_fallback_response(user_message, primary_emotion, self.context.user_name or "friend")
516
+
517
+ return clean_response_text(response_text.strip(), self.context.user_name)
518
+
519
+ except Exception as e:
520
+ print(f"Error generating response: {e}")
521
+ return self.natural_fallback_response(user_message, primary_emotion, self.context.user_name or "friend")
522
+
523
+ def process_message(self, user_message, chatbot_history):
524
+ """Process a user message and return the chatbot response"""
525
+ # Initialize context if first message
526
+ if not self.context.conversation_history:
527
+ initial_greeting = f"Hi! I'm {self.context.bot_name}, your friendly emotional support chatbot. Who am I talking to today?"
528
+ self.context.add_message("bot", initial_greeting)
529
+ self.context.waiting_for_name = True
530
+ return [[None, initial_greeting]]
531
+
532
+ # Handle name collection if this is the first user message
533
+ if self.context.waiting_for_name and not self.context.introduced:
534
+ common_greetings = ["hi", "hey", "hello", "greetings", "howdy", "hiya"]
535
+ words = user_message.strip().split()
536
+ potential_name = None
537
+
538
+ if "i'm" in user_message.lower() or "im" in user_message.lower():
539
+ parts = user_message.lower().replace("i'm", "im").split("im")
540
+ if len(parts) > 1 and parts[1].strip():
541
+ potential_name = parts[1].strip().split()[0].capitalize()
542
+
543
+ elif "my name is" in user_message.lower():
544
+ parts = user_message.lower().split("my name is")
545
+ if len(parts) > 1 and parts[1].strip():
546
+ potential_name = parts[1].strip().split()[0].capitalize()
547
+
548
+ elif len(words) <= 3 and words[0].lower() not in common_greetings:
549
+ potential_name = words[0].capitalize()
550
+
551
+ if potential_name:
552
+ potential_name = ''.join(c for c in potential_name if c.isalnum())
553
+
554
+ if potential_name and len(potential_name) >= 2 and potential_name.lower() not in common_greetings:
555
+ self.context.user_name = potential_name
556
+ greeting_response = f"Nice to meet you, {self.context.user_name}! How are you feeling today?"
557
+ else:
558
+ self.context.user_name = "friend"
559
+ greeting_response = "Nice to meet you! How are you feeling today?"
560
+
561
+ self.context.introduced = True
562
+ self.context.waiting_for_name = False
563
+ self.context.add_message("user", user_message)
564
+ self.context.add_message("bot", greeting_response)
565
+
566
+ return chatbot_history + [[user_message, greeting_response]]
567
+
568
+ # Regular message processing
569
+ emotion_data = self.classify_text(user_message)
570
+ self.context.add_message("user", user_message, emotion_data)
571
+
572
+ # Generate the response
573
+ bot_response = self.generate_response(user_message, emotion_data)
574
+ self.context.add_message("bot", bot_response)
575
+
576
+ # Create a simple emotion display text
577
+ emotion_text = self.format_emotion_text(emotion_data)
578
+
579
+ # Combine emotion text with bot response
580
+ full_response = f"{emotion_text}\n\n{bot_response}" if emotion_text else bot_response
581
+
582
+ # Return updated chat history in the expected tuple format
583
+ return chatbot_history + [[user_message, full_response]]
584
+
585
+ def reset_conversation(self):
586
+ """Reset the conversation context"""
587
+ self.context = ChatbotContext()
588
+ return []
589
+
590
+ # Create the Gradio interface
591
+ import gradio as gr
592
+ import os
593
+
594
+ def create_gradio_interface():
595
+ # Initialize the chatbot with default models
596
+ emotion_model_id = os.environ.get("EMOTION_MODEL_ID", "suku9/emotion-classifier")
597
+ response_model_id = os.environ.get("RESPONSE_MODEL_ID", "mistralai/Mistral-7B-Instruct-v0.2")
598
+
599
+ chatbot = GradioEmotionChatbot(emotion_model_id, response_model_id)
600
+
601
+ # Create the Gradio interface with dark mode styling
602
+ custom_css = """
603
+ /* Dark mode styling */
604
+ body {
605
+ background-color: #1a1a1a !important;
606
+ color: #e0e0e0 !important;
607
+ }
608
+
609
+ .gradio-container {
610
+ max-width: 1200px !important; /* Increased width for horizontal expansion */
611
+ margin: auto !important;
612
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif !important;
613
+ border-radius: 12px !important;
614
+ background: #2d2d2d !important;
615
+ padding: 20px !important;
616
+ }
617
+
618
+ /* Chatbot header styling */
619
+ .gradio-container h1, #header {
620
+ color: #a29bfe !important;
621
+ text-align: center !important;
622
+ font-size: 2.2rem !important; /* Larger font for better visibility */
623
+ margin-bottom: 8px !important;
624
+ font-weight: 700 !important;
625
+ text-shadow: 0 0 2px rgba(0,0,0,0.5) !important; /* Subtle shadow for clarity */
626
+ }
627
+
628
+ .gradio-container p, #subheader {
629
+ text-align: center !important;
630
+ color: #d0d0d0 !important; /* Lighter color for better contrast */
631
+ margin-bottom: 20px !important;
632
+ font-size: 1.1rem !important; /* Slightly larger font */
633
+ font-weight: 400 !important;
634
+ }
635
+
636
+ /* Chatbot window styling */
637
+ #chatbot {
638
+ height: 450px !important; /* Slightly taller for better content display */
639
+ overflow: auto !important;
640
+ border-radius: 10px !important;
641
+ background-color: #1a1a1a !important;
642
+ border: 1px solid #3d3d3d !important;
643
+ padding: 15px !important;
644
+ margin-bottom: 20px !important;
645
+ }
646
+
647
+ /* Force horizontal text orientation for ALL elements */
648
+ * {
649
+ writing-mode: horizontal-tb !important;
650
+ text-orientation: mixed !important;
651
+ direction: ltr !important;
652
+ }
653
+
654
+ /* Message styling */
655
+ .message {
656
+ border-radius: 12px !important;
657
+ padding: 10px 15px !important;
658
+ margin: 8px 0 !important;
659
+ max-width: 70% !important; /* Adjusted for horizontal expansion */
660
+ width: auto !important; /* Allow messages to expand */
661
+ word-break: break-word !important;
662
+ font-size: 1rem !important; /* Clearer font size */
663
+ line-height: 1.4 !important; /* Improved readability */
664
+ text-shadow: 0 0 1px rgba(0,0,0,0.3) !important; /* Subtle shadow for text clarity */
665
+ }
666
+
667
+ .user-message {
668
+ background-color: #4a5568 !important;
669
+ color: #f0f4f8 !important; /* Lighter text for contrast */
670
+ margin-left: auto !important; /* Align user messages to the right */
671
+ }
672
+
673
+ .bot-message {
674
+ background-color: #553c9a !important;
675
+ color: #ffffff !important; /* Pure white for maximum clarity */
676
+ margin-right: auto !important; /* Align bot messages to the left */
677
+ }
678
+
679
+ /* User input styling */
680
+ #user-input, .gradio-container textarea, .gradio-container input[type="text"] {
681
+ background-color: #2d2d2d !important;
682
+ color: #e0e0e0 !important;
683
+ border-radius: 20px !important;
684
+ padding: 12px 18px !important;
685
+ border: 1px solid #4a4a4a !important;
686
+ margin-bottom: 15px !important;
687
+ writing-mode: horizontal-tb !important;
688
+ text-orientation: mixed !important;
689
+ direction: ltr !important;
690
+ width: 100% !important;
691
+ min-height: 50px !important;
692
+ height: auto !important;
693
+ resize: none !important;
694
+ font-size: 1rem !important; /* Clearer font size */
695
+ }
696
+
697
+ /* Force text orientation for any text inputs */
698
+ .cm-editor, .cm-scroller, .cm-content, .cm-line {
699
+ writing-mode: horizontal-tb !important;
700
+ text-orientation: mixed !important;
701
+ }
702
+
703
+ /* Ensure row is horizontal */
704
+ .gradio-row {
705
+ flex-direction: row !important;
706
+ gap: 10px !important; /* Add spacing between elements */
707
+ }
708
+
709
+ /* Fix for chat bubbles */
710
+ .chat, .chat > div, .chat > div > div, .chat-msg, .chat-msg > div, .chat-msg-content {
711
+ writing-mode: horizontal-tb !important;
712
+ text-orientation: mixed !important;
713
+ }
714
+
715
+ /* Apply horizontal text to all text elements in chatbot */
716
+ .prose, .prose p, .prose span, .text-input-with-enter {
717
+ writing-mode: horizontal-tb !important;
718
+ text-orientation: mixed !important;
719
+ direction: ltr !important;
720
+ }
721
+
722
+ /* Target the specific user bubble on the right side */
723
+ .gradio-chatbot > div > div {
724
+ writing-mode: horizontal-tb !important;
725
+ text-orientation: mixed !important;
726
+ direction: ltr !important;
727
+ }
728
+
729
+ /* Target any text inside chatbot bubbles */
730
+ .gradio-chatbot * {
731
+ writing-mode: horizontal-tb !important;
732
+ text-orientation: mixed !important;
733
+ direction: ltr !important;
734
+ }
735
+
736
+ /* AVATAR AND USERNAME FIXES */
737
+ .avatar, .avatar-container, .avatar-image, .user-avatar, .bot-avatar {
738
+ writing-mode: horizontal-tb !important;
739
+ text-orientation: mixed !important;
740
+ direction: ltr !important;
741
+ }
742
+
743
+ /* Fix for specific containers */
744
+ [class*="message"], [class*="bubble"], [class*="avatar"], [class*="chat"] {
745
+ writing-mode: horizontal-tb !important;
746
+ text-orientation: mixed !important;
747
+ direction: ltr !important;
748
+ }
749
+
750
+ /* Button styling */
751
+ .send-btn, .clear-btn {
752
+ background-color: #6c5ce7 !important;
753
+ color: #ffffff !important;
754
+ border: none !important;
755
+ border-radius: 20px !important;
756
+ padding: 10px 20px !important;
757
+ font-weight: 600 !important;
758
+ cursor: pointer !important;
759
+ transition: all 0.3s ease !important;
760
+ font-size: 1rem !important;
761
+ }
762
+
763
+ .send-btn:hover, .clear-btn:hover {
764
+ background-color: #5649c1 !important;
765
+ transform: translateY(-1px) !important;
766
+ }
767
+
768
+ .clear-btn {
769
+ background-color: #e74c3c !important;
770
+ }
771
+
772
+ .clear-btn:hover {
773
+ background-color: #c0392b !important;
774
+ }
775
+
776
+ /* Hide footer */
777
+ footer {
778
+ display: none !important;
779
+ }
780
+
781
+ /* Fix scrollbar */
782
+ ::-webkit-scrollbar {
783
+ width: 8px;
784
+ background-color: #1a1a1a;
785
+ }
786
+
787
+ ::-webkit-scrollbar-thumb {
788
+ background-color: #4a4a4a;
789
+ border-radius: 4px;
790
+ }
791
+ """
792
+
793
+ with gr.Blocks(css=custom_css) as demo:
794
+ gr.Markdown("# EmotionChat", elem_id="header")
795
+ gr.Markdown("A supportive chatbot that understands how you feel", elem_id="subheader")
796
+
797
+ # Chat interface with improved styling
798
+ chatbot_interface = gr.Chatbot(
799
+ elem_id="chatbot",
800
+ show_label=False,
801
+ height=450,
802
+ avatar_images=["https://em-content.zobj.net/source/microsoft-teams/363/bust-in-silhouette_1f464.png",
803
+ "https://em-content.zobj.net/source/microsoft-teams/363/robot_1f916.png"],
804
+ )
805
+
806
+ # Input and button row with better styling
807
+ with gr.Row():
808
+ user_input = gr.Textbox(
809
+ placeholder="Type your message here...",
810
+ show_label=False,
811
+ container=False,
812
+ scale=8,
813
+ elem_id="user-input",
814
+ lines=1,
815
+ max_lines=1,
816
+ rtl=False
817
+ )
818
+ submit_btn = gr.Button("Send", scale=2, elem_classes="send-btn")
819
+
820
+ # New conversation button
821
+ clear_btn = gr.Button("New Conversation", elem_classes="clear-btn")
822
+
823
+ # Set up the event handlers
824
+ submit_btn.click(
825
+ chatbot.process_message,
826
+ inputs=[user_input, chatbot_interface],
827
+ outputs=[chatbot_interface],
828
+ ).then(
829
+ lambda: "", # Clear the input box after sending
830
+ None,
831
+ [user_input],
832
+ )
833
+
834
+ user_input.submit(
835
+ chatbot.process_message,
836
+ inputs=[user_input, chatbot_interface],
837
+ outputs=[chatbot_interface],
838
+ ).then(
839
+ lambda: "", # Clear
840
+
841
+ the input box after sending
842
+ None,
843
+ [user_input],
844
+ )
845
+
846
+ clear_btn.click(
847
+ chatbot.reset_conversation,
848
+ inputs=None,
849
+ outputs=[chatbot_interface],
850
+ )
851
+
852
+ return demo
853
+
854
+ if __name__ == "__main__":
855
+ demo = create_gradio_interface()
856
  demo.launch(debug=True, share=True)