Delete test_deps
Browse files- test_deps/README.md +0 -14
- test_deps/config/__init__.py +0 -7
- test_deps/config/__pycache__/__init__.cpython-310.pyc +0 -0
- test_deps/config/__pycache__/constants.cpython-310.pyc +0 -0
- test_deps/config/constants.py +0 -15
- test_deps/config/pretrained_config.yaml +0 -94
- test_deps/config/pretrained_face_config.yaml +0 -94
- test_deps/config/train_config.yaml +0 -9
- test_deps/config/ucf.yaml +0 -73
- test_deps/config/xception.yaml +0 -86
- test_deps/detectors/__init__.py +0 -11
- test_deps/detectors/__pycache__/__init__.cpython-310.pyc +0 -0
- test_deps/detectors/__pycache__/base_detector.cpython-310.pyc +0 -0
- test_deps/detectors/__pycache__/ucf_detector.cpython-310.pyc +0 -0
- test_deps/detectors/base_detector.py +0 -71
- test_deps/detectors/ucf_detector.py +0 -472
- test_deps/logger.py +0 -36
- test_deps/loss/__init__.py +0 -13
- test_deps/loss/__pycache__/__init__.cpython-310.pyc +0 -0
- test_deps/loss/__pycache__/abstract_loss_func.cpython-310.pyc +0 -0
- test_deps/loss/__pycache__/contrastive_regularization.cpython-310.pyc +0 -0
- test_deps/loss/__pycache__/cross_entropy_loss.cpython-310.pyc +0 -0
- test_deps/loss/__pycache__/l1_loss.cpython-310.pyc +0 -0
- test_deps/loss/abstract_loss_func.py +0 -17
- test_deps/loss/contrastive_regularization.py +0 -78
- test_deps/loss/cross_entropy_loss.py +0 -26
- test_deps/loss/l1_loss.py +0 -19
- test_deps/metrics/__init__.py +0 -7
- test_deps/metrics/__pycache__/__init__.cpython-310.pyc +0 -0
- test_deps/metrics/__pycache__/base_metrics_class.cpython-310.pyc +0 -0
- test_deps/metrics/__pycache__/registry.cpython-310.pyc +0 -0
- test_deps/metrics/base_metrics_class.py +0 -205
- test_deps/metrics/registry.py +0 -20
- test_deps/metrics/utils.py +0 -88
- test_deps/networks/__init__.py +0 -11
- test_deps/networks/__pycache__/__init__.cpython-310.pyc +0 -0
- test_deps/networks/__pycache__/xception.cpython-310.pyc +0 -0
- test_deps/networks/xception.py +0 -285
- test_deps/optimizor/LinearLR.py +0 -20
- test_deps/optimizor/SAM.py +0 -77
- test_deps/train_detector.py +0 -460
- test_deps/trainer/trainer.py +0 -441
test_deps/README.md
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
## UCF
|
| 2 |
-
|
| 3 |
-
This model has been adapted from [DeepfakeBench](https://github.com/SCLBD/DeepfakeBench).
|
| 4 |
-
|
| 5 |
-
##
|
| 6 |
-
|
| 7 |
-
- **Train UCF model**:
|
| 8 |
-
- Use `train_ucf.py`, which will download necessary pretrained `xception` backbone weights from HuggingFace (if not present locally) and start a training job with logging outputs in `.logs/`.
|
| 9 |
-
- Customize the training job by editing `config/ucf.yaml`
|
| 10 |
-
- `pm2 start train_ucf.py --no-autorestart` to train a generalist detector on datasets from `DATASET_META`
|
| 11 |
-
- `pm2 start train_ucf.py --no-autorestart -- --faces_only` to train a face expert detector on preprocessed-face only datasets
|
| 12 |
-
|
| 13 |
-
- **Miner Neurons**:
|
| 14 |
-
- The `UCF` class in `pretrained_ucf.py` is used by miner neurons to load and perform inference with pretrained UCF model weights.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_deps/config/__init__.py
DELETED
|
@@ -1,7 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import sys
|
| 3 |
-
current_file_path = os.path.abspath(__file__)
|
| 4 |
-
parent_dir = os.path.dirname(os.path.dirname(current_file_path))
|
| 5 |
-
project_root_dir = os.path.dirname(parent_dir)
|
| 6 |
-
sys.path.append(parent_dir)
|
| 7 |
-
sys.path.append(project_root_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_deps/config/__pycache__/__init__.cpython-310.pyc
DELETED
|
Binary file (350 Bytes)
|
|
|
test_deps/config/__pycache__/constants.cpython-310.pyc
DELETED
|
Binary file (543 Bytes)
|
|
|
test_deps/config/constants.py
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
|
| 3 |
-
# Path to the directory containing the constants.py file
|
| 4 |
-
CONFIGS_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 5 |
-
|
| 6 |
-
# The base directory for UCF-related files, i.e., UCF directory
|
| 7 |
-
UCF_BASE_PATH = os.path.abspath(os.path.join(CONFIGS_DIR, "..")) # Points to bitmind-subnet/base_miner/UCF/
|
| 8 |
-
# Absolute paths for the required files and directories
|
| 9 |
-
CONFIG_PATH = os.path.join(CONFIGS_DIR, "ucf.yaml") # Path to the ucf.yaml file
|
| 10 |
-
WEIGHTS_DIR = os.path.join(UCF_BASE_PATH, "weights/") # Path to pretrained weights directory
|
| 11 |
-
|
| 12 |
-
HF_REPO = "bitmind/ucf"
|
| 13 |
-
BACKBONE_CKPT = "xception_best.pth"
|
| 14 |
-
|
| 15 |
-
DLIB_FACE_PREDICTOR_PATH = os.path.abspath(os.path.join(UCF_BASE_PATH, "../../utils/dlib_tools/shape_predictor_81_face_landmarks.dat"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_deps/config/pretrained_config.yaml
DELETED
|
@@ -1,94 +0,0 @@
|
|
| 1 |
-
SWA: false
|
| 2 |
-
backbone_config:
|
| 3 |
-
dropout: false
|
| 4 |
-
inc: 3
|
| 5 |
-
mode: adjust_channel
|
| 6 |
-
num_classes: 2
|
| 7 |
-
backbone_name: xception
|
| 8 |
-
compression: c23
|
| 9 |
-
cuda: true
|
| 10 |
-
cudnn: true
|
| 11 |
-
dataset_json_folder: preprocessing/dataset_json_v3
|
| 12 |
-
dataset_meta:
|
| 13 |
-
fake:
|
| 14 |
-
- create_splits: false
|
| 15 |
-
path: bitmind/celeb-a-hq___stable-diffusion-xl-base-1.0___256_training_faces
|
| 16 |
-
- create_splits: false
|
| 17 |
-
path: bitmind/ffhq-256___stable-diffusion-xl-base-1.0_training_faces
|
| 18 |
-
real:
|
| 19 |
-
- create_splits: false
|
| 20 |
-
path: bitmind/celeb-a-hq_training_faces
|
| 21 |
-
- create_splits: false
|
| 22 |
-
path: bitmind/ffhq-256_training_faces
|
| 23 |
-
ddp: false
|
| 24 |
-
dry_run: false
|
| 25 |
-
encoder_feat_dim: 512
|
| 26 |
-
faces_only: true
|
| 27 |
-
frame_num:
|
| 28 |
-
test: 32
|
| 29 |
-
train: 32
|
| 30 |
-
lmdb: true
|
| 31 |
-
lmdb_dir: ./datasets/lmdb
|
| 32 |
-
local_rank: 0
|
| 33 |
-
log_dir: ./logs/training/ucf_2024-09-17-16-44-50
|
| 34 |
-
logdir: ./logs
|
| 35 |
-
loss_func:
|
| 36 |
-
cls_loss: cross_entropy
|
| 37 |
-
con_loss: contrastive_regularization
|
| 38 |
-
rec_loss: l1loss
|
| 39 |
-
spe_loss: cross_entropy
|
| 40 |
-
losstype: null
|
| 41 |
-
lr_scheduler: null
|
| 42 |
-
manualSeed: 1024
|
| 43 |
-
mean:
|
| 44 |
-
- 0.5
|
| 45 |
-
- 0.5
|
| 46 |
-
- 0.5
|
| 47 |
-
metric_scoring: auc
|
| 48 |
-
mode: train
|
| 49 |
-
model_name: ucf
|
| 50 |
-
nEpochs: 2
|
| 51 |
-
optimizer:
|
| 52 |
-
adam:
|
| 53 |
-
amsgrad: false
|
| 54 |
-
beta1: 0.9
|
| 55 |
-
beta2: 0.999
|
| 56 |
-
eps: 1.0e-08
|
| 57 |
-
lr: 0.0002
|
| 58 |
-
weight_decay: 0.0005
|
| 59 |
-
sgd:
|
| 60 |
-
lr: 0.0002
|
| 61 |
-
momentum: 0.9
|
| 62 |
-
weight_decay: 0.0005
|
| 63 |
-
type: adam
|
| 64 |
-
pretrained: ../weights/xception_best.pth
|
| 65 |
-
rec_iter: 100
|
| 66 |
-
resolution: 256
|
| 67 |
-
rgb_dir: ./datasets/rgb
|
| 68 |
-
save_avg: true
|
| 69 |
-
save_ckpt: true
|
| 70 |
-
save_epoch: 1
|
| 71 |
-
save_feat: true
|
| 72 |
-
specific_task_number: 2
|
| 73 |
-
split_transforms:
|
| 74 |
-
test:
|
| 75 |
-
name: base_transforms
|
| 76 |
-
train:
|
| 77 |
-
name: random_aug_transforms
|
| 78 |
-
validation:
|
| 79 |
-
name: base_transforms
|
| 80 |
-
start_epoch: 0
|
| 81 |
-
std:
|
| 82 |
-
- 0.5
|
| 83 |
-
- 0.5
|
| 84 |
-
- 0.5
|
| 85 |
-
test_batchSize: 32
|
| 86 |
-
train_batchSize: 32
|
| 87 |
-
train_dataset:
|
| 88 |
-
- bitmind/celeb-a-hq_training_faces
|
| 89 |
-
- bitmind/ffhq-256_training_faces
|
| 90 |
-
- bitmind/celeb-a-hq___stable-diffusion-xl-base-1.0___256_training_faces
|
| 91 |
-
- bitmind/ffhq-256___stable-diffusion-xl-base-1.0_training_faces
|
| 92 |
-
with_landmark: false
|
| 93 |
-
with_mask: false
|
| 94 |
-
workers: 7
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_deps/config/pretrained_face_config.yaml
DELETED
|
@@ -1,94 +0,0 @@
|
|
| 1 |
-
SWA: false
|
| 2 |
-
backbone_config:
|
| 3 |
-
dropout: false
|
| 4 |
-
inc: 3
|
| 5 |
-
mode: adjust_channel
|
| 6 |
-
num_classes: 2
|
| 7 |
-
backbone_name: xception
|
| 8 |
-
compression: c23
|
| 9 |
-
cuda: true
|
| 10 |
-
cudnn: true
|
| 11 |
-
dataset_json_folder: preprocessing/dataset_json_v3
|
| 12 |
-
dataset_meta:
|
| 13 |
-
fake:
|
| 14 |
-
- create_splits: false
|
| 15 |
-
path: bitmind/celeb-a-hq___stable-diffusion-xl-base-1.0___256_training_faces
|
| 16 |
-
- create_splits: false
|
| 17 |
-
path: bitmind/ffhq-256___stable-diffusion-xl-base-1.0_training_faces
|
| 18 |
-
real:
|
| 19 |
-
- create_splits: false
|
| 20 |
-
path: bitmind/celeb-a-hq_training_faces
|
| 21 |
-
- create_splits: false
|
| 22 |
-
path: bitmind/ffhq-256_training_faces
|
| 23 |
-
ddp: false
|
| 24 |
-
dry_run: false
|
| 25 |
-
encoder_feat_dim: 512
|
| 26 |
-
faces_only: true
|
| 27 |
-
frame_num:
|
| 28 |
-
test: 32
|
| 29 |
-
train: 32
|
| 30 |
-
lmdb: true
|
| 31 |
-
lmdb_dir: ./datasets/lmdb
|
| 32 |
-
local_rank: 0
|
| 33 |
-
log_dir: ./logs/training/ucf_2024-09-17-16-44-50
|
| 34 |
-
logdir: ./logs
|
| 35 |
-
loss_func:
|
| 36 |
-
cls_loss: cross_entropy
|
| 37 |
-
con_loss: contrastive_regularization
|
| 38 |
-
rec_loss: l1loss
|
| 39 |
-
spe_loss: cross_entropy
|
| 40 |
-
losstype: null
|
| 41 |
-
lr_scheduler: null
|
| 42 |
-
manualSeed: 1024
|
| 43 |
-
mean:
|
| 44 |
-
- 0.5
|
| 45 |
-
- 0.5
|
| 46 |
-
- 0.5
|
| 47 |
-
metric_scoring: auc
|
| 48 |
-
mode: train
|
| 49 |
-
model_name: ucf
|
| 50 |
-
nEpochs: 2
|
| 51 |
-
optimizer:
|
| 52 |
-
adam:
|
| 53 |
-
amsgrad: false
|
| 54 |
-
beta1: 0.9
|
| 55 |
-
beta2: 0.999
|
| 56 |
-
eps: 1.0e-08
|
| 57 |
-
lr: 0.0002
|
| 58 |
-
weight_decay: 0.0005
|
| 59 |
-
sgd:
|
| 60 |
-
lr: 0.0002
|
| 61 |
-
momentum: 0.9
|
| 62 |
-
weight_decay: 0.0005
|
| 63 |
-
type: adam
|
| 64 |
-
pretrained: ../weights/xception_best.pth
|
| 65 |
-
rec_iter: 100
|
| 66 |
-
resolution: 256
|
| 67 |
-
rgb_dir: ./datasets/rgb
|
| 68 |
-
save_avg: true
|
| 69 |
-
save_ckpt: true
|
| 70 |
-
save_epoch: 1
|
| 71 |
-
save_feat: true
|
| 72 |
-
specific_task_number: 2
|
| 73 |
-
split_transforms:
|
| 74 |
-
test:
|
| 75 |
-
name: base_transforms
|
| 76 |
-
train:
|
| 77 |
-
name: random_aug_transforms
|
| 78 |
-
validation:
|
| 79 |
-
name: base_transforms
|
| 80 |
-
start_epoch: 0
|
| 81 |
-
std:
|
| 82 |
-
- 0.5
|
| 83 |
-
- 0.5
|
| 84 |
-
- 0.5
|
| 85 |
-
test_batchSize: 32
|
| 86 |
-
train_batchSize: 32
|
| 87 |
-
train_dataset:
|
| 88 |
-
- bitmind/celeb-a-hq_training_faces
|
| 89 |
-
- bitmind/ffhq-256_training_faces
|
| 90 |
-
- bitmind/celeb-a-hq___stable-diffusion-xl-base-1.0___256_training_faces
|
| 91 |
-
- bitmind/ffhq-256___stable-diffusion-xl-base-1.0_training_faces
|
| 92 |
-
with_landmark: false
|
| 93 |
-
with_mask: false
|
| 94 |
-
workers: 7
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_deps/config/train_config.yaml
DELETED
|
@@ -1,9 +0,0 @@
|
|
| 1 |
-
mode: train
|
| 2 |
-
lmdb: True
|
| 3 |
-
dry_run: false
|
| 4 |
-
rgb_dir: './datasets/rgb'
|
| 5 |
-
lmdb_dir: './datasets/lmdb'
|
| 6 |
-
dataset_json_folder: './preprocessing/dataset_json'
|
| 7 |
-
SWA: False
|
| 8 |
-
save_avg: True
|
| 9 |
-
log_dir: ./logs/training/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_deps/config/ucf.yaml
DELETED
|
@@ -1,73 +0,0 @@
|
|
| 1 |
-
# log dir
|
| 2 |
-
log_dir: ../debug_logs/ucf
|
| 3 |
-
|
| 4 |
-
# model setting
|
| 5 |
-
pretrained: ../weights/xception_best.pth # path to a pre-trained model, if using one
|
| 6 |
-
model_name: ucf # model name
|
| 7 |
-
backbone_name: xception # backbone name
|
| 8 |
-
encoder_feat_dim: 512 # feature dimension of the backbone
|
| 9 |
-
|
| 10 |
-
#backbone setting
|
| 11 |
-
backbone_config:
|
| 12 |
-
mode: adjust_channel
|
| 13 |
-
num_classes: 2
|
| 14 |
-
inc: 3
|
| 15 |
-
dropout: false
|
| 16 |
-
|
| 17 |
-
compression: c23 # compression-level for videos
|
| 18 |
-
train_batchSize: 32 # training batch size
|
| 19 |
-
test_batchSize: 32 # test batch size
|
| 20 |
-
workers: 8 # number of data loading workers
|
| 21 |
-
frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing
|
| 22 |
-
resolution: 256 # resolution of output image to network
|
| 23 |
-
with_mask: false # whether to include mask information in the input
|
| 24 |
-
with_landmark: false # whether to include facial landmark information in the input
|
| 25 |
-
save_ckpt: true # whether to save checkpoint
|
| 26 |
-
save_feat: true # whether to save features
|
| 27 |
-
specific_task_number: 2 # default num datasets in FF++ used by DFB, overwritten in training
|
| 28 |
-
|
| 29 |
-
# mean and std for normalization
|
| 30 |
-
mean: [0.5, 0.5, 0.5]
|
| 31 |
-
std: [0.5, 0.5, 0.5]
|
| 32 |
-
|
| 33 |
-
# optimizer config
|
| 34 |
-
optimizer:
|
| 35 |
-
# choose between 'adam' and 'sgd'
|
| 36 |
-
type: adam
|
| 37 |
-
adam:
|
| 38 |
-
lr: 0.0002 # learning rate
|
| 39 |
-
beta1: 0.9 # beta1 for Adam optimizer
|
| 40 |
-
beta2: 0.999 # beta2 for Adam optimizer
|
| 41 |
-
eps: 0.00000001 # epsilon for Adam optimizer
|
| 42 |
-
weight_decay: 0.0005 # weight decay for regularization
|
| 43 |
-
amsgrad: false
|
| 44 |
-
sgd:
|
| 45 |
-
lr: 0.0002 # learning rate
|
| 46 |
-
momentum: 0.9 # momentum for SGD optimizer
|
| 47 |
-
weight_decay: 0.0005 # weight decay for regularization
|
| 48 |
-
|
| 49 |
-
# training config
|
| 50 |
-
lr_scheduler: null # learning rate scheduler
|
| 51 |
-
nEpochs: 20 # number of epochs to train for
|
| 52 |
-
start_epoch: 0 # manual epoch number (useful for restarts)
|
| 53 |
-
save_epoch: 1 # interval epochs for saving models
|
| 54 |
-
rec_iter: 100 # interval iterations for recording
|
| 55 |
-
logdir: ./logs # folder to output images and logs
|
| 56 |
-
manualSeed: 1024 # manual seed for random number generation
|
| 57 |
-
save_ckpt: false # whether to save checkpoint
|
| 58 |
-
|
| 59 |
-
# loss function
|
| 60 |
-
loss_func:
|
| 61 |
-
cls_loss: cross_entropy # loss function to use
|
| 62 |
-
spe_loss: cross_entropy
|
| 63 |
-
con_loss: contrastive_regularization
|
| 64 |
-
rec_loss: l1loss
|
| 65 |
-
losstype: null
|
| 66 |
-
|
| 67 |
-
# metric
|
| 68 |
-
metric_scoring: auc # metric for evaluation (auc, acc, eer, ap)
|
| 69 |
-
|
| 70 |
-
# cuda
|
| 71 |
-
|
| 72 |
-
cuda: true # whether to use CUDA acceleration
|
| 73 |
-
cudnn: true # whether to use CuDNN for convolution operations
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_deps/config/xception.yaml
DELETED
|
@@ -1,86 +0,0 @@
|
|
| 1 |
-
# log dir
|
| 2 |
-
log_dir: /data/home/zhiyuanyan/DeepfakeBench/logs/testing_bench
|
| 3 |
-
|
| 4 |
-
# model setting
|
| 5 |
-
pretrained: /data/home/zhiyuanyan/DeepfakeBench/training/pretrained/xception-b5690688.pth # path to a pre-trained model, if using one
|
| 6 |
-
model_name: xception # model name
|
| 7 |
-
backbone_name: xception # backbone name
|
| 8 |
-
|
| 9 |
-
#backbone setting
|
| 10 |
-
backbone_config:
|
| 11 |
-
mode: original
|
| 12 |
-
num_classes: 2
|
| 13 |
-
inc: 3
|
| 14 |
-
dropout: false
|
| 15 |
-
|
| 16 |
-
# dataset
|
| 17 |
-
all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV]
|
| 18 |
-
train_dataset: [FaceForensics++]
|
| 19 |
-
test_dataset: [FaceForensics++, DeepFakeDetection]
|
| 20 |
-
|
| 21 |
-
compression: c23 # compression-level for videos
|
| 22 |
-
train_batchSize: 32 # training batch size
|
| 23 |
-
test_batchSize: 32 # test batch size
|
| 24 |
-
workers: 8 # number of data loading workers
|
| 25 |
-
frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing
|
| 26 |
-
resolution: 256 # resolution of output image to network
|
| 27 |
-
with_mask: false # whether to include mask information in the input
|
| 28 |
-
with_landmark: false # whether to include facial landmark information in the input
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
# data augmentation
|
| 32 |
-
use_data_augmentation: true # Add this flag to enable/disable data augmentation
|
| 33 |
-
data_aug:
|
| 34 |
-
flip_prob: 0.5
|
| 35 |
-
rotate_prob: 0.0
|
| 36 |
-
rotate_limit: [-10, 10]
|
| 37 |
-
blur_prob: 0.5
|
| 38 |
-
blur_limit: [3, 7]
|
| 39 |
-
brightness_prob: 0.5
|
| 40 |
-
brightness_limit: [-0.1, 0.1]
|
| 41 |
-
contrast_limit: [-0.1, 0.1]
|
| 42 |
-
quality_lower: 40
|
| 43 |
-
quality_upper: 100
|
| 44 |
-
|
| 45 |
-
# mean and std for normalization
|
| 46 |
-
mean: [0.5, 0.5, 0.5]
|
| 47 |
-
std: [0.5, 0.5, 0.5]
|
| 48 |
-
|
| 49 |
-
# optimizer config
|
| 50 |
-
optimizer:
|
| 51 |
-
# choose between 'adam' and 'sgd'
|
| 52 |
-
type: adam
|
| 53 |
-
adam:
|
| 54 |
-
lr: 0.0002 # learning rate
|
| 55 |
-
beta1: 0.9 # beta1 for Adam optimizer
|
| 56 |
-
beta2: 0.999 # beta2 for Adam optimizer
|
| 57 |
-
eps: 0.00000001 # epsilon for Adam optimizer
|
| 58 |
-
weight_decay: 0.0005 # weight decay for regularization
|
| 59 |
-
amsgrad: false
|
| 60 |
-
sgd:
|
| 61 |
-
lr: 0.0002 # learning rate
|
| 62 |
-
momentum: 0.9 # momentum for SGD optimizer
|
| 63 |
-
weight_decay: 0.0005 # weight decay for regularization
|
| 64 |
-
|
| 65 |
-
# training config
|
| 66 |
-
lr_scheduler: null # learning rate scheduler
|
| 67 |
-
nEpochs: 10 # number of epochs to train for
|
| 68 |
-
start_epoch: 0 # manual epoch number (useful for restarts)
|
| 69 |
-
save_epoch: 1 # interval epochs for saving models
|
| 70 |
-
rec_iter: 100 # interval iterations for recording
|
| 71 |
-
logdir: ./logs # folder to output images and logs
|
| 72 |
-
manualSeed: 1024 # manual seed for random number generation
|
| 73 |
-
save_ckpt: true # whether to save checkpoint
|
| 74 |
-
save_feat: true # whether to save features
|
| 75 |
-
|
| 76 |
-
# loss function
|
| 77 |
-
loss_func: cross_entropy # loss function to use
|
| 78 |
-
losstype: null
|
| 79 |
-
|
| 80 |
-
# metric
|
| 81 |
-
metric_scoring: auc # metric for evaluation (auc, acc, eer, ap)
|
| 82 |
-
|
| 83 |
-
# cuda
|
| 84 |
-
|
| 85 |
-
cuda: true # whether to use CUDA acceleration
|
| 86 |
-
cudnn: true # whether to use CuDNN for convolution operations
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_deps/detectors/__init__.py
DELETED
|
@@ -1,11 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import sys
|
| 3 |
-
current_file_path = os.path.abspath(__file__)
|
| 4 |
-
parent_dir = os.path.dirname(os.path.dirname(current_file_path))
|
| 5 |
-
project_root_dir = os.path.dirname(parent_dir)
|
| 6 |
-
sys.path.append(parent_dir)
|
| 7 |
-
sys.path.append(project_root_dir)
|
| 8 |
-
|
| 9 |
-
from metrics.registry import DETECTOR
|
| 10 |
-
|
| 11 |
-
from .ucf_detector import UCFDetector
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_deps/detectors/__pycache__/__init__.cpython-310.pyc
DELETED
|
Binary file (455 Bytes)
|
|
|
test_deps/detectors/__pycache__/base_detector.cpython-310.pyc
DELETED
|
Binary file (2.57 kB)
|
|
|
test_deps/detectors/__pycache__/ucf_detector.cpython-310.pyc
DELETED
|
Binary file (12.9 kB)
|
|
|
test_deps/detectors/base_detector.py
DELETED
|
@@ -1,71 +0,0 @@
|
|
| 1 |
-
# author: Zhiyuan Yan
|
| 2 |
-
# email: zhiyuanyan@link.cuhk.edu.cn
|
| 3 |
-
# date: 2023-0706
|
| 4 |
-
# description: Abstract Class for the Deepfake Detector
|
| 5 |
-
|
| 6 |
-
import abc
|
| 7 |
-
import torch
|
| 8 |
-
import torch.nn as nn
|
| 9 |
-
from typing import Union
|
| 10 |
-
|
| 11 |
-
class AbstractDetector(nn.Module, metaclass=abc.ABCMeta):
|
| 12 |
-
"""
|
| 13 |
-
All deepfake detectors should subclass this class.
|
| 14 |
-
"""
|
| 15 |
-
def __init__(self, config=None, load_param: Union[bool, str] = False):
|
| 16 |
-
"""
|
| 17 |
-
config: (dict)
|
| 18 |
-
configurations for the model
|
| 19 |
-
load_param: (False | True | Path(str))
|
| 20 |
-
False Do not read; True Read the default path; Path Read the required path
|
| 21 |
-
"""
|
| 22 |
-
super().__init__()
|
| 23 |
-
|
| 24 |
-
@abc.abstractmethod
|
| 25 |
-
def features(self, data_dict: dict) -> torch.tensor:
|
| 26 |
-
"""
|
| 27 |
-
Returns the features from the backbone given the input data.
|
| 28 |
-
"""
|
| 29 |
-
pass
|
| 30 |
-
|
| 31 |
-
@abc.abstractmethod
|
| 32 |
-
def forward(self, data_dict: dict, inference=False) -> dict:
|
| 33 |
-
"""
|
| 34 |
-
Forward pass through the model, returning the prediction dictionary.
|
| 35 |
-
"""
|
| 36 |
-
pass
|
| 37 |
-
|
| 38 |
-
@abc.abstractmethod
|
| 39 |
-
def classifier(self, features: torch.tensor) -> torch.tensor:
|
| 40 |
-
"""
|
| 41 |
-
Classifies the features into classes.
|
| 42 |
-
"""
|
| 43 |
-
pass
|
| 44 |
-
|
| 45 |
-
@abc.abstractmethod
|
| 46 |
-
def build_backbone(self, config):
|
| 47 |
-
"""
|
| 48 |
-
Builds the backbone of the model.
|
| 49 |
-
"""
|
| 50 |
-
pass
|
| 51 |
-
|
| 52 |
-
@abc.abstractmethod
|
| 53 |
-
def build_loss(self, config):
|
| 54 |
-
"""
|
| 55 |
-
Builds the loss function for the model.
|
| 56 |
-
"""
|
| 57 |
-
pass
|
| 58 |
-
|
| 59 |
-
@abc.abstractmethod
|
| 60 |
-
def get_losses(self, data_dict: dict, pred_dict: dict) -> dict:
|
| 61 |
-
"""
|
| 62 |
-
Returns the losses for the model.
|
| 63 |
-
"""
|
| 64 |
-
pass
|
| 65 |
-
|
| 66 |
-
@abc.abstractmethod
|
| 67 |
-
def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict:
|
| 68 |
-
"""
|
| 69 |
-
Returns the training metrics for the model.
|
| 70 |
-
"""
|
| 71 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_deps/detectors/ucf_detector.py
DELETED
|
@@ -1,472 +0,0 @@
|
|
| 1 |
-
'''
|
| 2 |
-
# Source: https://github.com/SCLBD/DeepfakeBench/blob/main/training/detectors/ucf_detector.py
|
| 3 |
-
# author: Zhiyuan Yan
|
| 4 |
-
# email: zhiyuanyan@link.cuhk.edu.cn
|
| 5 |
-
# date: 2023-0706
|
| 6 |
-
# description: Class for the UCFDetector
|
| 7 |
-
|
| 8 |
-
Functions in the Class are summarized as:
|
| 9 |
-
1. __init__: Initialization
|
| 10 |
-
2. build_backbone: Backbone-building
|
| 11 |
-
3. build_loss: Loss-function-building
|
| 12 |
-
4. features: Feature-extraction
|
| 13 |
-
5. classifier: Classification
|
| 14 |
-
6. get_losses: Loss-computation
|
| 15 |
-
7. get_train_metrics: Training-metrics-computation
|
| 16 |
-
8. get_test_metrics: Testing-metrics-computation
|
| 17 |
-
9. forward: Forward-propagation
|
| 18 |
-
|
| 19 |
-
Reference:
|
| 20 |
-
@article{yan2023ucf,
|
| 21 |
-
title={UCF: Uncovering Common Features for Generalizable Deepfake Detection},
|
| 22 |
-
author={Yan, Zhiyuan and Zhang, Yong and Fan, Yanbo and Wu, Baoyuan},
|
| 23 |
-
journal={arXiv preprint arXiv:2304.13949},
|
| 24 |
-
year={2023}
|
| 25 |
-
}
|
| 26 |
-
'''
|
| 27 |
-
|
| 28 |
-
import os
|
| 29 |
-
import datetime
|
| 30 |
-
import logging
|
| 31 |
-
import random
|
| 32 |
-
import numpy as np
|
| 33 |
-
from sklearn import metrics
|
| 34 |
-
from typing import Union
|
| 35 |
-
from collections import defaultdict
|
| 36 |
-
|
| 37 |
-
import torch
|
| 38 |
-
import torch.nn as nn
|
| 39 |
-
import torch.nn.functional as F
|
| 40 |
-
import torch.optim as optim
|
| 41 |
-
from torch.nn import DataParallel
|
| 42 |
-
from torch.utils.tensorboard import SummaryWriter
|
| 43 |
-
|
| 44 |
-
from metrics.base_metrics_class import calculate_metrics_for_train
|
| 45 |
-
|
| 46 |
-
from .base_detector import AbstractDetector
|
| 47 |
-
from arena.detectors.UCF.detectors import DETECTOR
|
| 48 |
-
from networks import BACKBONE
|
| 49 |
-
from loss import LOSSFUNC
|
| 50 |
-
|
| 51 |
-
logger = logging.getLogger(__name__)
|
| 52 |
-
|
| 53 |
-
@DETECTOR.register_module(module_name='ucf')
|
| 54 |
-
class UCFDetector(AbstractDetector):
|
| 55 |
-
def __init__(self, config):
|
| 56 |
-
super().__init__()
|
| 57 |
-
self.config = config
|
| 58 |
-
self.num_classes = config['backbone_config']['num_classes']
|
| 59 |
-
self.encoder_feat_dim = config['encoder_feat_dim']
|
| 60 |
-
self.half_fingerprint_dim = self.encoder_feat_dim//2
|
| 61 |
-
|
| 62 |
-
self.encoder_f = self.build_backbone(config)
|
| 63 |
-
self.encoder_c = self.build_backbone(config)
|
| 64 |
-
|
| 65 |
-
self.loss_func = self.build_loss(config)
|
| 66 |
-
self.prob, self.label = [], []
|
| 67 |
-
self.correct, self.total = 0, 0
|
| 68 |
-
|
| 69 |
-
# basic function
|
| 70 |
-
self.lr = nn.LeakyReLU(inplace=True)
|
| 71 |
-
self.do = nn.Dropout(0.2)
|
| 72 |
-
self.pool = nn.AdaptiveAvgPool2d(1)
|
| 73 |
-
|
| 74 |
-
# conditional gan
|
| 75 |
-
self.con_gan = Conditional_UNet()
|
| 76 |
-
|
| 77 |
-
# head
|
| 78 |
-
specific_task_number = config['specific_task_number']
|
| 79 |
-
|
| 80 |
-
self.head_spe = Head(
|
| 81 |
-
in_f=self.half_fingerprint_dim,
|
| 82 |
-
hidden_dim=self.encoder_feat_dim,
|
| 83 |
-
out_f=specific_task_number
|
| 84 |
-
)
|
| 85 |
-
self.head_sha = Head(
|
| 86 |
-
in_f=self.half_fingerprint_dim,
|
| 87 |
-
hidden_dim=self.encoder_feat_dim,
|
| 88 |
-
out_f=self.num_classes
|
| 89 |
-
)
|
| 90 |
-
self.block_spe = Conv2d1x1(
|
| 91 |
-
in_f=self.encoder_feat_dim,
|
| 92 |
-
hidden_dim=self.half_fingerprint_dim,
|
| 93 |
-
out_f=self.half_fingerprint_dim
|
| 94 |
-
)
|
| 95 |
-
self.block_sha = Conv2d1x1(
|
| 96 |
-
in_f=self.encoder_feat_dim,
|
| 97 |
-
hidden_dim=self.half_fingerprint_dim,
|
| 98 |
-
out_f=self.half_fingerprint_dim
|
| 99 |
-
)
|
| 100 |
-
|
| 101 |
-
def build_backbone(self, config):
|
| 102 |
-
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 103 |
-
pretrained_path = os.path.join(current_dir, config['pretrained'])
|
| 104 |
-
# prepare the backbone
|
| 105 |
-
backbone_class = BACKBONE[config['backbone_name']]
|
| 106 |
-
model_config = config['backbone_config']
|
| 107 |
-
backbone = backbone_class(model_config)
|
| 108 |
-
# if donot load the pretrained weights, fail to get good results
|
| 109 |
-
state_dict = torch.load(pretrained_path)
|
| 110 |
-
for name, weights in state_dict.items():
|
| 111 |
-
if 'pointwise' in name:
|
| 112 |
-
state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1)
|
| 113 |
-
state_dict = {k:v for k, v in state_dict.items() if 'fc' not in k}
|
| 114 |
-
backbone.load_state_dict(state_dict, False)
|
| 115 |
-
logger.info('Load pretrained model successfully!')
|
| 116 |
-
return backbone
|
| 117 |
-
|
| 118 |
-
def build_loss(self, config):
|
| 119 |
-
cls_loss_class = LOSSFUNC[config['loss_func']['cls_loss']]
|
| 120 |
-
spe_loss_class = LOSSFUNC[config['loss_func']['spe_loss']]
|
| 121 |
-
con_loss_class = LOSSFUNC[config['loss_func']['con_loss']]
|
| 122 |
-
rec_loss_class = LOSSFUNC[config['loss_func']['rec_loss']]
|
| 123 |
-
cls_loss_func = cls_loss_class()
|
| 124 |
-
spe_loss_func = spe_loss_class()
|
| 125 |
-
con_loss_func = con_loss_class(margin=3.0)
|
| 126 |
-
rec_loss_func = rec_loss_class()
|
| 127 |
-
loss_func = {
|
| 128 |
-
'cls': cls_loss_func,
|
| 129 |
-
'spe': spe_loss_func,
|
| 130 |
-
'con': con_loss_func,
|
| 131 |
-
'rec': rec_loss_func,
|
| 132 |
-
}
|
| 133 |
-
return loss_func
|
| 134 |
-
|
| 135 |
-
def features(self, data_dict: dict) -> torch.tensor:
|
| 136 |
-
cat_data = data_dict['image']
|
| 137 |
-
# encoder
|
| 138 |
-
f_all = self.encoder_f.features(cat_data)
|
| 139 |
-
c_all = self.encoder_c.features(cat_data)
|
| 140 |
-
feat_dict = {'forgery': f_all, 'content': c_all}
|
| 141 |
-
return feat_dict
|
| 142 |
-
|
| 143 |
-
def classifier(self, features: torch.tensor) -> torch.tensor:
|
| 144 |
-
# classification, multi-task
|
| 145 |
-
# split the features into the specific and common forgery
|
| 146 |
-
f_spe = self.block_spe(features)
|
| 147 |
-
f_share = self.block_sha(features)
|
| 148 |
-
return f_spe, f_share
|
| 149 |
-
|
| 150 |
-
def get_losses(self, data_dict: dict, pred_dict: dict) -> dict:
|
| 151 |
-
if 'label_spe' in data_dict and 'recontruction_imgs' in pred_dict:
|
| 152 |
-
return self.get_train_losses(data_dict, pred_dict)
|
| 153 |
-
else: # test mode
|
| 154 |
-
return self.get_test_losses(data_dict, pred_dict)
|
| 155 |
-
|
| 156 |
-
def get_train_losses(self, data_dict: dict, pred_dict: dict) -> dict:
|
| 157 |
-
# get combined, real, fake imgs
|
| 158 |
-
cat_data = data_dict['image']
|
| 159 |
-
real_img, fake_img = cat_data.chunk(2, dim=0)
|
| 160 |
-
# get the reconstruction imgs
|
| 161 |
-
reconstruction_image_1, \
|
| 162 |
-
reconstruction_image_2, \
|
| 163 |
-
self_reconstruction_image_1, \
|
| 164 |
-
self_reconstruction_image_2 \
|
| 165 |
-
= pred_dict['recontruction_imgs']
|
| 166 |
-
# get label
|
| 167 |
-
label = data_dict['label']
|
| 168 |
-
label_spe = data_dict['label_spe']
|
| 169 |
-
# get pred
|
| 170 |
-
pred = pred_dict['cls']
|
| 171 |
-
pred_spe = pred_dict['cls_spe']
|
| 172 |
-
|
| 173 |
-
# 1. classification loss for common features
|
| 174 |
-
loss_sha = self.loss_func['cls'](pred, label)
|
| 175 |
-
|
| 176 |
-
# 2. classification loss for specific features
|
| 177 |
-
loss_spe = self.loss_func['spe'](pred_spe, label_spe)
|
| 178 |
-
|
| 179 |
-
# 3. reconstruction loss
|
| 180 |
-
self_loss_reconstruction_1 = self.loss_func['rec'](fake_img, self_reconstruction_image_1)
|
| 181 |
-
self_loss_reconstruction_2 = self.loss_func['rec'](real_img, self_reconstruction_image_2)
|
| 182 |
-
cross_loss_reconstruction_1 = self.loss_func['rec'](fake_img, reconstruction_image_2)
|
| 183 |
-
cross_loss_reconstruction_2 = self.loss_func['rec'](real_img, reconstruction_image_1)
|
| 184 |
-
loss_reconstruction = \
|
| 185 |
-
self_loss_reconstruction_1 + self_loss_reconstruction_2 + \
|
| 186 |
-
cross_loss_reconstruction_1 + cross_loss_reconstruction_2
|
| 187 |
-
|
| 188 |
-
# 4. constrative loss
|
| 189 |
-
common_features = pred_dict['feat']
|
| 190 |
-
specific_features = pred_dict['feat_spe']
|
| 191 |
-
loss_con = self.loss_func['con'](common_features, specific_features, label_spe)
|
| 192 |
-
|
| 193 |
-
# 5. total loss
|
| 194 |
-
loss = loss_sha + 0.1*loss_spe + 0.3*loss_reconstruction + 0.05*loss_con
|
| 195 |
-
loss_dict = {
|
| 196 |
-
'overall': loss,
|
| 197 |
-
'common': loss_sha,
|
| 198 |
-
'specific': loss_spe,
|
| 199 |
-
'reconstruction': loss_reconstruction,
|
| 200 |
-
'contrastive': loss_con,
|
| 201 |
-
}
|
| 202 |
-
return loss_dict
|
| 203 |
-
|
| 204 |
-
def get_test_losses(self, data_dict: dict, pred_dict: dict) -> dict:
|
| 205 |
-
# get label
|
| 206 |
-
label = data_dict['label']
|
| 207 |
-
# get pred
|
| 208 |
-
pred = pred_dict['cls']
|
| 209 |
-
# for test mode, only classification loss for common features
|
| 210 |
-
loss = self.loss_func['cls'](pred, label)
|
| 211 |
-
loss_dict = {'common': loss}
|
| 212 |
-
return loss_dict
|
| 213 |
-
|
| 214 |
-
def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict:
|
| 215 |
-
def get_accracy(label, output):
|
| 216 |
-
_, prediction = torch.max(output, 1) # argmax
|
| 217 |
-
correct = (prediction == label).sum().item()
|
| 218 |
-
accuracy = correct / prediction.size(0)
|
| 219 |
-
return accuracy
|
| 220 |
-
|
| 221 |
-
# get pred and label
|
| 222 |
-
label = data_dict['label']
|
| 223 |
-
pred = pred_dict['cls']
|
| 224 |
-
label_spe = data_dict['label_spe']
|
| 225 |
-
pred_spe = pred_dict['cls_spe']
|
| 226 |
-
|
| 227 |
-
# compute metrics for batch data
|
| 228 |
-
auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach())
|
| 229 |
-
acc_spe = get_accracy(label_spe.detach(), pred_spe.detach())
|
| 230 |
-
metric_batch_dict = {'acc': acc, 'acc_spe': acc_spe, 'auc': auc, 'eer': eer, 'ap': ap}
|
| 231 |
-
# we dont compute the video-level metrics for training
|
| 232 |
-
return metric_batch_dict
|
| 233 |
-
|
| 234 |
-
def forward(self, data_dict: dict, inference=False) -> dict:
|
| 235 |
-
# split the features into the content and forgery
|
| 236 |
-
features = self.features(data_dict)
|
| 237 |
-
forgery_features, content_features = features['forgery'], features['content']
|
| 238 |
-
# get the prediction by classifier (split the common and specific forgery)
|
| 239 |
-
f_spe, f_share = self.classifier(forgery_features)
|
| 240 |
-
|
| 241 |
-
if inference:
|
| 242 |
-
# inference only consider share loss
|
| 243 |
-
out_sha, sha_feat = self.head_sha(f_share)
|
| 244 |
-
out_spe, spe_feat = self.head_spe(f_spe)
|
| 245 |
-
prob_sha = torch.softmax(out_sha, dim=1)[:, 1]
|
| 246 |
-
self.prob.append(
|
| 247 |
-
prob_sha
|
| 248 |
-
.detach()
|
| 249 |
-
.squeeze()
|
| 250 |
-
.cpu()
|
| 251 |
-
.numpy()
|
| 252 |
-
)
|
| 253 |
-
_, prediction_class = torch.max(out_sha, 1)
|
| 254 |
-
if 'label' in data_dict:
|
| 255 |
-
self.label.append(
|
| 256 |
-
data_dict['label']
|
| 257 |
-
.detach()
|
| 258 |
-
.squeeze()
|
| 259 |
-
.cpu()
|
| 260 |
-
.numpy()
|
| 261 |
-
)
|
| 262 |
-
# deal with acc
|
| 263 |
-
common_label = (data_dict['label'] >= 1)
|
| 264 |
-
correct = (prediction_class == common_label).sum().item()
|
| 265 |
-
self.correct += correct
|
| 266 |
-
self.total += data_dict['label'].size(0)
|
| 267 |
-
|
| 268 |
-
pred_dict = {'cls': out_sha, 'feat': sha_feat}
|
| 269 |
-
return pred_dict
|
| 270 |
-
|
| 271 |
-
bs = f_share.size(0)
|
| 272 |
-
# using idx aug in the training mode
|
| 273 |
-
aug_idx = random.random()
|
| 274 |
-
if aug_idx < 0.7:
|
| 275 |
-
# real
|
| 276 |
-
idx_list = list(range(0, bs//2))
|
| 277 |
-
random.shuffle(idx_list)
|
| 278 |
-
f_share[0: bs//2] = f_share[idx_list]
|
| 279 |
-
# fake
|
| 280 |
-
idx_list = list(range(bs//2, bs))
|
| 281 |
-
random.shuffle(idx_list)
|
| 282 |
-
f_share[bs//2: bs] = f_share[idx_list]
|
| 283 |
-
|
| 284 |
-
# concat spe and share to obtain new_f_all
|
| 285 |
-
f_all = torch.cat((f_spe, f_share), dim=1)
|
| 286 |
-
|
| 287 |
-
# reconstruction loss
|
| 288 |
-
f2, f1 = f_all.chunk(2, dim=0)
|
| 289 |
-
c2, c1 = content_features.chunk(2, dim=0)
|
| 290 |
-
|
| 291 |
-
# ==== self reconstruction ==== #
|
| 292 |
-
# f1 + c1 -> f11, f11 + c1 -> near~I1
|
| 293 |
-
self_reconstruction_image_1 = self.con_gan(f1, c1)
|
| 294 |
-
|
| 295 |
-
# f2 + c2 -> f2, f2 + c2 -> near~I2
|
| 296 |
-
self_reconstruction_image_2 = self.con_gan(f2, c2)
|
| 297 |
-
|
| 298 |
-
# ==== cross combine ==== #
|
| 299 |
-
reconstruction_image_1 = self.con_gan(f1, c2)
|
| 300 |
-
reconstruction_image_2 = self.con_gan(f2, c1)
|
| 301 |
-
|
| 302 |
-
# head for spe and sha
|
| 303 |
-
out_spe, spe_feat = self.head_spe(f_spe)
|
| 304 |
-
out_sha, sha_feat = self.head_sha(f_share)
|
| 305 |
-
|
| 306 |
-
# get the probability of the pred
|
| 307 |
-
prob_sha = torch.softmax(out_sha, dim=1)[:, 1]
|
| 308 |
-
prob_spe = torch.softmax(out_spe, dim=1)[:, 1]
|
| 309 |
-
|
| 310 |
-
# build the prediction dict for each output
|
| 311 |
-
pred_dict = {
|
| 312 |
-
'cls': out_sha,
|
| 313 |
-
'prob': prob_sha,
|
| 314 |
-
'feat': sha_feat,
|
| 315 |
-
'cls_spe': out_spe,
|
| 316 |
-
'prob_spe': prob_spe,
|
| 317 |
-
'feat_spe': spe_feat,
|
| 318 |
-
'feat_content': content_features,
|
| 319 |
-
'recontruction_imgs': (
|
| 320 |
-
reconstruction_image_1,
|
| 321 |
-
reconstruction_image_2,
|
| 322 |
-
self_reconstruction_image_1,
|
| 323 |
-
self_reconstruction_image_2
|
| 324 |
-
)
|
| 325 |
-
}
|
| 326 |
-
return pred_dict
|
| 327 |
-
|
| 328 |
-
def sn_double_conv(in_channels, out_channels):
|
| 329 |
-
return nn.Sequential(
|
| 330 |
-
nn.utils.spectral_norm(
|
| 331 |
-
nn.Conv2d(in_channels, in_channels, 3, padding=1)),
|
| 332 |
-
nn.utils.spectral_norm(
|
| 333 |
-
nn.Conv2d(in_channels, out_channels, 3, padding=1, stride=2)),
|
| 334 |
-
nn.LeakyReLU(0.2, inplace=True)
|
| 335 |
-
)
|
| 336 |
-
|
| 337 |
-
def r_double_conv(in_channels, out_channels):
|
| 338 |
-
return nn.Sequential(
|
| 339 |
-
nn.Conv2d(in_channels, out_channels, 3, padding=1),
|
| 340 |
-
nn.ReLU(inplace=True),
|
| 341 |
-
nn.Conv2d(out_channels, out_channels, 3, padding=1),
|
| 342 |
-
nn.ReLU(inplace=True)
|
| 343 |
-
)
|
| 344 |
-
|
| 345 |
-
class AdaIN(nn.Module):
|
| 346 |
-
def __init__(self, eps=1e-5):
|
| 347 |
-
super().__init__()
|
| 348 |
-
self.eps = eps
|
| 349 |
-
# self.l1 = nn.Linear(num_classes, in_channel*4, bias=True) #bias is good :)
|
| 350 |
-
|
| 351 |
-
def c_norm(self, x, bs, ch, eps=1e-7):
|
| 352 |
-
# assert isinstance(x, torch.cuda.FloatTensor)
|
| 353 |
-
x_var = x.var(dim=-1) + eps
|
| 354 |
-
x_std = x_var.sqrt().view(bs, ch, 1, 1)
|
| 355 |
-
x_mean = x.mean(dim=-1).view(bs, ch, 1, 1)
|
| 356 |
-
return x_std, x_mean
|
| 357 |
-
|
| 358 |
-
def forward(self, x, y):
|
| 359 |
-
assert x.size(0)==y.size(0)
|
| 360 |
-
size = x.size()
|
| 361 |
-
bs, ch = size[:2]
|
| 362 |
-
x_ = x.view(bs, ch, -1)
|
| 363 |
-
y_ = y.reshape(bs, ch, -1)
|
| 364 |
-
x_std, x_mean = self.c_norm(x_, bs, ch, eps=self.eps)
|
| 365 |
-
y_std, y_mean = self.c_norm(y_, bs, ch, eps=self.eps)
|
| 366 |
-
out = ((x - x_mean.expand(size)) / x_std.expand(size)) \
|
| 367 |
-
* y_std.expand(size) + y_mean.expand(size)
|
| 368 |
-
return out
|
| 369 |
-
|
| 370 |
-
class Conditional_UNet(nn.Module):
|
| 371 |
-
|
| 372 |
-
def init_weight(self, std=0.2):
|
| 373 |
-
for m in self.modules():
|
| 374 |
-
cn = m.__class__.__name__
|
| 375 |
-
if cn.find('Conv') != -1:
|
| 376 |
-
m.weight.data.normal_(0., std)
|
| 377 |
-
elif cn.find('Linear') != -1:
|
| 378 |
-
m.weight.data.normal_(1., std)
|
| 379 |
-
m.bias.data.fill_(0)
|
| 380 |
-
|
| 381 |
-
def __init__(self):
|
| 382 |
-
super(Conditional_UNet, self).__init__()
|
| 383 |
-
|
| 384 |
-
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
| 385 |
-
self.maxpool = nn.MaxPool2d(2)
|
| 386 |
-
self.dropout = nn.Dropout(p=0.3)
|
| 387 |
-
#self.dropout_half = HalfDropout(p=0.3)
|
| 388 |
-
|
| 389 |
-
self.adain3 = AdaIN()
|
| 390 |
-
self.adain2 = AdaIN()
|
| 391 |
-
self.adain1 = AdaIN()
|
| 392 |
-
|
| 393 |
-
self.dconv_up3 = r_double_conv(512, 256)
|
| 394 |
-
self.dconv_up2 = r_double_conv(256, 128)
|
| 395 |
-
self.dconv_up1 = r_double_conv(128, 64)
|
| 396 |
-
|
| 397 |
-
self.conv_last = nn.Conv2d(64, 3, 1)
|
| 398 |
-
self.up_last = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
|
| 399 |
-
self.activation = nn.Tanh()
|
| 400 |
-
#self.init_weight()
|
| 401 |
-
|
| 402 |
-
def forward(self, c, x): # c is the style and x is the content
|
| 403 |
-
x = self.adain3(x, c)
|
| 404 |
-
x = self.upsample(x)
|
| 405 |
-
x = self.dropout(x)
|
| 406 |
-
x = self.dconv_up3(x)
|
| 407 |
-
c = self.upsample(c)
|
| 408 |
-
c = self.dropout(c)
|
| 409 |
-
c = self.dconv_up3(c)
|
| 410 |
-
|
| 411 |
-
x = self.adain2(x, c)
|
| 412 |
-
x = self.upsample(x)
|
| 413 |
-
x = self.dropout(x)
|
| 414 |
-
x = self.dconv_up2(x)
|
| 415 |
-
c = self.upsample(c)
|
| 416 |
-
c = self.dropout(c)
|
| 417 |
-
c = self.dconv_up2(c)
|
| 418 |
-
|
| 419 |
-
x = self.adain1(x, c)
|
| 420 |
-
x = self.upsample(x)
|
| 421 |
-
x = self.dropout(x)
|
| 422 |
-
x = self.dconv_up1(x)
|
| 423 |
-
|
| 424 |
-
x = self.conv_last(x)
|
| 425 |
-
out = self.up_last(x)
|
| 426 |
-
|
| 427 |
-
return self.activation(out)
|
| 428 |
-
|
| 429 |
-
class MLP(nn.Module):
|
| 430 |
-
def __init__(self, in_f, hidden_dim, out_f):
|
| 431 |
-
super(MLP, self).__init__()
|
| 432 |
-
self.pool = nn.AdaptiveAvgPool2d(1)
|
| 433 |
-
self.mlp = nn.Sequential(nn.Linear(in_f, hidden_dim),
|
| 434 |
-
nn.LeakyReLU(inplace=True),
|
| 435 |
-
nn.Linear(hidden_dim, hidden_dim),
|
| 436 |
-
nn.LeakyReLU(inplace=True),
|
| 437 |
-
nn.Linear(hidden_dim, out_f),)
|
| 438 |
-
|
| 439 |
-
def forward(self, x):
|
| 440 |
-
x = self.pool(x)
|
| 441 |
-
x = self.mlp(x)
|
| 442 |
-
return x
|
| 443 |
-
|
| 444 |
-
class Conv2d1x1(nn.Module):
|
| 445 |
-
def __init__(self, in_f, hidden_dim, out_f):
|
| 446 |
-
super(Conv2d1x1, self).__init__()
|
| 447 |
-
self.conv2d = nn.Sequential(nn.Conv2d(in_f, hidden_dim, 1, 1),
|
| 448 |
-
nn.LeakyReLU(inplace=True),
|
| 449 |
-
nn.Conv2d(hidden_dim, hidden_dim, 1, 1),
|
| 450 |
-
nn.LeakyReLU(inplace=True),
|
| 451 |
-
nn.Conv2d(hidden_dim, out_f, 1, 1),)
|
| 452 |
-
|
| 453 |
-
def forward(self, x):
|
| 454 |
-
x = self.conv2d(x)
|
| 455 |
-
return x
|
| 456 |
-
|
| 457 |
-
class Head(nn.Module):
|
| 458 |
-
def __init__(self, in_f, hidden_dim, out_f):
|
| 459 |
-
super(Head, self).__init__()
|
| 460 |
-
self.do = nn.Dropout(0.2)
|
| 461 |
-
self.pool = nn.AdaptiveAvgPool2d(1)
|
| 462 |
-
self.mlp = nn.Sequential(nn.Linear(in_f, hidden_dim),
|
| 463 |
-
nn.LeakyReLU(inplace=True),
|
| 464 |
-
nn.Linear(hidden_dim, out_f),)
|
| 465 |
-
|
| 466 |
-
def forward(self, x):
|
| 467 |
-
bs = x.size()[0]
|
| 468 |
-
x_feat = self.pool(x).view(bs, -1)
|
| 469 |
-
x = self.mlp(x_feat)
|
| 470 |
-
x = self.do(x)
|
| 471 |
-
return x, x_feat
|
| 472 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_deps/logger.py
DELETED
|
@@ -1,36 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import logging
|
| 3 |
-
|
| 4 |
-
import torch.distributed as dist
|
| 5 |
-
|
| 6 |
-
class RankFilter(logging.Filter):
|
| 7 |
-
def __init__(self, rank):
|
| 8 |
-
super().__init__()
|
| 9 |
-
self.rank = rank
|
| 10 |
-
|
| 11 |
-
def filter(self, record):
|
| 12 |
-
return dist.get_rank() == self.rank
|
| 13 |
-
|
| 14 |
-
def create_logger(log_path):
|
| 15 |
-
# Create log path
|
| 16 |
-
if os.path.isdir(os.path.dirname(log_path)):
|
| 17 |
-
os.makedirs(os.path.dirname(log_path), exist_ok=True)
|
| 18 |
-
|
| 19 |
-
# Create logger object
|
| 20 |
-
logger = logging.getLogger()
|
| 21 |
-
logger.setLevel(logging.INFO)
|
| 22 |
-
# Create file handler and set the formatter
|
| 23 |
-
fh = logging.FileHandler(log_path)
|
| 24 |
-
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
| 25 |
-
fh.setFormatter(formatter)
|
| 26 |
-
|
| 27 |
-
# Add the file handler to the logger
|
| 28 |
-
logger.addHandler(fh)
|
| 29 |
-
|
| 30 |
-
# Add a stream handler to print to console
|
| 31 |
-
sh = logging.StreamHandler()
|
| 32 |
-
sh.setLevel(logging.INFO) # Set logging level for stream handler
|
| 33 |
-
sh.setFormatter(formatter)
|
| 34 |
-
logger.addHandler(sh)
|
| 35 |
-
|
| 36 |
-
return logger
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_deps/loss/__init__.py
DELETED
|
@@ -1,13 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import sys
|
| 3 |
-
current_file_path = os.path.abspath(__file__)
|
| 4 |
-
parent_dir = os.path.dirname(os.path.dirname(current_file_path))
|
| 5 |
-
project_root_dir = os.path.dirname(parent_dir)
|
| 6 |
-
sys.path.append(parent_dir)
|
| 7 |
-
sys.path.append(project_root_dir)
|
| 8 |
-
|
| 9 |
-
from metrics.registry import LOSSFUNC
|
| 10 |
-
|
| 11 |
-
from .cross_entropy_loss import CrossEntropyLoss
|
| 12 |
-
from .contrastive_regularization import ContrastiveLoss
|
| 13 |
-
from .l1_loss import L1Loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_deps/loss/__pycache__/__init__.cpython-310.pyc
DELETED
|
Binary file (565 Bytes)
|
|
|
test_deps/loss/__pycache__/abstract_loss_func.cpython-310.pyc
DELETED
|
Binary file (977 Bytes)
|
|
|
test_deps/loss/__pycache__/contrastive_regularization.cpython-310.pyc
DELETED
|
Binary file (2.38 kB)
|
|
|
test_deps/loss/__pycache__/cross_entropy_loss.cpython-310.pyc
DELETED
|
Binary file (1.26 kB)
|
|
|
test_deps/loss/__pycache__/l1_loss.cpython-310.pyc
DELETED
|
Binary file (892 Bytes)
|
|
|
test_deps/loss/abstract_loss_func.py
DELETED
|
@@ -1,17 +0,0 @@
|
|
| 1 |
-
import torch.nn as nn
|
| 2 |
-
|
| 3 |
-
class AbstractLossClass(nn.Module):
|
| 4 |
-
"""Abstract class for loss functions."""
|
| 5 |
-
def __init__(self):
|
| 6 |
-
super(AbstractLossClass, self).__init__()
|
| 7 |
-
|
| 8 |
-
def forward(self, pred, label):
|
| 9 |
-
"""
|
| 10 |
-
Args:
|
| 11 |
-
pred: prediction of the model
|
| 12 |
-
label: ground truth label
|
| 13 |
-
|
| 14 |
-
Return:
|
| 15 |
-
loss: loss value
|
| 16 |
-
"""
|
| 17 |
-
raise NotImplementedError('Each subclass should implement the forward method.')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_deps/loss/contrastive_regularization.py
DELETED
|
@@ -1,78 +0,0 @@
|
|
| 1 |
-
import random
|
| 2 |
-
from collections import defaultdict
|
| 3 |
-
import torch
|
| 4 |
-
import torch.nn as nn
|
| 5 |
-
import torch.nn.functional as F
|
| 6 |
-
from .abstract_loss_func import AbstractLossClass
|
| 7 |
-
from metrics.registry import LOSSFUNC
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
def swap_spe_features(type_list, value_list):
|
| 11 |
-
type_list = type_list.cpu().numpy().tolist()
|
| 12 |
-
# get index
|
| 13 |
-
index_list = list(range(len(type_list)))
|
| 14 |
-
|
| 15 |
-
# init a dict, where its key is the type and value is the index
|
| 16 |
-
spe_dict = defaultdict(list)
|
| 17 |
-
|
| 18 |
-
# do for-loop to get spe dict
|
| 19 |
-
for i, one_type in enumerate(type_list):
|
| 20 |
-
spe_dict[one_type].append(index_list[i])
|
| 21 |
-
|
| 22 |
-
# shuffle the value list of each key
|
| 23 |
-
for keys in spe_dict.keys():
|
| 24 |
-
random.shuffle(spe_dict[keys])
|
| 25 |
-
|
| 26 |
-
# generate a new index list for the value list
|
| 27 |
-
new_index_list = []
|
| 28 |
-
for one_type in type_list:
|
| 29 |
-
value = spe_dict[one_type].pop()
|
| 30 |
-
new_index_list.append(value)
|
| 31 |
-
|
| 32 |
-
# swap the value_list by new_index_list
|
| 33 |
-
value_list_new = value_list[new_index_list]
|
| 34 |
-
|
| 35 |
-
return value_list_new
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
@LOSSFUNC.register_module(module_name="contrastive_regularization")
|
| 39 |
-
class ContrastiveLoss(AbstractLossClass):
|
| 40 |
-
def __init__(self, margin=1.0):
|
| 41 |
-
super().__init__()
|
| 42 |
-
self.margin = margin
|
| 43 |
-
|
| 44 |
-
def contrastive_loss(self, anchor, positive, negative):
|
| 45 |
-
dist_pos = F.pairwise_distance(anchor, positive)
|
| 46 |
-
dist_neg = F.pairwise_distance(anchor, negative)
|
| 47 |
-
# Compute loss as the distance between anchor and negative minus the distance between anchor and positive
|
| 48 |
-
loss = torch.mean(torch.clamp(dist_pos - dist_neg + self.margin, min=0.0))
|
| 49 |
-
return loss
|
| 50 |
-
|
| 51 |
-
def forward(self, common, specific, spe_label):
|
| 52 |
-
# prepare
|
| 53 |
-
bs = common.shape[0]
|
| 54 |
-
real_common, fake_common = common.chunk(2)
|
| 55 |
-
### common real
|
| 56 |
-
idx_list = list(range(0, bs//2))
|
| 57 |
-
random.shuffle(idx_list)
|
| 58 |
-
real_common_anchor = common[idx_list]
|
| 59 |
-
### common fake
|
| 60 |
-
idx_list = list(range(bs//2, bs))
|
| 61 |
-
random.shuffle(idx_list)
|
| 62 |
-
fake_common_anchor = common[idx_list]
|
| 63 |
-
### specific
|
| 64 |
-
specific_anchor = swap_spe_features(spe_label, specific)
|
| 65 |
-
real_specific_anchor, fake_specific_anchor = specific_anchor.chunk(2)
|
| 66 |
-
real_specific, fake_specific = specific.chunk(2)
|
| 67 |
-
|
| 68 |
-
# Compute the contrastive loss of common between real and fake
|
| 69 |
-
loss_realcommon = self.contrastive_loss(real_common, real_common_anchor, fake_common_anchor)
|
| 70 |
-
loss_fakecommon = self.contrastive_loss(fake_common, fake_common_anchor, real_common_anchor)
|
| 71 |
-
|
| 72 |
-
# Comupte the constrastive loss of specific between real and fake
|
| 73 |
-
loss_realspecific = self.contrastive_loss(real_specific, real_specific_anchor, fake_specific_anchor)
|
| 74 |
-
loss_fakespecific = self.contrastive_loss(fake_specific, fake_specific_anchor, real_specific_anchor)
|
| 75 |
-
|
| 76 |
-
# Compute the final loss as the sum of all contrastive losses
|
| 77 |
-
loss = loss_realcommon + loss_fakecommon + loss_fakespecific + loss_realspecific
|
| 78 |
-
return loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_deps/loss/cross_entropy_loss.py
DELETED
|
@@ -1,26 +0,0 @@
|
|
| 1 |
-
import torch.nn as nn
|
| 2 |
-
from .abstract_loss_func import AbstractLossClass
|
| 3 |
-
from metrics.registry import LOSSFUNC
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
@LOSSFUNC.register_module(module_name="cross_entropy")
|
| 7 |
-
class CrossEntropyLoss(AbstractLossClass):
|
| 8 |
-
def __init__(self):
|
| 9 |
-
super().__init__()
|
| 10 |
-
self.loss_fn = nn.CrossEntropyLoss()
|
| 11 |
-
|
| 12 |
-
def forward(self, inputs, targets):
|
| 13 |
-
"""
|
| 14 |
-
Computes the cross-entropy loss.
|
| 15 |
-
|
| 16 |
-
Args:
|
| 17 |
-
inputs: A PyTorch tensor of size (batch_size, num_classes) containing the predicted scores.
|
| 18 |
-
targets: A PyTorch tensor of size (batch_size) containing the ground-truth class indices.
|
| 19 |
-
|
| 20 |
-
Returns:
|
| 21 |
-
A scalar tensor representing the cross-entropy loss.
|
| 22 |
-
"""
|
| 23 |
-
# Compute the cross-entropy loss
|
| 24 |
-
loss = self.loss_fn(inputs, targets)
|
| 25 |
-
|
| 26 |
-
return loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_deps/loss/l1_loss.py
DELETED
|
@@ -1,19 +0,0 @@
|
|
| 1 |
-
import torch.nn as nn
|
| 2 |
-
from .abstract_loss_func import AbstractLossClass
|
| 3 |
-
from metrics.registry import LOSSFUNC
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
@LOSSFUNC.register_module(module_name="l1loss")
|
| 7 |
-
class L1Loss(AbstractLossClass):
|
| 8 |
-
def __init__(self):
|
| 9 |
-
super().__init__()
|
| 10 |
-
self.loss_fn = nn.L1Loss()
|
| 11 |
-
|
| 12 |
-
def forward(self, inputs, targets):
|
| 13 |
-
"""
|
| 14 |
-
Computes the l1 loss.
|
| 15 |
-
"""
|
| 16 |
-
# Compute the l1 loss
|
| 17 |
-
loss = self.loss_fn(inputs, targets)
|
| 18 |
-
|
| 19 |
-
return loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_deps/metrics/__init__.py
DELETED
|
@@ -1,7 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import sys
|
| 3 |
-
current_file_path = os.path.abspath(__file__)
|
| 4 |
-
parent_dir = os.path.dirname(os.path.dirname(current_file_path))
|
| 5 |
-
project_root_dir = os.path.dirname(parent_dir)
|
| 6 |
-
sys.path.append(parent_dir)
|
| 7 |
-
sys.path.append(project_root_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_deps/metrics/__pycache__/__init__.cpython-310.pyc
DELETED
|
Binary file (351 Bytes)
|
|
|
test_deps/metrics/__pycache__/base_metrics_class.cpython-310.pyc
DELETED
|
Binary file (6.21 kB)
|
|
|
test_deps/metrics/__pycache__/registry.cpython-310.pyc
DELETED
|
Binary file (1.01 kB)
|
|
|
test_deps/metrics/base_metrics_class.py
DELETED
|
@@ -1,205 +0,0 @@
|
|
| 1 |
-
import numpy as np
|
| 2 |
-
from sklearn import metrics
|
| 3 |
-
from collections import defaultdict
|
| 4 |
-
import torch
|
| 5 |
-
import torch.nn as nn
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
def get_accracy(output, label):
|
| 9 |
-
_, prediction = torch.max(output, 1) # argmax
|
| 10 |
-
correct = (prediction == label).sum().item()
|
| 11 |
-
accuracy = correct / prediction.size(0)
|
| 12 |
-
return accuracy
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def get_prediction(output, label):
|
| 16 |
-
prob = nn.functional.softmax(output, dim=1)[:, 1]
|
| 17 |
-
prob = prob.view(prob.size(0), 1)
|
| 18 |
-
label = label.view(label.size(0), 1)
|
| 19 |
-
#print(prob.size(), label.size())
|
| 20 |
-
datas = torch.cat((prob, label.float()), dim=1)
|
| 21 |
-
return datas
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
def calculate_metrics_for_train(label, output):
|
| 25 |
-
if output.size(1) == 2:
|
| 26 |
-
prob = torch.softmax(output, dim=1)[:, 1]
|
| 27 |
-
else:
|
| 28 |
-
prob = output
|
| 29 |
-
|
| 30 |
-
# Accuracy
|
| 31 |
-
_, prediction = torch.max(output, 1)
|
| 32 |
-
correct = (prediction == label).sum().item()
|
| 33 |
-
accuracy = correct / prediction.size(0)
|
| 34 |
-
|
| 35 |
-
# Average Precision
|
| 36 |
-
y_true = label.cpu().detach().numpy()
|
| 37 |
-
y_pred = prob.cpu().detach().numpy()
|
| 38 |
-
ap = metrics.average_precision_score(y_true, y_pred)
|
| 39 |
-
|
| 40 |
-
# AUC and EER
|
| 41 |
-
try:
|
| 42 |
-
fpr, tpr, thresholds = metrics.roc_curve(label.squeeze().cpu().numpy(),
|
| 43 |
-
prob.squeeze().cpu().numpy(),
|
| 44 |
-
pos_label=1)
|
| 45 |
-
except:
|
| 46 |
-
# for the case when we only have one sample
|
| 47 |
-
return None, None, accuracy, ap
|
| 48 |
-
|
| 49 |
-
if np.isnan(fpr[0]) or np.isnan(tpr[0]):
|
| 50 |
-
# for the case when all the samples within a batch is fake/real
|
| 51 |
-
auc, eer = None, None
|
| 52 |
-
else:
|
| 53 |
-
auc = metrics.auc(fpr, tpr)
|
| 54 |
-
fnr = 1 - tpr
|
| 55 |
-
eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
|
| 56 |
-
|
| 57 |
-
return auc, eer, accuracy, ap
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
# ------------ compute average metrics of batches---------------------
|
| 61 |
-
class Metrics_batch():
|
| 62 |
-
def __init__(self):
|
| 63 |
-
self.tprs = []
|
| 64 |
-
self.mean_fpr = np.linspace(0, 1, 100)
|
| 65 |
-
self.aucs = []
|
| 66 |
-
self.eers = []
|
| 67 |
-
self.aps = []
|
| 68 |
-
|
| 69 |
-
self.correct = 0
|
| 70 |
-
self.total = 0
|
| 71 |
-
self.losses = []
|
| 72 |
-
|
| 73 |
-
def update(self, label, output):
|
| 74 |
-
acc = self._update_acc(label, output)
|
| 75 |
-
if output.size(1) == 2:
|
| 76 |
-
prob = torch.softmax(output, dim=1)[:, 1]
|
| 77 |
-
else:
|
| 78 |
-
prob = output
|
| 79 |
-
#label = 1-label
|
| 80 |
-
#prob = torch.softmax(output, dim=1)[:, 1]
|
| 81 |
-
auc, eer = self._update_auc(label, prob)
|
| 82 |
-
ap = self._update_ap(label, prob)
|
| 83 |
-
|
| 84 |
-
return acc, auc, eer, ap
|
| 85 |
-
|
| 86 |
-
def _update_auc(self, lab, prob):
|
| 87 |
-
fpr, tpr, thresholds = metrics.roc_curve(lab.squeeze().cpu().numpy(),
|
| 88 |
-
prob.squeeze().cpu().numpy(),
|
| 89 |
-
pos_label=1)
|
| 90 |
-
if np.isnan(fpr[0]) or np.isnan(tpr[0]):
|
| 91 |
-
return -1, -1
|
| 92 |
-
|
| 93 |
-
auc = metrics.auc(fpr, tpr)
|
| 94 |
-
interp_tpr = np.interp(self.mean_fpr, fpr, tpr)
|
| 95 |
-
interp_tpr[0] = 0.0
|
| 96 |
-
self.tprs.append(interp_tpr)
|
| 97 |
-
self.aucs.append(auc)
|
| 98 |
-
|
| 99 |
-
# return auc
|
| 100 |
-
|
| 101 |
-
# EER
|
| 102 |
-
fnr = 1 - tpr
|
| 103 |
-
eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
|
| 104 |
-
self.eers.append(eer)
|
| 105 |
-
|
| 106 |
-
return auc, eer
|
| 107 |
-
|
| 108 |
-
def _update_acc(self, lab, output):
|
| 109 |
-
_, prediction = torch.max(output, 1) # argmax
|
| 110 |
-
correct = (prediction == lab).sum().item()
|
| 111 |
-
accuracy = correct / prediction.size(0)
|
| 112 |
-
# self.accs.append(accuracy)
|
| 113 |
-
self.correct = self.correct+correct
|
| 114 |
-
self.total = self.total+lab.size(0)
|
| 115 |
-
return accuracy
|
| 116 |
-
|
| 117 |
-
def _update_ap(self, label, prob):
|
| 118 |
-
y_true = label.cpu().detach().numpy()
|
| 119 |
-
y_pred = prob.cpu().detach().numpy()
|
| 120 |
-
ap = metrics.average_precision_score(y_true,y_pred)
|
| 121 |
-
self.aps.append(ap)
|
| 122 |
-
|
| 123 |
-
return np.mean(ap)
|
| 124 |
-
|
| 125 |
-
def get_mean_metrics(self):
|
| 126 |
-
mean_acc, std_acc = self.correct/self.total, 0
|
| 127 |
-
mean_auc, std_auc = self._mean_auc()
|
| 128 |
-
mean_err, std_err = np.mean(self.eers), np.std(self.eers)
|
| 129 |
-
mean_ap, std_ap = np.mean(self.aps), np.std(self.aps)
|
| 130 |
-
|
| 131 |
-
return {'acc':mean_acc, 'auc':mean_auc, 'eer':mean_err, 'ap':mean_ap}
|
| 132 |
-
|
| 133 |
-
def _mean_auc(self):
|
| 134 |
-
mean_tpr = np.mean(self.tprs, axis=0)
|
| 135 |
-
mean_tpr[-1] = 1.0
|
| 136 |
-
mean_auc = metrics.auc(self.mean_fpr, mean_tpr)
|
| 137 |
-
std_auc = np.std(self.aucs)
|
| 138 |
-
return mean_auc, std_auc
|
| 139 |
-
|
| 140 |
-
def clear(self):
|
| 141 |
-
self.tprs.clear()
|
| 142 |
-
self.aucs.clear()
|
| 143 |
-
# self.accs.clear()
|
| 144 |
-
self.correct=0
|
| 145 |
-
self.total=0
|
| 146 |
-
self.eers.clear()
|
| 147 |
-
self.aps.clear()
|
| 148 |
-
self.losses.clear()
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
# ------------ compute average metrics of all data ---------------------
|
| 152 |
-
class Metrics_all():
|
| 153 |
-
def __init__(self):
|
| 154 |
-
self.probs = []
|
| 155 |
-
self.labels = []
|
| 156 |
-
self.correct = 0
|
| 157 |
-
self.total = 0
|
| 158 |
-
|
| 159 |
-
def store(self, label, output):
|
| 160 |
-
prob = torch.softmax(output, dim=1)[:, 1]
|
| 161 |
-
_, prediction = torch.max(output, 1) # argmax
|
| 162 |
-
correct = (prediction == label).sum().item()
|
| 163 |
-
self.correct += correct
|
| 164 |
-
self.total += label.size(0)
|
| 165 |
-
self.labels.append(label.squeeze().cpu().numpy())
|
| 166 |
-
self.probs.append(prob.squeeze().cpu().numpy())
|
| 167 |
-
|
| 168 |
-
def get_metrics(self):
|
| 169 |
-
y_pred = np.concatenate(self.probs)
|
| 170 |
-
y_true = np.concatenate(self.labels)
|
| 171 |
-
# auc
|
| 172 |
-
fpr, tpr, thresholds = metrics.roc_curve(y_true,y_pred,pos_label=1)
|
| 173 |
-
auc = metrics.auc(fpr, tpr)
|
| 174 |
-
# eer
|
| 175 |
-
fnr = 1 - tpr
|
| 176 |
-
eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
|
| 177 |
-
# ap
|
| 178 |
-
ap = metrics.average_precision_score(y_true,y_pred)
|
| 179 |
-
# acc
|
| 180 |
-
acc = self.correct / self.total
|
| 181 |
-
return {'acc':acc, 'auc':auc, 'eer':eer, 'ap':ap}
|
| 182 |
-
|
| 183 |
-
def clear(self):
|
| 184 |
-
self.probs.clear()
|
| 185 |
-
self.labels.clear()
|
| 186 |
-
self.correct = 0
|
| 187 |
-
self.total = 0
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
# only used to record a series of scalar value
|
| 191 |
-
class Recorder:
|
| 192 |
-
def __init__(self):
|
| 193 |
-
self.sum = 0
|
| 194 |
-
self.num = 0
|
| 195 |
-
def update(self, item, num=1):
|
| 196 |
-
if item is not None:
|
| 197 |
-
self.sum += item * num
|
| 198 |
-
self.num += num
|
| 199 |
-
def average(self):
|
| 200 |
-
if self.num == 0:
|
| 201 |
-
return None
|
| 202 |
-
return self.sum/self.num
|
| 203 |
-
def clear(self):
|
| 204 |
-
self.sum = 0
|
| 205 |
-
self.num = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_deps/metrics/registry.py
DELETED
|
@@ -1,20 +0,0 @@
|
|
| 1 |
-
class Registry(object):
|
| 2 |
-
def __init__(self):
|
| 3 |
-
self.data = {}
|
| 4 |
-
|
| 5 |
-
def register_module(self, module_name=None):
|
| 6 |
-
def _register(cls):
|
| 7 |
-
name = module_name
|
| 8 |
-
if module_name is None:
|
| 9 |
-
name = cls.__name__
|
| 10 |
-
self.data[name] = cls
|
| 11 |
-
return cls
|
| 12 |
-
return _register
|
| 13 |
-
|
| 14 |
-
def __getitem__(self, key):
|
| 15 |
-
return self.data[key]
|
| 16 |
-
|
| 17 |
-
BACKBONE = Registry()
|
| 18 |
-
DETECTOR = Registry()
|
| 19 |
-
TRAINER = Registry()
|
| 20 |
-
LOSSFUNC = Registry()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_deps/metrics/utils.py
DELETED
|
@@ -1,88 +0,0 @@
|
|
| 1 |
-
from sklearn import metrics
|
| 2 |
-
import numpy as np
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
def parse_metric_for_print(metric_dict):
|
| 6 |
-
if metric_dict is None:
|
| 7 |
-
return "\n"
|
| 8 |
-
str = "\n"
|
| 9 |
-
str += "================================ Each dataset best metric ================================ \n"
|
| 10 |
-
for key, value in metric_dict.items():
|
| 11 |
-
if key != 'avg':
|
| 12 |
-
str= str+ f"| {key}: "
|
| 13 |
-
for k,v in value.items():
|
| 14 |
-
str = str + f" {k}={v} "
|
| 15 |
-
str= str+ "| \n"
|
| 16 |
-
else:
|
| 17 |
-
str += "============================================================================================= \n"
|
| 18 |
-
str += "================================== Average best metric ====================================== \n"
|
| 19 |
-
avg_dict = value
|
| 20 |
-
for avg_key, avg_value in avg_dict.items():
|
| 21 |
-
if avg_key == 'dataset_dict':
|
| 22 |
-
for key,value in avg_value.items():
|
| 23 |
-
str = str + f"| {key}: {value} | \n"
|
| 24 |
-
else:
|
| 25 |
-
str = str + f"| avg {avg_key}: {avg_value} | \n"
|
| 26 |
-
str += "============================================================================================="
|
| 27 |
-
return str
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
def get_test_metrics(y_pred, y_true, img_names=None, logger=None):
|
| 31 |
-
def get_video_metrics(image, pred, label):
|
| 32 |
-
result_dict = {}
|
| 33 |
-
new_label = []
|
| 34 |
-
new_pred = []
|
| 35 |
-
# print(image[0])
|
| 36 |
-
# print(pred.shape)
|
| 37 |
-
# print(label.shape)
|
| 38 |
-
for item in np.transpose(np.stack((image, pred, label)), (1, 0)):
|
| 39 |
-
|
| 40 |
-
s = item[0]
|
| 41 |
-
if '\\' in s:
|
| 42 |
-
parts = s.split('\\')
|
| 43 |
-
else:
|
| 44 |
-
parts = s.split('/')
|
| 45 |
-
a = parts[-2]
|
| 46 |
-
b = parts[-1]
|
| 47 |
-
|
| 48 |
-
if a not in result_dict:
|
| 49 |
-
result_dict[a] = []
|
| 50 |
-
|
| 51 |
-
result_dict[a].append(item)
|
| 52 |
-
image_arr = list(result_dict.values())
|
| 53 |
-
|
| 54 |
-
for video in image_arr:
|
| 55 |
-
pred_sum = 0
|
| 56 |
-
label_sum = 0
|
| 57 |
-
leng = 0
|
| 58 |
-
for frame in video:
|
| 59 |
-
pred_sum += float(frame[1])
|
| 60 |
-
label_sum += int(frame[2])
|
| 61 |
-
leng += 1
|
| 62 |
-
new_pred.append(pred_sum / leng)
|
| 63 |
-
new_label.append(int(label_sum / leng))
|
| 64 |
-
fpr, tpr, thresholds = metrics.roc_curve(new_label, new_pred)
|
| 65 |
-
v_auc = metrics.auc(fpr, tpr)
|
| 66 |
-
fnr = 1 - tpr
|
| 67 |
-
v_eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
|
| 68 |
-
return v_auc, v_eer
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
y_pred = y_pred.squeeze()
|
| 72 |
-
|
| 73 |
-
# For UCF, where labels for different manipulations are not consistent.
|
| 74 |
-
y_true[y_true >= 1] = 1
|
| 75 |
-
# auc
|
| 76 |
-
fpr, tpr, thresholds = metrics.roc_curve(y_true, y_pred, pos_label=1)
|
| 77 |
-
auc = metrics.auc(fpr, tpr)
|
| 78 |
-
# eer
|
| 79 |
-
fnr = 1 - tpr
|
| 80 |
-
eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
|
| 81 |
-
# ap
|
| 82 |
-
ap = metrics.average_precision_score(y_true, y_pred)
|
| 83 |
-
# acc
|
| 84 |
-
prediction_class = (y_pred > 0.5).astype(int)
|
| 85 |
-
correct = (prediction_class == np.clip(y_true, a_min=0, a_max=1)).sum().item()
|
| 86 |
-
acc = correct / len(prediction_class)
|
| 87 |
-
|
| 88 |
-
return {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap, 'pred': y_pred, 'label': y_true}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_deps/networks/__init__.py
DELETED
|
@@ -1,11 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import sys
|
| 3 |
-
current_file_path = os.path.abspath(__file__)
|
| 4 |
-
parent_dir = os.path.dirname(os.path.dirname(current_file_path))
|
| 5 |
-
project_root_dir = os.path.dirname(parent_dir)
|
| 6 |
-
sys.path.append(parent_dir)
|
| 7 |
-
sys.path.append(project_root_dir)
|
| 8 |
-
|
| 9 |
-
from metrics.registry import BACKBONE
|
| 10 |
-
|
| 11 |
-
from .xception import Xception
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_deps/networks/__pycache__/__init__.cpython-310.pyc
DELETED
|
Binary file (447 Bytes)
|
|
|
test_deps/networks/__pycache__/xception.cpython-310.pyc
DELETED
|
Binary file (6.7 kB)
|
|
|
test_deps/networks/xception.py
DELETED
|
@@ -1,285 +0,0 @@
|
|
| 1 |
-
'''
|
| 2 |
-
# author: Zhiyuan Yan
|
| 3 |
-
# email: zhiyuanyan@link.cuhk.edu.cn
|
| 4 |
-
# date: 2023-0706
|
| 5 |
-
|
| 6 |
-
The code is mainly modified from GitHub link below:
|
| 7 |
-
https://github.com/ondyari/FaceForensics/blob/master/classification/network/xception.py
|
| 8 |
-
'''
|
| 9 |
-
|
| 10 |
-
import os
|
| 11 |
-
import argparse
|
| 12 |
-
import logging
|
| 13 |
-
|
| 14 |
-
import math
|
| 15 |
-
import torch
|
| 16 |
-
# import pretrainedmodels
|
| 17 |
-
import torch.nn as nn
|
| 18 |
-
import torch.nn.functional as F
|
| 19 |
-
|
| 20 |
-
import torch.utils.model_zoo as model_zoo
|
| 21 |
-
from torch.nn import init
|
| 22 |
-
from typing import Union
|
| 23 |
-
from metrics.registry import BACKBONE
|
| 24 |
-
|
| 25 |
-
logger = logging.getLogger(__name__)
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
class SeparableConv2d(nn.Module):
|
| 30 |
-
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False):
|
| 31 |
-
super(SeparableConv2d, self).__init__()
|
| 32 |
-
|
| 33 |
-
self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size,
|
| 34 |
-
stride, padding, dilation, groups=in_channels, bias=bias)
|
| 35 |
-
self.pointwise = nn.Conv2d(
|
| 36 |
-
in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias)
|
| 37 |
-
|
| 38 |
-
def forward(self, x):
|
| 39 |
-
x = self.conv1(x)
|
| 40 |
-
x = self.pointwise(x)
|
| 41 |
-
return x
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
class Block(nn.Module):
|
| 45 |
-
def __init__(self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True):
|
| 46 |
-
super(Block, self).__init__()
|
| 47 |
-
|
| 48 |
-
if out_filters != in_filters or strides != 1:
|
| 49 |
-
self.skip = nn.Conv2d(in_filters, out_filters,
|
| 50 |
-
1, stride=strides, bias=False)
|
| 51 |
-
self.skipbn = nn.BatchNorm2d(out_filters)
|
| 52 |
-
else:
|
| 53 |
-
self.skip = None
|
| 54 |
-
|
| 55 |
-
self.relu = nn.ReLU(inplace=True)
|
| 56 |
-
rep = []
|
| 57 |
-
|
| 58 |
-
filters = in_filters
|
| 59 |
-
if grow_first: # whether the number of filters grows first
|
| 60 |
-
rep.append(self.relu)
|
| 61 |
-
rep.append(SeparableConv2d(in_filters, out_filters,
|
| 62 |
-
3, stride=1, padding=1, bias=False))
|
| 63 |
-
rep.append(nn.BatchNorm2d(out_filters))
|
| 64 |
-
filters = out_filters
|
| 65 |
-
|
| 66 |
-
for i in range(reps-1):
|
| 67 |
-
rep.append(self.relu)
|
| 68 |
-
rep.append(SeparableConv2d(filters, filters,
|
| 69 |
-
3, stride=1, padding=1, bias=False))
|
| 70 |
-
rep.append(nn.BatchNorm2d(filters))
|
| 71 |
-
|
| 72 |
-
if not grow_first:
|
| 73 |
-
rep.append(self.relu)
|
| 74 |
-
rep.append(SeparableConv2d(in_filters, out_filters,
|
| 75 |
-
3, stride=1, padding=1, bias=False))
|
| 76 |
-
rep.append(nn.BatchNorm2d(out_filters))
|
| 77 |
-
|
| 78 |
-
if not start_with_relu:
|
| 79 |
-
rep = rep[1:]
|
| 80 |
-
else:
|
| 81 |
-
rep[0] = nn.ReLU(inplace=False)
|
| 82 |
-
|
| 83 |
-
if strides != 1:
|
| 84 |
-
rep.append(nn.MaxPool2d(3, strides, 1))
|
| 85 |
-
self.rep = nn.Sequential(*rep)
|
| 86 |
-
|
| 87 |
-
def forward(self, inp):
|
| 88 |
-
x = self.rep(inp)
|
| 89 |
-
|
| 90 |
-
if self.skip is not None:
|
| 91 |
-
skip = self.skip(inp)
|
| 92 |
-
skip = self.skipbn(skip)
|
| 93 |
-
else:
|
| 94 |
-
skip = inp
|
| 95 |
-
|
| 96 |
-
x += skip
|
| 97 |
-
return x
|
| 98 |
-
|
| 99 |
-
def add_gaussian_noise(ins, mean=0, stddev=0.2):
|
| 100 |
-
noise = ins.data.new(ins.size()).normal_(mean, stddev)
|
| 101 |
-
return ins + noise
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
@BACKBONE.register_module(module_name="xception")
|
| 105 |
-
class Xception(nn.Module):
|
| 106 |
-
"""
|
| 107 |
-
Xception optimized for the ImageNet dataset, as specified in
|
| 108 |
-
https://arxiv.org/pdf/1610.02357.pdf
|
| 109 |
-
"""
|
| 110 |
-
|
| 111 |
-
def __init__(self, xception_config):
|
| 112 |
-
""" Constructor
|
| 113 |
-
Args:
|
| 114 |
-
xception_config: configuration file with the dict format
|
| 115 |
-
"""
|
| 116 |
-
super(Xception, self).__init__()
|
| 117 |
-
self.num_classes = xception_config["num_classes"]
|
| 118 |
-
self.mode = xception_config["mode"]
|
| 119 |
-
inc = xception_config["inc"]
|
| 120 |
-
dropout = xception_config["dropout"]
|
| 121 |
-
|
| 122 |
-
# Entry flow
|
| 123 |
-
self.conv1 = nn.Conv2d(inc, 32, 3, 2, 0, bias=False)
|
| 124 |
-
|
| 125 |
-
self.bn1 = nn.BatchNorm2d(32)
|
| 126 |
-
self.relu = nn.ReLU(inplace=True)
|
| 127 |
-
|
| 128 |
-
self.conv2 = nn.Conv2d(32, 64, 3, bias=False)
|
| 129 |
-
self.bn2 = nn.BatchNorm2d(64)
|
| 130 |
-
# do relu here
|
| 131 |
-
|
| 132 |
-
self.block1 = Block(
|
| 133 |
-
64, 128, 2, 2, start_with_relu=False, grow_first=True)
|
| 134 |
-
self.block2 = Block(
|
| 135 |
-
128, 256, 2, 2, start_with_relu=True, grow_first=True)
|
| 136 |
-
self.block3 = Block(
|
| 137 |
-
256, 728, 2, 2, start_with_relu=True, grow_first=True)
|
| 138 |
-
|
| 139 |
-
# middle flow
|
| 140 |
-
self.block4 = Block(
|
| 141 |
-
728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
| 142 |
-
self.block5 = Block(
|
| 143 |
-
728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
| 144 |
-
self.block6 = Block(
|
| 145 |
-
728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
| 146 |
-
self.block7 = Block(
|
| 147 |
-
728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
| 148 |
-
|
| 149 |
-
self.block8 = Block(
|
| 150 |
-
728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
| 151 |
-
self.block9 = Block(
|
| 152 |
-
728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
| 153 |
-
self.block10 = Block(
|
| 154 |
-
728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
| 155 |
-
self.block11 = Block(
|
| 156 |
-
728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
| 157 |
-
|
| 158 |
-
# Exit flow
|
| 159 |
-
self.block12 = Block(
|
| 160 |
-
728, 1024, 2, 2, start_with_relu=True, grow_first=False)
|
| 161 |
-
|
| 162 |
-
self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1)
|
| 163 |
-
self.bn3 = nn.BatchNorm2d(1536)
|
| 164 |
-
|
| 165 |
-
# do relu here
|
| 166 |
-
self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1)
|
| 167 |
-
self.bn4 = nn.BatchNorm2d(2048)
|
| 168 |
-
# used for iid
|
| 169 |
-
final_channel = 2048
|
| 170 |
-
if self.mode == 'adjust_channel_iid':
|
| 171 |
-
final_channel = 512
|
| 172 |
-
self.mode = 'adjust_channel'
|
| 173 |
-
self.last_linear = nn.Linear(final_channel, self.num_classes)
|
| 174 |
-
if dropout:
|
| 175 |
-
self.last_linear = nn.Sequential(
|
| 176 |
-
nn.Dropout(p=dropout),
|
| 177 |
-
nn.Linear(final_channel, self.num_classes)
|
| 178 |
-
)
|
| 179 |
-
|
| 180 |
-
self.adjust_channel = nn.Sequential(
|
| 181 |
-
nn.Conv2d(2048, 512, 1, 1),
|
| 182 |
-
nn.BatchNorm2d(512),
|
| 183 |
-
nn.ReLU(inplace=False),
|
| 184 |
-
)
|
| 185 |
-
|
| 186 |
-
def fea_part1_0(self, x):
|
| 187 |
-
x = self.conv1(x)
|
| 188 |
-
x = self.bn1(x)
|
| 189 |
-
x = self.relu(x)
|
| 190 |
-
|
| 191 |
-
return x
|
| 192 |
-
|
| 193 |
-
def fea_part1_1(self, x):
|
| 194 |
-
|
| 195 |
-
x = self.conv2(x)
|
| 196 |
-
x = self.bn2(x)
|
| 197 |
-
x = self.relu(x)
|
| 198 |
-
|
| 199 |
-
return x
|
| 200 |
-
|
| 201 |
-
def fea_part1(self, x):
|
| 202 |
-
x = self.conv1(x)
|
| 203 |
-
x = self.bn1(x)
|
| 204 |
-
x = self.relu(x)
|
| 205 |
-
|
| 206 |
-
x = self.conv2(x)
|
| 207 |
-
x = self.bn2(x)
|
| 208 |
-
x = self.relu(x)
|
| 209 |
-
|
| 210 |
-
return x
|
| 211 |
-
|
| 212 |
-
def fea_part2(self, x):
|
| 213 |
-
x = self.block1(x)
|
| 214 |
-
x = self.block2(x)
|
| 215 |
-
x = self.block3(x)
|
| 216 |
-
|
| 217 |
-
return x
|
| 218 |
-
|
| 219 |
-
def fea_part3(self, x):
|
| 220 |
-
if self.mode == "shallow_xception":
|
| 221 |
-
return x
|
| 222 |
-
else:
|
| 223 |
-
x = self.block4(x)
|
| 224 |
-
x = self.block5(x)
|
| 225 |
-
x = self.block6(x)
|
| 226 |
-
x = self.block7(x)
|
| 227 |
-
return x
|
| 228 |
-
|
| 229 |
-
def fea_part4(self, x):
|
| 230 |
-
if self.mode == "shallow_xception":
|
| 231 |
-
x = self.block12(x)
|
| 232 |
-
else:
|
| 233 |
-
x = self.block8(x)
|
| 234 |
-
x = self.block9(x)
|
| 235 |
-
x = self.block10(x)
|
| 236 |
-
x = self.block11(x)
|
| 237 |
-
x = self.block12(x)
|
| 238 |
-
return x
|
| 239 |
-
|
| 240 |
-
def fea_part5(self, x):
|
| 241 |
-
x = self.conv3(x)
|
| 242 |
-
x = self.bn3(x)
|
| 243 |
-
x = self.relu(x)
|
| 244 |
-
|
| 245 |
-
x = self.conv4(x)
|
| 246 |
-
x = self.bn4(x)
|
| 247 |
-
|
| 248 |
-
return x
|
| 249 |
-
|
| 250 |
-
def features(self, input):
|
| 251 |
-
x = self.fea_part1(input)
|
| 252 |
-
|
| 253 |
-
x = self.fea_part2(x)
|
| 254 |
-
x = self.fea_part3(x)
|
| 255 |
-
x = self.fea_part4(x)
|
| 256 |
-
|
| 257 |
-
x = self.fea_part5(x)
|
| 258 |
-
|
| 259 |
-
if self.mode == 'adjust_channel':
|
| 260 |
-
x = self.adjust_channel(x)
|
| 261 |
-
|
| 262 |
-
return x
|
| 263 |
-
|
| 264 |
-
def classifier(self, features,id_feat=None):
|
| 265 |
-
# for iid
|
| 266 |
-
if self.mode == 'adjust_channel':
|
| 267 |
-
x = features
|
| 268 |
-
else:
|
| 269 |
-
x = self.relu(features)
|
| 270 |
-
|
| 271 |
-
if len(x.shape) == 4:
|
| 272 |
-
x = F.adaptive_avg_pool2d(x, (1, 1))
|
| 273 |
-
x = x.view(x.size(0), -1)
|
| 274 |
-
self.last_emb = x
|
| 275 |
-
# for iid
|
| 276 |
-
if id_feat!=None:
|
| 277 |
-
out = self.last_linear(x-id_feat)
|
| 278 |
-
else:
|
| 279 |
-
out = self.last_linear(x)
|
| 280 |
-
return out
|
| 281 |
-
|
| 282 |
-
def forward(self, input):
|
| 283 |
-
x = self.features(input)
|
| 284 |
-
out = self.classifier(x)
|
| 285 |
-
return out, x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_deps/optimizor/LinearLR.py
DELETED
|
@@ -1,20 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
from torch.optim import SGD
|
| 3 |
-
from torch.optim.lr_scheduler import _LRScheduler
|
| 4 |
-
|
| 5 |
-
class LinearDecayLR(_LRScheduler):
|
| 6 |
-
def __init__(self, optimizer, n_epoch, start_decay, last_epoch=-1):
|
| 7 |
-
self.start_decay=start_decay
|
| 8 |
-
self.n_epoch=n_epoch
|
| 9 |
-
super(LinearDecayLR, self).__init__(optimizer, last_epoch)
|
| 10 |
-
|
| 11 |
-
def get_lr(self):
|
| 12 |
-
last_epoch = self.last_epoch
|
| 13 |
-
n_epoch=self.n_epoch
|
| 14 |
-
b_lr=self.base_lrs[0]
|
| 15 |
-
start_decay=self.start_decay
|
| 16 |
-
if last_epoch>start_decay:
|
| 17 |
-
lr=b_lr-b_lr/(n_epoch-start_decay)*(last_epoch-start_decay)
|
| 18 |
-
else:
|
| 19 |
-
lr=b_lr
|
| 20 |
-
return [lr]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_deps/optimizor/SAM.py
DELETED
|
@@ -1,77 +0,0 @@
|
|
| 1 |
-
# borrowed from
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
import torch.nn as nn
|
| 7 |
-
|
| 8 |
-
def disable_running_stats(model):
|
| 9 |
-
def _disable(module):
|
| 10 |
-
if isinstance(module, nn.BatchNorm2d):
|
| 11 |
-
module.backup_momentum = module.momentum
|
| 12 |
-
module.momentum = 0
|
| 13 |
-
|
| 14 |
-
model.apply(_disable)
|
| 15 |
-
|
| 16 |
-
def enable_running_stats(model):
|
| 17 |
-
def _enable(module):
|
| 18 |
-
if isinstance(module, nn.BatchNorm2d) and hasattr(module, "backup_momentum"):
|
| 19 |
-
module.momentum = module.backup_momentum
|
| 20 |
-
|
| 21 |
-
model.apply(_enable)
|
| 22 |
-
|
| 23 |
-
class SAM(torch.optim.Optimizer):
|
| 24 |
-
def __init__(self, params, base_optimizer, rho=0.05, **kwargs):
|
| 25 |
-
assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
|
| 26 |
-
|
| 27 |
-
defaults = dict(rho=rho, **kwargs)
|
| 28 |
-
super(SAM, self).__init__(params, defaults)
|
| 29 |
-
|
| 30 |
-
self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
|
| 31 |
-
self.param_groups = self.base_optimizer.param_groups
|
| 32 |
-
|
| 33 |
-
@torch.no_grad()
|
| 34 |
-
def first_step(self, zero_grad=False):
|
| 35 |
-
grad_norm = self._grad_norm()
|
| 36 |
-
for group in self.param_groups:
|
| 37 |
-
scale = group["rho"] / (grad_norm + 1e-12)
|
| 38 |
-
|
| 39 |
-
for p in group["params"]:
|
| 40 |
-
if p.grad is None: continue
|
| 41 |
-
e_w = p.grad * scale.to(p)
|
| 42 |
-
p.add_(e_w) # climb to the local maximum "w + e(w)"
|
| 43 |
-
self.state[p]["e_w"] = e_w
|
| 44 |
-
|
| 45 |
-
if zero_grad: self.zero_grad()
|
| 46 |
-
|
| 47 |
-
@torch.no_grad()
|
| 48 |
-
def second_step(self, zero_grad=False):
|
| 49 |
-
for group in self.param_groups:
|
| 50 |
-
for p in group["params"]:
|
| 51 |
-
if p.grad is None: continue
|
| 52 |
-
p.sub_(self.state[p]["e_w"]) # get back to "w" from "w + e(w)"
|
| 53 |
-
|
| 54 |
-
self.base_optimizer.step() # do the actual "sharpness-aware" update
|
| 55 |
-
|
| 56 |
-
if zero_grad: self.zero_grad()
|
| 57 |
-
|
| 58 |
-
@torch.no_grad()
|
| 59 |
-
def step(self, closure=None):
|
| 60 |
-
assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
|
| 61 |
-
closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass
|
| 62 |
-
|
| 63 |
-
self.first_step(zero_grad=True)
|
| 64 |
-
closure()
|
| 65 |
-
self.second_step()
|
| 66 |
-
|
| 67 |
-
def _grad_norm(self):
|
| 68 |
-
shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism
|
| 69 |
-
norm = torch.norm(
|
| 70 |
-
torch.stack([
|
| 71 |
-
p.grad.norm(p=2).to(shared_device)
|
| 72 |
-
for group in self.param_groups for p in group["params"]
|
| 73 |
-
if p.grad is not None
|
| 74 |
-
]),
|
| 75 |
-
p=2
|
| 76 |
-
)
|
| 77 |
-
return norm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_deps/train_detector.py
DELETED
|
@@ -1,460 +0,0 @@
|
|
| 1 |
-
# This script was adapted from the DeepfakeBench training code,
|
| 2 |
-
# originally authored by Zhiyuan Yan (zhiyuanyan@link.cuhk.edu.cn)
|
| 3 |
-
|
| 4 |
-
# Original: https://github.com/SCLBD/DeepfakeBench/blob/main/training/train.py
|
| 5 |
-
|
| 6 |
-
# BitMind's modifications include adding a testing phase, changing the
|
| 7 |
-
# data load/split pipeline to work with subnet 34's image augmentations
|
| 8 |
-
# and datasets from BitMind HuggingFace repositories, quality of life CLI args,
|
| 9 |
-
# logging changes, etc.
|
| 10 |
-
|
| 11 |
-
import os
|
| 12 |
-
import sys
|
| 13 |
-
import argparse
|
| 14 |
-
from os.path import join
|
| 15 |
-
import random
|
| 16 |
-
import datetime
|
| 17 |
-
import time
|
| 18 |
-
import yaml
|
| 19 |
-
from tqdm import tqdm
|
| 20 |
-
import numpy as np
|
| 21 |
-
from datetime import timedelta
|
| 22 |
-
from copy import deepcopy
|
| 23 |
-
from PIL import Image as pil_image
|
| 24 |
-
from pathlib import Path
|
| 25 |
-
import gc
|
| 26 |
-
|
| 27 |
-
import torch
|
| 28 |
-
import torch.nn as nn
|
| 29 |
-
import torch.nn.parallel
|
| 30 |
-
import torch.backends.cudnn as cudnn
|
| 31 |
-
import torch.utils.data
|
| 32 |
-
import torch.optim as optim
|
| 33 |
-
from torch.utils.data.distributed import DistributedSampler
|
| 34 |
-
import torch.distributed as dist
|
| 35 |
-
from torch.utils.data import DataLoader
|
| 36 |
-
|
| 37 |
-
from optimizor.SAM import SAM
|
| 38 |
-
from optimizor.LinearLR import LinearDecayLR
|
| 39 |
-
|
| 40 |
-
from trainer.trainer import Trainer
|
| 41 |
-
from arena.detectors.UCF.detectors import DETECTOR
|
| 42 |
-
from metrics.utils import parse_metric_for_print
|
| 43 |
-
from logger import create_logger, RankFilter
|
| 44 |
-
|
| 45 |
-
from huggingface_hub import hf_hub_download
|
| 46 |
-
|
| 47 |
-
# BitMind imports (not from original Deepfake Bench repo)
|
| 48 |
-
from bitmind.dataset_processing.load_split_data import load_datasets, create_real_fake_datasets
|
| 49 |
-
from bitmind.image_transforms import base_transforms, random_aug_transforms
|
| 50 |
-
from bitmind.constants import DATASET_META, FACE_TRAINING_DATASET_META
|
| 51 |
-
from config.constants import (
|
| 52 |
-
CONFIG_PATH,
|
| 53 |
-
WEIGHTS_DIR,
|
| 54 |
-
HF_REPO,
|
| 55 |
-
BACKBONE_CKPT
|
| 56 |
-
)
|
| 57 |
-
|
| 58 |
-
parser = argparse.ArgumentParser(description='Process some paths.')
|
| 59 |
-
parser.add_argument('--detector_path', type=str, default=CONFIG_PATH, help='path to detector YAML file')
|
| 60 |
-
parser.add_argument('--faces_only', dest='faces_only', action='store_true', default=False)
|
| 61 |
-
parser.add_argument('--no-save_ckpt', dest='save_ckpt', action='store_false', default=True)
|
| 62 |
-
parser.add_argument('--no-save_feat', dest='save_feat', action='store_false', default=True)
|
| 63 |
-
parser.add_argument("--ddp", action='store_true', default=False)
|
| 64 |
-
parser.add_argument('--local_rank', type=int, default=0)
|
| 65 |
-
parser.add_argument('--workers', type=int, default=os.cpu_count() - 1,
|
| 66 |
-
help='number of workers for data loading')
|
| 67 |
-
parser.add_argument('--epochs', type=int, default=None, help='number of training epochs')
|
| 68 |
-
|
| 69 |
-
args = parser.parse_args()
|
| 70 |
-
torch.cuda.set_device(args.local_rank)
|
| 71 |
-
print(f"torch.cuda.device(0): {torch.cuda.device(0)}")
|
| 72 |
-
print(f"torch.cuda.get_device_name(0): {torch.cuda.get_device_name(0)}")
|
| 73 |
-
|
| 74 |
-
def ensure_backbone_is_available(logger,
|
| 75 |
-
weights_dir=WEIGHTS_DIR,
|
| 76 |
-
model_filename=BACKBONE_CKPT,
|
| 77 |
-
hugging_face_repo_name=HF_REPO):
|
| 78 |
-
|
| 79 |
-
destination_path = Path(weights_dir) / Path(model_filename)
|
| 80 |
-
if not destination_path.parent.exists():
|
| 81 |
-
destination_path.parent.mkdir(parents=True, exist_ok=True)
|
| 82 |
-
logger.info(f"Created directory {destination_path.parent}.")
|
| 83 |
-
if not destination_path.exists():
|
| 84 |
-
model_path = hf_hub_download(hugging_face_repo_name, model_filename)
|
| 85 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 86 |
-
model = torch.load(model_path, map_location=device)
|
| 87 |
-
torch.save(model, destination_path)
|
| 88 |
-
del model
|
| 89 |
-
if torch.cuda.is_available():
|
| 90 |
-
torch.cuda.empty_cache()
|
| 91 |
-
logger.info(f"Downloaded backbone {model_filename} to {destination_path}.")
|
| 92 |
-
else:
|
| 93 |
-
logger.info(f"{model_filename} backbone already present at {destination_path}.")
|
| 94 |
-
|
| 95 |
-
def init_seed(config):
|
| 96 |
-
if config['manualSeed'] is None:
|
| 97 |
-
config['manualSeed'] = random.randint(1, 10000)
|
| 98 |
-
random.seed(config['manualSeed'])
|
| 99 |
-
if config['cuda']:
|
| 100 |
-
torch.manual_seed(config['manualSeed'])
|
| 101 |
-
torch.cuda.manual_seed_all(config['manualSeed'])
|
| 102 |
-
|
| 103 |
-
def custom_collate_fn(batch):
|
| 104 |
-
images, labels, source_labels = zip(*batch)
|
| 105 |
-
|
| 106 |
-
images = torch.stack(images, dim=0) # Stack image tensors into a single tensor
|
| 107 |
-
labels = torch.LongTensor(labels)
|
| 108 |
-
source_labels = torch.LongTensor(source_labels)
|
| 109 |
-
|
| 110 |
-
data_dict = {
|
| 111 |
-
'image': images,
|
| 112 |
-
'label': labels,
|
| 113 |
-
'label_spe': source_labels,
|
| 114 |
-
'landmark': None,
|
| 115 |
-
'mask': None
|
| 116 |
-
}
|
| 117 |
-
return data_dict
|
| 118 |
-
|
| 119 |
-
def prepare_datasets(config, logger):
|
| 120 |
-
start_time = log_start_time(logger, "Loading and splitting individual datasets")
|
| 121 |
-
|
| 122 |
-
real_datasets, fake_datasets = load_datasets(dataset_meta=config['dataset_meta'],
|
| 123 |
-
expert=config['faces_only'],
|
| 124 |
-
split_transforms=config['split_transforms'])
|
| 125 |
-
|
| 126 |
-
log_finish_time(logger, "Loading and splitting individual datasets", start_time)
|
| 127 |
-
|
| 128 |
-
start_time = log_start_time(logger, "Creating real fake dataset splits")
|
| 129 |
-
train_dataset, val_dataset, test_dataset = \
|
| 130 |
-
create_real_fake_datasets(real_datasets,
|
| 131 |
-
fake_datasets,
|
| 132 |
-
config['split_transforms']['train']['transform'],
|
| 133 |
-
config['split_transforms']['validation']['transform'],
|
| 134 |
-
config['split_transforms']['test']['transform'],
|
| 135 |
-
source_labels=True)
|
| 136 |
-
|
| 137 |
-
log_finish_time(logger, "Creating real fake dataset splits", start_time)
|
| 138 |
-
|
| 139 |
-
train_loader = torch.utils.data.DataLoader(train_dataset,
|
| 140 |
-
batch_size=config['train_batchSize'],
|
| 141 |
-
shuffle=True,
|
| 142 |
-
num_workers=config['workers'],
|
| 143 |
-
drop_last=True,
|
| 144 |
-
collate_fn=custom_collate_fn)
|
| 145 |
-
val_loader = torch.utils.data.DataLoader(val_dataset,
|
| 146 |
-
batch_size=config['train_batchSize'],
|
| 147 |
-
shuffle=True,
|
| 148 |
-
num_workers=config['workers'],
|
| 149 |
-
drop_last=True,
|
| 150 |
-
collate_fn=custom_collate_fn)
|
| 151 |
-
test_loader = torch.utils.data.DataLoader(test_dataset,
|
| 152 |
-
batch_size=config['train_batchSize'],
|
| 153 |
-
shuffle=True,
|
| 154 |
-
num_workers=config['workers'],
|
| 155 |
-
drop_last=True,
|
| 156 |
-
collate_fn=custom_collate_fn)
|
| 157 |
-
|
| 158 |
-
print(f"Train size: {len(train_loader.dataset)}")
|
| 159 |
-
print(f"Validation size: {len(val_loader.dataset)}")
|
| 160 |
-
print(f"Test size: {len(test_loader.dataset)}")
|
| 161 |
-
|
| 162 |
-
return train_loader, val_loader, test_loader
|
| 163 |
-
|
| 164 |
-
def choose_optimizer(model, config):
|
| 165 |
-
opt_name = config['optimizer']['type']
|
| 166 |
-
if opt_name == 'sgd':
|
| 167 |
-
optimizer = optim.SGD(
|
| 168 |
-
params=model.parameters(),
|
| 169 |
-
lr=config['optimizer'][opt_name]['lr'],
|
| 170 |
-
momentum=config['optimizer'][opt_name]['momentum'],
|
| 171 |
-
weight_decay=config['optimizer'][opt_name]['weight_decay']
|
| 172 |
-
)
|
| 173 |
-
return optimizer
|
| 174 |
-
elif opt_name == 'adam':
|
| 175 |
-
optimizer = optim.Adam(
|
| 176 |
-
params=model.parameters(),
|
| 177 |
-
lr=config['optimizer'][opt_name]['lr'],
|
| 178 |
-
weight_decay=config['optimizer'][opt_name]['weight_decay'],
|
| 179 |
-
betas=(config['optimizer'][opt_name]['beta1'], config['optimizer'][opt_name]['beta2']),
|
| 180 |
-
eps=config['optimizer'][opt_name]['eps'],
|
| 181 |
-
amsgrad=config['optimizer'][opt_name]['amsgrad'],
|
| 182 |
-
)
|
| 183 |
-
return optimizer
|
| 184 |
-
elif opt_name == 'sam':
|
| 185 |
-
optimizer = SAM(
|
| 186 |
-
model.parameters(),
|
| 187 |
-
optim.SGD,
|
| 188 |
-
lr=config['optimizer'][opt_name]['lr'],
|
| 189 |
-
momentum=config['optimizer'][opt_name]['momentum'],
|
| 190 |
-
)
|
| 191 |
-
else:
|
| 192 |
-
raise NotImplementedError('Optimizer {} is not implemented'.format(config['optimizer']))
|
| 193 |
-
return optimizer
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
def choose_scheduler(config, optimizer):
|
| 197 |
-
if config['lr_scheduler'] is None:
|
| 198 |
-
return None
|
| 199 |
-
elif config['lr_scheduler'] == 'step':
|
| 200 |
-
scheduler = optim.lr_scheduler.StepLR(
|
| 201 |
-
optimizer,
|
| 202 |
-
step_size=config['lr_step'],
|
| 203 |
-
gamma=config['lr_gamma'],
|
| 204 |
-
)
|
| 205 |
-
return scheduler
|
| 206 |
-
elif config['lr_scheduler'] == 'cosine':
|
| 207 |
-
scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
| 208 |
-
optimizer,
|
| 209 |
-
T_max=config['lr_T_max'],
|
| 210 |
-
eta_min=config['lr_eta_min'],
|
| 211 |
-
)
|
| 212 |
-
return scheduler
|
| 213 |
-
elif config['lr_scheduler'] == 'linear':
|
| 214 |
-
scheduler = LinearDecayLR(
|
| 215 |
-
optimizer,
|
| 216 |
-
config['nEpochs'],
|
| 217 |
-
int(config['nEpochs']/4),
|
| 218 |
-
)
|
| 219 |
-
else:
|
| 220 |
-
raise NotImplementedError('Scheduler {} is not implemented'.format(config['lr_scheduler']))
|
| 221 |
-
|
| 222 |
-
def choose_metric(config):
|
| 223 |
-
metric_scoring = config['metric_scoring']
|
| 224 |
-
if metric_scoring not in ['eer', 'auc', 'acc', 'ap']:
|
| 225 |
-
raise NotImplementedError('metric {} is not implemented'.format(metric_scoring))
|
| 226 |
-
return metric_scoring
|
| 227 |
-
|
| 228 |
-
def log_start_time(logger, process_name):
|
| 229 |
-
"""Log the start time of a process."""
|
| 230 |
-
start_time = time.time()
|
| 231 |
-
logger.info(f"{process_name} Start Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))}")
|
| 232 |
-
return start_time
|
| 233 |
-
|
| 234 |
-
def log_finish_time(logger, process_name, start_time):
|
| 235 |
-
"""Log the finish time and elapsed time of a process."""
|
| 236 |
-
finish_time = time.time()
|
| 237 |
-
elapsed_time = finish_time - start_time
|
| 238 |
-
|
| 239 |
-
# Convert elapsed time into hours, minutes, and seconds
|
| 240 |
-
hours, rem = divmod(elapsed_time, 3600)
|
| 241 |
-
minutes, seconds = divmod(rem, 60)
|
| 242 |
-
|
| 243 |
-
# Log the finish time and elapsed time
|
| 244 |
-
logger.info(f"{process_name} Finish Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(finish_time))}")
|
| 245 |
-
logger.info(f"{process_name} Elapsed Time: {int(hours)} hours, {int(minutes)} minutes, {seconds:.2f} seconds")
|
| 246 |
-
|
| 247 |
-
def save_config(config, outputs_dir):
|
| 248 |
-
"""
|
| 249 |
-
Saves a config dictionary as both a pickle file and a YAML file, ensuring only basic types are saved.
|
| 250 |
-
Also, lists like 'mean' and 'std' are saved in flow style (on a single line).
|
| 251 |
-
|
| 252 |
-
Args:
|
| 253 |
-
config (dict): The configuration dictionary to save.
|
| 254 |
-
outputs_dir (str): The directory path where the files will be saved.
|
| 255 |
-
"""
|
| 256 |
-
|
| 257 |
-
def is_basic_type(value):
|
| 258 |
-
"""
|
| 259 |
-
Check if a value is a basic data type that can be saved in YAML.
|
| 260 |
-
Basic types include int, float, str, bool, list, and dict.
|
| 261 |
-
"""
|
| 262 |
-
return isinstance(value, (int, float, str, bool, list, dict, type(None)))
|
| 263 |
-
|
| 264 |
-
def filter_dict(data_dict):
|
| 265 |
-
"""
|
| 266 |
-
Recursively filter out any keys from the dictionary whose values contain non-basic types (e.g., objects).
|
| 267 |
-
"""
|
| 268 |
-
if not isinstance(data_dict, dict):
|
| 269 |
-
return data_dict
|
| 270 |
-
|
| 271 |
-
filtered_dict = {}
|
| 272 |
-
for key, value in data_dict.items():
|
| 273 |
-
if isinstance(value, dict):
|
| 274 |
-
# Recursively filter nested dictionaries
|
| 275 |
-
nested_dict = filter_dict(value)
|
| 276 |
-
if nested_dict: # Only add non-empty dictionaries
|
| 277 |
-
filtered_dict[key] = nested_dict
|
| 278 |
-
elif is_basic_type(value):
|
| 279 |
-
# Add if the value is a basic type
|
| 280 |
-
filtered_dict[key] = value
|
| 281 |
-
else:
|
| 282 |
-
# Skip the key if the value is not a basic type (e.g., an object)
|
| 283 |
-
print(f"Skipping key '{key}' because its value is of type {type(value)}")
|
| 284 |
-
|
| 285 |
-
return filtered_dict
|
| 286 |
-
|
| 287 |
-
def save_dict_to_yaml(data_dict, file_path):
|
| 288 |
-
"""
|
| 289 |
-
Saves a dictionary to a YAML file, excluding any keys where the value is an object or contains an object.
|
| 290 |
-
Additionally, ensures that specific lists (like 'mean' and 'std') are saved in flow style.
|
| 291 |
-
|
| 292 |
-
Args:
|
| 293 |
-
data_dict (dict): The dictionary to save.
|
| 294 |
-
file_path (str): The local file path where the YAML file will be saved.
|
| 295 |
-
"""
|
| 296 |
-
|
| 297 |
-
# Custom representer for lists to force flow style (compact lists)
|
| 298 |
-
class FlowStyleList(list):
|
| 299 |
-
pass
|
| 300 |
-
|
| 301 |
-
def flow_style_list_representer(dumper, data):
|
| 302 |
-
return dumper.represent_sequence('tag:yaml.org,2002:seq', data, flow_style=True)
|
| 303 |
-
|
| 304 |
-
yaml.add_representer(FlowStyleList, flow_style_list_representer)
|
| 305 |
-
|
| 306 |
-
# Preprocess specific lists to be in flow style
|
| 307 |
-
if 'mean' in data_dict:
|
| 308 |
-
data_dict['mean'] = FlowStyleList(data_dict['mean'])
|
| 309 |
-
if 'std' in data_dict:
|
| 310 |
-
data_dict['std'] = FlowStyleList(data_dict['std'])
|
| 311 |
-
|
| 312 |
-
try:
|
| 313 |
-
# Filter the dictionary
|
| 314 |
-
filtered_dict = filter_dict(data_dict)
|
| 315 |
-
|
| 316 |
-
# Save the filtered dictionary as YAML
|
| 317 |
-
with open(file_path, 'w') as f:
|
| 318 |
-
yaml.dump(filtered_dict, f, default_flow_style=False) # Save with default block style except for FlowStyleList
|
| 319 |
-
print(f"Filtered dictionary successfully saved to {file_path}")
|
| 320 |
-
except Exception as e:
|
| 321 |
-
print(f"Error saving dictionary to YAML: {e}")
|
| 322 |
-
|
| 323 |
-
# Save as YAML
|
| 324 |
-
save_dict_to_yaml(config, outputs_dir + '/config.yaml')
|
| 325 |
-
|
| 326 |
-
def main():
|
| 327 |
-
torch.cuda.empty_cache()
|
| 328 |
-
gc.collect()
|
| 329 |
-
# parse options and load config
|
| 330 |
-
with open(args.detector_path, 'r') as f:
|
| 331 |
-
config = yaml.safe_load(f)
|
| 332 |
-
with open(os.getcwd() + '/config/train_config.yaml', 'r') as f:
|
| 333 |
-
config2 = yaml.safe_load(f)
|
| 334 |
-
if 'label_dict' in config:
|
| 335 |
-
config2['label_dict']=config['label_dict']
|
| 336 |
-
config.update(config2)
|
| 337 |
-
|
| 338 |
-
config['workers'] = args.workers
|
| 339 |
-
|
| 340 |
-
config['local_rank']=args.local_rank
|
| 341 |
-
if config['dry_run']:
|
| 342 |
-
config['nEpochs'] = 0
|
| 343 |
-
config['save_feat']=False
|
| 344 |
-
|
| 345 |
-
if args.epochs: config['nEpochs'] = args.epochs
|
| 346 |
-
|
| 347 |
-
config['split_transforms'] = {'train': {'name': 'base_transforms',
|
| 348 |
-
'transform': base_transforms},
|
| 349 |
-
'validation': {'name': 'base_transforms',
|
| 350 |
-
'transform': base_transforms},
|
| 351 |
-
'test': {'name': 'base_transforms',
|
| 352 |
-
'transform': base_transforms}}
|
| 353 |
-
config['faces_only'] = args.faces_only
|
| 354 |
-
config['dataset_meta'] = FACE_TRAINING_DATASET_META if config['faces_only'] else DATASET_META
|
| 355 |
-
dataset_names = [item["path"] for datasets in config['dataset_meta'].values() for item in datasets]
|
| 356 |
-
config['train_dataset'] = dataset_names
|
| 357 |
-
config['save_ckpt'] = args.save_ckpt
|
| 358 |
-
config['save_feat'] = args.save_feat
|
| 359 |
-
|
| 360 |
-
config['specific_task_number'] = len(config['dataset_meta']["fake"]) + 1
|
| 361 |
-
|
| 362 |
-
if config['lmdb']:
|
| 363 |
-
config['dataset_json_folder'] = 'preprocessing/dataset_json_v3'
|
| 364 |
-
|
| 365 |
-
# create logger
|
| 366 |
-
timenow=datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
|
| 367 |
-
|
| 368 |
-
outputs_dir = os.path.join(
|
| 369 |
-
config['log_dir'],
|
| 370 |
-
config['model_name'] + '_' + timenow
|
| 371 |
-
)
|
| 372 |
-
|
| 373 |
-
os.makedirs(outputs_dir, exist_ok=True)
|
| 374 |
-
logger = create_logger(os.path.join(outputs_dir, 'training.log'))
|
| 375 |
-
config['log_dir'] = outputs_dir
|
| 376 |
-
logger.info('Save log to {}'.format(outputs_dir))
|
| 377 |
-
|
| 378 |
-
config['ddp']= args.ddp
|
| 379 |
-
|
| 380 |
-
# init seed
|
| 381 |
-
init_seed(config)
|
| 382 |
-
|
| 383 |
-
# set cudnn benchmark if needed
|
| 384 |
-
if config['cudnn']:
|
| 385 |
-
cudnn.benchmark = True
|
| 386 |
-
if config['ddp']:
|
| 387 |
-
# dist.init_process_group(backend='gloo')
|
| 388 |
-
dist.init_process_group(
|
| 389 |
-
backend='nccl',
|
| 390 |
-
timeout=timedelta(minutes=30)
|
| 391 |
-
)
|
| 392 |
-
logger.addFilter(RankFilter(0))
|
| 393 |
-
|
| 394 |
-
ensure_backbone_is_available(logger=logger,
|
| 395 |
-
model_filename=config['pretrained'].split('/')[-1],
|
| 396 |
-
hugging_face_repo_name='bitmind/' + config['model_name'])
|
| 397 |
-
|
| 398 |
-
# prepare the model (detector)
|
| 399 |
-
model_class = DETECTOR[config['model_name']]
|
| 400 |
-
model = model_class(config)
|
| 401 |
-
|
| 402 |
-
# prepare the optimizer
|
| 403 |
-
optimizer = choose_optimizer(model, config)
|
| 404 |
-
|
| 405 |
-
# prepare the scheduler
|
| 406 |
-
scheduler = choose_scheduler(config, optimizer)
|
| 407 |
-
|
| 408 |
-
# prepare the metric
|
| 409 |
-
metric_scoring = choose_metric(config)
|
| 410 |
-
|
| 411 |
-
# prepare the trainer
|
| 412 |
-
trainer = Trainer(config, model, optimizer, scheduler, logger, metric_scoring)
|
| 413 |
-
|
| 414 |
-
# prepare the data loaders
|
| 415 |
-
train_loader, val_loader, test_loader = prepare_datasets(config, logger)
|
| 416 |
-
|
| 417 |
-
# print configuration
|
| 418 |
-
logger.info("--------------- Configuration ---------------")
|
| 419 |
-
params_string = "Parameters: \n"
|
| 420 |
-
for key, value in config.items():
|
| 421 |
-
params_string += "{}: {}".format(key, value) + "\n"
|
| 422 |
-
logger.info(params_string)
|
| 423 |
-
|
| 424 |
-
# save training configs
|
| 425 |
-
save_config(config, outputs_dir)
|
| 426 |
-
|
| 427 |
-
# start training
|
| 428 |
-
start_time = log_start_time(logger, "Training")
|
| 429 |
-
for epoch in range(config['start_epoch'], config['nEpochs'] + 1):
|
| 430 |
-
trainer.model.epoch = epoch
|
| 431 |
-
best_metric = trainer.train_epoch(
|
| 432 |
-
epoch,
|
| 433 |
-
train_data_loader=train_loader,
|
| 434 |
-
validation_data_loaders={'val':val_loader}
|
| 435 |
-
)
|
| 436 |
-
if best_metric is not None:
|
| 437 |
-
logger.info(f"===> Epoch[{epoch}] end with validation {metric_scoring}: {parse_metric_for_print(best_metric)}!")
|
| 438 |
-
logger.info("Stop Training on best Validation metric {}".format(parse_metric_for_print(best_metric)))
|
| 439 |
-
log_finish_time(logger, "Training", start_time)
|
| 440 |
-
|
| 441 |
-
# test
|
| 442 |
-
start_time = log_start_time(logger, "Test")
|
| 443 |
-
trainer.eval(eval_data_loaders={'test':test_loader}, eval_stage="test")
|
| 444 |
-
log_finish_time(logger, "Test", start_time)
|
| 445 |
-
|
| 446 |
-
# update
|
| 447 |
-
if 'svdd' in config['model_name']:
|
| 448 |
-
model.update_R(epoch)
|
| 449 |
-
if scheduler is not None:
|
| 450 |
-
scheduler.step()
|
| 451 |
-
|
| 452 |
-
# close the tensorboard writers
|
| 453 |
-
for writer in trainer.writers.values():
|
| 454 |
-
writer.close()
|
| 455 |
-
|
| 456 |
-
torch.cuda.empty_cache()
|
| 457 |
-
gc.collect()
|
| 458 |
-
|
| 459 |
-
if __name__ == '__main__':
|
| 460 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_deps/trainer/trainer.py
DELETED
|
@@ -1,441 +0,0 @@
|
|
| 1 |
-
# This script was adapted from the DeepfakeBench training code,
|
| 2 |
-
# originally authored by Zhiyuan Yan (zhiyuanyan@link.cuhk.edu.cn)
|
| 3 |
-
|
| 4 |
-
# Original: https://github.com/SCLBD/DeepfakeBench/blob/main/training/train.py
|
| 5 |
-
|
| 6 |
-
import os
|
| 7 |
-
import sys
|
| 8 |
-
current_file_path = os.path.abspath(__file__)
|
| 9 |
-
parent_dir = os.path.dirname(os.path.dirname(current_file_path))
|
| 10 |
-
project_root_dir = os.path.dirname(parent_dir)
|
| 11 |
-
sys.path.append(parent_dir)
|
| 12 |
-
sys.path.append(project_root_dir)
|
| 13 |
-
|
| 14 |
-
import pickle
|
| 15 |
-
import datetime
|
| 16 |
-
import logging
|
| 17 |
-
import numpy as np
|
| 18 |
-
from copy import deepcopy
|
| 19 |
-
from collections import defaultdict
|
| 20 |
-
from tqdm import tqdm
|
| 21 |
-
import time
|
| 22 |
-
import torch
|
| 23 |
-
import torch.nn as nn
|
| 24 |
-
import torch.nn.functional as F
|
| 25 |
-
import torch.optim as optim
|
| 26 |
-
from torch.nn import DataParallel
|
| 27 |
-
from torch.utils.tensorboard import SummaryWriter
|
| 28 |
-
from metrics.base_metrics_class import Recorder
|
| 29 |
-
from torch.optim.swa_utils import AveragedModel, SWALR
|
| 30 |
-
from torch import distributed as dist
|
| 31 |
-
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 32 |
-
from sklearn import metrics
|
| 33 |
-
from metrics.utils import get_test_metrics
|
| 34 |
-
|
| 35 |
-
FFpp_pool=['FaceForensics++','FF-DF','FF-F2F','FF-FS','FF-NT']#
|
| 36 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
class Trainer(object):
|
| 40 |
-
def __init__(
|
| 41 |
-
self,
|
| 42 |
-
config,
|
| 43 |
-
model,
|
| 44 |
-
optimizer,
|
| 45 |
-
scheduler,
|
| 46 |
-
logger,
|
| 47 |
-
metric_scoring='auc',
|
| 48 |
-
swa_model=None
|
| 49 |
-
):
|
| 50 |
-
# check if all the necessary components are implemented
|
| 51 |
-
if config is None or model is None or optimizer is None or logger is None:
|
| 52 |
-
raise ValueError("config, model, optimizier, logger, and tensorboard writer must be implemented")
|
| 53 |
-
|
| 54 |
-
self.config = config
|
| 55 |
-
self.model = model
|
| 56 |
-
self.optimizer = optimizer
|
| 57 |
-
self.scheduler = scheduler
|
| 58 |
-
self.swa_model = swa_model
|
| 59 |
-
self.writers = {} # dict to maintain different tensorboard writers for each dataset and metric
|
| 60 |
-
self.logger = logger
|
| 61 |
-
self.metric_scoring = metric_scoring
|
| 62 |
-
# maintain the best metric of all epochs
|
| 63 |
-
self.best_metrics_all_time = defaultdict(
|
| 64 |
-
lambda: defaultdict(lambda: float('-inf')
|
| 65 |
-
if self.metric_scoring != 'eer' else float('inf'))
|
| 66 |
-
)
|
| 67 |
-
self.speed_up() # move model to GPU
|
| 68 |
-
|
| 69 |
-
# create directory path
|
| 70 |
-
self.log_dir = self.config['log_dir']
|
| 71 |
-
print("Making dir ", self.log_dir)
|
| 72 |
-
os.makedirs(self.log_dir, exist_ok=True)
|
| 73 |
-
|
| 74 |
-
def get_writer(self, phase, dataset_key, metric_key):
|
| 75 |
-
phase = phase.split('/')[-1]
|
| 76 |
-
dataset_key = dataset_key.split('/')[-1]
|
| 77 |
-
metric_key = metric_key.split('/')[-1]
|
| 78 |
-
writer_key = f"{phase}-{dataset_key}-{metric_key}"
|
| 79 |
-
if writer_key not in self.writers:
|
| 80 |
-
# update directory path
|
| 81 |
-
writer_path = os.path.join(
|
| 82 |
-
self.log_dir,
|
| 83 |
-
phase,
|
| 84 |
-
dataset_key,
|
| 85 |
-
metric_key,
|
| 86 |
-
"metric_board"
|
| 87 |
-
)
|
| 88 |
-
os.makedirs(writer_path, exist_ok=True)
|
| 89 |
-
# update writers dictionary
|
| 90 |
-
self.writers[writer_key] = SummaryWriter(writer_path)
|
| 91 |
-
return self.writers[writer_key]
|
| 92 |
-
|
| 93 |
-
def speed_up(self):
|
| 94 |
-
self.model.to(device)
|
| 95 |
-
self.model.device = device
|
| 96 |
-
if self.config['ddp'] == True:
|
| 97 |
-
num_gpus = torch.cuda.device_count()
|
| 98 |
-
print(f'avai gpus: {num_gpus}')
|
| 99 |
-
# local_rank=[i for i in range(0,num_gpus)]
|
| 100 |
-
self.model = DDP(self.model, device_ids=[self.config['local_rank']],find_unused_parameters=True, output_device=self.config['local_rank'])
|
| 101 |
-
#self.optimizer = nn.DataParallel(self.optimizer, device_ids=[int(os.environ['LOCAL_RANK'])])
|
| 102 |
-
|
| 103 |
-
def setTrain(self):
|
| 104 |
-
self.model.train()
|
| 105 |
-
self.train = True
|
| 106 |
-
|
| 107 |
-
def setEval(self):
|
| 108 |
-
self.model.eval()
|
| 109 |
-
self.train = False
|
| 110 |
-
|
| 111 |
-
def load_ckpt(self, model_path):
|
| 112 |
-
if os.path.isfile(model_path):
|
| 113 |
-
saved = torch.load(model_path, map_location='cpu')
|
| 114 |
-
suffix = model_path.split('.')[-1]
|
| 115 |
-
if suffix == 'p':
|
| 116 |
-
self.model.load_state_dict(saved.state_dict())
|
| 117 |
-
else:
|
| 118 |
-
self.model.load_state_dict(saved)
|
| 119 |
-
self.logger.info('Model found in {}'.format(model_path))
|
| 120 |
-
else:
|
| 121 |
-
raise NotImplementedError(
|
| 122 |
-
"=> no model found at '{}'".format(model_path))
|
| 123 |
-
|
| 124 |
-
def save_ckpt(self, phase, dataset_key,ckpt_info=None):
|
| 125 |
-
save_dir = self.log_dir
|
| 126 |
-
os.makedirs(save_dir, exist_ok=True)
|
| 127 |
-
ckpt_name = f"ckpt_best.pth"
|
| 128 |
-
save_path = os.path.join(save_dir, ckpt_name)
|
| 129 |
-
if self.config['ddp'] == True:
|
| 130 |
-
torch.save(self.model.state_dict(), save_path)
|
| 131 |
-
else:
|
| 132 |
-
if 'svdd' in self.config['model_name']:
|
| 133 |
-
torch.save({'R': self.model.R,
|
| 134 |
-
'c': self.model.c,
|
| 135 |
-
'state_dict': self.model.state_dict(),}, save_path)
|
| 136 |
-
else:
|
| 137 |
-
torch.save(self.model.state_dict(), save_path)
|
| 138 |
-
self.logger.info(f"Checkpoint saved to {save_path}, current ckpt is {ckpt_info}")
|
| 139 |
-
|
| 140 |
-
def save_swa_ckpt(self):
|
| 141 |
-
save_dir = self.log_dir
|
| 142 |
-
os.makedirs(save_dir, exist_ok=True)
|
| 143 |
-
ckpt_name = f"swa.pth"
|
| 144 |
-
save_path = os.path.join(save_dir, ckpt_name)
|
| 145 |
-
torch.save(self.swa_model.state_dict(), save_path)
|
| 146 |
-
self.logger.info(f"SWA Checkpoint saved to {save_path}")
|
| 147 |
-
|
| 148 |
-
def save_feat(self, phase, fea, dataset_key):
|
| 149 |
-
save_dir = os.path.join(self.log_dir, phase, dataset_key)
|
| 150 |
-
os.makedirs(save_dir, exist_ok=True)
|
| 151 |
-
features = fea
|
| 152 |
-
feat_name = f"feat_best.npy"
|
| 153 |
-
save_path = os.path.join(save_dir, feat_name)
|
| 154 |
-
np.save(save_path, features)
|
| 155 |
-
self.logger.info(f"Feature saved to {save_path}")
|
| 156 |
-
|
| 157 |
-
def save_data_dict(self, phase, data_dict, dataset_key):
|
| 158 |
-
save_dir = os.path.join(self.log_dir, phase, dataset_key)
|
| 159 |
-
os.makedirs(save_dir, exist_ok=True)
|
| 160 |
-
file_path = os.path.join(save_dir, f'data_dict_{phase}.pickle')
|
| 161 |
-
with open(file_path, 'wb') as file:
|
| 162 |
-
pickle.dump(data_dict, file)
|
| 163 |
-
self.logger.info(f"data_dict saved to {file_path}")
|
| 164 |
-
|
| 165 |
-
def save_metrics(self, phase, metric_one_dataset, dataset_key):
|
| 166 |
-
save_dir = os.path.join(self.log_dir, phase, dataset_key)
|
| 167 |
-
os.makedirs(save_dir, exist_ok=True)
|
| 168 |
-
file_path = os.path.join(save_dir, 'metric_dict_best.pickle')
|
| 169 |
-
with open(file_path, 'wb') as file:
|
| 170 |
-
pickle.dump(metric_one_dataset, file)
|
| 171 |
-
self.logger.info(f"Metrics saved to {file_path}")
|
| 172 |
-
|
| 173 |
-
def train_step(self,data_dict):
|
| 174 |
-
if self.config['optimizer']['type']=='sam':
|
| 175 |
-
for i in range(2):
|
| 176 |
-
predictions = self.model(data_dict)
|
| 177 |
-
losses = self.model.get_losses(data_dict, predictions)
|
| 178 |
-
if i == 0:
|
| 179 |
-
pred_first = predictions
|
| 180 |
-
losses_first = losses
|
| 181 |
-
self.optimizer.zero_grad()
|
| 182 |
-
losses['overall'].backward()
|
| 183 |
-
if i == 0:
|
| 184 |
-
self.optimizer.first_step(zero_grad=True)
|
| 185 |
-
else:
|
| 186 |
-
self.optimizer.second_step(zero_grad=True)
|
| 187 |
-
return losses_first, pred_first
|
| 188 |
-
else:
|
| 189 |
-
predictions = self.model(data_dict)
|
| 190 |
-
if type(self.model) is DDP:
|
| 191 |
-
losses = self.model.module.get_losses(data_dict, predictions)
|
| 192 |
-
else:
|
| 193 |
-
losses = self.model.get_losses(data_dict, predictions)
|
| 194 |
-
self.optimizer.zero_grad()
|
| 195 |
-
losses['overall'].backward()
|
| 196 |
-
self.optimizer.step()
|
| 197 |
-
|
| 198 |
-
return losses,predictions
|
| 199 |
-
|
| 200 |
-
def train_epoch(
|
| 201 |
-
self,
|
| 202 |
-
epoch,
|
| 203 |
-
train_data_loader,
|
| 204 |
-
validation_data_loaders=None
|
| 205 |
-
):
|
| 206 |
-
|
| 207 |
-
self.logger.info("===> Epoch[{}] start!".format(epoch))
|
| 208 |
-
if epoch>=1:
|
| 209 |
-
times_per_epoch = 2
|
| 210 |
-
else:
|
| 211 |
-
times_per_epoch = 1
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
#times_per_epoch=4
|
| 215 |
-
validation_step = len(train_data_loader) // times_per_epoch # validate 10 times per epoch
|
| 216 |
-
step_cnt = epoch * len(train_data_loader)
|
| 217 |
-
|
| 218 |
-
# define training recorder
|
| 219 |
-
train_recorder_loss = defaultdict(Recorder)
|
| 220 |
-
train_recorder_metric = defaultdict(Recorder)
|
| 221 |
-
|
| 222 |
-
for iteration, data_dict in tqdm(enumerate(train_data_loader),total=len(train_data_loader)):
|
| 223 |
-
self.setTrain()
|
| 224 |
-
# more elegant and more scalable way of moving data to GPU
|
| 225 |
-
for key in data_dict.keys():
|
| 226 |
-
if data_dict[key]!=None and key!='name':
|
| 227 |
-
data_dict[key]=data_dict[key].cuda()
|
| 228 |
-
|
| 229 |
-
losses, predictions=self.train_step(data_dict)
|
| 230 |
-
# update learning rate
|
| 231 |
-
|
| 232 |
-
if 'SWA' in self.config and self.config['SWA'] and epoch>self.config['swa_start']:
|
| 233 |
-
self.swa_model.update_parameters(self.model)
|
| 234 |
-
|
| 235 |
-
# compute training metric for each batch data
|
| 236 |
-
if type(self.model) is DDP:
|
| 237 |
-
batch_metrics = self.model.module.get_train_metrics(data_dict, predictions)
|
| 238 |
-
else:
|
| 239 |
-
batch_metrics = self.model.get_train_metrics(data_dict, predictions)
|
| 240 |
-
|
| 241 |
-
# store data by recorder
|
| 242 |
-
## store metric
|
| 243 |
-
for name, value in batch_metrics.items():
|
| 244 |
-
train_recorder_metric[name].update(value)
|
| 245 |
-
## store loss
|
| 246 |
-
for name, value in losses.items():
|
| 247 |
-
train_recorder_loss[name].update(value)
|
| 248 |
-
|
| 249 |
-
# run tensorboard to visualize the training process
|
| 250 |
-
if iteration % 300 == 0 and self.config['local_rank']==0:
|
| 251 |
-
if self.config['SWA'] and (epoch>self.config['swa_start'] or self.config['dry_run']):
|
| 252 |
-
self.scheduler.step()
|
| 253 |
-
# info for loss
|
| 254 |
-
loss_str = f"Iter: {step_cnt} "
|
| 255 |
-
for k, v in train_recorder_loss.items():
|
| 256 |
-
v_avg = v.average()
|
| 257 |
-
if v_avg == None:
|
| 258 |
-
loss_str += f"training-loss, {k}: not calculated"
|
| 259 |
-
continue
|
| 260 |
-
loss_str += f"training-loss, {k}: {v_avg} "
|
| 261 |
-
# tensorboard-1. loss
|
| 262 |
-
processed_train_dataset = [dataset.split('/')[-1] for dataset in self.config['train_dataset']]
|
| 263 |
-
processed_train_dataset = ','.join(processed_train_dataset)
|
| 264 |
-
writer = self.get_writer('train', processed_train_dataset, k)
|
| 265 |
-
writer.add_scalar(f'train_loss/{k}', v_avg, global_step=step_cnt)
|
| 266 |
-
self.logger.info(loss_str)
|
| 267 |
-
# info for metric
|
| 268 |
-
metric_str = f"Iter: {step_cnt} "
|
| 269 |
-
for k, v in train_recorder_metric.items():
|
| 270 |
-
v_avg = v.average()
|
| 271 |
-
if v_avg == None:
|
| 272 |
-
metric_str += f"training-metric, {k}: not calculated "
|
| 273 |
-
continue
|
| 274 |
-
metric_str += f"training-metric, {k}: {v_avg} "
|
| 275 |
-
# tensorboard-2. metric
|
| 276 |
-
processed_train_dataset = [dataset.split('/')[-1] for dataset in self.config['train_dataset']]
|
| 277 |
-
processed_train_dataset = ','.join(processed_train_dataset)
|
| 278 |
-
writer = self.get_writer('train', processed_train_dataset, k)
|
| 279 |
-
writer.add_scalar(f'train_metric/{k}', v_avg, global_step=step_cnt)
|
| 280 |
-
self.logger.info(metric_str)
|
| 281 |
-
|
| 282 |
-
# clear recorder.
|
| 283 |
-
# Note we only consider the current 300 samples for computing batch-level loss/metric
|
| 284 |
-
for name, recorder in train_recorder_loss.items(): # clear loss recorder
|
| 285 |
-
recorder.clear()
|
| 286 |
-
for name, recorder in train_recorder_metric.items(): # clear metric recorder
|
| 287 |
-
recorder.clear()
|
| 288 |
-
|
| 289 |
-
# run validation
|
| 290 |
-
if (step_cnt+1) % validation_step == 0:
|
| 291 |
-
if validation_data_loaders is not None and ((not self.config['ddp']) or (self.config['ddp'] and dist.get_rank() == 0)):
|
| 292 |
-
self.logger.info("===> Validation start!")
|
| 293 |
-
validation_best_metric = self.eval(
|
| 294 |
-
eval_data_loaders=validation_data_loaders,
|
| 295 |
-
eval_stage="validation",
|
| 296 |
-
step=step_cnt,
|
| 297 |
-
epoch=epoch,
|
| 298 |
-
iteration=iteration
|
| 299 |
-
)
|
| 300 |
-
else:
|
| 301 |
-
validation_best_metric = None
|
| 302 |
-
|
| 303 |
-
step_cnt += 1
|
| 304 |
-
|
| 305 |
-
for key in data_dict.keys():
|
| 306 |
-
if data_dict[key]!=None and key!='name':
|
| 307 |
-
data_dict[key]=data_dict[key].cpu()
|
| 308 |
-
return validation_best_metric
|
| 309 |
-
|
| 310 |
-
def get_respect_acc(self,prob,label):
|
| 311 |
-
pred = np.where(prob > 0.5, 1, 0)
|
| 312 |
-
judge = (pred == label)
|
| 313 |
-
zero_num = len(label) - np.count_nonzero(label)
|
| 314 |
-
acc_fake = np.count_nonzero(judge[zero_num:]) / len(judge[zero_num:])
|
| 315 |
-
acc_real = np.count_nonzero(judge[:zero_num]) / len(judge[:zero_num])
|
| 316 |
-
return acc_real,acc_fake
|
| 317 |
-
|
| 318 |
-
def eval_one_dataset(self, data_loader):
|
| 319 |
-
# define eval recorder
|
| 320 |
-
eval_recorder_loss = defaultdict(Recorder)
|
| 321 |
-
prediction_lists = []
|
| 322 |
-
feature_lists=[]
|
| 323 |
-
label_lists = []
|
| 324 |
-
for i, data_dict in tqdm(enumerate(data_loader),total=len(data_loader)):
|
| 325 |
-
# get data
|
| 326 |
-
if 'label_spe' in data_dict:
|
| 327 |
-
data_dict.pop('label_spe') # remove the specific label
|
| 328 |
-
data_dict['label'] = torch.where(data_dict['label']!=0, 1, 0) # fix the label to 0 and 1 only
|
| 329 |
-
# move data to GPU elegantly
|
| 330 |
-
for key in data_dict.keys():
|
| 331 |
-
if data_dict[key]!=None:
|
| 332 |
-
data_dict[key]=data_dict[key].cuda()
|
| 333 |
-
# model forward without considering gradient computation
|
| 334 |
-
predictions = self.inference(data_dict) #dict with keys cls, feat
|
| 335 |
-
|
| 336 |
-
label_lists += list(data_dict['label'].cpu().detach().numpy())
|
| 337 |
-
# Get the predicted class for each sample in the batch
|
| 338 |
-
_, predicted_classes = torch.max(predictions['cls'], dim=1)
|
| 339 |
-
# Convert the predicted class indices to a list and add to prediction_lists
|
| 340 |
-
prediction_lists += predicted_classes.cpu().detach().numpy().tolist()
|
| 341 |
-
feature_lists += list(predictions['feat'].cpu().detach().numpy())
|
| 342 |
-
if type(self.model) is not AveragedModel:
|
| 343 |
-
# compute all losses for each batch data
|
| 344 |
-
if type(self.model) is DDP:
|
| 345 |
-
losses = self.model.module.get_losses(data_dict, predictions)
|
| 346 |
-
else:
|
| 347 |
-
losses = self.model.get_losses(data_dict, predictions)
|
| 348 |
-
|
| 349 |
-
# store data by recorder
|
| 350 |
-
for name, value in losses.items():
|
| 351 |
-
eval_recorder_loss[name].update(value)
|
| 352 |
-
return eval_recorder_loss, np.array(prediction_lists), np.array(label_lists),np.array(feature_lists)
|
| 353 |
-
|
| 354 |
-
def save_best(self,epoch,iteration,step,losses_one_dataset_recorder,key,metric_one_dataset,eval_stage):
|
| 355 |
-
best_metric = self.best_metrics_all_time[key].get(self.metric_scoring,
|
| 356 |
-
float('-inf') if self.metric_scoring != 'eer' else float(
|
| 357 |
-
'inf'))
|
| 358 |
-
# Check if the current score is an improvement
|
| 359 |
-
improved = (metric_one_dataset[self.metric_scoring] > best_metric) if self.metric_scoring != 'eer' else (
|
| 360 |
-
metric_one_dataset[self.metric_scoring] < best_metric)
|
| 361 |
-
if improved:
|
| 362 |
-
# Update the best metric
|
| 363 |
-
self.best_metrics_all_time[key][self.metric_scoring] = metric_one_dataset[self.metric_scoring]
|
| 364 |
-
if key == 'avg':
|
| 365 |
-
self.best_metrics_all_time[key]['dataset_dict'] = metric_one_dataset['dataset_dict']
|
| 366 |
-
# Save checkpoint, feature, and metrics if specified in config
|
| 367 |
-
if eval_stage=='validation' and self.config['save_ckpt'] and key not in FFpp_pool:
|
| 368 |
-
self.save_ckpt(eval_stage, key, f"{epoch}+{iteration}")
|
| 369 |
-
self.save_metrics(eval_stage, metric_one_dataset, key)
|
| 370 |
-
if losses_one_dataset_recorder is not None:
|
| 371 |
-
# info for each dataset
|
| 372 |
-
loss_str = f"dataset: {key} step: {step} "
|
| 373 |
-
for k, v in losses_one_dataset_recorder.items():
|
| 374 |
-
writer = self.get_writer(eval_stage, key, k)
|
| 375 |
-
v_avg = v.average()
|
| 376 |
-
if v_avg == None:
|
| 377 |
-
print(f'{k} is not calculated')
|
| 378 |
-
continue
|
| 379 |
-
# tensorboard-1. loss
|
| 380 |
-
writer.add_scalar(f'{eval_stage}_losses/{k}', v_avg, global_step=step)
|
| 381 |
-
loss_str += f"{eval_stage}-loss, {k}: {v_avg} "
|
| 382 |
-
self.logger.info(loss_str)
|
| 383 |
-
# tqdm.write(loss_str)
|
| 384 |
-
metric_str = f"dataset: {key} step: {step} "
|
| 385 |
-
for k, v in metric_one_dataset.items():
|
| 386 |
-
if k == 'pred' or k == 'label' or k=='dataset_dict':
|
| 387 |
-
continue
|
| 388 |
-
metric_str += f"{eval_stage}-metric, {k}: {v} "
|
| 389 |
-
# tensorboard-2. metric
|
| 390 |
-
writer = self.get_writer(eval_stage, key, k)
|
| 391 |
-
writer.add_scalar(f'{eval_stage}_metrics/{k}', v, global_step=step)
|
| 392 |
-
if 'pred' in metric_one_dataset:
|
| 393 |
-
acc_real, acc_fake = self.get_respect_acc(metric_one_dataset['pred'], metric_one_dataset['label'])
|
| 394 |
-
metric_str += f'{eval_stage}-metric, acc_real:{acc_real}; acc_fake:{acc_fake}'
|
| 395 |
-
writer.add_scalar(f'{eval_stage}_metrics/acc_real', acc_real, global_step=step)
|
| 396 |
-
writer.add_scalar(f'{eval_stage}_metrics/acc_fake', acc_fake, global_step=step)
|
| 397 |
-
self.logger.info(metric_str)
|
| 398 |
-
|
| 399 |
-
def eval(self, eval_data_loaders, eval_stage, step=None, epoch=None, iteration=None):
|
| 400 |
-
# set model to eval mode
|
| 401 |
-
self.setEval()
|
| 402 |
-
|
| 403 |
-
# define eval recorder
|
| 404 |
-
losses_all_datasets = {}
|
| 405 |
-
metrics_all_datasets = {}
|
| 406 |
-
best_metrics_per_dataset = defaultdict(dict) # best metric for each dataset, for each metric
|
| 407 |
-
avg_metric = {'acc': 0, 'auc': 0, 'eer': 0, 'ap': 0,'dataset_dict':{}} #'video_auc': 0
|
| 408 |
-
keys = eval_data_loaders.keys()
|
| 409 |
-
for key in keys:
|
| 410 |
-
# compute loss for each dataset
|
| 411 |
-
losses_one_dataset_recorder, predictions_nps, label_nps, feature_nps = self.eval_one_dataset(eval_data_loaders[key])
|
| 412 |
-
losses_all_datasets[key] = losses_one_dataset_recorder
|
| 413 |
-
metric_one_dataset=get_test_metrics(y_pred=predictions_nps,y_true=label_nps, logger=self.logger)
|
| 414 |
-
|
| 415 |
-
for metric_name, value in metric_one_dataset.items():
|
| 416 |
-
if metric_name in avg_metric:
|
| 417 |
-
avg_metric[metric_name]+=value
|
| 418 |
-
avg_metric['dataset_dict'][key] = metric_one_dataset[self.metric_scoring]
|
| 419 |
-
if type(self.model) is AveragedModel:
|
| 420 |
-
metric_str = f"Iter Final for SWA: "
|
| 421 |
-
for k, v in metric_one_dataset.items():
|
| 422 |
-
metric_str += f"{eval_stage}-metric, {k}: {v} "
|
| 423 |
-
self.logger.info(metric_str)
|
| 424 |
-
continue
|
| 425 |
-
self.save_best(epoch,iteration,step,losses_one_dataset_recorder,key,metric_one_dataset,eval_stage)
|
| 426 |
-
|
| 427 |
-
if len(keys)>0 and self.config.get('save_avg',False):
|
| 428 |
-
# calculate avg value
|
| 429 |
-
for key in avg_metric:
|
| 430 |
-
if key != 'dataset_dict':
|
| 431 |
-
avg_metric[key] /= len(keys)
|
| 432 |
-
self.save_best(epoch, iteration, step, None, 'avg', avg_metric, eval_stage)
|
| 433 |
-
|
| 434 |
-
self.logger.info(f'===> {eval_stage} Done!')
|
| 435 |
-
return self.best_metrics_all_time # return all types of mean metrics for determining the best ckpt
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
@torch.no_grad()
|
| 439 |
-
def inference(self, data_dict):
|
| 440 |
-
predictions = self.model(data_dict, inference=True)
|
| 441 |
-
return predictions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|