pavankumarvk commited on
Commit
7f307a0
·
verified ·
1 Parent(s): 156fc9b

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +127 -91
pipeline.py CHANGED
@@ -8,9 +8,7 @@ import tensorflow as tf
8
  from facenet_pytorch import MTCNN
9
  from rawnet import RawNet
10
 
11
-
12
-
13
- #Set random seed for reproducibility.
14
  tf.random.set_seed(42)
15
 
16
  # Extract model if not already extracted
@@ -22,39 +20,23 @@ if not os.path.exists("efficientnet-b0"):
22
  zip_ref.close()
23
  print("Model extracted successfully!")
24
 
25
- # Load models.
26
  # Load model without compiling to avoid optimizer dependency issues
27
  model = tf.keras.models.load_model("efficientnet-b0/", compile=False)
28
 
29
 
30
-
31
  class DetectionPipeline:
32
  """Pipeline class for detecting faces in the frames of a video file."""
33
 
34
- def __init__(self, n_frames=None, batch_size=60, resize=None, input_modality = 'video'):
35
- """Constructor for DetectionPipeline class.
36
-
37
- Keyword Arguments:
38
- n_frames {int} -- Total number of frames to load. These will be evenly spaced
39
- throughout the video. If not specified (i.e., None), all frames will be loaded.
40
- (default: {None})
41
- batch_size {int} -- Batch size to use with MTCNN face detector. (default: {32})
42
- resize {float} -- Fraction by which to resize frames from original prior to face
43
- detection. A value less than 1 results in downsampling and a value greater than
44
- 1 result in upsampling. (default: {None})
45
- """
46
  self.n_frames = n_frames
47
  self.batch_size = batch_size
48
  self.resize = resize
49
  self.input_modality = input_modality
50
 
51
  def __call__(self, filename):
52
- """Load frames from an MP4 video and detect faces.
53
-
54
- Arguments:
55
- filename {str} -- Path to video.
56
- """
57
- # Create video reader and find length
58
  if self.input_modality == 'video':
59
  print('Input modality is video.')
60
  v_cap = cv2.VideoCapture(filename)
@@ -80,11 +62,15 @@ class DetectionPipeline:
80
 
81
  # Resize frame to desired size
82
  if self.resize is not None:
83
- frame = frame.resize([int(d * self.resize) for d in frame.size])
 
84
  frames.append(frame)
85
 
86
  # When batch is full, detect faces and reset frame list
87
  if len(frames) % self.batch_size == 0 or j == sample[-1]:
 
 
 
88
  face2 = cv2.resize(frame, (224, 224))
89
  faces.append(face2)
90
 
@@ -93,55 +79,51 @@ class DetectionPipeline:
93
 
94
  elif self.input_modality == 'image':
95
  print('Input modality is image.')
96
- #Perform inference for image modality.
97
- print('Reading image')
98
- # print(f"Image path is: {filename}")
99
  image = cv2.cvtColor(filename, cv2.COLOR_BGR2RGB)
100
  image = cv2.resize(image, (224, 224))
101
-
102
- # if not face.any():
103
- # print("No faces found...")
104
-
105
  return image
106
 
107
  elif self.input_modality == 'audio':
108
- print("INput modality is audio.")
109
-
110
- #Load audio.
111
- x, sr = librosa.load(filename)
112
- x_pt = torch.Tensor(x)
113
- x_pt = torch.unsqueeze(x_pt, dim = 0)
114
- return x_pt
115
-
116
- else:
117
- raise ValueError("Invalid input modality. Must be either 'video' or image")
118
 
 
119
  detection_video_pipeline = DetectionPipeline(n_frames=5, batch_size=1, input_modality='video')
120
- detection_image_pipeline = DetectionPipeline(batch_size = 1, input_modality = 'image')
121
 
122
- def deepfakes_video_predict(input_video):
 
 
123
 
 
124
  faces = detection_video_pipeline(input_video)
125
  total = 0
126
  real_res = []
