Spaces:
Sleeping
Sleeping
File size: 4,599 Bytes
4ca6263 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import json
import numpy as np
import librosa
import torch
from dtw import dtw
from transformers import AutoFeatureExtractor, AutoModel
from arabic_phonemizer import ArabicPhonemizer
AUDIO_PATH = "sample_trim.wav"
CANON_PATH = "data/fatiha_canonical_fallback.json"
OUT_PATH = "output/alignment_wavlm.json"
MODEL_ID = "microsoft/wavlm-base"
def wavlm_embeddings(audio_16k: np.ndarray, sr: int):
fe = AutoFeatureExtractor.from_pretrained(MODEL_ID)
model = AutoModel.from_pretrained(MODEL_ID)
model.eval()
inputs = fe(audio_16k, sampling_rate=sr, return_tensors="pt")
with torch.no_grad():
out = model(**inputs)
# (frames, hidden)
emb = out.last_hidden_state[0].cpu().numpy()
return emb
def mean_pool(emb: np.ndarray):
return emb.mean(axis=0)
def load_audio_segment(path, start_s, end_s, sr=16000):
audio, _ = librosa.load(path, sr=sr, mono=True, offset=float(start_s), duration=float(end_s - start_s))
return audio
def canonical_word_list(canon):
words = []
for ay in canon["ayahs"]:
for w in ay["word_info"]:
words.append({"ayah": ay["ayah"], "word": w["word"], "base": w["base"]})
return words
def vad_segments_from_step8(feedback_path="output/feedback_madd.json"):
# Use the long segments already detected in your feedback JSON
d = json.load(open(feedback_path, encoding="utf-8"))
segs = [(s["start"], s["end"]) for s in d["segments_detected"]]
return segs
def cosine(a, b):
a = a / (np.linalg.norm(a) + 1e-9)
b = b / (np.linalg.norm(b) + 1e-9)
return float(np.dot(a, b))
def main():
canon = json.load(open(CANON_PATH, encoding="utf-8"))
canon_words = canonical_word_list(canon)
# We will build "prototype embeddings" for each canonical word by phonemizing text
# For MVP we don't synthesize audio; instead we just keep word order and do local matching.
# Real version uses forced alignment / phoneme decoding.
#
# Here we do a practical improvement: map each detected long segment to a nearby word index
# based on its relative time position in the recitation.
segs = vad_segments_from_step8()
# Compute full-audio embedding frames once
full_audio, sr = librosa.load(AUDIO_PATH, sr=16000, mono=True)
full_emb = wavlm_embeddings(full_audio, sr)
# Map time->frame index approximately
# WavLM frame rate is roughly 50 fps-ish after feature extraction; we estimate using emb length
total_sec = len(full_audio) / sr
frames = full_emb.shape[0]
fps = frames / total_sec
results = []
for i, (s, e) in enumerate(segs, 1):
# Take embedding slice for this time window
f0 = int(max(0, np.floor(s * fps)))
f1 = int(min(frames, np.ceil(e * fps)))
if f1 <= f0 + 1:
continue
seg_vec = mean_pool(full_emb[f0:f1])
# Estimate position in surah by time ratio, then search around that word index
t_mid = (s + e) / 2.0
ratio = t_mid / total_sec
est_idx = int(ratio * (len(canon_words) - 1))
# Search a window around estimated index
W = 6
cand_range = range(max(0, est_idx - W), min(len(canon_words), est_idx + W + 1))
# Score candidates (we don’t have word audio prototypes, so we use a simple proxy:
# compare segment vector to other segment vectors nearby is not helpful.
# Instead: pick the nearest index as MVP and output the search window.
# This step is mainly building the structure; next step will add real phoneme decoder/alignment.)
chosen = est_idx
results.append({
"segment_index": i,
"timestamp": {"start": round(float(s), 3), "end": round(float(e), 3)},
"estimated_word_index": est_idx,
"candidate_word_indices": list(cand_range),
"mapped_word": canon_words[chosen],
"note": "MVP time-based alignment using WavLM frame mapping. Next step replaces this with phoneme/CTC alignment."
})
out = {
"audio_path": AUDIO_PATH,
"total_sec": round(float(total_sec), 3),
"wavlm": {"model_id": MODEL_ID, "frames": int(frames), "fps_est": round(float(fps), 2)},
"num_canonical_words": len(canon_words),
"segments_used": len(results),
"alignment": results
}
json.dump(out, open(OUT_PATH, "w", encoding="utf-8"), ensure_ascii=False, indent=2)
print("OK ✅ wrote", OUT_PATH)
print("Segments aligned:", len(results))
if results:
print("Sample:", results[0])
if __name__ == "__main__":
main() |