""" OurNet Model Definition with ConvNeXt backbone This model is designed for image forgery detection using a ConvNeXt backbone with dual projection heads and a detection head. """ import torch import torch.nn as nn import timm class OurNet(nn.Module): def __init__(self, config=None): super().__init__() # Load config if provided if config is None: backbone_name = "convnext_base" n_features = 1024 else: backbone_name = config.get("backbone", {}).get("name", "convnext_base") n_features = config.get("backbone", {}).get("n_features", 1024) self.backbone = timm.create_model(backbone_name, pretrained=False) # Remove classification head if hasattr(self.backbone, "head"): self.n_features = self.backbone.head.in_features self.backbone.head = nn.Identity() elif hasattr(self.backbone, "fc"): self.n_features = self.backbone.fc.in_features self.backbone.fc = nn.Identity() elif hasattr(self.backbone, "classifier"): self.n_features = self.backbone.classifier.in_features self.backbone.classifier = nn.Identity() else: raise ValueError("Unsupported backbone architecture") # Projection heads self.aux_fc1 = nn.Sequential( nn.Linear(self.n_features, self.n_features), nn.ReLU(), nn.Linear(self.n_features, 128), ) self.aux_fc2 = nn.Sequential( nn.Linear(self.n_features, self.n_features), nn.ReLU(), nn.Linear(self.n_features, 128), ) # Detection heads self.det_fc1 = nn.Sequential( nn.Linear(self.n_features, self.n_features), nn.ReLU(), nn.Linear(self.n_features, 128), ) self.det_fc2 = nn.Sequential( nn.Linear(self.n_features, 256), nn.ReLU(inplace=True), nn.Dropout(0.5), nn.Linear(256, 1), ) def forward_det(self, x): """Detection forward pass""" feats = self.backbone.forward_features(x) feats = feats.mean([-2, -1]) homo_head = self.det_fc1(feats) det_head = self.det_fc2(feats) return homo_head, det_head def forward_proj(self, x): """Projection forward pass""" feats = self.backbone.forward_features(x) feats = feats.mean([-2, -1]) heter_head = self.aux_fc1(feats) homo_head = self.aux_fc2(feats) return heter_head, homo_head