ash12321's picture
Upload complete model package with all files
6ecece2 verified
"""
Simple inference script for deepfake detection
"""
import torch
import torch.nn.functional as F
import cv2
import numpy as np
from pathlib import Path
import json
from model import create_model
def load_model(checkpoint_path="pytorch_model.ckpt"):
"""Load the trained model"""
model = create_model()
checkpoint = torch.load(checkpoint_path, map_location='cuda')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith('model.model.'):
new_key = k.replace('model.model.', '')
new_state_dict[new_key] = v
elif k.startswith('model.'):
new_key = k.replace('model.', '')
new_state_dict[new_key] = v
else:
new_state_dict[k] = v
model.load_state_dict(new_state_dict, strict=False)
model.eval()
model = model.cuda()
return model
def process_video(video_path, n_frames=16, size=224):
"""Process video into model input format"""
cap = cv2.VideoCapture(str(video_path))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if total_frames < n_frames:
cap.release()
return None
frame_indices = np.linspace(0, total_frames - 1, n_frames, dtype=int)
frames = []
for idx in frame_indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
ret, frame = cap.read()
if ret:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = cv2.resize(frame, (size, size))
frame = torch.from_numpy(frame).float() / 255.0
frame = frame.permute(2, 0, 1)
frames.append(frame)
cap.release()
if len(frames) == n_frames:
video = torch.stack(frames, dim=1)
video = video * 2 - 1 # Normalize to [-1, 1]
return video.unsqueeze(0)
return None
def detect_deepfake(model, video_tensor, threshold=0.3137):
"""Detect if video is a deepfake"""
with torch.no_grad():
video_tensor = video_tensor.cuda()
frame_pred, flow_pred = model(video_tensor)
mid = video_tensor.shape[2] // 2
target = video_tensor[:, :, mid]
mse = F.mse_loss(frame_pred, target).item()
mae = F.l1_loss(frame_pred, target).item()
is_fake = mse > threshold
confidence = min(abs(mse - threshold) / threshold, 1.0)
return {
'prediction': 'FAKE' if is_fake else 'REAL',
'mse': mse,
'mae': mae,
'threshold': threshold,
'confidence': confidence,
'is_fake': is_fake
}
if __name__ == "__main__":
import sys
if len(sys.argv) < 2:
print("Usage: python inference.py <video_path>")
sys.exit(1)
video_path = sys.argv[1]
print("Loading model...")
model = load_model()
print(f"Processing video: {video_path}")
video_tensor = process_video(video_path)
if video_tensor is None:
print("Error: Could not process video")
sys.exit(1)
print("Detecting deepfake...")
result = detect_deepfake(model, video_tensor)
print("\n" + "="*50)
print("RESULTS:")
print("="*50)
print(f"Prediction: {result['prediction']}")
print(f"MSE: {result['mse']:.4f}")
print(f"MAE: {result['mae']:.4f}")
print(f"Threshold: {result['threshold']:.4f}")
print(f"Confidence: {result['confidence']:.2%}")
print("="*50)