caliangandrew commited on
Commit
ccc7fee
·
verified ·
1 Parent(s): 8f57ce7

Update test_detector.py

Browse files
Files changed (1) hide show
  1. test_detector.py +23 -20
test_detector.py CHANGED
@@ -15,14 +15,15 @@ from PIL import Image
15
  from huggingface_hub import hf_hub_download
16
  import gc
17
 
18
- from base_miner.UCF.config.constants import CONFIGS_DIR, WEIGHTS_DIR
19
- from base_miner.gating_mechanisms import FaceGate
 
20
 
21
- from base_miner.UCF.detectors import DETECTOR
22
- from base_miner.deepfake_detectors import DeepfakeDetector
23
- from base_miner import DETECTOR_REGISTRY, GATE_REGISTRY
 
24
 
25
- import bittensor as bt
26
 
27
  @DETECTOR_REGISTRY.register_module(module_name='test')
28
  class TestDetector(DeepfakeDetector):
@@ -34,11 +35,11 @@ class TestDetector(DeepfakeDetector):
34
  model_name (str): Name of the detector instance.
35
  config (str): Name of the YAML file in deepfake_detectors/config/ to load
36
  attributes from.
37
- device (str): The type of device ('cpu' or 'cuda').
38
  """
39
 
40
- def __init__(self, model_name: str = 'UCF', config: str = 'ucf.yaml', device: str = 'cpu'):
41
- super().__init__(model_name, config, device)
42
 
43
  def ensure_weights_are_available(self, weight_filename):
44
  destination_path = Path(WEIGHTS_DIR) / Path(weight_filename)
@@ -51,7 +52,7 @@ class TestDetector(DeepfakeDetector):
51
 
52
  def load_train_config(self):
53
  destination_path = Path(CONFIGS_DIR) / Path(self.train_config)
54
-
55
  if not destination_path.exists():
56
  local_config_path = hf_hub_download(self.hf_repo, self.train_config)
57
  print(f"Downloaded {self.hf_repo}/{self.train_config} to {local_config_path}")
@@ -83,9 +84,8 @@ class TestDetector(DeepfakeDetector):
83
  self.init_cudnn()
84
  self.init_seed()
85
  self.ensure_weights_are_available(self.weights)
86
- self.ensure_weights_are_available(self.train_config['pretrained'].split('/')[-1])
87
  model_class = DETECTOR[self.train_config['model_name']]
88
- bt.logging.info(f"Loaded config from training run: {self.train_config}")
89
  self.model = model_class(self.train_config).to(self.device)
90
  self.model.eval()
91
  weights_path = Path(WEIGHTS_DIR) / self.weights
@@ -115,17 +115,20 @@ class TestDetector(DeepfakeDetector):
115
  torch.Tensor: The preprocessed image tensor, ready for model inference.
116
  """
117
  # Convert image to RGB format to ensure consistent color handling.
118
- image = image.convert('RGB')
119
-
120
- # Define transformation sequence for image preprocessing.
121
- transform = transforms.Compose([
122
- transforms.Resize((res, res), interpolation=Image.LANCZOS), # Resize image to specified resolution.
123
- transforms.ToTensor(), # Convert the image to a PyTorch tensor.
124
- transforms.Normalize(mean=self.train_config['mean'], std=self.train_config['std']) # Normalize the image tensor.
125
- ])
 
 
126
 
127
  # Apply transformations and add a batch dimension for model inference.
128
  image_tensor = transform(image).unsqueeze(0)
 
129
 
130
  # Move the image tensor to the specified device (e.g., GPU).
131
  return image_tensor.to(self.device)
 
15
  from huggingface_hub import hf_hub_download
16
  import gc
17
 
18
+ from bitmind.image_transforms import ucf_transforms, ConvertToRGB, CenterCrop, CLAHE
19
+ from arena.detectors.UCF.config.constants import CONFIGS_DIR, WEIGHTS_DIR
20
+ from arena.detectors.gating_mechanisms import FaceGate
21
 
22
+ from arena.detectors.UCF.detectors import DETECTOR
23
+ from arena.detectors.deepfake_detectors import DeepfakeDetector
24
+ from arena.detectors import DETECTOR_REGISTRY, GATE_REGISTRY
25
+ from arena.utils.image_transforms import CLAHE
26
 
 
27
 
28
  @DETECTOR_REGISTRY.register_module(module_name='test')
29
  class TestDetector(DeepfakeDetector):
 
35
  model_name (str): Name of the detector instance.
36
  config (str): Name of the YAML file in deepfake_detectors/config/ to load
37
  attributes from.
38
+ cuda (bool): Whether to enable cuda (GPU).
39
  """
40
 
41
+ def __init__(self, model_name: str = 'UCF', config: str = 'ucf.yaml', cuda: bool = True):
42
+ super().__init__(model_name, config, cuda)
43
 
44
  def ensure_weights_are_available(self, weight_filename):
45
  destination_path = Path(WEIGHTS_DIR) / Path(weight_filename)
 
52
 
53
  def load_train_config(self):
54
  destination_path = Path(CONFIGS_DIR) / Path(self.train_config)
55
+
56
  if not destination_path.exists():
57
  local_config_path = hf_hub_download(self.hf_repo, self.train_config)
58
  print(f"Downloaded {self.hf_repo}/{self.train_config} to {local_config_path}")
 
84
  self.init_cudnn()
85
  self.init_seed()
86
  self.ensure_weights_are_available(self.weights)
87
+ self.ensure_weights_are_available(self.backbone_weights)
88
  model_class = DETECTOR[self.train_config['model_name']]
 
89
  self.model = model_class(self.train_config).to(self.device)
90
  self.model.eval()
91
  weights_path = Path(WEIGHTS_DIR) / self.weights
 
115
  torch.Tensor: The preprocessed image tensor, ready for model inference.
116
  """
117
  # Convert image to RGB format to ensure consistent color handling.
118
+ tforms = [
119
+ ConvertToRGB(),
120
+ CenterCrop(),
121
+ transforms.Resize((256,256)),
122
+ transforms.ToTensor()
123
+ ]
124
+ if self.dataset_type == 'real':
125
+ tforms.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
126
+
127
+ transform = transforms.Compose(tforms)
128
 
129
  # Apply transformations and add a batch dimension for model inference.
130
  image_tensor = transform(image).unsqueeze(0)
131
+
132
 
133
  # Move the image tensor to the specified device (e.g., GPU).
134
  return image_tensor.to(self.device)