| | from transformers import AutoModel, AutoFeatureExtractor |
| | import torch |
| | import torchaudio |
| | from io import BytesIO |
| | import base64 |
| |
|
| | SAMPLING_RATE = 24000 |
| |
|
| | class EndpointHandler: |
| | def __init__(self, model_dir): |
| | self.feature_extractor = AutoFeatureExtractor.from_pretrained( |
| | model_dir, trust_remote_code=True |
| | ) |
| | self.model = AutoModel.from_pretrained(model_dir, trust_remote_code=True) |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | self.model.to(self.device) |
| | self.model.eval() |
| |
|
| | def __call__(self, data): |
| | inputs = data.get("inputs") |
| |
|
| | |
| | if isinstance(inputs, str): |
| | audio_bytes = base64.b64decode(inputs) |
| | elif isinstance(inputs, (bytes, bytearray)): |
| | audio_bytes = inputs |
| | else: |
| | raise ValueError(f"Unexpected inputs type: {type(inputs)}") |
| |
|
| | |
| | waveform, sr = torchaudio.load(BytesIO(audio_bytes)) |
| | if sr != SAMPLING_RATE: |
| | waveform = torchaudio.functional.resample( |
| | waveform, orig_freq=sr, new_freq=SAMPLING_RATE |
| | ) |
| |
|
| | |
| | waveform = waveform.mean(dim=0).numpy() |
| |
|
| | |
| | processed = self.feature_extractor( |
| | waveform, |
| | return_tensors="pt", |
| | sampling_rate=SAMPLING_RATE, |
| | ).to(self.device) |
| |
|
| | with torch.no_grad(): |
| | outputs = self.model(**processed) |
| |
|
| | |
| | embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy().tolist() |
| | return {"embedding": embedding[0]} |