127
  fake_res = []
 
 
 
 
128
 
129
  for face in faces:
130
-
131
- face2 = face/255
132
  pred = model.predict(np.expand_dims(face2, axis=0))[0]
133
  real, fake = pred[0], pred[1]
134
  real_res.append(real)
135
  fake_res.append(fake)
136
 
137
- total+=1
138
 
139
- pred2 = pred[1]
140
 
141
  if pred2 > 0.5:
142
- fake+=1
143
  else:
144
- real+=1
 
145
  real_mean = np.mean(real_res)
146
  fake_mean = np.mean(fake_res)
147
  print(f"Real Faces: {real_mean}")
@@ -149,65 +131,119 @@ def deepfakes_video_predict(input_video):
149
  text = ""
150
 
151
  if real_mean >= 0.5:
152
- text = "The video is REAL. \n Deepfakes Confidence: " + str(round(100 - (real_mean*100), 3)) + "%"
153
  else:
154
- text = "The video is FAKE. \n Deepfakes Confidence: " + str(round(fake_mean*100, 3)) + "%"
155
 
156
  return text
157
 
158
 
159
  def deepfakes_image_predict(input_image):
160
  faces = detection_image_pipeline(input_image)
161
- face2 = faces/255
162
- pred = model.predict(np.expand_dims(face2, axis = 0))[0]
163
  real, fake = pred[0], pred[1]
164
  if real > 0.5:
165
- text2 = "The image is REAL. \n Deepfakes Confidence: " + str(round(100 - (real*100), 3)) + "%"
166
  else:
167
- text2 = "The image is FAKE. \n Deepfakes Confidence: " + str(round(fake*100, 3)) + "%"
168
  return text2
169
-
 
 
 
 
170
  def load_audio_model():
171
  d_args = {
172
- "nb_samp": 64600,
173
- "first_conv": 1024,
174
- "in_channels": 1,
175
- "filts": [20, [20, 20], [20, 128], [128, 128]],
176
- "blocks": [2, 4],
177
- "nb_fc_node": 1024,
178
- "gru_node": 1024,
179
- "nb_gru_layer": 3,
180
- "nb_classes": 2}
 
181
 
