pph-emotion-classification-model / handler_extended.py
vishrutjha's picture
Fix handler for HF Inference API compatibility
b575114 verified
import json
import base64
import io
import torch
import numpy as np
from typing import Dict, List, Any
import os
import sys
# Add the current directory to Python path for local imports
current_dir = os.path.dirname(os.path.abspath(__file__))
if current_dir not in sys.path:
sys.path.insert(0, current_dir)
try:
from modeling_emotion_av import EmotionAVModel, EmotionAVConfig
from feature_extraction_emotion_av import EmotionAVFeatureExtractor
except ImportError as e:
print(f"Warning: Could not import custom modules: {e}")
# Fallback imports
from transformers import AutoModel, AutoConfig, AutoFeatureExtractor
class ExtendedEndpointHandler:
"""
Extended handler that provides both HF-compatible output and detailed emotion analysis.
This handler can return arousal/valence information when specifically requested.
"""
def __init__(self, model_dir: str = ""):
"""
Initialize the handler for the emotion-av model.
Args:
model_dir (str): Path to the model directory
"""
try:
print(f"Initializing extended handler with model_dir: {model_dir}")
# Validate config file exists and is readable
config_path = os.path.join(model_dir, "config.json")
if not os.path.exists(config_path):
raise FileNotFoundError(f"Config file not found: {config_path}")
# Test reading config file
with open(config_path, 'r', encoding='utf-8') as f:
config_content = f.read().strip()
if not config_content:
raise ValueError("Config file is empty")
# Validate JSON
config_data = json.loads(config_content)
print(f"Successfully loaded config with keys: {list(config_data.keys())}")
# Load the custom model and feature extractor with error handling
try:
self.model = EmotionAVModel.from_pretrained(
model_dir,
trust_remote_code=True,
local_files_only=True
)
print("Successfully loaded EmotionAVModel")
except Exception as e:
print(f"Failed to load with EmotionAVModel: {e}")
# Fallback to AutoModel
self.model = AutoModel.from_pretrained(
model_dir,
trust_remote_code=True,
local_files_only=True
)
print("Successfully loaded with AutoModel")
try:
self.feature_extractor = EmotionAVFeatureExtractor.from_pretrained(
model_dir,
trust_remote_code=True,
local_files_only=True
)
print("Successfully loaded EmotionAVFeatureExtractor")
except Exception as e:
print(f"Failed to load with EmotionAVFeatureExtractor: {e}")
# Fallback to AutoFeatureExtractor
self.feature_extractor = AutoFeatureExtractor.from_pretrained(
model_dir,
trust_remote_code=True,
local_files_only=True
)
print("Successfully loaded with AutoFeatureExtractor")
self.model.eval()
print("Extended handler initialization completed successfully")
except Exception as e:
print(f"Error during extended handler initialization: {e}")
raise
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Handle inference requests with both standard and extended formats.
Args:
data (Dict): Input data containing 'inputs' key with audio data
Returns:
List[Dict]: Prediction results (HF-compatible by default, extended if requested)
"""
try:
# Get the inputs
inputs = data.get("inputs", data)
parameters = data.get("parameters", {})
# Check if extended output is requested
extended_output = parameters.get("extended_output", False)
include_arousal_valence = parameters.get("include_arousal_valence", False)
# Handle different input formats
if isinstance(inputs, str):
# Base64 encoded audio
try:
audio_bytes = base64.b64decode(inputs)
audio_data = self._process_audio_bytes(audio_bytes)
except Exception as e:
return [{"error": f"Failed to decode base64 audio: {str(e)}"}]
elif isinstance(inputs, (list, np.ndarray)):
# Raw audio array
audio_data = np.array(inputs, dtype=np.float32)
else:
return [{"error": "Invalid input format. Expected base64 string or audio array."}]
# Extract features
features = self.feature_extractor(
audio_data,
sampling_rate=parameters.get("sampling_rate", 16000),
return_tensors="pt"
)
# Run inference
with torch.no_grad():
outputs = self.model(features["input_features"])
# Process outputs
emotion_logits = outputs.emotion_logits
arousal_valence = outputs.arousal_valence
# Get emotion probabilities
emotion_probs = torch.softmax(emotion_logits, dim=-1)
# Denormalize arousal-valence from [0,1] to [-1,1]
arousal = (arousal_valence[0, 0].item() * 2) - 1
valence = (arousal_valence[0, 1].item() * 2) - 1
if extended_output or include_arousal_valence:
# Return extended format with arousal/valence information
return self._format_extended_output(emotion_probs[0], arousal, valence)
else:
# Return HF-compatible format: Array<{label: string, score: number}>
return self._format_standard_output(emotion_probs[0])
except Exception as e:
return [{"error": f"Inference failed: {str(e)}"}]
def _format_standard_output(self, emotion_probs: torch.Tensor) -> List[Dict[str, Any]]:
"""
Format output in HuggingFace-compatible format.
Args:
emotion_probs: Emotion probabilities tensor
Returns:
List of {label, score} dictionaries
"""
results = []
probs_sorted, indices = torch.sort(emotion_probs, descending=True)
for i in range(len(indices)):
idx = indices[i].item()
label = self.model.config.id2label[idx]
score = probs_sorted[i].item()
results.append({
"label": label,
"score": score
})
return results
def _format_extended_output(self, emotion_probs: torch.Tensor, arousal: float, valence: float) -> List[Dict[str, Any]]:
"""
Format output with extended emotion information including arousal/valence.
Args:
emotion_probs: Emotion probabilities tensor
arousal: Arousal value
valence: Valence value
Returns:
List with primary emotion and extended information
"""
# Get top emotion
predicted_id = torch.argmax(emotion_probs).item()
confidence = emotion_probs.max().item()
emotion_label = self.model.config.id2label[predicted_id]
# Create all emotions list
all_emotions = []
probs_sorted, indices = torch.sort(emotion_probs, descending=True)
for i in range(len(indices)):
idx = indices[i].item()
label = self.model.config.id2label[idx]
score = probs_sorted[i].item()
all_emotions.append({"label": label, "score": score})
# Return primary result with extended information
result = {
"label": emotion_label,
"score": confidence,
"arousal": arousal,
"valence": valence,
"all_emotions": all_emotions,
"emotion_distribution": {
self.model.config.id2label[j]: prob.item()
for j, prob in enumerate(emotion_probs)
}
}
return [result]
def _process_audio_bytes(self, audio_bytes: bytes) -> np.ndarray:
"""
Process audio bytes and convert to numpy array.
Args:
audio_bytes (bytes): Raw audio bytes
Returns:
np.ndarray: Processed audio array
"""
try:
import soundfile as sf
# Create BytesIO object from bytes
audio_io = io.BytesIO(audio_bytes)
# Load audio using soundfile
audio_data, sample_rate = sf.read(audio_io)
# Convert to float32 and ensure mono
if len(audio_data.shape) > 1:
audio_data = np.mean(audio_data, axis=1)
audio_data = audio_data.astype(np.float32)
return audio_data
except Exception as e:
# If soundfile fails, try alternative approach
try:
import librosa
audio_io = io.BytesIO(audio_bytes)
audio_data, sample_rate = librosa.load(audio_io, sr=16000, mono=True)
return audio_data.astype(np.float32)
except Exception as e2:
raise Exception(f"Failed to process audio: {str(e2)}")
# For compatibility, provide the standard handler as well
class EndpointHandler(ExtendedEndpointHandler):
"""
Standard handler that ensures HF compatibility by default.
"""
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Handle inference requests with HF-compatible output by default.
"""
# Force standard output for HF compatibility
parameters = data.get("parameters", {})
parameters["extended_output"] = False
parameters["include_arousal_valence"] = False
data["parameters"] = parameters
return super().__call__(data)