from __future__ import annotations from typing import Any import pandas as pd from models.arsenal_matchup_model import compute_arsenal_matchup_adjustment from models.batter_arsenal_model import build_batter_arsenal_feature_row from models.batter_zone_model import build_batter_zone_feature_row from models.family_zone_profile_store import ( build_batter_family_zone_feature_row, build_pitcher_family_zone_feature_row, ) from models.matchup_model import compute_family_zone_matchup_adjustment from models.pitch_sequence_model import build_sequence_features, predict_next_pitch_distribution from models.pitcher_arsenal_model import build_pitcher_arsenal_feature_row from models.pitcher_zone_model import build_pitcher_zone_feature_row from models.trajectory_model import build_trajectory_features from models.zone_matchup_model import compute_zone_matchup_adjustment _COUNT_STATES: tuple[tuple[int, int], ...] = ( (0, 0), (1, 0), (0, 1), (1, 1), (0, 2), (1, 2), (2, 2), (3, 2), ) _PITCH_FAMILIES: tuple[str, ...] = ("fastball", "breaking", "offspeed") _ZONES: tuple[str, ...] = ("heart", "shadow", "chase", "waste") def _safe_float(value: Any, default: float = 0.0) -> float: try: if value is None: return default text = str(value).strip().lower() if text in {"", "nan", "none"}: return default return float(value) except Exception: return default def _clamp(value: float, lo: float, hi: float) -> float: return max(lo, min(hi, value)) def _reliability(sample_size: Any, k: float) -> float: sample = max(0.0, _safe_float(sample_size, 0.0)) return _clamp(sample / (sample + max(1.0, float(k))), 0.0, 1.0) def _component_source_map() -> dict[str, dict[str, str]]: return { "zone_matchup": {"classification": "upgrade_existing_module", "source_module": "models.zone_matchup_model"}, "family_zone_matchup": {"classification": "reuse_as_is", "source_module": "models.matchup_model"}, "arsenal_matchup": {"classification": "upgrade_existing_module", "source_module": "models.arsenal_matchup_model"}, "trajectory": {"classification": "reuse_as_is", "source_module": "models.trajectory_model"}, "sequencing": {"classification": "upgrade_existing_module", "source_module": "models.pitch_sequence_model"}, "count_context": {"classification": "upgrade_existing_module", "source_module": "models.pitch_sequence_model"}, "shared_composer": {"classification": "new_source_of_truth_component", "source_module": "models.shared_matchup_engine"}, } def _runtime_bucket(runtime_cache: dict[str, Any] | None, key: str) -> dict[str, Any]: if runtime_cache is None: return {} bucket = runtime_cache.get(key) if not isinstance(bucket, dict): bucket = {} runtime_cache[key] = bucket return bucket def _cache_get_or_build( runtime_cache: dict[str, Any] | None, bucket_name: str, cache_key: tuple[Any, ...], builder, ): if runtime_cache is None: return builder() bucket = _runtime_bucket(runtime_cache, bucket_name) if cache_key not in bucket: bucket[cache_key] = builder() return bucket[cache_key] def _build_pitch_zone_mix( sequence_profiles: dict[str, dict[str, Any]], ) -> dict[str, float]: combined: dict[str, float] = {} if not sequence_profiles: return combined count_weight = 1.0 / float(len(sequence_profiles)) for payload in sequence_profiles.values(): fb = _safe_float(payload.get("fastball_prob")) br = _safe_float(payload.get("breaking_prob")) os = _safe_float(payload.get("offspeed_prob")) zone_probs = payload.get("zone_probs", {}) or {} family_probs = { "fastball": fb, "breaking": br, "offspeed": os, } for family, family_prob in family_probs.items(): for zone in _ZONES: zone_prob = _safe_float(zone_probs.get(zone)) combined[f"{family}_{zone}"] = combined.get(f"{family}_{zone}", 0.0) + ( family_prob * zone_prob * count_weight ) total = sum(combined.values()) if total > 0: for key in list(combined.keys()): combined[key] = combined[key] / total return combined def _build_pitch_family_mix( sequence_profiles: dict[str, dict[str, Any]], ) -> dict[str, float]: combined = {family: 0.0 for family in _PITCH_FAMILIES} if not sequence_profiles: return combined count_weight = 1.0 / float(len(sequence_profiles)) for payload in sequence_profiles.values(): for family in _PITCH_FAMILIES: combined[family] += _safe_float(payload.get(f"{family}_prob")) * count_weight total = sum(combined.values()) if total > 0: for key in list(combined.keys()): combined[key] = combined[key] / total return combined def _top_regions(weighted_map: dict[str, float], limit: int = 4) -> list[dict[str, Any]]: rows = [ {"region": key, "score": round(float(val), 6)} for key, val in weighted_map.items() if float(val) > 0 ] rows.sort(key=lambda item: item["score"], reverse=True) return rows[:limit] def compose_shared_matchup_context( *, batter_name: str, pitcher_name: str, batter_statcast_df: pd.DataFrame | None, pitcher_statcast_df: pd.DataFrame | None, batter_features: dict[str, Any] | None = None, pitcher_row: dict[str, Any] | None = None, game_row: dict[str, Any] | None = None, runtime_cache: dict[str, Any] | None = None, ) -> dict[str, Any]: empty = { "expected_pitch_mix_by_count": {}, "expected_zone_mix_by_count": {}, "expected_pitch_zone_mix_by_count": {}, "expected_pitch_family_mix": {}, "tunnel_pair_scores": [], "predicted_attack_regions": [], "predicted_damage_regions": [], "predicted_whiff_regions": [], "handedness_context": {}, "count_context_profile": {}, "matchup_coverage_confidence": 0.0, "component_source_map": _component_source_map(), "zone_matchup": {}, "family_zone_matchup": {}, "arsenal_matchup": {}, "trajectory": {}, "sequence_profiles": {}, } batter_df = batter_statcast_df if batter_statcast_df is not None else pd.DataFrame() pitcher_df = pitcher_statcast_df if pitcher_statcast_df is not None else batter_df if batter_df.empty or pitcher_df.empty or not batter_name or not pitcher_name: return empty batter_features = dict(batter_features or {}) pitcher_row = dict(pitcher_row or {}) game_row = dict(game_row or {}) cache_key = ( str(batter_name or "").strip().lower(), str(pitcher_name or "").strip().lower(), str(game_row.get("away_team") or "").strip().lower(), str(game_row.get("home_team") or "").strip().lower(), str(game_row.get("projected_starter_match_status") or "").strip().lower(), str(game_row.get("pitcher_id") or "").strip(), str(batter_features.get("batter_stand") or "").strip().upper(), str(pitcher_row.get("p_throws") or "").strip().upper(), ) shared_bucket = _runtime_bucket(runtime_cache, "shared_matchup_context") if runtime_cache is not None and cache_key in shared_bucket: return shared_bucket[cache_key] handedness_context = { "batter_stand": str(batter_features.get("batter_stand", "") or "").strip().upper(), "pitcher_hand": str(pitcher_row.get("p_throws", "") or "").strip().upper(), } pitcher_family_zone_row = _cache_get_or_build( runtime_cache, "pitcher_family_zone_rows", (id(pitcher_df), str(pitcher_name or "").strip().lower()), lambda: build_pitcher_family_zone_feature_row(pitcher_df, pitcher_name), ) count_context_profile: dict[str, dict[str, Any]] = {} expected_pitch_mix_by_count: dict[str, dict[str, float]] = {} expected_zone_mix_by_count: dict[str, dict[str, float]] = {} sequence_profiles: dict[str, dict[str, Any]] = {} for balls, strikes in _COUNT_STATES: count_key = f"{balls}-{strikes}" seq_features = build_sequence_features( game_row={**game_row, "balls": balls, "strikes": strikes}, pitcher_row=pitcher_row, batter_row=batter_features, pitcher_family_zone_row=pitcher_family_zone_row, ) seq_profile = predict_next_pitch_distribution(seq_features) sequence_profiles[count_key] = seq_profile expected_pitch_mix_by_count[count_key] = { family: round(_safe_float(seq_profile.get(f"{family}_prob")), 6) for family in _PITCH_FAMILIES } expected_zone_mix_by_count[count_key] = { zone: round(_safe_float((seq_profile.get("zone_probs") or {}).get(zone)), 6) for zone in _ZONES } leverage = "neutral" if strikes >= 2: leverage = "putaway" elif balls >= 2: leverage = "hitter_ahead" count_context_profile[count_key] = { "balls": balls, "strikes": strikes, "count_leverage": leverage, } expected_pitch_zone_mix_by_count = _build_pitch_zone_mix(sequence_profiles) expected_pitch_family_mix = _build_pitch_family_mix(sequence_profiles) batter_zone_row = _cache_get_or_build( runtime_cache, "batter_zone_rows", (id(batter_df), str(batter_name or "").strip().lower()), lambda: build_batter_zone_feature_row(batter_df, batter_name), ) pitcher_zone_row = _cache_get_or_build( runtime_cache, "pitcher_zone_rows", (id(pitcher_df), str(pitcher_name or "").strip().lower()), lambda: build_pitcher_zone_feature_row(pitcher_df, pitcher_name), ) zone_matchup = compute_zone_matchup_adjustment( batter_zone_row, pitcher_zone_row, pitch_zone_weights=expected_pitch_zone_mix_by_count, handedness_context=handedness_context, ) batter_family_zone_row = _cache_get_or_build( runtime_cache, "batter_family_zone_rows", (id(batter_df), str(batter_name or "").strip().lower()), lambda: build_batter_family_zone_feature_row(batter_df, batter_name), ) family_zone_matchup = compute_family_zone_matchup_adjustment( batter_family_zone_row=batter_family_zone_row, pitcher_family_zone_row=pitcher_family_zone_row, ) batter_arsenal_row = _cache_get_or_build( runtime_cache, "batter_arsenal_rows", (id(batter_df), str(batter_name or "").strip().lower()), lambda: build_batter_arsenal_feature_row(batter_df, batter_name), ) pitcher_arsenal_row = _cache_get_or_build( runtime_cache, "pitcher_arsenal_rows", (id(pitcher_df), str(pitcher_name or "").strip().lower()), lambda: build_pitcher_arsenal_feature_row(pitcher_df, pitcher_name), ) arsenal_matchup = compute_arsenal_matchup_adjustment( batter_arsenal_row=batter_arsenal_row, pitcher_arsenal_row=pitcher_arsenal_row, pitch_family_weights=expected_pitch_family_mix, handedness_context=handedness_context, ) trajectory = _cache_get_or_build( runtime_cache, "trajectory_rows", (id(pitcher_df), str(pitcher_name or "").strip().lower(), str(game_row.get("pitcher_id") or "").strip()), lambda: build_trajectory_features( statcast_df=pitcher_df, pitcher_name=pitcher_name, pitcher_id=game_row.get("pitcher_id"), ), ) attack_regions = _top_regions(expected_pitch_zone_mix_by_count, limit=5) damage_region_map: dict[str, float] = {} whiff_region_map: dict[str, float] = {} for key, attack_weight in expected_pitch_zone_mix_by_count.items(): try: family, zone = key.split("_", 1) except ValueError: continue batter_damage = _safe_float(batter_family_zone_row.get(f"damage_rate_{family}_{zone}")) pitcher_damage = _safe_float(pitcher_family_zone_row.get(f"damage_allowed_rate_{family}_{zone}")) batter_whiff = _safe_float(batter_family_zone_row.get(f"whiff_rate_{family}_{zone}")) pitcher_whiff = _safe_float(pitcher_family_zone_row.get(f"whiff_rate_{family}_{zone}")) damage_region_map[key] = attack_weight * ((batter_damage * 0.6) + (pitcher_damage * 0.4)) whiff_region_map[key] = attack_weight * ((batter_whiff * 0.55) + (pitcher_whiff * 0.45)) tunnel_score = _safe_float(trajectory.get("tunnel_score")) release_score = _safe_float(trajectory.get("release_consistency_score")) tunnel_pair_scores = [ { "pair": "arsenal_tunnel_profile", "tunnel_score": round(tunnel_score, 6), "release_consistency_score": round(release_score, 6), "deception_score": round(_safe_float(trajectory.get("deception_score")), 6), } ] if (tunnel_score or release_score) else [] coverage_signals = [ _reliability(batter_zone_row.get("zone_sample_size"), 200.0), _reliability(pitcher_zone_row.get("zone_sample_size"), 200.0), _reliability(batter_family_zone_row.get("family_zone_sample_size"), 220.0), _reliability(pitcher_family_zone_row.get("family_zone_sample_size"), 220.0), _reliability(batter_arsenal_row.get("arsenal_sample_size"), 180.0), _reliability(pitcher_arsenal_row.get("arsenal_sample_size"), 180.0), _reliability(trajectory.get("trajectory_sample_size"), 240.0), ] matchup_coverage_confidence = round(sum(coverage_signals) / len(coverage_signals), 4) result = { "expected_pitch_mix_by_count": expected_pitch_mix_by_count, "expected_zone_mix_by_count": expected_zone_mix_by_count, "expected_pitch_zone_mix_by_count": expected_pitch_zone_mix_by_count, "expected_pitch_family_mix": expected_pitch_family_mix, "tunnel_pair_scores": tunnel_pair_scores, "predicted_attack_regions": attack_regions, "predicted_damage_regions": _top_regions(damage_region_map, limit=5), "predicted_whiff_regions": _top_regions(whiff_region_map, limit=5), "handedness_context": handedness_context, "count_context_profile": count_context_profile, "matchup_coverage_confidence": matchup_coverage_confidence, "component_source_map": _component_source_map(), "zone_matchup": zone_matchup, "family_zone_matchup": family_zone_matchup, "arsenal_matchup": arsenal_matchup, "trajectory": trajectory, "sequence_profiles": sequence_profiles, "_component_rows": { "batter_zone_row": batter_zone_row, "pitcher_zone_row": pitcher_zone_row, "batter_family_zone_row": batter_family_zone_row, "pitcher_family_zone_row": pitcher_family_zone_row, "batter_arsenal_row": batter_arsenal_row, "pitcher_arsenal_row": pitcher_arsenal_row, }, } if runtime_cache is not None: shared_bucket[cache_key] = result return result