Update test_detector.py
Browse files- test_detector.py +23 -20
test_detector.py
CHANGED
|
@@ -15,14 +15,15 @@ from PIL import Image
|
|
| 15 |
from huggingface_hub import hf_hub_download
|
| 16 |
import gc
|
| 17 |
|
| 18 |
-
from
|
| 19 |
-
from
|
|
|
|
| 20 |
|
| 21 |
-
from
|
| 22 |
-
from
|
| 23 |
-
from
|
|
|
|
| 24 |
|
| 25 |
-
import bittensor as bt
|
| 26 |
|
| 27 |
@DETECTOR_REGISTRY.register_module(module_name='test')
|
| 28 |
class TestDetector(DeepfakeDetector):
|
|
@@ -34,11 +35,11 @@ class TestDetector(DeepfakeDetector):
|
|
| 34 |
model_name (str): Name of the detector instance.
|
| 35 |
config (str): Name of the YAML file in deepfake_detectors/config/ to load
|
| 36 |
attributes from.
|
| 37 |
-
|
| 38 |
"""
|
| 39 |
|
| 40 |
-
def __init__(self, model_name: str = 'UCF', config: str = 'ucf.yaml',
|
| 41 |
-
super().__init__(model_name, config,
|
| 42 |
|
| 43 |
def ensure_weights_are_available(self, weight_filename):
|
| 44 |
destination_path = Path(WEIGHTS_DIR) / Path(weight_filename)
|
|
@@ -51,7 +52,7 @@ class TestDetector(DeepfakeDetector):
|
|
| 51 |
|
| 52 |
def load_train_config(self):
|
| 53 |
destination_path = Path(CONFIGS_DIR) / Path(self.train_config)
|
| 54 |
-
|
| 55 |
if not destination_path.exists():
|
| 56 |
local_config_path = hf_hub_download(self.hf_repo, self.train_config)
|
| 57 |
print(f"Downloaded {self.hf_repo}/{self.train_config} to {local_config_path}")
|
|
@@ -83,9 +84,8 @@ class TestDetector(DeepfakeDetector):
|
|
| 83 |
self.init_cudnn()
|
| 84 |
self.init_seed()
|
| 85 |
self.ensure_weights_are_available(self.weights)
|
| 86 |
-
self.ensure_weights_are_available(self.
|
| 87 |
model_class = DETECTOR[self.train_config['model_name']]
|
| 88 |
-
bt.logging.info(f"Loaded config from training run: {self.train_config}")
|
| 89 |
self.model = model_class(self.train_config).to(self.device)
|
| 90 |
self.model.eval()
|
| 91 |
weights_path = Path(WEIGHTS_DIR) / self.weights
|
|
@@ -115,17 +115,20 @@ class TestDetector(DeepfakeDetector):
|
|
| 115 |
torch.Tensor: The preprocessed image tensor, ready for model inference.
|
| 116 |
"""
|
| 117 |
# Convert image to RGB format to ensure consistent color handling.
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
transforms.
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
|
|
|
|
|
|
| 126 |
|
| 127 |
# Apply transformations and add a batch dimension for model inference.
|
| 128 |
image_tensor = transform(image).unsqueeze(0)
|
|
|
|
| 129 |
|
| 130 |
# Move the image tensor to the specified device (e.g., GPU).
|
| 131 |
return image_tensor.to(self.device)
|
|
|
|
| 15 |
from huggingface_hub import hf_hub_download
|
| 16 |
import gc
|
| 17 |
|
| 18 |
+
from bitmind.image_transforms import ucf_transforms, ConvertToRGB, CenterCrop, CLAHE
|
| 19 |
+
from arena.detectors.UCF.config.constants import CONFIGS_DIR, WEIGHTS_DIR
|
| 20 |
+
from arena.detectors.gating_mechanisms import FaceGate
|
| 21 |
|
| 22 |
+
from arena.detectors.UCF.detectors import DETECTOR
|
| 23 |
+
from arena.detectors.deepfake_detectors import DeepfakeDetector
|
| 24 |
+
from arena.detectors import DETECTOR_REGISTRY, GATE_REGISTRY
|
| 25 |
+
from arena.utils.image_transforms import CLAHE
|
| 26 |
|
|
|
|
| 27 |
|
| 28 |
@DETECTOR_REGISTRY.register_module(module_name='test')
|
| 29 |
class TestDetector(DeepfakeDetector):
|
|
|
|
| 35 |
model_name (str): Name of the detector instance.
|
| 36 |
config (str): Name of the YAML file in deepfake_detectors/config/ to load
|
| 37 |
attributes from.
|
| 38 |
+
cuda (bool): Whether to enable cuda (GPU).
|
| 39 |
"""
|
| 40 |
|
| 41 |
+
def __init__(self, model_name: str = 'UCF', config: str = 'ucf.yaml', cuda: bool = True):
|
| 42 |
+
super().__init__(model_name, config, cuda)
|
| 43 |
|
| 44 |
def ensure_weights_are_available(self, weight_filename):
|
| 45 |
destination_path = Path(WEIGHTS_DIR) / Path(weight_filename)
|
|
|
|
| 52 |
|
| 53 |
def load_train_config(self):
|
| 54 |
destination_path = Path(CONFIGS_DIR) / Path(self.train_config)
|
| 55 |
+
|
| 56 |
if not destination_path.exists():
|
| 57 |
local_config_path = hf_hub_download(self.hf_repo, self.train_config)
|
| 58 |
print(f"Downloaded {self.hf_repo}/{self.train_config} to {local_config_path}")
|
|
|
|
| 84 |
self.init_cudnn()
|
| 85 |
self.init_seed()
|
| 86 |
self.ensure_weights_are_available(self.weights)
|
| 87 |
+
self.ensure_weights_are_available(self.backbone_weights)
|
| 88 |
model_class = DETECTOR[self.train_config['model_name']]
|
|
|
|
| 89 |
self.model = model_class(self.train_config).to(self.device)
|
| 90 |
self.model.eval()
|
| 91 |
weights_path = Path(WEIGHTS_DIR) / self.weights
|
|
|
|
| 115 |
torch.Tensor: The preprocessed image tensor, ready for model inference.
|
| 116 |
"""
|
| 117 |
# Convert image to RGB format to ensure consistent color handling.
|
| 118 |
+
tforms = [
|
| 119 |
+
ConvertToRGB(),
|
| 120 |
+
CenterCrop(),
|
| 121 |
+
transforms.Resize((256,256)),
|
| 122 |
+
transforms.ToTensor()
|
| 123 |
+
]
|
| 124 |
+
if self.dataset_type == 'real':
|
| 125 |
+
tforms.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
|
| 126 |
+
|
| 127 |
+
transform = transforms.Compose(tforms)
|
| 128 |
|
| 129 |
# Apply transformations and add a batch dimension for model inference.
|
| 130 |
image_tensor = transform(image).unsqueeze(0)
|
| 131 |
+
|
| 132 |
|
| 133 |
# Move the image tensor to the specified device (e.g., GPU).
|
| 134 |
return image_tensor.to(self.device)
|