import os from tempfile import TemporaryDirectory import gradio as gr import numpy as np import pandas as pd import spaces import torch from huggingface_hub import Repository from rlgym_tools.rocket_league.misc.serialize import serialize_game_state, serialize_scoreboard, \ SB_GAME_TIMER_SECONDS, SB_BLUE_SCORE, SB_ORANGE_SCORE from rlgym_tools.rocket_league.replays.convert import replay_to_rlgym from rlgym_tools.rocket_league.replays.parsed_replay import ParsedReplay from tqdm import trange, tqdm # os.chmod("/usr/local/lib/python3.10/site-packages/rlgym_tools/rocket_league/replays/carball", 0o755) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") repo = Repository(local_dir="vortex-ngp", clone_from="Rolv-Arild/vortex-ngp", token=os.getenv("HF_TOKEN")) repo.git_pull() DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") MODEL = torch.jit.load("vortex-ngp/vortex-ngp-glowing-dragon.pt", map_location=DEVICE) MODEL.eval() @spaces.GPU @torch.inference_mode() def infer(model, replay_file, nullify_goal_difference=False, ignore_ties=False): num_outputs = 123 swap_team_idx = torch.arange(num_outputs) mid = num_outputs // 2 swap_team_idx[mid:-1] = swap_team_idx[:mid] swap_team_idx[:mid] += num_outputs // 2 replay = ParsedReplay.load(replay_file) it = tqdm(replay_to_rlgym(replay), desc="Loading replay", total=len(replay.game_df)) replay_frames = [] serialized_states = [] serialized_scoreboards = [] seconds_remaining = [] for replay_frame in it: replay_frames.append(replay_frame) sstate = serialize_game_state(replay_frame.state) sscoreboard = serialize_scoreboard(replay_frame.scoreboard) serialized_states.append(sstate) serialized_scoreboards.append(sscoreboard) seconds_remaining.append(replay_frame.episode_seconds_remaining) serialized_states = torch.from_numpy(np.stack(serialized_states)) serialized_scoreboards = torch.from_numpy(np.stack(serialized_scoreboards)) seconds_remaining = torch.tensor(seconds_remaining) it.close() timer = serialized_scoreboards[:, SB_GAME_TIMER_SECONDS].clone() is_ot = timer > 450 ot_time_remaining = seconds_remaining[is_ot] if len(ot_time_remaining) > 0: ot_timer = ot_time_remaining[0] - ot_time_remaining timer[is_ot] = -ot_timer # Negate to indicate overtime goal_diff = serialized_scoreboards[:, SB_BLUE_SCORE] - serialized_scoreboards[:, SB_ORANGE_SCORE] goal_diff_diff = goal_diff.diff(prepend=torch.Tensor([0])) bs = 900 predictions = [] it = trange(len(serialized_states), desc="Running model") for i in range(0, len(serialized_states), bs): batch = (serialized_states[i:i + bs].clone().to(DEVICE), serialized_scoreboards[i:i + bs].clone().to(DEVICE)) if nullify_goal_difference or ignore_ties: batch[1][:, SB_BLUE_SCORE] = 0 batch[1][:, SB_ORANGE_SCORE] = 0 if ignore_ties: batch[1][:, SB_GAME_TIMER_SECONDS] = float("inf") out = model(*batch) it.update(len(batch[0])) predictions.append(out) predictions = torch.cat(predictions, dim=0) probs = predictions.softmax(dim=-1) bin_seconds = torch.linspace(0, 60, num_outputs // 2) class_names = [ f"{t}: {s:g}s" for t in ["Blue", "Orange"] for s in bin_seconds.tolist() ] class_names.append("Tie") preds = probs.cpu().numpy() preds = pd.DataFrame(data=preds, columns=class_names) preds["Blue"] = preds[[c for c in preds.columns if c.startswith("Blue")]].sum(axis=1) preds["Orange"] = preds[[c for c in preds.columns if c.startswith("Orange")]].sum(axis=1) preds["Timer"] = timer preds["Goal"] = goal_diff_diff preds["Touch"] = "" pid_to_name = {int(p["unique_id"]): p["name"] for p in replay.metadata["players"] if p["unique_id"] in replay.player_dfs} for i, replay_frame in enumerate(replay_frames): state = replay_frame.state for aid, car in state.cars.items(): if car.ball_touches > 0: team = "Blue" if car.is_blue else "Orange" name = pid_to_name[aid] name = name.replace("|", " ") # Replace pipe with space to not conflict with sep if preds.at[i, "Touch"] != "": preds.at[i, "Touch"] += "|" preds.at[i, "Touch"] += f"{team}|{name}" # Sort columns main_cols = ["Timer", "Blue", "Orange", "Tie", "Goal", "Touch"] preds = preds[main_cols + [c for c in preds.columns if c not in main_cols]] # Set index name preds.index.name = "Frame" remove_ties_mask = is_ot if not ignore_ties else torch.ones(len(preds), dtype=torch.bool) remove_ties_mask = remove_ties_mask.numpy() if remove_ties_mask.any(): tie_probs = preds.loc[remove_ties_mask, "Tie"] q = (1 - tie_probs) for c in preds.columns: if c.startswith("Blue") or c.startswith("Orange"): preds.loc[remove_ties_mask, c] /= q if ignore_ties: preds = preds.drop("Tie", axis=1) else: preds.loc[remove_ties_mask, "Tie"] = 0.0 return preds def plot_plotly(preds: pd.DataFrame): import plotly.graph_objects as go preds_df = preds.drop(["Touch", "Timer", "Goal"], axis=1) * 100 timer = preds["Timer"] fig = go.Figure() def format_timer(t): sign = '+' if t < 0 else '' return f"{sign}{abs(t) // 60:01.0f}:{abs(t) % 60:02.0f}" timer_text = [format_timer(t.item()) for t in timer.values] hovertemplate = 'Frame %{x}
Prob: %{y:.3g}%
Timer: %{customdata}' # Add traces for Blue, Orange, and Tie probabilities from the DataFrame fig.add_trace( go.Scatter(x=preds_df.index, y=preds_df["Blue"], mode='lines', name='Blue', line=dict(color='blue'), customdata=timer_text, hovertemplate=hovertemplate)) fig.add_trace( go.Scatter(x=preds_df.index, y=preds_df["Orange"], mode='lines', name='Orange', line=dict(color='orange'), customdata=timer_text, hovertemplate=hovertemplate)) if "Tie" in preds.columns: fig.add_trace( go.Scatter(x=preds_df.index, y=preds_df["Tie"], mode='lines', name='Tie', line=dict(color='gray'), customdata=timer_text, hovertemplate=hovertemplate)) # Add the horizontal line at y=50% fig.add_hline(y=50, line_dash="dash", line_color="black", name="50% Probability") # Add goal indicators b = o = 0 for goal_frame in preds["Goal"].index[preds["Goal"] != 0]: if preds["Goal"][goal_frame] > 0: b += 1 elif preds["Goal"][goal_frame] < 0: o += 1 fig.add_vline(x=goal_frame, line_dash="dash", line_color="red", annotation_text=f"{b}-{o}", annotation_position="top right") # Add touch indicators as points touches = {} for touch_frame in preds.index[preds["Touch"] != ""]: teams_players = preds.at[touch_frame, "Touch"].split('|') for team, player in zip(teams_players[::2], teams_players[1::2]): team = team.strip() player = player.strip() touches.setdefault(team, []).append((touch_frame, player)) for team in "Blue", "Orange": team_touches = touches.get(team, []) if not team_touches: continue x = [t[0] for t in team_touches] y = [preds_df.at[t[0], team] for t in team_touches] touch_players = [t[1] for t in team_touches] custom_data = [f"{timer_text[f]}
Touch by {p}" for f, p in zip(x, touch_players)] fig.add_trace( go.Scatter(x=x, y=y, mode='markers', name=f'{team} touches', marker=dict(size=5, color=team.lower(), symbol='circle-open-dot'), customdata=custom_data, hovertemplate=hovertemplate )) # Define the formatting function for the secondary x-axis labels def format_timer_ticks(x): """Converts a frame number to a formatted time string.""" x = int(x) # Ensure the index is within the bounds of the timer series x = max(0, min(x, len(timer) - 1)) # Calculate the time value t = timer.iloc[x] * 300 # Format the time as MM:SS, with a '+' for negative values (representing overtime) sign = '+' if t < 0 else '' minutes = int(abs(t) // 60) seconds = int(abs(t) % 60) return f"{sign}{minutes:01}:{seconds:02}" # Generate positions and labels for the secondary axis ticks # Creates 10 evenly spaced ticks for clarity tick_positions = np.linspace(0, len(preds_df) - 1, 10) tick_labels = [format_timer_ticks(val) for val in tick_positions] # Configure the figure's layout, titles, and both x-axes fig.update_layout( title="Interactive Probability Plot", xaxis=dict( title="Frame", gridcolor='#e5e7eb' # A light gray grid for a modern look ), yaxis=dict( title="Probability", gridcolor='#e5e7eb' ), # --- Secondary X-Axis Configuration --- xaxis2=dict( title="Timer", overlaying='x', # This makes it a secondary axis side='top', # Position it at the top tickmode='array', tickvals=tick_positions, ticktext=tick_labels ), legend=dict(x=0.01, y=0.99, yanchor="top", xanchor="left"), # Position legend inside plot plot_bgcolor='white' # A clean white background ) # fig.show() return fig DESCRIPTION = """ # Next Goal Predictor Upload a replay file to get a plot of the next goal prediction. The model is trained on about 9000 hours of SSL and RLCS replays in 1v1, 2v2, and 3v3 using [this dataset](https://www.kaggle.com/datasets/rolvarild/high-level-rocket-league-replay-dataset).
It predicts the probability that each team will score at 1 second intervals up to 60+ seconds. It also predicts ties (ball hitting the ground at 0s).
The plot only shows the totals for each team, but you can download the full predictions if you want.
(Sorry if it's a bit slow, it's running on CPU.) """.strip() RADIO_OPTIONS = ["Default", "Nullify goal difference", "Ignore ties"] RADIO_INFO = """ - **Default**: Uses the model as it is trained, with no modifications. - **Nullify goal difference**: Makes the model think the goal difference is always 0, so it doesn't have a bias towards one team. - **Ignore ties**: Makes the model pretend every situation is an overtime (e.g. ties are impossible). """.strip() with TemporaryDirectory() as temp_dir: with gr.Blocks() as demo: gr.Markdown(DESCRIPTION) # Use gr.Column to stack components vertically with gr.Column(): file_input = gr.File(label="Upload Replay File", type="filepath", file_types=[".replay"]) checkboxes = gr.Radio(label="Options", choices=RADIO_OPTIONS, type="index", value=RADIO_OPTIONS[0], info=RADIO_INFO) submit_button = gr.Button("Generate Predictions") plot_output = gr.Plot(label="Predictions") download_button = gr.DownloadButton("Download Predictions", visible=False) def make_plot(replay_file, radio_option, progress=gr.Progress(track_tqdm=True)): # Make plot on button click nullify_goal_difference = radio_option == 1 ignore_ties = radio_option == 2 print(f"Processing file: {replay_file}") replay_stem = os.path.splitext(os.path.basename(replay_file))[0] postfix = "" if nullify_goal_difference: postfix += "_nullify_goal_difference" elif ignore_ties: postfix += "_ignore_ties" preds_file = os.path.join(temp_dir, f"predictions_{replay_stem}{postfix}.csv") if os.path.exists(preds_file): print(f"Predictions file already exists: {preds_file}") preds = pd.read_csv(preds_file, dtype={"Touch": str}) preds["Touch"] = preds["Touch"].fillna("") else: preds = infer(MODEL, replay_file, nullify_goal_difference=nullify_goal_difference, ignore_ties=ignore_ties) plt = plot_plotly(preds) print(f"Plot generated for file: {replay_file}") preds.to_csv(preds_file) if len(os.listdir(temp_dir)) > 100: # Delete least recent file oldest_file = min(os.listdir(temp_dir), key=lambda f: os.path.getctime(os.path.join(temp_dir, f))) os.remove(os.path.join(temp_dir, oldest_file)) return plt, gr.DownloadButton(value=preds_file, visible=True) submit_button.click( fn=make_plot, inputs=[file_input, checkboxes], outputs=[plot_output, download_button], show_progress="full", ) demo.queue(default_concurrency_limit=None) demo.launch()