File size: 5,459 Bytes
cff1e0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""
Empathy EX (Exploration) Evaluator

Measures the exploration component of empathy in therapeutic responses.
"""
from typing import List
import torch
from transformers import AutoModel, AutoTokenizer

from evaluators.base import Evaluator
from evaluators.registry import register_evaluator
from custom_types import Utterance, EvaluationResult
from utils.evaluation_helpers import create_categorical_score, create_utterance_result


@register_evaluator(
    "empathy_ex",
    label="Empathy EX (Exploration)",
    description="Measures exploration component of empathy",
    category="Empathy"
)
class EmpathyEXEvaluator(Evaluator):
    """Evaluator for Empathy Exploration (EX)."""
    
    METRIC_NAME = "empathy_ex"
    MODEL_NAME = "RyanDDD/empathy-mental-health-reddit-EX"
    LABELS = ["Low", "Medium", "High"]
    
    def __init__(self, api_key: str = None):
        super().__init__()
        self.tokenizer = None
        self.model = None
        self._model_loaded = False
    
    def _load_model(self):
        """Load the model and tokenizer (lazy loading)."""
        if self._model_loaded:
            return
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL_NAME)
            self.model = AutoModel.from_pretrained(
                self.MODEL_NAME, 
                trust_remote_code=True,
                torch_dtype=torch.float32
            )
            # Ensure model is on CPU (or move to appropriate device)
            self.model = self.model.to('cpu')
            self.model.eval()
            self._model_loaded = True
        except Exception as e:
            raise RuntimeError(f"Failed to load {self.MODEL_NAME}: {e}")
    
    def _predict_single(self, seeker_text: str, response_text: str) -> dict:
        """
        Predict empathy level for a single seeker-response pair.
        
        Args:
            seeker_text: The seeker's (patient's) utterance
            response_text: The response (therapist's) utterance
            
        Returns:
            Dict with label, confidence, and probabilities
        """
        # Lazy load model on first use
        self._load_model()
        
        # Tokenize
        encoded_sp = self.tokenizer(
            seeker_text,
            max_length=64,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        encoded_rp = self.tokenizer(
            response_text,
            max_length=64,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        # Ensure tensors are on the same device as model
        device = next(self.model.parameters()).device
        encoded_sp = {k: v.to(device) for k, v in encoded_sp.items()}
        encoded_rp = {k: v.to(device) for k, v in encoded_rp.items()}
        
        # Predict
        with torch.no_grad():
            outputs = self.model(
                input_ids_SP=encoded_sp['input_ids'],
                input_ids_RP=encoded_rp['input_ids'],
                attention_mask_SP=encoded_sp['attention_mask'],
                attention_mask_RP=encoded_rp['attention_mask']
            )
            logits_empathy = outputs[0]
            probs = torch.softmax(logits_empathy, dim=1)
        
        empathy_level = torch.argmax(logits_empathy, dim=1).item()
        confidence = probs[0][empathy_level].item()
        
        return {
            "label": self.LABELS[empathy_level],
            "confidence": confidence,
            "probabilities": {
                "Low": probs[0][0].item(),
                "Medium": probs[0][1].item(),
                "High": probs[0][2].item()
            }
        }
    
    def execute(self, conversation: List[Utterance], **kwargs) -> EvaluationResult:
        """
        Evaluate empathy EX for each therapist response in the conversation.
        
        Args:
            conversation: List of utterances with 'speaker' and 'text'
            
        Returns:
            EvaluationResult with per-utterance scores
        """
        scores_per_utterance = []
        
        # Find seeker-response pairs
        for i, utt in enumerate(conversation):
            # Only evaluate therapist responses
            if utt["speaker"].lower() in ["therapist", "counselor", "provider"]:
                # Find the most recent patient/seeker utterance
                seeker_text = ""
                for j in range(i - 1, -1, -1):
                    if conversation[j]["speaker"].lower() in ["patient", "seeker", "client"]:
                        seeker_text = conversation[j]["text"]
                        break
                
                # If we found a seeker utterance, evaluate
                if seeker_text:
                    prediction = self._predict_single(seeker_text, utt["text"])
                    scores_per_utterance.append({
                        "empathy_ex": create_categorical_score(
                            label=prediction["label"],
                            confidence=prediction["confidence"]
                        )
                    })
                else:
                    # No seeker context, skip evaluation
                    scores_per_utterance.append({})
            else:
                # Not a therapist utterance, skip
                scores_per_utterance.append({})
        
        return create_utterance_result(conversation, scores_per_utterance)