File size: 14,321 Bytes
7d5f092
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
"""Forced Alignment β€” WebMAUS (BAS Web Services) + MFA (Montreal Forced Aligner)

Provides phone-level segmentation with precise boundaries, replacing
approximate Wav2Vec2 CTC phoneme timing.

Pipeline priority:
  1. MFA (local, faster, no network dependency)
  2. WebMAUS (BAS REST API fallback)
  3. Wav2Vec2 CTC (existing fallback β€” already in ai_classification.py)
"""

from __future__ import annotations

import json
import logging
import os
import shutil
import subprocess
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any

logger = logging.getLogger(__name__)

# BAS Web Services endpoint for MAUS
BAS_MAUS_URL = "https://clarin.phonetik.uni-muenchen.de/BASWebServices/services/runMAUS"

# Allow skipping the network fallback entirely (useful for offline / slow links)
SKIP_WEBMAUS = os.getenv("SKIP_WEBMAUS", "0").lower() in {"1", "true", "yes"}
WEBMAUS_TIMEOUT = int(os.getenv("WEBMAUS_TIMEOUT", "15"))


@dataclass
class PhoneSegment:
    """A single phone-level segment from forced alignment."""
    phone: str
    start_ms: float
    end_ms: float
    duration_ms: float
    confidence: float = 1.0
    source: str = "mfa"  # "mfa" | "webmaus" | "wav2vec"


@dataclass
class AlignmentResult:
    """Complete forced alignment output."""
    phones: list[PhoneSegment] = field(default_factory=list)
    words: list[dict[str, Any]] = field(default_factory=list)
    source: str = "none"
    textgrid_path: str | None = None
    success: bool = False
    error: str | None = None


# ── TextGrid Parser ──────────────────────────────────────────────────────

def _parse_textgrid(tg_path: Path) -> tuple[list[PhoneSegment], list[dict]]:
    """Parse a Praat TextGrid file (long or short format) into phone segments."""
    text = tg_path.read_text(encoding="utf-8", errors="replace")
    phones: list[PhoneSegment] = []
    words: list[dict] = []

    if '"IntervalTier"' in text or "IntervalTier" in text:
        phones, words = _parse_textgrid_long(text)
    else:
        phones, words = _parse_textgrid_short(text)

    return phones, words


def _parse_textgrid_long(text: str) -> tuple[list[PhoneSegment], list[dict]]:
    """Parse long-format TextGrid."""
    import re

    phones: list[PhoneSegment] = []
    words: list[dict] = []

    # Split into tiers
    tier_blocks = re.split(r'item\s*\[\d+\]', text)

    for block in tier_blocks:
        is_phone_tier = bool(re.search(r'name\s*=\s*"(phones?|segments?)"', block, re.I))
        is_word_tier = bool(re.search(r'name\s*=\s*"(words?)"', block, re.I))

        if not (is_phone_tier or is_word_tier):
            continue

        intervals = re.findall(
            r'xmin\s*=\s*([\d.]+)\s*xmax\s*=\s*([\d.]+)\s*text\s*=\s*"([^"]*)"',
            block,
        )

        for xmin_s, xmax_s, label in intervals:
            xmin = float(xmin_s)
            xmax = float(xmax_s)
            label = label.strip()
            if not label or label in {"", "sp", "sil", "SIL", "<p:>"}:
                continue

            start_ms = round(xmin * 1000, 2)
            end_ms = round(xmax * 1000, 2)
            dur_ms = round((xmax - xmin) * 1000, 2)

            if is_phone_tier:
                phones.append(PhoneSegment(
                    phone=label, start_ms=start_ms, end_ms=end_ms,
                    duration_ms=dur_ms, source="mfa",
                ))
            elif is_word_tier:
                words.append({
                    "word": label, "start_ms": start_ms,
                    "end_ms": end_ms, "duration_ms": dur_ms,
                })

    return phones, words


def _parse_textgrid_short(text: str) -> tuple[list[PhoneSegment], list[dict]]:
    """Parse short-format TextGrid (fallback)."""
    import re

    phones: list[PhoneSegment] = []
    lines = text.strip().split("\n")

    i = 0
    while i < len(lines):
        line = lines[i].strip().strip('"')
        if line.lower() in ("phones", "phone"):
            # Skip ahead to intervals
            while i < len(lines) and not lines[i].strip().replace('"', '').replace('.', '').replace('-', '').isdigit():
                i += 1
            # Parse intervals: xmin, xmax, label triplets
            while i + 2 < len(lines):
                try:
                    xmin = float(lines[i].strip())
                    xmax = float(lines[i + 1].strip())
                    label = lines[i + 2].strip().strip('"')
                    i += 3
                    if not label or label in {"", "sp", "sil"}:
                        continue
                    phones.append(PhoneSegment(
                        phone=label,
                        start_ms=round(xmin * 1000, 2),
                        end_ms=round(xmax * 1000, 2),
                        duration_ms=round((xmax - xmin) * 1000, 2),
                        source="mfa",
                    ))
                except (ValueError, IndexError):
                    break
        i += 1

    return phones, []


