File size: 4,141 Bytes
e231c61 bb7eaf5 e231c61 bb7eaf5 e231c61 bb7eaf5 e231c61 bb7eaf5 e231c61 bb7eaf5 e231c61 bb7eaf5 e231c61 bb7eaf5 e231c61 bb7eaf5 e231c61 bb7eaf5 e231c61 bb7eaf5 e231c61 bb7eaf5 e231c61 bb7eaf5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import os
import torch
import torch.nn as nn
import numpy as np
from torchvision import transforms as T
from torchvision.transforms.v2 import ToDtype
from decord import VideoReader, cpu
import gradio as gr
# -------------------------
# Step 0: Download model from Google Drive if not exists
# -------------------------
model_path = 'vifi_clip_30_epochs_k400_full_finetuned.pth'
if not os.path.exists(model_path):
print(f"π½ Downloading model to {model_path}...")
os.system("pip install -q gdown")
os.system("gdown --id 1Nx30Kbu5xnv6dPwz4I3Ivy380LCdp1Md -O vifi_clip_30_epochs_k400_full_finetuned.pth")
# -------------------------
# Transform
# -------------------------
def _transform(n_px=224):
return T.Compose([
ToDtype(torch.float32, scale=True),
T.Resize(n_px, antialias=True),
T.CenterCrop(n_px),
T.Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)),
])
# -------------------------
# Classifier Head
# -------------------------
class ClassificationHead(nn.Module):
def __init__(self, input_dim=512, num_classes=1):
super().__init__()
self.dense = nn.Linear(input_dim, num_classes)
def forward(self, x):
return self.dense(x)
# -------------------------
# Load ViFi-CLIP Model
# -------------------------
from trainers import vificlip
from utils.config import get_config
from utils.logger import create_logger
cfgpth = 'configs/zero_shot/train/k400/16_16_vifi_clip.yaml'
classifier_path = 'best_detector_model.pt'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class parse_option:
def __init__(self):
self.config = cfgpth
self.output = "exp"
self.resume = model_path
self.only_test = True
self.opts = None
self.batch_size = None
self.pretrained = None
self.accumulation_steps = None
self.local_rank = 0
args = parse_option()
config = get_config(args)
logger = create_logger(output_dir=args.output, name=f"{config.MODEL.ARCH}")
model = vificlip.returnCLIP(config, logger, class_names=["true", "false"])
model = model.float().to(device)
feature_extractor = model
classifier = ClassificationHead()
classifier.load_state_dict(torch.load(classifier_path, map_location=device))
classifier.to(device)
classifier.eval()
# -------------------------
# Inference Function
# -------------------------
def predict_video(video_path, threshold=0.5):
preprocess = _transform(224)
try:
vr = VideoReader(video_path, ctx=cpu(0))
total_frames = len(vr)
num_frames = 16
if total_frames > num_frames:
start = np.random.randint(0, total_frames - num_frames)
indices = list(range(start, start + num_frames))
else:
indices = list(range(total_frames))
indices += [total_frames - 1] * (num_frames - len(indices))
frames = vr.get_batch(indices).asnumpy()
video_tensor = torch.from_numpy(frames).permute(0, 3, 1, 2)
video_tensor = preprocess(video_tensor).unsqueeze(0).to(device)
B, T, C, H, W = video_tensor.shape
input_clip = video_tensor.view(B * T, C, H, W)
with torch.no_grad():
features = feature_extractor.image_encoder(input_clip)
features = features.view(B, T, -1).mean(dim=1)
logits = classifier(features)
prob = torch.sigmoid(logits).item()
label = "Real" if prob >= threshold else "Fake"
return f"{label} (prob: {prob:.4f}, threshold: {threshold})"
except Exception as e:
return f"β Error: {str(e)}"
# -------------------------
# Gradio UI
# -------------------------
gr.Interface(
fn=predict_video,
inputs=[
gr.Video(type="filepath", label="Upload Video (.mp4)"),
gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Threshold (Real β₯ Threshold)")
],
outputs="text",
title="π§ Deepfake Detection with ViFi-CLIP",
description="Upload a video to classify it as Real or Fake. Threshold slider lets you adjust sensitivity."
).launch()
|