caliangandrew commited on
Commit
8f57ce7
·
verified ·
1 Parent(s): 2a84f8d

Upload 42 files

Browse files
Files changed (39) hide show
  1. config/__init__.py +7 -0
  2. config/__pycache__/__init__.cpython-310.pyc +0 -0
  3. config/__pycache__/constants.cpython-310.pyc +0 -0
  4. config/constants.py +15 -0
  5. config/pretrained_config.yaml +94 -0
  6. config/pretrained_face_config.yaml +94 -0
  7. config/train_config.yaml +9 -0
  8. config/ucf.yaml +73 -0
  9. config/xception.yaml +86 -0
  10. detectors/__init__.py +11 -0
  11. detectors/__pycache__/__init__.cpython-310.pyc +0 -0
  12. detectors/__pycache__/base_detector.cpython-310.pyc +0 -0
  13. detectors/__pycache__/ucf_detector.cpython-310.pyc +0 -0
  14. detectors/base_detector.py +71 -0
  15. detectors/ucf_detector.py +472 -0
  16. loss/__init__.py +13 -0
  17. loss/__pycache__/__init__.cpython-310.pyc +0 -0
  18. loss/__pycache__/abstract_loss_func.cpython-310.pyc +0 -0
  19. loss/__pycache__/contrastive_regularization.cpython-310.pyc +0 -0
  20. loss/__pycache__/cross_entropy_loss.cpython-310.pyc +0 -0
  21. loss/__pycache__/l1_loss.cpython-310.pyc +0 -0
  22. loss/abstract_loss_func.py +17 -0
  23. loss/contrastive_regularization.py +78 -0
  24. loss/cross_entropy_loss.py +26 -0
  25. loss/l1_loss.py +19 -0
  26. metrics/__init__.py +7 -0
  27. metrics/__pycache__/__init__.cpython-310.pyc +0 -0
  28. metrics/__pycache__/base_metrics_class.cpython-310.pyc +0 -0
  29. metrics/__pycache__/registry.cpython-310.pyc +0 -0
  30. metrics/base_metrics_class.py +205 -0
  31. metrics/registry.py +20 -0
  32. metrics/utils.py +88 -0
  33. networks/__init__.py +11 -0
  34. networks/__pycache__/__init__.cpython-310.pyc +0 -0
  35. networks/__pycache__/xception.cpython-310.pyc +0 -0
  36. networks/xception.py +285 -0
  37. optimizor/LinearLR.py +20 -0
  38. optimizor/SAM.py +77 -0
  39. trainer/trainer.py +441 -0
config/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ current_file_path = os.path.abspath(__file__)
4
+ parent_dir = os.path.dirname(os.path.dirname(current_file_path))
5
+ project_root_dir = os.path.dirname(parent_dir)
6
+ sys.path.append(parent_dir)
7
+ sys.path.append(project_root_dir)
config/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (350 Bytes). View file
 
config/__pycache__/constants.cpython-310.pyc ADDED
Binary file (543 Bytes). View file
 
config/constants.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # Path to the directory containing the constants.py file
4
+ CONFIGS_DIR = os.path.dirname(os.path.abspath(__file__))
5
+
6
+ # The base directory for UCF-related files, i.e., UCF directory
7
+ UCF_BASE_PATH = os.path.abspath(os.path.join(CONFIGS_DIR, "..")) # Points to bitmind-subnet/base_miner/UCF/
8
+ # Absolute paths for the required files and directories
9
+ CONFIG_PATH = os.path.join(CONFIGS_DIR, "ucf.yaml") # Path to the ucf.yaml file
10
+ WEIGHTS_DIR = os.path.join(UCF_BASE_PATH, "weights/") # Path to pretrained weights directory
11
+
12
+ HF_REPO = "bitmind/ucf"
13
+ BACKBONE_CKPT = "xception_best.pth"
14
+
15
+ DLIB_FACE_PREDICTOR_PATH = os.path.abspath(os.path.join(UCF_BASE_PATH, "../../utils/dlib_tools/shape_predictor_81_face_landmarks.dat"))
config/pretrained_config.yaml ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SWA: false
2
+ backbone_config:
3
+ dropout: false
4
+ inc: 3
5
+ mode: adjust_channel
6
+ num_classes: 2
7
+ backbone_name: xception
8
+ compression: c23
9
+ cuda: true
10
+ cudnn: true
11
+ dataset_json_folder: preprocessing/dataset_json_v3
12
+ dataset_meta:
13
+ fake:
14
+ - create_splits: false
15
+ path: bitmind/celeb-a-hq___stable-diffusion-xl-base-1.0___256_training_faces
16
+ - create_splits: false
17
+ path: bitmind/ffhq-256___stable-diffusion-xl-base-1.0_training_faces
18
+ real:
19
+ - create_splits: false
20
+ path: bitmind/celeb-a-hq_training_faces
21
+ - create_splits: false
22
+ path: bitmind/ffhq-256_training_faces
23
+ ddp: false
24
+ dry_run: false
25
+ encoder_feat_dim: 512
26
+ faces_only: true
27
+ frame_num:
28
+ test: 32
29
+ train: 32
30
+ lmdb: true
31
+ lmdb_dir: ./datasets/lmdb
32
+ local_rank: 0
33
+ log_dir: ./logs/training/ucf_2024-09-17-16-44-50
34
+ logdir: ./logs
35
+ loss_func:
36
+ cls_loss: cross_entropy
37
+ con_loss: contrastive_regularization
38
+ rec_loss: l1loss
39
+ spe_loss: cross_entropy
40
+ losstype: null
41
+ lr_scheduler: null
42
+ manualSeed: 1024
43
+ mean:
44
+ - 0.5
45
+ - 0.5
46
+ - 0.5
47
+ metric_scoring: auc
48
+ mode: train
49
+ model_name: ucf
50
+ nEpochs: 2
51
+ optimizer:
52
+ adam:
53
+ amsgrad: false
54
+ beta1: 0.9
55
+ beta2: 0.999
56
+ eps: 1.0e-08
57
+ lr: 0.0002
58
+ weight_decay: 0.0005
59
+ sgd:
60
+ lr: 0.0002
61
+ momentum: 0.9
62
+ weight_decay: 0.0005
63
+ type: adam
64
+ pretrained: ../weights/xception_best.pth
65
+ rec_iter: 100
66
+ resolution: 256
67
+ rgb_dir: ./datasets/rgb
68
+ save_avg: true
69
+ save_ckpt: true
70
+ save_epoch: 1
71
+ save_feat: true
72
+ specific_task_number: 2
73
+ split_transforms:
74
+ test:
75
+ name: base_transforms
76
+ train:
77
+ name: random_aug_transforms
78
+ validation:
79
+ name: base_transforms
80
+ start_epoch: 0
81
+ std:
82
+ - 0.5
83
+ - 0.5
84
+ - 0.5
85
+ test_batchSize: 32
86
+ train_batchSize: 32
87
+ train_dataset:
88
+ - bitmind/celeb-a-hq_training_faces
89
+ - bitmind/ffhq-256_training_faces
90
+ - bitmind/celeb-a-hq___stable-diffusion-xl-base-1.0___256_training_faces
91
+ - bitmind/ffhq-256___stable-diffusion-xl-base-1.0_training_faces
92
+ with_landmark: false
93
+ with_mask: false
94
+ workers: 7
config/pretrained_face_config.yaml ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SWA: false
2
+ backbone_config:
3
+ dropout: false
4
+ inc: 3
5
+ mode: adjust_channel
6
+ num_classes: 2
7
+ backbone_name: xception
8
+ compression: c23
9
+ cuda: true
10
+ cudnn: true
11
+ dataset_json_folder: preprocessing/dataset_json_v3
12
+ dataset_meta:
13
+ fake:
14
+ - create_splits: false
15
+ path: bitmind/celeb-a-hq___stable-diffusion-xl-base-1.0___256_training_faces
16
+ - create_splits: false
17
+ path: bitmind/ffhq-256___stable-diffusion-xl-base-1.0_training_faces
18
+ real:
19
+ - create_splits: false
20
+ path: bitmind/celeb-a-hq_training_faces
21
+ - create_splits: false
22
+ path: bitmind/ffhq-256_training_faces
23
+ ddp: false
24
+ dry_run: false
25
+ encoder_feat_dim: 512
26
+ faces_only: true
27
+ frame_num:
28
+ test: 32
29
+ train: 32
30
+ lmdb: true
31
+ lmdb_dir: ./datasets/lmdb
32
+ local_rank: 0
33
+ log_dir: ./logs/training/ucf_2024-09-17-16-44-50
34
+ logdir: ./logs
35
+ loss_func:
36
+ cls_loss: cross_entropy
37
+ con_loss: contrastive_regularization
38
+ rec_loss: l1loss
39
+ spe_loss: cross_entropy
40
+ losstype: null
41
+ lr_scheduler: null
42
+ manualSeed: 1024
43
+ mean:
44
+ - 0.5
45
+ - 0.5
46
+ - 0.5
47
+ metric_scoring: auc
48
+ mode: train
49
+ model_name: ucf
50
+ nEpochs: 2
51
+ optimizer:
52
+ adam:
53
+ amsgrad: false
54
+ beta1: 0.9
55
+ beta2: 0.999
56
+ eps: 1.0e-08
57
+ lr: 0.0002
58
+ weight_decay: 0.0005
59
+ sgd:
60
+ lr: 0.0002
61
+ momentum: 0.9
62
+ weight_decay: 0.0005
63
+ type: adam
64
+ pretrained: ../weights/xception_best.pth
65
+ rec_iter: 100
66
+ resolution: 256
67
+ rgb_dir: ./datasets/rgb
68
+ save_avg: true
69
+ save_ckpt: true
70
+ save_epoch: 1
71
+ save_feat: true
72
+ specific_task_number: 2
73
+ split_transforms:
74
+ test:
75
+ name: base_transforms
76
+ train:
77
+ name: random_aug_transforms
78
+ validation:
79
+ name: base_transforms
80
+ start_epoch: 0
81
+ std:
82
+ - 0.5
83
+ - 0.5
84
+ - 0.5
85
+ test_batchSize: 32
86
+ train_batchSize: 32
87
+ train_dataset:
88
+ - bitmind/celeb-a-hq_training_faces
89
+ - bitmind/ffhq-256_training_faces
90
+ - bitmind/celeb-a-hq___stable-diffusion-xl-base-1.0___256_training_faces
91
+ - bitmind/ffhq-256___stable-diffusion-xl-base-1.0_training_faces
92
+ with_landmark: false
93
+ with_mask: false
94
+ workers: 7
config/train_config.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ mode: train
2
+ lmdb: True
3
+ dry_run: false
4
+ rgb_dir: './datasets/rgb'
5
+ lmdb_dir: './datasets/lmdb'
6
+ dataset_json_folder: './preprocessing/dataset_json'
7
+ SWA: False
8
+ save_avg: True
9
+ log_dir: ./logs/training/
config/ucf.yaml ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # log dir
2
+ log_dir: ../debug_logs/ucf
3
+
4
+ # model setting
5
+ pretrained: ../weights/xception_best.pth # path to a pre-trained model, if using one
6
+ model_name: ucf # model name
7
+ backbone_name: xception # backbone name
8
+ encoder_feat_dim: 512 # feature dimension of the backbone
9
+
10
+ #backbone setting
11
+ backbone_config:
12
+ mode: adjust_channel
13
+ num_classes: 2
14
+ inc: 3
15
+ dropout: false
16
+
17
+ compression: c23 # compression-level for videos
18
+ train_batchSize: 32 # training batch size
19
+ test_batchSize: 32 # test batch size
20
+ workers: 8 # number of data loading workers
21
+ frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing
22
+ resolution: 256 # resolution of output image to network
23
+ with_mask: false # whether to include mask information in the input
24
+ with_landmark: false # whether to include facial landmark information in the input
25
+ save_ckpt: true # whether to save checkpoint
26
+ save_feat: true # whether to save features
27
+ specific_task_number: 2 # default num datasets in FF++ used by DFB, overwritten in training
28
+
29
+ # mean and std for normalization
30
+ mean: [0.5, 0.5, 0.5]
31
+ std: [0.5, 0.5, 0.5]
32
+
33
+ # optimizer config
34
+ optimizer:
35
+ # choose between 'adam' and 'sgd'
36
+ type: adam
37
+ adam:
38
+ lr: 0.0002 # learning rate
39
+ beta1: 0.9 # beta1 for Adam optimizer
40
+ beta2: 0.999 # beta2 for Adam optimizer
41
+ eps: 0.00000001 # epsilon for Adam optimizer
42
+ weight_decay: 0.0005 # weight decay for regularization
43
+ amsgrad: false
44
+ sgd:
45
+ lr: 0.0002 # learning rate
46
+ momentum: 0.9 # momentum for SGD optimizer
47
+ weight_decay: 0.0005 # weight decay for regularization
48
+
49
+ # training config
50
+ lr_scheduler: null # learning rate scheduler
51
+ nEpochs: 20 # number of epochs to train for
52
+ start_epoch: 0 # manual epoch number (useful for restarts)
53
+ save_epoch: 1 # interval epochs for saving models
54
+ rec_iter: 100 # interval iterations for recording
55
+ logdir: ./logs # folder to output images and logs
56
+ manualSeed: 1024 # manual seed for random number generation
57
+ save_ckpt: false # whether to save checkpoint
58
+
59
+ # loss function
60
+ loss_func:
61
+ cls_loss: cross_entropy # loss function to use
62
+ spe_loss: cross_entropy
63
+ con_loss: contrastive_regularization
64
+ rec_loss: l1loss
65
+ losstype: null
66
+
67
+ # metric
68
+ metric_scoring: auc # metric for evaluation (auc, acc, eer, ap)
69
+
70
+ # cuda
71
+
72
+ cuda: true # whether to use CUDA acceleration
73
+ cudnn: true # whether to use CuDNN for convolution operations
config/xception.yaml ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # log dir
2
+ log_dir: /data/home/zhiyuanyan/DeepfakeBench/logs/testing_bench
3
+
4
+ # model setting
5
+ pretrained: /data/home/zhiyuanyan/DeepfakeBench/training/pretrained/xception-b5690688.pth # path to a pre-trained model, if using one
6
+ model_name: xception # model name
7
+ backbone_name: xception # backbone name
8
+
9
+ #backbone setting
10
+ backbone_config:
11
+ mode: original
12
+ num_classes: 2
13
+ inc: 3
14
+ dropout: false
15
+
16
+ # dataset
17
+ all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV]
18
+ train_dataset: [FaceForensics++]
19
+ test_dataset: [FaceForensics++, DeepFakeDetection]
20
+
21
+ compression: c23 # compression-level for videos
22
+ train_batchSize: 32 # training batch size
23
+ test_batchSize: 32 # test batch size
24
+ workers: 8 # number of data loading workers
25
+ frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing
26
+ resolution: 256 # resolution of output image to network
27
+ with_mask: false # whether to include mask information in the input
28
+ with_landmark: false # whether to include facial landmark information in the input
29
+
30
+
31
+ # data augmentation
32
+ use_data_augmentation: true # Add this flag to enable/disable data augmentation
33
+ data_aug:
34
+ flip_prob: 0.5
35
+ rotate_prob: 0.0
36
+ rotate_limit: [-10, 10]
37
+ blur_prob: 0.5
38
+ blur_limit: [3, 7]
39
+ brightness_prob: 0.5
40
+ brightness_limit: [-0.1, 0.1]
41
+ contrast_limit: [-0.1, 0.1]
42
+ quality_lower: 40
43
+ quality_upper: 100
44
+
45
+ # mean and std for normalization
46
+ mean: [0.5, 0.5, 0.5]
47
+ std: [0.5, 0.5, 0.5]
48
+
49
+ # optimizer config
50
+ optimizer:
51
+ # choose between 'adam' and 'sgd'
52
+ type: adam
53
+ adam:
54
+ lr: 0.0002 # learning rate
55
+ beta1: 0.9 # beta1 for Adam optimizer
56
+ beta2: 0.999 # beta2 for Adam optimizer
57
+ eps: 0.00000001 # epsilon for Adam optimizer
58
+ weight_decay: 0.0005 # weight decay for regularization
59
+ amsgrad: false
60
+ sgd:
61
+ lr: 0.0002 # learning rate
62
+ momentum: 0.9 # momentum for SGD optimizer
63
+ weight_decay: 0.0005 # weight decay for regularization
64
+
65
+ # training config
66
+ lr_scheduler: null # learning rate scheduler
67
+ nEpochs: 10 # number of epochs to train for
68
+ start_epoch: 0 # manual epoch number (useful for restarts)
69
+ save_epoch: 1 # interval epochs for saving models
70
+ rec_iter: 100 # interval iterations for recording
71
+ logdir: ./logs # folder to output images and logs
72
+ manualSeed: 1024 # manual seed for random number generation
73
+ save_ckpt: true # whether to save checkpoint
74
+ save_feat: true # whether to save features
75
+
76
+ # loss function
77
+ loss_func: cross_entropy # loss function to use
78
+ losstype: null
79
+
80
+ # metric
81
+ metric_scoring: auc # metric for evaluation (auc, acc, eer, ap)
82
+
83
+ # cuda
84
+
85
+ cuda: true # whether to use CUDA acceleration
86
+ cudnn: true # whether to use CuDNN for convolution operations
detectors/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ current_file_path = os.path.abspath(__file__)
4
+ parent_dir = os.path.dirname(os.path.dirname(current_file_path))
5
+ project_root_dir = os.path.dirname(parent_dir)
6
+ sys.path.append(parent_dir)
7
+ sys.path.append(project_root_dir)
8
+
9
+ from metrics.registry import DETECTOR
10
+
11
+ from .ucf_detector import UCFDetector
detectors/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (455 Bytes). View file
 
