iRecite-MVP-API / step12_align_segments_wavlm.py
didodev
Deploy iRecite MVP API (Docker + FastAPI)
4ca6263
import json
import numpy as np
import librosa
import torch
from dtw import dtw
from transformers import AutoFeatureExtractor, AutoModel
from arabic_phonemizer import ArabicPhonemizer
AUDIO_PATH = "sample_trim.wav"
CANON_PATH = "data/fatiha_canonical_fallback.json"
OUT_PATH = "output/alignment_wavlm.json"
MODEL_ID = "microsoft/wavlm-base"
def wavlm_embeddings(audio_16k: np.ndarray, sr: int):
fe = AutoFeatureExtractor.from_pretrained(MODEL_ID)
model = AutoModel.from_pretrained(MODEL_ID)
model.eval()
inputs = fe(audio_16k, sampling_rate=sr, return_tensors="pt")
with torch.no_grad():
out = model(**inputs)
# (frames, hidden)
emb = out.last_hidden_state[0].cpu().numpy()
return emb
def mean_pool(emb: np.ndarray):
return emb.mean(axis=0)
def load_audio_segment(path, start_s, end_s, sr=16000):
audio, _ = librosa.load(path, sr=sr, mono=True, offset=float(start_s), duration=float(end_s - start_s))
return audio
def canonical_word_list(canon):
words = []
for ay in canon["ayahs"]:
for w in ay["word_info"]:
words.append({"ayah": ay["ayah"], "word": w["word"], "base": w["base"]})
return words
def vad_segments_from_step8(feedback_path="output/feedback_madd.json"):
# Use the long segments already detected in your feedback JSON
d = json.load(open(feedback_path, encoding="utf-8"))
segs = [(s["start"], s["end"]) for s in d["segments_detected"]]
return segs
def cosine(a, b):
a = a / (np.linalg.norm(a) + 1e-9)
b = b / (np.linalg.norm(b) + 1e-9)
return float(np.dot(a, b))
def main():
canon = json.load(open(CANON_PATH, encoding="utf-8"))
canon_words = canonical_word_list(canon)
# We will build "prototype embeddings" for each canonical word by phonemizing text
# For MVP we don't synthesize audio; instead we just keep word order and do local matching.
# Real version uses forced alignment / phoneme decoding.
#
# Here we do a practical improvement: map each detected long segment to a nearby word index
# based on its relative time position in the recitation.
segs = vad_segments_from_step8()
# Compute full-audio embedding frames once
full_audio, sr = librosa.load(AUDIO_PATH, sr=16000, mono=True)
full_emb = wavlm_embeddings(full_audio, sr)
# Map time->frame index approximately
# WavLM frame rate is roughly 50 fps-ish after feature extraction; we estimate using emb length
total_sec = len(full_audio) / sr
frames = full_emb.shape[0]
fps = frames / total_sec
results = []
for i, (s, e) in enumerate(segs, 1):
# Take embedding slice for this time window
f0 = int(max(0, np.floor(s * fps)))
f1 = int(min(frames, np.ceil(e * fps)))
if f1 <= f0 + 1:
continue
seg_vec = mean_pool(full_emb[f0:f1])
# Estimate position in surah by time ratio, then search around that word index
t_mid = (s + e) / 2.0
ratio = t_mid / total_sec
est_idx = int(ratio * (len(canon_words) - 1))
# Search a window around estimated index
W = 6
cand_range = range(max(0, est_idx - W), min(len(canon_words), est_idx + W + 1))
# Score candidates (we don’t have word audio prototypes, so we use a simple proxy:
# compare segment vector to other segment vectors nearby is not helpful.
# Instead: pick the nearest index as MVP and output the search window.
# This step is mainly building the structure; next step will add real phoneme decoder/alignment.)
chosen = est_idx
results.append({
"segment_index": i,
"timestamp": {"start": round(float(s), 3), "end": round(float(e), 3)},
"estimated_word_index": est_idx,
"candidate_word_indices": list(cand_range),
"mapped_word": canon_words[chosen],
"note": "MVP time-based alignment using WavLM frame mapping. Next step replaces this with phoneme/CTC alignment."
})
out = {
"audio_path": AUDIO_PATH,
"total_sec": round(float(total_sec), 3),
"wavlm": {"model_id": MODEL_ID, "frames": int(frames), "fps_est": round(float(fps), 2)},
"num_canonical_words": len(canon_words),
"segments_used": len(results),
"alignment": results
}
json.dump(out, open(OUT_PATH, "w", encoding="utf-8"), ensure_ascii=False, indent=2)
print("OK ✅ wrote", OUT_PATH)
print("Segments aligned:", len(results))
if results:
print("Sample:", results[0])
if __name__ == "__main__":
main()