TheKernel01's picture
Sync from GitHub via hub-sync
0788e19 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from peft import LoraConfig, get_peft_model
from transformers import AutoModel
class AttentionPooling(nn.Module):
def __init__(self, dim):
super().__init__()
self.query = nn.Parameter(torch.randn(1, 1, dim))
self.key = nn.Linear(dim, dim)
self.value = nn.Linear(dim, dim)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
batch_size, _, dim = x.shape
query = self.query.expand(batch_size, -1, -1)
key = self.key(x)
value = self.value(x)
attention = torch.matmul(query, key.transpose(-2, -1)) / (dim**0.5)
attention = self.softmax(attention)
return torch.matmul(attention, value).squeeze(1)
class ForensicBranch(nn.Module):
"""Forensic Branch — captures high-frequency artifacts, initialized with SRM filters."""
def __init__(self, out_dim=256):
super().__init__()
# 5 SRM filters x 3 channels = 15 output channels
self.conv1 = nn.Conv2d(3, 15, kernel_size=5, padding=2, bias=False, groups=3)
kernels = torch.tensor(
[
[
[0, 0, -1, 0, 0],
[0, -1, -2, -1, 0],
[-1, -2, 16, -2, -1],
[0, -1, -2, -1, 0],
[0, 0, -1, 0, 0],
],
[
[0, 0, 0, 0, 0],
[0, 1, -2, 1, 0],
[0, -2, 4, -2, 0],
[0, 1, -2, 1, 0],
[0, 0, 0, 0, 0],
],
[
[0, 0, 0, 0, 0],
[0, -1, 2, -1, 0],
[0, 2, -4, 2, 0],
[0, -1, 2, -1, 0],
[0, 0, 0, 0, 0],
],
[
[0, 0, -1, 0, 0],
[0, 0, 2, 0, 0],
[0, 0, -1, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
],
[
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[-1, 2, -1, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
],
],
dtype=torch.float32,
)
# Initialize conv1 with the SRM kernels (shape 15, 1, 5, 5 for groups=3)
self.conv1.weight.data = kernels.unsqueeze(1).repeat(3, 1, 1, 1)
self.encoder = nn.Sequential(
nn.Conv2d(15, 32, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(32),
nn.GELU(),
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.GELU(),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.GELU(),
nn.Conv2d(128, out_dim, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(out_dim),
nn.GELU(),
)
self.pool = nn.AdaptiveAvgPool2d(1)
def forward(self, x):
# Residual = image minus low-freq smoothed version
blur = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
residual = x - blur
out = self.conv1(residual)
out = self.encoder(out)
return self.pool(out).flatten(1)
def _resolve_attr_path(module, attr_path):
current = module
for attr in attr_path.split('.'):
if not hasattr(current, attr):
return None
current = getattr(current, attr)
return current
def _find_transformer_layers(backbone):
candidate_paths = [
'encoder.layer',
'layers',
'blocks',
'transformer.layer',
'transformer.layers',
'model.encoder.layer',
'model.layers',
'backbone.encoder.layer',
'backbone.layers',
'backbone.blocks',
]
for path in candidate_paths:
layers = _resolve_attr_path(backbone, path)
if isinstance(layers, (nn.ModuleList, list, tuple)) and len(layers) > 0:
return list(layers)
for _, module in backbone.named_modules():
if isinstance(module, nn.ModuleList) and len(module) > 4:
return list(module)
return []
class DeForge_AI_Model(nn.Module):
def __init__(
self,
model_name='facebook/dinov3-vitl16-pretrain-lvd1689m',
lora_r=16,
lora_alpha=32,
lora_dropout=0.5,
lora_target_modules=None,
forensic_dim=256,
unfreeze_last_blocks=0,
image_size=256,
hf_token=None,
):
super().__init__()
if lora_target_modules is None:
lora_target_modules = [
'q_proj',
'k_proj',
'v_proj',
'out_proj',
'fc1',
'fc2',
]
self.image_size = image_size
self.forensic_dim = forensic_dim
backbone = AutoModel.from_pretrained(
model_name, trust_remote_code=True, token=hf_token
)
backbone.requires_grad_(False)
self._unfreeze_last_blocks(backbone, unfreeze_last_blocks)
lora_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias='none',
)
self.backbone = get_peft_model(backbone, lora_config)
hidden_size = self.backbone.config.hidden_size
self.attn_pool = AttentionPooling(hidden_size)
self.rgb_proj = nn.Sequential(
nn.Linear(hidden_size * 2, hidden_size),
nn.LayerNorm(hidden_size),
nn.GELU(),
nn.Dropout(0.1),
)
self.forensic_branch = ForensicBranch(out_dim=forensic_dim)
self.forensic_gate = nn.Parameter(torch.tensor(0.3))
self.head = nn.Sequential(
nn.Linear(hidden_size + forensic_dim, 512),
nn.LayerNorm(512),
nn.GELU(),
nn.Dropout(0.3),
nn.Linear(512, 128),
nn.GELU(),
nn.Linear(128, 1),
)
nn.init.zeros_(self.head[-1].weight)
nn.init.zeros_(self.head[-1].bias)
def _unfreeze_last_blocks(self, backbone, unfreeze_last_blocks):
if unfreeze_last_blocks <= 0:
return
layers = _find_transformer_layers(backbone)
if not layers:
return
for block in layers[-unfreeze_last_blocks:]:
block.requires_grad_(True)
for attr_path in [
'layernorm',
'norm',
'post_layernorm',
'ln_post',
'final_layer_norm',
]:
module = _resolve_attr_path(backbone, attr_path)
if isinstance(module, nn.Module):
module.requires_grad_(True)
def forward(self, x):
outputs = self.backbone(x)
last_hidden_state = outputs.last_hidden_state
cls_token = last_hidden_state[:, 0, :]
patch_tokens = last_hidden_state[:, 1:, :]
token_features = self.attn_pool(patch_tokens)
rgb_features = self.rgb_proj(torch.cat([cls_token, token_features], dim=1))
forensic_features = self.forensic_branch(x)
return self.head(
torch.cat([rgb_features, self.forensic_gate * forensic_features], dim=1)
)
def detect(self, x):
with torch.inference_mode():
return torch.sigmoid(self.forward(x)).squeeze(1)