suku9 commited on
Commit
60922b6
·
verified ·
1 Parent(s): b89d70e

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +848 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,848 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ import json
5
+ import re
6
+ import random
7
+ from transformers import (
8
+ AutoTokenizer,
9
+ AutoModelForSequenceClassification,
10
+ AutoModelForCausalLM,
11
+ pipeline,
12
+ )
13
+ import datetime
14
+ import sys
15
+
16
+ # Define emotion label mapping
17
+ EMOTION_LABELS = [
18
+ "admiration", "amusement", "anger", "annoyance", "approval", "caring", "confusion",
19
+ "curiosity", "desire", "disappointment", "disapproval", "disgust", "embarrassment",
20
+ "excitement", "fear", "gratitude", "grief", "joy", "love", "nervousness", "optimism",
21
+ "pride", "realization", "relief", "remorse", "sadness", "surprise", "neutral"
22
+ ]
23
+
24
+ # Map similar emotions to our response categories
25
+ EMOTION_MAPPING = {
26
+ "admiration": "joy",
27
+ "amusement": "joy",
28
+ "anger": "anger",
29
+ "annoyance": "anger",
30
+ "approval": "joy",
31
+ "caring": "joy",
32
+ "confusion": "neutral",
33
+ "curiosity": "neutral",
34
+ "desire": "neutral",
35
+ "disappointment": "sadness",
36
+ "disapproval": "anger",
37
+ "disgust": "disgust",
38
+ "embarrassment": "sadness",
39
+ "excitement": "joy",
40
+ "fear": "fear",
41
+ "gratitude": "joy",
42
+ "grief": "sadness",
43
+ "joy": "joy",
44
+ "love": "joy",
45
+ "nervousness": "fear",
46
+ "optimism": "joy",
47
+ "pride": "joy",
48
+ "realization": "neutral",
49
+ "relief": "joy",
50
+ "remorse": "sadness",
51
+ "sadness": "sadness",
52
+ "surprise": "surprise",
53
+ "neutral": "neutral"
54
+ }
55
+
56
+ class ChatbotContext:
57
+ """Class to maintain conversation context and history"""
58
+ def __init__(self):
59
+ self.conversation_history = []
60
+ self.detected_emotions = []
61
+ self.user_feedback = []
62
+ self.current_session_id = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
63
+ # Track emotional progression for therapeutic conversation flow
64
+ self.conversation_stage = "initial" # initial, middle, advanced
65
+ self.emotion_trajectory = [] # track emotion changes over time
66
+ self.consecutive_positive_count = 0
67
+ self.consecutive_negative_count = 0
68
+ # Add user name tracking
69
+ self.user_name = None
70
+ self.bot_name = "Mira" # Friendly, easy to remember name
71
+ self.introduced = False
72
+ self.waiting_for_name = False
73
+
74
+ def add_message(self, role, text, emotions=None):
75
+ """Add a message to the conversation history"""
76
+ timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
77
+ message = {
78
+ "role": role,
79
+ "text": text,
80
+ "timestamp": timestamp
81
+ }
82
+ if emotions and role == "user":
83
+ message["emotions"] = emotions
84
+ self.detected_emotions.append(emotions)
85
+ self._update_emotional_trajectory(emotions)
86
+
87
+ self.conversation_history.append(message)
88
+ return message
89
+
90
+ def _update_emotional_trajectory(self, emotions):
91
+ """Update the emotional trajectory based on newly detected emotions"""
92
+ # Get the primary emotion
93
+ primary_emotion = emotions[0]["emotion"] if emotions else "neutral"
94
+
95
+ # Add to trajectory
96
+ self.emotion_trajectory.append(primary_emotion)
97
+
98
+ # Classify as positive, negative, or neutral
99
+ positive_emotions = ["joy", "admiration", "amusement", "excitement",
100
+ "optimism", "gratitude", "pride", "love", "relief"]
101
+ negative_emotions = ["sadness", "anger", "fear", "disgust", "disappointment",
102
+ "annoyance", "disapproval", "embarrassment", "grief",
103
+ "remorse", "nervousness"]
104
+
105
+ if primary_emotion in positive_emotions:
106
+ self.consecutive_positive_count += 1
107
+ self.consecutive_negative_count = 0
108
+ elif primary_emotion in negative_emotions:
109
+ self.consecutive_negative_count += 1
110
+ self.consecutive_positive_count = 0
111
+ else: # neutral or other
112
+ # Don't reset counters for neutral emotions to maintain progress
113
+ pass
114
+
115
+ # Update conversation stage based on trajectory and message count
116
+ msg_count = len(self.conversation_history) // 2 # Count actual exchanges (user/bot pairs)
117
+ if msg_count <= 1: # First real exchange
118
+ self.conversation_stage = "initial"
119
+ elif msg_count <= 3: # First few exchanges
120
+ self.conversation_stage = "middle"
121
+ else: # More established conversation
122
+ self.conversation_stage = "advanced"
123
+
124
+ def get_emotional_state(self):
125
+ """Get the current emotional state of the conversation"""
126
+ if len(self.emotion_trajectory) < 2:
127
+ return "unknown"
128
+
129
+ # Get the last few emotions (with 'neutral' having less weight)
130
+ recent_emotions = self.emotion_trajectory[-3:]
131
+ positive_emotions = ["joy", "admiration", "amusement", "excitement",
132
+ "optimism", "gratitude", "pride", "love", "relief"]
133
+ negative_emotions = ["sadness", "anger", "fear", "disgust", "disappointment",
134
+ "annoyance", "disapproval", "embarrassment", "grief",
135
+ "remorse", "nervousness"]
136
+
137
+ # Count positive and negative emotions
138
+ pos_count = sum(1 for e in recent_emotions if e in positive_emotions)
139
+ neg_count = sum(1 for e in recent_emotions if e in negative_emotions)
140
+
141
+ if self.consecutive_positive_count >= 2:
142
+ return "positive"
143
+ elif self.consecutive_negative_count >= 2:
144
+ return "negative"
145
+ elif pos_count > neg_count:
146
+ return "improving"
147
+ elif neg_count > pos_count:
148
+ return "declining"
149
+ else:
150
+ return "neutral"
151
+
152
+ def add_feedback(self, rating, comments=None):
153
+ """Add user feedback about the chatbot's response"""
154
+ feedback = {
155
+ "rating": rating,
156
+ "comments": comments,
157
+ "timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
158
+ }
159
+ self.user_feedback.append(feedback)
160
+ return feedback
161
+
162
+ def get_recent_messages(self, count=5):
163
+ """Get the most recent messages from the conversation history"""
164
+ return self.conversation_history[-count:] if len(self.conversation_history) >= count else self.conversation_history
165
+
166
+ def save_conversation(self, filepath=None):
167
+ """Save the conversation history to a JSON file"""
168
+ if not filepath:
169
+ os.makedirs("./conversations", exist_ok=True)
170
+ filepath = f"./conversations/conversation_{self.current_session_id}.json"
171
+
172
+ data = {
173
+ "conversation_history": self.conversation_history,
174
+ "user_feedback": self.user_feedback,
175
+ "emotion_trajectory": self.emotion_trajectory,
176
+ "session_id": self.current_session_id,
177
+ "start_time": self.conversation_history[0]["timestamp"] if self.conversation_history else None,
178
+ "end_time": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
179
+ }
180
+
181
+ with open(filepath, 'w') as f:
182
+ json.dump(data, f, indent=2)
183
+ print(f"Conversation saved to {filepath}")
184
+ return filepath
185
+
186
+ def clean_response_text(response, user_name):
187
+ """Clean up the response text to make it more natural"""
188
+ # Remove repeated name mentions
189
+ if user_name:
190
+ # Replace patterns like "Hey user_name," or "Hi user_name,"
191
+ response = re.sub(r'^(Hey|Hi|Hello)\s+' + re.escape(user_name) + r',?\s+', '', response, flags=re.IGNORECASE)
192
+
193
+ # Replace duplicate name mentions
194
+ pattern = re.escape(user_name) + r',?\s+.*' + re.escape(user_name)
195
+ if re.search(pattern, response, re.IGNORECASE):
196
+ response = re.sub(r',?\s+' + re.escape(user_name) + r'([,.!?])', r'\1', response, flags=re.IGNORECASE)
197
+
198
+ # Remove name at the end of sentences if it appears earlier
199
+ if response.count(user_name) > 1:
200
+ response = re.sub(r',\s+' + re.escape(user_name) + r'([.!?])(\s|$)', r'\1\2', response, flags=re.IGNORECASE)
201
+
202
+ # Remove phrases that feel repetitive or formulaic
203
+ phrases_to_remove = [
204
+ r"let me know what you'd prefer,?\s+",
205
+ r"i'm here to listen,?\s+",
206
+ r"let me know if there's anything else,?\s+",
207
+ r"i'm all ears,?\s+",
208
+ r"i'm here for you,?\s+"
209
+ ]
210
+
211
+ for phrase in phrases_to_remove:
212
+ response = re.sub(phrase, "", response, flags=re.IGNORECASE)
213
+
214
+ # Fix multiple punctuation
215
+ response = re.sub(r'([.!?])\s+\1', r'\1', response)
216
+
217
+ # Fix missing space after punctuation
218
+ response = re.sub(r'([.!?])([A-Za-z])', r'\1 \2', response)
219
+
220
+ # Make sure first letter is capitalized
221
+ if response and len(response) > 0:
222
+ response = response[0].upper() + response[1:]
223
+
224
+ return response.strip()
225
+
226
+ class GradioEmotionChatbot:
227
+ def __init__(self, emotion_model_id, response_model_id=None, confidence_threshold=0.3):
228
+ self.emotion_model_id = emotion_model_id
229
+ self.response_model_id = response_model_id or "mistralai/Mistral-7B-Instruct-v0.2"
230
+ self.confidence_threshold = confidence_threshold
231
+ self.context = ChatbotContext()
232
+ self.initialize_models()
233
+
234
+ def initialize_models(self):
235
+ # Initialize emotion classification model
236
+ print(f"Loading emotion classification model: {self.emotion_model_id}")
237
+ try:
238
+ self.emotion_model = AutoModelForSequenceClassification.from_pretrained(self.emotion_model_id)
239
+ self.emotion_tokenizer = AutoTokenizer.from_pretrained(self.emotion_model_id)
240
+
241
+ self.emotion_classifier = pipeline(
242
+ "text-classification",
243
+ model=self.emotion_model,
244
+ tokenizer=self.emotion_tokenizer,
245
+ top_k=None # Returns scores for all labels
246
+ )
247
+ print("Emotion classification model loaded successfully!")
248
+ except Exception as e:
249
+ print(f"Error loading emotion classification model: {e}")
250
+ # Fallback to a dummy classifier for demo purposes
251
+ self.emotion_classifier = lambda text: [[{"label": "neutral", "score": 1.0}]]
252
+
253
+ # Initialize response generation model (or use fallback)
254
+ print(f"Loading response generation model: {self.response_model_id}")
255
+ try:
256
+ self.response_model = AutoModelForCausalLM.from_pretrained(
257
+ self.response_model_id,
258
+ torch_dtype=torch.float16,
259
+ device_map="auto"
260
+ )
261
+ self.response_tokenizer = AutoTokenizer.from_pretrained(self.response_model_id)
262
+
263
+ self.response_generator = pipeline(
264
+ "text-generation",
265
+ model=self.response_model,
266
+ tokenizer=self.response_tokenizer,
267
+ do_sample=True,
268
+ top_p=0.92,
269
+ top_k=50,
270
+ temperature=0.7,
271
+ max_new_tokens=100
272
+ )
273
+ print("Response generation model loaded successfully!")
274
+ except Exception as e:
275
+ print(f"Using fallback response generation. Reason: {e}")
276
+ self.response_generator = self.fallback_response_generator
277
+
278
+ def fallback_response_generator(self, prompt, **kwargs):
279
+ """Fallback response generator using templates"""
280
+ # Try to extract emotion from the prompt
281
+ emotion_match = re.search(r"emotion: (\w+)", prompt.lower())
282
+ if emotion_match:
283
+ emotion = emotion_match.group(1)
284
+ else:
285
+ emotion = "neutral"
286
+
287
+ # Default user name
288
+ user_name = "friend"
289
+ name_match = re.search(r"Your friend \((.*?)\)", prompt.lower())
290
+ if name_match:
291
+ user_name = name_match.group(1)
292
+
293
+ # Extract user message
294
+ message_match = re.search(r"message: \"(.*?)\"", prompt)
295
+ user_message = message_match.group(1) if message_match else ""
296
+
297
+ # Generate response using fallback method
298
+ response = self.natural_fallback_response(user_message, emotion, user_name)
299
+
300
+ # Format as if coming from the pipeline
301
+ return [{"generated_text": response}]
302
+
303
+ def natural_fallback_response(self, user_message, primary_emotion, user_name):
304
+ """Conversational fallback responses that sound like a supportive friend"""
305
+ # Define emotion categories
306
+ sad_emotions = ["sadness", "disappointment", "grief", "remorse"]
307
+ fear_emotions = ["fear", "nervousness", "anxiety"]
308
+ anger_emotions = ["anger", "annoyance", "disapproval", "disgust"]
309
+ joy_emotions = ["joy", "admiration", "amusement", "excitement", "optimism",
310
+ "gratitude", "pride", "love", "relief"]
311
+
312
+ # Multi-stage response templates - more natural and varied
313
+ if primary_emotion in joy_emotions:
314
+ responses = [
315
+ f"That's awesome, {user_name}! What made you feel that way?",
316
+ f"I'm so glad to hear that! Tell me more about it?",
317
+ f"That's great news! What else is going on with you lately?"
318
+ ]
319
+ elif primary_emotion in sad_emotions:
320
+ responses = [
321
+ f"I'm sorry to hear that, {user_name}. Want to talk about what happened?",
322
+ f"That sounds rough. What's been going on?",
323
+ f"Ugh, that's tough. How are you handling it?"
324
+ ]
325
+ elif primary_emotion in anger_emotions:
326
+ responses = [
327
+ f"That sounds really frustrating. What happened?",
328
+ f"Oh no, that would upset me too. Want to vent about it?",
329
+ f"I can see why you'd be upset about that. What are you thinking of doing?"
330
+ ]
331
+ elif primary_emotion in fear_emotions:
332
+ responses = [
333
+ f"That sounds scary, {user_name}. What's got you worried?",
334
+ f"I can imagine that would be stressful. What's on your mind about it?",
335
+ f"I get feeling anxious about that. What's the biggest concern for you?"
336
+ ]
337
+ else: # neutral emotions
338
+ responses = [
339
+ f"What's been on your mind lately, {user_name}?",
340
+ f"How's everything else going with you?",
341
+ f"Tell me more about what's going on in your life these days."
342
+ ]
343
+
344
+ return random.choice(responses)
345
+
346
+ def classify_text(self, text):
347
+ """Classify text and return emotion data"""
348
+ try:
349
+ results = self.emotion_classifier(text)
350
+
351
+ # Sort emotions by score in descending order
352
+ sorted_emotions = sorted(results[0], key=lambda x: x['score'], reverse=True)
353
+
354
+ # Process emotions above threshold
355
+ detected_emotions = []
356
+ for emotion in sorted_emotions:
357
+ # Map numerical label to emotion name
358
+ try:
359
+ label_id = int(emotion['label'].split('_')[-1]) if '_' in emotion['label'] else int(emotion['label'])
360
+ if 0 <= label_id < len(EMOTION_LABELS):
361
+ emotion_name = EMOTION_LABELS[label_id]
362
+ else:
363
+ emotion_name = emotion['label']
364
+ except (ValueError, IndexError):
365
+ emotion_name = emotion['label']
366
+
367
+ score = emotion['score']
368
+
369
+ if score >= self.confidence_threshold:
370
+ detected_emotions.append({"emotion": emotion_name, "score": score})
371
+
372
+ # If no emotions detected above threshold, add neutral
373
+ if not detected_emotions:
374
+ detected_emotions.append({"emotion": "neutral", "score": 1.0})
375
+
376
+ return detected_emotions
377
+ except Exception as e:
378
+ print(f"Error during classification: {e}")
379
+ # Return neutral as fallback
380
+ return [{"emotion": "neutral", "score": 1.0}]
381
+
382
+ def format_emotion_text(self, emotion_data):
383
+ """Create a simple emotion text display"""
384
+ if not emotion_data:
385
+ return ""
386
+
387
+ # Define emotion emojis
388
+ emotion_emojis = {
389
+ "joy": "😊", "admiration": "🤩", "amusement": "😄", "approval": "👍",
390
+ "excitement": "🎉", "gratitude": "🙏", "love": "❤️", "optimism": "🌟",
391
+ "pride": "🦚", "relief": "😌", "sadness": "😢", "disappointment": "😞",
392
+ "grief": "💔", "remorse": "😔", "embarrassment": "😳", "anger": "😠",
393
+ "annoyance": "😤", "disapproval": "👎", "disgust": "🤢", "fear": "😨",
394
+ "nervousness": "😰", "surprise": "😲", "confusion": "😕", "curiosity": "🤔",
395
+ "neutral": "😐", "realization": "💡", "desire": "✨"
396
+ }
397
+
398
+ # Format the primary emotion
399
+ primary = emotion_data[0]["emotion"]
400
+ emoji = emotion_emojis.get(primary, "😐")
401
+ score = emotion_data[0]["score"]
402
+
403
+ return f"Detected: {emoji} {primary.capitalize()} ({score:.2f})"
404
+
405
+ def generate_response(self, user_message, emotion_data):
406
+ """Generate a response based on the user's message and detected emotions"""
407
+ # Get the primary emotion with context awareness
408
+ primary_emotion = emotion_data[0]["emotion"] if emotion_data else "neutral"
409
+
410
+ # Get recent conversation history for context
411
+ recent_exchanges = self.context.get_recent_messages(6)
412
+ conversation_history = ""
413
+ for msg in recent_exchanges:
414
+ role = "Friend" if msg["role"] == "user" else self.context.bot_name
415
+ conversation_history += f"{role}: {msg['text']}\n"
416
+
417
+ # Check if this is a greeting
418
+ is_greeting = any(greeting in user_message.lower() for greeting in ["hi", "hello", "hey", "greetings"])
419
+ is_question_about_bot = "how are you" in user_message.lower() or any(q in user_message.lower() for q in ["what can you do", "who are you", "what are you", "your purpose"])
420
+
421
+ # Handle special cases
422
+ if is_greeting:
423
+ if len(self.context.conversation_history) <= 4: # First greeting exchange
424
+ return f"Hi! I'm {self.context.bot_name}. It's nice to meet you. How are you feeling today?"
425
+ else:
426
+ return f"Hey! Good to chat with you again. What's been going on with you?"
427
+
428
+ elif is_question_about_bot:
429
+ return f"I'm doing well, thanks for asking! I'm {self.context.bot_name}, here as a friend to chat whenever you need someone to talk to. What's on your mind today?"
430
+
431
+ # Create a more conversational prompt based on emotion
432
+ system_instruction = f"""You are {self.context.bot_name}, having a natural conversation with your friend. You should respond in a casual, warm way like a supportive friend would - not like a therapist or clinical chatbot.
433
+
434
+ Your friend seems to be feeling {primary_emotion}. In your response:
435
+ 1. Be genuinely empathetic but natural - like how a real friend would respond
436
+ 2. Keep your response short (1-3 sentences) and conversational
437
+ 3. Don't use phrases like "I understand" or "I'm here for you" too much - vary your language
438
+ 4. Use casual language, contractions (don't instead of do not), and occasional sentence fragments
439
+ 5. Don't sound formulaic or overly positive - be authentic
440
+ 6. Keep the same emotional tone throughout your response
441
+ 7. Don't explain what you're doing or add meta-commentary
442
+ 8. DON'T address them by name multiple times or at the end of sentences - it sounds unnatural
443
+ 9. Don't end with "Let me know what you'd prefer" or similar phrases
444
+
445
+ Recent conversation:
446
+ {conversation_history}
447
+
448
+ Your friend's message: "{user_message}"
449
+ Current emotion: {primary_emotion}
450
+
451
+ Respond naturally as a supportive friend (without using their name more than once if at all):"""
452
+
453
+ try:
454
+ # Generate the response
455
+ generated = self.response_generator(
456
+ system_instruction,
457
+ max_new_tokens=100,
458
+ do_sample=True,
459
+ temperature=0.8,
460
+ top_p=0.92,
461
+ top_k=50,
462
+ )
463
+
464
+ # Extract the generated text
465
+ if isinstance(generated, list):
466
+ response_text = generated[0].get('generated_text', '')
467
+ else:
468
+ response_text = generated.get('generated_text', '')
469
+
470
+ # Clean up the response - extract only the actual response without system prompt
471
+ if "[/INST]" in response_text:
472
+ parts = response_text.split("[/INST]")
473
+ if len(parts) > 1:
474
+ response_text = parts[1].strip()
475
+
476
+ # If we're still getting the system instruction, try an alternative approach
477
+ if "Your friend seems to be feeling" in response_text:
478
+ # Try to extract just the bot's response using pattern matching
479
+ match = re.search(r'Respond naturally as a supportive friend.*?:\s*(.*?)$', response_text, re.DOTALL)
480
+ if match:
481
+ response_text = match.group(1).strip()
482
+ else:
483
+ # If that fails, try another approach - take text after the last numbered instruction
484
+ match = re.search(r'9\.\s+[^\n]+\s*(.*?)$', response_text, re.DOTALL)
485
+ if match:
486
+ response_text = match.group(1).strip()
487
+ else:
488
+ # Last resort: pick a fallback response based on emotion
489
+ response_text = self.natural_fallback_response(user_message, primary_emotion, self.context.user_name or "friend")
490
+
491
+ # Remove any model-specific markers
492
+ response_text = response_text.replace("<s>", "").replace("</s>", "")
493
+
494
+ # Remove any internal notes or debugging info that might appear
495
+ if "Note:" in response_text:
496
+ response_text = response_text.split("Note:")[0].strip()
497
+
498
+ # Remove any metadata or system-like text
499
+ response_text = response_text.replace("Assistant:", "").replace(f"{self.context.bot_name}:", "").strip()
500
+
501
+ # Remove any quotation marks surrounding the response
502
+ response_text = response_text.strip('"').strip()
503
+
504
+ # Handle potential model halt mid-sentence
505
+ if response_text.endswith((".", "!", "?")):
506
+ pass # Response ends with proper punctuation
507
+ else:
508
+ # Try to find the last complete sentence
509
+ last_period = max(response_text.rfind("."), response_text.rfind("!"), response_text.rfind("?"))
510
+ if last_period > len(response_text) * 0.5: # If we've got at least half the response
511
+ response_text = response_text[:last_period+1]
512
+
513
+ # FINAL CHECK: If we still have parts of the system prompt, use fallback response
514
+ if any(phrase in response_text for phrase in ["Your friend seems to be feeling", "Keep your response short", "Be genuinely empathetic"]):
515
+ response_text = self.natural_fallback_response(user_message, primary_emotion, self.context.user_name or "friend")
516
+
517
+ return clean_response_text(response_text.strip(), self.context.user_name)
518
+
519
+ except Exception as e:
520
+ print(f"Error generating response: {e}")
521
+ return self.natural_fallback_response(user_message, primary_emotion, self.context.user_name or "friend")
522
+
523
+ def process_message(self, user_message, chatbot_history):
524
+ """Process a user message and return the chatbot response"""
525
+ # Initialize context if first message
526
+ if not self.context.conversation_history:
527
+ initial_greeting = f"Hi! I'm {self.context.bot_name}, your friendly emotional support chatbot. Who am I talking to today?"
528
+ self.context.add_message("bot", initial_greeting)
529
+ self.context.waiting_for_name = True
530
+ return [[None, initial_greeting]]
531
+
532
+ # Handle name collection if this is the first user message
533
+ if self.context.waiting_for_name and not self.context.introduced:
534
+ common_greetings = ["hi", "hey", "hello", "greetings", "howdy", "hiya"]
535
+ words = user_message.strip().split()
536
+ potential_name = None
537
+
538
+ if "i'm" in user_message.lower() or "im" in user_message.lower():
539
+ parts = user_message.lower().replace("i'm", "im").split("im")
540
+ if len(parts) > 1 and parts[1].strip():
541
+ potential_name = parts[1].strip().split()[0].capitalize()
542
+
543
+ elif "my name is" in user_message.lower():
544
+ parts = user_message.lower().split("my name is")
545
+ if len(parts) > 1 and parts[1].strip():
546
+ potential_name = parts[1].strip().split()[0].capitalize()
547
+
548
+ elif len(words) <= 3 and words[0].lower() not in common_greetings:
549
+ potential_name = words[0].capitalize()
550
+
551
+ if potential_name:
552
+ potential_name = ''.join(c for c in potential_name if c.isalnum())
553
+
554
+ if potential_name and len(potential_name) >= 2 and potential_name.lower() not in common_greetings:
555
+ self.context.user_name = potential_name
556
+ greeting_response = f"Nice to meet you, {self.context.user_name}! How are you feeling today?"
557
+ else:
558
+ self.context.user_name = "friend"
559
+ greeting_response = "Nice to meet you! How are you feeling today?"
560
+
561
+ self.context.introduced = True
562
+ self.context.waiting_for_name = False
563
+ self.context.add_message("user", user_message)
564
+ self.context.add_message("bot", greeting_response)
565
+
566
+ return chatbot_history + [[user_message, greeting_response]]
567
+
568
+ # Regular message processing
569
+ emotion_data = self.classify_text(user_message)
570
+ self.context.add_message("user", user_message, emotion_data)
571
+
572
+ # Generate the response
573
+ bot_response = self.generate_response(user_message, emotion_data)
574
+ self.context.add_message("bot", bot_response)
575
+
576
+ # Create a simple emotion display text
577
+ emotion_text = self.format_emotion_text(emotion_data)
578
+
579
+ # Combine emotion text with bot response
580
+ full_response = f"{emotion_text}\n\n{bot_response}" if emotion_text else bot_response
581
+
582
+ # Return updated chat history in the expected tuple format
583
+ return chatbot_history + [[user_message, full_response]]
584
+
585
+ def reset_conversation(self):
586
+ """Reset the conversation context"""
587
+ self.context = ChatbotContext()
588
+ return []
589
+
590
+ # Create the Gradio interface
591
+ def create_gradio_interface():
592
+ # Initialize the chatbot with default models
593
+ emotion_model_id = os.environ.get("EMOTION_MODEL_ID", "suku9/emotion-classifier")
594
+ response_model_id = os.environ.get("RESPONSE_MODEL_ID", "mistralai/Mistral-7B-Instruct-v0.2")
595
+
596
+ chatbot = GradioEmotionChatbot(emotion_model_id, response_model_id)
597
+
598
+ # Create the Gradio interface with dark mode styling
599
+ custom_css = """
600
+ /* Dark mode styling */
601
+ body {
602
+ background-color: #1a1a1a !important;
603
+ color: #e0e0e0 !important;
604
+ }
605
+
606
+ .gradio-container {
607
+ max-width: 750px !important;
608
+ margin: auto !important;
609
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif !important;
610
+ border-radius: 12px !important;
611
+ background: #2d2d2d !important;
612
+ padding: 15px !important;
613
+ }
614
+
615
+ /* Chatbot header styling */
616
+ .gradio-container h1, #header {
617
+ color: #a29bfe !important;
618
+ text-align: center !important;
619
+ font-size: 1.8rem !important;
620
+ margin-bottom: 5px !important;
621
+ font-weight: 600 !important;
622
+ }
623
+
624
+ .gradio-container p, #subheader {
625
+ text-align: center !important;
626
+ color: #b0b0b0 !important;
627
+ margin-bottom: 15px !important;
628
+ font-size: 0.9rem !important;
629
+ }
630
+
631
+ /* Chatbot window styling */
632
+ #chatbot {
633
+ height: 380px !important;
634
+ overflow: auto !important;
635
+ border-radius: 10px !important;
636
+ background-color: #1a1a1a !important;
637
+ border: 1px solid #3d3d3d !important;
638
+ padding: 10px !important;
639
+ margin-bottom: 15px !important;
640
+ }
641
+
642
+ /* Force horizontal text orientation for ALL elements */
643
+ * {
644
+ writing-mode: horizontal-tb !important;
645
+ text-orientation: mixed !important;
646
+ direction: ltr !important;
647
+ }
648
+
649
+ /* Message styling */
650
+ .message {
651
+ border-radius: 12px !important;
652
+ padding: 8px 12px !important;
653
+ margin: 5px 0 !important;
654
+ max-width: 85% !important;
655
+ width: 250px !important;
656
+ word-break: break-word !important;
657
+ writing-mode: horizontal-tb !important;
658
+ text-orientation: mixed !important;
659
+ direction: ltr !important;
660
+ }
661
+
662
+ .user-message {
663
+ background-color: #4a5568 !important;
664
+ color: #e2e8f0 !important;
665
+ writing-mode: horizontal-tb !important;
666
+ text-orientation: mixed !important;
667
+ }
668
+
669
+ .bot-message {
670
+ background-color: #553c9a !important;
671
+ color: #f8f9fa !important;
672
+ writing-mode: horizontal-tb !important;
673
+ text-orientation: mixed !important;
674
+ }
675
+
676
+ /* User input styling - FIX FOR VERTICAL TEXT */
677
+ #user-input, .gradio-container textarea, .gradio-container input[type="text"] {
678
+ background-color: #2d2d2d !important;
679
+ color: #e0e0e0 !important;
680
+ border-radius: 20px !important;
681
+ padding: 10px 15px !important;
682
+ border: 1px solid #3d3d3d !important;
683
+ margin-bottom: 10px !important;
684
+ writing-mode: horizontal-tb !important;
685
+ text-orientation: mixed !important;
686
+ direction: ltr !important;
687
+ width: 100% !important;
688
+ min-height: 45px !important;
689
+ height: auto !important;
690
+ resize: none !important;
691
+ }
692
+
693
+ /* Force text orientation for any text inputs */
694
+ .cm-editor, .cm-scroller, .cm-content, .cm-line {
695
+ writing-mode: horizontal-tb !important;
696
+ text-orientation: mixed !important;
697
+ }
698
+
699
+ /* Ensure row is horizontal */
700
+ .gradio-row {
701
+ flex-direction: row !important;
702
+ }
703
+
704
+ /* Fix for chat bubbles */
705
+ .chat, .chat > div, .chat > div > div, .chat-msg, .chat-msg > div, .chat-msg-content {
706
+ writing-mode: horizontal-tb !important;
707
+ text-orientation: mixed !important;
708
+ }
709
+
710
+ /* Apply horizontal text to all text elements in chatbot */
711
+ .prose, .prose p, .prose span, .text-input-with-enter {
712
+ writing-mode: horizontal-tb !important;
713
+ text-orientation: mixed !important;
714
+ direction: ltr !important;
715
+ }
716
+
717
+ /* Target the specific user bubble on the right side */
718
+ .gradio-chatbot > div > div {
719
+ writing-mode: horizontal-tb !important;
720
+ text-orientation: mixed !important;
721
+ direction: ltr !important;
722
+ }
723
+
724
+ /* Target any text inside chatbot bubbles */
725
+ .gradio-chatbot * {
726
+ writing-mode: horizontal-tb !important;
727
+ text-orientation: mixed !important;
728
+ direction: ltr !important;
729
+ }
730
+
731
+ /* AVATAR AND USERNAME FIXES */
732
+ .avatar, .avatar-container, .avatar-image, .user-avatar, .bot-avatar {
733
+ writing-mode: horizontal-tb !important;
734
+ text-orientation: mixed !important;
735
+ direction: ltr !important;
736
+ }
737
+
738
+ /* Fix for specific containers that might be causing issues */
739
+ [class*="message"], [class*="bubble"], [class*="avatar"], [class*="chat"] {
740
+ writing-mode: horizontal-tb !important;
741
+ text-orientation: mixed !important;
742
+ direction: ltr !important;
743
+ }
744
+
745
+ /* Button styling */
746
+ .send-btn, .clear-btn {
747
+ background-color: #6c5ce7 !important;
748
+ color: white !important;
749
+ border: none !important;
750
+ border-radius: 20px !important;
751
+ padding: 8px 16px !important;
752
+ font-weight: 600 !important;
753
+ cursor: pointer !important;
754
+ transition: all 0.3s ease !important;
755
+ }
756
+
757
+ .send-btn:hover, .clear-btn:hover {
758
+ background-color: #5649c1 !important;
759
+ transform: translateY(-1px) !important;
760
+ }
761
+
762
+ .clear-btn {
763
+ background-color: #e74c3c !important;
764
+ }
765
+
766
+ .clear-btn:hover {
767
+ background-color: #c0392b !important;
768
+ }
769
+
770
+ /* Hide footer */
771
+ footer {
772
+ display: none !important;
773
+ }
774
+
775
+ /* Fix scrollbar */
776
+ ::-webkit-scrollbar {
777
+ width: 6px;
778
+ background-color: #1a1a1a;
779
+ }
780
+
781
+ ::-webkit-scrollbar-thumb {
782
+ background-color: #4a4a4a;
783
+ border-radius: 3px;
784
+ }
785
+ """
786
+
787
+ with gr.Blocks(css=custom_css) as demo:
788
+ gr.Markdown("# EmotionChat", elem_id="header")
789
+ gr.Markdown("A supportive chatbot that understands how you feel", elem_id="subheader")
790
+
791
+ # Chat interface with improved styling
792
+ chatbot_interface = gr.Chatbot(
793
+ elem_id="chatbot",
794
+ show_label=False,
795
+ height=380,
796
+ avatar_images=["https://em-content.zobj.net/source/microsoft-teams/363/bust-in-silhouette_1f464.png",
797
+ "https://em-content.zobj.net/source/microsoft-teams/363/robot_1f916.png"],
798
+ )
799
+
800
+ # Input and button row with better styling
801
+ with gr.Row():
802
+ user_input = gr.Textbox(
803
+ placeholder="Type your message here...",
804
+ show_label=False,
805
+ container=False,
806
+ scale=9,
807
+ elem_id="user-input",
808
+ lines=1,
809
+ max_lines=1,
810
+ rtl=False
811
+ )
812
+ submit_btn = gr.Button("Send", scale=1, elem_classes="send-btn")
813
+
814
+ # New conversation button
815
+ clear_btn = gr.Button("New Conversation", elem_classes="clear-btn")
816
+
817
+ # Set up the event handlers
818
+ submit_btn.click(
819
+ chatbot.process_message,
820
+ inputs=[user_input, chatbot_interface],
821
+ outputs=[chatbot_interface],
822
+ ).then(
823
+ lambda: "", # Clear the input box after sending
824
+ None,
825
+ [user_input],
826
+ )
827
+
828
+ user_input.submit(
829
+ chatbot.process_message,
830
+ inputs=[user_input, chatbot_interface],
831
+ outputs=[chatbot_interface],
832
+ ).then(
833
+ lambda: "", # Clear the input box after sending
834
+ None,
835
+ [user_input],
836
+ )
837
+
838
+ clear_btn.click(
839
+ chatbot.reset_conversation,
840
+ inputs=None,
841
+ outputs=[chatbot_interface],
842
+ )
843
+
844
+ return demo
845
+
846
+ if __name__ == "__main__":
847
+ demo = create_gradio_interface()
848
+ demo.launch(debug=True, share=True)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ torch>=2.0.0
3
+ transformers>=4.36.0
4
+ numpy>=1.20.0
5
+ accelerate>=0.24.0
6
+ sentencepiece>=0.1.99
7
+ protobuf>=3.20.0