File size: 3,917 Bytes
69f399c
 
 
 
 
 
 
 
 
 
 
7ed4dd8
69f399c
 
7ed4dd8
69f399c
 
 
 
 
7ed4dd8
69f399c
 
 
 
 
 
 
 
 
 
7ed4dd8
 
69f399c
 
 
 
 
 
 
 
 
 
 
 
 
b13b7c1
69f399c
c533f9b
69f399c
7ed4dd8
69f399c
 
 
 
 
 
 
b13b7c1
69f399c
 
 
 
 
 
 
 
 
 
 
 
 
b13b7c1
69f399c
 
 
 
 
 
b13b7c1
69f399c
 
 
 
 
 
 
 
 
 
b13b7c1
69f399c
b13b7c1
69f399c
7ed4dd8
69f399c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b13b7c1
69f399c
 
 
 
b13b7c1
69f399c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
if os.getenv("SPACE_ID") is not None:
    os.environ["SDL_VIDEODRIVER"] = "dummy"
    os.environ["SDL_AUDIODRIVER"] = "dummy"

import torch
import torch.nn as nn
import numpy as np
import gymnasium as gym
import imageio
import gradio as gr
import spaces
from huggingface_hub import hf_hub_download


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    nn.init.orthogonal_(layer.weight, std)
    nn.init.constant_(layer.bias, bias_const)
    return layer


def get_actor_network(state_dim=8, action_dim=4):
    actor = nn.Sequential(
        layer_init(nn.Linear(state_dim, 64)),
        nn.Tanh(),
        layer_init(nn.Linear(64, 64)),
        nn.Tanh(),
        layer_init(nn.Linear(64, action_dim), std=0.01),
    )
    return actor


@spaces.GPU(duration=60)
def simulate_agent(stage_selection):
    weight_mapping = {
        "Stage 1: Baseline": "1_baseline.pth",
        "Stage 2: Surrogate Hacking": "2_surrogate_hacking_attention.pth",
        "Stage 3: Temporal Paradox ": "3_temporal_paradox_variance.pth",
        "Stage 4: Target Decoupling": "4_target_decoupling_final.pth"
    }
    filename = weight_mapping.get(stage_selection)
    repo_id = "ben-dlwlrma/Representation-Over-Routing" 
    
    try:
        weights_path = hf_hub_download(repo_id=repo_id, filename=filename)
    except Exception as e:
        raise gr.Error(f"Weight download failed. Error: {str(e)}")

    env = gym.make("LunarLander-v3", render_mode="rgb_array")
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    actor = get_actor_network(state_dim=8, action_dim=4).to(device)
    
    try:
        actor.load_state_dict(torch.load(weights_path, map_location=device, weights_only=True))
        actor.eval()
    except Exception as e:
        env.close()
        raise gr.Error(f"Architecture mismatch. Error: {str(e)}")

    state, _ = env.reset(seed=32)
    done = False
    frames = []
    step_count = 0

    while not done and step_count < 600:
        try:
            frame = env.render()
            if frame is not None:
                frames.append(frame)
        except Exception as e:
            env.close()
            raise gr.Error(f"Render failed: {str(e)}")
            
        state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
        with torch.no_grad():
            action_logits = actor(state_tensor)
            action = torch.argmax(action_logits, dim=1).item()
            
        state, _, terminated, truncated, _ = env.step(action)
        step_count += 1
        done = terminated or truncated

    env.close()

    video_filename = "eval_output.mp4"
    fps = 30
    try:
        imageio.mimsave(video_filename, frames, fps=fps, codec='libx264', pixelformat='yuv420p')
    except Exception as e:
        raise gr.Error(f"Video encoding failed: {str(e)}")

    return video_filename


with gr.Blocks(title="Representation over Routing", theme=gr.themes.Base()) as demo:
    gr.Markdown("## Representation over Routing")
    gr.Markdown("Multi-timescale RL evaluation environment. Select an ablation stage to visualize policy behavior.")
    
    with gr.Row():
        with gr.Column(scale=1):
            model_dropdown = gr.Dropdown(
                choices=[
                    "Stage 1: Baseline",
                    "Stage 2: Surrogate Hacking",
                    "Stage 3: Temporal Paradox ",
                    "Stage 4: Target Decoupling"
                ],
                value="Stage 4: Target Decoupling",
                label="Model Stage"
            )
            run_button = gr.Button("Run Inference", variant="primary")
            
        with gr.Column(scale=2):
            video_output = gr.Video(label="Environment Render", autoplay=True)

    run_button.click(
        fn=simulate_agent,
        inputs=[model_dropdown],
        outputs=[video_output]
    )

if __name__ == "__main__":
    demo.launch()