pitching_summary_complete / functions /heat_map_functions.py
nesticot's picture
Upload 9 files
05458eb verified
import pandas as pd
import numpy as np
import json
from matplotlib.ticker import FuncFormatter
from matplotlib.ticker import MaxNLocator
import math
from matplotlib.patches import Ellipse
import matplotlib.transforms as transforms
import matplotlib.colors
import matplotlib.colors as mcolors
import seaborn as sns
import matplotlib.pyplot as plt
import requests
import polars as pl
from PIL import Image
import requests
from io import BytesIO
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import PIL
from matplotlib.transforms import Bbox
import matplotlib.image as mpimg
from scipy.stats import gaussian_kde
from statsmodels.nonparametric.kernel_regression import KernelReg
format_dict = {
'pitch_percent': '{:.1%}',
'pitches': '{:.0f}',
'heart_zone_percent': '{:.1%}',
'shadow_zone_percent': '{:.1%}',
'chase_zone_percent': '{:.1%}',
'waste_zone_percent': '{:.1%}',
'csw_percent': '{:.1%}',
'whiff_rate': '{:.1%}',
'zone_whiff_percent': '{:.1%}',
'chase_percent': '{:.1%}',
'bip': '{:.0f}',
'xwoba_percent_contact': '{:.3f}'
}
label_translation_dict = {
'pitch_percent': 'Pitch%',
'pitches': 'Pitches',
'heart_zone_percent': 'Heart%',
'shadow_zone_percent': 'Shadow%',
'chase_zone_percent': 'Chase%',
'waste_zone_percent': 'Waste%',
'csw_percent': 'CSW%',
'whiff_rate': 'Whiff%',
'zone_whiff_percent': 'Z-Whiff%',
'chase_percent': 'O-Swing%',
'bip': 'BBE',
'xwoba_percent_contact': 'xwOBACON'
}
def pitch_heat_map(pitch_input, df):
df = df.with_columns([
pl.col('pitcher_id').count().over(['batter_hand', 'strikes', 'balls']).alias('h_s_b'),
pl.col('pitcher_id').count().over(['batter_hand', 'strikes', 'balls', 'pitch_type']).alias('h_s_b_pitch')
])
df = df.with_columns([
(pl.col('h_s_b_pitch') / pl.col('h_s_b')).alias('h_s_b_pitch_percent')
])
df_plot = df.filter(pl.col('pitch_type') == pitch_input)
return df_plot
def pitch_prop(df: pl.DataFrame, hand: str = 'R') -> pd.DataFrame:
df_plot_pd = df.to_pandas()
pivot_table = (df_plot_pd[df_plot_pd['batter_hand'].isin([hand])]
.groupby(['batter_hand','strikes', 'balls'])[['h_s_b_pitch_percent']]
.mean()
.reset_index()
.pivot(index='strikes',columns='balls',values='h_s_b_pitch_percent'))
# Create a new index and columns range
new_index = range(3)
new_columns = range(4)
# Reindex the pivot table
pivot_table = pivot_table.reindex(index=new_index, columns=new_columns)
# Fill any missing values with 0
pivot_table = pivot_table.fillna(0)
df_hand = pl.DataFrame(pivot_table.reset_index())
return df_hand
# DEFINE STRIKE ZONE
strike_zone = pd.DataFrame({
'PlateLocSide': [-0.9, -0.9, 0.9, 0.9, -0.9],
'PlateLocHeight': [1.5, 3.5, 3.5, 1.5, 1.5]
})
### STRIKE ZONE ###
def draw_line(axis, alpha_spot=1, catcher_p=True):
# Ensure strike_zone columns are NumPy arrays
plate_side = strike_zone['PlateLocSide'].to_numpy()
plate_height = strike_zone['PlateLocHeight'].to_numpy()
# Plot the strike zone
axis.plot(plate_side, plate_height, color='black', linewidth=1.3, zorder=3, alpha=alpha_spot)
if catcher_p:
# Add dashed lines and home plate for catcher perspective
axis.plot([-0.708, 0.708], [0.15, 0.15], color='black', linewidth=1, alpha=alpha_spot, zorder=1)
axis.plot([-0.708, -0.708], [0.15, 0.3], color='black', linewidth=1, alpha=alpha_spot, zorder=1)
axis.plot([-0.708, 0], [0.3, 0.5], color='black', linewidth=1, alpha=alpha_spot, zorder=1)
axis.plot([0, 0.708], [0.5, 0.3], color='black', linewidth=1, alpha=alpha_spot, zorder=1)
axis.plot([0.708, 0.708], [0.3, 0.15], color='black', linewidth=1, alpha=alpha_spot, zorder=1)
else:
# Add dashed lines and home plate for other perspective
axis.plot([-0.708, 0.708], [0.4, 0.4], color='black', linewidth=1, alpha=alpha_spot, zorder=1)
axis.plot([-0.708, -0.9], [0.4, -0.1], color='black', linewidth=1, alpha=alpha_spot, zorder=1)
axis.plot([-0.9, 0], [-0.1, -0.35], color='black', linewidth=1, alpha=alpha_spot, zorder=1)
axis.plot([0, 0.9], [-0.35, -0.1], color='black', linewidth=1, alpha=alpha_spot, zorder=1)
axis.plot([0.9, 0.708], [-0.1, 0.4], color='black', linewidth=1, alpha=alpha_spot, zorder=1)
def heat_map_plot(df:pl.DataFrame,
ax:plt.Axes,
cmap:matplotlib.colors.LinearSegmentedColormap,
hand:str):
if df.filter(pl.col('batter_hand')==hand).shape[0] > 3:
sns.kdeplot(data=df.filter(pl.col('batter_hand')==hand),
x='px',
y='pz',
cmap=cmap,
shade=True,
ax=ax,
thresh=0.3,
bw_adjust=1)
elif df.filter(pl.col('batter_hand')==hand).shape[0] > 0:
sns.scatterplot(data=df.filter(pl.col('batter_hand')==hand),
x='px',
y='pz',
cmap=cmap,
ax=ax)
draw_line(ax,alpha_spot=1,catcher_p = False)
ax.axis('off')
ax.axis('square')
ax.set_xlim(-2.75,2.75)
ax.set_ylim(-0.5,5)
def format_as_percentage(val):
return f'{val * 100:.0f}%'
def table_plot(ax:plt.Axes,
table:pl.DataFrame,
hand='R'):
# Create a transformation that converts from data coordinates to axes coordinates
trans = ax.transData + ax.transAxes.inverted()
if hand == 'R':
bbox_data = Bbox.from_bounds(1.7, -0.5, 2.5, 5)
else:
bbox_data = Bbox.from_bounds(-4.2, -0.5, 2.5, 5) # replace width and height with the desired values
bbox_axes = trans.transform_bbox(bbox_data)
if hand == 'R':
ax.text(s='Against RHH',x=2.95,y=4.65,fontsize=18,fontweight='bold',ha='center')
else:
ax.text(s='Against LHH',x=-2.95,y=4.65,fontsize=18,fontweight='bold',ha='center')
table = table.apply(lambda x: format_dict[x.name].format(x[0]) if x[0] != '—' else '—', axis=1)
table.index = [label_translation_dict[x] for x in table.index]
table_plot = ax.table(cellText=table.reset_index().values,
loc='right',
cellLoc='center',
colWidths=[0.52,0.3],
bbox=bbox_axes.bounds,zorder=100)
min_font_size = 14
# Set table properties
table_plot.auto_set_font_size(False)
#table.set_fontsize(min(min_font_size,max(min_font_size/((len(label_labels)/4)),10)))
table_plot.set_fontsize(min_font_size)
#table_left_plot.scale(1,3)
# Calculate the bbox in axes coordinates
bbox_data = Bbox.from_bounds(-1.25, 5, 2.5, 1) # replace width and height with the desired values
bbox_axes = trans.transform_bbox(bbox_data)
def table_plot_pivot(ax:plt.Axes,
pivot_table:pl.DataFrame,
df_colour:pd.DataFrame):
trans = ax.transData + ax.transAxes.inverted()
bbox_data = Bbox.from_bounds(-0.75, 5, 2.5, 1) # replace width and height with the desired values
bbox_axes = trans.transform_bbox(bbox_data)
table_plot_pivot = ax.table(cellText=[[format_as_percentage(val) for val in row] for row in pivot_table.select(pivot_table.columns[-4:]).to_numpy()],
colLabels =pivot_table.columns[-4:],
rowLabels =[' 0 ',' 1 ',' 2 '],
loc='center',
cellLoc='center',
colWidths=[0.3,0.3,0.30,0.3],
bbox=bbox_axes.bounds,zorder=100,
cellColours = df_colour[df_colour.columns[-4:]].values)
min_font_size = 11
# Set table properties
table_plot_pivot.auto_set_font_size(False)
#table.set_fontsize(min(min_font_size,max(min_font_size/((len(label_labels)/4)),10)))
table_plot_pivot.set_fontsize(min_font_size)
ax.text(x=-2.0, y=5.08, s='Strikes', rotation=90,fontweight='bold')
ax.text(x=0, y=6.05, s='Balls',fontweight='bold',ha='center')
def plot_header(pitcher_id: str, ax: plt.Axes, df_team: pl.DataFrame, df_players: pl.DataFrame,sport_id:int):
"""
Display the team logo for the given pitcher on the specified axis.
Parameters
----------
pitcher_id : str
The ID of the pitcher.
ax : plt.Axes
The axis to display the logo on.
df_team : pl.DataFrame
The DataFrame containing team data.
df_players : pl.DataFrame
The DataFrame containing player data.
"""
# List of MLB teams and their corresponding ESPN logo URLs
mlb_teams = [
{"team": "AZ", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/ari.png&h=500&w=500"},
{"team": "ATL", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/atl.png&h=500&w=500"},
{"team": "BAL", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/bal.png&h=500&w=500"},
{"team": "BOS", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/bos.png&h=500&w=500"},
{"team": "CHC", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/chc.png&h=500&w=500"},
{"team": "CWS", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/chw.png&h=500&w=500"},
{"team": "CIN", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/cin.png&h=500&w=500"},
{"team": "CLE", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/cle.png&h=500&w=500"},
{"team": "COL", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/col.png&h=500&w=500"},
{"team": "DET", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/det.png&h=500&w=500"},
{"team": "HOU", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/hou.png&h=500&w=500"},
{"team": "KC", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/kc.png&h=500&w=500"},
{"team": "LAA", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/laa.png&h=500&w=500"},
{"team": "LAD", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/lad.png&h=500&w=500"},
{"team": "MIA", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/mia.png&h=500&w=500"},
{"team": "MIL", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/mil.png&h=500&w=500"},
{"team": "MIN", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/min.png&h=500&w=500"},
{"team": "NYM", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/nym.png&h=500&w=500"},
{"team": "NYY", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/nyy.png&h=500&w=500"},
{"team": "OAK", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/oak.png&h=500&w=500"},
{"team": "PHI", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/phi.png&h=500&w=500"},
{"team": "PIT", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/pit.png&h=500&w=500"},
{"team": "SD", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/sd.png&h=500&w=500"},
{"team": "SF", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/sf.png&h=500&w=500"},
{"team": "SEA", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/sea.png&h=500&w=500"},
{"team": "STL", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/stl.png&h=500&w=500"},
{"team": "TB", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/tb.png&h=500&w=500"},
{"team": "TEX", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/tex.png&h=500&w=500"},
{"team": "TOR", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/tor.png&h=500&w=500"},
{"team": "WSH", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/wsh.png&h=500&w=500"},
{"team": "ATH", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/oak.png&h=500&w=500"},
]
try:
# Construct the URL for the player's headshot image based on sport ID
if int(sport_id) == 1:
url = f'https://img.mlbstatic.com/mlb-photos/image/upload/d_people:generic:headshot:67:current.png/w_640,q_auto:best/v1/people/{pitcher_id}/headshot/silo/current.png'
else:
url = f'https://img.mlbstatic.com/mlb-photos/image/upload/c_fill,g_auto/w_640/v1/people/{pitcher_id}/headshot/milb/current.png'
# Send a GET request to the URL and open the image from the response content
response = requests.get(url)
img = Image.open(BytesIO(response.content))
# Display the image on the axis
ax.imshow(img, extent=[-11.5, -9.5, 0, 2] if sport_id == 1 else [-11.5+2/6, -9.5-2/6, 0, 2], origin='upper')
except PIL.UnidentifiedImageError:
ax.axis('off')
try:
# Create a DataFrame from the list of dictionaries
df_image = pd.DataFrame(mlb_teams)
image_dict = df_image.set_index('team')['logo_url'].to_dict()
# Get the team ID for the given pitcher
team_id = df_players.filter(pl.col('player_id') == pitcher_id)['team'][0]
# Construct the URL to fetch team data
url_team = f'https://statsapi.mlb.com/api/v1/teams/{team_id}'
# Send a GET request to the team URL and parse the JSON response
data_team = requests.get(url_team).json()
# Extract the team abbreviation
if data_team['teams'][0]['id'] in df_team['parent_org_id']:
team_abb = df_team.filter(pl.col('team_id') == data_team['teams'][0]['id'])['parent_org_abbreviation'][0]
else:
team_abb = df_team.filter(pl.col('parent_org_id') == data_team['teams'][0]['parentOrgId'])['parent_org_abbreviation'][0]
# Get the logo URL from the image dictionary using the team abbreviation
logo_url = image_dict[team_abb]
# Send a GET request to the logo URL
response = requests.get(logo_url)
# Open the image from the response content
img = Image.open(BytesIO(response.content))
ax.imshow(img, extent=[9.5, 11.5, 0, 2], origin='upper')
# Turn off the axis
# ax.axis('off')
except (KeyError,IndexError) as e:
ax.axis('off')
return
# DEFINE STRIKE ZONE
strike_zone = pd.DataFrame({
'PlateLocSide': [-0.9, -0.9, 0.9, 0.9, -0.9],
'PlateLocHeight': [1.5, 3.5, 3.5, 1.5, 1.5]
})
### STRIKE ZONE ###
def draw_line(axis, alpha_spot=1, catcher_p=True):
# Ensure strike_zone columns are NumPy arrays
plate_side = strike_zone['PlateLocSide'].to_numpy()
plate_height = strike_zone['PlateLocHeight'].to_numpy()
# Plot the strike zone
axis.plot(plate_side, plate_height, color='black', linewidth=1.3, zorder=3, alpha=alpha_spot)
if catcher_p:
# Add dashed lines and home plate for catcher perspective
axis.plot([-0.708, 0.708], [0.15, 0.15], color='black', linewidth=1, alpha=alpha_spot, zorder=1)
axis.plot([-0.708, -0.708], [0.15, 0.3], color='black', linewidth=1, alpha=alpha_spot, zorder=1)
axis.plot([-0.708, 0], [0.3, 0.5], color='black', linewidth=1, alpha=alpha_spot, zorder=1)
axis.plot([0, 0.708], [0.5, 0.3], color='black', linewidth=1, alpha=alpha_spot, zorder=1)
axis.plot([0.708, 0.708], [0.3, 0.15], color='black', linewidth=1, alpha=alpha_spot, zorder=1)
else:
# Add dashed lines and home plate for other perspective
axis.plot([-0.708, 0.708], [0.4, 0.4], color='black', linewidth=1, alpha=alpha_spot, zorder=1)
axis.plot([-0.708, -0.9], [0.4, -0.1], color='black', linewidth=1, alpha=alpha_spot, zorder=1)
axis.plot([-0.9, 0], [-0.1, -0.35], color='black', linewidth=1, alpha=alpha_spot, zorder=1)
axis.plot([0, 0.9], [-0.35, -0.1], color='black', linewidth=1, alpha=alpha_spot, zorder=1)
axis.plot([0.9, 0.708], [-0.1, 0.4], color='black', linewidth=1, alpha=alpha_spot, zorder=1)
from matplotlib.patches import Rectangle
# Function to draw the strike zone and home plate
# Import necessary libraries
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Rectangle
from matplotlib import gridspec
import numpy as np
import pandas as pd
from statsmodels.nonparametric.kernel_regression import KernelReg
# DEFINE STRIKE ZONE
strike_zone = pd.DataFrame({
'PlateLocSide': [-0.9, -0.9, 0.9, 0.9, -0.9],
'PlateLocHeight': [1.5, 3.5, 3.5, 1.5, 1.5]
})
### STRIKE ZONE ###
def draw_line(axis, alpha_spot=1, catcher_p=True):
# Ensure strike_zone columns are NumPy arrays
plate_side = strike_zone['PlateLocSide'].to_numpy()
plate_height = strike_zone['PlateLocHeight'].to_numpy()
# Plot the strike zone
axis.plot(plate_side, plate_height, color='black', linewidth=1.3, zorder=3, alpha=alpha_spot)
if catcher_p:
# Add dashed lines and home plate for catcher perspective
axis.plot([-0.708, 0.708], [0.15, 0.15], color='black', linewidth=1, alpha=alpha_spot, zorder=1)
axis.plot([-0.708, -0.708], [0.15, 0.3], color='black', linewidth=1, alpha=alpha_spot, zorder=1)
axis.plot([-0.708, 0], [0.3, 0.5], color='black', linewidth=1, alpha=alpha_spot, zorder=1)
axis.plot([0, 0.708], [0.5, 0.3], color='black', linewidth=1, alpha=alpha_spot, zorder=1)
axis.plot([0.708, 0.708], [0.3, 0.15], color='black', linewidth=1, alpha=alpha_spot, zorder=1)
else:
# Add dashed lines and home plate for other perspective
axis.plot([-0.708, 0.708], [0.4, 0.4], color='black', linewidth=1, alpha=alpha_spot, zorder=1)
axis.plot([-0.708, -0.9], [0.4, -0.1], color='black', linewidth=1, alpha=alpha_spot, zorder=1)
axis.plot([-0.9, 0], [-0.1, -0.35], color='black', linewidth=1, alpha=alpha_spot, zorder=1)
axis.plot([0, 0.9], [-0.35, -0.1], color='black', linewidth=1, alpha=alpha_spot, zorder=1)
axis.plot([0.9, 0.708], [-0.1, 0.4], color='black', linewidth=1, alpha=alpha_spot, zorder=1)
def heat_map_plot_hex_whiff(df:pl.DataFrame,
ax:plt.Axes,
cmap:matplotlib.colors.LinearSegmentedColormap,
hand:str):
# Generate a grid of x and z coordinates for the strike zone area
heatmap_df = df.filter((pl.col('batter_hand')==hand)&((pl.col('is_swing')))).to_pandas() # Load your data here
heatmap_df['is_whiff'] = heatmap_df['is_whiff'].fillna(0)
bin_size = max(0.1, min(0.1, 1 / np.sqrt(len(heatmap_df))))
zone_df = pd.DataFrame(columns=['px', 'pz'])
for x in np.arange(-2.75, 2.85,bin_size):
for y in np.arange(-0.5, 5.6,bin_size):
zone_df.loc[len(zone_df)] = [round(x,1), round(y,1)]
heatmap_df.loc[heatmap_df['px'].notna(),'kde_x'] = np.clip(heatmap_df.loc[heatmap_df['px'].notna(),'px'].astype('float').mul(10).astype('int').div(10),
-2.75,
2.75)
heatmap_df.loc[heatmap_df['pz'].notna(),'kde_z'] = np.clip(heatmap_df.loc[heatmap_df['pz'].notna(),'pz'].astype('float').mul(10).astype('int').div(10),
-0.5,
5)
# Dynamically determine bandwidth for KDE
# bandwidth = np.clip(heatmap_df.shape[0] / 2000, 0.2, 0.2)
bandwidth = np.clip(1 / np.sqrt(len(df)), 0.3, 0.5)
# Kernel Regression for smoothing the metric values
v_center = 0.25
kde_df = pd.merge(zone_df,
heatmap_df
.dropna(subset=['is_whiff', 'px', 'pz'])
[['kde_x', 'kde_z', 'is_whiff']],
how='left',
left_on=['px', 'pz'],
right_on=['kde_x', 'kde_z']).fillna({'is_whiff': v_center})
kernel_regression = KernelReg(endog=kde_df['is_whiff'],
exog=[kde_df['px'], kde_df['pz']],
bw=[bandwidth,bandwidth],
var_type='cc')
kde_df['kernel_stat'] = kernel_regression.fit([kde_df['px'], kde_df['pz']])[0]
kde_df = kde_df.pivot_table(columns='px', index='pz', values='kernel_stat', aggfunc='mean')
kde_df = kde_df.round(3)
# Set up a gridspec layout for heatmap and colorbar
from matplotlib.colors import LinearSegmentedColormap
# Define the custom color palette
kde_min = '#648FFF' # Blue
kde_mid = '#ffffff' # White
kde_max = '#FFB000' # Orange
# Create a custom colormap
# kde_palette = LinearSegmentedColormap.from_list("kde_palette", [kde_min, kde_mid, kde_max])
# kde_palette = (sns.color_palette(f'blend:{kde_min},{kde_mid}', n_colors=101)[:-1] +
# sns.color_palette(f'blend:{kde_mid},{kde_max}', n_colors=101)[:-1])
# kde_palette = (sns.color_palette(f'blend:{kde_min},{kde_mid}', n_colors=101)[:-1] +
# sns.color_palette(f'blend:{kde_mid},{kde_max}', n_colors=101)[:-1])
# # Generate the heatmap
# heatmap = sns.heatmap(data=kde_df, cmap=kde_palette, center=v_center, vmin=0.15, vmax=0.35, cbar=False, ax=ax)
ax.imshow(kde_df.values, extent=[-2.25, 2.25, -0.5, 5], origin='lower', cmap=cmap, vmin=0.15, vmax=0.35,
interpolation='bilinear')# Customize axes
ax.axis('square')
ax.set(xlabel=None, ylabel=None)
ax.set_xlim(-2.75, 2.75)
ax.set_ylim(-0.5, 5)
# ax.set_xticks([])
# ax.set_yticks([])
# ax.invert_yaxis()
# ax.grid(False)
ax.axis('off')
draw_line(ax,alpha_spot=1,catcher_p = False)
def heat_map_plot_hex_damage(df:pl.DataFrame,
ax:plt.Axes,
cmap:matplotlib.colors.LinearSegmentedColormap,
hand:str):
heatmap_df = df.filter((pl.col('batter_hand')==hand)&((pl.col('launch_speed')>0))).to_pandas() # Load your data here
heatmap_df['woba_pred_contact'] = heatmap_df['woba_pred_contact'].fillna(0)
bin_size = max(0.2, min(0.3, 1 / np.sqrt(len(heatmap_df))))
# Generate a grid of x and z coordinates for the strike zone area
zone_df = pd.DataFrame(columns=['px', 'pz'])
for x in np.arange(-2.75, 2.95,bin_size):
for y in np.arange(-0.5, 5.7,bin_size):
zone_df.loc[len(zone_df)] = [round(x,1), round(y,1)]
heatmap_df.loc[heatmap_df['px'].notna(),'kde_x'] = np.clip(heatmap_df.loc[heatmap_df['px'].notna(),'px'].astype('float').mul(10).astype('int').div(10),
-2.75,
2.75)
heatmap_df.loc[heatmap_df['pz'].notna(),'kde_z'] = np.clip(heatmap_df.loc[heatmap_df['pz'].notna(),'pz'].astype('float').mul(10).astype('int').div(10),
-0.5,
5)
# Dynamically determine bandwidth for KDE
# bandwidth = np.clip(heatmap_df.shape[0] / 2000, 0.2, 0.2)
bandwidth = np.clip(1 / np.sqrt(len(df)), 0.3, 0.5)
# Kernel Regression for smoothing the metric values
v_center = 0.375
kde_df = pd.merge(zone_df,
heatmap_df
.dropna(subset=['woba_pred_contact', 'px', 'pz'])
[['kde_x', 'kde_z', 'woba_pred_contact']],
how='left',
left_on=['px', 'pz'],
right_on=['kde_x', 'kde_z']).fillna({'woba_pred_contact': v_center})
kernel_regression = KernelReg(endog=kde_df['woba_pred_contact'],
exog=[kde_df['px'], kde_df['pz']],
bw=[bandwidth,bandwidth],
var_type='cc')
kde_df['kernel_stat'] = kernel_regression.fit([kde_df['px'], kde_df['pz']])[0]
kde_df = kde_df.pivot_table(columns='px', index='pz', values='kernel_stat', aggfunc='mean')
kde_df = kde_df.round(3)
# Set up a gridspec layout for heatmap and colorbar
from matplotlib.colors import LinearSegmentedColormap
# Define the custom color palette
kde_min = '#648FFF' # Blue
kde_mid = '#ffffff' # White
kde_max = '#FFB000' # Orange
# Create a custom colormap
# kde_palette = LinearSegmentedColormap.from_list("kde_palette", [kde_min, kde_mid, kde_max])
# kde_palette = (sns.color_palette(f'blend:{kde_min},{kde_mid}', n_colors=101)[:-1] +
# sns.color_palette(f'blend:{kde_mid},{kde_max}', n_colors=101)[:-1])
# kde_palette = (sns.color_palette(f'blend:{kde_min},{kde_mid}', n_colors=101)[:-1] +
# sns.color_palette(f'blend:{kde_mid},{kde_max}', n_colors=101)[:-1])
# # Generate the heatmap
# heatmap = sns.heatmap(data=kde_df, cmap=kde_palette, center=v_center, vmin=0.15, vmax=0.35, cbar=False, ax=ax)
ax.imshow(kde_df.values, extent=[-2.25, 2.25, -0.5, 5], origin='lower', cmap=cmap, vmin=0.25, vmax=0.5,
interpolation='bilinear')# Customize axes
ax.axis('square')
ax.set(xlabel=None, ylabel=None)
ax.set_xlim(-2.75, 2.75)
ax.set_ylim(-0.5, 5)
# ax.set_xticks([])
# ax.set_yticks([])
# ax.invert_yaxis()
# ax.grid(False)
ax.axis('off')
draw_line(ax,alpha_spot=1,catcher_p = False)