iRecite-MVP-API / step16_ctc_word_timestamps.py
didodev
Deploy iRecite MVP API (Docker + FastAPI)
4ca6263
import json
import re
import numpy as np
import torch
import librosa
from transformers import AutoProcessor, AutoModelForCTC
AUDIO_PATH = "sample_trim.wav"
ALIGN_PATH = "output/text_alignment_global.json"
OUT_PATH = "output/word_timestamps.json"
MODEL_ID = "jonatasgrosman/wav2vec2-large-xlsr-53-arabic"
ARABIC_DIACRITICS = re.compile(r"[\u064B-\u0652\u0670\u0653\u0654\u0655]")
TATWEEL = "\u0640"
def normalize_ar(s: str) -> str:
s = s.replace(TATWEEL, "")
s = re.sub(ARABIC_DIACRITICS, "", s)
s = s.replace("أ", "ا").replace("إ", "ا").replace("آ", "ا")
s = s.replace("ى", "ي")
s = s.replace("ة", "ه")
s = re.sub(r"\s+", " ", s).strip()
return s
def main():
# Load alignment
align = json.load(open(ALIGN_PATH, encoding="utf-8"))
alignment = [a for a in align["alignment"] if a.get("canon")]
# Load audio
audio, sr = librosa.load(AUDIO_PATH, sr=16000, mono=True)
total_sec = len(audio) / sr
# Load CTC model
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForCTC.from_pretrained(MODEL_ID)
model.eval()
inputs = processor(audio, sampling_rate=sr, return_tensors="pt", padding=True)
with torch.no_grad():
logits = model(**inputs).logits[0] # (T, V)
pred_ids = torch.argmax(logits, dim=-1).cpu().numpy().tolist()
# Convert token IDs -> tokens
vocab = processor.tokenizer.get_vocab()
# invert vocab: id -> token
inv_vocab = {i: t for t, i in vocab.items()}
blank_id = processor.tokenizer.pad_token_id
if blank_id is None:
# fallback: common wav2vec2 blank is vocab["<pad>"]
blank_id = vocab.get("<pad>", None)
tokens = [inv_vocab[i] for i in pred_ids]
# Collapse repeats, remove blanks, keep time indices
collapsed = []
prev = None
for t_idx, tok_id in enumerate(pred_ids):
if tok_id == prev:
continue
prev = tok_id
if blank_id is not None and tok_id == blank_id:
continue
tok = inv_vocab.get(tok_id, "")
if tok.strip() == "":
continue
collapsed.append((t_idx, tok))
# Map CTC time index -> seconds
# time steps correspond to model frames spanning full audio
T = logits.shape[0]
def idx_to_time(i):
return (i / T) * total_sec
# Prepare normalized ASR tokens from alignment file (we use them to locate spans)
asr_tokens = []
for a in alignment:
if a["asr_token"] is None:
asr_tokens.append(None)
else:
asr_tokens.append(normalize_ar(a["asr_token"]))
# We will approximate word timestamps by scanning collapsed tokens and
# finding the earliest and latest CTC indices where the letters of the ASR token appear in order.
#
# This is a heuristic but works reasonably for MVP.
def find_span_for_word(word_norm, start_search_idx):
if not word_norm:
return None, start_search_idx
# remove spaces
target = word_norm.replace(" ", "")
if target == "":
return None, start_search_idx
i = start_search_idx
start_idx = None
last_idx = None
for ch in target:
found = False
while i < len(collapsed):
t_idx, tok = collapsed[i]
# tokens may be characters or pieces; match if character appears
if ch in tok:
if start_idx is None:
start_idx = t_idx
last_idx = t_idx
i += 1
found = True
break
i += 1
if not found:
return None, start_search_idx
return (start_idx, last_idx), i
out_rows = []
search_ptr = 0
for a in alignment:
cw = a["canon"]
tok = a["asr_token"]
tok_norm = normalize_ar(tok) if tok else None
span, search_ptr2 = find_span_for_word(tok_norm, search_ptr) if tok_norm else (None, search_ptr)
if span is None:
start_t = None
end_t = None
else:
s_idx, e_idx = span
start_t = round(float(idx_to_time(s_idx)), 3)
end_t = round(float(idx_to_time(e_idx)), 3)
# advance pointer to keep order
search_ptr = search_ptr2
out_rows.append({
"ayah": cw["ayah"],
"word": cw["word"],
"asr_token": tok,
"score": a["score"],
"match": a["match"],
"timestamp": None if start_t is None else {"start": start_t, "end": end_t}
})
out = {
"audio_path": AUDIO_PATH,
"model": MODEL_ID,
"note": "CTC-based approximate word timestamps; upgrade later with forced alignment for higher accuracy.",
"stats": {
"words": len(out_rows),
"timestamped": sum(1 for r in out_rows if r["timestamp"] is not None)
},
"words": out_rows
}
json.dump(out, open(OUT_PATH, "w", encoding="utf-8"), ensure_ascii=False, indent=2)
print("OK ✅ wrote", OUT_PATH)
print("Timestamped:", out["stats"]["timestamped"], "/", out["stats"]["words"])
print("Sample:", out_rows[0])
if __name__ == "__main__":
main()