rsm-roguchi commited on
Commit
752a595
·
1 Parent(s): cb73dd6
Files changed (6) hide show
  1. app.py +7 -32
  2. bin/cli.py +14 -4
  3. pyproject.toml +1 -0
  4. requirements.txt +2 -1
  5. src/model.py +205 -16
  6. src/tags.py +88 -114
app.py CHANGED
@@ -2,16 +2,16 @@
2
  import os, sys
3
  from datetime import datetime
4
 
5
- # Ensure we can import from ./src even on HF Spaces
6
  BASE_DIR = os.path.dirname(__file__)
7
  sys.path.append(os.path.join(BASE_DIR, "src"))
8
 
9
  import streamlit as st
10
  import pandas as pd
11
 
12
- # Your local modules
13
  from data import load_statcast, default_window
14
  from featurize import infer_ivb_sign, engineer_pitch_features
 
 
15
  from model import fit_kmeans, nearest_comps
16
  from tags import xy_cluster_tags
17
  from plots import movement_scatter_xy, radar_quality
@@ -26,38 +26,22 @@ except Exception:
26
  st.set_page_config(page_title="PitchXY (Handedness-Aware)", layout="wide")
27
  st.title("⚾ PitchXY — Handedness-Aware Pitch Archetypes & Scouting Cards")
28
 
29
- # ---- Helpers
30
-
31
 
32
  @st.cache_data(show_spinner=False, ttl=24 * 3600)
33
  def load_statcast_cached(start: str, end: str, force: bool = False) -> pd.DataFrame:
34
- """
35
- Cached wrapper around your loader. On Spaces, expensive network calls during
36
- app init are the #1 cause of infinite 'Starting...'. This keeps it fast.
37
- """
38
  return load_statcast(start, end, force=force)
39
 
40
 
41
  @st.cache_data(show_spinner=False)
42
  def load_sample_fallback() -> pd.DataFrame:
43
- """
44
- Optional: fallback sample data so the app is usable even if MLB/Statcast
45
- endpoints are rate limited / blocked in Spaces.
46
- - Put a small parquet or CSV in your Space repo: data/sample_statcast.parquet
47
- - Or host it under a HF Dataset repo and set SAMPLE_DATA_REPO, SAMPLE_DATA_FILE.
48
- """
49
  local_path = os.path.join(BASE_DIR, "data", "sample_statcast.parquet")
50
  if os.path.exists(local_path):
51
  return pd.read_parquet(local_path)
52
-
53
- # If not bundled locally, try HF Hub (if available)
54
  repo_id = os.getenv("SAMPLE_DATA_REPO", "").strip()
55
  file_name = os.getenv("SAMPLE_DATA_FILE", "sample_statcast.parquet").strip()
56
  if HF_HUB_OK and repo_id:
57
  path = hf_hub_download(repo_id=repo_id, filename=file_name, repo_type="dataset")
58
  return pd.read_parquet(path)
