fastwhisper_api / app.py
sidmazak's picture
Update app.py
faea4ca verified
"""
Production-ready FastAPI service for speech-to-text using WhisperX.
Optimized for CPU-only environments with controlled single-flight inference.
Word timestamps are obtained via forced alignment (Wav2Vec2).
Chunking logic automatically handles large audio files by splitting them into
overlapping chunks, transcribing each, and merging the results smartly.
"""
import asyncio
import logging
import os
import tempfile
import time
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Optional, List, Dict, Any
import aiofiles
import httpx
import whisperx
import torch
import numpy as np
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from pydantic import BaseModel, Field, HttpUrl
from starlette.middleware.base import BaseHTTPMiddleware
from whisperx.alignment import load_align_model, align
from huggingface_hub import snapshot_download
from dotenv import load_dotenv
load_dotenv()
# ----------------------------
# Config
# ----------------------------
logging.basicConfig(
level=os.getenv("LOG_LEVEL", "INFO"),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
MODEL_SIZE = os.getenv("MODEL_SIZE", "small")
MAX_FILE_SIZE_MB = int(os.getenv("MAX_FILE_SIZE_MB", "200"))
PORT = int(os.getenv("PORT", "7860"))
TRANSCRIBE_TIMEOUT_SEC = int(os.getenv("TRANSCRIBE_TIMEOUT_SEC", "300"))
MAX_CONCURRENT_TRANSCRIBES = int(os.getenv("MAX_CONCURRENT_TRANSCRIBES", "1"))
# Chunking parameters for large audio files (seconds)
CHUNK_SIZE_SEC = int(os.getenv("CHUNK_SIZE_SEC", "30"))
OVERLAP_SEC = int(os.getenv("OVERLAP_SEC", "5"))
DEVICE = "cpu"
HOST_CPU = os.cpu_count() or 4
CPU_THREADS = int(os.getenv("CPU_THREADS", str(max(1, HOST_CPU - 1))))
os.environ["OMP_NUM_THREADS"] = str(CPU_THREADS)
os.environ["MKL_NUM_THREADS"] = str(CPU_THREADS)
os.environ["OPENBLAS_NUM_THREADS"] = str(CPU_THREADS)
os.environ["NUMEXPR_NUM_THREADS"] = str(CPU_THREADS)
torch.set_num_threads(CPU_THREADS)
torch.set_num_interop_threads(1)
# Custom alignment model mapping – values are directory paths or HF repo IDs
CUSTOM_ALIGN_MODELS = {
"hi": os.getenv("HINDI_ALIGN_MODEL_PATH", "./models/hindi_alignment"),
}
# Global ASR model
asr_model: Optional[Any] = None
# Alignment models cache: key = language code, value = (model, metadata)
align_cache: Dict[str, tuple] = {}
# Concurrency gate
transcribe_gate = asyncio.Semaphore(MAX_CONCURRENT_TRANSCRIBES)
# ----------------------------
# Schemas
# ----------------------------
class TranscriptionOptions(BaseModel):
language: Optional[str] = Field(None, description="Language code, e.g. 'hi'")
task: str = Field("transcribe", pattern="^(transcribe|translate)$")
beam_size: int = Field(5, ge=1, le=10) # not used by WhisperX
temperature: float = Field(0.0, ge=0.0, le=1.0) # not used
return_timestamps: bool = Field(False, description="Include segment and word timestamps")
class TranscriptionRequest(BaseModel):
url: HttpUrl
options: TranscriptionOptions = Field(default_factory=TranscriptionOptions)
class TranscriptionWord(BaseModel):
word: str
start: float
end: float
probability: Optional[float] = None
class TranscriptionSegment(BaseModel):
start: float
end: float
text: str
words: Optional[List[TranscriptionWord]] = None
class TranscriptionResponse(BaseModel):
text: str
language: Optional[str] = None
duration: Optional[float] = None
segments: Optional[List[TranscriptionSegment]] = None
class ErrorResponse(BaseModel):
error: str
detail: Optional[str] = None
# ----------------------------
# Lifespan
# ----------------------------
@asynccontextmanager
async def lifespan(app: FastAPI):
global asr_model
logger.info(
f"Loading WhisperX model={MODEL_SIZE} on device={DEVICE}, "
f"CPU_THREADS={CPU_THREADS} (host={HOST_CPU})"
)
# Load ASR model
asr_model = whisperx.load_model(
MODEL_SIZE,
device=DEVICE,
compute_type="float32",
language=None,
)
logger.info("WhisperX ASR model loaded")
# Optionally pre‑load Hindi alignment (will work now)
try:
_get_align_model("hi")
logger.info("Hindi alignment model loaded and cached")
except Exception as e:
logger.warning(f"Hindi alignment model not pre-loaded: {e}")
yield
asr_model = None
align_cache.clear()
logger.info("Models unloaded")
app = FastAPI(
title="Speech-to-Text API",
version="2.1.0",
lifespan=lifespan,
docs_url="/docs",
redoc_url="/redoc",
)
# ----------------------------
# Helpers
# ----------------------------
def download_hindi_alignment_model(target_dir: str) -> str:
"""
Download the Hindi Wav2Vec2 alignment model.
Uses the public repo theainerd/Wav2Vec2-large-xlsr-hindi.
"""
os.makedirs(target_dir, exist_ok=True)
if os.path.isfile(os.path.join(target_dir, "pytorch_model.bin")):
logger.info("Hindi alignment model already present at %s", target_dir)
return target_dir
token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")
repo_id = os.getenv("HINDI_ALIGN_REPO", "theainerd/Wav2Vec2-large-xlsr-hindi")
logger.info("Downloading Hindi alignment model from %s...", repo_id)
snapshot_download(
repo_id=repo_id,
local_dir=target_dir,
local_dir_use_symlinks=False,
token=token,
)
logger.info("Download complete to %s", target_dir)
return target_dir
async def download_audio(url: str, timeout_sec: int = 90) -> Path:
try:
async with httpx.AsyncClient(timeout=timeout_sec, follow_redirects=True) as client:
response = await client.get(url)
response.raise_for_status()
content = response.content
size = len(content)
if size > MAX_FILE_SIZE_MB * 1024 * 1024:
raise HTTPException(
status_code=413,
detail=f"File too large: {size / (1024*1024):.2f} MB > {MAX_FILE_SIZE_MB} MB",
)
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".audio")
async with aiofiles.open(temp_file.name, "wb") as f:
await f.write(content)
logger.info(f"Downloaded audio from {url} to {temp_file.name} ({size} bytes)")
return Path(temp_file.name)
except httpx.TimeoutException:
raise HTTPException(status_code=504, detail="Download timeout")
except httpx.HTTPStatusError as e:
raise HTTPException(status_code=400, detail=f"Download failed: {str(e)}")
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected download error: {str(e)}")
async def convert_to_wav(input_path: Path) -> Path:
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_out:
output_path = Path(tmp_out.name)
cmd = [
"ffmpeg",
"-i", str(input_path),
"-acodec", "pcm_s16le",
"-ar", "16000",
"-ac", "1",
"-y",
str(output_path),
]
process = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
_, stderr = await process.communicate()
if process.returncode != 0:
try:
output_path.unlink(missing_ok=True)
except Exception:
pass
logger.error(f"ffmpeg conversion failed: {stderr.decode(errors='ignore')}")
raise HTTPException(status_code=400, detail="Audio conversion failed")
logger.info(f"Converted {input_path} -> {output_path}")
return output_path
async def transcribe_audio(
audio_path: Path,
language: Optional[str] = None,
task: str = "transcribe",
beam_size: int = 5,
temperature: float = 0.0,
return_timestamps: bool = False,
) -> Dict[str, Any]:
global asr_model
if asr_model is None:
raise RuntimeError("ASR model not loaded")
async with transcribe_gate:
logger.info("Transcription slot acquired")
try:
return await asyncio.wait_for(
asyncio.to_thread(
_transcribe_sync,
audio_path, language, task, return_timestamps,
),
timeout=TRANSCRIBE_TIMEOUT_SEC,
)
except asyncio.TimeoutError:
raise HTTPException(status_code=504, detail=f"Transcription timed out after {TRANSCRIBE_TIMEOUT_SEC}s")
def _get_align_model(lang_code: str):
"""Load (or retrieve from cache) the alignment model for a given language.
Custom models are downloaded automatically if missing."""
if lang_code in align_cache:
logger.info("Using cached alignment model for %s", lang_code)
return align_cache[lang_code]
# Custom model (e.g., Hindi)
if lang_code in CUSTOM_ALIGN_MODELS:
model_source = CUSTOM_ALIGN_MODELS[lang_code] # directory path or HF repo ID
# If it's a local directory that doesn't exist yet, auto‑download
if not os.path.exists(model_source):
if lang_code == "hi":
logger.info("Hindi alignment model not found at %s, initiating download", model_source)
model_source = download_hindi_alignment_model(model_source)
CUSTOM_ALIGN_MODELS[lang_code] = model_source # update mapping
else:
raise HTTPException(
status_code=400,
detail=f"No alignment model found for language '{lang_code}' at {model_source}."
)
logger.info("Loading custom alignment model for %s from %s", lang_code, model_source)
try:
# FIX: use model_name= (not model_path=)
model, metadata = load_align_model(
language_code=lang_code,
device=DEVICE,
model_name=model_source, # ← CORRECT PARAMETER
)
align_cache[lang_code] = (model, metadata)
return model, metadata
except Exception as e:
logger.error("Custom alignment model for %s failed: %s", lang_code, e)
raise HTTPException(status_code=500, detail=f"Alignment model for '{lang_code}' could not be loaded")
else:
# Default WhisperX models (en, fr, de, …)
logger.info("Loading default alignment model for language: %s", lang_code)
try:
model, metadata = load_align_model(language_code=lang_code, device=DEVICE)
align_cache[lang_code] = (model, metadata)
return model, metadata
except Exception as e:
logger.error("Could not load alignment model for %s: %s", lang_code, e)
raise HTTPException(
status_code=400,
detail=f"Alignment model not available for language '{lang_code}'."
)
def _avg_word_prob(words: Optional[List[Dict]]) -> float:
"""Return average probability from a list of word dicts (looks for 'score' key)."""
if not words:
return 0.0
probs = [w.get("score", 0.0) for w in words]
return sum(probs) / len(probs) if probs else 0.0
def _merge_segments(segments: List[Dict]) -> List[Dict]:
"""Merge overlapping segments by keeping the one with higher average word probability."""
if not segments:
return []
sorted_segs = sorted(segments, key=lambda s: s["start"])
merged = []
for seg in sorted_segs:
if not merged:
merged.append(seg)
continue
prev = merged[-1]
# Overlap if current segment starts before previous ends (with small tolerance)
if seg["start"] < prev["end"] - 0.01:
prob_prev = _avg_word_prob(prev.get("words"))
prob_seg = _avg_word_prob(seg.get("words"))
if prob_seg > prob_prev:
merged[-1] = seg
elif prob_seg == prob_prev:
# Keep the longer segment if probabilities equal
if (seg["end"] - seg["start"]) > (prev["end"] - prev["start"]):
merged[-1] = seg
# else keep previous
else:
merged.append(seg)
return merged
def _transcribe_sync(
audio_path: Path,
language: Optional[str],
task: str,
return_timestamps: bool,
) -> Dict[str, Any]:
"""Transcription logic with automatic chunking for large audio files."""
full_audio = whisperx.load_audio(str(audio_path))
full_duration = full_audio.shape[0] / 16000
# If audio is short, use single-pass
if full_duration <= CHUNK_SIZE_SEC:
logger.info("Short audio, using single-pass transcription")
result = asr_model.transcribe(
full_audio,
batch_size=16,
language=language,
task=task,
)
if return_timestamps:
lang_code = language if language else result.get("language", "en")
try:
align_model, align_metadata = _get_align_model(lang_code)
result = align(
result["segments"],
align_model,
align_metadata,
full_audio,
device=DEVICE,
return_char_alignments=False,
)
except HTTPException:
logger.warning("Alignment model unavailable for %s; returning segments without word timestamps", lang_code)
result["segments"] = result.get("segments", [])
except Exception as e:
logger.warning(f"Alignment failed: {e}; returning without word timestamps")
result["segments"] = result.get("segments", [])
all_text = []
segments_out = []
for seg in result.get("segments", []):
text = seg["text"].strip()
all_text.append(text)
word_list = None
if return_timestamps and "words" in seg:
word_list = [
{
"word": w["word"].strip(),
"start": w.get("start", 0.0),
"end": w.get("end", 0.0),
"probability": w.get("score", None),
}
for w in seg["words"]
]
segments_out.append({
"start": seg["start"],
"end": seg["end"],
"text": text,
"words": word_list,
})
return {
"text": " ".join(all_text).strip(),
"language": result.get("language"),
"duration": full_duration,
"segments": segments_out if return_timestamps else None,
}
# -----------------------------------
# Chunking for large audio files
# -----------------------------------
logger.info(f"Audio duration {full_duration:.1f}s > {CHUNK_SIZE_SEC}s, processing in overlapping chunks")
stride = CHUNK_SIZE_SEC - OVERLAP_SEC
# Generate chunk start times
start_times = []
t = 0.0
while t < full_duration:
start_times.append(t)
t += stride
all_segments = []
detected_lang = None
for idx, chunk_start in enumerate(start_times):
chunk_end = min(chunk_start + CHUNK_SIZE_SEC, full_duration)
start_sample = int(chunk_start * 16000)
end_sample = int(chunk_end * 16000)
chunk_audio = full_audio[start_sample:end_sample]
logger.debug(f"Chunk {idx+1}/{len(start_times)}: {chunk_start:.1f}s – {chunk_end:.1f}s")
result_chunk = asr_model.transcribe(
chunk_audio,
batch_size=16,
language=language,
task=task,
)
if detected_lang is None:
detected_lang = result_chunk.get("language")
lang_code = language if language else result_chunk.get("language", "en")
# Align if requested
if return_timestamps:
try:
align_model, align_metadata = _get_align_model(lang_code)
result_chunk = align(
result_chunk["segments"],
align_model,
align_metadata,
chunk_audio,
device=DEVICE,
return_char_alignments=False,
)
except HTTPException:
logger.warning(f"Alignment model unavailable for {lang_code} in chunk {idx+1}")
except Exception as e:
logger.warning(f"Alignment failed for chunk {idx+1}: {e}")
# Adjust timestamps to global time and collect segments
for seg in result_chunk.get("segments", []):
seg["start"] += chunk_start
seg["end"] += chunk_start
if "words" in seg:
for w in seg["words"]:
if "start" in w:
w["start"] += chunk_start
if "end" in w:
w["end"] += chunk_start
all_segments.append(seg)
# Merge overlapping segments intelligently
merged_segments = _merge_segments(all_segments)
# Assemble final output
all_text = []
segments_out = []
for seg in merged_segments:
text = seg["text"].strip()
all_text.append(text)
word_list = None
if return_timestamps and "words" in seg:
word_list = [
{
"word": w["word"].strip(),
"start": w.get("start", 0.0),
"end": w.get("end", 0.0),
"probability": w.get("score", None),
}
for w in seg["words"]
]
segments_out.append({
"start": seg["start"],
"end": seg["end"],
"text": text,
"words": word_list,
})
return {
"text": " ".join(all_text).strip(),
"language": detected_lang,
"duration": full_duration,
"segments": segments_out if return_timestamps else None,
}
def cleanup_paths(*paths: Optional[Path]):
for p in paths:
if p is None:
continue
try:
p.unlink(missing_ok=True)
except Exception as e:
logger.warning(f"Failed to delete temp file {p}: {e}")
# ----------------------------
# Endpoints
# ----------------------------
@app.get("/")
async def root():
return {"message": "Speech-to-Text API (WhisperX)", "docs": "/docs", "health": "/health"}
@app.get("/health")
async def health():
return {
"status": "healthy",
"model_loaded": asr_model is not None,
"model_size": MODEL_SIZE,
"device": DEVICE,
"cpu_threads": CPU_THREADS,
"max_concurrent_transcribes": MAX_CONCURRENT_TRANSCRIBES,
"transcribe_timeout_sec": TRANSCRIBE_TIMEOUT_SEC,
"chunk_size_sec": CHUNK_SIZE_SEC,
"overlap_sec": OVERLAP_SEC,
}
@app.post(
"/transcribe",
response_model=TranscriptionResponse,
responses={
400: {"model": ErrorResponse},
413: {"model": ErrorResponse},
429: {"model": ErrorResponse},
500: {"model": ErrorResponse},
504: {"model": ErrorResponse},
},
)
async def transcribe_file(
file: UploadFile = File(...),
language: Optional[str] = Form(None),
task: str = Form("transcribe"),
beam_size: int = Form(5, ge=1, le=10),
temperature: float = Form(0.0, ge=0.0, le=1.0),
return_timestamps: bool = Form(False),
):
temp_input_path: Optional[Path] = None
temp_wav_path: Optional[Path] = None
try:
content = await file.read()
size = len(content)
if size > MAX_FILE_SIZE_MB * 1024 * 1024:
raise HTTPException(
status_code=413,
detail=f"File too large: {size / (1024*1024):.2f} MB > {MAX_FILE_SIZE_MB} MB",
)
suffix = Path(file.filename or "upload.audio").suffix or ".audio"
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
temp_input_path = Path(tmp.name)
async with aiofiles.open(temp_input_path, "wb") as f:
await f.write(content)
temp_wav_path = await convert_to_wav(temp_input_path)
result = await transcribe_audio(
audio_path=temp_wav_path,
language=language,
task=task,
beam_size=beam_size,
temperature=temperature,
return_timestamps=return_timestamps,
)
return TranscriptionResponse(
text=result["text"],
language=result["language"],
duration=result["duration"],
segments=[TranscriptionSegment(**s) for s in result["segments"]] if result["segments"] else None,
)
except HTTPException:
raise
except Exception as e:
logger.exception("Transcription failed")
raise HTTPException(status_code=500, detail=f"Transcription error: {str(e)}")
finally:
cleanup_paths(temp_input_path, temp_wav_path)
@app.post(
"/transcribe-url",
response_model=TranscriptionResponse,
responses={
400: {"model": ErrorResponse},
413: {"model": ErrorResponse},
500: {"model": ErrorResponse},
504: {"model": ErrorResponse},
},
)
async def transcribe_url(request: TranscriptionRequest):
temp_audio_path: Optional[Path] = None
temp_wav_path: Optional[Path] = None
try:
temp_audio_path = await download_audio(str(request.url))
temp_wav_path = await convert_to_wav(temp_audio_path)
result = await transcribe_audio(
audio_path=temp_wav_path,
language=request.options.language,
task=request.options.task,
beam_size=request.options.beam_size,
temperature=request.options.temperature,
return_timestamps=request.options.return_timestamps,
)
return TranscriptionResponse(
text=result["text"],
language=result["language"],
duration=result["duration"],
segments=[TranscriptionSegment(**s) for s in result["segments"]] if result["segments"] else None,
)
except HTTPException:
raise
except Exception as e:
logger.exception("URL transcription failed")
raise HTTPException(status_code=500, detail=f"Transcription error: {str(e)}")
finally:
cleanup_paths(temp_audio_path, temp_wav_path)
# ----------------------------
# Logging middleware
# ----------------------------
class LoggingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
start = time.time()
response = await call_next(request)
logger.info(f"{request.method} {request.url.path} - {response.status_code} - {time.time() - start:.3f}s")
return response
app.add_middleware(LoggingMiddleware)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=PORT)