Spaces:
Runtime error
Runtime error
File size: 3,486 Bytes
dbe3f97 b636403 dbe3f97 b636403 dbe3f97 b636403 fa3b1cb b636403 fa3b1cb b636403 dbe3f97 b636403 dbe3f97 b636403 dbe3f97 b636403 dbe3f97 b636403 dbe3f97 b636403 dbe3f97 b636403 dbe3f97 b636403 dbe3f97 b636403 dbe3f97 b636403 dbe3f97 b636403 dbe3f97 b636403 dbe3f97 b636403 dbe3f97 b636403 dbe3f97 b636403 dbe3f97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import cv2
import numpy as np
import gradio as gr
# ------------------------------------------------------------------
# 1. Define the GenConViT Model Architecture (Minimal Version)
# ------------------------------------------------------------------
class GenConViT(nn.Module):
def __init__(self, num_classes=2):
super().__init__()
# Very lightweight demo backbone (adjust to your real architecture)
self.feature_extractor = nn.Sequential(
nn.Conv2d(3, 32, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, 3, stride=2, padding=1),
nn.AdaptiveAvgPool2d((1, 1)),
)
self.fc = nn.Linear(128, num_classes)
def forward(self, x):
x = self.feature_extractor(x)
x = x.flatten(1)
return self.fc(x)
# ------------------------------------------------------------------
# 2. Load Model From genconvit_ed_inference.pth
# ------------------------------------------------------------------
model_path = "genconvit_ed_inference.pth"
model = GenConViT(num_classes=2)
checkpoint = torch.load(model_path, map_location="cpu")
model.load_state_dict(checkpoint)
model.eval()
# ------------------------------------------------------------------
# 3. Preprocessing
# ------------------------------------------------------------------
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.5]*3, [0.5]*3)
])
# ------------------------------------------------------------------
# 4. Video Deepfake Detection Function
# ------------------------------------------------------------------
def detect_deepfake(video):
cap = cv2.VideoCapture(video)
if not cap.isOpened():
return "Error: Cannot open video", None
scores = []
sample_frame = None
frame_interval = 10 # Process every 10th frame
i = 0
while True:
ret, frame = cap.read()
if not ret:
break
if i % frame_interval == 0:
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img = Image.fromarray(rgb)
# Save last processed frame for display
sample_frame = img
inp = transform(img).unsqueeze(0)
with torch.no_grad():
logits = model(inp)
probs = torch.softmax(logits, dim=1)[0]
fake_prob = probs[1].item()
scores.append(fake_prob)
i += 1
cap.release()
if len(scores) == 0:
return "No frames processed", None
avg = float(np.mean(scores))
label = "🔴 Deepfake" if avg > 0.5 else "🟢 Real"
output = f"""
### **Prediction: {label}**
**Fake confidence: {avg:.4f}**
"""
return output, sample_frame
# ------------------------------------------------------------------
# 5. Gradio App UI
# ------------------------------------------------------------------
app = gr.Interface(
fn=detect_deepfake,
inputs=gr.Video(label="Upload a video"),
outputs=[
gr.Markdown(label="Prediction"),
gr.Image(label="Sample Frame")
],
title="GenConViT Deepfake Detector (Local .pth Model)",
description="Upload a video. The system loads genconvit_ed_inference.pth and predicts deepfake probability."
)
app.launch()
|