import torch import torch.nn as nn import timm class DeepfakeEffNetTransformer(nn.Module): def __init__(self): super().__init__() # CNN BACKBONE self.cnn = timm.create_model( "tf_efficientnetv2_b1", pretrained=False, num_classes=0, global_pool="" ) self.pool = nn.AdaptiveAvgPool2d(1) feat_dim = self.cnn.num_features # PROJECTION self.proj = nn.Linear(feat_dim, 512) # POSITIONAL EMBEDDING self.pos_embed = nn.Parameter( torch.randn(1, 32, 512) ) # TRANSFORMER encoder_layer = nn.TransformerEncoderLayer( d_model=512, nhead=8, dim_feedforward=1024, batch_first=True ) self.transformer = nn.TransformerEncoder( encoder_layer, num_layers=2 ) # CLASSIFIER self.classifier = nn.Sequential( nn.Linear(512, 128), nn.ReLU(), nn.Dropout(0.3), nn.Linear(128, 2) ) def forward(self, x): B, T, C, H, W = x.shape x = x.view(B * T, C, H, W) # CNN FEATURES feats = self.cnn(x) feats = self.pool(feats) feats = feats.view(B, T, -1) # PROJECTION feats = self.proj(feats) feats = feats + self.pos_embed # TEMPORAL TRANSFORMER out = self.transformer(feats) out = out.mean(dim=1) return self.classifier(out)