Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| class FocalLoss(nn.Module): | |
| def __init__(self, alpha=1, gamma=2, reduction='mean'): | |
| super(FocalLoss, self).__init__() | |
| self.alpha = alpha | |
| self.gamma = gamma | |
| self.reduction = reduction | |
| def forward(self, inputs, targets): | |
| ce_loss = F.cross_entropy(inputs, targets, reduction='none') | |
| pt = torch.exp(-ce_loss) | |
| focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss | |
| if self.reduction == 'mean': return focal_loss.mean() | |
| return focal_loss.sum() | |
| class ArcMarginProduct(nn.Module): | |
| def __init__(self, in_features, out_features, s=30.0, m=0.50, easy_margin=False): | |
| super(ArcMarginProduct, self).__init__() | |
| self.in_features = in_features | |
| self.out_features = out_features | |
| self.s = s | |
| self.m = m | |
| self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features)) | |
| nn.init.xavier_uniform_(self.weight) | |
| self.easy_margin = easy_margin | |
| self.cos_m = math.cos(m) | |
| self.sin_m = math.sin(m) | |
| self.th = math.cos(math.pi - m) | |
| self.mm = math.sin(math.pi - m) * m | |
| def forward(self, input, label): | |
| cosine = F.linear(F.normalize(input), F.normalize(self.weight)) | |
| sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1)) | |
| phi = cosine * self.cos_m - sine * self.sin_m | |
| if self.easy_margin: | |
| phi = torch.where(cosine > 0, phi, cosine) | |
| else: | |
| phi = torch.where(cosine > self.th, phi, cosine - self.mm) | |
| one_hot = torch.zeros(cosine.size(), device=input.device) | |
| one_hot.scatter_(1, label.view(-1, 1).long(), 1) | |
| output = (one_hot * phi) + ((1.0 - one_hot) * cosine) | |
| output *= self.s | |
| return output | |
| def predict(self, input): | |
| return F.linear(F.normalize(input), F.normalize(self.weight)) * self.s | |
| class DualStreamTransformer(nn.Module): | |
| def __init__(self, feat_num_1, feat_num_2, d_model=64, num_classes=3, dropout=0.3): | |
| super().__init__() | |
| self.feat_tokenizers_1 = nn.ModuleList([nn.Linear(1, d_model) for _ in range(feat_num_1)]) | |
| enc_layer_1 = nn.TransformerEncoderLayer(d_model=d_model, nhead=4, dim_feedforward=128, dropout=dropout, batch_first=True) | |
| self.encoder_1 = nn.TransformerEncoder(enc_layer_1, num_layers=2) | |
| self.cls_token_1 = nn.Parameter(torch.zeros(1, 1, d_model)) | |
| self.feat_tokenizers_2 = nn.ModuleList([nn.Linear(1, d_model) for _ in range(feat_num_2)]) | |
| enc_layer_2 = nn.TransformerEncoderLayer(d_model=d_model, nhead=4, dim_feedforward=128, dropout=dropout, batch_first=True) | |
| self.encoder_2 = nn.TransformerEncoder(enc_layer_2, num_layers=2) | |
| self.cls_token_2 = nn.Parameter(torch.zeros(1, 1, d_model)) | |
| self.fusion = nn.Sequential( | |
| nn.Linear(d_model * 2, d_model), | |
| nn.LayerNorm(d_model), | |
| nn.ReLU(), | |
| nn.Dropout(dropout) | |
| ) | |
| def forward_stream(self, x, tokenizers, encoder, cls_token): | |
| batch_size = x.size(0) | |
| tokens = [] | |
| for i, tokenizer in enumerate(tokenizers): | |
| val = x[:, i].unsqueeze(1) | |
| tokens.append(tokenizer(val)) | |
| x_emb = torch.stack(tokens, dim=1) | |
| cls_tokens = cls_token.expand(batch_size, -1, -1) | |
| x_emb = torch.cat((cls_tokens, x_emb), dim=1) | |
| x_out = encoder(x_emb) | |
| return x_out[:, 0, :] | |
| def forward(self, x1, x2): | |
| feat_1 = self.forward_stream(x1, self.feat_tokenizers_1, self.encoder_1, self.cls_token_1) | |
| feat_2 = self.forward_stream(x2, self.feat_tokenizers_2, self.encoder_2, self.cls_token_2) | |
| combined = torch.cat((feat_1, feat_2), dim=1) | |
| return self.fusion(combined) |