Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import argparse | |
| from data import load_statcast, default_window | |
| from featurize import infer_ivb_sign, engineer_pitch_features | |
| # ⬇️ NEW: import the updated API | |
| from model import fit_kmeans, nearest_comps | |
| from tags import xy_cluster_tags | |
| from plots import movement_scatter_xy | |
| from utils import ensure_dirs, ARTIFACTS_DIR | |
| import plotly.io as pio | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="PitchXY: handedness-aware pitch archetypes" | |
| ) | |
| parser.add_argument("--start", type=str, help="YYYY-MM-DD") | |
| parser.add_argument("--end", type=str, help="YYYY-MM-DD") | |
| parser.add_argument("-k", type=int, default=8) | |
| parser.add_argument( | |
| "--pitcher", type=str, help='Filter pitcher by name (e.g. "Cole")' | |
| ) | |
| parser.add_argument( | |
| "--save-html", action="store_true", help="Save plots to artifacts/" | |
| ) | |
| parser.add_argument( | |
| "--force", action="store_true", help="Force re-download Statcast" | |
| ) | |
| args = parser.parse_args() | |
| ensure_dirs() | |
| start, end = ( | |
| (args.start, args.end) if (args.start and args.end) else default_window() | |
| ) | |
| print(f"Window: {start} → {end}") | |
| df_raw = load_statcast(start, end, force=args.force) | |
| ivb_sign = infer_ivb_sign(df_raw) | |
| print(f"IVB sign inferred = {ivb_sign} (ride should be positive)") | |
| df_feat = engineer_pitch_features(df_raw, ivb_sign) | |
| # ⬇️ NEW: fit the improved model | |
| model = fit_pitch_clusters(df_feat, k=args.k) | |
| df_fit = model.df_fit # contains all original cols + 'cluster' | |
| # Tag clusters with human-readable names | |
| cluster_names = xy_cluster_tags(df_fit) | |
| df_fit["cluster_name"] = df_fit["cluster"].map(cluster_names) | |
| # Save artifacts | |
| feat_p = ARTIFACTS_DIR / "pitch_features.parquet" | |
| fit_p = ARTIFACTS_DIR / "pitch_features_clusters.parquet" | |
| df_feat.to_parquet(feat_p, index=False) | |
| df_fit.to_parquet(fit_p, index=False) | |
| print(f"Saved: {feat_p}, {fit_p}") | |
| # Optional pitcher card + comps | |
| if args.pitcher: | |
| sub = df_fit[ | |
| df_fit["player_name"].str.contains(args.pitcher, case=False, na=False) | |
| ] | |
| if sub.empty: | |
| print(f"No pitcher matched '{args.pitcher}'") | |
| else: | |
| name = sub["player_name"].iloc[0] | |
| df_p = df_fit[df_fit["player_name"] == name].sort_values("pitch_type") | |
| print(f"\n=== Scouting Card: {name} ===") | |
| print( | |
| df_p[ | |
| [ | |
| "pitch_type", | |
| "p_throws", | |
| "n", | |
| "velo", | |
| "ivb_in", | |
| "hb_as_in", | |
| "csw", | |
| "whiff_rate", | |
| "gb_rate", | |
| "zone_pct", | |
| "cluster_name", | |
| ] | |
| ].to_string(index=False) | |
| ) | |
| for _, row in df_p.iterrows(): | |
| # ⬇️ UPDATED: pass the model, not (df_fit, scaler, nn) | |
| comps = nearest_comps(row, model, k=5, allow_cross_type=False) | |
| print(f"\nNearest comps — {row['pitch_type']} ({row['cluster_name']}):") | |
| print(comps.to_string(index=False)) | |
| # Movement plot | |
| fig = movement_scatter_xy(df_fit, color="cluster_name") | |
| if args.save_html: | |
| out = ARTIFACTS_DIR / "movement_all.html" | |
| pio.write_html(fig, file=str(out), auto_open=False, include_plotlyjs="cdn") | |
| print(f"Saved plot: {out}") | |
| if __name__ == "__main__": | |
| main() | |