Nharen commited on
Commit
b6f3e6e
·
verified ·
1 Parent(s): 9d5c669

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +67 -56
README.md CHANGED
@@ -41,15 +41,15 @@ The SAC actor is a multi-layer perceptron with the following specifications:
41
  - Linear(17, 256) -> ReLU
42
  - Linear(256, 256) -> ReLU
43
  - Linear(256, 6) for `mean` + Linear(256, 6) for `log_std`
44
- - **Note:** The actor outputs the mean and log standard deviation for each action. For inference, only the mean is used, passed through a tanh activation to bound actions to [-1, 1].
45
 
46
  ## Common Mistakes to Avoid
47
 
48
- - **Layer Names:** The checkpoint uses `mean` and `log_std` layers. Do not try to map them to a generic `nn.Sequential` unless keys are remapped.
49
- - **Output Dimensions:** Ensure the model matches the checkpoint (6 actions).
50
- - **Continuous Actions:** HalfCheetah requires a numpy array for actions. Flatten tensors and convert to numpy.
51
- - **Episode Evaluation:** Always test over full episodes (100 episodes recommended) to properly evaluate policy performance.
52
- - **Checkpoint Loading:** Use `weights_only=True` with `torch.load` to safely load state dicts.
53
 
54
  ## Download and Test Code
55
 
@@ -60,58 +60,69 @@ import gymnasium as gym
60
  import numpy as np
61
  from huggingface_hub import hf_hub_download
62
 
 
 
 
 
 
 
 
 
 
 
 
63
  class SACActor(nn.Module):
64
- def __init__(self, obs_dim=17, action_dim=6, hidden_dim=256):
65
- super(SACActor, self).__init__()
66
- self.fc1 = nn.Linear(obs_dim, hidden_dim)
67
- self.fc2 = nn.Linear(hidden_dim, hidden_dim)
68
- self.mean = nn.Linear(hidden_dim, action_dim)
69
- self.log_std = nn.Linear(hidden_dim, action_dim)
70
- self.relu = nn.ReLU()
71
-
72
- def forward(self, x):
73
- x = self.relu(self.fc1(x))
74
- x = self.relu(self.fc2(x))
 
 
75
  mean = self.mean(x)
76
  return torch.tanh(mean)
77
 
