pitch_dash / bin /cli.py
rsm-roguchi
update
752a595
from __future__ import annotations
import argparse
from data import load_statcast, default_window
from featurize import infer_ivb_sign, engineer_pitch_features
# ⬇️ NEW: import the updated API
from model import fit_kmeans, nearest_comps
from tags import xy_cluster_tags
from plots import movement_scatter_xy
from utils import ensure_dirs, ARTIFACTS_DIR
import plotly.io as pio
def main():
parser = argparse.ArgumentParser(
description="PitchXY: handedness-aware pitch archetypes"
)
parser.add_argument("--start", type=str, help="YYYY-MM-DD")
parser.add_argument("--end", type=str, help="YYYY-MM-DD")
parser.add_argument("-k", type=int, default=8)
parser.add_argument(
"--pitcher", type=str, help='Filter pitcher by name (e.g. "Cole")'
)
parser.add_argument(
"--save-html", action="store_true", help="Save plots to artifacts/"
)
parser.add_argument(
"--force", action="store_true", help="Force re-download Statcast"
)
args = parser.parse_args()
ensure_dirs()
start, end = (
(args.start, args.end) if (args.start and args.end) else default_window()
)
print(f"Window: {start}{end}")
df_raw = load_statcast(start, end, force=args.force)
ivb_sign = infer_ivb_sign(df_raw)
print(f"IVB sign inferred = {ivb_sign} (ride should be positive)")
df_feat = engineer_pitch_features(df_raw, ivb_sign)
# ⬇️ NEW: fit the improved model
model = fit_pitch_clusters(df_feat, k=args.k)
df_fit = model.df_fit # contains all original cols + 'cluster'
# Tag clusters with human-readable names
cluster_names = xy_cluster_tags(df_fit)
df_fit["cluster_name"] = df_fit["cluster"].map(cluster_names)
# Save artifacts
feat_p = ARTIFACTS_DIR / "pitch_features.parquet"
fit_p = ARTIFACTS_DIR / "pitch_features_clusters.parquet"
df_feat.to_parquet(feat_p, index=False)
df_fit.to_parquet(fit_p, index=False)
print(f"Saved: {feat_p}, {fit_p}")
# Optional pitcher card + comps
if args.pitcher:
sub = df_fit[
df_fit["player_name"].str.contains(args.pitcher, case=False, na=False)
]
if sub.empty:
print(f"No pitcher matched '{args.pitcher}'")
else:
name = sub["player_name"].iloc[0]
df_p = df_fit[df_fit["player_name"] == name].sort_values("pitch_type")
print(f"\n=== Scouting Card: {name} ===")
print(
df_p[
[
"pitch_type",
"p_throws",
"n",
"velo",
"ivb_in",
"hb_as_in",
"csw",
"whiff_rate",
"gb_rate",
"zone_pct",
"cluster_name",
]
].to_string(index=False)
)
for _, row in df_p.iterrows():
# ⬇️ UPDATED: pass the model, not (df_fit, scaler, nn)
comps = nearest_comps(row, model, k=5, allow_cross_type=False)
print(f"\nNearest comps — {row['pitch_type']} ({row['cluster_name']}):")
print(comps.to_string(index=False))
# Movement plot
fig = movement_scatter_xy(df_fit, color="cluster_name")
if args.save_html:
out = ARTIFACTS_DIR / "movement_all.html"
pio.write_html(fig, file=str(out), auto_open=False, include_plotlyjs="cdn")
print(f"Saved plot: {out}")
if __name__ == "__main__":
main()