| """ |
| OurNet Model Definition with DINOv2 backbone |
| |
| Uses transformers.AutoModel to load DINOv2-base |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import AutoModel |
|
|
|
|
| class OurNet(nn.Module): |
| def __init__(self, config=None): |
| super().__init__() |
|
|
| |
| self.backbone = AutoModel.from_pretrained( |
| 'facebook/dinov2-base', |
| local_files_only=True |
| ) |
|
|
| |
| self.n_features = 768 |
|
|
| |
| self.proj_dim = config.get('proj_dim', 768) if config else 768 |
|
|
| |
| self.aux_fc1 = nn.Sequential( |
| nn.Linear(self.n_features, self.n_features), |
| nn.ReLU(), |
| nn.Linear(self.n_features, self.proj_dim) |
| ) |
|
|
| self.aux_fc2 = nn.Sequential( |
| nn.Linear(self.n_features, self.n_features), |
| nn.ReLU(), |
| nn.Linear(self.n_features, self.proj_dim) |
| ) |
|
|
| |
| self.det_fc1 = nn.Sequential( |
| nn.Linear(self.n_features, self.n_features), |
| nn.ReLU(), |
| nn.Linear(self.n_features, self.proj_dim) |
| ) |
|
|
| self.det_fc2 = nn.Sequential( |
| nn.Linear(self.n_features, 256), |
| nn.ReLU(inplace=True), |
| nn.Dropout(0.5), |
| nn.Linear(256, 1) |
| ) |
|
|
| def forward_proj(self, x): |
| """Forward pass for auxiliary projection (Stage 1)""" |
| outputs = self.backbone(x) |
| feats = outputs.last_hidden_state[:, 0] |
| heter_head = self.aux_fc1(feats) |
| homo_head = self.aux_fc2(feats) |
| return heter_head, homo_head |
|
|
| def forward_det(self, x): |
| """Forward pass for detection (Stage 2)""" |
| outputs = self.backbone(x) |
| feats = outputs.last_hidden_state[:, 0] |
| homo_head = self.det_fc1(feats) |
| det_head = self.det_fc2(feats) |
| return homo_head, det_head |