File size: 1,945 Bytes
deae56e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
import cv2
import numpy as np
from PIL import Image
from torchvision import transforms
from huggingface_hub import hf_hub_download
from model import load_model

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# 🔥 Download models
model_paths = {
    "m1": hf_hub_download("Simma7/deepfake_model", "video1.pth"),
    "m2": hf_hub_download("Simma7/deepfake_model", "video2.pth"),
    "m3": hf_hub_download("Simma7/deepfake_model", "video3.pt"),
}

# 🔥 Load models
models = {
    "m1": load_model(model_paths["m1"], "clip"),
    "m2": load_model(model_paths["m2"], "vit"),
    "m3": load_model(model_paths["m3"], "clip"),
}

# transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# extract frames
def extract_frames(video_path, num_frames=16):
    cap = cv2.VideoCapture(video_path)
    total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    indices = np.linspace(0, total - 1, num_frames, dtype=int)
    frames = []

    for i in indices:
        cap.set(cv2.CAP_PROP_POS_FRAMES, i)
        ret, frame = cap.read()
        if ret:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(Image.fromarray(frame))

    cap.release()
    return frames

# predict
def predict(video_path):
    frames = extract_frames(video_path)

    all_probs = []

    with torch.no_grad():
        for frame in frames:
            x = transform(frame).unsqueeze(0).to(DEVICE)

            probs = []
            for key, model in models.items():
                out = model(x)

                if out.shape[-1] == 1:
                    prob = torch.sigmoid(out).item()
                else:
                    prob = torch.softmax(out, dim=1)[0, 1].item()

                probs.append(prob)

            frame_prob = sum(probs) / len(probs)
            all_probs.append(frame_prob)

    final_prob = sum(all_probs) / len(all_probs)

    return final_prob