deepfake-detector / inference.py
dappai's picture
Upload 9 files
459fc8b verified
import torch
import cv2
import numpy as np
import torchvision.transforms as T
from collections import OrderedDict
import base64
from model import DeepfakeEffNetTransformer
from cam import GradCAM, overlay_heatmap
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# LOAD MODEL
model = DeepfakeEffNetTransformer()
state_dict = torch.load(
"best_model.pth",
map_location="cpu"
)
new_state = OrderedDict()
for k, v in state_dict.items():
name = k.replace("module.", "")
new_state[name] = v
model.load_state_dict(new_state)
model = model.to(device)
model.eval()
print("Model loaded")
# GRADCAM TARGET LAYER
target_layer = model.cnn.blocks[-1]
grad_cam = GradCAM(model, target_layer)
# FACE DETECTOR
face_detector = cv2.CascadeClassifier(
cv2.data.haarcascades +
"haarcascade_frontalface_default.xml"
)
# FRAME CACHE
LAST_FRAMES = []
# FRAME EXTRACTION
def extract_and_crop(video_path, num_frames=32):
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
idx = np.linspace(0, total_frames - 1, num_frames).astype(int)
frames = []
for i in idx:
cap.set(cv2.CAP_PROP_POS_FRAMES, i)
ret, frame = cap.read()
if not ret:
continue
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
faces = face_detector.detectMultiScale(
gray,
scaleFactor=1.3,
minNeighbors=5
)
if len(faces) > 0:
x, y, w, h = faces[0]
face = frame[y:y+h, x:x+w]
else:
face = frame
face = cv2.resize(face, (240,240))
face = cv2.cvtColor(face, cv2.COLOR_BGR2RGB)
frames.append(face)
cap.release()
return frames
# TRANSFORM
transform = T.Compose([
T.ToPILImage(),
T.Resize((240,240)),
T.ToTensor(),
T.Normalize([0.5]*3,[0.5]*3)
])
# INFERENCE
def run_inference(video_path):
global LAST_FRAMES
frames = extract_and_crop(video_path)
LAST_FRAMES = frames
if len(frames) == 0:
return {
"label": "Video tidak terbaca",
"confidence": 0,
"frames": []
}
imgs = []
for f in frames:
img = transform(f)
imgs.append(img)
imgs = torch.stack(imgs).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(imgs)
probs = torch.softmax(outputs, dim=1)[0]
pred = torch.argmax(probs).item()
confidence = probs[pred].item() * 100
label = "Real" if pred == 0 else "Fake"
encoded_frames = []
for f in frames:
_, buffer = cv2.imencode(
".jpg",
cv2.cvtColor(f, cv2.COLOR_RGB2BGR)
)
encoded_frames.append(
base64.b64encode(buffer).decode("utf-8")
)
return {
"label": label,
"confidence": confidence,
"frames": encoded_frames
}
# REGION IMPORTANCE
def compute_regions(cam):
regions = {}
regions["Forehead"] = cam[0:60, :].mean()
regions["Eyes"] = cam[60:110, :].mean()
regions["Cheeks"] = cam[110:170, :].mean()
regions["Mouth"] = cam[170:220, :].mean()
regions["Chin"] = cam[220:240, :].mean()
total = sum(regions.values()) + 1e-8
result = []
for k,v in regions.items():
result.append({
"name": k,
"value": float(v / total)
})
return result
# HEATMAP GENERATION
def generate_heatmap(frame_index):
global LAST_FRAMES
if frame_index >= len(LAST_FRAMES):
return None, None
frame = LAST_FRAMES[frame_index]
img = transform(frame)
seq = torch.stack([img] * 32)
seq = seq.unsqueeze(0).to(device)
cam = grad_cam.generate(seq)
regions = compute_regions(cam)
heatmap = overlay_heatmap(
cv2.cvtColor(frame, cv2.COLOR_RGB2BGR),
cam
)
return heatmap, regions