ben-dlwlrma commited on
Commit
3ae4f81
·
verified ·
1 Parent(s): 2fa41f4

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +97 -0
README.md CHANGED
@@ -1,3 +1,100 @@
1
  ---
 
 
 
 
 
 
 
 
2
  license: cc-by-4.0
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ tags:
3
+ - reinforcement-learning
4
+ - pytorch
5
+ - custom-implementation
6
+ - ppo
7
+ - deep-reinforcement-learning
8
+ - gym
9
+ - lunar-lander
10
  license: cc-by-4.0
11
+ library_name: pytorch
12
  ---
13
+
14
+ # Representation over Routing: Overcoming Surrogate Hacking in Multi-Timescale PPO
15
+
16
+ [![arXiv](https://img.shields.io/badge/arXiv-2604.13517-b31b1b.svg)](https://arxiv.org/abs/2604.13517)
17
+ [![GitHub](https://img.shields.io/badge/GitHub-Codebase-blue?logo=github)](https://github.com/ben-dlwlrma/Representation-Over-Routing)
18
+
19
+ This repository hosts the **pre-trained PyTorch model weights** for the 4-stage ablation study presented in the paper: *"Representation over Routing: Overcoming Surrogate Hacking in Multi-Timescale PPO"*.
20
+
21
+ Our work identifies severe optimization pathologies in multi-timescale RL (**Surrogate Objective Hacking** and **the Paradox of Temporal Uncertainty**) and introduces **Target Decoupling** to align agents with true long-term objectives without collapsing into short-term behavioral traps.
22
+
23
+ ## Model Weights Overview
24
+
25
+ We provide four standalone `.pth` weight files, corresponding to the isolated stages of our ablation study on the `LunarLander-v2` environment:
26
+
27
+ * **`1_baseline.pth` (Baseline)**: Suffers from hovering local optima, wasting fuel to hoard small centering rewards due to a fear of crashing.
28
+ * **`2_surrogate_hacking_attention.pth` (Surrogate Hacking)**: Demonstrates multi-timescale collapse. The policy artificially minimizes the surrogate loss by manipulating attention weights instead of improving physical control.
29
+ * **`3_temporal_paradox_variance.pth` (Temporal Paradox)**: Exhibits aimless wandering caused by the inability to confidently attribute credit over long horizons.
30
+ * **`4_target_decoupling_final.pth` (Target Decoupling)**: **Our proposed solution.** The agent uncovers true intelligence, executing a highly fuel-efficient and safe landing by understanding the ultimate long-term goal ($\gamma = 0.999$).
31
+
32
+ ## Usage & Inference
33
+
34
+ To fully reproduce the training process or run the visual evaluations (GIFs), please refer to the [official GitHub repository](https://github.com/ben-dlwlrma/Representation-Over-Routing).
35
+
36
+ Because the published weights only contain the parameters for the Actor networks, inference is exceptionally lightweight. You do not need to import the full training architecture. You can directly load the weights into a standard PyTorch `nn.Sequential` module using the following minimal snippet:
37
+
38
+ ```python
39
+ import torch
40
+ import torch.nn as nn
41
+ import numpy as np
42
+ import gymnasium as gym
43
+ from huggingface_hub import hf_hub_download
44
+
45
+ # 1. Download a specific stage's weight from Hugging Face
46
+ weight_path = hf_hub_download(
47
+ repo_id="ben-dlwlrma/Representation-Over-Routing",
48
+ filename="4_target_decoupling_final.pth"
49
+ )
50
+
51
+ # 2. Define the exact Actor network architecture
52
+ def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
53
+ nn.init.orthogonal_(layer.weight, std)
54
+ nn.init.constant_(layer.bias, bias_const)
55
+ return layer
56
+
57
+ actor = nn.Sequential(
58
+ layer_init(nn.Linear(8, 64)),
59
+ nn.Tanh(),
60
+ layer_init(nn.Linear(64, 64)),
61
+ nn.Tanh(),
62
+ layer_init(nn.Linear(64, 4), std=0.01),
63
+ )
64
+
65
+ # 3. Load weights
66
+ actor.load_state_dict(torch.load(weight_path, weights_only=True))
67
+ actor.eval()
68
+
69
+ # 4. Run Inference in environment
70
+ env = gym.make("LunarLander-v2")
71
+ state, _ = env.reset()
72
+ done = False
73
+
74
+ while not done:
75
+ state_tensor = torch.FloatTensor(state).unsqueeze(0)
76
+ with torch.no_grad():
77
+ logits = actor(state_tensor)
78
+ action = torch.argmax(logits, dim=1).item()
79
+
80
+ state, reward, terminated, truncated, _ = env.step(action)
81
+ done = terminated or truncated
82
+ ```
83
+
84
+ ## Citation
85
+
86
+ If you find this code or our insights useful in your research, please consider citing our work:
87
+
88
+ ```bibtex
89
+ @misc{sunRepresentationRoutingOvercoming2026b,
90
+ title = {Representation over {{Routing}}: {{Overcoming Surrogate Hacking}} in {{Multi-Timescale PPO}}},
91
+ shorttitle = {Representation over {{Routing}}},
92
+ author = {Sun, Jing},
93
+ year = 2026,
94
+ publisher = {arXiv},
95
+ doi = {10.48550/ARXIV.2604.13517},
96
+ urldate = {2026-04-16},
97
+ copyright = {Creative Commons Attribution 4.0 International},
98
+ keywords = {Artificial Intelligence (cs.AI),FOS: Computer and information sciences,Machine Learning (cs.LG)}
99
+ }
100
+ ```