vishrutjha's picture
Fix handler for HF Inference API compatibility
b575114 verified
import json
import base64
import io
import torch
import numpy as np
from typing import Dict, List, Any
import os
import sys
# Add the current directory to Python path for local imports
current_dir = os.path.dirname(os.path.abspath(__file__))
if current_dir not in sys.path:
sys.path.insert(0, current_dir)
try:
from modeling_emotion_av import EmotionAVModel, EmotionAVConfig
from feature_extraction_emotion_av import EmotionAVFeatureExtractor
except ImportError as e:
print(f"Warning: Could not import custom modules: {e}")
# Fallback imports
from transformers import AutoModel, AutoConfig, AutoFeatureExtractor
class EndpointHandler:
def __init__(self, model_dir: str = ""):
"""
Initialize the handler for the emotion-av model.
Args:
model_dir (str): Path to the model directory
"""
try:
print(f"Initializing handler with model_dir: {model_dir}")
# Validate config file exists and is readable
config_path = os.path.join(model_dir, "config.json")
if not os.path.exists(config_path):
raise FileNotFoundError(f"Config file not found: {config_path}")
# Test reading config file
with open(config_path, 'r', encoding='utf-8') as f:
config_content = f.read().strip()
if not config_content:
raise ValueError("Config file is empty")
# Validate JSON
config_data = json.loads(config_content)
print(f"Successfully loaded config with keys: {list(config_data.keys())}")
# Load the custom model and feature extractor with error handling
try:
self.model = EmotionAVModel.from_pretrained(
model_dir,
trust_remote_code=True,
local_files_only=True
)
print("Successfully loaded EmotionAVModel")
except Exception as e:
print(f"Failed to load with EmotionAVModel: {e}")
# Fallback to AutoModel
self.model = AutoModel.from_pretrained(
model_dir,
trust_remote_code=True,
local_files_only=True
)
print("Successfully loaded with AutoModel")
try:
self.feature_extractor = EmotionAVFeatureExtractor.from_pretrained(
model_dir,
trust_remote_code=True,
local_files_only=True
)
print("Successfully loaded EmotionAVFeatureExtractor")
except Exception as e:
print(f"Failed to load with EmotionAVFeatureExtractor: {e}")
# Fallback to AutoFeatureExtractor
self.feature_extractor = AutoFeatureExtractor.from_pretrained(
model_dir,
trust_remote_code=True,
local_files_only=True
)
print("Successfully loaded with AutoFeatureExtractor")
self.model.eval()
print("Handler initialization completed successfully")
except Exception as e:
print(f"Error during handler initialization: {e}")
raise
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Handle inference requests.
Args:
data (Dict): Input data containing 'inputs' key with audio data
Returns:
List[Dict]: Prediction results in HF-compatible format
"""
try:
# Get the inputs
inputs = data.get("inputs", data)
parameters = data.get("parameters", {})
# Handle different input formats
if isinstance(inputs, str):
# Base64 encoded audio
try:
audio_bytes = base64.b64decode(inputs)
audio_data = self._process_audio_bytes(audio_bytes)
except Exception as e:
return [{"error": f"Failed to decode base64 audio: {str(e)}"}]
elif isinstance(inputs, (list, np.ndarray)):
# Raw audio array
audio_data = np.array(inputs, dtype=np.float32)
else:
return [{"error": "Invalid input format. Expected base64 string or audio array."}]
# Extract features
features = self.feature_extractor(
audio_data,
sampling_rate=parameters.get("sampling_rate", 16000),
return_tensors="pt"
)
# Run inference
with torch.no_grad():
outputs = self.model(features["input_features"])
# Process outputs
emotion_logits = outputs.emotion_logits
arousal_valence = outputs.arousal_valence
# Get emotion probabilities
emotion_probs = torch.softmax(emotion_logits, dim=-1)
# Denormalize arousal-valence from [0,1] to [-1,1]
arousal = (arousal_valence[0, 0].item() * 2) - 1
valence = (arousal_valence[0, 1].item() * 2) - 1
# Create HF-compatible output: Array<{label: string, score: number}>
results = []
probs_sorted, indices = torch.sort(emotion_probs[0], descending=True)
# Return all emotions sorted by confidence
for i in range(len(indices)):
idx = indices[i].item()
label = self.model.config.id2label[idx]
score = probs_sorted[i].item()
# Strictly follow HF format: only label and score
results.append({
"label": label,
"score": score
})
return results
except Exception as e:
return [{"error": f"Inference failed: {str(e)}"}]
def _process_audio_bytes(self, audio_bytes: bytes) -> np.ndarray:
"""
Process audio bytes and convert to numpy array.
Args:
audio_bytes (bytes): Raw audio bytes
Returns:
np.ndarray: Processed audio array
"""
try:
import soundfile as sf
# Create BytesIO object from bytes
audio_io = io.BytesIO(audio_bytes)
# Load audio using soundfile
audio_data, sample_rate = sf.read(audio_io)
# Convert to float32 and ensure mono
if len(audio_data.shape) > 1:
audio_data = np.mean(audio_data, axis=1)
audio_data = audio_data.astype(np.float32)
return audio_data
except Exception as e:
# If soundfile fails, try alternative approach
try:
import librosa
audio_io = io.BytesIO(audio_bytes)
audio_data, sample_rate = librosa.load(audio_io, sr=16000, mono=True)
return audio_data.astype(np.float32)
except Exception as e2:
raise Exception(f"Failed to process audio: {str(e2)}")