File size: 3,629 Bytes
c75151e
 
 
 
752a595
 
c75151e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
752a595
 
 
 
 
 
c75151e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
752a595
 
c75151e
 
 
 
 
 
 
 
 
752a595
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from __future__ import annotations
import argparse
from data import load_statcast, default_window
from featurize import infer_ivb_sign, engineer_pitch_features

# ⬇️ NEW: import the updated API
from model import fit_kmeans, nearest_comps
from tags import xy_cluster_tags
from plots import movement_scatter_xy
from utils import ensure_dirs, ARTIFACTS_DIR
import plotly.io as pio


def main():
    parser = argparse.ArgumentParser(
        description="PitchXY: handedness-aware pitch archetypes"
    )
    parser.add_argument("--start", type=str, help="YYYY-MM-DD")
    parser.add_argument("--end", type=str, help="YYYY-MM-DD")
    parser.add_argument("-k", type=int, default=8)
    parser.add_argument(
        "--pitcher", type=str, help='Filter pitcher by name (e.g. "Cole")'
    )
    parser.add_argument(
        "--save-html", action="store_true", help="Save plots to artifacts/"
    )
    parser.add_argument(
        "--force", action="store_true", help="Force re-download Statcast"
    )
    args = parser.parse_args()

    ensure_dirs()
    start, end = (
        (args.start, args.end) if (args.start and args.end) else default_window()
    )
    print(f"Window: {start}{end}")

    df_raw = load_statcast(start, end, force=args.force)
    ivb_sign = infer_ivb_sign(df_raw)
    print(f"IVB sign inferred = {ivb_sign} (ride should be positive)")

    df_feat = engineer_pitch_features(df_raw, ivb_sign)

    # ⬇️ NEW: fit the improved model
    model = fit_pitch_clusters(df_feat, k=args.k)
    df_fit = model.df_fit  # contains all original cols + 'cluster'

    # Tag clusters with human-readable names
    cluster_names = xy_cluster_tags(df_fit)
    df_fit["cluster_name"] = df_fit["cluster"].map(cluster_names)

    # Save artifacts
    feat_p = ARTIFACTS_DIR / "pitch_features.parquet"
    fit_p = ARTIFACTS_DIR / "pitch_features_clusters.parquet"
    df_feat.to_parquet(feat_p, index=False)
    df_fit.to_parquet(fit_p, index=False)
    print(f"Saved: {feat_p}, {fit_p}")

    # Optional pitcher card + comps
    if args.pitcher:
        sub = df_fit[
            df_fit["player_name"].str.contains(args.pitcher, case=False, na=False)
        ]
        if sub.empty:
            print(f"No pitcher matched '{args.pitcher}'")
        else:
            name = sub["player_name"].iloc[0]
            df_p = df_fit[df_fit["player_name"] == name].sort_values("pitch_type")
            print(f"\n=== Scouting Card: {name} ===")
            print(
                df_p[
                    [
                        "pitch_type",
                        "p_throws",
                        "n",
                        "velo",
                        "ivb_in",
                        "hb_as_in",
                        "csw",
                        "whiff_rate",
                        "gb_rate",
                        "zone_pct",
                        "cluster_name",
                    ]
                ].to_string(index=False)
            )
            for _, row in df_p.iterrows():
                # ⬇️ UPDATED: pass the model, not (df_fit, scaler, nn)
                comps = nearest_comps(row, model, k=5, allow_cross_type=False)
                print(f"\nNearest comps — {row['pitch_type']} ({row['cluster_name']}):")
                print(comps.to_string(index=False))

    # Movement plot
    fig = movement_scatter_xy(df_fit, color="cluster_name")
    if args.save_html:
        out = ARTIFACTS_DIR / "movement_all.html"
        pio.write_html(fig, file=str(out), auto_open=False, include_plotlyjs="cdn")
        print(f"Saved plot: {out}")


if __name__ == "__main__":
    main()