Spaces:
Sleeping
Sleeping
| from flask import Flask, jsonify, request | |
| from flask_cors import CORS | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision.models as models | |
| from torchvision import transforms | |
| import torchaudio | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import base64 | |
| import io | |
| from PIL import Image | |
| from ultralytics import YOLO | |
| from PIL import Image | |
| app = Flask(__name__) | |
| CORS(app) | |
| idx_to_class_resnet50 = {0 : "Genuine" , 1:'Printed Paper' , 2 : 'Replayed'} | |
| idx_to_class_yolo9 = idx_to_class_yolo9 = {0: 'Genuine', 1: 'Printed Paper', 2: 'Replayed', 3: 'Paper Mask'} | |
| idx_to_class_resnet50_celeba = {0 : "Genuine" , 1:'Printed Paper' , 2 : 'Paper Cut',3:'Replayed',4:'3D Mask'} | |
| binary_labels = ['real','spoof'] | |
| transform_data_resnet50=transforms.Compose([ | |
| transforms.Resize(size=(224,224)), | |
| transforms.ToTensor() | |
| ]) | |
| transform_data_resnet50_celeba=transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Resize((224,224), antialias=True) | |
| ]) | |
| def process_audio(encoded_audio): | |
| decoded_audio = base64.b64decode(encoded_audio) | |
| audio_bytes = io.BytesIO(decoded_audio) | |
| waveform, sample_rate = torchaudio.load(audio_bytes) | |
| if waveform.size(0) > 1: | |
| waveform = waveform.mean(dim=0, keepdim=True) # Convert to mono by averaging channels | |
| mel_spectrogram = torchaudio.transforms.MelSpectrogram(n_mels=80)(waveform).squeeze(0) | |
| num_frames = mel_spectrogram.size(1) | |
| target_length = 400 | |
| if num_frames < target_length: | |
| padding = target_length - num_frames | |
| mel_spectrogram = torch.cat([mel_spectrogram, torch.zeros(mel_spectrogram.size(0), padding)], dim=1) | |
| else: | |
| mel_spectrogram = mel_spectrogram[:, :target_length] | |
| mel_spectrogram = mel_spectrogram.transpose(0, 1) | |
| length = torch.tensor([mel_spectrogram.size(0)]) | |
| return mel_spectrogram.unsqueeze(0) ,length | |
| model_resnet50 = models.resnet50(weights=False) | |
| num_classes = 3 | |
| model_resnet50.fc = nn.Linear(model_resnet50.fc.in_features, num_classes) | |
| model_resnet50.load_state_dict(torch.load('resnet50_pytorch_rose_weights.pth',map_location=torch.device('cpu'))) | |
| model_resnet50.eval() | |
| model_resnet50_celeba = models.resnet50(weights=False) | |
| num_classes = 5 | |
| model_resnet50_celeba.fc = nn.Linear(model_resnet50_celeba.fc.in_features, num_classes) | |
| model_resnet50_celeba.load_state_dict(torch.load('resnet50_model_weights_celeba.pth',map_location=torch.device('cpu'))) | |
| model_resnet50_celeba.eval() | |
| model_yolo9 = YOLO('yolo9_best.pt') | |
| class ConformerClassifier(torch.nn.Module): | |
| def __init__(self, input_dim, num_classes, num_heads, ffn_dim, num_layers, depthwise_conv_kernel_size,dropout=0.0,use_group_norm=False,convolution_first=False): | |
| super(ConformerClassifier, self).__init__() | |
| self.conformer = torchaudio.models.Conformer( | |
| input_dim=input_dim, | |
| num_heads=num_heads, | |
| ffn_dim=ffn_dim, | |
| num_layers=num_layers, | |
| depthwise_conv_kernel_size=depthwise_conv_kernel_size, | |
| dropout=dropout, | |
| use_group_norm=use_group_norm, | |
| convolution_first=convolution_first | |
| ) | |
| self.fc = torch.nn.Linear(input_dim, num_classes) | |
| def forward(self, x, lengths): | |
| x,length = self.conformer(x, lengths) | |
| x = x.mean(dim=1) | |
| x = self.fc(x) | |
| return x | |
| voice_binary_model = ConformerClassifier( | |
| input_dim=80, | |
| num_classes=2, | |
| num_heads=4, | |
| ffn_dim=128, | |
| num_layers=4, | |
| depthwise_conv_kernel_size=7, | |
| dropout=0.3, | |
| use_group_norm=False, | |
| convolution_first=True | |
| ) | |
| voice_binary_model.load_state_dict(torch.load('binary_voice_model.pth',map_location='cpu')) | |
| voice_binary_model.eval() | |
| print('Models Loaded Successfully') | |
| def home(): | |
| return "Welcome to the Flask API!" | |
| def get_data(): | |
| img = plt.imread('test1.jpeg') | |
| img_arr = np.array(img) | |
| pil_img = Image.fromarray(img_arr.astype(np.uint8)) | |
| buffered = io.BytesIO() | |
| pil_img.save(buffered, format="JPEG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| data = { | |
| 'message': 'Hello, World!', | |
| 'items': [1, 2, 3, 4, 5], | |
| 'image': img_str | |
| } | |
| return jsonify(data) | |
| def post_data(): | |
| try: | |
| # Parse the JSON request | |
| data = request.json | |
| # Ensure the necessary fields are present | |
| if 'imageData' not in data or 'model' not in data: | |
| return jsonify({"error": "Missing required fields: 'imageData' or 'model'"}), 400 | |
| base64_image = data['imageData'] | |
| # Decode the image data | |
| try: | |
| image_data = base64.b64decode(base64_image) | |
| except base64.binascii.Error as e: | |
| return jsonify({"error": "Invalid base64 string"}), 400 | |
| # Convert image data to PIL image | |
| try: | |
| image = Image.open(io.BytesIO(image_data)).convert('RGB') | |
| except IOError as e: | |
| return jsonify({"error": "Invalid image data"}), 400 | |
| # Model prediction logic | |
| if data['model'] == 'resnet': | |
| transform_img = transform_data_resnet50(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| pred = model_resnet50(transform_img) | |
| probabilities = F.softmax(pred[0], dim=0) | |
| cat = torch.argmax(pred[0]).item() | |
| prob = round((probabilities[cat] * 100).item(), 2) | |
| name = idx_to_class_resnet50[cat] | |
| elif data['model'] == 'resnet50': | |
| transform_img = transform_data_resnet50_celeba(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| pred = model_resnet50_celeba(transform_img) | |
| probabilities = F.softmax(pred[0], dim=0) | |
| cat = torch.argmax(pred[0]).item() | |
| prob = round((probabilities[cat] * 100).item(), 2) | |
| name = idx_to_class_resnet50_celeba[cat] | |
| else: | |
| results = model_yolo9(image) | |
| name = 'not detectable' | |
| prob = 0.00 | |
| for result in results[0].boxes: | |
| cls = int(result.cls.item()) | |
| name = idx_to_class_yolo9[cls] | |
| prob = round(result.conf.item() * 100, 2) | |
| # Return the successful response | |
| response = { | |
| 'message': 'Data received!', | |
| 'your_base64': data['imageData'], | |
| 'class': name, | |
| 'prob': prob | |
| } | |
| return jsonify(response), 201 | |
| except Exception as e: | |
| # Return an error response if something goes wrong | |
| return jsonify({"errorg": str(e)}), 500 | |
| def post_test_data(): | |
| data = request.json | |
| response = { | |
| 'message': 'Data received!', | |
| 'name': data['name'] | |
| } | |
| return jsonify(response), 201 | |
| def post_api_voice(): | |
| try: | |
| data = request.json | |
| if not data or 'base64' not in data: | |
| return jsonify({'error': 'Invalid input. No base64 data provided.'}), 400 | |
| encoded_audio = data['base64'] | |
| # Process the audio to get Mel spectrogram and length | |
| mel_spectrogram, length = process_audio(encoded_audio) | |
| # Ensure the model and input dimensions are correct | |
| with torch.no_grad(): | |
| output = voice_binary_model(mel_spectrogram, length) | |
| prob = torch.nn.functional.softmax(output[0], dim=0) | |
| pred = torch.argmax(prob).item() | |
| category = binary_labels[pred] | |
| probs_dict = {binary_labels[i]: prob[i].item() for i in range(len(binary_labels))} | |
| response = { | |
| 'message': 'Data received!', | |
| 'class': category, | |
| 'probs': probs_dict | |
| } | |
| return jsonify(response), 201 | |
| except KeyError as e: | |
| return jsonify({'error': f'Missing key: {str(e)}'}), 400 | |
| except Exception as e: | |
| return jsonify({'error': str(e)}), 500 |