Upload 42 files
Browse files- config/__init__.py +7 -0
- config/__pycache__/__init__.cpython-310.pyc +0 -0
- config/__pycache__/constants.cpython-310.pyc +0 -0
- config/constants.py +15 -0
- config/pretrained_config.yaml +94 -0
- config/pretrained_face_config.yaml +94 -0
- config/train_config.yaml +9 -0
- config/ucf.yaml +73 -0
- config/xception.yaml +86 -0
- detectors/__init__.py +11 -0
- detectors/__pycache__/__init__.cpython-310.pyc +0 -0
- detectors/__pycache__/base_detector.cpython-310.pyc +0 -0
- detectors/__pycache__/ucf_detector.cpython-310.pyc +0 -0
- detectors/base_detector.py +71 -0
- detectors/ucf_detector.py +472 -0
- loss/__init__.py +13 -0
- loss/__pycache__/__init__.cpython-310.pyc +0 -0
- loss/__pycache__/abstract_loss_func.cpython-310.pyc +0 -0
- loss/__pycache__/contrastive_regularization.cpython-310.pyc +0 -0
- loss/__pycache__/cross_entropy_loss.cpython-310.pyc +0 -0
- loss/__pycache__/l1_loss.cpython-310.pyc +0 -0
- loss/abstract_loss_func.py +17 -0
- loss/contrastive_regularization.py +78 -0
- loss/cross_entropy_loss.py +26 -0
- loss/l1_loss.py +19 -0
- metrics/__init__.py +7 -0
- metrics/__pycache__/__init__.cpython-310.pyc +0 -0
- metrics/__pycache__/base_metrics_class.cpython-310.pyc +0 -0
- metrics/__pycache__/registry.cpython-310.pyc +0 -0
- metrics/base_metrics_class.py +205 -0
- metrics/registry.py +20 -0
- metrics/utils.py +88 -0
- networks/__init__.py +11 -0
- networks/__pycache__/__init__.cpython-310.pyc +0 -0
- networks/__pycache__/xception.cpython-310.pyc +0 -0
- networks/xception.py +285 -0
- optimizor/LinearLR.py +20 -0
- optimizor/SAM.py +77 -0
- trainer/trainer.py +441 -0
config/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
current_file_path = os.path.abspath(__file__)
|
| 4 |
+
parent_dir = os.path.dirname(os.path.dirname(current_file_path))
|
| 5 |
+
project_root_dir = os.path.dirname(parent_dir)
|
| 6 |
+
sys.path.append(parent_dir)
|
| 7 |
+
sys.path.append(project_root_dir)
|
config/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (350 Bytes). View file
|
|
|
config/__pycache__/constants.cpython-310.pyc
ADDED
|
Binary file (543 Bytes). View file
|
|
|
config/constants.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
# Path to the directory containing the constants.py file
|
| 4 |
+
CONFIGS_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 5 |
+
|
| 6 |
+
# The base directory for UCF-related files, i.e., UCF directory
|
| 7 |
+
UCF_BASE_PATH = os.path.abspath(os.path.join(CONFIGS_DIR, "..")) # Points to bitmind-subnet/base_miner/UCF/
|
| 8 |
+
# Absolute paths for the required files and directories
|
| 9 |
+
CONFIG_PATH = os.path.join(CONFIGS_DIR, "ucf.yaml") # Path to the ucf.yaml file
|
| 10 |
+
WEIGHTS_DIR = os.path.join(UCF_BASE_PATH, "weights/") # Path to pretrained weights directory
|
| 11 |
+
|
| 12 |
+
HF_REPO = "bitmind/ucf"
|
| 13 |
+
BACKBONE_CKPT = "xception_best.pth"
|
| 14 |
+
|
| 15 |
+
DLIB_FACE_PREDICTOR_PATH = os.path.abspath(os.path.join(UCF_BASE_PATH, "../../utils/dlib_tools/shape_predictor_81_face_landmarks.dat"))
|
config/pretrained_config.yaml
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
SWA: false
|
| 2 |
+
backbone_config:
|
| 3 |
+
dropout: false
|
| 4 |
+
inc: 3
|
| 5 |
+
mode: adjust_channel
|
| 6 |
+
num_classes: 2
|
| 7 |
+
backbone_name: xception
|
| 8 |
+
compression: c23
|
| 9 |
+
cuda: true
|
| 10 |
+
cudnn: true
|
| 11 |
+
dataset_json_folder: preprocessing/dataset_json_v3
|
| 12 |
+
dataset_meta:
|
| 13 |
+
fake:
|
| 14 |
+
- create_splits: false
|
| 15 |
+
path: bitmind/celeb-a-hq___stable-diffusion-xl-base-1.0___256_training_faces
|
| 16 |
+
- create_splits: false
|
| 17 |
+
path: bitmind/ffhq-256___stable-diffusion-xl-base-1.0_training_faces
|
| 18 |
+
real:
|
| 19 |
+
- create_splits: false
|
| 20 |
+
path: bitmind/celeb-a-hq_training_faces
|
| 21 |
+
- create_splits: false
|
| 22 |
+
path: bitmind/ffhq-256_training_faces
|
| 23 |
+
ddp: false
|
| 24 |
+
dry_run: false
|
| 25 |
+
encoder_feat_dim: 512
|
| 26 |
+
faces_only: true
|
| 27 |
+
frame_num:
|
| 28 |
+
test: 32
|
| 29 |
+
train: 32
|
| 30 |
+
lmdb: true
|
| 31 |
+
lmdb_dir: ./datasets/lmdb
|
| 32 |
+
local_rank: 0
|
| 33 |
+
log_dir: ./logs/training/ucf_2024-09-17-16-44-50
|
| 34 |
+
logdir: ./logs
|
| 35 |
+
loss_func:
|
| 36 |
+
cls_loss: cross_entropy
|
| 37 |
+
con_loss: contrastive_regularization
|
| 38 |
+
rec_loss: l1loss
|
| 39 |
+
spe_loss: cross_entropy
|
| 40 |
+
losstype: null
|
| 41 |
+
lr_scheduler: null
|
| 42 |
+
manualSeed: 1024
|
| 43 |
+
mean:
|
| 44 |
+
- 0.5
|
| 45 |
+
- 0.5
|
| 46 |
+
- 0.5
|
| 47 |
+
metric_scoring: auc
|
| 48 |
+
mode: train
|
| 49 |
+
model_name: ucf
|
| 50 |
+
nEpochs: 2
|
| 51 |
+
optimizer:
|
| 52 |
+
adam:
|
| 53 |
+
amsgrad: false
|
| 54 |
+
beta1: 0.9
|
| 55 |
+
beta2: 0.999
|
| 56 |
+
eps: 1.0e-08
|
| 57 |
+
lr: 0.0002
|
| 58 |
+
weight_decay: 0.0005
|
| 59 |
+
sgd:
|
| 60 |
+
lr: 0.0002
|
| 61 |
+
momentum: 0.9
|
| 62 |
+
weight_decay: 0.0005
|
| 63 |
+
type: adam
|
| 64 |
+
pretrained: ../weights/xception_best.pth
|
| 65 |
+
rec_iter: 100
|
| 66 |
+
resolution: 256
|
| 67 |
+
rgb_dir: ./datasets/rgb
|
| 68 |
+
save_avg: true
|
| 69 |
+
save_ckpt: true
|
| 70 |
+
save_epoch: 1
|
| 71 |
+
save_feat: true
|
| 72 |
+
specific_task_number: 2
|
| 73 |
+
split_transforms:
|
| 74 |
+
test:
|
| 75 |
+
name: base_transforms
|
| 76 |
+
train:
|
| 77 |
+
name: random_aug_transforms
|
| 78 |
+
validation:
|
| 79 |
+
name: base_transforms
|
| 80 |
+
start_epoch: 0
|
| 81 |
+
std:
|
| 82 |
+
- 0.5
|
| 83 |
+
- 0.5
|
| 84 |
+
- 0.5
|
| 85 |
+
test_batchSize: 32
|
| 86 |
+
train_batchSize: 32
|
| 87 |
+
train_dataset:
|
| 88 |
+
- bitmind/celeb-a-hq_training_faces
|
| 89 |
+
- bitmind/ffhq-256_training_faces
|
| 90 |
+
- bitmind/celeb-a-hq___stable-diffusion-xl-base-1.0___256_training_faces
|
| 91 |
+
- bitmind/ffhq-256___stable-diffusion-xl-base-1.0_training_faces
|
| 92 |
+
with_landmark: false
|
| 93 |
+
with_mask: false
|
| 94 |
+
workers: 7
|
config/pretrained_face_config.yaml
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
SWA: false
|
| 2 |
+
backbone_config:
|
| 3 |
+
dropout: false
|
| 4 |
+
inc: 3
|
| 5 |
+
mode: adjust_channel
|
| 6 |
+
num_classes: 2
|
| 7 |
+
backbone_name: xception
|
| 8 |
+
compression: c23
|
| 9 |
+
cuda: true
|
| 10 |
+
cudnn: true
|
| 11 |
+
dataset_json_folder: preprocessing/dataset_json_v3
|
| 12 |
+
dataset_meta:
|
| 13 |
+
fake:
|
| 14 |
+
- create_splits: false
|
| 15 |
+
path: bitmind/celeb-a-hq___stable-diffusion-xl-base-1.0___256_training_faces
|
| 16 |
+
- create_splits: false
|
| 17 |
+
path: bitmind/ffhq-256___stable-diffusion-xl-base-1.0_training_faces
|
| 18 |
+
real:
|
| 19 |
+
- create_splits: false
|
| 20 |
+
path: bitmind/celeb-a-hq_training_faces
|
| 21 |
+
- create_splits: false
|
| 22 |
+
path: bitmind/ffhq-256_training_faces
|
| 23 |
+
ddp: false
|
| 24 |
+
dry_run: false
|
| 25 |
+
encoder_feat_dim: 512
|
| 26 |
+
faces_only: true
|
| 27 |
+
frame_num:
|
| 28 |
+
test: 32
|
| 29 |
+
train: 32
|
| 30 |
+
lmdb: true
|
| 31 |
+
lmdb_dir: ./datasets/lmdb
|
| 32 |
+
local_rank: 0
|
| 33 |
+
log_dir: ./logs/training/ucf_2024-09-17-16-44-50
|
| 34 |
+
logdir: ./logs
|
| 35 |
+
loss_func:
|
| 36 |
+
cls_loss: cross_entropy
|
| 37 |
+
con_loss: contrastive_regularization
|
| 38 |
+
rec_loss: l1loss
|
| 39 |
+
spe_loss: cross_entropy
|
| 40 |
+
losstype: null
|
| 41 |
+
lr_scheduler: null
|
| 42 |
+
manualSeed: 1024
|
| 43 |
+
mean:
|
| 44 |
+
- 0.5
|
| 45 |
+
- 0.5
|
| 46 |
+
- 0.5
|
| 47 |
+
metric_scoring: auc
|
| 48 |
+
mode: train
|
| 49 |
+
model_name: ucf
|
| 50 |
+
nEpochs: 2
|
| 51 |
+
optimizer:
|
| 52 |
+
adam:
|
| 53 |
+
amsgrad: false
|
| 54 |
+
beta1: 0.9
|
| 55 |
+
beta2: 0.999
|
| 56 |
+
eps: 1.0e-08
|
| 57 |
+
lr: 0.0002
|
| 58 |
+
weight_decay: 0.0005
|
| 59 |
+
sgd:
|
| 60 |
+
lr: 0.0002
|
| 61 |
+
momentum: 0.9
|
| 62 |
+
weight_decay: 0.0005
|
| 63 |
+
type: adam
|
| 64 |
+
pretrained: ../weights/xception_best.pth
|
| 65 |
+
rec_iter: 100
|
| 66 |
+
resolution: 256
|
| 67 |
+
rgb_dir: ./datasets/rgb
|
| 68 |
+
save_avg: true
|
| 69 |
+
save_ckpt: true
|
| 70 |
+
save_epoch: 1
|
| 71 |
+
save_feat: true
|
| 72 |
+
specific_task_number: 2
|
| 73 |
+
split_transforms:
|
| 74 |
+
test:
|
| 75 |
+
name: base_transforms
|
| 76 |
+
train:
|
| 77 |
+
name: random_aug_transforms
|
| 78 |
+
validation:
|
| 79 |
+
name: base_transforms
|
| 80 |
+
start_epoch: 0
|
| 81 |
+
std:
|
| 82 |
+
- 0.5
|
| 83 |
+
- 0.5
|
| 84 |
+
- 0.5
|
| 85 |
+
test_batchSize: 32
|
| 86 |
+
train_batchSize: 32
|
| 87 |
+
train_dataset:
|
| 88 |
+
- bitmind/celeb-a-hq_training_faces
|
| 89 |
+
- bitmind/ffhq-256_training_faces
|
| 90 |
+
- bitmind/celeb-a-hq___stable-diffusion-xl-base-1.0___256_training_faces
|
| 91 |
+
- bitmind/ffhq-256___stable-diffusion-xl-base-1.0_training_faces
|
| 92 |
+
with_landmark: false
|
| 93 |
+
with_mask: false
|
| 94 |
+
workers: 7
|
config/train_config.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
mode: train
|
| 2 |
+
lmdb: True
|
| 3 |
+
dry_run: false
|
| 4 |
+
rgb_dir: './datasets/rgb'
|
| 5 |
+
lmdb_dir: './datasets/lmdb'
|
| 6 |
+
dataset_json_folder: './preprocessing/dataset_json'
|
| 7 |
+
SWA: False
|
| 8 |
+
save_avg: True
|
| 9 |
+
log_dir: ./logs/training/
|
config/ucf.yaml
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# log dir
|
| 2 |
+
log_dir: ../debug_logs/ucf
|
| 3 |
+
|
| 4 |
+
# model setting
|
| 5 |
+
pretrained: ../weights/xception_best.pth # path to a pre-trained model, if using one
|
| 6 |
+
model_name: ucf # model name
|
| 7 |
+
backbone_name: xception # backbone name
|
| 8 |
+
encoder_feat_dim: 512 # feature dimension of the backbone
|
| 9 |
+
|
| 10 |
+
#backbone setting
|
| 11 |
+
backbone_config:
|
| 12 |
+
mode: adjust_channel
|
| 13 |
+
num_classes: 2
|
| 14 |
+
inc: 3
|
| 15 |
+
dropout: false
|
| 16 |
+
|
| 17 |
+
compression: c23 # compression-level for videos
|
| 18 |
+
train_batchSize: 32 # training batch size
|
| 19 |
+
test_batchSize: 32 # test batch size
|
| 20 |
+
workers: 8 # number of data loading workers
|
| 21 |
+
frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing
|
| 22 |
+
resolution: 256 # resolution of output image to network
|
| 23 |
+
with_mask: false # whether to include mask information in the input
|
| 24 |
+
with_landmark: false # whether to include facial landmark information in the input
|
| 25 |
+
save_ckpt: true # whether to save checkpoint
|
| 26 |
+
save_feat: true # whether to save features
|
| 27 |
+
specific_task_number: 2 # default num datasets in FF++ used by DFB, overwritten in training
|
| 28 |
+
|
| 29 |
+
# mean and std for normalization
|
| 30 |
+
mean: [0.5, 0.5, 0.5]
|
| 31 |
+
std: [0.5, 0.5, 0.5]
|
| 32 |
+
|
| 33 |
+
# optimizer config
|
| 34 |
+
optimizer:
|
| 35 |
+
# choose between 'adam' and 'sgd'
|
| 36 |
+
type: adam
|
| 37 |
+
adam:
|
| 38 |
+
lr: 0.0002 # learning rate
|
| 39 |
+
beta1: 0.9 # beta1 for Adam optimizer
|
| 40 |
+
beta2: 0.999 # beta2 for Adam optimizer
|
| 41 |
+
eps: 0.00000001 # epsilon for Adam optimizer
|
| 42 |
+
weight_decay: 0.0005 # weight decay for regularization
|
| 43 |
+
amsgrad: false
|
| 44 |
+
sgd:
|
| 45 |
+
lr: 0.0002 # learning rate
|
| 46 |
+
momentum: 0.9 # momentum for SGD optimizer
|
| 47 |
+
weight_decay: 0.0005 # weight decay for regularization
|
| 48 |
+
|
| 49 |
+
# training config
|
| 50 |
+
lr_scheduler: null # learning rate scheduler
|
| 51 |
+
nEpochs: 20 # number of epochs to train for
|
| 52 |
+
start_epoch: 0 # manual epoch number (useful for restarts)
|
| 53 |
+
save_epoch: 1 # interval epochs for saving models
|
| 54 |
+
rec_iter: 100 # interval iterations for recording
|
| 55 |
+
logdir: ./logs # folder to output images and logs
|
| 56 |
+
manualSeed: 1024 # manual seed for random number generation
|
| 57 |
+
save_ckpt: false # whether to save checkpoint
|
| 58 |
+
|
| 59 |
+
# loss function
|
| 60 |
+
loss_func:
|
| 61 |
+
cls_loss: cross_entropy # loss function to use
|
| 62 |
+
spe_loss: cross_entropy
|
| 63 |
+
con_loss: contrastive_regularization
|
| 64 |
+
rec_loss: l1loss
|
| 65 |
+
losstype: null
|
| 66 |
+
|
| 67 |
+
# metric
|
| 68 |
+
metric_scoring: auc # metric for evaluation (auc, acc, eer, ap)
|
| 69 |
+
|
| 70 |
+
# cuda
|
| 71 |
+
|
| 72 |
+
cuda: true # whether to use CUDA acceleration
|
| 73 |
+
cudnn: true # whether to use CuDNN for convolution operations
|
config/xception.yaml
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# log dir
|
| 2 |
+
log_dir: /data/home/zhiyuanyan/DeepfakeBench/logs/testing_bench
|
| 3 |
+
|
| 4 |
+
# model setting
|
| 5 |
+
pretrained: /data/home/zhiyuanyan/DeepfakeBench/training/pretrained/xception-b5690688.pth # path to a pre-trained model, if using one
|
| 6 |
+
model_name: xception # model name
|
| 7 |
+
backbone_name: xception # backbone name
|
| 8 |
+
|
| 9 |
+
#backbone setting
|
| 10 |
+
backbone_config:
|
| 11 |
+
mode: original
|
| 12 |
+
num_classes: 2
|
| 13 |
+
inc: 3
|
| 14 |
+
dropout: false
|
| 15 |
+
|
| 16 |
+
# dataset
|
| 17 |
+
all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV]
|
| 18 |
+
train_dataset: [FaceForensics++]
|
| 19 |
+
test_dataset: [FaceForensics++, DeepFakeDetection]
|
| 20 |
+
|
| 21 |
+
compression: c23 # compression-level for videos
|
| 22 |
+
train_batchSize: 32 # training batch size
|
| 23 |
+
test_batchSize: 32 # test batch size
|
| 24 |
+
workers: 8 # number of data loading workers
|
| 25 |
+
frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing
|
| 26 |
+
resolution: 256 # resolution of output image to network
|
| 27 |
+
with_mask: false # whether to include mask information in the input
|
| 28 |
+
with_landmark: false # whether to include facial landmark information in the input
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# data augmentation
|
| 32 |
+
use_data_augmentation: true # Add this flag to enable/disable data augmentation
|
| 33 |
+
data_aug:
|
| 34 |
+
flip_prob: 0.5
|
| 35 |
+
rotate_prob: 0.0
|
| 36 |
+
rotate_limit: [-10, 10]
|
| 37 |
+
blur_prob: 0.5
|
| 38 |
+
blur_limit: [3, 7]
|
| 39 |
+
brightness_prob: 0.5
|
| 40 |
+
brightness_limit: [-0.1, 0.1]
|
| 41 |
+
contrast_limit: [-0.1, 0.1]
|
| 42 |
+
quality_lower: 40
|
| 43 |
+
quality_upper: 100
|
| 44 |
+
|
| 45 |
+
# mean and std for normalization
|
| 46 |
+
mean: [0.5, 0.5, 0.5]
|
| 47 |
+
std: [0.5, 0.5, 0.5]
|
| 48 |
+
|
| 49 |
+
# optimizer config
|
| 50 |
+
optimizer:
|
| 51 |
+
# choose between 'adam' and 'sgd'
|
| 52 |
+
type: adam
|
| 53 |
+
adam:
|
| 54 |
+
lr: 0.0002 # learning rate
|
| 55 |
+
beta1: 0.9 # beta1 for Adam optimizer
|
| 56 |
+
beta2: 0.999 # beta2 for Adam optimizer
|
| 57 |
+
eps: 0.00000001 # epsilon for Adam optimizer
|
| 58 |
+
weight_decay: 0.0005 # weight decay for regularization
|
| 59 |
+
amsgrad: false
|
| 60 |
+
sgd:
|
| 61 |
+
lr: 0.0002 # learning rate
|
| 62 |
+
momentum: 0.9 # momentum for SGD optimizer
|
| 63 |
+
weight_decay: 0.0005 # weight decay for regularization
|
| 64 |
+
|
| 65 |
+
# training config
|
| 66 |
+
lr_scheduler: null # learning rate scheduler
|
| 67 |
+
nEpochs: 10 # number of epochs to train for
|
| 68 |
+
start_epoch: 0 # manual epoch number (useful for restarts)
|
| 69 |
+
save_epoch: 1 # interval epochs for saving models
|
| 70 |
+
rec_iter: 100 # interval iterations for recording
|
| 71 |
+
logdir: ./logs # folder to output images and logs
|
| 72 |
+
manualSeed: 1024 # manual seed for random number generation
|
| 73 |
+
save_ckpt: true # whether to save checkpoint
|
| 74 |
+
save_feat: true # whether to save features
|
| 75 |
+
|
| 76 |
+
# loss function
|
| 77 |
+
loss_func: cross_entropy # loss function to use
|
| 78 |
+
losstype: null
|
| 79 |
+
|
| 80 |
+
# metric
|
| 81 |
+
metric_scoring: auc # metric for evaluation (auc, acc, eer, ap)
|
| 82 |
+
|
| 83 |
+
# cuda
|
| 84 |
+
|
| 85 |
+
cuda: true # whether to use CUDA acceleration
|
| 86 |
+
cudnn: true # whether to use CuDNN for convolution operations
|
detectors/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
current_file_path = os.path.abspath(__file__)
|
| 4 |
+
parent_dir = os.path.dirname(os.path.dirname(current_file_path))
|
| 5 |
+
project_root_dir = os.path.dirname(parent_dir)
|
| 6 |
+
sys.path.append(parent_dir)
|
| 7 |
+
sys.path.append(project_root_dir)
|
| 8 |
+
|
| 9 |
+
from metrics.registry import DETECTOR
|
| 10 |
+
|
| 11 |
+
from .ucf_detector import UCFDetector
|
detectors/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (455 Bytes). View file
|
|
|
detectors/__pycache__/base_detector.cpython-310.pyc
ADDED
|
Binary file (2.57 kB). View file
|
|
|
detectors/__pycache__/ucf_detector.cpython-310.pyc
ADDED
|
Binary file (12.9 kB). View file
|
|
|
detectors/base_detector.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# author: Zhiyuan Yan
|
| 2 |
+
# email: zhiyuanyan@link.cuhk.edu.cn
|
| 3 |
+
# date: 2023-0706
|
| 4 |
+
# description: Abstract Class for the Deepfake Detector
|
| 5 |
+
|
| 6 |
+
import abc
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from typing import Union
|
| 10 |
+
|
| 11 |
+
class AbstractDetector(nn.Module, metaclass=abc.ABCMeta):
|
| 12 |
+
"""
|
| 13 |
+
All deepfake detectors should subclass this class.
|
| 14 |
+
"""
|
| 15 |
+
def __init__(self, config=None, load_param: Union[bool, str] = False):
|
| 16 |
+
"""
|
| 17 |
+
config: (dict)
|
| 18 |
+
configurations for the model
|
| 19 |
+
load_param: (False | True | Path(str))
|
| 20 |
+
False Do not read; True Read the default path; Path Read the required path
|
| 21 |
+
"""
|
| 22 |
+
super().__init__()
|
| 23 |
+
|
| 24 |
+
@abc.abstractmethod
|
| 25 |
+
def features(self, data_dict: dict) -> torch.tensor:
|
| 26 |
+
"""
|
| 27 |
+
Returns the features from the backbone given the input data.
|
| 28 |
+
"""
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
@abc.abstractmethod
|
| 32 |
+
def forward(self, data_dict: dict, inference=False) -> dict:
|
| 33 |
+
"""
|
| 34 |
+
Forward pass through the model, returning the prediction dictionary.
|
| 35 |
+
"""
|
| 36 |
+
pass
|
| 37 |
+
|
| 38 |
+
@abc.abstractmethod
|
| 39 |
+
def classifier(self, features: torch.tensor) -> torch.tensor:
|
| 40 |
+
"""
|
| 41 |
+
Classifies the features into classes.
|
| 42 |
+
"""
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
@abc.abstractmethod
|
| 46 |
+
def build_backbone(self, config):
|
| 47 |
+
"""
|
| 48 |
+
Builds the backbone of the model.
|
| 49 |
+
"""
|
| 50 |
+
pass
|
| 51 |
+
|
| 52 |
+
@abc.abstractmethod
|
| 53 |
+
def build_loss(self, config):
|
| 54 |
+
"""
|
| 55 |
+
Builds the loss function for the model.
|
| 56 |
+
"""
|
| 57 |
+
pass
|
| 58 |
+
|
| 59 |
+
@abc.abstractmethod
|
| 60 |
+
def get_losses(self, data_dict: dict, pred_dict: dict) -> dict:
|
| 61 |
+
"""
|
| 62 |
+
Returns the losses for the model.
|
| 63 |
+
"""
|
| 64 |
+
pass
|
| 65 |
+
|
| 66 |
+
@abc.abstractmethod
|
| 67 |
+
def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict:
|
| 68 |
+
"""
|
| 69 |
+
Returns the training metrics for the model.
|
| 70 |
+
"""
|
| 71 |
+
pass
|
detectors/ucf_detector.py
ADDED
|
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
# Source: https://github.com/SCLBD/DeepfakeBench/blob/main/training/detectors/ucf_detector.py
|
| 3 |
+
# author: Zhiyuan Yan
|
| 4 |
+
# email: zhiyuanyan@link.cuhk.edu.cn
|
| 5 |
+
# date: 2023-0706
|
| 6 |
+
# description: Class for the UCFDetector
|
| 7 |
+
|
| 8 |
+
Functions in the Class are summarized as:
|
| 9 |
+
1. __init__: Initialization
|
| 10 |
+
2. build_backbone: Backbone-building
|
| 11 |
+
3. build_loss: Loss-function-building
|
| 12 |
+
4. features: Feature-extraction
|
| 13 |
+
5. classifier: Classification
|
| 14 |
+
6. get_losses: Loss-computation
|
| 15 |
+
7. get_train_metrics: Training-metrics-computation
|
| 16 |
+
8. get_test_metrics: Testing-metrics-computation
|
| 17 |
+
9. forward: Forward-propagation
|
| 18 |
+
|
| 19 |
+
Reference:
|
| 20 |
+
@article{yan2023ucf,
|
| 21 |
+
title={UCF: Uncovering Common Features for Generalizable Deepfake Detection},
|
| 22 |
+
author={Yan, Zhiyuan and Zhang, Yong and Fan, Yanbo and Wu, Baoyuan},
|
| 23 |
+
journal={arXiv preprint arXiv:2304.13949},
|
| 24 |
+
year={2023}
|
| 25 |
+
}
|
| 26 |
+
'''
|
| 27 |
+
|
| 28 |
+
import os
|
| 29 |
+
import datetime
|
| 30 |
+
import logging
|
| 31 |
+
import random
|
| 32 |
+
import numpy as np
|
| 33 |
+
from sklearn import metrics
|
| 34 |
+
from typing import Union
|
| 35 |
+
from collections import defaultdict
|
| 36 |
+
|
| 37 |
+
import torch
|
| 38 |
+
import torch.nn as nn
|
| 39 |
+
import torch.nn.functional as F
|
| 40 |
+
import torch.optim as optim
|
| 41 |
+
from torch.nn import DataParallel
|
| 42 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 43 |
+
|
| 44 |
+
from metrics.base_metrics_class import calculate_metrics_for_train
|
| 45 |
+
|
| 46 |
+
from .base_detector import AbstractDetector
|
| 47 |
+
from arena.detectors.UCF.detectors import DETECTOR
|
| 48 |
+
from networks import BACKBONE
|
| 49 |
+
from loss import LOSSFUNC
|
| 50 |
+
|
| 51 |
+
logger = logging.getLogger(__name__)
|
| 52 |
+
|
| 53 |
+
@DETECTOR.register_module(module_name='ucf')
|
| 54 |
+
class UCFDetector(AbstractDetector):
|
| 55 |
+
def __init__(self, config):
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.config = config
|
| 58 |
+
self.num_classes = config['backbone_config']['num_classes']
|
| 59 |
+
self.encoder_feat_dim = config['encoder_feat_dim']
|
| 60 |
+
self.half_fingerprint_dim = self.encoder_feat_dim//2
|
| 61 |
+
|
| 62 |
+
self.encoder_f = self.build_backbone(config)
|
| 63 |
+
self.encoder_c = self.build_backbone(config)
|
| 64 |
+
|
| 65 |
+
self.loss_func = self.build_loss(config)
|
| 66 |
+
self.prob, self.label = [], []
|
| 67 |
+
self.correct, self.total = 0, 0
|
| 68 |
+
|
| 69 |
+
# basic function
|
| 70 |
+
self.lr = nn.LeakyReLU(inplace=True)
|
| 71 |
+
self.do = nn.Dropout(0.2)
|
| 72 |
+
self.pool = nn.AdaptiveAvgPool2d(1)
|
| 73 |
+
|
| 74 |
+
# conditional gan
|
| 75 |
+
self.con_gan = Conditional_UNet()
|
| 76 |
+
|
| 77 |
+
# head
|
| 78 |
+
specific_task_number = config['specific_task_number']
|
| 79 |
+
|
| 80 |
+
self.head_spe = Head(
|
| 81 |
+
in_f=self.half_fingerprint_dim,
|
| 82 |
+
hidden_dim=self.encoder_feat_dim,
|
| 83 |
+
out_f=specific_task_number
|
| 84 |
+
)
|
| 85 |
+
self.head_sha = Head(
|
| 86 |
+
in_f=self.half_fingerprint_dim,
|
| 87 |
+
hidden_dim=self.encoder_feat_dim,
|
| 88 |
+
out_f=self.num_classes
|
| 89 |
+
)
|
| 90 |
+
self.block_spe = Conv2d1x1(
|
| 91 |
+
in_f=self.encoder_feat_dim,
|
| 92 |
+
hidden_dim=self.half_fingerprint_dim,
|
| 93 |
+
out_f=self.half_fingerprint_dim
|
| 94 |
+
)
|
| 95 |
+
self.block_sha = Conv2d1x1(
|
| 96 |
+
in_f=self.encoder_feat_dim,
|
| 97 |
+
hidden_dim=self.half_fingerprint_dim,
|
| 98 |
+
out_f=self.half_fingerprint_dim
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
def build_backbone(self, config):
|
| 102 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 103 |
+
pretrained_path = os.path.join(current_dir, config['pretrained'])
|
| 104 |
+
# prepare the backbone
|
| 105 |
+
backbone_class = BACKBONE[config['backbone_name']]
|
| 106 |
+
model_config = config['backbone_config']
|
| 107 |
+
backbone = backbone_class(model_config)
|
| 108 |
+
# if donot load the pretrained weights, fail to get good results
|
| 109 |
+
state_dict = torch.load(pretrained_path)
|
| 110 |
+
for name, weights in state_dict.items():
|
| 111 |
+
if 'pointwise' in name:
|
| 112 |
+
state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1)
|
| 113 |
+
state_dict = {k:v for k, v in state_dict.items() if 'fc' not in k}
|
| 114 |
+
backbone.load_state_dict(state_dict, False)
|
| 115 |
+
logger.info('Load pretrained model successfully!')
|
| 116 |
+
return backbone
|
| 117 |
+
|
| 118 |
+
def build_loss(self, config):
|
| 119 |
+
cls_loss_class = LOSSFUNC[config['loss_func']['cls_loss']]
|
| 120 |
+
spe_loss_class = LOSSFUNC[config['loss_func']['spe_loss']]
|
| 121 |
+
con_loss_class = LOSSFUNC[config['loss_func']['con_loss']]
|
| 122 |
+
rec_loss_class = LOSSFUNC[config['loss_func']['rec_loss']]
|
| 123 |
+
cls_loss_func = cls_loss_class()
|
| 124 |
+
spe_loss_func = spe_loss_class()
|
| 125 |
+
con_loss_func = con_loss_class(margin=3.0)
|
| 126 |
+
rec_loss_func = rec_loss_class()
|
| 127 |
+
loss_func = {
|
| 128 |
+
'cls': cls_loss_func,
|
| 129 |
+
'spe': spe_loss_func,
|
| 130 |
+
'con': con_loss_func,
|
| 131 |
+
'rec': rec_loss_func,
|
| 132 |
+
}
|
| 133 |
+
return loss_func
|
| 134 |
+
|
| 135 |
+
def features(self, data_dict: dict) -> torch.tensor:
|
| 136 |
+
cat_data = data_dict['image']
|
| 137 |
+
# encoder
|
| 138 |
+
f_all = self.encoder_f.features(cat_data)
|
| 139 |
+
c_all = self.encoder_c.features(cat_data)
|
| 140 |
+
feat_dict = {'forgery': f_all, 'content': c_all}
|
| 141 |
+
return feat_dict
|
| 142 |
+
|
| 143 |
+
def classifier(self, features: torch.tensor) -> torch.tensor:
|
| 144 |
+
# classification, multi-task
|
| 145 |
+
# split the features into the specific and common forgery
|
| 146 |
+
f_spe = self.block_spe(features)
|
| 147 |
+
f_share = self.block_sha(features)
|
| 148 |
+
return f_spe, f_share
|
| 149 |
+
|
| 150 |
+
def get_losses(self, data_dict: dict, pred_dict: dict) -> dict:
|
| 151 |
+
if 'label_spe' in data_dict and 'recontruction_imgs' in pred_dict:
|
| 152 |
+
return self.get_train_losses(data_dict, pred_dict)
|
| 153 |
+
else: # test mode
|
| 154 |
+
return self.get_test_losses(data_dict, pred_dict)
|
| 155 |
+
|
| 156 |
+
def get_train_losses(self, data_dict: dict, pred_dict: dict) -> dict:
|
| 157 |
+
# get combined, real, fake imgs
|
| 158 |
+
cat_data = data_dict['image']
|
| 159 |
+
real_img, fake_img = cat_data.chunk(2, dim=0)
|
| 160 |
+
# get the reconstruction imgs
|
| 161 |
+
reconstruction_image_1, \
|
| 162 |
+
reconstruction_image_2, \
|
| 163 |
+
self_reconstruction_image_1, \
|
| 164 |
+
self_reconstruction_image_2 \
|
| 165 |
+
= pred_dict['recontruction_imgs']
|
| 166 |
+
# get label
|
| 167 |
+
label = data_dict['label']
|
| 168 |
+
label_spe = data_dict['label_spe']
|
| 169 |
+
# get pred
|
| 170 |
+
pred = pred_dict['cls']
|
| 171 |
+
pred_spe = pred_dict['cls_spe']
|
| 172 |
+
|
| 173 |
+
# 1. classification loss for common features
|
| 174 |
+
loss_sha = self.loss_func['cls'](pred, label)
|
| 175 |
+
|
| 176 |
+
# 2. classification loss for specific features
|
| 177 |
+
loss_spe = self.loss_func['spe'](pred_spe, label_spe)
|
| 178 |
+
|
| 179 |
+
# 3. reconstruction loss
|
| 180 |
+
self_loss_reconstruction_1 = self.loss_func['rec'](fake_img, self_reconstruction_image_1)
|
| 181 |
+
self_loss_reconstruction_2 = self.loss_func['rec'](real_img, self_reconstruction_image_2)
|
| 182 |
+
cross_loss_reconstruction_1 = self.loss_func['rec'](fake_img, reconstruction_image_2)
|
| 183 |
+
cross_loss_reconstruction_2 = self.loss_func['rec'](real_img, reconstruction_image_1)
|
| 184 |
+
loss_reconstruction = \
|
| 185 |
+
self_loss_reconstruction_1 + self_loss_reconstruction_2 + \
|
| 186 |
+
cross_loss_reconstruction_1 + cross_loss_reconstruction_2
|
| 187 |
+
|
| 188 |
+
# 4. constrative loss
|
| 189 |
+
common_features = pred_dict['feat']
|
| 190 |
+
specific_features = pred_dict['feat_spe']
|
| 191 |
+
loss_con = self.loss_func['con'](common_features, specific_features, label_spe)
|
| 192 |
+
|
| 193 |
+
# 5. total loss
|
| 194 |
+
loss = loss_sha + 0.1*loss_spe + 0.3*loss_reconstruction + 0.05*loss_con
|
| 195 |
+
loss_dict = {
|
| 196 |
+
'overall': loss,
|
| 197 |
+
'common': loss_sha,
|
| 198 |
+
'specific': loss_spe,
|
| 199 |
+
'reconstruction': loss_reconstruction,
|
| 200 |
+
'contrastive': loss_con,
|
| 201 |
+
}
|
| 202 |
+
return loss_dict
|
| 203 |
+
|
| 204 |
+
def get_test_losses(self, data_dict: dict, pred_dict: dict) -> dict:
|
| 205 |
+
# get label
|
| 206 |
+
label = data_dict['label']
|
| 207 |
+
# get pred
|
| 208 |
+
pred = pred_dict['cls']
|
| 209 |
+
# for test mode, only classification loss for common features
|
| 210 |
+
loss = self.loss_func['cls'](pred, label)
|
| 211 |
+
loss_dict = {'common': loss}
|
| 212 |
+
return loss_dict
|
| 213 |
+
|
| 214 |
+
def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict:
|
| 215 |
+
def get_accracy(label, output):
|
| 216 |
+
_, prediction = torch.max(output, 1) # argmax
|
| 217 |
+
correct = (prediction == label).sum().item()
|
| 218 |
+
accuracy = correct / prediction.size(0)
|
| 219 |
+
return accuracy
|
| 220 |
+
|
| 221 |
+
# get pred and label
|
| 222 |
+
label = data_dict['label']
|
| 223 |
+
pred = pred_dict['cls']
|
| 224 |
+
label_spe = data_dict['label_spe']
|
| 225 |
+
pred_spe = pred_dict['cls_spe']
|
| 226 |
+
|
| 227 |
+
# compute metrics for batch data
|
| 228 |
+
auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach())
|
| 229 |
+
acc_spe = get_accracy(label_spe.detach(), pred_spe.detach())
|
| 230 |
+
metric_batch_dict = {'acc': acc, 'acc_spe': acc_spe, 'auc': auc, 'eer': eer, 'ap': ap}
|
| 231 |
+
# we dont compute the video-level metrics for training
|
| 232 |
+
return metric_batch_dict
|
| 233 |
+
|
| 234 |
+
def forward(self, data_dict: dict, inference=False) -> dict:
|
| 235 |
+
# split the features into the content and forgery
|
| 236 |
+
features = self.features(data_dict)
|
| 237 |
+
forgery_features, content_features = features['forgery'], features['content']
|
| 238 |
+
# get the prediction by classifier (split the common and specific forgery)
|
| 239 |
+
f_spe, f_share = self.classifier(forgery_features)
|
| 240 |
+
|
| 241 |
+
if inference:
|
| 242 |
+
# inference only consider share loss
|
| 243 |
+
out_sha, sha_feat = self.head_sha(f_share)
|
| 244 |
+
out_spe, spe_feat = self.head_spe(f_spe)
|
| 245 |
+
prob_sha = torch.softmax(out_sha, dim=1)[:, 1]
|
| 246 |
+
self.prob.append(
|
| 247 |
+
prob_sha
|
| 248 |
+
.detach()
|
| 249 |
+
.squeeze()
|
| 250 |
+
.cpu()
|
| 251 |
+
.numpy()
|
| 252 |
+
)
|
| 253 |
+
_, prediction_class = torch.max(out_sha, 1)
|
| 254 |
+
if 'label' in data_dict:
|
| 255 |
+
self.label.append(
|
| 256 |
+
data_dict['label']
|
| 257 |
+
.detach()
|
| 258 |
+
.squeeze()
|
| 259 |
+
.cpu()
|
| 260 |
+
.numpy()
|
| 261 |
+
)
|
| 262 |
+
# deal with acc
|
| 263 |
+
common_label = (data_dict['label'] >= 1)
|
| 264 |
+
correct = (prediction_class == common_label).sum().item()
|
| 265 |
+
self.correct += correct
|
| 266 |
+
self.total += data_dict['label'].size(0)
|
| 267 |
+
|
| 268 |
+
pred_dict = {'cls': out_sha, 'feat': sha_feat}
|
| 269 |
+
return pred_dict
|
| 270 |
+
|
| 271 |
+
bs = f_share.size(0)
|
| 272 |
+
# using idx aug in the training mode
|
| 273 |
+
aug_idx = random.random()
|
| 274 |
+
if aug_idx < 0.7:
|
| 275 |
+
# real
|
| 276 |
+
idx_list = list(range(0, bs//2))
|
| 277 |
+
random.shuffle(idx_list)
|
| 278 |
+
f_share[0: bs//2] = f_share[idx_list]
|
| 279 |
+
# fake
|
| 280 |
+
idx_list = list(range(bs//2, bs))
|
| 281 |
+
random.shuffle(idx_list)
|
| 282 |
+
f_share[bs//2: bs] = f_share[idx_list]
|
| 283 |
+
|
| 284 |
+
# concat spe and share to obtain new_f_all
|
| 285 |
+
f_all = torch.cat((f_spe, f_share), dim=1)
|
| 286 |
+
|
| 287 |
+
# reconstruction loss
|
| 288 |
+
f2, f1 = f_all.chunk(2, dim=0)
|
| 289 |
+
c2, c1 = content_features.chunk(2, dim=0)
|
| 290 |
+
|
| 291 |
+
# ==== self reconstruction ==== #
|
| 292 |
+
# f1 + c1 -> f11, f11 + c1 -> near~I1
|
| 293 |
+
self_reconstruction_image_1 = self.con_gan(f1, c1)
|
| 294 |
+
|
| 295 |
+
# f2 + c2 -> f2, f2 + c2 -> near~I2
|
| 296 |
+
self_reconstruction_image_2 = self.con_gan(f2, c2)
|
| 297 |
+
|
| 298 |
+
# ==== cross combine ==== #
|
| 299 |
+
reconstruction_image_1 = self.con_gan(f1, c2)
|
| 300 |
+
reconstruction_image_2 = self.con_gan(f2, c1)
|
| 301 |
+
|
| 302 |
+
# head for spe and sha
|
| 303 |
+
out_spe, spe_feat = self.head_spe(f_spe)
|
| 304 |
+
out_sha, sha_feat = self.head_sha(f_share)
|
| 305 |
+
|
| 306 |
+
# get the probability of the pred
|
| 307 |
+
prob_sha = torch.softmax(out_sha, dim=1)[:, 1]
|
| 308 |
+
prob_spe = torch.softmax(out_spe, dim=1)[:, 1]
|
| 309 |
+
|
| 310 |
+
# build the prediction dict for each output
|
| 311 |
+
pred_dict = {
|
| 312 |
+
'cls': out_sha,
|
| 313 |
+
'prob': prob_sha,
|
| 314 |
+
'feat': sha_feat,
|
| 315 |
+
'cls_spe': out_spe,
|
| 316 |
+
'prob_spe': prob_spe,
|
| 317 |
+
'feat_spe': spe_feat,
|
| 318 |
+
'feat_content': content_features,
|
| 319 |
+
'recontruction_imgs': (
|
| 320 |
+
reconstruction_image_1,
|
| 321 |
+
reconstruction_image_2,
|
| 322 |
+
self_reconstruction_image_1,
|
| 323 |
+
self_reconstruction_image_2
|
| 324 |
+
)
|
| 325 |
+
}
|
| 326 |
+
return pred_dict
|
| 327 |
+
|
| 328 |
+
def sn_double_conv(in_channels, out_channels):
|
| 329 |
+
return nn.Sequential(
|
| 330 |
+
nn.utils.spectral_norm(
|
| 331 |
+
nn.Conv2d(in_channels, in_channels, 3, padding=1)),
|
| 332 |
+
nn.utils.spectral_norm(
|
| 333 |
+
nn.Conv2d(in_channels, out_channels, 3, padding=1, stride=2)),
|
| 334 |
+
nn.LeakyReLU(0.2, inplace=True)
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
def r_double_conv(in_channels, out_channels):
|
| 338 |
+
return nn.Sequential(
|
| 339 |
+
nn.Conv2d(in_channels, out_channels, 3, padding=1),
|
| 340 |
+
nn.ReLU(inplace=True),
|
| 341 |
+
nn.Conv2d(out_channels, out_channels, 3, padding=1),
|
| 342 |
+
nn.ReLU(inplace=True)
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
class AdaIN(nn.Module):
|
| 346 |
+
def __init__(self, eps=1e-5):
|
| 347 |
+
super().__init__()
|
| 348 |
+
self.eps = eps
|
| 349 |
+
# self.l1 = nn.Linear(num_classes, in_channel*4, bias=True) #bias is good :)
|
| 350 |
+
|
| 351 |
+
def c_norm(self, x, bs, ch, eps=1e-7):
|
| 352 |
+
# assert isinstance(x, torch.cuda.FloatTensor)
|
| 353 |
+
x_var = x.var(dim=-1) + eps
|
| 354 |
+
x_std = x_var.sqrt().view(bs, ch, 1, 1)
|
| 355 |
+
x_mean = x.mean(dim=-1).view(bs, ch, 1, 1)
|
| 356 |
+
return x_std, x_mean
|
| 357 |
+
|
| 358 |
+
def forward(self, x, y):
|
| 359 |
+
assert x.size(0)==y.size(0)
|
| 360 |
+
size = x.size()
|
| 361 |
+
bs, ch = size[:2]
|
| 362 |
+
x_ = x.view(bs, ch, -1)
|
| 363 |
+
y_ = y.reshape(bs, ch, -1)
|
| 364 |
+
x_std, x_mean = self.c_norm(x_, bs, ch, eps=self.eps)
|
| 365 |
+
y_std, y_mean = self.c_norm(y_, bs, ch, eps=self.eps)
|
| 366 |
+
out = ((x - x_mean.expand(size)) / x_std.expand(size)) \
|
| 367 |
+
* y_std.expand(size) + y_mean.expand(size)
|
| 368 |
+
return out
|
| 369 |
+
|
| 370 |
+
class Conditional_UNet(nn.Module):
|
| 371 |
+
|
| 372 |
+
def init_weight(self, std=0.2):
|
| 373 |
+
for m in self.modules():
|
| 374 |
+
cn = m.__class__.__name__
|
| 375 |
+
if cn.find('Conv') != -1:
|
| 376 |
+
m.weight.data.normal_(0., std)
|
| 377 |
+
elif cn.find('Linear') != -1:
|
| 378 |
+
m.weight.data.normal_(1., std)
|
| 379 |
+
m.bias.data.fill_(0)
|
| 380 |
+
|
| 381 |
+
def __init__(self):
|
| 382 |
+
super(Conditional_UNet, self).__init__()
|
| 383 |
+
|
| 384 |
+
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
| 385 |
+
self.maxpool = nn.MaxPool2d(2)
|
| 386 |
+
self.dropout = nn.Dropout(p=0.3)
|
| 387 |
+
#self.dropout_half = HalfDropout(p=0.3)
|
| 388 |
+
|
| 389 |
+
self.adain3 = AdaIN()
|
| 390 |
+
self.adain2 = AdaIN()
|
| 391 |
+
self.adain1 = AdaIN()
|
| 392 |
+
|
| 393 |
+
self.dconv_up3 = r_double_conv(512, 256)
|
| 394 |
+
self.dconv_up2 = r_double_conv(256, 128)
|
| 395 |
+
self.dconv_up1 = r_double_conv(128, 64)
|
| 396 |
+
|
| 397 |
+
self.conv_last = nn.Conv2d(64, 3, 1)
|
| 398 |
+
self.up_last = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
|
| 399 |
+
self.activation = nn.Tanh()
|
| 400 |
+
#self.init_weight()
|
| 401 |
+
|
| 402 |
+
def forward(self, c, x): # c is the style and x is the content
|
| 403 |
+
x = self.adain3(x, c)
|
| 404 |
+
x = self.upsample(x)
|
| 405 |
+
x = self.dropout(x)
|
| 406 |
+
x = self.dconv_up3(x)
|
| 407 |
+
c = self.upsample(c)
|
| 408 |
+
c = self.dropout(c)
|
| 409 |
+
c = self.dconv_up3(c)
|
| 410 |
+
|
| 411 |
+
x = self.adain2(x, c)
|
| 412 |
+
x = self.upsample(x)
|
| 413 |
+
x = self.dropout(x)
|
| 414 |
+
x = self.dconv_up2(x)
|
| 415 |
+
c = self.upsample(c)
|
| 416 |
+
c = self.dropout(c)
|
| 417 |
+
c = self.dconv_up2(c)
|
| 418 |
+
|
| 419 |
+
x = self.adain1(x, c)
|
| 420 |
+
x = self.upsample(x)
|
| 421 |
+
x = self.dropout(x)
|
| 422 |
+
x = self.dconv_up1(x)
|
| 423 |
+
|
| 424 |
+
x = self.conv_last(x)
|
| 425 |
+
out = self.up_last(x)
|
| 426 |
+
|
| 427 |
+
return self.activation(out)
|
| 428 |
+
|
| 429 |
+
class MLP(nn.Module):
|
| 430 |
+
def __init__(self, in_f, hidden_dim, out_f):
|
| 431 |
+
super(MLP, self).__init__()
|
| 432 |
+
self.pool = nn.AdaptiveAvgPool2d(1)
|
| 433 |
+
self.mlp = nn.Sequential(nn.Linear(in_f, hidden_dim),
|
| 434 |
+
nn.LeakyReLU(inplace=True),
|
| 435 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 436 |
+
nn.LeakyReLU(inplace=True),
|
| 437 |
+
nn.Linear(hidden_dim, out_f),)
|
| 438 |
+
|
| 439 |
+
def forward(self, x):
|
| 440 |
+
x = self.pool(x)
|
| 441 |
+
x = self.mlp(x)
|
| 442 |
+
return x
|
| 443 |
+
|
| 444 |
+
class Conv2d1x1(nn.Module):
|
| 445 |
+
def __init__(self, in_f, hidden_dim, out_f):
|
| 446 |
+
super(Conv2d1x1, self).__init__()
|
| 447 |
+
self.conv2d = nn.Sequential(nn.Conv2d(in_f, hidden_dim, 1, 1),
|
| 448 |
+
nn.LeakyReLU(inplace=True),
|
| 449 |
+
nn.Conv2d(hidden_dim, hidden_dim, 1, 1),
|
| 450 |
+
nn.LeakyReLU(inplace=True),
|
| 451 |
+
nn.Conv2d(hidden_dim, out_f, 1, 1),)
|
| 452 |
+
|
| 453 |
+
def forward(self, x):
|
| 454 |
+
x = self.conv2d(x)
|
| 455 |
+
return x
|
| 456 |
+
|
| 457 |
+
class Head(nn.Module):
|
| 458 |
+
def __init__(self, in_f, hidden_dim, out_f):
|
| 459 |
+
super(Head, self).__init__()
|
| 460 |
+
self.do = nn.Dropout(0.2)
|
| 461 |
+
self.pool = nn.AdaptiveAvgPool2d(1)
|
| 462 |
+
self.mlp = nn.Sequential(nn.Linear(in_f, hidden_dim),
|
| 463 |
+
nn.LeakyReLU(inplace=True),
|
| 464 |
+
nn.Linear(hidden_dim, out_f),)
|
| 465 |
+
|
| 466 |
+
def forward(self, x):
|
| 467 |
+
bs = x.size()[0]
|
| 468 |
+
x_feat = self.pool(x).view(bs, -1)
|
| 469 |
+
x = self.mlp(x_feat)
|
| 470 |
+
x = self.do(x)
|
| 471 |
+
return x, x_feat
|
| 472 |
+
|
loss/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
current_file_path = os.path.abspath(__file__)
|
| 4 |
+
parent_dir = os.path.dirname(os.path.dirname(current_file_path))
|
| 5 |
+
project_root_dir = os.path.dirname(parent_dir)
|
| 6 |
+
sys.path.append(parent_dir)
|
| 7 |
+
sys.path.append(project_root_dir)
|
| 8 |
+
|
| 9 |
+
from metrics.registry import LOSSFUNC
|
| 10 |
+
|
| 11 |
+
from .cross_entropy_loss import CrossEntropyLoss
|
| 12 |
+
from .contrastive_regularization import ContrastiveLoss
|
| 13 |
+
from .l1_loss import L1Loss
|
loss/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (565 Bytes). View file
|
|
|
loss/__pycache__/abstract_loss_func.cpython-310.pyc
ADDED
|
Binary file (977 Bytes). View file
|
|
|
loss/__pycache__/contrastive_regularization.cpython-310.pyc
ADDED
|
Binary file (2.38 kB). View file
|
|
|
loss/__pycache__/cross_entropy_loss.cpython-310.pyc
ADDED
|
Binary file (1.26 kB). View file
|
|
|
loss/__pycache__/l1_loss.cpython-310.pyc
ADDED
|
Binary file (892 Bytes). View file
|
|
|
loss/abstract_loss_func.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
class AbstractLossClass(nn.Module):
|
| 4 |
+
"""Abstract class for loss functions."""
|
| 5 |
+
def __init__(self):
|
| 6 |
+
super(AbstractLossClass, self).__init__()
|
| 7 |
+
|
| 8 |
+
def forward(self, pred, label):
|
| 9 |
+
"""
|
| 10 |
+
Args:
|
| 11 |
+
pred: prediction of the model
|
| 12 |
+
label: ground truth label
|
| 13 |
+
|
| 14 |
+
Return:
|
| 15 |
+
loss: loss value
|
| 16 |
+
"""
|
| 17 |
+
raise NotImplementedError('Each subclass should implement the forward method.')
|
loss/contrastive_regularization.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from .abstract_loss_func import AbstractLossClass
|
| 7 |
+
from metrics.registry import LOSSFUNC
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def swap_spe_features(type_list, value_list):
|
| 11 |
+
type_list = type_list.cpu().numpy().tolist()
|
| 12 |
+
# get index
|
| 13 |
+
index_list = list(range(len(type_list)))
|
| 14 |
+
|
| 15 |
+
# init a dict, where its key is the type and value is the index
|
| 16 |
+
spe_dict = defaultdict(list)
|
| 17 |
+
|
| 18 |
+
# do for-loop to get spe dict
|
| 19 |
+
for i, one_type in enumerate(type_list):
|
| 20 |
+
spe_dict[one_type].append(index_list[i])
|
| 21 |
+
|
| 22 |
+
# shuffle the value list of each key
|
| 23 |
+
for keys in spe_dict.keys():
|
| 24 |
+
random.shuffle(spe_dict[keys])
|
| 25 |
+
|
| 26 |
+
# generate a new index list for the value list
|
| 27 |
+
new_index_list = []
|
| 28 |
+
for one_type in type_list:
|
| 29 |
+
value = spe_dict[one_type].pop()
|
| 30 |
+
new_index_list.append(value)
|
| 31 |
+
|
| 32 |
+
# swap the value_list by new_index_list
|
| 33 |
+
value_list_new = value_list[new_index_list]
|
| 34 |
+
|
| 35 |
+
return value_list_new
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@LOSSFUNC.register_module(module_name="contrastive_regularization")
|
| 39 |
+
class ContrastiveLoss(AbstractLossClass):
|
| 40 |
+
def __init__(self, margin=1.0):
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.margin = margin
|
| 43 |
+
|
| 44 |
+
def contrastive_loss(self, anchor, positive, negative):
|
| 45 |
+
dist_pos = F.pairwise_distance(anchor, positive)
|
| 46 |
+
dist_neg = F.pairwise_distance(anchor, negative)
|
| 47 |
+
# Compute loss as the distance between anchor and negative minus the distance between anchor and positive
|
| 48 |
+
loss = torch.mean(torch.clamp(dist_pos - dist_neg + self.margin, min=0.0))
|
| 49 |
+
return loss
|
| 50 |
+
|
| 51 |
+
def forward(self, common, specific, spe_label):
|
| 52 |
+
# prepare
|
| 53 |
+
bs = common.shape[0]
|
| 54 |
+
real_common, fake_common = common.chunk(2)
|
| 55 |
+
### common real
|
| 56 |
+
idx_list = list(range(0, bs//2))
|
| 57 |
+
random.shuffle(idx_list)
|
| 58 |
+
real_common_anchor = common[idx_list]
|
| 59 |
+
### common fake
|
| 60 |
+
idx_list = list(range(bs//2, bs))
|
| 61 |
+
random.shuffle(idx_list)
|
| 62 |
+
fake_common_anchor = common[idx_list]
|
| 63 |
+
### specific
|
| 64 |
+
specific_anchor = swap_spe_features(spe_label, specific)
|
| 65 |
+
real_specific_anchor, fake_specific_anchor = specific_anchor.chunk(2)
|
| 66 |
+
real_specific, fake_specific = specific.chunk(2)
|
| 67 |
+
|
| 68 |
+
# Compute the contrastive loss of common between real and fake
|
| 69 |
+
loss_realcommon = self.contrastive_loss(real_common, real_common_anchor, fake_common_anchor)
|
| 70 |
+
loss_fakecommon = self.contrastive_loss(fake_common, fake_common_anchor, real_common_anchor)
|
| 71 |
+
|
| 72 |
+
# Comupte the constrastive loss of specific between real and fake
|
| 73 |
+
loss_realspecific = self.contrastive_loss(real_specific, real_specific_anchor, fake_specific_anchor)
|
| 74 |
+
loss_fakespecific = self.contrastive_loss(fake_specific, fake_specific_anchor, real_specific_anchor)
|
| 75 |
+
|
| 76 |
+
# Compute the final loss as the sum of all contrastive losses
|
| 77 |
+
loss = loss_realcommon + loss_fakecommon + loss_fakespecific + loss_realspecific
|
| 78 |
+
return loss
|
loss/cross_entropy_loss.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from .abstract_loss_func import AbstractLossClass
|
| 3 |
+
from metrics.registry import LOSSFUNC
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@LOSSFUNC.register_module(module_name="cross_entropy")
|
| 7 |
+
class CrossEntropyLoss(AbstractLossClass):
|
| 8 |
+
def __init__(self):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.loss_fn = nn.CrossEntropyLoss()
|
| 11 |
+
|
| 12 |
+
def forward(self, inputs, targets):
|
| 13 |
+
"""
|
| 14 |
+
Computes the cross-entropy loss.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
inputs: A PyTorch tensor of size (batch_size, num_classes) containing the predicted scores.
|
| 18 |
+
targets: A PyTorch tensor of size (batch_size) containing the ground-truth class indices.
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
A scalar tensor representing the cross-entropy loss.
|
| 22 |
+
"""
|
| 23 |
+
# Compute the cross-entropy loss
|
| 24 |
+
loss = self.loss_fn(inputs, targets)
|
| 25 |
+
|
| 26 |
+
return loss
|
loss/l1_loss.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from .abstract_loss_func import AbstractLossClass
|
| 3 |
+
from metrics.registry import LOSSFUNC
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@LOSSFUNC.register_module(module_name="l1loss")
|
| 7 |
+
class L1Loss(AbstractLossClass):
|
| 8 |
+
def __init__(self):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.loss_fn = nn.L1Loss()
|
| 11 |
+
|
| 12 |
+
def forward(self, inputs, targets):
|
| 13 |
+
"""
|
| 14 |
+
Computes the l1 loss.
|
| 15 |
+
"""
|
| 16 |
+
# Compute the l1 loss
|
| 17 |
+
loss = self.loss_fn(inputs, targets)
|
| 18 |
+
|
| 19 |
+
return loss
|
metrics/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
current_file_path = os.path.abspath(__file__)
|
| 4 |
+
parent_dir = os.path.dirname(os.path.dirname(current_file_path))
|
| 5 |
+
project_root_dir = os.path.dirname(parent_dir)
|
| 6 |
+
sys.path.append(parent_dir)
|
| 7 |
+
sys.path.append(project_root_dir)
|
metrics/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (351 Bytes). View file
|
|
|
metrics/__pycache__/base_metrics_class.cpython-310.pyc
ADDED
|
Binary file (6.21 kB). View file
|
|
|
metrics/__pycache__/registry.cpython-310.pyc
ADDED
|
Binary file (1.01 kB). View file
|
|
|
metrics/base_metrics_class.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from sklearn import metrics
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def get_accracy(output, label):
|
| 9 |
+
_, prediction = torch.max(output, 1) # argmax
|
| 10 |
+
correct = (prediction == label).sum().item()
|
| 11 |
+
accuracy = correct / prediction.size(0)
|
| 12 |
+
return accuracy
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_prediction(output, label):
|
| 16 |
+
prob = nn.functional.softmax(output, dim=1)[:, 1]
|
| 17 |
+
prob = prob.view(prob.size(0), 1)
|
| 18 |
+
label = label.view(label.size(0), 1)
|
| 19 |
+
#print(prob.size(), label.size())
|
| 20 |
+
datas = torch.cat((prob, label.float()), dim=1)
|
| 21 |
+
return datas
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def calculate_metrics_for_train(label, output):
|
| 25 |
+
if output.size(1) == 2:
|
| 26 |
+
prob = torch.softmax(output, dim=1)[:, 1]
|
| 27 |
+
else:
|
| 28 |
+
prob = output
|
| 29 |
+
|
| 30 |
+
# Accuracy
|
| 31 |
+
_, prediction = torch.max(output, 1)
|
| 32 |
+
correct = (prediction == label).sum().item()
|
| 33 |
+
accuracy = correct / prediction.size(0)
|
| 34 |
+
|
| 35 |
+
# Average Precision
|
| 36 |
+
y_true = label.cpu().detach().numpy()
|
| 37 |
+
y_pred = prob.cpu().detach().numpy()
|
| 38 |
+
ap = metrics.average_precision_score(y_true, y_pred)
|
| 39 |
+
|
| 40 |
+
# AUC and EER
|
| 41 |
+
try:
|
| 42 |
+
fpr, tpr, thresholds = metrics.roc_curve(label.squeeze().cpu().numpy(),
|
| 43 |
+
prob.squeeze().cpu().numpy(),
|
| 44 |
+
pos_label=1)
|
| 45 |
+
except:
|
| 46 |
+
# for the case when we only have one sample
|
| 47 |
+
return None, None, accuracy, ap
|
| 48 |
+
|
| 49 |
+
if np.isnan(fpr[0]) or np.isnan(tpr[0]):
|
| 50 |
+
# for the case when all the samples within a batch is fake/real
|
| 51 |
+
auc, eer = None, None
|
| 52 |
+
else:
|
| 53 |
+
auc = metrics.auc(fpr, tpr)
|
| 54 |
+
fnr = 1 - tpr
|
| 55 |
+
eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
|
| 56 |
+
|
| 57 |
+
return auc, eer, accuracy, ap
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# ------------ compute average metrics of batches---------------------
|
| 61 |
+
class Metrics_batch():
|
| 62 |
+
def __init__(self):
|
| 63 |
+
self.tprs = []
|
| 64 |
+
self.mean_fpr = np.linspace(0, 1, 100)
|
| 65 |
+
self.aucs = []
|
| 66 |
+
self.eers = []
|
| 67 |
+
self.aps = []
|
| 68 |
+
|
| 69 |
+
self.correct = 0
|
| 70 |
+
self.total = 0
|
| 71 |
+
self.losses = []
|
| 72 |
+
|
| 73 |
+
def update(self, label, output):
|
| 74 |
+
acc = self._update_acc(label, output)
|
| 75 |
+
if output.size(1) == 2:
|
| 76 |
+
prob = torch.softmax(output, dim=1)[:, 1]
|
| 77 |
+
else:
|
| 78 |
+
prob = output
|
| 79 |
+
#label = 1-label
|
| 80 |
+
#prob = torch.softmax(output, dim=1)[:, 1]
|
| 81 |
+
auc, eer = self._update_auc(label, prob)
|
| 82 |
+
ap = self._update_ap(label, prob)
|
| 83 |
+
|
| 84 |
+
return acc, auc, eer, ap
|
| 85 |
+
|
| 86 |
+
def _update_auc(self, lab, prob):
|
| 87 |
+
fpr, tpr, thresholds = metrics.roc_curve(lab.squeeze().cpu().numpy(),
|
| 88 |
+
prob.squeeze().cpu().numpy(),
|
| 89 |
+
pos_label=1)
|
| 90 |
+
if np.isnan(fpr[0]) or np.isnan(tpr[0]):
|
| 91 |
+
return -1, -1
|
| 92 |
+
|
| 93 |
+
auc = metrics.auc(fpr, tpr)
|
| 94 |
+
interp_tpr = np.interp(self.mean_fpr, fpr, tpr)
|
| 95 |
+
interp_tpr[0] = 0.0
|
| 96 |
+
self.tprs.append(interp_tpr)
|
| 97 |
+
self.aucs.append(auc)
|
| 98 |
+
|
| 99 |
+
# return auc
|
| 100 |
+
|
| 101 |
+
# EER
|
| 102 |
+
fnr = 1 - tpr
|
| 103 |
+
eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
|
| 104 |
+
self.eers.append(eer)
|
| 105 |
+
|
| 106 |
+
return auc, eer
|
| 107 |
+
|
| 108 |
+
def _update_acc(self, lab, output):
|
| 109 |
+
_, prediction = torch.max(output, 1) # argmax
|
| 110 |
+
correct = (prediction == lab).sum().item()
|
| 111 |
+
accuracy = correct / prediction.size(0)
|
| 112 |
+
# self.accs.append(accuracy)
|
| 113 |
+
self.correct = self.correct+correct
|
| 114 |
+
self.total = self.total+lab.size(0)
|
| 115 |
+
return accuracy
|
| 116 |
+
|
| 117 |
+
def _update_ap(self, label, prob):
|
| 118 |
+
y_true = label.cpu().detach().numpy()
|
| 119 |
+
y_pred = prob.cpu().detach().numpy()
|
| 120 |
+
ap = metrics.average_precision_score(y_true,y_pred)
|
| 121 |
+
self.aps.append(ap)
|
| 122 |
+
|
| 123 |
+
return np.mean(ap)
|
| 124 |
+
|
| 125 |
+
def get_mean_metrics(self):
|
| 126 |
+
mean_acc, std_acc = self.correct/self.total, 0
|
| 127 |
+
mean_auc, std_auc = self._mean_auc()
|
| 128 |
+
mean_err, std_err = np.mean(self.eers), np.std(self.eers)
|
| 129 |
+
mean_ap, std_ap = np.mean(self.aps), np.std(self.aps)
|
| 130 |
+
|
| 131 |
+
return {'acc':mean_acc, 'auc':mean_auc, 'eer':mean_err, 'ap':mean_ap}
|
| 132 |
+
|
| 133 |
+
def _mean_auc(self):
|
| 134 |
+
mean_tpr = np.mean(self.tprs, axis=0)
|
| 135 |
+
mean_tpr[-1] = 1.0
|
| 136 |
+
mean_auc = metrics.auc(self.mean_fpr, mean_tpr)
|
| 137 |
+
std_auc = np.std(self.aucs)
|
| 138 |
+
return mean_auc, std_auc
|
| 139 |
+
|
| 140 |
+
def clear(self):
|
| 141 |
+
self.tprs.clear()
|
| 142 |
+
self.aucs.clear()
|
| 143 |
+
# self.accs.clear()
|
| 144 |
+
self.correct=0
|
| 145 |
+
self.total=0
|
| 146 |
+
self.eers.clear()
|
| 147 |
+
self.aps.clear()
|
| 148 |
+
self.losses.clear()
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# ------------ compute average metrics of all data ---------------------
|
| 152 |
+
class Metrics_all():
|
| 153 |
+
def __init__(self):
|
| 154 |
+
self.probs = []
|
| 155 |
+
self.labels = []
|
| 156 |
+
self.correct = 0
|
| 157 |
+
self.total = 0
|
| 158 |
+
|
| 159 |
+
def store(self, label, output):
|
| 160 |
+
prob = torch.softmax(output, dim=1)[:, 1]
|
| 161 |
+
_, prediction = torch.max(output, 1) # argmax
|
| 162 |
+
correct = (prediction == label).sum().item()
|
| 163 |
+
self.correct += correct
|
| 164 |
+
self.total += label.size(0)
|
| 165 |
+
self.labels.append(label.squeeze().cpu().numpy())
|
| 166 |
+
self.probs.append(prob.squeeze().cpu().numpy())
|
| 167 |
+
|
| 168 |
+
def get_metrics(self):
|
| 169 |
+
y_pred = np.concatenate(self.probs)
|
| 170 |
+
y_true = np.concatenate(self.labels)
|
| 171 |
+
# auc
|
| 172 |
+
fpr, tpr, thresholds = metrics.roc_curve(y_true,y_pred,pos_label=1)
|
| 173 |
+
auc = metrics.auc(fpr, tpr)
|
| 174 |
+
# eer
|
| 175 |
+
fnr = 1 - tpr
|
| 176 |
+
eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
|
| 177 |
+
# ap
|
| 178 |
+
ap = metrics.average_precision_score(y_true,y_pred)
|
| 179 |
+
# acc
|
| 180 |
+
acc = self.correct / self.total
|
| 181 |
+
return {'acc':acc, 'auc':auc, 'eer':eer, 'ap':ap}
|
| 182 |
+
|
| 183 |
+
def clear(self):
|
| 184 |
+
self.probs.clear()
|
| 185 |
+
self.labels.clear()
|
| 186 |
+
self.correct = 0
|
| 187 |
+
self.total = 0
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
# only used to record a series of scalar value
|
| 191 |
+
class Recorder:
|
| 192 |
+
def __init__(self):
|
| 193 |
+
self.sum = 0
|
| 194 |
+
self.num = 0
|
| 195 |
+
def update(self, item, num=1):
|
| 196 |
+
if item is not None:
|
| 197 |
+
self.sum += item * num
|
| 198 |
+
self.num += num
|
| 199 |
+
def average(self):
|
| 200 |
+
if self.num == 0:
|
| 201 |
+
return None
|
| 202 |
+
return self.sum/self.num
|
| 203 |
+
def clear(self):
|
| 204 |
+
self.sum = 0
|
| 205 |
+
self.num = 0
|
metrics/registry.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class Registry(object):
|
| 2 |
+
def __init__(self):
|
| 3 |
+
self.data = {}
|
| 4 |
+
|
| 5 |
+
def register_module(self, module_name=None):
|
| 6 |
+
def _register(cls):
|
| 7 |
+
name = module_name
|
| 8 |
+
if module_name is None:
|
| 9 |
+
name = cls.__name__
|
| 10 |
+
self.data[name] = cls
|
| 11 |
+
return cls
|
| 12 |
+
return _register
|
| 13 |
+
|
| 14 |
+
def __getitem__(self, key):
|
| 15 |
+
return self.data[key]
|
| 16 |
+
|
| 17 |
+
BACKBONE = Registry()
|
| 18 |
+
DETECTOR = Registry()
|
| 19 |
+
TRAINER = Registry()
|
| 20 |
+
LOSSFUNC = Registry()
|
metrics/utils.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sklearn import metrics
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def parse_metric_for_print(metric_dict):
|
| 6 |
+
if metric_dict is None:
|
| 7 |
+
return "\n"
|
| 8 |
+
str = "\n"
|
| 9 |
+
str += "================================ Each dataset best metric ================================ \n"
|
| 10 |
+
for key, value in metric_dict.items():
|
| 11 |
+
if key != 'avg':
|
| 12 |
+
str= str+ f"| {key}: "
|
| 13 |
+
for k,v in value.items():
|
| 14 |
+
str = str + f" {k}={v} "
|
| 15 |
+
str= str+ "| \n"
|
| 16 |
+
else:
|
| 17 |
+
str += "============================================================================================= \n"
|
| 18 |
+
str += "================================== Average best metric ====================================== \n"
|
| 19 |
+
avg_dict = value
|
| 20 |
+
for avg_key, avg_value in avg_dict.items():
|
| 21 |
+
if avg_key == 'dataset_dict':
|
| 22 |
+
for key,value in avg_value.items():
|
| 23 |
+
str = str + f"| {key}: {value} | \n"
|
| 24 |
+
else:
|
| 25 |
+
str = str + f"| avg {avg_key}: {avg_value} | \n"
|
| 26 |
+
str += "============================================================================================="
|
| 27 |
+
return str
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_test_metrics(y_pred, y_true, img_names=None, logger=None):
|
| 31 |
+
def get_video_metrics(image, pred, label):
|
| 32 |
+
result_dict = {}
|
| 33 |
+
new_label = []
|
| 34 |
+
new_pred = []
|
| 35 |
+
# print(image[0])
|
| 36 |
+
# print(pred.shape)
|
| 37 |
+
# print(label.shape)
|
| 38 |
+
for item in np.transpose(np.stack((image, pred, label)), (1, 0)):
|
| 39 |
+
|
| 40 |
+
s = item[0]
|
| 41 |
+
if '\\' in s:
|
| 42 |
+
parts = s.split('\\')
|
| 43 |
+
else:
|
| 44 |
+
parts = s.split('/')
|
| 45 |
+
a = parts[-2]
|
| 46 |
+
b = parts[-1]
|
| 47 |
+
|
| 48 |
+
if a not in result_dict:
|
| 49 |
+
result_dict[a] = []
|
| 50 |
+
|
| 51 |
+
result_dict[a].append(item)
|
| 52 |
+
image_arr = list(result_dict.values())
|
| 53 |
+
|
| 54 |
+
for video in image_arr:
|
| 55 |
+
pred_sum = 0
|
| 56 |
+
label_sum = 0
|
| 57 |
+
leng = 0
|
| 58 |
+
for frame in video:
|
| 59 |
+
pred_sum += float(frame[1])
|
| 60 |
+
label_sum += int(frame[2])
|
| 61 |
+
leng += 1
|
| 62 |
+
new_pred.append(pred_sum / leng)
|
| 63 |
+
new_label.append(int(label_sum / leng))
|
| 64 |
+
fpr, tpr, thresholds = metrics.roc_curve(new_label, new_pred)
|
| 65 |
+
v_auc = metrics.auc(fpr, tpr)
|
| 66 |
+
fnr = 1 - tpr
|
| 67 |
+
v_eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
|
| 68 |
+
return v_auc, v_eer
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
y_pred = y_pred.squeeze()
|
| 72 |
+
|
| 73 |
+
# For UCF, where labels for different manipulations are not consistent.
|
| 74 |
+
y_true[y_true >= 1] = 1
|
| 75 |
+
# auc
|
| 76 |
+
fpr, tpr, thresholds = metrics.roc_curve(y_true, y_pred, pos_label=1)
|
| 77 |
+
auc = metrics.auc(fpr, tpr)
|
| 78 |
+
# eer
|
| 79 |
+
fnr = 1 - tpr
|
| 80 |
+
eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
|
| 81 |
+
# ap
|
| 82 |
+
ap = metrics.average_precision_score(y_true, y_pred)
|
| 83 |
+
# acc
|
| 84 |
+
prediction_class = (y_pred > 0.5).astype(int)
|
| 85 |
+
correct = (prediction_class == np.clip(y_true, a_min=0, a_max=1)).sum().item()
|
| 86 |
+
acc = correct / len(prediction_class)
|
| 87 |
+
|
| 88 |
+
return {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap, 'pred': y_pred, 'label': y_true}
|
networks/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
current_file_path = os.path.abspath(__file__)
|
| 4 |
+
parent_dir = os.path.dirname(os.path.dirname(current_file_path))
|
| 5 |
+
project_root_dir = os.path.dirname(parent_dir)
|
| 6 |
+
sys.path.append(parent_dir)
|
| 7 |
+
sys.path.append(project_root_dir)
|
| 8 |
+
|
| 9 |
+
from metrics.registry import BACKBONE
|
| 10 |
+
|
| 11 |
+
from .xception import Xception
|
networks/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (447 Bytes). View file
|
|
|
networks/__pycache__/xception.cpython-310.pyc
ADDED
|
Binary file (6.7 kB). View file
|
|
|
networks/xception.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
# author: Zhiyuan Yan
|
| 3 |
+
# email: zhiyuanyan@link.cuhk.edu.cn
|
| 4 |
+
# date: 2023-0706
|
| 5 |
+
|
| 6 |
+
The code is mainly modified from GitHub link below:
|
| 7 |
+
https://github.com/ondyari/FaceForensics/blob/master/classification/network/xception.py
|
| 8 |
+
'''
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import argparse
|
| 12 |
+
import logging
|
| 13 |
+
|
| 14 |
+
import math
|
| 15 |
+
import torch
|
| 16 |
+
# import pretrainedmodels
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
|
| 20 |
+
import torch.utils.model_zoo as model_zoo
|
| 21 |
+
from torch.nn import init
|
| 22 |
+
from typing import Union
|
| 23 |
+
from metrics.registry import BACKBONE
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class SeparableConv2d(nn.Module):
|
| 30 |
+
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False):
|
| 31 |
+
super(SeparableConv2d, self).__init__()
|
| 32 |
+
|
| 33 |
+
self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size,
|
| 34 |
+
stride, padding, dilation, groups=in_channels, bias=bias)
|
| 35 |
+
self.pointwise = nn.Conv2d(
|
| 36 |
+
in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias)
|
| 37 |
+
|
| 38 |
+
def forward(self, x):
|
| 39 |
+
x = self.conv1(x)
|
| 40 |
+
x = self.pointwise(x)
|
| 41 |
+
return x
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class Block(nn.Module):
|
| 45 |
+
def __init__(self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True):
|
| 46 |
+
super(Block, self).__init__()
|
| 47 |
+
|
| 48 |
+
if out_filters != in_filters or strides != 1:
|
| 49 |
+
self.skip = nn.Conv2d(in_filters, out_filters,
|
| 50 |
+
1, stride=strides, bias=False)
|
| 51 |
+
self.skipbn = nn.BatchNorm2d(out_filters)
|
| 52 |
+
else:
|
| 53 |
+
self.skip = None
|
| 54 |
+
|
| 55 |
+
self.relu = nn.ReLU(inplace=True)
|
| 56 |
+
rep = []
|
| 57 |
+
|
| 58 |
+
filters = in_filters
|
| 59 |
+
if grow_first: # whether the number of filters grows first
|
| 60 |
+
rep.append(self.relu)
|
| 61 |
+
rep.append(SeparableConv2d(in_filters, out_filters,
|
| 62 |
+
3, stride=1, padding=1, bias=False))
|
| 63 |
+
rep.append(nn.BatchNorm2d(out_filters))
|
| 64 |
+
filters = out_filters
|
| 65 |
+
|
| 66 |
+
for i in range(reps-1):
|
| 67 |
+
rep.append(self.relu)
|
| 68 |
+
rep.append(SeparableConv2d(filters, filters,
|
| 69 |
+
3, stride=1, padding=1, bias=False))
|
| 70 |
+
rep.append(nn.BatchNorm2d(filters))
|
| 71 |
+
|
| 72 |
+
if not grow_first:
|
| 73 |
+
rep.append(self.relu)
|
| 74 |
+
rep.append(SeparableConv2d(in_filters, out_filters,
|
| 75 |
+
3, stride=1, padding=1, bias=False))
|
| 76 |
+
rep.append(nn.BatchNorm2d(out_filters))
|
| 77 |
+
|
| 78 |
+
if not start_with_relu:
|
| 79 |
+
rep = rep[1:]
|
| 80 |
+
else:
|
| 81 |
+
rep[0] = nn.ReLU(inplace=False)
|
| 82 |
+
|
| 83 |
+
if strides != 1:
|
| 84 |
+
rep.append(nn.MaxPool2d(3, strides, 1))
|
| 85 |
+
self.rep = nn.Sequential(*rep)
|
| 86 |
+
|
| 87 |
+
def forward(self, inp):
|
| 88 |
+
x = self.rep(inp)
|
| 89 |
+
|
| 90 |
+
if self.skip is not None:
|
| 91 |
+
skip = self.skip(inp)
|
| 92 |
+
skip = self.skipbn(skip)
|
| 93 |
+
else:
|
| 94 |
+
skip = inp
|
| 95 |
+
|
| 96 |
+
x += skip
|
| 97 |
+
return x
|
| 98 |
+
|
| 99 |
+
def add_gaussian_noise(ins, mean=0, stddev=0.2):
|
| 100 |
+
noise = ins.data.new(ins.size()).normal_(mean, stddev)
|
| 101 |
+
return ins + noise
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
@BACKBONE.register_module(module_name="xception")
|
| 105 |
+
class Xception(nn.Module):
|
| 106 |
+
"""
|
| 107 |
+
Xception optimized for the ImageNet dataset, as specified in
|
| 108 |
+
https://arxiv.org/pdf/1610.02357.pdf
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
def __init__(self, xception_config):
|
| 112 |
+
""" Constructor
|
| 113 |
+
Args:
|
| 114 |
+
xception_config: configuration file with the dict format
|
| 115 |
+
"""
|
| 116 |
+
super(Xception, self).__init__()
|
| 117 |
+
self.num_classes = xception_config["num_classes"]
|
| 118 |
+
self.mode = xception_config["mode"]
|
| 119 |
+
inc = xception_config["inc"]
|
| 120 |
+
dropout = xception_config["dropout"]
|
| 121 |
+
|
| 122 |
+
# Entry flow
|
| 123 |
+
self.conv1 = nn.Conv2d(inc, 32, 3, 2, 0, bias=False)
|
| 124 |
+
|
| 125 |
+
self.bn1 = nn.BatchNorm2d(32)
|
| 126 |
+
self.relu = nn.ReLU(inplace=True)
|
| 127 |
+
|
| 128 |
+
self.conv2 = nn.Conv2d(32, 64, 3, bias=False)
|
| 129 |
+
self.bn2 = nn.BatchNorm2d(64)
|
| 130 |
+
# do relu here
|
| 131 |
+
|
| 132 |
+
self.block1 = Block(
|
| 133 |
+
64, 128, 2, 2, start_with_relu=False, grow_first=True)
|
| 134 |
+
self.block2 = Block(
|
| 135 |
+
128, 256, 2, 2, start_with_relu=True, grow_first=True)
|
| 136 |
+
self.block3 = Block(
|
| 137 |
+
256, 728, 2, 2, start_with_relu=True, grow_first=True)
|
| 138 |
+
|
| 139 |
+
# middle flow
|
| 140 |
+
self.block4 = Block(
|
| 141 |
+
728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
| 142 |
+
self.block5 = Block(
|
| 143 |
+
728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
| 144 |
+
self.block6 = Block(
|
| 145 |
+
728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
| 146 |
+
self.block7 = Block(
|
| 147 |
+
728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
| 148 |
+
|
| 149 |
+
self.block8 = Block(
|
| 150 |
+
728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
| 151 |
+
self.block9 = Block(
|
| 152 |
+
728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
| 153 |
+
self.block10 = Block(
|
| 154 |
+
728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
| 155 |
+
self.block11 = Block(
|
| 156 |
+
728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
| 157 |
+
|
| 158 |
+
# Exit flow
|
| 159 |
+
self.block12 = Block(
|
| 160 |
+
728, 1024, 2, 2, start_with_relu=True, grow_first=False)
|
| 161 |
+
|
| 162 |
+
self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1)
|
| 163 |
+
self.bn3 = nn.BatchNorm2d(1536)
|
| 164 |
+
|
| 165 |
+
# do relu here
|
| 166 |
+
self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1)
|
| 167 |
+
self.bn4 = nn.BatchNorm2d(2048)
|
| 168 |
+
# used for iid
|
| 169 |
+
final_channel = 2048
|
| 170 |
+
if self.mode == 'adjust_channel_iid':
|
| 171 |
+
final_channel = 512
|
| 172 |
+
self.mode = 'adjust_channel'
|
| 173 |
+
self.last_linear = nn.Linear(final_channel, self.num_classes)
|
| 174 |
+
if dropout:
|
| 175 |
+
self.last_linear = nn.Sequential(
|
| 176 |
+
nn.Dropout(p=dropout),
|
| 177 |
+
nn.Linear(final_channel, self.num_classes)
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
self.adjust_channel = nn.Sequential(
|
| 181 |
+
nn.Conv2d(2048, 512, 1, 1),
|
| 182 |
+
nn.BatchNorm2d(512),
|
| 183 |
+
nn.ReLU(inplace=False),
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
def fea_part1_0(self, x):
|
| 187 |
+
x = self.conv1(x)
|
| 188 |
+
x = self.bn1(x)
|
| 189 |
+
x = self.relu(x)
|
| 190 |
+
|
| 191 |
+
return x
|
| 192 |
+
|
| 193 |
+
def fea_part1_1(self, x):
|
| 194 |
+
|
| 195 |
+
x = self.conv2(x)
|
| 196 |
+
x = self.bn2(x)
|
| 197 |
+
x = self.relu(x)
|
| 198 |
+
|
| 199 |
+
return x
|
| 200 |
+
|
| 201 |
+
def fea_part1(self, x):
|
| 202 |
+
x = self.conv1(x)
|
| 203 |
+
x = self.bn1(x)
|
| 204 |
+
x = self.relu(x)
|
| 205 |
+
|
| 206 |
+
x = self.conv2(x)
|
| 207 |
+
x = self.bn2(x)
|
| 208 |
+
x = self.relu(x)
|
| 209 |
+
|
| 210 |
+
return x
|
| 211 |
+
|
| 212 |
+
def fea_part2(self, x):
|
| 213 |
+
x = self.block1(x)
|
| 214 |
+
x = self.block2(x)
|
| 215 |
+
x = self.block3(x)
|
| 216 |
+
|
| 217 |
+
return x
|
| 218 |
+
|
| 219 |
+
def fea_part3(self, x):
|
| 220 |
+
if self.mode == "shallow_xception":
|
| 221 |
+
return x
|
| 222 |
+
else:
|
| 223 |
+
x = self.block4(x)
|
| 224 |
+
x = self.block5(x)
|
| 225 |
+
x = self.block6(x)
|
| 226 |
+
x = self.block7(x)
|
| 227 |
+
return x
|
| 228 |
+
|
| 229 |
+
def fea_part4(self, x):
|
| 230 |
+
if self.mode == "shallow_xception":
|
| 231 |
+
x = self.block12(x)
|
| 232 |
+
else:
|
| 233 |
+
x = self.block8(x)
|
| 234 |
+
x = self.block9(x)
|
| 235 |
+
x = self.block10(x)
|
| 236 |
+
x = self.block11(x)
|
| 237 |
+
x = self.block12(x)
|
| 238 |
+
return x
|
| 239 |
+
|
| 240 |
+
def fea_part5(self, x):
|
| 241 |
+
x = self.conv3(x)
|
| 242 |
+
x = self.bn3(x)
|
| 243 |
+
x = self.relu(x)
|
| 244 |
+
|
| 245 |
+
x = self.conv4(x)
|
| 246 |
+
x = self.bn4(x)
|
| 247 |
+
|
| 248 |
+
return x
|
| 249 |
+
|
| 250 |
+
def features(self, input):
|
| 251 |
+
x = self.fea_part1(input)
|
| 252 |
+
|
| 253 |
+
x = self.fea_part2(x)
|
| 254 |
+
x = self.fea_part3(x)
|
| 255 |
+
x = self.fea_part4(x)
|
| 256 |
+
|
| 257 |
+
x = self.fea_part5(x)
|
| 258 |
+
|
| 259 |
+
if self.mode == 'adjust_channel':
|
| 260 |
+
x = self.adjust_channel(x)
|
| 261 |
+
|
| 262 |
+
return x
|
| 263 |
+
|
| 264 |
+
def classifier(self, features,id_feat=None):
|
| 265 |
+
# for iid
|
| 266 |
+
if self.mode == 'adjust_channel':
|
| 267 |
+
x = features
|
| 268 |
+
else:
|
| 269 |
+
x = self.relu(features)
|
| 270 |
+
|
| 271 |
+
if len(x.shape) == 4:
|
| 272 |
+
x = F.adaptive_avg_pool2d(x, (1, 1))
|
| 273 |
+
x = x.view(x.size(0), -1)
|
| 274 |
+
self.last_emb = x
|
| 275 |
+
# for iid
|
| 276 |
+
if id_feat!=None:
|
| 277 |
+
out = self.last_linear(x-id_feat)
|
| 278 |
+
else:
|
| 279 |
+
out = self.last_linear(x)
|
| 280 |
+
return out
|
| 281 |
+
|
| 282 |
+
def forward(self, input):
|
| 283 |
+
x = self.features(input)
|
| 284 |
+
out = self.classifier(x)
|
| 285 |
+
return out, x
|
optimizor/LinearLR.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.optim import SGD
|
| 3 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
| 4 |
+
|
| 5 |
+
class LinearDecayLR(_LRScheduler):
|
| 6 |
+
def __init__(self, optimizer, n_epoch, start_decay, last_epoch=-1):
|
| 7 |
+
self.start_decay=start_decay
|
| 8 |
+
self.n_epoch=n_epoch
|
| 9 |
+
super(LinearDecayLR, self).__init__(optimizer, last_epoch)
|
| 10 |
+
|
| 11 |
+
def get_lr(self):
|
| 12 |
+
last_epoch = self.last_epoch
|
| 13 |
+
n_epoch=self.n_epoch
|
| 14 |
+
b_lr=self.base_lrs[0]
|
| 15 |
+
start_decay=self.start_decay
|
| 16 |
+
if last_epoch>start_decay:
|
| 17 |
+
lr=b_lr-b_lr/(n_epoch-start_decay)*(last_epoch-start_decay)
|
| 18 |
+
else:
|
| 19 |
+
lr=b_lr
|
| 20 |
+
return [lr]
|
optimizor/SAM.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# borrowed from
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
def disable_running_stats(model):
|
| 9 |
+
def _disable(module):
|
| 10 |
+
if isinstance(module, nn.BatchNorm2d):
|
| 11 |
+
module.backup_momentum = module.momentum
|
| 12 |
+
module.momentum = 0
|
| 13 |
+
|
| 14 |
+
model.apply(_disable)
|
| 15 |
+
|
| 16 |
+
def enable_running_stats(model):
|
| 17 |
+
def _enable(module):
|
| 18 |
+
if isinstance(module, nn.BatchNorm2d) and hasattr(module, "backup_momentum"):
|
| 19 |
+
module.momentum = module.backup_momentum
|
| 20 |
+
|
| 21 |
+
model.apply(_enable)
|
| 22 |
+
|
| 23 |
+
class SAM(torch.optim.Optimizer):
|
| 24 |
+
def __init__(self, params, base_optimizer, rho=0.05, **kwargs):
|
| 25 |
+
assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
|
| 26 |
+
|
| 27 |
+
defaults = dict(rho=rho, **kwargs)
|
| 28 |
+
super(SAM, self).__init__(params, defaults)
|
| 29 |
+
|
| 30 |
+
self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
|
| 31 |
+
self.param_groups = self.base_optimizer.param_groups
|
| 32 |
+
|
| 33 |
+
@torch.no_grad()
|
| 34 |
+
def first_step(self, zero_grad=False):
|
| 35 |
+
grad_norm = self._grad_norm()
|
| 36 |
+
for group in self.param_groups:
|
| 37 |
+
scale = group["rho"] / (grad_norm + 1e-12)
|
| 38 |
+
|
| 39 |
+
for p in group["params"]:
|
| 40 |
+
if p.grad is None: continue
|
| 41 |
+
e_w = p.grad * scale.to(p)
|
| 42 |
+
p.add_(e_w) # climb to the local maximum "w + e(w)"
|
| 43 |
+
self.state[p]["e_w"] = e_w
|
| 44 |
+
|
| 45 |
+
if zero_grad: self.zero_grad()
|
| 46 |
+
|
| 47 |
+
@torch.no_grad()
|
| 48 |
+
def second_step(self, zero_grad=False):
|
| 49 |
+
for group in self.param_groups:
|
| 50 |
+
for p in group["params"]:
|
| 51 |
+
if p.grad is None: continue
|
| 52 |
+
p.sub_(self.state[p]["e_w"]) # get back to "w" from "w + e(w)"
|
| 53 |
+
|
| 54 |
+
self.base_optimizer.step() # do the actual "sharpness-aware" update
|
| 55 |
+
|
| 56 |
+
if zero_grad: self.zero_grad()
|
| 57 |
+
|
| 58 |
+
@torch.no_grad()
|
| 59 |
+
def step(self, closure=None):
|
| 60 |
+
assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
|
| 61 |
+
closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass
|
| 62 |
+
|
| 63 |
+
self.first_step(zero_grad=True)
|
| 64 |
+
closure()
|
| 65 |
+
self.second_step()
|
| 66 |
+
|
| 67 |
+
def _grad_norm(self):
|
| 68 |
+
shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism
|
| 69 |
+
norm = torch.norm(
|
| 70 |
+
torch.stack([
|
| 71 |
+
p.grad.norm(p=2).to(shared_device)
|
| 72 |
+
for group in self.param_groups for p in group["params"]
|
| 73 |
+
if p.grad is not None
|
| 74 |
+
]),
|
| 75 |
+
p=2
|
| 76 |
+
)
|
| 77 |
+
return norm
|
trainer/trainer.py
ADDED
|
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This script was adapted from the DeepfakeBench training code,
|
| 2 |
+
# originally authored by Zhiyuan Yan (zhiyuanyan@link.cuhk.edu.cn)
|
| 3 |
+
|
| 4 |
+
# Original: https://github.com/SCLBD/DeepfakeBench/blob/main/training/train.py
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
current_file_path = os.path.abspath(__file__)
|
| 9 |
+
parent_dir = os.path.dirname(os.path.dirname(current_file_path))
|
| 10 |
+
project_root_dir = os.path.dirname(parent_dir)
|
| 11 |
+
sys.path.append(parent_dir)
|
| 12 |
+
sys.path.append(project_root_dir)
|
| 13 |
+
|
| 14 |
+
import pickle
|
| 15 |
+
import datetime
|
| 16 |
+
import logging
|
| 17 |
+
import numpy as np
|
| 18 |
+
from copy import deepcopy
|
| 19 |
+
from collections import defaultdict
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
import time
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
import torch.nn.functional as F
|
| 25 |
+
import torch.optim as optim
|
| 26 |
+
from torch.nn import DataParallel
|
| 27 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 28 |
+
from metrics.base_metrics_class import Recorder
|
| 29 |
+
from torch.optim.swa_utils import AveragedModel, SWALR
|
| 30 |
+
from torch import distributed as dist
|
| 31 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 32 |
+
from sklearn import metrics
|
| 33 |
+
from metrics.utils import get_test_metrics
|
| 34 |
+
|
| 35 |
+
FFpp_pool=['FaceForensics++','FF-DF','FF-F2F','FF-FS','FF-NT']#
|
| 36 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class Trainer(object):
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
config,
|
| 43 |
+
model,
|
| 44 |
+
optimizer,
|
| 45 |
+
scheduler,
|
| 46 |
+
logger,
|
| 47 |
+
metric_scoring='auc',
|
| 48 |
+
swa_model=None
|
| 49 |
+
):
|
| 50 |
+
# check if all the necessary components are implemented
|
| 51 |
+
if config is None or model is None or optimizer is None or logger is None:
|
| 52 |
+
raise ValueError("config, model, optimizier, logger, and tensorboard writer must be implemented")
|
| 53 |
+
|
| 54 |
+
self.config = config
|
| 55 |
+
self.model = model
|
| 56 |
+
self.optimizer = optimizer
|
| 57 |
+
self.scheduler = scheduler
|
| 58 |
+
self.swa_model = swa_model
|
| 59 |
+
self.writers = {} # dict to maintain different tensorboard writers for each dataset and metric
|
| 60 |
+
self.logger = logger
|
| 61 |
+
self.metric_scoring = metric_scoring
|
| 62 |
+
# maintain the best metric of all epochs
|
| 63 |
+
self.best_metrics_all_time = defaultdict(
|
| 64 |
+
lambda: defaultdict(lambda: float('-inf')
|
| 65 |
+
if self.metric_scoring != 'eer' else float('inf'))
|
| 66 |
+
)
|
| 67 |
+
self.speed_up() # move model to GPU
|
| 68 |
+
|
| 69 |
+
# create directory path
|
| 70 |
+
self.log_dir = self.config['log_dir']
|
| 71 |
+
print("Making dir ", self.log_dir)
|
| 72 |
+
os.makedirs(self.log_dir, exist_ok=True)
|
| 73 |
+
|
| 74 |
+
def get_writer(self, phase, dataset_key, metric_key):
|
| 75 |
+
phase = phase.split('/')[-1]
|
| 76 |
+
dataset_key = dataset_key.split('/')[-1]
|
| 77 |
+
metric_key = metric_key.split('/')[-1]
|
| 78 |
+
writer_key = f"{phase}-{dataset_key}-{metric_key}"
|
| 79 |
+
if writer_key not in self.writers:
|
| 80 |
+
# update directory path
|
| 81 |
+
writer_path = os.path.join(
|
| 82 |
+
self.log_dir,
|
| 83 |
+
phase,
|
| 84 |
+
dataset_key,
|
| 85 |
+
metric_key,
|
| 86 |
+
"metric_board"
|
| 87 |
+
)
|
| 88 |
+
os.makedirs(writer_path, exist_ok=True)
|
| 89 |
+
# update writers dictionary
|
| 90 |
+
self.writers[writer_key] = SummaryWriter(writer_path)
|
| 91 |
+
return self.writers[writer_key]
|
| 92 |
+
|
| 93 |
+
def speed_up(self):
|
| 94 |
+
self.model.to(device)
|
| 95 |
+
self.model.device = device
|
| 96 |
+
if self.config['ddp'] == True:
|
| 97 |
+
num_gpus = torch.cuda.device_count()
|
| 98 |
+
print(f'avai gpus: {num_gpus}')
|
| 99 |
+
# local_rank=[i for i in range(0,num_gpus)]
|
| 100 |
+
self.model = DDP(self.model, device_ids=[self.config['local_rank']],find_unused_parameters=True, output_device=self.config['local_rank'])
|
| 101 |
+
#self.optimizer = nn.DataParallel(self.optimizer, device_ids=[int(os.environ['LOCAL_RANK'])])
|
| 102 |
+
|
| 103 |
+
def setTrain(self):
|
| 104 |
+
self.model.train()
|
| 105 |
+
self.train = True
|
| 106 |
+
|
| 107 |
+
def setEval(self):
|
| 108 |
+
self.model.eval()
|
| 109 |
+
self.train = False
|
| 110 |
+
|
| 111 |
+
def load_ckpt(self, model_path):
|
| 112 |
+
if os.path.isfile(model_path):
|
| 113 |
+
saved = torch.load(model_path, map_location='cpu')
|
| 114 |
+
suffix = model_path.split('.')[-1]
|
| 115 |
+
if suffix == 'p':
|
| 116 |
+
self.model.load_state_dict(saved.state_dict())
|
| 117 |
+
else:
|
| 118 |
+
self.model.load_state_dict(saved)
|
| 119 |
+
self.logger.info('Model found in {}'.format(model_path))
|
| 120 |
+
else:
|
| 121 |
+
raise NotImplementedError(
|
| 122 |
+
"=> no model found at '{}'".format(model_path))
|
| 123 |
+
|
| 124 |
+
def save_ckpt(self, phase, dataset_key,ckpt_info=None):
|
| 125 |
+
save_dir = self.log_dir
|
| 126 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 127 |
+
ckpt_name = f"ckpt_best.pth"
|
| 128 |
+
save_path = os.path.join(save_dir, ckpt_name)
|
| 129 |
+
if self.config['ddp'] == True:
|
| 130 |
+
torch.save(self.model.state_dict(), save_path)
|
| 131 |
+
else:
|
| 132 |
+
if 'svdd' in self.config['model_name']:
|
| 133 |
+
torch.save({'R': self.model.R,
|
| 134 |
+
'c': self.model.c,
|
| 135 |
+
'state_dict': self.model.state_dict(),}, save_path)
|
| 136 |
+
else:
|
| 137 |
+
torch.save(self.model.state_dict(), save_path)
|
| 138 |
+
self.logger.info(f"Checkpoint saved to {save_path}, current ckpt is {ckpt_info}")
|
| 139 |
+
|
| 140 |
+
def save_swa_ckpt(self):
|
| 141 |
+
save_dir = self.log_dir
|
| 142 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 143 |
+
ckpt_name = f"swa.pth"
|
| 144 |
+
save_path = os.path.join(save_dir, ckpt_name)
|
| 145 |
+
torch.save(self.swa_model.state_dict(), save_path)
|
| 146 |
+
self.logger.info(f"SWA Checkpoint saved to {save_path}")
|
| 147 |
+
|
| 148 |
+
def save_feat(self, phase, fea, dataset_key):
|
| 149 |
+
save_dir = os.path.join(self.log_dir, phase, dataset_key)
|
| 150 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 151 |
+
features = fea
|
| 152 |
+
feat_name = f"feat_best.npy"
|
| 153 |
+
save_path = os.path.join(save_dir, feat_name)
|
| 154 |
+
np.save(save_path, features)
|
| 155 |
+
self.logger.info(f"Feature saved to {save_path}")
|
| 156 |
+
|
| 157 |
+
def save_data_dict(self, phase, data_dict, dataset_key):
|
| 158 |
+
save_dir = os.path.join(self.log_dir, phase, dataset_key)
|
| 159 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 160 |
+
file_path = os.path.join(save_dir, f'data_dict_{phase}.pickle')
|
| 161 |
+
with open(file_path, 'wb') as file:
|
| 162 |
+
pickle.dump(data_dict, file)
|
| 163 |
+
self.logger.info(f"data_dict saved to {file_path}")
|
| 164 |
+
|
| 165 |
+
def save_metrics(self, phase, metric_one_dataset, dataset_key):
|
| 166 |
+
save_dir = os.path.join(self.log_dir, phase, dataset_key)
|
| 167 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 168 |
+
file_path = os.path.join(save_dir, 'metric_dict_best.pickle')
|
| 169 |
+
with open(file_path, 'wb') as file:
|
| 170 |
+
pickle.dump(metric_one_dataset, file)
|
| 171 |
+
self.logger.info(f"Metrics saved to {file_path}")
|
| 172 |
+
|
| 173 |
+
def train_step(self,data_dict):
|
| 174 |
+
if self.config['optimizer']['type']=='sam':
|
| 175 |
+
for i in range(2):
|
| 176 |
+
predictions = self.model(data_dict)
|
| 177 |
+
losses = self.model.get_losses(data_dict, predictions)
|
| 178 |
+
if i == 0:
|
| 179 |
+
pred_first = predictions
|
| 180 |
+
losses_first = losses
|
| 181 |
+
self.optimizer.zero_grad()
|
| 182 |
+
losses['overall'].backward()
|
| 183 |
+
if i == 0:
|
| 184 |
+
self.optimizer.first_step(zero_grad=True)
|
| 185 |
+
else:
|
| 186 |
+
self.optimizer.second_step(zero_grad=True)
|
| 187 |
+
return losses_first, pred_first
|
| 188 |
+
else:
|
| 189 |
+
predictions = self.model(data_dict)
|
| 190 |
+
if type(self.model) is DDP:
|
| 191 |
+
losses = self.model.module.get_losses(data_dict, predictions)
|
| 192 |
+
else:
|
| 193 |
+
losses = self.model.get_losses(data_dict, predictions)
|
| 194 |
+
self.optimizer.zero_grad()
|
| 195 |
+
losses['overall'].backward()
|
| 196 |
+
self.optimizer.step()
|
| 197 |
+
|
| 198 |
+
return losses,predictions
|
| 199 |
+
|
| 200 |
+
def train_epoch(
|
| 201 |
+
self,
|
| 202 |
+
epoch,
|
| 203 |
+
train_data_loader,
|
| 204 |
+
validation_data_loaders=None
|
| 205 |
+
):
|
| 206 |
+
|
| 207 |
+
self.logger.info("===> Epoch[{}] start!".format(epoch))
|
| 208 |
+
if epoch>=1:
|
| 209 |
+
times_per_epoch = 2
|
| 210 |
+
else:
|
| 211 |
+
times_per_epoch = 1
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
#times_per_epoch=4
|
| 215 |
+
validation_step = len(train_data_loader) // times_per_epoch # validate 10 times per epoch
|
| 216 |
+
step_cnt = epoch * len(train_data_loader)
|
| 217 |
+
|
| 218 |
+
# define training recorder
|
| 219 |
+
train_recorder_loss = defaultdict(Recorder)
|
| 220 |
+
train_recorder_metric = defaultdict(Recorder)
|
| 221 |
+
|
| 222 |
+
for iteration, data_dict in tqdm(enumerate(train_data_loader),total=len(train_data_loader)):
|
| 223 |
+
self.setTrain()
|
| 224 |
+
# more elegant and more scalable way of moving data to GPU
|
| 225 |
+
for key in data_dict.keys():
|
| 226 |
+
if data_dict[key]!=None and key!='name':
|
| 227 |
+
data_dict[key]=data_dict[key].cuda()
|
| 228 |
+
|
| 229 |
+
losses, predictions=self.train_step(data_dict)
|
| 230 |
+
# update learning rate
|
| 231 |
+
|
| 232 |
+
if 'SWA' in self.config and self.config['SWA'] and epoch>self.config['swa_start']:
|
| 233 |
+
self.swa_model.update_parameters(self.model)
|
| 234 |
+
|
| 235 |
+
# compute training metric for each batch data
|
| 236 |
+
if type(self.model) is DDP:
|
| 237 |
+
batch_metrics = self.model.module.get_train_metrics(data_dict, predictions)
|
| 238 |
+
else:
|
| 239 |
+
batch_metrics = self.model.get_train_metrics(data_dict, predictions)
|
| 240 |
+
|
| 241 |
+
# store data by recorder
|
| 242 |
+
## store metric
|
| 243 |
+
for name, value in batch_metrics.items():
|
| 244 |
+
train_recorder_metric[name].update(value)
|
| 245 |
+
## store loss
|
| 246 |
+
for name, value in losses.items():
|
| 247 |
+
train_recorder_loss[name].update(value)
|
| 248 |
+
|
| 249 |
+
# run tensorboard to visualize the training process
|
| 250 |
+
if iteration % 300 == 0 and self.config['local_rank']==0:
|
| 251 |
+
if self.config['SWA'] and (epoch>self.config['swa_start'] or self.config['dry_run']):
|
| 252 |
+
self.scheduler.step()
|
| 253 |
+
# info for loss
|
| 254 |
+
loss_str = f"Iter: {step_cnt} "
|
| 255 |
+
for k, v in train_recorder_loss.items():
|
| 256 |
+
v_avg = v.average()
|
| 257 |
+
if v_avg == None:
|
| 258 |
+
loss_str += f"training-loss, {k}: not calculated"
|
| 259 |
+
continue
|
| 260 |
+
loss_str += f"training-loss, {k}: {v_avg} "
|
| 261 |
+
# tensorboard-1. loss
|
| 262 |
+
processed_train_dataset = [dataset.split('/')[-1] for dataset in self.config['train_dataset']]
|
| 263 |
+
processed_train_dataset = ','.join(processed_train_dataset)
|
| 264 |
+
writer = self.get_writer('train', processed_train_dataset, k)
|
| 265 |
+
writer.add_scalar(f'train_loss/{k}', v_avg, global_step=step_cnt)
|
| 266 |
+
self.logger.info(loss_str)
|
| 267 |
+
# info for metric
|
| 268 |
+
metric_str = f"Iter: {step_cnt} "
|
| 269 |
+
for k, v in train_recorder_metric.items():
|
| 270 |
+
v_avg = v.average()
|
| 271 |
+
if v_avg == None:
|
| 272 |
+
metric_str += f"training-metric, {k}: not calculated "
|
| 273 |
+
continue
|
| 274 |
+
metric_str += f"training-metric, {k}: {v_avg} "
|
| 275 |
+
# tensorboard-2. metric
|
| 276 |
+
processed_train_dataset = [dataset.split('/')[-1] for dataset in self.config['train_dataset']]
|
| 277 |
+
processed_train_dataset = ','.join(processed_train_dataset)
|
| 278 |
+
writer = self.get_writer('train', processed_train_dataset, k)
|
| 279 |
+
writer.add_scalar(f'train_metric/{k}', v_avg, global_step=step_cnt)
|
| 280 |
+
self.logger.info(metric_str)
|
| 281 |
+
|
| 282 |
+
# clear recorder.
|
| 283 |
+
# Note we only consider the current 300 samples for computing batch-level loss/metric
|
| 284 |
+
for name, recorder in train_recorder_loss.items(): # clear loss recorder
|
| 285 |
+
recorder.clear()
|
| 286 |
+
for name, recorder in train_recorder_metric.items(): # clear metric recorder
|
| 287 |
+
recorder.clear()
|
| 288 |
+
|
| 289 |
+
# run validation
|
| 290 |
+
if (step_cnt+1) % validation_step == 0:
|
| 291 |
+
if validation_data_loaders is not None and ((not self.config['ddp']) or (self.config['ddp'] and dist.get_rank() == 0)):
|
| 292 |
+
self.logger.info("===> Validation start!")
|
| 293 |
+
validation_best_metric = self.eval(
|
| 294 |
+
eval_data_loaders=validation_data_loaders,
|
| 295 |
+
eval_stage="validation",
|
| 296 |
+
step=step_cnt,
|
| 297 |
+
epoch=epoch,
|
| 298 |
+
iteration=iteration
|
| 299 |
+
)
|
| 300 |
+
else:
|
| 301 |
+
validation_best_metric = None
|
| 302 |
+
|
| 303 |
+
step_cnt += 1
|
| 304 |
+
|
| 305 |
+
for key in data_dict.keys():
|
| 306 |
+
if data_dict[key]!=None and key!='name':
|
| 307 |
+
data_dict[key]=data_dict[key].cpu()
|
| 308 |
+
return validation_best_metric
|
| 309 |
+
|
| 310 |
+
def get_respect_acc(self,prob,label):
|
| 311 |
+
pred = np.where(prob > 0.5, 1, 0)
|
| 312 |
+
judge = (pred == label)
|
| 313 |
+
zero_num = len(label) - np.count_nonzero(label)
|
| 314 |
+
acc_fake = np.count_nonzero(judge[zero_num:]) / len(judge[zero_num:])
|
| 315 |
+
acc_real = np.count_nonzero(judge[:zero_num]) / len(judge[:zero_num])
|
| 316 |
+
return acc_real,acc_fake
|
| 317 |
+
|
| 318 |
+
def eval_one_dataset(self, data_loader):
|
| 319 |
+
# define eval recorder
|
| 320 |
+
eval_recorder_loss = defaultdict(Recorder)
|
| 321 |
+
prediction_lists = []
|
| 322 |
+
feature_lists=[]
|
| 323 |
+
label_lists = []
|
| 324 |
+
for i, data_dict in tqdm(enumerate(data_loader),total=len(data_loader)):
|
| 325 |
+
# get data
|
| 326 |
+
if 'label_spe' in data_dict:
|
| 327 |
+
data_dict.pop('label_spe') # remove the specific label
|
| 328 |
+
data_dict['label'] = torch.where(data_dict['label']!=0, 1, 0) # fix the label to 0 and 1 only
|
| 329 |
+
# move data to GPU elegantly
|
| 330 |
+
for key in data_dict.keys():
|
| 331 |
+
if data_dict[key]!=None:
|
| 332 |
+
data_dict[key]=data_dict[key].cuda()
|
| 333 |
+
# model forward without considering gradient computation
|
| 334 |
+
predictions = self.inference(data_dict) #dict with keys cls, feat
|
| 335 |
+
|
| 336 |
+
label_lists += list(data_dict['label'].cpu().detach().numpy())
|
| 337 |
+
# Get the predicted class for each sample in the batch
|
| 338 |
+
_, predicted_classes = torch.max(predictions['cls'], dim=1)
|
| 339 |
+
# Convert the predicted class indices to a list and add to prediction_lists
|
| 340 |
+
prediction_lists += predicted_classes.cpu().detach().numpy().tolist()
|
| 341 |
+
feature_lists += list(predictions['feat'].cpu().detach().numpy())
|
| 342 |
+
if type(self.model) is not AveragedModel:
|
| 343 |
+
# compute all losses for each batch data
|
| 344 |
+
if type(self.model) is DDP:
|
| 345 |
+
losses = self.model.module.get_losses(data_dict, predictions)
|
| 346 |
+
else:
|
| 347 |
+
losses = self.model.get_losses(data_dict, predictions)
|
| 348 |
+
|
| 349 |
+
# store data by recorder
|
| 350 |
+
for name, value in losses.items():
|
| 351 |
+
eval_recorder_loss[name].update(value)
|
| 352 |
+
return eval_recorder_loss, np.array(prediction_lists), np.array(label_lists),np.array(feature_lists)
|
| 353 |
+
|
| 354 |
+
def save_best(self,epoch,iteration,step,losses_one_dataset_recorder,key,metric_one_dataset,eval_stage):
|
| 355 |
+
best_metric = self.best_metrics_all_time[key].get(self.metric_scoring,
|
| 356 |
+
float('-inf') if self.metric_scoring != 'eer' else float(
|
| 357 |
+
'inf'))
|
| 358 |
+
# Check if the current score is an improvement
|
| 359 |
+
improved = (metric_one_dataset[self.metric_scoring] > best_metric) if self.metric_scoring != 'eer' else (
|
| 360 |
+
metric_one_dataset[self.metric_scoring] < best_metric)
|
| 361 |
+
if improved:
|
| 362 |
+
# Update the best metric
|
| 363 |
+
self.best_metrics_all_time[key][self.metric_scoring] = metric_one_dataset[self.metric_scoring]
|
| 364 |
+
if key == 'avg':
|
| 365 |
+
self.best_metrics_all_time[key]['dataset_dict'] = metric_one_dataset['dataset_dict']
|
| 366 |
+
# Save checkpoint, feature, and metrics if specified in config
|
| 367 |
+
if eval_stage=='validation' and self.config['save_ckpt'] and key not in FFpp_pool:
|
| 368 |
+
self.save_ckpt(eval_stage, key, f"{epoch}+{iteration}")
|
| 369 |
+
self.save_metrics(eval_stage, metric_one_dataset, key)
|
| 370 |
+
if losses_one_dataset_recorder is not None:
|
| 371 |
+
# info for each dataset
|
| 372 |
+
loss_str = f"dataset: {key} step: {step} "
|
| 373 |
+
for k, v in losses_one_dataset_recorder.items():
|
| 374 |
+
writer = self.get_writer(eval_stage, key, k)
|
| 375 |
+
v_avg = v.average()
|
| 376 |
+
if v_avg == None:
|
| 377 |
+
print(f'{k} is not calculated')
|
| 378 |
+
continue
|
| 379 |
+
# tensorboard-1. loss
|
| 380 |
+
writer.add_scalar(f'{eval_stage}_losses/{k}', v_avg, global_step=step)
|
| 381 |
+
loss_str += f"{eval_stage}-loss, {k}: {v_avg} "
|
| 382 |
+
self.logger.info(loss_str)
|
| 383 |
+
# tqdm.write(loss_str)
|
| 384 |
+
metric_str = f"dataset: {key} step: {step} "
|
| 385 |
+
for k, v in metric_one_dataset.items():
|
| 386 |
+
if k == 'pred' or k == 'label' or k=='dataset_dict':
|
| 387 |
+
continue
|
| 388 |
+
metric_str += f"{eval_stage}-metric, {k}: {v} "
|
| 389 |
+
# tensorboard-2. metric
|
| 390 |
+
writer = self.get_writer(eval_stage, key, k)
|
| 391 |
+
writer.add_scalar(f'{eval_stage}_metrics/{k}', v, global_step=step)
|
| 392 |
+
if 'pred' in metric_one_dataset:
|
| 393 |
+
acc_real, acc_fake = self.get_respect_acc(metric_one_dataset['pred'], metric_one_dataset['label'])
|
| 394 |
+
metric_str += f'{eval_stage}-metric, acc_real:{acc_real}; acc_fake:{acc_fake}'
|
| 395 |
+
writer.add_scalar(f'{eval_stage}_metrics/acc_real', acc_real, global_step=step)
|
| 396 |
+
writer.add_scalar(f'{eval_stage}_metrics/acc_fake', acc_fake, global_step=step)
|
| 397 |
+
self.logger.info(metric_str)
|
| 398 |
+
|
| 399 |
+
def eval(self, eval_data_loaders, eval_stage, step=None, epoch=None, iteration=None):
|
| 400 |
+
# set model to eval mode
|
| 401 |
+
self.setEval()
|
| 402 |
+
|
| 403 |
+
# define eval recorder
|
| 404 |
+
losses_all_datasets = {}
|
| 405 |
+
metrics_all_datasets = {}
|
| 406 |
+
best_metrics_per_dataset = defaultdict(dict) # best metric for each dataset, for each metric
|
| 407 |
+
avg_metric = {'acc': 0, 'auc': 0, 'eer': 0, 'ap': 0,'dataset_dict':{}} #'video_auc': 0
|
| 408 |
+
keys = eval_data_loaders.keys()
|
| 409 |
+
for key in keys:
|
| 410 |
+
# compute loss for each dataset
|
| 411 |
+
losses_one_dataset_recorder, predictions_nps, label_nps, feature_nps = self.eval_one_dataset(eval_data_loaders[key])
|
| 412 |
+
losses_all_datasets[key] = losses_one_dataset_recorder
|
| 413 |
+
metric_one_dataset=get_test_metrics(y_pred=predictions_nps,y_true=label_nps, logger=self.logger)
|
| 414 |
+
|
| 415 |
+
for metric_name, value in metric_one_dataset.items():
|
| 416 |
+
if metric_name in avg_metric:
|
| 417 |
+
avg_metric[metric_name]+=value
|
| 418 |
+
avg_metric['dataset_dict'][key] = metric_one_dataset[self.metric_scoring]
|
| 419 |
+
if type(self.model) is AveragedModel:
|
| 420 |
+
metric_str = f"Iter Final for SWA: "
|
| 421 |
+
for k, v in metric_one_dataset.items():
|
| 422 |
+
metric_str += f"{eval_stage}-metric, {k}: {v} "
|
| 423 |
+
self.logger.info(metric_str)
|
| 424 |
+
continue
|
| 425 |
+
self.save_best(epoch,iteration,step,losses_one_dataset_recorder,key,metric_one_dataset,eval_stage)
|
| 426 |
+
|
| 427 |
+
if len(keys)>0 and self.config.get('save_avg',False):
|
| 428 |
+
# calculate avg value
|
| 429 |
+
for key in avg_metric:
|
| 430 |
+
if key != 'dataset_dict':
|
| 431 |
+
avg_metric[key] /= len(keys)
|
| 432 |
+
self.save_best(epoch, iteration, step, None, 'avg', avg_metric, eval_stage)
|
| 433 |
+
|
| 434 |
+
self.logger.info(f'===> {eval_stage} Done!')
|
| 435 |
+
return self.best_metrics_all_time # return all types of mean metrics for determining the best ckpt
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
@torch.no_grad()
|
| 439 |
+
def inference(self, data_dict):
|
| 440 |
+
predictions = self.model(data_dict, inference=True)
|
| 441 |
+
return predictions
|