vaani-cavp-engine / modules /forced_alignment.py
Shaankar39's picture
init: Vaani CAVP engine (CPU, accuracy-first β€” Whisper large-v3, spaCy trf)
7d5f092
"""Forced Alignment β€” WebMAUS (BAS Web Services) + MFA (Montreal Forced Aligner)
Provides phone-level segmentation with precise boundaries, replacing
approximate Wav2Vec2 CTC phoneme timing.
Pipeline priority:
1. MFA (local, faster, no network dependency)
2. WebMAUS (BAS REST API fallback)
3. Wav2Vec2 CTC (existing fallback β€” already in ai_classification.py)
"""
from __future__ import annotations
import json
import logging
import os
import shutil
import subprocess
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
# BAS Web Services endpoint for MAUS
BAS_MAUS_URL = "https://clarin.phonetik.uni-muenchen.de/BASWebServices/services/runMAUS"
# Allow skipping the network fallback entirely (useful for offline / slow links)
SKIP_WEBMAUS = os.getenv("SKIP_WEBMAUS", "0").lower() in {"1", "true", "yes"}
WEBMAUS_TIMEOUT = int(os.getenv("WEBMAUS_TIMEOUT", "15"))
@dataclass
class PhoneSegment:
"""A single phone-level segment from forced alignment."""
phone: str
start_ms: float
end_ms: float
duration_ms: float
confidence: float = 1.0
source: str = "mfa" # "mfa" | "webmaus" | "wav2vec"
@dataclass
class AlignmentResult:
"""Complete forced alignment output."""
phones: list[PhoneSegment] = field(default_factory=list)
words: list[dict[str, Any]] = field(default_factory=list)
source: str = "none"
textgrid_path: str | None = None
success: bool = False
error: str | None = None
# ── TextGrid Parser ──────────────────────────────────────────────────────
def _parse_textgrid(tg_path: Path) -> tuple[list[PhoneSegment], list[dict]]:
"""Parse a Praat TextGrid file (long or short format) into phone segments."""
text = tg_path.read_text(encoding="utf-8", errors="replace")
phones: list[PhoneSegment] = []
words: list[dict] = []
if '"IntervalTier"' in text or "IntervalTier" in text:
phones, words = _parse_textgrid_long(text)
else:
phones, words = _parse_textgrid_short(text)
return phones, words
def _parse_textgrid_long(text: str) -> tuple[list[PhoneSegment], list[dict]]:
"""Parse long-format TextGrid."""
import re
phones: list[PhoneSegment] = []
words: list[dict] = []
# Split into tiers
tier_blocks = re.split(r'item\s*\[\d+\]', text)
for block in tier_blocks:
is_phone_tier = bool(re.search(r'name\s*=\s*"(phones?|segments?)"', block, re.I))
is_word_tier = bool(re.search(r'name\s*=\s*"(words?)"', block, re.I))
if not (is_phone_tier or is_word_tier):
continue
intervals = re.findall(
r'xmin\s*=\s*([\d.]+)\s*xmax\s*=\s*([\d.]+)\s*text\s*=\s*"([^"]*)"',
block,
)
for xmin_s, xmax_s, label in intervals:
xmin = float(xmin_s)
xmax = float(xmax_s)
label = label.strip()
if not label or label in {"", "sp", "sil", "SIL", "<p:>"}:
continue
start_ms = round(xmin * 1000, 2)
end_ms = round(xmax * 1000, 2)
dur_ms = round((xmax - xmin) * 1000, 2)
if is_phone_tier:
phones.append(PhoneSegment(
phone=label, start_ms=start_ms, end_ms=end_ms,
duration_ms=dur_ms, source="mfa",
))
elif is_word_tier:
words.append({
"word": label, "start_ms": start_ms,
"end_ms": end_ms, "duration_ms": dur_ms,
})
return phones, words
def _parse_textgrid_short(text: str) -> tuple[list[PhoneSegment], list[dict]]:
"""Parse short-format TextGrid (fallback)."""
import re
phones: list[PhoneSegment] = []
lines = text.strip().split("\n")
i = 0
while i < len(lines):
line = lines[i].strip().strip('"')
if line.lower() in ("phones", "phone"):
# Skip ahead to intervals
while i < len(lines) and not lines[i].strip().replace('"', '').replace('.', '').replace('-', '').isdigit():
i += 1
# Parse intervals: xmin, xmax, label triplets
while i + 2 < len(lines):
try:
xmin = float(lines[i].strip())
xmax = float(lines[i + 1].strip())
label = lines[i + 2].strip().strip('"')
i += 3
if not label or label in {"", "sp", "sil"}:
continue
phones.append(PhoneSegment(
phone=label,
start_ms=round(xmin * 1000, 2),
end_ms=round(xmax * 1000, 2),
duration_ms=round((xmax - xmin) * 1000, 2),
source="mfa",
))
except (ValueError, IndexError):
break
i += 1
return phones, []
# ── MFA (Montreal Forced Aligner) ────────────────────────────────────────
def _mfa_available() -> bool:
"""Check if MFA is installed and accessible."""
return shutil.which("mfa") is not None
def _run_mfa(audio_path: Path, transcript: str, language: str = "english") -> AlignmentResult:
"""Run Montreal Forced Aligner on audio + transcript."""
if not _mfa_available():
return AlignmentResult(source="mfa", error="MFA not installed")
tmpdir = tempfile.mkdtemp(prefix="mfa_")
try:
# MFA expects a directory with matched .wav + .txt files
stem = "input"
wav_dest = Path(tmpdir) / f"{stem}.wav"
txt_dest = Path(tmpdir) / f"{stem}.txt"
out_dir = Path(tmpdir) / "output"
out_dir.mkdir()
# Copy/convert audio to WAV 16kHz mono
if audio_path.suffix.lower() == ".wav":
shutil.copy2(audio_path, wav_dest)
else:
proc = subprocess.run(
["ffmpeg", "-y", "-i", str(audio_path), "-ar", "16000", "-ac", "1", str(wav_dest)],
capture_output=True, timeout=60,
)
if proc.returncode != 0:
return AlignmentResult(source="mfa", error="Audio conversion failed for MFA")
# Write transcript
txt_dest.write_text(transcript.strip(), encoding="utf-8")
# Map language to MFA dictionary/acoustic model names
dict_name = _mfa_model_name(language, "dictionary")
acoustic_name = _mfa_model_name(language, "acoustic")
# Run MFA align
cmd = [
"mfa", "align",
str(tmpdir),
dict_name,
acoustic_name,
str(out_dir),
"--clean",
"--single_speaker",
"--output_format", "long_textgrid",
]
logger.info("Running MFA: %s", " ".join(cmd))
proc = subprocess.run(cmd, capture_output=True, timeout=300, text=True)
if proc.returncode != 0:
logger.warning("MFA failed: %s", proc.stderr[:500])
return AlignmentResult(source="mfa", error=f"MFA exit code {proc.returncode}")
# Find output TextGrid
tg_files = list(out_dir.rglob("*.TextGrid"))
if not tg_files:
return AlignmentResult(source="mfa", error="MFA produced no TextGrid output")
tg_path = tg_files[0]
phones, words = _parse_textgrid(tg_path)
# Copy TextGrid to uploads for persistence
persistent_tg = audio_path.with_suffix(".TextGrid")
shutil.copy2(tg_path, persistent_tg)
return AlignmentResult(
phones=phones,
words=words,
source="mfa",
textgrid_path=str(persistent_tg),
success=True,
)
except subprocess.TimeoutExpired:
return AlignmentResult(source="mfa", error="MFA timed out (300s)")
except Exception as exc:
logger.exception("MFA alignment failed")
return AlignmentResult(source="mfa", error=str(exc))
finally:
shutil.rmtree(tmpdir, ignore_errors=True)
def _mfa_model_name(language: str, model_type: str) -> str:
"""Map language code to MFA model name."""
mapping = {
"en": ("english_mfa", "english_mfa"),
"hi": ("hindi_cv", "hindi_cv"),
"bn": ("bengali_cv", "bengali_cv"),
"or": ("odia_cv", "odia_cv"),
}
pair = mapping.get(language, ("english_mfa", "english_mfa"))
return pair[0] if model_type == "dictionary" else pair[1]
# ── WebMAUS (BAS Web Services) ───────────────────────────────────────────
def _run_webmaus(audio_path: Path, transcript: str, language: str = "eng-US") -> AlignmentResult:
if SKIP_WEBMAUS:
return AlignmentResult(source="webmaus", error="SKIP_WEBMAUS=1")
"""Run BAS WebMAUS for phonetic segmentation via REST API."""
try:
import requests
except ImportError:
return AlignmentResult(source="webmaus", error="requests library not installed")
# Map language codes to BAS MAUS language codes
lang_map = {
"en": "eng-US", "hi": "hin", "bn": "ben", "or": "ori",
"eng": "eng-US", "hin": "hin", "ben": "ben", "ori": "ori",
}
maus_lang = lang_map.get(language, "eng-US")
try:
with open(audio_path, "rb") as af:
files = {"SIGNAL": (audio_path.name, af, "audio/wav")}
data = {
"TEXT": transcript,
"LANGUAGE": maus_lang,
"OUTFORMAT": "TextGrid",
"MODUS": "standard",
"INSKANTEXTGRID": "true",
"INSORTTEXTGRID": "true",
}
logger.info("Calling WebMAUS API for language=%s (timeout=%ds)", maus_lang, WEBMAUS_TIMEOUT)
resp = requests.post(BAS_MAUS_URL, files=files, data=data, timeout=WEBMAUS_TIMEOUT)
if resp.status_code != 200:
return AlignmentResult(source="webmaus", error=f"WebMAUS HTTP {resp.status_code}")
# BAS returns XML with download link
import xml.etree.ElementTree as ET
root = ET.fromstring(resp.text)
success_el = root.find(".//success")
if success_el is None or success_el.text != "true":
err_msg = root.findtext(".//message", "Unknown WebMAUS error")
return AlignmentResult(source="webmaus", error=err_msg)
download_url = root.findtext(".//downloadLink")
if not download_url:
return AlignmentResult(source="webmaus", error="No download link in response")
# Download the TextGrid
tg_resp = requests.get(download_url, timeout=WEBMAUS_TIMEOUT)
if tg_resp.status_code != 200:
return AlignmentResult(source="webmaus", error="Failed to download TextGrid")
# Save TextGrid
tg_path = audio_path.with_suffix(".WebMAUS.TextGrid")
tg_path.write_text(tg_resp.text, encoding="utf-8")
phones, words = _parse_textgrid(tg_path)
# Re-tag source
for p in phones:
p.source = "webmaus"
return AlignmentResult(
phones=phones,
words=words,
source="webmaus",
textgrid_path=str(tg_path),
success=True,
)
except requests.Timeout:
return AlignmentResult(source="webmaus", error="WebMAUS request timed out (120s)")
except Exception as exc:
logger.exception("WebMAUS alignment failed")
return AlignmentResult(source="webmaus", error=str(exc))
# ── Public API ───────────────────────────────────────────────────────────
def forced_align(
audio_path: Path,
transcript: str,
language: str = "en",
prefer: str = "mfa",
) -> AlignmentResult:
"""Run forced alignment with fallback chain: MFA β†’ WebMAUS.
Args:
audio_path: Path to WAV audio file (16kHz mono recommended).
transcript: Plain text transcript of the audio.
language: ISO 639-1 language code (en, hi, bn, or).
prefer: Preferred aligner ("mfa" or "webmaus").
Returns:
AlignmentResult with phone-level segments and word boundaries.
"""
if not transcript or not transcript.strip():
return AlignmentResult(error="No transcript provided for alignment")
# Attempt preferred aligner first
if prefer == "webmaus":
result = _run_webmaus(audio_path, transcript, language)
if result.success:
logger.info("WebMAUS alignment succeeded: %d phones", len(result.phones))
return result
logger.warning("WebMAUS failed (%s), falling back to MFA", result.error)
result = _run_mfa(audio_path, transcript, language)
else:
result = _run_mfa(audio_path, transcript, language)
if result.success:
logger.info("MFA alignment succeeded: %d phones", len(result.phones))
return result
logger.warning("MFA failed (%s), falling back to WebMAUS", result.error)
result = _run_webmaus(audio_path, transcript, language)
if result.success:
logger.info("%s alignment succeeded: %d phones", result.source, len(result.phones))
else:
logger.warning("All forced alignment methods failed: %s", result.error)
return result
def alignment_to_phoneme_spans(alignment: AlignmentResult) -> list[dict[str, Any]]:
"""Convert AlignmentResult to the phoneme span format used by the rest of the pipeline.
This produces the same structure as Wav2Vec2 CTC output in ai_classification.py,
so downstream modules (phoneme_analysis, connected_speech, etc.) work unchanged.
"""
spans = []
for seg in alignment.phones:
spans.append({
"phoneme": seg.phone,
"start_ms": seg.start_ms,
"end_ms": seg.end_ms,
"duration_ms": seg.duration_ms,
"confidence": seg.confidence,
"source": seg.source,
})
return spans