Spaces:
Running on Zero
Running on Zero
Delete bench_wraparound.py
Browse files- bench_wraparound.py +0 -273
bench_wraparound.py
DELETED
|
@@ -1,273 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""
|
| 3 |
-
Benchmark: compare pure-Python align_wraparound vs Cython cy_align_wraparound.
|
| 4 |
-
|
| 5 |
-
Loads 50 verses from the repetition test set, runs both implementations,
|
| 6 |
-
verifies results match, and reports timing + speedup factor.
|
| 7 |
-
|
| 8 |
-
Usage:
|
| 9 |
-
python3 bench_wraparound.py
|
| 10 |
-
"""
|
| 11 |
-
|
| 12 |
-
import json
|
| 13 |
-
import sys
|
| 14 |
-
import time
|
| 15 |
-
from pathlib import Path
|
| 16 |
-
|
| 17 |
-
# ---------------------------------------------------------------------------
|
| 18 |
-
# Paths
|
| 19 |
-
# ---------------------------------------------------------------------------
|
| 20 |
-
SCRIPT_DIR = Path(__file__).parent # quranic_universal_aligner/
|
| 21 |
-
REPO_ROOT = SCRIPT_DIR.parent # quranic-universal-audio/
|
| 22 |
-
DATA_DIR = REPO_ROOT / "data"
|
| 23 |
-
|
| 24 |
-
sys.path.insert(0, str(SCRIPT_DIR))
|
| 25 |
-
|
| 26 |
-
# ---------------------------------------------------------------------------
|
| 27 |
-
# Import Python implementation from test harness (without modifying it)
|
| 28 |
-
# ---------------------------------------------------------------------------
|
| 29 |
-
sys.path.insert(0, str(SCRIPT_DIR / "docs" / "repetition_detection"))
|
| 30 |
-
from test_wraparound_dp import (
|
| 31 |
-
align_wraparound as py_align_wraparound,
|
| 32 |
-
build_ref_from_phonemizer,
|
| 33 |
-
load_substitution_costs,
|
| 34 |
-
COST_SUBSTITUTION, COST_DELETION, COST_INSERTION,
|
| 35 |
-
WRAP_PENALTY, MAX_WRAPS,
|
| 36 |
-
)
|
| 37 |
-
|
| 38 |
-
# ---------------------------------------------------------------------------
|
| 39 |
-
# Import Cython implementation
|
| 40 |
-
# ---------------------------------------------------------------------------
|
| 41 |
-
from src.alignment._dp_core import cy_align_wraparound, init_substitution_matrix
|
| 42 |
-
|
| 43 |
-
# ---------------------------------------------------------------------------
|
| 44 |
-
# Setup — defer init_substitution_matrix until all phonemes are registered
|
| 45 |
-
# ---------------------------------------------------------------------------
|
| 46 |
-
SUB_COSTS = load_substitution_costs()
|
| 47 |
-
# NOTE: init_substitution_matrix is called in main() AFTER collecting all
|
| 48 |
-
# unique phonemes from the test data. This avoids _grow_matrix() being
|
| 49 |
-
# triggered during alignment, which would discard custom sub costs
|
| 50 |
-
# (a known limitation of the current _grow_matrix implementation).
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
def load_test_data():
|
| 54 |
-
path = DATA_DIR / "repetition_test_set_base.json"
|
| 55 |
-
with open(path) as f:
|
| 56 |
-
return json.load(f)
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
def main():
|
| 60 |
-
N = 50 # number of verses to benchmark
|
| 61 |
-
print(f"\n{'='*70}")
|
| 62 |
-
print(f" Wraparound DP Benchmark: Python vs Cython")
|
| 63 |
-
print(f" Verses: {N}")
|
| 64 |
-
print(f"{'='*70}\n")
|
| 65 |
-
|
| 66 |
-
# Load test data
|
| 67 |
-
print("Loading test data...", end=" ", flush=True)
|
| 68 |
-
test_data = load_test_data()
|
| 69 |
-
print("done.")
|
| 70 |
-
|
| 71 |
-
# Initialize phonemizer
|
| 72 |
-
print("Initializing phonemizer...", end=" ", flush=True)
|
| 73 |
-
from src.alignment.phonemizer_utils import get_phonemizer
|
| 74 |
-
pm = get_phonemizer()
|
| 75 |
-
print("done.\n")
|
| 76 |
-
|
| 77 |
-
# Collect verse cases
|
| 78 |
-
cases = []
|
| 79 |
-
for reciter in [k for k in test_data if k != "_meta"]:
|
| 80 |
-
for verse_key, verse_data in test_data[reciter].items():
|
| 81 |
-
cases.append((reciter, verse_key, verse_data))
|
| 82 |
-
if len(cases) >= N:
|
| 83 |
-
break
|
| 84 |
-
if len(cases) >= N:
|
| 85 |
-
break
|
| 86 |
-
|
| 87 |
-
# Prepare all inputs first (exclude phonemizer time from benchmark)
|
| 88 |
-
print(f"Preparing {len(cases)} verse inputs...", end=" ", flush=True)
|
| 89 |
-
prepared = []
|
| 90 |
-
all_phonemes = set()
|
| 91 |
-
for reciter, verse_key, verse_data in cases:
|
| 92 |
-
surah, ayah = map(int, verse_key.split(":"))
|
| 93 |
-
P = verse_data["asr_phonemes"].split()
|
| 94 |
-
try:
|
| 95 |
-
R, R_phone_to_word, _ = build_ref_from_phonemizer(pm, surah, ayah)
|
| 96 |
-
except Exception as e:
|
| 97 |
-
print(f"\n SKIP {reciter}/{verse_key}: {e}")
|
| 98 |
-
continue
|
| 99 |
-
if len(R) == 0:
|
| 100 |
-
continue
|
| 101 |
-
all_phonemes.update(P)
|
| 102 |
-
all_phonemes.update(R)
|
| 103 |
-
prepared.append({
|
| 104 |
-
"reciter": reciter,
|
| 105 |
-
"verse_key": verse_key,
|
| 106 |
-
"P": P,
|
| 107 |
-
"R": R,
|
| 108 |
-
"R_phone_to_word": R_phone_to_word,
|
| 109 |
-
"p_len": len(P),
|
| 110 |
-
"r_len": len(R),
|
| 111 |
-
"num_reps": verse_data["num_reps"],
|
| 112 |
-
})
|
| 113 |
-
print(f"done. ({len(prepared)} usable)")
|
| 114 |
-
|
| 115 |
-
# Pre-register ALL phonemes in the substitution cost dict so that
|
| 116 |
-
# _grow_matrix() is never triggered during alignment. This avoids
|
| 117 |
-
# a known limitation where _grow_matrix discards custom sub costs.
|
| 118 |
-
print(f"Initializing substitution matrix ({len(all_phonemes)} phonemes)...", end=" ", flush=True)
|
| 119 |
-
augmented_costs = dict(SUB_COSTS)
|
| 120 |
-
for ph in all_phonemes:
|
| 121 |
-
# Add a self-pair entry so the phoneme gets an ID during init
|
| 122 |
-
augmented_costs[(ph, ph)] = 0.0
|
| 123 |
-
init_substitution_matrix(augmented_costs, COST_SUBSTITUTION)
|
| 124 |
-
print("done.\n")
|
| 125 |
-
|
| 126 |
-
# Common kwargs
|
| 127 |
-
common = dict(
|
| 128 |
-
expected_word=0,
|
| 129 |
-
prior_weight=0.0,
|
| 130 |
-
cost_sub=COST_SUBSTITUTION,
|
| 131 |
-
cost_del=COST_DELETION,
|
| 132 |
-
cost_ins=COST_INSERTION,
|
| 133 |
-
wrap_penalty=WRAP_PENALTY,
|
| 134 |
-
max_wraps=MAX_WRAPS,
|
| 135 |
-
scoring_mode="subtract",
|
| 136 |
-
wrap_score_cost=0.01,
|
| 137 |
-
)
|
| 138 |
-
|
| 139 |
-
# --- Warmup (1 run each) ---
|
| 140 |
-
print("Warmup run...", end=" ", flush=True)
|
| 141 |
-
d = prepared[0]
|
| 142 |
-
py_align_wraparound(d["P"], d["R"], d["R_phone_to_word"], **common)
|
| 143 |
-
# Cython version doesn't take scoring_mode/wrap_score_cost in common if default
|
| 144 |
-
cy_align_wraparound(
|
| 145 |
-
d["P"], d["R"], d["R_phone_to_word"],
|
| 146 |
-
expected_word=0, prior_weight=0.0,
|
| 147 |
-
cost_sub=COST_SUBSTITUTION, cost_del=COST_DELETION, cost_ins=COST_INSERTION,
|
| 148 |
-
wrap_penalty=WRAP_PENALTY, max_wraps=MAX_WRAPS,
|
| 149 |
-
scoring_mode="subtract", wrap_score_cost=0.01,
|
| 150 |
-
)
|
| 151 |
-
print("done.\n")
|
| 152 |
-
|
| 153 |
-
# --- Benchmark Python ---
|
| 154 |
-
print(f"Running Python align_wraparound on {len(prepared)} verses...")
|
| 155 |
-
py_results = []
|
| 156 |
-
t0 = time.perf_counter()
|
| 157 |
-
for d in prepared:
|
| 158 |
-
result = py_align_wraparound(d["P"], d["R"], d["R_phone_to_word"], **common)
|
| 159 |
-
py_results.append(result)
|
| 160 |
-
py_total = time.perf_counter() - t0
|
| 161 |
-
print(f" Python total: {py_total*1000:.1f} ms ({py_total*1000/len(prepared):.1f} ms/verse)\n")
|
| 162 |
-
|
| 163 |
-
# --- Benchmark Cython ---
|
| 164 |
-
print(f"Running Cython cy_align_wraparound on {len(prepared)} verses...")
|
| 165 |
-
cy_results = []
|
| 166 |
-
t0 = time.perf_counter()
|
| 167 |
-
for d in prepared:
|
| 168 |
-
result = cy_align_wraparound(
|
| 169 |
-
d["P"], d["R"], d["R_phone_to_word"],
|
| 170 |
-
expected_word=0, prior_weight=0.0,
|
| 171 |
-
cost_sub=COST_SUBSTITUTION, cost_del=COST_DELETION, cost_ins=COST_INSERTION,
|
| 172 |
-
wrap_penalty=WRAP_PENALTY, max_wraps=MAX_WRAPS,
|
| 173 |
-
scoring_mode="subtract", wrap_score_cost=0.01,
|
| 174 |
-
)
|
| 175 |
-
cy_results.append(result)
|
| 176 |
-
cy_total = time.perf_counter() - t0
|
| 177 |
-
print(f" Cython total: {cy_total*1000:.1f} ms ({cy_total*1000/len(prepared):.1f} ms/verse)\n")
|
| 178 |
-
|
| 179 |
-
# --- Compare results ---
|
| 180 |
-
print(f"{'='*70}")
|
| 181 |
-
print(f" Verification: comparing Python vs Cython results")
|
| 182 |
-
print(f"{'='*70}\n")
|
| 183 |
-
|
| 184 |
-
mismatches = 0
|
| 185 |
-
tol = 1e-6
|
| 186 |
-
|
| 187 |
-
for i, (d, py_r, cy_r) in enumerate(zip(prepared, py_results, cy_results)):
|
| 188 |
-
# Python returns 7-tuple: (best_j, best_j_start, best_cost, best_norm, n_wraps, max_j, wrap_points)
|
| 189 |
-
# Cython returns 6-tuple: (best_j, best_j_start, best_cost, best_norm, n_wraps, max_j)
|
| 190 |
-
py_j, py_js, py_cost, py_norm, py_k, py_mj, py_wp = py_r
|
| 191 |
-
cy_j, cy_js, cy_cost, cy_norm, cy_k, cy_mj = cy_r
|
| 192 |
-
|
| 193 |
-
match = True
|
| 194 |
-
errors = []
|
| 195 |
-
|
| 196 |
-
if py_j != cy_j:
|
| 197 |
-
errors.append(f"best_j: py={py_j} cy={cy_j}")
|
| 198 |
-
match = False
|
| 199 |
-
if py_js != cy_js:
|
| 200 |
-
errors.append(f"best_j_start: py={py_js} cy={cy_js}")
|
| 201 |
-
match = False
|
| 202 |
-
if py_j is not None and cy_j is not None:
|
| 203 |
-
if abs(py_cost - cy_cost) > tol:
|
| 204 |
-
errors.append(f"best_cost: py={py_cost:.6f} cy={cy_cost:.6f}")
|
| 205 |
-
match = False
|
| 206 |
-
if abs(py_norm - cy_norm) > tol:
|
| 207 |
-
errors.append(f"best_norm: py={py_norm:.6f} cy={cy_norm:.6f}")
|
| 208 |
-
match = False
|
| 209 |
-
if py_k != cy_k:
|
| 210 |
-
errors.append(f"n_wraps: py={py_k} cy={cy_k}")
|
| 211 |
-
match = False
|
| 212 |
-
# max_j comparison (Python uses max(max_j, j) for end_j; Cython returns raw max_j)
|
| 213 |
-
if py_j is not None and cy_j is not None and py_mj != cy_mj:
|
| 214 |
-
errors.append(f"max_j: py={py_mj} cy={cy_mj}")
|
| 215 |
-
match = False
|
| 216 |
-
|
| 217 |
-
if not match:
|
| 218 |
-
mismatches += 1
|
| 219 |
-
print(f" MISMATCH [{i}] {d['reciter']}/{d['verse_key']} "
|
| 220 |
-
f"(P={d['p_len']}, R={d['r_len']}, reps={d['num_reps']})")
|
| 221 |
-
for e in errors:
|
| 222 |
-
print(f" {e}")
|
| 223 |
-
|
| 224 |
-
# --- Summary ---
|
| 225 |
-
print(f"\n{'='*70}")
|
| 226 |
-
print(f" SUMMARY")
|
| 227 |
-
print(f"{'='*70}")
|
| 228 |
-
print(f" Verses benchmarked: {len(prepared)}")
|
| 229 |
-
print(f" Python total: {py_total*1000:>8.1f} ms ({py_total*1000/len(prepared):>6.1f} ms/verse)")
|
| 230 |
-
print(f" Cython total: {cy_total*1000:>8.1f} ms ({cy_total*1000/len(prepared):>6.1f} ms/verse)")
|
| 231 |
-
speedup = py_total / cy_total if cy_total > 0 else float('inf')
|
| 232 |
-
print(f" Speedup: {speedup:>8.1f}x")
|
| 233 |
-
print(f" Mismatches: {mismatches}/{len(prepared)}")
|
| 234 |
-
if mismatches == 0:
|
| 235 |
-
print(f" Result: ALL MATCH")
|
| 236 |
-
else:
|
| 237 |
-
print(f" Result: {mismatches} MISMATCHES FOUND")
|
| 238 |
-
print(f"{'='*70}\n")
|
| 239 |
-
|
| 240 |
-
# Also test scoring modes
|
| 241 |
-
print("Testing scoring modes (no_subtract, additive)...")
|
| 242 |
-
for mode in ["no_subtract", "additive"]:
|
| 243 |
-
d = prepared[0]
|
| 244 |
-
py_r = py_align_wraparound(
|
| 245 |
-
d["P"], d["R"], d["R_phone_to_word"],
|
| 246 |
-
expected_word=0, prior_weight=0.0,
|
| 247 |
-
cost_sub=COST_SUBSTITUTION, cost_del=COST_DELETION, cost_ins=COST_INSERTION,
|
| 248 |
-
wrap_penalty=WRAP_PENALTY, max_wraps=MAX_WRAPS,
|
| 249 |
-
scoring_mode=mode, wrap_score_cost=0.01,
|
| 250 |
-
)
|
| 251 |
-
cy_r = cy_align_wraparound(
|
| 252 |
-
d["P"], d["R"], d["R_phone_to_word"],
|
| 253 |
-
expected_word=0, prior_weight=0.0,
|
| 254 |
-
cost_sub=COST_SUBSTITUTION, cost_del=COST_DELETION, cost_ins=COST_INSERTION,
|
| 255 |
-
wrap_penalty=WRAP_PENALTY, max_wraps=MAX_WRAPS,
|
| 256 |
-
scoring_mode=mode, wrap_score_cost=0.01,
|
| 257 |
-
)
|
| 258 |
-
py_j, py_js, py_cost, py_norm, py_k, py_mj, _ = py_r
|
| 259 |
-
cy_j, cy_js, cy_cost, cy_norm, cy_k, cy_mj = cy_r
|
| 260 |
-
|
| 261 |
-
ok = (py_j == cy_j and py_js == cy_js and py_k == cy_k)
|
| 262 |
-
if py_j is not None and cy_j is not None:
|
| 263 |
-
ok = ok and abs(py_cost - cy_cost) < tol and abs(py_norm - cy_norm) < tol
|
| 264 |
-
status = "OK" if ok else "MISMATCH"
|
| 265 |
-
print(f" {mode}: {status} "
|
| 266 |
-
f"(py: j={py_j},js={py_js},cost={py_cost:.4f},norm={py_norm:.4f},k={py_k} | "
|
| 267 |
-
f"cy: j={cy_j},js={cy_js},cost={cy_cost:.4f},norm={cy_norm:.4f},k={cy_k})")
|
| 268 |
-
|
| 269 |
-
print()
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
if __name__ == "__main__":
|
| 273 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|