Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import timm | |
| class DeepfakeEffNetTransformer(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| # CNN BACKBONE | |
| self.cnn = timm.create_model( | |
| "tf_efficientnetv2_b1", | |
| pretrained=False, | |
| num_classes=0, | |
| global_pool="" | |
| ) | |
| self.pool = nn.AdaptiveAvgPool2d(1) | |
| feat_dim = self.cnn.num_features | |
| # PROJECTION | |
| self.proj = nn.Linear(feat_dim, 512) | |
| # POSITIONAL EMBEDDING | |
| self.pos_embed = nn.Parameter( | |
| torch.randn(1, 32, 512) | |
| ) | |
| # TRANSFORMER | |
| encoder_layer = nn.TransformerEncoderLayer( | |
| d_model=512, | |
| nhead=8, | |
| dim_feedforward=1024, | |
| batch_first=True | |
| ) | |
| self.transformer = nn.TransformerEncoder( | |
| encoder_layer, | |
| num_layers=2 | |
| ) | |
| # CLASSIFIER | |
| self.classifier = nn.Sequential( | |
| nn.Linear(512, 128), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(128, 2) | |
| ) | |
| def forward(self, x): | |
| B, T, C, H, W = x.shape | |
| x = x.view(B * T, C, H, W) | |
| # CNN FEATURES | |
| feats = self.cnn(x) | |
| feats = self.pool(feats) | |
| feats = feats.view(B, T, -1) | |
| # PROJECTION | |
| feats = self.proj(feats) | |
| feats = feats + self.pos_embed | |
| # TEMPORAL TRANSFORMER | |
| out = self.transformer(feats) | |
| out = out.mean(dim=1) | |
| return self.classifier(out) |