59
-
60
- # Give a tiny empty frame with expected columns to keep UI alive
61
  return pd.DataFrame(
62
  columns=[
63
  "game_date",
@@ -81,12 +65,8 @@ def load_sample_fallback() -> pd.DataFrame:
81
 
82
 
83
  def safe_load_data(start: str, end: str, force: bool) -> pd.DataFrame:
84
- """
85
- Try cached real data first; if it errors or returns empty, fall back to a sample.
86
- """
87
  try:
88
  df = load_statcast_cached(start, end, force)
89
- # Basic sanity check – empty windows are common; handle gracefully
90
  if df is not None and not df.empty:
91
  return df
92
  st.info("No live data returned for that window — showing sample data instead.")
@@ -95,19 +75,15 @@ def safe_load_data(start: str, end: str, force: bool) -> pd.DataFrame:
95
  return load_sample_fallback()
96
 
97
 
98
- # ---- Sidebar
99
-
100
  with st.sidebar:
101
  st.header("Data Window")
102
  dstart, dend = default_window()
103
  start = st.text_input("Start YYYY-MM-DD", dstart)
104
  end = st.text_input("End YYYY-MM-DD", dend)
105
- k = st.slider("Clusters (k)", 5, 12, 8)
106
  force = st.checkbox("Force re-download (discouraged on Spaces)", value=False)
107
  st.caption("Tip: avoid 'Force re-download' on Spaces to keep startup snappy.")
108
 
109
- # ---- Data pipeline
110
-
111
  with st.spinner("Loading data…"):
112
  df_raw = safe_load_data(start, end, force)
113
 
@@ -120,7 +96,6 @@ if df_raw.empty:
120
  st.stop()
121
 
122
 
123
- # Feature engineering (cache stable steps)
124
  @st.cache_data(show_spinner=False)
125
  def _featurize(df_raw_in: pd.DataFrame):
126
  ivb_sign = infer_ivb_sign(df_raw_in)
@@ -131,9 +106,11 @@ def _featurize(df_raw_in: pd.DataFrame):
131
  df_feat = _featurize(df_raw)
132
 
133
 
134
- @st.cache_data(show_spinner=False)
 
135
  def _fit_model(df_feat_in: pd.DataFrame, k_val: int):
136
  df_fit_local, scaler, km, nn = fit_kmeans(df_feat_in, k=k_val)
 
137
  cluster_names_local = xy_cluster_tags(df_fit_local)
138
  df_fit_local = df_fit_local.copy()
139
  df_fit_local["cluster_name"] = df_fit_local["cluster"].map(cluster_names_local)
@@ -143,8 +120,6 @@ def _fit_model(df_feat_in: pd.DataFrame, k_val: int):
143
  with st.spinner("Clustering & tagging…"):
144
  df_fit, scaler, km, nn = _fit_model(df_feat, k)
145
 
146
- # ---- UI
147
-
148
  pitcher = st.selectbox("Pitcher", sorted(df_fit["player_name"].dropna().unique()))
149
  df_p = df_fit[df_fit["player_name"] == pitcher].sort_values("pitch_type")
150
 
@@ -190,6 +165,6 @@ with tab2:
190
  with tab3:
191
  for _, row in df_p.iterrows():
192
  st.markdown(f"#### {row['pitch_type']} comps")
 
193
  comps = nearest_comps(row, df_fit, scaler, nn, within_pitch_type=True, k=6)
194
  st.dataframe(comps, use_container_width=True)
195
-
 
2
  import os, sys
3
  from datetime import datetime
4
 
 
5
  BASE_DIR = os.path.dirname(__file__)
6
  sys.path.append(os.path.join(BASE_DIR, "src"))
7
 
8
  import streamlit as st
9
  import pandas as pd
10
 
 
11
  from data import load_statcast, default_window
12
  from featurize import infer_ivb_sign, engineer_pitch_features
13
+
14
+ # ⬇️ Revert to older API
15
  from model import fit_kmeans, nearest_comps
16
  from tags import xy_cluster_tags
17
  from plots import movement_scatter_xy, radar_quality
 
26
  st.set_page_config(page_title="PitchXY (Handedness-Aware)", layout="wide")
27
  st.title("⚾ PitchXY — Handedness-Aware Pitch Archetypes & Scouting Cards")
28
 
 
 
29
 
30
  @st.cache_data(show_spinner=False, ttl=24 * 3600)
31
  def load_statcast_cached(start: str, end: str, force: bool = False) -> pd.DataFrame:
 
 
 
 
32
  return load_statcast(start, end, force=force)
33
 
34
 
35
  @st.cache_data(show_spinner=False)
36
  def load_sample_fallback() -> pd.DataFrame:
 
 
 
 
 
 
37
  local_path = os.path.join(BASE_DIR, "data", "sample_statcast.parquet")
38
  if os.path.exists(local_path):
39
  return pd.read_parquet(local_path)
 
 
40
  repo_id = os.getenv("SAMPLE_DATA_REPO", "").strip()
41
  file_name = os.getenv("SAMPLE_DATA_FILE", "sample_statcast.parquet").strip()
42
  if HF_HUB_OK and repo_id:
43
  path = hf_hub_download(repo_id=repo_id, filename=file_name, repo_type="dataset")
44
  return pd.read_parquet(path)
 
 
45
  return pd.DataFrame(
46
  columns=[
47
  "game_date",
 
65
 
66
 
67
  def safe_load_data(start: str, end: str, force: bool) -> pd.DataFrame:
 
 
 
68
  try:
69
  df = load_statcast_cached(start, end, force)
 
70
  if df is not None and not df.empty:
71
  return df
72
  st.info("No live data returned for that window — showing sample data instead.")
 
75
  return load_sample_fallback()
76
 
77
 
 
 
78
  with st.sidebar:
79
  st.header("Data Window")
80
  dstart, dend = default_window()
81
  start = st.text_input("Start YYYY-MM-DD", dstart)
82
  end = st.text_input("End YYYY-MM-DD", dend)
83
+ k = st.slider("Clusters (k)", 5, 40, 25)
84
  force = st.checkbox("Force re-download (discouraged on Spaces)", value=False)
85
  st.caption("Tip: avoid 'Force re-download' on Spaces to keep startup snappy.")
86
 
 
 
87
  with st.spinner("Loading data…"):
88
  df_raw = safe_load_data(start, end, force)
89
 
 
96
  st.stop()
97
 
98
 
 
99
  @st.cache_data(show_spinner=False)
100
  def _featurize(df_raw_in: pd.DataFrame):
101
  ivb_sign = infer_ivb_sign(df_raw_in)
 
106
  df_feat = _featurize(df_raw)
107
 
108
 
109
+ # ✅ Cache the fitted artifacts from the older API
110
+ @st.cache_resource(show_spinner=False)
111
  def _fit_model(df_feat_in: pd.DataFrame, k_val: int):
112
  df_fit_local, scaler, km, nn = fit_kmeans(df_feat_in, k=k_val)
113
+ # Tag clusters with readable names
114
  cluster_names_local = xy_cluster_tags(df_fit_local)
115
  df_fit_local = df_fit_local.copy()
116
  df_fit_local["cluster_name"] = df_fit_local["cluster"].map(cluster_names_local)
 
120
  with st.spinner("Clustering & tagging…"):
121
  df_fit, scaler, km, nn = _fit_model(df_feat, k)
122
 
 
 
123
  pitcher = st.selectbox("Pitcher", sorted(df_fit["player_name"].dropna().unique()))
124
  df_p = df_fit[df_fit["player_name"] == pitcher].sort_values("pitch_type")
125
 
 
165
  with tab3:
166
  for _, row in df_p.iterrows():
167
  st.markdown(f"#### {row['pitch_type']} comps")
168
+ # ⬇️ Old signature again
169
  comps = nearest_comps(row, df_fit, scaler, nn, within_pitch_type=True, k=6)
170
  st.dataframe(comps, use_container_width=True)
 
bin/cli.py CHANGED
@@ -2,6 +2,8 @@ from __future__ import annotations
2
  import argparse
3
  from data import load_statcast, default_window
4
  from featurize import infer_ivb_sign, engineer_pitch_features
 
 
5
  from model import fit_kmeans, nearest_comps
6
  from tags import xy_cluster_tags
7
  from plots import movement_scatter_xy
@@ -38,7 +40,12 @@ def main():
38
  print(f"IVB sign inferred = {ivb_sign} (ride should be positive)")
39
 
40
  df_feat = engineer_pitch_features(df_raw, ivb_sign)
41
- df_fit, scaler, km, nn = fit_kmeans(df_feat, k=args.k)
 
 
 
 
 
42
  cluster_names = xy_cluster_tags(df_fit)
43
  df_fit["cluster_name"] = df_fit["cluster"].map(cluster_names)
44
 
@@ -78,9 +85,8 @@ def main():
78
  ].to_string(index=False)
79
  )
80
  for _, row in df_p.iterrows():
81
- comps = nearest_comps(
82
- row, df_fit, scaler, nn, within_pitch_type=True, k=6
83
- )
84
  print(f"\nNearest comps — {row['pitch_type']} ({row['cluster_name']}):")
85
  print(comps.to_string(index=False))
86
 
@@ -90,3 +96,7 @@ def main():
90
  out = ARTIFACTS_DIR / "movement_all.html"
91
  pio.write_html(fig, file=str(out), auto_open=False, include_plotlyjs="cdn")
92
  print(f"Saved plot: {out}")
 
 
 
 
 
2
  import argparse
3
  from data import load_statcast, default_window
4
  from featurize import infer_ivb_sign, engineer_pitch_features
5
+
6
+ # ⬇️ NEW: import the updated API
7
  from model import fit_kmeans, nearest_comps
8
  from tags import xy_cluster_tags
9
  from plots import movement_scatter_xy
 
40
  print(f"IVB sign inferred = {ivb_sign} (ride should be positive)")
41
 
42
  df_feat = engineer_pitch_features(df_raw, ivb_sign)
43
+
44
+ # ⬇️ NEW: fit the improved model
45
+ model = fit_pitch_clusters(df_feat, k=args.k)
46
+ df_fit = model.df_fit # contains all original cols + 'cluster'
47
+
48
+ # Tag clusters with human-readable names
49
  cluster_names = xy_cluster_tags(df_fit)
50
  df_fit["cluster_name"] = df_fit["cluster"].map(cluster_names)
51
 
 
85
  ].to_string(index=False)
86
  )
