""" M3 Fallback — ViT temporal deepfake detector (ACTIVE TONIGHT). Model: prithivMLmods/Deep-Fake-Detector-v2-Model (image-classification). Samples 32 frames, averages fake probability. Swap for m3_sstgnn after L40S training. """ from __future__ import annotations import cv2 import numpy as np import torch from PIL import Image from transformers import AutoModelForImageClassification, AutoProcessor class M3FallbackModule: def __init__(self, cache_dir: str = "/data/model_cache"): self.device = "cpu" self.model = AutoModelForImageClassification.from_pretrained( "prithivMLmods/Deep-Fake-Detector-v2-Model", cache_dir=cache_dir ) self.processor = AutoProcessor.from_pretrained( "prithivMLmods/Deep-Fake-Detector-v2-Model", cache_dir=cache_dir ) self.model.eval() # Determine fake label index once id2label = self.model.config.id2label self._fake_idx = next( (i for i, v in id2label.items() if "fake" in str(v).lower()), 1, # default: index 1 = fake ) def to_gpu(self): self.device = "cuda" self.model = self.model.to("cuda") def to_cpu(self): self.device = "cpu" self.model = self.model.to("cpu") @torch.no_grad() def score(self, video_path: str) -> dict: frames = self._extract_frames(video_path, n=32) if not frames: return {"s3": 0.5, "note": "no_frames"} fake_scores: list[float] = [] for frame in frames: inputs = self.processor(images=frame, return_tensors="pt") inputs = {k: v.to(self.device) for k, v in inputs.items()} logits = self.model(**inputs).logits probs = torch.softmax(logits, dim=-1) fake_p = probs[0, self._fake_idx].item() fake_scores.append(fake_p) s3 = float(np.mean(fake_scores)) return {"s3": s3} def _extract_frames(self, video_path: str, n: int = 32) -> list[Image.Image]: cap = cv2.VideoCapture(video_path) total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) indices = np.linspace(0, max(total - 1, 0), n, dtype=int) if total > 0 else [] frames: list[Image.Image] = [] for idx in indices: cap.set(cv2.CAP_PROP_POS_FRAMES, int(idx)) ret, frame = cap.read() if ret: frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))) cap.release() return frames