abdullahsajid's picture
Update app.py
a27a142 verified
from flask import Flask, jsonify, request
from flask_cors import CORS
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torchvision import transforms
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
import base64
import io
from PIL import Image
from ultralytics import YOLO
from PIL import Image
app = Flask(__name__)
CORS(app)
idx_to_class_resnet50 = {0 : "Genuine" , 1:'Printed Paper' , 2 : 'Replayed'}
idx_to_class_yolo9 = idx_to_class_yolo9 = {0: 'Genuine', 1: 'Printed Paper', 2: 'Replayed', 3: 'Paper Mask'}
idx_to_class_resnet50_celeba = {0 : "Genuine" , 1:'Printed Paper' , 2 : 'Paper Cut',3:'Replayed',4:'3D Mask'}
binary_labels = ['real','spoof']
transform_data_resnet50=transforms.Compose([
transforms.Resize(size=(224,224)),
transforms.ToTensor()
])
transform_data_resnet50_celeba=transforms.Compose([
transforms.ToTensor(),
transforms.Resize((224,224), antialias=True)
])
def process_audio(encoded_audio):
decoded_audio = base64.b64decode(encoded_audio)
audio_bytes = io.BytesIO(decoded_audio)
waveform, sample_rate = torchaudio.load(audio_bytes)
if waveform.size(0) > 1:
waveform = waveform.mean(dim=0, keepdim=True) # Convert to mono by averaging channels
mel_spectrogram = torchaudio.transforms.MelSpectrogram(n_mels=80)(waveform).squeeze(0)
num_frames = mel_spectrogram.size(1)
target_length = 400
if num_frames < target_length:
padding = target_length - num_frames
mel_spectrogram = torch.cat([mel_spectrogram, torch.zeros(mel_spectrogram.size(0), padding)], dim=1)
else:
mel_spectrogram = mel_spectrogram[:, :target_length]
mel_spectrogram = mel_spectrogram.transpose(0, 1)
length = torch.tensor([mel_spectrogram.size(0)])
return mel_spectrogram.unsqueeze(0) ,length
model_resnet50 = models.resnet50(weights=False)
num_classes = 3
model_resnet50.fc = nn.Linear(model_resnet50.fc.in_features, num_classes)
model_resnet50.load_state_dict(torch.load('resnet50_pytorch_rose_weights.pth',map_location=torch.device('cpu')))
model_resnet50.eval()
model_resnet50_celeba = models.resnet50(weights=False)
num_classes = 5
model_resnet50_celeba.fc = nn.Linear(model_resnet50_celeba.fc.in_features, num_classes)
model_resnet50_celeba.load_state_dict(torch.load('resnet50_model_weights_celeba.pth',map_location=torch.device('cpu')))
model_resnet50_celeba.eval()
model_yolo9 = YOLO('yolo9_best.pt')
class ConformerClassifier(torch.nn.Module):
def __init__(self, input_dim, num_classes, num_heads, ffn_dim, num_layers, depthwise_conv_kernel_size,dropout=0.0,use_group_norm=False,convolution_first=False):
super(ConformerClassifier, self).__init__()
self.conformer = torchaudio.models.Conformer(
input_dim=input_dim,
num_heads=num_heads,
ffn_dim=ffn_dim,
num_layers=num_layers,
depthwise_conv_kernel_size=depthwise_conv_kernel_size,
dropout=dropout,
use_group_norm=use_group_norm,
convolution_first=convolution_first
)
self.fc = torch.nn.Linear(input_dim, num_classes)
def forward(self, x, lengths):
x,length = self.conformer(x, lengths)
x = x.mean(dim=1)
x = self.fc(x)
return x
voice_binary_model = ConformerClassifier(
input_dim=80,
num_classes=2,
num_heads=4,
ffn_dim=128,
num_layers=4,
depthwise_conv_kernel_size=7,
dropout=0.3,
use_group_norm=False,
convolution_first=True
)
voice_binary_model.load_state_dict(torch.load('binary_voice_model.pth',map_location='cpu'))
voice_binary_model.eval()
print('Models Loaded Successfully')
@app.route('/')
def home():
return "Welcome to the Flask API!"
@app.route('/api/face', methods=['GET'])
def get_data():
img = plt.imread('test1.jpeg')
img_arr = np.array(img)
pil_img = Image.fromarray(img_arr.astype(np.uint8))
buffered = io.BytesIO()
pil_img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode()
data = {
'message': 'Hello, World!',
'items': [1, 2, 3, 4, 5],
'image': img_str
}
return jsonify(data)
@app.route('/api/face', methods=['POST'])
def post_data():
try:
# Parse the JSON request
data = request.json
# Ensure the necessary fields are present
if 'imageData' not in data or 'model' not in data:
return jsonify({"error": "Missing required fields: 'imageData' or 'model'"}), 400
base64_image = data['imageData']
# Decode the image data
try:
image_data = base64.b64decode(base64_image)
except base64.binascii.Error as e:
return jsonify({"error": "Invalid base64 string"}), 400
# Convert image data to PIL image
try:
image = Image.open(io.BytesIO(image_data)).convert('RGB')
except IOError as e:
return jsonify({"error": "Invalid image data"}), 400
# Model prediction logic
if data['model'] == 'resnet':
transform_img = transform_data_resnet50(image).unsqueeze(0)
with torch.no_grad():
pred = model_resnet50(transform_img)
probabilities = F.softmax(pred[0], dim=0)
cat = torch.argmax(pred[0]).item()
prob = round((probabilities[cat] * 100).item(), 2)
name = idx_to_class_resnet50[cat]
elif data['model'] == 'resnet50':
transform_img = transform_data_resnet50_celeba(image).unsqueeze(0)
with torch.no_grad():
pred = model_resnet50_celeba(transform_img)
probabilities = F.softmax(pred[0], dim=0)
cat = torch.argmax(pred[0]).item()
prob = round((probabilities[cat] * 100).item(), 2)
name = idx_to_class_resnet50_celeba[cat]
else:
results = model_yolo9(image)
name = 'not detectable'
prob = 0.00
for result in results[0].boxes:
cls = int(result.cls.item())
name = idx_to_class_yolo9[cls]
prob = round(result.conf.item() * 100, 2)
# Return the successful response
response = {
'message': 'Data received!',
'your_base64': data['imageData'],
'class': name,
'prob': prob
}
return jsonify(response), 201
except Exception as e:
# Return an error response if something goes wrong
return jsonify({"errorg": str(e)}), 500
@app.route('/test', methods=['POST'])
def post_test_data():
data = request.json
response = {
'message': 'Data received!',
'name': data['name']
}
return jsonify(response), 201
@app.route('/api/voice', methods=['POST'])
def post_api_voice():
try:
data = request.json
if not data or 'base64' not in data:
return jsonify({'error': 'Invalid input. No base64 data provided.'}), 400
encoded_audio = data['base64']
# Process the audio to get Mel spectrogram and length
mel_spectrogram, length = process_audio(encoded_audio)
# Ensure the model and input dimensions are correct
with torch.no_grad():
output = voice_binary_model(mel_spectrogram, length)
prob = torch.nn.functional.softmax(output[0], dim=0)
pred = torch.argmax(prob).item()
category = binary_labels[pred]
probs_dict = {binary_labels[i]: prob[i].item() for i in range(len(binary_labels))}
response = {
'message': 'Data received!',
'class': category,
'probs': probs_dict
}
return jsonify(response), 201
except KeyError as e:
return jsonify({'error': f'Missing key: {str(e)}'}), 400
except Exception as e:
return jsonify({'error': str(e)}), 500