| import torch | |
| import torch.nn as nn | |
| class STOPPolicy(nn.Module): | |
| def __init__(self, feature_size, intermediate_size=512): | |
| super().__init__() | |
| self.fc = nn.Sequential( | |
| nn.Linear(feature_size, intermediate_size), | |
| nn.GELU(), | |
| nn.Linear(intermediate_size, 2) | |
| ) | |
| def forward(self, x, temperature=1.0): | |
| return self.fc(x) | |
| class LatentPolicy(nn.Module): | |
| def __init__(self, feature_size, intermediate_size=512, deterministic=False): | |
| super().__init__() | |
| self.deterministic = deterministic | |
| self.fc = nn.Sequential( | |
| nn.Linear(feature_size, intermediate_size), | |
| nn.GELU(), | |
| nn.Linear(intermediate_size, intermediate_size), | |
| nn.LayerNorm(intermediate_size), | |
| ) | |
| self.mean = nn.Linear(intermediate_size, feature_size) | |
| if not deterministic: | |
| self.log_var = nn.Linear(intermediate_size, feature_size) | |
| def forward(self, x, temperature=1.0): | |
| x = self.fc(x) | |
| mean = self.mean(x) | |
| if self.deterministic: | |
| return torch.distributions.Normal(mean, torch.ones_like(mean) * 1e-9) | |
| log_var = self.log_var(x) | |
| return mean, log_var | |