| """ | |
| TimeSformer-based Video Anomaly Detection Model | |
| For deepfake detection via reconstruction | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| class Factorized3DConv(nn.Module): | |
| def __init__(self, in_c, out_c, k=3, s=1, p=1): | |
| super().__init__() | |
| self.spatial = nn.Conv3d(in_c, out_c, (1,k,k), (1,s,s), (0,p,p), bias=False) | |
| self.temporal = nn.Conv3d(out_c, out_c, (k,1,1), (1,1,1), (p,0,0), bias=False) | |
| self.bn1 = nn.BatchNorm3d(out_c) | |
| self.bn2 = nn.BatchNorm3d(out_c) | |
| self.relu = nn.ReLU(inplace=True) | |
| def forward(self, x): | |
| return self.relu(self.bn2(self.temporal(self.relu(self.bn1(self.spatial(x)))))) | |
| class OpticalFlowEstimator(nn.Module): | |
| def __init__(self, in_c=3): | |
| super().__init__() | |
| self.conv1 = nn.Conv2d(in_c*2, 64, 7, 2, 3) | |
| self.conv2 = nn.Conv2d(64, 128, 5, 2, 2) | |
| self.conv3 = nn.Conv2d(128, 256, 3, 2, 1) | |
| self.flow_head = nn.Conv2d(256, 2, 1) | |
| def forward(self, f1, f2): | |
| x = torch.cat([f1, f2], 1) | |
| x = F.relu(self.conv1(x)) | |
| x = F.relu(self.conv2(x)) | |
| x = F.relu(self.conv3(x)) | |
| return F.interpolate(self.flow_head(x), f1.shape[2:], mode='bilinear', align_corners=False) | |
| class PatchEmbedding3D(nn.Module): | |
| def __init__(self, img_sz=224, p_sz=16, in_c=3, emb=768, t_sz=2): | |
| super().__init__() | |
| self.proj = nn.Conv3d(in_c, emb, (t_sz, p_sz, p_sz), (t_sz, p_sz, p_sz)) | |
| def forward(self, x): | |
| x = self.proj(x) | |
| B, E, T, H, W = x.shape | |
| return rearrange(x, 'b e t h w -> b (t h w) e'), T*H*W | |
| class MultiHeadAttention3D(nn.Module): | |
| def __init__(self, dim, heads=12, drop=0.): | |
| super().__init__() | |
| self.heads = heads | |
| self.scale = (dim // heads) ** -0.5 | |
| self.qkv = nn.Linear(dim, dim * 3) | |
| self.proj = nn.Linear(dim, dim) | |
| self.drop = nn.Dropout(drop) | |
| def forward(self, x): | |
| B, N, C = x.shape | |
| qkv = self.qkv(x).reshape(B, N, 3, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) | |
| q, k, v = qkv[0], qkv[1], qkv[2] | |
| attn = (q @ k.transpose(-2, -1)) * self.scale | |
| attn = self.drop(attn.softmax(-1)) | |
| return self.proj((attn @ v).transpose(1, 2).reshape(B, N, C)) | |
| class TransformerBlock3D(nn.Module): | |
| def __init__(self, dim, heads, mlp_r=4., drop=0.): | |
| super().__init__() | |
| self.norm1 = nn.LayerNorm(dim) | |
| self.attn = MultiHeadAttention3D(dim, heads, drop) | |
| self.norm2 = nn.LayerNorm(dim) | |
| self.mlp = nn.Sequential( | |
| nn.Linear(dim, int(dim * mlp_r)), | |
| nn.GELU(), | |
| nn.Dropout(drop), | |
| nn.Linear(int(dim * mlp_r), dim), | |
| nn.Dropout(drop) | |
| ) | |
| def forward(self, x): | |
| x = x + self.attn(self.norm1(x)) | |
| return x + self.mlp(self.norm2(x)) | |
| class EnhancedTimeSformer(nn.Module): | |
| def __init__(self, img_sz=224, p_sz=16, in_c=3, n_fr=16, emb=768, depth=12, heads=12): | |
| super().__init__() | |
| self.img_sz = img_sz | |
| self.p_sz = p_sz | |
| self.stem = nn.Sequential( | |
| Factorized3DConv(in_c, 64, 7, 2, 3), | |
| Factorized3DConv(64, 128, 3, 2, 1) | |
| ) | |
| self.flow_est = OpticalFlowEstimator(in_c) | |
| self.flow_enc = nn.Sequential( | |
| nn.Conv2d(2, 64, 3, 1, 1), | |
| nn.ReLU(), | |
| nn.Conv2d(64, 128, 3, 1, 1) | |
| ) | |
| self.patch_emb = PatchEmbedding3D(img_sz // 4, p_sz // 4, 128, emb, 2) | |
| self.cls_tok = nn.Parameter(torch.zeros(1, 1, emb)) | |
| self.pos_emb = nn.Parameter(torch.zeros(1, 2048, emb)) | |
| self.blocks = nn.ModuleList([ | |
| TransformerBlock3D(emb, heads, 4., 0.1) for _ in range(depth) | |
| ]) | |
| self.norm = nn.LayerNorm(emb) | |
| self.recon_grid = (img_sz // 4) // (p_sz // 4) | |
| self.dec = nn.ModuleDict({ | |
| 'frame': nn.Sequential( | |
| nn.Linear(emb, emb * 2), | |
| nn.GELU(), | |
| nn.Dropout(0.1), | |
| nn.Linear(emb * 2, self.recon_grid ** 2 * in_c) | |
| ), | |
| 'flow': nn.Sequential( | |
| nn.Linear(emb, emb), | |
| nn.GELU(), | |
| nn.Linear(emb, self.recon_grid ** 2 * 2) | |
| ) | |
| }) | |
| nn.init.trunc_normal_(self.pos_emb, std=0.02) | |
| nn.init.trunc_normal_(self.cls_tok, std=0.02) | |
| def forward(self, x): | |
| B, C, T, H, W = x.shape | |
| flows = [self.flow_est(x[:, :, t], x[:, :, t+1]) for t in range(T-1)] | |
| if flows: | |
| flows = torch.stack(flows, 2) | |
| flows = rearrange( | |
| self.flow_enc(rearrange(flows, 'b c t h w -> (b t) c h w')), | |
| '(b t) c h w -> b c t h w', | |
| b=B | |
| ) | |
| else: | |
| flows = None | |
| x = self.stem(x) | |
| if flows is not None: | |
| flows = F.interpolate( | |
| rearrange(flows, 'b c t h w -> (b t) c h w'), | |
| size=(x.shape[3], x.shape[4]), | |
| mode='bilinear', | |
| align_corners=False | |
| ) | |
| flows = rearrange(flows, '(b t) c h w -> b c t h w', b=B) | |
| if flows.shape[2] < x.shape[2]: | |
| flows = F.pad(flows, (0, 0, 0, 0, 0, x.shape[2] - flows.shape[2])) | |
| x = x + 0.1 * flows | |
| x, n_p = self.patch_emb(x) | |
| x = torch.cat([self.cls_tok.expand(B, -1, -1), x], 1) | |
| if x.shape[1] > self.pos_emb.shape[1]: | |
| pe = F.interpolate( | |
| self.pos_emb.permute(0, 2, 1), | |
| x.shape[1], | |
| mode='linear', | |
| align_corners=False | |
| ).permute(0, 2, 1) | |
| else: | |
| pe = self.pos_emb[:, :x.shape[1]] | |
| x = x + pe | |
| for blk in self.blocks: | |
| x = blk(x) | |
| x = self.norm(x) | |
| mid = n_p // 2 | |
| tok = x[:, mid:mid+1] | |
| fr = rearrange( | |
| self.dec['frame'](tok), | |
| 'b 1 (p1 p2 c) -> b c p1 p2', | |
| p1=self.recon_grid, | |
| p2=self.recon_grid, | |
| c=C | |
| ) | |
| fl = rearrange( | |
| self.dec['flow'](tok), | |
| 'b 1 (p1 p2 c) -> b c p1 p2', | |
| p1=self.recon_grid, | |
| p2=self.recon_grid, | |
| c=2 | |
| ) | |
| fr = F.interpolate(fr, (H, W), mode='bilinear', align_corners=False) | |
| fl = F.interpolate(fl, (H, W), mode='bilinear', align_corners=False) | |
| return fr, fl | |
| def create_model(): | |
| """Factory function to create the model""" | |
| return EnhancedTimeSformer( | |
| img_sz=224, | |
| p_sz=16, | |
| in_c=3, | |
| n_fr=16, | |
| emb=768, | |
| depth=12, | |
| heads=12 | |
| ) | |