""" TimeSformer-based Video Anomaly Detection Model For deepfake detection via reconstruction """ import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange class Factorized3DConv(nn.Module): def __init__(self, in_c, out_c, k=3, s=1, p=1): super().__init__() self.spatial = nn.Conv3d(in_c, out_c, (1,k,k), (1,s,s), (0,p,p), bias=False) self.temporal = nn.Conv3d(out_c, out_c, (k,1,1), (1,1,1), (p,0,0), bias=False) self.bn1 = nn.BatchNorm3d(out_c) self.bn2 = nn.BatchNorm3d(out_c) self.relu = nn.ReLU(inplace=True) def forward(self, x): return self.relu(self.bn2(self.temporal(self.relu(self.bn1(self.spatial(x)))))) class OpticalFlowEstimator(nn.Module): def __init__(self, in_c=3): super().__init__() self.conv1 = nn.Conv2d(in_c*2, 64, 7, 2, 3) self.conv2 = nn.Conv2d(64, 128, 5, 2, 2) self.conv3 = nn.Conv2d(128, 256, 3, 2, 1) self.flow_head = nn.Conv2d(256, 2, 1) def forward(self, f1, f2): x = torch.cat([f1, f2], 1) x = F.relu(self.conv1(x)) x = F.relu(self.conv2(x)) x = F.relu(self.conv3(x)) return F.interpolate(self.flow_head(x), f1.shape[2:], mode='bilinear', align_corners=False) class PatchEmbedding3D(nn.Module): def __init__(self, img_sz=224, p_sz=16, in_c=3, emb=768, t_sz=2): super().__init__() self.proj = nn.Conv3d(in_c, emb, (t_sz, p_sz, p_sz), (t_sz, p_sz, p_sz)) def forward(self, x): x = self.proj(x) B, E, T, H, W = x.shape return rearrange(x, 'b e t h w -> b (t h w) e'), T*H*W class MultiHeadAttention3D(nn.Module): def __init__(self, dim, heads=12, drop=0.): super().__init__() self.heads = heads self.scale = (dim // heads) ** -0.5 self.qkv = nn.Linear(dim, dim * 3) self.proj = nn.Linear(dim, dim) self.drop = nn.Dropout(drop) def forward(self, x): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] attn = (q @ k.transpose(-2, -1)) * self.scale attn = self.drop(attn.softmax(-1)) return self.proj((attn @ v).transpose(1, 2).reshape(B, N, C)) class TransformerBlock3D(nn.Module): def __init__(self, dim, heads, mlp_r=4., drop=0.): super().__init__() self.norm1 = nn.LayerNorm(dim) self.attn = MultiHeadAttention3D(dim, heads, drop) self.norm2 = nn.LayerNorm(dim) self.mlp = nn.Sequential( nn.Linear(dim, int(dim * mlp_r)), nn.GELU(), nn.Dropout(drop), nn.Linear(int(dim * mlp_r), dim), nn.Dropout(drop) ) def forward(self, x): x = x + self.attn(self.norm1(x)) return x + self.mlp(self.norm2(x)) class EnhancedTimeSformer(nn.Module): def __init__(self, img_sz=224, p_sz=16, in_c=3, n_fr=16, emb=768, depth=12, heads=12): super().__init__() self.img_sz = img_sz self.p_sz = p_sz self.stem = nn.Sequential( Factorized3DConv(in_c, 64, 7, 2, 3), Factorized3DConv(64, 128, 3, 2, 1) ) self.flow_est = OpticalFlowEstimator(in_c) self.flow_enc = nn.Sequential( nn.Conv2d(2, 64, 3, 1, 1), nn.ReLU(), nn.Conv2d(64, 128, 3, 1, 1) ) self.patch_emb = PatchEmbedding3D(img_sz // 4, p_sz // 4, 128, emb, 2) self.cls_tok = nn.Parameter(torch.zeros(1, 1, emb)) self.pos_emb = nn.Parameter(torch.zeros(1, 2048, emb)) self.blocks = nn.ModuleList([ TransformerBlock3D(emb, heads, 4., 0.1) for _ in range(depth) ]) self.norm = nn.LayerNorm(emb) self.recon_grid = (img_sz // 4) // (p_sz // 4) self.dec = nn.ModuleDict({ 'frame': nn.Sequential( nn.Linear(emb, emb * 2), nn.GELU(), nn.Dropout(0.1), nn.Linear(emb * 2, self.recon_grid ** 2 * in_c) ), 'flow': nn.Sequential( nn.Linear(emb, emb), nn.GELU(), nn.Linear(emb, self.recon_grid ** 2 * 2) ) }) nn.init.trunc_normal_(self.pos_emb, std=0.02) nn.init.trunc_normal_(self.cls_tok, std=0.02) def forward(self, x): B, C, T, H, W = x.shape flows = [self.flow_est(x[:, :, t], x[:, :, t+1]) for t in range(T-1)] if flows: flows = torch.stack(flows, 2) flows = rearrange( self.flow_enc(rearrange(flows, 'b c t h w -> (b t) c h w')), '(b t) c h w -> b c t h w', b=B ) else: flows = None x = self.stem(x) if flows is not None: flows = F.interpolate( rearrange(flows, 'b c t h w -> (b t) c h w'), size=(x.shape[3], x.shape[4]), mode='bilinear', align_corners=False ) flows = rearrange(flows, '(b t) c h w -> b c t h w', b=B) if flows.shape[2] < x.shape[2]: flows = F.pad(flows, (0, 0, 0, 0, 0, x.shape[2] - flows.shape[2])) x = x + 0.1 * flows x, n_p = self.patch_emb(x) x = torch.cat([self.cls_tok.expand(B, -1, -1), x], 1) if x.shape[1] > self.pos_emb.shape[1]: pe = F.interpolate( self.pos_emb.permute(0, 2, 1), x.shape[1], mode='linear', align_corners=False ).permute(0, 2, 1) else: pe = self.pos_emb[:, :x.shape[1]] x = x + pe for blk in self.blocks: x = blk(x) x = self.norm(x) mid = n_p // 2 tok = x[:, mid:mid+1] fr = rearrange( self.dec['frame'](tok), 'b 1 (p1 p2 c) -> b c p1 p2', p1=self.recon_grid, p2=self.recon_grid, c=C ) fl = rearrange( self.dec['flow'](tok), 'b 1 (p1 p2 c) -> b c p1 p2', p1=self.recon_grid, p2=self.recon_grid, c=2 ) fr = F.interpolate(fr, (H, W), mode='bilinear', align_corners=False) fl = F.interpolate(fl, (H, W), mode='bilinear', align_corners=False) return fr, fl def create_model(): """Factory function to create the model""" return EnhancedTimeSformer( img_sz=224, p_sz=16, in_c=3, n_fr=16, emb=768, depth=12, heads=12 )