Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| import numpy as np | |
| import json | |
| from matplotlib.ticker import FuncFormatter | |
| from matplotlib.ticker import MaxNLocator | |
| import math | |
| from matplotlib.patches import Ellipse | |
| import matplotlib.transforms as transforms | |
| import matplotlib.colors | |
| import matplotlib.colors as mcolors | |
| import seaborn as sns | |
| import matplotlib.pyplot as plt | |
| import requests | |
| import polars as pl | |
| from PIL import Image | |
| import requests | |
| from io import BytesIO | |
| from matplotlib.offsetbox import OffsetImage, AnnotationBbox | |
| import matplotlib.pyplot as plt | |
| import matplotlib.gridspec as gridspec | |
| import PIL | |
| from matplotlib.transforms import Bbox | |
| import matplotlib.image as mpimg | |
| from scipy.stats import gaussian_kde | |
| from statsmodels.nonparametric.kernel_regression import KernelReg | |
| format_dict = { | |
| 'pitch_percent': '{:.1%}', | |
| 'pitches': '{:.0f}', | |
| 'heart_zone_percent': '{:.1%}', | |
| 'shadow_zone_percent': '{:.1%}', | |
| 'chase_zone_percent': '{:.1%}', | |
| 'waste_zone_percent': '{:.1%}', | |
| 'csw_percent': '{:.1%}', | |
| 'whiff_rate': '{:.1%}', | |
| 'zone_whiff_percent': '{:.1%}', | |
| 'chase_percent': '{:.1%}', | |
| 'bip': '{:.0f}', | |
| 'xwoba_percent_contact': '{:.3f}' | |
| } | |
| label_translation_dict = { | |
| 'pitch_percent': 'Pitch%', | |
| 'pitches': 'Pitches', | |
| 'heart_zone_percent': 'Heart%', | |
| 'shadow_zone_percent': 'Shadow%', | |
| 'chase_zone_percent': 'Chase%', | |
| 'waste_zone_percent': 'Waste%', | |
| 'csw_percent': 'CSW%', | |
| 'whiff_rate': 'Whiff%', | |
| 'zone_whiff_percent': 'Z-Whiff%', | |
| 'chase_percent': 'O-Swing%', | |
| 'bip': 'BBE', | |
| 'xwoba_percent_contact': 'xwOBACON' | |
| } | |
| def pitch_heat_map(pitch_input, df): | |
| df = df.with_columns([ | |
| pl.col('pitcher_id').count().over(['batter_hand', 'strikes', 'balls']).alias('h_s_b'), | |
| pl.col('pitcher_id').count().over(['batter_hand', 'strikes', 'balls', 'pitch_type']).alias('h_s_b_pitch') | |
| ]) | |
| df = df.with_columns([ | |
| (pl.col('h_s_b_pitch') / pl.col('h_s_b')).alias('h_s_b_pitch_percent') | |
| ]) | |
| df_plot = df.filter(pl.col('pitch_type') == pitch_input) | |
| return df_plot | |
| def pitch_prop(df: pl.DataFrame, hand: str = 'R') -> pd.DataFrame: | |
| df_plot_pd = df.to_pandas() | |
| pivot_table = (df_plot_pd[df_plot_pd['batter_hand'].isin([hand])] | |
| .groupby(['batter_hand','strikes', 'balls'])[['h_s_b_pitch_percent']] | |
| .mean() | |
| .reset_index() | |
| .pivot(index='strikes',columns='balls',values='h_s_b_pitch_percent')) | |
| # Create a new index and columns range | |
| new_index = range(3) | |
| new_columns = range(4) | |
| # Reindex the pivot table | |
| pivot_table = pivot_table.reindex(index=new_index, columns=new_columns) | |
| # Fill any missing values with 0 | |
| pivot_table = pivot_table.fillna(0) | |
| df_hand = pl.DataFrame(pivot_table.reset_index()) | |
| return df_hand | |
| # DEFINE STRIKE ZONE | |
| strike_zone = pd.DataFrame({ | |
| 'PlateLocSide': [-0.9, -0.9, 0.9, 0.9, -0.9], | |
| 'PlateLocHeight': [1.5, 3.5, 3.5, 1.5, 1.5] | |
| }) | |
| ### STRIKE ZONE ### | |
| def draw_line(axis, alpha_spot=1, catcher_p=True): | |
| # Ensure strike_zone columns are NumPy arrays | |
| plate_side = strike_zone['PlateLocSide'].to_numpy() | |
| plate_height = strike_zone['PlateLocHeight'].to_numpy() | |
| # Plot the strike zone | |
| axis.plot(plate_side, plate_height, color='black', linewidth=1.3, zorder=3, alpha=alpha_spot) | |
| if catcher_p: | |
| # Add dashed lines and home plate for catcher perspective | |
| axis.plot([-0.708, 0.708], [0.15, 0.15], color='black', linewidth=1, alpha=alpha_spot, zorder=1) | |
| axis.plot([-0.708, -0.708], [0.15, 0.3], color='black', linewidth=1, alpha=alpha_spot, zorder=1) | |
| axis.plot([-0.708, 0], [0.3, 0.5], color='black', linewidth=1, alpha=alpha_spot, zorder=1) | |
| axis.plot([0, 0.708], [0.5, 0.3], color='black', linewidth=1, alpha=alpha_spot, zorder=1) | |
| axis.plot([0.708, 0.708], [0.3, 0.15], color='black', linewidth=1, alpha=alpha_spot, zorder=1) | |
| else: | |
| # Add dashed lines and home plate for other perspective | |
| axis.plot([-0.708, 0.708], [0.4, 0.4], color='black', linewidth=1, alpha=alpha_spot, zorder=1) | |
| axis.plot([-0.708, -0.9], [0.4, -0.1], color='black', linewidth=1, alpha=alpha_spot, zorder=1) | |
| axis.plot([-0.9, 0], [-0.1, -0.35], color='black', linewidth=1, alpha=alpha_spot, zorder=1) | |
| axis.plot([0, 0.9], [-0.35, -0.1], color='black', linewidth=1, alpha=alpha_spot, zorder=1) | |
| axis.plot([0.9, 0.708], [-0.1, 0.4], color='black', linewidth=1, alpha=alpha_spot, zorder=1) | |
| def heat_map_plot(df:pl.DataFrame, | |
| ax:plt.Axes, | |
| cmap:matplotlib.colors.LinearSegmentedColormap, | |
| hand:str, | |
| scatter:bool): | |
| if df.filter(pl.col('batter_hand')==hand).shape[0] > 3: | |
| sns.kdeplot(data=df.filter(pl.col('batter_hand')==hand), | |
| x='px', | |
| y='pz', | |
| cmap=cmap, | |
| shade=True, | |
| ax=ax, | |
| thresh=0.3, | |
| bw_adjust=1) | |
| elif df.filter(pl.col('batter_hand')==hand).shape[0] > 0: | |
| sns.scatterplot(data=df.filter(pl.col('batter_hand')==hand), | |
| x='px', | |
| y='pz', | |
| ec = 'black', | |
| color="#FFB000", | |
| ax=ax) | |
| if scatter: | |
| sns.scatterplot(data=df.filter(pl.col('batter_hand')==hand), | |
| x='px', | |
| y='pz', | |
| ec = 'black', | |
| color="#FFB000", | |
| ax=ax) | |
| draw_line(ax,alpha_spot=1,catcher_p = False) | |
| ax.axis('off') | |
| ax.axis('square') | |
| ax.set_xlim(-2.75,2.75) | |
| ax.set_ylim(-0.5,5) | |
| def format_as_percentage(val): | |
| return f'{val * 100:.0f}%' | |
| def table_plot(ax:plt.Axes, | |
| table:pl.DataFrame, | |
| hand='R'): | |
| # Create a transformation that converts from data coordinates to axes coordinates | |
| trans = ax.transData + ax.transAxes.inverted() | |
| if hand == 'R': | |
| bbox_data = Bbox.from_bounds(1.7, -0.5, 2.5, 5) | |
| else: | |
| bbox_data = Bbox.from_bounds(-4.2, -0.5, 2.5, 5) # replace width and height with the desired values | |
| bbox_axes = trans.transform_bbox(bbox_data) | |
| if hand == 'R': | |
| ax.text(s='Against RHH',x=2.95,y=4.65,fontsize=18,fontweight='bold',ha='center') | |
| else: | |
| ax.text(s='Against LHH',x=-2.95,y=4.65,fontsize=18,fontweight='bold',ha='center') | |
| table = table.apply(lambda x: format_dict[x.name].format(x[0]) if x[0] != '—' else '—', axis=1) | |
| table.index = [label_translation_dict[x] for x in table.index] | |
| table_plot = ax.table(cellText=table.reset_index().values, | |
| loc='right', | |
| cellLoc='center', | |
| colWidths=[0.52,0.3], | |
| bbox=bbox_axes.bounds,zorder=100) | |
| min_font_size = 14 | |
| # Set table properties | |
| table_plot.auto_set_font_size(False) | |
| #table.set_fontsize(min(min_font_size,max(min_font_size/((len(label_labels)/4)),10))) | |
| table_plot.set_fontsize(min_font_size) | |
| #table_left_plot.scale(1,3) | |
| # Calculate the bbox in axes coordinates | |
| bbox_data = Bbox.from_bounds(-1.25, 5, 2.5, 1) # replace width and height with the desired values | |
| bbox_axes = trans.transform_bbox(bbox_data) | |
| def table_plot_pivot(ax:plt.Axes, | |
| pivot_table:pl.DataFrame, | |
| df_colour:pd.DataFrame): | |
| trans = ax.transData + ax.transAxes.inverted() | |
| bbox_data = Bbox.from_bounds(-0.75, 5, 2.5, 1) # replace width and height with the desired values | |
| bbox_axes = trans.transform_bbox(bbox_data) | |
| table_plot_pivot = ax.table(cellText=[[format_as_percentage(val) for val in row] for row in pivot_table.select(pivot_table.columns[-4:]).to_numpy()], | |
| colLabels =pivot_table.columns[-4:], | |
| rowLabels =[' 0 ',' 1 ',' 2 '], | |
| loc='center', | |
| cellLoc='center', | |
| colWidths=[0.3,0.3,0.30,0.3], | |
| bbox=bbox_axes.bounds,zorder=100, | |
| cellColours = df_colour[df_colour.columns[-4:]].values) | |
| min_font_size = 11 | |
| # Set table properties | |
| table_plot_pivot.auto_set_font_size(False) | |
| #table.set_fontsize(min(min_font_size,max(min_font_size/((len(label_labels)/4)),10))) | |
| table_plot_pivot.set_fontsize(min_font_size) | |
| ax.text(x=-2.0, y=5.08, s='Strikes', rotation=90,fontweight='bold') | |
| ax.text(x=0, y=6.05, s='Balls',fontweight='bold',ha='center') | |
| def plot_header(pitcher_id: str, ax: plt.Axes, df_team: pl.DataFrame, df_players: pl.DataFrame,sport_id:int): | |
| """ | |
| Display the team logo for the given pitcher on the specified axis. | |
| Parameters | |
| ---------- | |
| pitcher_id : str | |
| The ID of the pitcher. | |
| ax : plt.Axes | |
| The axis to display the logo on. | |
| df_team : pl.DataFrame | |
| The DataFrame containing team data. | |
| df_players : pl.DataFrame | |
| The DataFrame containing player data. | |
| """ | |
| # List of MLB teams and their corresponding ESPN logo URLs | |
| mlb_teams = [ | |
| {"team": "AZ", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/ari.png&h=500&w=500"}, | |
| {"team": "ATL", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/atl.png&h=500&w=500"}, | |
| {"team": "BAL", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/bal.png&h=500&w=500"}, | |
| {"team": "BOS", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/bos.png&h=500&w=500"}, | |
| {"team": "CHC", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/chc.png&h=500&w=500"}, | |
| {"team": "CWS", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/chw.png&h=500&w=500"}, | |
| {"team": "CIN", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/cin.png&h=500&w=500"}, | |
| {"team": "CLE", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/cle.png&h=500&w=500"}, | |
| {"team": "COL", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/col.png&h=500&w=500"}, | |
| {"team": "DET", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/det.png&h=500&w=500"}, | |
| {"team": "HOU", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/hou.png&h=500&w=500"}, | |
| {"team": "KC", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/kc.png&h=500&w=500"}, | |
| {"team": "LAA", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/laa.png&h=500&w=500"}, | |
| {"team": "LAD", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/lad.png&h=500&w=500"}, | |
| {"team": "MIA", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/mia.png&h=500&w=500"}, | |
| {"team": "MIL", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/mil.png&h=500&w=500"}, | |
| {"team": "MIN", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/min.png&h=500&w=500"}, | |
| {"team": "NYM", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/nym.png&h=500&w=500"}, | |
| {"team": "NYY", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/nyy.png&h=500&w=500"}, | |
| {"team": "OAK", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/oak.png&h=500&w=500"}, | |
| {"team": "PHI", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/phi.png&h=500&w=500"}, | |
| {"team": "PIT", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/pit.png&h=500&w=500"}, | |
| {"team": "SD", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/sd.png&h=500&w=500"}, | |
| {"team": "SF", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/sf.png&h=500&w=500"}, | |
| {"team": "SEA", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/sea.png&h=500&w=500"}, | |
| {"team": "STL", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/stl.png&h=500&w=500"}, | |
| {"team": "TB", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/tb.png&h=500&w=500"}, | |
| {"team": "TEX", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/tex.png&h=500&w=500"}, | |
| {"team": "TOR", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/tor.png&h=500&w=500"}, | |
| {"team": "WSH", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/wsh.png&h=500&w=500"}, | |
| {"team": "ATH", "logo_url": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/mlb/500/scoreboard/oak.png&h=500&w=500"}, | |
| ] | |
| try: | |
| # Construct the URL for the player's headshot image based on sport ID | |
| if int(sport_id) == 1: | |
| url = f'https://img.mlbstatic.com/mlb-photos/image/upload/d_people:generic:headshot:67:current.png/w_640,q_auto:best/v1/people/{pitcher_id}/headshot/silo/current.png' | |
| else: | |
| url = f'https://img.mlbstatic.com/mlb-photos/image/upload/c_fill,g_auto/w_640/v1/people/{pitcher_id}/headshot/milb/current.png' | |
| # Send a GET request to the URL and open the image from the response content | |
| response = requests.get(url) | |
| img = Image.open(BytesIO(response.content)) | |
| # Display the image on the axis | |
| ax.imshow(img, extent=[-11.5, -9.5, 0, 2] if sport_id == 1 else [-11.5+2/6, -9.5-2/6, 0, 2], origin='upper') | |
| except PIL.UnidentifiedImageError: | |
| ax.axis('off') | |
| try: | |
| # Create a DataFrame from the list of dictionaries | |
| df_image = pd.DataFrame(mlb_teams) | |
| image_dict = df_image.set_index('team')['logo_url'].to_dict() | |
| # Get the team ID for the given pitcher | |
| team_id = df_players.filter(pl.col('player_id') == pitcher_id)['team'][0] | |
| # Construct the URL to fetch team data | |
| url_team = f'https://statsapi.mlb.com/api/v1/teams/{team_id}' | |
| # Send a GET request to the team URL and parse the JSON response | |
| data_team = requests.get(url_team).json() | |
| # Extract the team abbreviation | |
| if data_team['teams'][0]['id'] in df_team['parent_org_id']: | |
| team_abb = df_team.filter(pl.col('team_id') == data_team['teams'][0]['id'])['parent_org_abbreviation'][0] | |
| else: | |
| team_abb = df_team.filter(pl.col('parent_org_id') == data_team['teams'][0]['parentOrgId'])['parent_org_abbreviation'][0] | |
| # Get the logo URL from the image dictionary using the team abbreviation | |
| logo_url = image_dict[team_abb] | |
| # Send a GET request to the logo URL | |
| response = requests.get(logo_url) | |
| # Open the image from the response content | |
| img = Image.open(BytesIO(response.content)) | |
| ax.imshow(img, extent=[9.5, 11.5, 0, 2], origin='upper') | |
| # Turn off the axis | |
| # ax.axis('off') | |
| except (KeyError,IndexError) as e: | |
| ax.axis('off') | |
| return | |
| # DEFINE STRIKE ZONE | |
| strike_zone = pd.DataFrame({ | |
| 'PlateLocSide': [-0.9, -0.9, 0.9, 0.9, -0.9], | |
| 'PlateLocHeight': [1.5, 3.5, 3.5, 1.5, 1.5] | |
| }) | |
| ### STRIKE ZONE ### | |
| def draw_line(axis, alpha_spot=1, catcher_p=True): | |
| # Ensure strike_zone columns are NumPy arrays | |
| plate_side = strike_zone['PlateLocSide'].to_numpy() | |
| plate_height = strike_zone['PlateLocHeight'].to_numpy() | |
| # Plot the strike zone | |
| axis.plot(plate_side, plate_height, color='black', linewidth=1.3, zorder=3, alpha=alpha_spot) | |
| if catcher_p: | |
| # Add dashed lines and home plate for catcher perspective | |
| axis.plot([-0.708, 0.708], [0.15, 0.15], color='black', linewidth=1, alpha=alpha_spot, zorder=1) | |
| axis.plot([-0.708, -0.708], [0.15, 0.3], color='black', linewidth=1, alpha=alpha_spot, zorder=1) | |
| axis.plot([-0.708, 0], [0.3, 0.5], color='black', linewidth=1, alpha=alpha_spot, zorder=1) | |
| axis.plot([0, 0.708], [0.5, 0.3], color='black', linewidth=1, alpha=alpha_spot, zorder=1) | |
| axis.plot([0.708, 0.708], [0.3, 0.15], color='black', linewidth=1, alpha=alpha_spot, zorder=1) | |
| else: | |
| # Add dashed lines and home plate for other perspective | |
| axis.plot([-0.708, 0.708], [0.4, 0.4], color='black', linewidth=1, alpha=alpha_spot, zorder=1) | |
| axis.plot([-0.708, -0.9], [0.4, -0.1], color='black', linewidth=1, alpha=alpha_spot, zorder=1) | |
| axis.plot([-0.9, 0], [-0.1, -0.35], color='black', linewidth=1, alpha=alpha_spot, zorder=1) | |
| axis.plot([0, 0.9], [-0.35, -0.1], color='black', linewidth=1, alpha=alpha_spot, zorder=1) | |
| axis.plot([0.9, 0.708], [-0.1, 0.4], color='black', linewidth=1, alpha=alpha_spot, zorder=1) | |
| from matplotlib.patches import Rectangle | |
| # Function to draw the strike zone and home plate | |
| # Import necessary libraries | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from matplotlib.patches import Rectangle | |
| from matplotlib import gridspec | |
| import numpy as np | |
| import pandas as pd | |
| from statsmodels.nonparametric.kernel_regression import KernelReg | |
| # DEFINE STRIKE ZONE | |
| strike_zone = pd.DataFrame({ | |
| 'PlateLocSide': [-0.9, -0.9, 0.9, 0.9, -0.9], | |
| 'PlateLocHeight': [1.5, 3.5, 3.5, 1.5, 1.5] | |
| }) | |
| def heat_map_plot_hex_whiff(df:pl.DataFrame, | |
| ax:plt.Axes, | |
| cmap:matplotlib.colors.LinearSegmentedColormap, | |
| hand:str, | |
| scatter:bool): | |
| # Generate a grid of x and z coordinates for the strike zone area | |
| heatmap_df = df.filter((pl.col('batter_hand')==hand)&((pl.col('is_swing')))).to_pandas() # Load your data here | |
| heatmap_df['is_whiff'] = heatmap_df['is_whiff'].fillna(0) | |
| if scatter: | |
| # Define fixed color mapping | |
| color_mapping = {True: '#FFB000', False: '#648FFF'} | |
| sns.scatterplot( | |
| data=heatmap_df, | |
| x='px', | |
| y='pz', | |
| ec='black', | |
| hue='is_whiff', # <- color by whiff | |
| palette=color_mapping, # <- use the fixed colors | |
| style='is_whiff', # <- different marker shapes | |
| markers={True: 'X', False: 'o'}, # <- X and O | |
| legend=False, | |
| ax=ax | |
| ) | |
| bin_size = max(0.1, min(0.1, 1 / np.sqrt(len(heatmap_df)))) | |
| zone_df = pd.DataFrame(columns=['px', 'pz']) | |
| for x in np.arange(-2.75, 2.85,bin_size): | |
| for y in np.arange(-0.5, 5.6,bin_size): | |
| zone_df.loc[len(zone_df)] = [round(x,1), round(y,1)] | |
| heatmap_df.loc[heatmap_df['px'].notna(),'kde_x'] = np.clip(heatmap_df.loc[heatmap_df['px'].notna(),'px'].astype('float').mul(10).astype('int').div(10), | |
| -2.75, | |
| 2.75) | |
| heatmap_df.loc[heatmap_df['pz'].notna(),'kde_z'] = np.clip(heatmap_df.loc[heatmap_df['pz'].notna(),'pz'].astype('float').mul(10).astype('int').div(10), | |
| -0.5, | |
| 5) | |
| # Dynamically determine bandwidth for KDE | |
| # bandwidth = np.clip(heatmap_df.shape[0] / 2000, 0.2, 0.2) | |
| bandwidth = np.clip(1 / np.sqrt(len(df)), 0.3, 0.5) | |
| # Kernel Regression for smoothing the metric values | |
| v_center = 0.25 | |
| kde_df = pd.merge(zone_df, | |
| heatmap_df | |
| .dropna(subset=['is_whiff', 'px', 'pz']) | |
| [['kde_x', 'kde_z', 'is_whiff']], | |
| how='left', | |
| left_on=['px', 'pz'], | |
| right_on=['kde_x', 'kde_z']).fillna({'is_whiff': v_center}) | |
| kernel_regression = KernelReg(endog=kde_df['is_whiff'], | |
| exog=[kde_df['px'], kde_df['pz']], | |
| bw=[bandwidth,bandwidth], | |
| var_type='cc') | |
| kde_df['kernel_stat'] = kernel_regression.fit([kde_df['px'], kde_df['pz']])[0] | |
| kde_df = kde_df.pivot_table(columns='px', index='pz', values='kernel_stat', aggfunc='mean') | |
| kde_df = kde_df.round(3) | |
| # Set up a gridspec layout for heatmap and colorbar | |
| from matplotlib.colors import LinearSegmentedColormap | |
| # Define the custom color palette | |
| kde_min = '#648FFF' # Blue | |
| kde_mid = '#B8A175' # White | |
| kde_max = '#FFB000' # Orange | |
| # Create a custom colormap | |
| # kde_palette = LinearSegmentedColormap.from_list("kde_palette", [kde_min, kde_mid, kde_max]) | |
| # kde_palette = (sns.color_palette(f'blend:{kde_min},{kde_mid}', n_colors=101)[:-1] + | |
| # sns.color_palette(f'blend:{kde_mid},{kde_max}', n_colors=101)[:-1]) | |
| # kde_palette = (sns.color_palette(f'blend:{kde_min},{kde_mid}', n_colors=101)[:-1] + | |
| # sns.color_palette(f'blend:{kde_mid},{kde_max}', n_colors=101)[:-1]) | |
| # # Generate the heatmap | |
| # heatmap = sns.heatmap(data=kde_df, cmap=kde_palette, center=v_center, vmin=0.15, vmax=0.35, cbar=False, ax=ax) | |
| ax.imshow(kde_df.values, extent=[-2.25, 2.25, -0.5, 5], origin='lower', cmap=cmap, vmin=0.15, vmax=0.35, | |
| interpolation='bilinear')# Customize axes | |
| ax.axis('square') | |
| ax.set(xlabel=None, ylabel=None) | |
| ax.set_xlim(-2.75, 2.75) | |
| ax.set_ylim(-0.5, 5) | |
| # ax.set_xticks([]) | |
| # ax.set_yticks([]) | |
| # ax.invert_yaxis() | |
| # ax.grid(False) | |
| ax.axis('off') | |
| draw_line(ax,alpha_spot=1,catcher_p = False) | |
| draw_line(ax, alpha_spot=1, catcher_p=False) | |
| def heat_map_plot_hex_damage(df:pl.DataFrame, | |
| ax:plt.Axes, | |
| cmap:matplotlib.colors.LinearSegmentedColormap, | |
| hand:str, | |
| scatter:bool): | |
| heatmap_df = df.filter((pl.col('batter_hand')==hand)&((pl.col('launch_speed')>0))).to_pandas() # Load your data here | |
| heatmap_df['woba_pred_contact'] = heatmap_df['woba_pred_contact'].fillna(0.25) | |
| bin_size = max(0.2, min(0.3, 1 / np.sqrt(len(heatmap_df)))) | |
| if scatter: | |
| cmap_sum_r = matplotlib.colors.LinearSegmentedColormap.from_list("", ['#FFB000','#FFFFFF','#648FFF',]) | |
| print(df.filter((pl.col('batter_hand')==hand)&(pl.col('launch_speed')>0)).select(['batter_name','launch_speed','launch_angle','woba_pred_contact'])) | |
| sns.scatterplot(data=df.filter((pl.col('batter_hand')==hand)&(pl.col('launch_speed')>0)), | |
| x='px', | |
| y='pz', | |
| ec = 'black', | |
| palette=cmap_sum_r, | |
| hue='woba_pred_contact', | |
| hue_norm=mcolors.Normalize(vmin=0, vmax=0.5), # <- HERE | |
| size = 'woba_pred_contact', | |
| legend=False, | |
| ax=ax) | |
| print(df.filter((pl.col('batter_hand')==hand)&(pl.col('launch_speed')>0)).select(['batter_name','launch_speed','launch_angle','woba_pred_contact'])) | |
| # Generate a grid of x and z coordinates for the strike zone area | |
| zone_df = pd.DataFrame(columns=['px', 'pz']) | |
| for x in np.arange(-2.75, 2.95,bin_size): | |
| for y in np.arange(-0.5, 5.7,bin_size): | |
| zone_df.loc[len(zone_df)] = [round(x,1), round(y,1)] | |
| heatmap_df.loc[heatmap_df['px'].notna(),'kde_x'] = np.clip(heatmap_df.loc[heatmap_df['px'].notna(),'px'].astype('float').mul(10).astype('int').div(10), | |
| -2.75, | |
| 2.75) | |
| heatmap_df.loc[heatmap_df['pz'].notna(),'kde_z'] = np.clip(heatmap_df.loc[heatmap_df['pz'].notna(),'pz'].astype('float').mul(10).astype('int').div(10), | |
| -0.5, | |
| 5) | |
| # Dynamically determine bandwidth for KDE | |
| # bandwidth = np.clip(heatmap_df.shape[0] / 2000, 0.2, 0.2) | |
| bandwidth = np.clip(1 / np.sqrt(len(df)), 0.3, 0.5) | |
| # Kernel Regression for smoothing the metric values | |
| v_center = 0.375 | |
| kde_df = pd.merge(zone_df, | |
| heatmap_df | |
| .dropna(subset=['woba_pred_contact', 'px', 'pz']) | |
| [['kde_x', 'kde_z', 'woba_pred_contact']], | |
| how='left', | |
| left_on=['px', 'pz'], | |
| right_on=['kde_x', 'kde_z']).fillna({'woba_pred_contact': v_center}) | |
| kernel_regression = KernelReg(endog=kde_df['woba_pred_contact'], | |
| exog=[kde_df['px'], kde_df['pz']], | |
| bw=[bandwidth,bandwidth], | |
| var_type='cc') | |
| kde_df['kernel_stat'] = kernel_regression.fit([kde_df['px'], kde_df['pz']])[0] | |
| kde_df = kde_df.pivot_table(columns='px', index='pz', values='kernel_stat', aggfunc='mean') | |
| kde_df = kde_df.round(3) | |
| # Set up a gridspec layout for heatmap and colorbar | |
| from matplotlib.colors import LinearSegmentedColormap | |
| # Define the custom color palette | |
| kde_min = '#648FFF' # Blue | |
| kde_mid = '#B8A175' # White | |
| kde_max = '#FFB000' # Orange | |
| # Create a custom colormap | |
| # kde_palette = LinearSegmentedColormap.from_list("kde_palette", [kde_min, kde_mid, kde_max]) | |
| # kde_palette = (sns.color_palette(f'blend:{kde_min},{kde_mid}', n_colors=101)[:-1] + | |
| # sns.color_palette(f'blend:{kde_mid},{kde_max}', n_colors=101)[:-1]) | |
| # kde_palette = (sns.color_palette(f'blend:{kde_min},{kde_mid}', n_colors=101)[:-1] + | |
| # sns.color_palette(f'blend:{kde_mid},{kde_max}', n_colors=101)[:-1]) | |
| # # Generate the heatmap | |
| # heatmap = sns.heatmap(data=kde_df, cmap=kde_palette, center=v_center, vmin=0.15, vmax=0.35, cbar=False, ax=ax) | |
| ax.imshow(kde_df.values, extent=[-2.25, 2.25, -0.5, 5], origin='lower', cmap=cmap, vmin=0.25, vmax=0.5, | |
| interpolation='bilinear')# Customize axes | |
| ax.axis('square') | |
| ax.set(xlabel=None, ylabel=None) | |
| ax.set_xlim(-2.75, 2.75) | |
| ax.set_ylim(-0.5, 5) | |
| # ax.set_xticks([]) | |
| # ax.set_yticks([]) | |
| # ax.invert_yaxis() | |
| # ax.grid(False) | |
| ax.axis('off') | |
| draw_line(ax,alpha_spot=1,catcher_p = False) | |