| import torch | |
| import torch.nn as nn | |
| class DeepSetClassifier(nn.Module): | |
| def __init__(self, hidden_dim=256): | |
| super().__init__() | |
| self.phi = nn.Sequential( | |
| nn.Linear(1, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, hidden_dim) | |
| ) | |
| self.rho = nn.Sequential( | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, 1) | |
| ) | |
| def forward(self, x, lengths): | |
| phi_x = self.phi(x) # [B, T, D] | |
| mask = torch.arange(x.size(1)).unsqueeze(0).to(x.device) < lengths.unsqueeze(1) | |
| mask = mask.unsqueeze(-1) # [B, T, 1] | |
| phi_x = phi_x * mask | |
| agg = phi_x.sum(dim=1) / lengths.unsqueeze(-1) # Mean pooling | |
| out = self.rho(agg) | |
| return torch.sigmoid(out).squeeze(-1) |