deepfake_detection / cross_efficient_vit_model.py
Pranithkumar7's picture
Upload 11 files
a972d65 verified
import torch
from torch import einsum, nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from efficientnet_pytorch import EfficientNet
CROSS_EFFICIENT_VIT_CONFIG = {
"model": {
"image-size": 224,
"num-classes": 1,
"depth": 4,
"sm-dim": 192,
"sm-patch-size": 7,
"sm-enc-depth": 2,
"sm-enc-dim-head": 64,
"sm-enc-heads": 8,
"sm-enc-mlp-dim": 2048,
"lg-dim": 384,
"lg-patch-size": 56,
"lg-enc-depth": 3,
"lg-enc-dim-head": 64,
"lg-enc-heads": 8,
"lg-enc-mlp-dim": 2048,
"cross-attn-depth": 2,
"cross-attn-dim-head": 64,
"cross-attn-heads": 8,
"lg-channels": 24,
"sm-channels": 1280,
"dropout": 0.15,
"emb-dropout": 0.15,
}
}
def default(value, fallback):
return value if value is not None else fallback
class EfficientNetBackbone(EfficientNet):
def delete_blocks(self, limit):
pass
def extract_features_at_block(self, inputs, selected_block):
x = self._swish(self._bn0(self._conv_stem(inputs)))
for idx, block in enumerate(self._blocks):
drop_connect_rate = self._global_params.drop_connect_rate
if drop_connect_rate:
drop_connect_rate *= float(idx) / len(self._blocks)
x = block(x, drop_connect_rate=drop_connect_rate)
if idx > selected_block:
break
if selected_block >= len(self._blocks):
x = self._swish(self._bn1(self._conv_head(x)))
return x
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.0):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head**-0.5
self.attend = nn.Softmax(dim=-1)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
def forward(self, x, context=None, kv_include_self=False):
_, _, _, heads = *x.shape, self.heads
context = default(context, x)
if kv_include_self:
context = torch.cat((x, context), dim=1)
qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1))
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=heads), qkv)
dots = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
attn = self.attend(dots)
out = einsum("b h i j, b h j d -> b h i d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.0):
super().__init__()
self.layers = nn.ModuleList([])
self.norm = nn.LayerNorm(dim)
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
PreNorm(
dim,
Attention(
dim, heads=heads, dim_head=dim_head, dropout=dropout
),
),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)),
]
)
)
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)
class ProjectInOut(nn.Module):
def __init__(self, dim_in, dim_out, fn):
super().__init__()
self.fn = fn
need_projection = dim_in != dim_out
self.project_in = nn.Linear(dim_in, dim_out) if need_projection else nn.Identity()
self.project_out = nn.Linear(dim_out, dim_in) if need_projection else nn.Identity()
def forward(self, x, *args, **kwargs):
x = self.project_in(x)
x = self.fn(x, *args, **kwargs)
return self.project_out(x)
class CrossTransformer(nn.Module):
def __init__(self, sm_dim, lg_dim, depth, heads, dim_head, dropout):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
ProjectInOut(
sm_dim,
lg_dim,
PreNorm(
lg_dim,
Attention(
lg_dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
),
),
),
ProjectInOut(
lg_dim,
sm_dim,
PreNorm(
sm_dim,
Attention(
sm_dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
),
),
),
]
)
)
def forward(self, sm_tokens, lg_tokens):
(sm_cls, sm_patch_tokens), (lg_cls, lg_patch_tokens) = map(
lambda t: (t[:, :1], t[:, 1:]), (sm_tokens, lg_tokens)
)
for sm_attend_lg, lg_attend_sm in self.layers:
sm_cls = (
sm_attend_lg(sm_cls, context=lg_patch_tokens, kv_include_self=True)
+ sm_cls
)
lg_cls = (
lg_attend_sm(lg_cls, context=sm_patch_tokens, kv_include_self=True)
+ lg_cls
)
return torch.cat((sm_cls, sm_patch_tokens), dim=1), torch.cat(
(lg_cls, lg_patch_tokens), dim=1
)
class MultiScaleEncoder(nn.Module):
def __init__(
self,
*,
depth,
sm_dim,
lg_dim,
sm_enc_params,
lg_enc_params,
cross_attn_heads,
cross_attn_depth,
cross_attn_dim_head=64,
dropout=0.0
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
Transformer(dim=sm_dim, dropout=dropout, **sm_enc_params),
Transformer(dim=lg_dim, dropout=dropout, **lg_enc_params),
CrossTransformer(
sm_dim=sm_dim,
lg_dim=lg_dim,
depth=cross_attn_depth,
heads=cross_attn_heads,
dim_head=cross_attn_dim_head,
dropout=dropout,
),
]
)
)
def forward(self, sm_tokens, lg_tokens):
for sm_enc, lg_enc, cross_attend in self.layers:
sm_tokens, lg_tokens = sm_enc(sm_tokens), lg_enc(lg_tokens)
sm_tokens, lg_tokens = cross_attend(sm_tokens, lg_tokens)
return sm_tokens, lg_tokens
class ImageEmbedder(nn.Module):
def __init__(
self,
*,
dim,
image_size,
patch_size,
dropout=0.0,
efficient_block=8,
channels
):
super().__init__()
self.efficient_net = EfficientNetBackbone.from_name("efficientnet-b0")
self.efficient_net.delete_blocks(efficient_block)
self.efficient_block = efficient_block
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size**2
self.to_patch_embedding = nn.Sequential(
Rearrange(
"b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
p1=patch_size,
p2=patch_size,
),
nn.Linear(patch_dim, dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(dropout)
def forward(self, img):
x = self.efficient_net.extract_features_at_block(img, self.efficient_block)
x = self.to_patch_embedding(x)
batch_size, num_tokens, _ = x.shape
cls_tokens = repeat(self.cls_token, "() n d -> b n d", b=batch_size)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, : num_tokens + 1]
return self.dropout(x)
class CrossEfficientViT(nn.Module):
def __init__(self, *, config=CROSS_EFFICIENT_VIT_CONFIG):
super().__init__()
model_config = config["model"]
image_size = model_config["image-size"]
num_classes = model_config["num-classes"]
sm_dim = model_config["sm-dim"]
sm_channels = model_config["sm-channels"]
lg_dim = model_config["lg-dim"]
lg_channels = model_config["lg-channels"]
sm_patch_size = model_config["sm-patch-size"]
lg_patch_size = model_config["lg-patch-size"]
self.sm_image_embedder = ImageEmbedder(
dim=sm_dim,
image_size=image_size,
patch_size=sm_patch_size,
dropout=model_config["emb-dropout"],
efficient_block=16,
channels=sm_channels,
)
self.lg_image_embedder = ImageEmbedder(
dim=lg_dim,
image_size=image_size,
patch_size=lg_patch_size,
dropout=model_config["emb-dropout"],
efficient_block=1,
channels=lg_channels,
)
self.multi_scale_encoder = MultiScaleEncoder(
depth=model_config["depth"],
sm_dim=sm_dim,
lg_dim=lg_dim,
cross_attn_heads=model_config["cross-attn-heads"],
cross_attn_dim_head=model_config["cross-attn-dim-head"],
cross_attn_depth=model_config["cross-attn-depth"],
sm_enc_params={
"depth": model_config["sm-enc-depth"],
"heads": model_config["sm-enc-heads"],
"mlp_dim": model_config["sm-enc-mlp-dim"],
"dim_head": model_config["sm-enc-dim-head"],
},
lg_enc_params={
"depth": model_config["lg-enc-depth"],
"heads": model_config["lg-enc-heads"],
"mlp_dim": model_config["lg-enc-mlp-dim"],
"dim_head": model_config["lg-enc-dim-head"],
},
dropout=model_config["dropout"],
)
self.sm_mlp_head = nn.Sequential(nn.LayerNorm(sm_dim), nn.Linear(sm_dim, num_classes))
self.lg_mlp_head = nn.Sequential(nn.LayerNorm(lg_dim), nn.Linear(lg_dim, num_classes))
def forward(self, img):
sm_tokens = self.sm_image_embedder(img)
lg_tokens = self.lg_image_embedder(img)
sm_tokens, lg_tokens = self.multi_scale_encoder(sm_tokens, lg_tokens)
sm_cls, lg_cls = map(lambda t: t[:, 0], (sm_tokens, lg_tokens))
return self.sm_mlp_head(sm_cls) + self.lg_mlp_head(lg_cls)