File size: 4,042 Bytes
1efa4be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""Benchmark: Python vs Rust address converter."""

import json
import subprocess
import time
from pathlib import Path

DATA_PATH = Path(__file__).parent / "data" / "mapping.json"
RUST_BIN = Path(__file__).parent / "rust" / "target" / "release" / "address-converter"


def load_test_addresses(n: int | None = None) -> list[str]:
    """Generate test addresses from mapping.json ward records."""
    with open(DATA_PATH, encoding="utf-8") as f:
        data = json.load(f)

    addresses = []
    for rec in data["ward_mapping"]:
        addr = f"{rec['old_ward']}, {rec['old_district']}, {rec['old_province']}"
        addresses.append(addr)

    if n is not None:
        addresses = addresses[:n]
    return addresses


def bench_python(addresses: list[str]) -> tuple[list[str], float]:
    """Benchmark Python implementation, return (results, elapsed_seconds)."""
    from src.converter import convert_address

    # Warm up (load data)
    convert_address(addresses[0])

    start = time.perf_counter()
    results = []
    for addr in addresses:
        r = convert_address(addr)
        results.append(r.converted)
    elapsed = time.perf_counter() - start
    return results, elapsed


def bench_rust(addresses: list[str]) -> tuple[list[str], float]:
    """Benchmark Rust implementation, return (results, elapsed_seconds)."""
    if not RUST_BIN.exists():
        raise FileNotFoundError(
            f"Rust binary not found at {RUST_BIN}. Run: cd rust && cargo build --release"
        )

    input_data = "\n".join(addresses) + "\n"
    env = {"MAPPING_JSON": str(DATA_PATH)}

    proc = subprocess.run(
        [str(RUST_BIN), "bench"],
        input=input_data,
        capture_output=True,
        text=True,
        env=env,
    )

    if proc.returncode != 0:
        raise RuntimeError(f"Rust bench failed: {proc.stderr}")

    results = proc.stdout.strip().split("\n") if proc.stdout.strip() else []

    # Parse timing from stderr
    elapsed = 0.0
    for line in proc.stderr.strip().split("\n"):
        if line.startswith("BENCH:"):
            # "BENCH: 10602 addresses in 0.012345 s (1.234 us/addr)"
            parts = line.split()
            elapsed = float(parts[4])
            break

    return results, elapsed


def verify_correctness(py_results: list[str], rs_results: list[str], addresses: list[str]):
    """Check that Rust output matches Python output."""
    mismatches = 0
    for i, (py, rs) in enumerate(zip(py_results, rs_results)):
        if py != rs:
            mismatches += 1
            if mismatches <= 10:
                print(f"  MISMATCH [{i}] input: {addresses[i]}")
                print(f"    Python: {py}")
                print(f"    Rust:   {rs}")
    total = len(py_results)
    match = total - mismatches
    print(f"\nCorrectness: {match}/{total} match ({100*match/total:.1f}%)")
    if mismatches > 10:
        print(f"  ... and {mismatches - 10} more mismatches")
    return mismatches == 0


def main():
    print("Loading test addresses from mapping.json ...")
    addresses = load_test_addresses()
    n = len(addresses)
    print(f"  {n} addresses loaded\n")

    # Python benchmark
    print("Running Python benchmark ...")
    py_results, py_time = bench_python(addresses)
    print(f"  Python: {py_time:.4f} s  ({py_time/n*1e6:.1f} us/addr)\n")

    # Rust benchmark
    print("Running Rust benchmark ...")
    rs_results, rs_time = bench_rust(addresses)
    print(f"  Rust:   {rs_time:.4f} s  ({rs_time/n*1e6:.1f} us/addr)\n")

    # Correctness
    print("Verifying correctness ...")
    verify_correctness(py_results, rs_results, addresses)

    # Comparison
    speedup = py_time / rs_time if rs_time > 0 else float("inf")
    print(f"\n{'='*55}")
    print(f"  {'':20s} {'Total (s)':>10s} {'Per-addr (us)':>14s}")
    print(f"  {'Python':20s} {py_time:>10.4f} {py_time/n*1e6:>14.1f}")
    print(f"  {'Rust':20s} {rs_time:>10.4f} {rs_time/n*1e6:>14.1f}")
    print(f"  {'Speedup':20s} {speedup:>10.1f}x")
    print(f"{'='*55}")


if __name__ == "__main__":
    main()