Spaces:
Sleeping
Sleeping
| from functools import partial | |
| from timm.models import xception | |
| from model.common import SeparableConv2d, Block | |
| from model.common import GuidedAttention, GraphReasoning | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| encoder_params = { | |
| "xception": { | |
| "features": 2048, | |
| "init_op": partial(xception, pretrained=True) | |
| } | |
| } | |
| class Recce(nn.Module): | |
| """ End-to-End Reconstruction-Classification Learning for Face Forgery Detection """ | |
| def __init__(self, num_classes, drop_rate=0.2): | |
| super(Recce, self).__init__() | |
| self.name = "xception" | |
| self.loss_inputs = dict() | |
| self.encoder = encoder_params[self.name]["init_op"]() | |
| self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) | |
| self.dropout = nn.Dropout(drop_rate) | |
| self.fc = nn.Linear(encoder_params[self.name]["features"], num_classes) | |
| self.attention = GuidedAttention(depth=728, drop_rate=drop_rate) | |
| self.reasoning = GraphReasoning(728, 256, 256, 256, 128, 256, [2, 4], drop_rate) | |
| self.decoder1 = nn.Sequential( | |
| nn.UpsamplingNearest2d(scale_factor=2), | |
| SeparableConv2d(728, 256, 3, 1, 1, bias=False), | |
| nn.BatchNorm2d(256), | |
| nn.ReLU(inplace=True) | |
| ) | |
| self.decoder2 = Block(256, 256, 3, 1) | |
| self.decoder3 = nn.Sequential( | |
| nn.UpsamplingNearest2d(scale_factor=2), | |
| SeparableConv2d(256, 128, 3, 1, 1, bias=False), | |
| nn.BatchNorm2d(128), | |
| nn.ReLU(inplace=True) | |
| ) | |
| self.decoder4 = Block(128, 128, 3, 1) | |
| self.decoder5 = nn.Sequential( | |
| nn.UpsamplingNearest2d(scale_factor=2), | |
| SeparableConv2d(128, 64, 3, 1, 1, bias=False), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(inplace=True) | |
| ) | |
| self.decoder6 = nn.Sequential( | |
| nn.Conv2d(64, 3, 1, 1, bias=False), | |
| nn.Tanh() | |
| ) | |
| def norm_n_corr(self, x): | |
| norm_embed = F.normalize(self.global_pool(x), p=2, dim=1) | |
| corr = (torch.matmul(norm_embed.squeeze(), norm_embed.squeeze().T) + 1.) / 2. | |
| return norm_embed, corr | |
| def add_white_noise(tensor, mean=0., std=1e-6): | |
| rand = torch.rand([tensor.shape[0], 1, 1, 1]) | |
| rand = torch.where(rand > 0.5, 1., 0.).to(tensor.device) | |
| white_noise = torch.normal(mean, std, size=tensor.shape, device=tensor.device) | |
| noise_t = tensor + white_noise * rand | |
| noise_t = torch.clip(noise_t, -1., 1.) | |
| return noise_t | |
| def forward(self, x): | |
| # clear the loss inputs | |
| self.loss_inputs = dict(recons=[], contra=[]) | |
| noise_x = self.add_white_noise(x) if self.training else x | |
| out = self.encoder.conv1(noise_x) | |
| out = self.encoder.bn1(out) | |
| out = self.encoder.act1(out) | |
| out = self.encoder.conv2(out) | |
| out = self.encoder.bn2(out) | |
| out = self.encoder.act2(out) | |
| out = self.encoder.block1(out) | |
| out = self.encoder.block2(out) | |
| out = self.encoder.block3(out) | |
| embedding = self.encoder.block4(out) | |
| norm_embed, corr = self.norm_n_corr(embedding) | |
| self.loss_inputs['contra'].append(corr) | |
| out = self.dropout(embedding) | |
| out = self.decoder1(out) | |
| out_d2 = self.decoder2(out) | |
| norm_embed, corr = self.norm_n_corr(out_d2) | |
| self.loss_inputs['contra'].append(corr) | |
| out = self.decoder3(out_d2) | |
| out_d4 = self.decoder4(out) | |
| norm_embed, corr = self.norm_n_corr(out_d4) | |
| self.loss_inputs['contra'].append(corr) | |
| out = self.decoder5(out_d4) | |
| pred = self.decoder6(out) | |
| recons_x = F.interpolate(pred, size=x.shape[-2:], mode='bilinear', align_corners=True) | |
| self.loss_inputs['recons'].append(recons_x) | |
| embedding = self.encoder.block5(embedding) | |
| embedding = self.encoder.block6(embedding) | |
| embedding = self.encoder.block7(embedding) | |
| fusion = self.reasoning(embedding, out_d2, out_d4) + embedding | |
| embedding = self.encoder.block8(fusion) | |
| img_att = self.attention(x, recons_x, embedding) | |
| embedding = self.encoder.block9(img_att) | |
| embedding = self.encoder.block10(embedding) | |
| embedding = self.encoder.block11(embedding) | |
| embedding = self.encoder.block12(embedding) | |
| embedding = self.encoder.conv3(embedding) | |
| embedding = self.encoder.bn3(embedding) | |
| embedding = self.encoder.act3(embedding) | |
| embedding = self.encoder.conv4(embedding) | |
| embedding = self.encoder.bn4(embedding) | |
| embedding = self.encoder.act4(embedding) | |
| embedding = self.global_pool(embedding).squeeze() | |
| out = self.dropout(embedding) | |
| return self.fc(out) | |