PinHsuan commited on
Commit
3bf0f8d
·
verified ·
1 Parent(s): 445065e

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +28 -20
model.py CHANGED
@@ -2,34 +2,42 @@ import torch
2
  import torch.nn as nn
3
 
4
  class DualStreamTransformer(nn.Module):
5
- def __init__(self, n_feat1=25, n_feat2=12, d_model=32, num_classes=2):
6
  super(DualStreamTransformer, self).__init__()
 
7
  dim_ff = 128
 
8
  self.feat_tokenizers_1 = nn.ModuleList([nn.Linear(1, d_model) for _ in range(n_feat1)])
9
  self.cls_token_1 = nn.Parameter(torch.zeros(1, 1, d_model))
10
-
11
- encoder_layer_1 = nn.TransformerEncoderLayer(
12
- d_model=d_model,
13
- nhead=4,
14
- dim_feedforward=dim_ff,
15
- batch_first=True
16
- )
17
- self.encoder_1 = nn.TransformerEncoder(encoder_layer_1, num_layers=2)
18
  self.feat_tokenizers_2 = nn.ModuleList([nn.Linear(1, d_model) for _ in range(n_feat2)])
19
  self.cls_token_2 = nn.Parameter(torch.zeros(1, 1, d_model))
20
- encoder_layer_2 = nn.TransformerEncoderLayer(
21
- d_model=d_model,
22
- nhead=4,
23
- dim_feedforward=dim_ff,
24
- batch_first=True
25
- )
26
- self.encoder_2 = nn.TransformerEncoder(encoder_layer_2, num_layers=2)
27
- self.mlp_head = nn.Sequential(
28
- nn.Linear(d_model * 2, d_model),
29
- nn.ReLU(),
30
- nn.Linear(d_model, d_model)
31
  )
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  class ArcMarginProduct(nn.Module):
34
  def __init__(self, in_features, out_features, s=30.0, m=0.5):
35
  super(ArcMarginProduct, self).__init__()
 
2
  import torch.nn as nn
3
 
4
  class DualStreamTransformer(nn.Module):
5
+ def __init__(self, n_feat1=24, n_feat2=10, d_model=32, num_classes=2):
6
  super(DualStreamTransformer, self).__init__()
7
+
8
  dim_ff = 128
9
+
10
  self.feat_tokenizers_1 = nn.ModuleList([nn.Linear(1, d_model) for _ in range(n_feat1)])
11
  self.cls_token_1 = nn.Parameter(torch.zeros(1, 1, d_model))
12
+ encoder_layer_1 = nn.TransformerEncoderLayer(d_model=d_model, nhead=4, dim_feedforward=dim_ff, batch_first=True)
13
+ self.encoder_1 = nn.TransformerEncoder(encoder_layer_1, num_layers=2)
14
+
 
 
 
 
 
15
  self.feat_tokenizers_2 = nn.ModuleList([nn.Linear(1, d_model) for _ in range(n_feat2)])
16
  self.cls_token_2 = nn.Parameter(torch.zeros(1, 1, d_model))
17
+ encoder_layer_2 = nn.TransformerEncoderLayer(d_model=d_model, nhead=4, dim_feedforward=dim_ff, batch_first=True)
18
+ self.encoder_2 = nn.TransformerEncoder(encoder_layer_2, num_layers=2)
19
+
20
+
21
+ self.fusion = nn.Sequential(
22
+ nn.Linear(d_model * 2, d_model),
23
+ nn.ReLU()
 
 
 
 
24
  )
25
 
26
+ def forward(self, x1, x2):
27
+ tokens1 = [layer(x1[:, i].unsqueeze(1)) for i, layer in enumerate(self.feat_tokenizers_1)]
28
+ x1_emb = torch.stack(tokens1, dim=1)
29
+ x1_emb = torch.cat((self.cls_token_1.expand(x1.size(0), -1, -1), x1_emb), dim=1)
30
+ feat1 = self.encoder_1(x1_emb)[:, 0, :]
31
+
32
+
33
+ tokens2 = [layer(x2[:, i].unsqueeze(1)) for i, layer in enumerate(self.feat_tokenizers_2)]
34
+ x2_emb = torch.stack(tokens2, dim=1)
35
+ x2_emb = torch.cat((self.cls_token_2.expand(x2.size(0), -1, -1), x2_emb), dim=1)
36
+ feat2 = self.encoder_2(x2_emb)[:, 0, :]
37
+
38
+ combined = torch.cat((feat1, feat2), dim=1)
39
+ return self.fusion(combined)
40
+
41
  class ArcMarginProduct(nn.Module):
42
  def __init__(self, in_features, out_features, s=30.0, m=0.5):
43
  super(ArcMarginProduct, self).__init__()