File size: 3,453 Bytes
f96d7b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
539512b
f96d7b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
539512b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f96d7b3
 
 
 
 
539512b
 
 
f96d7b3
 
 
 
 
539512b
 
 
f96d7b3
539512b
f96d7b3
539512b
 
f96d7b3
 
 
 
539512b
f96d7b3
539512b
f96d7b3
 
539512b
 
f96d7b3
539512b
f96d7b3
 
539512b
 
f96d7b3
539512b
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr
import torch
import torch.nn.functional as F
from facenet_pytorch import MTCNN, InceptionResnetV1
import cv2
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from PIL import Image
import numpy as np
import warnings

warnings.filterwarnings("ignore")

# Download and Load Model
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

mtcnn = MTCNN(
    select_largest=False,
    post_process=False,
    device=DEVICE
).to(DEVICE).eval()
model = InceptionResnetV1(
    pretrained="vggface2",
    classify=True,
    num_classes=1,
    device=DEVICE
)

checkpoint = torch.load("resnetinceptionv1_epoch_32.pth", map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])
model.to(DEVICE)
model.eval()

# Model Inference 
def predict_frame(frame):
    """Predict whether the input frame contains a real or fake face"""
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    frame_pil = Image.fromarray(frame)

    face = mtcnn(frame_pil)
    if face is None:
        return None, None  # No face detected

    # Preprocess the face
    face = F.interpolate(face.unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False)
    face = face.to(DEVICE, dtype=torch.float32) / 255.0

    # Predict
    with torch.no_grad():
        output = torch.sigmoid(model(face).squeeze(0))
        prediction = "real" if output.item() < 0.5 else "fake"
        
        # Confidence scores
        real_prediction = 1 - output.item()
        fake_prediction = output.item()
        
        confidences = {
            'real': real_prediction,
            'fake': fake_prediction
        }

    # Visualize
    target_layers = [model.block8.branch1[-1]]
    use_cuda = True if torch.cuda.is_available() else False
    cam = GradCAM(model=model, target_layers=target_layers, use_cuda=use_cuda)
    targets = [ClassifierOutputTarget(0)]
    grayscale_cam = cam(input_tensor=face, targets=targets, eigen_smooth=True)
    grayscale_cam = grayscale_cam[0, :]
    face_np = face.squeeze(0).permute(1, 2, 0).cpu().numpy()
    visualization = show_cam_on_image(face_np, grayscale_cam, use_rgb=True)
    face_with_mask = cv2.addWeighted((face_np * 255).astype(np.uint8), 1, (visualization * 255).astype(np.uint8), 0.5, 0)

    return prediction, face_with_mask

def predict_video(input_video):
    cap = cv2.VideoCapture(input_video)

    frames = []
    confidences = []
    frame_count = 0
    skip_frames = 20

    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frame_count+=1
        if frame_count % skip_frames != 0:  # Skip frames if not divisible by skip_frames
            continue

        prediction, frame_with_mask = predict_frame(frame)

        frames.append(frame_with_mask)
        confidences.append(prediction)

    cap.release()

    # Determine the final prediction based on the maximum occurrence of predictions
    final_prediction = 'fake' if confidences.count('fake') > confidences.count('real') else 'real'

    return final_prediction

# Gradio Interface
interface = gr.Interface(
    fn=predict_video,
    inputs=[
        gr.Video(label="Input Video")
    ],
    outputs=[
        gr.Label(label="Class"),

    ],
    title="Deep fake video Detection",
    description="Detect whether the  Video is fake or real"
)

interface.launch()