File size: 4,599 Bytes
4ca6263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import json
import numpy as np
import librosa
import torch
from dtw import dtw
from transformers import AutoFeatureExtractor, AutoModel
from arabic_phonemizer import ArabicPhonemizer

AUDIO_PATH = "sample_trim.wav"
CANON_PATH = "data/fatiha_canonical_fallback.json"
OUT_PATH = "output/alignment_wavlm.json"

MODEL_ID = "microsoft/wavlm-base"

def wavlm_embeddings(audio_16k: np.ndarray, sr: int):
    fe = AutoFeatureExtractor.from_pretrained(MODEL_ID)
    model = AutoModel.from_pretrained(MODEL_ID)
    model.eval()

    inputs = fe(audio_16k, sampling_rate=sr, return_tensors="pt")
    with torch.no_grad():
        out = model(**inputs)
    # (frames, hidden)
    emb = out.last_hidden_state[0].cpu().numpy()
    return emb

def mean_pool(emb: np.ndarray):
    return emb.mean(axis=0)

def load_audio_segment(path, start_s, end_s, sr=16000):
    audio, _ = librosa.load(path, sr=sr, mono=True, offset=float(start_s), duration=float(end_s - start_s))
    return audio

def canonical_word_list(canon):
    words = []
    for ay in canon["ayahs"]:
        for w in ay["word_info"]:
            words.append({"ayah": ay["ayah"], "word": w["word"], "base": w["base"]})
    return words

def vad_segments_from_step8(feedback_path="output/feedback_madd.json"):
    # Use the long segments already detected in your feedback JSON
    d = json.load(open(feedback_path, encoding="utf-8"))
    segs = [(s["start"], s["end"]) for s in d["segments_detected"]]
    return segs

def cosine(a, b):
    a = a / (np.linalg.norm(a) + 1e-9)
    b = b / (np.linalg.norm(b) + 1e-9)
    return float(np.dot(a, b))

def main():
    canon = json.load(open(CANON_PATH, encoding="utf-8"))
    canon_words = canonical_word_list(canon)

    # We will build "prototype embeddings" for each canonical word by phonemizing text
    # For MVP we don't synthesize audio; instead we just keep word order and do local matching.
    # Real version uses forced alignment / phoneme decoding.
    #
    # Here we do a practical improvement: map each detected long segment to a nearby word index
    # based on its relative time position in the recitation.
    segs = vad_segments_from_step8()

    # Compute full-audio embedding frames once
    full_audio, sr = librosa.load(AUDIO_PATH, sr=16000, mono=True)
    full_emb = wavlm_embeddings(full_audio, sr)

    # Map time->frame index approximately
    # WavLM frame rate is roughly 50 fps-ish after feature extraction; we estimate using emb length
    total_sec = len(full_audio) / sr
    frames = full_emb.shape[0]
    fps = frames / total_sec

    results = []
    for i, (s, e) in enumerate(segs, 1):
        # Take embedding slice for this time window
        f0 = int(max(0, np.floor(s * fps)))
        f1 = int(min(frames, np.ceil(e * fps)))
        if f1 <= f0 + 1:
            continue
        seg_vec = mean_pool(full_emb[f0:f1])

        # Estimate position in surah by time ratio, then search around that word index
        t_mid = (s + e) / 2.0
        ratio = t_mid / total_sec
        est_idx = int(ratio * (len(canon_words) - 1))

        # Search a window around estimated index
        W = 6
        cand_range = range(max(0, est_idx - W), min(len(canon_words), est_idx + W + 1))

        # Score candidates (we don’t have word audio prototypes, so we use a simple proxy:
        # compare segment vector to other segment vectors nearby is not helpful.
        # Instead: pick the nearest index as MVP and output the search window.
        # This step is mainly building the structure; next step will add real phoneme decoder/alignment.)
        chosen = est_idx

        results.append({
            "segment_index": i,
            "timestamp": {"start": round(float(s), 3), "end": round(float(e), 3)},
            "estimated_word_index": est_idx,
            "candidate_word_indices": list(cand_range),
            "mapped_word": canon_words[chosen],
            "note": "MVP time-based alignment using WavLM frame mapping. Next step replaces this with phoneme/CTC alignment."
        })

    out = {
        "audio_path": AUDIO_PATH,
        "total_sec": round(float(total_sec), 3),
        "wavlm": {"model_id": MODEL_ID, "frames": int(frames), "fps_est": round(float(fps), 2)},
        "num_canonical_words": len(canon_words),
        "segments_used": len(results),
        "alignment": results
    }

    json.dump(out, open(OUT_PATH, "w", encoding="utf-8"), ensure_ascii=False, indent=2)
    print("OK ✅ wrote", OUT_PATH)
    print("Segments aligned:", len(results))
    if results:
        print("Sample:", results[0])

if __name__ == "__main__":
    main()