PinHsuan commited on
Commit
98cfc7c
·
verified ·
1 Parent(s): e694e57

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +90 -44
model.py CHANGED
@@ -1,49 +1,95 @@
1
  import torch
2
  import torch.nn as nn
 
 
3
 
4
- class DualStreamTransformer(nn.Module):
5
- def __init__(self, n_feat1=24, n_feat2=10, d_model=32, num_classes=2):
6
- super(DualStreamTransformer, self).__init__()
7
-
8
- dim_ff = 128
9
-
10
- self.feat_tokenizers_1 = nn.ModuleList([nn.Linear(1, d_model) for _ in range(n_feat1)])
11
- self.cls_token_1 = nn.Parameter(torch.zeros(1, 1, d_model))
12
- encoder_layer_1 = nn.TransformerEncoderLayer(d_model=d_model, nhead=4, dim_feedforward=dim_ff, batch_first=True)
13
- self.encoder_1 = nn.TransformerEncoder(encoder_layer_1, num_layers=2)
14
-
15
- self.feat_tokenizers_2 = nn.ModuleList([nn.Linear(1, d_model) for _ in range(n_feat2)])
16
- self.cls_token_2 = nn.Parameter(torch.zeros(1, 1, d_model))
17
- encoder_layer_2 = nn.TransformerEncoderLayer(d_model=d_model, nhead=4, dim_feedforward=dim_ff, batch_first=True)
18
- self.encoder_2 = nn.TransformerEncoder(encoder_layer_2, num_layers=2)
19
-
20
-
21
- self.fusion = nn.Sequential(
22
- nn.Linear(d_model * 2, d_model),
23
- nn.ReLU()
24
- )
25
-
26
- def forward(self, x1, x2):
27
- tokens1 = [layer(x1[:, i].unsqueeze(1)) for i, layer in enumerate(self.feat_tokenizers_1)]
28
- x1_emb = torch.stack(tokens1, dim=1)
29
- x1_emb = torch.cat((self.cls_token_1.expand(x1.size(0), -1, -1), x1_emb), dim=1)
30
- feat1 = self.encoder_1(x1_emb)[:, 0, :]
31
-
32
-
33
- tokens2 = [layer(x2[:, i].unsqueeze(1)) for i, layer in enumerate(self.feat_tokenizers_2)]
34
- x2_emb = torch.stack(tokens2, dim=1)
35
- x2_emb = torch.cat((self.cls_token_2.expand(x2.size(0), -1, -1), x2_emb), dim=1)
36
- feat2 = self.encoder_2(x2_emb)[:, 0, :]
37
-
38
- combined = torch.cat((feat1, feat2), dim=1)
39
- return self.fusion(combined)
40
 
41
  class ArcMarginProduct(nn.Module):
42
- def __init__(self, in_features, out_features, s=30.0, m=0.5):
43
- super(ArcMarginProduct, self).__init__()
44
- self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
45
- nn.init.xavier_uniform_(self.weight)
46
-
47
- def predict(self, x):
48
- cosine = torch.matmul(nn.functional.normalize(x), nn.functional.normalize(self.weight).t())
49
- return cosine
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
 
