Spaces:
Running on Zero
Running on Zero
| import os | |
| if os.getenv("SPACE_ID") is not None: | |
| os.environ["SDL_VIDEODRIVER"] = "dummy" | |
| os.environ["SDL_AUDIODRIVER"] = "dummy" | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import gymnasium as gym | |
| import imageio | |
| import gradio as gr | |
| import spaces | |
| from huggingface_hub import hf_hub_download | |
| def layer_init(layer, std=np.sqrt(2), bias_const=0.0): | |
| nn.init.orthogonal_(layer.weight, std) | |
| nn.init.constant_(layer.bias, bias_const) | |
| return layer | |
| def get_actor_network(state_dim=8, action_dim=4): | |
| actor = nn.Sequential( | |
| layer_init(nn.Linear(state_dim, 64)), | |
| nn.Tanh(), | |
| layer_init(nn.Linear(64, 64)), | |
| nn.Tanh(), | |
| layer_init(nn.Linear(64, action_dim), std=0.01), | |
| ) | |
| return actor | |
| def simulate_agent(stage_selection): | |
| weight_mapping = { | |
| "Stage 1: Baseline": "1_baseline.pth", | |
| "Stage 2: Surrogate Hacking": "2_surrogate_hacking_attention.pth", | |
| "Stage 3: Temporal Paradox ": "3_temporal_paradox_variance.pth", | |
| "Stage 4: Target Decoupling": "4_target_decoupling_final.pth" | |
| } | |
| filename = weight_mapping.get(stage_selection) | |
| repo_id = "ben-dlwlrma/Representation-Over-Routing" | |
| try: | |
| weights_path = hf_hub_download(repo_id=repo_id, filename=filename) | |
| except Exception as e: | |
| raise gr.Error(f"Weight download failed. Error: {str(e)}") | |
| env = gym.make("LunarLander-v3", render_mode="rgb_array") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| actor = get_actor_network(state_dim=8, action_dim=4).to(device) | |
| try: | |
| actor.load_state_dict(torch.load(weights_path, map_location=device, weights_only=True)) | |
| actor.eval() | |
| except Exception as e: | |
| env.close() | |
| raise gr.Error(f"Architecture mismatch. Error: {str(e)}") | |
| state, _ = env.reset(seed=32) | |
| done = False | |
| frames = [] | |
| step_count = 0 | |
| while not done and step_count < 600: | |
| try: | |
| frame = env.render() | |
| if frame is not None: | |
| frames.append(frame) | |
| except Exception as e: | |
| env.close() | |
| raise gr.Error(f"Render failed: {str(e)}") | |
| state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| action_logits = actor(state_tensor) | |
| action = torch.argmax(action_logits, dim=1).item() | |
| state, _, terminated, truncated, _ = env.step(action) | |
| step_count += 1 | |
| done = terminated or truncated | |
| env.close() | |
| video_filename = "eval_output.mp4" | |
| fps = 30 | |
| try: | |
| imageio.mimsave(video_filename, frames, fps=fps, codec='libx264', pixelformat='yuv420p') | |
| except Exception as e: | |
| raise gr.Error(f"Video encoding failed: {str(e)}") | |
| return video_filename | |
| with gr.Blocks(title="Representation over Routing", theme=gr.themes.Base()) as demo: | |
| gr.Markdown("## Representation over Routing") | |
| gr.Markdown("Multi-timescale RL evaluation environment. Select an ablation stage to visualize policy behavior.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| model_dropdown = gr.Dropdown( | |
| choices=[ | |
| "Stage 1: Baseline", | |
| "Stage 2: Surrogate Hacking", | |
| "Stage 3: Temporal Paradox ", | |
| "Stage 4: Target Decoupling" | |
| ], | |
| value="Stage 4: Target Decoupling", | |
| label="Model Stage" | |
| ) | |
| run_button = gr.Button("Run Inference", variant="primary") | |
| with gr.Column(scale=2): | |
| video_output = gr.Video(label="Environment Render", autoplay=True) | |
| run_button.click( | |
| fn=simulate_agent, | |
| inputs=[model_dropdown], | |
| outputs=[video_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |