s1ghhh's picture
Upload folder using huggingface_hub
91aec72 verified
import torch
import torch.nn as nn
class DeepSetClassifier(nn.Module):
def __init__(self, hidden_dim=256):
super().__init__()
self.phi = nn.Sequential(
nn.Linear(1, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim)
)
self.rho = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, x, lengths):
phi_x = self.phi(x) # [B, T, D]
mask = torch.arange(x.size(1)).unsqueeze(0).to(x.device) < lengths.unsqueeze(1)
mask = mask.unsqueeze(-1) # [B, T, 1]
phi_x = phi_x * mask
agg = phi_x.sum(dim=1) / lengths.unsqueeze(-1) # Mean pooling
out = self.rho(agg)
return torch.sigmoid(out).squeeze(-1)