hetchyy commited on
Commit
e80c3ea
·
verified ·
1 Parent(s): 343f470

Delete bench_wraparound.py

Browse files
Files changed (1) hide show
  1. bench_wraparound.py +0 -273
bench_wraparound.py DELETED
@@ -1,273 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Benchmark: compare pure-Python align_wraparound vs Cython cy_align_wraparound.
4
-
5
- Loads 50 verses from the repetition test set, runs both implementations,
6
- verifies results match, and reports timing + speedup factor.
7
-
8
- Usage:
9
- python3 bench_wraparound.py
10
- """
11
-
12
- import json
13
- import sys
14
- import time
15
- from pathlib import Path
16
-
17
- # ---------------------------------------------------------------------------
18
- # Paths
19
- # ---------------------------------------------------------------------------
20
- SCRIPT_DIR = Path(__file__).parent # quranic_universal_aligner/
21
- REPO_ROOT = SCRIPT_DIR.parent # quranic-universal-audio/
22
- DATA_DIR = REPO_ROOT / "data"
23
-
24
- sys.path.insert(0, str(SCRIPT_DIR))
25
-
26
- # ---------------------------------------------------------------------------
27
- # Import Python implementation from test harness (without modifying it)
28
- # ---------------------------------------------------------------------------
29
- sys.path.insert(0, str(SCRIPT_DIR / "docs" / "repetition_detection"))
30
- from test_wraparound_dp import (
31
- align_wraparound as py_align_wraparound,
32
- build_ref_from_phonemizer,
33
- load_substitution_costs,
34
- COST_SUBSTITUTION, COST_DELETION, COST_INSERTION,
35
- WRAP_PENALTY, MAX_WRAPS,
36
- )
37
-
38
- # ---------------------------------------------------------------------------
39
- # Import Cython implementation
40
- # ---------------------------------------------------------------------------
41
- from src.alignment._dp_core import cy_align_wraparound, init_substitution_matrix
42
-
43
- # ---------------------------------------------------------------------------
44
- # Setup — defer init_substitution_matrix until all phonemes are registered
45
- # ---------------------------------------------------------------------------
46
- SUB_COSTS = load_substitution_costs()
47
- # NOTE: init_substitution_matrix is called in main() AFTER collecting all
48
- # unique phonemes from the test data. This avoids _grow_matrix() being
49
- # triggered during alignment, which would discard custom sub costs
50
- # (a known limitation of the current _grow_matrix implementation).
51
-
52
-
53
- def load_test_data():
54
- path = DATA_DIR / "repetition_test_set_base.json"
55
- with open(path) as f:
56
- return json.load(f)
57
-
58
-
59
- def main():
60
- N = 50 # number of verses to benchmark
61
- print(f"\n{'='*70}")
62
- print(f" Wraparound DP Benchmark: Python vs Cython")
63
- print(f" Verses: {N}")
64
- print(f"{'='*70}\n")
65
-
66
- # Load test data
67
- print("Loading test data...", end=" ", flush=True)
68
- test_data = load_test_data()
69
- print("done.")
70
-
71
- # Initialize phonemizer
72
- print("Initializing phonemizer...", end=" ", flush=True)
73
- from src.alignment.phonemizer_utils import get_phonemizer
74
- pm = get_phonemizer()
75
- print("done.\n")
76
-
77
- # Collect verse cases
78
- cases = []
79
- for reciter in [k for k in test_data if k != "_meta"]:
80
- for verse_key, verse_data in test_data[reciter].items():
81
- cases.append((reciter, verse_key, verse_data))
82
- if len(cases) >= N:
83
- break
84
- if len(cases) >= N:
85
- break
86
-
87
- # Prepare all inputs first (exclude phonemizer time from benchmark)
88
- print(f"Preparing {len(cases)} verse inputs...", end=" ", flush=True)
89
- prepared = []
90
- all_phonemes = set()
91
- for reciter, verse_key, verse_data in cases:
92
- surah, ayah = map(int, verse_key.split(":"))
93
- P = verse_data["asr_phonemes"].split()
94
- try:
95
- R, R_phone_to_word, _ = build_ref_from_phonemizer(pm, surah, ayah)
96
- except Exception as e:
97
- print(f"\n SKIP {reciter}/{verse_key}: {e}")
98
- continue
99
- if len(R) == 0:
100
- continue
101
- all_phonemes.update(P)
102
- all_phonemes.update(R)
103
- prepared.append({
104
- "reciter": reciter,
105
- "verse_key": verse_key,
106
- "P": P,
107
- "R": R,
108
- "R_phone_to_word": R_phone_to_word,
109
- "p_len": len(P),
110
- "r_len": len(R),
111
- "num_reps": verse_data["num_reps"],
112
- })
113
- print(f"done. ({len(prepared)} usable)")
114
-
115
- # Pre-register ALL phonemes in the substitution cost dict so that
116
- # _grow_matrix() is never triggered during alignment. This avoids
117
- # a known limitation where _grow_matrix discards custom sub costs.
118
- print(f"Initializing substitution matrix ({len(all_phonemes)} phonemes)...", end=" ", flush=True)
119
- augmented_costs = dict(SUB_COSTS)
120
- for ph in all_phonemes:
121
- # Add a self-pair entry so the phoneme gets an ID during init
122
- augmented_costs[(ph, ph)] = 0.0
123
- init_substitution_matrix(augmented_costs, COST_SUBSTITUTION)
124
- print("done.\n")
125
-
126
- # Common kwargs
127
- common = dict(
128
- expected_word=0,
129
- prior_weight=0.0,
130
- cost_sub=COST_SUBSTITUTION,
131
- cost_del=COST_DELETION,
132
- cost_ins=COST_INSERTION,
133
- wrap_penalty=WRAP_PENALTY,
134
- max_wraps=MAX_WRAPS,
135
- scoring_mode="subtract",
136
- wrap_score_cost=0.01,
137
- )
138
-
139
- # --- Warmup (1 run each) ---
140
- print("Warmup run...", end=" ", flush=True)
141
- d = prepared[0]
142
- py_align_wraparound(d["P"], d["R"], d["R_phone_to_word"], **common)
143
- # Cython version doesn't take scoring_mode/wrap_score_cost in common if default
144
- cy_align_wraparound(
145
- d["P"], d["R"], d["R_phone_to_word"],
146
- expected_word=0, prior_weight=0.0,
147
- cost_sub=COST_SUBSTITUTION, cost_del=COST_DELETION, cost_ins=COST_INSERTION,
148
- wrap_penalty=WRAP_PENALTY, max_wraps=MAX_WRAPS,
149
- scoring_mode="subtract", wrap_score_cost=0.01,
150
- )
151
- print("done.\n")
152
-
153
- # --- Benchmark Python ---
154
- print(f"Running Python align_wraparound on {len(prepared)} verses...")
155
- py_results = []
156
- t0 = time.perf_counter()
157
- for d in prepared:
158
- result = py_align_wraparound(d["P"], d["R"], d["R_phone_to_word"], **common)
159
- py_results.append(result)
160
- py_total = time.perf_counter() - t0
161
- print(f" Python total: {py_total*1000:.1f} ms ({py_total*1000/len(prepared):.1f} ms/verse)\n")
162
-
163
- # --- Benchmark Cython ---
164
- print(f"Running Cython cy_align_wraparound on {len(prepared)} verses...")
165
- cy_results = []
166
- t0 = time.perf_counter()
167
- for d in prepared:
168
- result = cy_align_wraparound(
169
- d["P"], d["R"], d["R_phone_to_word"],
170
- expected_word=0, prior_weight=0.0,
171
- cost_sub=COST_SUBSTITUTION, cost_del=COST_DELETION, cost_ins=COST_INSERTION,
172
- wrap_penalty=WRAP_PENALTY, max_wraps=MAX_WRAPS,
173
- scoring_mode="subtract", wrap_score_cost=0.01,
174
- )
175
- cy_results.append(result)
176
- cy_total = time.perf_counter() - t0
177
- print(f" Cython total: {cy_total*1000:.1f} ms ({cy_total*1000/len(prepared):.1f} ms/verse)\n")
178
-
179
- # --- Compare results ---
180
- print(f"{'='*70}")
181
- print(f" Verification: comparing Python vs Cython results")
182
- print(f"{'='*70}\n")
183
-
184
- mismatches = 0
185
- tol = 1e-6
186
-
187
- for i, (d, py_r, cy_r) in enumerate(zip(prepared, py_results, cy_results)):
188
- # Python returns 7-tuple: (best_j, best_j_start, best_cost, best_norm, n_wraps, max_j, wrap_points)
189
- # Cython returns 6-tuple: (best_j, best_j_start, best_cost, best_norm, n_wraps, max_j)
190
- py_j, py_js, py_cost, py_norm, py_k, py_mj, py_wp = py_r
191
- cy_j, cy_js, cy_cost, cy_norm, cy_k, cy_mj = cy_r
192
-
193
- match = True
194
- errors = []
195
-
196
- if py_j != cy_j:
197
- errors.append(f"best_j: py={py_j} cy={cy_j}")
198
- match = False
199
- if py_js != cy_js:
200
- errors.append(f"best_j_start: py={py_js} cy={cy_js}")
201
- match = False
202
- if py_j is not None and cy_j is not None:
203
- if abs(py_cost - cy_cost) > tol:
204
- errors.append(f"best_cost: py={py_cost:.6f} cy={cy_cost:.6f}")
205
- match = False
206
- if abs(py_norm - cy_norm) > tol:
207
- errors.append(f"best_norm: py={py_norm:.6f} cy={cy_norm:.6f}")
208
- match = False
209
- if py_k != cy_k:
210
- errors.append(f"n_wraps: py={py_k} cy={cy_k}")
211
- match = False
212
- # max_j comparison (Python uses max(max_j, j) for end_j; Cython returns raw max_j)
213
- if py_j is not None and cy_j is not None and py_mj != cy_mj:
214
- errors.append(f"max_j: py={py_mj} cy={cy_mj}")
215
- match = False
216
-
217
- if not match:
218
- mismatches += 1
219
- print(f" MISMATCH [{i}] {d['reciter']}/{d['verse_key']} "
220
- f"(P={d['p_len']}, R={d['r_len']}, reps={d['num_reps']})")
221
- for e in errors:
222
- print(f" {e}")
223
-
224
- # --- Summary ---
225
- print(f"\n{'='*70}")
226
- print(f" SUMMARY")
227
- print(f"{'='*70}")
228
- print(f" Verses benchmarked: {len(prepared)}")
229
- print(f" Python total: {py_total*1000:>8.1f} ms ({py_total*1000/len(prepared):>6.1f} ms/verse)")
230
- print(f" Cython total: {cy_total*1000:>8.1f} ms ({cy_total*1000/len(prepared):>6.1f} ms/verse)")
231
- speedup = py_total / cy_total if cy_total > 0 else float('inf')
232
- print(f" Speedup: {speedup:>8.1f}x")
233
- print(f" Mismatches: {mismatches}/{len(prepared)}")
234
- if mismatches == 0:
235
- print(f" Result: ALL MATCH")
236
- else:
237
- print(f" Result: {mismatches} MISMATCHES FOUND")
238
- print(f"{'='*70}\n")
239
-
240
- # Also test scoring modes
241
- print("Testing scoring modes (no_subtract, additive)...")
242
- for mode in ["no_subtract", "additive"]:
243
- d = prepared[0]
244
- py_r = py_align_wraparound(
245
- d["P"], d["R"], d["R_phone_to_word"],
246
- expected_word=0, prior_weight=0.0,
247
- cost_sub=COST_SUBSTITUTION, cost_del=COST_DELETION, cost_ins=COST_INSERTION,
248
- wrap_penalty=WRAP_PENALTY, max_wraps=MAX_WRAPS,
249
- scoring_mode=mode, wrap_score_cost=0.01,
250
- )
251
- cy_r = cy_align_wraparound(
252
- d["P"], d["R"], d["R_phone_to_word"],
253
- expected_word=0, prior_weight=0.0,
254
- cost_sub=COST_SUBSTITUTION, cost_del=COST_DELETION, cost_ins=COST_INSERTION,
255
- wrap_penalty=WRAP_PENALTY, max_wraps=MAX_WRAPS,
256
- scoring_mode=mode, wrap_score_cost=0.01,
257
- )
258
- py_j, py_js, py_cost, py_norm, py_k, py_mj, _ = py_r
259
- cy_j, cy_js, cy_cost, cy_norm, cy_k, cy_mj = cy_r
260
-
261
- ok = (py_j == cy_j and py_js == cy_js and py_k == cy_k)
262
- if py_j is not None and cy_j is not None:
263
- ok = ok and abs(py_cost - cy_cost) < tol and abs(py_norm - cy_norm) < tol
264
- status = "OK" if ok else "MISMATCH"
265
- print(f" {mode}: {status} "
266
- f"(py: j={py_j},js={py_js},cost={py_cost:.4f},norm={py_norm:.4f},k={py_k} | "
267
- f"cy: j={cy_j},js={cy_js},cost={cy_cost:.4f},norm={cy_norm:.4f},k={cy_k})")
268
-
269
- print()
270
-
271
-
272
- if __name__ == "__main__":
273
- main()