Spaces:
Sleeping
Sleeping
File size: 1,547 Bytes
459fc8b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 | import torch
import torch.nn as nn
import timm
class DeepfakeEffNetTransformer(nn.Module):
def __init__(self):
super().__init__()
# CNN BACKBONE
self.cnn = timm.create_model(
"tf_efficientnetv2_b1",
pretrained=False,
num_classes=0,
global_pool=""
)
self.pool = nn.AdaptiveAvgPool2d(1)
feat_dim = self.cnn.num_features
# PROJECTION
self.proj = nn.Linear(feat_dim, 512)
# POSITIONAL EMBEDDING
self.pos_embed = nn.Parameter(
torch.randn(1, 32, 512)
)
# TRANSFORMER
encoder_layer = nn.TransformerEncoderLayer(
d_model=512,
nhead=8,
dim_feedforward=1024,
batch_first=True
)
self.transformer = nn.TransformerEncoder(
encoder_layer,
num_layers=2
)
# CLASSIFIER
self.classifier = nn.Sequential(
nn.Linear(512, 128),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(128, 2)
)
def forward(self, x):
B, T, C, H, W = x.shape
x = x.view(B * T, C, H, W)
# CNN FEATURES
feats = self.cnn(x)
feats = self.pool(feats)
feats = feats.view(B, T, -1)
# PROJECTION
feats = self.proj(feats)
feats = feats + self.pos_embed
# TEMPORAL TRANSFORMER
out = self.transformer(feats)
out = out.mean(dim=1)
return self.classifier(out) |