import torch import torch.nn as nn class DeepSetClassifier(nn.Module): def __init__(self, hidden_dim=256): super().__init__() self.phi = nn.Sequential( nn.Linear(1, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim) ) self.rho = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1) ) def forward(self, x, lengths): phi_x = self.phi(x) # [B, T, D] mask = torch.arange(x.size(1)).unsqueeze(0).to(x.device) < lengths.unsqueeze(1) mask = mask.unsqueeze(-1) # [B, T, 1] phi_x = phi_x * mask agg = phi_x.sum(dim=1) / lengths.unsqueeze(-1) # Mean pooling out = self.rho(agg) return torch.sigmoid(out).squeeze(-1)