yuwan0's picture
initial
0212735
"""
OurNet Model Definition with ConvNeXt backbone
This model is designed for image forgery detection using a ConvNeXt backbone
with dual projection heads and a detection head.
"""
import torch
import torch.nn as nn
import timm
class OurNet(nn.Module):
def __init__(self, config=None):
super().__init__()
# Load config if provided
if config is None:
backbone_name = "convnext_base"
n_features = 1024
else:
backbone_name = config.get("backbone", {}).get("name", "convnext_base")
n_features = config.get("backbone", {}).get("n_features", 1024)
self.backbone = timm.create_model(backbone_name, pretrained=False)
# Remove classification head
if hasattr(self.backbone, "head"):
self.n_features = self.backbone.head.in_features
self.backbone.head = nn.Identity()
elif hasattr(self.backbone, "fc"):
self.n_features = self.backbone.fc.in_features
self.backbone.fc = nn.Identity()
elif hasattr(self.backbone, "classifier"):
self.n_features = self.backbone.classifier.in_features
self.backbone.classifier = nn.Identity()
else:
raise ValueError("Unsupported backbone architecture")
# Projection heads
self.aux_fc1 = nn.Sequential(
nn.Linear(self.n_features, self.n_features),
nn.ReLU(),
nn.Linear(self.n_features, 128),
)
self.aux_fc2 = nn.Sequential(
nn.Linear(self.n_features, self.n_features),
nn.ReLU(),
nn.Linear(self.n_features, 128),
)
# Detection heads
self.det_fc1 = nn.Sequential(
nn.Linear(self.n_features, self.n_features),
nn.ReLU(),
nn.Linear(self.n_features, 128),
)
self.det_fc2 = nn.Sequential(
nn.Linear(self.n_features, 256),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(256, 1),
)
def forward_det(self, x):
"""Detection forward pass"""
feats = self.backbone.forward_features(x)
feats = feats.mean([-2, -1])
homo_head = self.det_fc1(feats)
det_head = self.det_fc2(feats)
return homo_head, det_head
def forward_proj(self, x):
"""Projection forward pass"""
feats = self.backbone.forward_features(x)
feats = feats.mean([-2, -1])
heter_head = self.aux_fc1(feats)
homo_head = self.aux_fc2(feats)
return heter_head, homo_head