Simma7 commited on
Commit
deae56e
·
verified ·
1 Parent(s): 5ad0794

Create interface.py

Browse files
Files changed (1) hide show
  1. interface.py +75 -0
interface.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import cv2
3
+ import numpy as np
4
+ from PIL import Image
5
+ from torchvision import transforms
6
+ from huggingface_hub import hf_hub_download
7
+ from model import load_model
8
+
9
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+ # 🔥 Download models
12
+ model_paths = {
13
+ "m1": hf_hub_download("Simma7/deepfake_model", "video1.pth"),
14
+ "m2": hf_hub_download("Simma7/deepfake_model", "video2.pth"),
15
+ "m3": hf_hub_download("Simma7/deepfake_model", "video3.pt"),
16
+ }
17
+
18
+ # 🔥 Load models
19
+ models = {
20
+ "m1": load_model(model_paths["m1"], "clip"),
21
+ "m2": load_model(model_paths["m2"], "vit"),
22
+ "m3": load_model(model_paths["m3"], "clip"),
23
+ }
24
+
25
+ # transforms
26
+ transform = transforms.Compose([
27
+ transforms.Resize((224, 224)),
28
+ transforms.ToTensor()
29
+ ])
30
+
31
+ # extract frames
32
+ def extract_frames(video_path, num_frames=16):
33
+ cap = cv2.VideoCapture(video_path)
34
+ total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
35
+
36
+ indices = np.linspace(0, total - 1, num_frames, dtype=int)
37
+ frames = []
38
+
39
+ for i in indices:
40
+ cap.set(cv2.CAP_PROP_POS_FRAMES, i)
41
+ ret, frame = cap.read()
42
+ if ret:
43
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
44
+ frames.append(Image.fromarray(frame))
45
+
46
+ cap.release()
47
+ return frames
48
+
49
+ # predict
50
+ def predict(video_path):
51
+ frames = extract_frames(video_path)
52
+
53
+ all_probs = []
54
+
55
+ with torch.no_grad():
56
+ for frame in frames:
57
+ x = transform(frame).unsqueeze(0).to(DEVICE)
58
+
59
+ probs = []
60
+ for key, model in models.items():
61
+ out = model(x)
62
+
63
+ if out.shape[-1] == 1:
64
+ prob = torch.sigmoid(out).item()
65
+ else:
66
+ prob = torch.softmax(out, dim=1)[0, 1].item()
67
+
68
+ probs.append(prob)
69
+
70
+ frame_prob = sum(probs) / len(probs)
71
+ all_probs.append(frame_prob)
72
+
73
+ final_prob = sum(all_probs) / len(all_probs)
74
+
75
+ return final_prob