import json import base64 import io import torch import numpy as np from typing import Dict, List, Any import os import sys # Add the current directory to Python path for local imports current_dir = os.path.dirname(os.path.abspath(__file__)) if current_dir not in sys.path: sys.path.insert(0, current_dir) try: from modeling_emotion_av import EmotionAVModel, EmotionAVConfig from feature_extraction_emotion_av import EmotionAVFeatureExtractor except ImportError as e: print(f"Warning: Could not import custom modules: {e}") # Fallback imports from transformers import AutoModel, AutoConfig, AutoFeatureExtractor class ExtendedEndpointHandler: """ Extended handler that provides both HF-compatible output and detailed emotion analysis. This handler can return arousal/valence information when specifically requested. """ def __init__(self, model_dir: str = ""): """ Initialize the handler for the emotion-av model. Args: model_dir (str): Path to the model directory """ try: print(f"Initializing extended handler with model_dir: {model_dir}") # Validate config file exists and is readable config_path = os.path.join(model_dir, "config.json") if not os.path.exists(config_path): raise FileNotFoundError(f"Config file not found: {config_path}") # Test reading config file with open(config_path, 'r', encoding='utf-8') as f: config_content = f.read().strip() if not config_content: raise ValueError("Config file is empty") # Validate JSON config_data = json.loads(config_content) print(f"Successfully loaded config with keys: {list(config_data.keys())}") # Load the custom model and feature extractor with error handling try: self.model = EmotionAVModel.from_pretrained( model_dir, trust_remote_code=True, local_files_only=True ) print("Successfully loaded EmotionAVModel") except Exception as e: print(f"Failed to load with EmotionAVModel: {e}") # Fallback to AutoModel self.model = AutoModel.from_pretrained( model_dir, trust_remote_code=True, local_files_only=True ) print("Successfully loaded with AutoModel") try: self.feature_extractor = EmotionAVFeatureExtractor.from_pretrained( model_dir, trust_remote_code=True, local_files_only=True ) print("Successfully loaded EmotionAVFeatureExtractor") except Exception as e: print(f"Failed to load with EmotionAVFeatureExtractor: {e}") # Fallback to AutoFeatureExtractor self.feature_extractor = AutoFeatureExtractor.from_pretrained( model_dir, trust_remote_code=True, local_files_only=True ) print("Successfully loaded with AutoFeatureExtractor") self.model.eval() print("Extended handler initialization completed successfully") except Exception as e: print(f"Error during extended handler initialization: {e}") raise def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Handle inference requests with both standard and extended formats. Args: data (Dict): Input data containing 'inputs' key with audio data Returns: List[Dict]: Prediction results (HF-compatible by default, extended if requested) """ try: # Get the inputs inputs = data.get("inputs", data) parameters = data.get("parameters", {}) # Check if extended output is requested extended_output = parameters.get("extended_output", False) include_arousal_valence = parameters.get("include_arousal_valence", False) # Handle different input formats if isinstance(inputs, str): # Base64 encoded audio try: audio_bytes = base64.b64decode(inputs) audio_data = self._process_audio_bytes(audio_bytes) except Exception as e: return [{"error": f"Failed to decode base64 audio: {str(e)}"}] elif isinstance(inputs, (list, np.ndarray)): # Raw audio array audio_data = np.array(inputs, dtype=np.float32) else: return [{"error": "Invalid input format. Expected base64 string or audio array."}] # Extract features features = self.feature_extractor( audio_data, sampling_rate=parameters.get("sampling_rate", 16000), return_tensors="pt" ) # Run inference with torch.no_grad(): outputs = self.model(features["input_features"]) # Process outputs emotion_logits = outputs.emotion_logits arousal_valence = outputs.arousal_valence # Get emotion probabilities emotion_probs = torch.softmax(emotion_logits, dim=-1) # Denormalize arousal-valence from [0,1] to [-1,1] arousal = (arousal_valence[0, 0].item() * 2) - 1 valence = (arousal_valence[0, 1].item() * 2) - 1 if extended_output or include_arousal_valence: # Return extended format with arousal/valence information return self._format_extended_output(emotion_probs[0], arousal, valence) else: # Return HF-compatible format: Array<{label: string, score: number}> return self._format_standard_output(emotion_probs[0]) except Exception as e: return [{"error": f"Inference failed: {str(e)}"}] def _format_standard_output(self, emotion_probs: torch.Tensor) -> List[Dict[str, Any]]: """ Format output in HuggingFace-compatible format. Args: emotion_probs: Emotion probabilities tensor Returns: List of {label, score} dictionaries """ results = [] probs_sorted, indices = torch.sort(emotion_probs, descending=True) for i in range(len(indices)): idx = indices[i].item() label = self.model.config.id2label[idx] score = probs_sorted[i].item() results.append({ "label": label, "score": score }) return results def _format_extended_output(self, emotion_probs: torch.Tensor, arousal: float, valence: float) -> List[Dict[str, Any]]: """ Format output with extended emotion information including arousal/valence. Args: emotion_probs: Emotion probabilities tensor arousal: Arousal value valence: Valence value Returns: List with primary emotion and extended information """ # Get top emotion predicted_id = torch.argmax(emotion_probs).item() confidence = emotion_probs.max().item() emotion_label = self.model.config.id2label[predicted_id] # Create all emotions list all_emotions = [] probs_sorted, indices = torch.sort(emotion_probs, descending=True) for i in range(len(indices)): idx = indices[i].item() label = self.model.config.id2label[idx] score = probs_sorted[i].item() all_emotions.append({"label": label, "score": score}) # Return primary result with extended information result = { "label": emotion_label, "score": confidence, "arousal": arousal, "valence": valence, "all_emotions": all_emotions, "emotion_distribution": { self.model.config.id2label[j]: prob.item() for j, prob in enumerate(emotion_probs) } } return [result] def _process_audio_bytes(self, audio_bytes: bytes) -> np.ndarray: """ Process audio bytes and convert to numpy array. Args: audio_bytes (bytes): Raw audio bytes Returns: np.ndarray: Processed audio array """ try: import soundfile as sf # Create BytesIO object from bytes audio_io = io.BytesIO(audio_bytes) # Load audio using soundfile audio_data, sample_rate = sf.read(audio_io) # Convert to float32 and ensure mono if len(audio_data.shape) > 1: audio_data = np.mean(audio_data, axis=1) audio_data = audio_data.astype(np.float32) return audio_data except Exception as e: # If soundfile fails, try alternative approach try: import librosa audio_io = io.BytesIO(audio_bytes) audio_data, sample_rate = librosa.load(audio_io, sr=16000, mono=True) return audio_data.astype(np.float32) except Exception as e2: raise Exception(f"Failed to process audio: {str(e2)}") # For compatibility, provide the standard handler as well class EndpointHandler(ExtendedEndpointHandler): """ Standard handler that ensures HF compatibility by default. """ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Handle inference requests with HF-compatible output by default. """ # Force standard output for HF compatibility parameters = data.get("parameters", {}) parameters["extended_output"] = False parameters["include_arousal_valence"] = False data["parameters"] = parameters return super().__call__(data)