ShaswatRobotics commited on
Commit
7ee5e5e
·
verified ·
1 Parent(s): 6d5221f

Upload 4 files

Browse files
ant/pwm_torch_seperate/encoder.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ae346f29cece02fefdc82310919c046f5c930a3eb165a5f521dfccbd1388154
3
+ size 407381
ant/pwm_torch_seperate/random_torch.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ generate_random_wm.py
3
+
4
+ Creates random (but valid) world-model networks:
5
+ - Encoder
6
+ - Transition model
7
+ - Reward model
8
+
9
+ Saves class-compatible state_dict weights.
10
+ """
11
+
12
+ import os
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+
17
+ # -------------------------
18
+ # Config (Ant-style)
19
+ # -------------------------
20
+ OBS_DIM = 105
21
+ ACTION_DIM = 8
22
+ LATENT_DIM = 32
23
+ HIDDEN_DIM = 256
24
+ SEED = 42
25
+
26
+ OUT_DIR = "weights"
27
+
28
+
29
+ # -------------------------
30
+ # Models
31
+ # -------------------------
32
+ class Encoder(nn.Module):
33
+ def __init__(self, obs_dim: int, latent_dim: int, hidden_dim: int):
34
+ super().__init__()
35
+ self.net = nn.Sequential(
36
+ nn.Linear(obs_dim, hidden_dim),
37
+ nn.ReLU(),
38
+ nn.Linear(hidden_dim, hidden_dim),
39
+ nn.ReLU(),
40
+ nn.Linear(hidden_dim, latent_dim),
41
+ )
42
+
43
+ def forward(self, obs):
44
+ return self.net(obs)
45
+
46
+
47
+ class TransitionModel(nn.Module):
48
+ def __init__(self, latent_dim: int, action_dim: int, hidden_dim: int):
49
+ super().__init__()
50
+ self.net = nn.Sequential(
51
+ nn.Linear(latent_dim + action_dim, hidden_dim),
52
+ nn.ReLU(),
53
+ nn.Linear(hidden_dim, hidden_dim),
54
+ nn.ReLU(),
55
+ nn.Linear(hidden_dim, latent_dim),
56
+ )
57
+
58
+ def forward(self, z, action):
59
+ x = torch.cat([z, action], dim=-1)
60
+ return self.net(x)
61
+
62
+
63
+ class RewardModel(nn.Module):
64
+ def __init__(self, latent_dim: int, action_dim: int, hidden_dim: int):
65
+ super().__init__()
66
+ self.net = nn.Sequential(
67
+ nn.Linear(latent_dim + action_dim, hidden_dim),
68
+ nn.ReLU(),
69
+ nn.Linear(hidden_dim, 1),
70
+ )
71
+
72
+ def forward(self, z, action):
73
+ x = torch.cat([z, action], dim=-1)
74
+ return self.net(x).squeeze(-1)
75
+
76
+
77
+ # -------------------------
78
+ # Initialization
79
+ # -------------------------
80
+ def init_weights(m):
81
+ if isinstance(m, nn.Linear):
82
+ nn.init.orthogonal_(m.weight)
83
+ nn.init.zeros_(m.bias)
84
+
85
+
86
+ # -------------------------
87
+ # Main
88
+ # -------------------------
89
+ def main():
90
+ torch.manual_seed(SEED)
91
+
92
+ encoder = Encoder(OBS_DIM, LATENT_DIM, HIDDEN_DIM)
93
+ transition = TransitionModel(LATENT_DIM, ACTION_DIM, HIDDEN_DIM)
94
+ reward = RewardModel(LATENT_DIM, ACTION_DIM, HIDDEN_DIM)
95
+
96
+ encoder.apply(init_weights)
97
+ transition.apply(init_weights)
98
+ reward.apply(init_weights)
99
+
100
+ os.makedirs(OUT_DIR, exist_ok=True)
101
+
102
+ torch.save(encoder.state_dict(), f"{OUT_DIR}/encoder.pth")
103
+ torch.save(transition.state_dict(), f"{OUT_DIR}/transition.pth")
104
+ torch.save(reward.state_dict(), f"{OUT_DIR}/reward.pth")
105
+
106
+ print("✅ Random world-model weights saved:")
107
+ print(f" {OUT_DIR}/encoder.pth")
108
+ print(f" {OUT_DIR}/transition.pth")
109
+ print(f" {OUT_DIR}/reward.pth")
110
+
111
+
112
+ if __name__ == "__main__":
113
+ main()
ant/pwm_torch_seperate/reward.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad796e55dfaebcba46a0baaec334c8e996794aace6950df447e6168f735426ab
3
+ size 45331
ant/pwm_torch_seperate/transition.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cd055e79288d6f6d3e10d83dab550a1dfb843261d44b7af029b6c7b5d4d4e5cd
3
+ size 341177