deepdetect / prediction.py
SanskarModi's picture
Update prediction.py
d7d3e9c verified
import cv2
import mediapipe as mp
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms
from pathlib import Path
from common import read_yaml
import os
os.environ['MPLCONFIGDIR'] = '/tmp'
PARAMS_FILE_PATH = Path("params.yaml")
class Prediction:
def __init__(self):
"""
Initialize the Prediction class with a pre-trained model and necessary parameters.
"""
self.device = torch.device("cpu")
self.model = torch.jit.load("model.pt", map_location=self.device)
self.model.eval()
params = read_yaml(PARAMS_FILE_PATH)
self.expansion_factor = params.expansion_factor
self.resolution = params.resolution
self.default_frame_count = params.sequence_length
# Initialize MediaPipe face detector
self.face_detection = mp.solutions.face_detection.FaceDetection(
model_selection=0, min_detection_confidence=0.6
)
# Define the classes for prediction
self.classes = [
"original",
"Deepfake (Face2Face)",
"Deepfake (FaceShifter)",
"Deepfake (FaceSwap)",
"Deepfake (NeuralTextures)",
]
def get_frames(self, video):
"""
Yields frames from the given video file.
"""
vidobj = cv2.VideoCapture(video)
success, image = vidobj.read()
while success:
yield image
success, image = vidobj.read()
def get_face(self, frame):
"""
Detect faces in a frame using MediaPipe.
Args:
frame (np.ndarray): Input frame
Returns:
tuple: (top, right, bottom, left) coordinates of the face or None if no face detected
"""
try:
# Convert frame from BGR (OpenCV) to RGB
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Detect faces
results = self.face_detection.process(rgb_frame)
if results.detections:
detection = results.detections[0] # Use the first detected face
h, w, _ = frame.shape
bboxC = detection.location_data.relative_bounding_box
# Calculate absolute coordinates
xmin = int(bboxC.xmin * w)
ymin = int(bboxC.ymin * h)
box_width = int(bboxC.width * w)
box_height = int(bboxC.height * h)
# Return in top, right, bottom, left format
top = max(ymin, 0)
right = min(xmin + box_width, w)
bottom = min(ymin + box_height, h)
left = max(xmin, 0)
return (top, right, bottom, left)
return None # No face detected
except Exception as e:
print(f"Error in get_face: {e}")
print(f"Frame shape: {frame.shape}, dtype: {frame.dtype}")
raise
def color_jitter(self, image):
"""
Applies color jitter to the given image for data augmentation.
Args:
image (np.ndarray): The input image
Returns:
np.ndarray: The color jittered image
"""
rng = np.random.default_rng(seed=42)
# Convert to HSV for easier manipulation
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
h, s, v = cv2.split(hsv)
# Adjust brightness
value = rng.uniform(0.8, 1.2)
v = cv2.multiply(v, value)
# Adjust contrast
mean = np.mean(v)
value = rng.uniform(0.8, 1.2)
v = cv2.addWeighted(v, value, mean, 1 - value, 0)
# Adjust saturation
value = rng.uniform(0.8, 1.2)
s = cv2.multiply(s, value)
final_hsv = cv2.merge((h, s, v))
image = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2BGR)
return image
def preprocess(self, video, seq_length=None):
"""
Preprocess the video by extracting frames, detecting faces, and resizing.
Applies same preprocessing as training pipeline.
Args:
video (str): Path to the video file
seq_length (int, optional): Number of frames to extract
Returns:
list: List of preprocessed frames
"""
frames = []
raw_frames = [] # Store original cropped frames for visualization
# Use provided sequence length or default from params
target_seq_length = (
seq_length if seq_length is not None else self.default_frame_count
)
transform = transforms.Compose(
[
transforms.ToPILImage(),
transforms.Resize(
tuple(self.resolution),
interpolation=transforms.InterpolationMode.BILINEAR,
),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
buffer = [] # For processing in batches of 4 like training pipeline
for idx, frame in enumerate(self.get_frames(video)):
if len(frames) < target_seq_length:
buffer.append(frame)
if len(buffer) == 4: # Process in batches of 4
faces = [self.get_face(f) for f in buffer]
for i, face in enumerate(faces):
if face is not None:
top, right, bottom, left = face
face_height = bottom - top
face_width = right - left
# Expand face region using expansion factor
expanded_top = max(
0, top - int(self.expansion_factor / 2 * face_height)
)
expanded_bottom = min(
buffer[i].shape[0],
bottom + int(self.expansion_factor / 2 * face_height),
)
expanded_left = max(
0, left - int(self.expansion_factor / 2 * face_width)
)
expanded_right = min(
buffer[i].shape[1],
right + int(self.expansion_factor / 2 * face_width),
)
# Crop and resize
cropped_face = cv2.resize(
buffer[i][
expanded_top:expanded_bottom,
expanded_left:expanded_right,
:,
],
tuple(self.resolution),
)
# Store original cropped face for visualization
raw_frames.append(cropped_face.copy())
# Apply color jitter like in training
cropped_face = self.color_jitter(cropped_face)
# Transform for model input
transformed = transform(cropped_face)
frames.append(transformed)
buffer = [] # Reset buffer
else:
break
# Handle padding if we have fewer frames than required
if len(frames) < target_seq_length:
# If we have some frames, duplicate the last one
if frames:
while len(frames) < target_seq_length:
frames.append(frames[-1])
raw_frames.append(raw_frames[-1])
else:
return [], [] # No faces detected
return frames[:target_seq_length], raw_frames[:target_seq_length]
def save_gradients(self, grad):
"""
Hook function to capture gradients.
"""
self.gradients = grad
def grad_cam(self, fmap, grads):
"""
Compute Grad-CAM using feature maps and gradients.
"""
pooled_grads = torch.mean(grads, dim=[0])
for i in range(fmap.shape[1]):
fmap[:, i, :, :] *= pooled_grads[i]
cam = torch.mean(fmap, dim=1).squeeze().cpu().detach().numpy()
# Apply ReLU to retain only positive activations
cam = np.maximum(cam, 0)
# Normalize Grad-CAM
cam = cam - np.min(cam)
cam = cam / np.max(cam) if np.max(cam) > 0 else cam # Prevent division by zero
# Resize the cam to match the resolution of the original image
cam = cv2.resize(cam, tuple(self.resolution))
# Convert to single-channel by summing or taking one of the channels
cam = np.sum(cam, axis=-1) if cam.shape[-1] > 1 else cam
return cam
def generate_gradcam(self, fmap, video_frame, grads):
"""
Generate the Grad-CAM heatmap and overlay it on the frame.
"""
cam = self.grad_cam(fmap, grads)
# Ensure cam is a single-channel 8-bit image
cam = np.uint8(255 * cam) # Scale to 0-255
heatmap = cv2.applyColorMap(cam, cv2.COLORMAP_JET) # Apply colormap
# Ensure video_frame is in the right format
video_frame = np.float32(cv2.cvtColor(video_frame, cv2.COLOR_RGB2BGR))
# Convert the normalized video_frame back to uint8 (0-255)
video_frame = np.uint8(255 * video_frame)
# Blend heatmap and original image with a weight to ensure the face is visible
alpha = 0.01 # Lower weight for the heatmap to make face more visible
beta = 1 - alpha # Weight for the original frame
overlayed_img = cv2.addWeighted(heatmap, alpha, video_frame, beta, 0)
return overlayed_img
def predict(self, video, seq_length=None):
"""
Predict whether a video is real or fake.
Args:
video (str): Path to the video file
seq_length (int, optional): Number of frames to use
Returns:
tuple: (prediction_result, gradcam_image, classification_details)
"""
frames, raw_frames = self.preprocess(video, seq_length)
if not frames:
return "No faces detected in the video", None, None
# Prepare input tensor for the model
target_seq_length = (
seq_length if seq_length is not None else self.default_frame_count
)
input_tensor = torch.stack(frames).unsqueeze(0)
input_tensor = input_tensor.view(1, target_seq_length, 3, *self.resolution)
input_tensor = input_tensor.to(self.device)
# Remove the torch.no_grad() context to allow gradient computation
input_tensor.requires_grad_(True)
# Forward pass with gradient tracking enabled
fmap, attn_wts, logits = self.model(input_tensor)
# Register hook for Grad-CAM
fmap.register_hook(self.save_gradients)
# Get predictions for all classes
class_probs = F.softmax(logits, dim=1).detach().cpu().numpy()[0]
# Get the predicted class
predicted_class_idx = np.argmax(class_probs)
predicted_class = (
self.classes[predicted_class_idx]
if predicted_class_idx < len(self.classes)
else "Unknown"
)
prediction = "Deepfake" if predicted_class_idx > 0 else "Real"
# Format confidence values to 2 decimal places
confidence_class = round(class_probs[predicted_class_idx] * 100, 2)
confidence_deepfake_real = (
round(class_probs[1:].max() * 100, 2)
if prediction == "Deepfake"
else round(class_probs[0] * 100, 2)
)
prediction_string = f"{prediction} {confidence_deepfake_real:.2f}% Confidence"
# Create detailed classification results as a dictionary
if prediction == "Deepfake":
# For deepfakes, show probabilities for each deepfake type
classification_details = {
self.classes[i]: float(class_probs[i]) for i in range(1, len(self.classes))
}
else:
# For real videos, just show real confidence
classification_details = {
"Real": float(class_probs[0])
}
# Backpropagate for Grad-CAM
self.model.zero_grad()
logits[0, predicted_class_idx].backward()
grads = self.gradients
# Generate Grad-CAM visualization for the best frame
if raw_frames:
middle_idx = len(raw_frames) // 2
gradcam_image = self.generate_gradcam(fmap, raw_frames[middle_idx], grads)
else:
gradcam_image = None
return prediction_string, gradcam_image, classification_details