Spaces:
Sleeping
Sleeping
| import logging | |
| import torch | |
| import torchaudio | |
| import numpy as np | |
| from pathlib import Path | |
| from typing import List, Dict, Any, Optional, Union | |
| # --- Constants --- | |
| TARGET_SR = 16000 | |
| WINDOW_SIZE_SEC = 30 # Whisper's native window | |
| VAD_THRESHOLD = 0.50 | |
| logger = logging.getLogger(__name__) | |
| class WhisperTranscriber: | |
| def __init__(self, model_path: str = "openai/whisper-large-v3", device: Optional[str] = None): | |
| """ | |
| Initialize the high-performance Whisper pipeline. | |
| Optimized for Egyptian Arabic real estate calls. | |
| """ | |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info("Loading model from %s on %s", model_path, self.device) | |
| try: | |
| from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline | |
| self.processor = AutoProcessor.from_pretrained(model_path) | |
| dtype = torch.float16 if "cuda" in self.device else torch.float32 | |
| self.model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
| model_path, | |
| dtype=dtype, | |
| low_cpu_mem_usage=True, | |
| ).to(self.device) | |
| # batch_size=16 is only useful on GPU; CPU benefits from 1-2 chunks at a time | |
| batch_size = 8 if "cuda" in self.device else 2 | |
| self.pipe = pipeline( | |
| "automatic-speech-recognition", | |
| model=self.model, | |
| tokenizer=self.processor.tokenizer, | |
| feature_extractor=self.processor.feature_extractor, | |
| chunk_length_s=30, | |
| batch_size=batch_size, | |
| return_timestamps=True, | |
| dtype=dtype, | |
| device=self.device, | |
| generate_kwargs={"max_new_tokens": 128}, | |
| ) | |
| except Exception as e: | |
| logger.error("Failed to load Whisper backend: %s", e) | |
| raise | |
| # Load domain-specific initial prompt | |
| self.initial_prompt = self._load_prompt() | |
| # Load Silero VAD | |
| self._vad_model: Optional[torch.nn.Module] = None | |
| self._get_speech_timestamps = None | |
| try: | |
| from silero_vad import load_silero_vad, get_speech_timestamps | |
| self._vad_model = load_silero_vad().to(self.device) | |
| self._get_speech_timestamps = get_speech_timestamps | |
| logger.info("Silero VAD loaded (threshold=%.2f).", VAD_THRESHOLD) | |
| except Exception as exc: | |
| logger.warning("Silero VAD could not be loaded: %s", exc) | |
| def _load_prompt(self) -> str: | |
| prompt_path = Path(__file__).parent.parent.parent / "prompts" / "whisper_initial_prompt.txt" | |
| if prompt_path.exists(): | |
| prompt = prompt_path.read_text(encoding="utf-8").strip() | |
| logger.info("Domain prompt loaded (%d characters).", len(prompt)) | |
| return prompt | |
| return "" | |
| def _load(self, audio_path: Union[str, Path]) -> torch.Tensor: | |
| import soundfile as sf | |
| audio_data, sr = sf.read(audio_path) | |
| # Convert to torch tensor | |
| waveform = torch.from_numpy(audio_data).float() | |
| # Handle multi-channel (soundfile returns [samples, channels]) | |
| if len(waveform.shape) > 1: | |
| waveform = torch.mean(waveform, dim=1) | |
| # Resample if necessary | |
| if sr != TARGET_SR: | |
| import torchaudio.transforms as T | |
| resampler = T.Resample(sr, TARGET_SR) | |
| waveform = resampler(waveform.unsqueeze(0)).squeeze(0) | |
| return waveform | |
| def _apply_vad(self, audio: torch.Tensor) -> torch.Tensor: | |
| if self._vad_model is None: | |
| return audio | |
| speech_timestamps = self._get_speech_timestamps( | |
| audio, self._vad_model, sampling_rate=TARGET_SR, threshold=VAD_THRESHOLD | |
| ) | |
| if not speech_timestamps: | |
| return torch.tensor([], device=audio.device) | |
| chunks = [audio[ts['start']:ts['end']] for ts in speech_timestamps] | |
| return torch.cat(chunks) if chunks else torch.tensor([], device=audio.device) | |
| def transcribe(self, audio_path: Union[str, Path]) -> str: | |
| """ | |
| Pure, high-performance transcription. | |
| Returns a clean, single-stream Egyptian Arabic transcript. | |
| """ | |
| logger.info("Transcribing: %s", audio_path) | |
| audio = self._load(audio_path).to(self.device) | |
| audio_clean = self._apply_vad(audio) | |
| if len(audio_clean) == 0: | |
| logger.warning("No speech detected by VAD.") | |
| return "" | |
| generate_kwargs: dict = {"language": "arabic"} | |
| if self.initial_prompt: | |
| prompt_ids = self.processor.get_prompt_ids(self.initial_prompt) | |
| if isinstance(prompt_ids, np.ndarray): | |
| prompt_ids = torch.from_numpy(prompt_ids).to(self.device) | |
| generate_kwargs["prompt_ids"] = prompt_ids | |
| # Inference using the optimized pipeline | |
| result = self.pipe( | |
| audio_clean.cpu().numpy(), | |
| generate_kwargs=generate_kwargs, | |
| ) | |
| transcript = result.get("text", "").strip() | |
| logger.info("Transcription complete. Length: %d chars.", len(transcript)) | |
| return transcript | |
| if __name__ == "__main__": | |
| logging.basicConfig(level=logging.INFO) | |
| # Quick test with merged model if available | |
| t = WhisperTranscriber(model_path="outputs/checkpoints/merged_model") | |
| print(t.transcribe("1775560189.41808.wav")) | |