Ben
Enable ZeroGPU inference
7ed4dd8
import os
if os.getenv("SPACE_ID") is not None:
os.environ["SDL_VIDEODRIVER"] = "dummy"
os.environ["SDL_AUDIODRIVER"] = "dummy"
import torch
import torch.nn as nn
import numpy as np
import gymnasium as gym
import imageio
import gradio as gr
import spaces
from huggingface_hub import hf_hub_download
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
nn.init.orthogonal_(layer.weight, std)
nn.init.constant_(layer.bias, bias_const)
return layer
def get_actor_network(state_dim=8, action_dim=4):
actor = nn.Sequential(
layer_init(nn.Linear(state_dim, 64)),
nn.Tanh(),
layer_init(nn.Linear(64, 64)),
nn.Tanh(),
layer_init(nn.Linear(64, action_dim), std=0.01),
)
return actor
@spaces.GPU(duration=60)
def simulate_agent(stage_selection):
weight_mapping = {
"Stage 1: Baseline": "1_baseline.pth",
"Stage 2: Surrogate Hacking": "2_surrogate_hacking_attention.pth",
"Stage 3: Temporal Paradox ": "3_temporal_paradox_variance.pth",
"Stage 4: Target Decoupling": "4_target_decoupling_final.pth"
}
filename = weight_mapping.get(stage_selection)
repo_id = "ben-dlwlrma/Representation-Over-Routing"
try:
weights_path = hf_hub_download(repo_id=repo_id, filename=filename)
except Exception as e:
raise gr.Error(f"Weight download failed. Error: {str(e)}")
env = gym.make("LunarLander-v3", render_mode="rgb_array")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
actor = get_actor_network(state_dim=8, action_dim=4).to(device)
try:
actor.load_state_dict(torch.load(weights_path, map_location=device, weights_only=True))
actor.eval()
except Exception as e:
env.close()
raise gr.Error(f"Architecture mismatch. Error: {str(e)}")
state, _ = env.reset(seed=32)
done = False
frames = []
step_count = 0
while not done and step_count < 600:
try:
frame = env.render()
if frame is not None:
frames.append(frame)
except Exception as e:
env.close()
raise gr.Error(f"Render failed: {str(e)}")
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
with torch.no_grad():
action_logits = actor(state_tensor)
action = torch.argmax(action_logits, dim=1).item()
state, _, terminated, truncated, _ = env.step(action)
step_count += 1
done = terminated or truncated
env.close()
video_filename = "eval_output.mp4"
fps = 30
try:
imageio.mimsave(video_filename, frames, fps=fps, codec='libx264', pixelformat='yuv420p')
except Exception as e:
raise gr.Error(f"Video encoding failed: {str(e)}")
return video_filename
with gr.Blocks(title="Representation over Routing", theme=gr.themes.Base()) as demo:
gr.Markdown("## Representation over Routing")
gr.Markdown("Multi-timescale RL evaluation environment. Select an ablation stage to visualize policy behavior.")
with gr.Row():
with gr.Column(scale=1):
model_dropdown = gr.Dropdown(
choices=[
"Stage 1: Baseline",
"Stage 2: Surrogate Hacking",
"Stage 3: Temporal Paradox ",
"Stage 4: Target Decoupling"
],
value="Stage 4: Target Decoupling",
label="Model Stage"
)
run_button = gr.Button("Run Inference", variant="primary")
with gr.Column(scale=2):
video_output = gr.Video(label="Environment Render", autoplay=True)
run_button.click(
fn=simulate_agent,
inputs=[model_dropdown],
outputs=[video_output]
)
if __name__ == "__main__":
demo.launch()