182
- model = RawNet(d_args = d_args, device='cpu')
183
-
184
- #Load ckpt.
185
- model_dict = model.state_dict()
186
- ckpt = torch.load('RawNet2.pth', map_location=torch.device('cpu'))
187
- model.load_state_dict(ckpt, model_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  return model
189
 
 
 
 
190
  audio_label_map = {
191
- 0: "Real audio",
192
- 1: "Fake audio"
193
  }
194
 
195
  def deepfakes_audio_predict(input_audio):
196
- #Perform inference on audio.
197
- x, sr = input_audio
198
- x_pt = torch.Tensor(x)
199
- x_pt = torch.unsqueeze(x_pt, dim = 0)
200
-
201
- #Load model.
202
- model = load_audio_model()
203
-
204
- #Perform inference.
205
- grads = model(x_pt)
206
-
207
- #Get the argmax.
208
- grads_np = grads.detach().numpy()
209
- result = np.argmax(grads_np)
210
-
211
- return audio_label_map[result]
212
-
213
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  from facenet_pytorch import MTCNN
9
  from rawnet import RawNet
10
 
11
+ # Set random seed for reproducibility.
 
 
12
  tf.random.set_seed(42)
13
 
14
  # Extract model if not already extracted
 
20
  zip_ref.close()
21
  print("Model extracted successfully!")
22
 
23
+ # Load Video/Image models.
24
  # Load model without compiling to avoid optimizer dependency issues
25
  model = tf.keras.models.load_model("efficientnet-b0/", compile=False)
26
 
27
 
 
28
  class DetectionPipeline:
29
  """Pipeline class for detecting faces in the frames of a video file."""
30
 
31
+ def __init__(self, n_frames=None, batch_size=60, resize=None, input_modality='video'):
32
+ """Constructor for DetectionPipeline class."""
 
 
 
 
 
 
 
 
 
 
33
  self.n_frames = n_frames
34
  self.batch_size = batch_size
35
  self.resize = resize
36
  self.input_modality = input_modality
37
 
38
  def __call__(self, filename):
39
+ """Load frames from an MP4 video and detect faces."""
 
 
 
 
 
40
  if self.input_modality == 'video':
41
  print('Input modality is video.')
42
  v_cap = cv2.VideoCapture(filename)
 
62
 
63
  # Resize frame to desired size
64
  if self.resize is not None:
65
+ frame = cv2.resize(frame, None, fx=self.resize, fy=self.resize)
66
+
67
  frames.append(frame)
68
 
69
  # When batch is full, detect faces and reset frame list
70
  if len(frames) % self.batch_size == 0 or j == sample[-1]:
71
+ # Simple resizing for the EfficientNet model (assuming face is centered or whole frame is analyzed)
72
+ # For a more robust solution, MTCNN should be used here to extract faces first.
73
+ # Based on your provided logic, we resize the frame directly.
74
  face2 = cv2.resize(frame, (224, 224))
75
  faces.append(face2)
76
 
 
79
 
80
  elif self.input_modality == 'image':
81
  print('Input modality is image.')
82
+ # Perform inference for image modality.
83
+ # Note: 'filename' here is actually the numpy array from Gradio Image component
 
84
  image = cv2.cvtColor(filename, cv2.COLOR_BGR2RGB)
85
  image = cv2.resize(image, (224, 224))
 
 
 
 
86
  return image
87
 
88
  elif self.input_modality == 'audio':
89
+ # Audio is handled by deepfakes_audio_predict directly,
90
+ # but if you use this class, return placeholder or raw audio.
91
+ return None
 
 
 
 
 
 
 
92
 
93
+ # Instantiate pipelines
94
  detection_video_pipeline = DetectionPipeline(n_frames=5, batch_size=1, input_modality='video')
95
+ detection_image_pipeline = DetectionPipeline(batch_size=1, input_modality='image')
96
 
97
+ # ---------------------------------------------------------
98
+ # Video & Image Prediction Functions
99
+ # ---------------------------------------------------------
100
 
101
+ def deepfakes_video_predict(input_video):
102
  faces = detection_video_pipeline(input_video)
103
  total = 0
104
  real_res = []
105
  fake_res = []
106
+
107
+ # Initialize counters for the simple voting logic
108
+ real_count = 0
109
+ fake_count = 0
110
 
111
  for face in faces:
112
+ face2 = face / 255.0
 
113
  pred = model.predict(np.expand_dims(face2, axis=0))[0]
114
  real, fake = pred[0], pred[1]
115
  real_res.append(real)
116
  fake_res.append(fake)
117
 
118
+ total += 1
119
 
120
+ pred2 = pred[1] # Probability of Fake
121
 
122
  if pred2 > 0.5:
123
+ fake_count += 1
124
  else:
125
+ real_count += 1
126
+
127
  real_mean = np.mean(real_res)
128
  fake_mean = np.mean(fake_res)
129
  print(f"Real Faces: {real_mean}")
 
131
  text = ""
132
 
133
  if real_mean >= 0.5:
134
+ text = "The video is REAL. \n Deepfakes Confidence: " + str(round(100 - (real_mean * 100), 3)) + "%"
135
  else:
136
+ text = "The video is FAKE. \n Deepfakes Confidence: " + str(round(fake_mean * 100, 3)) + "%"
137
 
138
  return text
139
 
140
 
141
  def deepfakes_image_predict(input_image):
142
  faces = detection_image_pipeline(input_image)
143
+ face2 = faces / 255.0
144
+ pred = model.predict(np.expand_dims(face2, axis=0))[0]
145
  real, fake = pred[0], pred[1]
146
  if real > 0.5:
147
+ text2 = "The image is REAL. \n Deepfakes Confidence: " + str(round(100 - (real * 100), 3)) + "%"
148
  else:
149
+ text2 = "The image is FAKE. \n Deepfakes Confidence: " + str(round(fake * 100), 3)) + "%"
150
  return text2