6
+
7
+ class FocalLoss(nn.Module):
8
+     def __init__(self, alpha=1, gamma=2, reduction='mean'):
9
+         super(FocalLoss, self).__init__()
10
+         self.alpha = alpha
11
+         self.gamma = gamma
12
+         self.reduction = reduction
13
+
14
+     def forward(self, inputs, targets):
15
+         ce_loss = F.cross_entropy(inputs, targets, reduction='none')
16
+         pt = torch.exp(-ce_loss)
17
+         focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
18
+         if self.reduction == 'mean': return focal_loss.mean()
19
+         return focal_loss.sum()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  class ArcMarginProduct(nn.Module):
22
+     def __init__(self, in_features, out_features, s=30.0, m=0.50, easy_margin=False):
23
+         super(ArcMarginProduct, self).__init__()
24
+         self.in_features = in_features
25
+         self.out_features = out_features
26
+         self.s = s
27
+         self.m = m
28
+         self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
29
+         nn.init.xavier_uniform_(self.weight)
30
+
31
+         self.easy_margin = easy_margin
32
+         self.cos_m = math.cos(m)
33
+         self.sin_m = math.sin(m)
34
+         self.th = math.cos(math.pi - m)
35
+         self.mm = math.sin(math.pi - m) * m
36
+
37
+     def forward(self, input, label):
38
+         cosine = F.linear(F.normalize(input), F.normalize(self.weight))
39
+         sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
40
+         phi = cosine * self.cos_m - sine * self.sin_m
41
+         if self.easy_margin:
42
+             phi = torch.where(cosine > 0, phi, cosine)
43
+         else:
44
+             phi = torch.where(cosine > self.th, phi, cosine - self.mm)
45
+         
46
+         one_hot = torch.zeros(cosine.size(), device=input.device)
47
+         one_hot.scatter_(1, label.view(-1, 1).long(), 1)
48
+         output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
49
+         output *= self.s
50
+         return output
51
+
52
+     def predict(self, input):
53
+         return F.linear(F.normalize(input), F.normalize(self.weight)) * self.s
54
+
55
+
56
+ class DualStreamTransformer(nn.Module):
57
+     def __init__(self, feat_num_1, feat_num_2, d_model=64, num_classes=3, dropout=0.3):
58
+         super().__init__()
59
+         # Stream 1: CCMQ
60
+         self.feat_tokenizers_1 = nn.ModuleList([nn.Linear(1, d_model) for _ in range(feat_num_1)])
61
+         enc_layer_1 = nn.TransformerEncoderLayer(d_model=d_model, nhead=4, dim_feedforward=128, dropout=dropout, batch_first=True)
62
+         self.encoder_1 = nn.TransformerEncoder(enc_layer_1, num_layers=2)
63
+         self.cls_token_1 = nn.Parameter(torch.zeros(1, 1, d_model))
64
+         
65
+         # Stream 2: OSDI
66
+         self.feat_tokenizers_2 = nn.ModuleList([nn.Linear(1, d_model) for _ in range(feat_num_2)])
67
+         enc_layer_2 = nn.TransformerEncoderLayer(d_model=d_model, nhead=4, dim_feedforward=128, dropout=dropout, batch_first=True)
68
+         self.encoder_2 = nn.TransformerEncoder(enc_layer_2, num_layers=2)
69
+         self.cls_token_2 = nn.Parameter(torch.zeros(1, 1, d_model))
70
+
71
+         # Fusion
72
+         self.fusion = nn.Sequential(
73
+             nn.Linear(d_model * 2, d_model),
74
+             nn.LayerNorm(d_model),
75
+             nn.ReLU(),
76
+             nn.Dropout(dropout)
77
+         )
78
+
79
+     def forward_stream(self, x, tokenizers, encoder, cls_token):
80
+         batch_size = x.size(0)
81
+         tokens = []
82
+         for i, tokenizer in enumerate(tokenizers):
83
+             val = x[:, i].unsqueeze(1)
84
+             tokens.append(tokenizer(val))
85
+         x_emb = torch.stack(tokens, dim=1)
86
+         cls_tokens = cls_token.expand(batch_size, -1, -1)
87
+         x_emb = torch.cat((cls_tokens, x_emb), dim=1)
88
+         x_out = encoder(x_emb)
89
+         return x_out[:, 0, :] 
90
+
91
+     def forward(self, x1, x2):
92
+         feat_1 = self.forward_stream(x1, self.feat_tokenizers_1, self.encoder_1, self.cls_token_1)
93
+         feat_2 = self.forward_stream(x2, self.feat_tokenizers_2, self.encoder_2, self.cls_token_2)
94
+         combined = torch.cat((feat_1, feat_2), dim=1)
95
+         return self.fusion(combined)