yuwan0's picture
initial
0212735
"""
OurNet Model Definition with FasterViT backbone
"""
import torch
import torch.nn as nn
from fastervit import create_model
class OurNet(nn.Module):
def __init__(self, config=None):
super().__init__()
# Load config
if config is None:
backbone_name = "faster_vit_2_224"
else:
backbone_name = config.get("backbone", {}).get("name", "faster_vit_2_224")
# Create FasterViT backbone (without pretrained weights)
self.backbone = create_model(backbone_name, pretrained=False)
# Dynamically get feature dimension using dummy input
dummy_input = torch.randn(1, 3, 224, 224)
with torch.no_grad():
dummy_feat = self.backbone.forward_features(dummy_input)
# Handle 4D output [B, C, H, W]
if len(dummy_feat.shape) == 4:
dummy_feat = dummy_feat.mean([-2, -1])
self.n_features = dummy_feat.shape[1]
# Auxiliary projection heads (for Stage 1)
self.aux_fc1 = nn.Sequential(
nn.Linear(self.n_features, self.n_features),
nn.ReLU(),
nn.Linear(self.n_features, 128),
)
self.aux_fc2 = nn.Sequential(
nn.Linear(self.n_features, self.n_features),
nn.ReLU(),
nn.Linear(self.n_features, 128),
)
# Detection heads (for Stage 2)
self.det_fc1 = nn.Sequential(
nn.Linear(self.n_features, self.n_features),
nn.ReLU(),
nn.Linear(self.n_features, 128),
)
self.det_fc2 = nn.Sequential(
nn.Linear(self.n_features, 256),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(256, 1),
)
def forward_det(self, x):
"""Forward pass for detection (Stage 2)"""
feats = self.backbone.forward_features(x)
# Handle 4D output
if len(feats.shape) == 4:
feats = feats.mean([-2, -1])
homo_head = self.det_fc1(feats)
det_head = self.det_fc2(feats)
return homo_head, det_head
def forward_proj(self, x):
"""Forward pass for auxiliary projection (Stage 1)"""
feats = self.backbone.forward_features(x)
# Handle 4D output
if len(feats.shape) == 4:
feats = feats.mean([-2, -1])
heter_head = self.aux_fc1(feats)
homo_head = self.aux_fc2(feats)
return heter_head, homo_head