151
+
152
+ # ---------------------------------------------------------
153
+ # Audio Prediction Functions
154
+ # ---------------------------------------------------------
155
+
156
  def load_audio_model():
157
  d_args = {
158
+ "nb_samp": 64600,
159
+ "first_conv": 1024,
160
+ "in_channels": 1,
161
+ "filts": [20, [20, 20], [20, 128], [128, 128]],
162
+ "blocks": [2, 4],
163
+ "nb_fc_node": 1024,
164
+ "gru_node": 1024,
165
+ "nb_gru_layer": 3,
166
+ "nb_classes": 2
167
+ }
168
 
169
+ device = torch.device('cpu')
170
+ model = RawNet(d_args=d_args, device=device)
171
+ model.eval()
172
+
173
+ # Load weights
174
+ # Ensure 'RawNet2.pth' is in your repository root
175
+ if os.path.exists('RawNet2.pth'):
176
+ try:
177
+ checkpoint = torch.load('RawNet2.pth', map_location=device)
178
+ # Handle different checkpoint formats (strict or not)
179
+ if isinstance(checkpoint, dict):
180
+ if 'model' in checkpoint:
181
+ model.load_state_dict(checkpoint['model'])
182
+ elif 'state_dict' in checkpoint:
183
+ model.load_state_dict(checkpoint['state_dict'])
184
+ else:
185
+ model.load_state_dict(checkpoint, strict=False)
186
+ else:
187
+ model.load_state_dict(checkpoint, strict=False)
188
+ print("Audio model loaded successfully.")
189
+ except Exception as e:
190
+ print(f"Error loading audio model weights: {e}")
191
+ else:
192
+ print("Warning: 'RawNet2.pth' not found. Audio detection will not work.")
193
+
194
  return model
195
 
196
+ # Load the audio model globally to avoid reloading it on every request
197
+ audio_model = load_audio_model()
198
+
199
  audio_label_map = {
200
+ 0: "Real",
201
+ 1: "Fake"
202
  }
203
 
204
  def deepfakes_audio_predict(input_audio):
205
+ """
206
+ input_audio: tuple (sample_rate, audio_data) provided by Gradio
207
+ """
208
+ if audio_model is None:
209
+ return "Error: Audio model not loaded."
210
+
211
+ try:
212
+ sr, x = input_audio
213
+ except ValueError:
214
+ # Fallback if input format is different (e.g. just file path)
215
+ return "Error: Invalid audio input format."
216
+
217
+ # Target sampling rate and length for RawNet
218
+ target_sr = 16000
219
+ target_len = 64600
220
+
221
+ # Resample if necessary
222
+ if sr != target_sr:
223
+ x = librosa.resample(x, orig_sr=sr, target_sr=target_sr)
224
+
225
+ # Pad or crop to target length
226
+ len_x = x.shape[0]
227
+ if len_x < target_len:
228
+ # Pad with zeros
229
+ x = np.pad(x, (0, target_len - len_x), mode='constant')
230
+ elif len_x > target_len:
231
+ # Center crop
232
+ start = (len_x - target_len) // 2
233
+ x = x[start:start + target_len]
234
+
235
+ # Convert to Tensor and add dimensions (Batch, Channel, Length)
236
+ x_pt = torch.from_numpy(x).float().unsqueeze(0).unsqueeze(0)
237
+
238
+ # Perform inference
239
+ with torch.no_grad():
240
+ output = audio_model(x_pt)
241
+
242
+ # Output is LogSoftmax, convert to probabilities
243
+ probs = torch.exp(output)
244
+ confidence, prediction = torch.max(probs, 1)
245
+
246
+ label = audio_label_map[prediction.item()]
247
+ confidence_score = confidence.item() * 100
248
+
249
+ return f"The audio is {label}.\nConfidence: {confidence_score:.2f}%"