ShaswatRobotics commited on
Commit
ee04203
·
verified ·
1 Parent(s): f71cfcd

Upload random_torch.py

Browse files
Files changed (1) hide show
  1. ant/pwm/random_torch.py +78 -0
ant/pwm/random_torch.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PyTorch WM models covering ALL CallModes
3
+ """
4
+
5
+ import os
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ OBS_DIM = 105
10
+ ACTION_DIM = 8
11
+ LATENT_DIM = 32
12
+ HIDDEN_DIM = 256
13
+
14
+ OUT_DIR = "weights_torch_signatures"
15
+ os.makedirs(OUT_DIR, exist_ok=True)
16
+
17
+ # ---------------- Encoder ----------------
18
+ class Encoder(nn.Module):
19
+ def __init__(self):
20
+ super().__init__()
21
+ self.net = nn.Sequential(
22
+ nn.Linear(OBS_DIM, HIDDEN_DIM),
23
+ nn.ReLU(),
24
+ nn.Linear(HIDDEN_DIM, LATENT_DIM),
25
+ )
26
+
27
+ def forward(self, obs):
28
+ return self.net(obs)
29
+
30
+ # ---------------- Positional / Kwargs ----------------
31
+ class TransitionPositional(nn.Module):
32
+ def __init__(self):
33
+ super().__init__()
34
+ self.net = nn.Sequential(
35
+ nn.Linear(LATENT_DIM + ACTION_DIM, HIDDEN_DIM),
36
+ nn.ReLU(),
37
+ nn.Linear(HIDDEN_DIM, LATENT_DIM),
38
+ )
39
+
40
+ def forward(self, z, a):
41
+ return self.net(torch.cat([z, a], dim=-1))
42
+
43
+ # ---------------- Tuple ----------------
44
+ class TransitionTuple(nn.Module):
45
+ def __init__(self):
46
+ super().__init__()
47
+ self.net = nn.Sequential(
48
+ nn.Linear(LATENT_DIM + ACTION_DIM, HIDDEN_DIM),
49
+ nn.ReLU(),
50
+ nn.Linear(HIDDEN_DIM, LATENT_DIM),
51
+ )
52
+
53
+ def forward(self, inputs):
54
+ z, a = inputs
55
+ return self.net(torch.cat([z, a], dim=-1))
56
+
57
+ # ---------------- Concat ----------------
58
+ class TransitionConcat(nn.Module):
59
+ def __init__(self):
60
+ super().__init__()
61
+ self.net = nn.Sequential(
62
+ nn.Linear(LATENT_DIM + ACTION_DIM, HIDDEN_DIM),
63
+ nn.ReLU(),
64
+ nn.Linear(HIDDEN_DIM, LATENT_DIM),
65
+ )
66
+
67
+ def forward(self, za):
68
+ return self.net(za)
69
+
70
+ def save(model, name):
71
+ torch.save(model.state_dict(), f"{OUT_DIR}/{name}.pth")
72
+
73
+ save(Encoder(), "encoder")
74
+ save(TransitionPositional(), "transition_positional")
75
+ save(TransitionTuple(), "transition_tuple")
76
+ save(TransitionConcat(), "transition_concat")
77
+
78
+ print("✅ PyTorch models saved")