File size: 2,031 Bytes
ee04203
 
 
 
 
 
 
 
 
 
d2c0e19
ee04203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""
PyTorch WM models covering ALL CallModes
"""

import os
import torch
import torch.nn as nn

OBS_DIM = 105
ACTION_DIM = 8
LATENT_DIM = 512
HIDDEN_DIM = 256

OUT_DIR = "weights_torch_signatures"
os.makedirs(OUT_DIR, exist_ok=True)

# ---------------- Encoder ----------------
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(OBS_DIM, HIDDEN_DIM),
            nn.ReLU(),
            nn.Linear(HIDDEN_DIM, LATENT_DIM),
        )

    def forward(self, obs):
        return self.net(obs)

# ---------------- Positional / Kwargs ----------------
class TransitionPositional(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(LATENT_DIM + ACTION_DIM, HIDDEN_DIM),
            nn.ReLU(),
            nn.Linear(HIDDEN_DIM, LATENT_DIM),
        )

    def forward(self, z, a):
        return self.net(torch.cat([z, a], dim=-1))

# ---------------- Tuple ----------------
class TransitionTuple(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(LATENT_DIM + ACTION_DIM, HIDDEN_DIM),
            nn.ReLU(),
            nn.Linear(HIDDEN_DIM, LATENT_DIM),
        )

    def forward(self, inputs):
        z, a = inputs
        return self.net(torch.cat([z, a], dim=-1))

# ---------------- Concat ----------------
class TransitionConcat(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(LATENT_DIM + ACTION_DIM, HIDDEN_DIM),
            nn.ReLU(),
            nn.Linear(HIDDEN_DIM, LATENT_DIM),
        )

    def forward(self, za):
        return self.net(za)

def save(model, name):
    torch.save(model.state_dict(), f"{OUT_DIR}/{name}.pth")

save(Encoder(), "encoder")
save(TransitionPositional(), "transition_positional")
save(TransitionTuple(), "transition_tuple")
save(TransitionConcat(), "transition_concat")

print("✅ PyTorch models saved")