import logging import torch import torchaudio import numpy as np from pathlib import Path from typing import List, Dict, Any, Optional, Union # --- Constants --- TARGET_SR = 16000 WINDOW_SIZE_SEC = 30 # Whisper's native window VAD_THRESHOLD = 0.50 logger = logging.getLogger(__name__) class WhisperTranscriber: def __init__(self, model_path: str = "openai/whisper-large-v3", device: Optional[str] = None): """ Initialize the high-performance Whisper pipeline. Optimized for Egyptian Arabic real estate calls. """ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") logger.info("Loading model from %s on %s", model_path, self.device) try: from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline self.processor = AutoProcessor.from_pretrained(model_path) dtype = torch.float16 if "cuda" in self.device else torch.float32 self.model = AutoModelForSpeechSeq2Seq.from_pretrained( model_path, dtype=dtype, low_cpu_mem_usage=True, ).to(self.device) # batch_size=16 is only useful on GPU; CPU benefits from 1-2 chunks at a time batch_size = 8 if "cuda" in self.device else 2 self.pipe = pipeline( "automatic-speech-recognition", model=self.model, tokenizer=self.processor.tokenizer, feature_extractor=self.processor.feature_extractor, chunk_length_s=30, batch_size=batch_size, return_timestamps=True, dtype=dtype, device=self.device, generate_kwargs={"max_new_tokens": 128}, ) except Exception as e: logger.error("Failed to load Whisper backend: %s", e) raise # Load domain-specific initial prompt self.initial_prompt = self._load_prompt() # Load Silero VAD self._vad_model: Optional[torch.nn.Module] = None self._get_speech_timestamps = None try: from silero_vad import load_silero_vad, get_speech_timestamps self._vad_model = load_silero_vad().to(self.device) self._get_speech_timestamps = get_speech_timestamps logger.info("Silero VAD loaded (threshold=%.2f).", VAD_THRESHOLD) except Exception as exc: logger.warning("Silero VAD could not be loaded: %s", exc) def _load_prompt(self) -> str: prompt_path = Path(__file__).parent.parent.parent / "prompts" / "whisper_initial_prompt.txt" if prompt_path.exists(): prompt = prompt_path.read_text(encoding="utf-8").strip() logger.info("Domain prompt loaded (%d characters).", len(prompt)) return prompt return "" def _load(self, audio_path: Union[str, Path]) -> torch.Tensor: import soundfile as sf audio_data, sr = sf.read(audio_path) # Convert to torch tensor waveform = torch.from_numpy(audio_data).float() # Handle multi-channel (soundfile returns [samples, channels]) if len(waveform.shape) > 1: waveform = torch.mean(waveform, dim=1) # Resample if necessary if sr != TARGET_SR: import torchaudio.transforms as T resampler = T.Resample(sr, TARGET_SR) waveform = resampler(waveform.unsqueeze(0)).squeeze(0) return waveform def _apply_vad(self, audio: torch.Tensor) -> torch.Tensor: if self._vad_model is None: return audio speech_timestamps = self._get_speech_timestamps( audio, self._vad_model, sampling_rate=TARGET_SR, threshold=VAD_THRESHOLD ) if not speech_timestamps: return torch.tensor([], device=audio.device) chunks = [audio[ts['start']:ts['end']] for ts in speech_timestamps] return torch.cat(chunks) if chunks else torch.tensor([], device=audio.device) def transcribe(self, audio_path: Union[str, Path]) -> str: """ Pure, high-performance transcription. Returns a clean, single-stream Egyptian Arabic transcript. """ logger.info("Transcribing: %s", audio_path) audio = self._load(audio_path).to(self.device) audio_clean = self._apply_vad(audio) if len(audio_clean) == 0: logger.warning("No speech detected by VAD.") return "" generate_kwargs: dict = {"language": "arabic"} if self.initial_prompt: prompt_ids = self.processor.get_prompt_ids(self.initial_prompt) if isinstance(prompt_ids, np.ndarray): prompt_ids = torch.from_numpy(prompt_ids).to(self.device) generate_kwargs["prompt_ids"] = prompt_ids # Inference using the optimized pipeline result = self.pipe( audio_clean.cpu().numpy(), generate_kwargs=generate_kwargs, ) transcript = result.get("text", "").strip() logger.info("Transcription complete. Length: %d chars.", len(transcript)) return transcript if __name__ == "__main__": logging.basicConfig(level=logging.INFO) # Quick test with merged model if available t = WhisperTranscriber(model_path="outputs/checkpoints/merged_model") print(t.transcribe("1775560189.41808.wav"))