ATP-Latent-Master / modules /projector.py
zz1358m's picture
Upload folder using huggingface_hub
7a92ec5 verified
import torch
import torch.nn as nn
class STOPPolicy(nn.Module):
def __init__(self, feature_size, intermediate_size=512):
super().__init__()
self.fc = nn.Sequential(
nn.Linear(feature_size, intermediate_size),
nn.GELU(),
nn.Linear(intermediate_size, 2)
)
def forward(self, x, temperature=1.0):
return self.fc(x)
class LatentPolicy(nn.Module):
def __init__(self, feature_size, intermediate_size=512, deterministic=False):
super().__init__()
self.deterministic = deterministic
self.fc = nn.Sequential(
nn.Linear(feature_size, intermediate_size),
nn.GELU(),
nn.Linear(intermediate_size, intermediate_size),
nn.LayerNorm(intermediate_size),
)
self.mean = nn.Linear(intermediate_size, feature_size)
if not deterministic:
self.log_var = nn.Linear(intermediate_size, feature_size)
def forward(self, x, temperature=1.0):
x = self.fc(x)
mean = self.mean(x)
if self.deterministic:
return torch.distributions.Normal(mean, torch.ones_like(mean) * 1e-9)
log_var = self.log_var(x)
return mean, log_var