File size: 18,210 Bytes
c8b77b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
"""
Empathy Training Module for MangoMAS Local

This module implements specialized training for empathy and emotional intelligence,
adapted from the AWS backup system for local training.
"""

import json
import logging
import os
import random
from typing import Any, Dict, List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

from ..core_framework import SpecializedTrainingModule, TrainingModuleConfig

logger = logging.getLogger(__name__)


class EmpathyDataset(Dataset):
    """Dataset for training empathy and emotional intelligence capabilities."""

    def __init__(self, data_path: str, tokenizer, max_length: int = 768):
        """
        Initialize the empathy dataset.

        Args:
            data_path: Path to the empathy data file
            tokenizer: Tokenizer for text processing
            max_length: Maximum sequence length
        """
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = self._load_data(data_path)

        logger.info(f"Loaded empathy dataset with {len(self.data)} examples")

    def _load_data(self, data_path: str) -> List[Dict]:
        """Load empathy training data."""
        data = []
        with open(data_path, "r", encoding="utf-8") as f:
            for line in f:
                try:
                    item = json.loads(line.strip())
                    # Validate required fields for empathy data
                    if (
                        "user_message" in item
                        and "emotional_state" in item
                        and "empathetic_response" in item
                    ):
                        data.append(item)
                except (json.JSONDecodeError, KeyError) as e:
                    logger.warning(f"Skipping invalid empathy data: {e}")
        return data

    def __len__(self) -> int:
        """Return the number of examples in the dataset."""
        return len(self.data)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        """Get a training example."""
        item = self.data[idx]

        # Format the empathy example
        user_message = item["user_message"]
        emotional_state = item["emotional_state"]
        empathetic_response = item["empathetic_response"]

        # Additional fields if available
        emotional_cues = item.get("emotional_cues", [])
        context = item.get("context", "")

        # Construct the text with empathy markers
        text = f"User: {user_message}\n\n"

        # Include emotional analysis section for training
        text += f"Emotional State: {emotional_state}\n"

        if emotional_cues:
            text += "Emotional Cues:\n"
            for i, cue in enumerate(emotional_cues):
                text += f"{i+1}. {cue}\n"
            text += "\n"

        if context:
            text += f"Context: {context}\n\n"

        text += f"Empathetic Response: {empathetic_response}"

        # Tokenize
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
            return_tensors="pt",
        )

        return {
            "input_ids": encoding["input_ids"].squeeze(),
            "attention_mask": encoding["attention_mask"].squeeze(),
            "labels": encoding["input_ids"].squeeze().clone(),
            "user_message": user_message,
            "emotional_state": emotional_state,
            "empathetic_response": empathetic_response,
        }


