| import gradio as gr |
| import torch |
| import torch.nn as nn |
| import torchvision |
| from torchvision import transforms |
| from PIL import Image |
| import torch.nn.functional as F |
| import torchvision.models as models |
| import torchaudio |
| import numpy as np |
| import matplotlib.pyplot as plt |
| from facenet_pytorch import MTCNN |
| from ultralytics import YOLO |
|
|
| binary_labels = ['Real', 'Spoof'] |
| binary_face_outputs = ['Real Face Detected', 'Spoof Face Detected'] |
| binary_voice_outputs = ['Real Audio Detected','Spoof Audio Detected'] |
| binary_finger_outputs = ['Real Finger Print Detected','Spoof Finger Print Detected'] |
| multi_face_labels = ['Genuine','Printed Photo','Paper Cut','Replayed','3D Mask'] |
| multi_face_outputs = ['Genuine Face Detected','Printed Photo Detected','Paper Cut Detected','Replayed Face Detected','3D Mask Detected'] |
| multi_voice_outputs = ['Real Audio Detected','Text to Speech Detected','Voice Conversion Detected','Text to Speech + Voice Conversion Detected'] |
| multi_voice_labels = ['Real','Text to Speech','Voice Conversion','Text to Speech + Voice Conversion'] |
|
|
|
|
| |
| def initialize_weights(m): |
| if isinstance(m, nn.Linear): |
| nn.init.xavier_uniform_(m.weight) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
|
|
| class NativeAdapter(nn.Module): |
| def __init__(self, input_dim=1024, bottleneck_dim=64): |
| super(NativeAdapter, self).__init__() |
| self.linear1 = nn.Linear(input_dim, bottleneck_dim) |
| self.activ = nn.GELU() |
| self.linear2 = nn.Linear(bottleneck_dim, input_dim) |
| self.apply(initialize_weights) |
| |
| def forward(self, x): |
| residual = x |
| out = self.linear1(x) |
| out = self.activ(out) |
| out = self.linear2(out) |
| return out + residual |
| |
| class EnsembleAdapter(nn.Module): |
| def __init__(self): |
| super(EnsembleAdapter, self).__init__() |
| self.adapter1 = NativeAdapter() |
| self.adapter2 = NativeAdapter() |
| |
| def forward(self, x): |
| out1 = self.adapter1(x) |
| out2 = self.adapter2(x) |
| out = (out1 + out2) / 2 |
| cos_sim = torch.nn.functional.cosine_similarity(out1, out2, dim=-1) |
| cos_sim_loss = cos_sim.mean() |
| return out, cos_sim_loss |
|
|
| class FWTLayer(nn.Module): |
| def __init__(self, hidden_dim=1024, std=0.02): |
| super(FWTLayer, self).__init__() |
| self.hidden_dim = hidden_dim |
| self.std = std |
| |
| self.W_alpha = nn.Parameter(torch.randn(hidden_dim)) |
| self.W_beta = nn.Parameter(torch.randn(hidden_dim)) |
| |
| def forward(self, x): |
| alpha = torch.randn(self.hidden_dim).to(x.device) * self.std * F.softplus(self.W_alpha) |
| beta = torch.randn(self.hidden_dim).to(x.device) * self.std * F.softplus(self.W_beta) |
| |
| x_transformed = x + alpha * x + beta |
| return x_transformed |
|
|
| class UpdatedBlock(nn.Module): |
| def __init__(self, encoder_block): |
| super(UpdatedBlock, self).__init__() |
| self.ln_1 = encoder_block.ln_1 |
| self.self_attention = encoder_block.self_attention |
| self.dropout = encoder_block.dropout |
| self.ensemble_adapter1 = EnsembleAdapter() |
| self.ln_2 = encoder_block.ln_2 |
| self.mlp = encoder_block.mlp |
| self.ensemble_adapter2 = EnsembleAdapter() |
| self.fwt_layer = FWTLayer() |
| |
| def forward(self, input): |
| x = self.ln_1(input) |
| x, _ = self.self_attention(x, x, x, need_weights=False) |
| x = self.dropout(x) |
| x, loss_1 = self.ensemble_adapter1(x) |
| x = x + input |
|
|
| y = self.ln_2(x) |
| y = self.mlp(y) |
| y, loss_2 = self.ensemble_adapter2(y) |
| out = x + y |
| if self.training: |
| out = self.fwt_layer(out) |
| return out, (loss_1 + loss_2) / 2 |
|
|
| class UpdatedEncoder(nn.Module): |
| def __init__(self, encoder): |
| super(UpdatedEncoder, self).__init__() |
| self.pos_embedding = encoder.pos_embedding |
| self.dropout = encoder.dropout |
| self.layers = nn.ModuleList([UpdatedBlock(layer) for layer in encoder.layers]) |
| self.ln = encoder.ln |
| |
| def forward(self, x): |
| out = x + self.pos_embedding |
| out = self.dropout(out) |
| total_loss = 0 |
| for layer in self.layers: |
| out, loss = layer(out) |
| total_loss += loss |
| out = self.ln(out) |
| return out, total_loss |
|
|
| class UpdatedViT(nn.Module): |
| def __init__(self, base_model): |
| super(UpdatedViT, self).__init__() |
| self.conv_proj = base_model.conv_proj |
| self.encoder = UpdatedEncoder(base_model.encoder) |
| self.heads = base_model.heads |
| self._process_input = base_model._process_input |
| self.class_token = base_model.class_token |
| |
| def forward(self, x): |
| x = self._process_input(x) |
| n = x.shape[0] |
| batch_class_token = self.class_token.expand(n, -1, -1) |
| x = torch.cat([batch_class_token, x], dim=1) |
| x, cos_loss = self.encoder(x) |
| x = x[:, 0] |
| x = self.heads(x) |
| return x, cos_loss / len(self.encoder.layers) |
|
|
|
|
| |
| class ConformerClassifier(torch.nn.Module): |
| def __init__(self, input_dim, num_classes, num_heads, ffn_dim, num_layers, depthwise_conv_kernel_size,dropout=0.0,use_group_norm=False,convolution_first=False): |
| super(ConformerClassifier, self).__init__() |
| self.conformer = torchaudio.models.Conformer( |
| input_dim=input_dim, |
| num_heads=num_heads, |
| ffn_dim=ffn_dim, |
| num_layers=num_layers, |
| depthwise_conv_kernel_size=depthwise_conv_kernel_size, |
| dropout=dropout, |
| use_group_norm=use_group_norm, |
| convolution_first=convolution_first |
| ) |
| self.fc = torch.nn.Linear(input_dim, num_classes) |
| |
| def forward(self, x, lengths): |
| x,length = self.conformer(x, lengths) |
| x = x.mean(dim=1) |
| x = self.fc(x) |
| return x |
|
|
|
|
| |
| vit_model_binary = torchvision.models.vit_l_16(weights=None,progress=True) |
| vit_model_binary.heads=nn.Sequential( |
| nn.Linear(1024, 512), |
| nn.ReLU(), |
| nn.Dropout(0.3), |
| nn.Linear(512, 2) |
| ) |
| |
| vit_binary_model = UpdatedViT(vit_model_binary) |
| vit_binary_model.load_state_dict(torch.load('Correct_vit_model_binary.pth',map_location='cpu')) |
| vit_binary_model.eval() |
| |
|
|
| |
| vit_model_multi = torchvision.models.vit_l_16(weights=None,progress=True) |
| vit_model_multi.heads=nn.Sequential( |
| nn.Linear(1024, 512), |
| nn.ReLU(), |
| nn.Dropout(0.3), |
| nn.Linear(512, 5) |
| ) |
| |
| vit_multi_model = UpdatedViT(vit_model_multi) |
| vit_multi_model.load_state_dict(torch.load('multi_vit_model.pth',map_location='cpu')) |
| vit_multi_model.eval() |
|
|
| |
| convnext_binary_model = torchvision.models.convnext_base(weights=None,progress=False) |
| convnext_binary_model.classifier[2]=nn.Sequential( |
| nn.Linear(1024, 512), |
| nn.ReLU(), |
| nn.Dropout(0.3), |
| nn.Linear(512, 2), |
| ) |
| convnext_binary_model.load_state_dict(torch.load('binary_convnext_model.pth',map_location='cpu')) |
| convnext_binary_model.eval() |
|
|
|
|
| |
| convnext_multi_model = torchvision.models.convnext_base(weights=None,progress=False) |
| convnext_multi_model.classifier[2]=nn.Sequential( |
| nn.Linear(1024, 512), |
| nn.ReLU(), |
| nn.Dropout(0.3), |
| nn.Linear(512, 5), |
| ) |
| convnext_multi_model.load_state_dict(torch.load('multi_convnext_model.pth',map_location='cpu')) |
| convnext_multi_model.eval() |
|
|
|
|
| |
| voice_binary_model = ConformerClassifier( |
| input_dim=80, |
| num_classes=2, |
| num_heads=4, |
| ffn_dim=128, |
| num_layers=4, |
| depthwise_conv_kernel_size=7, |
| dropout=0.3, |
| use_group_norm=False, |
| convolution_first=True |
| ) |
| voice_binary_model.load_state_dict(torch.load('binary_voice_model.pth',map_location='cpu')) |
| voice_binary_model.eval() |
|
|
|
|
| |
| voice_multi_model = ConformerClassifier( |
| input_dim=80, |
| num_classes=4, |
| num_heads=4, |
| ffn_dim=128, |
| num_layers=4, |
| depthwise_conv_kernel_size=31, |
| dropout=0.3, |
| use_group_norm=False, |
| convolution_first=True |
| ) |
|
|
| voice_multi_model.load_state_dict(torch.load('multi_voice_model.pth',map_location='cpu')) |
| voice_multi_model.eval() |
|
|
|
|
| |
| finger_print_binary = torchvision.models.vit_l_16(weights=None,progress=True) |
| finger_print_binary.heads=nn.Sequential( |
| nn.Linear(1024, 512), |
| nn.ReLU(), |
| nn.Dropout(0.3), |
| nn.Linear(512, 2) |
| ) |
| |
| finger_print_binary_model = UpdatedViT(finger_print_binary) |
| finger_print_binary_model.load_state_dict(torch.load('binary_finger_print_model.pth',map_location='cpu')) |
| finger_print_binary_model.eval() |
|
|
|
|
| |
| def process_image(img, extend=0): |
| mtcnn = MTCNN(keep_all=False, device='cuda' if torch.cuda.is_available() else 'cpu') |
| boxes, _ = mtcnn.detect(img) |
| face_detected = boxes is not None |
| if face_detected: |
| real_w, real_h = img.size |
| box = boxes[0] |
| bbox = list(map(float, box)) |
| x1 = int(bbox[0]) |
| y1 = int(bbox[1]) |
| w1 = int(bbox[2]) |
| h1 = int(bbox[3]) |
| c1 = max(0, x1 - extend) |
| c2 = max(0, y1 - extend) |
| c3 = min(real_w, w1 + extend) |
| c4 = min(real_h, h1 + extend) |
| img = img.crop((c1, c2, c3, c4)) |
| |
| transformer = torchvision.transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Resize((224, 224), antialias=True) |
| ]) |
| |
| img = transformer(img) |
| |
| return img.unsqueeze(0), face_detected |
|
|
|
|
| def process_image_finger(img): |
| transformer = torchvision.transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Resize((224, 224), antialias=True) |
| ]) |
| |
| img = transformer(img) |
| return img.unsqueeze(0), True |
|
|
|
|
| |
| def process_audio(audio): |
| waveform, sample_rate = torchaudio.load(audio) |
| |
| if waveform.size(0) > 1: |
| waveform = waveform.mean(dim=0, keepdim=True) |
| |
| mel_spectrogram = torchaudio.transforms.MelSpectrogram(n_mels=80)(waveform).squeeze(0) |
| num_frames = mel_spectrogram.size(1) |
| target_length = 400 |
| |
| if num_frames < target_length: |
| padding = target_length - num_frames |
| mel_spectrogram = torch.cat([mel_spectrogram, torch.zeros(mel_spectrogram.size(0), padding)], dim=1) |
| else: |
| mel_spectrogram = mel_spectrogram[:, :target_length] |
| |
| mel_spectrogram = mel_spectrogram.transpose(0, 1) |
| length = torch.tensor([mel_spectrogram.size(0)]) |
| return mel_spectrogram.unsqueeze(0) ,length |
|
|
|
|
| |
| def get_details(probs_dict): |
| string = '' |
| for key, value in probs_dict.items(): |
| prob = round(value*100,2) |
| string += f'<h3><b>{key}:</b> {prob}%<br></h3>' |
| return string |
|
|
| def get_formated_output(output_category,output_print,mode,model,probs_dict): |
| string = '<div>' |
| string += f'''<h1><b><span style="color: blue" >{output_category}</span></b> <span style="color: cornflowerblue">{output_print}%</span><h1/>''' |
| string += f'''<br>''' |
| string += f'''<h2>Model Details:</h2>''' |
| string += f'''<h3><b>Mode:</b> {mode}</h3>''' |
| string += f'''<h3><b>Model:</b> {model}</h3>''' |
| string += f'''<br>''' |
| string += f'''<h2>Classification Details:</h2>''' |
| string += get_details(probs_dict) |
| string += '</div>' |
| return string |
|
|
| def get_output_details(category,output_prob,probs_dict): |
| string = '<div>' |
| string += f'''<h1><b><span style="color: blue" >{category}</span></b> <span style="color: cornflowerblue">{output_prob}%</span><h1/>''' |
| string += f'''<br>''' |
| string += f'''<h2>Classification Details:</h2>''' |
| string += get_details(probs_dict) |
| string += '</div>' |
| return string |
|
|
| def update_status(image,mode,model): |
| if image: |
| pil_image = Image.open(image) |
| extend = 20 if model == 'transformer' else 0 |
| processed_image, is_face_detected = process_image(pil_image,extend) |
| if not is_face_detected: |
| return image, """<h1 style="color: red;font-size: 18px; font-weight: bold; ">No Face Detected</h1>""" |
|
|
| with torch.no_grad(): |
| if mode == 'binary': |
| if model=='transformer': |
| output, _ = vit_binary_model(processed_image) |
| else: |
| output= convnext_binary_model(processed_image) |
|
|
| prob = torch.nn.functional.softmax(output[0], dim=0) |
| pred = torch.argmax(prob).item() |
| category = binary_labels[pred] |
| output_category = binary_face_outputs[pred] |
| output_prob = prob[pred].item() |
| probs_dict = {binary_labels[i]: prob[i].item() for i in range(len(binary_labels))} |
| else: |
| if model=='transformer': |
| output, _ = vit_multi_model(processed_image) |
| else: |
| output= convnext_multi_model(processed_image) |
| prob = torch.nn.functional.softmax(output[0], dim=0) |
| pred = torch.argmax(prob).item() |
| category = multi_face_labels[pred] |
| output_category = multi_face_outputs[pred] |
| output_prob = prob[pred].item() |
| probs_dict = {multi_face_labels[i]: prob[i].item() for i in range(len(multi_face_labels))} |
| to_pil = transforms.ToPILImage() |
| cropped_pil_image = to_pil(processed_image.squeeze(0)) |
| output_print = round(output_prob*100,2) |
| return cropped_pil_image, get_formated_output(output_category,output_print,mode,model,probs_dict) |
| |
| else: |
| return image, """<h1 style="color: red;font-size: 18px; font-weight: bold; ">No image uploaded yet.</h1>""" |
|
|
| def handle_button_click_face(image, mode, model): |
| image, status = update_status(image, mode, model) |
| return image, status |
|
|
| def handle_button_click_finger(image): |
| if image: |
| pil_image = Image.open(image) |
| processed_image, is_finger_detected = process_image_finger(pil_image) |
| if not is_finger_detected: |
| return image, """<h1 style="color: red;font-size: 18px; font-weight: bold; ">No Finger Print Detected</h1>""" |
| with torch.no_grad(): |
| output, _ = finger_print_binary_model(processed_image) |
| prob = torch.nn.functional.softmax(output[0], dim=0) |
| pred = torch.argmax(prob).item() |
| output_category = binary_finger_outputs[pred] |
| output_prob = prob[pred].item() |
| probs_dict = {binary_labels[i]: prob[i].item() for i in range(len(binary_labels))} |
| output_prob = round(prob[pred].item()*100,2) |
| return get_output_details(output_category,output_prob,probs_dict) |
| else: |
| return """<h1 style="color: red;font-size: 18px; font-weight: bold; ">No image file uploaded yet.</h1>""" |
|
|
| def handle_button_click_voice(audio,mode): |
| if audio: |
| mel_spectrogram, length = process_audio(audio) |
| with torch.no_grad(): |
| if mode=='binary': |
| output = voice_binary_model(mel_spectrogram, length) |
| prob = torch.nn.functional.softmax(output[0], dim=0) |
| pred = torch.argmax(prob).item() |
| category = binary_voice_outputs[pred] |
| probs_dict = {binary_labels[i]: prob[i].item() for i in range(len(binary_labels))} |
| output_prob = round(prob[pred].item()*100,2) |
| else: |
| output = voice_multi_model(mel_spectrogram, length) |
| prob = torch.nn.functional.softmax(output[0], dim=0) |
| pred = torch.argmax(prob).item() |
| category = multi_voice_outputs[pred] |
| probs_dict = {multi_voice_labels[i]: prob[i].item() for i in range(len(multi_voice_labels))} |
| output_prob = round(prob[pred].item()*100,2) |
| |
| return get_output_details(category,output_prob,probs_dict) |
| else: |
| return """<h1 style="color: red;font-size: 18px; font-weight: bold; ">No audio file uploaded yet.</h1>""" |
| |
| |
|
|
|
|
| def update_visibility(modality): |
| if modality == 'face': |
| return ( |
| gr.update(visible=True), |
| gr.update(visible=True), |
| gr.update(visible=True), |
| gr.update(visible=True), |
| gr.update(visible=False), |
| gr.update(visible=False), |
| gr.update(visible=False), |
| gr.update(visible=False), |
| gr.update(visible=False) |
| ) |
| elif modality == 'finger print': |
| return ( |
| gr.update(visible=False), |
| gr.update(visible=False), |
| gr.update(visible=False), |
| gr.update(visible=False), |
| gr.update(visible=True), |
| gr.update(visible=True), |
| gr.update(visible=False), |
| gr.update(visible=False), |
| gr.update(visible=False) |
| ) |
| elif modality == 'voice': |
| return ( |
| gr.update(visible=False), |
| gr.update(visible=False), |
| gr.update(visible=False), |
| gr.update(visible=False), |
| gr.update(visible=False), |
| gr.update(visible=False), |
| gr.update(visible=True), |
| gr.update(visible=True), |
| gr.update(visible=True) |
| ) |
|
|
|
|
| with gr.Blocks() as demo: |
| gr.Markdown("# Multimodal Anti-Spoofing Detection") |
| |
| dropdown_modality = gr.Dropdown(label="Choose modality", choices=['face','voice','finger print'], value="face") |
| |
| dropdown_mode_face = gr.Dropdown(label="Choose mode", choices=['binary', 'multi'], value="binary", visible=True) |
| dropdown_model_face = gr.Dropdown(label="Choose model", choices=['transformer', 'convnext'], value="transformer", visible=True) |
| |
| with gr.Row(visible=True) as row_face: |
| image_input_face = gr.Image(label="Upload Image", type="filepath") |
| status_text_face = gr.HTML(label="Output", value='''<h1 style="color: blue; font-size: 18px; font-weight: bold;"> |
| Please upload image and press the process button! |
| </h1>''') |
| |
| process_button_face = gr.Button("Process Image", visible=True) |
|
|
| with gr.Row(visible=False) as row_finger: |
| image_input_finger = gr.Image(label="Upload Image", type="filepath") |
| status_text_finger = gr.HTML(label="Output", value='''<h1 style="color: blue; font-size: 18px; font-weight: bold;"> |
| Please upload finger print image and press the process button! |
| </h1>''') |
|
|
| process_button_finger = gr.Button("Process Image", visible=False) |
|
|
| dropdown_mode_voice = gr.Dropdown(label="Choose mode", choices=['binary', 'multi'], value="binary", visible=False) |
| with gr.Row(visible=False) as row_voice: |
| audio_input_voice = gr.File(label="Upload Audio", file_types=["audio"]) |
| status_text_voice = gr.HTML(label="Output", value='''<h1 style="color: blue; font-size: 18px; font-weight: bold;"> |
| Please upload audio file and press the process button! |
| </h1>''') |
|
|
| process_button_voice = gr.Button("Process Audio", visible=False) |
|
|
| dropdown_modality.change( |
| fn=update_visibility, |
| inputs=[dropdown_modality], |
| outputs=[ |
| dropdown_mode_face, |
| dropdown_model_face, |
| row_face, |
| process_button_face, |
| row_finger, |
| process_button_finger, |
| dropdown_mode_voice, |
| row_voice, |
| process_button_voice |
| ] |
| ) |
| |
| process_button_face.click( |
| fn=handle_button_click_face, |
| inputs=[image_input_face, dropdown_mode_face, dropdown_model_face], |
| outputs=[image_input_face, status_text_face] |
| ) |
|
|
| process_button_voice.click( |
| fn=handle_button_click_voice, |
| inputs=[audio_input_voice, dropdown_mode_voice], |
| outputs=status_text_voice |
| ) |
|
|
| process_button_finger.click( |
| fn=handle_button_click_finger, |
| inputs=image_input_finger, |
| outputs=status_text_finger |
| ) |
|
|
| demo.launch() |