sidmazak commited on
Commit
6b91a97
·
verified ·
1 Parent(s): 1979eb3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -0
app.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import shutil
4
+ import tempfile
5
+ import logging
6
+ from contextlib import asynccontextmanager
7
+ from typing import List, Optional
8
+
9
+ import torch
10
+ from fastapi import FastAPI, File, UploadFile, Form, HTTPException
11
+ from fastapi.responses import JSONResponse
12
+ from pydantic import BaseModel, Field
13
+
14
+ # ---------------------------------------------------------------------------
15
+ # Redirect HF / PyTorch caches to /tmp (required by HF Spaces)
16
+ # ---------------------------------------------------------------------------
17
+ os.environ.setdefault("HF_HOME", "/tmp/hf_cache")
18
+ os.environ.setdefault("TORCH_HOME", "/tmp/torch_cache")
19
+
20
+ # Now import the aligner – it will honour the cache env vars.
21
+ from ctc_forced_aligner import (
22
+ load_audio,
23
+ load_alignment_model,
24
+ generate_emissions,
25
+ preprocess_text,
26
+ get_alignments,
27
+ get_spans,
28
+ postprocess_results,
29
+ )
30
+
31
+ # ---------------------------------------------------------------------------
32
+ # Logging
33
+ # ---------------------------------------------------------------------------
34
+ logging.basicConfig(level=logging.INFO)
35
+ logger = logging.getLogger(__name__)
36
+
37
+ # ---------------------------------------------------------------------------
38
+ # Global variable for model, tokenizer and device
39
+ # ---------------------------------------------------------------------------
40
+ model = None
41
+ tokenizer = None
42
+ device = "cuda" if torch.cuda.is_available() else "cpu"
43
+ dtype = torch.float16 if device == "cuda" else torch.float32
44
+
45
+ # ---------------------------------------------------------------------------
46
+ # Pydantic models for Swagger documentation
47
+ # ---------------------------------------------------------------------------
48
+ class Segment(BaseModel):
49
+ start: float = Field(..., description="Segment start time in seconds")
50
+ end: float = Field(..., description="Segment end time in seconds")
51
+ text: str = Field(..., description="Aligned text of the segment")
52
+
53
+ class AlignmentResponse(BaseModel):
54
+ text: str = Field(..., description="Full, joined text that was aligned")
55
+ segments: List[Segment] = Field(..., description="List of aligned word segments")
56
+
57
+ # ---------------------------------------------------------------------------
58
+ # App lifespan – download/load the model once at startup
59
+ # ---------------------------------------------------------------------------
60
+ @asynccontextmanager
61
+ async def lifespan(app: FastAPI):
62
+ global model, tokenizer
63
+ logger.info(f"Loading alignment model on device: {device}")
64
+ model, tokenizer = load_alignment_model(
65
+ device=device,
66
+ model_path="MahmoudAshraf/mms-300m-1130-forced-aligner",
67
+ dtype=dtype,
68
+ )
69
+ logger.info("Model loaded successfully")
70
+ yield
71
+ # Cleanup (optional – HF Spaces will kill the container anyway)
72
+ del model, tokenizer
73
+
74
+ app = FastAPI(
75
+ title="Forced Alignment API",
76
+ description="Align text to audio using the MMS‑300M forced aligner model. "
77
+ "Supports 1130+ languages.",
78
+ version="1.0.0",
79
+ lifespan=lifespan,
80
+ )
81
+
82
+ # ---------------------------------------------------------------------------
83
+ # Health endpoint
84
+ # ---------------------------------------------------------------------------
85
+ @app.get("/health", tags=["health"])
86
+ async def health():
87
+ return {"status": "ok", "device": device, "model_loaded": model is not None}
88
+
89
+ # ---------------------------------------------------------------------------
90
+ # Core alignment endpoint
91
+ # ---------------------------------------------------------------------------
92
+ @app.post("/align", response_model=AlignmentResponse, tags=["alignment"])
93
+ async def align(
94
+ audio: UploadFile = File(..., description="Audio file (WAV, MP3, etc.)"),
95
+ text: str = Form(..., description="Text to align (plain string)"),
96
+ language: str = Form(
97
+ ..., description="ISO‑639‑3 language code (e.g., 'eng', 'ara', 'rus')"
98
+ ),
99
+ romanize: bool = Form(
100
+ True,
101
+ description="Whether to romanise non‑Latin scripts (required for default model)",
102
+ ),
103
+ batch_size: int = Form(4, description="Batch size for inference"),
104
+ ):
105
+ """
106
+ Align `text` to the provided `audio` and return word‑level timestamps.
107
+ """
108
+ # Save uploaded audio to a temporary file (under /tmp for HF Spaces)
109
+ tmp_dir = tempfile.mkdtemp(dir="/tmp")
110
+ audio_path = os.path.join(tmp_dir, "audio")
111
+ try:
112
+ with open(audio_path, "wb") as buffer:
113
+ shutil.copyfileobj(audio.file, buffer)
114
+
115
+ # ----- 1. Load audio waveform -----
116
+ audio_waveform = load_audio(audio_path, model.dtype, model.device)
117
+
118
+ # ----- 2. Prepare text -----
119
+ text_clean = text.strip()
120
+ if not text_clean:
121
+ raise HTTPException(status_code=400, detail="Text must not be empty")
122
+
123
+ # ----- 3. Generate emissions (log probabilities) -----
124
+ emissions, stride = generate_emissions(
125
+ model, audio_waveform, batch_size=batch_size
126
+ )
127
+
128
+ # ----- 4. Pre‑process text (star tokens, romanisation) -----
129
+ tokens_starred, text_starred = preprocess_text(
130
+ text_clean, romanize=romanize, language=language
131
+ )
132
+
133
+ # ----- 5. Get alignments -----
134
+ segments_raw, scores, blank_id = get_alignments(
135
+ emissions, tokens_starred, tokenizer
136
+ )
137
+
138
+ # ----- 6. Convert to word spans -----
139
+ spans = get_spans(tokens_starred, segments_raw, blank_id)
140
+
141
+ # ----- 7. Post‑process into final word timestamps -----
142
+ word_timestamps = postprocess_results(text_starred, spans, stride, scores)
143
+
144
+ # Build response
145
+ segments_out = [
146
+ Segment(start=seg["start"], end=seg["end"], text=seg["text"])
147
+ for seg in word_timestamps
148
+ ]
149
+ return AlignmentResponse(text=text_clean, segments=segments_out)
150
+
151
+ except Exception as e:
152
+ logger.exception("Alignment failed")
153
+ raise HTTPException(status_code=500, detail=str(e))
154
+ finally:
155
+ # Clean up temporary folder
156
+ shutil.rmtree(tmp_dir, ignore_errors=True)