""" Production-ready FastAPI service for speech-to-text using WhisperX. Optimized for CPU-only environments with controlled single-flight inference. Word timestamps are obtained via forced alignment (Wav2Vec2). Chunking logic automatically handles large audio files by splitting them into overlapping chunks, transcribing each, and merging the results smartly. """ import asyncio import logging import os import tempfile import time from contextlib import asynccontextmanager from pathlib import Path from typing import Optional, List, Dict, Any import aiofiles import httpx import whisperx import torch import numpy as np from fastapi import FastAPI, UploadFile, File, Form, HTTPException from pydantic import BaseModel, Field, HttpUrl from starlette.middleware.base import BaseHTTPMiddleware from whisperx.alignment import load_align_model, align from huggingface_hub import snapshot_download from dotenv import load_dotenv load_dotenv() # ---------------------------- # Config # ---------------------------- logging.basicConfig( level=os.getenv("LOG_LEVEL", "INFO"), format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) MODEL_SIZE = os.getenv("MODEL_SIZE", "small") MAX_FILE_SIZE_MB = int(os.getenv("MAX_FILE_SIZE_MB", "200")) PORT = int(os.getenv("PORT", "7860")) TRANSCRIBE_TIMEOUT_SEC = int(os.getenv("TRANSCRIBE_TIMEOUT_SEC", "300")) MAX_CONCURRENT_TRANSCRIBES = int(os.getenv("MAX_CONCURRENT_TRANSCRIBES", "1")) # Chunking parameters for large audio files (seconds) CHUNK_SIZE_SEC = int(os.getenv("CHUNK_SIZE_SEC", "30")) OVERLAP_SEC = int(os.getenv("OVERLAP_SEC", "5")) DEVICE = "cpu" HOST_CPU = os.cpu_count() or 4 CPU_THREADS = int(os.getenv("CPU_THREADS", str(max(1, HOST_CPU - 1)))) os.environ["OMP_NUM_THREADS"] = str(CPU_THREADS) os.environ["MKL_NUM_THREADS"] = str(CPU_THREADS) os.environ["OPENBLAS_NUM_THREADS"] = str(CPU_THREADS) os.environ["NUMEXPR_NUM_THREADS"] = str(CPU_THREADS) torch.set_num_threads(CPU_THREADS) torch.set_num_interop_threads(1) # Custom alignment model mapping – values are directory paths or HF repo IDs CUSTOM_ALIGN_MODELS = { "hi": os.getenv("HINDI_ALIGN_MODEL_PATH", "./models/hindi_alignment"), } # Global ASR model asr_model: Optional[Any] = None # Alignment models cache: key = language code, value = (model, metadata) align_cache: Dict[str, tuple] = {} # Concurrency gate transcribe_gate = asyncio.Semaphore(MAX_CONCURRENT_TRANSCRIBES) # ---------------------------- # Schemas # ---------------------------- class TranscriptionOptions(BaseModel): language: Optional[str] = Field(None, description="Language code, e.g. 'hi'") task: str = Field("transcribe", pattern="^(transcribe|translate)$") beam_size: int = Field(5, ge=1, le=10) # not used by WhisperX temperature: float = Field(0.0, ge=0.0, le=1.0) # not used return_timestamps: bool = Field(False, description="Include segment and word timestamps") class TranscriptionRequest(BaseModel): url: HttpUrl options: TranscriptionOptions = Field(default_factory=TranscriptionOptions) class TranscriptionWord(BaseModel): word: str start: float end: float probability: Optional[float] = None class TranscriptionSegment(BaseModel): start: float end: float text: str words: Optional[List[TranscriptionWord]] = None class TranscriptionResponse(BaseModel): text: str language: Optional[str] = None duration: Optional[float] = None segments: Optional[List[TranscriptionSegment]] = None class ErrorResponse(BaseModel): error: str detail: Optional[str] = None # ---------------------------- # Lifespan # ---------------------------- @asynccontextmanager async def lifespan(app: FastAPI): global asr_model logger.info( f"Loading WhisperX model={MODEL_SIZE} on device={DEVICE}, " f"CPU_THREADS={CPU_THREADS} (host={HOST_CPU})" ) # Load ASR model asr_model = whisperx.load_model( MODEL_SIZE, device=DEVICE, compute_type="float32", language=None, ) logger.info("WhisperX ASR model loaded") # Optionally pre‑load Hindi alignment (will work now) try: _get_align_model("hi") logger.info("Hindi alignment model loaded and cached") except Exception as e: logger.warning(f"Hindi alignment model not pre-loaded: {e}") yield asr_model = None align_cache.clear() logger.info("Models unloaded") app = FastAPI( title="Speech-to-Text API", version="2.1.0", lifespan=lifespan, docs_url="/docs", redoc_url="/redoc", ) # ---------------------------- # Helpers # ---------------------------- def download_hindi_alignment_model(target_dir: str) -> str: """ Download the Hindi Wav2Vec2 alignment model. Uses the public repo theainerd/Wav2Vec2-large-xlsr-hindi. """ os.makedirs(target_dir, exist_ok=True) if os.path.isfile(os.path.join(target_dir, "pytorch_model.bin")): logger.info("Hindi alignment model already present at %s", target_dir) return target_dir token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN") repo_id = os.getenv("HINDI_ALIGN_REPO", "theainerd/Wav2Vec2-large-xlsr-hindi") logger.info("Downloading Hindi alignment model from %s...", repo_id) snapshot_download( repo_id=repo_id, local_dir=target_dir, local_dir_use_symlinks=False, token=token, ) logger.info("Download complete to %s", target_dir) return target_dir async def download_audio(url: str, timeout_sec: int = 90) -> Path: try: async with httpx.AsyncClient(timeout=timeout_sec, follow_redirects=True) as client: response = await client.get(url) response.raise_for_status() content = response.content size = len(content) if size > MAX_FILE_SIZE_MB * 1024 * 1024: raise HTTPException( status_code=413, detail=f"File too large: {size / (1024*1024):.2f} MB > {MAX_FILE_SIZE_MB} MB", ) temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".audio") async with aiofiles.open(temp_file.name, "wb") as f: await f.write(content) logger.info(f"Downloaded audio from {url} to {temp_file.name} ({size} bytes)") return Path(temp_file.name) except httpx.TimeoutException: raise HTTPException(status_code=504, detail="Download timeout") except httpx.HTTPStatusError as e: raise HTTPException(status_code=400, detail=f"Download failed: {str(e)}") except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=f"Unexpected download error: {str(e)}") async def convert_to_wav(input_path: Path) -> Path: with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_out: output_path = Path(tmp_out.name) cmd = [ "ffmpeg", "-i", str(input_path), "-acodec", "pcm_s16le", "-ar", "16000", "-ac", "1", "-y", str(output_path), ] process = await asyncio.create_subprocess_exec( *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) _, stderr = await process.communicate() if process.returncode != 0: try: output_path.unlink(missing_ok=True) except Exception: pass logger.error(f"ffmpeg conversion failed: {stderr.decode(errors='ignore')}") raise HTTPException(status_code=400, detail="Audio conversion failed") logger.info(f"Converted {input_path} -> {output_path}") return output_path async def transcribe_audio( audio_path: Path, language: Optional[str] = None, task: str = "transcribe", beam_size: int = 5, temperature: float = 0.0, return_timestamps: bool = False, ) -> Dict[str, Any]: global asr_model if asr_model is None: raise RuntimeError("ASR model not loaded") async with transcribe_gate: logger.info("Transcription slot acquired") try: return await asyncio.wait_for( asyncio.to_thread( _transcribe_sync, audio_path, language, task, return_timestamps, ), timeout=TRANSCRIBE_TIMEOUT_SEC, ) except asyncio.TimeoutError: raise HTTPException(status_code=504, detail=f"Transcription timed out after {TRANSCRIBE_TIMEOUT_SEC}s") def _get_align_model(lang_code: str): """Load (or retrieve from cache) the alignment model for a given language. Custom models are downloaded automatically if missing.""" if lang_code in align_cache: logger.info("Using cached alignment model for %s", lang_code) return align_cache[lang_code] # Custom model (e.g., Hindi) if lang_code in CUSTOM_ALIGN_MODELS: model_source = CUSTOM_ALIGN_MODELS[lang_code] # directory path or HF repo ID # If it's a local directory that doesn't exist yet, auto‑download if not os.path.exists(model_source): if lang_code == "hi": logger.info("Hindi alignment model not found at %s, initiating download", model_source) model_source = download_hindi_alignment_model(model_source) CUSTOM_ALIGN_MODELS[lang_code] = model_source # update mapping else: raise HTTPException( status_code=400, detail=f"No alignment model found for language '{lang_code}' at {model_source}." ) logger.info("Loading custom alignment model for %s from %s", lang_code, model_source) try: # FIX: use model_name= (not model_path=) model, metadata = load_align_model( language_code=lang_code, device=DEVICE, model_name=model_source, # ← CORRECT PARAMETER ) align_cache[lang_code] = (model, metadata) return model, metadata except Exception as e: logger.error("Custom alignment model for %s failed: %s", lang_code, e) raise HTTPException(status_code=500, detail=f"Alignment model for '{lang_code}' could not be loaded") else: # Default WhisperX models (en, fr, de, …) logger.info("Loading default alignment model for language: %s", lang_code) try: model, metadata = load_align_model(language_code=lang_code, device=DEVICE) align_cache[lang_code] = (model, metadata) return model, metadata except Exception as e: logger.error("Could not load alignment model for %s: %s", lang_code, e) raise HTTPException( status_code=400, detail=f"Alignment model not available for language '{lang_code}'." ) def _avg_word_prob(words: Optional[List[Dict]]) -> float: """Return average probability from a list of word dicts (looks for 'score' key).""" if not words: return 0.0 probs = [w.get("score", 0.0) for w in words] return sum(probs) / len(probs) if probs else 0.0 def _merge_segments(segments: List[Dict]) -> List[Dict]: """Merge overlapping segments by keeping the one with higher average word probability.""" if not segments: return [] sorted_segs = sorted(segments, key=lambda s: s["start"]) merged = [] for seg in sorted_segs: if not merged: merged.append(seg) continue prev = merged[-1] # Overlap if current segment starts before previous ends (with small tolerance) if seg["start"] < prev["end"] - 0.01: prob_prev = _avg_word_prob(prev.get("words")) prob_seg = _avg_word_prob(seg.get("words")) if prob_seg > prob_prev: merged[-1] = seg elif prob_seg == prob_prev: # Keep the longer segment if probabilities equal if (seg["end"] - seg["start"]) > (prev["end"] - prev["start"]): merged[-1] = seg # else keep previous else: merged.append(seg) return merged def _transcribe_sync( audio_path: Path, language: Optional[str], task: str, return_timestamps: bool, ) -> Dict[str, Any]: """Transcription logic with automatic chunking for large audio files.""" full_audio = whisperx.load_audio(str(audio_path)) full_duration = full_audio.shape[0] / 16000 # If audio is short, use single-pass if full_duration <= CHUNK_SIZE_SEC: logger.info("Short audio, using single-pass transcription") result = asr_model.transcribe( full_audio, batch_size=16, language=language, task=task, ) if return_timestamps: lang_code = language if language else result.get("language", "en") try: align_model, align_metadata = _get_align_model(lang_code) result = align( result["segments"], align_model, align_metadata, full_audio, device=DEVICE, return_char_alignments=False, ) except HTTPException: logger.warning("Alignment model unavailable for %s; returning segments without word timestamps", lang_code) result["segments"] = result.get("segments", []) except Exception as e: logger.warning(f"Alignment failed: {e}; returning without word timestamps") result["segments"] = result.get("segments", []) all_text = [] segments_out = [] for seg in result.get("segments", []): text = seg["text"].strip() all_text.append(text) word_list = None if return_timestamps and "words" in seg: word_list = [ { "word": w["word"].strip(), "start": w.get("start", 0.0), "end": w.get("end", 0.0), "probability": w.get("score", None), } for w in seg["words"] ] segments_out.append({ "start": seg["start"], "end": seg["end"], "text": text, "words": word_list, }) return { "text": " ".join(all_text).strip(), "language": result.get("language"), "duration": full_duration, "segments": segments_out if return_timestamps else None, } # ----------------------------------- # Chunking for large audio files # ----------------------------------- logger.info(f"Audio duration {full_duration:.1f}s > {CHUNK_SIZE_SEC}s, processing in overlapping chunks") stride = CHUNK_SIZE_SEC - OVERLAP_SEC # Generate chunk start times start_times = [] t = 0.0 while t < full_duration: start_times.append(t) t += stride all_segments = [] detected_lang = None for idx, chunk_start in enumerate(start_times): chunk_end = min(chunk_start + CHUNK_SIZE_SEC, full_duration) start_sample = int(chunk_start * 16000) end_sample = int(chunk_end * 16000) chunk_audio = full_audio[start_sample:end_sample] logger.debug(f"Chunk {idx+1}/{len(start_times)}: {chunk_start:.1f}s – {chunk_end:.1f}s") result_chunk = asr_model.transcribe( chunk_audio, batch_size=16, language=language, task=task, ) if detected_lang is None: detected_lang = result_chunk.get("language") lang_code = language if language else result_chunk.get("language", "en") # Align if requested if return_timestamps: try: align_model, align_metadata = _get_align_model(lang_code) result_chunk = align( result_chunk["segments"], align_model, align_metadata, chunk_audio, device=DEVICE, return_char_alignments=False, ) except HTTPException: logger.warning(f"Alignment model unavailable for {lang_code} in chunk {idx+1}") except Exception as e: logger.warning(f"Alignment failed for chunk {idx+1}: {e}") # Adjust timestamps to global time and collect segments for seg in result_chunk.get("segments", []): seg["start"] += chunk_start seg["end"] += chunk_start if "words" in seg: for w in seg["words"]: if "start" in w: w["start"] += chunk_start if "end" in w: w["end"] += chunk_start all_segments.append(seg) # Merge overlapping segments intelligently merged_segments = _merge_segments(all_segments) # Assemble final output all_text = [] segments_out = [] for seg in merged_segments: text = seg["text"].strip() all_text.append(text) word_list = None if return_timestamps and "words" in seg: word_list = [ { "word": w["word"].strip(), "start": w.get("start", 0.0), "end": w.get("end", 0.0), "probability": w.get("score", None), } for w in seg["words"] ] segments_out.append({ "start": seg["start"], "end": seg["end"], "text": text, "words": word_list, }) return { "text": " ".join(all_text).strip(), "language": detected_lang, "duration": full_duration, "segments": segments_out if return_timestamps else None, } def cleanup_paths(*paths: Optional[Path]): for p in paths: if p is None: continue try: p.unlink(missing_ok=True) except Exception as e: logger.warning(f"Failed to delete temp file {p}: {e}") # ---------------------------- # Endpoints # ---------------------------- @app.get("/") async def root(): return {"message": "Speech-to-Text API (WhisperX)", "docs": "/docs", "health": "/health"} @app.get("/health") async def health(): return { "status": "healthy", "model_loaded": asr_model is not None, "model_size": MODEL_SIZE, "device": DEVICE, "cpu_threads": CPU_THREADS, "max_concurrent_transcribes": MAX_CONCURRENT_TRANSCRIBES, "transcribe_timeout_sec": TRANSCRIBE_TIMEOUT_SEC, "chunk_size_sec": CHUNK_SIZE_SEC, "overlap_sec": OVERLAP_SEC, } @app.post( "/transcribe", response_model=TranscriptionResponse, responses={ 400: {"model": ErrorResponse}, 413: {"model": ErrorResponse}, 429: {"model": ErrorResponse}, 500: {"model": ErrorResponse}, 504: {"model": ErrorResponse}, }, ) async def transcribe_file( file: UploadFile = File(...), language: Optional[str] = Form(None), task: str = Form("transcribe"), beam_size: int = Form(5, ge=1, le=10), temperature: float = Form(0.0, ge=0.0, le=1.0), return_timestamps: bool = Form(False), ): temp_input_path: Optional[Path] = None temp_wav_path: Optional[Path] = None try: content = await file.read() size = len(content) if size > MAX_FILE_SIZE_MB * 1024 * 1024: raise HTTPException( status_code=413, detail=f"File too large: {size / (1024*1024):.2f} MB > {MAX_FILE_SIZE_MB} MB", ) suffix = Path(file.filename or "upload.audio").suffix or ".audio" tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix) temp_input_path = Path(tmp.name) async with aiofiles.open(temp_input_path, "wb") as f: await f.write(content) temp_wav_path = await convert_to_wav(temp_input_path) result = await transcribe_audio( audio_path=temp_wav_path, language=language, task=task, beam_size=beam_size, temperature=temperature, return_timestamps=return_timestamps, ) return TranscriptionResponse( text=result["text"], language=result["language"], duration=result["duration"], segments=[TranscriptionSegment(**s) for s in result["segments"]] if result["segments"] else None, ) except HTTPException: raise except Exception as e: logger.exception("Transcription failed") raise HTTPException(status_code=500, detail=f"Transcription error: {str(e)}") finally: cleanup_paths(temp_input_path, temp_wav_path) @app.post( "/transcribe-url", response_model=TranscriptionResponse, responses={ 400: {"model": ErrorResponse}, 413: {"model": ErrorResponse}, 500: {"model": ErrorResponse}, 504: {"model": ErrorResponse}, }, ) async def transcribe_url(request: TranscriptionRequest): temp_audio_path: Optional[Path] = None temp_wav_path: Optional[Path] = None try: temp_audio_path = await download_audio(str(request.url)) temp_wav_path = await convert_to_wav(temp_audio_path) result = await transcribe_audio( audio_path=temp_wav_path, language=request.options.language, task=request.options.task, beam_size=request.options.beam_size, temperature=request.options.temperature, return_timestamps=request.options.return_timestamps, ) return TranscriptionResponse( text=result["text"], language=result["language"], duration=result["duration"], segments=[TranscriptionSegment(**s) for s in result["segments"]] if result["segments"] else None, ) except HTTPException: raise except Exception as e: logger.exception("URL transcription failed") raise HTTPException(status_code=500, detail=f"Transcription error: {str(e)}") finally: cleanup_paths(temp_audio_path, temp_wav_path) # ---------------------------- # Logging middleware # ---------------------------- class LoggingMiddleware(BaseHTTPMiddleware): async def dispatch(self, request, call_next): start = time.time() response = await call_next(request) logger.info(f"{request.method} {request.url.path} - {response.status_code} - {time.time() - start:.3f}s") return response app.add_middleware(LoggingMiddleware) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=PORT)