Spaces:
Sleeping
Sleeping
| import json | |
| import re | |
| import numpy as np | |
| import torch | |
| import librosa | |
| from transformers import AutoProcessor, AutoModelForCTC | |
| AUDIO_PATH = "sample_trim.wav" | |
| ALIGN_PATH = "output/text_alignment_global.json" | |
| OUT_PATH = "output/word_timestamps.json" | |
| MODEL_ID = "jonatasgrosman/wav2vec2-large-xlsr-53-arabic" | |
| ARABIC_DIACRITICS = re.compile(r"[\u064B-\u0652\u0670\u0653\u0654\u0655]") | |
| TATWEEL = "\u0640" | |
| def normalize_ar(s: str) -> str: | |
| s = s.replace(TATWEEL, "") | |
| s = re.sub(ARABIC_DIACRITICS, "", s) | |
| s = s.replace("أ", "ا").replace("إ", "ا").replace("آ", "ا") | |
| s = s.replace("ى", "ي") | |
| s = s.replace("ة", "ه") | |
| s = re.sub(r"\s+", " ", s).strip() | |
| return s | |
| def main(): | |
| # Load alignment | |
| align = json.load(open(ALIGN_PATH, encoding="utf-8")) | |
| alignment = [a for a in align["alignment"] if a.get("canon")] | |
| # Load audio | |
| audio, sr = librosa.load(AUDIO_PATH, sr=16000, mono=True) | |
| total_sec = len(audio) / sr | |
| # Load CTC model | |
| processor = AutoProcessor.from_pretrained(MODEL_ID) | |
| model = AutoModelForCTC.from_pretrained(MODEL_ID) | |
| model.eval() | |
| inputs = processor(audio, sampling_rate=sr, return_tensors="pt", padding=True) | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits[0] # (T, V) | |
| pred_ids = torch.argmax(logits, dim=-1).cpu().numpy().tolist() | |
| # Convert token IDs -> tokens | |
| vocab = processor.tokenizer.get_vocab() | |
| # invert vocab: id -> token | |
| inv_vocab = {i: t for t, i in vocab.items()} | |
| blank_id = processor.tokenizer.pad_token_id | |
| if blank_id is None: | |
| # fallback: common wav2vec2 blank is vocab["<pad>"] | |
| blank_id = vocab.get("<pad>", None) | |
| tokens = [inv_vocab[i] for i in pred_ids] | |
| # Collapse repeats, remove blanks, keep time indices | |
| collapsed = [] | |
| prev = None | |
| for t_idx, tok_id in enumerate(pred_ids): | |
| if tok_id == prev: | |
| continue | |
| prev = tok_id | |
| if blank_id is not None and tok_id == blank_id: | |
| continue | |
| tok = inv_vocab.get(tok_id, "") | |
| if tok.strip() == "": | |
| continue | |
| collapsed.append((t_idx, tok)) | |
| # Map CTC time index -> seconds | |
| # time steps correspond to model frames spanning full audio | |
| T = logits.shape[0] | |
| def idx_to_time(i): | |
| return (i / T) * total_sec | |
| # Prepare normalized ASR tokens from alignment file (we use them to locate spans) | |
| asr_tokens = [] | |
| for a in alignment: | |
| if a["asr_token"] is None: | |
| asr_tokens.append(None) | |
| else: | |
| asr_tokens.append(normalize_ar(a["asr_token"])) | |
| # We will approximate word timestamps by scanning collapsed tokens and | |
| # finding the earliest and latest CTC indices where the letters of the ASR token appear in order. | |
| # | |
| # This is a heuristic but works reasonably for MVP. | |
| def find_span_for_word(word_norm, start_search_idx): | |
| if not word_norm: | |
| return None, start_search_idx | |
| # remove spaces | |
| target = word_norm.replace(" ", "") | |
| if target == "": | |
| return None, start_search_idx | |
| i = start_search_idx | |
| start_idx = None | |
| last_idx = None | |
| for ch in target: | |
| found = False | |
| while i < len(collapsed): | |
| t_idx, tok = collapsed[i] | |
| # tokens may be characters or pieces; match if character appears | |
| if ch in tok: | |
| if start_idx is None: | |
| start_idx = t_idx | |
| last_idx = t_idx | |
| i += 1 | |
| found = True | |
| break | |
| i += 1 | |
| if not found: | |
| return None, start_search_idx | |
| return (start_idx, last_idx), i | |
| out_rows = [] | |
| search_ptr = 0 | |
| for a in alignment: | |
| cw = a["canon"] | |
| tok = a["asr_token"] | |
| tok_norm = normalize_ar(tok) if tok else None | |
| span, search_ptr2 = find_span_for_word(tok_norm, search_ptr) if tok_norm else (None, search_ptr) | |
| if span is None: | |
| start_t = None | |
| end_t = None | |
| else: | |
| s_idx, e_idx = span | |
| start_t = round(float(idx_to_time(s_idx)), 3) | |
| end_t = round(float(idx_to_time(e_idx)), 3) | |
| # advance pointer to keep order | |
| search_ptr = search_ptr2 | |
| out_rows.append({ | |
| "ayah": cw["ayah"], | |
| "word": cw["word"], | |
| "asr_token": tok, | |
| "score": a["score"], | |
| "match": a["match"], | |
| "timestamp": None if start_t is None else {"start": start_t, "end": end_t} | |
| }) | |
| out = { | |
| "audio_path": AUDIO_PATH, | |
| "model": MODEL_ID, | |
| "note": "CTC-based approximate word timestamps; upgrade later with forced alignment for higher accuracy.", | |
| "stats": { | |
| "words": len(out_rows), | |
| "timestamped": sum(1 for r in out_rows if r["timestamp"] is not None) | |
| }, | |
| "words": out_rows | |
| } | |
| json.dump(out, open(OUT_PATH, "w", encoding="utf-8"), ensure_ascii=False, indent=2) | |
| print("OK ✅ wrote", OUT_PATH) | |
| print("Timestamped:", out["stats"]["timestamped"], "/", out["stats"]["words"]) | |
| print("Sample:", out_rows[0]) | |
| if __name__ == "__main__": | |
| main() |