File size: 6,834 Bytes
7009c6f ccc7fee 85509cc ccc7fee 7009c6f 85509cc ccc7fee 7009c6f 840500d 7009c6f ccc7fee 7009c6f 18a48f3 ccc7fee 7009c6f ccc7fee 7009c6f ccc7fee 7009c6f ccc7fee 7009c6f ccc7fee 7009c6f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Ignore INFO and WARN messages
import random
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
from pathlib import Path
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import yaml
from PIL import Image
from huggingface_hub import hf_hub_download
import gc
from bitmind.image_transforms import ucf_transforms, ConvertToRGB, CenterCrop, CLAHE
from arena.detectors.test.config.constants import CONFIGS_DIR, WEIGHTS_DIR
from arena.detectors.gating_mechanisms import FaceGate
from arena.detectors.test.detectors import DETECTOR
from arena.detectors.deepfake_detectors import DeepfakeDetector
from arena.detectors import DETECTOR_REGISTRY, GATE_REGISTRY
from arena.utils.image_transforms import CLAHE
@DETECTOR_REGISTRY.register_module(module_name='test')
class TestDetector(DeepfakeDetector):
"""
DeepfakeDetector subclass that initializes a pretrained UCF model
for binary classification of fake and real images.
Attributes:
model_name (str): Name of the detector instance.
config (str): Name of the YAML file in deepfake_detectors/config/ to load
attributes from.
cuda (bool): Whether to enable cuda (GPU).
"""
def __init__(self, model_name: str = 'test', config: str = 'test_config.yaml', cuda: bool = True):
super().__init__(model_name, config, cuda)
def ensure_weights_are_available(self, weight_filename):
destination_path = Path(WEIGHTS_DIR) / Path(weight_filename)
if not destination_path.parent.exists():
destination_path.parent.mkdir(parents=True, exist_ok=True)
if not destination_path.exists():
model_path = hf_hub_download(self.hf_repo, weight_filename)
model = torch.load(model_path, map_location=self.device)
torch.save(model, destination_path)
def load_train_config(self):
destination_path = Path(CONFIGS_DIR) / Path(self.train_config)
if not destination_path.exists():
local_config_path = hf_hub_download(self.hf_repo, self.train_config)
print(f"Downloaded {self.hf_repo}/{self.train_config} to {local_config_path}")
config_dict = {}
with open(local_config_path, 'r') as f:
config_dict = yaml.safe_load(f)
with open(destination_path, 'w') as f:
yaml.dump(config_dict, f, default_flow_style=False)
with destination_path.open('r') as f:
return yaml.safe_load(f)
else:
print(f"Loaded local config from {destination_path}")
with destination_path.open('r') as f:
return yaml.safe_load(f)
def init_cudnn(self):
if self.train_config.get('cudnn'):
cudnn.benchmark = True
def init_seed(self):
seed_value = self.train_config.get('manualSeed')
if seed_value:
random.seed(seed_value)
torch.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value)
def load_model(self):
self.train_config = self.load_train_config()
self.init_cudnn()
self.init_seed()
self.ensure_weights_are_available(self.weights)
self.ensure_weights_are_available(self.backbone_weights)
model_class = DETECTOR[self.train_config['model_name']]
self.model = model_class(self.train_config).to(self.device)
self.model.eval()
weights_path = Path(WEIGHTS_DIR) / self.weights
checkpoint = torch.load(weights_path, map_location=self.device)
try:
self.model.load_state_dict(checkpoint, strict=True)
except RuntimeError as e:
if 'size mismatch' in str(e):
# Create a custom error message
custom_message = (
"\n\n Error: Incorrect specific_task_num in model config. The 'specific_task_num' "
"in 'config_path' yaml should match the value used during training. "
"A mismatch results in an incorrect output layer shape for UCF's learned disentanglement"
" of different forgery methods/sources.\n\n"
"Solution: Use the same config.yaml to intialize UCFDetector ('config_path' arg) "
"as output during training (config.yaml saved alongside weights in the training run's "
"logs directory). Or simply modify your config.yaml to ensure 'specific_task_num' equals "
"the value set during training (defaults to num fake training datasets + 1).\n"
)
raise RuntimeError(custom_message) from e
else: raise e
def preprocess(self, image, res=256):
"""Preprocess the image for model inference.
Returns:
torch.Tensor: The preprocessed image tensor, ready for model inference.
"""
# Convert image to RGB format to ensure consistent color handling.
tforms = [
ConvertToRGB(),
CenterCrop(),
transforms.Resize((256,256)),
transforms.ToTensor()
]
if self.dataset_type == 'real':
tforms.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
transform = transforms.Compose(tforms)
# Apply transformations and add a batch dimension for model inference.
image_tensor = transform(image).unsqueeze(0)
# Move the image tensor to the specified device (e.g., GPU).
return image_tensor.to(self.device)
def infer(self, image_tensor):
""" Perform inference using the model. """
with torch.no_grad():
self.model({'image': image_tensor}, inference=True)
return self.model.prob[-1]
def __call__(self, image: Image) -> float:
image_tensor = self.preprocess(image)
return self.infer(image_tensor)
def free_memory(self):
""" Frees up memory by setting model and large data structures to None. """
if self.model is not None:
self.model.cpu() # Move model to CPU to free up GPU memory (if applicable)
del self.model
self.model = None
if self.face_detector is not None:
del self.face_detector
self.face_detector = None
if self.face_predictor is not None:
del self.face_predictor
self.face_predictor = None
gc.collect()
# If using GPUs and PyTorch, clear the cache as well
if torch.cuda.is_available():
torch.cuda.empty_cache()
|