PinHsuan commited on
Commit
c39c656
·
verified ·
1 Parent(s): 6105886

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +17 -29
model.py CHANGED
@@ -4,43 +4,32 @@ import torch.nn as nn
4
  class DualStreamTransformer(nn.Module):
5
  def __init__(self, n_feat1=25, n_feat2=12, d_model=32, num_classes=2):
6
  super(DualStreamTransformer, self).__init__()
7
-
8
- # Stream 1: CCMQ Tokenizer & Encoder
9
  self.feat_tokenizers_1 = nn.ModuleList([nn.Linear(1, d_model) for _ in range(n_feat1)])
10
  self.cls_token_1 = nn.Parameter(torch.zeros(1, 1, d_model))
11
- encoder_layer_1 = nn.TransformerEncoderLayer(d_model=d_model, nhead=4, batch_first=True)
12
- self.encoder_1 = nn.TransformerEncoder(encoder_layer_1, num_layers=3)
13
-
14
- # Stream 2: OSDI Tokenizer & Encoder
 
 
 
 
15
  self.feat_tokenizers_2 = nn.ModuleList([nn.Linear(1, d_model) for _ in range(n_feat2)])
16
  self.cls_token_2 = nn.Parameter(torch.zeros(1, 1, d_model))
17
- encoder_layer_2 = nn.TransformerEncoderLayer(d_model=d_model, nhead=4, batch_first=True)
18
- self.encoder_2 = nn.TransformerEncoder(encoder_layer_2, num_layers=3)
19
-
20
- # Fusion 層
 
 
 
21
  self.mlp_head = nn.Sequential(
22
  nn.Linear(d_model * 2, d_model),
23
  nn.ReLU(),
24
- nn.Linear(d_model, d_model) # 輸出 Embedding 給 ArcMargin
25
  )
26
 
27
- def forward(self, x1, x2):
28
- # Stream 1 推論
29
- tokens1 = [layer(x1[:, i].unsqueeze(1)) for i, layer in enumerate(self.feat_tokenizers_1)]
30
- x1_emb = torch.stack(tokens1, dim=1)
31
- x1_emb = torch.cat((self.cls_token_1.expand(x1.size(0), -1, -1), x1_emb), dim=1)
32
- feat1 = self.encoder_1(x1_emb)[:, 0, :] # 取 CLS token
33
-
34
- # Stream 2 推論
35
- tokens2 = [layer(x2[:, i].unsqueeze(1)) for i, layer in enumerate(self.feat_tokenizers_2)]
36
- x2_emb = torch.stack(tokens2, dim=1)
37
- x2_emb = torch.cat((self.cls_token_2.expand(x2.size(0), -1, -1), x2_emb), dim=1)
38
- feat2 = self.encoder_2(x2_emb)[:, 0, :] # 取 CLS token
39
-
40
- # 特徵融合
41
- combined = torch.cat((feat1, feat2), dim=1)
42
- return self.mlp_head(combined)
43
-
44
  class ArcMarginProduct(nn.Module):
45
  def __init__(self, in_features, out_features, s=30.0, m=0.5):
46
  super(ArcMarginProduct, self).__init__()
@@ -48,6 +37,5 @@ class ArcMarginProduct(nn.Module):
48
  nn.init.xavier_uniform_(self.weight)
49
 
50
  def predict(self, x):
51
- # 推論時直接做線性映射或餘弦相似度
52
  cosine = torch.matmul(nn.functional.normalize(x), nn.functional.normalize(self.weight).t())
53
  return cosine
 
4
  class DualStreamTransformer(nn.Module):
5
  def __init__(self, n_feat1=25, n_feat2=12, d_model=32, num_classes=2):
6
  super(DualStreamTransformer, self).__init__()
7
+ dim_ff = 128
 
8
  self.feat_tokenizers_1 = nn.ModuleList([nn.Linear(1, d_model) for _ in range(n_feat1)])
9
  self.cls_token_1 = nn.Parameter(torch.zeros(1, 1, d_model))
10
+
11
+ encoder_layer_1 = nn.TransformerEncoderLayer(
12
+ d_model=d_model,
13
+ nhead=4,
14
+ dim_feedforward=dim_ff,
15
+ batch_first=True
16
+ )
17
+ self.encoder_1 = nn.TransformerEncoder(encoder_layer_1, num_layers=2)
18
  self.feat_tokenizers_2 = nn.ModuleList([nn.Linear(1, d_model) for _ in range(n_feat2)])
19
  self.cls_token_2 = nn.Parameter(torch.zeros(1, 1, d_model))
20
+ encoder_layer_2 = nn.TransformerEncoderLayer(
21
+ d_model=d_model,
22
+ nhead=4,
23
+ dim_feedforward=dim_ff,
24
+ batch_first=True
25
+ )
26
+ self.encoder_2 = nn.TransformerEncoder(encoder_layer_2, num_layers=2)
27
  self.mlp_head = nn.Sequential(
28
  nn.Linear(d_model * 2, d_model),
29
  nn.ReLU(),
30
+ nn.Linear(d_model, d_model)
31
  )
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  class ArcMarginProduct(nn.Module):
34
  def __init__(self, in_features, out_features, s=30.0, m=0.5):
35
  super(ArcMarginProduct, self).__init__()
 
37
  nn.init.xavier_uniform_(self.weight)
38
 
39
  def predict(self, x):
 
40
  cosine = torch.matmul(nn.functional.normalize(x), nn.functional.normalize(self.weight).t())
41
  return cosine