File size: 10,708 Bytes
b575114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
import json
import base64
import io
import torch
import numpy as np
from typing import Dict, List, Any
import os
import sys

# Add the current directory to Python path for local imports
current_dir = os.path.dirname(os.path.abspath(__file__))
if current_dir not in sys.path:
    sys.path.insert(0, current_dir)

try:
    from modeling_emotion_av import EmotionAVModel, EmotionAVConfig
    from feature_extraction_emotion_av import EmotionAVFeatureExtractor
except ImportError as e:
    print(f"Warning: Could not import custom modules: {e}")
    # Fallback imports
    from transformers import AutoModel, AutoConfig, AutoFeatureExtractor


class ExtendedEndpointHandler:
    """
    Extended handler that provides both HF-compatible output and detailed emotion analysis.
    This handler can return arousal/valence information when specifically requested.
    """
    
    def __init__(self, model_dir: str = ""):
        """
        Initialize the handler for the emotion-av model.
        
        Args:
            model_dir (str): Path to the model directory
        """
        try:
            print(f"Initializing extended handler with model_dir: {model_dir}")
            
            # Validate config file exists and is readable
            config_path = os.path.join(model_dir, "config.json")
            if not os.path.exists(config_path):
                raise FileNotFoundError(f"Config file not found: {config_path}")
            
            # Test reading config file
            with open(config_path, 'r', encoding='utf-8') as f:
                config_content = f.read().strip()
                if not config_content:
                    raise ValueError("Config file is empty")
                
                # Validate JSON
                config_data = json.loads(config_content)
                print(f"Successfully loaded config with keys: {list(config_data.keys())}")
            
            # Load the custom model and feature extractor with error handling
            try:
                self.model = EmotionAVModel.from_pretrained(
                    model_dir, 
                    trust_remote_code=True,
                    local_files_only=True
                )
                print("Successfully loaded EmotionAVModel")
            except Exception as e:
                print(f"Failed to load with EmotionAVModel: {e}")
                # Fallback to AutoModel
                self.model = AutoModel.from_pretrained(
                    model_dir,
                    trust_remote_code=True,
                    local_files_only=True
                )
                print("Successfully loaded with AutoModel")
            
            try:
                self.feature_extractor = EmotionAVFeatureExtractor.from_pretrained(
                    model_dir,
                    trust_remote_code=True,
                    local_files_only=True
                )
                print("Successfully loaded EmotionAVFeatureExtractor")
            except Exception as e:
                print(f"Failed to load with EmotionAVFeatureExtractor: {e}")
                # Fallback to AutoFeatureExtractor
                self.feature_extractor = AutoFeatureExtractor.from_pretrained(
                    model_dir,
                    trust_remote_code=True,
                    local_files_only=True
                )
                print("Successfully loaded with AutoFeatureExtractor")
                
            self.model.eval()
            print("Extended handler initialization completed successfully")
            
        except Exception as e:
            print(f"Error during extended handler initialization: {e}")
            raise
    
    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Handle inference requests with both standard and extended formats.
        
        Args:
            data (Dict): Input data containing 'inputs' key with audio data
            
        Returns:
            List[Dict]: Prediction results (HF-compatible by default, extended if requested)
        """
        try:
            # Get the inputs
            inputs = data.get("inputs", data)
            parameters = data.get("parameters", {})
            
            # Check if extended output is requested
            extended_output = parameters.get("extended_output", False)
            include_arousal_valence = parameters.get("include_arousal_valence", False)
            
            # Handle different input formats
            if isinstance(inputs, str):
                # Base64 encoded audio
                try:
                    audio_bytes = base64.b64decode(inputs)
                    audio_data = self._process_audio_bytes(audio_bytes)
                except Exception as e:
                    return [{"error": f"Failed to decode base64 audio: {str(e)}"}]
            elif isinstance(inputs, (list, np.ndarray)):
                # Raw audio array
                audio_data = np.array(inputs, dtype=np.float32)
            else:
                return [{"error": "Invalid input format. Expected base64 string or audio array."}]
            
            # Extract features
            features = self.feature_extractor(
                audio_data,
                sampling_rate=parameters.get("sampling_rate", 16000),
                return_tensors="pt"
            )
            
            # Run inference
            with torch.no_grad():
                outputs = self.model(features["input_features"])
            
            # Process outputs
            emotion_logits = outputs.emotion_logits
            arousal_valence = outputs.arousal_valence
            
            # Get emotion probabilities
            emotion_probs = torch.softmax(emotion_logits, dim=-1)
            
            # Denormalize arousal-valence from [0,1] to [-1,1]
            arousal = (arousal_valence[0, 0].item() * 2) - 1
            valence = (arousal_valence[0, 1].item() * 2) - 1
            
            if extended_output or include_arousal_valence:
                # Return extended format with arousal/valence information
                return self._format_extended_output(emotion_probs[0], arousal, valence)
            else:
                # Return HF-compatible format: Array<{label: string, score: number}>
                return self._format_standard_output(emotion_probs[0])
            
        except Exception as e:
            return [{"error": f"Inference failed: {str(e)}"}]
    
    def _format_standard_output(self, emotion_probs: torch.Tensor) -> List[Dict[str, Any]]:
        """
        Format output in HuggingFace-compatible format.
        
        Args:
            emotion_probs: Emotion probabilities tensor
            
        Returns:
            List of {label, score} dictionaries
        """
        results = []
        probs_sorted, indices = torch.sort(emotion_probs, descending=True)
        
        for i in range(len(indices)):
            idx = indices[i].item()
            label = self.model.config.id2label[idx]
            score = probs_sorted[i].item()
            
            results.append({
                "label": label,
                "score": score
            })
        
        return results
    
    def _format_extended_output(self, emotion_probs: torch.Tensor, arousal: float, valence: float) -> List[Dict[str, Any]]:
        """
        Format output with extended emotion information including arousal/valence.
        
        Args:
            emotion_probs: Emotion probabilities tensor
            arousal: Arousal value
            valence: Valence value
            
        Returns:
            List with primary emotion and extended information
        """
        # Get top emotion
        predicted_id = torch.argmax(emotion_probs).item()
        confidence = emotion_probs.max().item()
        emotion_label = self.model.config.id2label[predicted_id]
        
        # Create all emotions list
        all_emotions = []
        probs_sorted, indices = torch.sort(emotion_probs, descending=True)
        
        for i in range(len(indices)):
            idx = indices[i].item()
            label = self.model.config.id2label[idx]
            score = probs_sorted[i].item()
            all_emotions.append({"label": label, "score": score})
        
        # Return primary result with extended information
        result = {
            "label": emotion_label,
            "score": confidence,
            "arousal": arousal,
            "valence": valence,
            "all_emotions": all_emotions,
            "emotion_distribution": {
                self.model.config.id2label[j]: prob.item() 
                for j, prob in enumerate(emotion_probs)
            }
        }
        
        return [result]
    
    def _process_audio_bytes(self, audio_bytes: bytes) -> np.ndarray:
        """
        Process audio bytes and convert to numpy array.
        
        Args:
            audio_bytes (bytes): Raw audio bytes
            
        Returns:
            np.ndarray: Processed audio array
        """
        try:
            import soundfile as sf
            
            # Create BytesIO object from bytes
            audio_io = io.BytesIO(audio_bytes)
            
            # Load audio using soundfile
            audio_data, sample_rate = sf.read(audio_io)
            
            # Convert to float32 and ensure mono
            if len(audio_data.shape) > 1:
                audio_data = np.mean(audio_data, axis=1)
            
            audio_data = audio_data.astype(np.float32)
            
            return audio_data
            
        except Exception as e:
            # If soundfile fails, try alternative approach
            try:
                import librosa
                audio_io = io.BytesIO(audio_bytes)
                audio_data, sample_rate = librosa.load(audio_io, sr=16000, mono=True)
                return audio_data.astype(np.float32)
            except Exception as e2:
                raise Exception(f"Failed to process audio: {str(e2)}")


# For compatibility, provide the standard handler as well
class EndpointHandler(ExtendedEndpointHandler):
    """
    Standard handler that ensures HF compatibility by default.
    """
    
    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Handle inference requests with HF-compatible output by default.
        """
        # Force standard output for HF compatibility
        parameters = data.get("parameters", {})
        parameters["extended_output"] = False
        parameters["include_arousal_valence"] = False
        data["parameters"] = parameters
        
        return super().__call__(data)