| """ |
| OurNet Model Definition with FasterViT backbone |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from fastervit import create_model |
|
|
|
|
| class OurNet(nn.Module): |
| def __init__(self, config=None): |
| super().__init__() |
|
|
| |
| if config is None: |
| backbone_name = "faster_vit_2_224" |
| else: |
| backbone_name = config.get("backbone", {}).get("name", "faster_vit_2_224") |
|
|
| |
| self.backbone = create_model(backbone_name, pretrained=False) |
|
|
| |
| dummy_input = torch.randn(1, 3, 224, 224) |
| with torch.no_grad(): |
| dummy_feat = self.backbone.forward_features(dummy_input) |
| |
| if len(dummy_feat.shape) == 4: |
| dummy_feat = dummy_feat.mean([-2, -1]) |
| self.n_features = dummy_feat.shape[1] |
|
|
| |
| self.aux_fc1 = nn.Sequential( |
| nn.Linear(self.n_features, self.n_features), |
| nn.ReLU(), |
| nn.Linear(self.n_features, 128), |
| ) |
| self.aux_fc2 = nn.Sequential( |
| nn.Linear(self.n_features, self.n_features), |
| nn.ReLU(), |
| nn.Linear(self.n_features, 128), |
| ) |
|
|
| |
| self.det_fc1 = nn.Sequential( |
| nn.Linear(self.n_features, self.n_features), |
| nn.ReLU(), |
| nn.Linear(self.n_features, 128), |
| ) |
| self.det_fc2 = nn.Sequential( |
| nn.Linear(self.n_features, 256), |
| nn.ReLU(inplace=True), |
| nn.Dropout(0.5), |
| nn.Linear(256, 1), |
| ) |
|
|
| def forward_det(self, x): |
| """Forward pass for detection (Stage 2)""" |
| feats = self.backbone.forward_features(x) |
| |
| if len(feats.shape) == 4: |
| feats = feats.mean([-2, -1]) |
| homo_head = self.det_fc1(feats) |
| det_head = self.det_fc2(feats) |
| return homo_head, det_head |
|
|
| def forward_proj(self, x): |
| """Forward pass for auxiliary projection (Stage 1)""" |
| feats = self.backbone.forward_features(x) |
| |
| if len(feats.shape) == 4: |
| feats = feats.mean([-2, -1]) |
| heter_head = self.aux_fc1(feats) |
| homo_head = self.aux_fc2(feats) |
| return heter_head, homo_head |
|
|