from models import StudentForAudioClassification import torch import torchaudio class EndpointHandler: def __init__(self, model_dir, *args, **kwargs): self.model = StudentForAudioClassification.from_pretrained(model_dir, trust_remote_code=True) self.model.eval() bundle = torchaudio.pipelines.WAV2VEC2_BASE self.w2v_model = bundle.get_model() self.w2v_model.eval() def __call__(self, data): import io waveform, orig_sr = torchaudio.load(io.BytesIO(data["inputs"])) waveform = waveform.mean(dim=0, keepdim=True) if orig_sr != 16000: resampler = torchaudio.transforms.Resample(orig_sr, 16000) waveform = resampler(waveform) with torch.no_grad(): features = self.w2v_model(waveform)[0] x_w2v = features.mean(dim=1) x_w2v = x_w2v[:, :512] outputs = self.model(x_w2v) probs = torch.softmax(outputs.logits, dim=-1) return { "probabilities": probs.squeeze(0).tolist(), "label": int(probs.argmax(dim=-1)[0]) }