ash12321's picture
Upload complete model package with all files
6ecece2 verified
"""
TimeSformer-based Video Anomaly Detection Model
For deepfake detection via reconstruction
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class Factorized3DConv(nn.Module):
def __init__(self, in_c, out_c, k=3, s=1, p=1):
super().__init__()
self.spatial = nn.Conv3d(in_c, out_c, (1,k,k), (1,s,s), (0,p,p), bias=False)
self.temporal = nn.Conv3d(out_c, out_c, (k,1,1), (1,1,1), (p,0,0), bias=False)
self.bn1 = nn.BatchNorm3d(out_c)
self.bn2 = nn.BatchNorm3d(out_c)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return self.relu(self.bn2(self.temporal(self.relu(self.bn1(self.spatial(x))))))
class OpticalFlowEstimator(nn.Module):
def __init__(self, in_c=3):
super().__init__()
self.conv1 = nn.Conv2d(in_c*2, 64, 7, 2, 3)
self.conv2 = nn.Conv2d(64, 128, 5, 2, 2)
self.conv3 = nn.Conv2d(128, 256, 3, 2, 1)
self.flow_head = nn.Conv2d(256, 2, 1)
def forward(self, f1, f2):
x = torch.cat([f1, f2], 1)
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
return F.interpolate(self.flow_head(x), f1.shape[2:], mode='bilinear', align_corners=False)
class PatchEmbedding3D(nn.Module):
def __init__(self, img_sz=224, p_sz=16, in_c=3, emb=768, t_sz=2):
super().__init__()
self.proj = nn.Conv3d(in_c, emb, (t_sz, p_sz, p_sz), (t_sz, p_sz, p_sz))
def forward(self, x):
x = self.proj(x)
B, E, T, H, W = x.shape
return rearrange(x, 'b e t h w -> b (t h w) e'), T*H*W
class MultiHeadAttention3D(nn.Module):
def __init__(self, dim, heads=12, drop=0.):
super().__init__()
self.heads = heads
self.scale = (dim // heads) ** -0.5
self.qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
self.drop = nn.Dropout(drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = self.drop(attn.softmax(-1))
return self.proj((attn @ v).transpose(1, 2).reshape(B, N, C))
class TransformerBlock3D(nn.Module):
def __init__(self, dim, heads, mlp_r=4., drop=0.):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = MultiHeadAttention3D(dim, heads, drop)
self.norm2 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_r)),
nn.GELU(),
nn.Dropout(drop),
nn.Linear(int(dim * mlp_r), dim),
nn.Dropout(drop)
)
def forward(self, x):
x = x + self.attn(self.norm1(x))
return x + self.mlp(self.norm2(x))
class EnhancedTimeSformer(nn.Module):
def __init__(self, img_sz=224, p_sz=16, in_c=3, n_fr=16, emb=768, depth=12, heads=12):
super().__init__()
self.img_sz = img_sz
self.p_sz = p_sz
self.stem = nn.Sequential(
Factorized3DConv(in_c, 64, 7, 2, 3),
Factorized3DConv(64, 128, 3, 2, 1)
)
self.flow_est = OpticalFlowEstimator(in_c)
self.flow_enc = nn.Sequential(
nn.Conv2d(2, 64, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(64, 128, 3, 1, 1)
)
self.patch_emb = PatchEmbedding3D(img_sz // 4, p_sz // 4, 128, emb, 2)
self.cls_tok = nn.Parameter(torch.zeros(1, 1, emb))
self.pos_emb = nn.Parameter(torch.zeros(1, 2048, emb))
self.blocks = nn.ModuleList([
TransformerBlock3D(emb, heads, 4., 0.1) for _ in range(depth)
])
self.norm = nn.LayerNorm(emb)
self.recon_grid = (img_sz // 4) // (p_sz // 4)
self.dec = nn.ModuleDict({
'frame': nn.Sequential(
nn.Linear(emb, emb * 2),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(emb * 2, self.recon_grid ** 2 * in_c)
),
'flow': nn.Sequential(
nn.Linear(emb, emb),
nn.GELU(),
nn.Linear(emb, self.recon_grid ** 2 * 2)
)
})
nn.init.trunc_normal_(self.pos_emb, std=0.02)
nn.init.trunc_normal_(self.cls_tok, std=0.02)
def forward(self, x):
B, C, T, H, W = x.shape
flows = [self.flow_est(x[:, :, t], x[:, :, t+1]) for t in range(T-1)]
if flows:
flows = torch.stack(flows, 2)
flows = rearrange(
self.flow_enc(rearrange(flows, 'b c t h w -> (b t) c h w')),
'(b t) c h w -> b c t h w',
b=B
)
else:
flows = None
x = self.stem(x)
if flows is not None:
flows = F.interpolate(
rearrange(flows, 'b c t h w -> (b t) c h w'),
size=(x.shape[3], x.shape[4]),
mode='bilinear',
align_corners=False
)
flows = rearrange(flows, '(b t) c h w -> b c t h w', b=B)
if flows.shape[2] < x.shape[2]:
flows = F.pad(flows, (0, 0, 0, 0, 0, x.shape[2] - flows.shape[2]))
x = x + 0.1 * flows
x, n_p = self.patch_emb(x)
x = torch.cat([self.cls_tok.expand(B, -1, -1), x], 1)
if x.shape[1] > self.pos_emb.shape[1]:
pe = F.interpolate(
self.pos_emb.permute(0, 2, 1),
x.shape[1],
mode='linear',
align_corners=False
).permute(0, 2, 1)
else:
pe = self.pos_emb[:, :x.shape[1]]
x = x + pe
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
mid = n_p // 2
tok = x[:, mid:mid+1]
fr = rearrange(
self.dec['frame'](tok),
'b 1 (p1 p2 c) -> b c p1 p2',
p1=self.recon_grid,
p2=self.recon_grid,
c=C
)
fl = rearrange(
self.dec['flow'](tok),
'b 1 (p1 p2 c) -> b c p1 p2',
p1=self.recon_grid,
p2=self.recon_grid,
c=2
)
fr = F.interpolate(fr, (H, W), mode='bilinear', align_corners=False)
fl = F.interpolate(fl, (H, W), mode='bilinear', align_corners=False)
return fr, fl
def create_model():
"""Factory function to create the model"""
return EnhancedTimeSformer(
img_sz=224,
p_sz=16,
in_c=3,
n_fr=16,
emb=768,
depth=12,
heads=12
)