""" OurNet Model Definition with FasterViT backbone """ import torch import torch.nn as nn from fastervit import create_model class OurNet(nn.Module): def __init__(self, config=None): super().__init__() # Load config if config is None: backbone_name = "faster_vit_2_224" else: backbone_name = config.get("backbone", {}).get("name", "faster_vit_2_224") # Create FasterViT backbone (without pretrained weights) self.backbone = create_model(backbone_name, pretrained=False) # Dynamically get feature dimension using dummy input dummy_input = torch.randn(1, 3, 224, 224) with torch.no_grad(): dummy_feat = self.backbone.forward_features(dummy_input) # Handle 4D output [B, C, H, W] if len(dummy_feat.shape) == 4: dummy_feat = dummy_feat.mean([-2, -1]) self.n_features = dummy_feat.shape[1] # Auxiliary projection heads (for Stage 1) self.aux_fc1 = nn.Sequential( nn.Linear(self.n_features, self.n_features), nn.ReLU(), nn.Linear(self.n_features, 128), ) self.aux_fc2 = nn.Sequential( nn.Linear(self.n_features, self.n_features), nn.ReLU(), nn.Linear(self.n_features, 128), ) # Detection heads (for Stage 2) self.det_fc1 = nn.Sequential( nn.Linear(self.n_features, self.n_features), nn.ReLU(), nn.Linear(self.n_features, 128), ) self.det_fc2 = nn.Sequential( nn.Linear(self.n_features, 256), nn.ReLU(inplace=True), nn.Dropout(0.5), nn.Linear(256, 1), ) def forward_det(self, x): """Forward pass for detection (Stage 2)""" feats = self.backbone.forward_features(x) # Handle 4D output if len(feats.shape) == 4: feats = feats.mean([-2, -1]) homo_head = self.det_fc1(feats) det_head = self.det_fc2(feats) return homo_head, det_head def forward_proj(self, x): """Forward pass for auxiliary projection (Stage 1)""" feats = self.backbone.forward_features(x) # Handle 4D output if len(feats.shape) == 4: feats = feats.mean([-2, -1]) heter_head = self.aux_fc1(feats) homo_head = self.aux_fc2(feats) return heter_head, homo_head