MERT-v1-330M / handler.py
danbigeffect's picture
Update handler.py
f77600b verified
from transformers import AutoModel, AutoFeatureExtractor
import torch
import torchaudio
from io import BytesIO
import base64
SAMPLING_RATE = 24000
class EndpointHandler:
def __init__(self, model_dir):
self.feature_extractor = AutoFeatureExtractor.from_pretrained(
model_dir, trust_remote_code=True
)
self.model = AutoModel.from_pretrained(model_dir, trust_remote_code=True)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
self.model.eval()
def __call__(self, data):
inputs = data.get("inputs")
# Accept raw bytes or base64 string
if isinstance(inputs, str):
audio_bytes = base64.b64decode(inputs)
elif isinstance(inputs, (bytes, bytearray)):
audio_bytes = inputs
else:
raise ValueError(f"Unexpected inputs type: {type(inputs)}")
# Load audio and resample to 24kHz
waveform, sr = torchaudio.load(BytesIO(audio_bytes))
if sr != SAMPLING_RATE:
waveform = torchaudio.functional.resample(
waveform, orig_freq=sr, new_freq=SAMPLING_RATE
)
# Convert to mono, then numpy
waveform = waveform.mean(dim=0).numpy()
# Run through feature extractor at 24kHz
processed = self.feature_extractor(
waveform,
return_tensors="pt",
sampling_rate=SAMPLING_RATE,
).to(self.device)
with torch.no_grad():
outputs = self.model(**processed)
# Mean-pool last hidden state → [1, hidden_dim] → list
embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy().tolist()
return {"embedding": embedding[0]}