class EmpathyEvaluator:
    """Evaluator for empathy and emotional intelligence capabilities."""

    def __init__(self, tokenizer):
        """
        Initialize the empathy evaluator.

        Args:
            tokenizer: Tokenizer for text processing
        """
        self.tokenizer = tokenizer
        self.metrics = {
            "emotional_recognition": 0.0,
            "empathetic_language": 0.0,
            "supportive_tone": 0.0,
            "personalization": 0.0,
        }

        # Empathetic language markers
        self.empathetic_phrases = [
            "understand",
            "feel",
            "appreciate",
            "recognize",
            "acknowledge",
            "must be",
            "sounds like",
            "seems like",
            "I hear you",
            "that's difficult",
            "that's challenging",
            "I'm sorry",
            "thank you for sharing",
            "I can imagine",
        ]

        # Emotional state categories
        self.emotional_states = {
            "positive": [
                "happy",
                "excited",
                "grateful",
                "proud",
                "hopeful",
                "inspired",
            ],
            "negative": [
                "sad",
                "angry",
                "frustrated",
                "anxious",
                "disappointed",
                "overwhelmed",
            ],
            "neutral": [
                "confused",
                "uncertain",
                "curious",
                "surprised",
                "contemplative",
            ],
        }

    def evaluate(self, model, eval_dataset: EmpathyDataset) -> Dict[str, float]:
        """
        Evaluate empathy capabilities on the provided dataset.

        Args:
            model: The model to evaluate
            eval_dataset: Dataset of empathy examples

        Returns:
            Dictionary of evaluation metrics
        """
        model.eval()
        device = next(model.parameters()).device

        # Reset metrics
        for key in self.metrics:
            self.metrics[key] = 0.0

        total_examples = min(
            len(eval_dataset), 50
        )  # Limit to 50 examples for efficiency

        with torch.no_grad():
            for idx in range(total_examples):
                example = eval_dataset[idx]
                user_message = example["user_message"]
                expected_emotional_state = example["emotional_state"]

                # Generate response without providing emotional state
                prompt = f"User: {user_message}\n\nProvide an empathetic response:"
                input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(
                    device
                )

                generated_ids = model.generate(
                    input_ids, max_length=256, temperature=0.7, num_return_sequences=1
                )

                generated_text = self.tokenizer.decode(
                    generated_ids[0], skip_special_tokens=True
                )

                # Evaluate empathy quality
                self._evaluate_empathy(
                    user_message=user_message,
                    expected_emotional_state=expected_emotional_state,
                    expected_response=example["empathetic_response"],
                    generated_response=generated_text,
                )

        # Calculate averages
        for key in self.metrics:
            self.metrics[key] /= total_examples

        return self.metrics

    def _evaluate_empathy(
        self,
        user_message: str,
        expected_emotional_state: str,
        expected_response: str,
        generated_response: str,
    ) -> None:
        """
        Evaluate empathy quality for a specific example.

        Args:
            user_message: The user's message
            expected_emotional_state: Expected identified emotional state
            expected_response: Expected empathetic response
            generated_response: The response generated by the model
        """
        # 1. Emotional recognition - check if response acknowledges correct emotion
        emotional_category = None
        for category, emotions in self.emotional_states.items():
            if any(emotion in expected_emotional_state.lower() for emotion in emotions):
                emotional_category = category
                break

        if emotional_category:
            # Check if response contains words matching the emotional category
            emotion_words = self.emotional_states[emotional_category]
            emotion_recognition = any(
                word in generated_response.lower() for word in emotion_words
            )
            self.metrics["emotional_recognition"] += 1.0 if emotion_recognition else 0.0
        else:
            # Default partial score if we couldn't categorize
            self.metrics["emotional_recognition"] += 0.5

        # 2. Empathetic language - check for empathetic phrases
        empathy_phrase_count = sum(
            1
            for phrase in self.empathetic_phrases
            if phrase in generated_response.lower()
        )
        self.metrics["empathetic_language"] += min(1.0, empathy_phrase_count / 2)

        # 3. Supportive tone - simplified check for supportive language
        supportive_score = 0.0
        if (
            "here for you" in generated_response.lower()
            or "support" in generated_response.lower()
        ):
            supportive_score += 0.5
        if (
            "help" in generated_response.lower()
            or "advice" in generated_response.lower()
        ):
            supportive_score += 0.3
        if any(
            phrase in generated_response.lower()
            for phrase in ["let me know", "is there anything", "can i"]
        ):
            supportive_score += 0.2
        self.metrics["supportive_tone"] += min(1.0, supportive_score)

        # 4. Personalization - check if response refers to specific details from user message
        user_specific_terms = set(user_message.lower().split()) - {
            "i",
            "me",
            "my",
            "mine",
            "am",
            "was",
            "the",
            "a",
            "an",
        }
        generated_terms = set(generated_response.lower().split())
        specific_term_overlap = len(user_specific_terms.intersection(generated_terms))
        self.metrics["personalization"] += min(1.0, specific_term_overlap / 3)


