File size: 7,537 Bytes
b575114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import json
import base64
import io
import torch
import numpy as np
from typing import Dict, List, Any
import os
import sys

# Add the current directory to Python path for local imports
current_dir = os.path.dirname(os.path.abspath(__file__))
if current_dir not in sys.path:
    sys.path.insert(0, current_dir)

try:
    from modeling_emotion_av import EmotionAVModel, EmotionAVConfig
    from feature_extraction_emotion_av import EmotionAVFeatureExtractor
except ImportError as e:
    print(f"Warning: Could not import custom modules: {e}")
    # Fallback imports
    from transformers import AutoModel, AutoConfig, AutoFeatureExtractor


class EndpointHandler:
    def __init__(self, model_dir: str = ""):
        """
        Initialize the handler for the emotion-av model.
        
        Args:
            model_dir (str): Path to the model directory
        """
        try:
            print(f"Initializing handler with model_dir: {model_dir}")
            
            # Validate config file exists and is readable
            config_path = os.path.join(model_dir, "config.json")
            if not os.path.exists(config_path):
                raise FileNotFoundError(f"Config file not found: {config_path}")
            
            # Test reading config file
            with open(config_path, 'r', encoding='utf-8') as f:
                config_content = f.read().strip()
                if not config_content:
                    raise ValueError("Config file is empty")
                
                # Validate JSON
                config_data = json.loads(config_content)
                print(f"Successfully loaded config with keys: {list(config_data.keys())}")
            
            # Load the custom model and feature extractor with error handling
            try:
                self.model = EmotionAVModel.from_pretrained(
                    model_dir, 
                    trust_remote_code=True,
                    local_files_only=True
                )
                print("Successfully loaded EmotionAVModel")
            except Exception as e:
                print(f"Failed to load with EmotionAVModel: {e}")
                # Fallback to AutoModel
                self.model = AutoModel.from_pretrained(
                    model_dir,
                    trust_remote_code=True,
                    local_files_only=True
                )
                print("Successfully loaded with AutoModel")
            
            try:
                self.feature_extractor = EmotionAVFeatureExtractor.from_pretrained(
                    model_dir,
                    trust_remote_code=True,
                    local_files_only=True
                )
                print("Successfully loaded EmotionAVFeatureExtractor")
            except Exception as e:
                print(f"Failed to load with EmotionAVFeatureExtractor: {e}")
                # Fallback to AutoFeatureExtractor
                self.feature_extractor = AutoFeatureExtractor.from_pretrained(
                    model_dir,
                    trust_remote_code=True,
                    local_files_only=True
                )
                print("Successfully loaded with AutoFeatureExtractor")
                
            self.model.eval()
            print("Handler initialization completed successfully")
            
        except Exception as e:
            print(f"Error during handler initialization: {e}")
            raise
    
    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Handle inference requests.
        
        Args:
            data (Dict): Input data containing 'inputs' key with audio data
            
        Returns:
            List[Dict]: Prediction results in HF-compatible format
        """
        try:
            # Get the inputs
            inputs = data.get("inputs", data)
            parameters = data.get("parameters", {})
            
            # Handle different input formats
            if isinstance(inputs, str):
                # Base64 encoded audio
                try:
                    audio_bytes = base64.b64decode(inputs)
                    audio_data = self._process_audio_bytes(audio_bytes)
                except Exception as e:
                    return [{"error": f"Failed to decode base64 audio: {str(e)}"}]
            elif isinstance(inputs, (list, np.ndarray)):
                # Raw audio array
                audio_data = np.array(inputs, dtype=np.float32)
            else:
                return [{"error": "Invalid input format. Expected base64 string or audio array."}]
            
            # Extract features
            features = self.feature_extractor(
                audio_data,
                sampling_rate=parameters.get("sampling_rate", 16000),
                return_tensors="pt"
            )
            
            # Run inference
            with torch.no_grad():
                outputs = self.model(features["input_features"])
            
            # Process outputs
            emotion_logits = outputs.emotion_logits
            arousal_valence = outputs.arousal_valence
            
            # Get emotion probabilities
            emotion_probs = torch.softmax(emotion_logits, dim=-1)
            
            # Denormalize arousal-valence from [0,1] to [-1,1]
            arousal = (arousal_valence[0, 0].item() * 2) - 1
            valence = (arousal_valence[0, 1].item() * 2) - 1
            
            # Create HF-compatible output: Array<{label: string, score: number}>
            results = []
            probs_sorted, indices = torch.sort(emotion_probs[0], descending=True)
            
            # Return all emotions sorted by confidence
            for i in range(len(indices)):
                idx = indices[i].item()
                label = self.model.config.id2label[idx]
                score = probs_sorted[i].item()
                
                # Strictly follow HF format: only label and score
                results.append({
                    "label": label,
                    "score": score
                })
            
            return results
            
        except Exception as e:
            return [{"error": f"Inference failed: {str(e)}"}]
    
    def _process_audio_bytes(self, audio_bytes: bytes) -> np.ndarray:
        """
        Process audio bytes and convert to numpy array.
        
        Args:
            audio_bytes (bytes): Raw audio bytes
            
        Returns:
            np.ndarray: Processed audio array
        """
        try:
            import soundfile as sf
            
            # Create BytesIO object from bytes
            audio_io = io.BytesIO(audio_bytes)
            
            # Load audio using soundfile
            audio_data, sample_rate = sf.read(audio_io)
            
            # Convert to float32 and ensure mono
            if len(audio_data.shape) > 1:
                audio_data = np.mean(audio_data, axis=1)
            
            audio_data = audio_data.astype(np.float32)
            
            return audio_data
            
        except Exception as e:
            # If soundfile fails, try alternative approach
            try:
                import librosa
                audio_io = io.BytesIO(audio_bytes)
                audio_data, sample_rate = librosa.load(audio_io, sr=16000, mono=True)
                return audio_data.astype(np.float32)
            except Exception as e2:
                raise Exception(f"Failed to process audio: {str(e2)}")