"""Forced alignment core module using ctc-forced-aligner.""" import logging from pathlib import Path from typing import Dict, List, Union from config import ( MIN_CAPTION_DURATION_MS, GAP_BETWEEN_CAPTIONS_MS ) # Set up logging for this module logger = logging.getLogger(__name__) def align(audio_path: Union[str, Path], sentences: List[str], language: str = "ara") -> List[Dict]: """Perform forced alignment on audio with provided sentences. Uses the ctc-forced-aligner library to align text sentences with audio timestamps. Returns precise millisecond timestamps suitable for SRT generation. """ try: # Import alignment library from ctc_forced_aligner import AlignmentTorchSingleton import tempfile import ssl import urllib.request # Optimized model handling - avoid SSL patching # SSL issues should be handled by the alignment library itself except ImportError as e: raise RuntimeError( f"Required alignment libraries not installed: {e}\n" "Install with: pip install ctc-forced-aligner torch torchaudio" ) audio_path = Path(audio_path) if not audio_path.exists(): raise FileNotFoundError(f"Audio file not found: {audio_path}") if not sentences: raise ValueError("No sentences provided for alignment") # Clean sentences - remove empty ones clean_sentences = [s.strip() for s in sentences if s.strip()] if not clean_sentences: raise ValueError("No non-empty sentences provided for alignment") logger.info(f"Starting alignment for {len(clean_sentences)} sentences") # Create a temporary text file with the script with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False, encoding='utf-8') as f: script_text = '\n'.join(clean_sentences) f.write(script_text) temp_script_path = f.name try: print("📥 Loading facebook/mms-300m model (cached after first run)...") # Create alignment instance (singleton pattern - downloads model on first use) aligner = AlignmentTorchSingleton() # Create temporary output SRT file with tempfile.NamedTemporaryFile(suffix='.srt', delete=False) as f: temp_srt_path = f.name # Perform alignment using the built-in SRT generation with MMS_FA model success = aligner.generate_srt( input_audio_path=str(audio_path), input_text_path=temp_script_path, output_srt_path=temp_srt_path, model_type='MMS_FA' # Use facebook/mms-300m model ) if not success: raise RuntimeError("Alignment failed - no SRT file generated") # Parse the generated SRT to extract our format segments = _parse_generated_srt(temp_srt_path) # Clean up temp files Path(temp_script_path).unlink(missing_ok=True) Path(temp_srt_path).unlink(missing_ok=True) except Exception as e: # Clean up temp files on error Path(temp_script_path).unlink(missing_ok=True) try: Path(temp_srt_path).unlink(missing_ok=True) except: pass raise RuntimeError(f"Forced alignment failed: {e}") # Apply smart gap correction segments = _apply_smart_gap_correction(segments) logger.info(f"Alignment completed: {len(segments)} segments") return segments def align_word_level(audio_path: Union[str, Path], sentences: List[str], language: str = "ara", max_chars: int = 42) -> List[Dict]: """Perform true word-level forced alignment using facebook/mms-300m (MMS_FA). Arabic text is romanised with unidecode so the MMS_FA CTC model can align every word — Arabic, French and mixed tokens alike — at word granularity. Original script text is preserved unchanged in the output. Returns a flat list of per-word dicts (grouped later by srt_writer.group_words): [{"index": 1, "text": "كنت", "start_ms": 0, "end_ms": 300}, ...] """ try: import torch import torchaudio import torchaudio.functional as F from unidecode import unidecode from ctc_forced_aligner import ( load_audio as cfa_load_audio, align as cfa_align, unflatten, _postprocess_results, ) except ImportError as e: raise RuntimeError( f"Required libraries not installed: {e}\n" "Install with: pip install ctc-forced-aligner torch torchaudio" ) audio_path = Path(audio_path) if not audio_path.exists(): raise FileNotFoundError(f"Audio file not found: {audio_path}") clean_sentences = [s.strip() for s in sentences if s.strip()] if not clean_sentences: raise ValueError("No non-empty sentences provided for alignment") logger.info(f"Starting word-level alignment: {len(clean_sentences)} sentences") full_text = " ".join(clean_sentences) original_words = full_text.split() print("📥 Loading facebook/mms-300m model (cached after first run)...") device = torch.device("cpu") bundle = torchaudio.pipelines.MMS_FA dictionary = bundle.get_dict(star=None) model = bundle.get_model(with_star=False).to(device) waveform = cfa_load_audio(str(audio_path), ret_type="torch").to(device) print("🔊 Generating CTC emissions...") with torch.inference_mode(): emission, _ = model(waveform) # Romanise each script word via unidecode, then filter to MMS_FA phoneme set. # Arabic "كنت" → "knt", French "cellulite" → "cellulite", "100%" → "" romanized = [unidecode(w).lower() for w in original_words] cleaned = [ "".join(c for c in rom if c in dictionary and dictionary[c] != 0) for rom in romanized ] # Build aligned transcript and a map back to original word positions transcript: List[str] = [] pos_map: List[int] = [] # pos_map[i] = original_words index for orig_idx, cw in enumerate(cleaned): if cw: transcript.append(cw) pos_map.append(orig_idx) if not transcript: raise RuntimeError("All script words were filtered during romanisation") print(f"🔗 Running forced alignment ({len(transcript)} tokens)...") tokenized = [ dictionary[c] for word in transcript for c in word if c in dictionary and dictionary[c] != 0 ] aligned_tokens, alignment_scores = cfa_align(emission, tokenized, device) token_spans = F.merge_tokens(aligned_tokens[0], alignment_scores[0]) word_spans = unflatten(token_spans, [len(w) for w in transcript]) word_ts = _postprocess_results( transcript, word_spans, waveform, emission.size(1), bundle.sample_rate, alignment_scores ) # word_ts[i]: {"start": sec, "end": sec, "text": cleaned_word} # Map aligned timestamps back to original words by position ts_by_orig: Dict[int, Dict] = {pos_map[i]: word_ts[i] for i in range(len(pos_map))} word_segments: List[Dict] = [] for orig_idx, orig_word in enumerate(original_words): if orig_idx in ts_by_orig: wt = ts_by_orig[orig_idx] word_segments.append({ "index": orig_idx + 1, "text": orig_word, "start_ms": int(wt["start"] * 1000), "end_ms": int(wt["end"] * 1000), }) else: # Word had no phoneme tokens (e.g. "100%") — place after prev word prev_end = word_segments[-1]["end_ms"] if word_segments else 0 word_segments.append({ "index": orig_idx + 1, "text": orig_word, "start_ms": prev_end, "end_ms": prev_end + MIN_CAPTION_DURATION_MS, }) word_segments = _apply_smart_gap_correction(word_segments) for i, seg in enumerate(word_segments): seg["index"] = i + 1 logger.info(f"Word-level alignment completed: {len(word_segments)} words") return word_segments def _parse_generated_srt(srt_path: str) -> List[Dict]: """Parse SRT file generated by ctc-forced-aligner into our format.""" segments = [] with open(srt_path, 'r', encoding='utf-8') as f: content = f.read().strip() # Split by double newlines to get SRT blocks blocks = [block.strip() for block in content.split('\n\n') if block.strip()] for block in blocks: lines = block.split('\n') if len(lines) < 3: continue try: # Parse SRT block index = int(lines[0]) # Parse timestamp line: "00:00:01,234 --> 00:00:02,567" timestamp_line = lines[1] start_str, end_str = timestamp_line.split(' --> ') start_ms = _srt_time_to_ms(start_str) end_ms = _srt_time_to_ms(end_str) # Get text (may be multiple lines) text = '\n'.join(lines[2:]).strip() segment = { "index": index, "text": text, "start_ms": start_ms, "end_ms": end_ms } segments.append(segment) except (ValueError, IndexError) as e: logger.warning(f"Failed to parse SRT block: {block[:50]}... Error: {e}") continue return segments def _srt_time_to_ms(time_str: str) -> int: """Convert SRT time format (HH:MM:SS,mmm) to milliseconds.""" # Format: "00:00:01,234" time_part, ms_part = time_str.split(',') hours, minutes, seconds = map(int, time_part.split(':')) total_ms = (hours * 3600 + minutes * 60 + seconds) * 1000 + int(ms_part) return total_ms def _apply_smart_gap_correction(segments: List[Dict]) -> List[Dict]: """Apply smart gap correction to prevent overlapping captions. If consecutive captions overlap (end_ms[i] > start_ms[i+1]): - Set end_ms[i] = start_ms[i+1] - GAP_BETWEEN_CAPTIONS_MS - Log which segments were corrected """ if len(segments) <= 1: return segments corrected_segments = segments.copy() corrections_made = 0 for i in range(len(corrected_segments) - 1): current = corrected_segments[i] next_segment = corrected_segments[i + 1] if current["end_ms"] > next_segment["start_ms"]: # Calculate new end time with gap new_end_ms = next_segment["start_ms"] - GAP_BETWEEN_CAPTIONS_MS # Ensure minimum caption duration min_end_ms = current["start_ms"] + MIN_CAPTION_DURATION_MS if new_end_ms < min_end_ms: # If corrected end would be too short, adjust next segment start instead next_segment["start_ms"] = min_end_ms + GAP_BETWEEN_CAPTIONS_MS current["end_ms"] = min_end_ms else: current["end_ms"] = new_end_ms logger.debug(f"Corrected overlap between segments {i+1} and {i+2}") corrections_made += 1 if corrections_made > 0: logger.info(f"Smart gap correction applied to {corrections_made} segment pairs") return corrected_segments