#!/usr/bin/env python3 """ TajweedSST - Step 2: Hierarchical Alignment Engine The Anti-Drift Engine: 1. WhisperX: Get word-level anchors (rigid boundaries) 2. MFA: Get phoneme-level precision within words 3. Normalization: Clamp MFA durations to match WhisperX exactly Formula: Phoneme_New_Duration = Phoneme_Old * (Whisper_Word_Duration / Sum_MFA_Phonemes) """ import os import json import subprocess from dataclasses import dataclass, field from typing import List, Dict, Optional, Tuple from pathlib import Path @dataclass class PhonemeAlignment: """Single phoneme timing""" phoneme: str start: float end: float duration: float @property def normalized_duration(self) -> float: return self.end - self.start @dataclass class WordAlignment: """Word-level alignment with phoneme breakdown""" word_text: str whisper_start: float whisper_end: float phonemes: List[PhonemeAlignment] = field(default_factory=list) @property def whisper_duration(self) -> float: return self.whisper_end - self.whisper_start @dataclass class AlignmentResult: """Complete alignment for an audio segment""" audio_path: str surah: int ayah: int words: List[WordAlignment] = field(default_factory=list) metadata: Dict = field(default_factory=dict) class AlignmentEngine: """ Hierarchical alignment using WhisperX + MFA """ def __init__(self, whisperx_model: str = "large-v3", mfa_acoustic_model: str = "arabic_mfa", mfa_dictionary: str = "arabic_mfa", device: str = "cuda", compute_type: str = "float16"): """ Initialize alignment engine Args: whisperx_model: WhisperX model size mfa_acoustic_model: MFA acoustic model for Arabic mfa_dictionary: MFA pronunciation dictionary device: cuda or cpu compute_type: float16 or float32 """ self.whisperx_model = whisperx_model self.mfa_acoustic_model = mfa_acoustic_model self.mfa_dictionary = mfa_dictionary self.device = device self.compute_type = compute_type self._whisperx = None self._whisperx_align_model = None def _load_whisperx(self): """Lazy load WhisperX models""" if self._whisperx is None: import whisperx self._whisperx = whisperx.load_model( self.whisperx_model, device=self.device, compute_type=self.compute_type ) # Load alignment model for Arabic self._whisperx_align_model, self._whisperx_align_metadata = whisperx.load_align_model( language_code="ar", device=self.device ) def align(self, audio_path: str, phonetic_words: List[str], surah: int = 0, ayah: int = 0) -> AlignmentResult: """ Perform hierarchical alignment Args: audio_path: Path to audio file phonetic_words: List of phonetic transcriptions from TajweedParser surah: Surah number for metadata ayah: Ayah number for metadata Returns: AlignmentResult with word and phoneme timings """ result = AlignmentResult( audio_path=audio_path, surah=surah, ayah=ayah ) # Step 1: WhisperX word-level alignment whisper_words = self._run_whisperx(audio_path) # Step 2: MFA phoneme-level alignment for each word mfa_phonemes = self._run_mfa(audio_path, phonetic_words) # Step 3: Normalize MFA phonemes to WhisperX word boundaries for i, (whisper_word, phonemes) in enumerate(zip(whisper_words, mfa_phonemes)): word_alignment = WordAlignment( word_text=whisper_word['word'], whisper_start=whisper_word['start'], whisper_end=whisper_word['end'] ) # Normalize phoneme durations normalized_phonemes = self._normalize_phonemes( phonemes=phonemes, target_start=whisper_word['start'], target_end=whisper_word['end'] ) word_alignment.phonemes = normalized_phonemes result.words.append(word_alignment) return result def _run_whisperx(self, audio_path: str) -> List[Dict]: """ Run WhisperX for word-level timing Returns: List of {word, start, end} dicts """ self._load_whisperx() import whisperx # Transcribe audio = whisperx.load_audio(audio_path) result = self._whisperx.transcribe(audio, batch_size=16) # Align to get word-level timestamps aligned = whisperx.align( result["segments"], self._whisperx_align_model, self._whisperx_align_metadata, audio, self.device, return_char_alignments=False ) # Extract word timings words = [] for segment in aligned["segments"]: for word_data in segment.get("words", []): words.append({ "word": word_data["word"], "start": word_data["start"], "end": word_data["end"] }) return words def _run_mfa(self, audio_path: str, phonetic_words: List[str]) -> List[List[Dict]]: """ Run MFA for phoneme-level timing within each word Returns: List of phoneme lists per word """ # Create temp directory for MFA temp_dir = Path("/tmp/tajweedsst_mfa") temp_dir.mkdir(exist_ok=True) input_dir = temp_dir / "input" output_dir = temp_dir / "output" input_dir.mkdir(exist_ok=True) output_dir.mkdir(exist_ok=True) # Copy audio and create transcript audio_name = Path(audio_path).stem transcript_path = input_dir / f"{audio_name}.txt" # Write phonetic transcript (space-separated words) transcript = " ".join(phonetic_words) transcript_path.write_text(transcript) # Copy audio file import shutil audio_dest = input_dir / Path(audio_path).name shutil.copy(audio_path, audio_dest) # Run MFA try: subprocess.run([ "mfa", "align", str(input_dir), self.mfa_dictionary, self.mfa_acoustic_model, str(output_dir), "--clean", "--quiet" ], check=True, capture_output=True) except subprocess.CalledProcessError as e: print(f"MFA Error: {e.stderr.decode()}") return [[] for _ in phonetic_words] # Parse TextGrid output textgrid_path = output_dir / f"{audio_name}.TextGrid" if textgrid_path.exists(): return self._parse_textgrid(textgrid_path, len(phonetic_words)) return [[] for _ in phonetic_words] def _parse_textgrid(self, textgrid_path: Path, word_count: int) -> List[List[Dict]]: """Parse MFA TextGrid output for phoneme timings""" try: import textgrid tg = textgrid.TextGrid.fromFile(str(textgrid_path)) # Find phones tier phones_tier = None words_tier = None for tier in tg: if tier.name == "phones": phones_tier = tier elif tier.name == "words": words_tier = tier if not phones_tier or not words_tier: return [[] for _ in range(word_count)] # Group phonemes by word boundaries result = [] word_idx = 0 current_word_phones = [] for interval in phones_tier: if interval.mark and interval.mark != "": phone_data = { "phoneme": interval.mark, "start": interval.minTime, "end": interval.maxTime } # Check if we've moved to next word if word_idx < len(words_tier): word_interval = words_tier[word_idx] if interval.minTime >= word_interval.maxTime: result.append(current_word_phones) current_word_phones = [] word_idx += 1 current_word_phones.append(phone_data) # Don't forget last word if current_word_phones: result.append(current_word_phones) return result except Exception as e: print(f"TextGrid parse error: {e}") return [[] for _ in range(word_count)] def _normalize_phonemes(self, phonemes: List[Dict], target_start: float, target_end: float) -> List[PhonemeAlignment]: """ Normalize MFA phonemes to fit exactly within WhisperX word boundaries Formula: Phoneme_New_Duration = Phoneme_Old * (Whisper_Word_Duration / Sum_MFA_Phonemes) """ if not phonemes: return [] target_duration = target_end - target_start # Calculate total MFA duration mfa_total = sum(p['end'] - p['start'] for p in phonemes) if mfa_total == 0: return [] # Scale factor scale = target_duration / mfa_total # Normalize each phoneme normalized = [] current_time = target_start for phone in phonemes: old_duration = phone['end'] - phone['start'] new_duration = old_duration * scale normalized.append(PhonemeAlignment( phoneme=phone['phoneme'], start=current_time, end=current_time + new_duration, duration=new_duration )) current_time += new_duration # Ensure last phoneme ends exactly at target_end (floating point fix) if normalized: normalized[-1].end = target_end normalized[-1].duration = target_end - normalized[-1].start return normalized class MockAlignmentEngine(AlignmentEngine): """ Mock alignment engine for testing without WhisperX/MFA installed """ def align(self, audio_path: str, phonetic_words: List[str], surah: int = 0, ayah: int = 0) -> AlignmentResult: """Generate mock alignment data""" result = AlignmentResult( audio_path=audio_path, surah=surah, ayah=ayah ) # Mock timing: 0.5s per word current_time = 0.0 word_duration = 0.5 for word in phonetic_words: phonemes = word.split() phoneme_duration = word_duration / max(len(phonemes), 1) word_alignment = WordAlignment( word_text=word, whisper_start=current_time, whisper_end=current_time + word_duration ) phoneme_time = current_time for phoneme in phonemes: word_alignment.phonemes.append(PhonemeAlignment( phoneme=phoneme, start=phoneme_time, end=phoneme_time + phoneme_duration, duration=phoneme_duration )) phoneme_time += phoneme_duration result.words.append(word_alignment) current_time += word_duration + 0.1 # Gap between words return result def main(): """Test alignment engine""" print("=" * 50) print("TajweedSST Alignment Engine Test") print("=" * 50) # Use mock engine for testing engine = MockAlignmentEngine() # Test phonetic words from TajweedParser phonetic_words = ["q l", "h w", "ā l l ā h", "ʾ ḥ d"] result = engine.align( audio_path="test.wav", phonetic_words=phonetic_words, surah=112, ayah=1 ) print(f"Aligned {len(result.words)} words:") for word in result.words: print(f"\n Word: '{word.word_text}'") print(f" Anchor: {word.whisper_start:.3f} - {word.whisper_end:.3f}s") for phoneme in word.phonemes: print(f" [{phoneme.phoneme}] {phoneme.start:.3f} - {phoneme.end:.3f}s") if __name__ == "__main__": main()