genai-deepdetect / modules /m3_fallback.py
akagtag's picture
Initial deploy: M1 SyncNet + M2 CLIP + M3 ViT + M5 Llama NIM
16d70ee verified
"""
M3 Fallback — ViT temporal deepfake detector (ACTIVE TONIGHT).
Model: prithivMLmods/Deep-Fake-Detector-v2-Model (image-classification).
Samples 32 frames, averages fake probability.
Swap for m3_sstgnn after L40S training.
"""
from __future__ import annotations
import cv2
import numpy as np
import torch
from PIL import Image
from transformers import AutoModelForImageClassification, AutoProcessor
class M3FallbackModule:
def __init__(self, cache_dir: str = "/data/model_cache"):
self.device = "cpu"
self.model = AutoModelForImageClassification.from_pretrained(
"prithivMLmods/Deep-Fake-Detector-v2-Model", cache_dir=cache_dir
)
self.processor = AutoProcessor.from_pretrained(
"prithivMLmods/Deep-Fake-Detector-v2-Model", cache_dir=cache_dir
)
self.model.eval()
# Determine fake label index once
id2label = self.model.config.id2label
self._fake_idx = next(
(i for i, v in id2label.items() if "fake" in str(v).lower()),
1, # default: index 1 = fake
)
def to_gpu(self):
self.device = "cuda"
self.model = self.model.to("cuda")
def to_cpu(self):
self.device = "cpu"
self.model = self.model.to("cpu")
@torch.no_grad()
def score(self, video_path: str) -> dict:
frames = self._extract_frames(video_path, n=32)
if not frames:
return {"s3": 0.5, "note": "no_frames"}
fake_scores: list[float] = []
for frame in frames:
inputs = self.processor(images=frame, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
logits = self.model(**inputs).logits
probs = torch.softmax(logits, dim=-1)
fake_p = probs[0, self._fake_idx].item()
fake_scores.append(fake_p)
s3 = float(np.mean(fake_scores))
return {"s3": s3}
def _extract_frames(self, video_path: str, n: int = 32) -> list[Image.Image]:
cap = cv2.VideoCapture(video_path)
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
indices = np.linspace(0, max(total - 1, 0), n, dtype=int) if total > 0 else []
frames: list[Image.Image] = []
for idx in indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, int(idx))
ret, frame = cap.read()
if ret:
frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
cap.release()
return frames