87
  for _, row in df_p.iterrows():
88
+ # ⬇️ UPDATED: pass the model, not (df_fit, scaler, nn)
89
+ comps = nearest_comps(row, model, k=5, allow_cross_type=False)
 
90
  print(f"\nNearest comps — {row['pitch_type']} ({row['cluster_name']}):")
91
  print(comps.to_string(index=False))
92
 
 
96
  out = ARTIFACTS_DIR / "movement_all.html"
97
  pio.write_html(fig, file=str(out), auto_open=False, include_plotlyjs="cdn")
98
  print(f"Saved plot: {out}")
99
+
100
+
101
+ if __name__ == "__main__":
102
+ main()
pyproject.toml CHANGED
@@ -9,6 +9,7 @@ dependencies = [
9
  "numpy",
10
  "pybaseball",
11
  "scikit-learn",
 
12
  "plotly",
13
  "pyarrow",
14
  "streamlit" # needed for HF Space app below
 
9
  "numpy",
10
  "pybaseball",
11
  "scikit-learn",
12
+ "scikit-learn-extra",
13
  "plotly",
14
  "pyarrow",
15
  "streamlit" # needed for HF Space app below
requirements.txt CHANGED
@@ -6,4 +6,5 @@ scikit-learn==1.5.1
6
  pyarrow==16.1.0
7
  huggingface_hub==0.25.2
8
  pybaseball==2.2.7
9
- requests>=2.31.0
 
 
6
  pyarrow==16.1.0
7
  huggingface_hub==0.25.2
8
  pybaseball==2.2.7
9
+ requests>=2.31.0
10
+ scikit-learn-extra
src/model.py CHANGED
@@ -1,9 +1,15 @@
1
  from __future__ import annotations
 
2
  import pandas as pd
3
- from sklearn.preprocessing import StandardScaler
4
- from sklearn.cluster import KMeans
 
 
5
  from sklearn.neighbors import NearestNeighbors
6
 
 
 
 
7
  ARCH_FEATURES = [
8
  "velo",
9
  "ivb_in",
@@ -17,29 +23,169 @@ ARCH_FEATURES = [
17
  "zone_pct",
18
  ]
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- def fit_kmeans(df_feat: pd.DataFrame, k: int = 8, random_state: int = 42):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  df = df_feat.dropna(subset=ARCH_FEATURES).copy()
23
- X = df[ARCH_FEATURES].values
24
- scaler = StandardScaler()
25
- Xs = scaler.fit_transform(X)
26
- km = KMeans(n_clusters=k, n_init=20, random_state=random_state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  labels = km.fit_predict(Xs)
28
  df["cluster"] = labels
29
 
30
- nn = NearestNeighbors(n_neighbors=6, metric="euclidean")
31
- nn.fit(Xs)
32
  return df, scaler, km, nn
33
 
34
 
35
  def nearest_comps(
36
- row: pd.Series, df_fit: pd.DataFrame, scaler, nn, within_pitch_type=True, k=6
 
 
 
 
 
37
  ):
 
 
 
 
 
 
 
 
 
 
38
  xq = scaler.transform(row[ARCH_FEATURES].values.reshape(1, -1))
39
- dists, idxs = nn.kneighbors(xq, n_neighbors=k)
40
- comps = df_fit.iloc[idxs[0]].copy()
41
- if within_pitch_type:
42
- comps = comps[comps["pitch_type"] == row["pitch_type"]]
43
  cols = [
44
  "player_name",
45
  "pitch_type",
@@ -49,6 +195,49 @@ def nearest_comps(
49
  "hb_as_in",
50
  "whiff_rate",
51
  "gb_rate",
52
- "cluster_name",
53
  ]
54
- return comps[cols].head(k - 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
+ import numpy as np
3
  import pandas as pd
4
+ from typing import Dict, Optional, Tuple
5
+ from sklearn.impute import SimpleImputer
6
+ from sklearn.preprocessing import RobustScaler, StandardScaler, FunctionTransformer
7
+ from sklearn.pipeline import Pipeline
8
  from sklearn.neighbors import NearestNeighbors
9
 
10
+ # NEW: medoids (robust, nearest-exemplar clustering)
11
+ from sklearn_extra.cluster import KMedoids
12
+
13
  ARCH_FEATURES = [
14
  "velo",
15
  "ivb_in",
 
23
  "zone_pct",
24
  ]
25
 
26
+ # ---------- existing helpers (unchanged API) ----------
27
+
28
+
29
+ def winsorize_df(df: pd.DataFrame, cols, lower=0.01, upper=0.99):
30
+ q_low = df[cols].quantile(lower)
31
+ q_hi = df[cols].quantile(upper)
32
+ return df.assign(**{c: df[c].clip(q_low[c], q_hi[c]) for c in cols})
33
+
34
+
35
+ def groupwise_z(df: pd.DataFrame, cols, group_col="pitch_type"):
36
+ df = df.copy()
37
+
38
+ def _z(g):
39
+ return (g - g.mean()) / (g.std(ddof=0) + 1e-8)
40
+
41
+ gz_cols = []
42
+ for c in cols:
43
+ gz = f"{c}_gz"
44
+ df[gz] = df.groupby(group_col)[c].transform(_z)
45
+ gz_cols.append(gz)
46
+ return df, gz_cols
47
+
48
+
49
+ def _preprocessor(
50
+ gz_feats: list[str], weights: Optional[Dict[str, float]] = None
51
+ ) -> Pipeline:
52
+ """
53
+ Consistent preprocessing for clustering and neighbor search.
54
+ Applies impute -> robust scale -> standardize -> optional weights.
55
+ """
56
+ steps = [
57
+ ("imputer", SimpleImputer(strategy="median")),
58
+ ("robust", RobustScaler()),
59
+ ("std", StandardScaler(with_mean=True, with_std=True)),
60
+ ]
61
+ if weights:
62
+ w = np.array(
63
+ [weights.get(f.replace("_gz", ""), 1.0) for f in gz_feats], dtype=float
64
+ )
65
+ steps.append(
66
+ (
67
+ "weights",
68
+ FunctionTransformer(lambda X: X * w, feature_names_out="one-to-one"),
69
+ )
70
+ )
71
+ return Pipeline(steps)
72
+
73
+
74
+ # ---------- local, neighbor-aware label smoothing (kept) ----------
75
+
76
+
77
+ def _contextual_smooth_labels(
78
+ Xs: np.ndarray,
79
+ labels: np.ndarray,
80
+ n_neighbors: int = 15,
81
+ vote_thresh: float = 0.6,
82
+ margin: float = 0.0,
83
+ max_iters: int = 2,
84
+ ) -> np.ndarray:
85
+ """
86
+ Reassign labels by local kNN majority with a confidence threshold.
87
+ - vote_thresh: minimum fraction of neighbors that must agree to flip (e.g., 0.6)
88
+ - margin: require the neighbor-majority centroid to be at least 'margin' closer
89
+ than the current cluster center (0.0 = no distance guard)
90
+ """
91
+ n = len(labels)
92
+ labels = labels.copy()
93
+
94
+ knn = NearestNeighbors(n_neighbors=min(n, n_neighbors + 1), metric="manhattan").fit(
95
+ Xs
96
+ )
97
+ dists, idxs = knn.kneighbors(Xs)
98
 
99
+ def centroids(lbls):
100
+ Cs = []
101
+ for k in np.unique(lbls):
102
+ Cs.append(Xs[lbls == k].mean(axis=0))
103
+ return {k: c for k, c in zip(np.unique(lbls), Cs)}
104
+
105
+ for _ in range(max_iters):
106
+ C = centroids(labels)
107
+ changed = 0
108
+ for i in range(n):
109
+ neigh = idxs[i][1:] # drop self
110
+ neigh_lbls = labels[neigh]
111
+ vals, counts = np.unique(neigh_lbls, return_counts=True)
112
+ j = np.argmax(counts)
113
+ maj_label, maj_frac = vals[j], counts[j] / len(neigh_lbls)
114
+ if maj_frac < vote_thresh or maj_label == labels[i]:
115
+ continue
116
+ if margin > 0.0:
117
+ cur_c = C[labels[i]]
118
+ maj_c = C[maj_label]
119
+ di_cur = np.linalg.norm(Xs[i] - cur_c)
120
+ di_maj = np.linalg.norm(Xs[i] - maj_c)
121
+ if di_maj >= di_cur - margin:
122
+ continue
123
+ labels[i] = maj_label
124
+ changed += 1
125
+ if changed == 0:
126
+ break
127
+ return labels
128
+
129
+
130
+ # ---------- API: fit + comps (drop-in) ----------
131
+
132
+
133
+ def fit_kmeans(df_feat: pd.DataFrame, k: int = 20, random_state: int = 42):
134
+ """
135
+ DROP-IN REPLACEMENT:
136
+ - Uses K-MEDOIDS with MANHATTAN distance (closest-neighbor–friendly).
137
+ - Returns (df_with_clusters, scaler_pipeline, kmedoids_model, knn_index).
138
+ """
139
  df = df_feat.dropna(subset=ARCH_FEATURES).copy()
140
+
141
+ # Light winsorization: dampen outliers without warping scale
142
+ df[ARCH_FEATURES] = df[ARCH_FEATURES].clip(
143
+ df[ARCH_FEATURES].quantile(0.01),
144
+ df[ARCH_FEATURES].quantile(0.99),
145
+ axis=1,
146
+ )
147
+
148
+ # Consistent preprocessing for clustering and neighbors
149
+ scaler = _preprocessor(ARCH_FEATURES, weights=None)
150
+ Xs = scaler.fit_transform(df[ARCH_FEATURES].values)
151
+
152
+ # K-Medoids with Manhattan distance -> emphasizes true nearest relationships
153
+ km = KMedoids(
154
+ n_clusters=k,
155
+ metric="manhattan",
156
+ init="k-medoids++",
157
+ max_iter=500,
158
+ random_state=random_state,
159
+ )
160
  labels = km.fit_predict(Xs)
161
  df["cluster"] = labels
162
 
163
+ # NN index in the SAME space & metric
164
+ nn = NearestNeighbors(n_neighbors=8, metric="manhattan").fit(Xs)
165
  return df, scaler, km, nn
166
 
167
 
168
  def nearest_comps(
169
+ row: pd.Series,
170
+ df_fit: pd.DataFrame,
171
+ scaler: Pipeline,
172
+ nn: NearestNeighbors,
173
+ within_pitch_type: bool = True,
174
+ k: int = 6,
175
  ):
176
+ """
177
+ Nearest comps in the SAME preprocessed space and metric (Manhattan).
178
+ If within_pitch_type=True, restricts candidates to the same pitch_type.
179
+ """
180
+ # Ensure all required features exist
181
+ missing = [c for c in ARCH_FEATURES if c not in df_fit.columns]
182
+ if missing:
183
+ raise KeyError(f"nearest_comps: df_fit is missing required features: {missing}")
184
+
185
+ # Query vector in the exact same space as clustering
186
  xq = scaler.transform(row[ARCH_FEATURES].values.reshape(1, -1))
187
+
188
+ # Columns to return
 
 
189
  cols = [
190
  "player_name",
191
  "pitch_type",
 
195
  "hb_as_in",
196
  "whiff_rate",
197
  "gb_rate",
198
+ "cluster",
199
  ]
200
+
201
+ # Per-pitch-type neighborhood (preferred)
202
+ if within_pitch_type and "pitch_type" in df_fit.columns:
203
+ ptype = row.get("pitch_type")
204
+ if isinstance(ptype, str):
205
+ sub = df_fit[df_fit["pitch_type"] == ptype].copy()
206
+ if not sub.empty:
207
+ Xsub = scaler.transform(sub[ARCH_FEATURES].values)
208
+ k_loc = min(len(sub), max(2, k + 1)) # +1 to allow excluding self
209
+ knn_local = NearestNeighbors(n_neighbors=k_loc, metric="manhattan").fit(
210
+ Xsub
211
+ )
212
+ dists, inds = knn_local.kneighbors(xq, n_neighbors=k_loc)
213
+ cand = sub.iloc[inds[0]].copy()
214
+ cand["_dist"] = dists[0]
215
+ # Prefer excluding the same player if present
216
+ pname = row.get("player_name", None)
217
+ if pname is not None and "player_name" in cand.columns:
218
+ cand = cand[cand["player_name"] != pname]
219
+ return (
220
+ cand.sort_values("_dist")
221
+ .drop(columns=["_dist"], errors="ignore")[cols]
222
+ .head(k)
223
+ )
224
+
225
+ # Global fallback: use provided NN (already fit in Manhattan space)
226
+ k_glob = min(len(df_fit), max(2, k + 1))
227
+ dists, inds = nn.kneighbors(xq, n_neighbors=k_glob)
228
+ cand = df_fit.iloc[inds[0]].copy()
229
+ if within_pitch_type and "pitch_type" in df_fit.columns:
230
+ ptype = row.get("pitch_type")
231
+ if isinstance(ptype, str):
232
+ cand = cand[cand["pitch_type"] == ptype]
233
+ pname = row.get("player_name", None)
234
+ if pname is not None and "player_name" in cand.columns:
235
+ cand = cand[cand["player_name"] != pname]
236
+ cand["_dist"] = dists[0][: len(cand)] if len(dists[0]) >= len(cand) else 0.0
237
+ return (
238
+ cand.sort_values("_dist").drop(columns=["_dist"], errors="ignore")[cols].head(k)
239
+ )
240
+
241
+
242
+ # Make public API explicit (unchanged)
243
+ __all__ = ["ARCH_FEATURES", "fit_kmeans", "nearest_comps"]
src/tags.py CHANGED
@@ -1,9 +1,17 @@
1
  from __future__ import annotations
2
  import numpy as np
3
  import pandas as pd
 
4
 
5
 
6
- def _mag_label(v, q25, q75, small="Subtle", mid="Moderate", big="Heavy"):
 
 
 
 
 
 
 
7
  if pd.isna(v):
8
  return mid
9
  if v >= q75:
@@ -13,139 +21,97 @@ def _mag_label(v, q25, q75, small="Subtle", mid="Moderate", big="Heavy"):
13
  return mid
14
 
15
 
16
- def _vert_label(ivb):
17
- if pd.isna(ivb):
18
  return "Neutral"
19
- return "Ride" if ivb >= 0 else "Drop"
20
 
21
 
22
- def _armside_from_raw_hb(hb_raw: float, throws: str) -> str:
23
- """Return 'Arm-Side' or 'Glove-Side' from raw HB (catcher view) and dominant throws.
24
- Statcast convention (catcher view): positive = to catcher’s left (3B side).
25
- Arm-side mapping commonly used:
26
- - RHP arm-side run → negative hb_raw
27
- - LHP arm-side run → positive hb_raw
28
- """
29
- if pd.isna(hb_raw) or throws not in ("R", "L"):
30
  return "Neutral"
31
- if (throws == "R" and hb_raw < 0) or (throws == "L" and hb_raw > 0):
32
- return "Arm-Side"
33
- return "Glove-Side"
34
-
35
-
36
- def _infer_side_series(sub: pd.DataFrame) -> pd.Series:
37
- """Infer per-pitch side (Arm/Glove) robustly, using raw hb if available,
38
- else reconstruct a raw-ish value from hb_as_in and p_throws."""
39
- has_raw = "hb_in" in sub.columns
40
- if has_raw:
41
- hb_raw = sub["hb_in"]
42
- else:
43
- # Reconstruct raw-ish: if hb_as_in is arm-side-adjusted (positive toward arm-side),
44
- # then flip sign for RHP to get a catcher-view-like raw sign.
45
- # raw +hb_as for LHP, raw ≈ -hb_as for RHP
46
- if "hb_as_in" in sub.columns and "p_throws" in sub.columns:
47
- hb_raw = np.where(sub["p_throws"] == "L", sub["hb_as_in"], -sub["hb_as_in"])
48
- hb_raw = pd.Series(hb_raw, index=sub.index)
49
- else:
50
- return pd.Series(["Neutral"] * len(sub), index=sub.index)
51
-
52
- throws = sub["p_throws"].fillna(
53
- sub["p_throws"].mode().iloc[0] if not sub["p_throws"].mode().empty else "R"
54
- )
55
- return pd.Series(
56
- np.where(
57
- ((throws == "R") & (hb_raw < 0)) | ((throws == "L") & (hb_raw > 0)),
58
- "Arm-Side",
59
- "Glove-Side",
60
- ),
61
- index=sub.index,
62
- )
63
-
64
-
65
- def xy_cluster_tags(df_with_clusters: pd.DataFrame) -> dict[int, str]:
66
  df = df_with_clusters.copy()
67
 
68
- # Quantiles for magnitude bucketing
69
- q_abs_ivb25 = np.nanquantile(np.abs(df["ivb_in"]), 0.25)
70
- q_abs_ivb75 = np.nanquantile(np.abs(df["ivb_in"]), 0.75)
71
- q_abs_hb25 = np.nanquantile(np.abs(df["hb_as_in"]), 0.25)
72
- q_abs_hb75 = np.nanquantile(np.abs(df["hb_as_in"]), 0.75)
73
 
74
  # Quality quantiles
75
- q_wh75 = np.nanquantile(df["whiff_rate"], 0.75)
76
- q_gb75 = np.nanquantile(df["gb_rate"], 0.75)
77
- q_zn75 = np.nanquantile(df["zone_pct"], 0.75)
78
- q_wh50 = np.nanquantile(df["whiff_rate"], 0.50)
79
- q_gb50 = np.nanquantile(df["gb_rate"], 0.50)
80
- q_zn50 = np.nanquantile(df["zone_pct"], 0.50)
81
-
82
- tags = {}
 
83
  for c, sub in df.groupby("cluster"):
84
  # Robust central tendency
85
  row = sub.median(numeric_only=True)
86
 
87
- # Dominant metadata
88
- dom_pt = (
89
- sub["pitch_type"].mode().iloc[0]
90
- if "pitch_type" in sub and not sub["pitch_type"].mode().empty
91
- else "Pitch"
92
- )
93
- dom_throw = (
94
- sub["p_throws"].mode().iloc[0]
95
- if "p_throws" in sub and not sub["p_throws"].mode().empty
96
- else "R"
97
- )
98
 
99
- # Robust side inference
100
- per_pitch_side = _infer_side_series(sub)
101
- side_counts = per_pitch_side.value_counts(dropna=False)
102
- side = side_counts.idxmax() if not side_counts.empty else "Neutral"
103
 
104
- # If nearly tied or Neutral, fall back to median raw
105
- if side in ("Neutral",) or (
106
- len(side_counts) > 1 and (side_counts.max() - side_counts.min()) <= 2
107
- ):
108
- # Use hb_raw median logic
109
- if "hb_in" in sub.columns:
110
- hb_raw_med = sub["hb_in"].median()
111
- else:
112
- # Reconstruct raw-ish median from hb_as_in + throws
113
- if "hb_as_in" in sub.columns:
114
- hb_raw_med = sub.apply(
115
- lambda r: (
116
- r["hb_as_in"]
117
- if r.get("p_throws", dom_throw) == "L"
118
- else -r["hb_as_in"]
119
- ),
120
- axis=1,
121
- ).median()
122
- else:
123
- hb_raw_med = np.nan
124
- side = _armside_from_raw_hb(hb_raw_med, dom_throw)
125
-
126
- # Vertical shape from ivb sign (already handedness-invariant)
127
- vert = _vert_label(row.get("ivb_in", np.nan))
128
-
129
- # Magnitudes from absolute, handedness-invariant features
130
- mag_side = _mag_label(abs(row.get("hb_as_in", np.nan)), q_abs_hb25, q_abs_hb75)
131
- mag_vert = _mag_label(abs(row.get("ivb_in", np.nan)), q_abs_ivb25, q_abs_ivb75)
132
-
133
- # Flavor tags
134
  flavor = []
135
- if row.get("whiff_rate", 0) >= q_wh75:
136
  flavor.append("Whiff-First")
137
- if row.get("gb_rate", 0) >= q_gb75:
138
  flavor.append("Grounder-First")
139
- if row.get("zone_pct", 0) >= q_zn75:
140
  flavor.append("Strike-Throwing")
141
  if not flavor:
142
  diffs = {
143
- "Whiff-First": row.get("whiff_rate", 0) - q_wh50,
144
- "Grounder-First": row.get("gb_rate", 0) - q_gb50,
145
- "Strike-Throwing": row.get("zone_pct", 0) - q_zn50,
146
  }
147
  flavor.append(max(diffs, key=diffs.get))
148
 
 
149
  side_noun = (
150
  "Run"
151
  if side == "Arm-Side"
@@ -154,9 +120,17 @@ def xy_cluster_tags(df_with_clusters: pd.DataFrame) -> dict[int, str]:
154
  vert_noun = (
155
  "Ride" if vert == "Ride" else ("Drop" if vert == "Drop" else "Ride/Drop")
156
  )
157
- shape = f"{side} • {mag_side} {side_noun}, {mag_vert} {vert_noun}"
158
 
159
- tags[c] = f"{dom_pt}: {shape} " + " / ".join(flavor)
 
 
 
 
 
 
 
 
160
 
161
- return tags
162
 
 
 
1
  from __future__ import annotations
2
  import numpy as np
3
  import pandas as pd
4
+ from typing import Dict, Optional
5
 
6
 
7
+ def _safe_q(s: pd.Series, q: float, default: float) -> float:
8
+ s = pd.to_numeric(s, errors="coerce").dropna()
9
+ return float(s.quantile(q)) if len(s) else default
10
+
11
+
12
+ def _mag_label(
13
+ v: float, q25: float, q75: float, small="Subtle", mid="Moderate", big="Heavy"
14
+ ):
15
  if pd.isna(v):
16
  return mid
17
  if v >= q75:
 
21
  return mid
22
 
23
 
24
+ def _vert_label(ivb: float, eps: float = 0.5) -> str:
25
+ if pd.isna(ivb) or abs(ivb) <= eps:
26
  return "Neutral"
27
+ return "Ride" if ivb > 0 else "Drop"
28
 
29
 
30
+ def _side_label(hb_as: float, eps: float = 0.5) -> str:
31
+ """+hb_as_in = Arm-Side, -hb_as_in = Glove-Side; small |hb| -> Neutral."""
32
+ if pd.isna(hb_as) or abs(hb_as) <= eps:
 
 
 
 
 
33
  return "Neutral"
34
+ return "Arm-Side" if hb_as > 0 else "Glove-Side"
35
+
36
+
37
+ def xy_cluster_tags(
38
+ df_with_clusters: pd.DataFrame,
39
+ *,
40
+ eps_lat: float = 0.5, # dead-band for side near 0 (inches)
41
+ eps_vert: float = 0.5, # dead-band for ride/drop near 0
42
+ prefix_pitch_type: bool = False, # True to prepend dominant pitch_type like "SL:"
43
+ ) -> Dict[int, str]:
44
+ """
45
+ Cluster -> name using only movement characteristics:
46
+ - Side: sign(hb_as_in) (+ -> Arm-Side, - -> Glove-Side)
47
+ - Vert: sign(ivb_in) (+ -> Ride, - -> Drop)
48
+ Magnitude adjectives via quantiles (Subtle/Moderate/Heavy). Adds flavor tags
49
+ (Whiff-First / Grounder-First / Strike-Throwing) based on medians.
50
+
51
+ Returns {cluster_id: label}
52
+ """
53
+ if df_with_clusters.empty or "cluster" not in df_with_clusters.columns:
54
+ return {}
55
+
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  df = df_with_clusters.copy()
57
 
58
+ # Quantiles for magnitude bucketing (robust, adaptive per window)
59
+ q_abs_ivb25 = _safe_q(df.get("ivb_in", pd.Series([])), 0.25, 1.0)
60
+ q_abs_ivb75 = _safe_q(df.get("ivb_in", pd.Series([])).abs(), 0.75, 8.0)
61
+ q_abs_hb25 = _safe_q(df.get("hb_as_in", pd.Series([])).abs(), 0.25, 1.5)
62
+ q_abs_hb75 = _safe_q(df.get("hb_as_in", pd.Series([])).abs(), 0.75, 10.0)
63
 
64
  # Quality quantiles
65
+ q_wh75 = _safe_q(df.get("whiff_rate", pd.Series([])), 0.75, 0.30)
66
+ q_gb75 = _safe_q(df.get("gb_rate", pd.Series([])), 0.75, 0.45)
67
+ q_zn75 = _safe_q(df.get("zone_pct", pd.Series([])), 0.75, 0.52)
68
+ q_wh50 = _safe_q(df.get("whiff_rate", pd.Series([])), 0.50, 0.25)
69
+ q_gb50 = _safe_q(df.get("gb_rate", pd.Series([])), 0.50, 0.40)
70
+ q_zn50 = _safe_q(df.get("zone_pct", pd.Series([])), 0.50, 0.49)
71
+
72
+ tags: Dict[int, str] = {}
73
+
74
  for c, sub in df.groupby("cluster"):
75
  # Robust central tendency
76
  row = sub.median(numeric_only=True)
77
 
78
+ # Optional dominant metadata (NOT used for geometry)
79
+ prefix = ""
80
+ if (
81
+ prefix_pitch_type
82
+ and "pitch_type" in sub.columns
83
+ and not sub["pitch_type"].mode().empty
84
+ ):
85
+ prefix = f"{sub['pitch_type'].mode().iloc[0]}: "
 
 
 
86
 
87
+ # Geometry: use hb_as_in & ivb_in directly (signs define AS/GS and Ride/Drop)
88
+ hb_med = row.get("hb_as_in", np.nan)
89
+ ivb_med = row.get("ivb_in", np.nan)
 
90
 
91
+ side = _side_label(hb_med, eps=eps_lat) # Arm-Side / Glove-Side / Neutral
92
+ vert = _vert_label(ivb_med, eps=eps_vert) # Ride / Drop / Neutral
93
+
94
+ # Magnitude adjectives (absolute)
95
+ mag_side = _mag_label(abs(hb_med), q_abs_hb25, q_abs_hb75)
96
+ mag_vert = _mag_label(abs(ivb_med), q_abs_ivb25, q_abs_ivb75)
97
+
98
+ # Flavor tags (pick strongest; if none exceed 75th pct, choose highest vs median)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  flavor = []
100
+ if "whiff_rate" in row and row["whiff_rate"] >= q_wh75:
101
  flavor.append("Whiff-First")
102
+ if "gb_rate" in row and row["gb_rate"] >= q_gb75:
103
  flavor.append("Grounder-First")
104
+ if "zone_pct" in row and row["zone_pct"] >= q_zn75:
105
  flavor.append("Strike-Throwing")
106
  if not flavor:
107
  diffs = {
108
+ "Whiff-First": float(row.get("whiff_rate", 0) - q_wh50),
109
+ "Grounder-First": float(row.get("gb_rate", 0) - q_gb50),
110
+ "Strike-Throwing": float(row.get("zone_pct", 0) - q_zn50),
111
  }
112
  flavor.append(max(diffs, key=diffs.get))
113
 
114
+ # Compose human-readable shape
115
  side_noun = (
116
  "Run"
117
  if side == "Arm-Side"
 
120
  vert_noun = (
121
  "Ride" if vert == "Ride" else ("Drop" if vert == "Drop" else "Ride/Drop")
122
  )
 
123
 
124
+ # If Neutral on either axis, simplify the phrase
125
+ if side == "Neutral" and vert == "Neutral":
126
+ shape = "Neutral • Moderate Run/Sweep, Moderate Ride/Drop"
127
+ elif side == "Neutral":
128
+ shape = f"{vert} • Moderate Run/Sweep, {mag_vert} {vert_noun}"
129
+ elif vert == "Neutral":
130
+ shape = f"{side} • {mag_side} {side_noun}, Moderate Ride/Drop"
131
+ else:
132
+ shape = f"{side} • {mag_side} {side_noun}, {mag_vert} {vert_noun}"
133
 
134
+ tags[int(c)] = f"{prefix}{shape} • " + " / ".join(flavor)
135
 
136
+ return tags