SanskarModi commited on
Commit
397b9f3
·
verified ·
1 Parent(s): 04bb1f8

Update prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +26 -17
prediction.py CHANGED
@@ -281,19 +281,19 @@ class Prediction:
281
  def predict(self, video, seq_length=None):
282
  """
283
  Predict whether a video is real or fake.
284
-
285
  Args:
286
  video (str): Path to the video file
287
  seq_length (int, optional): Number of frames to use
288
-
289
  Returns:
290
  tuple: (prediction_result, gradcam_image, classification_details)
291
  """
292
  frames, raw_frames = self.preprocess(video, seq_length)
293
-
294
  if not frames:
295
  return "No faces detected in the video", None, None
296
-
297
  # Prepare input tensor for the model
298
  target_seq_length = (
299
  seq_length if seq_length is not None else self.default_frame_count
@@ -302,14 +302,23 @@ class Prediction:
302
  input_tensor = input_tensor.view(1, target_seq_length, 3, *self.resolution)
303
  input_tensor = input_tensor.to(self.device)
304
  input_tensor.requires_grad_()
305
-
306
- # Forward pass to get feature maps and final output
307
- fmap, attn_wts, output = self.model(input_tensor)
 
 
 
 
 
 
 
 
 
308
  fmap.register_hook(self.save_gradients)
309
-
310
  # Get predictions for all classes
311
- class_probs = F.softmax(output, dim=1).detach().cpu().numpy()[0]
312
-
313
  # Get the predicted class
314
  predicted_class_idx = np.argmax(class_probs)
315
  predicted_class = (
@@ -318,7 +327,7 @@ class Prediction:
318
  else "Unknown"
319
  )
320
  prediction = "Deepfake" if predicted_class_idx > 0 else "Real"
321
-
322
  # Format confidence values to 2 decimal places
323
  confidence_class = round(class_probs[predicted_class_idx] * 100, 2)
324
  confidence_deepfake_real = (
@@ -327,7 +336,7 @@ class Prediction:
327
  else round(class_probs[0] * 100, 2)
328
  )
329
  prediction_string = f"{prediction} {confidence_deepfake_real:.2f}% Confidence"
330
-
331
  # Create detailed classification results
332
  classification_details = (
333
  {
@@ -340,18 +349,18 @@ class Prediction:
340
  "confidence(%)": f"{confidence_class:.2f}",
341
  }
342
  )
343
-
344
  # Backpropagate for Grad-CAM
345
  self.model.zero_grad()
346
- output[0, predicted_class_idx].backward()
347
  grads = self.gradients
348
-
349
  # Generate Grad-CAM visualization for the best frame
350
  if raw_frames:
351
- # Choose middle frame for visualization
352
  middle_idx = len(raw_frames) // 2
353
  gradcam_image = self.generate_gradcam(fmap, raw_frames[middle_idx], grads)
354
  else:
355
  gradcam_image = None
356
-
357
  return prediction_string, gradcam_image, classification_details
 
 
281
  def predict(self, video, seq_length=None):
282
  """
283
  Predict whether a video is real or fake.
284
+
285
  Args:
286
  video (str): Path to the video file
287
  seq_length (int, optional): Number of frames to use
288
+
289
  Returns:
290
  tuple: (prediction_result, gradcam_image, classification_details)
291
  """
292
  frames, raw_frames = self.preprocess(video, seq_length)
293
+
294
  if not frames:
295
  return "No faces detected in the video", None, None
296
+
297
  # Prepare input tensor for the model
298
  target_seq_length = (
299
  seq_length if seq_length is not None else self.default_frame_count
 
302
  input_tensor = input_tensor.view(1, target_seq_length, 3, *self.resolution)
303
  input_tensor = input_tensor.to(self.device)
304
  input_tensor.requires_grad_()
305
+
306
+ # Forward pass to get model output dict
307
+ with torch.no_grad():
308
+ output_dict = self.model(input_tensor)
309
+
310
+ # Extract relevant outputs
311
+ fmap = output_dict["fmap"]
312
+ attn_wts = output_dict["attn"]
313
+ logits = output_dict["logits"]
314
+
315
+ # Register hook for Grad-CAM
316
+ fmap.requires_grad_()
317
  fmap.register_hook(self.save_gradients)
318
+
319
  # Get predictions for all classes
320
+ class_probs = F.softmax(logits, dim=1).detach().cpu().numpy()[0]
321
+
322
  # Get the predicted class
323
  predicted_class_idx = np.argmax(class_probs)
324
  predicted_class = (
 
327
  else "Unknown"
328
  )
329
  prediction = "Deepfake" if predicted_class_idx > 0 else "Real"
330
+
331
  # Format confidence values to 2 decimal places
332
  confidence_class = round(class_probs[predicted_class_idx] * 100, 2)
333
  confidence_deepfake_real = (
 
336
  else round(class_probs[0] * 100, 2)
337
  )
338
  prediction_string = f"{prediction} {confidence_deepfake_real:.2f}% Confidence"
339
+
340
  # Create detailed classification results
341
  classification_details = (
342
  {
 
349
  "confidence(%)": f"{confidence_class:.2f}",
350
  }
351
  )
352
+
353
  # Backpropagate for Grad-CAM
354
  self.model.zero_grad()
355
+ logits[0, predicted_class_idx].backward()
356
  grads = self.gradients
357
+
358
  # Generate Grad-CAM visualization for the best frame
359
  if raw_frames:
 
360
  middle_idx = len(raw_frames) // 2
361
  gradcam_image = self.generate_gradcam(fmap, raw_frames[middle_idx], grads)
362
  else:
363
  gradcam_image = None
364
+
365
  return prediction_string, gradcam_image, classification_details
366
+