2026_MLB_Model / models /strikeout_probability_engine.py
Syntrex's picture
Add props confidence breakdown diagnostics
d3b6f35
raw
history blame
19.9 kB
from __future__ import annotations
import math
from typing import Any
import pandas as pd
from models.arsenal_matchup_model import compute_arsenal_matchup_adjustment
from models.batter_arsenal_model import build_batter_arsenal_feature_row
from models.batter_zone_model import build_batter_zone_feature_row, normalize_pitch_family
from models.family_zone_profile_store import (
build_batter_family_zone_feature_row,
build_pitcher_family_zone_feature_row,
)
from models.matchup_model import (
compute_family_zone_matchup_adjustment,
compute_zone_matchup_adjustment,
)
from models.pitcher_adjustment import build_pitcher_feature_row
from models.pitcher_arsenal_model import build_pitcher_arsenal_feature_row
from models.pitcher_zone_model import build_pitcher_zone_feature_row
from models.trajectory_model import build_trajectory_features
def _safe_float(value: Any) -> float | None:
try:
if value is None:
return None
text = str(value).strip().lower()
if text in {"", "nan", "none"}:
return None
return float(value)
except Exception:
return None
def _clamp(value: float, lo: float, hi: float) -> float:
return max(lo, min(hi, value))
def _reliability(sample_size: Any, k: float = 120.0) -> float:
try:
sample = max(0.0, float(sample_size or 0.0))
except Exception:
sample = 0.0
return _clamp(sample / (sample + k), 0.0, 1.0)
def _poisson_prob_over(expected_value: float, line: float) -> float:
if expected_value <= 0:
return 0.0
target = int(math.floor(line))
cumulative = 0.0
for k in range(0, target + 1):
cumulative += math.exp(-expected_value) * (expected_value ** k) / math.factorial(k)
return _clamp(1.0 - cumulative, 0.0, 1.0)
def _poisson_prob_under(expected_value: float, line: float) -> float:
return _clamp(1.0 - _poisson_prob_over(expected_value, line), 0.0, 1.0)
def _bucket(score: float) -> str:
if score >= 75:
return "high"
if score >= 55:
return "medium"
return "low"
def _confidence_component(label: str, value: float, direction: str) -> dict[str, Any]:
return {
"label": label,
"value": round(float(value), 1),
"direction": direction,
}
def _normalize_name(value: Any) -> str:
return " ".join(str(value or "").strip().lower().split())
def _compute_sequencing_score(pitcher_statcast_df: pd.DataFrame, pitcher_name: str) -> dict[str, Any]:
empty = {
"sequencing_score": 0.5,
"sequencing_sample_size": 0,
"sequencing_reason_tags": [],
}
if pitcher_statcast_df is None or pitcher_statcast_df.empty or not pitcher_name:
return empty
if "player_name" not in pitcher_statcast_df.columns:
return empty
df = pitcher_statcast_df[
pitcher_statcast_df["player_name"].astype(str).str.casefold() == str(pitcher_name).casefold()
].copy()
if df.empty:
return empty
if "pitch_name" in df.columns:
pitch_name_series = df["pitch_name"]
elif "pitch_type" in df.columns:
pitch_name_series = df["pitch_type"]
else:
pitch_name_series = pd.Series(["unknown"] * len(df), index=df.index)
df["pitch_family"] = pitch_name_series.apply(normalize_pitch_family)
sort_cols = [c for c in ["game_date", "game_pk", "at_bat_number", "pitch_number"] if c in df.columns]
if sort_cols:
df = df.sort_values(sort_cols, na_position="last")
families = df["pitch_family"].astype(str).tolist()
if len(families) < 12:
return empty
transitions: dict[tuple[str, str], int] = {}
total = 0
changes = 0
for prev, nxt in zip(families, families[1:]):
if prev == "unknown" or nxt == "unknown":
continue
transitions[(prev, nxt)] = transitions.get((prev, nxt), 0) + 1
total += 1
if prev != nxt:
changes += 1
if total == 0:
return empty
diversity = len(transitions) / 9.0
change_rate = changes / total
score = _clamp((diversity * 0.55) + (change_rate * 0.45), 0.0, 1.0)
tags: list[str] = []
if score >= 0.65:
tags.append("Mixes sequences well")
elif score <= 0.35:
tags.append("Predictable sequencing")
return {
"sequencing_score": score,
"sequencing_sample_size": int(total),
"sequencing_reason_tags": tags,
}
def _aggregate_opponent_whiff_overlay(
batter_statcast_df: pd.DataFrame,
opponent_batters: list[str] | None,
opponent_team: str | None = None,
) -> dict[str, Any]:
out = {
"lineup_whiff_risk": 0.0,
"lineup_zone_whiff_risk": 0.0,
"lineup_sample_size": 0,
}
if batter_statcast_df is None or batter_statcast_df.empty:
return out
lineup_names = [str(name).strip() for name in (opponent_batters or []) if str(name).strip()]
if not lineup_names and opponent_team:
team_norm = _normalize_name(opponent_team)
working = batter_statcast_df.copy()
lineup_names = []
if {"inning_topbot", "home_team", "away_team", "player_name"}.issubset(working.columns):
top_mask = working["inning_topbot"].astype(str).str.lower().str.contains("top", na=False)
bottom_mask = working["inning_topbot"].astype(str).str.lower().str.contains("bot|bottom", na=False)
away_norm = working["away_team"].fillna("").astype(str).map(_normalize_name)
home_norm = working["home_team"].fillna("").astype(str).map(_normalize_name)
team_mask = (top_mask & away_norm.eq(team_norm)) | (bottom_mask & home_norm.eq(team_norm))
lineup_names = working.loc[team_mask, "player_name"].dropna().astype(str).unique().tolist()
if not lineup_names:
return out
arsenal_whiffs: list[float] = []
zone_whiffs: list[float] = []
for batter_name in lineup_names[:9]:
arsenal_row = build_batter_arsenal_feature_row(batter_statcast_df, batter_name)
family_zone_row = build_batter_family_zone_feature_row(batter_statcast_df, batter_name)
family_vals = [
_safe_float(arsenal_row.get(f"whiff_prob_{family}"))
for family in ["fastball", "breaking", "offspeed"]
]
family_vals = [v for v in family_vals if v is not None]
if family_vals:
arsenal_whiffs.append(sum(family_vals) / len(family_vals))
zone_vals: list[float] = []
for family in ["fastball", "breaking", "offspeed"]:
for zone in ["heart", "shadow", "chase", "waste"]:
val = _safe_float(family_zone_row.get(f"whiff_rate_{family}_{zone}"))
if val is not None:
zone_vals.append(val)
if zone_vals:
zone_whiffs.append(sum(zone_vals) / len(zone_vals))
if arsenal_whiffs:
out["lineup_whiff_risk"] = float(sum(arsenal_whiffs) / len(arsenal_whiffs))
if zone_whiffs:
out["lineup_zone_whiff_risk"] = float(sum(zone_whiffs) / len(zone_whiffs))
out["lineup_sample_size"] = len(lineup_names[:9])
return out
def _calibrate(probability: float) -> float:
centered = probability - 0.50
return _clamp(0.50 + (centered * 0.92), 0.02, 0.98)
def build_strikeout_probability_result(
pitcher_statcast_df: pd.DataFrame,
pitcher_name: str,
batter_statcast_df: pd.DataFrame | None = None,
opponent_batters: list[str] | None = None,
opponent_team: str | None = None,
line: float | None = None,
selection_side: str | None = None,
game_row: dict[str, Any] | None = None,
) -> dict[str, Any]:
result: dict[str, Any] = {
"mode": "pregame",
"raw_k_prob": None,
"calibrated_k_prob": None,
"fair_prob": None,
"expected_strikeouts": None,
"pitcher_swstr_rate": None,
"pitcher_csw_rate": None,
"pitcher_ball_rate": None,
"arsenal_whiff_risk": None,
"family_zone_whiff_risk": None,
"zone_whiff_risk": None,
"trajectory_tunnel_score": None,
"trajectory_release_consistency_score": None,
"sequencing_score": None,
"confidence_score": None,
"confidence_score_raw": None,
"confidence_score_display": None,
"confidence_source": "strikeout_v1_live",
"confidence_bucket": None,
"confidence_reasons": [],
"confidence_component_bonuses": [],
"confidence_component_penalties": [],
"confidence_primary_driver": None,
"confidence_summary_label": None,
"applied_layers": "",
"skipped_layers": "",
"reason_tags_for": [],
"reason_tags_against": [],
}
if (
pitcher_statcast_df is None
or pitcher_statcast_df.empty
or not pitcher_name
or line is None
or selection_side not in {"over", "under"}
):
result["skipped_layers"] = "missing_pitcher_or_line"
return result
pitcher_row = build_pitcher_feature_row(pitcher_statcast_df, pitcher_name)
pitcher_arsenal_row = build_pitcher_arsenal_feature_row(pitcher_statcast_df, pitcher_name)
pitcher_zone_row = build_pitcher_zone_feature_row(pitcher_statcast_df, pitcher_name)
pitcher_family_zone_row = build_pitcher_family_zone_feature_row(pitcher_statcast_df, pitcher_name)
traj_row = build_trajectory_features(pitcher_statcast_df, pitcher_name)
sequencing = _compute_sequencing_score(pitcher_statcast_df, pitcher_name)
opponent_overlay = _aggregate_opponent_whiff_overlay(
batter_statcast_df=batter_statcast_df if batter_statcast_df is not None else pd.DataFrame(),
opponent_batters=opponent_batters,
opponent_team=opponent_team,
)
lineup_family_zone_risk = 0.0
lineup_arsenal_risk = 0.0
if opponent_batters and batter_statcast_df is not None and not batter_statcast_df.empty:
family_zone_risks: list[float] = []
arsenal_risks: list[float] = []
for batter_name in opponent_batters[:9]:
batter_zone_row = build_batter_zone_feature_row(batter_statcast_df, batter_name)
batter_arsenal_row = build_batter_arsenal_feature_row(batter_statcast_df, batter_name)
batter_family_zone_row = build_batter_family_zone_feature_row(batter_statcast_df, batter_name)
zone_adj = compute_zone_matchup_adjustment(batter_zone_row, pitcher_zone_row)
arsenal_adj = compute_arsenal_matchup_adjustment(batter_arsenal_row, pitcher_arsenal_row)
family_zone_adj = compute_family_zone_matchup_adjustment(
batter_family_zone_row,
pitcher_family_zone_row,
)
zone_val = _safe_float(
family_zone_adj.get("family_zone_whiff_risk")
or zone_adj.get("hit_zone_boost")
)
arsenal_val = _safe_float(arsenal_adj.get("arsenal_whiff_risk"))
if zone_val is not None:
family_zone_risks.append(zone_val)
if arsenal_val is not None:
arsenal_risks.append(arsenal_val)
if family_zone_risks:
lineup_family_zone_risk = float(sum(family_zone_risks) / len(family_zone_risks))
if arsenal_risks:
lineup_arsenal_risk = float(sum(arsenal_risks) / len(arsenal_risks))
swstr = _safe_float(pitcher_row.get("swstr_rate"))
csw = _safe_float(pitcher_row.get("csw_rate"))
ball = _safe_float(pitcher_row.get("ball_rate"))
sample_size = int(pitcher_row.get("sample_size") or 0)
reliability = _reliability(sample_size, k=180.0)
lineup_reliability = _reliability(opponent_overlay.get("lineup_sample_size"), k=6.0)
traj_reliability = _reliability(traj_row.get("trajectory_sample_size"), k=220.0)
seq_reliability = _reliability(sequencing.get("sequencing_sample_size"), k=220.0)
expected_ks = 4.4
applied_layers: list[str] = []
reasons_for: list[str] = []
reasons_against: list[str] = []
if swstr is not None:
shift = ((swstr - 0.11) * 20.0) * reliability
expected_ks += shift
applied_layers.append("swstr")
if shift >= 0.30:
reasons_for.append("Misses bats consistently")
elif shift <= -0.25:
reasons_against.append("Swinging-strike rate is light")
if csw is not None:
shift = ((csw - 0.28) * 10.0) * reliability
expected_ks += shift
applied_layers.append("csw")
if shift >= 0.25:
reasons_for.append("Strong called plus whiff strike mix")
elif shift <= -0.20:
reasons_against.append("CSW profile is weak")
if ball is not None:
shift = ((0.36 - ball) * 8.0) * reliability
expected_ks += shift
applied_layers.append("ball_rate")
if shift >= 0.20:
reasons_for.append("Limits free balls and stays in leverage counts")
elif shift <= -0.20:
reasons_against.append("High ball rate can shorten outings")
arsenal_shift = ((lineup_arsenal_risk or opponent_overlay.get("lineup_whiff_risk") or 0.0) - 0.25) * 6.0 * lineup_reliability
expected_ks += arsenal_shift
if abs(arsenal_shift) > 1e-6:
applied_layers.append("arsenal")
if arsenal_shift >= 0.20:
reasons_for.append("Opponent whiff profile fits the arsenal mix")
elif arsenal_shift <= -0.15:
reasons_against.append("Opponent profile resists the primary mix")
family_zone_shift = ((lineup_family_zone_risk or opponent_overlay.get("lineup_zone_whiff_risk") or 0.0) - 0.24) * 5.0 * lineup_reliability
expected_ks += family_zone_shift
if abs(family_zone_shift) > 1e-6:
applied_layers.append("location")
if family_zone_shift >= 0.18:
reasons_for.append("Location profile creates chase and miss risk")
elif family_zone_shift <= -0.14:
reasons_against.append("Lineup handles these family-zone looks well")
tunnel = _safe_float(traj_row.get("tunnel_score"))
release_consistency = _safe_float(traj_row.get("release_consistency_score"))
if tunnel is not None:
shift = ((tunnel - 0.50) * 1.6) * traj_reliability
expected_ks += shift
applied_layers.append("tunneling")
if shift >= 0.10:
reasons_for.append("Strong pitch tunneling")
elif shift <= -0.10:
reasons_against.append("Tunneling is below average")
if release_consistency is not None:
shift = ((release_consistency - 0.50) * 1.2) * traj_reliability
expected_ks += shift
applied_layers.append("release")
if shift >= 0.08:
reasons_for.append("Repeatable release supports command")
elif shift <= -0.08:
reasons_against.append("Release consistency is shaky")
sequencing_score = _safe_float(sequencing.get("sequencing_score"))
if sequencing_score is not None:
shift = ((sequencing_score - 0.50) * 1.0) * seq_reliability
expected_ks += shift
applied_layers.append("sequencing")
if shift >= 0.08:
reasons_for.append("Sequencing keeps hitters off balance")
elif shift <= -0.08:
reasons_against.append("Pitch sequencing looks predictable")
line_value = float(line)
if selection_side == "over":
raw_prob = _poisson_prob_over(expected_ks, line_value)
else:
raw_prob = _poisson_prob_under(expected_ks, line_value)
calibrated_prob = _calibrate(raw_prob)
confidence = 52.0
confidence_reasons: list[str] = []
confidence_component_bonuses: list[dict[str, Any]] = []
confidence_component_penalties: list[dict[str, Any]] = []
if sample_size >= 400:
confidence += 10
confidence_component_bonuses.append(_confidence_component("Strong pitcher sample", 10, "bonus"))
elif sample_size < 150:
confidence -= 12
confidence_reasons.append("Limited pitcher pitch sample")
confidence_component_penalties.append(_confidence_component("Limited pitcher pitch sample", 12, "penalty"))
if opponent_overlay.get("lineup_sample_size", 0) >= 7:
confidence += 8
confidence_component_bonuses.append(_confidence_component("Projected lineup mostly complete", 8, "bonus"))
else:
confidence -= 6
confidence_reasons.append("Projected opponent lineup is incomplete")
confidence_component_penalties.append(_confidence_component("Projected opponent lineup is incomplete", 6, "penalty"))
if traj_reliability >= 0.45:
confidence += 5
confidence_component_bonuses.append(_confidence_component("Strong telemetry coverage", 5, "bonus"))
else:
confidence_reasons.append("Trajectory/tunneling sample is thin")
confidence_component_penalties.append(_confidence_component("Trajectory/tunneling sample is thin", 0, "penalty"))
if seq_reliability >= 0.40:
confidence += 4
confidence_component_bonuses.append(_confidence_component("Sequencing sample is stable", 4, "bonus"))
else:
confidence_reasons.append("Sequencing signal is still noisy")
confidence_component_penalties.append(_confidence_component("Sequencing signal is still noisy", 0, "penalty"))
if abs(calibrated_prob - 0.50) > 0.28:
confidence -= 5
confidence_reasons.append("Fair probability is still high-variance")
confidence_component_penalties.append(_confidence_component("Fair probability is still high-variance", 5, "penalty"))
confidence_raw = _clamp(confidence, 1.0, 100.0)
primary_penalty = max(
[item for item in confidence_component_penalties if float(item.get("value") or 0.0) > 0.0],
key=lambda item: float(item.get("value") or 0.0),
default=None,
)
primary_bonus = max(
[item for item in confidence_component_bonuses if float(item.get("value") or 0.0) > 0.0],
key=lambda item: float(item.get("value") or 0.0),
default=None,
)
primary_driver = primary_penalty or primary_bonus
summary_label = str((primary_driver or {}).get("label") or "").strip() or None
result.update(
{
"raw_k_prob": raw_prob,
"calibrated_k_prob": calibrated_prob,
"fair_prob": calibrated_prob,
"expected_strikeouts": _clamp(expected_ks, 1.0, 12.0),
"pitcher_swstr_rate": swstr,
"pitcher_csw_rate": csw,
"pitcher_ball_rate": ball,
"arsenal_whiff_risk": lineup_arsenal_risk or opponent_overlay.get("lineup_whiff_risk"),
"family_zone_whiff_risk": lineup_family_zone_risk or opponent_overlay.get("lineup_zone_whiff_risk"),
"zone_whiff_risk": lineup_family_zone_risk or opponent_overlay.get("lineup_zone_whiff_risk"),
"trajectory_tunnel_score": tunnel,
"trajectory_release_consistency_score": release_consistency,
"sequencing_score": sequencing_score,
"confidence_score": confidence_raw,
"confidence_score_raw": confidence_raw,
"confidence_score_display": confidence_raw,
"confidence_bucket": _bucket(confidence_raw),
"confidence_reasons": confidence_reasons[:5],
"confidence_component_bonuses": confidence_component_bonuses,
"confidence_component_penalties": confidence_component_penalties,
"confidence_primary_driver": primary_driver,
"confidence_summary_label": summary_label,
"applied_layers": "|".join(applied_layers),
"reason_tags_for": reasons_for[:4],
"reason_tags_against": reasons_against[:4],
"pitcher_reliability": reliability,
"lineup_reliability": lineup_reliability,
"trajectory_reliability": traj_reliability,
"sequencing_reliability": seq_reliability,
}
)
return result