nesticot's picture
Update app.py
1913b9d verified
import polars as pl
import numpy as np
import pandas as pd
import api_scraper
scrape = api_scraper.MLB_Scrape()
from functions import df_update
from functions import pitch_summary_functions
update = df_update.df_update()
from stuff_model import feature_engineering as fe
from stuff_model import stuff_apply
import requests
import joblib
from matplotlib.gridspec import GridSpec
from shiny import App, reactive, ui, render
from shiny.ui import h2, tags
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
from functions.pitch_summary_functions import *
from starlette.applications import Starlette
from starlette.responses import StreamingResponse, FileResponse
from starlette.routing import Route
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
import io
from PIL import Image
from io import BytesIO
import tempfile
import hashlib
import os
from pathlib import Path
from datetime import datetime, timedelta
import asyncio
import html
from urllib.parse import parse_qs, unquote
import re
# ===== CACHE CONFIG =====
# Use /tmp for Hugging Face Spaces compatibility (read-only filesystem elsewhere)
CACHE_DIR = Path("/tmp/plot_cache")
CACHE_DIR.mkdir(exist_ok=True)
CACHE_TTL_HOURS = 24 # Cache expires after 24 hours
IN_PROGRESS_REQUESTS = {} # Track ongoing requests to prevent duplicates
def get_cache_key(pitcher_id, year, level_id, game_type='R'):
"""Generate a cache key for the request"""
key = f"{pitcher_id}_{year}_{level_id}_{game_type}"
return hashlib.md5(key.encode()).hexdigest()
def get_cache_path(cache_key):
"""Get the file path for a cached plot"""
return CACHE_DIR / f"{cache_key}.png"
def is_cache_valid(cache_path):
"""Check if cache file exists and is not expired"""
if not cache_path.exists():
return False
file_age = datetime.now() - datetime.fromtimestamp(cache_path.stat().st_mtime)
return file_age < timedelta(hours=CACHE_TTL_HOURS)
def clear_old_cache():
"""Remove expired cache files"""
for cache_file in CACHE_DIR.glob("*.png"):
file_age = datetime.now() - datetime.fromtimestamp(cache_file.stat().st_mtime)
if file_age > timedelta(hours=CACHE_TTL_HOURS):
cache_file.unlink()
# ===== HARDCODED INPUTS =====
HARDCODED_SPLIT = 'all'
HARDCODED_START_DATE = '2025-01-01'
HARDCODED_END_DATE = '2025-12-31'
HARDCODED_GAME_TYPE = 'R'
HARDCODED_PLOT_1 = 'velocity_kdes'
HARDCODED_PLOT_2 = 'break_plot'
HARDCODED_PLOT_3 = 'pitch_usage'
HARDCODED_ROLLING_WINDOW = 50
USE_CUSTOM_LOGO = False
colour_palette = ['#FFB000','#648FFF','#785EF0',
'#DC267F','#FE6100','#3D1EB2','#894D80','#16AA02','#B5592B','#A3C1ED']
split_dict_hand = {'all':['L','R'],
'left':['L'],
'right':['R']}
type_dict = {'R':'Regular Season',
'S':'Spring',
'P':'Playoffs'}
level_dict = {'1':'MLB',
'11':'AAA',
'14':'A',
'16':'ROK',
'17':'AFL',
'22':'College',
'21':'Prospects',
'51':'International'}
function_dict = {
'velocity_kdes':'Velocity Distributions',
'break_plot':'Pitch Movement',
'tj_stuff_roling':'Rolling tjStuff+ by Pitch',
'tj_stuff_roling_game':'Rolling tjStuff+ by Game',
'location_plot_lhb':'Locations vs LHB',
'location_plot_rhb':'Locations vs RHB',
'pitch_usage':'Pitch Usage',
}
# ===== API ENDPOINTS =====
def decode_query_string(query_string):
"""Decode HTML-encoded query strings from WordPress/LiteSpeed
Handles cases like: pitcher_id=691587&amp;#038;year=2025&amp;#038;level_id=1
Which should become: pitcher_id=691587&year=2025&level_id=1
"""
if not query_string:
return {}
# Decode HTML entities (handles &amp; -> & and &#038; -> &)
decoded = html.unescape(query_string)
# Also handle double-encoded cases like &amp;#038;
decoded = decoded.replace('&#038;', '&')
decoded = decoded.replace('&amp;', '&')
# Handle any remaining URL encoding
decoded = unquote(decoded)
# Parse the cleaned query string
params = parse_qs(decoded, keep_blank_values=True)
# Convert lists to single values (take first value)
return {k: v[0] if v else None for k, v in params.items()}
async def plot_api_endpoint(request):
"""Generate pitcher summary plot via API with caching
Usage: /api/pitcher_plot?pitcher_id=694973&year=2025&level_id=1&type=R
Query Parameters:
- pitcher_id (int): Required. The pitcher's MLB ID
- year (int): Required. The season year (e.g., 2025)
- level_id (int): Required. The league level ID (1=MLB, 11=AAA, 14=A, 16=ROK, 17=AFL, 22=College, 21=Prospects, 51=International)
- type (str): Optional. Game type - R=Regular Season (default), S=Spring, P=Playoffs
"""
try:
# Decode potentially HTML-encoded query string (WordPress/LiteSpeed issue)
raw_query = str(request.url.query) if request.url.query else ""
decoded_params = decode_query_string(raw_query)
print(f"\n=== Query String Debug ===")
print(f"Raw query: {raw_query}")
print(f"Decoded params: {decoded_params}")
# Check for required parameters - try decoded first, then fall back to standard
pitcher_id_param = decoded_params.get('pitcher_id') or request.query_params.get('pitcher_id')
year_param = decoded_params.get('year') or request.query_params.get('year')
level_id_param = decoded_params.get('level_id') or request.query_params.get('level_id')
type_param = decoded_params.get('type') or request.query_params.get('type') or 'R' # Default to Regular Season
if not pitcher_id_param or not year_param or not level_id_param:
return StreamingResponse(
io.BytesIO("Missing required parameters. Need: pitcher_id, year, level_id".encode()),
media_type="text/plain",
status_code=400
)
pitcher_id = int(pitcher_id_param)
year = int(year_param)
level_id = int(level_id_param)
game_type = str(type_param).upper() # Normalize to uppercase
cache_key = get_cache_key(pitcher_id, year, level_id, game_type)
print(f"\n=== API Request ===")
print(f"Pitcher ID: {pitcher_id}, Year: {year}, Level: {level_id}, Type: {game_type}")
# Check cache first
cache_path = get_cache_path(cache_key)
if is_cache_valid(cache_path):
print(f"✅ Cache HIT - serving from disk ({cache_path.stat().st_size / 1024:.1f} KB)")
return FileResponse(cache_path, media_type="image/png")
# Check if request is already in progress
if cache_key in IN_PROGRESS_REQUESTS:
print(f"⏳ Request already in progress - waiting for result")
# Wait for the in-progress request to complete
await IN_PROGRESS_REQUESTS[cache_key]
if is_cache_valid(cache_path):
print(f"✅ Now serving from cache after concurrent request completed")
return FileResponse(cache_path, media_type="image/png")
# Mark request as in progress
future = asyncio.Future()
IN_PROGRESS_REQUESTS[cache_key] = future
try:
print(f"📊 Cache MISS - generating plot")
# Generate the plot
fig = generate_pitcher_plot(
pitcher_id=pitcher_id,
year=year,
level_id=level_id,
split=HARDCODED_SPLIT,
start_date=f"{year}-01-01",
end_date=f"{year}-12-31",
game_type=game_type,
plot_1=HARDCODED_PLOT_1,
plot_2=HARDCODED_PLOT_2,
plot_3=HARDCODED_PLOT_3,
rolling_window=HARDCODED_ROLLING_WINDOW,
full_resolution=True
)
# Save to cache
fig.savefig(cache_path, format='png', dpi=100)
plt.close(fig)
cache_size_kb = cache_path.stat().st_size / 1024
print(f"✅ Plot generated and cached - {cache_size_kb:.1f} KB")
future.set_result(True)
return FileResponse(cache_path, media_type="image/png")
except Exception as e:
future.set_exception(e)
raise
finally:
# Clean up in-progress tracking
IN_PROGRESS_REQUESTS.pop(cache_key, None)
# Periodically clean old cache files
if len(IN_PROGRESS_REQUESTS) == 0:
clear_old_cache()
except Exception as e:
import traceback
error_msg = traceback.format_exc()
print(f"API Error:\n{error_msg}")
return StreamingResponse(
io.BytesIO(f"Error: {str(e)}\n\nTraceback:\n{error_msg}".encode()),
media_type="text/plain",
status_code=400
)
class APIMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
# Add detailed logging
print(f"\n=== MIDDLEWARE DEBUG ===")
print(f"Request URL: {request.url}")
print(f"Request Path: {request.url.path}")
print(f"Request Method: {request.method}")
# Check if this is an API request
if request.url.path.startswith("/api/pitcher_plot"):
print("✅ API request detected - calling plot_api_endpoint")
try:
result = await plot_api_endpoint(request)
print(f"✅ API endpoint returned: {type(result)}")
return result
except Exception as e:
print(f"❌ API endpoint error: {e}")
return StreamingResponse(
io.BytesIO(f"API Middleware Error: {str(e)}".encode()),
media_type="text/plain",
status_code=500
)
else:
print("➡️ Not API request - passing to Shiny")
# Pass to Shiny
response = await call_next(request)
print(f"✅ Shiny response: {type(response)} - {getattr(response, 'status_code', 'Unknown')}")
return response
def generate_pitcher_plot(pitcher_id, year, level_id, split, start_date, end_date,
game_type, plot_1, plot_2, plot_3, rolling_window, full_resolution=False):
"""Generate the complete pitcher summary plot"""
try:
# Fetch game data
game_list = scrape.get_player_games_list(
sport_id=level_id,
season=year,
player_id=pitcher_id,
start_date=start_date,
end_date=end_date,
game_type=[game_type]
)
data_list = scrape.get_data(game_list_input=game_list[:])
# Process data
df = (stuff_apply.stuff_apply(fe.feature_engineering(update.update(
scrape.get_data_df(data_list=data_list).filter(
(pl.col("pitcher_id") == pitcher_id) &
(pl.col("is_pitch") == True) &
(pl.col("start_speed") >= 50) &
(pl.col('batter_hand').is_in(split_dict_hand[split]))
)
))).with_columns(
pl.col('pitch_type').count().over('pitch_type').alias('pitch_count')
))
if df is None or len(df) == 0:
fig = plt.figure(figsize=(9,9), dpi=43)
fig.text(x=0.1, y=0.9, s='No Statcast Data For This Pitcher', fontsize=36, ha='left')
return fig
df = df.clone()
# Set figure size
figsize = (21, 21) if full_resolution else (8, 8)
dpi = 100 if full_resolution else 37
fig = plt.figure(figsize=figsize, dpi=dpi)
plt.rcParams.update({'figure.autolayout': True})
fig.set_facecolor('white')
sns.set_theme(style="whitegrid", palette=colour_palette)
gs = gridspec.GridSpec(6, 8,
height_ratios=[6,20,12,36,36,6],
width_ratios=[4,18,18,18,18,18,18,4])
gs.update(hspace=0.2, wspace=0.5)
# Create subplots
ax_headshot = fig.add_subplot(gs[1,1:3])
ax_bio = fig.add_subplot(gs[1,3:5])
ax_logo = fig.add_subplot(gs[1,5:7])
ax_season_table = fig.add_subplot(gs[2,1:7])
ax_plot_1 = fig.add_subplot(gs[3,1:3])
ax_plot_2 = fig.add_subplot(gs[3,3:5])
ax_plot_3 = fig.add_subplot(gs[3,5:7])
ax_table = fig.add_subplot(gs[4,1:7])
ax_footer = fig.add_subplot(gs[-1,1:7])
ax_header = fig.add_subplot(gs[0,1:7])
ax_left = fig.add_subplot(gs[:,0])
ax_right = fig.add_subplot(gs[:,-1])
# Hide axes
ax_footer.axis('off')
ax_header.axis('off')
ax_left.axis('off')
ax_right.axis('off')
# Populate plots
df_teams = scrape.get_teams()
player_headshot(player_input=pitcher_id, ax=ax_headshot, sport_id=level_id, season=year)
player_bio(pitcher_id=pitcher_id, ax=ax_bio, sport_id=level_id, year_input=year)
plot_logo(pitcher_id=pitcher_id, ax=ax_logo, df_team=df_teams,
df_players=scrape.get_players(level_id, year, game_type=[game_type]))
stat_summary_table(df=df, ax=ax_season_table, player_input=pitcher_id,
split=split, sport_id=level_id, game_type=[game_type],
start_date_input=start_date, end_date_input=end_date)
# Generate the three plot panels
for plot_func, ax, z in [(plot_1, ax_plot_1, 1), (plot_2, ax_plot_2, 3), (plot_3, ax_plot_3, 5)]:
if plot_func == 'velocity_kdes':
velocity_kdes(df, ax=ax, gs=gs, gs_x=[3,4], gs_y=[z,z+2], fig=fig)
elif plot_func == 'tj_stuff_roling':
tj_stuff_roling(df=df, window=rolling_window, ax=ax)
elif plot_func == 'tj_stuff_roling_game':
tj_stuff_roling_game(df=df, window=rolling_window, ax=ax)
elif plot_func == 'break_plot':
break_plot(df=df, ax=ax)
elif plot_func == 'location_plot_lhb':
location_plot(df=df, ax=ax, hand='L')
elif plot_func == 'location_plot_rhb':
location_plot(df=df, ax=ax, hand='R')
elif plot_func == 'pitch_usage':
pitch_usage(df=df, ax=ax)
summary_table(df=df, ax=ax_table)
plot_footer(ax_footer)
# Add watermark
ax_watermark2 = fig.add_subplot(gs[-2:,1:4], zorder=1)
ax_watermark2.set_xlim(0,1)
ax_watermark2.set_ylim(0,1)
ax_watermark2.set_xticks([])
ax_watermark2.set_yticks([])
ax_watermark2.set_frame_on(False)
try:
img = Image.open('tj stats circle-01_new.jpg')
ax_watermark2.imshow(img, extent=[0.26, 0.46, 0.0, 0.2], origin='upper', zorder=-1, alpha=1)
except Exception as e:
print(f"Watermark error (non-critical): {e}")
pass
fig.subplots_adjust(left=0.01, right=0.99, top=0.99, bottom=0.01)
return fig
except Exception as e:
print(f"Error generating plot: {str(e)}")
fig = plt.figure(figsize=(9,9), dpi=43)
fig.text(x=0.1, y=0.9, s=f'Error: {str(e)}', fontsize=24, ha='left')
return fig
# ===== SHINY UI =====
app_ui = ui.page_sidebar(
ui.sidebar(
ui.input_numeric("pitcher_id", "Pitcher ID:", value=0),
ui.input_numeric("year_input", "Year:", value=2025),
ui.input_select("level_input", "Level:", level_dict, selected='1'),
ui.input_action_button("generate_plot", "Generate Plot", class_="btn-primary"),
ui.download_button("download_plot", "Download Plot", class_="btn-secondary"),
),
ui.h1("Pitcher Summary Dashboard"),
ui.output_text("status"),
ui.output_plot("plot", width="800px", height="800px")
)
# ===== SHINY SERVER =====
def server(input, output, session):
@render.text
def status():
if input.generate_plot() == 0:
return ""
return f"Generating plot for Pitcher {input.pitcher_id()} ({input.year_input()}, Level: {input.level_input()})"
@render.plot
@reactive.event(input.generate_plot, ignore_none=False)
def plot():
return generate_pitcher_plot(
pitcher_id=int(input.pitcher_id()),
year=int(input.year_input()),
level_id=int(input.level_input()),
split=HARDCODED_SPLIT,
start_date=HARDCODED_START_DATE,
end_date=HARDCODED_END_DATE,
game_type=HARDCODED_GAME_TYPE,
plot_1=HARDCODED_PLOT_1,
plot_2=HARDCODED_PLOT_2,
plot_3=HARDCODED_PLOT_3,
rolling_window=HARDCODED_ROLLING_WINDOW,
full_resolution=False
)
@session.download(filename="pitcher_summary.png")
def download_plot():
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp:
fig = generate_pitcher_plot(
pitcher_id=int(input.pitcher_id()),
year=int(input.year_input()),
level_id=int(input.level_input()),
split=HARDCODED_SPLIT,
start_date=HARDCODED_START_DATE,
end_date=HARDCODED_END_DATE,
game_type=HARDCODED_GAME_TYPE,
plot_1=HARDCODED_PLOT_1,
plot_2=HARDCODED_PLOT_2,
plot_3=HARDCODED_PLOT_3,
rolling_window=HARDCODED_ROLLING_WINDOW,
full_resolution=True
)
fig.savefig(tmp.name, format='png', dpi=100)
plt.close(fig)
return tmp.name
# ===== CREATE SHINY APP WITH API MIDDLEWARE =====
# Create the base Shiny app first
shiny_app = App(app_ui, server)
# Add API middleware to the existing Shiny app
shiny_app.starlette_app.add_middleware(APIMiddleware)
# Export the app
app = shiny_app
print("=" * 60)
print("API Endpoint: /api/pitcher_plot")
print("Query Parameters: pitcher_id, year, level_id")
print(f"Cache Directory: {CACHE_DIR}")
print(f"Cache TTL: {CACHE_TTL_HOURS} hours")
print("\nExample (Paul Skenes):")
print("/api/pitcher_plot?pitcher_id=694973&year=2025&level_id=1")
print("=" * 60)