import os import numpy as np import cv2 import torch from torchvision import transforms from tqdm import tqdm from dataset.loader import normalize_data from .config import load_config from .genconvit import GenConViT from decord import VideoReader, cpu from .face_detection import detect_faces # Import our new function device = "cuda" if torch.cuda.is_available() else "cpu" def load_genconvit(config, net, ed_weight, vae_weight, fp16): model = GenConViT( config, ed=ed_weight, vae=vae_weight, net=net, fp16=fp16 ) model.to(device) model.eval() if fp16: model.half() return model # Replace face_rec with our new function def face_rec(frames, p=None, klass=None): return detect_faces(frames) def preprocess_frame(frame): df_tensor = torch.tensor(frame, device=device).float() df_tensor = df_tensor.permute((0, 3, 1, 2)) for i in range(len(df_tensor)): df_tensor[i] = normalize_data()["vid"](df_tensor[i] / 255.0) return df_tensor def pred_vid(df, model): with torch.no_grad(): return max_prediction_value(torch.sigmoid(model(df).squeeze())) def max_prediction_value(y_pred): # Finds the index and value of the maximum prediction value. mean_val = torch.mean(y_pred, dim=0) return ( torch.argmax(mean_val).item(), mean_val[0].item() if mean_val[0] > mean_val[1] else abs(1 - mean_val[1]).item(), ) def real_or_fake(prediction): return {0: "REAL", 1: "FAKE"}[prediction ^ 1] def extract_frames(video_file, frames_nums=15): vr = VideoReader(video_file, ctx=cpu(0)) step_size = max(1, len(vr) // frames_nums) # Calculate the step size between frames return vr.get_batch( list(range(0, len(vr), step_size))[:frames_nums] ).asnumpy() # seek frames with step_size def df_face(vid, num_frames, net): img = extract_frames(vid, num_frames) face, count = face_rec(img) return preprocess_frame(face) if count > 0 else [] def is_video(vid): print('IS FILE', os.path.isfile(vid)) return os.path.isfile(vid) and vid.endswith( tuple([".avi", ".mp4", ".mpg", ".mpeg", ".mov"]) ) def set_result(): return { "video": { "name": [], "pred": [], "klass": [], "pred_label": [], "correct_label": [], } } def store_result( result, filename, y, y_val, klass, correct_label=None, compression=None ): result["video"]["name"].append(filename) result["video"]["pred"].append(y_val) result["video"]["klass"].append(klass.lower()) result["video"]["pred_label"].append(real_or_fake(y)) if correct_label is not None: result["video"]["correct_label"].append(correct_label) if compression is not None: result["video"]["compression"].append(compression) return result