Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import shutil | |
| import tempfile | |
| import logging | |
| from contextlib import asynccontextmanager | |
| from typing import List, Optional | |
| import torch | |
| from fastapi import FastAPI, File, UploadFile, Form, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel, Field | |
| # --------------------------------------------------------------------------- | |
| # Redirect HF / PyTorch caches to /tmp (required by HF Spaces) | |
| # --------------------------------------------------------------------------- | |
| os.environ.setdefault("HF_HOME", "/tmp/hf_cache") | |
| os.environ.setdefault("TORCH_HOME", "/tmp/torch_cache") | |
| # Now import the aligner – it will honour the cache env vars. | |
| from ctc_forced_aligner import ( | |
| load_audio, | |
| load_alignment_model, | |
| generate_emissions, | |
| preprocess_text, | |
| get_alignments, | |
| get_spans, | |
| postprocess_results, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Logging | |
| # --------------------------------------------------------------------------- | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # Global variable for model, tokenizer and device | |
| # --------------------------------------------------------------------------- | |
| model = None | |
| tokenizer = None | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| # --------------------------------------------------------------------------- | |
| # Pydantic models for Swagger documentation | |
| # --------------------------------------------------------------------------- | |
| class Segment(BaseModel): | |
| start: float = Field(..., description="Segment start time in seconds") | |
| end: float = Field(..., description="Segment end time in seconds") | |
| text: str = Field(..., description="Aligned text of the segment") | |
| class AlignmentResponse(BaseModel): | |
| text: str = Field(..., description="Full, joined text that was aligned") | |
| segments: List[Segment] = Field(..., description="List of aligned word segments") | |
| # --------------------------------------------------------------------------- | |
| # App lifespan – download/load the model once at startup | |
| # --------------------------------------------------------------------------- | |
| async def lifespan(app: FastAPI): | |
| global model, tokenizer | |
| logger.info(f"Loading alignment model on device: {device}") | |
| model, tokenizer = load_alignment_model( | |
| device=device, | |
| model_path="MahmoudAshraf/mms-300m-1130-forced-aligner", | |
| dtype=dtype, | |
| ) | |
| logger.info("Model loaded successfully") | |
| yield | |
| # Cleanup (optional – HF Spaces will kill the container anyway) | |
| del model, tokenizer | |
| app = FastAPI( | |
| title="Forced Alignment API", | |
| description="Align text to audio using the MMS‑300M forced aligner model. " | |
| "Supports 1130+ languages.", | |
| version="1.0.0", | |
| lifespan=lifespan, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Health endpoint | |
| # --------------------------------------------------------------------------- | |
| async def health(): | |
| return {"status": "ok", "device": device, "model_loaded": model is not None} | |
| # --------------------------------------------------------------------------- | |
| # Core alignment endpoint | |
| # --------------------------------------------------------------------------- | |
| async def align( | |
| audio: UploadFile = File(..., description="Audio file (WAV, MP3, etc.)"), | |
| text: str = Form(..., description="Text to align (plain string)"), | |
| language: str = Form( | |
| ..., description="ISO‑639‑3 language code (e.g., 'eng', 'ara', 'rus')" | |
| ), | |
| romanize: bool = Form( | |
| True, | |
| description="Whether to romanise non‑Latin scripts (required for default model)", | |
| ), | |
| batch_size: int = Form(4, description="Batch size for inference"), | |
| ): | |
| """ | |
| Align `text` to the provided `audio` and return word‑level timestamps. | |
| """ | |
| # Save uploaded audio to a temporary file (under /tmp for HF Spaces) | |
| tmp_dir = tempfile.mkdtemp(dir="/tmp") | |
| audio_path = os.path.join(tmp_dir, "audio") | |
| try: | |
| with open(audio_path, "wb") as buffer: | |
| shutil.copyfileobj(audio.file, buffer) | |
| # ----- 1. Load audio waveform ----- | |
| audio_waveform = load_audio(audio_path, model.dtype, model.device) | |
| # ----- 2. Prepare text ----- | |
| text_clean = text.strip() | |
| if not text_clean: | |
| raise HTTPException(status_code=400, detail="Text must not be empty") | |
| # ----- 3. Generate emissions (log probabilities) ----- | |
| emissions, stride = generate_emissions( | |
| model, audio_waveform, batch_size=batch_size | |
| ) | |
| # ----- 4. Pre‑process text (star tokens, romanisation) ----- | |
| tokens_starred, text_starred = preprocess_text( | |
| text_clean, romanize=romanize, language=language | |
| ) | |
| # ----- 5. Get alignments ----- | |
| segments_raw, scores, blank_id = get_alignments( | |
| emissions, tokens_starred, tokenizer | |
| ) | |
| # ----- 6. Convert to word spans ----- | |
| spans = get_spans(tokens_starred, segments_raw, blank_id) | |
| # ----- 7. Post‑process into final word timestamps ----- | |
| word_timestamps = postprocess_results(text_starred, spans, stride, scores) | |
| # Build response | |
| segments_out = [ | |
| Segment(start=seg["start"], end=seg["end"], text=seg["text"]) | |
| for seg in word_timestamps | |
| ] | |
| return AlignmentResponse(text=text_clean, segments=segments_out) | |
| except Exception as e: | |
| logger.exception("Alignment failed") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| finally: | |
| # Clean up temporary folder | |
| shutil.rmtree(tmp_dir, ignore_errors=True) |