| """ |
| Simple inference script for deepfake detection |
| """ |
|
|
| import torch |
| import torch.nn.functional as F |
| import cv2 |
| import numpy as np |
| from pathlib import Path |
| import json |
| from model import create_model |
|
|
| def load_model(checkpoint_path="pytorch_model.ckpt"): |
| """Load the trained model""" |
| model = create_model() |
| checkpoint = torch.load(checkpoint_path, map_location='cuda') |
| |
| if 'state_dict' in checkpoint: |
| state_dict = checkpoint['state_dict'] |
| else: |
| state_dict = checkpoint |
| |
| new_state_dict = {} |
| for k, v in state_dict.items(): |
| if k.startswith('model.model.'): |
| new_key = k.replace('model.model.', '') |
| new_state_dict[new_key] = v |
| elif k.startswith('model.'): |
| new_key = k.replace('model.', '') |
| new_state_dict[new_key] = v |
| else: |
| new_state_dict[k] = v |
| |
| model.load_state_dict(new_state_dict, strict=False) |
| model.eval() |
| model = model.cuda() |
| return model |
|
|
| def process_video(video_path, n_frames=16, size=224): |
| """Process video into model input format""" |
| cap = cv2.VideoCapture(str(video_path)) |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| |
| if total_frames < n_frames: |
| cap.release() |
| return None |
| |
| frame_indices = np.linspace(0, total_frames - 1, n_frames, dtype=int) |
| frames = [] |
| |
| for idx in frame_indices: |
| cap.set(cv2.CAP_PROP_POS_FRAMES, idx) |
| ret, frame = cap.read() |
| if ret: |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| frame = cv2.resize(frame, (size, size)) |
| frame = torch.from_numpy(frame).float() / 255.0 |
| frame = frame.permute(2, 0, 1) |
| frames.append(frame) |
| |
| cap.release() |
| |
| if len(frames) == n_frames: |
| video = torch.stack(frames, dim=1) |
| video = video * 2 - 1 |
| return video.unsqueeze(0) |
| |
| return None |
|
|
| def detect_deepfake(model, video_tensor, threshold=0.3137): |
| """Detect if video is a deepfake""" |
| with torch.no_grad(): |
| video_tensor = video_tensor.cuda() |
| frame_pred, flow_pred = model(video_tensor) |
| |
| mid = video_tensor.shape[2] // 2 |
| target = video_tensor[:, :, mid] |
| |
| mse = F.mse_loss(frame_pred, target).item() |
| mae = F.l1_loss(frame_pred, target).item() |
| |
| is_fake = mse > threshold |
| confidence = min(abs(mse - threshold) / threshold, 1.0) |
| |
| return { |
| 'prediction': 'FAKE' if is_fake else 'REAL', |
| 'mse': mse, |
| 'mae': mae, |
| 'threshold': threshold, |
| 'confidence': confidence, |
| 'is_fake': is_fake |
| } |
|
|
| if __name__ == "__main__": |
| import sys |
| |
| if len(sys.argv) < 2: |
| print("Usage: python inference.py <video_path>") |
| sys.exit(1) |
| |
| video_path = sys.argv[1] |
| |
| print("Loading model...") |
| model = load_model() |
| |
| print(f"Processing video: {video_path}") |
| video_tensor = process_video(video_path) |
| |
| if video_tensor is None: |
| print("Error: Could not process video") |
| sys.exit(1) |
| |
| print("Detecting deepfake...") |
| result = detect_deepfake(model, video_tensor) |
| |
| print("\n" + "="*50) |
| print("RESULTS:") |
| print("="*50) |
| print(f"Prediction: {result['prediction']}") |
| print(f"MSE: {result['mse']:.4f}") |
| print(f"MAE: {result['mae']:.4f}") |
| print(f"Threshold: {result['threshold']:.4f}") |
| print(f"Confidence: {result['confidence']:.2%}") |
| print("="*50) |
|
|