import cv2 import mediapipe as mp import numpy as np import torch import torch.nn.functional as F from torchvision import transforms from pathlib import Path from common import read_yaml import os os.environ['MPLCONFIGDIR'] = '/tmp' PARAMS_FILE_PATH = Path("params.yaml") class Prediction: def __init__(self): """ Initialize the Prediction class with a pre-trained model and necessary parameters. """ self.device = torch.device("cpu") self.model = torch.jit.load("model.pt", map_location=self.device) self.model.eval() params = read_yaml(PARAMS_FILE_PATH) self.expansion_factor = params.expansion_factor self.resolution = params.resolution self.default_frame_count = params.sequence_length # Initialize MediaPipe face detector self.face_detection = mp.solutions.face_detection.FaceDetection( model_selection=0, min_detection_confidence=0.6 ) # Define the classes for prediction self.classes = [ "original", "Deepfake (Face2Face)", "Deepfake (FaceShifter)", "Deepfake (FaceSwap)", "Deepfake (NeuralTextures)", ] def get_frames(self, video): """ Yields frames from the given video file. """ vidobj = cv2.VideoCapture(video) success, image = vidobj.read() while success: yield image success, image = vidobj.read() def get_face(self, frame): """ Detect faces in a frame using MediaPipe. Args: frame (np.ndarray): Input frame Returns: tuple: (top, right, bottom, left) coordinates of the face or None if no face detected """ try: # Convert frame from BGR (OpenCV) to RGB rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Detect faces results = self.face_detection.process(rgb_frame) if results.detections: detection = results.detections[0] # Use the first detected face h, w, _ = frame.shape bboxC = detection.location_data.relative_bounding_box # Calculate absolute coordinates xmin = int(bboxC.xmin * w) ymin = int(bboxC.ymin * h) box_width = int(bboxC.width * w) box_height = int(bboxC.height * h) # Return in top, right, bottom, left format top = max(ymin, 0) right = min(xmin + box_width, w) bottom = min(ymin + box_height, h) left = max(xmin, 0) return (top, right, bottom, left) return None # No face detected except Exception as e: print(f"Error in get_face: {e}") print(f"Frame shape: {frame.shape}, dtype: {frame.dtype}") raise def color_jitter(self, image): """ Applies color jitter to the given image for data augmentation. Args: image (np.ndarray): The input image Returns: np.ndarray: The color jittered image """ rng = np.random.default_rng(seed=42) # Convert to HSV for easier manipulation hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) h, s, v = cv2.split(hsv) # Adjust brightness value = rng.uniform(0.8, 1.2) v = cv2.multiply(v, value) # Adjust contrast mean = np.mean(v) value = rng.uniform(0.8, 1.2) v = cv2.addWeighted(v, value, mean, 1 - value, 0) # Adjust saturation value = rng.uniform(0.8, 1.2) s = cv2.multiply(s, value) final_hsv = cv2.merge((h, s, v)) image = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2BGR) return image def preprocess(self, video, seq_length=None): """ Preprocess the video by extracting frames, detecting faces, and resizing. Applies same preprocessing as training pipeline. Args: video (str): Path to the video file seq_length (int, optional): Number of frames to extract Returns: list: List of preprocessed frames """ frames = [] raw_frames = [] # Store original cropped frames for visualization # Use provided sequence length or default from params target_seq_length = ( seq_length if seq_length is not None else self.default_frame_count ) transform = transforms.Compose( [ transforms.ToPILImage(), transforms.Resize( tuple(self.resolution), interpolation=transforms.InterpolationMode.BILINEAR, ), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ] ) buffer = [] # For processing in batches of 4 like training pipeline for idx, frame in enumerate(self.get_frames(video)): if len(frames) < target_seq_length: buffer.append(frame) if len(buffer) == 4: # Process in batches of 4 faces = [self.get_face(f) for f in buffer] for i, face in enumerate(faces): if face is not None: top, right, bottom, left = face face_height = bottom - top face_width = right - left # Expand face region using expansion factor expanded_top = max( 0, top - int(self.expansion_factor / 2 * face_height) ) expanded_bottom = min( buffer[i].shape[0], bottom + int(self.expansion_factor / 2 * face_height), ) expanded_left = max( 0, left - int(self.expansion_factor / 2 * face_width) ) expanded_right = min( buffer[i].shape[1], right + int(self.expansion_factor / 2 * face_width), ) # Crop and resize cropped_face = cv2.resize( buffer[i][ expanded_top:expanded_bottom, expanded_left:expanded_right, :, ], tuple(self.resolution), ) # Store original cropped face for visualization raw_frames.append(cropped_face.copy()) # Apply color jitter like in training cropped_face = self.color_jitter(cropped_face) # Transform for model input transformed = transform(cropped_face) frames.append(transformed) buffer = [] # Reset buffer else: break # Handle padding if we have fewer frames than required if len(frames) < target_seq_length: # If we have some frames, duplicate the last one if frames: while len(frames) < target_seq_length: frames.append(frames[-1]) raw_frames.append(raw_frames[-1]) else: return [], [] # No faces detected return frames[:target_seq_length], raw_frames[:target_seq_length] def save_gradients(self, grad): """ Hook function to capture gradients. """ self.gradients = grad def grad_cam(self, fmap, grads): """ Compute Grad-CAM using feature maps and gradients. """ pooled_grads = torch.mean(grads, dim=[0]) for i in range(fmap.shape[1]): fmap[:, i, :, :] *= pooled_grads[i] cam = torch.mean(fmap, dim=1).squeeze().cpu().detach().numpy() # Apply ReLU to retain only positive activations cam = np.maximum(cam, 0) # Normalize Grad-CAM cam = cam - np.min(cam) cam = cam / np.max(cam) if np.max(cam) > 0 else cam # Prevent division by zero # Resize the cam to match the resolution of the original image cam = cv2.resize(cam, tuple(self.resolution)) # Convert to single-channel by summing or taking one of the channels cam = np.sum(cam, axis=-1) if cam.shape[-1] > 1 else cam return cam def generate_gradcam(self, fmap, video_frame, grads): """ Generate the Grad-CAM heatmap and overlay it on the frame. """ cam = self.grad_cam(fmap, grads) # Ensure cam is a single-channel 8-bit image cam = np.uint8(255 * cam) # Scale to 0-255 heatmap = cv2.applyColorMap(cam, cv2.COLORMAP_JET) # Apply colormap # Ensure video_frame is in the right format video_frame = np.float32(cv2.cvtColor(video_frame, cv2.COLOR_RGB2BGR)) # Convert the normalized video_frame back to uint8 (0-255) video_frame = np.uint8(255 * video_frame) # Blend heatmap and original image with a weight to ensure the face is visible alpha = 0.01 # Lower weight for the heatmap to make face more visible beta = 1 - alpha # Weight for the original frame overlayed_img = cv2.addWeighted(heatmap, alpha, video_frame, beta, 0) return overlayed_img def predict(self, video, seq_length=None): """ Predict whether a video is real or fake. Args: video (str): Path to the video file seq_length (int, optional): Number of frames to use Returns: tuple: (prediction_result, gradcam_image, classification_details) """ frames, raw_frames = self.preprocess(video, seq_length) if not frames: return "No faces detected in the video", None, None # Prepare input tensor for the model target_seq_length = ( seq_length if seq_length is not None else self.default_frame_count ) input_tensor = torch.stack(frames).unsqueeze(0) input_tensor = input_tensor.view(1, target_seq_length, 3, *self.resolution) input_tensor = input_tensor.to(self.device) # Remove the torch.no_grad() context to allow gradient computation input_tensor.requires_grad_(True) # Forward pass with gradient tracking enabled fmap, attn_wts, logits = self.model(input_tensor) # Register hook for Grad-CAM fmap.register_hook(self.save_gradients) # Get predictions for all classes class_probs = F.softmax(logits, dim=1).detach().cpu().numpy()[0] # Get the predicted class predicted_class_idx = np.argmax(class_probs) predicted_class = ( self.classes[predicted_class_idx] if predicted_class_idx < len(self.classes) else "Unknown" ) prediction = "Deepfake" if predicted_class_idx > 0 else "Real" # Format confidence values to 2 decimal places confidence_class = round(class_probs[predicted_class_idx] * 100, 2) confidence_deepfake_real = ( round(class_probs[1:].max() * 100, 2) if prediction == "Deepfake" else round(class_probs[0] * 100, 2) ) prediction_string = f"{prediction} {confidence_deepfake_real:.2f}% Confidence" # Create detailed classification results as a dictionary if prediction == "Deepfake": # For deepfakes, show probabilities for each deepfake type classification_details = { self.classes[i]: float(class_probs[i]) for i in range(1, len(self.classes)) } else: # For real videos, just show real confidence classification_details = { "Real": float(class_probs[0]) } # Backpropagate for Grad-CAM self.model.zero_grad() logits[0, predicted_class_idx].backward() grads = self.gradients # Generate Grad-CAM visualization for the best frame if raw_frames: middle_idx = len(raw_frames) // 2 gradcam_image = self.generate_gradcam(fmap, raw_frames[middle_idx], grads) else: gradcam_image = None return prediction_string, gradcam_image, classification_details