import os import gradio as gr from config import (MFA_SPACE_URL, MFA_TIMEOUT, MFA_PROGRESS_SEGMENT_RATE, MFA_METHOD, MFA_BEAM, MFA_RETRY_BEAM, MFA_SHARED_CMVN) # Lowercase special ref names for case-insensitive matching _SPECIAL_REFS = {"basmala", "isti'adha"} _BASMALA_TEXT = "بِسْمِ ٱللَّهِ ٱلرَّحْمَٰنِ ٱلرَّحِيم" _ISTIATHA_TEXT = "أَعُوذُ بِٱللَّهِ مِنَ الشَّيْطَانِ الرَّجِيم" def _mfa_upload_and_submit(refs, audio_paths, method=MFA_METHOD, beam=MFA_BEAM, retry_beam=MFA_RETRY_BEAM, shared_cmvn=MFA_SHARED_CMVN, padding="forward"): """Upload audio files and submit alignment batch to the MFA Space. Returns (event_id, headers, base_url) so the caller can yield a progress update before blocking on the SSE result stream. Args: refs: List of reference strings. audio_paths: List of audio file paths. method: Alignment method ("kalpy", "align_one", "python_api", "cli"). beam: Viterbi beam width (default 10). retry_beam: Retry beam width (default 40). padding: Gap-padding strategy ("forward", "symmetric", "none"). """ import requests hf_token = os.environ.get("HF_TOKEN", "") headers = {} if hf_token: headers["Authorization"] = f"Bearer {hf_token}" base = MFA_SPACE_URL # Upload all audio files in a single batched request files_payload = [] open_handles = [] for path in audio_paths: fh = open(path, "rb") open_handles.append(fh) files_payload.append(("files", (os.path.basename(path), fh, "audio/wav"))) try: resp = requests.post( f"{base}/gradio_api/upload", headers=headers, files=files_payload, timeout=MFA_TIMEOUT, ) resp.raise_for_status() if "application/json" not in resp.headers.get("content-type", ""): raise gr.Error( "MFA Space is not running (may be paused or restarting). " "Please try again in a minute." ) uploaded_paths = resp.json() finally: for fh in open_handles: fh.close() # Build FileData objects file_data_list = [ {"path": p, "meta": {"_type": "gradio.FileData"}} for p in uploaded_paths ] # Submit batch alignment (7 params: refs, files, method, beam, retry_beam, shared_cmvn, padding) submit_resp = requests.post( f"{base}/gradio_api/call/align_batch", headers={**headers, "Content-Type": "application/json"}, json={"data": [refs, file_data_list, method, str(beam), str(retry_beam), str(shared_cmvn).lower(), padding]}, timeout=MFA_TIMEOUT, ) submit_resp.raise_for_status() if "application/json" not in submit_resp.headers.get("content-type", ""): raise gr.Error( "MFA Space is not running (may be paused or restarting). " "Please try again in a minute." ) event_id = submit_resp.json()["event_id"] return event_id, headers, base def _mfa_wait_result(event_id, headers, base): """Wait for the MFA SSE stream and return parsed results list.""" import requests import json sse_resp = requests.get( f"{base}/gradio_api/call/align_batch/{event_id}", headers=headers, stream=True, timeout=MFA_TIMEOUT, ) sse_resp.raise_for_status() result_data = None current_event = None for line in sse_resp.iter_lines(decode_unicode=True): if line and line.startswith("event: "): current_event = line[7:] elif line and line.startswith("data: "): data_str = line[6:] if current_event == "complete": result_data = data_str elif current_event == "error": # Gradio 6.x may send null as error data; provide actionable message if data_str.strip() in ("null", ""): raise RuntimeError( "MFA align_batch failed: Space returned null error. " "This usually means a parameter count mismatch or " "Gradio input validation failure. Check that the " "client sends all required parameters." ) raise RuntimeError(f"MFA align_batch SSE error: {data_str}") if result_data is None: raise RuntimeError("No data received from MFA align_batch SSE stream") parsed = json.loads(result_data) # Gradio wraps the return value in a list if isinstance(parsed, list) and len(parsed) == 1: parsed = parsed[0] if parsed is None: raise RuntimeError("MFA align_batch returned null result") if not isinstance(parsed, dict) or parsed.get("status") != "ok": raise RuntimeError(f"MFA align_batch failed: {parsed}") return parsed["results"] # --------------------------------------------------------------------------- # MFA split helper (used by pipeline post-processing) # --------------------------------------------------------------------------- def mfa_split_timestamps(audio_int16, sample_rate, mfa_refs, method=MFA_METHOD, beam=MFA_BEAM, retry_beam=MFA_RETRY_BEAM, shared_cmvn=MFA_SHARED_CMVN): """Call MFA to get word timestamps for splitting segments. Args: audio_int16: List of int16 audio arrays (one per segment to split). sample_rate: Audio sample rate. mfa_refs: List of MFA ref strings (one per segment). method: Alignment method ("kalpy", "align_one", "python_api", "cli"). beam: Viterbi beam width (default 10). retry_beam: Retry beam width (default 40). Returns: List of results (one per segment), each a list of {location, start, end} dicts, or None on failure for that segment. """ import tempfile import wave if not mfa_refs or not audio_int16: return [None] * len(mfa_refs) # Write segment audio to temp WAV files audio_paths = [] for audio in audio_int16: tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) with wave.open(tmp.name, "wb") as wf: wf.setnchannels(1) wf.setsampwidth(2) wf.setframerate(sample_rate) wf.writeframes(audio.tobytes()) audio_paths.append(tmp.name) try: event_id, headers, base = _mfa_upload_and_submit( mfa_refs, audio_paths, method=method, beam=beam, retry_beam=retry_beam, shared_cmvn=shared_cmvn) results = _mfa_wait_result(event_id, headers, base) print(f"[MFA_SPLIT] Got {len(results)} results from MFA API") out = [] for result in results: if result.get("status") != "ok": print(f"[MFA_SPLIT] Segment failed: ref={result.get('ref')} error={result.get('error')}") out.append(None) else: out.append(result.get("words", [])) return out except Exception as e: print(f"[MFA_SPLIT] MFA call failed: {e}") return [None] * len(mfa_refs) finally: import os as _os for p in audio_paths: try: _os.unlink(p) except OSError: pass # --------------------------------------------------------------------------- # Reusable helpers (shared by UI generator and API function) # --------------------------------------------------------------------------- def _make_ts_key(result_idx, ref, loc): """Build the composite key used in word/letter timestamp dicts.""" is_special = ref.strip().lower() in _SPECIAL_REFS is_fused = "+" in ref if is_special: base_key = f"{ref}:{loc}" elif is_fused and loc.startswith("0:0:"): base_key = f"{ref}:{loc}" else: base_key = loc return f"{result_idx}:{base_key}" def _build_mfa_ref(seg): """Build the MFA ref string for a single segment. Returns None to skip.""" ref_from = seg.get("ref_from", "") ref_to = seg.get("ref_to", "") confidence = seg.get("confidence", 0) if not ref_from: ref_from = seg.get("special_type", "") ref_to = ref_from if not ref_from or confidence <= 0: return None if ref_from == ref_to: mfa_ref = ref_from else: mfa_ref = f"{ref_from}-{ref_to}" _is_special_ref = ref_from.strip().lower() in _SPECIAL_REFS if not _is_special_ref: matched_text = seg.get("matched_text", "") if matched_text.startswith(_ISTIATHA_TEXT): mfa_ref = f"Isti'adha+{mfa_ref}" elif matched_text.startswith(_BASMALA_TEXT): mfa_ref = f"Basmala+{mfa_ref}" return mfa_ref def _ensure_segment_wavs(segments, segment_dir): """Write individual segment WAVs from full.wav on demand (for MFA). Segments are sliced from the full recording using soundfile's frame-level random access — no need to load the entire file. """ if not segment_dir: return full_path = os.path.join(segment_dir, "full.wav") if not os.path.exists(full_path): return import soundfile as sf info = sf.info(full_path) sr = info.samplerate written = 0 for seg in segments: idx = seg.get("segment", 0) - 1 wav_path = os.path.join(segment_dir, f"seg_{idx}.wav") if os.path.exists(wav_path): continue start_frame = int(seg.get("time_from", 0) * sr) stop_frame = int(seg.get("time_to", 0) * sr) audio_slice, _ = sf.read(full_path, start=start_frame, stop=stop_frame, dtype='int16') sf.write(wav_path, audio_slice, sr, format='WAV', subtype='PCM_16') written += 1 if written: print(f"[MFA] Wrote {written} segment WAVs on demand from full.wav") def _build_mfa_refs(segments, segment_dir): """Build MFA refs and audio paths from segments. Returns (refs, audio_paths, seg_to_result_idx). """ refs = [] audio_paths = [] seg_to_result_idx = {} for seg in segments: seg_idx = seg.get("segment", 0) - 1 mfa_ref = _build_mfa_ref(seg) if mfa_ref is None: continue audio_path = os.path.join(segment_dir, f"seg_{seg_idx}.wav") if segment_dir else None if not audio_path or not os.path.exists(audio_path): continue seg_to_result_idx[seg_idx] = len(refs) refs.append(mfa_ref) audio_paths.append(audio_path) return refs, audio_paths, seg_to_result_idx def _assign_letter_groups(letters, word_location): """Assign group_id to letters sharing identical (start, end) timestamps.""" if not letters: return [] result = [] group_id = 0 prev_ts = None for letter in letters: ts = (letter.get("start"), letter.get("end")) if ts != prev_ts: group_id += 1 prev_ts = ts result.append({ "char": letter.get("char", ""), "start": letter.get("start"), "end": letter.get("end"), "group_id": f"{word_location}:{group_id}", }) return result def _build_timestamp_lookups(results): """Build timestamp lookup dicts from MFA results. Returns (word_timestamps, letter_timestamps, word_to_all_results). """ word_timestamps = {} letter_timestamps = {} word_to_all_results = {} for result_idx, result in enumerate(results): if result.get("status") != "ok": continue ref = result.get("ref", "") is_special = ref.strip().lower() in _SPECIAL_REFS is_fused = "+" in ref for word in result.get("words", []): loc = word.get("location", "") if loc: key = _make_ts_key(result_idx, ref, loc) word_timestamps[key] = (word["start"], word["end"]) letters = word.get("letters") if letters: letter_timestamps[key] = _assign_letter_groups(letters, loc) if not is_special and not (is_fused and loc.startswith("0:0:")): if loc not in word_to_all_results: word_to_all_results[loc] = [] word_to_all_results[loc].append(result_idx) return word_timestamps, letter_timestamps, word_to_all_results def _build_crossword_groups(results, letter_ts_dict): """Build mapping of (key, letter_idx) -> cross-word group_id. Only checks word boundaries: last letter(s) of word N vs first letter(s) of word N+1. """ crossword_groups = {} for result_idx, result in enumerate(results): if result.get("status") != "ok": continue ref = result.get("ref", "") words = result.get("words", []) for word_i in range(len(words) - 1): word_a = words[word_i] word_b = words[word_i + 1] loc_a = word_a.get("location", "") loc_b = word_b.get("location", "") if not loc_a or not loc_b: continue key_a = _make_ts_key(result_idx, ref, loc_a) key_b = _make_ts_key(result_idx, ref, loc_b) letters_a = letter_ts_dict.get(key_a, []) letters_b = letter_ts_dict.get(key_b, []) if not letters_a or not letters_b: continue for idx_a in range(len(letters_a) - 1, max(len(letters_a) - 3, -1), -1): letter_a = letters_a[idx_a] if letter_a.get("start") is None or letter_a.get("end") is None: continue for idx_b in range(min(3, len(letters_b))): letter_b = letters_b[idx_b] if letter_b.get("start") is None or letter_b.get("end") is None: continue if letter_a["start"] == letter_b["start"] and letter_a["end"] == letter_b["end"]: group_id = f"xword-{result_idx}-{word_i}" crossword_groups[(key_a, idx_a)] = group_id crossword_groups[(key_b, idx_b)] = group_id return crossword_groups def _reconstruct_ref_key(seg): """Reconstruct the MFA ref key for a segment (for result matching).""" ref_from = seg.get("ref_from", "") ref_to = seg.get("ref_to", "") if not ref_from: ref_from = seg.get("special_type", "") ref_to = ref_from ref_key = f"{ref_from}-{ref_to}" if ref_from != ref_to else ref_from is_special = ref_from.strip().lower() in _SPECIAL_REFS if not is_special: matched_text = seg.get("matched_text", "") if matched_text.startswith(_ISTIATHA_TEXT): ref_key = f"Isti'adha+{ref_key}" elif matched_text.startswith(_BASMALA_TEXT): ref_key = f"Basmala+{ref_key}" return ref_key def _extend_word_timestamps(word_timestamps, segments, seg_to_result_idx, results, segment_dir): """Extend word ends to fill gaps between consecutive words. Mutates word_timestamps in place. """ import wave for seg in segments: ref_from = seg.get("ref_from", "") confidence = seg.get("confidence", 0) if not ref_from: ref_from = seg.get("special_type", "") if not ref_from or confidence <= 0: continue seg_idx = seg.get("segment", 0) - 1 result_idx = seg_to_result_idx.get(seg_idx) if result_idx is None: continue ref_key = _reconstruct_ref_key(seg) seg_word_locs = [] for result in results: if result.get("ref") == ref_key and result.get("status") == "ok": for w in result.get("words", []): loc = w.get("location", "") if loc: key = _make_ts_key(result_idx, ref_key, loc) if key in word_timestamps: seg_word_locs.append(key) break if not seg_word_locs: continue # Extend each word's end to the next word's start for i in range(len(seg_word_locs) - 1): cur_start, cur_end = word_timestamps[seg_word_locs[i]] nxt_start, _ = word_timestamps[seg_word_locs[i + 1]] if nxt_start > cur_end: word_timestamps[seg_word_locs[i]] = (cur_start, nxt_start) # Extend first word back to time 0 so highlight starts immediately first_loc = seg_word_locs[0] first_start, first_end = word_timestamps[first_loc] if first_start > 0: word_timestamps[first_loc] = (0, first_end) # Extend last word to segment audio duration last_loc = seg_word_locs[-1] last_start, last_end = word_timestamps[last_loc] audio_path = os.path.join(segment_dir, f"seg_{seg_idx}.wav") if segment_dir else None if audio_path and os.path.exists(audio_path): with wave.open(audio_path, 'rb') as wf: seg_duration = wf.getnframes() / wf.getframerate() if seg_duration > last_end: word_timestamps[last_loc] = (last_start, seg_duration) def _build_enriched_json(segments, results, seg_to_result_idx, word_timestamps, letter_timestamps, granularity, *, minimal=False): """Build enriched segments with word (and optionally letter) timestamps. When *minimal* is True (API path), each segment only contains ``segment`` number + ``words`` array. When False (UI path), all original segment fields are preserved. Returns dict with "segments" key. """ from src.core.quran_index import get_quran_index index = get_quran_index() include_letters = (granularity == "words+chars") def _get_word_text(location): if not location or location.startswith("0:0:"): return "" try: parts = location.split(":") if len(parts) >= 3: key = (int(parts[0]), int(parts[1]), int(parts[2])) idx = index.word_lookup.get(key) if idx is not None: return index.words[idx].display_text except (ValueError, IndexError): pass return "" enriched_segments = [] for seg in segments: seg_idx = seg.get("segment", 0) - 1 result_idx = seg_to_result_idx.get(seg_idx) if minimal: segment_data = {"segment": seg.get("segment", 0)} else: segment_data = dict(seg) if result_idx is not None: _ref = seg.get("ref_from", "") or seg.get("special_type", "") is_special = _ref.lower() in _SPECIAL_REFS special_words = seg.get("matched_text", "").replace(" \u06dd ", " ").split() if is_special else [] for i, result in enumerate(results): if i != result_idx or result.get("status") != "ok": continue words_with_ts = [] for word_idx, word in enumerate(result.get("words", [])): if word.get("start") is None or word.get("end") is None: continue location = word.get("location", "") if minimal: # API: compact — [location, start, end] or [location, start, end, letters] word_entry = [location, round(word["start"], 4), round(word["end"], 4)] if include_letters and word.get("letters"): word_entry.append([ [lt.get("char", ""), round(lt["start"], 4), round(lt["end"], 4)] for lt in word.get("letters", []) if lt.get("start") is not None ]) words_with_ts.append(word_entry) else: # UI: keyed objects with display text if is_special or location.startswith("0:0:"): word_text = special_words[word_idx] if word_idx < len(special_words) else "" else: word_text = _get_word_text(location) word_data = { "word": word_text, "location": location, "start": round(word["start"], 4), "end": round(word["end"], 4), } if include_letters and word.get("letters"): word_data["letters"] = [ { "char": lt.get("char", ""), "start": round(lt["start"], 4), "end": round(lt["end"], 4), } for lt in word.get("letters", []) if lt.get("start") is not None ] words_with_ts.append(word_data) if words_with_ts: segment_data["words"] = words_with_ts break enriched_segments.append(segment_data) return {"segments": enriched_segments} # --------------------------------------------------------------------------- # Synchronous API function # --------------------------------------------------------------------------- def compute_mfa_timestamps_api(segments, segment_dir, granularity="words", method=MFA_METHOD, beam=MFA_BEAM, retry_beam=MFA_RETRY_BEAM, shared_cmvn=MFA_SHARED_CMVN): """Run MFA forced alignment and return enriched segments (no UI/HTML). Args: segments: List of segment dicts (same format as alignment response). segment_dir: Path to directory containing per-segment WAV files. granularity: "words" or "words+chars". method: Alignment method ("kalpy", "align_one", "python_api", "cli"). beam: Viterbi beam width (default 10). retry_beam: Retry beam width (default 40). Returns: Dict with "segments" key containing enriched segment data. """ if not granularity or granularity not in ("words", "words+chars"): granularity = "words" # Write individual segment WAVs on demand (sliced from full.wav) _ensure_segment_wavs(segments, segment_dir) refs, audio_paths, seg_to_result_idx = _build_mfa_refs(segments, segment_dir) if not refs: return {"segments": segments} event_id, headers, base = _mfa_upload_and_submit( refs, audio_paths, method=method, beam=beam, retry_beam=retry_beam, shared_cmvn=shared_cmvn) results = _mfa_wait_result(event_id, headers, base) word_ts, letter_ts, _ = _build_timestamp_lookups(results) _build_crossword_groups(results, letter_ts) _extend_word_timestamps(word_ts, segments, seg_to_result_idx, results, segment_dir) return _build_enriched_json(segments, results, seg_to_result_idx, word_ts, letter_ts, granularity, minimal=True) # --------------------------------------------------------------------------- # UI progress bar # --------------------------------------------------------------------------- def _ts_progress_bar_html(total_segments, rate, animated=True): """Return HTML for a progress bar showing Segment x/N. When *animated* is False the bar is static at 0 %. When True the CSS fill animation runs and an img-onerror trick drives the text counter (since Gradio innerHTML doesn't execute