File size: 2,031 Bytes
ee04203 d2c0e19 ee04203 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
"""
PyTorch WM models covering ALL CallModes
"""
import os
import torch
import torch.nn as nn
OBS_DIM = 105
ACTION_DIM = 8
LATENT_DIM = 512
HIDDEN_DIM = 256
OUT_DIR = "weights_torch_signatures"
os.makedirs(OUT_DIR, exist_ok=True)
# ---------------- Encoder ----------------
class Encoder(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(OBS_DIM, HIDDEN_DIM),
nn.ReLU(),
nn.Linear(HIDDEN_DIM, LATENT_DIM),
)
def forward(self, obs):
return self.net(obs)
# ---------------- Positional / Kwargs ----------------
class TransitionPositional(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(LATENT_DIM + ACTION_DIM, HIDDEN_DIM),
nn.ReLU(),
nn.Linear(HIDDEN_DIM, LATENT_DIM),
)
def forward(self, z, a):
return self.net(torch.cat([z, a], dim=-1))
# ---------------- Tuple ----------------
class TransitionTuple(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(LATENT_DIM + ACTION_DIM, HIDDEN_DIM),
nn.ReLU(),
nn.Linear(HIDDEN_DIM, LATENT_DIM),
)
def forward(self, inputs):
z, a = inputs
return self.net(torch.cat([z, a], dim=-1))
# ---------------- Concat ----------------
class TransitionConcat(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(LATENT_DIM + ACTION_DIM, HIDDEN_DIM),
nn.ReLU(),
nn.Linear(HIDDEN_DIM, LATENT_DIM),
)
def forward(self, za):
return self.net(za)
def save(model, name):
torch.save(model.state_dict(), f"{OUT_DIR}/{name}.pth")
save(Encoder(), "encoder")
save(TransitionPositional(), "transition_positional")
save(TransitionTuple(), "transition_tuple")
save(TransitionConcat(), "transition_concat")
print("✅ PyTorch models saved")
|