PinHsuan commited on
Commit
13c0689
·
verified ·
1 Parent(s): d882bc5

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +2 -4
model.py CHANGED
@@ -54,19 +54,17 @@ class ArcMarginProduct(nn.Module):
54
  class DualStreamTransformer(nn.Module):
55
  def __init__(self, feat_num_1, feat_num_2, d_model=64, num_classes=3, dropout=0.3):
56
  super().__init__()
57
- # Stream 1: CCMQ
58
  self.feat_tokenizers_1 = nn.ModuleList([nn.Linear(1, d_model) for _ in range(feat_num_1)])
59
  enc_layer_1 = nn.TransformerEncoderLayer(d_model=d_model, nhead=4, dim_feedforward=128, dropout=dropout, batch_first=True)
60
  self.encoder_1 = nn.TransformerEncoder(enc_layer_1, num_layers=2)
61
  self.cls_token_1 = nn.Parameter(torch.zeros(1, 1, d_model))
62
 
63
- # Stream 2: OSDI
64
  self.feat_tokenizers_2 = nn.ModuleList([nn.Linear(1, d_model) for _ in range(feat_num_2)])
65
  enc_layer_2 = nn.TransformerEncoderLayer(d_model=d_model, nhead=4, dim_feedforward=128, dropout=dropout, batch_first=True)
66
  self.encoder_2 = nn.TransformerEncoder(enc_layer_2, num_layers=2)
67
  self.cls_token_2 = nn.Parameter(torch.zeros(1, 1, d_model))
68
 
69
- # Fusion
70
  self.fusion = nn.Sequential(
71
  nn.Linear(d_model * 2, d_model),
72
  nn.LayerNorm(d_model),
@@ -89,5 +87,5 @@ class DualStreamTransformer(nn.Module):
89
  def forward(self, x1, x2):
90
  feat_1 = self.forward_stream(x1, self.feat_tokenizers_1, self.encoder_1, self.cls_token_1)
91
  feat_2 = self.forward_stream(x2, self.feat_tokenizers_2, self.encoder_2, self.cls_token_2)
92
- combined = torch.cat((feat_1, feat2), dim=1)
93
  return self.fusion(combined)
 
54
  class DualStreamTransformer(nn.Module):
55
  def __init__(self, feat_num_1, feat_num_2, d_model=64, num_classes=3, dropout=0.3):
56
  super().__init__()
 
57
  self.feat_tokenizers_1 = nn.ModuleList([nn.Linear(1, d_model) for _ in range(feat_num_1)])
58
  enc_layer_1 = nn.TransformerEncoderLayer(d_model=d_model, nhead=4, dim_feedforward=128, dropout=dropout, batch_first=True)
59
  self.encoder_1 = nn.TransformerEncoder(enc_layer_1, num_layers=2)
60
  self.cls_token_1 = nn.Parameter(torch.zeros(1, 1, d_model))
61
 
 
62
  self.feat_tokenizers_2 = nn.ModuleList([nn.Linear(1, d_model) for _ in range(feat_num_2)])
63
  enc_layer_2 = nn.TransformerEncoderLayer(d_model=d_model, nhead=4, dim_feedforward=128, dropout=dropout, batch_first=True)
64
  self.encoder_2 = nn.TransformerEncoder(enc_layer_2, num_layers=2)
65
  self.cls_token_2 = nn.Parameter(torch.zeros(1, 1, d_model))
66
 
67
+
68
  self.fusion = nn.Sequential(
69
  nn.Linear(d_model * 2, d_model),
70
  nn.LayerNorm(d_model),
 
87
  def forward(self, x1, x2):
88
  feat_1 = self.forward_stream(x1, self.feat_tokenizers_1, self.encoder_1, self.cls_token_1)
89
  feat_2 = self.forward_stream(x2, self.feat_tokenizers_2, self.encoder_2, self.cls_token_2)
90
+ combined = torch.cat((feat_1, feat_2), dim=1)
91
  return self.fusion(combined)