File size: 6,017 Bytes
6b91a97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
import sys
import shutil
import tempfile
import logging
from contextlib import asynccontextmanager
from typing import List, Optional

import torch
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field

# ---------------------------------------------------------------------------
# Redirect HF / PyTorch caches to /tmp (required by HF Spaces)
# ---------------------------------------------------------------------------
os.environ.setdefault("HF_HOME", "/tmp/hf_cache")
os.environ.setdefault("TORCH_HOME", "/tmp/torch_cache")

# Now import the aligner – it will honour the cache env vars.
from ctc_forced_aligner import (
    load_audio,
    load_alignment_model,
    generate_emissions,
    preprocess_text,
    get_alignments,
    get_spans,
    postprocess_results,
)

# ---------------------------------------------------------------------------
# Logging
# ---------------------------------------------------------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Global variable for model, tokenizer and device
# ---------------------------------------------------------------------------
model = None
tokenizer = None
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32

# ---------------------------------------------------------------------------
# Pydantic models for Swagger documentation
# ---------------------------------------------------------------------------
class Segment(BaseModel):
    start: float = Field(..., description="Segment start time in seconds")
    end: float = Field(..., description="Segment end time in seconds")
    text: str = Field(..., description="Aligned text of the segment")

class AlignmentResponse(BaseModel):
    text: str = Field(..., description="Full, joined text that was aligned")
    segments: List[Segment] = Field(..., description="List of aligned word segments")

# ---------------------------------------------------------------------------
# App lifespan – download/load the model once at startup
# ---------------------------------------------------------------------------
@asynccontextmanager
async def lifespan(app: FastAPI):
    global model, tokenizer
    logger.info(f"Loading alignment model on device: {device}")
    model, tokenizer = load_alignment_model(
        device=device,
        model_path="MahmoudAshraf/mms-300m-1130-forced-aligner",
        dtype=dtype,
    )
    logger.info("Model loaded successfully")
    yield
    # Cleanup (optional – HF Spaces will kill the container anyway)
    del model, tokenizer

app = FastAPI(
    title="Forced Alignment API",
    description="Align text to audio using the MMS‑300M forced aligner model. "
    "Supports 1130+ languages.",
    version="1.0.0",
    lifespan=lifespan,
)

# ---------------------------------------------------------------------------
# Health endpoint
# ---------------------------------------------------------------------------
@app.get("/health", tags=["health"])
async def health():
    return {"status": "ok", "device": device, "model_loaded": model is not None}

# ---------------------------------------------------------------------------
# Core alignment endpoint
# ---------------------------------------------------------------------------
@app.post("/align", response_model=AlignmentResponse, tags=["alignment"])
async def align(
    audio: UploadFile = File(..., description="Audio file (WAV, MP3, etc.)"),
    text: str = Form(..., description="Text to align (plain string)"),
    language: str = Form(
        ..., description="ISO‑639‑3 language code (e.g., 'eng', 'ara', 'rus')"
    ),
    romanize: bool = Form(
        True,
        description="Whether to romanise non‑Latin scripts (required for default model)",
    ),
    batch_size: int = Form(4, description="Batch size for inference"),
):
    """
    Align `text` to the provided `audio` and return word‑level timestamps.
    """
    # Save uploaded audio to a temporary file (under /tmp for HF Spaces)
    tmp_dir = tempfile.mkdtemp(dir="/tmp")
    audio_path = os.path.join(tmp_dir, "audio")
    try:
        with open(audio_path, "wb") as buffer:
            shutil.copyfileobj(audio.file, buffer)

        # ----- 1. Load audio waveform -----
        audio_waveform = load_audio(audio_path, model.dtype, model.device)

        # ----- 2. Prepare text -----
        text_clean = text.strip()
        if not text_clean:
            raise HTTPException(status_code=400, detail="Text must not be empty")

        # ----- 3. Generate emissions (log probabilities) -----
        emissions, stride = generate_emissions(
            model, audio_waveform, batch_size=batch_size
        )

        # ----- 4. Pre‑process text (star tokens, romanisation) -----
        tokens_starred, text_starred = preprocess_text(
            text_clean, romanize=romanize, language=language
        )

        # ----- 5. Get alignments -----
        segments_raw, scores, blank_id = get_alignments(
            emissions, tokens_starred, tokenizer
        )

        # ----- 6. Convert to word spans -----
        spans = get_spans(tokens_starred, segments_raw, blank_id)

        # ----- 7. Post‑process into final word timestamps -----
        word_timestamps = postprocess_results(text_starred, spans, stride, scores)

        # Build response
        segments_out = [
            Segment(start=seg["start"], end=seg["end"], text=seg["text"])
            for seg in word_timestamps
        ]
        return AlignmentResponse(text=text_clean, segments=segments_out)

    except Exception as e:
        logger.exception("Alignment failed")
        raise HTTPException(status_code=500, detail=str(e))
    finally:
        # Clean up temporary folder
        shutil.rmtree(tmp_dir, ignore_errors=True)