abdullahsajid's picture
Update app.py
ece9e2c verified
import gradio as gr
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F
import torchvision.models as models
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
from facenet_pytorch import MTCNN
from ultralytics import YOLO
binary_labels = ['Real', 'Spoof']
binary_face_outputs = ['Real Face Detected', 'Spoof Face Detected']
binary_voice_outputs = ['Real Audio Detected','Spoof Audio Detected']
binary_finger_outputs = ['Real Finger Print Detected','Spoof Finger Print Detected']
multi_face_labels = ['Genuine','Printed Photo','Paper Cut','Replayed','3D Mask']
multi_face_outputs = ['Genuine Face Detected','Printed Photo Detected','Paper Cut Detected','Replayed Face Detected','3D Mask Detected']
multi_voice_outputs = ['Real Audio Detected','Text to Speech Detected','Voice Conversion Detected','Text to Speech + Voice Conversion Detected']
multi_voice_labels = ['Real','Text to Speech','Voice Conversion','Text to Speech + Voice Conversion']
# Transformer Architecture
def initialize_weights(m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
class NativeAdapter(nn.Module):
def __init__(self, input_dim=1024, bottleneck_dim=64):
super(NativeAdapter, self).__init__()
self.linear1 = nn.Linear(input_dim, bottleneck_dim)
self.activ = nn.GELU()
self.linear2 = nn.Linear(bottleneck_dim, input_dim)
self.apply(initialize_weights)
def forward(self, x):
residual = x
out = self.linear1(x)
out = self.activ(out)
out = self.linear2(out)
return out + residual
class EnsembleAdapter(nn.Module):
def __init__(self):
super(EnsembleAdapter, self).__init__()
self.adapter1 = NativeAdapter()
self.adapter2 = NativeAdapter()
def forward(self, x):
out1 = self.adapter1(x)
out2 = self.adapter2(x)
out = (out1 + out2) / 2
cos_sim = torch.nn.functional.cosine_similarity(out1, out2, dim=-1)
cos_sim_loss = cos_sim.mean()
return out, cos_sim_loss
class FWTLayer(nn.Module):
def __init__(self, hidden_dim=1024, std=0.02):
super(FWTLayer, self).__init__()
self.hidden_dim = hidden_dim
self.std = std
self.W_alpha = nn.Parameter(torch.randn(hidden_dim))
self.W_beta = nn.Parameter(torch.randn(hidden_dim))
def forward(self, x):
alpha = torch.randn(self.hidden_dim).to(x.device) * self.std * F.softplus(self.W_alpha)
beta = torch.randn(self.hidden_dim).to(x.device) * self.std * F.softplus(self.W_beta)
x_transformed = x + alpha * x + beta
return x_transformed
class UpdatedBlock(nn.Module):
def __init__(self, encoder_block):
super(UpdatedBlock, self).__init__()
self.ln_1 = encoder_block.ln_1
self.self_attention = encoder_block.self_attention
self.dropout = encoder_block.dropout
self.ensemble_adapter1 = EnsembleAdapter()
self.ln_2 = encoder_block.ln_2
self.mlp = encoder_block.mlp
self.ensemble_adapter2 = EnsembleAdapter()
self.fwt_layer = FWTLayer()
def forward(self, input):
x = self.ln_1(input)
x, _ = self.self_attention(x, x, x, need_weights=False)
x = self.dropout(x)
x, loss_1 = self.ensemble_adapter1(x)
x = x + input
y = self.ln_2(x)
y = self.mlp(y)
y, loss_2 = self.ensemble_adapter2(y)
out = x + y
if self.training:
out = self.fwt_layer(out)
return out, (loss_1 + loss_2) / 2
class UpdatedEncoder(nn.Module):
def __init__(self, encoder):
super(UpdatedEncoder, self).__init__()
self.pos_embedding = encoder.pos_embedding
self.dropout = encoder.dropout
self.layers = nn.ModuleList([UpdatedBlock(layer) for layer in encoder.layers])
self.ln = encoder.ln
def forward(self, x):
out = x + self.pos_embedding
out = self.dropout(out)
total_loss = 0
for layer in self.layers:
out, loss = layer(out)
total_loss += loss
out = self.ln(out)
return out, total_loss
class UpdatedViT(nn.Module):
def __init__(self, base_model):
super(UpdatedViT, self).__init__()
self.conv_proj = base_model.conv_proj
self.encoder = UpdatedEncoder(base_model.encoder)
self.heads = base_model.heads
self._process_input = base_model._process_input
self.class_token = base_model.class_token
def forward(self, x):
x = self._process_input(x)
n = x.shape[0]
batch_class_token = self.class_token.expand(n, -1, -1)
x = torch.cat([batch_class_token, x], dim=1)
x, cos_loss = self.encoder(x)
x = x[:, 0]
x = self.heads(x)
return x, cos_loss / len(self.encoder.layers)
# Conformer Architecture
class ConformerClassifier(torch.nn.Module):
def __init__(self, input_dim, num_classes, num_heads, ffn_dim, num_layers, depthwise_conv_kernel_size,dropout=0.0,use_group_norm=False,convolution_first=False):
super(ConformerClassifier, self).__init__()
self.conformer = torchaudio.models.Conformer(
input_dim=input_dim,
num_heads=num_heads,
ffn_dim=ffn_dim,
num_layers=num_layers,
depthwise_conv_kernel_size=depthwise_conv_kernel_size,
dropout=dropout,
use_group_norm=use_group_norm,
convolution_first=convolution_first
)
self.fc = torch.nn.Linear(input_dim, num_classes)
def forward(self, x, lengths):
x,length = self.conformer(x, lengths)
x = x.mean(dim=1)
x = self.fc(x)
return x
# Vision Transformer Binary Model
vit_model_binary = torchvision.models.vit_l_16(weights=None,progress=True)
vit_model_binary.heads=nn.Sequential(
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 2)
)
vit_binary_model = UpdatedViT(vit_model_binary)
vit_binary_model.load_state_dict(torch.load('Correct_vit_model_binary.pth',map_location='cpu'))
vit_binary_model.eval()
# binary_vit_model.pth
# Vision Transformer Multi Model
vit_model_multi = torchvision.models.vit_l_16(weights=None,progress=True)
vit_model_multi.heads=nn.Sequential(
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 5)
)
vit_multi_model = UpdatedViT(vit_model_multi)
vit_multi_model.load_state_dict(torch.load('multi_vit_model.pth',map_location='cpu'))
vit_multi_model.eval()
# ConvNext Binary Model
convnext_binary_model = torchvision.models.convnext_base(weights=None,progress=False)
convnext_binary_model.classifier[2]=nn.Sequential(
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 2),
)
convnext_binary_model.load_state_dict(torch.load('binary_convnext_model.pth',map_location='cpu'))
convnext_binary_model.eval()
# ConvNext Multi Model
convnext_multi_model = torchvision.models.convnext_base(weights=None,progress=False)
convnext_multi_model.classifier[2]=nn.Sequential(
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 5),
)
convnext_multi_model.load_state_dict(torch.load('multi_convnext_model.pth',map_location='cpu'))
convnext_multi_model.eval()
# Voice Binary Model
voice_binary_model = ConformerClassifier(
input_dim=80,
num_classes=2,
num_heads=4,
ffn_dim=128,
num_layers=4,
depthwise_conv_kernel_size=7,
dropout=0.3,
use_group_norm=False,
convolution_first=True
)
voice_binary_model.load_state_dict(torch.load('binary_voice_model.pth',map_location='cpu'))
voice_binary_model.eval()
# Voice Multi Model
voice_multi_model = ConformerClassifier(
input_dim=80,
num_classes=4,
num_heads=4,
ffn_dim=128,
num_layers=4,
depthwise_conv_kernel_size=31,
dropout=0.3,
use_group_norm=False,
convolution_first=True
)
voice_multi_model.load_state_dict(torch.load('multi_voice_model.pth',map_location='cpu'))
voice_multi_model.eval()
# Finger Print Binary Model
finger_print_binary = torchvision.models.vit_l_16(weights=None,progress=True)
finger_print_binary.heads=nn.Sequential(
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 2)
)
finger_print_binary_model = UpdatedViT(finger_print_binary)
finger_print_binary_model.load_state_dict(torch.load('binary_finger_print_model.pth',map_location='cpu'))
finger_print_binary_model.eval()
# Image processing function
def process_image(img, extend=0):
mtcnn = MTCNN(keep_all=False, device='cuda' if torch.cuda.is_available() else 'cpu')
boxes, _ = mtcnn.detect(img)
face_detected = boxes is not None
if face_detected:
real_w, real_h = img.size
box = boxes[0]
bbox = list(map(float, box))
x1 = int(bbox[0])
y1 = int(bbox[1])
w1 = int(bbox[2])
h1 = int(bbox[3])
c1 = max(0, x1 - extend)
c2 = max(0, y1 - extend)
c3 = min(real_w, w1 + extend)
c4 = min(real_h, h1 + extend)
img = img.crop((c1, c2, c3, c4))
transformer = torchvision.transforms.Compose([
transforms.ToTensor(),
transforms.Resize((224, 224), antialias=True)
])
img = transformer(img)
return img.unsqueeze(0), face_detected
def process_image_finger(img):
transformer = torchvision.transforms.Compose([
transforms.ToTensor(),
transforms.Resize((224, 224), antialias=True)
])
img = transformer(img)
return img.unsqueeze(0), True
# Audo processing function
def process_audio(audio):
waveform, sample_rate = torchaudio.load(audio)
if waveform.size(0) > 1:
waveform = waveform.mean(dim=0, keepdim=True)
mel_spectrogram = torchaudio.transforms.MelSpectrogram(n_mels=80)(waveform).squeeze(0)
num_frames = mel_spectrogram.size(1)
target_length = 400
if num_frames < target_length:
padding = target_length - num_frames
mel_spectrogram = torch.cat([mel_spectrogram, torch.zeros(mel_spectrogram.size(0), padding)], dim=1)
else:
mel_spectrogram = mel_spectrogram[:, :target_length]
mel_spectrogram = mel_spectrogram.transpose(0, 1)
length = torch.tensor([mel_spectrogram.size(0)])
return mel_spectrogram.unsqueeze(0) ,length
# Helper functions
def get_details(probs_dict):
string = ''
for key, value in probs_dict.items():
prob = round(value*100,2)
string += f'<h3><b>{key}:</b> {prob}%<br></h3>'
return string
def get_formated_output(output_category,output_print,mode,model,probs_dict):
string = '<div>'
string += f'''<h1><b><span style="color: blue" >{output_category}</span></b> <span style="color: cornflowerblue">{output_print}%</span><h1/>'''
string += f'''<br>'''
string += f'''<h2>Model Details:</h2>'''
string += f'''<h3><b>Mode:</b> {mode}</h3>'''
string += f'''<h3><b>Model:</b> {model}</h3>'''
string += f'''<br>'''
string += f'''<h2>Classification Details:</h2>'''
string += get_details(probs_dict)
string += '</div>'
return string
def get_output_details(category,output_prob,probs_dict):
string = '<div>'
string += f'''<h1><b><span style="color: blue" >{category}</span></b> <span style="color: cornflowerblue">{output_prob}%</span><h1/>'''
string += f'''<br>'''
string += f'''<h2>Classification Details:</h2>'''
string += get_details(probs_dict)
string += '</div>'
return string
def update_status(image,mode,model):
if image:
pil_image = Image.open(image)
extend = 20 if model == 'transformer' else 0
processed_image, is_face_detected = process_image(pil_image,extend)
if not is_face_detected:
return image, """<h1 style="color: red;font-size: 18px; font-weight: bold; ">No Face Detected</h1>"""
with torch.no_grad():
if mode == 'binary':
if model=='transformer':
output, _ = vit_binary_model(processed_image)
else:
output= convnext_binary_model(processed_image)
prob = torch.nn.functional.softmax(output[0], dim=0)
pred = torch.argmax(prob).item()
category = binary_labels[pred]
output_category = binary_face_outputs[pred]
output_prob = prob[pred].item()
probs_dict = {binary_labels[i]: prob[i].item() for i in range(len(binary_labels))}
else:
if model=='transformer':
output, _ = vit_multi_model(processed_image)
else:
output= convnext_multi_model(processed_image)
prob = torch.nn.functional.softmax(output[0], dim=0)
pred = torch.argmax(prob).item()
category = multi_face_labels[pred]
output_category = multi_face_outputs[pred]
output_prob = prob[pred].item()
probs_dict = {multi_face_labels[i]: prob[i].item() for i in range(len(multi_face_labels))}
to_pil = transforms.ToPILImage()
cropped_pil_image = to_pil(processed_image.squeeze(0))
output_print = round(output_prob*100,2)
return cropped_pil_image, get_formated_output(output_category,output_print,mode,model,probs_dict)
else:
return image, """<h1 style="color: red;font-size: 18px; font-weight: bold; ">No image uploaded yet.</h1>"""
def handle_button_click_face(image, mode, model):
image, status = update_status(image, mode, model)
return image, status
def handle_button_click_finger(image):
if image:
pil_image = Image.open(image)
processed_image, is_finger_detected = process_image_finger(pil_image)
if not is_finger_detected:
return image, """<h1 style="color: red;font-size: 18px; font-weight: bold; ">No Finger Print Detected</h1>"""
with torch.no_grad():
output, _ = finger_print_binary_model(processed_image)
prob = torch.nn.functional.softmax(output[0], dim=0)
pred = torch.argmax(prob).item()
output_category = binary_finger_outputs[pred]
output_prob = prob[pred].item()
probs_dict = {binary_labels[i]: prob[i].item() for i in range(len(binary_labels))}
output_prob = round(prob[pred].item()*100,2)
return get_output_details(output_category,output_prob,probs_dict)
else:
return """<h1 style="color: red;font-size: 18px; font-weight: bold; ">No image file uploaded yet.</h1>"""
def handle_button_click_voice(audio,mode):
if audio:
mel_spectrogram, length = process_audio(audio)
with torch.no_grad():
if mode=='binary':
output = voice_binary_model(mel_spectrogram, length)
prob = torch.nn.functional.softmax(output[0], dim=0)
pred = torch.argmax(prob).item()
category = binary_voice_outputs[pred]
probs_dict = {binary_labels[i]: prob[i].item() for i in range(len(binary_labels))}
output_prob = round(prob[pred].item()*100,2)
else:
output = voice_multi_model(mel_spectrogram, length)
prob = torch.nn.functional.softmax(output[0], dim=0)
pred = torch.argmax(prob).item()
category = multi_voice_outputs[pred]
probs_dict = {multi_voice_labels[i]: prob[i].item() for i in range(len(multi_voice_labels))}
output_prob = round(prob[pred].item()*100,2)
return get_output_details(category,output_prob,probs_dict)
else:
return """<h1 style="color: red;font-size: 18px; font-weight: bold; ">No audio file uploaded yet.</h1>"""
def update_visibility(modality):
if modality == 'face':
return (
gr.update(visible=True), # dropdown_mode_face
gr.update(visible=True), # dropdown_model_face
gr.update(visible=True), # row_face
gr.update(visible=True), # process_button_face
gr.update(visible=False), # row_finger
gr.update(visible=False), # process_button_finger
gr.update(visible=False), # dropdown_mode_voice
gr.update(visible=False), # row_voice
gr.update(visible=False) # process_button_voice
)
elif modality == 'finger print':
return (
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False)
)
elif modality == 'voice':
return (
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True)
)
with gr.Blocks() as demo:
gr.Markdown("# Multimodal Anti-Spoofing Detection")
dropdown_modality = gr.Dropdown(label="Choose modality", choices=['face','voice','finger print'], value="face")
dropdown_mode_face = gr.Dropdown(label="Choose mode", choices=['binary', 'multi'], value="binary", visible=True)
dropdown_model_face = gr.Dropdown(label="Choose model", choices=['transformer', 'convnext'], value="transformer", visible=True)
with gr.Row(visible=True) as row_face:
image_input_face = gr.Image(label="Upload Image", type="filepath")
status_text_face = gr.HTML(label="Output", value='''<h1 style="color: blue; font-size: 18px; font-weight: bold;">
Please upload image and press the process button!
</h1>''')
process_button_face = gr.Button("Process Image", visible=True)
with gr.Row(visible=False) as row_finger:
image_input_finger = gr.Image(label="Upload Image", type="filepath")
status_text_finger = gr.HTML(label="Output", value='''<h1 style="color: blue; font-size: 18px; font-weight: bold;">
Please upload finger print image and press the process button!
</h1>''')
process_button_finger = gr.Button("Process Image", visible=False)
dropdown_mode_voice = gr.Dropdown(label="Choose mode", choices=['binary', 'multi'], value="binary", visible=False)
with gr.Row(visible=False) as row_voice:
audio_input_voice = gr.File(label="Upload Audio", file_types=["audio"])
status_text_voice = gr.HTML(label="Output", value='''<h1 style="color: blue; font-size: 18px; font-weight: bold;">
Please upload audio file and press the process button!
</h1>''')
process_button_voice = gr.Button("Process Audio", visible=False)
dropdown_modality.change(
fn=update_visibility,
inputs=[dropdown_modality],
outputs=[
dropdown_mode_face,
dropdown_model_face,
row_face,
process_button_face,
row_finger,
process_button_finger,
dropdown_mode_voice,
row_voice,
process_button_voice
]
)
process_button_face.click(
fn=handle_button_click_face,
inputs=[image_input_face, dropdown_mode_face, dropdown_model_face],
outputs=[image_input_face, status_text_face]
)
process_button_voice.click(
fn=handle_button_click_voice,
inputs=[audio_input_voice, dropdown_mode_voice],
outputs=status_text_voice
)
process_button_finger.click(
fn=handle_button_click_finger,
inputs=image_input_finger,
outputs=status_text_finger
)
demo.launch()