Spaces:
Build error
Build error
File size: 14,321 Bytes
7d5f092 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 | """Forced Alignment β WebMAUS (BAS Web Services) + MFA (Montreal Forced Aligner)
Provides phone-level segmentation with precise boundaries, replacing
approximate Wav2Vec2 CTC phoneme timing.
Pipeline priority:
1. MFA (local, faster, no network dependency)
2. WebMAUS (BAS REST API fallback)
3. Wav2Vec2 CTC (existing fallback β already in ai_classification.py)
"""
from __future__ import annotations
import json
import logging
import os
import shutil
import subprocess
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
# BAS Web Services endpoint for MAUS
BAS_MAUS_URL = "https://clarin.phonetik.uni-muenchen.de/BASWebServices/services/runMAUS"
# Allow skipping the network fallback entirely (useful for offline / slow links)
SKIP_WEBMAUS = os.getenv("SKIP_WEBMAUS", "0").lower() in {"1", "true", "yes"}
WEBMAUS_TIMEOUT = int(os.getenv("WEBMAUS_TIMEOUT", "15"))
@dataclass
class PhoneSegment:
"""A single phone-level segment from forced alignment."""
phone: str
start_ms: float
end_ms: float
duration_ms: float
confidence: float = 1.0
source: str = "mfa" # "mfa" | "webmaus" | "wav2vec"
@dataclass
class AlignmentResult:
"""Complete forced alignment output."""
phones: list[PhoneSegment] = field(default_factory=list)
words: list[dict[str, Any]] = field(default_factory=list)
source: str = "none"
textgrid_path: str | None = None
success: bool = False
error: str | None = None
# ββ TextGrid Parser ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _parse_textgrid(tg_path: Path) -> tuple[list[PhoneSegment], list[dict]]:
"""Parse a Praat TextGrid file (long or short format) into phone segments."""
text = tg_path.read_text(encoding="utf-8", errors="replace")
phones: list[PhoneSegment] = []
words: list[dict] = []
if '"IntervalTier"' in text or "IntervalTier" in text:
phones, words = _parse_textgrid_long(text)
else:
phones, words = _parse_textgrid_short(text)
return phones, words
def _parse_textgrid_long(text: str) -> tuple[list[PhoneSegment], list[dict]]:
"""Parse long-format TextGrid."""
import re
phones: list[PhoneSegment] = []
words: list[dict] = []
# Split into tiers
tier_blocks = re.split(r'item\s*\[\d+\]', text)
for block in tier_blocks:
is_phone_tier = bool(re.search(r'name\s*=\s*"(phones?|segments?)"', block, re.I))
is_word_tier = bool(re.search(r'name\s*=\s*"(words?)"', block, re.I))
if not (is_phone_tier or is_word_tier):
continue
intervals = re.findall(
r'xmin\s*=\s*([\d.]+)\s*xmax\s*=\s*([\d.]+)\s*text\s*=\s*"([^"]*)"',
block,
)
for xmin_s, xmax_s, label in intervals:
xmin = float(xmin_s)
xmax = float(xmax_s)
label = label.strip()
if not label or label in {"", "sp", "sil", "SIL", "<p:>"}:
continue
start_ms = round(xmin * 1000, 2)
end_ms = round(xmax * 1000, 2)
dur_ms = round((xmax - xmin) * 1000, 2)
if is_phone_tier:
phones.append(PhoneSegment(
phone=label, start_ms=start_ms, end_ms=end_ms,
duration_ms=dur_ms, source="mfa",
))
elif is_word_tier:
words.append({
"word": label, "start_ms": start_ms,
"end_ms": end_ms, "duration_ms": dur_ms,
})
return phones, words
def _parse_textgrid_short(text: str) -> tuple[list[PhoneSegment], list[dict]]:
"""Parse short-format TextGrid (fallback)."""
import re
phones: list[PhoneSegment] = []
lines = text.strip().split("\n")
i = 0
while i < len(lines):
line = lines[i].strip().strip('"')
if line.lower() in ("phones", "phone"):
# Skip ahead to intervals
while i < len(lines) and not lines[i].strip().replace('"', '').replace('.', '').replace('-', '').isdigit():
i += 1
# Parse intervals: xmin, xmax, label triplets
while i + 2 < len(lines):
try:
xmin = float(lines[i].strip())
xmax = float(lines[i + 1].strip())
label = lines[i + 2].strip().strip('"')
i += 3
if not label or label in {"", "sp", "sil"}:
continue
phones.append(PhoneSegment(
phone=label,
start_ms=round(xmin * 1000, 2),
end_ms=round(xmax * 1000, 2),
duration_ms=round((xmax - xmin) * 1000, 2),
source="mfa",
))
except (ValueError, IndexError):
break
i += 1
return phones, []
# ββ MFA (Montreal Forced Aligner) ββββββββββββββββββββββββββββββββββββββββ
def _mfa_available() -> bool:
"""Check if MFA is installed and accessible."""
return shutil.which("mfa") is not None
def _run_mfa(audio_path: Path, transcript: str, language: str = "english") -> AlignmentResult:
"""Run Montreal Forced Aligner on audio + transcript."""
if not _mfa_available():
return AlignmentResult(source="mfa", error="MFA not installed")
tmpdir = tempfile.mkdtemp(prefix="mfa_")
try:
# MFA expects a directory with matched .wav + .txt files
stem = "input"
wav_dest = Path(tmpdir) / f"{stem}.wav"
txt_dest = Path(tmpdir) / f"{stem}.txt"
out_dir = Path(tmpdir) / "output"
out_dir.mkdir()
# Copy/convert audio to WAV 16kHz mono
if audio_path.suffix.lower() == ".wav":
shutil.copy2(audio_path, wav_dest)
else:
proc = subprocess.run(
["ffmpeg", "-y", "-i", str(audio_path), "-ar", "16000", "-ac", "1", str(wav_dest)],
capture_output=True, timeout=60,
)
if proc.returncode != 0:
return AlignmentResult(source="mfa", error="Audio conversion failed for MFA")
# Write transcript
txt_dest.write_text(transcript.strip(), encoding="utf-8")
# Map language to MFA dictionary/acoustic model names
dict_name = _mfa_model_name(language, "dictionary")
acoustic_name = _mfa_model_name(language, "acoustic")
# Run MFA align
cmd = [
"mfa", "align",
str(tmpdir),
dict_name,
acoustic_name,
str(out_dir),
"--clean",
"--single_speaker",
"--output_format", "long_textgrid",
]
logger.info("Running MFA: %s", " ".join(cmd))
proc = subprocess.run(cmd, capture_output=True, timeout=300, text=True)
if proc.returncode != 0:
logger.warning("MFA failed: %s", proc.stderr[:500])
return AlignmentResult(source="mfa", error=f"MFA exit code {proc.returncode}")
# Find output TextGrid
tg_files = list(out_dir.rglob("*.TextGrid"))
if not tg_files:
return AlignmentResult(source="mfa", error="MFA produced no TextGrid output")
tg_path = tg_files[0]
phones, words = _parse_textgrid(tg_path)
# Copy TextGrid to uploads for persistence
persistent_tg = audio_path.with_suffix(".TextGrid")
shutil.copy2(tg_path, persistent_tg)
return AlignmentResult(
phones=phones,
words=words,
source="mfa",
textgrid_path=str(persistent_tg),
success=True,
)
except subprocess.TimeoutExpired:
return AlignmentResult(source="mfa", error="MFA timed out (300s)")
except Exception as exc:
logger.exception("MFA alignment failed")
return AlignmentResult(source="mfa", error=str(exc))
finally:
shutil.rmtree(tmpdir, ignore_errors=True)
def _mfa_model_name(language: str, model_type: str) -> str:
"""Map language code to MFA model name."""
mapping = {
"en": ("english_mfa", "english_mfa"),
"hi": ("hindi_cv", "hindi_cv"),
"bn": ("bengali_cv", "bengali_cv"),
"or": ("odia_cv", "odia_cv"),
}
pair = mapping.get(language, ("english_mfa", "english_mfa"))
return pair[0] if model_type == "dictionary" else pair[1]
# ββ WebMAUS (BAS Web Services) βββββββββββββββββββββββββββββββββββββββββββ
def _run_webmaus(audio_path: Path, transcript: str, language: str = "eng-US") -> AlignmentResult:
if SKIP_WEBMAUS:
return AlignmentResult(source="webmaus", error="SKIP_WEBMAUS=1")
"""Run BAS WebMAUS for phonetic segmentation via REST API."""
try:
import requests
except ImportError:
return AlignmentResult(source="webmaus", error="requests library not installed")
# Map language codes to BAS MAUS language codes
lang_map = {
"en": "eng-US", "hi": "hin", "bn": "ben", "or": "ori",
"eng": "eng-US", "hin": "hin", "ben": "ben", "ori": "ori",
}
maus_lang = lang_map.get(language, "eng-US")
try:
with open(audio_path, "rb") as af:
files = {"SIGNAL": (audio_path.name, af, "audio/wav")}
data = {
"TEXT": transcript,
"LANGUAGE": maus_lang,
"OUTFORMAT": "TextGrid",
"MODUS": "standard",
"INSKANTEXTGRID": "true",
"INSORTTEXTGRID": "true",
}
logger.info("Calling WebMAUS API for language=%s (timeout=%ds)", maus_lang, WEBMAUS_TIMEOUT)
resp = requests.post(BAS_MAUS_URL, files=files, data=data, timeout=WEBMAUS_TIMEOUT)
if resp.status_code != 200:
return AlignmentResult(source="webmaus", error=f"WebMAUS HTTP {resp.status_code}")
# BAS returns XML with download link
import xml.etree.ElementTree as ET
root = ET.fromstring(resp.text)
success_el = root.find(".//success")
if success_el is None or success_el.text != "true":
err_msg = root.findtext(".//message", "Unknown WebMAUS error")
return AlignmentResult(source="webmaus", error=err_msg)
download_url = root.findtext(".//downloadLink")
if not download_url:
return AlignmentResult(source="webmaus", error="No download link in response")
# Download the TextGrid
tg_resp = requests.get(download_url, timeout=WEBMAUS_TIMEOUT)
if tg_resp.status_code != 200:
return AlignmentResult(source="webmaus", error="Failed to download TextGrid")
# Save TextGrid
tg_path = audio_path.with_suffix(".WebMAUS.TextGrid")
tg_path.write_text(tg_resp.text, encoding="utf-8")
phones, words = _parse_textgrid(tg_path)
# Re-tag source
for p in phones:
p.source = "webmaus"
return AlignmentResult(
phones=phones,
words=words,
source="webmaus",
textgrid_path=str(tg_path),
success=True,
)
except requests.Timeout:
return AlignmentResult(source="webmaus", error="WebMAUS request timed out (120s)")
except Exception as exc:
logger.exception("WebMAUS alignment failed")
return AlignmentResult(source="webmaus", error=str(exc))
# ββ Public API βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def forced_align(
audio_path: Path,
transcript: str,
language: str = "en",
prefer: str = "mfa",
) -> AlignmentResult:
"""Run forced alignment with fallback chain: MFA β WebMAUS.
Args:
audio_path: Path to WAV audio file (16kHz mono recommended).
transcript: Plain text transcript of the audio.
language: ISO 639-1 language code (en, hi, bn, or).
prefer: Preferred aligner ("mfa" or "webmaus").
Returns:
AlignmentResult with phone-level segments and word boundaries.
"""
if not transcript or not transcript.strip():
return AlignmentResult(error="No transcript provided for alignment")
# Attempt preferred aligner first
if prefer == "webmaus":
result = _run_webmaus(audio_path, transcript, language)
if result.success:
logger.info("WebMAUS alignment succeeded: %d phones", len(result.phones))
return result
logger.warning("WebMAUS failed (%s), falling back to MFA", result.error)
result = _run_mfa(audio_path, transcript, language)
else:
result = _run_mfa(audio_path, transcript, language)
if result.success:
logger.info("MFA alignment succeeded: %d phones", len(result.phones))
return result
logger.warning("MFA failed (%s), falling back to WebMAUS", result.error)
result = _run_webmaus(audio_path, transcript, language)
if result.success:
logger.info("%s alignment succeeded: %d phones", result.source, len(result.phones))
else:
logger.warning("All forced alignment methods failed: %s", result.error)
return result
def alignment_to_phoneme_spans(alignment: AlignmentResult) -> list[dict[str, Any]]:
"""Convert AlignmentResult to the phoneme span format used by the rest of the pipeline.
This produces the same structure as Wav2Vec2 CTC output in ai_classification.py,
so downstream modules (phoneme_analysis, connected_speech, etc.) work unchanged.
"""
spans = []
for seg in alignment.phones:
spans.append({
"phoneme": seg.phone,
"start_ms": seg.start_ms,
"end_ms": seg.end_ms,
"duration_ms": seg.duration_ms,
"confidence": seg.confidence,
"source": seg.source,
})
return spans
|