File size: 1,771 Bytes
bde033d f77600b bde033d f77600b bde033d f77600b bde033d f77600b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 | from transformers import AutoModel, AutoFeatureExtractor
import torch
import torchaudio
from io import BytesIO
import base64
SAMPLING_RATE = 24000
class EndpointHandler:
def __init__(self, model_dir):
self.feature_extractor = AutoFeatureExtractor.from_pretrained(
model_dir, trust_remote_code=True
)
self.model = AutoModel.from_pretrained(model_dir, trust_remote_code=True)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
self.model.eval()
def __call__(self, data):
inputs = data.get("inputs")
# Accept raw bytes or base64 string
if isinstance(inputs, str):
audio_bytes = base64.b64decode(inputs)
elif isinstance(inputs, (bytes, bytearray)):
audio_bytes = inputs
else:
raise ValueError(f"Unexpected inputs type: {type(inputs)}")
# Load audio and resample to 24kHz
waveform, sr = torchaudio.load(BytesIO(audio_bytes))
if sr != SAMPLING_RATE:
waveform = torchaudio.functional.resample(
waveform, orig_freq=sr, new_freq=SAMPLING_RATE
)
# Convert to mono, then numpy
waveform = waveform.mean(dim=0).numpy()
# Run through feature extractor at 24kHz
processed = self.feature_extractor(
waveform,
return_tensors="pt",
sampling_rate=SAMPLING_RATE,
).to(self.device)
with torch.no_grad():
outputs = self.model(**processed)
# Mean-pool last hidden state → [1, hidden_dim] → list
embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy().tolist()
return {"embedding": embedding[0]} |