Spaces:
Sleeping
Sleeping
File size: 6,017 Bytes
6b91a97 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | import os
import sys
import shutil
import tempfile
import logging
from contextlib import asynccontextmanager
from typing import List, Optional
import torch
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
# ---------------------------------------------------------------------------
# Redirect HF / PyTorch caches to /tmp (required by HF Spaces)
# ---------------------------------------------------------------------------
os.environ.setdefault("HF_HOME", "/tmp/hf_cache")
os.environ.setdefault("TORCH_HOME", "/tmp/torch_cache")
# Now import the aligner – it will honour the cache env vars.
from ctc_forced_aligner import (
load_audio,
load_alignment_model,
generate_emissions,
preprocess_text,
get_alignments,
get_spans,
postprocess_results,
)
# ---------------------------------------------------------------------------
# Logging
# ---------------------------------------------------------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Global variable for model, tokenizer and device
# ---------------------------------------------------------------------------
model = None
tokenizer = None
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
# ---------------------------------------------------------------------------
# Pydantic models for Swagger documentation
# ---------------------------------------------------------------------------
class Segment(BaseModel):
start: float = Field(..., description="Segment start time in seconds")
end: float = Field(..., description="Segment end time in seconds")
text: str = Field(..., description="Aligned text of the segment")
class AlignmentResponse(BaseModel):
text: str = Field(..., description="Full, joined text that was aligned")
segments: List[Segment] = Field(..., description="List of aligned word segments")
# ---------------------------------------------------------------------------
# App lifespan – download/load the model once at startup
# ---------------------------------------------------------------------------
@asynccontextmanager
async def lifespan(app: FastAPI):
global model, tokenizer
logger.info(f"Loading alignment model on device: {device}")
model, tokenizer = load_alignment_model(
device=device,
model_path="MahmoudAshraf/mms-300m-1130-forced-aligner",
dtype=dtype,
)
logger.info("Model loaded successfully")
yield
# Cleanup (optional – HF Spaces will kill the container anyway)
del model, tokenizer
app = FastAPI(
title="Forced Alignment API",
description="Align text to audio using the MMS‑300M forced aligner model. "
"Supports 1130+ languages.",
version="1.0.0",
lifespan=lifespan,
)
# ---------------------------------------------------------------------------
# Health endpoint
# ---------------------------------------------------------------------------
@app.get("/health", tags=["health"])
async def health():
return {"status": "ok", "device": device, "model_loaded": model is not None}
# ---------------------------------------------------------------------------
# Core alignment endpoint
# ---------------------------------------------------------------------------
@app.post("/align", response_model=AlignmentResponse, tags=["alignment"])
async def align(
audio: UploadFile = File(..., description="Audio file (WAV, MP3, etc.)"),
text: str = Form(..., description="Text to align (plain string)"),
language: str = Form(
..., description="ISO‑639‑3 language code (e.g., 'eng', 'ara', 'rus')"
),
romanize: bool = Form(
True,
description="Whether to romanise non‑Latin scripts (required for default model)",
),
batch_size: int = Form(4, description="Batch size for inference"),
):
"""
Align `text` to the provided `audio` and return word‑level timestamps.
"""
# Save uploaded audio to a temporary file (under /tmp for HF Spaces)
tmp_dir = tempfile.mkdtemp(dir="/tmp")
audio_path = os.path.join(tmp_dir, "audio")
try:
with open(audio_path, "wb") as buffer:
shutil.copyfileobj(audio.file, buffer)
# ----- 1. Load audio waveform -----
audio_waveform = load_audio(audio_path, model.dtype, model.device)
# ----- 2. Prepare text -----
text_clean = text.strip()
if not text_clean:
raise HTTPException(status_code=400, detail="Text must not be empty")
# ----- 3. Generate emissions (log probabilities) -----
emissions, stride = generate_emissions(
model, audio_waveform, batch_size=batch_size
)
# ----- 4. Pre‑process text (star tokens, romanisation) -----
tokens_starred, text_starred = preprocess_text(
text_clean, romanize=romanize, language=language
)
# ----- 5. Get alignments -----
segments_raw, scores, blank_id = get_alignments(
emissions, tokens_starred, tokenizer
)
# ----- 6. Convert to word spans -----
spans = get_spans(tokens_starred, segments_raw, blank_id)
# ----- 7. Post‑process into final word timestamps -----
word_timestamps = postprocess_results(text_starred, spans, stride, scores)
# Build response
segments_out = [
Segment(start=seg["start"], end=seg["end"], text=seg["text"])
for seg in word_timestamps
]
return AlignmentResponse(text=text_clean, segments=segments_out)
except Exception as e:
logger.exception("Alignment failed")
raise HTTPException(status_code=500, detail=str(e))
finally:
# Clean up temporary folder
shutil.rmtree(tmp_dir, ignore_errors=True) |