abdullahsajid commited on
Commit
190dda7
·
verified ·
1 Parent(s): bb3fd34

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -11
app.py CHANGED
@@ -1,15 +1,16 @@
1
  from flask import Flask, jsonify, request
2
  from flask_cors import CORS
3
- import numpy as np
4
- import matplotlib.pyplot as plt
5
- import base64
6
- import io
7
  import torch
8
  import torch.nn as nn
9
  import torch.nn.functional as F
10
  import torchvision.models as models
11
- from PIL import Image
12
  from torchvision import transforms
 
 
 
 
 
 
13
  from ultralytics import YOLO
14
 
15
 
@@ -21,6 +22,9 @@ CORS(app)
21
  idx_to_class_resnet50 = {0 : "Genuine" , 1:'Printed Paper' , 2 : 'Replayed'}
22
  idx_to_class_yolo9 = idx_to_class_yolo9 = {0: 'Genuine', 1: 'Printed Paper', 2: 'Replayed', 3: 'Paper Mask'}
23
  idx_to_class_resnet50_celeba = {0 : "Genuine" , 1:'Printed Paper' , 2 : 'Paper Cut',3:'Replayed',4:'3D Mask'}
 
 
 
24
  transform_data_resnet50=transforms.Compose([
25
  transforms.Resize(size=(224,224)),
26
  transforms.ToTensor()
@@ -31,6 +35,28 @@ transform_data_resnet50_celeba=transforms.Compose([
31
  transforms.Resize((224,224), antialias=True)
32
  ])
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  model_resnet50 = models.resnet50(weights=False)
35
  num_classes = 3
36
  model_resnet50.fc = nn.Linear(model_resnet50.fc.in_features, num_classes)
@@ -44,12 +70,55 @@ model_resnet50_celeba.load_state_dict(torch.load('resnet50_model_weights_celeba.
44
  model_resnet50_celeba.eval()
45
 
46
  model_yolo9 = YOLO('yolo9_best.pt')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  print('Models Loaded Successfully')
48
 
49
 
50
 
51
 
52
- @app.route('/', methods=['GET'])
 
 
 
 
 
53
  def get_data():
54
  img = plt.imread('test1.jpeg')
55
  img_arr = np.array(img)
@@ -65,11 +134,6 @@ def get_data():
65
  }
66
  return jsonify(data)
67
 
68
-
69
- @app.route('/test')
70
- def home():
71
- return "Welcome to the Flask API!"
72
-
73
  @app.route('/', methods=['POST'])
74
  def post_data():
75
  try:
@@ -146,3 +210,20 @@ def post_test_data():
146
  return jsonify(response), 201
147
 
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from flask import Flask, jsonify, request
2
  from flask_cors import CORS
 
 
 
 
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  import torchvision.models as models
 
7
  from torchvision import transforms
8
+ import torchaudio
9
+ import numpy as np
10
+ import matplotlib.pyplot as plt
11
+ import base64
12
+ import io
13
+ from PIL import Image
14
  from ultralytics import YOLO
15
 
16
 
 
22
  idx_to_class_resnet50 = {0 : "Genuine" , 1:'Printed Paper' , 2 : 'Replayed'}
23
  idx_to_class_yolo9 = idx_to_class_yolo9 = {0: 'Genuine', 1: 'Printed Paper', 2: 'Replayed', 3: 'Paper Mask'}
24
  idx_to_class_resnet50_celeba = {0 : "Genuine" , 1:'Printed Paper' , 2 : 'Paper Cut',3:'Replayed',4:'3D Mask'}
25
+ binary_labels = ['real','spoof']
26
+
27
+
28
  transform_data_resnet50=transforms.Compose([
29
  transforms.Resize(size=(224,224)),
30
  transforms.ToTensor()
 
35
  transforms.Resize((224,224), antialias=True)
36
  ])
37
 
38
+ def process_audio(encoded_audio):
39
+ decoded_audio = base64.b64decode(encoded_audio)
40
+ audio_bytes = io.BytesIO(decoded_audio)
41
+ waveform, sample_rate = torchaudio.load(audio_bytes)
42
+
43
+ if waveform.size(0) > 1:
44
+ waveform = waveform.mean(dim=0, keepdim=True) # Convert to mono by averaging channels
45
+
46
+ mel_spectrogram = torchaudio.transforms.MelSpectrogram(n_mels=80)(waveform).squeeze(0)
47
+ num_frames = mel_spectrogram.size(1)
48
+ target_length = 400
49
+
50
+ if num_frames < target_length:
51
+ padding = target_length - num_frames
52
+ mel_spectrogram = torch.cat([mel_spectrogram, torch.zeros(mel_spectrogram.size(0), padding)], dim=1)
53
+ else:
54
+ mel_spectrogram = mel_spectrogram[:, :target_length]
55
+
56
+ mel_spectrogram = mel_spectrogram.transpose(0, 1)
57
+ length = torch.tensor([mel_spectrogram.size(0)])
58
+ return mel_spectrogram.unsqueeze(0) ,length
59
+
60
  model_resnet50 = models.resnet50(weights=False)
61
  num_classes = 3
62
  model_resnet50.fc = nn.Linear(model_resnet50.fc.in_features, num_classes)
 
70
  model_resnet50_celeba.eval()
71
 
72
  model_yolo9 = YOLO('yolo9_best.pt')
73
+
74
+
75
+
76
+ class ConformerClassifier(torch.nn.Module):
77
+ def __init__(self, input_dim, num_classes, num_heads, ffn_dim, num_layers, depthwise_conv_kernel_size,dropout=0.0,use_group_norm=False,convolution_first=False):
78
+ super(ConformerClassifier, self).__init__()
79
+ self.conformer = torchaudio.models.Conformer(
80
+ input_dim=input_dim,
81
+ num_heads=num_heads,
82
+ ffn_dim=ffn_dim,
83
+ num_layers=num_layers,
84
+ depthwise_conv_kernel_size=depthwise_conv_kernel_size,
85
+ dropout=dropout,
86
+ use_group_norm=use_group_norm,
87
+ convolution_first=convolution_first
88
+ )
89
+ self.fc = torch.nn.Linear(input_dim, num_classes)
90
+
91
+ def forward(self, x, lengths):
92
+ x,length = self.conformer(x, lengths)
93
+ x = x.mean(dim=1)
94
+ x = self.fc(x)
95
+ return x
96
+
97
+ voice_binary_model = ConformerClassifier(
98
+ input_dim=80,
99
+ num_classes=2,
100
+ num_heads=4,
101
+ ffn_dim=128,
102
+ num_layers=4,
103
+ depthwise_conv_kernel_size=7,
104
+ dropout=0.3,
105
+ use_group_norm=False,
106
+ convolution_first=True
107
+ )
108
+ voice_binary_model.load_state_dict(torch.load('binary_voice_model.pth',map_location='cpu'))
109
+ voice_binary_model.eval()
110
+
111
  print('Models Loaded Successfully')
112
 
113
 
114
 
115
 
116
+ @app.route('/')
117
+ def home():
118
+ return "Welcome to the Flask API!"
119
+
120
+
121
+ @app.route('/api/data', methods=['GET'])
122
  def get_data():
123
  img = plt.imread('test1.jpeg')
124
  img_arr = np.array(img)
 
134
  }
135
  return jsonify(data)
136
 
 
 
 
 
 
137
  @app.route('/', methods=['POST'])
138
  def post_data():
139
  try:
 
210
  return jsonify(response), 201
211
 
212
 
213
+ @app.route('/api/voice', methods=['POST'])
214
+ def post_api_voice():
215
+ data = request.json
216
+ encoded_audio = data['base64']
217
+ mel_spectrogram, length = process_audio(encoded_audio)
218
+ with torch.no_grad():
219
+ output = voice_binary_model(mel_spectrogram,length)
220
+ prob = torch.nn.functional.softmax(output[0], dim=0)
221
+ pred = torch.argmax(prob).item()
222
+ category = binary_labels[pred]
223
+ probs_dict = {binary_labels[i]:prob[i] for i in range(len(binary_labels))}
224
+ response = {
225
+ 'message': 'Data received!',
226
+ 'class' : category,
227
+ 'probs' : probs_dict
228
+ }
229
+ return jsonify(response), 201