Spaces:
Running
Running
| from __future__ import annotations | |
| from math import sqrt | |
| from typing import Any | |
| from simulation.pa_outcome_sampler import ( | |
| sample_contact_event, | |
| sample_pitch_family, | |
| sample_pitch_result, | |
| sample_zone_bucket, | |
| ) | |
| def _safe_count_int(value: Any, default: int = 0) -> int: | |
| try: | |
| if value is None: | |
| return default | |
| text = str(value).strip().lower() | |
| if text in {"", "nan", "none"}: | |
| return default | |
| return int(float(value)) | |
| except Exception: | |
| return default | |
| def _clamp(value: float, low: float, high: float) -> float: | |
| return max(low, min(high, value)) | |
| def simulate_plate_appearance( | |
| game_row: dict[str, Any], | |
| batter_row: dict[str, Any], | |
| pitcher_row: dict[str, Any], | |
| context_adjustment: dict[str, Any], | |
| pitcher_adjustment: dict[str, Any], | |
| bullpen_adjustment: dict[str, Any], | |
| sequence_distribution: dict[str, Any], | |
| batter_baseline: dict[str, Any], | |
| n_sims: int = 10000, | |
| ) -> dict[str, float]: | |
| hit_prob = ( | |
| batter_baseline["hit_prob_base"] | |
| + pitcher_adjustment["hit_adj"] | |
| + context_adjustment["hit_adj"] | |
| + bullpen_adjustment["bullpen_risk_adj_hit"] | |
| ) | |
| hr_prob = ( | |
| batter_baseline["hr_prob_base"] | |
| + pitcher_adjustment["hr_adj"] | |
| + context_adjustment["hr_adj"] | |
| + bullpen_adjustment["bullpen_risk_adj_hr"] | |
| ) | |
| tb2p_prob = ( | |
| batter_baseline["tb2p_prob_base"] | |
| + pitcher_adjustment["tb2p_adj"] | |
| + context_adjustment["tb2p_adj"] | |
| + bullpen_adjustment["bullpen_risk_adj_tb2p"] | |
| ) | |
| hit_prob = _clamp(hit_prob, 0.05, 0.55) | |
| hr_prob = _clamp(hr_prob, 0.005, 0.25) | |
| tb2p_prob = _clamp(tb2p_prob, 0.03, 0.45) | |
| start_balls = _safe_count_int(game_row.get("balls"), 0) | |
| start_strikes = _safe_count_int(game_row.get("strikes"), 0) | |
| hit_total = 0 | |
| hr_total = 0 | |
| tb2p_total = 0 | |
| zone_probs = sequence_distribution.get("zone_probs", {}) or {} | |
| # I1: Compute pitch movement magnitude from pitcher_row for movement adjustment | |
| movement_magnitude = None | |
| try: | |
| pfx_x = pitcher_row.get("avg_pfx_x") | |
| pfx_z = pitcher_row.get("avg_pfx_z") | |
| if pfx_x is not None and pfx_z is not None: | |
| movement_magnitude = sqrt(float(pfx_x) ** 2 + float(pfx_z) ** 2) | |
| except Exception: | |
| movement_magnitude = None | |
| for _ in range(n_sims): | |
| balls = _clamp(start_balls, 0, 3) | |
| strikes = _clamp(start_strikes, 0, 2) | |
| resolved = False | |
| # simulate remainder of PA pitch-by-pitch | |
| for _pitch_num in range(8): | |
| pitch_family = sample_pitch_family(sequence_distribution) | |
| zone_bucket = sample_zone_bucket(zone_probs) | |
| pitch_result = sample_pitch_result( | |
| balls=int(balls), | |
| strikes=int(strikes), | |
| zone_bucket=zone_bucket, | |
| pitch_family=pitch_family, | |
| movement_magnitude=movement_magnitude, | |
| )["pitch_result"] | |
| if pitch_result == "ball": | |
| balls += 1 | |
| if balls >= 4: | |
| resolved = True | |
| break | |
| elif pitch_result in {"called_strike", "whiff"}: | |
| strikes += 1 | |
| if strikes >= 3: | |
| resolved = True | |
| break | |
| elif pitch_result == "foul": | |
| if strikes < 2: | |
| strikes += 1 | |
| elif pitch_result == "ball_in_play": | |
| event = sample_contact_event( | |
| hit_prob=hit_prob, | |
| hr_prob=hr_prob, | |
| tb2p_prob=tb2p_prob, | |
| zone_bucket=zone_bucket, | |
| pitch_family=pitch_family, | |
| ) | |
| hit_total += int(event["hit"]) | |
| hr_total += int(event["hr"]) | |
| tb2p_total += int(event["tb2p"]) | |
| resolved = True | |
| break | |
| # unresolved long PA fallback | |
| if not resolved: | |
| event = sample_contact_event( | |
| hit_prob=hit_prob * 0.95, | |
| hr_prob=hr_prob * 0.95, | |
| tb2p_prob=tb2p_prob * 0.95, | |
| zone_bucket="shadow", | |
| pitch_family="breaking", | |
| ) | |
| hit_total += int(event["hit"]) | |
| hr_total += int(event["hr"]) | |
| tb2p_total += int(event["tb2p"]) | |
| return { | |
| "hit_prob": hit_total / n_sims, | |
| "hr_prob": hr_total / n_sims, | |
| "tb2p_prob": tb2p_total / n_sims, | |
| } | |
| def simulate_upcoming_hitter( | |
| game_row: dict[str, Any], | |
| batter_row: dict[str, Any], | |
| pitcher_row: dict[str, Any], | |
| context_adjustment: dict[str, Any], | |
| pitcher_adjustment: dict[str, Any], | |
| bullpen_adjustment: dict[str, Any], | |
| sequence_distribution: dict[str, Any], | |
| batter_baseline: dict[str, Any], | |
| n_sims: int = 10000, | |
| ) -> dict[str, float]: | |
| return simulate_plate_appearance( | |
| game_row=game_row, | |
| batter_row=batter_row, | |
| pitcher_row=pitcher_row, | |
| context_adjustment=context_adjustment, | |
| pitcher_adjustment=pitcher_adjustment, | |
| bullpen_adjustment=bullpen_adjustment, | |
| sequence_distribution=sequence_distribution, | |
| batter_baseline=batter_baseline, | |
| n_sims=n_sims, | |
| ) |