ASR_AGENT_ / adapters /whisper_transformers.py
unknown
Switch to transformers whisper for HF Spaces compatibility
323eea2
from __future__ import annotations
import time
from typing import Dict, Optional
import librosa
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from core.interfaces import ASRModel
from core.schemas import ASRConfig, ASROutput, Segment
class TransformersWhisperAdapter(ASRModel):
def __init__(self, model_name: str = "openai/whisper-small", device: Optional[str] = None):
self.model_name = model_name
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.processor = WhisperProcessor.from_pretrained(model_name)
self.model = WhisperForConditionalGeneration.from_pretrained(model_name).to(self.device)
def model_info(self) -> Dict:
return {"name": "transformers-whisper", "model_name": self.model_name, "device": self.device}
def transcribe(self, utt_id: str, audio_uri: str, config: Optional[ASRConfig] = None) -> ASROutput:
config = config or ASRConfig()
y, sr = librosa.load(audio_uri, sr=16000, mono=True)
duration_s = float(len(y) / 16000.0)
inputs = self.processor(y, sampling_rate=16000, return_tensors="pt")
input_features = inputs.input_features.to(self.device)
forced_decoder_ids = None
if config.language:
forced_decoder_ids = self.processor.get_decoder_prompt_ids(language=config.language, task=config.task)
t0 = time.time()
predicted_ids = self.model.generate(
input_features,
forced_decoder_ids=forced_decoder_ids,
num_beams=config.beam_size,
)
latency_ms = (time.time() - t0) * 1000.0
text = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)[0].strip()
return ASROutput(
utt_id=utt_id,
hyp_text=text,
segments=[Segment(start=0.0, end=duration_s, text=text)],
language=config.language,
duration_s=duration_s,
latency_ms=latency_ms,
confidence=None,
extras={},
)