deepfake-detector / model.py
dappai's picture
Upload 9 files
459fc8b verified
import torch
import torch.nn as nn
import timm
class DeepfakeEffNetTransformer(nn.Module):
def __init__(self):
super().__init__()
# CNN BACKBONE
self.cnn = timm.create_model(
"tf_efficientnetv2_b1",
pretrained=False,
num_classes=0,
global_pool=""
)
self.pool = nn.AdaptiveAvgPool2d(1)
feat_dim = self.cnn.num_features
# PROJECTION
self.proj = nn.Linear(feat_dim, 512)
# POSITIONAL EMBEDDING
self.pos_embed = nn.Parameter(
torch.randn(1, 32, 512)
)
# TRANSFORMER
encoder_layer = nn.TransformerEncoderLayer(
d_model=512,
nhead=8,
dim_feedforward=1024,
batch_first=True
)
self.transformer = nn.TransformerEncoder(
encoder_layer,
num_layers=2
)
# CLASSIFIER
self.classifier = nn.Sequential(
nn.Linear(512, 128),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(128, 2)
)
def forward(self, x):
B, T, C, H, W = x.shape
x = x.view(B * T, C, H, W)
# CNN FEATURES
feats = self.cnn(x)
feats = self.pool(feats)
feats = feats.view(B, T, -1)
# PROJECTION
feats = self.proj(feats)
feats = feats + self.pos_embed
# TEMPORAL TRANSFORMER
out = self.transformer(feats)
out = out.mean(dim=1)
return self.classifier(out)