# ── MFA (Montreal Forced Aligner) ────────────────────────────────────────

def _mfa_available() -> bool:
    """Check if MFA is installed and accessible."""
    return shutil.which("mfa") is not None


def _run_mfa(audio_path: Path, transcript: str, language: str = "english") -> AlignmentResult:
    """Run Montreal Forced Aligner on audio + transcript."""
    if not _mfa_available():
        return AlignmentResult(source="mfa", error="MFA not installed")

    tmpdir = tempfile.mkdtemp(prefix="mfa_")
    try:
        # MFA expects a directory with matched .wav + .txt files
        stem = "input"
        wav_dest = Path(tmpdir) / f"{stem}.wav"
        txt_dest = Path(tmpdir) / f"{stem}.txt"
        out_dir = Path(tmpdir) / "output"
        out_dir.mkdir()

        # Copy/convert audio to WAV 16kHz mono
        if audio_path.suffix.lower() == ".wav":
            shutil.copy2(audio_path, wav_dest)
        else:
            proc = subprocess.run(
                ["ffmpeg", "-y", "-i", str(audio_path), "-ar", "16000", "-ac", "1", str(wav_dest)],
                capture_output=True, timeout=60,
            )
            if proc.returncode != 0:
                return AlignmentResult(source="mfa", error="Audio conversion failed for MFA")

        # Write transcript
        txt_dest.write_text(transcript.strip(), encoding="utf-8")

        # Map language to MFA dictionary/acoustic model names
        dict_name = _mfa_model_name(language, "dictionary")
        acoustic_name = _mfa_model_name(language, "acoustic")

        # Run MFA align
        cmd = [
            "mfa", "align",
            str(tmpdir),
            dict_name,
            acoustic_name,
            str(out_dir),
            "--clean",
            "--single_speaker",
            "--output_format", "long_textgrid",
        ]

        logger.info("Running MFA: %s", " ".join(cmd))
        proc = subprocess.run(cmd, capture_output=True, timeout=300, text=True)

        if proc.returncode != 0:
            logger.warning("MFA failed: %s", proc.stderr[:500])
            return AlignmentResult(source="mfa", error=f"MFA exit code {proc.returncode}")

        # Find output TextGrid
        tg_files = list(out_dir.rglob("*.TextGrid"))
        if not tg_files:
            return AlignmentResult(source="mfa", error="MFA produced no TextGrid output")

        tg_path = tg_files[0]
        phones, words = _parse_textgrid(tg_path)

        # Copy TextGrid to uploads for persistence
        persistent_tg = audio_path.with_suffix(".TextGrid")
        shutil.copy2(tg_path, persistent_tg)

        return AlignmentResult(
            phones=phones,
            words=words,
            source="mfa",
            textgrid_path=str(persistent_tg),
            success=True,
        )

    except subprocess.TimeoutExpired:
        return AlignmentResult(source="mfa", error="MFA timed out (300s)")
    except Exception as exc:
        logger.exception("MFA alignment failed")
        return AlignmentResult(source="mfa", error=str(exc))
    finally:
        shutil.rmtree(tmpdir, ignore_errors=True)


def _mfa_model_name(language: str, model_type: str) -> str:
    """Map language code to MFA model name."""
    mapping = {
        "en": ("english_mfa", "english_mfa"),
        "hi": ("hindi_cv", "hindi_cv"),
        "bn": ("bengali_cv", "bengali_cv"),
        "or": ("odia_cv", "odia_cv"),
    }
    pair = mapping.get(language, ("english_mfa", "english_mfa"))
    return pair[0] if model_type == "dictionary" else pair[1]


# ── WebMAUS (BAS Web Services) ───────────────────────────────────────────

