from transformers import AutoModel, AutoFeatureExtractor import torch import torchaudio from io import BytesIO import base64 SAMPLING_RATE = 24000 class EndpointHandler: def __init__(self, model_dir): self.feature_extractor = AutoFeatureExtractor.from_pretrained( model_dir, trust_remote_code=True ) self.model = AutoModel.from_pretrained(model_dir, trust_remote_code=True) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(self.device) self.model.eval() def __call__(self, data): inputs = data.get("inputs") # Accept raw bytes or base64 string if isinstance(inputs, str): audio_bytes = base64.b64decode(inputs) elif isinstance(inputs, (bytes, bytearray)): audio_bytes = inputs else: raise ValueError(f"Unexpected inputs type: {type(inputs)}") # Load audio and resample to 24kHz waveform, sr = torchaudio.load(BytesIO(audio_bytes)) if sr != SAMPLING_RATE: waveform = torchaudio.functional.resample( waveform, orig_freq=sr, new_freq=SAMPLING_RATE ) # Convert to mono, then numpy waveform = waveform.mean(dim=0).numpy() # Run through feature extractor at 24kHz processed = self.feature_extractor( waveform, return_tensors="pt", sampling_rate=SAMPLING_RATE, ).to(self.device) with torch.no_grad(): outputs = self.model(**processed) # Mean-pool last hidden state → [1, hidden_dim] → list embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy().tolist() return {"embedding": embedding[0]}