2026_MLB_Model / models /batter_zone_model.py
Syntrex's picture
Update models/batter_zone_model.py
26e1c80 verified
raw
history blame
2.92 kB
from __future__ import annotations
from typing import Any
import pandas as pd
from models.batter_zone_store import load_batter_zone_store_metrics
PITCH_FAMILY_MAP = {
"4-seam fastball": "fastball",
"four-seam fastball": "fastball",
"fastball": "fastball",
"sinker": "fastball",
"cutter": "fastball",
"slider": "breaking",
"sweeper": "breaking",
"curveball": "breaking",
"knuckle curve": "breaking",
"slurve": "breaking",
"changeup": "offspeed",
"splitter": "offspeed",
"forkball": "offspeed",
"split-finger": "offspeed",
"circle change": "offspeed",
}
def normalize_pitch_family(pitch_name: Any) -> str:
text = str(pitch_name or "").strip().lower()
if text in {"", "nan", "none"}:
return "unknown"
return PITCH_FAMILY_MAP.get(text, "unknown")
def classify_zone_bucket(plate_x: Any, plate_z: Any) -> str:
try:
x = float(plate_x)
z = float(plate_z)
except Exception:
return "unknown"
zone_left = -0.83
zone_right = 0.83
zone_bottom = 1.50
zone_top = 3.50
if zone_left <= x <= zone_right and zone_bottom <= z <= zone_top:
inner_left = -0.45
inner_right = 0.45
inner_bottom = 1.90
inner_top = 3.10
if inner_left <= x <= inner_right and inner_bottom <= z <= inner_top:
return "heart"
return "shadow"
chase_left = -1.20
chase_right = 1.20
chase_bottom = 1.10
chase_top = 3.90
if chase_left <= x <= chase_right and chase_bottom <= z <= chase_top:
return "chase"
return "waste"
def build_batter_zone_feature_row(
statcast_df: pd.DataFrame,
player_name: str,
) -> dict[str, Any]:
store_metrics = load_batter_zone_store_metrics(player_name)
row: dict[str, Any] = {
"player_name": player_name,
"zone_sample_size": store_metrics.get("stored_zone_sample_size", 0),
}
pitch_families = ["fastball", "breaking", "offspeed"]
zones = ["heart", "shadow", "chase", "waste"]
for family in pitch_families:
for zone in zones:
row[f"hr_prob_{family}_{zone}"] = store_metrics.get(
f"stored_hr_prob_{family}_{zone}"
)
row[f"hit_prob_{family}_{zone}"] = store_metrics.get(
f"stored_hit_prob_{family}_{zone}"
)
row[f"tb2p_prob_{family}_{zone}"] = store_metrics.get(
f"stored_tb2p_prob_{family}_{zone}"
)
row[f"whiff_prob_{family}_{zone}"] = store_metrics.get(
f"stored_whiff_prob_{family}_{zone}"
)
row[f"damage_prob_{family}_{zone}"] = store_metrics.get(
f"stored_damage_prob_{family}_{zone}"
)
row[f"sample_size_{family}_{zone}"] = store_metrics.get(
f"stored_sample_size_{family}_{zone}"
)
return row