78
- def run_half_cheetah_test():
79
- path = hf_hub_download(
80
- repo_id="Nharen/Reward_Rush_SAC_Half_Cheetah",
81
- filename="half_cheetah.pth"
82
- )
83
-
84
- model = SACActor()
85
- state_dict = torch.load(path, map_location='cpu', weights_only=True)
86
-
87
- if isinstance(state_dict, dict) and "actor_state_dict" in state_dict:
88
- state_dict = state_dict["actor_state_dict"]
89
- elif isinstance(state_dict, dict) and "state_dict" in state_dict:
90
- state_dict = state_dict["state_dict"]
91
-
92
- model.load_state_dict(state_dict)
93
- model.eval()
94
-
95
- env = gym.make("HalfCheetah-v4")
96
- total_rewards = []
97
-
98
- for _ in range(100):
99
- state, _ = env.reset()
100
- episode_reward = 0
101
- done = False
102
- while not done:
103
- state_t = torch.as_tensor(state, dtype=torch.float32).unsqueeze(0)
104
- with torch.no_grad():
105
- action = model(state_t).cpu().numpy().flatten()
106
-
107
- state, reward, terminated, truncated, _ = env.step(action)
108
- episode_reward += reward
109
- done = terminated or truncated
110
- total_rewards.append(episode_reward)
111
-
112
- print(f"Average Reward over 100 episodes: {np.mean(total_rewards)}")
113
- env.close()
114
-
115
- if __name__ == "__main__":
116
- run_half_cheetah_test()
117
  ```
 
41
  - Linear(17, 256) -> ReLU
42
  - Linear(256, 256) -> ReLU
43
  - Linear(256, 6) for `mean` + Linear(256, 6) for `log_std`
44
+ - **Note:** The actor outputs mean and log standard deviation for each action. For inference, only the mean is used, passed through a tanh activation to bound actions to [-1, 1].
45
 
46
  ## Common Mistakes to Avoid
47
 
48
+ - **Layer Names:** The checkpoint uses `net`, `mean`, and `log_std`. Do not try to redefine layers with different names (`fc1`, `fc2`) unless you remap the keys.
49
+ - **Output Dimensions:** Ensure the actor matches the checkpoint dimensions (6 actions).
50
+ - **Continuous Actions:** HalfCheetah requires numpy arrays for actions. Flatten tensors and convert to numpy.
51
+ - **Episode Evaluation:** Always test over full episodes (100 recommended) to properly evaluate performance.
52
+ - **Checkpoint Loading:** Use `weights_only=True` when loading `.pth` state dicts for safety.
53
 
54
  ## Download and Test Code
55
 
 
60
  import numpy as np
61
  from huggingface_hub import hf_hub_download
62
 
63
+ # Load stripped checkpoint
64
+ ckpt = torch.load(
65
+ hf_hub_download("Nharen/Reward_Rush_SAC_Half_Cheetah", "half_cheetah.pth"),
66
+ weights_only=True
67
+ )
68
+
69
+ obs_dim = ckpt["obs_dim"]
70
+ act_dim = ckpt["act_dim"]
71
+ hidden_dim = ckpt.get("hidden_dim", 256)
72
+
73
+ # SAC Gaussian Actor
74
  class SACActor(nn.Module):
75
+ def __init__(self, obs_dim, act_dim, hidden_dim=256):
76
+ super().__init__()
77
+ self.net = nn.Sequential(
78
+ nn.Linear(obs_dim, hidden_dim),
79
+ nn.ReLU(),
80
+ nn.Linear(hidden_dim, hidden_dim),
81
+ nn.ReLU()
82
+ )
83
+ self.mean = nn.Linear(hidden_dim, act_dim)
84
+ self.log_std = nn.Linear(hidden_dim, act_dim)
85
+
86
+ def forward(self, obs):
87
+ x = self.net(obs)
88
  mean = self.mean(x)
89
  return torch.tanh(mean)
90
 
91
+ # Instantiate actor
92
+ actor = SACActor(obs_dim, act_dim, hidden_dim)
93
+ actor.load_state_dict(ckpt["actor_state_dict"])
94
+ actor.eval()
95
+
96
+ # Environment
97
+ env = gym.make("HalfCheetah-v4")
98
+ num_episodes = 100
99
+ episode_rewards = []
100
+
101
+ # Run evaluation
102
+ for ep in range(num_episodes):
103
+ obs, _ = env.reset()
104
+ done = False
105
+ ep_reward = 0.0
106
+
107
+ while not done:
108
+ with torch.no_grad():
109
+ obs_t = torch.tensor(obs, dtype=torch.float32).unsqueeze(0)
110
+ action = actor(obs_t).squeeze(0).cpu().numpy()
111
+ obs, reward, terminated, truncated, _ = env.step(action)
112
+ ep_reward += reward
113
+ done = terminated or truncated
114
+
115
+ episode_rewards.append(ep_reward)
116
+ print(f"Episode {ep+1:3d} | Reward: {ep_reward:.2f}")
117
+
118
+ env.close()
119
+
120
+ # Results
121
+ episode_rewards = np.array(episode_rewards)
122
+ print("\n===== Evaluation Summary =====")
123
+ print(f"Episodes run: {num_episodes}")
124
+ print(f"Mean reward: {episode_rewards.mean():.2f}")
125
+ print(f"Std reward: {episode_rewards.std():.2f}")
126
+ print(f"Min reward: {episode_rewards.min():.2f}")
127
+ print(f"Max reward: {episode_rewards.max():.2f}")
 
 
128
  ```