| | import torch
|
| | import torch.nn as nn
|
| | from torchvision import models
|
| |
|
| | class FeatureExtractor(nn.Module):
|
| | """
|
| | Extracts spatial features from a single frame using a pre-trained ResNeXt.
|
| | """
|
| | def __init__(self, freeze=True):
|
| | super(FeatureExtractor, self).__init__()
|
| |
|
| |
|
| |
|
| | self.model = models.resnext50_32x4d(weights=models.ResNeXt50_32X4D_Weights.IMAGENET1K_V2)
|
| |
|
| |
|
| | if freeze:
|
| | for param in self.model.parameters():
|
| | param.requires_grad = False
|
| |
|
| |
|
| |
|
| | self.feature_dim = self.model.fc.in_features
|
| |
|
| |
|
| |
|
| | self.model.fc = nn.Identity()
|
| |
|
| | def forward(self, x):
|
| |
|
| |
|
| | return self.model(x)
|
| |
|
| | class DeepfakeDetector(nn.Module):
|
| | """
|
| | Combines the CNN extractor and LSTM sequencer to classify a video.
|
| | """
|
| | def __init__(self, cnn_feature_dim, lstm_hidden_size=512, lstm_layers=2, num_classes=2, dropout=0.5):
|
| | """
|
| | Args:
|
| | cnn_feature_dim (int): The output dimension from our FeatureExtractor (e.g., 2048 for ResNeXt50)
|
| | lstm_hidden_size (int): The number of features in the LSTM's hidden state.
|
| | lstm_layers (int): The number of stacked LSTM layers.
|
| | num_classes (int): The number of output classes (2: Real/Fake).
|
| | dropout (float): Dropout probability for regularization.
|
| | """
|
| | super(DeepfakeDetector, self).__init__()
|
| |
|
| | self.feature_extractor = FeatureExtractor(freeze=True)
|
| | self.lstm_hidden_size = lstm_hidden_size
|
| | self.lstm_layers = lstm_layers
|
| |
|
| |
|
| |
|
| | self.lstm = nn.LSTM(
|
| | input_size=cnn_feature_dim,
|
| | hidden_size=lstm_hidden_size,
|
| | num_layers=lstm_layers,
|
| | batch_first=True,
|
| | bidirectional=True,
|
| | dropout=dropout if lstm_layers > 1 else 0
|
| | )
|
| |
|
| |
|
| |
|
| | self.fc1 = nn.Linear(
|
| | lstm_hidden_size * 2,
|
| | lstm_hidden_size // 2
|
| | )
|
| | self.relu = nn.ReLU()
|
| | self.dropout = nn.Dropout(dropout)
|
| | self.fc2 = nn.Linear(lstm_hidden_size // 2, num_classes)
|
| |
|
| | def forward(self, x):
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | batch_size, seq_len, c, h, w = x.shape
|
| |
|
| |
|
| |
|
| |
|
| | x_flat = x.view(batch_size * seq_len, c, h, w)
|
| |
|
| | features = self.feature_extractor(x_flat)
|
| |
|
| |
|
| |
|
| |
|
| | features_seq = features.view(batch_size, seq_len, -1)
|
| |
|
| |
|
| |
|
| |
|
| | lstm_out, (h_n, c_n) = self.lstm(features_seq)
|
| |
|
| |
|
| |
|
| | last_time_step_out = lstm_out[:, -1, :]
|
| |
|
| |
|
| |
|
| |
|
| | x = self.dropout(self.relu(self.fc1(last_time_step_out)))
|
| | out = self.fc2(x)
|
| |
|
| |
|
| | return out
|
| |
|