class EmpathyTrainingModule(SpecializedTrainingModule):
    """Specialized training module for empathy and emotional intelligence capabilities."""

    def __init__(self, config: TrainingModuleConfig, tokenizer):
        """
        Initialize the empathy training module.

        Args:
            config: Module configuration
            tokenizer: Tokenizer for text processing
        """
        super().__init__(config, tokenizer)

        # Initialize empathy-specific components
        self.data_path = config.data_path or "data/processed/empathy_train.jsonl"
        self.evaluator = EmpathyEvaluator(tokenizer)

        # Empathy-specific loss
        self.empathy_loss = nn.CrossEntropyLoss(ignore_index=-100)

        # Training metrics
        self.metrics = {
            "empathy_loss": 0.0,
            "emotion_recognition_rate": 0.0,
            "empathetic_language_score": 0.0,
        }

        logger.info("Initialized empathy training module")

    def prepare_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Prepare a batch of data for empathy training.

        Args:
            batch: The input batch from the dataloader

        Returns:
            Processed batch ready for empathy training
        """
        # Extract empathy-specific elements if they exist
        if all(
            key in batch
            for key in ["user_message", "emotional_state", "empathetic_response"]
        ):
            # This is already an empathy-specific batch
            return batch

        # For general conversation batches, we need to identify emotional content
        # This is a simplified placeholder implementation
        return batch

    def compute_loss(
        self, student_outputs: Any, teacher_outputs: Any, batch: Dict[str, torch.Tensor]
    ) -> torch.Tensor:
        """
        Compute the empathy-specific loss.

        Args:
            student_outputs: Outputs from the student model
            teacher_outputs: Outputs from the teacher model
            batch: The processed input batch

        Returns:
            Empathy-specific loss tensor
        """
        # Get logits from outputs
        student_logits = (
            student_outputs.logits
            if hasattr(student_outputs, "logits")
            else student_outputs
        )
        teacher_logits = (
            teacher_outputs.logits
            if hasattr(teacher_outputs, "logits")
            else teacher_outputs
        )

        # Standard distillation loss calculation
        student_logits = student_logits[:, :-1, :].contiguous()
        teacher_logits = teacher_logits[:, :-1, :].contiguous()
        target_ids = batch["labels"][:, 1:].contiguous()

        # KL divergence loss for distillation
        temperature = 2.0
        kl_loss = F.kl_div(
            F.log_softmax(student_logits / temperature, dim=-1),
            F.softmax(teacher_logits / temperature, dim=-1),
            reduction="batchmean",
        ) * (temperature**2)

        # Cross-entropy loss against labels
        ce_loss = self.empathy_loss(
            student_logits.view(-1, student_logits.size(-1)), target_ids.view(-1)
        )

        # Combined loss with balanced weighting
        loss = 0.5 * ce_loss + 0.5 * kl_loss

        # Update metrics
        self.metrics["empathy_loss"] = loss.item()

        return loss

    def get_metrics(self) -> Dict[str, float]:
        """
        Get metrics specific to empathy training.

        Returns:
            Dictionary of metric names and values
        """
        return self.metrics

    def generate_synthetic_empathy_data(
        self, output_path: str, num_samples: int = 1000
    ) -> None:
        """
        Generate synthetic empathy training data.

        Args:
            output_path: Path to save the generated data
            num_samples: Number of samples to generate
        """
        # This is a simplified implementation based on the AWS backup
        # In a full implementation, this would be much more sophisticated

        empathy_templates = [
            {
                "user_message": "I just lost my job after 5 years at the company. I'm feeling really lost and worried about how I'll pay my bills.",
                "emotional_state": "anxious and worried",
                "emotional_cues": [
                    "Lost job unexpectedly",
                    "Financial insecurity",
                    "Identity/purpose uncertainty",
                    "Use of 'lost' to describe feelings",
                ],
                "empathetic_response": "I'm really sorry to hear about losing your job, especially after dedicating 5 years to the company. That's incredibly difficult. It's completely understandable to feel lost and worried about your financial situation right now. Sudden changes like this can be overwhelming. Would it help to talk through some immediate steps you might take, or would you prefer to just share how you're processing this right now?",
                "context": "Professional setback, financial stress",
            },
            {
                "user_message": "My dog passed away yesterday. He was with me for 12 years and I don't know how to handle this emptiness in my home.",
                "emotional_state": "grief and sadness",
                "emotional_cues": [
                    "Loss of longtime companion",
                    "Reference to 'emptiness'",
                    "Uncertainty about coping",
                    "Recent timing (yesterday)",
                ],
                "empathetic_response": "I'm so sorry about the loss of your dog. After 12 years together, they're truly family, and that leaves such a profound emptiness. The love you shared for so many years speaks to what a caring person you are and what a special connection you had. There's no timeline for grief, especially for a companion who was part of your daily life for so long. Would it help to talk about some of your favorite memories of them?",
                "context": "Pet loss, grief",
            },
            {
                "user_message": "I just got accepted to my dream university! I've been working toward this for years and can't believe it finally happened.",
                "emotional_state": "excited and proud",
                "emotional_cues": [
                    "Exclamation mark showing excitement",
                    "Achievement of long-term goal",
                    "Expression of disbelief at success",
                    "Years of effort mentioned",
                ],
                "empathetic_response": "Congratulations! That's absolutely wonderful news about your acceptance to your dream university! Years of hard work and dedication have paid off, and you have every right to feel excited and proud of this achievement. It's such a special feeling when a goal you've worked toward for so long finally becomes reality. I'd love to hear more about the university and what you're most looking forward to as you start this new chapter!",
                "context": "Academic achievement, celebration",
            },
        ]

        # Generate variations
        output_data = []
        for _ in range(num_samples):
            template = random.choice(empathy_templates)

            # Create a variation to avoid exact duplicates
            variation = template.copy()

            # Add metadata
            variation["metadata"] = {
                "generated": True,
                "timestamp": (
                    torch.cuda.get_device_name(0)
                    if torch.cuda.is_available()
                    else "CPU"
                ),
            }

            output_data.append(variation)

        # Save to file
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        with open(output_path, "w", encoding="utf-8") as f:
            for item in output_data:
                f.write(json.dumps(item) + "\n")

        logger.info(
            f"Generated {len(output_data)} synthetic empathy examples at {output_path}"
        )