2026_MLB_Model / models /shared_matchup_engine.py
Syntrex's picture
Optimize props load path and reuse modeled state
2885bcc
raw
history blame
15.3 kB
from __future__ import annotations
from typing import Any
import pandas as pd
from models.arsenal_matchup_model import compute_arsenal_matchup_adjustment
from models.batter_arsenal_model import build_batter_arsenal_feature_row
from models.batter_zone_model import build_batter_zone_feature_row
from models.family_zone_profile_store import (
build_batter_family_zone_feature_row,
build_pitcher_family_zone_feature_row,
)
from models.matchup_model import compute_family_zone_matchup_adjustment
from models.pitch_sequence_model import build_sequence_features, predict_next_pitch_distribution
from models.pitcher_arsenal_model import build_pitcher_arsenal_feature_row
from models.pitcher_zone_model import build_pitcher_zone_feature_row
from models.trajectory_model import build_trajectory_features
from models.zone_matchup_model import compute_zone_matchup_adjustment
_COUNT_STATES: tuple[tuple[int, int], ...] = (
(0, 0),
(1, 0),
(0, 1),
(1, 1),
(0, 2),
(1, 2),
(2, 2),
(3, 2),
)
_PITCH_FAMILIES: tuple[str, ...] = ("fastball", "breaking", "offspeed")
_ZONES: tuple[str, ...] = ("heart", "shadow", "chase", "waste")
def _safe_float(value: Any, default: float = 0.0) -> float:
try:
if value is None:
return default
text = str(value).strip().lower()
if text in {"", "nan", "none"}:
return default
return float(value)
except Exception:
return default
def _clamp(value: float, lo: float, hi: float) -> float:
return max(lo, min(hi, value))
def _reliability(sample_size: Any, k: float) -> float:
sample = max(0.0, _safe_float(sample_size, 0.0))
return _clamp(sample / (sample + max(1.0, float(k))), 0.0, 1.0)
def _component_source_map() -> dict[str, dict[str, str]]:
return {
"zone_matchup": {"classification": "upgrade_existing_module", "source_module": "models.zone_matchup_model"},
"family_zone_matchup": {"classification": "reuse_as_is", "source_module": "models.matchup_model"},
"arsenal_matchup": {"classification": "upgrade_existing_module", "source_module": "models.arsenal_matchup_model"},
"trajectory": {"classification": "reuse_as_is", "source_module": "models.trajectory_model"},
"sequencing": {"classification": "upgrade_existing_module", "source_module": "models.pitch_sequence_model"},
"count_context": {"classification": "upgrade_existing_module", "source_module": "models.pitch_sequence_model"},
"shared_composer": {"classification": "new_source_of_truth_component", "source_module": "models.shared_matchup_engine"},
}
def _runtime_bucket(runtime_cache: dict[str, Any] | None, key: str) -> dict[str, Any]:
if runtime_cache is None:
return {}
bucket = runtime_cache.get(key)
if not isinstance(bucket, dict):
bucket = {}
runtime_cache[key] = bucket
return bucket
def _cache_get_or_build(
runtime_cache: dict[str, Any] | None,
bucket_name: str,
cache_key: tuple[Any, ...],
builder,
):
if runtime_cache is None:
return builder()
bucket = _runtime_bucket(runtime_cache, bucket_name)
if cache_key not in bucket:
bucket[cache_key] = builder()
return bucket[cache_key]
def _build_pitch_zone_mix(
sequence_profiles: dict[str, dict[str, Any]],
) -> dict[str, float]:
combined: dict[str, float] = {}
if not sequence_profiles:
return combined
count_weight = 1.0 / float(len(sequence_profiles))
for payload in sequence_profiles.values():
fb = _safe_float(payload.get("fastball_prob"))
br = _safe_float(payload.get("breaking_prob"))
os = _safe_float(payload.get("offspeed_prob"))
zone_probs = payload.get("zone_probs", {}) or {}
family_probs = {
"fastball": fb,
"breaking": br,
"offspeed": os,
}
for family, family_prob in family_probs.items():
for zone in _ZONES:
zone_prob = _safe_float(zone_probs.get(zone))
combined[f"{family}_{zone}"] = combined.get(f"{family}_{zone}", 0.0) + (
family_prob * zone_prob * count_weight
)
total = sum(combined.values())
if total > 0:
for key in list(combined.keys()):
combined[key] = combined[key] / total
return combined
def _build_pitch_family_mix(
sequence_profiles: dict[str, dict[str, Any]],
) -> dict[str, float]:
combined = {family: 0.0 for family in _PITCH_FAMILIES}
if not sequence_profiles:
return combined
count_weight = 1.0 / float(len(sequence_profiles))
for payload in sequence_profiles.values():
for family in _PITCH_FAMILIES:
combined[family] += _safe_float(payload.get(f"{family}_prob")) * count_weight
total = sum(combined.values())
if total > 0:
for key in list(combined.keys()):
combined[key] = combined[key] / total
return combined
def _top_regions(weighted_map: dict[str, float], limit: int = 4) -> list[dict[str, Any]]:
rows = [
{"region": key, "score": round(float(val), 6)}
for key, val in weighted_map.items()
if float(val) > 0
]
rows.sort(key=lambda item: item["score"], reverse=True)
return rows[:limit]
def compose_shared_matchup_context(
*,
batter_name: str,
pitcher_name: str,
batter_statcast_df: pd.DataFrame | None,
pitcher_statcast_df: pd.DataFrame | None,
batter_features: dict[str, Any] | None = None,
pitcher_row: dict[str, Any] | None = None,
game_row: dict[str, Any] | None = None,
runtime_cache: dict[str, Any] | None = None,
) -> dict[str, Any]:
empty = {
"expected_pitch_mix_by_count": {},
"expected_zone_mix_by_count": {},
"expected_pitch_zone_mix_by_count": {},
"expected_pitch_family_mix": {},
"tunnel_pair_scores": [],
"predicted_attack_regions": [],
"predicted_damage_regions": [],
"predicted_whiff_regions": [],
"handedness_context": {},
"count_context_profile": {},
"matchup_coverage_confidence": 0.0,
"component_source_map": _component_source_map(),
"zone_matchup": {},
"family_zone_matchup": {},
"arsenal_matchup": {},
"trajectory": {},
"sequence_profiles": {},
}
batter_df = batter_statcast_df if batter_statcast_df is not None else pd.DataFrame()
pitcher_df = pitcher_statcast_df if pitcher_statcast_df is not None else batter_df
if batter_df.empty or pitcher_df.empty or not batter_name or not pitcher_name:
return empty
batter_features = dict(batter_features or {})
pitcher_row = dict(pitcher_row or {})
game_row = dict(game_row or {})
cache_key = (
str(batter_name or "").strip().lower(),
str(pitcher_name or "").strip().lower(),
str(game_row.get("away_team") or "").strip().lower(),
str(game_row.get("home_team") or "").strip().lower(),
str(game_row.get("projected_starter_match_status") or "").strip().lower(),
str(game_row.get("pitcher_id") or "").strip(),
str(batter_features.get("batter_stand") or "").strip().upper(),
str(pitcher_row.get("p_throws") or "").strip().upper(),
)
shared_bucket = _runtime_bucket(runtime_cache, "shared_matchup_context")
if runtime_cache is not None and cache_key in shared_bucket:
return shared_bucket[cache_key]
handedness_context = {
"batter_stand": str(batter_features.get("batter_stand", "") or "").strip().upper(),
"pitcher_hand": str(pitcher_row.get("p_throws", "") or "").strip().upper(),
}
pitcher_family_zone_row = _cache_get_or_build(
runtime_cache,
"pitcher_family_zone_rows",
(id(pitcher_df), str(pitcher_name or "").strip().lower()),
lambda: build_pitcher_family_zone_feature_row(pitcher_df, pitcher_name),
)
count_context_profile: dict[str, dict[str, Any]] = {}
expected_pitch_mix_by_count: dict[str, dict[str, float]] = {}
expected_zone_mix_by_count: dict[str, dict[str, float]] = {}
sequence_profiles: dict[str, dict[str, Any]] = {}
for balls, strikes in _COUNT_STATES:
count_key = f"{balls}-{strikes}"
seq_features = build_sequence_features(
game_row={**game_row, "balls": balls, "strikes": strikes},
pitcher_row=pitcher_row,
batter_row=batter_features,
pitcher_family_zone_row=pitcher_family_zone_row,
)
seq_profile = predict_next_pitch_distribution(seq_features)
sequence_profiles[count_key] = seq_profile
expected_pitch_mix_by_count[count_key] = {
family: round(_safe_float(seq_profile.get(f"{family}_prob")), 6)
for family in _PITCH_FAMILIES
}
expected_zone_mix_by_count[count_key] = {
zone: round(_safe_float((seq_profile.get("zone_probs") or {}).get(zone)), 6)
for zone in _ZONES
}
leverage = "neutral"
if strikes >= 2:
leverage = "putaway"
elif balls >= 2:
leverage = "hitter_ahead"
count_context_profile[count_key] = {
"balls": balls,
"strikes": strikes,
"count_leverage": leverage,
}
expected_pitch_zone_mix_by_count = _build_pitch_zone_mix(sequence_profiles)
expected_pitch_family_mix = _build_pitch_family_mix(sequence_profiles)
batter_zone_row = _cache_get_or_build(
runtime_cache,
"batter_zone_rows",
(id(batter_df), str(batter_name or "").strip().lower()),
lambda: build_batter_zone_feature_row(batter_df, batter_name),
)
pitcher_zone_row = _cache_get_or_build(
runtime_cache,
"pitcher_zone_rows",
(id(pitcher_df), str(pitcher_name or "").strip().lower()),
lambda: build_pitcher_zone_feature_row(pitcher_df, pitcher_name),
)
zone_matchup = compute_zone_matchup_adjustment(
batter_zone_row,
pitcher_zone_row,
pitch_zone_weights=expected_pitch_zone_mix_by_count,
handedness_context=handedness_context,
)
batter_family_zone_row = _cache_get_or_build(
runtime_cache,
"batter_family_zone_rows",
(id(batter_df), str(batter_name or "").strip().lower()),
lambda: build_batter_family_zone_feature_row(batter_df, batter_name),
)
family_zone_matchup = compute_family_zone_matchup_adjustment(
batter_family_zone_row=batter_family_zone_row,
pitcher_family_zone_row=pitcher_family_zone_row,
)
batter_arsenal_row = _cache_get_or_build(
runtime_cache,
"batter_arsenal_rows",
(id(batter_df), str(batter_name or "").strip().lower()),
lambda: build_batter_arsenal_feature_row(batter_df, batter_name),
)
pitcher_arsenal_row = _cache_get_or_build(
runtime_cache,
"pitcher_arsenal_rows",
(id(pitcher_df), str(pitcher_name or "").strip().lower()),
lambda: build_pitcher_arsenal_feature_row(pitcher_df, pitcher_name),
)
arsenal_matchup = compute_arsenal_matchup_adjustment(
batter_arsenal_row=batter_arsenal_row,
pitcher_arsenal_row=pitcher_arsenal_row,
pitch_family_weights=expected_pitch_family_mix,
handedness_context=handedness_context,
)
trajectory = _cache_get_or_build(
runtime_cache,
"trajectory_rows",
(id(pitcher_df), str(pitcher_name or "").strip().lower(), str(game_row.get("pitcher_id") or "").strip()),
lambda: build_trajectory_features(
statcast_df=pitcher_df,
pitcher_name=pitcher_name,
pitcher_id=game_row.get("pitcher_id"),
),
)
attack_regions = _top_regions(expected_pitch_zone_mix_by_count, limit=5)
damage_region_map: dict[str, float] = {}
whiff_region_map: dict[str, float] = {}
for key, attack_weight in expected_pitch_zone_mix_by_count.items():
try:
family, zone = key.split("_", 1)
except ValueError:
continue
batter_damage = _safe_float(batter_family_zone_row.get(f"damage_rate_{family}_{zone}"))
pitcher_damage = _safe_float(pitcher_family_zone_row.get(f"damage_allowed_rate_{family}_{zone}"))
batter_whiff = _safe_float(batter_family_zone_row.get(f"whiff_rate_{family}_{zone}"))
pitcher_whiff = _safe_float(pitcher_family_zone_row.get(f"whiff_rate_{family}_{zone}"))
damage_region_map[key] = attack_weight * ((batter_damage * 0.6) + (pitcher_damage * 0.4))
whiff_region_map[key] = attack_weight * ((batter_whiff * 0.55) + (pitcher_whiff * 0.45))
tunnel_score = _safe_float(trajectory.get("tunnel_score"))
release_score = _safe_float(trajectory.get("release_consistency_score"))
tunnel_pair_scores = [
{
"pair": "arsenal_tunnel_profile",
"tunnel_score": round(tunnel_score, 6),
"release_consistency_score": round(release_score, 6),
"deception_score": round(_safe_float(trajectory.get("deception_score")), 6),
}
] if (tunnel_score or release_score) else []
coverage_signals = [
_reliability(batter_zone_row.get("zone_sample_size"), 200.0),
_reliability(pitcher_zone_row.get("zone_sample_size"), 200.0),
_reliability(batter_family_zone_row.get("family_zone_sample_size"), 220.0),
_reliability(pitcher_family_zone_row.get("family_zone_sample_size"), 220.0),
_reliability(batter_arsenal_row.get("arsenal_sample_size"), 180.0),
_reliability(pitcher_arsenal_row.get("arsenal_sample_size"), 180.0),
_reliability(trajectory.get("trajectory_sample_size"), 240.0),
]
matchup_coverage_confidence = round(sum(coverage_signals) / len(coverage_signals), 4)
result = {
"expected_pitch_mix_by_count": expected_pitch_mix_by_count,
"expected_zone_mix_by_count": expected_zone_mix_by_count,
"expected_pitch_zone_mix_by_count": expected_pitch_zone_mix_by_count,
"expected_pitch_family_mix": expected_pitch_family_mix,
"tunnel_pair_scores": tunnel_pair_scores,
"predicted_attack_regions": attack_regions,
"predicted_damage_regions": _top_regions(damage_region_map, limit=5),
"predicted_whiff_regions": _top_regions(whiff_region_map, limit=5),
"handedness_context": handedness_context,
"count_context_profile": count_context_profile,
"matchup_coverage_confidence": matchup_coverage_confidence,
"component_source_map": _component_source_map(),
"zone_matchup": zone_matchup,
"family_zone_matchup": family_zone_matchup,
"arsenal_matchup": arsenal_matchup,
"trajectory": trajectory,
"sequence_profiles": sequence_profiles,
"_component_rows": {
"batter_zone_row": batter_zone_row,
"pitcher_zone_row": pitcher_zone_row,
"batter_family_zone_row": batter_family_zone_row,
"pitcher_family_zone_row": pitcher_family_zone_row,
"batter_arsenal_row": batter_arsenal_row,
"pitcher_arsenal_row": pitcher_arsenal_row,
},
}
if runtime_cache is not None:
shared_bucket[cache_key] = result
return result