detectors/__pycache__/base_detector.cpython-310.pyc ADDED
Binary file (2.57 kB). View file
 
detectors/__pycache__/ucf_detector.cpython-310.pyc ADDED
Binary file (12.9 kB). View file
 
detectors/base_detector.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # author: Zhiyuan Yan
2
+ # email: zhiyuanyan@link.cuhk.edu.cn
3
+ # date: 2023-0706
4
+ # description: Abstract Class for the Deepfake Detector
5
+
6
+ import abc
7
+ import torch
8
+ import torch.nn as nn
9
+ from typing import Union
10
+
11
+ class AbstractDetector(nn.Module, metaclass=abc.ABCMeta):
12
+ """
13
+ All deepfake detectors should subclass this class.
14
+ """
15
+ def __init__(self, config=None, load_param: Union[bool, str] = False):
16
+ """
17
+ config: (dict)
18
+ configurations for the model
19
+ load_param: (False | True | Path(str))
20
+ False Do not read; True Read the default path; Path Read the required path
21
+ """
22
+ super().__init__()
23
+
24
+ @abc.abstractmethod
25
+ def features(self, data_dict: dict) -> torch.tensor:
26
+ """
27
+ Returns the features from the backbone given the input data.
28
+ """
29
+ pass
30
+
31
+ @abc.abstractmethod
32
+ def forward(self, data_dict: dict, inference=False) -> dict:
33
+ """
34
+ Forward pass through the model, returning the prediction dictionary.
35
+ """
36
+ pass
37
+
38
+ @abc.abstractmethod
39
+ def classifier(self, features: torch.tensor) -> torch.tensor:
40
+ """
41
+ Classifies the features into classes.
42
+ """
43
+ pass
44
+
45
+ @abc.abstractmethod
46
+ def build_backbone(self, config):
47
+ """
48
+ Builds the backbone of the model.
49
+ """
50
+ pass
51
+
52
+ @abc.abstractmethod
53
+ def build_loss(self, config):
54
+ """
55
+ Builds the loss function for the model.
56
+ """
57
+ pass
58
+
59
+ @abc.abstractmethod
60
+ def get_losses(self, data_dict: dict, pred_dict: dict) -> dict:
61
+ """
62
+ Returns the losses for the model.
63
+ """
64
+ pass
65
+
66
+ @abc.abstractmethod
67
+ def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict:
68
+ """
69
+ Returns the training metrics for the model.
70
+ """
71
+ pass
detectors/ucf_detector.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ # Source: https://github.com/SCLBD/DeepfakeBench/blob/main/training/detectors/ucf_detector.py
3
+ # author: Zhiyuan Yan
4
+ # email: zhiyuanyan@link.cuhk.edu.cn
5
+ # date: 2023-0706
6
+ # description: Class for the UCFDetector
7
+
8
+ Functions in the Class are summarized as:
9
+ 1. __init__: Initialization
10
+ 2. build_backbone: Backbone-building
11
+ 3. build_loss: Loss-function-building
12
+ 4. features: Feature-extraction
13
+ 5. classifier: Classification
14
+ 6. get_losses: Loss-computation
15
+ 7. get_train_metrics: Training-metrics-computation
16
+ 8. get_test_metrics: Testing-metrics-computation
17
+ 9. forward: Forward-propagation
18
+
19
+ Reference:
20
+ @article{yan2023ucf,
21
+ title={UCF: Uncovering Common Features for Generalizable Deepfake Detection},
22
+ author={Yan, Zhiyuan and Zhang, Yong and Fan, Yanbo and Wu, Baoyuan},
23
+ journal={arXiv preprint arXiv:2304.13949},
24
+ year={2023}
25
+ }
26
+ '''
27
+
28
+ import os
29
+ import datetime
30
+ import logging
31
+ import random
32
+ import numpy as np
33
+ from sklearn import metrics
34
+ from typing import Union
35
+ from collections import defaultdict
36
+
37
+ import torch
38
+ import torch.nn as nn
39
+ import torch.nn.functional as F
40
+ import torch.optim as optim
41
+ from torch.nn import DataParallel
42
+ from torch.utils.tensorboard import SummaryWriter
43
+
44
+ from metrics.base_metrics_class import calculate_metrics_for_train
45
+
46
+ from .base_detector import AbstractDetector
47
+ from arena.detectors.UCF.detectors import DETECTOR
48
+ from networks import BACKBONE
49
+ from loss import LOSSFUNC
50
+
51
+ logger = logging.getLogger(__name__)
52
+
53
+ @DETECTOR.register_module(module_name='ucf')
54
+ class UCFDetector(AbstractDetector):
55
+ def __init__(self, config):
56
+ super().__init__()
57
+ self.config = config
58
+ self.num_classes = config['backbone_config']['num_classes']
59
+ self.encoder_feat_dim = config['encoder_feat_dim']
60
+ self.half_fingerprint_dim = self.encoder_feat_dim//2
61
+
62
+ self.encoder_f = self.build_backbone(config)
63
+ self.encoder_c = self.build_backbone(config)
64
+
65
+ self.loss_func = self.build_loss(config)
66
+ self.prob, self.label = [], []
67
+ self.correct, self.total = 0, 0
68
+
69
+ # basic function
70
+ self.lr = nn.LeakyReLU(inplace=True)
71
+ self.do = nn.Dropout(0.2)
72
+ self.pool = nn.AdaptiveAvgPool2d(1)
73
+
74
+ # conditional gan
75
+ self.con_gan = Conditional_UNet()
76
+
77
+ # head
78
+ specific_task_number = config['specific_task_number']
79
+
80
+ self.head_spe = Head(
81
+ in_f=self.half_fingerprint_dim,
82
+ hidden_dim=self.encoder_feat_dim,
83
+ out_f=specific_task_number
84
+ )
85
+ self.head_sha = Head(
86
+ in_f=self.half_fingerprint_dim,
87
+ hidden_dim=self.encoder_feat_dim,
88
+ out_f=self.num_classes
89
+ )
90
+ self.block_spe = Conv2d1x1(
91
+ in_f=self.encoder_feat_dim,
92
+ hidden_dim=self.half_fingerprint_dim,
93
+ out_f=self.half_fingerprint_dim
94
+ )
95
+ self.block_sha = Conv2d1x1(
96
+ in_f=self.encoder_feat_dim,
97
+ hidden_dim=self.half_fingerprint_dim,
98
+ out_f=self.half_fingerprint_dim
99
+ )
100
+
101
+ def build_backbone(self, config):
102
+ current_dir = os.path.dirname(os.path.abspath(__file__))
103
+ pretrained_path = os.path.join(current_dir, config['pretrained'])
104
+ # prepare the backbone
105
+ backbone_class = BACKBONE[config['backbone_name']]
106
+ model_config = config['backbone_config']
107
+ backbone = backbone_class(model_config)
108
+ # if donot load the pretrained weights, fail to get good results
109
+ state_dict = torch.load(pretrained_path)
110
+ for name, weights in state_dict.items():
111
+ if 'pointwise' in name:
112
+ state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1)
113
+ state_dict = {k:v for k, v in state_dict.items() if 'fc' not in k}
114
+ backbone.load_state_dict(state_dict, False)
115
+ logger.info('Load pretrained model successfully!')
116
+ return backbone
117
+
118
+ def build_loss(self, config):
119
+ cls_loss_class = LOSSFUNC[config['loss_func']['cls_loss']]
120
+ spe_loss_class = LOSSFUNC[config['loss_func']['spe_loss']]
121
+ con_loss_class = LOSSFUNC[config['loss_func']['con_loss']]
122
+ rec_loss_class = LOSSFUNC[config['loss_func']['rec_loss']]
123
+ cls_loss_func = cls_loss_class()
124
+ spe_loss_func = spe_loss_class()
125
+ con_loss_func = con_loss_class(margin=3.0)
126
+ rec_loss_func = rec_loss_class()
127
+ loss_func = {
128
+ 'cls': cls_loss_func,
129
+ 'spe': spe_loss_func,
130
+ 'con': con_loss_func,
131
+ 'rec': rec_loss_func,
132
+ }
133
+ return loss_func
134
+
135
+ def features(self, data_dict: dict) -> torch.tensor:
136
+ cat_data = data_dict['image']
137
+ # encoder
138
+ f_all = self.encoder_f.features(cat_data)
139
+ c_all = self.encoder_c.features(cat_data)
140
+ feat_dict = {'forgery': f_all, 'content': c_all}
141
+ return feat_dict
142
+
143
+ def classifier(self, features: torch.tensor) -> torch.tensor:
144
+ # classification, multi-task
145
+ # split the features into the specific and common forgery
146
+ f_spe = self.block_spe(features)
147
+ f_share = self.block_sha(features)
148
+ return f_spe, f_share
149
+
150
+ def get_losses(self, data_dict: dict, pred_dict: dict) -> dict:
151
+ if 'label_spe' in data_dict and 'recontruction_imgs' in pred_dict:
152
+ return self.get_train_losses(data_dict, pred_dict)
153
+ else: # test mode
154
+ return self.get_test_losses(data_dict, pred_dict)
155
+
156
+ def get_train_losses(self, data_dict: dict, pred_dict: dict) -> dict:
157
+ # get combined, real, fake imgs
158
+ cat_data = data_dict['image']
159
+ real_img, fake_img = cat_data.chunk(2, dim=0)
160
+ # get the reconstruction imgs
161
+ reconstruction_image_1, \
162
+ reconstruction_image_2, \
163
+ self_reconstruction_image_1, \
164
+ self_reconstruction_image_2 \
165
+ = pred_dict['recontruction_imgs']
166
+ # get label
167
+ label = data_dict['label']
168
+ label_spe = data_dict['label_spe']
169
+ # get pred
170
+ pred = pred_dict['cls']
171
+ pred_spe = pred_dict['cls_spe']
172
+
173
+ # 1. classification loss for common features
174
+ loss_sha = self.loss_func['cls'](pred, label)
175
+
176
+ # 2. classification loss for specific features
177
+ loss_spe = self.loss_func['spe'](pred_spe, label_spe)
178
+
179
+ # 3. reconstruction loss
180
+ self_loss_reconstruction_1 = self.loss_func['rec'](fake_img, self_reconstruction_image_1)
181
+ self_loss_reconstruction_2 = self.loss_func['rec'](real_img, self_reconstruction_image_2)
182
+ cross_loss_reconstruction_1 = self.loss_func['rec'](fake_img, reconstruction_image_2)
183
+ cross_loss_reconstruction_2 = self.loss_func['rec'](real_img, reconstruction_image_1)
184
+ loss_reconstruction = \
185
+ self_loss_reconstruction_1 + self_loss_reconstruction_2 + \
186
+ cross_loss_reconstruction_1 + cross_loss_reconstruction_2
187
+
188
+ # 4. constrative loss
189
+ common_features = pred_dict['feat']
190
+ specific_features = pred_dict['feat_spe']
191
+ loss_con = self.loss_func['con'](common_features, specific_features, label_spe)
192
+
193
+ # 5. total loss
194
+ loss = loss_sha + 0.1*loss_spe + 0.3*loss_reconstruction + 0.05*loss_con
195
+ loss_dict = {
196
+ 'overall': loss,
197
+ 'common': loss_sha,
198
+ 'specific': loss_spe,
199
+ 'reconstruction': loss_reconstruction,
200
+ 'contrastive': loss_con,
201
+ }
202
+ return loss_dict
203
+
204
+ def get_test_losses(self, data_dict: dict, pred_dict: dict) -> dict:
205
+ # get label
206
+ label = data_dict['label']
207
+ # get pred
208
+ pred = pred_dict['cls']
209
+ # for test mode, only classification loss for common features
210
+ loss = self.loss_func['cls'](pred, label)
211
+ loss_dict = {'common': loss}
212
+ return loss_dict
213
+
214
+ def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict:
215
+ def get_accracy(label, output):
216
+ _, prediction = torch.max(output, 1) # argmax
217
+ correct = (prediction == label).sum().item()
218
+ accuracy = correct / prediction.size(0)
219
+ return accuracy
220
+
221
+ # get pred and label
222
+ label = data_dict['label']
223
+ pred = pred_dict['cls']
224
+ label_spe = data_dict['label_spe']
225
+ pred_spe = pred_dict['cls_spe']
226
+
227
+ # compute metrics for batch data
228
+ auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach())
229
+ acc_spe = get_accracy(label_spe.detach(), pred_spe.detach())
230
+ metric_batch_dict = {'acc': acc, 'acc_spe': acc_spe, 'auc': auc, 'eer': eer, 'ap': ap}
231
+ # we dont compute the video-level metrics for training
232
+ return metric_batch_dict
233
+
234
+ def forward(self, data_dict: dict, inference=False) -> dict:
235
+ # split the features into the content and forgery
236
+ features = self.features(data_dict)
237
+ forgery_features, content_features = features['forgery'], features['content']
238
+ # get the prediction by classifier (split the common and specific forgery)
239
+ f_spe, f_share = self.classifier(forgery_features)
240
+
241
+ if inference:
242
+ # inference only consider share loss
243
+ out_sha, sha_feat = self.head_sha(f_share)
244
+ out_spe, spe_feat = self.head_spe(f_spe)
245
+ prob_sha = torch.softmax(out_sha, dim=1)[:, 1]
246
+ self.prob.append(
247
+ prob_sha
248
+ .detach()
249
+ .squeeze()
250
+ .cpu()
251
+ .numpy()
252
+ )
253
+ _, prediction_class = torch.max(out_sha, 1)
254
+ if 'label' in data_dict:
255
+ self.label.append(
256
+ data_dict['label']
257
+ .detach()
258
+ .squeeze()
259
+ .cpu()
260
+ .numpy()
261
+ )
262
+ # deal with acc
263
+ common_label = (data_dict['label'] >= 1)
264
+ correct = (prediction_class == common_label).sum().item()
265
+ self.correct += correct
266
+ self.total += data_dict['label'].size(0)
267
+
268
+ pred_dict = {'cls': out_sha, 'feat': sha_feat}
269
+ return pred_dict
270
+
271
+ bs = f_share.size(0)
272
+ # using idx aug in the training mode
273
+ aug_idx = random.random()
274
+ if aug_idx < 0.7:
275
+ # real
276
+ idx_list = list(range(0, bs//2))
277
+ random.shuffle(idx_list)
278
+ f_share[0: bs//2] = f_share[idx_list]
279
+ # fake
280
+ idx_list = list(range(bs//2, bs))
281
+ random.shuffle(idx_list)
282
+ f_share[bs//2: bs] = f_share[idx_list]
283
+
284
+ # concat spe and share to obtain new_f_all
285
+ f_all = torch.cat((f_spe, f_share), dim=1)
286
+
287
+ # reconstruction loss
288
+ f2, f1 = f_all.chunk(2, dim=0)
289
+ c2, c1 = content_features.chunk(2, dim=0)
290
+
291
+ # ==== self reconstruction ==== #
292
+ # f1 + c1 -> f11, f11 + c1 -> near~I1
293
+ self_reconstruction_image_1 = self.con_gan(f1, c1)
294
+
295
+ # f2 + c2 -> f2, f2 + c2 -> near~I2
296
+ self_reconstruction_image_2 = self.con_gan(f2, c2)
297
+
298
+ # ==== cross combine ==== #
299
+ reconstruction_image_1 = self.con_gan(f1, c2)
300
+ reconstruction_image_2 = self.con_gan(f2, c1)
301
+
302
+ # head for spe and sha
303
+ out_spe, spe_feat = self.head_spe(f_spe)
304
+ out_sha, sha_feat = self.head_sha(f_share)
305
+
306
+ # get the probability of the pred
307
+ prob_sha = torch.softmax(out_sha, dim=1)[:, 1]
308
+ prob_spe = torch.softmax(out_spe, dim=1)[:, 1]
309
+
310
+ # build the prediction dict for each output
311
+ pred_dict = {
312
+ 'cls': out_sha,
313
+ 'prob': prob_sha,
314
+ 'feat': sha_feat,
315
+ 'cls_spe': out_spe,
316
+ 'prob_spe': prob_spe,
317
+ 'feat_spe': spe_feat,
318
+ 'feat_content': content_features,
319
+ 'recontruction_imgs': (
320
+ reconstruction_image_1,
321
+ reconstruction_image_2,
322
+ self_reconstruction_image_1,
323
+ self_reconstruction_image_2
324
+ )
325
+ }
326
+ return pred_dict
327
+
328
+ def sn_double_conv(in_channels, out_channels):
329
+ return nn.Sequential(
330
+ nn.utils.spectral_norm(
331
+ nn.Conv2d(in_channels, in_channels, 3, padding=1)),
332
+ nn.utils.spectral_norm(
333
+ nn.Conv2d(in_channels, out_channels, 3, padding=1, stride=2)),
334
+ nn.LeakyReLU(0.2, inplace=True)
335
+ )
336
+
337
+ def r_double_conv(in_channels, out_channels):
338
+ return nn.Sequential(
339
+ nn.Conv2d(in_channels, out_channels, 3, padding=1),
340
+ nn.ReLU(inplace=True),
341
+ nn.Conv2d(out_channels, out_channels, 3, padding=1),
342
+ nn.ReLU(inplace=True)
343
+ )
344
+
345
+ class AdaIN(nn.Module):
346
+ def __init__(self, eps=1e-5):
347
+ super().__init__()
348
+ self.eps = eps
349
+ # self.l1 = nn.Linear(num_classes, in_channel*4, bias=True) #bias is good :)
350
+
351
+ def c_norm(self, x, bs, ch, eps=1e-7):
352
+ # assert isinstance(x, torch.cuda.FloatTensor)
353
+ x_var = x.var(dim=-1) + eps
354
+ x_std = x_var.sqrt().view(bs, ch, 1, 1)
355
+ x_mean = x.mean(dim=-1).view(bs, ch, 1, 1)
356
+ return x_std, x_mean
357
+
358
+ def forward(self, x, y):
359
+ assert x.size(0)==y.size(0)
360
+ size = x.size()
361
+ bs, ch = size[:2]
362
+ x_ = x.view(bs, ch, -1)
363
+ y_ = y.reshape(bs, ch, -1)
364
+ x_std, x_mean = self.c_norm(x_, bs, ch, eps=self.eps)
365
+ y_std, y_mean = self.c_norm(y_, bs, ch, eps=self.eps)
366
+ out = ((x - x_mean.expand(size)) / x_std.expand(size)) \
367
+ * y_std.expand(size) + y_mean.expand(size)
368
+ return out
369
+
370
+ class Conditional_UNet(nn.Module):
371
+
372
+ def init_weight(self, std=0.2):
373
+ for m in self.modules():
374
+ cn = m.__class__.__name__
375
+ if cn.find('Conv') != -1:
376
+ m.weight.data.normal_(0., std)
377
+ elif cn.find('Linear') != -1:
378
+ m.weight.data.normal_(1., std)
379
+ m.bias.data.fill_(0)
380
+
381
+ def __init__(self):
382
+ super(Conditional_UNet, self).__init__()
383
+
384
+ self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
385
+ self.maxpool = nn.MaxPool2d(2)
386
+ self.dropout = nn.Dropout(p=0.3)
387
+ #self.dropout_half = HalfDropout(p=0.3)
388
+
389
+ self.adain3 = AdaIN()
390
+ self.adain2 = AdaIN()
391
+ self.adain1 = AdaIN()
392
+
393
+ self.dconv_up3 = r_double_conv(512, 256)
394
+ self.dconv_up2 = r_double_conv(256, 128)
395
+ self.dconv_up1 = r_double_conv(128, 64)
396
+
397
+ self.conv_last = nn.Conv2d(64, 3, 1)
398
+ self.up_last = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
399
+ self.activation = nn.Tanh()
400
+ #self.init_weight()
401
+
402
+ def forward(self, c, x): # c is the style and x is the content
403
+ x = self.adain3(x, c)
404
+ x = self.upsample(x)
405
+ x = self.dropout(x)
406
+ x = self.dconv_up3(x)
407
+ c = self.upsample(c)
408
+ c = self.dropout(c)
409
+ c = self.dconv_up3(c)
410
+
411
+ x = self.adain2(x, c)
412
+ x = self.upsample(x)
413
+ x = self.dropout(x)
414
+ x = self.dconv_up2(x)
415
+ c = self.upsample(c)
416
+ c = self.dropout(c)
417
+ c = self.dconv_up2(c)
418
+
419
+ x = self.adain1(x, c)
420
+ x = self.upsample(x)
421
+ x = self.dropout(x)
422
+ x = self.dconv_up1(x)
423
+
424
+ x = self.conv_last(x)
425
+ out = self.up_last(x)
426
+
427
+ return self.activation(out)
428
+
429
+ class MLP(nn.Module):
430
+ def __init__(self, in_f, hidden_dim, out_f):
431
+ super(MLP, self).__init__()
432
+ self.pool = nn.AdaptiveAvgPool2d(1)
433
+ self.mlp = nn.Sequential(nn.Linear(in_f, hidden_dim),
434
+ nn.LeakyReLU(inplace=True),
435
+ nn.Linear(hidden_dim, hidden_dim),
436
+ nn.LeakyReLU(inplace=True),
437
+ nn.Linear(hidden_dim, out_f),)
438
+
439
+ def forward(self, x):
440
+ x = self.pool(x)
441
+ x = self.mlp(x)
442
+ return x
443
+
444
+ class Conv2d1x1(nn.Module):
445
+ def __init__(self, in_f, hidden_dim, out_f):
446
+ super(Conv2d1x1, self).__init__()
447
+ self.conv2d = nn.Sequential(nn.Conv2d(in_f, hidden_dim, 1, 1),
448
+ nn.LeakyReLU(inplace=True),
449
+ nn.Conv2d(hidden_dim, hidden_dim, 1, 1),
450
+ nn.LeakyReLU(inplace=True),
451
+ nn.Conv2d(hidden_dim, out_f, 1, 1),)
452
+
453
+ def forward(self, x):
454
+ x = self.conv2d(x)
455
+ return x
456
+
457
+ class Head(nn.Module):
458
+ def __init__(self, in_f, hidden_dim, out_f):
459
+ super(Head, self).__init__()
460
+ self.do = nn.Dropout(0.2)
461
+ self.pool = nn.AdaptiveAvgPool2d(1)
462
+ self.mlp = nn.Sequential(nn.Linear(in_f, hidden_dim),
463
+ nn.LeakyReLU(inplace=True),
464
+ nn.Linear(hidden_dim, out_f),)
465
+
466
+ def forward(self, x):
467
+ bs = x.size()[0]
468
+ x_feat = self.pool(x).view(bs, -1)
469
+ x = self.mlp(x_feat)
470
+ x = self.do(x)
471
+ return x, x_feat
472
+
loss/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ current_file_path = os.path.abspath(__file__)
4
+ parent_dir = os.path.dirname(os.path.dirname(current_file_path))
5
+ project_root_dir = os.path.dirname(parent_dir)
6
+ sys.path.append(parent_dir)
7
+ sys.path.append(project_root_dir)
8
+
9
+ from metrics.registry import LOSSFUNC
10
+
11
+ from .cross_entropy_loss import CrossEntropyLoss
12
+ from .contrastive_regularization import ContrastiveLoss
13
+ from .l1_loss import L1Loss
loss/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (565 Bytes). View file
 
loss/__pycache__/abstract_loss_func.cpython-310.pyc ADDED
Binary file (977 Bytes). View file
 
loss/__pycache__/contrastive_regularization.cpython-310.pyc ADDED
Binary file (2.38 kB). View file
 
loss/__pycache__/cross_entropy_loss.cpython-310.pyc ADDED
Binary file (1.26 kB). View file
 
loss/__pycache__/l1_loss.cpython-310.pyc ADDED
Binary file (892 Bytes). View file
 
loss/abstract_loss_func.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ class AbstractLossClass(nn.Module):
4
+ """Abstract class for loss functions."""
5
+ def __init__(self):
6
+ super(AbstractLossClass, self).__init__()
7
+
8
+ def forward(self, pred, label):
9
+ """
10
+ Args:
11
+ pred: prediction of the model
12
+ label: ground truth label
13
+
14
+ Return:
15
+ loss: loss value
16
+ """
17
+ raise NotImplementedError('Each subclass should implement the forward method.')
loss/contrastive_regularization.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from collections import defaultdict
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from .abstract_loss_func import AbstractLossClass
7
+ from metrics.registry import LOSSFUNC
8
+
9
+
10
+ def swap_spe_features(type_list, value_list):
11
+ type_list = type_list.cpu().numpy().tolist()
12
+ # get index
13
+ index_list = list(range(len(type_list)))
14
+
15
+ # init a dict, where its key is the type and value is the index
16
+ spe_dict = defaultdict(list)
17
+
18
+ # do for-loop to get spe dict
19
+ for i, one_type in enumerate(type_list):
20
+ spe_dict[one_type].append(index_list[i])
21
+
22
+ # shuffle the value list of each key
23
+ for keys in spe_dict.keys():
24
+ random.shuffle(spe_dict[keys])
25
+
26
+ # generate a new index list for the value list
27
+ new_index_list = []
28
+ for one_type in type_list:
29
+ value = spe_dict[one_type].pop()
30
+ new_index_list.append(value)
31
+
32
+ # swap the value_list by new_index_list
33
+ value_list_new = value_list[new_index_list]
34
+
35
+ return value_list_new
36
+
37
+
38
+ @LOSSFUNC.register_module(module_name="contrastive_regularization")
39
+ class ContrastiveLoss(AbstractLossClass):
40
+ def __init__(self, margin=1.0):
41
+ super().__init__()
42
+ self.margin = margin
43
+
44
+ def contrastive_loss(self, anchor, positive, negative):
45
+ dist_pos = F.pairwise_distance(anchor, positive)
46
+ dist_neg = F.pairwise_distance(anchor, negative)
47
+ # Compute loss as the distance between anchor and negative minus the distance between anchor and positive
48
+ loss = torch.mean(torch.clamp(dist_pos - dist_neg + self.margin, min=0.0))
49
+ return loss
50
+
51
+ def forward(self, common, specific, spe_label):
52
+ # prepare
53
+ bs = common.shape[0]
54
+ real_common, fake_common = common.chunk(2)
55
+ ### common real
56
+ idx_list = list(range(0, bs//2))
57
+ random.shuffle(idx_list)
58
+ real_common_anchor = common[idx_list]
59
+ ### common fake
60
+ idx_list = list(range(bs//2, bs))
61
+ random.shuffle(idx_list)
62
+ fake_common_anchor = common[idx_list]
63
+ ### specific
64
+ specific_anchor = swap_spe_features(spe_label, specific)
65
+ real_specific_anchor, fake_specific_anchor = specific_anchor.chunk(2)
66
+ real_specific, fake_specific = specific.chunk(2)
67
+
68
+ # Compute the contrastive loss of common between real and fake
69
+ loss_realcommon = self.contrastive_loss(real_common, real_common_anchor, fake_common_anchor)
70
+ loss_fakecommon = self.contrastive_loss(fake_common, fake_common_anchor, real_common_anchor)
71
+
72
+ # Comupte the constrastive loss of specific between real and fake
73
+ loss_realspecific = self.contrastive_loss(real_specific, real_specific_anchor, fake_specific_anchor)
74
+ loss_fakespecific = self.contrastive_loss(fake_specific, fake_specific_anchor, real_specific_anchor)
75
+
76
+ # Compute the final loss as the sum of all contrastive losses
77
+ loss = loss_realcommon + loss_fakecommon + loss_fakespecific + loss_realspecific
78
+ return loss
loss/cross_entropy_loss.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from .abstract_loss_func import AbstractLossClass
3
+ from metrics.registry import LOSSFUNC
4
+
5
+
6
+ @LOSSFUNC.register_module(module_name="cross_entropy")
7
+ class CrossEntropyLoss(AbstractLossClass):
8
+ def __init__(self):
9
+ super().__init__()
10
+ self.loss_fn = nn.CrossEntropyLoss()
11
+
12
+ def forward(self, inputs, targets):
13
+ """
14
+ Computes the cross-entropy loss.
15
+
16
+ Args:
17
+ inputs: A PyTorch tensor of size (batch_size, num_classes) containing the predicted scores.
18
+ targets: A PyTorch tensor of size (batch_size) containing the ground-truth class indices.
19
+
20
+ Returns:
21
+ A scalar tensor representing the cross-entropy loss.
22
+ """
23
+ # Compute the cross-entropy loss
24
+ loss = self.loss_fn(inputs, targets)
25
+
26
+ return loss
loss/l1_loss.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from .abstract_loss_func import AbstractLossClass
3
+ from metrics.registry import LOSSFUNC
4
+
5
+
6
+ @LOSSFUNC.register_module(module_name="l1loss")
7
+ class L1Loss(AbstractLossClass):
8
+ def __init__(self):
9
+ super().__init__()
10
+ self.loss_fn = nn.L1Loss()
11
+
12
+ def forward(self, inputs, targets):
13
+ """
14
+ Computes the l1 loss.
15
+ """
16
+ # Compute the l1 loss
17
+ loss = self.loss_fn(inputs, targets)
18
+
19
+ return loss
metrics/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ current_file_path = os.path.abspath(__file__)
4
+ parent_dir = os.path.dirname(os.path.dirname(current_file_path))
5
+ project_root_dir = os.path.dirname(parent_dir)
6
+ sys.path.append(parent_dir)
7
+ sys.path.append(project_root_dir)
metrics/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (351 Bytes). View file
 
metrics/__pycache__/base_metrics_class.cpython-310.pyc ADDED
Binary file (6.21 kB). View file
 
metrics/__pycache__/registry.cpython-310.pyc ADDED
Binary file (1.01 kB). View file
 
metrics/base_metrics_class.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from sklearn import metrics
3
+ from collections import defaultdict
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+ def get_accracy(output, label):
9
+ _, prediction = torch.max(output, 1) # argmax
10
+ correct = (prediction == label).sum().item()
11
+ accuracy = correct / prediction.size(0)
12
+ return accuracy
13
+
14
+
15
+ def get_prediction(output, label):
16
+ prob = nn.functional.softmax(output, dim=1)[:, 1]
17
+ prob = prob.view(prob.size(0), 1)
18
+ label = label.view(label.size(0), 1)
19
+ #print(prob.size(), label.size())
20
+ datas = torch.cat((prob, label.float()), dim=1)
21
+ return datas
22
+
23
+
24
+ def calculate_metrics_for_train(label, output):
25
+ if output.size(1) == 2:
26
+ prob = torch.softmax(output, dim=1)[:, 1]
27
+ else:
28
+ prob = output
29
+
30
+ # Accuracy
31
+ _, prediction = torch.max(output, 1)
32
+ correct = (prediction == label).sum().item()
33
+ accuracy = correct / prediction.size(0)
34
+
35
+ # Average Precision
36
+ y_true = label.cpu().detach().numpy()
37
+ y_pred = prob.cpu().detach().numpy()
38
+ ap = metrics.average_precision_score(y_true, y_pred)
39
+
40
+ # AUC and EER
41
+ try:
42
+ fpr, tpr, thresholds = metrics.roc_curve(label.squeeze().cpu().numpy(),
43
+ prob.squeeze().cpu().numpy(),
44
+ pos_label=1)
45
+ except:
46
+ # for the case when we only have one sample
47
+ return None, None, accuracy, ap
48
+
49
+ if np.isnan(fpr[0]) or np.isnan(tpr[0]):
50
+ # for the case when all the samples within a batch is fake/real
51
+ auc, eer = None, None
52
+ else:
53
+ auc = metrics.auc(fpr, tpr)
54
+ fnr = 1 - tpr
55
+ eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
56
+
57
+ return auc, eer, accuracy, ap
58
+
59
+
60
+ # ------------ compute average metrics of batches---------------------
61
+ class Metrics_batch():
62
+ def __init__(self):
63
+ self.tprs = []
64
+ self.mean_fpr = np.linspace(0, 1, 100)
65
+ self.aucs = []
66
+ self.eers = []
67
+ self.aps = []
68
+
69
+ self.correct = 0
70
+ self.total = 0
71
+ self.losses = []
72
+
73
+ def update(self, label, output):
74
+ acc = self._update_acc(label, output)
75
+ if output.size(1) == 2:
76
+ prob = torch.softmax(output, dim=1)[:, 1]
77
+ else:
78
+ prob = output
79
+ #label = 1-label
80
+ #prob = torch.softmax(output, dim=1)[:, 1]
81
+ auc, eer = self._update_auc(label, prob)
82
+ ap = self._update_ap(label, prob)
83
+
84
+ return acc, auc, eer, ap
85
+
86
+ def _update_auc(self, lab, prob):
87
+ fpr, tpr, thresholds = metrics.roc_curve(lab.squeeze().cpu().numpy(),
88
+ prob.squeeze().cpu().numpy(),
89
+ pos_label=1)
90
+ if np.isnan(fpr[0]) or np.isnan(tpr[0]):
91
+ return -1, -1
92
+
93
+ auc = metrics.auc(fpr, tpr)
94
+ interp_tpr = np.interp(self.mean_fpr, fpr, tpr)
95
+ interp_tpr[0] = 0.0
96
+ self.tprs.append(interp_tpr)
97
+ self.aucs.append(auc)
98
+
99
+ # return auc
100
+
101
+ # EER
102
+ fnr = 1 - tpr
103
+ eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
104
+ self.eers.append(eer)
105
+
106
+ return auc, eer
107
+
108
+ def _update_acc(self, lab, output):
109
+ _, prediction = torch.max(output, 1) # argmax
110
+ correct = (prediction == lab).sum().item()
111
+ accuracy = correct / prediction.size(0)
112
+ # self.accs.append(accuracy)
113
+ self.correct = self.correct+correct
114
+ self.total = self.total+lab.size(0)
115
+ return accuracy
116
+
117
+ def _update_ap(self, label, prob):
118
+ y_true = label.cpu().detach().numpy()
119
+ y_pred = prob.cpu().detach().numpy()
120
+ ap = metrics.average_precision_score(y_true,y_pred)
121
+ self.aps.append(ap)
122
+
123
+ return np.mean(ap)
124
+
125
+ def get_mean_metrics(self):
126
+ mean_acc, std_acc = self.correct/self.total, 0
127
+ mean_auc, std_auc = self._mean_auc()
128
+ mean_err, std_err = np.mean(self.eers), np.std(self.eers)
129
+ mean_ap, std_ap = np.mean(self.aps), np.std(self.aps)
130
+
131
+ return {'acc':mean_acc, 'auc':mean_auc, 'eer':mean_err, 'ap':mean_ap}
132
+
133
+ def _mean_auc(self):
134
+ mean_tpr = np.mean(self.tprs, axis=0)
135
+ mean_tpr[-1] = 1.0
136
+ mean_auc = metrics.auc(self.mean_fpr, mean_tpr)
137
+ std_auc = np.std(self.aucs)
138
+ return mean_auc, std_auc
139
+
140
+ def clear(self):
141
+ self.tprs.clear()
142
+ self.aucs.clear()
143
+ # self.accs.clear()
144
+ self.correct=0
145
+ self.total=0
146
+ self.eers.clear()
147
+ self.aps.clear()
148
+ self.losses.clear()
149
+
150
+
151
+ # ------------ compute average metrics of all data ---------------------
152
+ class Metrics_all():
153
+ def __init__(self):
154
+ self.probs = []
155
+ self.labels = []
156
+ self.correct = 0
157
+ self.total = 0
158
+
159
+ def store(self, label, output):
160
+ prob = torch.softmax(output, dim=1)[:, 1]
161
+ _, prediction = torch.max(output, 1) # argmax
162
+ correct = (prediction == label).sum().item()
163
+ self.correct += correct
164
+ self.total += label.size(0)
165
+ self.labels.append(label.squeeze().cpu().numpy())
166
+ self.probs.append(prob.squeeze().cpu().numpy())
167
+
168
+ def get_metrics(self):
169
+ y_pred = np.concatenate(self.probs)
170
+ y_true = np.concatenate(self.labels)
171
+ # auc
172
+ fpr, tpr, thresholds = metrics.roc_curve(y_true,y_pred,pos_label=1)
173
+ auc = metrics.auc(fpr, tpr)
174
+ # eer
175
+ fnr = 1 - tpr
176
+ eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
177
+ # ap
178
+ ap = metrics.average_precision_score(y_true,y_pred)
179
+ # acc
180
+ acc = self.correct / self.total
181
+ return {'acc':acc, 'auc':auc, 'eer':eer, 'ap':ap}
182
+
183
+ def clear(self):
184
+ self.probs.clear()
185
+ self.labels.clear()
186
+ self.correct = 0
187
+ self.total = 0
188
+
189
+
190
+ # only used to record a series of scalar value
191
+ class Recorder:
192
+ def __init__(self):
193
+ self.sum = 0
194
+ self.num = 0
195
+ def update(self, item, num=1):
196
+ if item is not None:
197
+ self.sum += item * num
198
+ self.num += num
199
+ def average(self):
200
+ if self.num == 0:
201
+ return None
202
+ return self.sum/self.num
203
+ def clear(self):
204
+ self.sum = 0
205
+ self.num = 0
metrics/registry.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Registry(object):
2
+ def __init__(self):
3
+ self.data = {}
4
+
5
+ def register_module(self, module_name=None):
6
+ def _register(cls):
7
+ name = module_name
8
+ if module_name is None:
9
+ name = cls.__name__
10
+ self.data[name] = cls
11
+ return cls
12
+ return _register
13
+
14
+ def __getitem__(self, key):
15
+ return self.data[key]
16
+
17
+ BACKBONE = Registry()
18
+ DETECTOR = Registry()
19
+ TRAINER = Registry()
20
+ LOSSFUNC = Registry()
metrics/utils.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn import metrics
2
+ import numpy as np
3
+
4
+
5
+ def parse_metric_for_print(metric_dict):
6
+ if metric_dict is None:
7
+ return "\n"
8
+ str = "\n"
9
+ str += "================================ Each dataset best metric ================================ \n"
10
+ for key, value in metric_dict.items():
11
+ if key != 'avg':
12
+ str= str+ f"| {key}: "
13
+ for k,v in value.items():
14
+ str = str + f" {k}={v} "
15
+ str= str+ "| \n"
16
+ else:
17
+ str += "============================================================================================= \n"
18
+ str += "================================== Average best metric ====================================== \n"
19
+ avg_dict = value
20
+ for avg_key, avg_value in avg_dict.items():
21
+ if avg_key == 'dataset_dict':
22
+ for key,value in avg_value.items():
23
+ str = str + f"| {key}: {value} | \n"
24
+ else:
25
+ str = str + f"| avg {avg_key}: {avg_value} | \n"
26
+ str += "============================================================================================="
27
+ return str
28
+
29
+
30
+ def get_test_metrics(y_pred, y_true, img_names=None, logger=None):
31
+ def get_video_metrics(image, pred, label):
32
+ result_dict = {}
33
+ new_label = []
34
+ new_pred = []
35
+ # print(image[0])
36
+ # print(pred.shape)
37
+ # print(label.shape)
38
+ for item in np.transpose(np.stack((image, pred, label)), (1, 0)):
39
+
40
+ s = item[0]
41
+ if '\\' in s:
42
+ parts = s.split('\\')
43
+ else:
44
+ parts = s.split('/')
45
+ a = parts[-2]
46
+ b = parts[-1]
47
+
48
+ if a not in result_dict:
49
+ result_dict[a] = []
50
+
51
+ result_dict[a].append(item)
52
+ image_arr = list(result_dict.values())
53
+
54
+ for video in image_arr:
55
+ pred_sum = 0
56
+ label_sum = 0
57
+ leng = 0
58
+ for frame in video:
59
+ pred_sum += float(frame[1])
60
+ label_sum += int(frame[2])
61
+ leng += 1
62
+ new_pred.append(pred_sum / leng)
63
+ new_label.append(int(label_sum / leng))
64
+ fpr, tpr, thresholds = metrics.roc_curve(new_label, new_pred)
65
+ v_auc = metrics.auc(fpr, tpr)
66
+ fnr = 1 - tpr
67
+ v_eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
68
+ return v_auc, v_eer
69
+
70
+
71
+ y_pred = y_pred.squeeze()
72
+
73
+ # For UCF, where labels for different manipulations are not consistent.
74
+ y_true[y_true >= 1] = 1
75
+ # auc
76
+ fpr, tpr, thresholds = metrics.roc_curve(y_true, y_pred, pos_label=1)
77
+ auc = metrics.auc(fpr, tpr)
78
+ # eer
79
+ fnr = 1 - tpr
80
+ eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
81
+ # ap
82
+ ap = metrics.average_precision_score(y_true, y_pred)
83
+ # acc
84
+ prediction_class = (y_pred > 0.5).astype(int)
85
+ correct = (prediction_class == np.clip(y_true, a_min=0, a_max=1)).sum().item()
86
+ acc = correct / len(prediction_class)
87
+
88
+ return {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap, 'pred': y_pred, 'label': y_true}
networks/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ current_file_path = os.path.abspath(__file__)
4
+ parent_dir = os.path.dirname(os.path.dirname(current_file_path))
5
+ project_root_dir = os.path.dirname(parent_dir)
6
+ sys.path.append(parent_dir)
7
+ sys.path.append(project_root_dir)
8
+
9
+ from metrics.registry import BACKBONE
10
+
11
+ from .xception import Xception
networks/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (447 Bytes). View file
 
networks/__pycache__/xception.cpython-310.pyc ADDED
Binary file (6.7 kB). View file
 
networks/xception.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ # author: Zhiyuan Yan
3
+ # email: zhiyuanyan@link.cuhk.edu.cn
4
+ # date: 2023-0706
5
+
6
+ The code is mainly modified from GitHub link below:
7
+ https://github.com/ondyari/FaceForensics/blob/master/classification/network/xception.py
8
+ '''
9
+
10
+ import os
11
+ import argparse
12
+ import logging
13
+
14
+ import math
15
+ import torch
16
+ # import pretrainedmodels
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+
20
+ import torch.utils.model_zoo as model_zoo
21
+ from torch.nn import init
22
+ from typing import Union
23
+ from metrics.registry import BACKBONE
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+
29
+ class SeparableConv2d(nn.Module):
30
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False):
31
+ super(SeparableConv2d, self).__init__()
32
+
33
+ self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size,
34
+ stride, padding, dilation, groups=in_channels, bias=bias)
35
+ self.pointwise = nn.Conv2d(
36
+ in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias)
37
+
38
+ def forward(self, x):
39
+ x = self.conv1(x)
40
+ x = self.pointwise(x)
41
+ return x
42
+
43
+
44
+ class Block(nn.Module):
45
+ def __init__(self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True):
46
+ super(Block, self).__init__()
47
+
48
+ if out_filters != in_filters or strides != 1:
49
+ self.skip = nn.Conv2d(in_filters, out_filters,
50
+ 1, stride=strides, bias=False)
51
+ self.skipbn = nn.BatchNorm2d(out_filters)
52
+ else:
53
+ self.skip = None
54
+
55
+ self.relu = nn.ReLU(inplace=True)
56
+ rep = []
57
+
58
+ filters = in_filters
59
+ if grow_first: # whether the number of filters grows first
60
+ rep.append(self.relu)
61
+ rep.append(SeparableConv2d(in_filters, out_filters,
62
+ 3, stride=1, padding=1, bias=False))
63
+ rep.append(nn.BatchNorm2d(out_filters))
64
+ filters = out_filters
65
+
66
+ for i in range(reps-1):
67
+ rep.append(self.relu)
68
+ rep.append(SeparableConv2d(filters, filters,
69
+ 3, stride=1, padding=1, bias=False))
70
+ rep.append(nn.BatchNorm2d(filters))
71
+
72
+ if not grow_first:
73
+ rep.append(self.relu)
74
+ rep.append(SeparableConv2d(in_filters, out_filters,
75
+ 3, stride=1, padding=1, bias=False))
76
+ rep.append(nn.BatchNorm2d(out_filters))
77
+
78
+ if not start_with_relu:
79
+ rep = rep[1:]
80
+ else:
81
+ rep[0] = nn.ReLU(inplace=False)
82
+
83
+ if strides != 1:
84
+ rep.append(nn.MaxPool2d(3, strides, 1))
85
+ self.rep = nn.Sequential(*rep)
86
+
87
+ def forward(self, inp):
88
+ x = self.rep(inp)
89
+
90
+ if self.skip is not None:
91
+ skip = self.skip(inp)
92
+ skip = self.skipbn(skip)
93
+ else:
94
+ skip = inp
95
+
96
+ x += skip
97
+ return x
98
+
99
+ def add_gaussian_noise(ins, mean=0, stddev=0.2):
100
+ noise = ins.data.new(ins.size()).normal_(mean, stddev)
101
+ return ins + noise
102
+
103
+
104
+ @BACKBONE.register_module(module_name="xception")
105
+ class Xception(nn.Module):
106
+ """
107
+ Xception optimized for the ImageNet dataset, as specified in
108
+ https://arxiv.org/pdf/1610.02357.pdf
109
+ """
110
+
111
+ def __init__(self, xception_config):
112
+ """ Constructor
113
+ Args:
114
+ xception_config: configuration file with the dict format
115
+ """
116
+ super(Xception, self).__init__()
117
+ self.num_classes = xception_config["num_classes"]
118
+ self.mode = xception_config["mode"]
119
+ inc = xception_config["inc"]
120
+ dropout = xception_config["dropout"]
121
+
122
+ # Entry flow
123
+ self.conv1 = nn.Conv2d(inc, 32, 3, 2, 0, bias=False)
124
+
125
+ self.bn1 = nn.BatchNorm2d(32)
126
+ self.relu = nn.ReLU(inplace=True)
127
+
128
+ self.conv2 = nn.Conv2d(32, 64, 3, bias=False)
129
+ self.bn2 = nn.BatchNorm2d(64)
130
+ # do relu here
131
+
132
+ self.block1 = Block(
133
+ 64, 128, 2, 2, start_with_relu=False, grow_first=True)
134
+ self.block2 = Block(
135
+ 128, 256, 2, 2, start_with_relu=True, grow_first=True)
136
+ self.block3 = Block(
137
+ 256, 728, 2, 2, start_with_relu=True, grow_first=True)
138
+
139
+ # middle flow
140
+ self.block4 = Block(
141
+ 728, 728, 3, 1, start_with_relu=True, grow_first=True)
142
+ self.block5 = Block(
143
+ 728, 728, 3, 1, start_with_relu=True, grow_first=True)
144
+ self.block6 = Block(
145
+ 728, 728, 3, 1, start_with_relu=True, grow_first=True)
146
+ self.block7 = Block(
147
+ 728, 728, 3, 1, start_with_relu=True, grow_first=True)
148
+
149
+ self.block8 = Block(
150
+ 728, 728, 3, 1, start_with_relu=True, grow_first=True)
151
+ self.block9 = Block(
152
+ 728, 728, 3, 1, start_with_relu=True, grow_first=True)
153
+ self.block10 = Block(
154
+ 728, 728, 3, 1, start_with_relu=True, grow_first=True)
155
+ self.block11 = Block(
156
+ 728, 728, 3, 1, start_with_relu=True, grow_first=True)
157
+
158
+ # Exit flow
159
+ self.block12 = Block(
160
+ 728, 1024, 2, 2, start_with_relu=True, grow_first=False)
161
+
162
+ self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1)
163
+ self.bn3 = nn.BatchNorm2d(1536)
164
+
165
+ # do relu here
166
+ self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1)
167
+ self.bn4 = nn.BatchNorm2d(2048)
168
+ # used for iid
169
+ final_channel = 2048
170
+ if self.mode == 'adjust_channel_iid':
171
+ final_channel = 512
172
+ self.mode = 'adjust_channel'
173
+ self.last_linear = nn.Linear(final_channel, self.num_classes)
174
+ if dropout:
175
+ self.last_linear = nn.Sequential(
176
+ nn.Dropout(p=dropout),
177
+ nn.Linear(final_channel, self.num_classes)
178
+ )
179
+
180
+ self.adjust_channel = nn.Sequential(
181
+ nn.Conv2d(2048, 512, 1, 1),
182
+ nn.BatchNorm2d(512),
183
+ nn.ReLU(inplace=False),
184
+ )
185
+
186
+ def fea_part1_0(self, x):
187
+ x = self.conv1(x)
188
+ x = self.bn1(x)
189
+ x = self.relu(x)
190
+
191
+ return x
192
+
193
+ def fea_part1_1(self, x):
194
+
195
+ x = self.conv2(x)
196
+ x = self.bn2(x)
197
+ x = self.relu(x)
198
+
199
+ return x
200
+
201
+ def fea_part1(self, x):
202
+ x = self.conv1(x)
203
+ x = self.bn1(x)
204
+ x = self.relu(x)
205
+
206
+ x = self.conv2(x)
207
+ x = self.bn2(x)
208
+ x = self.relu(x)
209
+
210
+ return x
211
+
212
+ def fea_part2(self, x):
213
+ x = self.block1(x)
214
+ x = self.block2(x)
215
+ x = self.block3(x)
216
+
217
+ return x
218
+
219
+ def fea_part3(self, x):
220
+ if self.mode == "shallow_xception":
221
+ return x
222
+ else:
223
+ x = self.block4(x)
224
+ x = self.block5(x)
225
+ x = self.block6(x)
226
+ x = self.block7(x)
227
+ return x
228
+
229
+ def fea_part4(self, x):
230
+ if self.mode == "shallow_xception":
231
+ x = self.block12(x)
232
+ else:
233
+ x = self.block8(x)
234
+ x = self.block9(x)
235
+ x = self.block10(x)
236
+ x = self.block11(x)
237
+ x = self.block12(x)
238
+ return x
239
+
240
+ def fea_part5(self, x):
241
+ x = self.conv3(x)
242
+ x = self.bn3(x)
243
+ x = self.relu(x)
244
+
245
+ x = self.conv4(x)
246
+ x = self.bn4(x)
247
+
248
+ return x
249
+
250
+ def features(self, input):
251
+ x = self.fea_part1(input)
252
+
253
+ x = self.fea_part2(x)
254
+ x = self.fea_part3(x)
255
+ x = self.fea_part4(x)
256
+
257
+ x = self.fea_part5(x)
258
+
259
+ if self.mode == 'adjust_channel':
260
+ x = self.adjust_channel(x)
261
+
262
+ return x
263
+
264
+ def classifier(self, features,id_feat=None):
265
+ # for iid
266
+ if self.mode == 'adjust_channel':
267
+ x = features
268
+ else:
269
+ x = self.relu(features)
270
+
271
+ if len(x.shape) == 4:
272
+ x = F.adaptive_avg_pool2d(x, (1, 1))
273
+ x = x.view(x.size(0), -1)
274
+ self.last_emb = x
275
+ # for iid
276
+ if id_feat!=None:
277
+ out = self.last_linear(x-id_feat)
278
+ else:
279
+ out = self.last_linear(x)
280
+ return out
281
+
282
+ def forward(self, input):
283
+ x = self.features(input)
284
+ out = self.classifier(x)
285
+ return out, x
optimizor/LinearLR.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.optim import SGD
3
+ from torch.optim.lr_scheduler import _LRScheduler
4
+
5
+ class LinearDecayLR(_LRScheduler):
6
+ def __init__(self, optimizer, n_epoch, start_decay, last_epoch=-1):
7
+ self.start_decay=start_decay
8
+ self.n_epoch=n_epoch
9
+ super(LinearDecayLR, self).__init__(optimizer, last_epoch)
10
+
11
+ def get_lr(self):
12
+ last_epoch = self.last_epoch
13
+ n_epoch=self.n_epoch
14
+ b_lr=self.base_lrs[0]
15
+ start_decay=self.start_decay
16
+ if last_epoch>start_decay:
17
+ lr=b_lr-b_lr/(n_epoch-start_decay)*(last_epoch-start_decay)
18
+ else:
19
+ lr=b_lr
20
+ return [lr]
optimizor/SAM.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # borrowed from
2
+
3
+ import torch
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ def disable_running_stats(model):
9
+ def _disable(module):
10
+ if isinstance(module, nn.BatchNorm2d):
11
+ module.backup_momentum = module.momentum
12
+ module.momentum = 0
13
+
14
+ model.apply(_disable)
15
+
16
+ def enable_running_stats(model):
17
+ def _enable(module):
18
+ if isinstance(module, nn.BatchNorm2d) and hasattr(module, "backup_momentum"):
19
+ module.momentum = module.backup_momentum
20
+
21
+ model.apply(_enable)
22
+
23
+ class SAM(torch.optim.Optimizer):
24
+ def __init__(self, params, base_optimizer, rho=0.05, **kwargs):
25
+ assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
26
+
27
+ defaults = dict(rho=rho, **kwargs)
28
+ super(SAM, self).__init__(params, defaults)
29
+
30
+ self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
31
+ self.param_groups = self.base_optimizer.param_groups
32
+
33
+ @torch.no_grad()
34
+ def first_step(self, zero_grad=False):
35
+ grad_norm = self._grad_norm()
36
+ for group in self.param_groups:
37
+ scale = group["rho"] / (grad_norm + 1e-12)
38
+
39
+ for p in group["params"]:
40
+ if p.grad is None: continue
41
+ e_w = p.grad * scale.to(p)
42
+ p.add_(e_w) # climb to the local maximum "w + e(w)"
43
+ self.state[p]["e_w"] = e_w
44
+
45
+ if zero_grad: self.zero_grad()
46
+
47
+ @torch.no_grad()
48
+ def second_step(self, zero_grad=False):
49
+ for group in self.param_groups:
50
+ for p in group["params"]:
51
+ if p.grad is None: continue
52
+ p.sub_(self.state[p]["e_w"]) # get back to "w" from "w + e(w)"
53
+
54
+ self.base_optimizer.step() # do the actual "sharpness-aware" update
55
+
56
+ if zero_grad: self.zero_grad()
57
+
58
+ @torch.no_grad()
59
+ def step(self, closure=None):
60
+ assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
61
+ closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass
62
+
63
+ self.first_step(zero_grad=True)
64
+ closure()
65
+ self.second_step()
66
+
67
+ def _grad_norm(self):
68
+ shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism
69
+ norm = torch.norm(
70
+ torch.stack([
71
+ p.grad.norm(p=2).to(shared_device)
72
+ for group in self.param_groups for p in group["params"]
73
+ if p.grad is not None
74
+ ]),
75
+ p=2
76
+ )
77
+ return norm
trainer/trainer.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This script was adapted from the DeepfakeBench training code,
2
+ # originally authored by Zhiyuan Yan (zhiyuanyan@link.cuhk.edu.cn)
3
+
4
+ # Original: https://github.com/SCLBD/DeepfakeBench/blob/main/training/train.py
5
+
6
+ import os
7
+ import sys
8
+ current_file_path = os.path.abspath(__file__)
9
+ parent_dir = os.path.dirname(os.path.dirname(current_file_path))
10
+ project_root_dir = os.path.dirname(parent_dir)
11
+ sys.path.append(parent_dir)
12
+ sys.path.append(project_root_dir)
13
+
14
+ import pickle
15
+ import datetime
16
+ import logging
17
+ import numpy as np
18
+ from copy import deepcopy
19
+ from collections import defaultdict
20
+ from tqdm import tqdm
21
+ import time
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+ import torch.optim as optim
26
+ from torch.nn import DataParallel
27
+ from torch.utils.tensorboard import SummaryWriter
28
+ from metrics.base_metrics_class import Recorder
29
+ from torch.optim.swa_utils import AveragedModel, SWALR
30
+ from torch import distributed as dist
31
+ from torch.nn.parallel import DistributedDataParallel as DDP
32
+ from sklearn import metrics
33
+ from metrics.utils import get_test_metrics
34
+
35
+ FFpp_pool=['FaceForensics++','FF-DF','FF-F2F','FF-FS','FF-NT']#
36
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
+
38
+
39
+ class Trainer(object):
40
+ def __init__(
41
+ self,
42
+ config,
43
+ model,
44
+ optimizer,
45
+ scheduler,
46
+ logger,
47
+ metric_scoring='auc',
48
+ swa_model=None
49
+ ):
50
+ # check if all the necessary components are implemented
51
+ if config is None or model is None or optimizer is None or logger is None:
52
+ raise ValueError("config, model, optimizier, logger, and tensorboard writer must be implemented")
53
+
54
+ self.config = config
55
+ self.model = model
56
+ self.optimizer = optimizer
57
+ self.scheduler = scheduler
58
+ self.swa_model = swa_model
59
+ self.writers = {} # dict to maintain different tensorboard writers for each dataset and metric
60
+ self.logger = logger
61
+ self.metric_scoring = metric_scoring
62
+ # maintain the best metric of all epochs
63
+ self.best_metrics_all_time = defaultdict(
64
+ lambda: defaultdict(lambda: float('-inf')
65
+ if self.metric_scoring != 'eer' else float('inf'))
66
+ )
67
+ self.speed_up() # move model to GPU
68
+
69
+ # create directory path
70
+ self.log_dir = self.config['log_dir']
71
+ print("Making dir ", self.log_dir)
72
+ os.makedirs(self.log_dir, exist_ok=True)
73
+
74
+ def get_writer(self, phase, dataset_key, metric_key):
75
+ phase = phase.split('/')[-1]
76
+ dataset_key = dataset_key.split('/')[-1]
77
+ metric_key = metric_key.split('/')[-1]
78
+ writer_key = f"{phase}-{dataset_key}-{metric_key}"
79
+ if writer_key not in self.writers:
80
+ # update directory path
81
+ writer_path = os.path.join(
82
+ self.log_dir,
83
+ phase,
84
+ dataset_key,
85
+ metric_key,
86
+ "metric_board"
87
+ )
88
+ os.makedirs(writer_path, exist_ok=True)
89
+ # update writers dictionary
90
+ self.writers[writer_key] = SummaryWriter(writer_path)
91
+ return self.writers[writer_key]
92
+
93
+ def speed_up(self):
94
+ self.model.to(device)
95
+ self.model.device = device
96
+ if self.config['ddp'] == True:
97
+ num_gpus = torch.cuda.device_count()
98
+ print(f'avai gpus: {num_gpus}')
99
+ # local_rank=[i for i in range(0,num_gpus)]
100
+ self.model = DDP(self.model, device_ids=[self.config['local_rank']],find_unused_parameters=True, output_device=self.config['local_rank'])
101
+ #self.optimizer = nn.DataParallel(self.optimizer, device_ids=[int(os.environ['LOCAL_RANK'])])
102
+
103
+ def setTrain(self):
104
+ self.model.train()
105
+ self.train = True
106
+
107
+ def setEval(self):
108
+ self.model.eval()
109
+ self.train = False
110
+
111
+ def load_ckpt(self, model_path):
112
+ if os.path.isfile(model_path):
113
+ saved = torch.load(model_path, map_location='cpu')
114
+ suffix = model_path.split('.')[-1]
115
+ if suffix == 'p':
116
+ self.model.load_state_dict(saved.state_dict())
117
+ else:
118
+ self.model.load_state_dict(saved)
119
+ self.logger.info('Model found in {}'.format(model_path))
120
+ else:
121
+ raise NotImplementedError(
122
+ "=> no model found at '{}'".format(model_path))
123
+
124
+ def save_ckpt(self, phase, dataset_key,ckpt_info=None):
125
+ save_dir = self.log_dir
126
+ os.makedirs(save_dir, exist_ok=True)
127
+ ckpt_name = f"ckpt_best.pth"
128
+ save_path = os.path.join(save_dir, ckpt_name)
129
+ if self.config['ddp'] == True:
130
+ torch.save(self.model.state_dict(), save_path)
131
+ else:
132
+ if 'svdd' in self.config['model_name']:
133
+ torch.save({'R': self.model.R,
134
+ 'c': self.model.c,
135
+ 'state_dict': self.model.state_dict(),}, save_path)
136
+ else:
137
+ torch.save(self.model.state_dict(), save_path)
138
+ self.logger.info(f"Checkpoint saved to {save_path}, current ckpt is {ckpt_info}")
139
+
140
+ def save_swa_ckpt(self):
141
+ save_dir = self.log_dir
142
+ os.makedirs(save_dir, exist_ok=True)
143
+ ckpt_name = f"swa.pth"
144
+ save_path = os.path.join(save_dir, ckpt_name)
145
+ torch.save(self.swa_model.state_dict(), save_path)
146
+ self.logger.info(f"SWA Checkpoint saved to {save_path}")
147
+
148
+ def save_feat(self, phase, fea, dataset_key):
149
+ save_dir = os.path.join(self.log_dir, phase, dataset_key)
150
+ os.makedirs(save_dir, exist_ok=True)
151
+ features = fea
152
+ feat_name = f"feat_best.npy"
153
+ save_path = os.path.join(save_dir, feat_name)
154
+ np.save(save_path, features)
155
+ self.logger.info(f"Feature saved to {save_path}")
156
+
157
+ def save_data_dict(self, phase, data_dict, dataset_key):
158
+ save_dir = os.path.join(self.log_dir, phase, dataset_key)
159
+ os.makedirs(save_dir, exist_ok=True)
160
+ file_path = os.path.join(save_dir, f'data_dict_{phase}.pickle')
161
+ with open(file_path, 'wb') as file:
162
+ pickle.dump(data_dict, file)
163
+ self.logger.info(f"data_dict saved to {file_path}")
164
+
165
+ def save_metrics(self, phase, metric_one_dataset, dataset_key):
166
+ save_dir = os.path.join(self.log_dir, phase, dataset_key)
167
+ os.makedirs(save_dir, exist_ok=True)
168
+ file_path = os.path.join(save_dir, 'metric_dict_best.pickle')
169
+ with open(file_path, 'wb') as file:
170
+ pickle.dump(metric_one_dataset, file)
171
+ self.logger.info(f"Metrics saved to {file_path}")
172
+
173
+ def train_step(self,data_dict):
174
+ if self.config['optimizer']['type']=='sam':
175
+ for i in range(2):
176
+ predictions = self.model(data_dict)
177
+ losses = self.model.get_losses(data_dict, predictions)
178
+ if i == 0:
179
+ pred_first = predictions
180
+ losses_first = losses
181
+ self.optimizer.zero_grad()
182
+ losses['overall'].backward()
183
+ if i == 0:
184
+ self.optimizer.first_step(zero_grad=True)
185
+ else:
186
+ self.optimizer.second_step(zero_grad=True)
187
+ return losses_first, pred_first
188
+ else:
189
+ predictions = self.model(data_dict)
190
+ if type(self.model) is DDP:
191
+ losses = self.model.module.get_losses(data_dict, predictions)
192
+ else:
193
+ losses = self.model.get_losses(data_dict, predictions)
194
+ self.optimizer.zero_grad()
195
+ losses['overall'].backward()
196
+ self.optimizer.step()
197
+
198
+ return losses,predictions
199
+
200
+ def train_epoch(
201
+ self,
202
+ epoch,
203
+ train_data_loader,
204
+ validation_data_loaders=None
205
+ ):
206
+
207
+ self.logger.info("===> Epoch[{}] start!".format(epoch))
208
+ if epoch>=1:
209
+ times_per_epoch = 2
210
+ else:
211
+ times_per_epoch = 1
212
+
213
+
214
+ #times_per_epoch=4
215
+ validation_step = len(train_data_loader) // times_per_epoch # validate 10 times per epoch
216
+ step_cnt = epoch * len(train_data_loader)
217
+
218
+ # define training recorder
219
+ train_recorder_loss = defaultdict(Recorder)
220
+ train_recorder_metric = defaultdict(Recorder)
221
+
222
+ for iteration, data_dict in tqdm(enumerate(train_data_loader),total=len(train_data_loader)):
223
+ self.setTrain()
224
+ # more elegant and more scalable way of moving data to GPU
225
+ for key in data_dict.keys():
226
+ if data_dict[key]!=None and key!='name':
227
+ data_dict[key]=data_dict[key].cuda()
228
+
229
+ losses, predictions=self.train_step(data_dict)
230
+ # update learning rate
231
+
232
+ if 'SWA' in self.config and self.config['SWA'] and epoch>self.config['swa_start']:
233
+ self.swa_model.update_parameters(self.model)
234
+
235
+ # compute training metric for each batch data
236
+ if type(self.model) is DDP:
237
+ batch_metrics = self.model.module.get_train_metrics(data_dict, predictions)
238
+ else:
239
+ batch_metrics = self.model.get_train_metrics(data_dict, predictions)
240
+
241
+ # store data by recorder
242
+ ## store metric
243
+ for name, value in batch_metrics.items():
244
+ train_recorder_metric[name].update(value)
245
+ ## store loss
246
+ for name, value in losses.items():
247
+ train_recorder_loss[name].update(value)
248
+
249
+ # run tensorboard to visualize the training process
250
+ if iteration % 300 == 0 and self.config['local_rank']==0:
251
+ if self.config['SWA'] and (epoch>self.config['swa_start'] or self.config['dry_run']):
252
+ self.scheduler.step()
253
+ # info for loss
254
+ loss_str = f"Iter: {step_cnt} "
255
+ for k, v in train_recorder_loss.items():
256
+ v_avg = v.average()
257
+ if v_avg == None:
258
+ loss_str += f"training-loss, {k}: not calculated"
259
+ continue
260
+ loss_str += f"training-loss, {k}: {v_avg} "
261
+ # tensorboard-1. loss
262
+ processed_train_dataset = [dataset.split('/')[-1] for dataset in self.config['train_dataset']]
263
+ processed_train_dataset = ','.join(processed_train_dataset)
264
+ writer = self.get_writer('train', processed_train_dataset, k)
265
+ writer.add_scalar(f'train_loss/{k}', v_avg, global_step=step_cnt)
266
+ self.logger.info(loss_str)
267
+ # info for metric
268
+ metric_str = f"Iter: {step_cnt} "
269
+ for k, v in train_recorder_metric.items():
270
+ v_avg = v.average()
271
+ if v_avg == None:
272
+ metric_str += f"training-metric, {k}: not calculated "
273
+ continue
274
+ metric_str += f"training-metric, {k}: {v_avg} "
275
+ # tensorboard-2. metric
276
+ processed_train_dataset = [dataset.split('/')[-1] for dataset in self.config['train_dataset']]
277
+ processed_train_dataset = ','.join(processed_train_dataset)
278
+ writer = self.get_writer('train', processed_train_dataset, k)
279
+ writer.add_scalar(f'train_metric/{k}', v_avg, global_step=step_cnt)
280
+ self.logger.info(metric_str)
281
+
282
+ # clear recorder.
283
+ # Note we only consider the current 300 samples for computing batch-level loss/metric
284
+ for name, recorder in train_recorder_loss.items(): # clear loss recorder
285
+ recorder.clear()
286
+ for name, recorder in train_recorder_metric.items(): # clear metric recorder
287
+ recorder.clear()
288
+
289
+ # run validation
290
+ if (step_cnt+1) % validation_step == 0:
291
+ if validation_data_loaders is not None and ((not self.config['ddp']) or (self.config['ddp'] and dist.get_rank() == 0)):
292
+ self.logger.info("===> Validation start!")
293
+ validation_best_metric = self.eval(
294
+ eval_data_loaders=validation_data_loaders,
295
+ eval_stage="validation",
296
+ step=step_cnt,
297
+ epoch=epoch,
298
+ iteration=iteration
299
+ )
300
+ else:
301
+ validation_best_metric = None
302
+
303
+ step_cnt += 1
304
+
305
+ for key in data_dict.keys():
306
+ if data_dict[key]!=None and key!='name':
307
+ data_dict[key]=data_dict[key].cpu()
308
+ return validation_best_metric
309
+
310
+ def get_respect_acc(self,prob,label):
311
+ pred = np.where(prob > 0.5, 1, 0)
312
+ judge = (pred == label)
313
+ zero_num = len(label) - np.count_nonzero(label)
314
+ acc_fake = np.count_nonzero(judge[zero_num:]) / len(judge[zero_num:])
315
+ acc_real = np.count_nonzero(judge[:zero_num]) / len(judge[:zero_num])
316
+ return acc_real,acc_fake
317
+
318
+ def eval_one_dataset(self, data_loader):
319
+ # define eval recorder
320
+ eval_recorder_loss = defaultdict(Recorder)
321
+ prediction_lists = []
322
+ feature_lists=[]
323
+ label_lists = []
324
+ for i, data_dict in tqdm(enumerate(data_loader),total=len(data_loader)):
325
+ # get data
326
+ if 'label_spe' in data_dict:
327
+ data_dict.pop('label_spe') # remove the specific label
328
+ data_dict['label'] = torch.where(data_dict['label']!=0, 1, 0) # fix the label to 0 and 1 only
329
+ # move data to GPU elegantly
330
+ for key in data_dict.keys():
331
+ if data_dict[key]!=None:
332
+ data_dict[key]=data_dict[key].cuda()
333
+ # model forward without considering gradient computation
334
+ predictions = self.inference(data_dict) #dict with keys cls, feat
335
+
336
+ label_lists += list(data_dict['label'].cpu().detach().numpy())
337
+ # Get the predicted class for each sample in the batch
338
+ _, predicted_classes = torch.max(predictions['cls'], dim=1)
339
+ # Convert the predicted class indices to a list and add to prediction_lists
340
+ prediction_lists += predicted_classes.cpu().detach().numpy().tolist()
341
+ feature_lists += list(predictions['feat'].cpu().detach().numpy())
342
+ if type(self.model) is not AveragedModel:
343
+ # compute all losses for each batch data
344
+ if type(self.model) is DDP:
345
+ losses = self.model.module.get_losses(data_dict, predictions)
346
+ else:
347
+ losses = self.model.get_losses(data_dict, predictions)
348
+
349
+ # store data by recorder
350
+ for name, value in losses.items():
351
+ eval_recorder_loss[name].update(value)
352
+ return eval_recorder_loss, np.array(prediction_lists), np.array(label_lists),np.array(feature_lists)
353
+
354
+ def save_best(self,epoch,iteration,step,losses_one_dataset_recorder,key,metric_one_dataset,eval_stage):
355
+ best_metric = self.best_metrics_all_time[key].get(self.metric_scoring,
356
+ float('-inf') if self.metric_scoring != 'eer' else float(
357
+ 'inf'))
358
+ # Check if the current score is an improvement
359
+ improved = (metric_one_dataset[self.metric_scoring] > best_metric) if self.metric_scoring != 'eer' else (
360
+ metric_one_dataset[self.metric_scoring] < best_metric)
361
+ if improved:
362
+ # Update the best metric
363
+ self.best_metrics_all_time[key][self.metric_scoring] = metric_one_dataset[self.metric_scoring]
364
+ if key == 'avg':
365
+ self.best_metrics_all_time[key]['dataset_dict'] = metric_one_dataset['dataset_dict']
366
+ # Save checkpoint, feature, and metrics if specified in config
367
+ if eval_stage=='validation' and self.config['save_ckpt'] and key not in FFpp_pool:
368
+ self.save_ckpt(eval_stage, key, f"{epoch}+{iteration}")
369
+ self.save_metrics(eval_stage, metric_one_dataset, key)
370
+ if losses_one_dataset_recorder is not None:
371
+ # info for each dataset
372
+ loss_str = f"dataset: {key} step: {step} "
373
+ for k, v in losses_one_dataset_recorder.items():
374
+ writer = self.get_writer(eval_stage, key, k)
375
+ v_avg = v.average()
376
+ if v_avg == None:
377
+ print(f'{k} is not calculated')
378
+ continue
379
+ # tensorboard-1. loss
380
+ writer.add_scalar(f'{eval_stage}_losses/{k}', v_avg, global_step=step)
381
+ loss_str += f"{eval_stage}-loss, {k}: {v_avg} "
382
+ self.logger.info(loss_str)
383
+ # tqdm.write(loss_str)
384
+ metric_str = f"dataset: {key} step: {step} "
385
+ for k, v in metric_one_dataset.items():
386
+ if k == 'pred' or k == 'label' or k=='dataset_dict':
387
+ continue
388
+ metric_str += f"{eval_stage}-metric, {k}: {v} "
389
+ # tensorboard-2. metric
390
+ writer = self.get_writer(eval_stage, key, k)
391
+ writer.add_scalar(f'{eval_stage}_metrics/{k}', v, global_step=step)
392
+ if 'pred' in metric_one_dataset:
393
+ acc_real, acc_fake = self.get_respect_acc(metric_one_dataset['pred'], metric_one_dataset['label'])
394
+ metric_str += f'{eval_stage}-metric, acc_real:{acc_real}; acc_fake:{acc_fake}'
395
+ writer.add_scalar(f'{eval_stage}_metrics/acc_real', acc_real, global_step=step)
396
+ writer.add_scalar(f'{eval_stage}_metrics/acc_fake', acc_fake, global_step=step)
397
+ self.logger.info(metric_str)
398
+
399
+ def eval(self, eval_data_loaders, eval_stage, step=None, epoch=None, iteration=None):
400
+ # set model to eval mode
401
+ self.setEval()
402
+
403
+ # define eval recorder
404
+ losses_all_datasets = {}
405
+ metrics_all_datasets = {}
406
+ best_metrics_per_dataset = defaultdict(dict) # best metric for each dataset, for each metric
407
+ avg_metric = {'acc': 0, 'auc': 0, 'eer': 0, 'ap': 0,'dataset_dict':{}} #'video_auc': 0
408
+ keys = eval_data_loaders.keys()
409
+ for key in keys:
410
+ # compute loss for each dataset
411
+ losses_one_dataset_recorder, predictions_nps, label_nps, feature_nps = self.eval_one_dataset(eval_data_loaders[key])
412
+ losses_all_datasets[key] = losses_one_dataset_recorder
413
+ metric_one_dataset=get_test_metrics(y_pred=predictions_nps,y_true=label_nps, logger=self.logger)
414
+
415
+ for metric_name, value in metric_one_dataset.items():
416
+ if metric_name in avg_metric:
417
+ avg_metric[metric_name]+=value
418
+ avg_metric['dataset_dict'][key] = metric_one_dataset[self.metric_scoring]
419
+ if type(self.model) is AveragedModel:
420
+ metric_str = f"Iter Final for SWA: "
421
+ for k, v in metric_one_dataset.items():
422
+ metric_str += f"{eval_stage}-metric, {k}: {v} "
423
+ self.logger.info(metric_str)
424
+ continue
425
+ self.save_best(epoch,iteration,step,losses_one_dataset_recorder,key,metric_one_dataset,eval_stage)
426
+
427
+ if len(keys)>0 and self.config.get('save_avg',False):
428
+ # calculate avg value
429
+ for key in avg_metric:
430
+ if key != 'dataset_dict':
431
+ avg_metric[key] /= len(keys)
432
+ self.save_best(epoch, iteration, step, None, 'avg', avg_metric, eval_stage)
433
+
434
+ self.logger.info(f'===> {eval_stage} Done!')
435
+ return self.best_metrics_all_time # return all types of mean metrics for determining the best ckpt
436
+
437
+
438
+ @torch.no_grad()
439
+ def inference(self, data_dict):
440
+ predictions = self.model(data_dict, inference=True)
441
+ return predictions