|
|
import json |
|
|
import base64 |
|
|
import io |
|
|
import torch |
|
|
import numpy as np |
|
|
from typing import Dict, List, Any |
|
|
import os |
|
|
import sys |
|
|
|
|
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
if current_dir not in sys.path: |
|
|
sys.path.insert(0, current_dir) |
|
|
|
|
|
try: |
|
|
from modeling_emotion_av import EmotionAVModel, EmotionAVConfig |
|
|
from feature_extraction_emotion_av import EmotionAVFeatureExtractor |
|
|
except ImportError as e: |
|
|
print(f"Warning: Could not import custom modules: {e}") |
|
|
|
|
|
from transformers import AutoModel, AutoConfig, AutoFeatureExtractor |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, model_dir: str = ""): |
|
|
""" |
|
|
Initialize the handler for the emotion-av model. |
|
|
|
|
|
Args: |
|
|
model_dir (str): Path to the model directory |
|
|
""" |
|
|
try: |
|
|
print(f"Initializing handler with model_dir: {model_dir}") |
|
|
|
|
|
|
|
|
config_path = os.path.join(model_dir, "config.json") |
|
|
if not os.path.exists(config_path): |
|
|
raise FileNotFoundError(f"Config file not found: {config_path}") |
|
|
|
|
|
|
|
|
with open(config_path, 'r', encoding='utf-8') as f: |
|
|
config_content = f.read().strip() |
|
|
if not config_content: |
|
|
raise ValueError("Config file is empty") |
|
|
|
|
|
|
|
|
config_data = json.loads(config_content) |
|
|
print(f"Successfully loaded config with keys: {list(config_data.keys())}") |
|
|
|
|
|
|
|
|
try: |
|
|
self.model = EmotionAVModel.from_pretrained( |
|
|
model_dir, |
|
|
trust_remote_code=True, |
|
|
local_files_only=True |
|
|
) |
|
|
print("Successfully loaded EmotionAVModel") |
|
|
except Exception as e: |
|
|
print(f"Failed to load with EmotionAVModel: {e}") |
|
|
|
|
|
self.model = AutoModel.from_pretrained( |
|
|
model_dir, |
|
|
trust_remote_code=True, |
|
|
local_files_only=True |
|
|
) |
|
|
print("Successfully loaded with AutoModel") |
|
|
|
|
|
try: |
|
|
self.feature_extractor = EmotionAVFeatureExtractor.from_pretrained( |
|
|
model_dir, |
|
|
trust_remote_code=True, |
|
|
local_files_only=True |
|
|
) |
|
|
print("Successfully loaded EmotionAVFeatureExtractor") |
|
|
except Exception as e: |
|
|
print(f"Failed to load with EmotionAVFeatureExtractor: {e}") |
|
|
|
|
|
self.feature_extractor = AutoFeatureExtractor.from_pretrained( |
|
|
model_dir, |
|
|
trust_remote_code=True, |
|
|
local_files_only=True |
|
|
) |
|
|
print("Successfully loaded with AutoFeatureExtractor") |
|
|
|
|
|
self.model.eval() |
|
|
print("Handler initialization completed successfully") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error during handler initialization: {e}") |
|
|
raise |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Handle inference requests. |
|
|
|
|
|
Args: |
|
|
data (Dict): Input data containing 'inputs' key with audio data |
|
|
|
|
|
Returns: |
|
|
List[Dict]: Prediction results in HF-compatible format |
|
|
""" |
|
|
try: |
|
|
|
|
|
inputs = data.get("inputs", data) |
|
|
parameters = data.get("parameters", {}) |
|
|
|
|
|
|
|
|
if isinstance(inputs, str): |
|
|
|
|
|
try: |
|
|
audio_bytes = base64.b64decode(inputs) |
|
|
audio_data = self._process_audio_bytes(audio_bytes) |
|
|
except Exception as e: |
|
|
return [{"error": f"Failed to decode base64 audio: {str(e)}"}] |
|
|
elif isinstance(inputs, (list, np.ndarray)): |
|
|
|
|
|
audio_data = np.array(inputs, dtype=np.float32) |
|
|
else: |
|
|
return [{"error": "Invalid input format. Expected base64 string or audio array."}] |
|
|
|
|
|
|
|
|
features = self.feature_extractor( |
|
|
audio_data, |
|
|
sampling_rate=parameters.get("sampling_rate", 16000), |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(features["input_features"]) |
|
|
|
|
|
|
|
|
emotion_logits = outputs.emotion_logits |
|
|
arousal_valence = outputs.arousal_valence |
|
|
|
|
|
|
|
|
emotion_probs = torch.softmax(emotion_logits, dim=-1) |
|
|
|
|
|
|
|
|
arousal = (arousal_valence[0, 0].item() * 2) - 1 |
|
|
valence = (arousal_valence[0, 1].item() * 2) - 1 |
|
|
|
|
|
|
|
|
results = [] |
|
|
probs_sorted, indices = torch.sort(emotion_probs[0], descending=True) |
|
|
|
|
|
|
|
|
for i in range(len(indices)): |
|
|
idx = indices[i].item() |
|
|
label = self.model.config.id2label[idx] |
|
|
score = probs_sorted[i].item() |
|
|
|
|
|
|
|
|
results.append({ |
|
|
"label": label, |
|
|
"score": score |
|
|
}) |
|
|
|
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
return [{"error": f"Inference failed: {str(e)}"}] |
|
|
|
|
|
def _process_audio_bytes(self, audio_bytes: bytes) -> np.ndarray: |
|
|
""" |
|
|
Process audio bytes and convert to numpy array. |
|
|
|
|
|
Args: |
|
|
audio_bytes (bytes): Raw audio bytes |
|
|
|
|
|
Returns: |
|
|
np.ndarray: Processed audio array |
|
|
""" |
|
|
try: |
|
|
import soundfile as sf |
|
|
|
|
|
|
|
|
audio_io = io.BytesIO(audio_bytes) |
|
|
|
|
|
|
|
|
audio_data, sample_rate = sf.read(audio_io) |
|
|
|
|
|
|
|
|
if len(audio_data.shape) > 1: |
|
|
audio_data = np.mean(audio_data, axis=1) |
|
|
|
|
|
audio_data = audio_data.astype(np.float32) |
|
|
|
|
|
return audio_data |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
try: |
|
|
import librosa |
|
|
audio_io = io.BytesIO(audio_bytes) |
|
|
audio_data, sample_rate = librosa.load(audio_io, sr=16000, mono=True) |
|
|
return audio_data.astype(np.float32) |
|
|
except Exception as e2: |
|
|
raise Exception(f"Failed to process audio: {str(e2)}") |