Spaces:
Sleeping
Sleeping
Update model.py
Browse files
model.py
CHANGED
|
@@ -54,19 +54,17 @@ class ArcMarginProduct(nn.Module):
|
|
| 54 |
class DualStreamTransformer(nn.Module):
|
| 55 |
def __init__(self, feat_num_1, feat_num_2, d_model=64, num_classes=3, dropout=0.3):
|
| 56 |
super().__init__()
|
| 57 |
-
# Stream 1: CCMQ
|
| 58 |
self.feat_tokenizers_1 = nn.ModuleList([nn.Linear(1, d_model) for _ in range(feat_num_1)])
|
| 59 |
enc_layer_1 = nn.TransformerEncoderLayer(d_model=d_model, nhead=4, dim_feedforward=128, dropout=dropout, batch_first=True)
|
| 60 |
self.encoder_1 = nn.TransformerEncoder(enc_layer_1, num_layers=2)
|
| 61 |
self.cls_token_1 = nn.Parameter(torch.zeros(1, 1, d_model))
|
| 62 |
|
| 63 |
-
# Stream 2: OSDI
|
| 64 |
self.feat_tokenizers_2 = nn.ModuleList([nn.Linear(1, d_model) for _ in range(feat_num_2)])
|
| 65 |
enc_layer_2 = nn.TransformerEncoderLayer(d_model=d_model, nhead=4, dim_feedforward=128, dropout=dropout, batch_first=True)
|
| 66 |
self.encoder_2 = nn.TransformerEncoder(enc_layer_2, num_layers=2)
|
| 67 |
self.cls_token_2 = nn.Parameter(torch.zeros(1, 1, d_model))
|
| 68 |
|
| 69 |
-
|
| 70 |
self.fusion = nn.Sequential(
|
| 71 |
nn.Linear(d_model * 2, d_model),
|
| 72 |
nn.LayerNorm(d_model),
|
|
@@ -89,5 +87,5 @@ class DualStreamTransformer(nn.Module):
|
|
| 89 |
def forward(self, x1, x2):
|
| 90 |
feat_1 = self.forward_stream(x1, self.feat_tokenizers_1, self.encoder_1, self.cls_token_1)
|
| 91 |
feat_2 = self.forward_stream(x2, self.feat_tokenizers_2, self.encoder_2, self.cls_token_2)
|
| 92 |
-
combined = torch.cat((feat_1,
|
| 93 |
return self.fusion(combined)
|
|
|
|
| 54 |
class DualStreamTransformer(nn.Module):
|
| 55 |
def __init__(self, feat_num_1, feat_num_2, d_model=64, num_classes=3, dropout=0.3):
|
| 56 |
super().__init__()
|
|
|
|
| 57 |
self.feat_tokenizers_1 = nn.ModuleList([nn.Linear(1, d_model) for _ in range(feat_num_1)])
|
| 58 |
enc_layer_1 = nn.TransformerEncoderLayer(d_model=d_model, nhead=4, dim_feedforward=128, dropout=dropout, batch_first=True)
|
| 59 |
self.encoder_1 = nn.TransformerEncoder(enc_layer_1, num_layers=2)
|
| 60 |
self.cls_token_1 = nn.Parameter(torch.zeros(1, 1, d_model))
|
| 61 |
|
|
|
|
| 62 |
self.feat_tokenizers_2 = nn.ModuleList([nn.Linear(1, d_model) for _ in range(feat_num_2)])
|
| 63 |
enc_layer_2 = nn.TransformerEncoderLayer(d_model=d_model, nhead=4, dim_feedforward=128, dropout=dropout, batch_first=True)
|
| 64 |
self.encoder_2 = nn.TransformerEncoder(enc_layer_2, num_layers=2)
|
| 65 |
self.cls_token_2 = nn.Parameter(torch.zeros(1, 1, d_model))
|
| 66 |
|
| 67 |
+
|
| 68 |
self.fusion = nn.Sequential(
|
| 69 |
nn.Linear(d_model * 2, d_model),
|
| 70 |
nn.LayerNorm(d_model),
|
|
|
|
| 87 |
def forward(self, x1, x2):
|
| 88 |
feat_1 = self.forward_stream(x1, self.feat_tokenizers_1, self.encoder_1, self.cls_token_1)
|
| 89 |
feat_2 = self.forward_stream(x2, self.feat_tokenizers_2, self.encoder_2, self.cls_token_2)
|
| 90 |
+
combined = torch.cat((feat_1, feat_2), dim=1)
|
| 91 |
return self.fusion(combined)
|