Spaces:
Running
Running
| import os | |
| from tempfile import TemporaryDirectory | |
| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| import spaces | |
| import torch | |
| from huggingface_hub import Repository | |
| from rlgym_tools.rocket_league.misc.serialize import serialize_game_state, serialize_scoreboard, \ | |
| SB_GAME_TIMER_SECONDS, SB_BLUE_SCORE, SB_ORANGE_SCORE | |
| from rlgym_tools.rocket_league.replays.convert import replay_to_rlgym | |
| from rlgym_tools.rocket_league.replays.parsed_replay import ParsedReplay | |
| from tqdm import trange, tqdm | |
| # os.chmod("/usr/local/lib/python3.10/site-packages/rlgym_tools/rocket_league/replays/carball", 0o755) | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| repo = Repository(local_dir="vortex-ngp", clone_from="Rolv-Arild/vortex-ngp", token=os.getenv("HF_TOKEN")) | |
| repo.git_pull() | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| MODEL = torch.jit.load("vortex-ngp/vortex-ngp-glowing-dragon.pt", map_location=DEVICE) | |
| MODEL.eval() | |
| def infer(model, replay_file, | |
| nullify_goal_difference=False, | |
| ignore_ties=False): | |
| num_outputs = 123 | |
| swap_team_idx = torch.arange(num_outputs) | |
| mid = num_outputs // 2 | |
| swap_team_idx[mid:-1] = swap_team_idx[:mid] | |
| swap_team_idx[:mid] += num_outputs // 2 | |
| replay = ParsedReplay.load(replay_file) | |
| it = tqdm(replay_to_rlgym(replay), desc="Loading replay", total=len(replay.game_df)) | |
| replay_frames = [] | |
| serialized_states = [] | |
| serialized_scoreboards = [] | |
| seconds_remaining = [] | |
| for replay_frame in it: | |
| replay_frames.append(replay_frame) | |
| sstate = serialize_game_state(replay_frame.state) | |
| sscoreboard = serialize_scoreboard(replay_frame.scoreboard) | |
| serialized_states.append(sstate) | |
| serialized_scoreboards.append(sscoreboard) | |
| seconds_remaining.append(replay_frame.episode_seconds_remaining) | |
| serialized_states = torch.from_numpy(np.stack(serialized_states)) | |
| serialized_scoreboards = torch.from_numpy(np.stack(serialized_scoreboards)) | |
| seconds_remaining = torch.tensor(seconds_remaining) | |
| it.close() | |
| timer = serialized_scoreboards[:, SB_GAME_TIMER_SECONDS].clone() | |
| is_ot = timer > 450 | |
| ot_time_remaining = seconds_remaining[is_ot] | |
| if len(ot_time_remaining) > 0: | |
| ot_timer = ot_time_remaining[0] - ot_time_remaining | |
| timer[is_ot] = -ot_timer # Negate to indicate overtime | |
| goal_diff = serialized_scoreboards[:, SB_BLUE_SCORE] - serialized_scoreboards[:, SB_ORANGE_SCORE] | |
| goal_diff_diff = goal_diff.diff(prepend=torch.Tensor([0])) | |
| bs = 900 | |
| predictions = [] | |
| it = trange(len(serialized_states), desc="Running model") | |
| for i in range(0, len(serialized_states), bs): | |
| batch = (serialized_states[i:i + bs].clone().to(DEVICE), | |
| serialized_scoreboards[i:i + bs].clone().to(DEVICE)) | |
| if nullify_goal_difference or ignore_ties: | |
| batch[1][:, SB_BLUE_SCORE] = 0 | |
| batch[1][:, SB_ORANGE_SCORE] = 0 | |
| if ignore_ties: | |
| batch[1][:, SB_GAME_TIMER_SECONDS] = float("inf") | |
| out = model(*batch) | |
| it.update(len(batch[0])) | |
| predictions.append(out) | |
| predictions = torch.cat(predictions, dim=0) | |
| probs = predictions.softmax(dim=-1) | |
| bin_seconds = torch.linspace(0, 60, num_outputs // 2) | |
| class_names = [ | |
| f"{t}: {s:g}s" for t in ["Blue", "Orange"] | |
| for s in bin_seconds.tolist() | |
| ] | |
| class_names.append("Tie") | |
| preds = probs.cpu().numpy() | |
| preds = pd.DataFrame(data=preds, columns=class_names) | |
| preds["Blue"] = preds[[c for c in preds.columns if c.startswith("Blue")]].sum(axis=1) | |
| preds["Orange"] = preds[[c for c in preds.columns if c.startswith("Orange")]].sum(axis=1) | |
| preds["Timer"] = timer | |
| preds["Goal"] = goal_diff_diff | |
| preds["Touch"] = "" | |
| pid_to_name = {int(p["unique_id"]): p["name"] | |
| for p in replay.metadata["players"] | |
| if p["unique_id"] in replay.player_dfs} | |
| for i, replay_frame in enumerate(replay_frames): | |
| state = replay_frame.state | |
| for aid, car in state.cars.items(): | |
| if car.ball_touches > 0: | |
| team = "Blue" if car.is_blue else "Orange" | |
| name = pid_to_name[aid] | |
| name = name.replace("|", " ") # Replace pipe with space to not conflict with sep | |
| if preds.at[i, "Touch"] != "": | |
| preds.at[i, "Touch"] += "|" | |
| preds.at[i, "Touch"] += f"{team}|{name}" | |
| # Sort columns | |
| main_cols = ["Timer", "Blue", "Orange", "Tie", "Goal", "Touch"] | |
| preds = preds[main_cols + [c for c in preds.columns if c not in main_cols]] | |
| # Set index name | |
| preds.index.name = "Frame" | |
| remove_ties_mask = is_ot if not ignore_ties else torch.ones(len(preds), dtype=torch.bool) | |
| remove_ties_mask = remove_ties_mask.numpy() | |
| if remove_ties_mask.any(): | |
| tie_probs = preds.loc[remove_ties_mask, "Tie"] | |
| q = (1 - tie_probs) | |
| for c in preds.columns: | |
| if c.startswith("Blue") or c.startswith("Orange"): | |
| preds.loc[remove_ties_mask, c] /= q | |
| if ignore_ties: | |
| preds = preds.drop("Tie", axis=1) | |
| else: | |
| preds.loc[remove_ties_mask, "Tie"] = 0.0 | |
| return preds | |
| def plot_plotly(preds: pd.DataFrame): | |
| import plotly.graph_objects as go | |
| preds_df = preds.drop(["Touch", "Timer", "Goal"], axis=1) * 100 | |
| timer = preds["Timer"] | |
| fig = go.Figure() | |
| def format_timer(t): | |
| sign = '+' if t < 0 else '' | |
| return f"{sign}{abs(t) // 60:01.0f}:{abs(t) % 60:02.0f}" | |
| timer_text = [format_timer(t.item()) for t in timer.values] | |
| hovertemplate = '<b>Frame %{x}</b><br>Prob: %{y:.3g}%<br>Timer: %{customdata}<extra></extra>' | |
| # Add traces for Blue, Orange, and Tie probabilities from the DataFrame | |
| fig.add_trace( | |
| go.Scatter(x=preds_df.index, y=preds_df["Blue"], | |
| mode='lines', name='Blue', line=dict(color='blue'), | |
| customdata=timer_text, hovertemplate=hovertemplate)) | |
| fig.add_trace( | |
| go.Scatter(x=preds_df.index, y=preds_df["Orange"], | |
| mode='lines', name='Orange', line=dict(color='orange'), | |
| customdata=timer_text, hovertemplate=hovertemplate)) | |
| if "Tie" in preds.columns: | |
| fig.add_trace( | |
| go.Scatter(x=preds_df.index, y=preds_df["Tie"], | |
| mode='lines', name='Tie', line=dict(color='gray'), | |
| customdata=timer_text, hovertemplate=hovertemplate)) | |
| # Add the horizontal line at y=50% | |
| fig.add_hline(y=50, line_dash="dash", line_color="black", name="50% Probability") | |
| # Add goal indicators | |
| b = o = 0 | |
| for goal_frame in preds["Goal"].index[preds["Goal"] != 0]: | |
| if preds["Goal"][goal_frame] > 0: | |
| b += 1 | |
| elif preds["Goal"][goal_frame] < 0: | |
| o += 1 | |
| fig.add_vline(x=goal_frame, line_dash="dash", line_color="red", | |
| annotation_text=f"{b}-{o}", annotation_position="top right") | |
| # Add touch indicators as points | |
| touches = {} | |
| for touch_frame in preds.index[preds["Touch"] != ""]: | |
| teams_players = preds.at[touch_frame, "Touch"].split('|') | |
| for team, player in zip(teams_players[::2], teams_players[1::2]): | |
| team = team.strip() | |
| player = player.strip() | |
| touches.setdefault(team, []).append((touch_frame, player)) | |
| for team in "Blue", "Orange": | |
| team_touches = touches.get(team, []) | |
| if not team_touches: | |
| continue | |
| x = [t[0] for t in team_touches] | |
| y = [preds_df.at[t[0], team] for t in team_touches] | |
| touch_players = [t[1] for t in team_touches] | |
| custom_data = [f"{timer_text[f]}<br>Touch by {p}" | |
| for f, p in zip(x, touch_players)] | |
| fig.add_trace( | |
| go.Scatter(x=x, y=y, | |
| mode='markers', | |
| name=f'{team} touches', | |
| marker=dict(size=5, color=team.lower(), symbol='circle-open-dot'), | |
| customdata=custom_data, | |
| hovertemplate=hovertemplate | |
| )) | |
| # Define the formatting function for the secondary x-axis labels | |
| def format_timer_ticks(x): | |
| """Converts a frame number to a formatted time string.""" | |
| x = int(x) | |
| # Ensure the index is within the bounds of the timer series | |
| x = max(0, min(x, len(timer) - 1)) | |
| # Calculate the time value | |
| t = timer.iloc[x] * 300 | |
| # Format the time as MM:SS, with a '+' for negative values (representing overtime) | |
| sign = '+' if t < 0 else '' | |
| minutes = int(abs(t) // 60) | |
| seconds = int(abs(t) % 60) | |
| return f"{sign}{minutes:01}:{seconds:02}" | |
| # Generate positions and labels for the secondary axis ticks | |
| # Creates 10 evenly spaced ticks for clarity | |
| tick_positions = np.linspace(0, len(preds_df) - 1, 10) | |
| tick_labels = [format_timer_ticks(val) for val in tick_positions] | |
| # Configure the figure's layout, titles, and both x-axes | |
| fig.update_layout( | |
| title="Interactive Probability Plot", | |
| xaxis=dict( | |
| title="Frame", | |
| gridcolor='#e5e7eb' # A light gray grid for a modern look | |
| ), | |
| yaxis=dict( | |
| title="Probability", | |
| gridcolor='#e5e7eb' | |
| ), | |
| # --- Secondary X-Axis Configuration --- | |
| xaxis2=dict( | |
| title="Timer", | |
| overlaying='x', # This makes it a secondary axis | |
| side='top', # Position it at the top | |
| tickmode='array', | |
| tickvals=tick_positions, | |
| ticktext=tick_labels | |
| ), | |
| legend=dict(x=0.01, y=0.99, yanchor="top", xanchor="left"), # Position legend inside plot | |
| plot_bgcolor='white' # A clean white background | |
| ) | |
| # fig.show() | |
| return fig | |
| DESCRIPTION = """ | |
| # Next Goal Predictor | |
| Upload a replay file to get a plot of the next goal prediction. | |
| The model is trained on about 9000 hours of SSL and RLCS replays in 1v1, 2v2, and 3v3 using [this dataset](https://www.kaggle.com/datasets/rolvarild/high-level-rocket-league-replay-dataset).<br> | |
| It predicts the probability that each team will score at 1 second intervals up to 60+ seconds. | |
| It also predicts ties (ball hitting the ground at 0s).<br> | |
| The plot only shows the totals for each team, but you can download the full predictions if you want.<br> | |
| <sub>(Sorry if it's a bit slow, it's running on CPU.)<sub> | |
| """.strip() | |
| RADIO_OPTIONS = ["Default", "Nullify goal difference", "Ignore ties"] | |
| RADIO_INFO = """ | |
| - **Default**: Uses the model as it is trained, with no modifications. | |
| - **Nullify goal difference**: Makes the model think the goal difference is always 0, so it doesn't have a bias towards one team. | |
| - **Ignore ties**: Makes the model pretend every situation is an overtime (e.g. ties are impossible). | |
| """.strip() | |
| with TemporaryDirectory() as temp_dir: | |
| with gr.Blocks() as demo: | |
| gr.Markdown(DESCRIPTION) | |
| # Use gr.Column to stack components vertically | |
| with gr.Column(): | |
| file_input = gr.File(label="Upload Replay File", type="filepath", file_types=[".replay"]) | |
| checkboxes = gr.Radio(label="Options", choices=RADIO_OPTIONS, type="index", value=RADIO_OPTIONS[0], | |
| info=RADIO_INFO) | |
| submit_button = gr.Button("Generate Predictions") | |
| plot_output = gr.Plot(label="Predictions") | |
| download_button = gr.DownloadButton("Download Predictions", visible=False) | |
| def make_plot(replay_file, radio_option, progress=gr.Progress(track_tqdm=True)): | |
| # Make plot on button click | |
| nullify_goal_difference = radio_option == 1 | |
| ignore_ties = radio_option == 2 | |
| print(f"Processing file: {replay_file}") | |
| replay_stem = os.path.splitext(os.path.basename(replay_file))[0] | |
| postfix = "" | |
| if nullify_goal_difference: | |
| postfix += "_nullify_goal_difference" | |
| elif ignore_ties: | |
| postfix += "_ignore_ties" | |
| preds_file = os.path.join(temp_dir, f"predictions_{replay_stem}{postfix}.csv") | |
| if os.path.exists(preds_file): | |
| print(f"Predictions file already exists: {preds_file}") | |
| preds = pd.read_csv(preds_file, dtype={"Touch": str}) | |
| preds["Touch"] = preds["Touch"].fillna("") | |
| else: | |
| preds = infer(MODEL, replay_file, | |
| nullify_goal_difference=nullify_goal_difference, | |
| ignore_ties=ignore_ties) | |
| plt = plot_plotly(preds) | |
| print(f"Plot generated for file: {replay_file}") | |
| preds.to_csv(preds_file) | |
| if len(os.listdir(temp_dir)) > 100: | |
| # Delete least recent file | |
| oldest_file = min(os.listdir(temp_dir), key=lambda f: os.path.getctime(os.path.join(temp_dir, f))) | |
| os.remove(os.path.join(temp_dir, oldest_file)) | |
| return plt, gr.DownloadButton(value=preds_file, visible=True) | |
| submit_button.click( | |
| fn=make_plot, | |
| inputs=[file_input, checkboxes], | |
| outputs=[plot_output, download_button], | |
| show_progress="full", | |
| ) | |
| demo.queue(default_concurrency_limit=None) | |
| demo.launch() | |