deepfake_model / interface.py
Simma7's picture
Create interface.py
deae56e verified
import torch
import cv2
import numpy as np
from PIL import Image
from torchvision import transforms
from huggingface_hub import hf_hub_download
from model import load_model
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# 🔥 Download models
model_paths = {
"m1": hf_hub_download("Simma7/deepfake_model", "video1.pth"),
"m2": hf_hub_download("Simma7/deepfake_model", "video2.pth"),
"m3": hf_hub_download("Simma7/deepfake_model", "video3.pt"),
}
# 🔥 Load models
models = {
"m1": load_model(model_paths["m1"], "clip"),
"m2": load_model(model_paths["m2"], "vit"),
"m3": load_model(model_paths["m3"], "clip"),
}
# transforms
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
# extract frames
def extract_frames(video_path, num_frames=16):
cap = cv2.VideoCapture(video_path)
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
indices = np.linspace(0, total - 1, num_frames, dtype=int)
frames = []
for i in indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, i)
ret, frame = cap.read()
if ret:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(Image.fromarray(frame))
cap.release()
return frames
# predict
def predict(video_path):
frames = extract_frames(video_path)
all_probs = []
with torch.no_grad():
for frame in frames:
x = transform(frame).unsqueeze(0).to(DEVICE)
probs = []
for key, model in models.items():
out = model(x)
if out.shape[-1] == 1:
prob = torch.sigmoid(out).item()
else:
prob = torch.softmax(out, dim=1)[0, 1].item()
probs.append(prob)
frame_prob = sum(probs) / len(probs)
all_probs.append(frame_prob)
final_prob = sum(all_probs) / len(all_probs)
return final_prob