import pandas as pd import numpy as np import json from matplotlib.ticker import FuncFormatter from matplotlib.ticker import MaxNLocator import math from matplotlib.patches import Ellipse import matplotlib.transforms as transforms import matplotlib.colors import matplotlib.colors as mcolors import seaborn as sns import matplotlib.pyplot as plt import requests import polars as pl from PIL import Image import requests from io import BytesIO from matplotlib.offsetbox import OffsetImage, AnnotationBbox import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec import PIL from matplotlib.transforms import Bbox import matplotlib.image as mpimg from scipy.stats import gaussian_kde from statsmodels.nonparametric.kernel_regression import KernelReg format_dict = { 'pitch_percent': '{:.1%}', 'pitches': '{:.0f}', 'heart_zone_percent': '{:.1%}', 'shadow_zone_percent': '{:.1%}', 'chase_zone_percent': '{:.1%}', 'waste_zone_percent': '{:.1%}', 'csw_percent': '{:.1%}', 'whiff_rate': '{:.1%}', 'zone_whiff_percent': '{:.1%}', 'chase_percent': '{:.1%}', 'bip': '{:.0f}', 'xwoba_percent_contact': '{:.3f}' } label_translation_dict = { 'pitch_percent': 'Pitch%', 'pitches': 'Pitches', 'heart_zone_percent': 'Heart%', 'shadow_zone_percent': 'Shadow%', 'chase_zone_percent': 'Chase%', 'waste_zone_percent': 'Waste%', 'csw_percent': 'CSW%', 'whiff_rate': 'Whiff%', 'zone_whiff_percent': 'Z-Whiff%', 'chase_percent': 'O-Swing%', 'bip': 'BBE', 'xwoba_percent_contact': 'xwOBACON' } def pitch_heat_map(pitch_input, df): df = df.with_columns([ pl.col('pitcher_id').count().over(['batter_hand', 'strikes', 'balls']).alias('h_s_b'), pl.col('pitcher_id').count().over(['batter_hand', 'strikes', 'balls', 'pitch_type']).alias('h_s_b_pitch') ]) df = df.with_columns([ (pl.col('h_s_b_pitch') / pl.col('h_s_b')).alias('h_s_b_pitch_percent') ]) df_plot = df.filter(pl.col('pitch_type') == pitch_input) return df_plot def pitch_prop(df: pl.DataFrame, hand: str = 'R') -> pd.DataFrame: df_plot_pd = df.to_pandas() pivot_table = (df_plot_pd[df_plot_pd['batter_hand'].isin([hand])] .groupby(['batter_hand','strikes', 'balls'])[['h_s_b_pitch_percent']] .mean() .reset_index() .pivot(index='strikes',columns='balls',values='h_s_b_pitch_percent')) # Create a new index and columns range new_index = range(3) new_columns = range(4) # Reindex the pivot table pivot_table = pivot_table.reindex(index=new_index, columns=new_columns) # Fill any missing values with 0 pivot_table = pivot_table.fillna(0) df_hand = pl.DataFrame(pivot_table.reset_index()) return df_hand # DEFINE STRIKE ZONE strike_zone = pd.DataFrame({ 'PlateLocSide': [-0.9, -0.9, 0.9, 0.9, -0.9], 'PlateLocHeight': [1.5, 3.5, 3.5, 1.5, 1.5] }) ### STRIKE ZONE ### def draw_line(axis, alpha_spot=1, catcher_p=True): # Ensure strike_zone columns are NumPy arrays plate_side = strike_zone['PlateLocSide'].to_numpy() plate_height = strike_zone['PlateLocHeight'].to_numpy() # Plot the strike zone axis.plot(plate_side, plate_height, color='black', linewidth=1.3, zorder=3, alpha=alpha_spot) if catcher_p: # Add dashed lines and home plate for catcher perspective axis.plot([-0.708, 0.708], [0.15, 0.15], color='black', linewidth=1, alpha=alpha_spot, zorder=1) axis.plot([-0.708, -0.708], [0.15, 0.3], color='black', linewidth=1, alpha=alpha_spot, zorder=1) axis.plot([-0.708, 0], [0.3, 0.5], color='black', linewidth=1, alpha=alpha_spot, zorder=1) axis.plot([0, 0.708], [0.5, 0.3], color='black', linewidth=1, alpha=alpha_spot, zorder=1) axis.plot([0.708, 0.708], [0.3, 0.15], color='black', linewidth=1, alpha=alpha_spot, zorder=1) else: # Add dashed lines and home plate for other perspective axis.plot([-0.708, 0.708], [0.4, 0.4], color='black', linewidth=1, alpha=alpha_spot, zorder=1) axis.plot([-0.708, -0.9], [0.4, -0.1], color='black', linewidth=1, alpha=alpha_spot, zorder=1) axis.plot([-0.9, 0], [-0.1, -0.35], color='black', linewidth=1, alpha=alpha_spot, zorder=1) axis.plot([0, 0.9], [-0.35, -0.1], color='black', linewidth=1, alpha=alpha_spot, zorder=1) axis.plot([0.9, 0.708], [-0.1, 0.4], color='black', linewidth=1, alpha=alpha_spot, zorder=1) def heat_map_plot(df:pl.DataFrame, ax:plt.Axes, cmap:matplotlib.colors.LinearSegmentedColormap, hand:str): if df.filter(pl.col('batter_hand')==hand).shape[0] > 3: sns.kdeplot(data=df.filter(pl.col('batter_hand')==hand), x='px', y='pz', cmap=cmap, shade=True, ax=ax, thresh=0.3, bw_adjust=1) elif df.filter(pl.col('batter_hand')==hand).shape[0] > 0: sns.scatterplot(data=df.filter(pl.col('batter_hand')==hand), x='px', y='pz', cmap=cmap, ax=ax) draw_line(ax,alpha_spot=1,catcher_p = False) ax.axis('off') ax.axis('square') ax.set_xlim(-2.75,2.75) ax.set_ylim(-0.5,5) def format_as_percentage(val): return f'{val * 100:.0f}%' def table_plot(ax:plt.Axes, table:pl.DataFrame, hand='R'): # Create a transformation that converts from data coordinates to axes coordinates trans = ax.transData + ax.transAxes.inverted() if hand == 'R': bbox_data = Bbox.from_bounds(1.7, -0.5, 2.5, 5) else: bbox_data = Bbox.from_bounds(-4.2, -0.5, 2.5, 5) # replace width and height with the desired values bbox_axes = trans.transform_bbox(bbox_data) if hand == 'R': ax.text(s='Against RHH',x=2.95,y=4.65,fontsize=18,fontweight='bold',ha='center') else: ax.text(s='Against LHH',x=-2.95,y=4.65,fontsize=18,fontweight='bold',ha='center') table = table.apply(lambda x: format_dict[x.name].format(x[0]) if x[0] != '—' else '—', axis=1) table.index = [label_translation_dict[x] for x in table.index] table_plot = ax.table(cellText=table.reset_index().values, loc='right', cellLoc='center', colWidths=[0.52,0.3], bbox=bbox_axes.bounds,zorder=100) min_font_size = 14 # Set table properties table_plot.auto_set_font_size(False) #table.set_fontsize(min(min_font_size,max(min_font_size/((len(label_labels)/4)),10))) table_plot.set_fontsize(min_font_size) #table_left_plot.scale(1,3) # Calculate the bbox in axes coordinates bbox_data = Bbox.from_bounds(-1.25, 5, 2.5, 1) # replace width and height with the desired values bbox_axes = trans.transform_bbox(bbox_data) def table_plot_pivot(ax:plt.Axes, pivot_table:pl.DataFrame, df_colour:pd.DataFrame): trans = ax.transData + ax.transAxes.inverted() bbox_data = Bbox.from_bounds(-0.75, 5, 2.5, 1) # replace width and height with the desired values bbox_axes = trans.transform_bbox(bbox_data) table_plot_pivot = ax.table(cellText=[[format_as_percentage(val) for val in row] for row in pivot_table.select(pivot_table.columns[-4:]).to_numpy()], colLabels =pivot_table.columns[-4:], rowLabels =[' 0 ',' 1 ',' 2 '], loc='center', cellLoc='center', colWidths=[0.3,0.3,0.30,0.3], bbox=bbox_axes.bounds,zorder=100, cellColours = df_colour[df_colour.columns[-4:]].values) min_font_size = 11 # Set table properties table_plot_pivot.auto_set_font_size(False) #table.set_fontsize(min(min_font_size,max(min_font_size/((len(label_labels)/4)),10))) table_plot_pivot.set_fontsize(min_font_size) ax.text(x=-2.0, y=5.08, s='Strikes', rotation=90,fontweight='bold') ax.text(x=0, y=6.05, s='Balls',fontweight='bold',ha='center') def plot_header(pitcher_id: str, ax: plt.Axes, df_team: pl.DataFrame, df_players: pl.DataFrame,sport_id:int): """ Display the team logo for the given pitcher on the specified axis. Parameters ---------- pitcher_id : str The ID of the pitcher. ax : plt.Axes The axis to display the logo on. df_team : pl.DataFrame The DataFrame containing team data. df_players : pl.DataFrame The DataFrame containing player data. """ # List of MLB teams and their corresponding ESPN logo URLs mlb_teams = [ {"team": "AZ", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/ari.png&h=500&w=500"}, {"team": "ATL", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/atl.png&h=500&w=500"}, {"team": "BAL", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/bal.png&h=500&w=500"}, {"team": "BOS", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/bos.png&h=500&w=500"}, {"team": "CHC", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/chc.png&h=500&w=500"}, {"team": "CWS", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/chw.png&h=500&w=500"}, {"team": "CIN", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/cin.png&h=500&w=500"}, {"team": "CLE", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/cle.png&h=500&w=500"}, {"team": "COL", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/col.png&h=500&w=500"}, {"team": "DET", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/det.png&h=500&w=500"}, {"team": "HOU", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/hou.png&h=500&w=500"}, {"team": "KC", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/kc.png&h=500&w=500"}, {"team": "LAA", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/laa.png&h=500&w=500"}, {"team": "LAD", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/lad.png&h=500&w=500"}, {"team": "MIA", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/mia.png&h=500&w=500"}, {"team": "MIL", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/mil.png&h=500&w=500"}, {"team": "MIN", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/min.png&h=500&w=500"}, {"team": "NYM", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/nym.png&h=500&w=500"}, {"team": "NYY", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/nyy.png&h=500&w=500"}, {"team": "OAK", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/oak.png&h=500&w=500"}, {"team": "PHI", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/phi.png&h=500&w=500"}, {"team": "PIT", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/pit.png&h=500&w=500"}, {"team": "SD", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/sd.png&h=500&w=500"}, {"team": "SF", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/sf.png&h=500&w=500"}, {"team": "SEA", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/sea.png&h=500&w=500"}, {"team": "STL", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/stl.png&h=500&w=500"}, {"team": "TB", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/tb.png&h=500&w=500"}, {"team": "TEX", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/tex.png&h=500&w=500"}, {"team": "TOR", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/tor.png&h=500&w=500"}, {"team": "WSH", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/wsh.png&h=500&w=500"}, {"team": "ATH", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/oak.png&h=500&w=500"}, ] try: # Construct the URL for the player's headshot image based on sport ID if int(sport_id) == 1: url = f'https://img.mlbstatic.com/mlb-photos/image/upload/d_people:generic:headshot:67:current.png/w_640,q_auto:best/v1/people/{pitcher_id}/headshot/silo/current.png' else: url = f'https://img.mlbstatic.com/mlb-photos/image/upload/c_fill,g_auto/w_640/v1/people/{pitcher_id}/headshot/milb/current.png' # Send a GET request to the URL and open the image from the response content response = requests.get(url) img = Image.open(BytesIO(response.content)) # Display the image on the axis ax.imshow(img, extent=[-11.5, -9.5, 0, 2] if sport_id == 1 else [-11.5+2/6, -9.5-2/6, 0, 2], origin='upper') except PIL.UnidentifiedImageError: ax.axis('off') try: # Create a DataFrame from the list of dictionaries df_image = pd.DataFrame(mlb_teams) image_dict = df_image.set_index('team')['logo_url'].to_dict() # Get the team ID for the given pitcher team_id = df_players.filter(pl.col('player_id') == pitcher_id)['team'][0] # Construct the URL to fetch team data url_team = f'https://statsapi.mlb.com/api/v1/teams/{team_id}' # Send a GET request to the team URL and parse the JSON response data_team = requests.get(url_team).json() # Extract the team abbreviation if data_team['teams'][0]['id'] in df_team['parent_org_id']: team_abb = df_team.filter(pl.col('team_id') == data_team['teams'][0]['id'])['parent_org_abbreviation'][0] else: team_abb = df_team.filter(pl.col('parent_org_id') == data_team['teams'][0]['parentOrgId'])['parent_org_abbreviation'][0] # Get the logo URL from the image dictionary using the team abbreviation logo_url = image_dict[team_abb] # Send a GET request to the logo URL response = requests.get(logo_url) # Open the image from the response content img = Image.open(BytesIO(response.content)) ax.imshow(img, extent=[9.5, 11.5, 0, 2], origin='upper') # Turn off the axis # ax.axis('off') except (KeyError,IndexError) as e: ax.axis('off') return # DEFINE STRIKE ZONE strike_zone = pd.DataFrame({ 'PlateLocSide': [-0.9, -0.9, 0.9, 0.9, -0.9], 'PlateLocHeight': [1.5, 3.5, 3.5, 1.5, 1.5] }) ### STRIKE ZONE ### def draw_line(axis, alpha_spot=1, catcher_p=True): # Ensure strike_zone columns are NumPy arrays plate_side = strike_zone['PlateLocSide'].to_numpy() plate_height = strike_zone['PlateLocHeight'].to_numpy() # Plot the strike zone axis.plot(plate_side, plate_height, color='black', linewidth=1.3, zorder=3, alpha=alpha_spot) if catcher_p: # Add dashed lines and home plate for catcher perspective axis.plot([-0.708, 0.708], [0.15, 0.15], color='black', linewidth=1, alpha=alpha_spot, zorder=1) axis.plot([-0.708, -0.708], [0.15, 0.3], color='black', linewidth=1, alpha=alpha_spot, zorder=1) axis.plot([-0.708, 0], [0.3, 0.5], color='black', linewidth=1, alpha=alpha_spot, zorder=1) axis.plot([0, 0.708], [0.5, 0.3], color='black', linewidth=1, alpha=alpha_spot, zorder=1) axis.plot([0.708, 0.708], [0.3, 0.15], color='black', linewidth=1, alpha=alpha_spot, zorder=1) else: # Add dashed lines and home plate for other perspective axis.plot([-0.708, 0.708], [0.4, 0.4], color='black', linewidth=1, alpha=alpha_spot, zorder=1) axis.plot([-0.708, -0.9], [0.4, -0.1], color='black', linewidth=1, alpha=alpha_spot, zorder=1) axis.plot([-0.9, 0], [-0.1, -0.35], color='black', linewidth=1, alpha=alpha_spot, zorder=1) axis.plot([0, 0.9], [-0.35, -0.1], color='black', linewidth=1, alpha=alpha_spot, zorder=1) axis.plot([0.9, 0.708], [-0.1, 0.4], color='black', linewidth=1, alpha=alpha_spot, zorder=1) from matplotlib.patches import Rectangle # Function to draw the strike zone and home plate # Import necessary libraries import matplotlib.pyplot as plt import seaborn as sns from matplotlib.patches import Rectangle from matplotlib import gridspec import numpy as np import pandas as pd from statsmodels.nonparametric.kernel_regression import KernelReg # DEFINE STRIKE ZONE strike_zone = pd.DataFrame({ 'PlateLocSide': [-0.9, -0.9, 0.9, 0.9, -0.9], 'PlateLocHeight': [1.5, 3.5, 3.5, 1.5, 1.5] }) ### STRIKE ZONE ### def draw_line(axis, alpha_spot=1, catcher_p=True): # Ensure strike_zone columns are NumPy arrays plate_side = strike_zone['PlateLocSide'].to_numpy() plate_height = strike_zone['PlateLocHeight'].to_numpy() # Plot the strike zone axis.plot(plate_side, plate_height, color='black', linewidth=1.3, zorder=3, alpha=alpha_spot) if catcher_p: # Add dashed lines and home plate for catcher perspective axis.plot([-0.708, 0.708], [0.15, 0.15], color='black', linewidth=1, alpha=alpha_spot, zorder=1) axis.plot([-0.708, -0.708], [0.15, 0.3], color='black', linewidth=1, alpha=alpha_spot, zorder=1) axis.plot([-0.708, 0], [0.3, 0.5], color='black', linewidth=1, alpha=alpha_spot, zorder=1) axis.plot([0, 0.708], [0.5, 0.3], color='black', linewidth=1, alpha=alpha_spot, zorder=1) axis.plot([0.708, 0.708], [0.3, 0.15], color='black', linewidth=1, alpha=alpha_spot, zorder=1) else: # Add dashed lines and home plate for other perspective axis.plot([-0.708, 0.708], [0.4, 0.4], color='black', linewidth=1, alpha=alpha_spot, zorder=1) axis.plot([-0.708, -0.9], [0.4, -0.1], color='black', linewidth=1, alpha=alpha_spot, zorder=1) axis.plot([-0.9, 0], [-0.1, -0.35], color='black', linewidth=1, alpha=alpha_spot, zorder=1) axis.plot([0, 0.9], [-0.35, -0.1], color='black', linewidth=1, alpha=alpha_spot, zorder=1) axis.plot([0.9, 0.708], [-0.1, 0.4], color='black', linewidth=1, alpha=alpha_spot, zorder=1) def heat_map_plot_hex_whiff(df:pl.DataFrame, ax:plt.Axes, cmap:matplotlib.colors.LinearSegmentedColormap, hand:str): # Generate a grid of x and z coordinates for the strike zone area heatmap_df = df.filter((pl.col('batter_hand')==hand)&((pl.col('is_swing')))).to_pandas() # Load your data here heatmap_df['is_whiff'] = heatmap_df['is_whiff'].fillna(0) bin_size = max(0.1, min(0.1, 1 / np.sqrt(len(heatmap_df)))) zone_df = pd.DataFrame(columns=['px', 'pz']) for x in np.arange(-2.75, 2.85,bin_size): for y in np.arange(-0.5, 5.6,bin_size): zone_df.loc[len(zone_df)] = [round(x,1), round(y,1)] heatmap_df.loc[heatmap_df['px'].notna(),'kde_x'] = np.clip(heatmap_df.loc[heatmap_df['px'].notna(),'px'].astype('float').mul(10).astype('int').div(10), -2.75, 2.75) heatmap_df.loc[heatmap_df['pz'].notna(),'kde_z'] = np.clip(heatmap_df.loc[heatmap_df['pz'].notna(),'pz'].astype('float').mul(10).astype('int').div(10), -0.5, 5) # Dynamically determine bandwidth for KDE # bandwidth = np.clip(heatmap_df.shape[0] / 2000, 0.2, 0.2) bandwidth = np.clip(1 / np.sqrt(len(df)), 0.3, 0.5) # Kernel Regression for smoothing the metric values v_center = 0.25 kde_df = pd.merge(zone_df, heatmap_df .dropna(subset=['is_whiff', 'px', 'pz']) [['kde_x', 'kde_z', 'is_whiff']], how='left', left_on=['px', 'pz'], right_on=['kde_x', 'kde_z']).fillna({'is_whiff': v_center}) kernel_regression = KernelReg(endog=kde_df['is_whiff'], exog=[kde_df['px'], kde_df['pz']], bw=[bandwidth,bandwidth], var_type='cc') kde_df['kernel_stat'] = kernel_regression.fit([kde_df['px'], kde_df['pz']])[0] kde_df = kde_df.pivot_table(columns='px', index='pz', values='kernel_stat', aggfunc='mean') kde_df = kde_df.round(3) # Set up a gridspec layout for heatmap and colorbar from matplotlib.colors import LinearSegmentedColormap # Define the custom color palette kde_min = '#648FFF' # Blue kde_mid = '#ffffff' # White kde_max = '#FFB000' # Orange # Create a custom colormap # kde_palette = LinearSegmentedColormap.from_list("kde_palette", [kde_min, kde_mid, kde_max]) # kde_palette = (sns.color_palette(f'blend:{kde_min},{kde_mid}', n_colors=101)[:-1] + # sns.color_palette(f'blend:{kde_mid},{kde_max}', n_colors=101)[:-1]) # kde_palette = (sns.color_palette(f'blend:{kde_min},{kde_mid}', n_colors=101)[:-1] + # sns.color_palette(f'blend:{kde_mid},{kde_max}', n_colors=101)[:-1]) # # Generate the heatmap # heatmap = sns.heatmap(data=kde_df, cmap=kde_palette, center=v_center, vmin=0.15, vmax=0.35, cbar=False, ax=ax) ax.imshow(kde_df.values, extent=[-2.25, 2.25, -0.5, 5], origin='lower', cmap=cmap, vmin=0.15, vmax=0.35, interpolation='bilinear')# Customize axes ax.axis('square') ax.set(xlabel=None, ylabel=None) ax.set_xlim(-2.75, 2.75) ax.set_ylim(-0.5, 5) # ax.set_xticks([]) # ax.set_yticks([]) # ax.invert_yaxis() # ax.grid(False) ax.axis('off') draw_line(ax,alpha_spot=1,catcher_p = False) def heat_map_plot_hex_damage(df:pl.DataFrame, ax:plt.Axes, cmap:matplotlib.colors.LinearSegmentedColormap, hand:str): heatmap_df = df.filter((pl.col('batter_hand')==hand)&((pl.col('launch_speed')>0))).to_pandas() # Load your data here heatmap_df['woba_pred_contact'] = heatmap_df['woba_pred_contact'].fillna(0) bin_size = max(0.2, min(0.3, 1 / np.sqrt(len(heatmap_df)))) # Generate a grid of x and z coordinates for the strike zone area zone_df = pd.DataFrame(columns=['px', 'pz']) for x in np.arange(-2.75, 2.95,bin_size): for y in np.arange(-0.5, 5.7,bin_size): zone_df.loc[len(zone_df)] = [round(x,1), round(y,1)] heatmap_df.loc[heatmap_df['px'].notna(),'kde_x'] = np.clip(heatmap_df.loc[heatmap_df['px'].notna(),'px'].astype('float').mul(10).astype('int').div(10), -2.75, 2.75) heatmap_df.loc[heatmap_df['pz'].notna(),'kde_z'] = np.clip(heatmap_df.loc[heatmap_df['pz'].notna(),'pz'].astype('float').mul(10).astype('int').div(10), -0.5, 5) # Dynamically determine bandwidth for KDE # bandwidth = np.clip(heatmap_df.shape[0] / 2000, 0.2, 0.2) bandwidth = np.clip(1 / np.sqrt(len(df)), 0.3, 0.5) # Kernel Regression for smoothing the metric values v_center = 0.375 kde_df = pd.merge(zone_df, heatmap_df .dropna(subset=['woba_pred_contact', 'px', 'pz']) [['kde_x', 'kde_z', 'woba_pred_contact']], how='left', left_on=['px', 'pz'], right_on=['kde_x', 'kde_z']).fillna({'woba_pred_contact': v_center}) kernel_regression = KernelReg(endog=kde_df['woba_pred_contact'], exog=[kde_df['px'], kde_df['pz']], bw=[bandwidth,bandwidth], var_type='cc') kde_df['kernel_stat'] = kernel_regression.fit([kde_df['px'], kde_df['pz']])[0] kde_df = kde_df.pivot_table(columns='px', index='pz', values='kernel_stat', aggfunc='mean') kde_df = kde_df.round(3) # Set up a gridspec layout for heatmap and colorbar from matplotlib.colors import LinearSegmentedColormap # Define the custom color palette kde_min = '#648FFF' # Blue kde_mid = '#ffffff' # White kde_max = '#FFB000' # Orange # Create a custom colormap # kde_palette = LinearSegmentedColormap.from_list("kde_palette", [kde_min, kde_mid, kde_max]) # kde_palette = (sns.color_palette(f'blend:{kde_min},{kde_mid}', n_colors=101)[:-1] + # sns.color_palette(f'blend:{kde_mid},{kde_max}', n_colors=101)[:-1]) # kde_palette = (sns.color_palette(f'blend:{kde_min},{kde_mid}', n_colors=101)[:-1] + # sns.color_palette(f'blend:{kde_mid},{kde_max}', n_colors=101)[:-1]) # # Generate the heatmap # heatmap = sns.heatmap(data=kde_df, cmap=kde_palette, center=v_center, vmin=0.15, vmax=0.35, cbar=False, ax=ax) ax.imshow(kde_df.values, extent=[-2.25, 2.25, -0.5, 5], origin='lower', cmap=cmap, vmin=0.25, vmax=0.5, interpolation='bilinear')# Customize axes ax.axis('square') ax.set(xlabel=None, ylabel=None) ax.set_xlim(-2.75, 2.75) ax.set_ylim(-0.5, 5) # ax.set_xticks([]) # ax.set_yticks([]) # ax.invert_yaxis() # ax.grid(False) ax.axis('off') draw_line(ax,alpha_spot=1,catcher_p = False)