Spaces:
Running on Zero
Running on Zero
File size: 3,917 Bytes
69f399c 7ed4dd8 69f399c 7ed4dd8 69f399c 7ed4dd8 69f399c 7ed4dd8 69f399c b13b7c1 69f399c c533f9b 69f399c 7ed4dd8 69f399c b13b7c1 69f399c b13b7c1 69f399c b13b7c1 69f399c b13b7c1 69f399c b13b7c1 69f399c 7ed4dd8 69f399c b13b7c1 69f399c b13b7c1 69f399c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | import os
if os.getenv("SPACE_ID") is not None:
os.environ["SDL_VIDEODRIVER"] = "dummy"
os.environ["SDL_AUDIODRIVER"] = "dummy"
import torch
import torch.nn as nn
import numpy as np
import gymnasium as gym
import imageio
import gradio as gr
import spaces
from huggingface_hub import hf_hub_download
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
nn.init.orthogonal_(layer.weight, std)
nn.init.constant_(layer.bias, bias_const)
return layer
def get_actor_network(state_dim=8, action_dim=4):
actor = nn.Sequential(
layer_init(nn.Linear(state_dim, 64)),
nn.Tanh(),
layer_init(nn.Linear(64, 64)),
nn.Tanh(),
layer_init(nn.Linear(64, action_dim), std=0.01),
)
return actor
@spaces.GPU(duration=60)
def simulate_agent(stage_selection):
weight_mapping = {
"Stage 1: Baseline": "1_baseline.pth",
"Stage 2: Surrogate Hacking": "2_surrogate_hacking_attention.pth",
"Stage 3: Temporal Paradox ": "3_temporal_paradox_variance.pth",
"Stage 4: Target Decoupling": "4_target_decoupling_final.pth"
}
filename = weight_mapping.get(stage_selection)
repo_id = "ben-dlwlrma/Representation-Over-Routing"
try:
weights_path = hf_hub_download(repo_id=repo_id, filename=filename)
except Exception as e:
raise gr.Error(f"Weight download failed. Error: {str(e)}")
env = gym.make("LunarLander-v3", render_mode="rgb_array")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
actor = get_actor_network(state_dim=8, action_dim=4).to(device)
try:
actor.load_state_dict(torch.load(weights_path, map_location=device, weights_only=True))
actor.eval()
except Exception as e:
env.close()
raise gr.Error(f"Architecture mismatch. Error: {str(e)}")
state, _ = env.reset(seed=32)
done = False
frames = []
step_count = 0
while not done and step_count < 600:
try:
frame = env.render()
if frame is not None:
frames.append(frame)
except Exception as e:
env.close()
raise gr.Error(f"Render failed: {str(e)}")
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
with torch.no_grad():
action_logits = actor(state_tensor)
action = torch.argmax(action_logits, dim=1).item()
state, _, terminated, truncated, _ = env.step(action)
step_count += 1
done = terminated or truncated
env.close()
video_filename = "eval_output.mp4"
fps = 30
try:
imageio.mimsave(video_filename, frames, fps=fps, codec='libx264', pixelformat='yuv420p')
except Exception as e:
raise gr.Error(f"Video encoding failed: {str(e)}")
return video_filename
with gr.Blocks(title="Representation over Routing", theme=gr.themes.Base()) as demo:
gr.Markdown("## Representation over Routing")
gr.Markdown("Multi-timescale RL evaluation environment. Select an ablation stage to visualize policy behavior.")
with gr.Row():
with gr.Column(scale=1):
model_dropdown = gr.Dropdown(
choices=[
"Stage 1: Baseline",
"Stage 2: Surrogate Hacking",
"Stage 3: Temporal Paradox ",
"Stage 4: Target Decoupling"
],
value="Stage 4: Target Decoupling",
label="Model Stage"
)
run_button = gr.Button("Run Inference", variant="primary")
with gr.Column(scale=2):
video_output = gr.Video(label="Environment Render", autoplay=True)
run_button.click(
fn=simulate_agent,
inputs=[model_dropdown],
outputs=[video_output]
)
if __name__ == "__main__":
demo.launch()
|