Spaces:
Running
Running
| from __future__ import annotations | |
| from typing import Any | |
| import pandas as pd | |
| from models.arsenal_matchup_model import compute_arsenal_matchup_adjustment | |
| from models.batter_arsenal_model import build_batter_arsenal_feature_row | |
| from models.batter_zone_model import build_batter_zone_feature_row | |
| from models.family_zone_profile_store import ( | |
| build_batter_family_zone_feature_row, | |
| build_pitcher_family_zone_feature_row, | |
| ) | |
| from models.matchup_model import compute_family_zone_matchup_adjustment | |
| from models.pitch_sequence_model import build_sequence_features, predict_next_pitch_distribution | |
| from models.pitcher_arsenal_model import build_pitcher_arsenal_feature_row | |
| from models.pitcher_zone_model import build_pitcher_zone_feature_row | |
| from models.trajectory_model import build_trajectory_features | |
| from models.zone_matchup_model import compute_zone_matchup_adjustment | |
| _COUNT_STATES: tuple[tuple[int, int], ...] = ( | |
| (0, 0), | |
| (1, 0), | |
| (0, 1), | |
| (1, 1), | |
| (0, 2), | |
| (1, 2), | |
| (2, 2), | |
| (3, 2), | |
| ) | |
| _PITCH_FAMILIES: tuple[str, ...] = ("fastball", "breaking", "offspeed") | |
| _ZONES: tuple[str, ...] = ("heart", "shadow", "chase", "waste") | |
| def _safe_float(value: Any, default: float = 0.0) -> float: | |
| try: | |
| if value is None: | |
| return default | |
| text = str(value).strip().lower() | |
| if text in {"", "nan", "none"}: | |
| return default | |
| return float(value) | |
| except Exception: | |
| return default | |
| def _clamp(value: float, lo: float, hi: float) -> float: | |
| return max(lo, min(hi, value)) | |
| def _reliability(sample_size: Any, k: float) -> float: | |
| sample = max(0.0, _safe_float(sample_size, 0.0)) | |
| return _clamp(sample / (sample + max(1.0, float(k))), 0.0, 1.0) | |
| def _component_source_map() -> dict[str, dict[str, str]]: | |
| return { | |
| "zone_matchup": {"classification": "upgrade_existing_module", "source_module": "models.zone_matchup_model"}, | |
| "family_zone_matchup": {"classification": "reuse_as_is", "source_module": "models.matchup_model"}, | |
| "arsenal_matchup": {"classification": "upgrade_existing_module", "source_module": "models.arsenal_matchup_model"}, | |
| "trajectory": {"classification": "reuse_as_is", "source_module": "models.trajectory_model"}, | |
| "sequencing": {"classification": "upgrade_existing_module", "source_module": "models.pitch_sequence_model"}, | |
| "count_context": {"classification": "upgrade_existing_module", "source_module": "models.pitch_sequence_model"}, | |
| "shared_composer": {"classification": "new_source_of_truth_component", "source_module": "models.shared_matchup_engine"}, | |
| } | |
| def _runtime_bucket(runtime_cache: dict[str, Any] | None, key: str) -> dict[str, Any]: | |
| if runtime_cache is None: | |
| return {} | |
| bucket = runtime_cache.get(key) | |
| if not isinstance(bucket, dict): | |
| bucket = {} | |
| runtime_cache[key] = bucket | |
| return bucket | |
| def _cache_get_or_build( | |
| runtime_cache: dict[str, Any] | None, | |
| bucket_name: str, | |
| cache_key: tuple[Any, ...], | |
| builder, | |
| ): | |
| if runtime_cache is None: | |
| return builder() | |
| bucket = _runtime_bucket(runtime_cache, bucket_name) | |
| if cache_key not in bucket: | |
| bucket[cache_key] = builder() | |
| return bucket[cache_key] | |
| def _build_pitch_zone_mix( | |
| sequence_profiles: dict[str, dict[str, Any]], | |
| ) -> dict[str, float]: | |
| combined: dict[str, float] = {} | |
| if not sequence_profiles: | |
| return combined | |
| count_weight = 1.0 / float(len(sequence_profiles)) | |
| for payload in sequence_profiles.values(): | |
| fb = _safe_float(payload.get("fastball_prob")) | |
| br = _safe_float(payload.get("breaking_prob")) | |
| os = _safe_float(payload.get("offspeed_prob")) | |
| zone_probs = payload.get("zone_probs", {}) or {} | |
| family_probs = { | |
| "fastball": fb, | |
| "breaking": br, | |
| "offspeed": os, | |
| } | |
| for family, family_prob in family_probs.items(): | |
| for zone in _ZONES: | |
| zone_prob = _safe_float(zone_probs.get(zone)) | |
| combined[f"{family}_{zone}"] = combined.get(f"{family}_{zone}", 0.0) + ( | |
| family_prob * zone_prob * count_weight | |
| ) | |
| total = sum(combined.values()) | |
| if total > 0: | |
| for key in list(combined.keys()): | |
| combined[key] = combined[key] / total | |
| return combined | |
| def _build_pitch_family_mix( | |
| sequence_profiles: dict[str, dict[str, Any]], | |
| ) -> dict[str, float]: | |
| combined = {family: 0.0 for family in _PITCH_FAMILIES} | |
| if not sequence_profiles: | |
| return combined | |
| count_weight = 1.0 / float(len(sequence_profiles)) | |
| for payload in sequence_profiles.values(): | |
| for family in _PITCH_FAMILIES: | |
| combined[family] += _safe_float(payload.get(f"{family}_prob")) * count_weight | |
| total = sum(combined.values()) | |
| if total > 0: | |
| for key in list(combined.keys()): | |
| combined[key] = combined[key] / total | |
| return combined | |
| def _top_regions(weighted_map: dict[str, float], limit: int = 4) -> list[dict[str, Any]]: | |
| rows = [ | |
| {"region": key, "score": round(float(val), 6)} | |
| for key, val in weighted_map.items() | |
| if float(val) > 0 | |
| ] | |
| rows.sort(key=lambda item: item["score"], reverse=True) | |
| return rows[:limit] | |
| def compose_shared_matchup_context( | |
| *, | |
| batter_name: str, | |
| pitcher_name: str, | |
| batter_statcast_df: pd.DataFrame | None, | |
| pitcher_statcast_df: pd.DataFrame | None, | |
| batter_features: dict[str, Any] | None = None, | |
| pitcher_row: dict[str, Any] | None = None, | |
| game_row: dict[str, Any] | None = None, | |
| runtime_cache: dict[str, Any] | None = None, | |
| ) -> dict[str, Any]: | |
| empty = { | |
| "expected_pitch_mix_by_count": {}, | |
| "expected_zone_mix_by_count": {}, | |
| "expected_pitch_zone_mix_by_count": {}, | |
| "expected_pitch_family_mix": {}, | |
| "tunnel_pair_scores": [], | |
| "predicted_attack_regions": [], | |
| "predicted_damage_regions": [], | |
| "predicted_whiff_regions": [], | |
| "handedness_context": {}, | |
| "count_context_profile": {}, | |
| "matchup_coverage_confidence": 0.0, | |
| "component_source_map": _component_source_map(), | |
| "zone_matchup": {}, | |
| "family_zone_matchup": {}, | |
| "arsenal_matchup": {}, | |
| "trajectory": {}, | |
| "sequence_profiles": {}, | |
| } | |
| batter_df = batter_statcast_df if batter_statcast_df is not None else pd.DataFrame() | |
| pitcher_df = pitcher_statcast_df if pitcher_statcast_df is not None else batter_df | |
| if batter_df.empty or pitcher_df.empty or not batter_name or not pitcher_name: | |
| return empty | |
| batter_features = dict(batter_features or {}) | |
| pitcher_row = dict(pitcher_row or {}) | |
| game_row = dict(game_row or {}) | |
| cache_key = ( | |
| str(batter_name or "").strip().lower(), | |
| str(pitcher_name or "").strip().lower(), | |
| str(game_row.get("away_team") or "").strip().lower(), | |
| str(game_row.get("home_team") or "").strip().lower(), | |
| str(game_row.get("projected_starter_match_status") or "").strip().lower(), | |
| str(game_row.get("pitcher_id") or "").strip(), | |
| str(batter_features.get("batter_stand") or "").strip().upper(), | |
| str(pitcher_row.get("p_throws") or "").strip().upper(), | |
| ) | |
| shared_bucket = _runtime_bucket(runtime_cache, "shared_matchup_context") | |
| if runtime_cache is not None and cache_key in shared_bucket: | |
| return shared_bucket[cache_key] | |
| handedness_context = { | |
| "batter_stand": str(batter_features.get("batter_stand", "") or "").strip().upper(), | |
| "pitcher_hand": str(pitcher_row.get("p_throws", "") or "").strip().upper(), | |
| } | |
| pitcher_family_zone_row = _cache_get_or_build( | |
| runtime_cache, | |
| "pitcher_family_zone_rows", | |
| (id(pitcher_df), str(pitcher_name or "").strip().lower()), | |
| lambda: build_pitcher_family_zone_feature_row(pitcher_df, pitcher_name), | |
| ) | |
| count_context_profile: dict[str, dict[str, Any]] = {} | |
| expected_pitch_mix_by_count: dict[str, dict[str, float]] = {} | |
| expected_zone_mix_by_count: dict[str, dict[str, float]] = {} | |
| sequence_profiles: dict[str, dict[str, Any]] = {} | |
| for balls, strikes in _COUNT_STATES: | |
| count_key = f"{balls}-{strikes}" | |
| seq_features = build_sequence_features( | |
| game_row={**game_row, "balls": balls, "strikes": strikes}, | |
| pitcher_row=pitcher_row, | |
| batter_row=batter_features, | |
| pitcher_family_zone_row=pitcher_family_zone_row, | |
| ) | |
| seq_profile = predict_next_pitch_distribution(seq_features) | |
| sequence_profiles[count_key] = seq_profile | |
| expected_pitch_mix_by_count[count_key] = { | |
| family: round(_safe_float(seq_profile.get(f"{family}_prob")), 6) | |
| for family in _PITCH_FAMILIES | |
| } | |
| expected_zone_mix_by_count[count_key] = { | |
| zone: round(_safe_float((seq_profile.get("zone_probs") or {}).get(zone)), 6) | |
| for zone in _ZONES | |
| } | |
| leverage = "neutral" | |
| if strikes >= 2: | |
| leverage = "putaway" | |
| elif balls >= 2: | |
| leverage = "hitter_ahead" | |
| count_context_profile[count_key] = { | |
| "balls": balls, | |
| "strikes": strikes, | |
| "count_leverage": leverage, | |
| } | |
| expected_pitch_zone_mix_by_count = _build_pitch_zone_mix(sequence_profiles) | |
| expected_pitch_family_mix = _build_pitch_family_mix(sequence_profiles) | |
| batter_zone_row = _cache_get_or_build( | |
| runtime_cache, | |
| "batter_zone_rows", | |
| (id(batter_df), str(batter_name or "").strip().lower()), | |
| lambda: build_batter_zone_feature_row(batter_df, batter_name), | |
| ) | |
| pitcher_zone_row = _cache_get_or_build( | |
| runtime_cache, | |
| "pitcher_zone_rows", | |
| (id(pitcher_df), str(pitcher_name or "").strip().lower()), | |
| lambda: build_pitcher_zone_feature_row(pitcher_df, pitcher_name), | |
| ) | |
| zone_matchup = compute_zone_matchup_adjustment( | |
| batter_zone_row, | |
| pitcher_zone_row, | |
| pitch_zone_weights=expected_pitch_zone_mix_by_count, | |
| handedness_context=handedness_context, | |
| ) | |
| batter_family_zone_row = _cache_get_or_build( | |
| runtime_cache, | |
| "batter_family_zone_rows", | |
| (id(batter_df), str(batter_name or "").strip().lower()), | |
| lambda: build_batter_family_zone_feature_row(batter_df, batter_name), | |
| ) | |
| family_zone_matchup = compute_family_zone_matchup_adjustment( | |
| batter_family_zone_row=batter_family_zone_row, | |
| pitcher_family_zone_row=pitcher_family_zone_row, | |
| ) | |
| batter_arsenal_row = _cache_get_or_build( | |
| runtime_cache, | |
| "batter_arsenal_rows", | |
| (id(batter_df), str(batter_name or "").strip().lower()), | |
| lambda: build_batter_arsenal_feature_row(batter_df, batter_name), | |
| ) | |
| pitcher_arsenal_row = _cache_get_or_build( | |
| runtime_cache, | |
| "pitcher_arsenal_rows", | |
| (id(pitcher_df), str(pitcher_name or "").strip().lower()), | |
| lambda: build_pitcher_arsenal_feature_row(pitcher_df, pitcher_name), | |
| ) | |
| arsenal_matchup = compute_arsenal_matchup_adjustment( | |
| batter_arsenal_row=batter_arsenal_row, | |
| pitcher_arsenal_row=pitcher_arsenal_row, | |
| pitch_family_weights=expected_pitch_family_mix, | |
| handedness_context=handedness_context, | |
| ) | |
| trajectory = _cache_get_or_build( | |
| runtime_cache, | |
| "trajectory_rows", | |
| (id(pitcher_df), str(pitcher_name or "").strip().lower(), str(game_row.get("pitcher_id") or "").strip()), | |
| lambda: build_trajectory_features( | |
| statcast_df=pitcher_df, | |
| pitcher_name=pitcher_name, | |
| pitcher_id=game_row.get("pitcher_id"), | |
| ), | |
| ) | |
| attack_regions = _top_regions(expected_pitch_zone_mix_by_count, limit=5) | |
| damage_region_map: dict[str, float] = {} | |
| whiff_region_map: dict[str, float] = {} | |
| for key, attack_weight in expected_pitch_zone_mix_by_count.items(): | |
| try: | |
| family, zone = key.split("_", 1) | |
| except ValueError: | |
| continue | |
| batter_damage = _safe_float(batter_family_zone_row.get(f"damage_rate_{family}_{zone}")) | |
| pitcher_damage = _safe_float(pitcher_family_zone_row.get(f"damage_allowed_rate_{family}_{zone}")) | |
| batter_whiff = _safe_float(batter_family_zone_row.get(f"whiff_rate_{family}_{zone}")) | |
| pitcher_whiff = _safe_float(pitcher_family_zone_row.get(f"whiff_rate_{family}_{zone}")) | |
| damage_region_map[key] = attack_weight * ((batter_damage * 0.6) + (pitcher_damage * 0.4)) | |
| whiff_region_map[key] = attack_weight * ((batter_whiff * 0.55) + (pitcher_whiff * 0.45)) | |
| tunnel_score = _safe_float(trajectory.get("tunnel_score")) | |
| release_score = _safe_float(trajectory.get("release_consistency_score")) | |
| tunnel_pair_scores = [ | |
| { | |
| "pair": "arsenal_tunnel_profile", | |
| "tunnel_score": round(tunnel_score, 6), | |
| "release_consistency_score": round(release_score, 6), | |
| "deception_score": round(_safe_float(trajectory.get("deception_score")), 6), | |
| } | |
| ] if (tunnel_score or release_score) else [] | |
| coverage_signals = [ | |
| _reliability(batter_zone_row.get("zone_sample_size"), 200.0), | |
| _reliability(pitcher_zone_row.get("zone_sample_size"), 200.0), | |
| _reliability(batter_family_zone_row.get("family_zone_sample_size"), 220.0), | |
| _reliability(pitcher_family_zone_row.get("family_zone_sample_size"), 220.0), | |
| _reliability(batter_arsenal_row.get("arsenal_sample_size"), 180.0), | |
| _reliability(pitcher_arsenal_row.get("arsenal_sample_size"), 180.0), | |
| _reliability(trajectory.get("trajectory_sample_size"), 240.0), | |
| ] | |
| matchup_coverage_confidence = round(sum(coverage_signals) / len(coverage_signals), 4) | |
| result = { | |
| "expected_pitch_mix_by_count": expected_pitch_mix_by_count, | |
| "expected_zone_mix_by_count": expected_zone_mix_by_count, | |
| "expected_pitch_zone_mix_by_count": expected_pitch_zone_mix_by_count, | |
| "expected_pitch_family_mix": expected_pitch_family_mix, | |
| "tunnel_pair_scores": tunnel_pair_scores, | |
| "predicted_attack_regions": attack_regions, | |
| "predicted_damage_regions": _top_regions(damage_region_map, limit=5), | |
| "predicted_whiff_regions": _top_regions(whiff_region_map, limit=5), | |
| "handedness_context": handedness_context, | |
| "count_context_profile": count_context_profile, | |
| "matchup_coverage_confidence": matchup_coverage_confidence, | |
| "component_source_map": _component_source_map(), | |
| "zone_matchup": zone_matchup, | |
| "family_zone_matchup": family_zone_matchup, | |
| "arsenal_matchup": arsenal_matchup, | |
| "trajectory": trajectory, | |
| "sequence_profiles": sequence_profiles, | |
| "_component_rows": { | |
| "batter_zone_row": batter_zone_row, | |
| "pitcher_zone_row": pitcher_zone_row, | |
| "batter_family_zone_row": batter_family_zone_row, | |
| "pitcher_family_zone_row": pitcher_family_zone_row, | |
| "batter_arsenal_row": batter_arsenal_row, | |
| "pitcher_arsenal_row": pitcher_arsenal_row, | |
| }, | |
| } | |
| if runtime_cache is not None: | |
| shared_bucket[cache_key] = result | |
| return result | |