""" PyTorch WM models covering ALL CallModes """ import os import torch import torch.nn as nn OBS_DIM = 105 ACTION_DIM = 8 LATENT_DIM = 512 HIDDEN_DIM = 256 OUT_DIR = "weights_torch_signatures" os.makedirs(OUT_DIR, exist_ok=True) # ---------------- Encoder ---------------- class Encoder(nn.Module): def __init__(self): super().__init__() self.net = nn.Sequential( nn.Linear(OBS_DIM, HIDDEN_DIM), nn.ReLU(), nn.Linear(HIDDEN_DIM, LATENT_DIM), ) def forward(self, obs): return self.net(obs) # ---------------- Positional / Kwargs ---------------- class TransitionPositional(nn.Module): def __init__(self): super().__init__() self.net = nn.Sequential( nn.Linear(LATENT_DIM + ACTION_DIM, HIDDEN_DIM), nn.ReLU(), nn.Linear(HIDDEN_DIM, LATENT_DIM), ) def forward(self, z, a): return self.net(torch.cat([z, a], dim=-1)) # ---------------- Tuple ---------------- class TransitionTuple(nn.Module): def __init__(self): super().__init__() self.net = nn.Sequential( nn.Linear(LATENT_DIM + ACTION_DIM, HIDDEN_DIM), nn.ReLU(), nn.Linear(HIDDEN_DIM, LATENT_DIM), ) def forward(self, inputs): z, a = inputs return self.net(torch.cat([z, a], dim=-1)) # ---------------- Concat ---------------- class TransitionConcat(nn.Module): def __init__(self): super().__init__() self.net = nn.Sequential( nn.Linear(LATENT_DIM + ACTION_DIM, HIDDEN_DIM), nn.ReLU(), nn.Linear(HIDDEN_DIM, LATENT_DIM), ) def forward(self, za): return self.net(za) def save(model, name): torch.save(model.state_dict(), f"{OUT_DIR}/{name}.pth") save(Encoder(), "encoder") save(TransitionPositional(), "transition_positional") save(TransitionTuple(), "transition_tuple") save(TransitionConcat(), "transition_concat") print("✅ PyTorch models saved")