def _run_webmaus(audio_path: Path, transcript: str, language: str = "eng-US") -> AlignmentResult:
    if SKIP_WEBMAUS:
        return AlignmentResult(source="webmaus", error="SKIP_WEBMAUS=1")
    """Run BAS WebMAUS for phonetic segmentation via REST API."""
    try:
        import requests
    except ImportError:
        return AlignmentResult(source="webmaus", error="requests library not installed")

    # Map language codes to BAS MAUS language codes
    lang_map = {
        "en": "eng-US", "hi": "hin", "bn": "ben", "or": "ori",
        "eng": "eng-US", "hin": "hin", "ben": "ben", "ori": "ori",
    }
    maus_lang = lang_map.get(language, "eng-US")

    try:
        with open(audio_path, "rb") as af:
            files = {"SIGNAL": (audio_path.name, af, "audio/wav")}
            data = {
                "TEXT": transcript,
                "LANGUAGE": maus_lang,
                "OUTFORMAT": "TextGrid",
                "MODUS": "standard",
                "INSKANTEXTGRID": "true",
                "INSORTTEXTGRID": "true",
            }

            logger.info("Calling WebMAUS API for language=%s (timeout=%ds)", maus_lang, WEBMAUS_TIMEOUT)
            resp = requests.post(BAS_MAUS_URL, files=files, data=data, timeout=WEBMAUS_TIMEOUT)

        if resp.status_code != 200:
            return AlignmentResult(source="webmaus", error=f"WebMAUS HTTP {resp.status_code}")

        # BAS returns XML with download link
        import xml.etree.ElementTree as ET
        root = ET.fromstring(resp.text)

        success_el = root.find(".//success")
        if success_el is None or success_el.text != "true":
            err_msg = root.findtext(".//message", "Unknown WebMAUS error")
            return AlignmentResult(source="webmaus", error=err_msg)

        download_url = root.findtext(".//downloadLink")
        if not download_url:
            return AlignmentResult(source="webmaus", error="No download link in response")

        # Download the TextGrid
        tg_resp = requests.get(download_url, timeout=WEBMAUS_TIMEOUT)
        if tg_resp.status_code != 200:
            return AlignmentResult(source="webmaus", error="Failed to download TextGrid")

        # Save TextGrid
        tg_path = audio_path.with_suffix(".WebMAUS.TextGrid")
        tg_path.write_text(tg_resp.text, encoding="utf-8")

        phones, words = _parse_textgrid(tg_path)

        # Re-tag source
        for p in phones:
            p.source = "webmaus"

        return AlignmentResult(
            phones=phones,
            words=words,
            source="webmaus",
            textgrid_path=str(tg_path),
            success=True,
        )

    except requests.Timeout:
        return AlignmentResult(source="webmaus", error="WebMAUS request timed out (120s)")
    except Exception as exc:
        logger.exception("WebMAUS alignment failed")
        return AlignmentResult(source="webmaus", error=str(exc))


# ── Public API ───────────────────────────────────────────────────────────

def forced_align(
    audio_path: Path,
    transcript: str,
    language: str = "en",
    prefer: str = "mfa",
) -> AlignmentResult:
    """Run forced alignment with fallback chain: MFA β†’ WebMAUS.

    Args:
        audio_path: Path to WAV audio file (16kHz mono recommended).
        transcript: Plain text transcript of the audio.
        language: ISO 639-1 language code (en, hi, bn, or).
        prefer: Preferred aligner ("mfa" or "webmaus").

    Returns:
        AlignmentResult with phone-level segments and word boundaries.
    """
    if not transcript or not transcript.strip():
        return AlignmentResult(error="No transcript provided for alignment")

    # Attempt preferred aligner first
    if prefer == "webmaus":
        result = _run_webmaus(audio_path, transcript, language)
        if result.success:
            logger.info("WebMAUS alignment succeeded: %d phones", len(result.phones))
            return result
        logger.warning("WebMAUS failed (%s), falling back to MFA", result.error)
        result = _run_mfa(audio_path, transcript, language)
    else:
        result = _run_mfa(audio_path, transcript, language)
        if result.success:
            logger.info("MFA alignment succeeded: %d phones", len(result.phones))
            return result
        logger.warning("MFA failed (%s), falling back to WebMAUS", result.error)
        result = _run_webmaus(audio_path, transcript, language)

    if result.success:
        logger.info("%s alignment succeeded: %d phones", result.source, len(result.phones))
    else:
        logger.warning("All forced alignment methods failed: %s", result.error)

    return result


def alignment_to_phoneme_spans(alignment: AlignmentResult) -> list[dict[str, Any]]:
    """Convert AlignmentResult to the phoneme span format used by the rest of the pipeline.

    This produces the same structure as Wav2Vec2 CTC output in ai_classification.py,
    so downstream modules (phoneme_analysis, connected_speech, etc.) work unchanged.
    """
    spans = []
    for seg in alignment.phones:
        spans.append({
            "phoneme": seg.phone,
            "start_ms": seg.start_ms,
            "end_ms": seg.end_ms,
            "duration_ms": seg.duration_ms,
            "confidence": seg.confidence,
            "source": seg.source,
        })
    return spans