Spaces:
Running
Running
| """ | |
| Production-ready FastAPI service for speech-to-text using WhisperX. | |
| Optimized for CPU-only environments with controlled single-flight inference. | |
| Word timestamps are obtained via forced alignment (Wav2Vec2). | |
| Chunking logic automatically handles large audio files by splitting them into | |
| overlapping chunks, transcribing each, and merging the results smartly. | |
| """ | |
| import asyncio | |
| import logging | |
| import os | |
| import tempfile | |
| import time | |
| from contextlib import asynccontextmanager | |
| from pathlib import Path | |
| from typing import Optional, List, Dict, Any | |
| import aiofiles | |
| import httpx | |
| import whisperx | |
| import torch | |
| import numpy as np | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
| from pydantic import BaseModel, Field, HttpUrl | |
| from starlette.middleware.base import BaseHTTPMiddleware | |
| from whisperx.alignment import load_align_model, align | |
| from huggingface_hub import snapshot_download | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # ---------------------------- | |
| # Config | |
| # ---------------------------- | |
| logging.basicConfig( | |
| level=os.getenv("LOG_LEVEL", "INFO"), | |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| MODEL_SIZE = os.getenv("MODEL_SIZE", "small") | |
| MAX_FILE_SIZE_MB = int(os.getenv("MAX_FILE_SIZE_MB", "200")) | |
| PORT = int(os.getenv("PORT", "7860")) | |
| TRANSCRIBE_TIMEOUT_SEC = int(os.getenv("TRANSCRIBE_TIMEOUT_SEC", "300")) | |
| MAX_CONCURRENT_TRANSCRIBES = int(os.getenv("MAX_CONCURRENT_TRANSCRIBES", "1")) | |
| # Chunking parameters for large audio files (seconds) | |
| CHUNK_SIZE_SEC = int(os.getenv("CHUNK_SIZE_SEC", "30")) | |
| OVERLAP_SEC = int(os.getenv("OVERLAP_SEC", "5")) | |
| DEVICE = "cpu" | |
| HOST_CPU = os.cpu_count() or 4 | |
| CPU_THREADS = int(os.getenv("CPU_THREADS", str(max(1, HOST_CPU - 1)))) | |
| os.environ["OMP_NUM_THREADS"] = str(CPU_THREADS) | |
| os.environ["MKL_NUM_THREADS"] = str(CPU_THREADS) | |
| os.environ["OPENBLAS_NUM_THREADS"] = str(CPU_THREADS) | |
| os.environ["NUMEXPR_NUM_THREADS"] = str(CPU_THREADS) | |
| torch.set_num_threads(CPU_THREADS) | |
| torch.set_num_interop_threads(1) | |
| # Custom alignment model mapping – values are directory paths or HF repo IDs | |
| CUSTOM_ALIGN_MODELS = { | |
| "hi": os.getenv("HINDI_ALIGN_MODEL_PATH", "./models/hindi_alignment"), | |
| } | |
| # Global ASR model | |
| asr_model: Optional[Any] = None | |
| # Alignment models cache: key = language code, value = (model, metadata) | |
| align_cache: Dict[str, tuple] = {} | |
| # Concurrency gate | |
| transcribe_gate = asyncio.Semaphore(MAX_CONCURRENT_TRANSCRIBES) | |
| # ---------------------------- | |
| # Schemas | |
| # ---------------------------- | |
| class TranscriptionOptions(BaseModel): | |
| language: Optional[str] = Field(None, description="Language code, e.g. 'hi'") | |
| task: str = Field("transcribe", pattern="^(transcribe|translate)$") | |
| beam_size: int = Field(5, ge=1, le=10) # not used by WhisperX | |
| temperature: float = Field(0.0, ge=0.0, le=1.0) # not used | |
| return_timestamps: bool = Field(False, description="Include segment and word timestamps") | |
| class TranscriptionRequest(BaseModel): | |
| url: HttpUrl | |
| options: TranscriptionOptions = Field(default_factory=TranscriptionOptions) | |
| class TranscriptionWord(BaseModel): | |
| word: str | |
| start: float | |
| end: float | |
| probability: Optional[float] = None | |
| class TranscriptionSegment(BaseModel): | |
| start: float | |
| end: float | |
| text: str | |
| words: Optional[List[TranscriptionWord]] = None | |
| class TranscriptionResponse(BaseModel): | |
| text: str | |
| language: Optional[str] = None | |
| duration: Optional[float] = None | |
| segments: Optional[List[TranscriptionSegment]] = None | |
| class ErrorResponse(BaseModel): | |
| error: str | |
| detail: Optional[str] = None | |
| # ---------------------------- | |
| # Lifespan | |
| # ---------------------------- | |
| async def lifespan(app: FastAPI): | |
| global asr_model | |
| logger.info( | |
| f"Loading WhisperX model={MODEL_SIZE} on device={DEVICE}, " | |
| f"CPU_THREADS={CPU_THREADS} (host={HOST_CPU})" | |
| ) | |
| # Load ASR model | |
| asr_model = whisperx.load_model( | |
| MODEL_SIZE, | |
| device=DEVICE, | |
| compute_type="float32", | |
| language=None, | |
| ) | |
| logger.info("WhisperX ASR model loaded") | |
| # Optionally pre‑load Hindi alignment (will work now) | |
| try: | |
| _get_align_model("hi") | |
| logger.info("Hindi alignment model loaded and cached") | |
| except Exception as e: | |
| logger.warning(f"Hindi alignment model not pre-loaded: {e}") | |
| yield | |
| asr_model = None | |
| align_cache.clear() | |
| logger.info("Models unloaded") | |
| app = FastAPI( | |
| title="Speech-to-Text API", | |
| version="2.1.0", | |
| lifespan=lifespan, | |
| docs_url="/docs", | |
| redoc_url="/redoc", | |
| ) | |
| # ---------------------------- | |
| # Helpers | |
| # ---------------------------- | |
| def download_hindi_alignment_model(target_dir: str) -> str: | |
| """ | |
| Download the Hindi Wav2Vec2 alignment model. | |
| Uses the public repo theainerd/Wav2Vec2-large-xlsr-hindi. | |
| """ | |
| os.makedirs(target_dir, exist_ok=True) | |
| if os.path.isfile(os.path.join(target_dir, "pytorch_model.bin")): | |
| logger.info("Hindi alignment model already present at %s", target_dir) | |
| return target_dir | |
| token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN") | |
| repo_id = os.getenv("HINDI_ALIGN_REPO", "theainerd/Wav2Vec2-large-xlsr-hindi") | |
| logger.info("Downloading Hindi alignment model from %s...", repo_id) | |
| snapshot_download( | |
| repo_id=repo_id, | |
| local_dir=target_dir, | |
| local_dir_use_symlinks=False, | |
| token=token, | |
| ) | |
| logger.info("Download complete to %s", target_dir) | |
| return target_dir | |
| async def download_audio(url: str, timeout_sec: int = 90) -> Path: | |
| try: | |
| async with httpx.AsyncClient(timeout=timeout_sec, follow_redirects=True) as client: | |
| response = await client.get(url) | |
| response.raise_for_status() | |
| content = response.content | |
| size = len(content) | |
| if size > MAX_FILE_SIZE_MB * 1024 * 1024: | |
| raise HTTPException( | |
| status_code=413, | |
| detail=f"File too large: {size / (1024*1024):.2f} MB > {MAX_FILE_SIZE_MB} MB", | |
| ) | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".audio") | |
| async with aiofiles.open(temp_file.name, "wb") as f: | |
| await f.write(content) | |
| logger.info(f"Downloaded audio from {url} to {temp_file.name} ({size} bytes)") | |
| return Path(temp_file.name) | |
| except httpx.TimeoutException: | |
| raise HTTPException(status_code=504, detail="Download timeout") | |
| except httpx.HTTPStatusError as e: | |
| raise HTTPException(status_code=400, detail=f"Download failed: {str(e)}") | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Unexpected download error: {str(e)}") | |
| async def convert_to_wav(input_path: Path) -> Path: | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_out: | |
| output_path = Path(tmp_out.name) | |
| cmd = [ | |
| "ffmpeg", | |
| "-i", str(input_path), | |
| "-acodec", "pcm_s16le", | |
| "-ar", "16000", | |
| "-ac", "1", | |
| "-y", | |
| str(output_path), | |
| ] | |
| process = await asyncio.create_subprocess_exec( | |
| *cmd, | |
| stdout=asyncio.subprocess.PIPE, | |
| stderr=asyncio.subprocess.PIPE, | |
| ) | |
| _, stderr = await process.communicate() | |
| if process.returncode != 0: | |
| try: | |
| output_path.unlink(missing_ok=True) | |
| except Exception: | |
| pass | |
| logger.error(f"ffmpeg conversion failed: {stderr.decode(errors='ignore')}") | |
| raise HTTPException(status_code=400, detail="Audio conversion failed") | |
| logger.info(f"Converted {input_path} -> {output_path}") | |
| return output_path | |
| async def transcribe_audio( | |
| audio_path: Path, | |
| language: Optional[str] = None, | |
| task: str = "transcribe", | |
| beam_size: int = 5, | |
| temperature: float = 0.0, | |
| return_timestamps: bool = False, | |
| ) -> Dict[str, Any]: | |
| global asr_model | |
| if asr_model is None: | |
| raise RuntimeError("ASR model not loaded") | |
| async with transcribe_gate: | |
| logger.info("Transcription slot acquired") | |
| try: | |
| return await asyncio.wait_for( | |
| asyncio.to_thread( | |
| _transcribe_sync, | |
| audio_path, language, task, return_timestamps, | |
| ), | |
| timeout=TRANSCRIBE_TIMEOUT_SEC, | |
| ) | |
| except asyncio.TimeoutError: | |
| raise HTTPException(status_code=504, detail=f"Transcription timed out after {TRANSCRIBE_TIMEOUT_SEC}s") | |
| def _get_align_model(lang_code: str): | |
| """Load (or retrieve from cache) the alignment model for a given language. | |
| Custom models are downloaded automatically if missing.""" | |
| if lang_code in align_cache: | |
| logger.info("Using cached alignment model for %s", lang_code) | |
| return align_cache[lang_code] | |
| # Custom model (e.g., Hindi) | |
| if lang_code in CUSTOM_ALIGN_MODELS: | |
| model_source = CUSTOM_ALIGN_MODELS[lang_code] # directory path or HF repo ID | |
| # If it's a local directory that doesn't exist yet, auto‑download | |
| if not os.path.exists(model_source): | |
| if lang_code == "hi": | |
| logger.info("Hindi alignment model not found at %s, initiating download", model_source) | |
| model_source = download_hindi_alignment_model(model_source) | |
| CUSTOM_ALIGN_MODELS[lang_code] = model_source # update mapping | |
| else: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"No alignment model found for language '{lang_code}' at {model_source}." | |
| ) | |
| logger.info("Loading custom alignment model for %s from %s", lang_code, model_source) | |
| try: | |
| # FIX: use model_name= (not model_path=) | |
| model, metadata = load_align_model( | |
| language_code=lang_code, | |
| device=DEVICE, | |
| model_name=model_source, # ← CORRECT PARAMETER | |
| ) | |
| align_cache[lang_code] = (model, metadata) | |
| return model, metadata | |
| except Exception as e: | |
| logger.error("Custom alignment model for %s failed: %s", lang_code, e) | |
| raise HTTPException(status_code=500, detail=f"Alignment model for '{lang_code}' could not be loaded") | |
| else: | |
| # Default WhisperX models (en, fr, de, …) | |
| logger.info("Loading default alignment model for language: %s", lang_code) | |
| try: | |
| model, metadata = load_align_model(language_code=lang_code, device=DEVICE) | |
| align_cache[lang_code] = (model, metadata) | |
| return model, metadata | |
| except Exception as e: | |
| logger.error("Could not load alignment model for %s: %s", lang_code, e) | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Alignment model not available for language '{lang_code}'." | |
| ) | |
| def _avg_word_prob(words: Optional[List[Dict]]) -> float: | |
| """Return average probability from a list of word dicts (looks for 'score' key).""" | |
| if not words: | |
| return 0.0 | |
| probs = [w.get("score", 0.0) for w in words] | |
| return sum(probs) / len(probs) if probs else 0.0 | |
| def _merge_segments(segments: List[Dict]) -> List[Dict]: | |
| """Merge overlapping segments by keeping the one with higher average word probability.""" | |
| if not segments: | |
| return [] | |
| sorted_segs = sorted(segments, key=lambda s: s["start"]) | |
| merged = [] | |
| for seg in sorted_segs: | |
| if not merged: | |
| merged.append(seg) | |
| continue | |
| prev = merged[-1] | |
| # Overlap if current segment starts before previous ends (with small tolerance) | |
| if seg["start"] < prev["end"] - 0.01: | |
| prob_prev = _avg_word_prob(prev.get("words")) | |
| prob_seg = _avg_word_prob(seg.get("words")) | |
| if prob_seg > prob_prev: | |
| merged[-1] = seg | |
| elif prob_seg == prob_prev: | |
| # Keep the longer segment if probabilities equal | |
| if (seg["end"] - seg["start"]) > (prev["end"] - prev["start"]): | |
| merged[-1] = seg | |
| # else keep previous | |
| else: | |
| merged.append(seg) | |
| return merged | |
| def _transcribe_sync( | |
| audio_path: Path, | |
| language: Optional[str], | |
| task: str, | |
| return_timestamps: bool, | |
| ) -> Dict[str, Any]: | |
| """Transcription logic with automatic chunking for large audio files.""" | |
| full_audio = whisperx.load_audio(str(audio_path)) | |
| full_duration = full_audio.shape[0] / 16000 | |
| # If audio is short, use single-pass | |
| if full_duration <= CHUNK_SIZE_SEC: | |
| logger.info("Short audio, using single-pass transcription") | |
| result = asr_model.transcribe( | |
| full_audio, | |
| batch_size=16, | |
| language=language, | |
| task=task, | |
| ) | |
| if return_timestamps: | |
| lang_code = language if language else result.get("language", "en") | |
| try: | |
| align_model, align_metadata = _get_align_model(lang_code) | |
| result = align( | |
| result["segments"], | |
| align_model, | |
| align_metadata, | |
| full_audio, | |
| device=DEVICE, | |
| return_char_alignments=False, | |
| ) | |
| except HTTPException: | |
| logger.warning("Alignment model unavailable for %s; returning segments without word timestamps", lang_code) | |
| result["segments"] = result.get("segments", []) | |
| except Exception as e: | |
| logger.warning(f"Alignment failed: {e}; returning without word timestamps") | |
| result["segments"] = result.get("segments", []) | |
| all_text = [] | |
| segments_out = [] | |
| for seg in result.get("segments", []): | |
| text = seg["text"].strip() | |
| all_text.append(text) | |
| word_list = None | |
| if return_timestamps and "words" in seg: | |
| word_list = [ | |
| { | |
| "word": w["word"].strip(), | |
| "start": w.get("start", 0.0), | |
| "end": w.get("end", 0.0), | |
| "probability": w.get("score", None), | |
| } | |
| for w in seg["words"] | |
| ] | |
| segments_out.append({ | |
| "start": seg["start"], | |
| "end": seg["end"], | |
| "text": text, | |
| "words": word_list, | |
| }) | |
| return { | |
| "text": " ".join(all_text).strip(), | |
| "language": result.get("language"), | |
| "duration": full_duration, | |
| "segments": segments_out if return_timestamps else None, | |
| } | |
| # ----------------------------------- | |
| # Chunking for large audio files | |
| # ----------------------------------- | |
| logger.info(f"Audio duration {full_duration:.1f}s > {CHUNK_SIZE_SEC}s, processing in overlapping chunks") | |
| stride = CHUNK_SIZE_SEC - OVERLAP_SEC | |
| # Generate chunk start times | |
| start_times = [] | |
| t = 0.0 | |
| while t < full_duration: | |
| start_times.append(t) | |
| t += stride | |
| all_segments = [] | |
| detected_lang = None | |
| for idx, chunk_start in enumerate(start_times): | |
| chunk_end = min(chunk_start + CHUNK_SIZE_SEC, full_duration) | |
| start_sample = int(chunk_start * 16000) | |
| end_sample = int(chunk_end * 16000) | |
| chunk_audio = full_audio[start_sample:end_sample] | |
| logger.debug(f"Chunk {idx+1}/{len(start_times)}: {chunk_start:.1f}s – {chunk_end:.1f}s") | |
| result_chunk = asr_model.transcribe( | |
| chunk_audio, | |
| batch_size=16, | |
| language=language, | |
| task=task, | |
| ) | |
| if detected_lang is None: | |
| detected_lang = result_chunk.get("language") | |
| lang_code = language if language else result_chunk.get("language", "en") | |
| # Align if requested | |
| if return_timestamps: | |
| try: | |
| align_model, align_metadata = _get_align_model(lang_code) | |
| result_chunk = align( | |
| result_chunk["segments"], | |
| align_model, | |
| align_metadata, | |
| chunk_audio, | |
| device=DEVICE, | |
| return_char_alignments=False, | |
| ) | |
| except HTTPException: | |
| logger.warning(f"Alignment model unavailable for {lang_code} in chunk {idx+1}") | |
| except Exception as e: | |
| logger.warning(f"Alignment failed for chunk {idx+1}: {e}") | |
| # Adjust timestamps to global time and collect segments | |
| for seg in result_chunk.get("segments", []): | |
| seg["start"] += chunk_start | |
| seg["end"] += chunk_start | |
| if "words" in seg: | |
| for w in seg["words"]: | |
| if "start" in w: | |
| w["start"] += chunk_start | |
| if "end" in w: | |
| w["end"] += chunk_start | |
| all_segments.append(seg) | |
| # Merge overlapping segments intelligently | |
| merged_segments = _merge_segments(all_segments) | |
| # Assemble final output | |
| all_text = [] | |
| segments_out = [] | |
| for seg in merged_segments: | |
| text = seg["text"].strip() | |
| all_text.append(text) | |
| word_list = None | |
| if return_timestamps and "words" in seg: | |
| word_list = [ | |
| { | |
| "word": w["word"].strip(), | |
| "start": w.get("start", 0.0), | |
| "end": w.get("end", 0.0), | |
| "probability": w.get("score", None), | |
| } | |
| for w in seg["words"] | |
| ] | |
| segments_out.append({ | |
| "start": seg["start"], | |
| "end": seg["end"], | |
| "text": text, | |
| "words": word_list, | |
| }) | |
| return { | |
| "text": " ".join(all_text).strip(), | |
| "language": detected_lang, | |
| "duration": full_duration, | |
| "segments": segments_out if return_timestamps else None, | |
| } | |
| def cleanup_paths(*paths: Optional[Path]): | |
| for p in paths: | |
| if p is None: | |
| continue | |
| try: | |
| p.unlink(missing_ok=True) | |
| except Exception as e: | |
| logger.warning(f"Failed to delete temp file {p}: {e}") | |
| # ---------------------------- | |
| # Endpoints | |
| # ---------------------------- | |
| async def root(): | |
| return {"message": "Speech-to-Text API (WhisperX)", "docs": "/docs", "health": "/health"} | |
| async def health(): | |
| return { | |
| "status": "healthy", | |
| "model_loaded": asr_model is not None, | |
| "model_size": MODEL_SIZE, | |
| "device": DEVICE, | |
| "cpu_threads": CPU_THREADS, | |
| "max_concurrent_transcribes": MAX_CONCURRENT_TRANSCRIBES, | |
| "transcribe_timeout_sec": TRANSCRIBE_TIMEOUT_SEC, | |
| "chunk_size_sec": CHUNK_SIZE_SEC, | |
| "overlap_sec": OVERLAP_SEC, | |
| } | |
| async def transcribe_file( | |
| file: UploadFile = File(...), | |
| language: Optional[str] = Form(None), | |
| task: str = Form("transcribe"), | |
| beam_size: int = Form(5, ge=1, le=10), | |
| temperature: float = Form(0.0, ge=0.0, le=1.0), | |
| return_timestamps: bool = Form(False), | |
| ): | |
| temp_input_path: Optional[Path] = None | |
| temp_wav_path: Optional[Path] = None | |
| try: | |
| content = await file.read() | |
| size = len(content) | |
| if size > MAX_FILE_SIZE_MB * 1024 * 1024: | |
| raise HTTPException( | |
| status_code=413, | |
| detail=f"File too large: {size / (1024*1024):.2f} MB > {MAX_FILE_SIZE_MB} MB", | |
| ) | |
| suffix = Path(file.filename or "upload.audio").suffix or ".audio" | |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix) | |
| temp_input_path = Path(tmp.name) | |
| async with aiofiles.open(temp_input_path, "wb") as f: | |
| await f.write(content) | |
| temp_wav_path = await convert_to_wav(temp_input_path) | |
| result = await transcribe_audio( | |
| audio_path=temp_wav_path, | |
| language=language, | |
| task=task, | |
| beam_size=beam_size, | |
| temperature=temperature, | |
| return_timestamps=return_timestamps, | |
| ) | |
| return TranscriptionResponse( | |
| text=result["text"], | |
| language=result["language"], | |
| duration=result["duration"], | |
| segments=[TranscriptionSegment(**s) for s in result["segments"]] if result["segments"] else None, | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.exception("Transcription failed") | |
| raise HTTPException(status_code=500, detail=f"Transcription error: {str(e)}") | |
| finally: | |
| cleanup_paths(temp_input_path, temp_wav_path) | |
| async def transcribe_url(request: TranscriptionRequest): | |
| temp_audio_path: Optional[Path] = None | |
| temp_wav_path: Optional[Path] = None | |
| try: | |
| temp_audio_path = await download_audio(str(request.url)) | |
| temp_wav_path = await convert_to_wav(temp_audio_path) | |
| result = await transcribe_audio( | |
| audio_path=temp_wav_path, | |
| language=request.options.language, | |
| task=request.options.task, | |
| beam_size=request.options.beam_size, | |
| temperature=request.options.temperature, | |
| return_timestamps=request.options.return_timestamps, | |
| ) | |
| return TranscriptionResponse( | |
| text=result["text"], | |
| language=result["language"], | |
| duration=result["duration"], | |
| segments=[TranscriptionSegment(**s) for s in result["segments"]] if result["segments"] else None, | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.exception("URL transcription failed") | |
| raise HTTPException(status_code=500, detail=f"Transcription error: {str(e)}") | |
| finally: | |
| cleanup_paths(temp_audio_path, temp_wav_path) | |
| # ---------------------------- | |
| # Logging middleware | |
| # ---------------------------- | |
| class LoggingMiddleware(BaseHTTPMiddleware): | |
| async def dispatch(self, request, call_next): | |
| start = time.time() | |
| response = await call_next(request) | |
| logger.info(f"{request.method} {request.url.path} - {response.status_code} - {time.time() - start:.3f}s") | |
| return response | |
| app.add_middleware(LoggingMiddleware) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=PORT) |