|
|
import os |
|
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' |
|
|
|
|
|
import random |
|
|
import warnings |
|
|
warnings.filterwarnings("ignore", category=FutureWarning) |
|
|
from pathlib import Path |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.backends.cudnn as cudnn |
|
|
import torchvision.transforms as transforms |
|
|
import yaml |
|
|
from PIL import Image |
|
|
from huggingface_hub import hf_hub_download |
|
|
import gc |
|
|
|
|
|
from bitmind.image_transforms import ucf_transforms, ConvertToRGB, CenterCrop, CLAHE |
|
|
from arena.detectors.test.config.constants import CONFIGS_DIR, WEIGHTS_DIR |
|
|
from arena.detectors.gating_mechanisms import FaceGate |
|
|
|
|
|
from arena.detectors.test.detectors import DETECTOR |
|
|
from arena.detectors.deepfake_detectors import DeepfakeDetector |
|
|
from arena.detectors import DETECTOR_REGISTRY, GATE_REGISTRY |
|
|
from arena.utils.image_transforms import CLAHE |
|
|
|
|
|
|
|
|
@DETECTOR_REGISTRY.register_module(module_name='test') |
|
|
class TestDetector(DeepfakeDetector): |
|
|
""" |
|
|
DeepfakeDetector subclass that initializes a pretrained UCF model |
|
|
for binary classification of fake and real images. |
|
|
|
|
|
Attributes: |
|
|
model_name (str): Name of the detector instance. |
|
|
config (str): Name of the YAML file in deepfake_detectors/config/ to load |
|
|
attributes from. |
|
|
cuda (bool): Whether to enable cuda (GPU). |
|
|
""" |
|
|
|
|
|
def __init__(self, model_name: str = 'test', config: str = 'test_config.yaml', cuda: bool = True): |
|
|
super().__init__(model_name, config, cuda) |
|
|
|
|
|
def ensure_weights_are_available(self, weight_filename): |
|
|
destination_path = Path(WEIGHTS_DIR) / Path(weight_filename) |
|
|
if not destination_path.parent.exists(): |
|
|
destination_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
if not destination_path.exists(): |
|
|
model_path = hf_hub_download(self.hf_repo, weight_filename) |
|
|
model = torch.load(model_path, map_location=self.device) |
|
|
torch.save(model, destination_path) |
|
|
|
|
|
def load_train_config(self): |
|
|
destination_path = Path(CONFIGS_DIR) / Path(self.train_config) |
|
|
|
|
|
if not destination_path.exists(): |
|
|
local_config_path = hf_hub_download(self.hf_repo, self.train_config) |
|
|
print(f"Downloaded {self.hf_repo}/{self.train_config} to {local_config_path}") |
|
|
config_dict = {} |
|
|
with open(local_config_path, 'r') as f: |
|
|
config_dict = yaml.safe_load(f) |
|
|
with open(destination_path, 'w') as f: |
|
|
yaml.dump(config_dict, f, default_flow_style=False) |
|
|
with destination_path.open('r') as f: |
|
|
return yaml.safe_load(f) |
|
|
else: |
|
|
print(f"Loaded local config from {destination_path}") |
|
|
with destination_path.open('r') as f: |
|
|
return yaml.safe_load(f) |
|
|
|
|
|
def init_cudnn(self): |
|
|
if self.train_config.get('cudnn'): |
|
|
cudnn.benchmark = True |
|
|
|
|
|
def init_seed(self): |
|
|
seed_value = self.train_config.get('manualSeed') |
|
|
if seed_value: |
|
|
random.seed(seed_value) |
|
|
torch.manual_seed(seed_value) |
|
|
torch.cuda.manual_seed_all(seed_value) |
|
|
|
|
|
def load_model(self): |
|
|
self.train_config = self.load_train_config() |
|
|
self.init_cudnn() |
|
|
self.init_seed() |
|
|
self.ensure_weights_are_available(self.weights) |
|
|
self.ensure_weights_are_available(self.backbone_weights) |
|
|
model_class = DETECTOR[self.train_config['model_name']] |
|
|
self.model = model_class(self.train_config).to(self.device) |
|
|
self.model.eval() |
|
|
weights_path = Path(WEIGHTS_DIR) / self.weights |
|
|
checkpoint = torch.load(weights_path, map_location=self.device) |
|
|
try: |
|
|
self.model.load_state_dict(checkpoint, strict=True) |
|
|
except RuntimeError as e: |
|
|
if 'size mismatch' in str(e): |
|
|
|
|
|
custom_message = ( |
|
|
"\n\n Error: Incorrect specific_task_num in model config. The 'specific_task_num' " |
|
|
"in 'config_path' yaml should match the value used during training. " |
|
|
"A mismatch results in an incorrect output layer shape for UCF's learned disentanglement" |
|
|
" of different forgery methods/sources.\n\n" |
|
|
"Solution: Use the same config.yaml to intialize UCFDetector ('config_path' arg) " |
|
|
"as output during training (config.yaml saved alongside weights in the training run's " |
|
|
"logs directory). Or simply modify your config.yaml to ensure 'specific_task_num' equals " |
|
|
"the value set during training (defaults to num fake training datasets + 1).\n" |
|
|
) |
|
|
raise RuntimeError(custom_message) from e |
|
|
else: raise e |
|
|
|
|
|
def preprocess(self, image, res=256): |
|
|
"""Preprocess the image for model inference. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: The preprocessed image tensor, ready for model inference. |
|
|
""" |
|
|
|
|
|
tforms = [ |
|
|
ConvertToRGB(), |
|
|
CenterCrop(), |
|
|
transforms.Resize((256,256)), |
|
|
transforms.ToTensor() |
|
|
] |
|
|
if self.dataset_type == 'real': |
|
|
tforms.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) |
|
|
|
|
|
transform = transforms.Compose(tforms) |
|
|
|
|
|
|
|
|
image_tensor = transform(image).unsqueeze(0) |
|
|
|
|
|
|
|
|
|
|
|
return image_tensor.to(self.device) |
|
|
|
|
|
def infer(self, image_tensor): |
|
|
""" Perform inference using the model. """ |
|
|
with torch.no_grad(): |
|
|
self.model({'image': image_tensor}, inference=True) |
|
|
return self.model.prob[-1] |
|
|
|
|
|
def __call__(self, image: Image) -> float: |
|
|
image_tensor = self.preprocess(image) |
|
|
return self.infer(image_tensor) |
|
|
|
|
|
def free_memory(self): |
|
|
""" Frees up memory by setting model and large data structures to None. """ |
|
|
if self.model is not None: |
|
|
self.model.cpu() |
|
|
del self.model |
|
|
self.model = None |
|
|
|
|
|
if self.face_detector is not None: |
|
|
del self.face_detector |
|
|
self.face_detector = None |
|
|
|
|
|
if self.face_predictor is not None: |
|
|
del self.face_predictor |
|
|
self.face_predictor = None |
|
|
|
|
|
gc.collect() |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|