abdullahsajid's picture
Update app.py
53c7fc4 verified
from flask import Flask, jsonify, request
from flask_cors import CORS
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.models as models
from torchvision import transforms
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
import base64
import io
from PIL import Image
from ultralytics import YOLO
from PIL import Image
from facenet_pytorch import MTCNN
from ultralytics import YOLO
app = Flask(__name__)
CORS(app)
binary_labels = ['Real','Spoof']
multi_voice_labels = ['Real','Text to Speech','Voice Conversion','Text to Speech + Voice Conversion']
multi_face_labels = ['Genuine Face','Printed Photo','Paper Cut','Replayed Face','3D Mask']
multi_finger_print_labels = ['Real Fingerprint','Printed Image','Gelatin Mold','Silicone Mask']
finger_print_detector = YOLO('fingerprint_best.pt')
# def process_audio(encoded_audio):
# decoded_audio = base64.b64decode(encoded_audio)
# audio_bytes = io.BytesIO(decoded_audio)
# waveform, sample_rate = torchaudio.load(audio_bytes)
# if waveform.size(0) > 1:
# waveform = waveform.mean(dim=0, keepdim=True) # Convert to mono by averaging channels
# mel_spectrogram = torchaudio.transforms.MelSpectrogram(n_mels=80)(waveform).squeeze(0)
# num_frames = mel_spectrogram.size(1)
# target_length = 400
# if num_frames < target_length:
# padding = target_length - num_frames
# mel_spectrogram = torch.cat([mel_spectrogram, torch.zeros(mel_spectrogram.size(0), padding)], dim=1)
# else:
# mel_spectrogram = mel_spectrogram[:, :target_length]
# mel_spectrogram = mel_spectrogram.transpose(0, 1)
# length = torch.tensor([mel_spectrogram.size(0)])
# return mel_spectrogram.unsqueeze(0) ,length
def process_audio(encoded_audio):
decoded_audio = base64.b64decode(encoded_audio)
audio_bytes = io.BytesIO(decoded_audio)
waveform, sample_rate = torchaudio.load(audio_bytes)
mel_spectrogram = torchaudio.transforms.MelSpectrogram(n_mels=80)
normalize = torchvision.transforms.Normalize(mean=[0.5], std=[0.5])
mel_spectrogram = mel_spectrogram(waveform)[0].squeeze(0) if len(mel_spectrogram(waveform)) > 0 else mel_spectrogram(waveform).squeeze(0)
num_frames = mel_spectrogram.size(1)
target_length = 400
target_size=(224, 224)
if num_frames < target_length:
padding = target_length - num_frames
mel_spectrogram = torch.cat([mel_spectrogram, torch.zeros(mel_spectrogram.size(0), padding)], dim=1)
else:
mel_spectrogram = mel_spectrogram[:, :target_length]
mel_spectrogram = torchvision.transforms.Resize(target_size)(mel_spectrogram.unsqueeze(0)).squeeze(0)
mel_spectrogram = mel_spectrogram.unsqueeze(0)
mel_spectrogram = mel_spectrogram.repeat(3, 1, 1)
mel_spectrogram = normalize(mel_spectrogram).unsqueeze(0)
return mel_spectrogram
def process_image(base64_img,extend=0):
image_data = base64.b64decode(base64_img)
img = Image.open(io.BytesIO(image_data)).convert('RGB')
if isinstance(img, torch.Tensor):
img = transforms.ToPILImage()(img)
elif isinstance(img, np.ndarray):
img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
mtcnn = MTCNN(keep_all=False, device='cuda' if torch.cuda.is_available() else 'cpu')
boxes, _ = mtcnn.detect(img)
face_detected = boxes is not None
if face_detected:
real_w, real_h = img.size
box = boxes[0]
bbox = list(map(float, box))
x1 = int(bbox[0])
x2 = int(bbox[1])
x3 = int(bbox[2])
x4 = int(bbox[3])
img = img.crop((x1,x2,x3,x4))
transformer = torchvision.transforms.Compose([
transforms.ToTensor(),
transforms.Resize((224, 224), antialias=True)
])
img = transformer(img)
bbox = [x1,x2,x3,x4] if boxes is not None else None
return img.unsqueeze(0), face_detected, bbox
def process_fingerprint_image(base64_img):
image_data = base64.b64decode(base64_img)
img = Image.open(io.BytesIO(image_data)).convert('RGB')
if isinstance(img, torch.Tensor):
img = transforms.ToPILImage()(img)
elif isinstance(img, np.ndarray):
img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
transformer = torchvision.transforms.Compose([
transforms.ToTensor(),
transforms.Resize((224, 224), antialias=True)
])
img = transformer(img)
results = finger_print_detector(img.unsqueeze(0))
is_detected = any(np.array(results[0].boxes.cls.cpu())==0)
return img.unsqueeze(0),is_detected
class ConformerClassifier(torch.nn.Module):
def __init__(self, input_dim, num_classes, num_heads, ffn_dim, num_layers, depthwise_conv_kernel_size,dropout=0.0,use_group_norm=False,convolution_first=False):
super(ConformerClassifier, self).__init__()
self.conformer = torchaudio.models.Conformer(
input_dim=input_dim,
num_heads=num_heads,
ffn_dim=ffn_dim,
num_layers=num_layers,
depthwise_conv_kernel_size=depthwise_conv_kernel_size,
dropout=dropout,
use_group_norm=use_group_norm,
convolution_first=convolution_first
)
self.fc = torch.nn.Linear(input_dim, num_classes)
def forward(self, x, lengths):
x,length = self.conformer(x, lengths)
x = x.mean(dim=1)
x = self.fc(x)
return x
def initialize_weights(m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
class NativeAdapter(nn.Module):
def __init__(self, input_dim=1024, bottleneck_dim=64):
super(NativeAdapter, self).__init__()
self.linear1 = nn.Linear(input_dim, bottleneck_dim)
self.activ = nn.GELU()
self.linear2 = nn.Linear(bottleneck_dim, input_dim)
self.apply(initialize_weights)
def forward(self, x):
residual = x
out = self.linear1(x)
out = self.activ(out)
out = self.linear2(out)
return out + residual
class EnsembleAdapter(nn.Module):
def __init__(self):
super(EnsembleAdapter, self).__init__()
self.adapter1 = NativeAdapter()
self.adapter2 = NativeAdapter()
def forward(self, x):
out1 = self.adapter1(x)
out2 = self.adapter2(x)
out = (out1 + out2) / 2
cos_sim = torch.nn.functional.cosine_similarity(out1, out2, dim=-1)
cos_sim_loss = cos_sim.mean()
return out, cos_sim_loss
class FWTLayer(nn.Module):
def __init__(self, hidden_dim=1024, std=0.02):
super(FWTLayer, self).__init__()
self.hidden_dim = hidden_dim
self.std = std
self.W_alpha = nn.Parameter(torch.randn(hidden_dim))
self.W_beta = nn.Parameter(torch.randn(hidden_dim))
def forward(self, x):
alpha = torch.randn(self.hidden_dim).to(x.device) * self.std * F.softplus(self.W_alpha)
beta = torch.randn(self.hidden_dim).to(x.device) * self.std * F.softplus(self.W_beta)
x_transformed = x + alpha * x + beta
return x_transformed
class UpdatedBlock(nn.Module):
def __init__(self, encoder_block):
super(UpdatedBlock, self).__init__()
self.ln_1 = encoder_block.ln_1
self.self_attention = encoder_block.self_attention
self.dropout = encoder_block.dropout
self.ensemble_adapter1 = EnsembleAdapter()
self.ln_2 = encoder_block.ln_2
self.mlp = encoder_block.mlp
self.ensemble_adapter2 = EnsembleAdapter()
self.fwt_layer = FWTLayer()
def forward(self, input):
x = self.ln_1(input)
x, _ = self.self_attention(x, x, x, need_weights=False)
x = self.dropout(x)
x, loss_1 = self.ensemble_adapter1(x)
x = x + input
y = self.ln_2(x)
y = self.mlp(y)
y, loss_2 = self.ensemble_adapter2(y)
out = x + y
if self.training:
out = self.fwt_layer(out)
return out, (loss_1 + loss_2) / 2
class UpdatedEncoder(nn.Module):
def __init__(self, encoder):
super(UpdatedEncoder, self).__init__()
self.pos_embedding = encoder.pos_embedding
self.dropout = encoder.dropout
self.layers = nn.ModuleList([UpdatedBlock(layer) for layer in encoder.layers])
self.ln = encoder.ln
def forward(self, x):
out = x + self.pos_embedding
out = self.dropout(out)
total_loss = 0
for layer in self.layers:
out, loss = layer(out)
total_loss += loss
out = self.ln(out)
return out, total_loss
class UpdatedViT(nn.Module):
def __init__(self, base_model):
super(UpdatedViT, self).__init__()
self.conv_proj = base_model.conv_proj
self.encoder = UpdatedEncoder(base_model.encoder)
self.heads = base_model.heads
self._process_input = base_model._process_input
self.class_token = base_model.class_token
def forward(self, x):
x = self._process_input(x)
n = x.shape[0]
batch_class_token = self.class_token.expand(n, -1, -1)
x = torch.cat([batch_class_token, x], dim=1)
x, cos_loss = self.encoder(x)
x = x[:, 0]
x = self.heads(x)
return x, cos_loss / len(self.encoder.layers)
# Voice Binary Model
# voice_binary_model = ConformerClassifier(
# input_dim=80,
# num_classes=2,
# num_heads=4,
# ffn_dim=128,
# num_layers=4,
# depthwise_conv_kernel_size=7,
# dropout=0.3,
# use_group_norm=False,
# convolution_first=True
# )
# voice_binary_model.load_state_dict(torch.load('binary_voice_model.pth',map_location='cpu'))
# voice_binary_model.eval()
voice_model_binary = torchvision.models.vit_l_16(weights=None,progress=True)
voice_model_binary.heads=nn.Sequential(
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 2)
)
voice_binary_model = UpdatedViT(voice_model_binary)
voice_binary_model.load_state_dict(torch.load('voice_weights.pth',map_location='cpu'))
voice_binary_model.eval()
# Voice Multi Model
voice_multi_model = ConformerClassifier(
input_dim=80,
num_classes=4,
num_heads=4,
ffn_dim=128,
num_layers=4,
depthwise_conv_kernel_size=31,
dropout=0.3,
use_group_norm=False,
convolution_first=True
)
voice_multi_model.load_state_dict(torch.load('multi_voice_model.pth',map_location='cpu'))
voice_multi_model.eval()
# Vision Transformer Binary Model
vit_model_binary = torchvision.models.vit_l_16(weights=None,progress=True)
vit_model_binary.heads=nn.Sequential(
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 2)
)
vit_binary_model = UpdatedViT(vit_model_binary)
vit_binary_model.load_state_dict(torch.load('binary_vit_model_correct.pth',map_location='cpu'))
vit_binary_model.eval()
# Vision Transformer Multi Model
vit_model_multi = torchvision.models.vit_l_16(weights=None,progress=True)
vit_model_multi.heads=nn.Sequential(
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 5)
)
vit_multi_model = UpdatedViT(vit_model_multi)
vit_multi_model.load_state_dict(torch.load('multi_vit_model_correct.pth',map_location='cpu'))
vit_multi_model.eval()
# ConvNext Binary Model
convnext_binary_model = torchvision.models.convnext_base(weights=None,progress=False)
convnext_binary_model.classifier[2]=nn.Sequential(
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 2),
)
convnext_binary_model.load_state_dict(torch.load('binary_convnext_model_correct.pth',map_location='cpu'))
convnext_binary_model.eval()
# ConvNext Multi Model
convnext_multi_model = torchvision.models.convnext_base(weights=None,progress=False)
convnext_multi_model.classifier[2]=nn.Sequential(
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 5),
)
convnext_multi_model.load_state_dict(torch.load('multi_convnext_model_correct.pth',map_location='cpu'))
convnext_multi_model.eval()
# Fingerprint Binary Model
fingerprint_binary = torchvision.models.vit_l_16(weights=None,progress=True)
fingerprint_binary.heads=nn.Sequential(
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 2)
)
binary_fingerprint_model = UpdatedViT(fingerprint_binary)
binary_fingerprint_model.load_state_dict(torch.load('binary_finger_print_correct1.pth',map_location='cpu'))
fingerprint_binary.eval()
binary_fingerprint_model.eval()
# Fingerprint Multi Model
fingerprint_multi = torchvision.models.vit_l_16(weights=None,progress=True)
fingerprint_multi.heads=nn.Sequential(
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 4)
)
multi_fingerprint_model = UpdatedViT(fingerprint_multi)
multi_fingerprint_model.load_state_dict(torch.load('multi_finger_print_correct1.pth',map_location='cpu'))
fingerprint_multi.eval()
multi_fingerprint_model.eval()
print('Models Loaded Successfully')
@app.route('/')
def home():
return "Welcome to the Antispoofing Solutions!"
# @app.route('/api/voice', methods=['POST'])
# def post_api_voice():
# try:
# binary_mode = request.args.get('binary', 'False').lower() == 'true'
# data = request.json
# if not data or 'base64' not in data:
# return jsonify({'error': 'Invalid input. No base64 data provided.','status':400}), 400
# encoded_audio = data['base64']
# mel_spectrogram, length = process_audio(encoded_audio)
# with torch.no_grad():
# if binary_mode:
# output = voice_binary_model(mel_spectrogram, length)
# prob = torch.nn.functional.softmax(output[0], dim=0)
# pred = torch.argmax(prob).item()
# category = binary_labels[pred]
# probs_dict = {binary_labels[i]: prob[i].item()*100 for i in range(len(binary_labels))}
# else:
# output = voice_multi_model(mel_spectrogram, length)
# prob = torch.nn.functional.softmax(output[0], dim=0)
# pred = torch.argmax(prob).item()
# category = multi_voice_labels[pred]
# probs_dict = {multi_voice_labels[i]: prob[i].item()*100 for i in range(len(multi_voice_labels))}
# mode = 'binary' if binary_mode else 'multi'
# response = {
# 'message': 'Data received!',
# 'class': category,
# 'mode' : mode,
# 'probs': probs_dict,
# 'status':200
# }
# return jsonify(response), 201
# except KeyError as e:
# return jsonify({'error': f'Missing key: {str(e)}','status':400}), 400
# except Exception as e:
# return jsonify({'error': str(e),'status':400}), 500
@app.route('/api/voice', methods=['POST'])
def post_api_voice():
try:
binary_mode = True
data = request.json
if not data or 'base64' not in data:
return jsonify({'error': 'Invalid input. No base64 data provided.','status':400}), 400
encoded_audio = data['base64']
mel_spectrogram = process_audio(encoded_audio)
with torch.no_grad():
if binary_mode:
output,_ = voice_binary_model(mel_spectrogram)
prob = torch.nn.functional.softmax(output[0], dim=0)
pred = torch.argmax(prob).item()
category = binary_labels[pred]
probs_dict = {binary_labels[i]: prob[i].item()*100 for i in range(len(binary_labels))}
else:
output = voice_multi_model(mel_spectrogram, length)
prob = torch.nn.functional.softmax(output[0], dim=0)
pred = torch.argmax(prob).item()
category = multi_voice_labels[pred]
probs_dict = {multi_voice_labels[i]: prob[i].item()*100 for i in range(len(multi_voice_labels))}
mode = 'binary' if binary_mode else 'multi'
response = {
'message': 'Data received!',
'class': category,
'mode' : mode,
'probs': probs_dict,
'status':200
}
return jsonify(response), 201
except KeyError as e:
return jsonify({'error': f'Missing key: {str(e)}','status':400}), 400
except Exception as e:
return jsonify({'error': str(e),'status':400}), 500
@app.route('/api/face', methods=['POST'])
def post_api_face():
try:
binary_mode = request.args.get('binary', 'False').lower() == 'true'
model_name = request.args.get('model', 'convnext').lower()
data = request.json
if not data or 'base64' not in data:
return jsonify({'error': 'Invalid input. No base64 data provided.'}), 400
encoded_image = data['base64']
# Process the image
processsed_image, is_face_detected, bbox = process_image(encoded_image)
if not is_face_detected:
return jsonify({'error': 'No Face Detected.','status':400}), 400
with torch.no_grad():
if binary_mode:
if model_name=='transformer':
output, _ = vit_binary_model(processsed_image)
else:
output= convnext_binary_model(processsed_image)
model_name = 'convnext'
prob = torch.nn.functional.softmax(output[0], dim=0)
pred = torch.argmax(prob).item()
category = binary_labels[pred]
probs_dict = {binary_labels[i]: prob[i].item()*100 for i in range(len(binary_labels))}
else:
if model_name=='transformer':
output, _ = vit_multi_model(processsed_image)
else:
output= convnext_multi_model(processsed_image)
model_name = 'convnext'
prob = torch.nn.functional.softmax(output[0], dim=0)
pred = torch.argmax(prob).item()
category = multi_face_labels[pred]
probs_dict = {multi_face_labels[i]: prob[i].item()*100 for i in range(len(multi_face_labels))}
mode = 'binary' if binary_mode else 'multi'
response = {
'message': 'Data received!',
'class': category,
'probs': probs_dict,
'model': model_name,
'mode' : mode,
'bbox' : bbox,
'status':200
}
return jsonify(response), 201
except KeyError as e:
return jsonify({'error': f'Missing key: {str(e)}','status':400}), 400
except Exception as e:
return jsonify({'error': str(e),'status':400}), 500
@app.route('/api/fingerprint', methods=['POST'])
def post_api_fingerprint():
try:
data = request.json
binary_mode = request.args.get('binary', 'False').lower() == 'true'
if not data or 'base64' not in data:
return jsonify({'error': 'Invalid input. No base64 data provided.','status':400}), 400
encoded_image = data['base64']
# Process the image
processsed_image, is_detected = process_fingerprint_image(encoded_image)
if is_detected:
with torch.no_grad():
if binary_mode:
output, _ = vit_binary_model(processsed_image)
prob = torch.nn.functional.softmax(output[0], dim=0)
pred = torch.argmax(prob).item()
category = binary_labels[pred]
probs_dict = {binary_labels[i]: prob[i].item()*100 for i in range(len(binary_labels))}
else:
output, _ = vit_multi_model(processsed_image)
prob = torch.nn.functional.softmax(output[0], dim=0)
pred = torch.argmax(prob).item()
category = multi_finger_print_labels[pred]
probs_dict = {multi_finger_print_labels[i]: prob[i].item()*100 for i in range(len(multi_finger_print_labels))}
response = {
'message': 'Data received!',
'class': category,
'probs': probs_dict,
'status':200
}
return jsonify(response), 201
else:
return jsonify({'error': f'No Fingerprint Detected','status':400}), 400
except KeyError as e:
return jsonify({'error': f'Missing key: {str(e)}','status':400}), 400
except Exception as e:
return jsonify({'error': str(e),'status':400}), 500