deepdetection / modules /m1_lipsync.py
akagtag's picture
Fix ZeroGPU startup and local GPU inference path
19d9b40
from __future__ import annotations
import cv2
import librosa
import numpy as np
import torch
from huggingface_hub import hf_hub_download
class LipSyncModule:
"""
LipFD pretrained lip-sync deepfake detector.
Output score is in [0, 1], higher means more likely fake.
"""
def __init__(self, cache_dir: str = "/data/model_cache"):
self.device = "cpu"
self.cache_dir = cache_dir
self.available = False
self.load_error = ""
try:
self._load_model()
self.available = True
except Exception as exc:
self.model = None
self.load_error = str(exc)
print(f"LipSyncModule unavailable: {exc}")
def _load_model(self) -> None:
ckpt_path = hf_hub_download(
repo_id="akagtag/LipFD-checkpoint",
filename="ckpt.pth",
cache_dir=self.cache_dir,
)
from lipfd.model import LipFDNet
self.model = LipFDNet()
state_dict = torch.load(ckpt_path, map_location="cpu")
if isinstance(state_dict, dict) and "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
current = self.model.state_dict()
compatible = {
key.removeprefix("module."): value
for key, value in state_dict.items()
if key.removeprefix("module.") in current
and current[key.removeprefix("module.")].shape == value.shape
}
self.model.load_state_dict(compatible, strict=False)
self.model.eval()
def to_gpu(self) -> None:
if not self.available:
return
self.device = "cuda"
self.model = self.model.to("cuda")
def to_cpu(self) -> None:
if not self.available:
return
self.device = "cpu"
self.model = self.model.to("cpu")
@torch.no_grad()
def score(self, video_path: str) -> dict:
if not self.available:
return {
"s1": 0.5,
"segments": [],
"note": f"module_unavailable: {self.load_error}",
}
frames, audio, fps = self._preprocess(video_path)
if frames is None or audio is None:
return {"s1": 0.5, "segments": [], "note": "no_face_or_audio"}
frames_t = torch.tensor(frames, dtype=torch.float32).to(self.device)
audio_t = torch.tensor(audio, dtype=torch.float32).to(self.device)
logits = self.model(frames_t, audio_t)
score = torch.sigmoid(logits).mean().item()
return {"s1": score, "segments": self._get_segments(logits, fps)}
def _preprocess(self, video_path: str):
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
frames = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
lip_crop = self._extract_lip_region(frame)
if lip_crop is not None and lip_crop.size > 0:
lip_crop = cv2.resize(lip_crop, (96, 96))
frames.append(lip_crop)
cap.release()
if len(frames) < 5:
return None, None, fps
audio, sr = librosa.load(video_path, sr=16000)
if audio.size == 0:
return None, None, fps
mel = librosa.feature.melspectrogram(y=audio, sr=sr)
frames_arr = np.array(frames).transpose(0, 3, 1, 2) / 255.0
return frames_arr, mel, fps
def _extract_lip_region(self, frame):
face_cascade = cv2.CascadeClassifier(
cv2.data.haarcascades + "haarcascade_frontalface_default.xml"
)
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
faces = face_cascade.detectMultiScale(gray, 1.3, 5)
if len(faces) == 0:
return None
x, y, w, h = faces[0]
lip_y = y + int(h * 0.65)
lip_h = int(h * 0.35)
lip_x = x + int(w * 0.2)
lip_w = int(w * 0.6)
return frame[lip_y : lip_y + lip_h, lip_x : lip_x + lip_w]
def _get_segments(self, logits, fps: float) -> list[dict]:
scores = torch.sigmoid(logits).detach().cpu().flatten().numpy()
return [
{"time": round(i / fps, 2), "score": round(float(score), 3)}
for i, score in enumerate(scores)
if score > 0.6
]