address / benchmark.py
rain1024's picture
Add Rust address converter and use underthesea for normalization
1efa4be
"""Benchmark: Python vs Rust address converter."""
import json
import subprocess
import time
from pathlib import Path
DATA_PATH = Path(__file__).parent / "data" / "mapping.json"
RUST_BIN = Path(__file__).parent / "rust" / "target" / "release" / "address-converter"
def load_test_addresses(n: int | None = None) -> list[str]:
"""Generate test addresses from mapping.json ward records."""
with open(DATA_PATH, encoding="utf-8") as f:
data = json.load(f)
addresses = []
for rec in data["ward_mapping"]:
addr = f"{rec['old_ward']}, {rec['old_district']}, {rec['old_province']}"
addresses.append(addr)
if n is not None:
addresses = addresses[:n]
return addresses
def bench_python(addresses: list[str]) -> tuple[list[str], float]:
"""Benchmark Python implementation, return (results, elapsed_seconds)."""
from src.converter import convert_address
# Warm up (load data)
convert_address(addresses[0])
start = time.perf_counter()
results = []
for addr in addresses:
r = convert_address(addr)
results.append(r.converted)
elapsed = time.perf_counter() - start
return results, elapsed
def bench_rust(addresses: list[str]) -> tuple[list[str], float]:
"""Benchmark Rust implementation, return (results, elapsed_seconds)."""
if not RUST_BIN.exists():
raise FileNotFoundError(
f"Rust binary not found at {RUST_BIN}. Run: cd rust && cargo build --release"
)
input_data = "\n".join(addresses) + "\n"
env = {"MAPPING_JSON": str(DATA_PATH)}
proc = subprocess.run(
[str(RUST_BIN), "bench"],
input=input_data,
capture_output=True,
text=True,
env=env,
)
if proc.returncode != 0:
raise RuntimeError(f"Rust bench failed: {proc.stderr}")
results = proc.stdout.strip().split("\n") if proc.stdout.strip() else []
# Parse timing from stderr
elapsed = 0.0
for line in proc.stderr.strip().split("\n"):
if line.startswith("BENCH:"):
# "BENCH: 10602 addresses in 0.012345 s (1.234 us/addr)"
parts = line.split()
elapsed = float(parts[4])
break
return results, elapsed
def verify_correctness(py_results: list[str], rs_results: list[str], addresses: list[str]):
"""Check that Rust output matches Python output."""
mismatches = 0
for i, (py, rs) in enumerate(zip(py_results, rs_results)):
if py != rs:
mismatches += 1
if mismatches <= 10:
print(f" MISMATCH [{i}] input: {addresses[i]}")
print(f" Python: {py}")
print(f" Rust: {rs}")
total = len(py_results)
match = total - mismatches
print(f"\nCorrectness: {match}/{total} match ({100*match/total:.1f}%)")
if mismatches > 10:
print(f" ... and {mismatches - 10} more mismatches")
return mismatches == 0
def main():
print("Loading test addresses from mapping.json ...")
addresses = load_test_addresses()
n = len(addresses)
print(f" {n} addresses loaded\n")
# Python benchmark
print("Running Python benchmark ...")
py_results, py_time = bench_python(addresses)
print(f" Python: {py_time:.4f} s ({py_time/n*1e6:.1f} us/addr)\n")
# Rust benchmark
print("Running Rust benchmark ...")
rs_results, rs_time = bench_rust(addresses)
print(f" Rust: {rs_time:.4f} s ({rs_time/n*1e6:.1f} us/addr)\n")
# Correctness
print("Verifying correctness ...")
verify_correctness(py_results, rs_results, addresses)
# Comparison
speedup = py_time / rs_time if rs_time > 0 else float("inf")
print(f"\n{'='*55}")
print(f" {'':20s} {'Total (s)':>10s} {'Per-addr (us)':>14s}")
print(f" {'Python':20s} {py_time:>10.4f} {py_time/n*1e6:>14.1f}")
print(f" {'Rust':20s} {rs_time:>10.4f} {rs_time/n*1e6:>14.1f}")
print(f" {'Speedup':20s} {speedup:>10.1f}x")
print(f"{'='*55}")
if __name__ == "__main__":
main()