import os import sys import shutil import tempfile import logging from contextlib import asynccontextmanager from typing import List, Optional import torch from fastapi import FastAPI, File, UploadFile, Form, HTTPException from fastapi.responses import JSONResponse from pydantic import BaseModel, Field # --------------------------------------------------------------------------- # Redirect HF / PyTorch caches to /tmp (required by HF Spaces) # --------------------------------------------------------------------------- os.environ.setdefault("HF_HOME", "/tmp/hf_cache") os.environ.setdefault("TORCH_HOME", "/tmp/torch_cache") # Now import the aligner – it will honour the cache env vars. from ctc_forced_aligner import ( load_audio, load_alignment_model, generate_emissions, preprocess_text, get_alignments, get_spans, postprocess_results, ) # --------------------------------------------------------------------------- # Logging # --------------------------------------------------------------------------- logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Global variable for model, tokenizer and device # --------------------------------------------------------------------------- model = None tokenizer = None device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if device == "cuda" else torch.float32 # --------------------------------------------------------------------------- # Pydantic models for Swagger documentation # --------------------------------------------------------------------------- class Segment(BaseModel): start: float = Field(..., description="Segment start time in seconds") end: float = Field(..., description="Segment end time in seconds") text: str = Field(..., description="Aligned text of the segment") class AlignmentResponse(BaseModel): text: str = Field(..., description="Full, joined text that was aligned") segments: List[Segment] = Field(..., description="List of aligned word segments") # --------------------------------------------------------------------------- # App lifespan – download/load the model once at startup # --------------------------------------------------------------------------- @asynccontextmanager async def lifespan(app: FastAPI): global model, tokenizer logger.info(f"Loading alignment model on device: {device}") model, tokenizer = load_alignment_model( device=device, model_path="MahmoudAshraf/mms-300m-1130-forced-aligner", dtype=dtype, ) logger.info("Model loaded successfully") yield # Cleanup (optional – HF Spaces will kill the container anyway) del model, tokenizer app = FastAPI( title="Forced Alignment API", description="Align text to audio using the MMS‑300M forced aligner model. " "Supports 1130+ languages.", version="1.0.0", lifespan=lifespan, ) # --------------------------------------------------------------------------- # Health endpoint # --------------------------------------------------------------------------- @app.get("/health", tags=["health"]) async def health(): return {"status": "ok", "device": device, "model_loaded": model is not None} # --------------------------------------------------------------------------- # Core alignment endpoint # --------------------------------------------------------------------------- @app.post("/align", response_model=AlignmentResponse, tags=["alignment"]) async def align( audio: UploadFile = File(..., description="Audio file (WAV, MP3, etc.)"), text: str = Form(..., description="Text to align (plain string)"), language: str = Form( ..., description="ISO‑639‑3 language code (e.g., 'eng', 'ara', 'rus')" ), romanize: bool = Form( True, description="Whether to romanise non‑Latin scripts (required for default model)", ), batch_size: int = Form(4, description="Batch size for inference"), ): """ Align `text` to the provided `audio` and return word‑level timestamps. """ # Save uploaded audio to a temporary file (under /tmp for HF Spaces) tmp_dir = tempfile.mkdtemp(dir="/tmp") audio_path = os.path.join(tmp_dir, "audio") try: with open(audio_path, "wb") as buffer: shutil.copyfileobj(audio.file, buffer) # ----- 1. Load audio waveform ----- audio_waveform = load_audio(audio_path, model.dtype, model.device) # ----- 2. Prepare text ----- text_clean = text.strip() if not text_clean: raise HTTPException(status_code=400, detail="Text must not be empty") # ----- 3. Generate emissions (log probabilities) ----- emissions, stride = generate_emissions( model, audio_waveform, batch_size=batch_size ) # ----- 4. Pre‑process text (star tokens, romanisation) ----- tokens_starred, text_starred = preprocess_text( text_clean, romanize=romanize, language=language ) # ----- 5. Get alignments ----- segments_raw, scores, blank_id = get_alignments( emissions, tokens_starred, tokenizer ) # ----- 6. Convert to word spans ----- spans = get_spans(tokens_starred, segments_raw, blank_id) # ----- 7. Post‑process into final word timestamps ----- word_timestamps = postprocess_results(text_starred, spans, stride, scores) # Build response segments_out = [ Segment(start=seg["start"], end=seg["end"], text=seg["text"]) for seg in word_timestamps ] return AlignmentResponse(text=text_clean, segments=segments_out) except Exception as e: logger.exception("Alignment failed") raise HTTPException(status_code=500, detail=str(e)) finally: # Clean up temporary folder shutil.rmtree(tmp_dir, ignore_errors=True)