from flask import Flask, jsonify, request from flask_cors import CORS import torch import torch.nn as nn import torch.nn.functional as F import torchvision import torchvision.models as models from torchvision import transforms import torchaudio import numpy as np import matplotlib.pyplot as plt import base64 import io from PIL import Image from ultralytics import YOLO from PIL import Image from facenet_pytorch import MTCNN from ultralytics import YOLO app = Flask(__name__) CORS(app) binary_labels = ['Real','Spoof'] multi_voice_labels = ['Real','Text to Speech','Voice Conversion','Text to Speech + Voice Conversion'] multi_face_labels = ['Genuine Face','Printed Photo','Paper Cut','Replayed Face','3D Mask'] multi_finger_print_labels = ['Real Fingerprint','Printed Image','Gelatin Mold','Silicone Mask'] finger_print_detector = YOLO('fingerprint_best.pt') # def process_audio(encoded_audio): # decoded_audio = base64.b64decode(encoded_audio) # audio_bytes = io.BytesIO(decoded_audio) # waveform, sample_rate = torchaudio.load(audio_bytes) # if waveform.size(0) > 1: # waveform = waveform.mean(dim=0, keepdim=True) # Convert to mono by averaging channels # mel_spectrogram = torchaudio.transforms.MelSpectrogram(n_mels=80)(waveform).squeeze(0) # num_frames = mel_spectrogram.size(1) # target_length = 400 # if num_frames < target_length: # padding = target_length - num_frames # mel_spectrogram = torch.cat([mel_spectrogram, torch.zeros(mel_spectrogram.size(0), padding)], dim=1) # else: # mel_spectrogram = mel_spectrogram[:, :target_length] # mel_spectrogram = mel_spectrogram.transpose(0, 1) # length = torch.tensor([mel_spectrogram.size(0)]) # return mel_spectrogram.unsqueeze(0) ,length def process_audio(encoded_audio): decoded_audio = base64.b64decode(encoded_audio) audio_bytes = io.BytesIO(decoded_audio) waveform, sample_rate = torchaudio.load(audio_bytes) mel_spectrogram = torchaudio.transforms.MelSpectrogram(n_mels=80) normalize = torchvision.transforms.Normalize(mean=[0.5], std=[0.5]) mel_spectrogram = mel_spectrogram(waveform)[0].squeeze(0) if len(mel_spectrogram(waveform)) > 0 else mel_spectrogram(waveform).squeeze(0) num_frames = mel_spectrogram.size(1) target_length = 400 target_size=(224, 224) if num_frames < target_length: padding = target_length - num_frames mel_spectrogram = torch.cat([mel_spectrogram, torch.zeros(mel_spectrogram.size(0), padding)], dim=1) else: mel_spectrogram = mel_spectrogram[:, :target_length] mel_spectrogram = torchvision.transforms.Resize(target_size)(mel_spectrogram.unsqueeze(0)).squeeze(0) mel_spectrogram = mel_spectrogram.unsqueeze(0) mel_spectrogram = mel_spectrogram.repeat(3, 1, 1) mel_spectrogram = normalize(mel_spectrogram).unsqueeze(0) return mel_spectrogram def process_image(base64_img,extend=0): image_data = base64.b64decode(base64_img) img = Image.open(io.BytesIO(image_data)).convert('RGB') if isinstance(img, torch.Tensor): img = transforms.ToPILImage()(img) elif isinstance(img, np.ndarray): img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) mtcnn = MTCNN(keep_all=False, device='cuda' if torch.cuda.is_available() else 'cpu') boxes, _ = mtcnn.detect(img) face_detected = boxes is not None if face_detected: real_w, real_h = img.size box = boxes[0] bbox = list(map(float, box)) x1 = int(bbox[0]) x2 = int(bbox[1]) x3 = int(bbox[2]) x4 = int(bbox[3]) img = img.crop((x1,x2,x3,x4)) transformer = torchvision.transforms.Compose([ transforms.ToTensor(), transforms.Resize((224, 224), antialias=True) ]) img = transformer(img) bbox = [x1,x2,x3,x4] if boxes is not None else None return img.unsqueeze(0), face_detected, bbox def process_fingerprint_image(base64_img): image_data = base64.b64decode(base64_img) img = Image.open(io.BytesIO(image_data)).convert('RGB') if isinstance(img, torch.Tensor): img = transforms.ToPILImage()(img) elif isinstance(img, np.ndarray): img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) transformer = torchvision.transforms.Compose([ transforms.ToTensor(), transforms.Resize((224, 224), antialias=True) ]) img = transformer(img) results = finger_print_detector(img.unsqueeze(0)) is_detected = any(np.array(results[0].boxes.cls.cpu())==0) return img.unsqueeze(0),is_detected class ConformerClassifier(torch.nn.Module): def __init__(self, input_dim, num_classes, num_heads, ffn_dim, num_layers, depthwise_conv_kernel_size,dropout=0.0,use_group_norm=False,convolution_first=False): super(ConformerClassifier, self).__init__() self.conformer = torchaudio.models.Conformer( input_dim=input_dim, num_heads=num_heads, ffn_dim=ffn_dim, num_layers=num_layers, depthwise_conv_kernel_size=depthwise_conv_kernel_size, dropout=dropout, use_group_norm=use_group_norm, convolution_first=convolution_first ) self.fc = torch.nn.Linear(input_dim, num_classes) def forward(self, x, lengths): x,length = self.conformer(x, lengths) x = x.mean(dim=1) x = self.fc(x) return x def initialize_weights(m): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) class NativeAdapter(nn.Module): def __init__(self, input_dim=1024, bottleneck_dim=64): super(NativeAdapter, self).__init__() self.linear1 = nn.Linear(input_dim, bottleneck_dim) self.activ = nn.GELU() self.linear2 = nn.Linear(bottleneck_dim, input_dim) self.apply(initialize_weights) def forward(self, x): residual = x out = self.linear1(x) out = self.activ(out) out = self.linear2(out) return out + residual class EnsembleAdapter(nn.Module): def __init__(self): super(EnsembleAdapter, self).__init__() self.adapter1 = NativeAdapter() self.adapter2 = NativeAdapter() def forward(self, x): out1 = self.adapter1(x) out2 = self.adapter2(x) out = (out1 + out2) / 2 cos_sim = torch.nn.functional.cosine_similarity(out1, out2, dim=-1) cos_sim_loss = cos_sim.mean() return out, cos_sim_loss class FWTLayer(nn.Module): def __init__(self, hidden_dim=1024, std=0.02): super(FWTLayer, self).__init__() self.hidden_dim = hidden_dim self.std = std self.W_alpha = nn.Parameter(torch.randn(hidden_dim)) self.W_beta = nn.Parameter(torch.randn(hidden_dim)) def forward(self, x): alpha = torch.randn(self.hidden_dim).to(x.device) * self.std * F.softplus(self.W_alpha) beta = torch.randn(self.hidden_dim).to(x.device) * self.std * F.softplus(self.W_beta) x_transformed = x + alpha * x + beta return x_transformed class UpdatedBlock(nn.Module): def __init__(self, encoder_block): super(UpdatedBlock, self).__init__() self.ln_1 = encoder_block.ln_1 self.self_attention = encoder_block.self_attention self.dropout = encoder_block.dropout self.ensemble_adapter1 = EnsembleAdapter() self.ln_2 = encoder_block.ln_2 self.mlp = encoder_block.mlp self.ensemble_adapter2 = EnsembleAdapter() self.fwt_layer = FWTLayer() def forward(self, input): x = self.ln_1(input) x, _ = self.self_attention(x, x, x, need_weights=False) x = self.dropout(x) x, loss_1 = self.ensemble_adapter1(x) x = x + input y = self.ln_2(x) y = self.mlp(y) y, loss_2 = self.ensemble_adapter2(y) out = x + y if self.training: out = self.fwt_layer(out) return out, (loss_1 + loss_2) / 2 class UpdatedEncoder(nn.Module): def __init__(self, encoder): super(UpdatedEncoder, self).__init__() self.pos_embedding = encoder.pos_embedding self.dropout = encoder.dropout self.layers = nn.ModuleList([UpdatedBlock(layer) for layer in encoder.layers]) self.ln = encoder.ln def forward(self, x): out = x + self.pos_embedding out = self.dropout(out) total_loss = 0 for layer in self.layers: out, loss = layer(out) total_loss += loss out = self.ln(out) return out, total_loss class UpdatedViT(nn.Module): def __init__(self, base_model): super(UpdatedViT, self).__init__() self.conv_proj = base_model.conv_proj self.encoder = UpdatedEncoder(base_model.encoder) self.heads = base_model.heads self._process_input = base_model._process_input self.class_token = base_model.class_token def forward(self, x): x = self._process_input(x) n = x.shape[0] batch_class_token = self.class_token.expand(n, -1, -1) x = torch.cat([batch_class_token, x], dim=1) x, cos_loss = self.encoder(x) x = x[:, 0] x = self.heads(x) return x, cos_loss / len(self.encoder.layers) # Voice Binary Model # voice_binary_model = ConformerClassifier( # input_dim=80, # num_classes=2, # num_heads=4, # ffn_dim=128, # num_layers=4, # depthwise_conv_kernel_size=7, # dropout=0.3, # use_group_norm=False, # convolution_first=True # ) # voice_binary_model.load_state_dict(torch.load('binary_voice_model.pth',map_location='cpu')) # voice_binary_model.eval() voice_model_binary = torchvision.models.vit_l_16(weights=None,progress=True) voice_model_binary.heads=nn.Sequential( nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(0.3), nn.Linear(512, 2) ) voice_binary_model = UpdatedViT(voice_model_binary) voice_binary_model.load_state_dict(torch.load('voice_weights.pth',map_location='cpu')) voice_binary_model.eval() # Voice Multi Model voice_multi_model = ConformerClassifier( input_dim=80, num_classes=4, num_heads=4, ffn_dim=128, num_layers=4, depthwise_conv_kernel_size=31, dropout=0.3, use_group_norm=False, convolution_first=True ) voice_multi_model.load_state_dict(torch.load('multi_voice_model.pth',map_location='cpu')) voice_multi_model.eval() # Vision Transformer Binary Model vit_model_binary = torchvision.models.vit_l_16(weights=None,progress=True) vit_model_binary.heads=nn.Sequential( nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(0.3), nn.Linear(512, 2) ) vit_binary_model = UpdatedViT(vit_model_binary) vit_binary_model.load_state_dict(torch.load('binary_vit_model_correct.pth',map_location='cpu')) vit_binary_model.eval() # Vision Transformer Multi Model vit_model_multi = torchvision.models.vit_l_16(weights=None,progress=True) vit_model_multi.heads=nn.Sequential( nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(0.3), nn.Linear(512, 5) ) vit_multi_model = UpdatedViT(vit_model_multi) vit_multi_model.load_state_dict(torch.load('multi_vit_model_correct.pth',map_location='cpu')) vit_multi_model.eval() # ConvNext Binary Model convnext_binary_model = torchvision.models.convnext_base(weights=None,progress=False) convnext_binary_model.classifier[2]=nn.Sequential( nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(0.3), nn.Linear(512, 2), ) convnext_binary_model.load_state_dict(torch.load('binary_convnext_model_correct.pth',map_location='cpu')) convnext_binary_model.eval() # ConvNext Multi Model convnext_multi_model = torchvision.models.convnext_base(weights=None,progress=False) convnext_multi_model.classifier[2]=nn.Sequential( nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(0.3), nn.Linear(512, 5), ) convnext_multi_model.load_state_dict(torch.load('multi_convnext_model_correct.pth',map_location='cpu')) convnext_multi_model.eval() # Fingerprint Binary Model fingerprint_binary = torchvision.models.vit_l_16(weights=None,progress=True) fingerprint_binary.heads=nn.Sequential( nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(0.3), nn.Linear(512, 2) ) binary_fingerprint_model = UpdatedViT(fingerprint_binary) binary_fingerprint_model.load_state_dict(torch.load('binary_finger_print_correct1.pth',map_location='cpu')) fingerprint_binary.eval() binary_fingerprint_model.eval() # Fingerprint Multi Model fingerprint_multi = torchvision.models.vit_l_16(weights=None,progress=True) fingerprint_multi.heads=nn.Sequential( nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(0.3), nn.Linear(512, 4) ) multi_fingerprint_model = UpdatedViT(fingerprint_multi) multi_fingerprint_model.load_state_dict(torch.load('multi_finger_print_correct1.pth',map_location='cpu')) fingerprint_multi.eval() multi_fingerprint_model.eval() print('Models Loaded Successfully') @app.route('/') def home(): return "Welcome to the Antispoofing Solutions!" # @app.route('/api/voice', methods=['POST']) # def post_api_voice(): # try: # binary_mode = request.args.get('binary', 'False').lower() == 'true' # data = request.json # if not data or 'base64' not in data: # return jsonify({'error': 'Invalid input. No base64 data provided.','status':400}), 400 # encoded_audio = data['base64'] # mel_spectrogram, length = process_audio(encoded_audio) # with torch.no_grad(): # if binary_mode: # output = voice_binary_model(mel_spectrogram, length) # prob = torch.nn.functional.softmax(output[0], dim=0) # pred = torch.argmax(prob).item() # category = binary_labels[pred] # probs_dict = {binary_labels[i]: prob[i].item()*100 for i in range(len(binary_labels))} # else: # output = voice_multi_model(mel_spectrogram, length) # prob = torch.nn.functional.softmax(output[0], dim=0) # pred = torch.argmax(prob).item() # category = multi_voice_labels[pred] # probs_dict = {multi_voice_labels[i]: prob[i].item()*100 for i in range(len(multi_voice_labels))} # mode = 'binary' if binary_mode else 'multi' # response = { # 'message': 'Data received!', # 'class': category, # 'mode' : mode, # 'probs': probs_dict, # 'status':200 # } # return jsonify(response), 201 # except KeyError as e: # return jsonify({'error': f'Missing key: {str(e)}','status':400}), 400 # except Exception as e: # return jsonify({'error': str(e),'status':400}), 500 @app.route('/api/voice', methods=['POST']) def post_api_voice(): try: binary_mode = True data = request.json if not data or 'base64' not in data: return jsonify({'error': 'Invalid input. No base64 data provided.','status':400}), 400 encoded_audio = data['base64'] mel_spectrogram = process_audio(encoded_audio) with torch.no_grad(): if binary_mode: output,_ = voice_binary_model(mel_spectrogram) prob = torch.nn.functional.softmax(output[0], dim=0) pred = torch.argmax(prob).item() category = binary_labels[pred] probs_dict = {binary_labels[i]: prob[i].item()*100 for i in range(len(binary_labels))} else: output = voice_multi_model(mel_spectrogram, length) prob = torch.nn.functional.softmax(output[0], dim=0) pred = torch.argmax(prob).item() category = multi_voice_labels[pred] probs_dict = {multi_voice_labels[i]: prob[i].item()*100 for i in range(len(multi_voice_labels))} mode = 'binary' if binary_mode else 'multi' response = { 'message': 'Data received!', 'class': category, 'mode' : mode, 'probs': probs_dict, 'status':200 } return jsonify(response), 201 except KeyError as e: return jsonify({'error': f'Missing key: {str(e)}','status':400}), 400 except Exception as e: return jsonify({'error': str(e),'status':400}), 500 @app.route('/api/face', methods=['POST']) def post_api_face(): try: binary_mode = request.args.get('binary', 'False').lower() == 'true' model_name = request.args.get('model', 'convnext').lower() data = request.json if not data or 'base64' not in data: return jsonify({'error': 'Invalid input. No base64 data provided.'}), 400 encoded_image = data['base64'] # Process the image processsed_image, is_face_detected, bbox = process_image(encoded_image) if not is_face_detected: return jsonify({'error': 'No Face Detected.','status':400}), 400 with torch.no_grad(): if binary_mode: if model_name=='transformer': output, _ = vit_binary_model(processsed_image) else: output= convnext_binary_model(processsed_image) model_name = 'convnext' prob = torch.nn.functional.softmax(output[0], dim=0) pred = torch.argmax(prob).item() category = binary_labels[pred] probs_dict = {binary_labels[i]: prob[i].item()*100 for i in range(len(binary_labels))} else: if model_name=='transformer': output, _ = vit_multi_model(processsed_image) else: output= convnext_multi_model(processsed_image) model_name = 'convnext' prob = torch.nn.functional.softmax(output[0], dim=0) pred = torch.argmax(prob).item() category = multi_face_labels[pred] probs_dict = {multi_face_labels[i]: prob[i].item()*100 for i in range(len(multi_face_labels))} mode = 'binary' if binary_mode else 'multi' response = { 'message': 'Data received!', 'class': category, 'probs': probs_dict, 'model': model_name, 'mode' : mode, 'bbox' : bbox, 'status':200 } return jsonify(response), 201 except KeyError as e: return jsonify({'error': f'Missing key: {str(e)}','status':400}), 400 except Exception as e: return jsonify({'error': str(e),'status':400}), 500 @app.route('/api/fingerprint', methods=['POST']) def post_api_fingerprint(): try: data = request.json binary_mode = request.args.get('binary', 'False').lower() == 'true' if not data or 'base64' not in data: return jsonify({'error': 'Invalid input. No base64 data provided.','status':400}), 400 encoded_image = data['base64'] # Process the image processsed_image, is_detected = process_fingerprint_image(encoded_image) if is_detected: with torch.no_grad(): if binary_mode: output, _ = vit_binary_model(processsed_image) prob = torch.nn.functional.softmax(output[0], dim=0) pred = torch.argmax(prob).item() category = binary_labels[pred] probs_dict = {binary_labels[i]: prob[i].item()*100 for i in range(len(binary_labels))} else: output, _ = vit_multi_model(processsed_image) prob = torch.nn.functional.softmax(output[0], dim=0) pred = torch.argmax(prob).item() category = multi_finger_print_labels[pred] probs_dict = {multi_finger_print_labels[i]: prob[i].item()*100 for i in range(len(multi_finger_print_labels))} response = { 'message': 'Data received!', 'class': category, 'probs': probs_dict, 'status':200 } return jsonify(response), 201 else: return jsonify({'error': f'No Fingerprint Detected','status':400}), 400 except KeyError as e: return jsonify({'error': f'Missing key: {str(e)}','status':400}), 400 except Exception as e: return jsonify({'error': str(e),'status':400}), 500