yuwan0's picture
initial
0212735
"""
OurNet Model Definition with DINOv2 backbone
Uses transformers.AutoModel to load DINOv2-base
"""
import torch
import torch.nn as nn
from transformers import AutoModel
class OurNet(nn.Module):
def __init__(self, config=None):
super().__init__()
# DINOv2 model from HuggingFace
self.backbone = AutoModel.from_pretrained(
'facebook/dinov2-base',
local_files_only=True
)
# Feature dimension for DINOv2-base
self.n_features = 768
# Projection dimension from config
self.proj_dim = config.get('proj_dim', 768) if config else 768
# Auxiliary projection heads (for Stage 1)
self.aux_fc1 = nn.Sequential(
nn.Linear(self.n_features, self.n_features),
nn.ReLU(),
nn.Linear(self.n_features, self.proj_dim)
)
self.aux_fc2 = nn.Sequential(
nn.Linear(self.n_features, self.n_features),
nn.ReLU(),
nn.Linear(self.n_features, self.proj_dim)
)
# Detection heads (for Stage 2)
self.det_fc1 = nn.Sequential(
nn.Linear(self.n_features, self.n_features),
nn.ReLU(),
nn.Linear(self.n_features, self.proj_dim)
)
self.det_fc2 = nn.Sequential(
nn.Linear(self.n_features, 256),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(256, 1)
)
def forward_proj(self, x):
"""Forward pass for auxiliary projection (Stage 1)"""
outputs = self.backbone(x)
feats = outputs.last_hidden_state[:, 0] # CLS token
heter_head = self.aux_fc1(feats)
homo_head = self.aux_fc2(feats)
return heter_head, homo_head
def forward_det(self, x):
"""Forward pass for detection (Stage 2)"""
outputs = self.backbone(x)
feats = outputs.last_hidden_state[:, 0] # CLS token
homo_head = self.det_fc1(feats)
det_head = self.det_fc2(feats)
return homo_head, det_head