import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Ignore INFO and WARN messages import random import warnings warnings.filterwarnings("ignore", category=FutureWarning) from pathlib import Path import numpy as np import torch import torch.backends.cudnn as cudnn import torchvision.transforms as transforms import yaml from PIL import Image from huggingface_hub import hf_hub_download import gc from bitmind.image_transforms import ucf_transforms, ConvertToRGB, CenterCrop, CLAHE from arena.detectors.test.config.constants import CONFIGS_DIR, WEIGHTS_DIR from arena.detectors.gating_mechanisms import FaceGate from arena.detectors.test.detectors import DETECTOR from arena.detectors.deepfake_detectors import DeepfakeDetector from arena.detectors import DETECTOR_REGISTRY, GATE_REGISTRY from arena.utils.image_transforms import CLAHE @DETECTOR_REGISTRY.register_module(module_name='test') class TestDetector(DeepfakeDetector): """ DeepfakeDetector subclass that initializes a pretrained UCF model for binary classification of fake and real images. Attributes: model_name (str): Name of the detector instance. config (str): Name of the YAML file in deepfake_detectors/config/ to load attributes from. cuda (bool): Whether to enable cuda (GPU). """ def __init__(self, model_name: str = 'test', config: str = 'test_config.yaml', cuda: bool = True): super().__init__(model_name, config, cuda) def ensure_weights_are_available(self, weight_filename): destination_path = Path(WEIGHTS_DIR) / Path(weight_filename) if not destination_path.parent.exists(): destination_path.parent.mkdir(parents=True, exist_ok=True) if not destination_path.exists(): model_path = hf_hub_download(self.hf_repo, weight_filename) model = torch.load(model_path, map_location=self.device) torch.save(model, destination_path) def load_train_config(self): destination_path = Path(CONFIGS_DIR) / Path(self.train_config) if not destination_path.exists(): local_config_path = hf_hub_download(self.hf_repo, self.train_config) print(f"Downloaded {self.hf_repo}/{self.train_config} to {local_config_path}") config_dict = {} with open(local_config_path, 'r') as f: config_dict = yaml.safe_load(f) with open(destination_path, 'w') as f: yaml.dump(config_dict, f, default_flow_style=False) with destination_path.open('r') as f: return yaml.safe_load(f) else: print(f"Loaded local config from {destination_path}") with destination_path.open('r') as f: return yaml.safe_load(f) def init_cudnn(self): if self.train_config.get('cudnn'): cudnn.benchmark = True def init_seed(self): seed_value = self.train_config.get('manualSeed') if seed_value: random.seed(seed_value) torch.manual_seed(seed_value) torch.cuda.manual_seed_all(seed_value) def load_model(self): self.train_config = self.load_train_config() self.init_cudnn() self.init_seed() self.ensure_weights_are_available(self.weights) self.ensure_weights_are_available(self.backbone_weights) model_class = DETECTOR[self.train_config['model_name']] self.model = model_class(self.train_config).to(self.device) self.model.eval() weights_path = Path(WEIGHTS_DIR) / self.weights checkpoint = torch.load(weights_path, map_location=self.device) try: self.model.load_state_dict(checkpoint, strict=True) except RuntimeError as e: if 'size mismatch' in str(e): # Create a custom error message custom_message = ( "\n\n Error: Incorrect specific_task_num in model config. The 'specific_task_num' " "in 'config_path' yaml should match the value used during training. " "A mismatch results in an incorrect output layer shape for UCF's learned disentanglement" " of different forgery methods/sources.\n\n" "Solution: Use the same config.yaml to intialize UCFDetector ('config_path' arg) " "as output during training (config.yaml saved alongside weights in the training run's " "logs directory). Or simply modify your config.yaml to ensure 'specific_task_num' equals " "the value set during training (defaults to num fake training datasets + 1).\n" ) raise RuntimeError(custom_message) from e else: raise e def preprocess(self, image, res=256): """Preprocess the image for model inference. Returns: torch.Tensor: The preprocessed image tensor, ready for model inference. """ # Convert image to RGB format to ensure consistent color handling. tforms = [ ConvertToRGB(), CenterCrop(), transforms.Resize((256,256)), transforms.ToTensor() ] if self.dataset_type == 'real': tforms.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) transform = transforms.Compose(tforms) # Apply transformations and add a batch dimension for model inference. image_tensor = transform(image).unsqueeze(0) # Move the image tensor to the specified device (e.g., GPU). return image_tensor.to(self.device) def infer(self, image_tensor): """ Perform inference using the model. """ with torch.no_grad(): self.model({'image': image_tensor}, inference=True) return self.model.prob[-1] def __call__(self, image: Image) -> float: image_tensor = self.preprocess(image) return self.infer(image_tensor) def free_memory(self): """ Frees up memory by setting model and large data structures to None. """ if self.model is not None: self.model.cpu() # Move model to CPU to free up GPU memory (if applicable) del self.model self.model = None if self.face_detector is not None: del self.face_detector self.face_detector = None if self.face_predictor is not None: del self.face_predictor self.face_predictor = None gc.collect() # If using GPUs and PyTorch, clear the cache as well if torch.cuda.is_available(): torch.cuda.empty_cache()