diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..ed3807a39768624594ea748034c160a4525748ef 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,35 +1,2 @@ -*.7z filter=lfs diff=lfs merge=lfs -text -*.arrow filter=lfs diff=lfs merge=lfs -text -*.bin filter=lfs diff=lfs merge=lfs -text -*.bz2 filter=lfs diff=lfs merge=lfs -text -*.ckpt filter=lfs diff=lfs merge=lfs -text -*.ftz filter=lfs diff=lfs merge=lfs -text -*.gz filter=lfs diff=lfs merge=lfs -text -*.h5 filter=lfs diff=lfs merge=lfs -text -*.joblib filter=lfs diff=lfs merge=lfs -text -*.lfs.* filter=lfs diff=lfs merge=lfs -text -*.mlmodel filter=lfs diff=lfs merge=lfs -text -*.model filter=lfs diff=lfs merge=lfs -text -*.msgpack filter=lfs diff=lfs merge=lfs -text -*.npy filter=lfs diff=lfs merge=lfs -text -*.npz filter=lfs diff=lfs merge=lfs -text -*.onnx filter=lfs diff=lfs merge=lfs -text -*.ot filter=lfs diff=lfs merge=lfs -text -*.parquet filter=lfs diff=lfs merge=lfs -text -*.pb filter=lfs diff=lfs merge=lfs -text -*.pickle filter=lfs diff=lfs merge=lfs -text -*.pkl filter=lfs diff=lfs merge=lfs -text -*.pt filter=lfs diff=lfs merge=lfs -text -*.pth filter=lfs diff=lfs merge=lfs -text -*.rar filter=lfs diff=lfs merge=lfs -text -*.safetensors filter=lfs diff=lfs merge=lfs -text -saved_model/**/* filter=lfs diff=lfs merge=lfs -text -*.tar.* filter=lfs diff=lfs merge=lfs -text -*.tar filter=lfs diff=lfs merge=lfs -text -*.tflite filter=lfs diff=lfs merge=lfs -text -*.tgz filter=lfs diff=lfs merge=lfs -text -*.wasm filter=lfs diff=lfs merge=lfs -text -*.xz filter=lfs diff=lfs merge=lfs -text -*.zip filter=lfs diff=lfs merge=lfs -text -*.zst filter=lfs diff=lfs merge=lfs -text -*tfevents* filter=lfs diff=lfs merge=lfs -text +*.json filter=lfs diff=lfs merge=lfs -text +*.dat filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md index 7be5fc7f47d5db027d120b8024982df93db95b74..2c4c7df05acff2374d745666735bd7f38e899375 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,271 @@ ---- -license: mit ---- +# DFG - Deepfake Genome Codebase + +## 1. Environment Setup + +Create and activate the conda environment: + +```bash +# Create a new conda environment (Python 3.10 recommended) +conda create -n dfg python=3.10 -y + +# Activate the environment +conda activate dfg + +# Install dependencies +pip install -r requirements.txt +``` + +## 2. Dataset Configuration + +Before training or testing, you need to update the **dataset global path** to match your actual data location. + +Open `training/dataset/abstract_dataset.py` and modify the `DATASET_GLOBAL_PATH` variable: + +```python +# Change this to your actual dataset root path +DATASET_GLOBAL_PATH = "/your/actual/dataset/path/" +``` + +This path should point to the root directory containing your deepfake detection datasets (e.g., `DeepFakeGenome`, `deepfake_detecton_dataset`, etc.). + +## 3. Project and Dataset Structure + +``` +DFG/ +├── preprocessing/ +│ └── dataset_json/ # Dataset index JSON files +│ ├── protocol_2_train.json +│ ├── protocol_2_test.json +│ ├── protocol_3_test.json +│ ├── protocol_4_test.json +│ └── ... +├── training/ +│ ├── config/ +│ │ └── detector/ # Detector config YAML files +│ ├── detectors/ # Detector implementations +│ │ ├── __init__.py # Register all detectors here +│ │ ├── base_detector.py +│ │ └── ... +│ ├── networks/ # Backbone network implementations +│ ├── loss/ # Loss function definitions +│ ├── metrics/ # Evaluation metrics +│ ├── train.py # Training entry point +│ └── test_pall.py # Testing entry point +├── train.sh # Training script examples +├── test.sh # Testing script examples +├── requirements.txt # Python dependencies +└── README.md +``` + +## 4. Training + +Refer to `train.sh` for all training commands. Example: + +```bash +python -m torch.distributed.launch --master_port=29503 --nproc_per_node=8 training/train.py \ + --detector_path ./training/config/detector/clip_large_fft.yaml \ + --no-save_feat --ddp +``` + +Key arguments: +- `--master_port`: port for distributed training (change if port conflicts occur) +- `--nproc_per_node`: number of GPUs +- `--detector_path`: path to the detector config YAML +- `--no-save_feat`: disable feature saving during training +- `--ddp`: enable DistributedDataParallel + +## 5. Testing + +Refer to `test.sh` for all testing commands. Example: + +```bash +# Test on protocol 2 & 3 +python -m torch.distributed.launch --master_port=29510 --nproc_per_node=8 training/test_pall.py --ddp \ + --test_dataset "protocol_2_test" "protocol_3_test" \ + --detector_path ./training/config/detector/clip_large_fft.yaml \ + --weights_path logs/clip_models/clip_large_fft_2025-11-08-13-56-51 + +# Test on protocol 4 +python -m torch.distributed.launch --master_port=29512 --nproc_per_node=8 training/test_pall.py --ddp \ + --test_dataset "protocol_4_test" \ + --detector_path ./training/config/detector/clip_large_fft.yaml \ + --weights_path logs/clip_models/clip_large_fft_2025-11-08-13-56-51 \ + --test_config test_config_p4.yaml +``` + +Key arguments: +- `--test_dataset`: one or more dataset names (must match JSON filenames under `preprocessing/dataset_json/`) +- `--weights_path`: path to trained model checkpoint directory +- `--test_config`: additional test configuration (required for protocol 4) + +## 6. Adding a Custom Detector + +To integrate your own detector into the framework, follow these three steps: + +### Step 1: Create the detector config YAML + +Create a new file under `training/config/detector/`, e.g., `my_detector.yaml`: + +```yaml +# log dir +log_dir: logs/my_detector + +# model setting +pretrained: null +model_name: my_detector +backbone_name: resnet34 + +# backbone setting +backbone_config: + mode: original + num_classes: 2 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 +train_batchSize: 64 +test_batchSize: 64 +workers: 8 +frame_num: {'train': 16, 'test': 16} +resolution: 224 +with_mask: false +with_landmark: false + +# data augmentation +use_data_augmentation: false +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.485, 0.456, 0.406] +std: [0.229, 0.224, 0.225] + +# optimizer config +optimizer: + type: adam + adam: + lr: 0.0002 + beta1: 0.9 + beta2: 0.999 + eps: 0.00000001 + weight_decay: 0.0005 + amsgrad: false + +# training config +lr_scheduler: null +nEpochs: 20 +start_epoch: 0 +save_epoch: 1 +rec_iter: 100 +logdir: ./logs +manualSeed: 1024 +save_ckpt: true +save_feat: true + +# loss function +loss_func: cross_entropy +losstype: null + +# metric +metric_scoring: auc + +# cuda +ngpu: 1 +cuda: true +cudnn: true + +save_avg: true +save_latest_ckpt: true +``` + +### Step 2: Create the detector Python file + +Create `training/detectors/my_detector.py`: + +```python +import torch +import torch.nn as nn + +from metrics.base_metrics_class import calculate_metrics_for_train +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC + + +@DETECTOR.register_module(module_name='my_detector') +class MyDetector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.loss_func = LOSSFUNC[config['loss_func']]() + + def build_backbone(self, config): + backbone = BACKBONE[config['backbone_name']](config['backbone_config']) + return backbone + + def features(self, data_dict: dict) -> torch.Tensor: + return self.backbone(data_dict['image']) + + def classifier(self, features: torch.Tensor) -> torch.Tensor: + return self.fc(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + return {'overall': loss} + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + return {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + def forward(self, data_dict: dict, inference=False) -> dict: + features = self.features(data_dict) + pred = self.classifier(features) + prob = torch.softmax(pred, dim=1)[:, 1] + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + return pred_dict +``` + +### Step 3: Register the detector in `__init__.py` + +Add the following import line to `training/detectors/__init__.py`: + +```python +from .my_detector import MyDetector +``` + +That's it! Now you can train and test with your custom detector: + +```bash +# Train +python -m torch.distributed.launch --master_port=29503 --nproc_per_node=8 training/train.py \ + --detector_path ./training/config/detector/my_detector.yaml \ + --no-save_feat --ddp + +# Test +python -m torch.distributed.launch --master_port=29510 --nproc_per_node=8 training/test_pall.py --ddp \ + --test_dataset "protocol_2_test" "protocol_3_test" \ + --detector_path ./training/config/detector/my_detector.yaml \ + --weights_path logs/my_detector/ +``` + + diff --git a/preprocessing/config.yaml b/preprocessing/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4a74cc4de4891fbdc7280610f91d1880382f48e0 --- /dev/null +++ b/preprocessing/config.yaml @@ -0,0 +1,52 @@ +preprocess: + dataset_name: # the name of dataset + choices: ['FaceForensics++','Celeb-DF-v1', 'Celeb-DF-v2', 'DFDCP', 'DFDC', 'DeeperForensics-1.0','UADFV'] + default: 'FaceForensics++' + dataset_root_path: # the root path to the dataset + type: str + default: 'F:\' + comp: # the compression level of videos, only in the dataset of FaceForensics++. + choices: ['raw', 'c23', 'c40'] + default: 'c23' + mode: # based on the numbers of frame or skip the specific stride of frames. + choices: ['fixed_num_frames', 'fixed_stride'] + default: 'fixed_num_frames' + stride: # when 'mode' is 'fixed_stride', 'stride' is the number of frames to skip between each frame extracted. + type: int + default: 10 + num_frames: # when 'mode' is 'fixed_num_frames', 'num_frames' is the number of frames to extract from each video. + type: int + default: 32 + +rearrange: + dataset_name: # the name of dataset + choices: ['FaceForensics++', 'DeepFakeDetection', 'Celeb-DF-v1', 'Celeb-DF-v2','DFDCP', 'DFDC', 'DeeperForensics-1.0','UADFV','FaceShifter'] + default: 'FaceForensics++' + dataset_root_path: # the root path to the dataset + type: str + default: '' + output_file_path: # the json path to the dataset + type: str + default: '../preprocessing/dataset_json_v6' + comp: # the compression level of videos, only in the dataset of FaceForensics++. + choices: ['raw', 'c23', 'c40'] + default: 'c23' + perturbation: # Extensive real-world perturbations are applied to DeeperForensics-1.0 dataset + choices: ['end_to_end','end_to_end_level_1','end_to_end_level_2','end_to_end_level_3','end_to_end_level_4', + 'end_to_end_level_5','end_to_end_mix_2_distortions','end_to_end_mix_3_distortions', + 'end_to_end_mix_4_distortions','end_to_end_random_level','reenact_postprocess'] + default: 'end_to_end' + +to_lmdb: + dataset_name: # the name of dataset + choices: ['FaceForensics++', 'DeepFakeDetection', 'Celeb-DF-v1', 'Celeb-DF-v2','DFDCP', 'DFDC', 'DeeperForensics-1.0','UADFV','FaceShifter'] + default: 'FaceForensics++' + dataset_root_path: # the root path to the dataset + type: str + default: './datasets_v2' + output_lmdb_dir: # the json path to the dataset + type: str + default: './datasets_lmdbs' + comp: # the compression level of videos, only in the dataset of FaceForensics++. + choices: ['raw', 'c23', 'c40'] + default: 'c23' \ No newline at end of file diff --git a/preprocessing/dataset2lmdb_test.py b/preprocessing/dataset2lmdb_test.py new file mode 100644 index 0000000000000000000000000000000000000000..649bf1bc1661c2812907c3349d6c13a66a79b084 --- /dev/null +++ b/preprocessing/dataset2lmdb_test.py @@ -0,0 +1,99 @@ +import os +import json +import cv2 +import lmdb +import yaml +from PIL import Image +import io +import numpy as np +def file_to_binary(file_path): + """convert to binary""" + if file_path.endswith('.npy'): + data = np.load(file_path) + file_binary = data.tobytes() + else: + with open(file_path, 'rb') as f: + file_binary = f.read() + return file_binary + + +def create_lmdb_dataset(source_folder, lmdb_path, dataset_name, map_size): + """create LMDB dataset""" + # open LMDB file,create dataset + db = lmdb.open(lmdb_path, map_size=map_size) + with db.begin(write=True) as txn: + + for root, dirs, files in os.walk(source_folder,followlinks=True): + print(root) + if 'video' in root: + continue + for file in files: + print(file) + image_path = os.path.join(root, file) + # + relative_path = f"{dataset_name}/" + os.path.relpath(image_path, source_folder) + print("relative_path:", relative_path) + key = relative_path.encode('utf-8') + # txn.delete(key) + # relative_path = f"{dataset_name}\\original_sequences" + os.path.relpath(image_path, source_folder) + # key = relative_path.encode('utf-8') + print("image_path:", image_path) + value = file_to_binary(image_path) + + # write dataset + txn.put(key, value) + + + + db.close() + + +def read_lmdb(lmdb_dir_path): + # validate the key and value in the generated LMDB + env = lmdb.open(lmdb_dir_path) + + idx = '%09d' % 5 + with env.begin(write=False) as txn: + # key for validation + key='npy_test\\000_003\\000.npy' + binary = txn.get(key.encode()) + data = np.frombuffer(binary, dtype=np.uint32).reshape((81, 2)) + + # image_buf = np.frombuffer(image_bin, dtype=np.uint8) + # img = cv2.imdecode(image_buf, cv2.IMREAD_COLOR) + # image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) + + +# Usage example +import argparse +# Create the ArgumentParser object +parser = argparse.ArgumentParser(description='Process some inputs.') + +# Add the --dataset_size argument +parser.add_argument('--dataset_size', type=int, default=25, required=True, + help='lmdb requires pre-specifying the total dataset size (GB)') + +# Parse the arguments +args = parser.parse_args() + +if __name__ == '__main__': + # from config.yaml load parameters + yaml_path = './config_DFo.yaml' + # open the yaml file + try: + with open(yaml_path, 'r') as f: + config = yaml.safe_load(f) + except yaml.parser.ParserError as e: + print("YAML file parsing error:", e) + + config=config['to_lmdb'] + dataset_name = config['dataset_name']['default'] + dataset_size = args.dataset_size + dataset_root_path = config['dataset_root_path']['default'] + output_lmdb_dir =config['output_lmdb_dir']['default'] + os.makedirs(output_lmdb_dir,exist_ok=True) + + dataset_dir_path = f"{dataset_root_path}/{dataset_name}" + lmdb_path=f"{output_lmdb_dir}/{dataset_name}_lmdb" + create_lmdb_dataset(dataset_dir_path, lmdb_path, dataset_name,map_size=int(dataset_size) * 1024 * 1024 * 1024) + #read_lmdb(lmdb_path) diff --git a/preprocessing/dataset_json/Celeb-DF-v2.json b/preprocessing/dataset_json/Celeb-DF-v2.json new file mode 100644 index 0000000000000000000000000000000000000000..da5f2a14dfe5d0e8cf8f0c51ec9c91e52778ec4b --- /dev/null +++ b/preprocessing/dataset_json/Celeb-DF-v2.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:113fcde0ea7b1a03caf63e2ed2f3e6d80bf99efe18073ca05c606c9d0b260804 +size 20076776 diff --git a/preprocessing/dataset_json/DF40_all.json b/preprocessing/dataset_json/DF40_all.json new file mode 100644 index 0000000000000000000000000000000000000000..5045f417d7bfebc8efb73665fdc4bd01ddae2659 --- /dev/null +++ b/preprocessing/dataset_json/DF40_all.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6308d04ffd0e9da59a7df058bf6a27ae41da0a15f03add8a11f694f510a5b2f6 +size 125339450 diff --git a/preprocessing/dataset_json/DFDC.json b/preprocessing/dataset_json/DFDC.json new file mode 100644 index 0000000000000000000000000000000000000000..215934340c1cd903a2a1f1693780fe4854341733 --- /dev/null +++ b/preprocessing/dataset_json/DFDC.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8d1184758620c71b68ad8715e068644ed9792bdc6b2feba9cf0b7f8a98a7e00d +size 44499938 diff --git a/preprocessing/dataset_json/DFDCP.json b/preprocessing/dataset_json/DFDCP.json new file mode 100644 index 0000000000000000000000000000000000000000..0dbc83186c3f19b9b6cbd9ae63fa73f220e3bc41 --- /dev/null +++ b/preprocessing/dataset_json/DFDCP.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ed5022e36380b3c1ca21941e95bad7bcf08fc3c58e50441012757189eed1868d +size 27634090 diff --git a/preprocessing/dataset_json/DeepFakeDetection.json b/preprocessing/dataset_json/DeepFakeDetection.json new file mode 100644 index 0000000000000000000000000000000000000000..0111ca7d11e0101e295f6669bf017aeba1c86638 --- /dev/null +++ b/preprocessing/dataset_json/DeepFakeDetection.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aed7e4a257feb622435119bec621dc09b0614823e4e3e1186bc4b280a394fa90 +size 45849312 diff --git a/preprocessing/dataset_json/DiffFace.json b/preprocessing/dataset_json/DiffFace.json new file mode 100644 index 0000000000000000000000000000000000000000..554855c805bc05e29a42939f14afb06955d62977 --- /dev/null +++ b/preprocessing/dataset_json/DiffFace.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f3c8ff3368ae4c6ae5950ccc79d63603279c0c103264d2215eebd270e6a7535f +size 7177344 diff --git a/preprocessing/dataset_json/DreamBooth.json b/preprocessing/dataset_json/DreamBooth.json new file mode 100644 index 0000000000000000000000000000000000000000..afaac634a980fe66a19e0e096fef4cf3696c7d9f --- /dev/null +++ b/preprocessing/dataset_json/DreamBooth.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2e4cbfa4d0efef6f4b9f8fec1ad4be9efdad7def115bf993e1684c631b1bedbc +size 7841108 diff --git a/preprocessing/dataset_json/FF-DF.json b/preprocessing/dataset_json/FF-DF.json new file mode 100644 index 0000000000000000000000000000000000000000..dcae1dc49f58bdc3c10f9e03b373794b555104e1 --- /dev/null +++ b/preprocessing/dataset_json/FF-DF.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:be532fe67b2bebaaaf3a81237ccc518ba8ded043564b187c923e7c6e79bc242b +size 6633592 diff --git a/preprocessing/dataset_json/FF-F2F.json b/preprocessing/dataset_json/FF-F2F.json new file mode 100644 index 0000000000000000000000000000000000000000..14e55b4a42e5203ca04677306c9efca4dd983fb8 --- /dev/null +++ b/preprocessing/dataset_json/FF-F2F.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0cb9c3e9e209dfd45390b7da7a882fcefab4af4a53bd672c999b594dda49cee3 +size 6647968 diff --git a/preprocessing/dataset_json/FF-FS.json b/preprocessing/dataset_json/FF-FS.json new file mode 100644 index 0000000000000000000000000000000000000000..ec87fac7d30f9250879f2d9e6835ee1d48143b82 --- /dev/null +++ b/preprocessing/dataset_json/FF-FS.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3f14cc304d80e39597d280234c654aca1fb5af3615a26b3c431e662cb50e2c23 +size 6615423 diff --git a/preprocessing/dataset_json/FF-NT.json b/preprocessing/dataset_json/FF-NT.json new file mode 100644 index 0000000000000000000000000000000000000000..18317361c85c9491e07b8e367f5a33fc798325ca --- /dev/null +++ b/preprocessing/dataset_json/FF-NT.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:112494833ab34ab6b9476dc2632aed5c5928961b004e7267f3690c8b70b1c947 +size 6804515 diff --git a/preprocessing/dataset_json/FaceForensics++.json b/preprocessing/dataset_json/FaceForensics++.json new file mode 100644 index 0000000000000000000000000000000000000000..4999340c40afc2fb9c9c0dfbd47846cba75fb6da --- /dev/null +++ b/preprocessing/dataset_json/FaceForensics++.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:18802f85a0de861d07140fafe4fbbdda67167afb110d0b6f5cece1738c7428c8 +size 17184826 diff --git a/preprocessing/dataset_json/FaceShifter.json b/preprocessing/dataset_json/FaceShifter.json new file mode 100644 index 0000000000000000000000000000000000000000..6b55a996476267a7de2f7e01f4bf61d3fed3366a --- /dev/null +++ b/preprocessing/dataset_json/FaceShifter.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f2031400f8aea5369ebad30c5b9839db370c1bf7f7ba16183a9f9ed833b30904 +size 6695159 diff --git a/preprocessing/dataset_json/GPT4o.json b/preprocessing/dataset_json/GPT4o.json new file mode 100644 index 0000000000000000000000000000000000000000..10457999bc0ab560d411e5bf8c32ad39bcb69c66 --- /dev/null +++ b/preprocessing/dataset_json/GPT4o.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6d5c880684793cc10837cb5b39e2e19cdf5c0ab460f8cb6f4214fe0f65ef8571 +size 247155 diff --git a/preprocessing/dataset_json/HPS.json b/preprocessing/dataset_json/HPS.json new file mode 100644 index 0000000000000000000000000000000000000000..0d42cc3e864419237e8dbb0dee11be13e746ed30 --- /dev/null +++ b/preprocessing/dataset_json/HPS.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8334f78d11be167a14492562af36dc45e11899d6d9e6949e51a5c8d252e8c89b +size 8968435 diff --git a/preprocessing/dataset_json/Hart.json b/preprocessing/dataset_json/Hart.json new file mode 100644 index 0000000000000000000000000000000000000000..840ad96c4f87eae6668503aa52f5d6aaa88efc68 --- /dev/null +++ b/preprocessing/dataset_json/Hart.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f69d96d5d5aed7be81ad805e113c6128d1965b275a9586ff704d72abcdd2df50 +size 4182208 diff --git a/preprocessing/dataset_json/Imagic.json b/preprocessing/dataset_json/Imagic.json new file mode 100644 index 0000000000000000000000000000000000000000..cf242d5bf60e3a17375187eaa513efbe473e93ab --- /dev/null +++ b/preprocessing/dataset_json/Imagic.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8bcfb6a163b2c9e41d9c9444208a99cd04abfa634c209604348518cad56ad3eb +size 7937325 diff --git a/preprocessing/dataset_json/Infinity.json b/preprocessing/dataset_json/Infinity.json new file mode 100644 index 0000000000000000000000000000000000000000..cd915052c5247b01db825be397aa020864c7af50 --- /dev/null +++ b/preprocessing/dataset_json/Infinity.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5cd859f6e5aa9fdcb2e18bcc3a095c0ef73d9355fc3d083255291b186bbe7bbe +size 4332314 diff --git a/preprocessing/dataset_json/LoRA.json b/preprocessing/dataset_json/LoRA.json new file mode 100644 index 0000000000000000000000000000000000000000..973ef5b3375b8f45536aa0641c4efbd4ea853315 --- /dev/null +++ b/preprocessing/dataset_json/LoRA.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fc22024e17c872fda7beb88c09007e2c7a78b78255248a7f58e98ba58ee58517 +size 7655888 diff --git a/preprocessing/dataset_json/MidJourney.json b/preprocessing/dataset_json/MidJourney.json new file mode 100644 index 0000000000000000000000000000000000000000..19d8ed10d77049a189bb760b060db4586d21176c --- /dev/null +++ b/preprocessing/dataset_json/MidJourney.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5f9e087f7d7525fdfe756a017cc5f4b88b8ea8056954f5f5b2bbec5c99192c8b +size 767342 diff --git a/preprocessing/dataset_json/Midjourney_diff.json b/preprocessing/dataset_json/Midjourney_diff.json new file mode 100644 index 0000000000000000000000000000000000000000..aef21419da1f6c0bc4d52e1f5804f1d99059c938 --- /dev/null +++ b/preprocessing/dataset_json/Midjourney_diff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c3238e3ba309dc523560756611ae404ebf3d21084211fe05064452a38b5cf6c8 +size 8330096 diff --git a/preprocessing/dataset_json/SRI.json b/preprocessing/dataset_json/SRI.json new file mode 100644 index 0000000000000000000000000000000000000000..d61f9caf48fcd697456353958ab8a71e048d4f57 --- /dev/null +++ b/preprocessing/dataset_json/SRI.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6b750334121350c9c4f2e5790b1917400c26fd4af4323aec185415985993afe7 +size 1307793 diff --git a/preprocessing/dataset_json/SRI_hq.json b/preprocessing/dataset_json/SRI_hq.json new file mode 100644 index 0000000000000000000000000000000000000000..2cebb05519935f2feb9f3137526ea4980f87f0ec --- /dev/null +++ b/preprocessing/dataset_json/SRI_hq.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d1f4d7b697768c8a37d4409cbd8fc2c11c3d54a37697007c79bc6830c6965a6d +size 1206968 diff --git a/preprocessing/dataset_json/abstract_dataset.py b/preprocessing/dataset_json/abstract_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..867d831722992b2622e50af6a4dcb7f96178e6d6 --- /dev/null +++ b/preprocessing/dataset_json/abstract_dataset.py @@ -0,0 +1,668 @@ +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-03-30 +# description: Abstract Base Class for all types of deepfake datasets. + +import sys + +import lmdb + +sys.path.append('.') + +import os +import math +import yaml +import glob +import json + +import numpy as np +from copy import deepcopy +import cv2 +import random +from PIL import Image +from collections import defaultdict + +import torch +from torch.autograd import Variable +from torch.utils import data +from torchvision import transforms as T + +import albumentations as A + +from .albu import IsotropicResize + +FFpp_pool=['FaceForensics++','FaceShifter','DeepFakeDetection','FF-DF','FF-F2F','FF-FS','FF-NT']# +import pdb + +def all_in_pool(inputs,pool): + for each in inputs: + if each not in pool: + return False + return True + + +class DeepfakeAbstractBaseDataset(data.Dataset): + """ + Abstract base class for all deepfake datasets. + """ + def __init__(self, config=None, mode='train'): + """Initializes the dataset object. + + Args: + config (dict): A dictionary containing configuration parameters. + mode (str): A string indicating the mode (train or test). + + Raises: + NotImplementedError: If mode is not train or test. + """ + + # Set the configuration and mode + self.config = config + self.mode = mode + self.compression = config['compression'] + self.frame_num = config['frame_num'][mode] # + + # Check if 'video_mode' exists in config, otherwise set video_level to False + self.video_level = config.get('video_mode', False) + self.clip_size = config.get('clip_size', None) + self.lmdb = config.get('lmdb', False) + # Dataset dictionary + self.image_list = [] + self.label_list = [] + + # Set the dataset dictionary based on the mode + if mode == 'train': + dataset_list = config['train_dataset'] + + # Training data should be collected together for training + image_list, label_list = [], [] + for one_data in dataset_list: + # if one_data == "ivy_fake_train": + # tmp_image, tmp_label, tmp_name = self.collect_img_and_label_for_one_dataset(one_data) + # tmp_image = list(tmp_image) + # tmp_label = list(tmp_label) + # sample_indices = random.sample(range(len(tmp_image)), 9510) + # tmp_image = [tmp_image[i] for i in sample_indices] + # tmp_label = [tmp_label[i] for i in sample_indices] + + + # if one_data == "FF-DF": + # tmp_image, tmp_label, tmp_name = self.collect_img_and_label_for_one_dataset(one_data) + # tmp_image = list(tmp_image) + # tmp_label = list(tmp_label) + # # print('ffdf') + + # sample_indices = random.sample(range(len(tmp_image)), 7937) + # tmp_image = [tmp_image[i] for i in sample_indices] + # tmp_label = [tmp_label[i] for i in sample_indices] + tmp_image, tmp_label, tmp_name = self.collect_img_and_label_for_one_dataset(one_data) + image_list.extend(tmp_image) + label_list.extend(tmp_label) + if self.lmdb: + if len(dataset_list)>1: + if all_in_pool(dataset_list,FFpp_pool): + lmdb_path = os.path.join(config['lmdb_dir'], f"FaceForensics++_lmdb") + self.env = lmdb.open(lmdb_path, create=False, subdir=True, readonly=True, lock=False) + else: + raise ValueError('Training with multiple dataset and lmdb is not implemented yet.') + else: + lmdb_path = os.path.join(config['lmdb_dir'], f"{dataset_list[0] if dataset_list[0] not in FFpp_pool else 'FaceForensics++'}_lmdb") + self.env = lmdb.open(lmdb_path, create=False, subdir=True, readonly=True, lock=False) + elif mode == 'test': + one_data = config['test_dataset'] + # Test dataset should be evaluated separately. So collect only one dataset each time + image_list, label_list, name_list = self.collect_img_and_label_for_one_dataset(one_data) + if self.lmdb: + lmdb_path = os.path.join(config['lmdb_dir'], f"{one_data}_lmdb" if one_data not in FFpp_pool else 'FaceForensics++_lmdb') + self.env = lmdb.open(lmdb_path, create=False, subdir=True, readonly=True, lock=False) + else: + raise NotImplementedError('Only train and test modes are supported.') + + assert len(image_list)!=0 and len(label_list)!=0, f"Collect nothing for {mode} mode!" + self.image_list, self.label_list = image_list, label_list + + + # Create a dictionary containing the image and label lists + self.data_dict = { + 'image': self.image_list, + 'label': self.label_list, + } + + self.transform = self.init_data_aug_method() + + def init_data_aug_method(self): + # trans = A.Compose([ + # A.HorizontalFlip(p=self.config['data_aug']['flip_prob']), + # A.Rotate(limit=self.config['data_aug']['rotate_limit'], p=self.config['data_aug']['rotate_prob']), + # A.GaussianBlur(blur_limit=self.config['data_aug']['blur_limit'], p=self.config['data_aug']['blur_prob']), + # A.OneOf([ + # IsotropicResize(max_side=self.config['resolution'], interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC), + # IsotropicResize(max_side=self.config['resolution'], interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_LINEAR), + # IsotropicResize(max_side=self.config['resolution'], interpolation_down=cv2.INTER_LINEAR, interpolation_up=cv2.INTER_LINEAR), + # ], p = 0 if self.config['with_landmark'] else 1), + # A.OneOf([ + # A.RandomBrightnessContrast(brightness_limit=self.config['data_aug']['brightness_limit'], contrast_limit=self.config['data_aug']['contrast_limit']), + # A.FancyPCA(), + # A.HueSaturationValue() + # ], p=0.5), + # A.ImageCompression(quality_lower=self.config['data_aug']['quality_lower'], quality_upper=self.config['data_aug']['quality_upper'], p=0.5) + # ], + # keypoint_params=A.KeypointParams(format='xy') if self.config['with_landmark'] else None + # ) + + # video no aug + trans = A.Compose([ + A.HorizontalFlip(p=0.5), + A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5), + A.HueSaturationValue(p=0.3), + A.ImageCompression(quality_lower=40, quality_upper=100, p=0.1), # compression: 40-100, p=0.1 + A.GaussNoise(p=0.1), + A.MotionBlur(p=0.1), + A.CLAHE(p=0.1), + A.ChannelShuffle(p=0.1), + A.Cutout(p=0.1), + A.RandomGamma(p=0.3), + A.GlassBlur(p=0.3), + ]) + + return trans + + def rescale_landmarks(self, landmarks, original_size=256, new_size=224): + scale_factor = new_size / original_size + rescaled_landmarks = landmarks * scale_factor + return rescaled_landmarks + + + def collect_img_and_label_for_one_dataset(self, dataset_name: str): + """Collects image and label lists. + + Args: + dataset_name (str): A list containing one dataset information. e.g., 'FF-F2F' + + Returns: + list: A list of image paths. + list: A list of labels. + + Raises: + ValueError: If image paths or labels are not found. + NotImplementedError: If the dataset is not implemented yet. + """ + # Initialize the label and frame path lists + label_list = [] + frame_path_list = [] + + # Record video name for video-level metrics + video_name_list = [] + + # Try to get the dataset information from the JSON file + if not os.path.exists(self.config['dataset_json_folder']): + self.config['dataset_json_folder'] = self.config['dataset_json_folder'].replace('/Youtu_Pangu_Security_Public', '/Youtu_Pangu_Security/public') + try: + with open(os.path.join(self.config['dataset_json_folder'], dataset_name + '.json'), 'r') as f: + dataset_info = json.load(f) + except Exception as e: + print(e) + raise ValueError(f'dataset {dataset_name} not exist!') + + # If JSON file exists, do the following data collection + # FIXME: ugly, need to be modified here. + cp = None + if dataset_name == 'FaceForensics++_c40': + dataset_name = 'FaceForensics++' + cp = 'c40' + elif dataset_name == 'FF-DF_c40': + dataset_name = 'FF-DF' + cp = 'c40' + elif dataset_name == 'FF-F2F_c40': + dataset_name = 'FF-F2F' + cp = 'c40' + elif dataset_name == 'FF-FS_c40': + dataset_name = 'FF-FS' + cp = 'c40' + elif dataset_name == 'FF-NT_c40': + dataset_name = 'FF-NT' + cp = 'c40' + # Get the information for the current dataset + for label in dataset_info[dataset_name]: + sub_dataset_info = dataset_info[dataset_name][label][self.mode] + # Special case for FaceForensics++ and DeepFakeDetection, choose the compression type + # NOTE + if cp == None and dataset_name in ['FF-DF', 'FF-F2F', 'FF-FS', 'FF-NT', 'FaceForensics++','DeepFakeDetection','FaceShifter','ivy_fake_train','ivy_fake_test', + 'ivy_fake_test_Deepfakes','ivy_fake_test_NeuralTextures','ivy_fake_test_FaceSwap','ivy_fake_test_Face2Face']: + sub_dataset_info = sub_dataset_info[self.compression] + elif cp == 'c40' and dataset_name in ['FF-DF', 'FF-F2F', 'FF-FS', 'FF-NT', 'FaceForensics++','DeepFakeDetection','FaceShifter']: + sub_dataset_info = sub_dataset_info['c40'] + + # Iterate over the videos in the dataset + + for video_name, video_info in sub_dataset_info.items(): + # Unique video name + + unique_video_name = video_info['label'] + '_' + video_name + + # Get the label and frame paths for the current video + if video_info['label'] not in self.config['label_dict']: + raise ValueError(f'Label {video_info["label"]} is not found in the configuration file.') + label = self.config['label_dict'][video_info['label']] + frame_paths = video_info['frames'] + # sorted video path to the lists + if '\\' in frame_paths[0]: + frame_paths = sorted(frame_paths, key=lambda x: int(x.split('\\')[-1].split('.')[0])) + else: + frame_paths = sorted(frame_paths, key=lambda x: int(x.split('/')[-1].split('.')[0])) + + # Consider the case when the actual number of frames (e.g., 270) is larger than the specified (i.e., self.frame_num=32) + # In this case, we select self.frame_num frames from the original 270 frames + total_frames = len(frame_paths) + if self.frame_num < total_frames: + total_frames = self.frame_num + if self.video_level: + # Select clip_size continuous frames + start_frame = random.randint(0, total_frames - self.frame_num) if self.mode == 'train' else 0 + frame_paths = frame_paths[start_frame:start_frame + self.frame_num] # update total_frames + else: + # Select self.frame_num frames evenly distributed throughout the video + step = total_frames // self.frame_num + frame_paths = [frame_paths[i] for i in range(0, total_frames, step)][:self.frame_num] + + # If video-level methods, crop clips from the selected frames if needed + if self.video_level: + if self.clip_size is None: + raise ValueError('clip_size must be specified when video_level is True.') + # Check if the number of total frames is greater than or equal to clip_size + if total_frames >= self.clip_size: + # Initialize an empty list to store the selected continuous frames + selected_clips = [] + + # Calculate the number of clips to select + num_clips = total_frames // self.clip_size + + if num_clips > 1: + # Calculate the step size between each clip + clip_step = (total_frames - self.clip_size) // (num_clips - 1) + + # Select clip_size continuous frames from each part of the video + for i in range(num_clips): + # Ensure start_frame + self.clip_size - 1 does not exceed the index of the last frame + start_frame = random.randrange(i * clip_step, min((i + 1) * clip_step, total_frames - self.clip_size + 1)) if self.mode == 'train' else i * clip_step + continuous_frames = frame_paths[start_frame:start_frame + self.clip_size] + assert len(continuous_frames) == self.clip_size, 'clip_size is not equal to the length of frame_path_list' + selected_clips.append(continuous_frames) + + else: + start_frame = random.randrange(0, total_frames - self.clip_size + 1) if self.mode == 'train' else 0 + continuous_frames = frame_paths[start_frame:start_frame + self.clip_size] + assert len(continuous_frames)==self.clip_size, 'clip_size is not equal to the length of frame_path_list' + selected_clips.append(continuous_frames) + + # Append the list of selected clips and append the label + label_list.extend([label] * len(selected_clips)) + frame_path_list.extend(selected_clips) + # video name save + video_name_list.extend([unique_video_name] * len(selected_clips)) + + else: + print(f"Skipping video {unique_video_name} because it has less than clip_size ({self.clip_size}) frames ({total_frames}).") + + # Otherwise, extend the label and frame paths to the lists according to the number of frames + else: + # Extend the label and frame paths to the lists according to the number of frames + label_list.extend([label] * total_frames) + frame_path_list.extend(frame_paths) + # video name save + video_name_list.extend([unique_video_name] * len(frame_paths)) + + # Shuffle the label and frame path lists in the same order + shuffled = list(zip(label_list, frame_path_list, video_name_list)) + random.shuffle(shuffled) + label_list, frame_path_list, video_name_list = zip(*shuffled) + + return frame_path_list, label_list, video_name_list + + + def load_rgb(self, file_path): + """ + Load an RGB image from a file path and resize it to a specified resolution. + + Args: + file_path: A string indicating the path to the image file. + + Returns: + An Image object containing the loaded and resized image. + + Raises: + ValueError: If the loaded image is None. + """ + size = self.config['resolution'] # if self.mode == "train" else self.config['resolution'] + if not self.lmdb: + # if not file_path[0] == '.': + # file_path = f'./{self.config["rgb_dir"]}\\'+file_path + if not os.path.exists(file_path): + file_path = file_path.replace('\\', '/') + assert os.path.exists(file_path), f"{file_path} does not exist" + img = cv2.imread(file_path) + if img is None: + raise ValueError('Loaded image is None: {}'.format(file_path)) + elif self.lmdb: + with self.env.begin(write=False) as txn: + # transfer the path format from rgb-path to lmdb-key + if file_path[0]=='.': + file_path=file_path.replace('./datasets\\','') + + image_bin = txn.get(file_path.encode()) + image_buf = np.frombuffer(image_bin, dtype=np.uint8) + img = cv2.imdecode(image_buf, cv2.IMREAD_COLOR) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC) + return Image.fromarray(np.array(img, dtype=np.uint8)) + + + def load_mask(self, file_path): + """ + Load a binary mask image from a file path and resize it to a specified resolution. + + Args: + file_path: A string indicating the path to the mask file. + + Returns: + A numpy array containing the loaded and resized mask. + + Raises: + None. + """ + size = self.config['resolution'] + if file_path is None: + return np.zeros((size, size, 1)) + if not self.lmdb: + # if not file_path[0] == '.': + # file_path = f'./{self.config["rgb_dir"]}\\'+file_path + if os.path.exists(file_path): + mask = cv2.imread(file_path, 0) + if mask is None: + mask = np.zeros((size, size)) + else: + return np.zeros((size, size, 1)) + else: + with self.env.begin(write=False) as txn: + # transfer the path format from rgb-path to lmdb-key + if file_path[0]=='.': + file_path=file_path.replace('./datasets\\','') + + image_bin = txn.get(file_path.encode()) + if image_bin is None: + mask = np.zeros((size, size,3)) + else: + image_buf = np.frombuffer(image_bin, dtype=np.uint8) + mask = cv2.imdecode(image_buf, cv2.IMREAD_COLOR) + mask = cv2.resize(mask, (size, size)) / 255 + mask = np.expand_dims(mask, axis=2) + return np.float32(mask) + + def load_landmark(self, file_path): + """ + Load 2D facial landmarks from a file path. + + Args: + file_path: A string indicating the path to the landmark file. + + Returns: + A numpy array containing the loaded landmarks. + + Raises: + None. + """ + if file_path is None: + return np.zeros((81, 2)) + if not self.lmdb: + # if not file_path[0] == '.': + # file_path = f'./{self.config["rgb_dir"]}\\'+file_path + if os.path.exists(file_path): + landmark = np.load(file_path) + else: + return np.zeros((81, 2)) + else: + with self.env.begin(write=False) as txn: + # transfer the path format from rgb-path to lmdb-key + if file_path[0]=='.': + file_path=file_path.replace('./datasets\\','') + binary = txn.get(file_path.encode()) + landmark = np.frombuffer(binary, dtype=np.uint32).reshape((81, 2)) + landmark=self.rescale_landmarks(np.float32(landmark), original_size=256, new_size=self.config['resolution']) + return landmark + + def to_tensor(self, img): + """ + Convert an image to a PyTorch tensor. + """ + return T.ToTensor()(img) + + def normalize(self, img): + """ + Normalize an image. + """ + mean = self.config['mean'] + std = self.config['std'] + normalize = T.Normalize(mean=mean, std=std) + return normalize(img) + + def data_aug(self, img, landmark=None, mask=None, augmentation_seed=None): + """ + Apply data augmentation to an image, landmark, and mask. + + Args: + img: An Image object containing the image to be augmented. + landmark: A numpy array containing the 2D facial landmarks to be augmented. + mask: A numpy array containing the binary mask to be augmented. + + Returns: + The augmented image, landmark, and mask. + """ + + # Set the seed for the random number generator + if augmentation_seed is not None: + random.seed(augmentation_seed) + np.random.seed(augmentation_seed) + + # Create a dictionary of arguments + kwargs = {'image': img} + + # Check if the landmark and mask are not None + if landmark is not None: + kwargs['keypoints'] = landmark + kwargs['keypoint_params'] = A.KeypointParams(format='xy') + if mask is not None: + mask = mask.squeeze(2) + if mask.max() > 0: + kwargs['mask'] = mask + + # Apply data augmentation + transformed = self.transform(**kwargs) + + # Get the augmented image, landmark, and mask + # NOTE + # augmented_img = transformed['image'] + augmented_img = kwargs['image'] + augmented_landmark = transformed.get('keypoints') + augmented_mask = transformed.get('mask',mask) + + # Convert the augmented landmark to a numpy array + if augmented_landmark is not None: + augmented_landmark = np.array(augmented_landmark) + + # Reset the seeds to ensure different transformations for different videos + if augmentation_seed is not None: + random.seed() + np.random.seed() + + return augmented_img, augmented_landmark, augmented_mask + + def __getitem__(self, index, no_norm=False): + """ + Returns the data point at the given index. + + Args: + index (int): The index of the data point. + + Returns: + A tuple containing the image tensor, the label tensor, the landmark tensor, + and the mask tensor. + """ + # Get the image paths and label + image_paths = self.data_dict['image'][index] + label = self.data_dict['label'][index] + + # Image-level: FaceForensics++\manipulated_sequences\NeuralTextures\c23\frames\487_477\000.png + # Video-level: image_paths ['FaceForensics++\\original_sequences\\youtube\\c23\\frames\\977\\000.png', ..., 'FaceForensics++\\original_sequences\\youtube\\c23\\frames\\977\\314.png'] + if not isinstance(image_paths, list): + image_paths = [image_paths] # for the image-level IO, only one frame is used + + image_tensors = [] + landmark_tensors = [] + mask_tensors = [] + augmentation_seed = None + + for image_path in image_paths: + # Initialize a new seed for data augmentation at the start of each video + if self.video_level and image_path == image_paths[0]: + augmentation_seed = random.randint(0, 2**32 - 1) + + # Get the mask and landmark paths + mask_path = image_path.replace('frames', 'masks') # Use .png for mask + landmark_path = image_path.replace('frames', 'landmarks').replace('.png', '.npy') # Use .npy for landmark + + # Load the image + try: + image = self.load_rgb(image_path) + except Exception as e: + # Skip this image and return the first one + print(f"Error loading image at index {index}: {e}") + return self.__getitem__(0) + image = np.array(image) # Convert to numpy array for data augmentation + + # Load mask and landmark (if needed) + if self.config['with_mask']: + mask = self.load_mask(mask_path) + else: + mask = None + if self.config['with_landmark']: + landmarks = self.load_landmark(landmark_path) + else: + landmarks = None + + # Do Data Augmentation + if self.mode == 'train' and self.config['use_data_augmentation']: + image_trans, landmarks_trans, mask_trans = self.data_aug(image, landmarks, mask, augmentation_seed) + else: + # if self.mode == 'train': + # print("Train w/o data_augmentation") + image_trans, landmarks_trans, mask_trans = deepcopy(image), deepcopy(landmarks), deepcopy(mask) + + + # To tensor and normalize + if not no_norm: + image_trans = self.normalize(self.to_tensor(image_trans)) + if self.config['with_landmark']: + landmarks_trans = torch.from_numpy(landmarks) + if self.config['with_mask']: + mask_trans = torch.from_numpy(mask_trans) + + image_tensors.append(image_trans) + landmark_tensors.append(landmarks_trans) + mask_tensors.append(mask_trans) + + if self.video_level: + # Stack image tensors along a new dimension (time) + image_tensors = torch.stack(image_tensors, dim=0) + # Stack landmark and mask tensors along a new dimension (time) + if not any(landmark is None or (isinstance(landmark, list) and None in landmark) for landmark in landmark_tensors): + landmark_tensors = torch.stack(landmark_tensors, dim=0) + if not any(m is None or (isinstance(m, list) and None in m) for m in mask_tensors): + mask_tensors = torch.stack(mask_tensors, dim=0) + else: + # Get the first image tensor + image_tensors = image_tensors[0] + # Get the first landmark and mask tensors + if not any(landmark is None or (isinstance(landmark, list) and None in landmark) for landmark in landmark_tensors): + landmark_tensors = landmark_tensors[0] + if not any(m is None or (isinstance(m, list) and None in m) for m in mask_tensors): + mask_tensors = mask_tensors[0] + + return image_tensors, label, landmark_tensors, mask_tensors + + @staticmethod + def collate_fn(batch): + """ + Collate a batch of data points. + + Args: + batch (list): A list of tuples containing the image tensor, the label tensor, + the landmark tensor, and the mask tensor. + + Returns: + A tuple containing the image tensor, the label tensor, the landmark tensor, + and the mask tensor. + """ + # Separate the image, label, landmark, and mask tensors + images, labels, landmarks, masks = zip(*batch) + + # Stack the image, label, landmark, and mask tensors + images = torch.stack(images, dim=0) + labels = torch.LongTensor(labels) + + # Special case for landmarks and masks if they are None + if not any(landmark is None or (isinstance(landmark, list) and None in landmark) for landmark in landmarks): + landmarks = torch.stack(landmarks, dim=0) + else: + landmarks = None + + if not any(m is None or (isinstance(m, list) and None in m) for m in masks): + masks = torch.stack(masks, dim=0) + else: + masks = None + + # Create a dictionary of the tensors + data_dict = {} + data_dict['image'] = images + data_dict['label'] = labels + data_dict['landmark'] = landmarks + data_dict['mask'] = masks + return data_dict + + def __len__(self): + """ + Return the length of the dataset. + + Args: + None. + + Returns: + An integer indicating the length of the dataset. + + Raises: + AssertionError: If the number of images and labels in the dataset are not equal. + """ + assert len(self.image_list) == len(self.label_list), 'Number of images and labels are not equal' + return len(self.image_list) + + +if __name__ == "__main__": + with open('/data/home/zhiyuanyan/DeepfakeBench/training/config/detector/video_baseline.yaml', 'r') as f: + config = yaml.safe_load(f) + train_set = DeepfakeAbstractBaseDataset( + config = config, + mode = 'train', + ) + train_data_loader = \ + torch.utils.data.DataLoader( + dataset=train_set, + batch_size=config['train_batchSize'], + shuffle=True, + num_workers=0, + collate_fn=train_set.collate_fn, + ) + from tqdm import tqdm + for iteration, batch in enumerate(tqdm(train_data_loader)): + # print(iteration) + ... + # if iteration > 10: + # break diff --git a/preprocessing/dataset_json/gpa.json b/preprocessing/dataset_json/gpa.json new file mode 100644 index 0000000000000000000000000000000000000000..19a5faa650104bdd94dda7547d6c700ea060334c --- /dev/null +++ b/preprocessing/dataset_json/gpa.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3fb111163a61bdb552b4354897adddb63e73a15153cd5f51c4b76b3c226a5e9c +size 4382973 diff --git a/preprocessing/dataset_json/heygen.json b/preprocessing/dataset_json/heygen.json new file mode 100644 index 0000000000000000000000000000000000000000..776990cd5b67345be821a203c09b687e6cb54a9e --- /dev/null +++ b/preprocessing/dataset_json/heygen.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3bb0f12c3d3fc0d045056b8b04d9f158275dde13c651343a667099f869c8bc96 +size 1681524 diff --git a/preprocessing/dataset_json/others/Chameleon.json b/preprocessing/dataset_json/others/Chameleon.json new file mode 100644 index 0000000000000000000000000000000000000000..f5bf9f1f83fe80c473683fde4b33ec9fef2a03bc --- /dev/null +++ b/preprocessing/dataset_json/others/Chameleon.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a2a154223e90dfc697d094f6a7f44888e22c39e1241be584994e099519970ea4 +size 12495937 diff --git a/preprocessing/dataset_json/others/CoDiff.json b/preprocessing/dataset_json/others/CoDiff.json new file mode 100644 index 0000000000000000000000000000000000000000..1b976604b1dfcb374b59329f832e5e398b428bcf --- /dev/null +++ b/preprocessing/dataset_json/others/CoDiff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d2255d70cec56cbfacec704cbbda4b240d4c110835b3454c82873fb04f2f8ad9 +size 7851178 diff --git a/preprocessing/dataset_json/others/CollabDiff.json b/preprocessing/dataset_json/others/CollabDiff.json new file mode 100644 index 0000000000000000000000000000000000000000..8401a3bc074e7a182f38884b50a53e54639d1f07 --- /dev/null +++ b/preprocessing/dataset_json/others/CollabDiff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:564cf3e540fa21664e2711ae75013343281c880e5584bc2e8d5c6fc06d15990f +size 473682 diff --git a/preprocessing/dataset_json/others/DCFace.json b/preprocessing/dataset_json/others/DCFace.json new file mode 100644 index 0000000000000000000000000000000000000000..1ec46a7b3d684e81dffc81266a54dbd929cd77d1 --- /dev/null +++ b/preprocessing/dataset_json/others/DCFace.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6622e0a0dd8c552ab2b233c394848053a1b20fa461fea8e74ab99de5bcb87d32 +size 8019871 diff --git a/preprocessing/dataset_json/others/DeeperForensics-1.0.json b/preprocessing/dataset_json/others/DeeperForensics-1.0.json new file mode 100644 index 0000000000000000000000000000000000000000..c31c800493663f1d0914d25f28bda0435cf23078 --- /dev/null +++ b/preprocessing/dataset_json/others/DeeperForensics-1.0.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9c3d582eae7754cb005a687d607e25861585247a6a81908364f32eab85d6af66 +size 2606195 diff --git a/preprocessing/dataset_json/others/DiT_cdf.json b/preprocessing/dataset_json/others/DiT_cdf.json new file mode 100644 index 0000000000000000000000000000000000000000..331287937663cea4153bc32112a39547d2baf286 --- /dev/null +++ b/preprocessing/dataset_json/others/DiT_cdf.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cb15e8b0b4bd900d6251b39f785145bc628a2e148bc7d5791db06fdc0c8b4492 +size 5204091 diff --git a/preprocessing/dataset_json/others/DiT_ff.json b/preprocessing/dataset_json/others/DiT_ff.json new file mode 100644 index 0000000000000000000000000000000000000000..36ef208e9bea9821fa0daff0a0d3ab80f2776bf3 --- /dev/null +++ b/preprocessing/dataset_json/others/DiT_ff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a1725104dfcc9b459b1e671a16bcc37ca1c1fd02cdad86aa2c1361edb47629c4 +size 4677394 diff --git a/preprocessing/dataset_json/others/EFSAll_cdf.json b/preprocessing/dataset_json/others/EFSAll_cdf.json new file mode 100644 index 0000000000000000000000000000000000000000..08c59c89a3d10824970d6c28bcfaf59cf8d37f1f --- /dev/null +++ b/preprocessing/dataset_json/others/EFSAll_cdf.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ae2487011b7dd66954830da3f84208ac280772159c950f50446ec6114db96e82 +size 55801088 diff --git a/preprocessing/dataset_json/others/EFSAll_ff.json b/preprocessing/dataset_json/others/EFSAll_ff.json new file mode 100644 index 0000000000000000000000000000000000000000..345725f15252ffe2d82a9be4a79d618afa92500e --- /dev/null +++ b/preprocessing/dataset_json/others/EFSAll_ff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:56daa5d2d011ad0bcef828979a11786ab0599f63063aea7c61dbd208f6e053ce +size 48399606 diff --git a/preprocessing/dataset_json/others/FRAll_cdf.json b/preprocessing/dataset_json/others/FRAll_cdf.json new file mode 100644 index 0000000000000000000000000000000000000000..4606f17242a367f345925c71757aaf3d9f44a915 --- /dev/null +++ b/preprocessing/dataset_json/others/FRAll_cdf.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1db14cdd1e09a50eddeed1260fdc799e2370a5f8944b755d4dbe3cf640a79dfd +size 51106195 diff --git a/preprocessing/dataset_json/others/FRAll_ff.json b/preprocessing/dataset_json/others/FRAll_ff.json new file mode 100644 index 0000000000000000000000000000000000000000..62e5cfcce85b0d0885eb57eb972958a070fd305e --- /dev/null +++ b/preprocessing/dataset_json/others/FRAll_ff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9cbca5ced198c1f795a7a304c705d2fe4af5006adc41cf7f898f6f6c3de74624 +size 67347051 diff --git a/preprocessing/dataset_json/others/FSAll_cdf.json b/preprocessing/dataset_json/others/FSAll_cdf.json new file mode 100644 index 0000000000000000000000000000000000000000..5c6c938e96b67b3d2afa87efc68ecdf343185447 --- /dev/null +++ b/preprocessing/dataset_json/others/FSAll_cdf.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:73939286d610f0d2da4cc589285690762ce070f77beb2b2cfa2b8068126f7efb +size 34632459 diff --git a/preprocessing/dataset_json/others/FSAll_ff.json b/preprocessing/dataset_json/others/FSAll_ff.json new file mode 100644 index 0000000000000000000000000000000000000000..4cec80253b61ab8e6a298112f89ed12d5f15433b --- /dev/null +++ b/preprocessing/dataset_json/others/FSAll_ff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ffaac2b1110f1616950364320eef1fd75e60e08273c73e1600413c4aba99fe05 +size 53468909 diff --git a/preprocessing/dataset_json/others/FaceForensics++_vae.json b/preprocessing/dataset_json/others/FaceForensics++_vae.json new file mode 100644 index 0000000000000000000000000000000000000000..273969876bad751bcd2dc5e35eabb41f5e11922f --- /dev/null +++ b/preprocessing/dataset_json/others/FaceForensics++_vae.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7a55d9e412945fbf17cd1f1f9c2fbabb43074541ba8707772c343821d108b39a +size 22330288 diff --git a/preprocessing/dataset_json/others/FreeDoM_I.json b/preprocessing/dataset_json/others/FreeDoM_I.json new file mode 100644 index 0000000000000000000000000000000000000000..86ea7a6d9554f417ff816b2c11c6cd8fd2f94ec7 --- /dev/null +++ b/preprocessing/dataset_json/others/FreeDoM_I.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b0b820bb3074d176ce55feabc88effe1d60c1eb9d8c4207916fbe98eb239c0c7 +size 7371858 diff --git a/preprocessing/dataset_json/others/FreeDoM_T.json b/preprocessing/dataset_json/others/FreeDoM_T.json new file mode 100644 index 0000000000000000000000000000000000000000..d68dcd2efe515a553876582cfe7805382564239c --- /dev/null +++ b/preprocessing/dataset_json/others/FreeDoM_T.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:868d8666710d5a52172e20b66b2ff4831fb8371867d91f3ed066af6326cb97e6 +size 7445199 diff --git a/preprocessing/dataset_json/others/MRAA_cdf.json b/preprocessing/dataset_json/others/MRAA_cdf.json new file mode 100644 index 0000000000000000000000000000000000000000..e50f57710a2829b85e2eaec2e726276dd1b6b79d --- /dev/null +++ b/preprocessing/dataset_json/others/MRAA_cdf.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4aa412d4c096609acf6474322137de6d1723e765002d7a52282e1a8f7f217a9e +size 4030920 diff --git a/preprocessing/dataset_json/others/MRAA_ff.json b/preprocessing/dataset_json/others/MRAA_ff.json new file mode 100644 index 0000000000000000000000000000000000000000..1ab2d308dbaa6641814267ae9d54343c3fdb8273 --- /dev/null +++ b/preprocessing/dataset_json/others/MRAA_ff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:057ba54e1573ea011a07f9fd1713adac01a96dc822e27ce1acba1d1d3730843b +size 5454366 diff --git a/preprocessing/dataset_json/others/SDXL.json b/preprocessing/dataset_json/others/SDXL.json new file mode 100644 index 0000000000000000000000000000000000000000..39564b963aff3f5b566a96c73e8826f2f61aaec9 --- /dev/null +++ b/preprocessing/dataset_json/others/SDXL.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:beb9ca562ef02ee5ef292bdcc46683c4a6fd42834ae6ad3680b0c4257ca87675 +size 8257844 diff --git a/preprocessing/dataset_json/others/SDXL_Refine.json b/preprocessing/dataset_json/others/SDXL_Refine.json new file mode 100644 index 0000000000000000000000000000000000000000..aeb60e95a310d97bc6767439ec714b43fe22f964 --- /dev/null +++ b/preprocessing/dataset_json/others/SDXL_Refine.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d739eb33786955d6a894c803fc6da0aefc275d78877f8c109718992c32eb9e21 +size 8324707 diff --git a/preprocessing/dataset_json/others/SiT_cdf.json b/preprocessing/dataset_json/others/SiT_cdf.json new file mode 100644 index 0000000000000000000000000000000000000000..da5965cfed478781d7dbfc0dd78ca7b7ac07a4ea --- /dev/null +++ b/preprocessing/dataset_json/others/SiT_cdf.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b035b250609c64012b99fc9efc12e2ae133502dfb5bee52a4cab6bbf50709c00 +size 5204091 diff --git a/preprocessing/dataset_json/others/SiT_ff.json b/preprocessing/dataset_json/others/SiT_ff.json new file mode 100644 index 0000000000000000000000000000000000000000..8640726cefc2e7af4306330df716f5a8c832694e --- /dev/null +++ b/preprocessing/dataset_json/others/SiT_ff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e4d46c1ab82ab49fb42e486e6d2b438c829a867864ae337b58fc382862ef736a +size 4677394 diff --git a/preprocessing/dataset_json/others/StyleGAN2_cdf.json b/preprocessing/dataset_json/others/StyleGAN2_cdf.json new file mode 100644 index 0000000000000000000000000000000000000000..2d4087357bb8939b40b534988af139c620005cac --- /dev/null +++ b/preprocessing/dataset_json/others/StyleGAN2_cdf.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:624a834d28c00ff18fcd4e7fb3b36c044eac6ff5fc5e89a140042bf40b0125e7 +size 5782601 diff --git a/preprocessing/dataset_json/others/StyleGAN2_ff.json b/preprocessing/dataset_json/others/StyleGAN2_ff.json new file mode 100644 index 0000000000000000000000000000000000000000..0a72d356c1e62300dbc74f858e39ce6f00262fbd --- /dev/null +++ b/preprocessing/dataset_json/others/StyleGAN2_ff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:52bc8aae5064965ec20d5dbd34c17d7fa280dd5a47a4465d0b26e7cf92524321 +size 4944954 diff --git a/preprocessing/dataset_json/others/StyleGAN3_cdf.json b/preprocessing/dataset_json/others/StyleGAN3_cdf.json new file mode 100644 index 0000000000000000000000000000000000000000..31ac16f0ad5b52f49beab655f39ecf4fd3fea934 --- /dev/null +++ b/preprocessing/dataset_json/others/StyleGAN3_cdf.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b07a2022eeece869c748d487c2c67c22053196e147c97c88589dd23882eb5a00 +size 5782601 diff --git a/preprocessing/dataset_json/others/StyleGAN3_ff.json b/preprocessing/dataset_json/others/StyleGAN3_ff.json new file mode 100644 index 0000000000000000000000000000000000000000..7c87518e3949fc5c25a6606e916f0ce74f6a828b --- /dev/null +++ b/preprocessing/dataset_json/others/StyleGAN3_ff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4de58934104a9ce3d956aad6e8b15ef43099a30f8fd03de45a89428d7d1f9ba0 +size 4944954 diff --git a/preprocessing/dataset_json/others/StyleGANXL_cdf.json b/preprocessing/dataset_json/others/StyleGANXL_cdf.json new file mode 100644 index 0000000000000000000000000000000000000000..122460b23da834cf070d5cb60d0e76834261ed60 --- /dev/null +++ b/preprocessing/dataset_json/others/StyleGANXL_cdf.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eaf0bb3a419691ca358e276c35c54df8d05d49a3492b5025069c36f15dbe7dc3 +size 5839910 diff --git a/preprocessing/dataset_json/others/StyleGANXL_ff.json b/preprocessing/dataset_json/others/StyleGANXL_ff.json new file mode 100644 index 0000000000000000000000000000000000000000..f80b7f1383259175eaea338d926614c27e4750eb --- /dev/null +++ b/preprocessing/dataset_json/others/StyleGANXL_ff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:90b8446dafdf6275922f7675efc7fe1bf7b31dd133c0af736fa91b2387449218 +size 4972417 diff --git a/preprocessing/dataset_json/others/VQGAN_cdf.json b/preprocessing/dataset_json/others/VQGAN_cdf.json new file mode 100644 index 0000000000000000000000000000000000000000..96b299892eb32fe29b82968c95ccf683f618c47f --- /dev/null +++ b/preprocessing/dataset_json/others/VQGAN_cdf.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1dda89979d4d6d0266248e0160665e33543af98657e3f760b9b72dea2c4d56a4 +size 5399625 diff --git a/preprocessing/dataset_json/others/VQGAN_ff.json b/preprocessing/dataset_json/others/VQGAN_ff.json new file mode 100644 index 0000000000000000000000000000000000000000..1a8ea14b7fe9b05f94f113829679c267f830ff98 --- /dev/null +++ b/preprocessing/dataset_json/others/VQGAN_ff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3e05bbf40a8f44bb4c2a7cea8efc96a4a293b35a1347c835a1afd235ffe2022e +size 4771000 diff --git a/preprocessing/dataset_json/others/adm.json b/preprocessing/dataset_json/others/adm.json new file mode 100644 index 0000000000000000000000000000000000000000..0324fc8ef50fb1fd491decee56a0edb9c26d0c0f --- /dev/null +++ b/preprocessing/dataset_json/others/adm.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d845c00046e6624f8470389bdc00c7ad7f65c801f981addd51880097040d966e +size 70928395 diff --git a/preprocessing/dataset_json/others/biggan.json b/preprocessing/dataset_json/others/biggan.json new file mode 100644 index 0000000000000000000000000000000000000000..75827051f6e83d6c7b417da35353cd438ca1db2f --- /dev/null +++ b/preprocessing/dataset_json/others/biggan.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cbcab5cefbd7446feab3c93ec401d7c034e111b011edd5f34b06a463b0b47913 +size 75962300 diff --git a/preprocessing/dataset_json/others/blendface_cdf.json b/preprocessing/dataset_json/others/blendface_cdf.json new file mode 100644 index 0000000000000000000000000000000000000000..4d9b63c7c926d99d00316f2fc8e036e6b6b86378 --- /dev/null +++ b/preprocessing/dataset_json/others/blendface_cdf.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0056ec6189c29f92078cb458dabf8e7ac2ab5bc9c3c5d9e20ac157b0bc1768dc +size 4248757 diff --git a/preprocessing/dataset_json/others/blendface_ff.json b/preprocessing/dataset_json/others/blendface_ff.json new file mode 100644 index 0000000000000000000000000000000000000000..eef77c7fa98bc74538e350cc85bf2234482e701a --- /dev/null +++ b/preprocessing/dataset_json/others/blendface_ff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3968f25a003f90f17fd13e9b4470f0476a5fc3e5500fdb03f4f39b6d7152f831 +size 5608686 diff --git a/preprocessing/dataset_json/others/cycle_diff.json b/preprocessing/dataset_json/others/cycle_diff.json new file mode 100644 index 0000000000000000000000000000000000000000..1039b9d06516a5ca84f4457613330222eb9ae28d --- /dev/null +++ b/preprocessing/dataset_json/others/cycle_diff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:711dcb423b6d3369984690f7bebb5a9b2bf708994907d91bac31f22213af8646 +size 8802711 diff --git a/preprocessing/dataset_json/others/dalle2.json b/preprocessing/dataset_json/others/dalle2.json new file mode 100644 index 0000000000000000000000000000000000000000..72b81ba72aa54fbd493402b42926b8a2e7f43c70 --- /dev/null +++ b/preprocessing/dataset_json/others/dalle2.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0c2b376deea5a41cbf50326622ba0fd92ecdd636c473aa46474b0a6cd00562f7 +size 489450 diff --git a/preprocessing/dataset_json/others/danet_cdf.json b/preprocessing/dataset_json/others/danet_cdf.json new file mode 100644 index 0000000000000000000000000000000000000000..fb6380bf1383fdfb82d0a28d45cd718084e2cf49 --- /dev/null +++ b/preprocessing/dataset_json/others/danet_cdf.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d1064ef3f33e56194af5288108eb33526b4446dd5c4f9ec35837b734f65dc067 +size 3926523 diff --git a/preprocessing/dataset_json/others/danet_ff.json b/preprocessing/dataset_json/others/danet_ff.json new file mode 100644 index 0000000000000000000000000000000000000000..a47b228c6dfdaea0de714fc3621aeac9129b67f0 --- /dev/null +++ b/preprocessing/dataset_json/others/danet_ff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0932362d828570309cd54c8799d12a85973d9fab96bc434c4e03ee7e46354f10 +size 5476819 diff --git a/preprocessing/dataset_json/others/ddim_cdf.json b/preprocessing/dataset_json/others/ddim_cdf.json new file mode 100644 index 0000000000000000000000000000000000000000..636693529806c6a0ca93023d4a72f893b6523274 --- /dev/null +++ b/preprocessing/dataset_json/others/ddim_cdf.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f834231ae15d00eab9f5ce48dae79e70f1384675156263fa0935f4c05d2c54c0 +size 6310932 diff --git a/preprocessing/dataset_json/others/ddim_ff.json b/preprocessing/dataset_json/others/ddim_ff.json new file mode 100644 index 0000000000000000000000000000000000000000..a82cc576f01f0216ff6ba1772f3f29e3d4039115 --- /dev/null +++ b/preprocessing/dataset_json/others/ddim_ff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0635dbf81dde7293ab85c8e10bb13788b19570dce8d3f6e3cdf5849583ef5b6b +size 4769167 diff --git a/preprocessing/dataset_json/others/deepfacelab.json b/preprocessing/dataset_json/others/deepfacelab.json new file mode 100644 index 0000000000000000000000000000000000000000..728acd842fac6ba040d793f4212b3bbacda16aa5 --- /dev/null +++ b/preprocessing/dataset_json/others/deepfacelab.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:66df7d43a6fa7356c099743d94d216c8e4ad8ef7794a6ec1382fbec744441f01 +size 738691 diff --git a/preprocessing/dataset_json/others/e4e_cdf_old.json b/preprocessing/dataset_json/others/e4e_cdf_old.json new file mode 100644 index 0000000000000000000000000000000000000000..abbc36c2884e8d9323b7fad1869bd2629e0e3822 --- /dev/null +++ b/preprocessing/dataset_json/others/e4e_cdf_old.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aa35d58da4320cbf723e40f34884d8b917ff3d78fc45e894a624195d27c02d2b +size 4160529 diff --git a/preprocessing/dataset_json/others/e4e_ff.json b/preprocessing/dataset_json/others/e4e_ff.json new file mode 100644 index 0000000000000000000000000000000000000000..c3bd9e6d81002a7bb0308495869a2d63b61e44b0 --- /dev/null +++ b/preprocessing/dataset_json/others/e4e_ff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:64acd5f8af8b35fe4e008351c493017060e7755a715998c6a932ca96fac27012 +size 6281120 diff --git a/preprocessing/dataset_json/others/e4e_ff_old.json b/preprocessing/dataset_json/others/e4e_ff_old.json new file mode 100644 index 0000000000000000000000000000000000000000..cf40c563984a619a39efb9c34864d38a805d2cf7 --- /dev/null +++ b/preprocessing/dataset_json/others/e4e_ff_old.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cba5d5007b0b78c47ed559f7e3642942e8c6a87cc5b51fe25fa0089f708dae85 +size 15084316 diff --git a/preprocessing/dataset_json/others/e4e_known.json b/preprocessing/dataset_json/others/e4e_known.json new file mode 100644 index 0000000000000000000000000000000000000000..18a430406ca77239b927aac59cae0a2f3233b09f --- /dev/null +++ b/preprocessing/dataset_json/others/e4e_known.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fa83f53006c8bf5596848f1468d7bc5453fe02b091e77c3fb9d8363ec17f8fa8 +size 14111021 diff --git a/preprocessing/dataset_json/others/e4e_old.json b/preprocessing/dataset_json/others/e4e_old.json new file mode 100644 index 0000000000000000000000000000000000000000..8b6beba1132acc318c4d88ce069132fd3b78526d --- /dev/null +++ b/preprocessing/dataset_json/others/e4e_old.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a449090f651dfbbaa8b22e67c186d45e333d01f8c4919b55e9aa1f2fd1a33ff1 +size 7376543 diff --git a/preprocessing/dataset_json/others/e4s_cdf.json b/preprocessing/dataset_json/others/e4s_cdf.json new file mode 100644 index 0000000000000000000000000000000000000000..3a80f2a14fb4083a049e43012d4515089f075f19 --- /dev/null +++ b/preprocessing/dataset_json/others/e4s_cdf.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c0251567863e16988c69b4086f2ea5702c63efb00ca4de5e364cc8efd615f399 +size 2302985 diff --git a/preprocessing/dataset_json/others/e4s_ff.json b/preprocessing/dataset_json/others/e4s_ff.json new file mode 100644 index 0000000000000000000000000000000000000000..bdb4f2929be6ed19230e9be86570493cb6b75027 --- /dev/null +++ b/preprocessing/dataset_json/others/e4s_ff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a9501bf9dc4b2c00dbb74e93ae6133e9eed2c8c43382eab56140d62a77e73f70 +size 5308400 diff --git a/preprocessing/dataset_json/others/facedancer_cdf.json b/preprocessing/dataset_json/others/facedancer_cdf.json new file mode 100644 index 0000000000000000000000000000000000000000..acbf100c937306878542abeede4877b963db2929 --- /dev/null +++ b/preprocessing/dataset_json/others/facedancer_cdf.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9e737e1c8de71c9e77bf9cea146155bfcd825865fdeee6d5b5f2b822c207527d +size 4285468 diff --git a/preprocessing/dataset_json/others/facedancer_ff.json b/preprocessing/dataset_json/others/facedancer_ff.json new file mode 100644 index 0000000000000000000000000000000000000000..5a672a00031259029e5f52d5f54a48254276165e --- /dev/null +++ b/preprocessing/dataset_json/others/facedancer_ff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:835058f61a618ceb8152198b5388b16a528224410dcdfd3c031ec13bcce7ec38 +size 5660578 diff --git a/preprocessing/dataset_json/others/faceswap_cdf.json b/preprocessing/dataset_json/others/faceswap_cdf.json new file mode 100644 index 0000000000000000000000000000000000000000..67fe39a00d30c749966a40e0105b904ff0acc96c --- /dev/null +++ b/preprocessing/dataset_json/others/faceswap_cdf.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6101863766c16c7589a7faa2ab29cf2fe28c6fc3ecf9e44af8301996dfb3b474 +size 4182092 diff --git a/preprocessing/dataset_json/others/faceswap_ff.json b/preprocessing/dataset_json/others/faceswap_ff.json new file mode 100644 index 0000000000000000000000000000000000000000..fa857130204255356870edc4bbb1594f4507971e --- /dev/null +++ b/preprocessing/dataset_json/others/faceswap_ff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9d34f307b43e8cba553ede09b7bde86c7eece7a2c55cc60dc94fc4bb0c70d6e7 +size 5598718 diff --git a/preprocessing/dataset_json/others/facevid2vid_cdf.json b/preprocessing/dataset_json/others/facevid2vid_cdf.json new file mode 100644 index 0000000000000000000000000000000000000000..07e59da3027a6dd170dee5d72764b8f678dc2d3e --- /dev/null +++ b/preprocessing/dataset_json/others/facevid2vid_cdf.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7801026e94031c860cb4cb7537a3a6eb06eafcd39845918228c226348c50f3ff +size 4261519 diff --git a/preprocessing/dataset_json/others/facevid2vid_ff.json b/preprocessing/dataset_json/others/facevid2vid_ff.json new file mode 100644 index 0000000000000000000000000000000000000000..2087d717a078221cee8aceede1dd49ec4a5e51a6 --- /dev/null +++ b/preprocessing/dataset_json/others/facevid2vid_ff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c2f974b8025527b6f88b7f6db170c2710907fa1971e30929d93292f5f5a56fe4 +size 5702469 diff --git a/preprocessing/dataset_json/others/fomm_cdf.json b/preprocessing/dataset_json/others/fomm_cdf.json new file mode 100644 index 0000000000000000000000000000000000000000..f8d42ef9d13a7057f6dad7dd2573d493472da113 --- /dev/null +++ b/preprocessing/dataset_json/others/fomm_cdf.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d8447b58ea8d23686c12a9a1b7a0b1813cfb0e0447ccb06ecf919acfac46d017 +size 4036914 diff --git a/preprocessing/dataset_json/others/fomm_ff.json b/preprocessing/dataset_json/others/fomm_ff.json new file mode 100644 index 0000000000000000000000000000000000000000..aadfc22a1da0152958fbca5e1c35e4ebdea29336 --- /dev/null +++ b/preprocessing/dataset_json/others/fomm_ff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:de50cd925ed05cad7688dcbc1d8e483a71367549ebee1c2de505b1165c63caf2 +size 5459521 diff --git a/preprocessing/dataset_json/others/fsgan_cdf.json b/preprocessing/dataset_json/others/fsgan_cdf.json new file mode 100644 index 0000000000000000000000000000000000000000..5767df0b63066bdb977de05aef7f9ab803cfaa54 --- /dev/null +++ b/preprocessing/dataset_json/others/fsgan_cdf.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:044d93274585f52cc53394c525a13f1d718aaed3cd7d4d48be23c50f15a865a5 +size 3967037 diff --git a/preprocessing/dataset_json/others/fsgan_ff.json b/preprocessing/dataset_json/others/fsgan_ff.json new file mode 100644 index 0000000000000000000000000000000000000000..ceebba48261d120e2786a3effb05504a5829dd14 --- /dev/null +++ b/preprocessing/dataset_json/others/fsgan_ff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b4b78751e0f5736d49578cf4fe09aa8498040f8ddda398ca157b121f2651622e +size 5415933 diff --git a/preprocessing/dataset_json/others/genimage_mj.json b/preprocessing/dataset_json/others/genimage_mj.json new file mode 100644 index 0000000000000000000000000000000000000000..6eaa8904e5dacee51cd72871eadb860fafff50c4 --- /dev/null +++ b/preprocessing/dataset_json/others/genimage_mj.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0c41e4e50490c040bf3d6d435371989036b3431e6f2ba57feed57959d920647d +size 73639721 diff --git a/preprocessing/dataset_json/others/glide.json b/preprocessing/dataset_json/others/glide.json new file mode 100644 index 0000000000000000000000000000000000000000..779ca0db2f6a5f9b4f510397b02078ad7ee1dc65 --- /dev/null +++ b/preprocessing/dataset_json/others/glide.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ec9e8499ab07103a797322bb6367ae1ec171d5402671c4c17ef663f5126f0b94 +size 78398099 diff --git a/preprocessing/dataset_json/others/hyperreenact_cdf.json b/preprocessing/dataset_json/others/hyperreenact_cdf.json new file mode 100644 index 0000000000000000000000000000000000000000..b690a3c97ad87548a88e1f16c597108b3b79c67e --- /dev/null +++ b/preprocessing/dataset_json/others/hyperreenact_cdf.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f625c9a16f3283c3d543b4d7d183a1512fc16386b30051329eb8896af32dfd4e +size 4405854 diff --git a/preprocessing/dataset_json/others/hyperreenact_ff.json b/preprocessing/dataset_json/others/hyperreenact_ff.json new file mode 100644 index 0000000000000000000000000000000000000000..1183c0824202ac8e775a8db42224e8150a0f65f7 --- /dev/null +++ b/preprocessing/dataset_json/others/hyperreenact_ff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:db84210c07a1ae2eb62841a0aa4f016e52d2c2bfacc218d492ac73c94bbc826b +size 5415151 diff --git a/preprocessing/dataset_json/others/inswap_cdf.json b/preprocessing/dataset_json/others/inswap_cdf.json new file mode 100644 index 0000000000000000000000000000000000000000..b651e27c0a80e0cc754d1910dd6e8a36540641a3 --- /dev/null +++ b/preprocessing/dataset_json/others/inswap_cdf.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:204ae51489638f6f7b19f9990007f63889448cf276c5e5db10de07c5435ec7f7 +size 3030826 diff --git a/preprocessing/dataset_json/others/inswap_ff.json b/preprocessing/dataset_json/others/inswap_ff.json new file mode 100644 index 0000000000000000000000000000000000000000..225403c1e436c790022d4a1cb97445309b7c78de --- /dev/null +++ b/preprocessing/dataset_json/others/inswap_ff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eaf340bd2ab286804f7ce9200a57f8d95ef8f9da2592e18739196ed83f1be79a +size 5812542 diff --git a/preprocessing/dataset_json/others/ivy_fake_test.json b/preprocessing/dataset_json/others/ivy_fake_test.json new file mode 100644 index 0000000000000000000000000000000000000000..70aecd83d9a0a9c6b634d9f0db8219511b1be4ac --- /dev/null +++ b/preprocessing/dataset_json/others/ivy_fake_test.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:998f7a21557869c63aef160700777d46a14f84b6d0603d4087a95c6470902366 +size 180646 diff --git a/preprocessing/dataset_json/others/ivy_fake_test_Deepfakes.json b/preprocessing/dataset_json/others/ivy_fake_test_Deepfakes.json new file mode 100644 index 0000000000000000000000000000000000000000..7215acdb7ac4f9371e89cf85d30d2db51a779b61 --- /dev/null +++ b/preprocessing/dataset_json/others/ivy_fake_test_Deepfakes.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ff46024a7846e7bedf9a71fad8ae94c645cf1cb20aeeb6bf8d1a5f5e45abdec5 +size 105185 diff --git a/preprocessing/dataset_json/others/ivy_fake_test_Deepfakes_purefake.json b/preprocessing/dataset_json/others/ivy_fake_test_Deepfakes_purefake.json new file mode 100644 index 0000000000000000000000000000000000000000..e4dfb4e4578c3048dcde9e06d3b3402aa1c56e59 --- /dev/null +++ b/preprocessing/dataset_json/others/ivy_fake_test_Deepfakes_purefake.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:62289430b9214a2cc8fc8a9bd0c2d2e579e11069e031aa69562fd4b1950341db +size 26800 diff --git a/preprocessing/dataset_json/others/ivy_fake_test_Face2Face.json b/preprocessing/dataset_json/others/ivy_fake_test_Face2Face.json new file mode 100644 index 0000000000000000000000000000000000000000..bb73c1d96e61da48763a3d41064393dfbe6059a5 --- /dev/null +++ b/preprocessing/dataset_json/others/ivy_fake_test_Face2Face.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:98c93f7ceccbfe967c32d036d0f1e93f51cd607d3f7110273f972bec7ccad500 +size 106470 diff --git a/preprocessing/dataset_json/others/ivy_fake_test_Face2Face_purefake.json b/preprocessing/dataset_json/others/ivy_fake_test_Face2Face_purefake.json new file mode 100644 index 0000000000000000000000000000000000000000..7df3e24573e754dac446d30c0087192a327231ed --- /dev/null +++ b/preprocessing/dataset_json/others/ivy_fake_test_Face2Face_purefake.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:97b7ddf62fe7188cdf2a6aadc024f575e1c95c5a2749890e24c59f9888db19f3 +size 28085 diff --git a/preprocessing/dataset_json/others/ivy_fake_test_FaceSwap.json b/preprocessing/dataset_json/others/ivy_fake_test_FaceSwap.json new file mode 100644 index 0000000000000000000000000000000000000000..ba21a09b7ea18afb38166d1c21dff8b1c0c50e1c --- /dev/null +++ b/preprocessing/dataset_json/others/ivy_fake_test_FaceSwap.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5e6e0a619ec963e02b61337ecaae621525c1fbe63359945645a4af0843e1aabd +size 102704 diff --git a/preprocessing/dataset_json/others/ivy_fake_test_FaceSwap_purefake.json b/preprocessing/dataset_json/others/ivy_fake_test_FaceSwap_purefake.json new file mode 100644 index 0000000000000000000000000000000000000000..c8ffdea880e9a7aa5d343e48d664db935cdb687e --- /dev/null +++ b/preprocessing/dataset_json/others/ivy_fake_test_FaceSwap_purefake.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a0bcbac0d5244021ca1803136153a8a5ee2eb563be98a00af1ccb8590e76aaef +size 24319 diff --git a/preprocessing/dataset_json/others/ivy_fake_test_NeuralTextures.json b/preprocessing/dataset_json/others/ivy_fake_test_NeuralTextures.json new file mode 100644 index 0000000000000000000000000000000000000000..06290bf4d43e591613411b56872131b37385516e --- /dev/null +++ b/preprocessing/dataset_json/others/ivy_fake_test_NeuralTextures.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4a8c95f8d8f767149068bc5342b56122a598404d7800444eb9168152c4c4a374 +size 108449 diff --git a/preprocessing/dataset_json/others/ivy_fake_test_NeuralTextures_purefake.json b/preprocessing/dataset_json/others/ivy_fake_test_NeuralTextures_purefake.json new file mode 100644 index 0000000000000000000000000000000000000000..5b0d3b37298ec662e368e829092d4065aefe7cba --- /dev/null +++ b/preprocessing/dataset_json/others/ivy_fake_test_NeuralTextures_purefake.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c3a2c7665495a48cbb3509f2d568d4589d3d5fd7eeb184f4d5d5886124304755 +size 30064 diff --git a/preprocessing/dataset_json/others/ivy_fake_train.json b/preprocessing/dataset_json/others/ivy_fake_train.json new file mode 100644 index 0000000000000000000000000000000000000000..43795cab85403bb30211ec9daeeee3b5b8c461b8 --- /dev/null +++ b/preprocessing/dataset_json/others/ivy_fake_train.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:71a16e9c6d86b319a82a81916fd261c1fb4c18c3e1301d244b2433eeb0bef729 +size 3693471 diff --git a/preprocessing/dataset_json/others/lia_cdf.json b/preprocessing/dataset_json/others/lia_cdf.json new file mode 100644 index 0000000000000000000000000000000000000000..7bfc0bb86ad170082b8a9e2d0ac750665c6d4c33 --- /dev/null +++ b/preprocessing/dataset_json/others/lia_cdf.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:42206bef50fdc6a2f29165cdc07f7d9ac2c79f5fac4c8a7a9be0fb4b87cb8e4d +size 3918501 diff --git a/preprocessing/dataset_json/others/lia_ff.json b/preprocessing/dataset_json/others/lia_ff.json new file mode 100644 index 0000000000000000000000000000000000000000..ae05cdd1dfbca1989f2b34849c277b1bde3ee188 --- /dev/null +++ b/preprocessing/dataset_json/others/lia_ff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:825f0a559aa385b7a97b860f3c83c5c0e10f5d8c071002c2b52c8ad5613e47c7 +size 5418184 diff --git a/preprocessing/dataset_json/others/mcnet_cdf.json b/preprocessing/dataset_json/others/mcnet_cdf.json new file mode 100644 index 0000000000000000000000000000000000000000..2166bc82348e7f92b5b51dbb9374a0f8912efe44 --- /dev/null +++ b/preprocessing/dataset_json/others/mcnet_cdf.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6a00023846f877ad377942d638559ab0c324ab57bcf27a83157798c939382ed4 +size 3913813 diff --git a/preprocessing/dataset_json/others/mcnet_ff.json b/preprocessing/dataset_json/others/mcnet_ff.json new file mode 100644 index 0000000000000000000000000000000000000000..2498d4a74ac27b55c9f57c7522ebcec26d944b2c --- /dev/null +++ b/preprocessing/dataset_json/others/mcnet_ff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fcfa0e1a8388528013e2bd34bb326991511778002ea87f7f5659f1bdd132ca3f +size 5472919 diff --git a/preprocessing/dataset_json/others/mobileswap_cdf.json b/preprocessing/dataset_json/others/mobileswap_cdf.json new file mode 100644 index 0000000000000000000000000000000000000000..070c36135f4b4f3df4fde8f3b249eae902c8af87 --- /dev/null +++ b/preprocessing/dataset_json/others/mobileswap_cdf.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c7497fcc3de4d81d74adbd5cb1f516f91cb8b0c46b9b05a184bea5b2a74ff384 +size 4277540 diff --git a/preprocessing/dataset_json/others/mobileswap_ff.json b/preprocessing/dataset_json/others/mobileswap_ff.json new file mode 100644 index 0000000000000000000000000000000000000000..1b74e5eff3ae856293ff346484997305a1e4b3f2 --- /dev/null +++ b/preprocessing/dataset_json/others/mobileswap_ff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9666a98e7d3b2fcaf746d7f862a4fd6e781c8582b9e2a32170e33a0c875b35fe +size 9120483 diff --git a/preprocessing/dataset_json/others/one_shot_free_cdf.json b/preprocessing/dataset_json/others/one_shot_free_cdf.json new file mode 100644 index 0000000000000000000000000000000000000000..d0e30711697187ca926f26cf4371100679b83b7c --- /dev/null +++ b/preprocessing/dataset_json/others/one_shot_free_cdf.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1fde7ae8f33066280fc83507f3773f8f78c6d8597cf941d6db6dec6530810125 +size 4475773 diff --git a/preprocessing/dataset_json/others/one_shot_free_ff.json b/preprocessing/dataset_json/others/one_shot_free_ff.json new file mode 100644 index 0000000000000000000000000000000000000000..07e55e1125e4297c1a35f68673067526e5c71b1d --- /dev/null +++ b/preprocessing/dataset_json/others/one_shot_free_ff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0da4ec6c7de2f85823aff37159e414dba757a87198489cc7d7b8a103bc34d798 +size 5790228 diff --git a/preprocessing/dataset_json/others/pirender_cdf.json b/preprocessing/dataset_json/others/pirender_cdf.json new file mode 100644 index 0000000000000000000000000000000000000000..920cd345d0f09a8a4a847f8275c738fbcdfb9292 --- /dev/null +++ b/preprocessing/dataset_json/others/pirender_cdf.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bda90f8a9806714539484bad8a03610091f2ce20723729d3d58b5667a229a5bc +size 4189462 diff --git a/preprocessing/dataset_json/others/pirender_ff.json b/preprocessing/dataset_json/others/pirender_ff.json new file mode 100644 index 0000000000000000000000000000000000000000..4cca3f51c241f8dbd9080f64b466dbfd6563479c --- /dev/null +++ b/preprocessing/dataset_json/others/pirender_ff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:43a8b8b25901b08534c2ccf9f9167e024127a7a1d727000a45de0934731f7c18 +size 5611202 diff --git a/preprocessing/dataset_json/others/pixart_cdf.json b/preprocessing/dataset_json/others/pixart_cdf.json new file mode 100644 index 0000000000000000000000000000000000000000..2cbb1ac8c95b06529437fce5241bfcf25b6f38bb --- /dev/null +++ b/preprocessing/dataset_json/others/pixart_cdf.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b13445489282402707f9682a60650fc70732674dfa4486b33004e22948488961 +size 6427898 diff --git a/preprocessing/dataset_json/others/pixart_ff.json b/preprocessing/dataset_json/others/pixart_ff.json new file mode 100644 index 0000000000000000000000000000000000000000..5bec3da464ae3c8dfef071101986bbd0e195cfb8 --- /dev/null +++ b/preprocessing/dataset_json/others/pixart_ff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f3f607d4e2fa841999182c0c9502d7c1fb7a3f72b2f3d44302cb365b90b5695c +size 4824027 diff --git a/preprocessing/dataset_json/others/rddm_cdf.json b/preprocessing/dataset_json/others/rddm_cdf.json new file mode 100644 index 0000000000000000000000000000000000000000..41d137540f45ad42c10a872406acc3c3e99f809c --- /dev/null +++ b/preprocessing/dataset_json/others/rddm_cdf.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f607e2c738b79520f7a09abdb21630f2597f19fa4e2bbbbe1ccd1c5b935b1f65 +size 3331704 diff --git a/preprocessing/dataset_json/others/rddm_ff.json b/preprocessing/dataset_json/others/rddm_ff.json new file mode 100644 index 0000000000000000000000000000000000000000..ab2fe16e32a6018204560662c0d84c240ba9c8c1 --- /dev/null +++ b/preprocessing/dataset_json/others/rddm_ff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a64fdf740d67ff758bc066ae5045035b05906fc991c07dbcf0b7480a959f9ab6 +size 4882817 diff --git a/preprocessing/dataset_json/others/roop.json b/preprocessing/dataset_json/others/roop.json new file mode 100644 index 0000000000000000000000000000000000000000..9d57f72ab3f735745dcd8ca7ff9c6af626a196d5 --- /dev/null +++ b/preprocessing/dataset_json/others/roop.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:81685eeef5d867584c1ae39e6c6bd755b332b25af0f760cadc3acfb22b85ad4b +size 37145766 diff --git a/preprocessing/dataset_json/others/roop_ff.json b/preprocessing/dataset_json/others/roop_ff.json new file mode 100644 index 0000000000000000000000000000000000000000..8875dc0c185041f6bed5fb4a2d11b2e26337d040 --- /dev/null +++ b/preprocessing/dataset_json/others/roop_ff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e0d33779c2b25b664057b23d4965703719e7af1d8a7565e78c12d347f79aa266 +size 27452264 diff --git a/preprocessing/dataset_json/others/roop_ff_ori.json b/preprocessing/dataset_json/others/roop_ff_ori.json new file mode 100644 index 0000000000000000000000000000000000000000..8875dc0c185041f6bed5fb4a2d11b2e26337d040 --- /dev/null +++ b/preprocessing/dataset_json/others/roop_ff_ori.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e0d33779c2b25b664057b23d4965703719e7af1d8a7565e78c12d347f79aa266 +size 27452264 diff --git a/preprocessing/dataset_json/others/roop_ori.json b/preprocessing/dataset_json/others/roop_ori.json new file mode 100644 index 0000000000000000000000000000000000000000..9d57f72ab3f735745dcd8ca7ff9c6af626a196d5 --- /dev/null +++ b/preprocessing/dataset_json/others/roop_ori.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:81685eeef5d867584c1ae39e6c6bd755b332b25af0f760cadc3acfb22b85ad4b +size 37145766 diff --git a/preprocessing/dataset_json/others/sadtalker_cdf.json b/preprocessing/dataset_json/others/sadtalker_cdf.json new file mode 100644 index 0000000000000000000000000000000000000000..1ec9d32bd6d7a2dcdf9ce4a549bac548cb9a34a1 --- /dev/null +++ b/preprocessing/dataset_json/others/sadtalker_cdf.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ae8a65d334bad6f4c8e365ad9d7f1172398e4e100746cee707ed05794295e62f +size 4934457 diff --git a/preprocessing/dataset_json/others/sadtalker_ff.json b/preprocessing/dataset_json/others/sadtalker_ff.json new file mode 100644 index 0000000000000000000000000000000000000000..36c33d2ba0b2759b1b63372d5550c3c9ec65ebde --- /dev/null +++ b/preprocessing/dataset_json/others/sadtalker_ff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9db653e0354f7e0def13235f6014176c2feb8b25cd6e346446afcaacd5fa973b +size 6080261 diff --git a/preprocessing/dataset_json/others/sd1.5_cdf.json b/preprocessing/dataset_json/others/sd1.5_cdf.json new file mode 100644 index 0000000000000000000000000000000000000000..2afdc8916115e88ab0b1a6f613c21b85b55097d0 --- /dev/null +++ b/preprocessing/dataset_json/others/sd1.5_cdf.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7f94e1c3302b3e55cc38b363994bb0faa0c3ad0b0094ffc494fdb05dcd78a6b8 +size 6369415 diff --git a/preprocessing/dataset_json/others/sd1.5_ff.json b/preprocessing/dataset_json/others/sd1.5_ff.json new file mode 100644 index 0000000000000000000000000000000000000000..a8389ac4241e18281b55a1488d88409f1d6b120e --- /dev/null +++ b/preprocessing/dataset_json/others/sd1.5_ff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aa876da5df9c727cd05571778cb2bdc9f947fed9aa8b1fa62e973ba2158d69ec +size 4116694 diff --git a/preprocessing/dataset_json/others/sd2.1_cdf.json b/preprocessing/dataset_json/others/sd2.1_cdf.json new file mode 100644 index 0000000000000000000000000000000000000000..9ac799efe31abad6d3fc0a57d3f34d413cdc9ec6 --- /dev/null +++ b/preprocessing/dataset_json/others/sd2.1_cdf.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8a8b5d71b29457901b7c613a0d3e0821395ee592268d9f2e988eabeb13482fcf +size 6369415 diff --git a/preprocessing/dataset_json/others/sd2.1_ff.json b/preprocessing/dataset_json/others/sd2.1_ff.json new file mode 100644 index 0000000000000000000000000000000000000000..d3151d11ed30e4e9e0a2b794096115371a637456 --- /dev/null +++ b/preprocessing/dataset_json/others/sd2.1_ff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:94c55c0a10e9c3b124b2b296179be73d8099f6a050ffe7f15b262d606db68907 +size 4796597 diff --git a/preprocessing/dataset_json/others/sdv4.json b/preprocessing/dataset_json/others/sdv4.json new file mode 100644 index 0000000000000000000000000000000000000000..fd6abce823f928a90d2a866f1d18c91216778ae8 --- /dev/null +++ b/preprocessing/dataset_json/others/sdv4.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7adf97282d60c6c75e4e7013b8ac520f3d5d17605117e8af994f69d95f843b24 +size 73872424 diff --git a/preprocessing/dataset_json/others/sdv5.json b/preprocessing/dataset_json/others/sdv5.json new file mode 100644 index 0000000000000000000000000000000000000000..fd1b9773de83667592e5d3bf4370c8b45fdf39a2 --- /dev/null +++ b/preprocessing/dataset_json/others/sdv5.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c8099217ad4a0a9ace70bdd6d4c945c8b11bb72ccdcdfd014efa95217724ed2a +size 74587310 diff --git a/preprocessing/dataset_json/others/simswap.json b/preprocessing/dataset_json/others/simswap.json new file mode 100644 index 0000000000000000000000000000000000000000..b7511c3630aa91a157e5dc05677d9b70851ba378 --- /dev/null +++ b/preprocessing/dataset_json/others/simswap.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6d739acabe69b2b1784f5fdd4a7e0be4282be19b76efb73168fa086d567df66d +size 2439411 diff --git a/preprocessing/dataset_json/others/simswap_cdf.json b/preprocessing/dataset_json/others/simswap_cdf.json new file mode 100644 index 0000000000000000000000000000000000000000..eea236e9267c0cfb2e1e007203fa4b19e8610d08 --- /dev/null +++ b/preprocessing/dataset_json/others/simswap_cdf.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c280f30e086170d973453b6295aea1bb06b5bd6c691f6e99f6242c53dbf81910 +size 4164167 diff --git a/preprocessing/dataset_json/others/simswap_ff.json b/preprocessing/dataset_json/others/simswap_ff.json new file mode 100644 index 0000000000000000000000000000000000000000..9dc12154138267f15b699e8a35ce32b3b14af830 --- /dev/null +++ b/preprocessing/dataset_json/others/simswap_ff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8decca2745b37913a1bef44a0ff0c61b146ecc6bf4f20a50bc953772f6f7f9e7 +size 6209156 diff --git a/preprocessing/dataset_json/others/simswap_ff_ori.json b/preprocessing/dataset_json/others/simswap_ff_ori.json new file mode 100644 index 0000000000000000000000000000000000000000..7765ec21679c2185fb9ccf5831b3fbd0e2ed1301 --- /dev/null +++ b/preprocessing/dataset_json/others/simswap_ff_ori.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e90b07851037a7dd6e523090024240d17be740839948b6a76e491edf1bacfb2d +size 2551449 diff --git a/preprocessing/dataset_json/others/simswap_ori.json b/preprocessing/dataset_json/others/simswap_ori.json new file mode 100644 index 0000000000000000000000000000000000000000..b7511c3630aa91a157e5dc05677d9b70851ba378 --- /dev/null +++ b/preprocessing/dataset_json/others/simswap_ori.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6d739acabe69b2b1784f5fdd4a7e0be4282be19b76efb73168fa086d567df66d +size 2439411 diff --git a/preprocessing/dataset_json/others/stargan.json b/preprocessing/dataset_json/others/stargan.json new file mode 100644 index 0000000000000000000000000000000000000000..1e1c2d01d6c145fbb98f9e42646c1ff075834666 --- /dev/null +++ b/preprocessing/dataset_json/others/stargan.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:68637d6bf573c61339e6057bd633f66e72891e883bfa68d6c7820249a961bebd +size 832449 diff --git a/preprocessing/dataset_json/others/starganv2.json b/preprocessing/dataset_json/others/starganv2.json new file mode 100644 index 0000000000000000000000000000000000000000..906b1940eb56f60eb581f93d14bcef7b91d3388e --- /dev/null +++ b/preprocessing/dataset_json/others/starganv2.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dfcfb700938aab132f9dbfd7d7c6d3d9975d22193674b3bc90f4f1b748e3aae9 +size 871239 diff --git a/preprocessing/dataset_json/others/styleclip.json b/preprocessing/dataset_json/others/styleclip.json new file mode 100644 index 0000000000000000000000000000000000000000..60a7a5ef37775da4c0d414214234268570cb744b --- /dev/null +++ b/preprocessing/dataset_json/others/styleclip.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fd34ed6c3020abf5bd12d341c65fbf20240185540ae6d91832d9265f34e0b2b5 +size 860695 diff --git a/preprocessing/dataset_json/others/tpsm_cdf.json b/preprocessing/dataset_json/others/tpsm_cdf.json new file mode 100644 index 0000000000000000000000000000000000000000..66aa18d11d9405bd83f695565a3e59430d4081df --- /dev/null +++ b/preprocessing/dataset_json/others/tpsm_cdf.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:69f52ad32ff59abe145dff8d121213832846d0b572ce5f7192713ee022efd85c +size 4078296 diff --git a/preprocessing/dataset_json/others/tpsm_ff.json b/preprocessing/dataset_json/others/tpsm_ff.json new file mode 100644 index 0000000000000000000000000000000000000000..aec17231afd81a182aa6b0392cc9a0b6002a66c9 --- /dev/null +++ b/preprocessing/dataset_json/others/tpsm_ff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6ff40d81a12ed59d52001aaf7691745c355569b13509514471afb4e07a356e9f +size 5468776 diff --git a/preprocessing/dataset_json/others/uniface.json b/preprocessing/dataset_json/others/uniface.json new file mode 100644 index 0000000000000000000000000000000000000000..0cacd0c439374800143150d229d901282eb8562b --- /dev/null +++ b/preprocessing/dataset_json/others/uniface.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:77df229f6f785022b6863d0bd91887e7fd8eaba0a58183474a0541aa4e86f7fb +size 4014051 diff --git a/preprocessing/dataset_json/others/uniface_cdf.json b/preprocessing/dataset_json/others/uniface_cdf.json new file mode 100644 index 0000000000000000000000000000000000000000..6a5992fab363772514153ee9aedfcb0cbe97657b --- /dev/null +++ b/preprocessing/dataset_json/others/uniface_cdf.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c5a484abebf4fb4e28a1bc6d496ba0c170dd99857af2c9201af943f6c8d67ec9 +size 4093379 diff --git a/preprocessing/dataset_json/others/uniface_ff.json b/preprocessing/dataset_json/others/uniface_ff.json new file mode 100644 index 0000000000000000000000000000000000000000..91c87daeb9c74af37f0193a6848d825547a7f61b --- /dev/null +++ b/preprocessing/dataset_json/others/uniface_ff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bb9d8d00dada4cc135424783dc4874ea49e513589aa801b408a0883952231188 +size 4725750 diff --git a/preprocessing/dataset_json/others/uniface_ff_ori.json b/preprocessing/dataset_json/others/uniface_ff_ori.json new file mode 100644 index 0000000000000000000000000000000000000000..4998430abdcc646d82af3ba5ae7254cf46c34946 --- /dev/null +++ b/preprocessing/dataset_json/others/uniface_ff_ori.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5f2a35ae784a10903050fb63e66de7f8903f577221f35c3fdd5617bbdeafacb5 +size 2365923 diff --git a/preprocessing/dataset_json/others/uniface_ori.json b/preprocessing/dataset_json/others/uniface_ori.json new file mode 100644 index 0000000000000000000000000000000000000000..0cacd0c439374800143150d229d901282eb8562b --- /dev/null +++ b/preprocessing/dataset_json/others/uniface_ori.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:77df229f6f785022b6863d0bd91887e7fd8eaba0a58183474a0541aa4e86f7fb +size 4014051 diff --git a/preprocessing/dataset_json/others/vqdm.json b/preprocessing/dataset_json/others/vqdm.json new file mode 100644 index 0000000000000000000000000000000000000000..9e7afcc076e3efebd314a703cca831b6c1202dd7 --- /dev/null +++ b/preprocessing/dataset_json/others/vqdm.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:567a4af9a879a647f4c0210ab64c1986f5a67de1141b888ced32a77153908246 +size 79789556 diff --git a/preprocessing/dataset_json/others/wav2lip_cdf.json b/preprocessing/dataset_json/others/wav2lip_cdf.json new file mode 100644 index 0000000000000000000000000000000000000000..25857af1ffe431344cccec5d402e029dc15d6f00 --- /dev/null +++ b/preprocessing/dataset_json/others/wav2lip_cdf.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4e5a728f4d23a779c1315112a755553f3e2483b74278703f3e16fef1c1f4d137 +size 4816587 diff --git a/preprocessing/dataset_json/others/wav2lip_ff.json b/preprocessing/dataset_json/others/wav2lip_ff.json new file mode 100644 index 0000000000000000000000000000000000000000..5550cbaf552b44260f2b1b9278e00212cd3fbcd7 --- /dev/null +++ b/preprocessing/dataset_json/others/wav2lip_ff.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:642452a318e56f64708de6667f2d995a56e480f9fb6b384dcc4a614415be5866 +size 5993071 diff --git a/preprocessing/dataset_json/others/whichisreal.json b/preprocessing/dataset_json/others/whichisreal.json new file mode 100644 index 0000000000000000000000000000000000000000..85aedfaaee1e2858ae458becd177087d2243cf61 --- /dev/null +++ b/preprocessing/dataset_json/others/whichisreal.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0b8600f3038b7c2bdaeb54d27e67fcdc143d2c69b0c9b063d307ee2382a52a1a +size 457913 diff --git a/preprocessing/dataset_json/protocol_2_test.json b/preprocessing/dataset_json/protocol_2_test.json new file mode 100644 index 0000000000000000000000000000000000000000..21b1a8462aca677d9500d3314ddc6f869a14e8cb --- /dev/null +++ b/preprocessing/dataset_json/protocol_2_test.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2d340cb74db152168e6f69e04d4cf66d86dd404efd7e116bd5dc44eebf0d0468 +size 18192587 diff --git a/preprocessing/dataset_json/protocol_2_train.json b/preprocessing/dataset_json/protocol_2_train.json new file mode 100644 index 0000000000000000000000000000000000000000..3bbb952f1a2f39a46774968b9feff1463ab57aa9 --- /dev/null +++ b/preprocessing/dataset_json/protocol_2_train.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:095ba28c300d5983212209808c46da58209b290a0e7ab4bc0550fe8a84d8f835 +size 82419353 diff --git a/preprocessing/dataset_json/protocol_3_test.json b/preprocessing/dataset_json/protocol_3_test.json new file mode 100644 index 0000000000000000000000000000000000000000..d69d9c0498e5756325b77a5eaa9c83ace9c17c9b --- /dev/null +++ b/preprocessing/dataset_json/protocol_3_test.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:63448b50a87804f8ffc46b22cef88ef79a19d4b7ce77edda333c8a84703963ce +size 77096617 diff --git a/preprocessing/dataset_json/protocol_4_test.json b/preprocessing/dataset_json/protocol_4_test.json new file mode 100644 index 0000000000000000000000000000000000000000..b4d620d5b2de387ed5a876d74cd3e9bcf92f2c50 --- /dev/null +++ b/preprocessing/dataset_json/protocol_4_test.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0a91cbab3292ddaae6f669ffded5aa5da0bb52f294656f95721a182e75fce96b +size 101328873 diff --git a/preprocessing/dataset_json/test_DFR.json b/preprocessing/dataset_json/test_DFR.json new file mode 100644 index 0000000000000000000000000000000000000000..17e4e8dbbac02400378028f708e1b376cb0fda19 --- /dev/null +++ b/preprocessing/dataset_json/test_DFR.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:29a79bd8fd2e82df51e1a137687a0a6cfb700b00a17e2699d582611cc1b281db +size 532565 diff --git a/preprocessing/dataset_json/test_FFIW.json b/preprocessing/dataset_json/test_FFIW.json new file mode 100644 index 0000000000000000000000000000000000000000..32e41eacc566cac980cec7253982e78aff1523bb --- /dev/null +++ b/preprocessing/dataset_json/test_FFIW.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:be52709e895d041c8d8a718fa27fb3ed73145f8947bff199f6159d84be0521cf +size 2710859 diff --git a/preprocessing/dataset_json/test_WDF.json b/preprocessing/dataset_json/test_WDF.json new file mode 100644 index 0000000000000000000000000000000000000000..1c4cb42caff3ed00abaaa50a3d0a61773003faf1 --- /dev/null +++ b/preprocessing/dataset_json/test_WDF.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e9269abb33a0c7b53a84b0477e70d1c511bf1a0c072a88ed2172ed59378e9209 +size 2373710 diff --git a/preprocessing/dataset_json/with_stylegan2/protocol_2_test.json b/preprocessing/dataset_json/with_stylegan2/protocol_2_test.json new file mode 100644 index 0000000000000000000000000000000000000000..44cd775d585c0f66cb0d9277a72634c43f759d9e --- /dev/null +++ b/preprocessing/dataset_json/with_stylegan2/protocol_2_test.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ea071203922e008a792726bad3469d5dcc8ee4287b200459212186be175a67b5 +size 18553725 diff --git a/preprocessing/dataset_json/with_stylegan2/protocol_2_train.json b/preprocessing/dataset_json/with_stylegan2/protocol_2_train.json new file mode 100644 index 0000000000000000000000000000000000000000..dc8566d710c449da80c67be2c4a6ada2210cd38c --- /dev/null +++ b/preprocessing/dataset_json/with_stylegan2/protocol_2_train.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:55ba25834dcf6e638f975cbd036504cac4300ada633b7366178d661485a79c14 +size 84177275 diff --git a/preprocessing/dataset_json/with_stylegan2/protocol_3_test.json b/preprocessing/dataset_json/with_stylegan2/protocol_3_test.json new file mode 100644 index 0000000000000000000000000000000000000000..f8f76c02aa0498d2a04b03e83867b5712eb2ccb9 --- /dev/null +++ b/preprocessing/dataset_json/with_stylegan2/protocol_3_test.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3dbb32c8c6d25084100a05143db75106a1f8754ec9e22f6218e809907dc53892 +size 80307852 diff --git a/preprocessing/dataset_json/wukong.json b/preprocessing/dataset_json/wukong.json new file mode 100644 index 0000000000000000000000000000000000000000..ee8985aeeb85bc289a8c0a91741244efd2af0a99 --- /dev/null +++ b/preprocessing/dataset_json/wukong.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3c4b49fa286df26b3e10d50fe68905406dc900e1a8109feda98db2ef14db540c +size 76492382 diff --git a/preprocessing/dlib_tools/readme.md b/preprocessing/dlib_tools/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..4b49c2719af81505c0ddfed82c88249009d7ab3d --- /dev/null +++ b/preprocessing/dlib_tools/readme.md @@ -0,0 +1 @@ +Put the dlib face detector here. \ No newline at end of file diff --git a/preprocessing/dlib_tools/shape_predictor_81_face_landmarks.dat b/preprocessing/dlib_tools/shape_predictor_81_face_landmarks.dat new file mode 100644 index 0000000000000000000000000000000000000000..e6f6c4e3aa9aab54aef4a2ccc997397c8601c68f --- /dev/null +++ b/preprocessing/dlib_tools/shape_predictor_81_face_landmarks.dat @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8cae4375589dd915d9a0a881101bed1bbb4e9887e35e63b024388f1ca25ff869 +size 19743860 diff --git a/preprocessing/logs/readme.md b/preprocessing/logs/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..55bd1c249b312520a2f7e9294864a2eb8b01bfd1 --- /dev/null +++ b/preprocessing/logs/readme.md @@ -0,0 +1 @@ +This folder saves the processing log file for each dataset. \ No newline at end of file diff --git a/preprocessing/preprocess.py b/preprocessing/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..f6e0aeae0b4b942406670a74b50a47b92774b13f --- /dev/null +++ b/preprocessing/preprocess.py @@ -0,0 +1,514 @@ +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-03-29 +# description: Data pre-processing script for deepfake dataset. + + +""" +Original dataset structure before the preprocessing: + +-FaceForensics++ + -original_sequences + -youtube + -c23 + -videos + *.mp4 + -manipulated_sequences + -Deepfakes + -c23 + -videos + -Face2Face + -c23 + -videos + -FaceSwap + -c23 + -videos + -NeuralTextures + -c23 + -videos + -FaceShifter + -c23 + -videos + -DeepFakeDetection + -c23 + -videos + +-Celeb-DF-v1/v2 + -Celeb-synthesis + -videos + -Celeb-real + -videos + -YouTube-real + -videos + +-DFDCP + -method_A + -method_B + -original_videos + +-DeeperForensics-1.0 + -manipulated_videos + -source_videos + +We then additionally obtain "frames", "landmarks", and "mask" directories in same directory as the "videos" folder. +""" + + +import os +import sys +import time +import cv2 +import dlib +import yaml +import logging +import datetime +import glob +import concurrent.futures +import numpy as np +from tqdm import tqdm +from pathlib import Path +from imutils import face_utils +from skimage import transform as trans + + +def create_logger(log_path): + """ + Creates a logger object and saves all messages to a file. + + Args: + log_path (str): The path to save the log file. + + Returns: + logger: The logger object. + """ + # Create logger object + logger = logging.getLogger() + logger.setLevel(logging.INFO) + + # Create file handler and set the formatter + fh = logging.FileHandler(log_path) + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + fh.setFormatter(formatter) + + # Add the file handler to the logger + logger.addHandler(fh) + + # Add a stream handler to print to console + sh = logging.StreamHandler() + sh.setFormatter(formatter) + logger.addHandler(sh) + + return logger + + +def get_keypts(image, face, predictor, face_detector): + # detect the facial landmarks for the selected face + shape = predictor(image, face) + + # select the key points for the eyes, nose, and mouth + leye = np.array([shape.part(37).x, shape.part(37).y]).reshape(-1, 2) + reye = np.array([shape.part(44).x, shape.part(44).y]).reshape(-1, 2) + nose = np.array([shape.part(30).x, shape.part(30).y]).reshape(-1, 2) + lmouth = np.array([shape.part(49).x, shape.part(49).y]).reshape(-1, 2) + rmouth = np.array([shape.part(55).x, shape.part(55).y]).reshape(-1, 2) + + pts = np.concatenate([leye, reye, nose, lmouth, rmouth], axis=0) + + return pts + + +def extract_aligned_face_dlib(face_detector, predictor, image, res=256, mask=None): + def img_align_crop(img, landmark=None, outsize=None, scale=1.3, mask=None): + """ + align and crop the face according to the given bbox and landmarks + landmark: 5 key points + """ + + M = None + target_size = [112, 112] + dst = np.array([ + [30.2946, 51.6963], + [65.5318, 51.5014], + [48.0252, 71.7366], + [33.5493, 92.3655], + [62.7299, 92.2041]], dtype=np.float32) + + if target_size[1] == 112: + dst[:, 0] += 8.0 + + dst[:, 0] = dst[:, 0] * outsize[0] / target_size[0] + dst[:, 1] = dst[:, 1] * outsize[1] / target_size[1] + + target_size = outsize + + margin_rate = scale - 1 + x_margin = target_size[0] * margin_rate / 2. + y_margin = target_size[1] * margin_rate / 2. + + # move + dst[:, 0] += x_margin + dst[:, 1] += y_margin + + # resize + dst[:, 0] *= target_size[0] / (target_size[0] + 2 * x_margin) + dst[:, 1] *= target_size[1] / (target_size[1] + 2 * y_margin) + + src = landmark.astype(np.float32) + + # use skimage tranformation + tform = trans.SimilarityTransform() + tform.estimate(src, dst) + M = tform.params[0:2, :] + + # M: use opencv + # M = cv2.getAffineTransform(src[[0,1,2],:],dst[[0,1,2],:]) + + img = cv2.warpAffine(img, M, (target_size[1], target_size[0])) + + if outsize is not None: + img = cv2.resize(img, (outsize[1], outsize[0])) + + if mask is not None: + mask = cv2.warpAffine(mask, M, (target_size[1], target_size[0])) + mask = cv2.resize(mask, (outsize[1], outsize[0])) + return img, mask + else: + return img, None + + # Image size + height, width = image.shape[:2] + + # Convert to rgb + rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + # Detect with dlib + faces = face_detector(rgb, 1) + if len(faces): + # For now only take the biggest face + face = max(faces, key=lambda rect: rect.width() * rect.height()) + + # Get the landmarks/parts for the face in box d only with the five key points + landmarks = get_keypts(rgb, face, predictor, face_detector) + + # Align and crop the face + cropped_face, mask_face = img_align_crop(rgb, landmarks, outsize=(res, res), mask=mask) + cropped_face = cv2.cvtColor(cropped_face, cv2.COLOR_RGB2BGR) + + # Extract the all landmarks from the aligned face + face_align = face_detector(cropped_face, 1) + if len(face_align) == 0: + return None, None, None + landmark = predictor(cropped_face, face_align[0]) + landmark = face_utils.shape_to_np(landmark) + + return cropped_face, landmark, mask_face + + else: + return None, None, None + +def video_manipulate( + movie_path: Path, + mask_path: Path, + dataset_path: Path, + mode: str, + num_frames: int, + stride: int, + ) -> None: + """ + Processes a single video file by detecting and cropping the largest face in each frame and saving the results. + + Args: + movie_path (str): Path to the video file to process. + dataset_path (str): Path to the dataset directory. + mask_path (str): Path to the mask directory. + mode (str): Either 'fixed_num_frames' or 'fixed_stride'. + num_frames (int): Number of frames to extract from the video. + stride (int): Number of frames to skip between each frame extracted. + margin (float): Amount to increase the size of the face bounding box by. + visualization (bool): Whether to save visualization images. + + Returns: + None + """ + + # Define face detector and predictor models + face_detector = dlib.get_frontal_face_detector() + predictor_path = './dlib_tools/shape_predictor_81_face_landmarks.dat' + ## Check if predictor path exists + if not os.path.exists(predictor_path): + logger.error(f"Predictor path does not exist: {predictor_path}") + sys.exit() + face_predictor = dlib.shape_predictor(predictor_path) + + def facecrop( + org_path: Path, + mask_path: Path, + save_path: Path, + mode: str, + num_frames: int, + stride: int, + face_predictor: dlib.shape_predictor, + face_detector: dlib.fhog_object_detector, + margin: float = 0.5, + visualization: bool = False + ) -> None: + """ + Helper function for cropping face and extracting landmarks. + """ + + # Open the video file + assert org_path.exists(), f"Video file {org_path} does not exist." + cap_org = cv2.VideoCapture(str(org_path)) + if not cap_org.isOpened(): + logger.error(f"Failed to open {org_path}") + return + + if mask_path is not None: + cap_mask = cv2.VideoCapture(str(mask_path)) + if not cap_mask.isOpened(): + logger.error(f"Failed to open {mask_path}") + return + + # Get the number of frames in the video + frame_count_org = int(cap_org.get(cv2.CAP_PROP_FRAME_COUNT)) + + # Get the mode + if mode == 'fixed_num_frames': + # Get the frame rate of the video by dividing the number of frames by the duration (same interval between frames) + frame_idxs = np.linspace(0, frame_count_org - 1, num_frames, endpoint=True, dtype=int) + elif mode == 'fixed_stride': + # Get the frame rate of the video by dividing the number of frames by the duration (same interval between frames) + frame_idxs = np.arange(0, frame_count_org, stride, dtype=int) + + # Iterate through the frames + for cnt_frame in range(frame_count_org): + ret_org, frame_org = cap_org.read() + if mask_path is not None: + ret_mask, frame_mask = cap_mask.read() + else: + frame_mask = None + height, width = frame_org.shape[:-1] + + # Check if the frame was successfully read + if not ret_org: + logger.warning(f"Failed to read frame {cnt_frame} of {org_path}") + break + + # Check if the mask was successfully read + if mask_path is not None and not ret_mask: + logger.warning(f"Failed to read mask {cnt_frame} of {mask_path}") + break + # Check if the frame is one of the frames to extract + if cnt_frame not in frame_idxs: + continue + + # Use the function to extract the aligned and cropped face + if mask_path is not None: + cropped_face, landmarks, masks = extract_aligned_face_dlib(face_detector, face_predictor, frame_org, mask=frame_mask) + else: + cropped_face, landmarks, _ = extract_aligned_face_dlib(face_detector, face_predictor, frame_org, mask=frame_mask) + + # Check if a face was detected and cropped + if cropped_face is None: + logger.warning(f"No faces in frame {cnt_frame} of {org_path}") + continue + + # Check if the landmarks were detected + if landmarks is None: + logger.warning(f"No landmarks in frame {cnt_frame} of {org_path}") + continue + + # Save cropped face, landmarks, and visualization image + save_path_ = save_path / 'frames' / org_path.stem + save_path_.mkdir(parents=True, exist_ok=True) + + # Save cropped face + image_path = save_path_ / f"{cnt_frame:03d}.png" + if not image_path.is_file(): + cv2.imwrite(str(image_path), cropped_face) + + # Save landmarks + land_path = save_path / 'landmarks' / org_path.stem / f"{cnt_frame:03d}.npy" + os.makedirs(os.path.dirname(land_path), exist_ok=True) + np.save(str(land_path), landmarks) + + # Save mask + if mask_path is not None: + mask_path = save_path / 'masks' / org_path.stem / f"{cnt_frame:03d}.png" + os.makedirs(os.path.dirname(mask_path), exist_ok=True) + _, binary_mask = cv2.threshold(masks, 1, 255, cv2.THRESH_BINARY) # obtain binary mask only + cv2.imwrite(str(mask_path), binary_mask) + + # Release the video capture + cap_org.release() + if mask_path is not None: + cap_mask.release() + + # Iterate through the videos in the dataset and extract faces + try: + facecrop(movie_path, mask_path, dataset_path, mode, num_frames, stride, face_predictor, face_detector) + except Exception as e: + logger.error(f"Error processing video {movie_path}: {e}") + + +def preprocess(dataset_path, mask_path, mode, num_frames, stride, logger): + # Define paths to videos in dataset + movies_path_list = sorted([Path(p) for p in glob.glob(os.path.join(dataset_path, '**/*.mp4'), recursive=True)]) + if len(movies_path_list) == 0: + logger.error(f"No videos found in {dataset_path}") + sys.exit() + logger.info(f"{len(movies_path_list)} videos found in {dataset_path}") + + # Define paths to masks in dataset + if mask_path is not None: + masks_path_list = sorted([Path(p) for p in glob.glob(os.path.join(mask_path, '**/*.mp4'), recursive=True)]) + if len(masks_path_list) == 0: + logger.error(f"No masks found in {mask_path}") + # sys.exit() + logger.info(f"{len(masks_path_list)} masks found in {mask_path}") + + # Start timer + start_time = time.monotonic() + + # Define the number of processes based on CPU capabilities + num_processes = os.cpu_count() + + # Use multiprocessing to process videos in parallel + with concurrent.futures.ThreadPoolExecutor(max_workers=num_processes) as executor: + futures = [] + for movie_path in movies_path_list: + # Check if there is a mask for the video + if mask_path is not None: + if movie_path.stem not in [path.stem for path in masks_path_list]: + logger.error(f"No mask for video {movie_path}") + # Define the mask path + mask_path = next((path for path in masks_path_list if path.stem == movie_path.stem), None) + if mask_path is None: + logger.error(f"Mask path not found for video {movie_path}") + # Create a future for each video and submit it for processing + futures.append( + executor.submit( + video_manipulate, + movie_path, + mask_path, + dataset_path, + mode, + num_frames, + stride, + ) + ) + # Wait for all futures to complete and log any errors + for future in tqdm(concurrent.futures.as_completed(futures), total=len(movies_path_list)): + # Print the current time + logger.info(f"Current time: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + try: + future.result() + except Exception as e: + logger.error(f"Error processing video: {e}") + + # End timer + end_time = time.monotonic() + duration_minutes = (end_time - start_time) / 60 + logger.info(f"Total time taken: {duration_minutes:.2f} minutes") + +if __name__ == '__main__': + # from config.yaml load parameters + yaml_path = './config.yaml' + # open the yaml file + try: + with open(yaml_path, 'r') as f: + config = yaml.safe_load(f) + except yaml.parser.ParserError as e: + print("YAML file parsing error:", e) + + # Get the parameters + dataset_name = config['preprocess']['dataset_name']['default'] + dataset_root_path = config['preprocess']['dataset_root_path']['default'] + comp = config['preprocess']['comp']['default'] + mode = config['preprocess']['mode']['default'] + stride = config['preprocess']['stride']['default'] + num_frames = config['preprocess']['num_frames']['default'] + + # use dataset_name and dataset_root_path to get dataset_path + dataset_path = Path(os.path.join(dataset_root_path, dataset_name)) + + # Create logger + log_path = f'./logs/{dataset_name}.log' + logger = create_logger(log_path) + + # Define dataset path based on the input arguments + ## faceforensic++ + if dataset_name == 'FaceForensics++': + sub_dataset_names = ["original_sequences/youtube","original_sequences/actors", \ + "manipulated_sequences/Deepfakes", \ + "manipulated_sequences/Face2Face", "manipulated_sequences/FaceSwap", \ + "manipulated_sequences/NeuralTextures","manipulated_sequences/FaceShifter",\ + "manipulated_sequences/DeepFakeDetection"] + sub_dataset_paths = [Path(os.path.join(dataset_path, name, comp)) for name in sub_dataset_names] + # mask + mask_dataset_names = ["manipulated_sequences/Deepfakes", "manipulated_sequences/Face2Face", \ + "manipulated_sequences/FaceSwap", "manipulated_sequences/NeuralTextures",\ + "manipulated_sequences/DeepFakeDetection"] + # mask_dataset_names = [] + mask_dataset_paths = [Path(os.path.join(dataset_path, name)) for name in mask_dataset_names] + ## Celeb-DF-v1 + elif dataset_name == 'Celeb-DF-v1': + sub_dataset_names = ['Celeb-real', 'Celeb-synthesis', 'YouTube-real'] + sub_dataset_paths = [Path(os.path.join(dataset_path, name)) for name in sub_dataset_names] + + ## Celeb-DF-v2 + elif dataset_name == 'Celeb-DF-v2': + sub_dataset_names = ['Celeb-real', 'Celeb-synthesis', 'YouTube-real'] + sub_dataset_paths = [Path(os.path.join(dataset_path, name)) for name in sub_dataset_names] + + ## DFDCP + elif dataset_name == 'DFDCP': + sub_dataset_names = ['original_videos', 'method_A', 'method_B'] + sub_dataset_paths = [Path(os.path.join(dataset_path, name)) for name in sub_dataset_names] + + ## DFDC-test + elif dataset_name == 'DFDC': + sub_dataset_names = ['test', 'train'] + # train dataset is too large, so we split it into 50 parts + sub_train_dataset_names = ["dfdc_train_part_" + str(i) for i in range(0,50)] + sub_train_dataset_paths = [Path(os.path.join(dataset_path, 'train', name)) for name in sub_train_dataset_names] + sub_dataset_paths = [Path(os.path.join(dataset_path, 'test'))] + sub_train_dataset_paths + + ## DeeperForensics-1.0 + elif dataset_name == 'DeeperForensics-1.0': + real_sub_dataset_names = ['source_videos/' + name for name in os.listdir(os.path.join(dataset_path, 'source_videos'))] + fake_sub_dataset_names = ['manipulated_videos/' + name for name in os.listdir(os.path.join(dataset_path, 'manipulated_videos'))] + real_sub_dataset_names.extend(fake_sub_dataset_names) + sub_dataset_names = real_sub_dataset_names + sub_dataset_paths = [Path(os.path.join(dataset_path, name)) for name in sub_dataset_names] + + ## UADFV + elif dataset_name == 'UADFV': + sub_dataset_names = ['fake', 'real'] + sub_dataset_paths = [Path(os.path.join(dataset_path, name)) for name in sub_dataset_names] + else: + raise ValueError(f"Dataset {dataset_name} not recognized") + + # Check if dataset path exists + if not Path(dataset_path).exists(): + logger.error(f"Dataset path does not exist: {dataset_path}") + sys.exit() + + if 'sub_dataset_paths' in globals() and len(sub_dataset_paths) != 0: + # Check if sub_dataset path exists + for sub_dataset_path in sub_dataset_paths: + if not Path(sub_dataset_path).exists(): + logger.error(f"Sub Dataset path does not exist: {sub_dataset_path}") + sys.exit() + # preprocess each sub_dataset + for sub_dataset_path in sub_dataset_paths: + # only part of FaceForensics++ has mask + if dataset_name == 'FaceForensics++' and sub_dataset_path.parent in mask_dataset_paths: + mask_dataset_path = os.path.join(sub_dataset_path.parent, "masks") + preprocess(sub_dataset_path, mask_dataset_path, mode, num_frames, stride, logger) + else: + preprocess(sub_dataset_path, None, mode, num_frames, stride, logger) + else: + logger.error(f"Sub Dataset path does not exist: {sub_dataset_paths}") + sys.exit() + logger.info("Face cropping complete!") diff --git a/preprocessing/rearrange.py b/preprocessing/rearrange.py new file mode 100644 index 0000000000000000000000000000000000000000..dc192294ad40cb0a7408baf41268d570ad3411d4 --- /dev/null +++ b/preprocessing/rearrange.py @@ -0,0 +1,517 @@ +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-03-29 +# description: Data pre-processing script for deepfake dataset. + + +""" +After running this code, it will generates a json file looks like the below structure for re-arrange data. + +{ + "FaceForensics++": { + "Deepfakes": { + "video1": { + "label": "fake", + "frames": [ + "/path/to/frames/video1/frame1.png", + "/path/to/frames/video1/frame2.png", + ... + ] + }, + "video2": { + "label": "fake", + "frames": [ + "/path/to/frames/video2/frame1.png", + "/path/to/frames/video2/frame2.png", + ... + ] + }, + ... + }, + "original_sequences": { + "youtube": { + "video1": { + "label": "real", + "frames": [ + "/path/to/frames/video1/frame1.png", + "/path/to/frames/video1/frame2.png", + ... + ] + }, + "video2": { + "label": "real", + "frames": [ + "/path/to/frames/video2/frame1.png", + "/path/to/frames/video2/frame2.png", + ... + ] + }, + ... + } + } + } +} +""" + + +import os +import glob +import re +import cv2 +import json +import yaml +import pandas as pd +from pathlib import Path + + +def generate_dataset_file(dataset_name, dataset_root_path, output_file_path, compression_level='c23', perturbation = 'end_to_end'): + """ + Description: + - Generate a JSON file containing information about the specified datasets' videos and frames. + Args: + - dataset: The name of the dataset. + - dataset_path: The path to the dataset. + - output_file_path: The path to the output JSON file. + - compression_level: The compression level of the dataset. + """ + + # Initialize an empty dictionary to store dataset information. + dataset_dict = {} + + + ## FaceForensics++ dataset or DeepfakeDetection dataset + ## Note: DeepfakeDetection dataset is a subset of FaceForensics++ dataset + if dataset_name == 'FaceForensics++' or dataset_name == 'DeepFakeDetection' or dataset_name == 'FaceShifter': + ff_dict = { + 'Deepfakes': 'FF-DF', + 'Face2Face': 'FF-F2F', + 'FaceSwap': 'FF-FS', + 'Real': 'FF-real', + 'DFD_Real': 'DFD_real', + 'NeuralTextures': 'FF-NT', + 'FaceShifter': 'FF-FH', + 'DeepFakeDetection': 'DFD_fake', + 'DeepFakeDetection_original': 'DFD_real', + } + # Load the JSON files for data split + dataset_path = os.path.join(dataset_root_path, 'FaceForensics++') + + # Load the JSON files for data split + with open(file=os.path.join(os.path.join(dataset_root_path, 'FaceForensics++', 'train.json')), mode='r') as f: + train_json = json.load(f) + with open(file=os.path.join(os.path.join(dataset_root_path, 'FaceForensics++', 'val.json')), mode='r') as f: + val_json = json.load(f) + with open(file=os.path.join(os.path.join(dataset_root_path, 'FaceForensics++', 'test.json')), mode='r') as f: + test_json = json.load(f) + + # Create a dictionary for searching the data split + video_to_mode = dict() + for d1, d2 in train_json: + video_to_mode[d1] = 'train' + video_to_mode[d2] = 'train' + video_to_mode[d1+'_'+d2] = 'train' + video_to_mode[d2+'_'+d1] = 'train' + for d1, d2 in val_json: + video_to_mode[d1] = 'val' + video_to_mode[d2] = 'val' + video_to_mode[d1+'_'+d2] = 'val' + video_to_mode[d2+'_'+d1] = 'val' + for d1, d2 in test_json: + video_to_mode[d1] = 'test' + video_to_mode[d2] = 'test' + video_to_mode[d1+'_'+d2] = 'test' + video_to_mode[d2+'_'+d1] = 'test' + + + # FaceForensics++ real dataset + if os.path.isdir(dataset_path) and os.path.isdir(os.path.join(dataset_path, 'original_sequences')): + label = 'Real' + dataset_dict['FaceForensics++'] = {} + dataset_dict['FaceForensics++']['FF-real'] = {} + dataset_dict['FaceForensics++']['DFD_real'] = {} + + # Iterate over all compression levels: c23, c40, raw + dataset_dict['FaceForensics++']['FF-real']['train'] = {} + dataset_dict['FaceForensics++']['FF-real']['test'] = {} + dataset_dict['FaceForensics++']['FF-real']['val'] = {} + for compression_level in os.scandir(os.path.join(dataset_path, 'original_sequences', 'youtube')): + if compression_level.is_dir(): + compression_level = compression_level.name + dataset_dict['FaceForensics++']['FF-real']['train'][compression_level] = {} + dataset_dict['FaceForensics++']['FF-real']['test'][compression_level] = {} + dataset_dict['FaceForensics++']['FF-real']['val'][compression_level] = {} + + # Iterate over all videos + for video_path in os.scandir(os.path.join(dataset_path, 'original_sequences', 'youtube', compression_level, 'frames')): + if video_path.is_dir(): + video_name = video_path.name + mode = video_to_mode[video_name] + frame_paths = [os.path.join(video_path, frame.name) for frame in os.scandir(video_path)] + dataset_dict['FaceForensics++']['FF-real'][mode][compression_level][video_name] = {'label': ff_dict[label], 'frames': frame_paths} + + label = 'DFD_Real' + # Same operations for DeepfakeDetection real dataset + dataset_dict['FaceForensics++']['DFD_real']['train'] = {} + dataset_dict['FaceForensics++']['DFD_real']['test'] = {} + dataset_dict['FaceForensics++']['DFD_real']['val'] = {} + for compression_level in os.scandir(os.path.join(dataset_path, 'original_sequences', 'actors')): + if compression_level.is_dir() and compression_level.name in ["c23", "c40", "raw"]: + compression_level = compression_level.name + dataset_dict['FaceForensics++']['DFD_real']['train'][compression_level] = {} + dataset_dict['FaceForensics++']['DFD_real']['test'][compression_level] = {} + dataset_dict['FaceForensics++']['DFD_real']['val'][compression_level] = {} + # Iterate over all videos + for video_path in os.scandir(os.path.join(dataset_path, 'original_sequences', 'actors', compression_level, 'frames')): + if video_path.is_dir(): + video_name = video_path.name + frame_paths = [os.path.join(video_path, frame.name) for frame in os.scandir(video_path)] + dataset_dict['FaceForensics++']['DFD_real']['train'][compression_level][video_name] = {'label': ff_dict[label], 'frames': frame_paths} + dataset_dict['FaceForensics++']['DFD_real']['test'][compression_level][video_name] = {'label': ff_dict[label], 'frames': frame_paths} + dataset_dict['FaceForensics++']['DFD_real']['val'][compression_level][video_name] = {'label': ff_dict[label], 'frames': frame_paths} + # FaceForensics++ fake datasets + if os.path.isdir(os.path.join(dataset_path, 'manipulated_sequences')): + for label_dir in os.scandir(os.path.join(dataset_path, 'manipulated_sequences')): + if label_dir.is_dir(): + label = label_dir.name + dataset_dict['FaceForensics++'][ff_dict[label]] = {} + dataset_dict['FaceForensics++'][ff_dict[label]]['train'] = {} + dataset_dict['FaceForensics++'][ff_dict[label]]['test'] = {} + dataset_dict['FaceForensics++'][ff_dict[label]]['val'] = {} + + # Iterate over all compression levels: c23, c40, raw + for compression_level in os.scandir(os.path.join(dataset_path, 'manipulated_sequences', label)): + if compression_level.is_dir() and compression_level.name in ["c23", "c40", "raw"]: + compression_level = compression_level.name + dataset_dict['FaceForensics++'][ff_dict[label]]['train'][compression_level] = {} + dataset_dict['FaceForensics++'][ff_dict[label]]['test'][compression_level] = {} + dataset_dict['FaceForensics++'][ff_dict[label]]['val'][compression_level] = {} + # Iterate over all videos + + for video_path in os.scandir(os.path.join(dataset_path, 'manipulated_sequences', label, compression_level, 'frames')): + if video_path.is_dir(): + video_name = video_path.name + frame_paths = [os.path.join(video_path, frame.name) for frame in os.scandir(video_path)] + if label != 'FaceShifter': + mask_paths = os.path.join(dataset_path, 'manipulated_sequences', label, 'c23','masks', video_name) + # mask is all the same for all compression levels + if os.path.exists(mask_paths): + mask_frames_paths = [os.path.join(mask_paths, frame.name) for frame in os.scandir(mask_paths)] + else: + mask_frames_paths = [] + try: + mode = video_to_mode[video_name] + dataset_dict['FaceForensics++'][ff_dict[label]][mode][compression_level][video_name] = {'label': ff_dict[label], 'frames': frame_paths, 'masks': mask_frames_paths} + # DeepfakeDetection dataset + except: + dataset_dict['FaceForensics++'][ff_dict[label]]['train'][compression_level][video_name] = {'label': ff_dict[label], 'frames': frame_paths, 'masks': mask_frames_paths} + dataset_dict['FaceForensics++'][ff_dict[label]]['val'][compression_level][video_name] = {'label': ff_dict[label], 'frames': frame_paths, 'masks': mask_frames_paths} + dataset_dict['FaceForensics++'][ff_dict[label]]['test'][compression_level][video_name] = {'label': ff_dict[label], 'frames': frame_paths, 'masks': mask_frames_paths} + # FaceShifter dataset + else: + mode = video_to_mode[video_name] + dataset_dict['FaceForensics++'][ff_dict[label]][mode][compression_level][video_name] = {'label': ff_dict[label], 'frames': frame_paths} + + + # get the DeepfakeDetection dataset from FaceForensics++ dataset + if dataset_name == 'FaceForensics++': + # Delete the DeepfakeDetection dataset from FaceForensics++ dataset + del dataset_dict['FaceForensics++']['DFD_fake'] + del dataset_dict['FaceForensics++']['DFD_real'] + del dataset_dict['FaceForensics++']['FF-FH'] + elif dataset_name == 'DeepFakeDetection': + # Check if the DeepfakeDetection dataset is in the FaceForensics++ dataset + if 'DFD_fake' in dataset_dict['FaceForensics++'] and \ + 'DFD_real' in dataset_dict['FaceForensics++']: + # Add the DeepfakeDetection dataset to the dataset_dict + dataset_dict['DeepFakeDetection'] = { + 'DFD_fake': dataset_dict['FaceForensics++']['DFD_fake'], + 'DFD_real': dataset_dict['FaceForensics++']['DFD_real'] + } + del dataset_dict['FaceForensics++'] + elif dataset_name == 'FaceShifter': + if 'FF-FH' in dataset_dict['FaceForensics++'] and \ + 'FF-real' in dataset_dict['FaceForensics++']: + # Add the DeepfakeDetection dataset to the dataset_dict + dataset_dict['FaceShifter'] = { + 'FF-FH': dataset_dict['FaceForensics++']['FF-FH'], + 'FF-real': dataset_dict['FaceForensics++']['FF-real'] + } + del dataset_dict['FaceForensics++'] + else: + # TODO + raise ValueError('DeepfakeDetection dataset not found in FaceForensics++ dataset.') + else: + raise ValueError('Invalid dataset name: {}'.format(dataset_name)) + + # if FaceForensics++, based on label and generate the json + if dataset_name == 'FaceForensics++': + for label, value in dataset_dict['FaceForensics++'].items(): + if label != 'FF-real': + with open(os.path.join(output_file_path,f'{label}.json'), 'w') as f: + data = {label: {'FF-real': dataset_dict['FaceForensics++']['FF-real'], + label: value, + }} + json.dump(data, f) + print(f"Finish writing {label}.json") + + ## Celeb-DF-v1 dataset + ## Note: videos in Celeb-DF-v1/2 are not in the same format as in FaceForensics++ dataset + elif dataset_name == 'Celeb-DF-v1': + dataset_path = os.path.join(dataset_root_path, dataset_name) + dataset_dict[dataset_name] = {} + for folder in os.scandir(dataset_path): + if not os.path.isdir(folder): + continue + if folder.name in ['Celeb-real', 'YouTube-real']: + label = 'CelebDFv1_real' + else: + label = 'CelebDFv1_fake' + assert label in ['CelebDFv1_real', 'CelebDFv1_fake'], 'Invalid label: {}'.format(label) + dataset_dict[dataset_name][label] = {} + dataset_dict[dataset_name][label]['train'] = {} + dataset_dict[dataset_name][label]['val'] = {} + dataset_dict[dataset_name][label]['test'] = {} + for video_path in os.scandir(os.path.join(dataset_path, folder.name, 'frames')): + if video_path.is_dir(): + video_name = video_path.name + frame_paths = [os.path.join(video_path, frame.name) for frame in os.scandir(video_path)] + dataset_dict[dataset_name][label]['train'][video_name] = {'label': label, 'frames': frame_paths} + + # Special case for test&val data of Celeb-DF-v1/2 + with open(os.path.join(dataset_root_path, dataset_name, 'List_of_testing_videos.txt'), 'r') as f: + lines = f.readlines() + for line in lines: + if 'real' in line: + label = 'CelebDFv1_real' + elif 'synthesis' in line: + label = 'CelebDFv1_fake' + else: + raise ValueError(f"wrong in processing vidname {dataset_name}: {line}") + + vidname = line.split('\n')[0].split('/')[-1].split('.mp4')[0] + frame_paths = glob.glob( + os.path.join(dataset_root_path, dataset_name, line.split(' ')[1].split('/')[0], 'frames', vidname, '*png')) + dataset_dict[dataset_name][label]['test'][vidname] = {'label': label, 'frames': frame_paths} + dataset_dict[dataset_name][label]['val'][vidname] = {'label': label, 'frames': frame_paths} + + ## Celeb-DF-v2 dataset + ## Note: videos in Celeb-DF-v1/2 are not in the same format as in FaceForensics++ dataset + elif dataset_name == 'Celeb-DF-v2': + dataset_path = os.path.join(dataset_root_path, dataset_name) + dataset_dict[dataset_name] = {} + for folder in os.scandir(dataset_path): + if not os.path.isdir(folder): + continue + if folder.name in ['Celeb-real', 'YouTube-real']: + label = 'CelebDFv2_real' + else: + label = 'CelebDFv2_fake' + assert label in ['CelebDFv2_real', 'CelebDFv2_fake'], 'Invalid label: {}'.format(label) + dataset_dict[dataset_name][label] = {} + dataset_dict[dataset_name][label]['train'] = {} + dataset_dict[dataset_name][label]['val'] = {} + dataset_dict[dataset_name][label]['test'] = {} + for video_path in os.scandir(os.path.join(dataset_path, folder.name, 'frames')): + if video_path.is_dir(): + video_name = video_path.name + frame_paths = [os.path.join(video_path, frame.name) for frame in os.scandir(video_path)] + dataset_dict[dataset_name][label]['train'][video_name] = {'label': label, 'frames': frame_paths} + + # Special case for test&val data of Celeb-DF-v1/2 + with open(os.path.join(dataset_root_path, dataset_name, 'List_of_testing_videos.txt'), 'r') as f: + lines = f.readlines() + for line in lines: + if 'real' in line: + label = 'CelebDFv2_real' + elif 'synthesis' in line: + label = 'CelebDFv2_fake' + else: + raise ValueError(f"wrong in processing vidname {dataset_name}: {line}") + + vidname = line.split('\n')[0].split('/')[-1].split('.mp4')[0] + frame_paths = glob.glob( + os.path.join(dataset_root_path, dataset_name, line.split(' ')[1].split('/')[0], 'frames', vidname, '*png')) + dataset_dict[dataset_name][label]['test'][vidname] = {'label': label, 'frames': frame_paths} + dataset_dict[dataset_name][label]['val'][vidname] = {'label': label, 'frames': frame_paths} + + ## DFDCP dataset + elif dataset_name == 'DFDCP': + dataset_path = os.path.join(dataset_root_path, dataset_name) + #initialize the dataset dictionary + dataset_dict[dataset_name] = {'DFDCP_Real': {'train': {}, 'test': {}, 'val': {}}, + 'DFDCP_FakeA': {'train': {}, 'test': {}, 'val': {}}, + 'DFDCP_FakeB': {'train': {}, 'test': {}, 'val': {}}} + # Open the dataset information file ('dataset.json') and parse its contents + with open(os.path.join(dataset_path, 'dataset.json' ), 'r') as f: + dataset_info = json.load(f) + # Iterate over the dataset_info dictionary and extract the index and file name for each video + for dataset in dataset_info.keys(): + index = dataset.split('/')[0] + vidname = dataset.split('/')[-1].split(".")[0] + if Path(os.path.join(dataset_path, index, 'frames', vidname)).exists(): + frame_paths = glob.glob(os.path.join(dataset_path, index, 'frames', vidname, '*png')) + if len(frame_paths) == 0: + continue + label = dataset_info[dataset]['label'] + if label == 'real': + label = 'DFDCP_Real' + elif label == 'fake' and index == 'method_A': + label = 'DFDCP_FakeA' + elif label == 'fake' and index == 'method_B': + label = 'DFDCP_FakeB' + else: + raise ValueError(f"wrong in processing vidname {dataset_name}: {line}") + set_attr = dataset_info[dataset]['set'] # train, test, val + dataset_dict[dataset_name][label][set_attr][vidname] = {'label': label, 'frames': frame_paths} + # Special case for val data of DFDCP + for label in ['DFDCP_Real', 'DFDCP_FakeA', 'DFDCP_FakeB']: + dataset_dict[dataset_name][label]['val'] = dataset_dict[dataset_name][label]['test'] + + ## DFDC dataset + elif dataset_name == 'DFDC': + dataset_path = os.path.join(dataset_root_path, dataset_name) + dataset_dict[dataset_name] = {'DFDC_Real': {'train': {}, 'test': {}, 'val': {}}, + 'DFDC_Fake': {'train': {}, 'test': {}, 'val': {}}} + for folder in os.scandir(dataset_path): + if not os.path.isdir(folder): + continue + if folder.name in ['test']: + # read csv file + df = pd.read_csv(os.path.join(dataset_path,folder.name,'labels.csv')) + labels = ['DFDC_Real','DFDC_Fake'] + + for index, row in df.iterrows(): + vidname = row['filename'].split('.mp4')[0] + label = labels[row['label']] + assert label in ['DFDC_Real','DFDC_Fake'], 'Invalid label: {}'.format(label) + frame_paths = glob.glob(os.path.join(dataset_path, folder.name,'frames', vidname, '*png')) + if len(frame_paths) == 0: + continue + dataset_dict[dataset_name][label]['test'][vidname] = {'label': label, 'frames': frame_paths} + dataset_dict[dataset_name][label]['val'] = {'label': label, 'frames': frame_paths} + + elif folder.name in ['train']: + num_file = 0 + for dfdc_train_part in os.scandir(os.path.join(dataset_path, folder.name)): + if not os.path.isdir(dfdc_train_part): + continue + num_file += 1 + print('processing {}th file in 50 files.'.format(num_file)) + with open(os.path.join(dfdc_train_part, 'metadata.json'), 'r') as f: + metadata = json.load(f) + for video_path in os.scandir(os.path.join(dfdc_train_part, 'frames')): + if video_path.is_dir(): + video_name = video_path.name + label = metadata[video_name + ".mp4"]["label"] + assert label in ['REAL', 'FAKE'], 'Invalid label: {}'.format(label) + if label == 'REAL': + label = 'DFDC_Real' + else: + label = 'DFDC_Fake' + frame_paths = [os.path.join(video_path, frame.name) for frame in os.scandir(video_path)] + dataset_dict[dataset_name][label]['train'][video_name] = {'label': label, 'frames': frame_paths} + dataset_dict[dataset_name][label]['val'][video_name] = {'label': label, 'frames': frame_paths} + + ## DeeperForensics-1.0 dataset + elif dataset_name == 'DeeperForensics-1.0': + with open(os.path.join(dataset_root_path, dataset_name, 'lists/splits/train.txt'), 'r') as f: + train_txt = f.readlines() + train_txt = [line.strip().split('.')[0] for line in train_txt] + with open(os.path.join(dataset_root_path, dataset_name, 'lists/splits/test.txt'), 'r') as f: + test_txt = f.readlines() + test_txt = [line.strip().split('.')[0] for line in test_txt] + with open(os.path.join(dataset_root_path, dataset_name, 'lists/splits/val.txt'), 'r') as f: + val_txt = f.readlines() + val_txt = [line.strip().split('.')[0] for line in val_txt] + dataset_path = os.path.join(dataset_root_path, dataset_name) + dataset_dict[dataset_name] = {'DF_real': {'train': {}, 'test': {}, 'val': {}}, + 'DF_fake': {'train': {}, 'test': {}, 'val': {}}} + if not Path(os.path.join(dataset_path, 'manipulated_videos', perturbation)).exists(): + raise ValueError(f"wrong in processing perturbation {perturbation} in manipulated_videos") + print(f"processing perturbation {perturbation} in manipulated_videos") + for video_path in os.scandir(os.path.join(dataset_path, 'manipulated_videos', perturbation, 'frames')): + if video_path.is_dir(): + video_name = video_path.name + if video_name in train_txt: + set_attr = 'train' + elif video_name in test_txt: + set_attr = 'test' + elif video_name in val_txt: + set_attr = 'val' + else: + raise ValueError(f"wrong in processing vidname {dataset_name}: {line}") + label = 'DF_fake' + frame_paths = [os.path.join(video_path, frame.name) for frame in os.scandir(video_path)] + ## if frame image in frame_paths is not the correct png, skip this frame yxh + for frame_path in frame_paths: + if cv2.imread(frame_path) is None: + frame_paths.remove(frame_path) + dataset_dict[dataset_name][label][set_attr][video_name] = {'label': label, 'frames': frame_paths} + for actor_path in os.scandir(os.path.join(dataset_path, 'source_videos')): + print("actor",actor_path.name) + if not os.path.isdir(actor_path): + continue + label = 'DF_real' + video_paths = [os.path.join(actor_path, 'frames', video.name) for video in os.scandir(os.path.join(actor_path, 'frames'))] + for video_path in video_paths: + video_name = video_path.split('/')[-1] + frame_paths = [os.path.join(video_path, frame.name) for frame in os.scandir(video_path)] + ## if frame image in frame_paths is not the correct png, skip this frame yxh + for frame_path in frame_paths: + if cv2.imread(frame_path) is None: + frame_paths.remove(frame_path) + dataset_dict[dataset_name][label]['train'][video_name] = {'label': label, 'frames': frame_paths} + dataset_dict[dataset_name][label]['test'][video_name] = {'label': label, 'frames': frame_paths} + dataset_dict[dataset_name][label]['val'][video_name] = {'label': label, 'frames': frame_paths} + + ## UADFV dataset + elif dataset_name == 'UADFV': + dataset_path = os.path.join(dataset_root_path, dataset_name) + dataset_dict[dataset_name] = {'UADFV_Real': {'train': {}, 'test': {}, 'val': {}}, + 'UADFV_Fake': {'train': {}, 'test': {}, 'val': {}}} + for folder in os.scandir(dataset_path): + if not os.path.isdir(folder): + continue + elif folder.name in ['fake']: + for video_path in os.scandir(os.path.join(dataset_path, folder.name, 'frames')): + if video_path.is_dir(): + video_name = video_path.name + label = 'UADFV_Fake' + frame_paths = [os.path.join(video_path, frame.name) for frame in os.scandir(video_path)] + dataset_dict[dataset_name][label]['train'][video_name] = {'label': label, 'frames': frame_paths} + dataset_dict[dataset_name][label]['test'][video_name] = {'label': label, 'frames': frame_paths} + dataset_dict[dataset_name][label]['val'][video_name] = {'label': label, 'frames': frame_paths} + elif folder.name in ['real']: + for video_path in os.scandir(os.path.join(dataset_path, folder.name, 'frames')): + if video_path.is_dir(): + video_name = video_path.name + label = 'UADFV_Real' + frame_paths = [os.path.join(video_path, frame.name) for frame in os.scandir(video_path)] + dataset_dict[dataset_name][label]['train'][video_name] = {'label': label, 'frames': frame_paths} + dataset_dict[dataset_name][label]['test'][video_name] = {'label': label, 'frames': frame_paths} + dataset_dict[dataset_name][label]['val'][video_name] = {'label': label, 'frames': frame_paths} + + # Convert the dataset dictionary to JSON format and save to file + output_file_path = os.path.join(output_file_path, dataset_name + '.json') + with open(output_file_path, 'w') as f: + json.dump(dataset_dict, f) + # print the successfully generated dataset dictionary + print(f"{dataset_name}.json generated successfully.") + +if __name__ == '__main__': + # from config.yaml load parameters + yaml_path = './config.yaml' + # open the yaml file + try: + with open(yaml_path, 'r') as f: + config = yaml.safe_load(f) + except yaml.parser.ParserError as e: + print("YAML file parsing error:", e) + + dataset_name = config['rearrange']['dataset_name']['default'] + dataset_root_path = config['rearrange']['dataset_root_path']['default'] + output_file_path = config['rearrange']['output_file_path']['default'] + comp = config['rearrange']['comp']['default'] + perturbation = config['rearrange']['perturbation']['default'] + # Call the generate_dataset_file function + generate_dataset_file(dataset_name, dataset_root_path, output_file_path, comp, perturbation) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..1b7e20348257022372d65a8fad3bcaa223dbffb4 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,117 @@ +absl-py==2.3.0 +accelerate==1.12.0 +albumentations==1.1.0 +cachetools==5.5.2 +certifi==2025.4.26 +charset-normalizer==3.4.2 +clip @ git+https://github.com/openai/CLIP.git@dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1 +contourpy==1.3.2 +cycler==0.12.1 +dlib==19.24.0 +efficientnet_pytorch==0.7.1 +einops==0.8.1 +et_xmlfile==2.0.0 +filelock==3.18.0 +filterpy==1.4.5 +fonttools==4.58.1 +fsspec==2025.5.1 +ftfy==6.3.1 +fvcore==0.1.5.post20221221 +google-auth==2.40.2 +google-auth-oauthlib==0.4.6 +grpcio==1.72.1 +hf-xet==1.1.3 +huggingface-hub==0.36.0 +idna==3.10 +imageio==2.37.0 +imgaug==0.4.0 +imutils==0.5.4 +iopath==0.1.10 +Jinja2==3.1.6 +joblib==1.5.1 +kiwisolver==1.4.8 +kornia==0.8.1 +kornia_rs==0.1.9 +lazy_loader==0.4 +lightning-utilities==0.15.2 +lmdb==1.6.2 +loralib==0.1.2 +Markdown==3.8 +MarkupSafe==3.0.2 +matplotlib==3.10.3 +mpmath==1.3.0 +munch==4.0.0 +networkx==3.4.2 +numpy==1.26.0 +nvidia-cublas-cu11==11.11.3.6 +nvidia-cuda-cupti-cu11==11.8.87 +nvidia-cuda-nvrtc-cu11==11.8.89 +nvidia-cuda-runtime-cu11==11.8.89 +nvidia-cudnn-cu11==9.1.0.70 +nvidia-cufft-cu11==10.9.0.58 +nvidia-curand-cu11==10.3.0.86 +nvidia-cusolver-cu11==11.4.1.48 +nvidia-cusparse-cu11==11.7.5.86 +nvidia-ml-py==12.575.51 +nvidia-nccl-cu11==2.21.5 +nvidia-nvtx-cu11==11.8.86 +nvitop==1.5.1 +oauthlib==3.2.2 +opencv-python==4.6.0.66 +opencv-python-headless==4.11.0.86 +openpyxl==3.1.5 +packaging==25.0 +pandas==2.2.3 +peft==0.18.0 +pillow==11.2.1 +portalocker==3.1.1 +pretrainedmodels==0.7.4 +prettytable==3.16.0 +protobuf==3.19.6 +psutil==7.0.0 +pyarrow==20.0.0 +pyasn1==0.6.1 +pyasn1_modules==0.4.2 +pyparsing==3.2.3 +python-dateutil==2.9.0.post0 +pytz==2025.2 +PyWavelets==1.8.0 +PyYAML==6.0 +qudida==0.0.4 +regex==2024.11.6 +requests==2.32.3 +requests-oauthlib==2.0.0 +rsa==4.9.1 +safetensors==0.5.3 +scikit-image==0.25.2 +scikit-learn==1.6.1 +scipy==1.15.3 +seaborn==0.13.2 +segmentation-models-pytorch==0.3.2 +shapely==2.1.1 +simplejson==3.20.1 +six==1.17.0 +sympy==1.14.0 +tabulate==0.9.0 +tensorboard==2.19.0 +tensorboard-data-server==0.7.2 +tensorboard-plugin-wit==1.8.1 +termcolor==3.1.0 +threadpoolctl==3.6.0 +tifffile==2025.5.10 +timm==1.0.24 +tokenizers==0.22.1 +torch==2.7.1+cu118 +torchaudio==2.7.1+cu118 +torchmetrics==1.8.2 +torchtoolbox==0.1.8.2 +torchvision==0.22.1+cu118 +tqdm==4.61.0 +transformers==4.57.3 +triton==3.3.1 +typing_extensions==4.14.0 +tzdata==2025.2 +urllib3==2.4.0 +wcwidth==0.2.13 +Werkzeug==3.1.3 +yacs==0.1.8 diff --git a/test.sh b/test.sh new file mode 100644 index 0000000000000000000000000000000000000000..a89f92da66c5fbae3c0e7b3d1b41aebec40a19f6 --- /dev/null +++ b/test.sh @@ -0,0 +1,13 @@ +#### clip_large_fft inference and evaluation on protocol2&3 +python -m torch.distributed.launch --master_port=29510 --nproc_per_node=8 training/test_pall.py --ddp \ + --test_dataset "protocol_2_test" "protocol_3_test" \ + --detector_path ./training/config/detector/clip_large_fft.yaml \ + --weights_path logs/clip_models/clip_large_fft_2025-11-08-13-56-51 + + +#### clip_large_fft inference and evaluation on protocol4 +python -m torch.distributed.launch --master_port=29512 --nproc_per_node=8 training/test_pall.py --ddp \ + --test_dataset "protocol_4_test" \ + --detector_path ./training/config/detector/clip_large_fft.yaml \ + --weights_path logs/clip_models/clip_large_fft_2025-11-08-13-56-51 \ + --test_config test_config_p4.yaml diff --git a/train.sh b/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..8099aa5885f09ebaac2921b91326f14218e0ac40 --- /dev/null +++ b/train.sh @@ -0,0 +1,5 @@ +python -m torch.distributed.launch --master_port=29505 --nproc_per_node=8 training/train.py \ + --detector_path ./training/config/detector/clip_large_fft.yaml \ + --no-save_feat --ddp + + diff --git a/training/config/__init__.py b/training/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..676145d777810e4a51bdaf59fdec4f5358aae349 --- /dev/null +++ b/training/config/__init__.py @@ -0,0 +1,7 @@ +import os +import sys +current_file_path = os.path.abspath(__file__) +parent_dir = os.path.dirname(os.path.dirname(current_file_path)) +project_root_dir = os.path.dirname(parent_dir) +sys.path.append(parent_dir) +sys.path.append(project_root_dir) diff --git a/training/config/backbone/cls_hrnet_w48.yaml b/training/config/backbone/cls_hrnet_w48.yaml new file mode 100644 index 0000000000000000000000000000000000000000..26cfa87b7e41e12a34c27ee30649a926652fdcb7 --- /dev/null +++ b/training/config/backbone/cls_hrnet_w48.yaml @@ -0,0 +1,103 @@ +CUDNN: + BENCHMARK: true + DETERMINISTIC: false + ENABLED: true +GPUS: (0,1,2,3) +OUTPUT_DIR: 'output' +LOG_DIR: 'log' +WORKERS: 4 +PRINT_FREQ: 100 + +DATASET: + DATASET: lip + ROOT: 'data/' + TEST_SET: 'list/lip/valList.txt' + TRAIN_SET: 'list/lip/trainList.txt' + NUM_CLASSES: 20 +MODEL: + NAME: cls_hrnet + #IMAGE_SIZE: + # - 224 + # - 224 + EXTRA: + STAGE1: + NUM_MODULES: 1 + NUM_RANCHES: 1 + BLOCK: BOTTLENECK + NUM_BLOCKS: + - 4 + NUM_CHANNELS: + - 64 + FUSE_METHOD: SUM + STAGE2: + NUM_MODULES: 1 + NUM_BRANCHES: 2 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + NUM_CHANNELS: + - 48 + - 96 + FUSE_METHOD: SUM + STAGE3: + NUM_MODULES: 4 + NUM_BRANCHES: 3 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + - 4 + NUM_CHANNELS: + - 48 + - 96 + - 192 + FUSE_METHOD: SUM + STAGE4: + NUM_MODULES: 3 + NUM_BRANCHES: 4 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + - 4 + - 4 + NUM_CHANNELS: + - 48 + - 96 + - 192 + - 384 + FUSE_METHOD: SUM +LOSS: + USE_OHEM: false + OHEMTHRES: 0.9 + OHEMKEEP: 131072 +TRAIN: + IMAGE_SIZE: + - 473 + - 473 + BASE_SIZE: 473 + BATCH_SIZE_PER_GPU: 10 + SHUFFLE: true + BEGIN_EPOCH: 0 + END_EPOCH: 150 + RESUME: true + OPTIMIZER: sgd + LR: 0.007 + WD: 0.0005 + MOMENTUM: 0.9 + NESTEROV: false + FLIP: true + MULTI_SCALE: true + DOWNSAMPLERATE: 1 + IGNORE_LABEL: 255 + SCALE_FACTOR: 11 +TEST: + IMAGE_SIZE: + - 473 + - 473 + BASE_SIZE: 473 + BATCH_SIZE_PER_GPU: 16 + NUM_SAMPLES: 2000 + FLIP_TEST: false + MULTI_SCALE: false diff --git a/training/config/config/__init__.py b/training/config/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..379ba40bcdf97ea6dc4fca3dc8215da42b68cb31 --- /dev/null +++ b/training/config/config/__init__.py @@ -0,0 +1,7 @@ +import os +import sys +current_file_path = os.path.abspath(__file__) +parent_dir = os.path.dirname(os.path.dirname(current_file_path)) +project_root_dir = os.path.dirname(parent_dir) +sys.path.append(parent_dir) +sys.path.append(project_root_dir) diff --git a/training/config/config/backbone/cls_hrnet_w48.yaml b/training/config/config/backbone/cls_hrnet_w48.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b273cf5cada1c13e6b4c91c1f6b109c7dc0dd57a --- /dev/null +++ b/training/config/config/backbone/cls_hrnet_w48.yaml @@ -0,0 +1,103 @@ +CUDNN: + BENCHMARK: true + DETERMINISTIC: false + ENABLED: true +GPUS: (0,1,2,3) +OUTPUT_DIR: 'output' +LOG_DIR: 'log' +WORKERS: 4 +PRINT_FREQ: 100 + +DATASET: + DATASET: lip + ROOT: 'data/' + TEST_SET: 'list/lip/valList.txt' + TRAIN_SET: 'list/lip/trainList.txt' + NUM_CLASSES: 20 +MODEL: + NAME: cls_hrnet + #IMAGE_SIZE: + # - 224 + # - 224 + EXTRA: + STAGE1: + NUM_MODULES: 1 + NUM_RANCHES: 1 + BLOCK: BOTTLENECK + NUM_BLOCKS: + - 4 + NUM_CHANNELS: + - 64 + FUSE_METHOD: SUM + STAGE2: + NUM_MODULES: 1 + NUM_BRANCHES: 2 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + NUM_CHANNELS: + - 48 + - 96 + FUSE_METHOD: SUM + STAGE3: + NUM_MODULES: 4 + NUM_BRANCHES: 3 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + - 4 + NUM_CHANNELS: + - 48 + - 96 + - 192 + FUSE_METHOD: SUM + STAGE4: + NUM_MODULES: 3 + NUM_BRANCHES: 4 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + - 4 + - 4 + NUM_CHANNELS: + - 48 + - 96 + - 192 + - 384 + FUSE_METHOD: SUM +LOSS: + USE_OHEM: false + OHEMTHRES: 0.9 + OHEMKEEP: 131072 +TRAIN: + IMAGE_SIZE: + - 473 + - 473 + BASE_SIZE: 473 + BATCH_SIZE_PER_GPU: 10 + SHUFFLE: true + BEGIN_EPOCH: 0 + END_EPOCH: 150 + RESUME: true + OPTIMIZER: sgd + LR: 0.007 + WD: 0.0005 + MOMENTUM: 0.9 + NESTEROV: false + FLIP: true + MULTI_SCALE: true + DOWNSAMPLERATE: 1 + IGNORE_LABEL: 255 + SCALE_FACTOR: 11 +TEST: + IMAGE_SIZE: + - 473 + - 473 + BASE_SIZE: 473 + BATCH_SIZE_PER_GPU: 16 + NUM_SAMPLES: 2000 + FLIP_TEST: false + MULTI_SCALE: false diff --git a/training/config/config/test_config.yaml b/training/config/config/test_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..336b5cb14dcc03e4e94193bfe6e729ecc9c3de7a --- /dev/null +++ b/training/config/config/test_config.yaml @@ -0,0 +1,38 @@ +mode: test +lmdb: True +dataset_root_rgb: './datasets' +lmdb_dir: 'I:\transform_2_lmdb' +dataset_json_folder: 'preprocessing/dataset_json_v3' +label_dict: + # DFD + DFD_fake: 1 + DFD_real: 0 + # FF++ + FaceShifter(FF-real+FF-FH) + FF-SH: 1 + FF-F2F: 1 + FF-DF: 1 + FF-FS: 1 + FF-NT: 1 + FF-FH: 1 + FF-real: 0 + # CelebDF + CelebDFv1_real: 0 + CelebDFv1_fake: 1 + CelebDFv2_real: 0 + CelebDFv2_fake: 1 + # DFDCP + DFDCP_Real: 0 + DFDCP_FakeA: 1 + DFDCP_FakeB: 1 + # DFDC + DFDC_Fake: 1 + DFDC_Real: 0 + # DeeperForensics-1.0 + DF_fake: 1 + DF_real: 0 + # UADFV + UADFV_Fake: 1 + UADFV_Real: 0 + # Roop + roop_Real: 0 + roop_Fake: 1 \ No newline at end of file diff --git a/training/config/config/train_config.yaml b/training/config/config/train_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..eab0431f9ef9c3e880b5947d4b8d0d863d12a29c --- /dev/null +++ b/training/config/config/train_config.yaml @@ -0,0 +1,43 @@ +mode: train +lmdb: True +dry_run: false +dataset_root_rgb: './datasets' +lmdb_dir: 'I:\transform_2_lmdb' +dataset_json_folder: '/data/home/zhiyuanyan/DeepfakeBenchv2/preprocessing/dataset_json' +SWA: False +save_avg: True +# log_dir: ./logs/training/ +# label settings +label_dict: + # DFD + DFD_fake: 1 + DFD_real: 0 + # FF++ + FaceShifter(FF-real+FF-FH) + FF-SH: 1 + FF-F2F: 1 + FF-DF: 1 + FF-FS: 1 + FF-NT: 1 + FF-FH: 1 + FF-real: 0 + # CelebDF + CelebDFv1_real: 0 + CelebDFv1_fake: 1 + CelebDFv2_real: 0 + CelebDFv2_fake: 1 + # DFDCP + DFDCP_Real: 0 + DFDCP_FakeA: 1 + DFDCP_FakeB: 1 + # DFDC + DFDC_Fake: 1 + DFDC_Real: 0 + # DeeperForensics-1.0 + DF_fake: 1 + DF_real: 0 + # UADFV + UADFV_Fake: 1 + UADFV_Real: 0 + # Roop + roop_Real: 0 + roop_Fake: 1 \ No newline at end of file diff --git a/training/config/detector/ae.yaml b/training/config/detector/ae.yaml new file mode 100644 index 0000000000000000000000000000000000000000..71aec9988ef80377fc7c408a01ea42f9c3f0aff0 --- /dev/null +++ b/training/config/detector/ae.yaml @@ -0,0 +1,88 @@ +# log dir +log_dir: logs/ae_models + +# model setting +pretrained: openai/clip-vit-large-patch14 # path to a pre-trained model, if using one +model_name: ae_detector # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 36 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 # compression-level for videos +train_batchSize: 32 # training batch size +test_batchSize: 32 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 16} # number of frames to use per video in training and testing +resolution: 256 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: acc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations + +# save latest ckpt +save_latest_ckpt: true diff --git a/training/config/detector/ae_resnet.yaml b/training/config/detector/ae_resnet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0868b3d2f2db9d23ed1b220cca45581ecebe0faa --- /dev/null +++ b/training/config/detector/ae_resnet.yaml @@ -0,0 +1,88 @@ +# log dir +log_dir: logs/ae_models + +# model setting +pretrained: openai/clip-vit-large-patch14 # path to a pre-trained model, if using one +model_name: resnet34_ae_trace # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 36 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 # compression-level for videos +train_batchSize: 32 # training batch size +test_batchSize: 32 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 16} # number of frames to use per video in training and testing +resolution: 256 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: acc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations + +# save latest ckpt +save_latest_ckpt: true diff --git a/training/config/detector/altfreezing.yaml b/training/config/detector/altfreezing.yaml new file mode 100644 index 0000000000000000000000000000000000000000..101ae45beea838feee58658711464c14e439c55f --- /dev/null +++ b/training/config/detector/altfreezing.yaml @@ -0,0 +1,87 @@ +# log dir +log_dir: /Youtu_Pangu_Security_Public_cq11/shunliwang/DeepFakeBench/logs/altfreezing + +# model setting +pretrained: training/weights/I3D_8x8_R50.pth # path to a pre-trained model, if using one +model_name: altfreezing # model name +backbone_name: null # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 1 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FaceForensics++] +test_dataset: [Celeb-DF-v2, DeepFakeDetection, FaceShifter] + +compression: c23 # compression-level for videos +train_batchSize: 8 # training batch size +test_batchSize: 8 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 300, 'test': 300} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input +video_mode: True # whether to use video-level data +clip_size: 16 # number of frames in each clip + + +# data augmentation +use_data_augmentation: true # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 100 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations diff --git a/training/config/detector/capsule_net.yaml b/training/config/detector/capsule_net.yaml new file mode 100644 index 0000000000000000000000000000000000000000..54defc7aedd4b744042ca77061e34335fadf3bcc --- /dev/null +++ b/training/config/detector/capsule_net.yaml @@ -0,0 +1,86 @@ +# log dir +log_dir: /mntcephfs/lab_data/yuanxinhang/benchmark_results/logs_final/capsule_new + +# model setting +pretrained: ./training/pretrained/xception-b5690688.pth # path to a pre-trained model, if using one +model_name: capsule_net # model name +backbone_name: xception # no backbone VGGextractor +num_classes: 2 + +#backbone setting +backbone_config: + mode: original + num_classes: 2 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FaceForensics++, Celeb-DF-v2, DFDCP, DeepFakeDetection] +test_dataset: [FaceForensics++, DeepFakeDetection, Celeb-DF-v2, DFDCP] + +train_batchSize: 32 # training batch size +test_batchSize: 32 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing +resolution: 256 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input +compression: c23 # compression-level for videos + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.00008 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 10 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: false # whether to save checkpoint + +# loss function +loss_func: capsule_loss # loss function to use +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda + +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations diff --git a/training/config/detector/clip.yaml b/training/config/detector/clip.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ebf015dc16e801078d1c63c61c4fba7a10daf84a --- /dev/null +++ b/training/config/detector/clip.yaml @@ -0,0 +1,86 @@ +# log dir +log_dir: /data/home/zhiyuanyan/DeepfakeBench/logs_debug/clip + +# model setting +pretrained: openai/clip-vit-base-patch16 # path to a pre-trained model, if using one +model_name: clip # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 2 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FaceForensics++] +test_dataset: [Celeb-DF-v2, DeepFakeDetection, FaceShifter, DFDC, DFDCP] + +compression: c23 # compression-level for videos +train_batchSize: 32 # training batch size +test_batchSize: 32 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 32} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: true # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.00002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 100 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda + +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations diff --git a/training/config/detector/clip_adapter_two_3dconv.yaml b/training/config/detector/clip_adapter_two_3dconv.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b0aa826352c575d4a3644f39026cd8ad4b8d37a6 --- /dev/null +++ b/training/config/detector/clip_adapter_two_3dconv.yaml @@ -0,0 +1,86 @@ +# log dir +log_dir: /Youtu_Pangu_Security_Public_cq11/shunliwang/DeepFakeBench/logs/clip_adapter_two_3dconv + +# model setting +pretrained: ViT-L14 # path to a pre-trained model, if using one +model_name: clip_adapter_two_3dconv # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 2 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FaceForensics++] +test_dataset: [Celeb-DF-v2, FaceShifter, DeepFakeDetection] + +compression: c23 # compression-level for videos +train_batchSize: 16 # training batch size +test_batchSize: 16 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 32} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input +video_mode: True # whether to use video-level data +clip_size: 8 # number of frames in each clip + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations diff --git a/training/config/detector/clip_base_fft.yaml b/training/config/detector/clip_base_fft.yaml new file mode 100644 index 0000000000000000000000000000000000000000..322adc20409b2bcafae8eb1930fee583ba268982 --- /dev/null +++ b/training/config/detector/clip_base_fft.yaml @@ -0,0 +1,85 @@ +# log dir +log_dir: logs/clip_models + +# model setting +pretrained: openai/clip-vit-base-patch16 # path to a pre-trained model, if using one +model_name: clip_base_fft # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 36 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 # compression-level for videos +train_batchSize: 32 # training batch size +test_batchSize: 32 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 16} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.00002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: acc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations diff --git a/training/config/detector/clip_base_vid.yaml b/training/config/detector/clip_base_vid.yaml new file mode 100644 index 0000000000000000000000000000000000000000..11a8ff3fccd2e228d83d79ed835c3bffa5a480b8 --- /dev/null +++ b/training/config/detector/clip_base_vid.yaml @@ -0,0 +1,86 @@ +# log dir +log_dir: /Youtu_Pangu_Security_Public_cq11/shunliwang/DeepFakeBench/logs/clip_base_vid + +# model setting +pretrained: openai/clip-vit-base-patch16 # path to a pre-trained model, if using one +model_name: clip_base_vid # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 2 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FaceForensics++] +test_dataset: [Celeb-DF-v2, FaceShifter, DeepFakeDetection] + +compression: c23 # compression-level for videos +train_batchSize: 16 # training batch size +test_batchSize: 16 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 32} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input +video_mode: True # whether to use video-level data +clip_size: 8 # number of frames in each clip + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations diff --git a/training/config/detector/clip_contrast.yaml b/training/config/detector/clip_contrast.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a556edbdca3e91329d07354de1bcdce88f064e9a --- /dev/null +++ b/training/config/detector/clip_contrast.yaml @@ -0,0 +1,88 @@ +# log dir +log_dir: logs/clip_models + +# model setting +pretrained: openai/clip-vit-large-patch14 # path to a pre-trained model, if using one +model_name: clip_contrast # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 36 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 # compression-level for videos +train_batchSize: 32 # training batch size +test_batchSize: 32 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 16} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.00002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: acc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations + +# save latest ckpt +save_latest_ckpt: true diff --git a/training/config/detector/clip_contrast_hier.yaml b/training/config/detector/clip_contrast_hier.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1c18625437b074fbf0f8816478902918a72217d1 --- /dev/null +++ b/training/config/detector/clip_contrast_hier.yaml @@ -0,0 +1,88 @@ +# log dir +log_dir: logs/clip_models + +# model setting +pretrained: openai/clip-vit-large-patch14 # path to a pre-trained model, if using one +model_name: clip_contrast_hier # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 36 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 # compression-level for videos +train_batchSize: 32 # training batch size +test_batchSize: 32 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 16} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.00002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: acc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations + +# save latest ckpt +save_latest_ckpt: true diff --git a/training/config/detector/clip_image.yaml b/training/config/detector/clip_image.yaml new file mode 100644 index 0000000000000000000000000000000000000000..793632817ba193743299ef9fe87ed60dfadb0021 --- /dev/null +++ b/training/config/detector/clip_image.yaml @@ -0,0 +1,86 @@ +log_dir: logs/clip_models + +# model setting +model_name: clip_image # Must match the name registered in `DETECTOR.register_module` +clip_model_name: ViT-B/32 # CLIP model variant (e.g. ViT-B/32 or ViT-L/14) +pretrained: null # Use the built-in CLIP pretrained weights; no extra path is required + +# backbone setting (framework-compatible format) +backbone_name: clip_image # Identify the backbone as the CLIP image encoder +backbone_config: + num_classes: 36 # Number of classes for multi-class classification + inc: 3 # Number of image input channels (RGB) + mode: original # Framework-compatible field + +# classifier config (linear classifier configuration) +classifier_config: + hidden_size_list: [512, 256] # Hidden layer dimensions (matching the original implementation) + num_classes: 36 # Number of classes for multi-class classification (must match the task) + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] +save_ckpt: true # Whether to save model checkpoints +save_feat: true # Whether to save features + +# Data parameters (no text-related settings) +compression: c23 # Video compression level +train_batchSize: 32 # Training batch size (matching the original implementation) +test_batchSize: 32 # Test batch size +workers: 8 # Number of data loading workers +frame_num: {'train': 8, 'test': 16} # Number of frames sampled from each video +resolution: 224 # Input resolution (CLIP defaults to 224) +balance_data: false # Whether to balance real/fake samples in the training set +with_mask: false # No mask input is required +with_landmark: false # No facial landmark input is required + +# Data augmentation +use_data_augmentation: false +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# Normalization parameters (matching the original implementation) +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# Optimizer configuration (matching the original implementation) +optimizer: + type: adam + adam: + lr: 0.0003 # Learning rate (3e-4 in the original implementation) + beta1: 0.9 + beta2: 0.999 + eps: 0.00000001 + weight_decay: 0.0005 + amsgrad: false + +# Training configuration +lr_scheduler: null # No learning rate scheduler +nEpochs: 50 # Number of training epochs (matching the original implementation) +start_epoch: 0 # Starting epoch +save_epoch: 1 # Save a checkpoint every n epochs +rec_iter: 100 # Log once every n steps +logdir: ./logs # Log directory +manualSeed: 1024 # Random seed +save_latest_ckpt: true # Save the latest checkpoint + +# Loss function (multi-class cross-entropy) +loss_func: cross_entropy +losstype: null + +# Evaluation metric +metric_scoring: acc # Primary evaluation metric (accuracy) + +# Hardware configuration +cuda: true # Use CUDA +cudnn: true diff --git a/training/config/detector/clip_large.yaml b/training/config/detector/clip_large.yaml new file mode 100644 index 0000000000000000000000000000000000000000..949e3a9f772f1c0088dd1f0b60c8ac74a778794a --- /dev/null +++ b/training/config/detector/clip_large.yaml @@ -0,0 +1,88 @@ +# log dir +log_dir: logs/clip_models + +# model setting +pretrained: openai/clip-vit-large-patch14 # path to a pre-trained model, if using one +model_name: clip_large # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 36 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 # compression-level for videos +train_batchSize: 32 # training batch size +test_batchSize: 32 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 16} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.00002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: acc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations + +# save latest ckpt +save_latest_ckpt: true diff --git a/training/config/detector/clip_large_fft.yaml b/training/config/detector/clip_large_fft.yaml new file mode 100644 index 0000000000000000000000000000000000000000..71ab0d5e2c273b67cd9cd86be2da351f42e978d2 --- /dev/null +++ b/training/config/detector/clip_large_fft.yaml @@ -0,0 +1,88 @@ +# log dir +log_dir: logs/clip_models + +# model setting +pretrained: openai/clip-vit-large-patch14 # path to a pre-trained model, if using one +model_name: clip_large_fft # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 36 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 # compression-level for videos +train_batchSize: 32 # training batch size +test_batchSize: 32 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 16} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.00002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: acc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations + +# save latest ckpt +save_latest_ckpt: true diff --git a/training/config/detector/clip_large_fft_dino_orth.yaml b/training/config/detector/clip_large_fft_dino_orth.yaml new file mode 100644 index 0000000000000000000000000000000000000000..554ddd5485dea1dac4de96844e3284178bbc5f50 --- /dev/null +++ b/training/config/detector/clip_large_fft_dino_orth.yaml @@ -0,0 +1,88 @@ +# log dir +log_dir: logs/dis_models + +# model setting +pretrained: dinov2_vitl14 # path to a pre-trained model, if using one +model_name: clip_large_fft_dino_orth # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 36 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 # compression-level for videos +train_batchSize: 64 # training batch size +test_batchSize: 128 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 16} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.00002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.001 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy_orth1 # loss function to use +losstype: null + +# metric +metric_scoring: acc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations + +# save latest ckpt +save_latest_ckpt: true \ No newline at end of file diff --git a/training/config/detector/clip_large_fft_dis.yaml b/training/config/detector/clip_large_fft_dis.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1c75052feceb93edf5b93505394e97c2c1d03ae0 --- /dev/null +++ b/training/config/detector/clip_large_fft_dis.yaml @@ -0,0 +1,88 @@ +# log dir +log_dir: logs/dis_models + +# model setting +pretrained: openai/clip-vit-large-patch14 # path to a pre-trained model, if using one +model_name: clip_large_fft_dis # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 36 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 # compression-level for videos +train_batchSize: 64 # training batch size +test_batchSize: 64 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 16} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.00002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: acc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations + +# save latest ckpt +save_latest_ckpt: true \ No newline at end of file diff --git a/training/config/detector/clip_large_fft_dis_cat1.yaml b/training/config/detector/clip_large_fft_dis_cat1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9ab223fd5aca99b9cd80716286faa67ee6d8eb64 --- /dev/null +++ b/training/config/detector/clip_large_fft_dis_cat1.yaml @@ -0,0 +1,88 @@ +# log dir +log_dir: logs/dis_models + +# model setting +pretrained: openai/clip-vit-large-patch14 # path to a pre-trained model, if using one +model_name: clip_large_fft_dis_cat1 # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 36 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 # compression-level for videos +train_batchSize: 64 # training batch size +test_batchSize: 128 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 16} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.00002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy_orth # loss function to use +losstype: null + +# metric +metric_scoring: acc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations + +# save latest ckpt +save_latest_ckpt: true \ No newline at end of file diff --git a/training/config/detector/clip_large_fft_dis_cat2.yaml b/training/config/detector/clip_large_fft_dis_cat2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bd8ffdb791b690ceaf6bc76a5edc87d5e159a87d --- /dev/null +++ b/training/config/detector/clip_large_fft_dis_cat2.yaml @@ -0,0 +1,88 @@ +# log dir +log_dir: logs/dis_models + +# model setting +pretrained: openai/clip-vit-large-patch14 # path to a pre-trained model, if using one +model_name: clip_large_fft_dis_cat2 # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 36 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 # compression-level for videos +train_batchSize: 64 # training batch size +test_batchSize: 128 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 16} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.00002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy_orth # loss function to use +losstype: null + +# metric +metric_scoring: acc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations + +# save latest ckpt +save_latest_ckpt: true \ No newline at end of file diff --git a/training/config/detector/clip_large_fft_dis_orth.yaml b/training/config/detector/clip_large_fft_dis_orth.yaml new file mode 100644 index 0000000000000000000000000000000000000000..093621c0e2e7b11f6643f36718815ca50b8712df --- /dev/null +++ b/training/config/detector/clip_large_fft_dis_orth.yaml @@ -0,0 +1,88 @@ +# log dir +log_dir: logs/dis_models + +# model setting +pretrained: openai/clip-vit-large-patch14 # path to a pre-trained model, if using one +model_name: clip_large_fft_dis_orth # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 36 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 # compression-level for videos +train_batchSize: 64 # training batch size +test_batchSize: 64 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 16} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.00002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy_orth # loss function to use +losstype: null + +# metric +metric_scoring: acc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations + +# save latest ckpt +save_latest_ckpt: true \ No newline at end of file diff --git a/training/config/detector/clip_large_fft_dis_orth1.yaml b/training/config/detector/clip_large_fft_dis_orth1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dc1abeef349255e727df2794e822e585204b4ff8 --- /dev/null +++ b/training/config/detector/clip_large_fft_dis_orth1.yaml @@ -0,0 +1,88 @@ +# log dir +log_dir: logs/dis_models + +# model setting +pretrained: openai/clip-vit-large-patch14 # path to a pre-trained model, if using one +model_name: clip_large_fft_dis_orth1 # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 36 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 # compression-level for videos +train_batchSize: 64 # training batch size +test_batchSize: 64 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 16} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.00002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy_orth1 # loss function to use +losstype: null + +# metric +metric_scoring: acc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations + +# save latest ckpt +save_latest_ckpt: true \ No newline at end of file diff --git a/training/config/detector/clip_large_fft_dis_orth2.yaml b/training/config/detector/clip_large_fft_dis_orth2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..192402157d577ec3759ac6b539d41dbf048ee59e --- /dev/null +++ b/training/config/detector/clip_large_fft_dis_orth2.yaml @@ -0,0 +1,88 @@ +# log dir +log_dir: logs/dis_models + +# model setting +pretrained: openai/clip-vit-large-patch14 # path to a pre-trained model, if using one +model_name: clip_large_fft_dis_orth2 # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 36 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 # compression-level for videos +train_batchSize: 64 # training batch size +test_batchSize: 64 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 16} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.00002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy_orth2 # loss function to use +losstype: null + +# metric +metric_scoring: acc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations + +# save latest ckpt +save_latest_ckpt: true \ No newline at end of file diff --git a/training/config/detector/clip_large_fft_dis_orth3.yaml b/training/config/detector/clip_large_fft_dis_orth3.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0f31e5a5e0317ea9348417b918d51f42f1efc97b --- /dev/null +++ b/training/config/detector/clip_large_fft_dis_orth3.yaml @@ -0,0 +1,88 @@ +# log dir +log_dir: logs/dis_models + +# model setting +pretrained: openai/clip-vit-large-patch14 # path to a pre-trained model, if using one +model_name: clip_large_fft_dis_orth3 # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 36 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 # compression-level for videos +train_batchSize: 64 # training batch size +test_batchSize: 64 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 16} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.00002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy_orth1 # loss function to use +losstype: null + +# metric +metric_scoring: acc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations + +# save latest ckpt +save_latest_ckpt: true \ No newline at end of file diff --git a/training/config/detector/clip_large_fft_dis_vae1.yaml b/training/config/detector/clip_large_fft_dis_vae1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..36e0392ddeb9ffc2c10d00b5cb6a9f4aafa356dd --- /dev/null +++ b/training/config/detector/clip_large_fft_dis_vae1.yaml @@ -0,0 +1,88 @@ +# log dir +log_dir: logs/dis_models + +# model setting +pretrained: openai/clip-vit-large-patch14 # path to a pre-trained model, if using one +model_name: clip_large_fft_vae1 # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 36 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 # compression-level for videos +train_batchSize: 64 # training batch size +test_batchSize: 128 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 16} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.00002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: acc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations + +# save latest ckpt +save_latest_ckpt: true \ No newline at end of file diff --git a/training/config/detector/clip_large_fft_dis_vae2.yaml b/training/config/detector/clip_large_fft_dis_vae2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..91451a06c1285461318fcc84613cab92cdf56253 --- /dev/null +++ b/training/config/detector/clip_large_fft_dis_vae2.yaml @@ -0,0 +1,88 @@ +# log dir +log_dir: logs/dis_models + +# model setting +pretrained: openai/clip-vit-large-patch14 # path to a pre-trained model, if using one +model_name: clip_large_fft_vae2 # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 36 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 # compression-level for videos +train_batchSize: 64 # training batch size +test_batchSize: 128 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 16} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.00002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: acc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations + +# save latest ckpt +save_latest_ckpt: true \ No newline at end of file diff --git a/training/config/detector/clip_large_fft_lsda.yaml b/training/config/detector/clip_large_fft_lsda.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9287930ce29f61e404d9bcd3da6336c2b82d8578 --- /dev/null +++ b/training/config/detector/clip_large_fft_lsda.yaml @@ -0,0 +1,88 @@ +# log dir +log_dir: logs/clip_models + +# model setting +pretrained: openai/clip-vit-large-patch14 # path to a pre-trained model, if using one +model_name: clip_large_fft_lsda # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 36 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 # compression-level for videos +train_batchSize: 32 # training batch size +test_batchSize: 32 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 16} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.00002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: acc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations + +# save latest ckpt +save_latest_ckpt: true diff --git a/training/config/detector/clip_large_fft_supcon.yaml b/training/config/detector/clip_large_fft_supcon.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9769cddf05a8c3b3070aea1b02e40a30c6205d09 --- /dev/null +++ b/training/config/detector/clip_large_fft_supcon.yaml @@ -0,0 +1,88 @@ +# log dir +log_dir: logs/clip_models + +# model setting +pretrained: openai/clip-vit-large-patch14 # path to a pre-trained model, if using one +model_name: clip_large_fft_supcon # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 36 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 # compression-level for videos +train_batchSize: 128 # training batch size +test_batchSize: 64 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 16} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.00002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 50 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: supcon # loss function to use +losstype: null + +# metric +metric_scoring: acc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations + +# save latest ckpt +save_latest_ckpt: true \ No newline at end of file diff --git a/training/config/detector/clip_large_fft_supcon_cls.yaml b/training/config/detector/clip_large_fft_supcon_cls.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a76bba1c564a451f2031b8e39b81a83430d97c44 --- /dev/null +++ b/training/config/detector/clip_large_fft_supcon_cls.yaml @@ -0,0 +1,88 @@ +# log dir +log_dir: logs/clip_models + +# model setting +pretrained: openai/clip-vit-large-patch14 # path to a pre-trained model, if using one +model_name: clip_large_fft_dis_cat1 # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 36 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 # compression-level for videos +train_batchSize: 128 # training batch size +test_batchSize: 64 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 16} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.00002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: supcon_cls # loss function to use +losstype: null + +# metric +metric_scoring: acc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations + +# save latest ckpt +save_latest_ckpt: true \ No newline at end of file diff --git a/training/config/detector/clip_large_lora.yaml b/training/config/detector/clip_large_lora.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f6adca145e61dcf4d34a958d428f29ce7c816891 --- /dev/null +++ b/training/config/detector/clip_large_lora.yaml @@ -0,0 +1,88 @@ +# log dir +log_dir: logs/clip_models + +# model setting +pretrained: openai/clip-vit-large-patch14 # path to a pre-trained model, if using one +model_name: clip_large_lora # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 36 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 # compression-level for videos +train_batchSize: 64 # training batch size +test_batchSize: 64 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 16, 'test': 16} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input +video_mode: false # whether to use video-level data +clip_size: 8 # number of frames in each clip + +# data augmentation +use_data_augmentation: true # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: acc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations + +save_latest_ckpt: true diff --git a/training/config/detector/clip_large_vid.yaml b/training/config/detector/clip_large_vid.yaml new file mode 100644 index 0000000000000000000000000000000000000000..78427333053463b0c221ddb005190ed29eb3c771 --- /dev/null +++ b/training/config/detector/clip_large_vid.yaml @@ -0,0 +1,86 @@ +# log dir +log_dir: /Youtu_Pangu_Security_Public_cq11/shunliwang/DeepFakeBench/logs/clip_large_vid + +# model setting +pretrained: openai/clip-vit-large-patch14 # path to a pre-trained model, if using one +model_name: clip_large_vid # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 2 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FaceForensics++] +test_dataset: [Celeb-DF-v2, FaceShifter, DeepFakeDetection] + +compression: c23 # compression-level for videos +train_batchSize: 16 # training batch size +test_batchSize: 16 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 32} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input +video_mode: True # whether to use video-level data +clip_size: 8 # number of frames in each clip + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations diff --git a/training/config/detector/clip_openai_vid.yaml b/training/config/detector/clip_openai_vid.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a02eab53a64434afd7030af12e307ea4b189dc06 --- /dev/null +++ b/training/config/detector/clip_openai_vid.yaml @@ -0,0 +1,86 @@ +# log dir +log_dir: /Youtu_Pangu_Security_Public_cq11/shunliwang/DeepFakeBench/logs/clip_openai_vid + +# model setting +pretrained: ViT-L/14 # path to a pre-trained model, if using one +model_name: clip_openai_vid # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 2 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FaceForensics++] +test_dataset: [Celeb-DF-v2, FaceShifter, DeepFakeDetection] + +compression: c23 # compression-level for videos +train_batchSize: 16 # training batch size +test_batchSize: 16 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 32} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input +video_mode: True # whether to use video-level data +clip_size: 8 # number of frames in each clip + +# data augmentation +use_data_augmentation: true # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations diff --git a/training/config/detector/clip_patch_shuffle.yaml b/training/config/detector/clip_patch_shuffle.yaml new file mode 100644 index 0000000000000000000000000000000000000000..084841d83e2a36b1ba2554e3a8ad8ae2d2348118 --- /dev/null +++ b/training/config/detector/clip_patch_shuffle.yaml @@ -0,0 +1,88 @@ +# log dir +log_dir: logs/clip_models + +# model setting +pretrained: openai/clip-vit-large-patch14 # path to a pre-trained model, if using one +model_name: clip_patch_shuffle # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 36 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 # compression-level for videos +train_batchSize: 32 # training batch size +test_batchSize: 32 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 16} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.00002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: acc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations + +# save latest ckpt +save_latest_ckpt: true diff --git a/training/config/detector/clip_videomae.yaml b/training/config/detector/clip_videomae.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d1bdd9ffc9e2bfcff06922aedf3bb91b81141d1e --- /dev/null +++ b/training/config/detector/clip_videomae.yaml @@ -0,0 +1,87 @@ +# log dir +log_dir: /Youtu_Pangu_Security_Public_cq11/shunliwang/DeepFakeBench/logs/videomae + +# model setting +pretrained_clip: MCG-NJU/videomae-base +pretrained_videomae: MCG-NJU/videomae-base +model_name: clip_videomae # model name +backbone_name: null # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 2 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FaceForensics++] +test_dataset: [Celeb-DF-v2, DeepFakeDetection, FaceShifter, DFDC, DFDCP] + +compression: c23 # compression-level for videos +train_batchSize: 8 # training batch size +test_batchSize: 8 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: true # whether to include facial landmark information in the input +video_mode: True # whether to use video-level data +clip_size: 8 # number of frames in each clip + +# data augmentation +use_data_augmentation: true # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 100 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations diff --git a/training/config/detector/cnn_dct.yaml b/training/config/detector/cnn_dct.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2c4ad422d9fe5e4c25a93ede535cc222b1fa5168 --- /dev/null +++ b/training/config/detector/cnn_dct.yaml @@ -0,0 +1,91 @@ +# log dir +log_dir: logs/cnn_dct + +# model setting +pretrained: null +model_name: cnn_dct # +backbone_name: simple_cnn # SimpleCNN + +# backbone setting +backbone_config: + dropout: 0.5 + num_classes: 36 # + inc: 3 # input channel + mode: original # CNN + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] +save_ckpt: true +save_feat: true + +compression: c23 +train_batchSize: 32 +test_batchSize: 32 +workers: 8 +frame_num: {'train': 8, 'test': 16} +resolution: 256 +balance_data: false +with_mask: false +with_landmark: false + + +# data augmentation +use_data_augmentation: false +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] + +# optimizer config +optimizer: + type: adam + adam: + lr: 0.0002 + beta1: 0.9 + beta2: 0.999 + eps: 0.00000001 + weight_decay: 0.0005 + amsgrad: false + sgd: + lr: 0.0002 + momentum: 0.9 + weight_decay: 0.0005 + +# training config +lr_scheduler: null +nEpochs: 20 +start_epoch: 0 +save_epoch: 1 +rec_iter: 100 +logdir: ./logs +manualSeed: 1024 +save_ckpt: false + +# loss function +loss_func: cross_entropy +losstype: null + +# metric +metric_scoring: acc + +# cuda +cuda: true +cudnn: true + +# save latest ckpt +save_latest_ckpt: true + + diff --git a/training/config/detector/core.yaml b/training/config/detector/core.yaml new file mode 100644 index 0000000000000000000000000000000000000000..191e9dee49352f5bbedb9a235e7c80ec4dfff878 --- /dev/null +++ b/training/config/detector/core.yaml @@ -0,0 +1,88 @@ +# log dir +log_dir: /data/home/zhiyuanyan/logs/testing_bench + +# model setting +# pretrained: /home/yuanxinhang/resnet34-b627a593.pth # path to a pre-trained model, if using one +pretrained: ./training/pretrained/xception-b5690688.pth # path to a pre-trained model, if using one +model_name: core # model name +backbone_name: xception # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 2 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FF-NT] +test_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT] + +compression: c23 # compression-level for videos +train_batchSize: 32 # training batch size +test_batchSize: 32 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing +resolution: 256 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 10 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: false # whether to save checkpoint + +# loss function +loss_func: consistency_loss # loss function to use +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda + +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations diff --git a/training/config/detector/d_det.yaml b/training/config/detector/d_det.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0a233bccde71f8aa4205050a22a4b872bad60427 --- /dev/null +++ b/training/config/detector/d_det.yaml @@ -0,0 +1,95 @@ +# log dir +log_dir: logs/dna_det + +# model setting +pretrained: null +model_name: dna_det +backbone_name: dna_simple_cnn + +# backbone setting +backbone_config: + num_classes: 36 + pretrain: false + head: mlp + dim_in: 512 + feat_dim: 128 + +# SupConLoss related settings +temperature: 0.07 # SupConLoss + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 +train_batchSize: 64 +test_batchSize: 64 +workers: 8 +frame_num: {'train': 8, 'test': 16} +resolution: 256 +with_mask: false +with_landmark: false +save_ckpt: true +save_feat: true + +# data augmentation +use_data_augmentation: false +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] + +# optimizer config + +optimizer: + type: adam + adam: + lr: 0.0002 + beta1: 0.9 + beta2: 0.999 + eps: 0.00000001 + weight_decay: 0.0005 + amsgrad: false + sgd: + lr: 0.0002 + momentum: 0.9 + weight_decay: 0.0005 + +# training config +lr_scheduler: null +step_size: 10 +gamma: 0.1 + +nEpochs: 20 +start_epoch: 0 +save_epoch: 1 +rec_iter: 100 +logdir: ./logs +manualSeed: 1024 +save_ckpt: true + +# loss function +loss_func: ce_supcon +losstype: null + +# metric +metric_scoring: acc + +# cuda +cuda: true +cudnn: true + +# save latest ckpt +save_latest_ckpt: true diff --git a/training/config/detector/dino_contrast.yaml b/training/config/detector/dino_contrast.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2e342067117cd2af3a0f9a605aab526cfcc38378 --- /dev/null +++ b/training/config/detector/dino_contrast.yaml @@ -0,0 +1,89 @@ +# log dir +log_dir: logs/dino_models + +# model setting +pretrained: dinov2_vitl14 # path to a pre-trained model, if using one +model_name: dinov2_large_fft_contrast # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 36 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 # compression-level for videos +train_batchSize: 64 # training batch size +test_batchSize: 128 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 16} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: true # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.00002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: acc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations +ddp: true + +# save latest ckpt +save_latest_ckpt: true diff --git a/training/config/detector/dinov2_large_fft.yaml b/training/config/detector/dinov2_large_fft.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b79ea3e9833318b8d57d4590be551694942fce63 --- /dev/null +++ b/training/config/detector/dinov2_large_fft.yaml @@ -0,0 +1,89 @@ +# log dir +log_dir: logs/dino_models + +# model setting +pretrained: dinov2_vitl14 # path to a pre-trained model, if using one +model_name: dinov2_large_fft # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 36 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 # compression-level for videos +train_batchSize: 64 # training batch size +test_batchSize: 128 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 16} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: true # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.00002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: acc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations +ddp: true + +# save latest ckpt +save_latest_ckpt: true \ No newline at end of file diff --git a/training/config/detector/dinov2_large_fft_res.yaml b/training/config/detector/dinov2_large_fft_res.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1e2b3a4e73d416a559a108af44db8d8c92abd6f1 --- /dev/null +++ b/training/config/detector/dinov2_large_fft_res.yaml @@ -0,0 +1,89 @@ +# log dir +log_dir: logs/dino_models_res + +# model setting +pretrained: dinov2_vitl14 # path to a pre-trained model, if using one +model_name: dinov2_large_fft # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 36 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [res_protocol_2_train] +test_dataset: [res_protocol_2_test] + +compression: c23 # compression-level for videos +train_batchSize: 64 # training batch size +test_batchSize: 128 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 16} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.00002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: acc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations +ddp: true + +# save latest ckpt +save_latest_ckpt: true \ No newline at end of file diff --git a/training/config/detector/dinov3_large_fft.yaml b/training/config/detector/dinov3_large_fft.yaml new file mode 100644 index 0000000000000000000000000000000000000000..44eea0d73cdad099479956c942ee8528cf6550cd --- /dev/null +++ b/training/config/detector/dinov3_large_fft.yaml @@ -0,0 +1,89 @@ +# log dir +log_dir: logs/dino_models + +# model setting +pretrained: dinov3_vitl16 # path to a pre-trained model, if using one +model_name: dinov3_large_fft # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 36 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 # compression-level for videos +train_batchSize: 64 # training batch size +test_batchSize: 64 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 16} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: true # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.00002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: acc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations +ddp: true + +# save latest ckpt +save_latest_ckpt: true \ No newline at end of file diff --git a/training/config/detector/efficientnetb4.yaml b/training/config/detector/efficientnetb4.yaml new file mode 100644 index 0000000000000000000000000000000000000000..50c8cd22af21b87a8772c814bb204d028a792ad7 --- /dev/null +++ b/training/config/detector/efficientnetb4.yaml @@ -0,0 +1,88 @@ +# log dir +log_dir: logs/evaluations/effnb4 + +# model setting +# pretrained: /home/zhiyuanyan/disfin/deepfake_benchmark/training/pretrained/xception-b5690688.pth # path to a pre-trained model, if using one +pretrained: ./training/pretrained/efficientnet-b4-6ed6700e.pth # path to a pre-trained model, if using one +model_name: efficientnetb4 # model name +backbone_name: efficientnetb4 # backbone name + +#backbone setting +backbone_config: + num_classes: 2 + inc: 3 + dropout: false + mode: Original + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FF-NT] +test_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT] + +compression: c23 # compression-level for videos +train_batchSize: 32 # training batch size +test_batchSize: 32 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing +resolution: 256 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + + +# data augmentation +use_data_augmentation: true # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 10 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: false # whether to save checkpoint + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda + +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations diff --git a/training/config/detector/effort.yaml b/training/config/detector/effort.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f309b2590e085d154487cacc6caca8b0a96d777b --- /dev/null +++ b/training/config/detector/effort.yaml @@ -0,0 +1,91 @@ +# log dir +log_dir: logs/clip_models + +# model setting +pretrained: openai/clip-vit-large-patch14 +model_name: effort # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 36 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 # compression-level for videos +train_batchSize: 64 # training batch size +test_batchSize: 64 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 16, 'test': 16} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # betanv2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: acc # metric for evaluation (auc, acc, eer, ap) + +# cuda +ngpu: 1 # number of GPUs to use +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations + +save_avg: true + +# save latest ckpt +save_latest_ckpt: true diff --git a/training/config/detector/effort_cl.yaml b/training/config/detector/effort_cl.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a385df4588e6ac3bf85aa7754d05c55869d99659 --- /dev/null +++ b/training/config/detector/effort_cl.yaml @@ -0,0 +1,93 @@ +# log dir +log_dir: logs/clip_models + +# model setting +pretrained: openai/clip-vit-large-patch14 +model_name: effort_cl # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 2 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +# train_dataset: [protocol_2_train] +# test_dataset: [protocol_2_test] +train_dataset: [FaceForensics++] +test_dataset: [FaceForensics++, Celeb-DF-v2] + +compression: c23 # compression-level for videos +train_batchSize: 64 # training batch size +test_batchSize: 64 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 16} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: true # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # betanv2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: supcon_cls # loss function to use +losstype: null + +# metric +metric_scoring: acc # metric for evaluation (auc, acc, eer, ap) + +# cuda +ngpu: 1 # number of GPUs to use +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations + +save_avg: true + +# save latest ckpt +save_latest_ckpt: true \ No newline at end of file diff --git a/training/config/detector/effort_patch_shuffle.yaml b/training/config/detector/effort_patch_shuffle.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b3880a554a72e2858c7c78986cb386c492a37bd9 --- /dev/null +++ b/training/config/detector/effort_patch_shuffle.yaml @@ -0,0 +1,91 @@ +# log dir +log_dir: logs/effort + +# model setting +pretrained: openai/clip-vit-large-patch14 +model_name: effort_shuffle_ensemble # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 36 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 # compression-level for videos +train_batchSize: 64 # training batch size +test_batchSize: 64 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 16, 'test': 16} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # betanv2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: acc # metric for evaluation (auc, acc, eer, ap) + +# cuda +ngpu: 1 # number of GPUs to use +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations + +save_avg: true + +# save latest ckpt +save_latest_ckpt: true diff --git a/training/config/detector/effort_vid.yaml b/training/config/detector/effort_vid.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8c47dff762e5c2ac8bba58dbae8f34f46ac0889f --- /dev/null +++ b/training/config/detector/effort_vid.yaml @@ -0,0 +1,89 @@ +# log dir +log_dir: /Youtu_Pangu_Security_Public_cq11/shunliwang/DeepFakeBench/logs/effort_vid + +# model setting +pretrained: openai/clip-vit-large-patch14 +model_name: effort_vid # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 2 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FaceForensics++] +test_dataset: [Celeb-DF-v2, FaceShifter] # , DeeperForensics-1.0 + +compression: c23 # compression-level for videos +train_batchSize: 8 # training batch size +test_batchSize: 8 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input +video_mode: True # whether to use video-level data +clip_size: 16 # number of frames in each clip + +# data augmentation +use_data_augmentation: true # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # betanv2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda +ngpu: 1 # number of GPUs to use +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations + +save_avg: true \ No newline at end of file diff --git a/training/config/detector/f3net.yaml b/training/config/detector/f3net.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fca6b6273b082c3ea0ae17ee430a2c60dd70fdd1 --- /dev/null +++ b/training/config/detector/f3net.yaml @@ -0,0 +1,91 @@ +# log dir +log_dir: logs/f3net + +# model setting +pretrained: ./training/pretrained/xception-b5690688.pth # path to a pre-trained model, if using one +model_name: f3net # model name +backbone_name: xception # backbone name + +# backbone setting +backbone_config: + dropout: 0.5 + num_classes: 36 + inc: 3 + mode: original + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +compression: c23 # compression-level for videos +train_batchSize: 32 # training batch size +test_batchSize: 32 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 16} # number of frames to use per video in training and testing +resolution: 256 # resolution of output image to network +balance_data: false # whether to balance the number of real and fake videos in training data +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: false # whether to save checkpoint + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: acc # metric for evaluation (auc, acc, eer, ap) + +# cuda + +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations +# save latest ckpt +save_latest_ckpt: true + diff --git a/training/config/detector/facexray.yaml b/training/config/detector/facexray.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d5eae8b0582eee0ee0f2ce9db8c7c1091a815003 --- /dev/null +++ b/training/config/detector/facexray.yaml @@ -0,0 +1,88 @@ +# log dir +log_dir: /mntcephfs/lab_data/tianshuoge/benchmark_results/logs_final/facexray_nt + +# model setting +pretrained: ./training/pretrained/efficientnet-b4-6ed6700e.pth # path to a pre-trained model, if using one +model_name: facexray # model name +backbone_name: efficientnetb4 # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 2 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FF-NT] +test_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT] +dataset_type: blend + +compression: c23 # compression-level for videos +train_batchSize: 16 # training batch size +test_batchSize: 32 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing +resolution: 256 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 10 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: false # whether to save checkpoint + +# loss function +loss_func: + cls_loss: cross_entropy # loss function to use + mask_loss: bce +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda + +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations diff --git a/training/config/detector/ffd.yaml b/training/config/detector/ffd.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8e4a0d7773df78cf1770e92ceebaadd3f6cbb3c7 --- /dev/null +++ b/training/config/detector/ffd.yaml @@ -0,0 +1,92 @@ +# log dir +log_dir: /data/home/zhiyuanyan/logs/testing_bench + +# model setting +pretrained: ./training/pretrained/xception-b5690688.pth # path to a pre-trained model, if using one +model_name: ffd # model name +backbone_name: xception # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 2 + inc: 3 + dropout: false + +# model setting +maptype: reg + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FF-NT] +test_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT] + +compression: c23 # compression-level for videos +train_batchSize: 32 # training batch size +test_batchSize: 32 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing +resolution: 256 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 10 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: false # whether to save checkpoint + +# loss function +loss_func: + cls_loss: cross_entropy # loss function to use + mask_loss: l1loss +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda + +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations diff --git a/training/config/detector/ftcn.yaml b/training/config/detector/ftcn.yaml new file mode 100644 index 0000000000000000000000000000000000000000..76aebf5a505dd4060b67bb9c5941bc28e61c5a4f --- /dev/null +++ b/training/config/detector/ftcn.yaml @@ -0,0 +1,88 @@ +# log dir +log_dir: /data/home/zhiyuanyan/DeepfakeBench/logs_debug/video_baseline + +# model setting +pretrained: ./training/pretrained/I3D_8x8_R50.pth # path to a pre-trained model, if using one +model_name: videomae # video_baseline # model name +backbone_name: null # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 1 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FaceForensics++] +test_dataset: [Celeb-DF-v2] + +compression: c23 # compression-level for videos +train_batchSize: 8 # training batch size +test_batchSize: 8 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 300, 'test': 300} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input +video_mode: True # whether to use video-level data +clip_size: 16 # number of frames in each clip + + +# data augmentation +use_data_augmentation: true # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 10 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda + +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations diff --git a/training/config/detector/fwa.yaml b/training/config/detector/fwa.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9ac8037ce1a4778cd9fa3e433115210c8c4b310a --- /dev/null +++ b/training/config/detector/fwa.yaml @@ -0,0 +1,88 @@ +# log dir +log_dir: /mntcephfs/lab_data/yuanxinhang/benchmark_results/logs_analysis/fwa + +# model setting +pretrained: ./training/pretrained/xception-b5690688.pth # path to a pre-trained model, if using one +model_name: fwa # model name +backbone_name: xception # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 2 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FF-NT] +test_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT] +dataset_type: blend + +compression: c23 # compression-level for videos +train_batchSize: 16 # training batch size +test_batchSize: 32 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing +resolution: 256 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 10 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: false # whether to save checkpoint + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda + +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations diff --git a/training/config/detector/ganatt.yaml b/training/config/detector/ganatt.yaml new file mode 100644 index 0000000000000000000000000000000000000000..361198cd02144cdfdac550613915de81954f66ba --- /dev/null +++ b/training/config/detector/ganatt.yaml @@ -0,0 +1,87 @@ +# log dir +log_dir: logs/ganatt_models + +# model setting +model_name: ganatt # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 36 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 # compression-level for videos +train_batchSize: 32 # training batch size +test_batchSize: 32 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 16} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: acc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations + +# save latest ckpt +save_latest_ckpt: true diff --git a/training/config/detector/hrnet.yaml b/training/config/detector/hrnet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0ed2270073616597e7ecb6b681a28e562aaf03f4 --- /dev/null +++ b/training/config/detector/hrnet.yaml @@ -0,0 +1,122 @@ +# log dir +log_dir: logs/clip_models + +# model setting +model_name: hrnet # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 36 + inc: 3 + dropout: false + +HRNET: + PRETRAINED_LAYERS: ['*'] + STEM_INPLANES: 64 + FINAL_CONV_KERNEL: 1 + PRETRAINED: 'models/hrnet_w18_small_v2.pth' + STAGE1: + NUM_MODULES: 1 + NUM_BRANCHES: 1 + NUM_BLOCKS: [2] + NUM_CHANNELS: [64] + BLOCK: 'BOTTLENECK' + FUSE_METHOD: 'SUM' + STAGE2: + NUM_MODULES: 1 + NUM_BRANCHES: 4 + NUM_BLOCKS: [2, 2, 2, 2] + NUM_CHANNELS: [18, 36, 72, 144] + BLOCK: 'BASIC' + FUSE_METHOD: 'SUM' + STAGE3: + NUM_MODULES: 1 + NUM_BRANCHES: 4 + NUM_BLOCKS: [2, 2, 2, 2] + NUM_CHANNELS: [18, 36, 72, 144] + BLOCK: 'BASIC' + FUSE_METHOD: 'SUM' + STAGE4: + NUM_MODULES: 1 + NUM_BRANCHES: 4 + NUM_BLOCKS: [2, 2, 2, 2] + NUM_CHANNELS: [18, 36, 72, 144] + BLOCK: 'BASIC' + FUSE_METHOD: 'SUM' + + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 # compression-level for videos +train_batchSize: 32 # training batch size +test_batchSize: 32 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 16} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.00002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: acc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations + +# save latest ckpt +save_latest_ckpt: true diff --git a/training/config/detector/i3d.yaml b/training/config/detector/i3d.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a34e09212188b5f223895b84497e3d37f0b0de3a --- /dev/null +++ b/training/config/detector/i3d.yaml @@ -0,0 +1,88 @@ +# log dir +log_dir: /data/home/zhiyuanyan/logs/i3d + +# model setting +pretrained: ./training/pretrained/I3D_8x8_R50.pth # path to a pre-trained model, if using one +model_name: i3d # model name +backbone_name: xception # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 1 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FaceForensics++] +test_dataset: [Celeb-DF-v2, DeepFakeDetection, FaceShifter, DFDC, DFDCP] + +compression: c23 # compression-level for videos +train_batchSize: 8 # training batch size +test_batchSize: 8 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input +video_mode: True # whether to use video-level data +clip_size: 16 # number of frames in each clip + + +# data augmentation +use_data_augmentation: true # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 30 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda + +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations diff --git a/training/config/detector/iid.yaml b/training/config/detector/iid.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e83ed30d16c4514ad5976a252b2f0463b9e64b0a --- /dev/null +++ b/training/config/detector/iid.yaml @@ -0,0 +1,103 @@ +# log dir +log_dir: /data/home/zhiyuanyan/logs/benchv2/iid + +#LMDB dir +lmdb_dir: 'I:\transform_2_lmdb' + +# model setting +pretrained: ./training/pretrained/xception-b5690688.pth #./training/pretrained/xception-b5690688.pth # path to a pre-trained model, if using one +explicit_extractor_pretrained: ./training/pretrained/backbone.pth +model_name: iid # model name +backbone_name: xception # backbone name +restore_ckpt: 'None' +#backbone setting +backbone_config: + mode: adjust_channel_iid #adjust_channel + num_classes: 2 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FaceForensics++] +test_dataset: [Celeb-DF-v2, roop] +dataset_type: iid +compression: c23 # compression-level for videos +train_batchSize: 32 # training batch size +test_batchSize: 32 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing +resolution: 256 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: true # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + remove_attribute: false + gridmask_prob: 0 + remove_nose_prob: 0.0 + remove_eyes_prob: 0.0 + + +# mean and std for normalization +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + sam: + lr: 0.001 # learning rate + momentum: 0.9 # momentum for SGD optimizer + +# training config +lr_scheduler: null # learning rate scheduler step +lr_step: 3 +lr_gamma: 0.4 # learning rate scheduler +nEpochs: 10 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use cross_entropy + +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda + +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations + +embedding_size: 512 \ No newline at end of file diff --git a/training/config/detector/lorax.yaml b/training/config/detector/lorax.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dbfcd4cda891328a10787cb55ea73dfd5ee4b10f --- /dev/null +++ b/training/config/detector/lorax.yaml @@ -0,0 +1,93 @@ +# log dir +log_dir: logs/lorax_convit_tiny + +# model setting +pretrained: null +model_name: lorax_convit +backbone_name: lorax_convit + +# backbone setting: +backbone_config: + num_classes: 36 + + model: convit_pretrain + + # args.r_param + r_param: 8 + reg: ".*\\.attn.qkv|.*\\.attn.qk|.*\\.attn.v|\\.attn\\.pos_proj|\\.attn\\.proj" + r_param: 64 + + # args.r_alpha_ratio + r_alpha_ratio: 2.0 + + + reg: + - qkv + - proj + + # LoRA dropout + lora_dropout: 0.1 + + +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 +train_batchSize: 64 +test_batchSize: 64 +workers: 8 +frame_num: {'train': 8, 'test': 16} +resolution: 224 +with_mask: false +with_landmark: false +save_ckpt: true +save_feat: true + +use_data_augmentation: false +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] + +optimizer: + type: adam + adam: + lr: 0.0002 + beta1: 0.9 + beta2: 0.999 + eps: 1.0e-8 + weight_decay: 0.0005 + amsgrad: false + sgd: + lr: 0.0002 + momentum: 0.9 + weight_decay: 0.0005 + +lr_scheduler: null +nEpochs: 20 +start_epoch: 0 +save_epoch: 1 +rec_iter: 100 +logdir: ./logs +manualSeed: 1024 +save_ckpt: false + +loss_func: cross_entropy +losstype: null + +metric_scoring: acc + +cuda: true +cudnn: true +save_latest_ckpt: true diff --git a/training/config/detector/lrl.yaml b/training/config/detector/lrl.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b599cd58e7f7d6b32ec7cc9619298e37526abfc1 --- /dev/null +++ b/training/config/detector/lrl.yaml @@ -0,0 +1,131 @@ +# log dir +log_dir: /data/home/zhiyuanyan/logs/benchv2/lrl + +#LMDB dir +lmdb_dir: 'I:\transform_2_lmdb' + +pretrained: ./training/pretrained/efficientnet-b4-6ed6700e.pth #./training/pretrained/efficientnet-b4-6ed6700e.pth # path to a pre-trained model, if using one +model_name: lrl # model name +backbone_name: efficientnetb4 # backbone name + +#backbone setting +backbone_config: + num_classes: 2 + inc: 3 + dropout: 0.3 + mode: Original + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FaceForensics++] +test_dataset: [Celeb-DF-v2, roop] +dataset_type: lrl +compression: c23 # compression-level for videos +train_batchSize: 32 # training batch size +test_batchSize: 32 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 32} # number of frames to use per video in training and testing +resolution: 256 # resolution of output image to network +with_mask: true # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# label settings +label_dict: + # DFD + DFD_fake: 1 + DFD_real: 0 + # FF++ + FaceShifter(FF-real+FF-FH) + FF-SH: 1 + FF-F2F: 1 + FF-DF: 1 + FF-FS: 1 + FF-NT: 1 +# FF-DF: 1 +# FF-F2F: 2 +# FF-FS: 3 +# FF-NT: 4 + FF-FH: 1 + FF-real: 0 + # CelebDF + CelebDFv1_real: 0 + CelebDFv1_fake: 1 + CelebDFv2_real: 0 + CelebDFv2_fake: 1 + # DFDCP + DFDCP_Real: 0 + DFDCP_FakeA: 1 + DFDCP_FakeB: 1 + # DFDC + DFDC_Fake: 1 + DFDC_Real: 0 + # DeeperForensics-1.0 + DF_fake: 1 + DF_real: 0 + # UADFV + UADFV_Fake: 1 + UADFV_Real: 0 + # Roop + roop_Real: 0 + roop_Fake: 1 + + + +# data augmentation +use_data_augmentation: true # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler step +lr_step: 3 +lr_gamma: 0.4 +nEpochs: 40 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation + + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda + +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations diff --git a/training/config/detector/lsda.yaml b/training/config/detector/lsda.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bd714d0e762fdd0cab8d21cdcece69cd5a154aab --- /dev/null +++ b/training/config/detector/lsda.yaml @@ -0,0 +1,135 @@ +# log dir +log_dir: logs/evaluations/lsda + +# model setting +pretrained: ./training/pretrained/efficientnet-b4-6ed6700e.pth # path to a pre-trained model, if using one +model_name: lsda # model name +backbone_name: efficientnetb4 # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 2 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV, stargan] +train_dataset: [FaceForensics++] +test_dataset: [Celeb-DF-v2] +dataset_type: blend + +compression: c23 # compression-level for videos +train_batchSize: 20 # training batch size +test_batchSize: 32 # test batch size +workers: 4 # number of data loading workers +frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing +resolution: 256 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input +save_ckpt: false # whether to save checkpoint +save_feat: false # whether to save features + +# label settings +label_dict: + # DFD + DFD_fake: 1 + DFD_real: 0 + # FF++ + FaceShifter(FF-real+FF-FH) + FF-SH: 1 + FF-F2F: 1 + FF-DF: 1 + FF-FS: 1 + FF-NT: 1 + FF-FH: 1 + FF-real: 0 + # CelebDF + CelebDFv1_real: 0 + CelebDFv1_fake: 1 + CelebDFv2_real: 0 + CelebDFv2_fake: 1 + # DFDCP + DFDCP_Real: 0 + DFDCP_FakeA: 1 + DFDCP_FakeB: 1 + # DFDC + DFDC_Fake: 1 + DFDC_Real: 0 + # DeeperForensics-1.0 + DF_fake: 1 + DF_real: 0 + # UADFV + UADFV_Fake: 1 + UADFV_Real: 0 + #stargan + 0_real: 0 + 1_fake: 1 + real: 0 + fake: 1 + + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + sam: + lr: 0.001 # learning rate + momentum: 0.9 # momentum for SGD optimizer + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 50 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda + +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations + + + +# lsda +teacher: efficientnetb4 +student: efficientnetb4 +real_encoder: null diff --git a/training/config/detector/meso4.yaml b/training/config/detector/meso4.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2d80200ed31f8e874854c3a6aa8d16bca0c635a3 --- /dev/null +++ b/training/config/detector/meso4.yaml @@ -0,0 +1,85 @@ +# log dir +log_dir: /mntcephfs/lab_data/yuanxinhang/benchmark_results/logs_final/meso4 + +# model setting +pretrained: false # path to a pre-trained model, if using one +model_name: meso4 # model name +backbone_name: meso4 # backbone name + +#backbone setting +backbone_config: + num_classes: 2 + inc: 3 + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FF-NT] +test_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT] + +compression: c23 # compression-level for videos +train_batchSize: 32 # training batch size +test_batchSize: 32 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing +resolution: 256 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 10 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: false # whether to save checkpoint + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda + +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations \ No newline at end of file diff --git a/training/config/detector/meso4Inception.yaml b/training/config/detector/meso4Inception.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7d7e3a3e4b5381d3ce389c7e9272d6ce98d0569c --- /dev/null +++ b/training/config/detector/meso4Inception.yaml @@ -0,0 +1,88 @@ +# log dir +log_dir: /mntcephfs/lab_data/yuanxinhang/benchmark_results/logs_final/meso4Inception + +# model setting +pretrained: false # path to a pre-trained model, if using one +model_name: meso4Inception # model name +backbone_name: meso4Inception # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 2 + inc: 3 + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FF-NT] +test_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT] + +compression: c23 # compression-level for videos +train_batchSize: 32 # training batch size +test_batchSize: 32 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing +resolution: 256 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 10 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +tsne: false # whether to calculate and visualize t-SNE embeddings +gradcam: false # whether to generate and visualize Grad-CAM heatmaps +manualSeed: 1024 # manual seed for random number generation +save_ckpt: false # whether to save checkpoint + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda + +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations \ No newline at end of file diff --git a/training/config/detector/multi_attention.yaml b/training/config/detector/multi_attention.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1249401348f6d81f82b7f2637c5261ff284a0323 --- /dev/null +++ b/training/config/detector/multi_attention.yaml @@ -0,0 +1,117 @@ +# log dir +log_dir: log/multi_attention/ + +# model setting +pretrained: ./training/pretrained/efficientnet-b4-6ed6700e.pth # path to a pre-trained model, if using one +model_name: multi_attention # model name +backbone_name: efficientnetb4 # backbone name + +#backbone setting +backbone_config: + num_classes: 2 + inc: 3 + dropout: false + mode: Original + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FaceForensics++] +test_dataset: [FaceForensics++, Celeb-DF-v1, Celeb-DF-v2] + +compression: c23 # compression-level for videos +train_batchSize: 32 # training batch size +test_batchSize: 32 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing +resolution: 256 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.000001 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: step # learning rate scheduler +lr_step: 5 +lr_gamma: 0.5 +nEpochs: 50 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: false # whether to save checkpoint + +# loss function +loss_func: + cls_loss: cross_entropy # loss function to use + ril_loss: region_independent_loss # Region Independent Loss + ril_params: + N: 32 + alpha: 0.05 + alpha_decay: 0.9 + inter_margin: 0.2 + intra_margin: [0.05, 0.1] + weights: [1, 1] # weights for CE_loss and RIL, respectively +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda + +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations + +# model parameters +feature_layer: b2 +attention_layer: b5 +num_attentions: 4 +mid_dim: 256 +dropout_rate: 0.25 +dropout_rate_final: 0.5 +AGDA: + kernel_size: 11 + dilation: 2 + sigma: 7 + threshold: [0.4, 0.6] + zoom: [3, 5] + scale_factor: 0.5 + noise_rate: 0.1 + +backbone_nEpochs: 10 +batch_per_epoch: 3591 \ No newline at end of file diff --git a/training/config/detector/npr.yaml b/training/config/detector/npr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..70a31603b25e6908061bc1df43a27ef94b39ea2b --- /dev/null +++ b/training/config/detector/npr.yaml @@ -0,0 +1,94 @@ +# log dir +log_dir: logs/npr + +# model setting +pretrained: null +model_name: npr +backbone_name: resnet50 + +# === NPR preprocessing === +use_npr: true # +npr_factor: 0.5 # interpolate(x, 0.5) +npr_scale: 0.6666667 # NPR * 2/3 + +# backbone setting +backbone_config: + num_classes: 36 + inc: 3 + dropout: false + mode: NPR + +# dataset + +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + + + +compression: c23 +train_batchSize: 64 +test_batchSize: 64 +workers: 8 +frame_num: {'train': 8, 'test': 16} +resolution: 256 +with_mask: false +with_landmark: false +save_ckpt: true +save_feat: true + +# data augmentation +use_data_augmentation: false +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] + +# optimizer config +optimizer: + type: adam + adam: + lr: 0.0002 # --lr 0.0002 + beta1: 0.9 + beta2: 0.999 + eps: 0.00000001 + weight_decay: 0.0005 + amsgrad: false + sgd: + lr: 0.0002 + momentum: 0.9 + weight_decay: 0.0005 + +# training config +lr_scheduler: null +start_epoch: 0 +save_epoch: 1 +rec_iter: 100 +logdir: ./logs +manualSeed: 1024 +save_ckpt: true + +# loss function +loss_func: cross_entropy +losstype: null + +# metric +metric_scoring: acc + +# cuda +cuda: true +cudnn: true + +# save latest ckpt +save_latest_ckpt: true \ No newline at end of file diff --git a/training/config/detector/ooc.yaml b/training/config/detector/ooc.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c36e0c557e1c9c15d1a7fc3849a9693952cdc54e --- /dev/null +++ b/training/config/detector/ooc.yaml @@ -0,0 +1,128 @@ +# log dir +log_dir: logs/coop_ooc_multiclass + +# model setting +model_name: coop_ooc # DETECTOR.register_module(module_name='coop_ooc') +backbone_name: coop_ooc + +# backbone setting (CoOp / OOC) +backbone_config: + backbone: RN50 # 'RN50' or 'ViT' + n_ctx: 8 # prompt context token + class_names: [ + "Real", + "FF-DF", + "FF-F2F", + "FF-NT", + "FF-FH", + "fsgan_Fake", + "faceswap_Fake", + "inswap_Fake", + "simswap_Fake", + "blendface_Fake", + "uniface_Fake", + "e4s_Fake", + "facedancer_Fake", + "mobileswap_Fake", + "sadtalker_Fake", + "wav2lip_Fake", + "fomm_Fake", + "MRAA_Fake", + "one_shot_free_Fake", + "pirender_Fake", + "tpsm_Fake", + "lia_Fake", + "danet_Fake", + "mcnet_Fake", + "hyperreenact_Fake", + "facevid2vid_Fake", + "VQGAN_Fake", + "StyleGAN3_Fake", + "StyleGANXL_Fake", + "ddim_Fake", + "sd2.1_Fake", + "rddm_Fake", + "pixart_Fake", + "DiT_Fake", + "SiT_Fake", + "e4e_Fake" + ] + num_classes: 36 #class_names + inc: 3 + dropout: false + mode: Original + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 +train_batchSize: 64 +test_batchSize: 64 +workers: 8 +frame_num: {'train': 8, 'test': 16} +resolution: 224 +with_mask: false +with_landmark: false +save_ckpt: true +save_feat: true + +# data augmentation +use_data_augmentation: false +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization (CLIP) +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + type: adam + adam: + lr: 0.0001 + beta1: 0.9 + beta2: 0.999 + eps: 0.00000001 + weight_decay: 0.0005 + amsgrad: false + sgd: + lr: 0.0002 + momentum: 0.9 + weight_decay: 0.0005 + +# training config +lr_scheduler: null +nEpochs: 20 +start_epoch: 0 +save_epoch: 1 +rec_iter: 100 +logdir: ./logs +manualSeed: 1024 +save_ckpt: true + +# loss function +loss_func: cross_entropy +losstype: null + +# metric +metric_scoring: acc + +# cuda +cuda: true +cudnn: true + +# save latest ckpt +save_latest_ckpt: true + + diff --git a/training/config/detector/pcl_xception.yaml b/training/config/detector/pcl_xception.yaml new file mode 100644 index 0000000000000000000000000000000000000000..df9148b7a71bf8486b1b02b69f79a6a4b099d6c7 --- /dev/null +++ b/training/config/detector/pcl_xception.yaml @@ -0,0 +1,100 @@ +# log dir +log_dir: logs/evaluations/pcl_xception_best + +#LMDB dir +lmdb_dir: 'I:\transform_2_lmdb' + +# model setting +pretrained: ./training/pretrained/xception-b5690688.pth #./training/pretrained/xception-b5690688.pth # path to a pre-trained model, if using one +model_name: pcl_xception # model name +backbone_name: xception_sladd # backbone name +restore_ckpt: 'None' +#backbone setting +backbone_config: + mode: original + num_classes: 2 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FaceForensics++] +test_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT] +dataset_type: I2G +compression: c23 # compression-level for videos +train_batchSize: 32 # training batch size +test_batchSize: 32 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 32} # number of frames to use per video in training and testing +resolution: 256 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: true # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: true # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + remove_attribute: false + gridmask_prob: 0 + remove_nose_prob: 0.0 + remove_eyes_prob: 0.0 + + +# mean and std for normalization +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + sam: + lr: 0.001 # learning rate + momentum: 0.9 # momentum for SGD optimizer + +# training config +lr_scheduler: step # learning rate scheduler step +lr_step: 3 +lr_gamma: 0.4 # learning rate scheduler +nEpochs: 150 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use cross_entropy +pcl_loss_weight: 1 # PLW +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda + +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations diff --git a/training/config/detector/pose.yaml b/training/config/detector/pose.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b79073c9a371d728023fba4039d0cdccd95d72ad --- /dev/null +++ b/training/config/detector/pose.yaml @@ -0,0 +1,86 @@ +# log dir +log_dir: logs/pose_models + +# model setting +model_name: pose # model name + +#backbone setting +backbone_config: + mode: original + num_classes: 36 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 # compression-level for videos +train_batchSize: 32 # training batch size +test_batchSize: 32 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 16} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: acc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations + +# save latest ckpt +save_latest_ckpt: true diff --git a/training/config/detector/recce.yaml b/training/config/detector/recce.yaml new file mode 100644 index 0000000000000000000000000000000000000000..604633b039935e6be87eb51e80dae945e8046dca --- /dev/null +++ b/training/config/detector/recce.yaml @@ -0,0 +1,87 @@ +# log dir +log_dir: /Youtu_Pangu_Security_Public_cq11/shunliwang/DeepFakeBench/logs/testing_bench + +# model setting +pretrained: /Youtu_Pangu_Security_Public_cq11/shunliwang/DeepFakeBench/training/pretrained/xception-b5690688.pth # path to a pre-trained model, if using one +model_name: recce # model name +backbone_name: xception # backbone name + +#backbone setting none +backbone_config: + mode: original + num_classes: 2 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FF-NT] +test_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT] + +compression: c23 # compression-level for videos +train_batchSize: 32 # training batch size +test_batchSize: 32 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing +resolution: 256 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 10 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: false # whether to save checkpoint + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda + +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations diff --git a/training/config/detector/repmix.yaml b/training/config/detector/repmix.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e8d0eb90913268de8c95bc9257b200ba306da81f --- /dev/null +++ b/training/config/detector/repmix.yaml @@ -0,0 +1,87 @@ +# log dir +log_dir: logs/repmix_models + +# model setting +model_name: repmix # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 36 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 # compression-level for videos +train_batchSize: 32 # training batch size +test_batchSize: 32 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 16} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.00002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: acc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations + +# save latest ckpt +save_latest_ckpt: true diff --git a/training/config/detector/resnet34.yaml b/training/config/detector/resnet34.yaml new file mode 100644 index 0000000000000000000000000000000000000000..92ac7264a6c24ae68c1d6b23d35ab2f69e3a3e44 --- /dev/null +++ b/training/config/detector/resnet34.yaml @@ -0,0 +1,90 @@ +# log dir +log_dir: logs/resnet34 + +# model setting +pretrained: /Youtu_Pangu_Security_Public_cq11/shunliwang/DeepFakeBench/training/pretrained/resnet34-b627a593.pth # path to a pre-trained model, if using one +model_name: resnet34 # model name +backbone_name: resnet34 # backbone name + +#backbone setting +backbone_config: + num_classes: 36 + inc: 3 + dropout: false + mode: Original + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 # compression-level for videos +train_batchSize: 64 # training batch size +test_batchSize: 64 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 16} # number of frames to use per video in training and testing +resolution: 256 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: false # whether to save checkpoint + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: acc # metric for evaluation (auc, acc, eer, ap) + +# cuda + +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations + +# save latest ckpt +save_latest_ckpt: true diff --git a/training/config/detector/resnet_distill.yaml b/training/config/detector/resnet_distill.yaml new file mode 100644 index 0000000000000000000000000000000000000000..199ff0bad004abc01837690b15252913a1cad79d --- /dev/null +++ b/training/config/detector/resnet_distill.yaml @@ -0,0 +1,90 @@ +# log dir +log_dir: logs/resnet34 + +# model setting +pretrained: /Youtu_Pangu_Security_Public/rainxyzhou/resnet34_dinov2_distillation_best.pth # path to a pre-trained model, if using one +model_name: resnet34_distill # model name +backbone_name: resnet34 # backbone name + +#backbone setting +backbone_config: + num_classes: 36 + inc: 3 + dropout: false + mode: Original + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 # compression-level for videos +train_batchSize: 64 # training batch size +test_batchSize: 64 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 16} # number of frames to use per video in training and testing +resolution: 256 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.00002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: false # whether to save checkpoint + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: acc # metric for evaluation (auc, acc, eer, ap) + +# cuda + +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations + +# save latest ckpt +save_latest_ckpt: true diff --git a/training/config/detector/rfm.yaml b/training/config/detector/rfm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5b15713e476b37e7515b574ab41630605cb7f136 --- /dev/null +++ b/training/config/detector/rfm.yaml @@ -0,0 +1,86 @@ +# log dir +log_dir: /data/home/zhiyuanyan/logs/benchv2/rfm + +# model setting +pretrained: ./training/pretrained/xception-b5690688.pth # path to a pre-trained model, if using one +model_name: rfm # model name +backbone_name: xception # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 2 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FaceForensics++] +test_dataset: [Celeb-DF-v2, roop] + +compression: c23 # compression-level for videos +train_batchSize: 32 # training batch size +test_batchSize: 32 # test batch size +workers: 4 # number of data loading workers +frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing +resolution: 256 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: true # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 10 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda + +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations diff --git a/training/config/detector/sbi.yaml b/training/config/detector/sbi.yaml new file mode 100644 index 0000000000000000000000000000000000000000..578a1b2dfde50a877006d657cfb18641636fdcf8 --- /dev/null +++ b/training/config/detector/sbi.yaml @@ -0,0 +1,94 @@ +# log dir +log_dir: /Youtu_Pangu_Security_Public_cq11/shunliwang/DeepFakeBench/logs/benchv2/sbi_v3 + +# model setting +pretrained: /Youtu_Pangu_Security_Public_cq11/shunliwang/DeepFakeBench/training/pretrained/xception-b5690688.pth # path to a pre-trained model, if using one +model_name: sbi # model name +backbone_name: efficientnetb4 # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 3 + inc: 3 + dropout: false + pretrained: true + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FaceForensics++] +test_dataset: [Celeb-DF-v2, FaceShifter] # roop +dataset_type: blend + +compression: c23 # compression-level for videos +train_batchSize: 24 # training batch size +test_batchSize: 32 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 32} # number of frames to use per video in training and testing +resolution: 256 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: true # whether to include facial landmark information in the input +save_ckpt: false # whether to save checkpoint +save_feat: false # whether to save features + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.485, 0.456, 0.406] +std: [0.229, 0.224, 0.225] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + sam: + lr: 0.001 # learning rate + momentum: 0.9 # momentum for SGD optimizer + +# training config +lr_scheduler: step # learning rate scheduler step +lr_step: 3 +lr_gamma: 0.4 # learning rate scheduler +nEpochs: 50 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: false # whether to save checkpoint + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda + +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations diff --git a/training/config/detector/sia.yaml b/training/config/detector/sia.yaml new file mode 100644 index 0000000000000000000000000000000000000000..55328f6f8c285014a03651d8b83ba72c094c5aa7 --- /dev/null +++ b/training/config/detector/sia.yaml @@ -0,0 +1,123 @@ +# log dir +log_dir: ../loggg/ +# model setting +pretrained: ./training/pretrained/efficientnet-b4-6ed6700e.pth # path to a pre-trained model, if using one +model_name: sia # model name +backbone_name: efficientnetb4 # backbone name + +#backbone setting +backbone_config: + num_classes: 2 + inc: 3 + dropout: false + mode: Original + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FF-NT] +test_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT] +dataset_json_folder: '/workspace/chenyize3/DeepfakeBench/preprocessing/dataset_json' + +compression: c23 # compression-level for videos +train_batchSize: 32 # training batch size +test_batchSize: 32 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing +resolution: 256 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# label settings +label_dict: + # DFD + DFD_fake: 1 + DFD_real: 0 + # FF++ + FaceShifter(FF-real+FF-FH) + FF-SH: 1 + FF-F2F: 1 + FF-DF: 1 + FF-FS: 1 + FF-NT: 1 + FF-FH: 1 + FF-real: 0 + # CelebDF + CelebDFv1_real: 0 + CelebDFv1_fake: 1 + CelebDFv2_real: 0 + CelebDFv2_fake: 1 + # DFDCP + DFDCP_Real: 0 + DFDCP_FakeA: 1 + DFDCP_FakeB: 1 + # DFDC + DFDC_Fake: 1 + DFDC_Real: 0 + # DeeperForensics-1.0 + DF_fake: 1 + DF_real: 0 + # UADFV + UADFV_Fake: 1 + UADFV_Real: 0 + + + +# data augmentation +use_data_augmentation: true # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + step: + lr_step: 5 + lr_gamma: 0.1 + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 10 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: false # whether to save checkpoint + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda + +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations diff --git a/training/config/detector/sladd_detector.yaml b/training/config/detector/sladd_detector.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7bf87f7b15b8c22da98bcab87ed6a67108da3e22 --- /dev/null +++ b/training/config/detector/sladd_detector.yaml @@ -0,0 +1,102 @@ +# log dir +log_dir: logs/evaluations/sladd_best + +#LMDB dir +lmdb_dir: 'I:\transform_2_lmdb' + +# model setting +pretrained: ./training/pretrained/xception-b5690688.pth #./training/pretrained/xception-b5690688.pth # path to a pre-trained model, if using one +model_name: sladd # model name +backbone_name: xception_sladd # backbone name +restore_ckpt: 'None' +#backbone setting +backbone_config: + mode: original + num_classes: 2 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FaceForensics++] +test_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT] +dataset_type: pair +compression: c23 # compression-level for videos +train_batchSize: 32 # training batch size +test_batchSize: 32 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 32} # number of frames to use per video in training and testing +resolution: 256 # resolution of output image to network +with_mask: true # whether to include mask information in the input +with_landmark: true # whether to include facial landmark information in the input +# use_phase_spe: None + + +# data augmentation +use_data_augmentation: true # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + remove_attribute: false + gridmask_prob: 0 + remove_nose_prob: 0.0 + remove_eyes_prob: 0.0 + + +# mean and std for normalization +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + sam: + lr: 0.001 # learning rate + momentum: 0.9 # momentum for SGD optimizer + +# training config +lr_scheduler: null # learning rate scheduler step +lr_step: 3 +lr_gamma: 0.4 # learning rate scheduler +nEpochs: 40 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use cross_entropy +typeloss_func: am_softmax + +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda + +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations diff --git a/training/config/detector/spsl.yaml b/training/config/detector/spsl.yaml new file mode 100644 index 0000000000000000000000000000000000000000..430b5ac0a7832fdad45b1ec29f97f4aae83c59ee --- /dev/null +++ b/training/config/detector/spsl.yaml @@ -0,0 +1,88 @@ +# log dir +log_dir: /mntcephfs/lab_data/zhiyuanyan/benchmark_results/logs_final/spsl_4frames + +# model setting +pretrained: ./training/pretrained/xception-b5690688.pth # path to a pre-trained model, if using one +# pretrained: /home/tianshuoge/resnet34-b627a593.pth # path to a pre-trained model, if using one +model_name: spsl # model name +backbone_name: xception # backbone name + +#backbone setting +backbone_config: + mode: original # shallow_xception + num_classes: 2 + inc: 4 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FF-FS] +test_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT] + +compression: c23 # compression-level for videos +train_batchSize: 32 # training batch size +test_batchSize: 32 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 4, 'test': 32} # number of frames to use per video in training and testing +resolution: 256 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + + +# data augmentation +use_data_augmentation: true # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 10 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: false # whether to save checkpoint + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda + +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations diff --git a/training/config/detector/srm.yaml b/training/config/detector/srm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..034860cbe2a20a928f307c1f618911563834c8ff --- /dev/null +++ b/training/config/detector/srm.yaml @@ -0,0 +1,88 @@ +# log dir +log_dir: /data/home/zhiyuanyan/logs/testing_bench + +# model setting +pretrained: ./training/pretrained/xception-b5690688.pth # path to a pre-trained model, if using one +model_name: srm # model name +backbone_name: xception # backbone name + +#backbone setting +backbone_config: + num_classes: 2 + inc: 3 + dropout: false + mode: original + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FF-F2F] +test_dataset: [FF-F2F, FF-DF, FF-FS, FF-NT] + +compression: c23 # compression-level for videos +train_batchSize: 32 # training batch size +test_batchSize: 32 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing +resolution: 256 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 10 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save checkpoint + +# loss function +loss_func: am_softmax # loss function to use +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda + +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations diff --git a/training/config/detector/stil.yaml b/training/config/detector/stil.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e067dea758bf6bf0f7663372be6052ac99fed3e5 --- /dev/null +++ b/training/config/detector/stil.yaml @@ -0,0 +1,88 @@ +# log dir +log_dir: /data/home/zhiyuanyan/logs/benchv2/stil + +# model setting +pretrained: null # path to a pre-trained model, if using one +model_name: stil # model name +backbone_name: xception # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 2 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FaceForensics++] +test_dataset: [Celeb-DF-v2] + +compression: c23 # compression-level for videos +train_batchSize: 8 # training batch size +test_batchSize: 8 # test batch size +workers: 4 # number of data loading workers +frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input +video_mode: True # whether to use video-level data +clip_size: 8 # number of frames in each clip + + +# data augmentation +use_data_augmentation: true # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 10 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda + +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations diff --git a/training/config/detector/tall.yaml b/training/config/detector/tall.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2129df252e71de1be5c95a0ce6df0fe82bb5a742 --- /dev/null +++ b/training/config/detector/tall.yaml @@ -0,0 +1,135 @@ +# log dir +log_dir: ./logs/benchv2/tall/trainOnFS + +# model setting +pretrained: null # path to a pre-trained model, if using one +model_name: tall # model name + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FaceForensics++] +test_dataset: [Celeb-DF-v2] + +compression: c23 # compression-level for videos +train_batchSize: 32 # training batch size +test_batchSize: 32 # test batch size +workers: 4 # number of data loading workers +frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input +video_mode: True # whether to use video-level data +clip_size: 4 # number of frames in each clip, should be square number of an integer +dataset_type: tall + + +# data augmentation +use_data_augmentation: true # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.485, 0.456, 0.406] +std: [0.229, 0.224, 0.225] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.00002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 100 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda + +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations + +dataset_json_folder: 'datasets/dataset_json' +rgb_dir: 'datasets/rgb' +label_dict: + # DFD + DFD_fake: 1 + DFD_real: 0 + # FF++ + FaceShifter(FF-real+FF-FH) + FF-SH: 1 + FF-F2F: 1 + FF-DF: 1 + FF-FS: 1 + FF-NT: 1 + FF-FH: 1 + FF-real: 0 + # CelebDF + CelebDFv1_real: 0 + CelebDFv1_fake: 1 + CelebDFv2_real: 0 + CelebDFv2_fake: 1 + # DFDCP + DFDCP_Real: 0 + DFDCP_FakeA: 1 + DFDCP_FakeB: 1 + # DFDC + DFDC_Fake: 1 + DFDC_Real: 0 + # DeeperForensics-1.0 + DF_fake: 1 + DF_real: 0 + # UADFV + UADFV_Fake: 1 + UADFV_Real: 0 + # Roop + roop_Real: 0 + roop_Fake: 1 + + +mask_grid_size: 16 + +num_classes: 2 +embed_dim: 128 +mlp_ratio: 4.0 +patch_size: 4 +window_size: [14, 14, 14, 7] +depths: [2, 2, 18, 2] +num_heads: [4, 8, 16, 32] +ape: true # use absolution position embedding +thumbnail_rows: 2 +drop_rate: 0 +drop_path_rate: 0.1 + +task_target: TALL + diff --git a/training/config/detector/timetransformer.yaml b/training/config/detector/timetransformer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ba14fb6b37f77b77aeaf154ec6c3cdad7f614035 --- /dev/null +++ b/training/config/detector/timetransformer.yaml @@ -0,0 +1,88 @@ +# log dir +log_dir: /data/home/zhiyuanyan/DeepfakeBench/logs_debug/time_transformer + +# model setting +pretrained: ./training/pretrained/xception-b5690688.pth # path to a pre-trained model, if using one +model_name: time_transformer # model name +backbone_name: null # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 2 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FaceForensics++] +test_dataset: [FaceForensics++, Celeb-DF-v2, DeepFakeDetection, FaceShifter] + +compression: c23 # compression-level for videos +train_batchSize: 8 # training batch size +test_batchSize: 8 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input +video_mode: True # whether to use video-level data +clip_size: 8 # number of frames in each clip + + +# data augmentation +use_data_augmentation: true # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 10 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda + +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations diff --git a/training/config/detector/ucf.yaml b/training/config/detector/ucf.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2d368c4ffc5cb805aad50f27b2710af373e19859 --- /dev/null +++ b/training/config/detector/ucf.yaml @@ -0,0 +1,130 @@ +# log dir +log_dir: /data/home/zhiyuanyan/DeepfakeBench/debug_logs/ucf + +# model setting +pretrained: ./training/pretrained/xception-b5690688.pth # path to a pre-trained model, if using one +# pretrained: '/home/zhiyuanyan/.cache/torch/hub/checkpoints/resnet34-b627a593.pth' # path to a pre-trained model, if using one +model_name: ucf # model name +backbone_name: xception # backbone name +encoder_feat_dim: 512 # feature dimension of the backbone + +#backbone setting +backbone_config: + mode: adjust_channel + num_classes: 2 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FF-F2F, FF-DF, FF-FS, FF-NT,] +test_dataset: [Celeb-DF-v2] +dataset_type: pair + +compression: c23 # compression-level for videos +train_batchSize: 16 # training batch size +test_batchSize: 32 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing +resolution: 256 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input +save_feat: true # whether to save features + +# label settings +label_dict: + # DFD + DFD_fake: 1 + DFD_real: 0 + FaceShifter: 1 + FF-FH: 1 + # FF++ + FaceShifter(FF-real+FF-FH) + # ucf specific label setting + FF-DF: 1 + FF-F2F: 2 + FF-FS: 3 + FF-NT: 4 + FF-real: 0 + # CelebDF + CelebDFv1_real: 0 + CelebDFv1_fake: 1 + CelebDFv2_real: 0 + CelebDFv2_fake: 1 + # DFDCP + DFDCP_Real: 0 + DFDCP_FakeA: 1 + DFDCP_FakeB: 1 + # DFDC + DFDC_Fake: 1 + DFDC_Real: 0 + # DeeperForensics-1.0 + DF_fake: 1 + DF_real: 0 + # UADFV + UADFV_Fake: 1 + UADFV_Real: 0 + # roop + roop_Fake: 1 + roop_Real: 0 + + + +# data augmentation +use_data_augmentation: true # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 5 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: false # whether to save checkpoint + +# loss function +loss_func: + cls_loss: cross_entropy # loss function to use + spe_loss: cross_entropy + con_loss: contrastive_regularization + rec_loss: l1loss +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda + +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations diff --git a/training/config/detector/uia_vit.yaml b/training/config/detector/uia_vit.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b2384a2e110cba714faa2fb41521deb2ab7165a8 --- /dev/null +++ b/training/config/detector/uia_vit.yaml @@ -0,0 +1,87 @@ +# log dir +log_dir: /data/home/zhiyuanyan/logs/benchv2/uia_vit + +# model setting +model_name: uia_vit # model name +pretrained: https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FaceForensics++] +test_dataset: [Celeb-DF-v2, roop] + +compression: c23 # compression-level for videos +train_batchSize: 64 # training batch size +test_batchSize: 64 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.00003 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.00001 # weight decay for regularization + amsgrad: false + + +# training config +#lr_scheduler: step +#lr_step: 2 +#lr_gamma: 0.8 +lr_scheduler: null +nEpochs: 10 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: false # whether to save checkpoint + +# loss function +loss_func: + cls_loss: cross_entropy # loss function to use + pcl_loss: patch_consistency_loss # Patch Consistency Loss + weights: [0.06, 0.05, 0.5] +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda + +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations + + +batch_per_epoch: 1796 # should be equal to number of batch in one epoch, manually required + + diff --git a/training/config/detector/universal.yaml b/training/config/detector/universal.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7e08bdedb8424a1f976074ad611782fe54bf05c5 --- /dev/null +++ b/training/config/detector/universal.yaml @@ -0,0 +1,88 @@ +# configs/universal_vit_clip_linearhead.yaml + +# log dir +log_dir: logs/universal_vit_clip_linearhead + +# model setting +pretrained: null # Let timm download the CLIP weights automatically +model_name: universal # Keep using `universal`, corresponding to the modified class above +backbone_name: vit_base_patch16_clip_224.openai # timm backbone name + +# backbone setting +backbone_config: + num_classes: 36 # Set according to the number of classes in your task + inc: 3 + dropout: false + mode: Original + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 +train_batchSize: 64 +test_batchSize: 64 +workers: 8 +frame_num: {'train': 8, 'test': 16} +resolution: 224 +with_mask: false +with_landmark: false +save_ckpt: true +save_feat: true # Enable this if you also want to save features from the frozen backbone + +# data augmentation +use_data_augmentation: false +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization (CLIP defaults) +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + type: adam + adam: + lr: 0.0002 # This only applies to `classifier_head` + beta1: 0.9 + beta2: 0.999 + eps: 0.00000001 + weight_decay: 0.0005 + amsgrad: false + sgd: + lr: 0.0002 + momentum: 0.9 + weight_decay: 0.0005 + +# training config +lr_scheduler: null +nEpochs: 20 +start_epoch: 0 +save_epoch: 1 +rec_iter: 100 +logdir: ./logs +manualSeed: 1024 +save_ckpt: true +save_latest_ckpt: true + +# loss function +loss_func: cross_entropy +losstype: null + +# metric +metric_scoring: acc + +# cuda +cuda: true +cudnn: true + diff --git a/training/config/detector/videomae.yaml b/training/config/detector/videomae.yaml new file mode 100644 index 0000000000000000000000000000000000000000..24b56b427d7e1c781cfaf77b5bbdafeaeccee931 --- /dev/null +++ b/training/config/detector/videomae.yaml @@ -0,0 +1,87 @@ +# log dir +log_dir: /Youtu_Pangu_Security_Public_cq11/shunliwang/DeepFakeBench/logs/videomae + +# model setting +pretrained: MCG-NJU/videomae-base-finetuned-kinetics # path to a pre-trained model, if using one +model_name: videomae # model name +backbone_name: null # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 2 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FaceForensics++] +test_dataset: [Celeb-DF-v2, DeepFakeDetection, FaceShifter] + +compression: c23 # compression-level for videos +train_batchSize: 8 # training batch size +test_batchSize: 8 # test batch size +workers: 16 # number of data loading workers +frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing +resolution: 244 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input +video_mode: True # whether to use video-level data +clip_size: 8 # number of frames in each clip + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] + + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.000001 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations diff --git a/training/config/detector/videomae_large.yaml b/training/config/detector/videomae_large.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fd1ea08660cb2ca4a16fd93cb4d562dcfee0cc75 --- /dev/null +++ b/training/config/detector/videomae_large.yaml @@ -0,0 +1,86 @@ +# log dir +log_dir: /Youtu_Pangu_Security_Public_cq11/shunliwang/DeepFakeBench/logs/videomae_large + +# model setting +pretrained: MCG-NJU/videomae-large # path to a pre-trained model, if using one +model_name: videomae_large # model name +backbone_name: null # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 2 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FaceForensics++] +test_dataset: [Celeb-DF-v2, DeepFakeDetection, FaceShifter] + +compression: c23 # compression-level for videos +train_batchSize: 8 # training batch size +test_batchSize: 8 # test batch size +workers: 16 # number of data loading workers +frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input +video_mode: True # whether to use video-level data +clip_size: 16 # number of frames in each clip + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.000001 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 50 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations diff --git a/training/config/detector/videomae_lora.yaml b/training/config/detector/videomae_lora.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a9a437b77d8c36ccafff976ec96c9dbb7180c230 --- /dev/null +++ b/training/config/detector/videomae_lora.yaml @@ -0,0 +1,87 @@ +# log dir +log_dir: /Youtu_Pangu_Security_Public_cq11/shunliwang/DeepFakeBench/logs/videomae_lora + +# model setting +pretrained: MCG-NJU/videomae-base-finetuned-kinetics # path to a pre-trained model, if using one +model_name: videomae_lora # model name +backbone_name: null # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 2 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FaceForensics++] +test_dataset: [Celeb-DF-v2, DeepFakeDetection, FaceShifter] + +compression: c23 # compression-level for videos +train_batchSize: 32 # training batch size +test_batchSize: 32 # test batch size +workers: 16 # number of data loading workers +frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing +resolution: 256 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input +video_mode: True # whether to use video-level data +clip_size: 8 # number of frames in each clip + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] + + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations diff --git a/training/config/detector/vit.yaml b/training/config/detector/vit.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8a0dc077de2390c014fb4363a85e992d471bec41 --- /dev/null +++ b/training/config/detector/vit.yaml @@ -0,0 +1,88 @@ +# log dir +log_dir: logs/vit_models + +# model setting +pretrained: google/vit-large-patch16-224 # path to a pre-trained model, if using one +model_name: vit_large_fft # model name +backbone_name: vit # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 36 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 # compression-level for videos +train_batchSize: 32 # training batch size +test_batchSize: 32 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 16} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.00002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: acc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations + +# save latest ckpt +save_latest_ckpt: true diff --git a/training/config/detector/xception.yaml b/training/config/detector/xception.yaml new file mode 100644 index 0000000000000000000000000000000000000000..090b13c0e7c5f4f4cf10fbb0fa2dc513b597a287 --- /dev/null +++ b/training/config/detector/xception.yaml @@ -0,0 +1,89 @@ +# log dir +log_dir: logs/xception + +# model setting +pretrained: ./training/pretrained/xception-b5690688.pth # path to a pre-trained model, if using one +model_name: xception # model name +backbone_name: xception # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 36 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [protocol_2_train] +test_dataset: [protocol_2_test] + +compression: c23 # compression-level for videos +train_batchSize: 64 # training batch size +test_batchSize: 64 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 8, 'test': 16} # number of frames to use per video in training and testing +resolution: 256 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.0 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 20 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: acc # metric for evaluation (auc, acc, eer, ap) + +# cuda + +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations + +# save latest ckpt +save_latest_ckpt: true \ No newline at end of file diff --git a/training/config/detector/xclip.yaml b/training/config/detector/xclip.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d443d53a92d67720af09851d3ff00920001a333f --- /dev/null +++ b/training/config/detector/xclip.yaml @@ -0,0 +1,88 @@ +# log dir +log_dir: /Youtu_Pangu_Security_Public_cq11/shunliwang/DeepFakeBench/logs/xclip + +# model setting +pretrained: microsoft/xclip-base-patch32 # path to a pre-trained model, if using one +model_name: xclip # model name +backbone_name: xception # backbone name + +#backbone setting +backbone_config: + mode: original + num_classes: 2 + inc: 3 + dropout: false + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV, roop] +train_dataset: [FaceForensics++] +test_dataset: [Celeb-DF-v2, DeepFakeDetection, FaceShifter] + +compression: c23 # compression-level for videos +train_batchSize: 8 # training batch size +test_batchSize: 8 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 150, 'test': 300} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input +video_mode: True # whether to use video-level data +clip_size: 32 # number of frames in each clip + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 100 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda + +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations diff --git a/training/config/test_config.yaml b/training/config/test_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d743c7b0b87f896dd7f34fc4e75d1569a70cf0f4 --- /dev/null +++ b/training/config/test_config.yaml @@ -0,0 +1,50 @@ +mode: test +lmdb: False +rgb_dir: '' +lmdb_dir: './datasets/lmdb' +dataset_json_folder: './preprocessing/dataset_json' +label_dict: + # DFD + DFD_fake: 1 + DFD_real: 0 + # FF++ + FaceShifter(FF-real+FF-FH) + FF-SH: 1 + FF-F2F: 1 + FF-DF: 1 + FF-FS: 1 + FF-NT: 1 + FF-FH: 1 + FF-real: 0 + # CelebDF + CelebDFv1_real: 0 + CelebDFv1_fake: 1 + CelebDFv2_real: 0 + CelebDFv2_fake: 1 + # DFDCP + DFDCP_Real: 0 + DFDCP_FakeA: 1 + DFDCP_FakeB: 1 + # DFDC + DFDC_Fake: 1 + DFDC_Real: 0 + # DeeperForensics-1.0 + DF_fake: 1 + DF_real: 0 + # UADFV + UADFV_Fake: 1 + UADFV_Real: 0 + # Roop + roop_Real: 0 + roop_Fake: 1 + # DFR WDF FFIW + DFR_Real: 0 + DFR_Fake: 1 + WDF_Real: 0 + WDF_Fake: 1 + FFIW_Real: 0 + FFIW_Fake: 1 + # DeepFakeGenome + Infinity_Real: 0 + Infinity_Fake: 1 + Hart_Real: 0 + Hart_Fake: 1 diff --git a/training/config/test_config_p2.yaml b/training/config/test_config_p2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..db42540420cda88b3372f80a5fc67fd93ea46e4a --- /dev/null +++ b/training/config/test_config_p2.yaml @@ -0,0 +1,49 @@ +mode: test +lmdb: False +rgb_dir: '' +lmdb_dir: './datasets/lmdb' +dataset_json_folder: './preprocessing/dataset_json' +ddp: false +label_dict: + # 2 types of real data + fsgan_Real: 0 + FF-real: 0 + # 35 types of fake data + + FF-DF: 1 + FF-F2F: 2 + FF-NT: 3 + FF-FH: 4 + fsgan_Fake: 5 + faceswap_Fake: 6 + inswap_Fake: 7 + simswap_Fake: 8 + blendface_Fake: 9 + uniface_Fake: 10 + e4s_Fake: 11 + facedancer_Fake: 12 + mobileswap_Fake: 13 + sadtalker_Fake: 14 + wav2lip_Fake: 15 + fomm_Fake: 16 + MRAA_Fake: 17 + one_shot_free_Fake: 18 + pirender_Fake: 19 + tpsm_Fake: 20 + lia_Fake: 21 + danet_Fake: 22 + mcnet_Fake: 23 + hyperreenact_Fake: 24 + facevid2vid_Fake: 25 + VQGAN_Fake: 26 + # StyleGAN2_Fake: 27 + StyleGAN3_Fake: 27 + StyleGANXL_Fake: 28 + ddim_Fake: 29 + sd2.1_Fake: 30 + rddm_Fake: 31 + pixart_Fake: 32 + DiT_Fake: 33 + SiT_Fake: 34 + e4e_Fake: 35 + diff --git a/training/config/test_config_p4.yaml b/training/config/test_config_p4.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f1b651e470714f8a6c687ac7f71a3333d8623211 --- /dev/null +++ b/training/config/test_config_p4.yaml @@ -0,0 +1,153 @@ +mode: test +lmdb: False +rgb_dir: '' +lmdb_dir: './datasets/lmdb' +dataset_json_folder: './preprocessing/dataset_json' +ddp: false +label_dict: + # 4 types of real data + "CelebA_1k_Real": 0 + "CelebA_5k_Real": 1 + "CelebA_1w_Real": 2 + "ForgeryNet_1w_Real": 3 + # 65 types of fake data + "deepfacelab_Fake": 4 + "heygen_Fake": 5 + "stargan_Fake": 6 + "starganv2_Fake": 7 + "styleclip_Fake": 8 + "whichisreal_Fake": 9 + "ATVG-Net_Fake": 10 + "FirstOrderMotion_Fake": 11 + "Talking-headVideo_Fake": 12 + "MMreplacement_Fake": 13 + "DiscoFaceGAN_Fake": 14 + "MaskGAN_Fake": 15 + "SC-FEGAN_Fake": 16 + "DSS_Fake": 17 + "SBS_Fake": 18 + "FaceAPP_Fake": 19 + "StyleGAN_Fake": 20 + "ProGAN_Fake": 21 + "MMDGAN_Fake": 22 + "SNGAN_Fake": 23 + "CramerGAN_Fake": 24 + "InfoMax-GAN_Fake": 25 + "SSGAN_Fake": 26 + "AttGAN_Fake": 27 + "Adobe Firefly_Fake": 28 + "Dall-E 3_Fake": 29 + "Flux.1_Fake": 30 + "Flux.1.1 Pro_Fake": 31 + "Freepik_Fake": 32 + "Leonardo AI_Fake": 33 + "Stable Diffusion 3.5_Fake": 34 + "Stable Diffusion XL_Fake": 35 + "Starry AI_Fake": 36 + "Dall-E 1_Fake": 37 + "Deep AI_Fake": 38 + "Hotpot AI_Fake": 39 + "Nvidia Sana PAG_Fake": 40 + "Stable Cascade_Fake": 41 + "Stable Diffusion Attend and Excite_Fake": 42 + "Tencent_Hunyuan_Fake": 43 + "Midjourney_Fake": 44 + "SDXL_Fake": 45 + "FreeDoM_T_Fake": 46 + "HPS_Fake": 47 + "SDXL_Refine_Fake": 48 + "DreamBooth_Fake": 49 + "DiffFace_Fake": 50 + "DCFace_Fake": 51 + "Imagic_Fake": 52 + "CoDiff_Fake": 53 + "cycle_diff_Fake": 54 + "PNDM_Fake": 55 + "ADM_Fake": 56 + "LDM_Fake": 57 + "SDv15_DS0.7_Fake": 58 + "Inpaint_Fake": 59 + "DiffSwap_Fake": 60 + "Qwen-Image_Fake": 61 + "BAGEL_Fake": 62 + "HunyuanImage-2.1_Fake": 63 + "Infinity_Fake": 64 + "Hart_Fake": 65 + "HunyuanImage-3.0_Fake": 66 + "NanoBanana_Fake": 67 + "GPT4o_Fake": 68 + +label2real_dict: + # 4 types of real data + "CelebA_1k_Real": 0 + "CelebA_5k_Real": 1 + "CelebA_1w_Real": 2 + "ForgeryNet_1w_Real": 3 + # 65 types of fake data + "deepfacelab_Fake": 1 + "heygen_Fake": 1 + "stargan_Fake": 1 + "starganv2_Fake": 1 + "styleclip_Fake": 1 + "whichisreal_Fake": 1 + "ATVG-Net_Fake": 3 + "FirstOrderMotion_Fake": 3 + "Talking-headVideo_Fake": 3 + "MMreplacement_Fake": 3 + "DiscoFaceGAN_Fake": 3 + "MaskGAN_Fake": 3 + "SC-FEGAN_Fake": 3 + "DSS_Fake": 3 + "SBS_Fake": 3 + "FaceAPP_Fake": 2 + "StyleGAN_Fake": 2 + "ProGAN_Fake": 2 + "MMDGAN_Fake": 2 + "SNGAN_Fake": 2 + "CramerGAN_Fake": 2 + "InfoMax-GAN_Fake": 2 + "SSGAN_Fake": 2 + "AttGAN_Fake": 1 + "Adobe Firefly_Fake": 0 + "Dall-E 3_Fake": 0 + "Flux.1_Fake": 0 + "Flux.1.1 Pro_Fake": 0 + "Freepik_Fake": 0 + "Leonardo AI_Fake": 0 + "Stable Diffusion 3.5_Fake": 0 + "Stable Diffusion XL_Fake": 0 + "Starry AI_Fake": 0 + "Dall-E 1_Fake": 0 + "Deep AI_Fake": 0 + "Hotpot AI_Fake": 0 + "Nvidia Sana PAG_Fake": 0 + "Stable Cascade_Fake": 0 + "Stable Diffusion Attend and Excite_Fake": 0 + "Tencent_Hunyuan_Fake": 0 + "Midjourney_Fake": 2 + "SDXL_Fake": 2 + "FreeDoM_T_Fake": 2 + "HPS_Fake": 2 + "SDXL_Refine_Fake": 2 + "DreamBooth_Fake": 2 + "DiffFace_Fake": 2 + "DCFace_Fake": 2 + "Imagic_Fake": 2 + "CoDiff_Fake": 2 + "cycle_diff_Fake": 2 + "PNDM_Fake": 2 + "ADM_Fake": 2 + "LDM_Fake": 2 + "SDv15_DS0.7_Fake": 2 + "Inpaint_Fake": 2 + "DiffSwap_Fake": 2 + "Qwen-Image_Fake": 1 + "BAGEL_Fake": 1 + "HunyuanImage-2.1_Fake": 1 + "Infinity_Fake": 1 + "Hart_Fake": 1 + "HunyuanImage-3.0_Fake": 1 + "NanoBanana_Fake": 0 + "GPT4o_Fake": 0 + + diff --git a/training/config/train_config.yaml b/training/config/train_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9bb3b1f26d8afc5734283afc7086f0082d4b49f3 --- /dev/null +++ b/training/config/train_config.yaml @@ -0,0 +1,50 @@ +mode: train +lmdb: False +dry_run: false +rgb_dir: '' +lmdb_dir: './datasets/lmdb' +dataset_json_folder: './preprocessing/dataset_json' +SWA: False +save_avg: True +# log_dir: ./logs/training/ +# label settings +label_dict: + # DFD + DFD_fake: 1 + DFD_real: 0 + # FF++ + FaceShifter(FF-real+FF-FH) + FF-SH: 1 + FF-F2F: 1 + FF-DF: 1 + FF-FS: 1 + FF-NT: 1 + FF-FH: 1 + FF-real: 0 + # CelebDF + CelebDFv1_real: 0 + CelebDFv1_fake: 1 + CelebDFv2_real: 0 + CelebDFv2_fake: 1 + # DFDCP + DFDCP_Real: 0 + DFDCP_FakeA: 1 + DFDCP_FakeB: 1 + # DFDC + DFDC_Fake: 1 + DFDC_Real: 0 + # DeeperForensics-1.0 + DF_fake: 1 + DF_real: 0 + # UADFV + UADFV_Fake: 1 + UADFV_Real: 0 + # Roop + roop_Real: 0 + roop_Fake: 1 + # DFR WDF FFIW + DFR_Real: 0 + DFR_Fake: 1 + WDF_Real: 0 + WDF_Fake: 1 + FFIW_Real: 0 + FFIW_Fake: 1 \ No newline at end of file diff --git a/training/config/train_config_p2.yaml b/training/config/train_config_p2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2564cdafcd6ff9291cf15944a42a029412b0cef4 --- /dev/null +++ b/training/config/train_config_p2.yaml @@ -0,0 +1,52 @@ +mode: train +lmdb: False +dry_run: false +rgb_dir: '' +lmdb_dir: './datasets/lmdb' +dataset_json_folder: './preprocessing/dataset_json' +SWA: False +save_avg: True +# log_dir: ./logs/training/ +# label settings +label_dict: + # 1 type of real data + FF-real: 0 + # 35 types of fake data + + FF-DF: 1 + FF-F2F: 2 + FF-NT: 3 + FF-FH: 4 + fsgan_Fake: 5 + faceswap_Fake: 6 + inswap_Fake: 7 + simswap_Fake: 8 + blendface_Fake: 9 + uniface_Fake: 10 + e4s_Fake: 11 + facedancer_Fake: 12 + mobileswap_Fake: 13 + sadtalker_Fake: 14 + wav2lip_Fake: 15 + fomm_Fake: 16 + MRAA_Fake: 17 + one_shot_free_Fake: 18 + pirender_Fake: 19 + tpsm_Fake: 20 + lia_Fake: 21 + danet_Fake: 22 + mcnet_Fake: 23 + hyperreenact_Fake: 24 + facevid2vid_Fake: 25 + VQGAN_Fake: 26 + # StyleGAN2_Fake: 27 + StyleGAN3_Fake: 27 + StyleGANXL_Fake: 28 + ddim_Fake: 29 + sd2.1_Fake: 30 + rddm_Fake: 31 + pixart_Fake: 32 + DiT_Fake: 33 + SiT_Fake: 34 + e4e_Fake: 35 + diff --git a/training/dataset/I2G_dataset.py b/training/dataset/I2G_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a09c505fbabd418b159483201ee0ba22ea8b006c --- /dev/null +++ b/training/dataset/I2G_dataset.py @@ -0,0 +1,389 @@ +# Created by: Kaede Shiohara +# Yamasaki Lab at The University of Tokyo +# shiohara@cvm.t.u-tokyo.ac.jp +# Copyright (c) 2021 +# 3rd party softwares' licenses are noticed at https://github.com/mapooon/SelfBlendedImages/blob/master/LICENSE +import logging +import os +import pickle + +import cv2 +import numpy as np +import scipy as sp +import yaml +from skimage.measure import label, regionprops +import random +from PIL import Image +import sys +import albumentations as A +from torch.utils.data import DataLoader +from dataset.utils.bi_online_generation import random_get_hull +from dataset.abstract_dataset import DeepfakeAbstractBaseDataset +from dataset.pair_dataset import pairDataset +import torch + +class RandomDownScale(A.core.transforms_interface.ImageOnlyTransform): + def apply(self, img, ratio_list=None, **params): + if ratio_list is None: + ratio_list = [2, 4] + r = ratio_list[np.random.randint(len(ratio_list))] + return self.randomdownscale(img, r) + + def randomdownscale(self, img, r): + keep_ratio = True + keep_input_shape = True + H, W, C = img.shape + + img_ds = cv2.resize(img, (int(W / r), int(H / r)), interpolation=cv2.INTER_NEAREST) + if keep_input_shape: + img_ds = cv2.resize(img_ds, (W, H), interpolation=cv2.INTER_LINEAR) + + return img_ds + + +''' +from PIL import ImageDraw + +img_pil=Image.fromarray(img) +draw = ImageDraw.Draw(img_pil) + + +for i, point in enumerate(landmark): + x, y = point + radius = 1 + draw.ellipse((x-radius, y-radius, x+radius, y+radius), fill="red") + draw.text((x+radius+2, y-radius), str(i), fill="black") +img_pil.show() + +''' + +def alpha_blend(source, target, mask): + mask_blured = get_blend_mask(mask) + img_blended = (mask_blured * source + (1 - mask_blured) * target) + return img_blended, mask_blured + + +def dynamic_blend(source, target, mask): + mask_blured = get_blend_mask(mask) + # worth consideration, 1 in the official paper, 0.25, 0.5, 0.75,1,1,1 in sbi. + blend_list = [1, 1, 1] + blend_ratio = blend_list[np.random.randint(len(blend_list))] + mask_blured *= blend_ratio + img_blended = (mask_blured * source + (1 - mask_blured) * target) + return img_blended, mask_blured + + +def get_blend_mask(mask): + H, W = mask.shape + size_h = np.random.randint(192, 257) + size_w = np.random.randint(192, 257) + mask = cv2.resize(mask, (size_w, size_h)) + kernel_1 = random.randrange(5, 26, 2) + kernel_1 = (kernel_1, kernel_1) + kernel_2 = random.randrange(5, 26, 2) + kernel_2 = (kernel_2, kernel_2) + + mask_blured = cv2.GaussianBlur(mask, kernel_1, 0) + mask_blured = mask_blured / (mask_blured.max()) + mask_blured[mask_blured < 1] = 0 + + mask_blured = cv2.GaussianBlur(mask_blured, kernel_2, np.random.randint(5, 46)) + mask_blured = mask_blured / (mask_blured.max()) + mask_blured = cv2.resize(mask_blured, (W, H)) + return mask_blured.reshape((mask_blured.shape + (1,))) + + +def get_alpha_blend_mask(mask): + kernel_list = [(11, 11), (9, 9), (7, 7), (5, 5), (3, 3)] + blend_list = [0.25, 0.5, 0.75] + kernel_idxs = random.choices(range(len(kernel_list)), k=2) + blend_ratio = blend_list[random.sample(range(len(blend_list)), 1)[0]] + mask_blured = cv2.GaussianBlur(mask, kernel_list[0], 0) + # print(mask_blured.max()) + mask_blured[mask_blured < mask_blured.max()] = 0 + mask_blured[mask_blured > 0] = 1 + # mask_blured = mask + mask_blured = cv2.GaussianBlur(mask_blured, kernel_list[kernel_idxs[1]], 0) + mask_blured = mask_blured / (mask_blured.max()) + return mask_blured.reshape((mask_blured.shape + (1,))) + + +class I2GDataset(DeepfakeAbstractBaseDataset): + def __init__(self, config=None, mode='train'): + #config['GridShuffle']['p'] = 0 + super().__init__(config, mode) + real_images_list = [img for img, label in zip(self.image_list, self.label_list) if label == 0] + self.real_images_list = list(set(real_images_list)) # de-duplicate since DF,F2F,FS,NT have same real images + self.source_transforms = self.get_source_transforms() + self.transforms = self.get_transforms() + self.init_nearest() + + def init_nearest(self): + if os.path.exists('training/lib/nearest_face_info.pkl'): + with open('training/lib/nearest_face_info.pkl', 'rb') as f: + face_info = pickle.load(f) + self.face_info = face_info + # Check if the dictionary has already been created + if os.path.exists('training/lib/landmark_dict_ffall.pkl'): + with open('training/lib/landmark_dict_ffall.pkl', 'rb') as f: + landmark_dict = pickle.load(f) + self.landmark_dict = landmark_dict + + def reorder_landmark(self, landmark): + landmark = landmark.copy() + landmark_add = np.zeros((13, 2)) + for idx, idx_l in enumerate([77, 75, 76, 68, 69, 70, 71, 80, 72, 73, 79, 74, 78]): + landmark_add[idx] = landmark[idx_l] + landmark[68:] = landmark_add + return landmark + + def hflip(self, img, mask=None, landmark=None, bbox=None): + H, W = img.shape[:2] + landmark = landmark.copy() + if bbox is not None: + bbox = bbox.copy() + + if landmark is not None: + landmark_new = np.zeros_like(landmark) + + landmark_new[:17] = landmark[:17][::-1] + landmark_new[17:27] = landmark[17:27][::-1] + + landmark_new[27:31] = landmark[27:31] + landmark_new[31:36] = landmark[31:36][::-1] + + landmark_new[36:40] = landmark[42:46][::-1] + landmark_new[40:42] = landmark[46:48][::-1] + + landmark_new[42:46] = landmark[36:40][::-1] + landmark_new[46:48] = landmark[40:42][::-1] + + landmark_new[48:55] = landmark[48:55][::-1] + landmark_new[55:60] = landmark[55:60][::-1] + + landmark_new[60:65] = landmark[60:65][::-1] + landmark_new[65:68] = landmark[65:68][::-1] + if len(landmark) == 68: + pass + elif len(landmark) == 81: + landmark_new[68:81] = landmark[68:81][::-1] + else: + raise NotImplementedError + landmark_new[:, 0] = W - landmark_new[:, 0] + + else: + landmark_new = None + + if bbox is not None: + bbox_new = np.zeros_like(bbox) + bbox_new[0, 0] = bbox[1, 0] + bbox_new[1, 0] = bbox[0, 0] + bbox_new[:, 0] = W - bbox_new[:, 0] + bbox_new[:, 1] = bbox[:, 1].copy() + if len(bbox) > 2: + bbox_new[2, 0] = W - bbox[3, 0] + bbox_new[2, 1] = bbox[3, 1] + bbox_new[3, 0] = W - bbox[2, 0] + bbox_new[3, 1] = bbox[2, 1] + bbox_new[4, 0] = W - bbox[4, 0] + bbox_new[4, 1] = bbox[4, 1] + bbox_new[5, 0] = W - bbox[6, 0] + bbox_new[5, 1] = bbox[6, 1] + bbox_new[6, 0] = W - bbox[5, 0] + bbox_new[6, 1] = bbox[5, 1] + else: + bbox_new = None + + if mask is not None: + mask = mask[:, ::-1] + else: + mask = None + img = img[:, ::-1].copy() + return img, mask, landmark_new, bbox_new + + + + def get_source_transforms(self): + return A.Compose([ + A.Compose([ + A.RGBShift((-20, 20), (-20, 20), (-20, 20), p=0.3), + A.HueSaturationValue(hue_shift_limit=(-0.3, 0.3), sat_shift_limit=(-0.3, 0.3), + val_shift_limit=(-0.3, 0.3), p=1), + A.RandomBrightnessContrast(brightness_limit=(-0.1, 0.1), contrast_limit=(-0.1, 0.1), p=1), + ], p=1), + + A.OneOf([ + RandomDownScale(p=1), + A.Sharpen(alpha=(0.2, 0.5), lightness=(0.5, 1.0), p=1), + ], p=1), + + ], p=1.) + + def get_fg_bg(self, one_lmk_path): + """ + Get foreground and background paths + """ + bg_lmk_path = one_lmk_path + # Randomly pick one from the nearest neighbors for the foreground + if bg_lmk_path in self.face_info: + fg_lmk_path = random.choice(self.face_info[bg_lmk_path]) + else: + fg_lmk_path = bg_lmk_path + return fg_lmk_path, bg_lmk_path + + def get_transforms(self): + return A.Compose([ + + A.RGBShift((-20, 20), (-20, 20), (-20, 20), p=0.3), + A.HueSaturationValue(hue_shift_limit=(-0.3, 0.3), sat_shift_limit=(-0.3, 0.3), + val_shift_limit=(-0.3, 0.3), p=0.3), + A.RandomBrightnessContrast(brightness_limit=(-0.3, 0.3), contrast_limit=(-0.3, 0.3), p=0.3), + A.ImageCompression(quality_lower=40, quality_upper=100, p=0.5), + + ], + additional_targets={f'image1': 'image'}, + p=1.) + + def randaffine(self, img, mask): + f = A.Affine( + translate_percent={'x': (-0.03, 0.03), 'y': (-0.015, 0.015)}, + scale=[0.95, 1 / 0.95], + fit_output=False, + p=1) + + g = A.ElasticTransform( + alpha=50, + sigma=7, + alpha_affine=0, + p=1, + ) + + transformed = f(image=img, mask=mask) + img = transformed['image'] + + mask = transformed['mask'] + transformed = g(image=img, mask=mask) + mask = transformed['mask'] + return img, mask + + def __len__(self): + return len(self.real_images_list) + + + def colorTransfer(self, src, dst, mask): + transferredDst = np.copy(dst) + maskIndices = np.where(mask != 0) + maskedSrc = src[maskIndices[0], maskIndices[1]].astype(np.float32) + maskedDst = dst[maskIndices[0], maskIndices[1]].astype(np.float32) + + # Compute means and standard deviations + meanSrc = np.mean(maskedSrc, axis=0) + stdSrc = np.std(maskedSrc, axis=0) + meanDst = np.mean(maskedDst, axis=0) + stdDst = np.std(maskedDst, axis=0) + + # Perform color transfer + maskedDst = (maskedDst - meanDst) * (stdSrc / stdDst) + meanSrc + maskedDst = np.clip(maskedDst, 0, 255) + + # Copy the entire background into transferredDst + transferredDst = np.copy(dst) + # Now apply color transfer only to the masked region + transferredDst[maskIndices[0], maskIndices[1]] = maskedDst.astype(np.uint8) + + return transferredDst + + + + def two_blending(self, img_bg, img_fg, landmark): + H, W = len(img_bg), len(img_bg[0]) + if np.random.rand() < 0.25: + landmark = landmark[:68] + logging.disable(logging.FATAL) + mask = random_get_hull(landmark, img_bg) + logging.disable(logging.NOTSET) + source = img_fg.copy() + target = img_bg.copy() + # if np.random.rand() < 0.5: + # source = self.source_transforms(image=source.astype(np.uint8))['image'] + # else: + # target = self.source_transforms(image=target.astype(np.uint8))['image'] + source_v2, mask_v2 = self.randaffine(source, mask) + source_v3=self.colorTransfer(target,source_v2,mask_v2) + img_blended, mask = dynamic_blend(source_v3, target, mask_v2) + img_blended = img_blended.astype(np.uint8) + img = img_bg.astype(np.uint8) + + return img, img_blended, mask.squeeze(2) + + + def __getitem__(self, index): + image_path_bg = self.real_images_list[index] + label = 0 + + # Get the mask and landmark paths + landmark_path_bg = image_path_bg.replace('frames', 'landmarks').replace('.png', '.npy') # Use .npy for landmark + landmark_path_fg, landmark_path_bg = self.get_fg_bg(landmark_path_bg) + image_path_fg = landmark_path_fg.replace('landmarks','frames').replace('.npy','.png') + try: + image_bg = self.load_rgb(image_path_bg) + image_fg = self.load_rgb(image_path_fg) + except Exception as e: + # Skip this image and return the first one + print(f"Error loading image at index {index}: {e}") + return self.__getitem__(0) + image_bg = np.array(image_bg) # Convert to numpy array for data augmentation + image_fg = np.array(image_fg) # Convert to numpy array for data augmentation + + landmarks_bg = self.load_landmark(landmark_path_bg) + landmarks_fg = self.load_landmark(landmark_path_fg) + + + landmarks_bg = np.clip(landmarks_bg, 0, self.config['resolution'] - 1) + landmarks_bg = self.reorder_landmark(landmarks_bg) + + img_r, img_f, mask_f = self.two_blending(image_bg.copy(), image_fg.copy(),landmarks_bg.copy()) + transformed = self.transforms(image=img_f.astype('uint8'), image1=img_r.astype('uint8')) + img_f = transformed['image'] + img_r = transformed['image1'] + # img_f = img_f.transpose((2, 0, 1)) + # img_r = img_r.transpose((2, 0, 1)) + img_f = self.normalize(self.to_tensor(img_f)) + img_r = self.normalize(self.to_tensor(img_r)) + mask_f = self.to_tensor(mask_f) + mask_r=torch.zeros_like(mask_f) # zeros or ones + return img_f, img_r, mask_f,mask_r + + @staticmethod + def collate_fn(batch): + img_f, img_r, mask_f,mask_r = zip(*batch) + data = {} + fake_mask = torch.stack(mask_f,dim=0) + real_mask = torch.stack(mask_r, dim=0) + fake_images = torch.stack(img_f, dim=0) + real_images = torch.stack(img_r, dim=0) + data['image'] = torch.cat([real_images, fake_images], dim=0) + data['label'] = torch.tensor([0] * len(img_r) + [1] * len(img_f)) + data['landmark'] = None + data['mask'] = torch.cat([real_mask, fake_mask], dim=0) + return data + + +if __name__ == '__main__': + detector_path = r"./training/config/detector/xception.yaml" + # weights_path = "./ckpts/xception/CDFv2/tb_v1/ov.pth" + with open(detector_path, 'r') as f: + config = yaml.safe_load(f) + with open('./training/config/train_config.yaml', 'r') as f: + config2 = yaml.safe_load(f) + config2['data_manner'] = 'lmdb' + config['dataset_json_folder'] = 'preprocessing/dataset_json_v3' + config.update(config2) + dataset = I2GDataset(config=config) + batch_size = 2 + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True,collate_fn=dataset.collate_fn) + + for i, batch in enumerate(dataloader): + print(f"Batch {i}: {batch}") + continue + diff --git a/training/dataset/__init__.py b/training/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8d34851aef959f7487e13f3d8d195e650abb0489 --- /dev/null +++ b/training/dataset/__init__.py @@ -0,0 +1,19 @@ +import os +import sys +current_file_path = os.path.abspath(__file__) +parent_dir = os.path.dirname(os.path.dirname(current_file_path)) +project_root_dir = os.path.dirname(parent_dir) +sys.path.append(parent_dir) +sys.path.append(project_root_dir) + + +from .I2G_dataset import I2GDataset +from .iid_dataset import IIDDataset +from .abstract_dataset import DeepfakeAbstractBaseDataset +from .ff_blend import FFBlendDataset +from .fwa_blend import FWABlendDataset +from .lrl_dataset import LRLDataset +from .pair_dataset import pairDataset +from .sbi_dataset import SBIDataset +from .lsda_dataset import LSDADataset +from .tall_dataset import TALLDataset diff --git a/training/dataset/abstract_dataset.py b/training/dataset/abstract_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e17c7c50aa8f128d56534aad6ce716e6cf938854 --- /dev/null +++ b/training/dataset/abstract_dataset.py @@ -0,0 +1,655 @@ +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-03-30 +# description: Abstract Base Class for all types of deepfake datasets. + +import sys + +import lmdb + +sys.path.append('.') + +import os +import math +import yaml +import glob +import json + +import numpy as np +from copy import deepcopy +import cv2 +import random +from PIL import Image +from collections import defaultdict + +import torch +from torch.autograd import Variable +from torch.utils import data +from torchvision import transforms as T + +import albumentations as A + +from .albu import IsotropicResize + +FFpp_pool=['FaceForensics++','FaceShifter','DeepFakeDetection','FF-DF','FF-F2F','FF-FS','FF-NT']# + +# Change this to your actual dataset root path +DATASET_GLOBAL_PATH = "/dockerdata/deepfakes_detection_datasets/" + + +def all_in_pool(inputs,pool): + for each in inputs: + if each not in pool: + return False + return True + + +class DeepfakeAbstractBaseDataset(data.Dataset): + """ + Abstract base class for all deepfake datasets. + """ + def __init__(self, config=None, mode='train'): + """Initializes the dataset object. + + Args: + config (dict): A dictionary containing configuration parameters. + mode (str): A string indicating the mode (train or test). + + Raises: + NotImplementedError: If mode is not train or test. + """ + + # Set the configuration and mode + self.config = config + self.mode = mode + self.compression = config['compression'] + self.frame_num = config['frame_num'][mode] # + + # Check if 'video_mode' exists in config, otherwise set video_level to False + self.video_level = config.get('video_mode', False) + self.clip_size = config.get('clip_size', None) + self.lmdb = config.get('lmdb', False) + # Dataset dictionary + self.image_list = [] + self.label_list = [] + + # Set the dataset dictionary based on the mode + if mode == 'train': + dataset_list = config['train_dataset'] + # Training data should be collected together for training + image_list, label_list = [], [] + for one_data in dataset_list: + tmp_image, tmp_label, tmp_name = self.collect_img_and_label_for_one_dataset(one_data) + image_list.extend(tmp_image) + label_list.extend(tmp_label) + if self.lmdb: + if len(dataset_list)>1: + if all_in_pool(dataset_list,FFpp_pool): + lmdb_path = os.path.join(config['lmdb_dir'], f"FaceForensics++_lmdb") + self.env = lmdb.open(lmdb_path, create=False, subdir=True, readonly=True, lock=False) + else: + raise ValueError('Training with multiple dataset and lmdb is not implemented yet.') + else: + lmdb_path = os.path.join(config['lmdb_dir'], f"{dataset_list[0] if dataset_list[0] not in FFpp_pool else 'FaceForensics++'}_lmdb") + self.env = lmdb.open(lmdb_path, create=False, subdir=True, readonly=True, lock=False) + elif mode == 'test': + one_data = config['test_dataset'] + # Test dataset should be evaluated separately. So collect only one dataset each time + image_list, label_list, name_list = self.collect_img_and_label_for_one_dataset(one_data) + if self.lmdb: + lmdb_path = os.path.join(config['lmdb_dir'], f"{one_data}_lmdb" if one_data not in FFpp_pool else 'FaceForensics++_lmdb') + self.env = lmdb.open(lmdb_path, create=False, subdir=True, readonly=True, lock=False) + else: + raise NotImplementedError('Only train and test modes are supported.') + + assert len(image_list)!=0 and len(label_list)!=0, f"Collect nothing for {mode} mode!" + self.image_list, self.label_list = image_list, label_list + + + # Create a dictionary containing the image and label lists + self.data_dict = { + 'image': self.image_list, + 'label': self.label_list, + } + + self.transform = self.init_data_aug_method() + + def init_data_aug_method(self): + # trans = A.Compose([ + # A.HorizontalFlip(p=self.config['data_aug']['flip_prob']), + # A.Rotate(limit=self.config['data_aug']['rotate_limit'], p=self.config['data_aug']['rotate_prob']), + # A.GaussianBlur(blur_limit=self.config['data_aug']['blur_limit'], p=self.config['data_aug']['blur_prob']), + # A.OneOf([ + # IsotropicResize(max_side=self.config['resolution'], interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC), + # IsotropicResize(max_side=self.config['resolution'], interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_LINEAR), + # IsotropicResize(max_side=self.config['resolution'], interpolation_down=cv2.INTER_LINEAR, interpolation_up=cv2.INTER_LINEAR), + # ], p = 0 if self.config['with_landmark'] else 1), + # A.OneOf([ + # A.RandomBrightnessContrast(brightness_limit=self.config['data_aug']['brightness_limit'], contrast_limit=self.config['data_aug']['contrast_limit']), + # A.FancyPCA(), + # A.HueSaturationValue() + # ], p=0.5), + # A.ImageCompression(quality_lower=self.config['data_aug']['quality_lower'], quality_upper=self.config['data_aug']['quality_upper'], p=0.5) + # ], + # keypoint_params=A.KeypointParams(format='xy') if self.config['with_landmark'] else None + # ) + + # video no aug + trans = A.Compose([ + A.HorizontalFlip(p=0.5), + A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5), + A.HueSaturationValue(p=0.3), + A.ImageCompression(quality_lower=40, quality_upper=100, p=0.1), # Image compression: 40-100, p=0.1 + A.GaussNoise(p=0.1), + A.MotionBlur(p=0.1), + A.CLAHE(p=0.1), + A.ChannelShuffle(p=0.1), + A.Cutout(p=0.1), + A.RandomGamma(p=0.3), + A.GlassBlur(p=0.3), + ]) + + return trans + + def rescale_landmarks(self, landmarks, original_size=256, new_size=224): + scale_factor = new_size / original_size + rescaled_landmarks = landmarks * scale_factor + return rescaled_landmarks + + + def collect_img_and_label_for_one_dataset(self, dataset_name: str): + """Collects image and label lists. + + Args: + dataset_name (str): A list containing one dataset information. e.g., 'FF-F2F' + + Returns: + list: A list of image paths. + list: A list of labels. + + Raises: + ValueError: If image paths or labels are not found. + NotImplementedError: If the dataset is not implemented yet. + """ + # Initialize the label and frame path lists + label_list = [] + frame_path_list = [] + + # Record video name for video-level metrics + video_name_list = [] + + # Try to get the dataset information from the JSON file + if not os.path.exists(self.config['dataset_json_folder']): + self.config['dataset_json_folder'] = self.config['dataset_json_folder'].replace('/Youtu_Pangu_Security_Public', '/Youtu_Pangu_Security/public') + try: + with open(os.path.join(self.config['dataset_json_folder'], dataset_name + '.json'), 'r') as f: + dataset_info = json.load(f) + except Exception as e: + print(e) + raise ValueError(f'dataset {dataset_name} not exist!') + + # If JSON file exists, do the following data collection + # FIXME: ugly, need to be modified here. + cp = None + if dataset_name == 'FaceForensics++_c40': + dataset_name = 'FaceForensics++' + cp = 'c40' + elif dataset_name == 'FF-DF_c40': + dataset_name = 'FF-DF' + cp = 'c40' + elif dataset_name == 'FF-F2F_c40': + dataset_name = 'FF-F2F' + cp = 'c40' + elif dataset_name == 'FF-FS_c40': + dataset_name = 'FF-FS' + cp = 'c40' + elif dataset_name == 'FF-NT_c40': + dataset_name = 'FF-NT' + cp = 'c40' + + # Get the information for the current dataset + dataset_name = list(dataset_info.keys())[0] + for label in dataset_info[dataset_name]: + sub_dataset_info = dataset_info[dataset_name][label][self.mode] + # Special case for FaceForensics++ and DeepFakeDetection, choose the compression type + if cp == None and dataset_name in ['FF-DF', 'FF-F2F', 'FF-FS', 'FF-NT', 'FaceForensics++','DeepFakeDetection','FaceShifter']: + sub_dataset_info = sub_dataset_info[self.compression] + elif cp == 'c40' and dataset_name in ['FF-DF', 'FF-F2F', 'FF-FS', 'FF-NT', 'FaceForensics++','DeepFakeDetection','FaceShifter']: + sub_dataset_info = sub_dataset_info['c40'] + + # Iterate over the videos in the dataset + for video_name, video_info in sub_dataset_info.items(): + # Unique video name + unique_video_name = video_info['label'] + '_' + video_name + + # Get the label and frame paths for the current video + if video_info['label'] not in self.config['label_dict']: + raise ValueError(f'Label {video_info["label"]} is not found in the configuration file.') + label = self.config['label_dict'][video_info['label']] + frame_paths = video_info['frames'] + + frame_paths = [DATASET_GLOBAL_PATH + i for i in frame_paths] + + # sorted video path to the lists + if len(frame_paths) == 0: + continue + + if '\\' in frame_paths[0]: + frame_paths = sorted(frame_paths, key=lambda x: str(x.split('\\')[-1].split('.')[0])) + else: + frame_paths = sorted(frame_paths, key=lambda x: str(x.split('/')[-1].split('.')[0])) + + # Consider the case when the actual number of frames (e.g., 270) is larger than the specified (i.e., self.frame_num=32) + # In this case, we select self.frame_num frames from the original 270 frames + total_frames = len(frame_paths) + if self.frame_num < total_frames: + total_frames = self.frame_num + if self.video_level: + # Select clip_size continuous frames + start_frame = random.randint(0, total_frames - self.frame_num) if self.mode == 'train' else 0 + frame_paths = frame_paths[start_frame:start_frame + self.frame_num] # update total_frames + else: + # Select self.frame_num frames evenly distributed throughout the video + step = total_frames // self.frame_num + frame_paths = [frame_paths[i] for i in range(0, total_frames, step)][:self.frame_num] + + # If video-level methods, crop clips from the selected frames if needed + if self.video_level: + if self.clip_size is None: + raise ValueError('clip_size must be specified when video_level is True.') + # Check if the number of total frames is greater than or equal to clip_size + if total_frames >= self.clip_size: + # Initialize an empty list to store the selected continuous frames + selected_clips = [] + + # Calculate the number of clips to select + num_clips = total_frames // self.clip_size + + if num_clips > 1: + # Calculate the step size between each clip + clip_step = (total_frames - self.clip_size) // (num_clips - 1) + + # Select clip_size continuous frames from each part of the video + for i in range(num_clips): + # Ensure start_frame + self.clip_size - 1 does not exceed the index of the last frame + start_frame = random.randrange(i * clip_step, min((i + 1) * clip_step, total_frames - self.clip_size + 1)) if self.mode == 'train' else i * clip_step + continuous_frames = frame_paths[start_frame:start_frame + self.clip_size] + assert len(continuous_frames) == self.clip_size, 'clip_size is not equal to the length of frame_path_list' + selected_clips.append(continuous_frames) + + else: + start_frame = random.randrange(0, total_frames - self.clip_size + 1) if self.mode == 'train' else 0 + continuous_frames = frame_paths[start_frame:start_frame + self.clip_size] + assert len(continuous_frames)==self.clip_size, 'clip_size is not equal to the length of frame_path_list' + selected_clips.append(continuous_frames) + + # Append the list of selected clips and append the label + label_list.extend([label] * len(selected_clips)) + frame_path_list.extend(selected_clips) + # video name save + video_name_list.extend([unique_video_name] * len(selected_clips)) + + else: + print(f"Skipping video {unique_video_name} because it has less than clip_size ({self.clip_size}) frames ({total_frames}).") + + # Otherwise, extend the label and frame paths to the lists according to the number of frames + else: + # Extend the label and frame paths to the lists according to the number of frames + label_list.extend([label] * total_frames) + frame_path_list.extend(frame_paths) + # video name save + video_name_list.extend([unique_video_name] * len(frame_paths)) + + # Shuffle the label and frame path lists in the same order + shuffled = list(zip(label_list, frame_path_list, video_name_list)) + random.shuffle(shuffled) + label_list, frame_path_list, video_name_list = zip(*shuffled) + + return frame_path_list, label_list, video_name_list + + + def load_rgb(self, file_path): + """ + Load an RGB image from a file path and resize it to a specified resolution. + + Args: + file_path: A string indicating the path to the image file. + + Returns: + An Image object containing the loaded and resized image. + + Raises: + ValueError: If the loaded image is None. + """ + size = self.config['resolution'] # if self.mode == "train" else self.config['resolution'] + if not self.lmdb: + # if not file_path[0] == '.': + # file_path = f'./{self.config["rgb_dir"]}\\'+file_path + if not os.path.exists(file_path): + file_path = file_path.replace('\\', '/') + assert os.path.exists(file_path), f"{file_path} does not exist" + img = cv2.imread(file_path) + if img is None: + raise ValueError('Loaded image is None: {}'.format(file_path)) + elif self.lmdb: + with self.env.begin(write=False) as txn: + # transfer the path format from rgb-path to lmdb-key + if file_path[0]=='.': + file_path=file_path.replace('./datasets\\','') + + image_bin = txn.get(file_path.encode()) + image_buf = np.frombuffer(image_bin, dtype=np.uint8) + img = cv2.imdecode(image_buf, cv2.IMREAD_COLOR) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC) + return Image.fromarray(np.array(img, dtype=np.uint8)) + + + def load_mask(self, file_path): + """ + Load a binary mask image from a file path and resize it to a specified resolution. + + Args: + file_path: A string indicating the path to the mask file. + + Returns: + A numpy array containing the loaded and resized mask. + + Raises: + None. + """ + size = self.config['resolution'] + if file_path is None: + return np.zeros((size, size, 1)) + if not self.lmdb: + # if not file_path[0] == '.': + # file_path = f'./{self.config["rgb_dir"]}\\'+file_path + if os.path.exists(file_path): + mask = cv2.imread(file_path, 0) + if mask is None: + mask = np.zeros((size, size)) + else: + return np.zeros((size, size, 1)) + else: + with self.env.begin(write=False) as txn: + # transfer the path format from rgb-path to lmdb-key + if file_path[0]=='.': + file_path=file_path.replace('./datasets\\','') + + image_bin = txn.get(file_path.encode()) + if image_bin is None: + mask = np.zeros((size, size,3)) + else: + image_buf = np.frombuffer(image_bin, dtype=np.uint8) + + mask = cv2.imdecode(image_buf, cv2.IMREAD_COLOR) + mask = cv2.resize(mask, (size, size)) / 255 + mask = np.expand_dims(mask, axis=2) + return np.float32(mask) + + def load_landmark(self, file_path): + """ + Load 2D facial landmarks from a file path. + + Args: + file_path: A string indicating the path to the landmark file. + + Returns: + A numpy array containing the loaded landmarks. + + Raises: + None. + """ + if file_path is None: + return np.zeros((81, 2)) + if not self.lmdb: + # if not file_path[0] == '.': + # file_path = f'./{self.config["rgb_dir"]}\\'+file_path + if os.path.exists(file_path): + landmark = np.load(file_path) + else: + return np.zeros((81, 2)) + else: + with self.env.begin(write=False) as txn: + # transfer the path format from rgb-path to lmdb-key + if file_path[0]=='.': + file_path=file_path.replace('./datasets\\','') + binary = txn.get(file_path.encode()) + landmark = np.frombuffer(binary, dtype=np.uint32).reshape((81, 2)) + landmark=self.rescale_landmarks(np.float32(landmark), original_size=256, new_size=self.config['resolution']) + return landmark + + def to_tensor(self, img): + """ + Convert an image to a PyTorch tensor. + """ + return T.ToTensor()(img) + + def normalize(self, img): + """ + Normalize an image. + """ + mean = self.config['mean'] + std = self.config['std'] + normalize = T.Normalize(mean=mean, std=std) + return normalize(img) + + def data_aug(self, img, landmark=None, mask=None, augmentation_seed=None): + """ + Apply data augmentation to an image, landmark, and mask. + + Args: + img: An Image object containing the image to be augmented. + landmark: A numpy array containing the 2D facial landmarks to be augmented. + mask: A numpy array containing the binary mask to be augmented. + + Returns: + The augmented image, landmark, and mask. + """ + + # Set the seed for the random number generator + if augmentation_seed is not None: + random.seed(augmentation_seed) + np.random.seed(augmentation_seed) + + # Create a dictionary of arguments + kwargs = {'image': img} + + # Check if the landmark and mask are not None + if landmark is not None: + kwargs['keypoints'] = landmark + kwargs['keypoint_params'] = A.KeypointParams(format='xy') + if mask is not None: + mask = mask.squeeze(2) + if mask.max() > 0: + kwargs['mask'] = mask + + # Apply data augmentation + transformed = self.transform(**kwargs) + + # Get the augmented image, landmark, and mask + augmented_img = transformed['image'] + augmented_landmark = transformed.get('keypoints') + augmented_mask = transformed.get('mask',mask) + + # Convert the augmented landmark to a numpy array + if augmented_landmark is not None: + augmented_landmark = np.array(augmented_landmark) + + # Reset the seeds to ensure different transformations for different videos + if augmentation_seed is not None: + random.seed() + np.random.seed() + + return augmented_img, augmented_landmark, augmented_mask + + def __getitem__(self, index, no_norm=False): + """ + Returns the data point at the given index. + + Args: + index (int): The index of the data point. + + Returns: + A tuple containing the image tensor, the label tensor, the landmark tensor, + and the mask tensor. + """ + # Get the image paths and label + image_paths = self.data_dict['image'][index] + label = self.data_dict['label'][index] + + # Image-level: FaceForensics++\manipulated_sequences\NeuralTextures\c23\frames\487_477\000.png + # Video-level: image_paths ['FaceForensics++\\original_sequences\\youtube\\c23\\frames\\977\\000.png', ..., 'FaceForensics++\\original_sequences\\youtube\\c23\\frames\\977\\314.png'] + if not isinstance(image_paths, list): + image_paths = [image_paths] # for the image-level IO, only one frame is used + + image_tensors = [] + landmark_tensors = [] + mask_tensors = [] + augmentation_seed = None + + for image_path in image_paths: + # Initialize a new seed for data augmentation at the start of each video + if self.video_level and image_path == image_paths[0]: + augmentation_seed = random.randint(0, 2**32 - 1) + + # Get the mask and landmark paths + mask_path = image_path.replace('frames', 'masks') # Use .png for mask + landmark_path = image_path.replace('frames', 'landmarks').replace('.png', '.npy') # Use .npy for landmark + + # Load the image + try: + image = self.load_rgb(image_path) + except Exception as e: + # Skip this image and return the first one + print(f"Error loading image at index {index}: {e}") + return self.__getitem__(0) + image = np.array(image) # Convert to numpy array for data augmentation + + # Load mask and landmark (if needed) + if self.config['with_mask']: + mask = self.load_mask(mask_path) + else: + mask = None + if self.config['with_landmark']: + landmarks = self.load_landmark(landmark_path) + else: + landmarks = None + + # Do Data Augmentation + if self.mode == 'train' and self.config['use_data_augmentation']: + image_trans, landmarks_trans, mask_trans = self.data_aug(image, landmarks, mask, augmentation_seed) + else: + # if self.mode == 'train': + # print("Train w/o data_augmentation") + image_trans, landmarks_trans, mask_trans = deepcopy(image), deepcopy(landmarks), deepcopy(mask) + + + # To tensor and normalize + if not no_norm: + image_trans = self.normalize(self.to_tensor(image_trans)) + if self.config['with_landmark']: + landmarks_trans = torch.from_numpy(landmarks) + if self.config['with_mask']: + mask_trans = torch.from_numpy(mask_trans) + + image_tensors.append(image_trans) + landmark_tensors.append(landmarks_trans) + mask_tensors.append(mask_trans) + + if self.video_level: + # Stack image tensors along a new dimension (time) + image_tensors = torch.stack(image_tensors, dim=0) + # Stack landmark and mask tensors along a new dimension (time) + if not any(landmark is None or (isinstance(landmark, list) and None in landmark) for landmark in landmark_tensors): + landmark_tensors = torch.stack(landmark_tensors, dim=0) + if not any(m is None or (isinstance(m, list) and None in m) for m in mask_tensors): + mask_tensors = torch.stack(mask_tensors, dim=0) + else: + # Get the first image tensor + image_tensors = image_tensors[0] + # Get the first landmark and mask tensors + if not any(landmark is None or (isinstance(landmark, list) and None in landmark) for landmark in landmark_tensors): + landmark_tensors = landmark_tensors[0] + if not any(m is None or (isinstance(m, list) and None in m) for m in mask_tensors): + mask_tensors = mask_tensors[0] + + return image_tensors, label, landmark_tensors, mask_tensors + + @staticmethod + def collate_fn(batch): + """ + Collate a batch of data points. + + Args: + batch (list): A list of tuples containing the image tensor, the label tensor, + the landmark tensor, and the mask tensor. + + Returns: + A tuple containing the image tensor, the label tensor, the landmark tensor, + and the mask tensor. + """ + # Separate the image, label, landmark, and mask tensors + images, labels, landmarks, masks = zip(*batch) + + # Stack the image, label, landmark, and mask tensors + images = torch.stack(images, dim=0) + labels = torch.LongTensor(labels) + + # Special case for landmarks and masks if they are None + if not any(landmark is None or (isinstance(landmark, list) and None in landmark) for landmark in landmarks): + landmarks = torch.stack(landmarks, dim=0) + else: + landmarks = None + + if not any(m is None or (isinstance(m, list) and None in m) for m in masks): + masks = torch.stack(masks, dim=0) + else: + masks = None + + # Create a dictionary of the tensors + data_dict = {} + data_dict['image'] = images + data_dict['label'] = labels + data_dict['landmark'] = landmarks + data_dict['mask'] = masks + return data_dict + + def __len__(self): + """ + Return the length of the dataset. + + Args: + None. + + Returns: + An integer indicating the length of the dataset. + + Raises: + AssertionError: If the number of images and labels in the dataset are not equal. + """ + assert len(self.image_list) == len(self.label_list), 'Number of images and labels are not equal' + return len(self.image_list) + + +if __name__ == "__main__": + with open('/data/home/zhiyuanyan/DeepfakeBench/training/config/detector/video_baseline.yaml', 'r') as f: + config = yaml.safe_load(f) + train_set = DeepfakeAbstractBaseDataset( + config = config, + mode = 'train', + ) + train_data_loader = \ + torch.utils.data.DataLoader( + dataset=train_set, + batch_size=config['train_batchSize'], + shuffle=True, + num_workers=0, + collate_fn=train_set.collate_fn, + ) + from tqdm import tqdm + for iteration, batch in enumerate(tqdm(train_data_loader)): + # print(iteration) + ... + # if iteration > 10: + # break diff --git a/training/dataset/albu.py b/training/dataset/albu.py new file mode 100644 index 0000000000000000000000000000000000000000..2bd8f3aaa5a91c83893180601094505f672769f4 --- /dev/null +++ b/training/dataset/albu.py @@ -0,0 +1,99 @@ +import random + +import cv2 +import numpy as np +from albumentations import DualTransform, ImageOnlyTransform +from albumentations.augmentations.crops.functional import crop + + +def isotropically_resize_image(img, size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC): + h, w = img.shape[:2] + if max(w, h) == size: + return img + if w > h: + scale = size / w + h = h * scale + w = size + else: + scale = size / h + w = w * scale + h = size + interpolation = interpolation_up if scale > 1 else interpolation_down + resized = cv2.resize(img, (int(w), int(h)), interpolation=interpolation) + return resized + + +class IsotropicResize(DualTransform): + def __init__(self, max_side, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC, + always_apply=False, p=1): + super(IsotropicResize, self).__init__(always_apply, p) + self.max_side = max_side + self.interpolation_down = interpolation_down + self.interpolation_up = interpolation_up + + def apply(self, img, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC, **params): + return isotropically_resize_image(img, size=self.max_side, interpolation_down=interpolation_down, + interpolation_up=interpolation_up) + + def apply_to_mask(self, img, **params): + return self.apply(img, interpolation_down=cv2.INTER_NEAREST, interpolation_up=cv2.INTER_NEAREST, **params) + + def get_transform_init_args_names(self): + return ("max_side", "interpolation_down", "interpolation_up") + + +class Resize4xAndBack(ImageOnlyTransform): + def __init__(self, always_apply=False, p=0.5): + super(Resize4xAndBack, self).__init__(always_apply, p) + + def apply(self, img, **params): + h, w = img.shape[:2] + scale = random.choice([2, 4]) + img = cv2.resize(img, (w // scale, h // scale), interpolation=cv2.INTER_AREA) + img = cv2.resize(img, (w, h), + interpolation=random.choice([cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_NEAREST])) + return img + + +class RandomSizedCropNonEmptyMaskIfExists(DualTransform): + + def __init__(self, min_max_height, w2h_ratio=[0.7, 1.3], always_apply=False, p=0.5): + super(RandomSizedCropNonEmptyMaskIfExists, self).__init__(always_apply, p) + + self.min_max_height = min_max_height + self.w2h_ratio = w2h_ratio + + def apply(self, img, x_min=0, x_max=0, y_min=0, y_max=0, **params): + cropped = crop(img, x_min, y_min, x_max, y_max) + return cropped + + @property + def targets_as_params(self): + return ["mask"] + + def get_params_dependent_on_targets(self, params): + mask = params["mask"] + mask_height, mask_width = mask.shape[:2] + crop_height = int(mask_height * random.uniform(self.min_max_height[0], self.min_max_height[1])) + w2h_ratio = random.uniform(*self.w2h_ratio) + crop_width = min(int(crop_height * w2h_ratio), mask_width - 1) + if mask.sum() == 0: + x_min = random.randint(0, mask_width - crop_width + 1) + y_min = random.randint(0, mask_height - crop_height + 1) + else: + mask = mask.sum(axis=-1) if mask.ndim == 3 else mask + non_zero_yx = np.argwhere(mask) + y, x = random.choice(non_zero_yx) + x_min = x - random.randint(0, crop_width - 1) + y_min = y - random.randint(0, crop_height - 1) + x_min = np.clip(x_min, 0, mask_width - crop_width) + y_min = np.clip(y_min, 0, mask_height - crop_height) + + x_max = x_min + crop_height + y_max = y_min + crop_width + y_max = min(mask_height, y_max) + x_max = min(mask_width, x_max) + return {"x_min": x_min, "x_max": x_max, "y_min": y_min, "y_max": y_max} + + def get_transform_init_args_names(self): + return "min_max_height", "height", "width", "w2h_ratio" \ No newline at end of file diff --git a/training/dataset/face_utils.py b/training/dataset/face_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..27ae0f48f93b1453950ffdcdef54b909c5c0314f --- /dev/null +++ b/training/dataset/face_utils.py @@ -0,0 +1,238 @@ +import cv2 +import numpy as np +from skimage import transform as trans +# from mtcnn.mtcnn import MTCNN + + +def get_keypts(face): + # get key points from the results of mtcnn + + if len(face['keypoints']) == 0: + return [] + + leye = np.array(face['keypoints']['left_eye'], dtype=np.int).reshape(-1, 2) + reye = np.array(face['keypoints']['right_eye'], + dtype=np.int).reshape(-1, 2) + nose = np.array(face['keypoints']['nose'], dtype=np.int).reshape(-1, 2) + lmouth = np.array(face['keypoints']['mouth_left'], + dtype=np.int).reshape(-1, 2) + rmouth = np.array(face['keypoints']['mouth_right'], + dtype=np.int).reshape(-1, 2) + + pts = np.concatenate([leye, reye, nose, lmouth, rmouth], axis=0) + + return pts + + +def img_align_crop(img, landmark=None, outsize=None, scale=1.3, mask=None): + """ align and crop the face according to the given bbox and landmarks + landmark: 5 key points + """ + + M = None + + target_size = [112, 112] + + dst = np.array([ + [30.2946, 51.6963], + [65.5318, 51.5014], + [48.0252, 71.7366], + [33.5493, 92.3655], + [62.7299, 92.2041]], dtype=np.float32) + + if target_size[1] == 112: + dst[:, 0] += 8.0 + + dst[:, 0] = dst[:, 0] * outsize[0] / target_size[0] + dst[:, 1] = dst[:, 1] * outsize[1] / target_size[1] + + target_size = outsize + + margin_rate = scale - 1 + x_margin = target_size[0] * margin_rate / 2. + y_margin = target_size[1] * margin_rate / 2. + + # move + dst[:, 0] += x_margin + dst[:, 1] += y_margin + + # resize + dst[:, 0] *= target_size[0] / (target_size[0] + 2 * x_margin) + dst[:, 1] *= target_size[1] / (target_size[1] + 2 * y_margin) + + src = landmark.astype(np.float32) + + # use skimage tranformation + tform = trans.SimilarityTransform() + tform.estimate(src, dst) + M = tform.params[0:2, :] + + # M: use opencv + # M = cv2.getAffineTransform(src[[0,1,2],:],dst[[0,1,2],:]) + + img = cv2.warpAffine(img, M, (target_size[1], target_size[0])) + + if outsize is not None: + img = cv2.resize(img, (outsize[1], outsize[0])) + + if mask is not None: + mask = cv2.warpAffine(mask, M, (target_size[1], target_size[0])) + mask = cv2.resize(mask, (outsize[1], outsize[0])) + return img, mask + else: + return img + + + + + +def expand_bbox(bbox, width, height, scale=1.3, minsize=None): + """ + Expand original boundingbox by scale. + :param bbx: original boundingbox + :param width: frame width + :param height: frame height + :param scale: bounding box size multiplier to get a bigger face region + :param minsize: set minimum bounding box size + :return: expanded bbox + """ + x, y, w, h = bbox + + # box center + cx = int(x + w / 2) + cy = int(y + h / 2) + + # expand by scale factor + new_size = max(int(w * scale), int(h * scale)) + new_x = max(0, int(cx - new_size / 2)) + new_y = max(0, int(cy - new_size / 2)) + + # Check for too big bbox for given x, y + new_size = min(width - new_x, new_size) + new_size = min(height - new_size, new_size) + + return new_x, new_y, new_size, new_size + + +def extract_face_MTCNN(face_detector, image, expand_scale=1.3, res=256): + # Image size + height, width = image.shape[:2] + + # Convert to rgb + rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + # Detect with dlib + faces = face_detector.detect_faces(rgb) + if len(faces): + # For now only take biggest face + face = None + bbox = None + max_region = 0 + for ff in faces: + if max_region == 0: + face = ff + bbox = face['box'] + max_region = bbox[2]*bbox[3] + else: + bb = ff['box'] + region = bb[2]*bb[3] + if region > max_rigion: + max_rigion = region + face = ff + bbox = face['box'] + print(max_region) + #face = faces[0] + + #bbox = face['box'] + + # --- Prediction --------------------------------------------------- + # Face crop with MTCNN and bounding box scale enlargement + x, y, w, h = expand_bbox(bbox, width, height, scale=expand_scale) + cropped_face = rgb[y:y+h, x:x+w] + + cropped_face = cv2.resize( + cropped_face, (res, res), interpolation=cv2.INTER_CUBIC) + cropped_face = cv2.cvtColor(cropped_face, cv2.COLOR_RGB2BGR) + return cropped_face + + return None + + +def extract_aligned_face_MTCNN(face_detector, image, expand_scale=1.3, res=256, mask=None): + # Image size + height, width = image.shape[:2] + + # Convert to rgb + rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + # Detect with dlib + faces = face_detector.detect_faces(rgb) + if len(faces): + # For now only take biggest face + face = None + bbox = None + max_region = 0 + for i, ff in enumerate(faces): + if max_region == 0: + face = ff + bbox = face['box'] + max_region = bbox[2]*bbox[3] + else: + bb = ff['box'] + region = bb[2]*bb[3] + if region > max_region: + max_region = region + face = ff + bbox = face['box'] + #print('face {}: {}'.format(i, max_region)) + #face = faces[0] + + landmarks = get_keypts(face) + + # --- Prediction --------------------------------------------------- + # Face aligned crop with MTCNN and bounding box scale enlargement + if mask is not None: + cropped_face, cropped_mask = img_align_crop(rgb, landmarks, outsize=[ + res, res], scale=expand_scale, mask=mask) + cropped_face = cv2.cvtColor(cropped_face, cv2.COLOR_RGB2BGR) + cropped_mask = cv2.cvtColor(cropped_mask, cv2.COLOR_RGB2GRAY) + return cropped_face, cropped_mask + else: + cropped_face = img_align_crop(rgb, landmarks, outsize=[ + res, res], scale=expand_scale) + cropped_face = cv2.cvtColor(cropped_face, cv2.COLOR_RGB2BGR) + return cropped_face + + return None + + +def extract_face_DLIB(face_detector, image, expand_scale=1.3, res=256): + # Image size + height, width = image.shape[:2] + + # Convert to gray + gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + + # Detect with dlib + faces = face_detector(gray, 1) + if len(faces): + # For now only take biggest face + face = faces[0] + + x1 = face.left() + y1 = face.top() + x2 = face.right() + y2 = face.bottom() + bbox = (x1, y1, x2-x1, y2-y1) + + # --- Prediction --------------------------------------------------- + # Face crop with dlib and bounding box scale enlargement + x, y, w, h = expand_bbox(bbox, width, height, scale=expand_scale) + cropped_face = image[y:y+h, x:x+w] + + cropped_face = cv2.resize( + cropped_face, (res, res), interpolation=cv2.INTER_CUBIC) + + return cropped_face + + return None diff --git a/training/dataset/ff_blend.py b/training/dataset/ff_blend.py new file mode 100644 index 0000000000000000000000000000000000000000..c2b036a581725840764c1143175f26bbeb311083 --- /dev/null +++ b/training/dataset/ff_blend.py @@ -0,0 +1,572 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-03-30 + +The code is designed for Face X-ray. +''' + +import os +import sys +import json +import pickle +import time + +import lmdb +import numpy as np +import albumentations as A +import cv2 +import random +from PIL import Image +from skimage.util import random_noise +from scipy import linalg +import heapq as hq +import lmdb +import torch +from torch.autograd import Variable +from torch.utils import data +from torchvision import transforms as T +import torchvision + +from dataset.utils.face_blend import * +from dataset.utils.face_align import get_align_mat_new +from dataset.utils.color_transfer import color_transfer +from dataset.utils.faceswap_utils import blendImages as alpha_blend_fea +from dataset.utils.faceswap_utils import AlphaBlend as alpha_blend +from dataset.utils.face_aug import aug_one_im, change_res +from dataset.utils.image_ae import get_pretraiend_ae +from dataset.utils.warp import warp_mask +from dataset.utils import faceswap +from scipy.ndimage.filters import gaussian_filter + + +class RandomDownScale(A.core.transforms_interface.ImageOnlyTransform): + def apply(self,img,**params): + return self.randomdownscale(img) + + def randomdownscale(self,img): + keep_ratio=True + keep_input_shape=True + H,W,C=img.shape + ratio_list=[2,4] + r=ratio_list[np.random.randint(len(ratio_list))] + img_ds=cv2.resize(img,(int(W/r),int(H/r)),interpolation=cv2.INTER_NEAREST) + if keep_input_shape: + img_ds=cv2.resize(img_ds,(W,H),interpolation=cv2.INTER_LINEAR) + return img_ds + + +class FFBlendDataset(data.Dataset): + def __init__(self, config=None): + + self.lmdb = config.get('lmdb', False) + if self.lmdb: + lmdb_path = os.path.join(config['lmdb_dir'], f"FaceForensics++_lmdb") + self.env = lmdb.open(lmdb_path, create=False, subdir=True, readonly=True, lock=False) + + # Check if the dictionary has already been created + if os.path.exists('training/lib/nearest_face_info.pkl'): + with open('training/lib/nearest_face_info.pkl', 'rb') as f: + face_info = pickle.load(f) + else: + raise ValueError(f"Need to run the dataset/generate_xray_nearest.py before training the face xray.") + self.face_info = face_info + # Check if the dictionary has already been created + if os.path.exists('training/lib/landmark_dict_ffall.pkl'): + with open('training/lib/landmark_dict_ffall.pkl', 'rb') as f: + landmark_dict = pickle.load(f) + self.landmark_dict = landmark_dict + self.imid_list = self.get_training_imglist() + self.transforms = T.Compose([ + # T.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)), + # T.ColorJitter(hue=.05, saturation=.05), + # T.RandomHorizontalFlip(), + # T.RandomRotation(20, resample=Image.BILINEAR), + T.ToTensor(), + T.Normalize(mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5]) + ]) + self.data_dict = { + 'imid_list': self.imid_list + } + self.config=config + # def data_aug(self, im): + # """ + # Apply data augmentation on the input image. + # """ + # transform = T.Compose([ + # T.ToPILImage(), + # T.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)), + # T.ColorJitter(hue=.05, saturation=.05), + # ]) + # # Apply transformations + # im_aug = transform(im) + # return im_aug + + def blended_aug(self, im): + transform = A.Compose([ + A.RGBShift((-20,20),(-20,20),(-20,20),p=0.3), + A.HueSaturationValue(hue_shift_limit=(-0.3,0.3), sat_shift_limit=(-0.3,0.3), val_shift_limit=(-0.3,0.3), p=0.3), + A.RandomBrightnessContrast(brightness_limit=(-0.3,0.3), contrast_limit=(-0.3,0.3), p=0.3), + A.ImageCompression(quality_lower=40, quality_upper=100,p=0.5) + ]) + # Apply transformations + im_aug = transform(image=im) + return im_aug['image'] + + + def data_aug(self, im): + """ + Apply data augmentation on the input image using albumentations. + """ + transform = A.Compose([ + A.Compose([ + A.RGBShift((-20,20),(-20,20),(-20,20),p=0.3), + A.HueSaturationValue(hue_shift_limit=(-0.3,0.3), sat_shift_limit=(-0.3,0.3), val_shift_limit=(-0.3,0.3), p=1), + A.RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1,0.1), p=1), + ],p=1), + A.OneOf([ + RandomDownScale(p=1), + A.Sharpen(alpha=(0.2, 0.5), lightness=(0.5, 1.0), p=1), + ],p=1), + ], p=1.) + # Apply transformations + im_aug = transform(image=im) + return im_aug['image'] + + + def get_training_imglist(self): + """ + Get the list of training images. + """ + random.seed(1024) # Fix the random seed for reproducibility + imid_list = list(self.landmark_dict.keys()) + # imid_list = [imid.replace('landmarks', 'frames').replace('npy', 'png') for imid in imid_list] + random.shuffle(imid_list) + return imid_list + + def load_rgb(self, file_path): + """ + Load an RGB image from a file path and resize it to a specified resolution. + + Args: + file_path: A string indicating the path to the image file. + + Returns: + An Image object containing the loaded and resized image. + + Raises: + ValueError: If the loaded image is None. + """ + size = self.config['resolution'] # if self.mode == "train" else self.config['resolution'] + if not self.lmdb: + if not file_path[0] == '.': + file_path = f'./{self.config["rgb_dir"]}\\'+file_path + assert os.path.exists(file_path), f"{file_path} does not exist" + img = cv2.imread(file_path) + if img is None: + raise ValueError('Loaded image is None: {}'.format(file_path)) + elif self.lmdb: + with self.env.begin(write=False) as txn: + # transfer the path format from rgb-path to lmdb-key + if file_path[0]=='.': + file_path=file_path.replace('./datasets\\','') + + image_bin = txn.get(file_path.encode()) + image_buf = np.frombuffer(image_bin, dtype=np.uint8) + img = cv2.imdecode(image_buf, cv2.IMREAD_COLOR) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC) + return np.array(img, dtype=np.uint8) + + + def load_mask(self, file_path): + """ + Load a binary mask image from a file path and resize it to a specified resolution. + + Args: + file_path: A string indicating the path to the mask file. + + Returns: + A numpy array containing the loaded and resized mask. + + Raises: + None. + """ + size = self.config['resolution'] + if file_path is None: + if not file_path[0] == '.': + file_path = f'./{self.config["rgb_dir"]}\\'+file_path + return np.zeros((size, size, 1)) + if not self.lmdb: + if os.path.exists(file_path): + mask = cv2.imread(file_path, 0) + if mask is None: + mask = np.zeros((size, size)) + else: + return np.zeros((size, size, 1)) + else: + with self.env.begin(write=False) as txn: + # transfer the path format from rgb-path to lmdb-key + if file_path[0]=='.': + file_path=file_path.replace('./datasets\\','') + image_bin = txn.get(file_path.encode()) + image_buf = np.frombuffer(image_bin, dtype=np.uint8) + # cv2.IMREAD_GRAYSCALE reads grayscale images, while cv2.IMREAD_COLOR reads color images + mask = cv2.imdecode(image_buf, cv2.IMREAD_COLOR) + mask = cv2.resize(mask, (size, size)) / 255 + mask = np.expand_dims(mask, axis=2) + return np.float32(mask) + + def load_landmark(self, file_path): + """ + Load 2D facial landmarks from a file path. + + Args: + file_path: A string indicating the path to the landmark file. + + Returns: + A numpy array containing the loaded landmarks. + + Raises: + None. + """ + if file_path is None: + return np.zeros((81, 2)) + if not self.lmdb: + if not file_path[0] == '.': + file_path = f'./{self.config["rgb_dir"]}\\'+file_path + if os.path.exists(file_path): + landmark = np.load(file_path) + else: + return np.zeros((81, 2)) + else: + with self.env.begin(write=False) as txn: + # transfer the path format from rgb-path to lmdb-key + if file_path[0]=='.': + file_path=file_path.replace('./datasets\\','') + binary = txn.get(file_path.encode()) + landmark = np.frombuffer(binary, dtype=np.uint32).reshape((81, 2)) + return np.float32(landmark) + + def preprocess_images(self, imid_fg, imid_bg): + """ + Load foreground and background images and face shapes. + """ + fg_im = self.load_rgb(imid_fg.replace('landmarks', 'frames').replace('npy', 'png')) + fg_im = np.array(self.data_aug(fg_im)) + fg_shape = self.landmark_dict[imid_fg] + fg_shape = np.array(fg_shape, dtype=np.int32) + + bg_im = self.load_rgb(imid_bg.replace('landmarks', 'frames').replace('npy', 'png')) + bg_im = np.array(self.data_aug(bg_im)) + bg_shape = self.landmark_dict[imid_bg] + bg_shape = np.array(bg_shape, dtype=np.int32) + + if fg_im is None: + return bg_im, bg_shape, bg_im, bg_shape + elif bg_im is None: + return fg_im, fg_shape, fg_im, fg_shape + + return fg_im, fg_shape, bg_im, bg_shape + + + def get_fg_bg(self, one_lmk_path): + """ + Get foreground and background paths + """ + bg_lmk_path = one_lmk_path + # Randomly pick one from the nearest neighbors for the foreground + if bg_lmk_path in self.face_info: + fg_lmk_path = random.choice(self.face_info[bg_lmk_path]) + else: + fg_lmk_path = bg_lmk_path + + return fg_lmk_path, bg_lmk_path + + + def generate_masks(self, fg_im, fg_shape, bg_im, bg_shape): + """ + Generate masks for foreground and background images. + """ + fg_mask = get_mask(fg_shape, fg_im, deform=False) + bg_mask = get_mask(bg_shape, bg_im, deform=True) + + # # Only do the postprocess for the background mask + bg_mask_postprocess = warp_mask(bg_mask, std=20) + return fg_mask, bg_mask_postprocess + + + def warp_images(self, fg_im, fg_shape, bg_im, bg_shape, fg_mask): + """ + Warp foreground face onto background image using affine or 3D warping. + """ + H, W, C = bg_im.shape + use_3d_warp = np.random.rand() < 0.5 + + if not use_3d_warp: + aff_param = np.array(get_align_mat_new(fg_shape, bg_shape)).reshape(2, 3) + warped_face = cv2.warpAffine(fg_im, aff_param, (W, H), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REFLECT) + fg_mask = cv2.warpAffine(fg_mask, aff_param, (W, H), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REFLECT) + fg_mask = fg_mask > 0 + else: + warped_face = faceswap.warp_image_3d(fg_im, np.array(fg_shape[:48]), np.array(bg_shape[:48]), (H, W)) + fg_mask = np.mean(warped_face, axis=2) > 0 + + return warped_face, fg_mask + + + def colorTransfer(self, src, dst, mask): + transferredDst = np.copy(dst) + maskIndices = np.where(mask != 0) + maskedSrc = src[maskIndices[0], maskIndices[1]].astype(np.float32) + maskedDst = dst[maskIndices[0], maskIndices[1]].astype(np.float32) + + # Compute means and standard deviations + meanSrc = np.mean(maskedSrc, axis=0) + stdSrc = np.std(maskedSrc, axis=0) + meanDst = np.mean(maskedDst, axis=0) + stdDst = np.std(maskedDst, axis=0) + + # Perform color transfer + maskedDst = (maskedDst - meanDst) * (stdSrc / stdDst) + meanSrc + maskedDst = np.clip(maskedDst, 0, 255) + + # Copy the entire background into transferredDst + transferredDst = np.copy(dst) + # Now apply color transfer only to the masked region + transferredDst[maskIndices[0], maskIndices[1]] = maskedDst.astype(np.uint8) + + return transferredDst + + + def blend_images(self, color_corrected_fg, bg_im, bg_mask, featherAmount=0.2): + """ + Blend foreground and background images together. + """ + # normalize the mask to have values between 0 and 1 + b_mask = bg_mask / 255. + + # Add an extra dimension and repeat the mask to match the number of channels in color_corrected_fg and bg_im + b_mask = np.repeat(b_mask[:, :, np.newaxis], 3, axis=2) + + # Compute the alpha blending + maskIndices = np.where(b_mask != 0) + maskPts = np.hstack((maskIndices[1][:, np.newaxis], maskIndices[0][:, np.newaxis])) + + # FIXME: deal with the bugs of empty maskpts + if maskPts.size == 0: + print(f"No non-zero values found in bg_mask for blending. Skipping this image.") + return color_corrected_fg # or handle this situation differently according to the needs + + faceSize = np.max(maskPts, axis=0) - np.min(maskPts, axis=0) + featherAmount = featherAmount * np.max(faceSize) + + hull = cv2.convexHull(maskPts) + dists = np.zeros(maskPts.shape[0]) + for i in range(maskPts.shape[0]): + dists[i] = cv2.pointPolygonTest(hull, (int(maskPts[i, 0]), int(maskPts[i, 1])), True) + + weights = np.clip(dists / featherAmount, 0, 1) + + # Perform the blending operation + color_corrected_fg = color_corrected_fg.astype(float) + bg_im = bg_im.astype(float) + blended_image = np.copy(bg_im) + blended_image[maskIndices[0], maskIndices[1]] = weights[:, np.newaxis] * color_corrected_fg[maskIndices[0], maskIndices[1]] + (1 - weights[:, np.newaxis]) * bg_im[maskIndices[0], maskIndices[1]] + + # Convert the blended image to 8-bit unsigned integers + blended_image = np.clip(blended_image, 0, 255) + blended_image = blended_image.astype(np.uint8) + return blended_image + + + def process_images(self, imid_fg, imid_bg, index): + """ + Overview: + Process foreground and background images following the data generation pipeline (BI dataset). + + Terminology: + Foreground (fg) image: The image containing the face that will be blended onto the background image. + Background (bg) image: The image onto which the face from the foreground image will be blended. + """ + fg_im, fg_shape, bg_im, bg_shape = self.preprocess_images(imid_fg, imid_bg) + fg_mask, bg_mask = self.generate_masks(fg_im, fg_shape, bg_im, bg_shape) + warped_face, fg_mask = self.warp_images(fg_im, fg_shape, bg_im, bg_shape, fg_mask) + + try: + # add the below two lines to make sure the bg_mask is strictly within the fg_mask + bg_mask[fg_mask == 0] = 0 + color_corrected_fg = self.colorTransfer(bg_im, warped_face, bg_mask) + blended_image = self.blend_images(color_corrected_fg, bg_im, bg_mask) + # FIXME: ugly, in order to fix the problem of mask (all zero values for bg_mask) + except: + color_corrected_fg = self.colorTransfer(bg_im, warped_face, bg_mask) + blended_image = self.blend_images(color_corrected_fg, bg_im, bg_mask) + boundary = get_boundary(bg_mask) + + # # Prepare images and titles for the combined image + # images = [fg_im, np.where(fg_mask>0, 255, 0), bg_im, bg_mask, color_corrected_fg, blended_image, np.where(boundary>0, 255, 0)] + # titles = ["Fg Image", "Fg Mask", "Bg Image", + # "Bg Mask", "Blended Region", + # "Blended Image", "Boundary"] + + # # Save the combined image + # os.makedirs('facexray_examples_3', exist_ok=True) + # self.save_combined_image(images, titles, index, f'facexray_examples_3/combined_image_{index}.png') + return blended_image, boundary, bg_im + + + def post_proc(self, img): + ''' + if self.mode == 'train': + #if np.random.rand() < 0.5: + # img = random_add_noise(img) + #add_gaussian_noise(img) + if np.random.rand() < 0.5: + #img, _ = change_res(img) + img = gaussian_blur(img) + ''' + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + im_aug = self.blended_aug(img) + im_aug = Image.fromarray(np.uint8(img)) + im_aug = self.transforms(im_aug) + return im_aug + + + @staticmethod + def save_combined_image(images, titles, index, save_path): + """ + Save the combined image with titles for each single image. + + Args: + images (List[np.ndarray]): List of images to be combined. + titles (List[str]): List of titles for each image. + index (int): Index of the image. + save_path (str): Path to save the combined image. + """ + # Determine the maximum height and width among the images + max_height = max(image.shape[0] for image in images) + max_width = max(image.shape[1] for image in images) + + # Create the canvas + canvas = np.zeros((max_height * len(images), max_width, 3), dtype=np.uint8) + + # Place the images and titles on the canvas + current_height = 0 + for image, title in zip(images, titles): + height, width = image.shape[:2] + + # Check if image has a third dimension (color channels) + if image.ndim == 2: + # If not, add a third dimension + image = np.tile(image[..., None], (1, 1, 3)) + + canvas[current_height : current_height + height, :width] = image + cv2.putText( + canvas, title, (10, current_height + 30), + cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2 + ) + current_height += height + + # Save the combined image + cv2.imwrite(save_path, canvas) + + + def __getitem__(self, index): + """ + Get an item from the dataset by index. + """ + one_lmk_path = self.imid_list[index] + try: + label = 1 if one_lmk_path.split('/')[6]=='manipulated_sequences' else 0 + except Exception as e: + label = 1 if one_lmk_path.split('\\')[6] == 'manipulated_sequences' else 0 + imid_fg, imid_bg = self.get_fg_bg(one_lmk_path) + manipulate_img, boundary, imid_bg = self.process_images(imid_fg, imid_bg, index) + + manipulate_img = self.post_proc(manipulate_img) + imid_bg = self.post_proc(imid_bg) + boundary = torch.from_numpy(boundary) + boundary = boundary.unsqueeze(2).permute(2, 0, 1) + + # fake data + fake_data_tuple = (manipulate_img, boundary, 1) + # real data + real_data_tuple = (imid_bg, torch.zeros_like(boundary), label) + + return fake_data_tuple, real_data_tuple + + + @staticmethod + def collate_fn(batch): + """ + Collates batches of data and shuffles the images. + """ + # Unzip the batch + fake_data, real_data = zip(*batch) + + # Unzip the fake and real data + fake_images, fake_boundaries, fake_labels = zip(*fake_data) + real_images, real_boundaries, real_labels = zip(*real_data) + + # Combine fake and real data + images = torch.stack(fake_images + real_images) + boundaries = torch.stack(fake_boundaries + real_boundaries) + labels = torch.tensor(fake_labels + real_labels) + + # Combine images, boundaries, and labels into tuples + combined_data = list(zip(images, boundaries, labels)) + + # Shuffle the combined data + random.shuffle(combined_data) + + # Unzip the shuffled data + images, boundaries, labels = zip(*combined_data) + + # Create the data dictionary + data_dict = { + 'image': torch.stack(images), + 'label': torch.tensor(labels), + 'mask': torch.stack(boundaries), # Assuming boundaries are your masks + 'landmark': None # Add your landmark data if available + } + + return data_dict + + + def __len__(self): + """ + Get the length of the dataset. + """ + return len(self.imid_list) + + +if __name__ == "__main__": + dataset = FFBlendDataset() + print('dataset lenth: ', len(dataset)) + + def tensor2bgr(im): + img = im.squeeze().cpu().numpy().transpose(1, 2, 0) + img = (img + 1)/2 * 255 + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + return img + + def tensor2gray(im): + img = im.squeeze().cpu().numpy() + img = img * 255 + return img + + for i, data_dict in enumerate(dataset): + if i > 20: + break + if label == 1: + if not use_mouth: + img, boudary = im + cv2.imwrite('{}_whole.png'.format(i), tensor2bgr(img)) + cv2.imwrite('{}_boudnary.png'.format(i), tensor2gray(boudary)) + else: + img, mouth, boudary = im + cv2.imwrite('{}_whole.png'.format(i), tensor2bgr(img)) + cv2.imwrite('{}_mouth.png'.format(i), tensor2bgr(mouth)) + cv2.imwrite('{}_boudnary.png'.format(i), tensor2gray(boudary)) diff --git a/training/dataset/fwa_blend.py b/training/dataset/fwa_blend.py new file mode 100644 index 0000000000000000000000000000000000000000..ad3c1df0bb65ccba4cd2d6aa7c31cc5c462c1efe --- /dev/null +++ b/training/dataset/fwa_blend.py @@ -0,0 +1,548 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-03-30 + +The code is designed for FWA and mainly modified from the below link: +https://github.com/yuezunli/DSP-FWA +''' + +import os +import sys +import json +import pickle +import time + +import dlib +import numpy as np +from copy import deepcopy +import cv2 +import random +from PIL import Image +from skimage.util import random_noise +from skimage.draw import polygon +from scipy import linalg +import heapq as hq +import albumentations as A + +import torch +from torch.autograd import Variable +from torch.utils import data +from torchvision import transforms as T +import torchvision + +from dataset.utils.face_blend import * +from dataset.utils.face_align import get_align_mat_new +from dataset.utils.color_transfer import color_transfer +from dataset.utils.faceswap_utils import blendImages as alpha_blend_fea +from dataset.utils.faceswap_utils import AlphaBlend as alpha_blend +from dataset.utils.face_aug import aug_one_im, change_res +from dataset.utils.image_ae import get_pretraiend_ae +from dataset.utils.warp import warp_mask +from dataset.utils import faceswap +from scipy.ndimage.filters import gaussian_filter +from skimage.transform import AffineTransform, warp + +from dataset.abstract_dataset import DeepfakeAbstractBaseDataset + + +# Define face detector and predictor models +face_detector = dlib.get_frontal_face_detector() +predictor_path = 'preprocessing/dlib_tools/shape_predictor_81_face_landmarks.dat' +face_predictor = dlib.shape_predictor(predictor_path) + + +mean_face_x = np.array([ + 0.000213256, 0.0752622, 0.18113, 0.29077, 0.393397, 0.586856, 0.689483, 0.799124, + 0.904991, 0.98004, 0.490127, 0.490127, 0.490127, 0.490127, 0.36688, 0.426036, + 0.490127, 0.554217, 0.613373, 0.121737, 0.187122, 0.265825, 0.334606, 0.260918, + 0.182743, 0.645647, 0.714428, 0.793132, 0.858516, 0.79751, 0.719335, 0.254149, + 0.340985, 0.428858, 0.490127, 0.551395, 0.639268, 0.726104, 0.642159, 0.556721, + 0.490127, 0.423532, 0.338094, 0.290379, 0.428096, 0.490127, 0.552157, 0.689874, + 0.553364, 0.490127, 0.42689]) + +mean_face_y = np.array([ + 0.106454, 0.038915, 0.0187482, 0.0344891, 0.0773906, 0.0773906, 0.0344891, + 0.0187482, 0.038915, 0.106454, 0.203352, 0.307009, 0.409805, 0.515625, 0.587326, + 0.609345, 0.628106, 0.609345, 0.587326, 0.216423, 0.178758, 0.179852, 0.231733, + 0.245099, 0.244077, 0.231733, 0.179852, 0.178758, 0.216423, 0.244077, 0.245099, + 0.780233, 0.745405, 0.727388, 0.742578, 0.727388, 0.745405, 0.780233, 0.864805, + 0.902192, 0.909281, 0.902192, 0.864805, 0.784792, 0.778746, 0.785343, 0.778746, + 0.784792, 0.824182, 0.831803, 0.824182]) + +landmarks_2D = np.stack([mean_face_x, mean_face_y], axis=1) + + +class RandomDownScale(A.core.transforms_interface.ImageOnlyTransform): + def apply(self,img,**params): + return self.randomdownscale(img) + + def randomdownscale(self,img): + keep_ratio=True + keep_input_shape=True + H,W,C=img.shape + ratio_list=[2,4] + r=ratio_list[np.random.randint(len(ratio_list))] + img_ds=cv2.resize(img,(int(W/r),int(H/r)),interpolation=cv2.INTER_NEAREST) + if keep_input_shape: + img_ds=cv2.resize(img_ds,(W,H),interpolation=cv2.INTER_LINEAR) + return img_ds + + +def umeyama( src, dst, estimate_scale ): + """Estimate N-D similarity transformation with or without scaling. + Parameters + ---------- + src : (M, N) array + Source coordinates. + dst : (M, N) array + Destination coordinates. + estimate_scale : bool + Whether to estimate scaling factor. + Returns + ------- + T : (N + 1, N + 1) + The homogeneous similarity transformation matrix. The matrix contains + NaN values only if the problem is not well-conditioned. + References + ---------- + .. [1] "Least-squares estimation of transformation parameters between two + point patterns", Shinji Umeyama, PAMI 1991, DOI: 10.1109/34.88573 + """ + + num = src.shape[0] + dim = src.shape[1] + + # Compute mean of src and dst. + src_mean = src.mean(axis=0) + dst_mean = dst.mean(axis=0) + + # Subtract mean from src and dst. + src_demean = src - src_mean + dst_demean = dst - dst_mean + + # Eq. (38). + A = np.dot(dst_demean.T, src_demean) / num + + # Eq. (39). + d = np.ones((dim,), dtype=np.double) + if np.linalg.det(A) < 0: + d[dim - 1] = -1 + + T = np.eye(dim + 1, dtype=np.double) + + U, S, V = np.linalg.svd(A) + + # Eq. (40) and (43). + rank = np.linalg.matrix_rank(A) + if rank == 0: + return np.nan * T + elif rank == dim - 1: + if np.linalg.det(U) * np.linalg.det(V) > 0: + T[:dim, :dim] = np.dot(U, V) + else: + s = d[dim - 1] + d[dim - 1] = -1 + T[:dim, :dim] = np.dot(U, np.dot(np.diag(d), V)) + d[dim - 1] = s + else: + T[:dim, :dim] = np.dot(U, np.dot(np.diag(d), V.T)) + + if estimate_scale: + # Eq. (41) and (42). + scale = 1.0 / src_demean.var(axis=0).sum() * np.dot(S, d) + else: + scale = 1.0 + + T[:dim, dim] = dst_mean - scale * np.dot(T[:dim, :dim], src_mean.T) + T[:dim, :dim] *= scale + + return T + + +def shape_to_np(shape, dtype="int"): + # initialize the list of (x, y)-coordinates + coords = np.zeros((68, 2), dtype=dtype) + + # loop over the 68 facial landmarks and convert them + # to a 2-tuple of (x, y)-coordinates + for i in range(0, 68): + coords[i] = (shape.part(i).x, shape.part(i).y) + + # return the list of (x, y)-coordinates + return coords + + +from skimage.transform import AffineTransform, warp + +def get_warped_face(face, landmarks, tform): + """ + Apply the given affine transformation to the face and landmarks. + + Args: + face (np.ndarray): The face image to be transformed. + landmarks (np.ndarray): The facial landmarks to be transformed. + tform (AffineTransform): The transformation to apply. + + Returns: + warped_face (np.ndarray): The transformed face image. + warped_landmarks (np.ndarray): The transformed facial landmarks. + """ + # Apply the transformation to the face + warped_face = warp(face, tform.inverse, output_shape=face.shape) + warped_face = (warped_face * 255).astype(np.uint8) + + # Apply the transformation to the landmarks + warped_landmarks = tform.inverse(landmarks) + + return warped_face, warped_landmarks + + +def warp_face_within_landmarks(face, landmarks, tform): + """ + Apply the given affine transformation to the face and landmarks, + and retain only the area within the landmarks. + + Args: + face (np.ndarray): The face image to be transformed. + landmarks (np.ndarray): The facial landmarks to be transformed. + tform (AffineTransform): The transformation to apply. + + Returns: + warped_face (np.ndarray): The transformed face image. + warped_landmarks (np.ndarray): The transformed facial landmarks. + """ + # Apply the transformation to the face + warped_face = warp(face, tform.inverse, output_shape=face.shape) + warped_face = (warped_face * 255).astype(np.uint8) + + # Apply the transformation to the landmarks + warped_landmarks = np.linalg.inv(landmarks) + + # Generate a mask based on the landmarks + rr, cc = polygon(warped_landmarks[:, 1], warped_landmarks[:, 0]) + mask = np.zeros_like(warped_face, dtype=np.uint8) + mask[rr, cc] = 1 + + # Apply the mask to the face + warped_face *= mask + + return warped_face, warped_landmarks + + +def get_2d_aligned_face(image, mat, size, padding=[0, 0]): + mat = mat * size + mat[0, 2] += padding[0] + mat[1, 2] += padding[1] + return cv2.warpAffine(image, mat, (size + 2 * padding[0], size + 2 * padding[1])) + + +def get_2d_aligned_landmarks(face_cache, aligned_face_size=256, padding=(0, 0)): + mat, points = face_cache + # Mapping landmarks to aligned face + pred_ = np.concatenate([points, np.ones((points.shape[0], 1))], axis=-1) + pred_ = np.transpose(pred_) + mat = mat * aligned_face_size + mat[0, 2] += padding[0] + mat[1, 2] += padding[1] + aligned_shape = np.dot(mat, pred_) + aligned_shape = np.transpose(aligned_shape[:2, :]) + return aligned_shape + + +def get_aligned_face_and_landmarks(im, face_cache, aligned_face_size = 256, padding=(0, 0)): + """ + get all aligned faces and landmarks of all images + :param imgs: origin images + :param fa: face_alignment package + :return: + """ + aligned_cur_shapes = [] + aligned_cur_im = [] + for mat, points in face_cache: + # Get transform matrix + aligned_face = get_2d_aligned_face(im, mat, aligned_face_size, padding) + aligned_shape = get_2d_aligned_landmarks([mat, points], aligned_face_size, padding) + aligned_cur_shapes.append(aligned_shape) + aligned_cur_im.append(aligned_face) + return aligned_cur_im, aligned_cur_shapes + + +def face_warp(im, face, trans_matrix, size, padding): + new_face = np.clip(face, 0, 255).astype(im.dtype) + image_size = im.shape[1], im.shape[0] + + tmp_matrix = trans_matrix * size + delta_matrix = np.array([[0., 0., padding[0]*1.0], [0., 0., padding[1]*1.0]]) + tmp_matrix = tmp_matrix + delta_matrix + + # Warp the new face onto a blank canvas + warped_face = np.zeros_like(im) + cv2.warpAffine(new_face, tmp_matrix, image_size, warped_face, cv2.WARP_INVERSE_MAP, + cv2.BORDER_TRANSPARENT) + + # Create a mask of the warped face + mask = (warped_face > 0).astype(np.uint8) + + # Blend the warped face with the original image + new_image = im * (1 - mask) + warped_face * mask + + return new_image, mask + + +def get_face_loc(im, face_detector, scale=0): + """ get face locations, color order of images is rgb """ + faces = face_detector(np.uint8(im), scale) + face_list = [] + if faces is not None or len(faces) > 0: + for i, d in enumerate(faces): + try: + face_list.append([d.left(), d.top(), d.right(), d.bottom()]) + except: + face_list.append([d.rect.left(), d.rect.top(), d.rect.right(), d.rect.bottom()]) + return face_list + + + +def align(im, face_detector, lmark_predictor, scale=0): + # This version we handle all faces in view + # channel order rgb + im = np.uint8(im) + faces = face_detector(im, scale) + face_list = [] + if faces is not None or len(faces) > 0: + for pred in faces: + try: + points = shape_to_np(lmark_predictor(im, pred)) + except: + points = shape_to_np(lmark_predictor(im, pred.rect)) + trans_matrix = umeyama(points[17:], landmarks_2D, True)[0:2] + face_list.append([trans_matrix, points]) + return face_list + + +class FWABlendDataset(DeepfakeAbstractBaseDataset): + def __init__(self, config=None): + super().__init__(config, mode='train') + self.transforms = T.Compose([ + T.ToTensor(), + T.Normalize(mean=config['mean'], + std=config['std']) + ]) + self.resolution = config['resolution'] + + + def blended_aug(self, im): + transform = A.Compose([ + A.RGBShift((-20,20),(-20,20),(-20,20),p=0.3), + A.HueSaturationValue(hue_shift_limit=(-0.3,0.3), sat_shift_limit=(-0.3,0.3), val_shift_limit=(-0.3,0.3), p=0.3), + A.RandomBrightnessContrast(brightness_limit=(-0.3,0.3), contrast_limit=(-0.3,0.3), p=0.3), + A.ImageCompression(quality_lower=40, quality_upper=100,p=0.5) + ]) + # Apply transformations + im_aug = transform(image=im) + return im_aug['image'] + + + def data_aug(self, im): + """ + Apply data augmentation on the input image using albumentations. + """ + transform = A.Compose([ + A.Compose([ + A.RGBShift((-20,20),(-20,20),(-20,20),p=0.3), + A.HueSaturationValue(hue_shift_limit=(-0.3,0.3), sat_shift_limit=(-0.3,0.3), val_shift_limit=(-0.3,0.3), p=1), + A.RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1,0.1), p=1), + ],p=1), + A.OneOf([ + RandomDownScale(p=1), + A.Sharpen(alpha=(0.2, 0.5), lightness=(0.5, 1.0), p=1), + ],p=1), + ], p=1.) + # Apply transformations + im_aug = transform(image=im) + return im_aug['image'] + + + def blend_images(self, img_path): + #im = cv2.imread(img_path) + im = np.array(self.load_rgb(img_path)) + + # Get the alignment of the head + face_cache = align(im, face_detector, face_predictor) + + # Get the aligned face and landmarks + aligned_im_head, aligned_shape = get_aligned_face_and_landmarks(im, face_cache) + # If no faces were detected in the image, return None (or any suitable value) + if len(aligned_im_head) == 0 or len(aligned_shape) == 0: + return None, None + aligned_im_head = aligned_im_head[0] + aligned_shape = aligned_shape[0] + + # Apply transformations to the face + scale_factor = random.choice([0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]) + scaled_face = cv2.resize(aligned_im_head, (0, 0), fx=scale_factor, fy=scale_factor) + + # Apply Gaussian blur to the scaled face + blurred_face = cv2.GaussianBlur(scaled_face, (5, 5), 0) + + # Resize the processed image back to the original size + resized_face = cv2.resize(blurred_face, (aligned_im_head.shape[1], aligned_im_head.shape[0])) + + # Generate a random facial mask + mask = get_mask(aligned_shape.astype(np.float32), resized_face, std=20, deform=True) + + # Apply the mask to the resized face + masked_face = cv2.bitwise_and(resized_face, resized_face, mask=mask) + + # do aug before warp + im = np.array(self.blended_aug(im)) + + # Warp the face back to the original image + im, masked_face = face_warp(im, masked_face, face_cache[0][0], self.resolution, [0, 0]) + shape = get_2d_aligned_landmarks(face_cache[0], self.resolution, [0, 0]) + return im, masked_face + + + def process_images(self, img_path, index): + """ + Process an image following the data generation pipeline. + """ + blended_im, mask = self.blend_images(img_path) + + # Prepare images and titles for the combined image + imid_fg = np.array(self.load_rgb(img_path)) + imid_fg = np.array(self.data_aug(imid_fg)) + + if blended_im is None or mask is None: + return imid_fg, None + + # images = [ + # imid_fg, + # np.where(mask.astype(np.uint8)>0, 255, 0), + # blended_im, + # ] + # titles = ["Image", "Mask", "Blended Image"] + + # # Save the combined image + # os.makedirs('fwa_examples_2', exist_ok=True) + # self.save_combined_image(images, titles, index, f'fwa_examples_2/combined_image_{index}.png') + return imid_fg, blended_im + + + def post_proc(self, img): + ''' + if self.mode == 'train': + #if np.random.rand() < 0.5: + # img = random_add_noise(img) + #add_gaussian_noise(img) + if np.random.rand() < 0.5: + #img, _ = change_res(img) + img = gaussian_blur(img) + ''' + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + im_aug = self.blended_aug(img) + im_aug = Image.fromarray(np.uint8(img)) + im_aug = self.transforms(im_aug) + return im_aug + + + @staticmethod + def save_combined_image(images, titles, index, save_path): + """ + Save the combined image with titles for each single image. + + Args: + images (List[np.ndarray]): List of images to be combined. + titles (List[str]): List of titles for each image. + index (int): Index of the image. + save_path (str): Path to save the combined image. + """ + # Determine the maximum height and width among the images + max_height = max(image.shape[0] for image in images) + max_width = max(image.shape[1] for image in images) + + # Create the canvas + canvas = np.zeros((max_height * len(images), max_width, 3), dtype=np.uint8) + + # Place the images and titles on the canvas + current_height = 0 + for image, title in zip(images, titles): + height, width = image.shape[:2] + + # Check if image has a third dimension (color channels) + if image.ndim == 2: + # If not, add a third dimension + image = np.tile(image[..., None], (1, 1, 3)) + + canvas[current_height : current_height + height, :width] = image + cv2.putText( + canvas, title, (10, current_height + 30), + cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2 + ) + current_height += height + + # Save the combined image + cv2.imwrite(save_path, canvas) + + + def __getitem__(self, index): + """ + Get an item from the dataset by index. + """ + one_img_path = self.data_dict['image'][index] + try: + label = 1 if one_img_path.split('/')[6]=='manipulated_sequences' else 0 + except Exception as e: + label = 1 if one_img_path.split('\\')[6] == 'manipulated_sequences' else 0 + blend_label = 1 + imid, manipulate_img = self.process_images(one_img_path, index) + + if manipulate_img is None: + manipulate_img = deepcopy(imid) + blend_label = label + manipulate_img = self.post_proc(manipulate_img) + imid = self.post_proc(imid) + + # blend data + fake_data_tuple = (manipulate_img, blend_label) + # original data + real_data_tuple = (imid, label) + + return fake_data_tuple, real_data_tuple + + + @staticmethod + def collate_fn(batch): + """ + Collates batches of data and shuffles the images. + """ + # Unzip the batch + fake_data, real_data = zip(*batch) + + # Unzip the fake and real data + fake_images, fake_labels = zip(*fake_data) + real_images, real_labels = zip(*real_data) + + # Combine fake and real data + images = torch.stack(fake_images + real_images) + labels = torch.tensor(fake_labels + real_labels) + + # Combine images, boundaries, and labels into tuples + combined_data = list(zip(images, labels)) + + # Shuffle the combined data + random.shuffle(combined_data) + + # Unzip the shuffled data + images, labels = zip(*combined_data) + + # Create the data dictionary + data_dict = { + 'image': torch.stack(images), + 'label': torch.tensor(labels), + 'mask': None, + 'landmark': None # Add your landmark data if available + } + + return data_dict diff --git a/training/dataset/generate_parsing_mask.py b/training/dataset/generate_parsing_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..0282dd764cd262e52fdd96fc2729bc8585efccea --- /dev/null +++ b/training/dataset/generate_parsing_mask.py @@ -0,0 +1,129 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2024-01-26 + +The code is designed for self-blending method (SBI, CVPR 2024). +''' + +import sys +sys.path.append('.') + +import os +import cv2 +import yaml +import random +import torch +import torch.nn as nn +from PIL import Image +import numpy as np +from copy import deepcopy +import albumentations as A +from training.dataset.abstract_dataset import DeepfakeAbstractBaseDataset +from training.dataset.sbi_api import SBI_API +from training.dataset.utils.bi_online_generation_yzy import random_get_hull +from training.dataset.SimSwap.test_one_image import self_blend + +import warnings +warnings.filterwarnings('ignore') + + +from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +image_processor = SegformerImageProcessor.from_pretrained("/Youtu_Pangu_Security_Public/youtu-pangu-public/zhiyuanyan/huggingface/hub/models--jonathandinu--face-parsing/snapshots/a2bf62f39dfd8f8856a3c19be8b0707a8d68abdd") +face_parser = SegformerForSemanticSegmentation.from_pretrained("/Youtu_Pangu_Security_Public/youtu-pangu-public/zhiyuanyan/huggingface/hub/models--jonathandinu--face-parsing/snapshots/a2bf62f39dfd8f8856a3c19be8b0707a8d68abdd").to(device) + + +def create_facial_mask(mask, with_neck=False): + facial_labels = [1, 2, 3, 4, 5, 6, 7, 10, 11, 12] + if with_neck: + facial_labels += [17] + facial_mask = np.zeros_like(mask, dtype=bool) + for label in facial_labels: + facial_mask |= (mask == label) + return facial_mask.astype(np.uint8) * 255 + + +def face_parsing_mask(img1, with_neck=False): + # run inference on image + img1 = Image.fromarray(img1) + inputs = image_processor(images=img1, return_tensors="pt").to(device) + outputs = face_parser(**inputs) + logits = outputs.logits # shape (batch_size, num_labels, ~height/4, ~width/4) + + # resize output to match input image dimensions + upsampled_logits = nn.functional.interpolate(logits, + size=img1.size[::-1], # H x W + mode='bilinear', + align_corners=False) + labels = upsampled_logits.argmax(dim=1)[0] + mask = labels.cpu().numpy() + mask = create_facial_mask(mask, with_neck) + return mask + + +class YZYDataset(DeepfakeAbstractBaseDataset): + def __init__(self, config=None, mode='train'): + super().__init__(config, mode) + + # Get real lists + # Fix the label of real images to be 0 + self.real_imglist = [(img, label) for img, label in zip(self.image_list, self.label_list) if label == 0] + + + def __getitem__(self, index): + # Get the real image paths and labels + real_image_path, real_label = self.real_imglist[index] + # real_image_path = real_image_path.replace('/Youtu_Pangu_Security_Public/', '/Youtu_Pangu_Security/public/') + + # Load the real images + real_image = self.load_rgb(real_image_path) + real_image = np.array(real_image) # Convert to numpy array + + # Face Parsing + mask = face_parsing_mask(real_image, with_neck=False) + parse_mask_path = real_image_path.replace('frames', 'parse_mask') + os.makedirs(os.path.dirname(parse_mask_path), exist_ok=True) + cv2.imwrite(parse_mask_path, mask) + + # # SRI generation + # sri_image = self_blend(real_image) + # sri_path = real_image_path.replace('frames', 'sri_frames') + # os.makedirs(os.path.dirname(sri_path), exist_ok=True) + # cv2.imwrite(sri_path, sri_image) + + @staticmethod + def collate_fn(batch): + data_dict = { + 'image': None, + 'label': None, + 'landmark': None, + 'mask': None, + } + return data_dict + + def __len__(self): + return len(self.real_imglist) + + + +if __name__ == '__main__': + with open('./training/config/detector/sbi.yaml', 'r') as f: + config = yaml.safe_load(f) + with open('./training/config/train_config.yaml', 'r') as f: + config2 = yaml.safe_load(f) + config2['data_manner'] = 'lmdb' + config['dataset_json_folder'] = '/Youtu_Pangu_Security_Public/youtu-pangu-public/zhiyuanyan/DeepfakeBenchv2/preprocessing/dataset_json' + config.update(config2) + train_set = YZYDataset(config=config, mode='train') + train_data_loader = \ + torch.utils.data.DataLoader( + dataset=train_set, + batch_size=config['train_batchSize'], + shuffle=True, + num_workers=0, + collate_fn=train_set.collate_fn, + ) + from tqdm import tqdm + for iteration, batch in enumerate(tqdm(train_data_loader)): + print(iteration) \ No newline at end of file diff --git a/training/dataset/generate_xray_nearest.py b/training/dataset/generate_xray_nearest.py new file mode 100644 index 0000000000000000000000000000000000000000..cf946f1c37e5c5cfe03a712caa9891968b9582c9 --- /dev/null +++ b/training/dataset/generate_xray_nearest.py @@ -0,0 +1,136 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-03-30 + +The code is specifically designed for generating nearest sample pairs for Face X-ray. +Alternatively, you can utilize the pre-generated pkl files available in our GitHub repository. Please refer to the "Releases" section on our repository for accessing these files. +''' + +import os +import json +import pickle +import numpy as np +import heapq +import random +from tqdm import tqdm +from scipy.spatial import KDTree + + +def load_landmark(file_path): + """ + Load 2D facial landmarks from a file path. + + Args: + file_path: A string indicating the path to the landmark file. + + Returns: + A numpy array containing the loaded landmarks. + + Raises: + None. + """ + if file_path is None: + return np.zeros((81, 2)) + if os.path.exists(file_path): + landmark = np.load(file_path) + return np.float32(landmark) + else: + return np.zeros((81, 2)) + + +def get_landmark_dict(dataset_folder): + # Check if the dictionary has already been created + if os.path.exists('landmark_dict_ff.pkl'): + with open('landmark_dict_ff.pkl', 'rb') as f: + return pickle.load(f) + # Open the metadata file for the current folder + metadata_path = os.path.join(dataset_folder, "FaceForensics++.json") + with open(metadata_path, "r") as f: + metadata = json.load(f) + # Iterate over the metadata entries and add the landmark paths to the list + ff_real_data = metadata['FaceForensics++']['FF-real'] + # Using dictionary comprehension to generate the landmark_dict + landmark_dict = { + frame_path.replace('frames', 'landmarks').replace(".png", ".npy"): load_landmark( + frame_path.replace('frames', 'landmarks').replace(".png", ".npy") + ) + for mode, value in ff_real_data.items() + for video_name, video_info in tqdm(value['c23'].items()) + for frame_path in video_info['frames'] + } + # Save the dictionary to a pickle file + with open('landmark_dict_ffall.pkl', 'wb') as f: + pickle.dump(landmark_dict, f) + return landmark_dict + + +def get_nearest_faces_fixed_pair(landmark_info, num_neighbors): + ''' + Using KDTree to find the nearest faces for each image (Much faster!!) + ''' + random.seed(1024) # Fix the random seed for reproducibility + + # Check if the dictionary has already been created + if os.path.exists('nearest_face_info.pkl'): + with open('nearest_face_info.pkl', 'rb') as f: + return pickle.load(f) + + landmarks_array = np.array([lmk.flatten() for lmk in landmark_info.values()]) + landmark_ids = list(landmark_info.keys()) + + # Build a KDTree using the flattened landmarks + tree = KDTree(landmarks_array) + + nearest_faces = {} + for idx, this_lmk in tqdm(enumerate(landmarks_array), total=len(landmarks_array)): + # Query the KDTree for the nearest neighbors (excluding itself) + dists, indices = tree.query(this_lmk, k=num_neighbors + 1) + # Randomly pick one from the nearest N neighbors (excluding itself) + picked_idx = random.choice(indices[1:]) + nearest_faces[landmark_ids[idx]] = landmark_ids[picked_idx] + + # Save the dictionary to a pickle file + with open('nearest_face_info.pkl', 'wb') as f: + pickle.dump(nearest_faces, f) + + return nearest_faces + + +def get_nearest_faces(landmark_info, num_neighbors): + ''' + Using KDTree to find the nearest faces for each image (Much faster!!) + ''' + random.seed(1024) # Fix the random seed for reproducibility + + # Check if the dictionary has already been created + if os.path.exists('nearest_face_info.pkl'): + with open('nearest_face_info.pkl', 'rb') as f: + return pickle.load(f) + + landmarks_array = np.array([lmk.flatten() for lmk in landmark_info.values()]) + landmark_ids = list(landmark_info.keys()) + + # Build a KDTree using the flattened landmarks + tree = KDTree(landmarks_array) + + nearest_faces = {} + for idx, this_lmk in tqdm(enumerate(landmarks_array), total=len(landmarks_array)): + # Query the KDTree for the nearest neighbors (excluding itself) + dists, indices = tree.query(this_lmk, k=num_neighbors + 1) + # Store the nearest N neighbors (excluding itself) + nearest_faces[landmark_ids[idx]] = [landmark_ids[i] for i in indices[1:]] + + # Save the dictionary to a pickle file + with open('nearest_face_info.pkl', 'wb') as f: + pickle.dump(nearest_faces, f) + + return nearest_faces + +# Load the landmark dictionary and obtain the landmark dict +dataset_folder = "/home/zhiyuanyan/disfin/deepfake_benchmark/preprocessing/dataset_json/" +landmark_info = get_landmark_dict(dataset_folder) + +# Get the nearest faces for each image (in landmark_dict) +num_neighbors = 100 +nearest_faces_info = get_nearest_faces(landmark_info, num_neighbors) # running time: about 20 mins diff --git a/training/dataset/iid_dataset.py b/training/dataset/iid_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..299af7ee4a6074830e3591d9a53d357a5b07d434 --- /dev/null +++ b/training/dataset/iid_dataset.py @@ -0,0 +1,115 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-03-30 + +The code is designed for scenarios such as disentanglement-based methods where it is necessary to ensure an equal number of positive and negative samples. +''' +import os.path +from copy import deepcopy +import cv2 +import math +import torch +import random + +import yaml +from PIL import Image, ImageDraw +import numpy as np +from torch.utils.data import DataLoader + +from dataset.abstract_dataset import DeepfakeAbstractBaseDataset + +class IIDDataset(DeepfakeAbstractBaseDataset): + def __init__(self, config=None, mode='train'): + super().__init__(config, mode) + + + def __getitem__(self, index): + # Get the image paths and label + image_path = self.data_dict['image'][index] + if '\\' in image_path: + per = image_path.split('\\')[-2] + else: + per = image_path.split('/')[-2] + id_index = int(per.split('_')[-1]) # real video id + label = self.data_dict['label'][index] + + # Load the image + try: + image = self.load_rgb(image_path) + except Exception as e: + # Skip this image and return the first one + print(f"Error loading image at index {index}: {e}") + return self.__getitem__(0) + image = np.array(image) # Convert to numpy array for data augmentation + + # Do Data Augmentation + image_trans,_,_ = self.data_aug(image) + + # To tensor and normalize + image_trans = self.normalize(self.to_tensor(image_trans)) + + return id_index, image_trans, label + + @staticmethod + def collate_fn(batch): + """ + Collate a batch of data points. + + Args: + batch (list): A list of tuples containing the image tensor, the label tensor, + the landmark tensor, and the mask tensor. + + Returns: + A tuple containing the image tensor, the label tensor, the landmark tensor, + and the mask tensor. + """ + # Separate the image, label, landmark, and mask tensors + id_indexes, image_trans, label = zip(*batch) + + # Stack the image, label, landmark, and mask tensors + images = torch.stack(image_trans, dim=0) + labels = torch.LongTensor(label) + ids = torch.LongTensor(id_indexes) + # Create a dictionary of the tensors + data_dict = {} + data_dict['image'] = images + data_dict['label'] = labels + data_dict['id_index'] = ids + data_dict['mask']=None + data_dict['landmark']=None + return data_dict + + +def draw_landmark(img,landmark): + draw = ImageDraw.Draw(img) + + # landmark = np.stack([mean_face_x, mean_face_y], axis=1) + # landmark *=256 + + for i, point in enumerate(landmark): + + draw.ellipse((point[0] - 1, point[1] - 1, point[0] + 1, point[1] + 1), fill=(255, 0, 0)) + + draw.text((point[0], point[1]), str(i), fill=(255, 255, 255)) + return img + + +if __name__ == '__main__': + detector_path = r"./training/config/detector/xception.yaml" + # weights_path = "./ckpts/xception/CDFv2/tb_v1/ov.pth" + with open(detector_path, 'r') as f: + config = yaml.safe_load(f) + with open('./training/config/train_config.yaml', 'r') as f: + config2 = yaml.safe_load(f) + config2['data_manner'] = 'lmdb' + config['dataset_json_folder'] = 'preprocessing/dataset_json_v3' + config.update(config2) + dataset = IIDDataset(config=config) + batch_size = 2 + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True,collate_fn=dataset.collate_fn) + + for i, batch in enumerate(dataloader): + print(f"Batch {i}: {batch}") + + img = batch['img'] diff --git a/training/dataset/library/DeepFakeMask.py b/training/dataset/library/DeepFakeMask.py new file mode 100644 index 0000000000000000000000000000000000000000..3ad16cab208910d64476753d82735d79b3571ee3 --- /dev/null +++ b/training/dataset/library/DeepFakeMask.py @@ -0,0 +1,181 @@ +#!/usr/bin/python +# -*- coding: UTF-8 -*- +# Created by: algohunt +# Microsoft Research & Peking University +# lilingzhi@pku.edu.cn +# Copyright (c) 2019 + +#!/usr/bin/env python3 +""" Masks functions for faceswap.py """ + +import inspect +import logging +import sys + +import cv2 +import numpy as np + +# logger = logging.getLogger(__name__) # pylint: disable=invalid-name + + +def get_available_masks(): + """ Return a list of the available masks for cli """ + masks = sorted([name for name, obj in inspect.getmembers(sys.modules[__name__]) + if inspect.isclass(obj) and name != "Mask"]) + masks.append("none") + # logger.debug(masks) + return masks + + +def get_default_mask(): + """ Set the default mask for cli """ + masks = get_available_masks() + default = "dfl_full" + default = default if default in masks else masks[0] + # logger.debug(default) + return default + + +class Mask(): + """ Parent class for masks + the output mask will be .mask + channels: 1, 3 or 4: + 1 - Returns a single channel mask + 3 - Returns a 3 channel mask + 4 - Returns the original image with the mask in the alpha channel """ + + def __init__(self, landmarks, face, channels=4): + # logger.info("Initializing %s: (face_shape: %s, channels: %s, landmarks: %s)", + # self.__class__.__name__, face.shape, channels, landmarks) + self.landmarks = landmarks + self.face = face + self.channels = channels + + mask = self.build_mask() + self.mask = self.merge_mask(mask) + #logger.info("Initialized %s", self.__class__.__name__) + + def build_mask(self): + """ Override to build the mask """ + raise NotImplementedError + + def merge_mask(self, mask): + """ Return the mask in requested shape """ + #logger.info("mask_shape: %s", mask.shape) + assert self.channels in (1, 3, 4), "Channels should be 1, 3 or 4" + assert mask.shape[2] == 1 and mask.ndim == 3, "Input mask be 3 dimensions with 1 channel" + + if self.channels == 3: + retval = np.tile(mask, 3) + elif self.channels == 4: + retval = np.concatenate((self.face, mask), -1) + else: + retval = mask + + #logger.info("Final mask shape: %s", retval.shape) + return retval + + +class dfl_full(Mask): # pylint: disable=invalid-name + """ DFL facial mask """ + def build_mask(self): + mask = np.zeros(self.face.shape[0:2] + (1, ), dtype=np.float32) + + nose_ridge = (self.landmarks[27:31], self.landmarks[33:34]) + jaw = (self.landmarks[0:17], + self.landmarks[48:68], + self.landmarks[0:1], + self.landmarks[8:9], + self.landmarks[16:17]) + eyes = (self.landmarks[17:27], + self.landmarks[0:1], + self.landmarks[27:28], + self.landmarks[16:17], + self.landmarks[33:34]) + parts = [jaw, nose_ridge, eyes] + + for item in parts: + merged = np.concatenate(item) + cv2.fillConvexPoly(mask, cv2.convexHull(merged), 255.) # pylint: disable=no-member + return mask + + +class components(Mask): # pylint: disable=invalid-name + """ Component model mask """ + def build_mask(self): + mask = np.zeros(self.face.shape[0:2] + (1, ), dtype=np.float32) + + r_jaw = (self.landmarks[0:9], self.landmarks[17:18]) + l_jaw = (self.landmarks[8:17], self.landmarks[26:27]) + r_cheek = (self.landmarks[17:20], self.landmarks[8:9]) + l_cheek = (self.landmarks[24:27], self.landmarks[8:9]) + nose_ridge = (self.landmarks[19:25], self.landmarks[8:9],) + r_eye = (self.landmarks[17:22], + self.landmarks[27:28], + self.landmarks[31:36], + self.landmarks[8:9]) + l_eye = (self.landmarks[22:27], + self.landmarks[27:28], + self.landmarks[31:36], + self.landmarks[8:9]) + nose = (self.landmarks[27:31], self.landmarks[31:36]) + parts = [r_jaw, l_jaw, r_cheek, l_cheek, nose_ridge, r_eye, l_eye, nose] + + for item in parts: + merged = np.concatenate(item) + cv2.fillConvexPoly(mask, cv2.convexHull(merged), 255.) # pylint: disable=no-member + return mask + + +class extended(Mask): # pylint: disable=invalid-name + """ Extended mask + Based on components mask. Attempts to extend the eyebrow points up the forehead + """ + def build_mask(self): + mask = np.zeros(self.face.shape[0:2] + (1, ), dtype=np.float32) + + landmarks = self.landmarks.copy() + # mid points between the side of face and eye point + ml_pnt = (landmarks[36] + landmarks[0]) // 2 + mr_pnt = (landmarks[16] + landmarks[45]) // 2 + + # mid points between the mid points and eye + ql_pnt = (landmarks[36] + ml_pnt) // 2 + qr_pnt = (landmarks[45] + mr_pnt) // 2 + + # Top of the eye arrays + bot_l = np.array((ql_pnt, landmarks[36], landmarks[37], landmarks[38], landmarks[39])) + bot_r = np.array((landmarks[42], landmarks[43], landmarks[44], landmarks[45], qr_pnt)) + + # Eyebrow arrays + top_l = landmarks[17:22] + top_r = landmarks[22:27] + + # Adjust eyebrow arrays + landmarks[17:22] = top_l + ((top_l - bot_l) // 2) + landmarks[22:27] = top_r + ((top_r - bot_r) // 2) + + r_jaw = (landmarks[0:9], landmarks[17:18]) + l_jaw = (landmarks[8:17], landmarks[26:27]) + r_cheek = (landmarks[17:20], landmarks[8:9]) + l_cheek = (landmarks[24:27], landmarks[8:9]) + nose_ridge = (landmarks[19:25], landmarks[8:9],) + r_eye = (landmarks[17:22], landmarks[27:28], landmarks[31:36], landmarks[8:9]) + l_eye = (landmarks[22:27], landmarks[27:28], landmarks[31:36], landmarks[8:9]) + nose = (landmarks[27:31], landmarks[31:36]) + parts = [r_jaw, l_jaw, r_cheek, l_cheek, nose_ridge, r_eye, l_eye, nose] + + for item in parts: + merged = np.concatenate(item) + cv2.fillConvexPoly(mask, cv2.convexHull(merged), 255.) # pylint: disable=no-member + return mask + + +class facehull(Mask): # pylint: disable=invalid-name + """ Basic face hull mask """ + def build_mask(self): + mask = np.zeros(self.face.shape[0:2] + (1, ), dtype=np.float32) + hull = cv2.convexHull( # pylint: disable=no-member + np.array(self.landmarks).reshape((-1, 2))) + cv2.fillConvexPoly(mask, hull, 255.0, lineType=cv2.LINE_AA) # pylint: disable=no-member + return mask \ No newline at end of file diff --git a/training/dataset/library/LICENSE b/training/dataset/library/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..f288702d2fa16d3cdf0035b15a9fcbc552cd88e7 --- /dev/null +++ b/training/dataset/library/LICENSE @@ -0,0 +1,674 @@ + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +. diff --git a/training/dataset/library/README.md b/training/dataset/library/README.md new file mode 100644 index 0000000000000000000000000000000000000000..28ec640b626e03f2daabfe2e9b95dac0c1720004 --- /dev/null +++ b/training/dataset/library/README.md @@ -0,0 +1,12 @@ +# Face-X-ray +The author's unofficial PyTorch re-implementation of Face Xray + +This repo contains code for the BI data generation pipeline from [Face X-ray for More General Face Forgery Detection](https://arxiv.org/abs/1912.13458) by Lingzhi Li, Jianmin Bao, Ting Zhang, Hao Yang, Dong Chen, Fang Wen, Baining Guo. + +# Usage + +Just run bi_online_generation.py and you can get the following result. which is describe at Figure.5 in the paper. + +![demo](all_in_one.jpg) + +To get the whole BI dataset, you will need crop all the face and compute the landmarks as describe in the code. \ No newline at end of file diff --git a/training/dataset/library/bi_online_generation.py b/training/dataset/library/bi_online_generation.py new file mode 100644 index 0000000000000000000000000000000000000000..98ad61229ca907f4d9f2e867756efd9d1f5940da --- /dev/null +++ b/training/dataset/library/bi_online_generation.py @@ -0,0 +1,241 @@ +import dlib +from skimage import io +from skimage import transform as sktransform +import numpy as np +from matplotlib import pyplot as plt +import json +import os +import random +from PIL import Image +from imgaug import augmenters as iaa +from .DeepFakeMask import dfl_full,facehull,components,extended +import cv2 +import tqdm + +def name_resolve(path): + name = os.path.splitext(os.path.basename(path))[0] + vid_id, frame_id = name.split('_')[0:2] + return vid_id, frame_id + +def total_euclidean_distance(a,b): + assert len(a.shape) == 2 + return np.sum(np.linalg.norm(a-b,axis=1)) + +def random_get_hull(landmark,img1,hull_type): + if hull_type == 0: + mask = dfl_full(landmarks=landmark.astype('int32'),face=img1, channels=3).mask + return mask/255 + elif hull_type == 1: + mask = extended(landmarks=landmark.astype('int32'),face=img1, channels=3).mask + return mask/255 + elif hull_type == 2: + mask = components(landmarks=landmark.astype('int32'),face=img1, channels=3).mask + return mask/255 + elif hull_type == 3: + mask = facehull(landmarks=landmark.astype('int32'),face=img1, channels=3).mask + return mask/255 + +def random_erode_dilate(mask, ksize=None): + if random.random()>0.5: + if ksize is None: + ksize = random.randint(1,21) + if ksize % 2 == 0: + ksize += 1 + mask = np.array(mask).astype(np.uint8)*255 + kernel = np.ones((ksize,ksize),np.uint8) + mask = cv2.erode(mask,kernel,1)/255 + else: + if ksize is None: + ksize = random.randint(1,5) + if ksize % 2 == 0: + ksize += 1 + mask = np.array(mask).astype(np.uint8)*255 + kernel = np.ones((ksize,ksize),np.uint8) + mask = cv2.dilate(mask,kernel,1)/255 + return mask + + +# borrow from https://github.com/MarekKowalski/FaceSwap +def blendImages(src, dst, mask, featherAmount=0.2): + + maskIndices = np.where(mask != 0) + + src_mask = np.ones_like(mask) + dst_mask = np.zeros_like(mask) + + maskPts = np.hstack((maskIndices[1][:, np.newaxis], maskIndices[0][:, np.newaxis])) + faceSize = np.max(maskPts, axis=0) - np.min(maskPts, axis=0) + featherAmount = featherAmount * np.max(faceSize) + + hull = cv2.convexHull(maskPts) + dists = np.zeros(maskPts.shape[0]) + for i in range(maskPts.shape[0]): + dists[i] = cv2.pointPolygonTest(hull, (maskPts[i, 0], maskPts[i, 1]), True) + + weights = np.clip(dists / featherAmount, 0, 1) + + composedImg = np.copy(dst) + composedImg[maskIndices[0], maskIndices[1]] = weights[:, np.newaxis] * src[maskIndices[0], maskIndices[1]] + (1 - weights[:, np.newaxis]) * dst[maskIndices[0], maskIndices[1]] + + composedMask = np.copy(dst_mask) + composedMask[maskIndices[0], maskIndices[1]] = weights[:, np.newaxis] * src_mask[maskIndices[0], maskIndices[1]] + ( + 1 - weights[:, np.newaxis]) * dst_mask[maskIndices[0], maskIndices[1]] + + return composedImg, composedMask + + +# borrow from https://github.com/MarekKowalski/FaceSwap +def colorTransfer(src, dst, mask): + transferredDst = np.copy(dst) + + maskIndices = np.where(mask != 0) + + + maskedSrc = src[maskIndices[0], maskIndices[1]].astype(np.int32) + maskedDst = dst[maskIndices[0], maskIndices[1]].astype(np.int32) + + meanSrc = np.mean(maskedSrc, axis=0) + meanDst = np.mean(maskedDst, axis=0) + + maskedDst = maskedDst - meanDst + maskedDst = maskedDst + meanSrc + maskedDst = np.clip(maskedDst, 0, 255) + + transferredDst[maskIndices[0], maskIndices[1]] = maskedDst + + return transferredDst + +class BIOnlineGeneration(): + def __init__(self): + with open('precomuted_landmarks.json', 'r') as f: + self.landmarks_record = json.load(f) + for k,v in self.landmarks_record.items(): + self.landmarks_record[k] = np.array(v) + # extract all frame from all video in the name of {videoid}_{frameid} + self.data_list = [ + '000_0000.png', + '001_0000.png' + ] * 10000 + + # predefine mask distortion + self.distortion = iaa.Sequential([iaa.PiecewiseAffine(scale=(0.01, 0.15))]) + + def gen_one_datapoint(self): + background_face_path = random.choice(self.data_list) + data_type = 'real' if random.randint(0,1) else 'fake' + if data_type == 'fake' : + face_img,mask = self.get_blended_face(background_face_path) + mask = ( 1 - mask ) * mask * 4 + else: + face_img = io.imread(background_face_path) + mask = np.zeros((317, 317, 1)) + + # randomly downsample after BI pipeline + if random.randint(0,1): + aug_size = random.randint(64, 317) + face_img = Image.fromarray(face_img) + if random.randint(0,1): + face_img = face_img.resize((aug_size, aug_size), Image.BILINEAR) + else: + face_img = face_img.resize((aug_size, aug_size), Image.NEAREST) + face_img = face_img.resize((317, 317),Image.BILINEAR) + face_img = np.array(face_img) + + # random jpeg compression after BI pipeline + if random.randint(0,1): + quality = random.randint(60, 100) + encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality] + face_img_encode = cv2.imencode('.jpg', face_img, encode_param)[1] + face_img = cv2.imdecode(face_img_encode, cv2.IMREAD_COLOR) + + face_img = face_img[60:317,30:287,:] + mask = mask[60:317,30:287,:] + + # random flip + if random.randint(0,1): + face_img = np.flip(face_img,1) + mask = np.flip(mask,1) + + return face_img,mask,data_type + + def get_blended_face(self,background_face_path): + background_face = io.imread(background_face_path) + background_landmark = self.landmarks_record[background_face_path] + + foreground_face_path = self.search_similar_face(background_landmark,background_face_path) + foreground_face = io.imread(foreground_face_path) + + # down sample before blending + aug_size = random.randint(128,317) + background_landmark = background_landmark * (aug_size/317) + foreground_face = sktransform.resize(foreground_face,(aug_size,aug_size),preserve_range=True).astype(np.uint8) + background_face = sktransform.resize(background_face,(aug_size,aug_size),preserve_range=True).astype(np.uint8) + + # get random type of initial blending mask + mask = random_get_hull(background_landmark, background_face) + + # random deform mask + mask = self.distortion.augment_image(mask) + mask = random_erode_dilate(mask) + + # filte empty mask after deformation + if np.sum(mask) == 0 : + raise NotImplementedError + + # apply color transfer + foreground_face = colorTransfer(background_face, foreground_face, mask*255) + + # blend two face + blended_face, mask = blendImages(foreground_face, background_face, mask*255) + blended_face = blended_face.astype(np.uint8) + + # resize back to default resolution + blended_face = sktransform.resize(blended_face,(317,317),preserve_range=True).astype(np.uint8) + mask = sktransform.resize(mask,(317,317),preserve_range=True) + mask = mask[:,:,0:1] + return blended_face,mask + + def search_similar_face(self,this_landmark,background_face_path): + vid_id, frame_id = name_resolve(background_face_path) + min_dist = 99999999 + + # random sample 5000 frame from all frams: + all_candidate_path = random.sample( self.data_list, k=5000) + + # filter all frame that comes from the same video as background face + all_candidate_path = filter(lambda k:name_resolve(k)[0] != vid_id, all_candidate_path) + all_candidate_path = list(all_candidate_path) + + # loop throungh all candidates frame to get best match + for candidate_path in all_candidate_path: + candidate_landmark = self.landmarks_record[candidate_path].astype(np.float32) + candidate_distance = total_euclidean_distance(candidate_landmark, this_landmark) + if candidate_distance < min_dist: + min_dist = candidate_distance + min_path = candidate_path + + return min_path + +if __name__ == '__main__': + ds = BIOnlineGeneration() + from tqdm import tqdm + all_imgs = [] + for _ in tqdm(range(50)): + img,mask,label = ds.gen_one_datapoint() + mask = np.repeat(mask,3,2) + mask = (mask*255).astype(np.uint8) + img_cat = np.concatenate([img,mask],1) + all_imgs.append(img_cat) + all_in_one = Image.new('RGB', (2570,2570)) + + for x in range(5): + for y in range(10): + idx = x*10+y + im = Image.fromarray(all_imgs[idx]) + + dx = x*514 + dy = y*257 + + all_in_one.paste(im, (dx,dy)) + + all_in_one.save("all_in_one.jpg") \ No newline at end of file diff --git a/training/dataset/library/precomuted_landmarks.json b/training/dataset/library/precomuted_landmarks.json new file mode 100644 index 0000000000000000000000000000000000000000..5414035ad197fccd87ac514a10eafefccbbc042e --- /dev/null +++ b/training/dataset/library/precomuted_landmarks.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5772307690f8d05f62bb8e3d2e90f645ada908e44535a2c8444f0f66e0a71b25 +size 1650 diff --git a/training/dataset/lrl_dataset.py b/training/dataset/lrl_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d80323a0206c086bd6e5b00c6ebbfcd2b97bb2db --- /dev/null +++ b/training/dataset/lrl_dataset.py @@ -0,0 +1,139 @@ +import os +import sys +current_file_path = os.path.abspath(__file__) +parent_dir = os.path.dirname(os.path.dirname(current_file_path)) +project_root_dir = os.path.dirname(parent_dir) +sys.path.append(parent_dir) +sys.path.append(project_root_dir) + +import cv2 +import random +import yaml +import torch +import numpy as np +from copy import deepcopy +import albumentations as A +from .abstract_dataset import DeepfakeAbstractBaseDataset +from PIL import Image + +c=0 + +class LRLDataset(DeepfakeAbstractBaseDataset): + def __init__(self, config=None, mode='train'): + super().__init__(config, mode) + global c + c=config + + def multi_pass_filter(self, img, r1=0.33, r2=0.66): + rows, cols = img.shape + k = cols / rows + + mask = np.zeros((rows, cols), np.uint8) + x, y = np.ogrid[:rows, :cols] + mask_area = (k * x + y < r1 * cols) + mask[mask_area] = 1 + low_mask = mask + + mask = np.ones((rows, cols), np.uint8) + x, y = np.ogrid[:rows, :cols] + mask_area = (k * x + y < r2 * cols) + mask[mask_area] = 0 + high_mask = mask + + mask1 = np.zeros((rows, cols), np.uint8) + mask1[low_mask == 0] = 1 + mask2 = np.zeros((rows, cols), np.uint8) + mask2[high_mask == 0] = 1 + mid_mask = mask1 * mask2 + + return low_mask, mid_mask, high_mask + + def image2dct(self,img): + img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + img_gray = np.float32(img_gray) + img_dct = cv2.dct(img_gray) + # img_dct = np.log(np.abs(img_dct)+1e-6) + + low_mask, mid_mask, high_mask = self.multi_pass_filter(img_dct, r1=0.33, r2=0.33) + img_dct_filterd = high_mask * img_dct + img_idct = cv2.idct(img_dct_filterd) + + return img_idct + + def __getitem__(self, index): + image_trans, label, landmark_tensors, mask_trans = super().__getitem__(index, no_norm=True) + + img_idct = self.image2dct(image_trans) + # normalize idct + img_idct = (img_idct / 255 - 0.5) / 0.5 + # img_idct = img_idct[np.newaxis, ...] + + # To tensor and normalize for fake and real images + image_trans = self.normalize(self.to_tensor(image_trans)) + img_idct_trans = self.to_tensor(img_idct) + mask_trans = torch.from_numpy(mask_trans) + mask_trans = mask_trans.squeeze(2).permute(2, 0, 1) + mask_trans = torch.mean(mask_trans, dim=0, keepdim=True) + return image_trans, label, img_idct_trans, mask_trans + + def __len__(self): + return len(self.image_list) + + + @staticmethod + def collate_fn(batch): + """ + Collate a batch of data points. + + Args: + batch (list): A list of tuples containing the image tensor and label tensor. + + Returns: + A tuple containing the image tensor, the label tensor, the landmark tensor, + and the mask tensor. + """ + global c + images, labels, img_idct_trans, masks = zip(*batch) + # Stack the image, label, landmark, and mask tensors + images = torch.stack(images, dim=0) + labels = torch.LongTensor(labels) + masks = torch.stack(masks, dim=0) + img_idct_trans = torch.stack(img_idct_trans, dim=0) + + data_dict = { + 'image': images, + 'label': labels, + 'landmark': None, + 'idct': img_idct_trans, + 'mask': masks, + } + return data_dict + + + +if __name__ == '__main__': + with open(r'H:\code\DeepfakeBench\training\config\detector\lrl_effnb4.yaml', 'r') as f: + config = yaml.safe_load(f) + with open(r'H:\code\DeepfakeBench\training\config\train_config.yaml', 'r') as f: + config2 = yaml.safe_load(f) + random.seed(config['manualSeed']) + torch.manual_seed(config['manualSeed']) + if config['cuda']: + torch.cuda.manual_seed_all(config['manualSeed']) + config2['data_manner'] = 'lmdb' + config['dataset_json_folder'] = 'preprocessing/dataset_json_v3' + config.update(config2) + train_set = LRLDataset(config=config, mode='train') + train_data_loader = \ + torch.utils.data.DataLoader( + dataset=train_set, + batch_size=4, + shuffle=True, + num_workers=0, + collate_fn=train_set.collate_fn, + ) + from tqdm import tqdm + for iteration, batch in enumerate(tqdm(train_data_loader)): + print(iteration) + if iteration > 10: + break \ No newline at end of file diff --git a/training/dataset/lsda_dataset.py b/training/dataset/lsda_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..afdea7613ade084e82cf67eb0e669baad3ce90c9 --- /dev/null +++ b/training/dataset/lsda_dataset.py @@ -0,0 +1,380 @@ +import sys +sys.path.append('.') + +import os +import sys +import json +import math +import yaml + +import numpy as np +import cv2 +import random +from PIL import Image + +import torch +from torch.autograd import Variable +from torch.utils import data +from torchvision import transforms as T + + +import skimage.draw +import albumentations as alb +from albumentations import Compose, RandomBrightnessContrast, \ + HorizontalFlip, FancyPCA, HueSaturationValue, OneOf, ToGray, \ + ShiftScaleRotate, ImageCompression, PadIfNeeded, GaussNoise, GaussianBlur, RandomResizedCrop +from torch.utils.data.sampler import Sampler +from .abstract_dataset import DeepfakeAbstractBaseDataset + + +private_path_prefix = '/home/zhaokangran/cvpr24/training' + +fake_dict = { + 'real': 0, + 'Deepfakes': 1, + 'Face2Face': 2, + 'FaceSwap': 3, + 'NeuralTextures': 4, + # 'Deepfakes_Face2Face': 5, + # 'Deepfakes_FaceSwap': 6, + # 'Deepfakes_NeuralTextures': 7, + # 'Deepfakes_real': 8, + # 'Face2Face_FaceSwap': 9, + # 'Face2Face_NeuralTextures': 10, + # 'Face2Face_real': 11, + # 'FaceSwap_NeuralTextures': 12, + # 'FaceSwap_real': 13, + # 'NeuralTextures_real': 14, +} + + + +class RandomDownScale(alb.core.transforms_interface.ImageOnlyTransform): + def apply(self,img,**params): + return self.randomdownscale(img) + + def randomdownscale(self,img): + keep_ratio=True + keep_input_shape=True + H,W,C=img.shape + ratio_list=[2,4] + r=ratio_list[np.random.randint(len(ratio_list))] + img_ds=cv2.resize(img,(int(W/r),int(H/r)),interpolation=cv2.INTER_NEAREST) + if keep_input_shape: + img_ds=cv2.resize(img_ds,(W,H),interpolation=cv2.INTER_LINEAR) + + return img_ds + + +augmentation_methods = alb.Compose([ + # alb.RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1,0.1), p=0.5), + # HorizontalFlip(p=0.5), + # RandomDownScale(p=0.5), + # alb.Sharpen(alpha=(0.2, 0.5), lightness=(0.5, 1.0), p=0.5), + alb.ImageCompression(quality_lower=40,quality_upper=100,p=0.5), + GaussianBlur(blur_limit=[3, 7], p=0.5) +], p=1.) + +augmentation_methods2 = alb.Compose([ + alb.RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1,0.1), p=0.5), + HorizontalFlip(p=0.5), + RandomDownScale(p=0.5), + alb.Sharpen(alpha=(0.2, 0.5), lightness=(0.5, 1.0), p=0.5), + alb.ImageCompression(quality_lower=40,quality_upper=100,p=0.5), +], +additional_targets={f'image1':'image', f'image2':'image', f'image3':'image', f'image4':'image'}, +p=1.) + +normalize = T.Normalize(mean=[0.5, 0.5, 0.5], + std =[0.5, 0.5, 0.5]) +transforms1 = T.Compose([ + T.ToTensor(), + normalize + ]) + +#========================================== + +def load_rgb(file_path, size=256): + assert os.path.exists(file_path), f"{file_path} is not exists" + img = cv2.imread(file_path) + if img is None: + raise ValueError('Img is None: {}'.format(file_path)) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC) + + return Image.fromarray(np.array(img, dtype=np.uint8)) + + +def load_mask(file_path, size=256): + mask = cv2.imread(file_path, 0) + if mask is None: + mask = np.zeros((size, size)) + + mask = cv2.resize(mask, (size, size))/255 + mask = np.expand_dims(mask, axis=2) + return np.float32(mask) + + +def add_gaussian_noise(ins, mean=0, stddev=0.1): + noise = ins.data.new(ins.size()).normal_(mean, stddev) + return torch.clamp(ins + noise, -1, 1) + + +# class RandomBlur(object): +# """ Randomly blur an image +# """ +# def __init__(self, ratio,) + +# class RandomCompression(object): +# """ Randomly compress an image +# """ + +class CustomSampler(Sampler): + def __init__(self, num_groups=2*360, n_frame_per_vid=32, videos_per_group=5, batch_size=10): + self.num_groups = num_groups + self.n_frame_per_vid = n_frame_per_vid + self.videos_per_group = videos_per_group + self.batch_size = batch_size + assert self.batch_size % self.videos_per_group == 0, "Batch size should be a multiple of videos_per_group." + self.groups_per_batch = self.batch_size // self.videos_per_group + + def __iter__(self): + group_indices = list(range(self.num_groups)) + random.shuffle(group_indices) + + # For each batch + for i in range(0, len(group_indices), self.groups_per_batch): + selected_groups = group_indices[i:i+self.groups_per_batch] + + # For each group + for group in selected_groups: + frame_idx = random.randint(0, self.n_frame_per_vid - 1) # Random frame index for this group's videos + + # Return the frame for each video in this group using the same frame_idx + for video_offset in range(self.videos_per_group): + yield group * self.videos_per_group * self.n_frame_per_vid + video_offset * self.n_frame_per_vid + frame_idx + + def __len__(self): + return self.num_groups * self.videos_per_group # Total frames + + + +class LSDADataset(DeepfakeAbstractBaseDataset): + + on_3060 = "3060" in torch.cuda.get_device_name() + transfer_dict = { + 'youtube':'FF-real', + 'Deepfakes':'FF-DF', + 'Face2Face':'FF-F2F', + 'FaceSwap':'FF-FS', + 'NeuralTextures':'FF-NT' + + + } + if on_3060: + data_root = r'F:\Datasets\rgb\FaceForensics++' + else: + data_root = r'./datasets/FaceForensics++' + data_list = { + 'test': r'./datasets/FaceForensics++/test.json', + 'train': r'./datasets/FaceForensics++/train.json', + 'eval': r'./datasets/FaceForensics++/val.json' + } + + def __init__(self, config=None, mode='train', with_dataset=['Deepfakes', 'Face2Face', 'FaceSwap', 'NeuralTextures']): + super().__init__(config, mode) + self.mode = mode + self.res = config['resolution'] + self.fake_dict = fake_dict + # transform + self.normalize = T.Normalize(mean=config['mean'], + std =config['std']) + # data aug and transform + self.transforms1 = T.Compose([ + T.ToTensor(), + self.normalize + ]) + self.img_lines = [] + self.config=config + with open(self.config['dataset_json_folder']+'/FaceForensics++.json', 'r') as fd: + self.img_json = json.load(fd) + with open(self.data_list[mode], 'r') as fd: + data = json.load(fd) + img_lines = [] + for pair in data: + r1, r2 = pair + step = 1 + # collect a group with 1+len(fakes) videos, each video has self.frames[mode] frames。 + for i in range(0, config['frame_num'][mode], step): + # collect real data here(r1) + img_lines.append(('{}/{}'.format('youtube', r1), i, 0, mode)) + + for fake_d in with_dataset: + # collect fake data here(r1_r2 * 4) + for i in range(0, config['frame_num'][mode], step): + img_lines.append( + ('{}/{}_{}'.format(fake_d, r1, r2), i, self.fake_dict[fake_d], mode)) + + for i in range(0, config['frame_num'][mode], step): + # collect real data here(r2) + img_lines.append(('{}/{}'.format('youtube', r2), i, 0, mode)) + + for fake_d in with_dataset: + # collect fake data here(r2_r1 * 4) + for i in range(0, config['frame_num'][mode], step): + img_lines.append( + ('{}/{}_{}'.format(fake_d, r2, r1), i, self.fake_dict[fake_d], mode)) + + # 2*360 (groups) * 1+len(with_dataset) (videos in each group) * self.frames[mode] (frames in each video) + assert len(img_lines) == 2*len(data) * (1 + len(with_dataset)) * config['frame_num'][mode], "to match our custom sampler, the length should be 2*360*(1+len(with_dataset))*frames[mode]" + self.img_lines.extend(img_lines) + + + def get_ids_from_path(self, path): + parts = path.split('/') + try: + if 'youtube' in path: + return [int(parts[-1])] + else: + return list(map(int, parts[-1].split('_'))) + except: + raise ValueError("wrong path: {}".format(path)) + + def load_image(self, name, idx): + instance_type, video_name = name.split('/') + all_frames = self.img_json[self.data_root.split(os.path.sep)[-1]][self.transfer_dict[instance_type]]['train']['c23'][video_name]['frames'] + img_path = all_frames[idx] + + impath = img_path + img = self.load_rgb(impath) + return img + + def __getitem__(self, index): + name, idx, label, mode = self.img_lines[index] + label = int(label) # specific fake label from 1-4 + + + try: + img = self.load_image(name, idx) + except Exception as e: + + # random_idx = random.randint(0, len(self.img_lines)-1) + # print(f'Error loading image {name} at index {idx} due to the loading error. Try another one at index {random_idx}') + # return self.__getitem__(random_idx) + + + if idx==0: + new_index = index+1 + elif idx==31: + new_index = index-1 + else: + new_index = index + random.choice([-1,1]) + print(f'Error loading image {name} at index {idx} due to the loading error. Try another one at index {new_index}') + return self.__getitem__(new_index) + + + if self.mode=='train': + # do augmentation + img = np.asarray(img) # convert PIL to numpy + + img = augmentation_methods2(image=img)['image'] + img = Image.fromarray(np.array(img, dtype=np.uint8)) # covnert numpy to PIL + + # transform with PIL as input + img = self.transforms1(img) + else: + raise ValueError("Not implemented yet") + + return (img, label) + + + + def __len__(self): + return len(self.img_lines) + + + + @staticmethod + def collate_fn(batch): + # Unzip the batch into images and labels + images, labels = zip(*batch) + + # images, labels = zip(batch['image'], batch['label']) + + # image_list = [] + + # for i in range(len(images)//5): + + # img = images[i*5:(i+1)*5] + + # # do augmentation + # imgs_aug = augmentation_methods2(image=np.asarray(img[0]), image1=np.asarray(img[1]), image2=np.asarray(img[2]), image3=np.asarray(img[3]), image4=np.asarray(img[4])) + # for k in imgs_aug: + + # img_aug = Image.fromarray(np.array(imgs_aug[k], dtype=np.uint8)) # covnert numpy to PIL + + # # transform with PIL as input + # img_aug = transforms1(img_aug) + # image_list.append(img_aug) + + # Stack the images and labels + images = torch.stack(images, dim=0) # Shape: (batch_size, c, h, w) + labels = torch.tensor(labels, dtype=torch.long) + + bs, c, h, w = images.shape + + # Assume videos_per_group is 5 in our case + videos_per_group = 5 + num_groups = bs // videos_per_group + + # Reshape to get the group dimension: (num_groups, videos_per_group, c, h, w) + images_grouped = images.view(num_groups, videos_per_group, c, h, w) + labels_grouped = labels.view(num_groups, videos_per_group) + + valid_indices = [] + for i, group in enumerate(labels_grouped): + if set(group.numpy().tolist()) == {0, 1, 2, 3, 4}: + valid_indices.append(i) + # elif set(group.numpy().tolist()) == {0, 1, 2, 3}: + # valid_indices.append(i) + # elif set(group.numpy().tolist()) == {0, 1, 2, 3, 4, 5}: + # valid_indices.append(i) + + images_grouped = images_grouped[valid_indices] + labels_grouped = labels_grouped[valid_indices] + + if not valid_indices: + raise ValueError("No valid groups found in this batch.") + + # # Shuffle the video order within each group + # for i in range(num_groups): + # perm = torch.randperm(videos_per_group) + # images_grouped[i] = images_grouped[i, perm] + # labels_grouped[i] = labels_grouped[i, perm] + + # # Flatten back to original shape but with shuffled video order + # images_shuffled = images_grouped.view(num_groups, videos_per_group, c, h, w) + # labels_shuffled = labels_grouped.view(bs) + + return {'image': images_grouped, 'label': labels_grouped, 'mask': None, 'landmark': None} + + +if __name__ == '__main__': + with open('/data/home/zhiyuanyan/DeepfakeBench/training/config/detector/lsda.yaml', 'r') as f: + config = yaml.safe_load(f) + train_set = LSDADataset(config=config, mode='train') + custom_sampler = CustomSampler(num_groups=2*360, n_frame_per_vid=config['frame_num']['train'], batch_size=config['train_batchSize'], videos_per_group=5) + train_data_loader = \ + torch.utils.data.DataLoader( + dataset=train_set, + batch_size=config['train_batchSize'], + num_workers=0, + sampler=custom_sampler, + collate_fn=train_set.collate_fn, + ) + from tqdm import tqdm + for iteration, batch in enumerate(tqdm(train_data_loader)): + print(iteration) + if iteration > 10: + break \ No newline at end of file diff --git a/training/dataset/pair_dataset.py b/training/dataset/pair_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ce3b676e70f215d0261dff5339d2bf56a1f67f96 --- /dev/null +++ b/training/dataset/pair_dataset.py @@ -0,0 +1,150 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-03-30 + +The code is designed for scenarios such as disentanglement-based methods where it is necessary to ensure an equal number of positive and negative samples. +''' + +import torch +import random +import numpy as np +from dataset.abstract_dataset import DeepfakeAbstractBaseDataset + + +class pairDataset(DeepfakeAbstractBaseDataset): + def __init__(self, config=None, mode='train'): + super().__init__(config, mode) + + # Get real and fake image lists + # Fix the label of real images to be 0 and fake images to be 1 + self.fake_imglist = [(img, label, 1) for img, label in zip(self.image_list, self.label_list) if label != 0] + self.real_imglist = [(img, label, 0) for img, label in zip(self.image_list, self.label_list) if label == 0] + + def __getitem__(self, index, norm=True): + # Get the fake and real image paths and labels + fake_image_path, fake_spe_label, fake_label = self.fake_imglist[index] + real_index = random.randint(0, len(self.real_imglist) - 1) # Randomly select a real image + real_image_path, real_spe_label, real_label = self.real_imglist[real_index] + + # Get the mask and landmark paths for fake and real images + fake_mask_path = fake_image_path.replace('frames', 'masks') + fake_landmark_path = fake_image_path.replace('frames', 'landmarks').replace('.png', '.npy') + + real_mask_path = real_image_path.replace('frames', 'masks') + real_landmark_path = real_image_path.replace('frames', 'landmarks').replace('.png', '.npy') + + # Load the fake and real images + fake_image = self.load_rgb(fake_image_path) + real_image = self.load_rgb(real_image_path) + + fake_image = np.array(fake_image) # Convert to numpy array for data augmentation + real_image = np.array(real_image) # Convert to numpy array for data augmentation + + # Load mask and landmark (if needed) for fake and real images + if self.config['with_mask']: + fake_mask = self.load_mask(fake_mask_path) + real_mask = self.load_mask(real_mask_path) + else: + fake_mask, real_mask = None, None + + if self.config['with_landmark']: + fake_landmarks = self.load_landmark(fake_landmark_path) + real_landmarks = self.load_landmark(real_landmark_path) + else: + fake_landmarks, real_landmarks = None, None + + # Do transforms for fake and real images + fake_image_trans, fake_landmarks_trans, fake_mask_trans = self.data_aug(fake_image, fake_landmarks, fake_mask) + real_image_trans, real_landmarks_trans, real_mask_trans = self.data_aug(real_image, real_landmarks, real_mask) + + if not norm: + return {"fake": (fake_image_trans, fake_label), + "real": (real_image_trans, real_label)} + + # To tensor and normalize for fake and real images + fake_image_trans = self.normalize(self.to_tensor(fake_image_trans)) + real_image_trans = self.normalize(self.to_tensor(real_image_trans)) + + # Convert landmarks and masks to tensors if they exist + if self.config['with_landmark']: + fake_landmarks_trans = torch.from_numpy(fake_landmarks_trans) + real_landmarks_trans = torch.from_numpy(real_landmarks_trans) + if self.config['with_mask']: + fake_mask_trans = torch.from_numpy(fake_mask_trans) + real_mask_trans = torch.from_numpy(real_mask_trans) + + return {"fake": (fake_image_trans, fake_label, fake_spe_label, fake_landmarks_trans, fake_mask_trans), + "real": (real_image_trans, real_label, real_spe_label, real_landmarks_trans, real_mask_trans)} + + def __len__(self): + return len(self.fake_imglist) + + @staticmethod + def collate_fn(batch): + """ + Collate a batch of data points. + + Args: + batch (list): A list of tuples containing the image tensor, the label tensor, + the landmark tensor, and the mask tensor. + + Returns: + A tuple containing the image tensor, the label tensor, the landmark tensor, + and the mask tensor. + """ + # Separate the image, label, landmark, and mask tensors for fake and real data + fake_images, fake_labels, fake_spe_labels, fake_landmarks, fake_masks = zip(*[data["fake"] for data in batch]) + real_images, real_labels, real_spe_labels, real_landmarks, real_masks = zip(*[data["real"] for data in batch]) + + # Stack the image, label, landmark, and mask tensors for fake and real data + fake_images = torch.stack(fake_images, dim=0) + fake_labels = torch.LongTensor(fake_labels) + fake_spe_labels = torch.LongTensor(fake_spe_labels) + real_images = torch.stack(real_images, dim=0) + real_labels = torch.LongTensor(real_labels) + real_spe_labels = torch.LongTensor(real_spe_labels) + + # Special case for landmarks and masks if they are None + if fake_landmarks[0] is not None: + fake_landmarks = torch.stack(fake_landmarks, dim=0) + else: + fake_landmarks = None + if real_landmarks[0] is not None: + real_landmarks = torch.stack(real_landmarks, dim=0) + else: + real_landmarks = None + + if fake_masks[0] is not None: + fake_masks = torch.stack(fake_masks, dim=0) + else: + fake_masks = None + if real_masks[0] is not None: + real_masks = torch.stack(real_masks, dim=0) + else: + real_masks = None + + # Combine the fake and real tensors and create a dictionary of the tensors + images = torch.cat([real_images, fake_images], dim=0) + labels = torch.cat([real_labels, fake_labels], dim=0) + spe_labels = torch.cat([real_spe_labels, fake_spe_labels], dim=0) + + if fake_landmarks is not None and real_landmarks is not None: + landmarks = torch.cat([real_landmarks, fake_landmarks], dim=0) + else: + landmarks = None + + if fake_masks is not None and real_masks is not None: + masks = torch.cat([real_masks, fake_masks], dim=0) + else: + masks = None + + data_dict = { + 'image': images, + 'label': labels, + 'label_spe': spe_labels, + 'landmark': landmarks, + 'mask': masks + } + return data_dict + diff --git a/training/dataset/sbi_api.py b/training/dataset/sbi_api.py new file mode 100644 index 0000000000000000000000000000000000000000..461f9b6ab9c3b7c9d26aef1c8428cd01f90c9e4e --- /dev/null +++ b/training/dataset/sbi_api.py @@ -0,0 +1,371 @@ +# Created by: Kaede Shiohara +# Yamasaki Lab at The University of Tokyo +# shiohara@cvm.t.u-tokyo.ac.jp +# Copyright (c) 2021 +# 3rd party softwares' licenses are noticed at https://github.com/mapooon/SelfBlendedImages/blob/master/LICENSE + +import torch +from torchvision import datasets,transforms,utils +from torch.utils.data import Dataset,IterableDataset +from glob import glob +import os +import numpy as np +from PIL import Image +import random +import cv2 +from torch import nn +import sys +import scipy as sp +from skimage.measure import label, regionprops +from training.dataset.library.bi_online_generation import random_get_hull +import albumentations as alb + +import warnings +warnings.filterwarnings('ignore') + + +def alpha_blend(source,target,mask): + mask_blured = get_blend_mask(mask) + img_blended=(mask_blured * source + (1 - mask_blured) * target) + return img_blended,mask_blured + + +def dynamic_blend(source,target,mask): + mask_blured = get_blend_mask(mask) + blend_list=[0.25,0.5,0.75,1,1,1] + blend_ratio = blend_list[np.random.randint(len(blend_list))] + mask_blured*=blend_ratio + img_blended=(mask_blured * source + (1 - mask_blured) * target) + return img_blended,mask_blured + + +def get_blend_mask(mask): + H,W=mask.shape + size_h=np.random.randint(192,257) + size_w=np.random.randint(192,257) + mask=cv2.resize(mask,(size_w,size_h)) + kernel_1=random.randrange(5,26,2) + kernel_1=(kernel_1,kernel_1) + kernel_2=random.randrange(5,26,2) + kernel_2=(kernel_2,kernel_2) + + mask_blured = cv2.GaussianBlur(mask, kernel_1, 0) + mask_blured = mask_blured/(mask_blured.max()) + mask_blured[mask_blured<1]=0 + + mask_blured = cv2.GaussianBlur(mask_blured, kernel_2, np.random.randint(5,46)) + mask_blured = mask_blured/(mask_blured.max()) + mask_blured = cv2.resize(mask_blured,(W,H)) + return mask_blured.reshape((mask_blured.shape+(1,))) + + +def get_alpha_blend_mask(mask): + kernel_list=[(11,11),(9,9),(7,7),(5,5),(3,3)] + blend_list=[0.25,0.5,0.75] + kernel_idxs=random.choices(range(len(kernel_list)), k=2) + blend_ratio = blend_list[random.sample(range(len(blend_list)), 1)[0]] + mask_blured = cv2.GaussianBlur(mask, kernel_list[0], 0) + # print(mask_blured.max()) + mask_blured[mask_blured0]=1 + # mask_blured = mask + mask_blured = cv2.GaussianBlur(mask_blured, kernel_list[kernel_idxs[1]], 0) + mask_blured = mask_blured/(mask_blured.max()) + return mask_blured.reshape((mask_blured.shape+(1,))) + + +class RandomDownScale(alb.core.transforms_interface.ImageOnlyTransform): + def apply(self,img,**params): + return self.randomdownscale(img) + + def randomdownscale(self,img): + keep_ratio=True + keep_input_shape=True + H,W,C=img.shape + ratio_list=[2,4] + r=ratio_list[np.random.randint(len(ratio_list))] + img_ds=cv2.resize(img,(int(W/r),int(H/r)),interpolation=cv2.INTER_NEAREST) + if keep_input_shape: + img_ds=cv2.resize(img_ds,(W,H),interpolation=cv2.INTER_LINEAR) + + return img_ds + + + +def get_boundary(mask, apply_dilation=True, apply_motion_blur=True): + if len(mask.shape) == 3: + mask = mask[:, :, 0] + + mask = cv2.GaussianBlur(mask, (3, 3), 0) + if mask.max() > 1: + boundary = mask / 255. + else: + boundary = mask + boundary = 4 * boundary * (1. - boundary) + + boundary = boundary * 255 + boundary = random_dilate(boundary) + + if apply_motion_blur: + boundary = random_motion_blur(boundary) + boundary = boundary / 255. + return boundary + +def random_dilate(mask, max_kernel_size=5): + kernel_size = random.randint(1, max_kernel_size) + kernel = np.ones((kernel_size, kernel_size), np.uint8) + dilated_mask = cv2.dilate(mask, kernel, iterations=1) + return dilated_mask + +def random_motion_blur(mask, max_kernel_size=5): + kernel_size = random.randint(1, max_kernel_size) + kernel = np.zeros((kernel_size, kernel_size)) + anchor = random.randint(0, kernel_size - 1) + kernel[:, anchor] = 1 / kernel_size + motion_blurred_mask = cv2.filter2D(mask, -1, kernel) + return motion_blurred_mask + + + +class SBI_API: + def __init__(self,phase='train',image_size=256): + + assert phase == 'train', f"Current SBI API only support train phase, but got {phase}" + + self.image_size=(image_size,image_size) + self.phase=phase + + self.transforms=self.get_transforms() + self.source_transforms = self.get_source_transforms() + self.bob_transforms = self.get_source_transforms_for_bob() + + + def __call__(self,img,landmark=None): + try: + assert landmark is not None, "landmark of the facial image should not be None." + # img_r,img_f,mask_f=self.self_blending(img.copy(),landmark.copy()) + + if random.random() < 1.0: + # apply sbi + img_r,img_f,mask_f=self.self_blending(img.copy(),landmark.copy()) + else: + # apply boundary motion blur (bob) + img_r,img_f,mask_f=self.bob(img.copy(),landmark.copy()) + + if self.phase=='train': + transformed=self.transforms(image=img_f.astype('uint8'),image1=img_r.astype('uint8')) + img_f=transformed['image'] + img_r=transformed['image1'] + return img_f,img_r + except Exception as e: + print(e) + return None,None + + + def get_source_transforms(self): + return alb.Compose([ + alb.Compose([ + alb.RGBShift((-20,20),(-20,20),(-20,20),p=0.3), + alb.HueSaturationValue(hue_shift_limit=(-0.3,0.3), sat_shift_limit=(-0.3,0.3), val_shift_limit=(-0.3,0.3), p=1), + alb.RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1,0.1), p=1), + ],p=1), + + alb.OneOf([ + RandomDownScale(p=1), + alb.Sharpen(alpha=(0.2, 0.5), lightness=(0.5, 1.0), p=1), + ],p=1), + + ], p=1.) + + + def get_transforms(self): + return alb.Compose([ + + alb.RGBShift((-20,20),(-20,20),(-20,20),p=0.3), + alb.HueSaturationValue(hue_shift_limit=(-0.3,0.3), sat_shift_limit=(-0.3,0.3), val_shift_limit=(-0.3,0.3), p=0.3), + alb.RandomBrightnessContrast(brightness_limit=(-0.3,0.3), contrast_limit=(-0.3,0.3), p=0.3), + alb.ImageCompression(quality_lower=40,quality_upper=100,p=0.5), + + ], + additional_targets={f'image1': 'image'}, + p=1.) + + + def randaffine(self,img,mask): + f=alb.Affine( + translate_percent={'x':(-0.03,0.03),'y':(-0.015,0.015)}, + scale=[0.95,1/0.95], + fit_output=False, + p=1) + + g=alb.ElasticTransform( + alpha=50, + sigma=7, + alpha_affine=0, + p=1, + ) + + transformed=f(image=img,mask=mask) + img=transformed['image'] + + mask=transformed['mask'] + transformed=g(image=img,mask=mask) + mask=transformed['mask'] + return img,mask + + + def get_source_transforms_for_bob(self): + return alb.Compose([ + alb.Compose([ + alb.ImageCompression(quality_lower=40,quality_upper=100,p=1), + ],p=1), + + alb.OneOf([ + RandomDownScale(p=1), + alb.Sharpen(alpha=(0.2, 0.5), lightness=(0.5, 1.0), p=1), + ],p=1), + + ], p=1.) + + def bob(self,img,landmark): + H,W=len(img),len(img[0]) + if np.random.rand()<0.25: + landmark=landmark[:68] + # mask=np.zeros_like(img[:,:,0]) + # cv2.fillConvexPoly(mask, cv2.convexHull(landmark), 1.) + hull_type = random.choice([0, 1, 2, 3]) + mask=random_get_hull(landmark,img,hull_type)[:,:,0] + + source = img.copy() + source = self.bob_transforms(image=source.astype(np.uint8))['image'] + source, mask = self.randaffine(source,mask) + mask = get_blend_mask(mask) + + # get boundary with motion blur + boundary = get_boundary(mask) + + blend_list = [0.25,0.5,0.75,1,1,1] + blend_ratio = blend_list[np.random.randint(len(blend_list))] + boundary *= blend_ratio + boundary = np.repeat(boundary[:, :, np.newaxis], 3, axis=2) + img_blended = (boundary * source + (1 - boundary) * img) + + img_blended = img_blended.astype(np.uint8) + img = img.astype(np.uint8) + + return img,img_blended,boundary.squeeze() + + + def self_blending(self,img,landmark): + H,W=len(img),len(img[0]) + if np.random.rand()<0.25: + landmark=landmark[:68] + # mask=np.zeros_like(img[:,:,0]) + # cv2.fillConvexPoly(mask, cv2.convexHull(landmark), 1.) + hull_type = random.choice([0, 1, 2, 3]) + mask=random_get_hull(landmark,img,hull_type)[:,:,0] + + source = img.copy() + if np.random.rand()<0.5: + source = self.source_transforms(image=source.astype(np.uint8))['image'] + else: + img = self.source_transforms(image=img.astype(np.uint8))['image'] + + source, mask = self.randaffine(source,mask) + + img_blended,mask=dynamic_blend(source,img,mask) + img_blended = img_blended.astype(np.uint8) + img = img.astype(np.uint8) + + return img,img_blended,mask + + + def reorder_landmark(self,landmark): + landmark_add=np.zeros((13,2)) + for idx,idx_l in enumerate([77,75,76,68,69,70,71,80,72,73,79,74,78]): + landmark_add[idx]=landmark[idx_l] + landmark[68:]=landmark_add + return landmark + + + def hflip(self,img,mask=None,landmark=None,bbox=None): + H,W=img.shape[:2] + landmark=landmark.copy() + if bbox is not None: + bbox=bbox.copy() + + if landmark is not None: + landmark_new=np.zeros_like(landmark) + + + landmark_new[:17]=landmark[:17][::-1] + landmark_new[17:27]=landmark[17:27][::-1] + + landmark_new[27:31]=landmark[27:31] + landmark_new[31:36]=landmark[31:36][::-1] + + landmark_new[36:40]=landmark[42:46][::-1] + landmark_new[40:42]=landmark[46:48][::-1] + + landmark_new[42:46]=landmark[36:40][::-1] + landmark_new[46:48]=landmark[40:42][::-1] + + landmark_new[48:55]=landmark[48:55][::-1] + landmark_new[55:60]=landmark[55:60][::-1] + + landmark_new[60:65]=landmark[60:65][::-1] + landmark_new[65:68]=landmark[65:68][::-1] + if len(landmark)==68: + pass + elif len(landmark)==81: + landmark_new[68:81]=landmark[68:81][::-1] + else: + raise NotImplementedError + landmark_new[:,0]=W-landmark_new[:,0] + + else: + landmark_new=None + + if bbox is not None: + bbox_new=np.zeros_like(bbox) + bbox_new[0,0]=bbox[1,0] + bbox_new[1,0]=bbox[0,0] + bbox_new[:,0]=W-bbox_new[:,0] + bbox_new[:,1]=bbox[:,1].copy() + if len(bbox)>2: + bbox_new[2,0]=W-bbox[3,0] + bbox_new[2,1]=bbox[3,1] + bbox_new[3,0]=W-bbox[2,0] + bbox_new[3,1]=bbox[2,1] + bbox_new[4,0]=W-bbox[4,0] + bbox_new[4,1]=bbox[4,1] + bbox_new[5,0]=W-bbox[6,0] + bbox_new[5,1]=bbox[6,1] + bbox_new[6,0]=W-bbox[5,0] + bbox_new[6,1]=bbox[5,1] + else: + bbox_new=None + + if mask is not None: + mask=mask[:,::-1] + else: + mask=None + img=img[:,::-1].copy() + return img,mask,landmark_new,bbox_new + + +if __name__=='__main__': + seed=10 + random.seed(seed) + torch.manual_seed(seed) + np.random.seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + api=SBI_API(phase='train',image_size=256) + + img_path = 'FaceForensics++/original_sequences/youtube/c23/frames/000/000.png' + img = cv2.imread(img_path) + landmark_path = img_path.replace('frames', 'landmarks').replace('png', 'npy') + landmark = np.load(landmark_path) + sbi_img, ori_img = api(img, landmark) diff --git a/training/dataset/sbi_dataset.py b/training/dataset/sbi_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d39e38d2ac9f599067cb4efe4cef063eb74eee35 --- /dev/null +++ b/training/dataset/sbi_dataset.py @@ -0,0 +1,139 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2024-01-26 + +The code is designed for self-blending method (SBI, CVPR 2024). +''' + +import sys +sys.path.append('.') + +import cv2 +import yaml +import torch +import numpy as np +from copy import deepcopy +import albumentations as A +from training.dataset.albu import IsotropicResize +from training.dataset.abstract_dataset import DeepfakeAbstractBaseDataset +from training.dataset.sbi_api import SBI_API + + +class SBIDataset(DeepfakeAbstractBaseDataset): + def __init__(self, config=None, mode='train'): + super().__init__(config, mode) + + # Get real lists + # Fix the label of real images to be 0 + self.real_imglist = [(img, label) for img, label in zip(self.image_list, self.label_list) if label == 0] + + # Init SBI + self.sbi = SBI_API(phase=mode,image_size=config['resolution']) + + # Init data augmentation method + self.transform = self.init_data_aug_method() + + def __getitem__(self, index): + # Get the real image paths and labels + real_image_path, real_label = self.real_imglist[index] + + # Get the landmark paths for real images + real_landmark_path = real_image_path.replace('frames', 'landmarks').replace('.png', '.npy') + landmark = self.load_landmark(real_landmark_path).astype(np.int32) + + # Load the real images + real_image = self.load_rgb(real_image_path) + real_image = np.array(real_image) # Convert to numpy array + + # Generate the corresponding SBI sample + fake_image, real_image = self.sbi(real_image, landmark) + if fake_image is None: + fake_image = deepcopy(real_image) + fake_label = 0 + else: + fake_label = 1 + + # To tensor and normalize for fake and real images + fake_image_trans = self.normalize(self.to_tensor(fake_image)) + real_image_trans = self.normalize(self.to_tensor(real_image)) + + return {"fake": (fake_image_trans, fake_label), + "real": (real_image_trans, real_label)} + + def __len__(self): + return len(self.real_imglist) + + @staticmethod + def collate_fn(batch): + """ + Collate a batch of data points. + + Args: + batch (list): A list of tuples containing the image tensor and label tensor. + + Returns: + A tuple containing the image tensor, the label tensor, the landmark tensor, + and the mask tensor. + """ + # Separate the image, label, landmark, and mask tensors for fake and real data + fake_images, fake_labels = zip(*[data["fake"] for data in batch]) + real_images, real_labels = zip(*[data["real"] for data in batch]) + + # Stack the image, label, landmark, and mask tensors for fake and real data + fake_images = torch.stack(fake_images, dim=0) + fake_labels = torch.LongTensor(fake_labels) + real_images = torch.stack(real_images, dim=0) + real_labels = torch.LongTensor(real_labels) + + # Combine the fake and real tensors and create a dictionary of the tensors + images = torch.cat([real_images, fake_images], dim=0) + labels = torch.cat([real_labels, fake_labels], dim=0) + + data_dict = { + 'image': images, + 'label': labels, + 'landmark': None, + 'mask': None, + } + return data_dict + + def init_data_aug_method(self): + trans = A.Compose([ + A.HorizontalFlip(p=self.config['data_aug']['flip_prob']), + A.Rotate(limit=self.config['data_aug']['rotate_limit'], p=self.config['data_aug']['rotate_prob']), + A.GaussianBlur(blur_limit=self.config['data_aug']['blur_limit'], p=self.config['data_aug']['blur_prob']), + A.OneOf([ + IsotropicResize(max_side=self.config['resolution'], interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC), + IsotropicResize(max_side=self.config['resolution'], interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_LINEAR), + IsotropicResize(max_side=self.config['resolution'], interpolation_down=cv2.INTER_LINEAR, interpolation_up=cv2.INTER_LINEAR), + ], p = 0 if self.config['with_landmark'] else 1), + A.OneOf([ + A.RandomBrightnessContrast(brightness_limit=self.config['data_aug']['brightness_limit'], contrast_limit=self.config['data_aug']['contrast_limit']), + A.FancyPCA(), + A.HueSaturationValue() + ], p=0.5), + A.ImageCompression(quality_lower=self.config['data_aug']['quality_lower'], quality_upper=self.config['data_aug']['quality_upper'], p=0.5) + ], + additional_targets={'real': 'sbi'}, + ) + return trans + + +if __name__ == '__main__': + with open('/data/home/zhiyuanyan/DeepfakeBench/training/config/detector/sbi.yaml', 'r') as f: + config = yaml.safe_load(f) + train_set = SBIDataset(config=config, mode='train') + train_data_loader = \ + torch.utils.data.DataLoader( + dataset=train_set, + batch_size=config['train_batchSize'], + shuffle=True, + num_workers=0, + collate_fn=train_set.collate_fn, + ) + from tqdm import tqdm + for iteration, batch in enumerate(tqdm(train_data_loader)): + print(iteration) + if iteration > 10: + break \ No newline at end of file diff --git a/training/dataset/tall_dataset.py b/training/dataset/tall_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..3d6fe75b5941368727cfe5f4d7a649316a4fa5f1 --- /dev/null +++ b/training/dataset/tall_dataset.py @@ -0,0 +1,183 @@ +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-03-30 +# description: Abstract Base Class for all types of deepfake datasets. + +import sys + +from torch import nn + +sys.path.append('.') + +import yaml +import numpy as np +from copy import deepcopy +import random +import torch +from torch.utils import data +from torchvision.utils import save_image +from training.dataset import DeepfakeAbstractBaseDataset +from einops import rearrange + +FFpp_pool = ['FaceForensics++', 'FaceShifter', 'DeepFakeDetection', 'FF-DF', 'FF-F2F', 'FF-FS', 'FF-NT'] # + + +def all_in_pool(inputs, pool): + for each in inputs: + if each not in pool: + return False + return True + + +class TALLDataset(DeepfakeAbstractBaseDataset): + def __init__(self, config=None, mode='train'): + """Initializes the dataset object. + + Args: + config (dict): A dictionary containing configuration parameters. + mode (str): A string indicating the mode (train or test). + + Raises: + NotImplementedError: If mode is not train or test. + """ + super().__init__(config, mode) + + assert self.video_level, "TALL is a videl-based method" + assert int(self.clip_size ** 0.5) ** 2 == self.clip_size, 'clip_size must be square of an integer, e.g., 4' + + def __getitem__(self, index, no_norm=False): + """ + Returns the data point at the given index. + + Args: + index (int): The index of the data point. + + Returns: + A tuple containing the image tensor, the label tensor, the landmark tensor, + and the mask tensor. + """ + # Get the image paths and label + image_paths = self.data_dict['image'][index] + label = self.data_dict['label'][index] + + if not isinstance(image_paths, list): + image_paths = [image_paths] # for the image-level IO, only one frame is used + + image_tensors = [] + landmark_tensors = [] + mask_tensors = [] + augmentation_seed = None + + for image_path in image_paths: + # Initialize a new seed for data augmentation at the start of each video + if self.video_level and image_path == image_paths[0]: + augmentation_seed = random.randint(0, 2 ** 32 - 1) + + # Get the mask and landmark paths + mask_path = image_path.replace('frames', 'masks') # Use .png for mask + landmark_path = image_path.replace('frames', 'landmarks').replace('.png', '.npy') # Use .npy for landmark + + # Load the image + try: + image = self.load_rgb(image_path) + except Exception as e: + # Skip this image and return the first one + print(f"Error loading image at index {index}: {e}") + return self.__getitem__(0) + image = np.array(image) # Convert to numpy array for data augmentation + + # Load mask and landmark (if needed) + if self.config['with_mask']: + mask = self.load_mask(mask_path) + else: + mask = None + if self.config['with_landmark']: + landmarks = self.load_landmark(landmark_path) + else: + landmarks = None + + # Do Data Augmentation + if self.mode == 'train' and self.config['use_data_augmentation']: + image_trans, landmarks_trans, mask_trans = self.data_aug(image, landmarks, mask, augmentation_seed) + else: + image_trans, landmarks_trans, mask_trans = deepcopy(image), deepcopy(landmarks), deepcopy(mask) + + # To tensor and normalize + if not no_norm: + image_trans = self.normalize(self.to_tensor(image_trans)) + if self.config['with_landmark']: + landmarks_trans = torch.from_numpy(landmarks) + if self.config['with_mask']: + mask_trans = torch.from_numpy(mask_trans) + + image_tensors.append(image_trans) + landmark_tensors.append(landmarks_trans) + mask_tensors.append(mask_trans) + + if self.video_level: + + # Stack image tensors along a new dimension (time) + image_tensors = torch.stack(image_tensors, dim=0) + + # cut out 16x16 patch + F, C, H, W = image_tensors.shape + x, y = np.random.randint(W), np.random.randint(H) + x1 = np.clip(x - self.config['mask_grid_size'] // 2, 0, W) + x2 = np.clip(x + self.config['mask_grid_size'] // 2, 0, W) + y1 = np.clip(y - self.config['mask_grid_size'] // 2, 0, H) + y2 = np.clip(y + self.config['mask_grid_size'] // 2, 0, H) + image_tensors[:, :, y1:y2, x1:x2] = -1 + + # # concatenate sub-image and reszie to 224x224 + # image_tensors = image_tensors.reshape(-1, H, W) + # image_tensors = rearrange(image_tensors, '(rh rw c) h w -> c (rh h) (rw w)', rh=2, c=C) + # image_tensors = nn.functional.interpolate(image_tensors.unsqueeze(0), + # size=(self.config['resolution'], self.config['resolution']), + # mode='bilinear', align_corners=False).squeeze(0) + # Stack landmark and mask tensors along a new dimension (time) + if not any(landmark is None or (isinstance(landmark, list) and None in landmark) for landmark in + landmark_tensors): + landmark_tensors = torch.stack(landmark_tensors, dim=0) + if not any(m is None or (isinstance(m, list) and None in m) for m in mask_tensors): + mask_tensors = torch.stack(mask_tensors, dim=0) + else: + # Get the first image tensor + image_tensors = image_tensors[0] + # Get the first landmark and mask tensors + if not any(landmark is None or (isinstance(landmark, list) and None in landmark) for landmark in + landmark_tensors): + landmark_tensors = landmark_tensors[0] + if not any(m is None or (isinstance(m, list) and None in m) for m in mask_tensors): + mask_tensors = mask_tensors[0] + + return image_tensors, label, landmark_tensors, mask_tensors + + +if __name__ == "__main__": + with open('training/config/detector/tall.yaml', 'r') as f: + config = yaml.safe_load(f) + train_set = TALLDataset( + config=config, + mode='train', + ) + train_data_loader = \ + torch.utils.data.DataLoader( + dataset=train_set, + batch_size=config['train_batchSize'], + shuffle=True, + num_workers=0, + collate_fn=train_set.collate_fn, + ) + from tqdm import tqdm + + for iteration, batch in enumerate(tqdm(train_data_loader)): + print(batch['image'].shape) + print(batch['label']) + b, f, c, h, w = batch['image'].shape + for i in range(f): + img_tensor = batch['image'][0][i] + img_tensor = img_tensor * torch.tensor([0.5, 0.5, 0.5]).reshape(-1, 1, 1) + torch.tensor( + [0.5, 0.5, 0.5]).reshape(-1, 1, 1) + save_image(img_tensor, f'{i}.png') + + break diff --git a/training/dataset/utils/DeepFakeMask.py b/training/dataset/utils/DeepFakeMask.py new file mode 100644 index 0000000000000000000000000000000000000000..442e0e80816ea1cf04dde98c92ec16656f928be5 --- /dev/null +++ b/training/dataset/utils/DeepFakeMask.py @@ -0,0 +1,402 @@ +#!/usr/bin/python +# -*- coding: UTF-8 -*- +# Created by: algohunt +# Microsoft Research & Peking University +# lilingzhi@pku.edu.cn +# Copyright (c) 2019 + +#!/usr/bin/env python3 +""" Masks functions for faceswap.py """ + +import inspect +import logging +import sys + +import cv2 +import numpy as np +import random +from math import ceil, floor +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + +def landmarks_to_bbox(landmarks: np.ndarray) -> np.ndarray: + if not isinstance(landmarks, np.ndarray): + landmarks = np.array(landmarks) + assert landmarks.shape[1] == 2 + x0, y0 = np.min(landmarks, axis=0) # minimum values along the x and y axes respectively, [264,97] + x1, y1 = np.max(landmarks, axis=0) # minimum values along the x and y axes respectively, [370,236] + bbox = np.array([x0, y0, x1, y1]) + return bbox + +def mask_from_points(image: np.ndarray, points: np.ndarray) -> np.ndarray: + """8 (or omitted) - 8-connected line. + 4 - 4-connected line. + LINE_AA - antialiased line.""" + h, w = image.shape[:2] + points = points.astype(int) + assert points.shape[1] == 2, f"points.shape: {points.shape}" + out = np.zeros((h, w), dtype=np.uint8) + hull = cv2.convexHull(points.astype(int)) + cv2.fillConvexPoly(out, hull, 255, lineType=4) # cv2.LINE_AA + return out + +def get_available_masks(): + """ Return a list of the available masks for cli """ + masks = sorted([name for name, obj in inspect.getmembers(sys.modules[__name__]) + if inspect.isclass(obj) and name != "Mask"]) + masks.append("none") + # logger.debug(masks) + return masks + +def landmarks_68_symmetries(): + # 68 landmarks symmetry + # + sym_ids = [9, 58, 67, 63, 52, 34, 31, 30, 29, 28] + sym = { + 1: 17, + 2: 16, + 3: 15, + 4: 14, + 5: 13, + 6: 12, + 7: 11, + 8: 10, + # + 51: 53, + 50: 54, + 49: 55, + 60: 56, + 59: 57, + # + 62: 64, + 61: 65, + 68: 66, + # + 33: 35, + 32: 36, + # + 37: 46, + 38: 45, + 39: 44, + 40: 43, + 41: 48, + 42: 47, + # + 18: 27, + 19: 26, + 20: 25, + 21: 24, + 22: 23, + # + # id + 9: 9, + 58: 58, + 67: 67, + 63: 63, + 52: 52, + 34: 34, + 31: 31, + 30: 30, + 29: 29, + 28: 28, + } + return sym, sym_ids + + + +def get_default_mask(): + """ Set the default mask for cli """ + masks = get_available_masks() + default = "dfl_full" + default = default if default in masks else masks[0] + # logger.debug(default) + return default + + +class Mask(): + """ Parent class for masks + the output mask will be .mask + channels: 1, 3 or 4: + 1 - Returns a single channel mask + 3 - Returns a 3 channel mask + 4 - Returns the original image with the mask in the alpha channel """ + + def __init__(self, landmarks, face, channels=4, idx = 0): + # logger.info("Initializing %s: (face_shape: %s, channels: %s, landmarks: %s)", + # self.__class__.__name__, face.shape, channels, landmarks) + self.landmarks = landmarks + self.face = face + self.channels = channels + self.cols = 4 # grid mask + self.rows = 4 # grid mask + self.idx = idx # grid mask + + mask = self.build_mask() + self.mask = self.merge_mask(mask) + # logger.info("Initialized %s", self.__class__.__name__) + + def build_mask(self): + """ Override to build the mask """ + raise NotImplementedError + + def merge_mask(self, mask): + """ Return the mask in requested shape """ + # logger.info("mask_shape: %s", mask.shape) + assert self.channels in (1, 3, 4), "Channels should be 1, 3 or 4" + assert mask.shape[2] == 1 and mask.ndim == 3, "Input mask be 3 dimensions with 1 channel" + + if self.channels == 3: + retval = np.tile(mask, 3) + elif self.channels == 4: + retval = np.concatenate((self.face, mask), -1) + else: + retval = mask + + # logger.info("Final mask shape: %s", retval.shape) + return retval + + +class dfl_full(Mask): # pylint: disable=invalid-name + """ DFL facial mask """ + def build_mask(self): + mask = np.zeros(self.face.shape[0:2] + (1, ), dtype=np.float32) + + nose_ridge = (self.landmarks[27:31], self.landmarks[33:34]) + jaw = (self.landmarks[0:17], + self.landmarks[48:68], + self.landmarks[0:1], + self.landmarks[8:9], + self.landmarks[16:17]) + eyes = (self.landmarks[17:27], + self.landmarks[0:1], + self.landmarks[27:28], + self.landmarks[16:17], + self.landmarks[33:34]) + parts = [jaw, nose_ridge, eyes] + + for item in parts: + merged = np.concatenate(item) + cv2.fillConvexPoly(mask, cv2.convexHull(merged), 255.) # pylint: disable=no-member + return mask + + +class components(Mask): # pylint: disable=invalid-name + """ Component model mask """ + def build_mask(self): + mask = np.zeros(self.face.shape[0:2] + (1, ), dtype=np.float32) + + r_jaw = (self.landmarks[0:9], self.landmarks[17:18]) + l_jaw = (self.landmarks[8:17], self.landmarks[26:27]) + r_cheek = (self.landmarks[17:20], self.landmarks[8:9]) + l_cheek = (self.landmarks[24:27], self.landmarks[8:9]) + nose_ridge = (self.landmarks[19:25], self.landmarks[8:9],) + r_eye = (self.landmarks[17:22], + self.landmarks[27:28], + self.landmarks[31:36], + self.landmarks[8:9]) + l_eye = (self.landmarks[22:27], + self.landmarks[27:28], + self.landmarks[31:36], + self.landmarks[8:9]) + nose = (self.landmarks[27:31], self.landmarks[31:36]) + parts = [r_jaw, l_jaw, r_cheek, l_cheek, nose_ridge, r_eye, l_eye, nose] + + # ---change 0531 random select parts --- + # r_face = (self.landmarks[0:9], self.landmarks[17:18],self.landmarks[17:20], self.landmarks[8:9]) + # l_face = (self.landmarks[8:17], self.landmarks[26:27],self.landmarks[24:27], self.landmarks[8:9]) + # nose_final = (self.landmarks[19:25], self.landmarks[8:9],self.landmarks[27:31], self.landmarks[31:36]) + # parts = [r_face,l_face,nose_final,r_eye,l_eye] + # num_to_select = random.randint(1, len(parts)) + # parts = random.sample(parts, num_to_select) + # print(len(parts), parts[0]) + # ---change 0531 random select parts --- + + for item in parts: + merged = np.concatenate(item) + cv2.fillConvexPoly(mask, cv2.convexHull(merged), 255.) # pylint: disable=no-member + return mask + + +class extended(Mask): # pylint: disable=invalid-name + """ Extended mask + Based on components mask. Attempts to extend the eyebrow points up the forehead + """ + def build_mask(self): + mask = np.zeros(self.face.shape[0:2] + (1, ), dtype=np.float32) + + landmarks = self.landmarks.copy() + # mid points between the side of face and eye point + ml_pnt = (landmarks[36] + landmarks[0]) // 2 + mr_pnt = (landmarks[16] + landmarks[45]) // 2 + + # mid points between the mid points and eye + ql_pnt = (landmarks[36] + ml_pnt) // 2 + qr_pnt = (landmarks[45] + mr_pnt) // 2 + + # Top of the eye arrays + bot_l = np.array((ql_pnt, landmarks[36], landmarks[37], landmarks[38], landmarks[39])) + bot_r = np.array((landmarks[42], landmarks[43], landmarks[44], landmarks[45], qr_pnt)) + + # Eyebrow arrays + top_l = landmarks[17:22] + top_r = landmarks[22:27] + + # Adjust eyebrow arrays + landmarks[17:22] = top_l + ((top_l - bot_l) // 2) + landmarks[22:27] = top_r + ((top_r - bot_r) // 2) + + r_jaw = (landmarks[0:9], landmarks[17:18]) + l_jaw = (landmarks[8:17], landmarks[26:27]) + r_cheek = (landmarks[17:20], landmarks[8:9]) + l_cheek = (landmarks[24:27], landmarks[8:9]) + nose_ridge = (landmarks[19:25], landmarks[8:9],) + r_eye = (landmarks[17:22], landmarks[27:28], landmarks[31:36], landmarks[8:9]) + l_eye = (landmarks[22:27], landmarks[27:28], landmarks[31:36], landmarks[8:9]) + nose = (landmarks[27:31], landmarks[31:36]) + parts = [r_jaw, l_jaw, r_cheek, l_cheek, nose_ridge, r_eye, l_eye, nose] + + for item in parts: + merged = np.concatenate(item) + cv2.fillConvexPoly(mask, cv2.convexHull(merged), 255.) # pylint: disable=no-member + return mask + + +class facehull(Mask): # pylint: disable=invalid-name + """ Basic face hull mask """ + def build_mask(self): + mask = np.zeros(self.face.shape[0:2] + (1, ), dtype=np.float32) + hull = cv2.convexHull( # pylint: disable=no-member + np.array(self.landmarks).reshape((-1, 2))) + cv2.fillConvexPoly(mask, hull, 255.0, lineType=cv2.LINE_AA) # pylint: disable=no-member + return mask + # mask = np.zeros(img.shape[0:2] + (1, ), dtype=np.float32) + # hull = cv2.convexHull(np.array(landmark).reshape((-1, 2))) + +class facehull2(Mask): # pylint: disable=invalid-name + """ Basic face hull mask """ + def build_mask(self): + mask = np.zeros(self.face.shape[0:2] + (1, ), dtype=np.uint8) + hull = cv2.convexHull( # pylint: disable=no-member + np.array(self.landmarks).reshape((-1, 2))) + cv2.fillConvexPoly(mask, hull, 1.0, lineType=cv2.LINE_AA) + return mask + + + +class gridMasking(Mask): + + def build_mask(self): + h, w = self.face.shape[:2] + landmarks = self.landmarks[:68] + # if idx is None: + # idx = np.random.randint(0, self.total) + r, c = divmod(self.idx, self.cols) # Get the quotient and remainder, meaning this idx corresponds to row r and column c + + # pixel related + xmin, ymin, xmax, ymax = landmarks_to_bbox(landmarks) + dx = ceil((xmax - xmin) / self.cols) + dy = ceil((ymax - ymin) / self.rows) + + mask = np.zeros((h, w), dtype=np.uint8) + + # fill the cell mask + x0, y0 = floor(xmin + dx * c), floor(ymin + dy * r) + x1, y1 = floor(x0 + dx), floor(y0 + dy) + cv2.rectangle(mask, (x0, y0), (x1, y1), 255, -1) + + # merge the cell mask with the convex hull + ch = mask_from_points(self.face, landmarks) + # ch = cv2.cvtColor(ch, cv2.COLOR_BGR2GRAY) + # mask = (mask & ch) / 255.0 + mask = cv2.bitwise_and(mask, mask, mask=ch) + mask = mask.reshape([mask.shape[0],mask.shape[1], 1]) + # cv2.bitwise_or(img, d_3c_i) + + return mask + +class MeshgridMasking(Mask): + areas = [ + [1, 2, 3, 4, 5, 6, 7, 49, 32, 40, 41, 42, 37, 18], + [37, 38, 39, 40, 41, 42], # left eye + [18, 19, 20, 21, 22, 28, 40, 39, 38, 37], + [28, 29, 30, 31, 32, 40], + ] + areas_asym = [ + [20, 21, 22, 28, 23, 24, 25], # old [22, 23, 28], + [31, 32, 33, 34, 35, 36], + [32, 33, 34, 35, 36, 55, 54, 53, 52, 51, 50, 49], + [49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60], + [7, 8, 9, 10, 11, 55, 56, 57, 58, 59, 60, 49], + ] + + def init(self, **kwargs): + # super().__init__(**kwargs) + + sym, _ = landmarks_68_symmetries() + # construct list of points paths + paths = [] + paths += self.areas_asym # asymmetrical areas + paths += self.areas # left + paths += [[sym[ld68_id] for ld68_id in area] for area in self.areas] # right + assert len(paths) == self.total + self.paths = paths + + @property + def total(self) -> int: + total = len(self.areas_asym) + len(self.areas) * 2 + return total + + def transform_landmarks(self, landmarks): + """Transform landmarks to extend the eyebrow points up the forehead""" + new_landmarks = landmarks.copy() + # mid points between the side of face and eye point + ml_pnt = (new_landmarks[36] + new_landmarks[0]) // 2 + mr_pnt = (new_landmarks[16] + new_landmarks[45]) // 2 + + # mid points between the mid points and eye + ql_pnt = (new_landmarks[36] + ml_pnt) // 2 + qr_pnt = (new_landmarks[45] + mr_pnt) // 2 + + # Top of the eye arrays + bot_l = np.array( + ( + ql_pnt, + new_landmarks[36], + new_landmarks[37], + new_landmarks[38], + new_landmarks[39], + ) + ) + bot_r = np.array( + ( + new_landmarks[42], + new_landmarks[43], + new_landmarks[44], + new_landmarks[45], + qr_pnt, + ) + ) + + # Eyebrow arrays + top_l = new_landmarks[17:22] + top_r = new_landmarks[22:27] + + # Adjust eyebrow arrays + new_landmarks[17:22] = top_l + ((top_l - bot_l) // 2) + new_landmarks[22:27] = top_r + ((top_r - bot_r) // 2) + + return new_landmarks + + def build_mask(self) -> np.ndarray: + self.init() + h, w = self.face.shape[:2] + + path = self.paths[self.idx] + new_landmarks = self.transform_landmarks(self.landmarks) + points = [new_landmarks[ld68_id - 1] for ld68_id in path] + points = np.array(points, dtype=np.int32) + + # cv2.fillConvexPoly(out, points, 255, lineType=4) + mask = np.zeros((h, w), dtype=np.uint8) + cv2.fillPoly(mask, [points], 255) + mask = mask.reshape([mask.shape[0],mask.shape[1], 1]) + return mask \ No newline at end of file diff --git a/training/dataset/utils/SLADD.py b/training/dataset/utils/SLADD.py new file mode 100644 index 0000000000000000000000000000000000000000..5bce890e2ccedb53a3d781fe37beed4a54381448 --- /dev/null +++ b/training/dataset/utils/SLADD.py @@ -0,0 +1,163 @@ +from enum import Enum +from functools import reduce + +import cv2 +import numpy as np +from scipy.ndimage import binary_dilation + +from .DeepFakeMask import Mask + + +def dist(a, b): + x1, y1 = a + x2, y2 = b + return np.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2) + # return np.linalg.norm(a-b) + + +def get_five_key(landmarks_68): + # get the five key points by using the landmarks + leye_center = (landmarks_68[36] + landmarks_68[39]) * 0.5 + reye_center = (landmarks_68[42] + landmarks_68[45]) * 0.5 + nose = landmarks_68[33] + lmouth = landmarks_68[48] + rmouth = landmarks_68[54] + leye_left = landmarks_68[36] + leye_right = landmarks_68[39] + reye_left = landmarks_68[42] + reye_right = landmarks_68[45] + out = [ + tuple(x.astype("int32")) + for x in [ + leye_center, + reye_center, + nose, + lmouth, + rmouth, + leye_left, + leye_right, + reye_left, + reye_right, + ] + ] + return out + + +def remove_eyes(image, landmarks, opt): + ##l: left eye; r: right eye, b: both eye + if opt == "l": + (x1, y1), (x2, y2) = landmarks[5:7] + elif opt == "r": + (x1, y1), (x2, y2) = landmarks[7:9] + elif opt == "b": + (x1, y1), (x2, y2) = landmarks[:2] + else: + print("wrong region") + mask = np.zeros_like(image[..., 0]) + line = cv2.line(mask, (x1, y1), (x2, y2), color=(1), thickness=2) + w = dist((x1, y1), (x2, y2)) + dilation = int(w // 4) + if opt != "b": + dilation *= 4 + line = binary_dilation(line, iterations=dilation) + return line + + +def remove_nose(image, landmarks): + (x1, y1), (x2, y2) = landmarks[:2] + x3, y3 = landmarks[2] + mask = np.zeros_like(image[..., 0]) + x4 = int((x1 + x2) / 2) + y4 = int((y1 + y2) / 2) + line = cv2.line(mask, (x3, y3), (x4, y4), color=(1), thickness=2) + w = dist((x1, y1), (x2, y2)) + dilation = int(w // 4) + line = binary_dilation(line, iterations=dilation) + return line + + +def remove_mouth(image, landmarks): + (x1, y1), (x2, y2) = landmarks[3:5] + mask = np.zeros_like(image[..., 0]) + line = cv2.line(mask, (x1, y1), (x2, y2), color=(1), thickness=2) + w = dist((x1, y1), (x2, y2)) + dilation = int(w // 3) + line = binary_dilation(line, iterations=dilation) + return line + + +class SladdRegion(Enum): + left_eye = 0 + right_eye = 1 + nose = 2 + mouth = 3 + # composition + both_eyes = left_eye + right_eye # 4 + + +class SladdMasking(Mask): + + # [0, 1, 2, 3, (0, 1), (0, 2), (1, 2), (2, 3), (0, 1, 2), (0, 1, 2, 3)] + # left-eye, right-eye, nose, mouth, ... + ALL_REGIONS = [ + SladdRegion.left_eye, + SladdRegion.right_eye, + SladdRegion.nose, + SladdRegion.mouth, + ] + REGIONS = [ + [SladdRegion.left_eye], + [SladdRegion.right_eye], + [SladdRegion.nose], + [SladdRegion.mouth], + [SladdRegion.left_eye, SladdRegion.right_eye], + [SladdRegion.left_eye, SladdRegion.nose], + [SladdRegion.right_eye, SladdRegion.nose], + [SladdRegion.nose, SladdRegion.mouth], + [SladdRegion.left_eye, SladdRegion.right_eye, SladdRegion.nose], + ALL_REGIONS, + ] + + def init(self, compose: bool = False, single: bool = True, **kwargs): + # super().__init__(**kwargs) + self.compose = compose + if compose: + self.regions = SladdMasking.REGIONS + else: + self.regions = [reg for reg in SladdMasking.REGIONS if len(reg) == 1] + if single: + self.regions = [self.ALL_REGIONS] + + @property + def total(self) -> int: + return len(self.regions) + + @staticmethod + def parse(img, reg, landmarks) -> np.ndarray: + five_key = get_five_key(landmarks) + if reg is SladdRegion.left_eye: + mask = remove_eyes(img, five_key, "l") + elif reg is SladdRegion.right_eye: + mask = remove_eyes(img, five_key, "r") + elif reg is SladdRegion.nose: + mask = remove_nose(img, five_key) + elif reg is SladdRegion.mouth: + mask = remove_mouth(img, five_key) + else: + raise ValueError("Invalid region") + # elif reg == SladdRegion4: + # mask = remove_eyes(img, five_key, "b") + return mask + + def build_mask(self) -> np.ndarray: + self.init() + h, w = self.face.shape[:2] + # print(len(self.regions)) + regs = [self.regions[0][self.idx]] + # if isinstance(reg, int): + # mask = parse(img, reg, landmarks) + masks = [SladdMasking.parse(self.face, reg, self.landmarks) for reg in regs] + mask = reduce(np.maximum, masks) + mask = mask.reshape([mask.shape[0],mask.shape[1], 1]) + + return mask diff --git a/training/dataset/utils/attribution_mask.py b/training/dataset/utils/attribution_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..6c0f9b06f7421fe94e2eae3d2eb44bc493237562 --- /dev/null +++ b/training/dataset/utils/attribution_mask.py @@ -0,0 +1,55 @@ + + +import cv2 +import math +import numpy as np +from scipy.ndimage import binary_erosion, binary_dilation +def dist(p1, p2): + return math.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2) + +def remove_mouth(image, landmarks): + (x1, y1), (x2, y2) = landmarks[3:5] + mask = np.zeros_like(image[..., 0]) + line = cv2.line(mask, (x1, y1), (x2, y2), color=(1), thickness=2) + w = dist((x1, y1), (x2, y2)) + dilation = int(w // 3) + line = binary_dilation(line, iterations=dilation) + return line + +def remove_eyes(image, landmarks, opt='b'): + ##l: left eye; r: right eye, b: both eye + if opt == 'l': + (x1, y1), (x2, y2) = landmarks[36],landmarks[39] + elif opt == 'r': + (x1, y1), (x2, y2) = landmarks[42],landmarks[46] + elif opt == 'b': + (x1, y1), (x2, y2) = landmarks[36],landmarks[46] + else: + print('wrong region') + mask = np.zeros_like(image[..., 0]) + line = cv2.line(np.array(mask, dtype=np.uint8), (int(x1), int(y1)), (int(x2), int(y2)), color=(1), thickness=2) + w = dist((x1, y1), (x2, y2)) + dilation = int(w // 4) + if opt != 'b': + dilation *= 4 + line = binary_dilation(line, iterations=dilation) + return line + +def remove_nose(image, landmarks): + ##l: left eye; r: right eye, b: both eye + + (x1, y1), (x2, y2) = landmarks[27], landmarks[30] + mask = np.zeros_like(image[..., 0]) + line = cv2.line(np.array(mask, dtype=np.uint8), (int(x1), int(y1)), (int(x2), int(y2)), color=(1), thickness=2) + w = dist((x1, y1), (x2, y2)) + dilation = int(w // 3) + line1 = binary_dilation(line, iterations=dilation) + + (x1, y1), (x2, y2) = landmarks[31], landmarks[35] + mask = np.zeros_like(image[..., 0]) + line = cv2.line(np.array(mask, dtype=np.uint8), (int(x1), int(y1)), (int(x2), int(y2)), color=(1), thickness=2) + w = dist((x1, y1), (x2, y2)) + dilation = int(w //4 ) + line2 = binary_dilation(line, iterations=dilation) + + return line1+line2 \ No newline at end of file diff --git a/training/dataset/utils/bi_online_generation.py b/training/dataset/utils/bi_online_generation.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc30a2dfcda5766a5d8425a7e5496dbec7b53ad --- /dev/null +++ b/training/dataset/utils/bi_online_generation.py @@ -0,0 +1,289 @@ +import dlib +from skimage import io +from skimage import transform as sktransform +import numpy as np +from matplotlib import pyplot as plt +import json +import os +import random +from PIL import Image +from imgaug import augmenters as iaa +from dataset.library.DeepFakeMask import dfl_full,facehull,components,extended +from dataset.utils.attribution_mask import * +import cv2 +import tqdm + +''' +from PIL import ImageDraw +# Create an object that can draw on the image +img_pil=Image.fromarray(img) +draw = ImageDraw.Draw(img_pil) + +# Draw points on the image +for i, point in enumerate(landmark): + x, y = point + radius = 1 # radius of the point + draw.ellipse((x-radius, y-radius, x+radius, y+radius), fill="red") + draw.text((x+radius+2, y-radius), str(i), fill="black") # Add a label next to the point +img_pil.show() +''' + + +def name_resolve(path): + name = os.path.splitext(os.path.basename(path))[0] + vid_id, frame_id = name.split('_')[0:2] + return vid_id, frame_id + +def total_euclidean_distance(a,b): + assert len(a.shape) == 2 + return np.sum(np.linalg.norm(a-b,axis=1)) + +def get_five_key(landmarks_68): + # get the five key points by using the landmarks + leye_center = (landmarks_68[36] + landmarks_68[39])*0.5 + reye_center = (landmarks_68[42] + landmarks_68[45])*0.5 + nose = landmarks_68[33] + lmouth = landmarks_68[48] + rmouth = landmarks_68[54] + leye_left = landmarks_68[36] + leye_right = landmarks_68[39] + reye_left = landmarks_68[42] + reye_right = landmarks_68[45] + out = [ tuple(x.astype('int32')) for x in [ + leye_center,reye_center,nose,lmouth,rmouth,leye_left,leye_right,reye_left,reye_right + ]] + return out + +def random_get_hull(landmark,img1,hull_type=None): + if hull_type==None: + hull_type = random.choice([0,1,2,3]) + if hull_type == 0: + mask = dfl_full(landmarks=landmark.astype('int32'),face=img1, channels=3).mask + return mask[:,:,0]/255 + elif hull_type == 1: + mask = extended(landmarks=landmark.astype('int32'),face=img1, channels=3).mask + return mask[:,:,0]/255 + elif hull_type == 2: + mask = components(landmarks=landmark.astype('int32'),face=img1, channels=3).mask + return mask[:,:,0]/255 + elif hull_type == 3: + mask = facehull(landmarks=landmark.astype('int32'),face=img1, channels=3).mask + return mask[:,:,0]/255 + elif hull_type == 4: + mask = remove_mouth(img1,get_five_key(landmark)) + return mask.astype(np.float32) + elif hull_type == 5: + mask = remove_eyes(img1,landmark) + return mask.astype(np.float32) + elif hull_type == 6: + mask = remove_nose(img1,landmark) + return mask.astype(np.float32) + elif hull_type == 7: + mask = remove_nose(img1,landmark) + remove_eyes(img1,landmark) + remove_mouth(img1,get_five_key(landmark)) + return mask.astype(np.float32) + + +def random_erode_dilate(mask, ksize=None): + if random.random()>0.5: + if ksize is None: + ksize = random.randint(1,21) + if ksize % 2 == 0: + ksize += 1 + mask = np.array(mask).astype(np.uint8)*255 + kernel = np.ones((ksize,ksize),np.uint8) + mask = cv2.erode(mask,kernel,1)/255 + else: + if ksize is None: + ksize = random.randint(1,5) + if ksize % 2 == 0: + ksize += 1 + mask = np.array(mask).astype(np.uint8)*255 + kernel = np.ones((ksize,ksize),np.uint8) + mask = cv2.dilate(mask,kernel,1)/255 + return mask + + +# borrow from https://github.com/MarekKowalski/FaceSwap +def blendImages(src, dst, mask, featherAmount=0.2): + + maskIndices = np.where(mask != 0) + + src_mask = np.ones_like(mask) + dst_mask = np.zeros_like(mask) + + maskPts = np.hstack((maskIndices[1][:, np.newaxis], maskIndices[0][:, np.newaxis])) + faceSize = np.max(maskPts, axis=0) - np.min(maskPts, axis=0) + featherAmount = featherAmount * np.max(faceSize) + + hull = cv2.convexHull(maskPts) + dists = np.zeros(maskPts.shape[0]) + for i in range(maskPts.shape[0]): + dists[i] = cv2.pointPolygonTest(hull, (maskPts[i, 0], maskPts[i, 1]), True) + + weights = np.clip(dists / featherAmount, 0, 1) + + composedImg = np.copy(dst) + composedImg[maskIndices[0], maskIndices[1]] = weights[:, np.newaxis] * src[maskIndices[0], maskIndices[1]] + (1 - weights[:, np.newaxis]) * dst[maskIndices[0], maskIndices[1]] + + composedMask = np.copy(dst_mask) + composedMask[maskIndices[0], maskIndices[1]] = weights[:, np.newaxis] * src_mask[maskIndices[0], maskIndices[1]] + ( + 1 - weights[:, np.newaxis]) * dst_mask[maskIndices[0], maskIndices[1]] + + return composedImg, composedMask + + +# borrow from https://github.com/MarekKowalski/FaceSwap +def colorTransfer(src, dst, mask): + transferredDst = np.copy(dst) + + maskIndices = np.where(mask != 0) + + + maskedSrc = src[maskIndices[0], maskIndices[1]].astype(np.int32) + maskedDst = dst[maskIndices[0], maskIndices[1]].astype(np.int32) + + meanSrc = np.mean(maskedSrc, axis=0) + meanDst = np.mean(maskedDst, axis=0) + + maskedDst = maskedDst - meanDst + maskedDst = maskedDst + meanSrc + maskedDst = np.clip(maskedDst, 0, 255) + + transferredDst[maskIndices[0], maskIndices[1]] = maskedDst + + return transferredDst + +class BIOnlineGeneration(): + def __init__(self): + with open('precomuted_landmarks.json', 'r') as f: + self.landmarks_record = json.load(f) + for k,v in self.landmarks_record.items(): + self.landmarks_record[k] = np.array(v) + # extract all frame from all video in the name of {videoid}_{frameid} + self.data_list = [ + '000_0000.png', + '001_0000.png' + ] * 10000 + + # predefine mask distortion + self.distortion = iaa.Sequential([iaa.PiecewiseAffine(scale=(0.01, 0.15))]) + + def gen_one_datapoint(self): + background_face_path = random.choice(self.data_list) + data_type = 'real' if random.randint(0,1) else 'fake' + if data_type == 'fake' : + face_img,mask = self.get_blended_face(background_face_path) + mask = ( 1 - mask ) * mask * 4 + else: + face_img = io.imread(background_face_path) + mask = np.zeros((317, 317, 1)) + + # randomly downsample after BI pipeline + if random.randint(0,1): + aug_size = random.randint(64, 317) + face_img = Image.fromarray(face_img) + if random.randint(0,1): + face_img = face_img.resize((aug_size, aug_size), Image.BILINEAR) + else: + face_img = face_img.resize((aug_size, aug_size), Image.NEAREST) + face_img = face_img.resize((317, 317),Image.BILINEAR) + face_img = np.array(face_img) + + # random jpeg compression after BI pipeline + if random.randint(0,1): + quality = random.randint(60, 100) + encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality] + face_img_encode = cv2.imencode('.jpg', face_img, encode_param)[1] + face_img = cv2.imdecode(face_img_encode, cv2.IMREAD_COLOR) + + face_img = face_img[60:317,30:287,:] + mask = mask[60:317,30:287,:] + + # random flip + if random.randint(0,1): + face_img = np.flip(face_img,1) + mask = np.flip(mask,1) + + return face_img,mask,data_type + + def get_blended_face(self,background_face_path): + background_face = io.imread(background_face_path) + background_landmark = self.landmarks_record[background_face_path] + + foreground_face_path = self.search_similar_face(background_landmark,background_face_path) + foreground_face = io.imread(foreground_face_path) + + # down sample before blending + aug_size = random.randint(128,317) + background_landmark = background_landmark * (aug_size/317) + foreground_face = sktransform.resize(foreground_face,(aug_size,aug_size),preserve_range=True).astype(np.uint8) + background_face = sktransform.resize(background_face,(aug_size,aug_size),preserve_range=True).astype(np.uint8) + + # get random type of initial blending mask + mask = random_get_hull(background_landmark, background_face) + + # random deform mask + mask = self.distortion.augment_image(mask) + mask = random_erode_dilate(mask) + + # filte empty mask after deformation + if np.sum(mask) == 0 : + raise NotImplementedError + + # apply color transfer + foreground_face = colorTransfer(background_face, foreground_face, mask*255) + + # blend two face + blended_face, mask = blendImages(foreground_face, background_face, mask*255) + blended_face = blended_face.astype(np.uint8) + + # resize back to default resolution + blended_face = sktransform.resize(blended_face,(317,317),preserve_range=True).astype(np.uint8) + mask = sktransform.resize(mask,(317,317),preserve_range=True) + mask = mask[:,:,0:1] + return blended_face,mask + + def search_similar_face(self,this_landmark,background_face_path): + vid_id, frame_id = name_resolve(background_face_path) + min_dist = 99999999 + + # random sample 5000 frame from all frams: + all_candidate_path = random.sample( self.data_list, k=5000) + + # filter all frame that comes from the same video as background face + all_candidate_path = filter(lambda k:name_resolve(k)[0] != vid_id, all_candidate_path) + all_candidate_path = list(all_candidate_path) + + # loop throungh all candidates frame to get best match + for candidate_path in all_candidate_path: + candidate_landmark = self.landmarks_record[candidate_path].astype(np.float32) + candidate_distance = total_euclidean_distance(candidate_landmark, this_landmark) + if candidate_distance < min_dist: + min_dist = candidate_distance + min_path = candidate_path + + return min_path + +if __name__ == '__main__': + ds = BIOnlineGeneration() + from tqdm import tqdm + all_imgs = [] + for _ in tqdm(range(50)): + img,mask,label = ds.gen_one_datapoint() + mask = np.repeat(mask,3,2) + mask = (mask*255).astype(np.uint8) + img_cat = np.concatenate([img,mask],1) + all_imgs.append(img_cat) + all_in_one = Image.new('RGB', (2570,2570)) + + for x in range(5): + for y in range(10): + idx = x*10+y + im = Image.fromarray(all_imgs[idx]) + + dx = x*514 + dy = y*257 + + all_in_one.paste(im, (dx,dy)) + + all_in_one.save("all_in_one.jpg") \ No newline at end of file diff --git a/training/dataset/utils/bi_online_generation_yzy.py b/training/dataset/utils/bi_online_generation_yzy.py new file mode 100644 index 0000000000000000000000000000000000000000..1cc85db2845a9f59142df9e21b5af5a90e91950c --- /dev/null +++ b/training/dataset/utils/bi_online_generation_yzy.py @@ -0,0 +1,268 @@ +import dlib +from skimage import io +from skimage import transform as sktransform +import numpy as np +from matplotlib import pyplot as plt +import json +import os +import random +from PIL import Image +from imgaug import augmenters as iaa +from .DeepFakeMask import dfl_full,facehull,components,extended,gridMasking,MeshgridMasking, facehull2 +from .SLADD import SladdMasking +import cv2 +import torch +import torch.nn as nn +import tqdm +import pdb + + +def name_resolve(path): + name = os.path.splitext(os.path.basename(path))[0] + vid_id, frame_id = name.split('_')[0:2] + return vid_id, frame_id + +def total_euclidean_distance(a,b): + assert len(a.shape) == 2 + return np.sum(np.linalg.norm(a-b,axis=1)) + +def random_get_hull(landmark,img1,hull_type0, idx=0): + # print("in bi online generation----------",hull_type0) + if hull_type0 == -1: + hull_type = random.choice([0,1,2,3]) + else: + # hull_type = int(random.choice(hull_type0)) + hull_type = hull_type0 + # print(hull_type) + if hull_type == 0: + # print("here") + mask = dfl_full(landmarks=landmark.astype('int32'),face=img1, channels=3).mask + return mask/255, idx + elif hull_type == 1: + mask = extended(landmarks=landmark.astype('int32'),face=img1, channels=3).mask + return mask/255, idx + elif hull_type == 2: + mask = components(landmarks=landmark.astype('int32'),face=img1, channels=3).mask + return mask/255, idx + elif hull_type == 3: + mask = facehull(landmarks=landmark.astype('int32'),face=img1, channels=3).mask + return mask/255, idx # --change0628-- mask/255 + + # elif hull_type == 4: # SLADD + # mask = SladdMasking(landmarks=landmark.astype('int32'),face=img1, channels=3, idx=0).mask + # return mask/1., idx + # elif hull_type == 5: # SLADD + # mask = SladdMasking(landmarks=landmark.astype('int32'),face=img1, channels=3, idx=1).mask + # return mask/1., idx + # elif hull_type == 6: # SLADD + # mask = SladdMasking(landmarks=landmark.astype('int32'),face=img1, channels=3, idx=2).mask + # return mask/1., idx + elif hull_type == 6: # SLADD/mouth + mask = SladdMasking(landmarks=landmark.astype('int32'),face=img1, channels=3, idx=3).mask + return mask/1., idx + + +def random_erode_dilate(mask, ksize=None): + if random.random()>0.5: + if ksize is None: + ksize = random.randint(1,21) + if ksize % 2 == 0: + ksize += 1 + mask = np.array(mask).astype(np.uint8)*255 + kernel = np.ones((ksize,ksize),np.uint8) + mask = cv2.erode(mask,kernel,1)/255 + else: + if ksize is None: + ksize = random.randint(1,5) + if ksize % 2 == 0: + ksize += 1 + mask = np.array(mask).astype(np.uint8)*255 + kernel = np.ones((ksize,ksize),np.uint8) + mask = cv2.dilate(mask,kernel,1)/255 + return mask + + +# borrow from https://github.com/MarekKowalski/FaceSwap +def blendImages(src, dst, mask, featherAmount=0.2): + + maskIndices = np.where(mask != 0) + + src_mask = np.ones_like(mask) + dst_mask = np.zeros_like(mask) + + maskPts = np.hstack((maskIndices[1][:, np.newaxis], maskIndices[0][:, np.newaxis])) + faceSize = np.max(maskPts, axis=0) - np.min(maskPts, axis=0) + featherAmount = featherAmount * np.max(faceSize) + + hull = cv2.convexHull(maskPts) + dists = np.zeros(maskPts.shape[0]) + for i in range(maskPts.shape[0]): + dists[i] = cv2.pointPolygonTest(hull, (maskPts[i, 0], maskPts[i, 1]), True) + + weights = np.clip(dists / featherAmount, 0, 1) + + composedImg = np.copy(dst) + composedImg[maskIndices[0], maskIndices[1]] = weights[:, np.newaxis] * src[maskIndices[0], maskIndices[1]] + (1 - weights[:, np.newaxis]) * dst[maskIndices[0], maskIndices[1]] + + composedMask = np.copy(dst_mask) + composedMask[maskIndices[0], maskIndices[1]] = weights[:, np.newaxis] * src_mask[maskIndices[0], maskIndices[1]] + ( + 1 - weights[:, np.newaxis]) * dst_mask[maskIndices[0], maskIndices[1]] + + return composedImg, composedMask + + +# borrow from https://github.com/MarekKowalski/FaceSwap +def colorTransfer(src, dst, mask): + transferredDst = np.copy(dst) + + maskIndices = np.where(mask != 0) + + + maskedSrc = src[maskIndices[0], maskIndices[1]].astype(np.int32) + maskedDst = dst[maskIndices[0], maskIndices[1]].astype(np.int32) + + meanSrc = np.mean(maskedSrc, axis=0) + meanDst = np.mean(maskedDst, axis=0) + + maskedDst = maskedDst - meanDst + maskedDst = maskedDst + meanSrc + maskedDst = np.clip(maskedDst, 0, 255) + + transferredDst[maskIndices[0], maskIndices[1]] = maskedDst + + return transferredDst + +class BIOnlineGeneration(): + def __init__(self): + with open('precomuted_landmarks.json', 'r') as f: + self.landmarks_record = json.load(f) + for k,v in self.landmarks_record.items(): + self.landmarks_record[k] = np.array(v) + # extract all frame from all video in the name of {videoid}_{frameid} + self.data_list = [ + '000_0000.png', + '001_0000.png' + ] * 10000 + + # predefine mask distortion + self.distortion = iaa.Sequential([iaa.PiecewiseAffine(scale=(0.01, 0.15))]) + + def gen_one_datapoint(self): + background_face_path = random.choice(self.data_list) + data_type = 'real' if random.randint(0,1) else 'fake' + if data_type == 'fake' : + face_img,mask = self.get_blended_face(background_face_path) + mask = ( 1 - mask ) * mask * 4 + else: + face_img = io.imread(background_face_path) + mask = np.zeros((317, 317, 1)) + + # randomly downsample after BI pipeline + if random.randint(0,1): + aug_size = random.randint(64, 317) + face_img = Image.fromarray(face_img) + if random.randint(0,1): + face_img = face_img.resize((aug_size, aug_size), Image.BILINEAR) + else: + face_img = face_img.resize((aug_size, aug_size), Image.NEAREST) + face_img = face_img.resize((317, 317),Image.BILINEAR) + face_img = np.array(face_img) + + # random jpeg compression after BI pipeline + if random.randint(0,1): + quality = random.randint(60, 100) + encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality] + face_img_encode = cv2.imencode('.jpg', face_img, encode_param)[1] + face_img = cv2.imdecode(face_img_encode, cv2.IMREAD_COLOR) + + face_img = face_img[60:317,30:287,:] + mask = mask[60:317,30:287,:] + + # random flip + if random.randint(0,1): + face_img = np.flip(face_img,1) + mask = np.flip(mask,1) + + return face_img,mask,data_type + + def get_blended_face(self,background_face_path): + background_face = io.imread(background_face_path) + background_landmark = self.landmarks_record[background_face_path] + + foreground_face_path = self.search_similar_face(background_landmark,background_face_path) + foreground_face = io.imread(foreground_face_path) + + # down sample before blending + aug_size = random.randint(128,317) + background_landmark = background_landmark * (aug_size/317) + foreground_face = sktransform.resize(foreground_face,(aug_size,aug_size),preserve_range=True).astype(np.uint8) + background_face = sktransform.resize(background_face,(aug_size,aug_size),preserve_range=True).astype(np.uint8) + + # get random type of initial blending mask + mask, idx = random_get_hull(background_landmark, background_face) + + # random deform mask + mask = self.distortion.augment_image(mask) + mask = random_erode_dilate(mask) + + # filte empty mask after deformation + if np.sum(mask) == 0 : + raise NotImplementedError + + # apply color transfer + foreground_face = colorTransfer(background_face, foreground_face, mask*255) + + # blend two face + blended_face, mask = blendImages(foreground_face, background_face, mask*255) + blended_face = blended_face.astype(np.uint8) + + # resize back to default resolution + blended_face = sktransform.resize(blended_face,(317,317),preserve_range=True).astype(np.uint8) + mask = sktransform.resize(mask,(317,317),preserve_range=True) + mask = mask[:,:,0:1] + return blended_face,mask + + def search_similar_face(self,this_landmark,background_face_path): + vid_id, frame_id = name_resolve(background_face_path) + min_dist = 99999999 + + # random sample 5000 frame from all frams: + all_candidate_path = random.sample( self.data_list, k=5000) + + # filter all frame that comes from the same video as background face + all_candidate_path = filter(lambda k:name_resolve(k)[0] != vid_id, all_candidate_path) + all_candidate_path = list(all_candidate_path) + + # loop throungh all candidates frame to get best match + for candidate_path in all_candidate_path: + candidate_landmark = self.landmarks_record[candidate_path].astype(np.float32) + candidate_distance = total_euclidean_distance(candidate_landmark, this_landmark) + if candidate_distance < min_dist: + min_dist = candidate_distance + min_path = candidate_path + + return min_path + +if __name__ == '__main__': + ds = BIOnlineGeneration() + from tqdm import tqdm + all_imgs = [] + for _ in tqdm(range(50)): + img,mask,label = ds.gen_one_datapoint() + mask = np.repeat(mask,3,2) + mask = (mask*255).astype(np.uint8) + img_cat = np.concatenate([img,mask],1) + all_imgs.append(img_cat) + all_in_one = Image.new('RGB', (2570,2570)) + + for x in range(5): + for y in range(10): + idx = x*10+y + im = Image.fromarray(all_imgs[idx]) + + dx = x*514 + dy = y*257 + + all_in_one.paste(im, (dx,dy)) + + all_in_one.save("all_in_one.jpg") \ No newline at end of file diff --git a/training/dataset/utils/color_transfer.py b/training/dataset/utils/color_transfer.py new file mode 100644 index 0000000000000000000000000000000000000000..0845dbcf8221b400218411f11968eb4156233022 --- /dev/null +++ b/training/dataset/utils/color_transfer.py @@ -0,0 +1,516 @@ +import cv2 +import numpy as np +from numpy import linalg as npla + +import scipy as sp +import scipy.sparse +from scipy.sparse.linalg import spsolve + + +def color_transfer_sot(src, trg, steps=10, batch_size=5, reg_sigmaXY=16.0, reg_sigmaV=5.0): + """ + Color Transform via Sliced Optimal Transfer + ported by @iperov from https://github.com/dcoeurjo/OTColorTransfer + + src - any float range any channel image + dst - any float range any channel image, same shape as src + steps - number of solver steps + batch_size - solver batch size + reg_sigmaXY - apply regularization and sigmaXY of filter, otherwise set to 0.0 + reg_sigmaV - sigmaV of filter + + return value - clip it manually + """ + if not np.issubdtype(src.dtype, np.floating): + raise ValueError("src value must be float") + if not np.issubdtype(trg.dtype, np.floating): + raise ValueError("trg value must be float") + + if len(src.shape) != 3: + raise ValueError("src shape must have rank 3 (h,w,c)") + + if src.shape != trg.shape: + raise ValueError("src and trg shapes must be equal") + + src_dtype = src.dtype + h, w, c = src.shape + new_src = src.copy() + + for step in range(steps): + advect = np.zeros((h*w, c), dtype=src_dtype) + for batch in range(batch_size): + dir = np.random.normal(size=c).astype(src_dtype) + dir /= npla.norm(dir) + + projsource = np.sum(new_src*dir, axis=-1).reshape((h*w)) + projtarget = np.sum(trg*dir, axis=-1).reshape((h*w)) + + idSource = np.argsort(projsource) + idTarget = np.argsort(projtarget) + + a = projtarget[idTarget]-projsource[idSource] + for i_c in range(c): + advect[idSource, i_c] += a * dir[i_c] + new_src += advect.reshape((h, w, c)) / batch_size + + if reg_sigmaXY != 0.0: + src_diff = new_src-src + src_diff_filt = cv2.bilateralFilter( + src_diff, 0, reg_sigmaV, reg_sigmaXY) + if len(src_diff_filt.shape) == 2: + src_diff_filt = src_diff_filt[..., None] + new_src = src + src_diff_filt + return new_src + + +def color_transfer_mkl(x0, x1): + eps = np.finfo(float).eps + + h, w, c = x0.shape + h1, w1, c1 = x1.shape + + x0 = x0.reshape((h*w, c)) + x1 = x1.reshape((h1*w1, c1)) + + a = np.cov(x0.T) + b = np.cov(x1.T) + + Da2, Ua = np.linalg.eig(a) + Da = np.diag(np.sqrt(Da2.clip(eps, None))) + + C = np.dot(np.dot(np.dot(np.dot(Da, Ua.T), b), Ua), Da) + + Dc2, Uc = np.linalg.eig(C) + Dc = np.diag(np.sqrt(Dc2.clip(eps, None))) + + Da_inv = np.diag(1./(np.diag(Da))) + + t = np.dot( + np.dot(np.dot(np.dot(np.dot(np.dot(Ua, Da_inv), Uc), Dc), Uc.T), Da_inv), Ua.T) + + mx0 = np.mean(x0, axis=0) + mx1 = np.mean(x1, axis=0) + + result = np.dot(x0-mx0, t) + mx1 + return np.clip(result.reshape((h, w, c)).astype(x0.dtype), 0, 1) + + +def color_transfer_idt(i0, i1, bins=256, n_rot=20): + relaxation = 1 / n_rot + h, w, c = i0.shape + h1, w1, c1 = i1.shape + + i0 = i0.reshape((h*w, c)) + i1 = i1.reshape((h1*w1, c1)) + + n_dims = c + + d0 = i0.T + d1 = i1.T + + for i in range(n_rot): + + r = sp.stats.special_ortho_group.rvs(n_dims).astype(np.float32) + + d0r = np.dot(r, d0) + d1r = np.dot(r, d1) + d_r = np.empty_like(d0) + + for j in range(n_dims): + + lo = min(d0r[j].min(), d1r[j].min()) + hi = max(d0r[j].max(), d1r[j].max()) + + p0r, edges = np.histogram(d0r[j], bins=bins, range=[lo, hi]) + p1r, _ = np.histogram(d1r[j], bins=bins, range=[lo, hi]) + + cp0r = p0r.cumsum().astype(np.float32) + cp0r /= cp0r[-1] + + cp1r = p1r.cumsum().astype(np.float32) + cp1r /= cp1r[-1] + + f = np.interp(cp0r, cp1r, edges[1:]) + + d_r[j] = np.interp(d0r[j], edges[1:], f, left=0, right=bins) + + d0 = relaxation * np.linalg.solve(r, (d_r - d0r)) + d0 + + return np.clip(d0.T.reshape((h, w, c)).astype(i0.dtype), 0, 1) + + +def laplacian_matrix(n, m): + mat_D = scipy.sparse.lil_matrix((m, m)) + mat_D.setdiag(-1, -1) + mat_D.setdiag(4) + mat_D.setdiag(-1, 1) + mat_A = scipy.sparse.block_diag([mat_D] * n).tolil() + mat_A.setdiag(-1, 1*m) + mat_A.setdiag(-1, -1*m) + return mat_A + + +def seamless_clone(source, target, mask): + h, w, c = target.shape + result = [] + + mat_A = laplacian_matrix(h, w) + laplacian = mat_A.tocsc() + + mask[0, :] = 1 + mask[-1, :] = 1 + mask[:, 0] = 1 + mask[:, -1] = 1 + q = np.argwhere(mask == 0) + + k = q[:, 1]+q[:, 0]*w + mat_A[k, k] = 1 + mat_A[k, k + 1] = 0 + mat_A[k, k - 1] = 0 + mat_A[k, k + w] = 0 + mat_A[k, k - w] = 0 + + mat_A = mat_A.tocsc() + mask_flat = mask.flatten() + for channel in range(c): + + source_flat = source[:, :, channel].flatten() + target_flat = target[:, :, channel].flatten() + + mat_b = laplacian.dot(source_flat)*0.75 + mat_b[mask_flat == 0] = target_flat[mask_flat == 0] + + x = spsolve(mat_A, mat_b).reshape((h, w)) + result.append(x) + + return np.clip(np.dstack(result), 0, 1) + + +def reinhard_color_transfer(target, source, clip=False, preserve_paper=False, source_mask=None, target_mask=None): + """ + Transfers the color distribution from the source to the target + image using the mean and standard deviations of the L*a*b* + color space. + + This implementation is (loosely) based on to the "Color Transfer + between Images" paper by Reinhard et al., 2001. + + Parameters: + ------- + source: NumPy array + OpenCV image in BGR color space (the source image) + target: NumPy array + OpenCV image in BGR color space (the target image) + clip: Should components of L*a*b* image be scaled by np.clip before + converting back to BGR color space? + If False then components will be min-max scaled appropriately. + Clipping will keep target image brightness truer to the input. + Scaling will adjust image brightness to avoid washed out portions + in the resulting color transfer that can be caused by clipping. + preserve_paper: Should color transfer strictly follow methodology + layed out in original paper? The method does not always produce + aesthetically pleasing results. + If False then L*a*b* components will scaled using the reciprocal of + the scaling factor proposed in the paper. This method seems to produce + more consistently aesthetically pleasing results + + Returns: + ------- + transfer: NumPy array + OpenCV image (w, h, 3) NumPy array (uint8) + """ + + # convert the images from the RGB to L*ab* color space, being + # sure to utilizing the floating point data type (note: OpenCV + # expects floats to be 32-bit, so use that instead of 64-bit) + source = cv2.cvtColor(source, cv2.COLOR_BGR2LAB).astype(np.float32) + target = cv2.cvtColor(target, cv2.COLOR_BGR2LAB).astype(np.float32) + + # compute color statistics for the source and target images + src_input = source if source_mask is None else source*source_mask + tgt_input = target if target_mask is None else target*target_mask + (lMeanSrc, lStdSrc, aMeanSrc, aStdSrc, + bMeanSrc, bStdSrc) = lab_image_stats(src_input) + (lMeanTar, lStdTar, aMeanTar, aStdTar, + bMeanTar, bStdTar) = lab_image_stats(tgt_input) + + # subtract the means from the target image + (l, a, b) = cv2.split(target) + l -= lMeanTar + a -= aMeanTar + b -= bMeanTar + + if preserve_paper: + # scale by the standard deviations using paper proposed factor + l = (lStdTar / lStdSrc) * l + a = (aStdTar / aStdSrc) * a + b = (bStdTar / bStdSrc) * b + else: + # scale by the standard deviations using reciprocal of paper proposed factor + l = (lStdSrc / lStdTar) * l + a = (aStdSrc / aStdTar) * a + b = (bStdSrc / bStdTar) * b + + # add in the source mean + l += lMeanSrc + a += aMeanSrc + b += bMeanSrc + + # clip/scale the pixel intensities to [0, 255] if they fall + # outside this range + l = _scale_array(l, clip=clip) + a = _scale_array(a, clip=clip) + b = _scale_array(b, clip=clip) + + # merge the channels together and convert back to the RGB color + # space, being sure to utilize the 8-bit unsigned integer data + # type + transfer = cv2.merge([l, a, b]) + transfer = cv2.cvtColor(transfer.astype(np.uint8), cv2.COLOR_LAB2BGR) + + # return the color transferred image + return transfer + + +def linear_color_transfer(target_img, source_img, mode='pca', eps=1e-5): + ''' + Matches the colour distribution of the target image to that of the source image + using a linear transform. + Images are expected to be of form (w,h,c) and float in [0,1]. + Modes are chol, pca or sym for different choices of basis. + ''' + mu_t = target_img.mean(0).mean(0) + t = target_img - mu_t + t = t.transpose(2, 0, 1).reshape(t.shape[-1], -1) + Ct = t.dot(t.T) / t.shape[1] + eps * np.eye(t.shape[0]) + mu_s = source_img.mean(0).mean(0) + s = source_img - mu_s + s = s.transpose(2, 0, 1).reshape(s.shape[-1], -1) + Cs = s.dot(s.T) / s.shape[1] + eps * np.eye(s.shape[0]) + if mode == 'chol': + chol_t = np.linalg.cholesky(Ct) + chol_s = np.linalg.cholesky(Cs) + ts = chol_s.dot(np.linalg.inv(chol_t)).dot(t) + if mode == 'pca': + eva_t, eve_t = np.linalg.eigh(Ct) + Qt = eve_t.dot(np.sqrt(np.diag(eva_t))).dot(eve_t.T) + eva_s, eve_s = np.linalg.eigh(Cs) + Qs = eve_s.dot(np.sqrt(np.diag(eva_s))).dot(eve_s.T) + ts = Qs.dot(np.linalg.inv(Qt)).dot(t) + if mode == 'sym': + eva_t, eve_t = np.linalg.eigh(Ct) + Qt = eve_t.dot(np.sqrt(np.diag(eva_t))).dot(eve_t.T) + Qt_Cs_Qt = Qt.dot(Cs).dot(Qt) + eva_QtCsQt, eve_QtCsQt = np.linalg.eigh(Qt_Cs_Qt) + QtCsQt = eve_QtCsQt.dot(np.sqrt(np.diag(eva_QtCsQt))).dot(eve_QtCsQt.T) + ts = np.linalg.inv(Qt).dot(QtCsQt).dot(np.linalg.inv(Qt)).dot(t) + matched_img = ts.reshape( + *target_img.transpose(2, 0, 1).shape).transpose(1, 2, 0) + matched_img += mu_s + matched_img[matched_img > 1] = 1 + matched_img[matched_img < 0] = 0 + return np.clip(matched_img.astype(source_img.dtype), 0, 1) + + +def lab_image_stats(image): + # compute the mean and standard deviation of each channel + (l, a, b) = cv2.split(image) + (lMean, lStd) = (l.mean(), l.std()) + (aMean, aStd) = (a.mean(), a.std()) + (bMean, bStd) = (b.mean(), b.std()) + + # return the color statistics + return (lMean, lStd, aMean, aStd, bMean, bStd) + + +def _scale_array(arr, clip=True): + if clip: + return np.clip(arr, 0, 255) + + mn = arr.min() + mx = arr.max() + scale_range = (max([mn, 0]), min([mx, 255])) + + if mn < scale_range[0] or mx > scale_range[1]: + return (scale_range[1] - scale_range[0]) * (arr - mn) / (mx - mn) + scale_range[0] + + return arr + + +def channel_hist_match(source, template, hist_match_threshold=255, mask=None): + # Code borrowed from: + # https://stackoverflow.com/questions/32655686/histogram-matching-of-two-images-in-python-2-x + masked_source = source + masked_template = template + + if mask is not None: + masked_source = source * mask + masked_template = template * mask + + oldshape = source.shape + source = source.ravel() + template = template.ravel() + masked_source = masked_source.ravel() + masked_template = masked_template.ravel() + s_values, bin_idx, s_counts = np.unique(source, return_inverse=True, + return_counts=True) + t_values, t_counts = np.unique(template, return_counts=True) + + s_quantiles = np.cumsum(s_counts).astype(np.float64) + s_quantiles = hist_match_threshold * s_quantiles / s_quantiles[-1] + t_quantiles = np.cumsum(t_counts).astype(np.float64) + t_quantiles = 255 * t_quantiles / t_quantiles[-1] + interp_t_values = np.interp(s_quantiles, t_quantiles, t_values) + + return interp_t_values[bin_idx].reshape(oldshape) + + +def color_hist_match(src_im, tar_im, hist_match_threshold=255, mask=None): + h, w, c = src_im.shape + matched_R = channel_hist_match( + src_im[:, :, 0], tar_im[:, :, 0], hist_match_threshold, mask) + matched_G = channel_hist_match( + src_im[:, :, 1], tar_im[:, :, 1], hist_match_threshold, mask) + matched_B = channel_hist_match( + src_im[:, :, 2], tar_im[:, :, 2], hist_match_threshold, mask) + + to_stack = (matched_R, matched_G, matched_B) + for i in range(3, c): + to_stack += (src_im[:, :, i],) + + matched = np.stack(to_stack, axis=-1).astype(src_im.dtype) + return matched + + +def color_transfer_mix(img_src, img_trg): + img_src = np.clip(img_src*255.0, 0, 255).astype(np.uint8) + img_trg = np.clip(img_trg*255.0, 0, 255).astype(np.uint8) + + img_src_lab = cv2.cvtColor(img_src, cv2.COLOR_BGR2LAB) + img_trg_lab = cv2.cvtColor(img_trg, cv2.COLOR_BGR2LAB) + + rct_light = np.clip(linear_color_transfer(img_src_lab[..., 0:1].astype(np.float32)/255.0, + img_trg_lab[..., 0:1].astype(np.float32)/255.0)[..., 0]*255.0, + 0, 255).astype(np.uint8) + + img_src_lab[..., 0] = (np.ones_like(rct_light)*100).astype(np.uint8) + img_src_lab = cv2.cvtColor(img_src_lab, cv2.COLOR_LAB2BGR) + + img_trg_lab[..., 0] = (np.ones_like(rct_light)*100).astype(np.uint8) + img_trg_lab = cv2.cvtColor(img_trg_lab, cv2.COLOR_LAB2BGR) + + img_rct = color_transfer_sot(img_src_lab.astype( + np.float32), img_trg_lab.astype(np.float32)) + img_rct = np.clip(img_rct, 0, 255).astype(np.uint8) + + img_rct = cv2.cvtColor(img_rct, cv2.COLOR_BGR2LAB) + img_rct[..., 0] = rct_light + img_rct = cv2.cvtColor(img_rct, cv2.COLOR_LAB2BGR) + + return (img_rct / 255.0).astype(np.float32) + + +def colorTransfer_fs(src_, dst_, mask): + src = dst_ + dst = src_ + transferredDst = np.copy(dst) + # indeksy nie czarnych pikseli maski + maskIndices = np.where(mask != 0) + # src[maskIndices[0], maskIndices[1]] zwraca piksele w nie czarnym obszarze maski + + maskedSrc = src[maskIndices[0], maskIndices[1]].astype(np.int32) + maskedDst = dst[maskIndices[0], maskIndices[1]].astype(np.int32) + + meanSrc = np.mean(maskedSrc, axis=0) + meanDst = np.mean(maskedDst, axis=0) + + maskedDst = maskedDst - meanDst + maskedDst = maskedDst + meanSrc + maskedDst = np.clip(maskedDst, 0, 255) + + transferredDst[maskIndices[0], maskIndices[1]] = maskedDst + return transferredDst + +def colorTransfer_avg(img_src, img_tgt, mask=None): + img_new = img_src.copy() + img_old = img_tgt.copy() + # print(mask) + if mask is not None: + img_new = (img_new*mask)#.astype(np.uint8) + img_old = (img_old*mask)#.astype(np.uint8) + # cv2.imshow('tgt', img_old) + w,h,c = img_new.shape + for i in range(img_new.shape[2]): + old_avg = img_old[:, :, i].mean() + new_avg = img_new[:, :, i].mean() + diff_int = old_avg - new_avg + # print(diff_int) + for m in range(img_new.shape[0]): + for n in range(img_new.shape[1]): + temp = img_new[m,n,i] + diff_int + temp = max(0., temp) + temp = min(1., temp) + # print(img_new[m,n,i], temp) + img_new[m,n,i] = temp + + return img_new + + + +def color_transfer(ct_mode, img_src, img_trg, mask): + """ + color transfer for [0,1] float32 inputs + """ + img_src = img_src.astype(dtype=np.float32) / 255.0 + img_trg = img_trg.astype(dtype=np.float32) / 255.0 + + if ct_mode == 'lct': + out = linear_color_transfer(img_src, img_trg) + elif ct_mode == 'rct': + out = reinhard_color_transfer(np.clip(img_src*255, 0, 255).astype(np.uint8), + np.clip(img_trg*255, 0, + 255).astype(np.uint8), + preserve_paper=np.random.rand() < 0.5, + clip=np.random.rand() < 0.5) + out = np.clip(out.astype(np.float32) / 255.0, 0.0, 1.0) + elif ct_mode == 'rct-m': + out = reinhard_color_transfer(np.clip(img_src*255, 0, 255).astype(np.uint8), + np.clip(img_trg*255, 0, + 255).astype(np.uint8), + source_mask=mask, target_mask=mask) + #preserve_paper=np.random.rand() < 0.5, + #clip=np.random.rand() < 0.5) + out = np.clip(out.astype(np.float32) / 255.0, 0.0, 1.0) + elif ct_mode == 'rct-fs': + out = colorTransfer_fs(np.clip(img_src*255, 0, 255).astype(np.uint8), + np.clip(img_trg*255, 0, 255).astype(np.uint8), mask) + out = np.clip(out.astype(np.float32) / 255.0, 0.0, 1.0) + elif ct_mode == 'mkl': + out = color_transfer_mkl(img_src, img_trg) + elif ct_mode == 'mkl-m': + out = color_transfer_mkl(img_src*mask, img_trg*mask) + elif ct_mode == 'idt': + out = color_transfer_idt(img_src, img_trg) + elif ct_mode == 'idt-m': + out = color_transfer_idt(img_src*mask, img_trg*mask) + elif ct_mode == 'sot': + out = color_transfer_sot(img_src, img_trg) + out = np.clip(out, 0.0, 1.0) + elif ct_mode == 'sot-m': + out = color_transfer_sot( + (img_src*mask).astype(np.float32), (img_trg*mask).astype(np.float32)) + out = np.clip(out, 0.0, 1.0) + elif ct_mode == 'mix-m': + out = color_transfer_mix(img_src*mask, img_trg*mask) + elif ct_mode == 'seamless-hist-match': + out = color_hist_match(img_src, img_trg) + elif ct_mode == 'seamless-hist-match-m': + out = color_hist_match(img_src, img_trg, mask=mask) + elif ct_mode == 'avg-align': + out = colorTransfer_avg(img_src, img_trg, mask=mask) + out = np.clip(out, 0.0, 1.0) + else: + raise ValueError(f"unknown ct_mode {ct_mode}") + + out = np.clip(out*255, 0, 255).astype(np.uint8) + return out \ No newline at end of file diff --git a/training/dataset/utils/face_align.py b/training/dataset/utils/face_align.py new file mode 100644 index 0000000000000000000000000000000000000000..062eee5a6f13d8ae915055923375a476af9afc2c --- /dev/null +++ b/training/dataset/utils/face_align.py @@ -0,0 +1,173 @@ +import numpy + +from .umeyama import umeyama +from numpy.linalg import inv +import cv2 + +mean_face_x = numpy.array([ +0.000213256, 0.0752622, 0.18113, 0.29077, 0.393397, 0.586856, 0.689483, 0.799124, +0.904991, 0.98004, 0.490127, 0.490127, 0.490127, 0.490127, 0.36688, 0.426036, +0.490127, 0.554217, 0.613373, 0.121737, 0.187122, 0.265825, 0.334606, 0.260918, +0.182743, 0.645647, 0.714428, 0.793132, 0.858516, 0.79751, 0.719335, 0.254149, +0.340985, 0.428858, 0.490127, 0.551395, 0.639268, 0.726104, 0.642159, 0.556721, +0.490127, 0.423532, 0.338094, 0.290379, 0.428096, 0.490127, 0.552157, 0.689874, +0.553364, 0.490127, 0.42689 ]) + +mean_face_y = numpy.array([ +0.106454, 0.038915, 0.0187482, 0.0344891, 0.0773906, 0.0773906, 0.0344891, +0.0187482, 0.038915, 0.106454, 0.203352, 0.307009, 0.409805, 0.515625, 0.587326, +0.609345, 0.628106, 0.609345, 0.587326, 0.216423, 0.178758, 0.179852, 0.231733, +0.245099, 0.244077, 0.231733, 0.179852, 0.178758, 0.216423, 0.244077, 0.245099, +0.780233, 0.745405, 0.727388, 0.742578, 0.727388, 0.745405, 0.780233, 0.864805, +0.902192, 0.909281, 0.902192, 0.864805, 0.784792, 0.778746, 0.785343, 0.778746, +0.784792, 0.824182, 0.831803, 0.824182 ]) + +landmarks_2D = numpy.stack( [ mean_face_x, mean_face_y ], axis=1 ) + +def get_align_mat(face, size, should_align_eyes): + mat_umeyama = umeyama(numpy.array(face.landmarks_as_xy()[17:]), landmarks_2D, True)[0:2] + + if should_align_eyes is False: + return mat_umeyama + + mat_umeyama = mat_umeyama * size + + # Convert to matrix + landmarks = numpy.matrix(face.landmarks_as_xy()) + + # cv2 expects points to be in the form np.array([ [[x1, y1]], [[x2, y2]], ... ]), we'll expand the dim + landmarks = numpy.expand_dims(landmarks, axis=1) + + # Align the landmarks using umeyama + umeyama_landmarks = cv2.transform(landmarks, mat_umeyama, landmarks.shape) + + # Determine a rotation matrix to align eyes horizontally + mat_align_eyes = align_eyes(umeyama_landmarks, size) + + # Extend the 2x3 transform matrices to 3x3 so we can multiply them + # and combine them as one + mat_umeyama = numpy.matrix(mat_umeyama) + mat_umeyama.resize((3, 3)) + mat_align_eyes = numpy.matrix(mat_align_eyes) + mat_align_eyes.resize((3, 3)) + mat_umeyama[2] = mat_align_eyes[2] = [0, 0, 1] + + # Combine the umeyama transform with the extra rotation matrix + transform_mat = mat_align_eyes * mat_umeyama + + # Remove the extra row added, shape needs to be 2x3 + transform_mat = numpy.delete(transform_mat, 2, 0) + transform_mat = transform_mat / size + return transform_mat + + +from .face_blend import get_5_keypoint + +def get_align_mat_new(src_lmk, tgt_lmk, size=256, should_align_eyes=False): + mat_umeyama = umeyama(get_5_keypoint(src_lmk), get_5_keypoint(tgt_lmk), True)[0:2] + # mat_umeyama = umeyama(numpy.array(src_lmk[17:]), numpy.array(tgt_lmk[17:]), True)[0:2] + + if should_align_eyes is False: + return mat_umeyama + + mat_umeyama = mat_umeyama * size + + # Convert to matrix + landmarks = numpy.matrix(face.landmarks_as_xy()) + + # cv2 expects points to be in the form np.array([ [[x1, y1]], [[x2, y2]], ... ]), we'll expand the dim + landmarks = numpy.expand_dims(landmarks, axis=1) + + # Align the landmarks using umeyama + umeyama_landmarks = cv2.transform(landmarks, mat_umeyama, landmarks.shape) + + # Determine a rotation matrix to align eyes horizontally + mat_align_eyes = align_eyes(umeyama_landmarks, size) + + # Extend the 2x3 transform matrices to 3x3 so we can multiply them + # and combine them as one + mat_umeyama = numpy.matrix(mat_umeyama) + mat_umeyama.resize((3, 3)) + mat_align_eyes = numpy.matrix(mat_align_eyes) + mat_align_eyes.resize((3, 3)) + mat_umeyama[2] = mat_align_eyes[2] = [0, 0, 1] + + # Combine the umeyama transform with the extra rotation matrix + transform_mat = mat_align_eyes * mat_umeyama + + # Remove the extra row added, shape needs to be 2x3 + transform_mat = numpy.delete(transform_mat, 2, 0) + transform_mat = transform_mat / size + return transform_mat + +# Code borrowed from https://github.com/jrosebr1/imutils/blob/d5cb29d02cf178c399210d5a139a821dfb0ae136/imutils/face_utils/helpers.py +""" +The MIT License (MIT) + +Copyright (c) 2015-2016 Adrian Rosebrock, http://www.pyimagesearch.com + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from collections import OrderedDict +import numpy as np +import cv2 + +# define a dictionary that maps the indexes of the facial +# landmarks to specific face regions +FACIAL_LANDMARKS_IDXS = OrderedDict([ + ("mouth", (48, 68)), + ("right_eyebrow", (17, 22)), + ("left_eyebrow", (22, 27)), + ("right_eye", (36, 42)), + ("left_eye", (42, 48)), + ("nose", (27, 36)), + ("jaw", (0, 17)), + ("chin", (8, 11)) +]) + +# Returns a rotation matrix that when applied to the 68 input facial landmarks +# results in landmarks with eyes aligned horizontally +def align_eyes(landmarks, size): + desiredLeftEye = (0.35, 0.35) # (y, x) value + desiredFaceWidth = desiredFaceHeight = size + + # extract the left and right eye (x, y)-coordinates + (lStart, lEnd) = FACIAL_LANDMARKS_IDXS["left_eye"] + (rStart, rEnd) = FACIAL_LANDMARKS_IDXS["right_eye"] + leftEyePts = landmarks[lStart:lEnd] + rightEyePts = landmarks[rStart:rEnd] + + # compute the center of mass for each eye + leftEyeCenter = leftEyePts.mean(axis=0).astype("int") + rightEyeCenter = rightEyePts.mean(axis=0).astype("int") + + # compute the angle between the eye centroids + dY = rightEyeCenter[0,1] - leftEyeCenter[0,1] + dX = rightEyeCenter[0,0] - leftEyeCenter[0,0] + angle = np.degrees(np.arctan2(dY, dX)) - 180 + + # compute center (x, y)-coordinates (i.e., the median point) + # between the two eyes in the input image + eyesCenter = ((leftEyeCenter[0,0] + rightEyeCenter[0,0]) // 2, (leftEyeCenter[0,1] + rightEyeCenter[0,1]) // 2) + + # grab the rotation matrix for rotating and scaling the face + M = cv2.getRotationMatrix2D(eyesCenter, angle, 1.0) + + return M \ No newline at end of file diff --git a/training/dataset/utils/face_aug.py b/training/dataset/utils/face_aug.py new file mode 100644 index 0000000000000000000000000000000000000000..3e116af4bcd3dc9fb1821652ae8d4afe5ce9e6de --- /dev/null +++ b/training/dataset/utils/face_aug.py @@ -0,0 +1,125 @@ +""" +Exposing DeepFake Videos By Detecting Face Warping Artifacts +Yuezun Li, Siwei Lyu +https://arxiv.org/abs/1811.00656 +""" +import cv2 +import numpy as np +from PIL import Image, ImageEnhance +# We only use opencv3 +# if not (cv2.__version__).startswith('3.'): +# raise ValueError('Only opencv 3. is supported!') + +''' +these two function is implemented by myself, may have some errors QAQ +''' + + +def change_res(img): + init_res = img.shape[0] + fake_res = np.random.randint(init_res//4, init_res*2) + img = cv2.resize(img, (fake_res, fake_res)) + img = cv2.resize(img, (init_res, init_res)) + return img, fake_res + + +def aug_one_im(img, + random_transform_args=None, + color_rng=[0.9, 1.1]): + """ + Augment operation for image list + :param images: image list + :param random_transform_args: shape transform arguments + :param color_rng: color transform arguments + :return: + """ + images = [img] + images = aug(images, random_transform_args, color_rng) + + return images[0] + + +def aug(images, + random_transform_args={ + 'rotation_range': 10, + 'zoom_range': 0.05, + 'shift_range': 0.05, + 'random_flip': 0.5, + }, + color_rng=[0.9, 1.1]): + """ + Augment operation for image list + :param images: image list + :param random_transform_args: shape transform arguments + :param color_rng: color transform arguments + :return: + """ + if random_transform_args is not None: # do aug + # Transform + images = random_transform(images, **random_transform_args) + # Color + if color_rng is not None: + for i, im in enumerate(images): + # im = im[:, :, (2, 1, 0)] + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + im = Image.fromarray(np.uint8(im)) + + # Brightness + factor = np.random.uniform(color_rng[0], color_rng[1]) + enhancer = ImageEnhance.Brightness(im) + im = enhancer.enhance(factor) + # Contrast + factor = np.random.uniform(color_rng[0], color_rng[1]) + enhancer = ImageEnhance.Contrast(im) + im = enhancer.enhance(factor) + # Color distort + factor = np.random.uniform(color_rng[0], color_rng[1]) + enhancer = ImageEnhance.Color(im) + im = enhancer.enhance(factor) + + # Sharpe + factor = np.random.uniform(color_rng[0], color_rng[1]) + enhancer = ImageEnhance.Sharpness(im) + im = enhancer.enhance(factor) + im = np.array(im).astype(np.uint8) + # im = im[:, :, (2, 1, 0)] + im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR) + images[i] = im.copy() + + return images + + +def random_transform(images, rotation_range, zoom_range, shift_range, random_flip): + """ + Random transform images in a list + :param images: + :param rotation_range: + :param zoom_range: + :param shift_range: + :param random_flip: + :return: + """ + h, w = images[0].shape[:2] + rotation = np.random.uniform(-rotation_range, rotation_range) + scale = np.random.uniform(1 - zoom_range, 1 + zoom_range) + tx = np.random.uniform(-shift_range, shift_range) * w + ty = np.random.uniform(-shift_range, shift_range) * h + flip_prob = np.random.random() + for i, image in enumerate(images): + mat = cv2.getRotationMatrix2D((w / 2, h / 2), rotation, scale) + mat[:, 2] += (tx, ty) + result = cv2.warpAffine( + image, mat, (w, h), borderMode=cv2.BORDER_REPLICATE) + if flip_prob < random_flip: + result = result[:, ::-1] + images[i] = result.copy() + return images + + +if __name__ == "__main__": + dirr = '/FaceXray/dataset/utils/' + test_im = cv2.imread('{}test.png'.format(dirr)) + resample_res, fake_res = change_res(test_im) + cv2.imwrite('{}res_{}.png'.format(dirr, fake_res), resample_res) + aug_im = aug_one_im(test_im) + cv2.imwrite('{}auged.png'.format(dirr), aug_im) diff --git a/training/dataset/utils/face_blend.py b/training/dataset/utils/face_blend.py new file mode 100644 index 0000000000000000000000000000000000000000..1b12657f80e8d2305df3e9f86cc402ee2469924e --- /dev/null +++ b/training/dataset/utils/face_blend.py @@ -0,0 +1,469 @@ +''' +Create face mask and face boundary mask according to face landmarks, +so as to supervize the activation of Conv layer. +''' + +import os +import numpy as np +import cv2 +import dlib +import random +import argparse +from tqdm import tqdm +import time +from skimage import transform as trans +# from color_transfer import color_transfer +from .warp import gen_warp_params, warp_by_params, warp_mask + + +def crop_img_bbox(img, bbox, res, scale=1.3): + x, y, w, h = bbox + left, right = x, x+w + top, bottom = y, y+h + + H, W, C = img.shape + cx, cy = (left+right)//2, (top+bottom)//2 + w, h = (right-left)//2, (bottom-top)//2 + + x1 = max(0, int(cx-w*scale)) + x2 = min(W, int(cx+w*scale)) + y1 = max(0, int(cy-h*scale)) + y2 = min(H, int(cy+h*scale)) + + roi = img[y1:y2, x1:x2] + roi = cv2.resize(roi, (res, res)) + + return roi + + +def get_mask_center(mask): + l, t, w, h = cv2.boundingRect(mask[:, :, 0:1].astype(np.uint8)) + center = int(l+w/2), int(t+h/2) + return center + + +def get_5_keypoint(shape): + def get_point(idx): + # return [shape.part(idx).x, shape.part(idx).y] + return shape[idx] + + def center(pt1, pt2): + return [(pt1[0]+pt2[0])//2, (pt1[1]+pt2[1])//2] + + leye = np.array(center(get_point(36), get_point(39)), + dtype=int).reshape(-1, 2) + reye = np.array(center(get_point(45), get_point(42)), + dtype=int).reshape(-1, 2) + nose = np.array(get_point(30), dtype=int).reshape(-1, 2) + lmouth = np.array(get_point(48), + dtype=int).reshape(-1, 2) + rmouth = np.array(get_point(54), + dtype=int).reshape(-1, 2) + + pts = np.concatenate([leye, reye, nose, lmouth, rmouth], axis=0) + + return pts + + +def get_boundary(mask): + if len(mask.shape) == 3: + mask = mask[:, :, 0] + mask = cv2.GaussianBlur(mask, (3, 3), 0) + boundary = mask / 255. + boundary = 4*boundary*(1.-boundary) + return boundary + + +# def get_boundary(mask): +# if len(mask.shape) == 3: +# mask = mask[:, :, 0] +# mask = cv2.GaussianBlur(mask, (3, 3), 0) +# mask = mask.astype(np.uint8) + +# # Dilation and Erosion to find the boundary +# dilated = cv2.dilate(mask, None, iterations=1) +# boundary = cv2.subtract(dilated, mask) + +# # normalize the boundary to have values between 0 and 1 +# boundary = boundary / 255. + +# return boundary + + + +def blur_mask(mask): + blur_k = 2*np.random.randint(1, 10)-1 + + #kernel = np.ones((blur_k+1, blur_k+1), np.uint8) + #mask = cv2.erode(mask, kernel) + + mask = cv2.GaussianBlur(mask, (blur_k, blur_k), 0) + + + return mask + + +def random_deform(pt, tgt, scale=0.3): + x1, y1 = pt + x2, y2 = tgt + + x = x1+(x2-x1)*np.random.rand()*scale + y = y1+(y2-y1)*np.random.rand()*scale + #print('before:', pt, ' after:', [int(x), int(y)]) + return [int(x), int(y)] + + +def get_specific_mask(img, shape, mtype='mouth', random_side=False): + if mtype == 'eyes': + landmarks = shape[42:45] if random.choice([True, False]) else shape[36:39] + + elif mtype == 'nose': + landmarks = shape[27:35] + + elif mtype == 'mouth': + landmarks = shape[48:60] + + elif mtype == 'eyebrows': + landmarks = shape[22:26] if random.choice([True, False]) else shape[17:21] + + else: + raise ValueError(f"Invalid mtype. Choose from 'eyes', 'nose', 'mouth', or 'eyebrows', but got {mtype}") + + # find convex hull + hull = cv2.convexHull(landmarks) + hull = hull.astype(int) + + # mask + hull_mask = np.zeros_like(img) + cv2.fillPoly(hull_mask, [hull], (255, 255, 255)) + mask = hull_mask + return mask + + +def get_hull_mask(img, shape, mtype='hull'): + if mtype == 'normal-hull': + landmarks = np.array(shape) + + # find convex hull + hull = cv2.convexHull(landmarks) + hull = hull.astype(int) + + # full face mask + hull_mask = np.zeros_like(img) + cv2.fillPoly(hull_mask, [hull], (255, 255, 255)) + mask = hull_mask + + elif mtype == 'inner-hull': + landmarks = shape[17:] + landmarks = np.array(landmarks) + + # find convex hull + hull = cv2.convexHull(landmarks) + hull = hull.astype(int) + + # full face mask + hull_mask = np.zeros_like(img) + cv2.fillPoly(hull_mask, [hull], (255, 255, 255)) + + mask = hull_mask + + elif mtype == 'inner-hull-no-eyebrow': + landmarks = shape[27:] + landmarks = np.array(landmarks) + # find convex hull + hull = cv2.convexHull(landmarks) + hull = hull.astype(int) + + # full face mask + hull_mask = np.zeros_like(img) + cv2.fillPoly(hull_mask, [hull], (255, 255, 255)) + + mask = hull_mask + + elif mtype == 'mouth-hull': + landmarks = shape[2:15] + #landmarks.append(shape[29]) + landmarks = np.concatenate([landmarks, shape[29].reshape(1, -1)], axis=0) + + # find convex hull + hull = cv2.convexHull(landmarks) + hull = hull.astype(int) + + # full face mask + hull_mask = np.zeros_like(img) + cv2.fillPoly(hull_mask, [hull], (255, 255, 255)) + + # kernel = np.ones((2, 2), np.uint8) + # c_mask = cv2.dilate(hull_mask, kernel, iterations=1) + mask = hull_mask + + elif mtype == 'whole-hull': + face_height = shape[9][1] - shape[22][1] + landmarks = [] + for i in range(27): + lmk = shape[i] + if i >= 5 and i <= 11: + x, y = lmk[0], lmk[1] + lmk = [x, max(0, y+15)] + # lift the eyebrows to get a larger landmark convex hull + if i >= 18 and i <= 27: + x, y = lmk[0], lmk[1] + lmk = [x, max(0, y-face_height//4)] + + landmarks.append(lmk) + + # find convex hull + landmarks = np.array(landmarks, dtype=np.int32) + hull = cv2.convexHull(landmarks) + hull = np.reshape(hull, (1, -1, 2)) + + # full face mask + hull_mask = np.zeros_like(img) + cv2.fillPoly(hull_mask, [hull], (255, 255, 255)) + + # kernel = np.ones((2, 2), np.uint8) + # c_mask = cv2.dilate(hull_mask, kernel, iterations=1) + mask = hull_mask + ''' + elif mtype == 'rect': + cnt = [] + for idx in [5, 11, 17, 26]: + cnt.append(shape[idx]) + x, y, w, h = cv2.boundingRect(np.array(cnt)) + rect_mask = np.zeros_like(img) + cv2.rectangle(rect_mask, (x, y), (x+w, y+h), + (255, 255, 255), cv2.FILLED) + mask = rect_mask + ''' + return mask + + +def get_mask(shape, img, std=20, deform=True, restrict_mask=None): + mask_type = [ + 'normal-hull', + 'inner-hull', + 'inner-hull-no-eyebrow', + 'mouth-hull', + 'whole-hull' + ] + max_mask = get_hull_mask(img, shape, 'whole-hull') + if deform: + mtype = mask_type[np.random.randint(len(mask_type))] + if mtype == 'rect': + mask = get_hull_mask(img, shape, 'inner-hull-no-eyebrow') + x, y, w, h = cv2.boundingRect(mask[:,:,0]) + for i in range(y, y+h): + for j in range(x, x+w): + for k in range(mask.shape[2]): + mask[i, j, k] = 255 + else: + mask = get_hull_mask(img, shape, mtype) + + # random deform + if np.random.rand() < 0.9: + mask = warp_mask(mask, std=std) + + # # random erode/dilate + # prob = np.random.rand() + # if prob < 0.3: + # erode_k = 2*np.random.randint(1, 10)+1 + # kernel = np.ones((erode_k, erode_k), np.uint8) + # mask = cv2.erode(mask, kernel) + # elif prob < 0.6: + # erode_k = 2*np.random.randint(1, 10)+1 + # kernel = np.ones((erode_k, erode_k), np.uint8) + # mask = cv2.dilate(mask, kernel) + else: + mask = max_mask.copy() + + if restrict_mask is not None: + mask = mask*(restrict_mask//255) + + # restrict mask range + mask = mask *(max_mask//255) + + # random blur + if deform and np.random.rand() < 0.9: + mask = blur_mask(mask) + + return mask[:,:,0] + +def mask_postprocess(mask): + # random erode/dilate + prob = np.random.rand() + if prob < 0.3: + erode_k = 2*np.random.randint(1, 10)+1 + kernel = np.ones((erode_k, erode_k), np.uint8) + mask = cv2.erode(mask, kernel) + elif prob < 0.6: + erode_k = 2*np.random.randint(1, 10)+1 + kernel = np.ones((erode_k, erode_k), np.uint8) + mask = cv2.dilate(mask, kernel) + + # random blur + if np.random.rand() < 0.9: + mask = blur_mask(mask) + + return mask + + +def get_affine_param(from_, to_): + # use skimage tranformation + tform = trans.SimilarityTransform() + tform.estimate(from_.astype(np.float32), to_.astype( + np.float32)) # tform.estimate(from_, to_) + M = tform.params[0:2, :] + + return M + + +def random_sharpen_img(img): + cand = ['bsharpen', 'gsharpen'] # , 'none'] + mode = cand[np.random.randint(len(cand))] + # print('sharpen mode:', mode) + if mode == "bsharpen": + # Sharpening using filter2D + kernel = np.ones((3, 3)) * (-1) + kernel[1, 1] = 9 + #kernel /= 9. + out = cv2.filter2D(img, -1, kernel) + elif mode == "gsharpen": + # Sharpening using Weighted Method + gaussain_blur = cv2.GaussianBlur(img, (0, 0), 3.0) + out = cv2.addWeighted( + img, 1.5, gaussain_blur, -0.5, 0, img) + else: + out = img + + return out + + +def random_blur_img(img): + cand = ['avg', 'gaussion', 'med'] # , 'none'] + mode = cand[np.random.randint(len(cand))] + # print('blur mode:', mode) + ksize = 2*np.random.randint(1, 5)+1 + + if mode == 'avg': + # Averaging + out = cv2.blur(img, (ksize, ksize)) + elif mode == 'gaussion': + # Gaussian Blurring + out = cv2.GaussianBlur(img, (ksize, ksize), 0) + elif mode == 'med': + # Median blurring + out = cv2.medianBlur(img, ksize) + else: + out = img + # elif mode == 'bilateral' + # # Bilateral Filtering + # out = cv2.bilateralFilter(img,9,75,75) + + return out + + +def random_warp_img(img, prob=0.5): + H, W, C = img.shape + param = gen_warp_params(W, flip=False) + choice = [True, False] + + out = warp_by_params(param, img, + can_flip=False, # choice[np.random.randint(2)], + can_transform=False, # choice[np.random.randint(2)], + can_warp=(np.random.randint(10) < int(prob*10)), + border_replicate=choice[np.random.randint(2)]) + return out + + +def main(args): + np.random.seed(int(time.time())) + detector = dlib.get_frontal_face_detector() + landmark_predictor = dlib.shape_predictor(args.model) + + src_im = cv2.imread(args.src) + tgt_im = cv2.imread(args.tgt) + + H, W, C = tgt_im.shape + src_im = cv2.resize(src_im, (W, H)) + + def get_shape(img): + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + dets = detector(img, 1) + det = dets[0] + shape = landmark_predictor(img, det) + + return shape, det + + src_shape, src_det = get_shape(src_im) + src_5_pts = get_5_keypoint(src_shape) + src_mask = get_mask(src_shape, src_im, whole=True, deform=False) + + tgt_shape, tgt_det = get_shape(tgt_im) + tgt_5_pts = get_5_keypoint(tgt_shape) + tgt_mask = get_mask(tgt_shape, tgt_im, whole=False, deform=True) + + #aff_param = get_affine_param(src_5_pts, tgt_5_pts) + + # color transfer: + mask = src_mask[:, :, 0:1]/255. + ct_modes = ['lct', 'rct', 'idt', 'idt-m', 'mkl', 'mkl-m', + 'sot', 'sot-m', 'mix-m'] # , 'seamless-hist-match'] + for mode in ct_modes: + colored_src = color_transfer(mode, src_im, tgt_im, mask) + cv2.imwrite('{}_colored.png'.format(mode), colored_src) + src_im = colored_src + + w1, h1 = src_det.right()-src_det.left(), src_det.bottom()-src_det.top() + w2, h2 = tgt_det.right()-tgt_det.left(), tgt_det.bottom()-tgt_det.top() + w_scale, h_scale = w2/w1, h2/h1 + + scaled_src = cv2.resize(src_im, (int(W*w_scale), int(H*h_scale))) + scaled_mask = cv2.resize(src_mask, (int(W*w_scale), int(H*h_scale))) + + src_5_pts[:, 0] = src_5_pts[:, 0]*w_scale + src_5_pts[:, 1] = src_5_pts[:, 1]*h_scale + aff_param = get_affine_param(src_5_pts, tgt_5_pts) + + aligned_src = cv2.warpAffine( + scaled_src, aff_param, (W, H), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REFLECT) + aligned_mask = cv2.warpAffine( + scaled_mask, aff_param, (W, H), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REFLECT) + + center = get_mask_center(aligned_mask) + print('mask center:', center) + # colored_src = transfer_color(aligned_src, tgt_im) + + init_blend = cv2.seamlessClone( + aligned_src, tgt_im, aligned_mask, center, cv2.NORMAL_CLONE) + cv2.imwrite('init_blended.png', init_blend) + + # aligned_blend = cv2.warpAffine( + # colored_blend, aff_param, (W, H), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REFLECT) + b_mask = tgt_mask[:, :, 0:1]/255. + out_blend = init_blend*b_mask + tgt_im*(1. - b_mask) + cv2.imwrite('out_blended.png', out_blend) + + res = 256 + blend_crop = crop_img_bbox(out_blend, tgt_det, res, scale=1.5) + mask_crop = crop_img_bbox(tgt_mask, tgt_det, res, scale=1.5) + boundary = get_boundary(mask_crop) + + cv2.imwrite('crop_blend.png', blend_crop) + cv2.imwrite('crop_mask.png', mask_crop) + cv2.imwrite('crop_bound.png', boundary*255) + + +if __name__ == "__main__": + p = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + p.add_argument('-s', '--src', type=str, + help='src image') + p.add_argument('-t', '--tgt', type=str, + help='tgt image') + p.add_argument('--model', type=str, default='/data1/yuchen/download/face_landmark/shape_predictor_68_face_landmarks.dat', + help="path to downloaded detector") + args = p.parse_args() + print(args) + + main(args) diff --git a/training/dataset/utils/faceswap.py b/training/dataset/utils/faceswap.py new file mode 100644 index 0000000000000000000000000000000000000000..0dbcb1db91e3d12ab7b0cd82ab0950ef05ad2c82 --- /dev/null +++ b/training/dataset/utils/faceswap.py @@ -0,0 +1,249 @@ +''' +code from https://github.com/wuhuikai/FaceSwap/blob/master/face_swap.py +''' + + + +#! /usr/bin/env python +import cv2 +import numpy as np +import scipy.spatial as spatial +import logging + + +## 3D Transform +def bilinear_interpolate(img, coords): + """ Interpolates over every image channel + http://en.wikipedia.org/wiki/Bilinear_interpolation + :param img: max 3 channel image + :param coords: 2 x _m_ array. 1st row = xcoords, 2nd row = ycoords + :returns: array of interpolated pixels with same shape as coords + """ + int_coords = np.int32(coords) + x0, y0 = int_coords + x0[x0>254] = 254 + y0[y0>254] = 254 + dx, dy = coords - int_coords + + # 4 Neighour pixels + q11 = img[y0, x0] + q21 = img[y0, x0 + 1] + q12 = img[y0 + 1, x0] + q22 = img[y0 + 1, x0 + 1] + + btm = q21.T * dx + q11.T * (1 - dx) + top = q22.T * dx + q12.T * (1 - dx) + inter_pixel = top * dy + btm * (1 - dy) + + return inter_pixel.T + +def grid_coordinates(points): + """ x,y grid coordinates within the ROI of supplied points + :param points: points to generate grid coordinates + :returns: array of (x, y) coordinates + """ + xmin = np.min(points[:, 0]) + xmax = np.max(points[:, 0]) + 1 + ymin = np.min(points[:, 1]) + ymax = np.max(points[:, 1]) + 1 + + return np.asarray([(x, y) for y in range(ymin, ymax) + for x in range(xmin, xmax)], np.uint32) + + +def process_warp(src_img, result_img, tri_affines, dst_points, delaunay): + """ + Warp each triangle from the src_image only within the + ROI of the destination image (points in dst_points). + """ + roi_coords = grid_coordinates(dst_points) + # indices to vertices. -1 if pixel is not in any triangle + roi_tri_indices = delaunay.find_simplex(roi_coords) + + for simplex_index in range(len(delaunay.simplices)): + coords = roi_coords[roi_tri_indices == simplex_index] + num_coords = len(coords) + out_coords = np.dot(tri_affines[simplex_index], + np.vstack((coords.T, np.ones(num_coords)))) + x, y = coords.T + x[x>255] = 255 + y[y>255] = 255 + result_img[y, x] = bilinear_interpolate(src_img, out_coords) + + return None + + +def triangular_affine_matrices(vertices, src_points, dst_points): + """ + Calculate the affine transformation matrix for each + triangle (x,y) vertex from dst_points to src_points + :param vertices: array of triplet indices to corners of triangle + :param src_points: array of [x, y] points to landmarks for source image + :param dst_points: array of [x, y] points to landmarks for destination image + :returns: 2 x 3 affine matrix transformation for a triangle + """ + ones = [1, 1, 1] + for tri_indices in vertices: + #print(tri_indices) + src_tri = np.vstack((src_points[tri_indices, :].T, ones)) + dst_tri = np.vstack((dst_points[tri_indices, :].T, ones)) + mat = np.dot(src_tri, np.linalg.inv(dst_tri))[:2, :] + yield mat + + +def warp_image_3d(src_img, src_points, dst_points, dst_shape, dtype=np.uint8): + rows, cols = dst_shape[:2] + result_img = np.zeros((rows, cols, 3), dtype=dtype) + + delaunay = spatial.Delaunay(dst_points) + tri_affines = np.asarray(list(triangular_affine_matrices( + delaunay.simplices, src_points, dst_points))) + + process_warp(src_img, result_img, tri_affines, dst_points, delaunay) + + return result_img + + +## 2D Transform +def transformation_from_points(points1, points2): + points1 = points1.astype(np.float64) + points2 = points2.astype(np.float64) + + c1 = np.mean(points1, axis=0) + c2 = np.mean(points2, axis=0) + points1 -= c1 + points2 -= c2 + + s1 = np.std(points1) + s2 = np.std(points2) + points1 /= s1 + points2 /= s2 + + U, S, Vt = np.linalg.svd(np.dot(points1.T, points2)) + R = (np.dot(U, Vt)).T + + return np.vstack([np.hstack([s2 / s1 * R, + (c2.T - np.dot(s2 / s1 * R, c1.T))[:, np.newaxis]]), + np.array([[0., 0., 1.]])]) + + +def warp_image_2d(im, M, dshape): + output_im = np.zeros(dshape, dtype=im.dtype) + cv2.warpAffine(im, + M[:2], + (dshape[1], dshape[0]), + dst=output_im, + borderMode=cv2.BORDER_TRANSPARENT, + flags=cv2.WARP_INVERSE_MAP) + + return output_im + + +## Generate Mask +def mask_from_points(size, points,erode_flag=1): + radius = 10 # kernel size + kernel = np.ones((radius, radius), np.uint8) + + mask = np.zeros(size, np.uint8) + cv2.fillConvexPoly(mask, cv2.convexHull(points), 255) + if erode_flag: + mask = cv2.erode(mask, kernel,iterations=1) + + return mask + + +## Color Correction +def correct_colours(im1, im2, landmarks1): + COLOUR_CORRECT_BLUR_FRAC = 0.75 + LEFT_EYE_POINTS = list(range(42, 48)) + RIGHT_EYE_POINTS = list(range(36, 42)) + + blur_amount = COLOUR_CORRECT_BLUR_FRAC * np.linalg.norm( + np.mean(landmarks1[LEFT_EYE_POINTS], axis=0) - + np.mean(landmarks1[RIGHT_EYE_POINTS], axis=0)) + blur_amount = int(blur_amount) + if blur_amount % 2 == 0: + blur_amount += 1 + im1_blur = cv2.GaussianBlur(im1, (blur_amount, blur_amount), 0) + im2_blur = cv2.GaussianBlur(im2, (blur_amount, blur_amount), 0) + + # Avoid divide-by-zero errors. + im2_blur = im2_blur.astype(int) + im2_blur += 128*(im2_blur <= 1) + + result = im2.astype(np.float64) * im1_blur.astype(np.float64) / im2_blur.astype(np.float64) + result = np.clip(result, 0, 255).astype(np.uint8) + + return result + + +## Copy-and-paste +def apply_mask(img, mask): + """ Apply mask to supplied image + :param img: max 3 channel image + :param mask: [0-255] values in mask + :returns: new image with mask applied + """ + masked_img=cv2.bitwise_and(img,img,mask=mask) + + return masked_img + + +## Alpha blending +def alpha_feathering(src_img, dest_img, img_mask, blur_radius=15): + mask = cv2.blur(img_mask, (blur_radius, blur_radius)) + mask = mask / 255.0 + + result_img = np.empty(src_img.shape, np.uint8) + for i in range(3): + result_img[..., i] = src_img[..., i] * mask + dest_img[..., i] * (1-mask) + + return result_img + + +def check_points(img,points): + # Todo: I just consider one situation. + if points[8,1]>img.shape[0]: + logging.error("Jaw part out of image") + else: + return True + return False + + +def face_swap(src_face, dst_face, src_points, dst_points, dst_shape, dst_img, args, end=48): + h, w = dst_face.shape[:2] + + ## 3d warp + warped_src_face = warp_image_3d(src_face, src_points[:end], dst_points[:end], (h, w)) + ## Mask for blending + mask = mask_from_points((h, w), dst_points) + mask_src = np.mean(warped_src_face, axis=2) > 0 + mask = np.asarray(mask * mask_src, dtype=np.uint8) + ## Correct color + if args.correct_color: + warped_src_face = apply_mask(warped_src_face, mask) + dst_face_masked = apply_mask(dst_face, mask) + warped_src_face = correct_colours(dst_face_masked, warped_src_face, dst_points) + ## 2d warp + if args.warp_2d: + unwarped_src_face = warp_image_3d(warped_src_face, dst_points[:end], src_points[:end], src_face.shape[:2]) + warped_src_face = warp_image_2d(unwarped_src_face, transformation_from_points(dst_points, src_points), + (h, w, 3)) + + mask = mask_from_points((h, w), dst_points) + mask_src = np.mean(warped_src_face, axis=2) > 0 + mask = np.asarray(mask * mask_src, dtype=np.uint8) + + ## Shrink the mask + kernel = np.ones((10, 10), np.uint8) + mask = cv2.erode(mask, kernel, iterations=1) + ##Poisson Blending + r = cv2.boundingRect(mask) + center = ((r[0] + int(r[2] / 2), r[1] + int(r[3] / 2))) + output = cv2.seamlessClone(warped_src_face, dst_face, mask, center, cv2.NORMAL_CLONE) + + x, y, w, h = dst_shape + dst_img_cp = dst_img.copy() + dst_img_cp[y:y + h, x:x + w] = output + + return dst_img_cp diff --git a/training/dataset/utils/faceswap_utils.py b/training/dataset/utils/faceswap_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..33b944a90b015a6f329d054285aa67e927b61b32 --- /dev/null +++ b/training/dataset/utils/faceswap_utils.py @@ -0,0 +1,63 @@ +import numpy as np +import cv2 + +def AlphaBlend(foreground, background, alpha): + # Convert uint8 to float + foreground = foreground.astype(float) + background = background.astype(float) + + # Normalize the alpha mask to keep intensity between 0 and 1 + alpha = alpha.astype(float)/255 + if len(alpha.shape) < 3: + alpha = np.expand_dims(alpha, 2) + outImage = alpha * foreground + (1.-alpha) * background + outImage = np.clip(outImage, 0, 255).astype(np.uint8) + + return outImage + +def blendImages(src, dst, mask, featherAmount=0.1): + maskIndices = np.where(mask != 0) + maskPts = np.hstack( + (maskIndices[1][:, np.newaxis], maskIndices[0][:, np.newaxis])) + faceSize = np.max(maskPts, axis=0) - np.min(maskPts, axis=0) + featherAmount = 0.2 + + hull = cv2.convexHull(maskPts) + #hull = hull.astype(np.uint64) + dists = np.zeros(maskPts.shape[0]) + for i in range(maskPts.shape[0]): + point = (maskPts[i, 0], maskPts[i, 1]) + """ + The third paprameter can be set as "True" for more visually diverse images. + We use "False" to add imperceptible image patterns to synthesize new images. + """ + point_x, point_y = point + dists[i] = cv2.pointPolygonTest(hull, (int(point_x),int(point_y)), False) + + weights = np.clip(dists / featherAmount, 0, 1) + composedImg = np.copy(dst) + composedImg[maskIndices[0], maskIndices[1]] = weights[:, np.newaxis] * \ + src[maskIndices[0], maskIndices[1]] + \ + (1 - weights[:, np.newaxis]) * \ + dst[maskIndices[0], maskIndices[1]] + newMask = np.zeros_like(dst).astype(np.float32) + newMask[maskIndices[0], maskIndices[1]] = weights[:, np.newaxis] + + return composedImg, newMask + + +def colorTransfer(src_, dst_, mask): + src = dst_ + dst = src_ + transferredDst = np.copy(dst) + maskIndices = np.where(mask != 0) + maskedSrc = src[maskIndices[0], maskIndices[1]].astype(np.int32) + maskedDst = dst[maskIndices[0], maskIndices[1]].astype(np.int32) + meanSrc = np.mean(maskedSrc, axis=0) + meanDst = np.mean(maskedDst, axis=0) + maskedDst = maskedDst - meanDst + maskedDst = maskedDst + meanSrc + maskedDst = np.clip(maskedDst, 0, 255) + transferredDst[maskIndices[0], maskIndices[1]] = maskedDst + + return transferredDst diff --git a/training/dataset/utils/faceswap_utils_sladd.py b/training/dataset/utils/faceswap_utils_sladd.py new file mode 100644 index 0000000000000000000000000000000000000000..c277a8862ee0a304106626fecff2ba3aaae8c04e --- /dev/null +++ b/training/dataset/utils/faceswap_utils_sladd.py @@ -0,0 +1,62 @@ +import numpy as np +import cv2 + +def AlphaBlend(foreground, background, alpha): + # Convert uint8 to float + foreground = foreground.astype(float) + background = background.astype(float) + + # Normalize the alpha mask to keep intensity between 0 and 1 + alpha = alpha.astype(float)/255 + if len(alpha.shape) < 3: + alpha = np.expand_dims(alpha, 2) + outImage = alpha * foreground + (1.-alpha) * background + outImage = np.clip(outImage, 0, 255).astype(np.uint8) + + return outImage + +def blendImages(src, dst, mask, featherAmount=0.1): + maskIndices = np.where(mask != 0) + maskPts = np.hstack( + (maskIndices[1][:, np.newaxis], maskIndices[0][:, np.newaxis])) + faceSize = np.max(maskPts, axis=0) - np.min(maskPts, axis=0) + featherAmount = featherAmount * np.max(faceSize) + + hull = cv2.convexHull(maskPts) + #hull = hull.astype(np.uint64) + dists = np.zeros(maskPts.shape[0]) + for i in range(maskPts.shape[0]): + point = (int(maskPts[i, 0]), int(maskPts[i, 1])) + """ + The third paprameter can be set as "True" for more visually diverse images. + We use "False" to add imperceptible image patterns to synthesize new images. + """ + dists[i] = cv2.pointPolygonTest(hull, point, False) + + weights = np.clip(dists / featherAmount, 0, 1) + composedImg = np.copy(dst) + composedImg[maskIndices[0], maskIndices[1]] = weights[:, np.newaxis] * \ + src[maskIndices[0], maskIndices[1]] + \ + (1 - weights[:, np.newaxis]) * \ + dst[maskIndices[0], maskIndices[1]] + newMask = np.zeros_like(dst).astype(np.float32) + newMask[maskIndices[0], maskIndices[1]] = weights[:, np.newaxis] + + return composedImg, newMask + + +def colorTransfer(src_, dst_, mask): + src = dst_ + dst = src_ + transferredDst = np.copy(dst) + maskIndices = np.where(mask != 0) + maskedSrc = src[maskIndices[0], maskIndices[1]].astype(np.int32) + maskedDst = dst[maskIndices[0], maskIndices[1]].astype(np.int32) + meanSrc = np.mean(maskedSrc, axis=0) + meanDst = np.mean(maskedDst, axis=0) + maskedDst = maskedDst - meanDst + maskedDst = maskedDst + meanSrc + maskedDst = np.clip(maskedDst, 0, 255) + transferredDst[maskIndices[0], maskIndices[1]] = maskedDst + + return transferredDst diff --git a/training/dataset/utils/image_ae.py b/training/dataset/utils/image_ae.py new file mode 100644 index 0000000000000000000000000000000000000000..ee0e80701e3b0897a07430191c8ca618f5657a67 --- /dev/null +++ b/training/dataset/utils/image_ae.py @@ -0,0 +1,135 @@ +from torch import nn +from torch.autograd import Variable +import torch +import torch.nn.functional as F + +import torchvision.models as models + +def add_gaussian_noise(ins, mean=0, stddev=0.1): + noise = ins.data.new(ins.size()).normal_(mean, stddev) + return ins + noise + +class FlattenLayer(nn.Module): + def __init__(self): + super(FlattenLayer, self).__init__() + + def forward(self, x): + return x.view(x.size(0), -1) + + +class UnflattenLayer(nn.Module): + def __init__(self, width): + super(UnflattenLayer, self).__init__() + self.width = width + + def forward(self, x): + return x.view(x.size(0), -1, self.width, self.width) + +class VAE_Encoder(nn.Module): + ''' + VAE_Encoder: Encode image into std and logvar + ''' + + def __init__(self, latent_dim=256): + super(VAE_Encoder, self).__init__() + self.resnet = models.resnet18(pretrained=True) + self.resnet.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.resnet = nn.Sequential( + *list(self.resnet.children())[:-1], + FlattenLayer() + ) + + self.l_mu = nn.Linear(512, latent_dim) + self.l_var = nn.Linear(512, latent_dim) + + def encode(self, x): + hidden = self.resnet(x) + mu = self.l_mu(hidden) + logvar = self.l_var(hidden) + return mu, logvar + + def reparameterize(self, mu, logvar): + if self.training: + std = torch.exp(0.5*logvar) + eps = torch.randn_like(std) + return mu + eps*std + + else: + return mu + + def forward(self, x): + mu, logvar = self.encode(x) + z = self.reparameterize(mu, logvar) + return z, mu, logvar + + +class VAE_Decoder(nn.Module): + ''' + VAE_Decoder: Decode noise to image + ''' + + def __init__(self, latent_dim, output_dim=3): + super(VAE_Decoder, self).__init__() + self.convs = nn.Sequential( + UnflattenLayer(width=1), + nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(512, 384, 4, 2, 1, bias=False), + nn.BatchNorm2d(384), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(384, 192, 4, 2, 1, bias=False), + nn.BatchNorm2d(192), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(192, 96, 4, 2, 1, bias=False), + nn.BatchNorm2d(96), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(96, 64, 4, 2, 1, bias=False), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False), + nn.BatchNorm2d(32), + nn.LeakyReLU(inplace=True), + nn.ConvTranspose2d(32, 3, 4, 2, 1, bias=False), + nn.Tanh() + ) + + def forward(self, z): + return self.convs(z) + +class ImageAE(nn.Module): + # VAE architecture + def __init__(self): + super(ImageAE, self).__init__() + latent_dim = 512 + self.enc = VAE_Encoder(latent_dim) + self.dec = VAE_Decoder(latent_dim) + + def forward(self, x): + z, *_ = self.enc(x) + out = self.dec(z) + + return out + + def load_ckpt(self, enc_path, dec_path): + self.enc.load_state_dict(torch.load(enc_path, map_location='cpu')) + self.dec.load_state_dict(torch.load(dec_path, map_location='cpu')) + + +def get_pretraiend_ae(enc_path='pretrained/ae/vae/enc.pth', dec_path='pretrained/ae/vae/dec1.pth'): + ae = ImageAE() + ae.load_ckpt(enc_path, dec_path) + print('load image auto-encoder') + ae.eval() + return ae + +# from networks.pix2pix_network import UnetGenerator +def get_pretraiend_unet(path='pretrained/ae/unet/ckpt_srm.pth'): + unet = UnetGenerator(3, 3, 8) + unet.load_state_dict(torch.load(path, map_location='cpu')) + print('load Unet') + unet.eval() + return unet + +if __name__ == "__main__": + ae = get_pretraiend_ae() + print(ae) diff --git a/training/dataset/utils/umeyama.py b/training/dataset/utils/umeyama.py new file mode 100644 index 0000000000000000000000000000000000000000..a83548491f16e5e740c1144b9e181fe1587fb5bc --- /dev/null +++ b/training/dataset/utils/umeyama.py @@ -0,0 +1,84 @@ +## License (Modified BSD) +## Copyright (C) 2011, the scikit-image team All rights reserved. +## +## Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: +## +## Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +## Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. +## Neither the name of skimage nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. +## THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# umeyama function from scikit-image/skimage/transform/_geometric.py + +import numpy as np + + +def umeyama(src, dst, estimate_scale): + """Estimate N-D similarity transformation with or without scaling. + Parameters + ---------- + src : (M, N) array + Source coordinates. + dst : (M, N) array + Destination coordinates. + estimate_scale : bool + Whether to estimate scaling factor. + Returns + ------- + T : (N + 1, N + 1) + The homogeneous similarity transformation matrix. The matrix contains + NaN values only if the problem is not well-conditioned. + References + ---------- + .. [1] "Least-squares estimation of transformation parameters between two + point patterns", Shinji Umeyama, PAMI 1991, DOI: 10.1109/34.88573 + """ + + num = src.shape[0] + dim = src.shape[1] + + # Compute mean of src and dst. + src_mean = src.mean(axis=0) + dst_mean = dst.mean(axis=0) + + # Subtract mean from src and dst. + src_demean = src - src_mean + dst_demean = dst - dst_mean + + # Eq. (38). + A = np.dot(dst_demean.T, src_demean) / num + + # Eq. (39). + d = np.ones((dim,), dtype=np.double) + if np.linalg.det(A) < 0: + d[dim - 1] = -1 + + T = np.eye(dim + 1, dtype=np.double) + + U, S, V = np.linalg.svd(A) + + # Eq. (40) and (43). + rank = np.linalg.matrix_rank(A) + if rank == 0: + return np.nan * T + elif rank == dim - 1: + if np.linalg.det(U) * np.linalg.det(V) > 0: + T[:dim, :dim] = np.dot(U, V) + else: + s = d[dim - 1] + d[dim - 1] = -1 + T[:dim, :dim] = np.dot(U, np.dot(np.diag(d), V)) + d[dim - 1] = s + else: + T[:dim, :dim] = np.dot(U, np.dot(np.diag(d), V.T)) + + if estimate_scale: + # Eq. (41) and (42). + scale = 1.0 / src_demean.var(axis=0).sum() * np.dot(S, d) + else: + scale = 1.0 + + T[:dim, dim] = dst_mean - scale * np.dot(T[:dim, :dim], src_mean.T) + T[:dim, :dim] *= scale + + return T diff --git a/training/dataset/utils/warp.py b/training/dataset/utils/warp.py new file mode 100644 index 0000000000000000000000000000000000000000..437e58f43e9b15476eda06c7aedbec425e259aa0 --- /dev/null +++ b/training/dataset/utils/warp.py @@ -0,0 +1,111 @@ +import numpy as np +import cv2 +# from core import randomex + + +def random_normal(size=(1,), trunc_val=2.5): + len = np.array(size).prod() + result = np.empty((len,), dtype=np.float32) + + for i in range(len): + while True: + x = np.random.normal() + if x >= -trunc_val and x <= trunc_val: + break + result[i] = (x / trunc_val) + + return result.reshape(size) + + +def gen_warp_params(w, flip, rotation_range=[-10, 10], scale_range=[-0.5, 0.5], tx_range=[-0.05, 0.05], ty_range=[-0.05, 0.05], rnd_state=None): + if rnd_state is None: + rnd_state = np.random + + rotation = rnd_state.uniform(rotation_range[0], rotation_range[1]) + scale = rnd_state.uniform(1 + scale_range[0], 1 + scale_range[1]) + tx = rnd_state.uniform(tx_range[0], tx_range[1]) + ty = rnd_state.uniform(ty_range[0], ty_range[1]) + p_flip = flip and rnd_state.randint(10) < 4 + + # random warp by grid + cell_size = [w // (2**i) for i in range(1, 4)][rnd_state.randint(3)] + cell_count = w // cell_size + 1 + + grid_points = np.linspace(0, w, cell_count) + mapx = np.broadcast_to(grid_points, (cell_count, cell_count)).copy() + mapy = mapx.T + + mapx[1:-1, 1:-1] = mapx[1:-1, 1:-1] + \ + random_normal( + size=(cell_count-2, cell_count-2))*(cell_size*0.24) + mapy[1:-1, 1:-1] = mapy[1:-1, 1:-1] + \ + random_normal( + size=(cell_count-2, cell_count-2))*(cell_size*0.24) + + half_cell_size = cell_size // 2 + + mapx = cv2.resize(mapx, (w+cell_size,)*2)[ + half_cell_size:-half_cell_size-1, half_cell_size:-half_cell_size-1].astype(np.float32) + mapy = cv2.resize(mapy, (w+cell_size,)*2)[ + half_cell_size:-half_cell_size-1, half_cell_size:-half_cell_size-1].astype(np.float32) + + # random transform + random_transform_mat = cv2.getRotationMatrix2D( + (w // 2, w // 2), rotation, scale) + random_transform_mat[:, 2] += (tx*w, ty*w) + + params = dict() + params['mapx'] = mapx + params['mapy'] = mapy + params['rmat'] = random_transform_mat + params['w'] = w + params['flip'] = p_flip + + return params + + +def warp_by_params(params, img, can_warp, can_transform, can_flip, border_replicate, cv2_inter=cv2.INTER_CUBIC): + if can_warp: + img = cv2.remap(img, params['mapx'], params['mapy'], cv2_inter) + if can_transform: + img = cv2.warpAffine(img, params['rmat'], (params['w'], params['w']), borderMode=( + cv2.BORDER_REPLICATE if border_replicate else cv2.BORDER_CONSTANT), flags=cv2_inter) + if len(img.shape) == 2: + img = img[..., None] + if can_flip and params['flip']: + img = img[:, ::-1, ...] + return img + +from skimage.transform import PiecewiseAffineTransform, warp +def random_deform(imageSize, nrows, ncols, mean=0, std=5): + try: + h, w, c = imageSize + except: + h, w = imageSize + c = 1 + rows = np.linspace(0, h, nrows).astype(np.int32) + cols = np.linspace(0, w, ncols).astype(np.int32) + rows, cols = np.meshgrid(rows, cols) + anchors = np.vstack([rows.flat, cols.flat]).T + assert anchors.shape[1] == 2 and anchors.shape[0] == ncols * nrows + deformed = anchors + np.random.normal(mean, std, size=anchors.shape) + #print(anchors) + #print(deformed) + np.clip(deformed[:,0], 0, h-1, deformed[:,0]) + np.clip(deformed[:,1], 0, w-1, deformed[:,1]) + return anchors.astype(np.float32), deformed.astype(np.float32) + + +def piecewise_affine_transform(image, srcAnchor, tgtAnchor): + trans = PiecewiseAffineTransform() + trans.estimate(srcAnchor, tgtAnchor) + # tform.estimate(from_.astype(np.float32), to_.astype( + # np.float32)) # tform.estimate(from_, to_) + # M = tform.params[0:2, :] + warped = warp(image, trans) + return warped + +def warp_mask(mask, std): + ach, tgt_ach = random_deform(mask.shape, 4, 4, std=std) + warped_mask = piecewise_affine_transform(mask, ach, tgt_ach) + return (warped_mask*255).astype(np.uint8) diff --git a/training/detectors/__init__.py b/training/detectors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f529fb42418c6bb909832ddfa61cc084e9b01d58 --- /dev/null +++ b/training/detectors/__init__.py @@ -0,0 +1,123 @@ +import os +import sys +current_file_path = os.path.abspath(__file__) +parent_dir = os.path.dirname(os.path.dirname(current_file_path)) +project_root_dir = os.path.dirname(parent_dir) +sys.path.append(parent_dir) +sys.path.append(project_root_dir) + +from metrics.registry import DETECTOR +from .utils import slowfast + +from .facexray_detector import FaceXrayDetector +from .xception_detector import XceptionDetector +from .efficientnetb4_detector import EfficientDetector +from .resnet34_detector import ResnetDetector +from .f3net_detector import F3netDetector +from .meso4_detector import Meso4Detector +from .meso4Inception_detector import Meso4InceptionDetector +from .spsl_detector import SpslDetector +from .core_detector import CoreDetector +from .capsule_net_detector import CapsuleNetDetector +from .srm_detector import SRMDetector +from .ucf_detector import UCFDetector +from .recce_detector import RecceDetector +from .fwa_detector import FWADetector +from .ffd_detector import FFDDetector + +from .clip_detector import CLIPDetector +from .timesformer_detector import TimeSformerDetector +from .xclip_detector import XCLIPDetector +from .sbi_detector import SBIDetector +from .ftcn_detector import FTCNDetector +from .i3d_detector import I3DDetector +from .altfreezing_detector import AltFreezingDetector +from .stil_detector import STILDetector +from .lsda_detector import LSDADetector +from .sladd_detector import SLADDXceptionDetector +from .pcl_xception_detector import PCLXceptionDetector +from .iid_detector import IIDDetector +from .lrl_detector import LRLDetector +from .rfm_detector import RFMDetector +from .uia_vit_detector import UIAViTDetector +from .multi_attention_detector import MultiAttentionDetector +from .sia_detector import SIADetector +from .tall_detector import TALLDetector +from .cnn_dct_detector import CNNDCTDetector +from .clip_image_detector import CLIPImageDetector +from .dino_contrast import DINOv2_Large_FFT_Contrast_Detector + +from .d_det import DNADetector +from .universal import UniversalDetector +from .clip_large_detector import CLIP_Large_Detector +# Effort Models +from .effort_detector import EffortDetector +from .effort_vid_detector import EffortVidDetector # 16 frames + avg +# VideoMAE LoRA +from .videomae_detector import VideoMAEDetector +from .videomae_large_detector import VideoMAELargeDetector +from .videomae_lora_detector import VideoMAELoRADetector +# CLIP-ViT FFT +from .clip_large_fft_detector import CLIP_Large_FFT_Detector +from .clip_base_fft_detector import CLIP_Base_FFT_Detector +# CLIP-ViT LoRA r16 +from .clip_base_vid_detector import CLIP_Base_Vid_Detector +from .clip_large_vid_detector import CLIP_Large_Vid_Detector +from .clip_openai_vid_detector import CLIP_Openai_Large_Vid_Detector +from .clip_large_lora_detector import CLIP_Large_LoRA_Detector +# CLIP-ViT-Adapter +from .clip_adapter_two_3dconv_detector import CLIPAdapter3DConvDetector + +# CLIP Contrast +from .clip_contrast_detector import CLIP_Contrast +from .clip_hier_contrast_detector import CLIP_Contrast_HIER + +#aug +from .clip_large_lsda import CLIP_Large_FFT_LSDA_Detector +from .clip_patch_shuffle import CLIP_PATCH_SHUFFLE_Detector +from .vit_detector import ViT_Large_FFT_Detector +from .effort_patch_shuffle import Effort_Shuffle_Ensenble_Detector + + +# DINO +from .dinov2_large_fft_detector import DINOv2_Large_FFT_Detector +from .dinov3_large_fft_detector import DINOv3_Large_FFT_Detector +from .clip_large_fft_supcon_detector import CLIP_Large_FFT_SupCon_Detector +# SupConCls +from .clip_large_fft_supcon_cls_detector import CLIP_Large_FFT_SupCon_Cls_Detector +# Dis +from .clip_large_fft_dis_detector import CLIP_Large_FFT_Dis_Detector +from .clip_large_fft_dis_orth_detector import CLIP_Large_FFT_Dis_Orth_Detector +from .clip_large_fft_dis_orth1_detector import CLIP_Large_FFT_Dis_Orth1_Detector # Full features -net> semantic features; full features - semantic features; orthogonal loss +from .clip_large_fft_dis_orth2_detector import CLIP_Large_FFT_Dis_Orth2_Detector # Full features -net> semantic features; subtract the projection of full features onto semantic features; orthogonal loss with dual mapping +from .clip_large_fft_dis_orth3_detector import CLIP_Large_FFT_Dis_Orth3_Detector # Full features -net> semantic features; subtract the projection of full features onto semantic features; orthogonal loss with dual mapping, using ReLU in the projection layer +# VAE +from .clip_large_fft_vae1_detector import CLIP_Large_FFT_VAE1_Detector +from .clip_large_fft_vae2_detector import CLIP_Large_FFT_VAE2_Detector +# Concat +from .clip_large_fft_dis_cat1_detector import CLIP_Large_FFT_Dis_Cat1_Detector +from .clip_large_fft_dis_cat2_detector import CLIP_Large_FFT_Dis_Cat2_Detector +from .clip_large_fft_dino_orth_detector import CLIP_Large_FFT_Dino_Orth_Detector +# Effort CL +from .effort_cl_detector import EffortCLDetector +# AE +from .ae_detector import LDM_AE_Classify_Detector +from .ae_detector_resnet34 import ResNet34_AE_Trace_Detector +# POSE +from .pose_detector import POSE_Detector +# HRNet +from .hrnet_detector import HRNet_Detector +# RepMix +from .repmix_detector import RepMix_Detector +# Lorax +from .lorax_detector import LoRAXConvitDetector +# GANAtt +from .ganatt_detector import GANAtt_Detector +# NPR +from .npr_detector import NPR +#OOC +from .ooc_detector import OOCDetector + +from .resnet34_distill_detector import DetectorDistill + + diff --git a/training/detectors/ae_detector.py b/training/detectors/ae_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..ed42ccb41534b325371085648c35eb4b8633d359 --- /dev/null +++ b/training/detectors/ae_detector.py @@ -0,0 +1,385 @@ +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter +from torchvision.models import vgg16 +import torchvision.transforms.functional as F_tv + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC + +logger = logging.getLogger(__name__) + +# --------------------------------------------------- +class Swish(nn.Module): + """Swish activation commonly used in LDM""" + def forward(self, x): + return x * torch.sigmoid(x) + +def weights_init(m): + """Initialize weights (for GAN and AE)""" + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm') != -1 or classname.find('GroupNorm') != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) + +class PerceptualLoss(nn.Module): + """LDM-style perceptual loss (VGG16 intermediate-layer features) - fixed for batch normalization issues""" + def __init__(self): + super().__init__() + # Load pretrained VGG16 and freeze its parameters + vgg = vgg16(pretrained=True).features[:10] + for param in vgg.parameters(): + param.requires_grad = False + self.vgg = vgg + # Store normalization parameters (use functional ops instead of instantiating transforms.Normalize) + self.mean = torch.tensor([0.485, 0.456, 0.406]) + self.std = torch.tensor([0.229, 0.224, 0.225]) + self.mse = nn.MSELoss() + + def forward(self, x, x_recon): + """ + Fix: support normalization for batched data ([B,3,H,W]) + x: original image ([B,3,H,W], ∈[-1,1]) + x_recon: reconstructed image ([B,3,H,W], ∈[-1,1]) + """ + # 1. [-1,1] → [0,1] + x = (x + 1) / 2 + x_recon = (x_recon + 1) / 2 + + # 2. Move mean/std to the same device as x (GPU/CPU) + mean = self.mean.to(x.device, dtype=x.dtype) + std = self.std.to(x.device, dtype=x.dtype) + + # 3. Use torchvision.functional.normalize (supports batched data [B,C,H,W]) + # Note: mean/std must be expanded to [1,3,1,1] to match the batched tensor dimensions + x_norm = F_tv.normalize(x, mean=mean.view(1, 3, 1, 1), std=std.view(1, 3, 1, 1)) + x_recon_norm = F_tv.normalize(x_recon, mean=mean.view(1, 3, 1, 1), std=std.view(1, 3, 1, 1)) + + # 4. Extract VGG features and compute MSE + feat_x = self.vgg(x_norm) + feat_recon = self.vgg(x_recon_norm) + return self.mse(feat_x, feat_recon) + +class LDMEncoder(nn.Module): + """LDM-style VAE encoder""" + def __init__(self, in_channels=3, latent_channels=4, base_channels=64): + super().__init__() + self.encoder = nn.Sequential( + nn.Conv2d(in_channels, base_channels, kernel_size=4, stride=2, padding=1), + nn.GroupNorm(32, base_channels), + Swish(), + nn.Conv2d(base_channels, base_channels*2, kernel_size=4, stride=2, padding=1), + nn.GroupNorm(32, base_channels*2), + Swish(), + nn.Conv2d(base_channels*2, base_channels*4, kernel_size=4, stride=2, padding=1), + nn.GroupNorm(32, base_channels*4), + Swish(), + nn.Conv2d(base_channels*4, base_channels*8, kernel_size=4, stride=2, padding=1), + nn.GroupNorm(32, base_channels*8), + Swish(), + nn.Conv2d(base_channels*8, latent_channels, kernel_size=3, stride=1, padding=1) + ) + + def forward(self, x): + return self.encoder(x) + +class LDMDecoder(nn.Module): + """LDM-style VAE decoder""" + def __init__(self, out_channels=3, latent_channels=4, base_channels=64): + super().__init__() + self.decoder = nn.Sequential( + nn.Conv2d(latent_channels, base_channels*8, kernel_size=3, stride=1, padding=1), + nn.GroupNorm(32, base_channels*8), + Swish(), + nn.ConvTranspose2d(base_channels*8, base_channels*4, kernel_size=4, stride=2, padding=1), + nn.GroupNorm(32, base_channels*4), + Swish(), + nn.ConvTranspose2d(base_channels*4, base_channels*2, kernel_size=4, stride=2, padding=1), + nn.GroupNorm(32, base_channels*2), + Swish(), + nn.ConvTranspose2d(base_channels*2, base_channels, kernel_size=4, stride=2, padding=1), + nn.GroupNorm(32, base_channels), + Swish(), + nn.ConvTranspose2d(base_channels, out_channels, kernel_size=4, stride=2, padding=1), + nn.Tanh() + ) + + def forward(self, z): + return self.decoder(z) + +class PatchGANDiscriminator(nn.Module): + """PatchGAN discriminator (used for adversarial loss)""" + def __init__(self, in_channels=3, base_channels=64): + super().__init__() + self.discriminator = nn.Sequential( + nn.Conv2d(in_channels, base_channels, kernel_size=4, stride=2, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(base_channels, base_channels*2, kernel_size=4, stride=2, padding=1), + nn.GroupNorm(32, base_channels*2), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(base_channels*2, base_channels*4, kernel_size=4, stride=2, padding=1), + nn.GroupNorm(32, base_channels*4), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(base_channels*4, base_channels*8, kernel_size=4, stride=2, padding=1), + nn.GroupNorm(32, base_channels*8), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(base_channels*8, 1, kernel_size=4, stride=1, padding=1) + ) + + def forward(self, x): + return self.discriminator(x) + +class AETotalLoss(nn.Module): + """Combined loss: perceptual loss + MSE + adversarial loss + classification loss + LDM KL loss""" + def __init__(self, lambda_perceptual=0, lambda_mse=0, lambda_adv=0, lambda_cls=1.0, lambda_kl=0): + super().__init__() + self.lambda_perceptual = lambda_perceptual + self.lambda_mse = lambda_mse + self.lambda_adv = lambda_adv + self.lambda_cls = lambda_cls + self.lambda_kl = lambda_kl # initial KL weight + self.perceptual_loss = PerceptualLoss() + self.mse = nn.MSELoss() + self.ce = nn.CrossEntropyLoss() + + def forward(self, x, x_recon, pred, labels, z, d_fake): + """ + Generator loss: excludes the discriminator real branch (discriminator is computed separately) + Args: + d_fake: discriminator output for the reconstructed image (already detached or from an independent computation graph) + """ + # 1. perceptual loss + loss_perceptual = self.perceptual_loss(x, x_recon) + # 2. MSE loss + loss_mse = self.mse(x, x_recon) + # 3. adversarial loss (generator part only; d_fake is already detached) + loss_adv_gen = self.mse(d_fake, torch.ones_like(d_fake)) + # 4. classification loss + loss_cls = self.ce(pred, labels) + # 5. simplified LDM KL loss (z is the mean μ) + loss_kl = 0.5 * torch.sum(z ** 2, dim=[1,2,3]).mean() + + # total generator loss + total_loss = ( + self.lambda_perceptual * loss_perceptual + + self.lambda_mse * loss_mse + + self.lambda_adv * loss_adv_gen + + self.lambda_cls * loss_cls + + self.lambda_kl * loss_kl + ) + + return { + 'gen_total': loss_cls, + 'loss_perceptual': loss_perceptual, + 'loss_mse': loss_mse, + 'loss_adv_gen': loss_adv_gen, + 'loss_cls': loss_cls, + 'loss_kl': loss_kl + } + +class DiscriminatorLoss(nn.Module): + """Standalone discriminator loss (computed separately to avoid conflicts with the generator graph)""" + def __init__(self): + super().__init__() + self.mse = nn.MSELoss() + + def forward(self, d_real, d_fake): + """ + discriminator loss (LSGAN) + Args: + d_real: discriminator output for the original image + d_fake: discriminator output for the reconstructed image (detach the generator graph) + """ + loss_real = self.mse(d_real, torch.ones_like(d_real)) + loss_fake = self.mse(d_fake.detach(), torch.zeros_like(d_fake)) + return 0.5 * (loss_real + loss_fake) + +# -------------------------- Bench-adapted detector (core class) -------------------------- +@DETECTOR.register_module(module_name='ae_detector') +class LDM_AE_Classify_Detector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.num_classes = config['backbone_config']['num_classes'] + self.latent_channels = config.get('latent_channels', 1024) + self.base_channels = config.get('base_channels', 64) + self.in_channels = config.get('in_channels', 3) + self.out_channels = config.get('out_channels', 3) + + # KL annealing configuration + self.kl_annealing_steps = config.get('kl_annealing_steps', 2000) + self.max_lambda_kl = config.get('lambda_kl', 0.001) + self.current_step = 0 # global iteration step + + # Build network components + self.encoder, self.decoder, self.classifier_head, self.discriminator = self.build_backbone(config) + # Build the loss function (separate generator and discriminator losses) + self.gen_loss_func, self.disc_loss_func = self.build_loss(config) + + # Logging configuration + logger.info(f"LDM-AE Detector initialized successfully:") + logger.info(f" - number of classes: {self.num_classes}") + logger.info(f" - latent-space channel count: {self.latent_channels}") + logger.info(f" - KL annealing steps: {self.kl_annealing_steps}") + logger.info(f" - maximum KL weight: {self.max_lambda_kl}") + + def build_backbone(self, config): + """Build the AE encoder, decoder, classification head, and discriminator""" + # 1. LDM encoder + encoder = LDMEncoder( + in_channels=self.in_channels, + latent_channels=self.latent_channels, + base_channels=self.base_channels + ) + # 2. LDM decoder + decoder = LDMDecoder( + out_channels=self.out_channels, + latent_channels=self.latent_channels, + base_channels=self.base_channels + ) + # 3. Latent-space classification head (renamed to classifier_head to avoid conflict with the abstract classifier method) + classifier_head = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Flatten(), + nn.Linear(self.latent_channels, 256), + nn.LayerNorm(256), + Swish(), + nn.Dropout(0.5), + nn.Linear(256, self.num_classes) + ) + # 4. PatchGAN discriminator + discriminator = PatchGANDiscriminator(in_channels=self.in_channels) + + # Initialize weights + encoder.apply(weights_init) + decoder.apply(weights_init) + classifier_head.apply(weights_init) + discriminator.apply(weights_init) + + return encoder, decoder, classifier_head, discriminator + + def build_loss(self, config): + """Build separate loss functions for the generator and discriminator""" + # generator loss + gen_loss_kwargs = { + 'lambda_perceptual': config.get('lambda_perceptual', 0.8), + 'lambda_mse': config.get('lambda_mse', 0.2), + 'lambda_adv': config.get('lambda_adv', 0.1), + 'lambda_cls': config.get('lambda_cls', 1.0), + 'lambda_kl': config.get('lambda_kl', 0.001) + } + gen_loss_func = AETotalLoss(**gen_loss_kwargs) + # discriminator loss (independent) + disc_loss_func = DiscriminatorLoss() + return gen_loss_func, disc_loss_func + + def features(self, data_dict: dict) -> tuple: + """Extract features: encode + decode, and return latent features and the reconstructed image""" + x = data_dict['image'] # Input: [B,3,H,W] ∈[-1,1] + z = self.encoder(x) # Encode: [B,4,H/16,W/16] + x_recon = self.decoder(z) # Decode: [B,3,H,W] ∈[-1,1] + return z, x_recon + + def classifier(self, z: torch.Tensor) -> torch.Tensor: + """Implement the abstract classifier method (matching the AbstractDetector interface)""" + return self.classifier_head(z) + + def get_gen_losses(self, data_dict: dict, pred_dict: dict) -> dict: + """Compute generator loss (independent of the discriminator)""" + # 1. Extract inputs and predictions + x = data_dict['image'] # original image + labels = data_dict['label'] # ground-truth labels + x_recon = pred_dict['x_recon'] # reconstructed image + pred = pred_dict['cls'] # classification prediction + z = pred_dict['z'] # latent features + d_fake = pred_dict['d_fake'] # discriminator output for the reconstructed image (already detached) + + # 2. KL annealing: dynamically adjust the KL weight + if self.current_step < self.kl_annealing_steps: + self.gen_loss_func.lambda_kl = self.max_lambda_kl * (self.current_step / self.kl_annealing_steps) + else: + self.gen_loss_func.lambda_kl = self.max_lambda_kl + self.current_step += 1 + + # 3. Compute generator loss + loss_dict = self.gen_loss_func(x, x_recon, pred, labels, z, d_fake) + + # 4. Assemble the loss dictionary + return { + 'overall': loss_dict['gen_total'], # total generator loss + 'loss_perceptual': loss_dict['loss_perceptual'].item(), + 'loss_mse': loss_dict['loss_mse'].item(), + 'loss_adv_gen': loss_dict['loss_adv_gen'].item(), + 'loss_cls': loss_dict['loss_cls'].item(), + 'loss_kl': loss_dict['loss_kl'].item(), + 'lambda_kl': self.gen_loss_func.lambda_kl + } + + def get_disc_losses(self, data_dict: dict, pred_dict: dict) -> torch.Tensor: + """Compute discriminator loss (fully independent)""" + d_real = pred_dict['d_real'] # discriminator output for the original image + d_fake = pred_dict['d_fake'] # discriminator output for the reconstructed image + return self.disc_loss_func(d_real, d_fake) + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + """Compute training metrics (adapted to the Bench interface)""" + labels = data_dict['label'].detach() + pred = pred_dict['cls'].detach() + acc, mAP = calculate_acc_for_train(labels, pred, self.num_classes) + return {'acc': acc, 'mAP': mAP} + + def forward(self, data_dict: dict, inference=False) -> dict: + """Forward pass (separate generator and discriminator computation graphs)""" + x = data_dict['image'] + # 1. Generator forward pass (encode + decode + classify) + z, x_recon = self.features(data_dict) + pred = self.classifier(z) + prob = torch.softmax(pred, dim=1) + + # 2. Discriminator forward pass (computed independently to avoid coupling with the generator graph) + with torch.no_grad(): # Do not compute discriminator gradients during generator inference + d_real = self.discriminator(x) # discriminate the original image + d_fake = self.discriminator(x_recon) # discriminate the reconstructed image + + # 3. Build the prediction dictionary (including discriminator outputs for loss computation) + pred_dict = { + 'cls': pred, # classification logits + 'prob': prob, # classification probabilities + 'z': z, # latent features + 'x_recon': x_recon, # reconstructed image + 'feat': z.mean(dim=[2,3]), # global mean of latent features + 'd_real': d_real, # discriminator real output + 'd_fake': d_fake # discriminator fake output + } + + # Visualization during inference + # if inference and hasattr(self, 'writer'): + # self.writer.add_images('recon/image_gt', (x + 1)/2, global_step=self.current_step, dataformats='NCHW') + # self.writer.add_images('recon/image_pred', (x_recon + 1)/2, global_step=self.current_step, dataformats='NCHW') + + return pred_dict + + def get_discriminator(self): + """Provide an accessor for the discriminator (the Bench training loop needs to optimize it separately)""" + return self.discriminator + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + """Bench-compatible interface: return generator loss (the discriminator loss is obtained separately via get_disc_losses)""" + return self.get_gen_losses(data_dict, pred_dict) \ No newline at end of file diff --git a/training/detectors/ae_detector_clip.py b/training/detectors/ae_detector_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..77073d506a323ba5f0888770519840c56eded12c --- /dev/null +++ b/training/detectors/ae_detector_clip.py @@ -0,0 +1,332 @@ +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union, Dict, Tuple +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter +from torchvision.models import vgg16 +from torchvision import transforms +from PIL import Image + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train +from metrics.recon_metrics import calculate_psnr, calculate_ssim + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +from transformers import AutoProcessor, CLIPModel, ViTModel, ViTConfig +import loralib as lora +import copy + +logger = logging.getLogger(__name__) + +# ---------------------------------------------------- +class LocalAttention(nn.Module): + + def __init__(self, dim, window_size=4): + super().__init__() + self.dim = dim + self.window_size = window_size + self.query_conv = nn.Conv2d(dim, dim//4, kernel_size=1) + self.key_conv = nn.Conv2d(dim, dim//4, kernel_size=1) + self.value_conv = nn.Conv2d(dim, dim, kernel_size=1) + self.norm = nn.LayerNorm(dim) + + def forward(self, x): + B, C, H, W = x.shape + x_flat = x.permute(0, 2, 3, 1).reshape(B, H*W, C) + x_norm = self.norm(x_flat) + + q = self.query_conv(x).permute(0, 2, 3, 1).reshape(B, H*W, C//4) + k = self.key_conv(x).permute(0, 2, 3, 1).reshape(B, H*W, C//4) + v = self.value_conv(x).permute(0, 2, 3, 1).reshape(B, H*W, C) + + attn_weights = torch.zeros(B, H*W, H*W, device=x.device) + for i in range(H*W): + row = i // W + col = i % W + row_start = max(0, row - self.window_size//2) + row_end = min(H, row + self.window_size//2 + 1) + col_start = max(0, col - self.window_size//2) + col_end = min(W, col + self.window_size//2 + 1) + + window_indices = [] + for r in range(row_start, row_end): + for c in range(col_start, col_end): + window_indices.append(r * W + c) + window_indices = torch.tensor(window_indices, device=x.device) + + q_i = q[:, i:i+1, :] + k_window = k[:, window_indices, :] + attn_i = torch.bmm(q_i, k_window.transpose(1, 2)) / (C//4)**0.5 + attn_weights[:, i, window_indices] = F.softmax(attn_i, dim=-1) + + output = torch.bmm(attn_weights, v) + output = output.reshape(B, H, W, C).permute(0, 3, 1, 2) + return output + x +class PerceptualLossModule(nn.Module): + + def __init__(self): + super().__init__() + vgg = vgg16(pretrained=True).features + self.feature_extractor = nn.Sequential(*list(vgg.children())[:8]).eval() + for param in self.feature_extractor.parameters(): + param.requires_grad = False + + def forward(self, pred, target): + pred_feat = self.feature_extractor(pred) + target_feat = self.feature_extractor(target) + return F.mse_loss(pred_feat, target_feat) + +# -------------------------- -------------------------- +@DETECTOR.register_module(module_name='clip_vit_ae_recon') +class CLIP_ViT_AE_Reconstructor(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + + self.latent_dim = config['backbone_config'].get('latent_dim', 512) + self.freeze_clip_layers = config['backbone_config'].get('freeze_clip_layers', 4) + self.perceptual_weight = config.get('perceptual_weight', 0.3) # + self.num_classes = config['backbone_config'].get('num_classes', 36) + self.enable_classification = config.get('enable_classification', True) + self.recon_weight = config.get('recon_weight', 0.7) + self.cls_weight = config.get('cls_weight', 0.3) + + + self.backbone = self.build_backbone(config) # CLIP-ViT Encoder + self.decoder = self.build_decoder(config) + self.classifier = self.build_classifier() + + + self.mse_loss = nn.MSELoss() + self.perceptual_loss = PerceptualLossModule().to(self.device) + + + self.best_psnr = 0.0 + self.current_epoch = 0 + + def build_backbone(self, config) -> nn.Module: + """Build the CLIP-ViT-Base encoder""" + + model_name = config['backbone_config'].get('pretrained', 'openai/clip-vit-base-patch16') + clip_model = CLIPModel.from_pretrained(model_name).vision_model.to(self.device) + + num_transformer_layers = len(clip_model.encoder.layers) + for i, layer in enumerate(clip_model.encoder.layers): + if i < self.freeze_clip_layers: + for param in layer.parameters(): + param.requires_grad = False + + + feature_compress = nn.Sequential( + nn.Linear(768, self.latent_dim), + nn.LayerNorm(self.latent_dim), + nn.GELU() + ).to(self.device) + + + class CLIPEncoderWrapper(nn.Module): + def __init__(self, clip_vision, compress_layer, latent_dim): + super().__init__() + self.clip_vision = clip_vision + self.compress_layer = compress_layer + self.latent_dim = latent_dim + + def forward(self, x): + # x: [batch, 3, 224, 224] + clip_output = self.clip_vision(x) # last_hidden_state: [batch, 197, 768] + patch_tokens = clip_output.last_hidden_state[:, 1:, :] + compressed = self.compress_layer(patch_tokens) # [batch, 196, latent_dim] + latent = compressed.reshape(-1, 14, 14, self.latent_dim).permute(0, 3, 1, 2) # [batch, latent_dim, 14, 14] + return latent + + return CLIPEncoderWrapper(clip_model, feature_compress, self.latent_dim) + + def build_decoder(self, config) -> nn.Module: + + latent_dim = self.latent_dim + + class Decoder(nn.Module): + def __init__(self, latent_dim): + super().__init__() + self.upsample_blocks = nn.Sequential( + # Block 1: 14x14 -> 28x28, 512 -> 256 + self._build_upsample_block(latent_dim, 256), + # Block 2: 28x28 -> 56x56, 256 -> 128 + self._build_upsample_block(256, 128), + # Block 3: 56x56 -> 112x112, 128 -> 64 + self._build_upsample_block(128, 64), + # Block 4: 112x112 -> 224x224, 64 -> 32 + self._build_upsample_block(64, 32) + ) + self.final_conv = nn.Sequential( + nn.Conv2d(32, 16, kernel_size=3, padding=1), + nn.GELU(), + nn.Conv2d(16, 3, kernel_size=3, padding=1), + nn.Sigmoid() + ) + + def _build_upsample_block(self, in_ch, out_ch): + return nn.Sequential( + nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1, bias=False), + nn.LayerNorm([out_ch, 0, 0]), + nn.GELU(), + LocalAttention(out_ch, window_size=4), + nn.Conv2d(out_ch, out_ch, 3, padding=1), + nn.LayerNorm([out_ch, 0, 0]), + nn.GELU() + ) + + def forward(self, latent): + x = self.upsample_blocks(latent) + return self.final_conv(x) + + return Decoder(latent_dim).to(self.device) + + def build_classifier(self) -> nn.Module: + + return nn.Sequential( + nn.AdaptiveAvgPool2d(1), # [batch, latent_dim, 1, 1] + nn.Flatten(), # [batch, latent_dim] + nn.Linear(self.latent_dim, self.num_classes) + ).to(self.device) + + def features(self, data_dict: dict) -> torch.Tensor: + + return self.backbone(data_dict['image']) + + def reconstruct(self, latent: torch.Tensor) -> torch.Tensor: + + return self.decoder(latent) + + def classify(self, latent: torch.Tensor) -> torch.Tensor: + + return self.classifier(latent) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + + loss_dict = {} + recon = pred_dict['recon'] + target = data_dict['image'] + + + mse_loss_val = self.mse_loss(recon, target) + perceptual_loss_val = self.perceptual_loss(recon, target) + recon_loss = mse_loss_val + self.perceptual_weight * perceptual_loss_val + loss_dict['recon_loss'] = recon_loss + loss_dict['mse_loss'] = mse_loss_val + loss_dict['perceptual_loss'] = perceptual_loss_val + + + if self.enable_classification and 'label' in data_dict: + cls_pred = pred_dict['cls'] + cls_label = data_dict['label'] + cls_loss = F.cross_entropy(cls_pred, cls_label) + loss_dict['cls_loss'] = cls_loss + + + total_loss = self.recon_weight * recon_loss + self.cls_weight * cls_loss + loss_dict['overall'] = total_loss + else: + + loss_dict['overall'] = recon_loss + + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + + labels = data_dict['label'].detach() + pred_cls = pred_dict['cls'].detach() + acc, mAP = calculate_acc_for_train(labels, pred_cls, self.num_classes) + + x = data_dict['image'].detach() + x_recon = pred_dict['x_recon'].detach() + x = (x + 1) / 2 # [-1,1] → [0,1] + x_recon = (x_recon + 1) / 2 + + mse = F.mse_loss(x, x_recon).item() + psnr = 10 * np.log10(1.0 / (mse + 1e-6)) + + + ssim = self._calculate_ssim(x, x_recon) + + return { + 'acc': acc, 'mAP': mAP, + 'psnr': psnr, 'ssim': ssim + } + + def _calculate_ssim(self, x, x_recon, window_size=11, sigma=1.5): + + B, C, H, W = x.shape + + gauss = torch.Tensor(np.exp(-np.arange(0, window_size)**2 / (2 * sigma**2))).to(x.device) + gauss = gauss / gauss.sum() + window = gauss.unsqueeze(1) @ gauss.unsqueeze(0) + window = window.expand(C, 1, window_size, window_size).contiguous() + + + mu1 = F.conv2d(x, window, padding=window_size//2, groups=C) + mu2 = F.conv2d(x_recon, window, padding=window_size//2, groups=C) + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d(x*x, window, padding=window_size//2, groups=C) - mu1_sq + sigma2_sq = F.conv2d(x_recon*x_recon, window, padding=window_size//2, groups=C) - mu2_sq + sigma12 = F.conv2d(x*x_recon, window, padding=window_size//2, groups=C) - mu1_mu2 + + C1 = (0.01 * 1)**2 + C2 = (0.03 * 1)**2 + ssim_map = ((2*mu1_mu2 + C1) * (2*sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean().item() + + def forward(self, data_dict: dict, inference=False) -> dict: + + latent = self.features(data_dict) + + recon = self.reconstruct(latent) + + cls_pred = self.classify(latent) if self.enable_classification else None + + cls_prob = torch.softmax(cls_pred, dim=1) if cls_pred is not None else None + + + pred_dict = { + 'latent': latent, + 'recon': recon, + 'cls': cls_pred, + 'cls_prob': cls_prob, + 'psnr': calculate_psnr(recon, data_dict['image']).item() if inference else None, + 'ssim': calculate_ssim(recon, data_dict['image']).item() if inference else None + } + + return pred_dict + + def save_best_model(self, save_path: str, metric: float): + + if metric > self.best_psnr: + self.best_psnr = metric + torch.save({ + 'epoch': self.current_epoch, + 'model_state_dict': self.state_dict(), + 'best_psnr': self.best_psnr, + 'config': self.config + }, os.path.join(save_path, 'best_clip_ae.pth')) + logger.info(f"best save in (PSNR: {self.best_psnr:.2f})") + +# --------------------------------------------------- +def get_clip_processor(model_name: str = "openai/clip-vit-base-patch16") -> AutoProcessor: + + return AutoProcessor.from_pretrained(model_name) diff --git a/training/detectors/ae_detector_resnet34.py b/training/detectors/ae_detector_resnet34.py new file mode 100644 index 0000000000000000000000000000000000000000..9cac15e8b9ab0935726c56976e82bcaa800419c5 --- /dev/null +++ b/training/detectors/ae_detector_resnet34.py @@ -0,0 +1,455 @@ +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter +from torchvision.models import vgg16 +import torchvision.transforms.functional as F_tv + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC + +logger = logging.getLogger(__name__) + +# -------------------------- ResNet-34 -------------------------- +class BasicBlock(nn.Module): + + expansion = 1 + + def __init__(self, in_channels, out_channels, stride=1, downsample=None, upsample=None): + super().__init__() + self.stride = stride + self.upsample = upsample + self.downsample = downsample + + + if self.upsample is not None: + + self.conv1 = nn.ConvTranspose2d( + in_channels, out_channels, kernel_size=4, stride=stride, padding=1, bias=False + ) + else: + + self.conv1 = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False + ) + + self.bn1 = nn.BatchNorm2d(out_channels) + self.conv2 = nn.Conv2d( + out_channels, out_channels * self.expansion, kernel_size=3, stride=1, padding=1, bias=False + ) + self.bn2 = nn.BatchNorm2d(out_channels * self.expansion) + self.act = nn.SiLU() + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.act(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + + residual = self.downsample(x) + elif self.upsample is not None: + + residual = self.upsample(x) + + out += residual + out = self.act(out) + return out + +class ResNet34_Encoder(nn.Module): + + def __init__(self, in_channels=3, latent_channels=512, base_channels=64): + super().__init__() + self.in_channels = base_channels + + self.conv1 = nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(base_channels) + self.act = nn.SiLU() + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(BasicBlock, base_channels*1, 3, stride=1) + self.layer2 = self._make_layer(BasicBlock, base_channels*2, 4, stride=2) + self.layer3 = self._make_layer(BasicBlock, base_channels*4, 6, stride=2) + self.layer4 = self._make_layer(BasicBlock, base_channels*8, 3, stride=2) + + self.latent_proj = nn.Sequential( + nn.Conv2d(base_channels*8*BasicBlock.expansion, latent_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(latent_channels), + nn.SiLU() + ) + + self.apply(self._weights_init) + + def _make_layer(self, block, out_channels, block_num, stride=1): + """Build encoder residual blocks (downsampling)""" + downsample = None + + if stride != 1 or self.in_channels != out_channels * block.expansion: + downsample = nn.Sequential( + + nn.Conv2d( + self.in_channels, out_channels * block.expansion, + kernel_size=1, stride=stride, bias=False + ), + nn.BatchNorm2d(out_channels * block.expansion) + ) + + layers = [] + + layers.append(block(self.in_channels, out_channels, stride, downsample=downsample, upsample=None)) + self.in_channels = out_channels * block.expansion + + for _ in range(1, block_num): + layers.append(block(self.in_channels, out_channels, stride=1, downsample=None, upsample=None)) + + return nn.Sequential(*layers) + + def _weights_init(self, m): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='linear') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, x): + # input:[B,3,H,W] ∈[-1,1] → output:[B, latent_channels, H/32, W/32] + x = self.conv1(x) + x = self.bn1(x) + x = self.act(x) + x = self.maxpool(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + z = self.latent_proj(x) + return z + +class ResNet34_Decoder(nn.Module): + + def __init__(self, out_channels=3, latent_channels=512, base_channels=64): + super().__init__() + self.base_channels = base_channels + + self.latent_inv_proj = nn.Sequential( + nn.Conv2d(latent_channels, base_channels*8, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(base_channels*8), + nn.SiLU() + ) + + self.in_channels = base_channels*8 + self.layer4 = self._make_layer(BasicBlock, base_channels*8, 3, stride=2) + self.layer3 = self._make_layer(BasicBlock, base_channels*4, 6, stride=2) + self.layer2 = self._make_layer(BasicBlock, base_channels*2, 4, stride=2) + self.layer1 = self._make_layer(BasicBlock, base_channels*1, 3, stride=2) + + self.final_upsample = nn.Sequential( + nn.ConvTranspose2d(base_channels, base_channels//2, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(base_channels//2), + nn.SiLU(), + nn.Conv2d(base_channels//2, out_channels, kernel_size=3, stride=1, padding=1), + nn.Tanh() + ) + + self.apply(self._weights_init) + + def _make_layer(self, block, out_channels, block_num, stride=2): + + upsample = None + + if stride == 2 or self.in_channels != out_channels * block.expansion: + upsample = nn.Sequential( + + nn.ConvTranspose2d( + self.in_channels, out_channels * block.expansion, + kernel_size=4, stride=stride, padding=1, bias=False + ), + nn.BatchNorm2d(out_channels * block.expansion) + ) + + layers = [] + + layers.append(block(self.in_channels, out_channels, stride=stride, downsample=None, upsample=upsample)) + self.in_channels = out_channels * block.expansion + + for _ in range(1, block_num): + layers.append(block(self.in_channels, out_channels, stride=1, downsample=None, upsample=None)) + + return nn.Sequential(*layers) + + def _weights_init(self, m): + if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='linear') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, z): + # input:[B, latent_channels, H/32, W/32] → output:[B,3,H,W] ∈[-1,1] + x = self.latent_inv_proj(z) + x = self.layer4(x) + x = self.layer3(x) + x = self.layer2(x) + x = self.layer1(x) + x_recon = self.final_upsample(x) + return x_recon + +class TraceClassifierHead(nn.Module): + + def __init__(self, latent_channels=512, num_classes=10): + super().__init__() + self.classifier = nn.Sequential( + nn.AdaptiveAvgPool2d(1), # [B,C,H/32,W/32] → [B,C,1,1] + nn.Flatten(), # → [B,C] + nn.Linear(latent_channels, 1024), + nn.LayerNorm(1024), + nn.SiLU(), + nn.Dropout(0.5), + nn.Linear(1024, 512), + nn.LayerNorm(512), + nn.SiLU(), + nn.Dropout(0.3), + nn.Linear(512, num_classes) + ) + + self.apply(self._weights_init) + + def _weights_init(self, m): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, z): + + return self.classifier(z) + +# ----------------------------------------------- +class PerceptualLoss(nn.Module): + + def __init__(self): + super().__init__() + vgg = vgg16(pretrained=True).features[:10] + for param in vgg.parameters(): + param.requires_grad = False + self.vgg = vgg.eval() + self.mean = torch.tensor([0.485, 0.456, 0.406]) + self.std = torch.tensor([0.229, 0.224, 0.225]) + self.mse = nn.MSELoss() + + def forward(self, x, x_recon): + # x, x_recon: [B,3,H,W] ∈[-1,1] + x = (x + 1) / 2 + x_recon = (x_recon + 1) / 2 + mean = self.mean.to(x.device, dtype=x.dtype).view(1,3,1,1) + std = self.std.to(x.device, dtype=x.dtype).view(1,3,1,1) + x_norm = F_tv.normalize(x, mean=mean, std=std) + x_recon_norm = F_tv.normalize(x_recon, mean=mean, std=std) + feat_x = self.vgg(x_norm) + feat_recon = self.vgg(x_recon_norm) + return self.mse(feat_x, feat_recon) + +class TotalLoss(nn.Module): + def __init__(self, lambda_perceptual=1.0, lambda_cls=1.0, num_classes=10): + super().__init__() + self.lambda_perceptual = lambda_perceptual + self.lambda_cls = lambda_cls + self.perceptual_loss = PerceptualLoss() + self.cls_loss = nn.CrossEntropyLoss() + + def forward(self, x, x_recon, pred_cls, labels): + """ + Args: + x: [B,3,H,W] + x_recon: [B,3,H,W] + pred_cls: logits [B, num_classes] + labels: [B] + """ + loss_perceptual = self.perceptual_loss(x, x_recon) + loss_cls = self.cls_loss(pred_cls, labels) + total_loss = self.lambda_perceptual * loss_perceptual + self.lambda_cls * loss_cls + return { + 'overall': total_loss, + 'loss_perceptual': loss_perceptual, + 'loss_cls': loss_cls + } + +# -------------------------- ------------------------ +@DETECTOR.register_module(module_name='resnet34_ae_trace') +class ResNet34_AE_Trace_Detector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + + self.in_channels = config.get('in_channels', 3) + self.out_channels = config.get('out_channels', 3) + self.latent_channels = config.get('latent_channels', 512) + self.base_channels = config.get('base_channels', 64) + self.num_classes = config['backbone_config']['num_classes'] + + self.lambda_perceptual = config.get('lambda_perceptual', 0.1) + self.lambda_cls = config.get('lambda_cls', 1.0) + + + self.encoder, self.decoder, self.classifier_head = self.build_backbone(config) + + self.loss_func = self.build_loss(config) + + + # for param in self.decoder.parameters(): + # param.requires_grad = False + + + logger.info(f"ResNet34-AE-TRACE Detector:") + logger.info(f" - ResNet34 Encoder + ResNet34 Decoder + classifier") + logger.info(f" -latent channels:{self.latent_channels}") + logger.info(f" - num_class:{self.num_classes}") + logger.info(f" - loss weight{self.lambda_perceptual} | cls loss{self.lambda_cls}") + + def build_backbone(self, config): + """Encoder、Decoder、classifier""" + # Encoder + encoder = ResNet34_Encoder( + in_channels=self.in_channels, + latent_channels=self.latent_channels, + base_channels=self.base_channels + ) + # Decoder + decoder = ResNet34_Decoder( + out_channels=self.out_channels, + latent_channels=self.latent_channels, + base_channels=self.base_channels + ) + classifier_head = TraceClassifierHead( + latent_channels=self.latent_channels, + num_classes=self.num_classes + ) + return encoder, decoder, classifier_head + + def build_loss(self, config): + + return TotalLoss( + lambda_perceptual=self.lambda_perceptual, + lambda_cls=self.lambda_cls, + num_classes=self.num_classes + ) + + def features(self, data_dict: dict) -> torch.tensor: + + x = data_dict['image'] + z = self.encoder(x) + return z + + def classifier(self, features: torch.tensor) -> torch.tensor: + + return self.classifier_head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + + x = data_dict['image'] + labels = data_dict['label'] + x_recon = pred_dict['x_recon'] + pred_cls = pred_dict['cls'] + + loss_dict = self.loss_func(x, x_recon, pred_cls, labels) + + return { + 'overall': loss_dict['overall'], + 'loss_perceptual': loss_dict['loss_perceptual'].item(), + 'loss_cls': loss_dict['loss_cls'].item() + } + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + + # 1. metric(acc、mAP) + labels = data_dict['label'].detach() + pred_cls = pred_dict['cls'].detach() + acc, mAP = calculate_acc_for_train(labels, pred_cls, self.num_classes) + + x = data_dict['image'].detach() + x_recon = pred_dict['x_recon'].detach() + x = (x + 1) / 2 # [-1,1] → [0,1] + x_recon = (x_recon + 1) / 2 + + + mse = F.mse_loss(x, x_recon).item() + psnr = 10 * np.log10(1.0 / (mse + 1e-6)) + + + ssim = self._calculate_ssim(x, x_recon) + + + return { + 'acc': acc, 'mAP': mAP, + 'psnr': psnr, 'ssim': ssim + } + + def _calculate_ssim(self, x, x_recon, window_size=11, sigma=1.5): + + B, C, H, W = x.shape + + gauss = torch.Tensor(np.exp(-np.arange(0, window_size)**2 / (2 * sigma**2))).to(x.device) + gauss = gauss / gauss.sum() + window = gauss.unsqueeze(1) @ gauss.unsqueeze(0) + window = window.expand(C, 1, window_size, window_size).contiguous() + + + mu1 = F.conv2d(x, window, padding=window_size//2, groups=C) + mu2 = F.conv2d(x_recon, window, padding=window_size//2, groups=C) + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d(x*x, window, padding=window_size//2, groups=C) - mu1_sq + sigma2_sq = F.conv2d(x_recon*x_recon, window, padding=window_size//2, groups=C) - mu2_sq + sigma12 = F.conv2d(x*x_recon, window, padding=window_size//2, groups=C) - mu1_mu2 + + C1 = (0.01 * 1)**2 + C2 = (0.03 * 1)**2 + ssim_map = ((2*mu1_mu2 + C1) * (2*sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean().item() + + def forward(self, data_dict: dict, inference=False) -> dict: + x = data_dict['image'] + + + z = self.features(data_dict) + + + pred_cls = self.classifier(z) + pred_prob = torch.softmax(pred_cls, dim=1) + + + x_recon = self.decoder(z) + + + pred_dict = { + 'cls': pred_cls, + 'prob': pred_prob, + 'feat': torch.mean(z, dim=[2,3]), + 'z': z, + 'x_recon': x_recon + } + + + # if inference and hasattr(self, 'writer'): + # self.writer.add_images('recon/gt', (x + 1)/2, global_step=self.current_step, dataformats='NCHW') + # self.writer.add_images('recon/pred', (x_recon + 1)/2, global_step=self.current_step, dataformats='NCHW') + + return pred_dict diff --git a/training/detectors/altfreezing_detector.py b/training/detectors/altfreezing_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..dbd64e493d0a7a52780acef4dc7ea00ace0a479c --- /dev/null +++ b/training/detectors/altfreezing_detector.py @@ -0,0 +1,241 @@ +config_text = """ +TRAIN: + ENABLE: True + DATASET: kinetics + BATCH_SIZE: 64 + EVAL_PERIOD: 10 + CHECKPOINT_PERIOD: 1 + AUTO_RESUME: True +DATA: + NUM_FRAMES: 8 + SAMPLING_RATE: 8 + TRAIN_JITTER_SCALES: [256, 320] + TRAIN_CROP_SIZE: 224 + TEST_CROP_SIZE: 256 + INPUT_CHANNEL_NUM: [3] +RESNET: + ZERO_INIT_FINAL_BN: True + WIDTH_PER_GROUP: 64 + NUM_GROUPS: 1 + DEPTH: 50 + TRANS_FUNC: bottleneck_transform + STRIDE_1X1: False + NUM_BLOCK_TEMP_KERNEL: [[3], [4], [6], [3]] +NONLOCAL: + LOCATION: [[[]], [[]], [[]], [[]]] + GROUP: [[1], [1], [1], [1]] + INSTANTIATION: softmax +BN: + USE_PRECISE_STATS: True + NUM_BATCHES_PRECISE: 200 +SOLVER: + BASE_LR: 0.1 + LR_POLICY: cosine + MAX_EPOCH: 196 + MOMENTUM: 0.9 + WEIGHT_DECAY: 1e-4 + WARMUP_EPOCHS: 34.0 + WARMUP_START_LR: 0.01 + OPTIMIZING_METHOD: sgd +MODEL: + NUM_CLASSES: 1 + ARCH: i3d + MODEL_NAME: ResNet + LOSS_FUNC: cross_entropy + DROPOUT_RATE: 0.5 + HEAD_ACT: sigmoid +TEST: + ENABLE: True + DATASET: kinetics + BATCH_SIZE: 64 +DATA_LOADER: + NUM_WORKERS: 8 + PIN_MEMORY: True +NUM_GPUS: 8 +NUM_SHARDS: 1 +RNG_SEED: 0 +OUTPUT_DIR: . +""" + + +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the AltFreezingDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@InProceedings{Wang_2023_CVPR, + author = {Wang, Zhendong and Bao, Jianmin and Zhou, Wengang and Wang, Weilun and Li, Houqiang}, + title = {AltFreezing for More General Video Face Forgery Detection}, + booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, + month = {June}, + year = {2023}, + pages = {4129-4138} +} +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC + + +import os +import sys +current_file_path = os.path.abspath(__file__) +parent_dir = os.path.dirname(os.path.dirname(current_file_path)) +project_root_dir = os.path.dirname(parent_dir) +sys.path.append(parent_dir) +sys.path.append(project_root_dir) + +import torch +from .utils.slowfast.models.video_model_builder import ResNet as ResNetOri +from .utils.slowfast.config.defaults import get_cfg +from torch import nn +import random + + +random_select = True +no_time_pool = False + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='altfreezing') +class AltFreezingDetector(AbstractDetector): + def __init__(self, config): + super().__init__() + cfg = get_cfg() + cfg.merge_from_str(config_text) + cfg.NUM_GPUS = 1 + cfg.TEST.BATCH_SIZE = 1 + cfg.TRAIN.BATCH_SIZE = 1 + cfg.DATA.NUM_FRAMES = config['clip_size'] + self.resnet = ResNetOri(cfg) + if config['pretrained'] is not None: + print(f"loading pretrained model from {config['pretrained']}") + pretrained_weights = torch.load(config['pretrained'], map_location='cpu', encoding='latin1') + modified_weights = {k.replace("resnet.", ""): v for k, v in pretrained_weights.items()} + # fit from 400 num_classes to 1 + modified_weights["head.projection.weight"] = modified_weights["head.projection.weight"][:1, :] + modified_weights["head.projection.bias"] = modified_weights["head.projection.bias"][:1] + # load final ckpt + self.resnet.load_state_dict(modified_weights, strict=True) + + self.conv_dict = self.find_conv_layers(self.resnet) + print("1x3x3 Conv: {}\n3x1x1 Conv:{}".format(self.conv_dict['spatial'], self.conv_dict['temporal'])) + self.train_batch_cnt = 0 + + self.loss_func = nn.BCELoss() # The output of the model is a probability value between 0 and 1 (haved used sigmoid) + + def find_conv_layers(self, module, parent_name='', conv_dict=None): + if conv_dict is None: + conv_dict = {'temporal': [], 'spatial': []} + + for name, sub_module in module.named_children(): + full_name = f'{parent_name}.{name}' if parent_name else name + + if isinstance(sub_module, nn.Conv3d): + if sub_module.kernel_size == (3, 1, 1): + conv_dict['temporal'].append(full_name) + if sub_module.kernel_size == (1, 3, 3): + conv_dict['spatial'].append(full_name) + else: + self.find_conv_layers(sub_module, full_name, conv_dict) + + return conv_dict + + def alternate_mode(self, target_mode): + for layer_name in self.conv_dict['temporal']: + layer = dict(self.resnet.named_modules())[layer_name] + layer.weight.requires_grad = True if target_mode == 'temporal' else False + if layer.bias is not None: + layer.bias.requires_grad = True if target_mode == 'temporal' else False + + for layer_name in self.conv_dict['spatial']: + layer = dict(self.resnet.named_modules())[layer_name] + layer.weight.requires_grad = True if target_mode == 'spatial' else False + if layer.bias is not None: + layer.bias.requires_grad = True if target_mode == 'spatial' else False + + + def build_backbone(self, config): + pass + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + inputs = [data_dict['image'].permute(0,2,1,3,4)] + pred = self.resnet(inputs) + output = {"final_output": pred} + + return output["final_output"] + + def classifier(self, features: torch.tensor): + pass + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'].float() + pred = pred_dict['cls'].view(-1) + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + # compute metrics for batch data + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the probability + prob = self.features(data_dict) + # build the prediction dict for each output + pred_dict = {'cls': prob, 'prob': prob, 'feat': prob} + if not inference: + if self.train_batch_cnt % (20 + 1) == 0: + self.alternate_mode('spatial') + elif self.train_batch_cnt % (20 + 1) == 1: + self.alternate_mode('temporal') + + self.train_batch_cnt += 1 + + return pred_dict diff --git a/training/detectors/base_detector.py b/training/detectors/base_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..b240972b16a301c456d8836809e99b84a82e6af0 --- /dev/null +++ b/training/detectors/base_detector.py @@ -0,0 +1,71 @@ +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Abstract Class for the Deepfake Detector + +import abc +import torch +import torch.nn as nn +from typing import Union + +class AbstractDetector(nn.Module, metaclass=abc.ABCMeta): + """ + All deepfake detectors should subclass this class. + """ + def __init__(self, config=None, load_param: Union[bool, str] = False): + """ + config: (dict) + configurations for the model + load_param: (False | True | Path(str)) + False Do not read; True Read the default path; Path Read the required path + """ + super().__init__() + + @abc.abstractmethod + def features(self, data_dict: dict) -> torch.tensor: + """ + Returns the features from the backbone given the input data. + """ + pass + + @abc.abstractmethod + def forward(self, data_dict: dict, inference=False) -> dict: + """ + Forward pass through the model, returning the prediction dictionary. + """ + pass + + @abc.abstractmethod + def classifier(self, features: torch.tensor) -> torch.tensor: + """ + Classifies the features into classes. + """ + pass + + @abc.abstractmethod + def build_backbone(self, config): + """ + Builds the backbone of the model. + """ + pass + + @abc.abstractmethod + def build_loss(self, config): + """ + Builds the loss function for the model. + """ + pass + + @abc.abstractmethod + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + """ + Returns the losses for the model. + """ + pass + + @abc.abstractmethod + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + """ + Returns the training metrics for the model. + """ + pass diff --git a/training/detectors/capsule_net_detector.py b/training/detectors/capsule_net_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..b8204a5f5bd9702585c848eaebfcef412af3e9ac --- /dev/null +++ b/training/detectors/capsule_net_detector.py @@ -0,0 +1,271 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the CapsuleNetDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{nguyen2019capsule, + title={Capsule-forensics: Using capsule networks to detect forged images and videos}, + author={Nguyen, Huy H and Yamagishi, Junichi and Echizen, Isao}, + booktitle={ICASSP 2019-2019 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, + pages={2307--2311}, + year={2019}, + organization={IEEE} +} + +GitHub Reference: +https://github.com/niyunsheng/CORE +''' + +import os +import datetime +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC + +import torchvision.models as models + +@DETECTOR.register_module(module_name='capsule_net') +class CapsuleNetDetector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.loss_func = self.build_loss(config) + #capsule net + self.num_classes = config['num_classes'] + self.vgg_ext = VggExtractor() + self.fea_ext = FeatureExtractor() + self.fea_ext.apply(self.weights_init) + + self.NO_CAPS = 10 + self.routing_stats = RoutingLayer(num_input_capsules=self.NO_CAPS, num_output_capsules= self.num_classes, data_in=8, data_out=4, num_iterations=2) + + def build_backbone(self, config): + ... # do not need one specific backbone for capsule net + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + input = self.vgg_ext(data_dict['image']) + feature = self.fea_ext(input) + return feature + + def classifier(self, features: torch.tensor) -> torch.tensor: + z = self.routing_stats(features, random = False, dropout = 0.0) + # z[b, data, out_caps] + classes = F.softmax(z, dim=-1) + class_ = classes.detach() + class_ = class_.mean(dim=1) + return classes, class_ + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + classes = pred_dict['classes'] + loss = self.loss_func(classes, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + # compute metrics for batch data + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + preds, pred = self.classifier(features) + # get the probability of the pred + prob = torch.softmax(pred, dim=1)[:, 1] + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features, 'classes': preds} + return pred_dict + + def weights_init(self, m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + m.weight.data.normal_(0.0, 0.02) + elif classname.find('BatchNorm') != -1: + m.weight.data.normal_(1.0, 0.02) + m.bias.data.fill_(0) + +# VGG input(10,3,256,256) +class VggExtractor(nn.Module): + def __init__(self, train=False): + super(VggExtractor, self).__init__() + self.vgg_1 = self.Vgg(models.vgg19(pretrained=True), 0, 18) + if train: + self.vgg_1.train(mode=True) + self.freeze_gradient() + else: + self.vgg_1.eval() + + def Vgg(self, vgg, begin, end): + features = nn.Sequential(*list(vgg.features.children())[begin:(end+1)]) + return features + + def freeze_gradient(self, begin=0, end=9): + for i in range(begin, end+1): + self.vgg_1[i].requires_grad = False + + def forward(self, input): + return self.vgg_1(input) + +class FeatureExtractor(nn.Module): + def __init__(self): + super(FeatureExtractor, self).__init__() + self.NO_CAPS = 10 ##mark yxh + self.capsules = nn.ModuleList([ + nn.Sequential( + nn.Conv2d(256, 64, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(64), + nn.ReLU(), + nn.Conv2d(64, 16, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(16), + nn.ReLU(), + StatsNet(), + + nn.Conv1d(2, 8, kernel_size=5, stride=2, padding=2), + nn.BatchNorm1d(8), + nn.Conv1d(8, 1, kernel_size=3, stride=1, padding=1), + nn.BatchNorm1d(1), + View(-1, 8), + ) + for _ in range(self.NO_CAPS)] + ) + + def squash(self, tensor, dim): + squared_norm = (tensor ** 2).sum(dim=dim, keepdim=True) + scale = squared_norm / (1 + squared_norm) + return scale * tensor / (torch.sqrt(squared_norm)) + + def forward(self, x): + # outputs = [capsule(x.detach()) for capsule in self.capsules] + # outputs = [capsule(x.clone()) for capsule in self.capsules] + outputs = [capsule(x) for capsule in self.capsules] + output = torch.stack(outputs, dim=-1) + + return self.squash(output, dim=-1) + +class StatsNet(nn.Module): + def __init__(self): + super(StatsNet, self).__init__() + + def forward(self, x): + x = x.view(x.data.shape[0], x.data.shape[1], x.data.shape[2]*x.data.shape[3]) + + mean = torch.mean(x, 2) + std = torch.std(x, 2) + + return torch.stack((mean, std), dim=1) + +class View(nn.Module): + def __init__(self, *shape): + super(View, self).__init__() + self.shape = shape + + def forward(self, input): + return input.view(self.shape) + +# Capsule right Dynamic routing +class RoutingLayer(nn.Module): + def __init__(self, num_input_capsules, num_output_capsules, data_in, data_out, num_iterations): + super(RoutingLayer, self).__init__() + + self.num_iterations = num_iterations + self.route_weights = nn.Parameter(torch.randn(num_output_capsules, num_input_capsules, data_out, data_in)) + + + def squash(self, tensor, dim): + squared_norm = (tensor ** 2).sum(dim=dim, keepdim=True) + scale = squared_norm / (1 + squared_norm) + return scale * tensor / (torch.sqrt(squared_norm)) + + def forward(self, x, random, dropout): + # x[b, data, in_caps] + + x = x.transpose(2, 1) + # x[b, in_caps, data] + + if random: + # noise = torch.Tensor(0.01*torch.randn(*self.route_weights.size())).cuda() + noise = torch.Tensor(0.01*torch.randn(*self.route_weights.size())).cuda() + route_weights = self.route_weights + noise + else: + route_weights = self.route_weights + + priors = route_weights[:, None, :, :, :] @ x[None, :, :, :, None] + + # route_weights [out_caps , 1 , in_caps , data_out , data_in] + # x [ 1 , b , in_caps , data_in , 1 ] + # priors [out_caps , b , in_caps , data_out, 1 ] + + priors = priors.transpose(1, 0) + # priors[b, out_caps, in_caps, data_out, 1] + + if dropout > 0.0: + # drop = torch.Tensor(torch.FloatTensor(*priors.size()).bernoulli(1.0- dropout)).cuda() + drop = torch.Tensor(torch.FloatTensor(*priors.size()).bernoulli(1.0- dropout)).cuda() + priors = priors * drop + + + # logits = torch.Tensor(torch.zeros(*priors.size())).cuda() + logits = torch.Tensor(torch.zeros(*priors.size())).to(priors.device) + # logits[b, out_caps, in_caps, data_out, 1] + + num_iterations = self.num_iterations + + for i in range(num_iterations): + probs = F.softmax(logits, dim=2) + outputs = self.squash((probs * priors).sum(dim=2, keepdim=True), dim=3) + + if i != self.num_iterations - 1: + delta_logits = priors * outputs + logits = logits + delta_logits + + # outputs[b, out_caps, 1, data_out, 1] + outputs = outputs.squeeze() + + if len(outputs.shape) == 3: + outputs = outputs.transpose(2, 1).contiguous() + else: + outputs = outputs.unsqueeze_(dim=0).transpose(2, 1).contiguous() + # outputs[b, data_out, out_caps] + + return outputs diff --git a/training/detectors/clip_adapter_two_3dconv_detector.py b/training/detectors/clip_adapter_two_3dconv_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..2f2b102403140676247a4496d345a2440e8b8765 --- /dev/null +++ b/training/detectors/clip_adapter_two_3dconv_detector.py @@ -0,0 +1,578 @@ +from collections import OrderedDict +from typing import Tuple, Union +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +import numpy as np +import torch +from torch import nn, einsum +from einops import rearrange +import math +import torch.nn.functional as F +import clip +from einops import rearrange + + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC + + +logger = logging.getLogger(__name__) + +@DETECTOR.register_module(module_name='clip_adapter_two_3dconv') +class CLIPAdapter3DConvDetector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.head = I3DHead( + num_classes=2, + in_channels=1024, + spatial_type='avg', + dropout_ratio=0.5 + ) + self.loss_func = self.build_loss(config) + self.prob, self.label = [], [] + self.correct, self.total = 0, 0 + + def build_backbone(self, config): + assert self.config['resolution'] == 224, 'The resolution of the input image should be 224x224' + # assert self.config['clip_size'] == 8, 'The number of frames should be 8' + + # prepare the backbone + backbone = ViT_CLIP( + input_resolution=224, + num_frames=self.config['clip_size'], + patch_size=14, + width=1024, + layers=14, + heads=16, + drop_path_rate=0.1, + num_tadapter=1, + adapter_scale=0.5, + pretrained=True + ) + + ## freeze some parameters + for name, param in backbone.named_parameters(): + if 'temporal_embedding' not in name and 'ln_post' not in name and 'cls_head' not in name and 'Adapter' not in name: + param.requires_grad = False + + for name, param in backbone.named_parameters(): + print('{}: {}'.format(name, param.requires_grad)) + num_param = sum(p.numel() for p in backbone.parameters() if p.requires_grad) + num_total_param = sum(p.numel() for p in backbone.parameters()) + print('Number of total parameters: {}, tunable parameters: {}'.format(num_total_param, num_param)) + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + feat = self.backbone(data_dict['image']) + return feat + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label.long()) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + # compute metrics for batch data + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + return metric_batch_dict + + def get_test_metrics(self): + y_pred, y_true = self.video_calculation(self.video_names, self.prob, self.label) + # auc + fpr, tpr, thresholds = metrics.roc_curve(y_true, y_pred, pos_label=1) + auc = metrics.auc(fpr, tpr) + # eer + fnr = 1 - tpr + eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))] + # ap + ap = metrics.average_precision_score(y_true,y_pred) + # acc + acc = self.correct / self.total + # reset the prob and label + self.prob, self.label = [], [] + self.correct, self.total = 0, 0 + return {'acc':acc, 'auc':auc, 'eer':eer, 'ap':ap, 'pred':y_pred, 'label':y_true} + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + prob = torch.softmax(pred, dim=1)[:, 1] + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + if inference: + self.prob.append( + pred_dict['prob'] + .detach() + .squeeze() + .cpu() + .numpy() + ) + self.label.append( + data_dict['label'] + .detach() + .squeeze() + .cpu() + .numpy() + ) + # deal with acc + _, prediction_class = torch.max(pred, 1) + correct = (prediction_class == data_dict['label']).sum().item() + self.correct += correct + self.total += data_dict['label'].size(0) + + return pred_dict + + +class Adapter(nn.Module): + def __init__(self, D_features, mlp_ratio=0.25, act_layer=nn.GELU, skip_connect=True): + super().__init__() + self.skip_connect = skip_connect + D_hidden_features = int(D_features * mlp_ratio) + self.act = act_layer() + self.D_fc1 = nn.Linear(D_features, D_hidden_features) + self.D_fc2 = nn.Linear(D_hidden_features, D_features) + + def forward(self, x): + # x is (BT, HW+1, D) + xs = self.D_fc1(x) + xs = self.act(xs) + xs = self.D_fc2(xs) + if self.skip_connect: + x = x + xs + else: + x = xs + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class LayerNormProxy(nn.Module): + def __init__(self, dim): + super().__init__() + self.norm = nn.LayerNorm(dim) + + def forward(self, x): + x = rearrange(x, 'b c t h w -> b t h w c') + x = self.norm(x) + x = rearrange(x, 'b t h w c -> b c t h w') + return x + + +class DepthwiseConv3D(nn.Module): + def __init__(self, in_channels, kernel_size): + super().__init__() + self.conv = nn.Conv3d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, groups=in_channels, padding=(kernel_size[0]//2, kernel_size[1]//2, kernel_size[2]//2)) + self.bn = nn.BatchNorm3d(num_features=in_channels) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + +class ViT_Adapter(nn.Module): + def __init__(self, num_frames=8, in_channels=1024, out_channels=1024): + super().__init__() + self.num_frames=num_frames + self.in_channels = in_channels + self.out_channels = out_channels + self.adapter_channels = int(1024 * 0.5) + + self.down = nn.Linear(in_features=self.in_channels, out_features=self.adapter_channels) + self.gelu1 = nn.GELU() + + self.s_conv = DepthwiseConv3D(in_channels=self.adapter_channels, kernel_size=(1, 3, 3)) + self.t_conv = DepthwiseConv3D(in_channels=self.adapter_channels, kernel_size=(3, 1, 1)) + + self.gelu = nn.GELU() + + self.up = nn.Linear(in_features=self.adapter_channels, out_features=self.out_channels) + self.gelu2 = nn.GELU() + + def forward(self, x): + # hw+1 bt c + n, bt, c = x.shape + H = round(math.sqrt(n - 1)) + x_in = x + + x = self.down(x) + x = self.gelu1(x) + + cls = x[0, :, :].unsqueeze(0) + x = x[1:, :, :] + + x = rearrange(x, '(h w) (b t) c -> b c t h w', t=self.num_frames, h=H) + + # Apply depthwise 3D convolutions + xs = self.s_conv(x) + xt = self.t_conv(x) + + # Fusion of xs and xt + x = (xs + xt) / 2 + x = self.gelu(x) + x = rearrange(x, 'b c t h w -> (h w) (b t) c') + + x = torch.cat([cls, x], dim=0) + + x = self.up(x) + x = self.gelu2(x) + + # residual + x += x_in + return x + + +class ResidualAttentionBlock(nn.Module): + def __init__(self): + super().__init__() + d_model = 1024 + n_head = 16 + self.ln_1 = LayerNorm(d_model) + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_2 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.Adapter = ViT_Adapter() + + def attention(self, x): + return self.attn(x, x, x, need_weights=False)[0] + + def forward(self, x): + # x shape [HW+1, BT, C] + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + x = self.Adapter(x) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, num_tadapter=1, scale=1., drop_path=0.1): + super().__init__() + self.width = width + self.layers = layers + dpr = [x.item() for x in torch.linspace(0, drop_path, self.layers)] + self.resblocks = nn.Sequential(*[ResidualAttentionBlock() for i in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class ViT_CLIP(nn.Module): + ## ViT definition in CLIP image encoder + def __init__(self, + input_resolution: int, # 224 + num_frames: int, # 8 + patch_size: int, # 14 + width: int, # 1024 + layers: int, # 14 + heads: int, # 16 + drop_path_rate, # 0.1 + num_tadapter=1, + adapter_scale=0.5, + pretrained=None # pretrained=True + ): + super().__init__() + + self.input_resolution = input_resolution + self.pretrained = pretrained + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) # 3-1024 14*14 bias-free convolution + + scale = width ** -0.5 + self.layers = layers + self.class_embedding = nn.Parameter(scale * torch.randn(width)) # 1024 + self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) # (257,1024) + self.ln_pre = LayerNorm(width) + + self.num_frames = num_frames + self.temporal_embedding = nn.Parameter(torch.zeros(1, num_frames, width)) # (1,8,1024) + + self.transformer = Transformer(width, layers, heads, num_tadapter=num_tadapter, scale=adapter_scale, drop_path=drop_path_rate) + + self.ln_post = LayerNorm(width) + + self.init_weights() + + + def init_weights(self): + logger.info(f'load model from: {self.pretrained}') + # Load OpenAI CLIP pretrained weights + clip_model, preprocess = clip.load("ViT-L/14", device="cpu") + pretrain_dict = clip_model.visual.state_dict() + del clip_model + del pretrain_dict['proj'] + msg = self.load_state_dict(pretrain_dict, strict=False) + logger.info('Missing keys: {}'.format(msg.missing_keys)) + logger.info('Unexpected keys: {}'.format(msg.unexpected_keys)) + logger.info(f"=> loaded successfully '{self.pretrained}'") + torch.cuda.empty_cache() + # zero-initialize Adapters + for n1, m1 in self.named_modules(): + if 'Adapter' in n1: + for n2, m2 in m1.named_modules(): + if 'up' in n2: + logger.info('init: {}.{}'.format(n1, n2)) + nn.init.constant_(m2.weight, 0) + nn.init.constant_(m2.bias, 0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed', 'temporal_embedding'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table', 'temporal_position_bias_table'} + + def extract_class_indices(self, labels, which_class): + class_mask = torch.eq(labels, which_class) + class_mask_indices = torch.nonzero(class_mask, as_tuple=False) + return torch.reshape(class_mask_indices, (-1,)) + + def get_feat(self, x): + x = rearrange(x, 'b t c h w -> (b t) c h w') # merge the B and T dimensions into [BT, C, H, W] + x = self.conv1(x) # pass through the first convolution layer [BT, 1024, 16, 16] + x = x.reshape(x.shape[0], x.shape[1], -1) # [BT,1024,256] + x = x.permute(0, 2, 1) # [BT,256,1024] + x = torch.cat( # add class_embedding + [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], + dim=1) # [class_embedding(1024) + zeros(BT, 1, 1024), (BT,256,1024)] => (BT,1+256,1024) + x = x + self.positional_embedding.to(x.dtype) # (BT,1+256,1024) + positional_embedding(257,1024) + # n = h*w+1 + n = x.shape[1] # 257 + + x = rearrange(x, '(b t) n c -> (b n) t c', t=self.num_frames) + x = x + self.temporal_embedding + x = rearrange(x, '(b n) t c -> (b t) n c', n=n) + + x = self.ln_pre(x) + x = x.permute(1, 0, 2) + x = self.transformer(x) + x = x.permute(1, 0, 2) + x = self.ln_post(x) + return x + + def forward(self, x: torch.Tensor): + B, T, C, H, W = x.shape + x = self.get_feat(x) + + x = x[:, 0] + x = rearrange(x, '(b t) d -> b d t',b=B,t=T) + + x = x.unsqueeze(-1).unsqueeze(-1) # BDTHW for I3D head + + return x + + +class I3DHead(nn.Module): + """Classification head for I3D. + + Args: + num_classes (int): Number of classes to be classified. + in_channels (int): Number of channels in input feature. + Default: dict(type='CrossEntropyLoss') + spatial_type (str): Pooling type in spatial dimension. Default: 'avg'. + dropout_ratio (float): Probability of dropout layer. Default: 0.5. + kwargs (dict, optional): Any keyword argument to be used to initialize + the head. + """ + + def __init__(self, + num_classes, + in_channels, + spatial_type='avg', + dropout_ratio=0.5, + **kwargs): + super().__init__() + self.num_classes = num_classes + self.in_channels = in_channels + self.spatial_type = spatial_type + self.dropout_ratio = dropout_ratio + if self.dropout_ratio != 0: + self.dropout = nn.Dropout(p=self.dropout_ratio) + else: + self.dropout = None + self.fc_cls = nn.Linear(self.in_channels, self.num_classes) + + if self.spatial_type == 'avg': + # use `nn.AdaptiveAvgPool3d` to adaptively match the in_channels. + self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) + else: + self.avg_pool = None + + def forward(self, x): + """Defines the computation performed at every call. + + Args: + x (torch.Tensor): The input data. + + Returns: + torch.Tensor: The classification scores for input samples. + """ + # [N, in_channels, 4, 7, 7] + if self.avg_pool is not None: + x = self.avg_pool(x) + # [N, in_channels, 1, 1, 1] + if self.dropout is not None: + x = self.dropout(x) + # [N, in_channels, 1, 1, 1] + x = x.view(x.shape[0], -1) + # [N, in_channels] + cls_score = self.fc_cls(x) + # [N, num_classes] + return cls_score + + +if __name__ == '__main__': + vit_model = ViT_CLIP( + input_resolution=224, + num_frames=8, + patch_size=16, + width=768, + layers=12, + heads=12, + drop_path_rate=0.1, + num_tadapter=1, + adapter_scale=0.5, + pretrained=True + ) + + i3d_head = I3DHead( + num_classes=2, + in_channels=768, + spatial_type='avg', + dropout_ratio=0.5 + ) + + rand_input = torch.rand(2, 8, 3, 224, 224) + feat = vit_model(rand_input) + print(feat.shape) + output = i3d_head(feat) + print(output.shape) + + +''' +VisionTransformer( + (conv1): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False) + (ln_pre): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) + (transformer): Transformer( + (resblocks): Sequential( + (0): ResidualAttentionBlock( + (attn): MultiheadAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) + ) + (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) + (mlp): Sequential( + (c_fc): Linear(in_features=1024, out_features=4096, bias=True) + (gelu): QuickGELU() + (c_proj): Linear(in_features=4096, out_features=1024, bias=True) + ) + (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) + ) + + ...... + + (23): ResidualAttentionBlock( + (attn): MultiheadAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) + ) + (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) + (mlp): Sequential( + (c_fc): Linear(in_features=1024, out_features=4096, bias=True) + (gelu): QuickGELU() + (c_proj): Linear(in_features=4096, out_features=1024, bias=True) + ) + (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) + ) + ) + ) + (ln_post): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) +) +''' + + +''' +class_embedding +positional_embedding +proj +conv1.weight +ln_pre.weight +ln_pre.bias +transformer.resblocks.0.attn.in_proj_weight +transformer.resblocks.0.attn.in_proj_bias +transformer.resblocks.0.attn.out_proj.weight +transformer.resblocks.0.attn.out_proj.bias +transformer.resblocks.0.ln_1.weight +transformer.resblocks.0.ln_1.bias +transformer.resblocks.0.mlp.c_fc.weight +transformer.resblocks.0.mlp.c_fc.bias +transformer.resblocks.0.mlp.c_proj.weight +transformer.resblocks.0.mlp.c_proj.bias +transformer.resblocks.0.ln_2.weight +transformer.resblocks.0.ln_2.bias +...... +transformer.resblocks.23.attn.in_proj_weight +transformer.resblocks.23.attn.in_proj_bias +transformer.resblocks.23.attn.out_proj.weight +transformer.resblocks.23.attn.out_proj.bias +transformer.resblocks.23.ln_1.weight +transformer.resblocks.23.ln_1.bias +transformer.resblocks.23.mlp.c_fc.weight +transformer.resblocks.23.mlp.c_fc.bias +transformer.resblocks.23.mlp.c_proj.weight +transformer.resblocks.23.mlp.c_proj.bias +transformer.resblocks.23.ln_2.weight +transformer.resblocks.23.ln_2.bias +ln_post.weight +ln_post.bias +''' \ No newline at end of file diff --git a/training/detectors/clip_adapter_two_3dconv_detector_cvpr.py b/training/detectors/clip_adapter_two_3dconv_detector_cvpr.py new file mode 100644 index 0000000000000000000000000000000000000000..65821b38bc1e29882e71e9faeab04d8cbff8a801 --- /dev/null +++ b/training/detectors/clip_adapter_two_3dconv_detector_cvpr.py @@ -0,0 +1,514 @@ +from collections import OrderedDict +from typing import Tuple, Union +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +import numpy as np +import torch +from torch import nn, einsum +from einops import rearrange +import math +import torch.nn.functional as F +import clip +from einops import rearrange +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the XceptionDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{rossler2019faceforensics++, + title={Faceforensics++: Learning to detect manipulated facial images}, + author={Rossler, Andreas and Cozzolino, Davide and Verdoliva, Luisa and Riess, Christian and Thies, Justus and Nie{\ss}ner, Matthias}, + booktitle={Proceedings of the IEEE/CVF international conference on computer vision}, + pages={1--11}, + year={2019} +} +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC + + +logger = logging.getLogger(__name__) + +@DETECTOR.register_module(module_name='clip_adapter_two_3dconv') +class CLIPAdapter3DConvDetector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.head = I3DHead( + num_classes=2, + in_channels=1024, + spatial_type='avg', + dropout_ratio=0.5 + ) + self.loss_func = self.build_loss(config) + self.prob, self.label = [], [] + self.correct, self.total = 0, 0 + + def build_backbone(self, config): + assert self.config['resolution'] == 224, 'The resolution of the input image should be 224x224' + # assert self.config['clip_size'] == 8, 'The number of frames should be 8' + # prepare the backbone + backbone = ViT_CLIP( + input_resolution=224, + num_frames=self.config['clip_size'], + patch_size=14, + width=1024, + layers=14, + heads=16, + drop_path_rate=0.1, + num_tadapter=1, + adapter_scale=0.5, + pretrained=True + ) + + ## freeze some parameters + for name, param in backbone.named_parameters(): + if 'temporal_embedding' not in name and 'ln_post' not in name and 'cls_head' not in name and 'Adapter' not in name: + param.requires_grad = False + + for name, param in backbone.named_parameters(): + print('{}: {}'.format(name, param.requires_grad)) + num_param = sum(p.numel() for p in backbone.parameters() if p.requires_grad) + num_total_param = sum(p.numel() for p in backbone.parameters()) + print('Number of total parameters: {}, tunable parameters: {}'.format(num_total_param, num_param)) + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + feat = self.backbone(data_dict['image']) + return feat + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label.long()) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + # compute metrics for batch data + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + return metric_batch_dict + + def get_test_metrics(self): + y_pred, y_true = self.video_calculation(self.video_names, self.prob, self.label) + # auc + fpr, tpr, thresholds = metrics.roc_curve(y_true, y_pred, pos_label=1) + auc = metrics.auc(fpr, tpr) + # eer + fnr = 1 - tpr + eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))] + # ap + ap = metrics.average_precision_score(y_true,y_pred) + # acc + acc = self.correct / self.total + # reset the prob and label + self.prob, self.label = [], [] + self.correct, self.total = 0, 0 + return {'acc':acc, 'auc':auc, 'eer':eer, 'ap':ap, 'pred':y_pred, 'label':y_true} + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + prob = torch.softmax(pred, dim=1)[:, 1] + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + if inference: + self.prob.append( + pred_dict['prob'] + .detach() + .squeeze() + .cpu() + .numpy() + ) + self.label.append( + data_dict['label'] + .detach() + .squeeze() + .cpu() + .numpy() + ) + # deal with acc + _, prediction_class = torch.max(pred, 1) + correct = (prediction_class == data_dict['label']).sum().item() + self.correct += correct + self.total += data_dict['label'].size(0) + + return pred_dict + + + +class Adapter(nn.Module): + def __init__(self, D_features, mlp_ratio=0.25, act_layer=nn.GELU, skip_connect=True): + super().__init__() + self.skip_connect = skip_connect + D_hidden_features = int(D_features * mlp_ratio) + self.act = act_layer() + self.D_fc1 = nn.Linear(D_features, D_hidden_features) + self.D_fc2 = nn.Linear(D_hidden_features, D_features) + + def forward(self, x): + # x is (BT, HW+1, D) + xs = self.D_fc1(x) + xs = self.act(xs) + xs = self.D_fc2(xs) + if self.skip_connect: + x = x + xs + else: + x = xs + return x + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class LayerNormProxy(nn.Module): + def __init__(self, dim): + super().__init__() + self.norm = nn.LayerNorm(dim) + + def forward(self, x): + x = rearrange(x, 'b c t h w -> b t h w c') + x = self.norm(x) + x = rearrange(x, 'b t h w c -> b c t h w') + return x + + +class DepthwiseConv3D(nn.Module): + def __init__(self, in_channels, kernel_size): + super().__init__() + self.conv = nn.Conv3d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, groups=in_channels, padding=(kernel_size[0]//2, kernel_size[1]//2, kernel_size[2]//2)) + self.bn = nn.BatchNorm3d(num_features=in_channels) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + +class ViT_Adapter(nn.Module): + def __init__(self, num_frames=8, in_channels=1024, out_channels=1024): + super().__init__() + self.num_frames=num_frames + self.in_channels = in_channels + self.out_channels = out_channels + self.adapter_channels = int(1024 * 0.5) + + self.down = nn.Linear(in_features=self.in_channels, out_features=self.adapter_channels) + self.gelu1 = nn.GELU() + + self.s_conv = DepthwiseConv3D(in_channels=self.adapter_channels, kernel_size=(1, 3, 3)) + self.t_conv = DepthwiseConv3D(in_channels=self.adapter_channels, kernel_size=(3, 1, 1)) + + self.gelu = nn.GELU() + + self.up = nn.Linear(in_features=self.adapter_channels, out_features=self.out_channels) + self.gelu2 = nn.GELU() + + def forward(self, x): + # hw+1 bt c + n, bt, c = x.shape + H = round(math.sqrt(n - 1)) + x_in = x + + x = self.down(x) + x = self.gelu1(x) + + cls = x[0, :, :].unsqueeze(0) + x = x[1:, :, :] + + x = rearrange(x, '(h w) (b t) c -> b c t h w', t=self.num_frames, h=H) + + # Apply depthwise 3D convolutions + xs = self.s_conv(x) + xt = self.t_conv(x) + + # Fusion of xs and xt + x = (xs + xt) / 2 + x = self.gelu(x) + x = rearrange(x, 'b c t h w -> (h w) (b t) c') + + x = torch.cat([cls, x], dim=0) + + x = self.up(x) + x = self.gelu2(x) + + # residual + x += x_in + return x + + + +class ResidualAttentionBlock(nn.Module): + def __init__(self): + super().__init__() + d_model = 1024 + n_head = 16 + self.ln_1 = LayerNorm(d_model) + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_2 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.Adapter = ViT_Adapter() + + def attention(self, x): + return self.attn(x, x, x, need_weights=False)[0] + + def forward(self, x): + # x shape [HW+1, BT, C] + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + x = self.Adapter(x) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, num_tadapter=1, scale=1., drop_path=0.1): + super().__init__() + self.width = width + self.layers = layers + dpr = [x.item() for x in torch.linspace(0, drop_path, self.layers)] + self.resblocks = nn.Sequential(*[ResidualAttentionBlock() for i in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class ViT_CLIP(nn.Module): + ## ViT definition in CLIP image encoder + def __init__(self, input_resolution: int, num_frames: int, patch_size: int, width: int, layers: int, heads: int, drop_path_rate, num_tadapter=1, adapter_scale=0.5, pretrained=None): + super().__init__() + self.input_resolution = input_resolution + self.pretrained = pretrained + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width ** -0.5 + self.layers = layers + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.num_frames = num_frames + self.temporal_embedding = nn.Parameter(torch.zeros(1, num_frames, width)) + + self.transformer = Transformer(width, layers, heads, num_tadapter=num_tadapter, scale=adapter_scale, drop_path=drop_path_rate) + + self.ln_post = LayerNorm(width) + + self.init_weights() + + + def init_weights(self): + logger.info(f'load model from: {self.pretrained}') + # Load OpenAI CLIP pretrained weights + clip_model, preprocess = clip.load("ViT-L/14", device="cpu") + pretrain_dict = clip_model.visual.state_dict() + del clip_model + del pretrain_dict['proj'] + msg = self.load_state_dict(pretrain_dict, strict=False) + logger.info('Missing keys: {}'.format(msg.missing_keys)) + logger.info('Unexpected keys: {}'.format(msg.unexpected_keys)) + logger.info(f"=> loaded successfully '{self.pretrained}'") + torch.cuda.empty_cache() + # zero-initialize Adapters + for n1, m1 in self.named_modules(): + if 'Adapter' in n1: + for n2, m2 in m1.named_modules(): + if 'up' in n2: + logger.info('init: {}.{}'.format(n1, n2)) + nn.init.constant_(m2.weight, 0) + nn.init.constant_(m2.bias, 0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed', 'temporal_embedding'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table', 'temporal_position_bias_table'} + + def extract_class_indices(self, labels, which_class): + class_mask = torch.eq(labels, which_class) + class_mask_indices = torch.nonzero(class_mask, as_tuple=False) + return torch.reshape(class_mask_indices, (-1,)) + + def get_feat(self, x): + x = rearrange(x, 'b t c h w -> (b t) c h w') + x = self.conv1(x) + x = x.reshape(x.shape[0], x.shape[1], -1) + x = x.permute(0, 2, 1) + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) + x = x + self.positional_embedding.to(x.dtype) + # n = h*w+1 + n = x.shape[1] + + x = rearrange(x, '(b t) n c -> (b n) t c', t=self.num_frames) + x = x + self.temporal_embedding + x = rearrange(x, '(b n) t c -> (b t) n c', n=n) + + x = self.ln_pre(x) + x = x.permute(1, 0, 2) + x = self.transformer(x) + x = x.permute(1, 0, 2) + x = self.ln_post(x) + return x + + def forward(self, x: torch.Tensor): + B, T, C, H, W = x.shape + x = self.get_feat(x) + + x = x[:, 0] + x = rearrange(x, '(b t) d -> b d t',b=B,t=T) + + x = x.unsqueeze(-1).unsqueeze(-1) # BDTHW for I3D head + + return x + + +class I3DHead(nn.Module): + """Classification head for I3D. + + Args: + num_classes (int): Number of classes to be classified. + in_channels (int): Number of channels in input feature. + Default: dict(type='CrossEntropyLoss') + spatial_type (str): Pooling type in spatial dimension. Default: 'avg'. + dropout_ratio (float): Probability of dropout layer. Default: 0.5. + kwargs (dict, optional): Any keyword argument to be used to initialize + the head. + """ + + def __init__(self, + num_classes, + in_channels, + spatial_type='avg', + dropout_ratio=0.5, + **kwargs): + super().__init__() + self.num_classes = num_classes + self.in_channels = in_channels + self.spatial_type = spatial_type + self.dropout_ratio = dropout_ratio + if self.dropout_ratio != 0: + self.dropout = nn.Dropout(p=self.dropout_ratio) + else: + self.dropout = None + self.fc_cls = nn.Linear(self.in_channels, self.num_classes) + + if self.spatial_type == 'avg': + # use `nn.AdaptiveAvgPool3d` to adaptively match the in_channels. + self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) + else: + self.avg_pool = None + + def forward(self, x): + """Defines the computation performed at every call. + + Args: + x (torch.Tensor): The input data. + + Returns: + torch.Tensor: The classification scores for input samples. + """ + # [N, in_channels, 4, 7, 7] + if self.avg_pool is not None: + x = self.avg_pool(x) + # [N, in_channels, 1, 1, 1] + if self.dropout is not None: + x = self.dropout(x) + # [N, in_channels, 1, 1, 1] + x = x.view(x.shape[0], -1) + # [N, in_channels] + cls_score = self.fc_cls(x) + # [N, num_classes] + return cls_score + + +if __name__ == '__main__': + vit_model = ViT_CLIP( + input_resolution=224, + num_frames=8, + patch_size=16, + width=768, + layers=12, + heads=12, + drop_path_rate=0.1, + num_tadapter=1, + adapter_scale=0.5, + pretrained=True + ) + + i3d_head = I3DHead( + num_classes=2, + in_channels=768, + spatial_type='avg', + dropout_ratio=0.5 + ) + + rand_input = torch.rand(2, 8, 3, 224, 224) + feat = vit_model(rand_input) + print(feat.shape) + output = i3d_head(feat) + print(output.shape) + + diff --git a/training/detectors/clip_base_fft_detector.py b/training/detectors/clip_base_fft_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..321c51f574839e1d16a7eee1aa06788f7d951e75 --- /dev/null +++ b/training/detectors/clip_base_fft_detector.py @@ -0,0 +1,120 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the CLIPDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{radford2021learning, + title={Learning transferable visual models from natural language supervision}, + author={Radford, Alec and Kim, Jong Wook and Hallacy, Chris and Ramesh, Aditya and Goh, Gabriel and Agarwal, Sandhini and Sastry, Girish and Askell, Amanda and Mishkin, Pamela and Clark, Jack and others}, + booktitle={International conference on machine learning}, + pages={8748--8763}, + year={2021}, + organization={PMLR} +} +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +from transformers import AutoProcessor, CLIPModel, ViTModel, ViTConfig +import loralib as lora +import copy + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='clip_base_fft') +class CLIP_Base_FFT_Detector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.head = nn.Linear(768, config['backbone_config']['num_classes']) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + # prepare the backbone + _, backbone = get_clip_visual(model_name=config['pretrained']) + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + feat = self.backbone(data_dict['image'])['pooler_output'] + return feat + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + + # compute metrics for batch data + # auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + # metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.config['backbone_config']['num_classes']) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + # prob = torch.softmax(pred, dim=1)[:, 1] + prob = torch.softmax(pred, dim=1) + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + return pred_dict + + +def get_clip_visual(model_name = "openai/clip-vit-base-patch16"): + processor = AutoProcessor.from_pretrained(model_name) + model = CLIPModel.from_pretrained(model_name) + return processor, model.vision_model diff --git a/training/detectors/clip_base_vid_detector.py b/training/detectors/clip_base_vid_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..044a7a293368ad114d99e82af53796d52a936ac2 --- /dev/null +++ b/training/detectors/clip_base_vid_detector.py @@ -0,0 +1,167 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the CLIPDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{radford2021learning, + title={Learning transferable visual models from natural language supervision}, + author={Radford, Alec and Kim, Jong Wook and Hallacy, Chris and Ramesh, Aditya and Goh, Gabriel and Agarwal, Sandhini and Sastry, Girish and Askell, Amanda and Mishkin, Pamela and Clark, Jack and others}, + booktitle={International conference on machine learning}, + pages={8748--8763}, + year={2021}, + organization={PMLR} +} +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +from transformers import AutoProcessor, CLIPModel, ViTModel, ViTConfig +import loralib as lora +import copy + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='clip_base_vid') +class CLIP_Base_Vid_Detector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.head = nn.Linear(768, 2) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + # prepare the backbone + _, backbone = get_clip_visual(model_name=config['pretrained']) + backbone = to_lora(backbone, r=16) + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + B, T, C, H, W = data_dict['image'].shape + feat = self.backbone(data_dict['image'].view(B * T, C, H, W))['pooler_output'] + feat = feat.view(B, T, feat.shape[-1]).mean(1) # Temporal avg + return feat + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + # compute metrics for batch data + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + prob = torch.softmax(pred, dim=1)[:, 1] + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + return pred_dict + + +def get_clip_visual(model_name = "openai/clip-vit-base-patch16"): + processor = AutoProcessor.from_pretrained(model_name) + model = CLIPModel.from_pretrained(model_name) + return processor, model.vision_model + + +def to_lora(model, target=[nn.Linear, nn.Conv2d, nn.Embedding], r=16, f_class=None, layers=None, names=None): + + for n, m in model.named_modules(): + if f_class is not None and not isinstance(m, f_class): + continue + if isinstance(m, nn.Sequential) or isinstance(m, nn.ModuleList): + + for name, mod in m.named_children(): + + # print(name, mod) + if isinstance(mod, nn.Linear) and not isinstance(mod, lora.Linear): + mod = change_mod(mod, r=r) + m._modules[name] = mod + else: + if layers is None or any(['layers.' + str(i) in n for i in layers]): + for name, mod in m.named_children(): + # if 'self_attn' in f_name: + # print(name, mod) + if isinstance(mod, nn.Linear) and not isinstance(mod, lora.Linear): + if names is None or any(na in name for na in names): + mod = change_mod(mod, r=r) + setattr(m, name, mod) + + lora.mark_only_lora_as_trainable(model) + return model + + +def change_mod(m, targets=[nn.Linear, nn.Conv2d, nn.Embedding], r=16): + st_dict = m.state_dict() + + if nn.Linear in targets and isinstance(m, nn.Linear): + dtype = m.weight.dtype + new_m = lora.Linear(m.in_features, m.out_features, bias=m.bias is not None, r=r, dtype=dtype) + new_m.load_state_dict(st_dict, strict=False) + # print(new_m) + m = new_m + elif nn.Conv2d in targets and isinstance(m, nn.Conv2d): + new_m = lora.Conv2d(m.in_channels, m.out_channels, m.kernel_size, stride=m.stride, padding=m.padding, \ + dilation=m.dilation, transposed=m.transposed, output_padding=m.output_padding, groups=m.groups, bias=m.bias, r=r) + new_m.load_state_dict(st_dict, strict=False) + m = new_m + elif nn.Embedding in targets and isinstance(m, nn.Embedding): + new_m = lora.Embedding(m.num_embeddings, m.embedding_dim, padding_idx=m.padding_idx, max_norm=m.max_norm, norm_type=m.norm_type, \ + scale_grad_by_freq=m.scale_grad_by_freq, freeze=m.freeze, sparse=m.sparse, r=r) + new_m.load_state_dict(st_dict, strict=False) + m = new_m + + return m diff --git a/training/detectors/clip_contrast_detector.py b/training/detectors/clip_contrast_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..8a4e0027d84f26e195db597a5ad9bffc50f72c60 --- /dev/null +++ b/training/detectors/clip_contrast_detector.py @@ -0,0 +1,111 @@ +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +from transformers import AutoProcessor, CLIPModel, ViTModel, ViTConfig +import loralib as lora +import copy + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='clip_contrast') +class CLIP_Contrast(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.num_classes = config['backbone_config']['num_classes'] # New: unify the class-count variable + self.backbone = self.build_backbone(config) + self.head = nn.Linear(1024, self.num_classes) + + # 1. original classification loss + self.cls_loss_func = self.build_loss(config) + # 2. New: SupCon contrastive loss (kept consistent with DINOv2) + supconloss_class = LOSSFUNC["supcon"] + self.supcon_loss_func = supconloss_class() + + def build_backbone(self, config): + # Keep the original logic unchanged + _, backbone = get_clip_visual(model_name=config['pretrained']) + return backbone + + def build_loss(self, config): + # Keep the original logic unchanged + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + # Keep the original logic unchanged + feat = self.backbone(data_dict['image'])['pooler_output'] + return feat + + def classifier(self, features: torch.tensor) -> torch.tensor: + # Keep the original logic unchanged + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + """Core change: combine classification loss with SupCon contrastive loss""" + # 1. original classification loss (main loss) + label = data_dict['label'] + pred = pred_dict['cls'] + cls_loss = self.cls_loss_func(pred, label) + + # 2. New: SupCon contrastive loss (single-view features, adapted to CLIP 1024-dimensional features) + feat = pred_dict['feat'] # [B, 1024] → SupCon loss automatically handles dimensions + supcon_loss = self.supcon_loss_func(feat, label) + + # 3. Total loss:classification-dominated + contrastive auxiliary loss (configurable weight) + contrast_weight = self.config.get('contrast_weight', 0.1) # default 0.1, consistent with DINOv2 + total_loss = cls_loss + contrast_weight * supcon_loss + + # 4. Loss dictionary: keep the original keys and add detailed loss items + loss_dict = { + 'overall': total_loss, + 'cls_loss': cls_loss, # original classification loss + 'supcon_loss': supcon_loss # new contrastive loss + } + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + # Keep the original logic unchanged + label = data_dict['label'] + pred = pred_dict['cls'] + + acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.num_classes) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # Keep the original logic unchanged(feat is already included in pred_dict for use by the contrastive loss) + features = self.features(data_dict) + pred = self.classifier(features) + prob = torch.softmax(pred, dim=1) + + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + return pred_dict + + +def get_clip_visual(model_name = "openai/clip-vit-base-patch16"): + # Keep the original logic unchanged + processor = AutoProcessor.from_pretrained(model_name) + model = CLIPModel.from_pretrained(model_name) + return processor, model.vision_model \ No newline at end of file diff --git a/training/detectors/clip_detector.py b/training/detectors/clip_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..a259aa4af4cb4adbe5aa2b51291e6c5c6907c445 --- /dev/null +++ b/training/detectors/clip_detector.py @@ -0,0 +1,115 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the CLIPDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{radford2021learning, + title={Learning transferable visual models from natural language supervision}, + author={Radford, Alec and Kim, Jong Wook and Hallacy, Chris and Ramesh, Aditya and Goh, Gabriel and Agarwal, Sandhini and Sastry, Girish and Askell, Amanda and Mishkin, Pamela and Clark, Jack and others}, + booktitle={International conference on machine learning}, + pages={8748--8763}, + year={2021}, + organization={PMLR} +} +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +from transformers import AutoProcessor, CLIPModel, ViTModel, ViTConfig +import loralib as lora +import copy + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='clip') +class CLIPDetector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.head = nn.Linear(768, 2) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + # prepare the backbone + _, backbone = get_clip_visual(model_name=config['pretrained']) + return backbone + + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + feat = self.backbone(data_dict['image'])['pooler_output'] + return feat + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + # compute metrics for batch data + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + prob = torch.softmax(pred, dim=1)[:, 1] + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + return pred_dict + + +def get_clip_visual(model_name = "openai/clip-vit-base-patch16"): + processor = AutoProcessor.from_pretrained(model_name) + model = CLIPModel.from_pretrained(model_name) + return processor, model.vision_model diff --git a/training/detectors/clip_hier_contrast_detector.py b/training/detectors/clip_hier_contrast_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..9b4e2415a49b3f37cdb900523e2579e00020333d --- /dev/null +++ b/training/detectors/clip_hier_contrast_detector.py @@ -0,0 +1,156 @@ +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +from transformers import AutoProcessor, CLIPModel, ViTModel, ViTConfig +import loralib as lora +import copy + +logger = logging.getLogger(__name__) + +# Original label -> grouped label +group_map = { + 1: 0, # FF-DF_Fake -> Face Swap + 2: 0, # FF-F2F_Fake -> Face Swap + 3: 0, # FF-NT_Fake -> Face Swap + 4: 0, # FF-FH_Fake -> Face Swap + 5: 0, # fsgan_Fake -> Face Swap + 6: 0, # faceswap_Fake -> Face Swap + 7: 0, # inswap_Fake -> Face Swap + 8: 0, # simswap_Fake -> Face Swap + 9: 0, # blendface_Fake -> Face Swap + 10: 0, # uniface_Fake -> Face Swap + 11: 0, # e4s_Fake -> Face Swap + 12: 0, # facedancer_Fake -> Face Swap + 13: 0, # mobileswap_Fake -> Face Swap + + 14: 1, # sadtalker_Fake -> Reenactment (Audio-driven) + 15: 1, # wav2lip_Fake -> Reenactment (Audio-driven) + 16: 1, # fomm_Fake -> Reenactment (Image-driven) + 17: 1, # MRAA_Fake -> Reenactment (Image-driven) + 18: 1, # one_shot_free_Fake -> Reenactment (Image-driven, OneShot) + 19: 1, # pirender_Fake -> Reenactment (Image-driven) + 20: 1, # tpsm_Fake -> Reenactment (Image-driven, TPSMM) + 21: 1, # lia_Fake -> Reenactment (Image-driven, LIA) + 22: 1, # danet_Fake -> Reenactment (Image-driven, DaNet) + 23: 1, # mcnet_Fake -> Reenactment (Image-driven, MCNet) + 24: 1, # hyperreenact_Fake -> Reenactment + 25: 1, # facevid2vid_Fake -> Reenactment (Landmark-driven FS_vid2vid) + + 26: 2, # VQGAN_Fake -> Entire Face Synthesis (GAN based) + 27: 2, # StyleGAN3_Fake -> Entire Face Synthesis (GAN based) + 28: 2, # StyleGANXL_Fake -> Entire Face Synthesis (GAN based) + 29: 2, # ddim_Fake -> Entire Face Synthesis (Latent Diffusion) + 30: 2, # sd2_1_Fake -> Entire Face Synthesis (Latent Diffusion) + 31: 2, # rddm_Fake -> Entire Face Synthesis (Latent Diffusion) + 32: 2, # pixart_Fake -> Entire Face Synthesis (Latent Diffusion) + 33: 2, # DiT_Fake -> Entire Face Synthesis (Latent Diffusion) + 34: 2, # SiT_Fake -> Entire Face Synthesis (Latent Diffusion) + + 35: 3, # e4e_Fake -> Face Edit (StyleGAN based) +} + +@DETECTOR.register_module(module_name='clip_contrast_hier') +class CLIP_Contrast_HIER(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.num_classes = config['backbone_config']['num_classes'] # New: unify the class-count variable + self.backbone = self.build_backbone(config) + self.head = nn.Linear(1024, self.num_classes) + + # 1. original classification loss + self.cls_loss_func = self.build_loss(config) + # 2. New: SupCon contrastive loss (kept consistent with DINOv2) + supconloss_class = LOSSFUNC["supcon"] + self.supcon_loss_func = supconloss_class() + + def build_backbone(self, config): + # Keep the original logic unchanged + _, backbone = get_clip_visual(model_name=config['pretrained']) + return backbone + + def build_loss(self, config): + # Keep the original logic unchanged + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + # Keep the original logic unchanged + feat = self.backbone(data_dict['image'])['pooler_output'] + return feat + + def classifier(self, features: torch.tensor) -> torch.tensor: + # Keep the original logic unchanged + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + """Core change: combine classification loss with SupCon contrastive loss""" + # 1. original classification loss (main loss) + label = data_dict['label'] + pred = pred_dict['cls'] + cls_loss = self.cls_loss_func(pred, label) + group_label = label.clone() + for k, v in group_map.items(): + group_label[label == k] = v+1 + + # 2. New: SupCon contrastive loss (single-view features, adapted to CLIP 1024-dimensional features) + feat = pred_dict['feat'] # [B, 1024] → SupCon loss automatically handles dimensions + supcon_loss = self.supcon_loss_func(feat, label) + supcon_loss_hier=self.supcon_loss_func(feat,group_label) + # 3. Total loss:classification-dominated + contrastive auxiliary loss (configurable weight) + contrast_weight = self.config.get('contrast_weight', 0.1) # default 0.1, consistent with DINOv2 + contrast_weight=0.1 + total_loss = cls_loss + contrast_weight * (supcon_loss_hier) + + # 4. Loss dictionary: keep the original keys and add detailed loss items + loss_dict = { + 'overall': total_loss, + 'cls_loss': cls_loss, # original classification loss + 'supcon_loss': supcon_loss # new contrastive loss + } + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + # Keep the original logic unchanged + label = data_dict['label'] + pred = pred_dict['cls'] + + acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.num_classes) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # Keep the original logic unchanged(feat is already included in pred_dict for use by the contrastive loss) + features = self.features(data_dict) + pred = self.classifier(features) + prob = torch.softmax(pred, dim=1) + + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + return pred_dict + + +def get_clip_visual(model_name = "openai/clip-vit-base-patch16"): + # Keep the original logic unchanged + processor = AutoProcessor.from_pretrained(model_name) + model = CLIPModel.from_pretrained(model_name) + return processor, model.vision_model \ No newline at end of file diff --git a/training/detectors/clip_image_detector.py b/training/detectors/clip_image_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..556b140a045ac80bb46a6cf30dba42ec9d53dc86 --- /dev/null +++ b/training/detectors/clip_image_detector.py @@ -0,0 +1,99 @@ +import os +import logging +import numpy as np +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F +import clip + +from metrics.base_metrics_class import calculate_acc_for_train +from .base_detector import AbstractDetector +from detectors import DETECTOR +from loss import LOSSFUNC + +logger = logging.getLogger(__name__) + +# Linear classifier module (adapted to the 512-dimensional CLIP image embedding) +class LinearClassifier(nn.Module): + def __init__(self, input_size: int, hidden_size_list: List[int], num_classes: int): + super(LinearClassifier, self).__init__() + self.dropout = nn.Dropout(0.5) + self.fc1 = nn.Linear(input_size, hidden_size_list[0]) + self.fc2 = nn.Linear(hidden_size_list[0], hidden_size_list[1]) + self.fc3 = nn.Linear(hidden_size_list[1], num_classes) + + def forward(self, x: torch.tensor) -> torch.tensor: + out = self.fc1(x) + out = F.relu(out) + out = self.dropout(out) + out = self.fc2(out) + out = F.relu(out) + out = self.fc3(out) + return out + +@DETECTOR.register_module(module_name='clip_image') +class CLIPImageDetector(AbstractDetector): + def __init__(self, config): + super().__init__(config) # keep the parent `__init__` signature aligned + self.config = config + self.device = torch.device("cuda" if config['cuda'] else "cpu") + self.backbone = self.build_backbone(config) + self.classifier_module = self.build_classifier(config) # avoid naming conflicts with the parent method + self.loss_func = self.build_loss(config) + + def build_backbone(self, config) -> clip.model.CLIP: + # Load only the CLIP model and use its image encoder. + clip_model, _ = clip.load(config['clip_model_name'], device=self.device) + logger.info(f"Loaded CLIP image encoder: {config['clip_model_name']}") + return clip_model + + def build_loss(self, config) -> nn.CrossEntropyLoss: + # Implement the abstract parent method to build the multi-classification loss. + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func.to(self.device) + + def features(self, data_dict: dict) -> torch.tensor: + # Implement the abstract parent method to extract image features only, without text logic. + images = data_dict['image'].to(self.device) + # Extract the 512-dimensional CLIP image embedding. + with torch.no_grad(): + image_emb = self.backbone.encode_image(images) + return image_emb.float() + + def classifier(self, features: torch.tensor) -> torch.tensor: + # Implement the abstract parent method for feature classification. + return self.classifier_module(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + # Implement the abstract parent method to compute the loss. + labels = data_dict['label'].to(self.device) + preds = pred_dict['cls'] + loss = self.loss_func(preds, labels) + return {'overall': loss} + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + # Implement the abstract parent method to compute training metrics for multi-class classification. + labels = data_dict['label'].detach().cpu() + preds = pred_dict['cls'].detach().cpu() + num_classes = self.config['classifier_config']['num_classes'] + acc, mAP = calculate_acc_for_train(labels, preds, num_classes) + return {'acc': acc, 'mAP': mAP} + + def forward(self, data_dict: dict, inference=False) -> dict: + # Implement the abstract parent method for forward propagation. + features = self.features(data_dict) + cls_pred = self.classifier(features) + prob = torch.softmax(cls_pred, dim=1) + return {'cls': cls_pred, 'prob': prob, 'feat': features} + + def build_classifier(self, config) -> LinearClassifier: + # Helper method to build the linear classifier (input dimension = 512-dimensional CLIP image embedding). + input_size = 512 # the CLIP ViT-B/32 image embedding dimension is fixed at 512 + hidden_size_list = config['classifier_config']['hidden_size_list'] + num_classes = config['classifier_config']['num_classes'] + classifier = LinearClassifier(input_size, hidden_size_list, num_classes) + return classifier.to(self.device) + diff --git a/training/detectors/clip_large_detector.py b/training/detectors/clip_large_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..fbd7d160de8a9553bb61ac368297ff9303da4e8e --- /dev/null +++ b/training/detectors/clip_large_detector.py @@ -0,0 +1,129 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the CLIPDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{radford2021learning, + title={Learning transferable visual models from natural language supervision}, + author={Radford, Alec and Kim, Jong Wook and Hallacy, Chris and Ramesh, Aditya and Goh, Gabriel and Agarwal, Sandhini and Sastry, Girish and Askell, Amanda and Mishkin, Pamela and Clark, Jack and others}, + booktitle={International conference on machine learning}, + pages={8748--8763}, + year={2021}, + organization={PMLR} +} +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +from transformers import AutoProcessor, CLIPModel, ViTModel, ViTConfig +import loralib as lora +import copy + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='clip_large') +class CLIP_Large_Detector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + # New: freeze all parameters in the CLIP backbone + self.freeze_backbone() + self.head = nn.Linear(1024, config['backbone_config']['num_classes']) + self.loss_func = self.build_loss(config) + + def freeze_backbone(self): + """Freeze all parameters of the CLIP visual backbone to disable gradient updates""" + for param in self.backbone.parameters(): + param.requires_grad = False + # Verify the freezing effect (optional, for debugging) + logger.info("CLIP backbone has been frozen successfully!") + + def build_backbone(self, config): + # prepare the backbone + _, backbone = get_clip_visual(model_name=config['pretrained']) + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + feat = self.backbone(data_dict['image'])['pooler_output'] + return feat + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + + # compute metrics for batch data + # auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + # metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.config['backbone_config']['num_classes']) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + # prob = torch.softmax(pred, dim=1)[:, 1] + prob = torch.softmax(pred, dim=1) + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + return pred_dict + + +def get_clip_visual(model_name = "openai/clip-vit-base-patch16"): + processor = AutoProcessor.from_pretrained(model_name) + model = CLIPModel.from_pretrained(model_name) + return processor, model.vision_model \ No newline at end of file diff --git a/training/detectors/clip_large_fft_detector.py b/training/detectors/clip_large_fft_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..6042d91df1aa1be8265408a6efb727a7a5c51b74 --- /dev/null +++ b/training/detectors/clip_large_fft_detector.py @@ -0,0 +1,120 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the CLIPDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{radford2021learning, + title={Learning transferable visual models from natural language supervision}, + author={Radford, Alec and Kim, Jong Wook and Hallacy, Chris and Ramesh, Aditya and Goh, Gabriel and Agarwal, Sandhini and Sastry, Girish and Askell, Amanda and Mishkin, Pamela and Clark, Jack and others}, + booktitle={International conference on machine learning}, + pages={8748--8763}, + year={2021}, + organization={PMLR} +} +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +from transformers import AutoProcessor, CLIPModel, ViTModel, ViTConfig +import loralib as lora +import copy + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='clip_large_fft') +class CLIP_Large_FFT_Detector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.head = nn.Linear(1024, config['backbone_config']['num_classes']) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + # prepare the backbone + _, backbone = get_clip_visual(model_name=config['pretrained']) + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + feat = self.backbone(data_dict['image'])['pooler_output'] + return feat + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + + # compute metrics for batch data + # auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + # metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.config['backbone_config']['num_classes']) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + # prob = torch.softmax(pred, dim=1)[:, 1] + prob = torch.softmax(pred, dim=1) + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + return pred_dict + + +def get_clip_visual(model_name = "openai/clip-vit-base-patch16"): + processor = AutoProcessor.from_pretrained(model_name) + model = CLIPModel.from_pretrained(model_name) + return processor, model.vision_model diff --git a/training/detectors/clip_large_fft_dino_orth_detector.py b/training/detectors/clip_large_fft_dino_orth_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..88ce7010de1ba0e59342992f4c59863f7464e0cb --- /dev/null +++ b/training/detectors/clip_large_fft_dino_orth_detector.py @@ -0,0 +1,175 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the CLIPDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{radford2021learning, + title={Learning transferable visual models from natural language supervision}, + author={Radford, Alec and Kim, Jong Wook and Hallacy, Chris and Ramesh, Aditya and Goh, Gabriel and Agarwal, Sandhini and Sastry, Girish and Askell, Amanda and Mishkin, Pamela and Clark, Jack and others}, + booktitle={International conference on machine learning}, + pages={8748--8763}, + year={2021}, + organization={PMLR} +} +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +from transformers import AutoProcessor, CLIPModel, ViTModel, ViTConfig +import loralib as lora +import copy + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='clip_large_fft_dino_orth') +class CLIP_Large_FFT_Dino_Orth_Detector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone_s, self.backbone_f = self.build_backbone(config) + self.semantic_proj = nn.Sequential(nn.Linear(1024, 1024), nn.ReLU(), nn.Linear(1024, 1024)) + self.finger_proj = nn.Sequential(nn.Linear(1024, 1024), nn.ReLU(), nn.Linear(1024, 1024)) + self.head = nn.Linear(1024, config['backbone_config']['num_classes']) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + # prepare the backbone + backbone_s = get_clip_visual(model_name=config['pretrained']) # frozen + backbone_f = get_clip_visual(model_name=config['pretrained']) # trainable + + for param in backbone_s.parameters(): + param.requires_grad = False + + return backbone_s, backbone_f + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def project_and_decompose(self, a, b, eps=1e-8): + """ + Decompose the 1024-dimensional vector a into the projection component along b and the orthogonal component perpendicular to b + + Args: + a: a 1024-dimensional vector with shape (..., 1024), supporting batched processing (e.g. (batch_size, 1024)) + b: a 1024-dimensional vector whose shape must match a (with the last dimension being 1024) + eps: a small value to prevent division by zero + + Returns: + proj: projection component along b, with the same shape as the input + ortho: orthogonal component perpendicular to b, with the same shape as the input + """ + # Compute the squared norm of b (..., 1), keeping the last dimension for broadcasting + b_norm_sq = torch.sum(b **2, dim=-1, keepdim=True) + eps + + # Compute the dot product of a and b (..., 1) + a_dot_b = torch.sum(a * b, dim=-1, keepdim=True) + + # Projection coefficient = dot product / squared norm of b (..., 1) + proj_coeff = a_dot_b / b_norm_sq + + # Projection component along b = coefficient * b + proj = proj_coeff * b # Use broadcasting to multiply the coefficients with b element-wise + + # Orthogonal component = original vector a - projection component + ortho = a - proj + + return proj, ortho + + def features(self, data_dict: dict) -> torch.tensor: + feat_clip = self.backbone_s(data_dict['image']) + feat_all = self.backbone_f(data_dict['image']) + + #### Disentanglement operations + # 1. Semantic projection disentanglement + feat_s = self.semantic_proj(feat_all) + # 2. Project onto the semantic direction and obtain the orthogonal component + proj, ortho = self.project_and_decompose(feat_all, feat_s) + # 3. Obtain the fingerprint feature + feat_f = self.finger_proj(ortho) + + return feat_clip, feat_s, feat_f, feat_all + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss_ce, loss_orth, loss_mse = self.loss_func(pred, label, pred_dict['feat_clip'], pred_dict['feat_s'], pred_dict['feat']) + loss_dict = { + 'overall': loss_ce + loss_orth + loss_mse, + 'loss_ce': loss_ce, + 'loss_orth': loss_orth, + 'loss_mse': loss_mse, + } + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + + # compute metrics for batch data + # auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + # metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.config['backbone_config']['num_classes']) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + feat_clip, feat_s, feat_f, feat_all = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(feat_f) + # get the probability of the pred + # prob = torch.softmax(pred, dim=1)[:, 1] + prob = torch.softmax(pred, dim=1) + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat_clip': feat_clip, 'feat_s': feat_s, 'feat': feat_f} + # pred_dict = {'cls': pred, 'prob': prob, 'feat_clip': feat_clip, 'feat_s': feat_s, 'feat': feat_f + feat_s} # Alternative retrieval mode: use f+s features for retrieval + # pred_dict = {'cls': pred, 'prob': prob, 'feat_clip': feat_clip, 'feat_s': feat_s, 'feat': feat_all} # Alternative retrieval mode: use all features for retrieval + # pred_dict = {'cls': pred, 'prob': prob, 'feat_clip': feat_clip, 'feat_s': feat_s, 'feat': feat_all + feat_f} # Alternative retrieval mode: use all + f features for retrieval + # pred_dict = {'cls': pred, 'prob': prob, 'feat_clip': feat_clip, 'feat_s': feat_s, 'feat': feat_f + feat_s * 0.1} # Alternative retrieval mode: use all + f features for retrieval + return pred_dict + + +def get_clip_visual(model_name = 'dinov2_vitl14'): + dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', model_name) + return dinov2_vitl14 diff --git a/training/detectors/clip_large_fft_dis_cat1_detector.py b/training/detectors/clip_large_fft_dis_cat1_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..1f6004de5a219693b35df931b3329e4c33e56f4b --- /dev/null +++ b/training/detectors/clip_large_fft_dis_cat1_detector.py @@ -0,0 +1,249 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the CLIPDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{radford2021learning, + title={Learning transferable visual models from natural language supervision}, + author={Radford, Alec and Kim, Jong Wook and Hallacy, Chris and Ramesh, Aditya and Goh, Gabriel and Agarwal, Sandhini and Sastry, Girish and Askell, Amanda and Mishkin, Pamela and Clark, Jack and others}, + booktitle={International conference on machine learning}, + pages={8748--8763}, + year={2021}, + organization={PMLR} +} +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +from transformers import AutoProcessor, CLIPModel, ViTModel, ViTConfig +import loralib as lora +import copy + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='clip_large_fft_dis_cat1') +class CLIP_Large_FFT_Dis_Cat1_Detector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone_clip_s, self.backbone_f, self.backbone_dino_s = self.build_backbone(config) + # CLIP branch + self.semantic_proj_clip = nn.Sequential(nn.Linear(1024, 1024), nn.ReLU(), nn.Linear(1024, 1024)) + self.finger_proj_clip = nn.Sequential(nn.Linear(1024, 1024), nn.ReLU(), nn.Linear(1024, 1024)) + # DINO branch + self.semantic_proj_dino = nn.Sequential(nn.Linear(1024, 1024), nn.ReLU(), nn.Linear(1024, 1024)) + self.finger_proj_dino = nn.Sequential(nn.Linear(1024, 1024), nn.ReLU(), nn.Linear(1024, 1024)) + # classification projection + self.head = nn.Linear(1024, config['backbone_config']['num_classes']) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + # prepare the backbone + _, backbone_clip_s = get_clip_visual(model_name=config['pretrained']) # frozen + _, backbone_f = get_clip_visual(model_name=config['pretrained']) # trainable + backbone_dino_s = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14') # frozen + + for param in backbone_clip_s.parameters(): + param.requires_grad = False + for param in backbone_dino_s.parameters(): + param.requires_grad = False + + return backbone_clip_s, backbone_f, backbone_dino_s + + def build_loss(self, config): + # prepare the loss function + loss_func = CLIP_DINO_Loss() + return loss_func + + def project_and_decompose(self, a, b, eps=1e-8): + """ + Decompose the 1024-dimensional vector a into the projection component along b and the orthogonal component perpendicular to b + + Args: + a: a 1024-dimensional vector with shape (..., 1024), supporting batched processing (e.g. (batch_size, 1024)) + b: a 1024-dimensional vector whose shape must match a (with the last dimension being 1024) + eps: a small value to prevent division by zero + + Returns: + proj: projection component along b, with the same shape as the input + ortho: orthogonal component perpendicular to b, with the same shape as the input + """ + # Compute the squared norm of b (..., 1), keeping the last dimension for broadcasting + b_norm_sq = torch.sum(b **2, dim=-1, keepdim=True) + eps + + # Compute the dot product of a and b (..., 1) + a_dot_b = torch.sum(a * b, dim=-1, keepdim=True) + + # Projection coefficient = dot product / squared norm of b (..., 1) + proj_coeff = a_dot_b / b_norm_sq + + # Projection component along b = coefficient * b + proj = proj_coeff * b # Use broadcasting to multiply the coefficients with b element-wise + + # Orthogonal component = original vector a - projection component + ortho = a - proj + + return proj, ortho + + def features(self, data_dict: dict) -> torch.tensor: + feat_clip = self.backbone_clip_s(data_dict['image'])['pooler_output'] + feat_all = self.backbone_f(data_dict['image'])['pooler_output'] + feat_dino = self.backbone_dino_s(data_dict['image']) + + #### Disentanglement operations + # 1. Semantic projection disentanglement + feat_clip_s = self.semantic_proj_clip(feat_all) + # 2. Project onto the semantic direction and obtain the orthogonal component + proj, ortho_clip = self.project_and_decompose(feat_all, feat_clip_s) + # 3. Obtain the fingerprint feature + feat_clip_f = self.finger_proj_clip(ortho_clip) + + #### Disentanglement operations + # 1. Semantic projection disentanglement + feat_dino_s = self.semantic_proj_dino(feat_all) + # 2. Project onto the semantic direction and obtain the orthogonal component + proj, ortho_dino = self.project_and_decompose(feat_all, feat_dino_s) + # 3. Obtain the fingerprint feature + feat_dino_f = self.finger_proj_dino(ortho_dino) + + return feat_clip, feat_dino, feat_clip_s, feat_dino_s, feat_clip_f, feat_dino_f, feat_all + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + loss_dict = self.loss_func(pred_dict, label) + + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + + # compute metrics for batch data + # auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + # metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.config['backbone_config']['num_classes']) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + feat_clip, feat_dino, feat_clip_s, feat_dino_s, feat_clip_f, feat_dino_f, feat_all = self.features(data_dict) + # get the prediction by classifier + feat_f = feat_clip_f + feat_dino_f + pred = self.classifier(feat_f) + # get the probability of the pred + # prob = torch.softmax(pred, dim=1)[:, 1] + prob = torch.softmax(pred, dim=1) + # build the prediction dict for each output + pred_dict = { + 'cls': pred, + 'prob': prob, + 'feat_clip': feat_clip, + 'feat_dino': feat_dino, + 'feat_clip_s': feat_clip_s, + 'feat_dino_s': feat_dino_s, + 'feat_clip_f': feat_clip_f, + 'feat_dino_f': feat_dino_f, + 'feat': feat_f, + 'feat_all': feat_all, + } + + return pred_dict + + +def get_clip_visual(model_name = "openai/clip-vit-base-patch16"): + processor = AutoProcessor.from_pretrained(model_name) + model = CLIPModel.from_pretrained(model_name) + return processor, model.vision_model + + +class OrthogonalLoss(nn.Module): + """Orthogonal loss: minimize the squared normalized dot product of two features""" + def __init__(self, eps=1e-8): + super().__init__() + self.eps = eps + + def forward(self, feat1, feat2): + assert feat1.shape == feat2.shape, "Feature shapes must match" + dot_product = torch.sum(feat1 * feat2, dim=1, keepdim=True) # dot product + norm1 = torch.norm(feat1, dim=1, keepdim=True) + self.eps # norm of feature 1 + norm2 = torch.norm(feat2, dim=1, keepdim=True) + self.eps # norm of feature 2 + normalized_dot = dot_product / (norm1 * norm2) # normalized dot product + return torch.mean(normalized_dot **2) # minimize the squared value (target is 0) + + +class CLIP_DINO_Loss(nn.Module): + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + self.loss_mse = nn.MSELoss() + self.loss_orth = OrthogonalLoss() + + def forward(self, pred_dict, targets): + """ + Computes the cross-entropy loss. + + Args: + inputs: A PyTorch tensor of size (batch_size, num_classes) containing the predicted scores. + targets: A PyTorch tensor of size (batch_size) containing the ground-truth class indices. + + Returns: + A scalar tensor representing the cross-entropy loss. + """ + # Compute the cross-entropy loss + loss_ce = self.loss_fn(pred_dict['cls'], targets) + # + loss_orth = ( + self.loss_orth(pred_dict['feat_clip_s'], pred_dict['feat_clip_f']) + \ + self.loss_orth(pred_dict['feat_dino_s'], pred_dict['feat_dino_f']) + ) / 2 + # + loss_mse = ( + self.loss_mse(pred_dict['feat_clip_s'], pred_dict['feat_clip']) + \ + self.loss_mse(pred_dict['feat_dino_s'], pred_dict['feat_dino']) + ) / 2 + + return { + "overall": loss_ce + loss_orth + loss_mse, + "loss_ce": loss_ce, + "loss_orth": loss_orth, + "loss_mse": loss_mse # monitor the effect of semantic regularization + } diff --git a/training/detectors/clip_large_fft_dis_cat2_detector.py b/training/detectors/clip_large_fft_dis_cat2_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..0ca9c64cd4678d3e3763c9635e8ba36f12e959d8 --- /dev/null +++ b/training/detectors/clip_large_fft_dis_cat2_detector.py @@ -0,0 +1,249 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the CLIPDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{radford2021learning, + title={Learning transferable visual models from natural language supervision}, + author={Radford, Alec and Kim, Jong Wook and Hallacy, Chris and Ramesh, Aditya and Goh, Gabriel and Agarwal, Sandhini and Sastry, Girish and Askell, Amanda and Mishkin, Pamela and Clark, Jack and others}, + booktitle={International conference on machine learning}, + pages={8748--8763}, + year={2021}, + organization={PMLR} +} +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +from transformers import AutoProcessor, CLIPModel, ViTModel, ViTConfig +import loralib as lora +import copy + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='clip_large_fft_dis_cat2') +class CLIP_Large_FFT_Dis_Cat2_Detector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone_clip_s, self.backbone_f, self.backbone_dino_s = self.build_backbone(config) + # CLIP branch + self.semantic_proj_clip = nn.Sequential(nn.Linear(1024, 1024), nn.ReLU(), nn.Linear(1024, 1024)) + self.finger_proj_clip = nn.Sequential(nn.Linear(1024, 1024), nn.ReLU(), nn.Linear(1024, 1024)) + # DINO branch + self.semantic_proj_dino = nn.Sequential(nn.Linear(1024, 1024), nn.ReLU(), nn.Linear(1024, 1024)) + self.finger_proj_dino = nn.Sequential(nn.Linear(1024, 1024), nn.ReLU(), nn.Linear(1024, 1024)) + # classification projection + self.head = nn.Linear(1024, config['backbone_config']['num_classes']) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + # prepare the backbone + _, backbone_clip_s = get_clip_visual(model_name=config['pretrained']) # frozen + backbone_f = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14') # trainable + backbone_dino_s = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14') # frozen + + for param in backbone_clip_s.parameters(): + param.requires_grad = False + for param in backbone_dino_s.parameters(): + param.requires_grad = False + + return backbone_clip_s, backbone_f, backbone_dino_s + + def build_loss(self, config): + # prepare the loss function + loss_func = CLIP_DINO_Loss() + return loss_func + + def project_and_decompose(self, a, b, eps=1e-8): + """ + Decompose the 1024-dimensional vector a into the projection component along b and the orthogonal component perpendicular to b + + Args: + a: a 1024-dimensional vector with shape (..., 1024), supporting batched processing (e.g. (batch_size, 1024)) + b: a 1024-dimensional vector whose shape must match a (with the last dimension being 1024) + eps: a small value to prevent division by zero + + Returns: + proj: projection component along b, with the same shape as the input + ortho: orthogonal component perpendicular to b, with the same shape as the input + """ + # Compute the squared norm of b (..., 1), keeping the last dimension for broadcasting + b_norm_sq = torch.sum(b **2, dim=-1, keepdim=True) + eps + + # Compute the dot product of a and b (..., 1) + a_dot_b = torch.sum(a * b, dim=-1, keepdim=True) + + # Projection coefficient = dot product / squared norm of b (..., 1) + proj_coeff = a_dot_b / b_norm_sq + + # Projection component along b = coefficient * b + proj = proj_coeff * b # Use broadcasting to multiply the coefficients with b element-wise + + # Orthogonal component = original vector a - projection component + ortho = a - proj + + return proj, ortho + + def features(self, data_dict: dict) -> torch.tensor: + feat_clip = self.backbone_clip_s(data_dict['image'])['pooler_output'] + feat_all = self.backbone_f(data_dict['image']) + feat_dino = self.backbone_dino_s(data_dict['image']) + + #### Disentanglement operations + # 1. Semantic projection disentanglement + feat_clip_s = self.semantic_proj_clip(feat_all) + # 2. Project onto the semantic direction and obtain the orthogonal component + proj, ortho_clip = self.project_and_decompose(feat_all, feat_clip_s) + # 3. Obtain the fingerprint feature + feat_clip_f = self.finger_proj_clip(ortho_clip) + + #### Disentanglement operations + # 1. Semantic projection disentanglement + feat_dino_s = self.semantic_proj_dino(feat_all) + # 2. Project onto the semantic direction and obtain the orthogonal component + proj, ortho_dino = self.project_and_decompose(feat_all, feat_dino_s) + # 3. Obtain the fingerprint feature + feat_dino_f = self.finger_proj_dino(ortho_dino) + + return feat_clip, feat_dino, feat_clip_s, feat_dino_s, feat_clip_f, feat_dino_f, feat_all + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + loss_dict = self.loss_func(pred_dict, label) + + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + + # compute metrics for batch data + # auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + # metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.config['backbone_config']['num_classes']) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + feat_clip, feat_dino, feat_clip_s, feat_dino_s, feat_clip_f, feat_dino_f, feat_all = self.features(data_dict) + # get the prediction by classifier + feat_f = feat_clip_f + feat_dino_f + pred = self.classifier(feat_f) + # get the probability of the pred + # prob = torch.softmax(pred, dim=1)[:, 1] + prob = torch.softmax(pred, dim=1) + # build the prediction dict for each output + pred_dict = { + 'cls': pred, + 'prob': prob, + 'feat_clip': feat_clip, + 'feat_dino': feat_dino, + 'feat_clip_s': feat_clip_s, + 'feat_dino_s': feat_dino_s, + 'feat_clip_f': feat_clip_f, + 'feat_dino_f': feat_dino_f, + 'feat': feat_f, + 'feat_all': feat_all, + } + + return pred_dict + + +def get_clip_visual(model_name = "openai/clip-vit-base-patch16"): + processor = AutoProcessor.from_pretrained(model_name) + model = CLIPModel.from_pretrained(model_name) + return processor, model.vision_model + + +class OrthogonalLoss(nn.Module): + """Orthogonal loss: minimize the squared normalized dot product of two features""" + def __init__(self, eps=1e-8): + super().__init__() + self.eps = eps + + def forward(self, feat1, feat2): + assert feat1.shape == feat2.shape, "Feature shapes must match" + dot_product = torch.sum(feat1 * feat2, dim=1, keepdim=True) # dot product + norm1 = torch.norm(feat1, dim=1, keepdim=True) + self.eps # norm of feature 1 + norm2 = torch.norm(feat2, dim=1, keepdim=True) + self.eps # norm of feature 2 + normalized_dot = dot_product / (norm1 * norm2) # normalized dot product + return torch.mean(normalized_dot **2) # minimize the squared value (target is 0) + + +class CLIP_DINO_Loss(nn.Module): + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + self.loss_mse = nn.MSELoss() + self.loss_orth = OrthogonalLoss() + + def forward(self, pred_dict, targets): + """ + Computes the cross-entropy loss. + + Args: + inputs: A PyTorch tensor of size (batch_size, num_classes) containing the predicted scores. + targets: A PyTorch tensor of size (batch_size) containing the ground-truth class indices. + + Returns: + A scalar tensor representing the cross-entropy loss. + """ + # Compute the cross-entropy loss + loss_ce = self.loss_fn(pred_dict['cls'], targets) + # + loss_orth = ( + self.loss_orth(pred_dict['feat_clip_s'], pred_dict['feat_clip_f']) + \ + self.loss_orth(pred_dict['feat_dino_s'], pred_dict['feat_dino_f']) + ) / 2 + # + loss_mse = ( + self.loss_mse(pred_dict['feat_clip_s'], pred_dict['feat_clip']) + \ + self.loss_mse(pred_dict['feat_dino_s'], pred_dict['feat_dino']) + ) / 2 * 0.1 + + return { + "overall": loss_ce + loss_orth + loss_mse, + "loss_ce": loss_ce, + "loss_orth": loss_orth, + "loss_mse": loss_mse # monitor the effect of semantic regularization + } diff --git a/training/detectors/clip_large_fft_dis_detector.py b/training/detectors/clip_large_fft_dis_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..55b5dc0539fcca049e04895b8d6f9e8d22e45068 --- /dev/null +++ b/training/detectors/clip_large_fft_dis_detector.py @@ -0,0 +1,162 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the CLIPDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{radford2021learning, + title={Learning transferable visual models from natural language supervision}, + author={Radford, Alec and Kim, Jong Wook and Hallacy, Chris and Ramesh, Aditya and Goh, Gabriel and Agarwal, Sandhini and Sastry, Girish and Askell, Amanda and Mishkin, Pamela and Clark, Jack and others}, + booktitle={International conference on machine learning}, + pages={8748--8763}, + year={2021}, + organization={PMLR} +} +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +from transformers import AutoProcessor, CLIPModel, ViTModel, ViTConfig +import loralib as lora +import copy + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='clip_large_fft_dis') +class CLIP_Large_FFT_Dis_Detector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone_s, self.backbone_f = self.build_backbone(config) + self.head = nn.Linear(1024, config['backbone_config']['num_classes']) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + # prepare the backbone + _, backbone_s = get_clip_visual(model_name=config['pretrained']) # frozen + _, backbone_f = get_clip_visual(model_name=config['pretrained']) # trainable + + for param in backbone_s.parameters(): + param.requires_grad = False + + return backbone_s, backbone_f + + def parameters(self, recurse: bool = True): + """ + Rewrite the parameter iteration method to return trainable parameters and display full module paths + """ + print("="*70) + print("Trainable parameter list (with module paths):") + print("-"*70) + + # Helper function: recursively obtain all parameters of a module and their full paths + def get_named_params(module, parent_name=""): + named_params = [] + for name, param in module.named_parameters(recurse=recurse): + # Build the full path (parent module name + current parameter name) + full_name = f"{parent_name}.{name}" if parent_name else name + named_params.append((full_name, param)) + return named_params + + # 1. Get parameters of backbone_f and their paths + backbone_f_params = get_named_params(self.backbone_f, parent_name="backbone_f") + for i, (full_name, param) in enumerate(backbone_f_params, 1): + print(f"Parameter {i}: {full_name} | Shape: {param.shape}") + yield param + + # 2. Get parameters of the head and their paths + head_params = get_named_params(self.head, parent_name="head") + for i, (full_name, param) in enumerate(head_params, 1): + print(f"Parameter {i}: {full_name} | Shape: {param.shape}") + yield param + + # Count the total number + total = len(backbone_f_params) + len(head_params) + print("-"*70) + print(f"Total number of trainable parameters: {total}") + print("="*70) + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + feat_s = self.backbone_s(data_dict['image'])['pooler_output'] + feat_f = self.backbone_f(data_dict['image'])['pooler_output'] + return feat_s, feat_f + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + + # compute metrics for batch data + # auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + # metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.config['backbone_config']['num_classes']) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + feat_s, feat_f = self.features(data_dict) + features = feat_f - feat_s # the simplest disentanglement method: directly perform subtraction + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + # prob = torch.softmax(pred, dim=1)[:, 1] + prob = torch.softmax(pred, dim=1) + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + return pred_dict + + +def get_clip_visual(model_name = "openai/clip-vit-base-patch16"): + processor = AutoProcessor.from_pretrained(model_name) + model = CLIPModel.from_pretrained(model_name) + return processor, model.vision_model diff --git a/training/detectors/clip_large_fft_dis_orth1_detector.py b/training/detectors/clip_large_fft_dis_orth1_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..0389860053636c104bfde1b55c26f2b65525ef70 --- /dev/null +++ b/training/detectors/clip_large_fft_dis_orth1_detector.py @@ -0,0 +1,180 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the CLIPDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{radford2021learning, + title={Learning transferable visual models from natural language supervision}, + author={Radford, Alec and Kim, Jong Wook and Hallacy, Chris and Ramesh, Aditya and Goh, Gabriel and Agarwal, Sandhini and Sastry, Girish and Askell, Amanda and Mishkin, Pamela and Clark, Jack and others}, + booktitle={International conference on machine learning}, + pages={8748--8763}, + year={2021}, + organization={PMLR} +} +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +from transformers import AutoProcessor, CLIPModel, ViTModel, ViTConfig +import loralib as lora +import copy + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='clip_large_fft_dis_orth1') +class CLIP_Large_FFT_Dis_Orth1_Detector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone_s, self.backbone_f = self.build_backbone(config) + self.semantic_proj = nn.Linear(1024, 1024) + self.head = nn.Linear(1024, config['backbone_config']['num_classes']) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + # prepare the backbone + _, backbone_s = get_clip_visual(model_name=config['pretrained']) # frozen + _, backbone_f = get_clip_visual(model_name=config['pretrained']) # trainable + + for param in backbone_s.parameters(): + param.requires_grad = False + + return backbone_s, backbone_f + + def parameters(self, recurse: bool = True): + """ + Rewrite the parameter iteration method to return trainable parameters and display full module paths + """ + print("="*70) + print("Trainable parameter list (with module paths):") + print("-"*70) + + # Helper function: recursively obtain all parameters of a module and their full paths + def get_named_params(module, parent_name=""): + named_params = [] + for name, param in module.named_parameters(recurse=recurse): + # Build the full path (parent module name + current parameter name) + full_name = f"{parent_name}.{name}" if parent_name else name + named_params.append((full_name, param)) + return named_params + + # 1. Get parameters of backbone_f and their paths + backbone_f_params = get_named_params(self.backbone_f, parent_name="backbone_f") + for i, (full_name, param) in enumerate(backbone_f_params, 1): + print(f"Parameter {i}: {full_name} | Shape: {param.shape}") + yield param + + # 2. Get parameters of the head and their paths + head_params = get_named_params(self.head, parent_name="head") + for i, (full_name, param) in enumerate(head_params, 1): + print(f"Parameter {i}: {full_name} | Shape: {param.shape}") + yield param + + linear_params = get_named_params(self.semantic_proj, parent_name="head") + for i, (full_name, param) in enumerate(linear_params, 1): + print(f"Parameter {i}: {full_name} | Shape: {param.shape}") + yield param + + # Count the total number + total = len(backbone_f_params) + len(head_params) + len(linear_params) + print("-"*70) + print(f"Total number of trainable parameters: {total}") + print("="*70) + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + feat_clip = self.backbone_s(data_dict['image'])['pooler_output'] + feat_all = self.backbone_f(data_dict['image'])['pooler_output'] + + feat_s = self.semantic_proj(feat_all) + feat_f = feat_all - feat_s + + return feat_clip, feat_s, feat_f, feat_all + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss_ce, loss_orth, loss_mse = self.loss_func(pred, label, pred_dict['feat_clip'], pred_dict['feat_s'], pred_dict['feat']) + loss_dict = { + 'overall': loss_ce + loss_orth + loss_mse, + 'loss_ce': loss_ce, + 'loss_orth': loss_orth, + 'loss_mse': loss_mse, + } + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + + # compute metrics for batch data + # auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + # metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.config['backbone_config']['num_classes']) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + feat_clip, feat_s, feat_f, feat_all = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(feat_f) + # get the probability of the pred + # prob = torch.softmax(pred, dim=1)[:, 1] + prob = torch.softmax(pred, dim=1) + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat_clip': feat_clip, 'feat_s': feat_s, 'feat': feat_f} + # pred_dict = {'cls': pred, 'prob': prob, 'feat_clip': feat_clip, 'feat_s': feat_s, 'feat': feat_f + feat_s} # Alternative retrieval mode: use f+s features for retrieval + # pred_dict = {'cls': pred, 'prob': prob, 'feat_clip': feat_clip, 'feat_s': feat_s, 'feat': feat_all} # Alternative retrieval mode: use all features for retrieval + # pred_dict = {'cls': pred, 'prob': prob, 'feat_clip': feat_clip, 'feat_s': feat_s, 'feat': feat_all + feat_f} # Alternative retrieval mode: use all + f features for retrieval + # pred_dict = {'cls': pred, 'prob': prob, 'feat_clip': feat_clip, 'feat_s': feat_s, 'feat': feat_f + feat_s * 0.1} # Alternative retrieval mode: use all + f features for retrieval + return pred_dict + + +def get_clip_visual(model_name = "openai/clip-vit-base-patch16"): + processor = AutoProcessor.from_pretrained(model_name) + model = CLIPModel.from_pretrained(model_name) + return processor, model.vision_model diff --git a/training/detectors/clip_large_fft_dis_orth2_detector.py b/training/detectors/clip_large_fft_dis_orth2_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..c6f134536d5a151da19f0be322a4ec79cf06e22d --- /dev/null +++ b/training/detectors/clip_large_fft_dis_orth2_detector.py @@ -0,0 +1,189 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the CLIPDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{radford2021learning, + title={Learning transferable visual models from natural language supervision}, + author={Radford, Alec and Kim, Jong Wook and Hallacy, Chris and Ramesh, Aditya and Goh, Gabriel and Agarwal, Sandhini and Sastry, Girish and Askell, Amanda and Mishkin, Pamela and Clark, Jack and others}, + booktitle={International conference on machine learning}, + pages={8748--8763}, + year={2021}, + organization={PMLR} +} +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +from transformers import AutoProcessor, CLIPModel, ViTModel, ViTConfig +import loralib as lora +import copy + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='clip_large_fft_dis_orth2') +class CLIP_Large_FFT_Dis_Orth2_Detector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone_s, self.backbone_f = self.build_backbone(config) + self.semantic_proj = nn.Linear(1024, 1024) + self.finger_proj = nn.Linear(1024, 1024) + self.head = nn.Linear(1024, config['backbone_config']['num_classes']) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + # prepare the backbone + _, backbone_s = get_clip_visual(model_name=config['pretrained']) # frozen + _, backbone_f = get_clip_visual(model_name=config['pretrained']) # trainable + + for param in backbone_s.parameters(): + param.requires_grad = False + + return backbone_s, backbone_f + + def parameters(self, recurse: bool = True): + """Exclude the backbone_s network""" + modules_to_include = [ + self.backbone_f, + self.semantic_proj, + self.finger_proj, + self.head + ] + + for module in modules_to_include: + for param in module.parameters(recurse=recurse): + yield param + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def project_and_decompose(self, a, b, eps=1e-8): + """ + Decompose the 1024-dimensional vector a into the projection component along b and the orthogonal component perpendicular to b + + Args: + a: a 1024-dimensional vector with shape (..., 1024), supporting batched processing (e.g. (batch_size, 1024)) + b: a 1024-dimensional vector whose shape must match a (with the last dimension being 1024) + eps: a small value to prevent division by zero + + Returns: + proj: projection component along b, with the same shape as the input + ortho: orthogonal component perpendicular to b, with the same shape as the input + """ + # Compute the squared norm of b (..., 1), keeping the last dimension for broadcasting + b_norm_sq = torch.sum(b **2, dim=-1, keepdim=True) + eps + + # Compute the dot product of a and b (..., 1) + a_dot_b = torch.sum(a * b, dim=-1, keepdim=True) + + # Projection coefficient = dot product / squared norm of b (..., 1) + proj_coeff = a_dot_b / b_norm_sq + + # Projection component along b = coefficient * b + proj = proj_coeff * b # Use broadcasting to multiply the coefficients with b element-wise + + # Orthogonal component = original vector a - projection component + ortho = a - proj + + return proj, ortho + + def features(self, data_dict: dict) -> torch.tensor: + feat_clip = self.backbone_s(data_dict['image'])['pooler_output'] + feat_all = self.backbone_f(data_dict['image'])['pooler_output'] + + #### Disentanglement operations + # 1. Semantic projection disentanglement + feat_s = self.semantic_proj(feat_all) + # 2. Project onto the semantic direction and obtain the orthogonal component + proj, ortho = self.project_and_decompose(feat_all, feat_s) + # 3. Obtain the fingerprint feature + feat_f = self.finger_proj(ortho) + + return feat_clip, feat_s, feat_f, feat_all + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss_ce, loss_orth, loss_mse = self.loss_func(pred, label, pred_dict['feat_clip'], pred_dict['feat_s'], pred_dict['feat']) + loss_dict = { + 'overall': loss_ce + loss_orth + loss_mse, + 'loss_ce': loss_ce, + 'loss_orth': loss_orth, + 'loss_mse': loss_mse, + } + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + + # compute metrics for batch data + # auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + # metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.config['backbone_config']['num_classes']) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + feat_clip, feat_s, feat_f, feat_all = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(feat_f) + # get the probability of the pred + # prob = torch.softmax(pred, dim=1)[:, 1] + prob = torch.softmax(pred, dim=1) + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat_clip': feat_clip, 'feat_s': feat_s, 'feat': feat_f} + # pred_dict = {'cls': pred, 'prob': prob, 'feat_clip': feat_clip, 'feat_s': feat_s, 'feat': feat_f + feat_s} # Alternative retrieval mode: use f+s features for retrieval + # pred_dict = {'cls': pred, 'prob': prob, 'feat_clip': feat_clip, 'feat_s': feat_s, 'feat': feat_all} # Alternative retrieval mode: use all features for retrieval + # pred_dict = {'cls': pred, 'prob': prob, 'feat_clip': feat_clip, 'feat_s': feat_s, 'feat': feat_all + feat_f} # Alternative retrieval mode: use all + f features for retrieval + # pred_dict = {'cls': pred, 'prob': prob, 'feat_clip': feat_clip, 'feat_s': feat_s, 'feat': feat_f + feat_s * 0.1} # Alternative retrieval mode: use all + f features for retrieval + return pred_dict + + +def get_clip_visual(model_name = "openai/clip-vit-base-patch16"): + processor = AutoProcessor.from_pretrained(model_name) + model = CLIPModel.from_pretrained(model_name) + return processor, model.vision_model diff --git a/training/detectors/clip_large_fft_dis_orth3_detector.py b/training/detectors/clip_large_fft_dis_orth3_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..d919536f7023719519673cd49f4f80d1fb06c221 --- /dev/null +++ b/training/detectors/clip_large_fft_dis_orth3_detector.py @@ -0,0 +1,199 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the CLIPDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{radford2021learning, + title={Learning transferable visual models from natural language supervision}, + author={Radford, Alec and Kim, Jong Wook and Hallacy, Chris and Ramesh, Aditya and Goh, Gabriel and Agarwal, Sandhini and Sastry, Girish and Askell, Amanda and Mishkin, Pamela and Clark, Jack and others}, + booktitle={International conference on machine learning}, + pages={8748--8763}, + year={2021}, + organization={PMLR} +} +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +from transformers import AutoProcessor, CLIPModel, ViTModel, ViTConfig +import loralib as lora +import copy + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='clip_large_fft_dis_orth3') +class CLIP_Large_FFT_Dis_Orth3_Detector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone_s, self.backbone_f = self.build_backbone(config) + self.semantic_proj = nn.Sequential( + nn.Linear(1024, 1024), + nn.ReLU(), + nn.Linear(1024, 1024), + ) + self.finger_proj = nn.Sequential( + nn.Linear(1024, 1024), + nn.ReLU(), + nn.Linear(1024, 1024), + ) + self.head = nn.Linear(1024, config['backbone_config']['num_classes']) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + # prepare the backbone + _, backbone_s = get_clip_visual(model_name=config['pretrained']) # frozen + _, backbone_f = get_clip_visual(model_name=config['pretrained']) # trainable + + for param in backbone_s.parameters(): + param.requires_grad = False + + return backbone_s, backbone_f + + def parameters(self, recurse: bool = True): + """Exclude the backbone_s network""" + modules_to_include = [ + self.backbone_f, + self.semantic_proj, + self.finger_proj, + self.head + ] + + for module in modules_to_include: + for param in module.parameters(recurse=recurse): + yield param + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def project_and_decompose(self, a, b, eps=1e-8): + """ + Decompose the 1024-dimensional vector a into the projection component along b and the orthogonal component perpendicular to b + + Args: + a: a 1024-dimensional vector with shape (..., 1024), supporting batched processing (e.g. (batch_size, 1024)) + b: a 1024-dimensional vector whose shape must match a (with the last dimension being 1024) + eps: a small value to prevent division by zero + + Returns: + proj: projection component along b, with the same shape as the input + ortho: orthogonal component perpendicular to b, with the same shape as the input + """ + # Compute the squared norm of b (..., 1), keeping the last dimension for broadcasting + b_norm_sq = torch.sum(b **2, dim=-1, keepdim=True) + eps + + # Compute the dot product of a and b (..., 1) + a_dot_b = torch.sum(a * b, dim=-1, keepdim=True) + + # Projection coefficient = dot product / squared norm of b (..., 1) + proj_coeff = a_dot_b / b_norm_sq + + # Projection component along b = coefficient * b + proj = proj_coeff * b # Use broadcasting to multiply the coefficients with b element-wise + + # Orthogonal component = original vector a - projection component + ortho = a - proj + + return proj, ortho + + def features(self, data_dict: dict) -> torch.tensor: + feat_clip = self.backbone_s(data_dict['image'])['pooler_output'] + feat_all = self.backbone_f(data_dict['image'])['pooler_output'] + + #### Disentanglement operations + # 1. Semantic projection disentanglement + feat_s = self.semantic_proj(feat_all) + # 2. Project onto the semantic direction and obtain the orthogonal component + proj, ortho = self.project_and_decompose(feat_all, feat_s) + # 3. Obtain the fingerprint feature + feat_f = self.finger_proj(ortho) + + return feat_clip, feat_s, feat_f, feat_all + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss_ce, loss_orth, loss_mse = self.loss_func(pred, label, pred_dict['feat_clip'], pred_dict['feat_s'], pred_dict['feat']) + loss_dict = { + 'overall': loss_ce + loss_orth + loss_mse, + 'loss_ce': loss_ce, + 'loss_orth': loss_orth, + 'loss_mse': loss_mse, + } + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + + # compute metrics for batch data + # auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + # metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.config['backbone_config']['num_classes']) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + feat_clip, feat_s, feat_f, feat_all = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(feat_f) + # get the probability of the pred + # prob = torch.softmax(pred, dim=1)[:, 1] + prob = torch.softmax(pred, dim=1) + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat_clip': feat_clip, 'feat_s': feat_s, 'feat': feat_f} + # pred_dict = {'cls': pred, 'prob': prob, 'feat_clip': feat_clip, 'feat_s': feat_s, 'feat': feat_f + feat_s} # Alternative retrieval mode: use f+s features for retrieval + # pred_dict = {'cls': pred, 'prob': prob, 'feat_clip': feat_clip, 'feat_s': feat_s, 'feat': feat_all} # Alternative retrieval mode: use all features for retrieval + # pred_dict = {'cls': pred, 'prob': prob, 'feat_clip': feat_clip, 'feat_s': feat_s, 'feat': feat_all + feat_f} # Alternative retrieval mode: use all + f features for retrieval + # pred_dict = {'cls': pred, 'prob': prob, 'feat_clip': feat_clip, 'feat_s': feat_s, 'feat': feat_f + feat_s * 0.1} # Alternative retrieval mode: use all + f features for retrieval + # pred_dict = {'cls': pred, 'prob': prob, 'feat_clip': feat_clip, 'feat_s': feat_s, 'feat': feat_s} # use semantic features for retrieval + + return pred_dict + + +def get_clip_visual(model_name = "openai/clip-vit-base-patch16"): + processor = AutoProcessor.from_pretrained(model_name) + model = CLIPModel.from_pretrained(model_name) + return processor, model.vision_model diff --git a/training/detectors/clip_large_fft_dis_orth4_detector.py b/training/detectors/clip_large_fft_dis_orth4_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..41ff7d0f881427a4fea0eafa7e6b534f1ba8a5ed --- /dev/null +++ b/training/detectors/clip_large_fft_dis_orth4_detector.py @@ -0,0 +1,222 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the CLIPDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{radford2021learning, + title={Learning transferable visual models from natural language supervision}, + author={Radford, Alec and Kim, Jong Wook and Hallacy, Chris and Ramesh, Aditya and Goh, Gabriel and Agarwal, Sandhini and Sastry, Girish and Askell, Amanda and Mishkin, Pamela and Clark, Jack and others}, + booktitle={International conference on machine learning}, + pages={8748--8763}, + year={2021}, + organization={PMLR} +} +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +from transformers import AutoProcessor, CLIPModel, ViTModel, ViTConfig +import loralib as lora +import copy + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='clip_large_fft_dis_orth4') +class CLIP_Large_FFT_Dis_Orth4_Detector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + backbone_s_clip, backbone_s_dino, backbone_f = self.build_backbone(config) + self.semantic_clip_proj = nn.Linear(1024, 1024) + self.semantic_dino_proj = nn.Linear(1024, 1024) + self.finger_proj = nn.Linear(1024, 1024) # single fingerprint feature + self.head = nn.Linear(1024, config['backbone_config']['num_classes']) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + # prepare the backbone + _, backbone_s_clip = get_clip_visual(model_name=config['pretrained']) # frozen + backbone_s_dino = torch.hub.load('facebookresearch/dinov2', "dinov2_vitl14") # frozen + + _, backbone_f = get_clip_visual(model_name=config['pretrained']) # trainable + + for param in backbone_s_clip.parameters(): + param.requires_grad = False + + for param in backbone_s_dino.parameters(): + param.requires_grad = False + + return backbone_s_clip, backbone_s_dino, backbone_f + + def parameters(self, recurse: bool = True): + """ + Rewrite the parameter iteration method to return trainable parameters and display full module paths + """ + print("="*70) + print("Trainable parameter list (with module paths):") + print("-"*70) + + # Helper function: recursively obtain all parameters of a module and their full paths + def get_named_params(module, parent_name=""): + named_params = [] + for name, param in module.named_parameters(recurse=recurse): + # Build the full path (parent module name + current parameter name) + full_name = f"{parent_name}.{name}" if parent_name else name + named_params.append((full_name, param)) + return named_params + + # 1. Get parameters of backbone_f and their paths + backbone_f_params = get_named_params(self.backbone_f, parent_name="backbone_f") + for i, (full_name, param) in enumerate(backbone_f_params, 1): + print(f"Parameter {i}: {full_name} | Shape: {param.shape}") + yield param + + # 2. Get parameters of the head and their paths + head_params = get_named_params(self.head, parent_name="head") + for i, (full_name, param) in enumerate(head_params, 1): + print(f"Parameter {i}: {full_name} | Shape: {param.shape}") + yield param + + linear_params = get_named_params(self.semantic_proj, parent_name="head") + for i, (full_name, param) in enumerate(linear_params, 1): + print(f"Parameter {i}: {full_name} | Shape: {param.shape}") + yield param + + # Count the total number + total = len(backbone_f_params) + len(head_params) + len(linear_params) + print("-"*70) + print(f"Total number of trainable parameters: {total}") + print("="*70) + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def project_and_decompose(self, a, b, eps=1e-8): + """ + Decompose the 1024-dimensional vector a into the projection component along b and the orthogonal component perpendicular to b + + Args: + a: a 1024-dimensional vector with shape (..., 1024), supporting batched processing (e.g. (batch_size, 1024)) + b: a 1024-dimensional vector whose shape must match a (with the last dimension being 1024) + eps: a small value to prevent division by zero + + Returns: + proj: projection component along b, with the same shape as the input + ortho: orthogonal component perpendicular to b, with the same shape as the input + """ + # Compute the squared norm of b (..., 1), keeping the last dimension for broadcasting + b_norm_sq = torch.sum(b **2, dim=-1, keepdim=True) + eps + + # Compute the dot product of a and b (..., 1) + a_dot_b = torch.sum(a * b, dim=-1, keepdim=True) + + # Projection coefficient = dot product / squared norm of b (..., 1) + proj_coeff = a_dot_b / b_norm_sq + + # Projection component along b = coefficient * b + proj = proj_coeff * b # Use broadcasting to multiply the coefficients with b element-wise + + # Orthogonal component = original vector a - projection component + ortho = a - proj + + return proj, ortho + + def features(self, data_dict: dict) -> torch.tensor: + feat_clip = self.backbone_s(data_dict['image'])['pooler_output'] + feat_all = self.backbone_f(data_dict['image'])['pooler_output'] + + #### Disentanglement operations + # 1. Semantic projection disentanglement + feat_s = self.semantic_proj(feat_all) + # 2. Project onto the semantic direction and obtain the orthogonal component + proj, ortho = self.project_and_decompose(feat_all, feat_s) + # 3. Obtain the fingerprint feature + feat_f = self.finger_proj(ortho) + + return feat_clip, feat_s, feat_f, feat_all + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss_ce, loss_orth, loss_mse = self.loss_func(pred, label, pred_dict['feat_clip'], pred_dict['feat_s'], pred_dict['feat']) + loss_dict = { + 'overall': loss_ce + loss_orth + loss_mse, + 'loss_ce': loss_ce, + 'loss_orth': loss_orth, + 'loss_mse': loss_mse, + } + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + + # compute metrics for batch data + # auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + # metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.config['backbone_config']['num_classes']) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + feat_clip, feat_s, feat_f, feat_all = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(feat_f) + # get the probability of the pred + # prob = torch.softmax(pred, dim=1)[:, 1] + prob = torch.softmax(pred, dim=1) + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat_clip': feat_clip, 'feat_s': feat_s, 'feat': feat_f} + # pred_dict = {'cls': pred, 'prob': prob, 'feat_clip': feat_clip, 'feat_s': feat_s, 'feat': feat_f + feat_s} # Alternative retrieval mode: use f+s features for retrieval + # pred_dict = {'cls': pred, 'prob': prob, 'feat_clip': feat_clip, 'feat_s': feat_s, 'feat': feat_all} # Alternative retrieval mode: use all features for retrieval + # pred_dict = {'cls': pred, 'prob': prob, 'feat_clip': feat_clip, 'feat_s': feat_s, 'feat': feat_all + feat_f} # Alternative retrieval mode: use all + f features for retrieval + # pred_dict = {'cls': pred, 'prob': prob, 'feat_clip': feat_clip, 'feat_s': feat_s, 'feat': feat_f + feat_s * 0.1} # Alternative retrieval mode: use all + f features for retrieval + return pred_dict + + +def get_clip_visual(model_name = "openai/clip-vit-base-patch16"): + processor = AutoProcessor.from_pretrained(model_name) + model = CLIPModel.from_pretrained(model_name) + return processor, model.vision_model diff --git a/training/detectors/clip_large_fft_dis_orth_detector.py b/training/detectors/clip_large_fft_dis_orth_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..f0e45eabb25ddfabefd45857cbfe2bd76791c297 --- /dev/null +++ b/training/detectors/clip_large_fft_dis_orth_detector.py @@ -0,0 +1,166 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the CLIPDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{radford2021learning, + title={Learning transferable visual models from natural language supervision}, + author={Radford, Alec and Kim, Jong Wook and Hallacy, Chris and Ramesh, Aditya and Goh, Gabriel and Agarwal, Sandhini and Sastry, Girish and Askell, Amanda and Mishkin, Pamela and Clark, Jack and others}, + booktitle={International conference on machine learning}, + pages={8748--8763}, + year={2021}, + organization={PMLR} +} +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +from transformers import AutoProcessor, CLIPModel, ViTModel, ViTConfig +import loralib as lora +import copy + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='clip_large_fft_dis_orth') +class CLIP_Large_FFT_Dis_Orth_Detector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone_s, self.backbone_f = self.build_backbone(config) + self.head = nn.Linear(1024, config['backbone_config']['num_classes']) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + # prepare the backbone + _, backbone_s = get_clip_visual(model_name=config['pretrained']) # frozen + _, backbone_f = get_clip_visual(model_name=config['pretrained']) # trainable + + for param in backbone_s.parameters(): + param.requires_grad = False + + return backbone_s, backbone_f + + def parameters(self, recurse: bool = True): + """ + Rewrite the parameter iteration method to return trainable parameters and display full module paths + """ + print("="*70) + print("Trainable parameter list (with module paths):") + print("-"*70) + + # Helper function: recursively obtain all parameters of a module and their full paths + def get_named_params(module, parent_name=""): + named_params = [] + for name, param in module.named_parameters(recurse=recurse): + # Build the full path (parent module name + current parameter name) + full_name = f"{parent_name}.{name}" if parent_name else name + named_params.append((full_name, param)) + return named_params + + # 1. Get parameters of backbone_f and their paths + backbone_f_params = get_named_params(self.backbone_f, parent_name="backbone_f") + for i, (full_name, param) in enumerate(backbone_f_params, 1): + print(f"Parameter {i}: {full_name} | Shape: {param.shape}") + yield param + + # 2. Get parameters of the head and their paths + head_params = get_named_params(self.head, parent_name="head") + for i, (full_name, param) in enumerate(head_params, 1): + print(f"Parameter {i}: {full_name} | Shape: {param.shape}") + yield param + + # Count the total number + total = len(backbone_f_params) + len(head_params) + print("-"*70) + print(f"Total number of trainable parameters: {total}") + print("="*70) + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + feat_s = self.backbone_s(data_dict['image'])['pooler_output'] + feat_f = self.backbone_f(data_dict['image'])['pooler_output'] + return feat_s, feat_f + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss_ce, loss_orth = self.loss_func(pred, label, pred_dict['feat_s'], pred_dict['feat']) + loss_dict = { + 'overall': loss_ce + loss_orth, + 'loss_ce': loss_ce, + 'loss_orth': loss_orth, + } + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + + # compute metrics for batch data + # auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + # metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.config['backbone_config']['num_classes']) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + feat_s, feat_f = self.features(data_dict) + features = feat_f - feat_s # the simplest disentanglement method: directly perform subtraction + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + # prob = torch.softmax(pred, dim=1)[:, 1] + prob = torch.softmax(pred, dim=1) + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features, 'feat_s': feat_s} + return pred_dict + + +def get_clip_visual(model_name = "openai/clip-vit-base-patch16"): + processor = AutoProcessor.from_pretrained(model_name) + model = CLIPModel.from_pretrained(model_name) + return processor, model.vision_model diff --git a/training/detectors/clip_large_fft_supcon_cls_detector.py b/training/detectors/clip_large_fft_supcon_cls_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..a68c923902ccee9334760892bea97064448373cf --- /dev/null +++ b/training/detectors/clip_large_fft_supcon_cls_detector.py @@ -0,0 +1,121 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the CLIPDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{radford2021learning, + title={Learning transferable visual models from natural language supervision}, + author={Radford, Alec and Kim, Jong Wook and Hallacy, Chris and Ramesh, Aditya and Goh, Gabriel and Agarwal, Sandhini and Sastry, Girish and Askell, Amanda and Mishkin, Pamela and Clark, Jack and others}, + booktitle={International conference on machine learning}, + pages={8748--8763}, + year={2021}, + organization={PMLR} +} +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +from transformers import AutoProcessor, CLIPModel, ViTModel, ViTConfig +import loralib as lora +import copy + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='clip_large_fft_supcon_cls') +class CLIP_Large_FFT_SupCon_Cls_Detector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.head = nn.Linear(1024, config['backbone_config']['num_classes']) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + # prepare the backbone + _, backbone = get_clip_visual(model_name=config['pretrained']) + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + feat = self.backbone(data_dict['image'])['pooler_output'] + return feat + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + feat = pred_dict['feat'] + loss = self.loss_func(feat, pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + + # compute metrics for batch data + # auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + # metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.config['backbone_config']['num_classes']) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + # prob = torch.softmax(pred, dim=1)[:, 1] + prob = torch.softmax(pred, dim=1) + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + return pred_dict + + +def get_clip_visual(model_name = "openai/clip-vit-base-patch16"): + processor = AutoProcessor.from_pretrained(model_name) + model = CLIPModel.from_pretrained(model_name) + return processor, model.vision_model diff --git a/training/detectors/clip_large_fft_supcon_detector.py b/training/detectors/clip_large_fft_supcon_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..8eb0222042e039e4fb8f58f8d4583133cfadd93b --- /dev/null +++ b/training/detectors/clip_large_fft_supcon_detector.py @@ -0,0 +1,120 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the CLIPDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{radford2021learning, + title={Learning transferable visual models from natural language supervision}, + author={Radford, Alec and Kim, Jong Wook and Hallacy, Chris and Ramesh, Aditya and Goh, Gabriel and Agarwal, Sandhini and Sastry, Girish and Askell, Amanda and Mishkin, Pamela and Clark, Jack and others}, + booktitle={International conference on machine learning}, + pages={8748--8763}, + year={2021}, + organization={PMLR} +} +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +from transformers import AutoProcessor, CLIPModel, ViTModel, ViTConfig +import loralib as lora +import copy + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='clip_large_fft_supcon') +class CLIP_Large_FFT_SupCon_Detector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.head = nn.Linear(1024, config['backbone_config']['num_classes']) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + # prepare the backbone + _, backbone = get_clip_visual(model_name=config['pretrained']) + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + feat = self.backbone(data_dict['image'])['pooler_output'] + return feat + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + feat = pred_dict['feat'] + loss = self.loss_func(feat, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + + # compute metrics for batch data + # auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + # metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.config['backbone_config']['num_classes']) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + # prob = torch.softmax(pred, dim=1)[:, 1] + prob = torch.softmax(pred, dim=1) + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + return pred_dict + + +def get_clip_visual(model_name = "openai/clip-vit-base-patch16"): + processor = AutoProcessor.from_pretrained(model_name) + model = CLIPModel.from_pretrained(model_name) + return processor, model.vision_model diff --git a/training/detectors/clip_large_fft_vae1_detector.py b/training/detectors/clip_large_fft_vae1_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..6b16dfc23755b62909eb9e49d1ad9dd4de27db39 --- /dev/null +++ b/training/detectors/clip_large_fft_vae1_detector.py @@ -0,0 +1,247 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the CLIPDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{radford2021learning, + title={Learning transferable visual models from natural language supervision}, + author={Radford, Alec and Kim, Jong Wook and Hallacy, Chris and Ramesh, Aditya and Goh, Gabriel and Agarwal, Sandhini and Sastry, Girish and Askell, Amanda and Mishkin, Pamela and Clark, Jack and others}, + booktitle={International conference on machine learning}, + pages={8748--8763}, + year={2021}, + organization={PMLR} +} +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +from transformers import AutoProcessor, CLIPModel, ViTModel, ViTConfig +import loralib as lora +import copy + +logger = logging.getLogger(__name__) + + +def reparameterize(mu, logvar): + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return mu + eps * std + + +@DETECTOR.register_module(module_name='clip_large_fft_vae1') +class CLIP_Large_FFT_VAE1_Detector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + # 1.backbone network + self.clip_model, self.vit = self.build_backbone(config) + self.clip_feat_dim = 1024 + # 2.VAE disentanglement encoder + self.feat_dim = 1024 + self.semantic_dim = 512 # semantic feature dimension + self.fingerprint_dim = 512 # fingerprint feature dimension + # Semantic encoder (outputs 512 dimensions and aligns with CLIP features directly or through a projection layer) + self.semantic_encoder = nn.Sequential( + nn.Linear(self.feat_dim, 768), + nn.ReLU(), + nn.Linear(768, self.semantic_dim) + ) + # Projection layer for semantic features (used for alignment when semantic and CLIP dimensions differ) + if self.semantic_dim != self.clip_feat_dim: + self.sem_proj_to_clip = nn.Linear(self.semantic_dim, self.clip_feat_dim) + else: + self.sem_proj_to_clip = nn.Identity() # No projection is needed when dimensions already match + # Fingerprint encoder + self.fingerprint_encoder = nn.Sequential( + nn.Linear(self.feat_dim, 768), + nn.ReLU(), + nn.Linear(768, self.fingerprint_dim) + ) + # 3.decoder and classification head + self.decoder = nn.Sequential( # Simplified decoder that keeps the reconstruction in 1024 dimensions + nn.Linear(self.feat_dim, 768), + nn.ReLU(), + nn.Linear(768, self.feat_dim) + ) + self.head = nn.Linear(self.fingerprint_dim, config['backbone_config']['num_classes']) + # 4.VAE mean and variance layers (simplified to direct outputs without intermediate layers) + self.semantic_mu = nn.Linear(self.semantic_dim, self.semantic_dim) + self.semantic_logvar = nn.Linear(self.semantic_dim, self.semantic_dim) + self.fingerprint_mu = nn.Linear(self.fingerprint_dim, self.fingerprint_dim) + self.fingerprint_logvar = nn.Linear(self.fingerprint_dim, self.fingerprint_dim) + + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + # prepare the backbone + _, clip_model = get_clip_visual(model_name=config['pretrained']) # frozen + _, vit = get_clip_visual(model_name=config['pretrained']) # trainable + + for param in clip_model.parameters(): + param.requires_grad = False + + return clip_model, vit + + def build_loss(self, config): + loss_func = DeepfakeVAELoss() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + vit_feat = self.vit(data_dict['image'])['pooler_output'] + with torch.no_grad(): + clip_feat = self.clip_model(data_dict['image'])['pooler_output'] + clip_feat = F.normalize(clip_feat, dim=-1) # normalize CLIP features + + return clip_feat, vit_feat + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + loss_dict = self.loss_func(pred_dict, label) + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + + # compute metrics for batch data + # auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + # metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.config['backbone_config']['num_classes']) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # Step 1: extract ViT and CLIP features as the semantic anchor + clip_feat, vit_feat = self.features(data_dict) + + # Step 2: disentangle semantic and fingerprint features + z_semantic_raw = self.semantic_encoder(vit_feat) # (batch_size, 512) + sem_mu = self.semantic_mu(z_semantic_raw) + sem_logvar = self.semantic_logvar(z_semantic_raw) + z_semantic = reparameterize(sem_mu, sem_logvar) # semantic feature + + z_fingerprint_raw = self.fingerprint_encoder(vit_feat) # (batch_size, 512) + fing_mu = self.fingerprint_mu(z_fingerprint_raw) + fing_logvar = self.fingerprint_logvar(z_fingerprint_raw) + z_fingerprint = reparameterize(fing_mu, fing_logvar) # fingerprint feature + + # Step 3: project semantic features into the CLIP space for regularization + z_semantic_clip = self.sem_proj_to_clip(z_semantic) # align CLIP dimensions + z_semantic_clip = F.normalize(z_semantic_clip, dim=-1) # normalize to ensure a consistent scale + + # Step 4: reconstruction and classification + recon_vit_feat = self.decoder(torch.cat([z_semantic, z_fingerprint], dim=-1)) # reconstruct from concatenation + + # get the prediction by classifier + pred = self.classifier(z_fingerprint) + # get the probability of the pred + # prob = torch.softmax(pred, dim=1)[:, 1] + prob = torch.softmax(pred, dim=1) + # build the prediction dict for each output + pred_dict = { + #### Two ViT features from the beginning and the end + "vit_feat": vit_feat, # ViT output features + 'recon_vit_feat': recon_vit_feat, # reconstructed ViT features + #### fingerprint feature used for discrimination + 'feat': z_fingerprint, + #### semantic-related features + "z_semantic": z_semantic, + 'clip_feat': clip_feat, # static CLIP semantic feature + "z_semantic_clip": z_semantic_clip, # semantic features projected into CLIP space + # parameter features + "sem_mu": sem_mu, + "sem_logvar": sem_logvar, + "fing_mu": fing_mu, + "fing_logvar": fing_logvar, + # prediction-score features + 'cls': pred, + 'prob': prob, + } + return pred_dict + + +def get_clip_visual(model_name = "openai/clip-vit-base-patch16"): + processor = AutoProcessor.from_pretrained(model_name) + model = CLIPModel.from_pretrained(model_name) + return processor, model.vision_model + + +# Loss function: add CLIP semantic regularization +class DeepfakeVAELoss(nn.Module): + def __init__(self, kl_weight=0.001, recon_weight=1.0, clip_sem_weight=0.1): + super().__init__() + self.kl_weight = kl_weight + self.recon_weight = recon_weight + self.clip_sem_weight = clip_sem_weight # weight of the CLIP semantic regularization + self.ce_loss = nn.CrossEntropyLoss() + self.mse_loss = nn.MSELoss() + + def kl_divergence(self, mu, logvar): + return -0.5 * torch.sum(1 + logvar - mu**2 - logvar.exp(), dim=1).mean() + + def forward(self, model_outputs, targets): + # 1. Original loss + loss_ce = self.ce_loss(model_outputs["cls"], targets) + loss_recon = self.mse_loss(model_outputs["recon_vit_feat"], model_outputs["vit_feat"]) + loss_kl_sem = self.kl_divergence(model_outputs["sem_mu"], model_outputs["sem_logvar"]) + loss_kl_fing = self.kl_divergence(model_outputs["fing_mu"], model_outputs["fing_logvar"]) + loss_kl = loss_kl_sem + loss_kl_fing + + # 2. Added: CLIP semantic regularization (MSE loss) + # Make the projected semantic features as close as possible to CLIP features + loss_clip_sem = self.mse_loss( + model_outputs["z_semantic_clip"], # disentangled semantic features projected into CLIP space + model_outputs["clip_feat"] # semantic features extracted by CLIP (anchor) + ) + + # 3. Total loss + total_loss = ( + loss_ce + + self.recon_weight * loss_recon + + self.kl_weight * loss_kl + + self.clip_sem_weight * loss_clip_sem # add CLIP regularization + ) + return { + "overall": total_loss, + "loss_ce": loss_ce, + "loss_recon": loss_recon, + "loss_kl": loss_kl, + "loss_clip_sem": loss_clip_sem # monitor the effect of semantic regularization + } diff --git a/training/detectors/clip_large_fft_vae2_detector.py b/training/detectors/clip_large_fft_vae2_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..9ca24c818556c3ccc36b811544e8f4be95f9f625 --- /dev/null +++ b/training/detectors/clip_large_fft_vae2_detector.py @@ -0,0 +1,262 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the CLIPDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{radford2021learning, + title={Learning transferable visual models from natural language supervision}, + author={Radford, Alec and Kim, Jong Wook and Hallacy, Chris and Ramesh, Aditya and Goh, Gabriel and Agarwal, Sandhini and Sastry, Girish and Askell, Amanda and Mishkin, Pamela and Clark, Jack and others}, + booktitle={International conference on machine learning}, + pages={8748--8763}, + year={2021}, + organization={PMLR} +} +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +from transformers import AutoProcessor, CLIPModel, ViTModel, ViTConfig +import loralib as lora +import copy + +logger = logging.getLogger(__name__) + + +def reparameterize(mu, logvar): + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return mu + eps * std + + +@DETECTOR.register_module(module_name='clip_large_fft_vae2') +class CLIP_Large_FFT_VAE2_Detector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + # 1.backbone network + self.clip_model, self.vit = self.build_backbone(config) + self.clip_feat_dim = 1024 + # 2.VAE disentanglement encoder + self.feat_dim = 1024 + self.semantic_dim = 512 # semantic feature dimension + self.fingerprint_dim = 512 # fingerprint feature dimension + # Semantic encoder (outputs 512 dimensions and aligns with CLIP features directly or through a projection layer) + self.semantic_encoder = nn.Sequential( + nn.Linear(self.feat_dim, 768), + nn.ReLU(), + nn.Linear(768, self.semantic_dim) + ) + # Projection layer for semantic features (used for alignment when semantic and CLIP dimensions differ) + if self.semantic_dim != self.clip_feat_dim: + self.sem_proj_to_clip = nn.Linear(self.semantic_dim, self.clip_feat_dim) + else: + self.sem_proj_to_clip = nn.Identity() # No projection is needed when dimensions already match + # Fingerprint encoder + self.fingerprint_encoder = nn.Sequential( + nn.Linear(self.feat_dim, 768), + nn.ReLU(), + nn.Linear(768, self.fingerprint_dim) + ) + # 3.decoder and classification head + self.decoder = nn.Sequential( # Simplified decoder that keeps the reconstruction in 1024 dimensions + nn.Linear(self.feat_dim, 768), + nn.ReLU(), + nn.Linear(768, self.feat_dim) + ) + self.head = nn.Linear(self.fingerprint_dim, config['backbone_config']['num_classes']) + # 4.VAE mean and variance layers (simplified to direct outputs without intermediate layers) + self.semantic_mu = nn.Linear(self.semantic_dim, self.semantic_dim) + self.semantic_logvar = nn.Linear(self.semantic_dim, self.semantic_dim) + self.fingerprint_mu = nn.Linear(self.fingerprint_dim, self.fingerprint_dim) + self.fingerprint_logvar = nn.Linear(self.fingerprint_dim, self.fingerprint_dim) + + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + # prepare the backbone + _, clip_model = get_clip_visual(model_name=config['pretrained']) # frozen + _, vit = get_clip_visual(model_name=config['pretrained']) # trainable + + # Freeze CLIP weights + for param in clip_model.parameters(): + param.requires_grad = False + + # Load ViT weights trained for multi-class classification + ckpt = torch.load("/Youtu_Pangu_Security_Public_cq11/shunliwang/DeepFakeBench_DFG/logs/clip_models/clip_large_fft_2025-10-10-12-16-50/test/protocol_2_test/ckpt_best.pth", map_location="cuda") # module. + new_state_dict = {k.replace('module.', ''): v + for k, v in ckpt.items() + if k.startswith('module.')} + new_state_dict = {k.replace('backbone.', ''): v + for k, v in new_state_dict.items() + if k.startswith('backbone.')} + vit.load_state_dict(new_state_dict, strict=True) + + # Freeze ViT weights and train only the disentanglement module + for param in vit.parameters(): + param.requires_grad = False + + return clip_model, vit + + def build_loss(self, config): + loss_func = DeepfakeVAELoss() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + vit_feat = self.vit(data_dict['image'])['pooler_output'] + with torch.no_grad(): + clip_feat = self.clip_model(data_dict['image'])['pooler_output'] + clip_feat = F.normalize(clip_feat, dim=-1) # normalize CLIP features + + return clip_feat, vit_feat + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + loss_dict = self.loss_func(pred_dict, label) + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + + # compute metrics for batch data + # auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + # metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.config['backbone_config']['num_classes']) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # Step 1: extract ViT and CLIP features as the semantic anchor + clip_feat, vit_feat = self.features(data_dict) + + # Step 2: disentangle semantic and fingerprint features + z_semantic_raw = self.semantic_encoder(vit_feat) # (batch_size, 512) + sem_mu = self.semantic_mu(z_semantic_raw) + sem_logvar = self.semantic_logvar(z_semantic_raw) + z_semantic = reparameterize(sem_mu, sem_logvar) # semantic feature + + z_fingerprint_raw = self.fingerprint_encoder(vit_feat) # (batch_size, 512) + fing_mu = self.fingerprint_mu(z_fingerprint_raw) + fing_logvar = self.fingerprint_logvar(z_fingerprint_raw) + z_fingerprint = reparameterize(fing_mu, fing_logvar) # fingerprint feature + + # Step 3: project semantic features into the CLIP space for regularization + z_semantic_clip = self.sem_proj_to_clip(z_semantic) # align CLIP dimensions + z_semantic_clip = F.normalize(z_semantic_clip, dim=-1) # normalize to ensure a consistent scale + + # Step 4: reconstruction and classification + recon_vit_feat = self.decoder(torch.cat([z_semantic, z_fingerprint], dim=-1)) # reconstruct from concatenation + + # get the prediction by classifier + pred = self.classifier(z_fingerprint) + # get the probability of the pred + # prob = torch.softmax(pred, dim=1)[:, 1] + prob = torch.softmax(pred, dim=1) + # build the prediction dict for each output + pred_dict = { + #### Two ViT features from the beginning and the end + "vit_feat": vit_feat, # ViT output features + 'recon_vit_feat': recon_vit_feat, # reconstructed ViT features + #### fingerprint feature used for discrimination + 'feat': z_fingerprint, + #### semantic-related features + "z_semantic": z_semantic, + 'clip_feat': clip_feat, # static CLIP semantic feature + "z_semantic_clip": z_semantic_clip, # semantic features projected into CLIP space + # parameter features + "sem_mu": sem_mu, + "sem_logvar": sem_logvar, + "fing_mu": fing_mu, + "fing_logvar": fing_logvar, + # prediction-score features + 'cls': pred, + 'prob': prob, + } + return pred_dict + + +def get_clip_visual(model_name = "openai/clip-vit-base-patch16"): + processor = AutoProcessor.from_pretrained(model_name) + model = CLIPModel.from_pretrained(model_name) + return processor, model.vision_model + + +# Loss function: add CLIP semantic regularization +class DeepfakeVAELoss(nn.Module): + def __init__(self, kl_weight=0.001, recon_weight=1.0, clip_sem_weight=0.1): + super().__init__() + self.kl_weight = kl_weight + self.recon_weight = recon_weight + self.clip_sem_weight = clip_sem_weight # weight of the CLIP semantic regularization + self.ce_loss = nn.CrossEntropyLoss() + self.mse_loss = nn.MSELoss() + + def kl_divergence(self, mu, logvar): + return -0.5 * torch.sum(1 + logvar - mu**2 - logvar.exp(), dim=1).mean() + + def forward(self, model_outputs, targets): + # 1. Original loss + loss_ce = self.ce_loss(model_outputs["cls"], targets) + loss_recon = self.mse_loss(model_outputs["recon_vit_feat"], model_outputs["vit_feat"]) + loss_kl_sem = self.kl_divergence(model_outputs["sem_mu"], model_outputs["sem_logvar"]) + loss_kl_fing = self.kl_divergence(model_outputs["fing_mu"], model_outputs["fing_logvar"]) + loss_kl = loss_kl_sem + loss_kl_fing + + # 2. Added: CLIP semantic regularization (MSE loss) + # Make the projected semantic features as close as possible to CLIP features + loss_clip_sem = self.mse_loss( + model_outputs["z_semantic_clip"], # disentangled semantic features projected into CLIP space + model_outputs["clip_feat"] # semantic features extracted by CLIP (anchor) + ) + + # 3. Total loss + total_loss = ( + loss_ce + + self.recon_weight * loss_recon + + self.kl_weight * loss_kl + + self.clip_sem_weight * loss_clip_sem # add CLIP regularization + ) + return { + "overall": total_loss, + "loss_ce": loss_ce, + "loss_recon": loss_recon, + "loss_kl": loss_kl, + "loss_clip_sem": loss_clip_sem # monitor the effect of semantic regularization + } diff --git a/training/detectors/clip_large_lora_detector.py b/training/detectors/clip_large_lora_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..611da4eecd615ce3bdb5c27be8be86c7cbb3a0a2 --- /dev/null +++ b/training/detectors/clip_large_lora_detector.py @@ -0,0 +1,171 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the CLIPDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{radford2021learning, + title={Learning transferable visual models from natural language supervision}, + author={Radford, Alec and Kim, Jong Wook and Hallacy, Chris and Ramesh, Aditya and Goh, Gabriel and Agarwal, Sandhini and Sastry, Girish and Askell, Amanda and Mishkin, Pamela and Clark, Jack and others}, + booktitle={International conference on machine learning}, + pages={8748--8763}, + year={2021}, + organization={PMLR} +} +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +from transformers import AutoProcessor, CLIPModel, ViTModel, ViTConfig +import loralib as lora +import copy + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='clip_large_lora') +class CLIP_Large_LoRA_Detector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.head = nn.Linear(1024, config['backbone_config']['num_classes']) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + # prepare the backbone + _, backbone = get_clip_visual(model_name=config['pretrained']) + backbone = to_lora(backbone, r=1) + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + feat = self.backbone(data_dict['image'])['pooler_output'] + return feat + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + + # compute metrics for batch data + # auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + # metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.config['backbone_config']['num_classes']) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + # prob = torch.softmax(pred, dim=1)[:, 1] + prob = torch.softmax(pred, dim=1) + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + return pred_dict + + +def get_clip_visual(model_name = "openai/clip-vit-base-patch16"): + processor = AutoProcessor.from_pretrained(model_name) + model = CLIPModel.from_pretrained(model_name) + return processor, model.vision_model + + +def to_lora(model, target=[nn.Linear, nn.Conv2d, nn.Embedding], r=16, f_class=None, layers=None, names=None): + + for n, m in model.named_modules(): + if f_class is not None and not isinstance(m, f_class): + continue + if isinstance(m, nn.Sequential) or isinstance(m, nn.ModuleList): + + for name, mod in m.named_children(): + + # print(name, mod) + if isinstance(mod, nn.Linear) and not isinstance(mod, lora.Linear): + mod = change_mod(mod, r=r) + m._modules[name] = mod + else: + if layers is None or any(['layers.' + str(i) in n for i in layers]): + for name, mod in m.named_children(): + # if 'self_attn' in f_name: + # print(name, mod) + if isinstance(mod, nn.Linear) and not isinstance(mod, lora.Linear): + if names is None or any(na in name for na in names): + mod = change_mod(mod, r=r) + setattr(m, name, mod) + + lora.mark_only_lora_as_trainable(model) + return model + + +def change_mod(m, targets=[nn.Linear, nn.Conv2d, nn.Embedding], r=16): + st_dict = m.state_dict() + + if nn.Linear in targets and isinstance(m, nn.Linear): + dtype = m.weight.dtype + new_m = lora.Linear(m.in_features, m.out_features, bias=m.bias is not None, r=r, dtype=dtype) + new_m.load_state_dict(st_dict, strict=False) + # print(new_m) + m = new_m + elif nn.Conv2d in targets and isinstance(m, nn.Conv2d): + new_m = lora.Conv2d(m.in_channels, m.out_channels, m.kernel_size, stride=m.stride, padding=m.padding, \ + dilation=m.dilation, transposed=m.transposed, output_padding=m.output_padding, groups=m.groups, bias=m.bias, r=r) + new_m.load_state_dict(st_dict, strict=False) + m = new_m + elif nn.Embedding in targets and isinstance(m, nn.Embedding): + new_m = lora.Embedding(m.num_embeddings, m.embedding_dim, padding_idx=m.padding_idx, max_norm=m.max_norm, norm_type=m.norm_type, \ + scale_grad_by_freq=m.scale_grad_by_freq, freeze=m.freeze, sparse=m.sparse, r=r) + new_m.load_state_dict(st_dict, strict=False) + m = new_m + + return m diff --git a/training/detectors/clip_large_lsda.py b/training/detectors/clip_large_lsda.py new file mode 100644 index 0000000000000000000000000000000000000000..a22f6aeaa894cc79e549c2dc2c81e51d9de0e7f6 --- /dev/null +++ b/training/detectors/clip_large_lsda.py @@ -0,0 +1,194 @@ +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +from transformers import AutoProcessor, CLIPModel, ViTModel, ViTConfig +import loralib as lora +import copy + +logger = logging.getLogger(__name__) + +# Keep only intra-domain feature enhancement (centrifugal / affine / additive transforms, all enabled) +class LSDAAugmentor(nn.Module): + def __init__(self, feature_dim=1024, beta_base=0.5, affine_theta_range=(-np.pi/6, np.pi/6), noise_ratio=0.01): + super().__init__() + self.feature_dim = feature_dim + self.beta_base = beta_base # base coefficient for the centrifugal transform + self.affine_theta_range = affine_theta_range # affine rotation angle range + self.noise_ratio = noise_ratio # ratio of noise to the feature standard deviation + + def centrifugal_trans(self, z: torch.Tensor) -> torch.Tensor: + """Original version: adaptive centrifugal transform (the closer a sample is to the domain center, the farther it is pushed)""" + batch_size = z.shape[0] + mu = z.mean(dim=0, keepdim=True) # domain center [1, 1024] + dist = torch.norm(z - mu, dim=1, keepdim=True) # distance from each sample to the center [B, 1] + dist_max = dist.max() + 1e-6 # avoid division by zero + beta = self.beta_base * (1 - dist / dist_max) # the closer the distance, the larger β becomes + z_aug = z + beta * (z - mu) # apply the centrifugal transform + return z_aug + + def affine_trans(self, z: torch.Tensor) -> torch.Tensor: + """Affine transform that truly performs high-dimensional rotation (block-wise 2D rotations)""" + batch_size, feat_dim = z.shape # [32, 1024] + device = z.device + + # 1. Validate the feature dimension: it must be even (1024 satisfies this requirement) + assert feat_dim % 2 == 0, f"feature dimension{feat_dim}must be even in order to perform block-wise 2D rotations" + num_blocks = feat_dim // 2 # 512 two-dimensional blocks + + # 2. Generate per-sample, per-block rotation angles (the actual rotation parameters) + theta = torch.rand(batch_size, num_blocks, device=device) * (self.affine_theta_range[1] - self.affine_theta_range[0]) + self.affine_theta_range[0] + # Generate per-sample scaling factors + scale = torch.rand(batch_size, 1, device=device) * 0.1 + 0.95 # [32, 1] + + # 3. Compute the cos/sin values of the rotation matrix + cos_theta = torch.cos(theta) # [32, 512] + sin_theta = torch.sin(theta) # [32, 512] + + # 4. Split the 1024-dimensional features into 512 two-dimensional blocks + z_reshaped = z.reshape(batch_size, num_blocks, 2) # [32, 512, 2] + + # 5. Perform true 2D rotation on each two-dimensional block + z_rotated = torch.stack([ + z_reshaped[..., 0] * cos_theta - z_reshaped[..., 1] * sin_theta, # x' = xcosθ - ysinθ + z_reshaped[..., 0] * sin_theta + z_reshaped[..., 1] * cos_theta # y' = xsinθ + ycosθ + ], dim=-1) # [32, 512, 2] + + # 6. Reassemble by concatenation and apply global scaling + z_aug = z_rotated.reshape(batch_size, feat_dim) * scale # [32, 1024] + + return z_aug + + def additive_trans(self, z: torch.Tensor) -> torch.Tensor: + """ + Apply adaptive additive noise augmentation to the input tensor based on the feature standard deviation + + Args: + z (torch.Tensor): input feature tensor with shape [batch_size, feature_dim] + + Returns: + torch.Tensor: augmented feature tensor with the same shape as the input + + Note: + 1. noise strength is jointly determined by the standard deviation of the feature dimension and the preset noise_ratio + 2. keep statistics computed independently for each batch dimension + 3. the generated noise follows a standard normal distribution N(0,1) and is then scaled + """ + """Original version: adaptive additive noise (based on the feature standard deviation)""" + # Compute the global feature standard deviation (batch-aware) + z_std = z.std(dim=0, keepdim=True) + noise = torch.randn_like(z) * z_std * self.noise_ratio + z_aug = z + noise + return z_aug + + def forward(self, z: torch.Tensor) -> torch.Tensor: + """Original core logic: sequentially stack all intra-domain augmentations during training, while validation/inference returns the original features""" + if not self.training: + return z + + # Step 1: centrifugal transform + z = self.centrifugal_trans(z) + # Step 2: affine transform + z = self.affine_trans(z) + # Step 3: additive noise + z = self.additive_trans(z) + + return z + + +@DETECTOR.register_module(module_name='clip_large_fft_lsda') +class CLIP_Large_FFT_LSDA_Detector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.num_classes = config['backbone_config']['num_classes'] + + # Original components (kept intact) + self.backbone = self.build_backbone(config) + self.head = nn.Linear(1024, self.num_classes) + self.loss_func = self.build_loss(config) + + # Keep only the intra-domain augmentor (no other new components) + self.augmentor = LSDAAugmentor(feature_dim=1024) + + def build_backbone(self, config): + _, backbone = get_clip_visual(model_name=config['pretrained']) + return backbone + + def build_loss(self, config): + # Keep only the original classification loss, with no extra losses + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + """Core behavior: use augmented features for training and original features for validation/inference""" + # Extract the original CLIP features + feat = self.backbone(data_dict['image'])['pooler_output'] # [B, 1024] + + # During training, apply intra-domain augmentation; during validation/inference, return the original features directly + feat = self.augmentor(feat) + + return feat + + def classifier(self, features: torch.tensor) -> torch.tensor: + # Original classification head (unchanged) + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + # Keep only the original classification loss (no domain loss / distillation loss) + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + # Original metric computation (unchanged) + label = data_dict['label'] + pred = pred_dict['cls'] + acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.num_classes) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + """ + - When inference=True (validation/inference), the augmentor automatically returns the original features + - When inference=False (training), the augmentor returns augmented features + """ + # Control the augmentor mode (force eval during inference) + if inference: + self.augmentor.eval() + else: + self.augmentor.train() + + # Feature extraction (augmented during training, original during validation/inference) + features = self.features(data_dict) + # classification prediction (unchanged) + pred = self.classifier(features) + prob = torch.softmax(pred, dim=1) + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + return pred_dict + + +def get_clip_visual(model_name="openai/clip-vit-base-patch16"): + processor = AutoProcessor.from_pretrained(model_name) + model = CLIPModel.from_pretrained(model_name) + return processor, model.vision_model diff --git a/training/detectors/clip_large_vid_detector.py b/training/detectors/clip_large_vid_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..b2d11a306f7d2a7dda5c67602c267c6d834d8a7f --- /dev/null +++ b/training/detectors/clip_large_vid_detector.py @@ -0,0 +1,167 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the CLIPDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{radford2021learning, + title={Learning transferable visual models from natural language supervision}, + author={Radford, Alec and Kim, Jong Wook and Hallacy, Chris and Ramesh, Aditya and Goh, Gabriel and Agarwal, Sandhini and Sastry, Girish and Askell, Amanda and Mishkin, Pamela and Clark, Jack and others}, + booktitle={International conference on machine learning}, + pages={8748--8763}, + year={2021}, + organization={PMLR} +} +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +from transformers import AutoProcessor, CLIPModel, ViTModel, ViTConfig +import loralib as lora +import copy + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='clip_large_vid') +class CLIP_Large_Vid_Detector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.head = nn.Linear(1024, 2) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + # prepare the backbone + _, backbone = get_clip_visual(model_name=config['pretrained']) + backbone = to_lora(backbone, r=16) + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + B, T, C, H, W = data_dict['image'].shape + feat = self.backbone(data_dict['image'].view(B * T, C, H, W))['pooler_output'] + feat = feat.view(B, T, feat.shape[-1]).mean(1) # Temporal avg + return feat + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + # compute metrics for batch data + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + prob = torch.softmax(pred, dim=1)[:, 1] + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + return pred_dict + + +def get_clip_visual(model_name = "openai/clip-vit-base-patch16"): + processor = AutoProcessor.from_pretrained(model_name) + model = CLIPModel.from_pretrained(model_name) + return processor, model.vision_model + + +def to_lora(model, target=[nn.Linear, nn.Conv2d, nn.Embedding], r=16, f_class=None, layers=None, names=None): + + for n, m in model.named_modules(): + if f_class is not None and not isinstance(m, f_class): + continue + if isinstance(m, nn.Sequential) or isinstance(m, nn.ModuleList): + + for name, mod in m.named_children(): + + # print(name, mod) + if isinstance(mod, nn.Linear) and not isinstance(mod, lora.Linear): + mod = change_mod(mod, r=r) + m._modules[name] = mod + else: + if layers is None or any(['layers.' + str(i) in n for i in layers]): + for name, mod in m.named_children(): + # if 'self_attn' in f_name: + # print(name, mod) + if isinstance(mod, nn.Linear) and not isinstance(mod, lora.Linear): + if names is None or any(na in name for na in names): + mod = change_mod(mod, r=r) + setattr(m, name, mod) + + lora.mark_only_lora_as_trainable(model) + return model + + +def change_mod(m, targets=[nn.Linear, nn.Conv2d, nn.Embedding], r=16): + st_dict = m.state_dict() + + if nn.Linear in targets and isinstance(m, nn.Linear): + dtype = m.weight.dtype + new_m = lora.Linear(m.in_features, m.out_features, bias=m.bias is not None, r=r, dtype=dtype) + new_m.load_state_dict(st_dict, strict=False) + # print(new_m) + m = new_m + elif nn.Conv2d in targets and isinstance(m, nn.Conv2d): + new_m = lora.Conv2d(m.in_channels, m.out_channels, m.kernel_size, stride=m.stride, padding=m.padding, \ + dilation=m.dilation, transposed=m.transposed, output_padding=m.output_padding, groups=m.groups, bias=m.bias, r=r) + new_m.load_state_dict(st_dict, strict=False) + m = new_m + elif nn.Embedding in targets and isinstance(m, nn.Embedding): + new_m = lora.Embedding(m.num_embeddings, m.embedding_dim, padding_idx=m.padding_idx, max_norm=m.max_norm, norm_type=m.norm_type, \ + scale_grad_by_freq=m.scale_grad_by_freq, freeze=m.freeze, sparse=m.sparse, r=r) + new_m.load_state_dict(st_dict, strict=False) + m = new_m + + return m diff --git a/training/detectors/clip_large_vid_torth_detector.py b/training/detectors/clip_large_vid_torth_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..88974fda44fe59b241acf671ef3f662734c273da --- /dev/null +++ b/training/detectors/clip_large_vid_torth_detector.py @@ -0,0 +1,262 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the CLIPDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{radford2021learning, + title={Learning transferable visual models from natural language supervision}, + author={Radford, Alec and Kim, Jong Wook and Hallacy, Chris and Ramesh, Aditya and Goh, Gabriel and Agarwal, Sandhini and Sastry, Girish and Askell, Amanda and Mishkin, Pamela and Clark, Jack and others}, + booktitle={International conference on machine learning}, + pages={8748--8763}, + year={2021}, + organization={PMLR} +} +''' + +import os,sys +import datetime +import logging +import numpy as np +import yaml +from sklearn import metrics +from typing import Union +from collections import defaultdict + +#### +current_file_path = os.path.abspath(__file__) +parent_dir = os.path.dirname(os.path.dirname(current_file_path)) +project_root_dir = os.path.dirname(parent_dir) +sys.path.append(parent_dir) +sys.path.append(project_root_dir) +sys.path.append("/Youtu_Pangu_Security_Public_cq11/shunliwang/DeepFakeBench/preprocessing") +#### + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train + + +# from .base_detector import AbstractDetector +#### +from base_detector import AbstractDetector +#### + +# from detectors import DETECTOR +# from networks import BACKBONE +# from loss import LOSSFUNC +from transformers import AutoProcessor, CLIPModel, ViTModel, ViTConfig +import loralib as lora +import copy + +logger = logging.getLogger(__name__) + + +class CLIP_Large_Vid_Detector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.head = nn.Linear(1024, 2) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + # prepare the backbone + _, backbone = get_clip_visual(model_name=config['pretrained']) + backbone = to_lora(backbone, r=16) + return backbone + + def build_loss(self, config): + # prepare the loss function + # loss_class = LOSSFUNC[config['loss_func']] + # loss_func = loss_class() + return 0 # loss_func + + # def features(self, data_dict: dict) -> torch.tensor: + # B, T, C, H, W = data_dict['image'].shape + # feat = self.backbone(data_dict['image'].view(B * T, C, H, W))['pooler_output'] + # feat = feat.view(B, T, feat.shape[-1]).mean(1) # Temporal avg + # return feat + + def features(self, data_dict) -> torch.tensor: + print(data_dict.shape) + B, T, C, H, W = data_dict.shape + feat = self.backbone(data_dict.view(B * T, C, H, W))['pooler_output'] + # How to enforce orthogonality in temporal modeling + feat = feat.view(B, T, feat.shape[-1]).mean(1) # Temporal avg + return feat + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + # compute metrics for batch data + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + prob = torch.softmax(pred, dim=1)[:, 1] + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + return pred_dict + + +def get_clip_visual(model_name = "openai/clip-vit-base-patch16"): + processor = AutoProcessor.from_pretrained(model_name) + model = CLIPModel.from_pretrained(model_name) + return processor, model.vision_model + + +def to_lora(model, target=[nn.Linear, nn.Conv2d, nn.Embedding], r=16, f_class=None, layers=None, names=None): + + for n, m in model.named_modules(): + if f_class is not None and not isinstance(m, f_class): + continue + if isinstance(m, nn.Sequential) or isinstance(m, nn.ModuleList): + + for name, mod in m.named_children(): + + # print(name, mod) + if isinstance(mod, nn.Linear) and not isinstance(mod, lora.Linear): + mod = change_mod(mod, r=r) + m._modules[name] = mod + else: + if layers is None or any(['layers.' + str(i) in n for i in layers]): + for name, mod in m.named_children(): + # if 'self_attn' in f_name: + # print(name, mod) + if isinstance(mod, nn.Linear) and not isinstance(mod, lora.Linear): + if names is None or any(na in name for na in names): + mod = change_mod(mod, r=r) + setattr(m, name, mod) + + lora.mark_only_lora_as_trainable(model) + return model + + +def change_mod(m, targets=[nn.Linear, nn.Conv2d, nn.Embedding], r=16): + st_dict = m.state_dict() + + if nn.Linear in targets and isinstance(m, nn.Linear): + dtype = m.weight.dtype + new_m = lora.Linear(m.in_features, m.out_features, bias=m.bias is not None, r=r, dtype=dtype) + new_m.load_state_dict(st_dict, strict=False) + # print(new_m) + m = new_m + elif nn.Conv2d in targets and isinstance(m, nn.Conv2d): + new_m = lora.Conv2d(m.in_channels, m.out_channels, m.kernel_size, stride=m.stride, padding=m.padding, \ + dilation=m.dilation, transposed=m.transposed, output_padding=m.output_padding, groups=m.groups, bias=m.bias, r=r) + new_m.load_state_dict(st_dict, strict=False) + m = new_m + elif nn.Embedding in targets and isinstance(m, nn.Embedding): + new_m = lora.Embedding(m.num_embeddings, m.embedding_dim, padding_idx=m.padding_idx, max_norm=m.max_norm, norm_type=m.norm_type, \ + scale_grad_by_freq=m.scale_grad_by_freq, freeze=m.freeze, sparse=m.sparse, r=r) + new_m.load_state_dict(st_dict, strict=False) + m = new_m + + return m + + +if __name__ == '__main__': + + with open('/Youtu_Pangu_Security_Public_cq11/shunliwang/DeepFakeBench/training/config/detector/clip_large_vid.yaml', 'r', encoding='utf-8') as file: + data = yaml.safe_load(file) + + model = CLIP_Large_Vid_Detector(data) + # print(model) + + print("Trainable parameters:") + for name, param in model.named_parameters(): + if param.requires_grad: + print(f"Parameter name: {name}, shape: {param.shape}, parameter count: {param.numel()}") + + # Method 2: compute and print the total number of model parameters + total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) # 7,079,938 + print(f"\nTotal model parameters: {total_params:,}") + + # Method 3: print the model structure and parameter counts + from torchsummary import summary + summary(model, input_size=(8, 3, 224, 224)) + +''' +CLIP_Large_Vid_Detector( + (backbone): CLIPVisionTransformer( + (embeddings): CLIPVisionEmbeddings( + (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False) + (position_embedding): Embedding(257, 1024) + ) + (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) + (encoder): CLIPEncoder( + (layers): ModuleList( + (0): CLIPEncoderLayer( + (self_attn): CLIPAttention( + (k_proj): Linear(in_features=1024, out_features=1024, bias=True) + (v_proj): Linear(in_features=1024, out_features=1024, bias=True) + (q_proj): Linear(in_features=1024, out_features=1024, bias=True) + (out_proj): Linear(in_features=1024, out_features=1024, bias=True) + ) + (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) + (mlp): CLIPMLP( + (activation_fn): QuickGELUActivation() + (fc1): Linear(in_features=1024, out_features=4096, bias=True) + (fc2): Linear(in_features=4096, out_features=1024, bias=True) + ) + (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) + ) + + ...... + + (23): CLIPEncoderLayer( + (self_attn): CLIPAttention( + (k_proj): Linear(in_features=1024, out_features=1024, bias=True) + (v_proj): Linear(in_features=1024, out_features=1024, bias=True) + (q_proj): Linear(in_features=1024, out_features=1024, bias=True) + (out_proj): Linear(in_features=1024, out_features=1024, bias=True) + ) + (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) + (mlp): CLIPMLP( + (activation_fn): QuickGELUActivation() + (fc1): Linear(in_features=1024, out_features=4096, bias=True) + (fc2): Linear(in_features=4096, out_features=1024, bias=True) + ) + (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) + ) + ) + ) + (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) + ) + (head): Linear(in_features=1024, out_features=2, bias=True) +) +''' diff --git a/training/detectors/clip_openai_vid_detector.py b/training/detectors/clip_openai_vid_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..bec1e8c8b3981f9320c50b7be49ce9ff0d83b7ba --- /dev/null +++ b/training/detectors/clip_openai_vid_detector.py @@ -0,0 +1,259 @@ +''' +OpenAI Official CLIP-ViT +''' + + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from collections import OrderedDict +from typing import Tuple, Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +from transformers import AutoProcessor, CLIPModel, ViTModel, ViTConfig +import loralib as lora +import copy + +import clip + +logger = logging.getLogger(__name__) + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + # [1+256,BT,1024] + x = x + self.attention(self.ln_1(x)) # hierarchical attention; can we avoid passing through attention twice because it is too slow? + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class VisionTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int): + super().__init__() + self.input_resolution = input_resolution + # self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # [BT,1024,16,16] + x = x.reshape(x.shape[0], x.shape[1], -1) # [BT,1024,256] + x = x.permute(0, 2, 1) # [BT,256,1024] + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # [BT,1+256,1024] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND | [BT,1+256,1024] -> [1+256,BT,1024] + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD | [1+256,BT,1024] -> [BT,1+256,1024] + + x = self.ln_post(x[:, 0, :]) + + return x + + +@DETECTOR.register_module(module_name='clip_openai_vid') +class CLIP_Openai_Large_Vid_Detector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.head = nn.Linear(1024, 2) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + # prepare the backbone + clip_model, preprocess = clip.load(config['pretrained'], device="cpu") + pretrain_dict = clip_model.visual.state_dict() + + backbone = VisionTransformer( + input_resolution=config['resolution'], + patch_size=14, + width=1024, + layers=24, + heads=16 + ) + + msg = backbone.load_state_dict(pretrain_dict, strict=False) + print('Missing keys: {}'.format(msg.missing_keys)) + print('Unexpected keys: {}'.format(msg.unexpected_keys)) + + backbone = to_lora(backbone, r=16) + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + B, T, C, H, W = data_dict['image'].shape + feat = self.backbone(data_dict['image'].view(B * T, C, H, W)) + feat = feat.view(B, T, feat.shape[-1]).mean(1) # Temporal avg + return feat + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + # compute metrics for batch data + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + prob = torch.softmax(pred, dim=1)[:, 1] + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + return pred_dict + + +def to_lora(model, target=[nn.Linear, nn.Conv2d, nn.Embedding], r=16, f_class=None, layers=None, names=None): + + for n, m in model.named_modules(): + if f_class is not None and not isinstance(m, f_class): + continue + if isinstance(m, nn.Sequential) or isinstance(m, nn.ModuleList): + + for name, mod in m.named_children(): + + # print(name, mod) + if isinstance(mod, nn.Linear) and not isinstance(mod, lora.Linear): + mod = change_mod(mod, r=r) + m._modules[name] = mod + else: + if layers is None or any(['layers.' + str(i) in n for i in layers]): + for name, mod in m.named_children(): + # if 'self_attn' in f_name: + # print(name, mod) + if isinstance(mod, nn.Linear) and not isinstance(mod, lora.Linear): + if names is None or any(na in name for na in names): + mod = change_mod(mod, r=r) + setattr(m, name, mod) + + lora.mark_only_lora_as_trainable(model) + return model + + +def change_mod(m, targets=[nn.Linear, nn.Conv2d, nn.Embedding], r=16): + st_dict = m.state_dict() + + if nn.Linear in targets and isinstance(m, nn.Linear): + dtype = m.weight.dtype + new_m = lora.Linear(m.in_features, m.out_features, bias=m.bias is not None, r=r, dtype=dtype) + new_m.load_state_dict(st_dict, strict=False) + # print(new_m) + m = new_m + elif nn.Conv2d in targets and isinstance(m, nn.Conv2d): + new_m = lora.Conv2d(m.in_channels, m.out_channels, m.kernel_size, stride=m.stride, padding=m.padding, \ + dilation=m.dilation, transposed=m.transposed, output_padding=m.output_padding, groups=m.groups, bias=m.bias, r=r) + new_m.load_state_dict(st_dict, strict=False) + m = new_m + elif nn.Embedding in targets and isinstance(m, nn.Embedding): + new_m = lora.Embedding(m.num_embeddings, m.embedding_dim, padding_idx=m.padding_idx, max_norm=m.max_norm, norm_type=m.norm_type, \ + scale_grad_by_freq=m.scale_grad_by_freq, freeze=m.freeze, sparse=m.sparse, r=r) + new_m.load_state_dict(st_dict, strict=False) + m = new_m + + return m + + +if __name__ == '__main__': + + model = VisionTransformer( + input_resolution=224, + patch_size=14, + width=1024, + layers=24, + heads=16 + ) + + import clip + clip_model, preprocess = clip.load("ViT-L/14", device="cpu") + pretrain_dict = clip_model.visual.state_dict() + + print(model) + + msg = model.load_state_dict(pretrain_dict, strict=False) + print('Missing keys: {}'.format(msg.missing_keys)) + print('Unexpected keys: {}'.format(msg.unexpected_keys)) diff --git a/training/detectors/clip_patch_shuffle.py b/training/detectors/clip_patch_shuffle.py new file mode 100644 index 0000000000000000000000000000000000000000..7517b80b5cd32f01bb3013cb45ddf0ced24c84ec --- /dev/null +++ b/training/detectors/clip_patch_shuffle.py @@ -0,0 +1,171 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the CLIPDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{radford2021learning, + title={Learning transferable visual models from natural language supervision}, + author={Radford, Alec and Kim, Jong Wook and Hallacy, Chris and Ramesh, Aditya and Goh, Gabriel and Agarwal, Sandhini and Sastry, Girish and Askell, Amanda and Mishkin, Pamela and Clark, Jack and others}, + booktitle={International conference on machine learning}, + pages={8748--8763}, + year={2021}, + organization={PMLR} +} +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +from transformers import AutoProcessor, CLIPModel, ViTModel, ViTConfig +import loralib as lora +import copy + +logger = logging.getLogger(__name__) + + + +def shuffle_patches(images: torch.Tensor, patch_levels: list = [14, 56, 224]) -> torch.Tensor: + """ + Apply patch-level shuffling to the input images, where each image in the batch randomly selects a shuffle level. + images: input tensor of shape [B, C, H, W], requiring H=W=224 so all levels divide evenly. + patch_levels: list of shuffle levels corresponding to patch sizes; 224 means no shuffle. + Returns: image tensor with the same shape [B, C, H, W], shuffled per image by a random level. + """ + B, C, H, W = images.shape + + # Initialize the output tensor used to store the shuffled result for each image. + shuffled_images = torch.empty_like(images, device=images.device) + # Randomly select a shuffle level (patch size) for each image in the batch. + probs = [0.33, 0.33, 0.34] + random_ps = torch.tensor( + [torch.multinomial(torch.tensor(probs), 1).item() for _ in range(B)], + device=images.device + ) + random_ps = torch.tensor([patch_levels[p] for p in random_ps], device=images.device) + # Process each image independently and shuffle or keep it based on the sampled patch size. + for b in range(B): + ps = random_ps[b].item() # patch size selected for the current image + img = images[b:b+1] # take one image while keeping shape [1, C, H, W] + B_single, C, H, W = img.shape + num_patches_h = H // ps + num_patches_w = W // ps + num_patches = num_patches_h * num_patches_w + + # Original patch split logic:[1, C, H, W] -> [1, num_patches, C, ps, ps] + img = img.view(B_single, C, num_patches_h, ps, num_patches_w, ps) + img = img.permute(0, 2, 4, 1, 3, 5).contiguous() + img = img.view(B_single, num_patches, C, ps, ps) + + # Shuffle patches: when `num_patches=1` (`ps=224`), `randperm(1)` is still `[0]`, which is equivalent to no shuffle. + perm = torch.randperm(num_patches, device=img.device) + batch_idx = torch.arange(B_single, device=img.device).unsqueeze(1).expand(B_single, num_patches) + img = img[batch_idx, perm] + + # Restore the original shape: [1, num_patches, C, ps, ps] -> [1, C, H, W] + img = img.view(B_single, num_patches_h, num_patches_w, C, ps, ps) + img = img.permute(0, 3, 1, 4, 2, 5).contiguous() + img = img.view(B_single, C, H, W) + + # Write the current image result back into the output tensor. + shuffled_images[b:b+1] = img + + return shuffled_images + +@DETECTOR.register_module(module_name='clip_patch_shuffle') +class CLIP_PATCH_SHUFFLE_Detector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.head = nn.Linear(1024, config['backbone_config']['num_classes']) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + # prepare the backbone + _, backbone = get_clip_visual(model_name=config['pretrained']) + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + x=data_dict['image'] + if self.training: # shuffle only during training + x = shuffle_patches(x, patch_levels= [14, 14, 14]) + feat = self.backbone(x)['pooler_output'] + return feat + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + + # compute metrics for batch data + # auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + # metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.config['backbone_config']['num_classes']) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + # prob = torch.softmax(pred, dim=1)[:, 1] + prob = torch.softmax(pred, dim=1) + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + return pred_dict + + +def get_clip_visual(model_name = "openai/clip-vit-base-patch16"): + processor = AutoProcessor.from_pretrained(model_name) + model = CLIPModel.from_pretrained(model_name) + return processor, model.vision_model diff --git a/training/detectors/cnn_dct_detector.py b/training/detectors/cnn_dct_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..9c12ca81d72c7591dfe01a76c48577392827dbd7 --- /dev/null +++ b/training/detectors/cnn_dct_detector.py @@ -0,0 +1,213 @@ +# author: Your Name +# email: your_email@example.com +# date: 2025-11-20 +# description: Class for the CNN_DCT Detector (36-class classifier adapted to the YAML configuration) + +import os +import logging +import numpy as np +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from scipy.fftpack import dct, idct + +from metrics.base_metrics_class import calculate_acc_for_train # multi-class metrics +from .base_detector import AbstractDetector +from detectors import DETECTOR +from loss import LOSSFUNC + +logger = logging.getLogger(__name__) + +# --------------------------- DCT Utils (PyTorch implementation) --------------------------- +def dct2_torch(array: Tensor, batched: bool = True) -> Tensor: + """2D DCT implemented in PyTorch (consistent with the original TensorFlow logic)""" + shape = array.shape + dtype = array.dtype + array = array.float() + + if batched: + array = array.permute(0, 3, 2, 1) # [B, C, W, H] + array = torch.from_numpy(dct(array.cpu().numpy(), type=2, norm='ortho', axis=-1)).to(array.device) + array = array.permute(0, 1, 3, 2) # [B, C, H, W] + array = torch.from_numpy(dct(array.cpu().numpy(), type=2, norm='ortho', axis=-1)).to(array.device) + array = array.permute(0, 2, 3, 1) # [B, H, W, C] + else: + array = array.permute(2, 1, 0) # [C, W, H] + array = torch.from_numpy(dct(array.cpu().numpy(), type=2, norm='ortho', axis=-1)).to(array.device) + array = array.permute(0, 2, 1) # [C, H, W] + array = torch.from_numpy(dct(array.cpu().numpy(), type=2, norm='ortho', axis=-1)).to(array.device) + array = array.permute(1, 2, 0) # [H, W, C] + + array = array.to(dtype) + assert array.shape == shape, f"DCT shape mismatch: {array.shape} vs {shape}" + return array + +# --------------------------- DCT Layer (PyTorch implementation) --------------------------- +class DCTLayer(nn.Module): + def __init__(self, mean: np.ndarray = None, var: np.ndarray = None): + super().__init__() + self.use_normalize = mean is not None and var is not None # whether normalization is applied + if self.use_normalize: + self.mean = torch.tensor(mean, dtype=torch.float32) + self.var = torch.tensor(var, dtype=torch.float32) + self.std = torch.sqrt(self.var) + self.register_buffer('mean_w', self.mean) + self.register_buffer('std_w', self.std) + + def forward(self, inputs: Tensor) -> Tensor: + # DCT transform + logarithmic scaling (core logic preserved) + x = dct2_torch(inputs, batched=True) + x = torch.abs(x) + 1e-13 + x = torch.log(x) + + # Apply normalization only when mean/var are provided + if self.use_normalize: + x = (x - self.mean_w) / self.std_w + return x + +# --------------------------- Simple CNN backbone (adapted to the YAML backbone_config) --------------------------- +class SimpleCNN(nn.Module): + def __init__(self, input_shape: tuple, backbone_config: dict): + super().__init__() + self.backbone_config = backbone_config + self.num_classes = backbone_config['num_classes'] # read the number of classes from backbone_config + self.dropout_rate = backbone_config['dropout'] # read the dropout rate from backbone_config + self.input_shape = input_shape # (H, W, C) + + # CNN backbone (same as the original structure, with dropout added) + self.conv_layers = nn.Sequential( + nn.Conv2d(input_shape[-1], 3, kernel_size=3, padding='same'), + nn.ReLU(inplace=True), + nn.Conv2d(3, 8, kernel_size=3, padding='same'), + nn.ReLU(inplace=True), + nn.AvgPool2d(kernel_size=2), + nn.Conv2d(8, 16, kernel_size=3, padding='same'), + nn.ReLU(inplace=True), + nn.AvgPool2d(kernel_size=2), + nn.Conv2d(16, 32, kernel_size=3, padding='same'), + nn.ReLU(inplace=True), + nn.Flatten() + ) + + + self.head = nn.Linear(32*64*64, 1024) + + # Dynamically compute the flattened dimension + with torch.no_grad(): + dummy_input = torch.randn(1, input_shape[-1], input_shape[0], input_shape[1]) + flat_dim = self.conv_layers(dummy_input).shape[1] + + # Multi-class classification head (with dropout, adapted to backbone_config) + self.classifier = nn.Sequential( + nn.Dropout(self.dropout_rate), + nn.Linear(1024, self.num_classes) + ) + + def forward(self, x: Tensor) -> Tensor: + x = self.conv_layers(x) + x=self.head(x) + x = self.classifier(x) # output logits without activation + return x + + def features(self, x: Tensor) -> Tensor: + return self.head(self.conv_layers(x)) + +# --------------------------- CNN_DCT detector (fully adapted to the YAML configuration) --------------------------- +@DETECTOR.register_module(module_name='cnn_dct') +class CNNDCTDetector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + + # 1. Read the core configuration from YAML (aligned with backbone_config) + self.backbone_config = config['backbone_config'] # Read backbone_config + self.num_classes = self.backbone_config['num_classes'] # take the 36 classes from backbone_config + self.resolution = config.get('resolution', 256) # read the resolution from YAML (default 256) + self.input_shape = (self.resolution, self.resolution, 3) # (H, W, C) = (256,256,3) + + # 2. Initialize the DCT layer (without normalization, consistent with the previous logic) + self.dct_layer = DCTLayer(mean=None, var=None) + + # 3. Build the backbone (pass in backbone_config) + self.backbone = self.build_backbone(config) + + # 4. Build the loss function (read loss_func from YAML) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + """Build the CNN backbone fully according to backbone_config from the YAML file""" + cnn = SimpleCNN( + input_shape=self.input_shape, + backbone_config=self.backbone_config # pass in the complete backbone_config + ) + logger.info( + f"Built SimpleCNN for {self.num_classes}-class classification. " + f"Input shape: {self.input_shape}, Dropout: {self.backbone_config['dropout']}" + ) + return cnn + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + """Feature extraction (DCT + CNN)""" + # Input image: [B, 3, H, W] (PyTorch format, aligned with YAML resolution=256) + img = data_dict['image'] + # Convert to TF format [B, H, W, C] for the DCT layer + img_tf_format = img.permute(0, 2, 3, 1) + # DCT transform (without normalization) + dct_feat = self.dct_layer(img_tf_format) + # Convert back to PyTorch format [B, C, H, W] + dct_feat_pytorch = dct_feat.permute(0, 3, 1, 2) + # Extract CNN features + cnn_feat = self.backbone.features(dct_feat_pytorch) + return cnn_feat + + def classifier(self, features: torch.tensor) -> torch.tensor: + """Multi-class classifier (outputs logits)""" + pred_logits = self.backbone.classifier(features) + return pred_logits + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + """Compute the multi-classification loss""" + label = data_dict['label'] # Multi-class labels: [B] (0~35) + pred_logits = pred_dict['cls'] # [B, 36] + + loss = self.loss_func(pred_logits, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + """Compute multi-class metrics (acc + mAP)""" + label = data_dict['label'].detach() + pred_logits = pred_dict['cls'].detach() + + acc, mAP = calculate_acc_for_train(label, pred_logits, self.num_classes) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + """Forward pass (adapted for 36-class classification)""" + # 1. Extract features + features = self.features(data_dict) + # 2. Classification prediction (outputs 36-class logits) + pred_logits = self.classifier(features) + # 3. Compute the 36-class probability distribution + pred_prob = F.softmax(pred_logits, dim=1) # [B, 36] + # 4. Predicted class (0~35) + pred_label = torch.argmax(pred_prob, dim=1) # [B] + + pred_dict = { + 'cls': pred_logits, # For training: logits [B,36] + 'prob': pred_prob, # For inference: probabilities [B,36] + 'label': pred_label, # For inference: class [B] + 'feat': features # features + } + return pred_dict diff --git a/training/detectors/core_detector.py b/training/detectors/core_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..072d453945250f49135024d54972ab925a1269af --- /dev/null +++ b/training/detectors/core_detector.py @@ -0,0 +1,124 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the CoreDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{ni2022core, + title={Core: Consistent representation learning for face forgery detection}, + author={Ni, Yunsheng and Meng, Depu and Yu, Changqian and Quan, Chengbin and Ren, Dongchun and Zhao, Youjian}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + pages={12--21}, + year={2022} +} + +GitHub Reference: +https://github.com/nii-yamagishilab/Capsule-Forensics-v2 +''' + +import os +import datetime +import logging +import random +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +from efficientnet_pytorch import EfficientNet + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='core') +class CoreDetector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + # prepare the backbone + backbone_class = BACKBONE[config['backbone_name']] + model_config = config['backbone_config'] + backbone = backbone_class(model_config) + # if donot load the pretrained weights, fail to get good results + state_dict = torch.load(config['pretrained']) + for name, weights in state_dict.items(): + if 'pointwise' in name: + state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1) + state_dict = {k:v for k, v in state_dict.items() if 'fc' not in k} + backbone.load_state_dict(state_dict, False) + logger.info('Load pretrained model successfully!') + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + return self.backbone.features(data_dict['image']) + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.backbone.classifier(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + core_feat = pred_dict['core_feat'] + loss = self.loss_func(core_feat, pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + # compute metrics for batch data + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the core_feat for loss + core_feat = nn.ReLU(inplace=False)(features) + core_feat= F.adaptive_avg_pool2d(core_feat, (1, 1)) + core_feat = core_feat.view(core_feat.size(0), -1) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + prob = torch.softmax(pred, dim=1)[:, 1] + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features, 'core_feat': core_feat} + + return pred_dict + diff --git a/training/detectors/d_det.py b/training/detectors/d_det.py new file mode 100644 index 0000000000000000000000000000000000000000..3f6bd6a76dd56915c79c8242117e74657bf53225 --- /dev/null +++ b/training/detectors/d_det.py @@ -0,0 +1,335 @@ +# detector_dna.py + +import os +import logging +import numpy as np +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from scipy.fftpack import dct, idct + +from metrics.base_metrics_class import calculate_acc_for_train # multi-class metrics +from .base_detector import AbstractDetector +from detectors import DETECTOR +from loss import LOSSFUNC +logger = logging.getLogger(__name__) + +#################################### +# Original network definitions: Simple_CNN and SupConNet +#################################### + +class vgg_layer(nn.Module): + def __init__(self, nin, nout): + super(vgg_layer, self).__init__() + self.main = nn.Sequential( + nn.Conv2d(nin, nout, 3, 1, 1), + nn.BatchNorm2d(nout), + nn.LeakyReLU(0.2) + ) + + def forward(self, input): + return self.main(input) + +class dcgan_conv(nn.Module): + def __init__(self, nin, nout): + super(dcgan_conv, self).__init__() + self.main = nn.Sequential( + nn.Conv2d(nin, nout, 4, 2, 1), + nn.BatchNorm2d(nout), + nn.LeakyReLU(0.2), + ) + + def forward(self, input): + return self.main(input) + +class Simple_CNN(nn.Module): + def __init__(self, class_num, pretrain=False): + super(Simple_CNN, self).__init__() + nc = 3 + nf = 64 + self.main = nn.Sequential( + dcgan_conv(nc, nf), + vgg_layer(nf, nf), + + dcgan_conv(nf, nf * 2), + vgg_layer(nf * 2, nf * 2), + + dcgan_conv(nf * 2, nf * 4), + vgg_layer(nf * 4, nf * 4), + + dcgan_conv(nf * 4, nf * 8), + vgg_layer(nf * 8, nf * 8), + ) + self.pool = nn.AdaptiveAvgPool2d(1) + self.classification_head = nn.Sequential( + nn.Dropout(p=0.2, inplace=True), + nn.Linear(nf * 8, class_num, bias=True) + ) + self.pretrain = pretrain + + def forward(self, input): + embedding = self.main(input) # [B, nf*8, H', W'] + feature = self.pool(embedding) # [B, nf*8, 1, 1] + feature = feature.view(feature.shape[0], -1) # [B, nf*8] + cls_out = self.classification_head(feature) # [B, num_classes] + if not self.pretrain: + cls_out = F.softmax(cls_out, dim=1) + return cls_out, embedding + +class SupConNet(nn.Module): + """backbone + projection head""" + def __init__(self, backbone, head='mlp', dim_in=512, feat_dim=128): + super(SupConNet, self).__init__() + self.backbone = backbone + if head == 'linear': + self.head = nn.Linear(dim_in, feat_dim) + elif head == 'mlp': + self.head = nn.Sequential( + nn.Linear(dim_in, dim_in), + nn.ReLU(inplace=True), + nn.Linear(dim_in, feat_dim) + ) + else: + raise ValueError(f'Unknown head type: {head}') + + def forward(self, x): + # cls_out: softmax logits from Simple_CNN + cls_out, embedding = self.backbone(x) # embedding: [B, C, H, W] + feat = self.backbone.pool(embedding) # [B, C, 1, 1] + feat = feat.view(feat.shape[0], -1) # [B, C] + feat = F.normalize(self.head(feat), dim=1) # [B, feat_dim] + return cls_out, feat, embedding + + +############################## +# SupConLoss and AWL +############################## + +class SupConLoss(nn.Module): + def __init__(self, temperature=0.07, contrast_mode='all', + base_temperature=0.07): + super(SupConLoss, self).__init__() + self.temperature = temperature + self.contrast_mode = contrast_mode + self.base_temperature = base_temperature + + def forward(self, features, labels=None, mask=None): + device = (torch.device('cuda') + if features.is_cuda + else torch.device('cpu')) + + if len(features.shape) < 3: + raise ValueError('`features` needs to be [bsz, n_views, ...],' + 'at least 3 dimensions are required') + if len(features.shape) > 3: + features = features.view(features.shape[0], features.shape[1], -1) + + batch_size = features.shape[0] + if labels is not None and mask is not None: + raise ValueError('Cannot define both `labels` and `mask`') + elif labels is None and mask is None: + mask = torch.eye(batch_size, dtype=torch.float32).to(device) + elif labels is not None: + labels = labels.contiguous().view(-1, 1) + if labels.shape[0] != batch_size: + raise ValueError('Num of labels does not match num of features') + mask = torch.eq(labels, labels.T).float().to(device) + else: + mask = mask.float().to(device) + + contrast_count = features.shape[1] + contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) + + if self.contrast_mode == 'one': + anchor_feature = features[:, 0] + anchor_count = 1 + elif self.contrast_mode == 'all': + anchor_feature = contrast_feature + anchor_count = contrast_count + else: + raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) + + # compute logits + anchor_dot_contrast = torch.div( + torch.matmul(anchor_feature, contrast_feature.T), + self.temperature) + + # for numerical stability + logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) + logits = anchor_dot_contrast - logits_max.detach() + + # tile mask + mask = mask.repeat(anchor_count, contrast_count) + + # mask-out self-contrast cases + logits_mask = torch.scatter( + torch.ones_like(mask), + 1, + torch.arange(batch_size * anchor_count).view(-1, 1).to(device), + 0 + ) + mask = mask * logits_mask + + # compute log_prob + exp_logits = torch.exp(logits) * logits_mask + log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) + pos_per_sample = mask.sum(1) # [bsz * anchor_count] + + valid_mask = pos_per_sample > 0 + mean_log_prob_pos = (mask * log_prob).sum(1) + mean_log_prob_pos = mean_log_prob_pos[valid_mask] / (pos_per_sample[valid_mask] + 1e-8) + + loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos + loss = loss.mean() + + return loss + +class AutomaticWeightedLoss(nn.Module): + def __init__(self, num=2): + super(AutomaticWeightedLoss, self).__init__() + params = torch.ones(num, requires_grad=True) + self.params = torch.nn.Parameter(params) + + def forward(self, *x): + loss_sum = 0 + for i, loss in enumerate(x): + loss_sum += 0.5 / (self.params[i] ** 2) * loss + torch.log(1 + self.params[i] ** 2) + return loss_sum + + +#################################### +# DNADetector: inherits from AbstractDetector +#################################### + +@DETECTOR.register_module(module_name='dna_det') +class DNADetector(AbstractDetector): + """ + DNA_DET-style detector: + - backbone: Simple_CNN + SupConNet + - loss: automatically weighted CE + SupConLoss + """ + def __init__(self, config, load_param: Union[bool, str] = False): + super().__init__(config=config, load_param=load_param) + self.config = config + self.backbone_config = config['backbone_config'] # Read backbone_config + self.num_classes = self.backbone_config['num_classes'] # take the 36 classes from backbone_config + # Build the backbone (SupConNet(Simple_CNN)) + self.backbone = self.build_backbone(config) + + # Build the loss-related modules (CE, SupCon, AWL) + self.loss_modules = self.build_loss(config) + + #################################### + # Build backbone / loss + #################################### + def build_backbone(self, config): + bb_cfg = config['backbone_config'] + num_classes = bb_cfg.get('num_classes', 2) + pretrain = bb_cfg.get('pretrain', False) + head_type = bb_cfg.get('head', 'mlp') + dim_in = bb_cfg.get('dim_in', 512) + feat_dim = bb_cfg.get('feat_dim', 128) + + base_cnn = Simple_CNN(num_classes, pretrain=pretrain) + backbone = SupConNet( + backbone=base_cnn, + head=head_type, + dim_in=dim_in, + feat_dim=feat_dim + ) + return backbone + + def build_loss(self, config): + temperature = config.get('temperature', 0.07) + criterion_ce = nn.CrossEntropyLoss() + criterion_con = SupConLoss(temperature=temperature) + awl = AutomaticWeightedLoss(num=2) + return { + 'ce': criterion_ce, + 'con': criterion_con, + 'awl': awl + } + + #################################### + # Interfaces that AbstractDetector must implement + #################################### + def features(self, data_dict: dict) -> torch.Tensor: + """ + Return the embedding features from the backbone network. + """ + x = data_dict['image'] # [B, 3, H, W] + cls_out, feat_vec, embedding = self.backbone(x) + # Return convolutional embeddings for visualization or post-processing. + return embedding + + def classifier(self, features: torch.Tensor) -> torch.Tensor: + """ + Use the backbone classification head to output class logits or probabilities. + For consistency, classification is handled directly in `forward`. + If you want to avoid repeated computation, construct `pred_dict` directly in `forward`. + """ + # Assume the input is `data_dict['image']` rather than an embedding tensor, + # for simplicity, classification is handled directly in `forward`, and `classifier` is not called separately. + raise NotImplementedError( + "Classification in DNADetector is handled inside `forward`; `classifier(features)` is not called separately." + ) + + def forward(self, data_dict: dict, inference: bool = False) -> dict: + """ + Forward process: + - Input: `data_dict['image']` with shape [B, 3, H, W] + - Output: `pred_dict = {'cls', 'prob', 'feat', 'embedding'}` + """ + x = data_dict['image'] + cls_out, feat_vec, embedding = self.backbone(x) # cls_out: [B, C], feat_vec: [B, feat_dim] + + # `prob` is already softmax output. Remove softmax in `Simple_CNN` if logits are preferred here. + prob = cls_out + + # SupConLoss expects features shaped as [B, n_views, feat_dim]. + # Assume a single view here and unsqueeze on dimension 1. + contrast_feat = feat_vec.unsqueeze(1) # [B, 1, feat_dim] + + pred_dict = { + 'cls': cls_out, # [B, num_classes] + 'prob': prob, # same as cls_out (softmax) + 'feat': contrast_feat, # [B, 1, feat_dim] for SupConLoss + 'embedding': embedding # [B, C, H', W'] optional convolutional features + } + return pred_dict + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + """ + Compute the total loss and its components: + - CE loss: classification cross-entropy + - SupCon loss: contrastive learning loss + - AWL: automatically weighted overall loss + """ + label = data_dict['label'] # [B] + cls = pred_dict['cls'] # [B, num_classes] + feat = pred_dict['feat'] # [B, n_views, feat_dim] + + ce = self.loss_modules['ce'](cls, label) + con = self.loss_modules['con'](feat, labels=label) + overall = self.loss_modules['awl'](ce, con) + + loss_dict = { + 'overall': overall, + 'ce': ce, + 'con': con + } + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + """Compute multi-class metrics (acc + mAP)""" + label = data_dict['label'].detach() + pred_logits = pred_dict['cls'].detach() + + acc, mAP = calculate_acc_for_train(label, pred_logits, self.num_classes) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + return metric_batch_dict + diff --git a/training/detectors/dino_contrast.py b/training/detectors/dino_contrast.py new file mode 100644 index 0000000000000000000000000000000000000000..e81171c934056b9a9d392c11db61291861e5a5e6 --- /dev/null +++ b/training/detectors/dino_contrast.py @@ -0,0 +1,112 @@ +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +from transformers import AutoImageProcessor, AutoModel +import loralib as lora +import copy + +logger = logging.getLogger(__name__) + + + +@DETECTOR.register_module(module_name='dinov2_large_fft_contrast') +class DINOv2_Large_FFT_Contrast_Detector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.num_classes = config['backbone_config']['num_classes'] # 36 classes + self.backbone = self.build_backbone(config) + + # 36-class head (DINOv2-Large 1024 dimensions -> 36 dimensions) + self.head = nn.Linear(1024, self.num_classes) + + # Losses: classification loss as the main objective + SupCon contrastive loss as an auxiliary term + self.cls_loss_func = self.build_loss(config) + supconloss_claas=LOSSFUNC["supcon"] + self.supcon_loss_func = supconloss_claas() + + def build_backbone(self, config): + backbone = get_clip_visual(model_name=config['pretrained']) + return backbone + + def build_loss(self, config): + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + """Extract DINOv2-Large features during full fine-tuning.""" + img = data_dict['image'] + feat = self.backbone(data_dict['image']) + return feat + + def classifier(self, features: torch.tensor) -> torch.tensor: + """36-class head.""" + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + """Combined loss: 36-class classification (90%) + SupCon contrastive loss (10%).""" + # 1. Main 36-class classification loss for full fine-tuning + label = data_dict['label'] + pred = pred_dict['cls'] + cls_loss = self.cls_loss_func(pred, label) + + # 2. SupCon contrastive loss (single view: use image features only, expanded to [B, 1, 1024]) + feat = pred_dict['feat'] # [B, 1024] → [B, 1, 1024] + supcon_loss = self.supcon_loss_func(feat, label) + + # 3. Total loss(classification-dominant with contrastive assistance) + contrast_weight = self.config.get('contrast_weight', 0.05) + contrast_weight=0.1 + total_loss = cls_loss + contrast_weight * supcon_loss + + loss_dict = { + 'overall': total_loss, + 'cls_loss': cls_loss, + 'supcon_loss': supcon_loss + } + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + """36-class metrics (Acc + mAP).""" + label = data_dict['label'].detach() + pred = pred_dict['cls'].detach() + acc, mAP = calculate_acc_for_train(label, pred, self.num_classes) + return {'acc': acc, 'mAP': mAP} + + def forward(self, data_dict: dict, inference=False) -> dict: + + # Single-view feature extraction (unchanged) + features = self.features(data_dict) # [B, 1024] + pred = self.classifier(features) + prob = torch.softmax(pred, dim=1) + + pred_dict = { + 'cls': pred, + 'prob': prob, + 'feat': features # directly use single-view features for contrastive learning + } + return pred_dict + +def get_clip_visual(model_name = 'dinov2_vitl14'): + dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', model_name) + return dinov2_vitl14 + diff --git a/training/detectors/dinov2_large_fft_detector.py b/training/detectors/dinov2_large_fft_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..200124d0563f9977f59cde8fdff7575f0158d418 --- /dev/null +++ b/training/detectors/dinov2_large_fft_detector.py @@ -0,0 +1,119 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the CLIPDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{radford2021learning, + title={Learning transferable visual models from natural language supervision}, + author={Radford, Alec and Kim, Jong Wook and Hallacy, Chris and Ramesh, Aditya and Goh, Gabriel and Agarwal, Sandhini and Sastry, Girish and Askell, Amanda and Mishkin, Pamela and Clark, Jack and others}, + booktitle={International conference on machine learning}, + pages={8748--8763}, + year={2021}, + organization={PMLR} +} +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +from transformers import AutoImageProcessor, AutoModel +import loralib as lora +import copy + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='dinov2_large_fft') +class DINOv2_Large_FFT_Detector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.head = nn.Linear(1024, config['backbone_config']['num_classes']) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + # prepare the backbone + backbone = get_clip_visual(model_name=config['pretrained']) + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + feat = self.backbone(data_dict['image']) + return feat + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + + # compute metrics for batch data + # auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + # metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.config['backbone_config']['num_classes']) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + # prob = torch.softmax(pred, dim=1)[:, 1] + prob = torch.softmax(pred, dim=1) + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + return pred_dict + + +def get_clip_visual(model_name = 'dinov2_vitl14'): + dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', model_name) + return dinov2_vitl14 diff --git a/training/detectors/dinov3_large_fft_detector.py b/training/detectors/dinov3_large_fft_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..2bd9b4c9b7bd831a6fcb7f9d4b0910cfdce424ac --- /dev/null +++ b/training/detectors/dinov3_large_fft_detector.py @@ -0,0 +1,122 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the CLIPDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{radford2021learning, + title={Learning transferable visual models from natural language supervision}, + author={Radford, Alec and Kim, Jong Wook and Hallacy, Chris and Ramesh, Aditya and Goh, Gabriel and Agarwal, Sandhini and Sastry, Girish and Askell, Amanda and Mishkin, Pamela and Clark, Jack and others}, + booktitle={International conference on machine learning}, + pages={8748--8763}, + year={2021}, + organization={PMLR} +} +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +from transformers import AutoImageProcessor, AutoModel +import loralib as lora +import copy + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='dinov3_large_fft') +class DINOv3_Large_FFT_Detector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.head = nn.Linear(1024, config['backbone_config']['num_classes']) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + # prepare the backbone + backbone = get_clip_visual(model_name=config['pretrained']) + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + feat = self.backbone(data_dict['image']) + return feat + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + + # compute metrics for batch data + # auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + # metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.config['backbone_config']['num_classes']) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + # prob = torch.softmax(pred, dim=1)[:, 1] + prob = torch.softmax(pred, dim=1) + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + return pred_dict + + +def get_clip_visual(model_name = 'dinov3_vitl16'): + dinov3_vitl14 = torch.hub.load('facebookresearch/dinov3', model_name) + return dinov3_vitl14 + +# urllib.error.HTTPError: HTTP Error 403: rate limit exceeded +# Use a single GPU first to perform hub loading diff --git a/training/detectors/efficientnetb4_detector.py b/training/detectors/efficientnetb4_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..43d104e40b56a28d8116e37629370e1dd50b2a33 --- /dev/null +++ b/training/detectors/efficientnetb4_detector.py @@ -0,0 +1,114 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the EfficientDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{tan2019efficientnet, + title={Efficientnet: Rethinking model scaling for convolutional neural networks}, + author={Tan, Mingxing and Le, Quoc}, + booktitle={International conference on machine learning}, + pages={6105--6114}, + year={2019}, + organization={PMLR} +} +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +import random + +logger = logging.getLogger(__name__) + +@DETECTOR.register_module(module_name='efficientnetb4') +class EfficientDetector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + # prepare the backbone + backbone_class = BACKBONE[config['backbone_name']] + model_config = config['backbone_config'] + model_config['pretrained'] = self.config['pretrained'] + backbone = backbone_class(model_config) + if config['pretrained'] != 'None': + logger.info('Load pretrained model successfully!') + else: + logger.info('No pretrained model.') + return backbone + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + x = self.backbone.features(data_dict['image']) + return x + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.backbone.classifier(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + # compute metrics for batch data + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + prob = torch.softmax(pred, dim=1)[:, 1] + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + + return pred_dict + diff --git a/training/detectors/effort_cl_detector.py b/training/detectors/effort_cl_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..e3adb237aadf5a5ca0a4892b1f27384ab15fb960 --- /dev/null +++ b/training/detectors/effort_cl_detector.py @@ -0,0 +1,386 @@ +import os +import math +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC + +import loralib as lora +from transformers import AutoProcessor, CLIPModel, ViTModel, ViTConfig + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='effort_cl') +class EffortCLDetector(nn.Module): + def __init__(self, config=None): + super(EffortCLDetector, self).__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.head = nn.Linear(1024, config['backbone_config']['num_classes']) + + self.mlp_cl = nn.Sequential( + nn.Linear(1024, 1024), + nn.ReLU(inplace=True), + nn.Linear(1024, 1024), + nn.ReLU(inplace=True), + nn.Linear(1024, 1024), + nn.ReLU(inplace=True) + ) + self.head_cl = nn.Linear(1024, config['backbone_config']['num_classes']) + + loss_class = LOSSFUNC[config['loss_func']] + self.loss_func = loss_class() + + self.prob, self.label = [], [] + self.correct, self.total = 0, 0 + + # Step 1: load the previous weights + ckpt = torch.load("logs/effort/effort_2025-06-09-13-32-29/test/Celeb-DF-v2/ckpt_best.pth") + new_state_dict = {k.replace('module.', ''): v for k, v in ckpt.items()} + self.load_state_dict(new_state_dict, strict=False) + + # Step 2: freeze the backbone and head + for param in self.backbone.parameters(): + param.requires_grad = False + for param in self.head.parameters(): + param.requires_grad = False + + # Step 3: add a new supervised contrastive loss layer + + + def build_backbone(self, config): + # Download model + # https://huggingface.co/openai/clip-vit-large-patch14 + + # mean: [0.48145466, 0.4578275, 0.40821073] + # std: [0.26862954, 0.26130258, 0.27577711] + + # ViT-L/14 224*224 + clip_model = CLIPModel.from_pretrained(self.config["pretrained"]) + + # Apply SVD to self_attn layers only + # ViT-L/14 224*224: 1024-1 + clip_model.vision_model = apply_svd_residual_to_self_attn(clip_model.vision_model, r=1024-1) + + # for name, param in clip_model.vision_model.named_parameters(): + # print('{}: {}'.format(name, param.requires_grad)) + num_param = sum(p.numel() for p in clip_model.vision_model.parameters() if p.requires_grad) + num_total_param = sum(p.numel() for p in clip_model.vision_model.parameters()) + print('Number of total parameters: {}, tunable parameters: {}'.format(num_total_param, num_param)) + + return clip_model.vision_model + + def features(self, data_dict: dict) -> torch.tensor: + # data_dict['image']: torch.Size([32, 3, 224, 224]) + feat = self.backbone(data_dict['image'])['pooler_output'] + # feat torch.Size([32, 1024]) + return feat + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + # def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + # label = data_dict['label'] + # pred = pred_dict['cls'] + # loss = self.loss_func(pred, label) + + # # Regularization term + # lambda_reg = 0.1 + # orthogonal_losses = [] + # for module in self.backbone.modules(): + # if isinstance(module, SVDResidualLinear): + # # Apply orthogonal constraints to the U_residual and V_residual matrix + # orthogonal_losses.append(module.compute_orthogonal_loss()) + + # if orthogonal_losses: + # reg_term = sum(orthogonal_losses) + # loss += lambda_reg * reg_term + + # loss_dict = {'overall': loss} + # return loss_dict + + def compute_weight_loss(self): + weight_sum_dict = {} + num_weight_dict = {} + for name, module in self.backbone.named_modules(): + if isinstance(module, SVDResidualLinear): + weight_curr = module.compute_current_weight() + if str(weight_curr.size()) not in weight_sum_dict.keys(): + weight_sum_dict[str(weight_curr.size())] = weight_curr + num_weight_dict[str(weight_curr.size())] = 1 + else: + weight_sum_dict[str(weight_curr.size())] += weight_curr + num_weight_dict[str(weight_curr.size())] += 1 + + loss2 = 0.0 + for k in weight_sum_dict.keys(): + _, S_sum, _ = torch.linalg.svd(weight_sum_dict[k], full_matrices=False) + loss2 += -torch.mean(S_sum) + loss2 /= len(weight_sum_dict.keys()) + return loss2 + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] # Tensor of shape [batch_size] + pred = pred_dict['cls'] # Tensor of shape [batch_size, num_classes] + + # Compute overall loss using all samples + loss = self.loss_func(pred_dict['feat'], pred, label) + + # Create masks for real and fake classes + mask_real = label == 0 # Boolean tensor + mask_fake = label == 1 # Boolean tensor + + # Compute loss for real class + if mask_real.sum() > 0: + feat_real = pred_dict['feat'][mask_real] + pred_real = pred[mask_real] + label_real = label[mask_real] + loss_real = self.loss_func(feat_real, pred_real, label_real) + else: + # No real samples in batch + loss_real = torch.tensor(0.0, device=pred.device) + + # Compute loss for fake class + if mask_fake.sum() > 0: + feat_fake = pred_dict['feat'][mask_fake] + pred_fake = pred[mask_fake] + label_fake = label[mask_fake] + loss_fake = self.loss_func(feat_fake, pred_fake, label_fake) + else: + # No fake samples in batch + loss_fake = torch.tensor(0.0, device=pred.device) + + # loss2 = self.compute_weight_loss() + # overall_loss = loss + loss2 + + # Return a dictionary with all losses + loss_dict = { + 'overall': loss, + 'real_loss': loss_real, + 'fake_loss': loss_fake, + # 'erank_loss': loss2 + } + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + + # compute metrics for batch data + # auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + # metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.config['backbone_config']['num_classes']) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # CL + features_cl = self.mlp_cl(features) + # CL Head + pred_cl = self.head_cl(features_cl) + # get the probability of the pred + # prob = torch.softmax(pred, dim=1)[:, 1] + prob_cl = torch.softmax(pred_cl, dim=1) + # build the prediction dict for each output + pred_dict = {'cls': pred_cl, 'prob': prob_cl, 'feat': features_cl} + + return pred_dict + +# Custom module to represent the residual using SVD components +class SVDResidualLinear(nn.Module): + def __init__(self, in_features, out_features, r, bias=True, init_weight=None): + super(SVDResidualLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.r = r # Number of top singular values to exclude + + # Original weights (fixed) + self.weight_main = nn.Parameter(torch.Tensor(out_features, in_features), requires_grad=False) + if init_weight is not None: + self.weight_main.data.copy_(init_weight) + else: + nn.init.kaiming_uniform_(self.weight_main, a=math.sqrt(5)) + + # Bias + if bias: + self.bias = nn.Parameter(torch.Tensor(out_features)) + nn.init.zeros_(self.bias) + else: + self.register_parameter('bias', None) + + def compute_current_weight(self): + if self.S_residual is not None: + return self.weight_main + self.U_residual @ torch.diag(self.S_residual) @ self.V_residual + else: + return self.weight_main + + def forward(self, x): + if hasattr(self, 'U_residual') and hasattr(self, 'V_residual') and self.S_residual is not None: + # Reconstruct the residual weight + residual_weight = self.U_residual @ torch.diag(self.S_residual) @ self.V_residual + # Total weight is the fixed main weight plus the residual + weight = self.weight_main + residual_weight + else: + # If residual components are not set, use only the main weight + weight = self.weight_main + + return F.linear(x, weight, self.bias) + + def compute_orthogonal_loss(self): + if self.S_residual is not None: + # According to the properties of orthogonal matrices: A^TA = I + UUT = torch.cat((self.U_r, self.U_residual), dim=1) @ torch.cat((self.U_r, self.U_residual), dim=1).t() + VVT = torch.cat((self.V_r, self.V_residual), dim=0) @ torch.cat((self.V_r, self.V_residual), dim=0).t() + # print(self.U_r.size(), self.U_residual.size()) # torch.Size([1024, 1023]) torch.Size([1024, 1]) + # print(self.V_r.size(), self.V_residual.size()) # torch.Size([1023, 1024]) torch.Size([1, 1024]) + # UUT = self.U_residual @ self.U_residual.t() + # VVT = self.V_residual @ self.V_residual.t() + + # Construct an identity matrix + UUT_identity = torch.eye(UUT.size(0), device=UUT.device) + VVT_identity = torch.eye(VVT.size(0), device=VVT.device) + + # Using frobenius norm to compute loss + loss = 0.5 * torch.norm(UUT - UUT_identity, p='fro') + 0.5 * torch.norm(VVT - VVT_identity, p='fro') + else: + loss = 0.0 + + return loss + + def compute_keepsv_loss(self): + if (self.S_residual is not None) and (self.weight_original_fnorm is not None): + # Total current weight is the fixed main weight plus the residual + weight_current = self.weight_main + self.U_residual @ torch.diag(self.S_residual) @ self.V_residual + # Frobenius norm of current weight + weight_current_fnorm = torch.norm(weight_current, p='fro') + + loss = torch.abs(weight_current_fnorm ** 2 - self.weight_original_fnorm ** 2) + # loss = torch.abs(weight_current_fnorm ** 2 + 0.01 * self.weight_main_fnorm ** 2 - 1.01 * self.weight_original_fnorm ** 2) + else: + loss = 0.0 + + return loss + + def compute_fn_loss(self): + if (self.S_residual is not None): + weight_current = self.weight_main + self.U_residual @ torch.diag(self.S_residual) @ self.V_residual + weight_current_fnorm = torch.norm(weight_current, p='fro') + + loss = weight_current_fnorm ** 2 + else: + loss = 0.0 + + return loss + + +# Function to replace nn.Linear modules within self_attn modules with SVDResidualLinear +def apply_svd_residual_to_self_attn(model, r): + for name, module in model.named_children(): + if 'self_attn' in name: + # Replace nn.Linear layers in this module + for sub_name, sub_module in module.named_modules(): + if isinstance(sub_module, nn.Linear): + # Get parent module within self_attn + parent_module = module + sub_module_names = sub_name.split('.') + for module_name in sub_module_names[:-1]: + parent_module = getattr(parent_module, module_name) + # Replace the nn.Linear layer with SVDResidualLinear + setattr(parent_module, sub_module_names[-1], replace_with_svd_residual(sub_module, r)) + else: + # Recursively apply to child modules + apply_svd_residual_to_self_attn(module, r) + # After replacing, set requires_grad for residual components + for param_name, param in model.named_parameters(): + if any(x in param_name for x in ['S_residual', 'U_residual', 'V_residual']): + param.requires_grad = True + else: + param.requires_grad = False + return model + + +# Function to replace a module with SVDResidualLinear +def replace_with_svd_residual(module, r): + if isinstance(module, nn.Linear): + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + + # Create SVDResidualLinear module + new_module = SVDResidualLinear(in_features, out_features, r, bias=bias, init_weight=module.weight.data.clone()) + + if bias and module.bias is not None: + new_module.bias.data.copy_(module.bias.data) + + new_module.weight_original_fnorm = torch.norm(module.weight.data, p='fro') + + # Perform SVD on the original weight + U, S, Vh = torch.linalg.svd(module.weight.data, full_matrices=False) + + # Determine r based on the rank of the weight matrix + r = min(r, len(S)) # Ensure r does not exceed the number of singular values + + # Keep top r singular components (main weight) + U_r = U[:, :r] # Shape: (out_features, r) + S_r = S[:r] # Shape: (r,) + Vh_r = Vh[:r, :] # Shape: (r, in_features) + + # Reconstruct the main weight (fixed) + weight_main = U_r @ torch.diag(S_r) @ Vh_r + + # Calculate the frobenius norm of main weight + new_module.weight_main_fnorm = torch.norm(weight_main.data, p='fro') + + # Set the main weight + new_module.weight_main.data.copy_(weight_main) + + # Residual components (trainable) + U_residual = U[:, r:] # Shape: (out_features, n - r) + S_residual = S[r:] # Shape: (n - r,) + Vh_residual = Vh[r:, :] # Shape: (n - r, in_features) + + if len(S_residual) > 0: + new_module.S_residual = nn.Parameter(S_residual.clone()) + new_module.U_residual = nn.Parameter(U_residual.clone()) + new_module.V_residual = nn.Parameter(Vh_residual.clone()) + + new_module.S_r = nn.Parameter(S_r.clone(), requires_grad=False) + new_module.U_r = nn.Parameter(U_r.clone(), requires_grad=False) + new_module.V_r = nn.Parameter(Vh_r.clone(), requires_grad=False) + else: + new_module.S_residual = None + new_module.U_residual = None + new_module.V_residual = None + + new_module.S_r = None + new_module.U_r = None + new_module.V_r = None + + return new_module + else: + return module + + diff --git a/training/detectors/effort_detector.py b/training/detectors/effort_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..1f954045779bf08fdc381bce6bec34f6b9eaeaed --- /dev/null +++ b/training/detectors/effort_detector.py @@ -0,0 +1,523 @@ +import os +import math +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC + +import loralib as lora +from transformers import AutoProcessor, CLIPModel, ViTModel, ViTConfig + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='effort') +class EffortDetector(nn.Module): + def __init__(self, config=None): + super(EffortDetector, self).__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.head = nn.Linear(1024, config['backbone_config']['num_classes']) + self.loss_func = nn.CrossEntropyLoss() + self.prob, self.label = [], [] + self.correct, self.total = 0, 0 + + def build_backbone(self, config): + # Download model + # https://huggingface.co/openai/clip-vit-large-patch14 + + # mean: [0.48145466, 0.4578275, 0.40821073] + # std: [0.26862954, 0.26130258, 0.27577711] + + # ViT-L/14 224*224 + clip_model = CLIPModel.from_pretrained(self.config["pretrained"]) + + # Apply SVD to self_attn layers only + # ViT-L/14 224*224: 1024-1 + clip_model.vision_model = apply_svd_residual_to_self_attn(clip_model.vision_model, r=1024-64) + + for name, param in clip_model.vision_model.named_parameters(): + print('{}: {}'.format(name, param.requires_grad)) + num_param = sum(p.numel() for p in clip_model.vision_model.parameters() if p.requires_grad) + num_total_param = sum(p.numel() for p in clip_model.vision_model.parameters()) + print('Number of total parameters: {}, tunable parameters: {}'.format(num_total_param, num_param)) + + return clip_model.vision_model + + def features(self, data_dict: dict) -> torch.tensor: + # data_dict['image']: torch.Size([32, 3, 224, 224]) + feat = self.backbone(data_dict['image'])['pooler_output'] + # feat torch.Size([32, 1024]) + return feat + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + # def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + # label = data_dict['label'] + # pred = pred_dict['cls'] + # loss = self.loss_func(pred, label) + + # # Regularization term + # lambda_reg = 0.1 + # orthogonal_losses = [] + # for module in self.backbone.modules(): + # if isinstance(module, SVDResidualLinear): + # # Apply orthogonal constraints to the U_residual and V_residual matrix + # orthogonal_losses.append(module.compute_orthogonal_loss()) + + # if orthogonal_losses: + # reg_term = sum(orthogonal_losses) + # loss += lambda_reg * reg_term + + # loss_dict = {'overall': loss} + # return loss_dict + + def compute_weight_loss(self): + weight_sum_dict = {} + num_weight_dict = {} + for name, module in self.backbone.named_modules(): + if isinstance(module, SVDResidualLinear): + weight_curr = module.compute_current_weight() + if str(weight_curr.size()) not in weight_sum_dict.keys(): + weight_sum_dict[str(weight_curr.size())] = weight_curr + num_weight_dict[str(weight_curr.size())] = 1 + else: + weight_sum_dict[str(weight_curr.size())] += weight_curr + num_weight_dict[str(weight_curr.size())] += 1 + + loss2 = 0.0 + for k in weight_sum_dict.keys(): + _, S_sum, _ = torch.linalg.svd(weight_sum_dict[k], full_matrices=False) + loss2 += -torch.mean(S_sum) + loss2 /= len(weight_sum_dict.keys()) + return loss2 + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] # Tensor of shape [batch_size] + pred = pred_dict['cls'] # Tensor of shape [batch_size, num_classes] + + # Compute overall loss using all samples + loss = self.loss_func(pred, label) + + # Create masks for real and fake classes + mask_real = label == 0 # Boolean tensor + mask_fake = label == 1 # Boolean tensor + + # Compute loss for real class + if mask_real.sum() > 0: + pred_real = pred[mask_real] + label_real = label[mask_real] + loss_real = self.loss_func(pred_real, label_real) + else: + # No real samples in batch + loss_real = torch.tensor(0.0, device=pred.device) + + # Compute loss for fake class + if mask_fake.sum() > 0: + pred_fake = pred[mask_fake] + label_fake = label[mask_fake] + loss_fake = self.loss_func(pred_fake, label_fake) + else: + # No fake samples in batch + loss_fake = torch.tensor(0.0, device=pred.device) + + # loss2 = self.compute_weight_loss() + # overall_loss = loss + loss2 + + # Return a dictionary with all losses + loss_dict = { + 'overall': loss, + 'real_loss': loss_real, + 'fake_loss': loss_fake, + # 'erank_loss': loss2 + } + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + + # compute metrics for batch data + # auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + # metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.config['backbone_config']['num_classes']) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + + return metric_batch_dict + + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + # prob = torch.softmax(pred, dim=1)[:, 1] + prob = torch.softmax(pred, dim=1) + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + + return pred_dict + +# Custom module to represent the residual using SVD components +class SVDResidualLinear(nn.Module): + def __init__(self, in_features, out_features, r, bias=True, init_weight=None): + super(SVDResidualLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.r = r # Number of top singular values to exclude + + # Original weights (fixed) + self.weight_main = nn.Parameter(torch.Tensor(out_features, in_features), requires_grad=False) + if init_weight is not None: + self.weight_main.data.copy_(init_weight) + else: + nn.init.kaiming_uniform_(self.weight_main, a=math.sqrt(5)) + + # Bias + if bias: + self.bias = nn.Parameter(torch.Tensor(out_features)) + nn.init.zeros_(self.bias) + else: + self.register_parameter('bias', None) + + def compute_current_weight(self): + if self.S_residual is not None: + return self.weight_main + self.U_residual @ torch.diag(self.S_residual) @ self.V_residual + else: + return self.weight_main + + def forward(self, x): + if hasattr(self, 'U_residual') and hasattr(self, 'V_residual') and self.S_residual is not None: + # Reconstruct the residual weight + residual_weight = self.U_residual @ torch.diag(self.S_residual) @ self.V_residual + # Total weight is the fixed main weight plus the residual + weight = self.weight_main + residual_weight + else: + # If residual components are not set, use only the main weight + weight = self.weight_main + + return F.linear(x, weight, self.bias) + + def compute_orthogonal_loss(self): + if self.S_residual is not None: + # According to the properties of orthogonal matrices: A^TA = I + UUT = torch.cat((self.U_r, self.U_residual), dim=1) @ torch.cat((self.U_r, self.U_residual), dim=1).t() + VVT = torch.cat((self.V_r, self.V_residual), dim=0) @ torch.cat((self.V_r, self.V_residual), dim=0).t() + # print(self.U_r.size(), self.U_residual.size()) # torch.Size([1024, 1023]) torch.Size([1024, 1]) + # print(self.V_r.size(), self.V_residual.size()) # torch.Size([1023, 1024]) torch.Size([1, 1024]) + # UUT = self.U_residual @ self.U_residual.t() + # VVT = self.V_residual @ self.V_residual.t() + + # Construct an identity matrix + UUT_identity = torch.eye(UUT.size(0), device=UUT.device) + VVT_identity = torch.eye(VVT.size(0), device=VVT.device) + + # Using frobenius norm to compute loss + loss = 0.5 * torch.norm(UUT - UUT_identity, p='fro') + 0.5 * torch.norm(VVT - VVT_identity, p='fro') + else: + loss = 0.0 + + return loss + + def compute_keepsv_loss(self): + if (self.S_residual is not None) and (self.weight_original_fnorm is not None): + # Total current weight is the fixed main weight plus the residual + weight_current = self.weight_main + self.U_residual @ torch.diag(self.S_residual) @ self.V_residual + # Frobenius norm of current weight + weight_current_fnorm = torch.norm(weight_current, p='fro') + + loss = torch.abs(weight_current_fnorm ** 2 - self.weight_original_fnorm ** 2) + # loss = torch.abs(weight_current_fnorm ** 2 + 0.01 * self.weight_main_fnorm ** 2 - 1.01 * self.weight_original_fnorm ** 2) + else: + loss = 0.0 + + return loss + + def compute_fn_loss(self): + if (self.S_residual is not None): + weight_current = self.weight_main + self.U_residual @ torch.diag(self.S_residual) @ self.V_residual + weight_current_fnorm = torch.norm(weight_current, p='fro') + + loss = weight_current_fnorm ** 2 + else: + loss = 0.0 + + return loss + + +# Function to replace nn.Linear modules within self_attn modules with SVDResidualLinear +def apply_svd_residual_to_self_attn(model, r): + for name, module in model.named_children(): + if 'self_attn' in name: + # Replace nn.Linear layers in this module + for sub_name, sub_module in module.named_modules(): + if isinstance(sub_module, nn.Linear): + # Get parent module within self_attn + parent_module = module + sub_module_names = sub_name.split('.') + for module_name in sub_module_names[:-1]: + parent_module = getattr(parent_module, module_name) + # Replace the nn.Linear layer with SVDResidualLinear + setattr(parent_module, sub_module_names[-1], replace_with_svd_residual(sub_module, r)) + else: + # Recursively apply to child modules + apply_svd_residual_to_self_attn(module, r) + # After replacing, set requires_grad for residual components + for param_name, param in model.named_parameters(): + if any(x in param_name for x in ['S_residual', 'U_residual', 'V_residual']): + param.requires_grad = True + else: + param.requires_grad = False + return model + + +# Function to replace a module with SVDResidualLinear +def replace_with_svd_residual(module, r): + if isinstance(module, nn.Linear): + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + + # Create SVDResidualLinear module + new_module = SVDResidualLinear(in_features, out_features, r, bias=bias, init_weight=module.weight.data.clone()) + + if bias and module.bias is not None: + new_module.bias.data.copy_(module.bias.data) + + new_module.weight_original_fnorm = torch.norm(module.weight.data, p='fro') + + # Perform SVD on the original weight + U, S, Vh = torch.linalg.svd(module.weight.data, full_matrices=False) + + # Determine r based on the rank of the weight matrix + r = min(r, len(S)) # Ensure r does not exceed the number of singular values + + # Keep top r singular components (main weight) + U_r = U[:, :r] # Shape: (out_features, r) + S_r = S[:r] # Shape: (r,) + Vh_r = Vh[:r, :] # Shape: (r, in_features) + + # Reconstruct the main weight (fixed) + weight_main = U_r @ torch.diag(S_r) @ Vh_r + + # Calculate the frobenius norm of main weight + new_module.weight_main_fnorm = torch.norm(weight_main.data, p='fro') + + # Set the main weight + new_module.weight_main.data.copy_(weight_main) + + # Residual components (trainable) + U_residual = U[:, r:] # Shape: (out_features, n - r) + S_residual = S[r:] # Shape: (n - r,) + Vh_residual = Vh[r:, :] # Shape: (n - r, in_features) + + if len(S_residual) > 0: + new_module.S_residual = nn.Parameter(S_residual.clone()) + new_module.U_residual = nn.Parameter(U_residual.clone()) + new_module.V_residual = nn.Parameter(Vh_residual.clone()) + + new_module.S_r = nn.Parameter(S_r.clone(), requires_grad=False) + new_module.U_r = nn.Parameter(U_r.clone(), requires_grad=False) + new_module.V_r = nn.Parameter(Vh_r.clone(), requires_grad=False) + else: + new_module.S_residual = None + new_module.U_residual = None + new_module.V_residual = None + + new_module.S_r = None + new_module.U_r = None + new_module.V_r = None + + return new_module + else: + return module + +''' +Training Params: + +embeddings.class_embedding: False +embeddings.patch_embedding.weight: False +embeddings.position_embedding.weight: False +pre_layrnorm.weight: False +pre_layrnorm.bias: False + +encoder.layers.0.self_attn.k_proj.weight_main: False +encoder.layers.0.self_attn.k_proj.bias: False +encoder.layers.0.self_attn.k_proj.S_residual: True +encoder.layers.0.self_attn.k_proj.U_residual: True +encoder.layers.0.self_attn.k_proj.V_residual: True +encoder.layers.0.self_attn.v_proj.weight_main: False +encoder.layers.0.self_attn.v_proj.bias: False +encoder.layers.0.self_attn.v_proj.S_residual: True +encoder.layers.0.self_attn.v_proj.U_residual: True +encoder.layers.0.self_attn.v_proj.V_residual: True +encoder.layers.0.self_attn.q_proj.weight_main: False +encoder.layers.0.self_attn.q_proj.bias: False +encoder.layers.0.self_attn.q_proj.S_residual: True +encoder.layers.0.self_attn.q_proj.U_residual: True +encoder.layers.0.self_attn.q_proj.V_residual: True +encoder.layers.0.self_attn.out_proj.weight_main: False +encoder.layers.0.self_attn.out_proj.bias: False +encoder.layers.0.self_attn.out_proj.S_residual: True +encoder.layers.0.self_attn.out_proj.U_residual: True +encoder.layers.0.self_attn.out_proj.V_residual: True +encoder.layers.0.layer_norm1.weight: False +encoder.layers.0.layer_norm1.bias: False +encoder.layers.0.mlp.fc1.weight: False +encoder.layers.0.mlp.fc1.bias: False +encoder.layers.0.mlp.fc2.weight: False +encoder.layers.0.mlp.fc2.bias: False +encoder.layers.0.layer_norm2.weight: False +encoder.layers.0.layer_norm2.bias: False + +encoder.layers.1.self_attn.k_proj.weight_main: False +encoder.layers.1.self_attn.k_proj.bias: False +encoder.layers.1.self_attn.k_proj.S_residual: True +encoder.layers.1.self_attn.k_proj.U_residual: True +encoder.layers.1.self_attn.k_proj.V_residual: True +encoder.layers.1.self_attn.v_proj.weight_main: False +encoder.layers.1.self_attn.v_proj.bias: False +encoder.layers.1.self_attn.v_proj.S_residual: True +encoder.layers.1.self_attn.v_proj.U_residual: True +encoder.layers.1.self_attn.v_proj.V_residual: True +encoder.layers.1.self_attn.q_proj.weight_main: False +encoder.layers.1.self_attn.q_proj.bias: False +encoder.layers.1.self_attn.q_proj.S_residual: True +encoder.layers.1.self_attn.q_proj.U_residual: True +encoder.layers.1.self_attn.q_proj.V_residual: True +encoder.layers.1.self_attn.out_proj.weight_main: False +encoder.layers.1.self_attn.out_proj.bias: False +encoder.layers.1.self_attn.out_proj.S_residual: True +encoder.layers.1.self_attn.out_proj.U_residual: True +encoder.layers.1.self_attn.out_proj.V_residual: True +encoder.layers.1.layer_norm1.weight: False +encoder.layers.1.layer_norm1.bias: False +encoder.layers.1.mlp.fc1.weight: False +encoder.layers.1.mlp.fc1.bias: False +encoder.layers.1.mlp.fc2.weight: False +encoder.layers.1.mlp.fc2.bias: False +encoder.layers.1.layer_norm2.weight: False +encoder.layers.1.layer_norm2.bias: False + +encoder.layers.2.self_attn.k_proj.weight_main: False +encoder.layers.2.self_attn.k_proj.bias: False +encoder.layers.2.self_attn.k_proj.S_residual: True +encoder.layers.2.self_attn.k_proj.U_residual: True +encoder.layers.2.self_attn.k_proj.V_residual: True +encoder.layers.2.self_attn.v_proj.weight_main: False +encoder.layers.2.self_attn.v_proj.bias: False +encoder.layers.2.self_attn.v_proj.S_residual: True +encoder.layers.2.self_attn.v_proj.U_residual: True +encoder.layers.2.self_attn.v_proj.V_residual: True +encoder.layers.2.self_attn.q_proj.weight_main: False +encoder.layers.2.self_attn.q_proj.bias: False +encoder.layers.2.self_attn.q_proj.S_residual: True +encoder.layers.2.self_attn.q_proj.U_residual: True +encoder.layers.2.self_attn.q_proj.V_residual: True +encoder.layers.2.self_attn.out_proj.weight_main: False +encoder.layers.2.self_attn.out_proj.bias: False +encoder.layers.2.self_attn.out_proj.S_residual: True +encoder.layers.2.self_attn.out_proj.U_residual: True +encoder.layers.2.self_attn.out_proj.V_residual: True +encoder.layers.2.layer_norm1.weight: False +encoder.layers.2.layer_norm1.bias: False +encoder.layers.2.mlp.fc1.weight: False +encoder.layers.2.mlp.fc1.bias: False +encoder.layers.2.mlp.fc2.weight: False +encoder.layers.2.mlp.fc2.bias: False +encoder.layers.2.layer_norm2.weight: False +encoder.layers.2.layer_norm2.bias: False +... +encoder.layers.23.self_attn.k_proj.weight_main: False +encoder.layers.23.self_attn.k_proj.bias: False +encoder.layers.23.self_attn.k_proj.S_residual: True +encoder.layers.23.self_attn.k_proj.U_residual: True +encoder.layers.23.self_attn.k_proj.V_residual: True +encoder.layers.23.self_attn.v_proj.weight_main: False +encoder.layers.23.self_attn.v_proj.bias: False +encoder.layers.23.self_attn.v_proj.S_residual: True +encoder.layers.23.self_attn.v_proj.U_residual: True +encoder.layers.23.self_attn.v_proj.V_residual: True +encoder.layers.23.self_attn.q_proj.weight_main: False +encoder.layers.23.self_attn.q_proj.bias: False +encoder.layers.23.self_attn.q_proj.S_residual: True +encoder.layers.23.self_attn.q_proj.U_residual: True +encoder.layers.23.self_attn.q_proj.V_residual: True +encoder.layers.23.self_attn.out_proj.weight_main: False +encoder.layers.23.self_attn.out_proj.bias: False +encoder.layers.23.self_attn.out_proj.S_residual: True +encoder.layers.23.self_attn.out_proj.U_residual: True +encoder.layers.23.self_attn.out_proj.V_residual: True +encoder.layers.23.layer_norm1.weight: False +encoder.layers.23.layer_norm1.bias: False +encoder.layers.23.mlp.fc1.weight: False +encoder.layers.23.mlp.fc1.bias: False +encoder.layers.23.mlp.fc2.weight: False +encoder.layers.23.mlp.fc2.bias: False +encoder.layers.23.layer_norm2.weight: False +encoder.layers.23.layer_norm2.bias: False + +post_layernorm.weight: False +post_layernorm.bias: False +Number of total parameters: 303376480, tunable parameters: 196704 + + +===> Load checkpoint done! +100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 130/130 [01:07<00:00, 1.92it/s] +dataset: Celeb-DF-v2 +acc: 0.7873882580333413 +auc: 0.8674386218546616 +eer: 0.21000704721634955 +ap: 0.9322288761515111 +pred: [0.9752515 0.6580601 0.75344455 ... 0.45359948 0.8914075 0.14674814] +video_auc: 0.9105750165234634 +label: [1 1 0 ... 1 1 0] +100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 851/851 [07:17<00:00, 1.94it/s] +dataset: DeepFakeDetection +acc: 0.8606078424166698 +auc: 0.9048725171315315 +eer: 0.16390041493775934 +ap: 0.9883843861944681 +pred: [0.9912942 0.4690933 0.99789536 ... 0.8104649 0.9893 0.78386295] +video_auc: 0.9373875743738758 +label: [1 1 1 ... 1 1 1] +100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [01:21<00:00, 1.94it/s] +dataset: DFDCP +acc: 0.7002589125672177 +auc: 0.8182703419711848 +eer: 0.28125 +ap: 0.9055071587990912 +pred: [0.38877106 0.606897 0.5618232 ... 0.5063871 0.08330307 0.31897077] +video_auc: 0.849247887904389 +label: [1 0 1 ... 1 0 0] +100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1128/1128 [09:41<00:00, 1.94it/s] +dataset: DFDC +acc: 0.7389942337547128 +auc: 0.8153807538695912 +eer: 0.26452712297642716 +ap: 0.8553135531715989 +pred: [0.24556044 0.2040193 0.4257187 ... 0.82186127 0.9962172 0.50927925] +video_auc: 0.8395751948048426 +label: [0 1 0 ... 0 1 0] +===> Test Done! +''' + diff --git a/training/detectors/effort_patch_shuffle.py b/training/detectors/effort_patch_shuffle.py new file mode 100644 index 0000000000000000000000000000000000000000..b1fa8c027b8ed4b7565785e68d3ac4c31b900a93 --- /dev/null +++ b/training/detectors/effort_patch_shuffle.py @@ -0,0 +1,457 @@ +import os +import math +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +import albumentations as A +import loralib as lora +from transformers import AutoProcessor, CLIPModel, ViTModel, ViTConfig + +logger = logging.getLogger(__name__) + + +def get_clip_visual(model_name = "openai/clip-vit-base-patch16"): + processor = AutoProcessor.from_pretrained(model_name) + model = CLIPModel.from_pretrained(model_name) + return processor, model.vision_model + + +def shuffle_patches(images: torch.Tensor, patch_size: int = 14) -> torch.Tensor: + """ + Apply patch-level shuffling to the input images. + images: [B, C, H, W] + patch_size: patch size used by ViT (for example, 16) + Returns: an image tensor with the same shape [B, C, H, W] + """ + B, C, H, W = images.shape + assert H % patch_size == 0 and W % patch_size == 0, \ + f"H ({H}) and W ({W}) must be divisible by patch_size ({patch_size})" + + num_patches_h = H // patch_size + num_patches_w = W // patch_size + num_patches = num_patches_h * num_patches_w + + # [B, C, H, W] -> [B, C, num_patches_h, patch_size, num_patches_w, patch_size] + images = images.view(B, C, num_patches_h, patch_size, num_patches_w, patch_size) + # -> [B, num_patches_h, num_patches_w, C, patch_size, patch_size] + images = images.permute(0, 2, 4, 1, 3, 5).contiguous() + # -> [B, num_patches, C, patch_size, patch_size] + images = images.view(B, num_patches, C, patch_size, patch_size) + + # Shuffle patch order independently for each image. + # permutation shape: [B, num_patches] + perms = torch.stack( + [torch.randperm(num_patches, device=images.device) for _ in range(B)], + dim=0 + ) + # Use advanced indexing to perform the shuffle. + batch_idx = torch.arange(B, device=images.device).unsqueeze(1).expand(B, num_patches) + images = images[batch_idx, perms] # [B, num_patches, C, patch_size, patch_size] + + # Restore the original image shape. + images = images.view(B, num_patches_h, num_patches_w, C, patch_size, patch_size) + # -> [B, C, num_patches_h, patch_size, num_patches_w, patch_size] + images = images.permute(0, 3, 1, 4, 2, 5).contiguous() + # -> [B, C, H, W] + images = images.view(B, C, H, W) + + return images + +def get_aug_transform(): + return A.Compose([ + A.HorizontalFlip(p=0.5), + A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5), + A.HueSaturationValue(p=0.3), + A.ImageCompression(quality_lower=40, quality_upper=100, p=0.1), + A.GaussNoise(p=0.1), + A.MotionBlur(p=0.1), + A.CLAHE(p=0.1), + A.ChannelShuffle(p=0.1), + A.Cutout(p=0.1), + A.RandomGamma(p=0.3), + A.GlassBlur(p=0.3), + ]) + + +def data_aug(images: torch.Tensor) -> torch.Tensor: + is_gpu = images.is_cuda + aug = get_aug_transform() + + # Step 1: convert the batch tensor to batch numpy arrays (BHWC uint8 0-255). + imgs_np = images.cpu().detach().numpy() + imgs_np = np.transpose(imgs_np, (0, 2, 3, 1)) # BCHW -> BHWC + imgs_np = (imgs_np * 255).astype(np.uint8) + + # Step 2: augment images one by one to avoid KeyError from batch-style arguments. + imgs_aug_np = [] + for img in imgs_np: + # Pass a single image with `image=img`, which is natively supported by Albumentations. + aug_img = aug(image=img)["image"] + imgs_aug_np.append(aug_img) + imgs_aug_np = np.array(imgs_aug_np) # convert back to batch numpy arrays + + # Step 3: convert back to a tensor while preserving the original logic. + aug_tensor = torch.from_numpy(imgs_aug_np).permute(0, 3, 1, 2) + aug_tensor = aug_tensor.float() / 255.0 + + # Restore the original device. + if is_gpu: + aug_tensor = aug_tensor.cuda() + + return aug_tensor + + +@DETECTOR.register_module(module_name='effort_shuffle_ensemble') +class Effort_Shuffle_Ensenble_Detector(nn.Module): + def __init__(self, config=None): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.head = nn.Linear(1024, config['backbone_config']['num_classes']) + #self.head1 = nn.Linear(1024, config['backbone_config']['num_classes']) + self.loss_func = nn.CrossEntropyLoss() + self.prob, self.label = [], [] + self.correct, self.total = 0, 0 + #self.backbone2=self.build_clip_backbone(config) + + def build_clip_backbone(self,config): + _, backbone = get_clip_visual(model_name=config['pretrained']) + return backbone + + def build_backbone(self, config): + # Download model + # https://huggingface.co/openai/clip-vit-large-patch14 + + # mean: [0.48145466, 0.4578275, 0.40821073] + # std: [0.26862954, 0.26130258, 0.27577711] + + # ViT-L/14 224*224 + clip_model = CLIPModel.from_pretrained(self.config["pretrained"]) + + # Apply SVD to self_attn layers only + # ViT-L/14 224*224: 1024-1 + clip_model.vision_model = apply_svd_residual_to_self_attn(clip_model.vision_model, r=1024-1) + + for name, param in clip_model.vision_model.named_parameters(): + print('{}: {}'.format(name, param.requires_grad)) + num_param = sum(p.numel() for p in clip_model.vision_model.parameters() if p.requires_grad) + num_total_param = sum(p.numel() for p in clip_model.vision_model.parameters()) + print('Number of total parameters: {}, tunable parameters: {}'.format(num_total_param, num_param)) + + return clip_model.vision_model + + def features(self, data_dict: dict) -> torch.tensor: + # data_dict['image']: torch.Size([32, 3, 224, 224]) + if self.training: + #aug_image=data_aug(data_dict['image']) + shuffle_images=shuffle_patches(data_dict['image'],14) + feat = self.backbone(shuffle_images)['pooler_output'] + #feat1=self.backbone2(shuffle_images)['pooler_output'] + else: + feat = self.backbone(data_dict['image'])['pooler_output'] + #feat1=self.backbone2(data_dict['image'])['pooler_output'] + # feat torch.Size([32, 1024]) + return feat#,feat1 + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + # def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + # label = data_dict['label'] + # pred = pred_dict['cls'] + # loss = self.loss_func(pred, label) + + # # Regularization term + # lambda_reg = 0.1 + # orthogonal_losses = [] + # for module in self.backbone.modules(): + # if isinstance(module, SVDResidualLinear): + # # Apply orthogonal constraints to the U_residual and V_residual matrix + # orthogonal_losses.append(module.compute_orthogonal_loss()) + + # if orthogonal_losses: + # reg_term = sum(orthogonal_losses) + # loss += lambda_reg * reg_term + + # loss_dict = {'overall': loss} + # return loss_dict + + def compute_weight_loss(self): + weight_sum_dict = {} + num_weight_dict = {} + for name, module in self.backbone.named_modules(): + if isinstance(module, SVDResidualLinear): + weight_curr = module.compute_current_weight() + if str(weight_curr.size()) not in weight_sum_dict.keys(): + weight_sum_dict[str(weight_curr.size())] = weight_curr + num_weight_dict[str(weight_curr.size())] = 1 + else: + weight_sum_dict[str(weight_curr.size())] += weight_curr + num_weight_dict[str(weight_curr.size())] += 1 + + loss2 = 0.0 + for k in weight_sum_dict.keys(): + _, S_sum, _ = torch.linalg.svd(weight_sum_dict[k], full_matrices=False) + loss2 += -torch.mean(S_sum) + loss2 /= len(weight_sum_dict.keys()) + return loss2 + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] # Tensor of shape [batch_size] + pred = pred_dict['cls'] # Tensor of shape [batch_size, num_classes] + + # Compute overall loss using all samples + loss = self.loss_func(pred, label) + + # Create masks for real and fake classes + mask_real = label == 0 # Boolean tensor + mask_fake = label == 1 # Boolean tensor + + # Compute loss for real class + if mask_real.sum() > 0: + pred_real = pred[mask_real] + label_real = label[mask_real] + loss_real = self.loss_func(pred_real, label_real) + else: + # No real samples in batch + loss_real = torch.tensor(0.0, device=pred.device) + + # Compute loss for fake class + if mask_fake.sum() > 0: + pred_fake = pred[mask_fake] + label_fake = label[mask_fake] + loss_fake = self.loss_func(pred_fake, label_fake) + else: + # No fake samples in batch + loss_fake = torch.tensor(0.0, device=pred.device) + + # loss2 = self.compute_weight_loss() + # overall_loss = loss + loss2 + + # Return a dictionary with all losses + loss_dict = { + 'overall': loss, + 'real_loss': loss_real, + 'fake_loss': loss_fake, + # 'erank_loss': loss2 + } + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + + # compute metrics for batch data + # auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + # metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.config['backbone_config']['num_classes']) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + + return metric_batch_dict + + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features= self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + #features=features+f1 + #pred=pred+pred1 + # get the probability of the pred + # prob = torch.softmax(pred, dim=1)[:, 1] + prob = torch.softmax(pred, dim=1) + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + + return pred_dict + +# Custom module to represent the residual using SVD components +class SVDResidualLinear(nn.Module): + def __init__(self, in_features, out_features, r, bias=True, init_weight=None): + super(SVDResidualLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.r = r # Number of top singular values to exclude + + # Original weights (fixed) + self.weight_main = nn.Parameter(torch.Tensor(out_features, in_features), requires_grad=False) + if init_weight is not None: + self.weight_main.data.copy_(init_weight) + else: + nn.init.kaiming_uniform_(self.weight_main, a=math.sqrt(5)) + + # Bias + if bias: + self.bias = nn.Parameter(torch.Tensor(out_features)) + nn.init.zeros_(self.bias) + else: + self.register_parameter('bias', None) + + def compute_current_weight(self): + if self.S_residual is not None: + return self.weight_main + self.U_residual @ torch.diag(self.S_residual) @ self.V_residual + else: + return self.weight_main + + def forward(self, x): + if hasattr(self, 'U_residual') and hasattr(self, 'V_residual') and self.S_residual is not None: + # Reconstruct the residual weight + residual_weight = self.U_residual @ torch.diag(self.S_residual) @ self.V_residual + # Total weight is the fixed main weight plus the residual + weight = self.weight_main + residual_weight + else: + # If residual components are not set, use only the main weight + weight = self.weight_main + + return F.linear(x, weight, self.bias) + + def compute_orthogonal_loss(self): + if self.S_residual is not None: + # According to the properties of orthogonal matrices: A^TA = I + UUT = torch.cat((self.U_r, self.U_residual), dim=1) @ torch.cat((self.U_r, self.U_residual), dim=1).t() + VVT = torch.cat((self.V_r, self.V_residual), dim=0) @ torch.cat((self.V_r, self.V_residual), dim=0).t() + # print(self.U_r.size(), self.U_residual.size()) # torch.Size([1024, 1023]) torch.Size([1024, 1]) + # print(self.V_r.size(), self.V_residual.size()) # torch.Size([1023, 1024]) torch.Size([1, 1024]) + # UUT = self.U_residual @ self.U_residual.t() + # VVT = self.V_residual @ self.V_residual.t() + + # Construct an identity matrix + UUT_identity = torch.eye(UUT.size(0), device=UUT.device) + VVT_identity = torch.eye(VVT.size(0), device=VVT.device) + + # Using frobenius norm to compute loss + loss = 0.5 * torch.norm(UUT - UUT_identity, p='fro') + 0.5 * torch.norm(VVT - VVT_identity, p='fro') + else: + loss = 0.0 + + return loss + + def compute_keepsv_loss(self): + if (self.S_residual is not None) and (self.weight_original_fnorm is not None): + # Total current weight is the fixed main weight plus the residual + weight_current = self.weight_main + self.U_residual @ torch.diag(self.S_residual) @ self.V_residual + # Frobenius norm of current weight + weight_current_fnorm = torch.norm(weight_current, p='fro') + + loss = torch.abs(weight_current_fnorm ** 2 - self.weight_original_fnorm ** 2) + # loss = torch.abs(weight_current_fnorm ** 2 + 0.01 * self.weight_main_fnorm ** 2 - 1.01 * self.weight_original_fnorm ** 2) + else: + loss = 0.0 + + return loss + + def compute_fn_loss(self): + if (self.S_residual is not None): + weight_current = self.weight_main + self.U_residual @ torch.diag(self.S_residual) @ self.V_residual + weight_current_fnorm = torch.norm(weight_current, p='fro') + + loss = weight_current_fnorm ** 2 + else: + loss = 0.0 + + return loss + + +# Function to replace nn.Linear modules within self_attn modules with SVDResidualLinear +def apply_svd_residual_to_self_attn(model, r): + for name, module in model.named_children(): + if 'self_attn' in name: + # Replace nn.Linear layers in this module + for sub_name, sub_module in module.named_modules(): + if isinstance(sub_module, nn.Linear): + # Get parent module within self_attn + parent_module = module + sub_module_names = sub_name.split('.') + for module_name in sub_module_names[:-1]: + parent_module = getattr(parent_module, module_name) + # Replace the nn.Linear layer with SVDResidualLinear + setattr(parent_module, sub_module_names[-1], replace_with_svd_residual(sub_module, r)) + else: + # Recursively apply to child modules + apply_svd_residual_to_self_attn(module, r) + # After replacing, set requires_grad for residual components + for param_name, param in model.named_parameters(): + if any(x in param_name for x in ['S_residual', 'U_residual', 'V_residual']): + param.requires_grad = True + else: + param.requires_grad = False + return model + + +# Function to replace a module with SVDResidualLinear +def replace_with_svd_residual(module, r): + if isinstance(module, nn.Linear): + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + + # Create SVDResidualLinear module + new_module = SVDResidualLinear(in_features, out_features, r, bias=bias, init_weight=module.weight.data.clone()) + + if bias and module.bias is not None: + new_module.bias.data.copy_(module.bias.data) + + new_module.weight_original_fnorm = torch.norm(module.weight.data, p='fro') + + # Perform SVD on the original weight + U, S, Vh = torch.linalg.svd(module.weight.data, full_matrices=False) + + # Determine r based on the rank of the weight matrix + r = min(r, len(S)) # Ensure r does not exceed the number of singular values + + # Keep top r singular components (main weight) + U_r = U[:, :r] # Shape: (out_features, r) + S_r = S[:r] # Shape: (r,) + Vh_r = Vh[:r, :] # Shape: (r, in_features) + + # Reconstruct the main weight (fixed) + weight_main = U_r @ torch.diag(S_r) @ Vh_r + + # Calculate the frobenius norm of main weight + new_module.weight_main_fnorm = torch.norm(weight_main.data, p='fro') + + # Set the main weight + new_module.weight_main.data.copy_(weight_main) + + # Residual components (trainable) + U_residual = U[:, r:] # Shape: (out_features, n - r) + S_residual = S[r:] # Shape: (n - r,) + Vh_residual = Vh[r:, :] # Shape: (n - r, in_features) + + if len(S_residual) > 0: + new_module.S_residual = nn.Parameter(S_residual.clone()) + new_module.U_residual = nn.Parameter(U_residual.clone()) + new_module.V_residual = nn.Parameter(Vh_residual.clone()) + + new_module.S_r = nn.Parameter(S_r.clone(), requires_grad=False) + new_module.U_r = nn.Parameter(U_r.clone(), requires_grad=False) + new_module.V_r = nn.Parameter(Vh_r.clone(), requires_grad=False) + else: + new_module.S_residual = None + new_module.U_residual = None + new_module.V_residual = None + + new_module.S_r = None + new_module.U_r = None + new_module.V_r = None + + return new_module + else: + return module diff --git a/training/detectors/effort_vid_detector.py b/training/detectors/effort_vid_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..41a8e4c4a18a8fb1111d127db7969fea52fe4030 --- /dev/null +++ b/training/detectors/effort_vid_detector.py @@ -0,0 +1,485 @@ +import os +import math +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC + +import loralib as lora +from transformers import AutoProcessor, CLIPModel, ViTModel, ViTConfig + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='effort_vid') +class EffortVidDetector(nn.Module): + def __init__(self, config=None): + super(EffortVidDetector, self).__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.head = nn.Linear(1024, 2) + self.loss_func = nn.CrossEntropyLoss() + self.prob, self.label = [], [] + self.correct, self.total = 0, 0 + + def build_backbone(self, config): + # Download model + # https://huggingface.co/openai/clip-vit-large-patch14 + + # mean: [0.48145466, 0.4578275, 0.40821073] + # std: [0.26862954, 0.26130258, 0.27577711] + + # ViT-L/14 224*224 + clip_model = CLIPModel.from_pretrained(self.config["pretrained"]) + + # Apply SVD to self_attn layers only + # ViT-L/14 224*224: 1024-1 + clip_model.vision_model = apply_svd_residual_to_self_attn(clip_model.vision_model, r=1024-1) + + for name, param in clip_model.vision_model.named_parameters(): + print('{}: {}'.format(name, param.requires_grad)) + num_param = sum(p.numel() for p in clip_model.vision_model.parameters() if p.requires_grad) + num_total_param = sum(p.numel() for p in clip_model.vision_model.parameters()) + print('Number of total parameters: {}, tunable parameters: {}'.format(num_total_param, num_param)) + + return clip_model.vision_model + + def features(self, data_dict: dict) -> torch.tensor: + # data_dict['image']: torch.Size([8, 16, 3, 224, 224]) + B, T, C, H, W = data_dict['image'].shape + feat = self.backbone(data_dict['image'].view(B * T, C, H, W))['pooler_output'] + feat = feat.view(B, T, feat.shape[-1]).mean(1) # Temporal avg + # feat torch.Size([32, 1024]) + return feat + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + # def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + # label = data_dict['label'] + # pred = pred_dict['cls'] + # loss = self.loss_func(pred, label) + + # # Regularization term + # lambda_reg = 0.1 + # orthogonal_losses = [] + # for module in self.backbone.modules(): + # if isinstance(module, SVDResidualLinear): + # # Apply orthogonal constraints to the U_residual and V_residual matrix + # orthogonal_losses.append(module.compute_orthogonal_loss()) + + # if orthogonal_losses: + # reg_term = sum(orthogonal_losses) + # loss += lambda_reg * reg_term + + # loss_dict = {'overall': loss} + # return loss_dict + + def compute_weight_loss(self): + weight_sum_dict = {} + num_weight_dict = {} + for name, module in self.backbone.named_modules(): + if isinstance(module, SVDResidualLinear): + weight_curr = module.compute_current_weight() + if str(weight_curr.size()) not in weight_sum_dict.keys(): + weight_sum_dict[str(weight_curr.size())] = weight_curr + num_weight_dict[str(weight_curr.size())] = 1 + else: + weight_sum_dict[str(weight_curr.size())] += weight_curr + num_weight_dict[str(weight_curr.size())] += 1 + + loss2 = 0.0 + for k in weight_sum_dict.keys(): + _, S_sum, _ = torch.linalg.svd(weight_sum_dict[k], full_matrices=False) + loss2 += -torch.mean(S_sum) + loss2 /= len(weight_sum_dict.keys()) + return loss2 + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] # Tensor of shape [batch_size] + pred = pred_dict['cls'] # Tensor of shape [batch_size, num_classes] + + # Compute overall loss using all samples + loss = self.loss_func(pred, label) + + # Create masks for real and fake classes + mask_real = label == 0 # Boolean tensor + mask_fake = label == 1 # Boolean tensor + + # Compute loss for real class + if mask_real.sum() > 0: + pred_real = pred[mask_real] + label_real = label[mask_real] + loss_real = self.loss_func(pred_real, label_real) + else: + # No real samples in batch + loss_real = torch.tensor(0.0, device=pred.device) + + # Compute loss for fake class + if mask_fake.sum() > 0: + pred_fake = pred[mask_fake] + label_fake = label[mask_fake] + loss_fake = self.loss_func(pred_fake, label_fake) + else: + # No fake samples in batch + loss_fake = torch.tensor(0.0, device=pred.device) + + #### With erank_loss loss + # loss2 = self.compute_weight_loss() + # overall_loss = loss + loss2 + # Return a dictionary with all losses + # loss_dict = { + # 'overall': overall_loss, + # 'real_loss': loss_real, + # 'fake_loss': loss_fake, + # 'erank_loss': loss2 + # } + + #### Without erank_loss loss + loss_dict = { + 'overall': loss, + 'real_loss': loss_real, + 'fake_loss': loss_fake, + } + + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + # compute metrics for batch data + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + prob = torch.softmax(pred, dim=1)[:, 1] + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + + return pred_dict + + +# Custom module to represent the residual using SVD components +class SVDResidualLinear(nn.Module): + def __init__(self, in_features, out_features, r, bias=True, init_weight=None): + super(SVDResidualLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.r = r # Number of top singular values to exclude + + # Original weights (fixed) + self.weight_main = nn.Parameter(torch.Tensor(out_features, in_features), requires_grad=False) + if init_weight is not None: + self.weight_main.data.copy_(init_weight) + else: + nn.init.kaiming_uniform_(self.weight_main, a=math.sqrt(5)) + + # Bias + if bias: + self.bias = nn.Parameter(torch.Tensor(out_features)) + nn.init.zeros_(self.bias) + else: + self.register_parameter('bias', None) + + def compute_current_weight(self): + if self.S_residual is not None: + return self.weight_main + self.U_residual @ torch.diag(self.S_residual) @ self.V_residual + else: + return self.weight_main + + def forward(self, x): + if hasattr(self, 'U_residual') and hasattr(self, 'V_residual') and self.S_residual is not None: + # Reconstruct the residual weight + residual_weight = self.U_residual @ torch.diag(self.S_residual) @ self.V_residual + # Total weight is the fixed main weight plus the residual + weight = self.weight_main + residual_weight + else: + # If residual components are not set, use only the main weight + weight = self.weight_main + + return F.linear(x, weight, self.bias) + + def compute_orthogonal_loss(self): + # According to the properties of orthogonal matrices: A^TA = I + UUT_residual = self.U_residual @ self.U_residual.t() + VVT_residual = self.V_residual @ self.V_residual.t() + + # Construct an identity matrix + UUT_residual_identity = torch.eye(UUT_residual.size(0), device=UUT_residual.device) + VVT_residual_identity = torch.eye(VVT_residual.size(0), device=VVT_residual.device) + + # Frobenius norm + loss = 0.5 * torch.norm(UUT_residual - UUT_residual_identity, p='fro') + 0.5 * torch.norm(VVT_residual - VVT_residual_identity, p='fro') + + return loss + + +# Function to replace nn.Linear modules within self_attn modules with SVDResidualLinear +def apply_svd_residual_to_self_attn(model, r): + for name, module in model.named_children(): + if 'self_attn' in name: + # Replace nn.Linear layers in this module + for sub_name, sub_module in module.named_modules(): + if isinstance(sub_module, nn.Linear): + # Get parent module within self_attn + parent_module = module + sub_module_names = sub_name.split('.') + for module_name in sub_module_names[:-1]: + parent_module = getattr(parent_module, module_name) + # Replace the nn.Linear layer with SVDResidualLinear + setattr(parent_module, sub_module_names[-1], replace_with_svd_residual(sub_module, r)) + else: + # Recursively apply to child modules + apply_svd_residual_to_self_attn(module, r) + # After replacing, set requires_grad for residual components + for param_name, param in model.named_parameters(): + if any(x in param_name for x in ['S_residual', 'U_residual', 'V_residual']): + param.requires_grad = True + else: + param.requires_grad = False + return model + + +# Function to replace a module with SVDResidualLinear +def replace_with_svd_residual(module, r): + if isinstance(module, nn.Linear): + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + + # Create SVDResidualLinear module + new_module = SVDResidualLinear(in_features, out_features, r, bias=bias, init_weight=module.weight.data.clone()) + + if bias and module.bias is not None: + new_module.bias.data.copy_(module.bias.data) + + # Perform SVD on the original weight + U, S, Vh = torch.linalg.svd(module.weight.data, full_matrices=False) + + # Determine r based on the rank of the weight matrix + r = min(r, len(S)) # Ensure r does not exceed the number of singular values + + # Keep top r singular components (main weight) + U_r = U[:, :r] # Shape: (out_features, r) + S_r = S[:r] # Shape: (r,) + Vh_r = Vh[:r, :] # Shape: (r, in_features) + + # Reconstruct the main weight (fixed) + weight_main = U_r @ torch.diag(S_r) @ Vh_r + + # Set the main weight + new_module.weight_main.data.copy_(weight_main) + + # Residual components (trainable) + U_residual = U[:, r:] # Shape: (out_features, n - r) + S_residual = S[r:] # Shape: (n - r,) + Vh_residual = Vh[r:, :] # Shape: (n - r, in_features) + + if len(S_residual) > 0: + # S_residual is trainable + new_module.S_residual = nn.Parameter(S_residual.clone()) + # U_residual and V_residual are also trainable + new_module.U_residual = nn.Parameter(U_residual.clone()) + new_module.V_residual = nn.Parameter(Vh_residual.clone()) + else: + # If no residual components, set placeholders + new_module.S_residual = None + new_module.U_residual = None + new_module.V_residual = None + + return new_module + else: + return module + +''' +Training Params: + +embeddings.class_embedding: False +embeddings.patch_embedding.weight: False +embeddings.position_embedding.weight: False +pre_layrnorm.weight: False +pre_layrnorm.bias: False + +encoder.layers.0.self_attn.k_proj.weight_main: False +encoder.layers.0.self_attn.k_proj.bias: False +encoder.layers.0.self_attn.k_proj.S_residual: True +encoder.layers.0.self_attn.k_proj.U_residual: True +encoder.layers.0.self_attn.k_proj.V_residual: True +encoder.layers.0.self_attn.v_proj.weight_main: False +encoder.layers.0.self_attn.v_proj.bias: False +encoder.layers.0.self_attn.v_proj.S_residual: True +encoder.layers.0.self_attn.v_proj.U_residual: True +encoder.layers.0.self_attn.v_proj.V_residual: True +encoder.layers.0.self_attn.q_proj.weight_main: False +encoder.layers.0.self_attn.q_proj.bias: False +encoder.layers.0.self_attn.q_proj.S_residual: True +encoder.layers.0.self_attn.q_proj.U_residual: True +encoder.layers.0.self_attn.q_proj.V_residual: True +encoder.layers.0.self_attn.out_proj.weight_main: False +encoder.layers.0.self_attn.out_proj.bias: False +encoder.layers.0.self_attn.out_proj.S_residual: True +encoder.layers.0.self_attn.out_proj.U_residual: True +encoder.layers.0.self_attn.out_proj.V_residual: True +encoder.layers.0.layer_norm1.weight: False +encoder.layers.0.layer_norm1.bias: False +encoder.layers.0.mlp.fc1.weight: False +encoder.layers.0.mlp.fc1.bias: False +encoder.layers.0.mlp.fc2.weight: False +encoder.layers.0.mlp.fc2.bias: False +encoder.layers.0.layer_norm2.weight: False +encoder.layers.0.layer_norm2.bias: False + +encoder.layers.1.self_attn.k_proj.weight_main: False +encoder.layers.1.self_attn.k_proj.bias: False +encoder.layers.1.self_attn.k_proj.S_residual: True +encoder.layers.1.self_attn.k_proj.U_residual: True +encoder.layers.1.self_attn.k_proj.V_residual: True +encoder.layers.1.self_attn.v_proj.weight_main: False +encoder.layers.1.self_attn.v_proj.bias: False +encoder.layers.1.self_attn.v_proj.S_residual: True +encoder.layers.1.self_attn.v_proj.U_residual: True +encoder.layers.1.self_attn.v_proj.V_residual: True +encoder.layers.1.self_attn.q_proj.weight_main: False +encoder.layers.1.self_attn.q_proj.bias: False +encoder.layers.1.self_attn.q_proj.S_residual: True +encoder.layers.1.self_attn.q_proj.U_residual: True +encoder.layers.1.self_attn.q_proj.V_residual: True +encoder.layers.1.self_attn.out_proj.weight_main: False +encoder.layers.1.self_attn.out_proj.bias: False +encoder.layers.1.self_attn.out_proj.S_residual: True +encoder.layers.1.self_attn.out_proj.U_residual: True +encoder.layers.1.self_attn.out_proj.V_residual: True +encoder.layers.1.layer_norm1.weight: False +encoder.layers.1.layer_norm1.bias: False +encoder.layers.1.mlp.fc1.weight: False +encoder.layers.1.mlp.fc1.bias: False +encoder.layers.1.mlp.fc2.weight: False +encoder.layers.1.mlp.fc2.bias: False +encoder.layers.1.layer_norm2.weight: False +encoder.layers.1.layer_norm2.bias: False + +encoder.layers.2.self_attn.k_proj.weight_main: False +encoder.layers.2.self_attn.k_proj.bias: False +encoder.layers.2.self_attn.k_proj.S_residual: True +encoder.layers.2.self_attn.k_proj.U_residual: True +encoder.layers.2.self_attn.k_proj.V_residual: True +encoder.layers.2.self_attn.v_proj.weight_main: False +encoder.layers.2.self_attn.v_proj.bias: False +encoder.layers.2.self_attn.v_proj.S_residual: True +encoder.layers.2.self_attn.v_proj.U_residual: True +encoder.layers.2.self_attn.v_proj.V_residual: True +encoder.layers.2.self_attn.q_proj.weight_main: False +encoder.layers.2.self_attn.q_proj.bias: False +encoder.layers.2.self_attn.q_proj.S_residual: True +encoder.layers.2.self_attn.q_proj.U_residual: True +encoder.layers.2.self_attn.q_proj.V_residual: True +encoder.layers.2.self_attn.out_proj.weight_main: False +encoder.layers.2.self_attn.out_proj.bias: False +encoder.layers.2.self_attn.out_proj.S_residual: True +encoder.layers.2.self_attn.out_proj.U_residual: True +encoder.layers.2.self_attn.out_proj.V_residual: True +encoder.layers.2.layer_norm1.weight: False +encoder.layers.2.layer_norm1.bias: False +encoder.layers.2.mlp.fc1.weight: False +encoder.layers.2.mlp.fc1.bias: False +encoder.layers.2.mlp.fc2.weight: False +encoder.layers.2.mlp.fc2.bias: False +encoder.layers.2.layer_norm2.weight: False +encoder.layers.2.layer_norm2.bias: False +... +encoder.layers.23.self_attn.k_proj.weight_main: False +encoder.layers.23.self_attn.k_proj.bias: False +encoder.layers.23.self_attn.k_proj.S_residual: True +encoder.layers.23.self_attn.k_proj.U_residual: True +encoder.layers.23.self_attn.k_proj.V_residual: True +encoder.layers.23.self_attn.v_proj.weight_main: False +encoder.layers.23.self_attn.v_proj.bias: False +encoder.layers.23.self_attn.v_proj.S_residual: True +encoder.layers.23.self_attn.v_proj.U_residual: True +encoder.layers.23.self_attn.v_proj.V_residual: True +encoder.layers.23.self_attn.q_proj.weight_main: False +encoder.layers.23.self_attn.q_proj.bias: False +encoder.layers.23.self_attn.q_proj.S_residual: True +encoder.layers.23.self_attn.q_proj.U_residual: True +encoder.layers.23.self_attn.q_proj.V_residual: True +encoder.layers.23.self_attn.out_proj.weight_main: False +encoder.layers.23.self_attn.out_proj.bias: False +encoder.layers.23.self_attn.out_proj.S_residual: True +encoder.layers.23.self_attn.out_proj.U_residual: True +encoder.layers.23.self_attn.out_proj.V_residual: True +encoder.layers.23.layer_norm1.weight: False +encoder.layers.23.layer_norm1.bias: False +encoder.layers.23.mlp.fc1.weight: False +encoder.layers.23.mlp.fc1.bias: False +encoder.layers.23.mlp.fc2.weight: False +encoder.layers.23.mlp.fc2.bias: False +encoder.layers.23.layer_norm2.weight: False +encoder.layers.23.layer_norm2.bias: False + +post_layernorm.weight: False +post_layernorm.bias: False +Number of total parameters: 303376480, tunable parameters: 196704 + + +===> Load checkpoint done! +100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 130/130 [01:07<00:00, 1.92it/s] +dataset: Celeb-DF-v2 +acc: 0.7873882580333413 +auc: 0.8674386218546616 +eer: 0.21000704721634955 +ap: 0.9322288761515111 +pred: [0.9752515 0.6580601 0.75344455 ... 0.45359948 0.8914075 0.14674814] +video_auc: 0.9105750165234634 +label: [1 1 0 ... 1 1 0] +100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 851/851 [07:17<00:00, 1.94it/s] +dataset: DeepFakeDetection +acc: 0.8606078424166698 +auc: 0.9048725171315315 +eer: 0.16390041493775934 +ap: 0.9883843861944681 +pred: [0.9912942 0.4690933 0.99789536 ... 0.8104649 0.9893 0.78386295] +video_auc: 0.9373875743738758 +label: [1 1 1 ... 1 1 1] +100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [01:21<00:00, 1.94it/s] +dataset: DFDCP +acc: 0.7002589125672177 +auc: 0.8182703419711848 +eer: 0.28125 +ap: 0.9055071587990912 +pred: [0.38877106 0.606897 0.5618232 ... 0.5063871 0.08330307 0.31897077] +video_auc: 0.849247887904389 +label: [1 0 1 ... 1 0 0] +100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1128/1128 [09:41<00:00, 1.94it/s] +dataset: DFDC +acc: 0.7389942337547128 +auc: 0.8153807538695912 +eer: 0.26452712297642716 +ap: 0.8553135531715989 +pred: [0.24556044 0.2040193 0.4257187 ... 0.82186127 0.9962172 0.50927925] +video_auc: 0.8395751948048426 +label: [0 1 0 ... 0 1 0] +===> Test Done! +''' + diff --git a/training/detectors/f3net_detector.py b/training/detectors/f3net_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..1b87f9e17bb3a9d2417ee740860ce7b3eb6355f0 --- /dev/null +++ b/training/detectors/f3net_detector.py @@ -0,0 +1,221 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the F3netDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{qian2020thinking, + title={Thinking in frequency: Face forgery detection by mining frequency-aware clues}, + author={Qian, Yuyang and Yin, Guojun and Sheng, Lu and Chen, Zixuan and Shao, Jing}, + booktitle={European conference on computer vision}, + pages={86--103}, + year={2020}, + organization={Springer} +} + +GitHub Reference: +https://github.com/yyk-wew/F3Net + +Notes: +We replicate the results by solely utilizing the FAD branch, following the reference GitHub implementation (https://github.com/yyk-wew/F3Net). +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train +from metrics.base_metrics_class import calculate_acc_for_train +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC + +logger = logging.getLogger(__name__) + +@DETECTOR.register_module(module_name='f3net') +class F3netDetector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.loss_func = self.build_loss(config) + # modules only use in FAD + img_size = config['resolution'] + self.FAD_head = FAD_Head(img_size) + self.fc = nn.Linear(in_features=2048, out_features=1024) + # Step 2: global pooling (compress 18x18 -> 1x1) + self.gap = nn.AdaptiveAvgPool2d(1) + # Step 3: map to 1024 dimensions with a fully connected layer + + def build_backbone(self, config): + # prepare the backbone + backbone_class = BACKBONE[config['backbone_name']] + model_config = config['backbone_config'] + backbone = backbone_class(model_config) + + # To get a good performance, use the ImageNet-pretrained Xception model + state_dict = torch.load(config['pretrained']) + for name, weights in state_dict.items(): + if 'pointwise' in name: + state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1) + state_dict = {k:v for k, v in state_dict.items() if 'fc' not in k} + conv1_data = state_dict['conv1.weight'].data + backbone.load_state_dict(state_dict, False) + logger.info('Load pretrained model from {}'.format(config['pretrained'])) + + # copy on conv1 + # let new conv1 use old param to balance the network + backbone.conv1 = nn.Conv2d(12, 32, 3, 2, 0, bias=False) + for i in range(4): + backbone.conv1.weight.data[:, i*3:(i+1)*3, :, :] = conv1_data / 4.0 + logger.info('Copy conv1 from pretrained model') + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + fea_FAD = self.FAD_head(data_dict['image']) # [B, 12, 256, 256] + return self.backbone.features(fea_FAD) + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.backbone.classifier(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + # compute metrics for batch data + #muti-classification + acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.config['backbone_config']['num_classes']) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + #binary + #auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + #metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + features=self.gap(features) + features = torch.flatten(features, start_dim=1) + # features=self.fc(features) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + # prob = torch.softmax(pred, dim=1)[:, 1] + prob = torch.softmax(pred, dim=1) + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + + return pred_dict + + +# ===================================== other modules for F3Net # ===================================== + + +# Filter Module +class Filter(nn.Module): + def __init__(self, size, band_start, band_end, use_learnable=True, norm=False): + super(Filter, self).__init__() + self.use_learnable = use_learnable + + self.base = nn.Parameter(torch.tensor(generate_filter(band_start, band_end, size)), requires_grad=False) + if self.use_learnable: + self.learnable = nn.Parameter(torch.randn(size, size), requires_grad=True) + self.learnable.data.normal_(0., 0.1) + + self.norm = norm + if norm: + self.ft_num = nn.Parameter(torch.sum(torch.tensor(generate_filter(band_start, band_end, size))), requires_grad=False) + + + def forward(self, x): + if self.use_learnable: + filt = self.base + norm_sigma(self.learnable) + else: + filt = self.base + + if self.norm: + y = x * filt / self.ft_num + else: + y = x * filt + return y + + +# FAD Module +class FAD_Head(nn.Module): + def __init__(self, size): + super(FAD_Head, self).__init__() + + # init DCT matrix + self._DCT_all = nn.Parameter(torch.tensor(DCT_mat(size)).float(), requires_grad=False) + self._DCT_all_T = nn.Parameter(torch.transpose(torch.tensor(DCT_mat(size)).float(), 0, 1), requires_grad=False) + + # define base filters and learnable + # 0 - 1/16 || 1/16 - 1/8 || 1/8 - 1 + low_filter = Filter(size, 0, size // 2.82) + middle_filter = Filter(size, size // 2.82, size // 2) + high_filter = Filter(size, size // 2, size * 2) + all_filter = Filter(size, 0, size * 2) + + self.filters = nn.ModuleList([low_filter, middle_filter, high_filter, all_filter]) + + def forward(self, x): + # DCT + x_freq = self._DCT_all @ x @ self._DCT_all_T # [N, 3, 299, 299] + + # 4 kernel + y_list = [] + for i in range(4): + x_pass = self.filters[i](x_freq) # [N, 3, 299, 299] + y = self._DCT_all_T @ x_pass @ self._DCT_all # [N, 3, 299, 299] + y_list.append(y) + out = torch.cat(y_list, dim=1) # [N, 12, 299, 299] + return out + +# utils +def DCT_mat(size): + m = [[ (np.sqrt(1./size) if i == 0 else np.sqrt(2./size)) * np.cos((j + 0.5) * np.pi * i / size) for j in range(size)] for i in range(size)] + return m + +def generate_filter(start, end, size): + return [[0. if i + j > end or i + j < start else 1. for j in range(size)] for i in range(size)] + +def norm_sigma(x): + return 2. * torch.sigmoid(x) - 1. + + diff --git a/training/detectors/facexray_detector.py b/training/detectors/facexray_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..b8407fa5c00207b4cf1ae593d94f8553ade62d0c --- /dev/null +++ b/training/detectors/facexray_detector.py @@ -0,0 +1,156 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the UCFDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{li2020face, + title={Face x-ray for more general face forgery detection}, + author={Li, Lingzhi and Bao, Jianmin and Zhang, Ting and Yang, Hao and Chen, Dong and Wen, Fang and Guo, Baining}, + booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition}, + pages={5001--5010}, + year={2020} +} + +Notes: +To implement Face X-ray, we utilize the pretrained hrnetv2_w48 as the backbone. Despite our efforts to experiment with alternative backbones, we were unable to attain comparable results with other detectors. +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +from networks.cls_hrnet import get_cls_net +import yaml + +logger = logging.getLogger(__name__) + +@DETECTOR.register_module(module_name='facexray') +class FaceXrayDetector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + # build model + self.backbone = self.build_backbone(config) + self.post_process = nn.Sequential( + nn.Conv2d(in_channels=720, out_channels=256, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(256), + nn.ReLU(), + nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(128), + nn.ReLU(), + nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1, padding=0), + nn.Upsample(size=(256, 256), mode='bilinear', align_corners=True), + ) + self.fc = nn.Sequential( + nn.Linear(128*128, 1024), + nn.BatchNorm1d(1024), + nn.ReLU(), + nn.Linear(1024, 128), + nn.BatchNorm1d(128), + nn.ReLU(), + nn.Linear(128, 2), + ) + + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + cfg_path = './training/config/backbone/cls_hrnet_w48.yaml' + # parse options and load config + with open(cfg_path, 'r') as f: + cfg_config = yaml.safe_load(f) + convnet = get_cls_net(cfg_config) + saved = torch.load('./training/pretrained/hrnetv2_w48_imagenet_pretrained.pth', map_location='cpu') + convnet.load_state_dict(saved, False) + print('Load HRnet') + return convnet + + def build_loss(self, config): + cls_loss_class = LOSSFUNC[config['loss_func']['cls_loss']] + mask_loss_class = LOSSFUNC[config['loss_func']['mask_loss']] + cls_loss_func = cls_loss_class() + mask_loss_func = mask_loss_class() + loss_func = {'cls': cls_loss_func, 'mask': mask_loss_func} + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + return self.backbone.features(data_dict['image']) + + def classifier(self, features: list) -> torch.tensor: + # mask + mask = self.post_process(features) + # feat + feat = F.adaptive_avg_pool2d(mask, 128).view(mask.size(0), -1) + # cls + score = self.fc(feat) + return feat, score, mask + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + # label + label = data_dict['label'] + mask_gt = data_dict['mask'] if data_dict['mask'] is not None else None + # pred + pred_cls = pred_dict['cls'] + pred_mask = pred_dict['mask_pred'] if data_dict['mask'] is not None else None + # loss + loss_cls = self.loss_func['cls'](pred_cls, label) + if data_dict['mask'] is not None: + # Move tensors to the same device + mask_gt = mask_gt.to(pred_mask.device) + loss_mask = F.mse_loss(pred_mask.squeeze().float(), mask_gt.squeeze().float()) + # follow the original paper, + # FIXME: we set λ = 1000 to force the network focusing more on learning the face X-ray prediction + loss = loss_cls + 1000. * loss_mask + loss_dict = {'overall': loss, 'mask': loss_mask, 'cls': loss_cls} + else: # mask_gt is none (during the testing or inference) + loss = loss_cls + loss_dict = {'overall': loss, 'cls': loss_cls} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + # compute metrics for batch data + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + features = self.features(data_dict) + features, pred, mask_pred = self.classifier(features) + # get the probability of the pred + prob = torch.softmax(pred, dim=1)[:, 1] + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features, 'mask_pred': mask_pred} + + return pred_dict + diff --git a/training/detectors/ffd_detector.py b/training/detectors/ffd_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..25f08d6c394a154e003d26309f05b2d9602991bd --- /dev/null +++ b/training/detectors/ffd_detector.py @@ -0,0 +1,221 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the FFDDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{dang2020detection, + title={On the detection of digital face manipulation}, + author={Dang, Hao and Liu, Feng and Stehouwer, Joel and Liu, Xiaoming and Jain, Anil K}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern recognition}, + pages={5781--5790}, + year={2020} +} + +GitHub Reference: +https://github.com/JStehouwer/FFD_CVPR2020 +''' + +import os +import datetime +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter +from imageio import imread +from torchvision import transforms +from metrics.base_metrics_class import calculate_metrics_for_train +from networks.xception import Block, SeparableConv2d +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC + +import logging +logger = logging.getLogger(__name__) + +@DETECTOR.register_module(module_name='ffd') +class FFDDetector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.loss_func = self.build_loss(config) + + # model + templates = get_templates() + maptype = config['maptype'] + if maptype == 'none': + self.map = [1, None] + elif maptype == 'reg': + self.map = RegressionMap(728) + elif maptype == 'tmp': + self.map = TemplateMap(728, templates) + elif maptype == 'pca_tmp': + self.map = PCATemplateMap(templates) + else: + print('Unknown map type: `{0}`'.format(maptype)) + + def build_backbone(self, config): + # prepare the backbone + backbone_class = BACKBONE[config['backbone_name']] + model_config = config['backbone_config'] + backbone = backbone_class(model_config) + # if donot load the pretrained weights, fail to get good results + state_dict = torch.load(config['pretrained']) + for name, weights in state_dict.items(): + if 'pointwise' in name: + state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1) + state_dict = {k:v for k, v in state_dict.items() if 'fc' not in k} + backbone.load_state_dict(state_dict, False) + logger.info('Load pretrained model successfully!') + return backbone + + def build_loss(self, config): + # prepare the loss function + cls_loss_class = LOSSFUNC[config['loss_func']['cls_loss']] + mask_loss_class = LOSSFUNC[config['loss_func']['mask_loss']] + cls_loss_func = cls_loss_class() + mask_loss_func = mask_loss_class() + loss_func = {'cls': cls_loss_func, 'mask': mask_loss_func} + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + # Pass the input through the Xception backbone + x = self.backbone.fea_part1(data_dict['image']) + x = self.backbone.fea_part2(x) + x = self.backbone.fea_part3(x) # This ends at block7 in the backbone + mask, vec = self.map(x) # Compute the mask here + x = x * mask # Apply the mask + x = self.backbone.fea_part4(x) # Continue with the rest of the backbone + x = self.backbone.fea_part5(x) + return x, mask, vec + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.backbone.classifier(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + # label + label = data_dict['label'] + mask_gt = data_dict['mask'] if data_dict['mask'] is not None else None + # pred + pred_cls = pred_dict['cls'] + pred_mask = pred_dict['mask_pred'] if data_dict['mask'] is not None else None + # loss + loss_cls = self.loss_func['cls'](pred_cls, label) + if data_dict['mask'] is not None: + # Move tensors to the same device + mask_gt = mask_gt.to(pred_mask.device) + loss_mask = self.loss_func['mask'](pred_mask, mask_gt) + # follow the original paper, + loss = loss_cls + loss_mask + loss_dict = {'overall': loss, 'mask': loss_mask, 'cls': loss_cls} + else: # mask_gt is none (during the testing or inference) + loss = loss_cls + loss_dict = {'overall': loss, 'cls': loss_cls} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + # compute metrics for batch data + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features, mask, vec = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + prob = torch.softmax(pred, dim=1)[:, 1] + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features, 'mask': mask, 'vec': vec} + + return pred_dict + +class RegressionMap(nn.Module): + def __init__(self, c_in): + super(RegressionMap, self).__init__() + self.c = SeparableConv2d(c_in, 1, 3, stride=1, padding=1, bias=False) + self.s = nn.Sigmoid() + + def forward(self, x): + mask = self.c(x) + mask = self.s(mask) + return mask, None + +class TemplateMap(nn.Module): + def __init__(self, c_in, templates): + super(TemplateMap, self).__init__() + self.c = Block(c_in, 364, 2, 2, start_with_relu=True, grow_first=False) + self.l = nn.Linear(364, 10) + self.relu = nn.ReLU(inplace=True) + + self.templates = templates + + def forward(self, x): + v = self.c(x) + v = self.relu(v) + v = F.adaptive_avg_pool2d(v, (1,1)) + v = v.view(v.size(0), -1) + v = self.l(v) + mask = torch.mm(v, self.templates.reshape(10,361)) + mask = mask.reshape(x.shape[0], 1, 19, 19) + + return mask, v + +class PCATemplateMap(nn.Module): + def __init__(self, templates): + super(PCATemplateMap, self).__init__() + self.templates = templates + + def forward(self, x): + fe = x.view(x.shape[0], x.shape[1], x.shape[2]*x.shape[3]) + fe = torch.transpose(fe, 1, 2) + mu = torch.mean(fe, 2, keepdim=True) + fea_diff = fe - mu + + cov_fea = torch.bmm(fea_diff, torch.transpose(fea_diff, 1, 2)) + B = self.templates.reshape(1, 10, 361).repeat(x.shape[0], 1, 1) + D = torch.bmm(torch.bmm(B, cov_fea), torch.transpose(B, 1, 2)) + eigen_value, eigen_vector = D.symeig(eigenvectors=True) + index = torch.tensor([9]).cuda() + eigen = torch.index_select(eigen_vector, 2, index) + + v = eigen.squeeze(-1) + mask = torch.mm(v, self.templates.reshape(10, 361)) + mask = mask.reshape(x.shape[0], 1, 19, 19) + return mask, v + +def get_templates(): + templates_list = [] + for i in range(10): + img = imread('./training/lib/component/MCT/template{:d}.png'.format(i)) + templates_list.append(transforms.functional.to_tensor(img)[0:1,0:19,0:19]) + if torch.cuda.is_available(): + templates = torch.stack(templates_list).cuda() + else: + templates = torch.stack(templates_list) + templates = templates.squeeze(1) + return templates diff --git a/training/detectors/ftcn_detector.py b/training/detectors/ftcn_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..19a3bd2b236214f1aa43b4dfc362e8f65558f85d --- /dev/null +++ b/training/detectors/ftcn_detector.py @@ -0,0 +1,482 @@ +config_text = """ +TRAIN: + ENABLE: True + DATASET: kinetics + BATCH_SIZE: 64 + EVAL_PERIOD: 10 + CHECKPOINT_PERIOD: 1 + AUTO_RESUME: True +DATA: + NUM_FRAMES: 16 + SAMPLING_RATE: 8 + TRAIN_JITTER_SCALES: [256, 320] + TRAIN_CROP_SIZE: 224 + TEST_CROP_SIZE: 256 + INPUT_CHANNEL_NUM: [3] +RESNET: + ZERO_INIT_FINAL_BN: True + WIDTH_PER_GROUP: 64 + NUM_GROUPS: 1 + DEPTH: 50 + TRANS_FUNC: bottleneck_transform + STRIDE_1X1: False + NUM_BLOCK_TEMP_KERNEL: [[3], [4], [6], [3]] +NONLOCAL: + LOCATION: [[[]], [[]], [[]], [[]]] + GROUP: [[1], [1], [1], [1]] + INSTANTIATION: softmax +BN: + USE_PRECISE_STATS: True + NUM_BATCHES_PRECISE: 200 +SOLVER: + BASE_LR: 0.1 + LR_POLICY: cosine + MAX_EPOCH: 196 + MOMENTUM: 0.9 + WEIGHT_DECAY: 1e-4 + WARMUP_EPOCHS: 34.0 + WARMUP_START_LR: 0.01 + OPTIMIZING_METHOD: sgd +MODEL: + NUM_CLASSES: 1 + ARCH: i3d + MODEL_NAME: ResNet + LOSS_FUNC: cross_entropy + DROPOUT_RATE: 0.5 + HEAD_ACT: sigmoid +TEST: + ENABLE: True + DATASET: kinetics + BATCH_SIZE: 64 +DATA_LOADER: + NUM_WORKERS: 8 + PIN_MEMORY: True +NUM_GPUS: 8 +NUM_SHARDS: 1 +RNG_SEED: 0 +OUTPUT_DIR: . +""" + + +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the XceptionDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{rossler2019faceforensics++, + title={Faceforensics++: Learning to detect manipulated facial images}, + author={Rossler, Andreas and Cozzolino, Davide and Verdoliva, Luisa and Riess, Christian and Thies, Justus and Nie{\ss}ner, Matthias}, + booktitle={Proceedings of the IEEE/CVF international conference on computer vision}, + pages={1--11}, + year={2019} +} +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC + + +import os +import sys +current_file_path = os.path.abspath(__file__) +parent_dir = os.path.dirname(os.path.dirname(current_file_path)) +project_root_dir = os.path.dirname(parent_dir) +sys.path.append(parent_dir) +sys.path.append(project_root_dir) + +import torch +from .utils.slowfast.models.video_model_builder import ResNet as ResNetOri +from .utils.slowfast.config.defaults import get_cfg +from torch import nn +from inspect import signature +from networks.time_transformer import TimeTransformer +import random + + +random_select = True +no_time_pool = True + +logger = logging.getLogger(__name__) + +@DETECTOR.register_module(module_name='ftcn') +class FTCNDetector(AbstractDetector): + def __init__(self, config): + super().__init__() + cfg = get_cfg() + cfg.merge_from_str(config_text) + cfg.NUM_GPUS = 1 + cfg.TEST.BATCH_SIZE = 1 + cfg.TRAIN.BATCH_SIZE = 1 + cfg.DATA.NUM_FRAMES = config['clip_size'] + self.resnet = ResNetOri(cfg) + if config['pretrained'] is not None: + print(f"loading pretrained model from {config['pretrained']}") + pretrained_weights = torch.load(config['pretrained'], map_location='cpu', encoding='latin1') + modified_weights = {k.replace("resnet.", ""): v for k, v in pretrained_weights.items()} + # fit from 400 num_classes to 1 + modified_weights["head.projection.weight"] = modified_weights["head.projection.weight"][:1, :] + modified_weights["head.projection.bias"] = modified_weights["head.projection.bias"][:1] + # load final ckpt + self.resnet.load_state_dict(modified_weights, strict=True) + + + temporal_only_conv(self.resnet, "model", 0) + + stop_point = 5 + for i in [5, 4, 3]: + if stop_point <= i: + setattr(self.resnet, f"s{i}", nn.Identity()) + if stop_point==3: + setattr(self.resnet, f"pathway0_pool", nn.Identity()) + + params = { + 6: dict(spatial_size=7, time_size=config['clip_size'], in_channels=2048), + 5: dict(spatial_size=14, time_size=config['clip_size'], in_channels=1024), + 4: dict(spatial_size=28, time_size=config['clip_size'], in_channels=512), + 3: dict(spatial_size=56, time_size=config['clip_size']*2, in_channels=256), + }[stop_point] + + self.resnet.head = TransformerHead(**params) + + self.loss_func = nn.BCELoss() # The output of the model is a probability value between 0 and 1 (haved used sigmoid) + + def build_backbone(self, config): + pass + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + inputs = [data_dict['image'].permute(0,2,1,3,4)] + pred, video_level_features = self.resnet(inputs) + output = {} + output["final_output"] = pred + return output["final_output"], video_level_features + + def classifier(self, features: torch.tensor): + pass + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'].float() + pred = pred_dict['cls'].view(-1) + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + # compute metrics for batch data + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features and probability + prob, features = self.features(data_dict) + # build the prediction dict for each output + pred_dict = {'cls': prob, 'prob': prob, 'feat': features} + return pred_dict + + + + +class RandomPatchPool(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + # batch,channel,16,7x7 + b, c, t, h, w = x.shape + x = x.reshape(b, c, t, h * w) + if self.training and random_select: + while True: + idx = random.randint(0, h * w - 1) + i = idx // h + j = idx % h + if j == 0 or i == h - 1 or j == h - 1: + continue + else: + break + else: + idx = h * w // 2 + x = x[..., idx] + return x + + +def valid_idx(idx, h): + i = idx // h + j = idx % h + if j == 0 or i == h - 1 or j == h - 1: + return False + else: + return True + + +class RandomAvgPool(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + # batch,channel,16,7x7 + b, c, t, h, w = x.shape + x = x.reshape(b, c, t, h * w) + candidates = list(range(h * w)) + candidates = [idx for idx in candidates if valid_idx(idx, h)] + max_k = len(candidates) + if self.training and random_select: + k = 8 + else: + k = max_k + candidates = random.sample(candidates, k) + x = x[..., candidates].mean(-1) + return x + + +class TransformerHead(nn.Module): + def __init__(self, spatial_size=7, time_size=16, in_channels=2048): + super().__init__() + # if no_time_pool: + # time_size = time_size * 2 + patch_type = "time" + if patch_type == "time": + self.pool = nn.AvgPool3d((1, spatial_size, spatial_size)) + self.num_patches = time_size + elif patch_type == "spatial": + self.pool = nn.AvgPool3d((time_size, 1, 1)) + self.num_patches = spatial_size ** 2 + elif patch_type == "random": + self.pool = RandomPatchPool() + self.num_patches = time_size + elif patch_type == "random_avg": + self.pool = RandomAvgPool() + self.num_patches = time_size + elif patch_type == "all": + self.pool = nn.Identity() + self.num_patches = time_size * spatial_size * spatial_size + else: + raise NotImplementedError(patch_type) + + self.dim = -1 + if self.dim == -1: + self.dim = in_channels + + self.in_channels = in_channels + + if self.dim != self.in_channels: + self.fc = nn.Linear(self.in_channels, self.dim) + + default_params = dict( + dim=self.dim, depth=1, heads=16, mlp_dim=2048, dropout=0.1, emb_dropout=0.1, + ) + params = dict( + patch_type="time", + stop_point=5, + random_select=True, + k=8, + sigmoid_before=False, + ) + for key in default_params: + if key in params: + default_params[key] = params[key] + print(default_params) + self.time_T = TimeTransformer( + num_patches=self.num_patches, num_classes=1, **default_params + ) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + feat = self.pool(x[0]) + x = feat.reshape(-1, self.in_channels, self.num_patches) + x = x.permute(0, 2, 1) + if self.dim != self.in_channels: + x = self.fc(x.reshape(-1, self.in_channels)) + x = x.reshape(-1, self.num_patches, self.dim) + x = self.time_T(x) + x = self.sigmoid(x) + return x, feat + + +parameters = [parameter for parameter in signature(nn.Conv3d).parameters] +print(parameters) + +spatial_count = 0 +keep_stride_count = 0 +print(f"spatial_count={spatial_count} keep_stride_count={keep_stride_count}") + + +def temporal_only_conv(module, name, removed, stride_removed=0): + """ + Recursively put desired batch norm in nn.module module. + + set module = net to start code. + """ + # go through all attributes of module nn.module (e.g. network or layer) and put batch norms if present + for attr_str in dir(module): + sub_module = getattr(module, attr_str) + if type(sub_module) == nn.Conv3d: + target_spatial_size = 1 + predefine_padding = {1: 0, 3: 1, 5: 2, 7: 3} + kernel_size = list(sub_module.kernel_size) + assert kernel_size[1] == kernel_size[2] + stride = sub_module.stride + extra = None + if stride[1] == stride[2] == 2: + stride_removed += 1 + if stride_removed > keep_stride_count: + stride = [1, 1, 1] + extra = nn.MaxPool3d((1, 2, 2)) + else: + print(f"stride {stride_removed} keeped") + + if kernel_size[1] == 1 and extra is None: + continue + padding = list(sub_module.padding) + + kernel_size[1] = kernel_size[2] = target_spatial_size + padding[1] = padding[2] = predefine_padding[target_spatial_size] + + # param_dict = {key: getattr(sub_module, key) for key in parameters } + param_dict = {key: getattr(sub_module, key) for key in parameters if key not in ['device', 'dtype']} + + param_dict.update(kernel_size=kernel_size, padding=padding, stride=stride) + + conv = nn.Conv3d(**param_dict) + + new_module = conv + + removed += 1 + if removed > spatial_count: + print( + f"{removed} replace {name}.{attr_str}: {str(sub_module)} with {str(new_module)}" + ) + setattr(module, attr_str, new_module) + if extra is not None: + if attr_str == "conv": + bn_str = "bn" + else: + bn_str = f"{attr_str}_bn" + if hasattr(module, bn_str): + bn_module = getattr(module, bn_str) + assert isinstance(bn_module, nn.BatchNorm3d) + new_bn_module = nn.Sequential(bn_module, extra) + setattr(module, bn_str, new_bn_module) + print(f"stride {stride_removed} replace {name}.{bn_str}: {str(new_bn_module)}") + else: + print(f"Attribute {bn_str} not found in {name}") + else: + print("keep spatial") + elif type(sub_module) == nn.Dropout: + new_module = nn.Dropout(p=0.5) + # print(f"replace {name}.{attr_str}: {str(sub_module)} with {str(new_module)}") + setattr(module, attr_str, new_module) + if no_time_pool: + if type(sub_module) == nn.MaxPool3d: + kernel_size = list(sub_module.kernel_size) + if kernel_size[0] == 2: + kernel_size[0] = 1 + setattr(module, attr_str, nn.MaxPool3d(kernel_size)) + elif type(sub_module) == nn.AvgPool3d: + kernel_size = list(sub_module.kernel_size) + kernel_size[0] = 2 * kernel_size[0] + setattr(module, attr_str, nn.AvgPool3d(kernel_size)) + + # iterate through immediate child modules. Note, the recursion is done by our code no need to use named_modules() + old_name = name + for name, immediate_child_module in module.named_children(): + removed, stride_removed = temporal_only_conv( + immediate_child_module, old_name + "." + name, removed, stride_removed + ) + return removed, stride_removed + + +# class I3D8x8(nn.Module): +# def __init__(self, pretrained_path=None) -> None: +# super(I3D8x8, self).__init__() +# cfg = get_cfg() +# cfg.merge_from_str(config_text) +# cfg.NUM_GPUS = 1 +# cfg.TEST.BATCH_SIZE = 1 +# cfg.TRAIN.BATCH_SIZE = 1 +# cfg.DATA.NUM_FRAMES = 16 +# self.resnet = ResNetOri(cfg) +# temporal_only_conv(self.resnet, "model", 0) + +# stop_point = 5 + +# for i in [5, 4, 3]: +# if stop_point <= i: +# setattr(self.resnet, f"s{i}", nn.Identity()) +# if stop_point==3: +# setattr(self.resnet, f"pathway0_pool", nn.Identity()) + +# params = { +# 6: dict(spatial_size=7, time_size=16, in_channels=2048), +# 5: dict(spatial_size=14, time_size=16, in_channels=1024), +# 4: dict(spatial_size=28, time_size=16, in_channels=512), +# 3: dict(spatial_size=56, time_size=32, in_channels=256), +# }[stop_point] + +# self.resnet.head = TransformerHead(**params) + +# if pretrained_path is not None: +# print(f"loading pretrained model from {pretrained_path}") +# pretrained_weights = torch.load(pretrained_path) +# modified_weights = {k.replace("resnet.", ""): v for k, v in pretrained_weights.items()} +# self.resnet.load_state_dict(modified_weights, strict=True) + +# def forward( +# self, +# images, +# noise=None, +# has_mask=None, +# freeze_backbone=False, +# return_feature_maps=False, +# ): +# assert not freeze_backbone +# inputs = [images] +# pred = self.resnet(inputs) +# output = {} +# output["final_output"] = pred +# return output + + +# if __name__ == '__main__': +# model = I3D8x8() +# inp = torch.randn(1, 3, 16, 224, 224) +# out = model(inp) diff --git a/training/detectors/fwa_detector.py b/training/detectors/fwa_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..e54226e29f60970cbb09e0abc7f262605432d0ac --- /dev/null +++ b/training/detectors/fwa_detector.py @@ -0,0 +1,115 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the FWADetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@article{li2018exposing, + title={Exposing deepfake videos by detecting face warping artifacts}, + author={Li, Yuezun and Lyu, Siwei}, + journal={arXiv preprint arXiv:1811.00656}, + year={2018} +} + +This code is modified from the official implementation repository: +https://github.com/yuezunli/CVPRW2019_Face_Artifacts +''' + +import os +import logging +import datetime +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC + + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='fwa') +class FWADetector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + # prepare the backbone + backbone_class = BACKBONE[config['backbone_name']] + model_config = config['backbone_config'] + backbone = backbone_class(model_config) + # if donot load the pretrained weights, fail to get good results + state_dict = torch.load(config['pretrained']) + for name, weights in state_dict.items(): + if 'pointwise' in name: + state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1) + state_dict = {k:v for k, v in state_dict.items() if 'fc' not in k} + backbone.load_state_dict(state_dict, False) + logger.info('Load pretrained model successfully!') + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + return self.backbone.features(data_dict['image']) + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.backbone.classifier(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + # compute metrics for batch data + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + prob = torch.softmax(pred, dim=1)[:, 1] + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + return pred_dict diff --git a/training/detectors/ganatt_detector.py b/training/detectors/ganatt_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..51af3c55b730fd02e7aeb837f70320cd2e1d88ad --- /dev/null +++ b/training/detectors/ganatt_detector.py @@ -0,0 +1,191 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# The script is adopted from Ke Sun (sunk@mail.ustc.edu.cn) +# ------------------------------------------------------------------------------ + + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict +import numbers, math + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +import copy +import torch.nn.init as init + +logger = logging.getLogger(__name__) + + +class FeatureExtractor(nn.Module): + """ + Feature extraction network: progressively downsample a 224×224×3 input to the specified number of channels and size, then apply adaptive pooling + Downsampling pipeline: + 224×224×3 → 224×224×16 → 122×122×32 → 61×61×64 → 30×30×128 → 15×15×256 → 7×7×512 → AdaptivePooling + """ + def __init__(self, adaptive_pool_output_size=(1, 1)): + """ + Parameter description: + adaptive_pool_output_size: Output size of adaptive pooling, defaulting to (1,1) (global pooling), and can be set to (7,7), (4,4), etc. + """ + super(FeatureExtractor, self).__init__() + + # 1. 224×224×3 → 224×224×16 (No downsampling, only channel expansion) + self.conv1 = nn.Sequential( + nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1), # padding=1 keeps the spatial size unchanged + nn.BatchNorm2d(16), # Batch normalization improves training stability + nn.ReLU(inplace=True) # activation function + ) + + # 2. 224×224×16 → 122×122×32 (downsampling: stride=2, kernel=5, padding=1) + self.conv2 = nn.Sequential( + nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=2, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True) + ) + + # 3. 122×122×32 → 61×61×64 (downsampling: stride=2, kernel=3, padding=1) + self.conv3 = nn.Sequential( + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True) + ) + + # 4. 61×61×64 → 30×30×128 (downsampling: stride=2, kernel=3, padding=0) + self.conv4 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=0), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True) + ) + + # 5. 30×30×128 → 15×15×256 (downsampling: stride=2, kernel=2, padding=0) + self.conv5 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=256, kernel_size=2, stride=2, padding=0), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True) + ) + + # 6. 15×15×256 → 7×7×512 (downsampling: stride=2, kernel=3, padding=0) + self.conv6 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=0), + nn.BatchNorm2d(512), + nn.ReLU(inplace=True) + ) + + # Adaptive pooling layer (output size can be set flexibly) + self.adaptive_pool = nn.AdaptiveAvgPool2d(adaptive_pool_output_size) + + # Initialize parameters of all layers + self._initialize_weights() + + def forward(self, x): + """Forward pass that returns the flattened feature vector""" + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + x = self.conv5(x) + x = self.conv6(x) + x = self.adaptive_pool(x) + x = x.view(x.size(0), -1) # Flatten to (batch_size, feature_dim) + return x + + def _initialize_weights(self): + """ + Initialize parameters of all layers in the network: + - Convolution layers: initialize with Kaiming normal distribution + - Batch normalization layers: initialize weights to 1 and biases to 0 + - All bias terms (for convolution layers): initialize to 0 + """ + for m in self.modules(): + if isinstance(m, nn.Conv2d): + # Convolution layer initialization: Kaiming He initialization (suitable for ReLU activation) + init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + # Initialize bias to 0 + init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + # Batch normalization initialization: weight=1, bias=0 + init.constant_(m.weight, 1) + init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + # If fully connected layers are added later, initialize them as well (the current network has none) + init.normal_(m.weight, 0, 0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + + +@DETECTOR.register_module(module_name='ganatt') +class GANAtt_Detector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.head = nn.Linear(512, config['backbone_config']['num_classes']) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + # prepare the backbone + backbone = FeatureExtractor() + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + feat = self.backbone(data_dict['image']) + return feat + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + + # compute metrics for batch data + # auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + # metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.config['backbone_config']['num_classes']) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + # prob = torch.softmax(pred, dim=1)[:, 1] + prob = torch.softmax(pred, dim=1) + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + return pred_dict diff --git a/training/detectors/hrnet_detector.py b/training/detectors/hrnet_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..17b51197c1fbd3ee5d70cc243e5a3f1a7be06307 --- /dev/null +++ b/training/detectors/hrnet_detector.py @@ -0,0 +1,785 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# The script is adopted from Ke Sun (sunk@mail.ustc.edu.cn) +# ------------------------------------------------------------------------------ + + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict +import numbers, math + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +import copy + +logger = logging.getLogger(__name__) + + +BN_MOMENTUM = 0.01 + +def weights_init(init_type='gaussian'): + def init_fun(m): + classname = m.__class__.__name__ + if (classname.find('Conv') == 0 or classname.find( + 'Linear') == 0) and hasattr(m, 'weight'): + if init_type == 'gaussian': + nn.init.normal_(m.weight, 0.0, 0.02) + elif init_type == 'xavier': + nn.init.xavier_normal_(m.weight, gain=math.sqrt(2)) + elif init_type == 'kaiming': + nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') + elif init_type == 'orthogonal': + nn.init.orthogonal_(m.weight, gain=math.sqrt(2)) + elif init_type == 'default': + pass + else: + assert 0, "Unsupported initialization: {}".format(init_type) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias, 0.0) + return init_fun + +class GaussianSmoothing(nn.Module): + """ + Apply gaussian smoothing on a + 1d, 2d or 3d tensor. Filtering is performed seperately for each channel + in the input using a depthwise convolution. + Arguments: + channels (int, sequence): Number of channels of the input tensors. Output will + have this number of channels as well. + kernel_size (int, sequence): Size of the gaussian kernel. + sigma (float, sequence): Standard deviation of the gaussian kernel. + dim (int, optional): The number of dimensions of the data. + Default value is 2 (spatial). + """ + def __init__(self, channels, kernel_size, sigma, dim=2): + super(GaussianSmoothing, self).__init__() + if isinstance(kernel_size, numbers.Number): + kernel_size = [kernel_size] * dim + if isinstance(sigma, numbers.Number): + sigma = [sigma] * dim + + # The gaussian kernel is the product of the + # gaussian function of each dimension. + kernel = 1 + meshgrids = torch.meshgrid( + [ + torch.arange(size, dtype=torch.float32) + for size in kernel_size + ], indexing='ij' + ) + for size, std, mgrid in zip(kernel_size, sigma, meshgrids): + mean = (size - 1) / 2 + kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \ + torch.exp(-((mgrid - mean) / std) ** 2 / 2) + # Make sure sum of values in gaussian kernel equals 1. + kernel = kernel / torch.sum(kernel) + + # Reshape to depthwise convolutional weight + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) + self.register_buffer('weight', kernel) + self.groups = channels + + if dim == 1: + self.conv = F.conv1d + elif dim == 2: + self.conv = F.conv2d + elif dim == 3: + self.conv = F.conv3d + else: + raise RuntimeError( + 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim) + ) + + def forward(self, input): + """ + Apply gaussian filter to input. + Arguments: + input (torch.Tensor): Input to apply gaussian filter on. + Returns: + filtered (torch.Tensor): Filtered output. + """ + return self.conv(input, weight=self.weight, groups=self.groups) + +class LaPlacianMs(nn.Module): + def __init__(self,in_c,gauss_ker_size=3,scale=[2],drop_rate=0.2): + super(LaPlacianMs, self).__init__() + self.scale = scale + self.gauss_ker_size = gauss_ker_size + ## apply gaussian smoothing to input feature maps with 3 planes + ## with kernel size K and sigma s + self.smoothing = nn.ModuleDict() + for s in self.scale: + self.smoothing['scale-'+str(s)] = GaussianSmoothing(in_c, self.gauss_ker_size, s) + self.conv_1x1 = nn.Sequential(nn.Conv2d(in_c*len(scale), in_c, + kernel_size=1, stride=1, + bias=False,groups=1), + nn.BatchNorm2d(in_c), + nn.ReLU(inplace=True), + nn.Dropout(p=drop_rate) + ) + # Official init from torch repo. + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.constant_(m.bias, 0) + + def down(self,x,s): + return F.interpolate(x,scale_factor=s, + mode='bilinear', + align_corners=False) + def up (self,x, size): + return F.interpolate(x,size=size,mode='bilinear',align_corners=False) + + def forward(self, x): + for i, s in enumerate(self.scale): + sm = self.smoothing['scale-'+str(s)](x) + sm = self.down(sm,1/s) + sm = self.up(sm,(x.shape[2],x.shape[3])) + if i == 0: + diff = x - sm + else: + diff = torch.cat((diff, x - sm), dim=1) + return self.conv_1x1(diff) + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + +class CatDepth(nn.Module): + def __init__(self): + super(CatDepth, self).__init__() + + def forward(self, x, y): + return torch.cat([x,y],dim=1) + +'''GX: basicblock contains two conv3x3 and two batch norm''' +'''GX: at last, it has a residual connection''' +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=False) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out = out + residual + out = self.relu(out) + + return out + +'''GX: 3 conv + 3 bn then a residual.''' +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, + bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion, + momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=False) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out = out + residual + out = self.relu(out) + + return out + +'''GX: the basic component in the network.''' +class HighResolutionModule(nn.Module): + def __init__(self, num_branches, blocks, num_blocks, num_inchannels, + num_channels, fuse_method, multi_scale_output=True): + super(HighResolutionModule, self).__init__() + self._check_branches( + num_branches, blocks, num_blocks, num_inchannels, num_channels) + + self.num_inchannels = num_inchannels + self.fuse_method = fuse_method + self.num_branches = num_branches + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches( + num_branches, blocks, num_blocks, num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(inplace=False) + + def _check_branches(self, num_branches, blocks, num_blocks, + num_inchannels, num_channels): + if num_branches != len(num_blocks): + error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( + num_branches, len(num_blocks)) + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( + num_branches, len(num_channels)) + raise ValueError(error_msg) + + if num_branches != len(num_inchannels): + error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( + num_branches, len(num_inchannels)) + raise ValueError(error_msg) + + def _make_one_branch(self, branch_index, block, num_blocks, num_channels, + stride=1): + downsample = None + if stride != 1 or \ + self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.num_inchannels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(num_channels[branch_index] * block.expansion, + momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index], stride, downsample)) + self.num_inchannels[branch_index] = \ + num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index])) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + branches.append( + self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + ## GX: fuse layer converts feature maps at different resolution branches + ## GX: into the feature map of the new branches' feature map. + ## GX: https://zhuanlan.zhihu.com/p/335333233 + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], num_inchannels[i], 1, 1, 0, bias=False), + nn.BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM))) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i - j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + 3, 2, 1, bias=False), + nn.BatchNorm2d(num_outchannels_conv3x3, + momentum=BN_MOMENTUM))) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + 3, 2, 1, bias=False), + nn.BatchNorm2d(num_outchannels_conv3x3, + momentum=BN_MOMENTUM), + nn.ReLU(inplace=False))) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + + def forward(self, x): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + for i in range(len(self.fuse_layers)): + y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + elif j > i: + width_output = x[i].shape[-1] + height_output = x[i].shape[-2] + y = y + F.interpolate( + self.fuse_layers[i][j](x[j]), + size=[height_output, width_output], + mode='bilinear', align_corners=True) + else: + y = y + self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + + return x_fuse + + +blocks_dict = { + 'BASIC': BasicBlock, + 'BOTTLENECK': Bottleneck +} + +## GX: the HighResolutionNet has 4 stages. +## GX: each stage has one module which is HighResolutionModule. +## GX: HighResolutionModule has 1,2,3,4 branches. +## GX: each stage has a transitional layers in between. +class HighResolutionNet(nn.Module): + + def __init__(self, config, **kwargs): + super(HighResolutionNet, self).__init__() + + # noise conv + # self.im_conv = nn.Conv2d(3, 10, kernel_size=3, stride=1, padding=1, bias=False) + # self.bayar_conv = nn.Conv2d(3, 3, kernel_size=5, stride=1, padding=2, bias=False) + # self.constraints = BayarConstraint() + + # stem net + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=False) + + # frequency branch + self.conv1fre = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1fre = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.conv2fre = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2fre = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.laplacian = LaPlacianMs(in_c=64,gauss_ker_size=3,scale=[2,4,8]) + + # concat + self.concat_depth = CatDepth() + self.conv_1x1_merge = nn.Sequential(nn.Conv2d(128, 64, + kernel_size=1, stride=1, + bias=False,groups=2), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + nn.Dropout(p=0.2) + ) + # self.module_initializer = module_initializer() + # self.conv_1x1_merge = self.module_initializer(self.conv_1x1_merge) + self.conv_1x1_merge.apply(weights_init('kaiming')) + + self.stage1_cfg = config['STAGE1'] + num_channels = self.stage1_cfg['NUM_CHANNELS'][0] + block = blocks_dict[self.stage1_cfg['BLOCK']] + num_blocks = self.stage1_cfg['NUM_BLOCKS'][0] + self.layer1 = self._make_layer(block, 64, num_channels, num_blocks) + stage1_out_channel = block.expansion * num_channels + + self.stage2_cfg = config['STAGE2'] + num_channels = self.stage2_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage2_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition1 = self._make_transition_layer( + [stage1_out_channel], num_channels) + self.stage2, pre_stage_channels = self._make_stage( + self.stage2_cfg, num_channels) + + self.stage3_cfg = config['STAGE3'] + num_channels = self.stage3_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage3_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition2 = self._make_transition_layer( + pre_stage_channels, num_channels) + self.stage3, pre_stage_channels = self._make_stage( + self.stage3_cfg, num_channels) + + self.stage4_cfg = config['STAGE4'] + num_channels = self.stage4_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage4_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition3 = self._make_transition_layer( + pre_stage_channels, num_channels) + self.stage4, pre_stage_channels = self._make_stage( + self.stage4_cfg, num_channels, multi_scale_output=True) + + last_inp_channels = np.int32(np.sum(pre_stage_channels)) + + # ------------------------------ + # Added downsampling layers (s1 -> 28x28, s2 -> 28x28, s3 -> 28x28) + # ------------------------------ + # s1: 224 -> 28 (8x downsampling): Conv(stride=8) + BN + ReLU + self.downsample_s1 = nn.Sequential( + nn.Conv2d(18, 18, kernel_size=3, stride=8, padding=1, bias=False), + nn.BatchNorm2d(18), + nn.ReLU(inplace=True) + ) + # s2: 112 -> 28 (4x downsampling) + self.downsample_s2 = nn.Sequential( + nn.Conv2d(36, 36, kernel_size=3, stride=4, padding=1, bias=False), + nn.BatchNorm2d(36), + nn.ReLU(inplace=True) + ) + # s3: 56 -> 28 (2x downsampling) + self.downsample_s3 = nn.Sequential( + nn.Conv2d(72, 72, kernel_size=3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(72), + nn.ReLU(inplace=True) + ) + + # ------------------------------ + # Added a channel adjustment layer (270 -> 512) as the main up-projection step + # ------------------------------ + self.adjust_channels = nn.Conv2d( + in_channels=270, # 18+36+72+144=270 + out_channels=512, # target channel count + kernel_size=1, + stride=1, + padding=0, + bias=False # no bias when paired with BN + ) + self.bn_after_adjust = nn.BatchNorm2d(512) # improve training stability + self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1)) + + # ------------------------------ + # Weight initialization (reusing the existing `weights_init` function) + # ------------------------------ + self.downsample_s1.apply(weights_init('kaiming')) + self.downsample_s2.apply(weights_init('kaiming')) + self.downsample_s3.apply(weights_init('kaiming')) + self.adjust_channels.apply(weights_init('kaiming')) + + + ## GX: one dimension matrix converts pre to pos. + ## GX: if channel numbers are equal, pass it directly. + ## GX: if channel numbers are different, using conv 3x3. + ## GX: https://zhuanlan.zhihu.com/p/335333233 + def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append(nn.Sequential( + nn.Conv2d(num_channels_pre_layer[i], + num_channels_cur_layer[i], + 3, + 1, + 1, + bias=False), + nn.BatchNorm2d( + num_channels_cur_layer[i], momentum=BN_MOMENTUM), + nn.ReLU(inplace=False))) + else: + transition_layers.append(None) + else: + conv3x3s = [] + for j in range(i + 1 - num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = num_channels_cur_layer[i] \ + if j == i - num_branches_pre else inchannels + conv3x3s.append(nn.Sequential( + nn.Conv2d( + inchannels, outchannels, 3, 2, 1, bias=False), + nn.BatchNorm2d(outchannels, momentum=BN_MOMENTUM), + nn.ReLU(inplace=False))) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + ## GX: _make_layer creates a conv + bn + def _make_layer(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(inplanes, planes, stride, downsample)) + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True): + ## GX: num_modules are all 1 in this work. + ## GX: light-weight architectures: num_blocks are all 0. + ## GX: branch numbers are 2, 3, 4. + num_modules = layer_config['NUM_MODULES'] + num_branches = layer_config['NUM_BRANCHES'] + num_blocks = layer_config['NUM_BLOCKS'] + num_channels = layer_config['NUM_CHANNELS'] + block = blocks_dict[layer_config['BLOCK']] + fuse_method = layer_config['FUSE_METHOD'] + + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + if not multi_scale_output and i == num_modules - 1: + reset_multi_scale_output = False + else: + reset_multi_scale_output = True + modules.append( + HighResolutionModule(num_branches, block, num_blocks, + num_inchannels, num_channels, fuse_method, + reset_multi_scale_output) + ) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels + + def forward(self, x): + x_fre = self.conv1fre(x) + x_fre = self.bn1fre(x_fre) + x_fre = self.relu(x_fre) + x_fre = self.laplacian(x_fre) + x_fre = self.conv2fre(x_fre) + x_fre = self.bn2fre(x_fre) + x_fre = self.relu(x_fre) + + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + x = self.concat_depth(x, x_fre) + x = self.conv_1x1_merge(x) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_cfg['NUM_BRANCHES']): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + x_list = [] + for i in range(self.stage3_cfg['NUM_BRANCHES']): + if self.transition2[i] is not None: + if i < self.stage2_cfg['NUM_BRANCHES']: + x_list.append(self.transition2[i](y_list[i])) + else: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + x_list = [] + for i in range(self.stage4_cfg['NUM_BRANCHES']): + if self.transition3[i] is not None: + if i < self.stage3_cfg['NUM_BRANCHES']: + x_list.append(self.transition3[i](y_list[i])) + else: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + x = self.stage4(x_list) + s1, s2, s3, s4 = x + # torch.Size([32, 18, 224, 224]) + # torch.Size([32, 36, 112, 112]) + # torch.Size([32, 72, 56, 56]) + # torch.Size([32, 144, 28, 28]) + + # Step 1: Downsample s1, s2, and s3 to 28x28 to match the resolution of s4. + # Key point: use the predefined model layers instead of creating them dynamically in `forward`. + s1_down = self.downsample_s1(s1) # [32,18,28,28] + s2_down = self.downsample_s2(s2) # [32,36,28,28] + s3_down = self.downsample_s3(s3) # [32,72,28,28] + # s4 does not need additional processing: [32,144,28,28] + + # Step 2: Concatenate channels (total channels = 18 + 36 + 72 + 144 = 270) + concat_feat = torch.cat([s1_down, s2_down, s3_down, s4], dim=1) # [32,270,28,28] + + # Step 3: Use a 1x1 convolution to project features to 512 channels (`out_channels=512`). + final_feat = self.adjust_channels(concat_feat) # [32,512,28,28] + + # Optional: apply activation and BN to improve stability (recommended). + final_feat = self.bn_after_adjust(final_feat) + final_feat = F.relu(final_feat) + pooled_feat = self.adaptive_pool(final_feat) + + flattened_feat = pooled_feat.view(pooled_feat.size(0), -1) # [32,512] + + return flattened_feat + + def init_weights(self, pretrained='',): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight, std=0.001) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + if os.path.isfile(pretrained): + pretrained_dict = torch.load(pretrained) + ## GX: official pre-trained dict. + print('=> loading HRNet pretrained model {}'.format(pretrained)) + model_dict = self.state_dict() + model_pretrained_lst, model_nopretrained_lst = [], [] + ## GX: model_dict is weights from the current architecture. + pretrained_dict_used = {} + ## GX: gather weights from pretrained_dict to model_dict. + nopretrained_dict = {k: v for k, v in model_dict.items()} + + for k, v in model_dict.items(): + pretrained_key = 'model.' + k + if pretrained_key not in pretrained_dict.keys(): + if 'stage2' in pretrained_key and 'fuse_layers' not in pretrained_key: + if 'branches.2' in pretrained_key: + pretrained_key = pretrained_key.replace('stage2.0.', 'stage3.0.') + elif 'branches.3' in pretrained_key: + pretrained_key = pretrained_key.replace('stage2.0.', 'stage4.0.') + elif 'stage3' in pretrained_key and 'fuse_layers' not in pretrained_key: + pretrained_key = pretrained_key.replace('stage3.0.', 'stage4.0.') + elif 'fre' in pretrained_key: + pretrained_key = pretrained_key.replace('fre', '') + if pretrained_key in pretrained_dict.keys(): + pretrained_dict_used[k] = pretrained_dict[pretrained_key] + nopretrained_dict.pop(k) + print("no pretrain dict length is: ", len(nopretrained_dict)) + print("pretrained dict length is: ", len(pretrained_dict)) + model_dict.update(pretrained_dict_used) + self.load_state_dict(model_dict) + +def get_seg_model(cfg, **kwargs): + model = HighResolutionNet(cfg, **kwargs) + model.init_weights(cfg["PRETRAINED"]) + return model + + +@DETECTOR.register_module(module_name='hrnet') +class HRNet_Detector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.head = nn.Linear(512, config['backbone_config']['num_classes']) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + # prepare the backbone + backbone = get_seg_model(config["HRNET"]) + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + feat = self.backbone(data_dict['image']) + return feat + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + + # compute metrics for batch data + # auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + # metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.config['backbone_config']['num_classes']) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + # prob = torch.softmax(pred, dim=1)[:, 1] + prob = torch.softmax(pred, dim=1) + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + return pred_dict + diff --git a/training/detectors/i3d_detector.py b/training/detectors/i3d_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..0aa2763bea7da2d09eeb0cd5766a95d661e61a22 --- /dev/null +++ b/training/detectors/i3d_detector.py @@ -0,0 +1,176 @@ +config_text = """ +TRAIN: + ENABLE: True + DATASET: kinetics + BATCH_SIZE: 64 + EVAL_PERIOD: 10 + CHECKPOINT_PERIOD: 1 + AUTO_RESUME: True +DATA: + NUM_FRAMES: 16 + SAMPLING_RATE: 8 + TRAIN_JITTER_SCALES: [256, 320] + TRAIN_CROP_SIZE: 224 + TEST_CROP_SIZE: 256 + INPUT_CHANNEL_NUM: [3] +RESNET: + ZERO_INIT_FINAL_BN: True + WIDTH_PER_GROUP: 64 + NUM_GROUPS: 1 + DEPTH: 50 + TRANS_FUNC: bottleneck_transform + STRIDE_1X1: False + NUM_BLOCK_TEMP_KERNEL: [[3], [4], [6], [3]] +NONLOCAL: + LOCATION: [[[]], [[]], [[]], [[]]] + GROUP: [[1], [1], [1], [1]] + INSTANTIATION: softmax +BN: + USE_PRECISE_STATS: True + NUM_BATCHES_PRECISE: 200 +SOLVER: + BASE_LR: 0.1 + LR_POLICY: cosine + MAX_EPOCH: 196 + MOMENTUM: 0.9 + WEIGHT_DECAY: 1e-4 + WARMUP_EPOCHS: 34.0 + WARMUP_START_LR: 0.01 + OPTIMIZING_METHOD: sgd +MODEL: + NUM_CLASSES: 1 + ARCH: i3d + MODEL_NAME: ResNet + LOSS_FUNC: cross_entropy + DROPOUT_RATE: 0.5 + HEAD_ACT: sigmoid +TEST: + ENABLE: True + DATASET: kinetics + BATCH_SIZE: 64 +DATA_LOADER: + NUM_WORKERS: 8 + PIN_MEMORY: True +NUM_GPUS: 8 +NUM_SHARDS: 1 +RNG_SEED: 0 +OUTPUT_DIR: . +""" + +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the I3DDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{carreira2017quo, + title={Quo vadis, action recognition? a new model and the kinetics dataset}, + author={Carreira, Joao and Zisserman, Andrew}, + booktitle={proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, + pages={6299--6308}, + year={2017} +} +''' + +import logging +import os +import sys + +from detectors import DETECTOR +from loss import LOSSFUNC +from metrics.base_metrics_class import calculate_metrics_for_train + +from .base_detector import AbstractDetector + +current_file_path = os.path.abspath(__file__) +parent_dir = os.path.dirname(os.path.dirname(current_file_path)) +project_root_dir = os.path.dirname(parent_dir) +sys.path.append(parent_dir) +sys.path.append(project_root_dir) + +import torch +from .utils.slowfast.models.video_model_builder import ResNet as ResNetOri +from .utils.slowfast.config.defaults import get_cfg +from torch import nn + +random_select = True +no_time_pool = True + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='i3d') +class I3DDetector(AbstractDetector): + def __init__(self, config): + super().__init__() + cfg = get_cfg() + cfg.merge_from_str(config_text) + cfg.NUM_GPUS = 1 + cfg.TEST.BATCH_SIZE = 1 + cfg.TRAIN.BATCH_SIZE = 1 + cfg.DATA.NUM_FRAMES = config['clip_size'] + self.resnet = ResNetOri(cfg) + if config['pretrained'] is not None: + print(f"loading pretrained model from {config['pretrained']}") + pretrained_weights = torch.load(config['pretrained'], map_location='cpu', encoding='latin1') + modified_weights = {k.replace("resnet.", ""): v for k, v in pretrained_weights.items()} + # fit from 400 num_classes to 1 + modified_weights["head.projection.weight"] = modified_weights["head.projection.weight"][:1, :] + modified_weights["head.projection.bias"] = modified_weights["head.projection.bias"][:1] + # load final ckpt + self.resnet.load_state_dict(modified_weights, strict=True) + + self.loss_func = nn.BCELoss() # The output of the model is a probability value between 0 and 1 (haved used sigmoid) + + def build_backbone(self, config): + pass + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + inputs = [data_dict['image'].permute(0, 2, 1, 3, 4)] + pred = self.resnet(inputs) + output = {"final_output": pred} + + return output["final_output"] + + def classifier(self, features: torch.tensor): + pass + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'].float() + pred = pred_dict['cls'].view(-1) + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + prob = self.features(data_dict) + pred_dict = {'cls': prob, 'prob': prob, 'feat': prob} + + return pred_dict diff --git a/training/detectors/iid_detector.py b/training/detectors/iid_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..3a803be711eb7dcf364f381698d4cc85995711be --- /dev/null +++ b/training/detectors/iid_detector.py @@ -0,0 +1,232 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the IIDDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{huang2023implicit, + title={Implicit identity driven deepfake face swapping detection}, + author={Huang, Baojin and Wang, Zhongyuan and Yang, Jifan and Ai, Jiaxin and Zou, Qin and Wang, Qian and Ye, Dengpan}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + pages={4490--4499}, + year={2023} +} +''' + +import os +import datetime +import logging +import random + +import numpy as np +import yaml +from sklearn import metrics +from typing import Union +from collections import defaultdict + +from dataset.iid_dataset import IIDDataset +from detectors.utils.iid_api import FC_ddp,FC_ddp2 +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train +from networks.iresnet_iid import iresnet50 + +from detectors.base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +from .utils.iid_api import l2_norm + +logger = logging.getLogger(__name__) +torch.autograd.set_detect_anomaly(True) + + +@DETECTOR.register_module(module_name='iid') +class IIDDetector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.loss_func = self.build_loss(config) + self.explicit_extractor = iresnet50(False, fp16=False) + self.explicit_extractor.load_state_dict(torch.load(config['explicit_extractor_pretrained'])) + self.explicit_extractor.cuda().eval() + self.BCE_LOSS = FC_ddp(config['embedding_size'], config['backbone_config']['num_classes']).cuda() + self.IIE_LOSS = FC_ddp2(config['embedding_size'], 1000, scale=64, margin=0.4, mode='arcface', use_cifp=False, + reduction='mean',ddp=config['ddp']).cuda() + self.IIE_LOSS.train().cuda() + + def build_backbone(self, config): + # prepare the backbone + backbone_class = BACKBONE[config['backbone_name']] + model_config = config['backbone_config'] + backbone = backbone_class(model_config) + if config['pretrained'] != 'None': + # if donot load the pretrained weights, fail to get good results + state_dict = torch.load(config['pretrained']) + for name, weights in state_dict.items(): + if 'pointwise' in name: + state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1) + state_dict = {k:v for k, v in state_dict.items() if 'fc' not in k} + backbone.load_state_dict(state_dict, False) + logger.info('Load pretrained model successfully!') + else: + logger.info('No pretrained model.') + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + if config['loss_func']=='center_loss': + loss_func = loss_class(num_classes=2, feat_dim=2048) + else: + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + return self.backbone.features(data_dict['image']) #32,3,256,256 + + def classifier(self, features: torch.tensor,id_f=None) -> torch.tensor: + return self.backbone.classifier(features,id_f) + + def get_train_loss(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + id_index = data_dict['id_index'].cuda() + id_feat = pred_dict['id_feat'] + embed = pred_dict['embed'] + + real_id = (label == 1) + fake_id = (label == 0) + im_embs = l2_norm(embed) + em_embs = l2_norm(id_feat) + loss = 0 + + loss_eic = (im_embs[real_id] * em_embs[real_id]).sum(dim=1).mean() - (im_embs[fake_id] * em_embs[fake_id]).sum( + dim=1).mean() + loss_ce = self.BCE_LOSS(pred, label, return_logits=True).mean() + loss += loss_ce + loss_id, _, _ = self.IIE_LOSS(embed, id_index, return_logits=True) + # loss_id = 0 + loss += 0.05 * loss_id + loss += 0.1 * loss_eic + + loss_dict = {'overall': loss,'loss_bce': loss_ce, 'loss_iie': loss_id, 'loss_eic': loss_eic} + return loss_dict + + def get_test_loss(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + id_feat = pred_dict['id_feat'] + embed = pred_dict['embed'] + + real_id = (label == 1) + fake_id = (label == 0) + im_embs = l2_norm(embed) + em_embs = l2_norm(id_feat) + loss = 0 + + loss_eic = (im_embs[real_id] * em_embs[real_id]).sum(dim=1).mean() - (im_embs[fake_id] * em_embs[fake_id]).sum( + dim=1).mean() + loss_ce = self.BCE_LOSS(pred, label, return_logits=True).mean() + loss += loss_ce + # loss_id = 0 + loss += 0.1 * loss_eic + + loss_dict = {'overall': loss,'loss_bce': loss_ce, 'loss_iie': 0, 'loss_eic': loss_eic} + return loss_dict + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + if 'id_index' in data_dict: # depend on the dataset for io + return self.get_train_loss(data_dict,pred_dict) + else: + return self.get_test_loss(data_dict, pred_dict) + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + # compute metrics for batch data + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + resized_images = F.interpolate(data_dict['image'], size=(112, 112), mode='bilinear', align_corners=False) + id_feat = self.explicit_extractor(resized_images) + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features,id_feat) + + embed=self.backbone.last_emb + # get the probability of the pred + prob = torch.softmax(pred, dim=1)[:, 1] + + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features,'id_feat': id_feat,'embed':embed} + return pred_dict + +if __name__ == '__main__': + + with open(r'H:\code\DeepfakeBench\training\config\detector\iid_detector.yaml', 'r') as f: + config = yaml.safe_load(f) + if config['manualSeed'] is None: + config['manualSeed'] = random.randint(1, 10000) + random.seed(config['manualSeed']) + torch.manual_seed(config['manualSeed']) + if config['cuda']: + torch.cuda.manual_seed_all(config['manualSeed']) + + config['data_manner'] = 'lmdb' + config['dataset_json_folder'] = 'preprocessing/dataset_json_v3' + config['sample_size']=256 + config['with_mask']=False + config['with_landmark']=False + config['use_data_augmentation']=True + config['ddp']=False + detector=IIDDetector(config=config).cuda() + train_set = IIDDataset(config=config, mode='train') + train_data_loader = \ + torch.utils.data.DataLoader( + dataset=train_set, + batch_size=4, + shuffle=True, + num_workers=0, + collate_fn=train_set.collate_fn, + ) + optimizer = optim.Adam( + params=detector.parameters(), + lr=config['optimizer']['adam']['lr'], + weight_decay=config['optimizer']['adam']['weight_decay'], + betas=(config['optimizer']['adam']['beta1'], config['optimizer']['adam']['beta2']), + eps=config['optimizer']['adam']['eps'], + amsgrad=config['optimizer']['adam']['amsgrad'], + ) + from tqdm import tqdm + for iteration, batch in enumerate(tqdm(train_data_loader)): + batch['image'],batch['label']=batch['image'].cuda(),batch['label'].cuda() + predictions=detector(batch) + losses = detector.get_losses(batch, predictions) + optimizer.zero_grad() + losses['overall'].backward() + optimizer.step() + + if iteration > 10: + break diff --git a/training/detectors/lorax_detector.py b/training/detectors/lorax_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..8d1bf18b63a969626bddb2a9d154296ef6011afe --- /dev/null +++ b/training/detectors/lorax_detector.py @@ -0,0 +1,361 @@ +""" +LoRAX backbone and detector are packaged in the same file +""" + +import copy +from typing import Dict, Union + +import peft +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.models.convit import convit_tiny, convit_small, convit_base +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + + +# ======================= +# Original `convit_lorax` content +# ======================= + +class Identity(nn.Module): + def __init__(self): + super(Identity, self).__init__() + + def forward(self, x): + return x + + +class SimpleLinear(nn.Module): + ''' + Reference: + https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py + ''' + def __init__(self, in_features, out_features, bias=True): + super(SimpleLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter(torch.Tensor(out_features, in_features)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_features)) + else: + self.register_parameter('bias', None) + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight, nonlinearity='linear') + nn.init.constant_(self.bias, 0) + + def forward(self, input): + return {'logits': F.linear(input, self.weight, self.bias)} + + +# Do not hardcode `cuda`; choose based on the current environment +_DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' + + +class LoRAX(nn.Module): + def __init__(self, base_model, peft_config, out_dim): + super(LoRAX, self).__init__() + self.model = base_model + # Remove the original `base_model` head and use our own head + self.model.head = Identity() + self.peft_config = peft_config + + self.adapter_names = [] + self.add_adapter() + + self.embed_dim = base_model.embed_dim # e.g. 768 + self.out_dim = out_dim + + self.head = self.generate_fc(self.embed_dim, out_dim) + self.div_head = None + + def feature_dim(self): + return self.embed_dim * len(self.adapter_names) + + def generate_fc(self, in_dim, out_dim): + fc = SimpleLinear(in_dim, out_dim) + return fc + + def freeze(self): + # This only counts parameters and does not actually set `requires_grad=False` + npar = sum(p.numel() for p in self.head.parameters() if p.requires_grad) + pytorch_total_params = sum(p.numel() for p in self.model.parameters()) + pytorch_grad_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) + # If needed, real freezing logic can be added here + + def end_finetuning(self): + """Start FT mode, usually with backbone freezed and balanced classes.""" + self.in_finetuning = False + + def begin_finetuning(self): + """End FT mode, usually with backbone freezed and balanced classes.""" + self.in_finetuning = True + + def reset_classifier(self, nb_classes): + new_head = self.generate_fc(self.feature_dim(), nb_classes) + + if self.head is not None: + nb_output = self.head.out_features + weight = copy.deepcopy(self.head.weight.data) + bias = copy.deepcopy(self.head.bias.data) + new_head.weight.data[:weight.shape[0], :weight.shape[1]] = weight + new_head.bias.data[:nb_output] = bias + + self.head = new_head + + def add_adapter(self): + i = len(self.adapter_names) + adapter_name = f'task{i}' + + if i == 0: + peft_model = peft.get_peft_model(self.model, self.peft_config, adapter_name=adapter_name).to(_DEVICE) + self.model = peft_model + else: + self.model.add_adapter(adapter_name, self.peft_config) + + self.adapter_names.append(adapter_name) + + print(f'added adapter {adapter_name}') + print(f'all adapters {self.adapter_names}') + + def set_active_adapter(self): + self.model.set_adapter(self.adapter_names[-1]) + print(f'active adapter {self.adapter_names[-1]}') + + def forward_features(self, x): + # Run forward once for each adapter and concatenate on the feature dimension + features = [self.forward_adapter(adapter_name, x) for adapter_name in self.adapter_names] + features = torch.cat(features, 1) + self.model.set_adapter(self.adapter_names[-1]) + return features, None, None + + def forward_adapter(self, adapter_name, x): + self.model.set_adapter(adapter_name) + assert self.model.active_adapters[0] == adapter_name + assert len(self.model.active_adapters) == 1 + # Original comment: `feats=self.model.forward_features(x)[0]` + feats = self.model.forward(x) + return feats + + def forward(self, x): + features = [self.forward_adapter(adapter_name, x) for adapter_name in self.adapter_names] + features = torch.cat(features, 1) + out = self.head(features)['logits'] + + if self.div_head is not None: + f_dim = self.embed_dim + aux_logits = self.div_head(features[:, -f_dim:]) + return {'logits': out, 'div': aux_logits} + + return out + + def get_internal_losses(self, clf_loss): + int_losses = {} + return int_losses + + def epoch_log(self): + log = {} + return log + + def reset_div_head(self, one_real=False): + # Remove the old `div_head` + if hasattr(self, 'div_head'): + del self.div_head + + n_features = self.embed_dim + # TODO: need to handle single real scenario + if one_real: + new_head = nn.Linear(n_features, 2) + else: + new_head = nn.Linear(n_features, 3) + + self.div_head = new_head + print('added diversity head') + + +# ======================= +# Detector based on LoRAX +# ======================= + +def _build_convit_backbone(variant: str, pretrained: bool = True): + """A simple wrapper around timm ConViT model construction.""" + if variant == 'tiny': + model = convit_tiny(pretrained=pretrained) + elif variant == 'small': + model = convit_small(pretrained=pretrained) + elif variant == 'base': + model = convit_base(pretrained=pretrained) + else: + raise ValueError(f'Unknown ConViT variant: {variant}') + return model + + +@DETECTOR.register_module(module_name='lorax_convit') +class LoRAXConvitDetector(AbstractDetector): + """ + Use ConViT + LoRA (LoRAX) as the backbone for the Deepfake detector. + This file directly contains the LoRAX implementation. + """ + + def __init__(self, config: Dict, load_param: Union[bool, str] = False): + super().__init__(config=config, load_param=load_param) + self.config = config + + # Build the LoRAX backbone + self.backbone = self.build_backbone(config) + + # Build the loss function + self.loss_func = self.build_loss(config) + + # If needed, handle `load_param` here to load the full detector state from a saved checkpoint. + # if isinstance(load_param, str): + # state = torch.load(load_param, map_location='cpu') + # self.load_state_dict(state, strict=True) + + # ------------------------------------------------------ + # Backbone construction: ConViT + LoRA -> LoRAX + # ------------------------------------------------------ + def build_backbone(self, config: Dict): + """ + Keep YAML / config field naming consistent with the original LoRAX implementation: + backbone_config: + num_classes: int + model: convit_pretrain | convit_pretrain_small | convit_pretrain_tiny + r_param: int + r_alpha_ratio: float + reg: list[str] # Module names targeted by LoRA + lora_dropout: float + """ + bcfg = config['backbone_config'] + + num_classes = bcfg['num_classes'] + # Corresponds to `args.model` in the original code. + convit_model_name = bcfg.get('model', 'convit_pretrain_tiny') + + # Corresponds to `args.r_param` in the original code. + r_param = bcfg.get('r_param', 8) + # Corresponds to `args.r_alpha_ratio` in the original code. + r_alpha_ratio = bcfg.get('r_alpha_ratio', 2.0) + # Corresponds to `args.reg` in the original code. + target_modules = bcfg.get('reg', ['qkv', 'proj']) + lora_dropout = bcfg.get('lora_dropout', 0.1) + + # --- 1) Build the ConViT backbone, matching `factory.get_backbone` --- + if convit_model_name == 'convit_pretrain': + from timm.models.convit import convit_base + base_model = convit_base(pretrained=True) + elif convit_model_name == 'convit_pretrain_small': + from timm.models.convit import convit_small + base_model = convit_small(pretrained=True) + elif convit_model_name == 'convit_pretrain_tiny': + from timm.models.convit import convit_tiny + base_model = convit_tiny(pretrained=True) + else: + raise ValueError(f'Unknown backbone model: {convit_model_name}') + + # --- 2) Configure LoRA to match the original `main(args)` setup --- + lora_alpha = int(r_param * r_alpha_ratio) + peft_config = peft.LoraConfig( + target_modules=target_modules, + r=r_param, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + bias="none", + ) + + # --- 3) Build LoRAX --- + lorax_model = LoRAX( + base_model=base_model, + peft_config=peft_config, + out_dim=num_classes, + ) + + return lorax_model + + # ------------------------------------------------------ + # Build the loss function + # ------------------------------------------------------ + def build_loss(self, config: Dict): + loss_name = config['loss_func'] + loss_class = LOSSFUNC[loss_name] + loss_func = loss_class() + return loss_func + + # ------------------------------------------------------ + # Feature extraction + # ------------------------------------------------------ + def features(self, data_dict: Dict) -> torch.Tensor: + """ + Extract features with `LoRAX.forward_features`. + Returns shape: [B, embed_dim * len(adapter_names)] + """ + images = data_dict['image'] # [B,C,H,W] + feats, _, _ = self.backbone.forward_features(images) + return feats + + # ------------------------------------------------------ + # Classifier + # ------------------------------------------------------ + def classifier(self, features: torch.Tensor) -> torch.Tensor: + """ + `LoRAX.head` is a `SimpleLinear` module returning `{'logits': Tensor}`. + """ + out = self.backbone.head(features) + logits = out['logits'] + return logits + + # ------------------------------------------------------ + # Forward propagation + # ------------------------------------------------------ + def forward(self, data_dict: Dict, inference: bool = False) -> Dict: + """ + data_dict: + - image: [B, C, H, W] + Returns: + pred_dict = { + 'cls': logits [B, num_classes], + 'prob': prob [B, num_classes], + 'feat': feat [B, feature_dim] + } + """ + feats = self.features(data_dict) + logits = self.classifier(feats) + prob = torch.softmax(logits, dim=1) + + pred_dict = { + 'cls': logits, + 'prob': prob, + 'feat': feats + } + return pred_dict + + # ------------------------------------------------------ + # Loss computation + # ------------------------------------------------------ + def get_losses(self, data_dict: Dict, pred_dict: Dict) -> Dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + return {'overall': loss} + + # ------------------------------------------------------ + # Training metrics + # ------------------------------------------------------ + def get_train_metrics(self, data_dict: Dict, pred_dict: Dict) -> Dict: + label = data_dict['label'] + pred = pred_dict['cls'] + num_classes = self.config['backbone_config']['num_classes'] + + acc, mAP = calculate_acc_for_train( + label.detach(), + pred.detach(), + num_classes + ) + return {'acc': acc, 'mAP': mAP} + diff --git a/training/detectors/lrl_detector.py b/training/detectors/lrl_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..611be669c5dda5f3a814d485a6554c907d0352e4 --- /dev/null +++ b/training/detectors/lrl_detector.py @@ -0,0 +1,342 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the LocalRelationDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{chen2021local, + title={Local relation learning for face forgery detection}, + author={Chen, Shen and Yao, Taiping and Chen, Yang and Ding, Shouhong and Li, Jilin and Ji, Rongrong}, + booktitle={Proceedings of the AAAI conference on artificial intelligence}, + volume={35}, + number={2}, + pages={1081--1088}, + year={2021} +} +''' + +import os +import datetime +import logging +import numpy as np +import yaml +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel, Dropout2d, UpsamplingBilinear2d +from torch.utils.tensorboard import SummaryWriter + +from dataset.lrl_dataset import LRLDataset +from metrics.base_metrics_class import calculate_metrics_for_train + +from detectors.base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +import random + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='lrl') +class LRLDetector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.encoder_rgb = self.build_backbone(config) + self.encoder_idct = self.build_backbone(config) + self.encoder_idct.efficientnet._conv_stem = nn.Conv2d(1, 48, kernel_size=3, stride=2, bias=False) + self.loss_func = self.build_loss(config) + self.feature_adjust1 = nn.Upsample(scale_factor=0.25) + self.feature_adjust2 = nn.Upsample(scale_factor=0.5) + self.decoder = Decoder(decoder_filters=[64, 128, 256, 256], + filters=[48, 40, 64, 176, 2008]) + self.rfam1 = RFAM(56) + self.rfam2 = RFAM(160) + self.rfam3 = RFAM(1792) + + self.final = nn.Conv2d(64, out_channels=1, kernel_size=1, bias=False) + + self.overall_classifier = nn.Sequential( + nn.Linear(240, 128), + nn.ReLU(), + nn.Dropout(0.2), + nn.Linear(128, 2), + ) + + def build_backbone(self, config): + # prepare the backbone + backbone_class = BACKBONE[config['backbone_name']] + model_config = config['backbone_config'] + model_config['pretrained'] = self.config['pretrained'] + backbone = backbone_class(model_config) + if config['pretrained'] != 'None': + logger.info('Load pretrained model successfully!') + else: + logger.info('No pretrained model.') + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + self.seg_loss = nn.BCELoss() + self.sim_loss = nn.MSELoss() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + rgb=data_dict['image'] + idct=data_dict['idct'] + + #torch.Size([b, 56, 32, 32]) + rgb1=self.encoder_rgb.block_part1(rgb) + idct1=self.encoder_idct.block_part1(idct) + + rgb1, idct1 = self.rfam1(rgb1, idct1) + featuremap_low = rgb1 + idct1 + + #torch.Size([b, 160, 16, 16]) + rgb2=self.encoder_rgb.block_part2(rgb1) + idct2=self.encoder_idct.block_part2(idct1) + + rgb2, idct2 = self.rfam2(rgb2, idct2) + featuremap_mid = rgb2 + idct2 + + #torch.Size([b, 1792, 8, 8]) + rgb3=self.encoder_rgb.block_part3(rgb2) + idct3=self.encoder_idct.block_part3(idct2) + + rgb3, idct3 = self.rfam3(rgb3, idct3) + featuremap_high = rgb3 + idct3 + + f1 = self.feature_adjust1(featuremap_low) + f2 = self.feature_adjust2(featuremap_mid) + f3 = featuremap_high + featuremap = torch.cat((f1, f2, f3), dim=1) + + return featuremap + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.overall_classifier(features) + + def get_similaritys(self, masks, n=9, m=9): + similaritys = [] + for i in range(len(masks)): + ratios = [y.float().mean() for x in torch.chunk(masks[i], n, dim=0) for y in torch.chunk(x, m, dim=1)] + ratios = torch.tensor(ratios).view(-1, 1) + similarity = 1 - torch.norm(ratios[:, None] - ratios, dim=2, p=2) + similaritys.append(similarity) + similaritys = torch.stack(similaritys) + return similaritys + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + masks = data_dict['mask'] + pred_mask = pred_dict['mask_pred'] + sim = pred_dict['sim'] + sim_gt = self.get_similaritys(masks.squeeze(1), n=4, m=4).cuda() + pred = pred_dict['cls'] + sim_loss = self.sim_loss(sim,sim_gt) + seg_loss = self.seg_loss(pred_mask,masks) + ce_loss = self.loss_func(pred, label) + loss = sim_loss+seg_loss+ce_loss + loss_dict = {'overall': loss,'sim_loss':sim_loss,'seg_loss':seg_loss,'ce_loss':ce_loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = torch.ceil(data_dict['label'].clamp(max=1).float()).long() + pred = pred_dict['cls'] + # compute metrics for batch data + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + return metric_batch_dict + + def feature_process(self,feature): + w = F.unfold(feature, kernel_size=2, stride=2, padding=0).permute(0, 2, 1) # (2008,8,8) to (16,8032), that is, 4*4 and flatten + w_normed = w / (w * w).sum(dim=2, keepdim=True).sqrt() + B, K = w.shape[:2] + sim = torch.einsum('bij,bjk->bik', w_normed, w_normed.permute(0, 2, 1)) # cross-similarity (16,16) + sim = (sim + 1) / 2 + mask = (torch.eye(K) != 1).repeat(B, 1).view(B, K, K).cuda() + sim_mask = torch.masked_select(sim, mask).view(B, K, -1) # remove self-similarity + x = sim_mask.view(B, -1) + return x,sim + + def forward(self, data_dict: dict, inference=False) -> dict: + + # get the features by backbone + features = self.features(data_dict) + + features_processed,sim = self.feature_process(features) + # get the prediction by classifier + pred_raw = self.classifier(features_processed) + + encoder_results = [features] + mask = self.final(self.decoder(encoder_results)) + mask = torch.sigmoid(mask) + # get the probability of the pred + if pred_raw.size(1)>2: + pred=torch.stack([pred_raw[:, 0], torch.sum(pred_raw[:, 1:], dim=1)], dim=1) + else: + pred=pred_raw + prob = torch.softmax(pred, dim=1)[:, 1] + # build the prediction dict for each output + pred_dict = {'cls': pred_raw, 'prob': prob, 'feat': features, 'mask_pred': mask, 'sim': sim} + return pred_dict + # else: + # return pred + + +class DecoderBlock(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.layer = nn.Sequential( + nn.Upsample(scale_factor=2), + nn.Conv2d(in_channels, out_channels, 3, padding=1), + nn.InstanceNorm2d(out_channels), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + return self.layer(x) + + +class ConcatBottleneck(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.seq = nn.Sequential( + nn.Conv2d(in_channels, out_channels, 3, padding=1), + nn.InstanceNorm2d(out_channels), + nn.ReLU(inplace=True) + ) + + def forward(self, dec, enc=None): + return self.seq(dec) + + +class Decoder(nn.Module): + def __init__(self, decoder_filters, filters, upsample_filters=None, + decoder_block=DecoderBlock, bottleneck=ConcatBottleneck, dropout=0): + super().__init__() + self.decoder_filters = decoder_filters + self.filters = filters + self.decoder_block = decoder_block + self.decoder_stages = nn.ModuleList([self._get_decoder(idx) for idx in range(0, len(decoder_filters))]) + self.bottlenecks = nn.ModuleList([bottleneck(f, f) + for i, f in enumerate(reversed(decoder_filters))]) + self.dropout = Dropout2d(dropout) if dropout > 0 else None + self.last_block = None + if upsample_filters: + self.last_block = decoder_block(decoder_filters[0], out_channels=upsample_filters) + else: + self.last_block = UpsamplingBilinear2d(scale_factor=2) + + def forward(self, encoder_results: list): + x = encoder_results[0] + bottlenecks = self.bottlenecks + for idx, bottleneck in enumerate(bottlenecks): + rev_idx = - (idx + 1) + x = self.decoder_stages[rev_idx](x) + x = bottleneck(x) + if self.last_block: + x = self.last_block(x) + if self.dropout: + x = self.dropout(x) + return x + + def _get_decoder(self, layer): + idx = layer + 1 + if idx == len(self.decoder_filters): + in_channels = self.filters[idx] + else: + in_channels = self.decoder_filters[idx] + return self.decoder_block(in_channels, self.decoder_filters[max(layer, 0)]) + + +class RFAM(nn.Module): + def __init__(self, features): + super(RFAM, self).__init__() + self.attention = nn.Sequential( + nn.Conv2d(features * 2, features, 1), + nn.BatchNorm2d(features), + nn.ReLU(), + nn.Conv2d(features, 2, 3, padding=1), + nn.Sigmoid(), + ) + + def forward(self, x1, x2): + U = torch.cat((x1, x2), dim=1) + A = self.attention(U) + A1 = A[:, 0, ...].unsqueeze(1).contiguous() + A2 = A[:, 1, ...].unsqueeze(1).contiguous() + x1 *= A1 + x2 *= A2 + return x1, x2 + + +if __name__ == '__main__': + + with open(r'H:\code\DeepfakeBench\training\config\detector\lrl.yaml', 'r') as f: + config = yaml.safe_load(f) + with open('./training/config/train_config.yaml', 'r') as f: + config2 = yaml.safe_load(f) + config.update(config2) + if config['manualSeed'] is None: + config['manualSeed'] = random.randint(1, 10000) + random.seed(config['manualSeed']) + torch.manual_seed(config['manualSeed']) + if config['cuda']: + torch.cuda.manual_seed_all(config['manualSeed']) + detector=LRLDetector(config=config).cuda() + config['data_manner'] = 'lmdb' + config['dataset_json_folder'] = 'preprocessing/dataset_json_v3' + config['sample_size']=256 + config['with_mask']=True + config['with_landmark']=True + config['use_data_augmentation']=True + train_set = LRLDataset(config=config, mode='train') + train_data_loader = \ + torch.utils.data.DataLoader( + dataset=train_set, + batch_size=2, + shuffle=True, + num_workers=0, + collate_fn=train_set.collate_fn, + ) + optimizer = optim.Adam( + params=detector.parameters(), + lr=config['optimizer']['adam']['lr'], + weight_decay=config['optimizer']['adam']['weight_decay'], + betas=(config['optimizer']['adam']['beta1'], config['optimizer']['adam']['beta2']), + eps=config['optimizer']['adam']['eps'], + amsgrad=config['optimizer']['adam']['amsgrad'], + ) + from tqdm import tqdm + for iteration, batch in enumerate(tqdm(train_data_loader)): + print(iteration) + batch['image'],batch['label'],batch['mask'],batch['idct']=batch['image'].cuda(),batch['label'].cuda(),batch['mask'].cuda(),batch['idct'].cuda() + predictions=detector(batch) + losses = detector.get_losses(batch, predictions) + optimizer.zero_grad() + losses['overall'].backward() + optimizer.step() + + if iteration > 10: + break diff --git a/training/detectors/lsda_detector.py b/training/detectors/lsda_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..53292fed0eeb77273c33aa5bc90275ba66d4d905 --- /dev/null +++ b/training/detectors/lsda_detector.py @@ -0,0 +1,568 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the LSDADetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@article{yan2023transcending, + title={Transcending forgery specificity with latent space augmentation for generalizable deepfake detection}, + author={Yan, Zhiyuan and Luo, Yuhao and Lyu, Siwei and Liu, Qingshan and Wu, Baoyuan}, + journal={arXiv preprint arXiv:2311.11278}, + year={2023} +} +''' + + +import os +import datetime +import numpy as np +from collections import defaultdict +import random + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter +import cv2 +from collections import defaultdict + + +from efficientnet_pytorch import EfficientNet +from networks.iresnet import iresnet100 +from networks.xception import Xception +from detectors import DETECTOR +from sklearn import metrics +from metrics.base_metrics_class import calculate_metrics_for_train +from .base_detector import AbstractDetector + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC + + +device = "cuda" if torch.cuda.is_available() else "cpu" + + + +@DETECTOR.register_module(module_name='lsda') +class LSDADetector(AbstractDetector): + def __init__(self, config): + super().__init__() + # model + forgery_num = 4 + self.model = generator( + num_classes=forgery_num+1, encoder_feat_dim=512, + teacher=config['teacher'], student=config['student'], + real_encoder=config['real_encoder'], + ).to(device) + + # loss + self.cls_criterion = nn.CrossEntropyLoss() + self.gan_loss_fn = nn.BCELoss() + self.prob, self.label = [], [] + self.correct, self.total = 0, 0 + + def build_backbone(self, config): + pass # FIXME: will be added into this function + + def build_loss(self, config): + pass # FIXME: will be added into this function + + def features(self, data_dict: dict) -> torch.tensor: + pass # FIXME: will be added into this function + + def classifier(self, features: torch.tensor) -> torch.tensor: + pass # FIXME: will be added into this function + + def get_losses(self, data_dict: dict, predictions: dict) -> dict: + try: + deepfake_loss, total_loss_distillation, domain_loss, loss_real = predictions['pred_loss'] + + loss = \ + 1 * domain_loss + \ + 0.5 * deepfake_loss + \ + 1 * total_loss_distillation + \ + 1 * loss_real + loss_dict = {'overall': loss, 'domain': domain_loss, 'deepfake': deepfake_loss, 'distillation': total_loss_distillation, 'real_loss': loss_real} + + except: + # test time + loss = 0 + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + label = torch.where(label == 0, 0, 1).reshape(-1,1) + prob = pred_dict['prob'].reshape(-1,1) + # compute metrics for batch data + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), prob.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + + # 1. Forward pass + # pred, data_dict['label'], feat = + model_output = self.model(data_dict['image'], data_dict['label'], inference=inference) + if inference: + pred = model_output + prob = torch.softmax(pred, dim=1)[:, 1] + pred_dict = {'cls': pred, 'prob': prob, 'feat': prob} + else: + pred, deepfake_loss, total_loss_distillation, domain_loss, loss_real, student_feature = model_output + loss = (deepfake_loss, total_loss_distillation, domain_loss, loss_real) + prob = torch.softmax(pred, dim=1)[:, 1] + pred_dict = {'cls': pred, 'prob': prob, 'feat': student_feature, 'pred_loss': loss} + + if inference: + self.prob.append( + pred_dict['prob'] + .detach() + .squeeze() + .cpu() + .numpy() + ) + self.label.append( + data_dict['label'] + .detach() + .squeeze() + .cpu() + .numpy() + ) + # deal with acc + _, prediction_class = torch.max(pred, 1) + correct = (prediction_class == data_dict['label']).sum().item() + self.correct += correct + self.total += data_dict['label'].size(0) + + return pred_dict + + +class efficientnet(nn.Module): + def __init__(self, pretrain='efficientnet-b4', sbi=None): + super(efficientnet, self).__init__() + self.model = EfficientNet.from_pretrained(pretrain,weights_path='./training/pretrained/efficientnet-b4-6ed6700e.pth') + + if pretrain == 'efficientnet-b4': + self.conv = nn.Conv2d(1792, 512, 1) + elif pretrain == 'efficientnet-b1': + self.conv = nn.Conv2d(1280, 512, 1) + elif pretrain == 'efficientnet-b3': + self.conv = nn.Conv2d(1536, 512, 1) + elif pretrain == 'efficientnet-b5': + self.conv = nn.Conv2d(2048, 512, 1) + elif pretrain == 'efficientnet-b6': + self.conv = nn.Conv2d(2304, 512, 1) + else: + raise ValueError('pretrain is not supported') + + # self.channel_adjust_conv = nn.Conv2d(2424, 512, 1) + + def features(self, x): + x = self.model.extract_features(x) + x = self.conv(x) + + return x + + def forward(self, x): + x = self.model.extract_features(x) + x = self.conv(x) + + return x + + +class MLP(nn.Module): + def __init__(self, in_f, hidden_dim, out_f): + super(MLP, self).__init__() + self.pool = nn.AdaptiveAvgPool2d(1) + self.mlp = nn.Sequential(nn.Linear(in_f, hidden_dim), + nn.LeakyReLU(inplace=True), + nn.Linear(hidden_dim, hidden_dim), + nn.LeakyReLU(inplace=True), + nn.Linear(hidden_dim, out_f),) + + def forward(self, x): + x = self.pool(x) + x = self.mlp(x) + return x + +class Conv2d1x1(nn.Module): + def __init__(self, in_f, hidden_dim, out_f): + super(Conv2d1x1, self).__init__() + self.conv2d = nn.Sequential(nn.Conv2d(in_f, hidden_dim, 1, 1), + nn.LeakyReLU(inplace=True), + nn.Conv2d(hidden_dim, hidden_dim, 1, 1), + nn.LeakyReLU(inplace=True), + nn.Conv2d(hidden_dim, out_f, 1, 1),) + + def forward(self, x): + x = self.conv2d(x) + return x + +class Head(nn.Module): + def __init__(self, in_f, hidden_dim, out_f): + super(Head, self).__init__() + self.do = nn.Dropout(0.2) + self.pool = nn.AdaptiveAvgPool2d(1) + self.mlp = nn.Sequential(nn.Linear(in_f, hidden_dim), + nn.LeakyReLU(inplace=True), + nn.Linear(hidden_dim, out_f),) + + def forward(self, x): + bs = x.size()[0] + x_feat = self.pool(x).view(bs, -1) + x = self.mlp(x_feat) + x = self.do(x) + return x, x_feat + +def set_requires_grad(model, val): + for p in model.parameters(): + p.requires_grad = val + + +class generator(nn.Module): + def __init__(self, num_classes, + encoder_feat_dim, + num_domains=5, + teacher='efficientnetb4', + student='efficientnetb4', + real_encoder=None, + ) -> None: + + super(generator, self).__init__() + self.num_domains = num_domains + # init variable + self.num_classes = num_classes + self.encoder_feat_dim = encoder_feat_dim + self.half_fingerprint_dim = encoder_feat_dim//2 + + # basic function + self.lr = nn.LeakyReLU(inplace=True) + self.do = nn.Dropout(0.2) + self.pool = nn.AdaptiveAvgPool2d(1) + self.count = 0 + + # Use a `ModuleList` for the 4 fake classes + if teacher == 'xception': + self.encoders_f = nn.ModuleList([self.init_xcep() for _ in range(self.num_domains-1)]) + elif teacher == 'efficientnetb4': + self.encoders_f = nn.ModuleList([self.init_efficient() for _ in range(self.num_domains-1)]) + + if real_encoder is None: + self.encoder_c = iresnet100(pretrained=False, fp16=False) + elif real_encoder == 'efficientnetb4': + print('real encoder: efficient') + self.encoder_c = self.init_efficient() + + + if student == 'xception': + self.student_encoder = self.init_xcep() + elif student == 'efficientnetb4': + self.student_encoder = self.init_efficient() + + self.fc_weights = nn.Sequential( + nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=1, stride=1, padding=0), + nn.LeakyReLU(inplace=True), + nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=True), + nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=True), + ) + + self.mlp = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + nn.Flatten(), + nn.Linear(self.half_fingerprint_dim*2, self.half_fingerprint_dim), + nn.LeakyReLU(inplace=True), + nn.Linear(self.half_fingerprint_dim, num_domains), + ) + + self.binary_classifier = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + nn.Flatten(), + nn.Linear(self.encoder_feat_dim, 512), + nn.LeakyReLU(inplace=True), + nn.Linear(512, 2), + ) + + self.cls_criterion = nn.CrossEntropyLoss() + + def init_xcep(self, pretrained_path='pretrained/xception-b5690688.pth'): + xcep = Xception(self.num_classes) + # load pre-trained Xception + state_dict = torch.load(pretrained_path) + for name, weights in state_dict.items(): + if 'pointwise' in name: + state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1) + state_dict = {k:v for k, v in state_dict.items() if 'fc' not in k} + xcep.load_state_dict(state_dict, False) + return xcep + + def init_efficient(self): + model = efficientnet(pretrain='efficientnet-b4') + return model + + # only for grad cam + def features(self, cat_data): + # Binary classification detector, a student model, to be distilled (real/fake) + student_feature = self.student_encoder.features(cat_data) + return student_feature + + # only for grad cam + def classifier(self, fea): + out = self.binary_classifier(fea) + return out, None + + def real_fake_feature_extract(self, cat_data): + number_of_groups, video_per_group, c, h, w = cat_data.shape + + # Use defaultdict to store tensors for each domain + domain_f_chunks = defaultdict(list) + domain_c_chunks = defaultdict(list) + + for domain_id in range(video_per_group): + # Get the data for the current domain across all groups + domain_data_tensor = cat_data[:, domain_id] + + # Compute self-generation loss + c_chunk = self.encoder_c(domain_data_tensor) + if domain_id>0: # Use the encoder corresponding to the domain; 5 in total + f_chunk = self.encoders_f[domain_id-1].features(domain_data_tensor) + # Store the chunks in the defaultdict + domain_f_chunks[domain_id-1] = f_chunk + domain_c_chunks[domain_id] = c_chunk + + # Reconstruct the tensors based on the label order + all_f_outputs = torch.stack(list(domain_f_chunks.values())).transpose(1, 0) + all_c_outputs = torch.stack(list(domain_c_chunks.values())).transpose(1, 0) + + return all_f_outputs, all_c_outputs + + + def augment_domains(self, groups_feature_maps): + # Helper Functions + def hard_example_interpolation(z_i, hard_example, lambda_1): + return z_i + lambda_1 * (hard_example - z_i) + + def hard_example_extrapolation(z_i, mean_latent, lambda_2): + return z_i + lambda_2 * (z_i - mean_latent) + + def add_gaussian_noise(z_i, sigma, lambda_3): + epsilon = torch.randn_like(z_i) * sigma + return z_i + lambda_3 * epsilon + + def difference_transform(z_i, z_j, z_k, lambda_4): + return z_i + lambda_4 * (z_j - z_k) + + def distance(z_i, z_j): + return torch.norm(z_i - z_j) + + + domain_number = len(groups_feature_maps[0]) + + # Calculate the mean latent vector for each domain across all groups; why 8*8 + domain_means = [] + for domain_idx in range(domain_number): + all_samples_in_domain = torch.cat([group[domain_idx] for group in groups_feature_maps], dim=0) + domain_mean = torch.mean(all_samples_in_domain, dim=0) + domain_means.append(domain_mean) + + # Identify the hard example for each domain across all groups (the farest one) + hard_examples = [] + for domain_idx in range(domain_number): + all_samples_in_domain = torch.cat([group[domain_idx] for group in groups_feature_maps], dim=0) + distances = torch.tensor([distance(z, domain_means[domain_idx]) for z in all_samples_in_domain]) + hard_example = all_samples_in_domain[torch.argmax(distances)] + hard_examples.append(hard_example) + + + augmented_groups = [] + # modify each feature maps + for group_feature_maps in groups_feature_maps: + augmented_domains = [] + + for domain_idx, domain_feature_maps in enumerate(group_feature_maps): + # Choose a random augmentation + augmentations = [ + lambda z: hard_example_interpolation(z, hard_examples[domain_idx], random.random()), + lambda z: hard_example_extrapolation(z, domain_means[domain_idx], random.random()), + lambda z: add_gaussian_noise(z, random.random(), random.random()), + lambda z: difference_transform(z, domain_feature_maps[0], domain_feature_maps[1], random.random()) + ] + chosen_aug = random.choice(augmentations) + augmented = torch.stack([chosen_aug(z) for z in domain_feature_maps]) + augmented_domains.append(augmented) + + augmented_domains = torch.stack(augmented_domains) + augmented_groups.append(augmented_domains) + + return torch.stack(augmented_groups) + + + def mixup_in_latent_space(self, data): + # data shape: [batchsize, num_domains, 3, 256, 256] + bs, num_domains, _, _, _ = data.shape + + # Initialize an empty tensor for mixed data + mixed_data = torch.zeros_like(data) + + # For each sample in the batch + for i in range(bs): + # Step 1: Generate a shuffled index list for the domains + shuffled_idxs = torch.randperm(num_domains) + + # Step 2: Choose random alpha between 0.5 and 2, then sample lambda from beta distribution + alpha = torch.rand(1) * 1.5 + 0.5 # random alpha between 0.5 and 2 + lambda_ = torch.distributions.beta.Beta(alpha, alpha).sample().to(data.device) + + # Step 3: Perform mixup using the shuffled indices + mixed_data[i] = lambda_ * data[i] + (1 - lambda_) * data[i, shuffled_idxs] + + return mixed_data + + + def rotate_trans(self, fake_fs, + rotation_degree_range=(-30, 30)): + + # Convert degrees to radians + rotation_degree = torch.rand(1).to(fake_fs.device) * (rotation_degree_range[1] - rotation_degree_range[0]) + rotation_degree_range[0] + rotation_radians = rotation_degree * (3.141592653589793 / 180.0) + # Create an identity affine transformation (3x4) with the rotation in the top-left 2x2 corner + identity_affine = torch.tensor([ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0] + ], dtype=torch.float32).to(fake_fs.device) + # Fill the rotation into the top-left 2x2 + identity_affine[0, 0:2] = torch.tensor([torch.cos(rotation_radians), -torch.sin(rotation_radians)], dtype=torch.float32).to(fake_fs.device) + identity_affine[1, 0:2] = torch.tensor([torch.sin(rotation_radians), torch.cos(rotation_radians)], dtype=torch.float32).to(fake_fs.device) + # Expand the affine transformation for the batch + theta = identity_affine.unsqueeze(0).repeat(fake_fs.size(0), 1, 1) + grid = F.affine_grid(theta, fake_fs.size()) + fake_fs = F.grid_sample(fake_fs, grid) + + return fake_fs + + + @staticmethod + def cosine_similarity_loss(x, y, dim=1, eps=1e-8): + x_norm = x / (x.norm(dim=dim, keepdim=True) + eps) + y_norm = y / (y.norm(dim=dim, keepdim=True) + eps) + cos_sim = (x_norm * y_norm).sum(dim=dim) + return 1 - cos_sim + + + @staticmethod + def js_loss(inputs, targets): + """ + Computes the Jensen-Shannon divergence loss. + """ + # Compute the probability distributions + inputs_prob = F.softmax(inputs, dim=1) + targets_prob = F.softmax(targets, dim=1) + + # Compute the average probability distribution + avg_prob = (inputs_prob + targets_prob) / 2 + + # Compute the KL divergence component for each distribution + kl_div_loss = nn.KLDivLoss(reduction='batchmean') + kl_inputs = kl_div_loss(inputs_prob.log(), avg_prob) + kl_targets = kl_div_loss(targets_prob.log(), avg_prob) + + # Compute the Jensen-Shannon divergence + loss = 0.5 * (kl_inputs + kl_targets) + return loss + + + def forward(self, cat_data, label=None, inference=False): + if inference: + # Use the common encoder for inference/testing + student_feature = self.student_encoder.features(cat_data) + out_common = self.binary_classifier(student_feature) + return out_common + + # Obtain data + number_of_groups, video_per_group, c, h, w = cat_data.shape + + # Extract the real and fake features separately ; each one is extracted independently by a separate EfficientNet-B4 + f_outputs, c_outputs = self.real_fake_feature_extract(cat_data) + + # p = random.random() + # if p > 0.5: + # f_outputs = self.rotate_trans(f_outputs) + + + + # Perform augmentation in the latent space / f_out contains only fake samples + f_outputs_aug = self.augment_domains(f_outputs) + # Mixup in the latent space for cross-domain + mix_f_outputs = self.mixup_in_latent_space(f_outputs) + aug_fake = torch.cat([f_outputs_aug, mix_f_outputs], dim=2).view(-1, self.encoder_feat_dim*2, 8, 8) + fc = self.fc_weights(aug_fake).view(number_of_groups, video_per_group-1, self.encoder_feat_dim, 8, 8) + + + + # real constrain (optional, for the aim of learning real-features (e.g., ID) better) + real = c_outputs[:, 0, :, :, :] + df = c_outputs[:, 1, :, :, :] + f2f = c_outputs[:, 2, :, :, :] + fs = c_outputs[:, 3, :, :, :] + nt = c_outputs[:, 4, :, :, :] + loss_real = self.cosine_similarity_loss(real, nt).sum() \ + + self.cosine_similarity_loss(real, f2f).sum() \ + - self.cosine_similarity_loss(real, fs).sum() \ + - self.cosine_similarity_loss(real, df).sum() + # loss_real = self.js_loss(real, nt) + self.js_loss(real, f2f) - self.js_loss(real, fs) - self.js_loss(real, df) + loss_real = loss_real.mean() + + + + + # Obtain reshape label + label = label.contiguous().view(-1) + # Obtain the binary label + binary_label = torch.where(label==0, 0, 1) + + + + + # Binary classification detector, a student model, to be distilled (real/fake) + student_feature = self.student_encoder.features(cat_data.view(-1, c, h, w)) + binary_out = self.binary_classifier(student_feature) + deepfake_loss = F.cross_entropy(binary_out, binary_label) + + + # Distillation for the student encoder + real_mask = (label == 0) + fake_mask = (label > 0) + distillation_real_feature = student_feature[real_mask] + distillation_fake_feature = student_feature[fake_mask].reshape((number_of_groups, video_per_group-1, self.encoder_feat_dim, 8, 8)) + loss_distillation_real = F.mse_loss(distillation_real_feature, c_outputs.reshape(((-1, self.encoder_feat_dim, 8, 8)))[real_mask]) + loss_distillation_fake = F.mse_loss(distillation_fake_feature, fc) + total_loss_distillation = loss_distillation_real + loss_distillation_fake + + + + + # Domain classification loss for all domains + all_domain_feat = torch.cat([c_outputs[:, 0, :, :, :].unsqueeze(1), f_outputs], dim=1).reshape((number_of_groups*video_per_group, self.encoder_feat_dim, 8, 8)) + out_spe = self.mlp(all_domain_feat) + domain_loss = self.cls_criterion(out_spe, label) + + return binary_out, deepfake_loss, total_loss_distillation, domain_loss, loss_real, student_feature diff --git a/training/detectors/meso4Inception_detector.py b/training/detectors/meso4Inception_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..e59577eb64df9fc84b652903f09dbc31ae991210 --- /dev/null +++ b/training/detectors/meso4Inception_detector.py @@ -0,0 +1,111 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the Meso4InceptionDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{afchar2018mesonet, + title={Mesonet: a compact facial video forgery detection network}, + author={Afchar, Darius and Nozick, Vincent and Yamagishi, Junichi and Echizen, Isao}, + booktitle={2018 IEEE international workshop on information forensics and security (WIFS)}, + pages={1--7}, + year={2018}, + organization={IEEE} +} + +GitHub Reference: +https://github.com/HongguLiu/MesoNet-Pytorch +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='meso4Inception') +class Meso4InceptionDetector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + # prepare the backbone + backbone_class = BACKBONE[config['backbone_name']] + model_config = config['backbone_config'] + backbone = backbone_class(model_config) + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + return self.backbone.features(data_dict['image']) + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.backbone.classifier(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + # compute metrics for batch data + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + prob = torch.softmax(pred, dim=1)[:, 1] + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + return pred_dict + + + diff --git a/training/detectors/meso4_detector.py b/training/detectors/meso4_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..8d6e548a041d80d1c7b3638225552d5fbf13fd74 --- /dev/null +++ b/training/detectors/meso4_detector.py @@ -0,0 +1,109 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the Meso4Detector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{afchar2018mesonet, + title={Mesonet: a compact facial video forgery detection network}, + author={Afchar, Darius and Nozick, Vincent and Yamagishi, Junichi and Echizen, Isao}, + booktitle={2018 IEEE international workshop on information forensics and security (WIFS)}, + pages={1--7}, + year={2018}, + organization={IEEE} +} + +GitHub Reference: +https://github.com/HongguLiu/MesoNet-Pytorch +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='meso4') +class Meso4Detector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + # prepare the backbone + backbone_class = BACKBONE[config['backbone_name']] + model_config = config['backbone_config'] + backbone = backbone_class(model_config) + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + return self.backbone.features(data_dict['image']) + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.backbone.classifier(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + # compute metrics for batch data + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + prob = torch.softmax(pred, dim=1)[:, 1] + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + return pred_dict + diff --git a/training/detectors/multi_attention_detector.py b/training/detectors/multi_attention_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..a044d44b04df344e1074b6bec023be0959e0c96d --- /dev/null +++ b/training/detectors/multi_attention_detector.py @@ -0,0 +1,473 @@ +""" +# author: Kangran ZHAO +# email: kangranzhao@link.cuhk.edu.cn +# date: 2024-0401 +# description: Class for the Multi-attention Detector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@INPROCEEDINGS{9577592, + author={Zhao, Hanqing and Wei, Tianyi and Zhou, Wenbo and Zhang, Weiming and Chen, Dongdong and Yu, Nenghai}, + booktitle={2021 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, + title={Multi-attentional Deepfake Detection}, + year={2021}, + volume={}, + number={}, + pages={2185-2194}, + keywords={Measurement;Semantics;Feature extraction;Forgery;Pattern recognition;Feeds;Task analysis}, + doi={10.1109/CVPR46437.2021.00222} + } + +Codes are modified based on GitHub repo https://github.com/yoctta/multiple-attention +""" + +import random + +import kornia +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from detectors import DETECTOR +from loss import LOSSFUNC +from metrics.base_metrics_class import calculate_metrics_for_train +from networks import BACKBONE +from sklearn import metrics + +from .base_detector import AbstractDetector + + +@DETECTOR.register_module(module_name='multi_attention') +class MultiAttentionDetector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.block_layer = {"b1": 1, "b2": 5, "b3": 9, "b4": 15, "b5": 21, "b6": 29, "b7": 31} + self.mid_dim = config["mid_dim"] + self.backbone = self.build_backbone(config) + self.loss_func = self.build_loss(config) + self.batch_cnt = 0 + + with torch.no_grad(): + layer_outputs = self.features({"image": torch.zeros(1, 3, config["resolution"], config["resolution"])}) + + self.feature_layer = config["feature_layer"] + self.attention_layer = config["attention_layer"] + self.num_classes = config["backbone_config"]["num_classes"] + self.num_shallow_features = layer_outputs[self.feature_layer].shape[1] + self.num_attention_features = layer_outputs[self.attention_layer].shape[1] + self.num_final_features = layer_outputs["final"].shape[1] + self.num_attentions = config["num_attentions"] + + self.AGDA = AGDA(kernel_size=config["AGDA"]["kernel_size"], + dilation=config["AGDA"]["dilation"], + sigma=config["AGDA"]["sigma"], + threshold=config["AGDA"]["threshold"], + zoom=config["AGDA"]["zoom"], + scale_factor=config["AGDA"]["scale_factor"], + noise_rate=config["AGDA"]["noise_rate"]) + + self.attention_generation = AttentionMap(self.num_attention_features, self.num_attentions) + self.attention_pooling = AttentionPooling() + self.texture_enhance = TextureEnhanceV1(self.num_shallow_features, self.num_attentions) # Todo + self.num_enhanced_features = self.texture_enhance.output_features + self.num_features_d = self.texture_enhance.output_features_d + self.projection_local = nn.Sequential(nn.Linear(self.num_attentions * self.num_enhanced_features, self.mid_dim), + nn.Hardswish(), + nn.Linear(self.mid_dim, self.mid_dim), + nn.Hardswish()) + self.projection_final = nn.Sequential(nn.Linear(self.num_final_features, self.mid_dim), + nn.Hardswish()) + self.ensemble_classifier_fc = nn.Sequential(nn.Linear(self.mid_dim * 2, self.mid_dim), + nn.Hardswish(), + nn.Linear(self.mid_dim, self.num_classes)) + self.dropout = nn.Dropout(config["dropout_rate"], inplace=True) + self.dropout_final = nn.Dropout(config["dropout_rate_final"], inplace=True) + + def build_backbone(self, config): + backbone_class = BACKBONE[config['backbone_name']] + model_config = config['backbone_config'] + model_config['pretrained'] = self.config.get('pretrained', None) + backbone = backbone_class(model_config) + + return backbone + + def build_loss(self, config): + cls_loss_class = LOSSFUNC[config["loss_func"]["cls_loss"]] + ril_loss_class = LOSSFUNC[config["loss_func"]["ril_loss"]] + cls_loss_func = cls_loss_class() + ril_loss_func = ril_loss_class(M=config["num_attentions"], + N=config["loss_func"]["ril_params"]["N"], + alpha=config["loss_func"]["ril_params"]["alpha"], + alpha_decay=config["loss_func"]["ril_params"]["alpha_decay"], + decay_batch=config["batch_per_epoch"], + inter_margin=config["loss_func"]["ril_params"]["inter_margin"], + intra_margin=config["loss_func"]["ril_params"]["intra_margin"]) + + return {"cls": cls_loss_func, "ril": ril_loss_func, "weights": config["loss_func"]["weights"]} + + def features(self, data_dict: dict) -> torch.tensor: + x = data_dict["image"] + layer_output = {} + for name, module in self.backbone.efficientnet.named_children(): + if name == "_avg_pooling": + layer_output["final"] = x + break + elif name != "_blocks": + x = module(x) + else: + for i in range(len(module)): + x = module[i](x) + if i == self.block_layer["b1"]: + layer_output["b1"] = x + elif i == self.block_layer["b2"]: + layer_output["b2"] = x + elif i == self.block_layer["b3"]: + layer_output["b3"] = x + elif i == self.block_layer["b4"]: + layer_output["b4"] = x + elif i == self.block_layer["b5"]: + layer_output["b5"] = x + elif i == self.block_layer["b6"]: + layer_output["b6"] = x + elif i == self.block_layer["b7"]: + layer_output["b7"] = x + + x = F.adaptive_avg_pool2d(x, (1, 1)) + x = x.view(x.size(0), -1) + layer_output["logit"] = self.backbone.last_layer(x) + + return layer_output + + def classifier(self, features: torch.tensor) -> torch.tensor: + pass # do not overwrite this, since classifier structure has been written in self.forward() + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + if self.batch_cnt <= self.config["backbone_nEpochs"] * self.config["batch_per_epoch"]: + label = data_dict["label"] + pred = pred_dict["cls"] + ce_loss = self.loss_func["cls"](pred, label) + + return {"overall": ce_loss, "ce_loss": ce_loss} + else: + label = data_dict["label"] + pred = pred_dict["cls"] + feature_maps_d = pred_dict["feature_maps_d"] + attention_maps = pred_dict["attentions"] + + ce_loss = self.loss_func["cls"](pred, label) + ril_loss = self.loss_func["ril"](feature_maps_d, attention_maps, label) + weights = self.loss_func["weights"] + over_all_loss = weights[0] * ce_loss + weights[1] * ril_loss + + return {"overall": over_all_loss, "ce_loss": ce_loss, "ril_loss": ril_loss} + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + return metric_batch_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + self.batch_cnt += 1 + if self.batch_cnt <= self.config["backbone_nEpochs"] * self.config["batch_per_epoch"]: + layer_output = self.features(data_dict) + pred = layer_output["logit"] + prob = torch.softmax(pred, dim=1)[:, 1] + pred_dict = {"cls": pred, + "prob": prob, + "feat": layer_output["final"]} + + else: + if not inference: # use AGDA when training + with torch.no_grad(): + layer_output = self.features(data_dict) + raw_attentions = layer_output[self.attention_layer] + attention_maps = self.attention_generation(raw_attentions) + data_dict["image"], _ = self.AGDA.agda(data_dict["image"], attention_maps) + + # Get Attention Maps + layer_output = self.features(data_dict) + raw_attentions = layer_output[self.attention_layer] + attention_maps = self.attention_generation(raw_attentions) + + # Get Textural Feature Matrix P + shallow_features = layer_output[self.feature_layer] + enhanced_features, feature_maps_d = self.texture_enhance(shallow_features, attention_maps) + textural_feature_matrix_p = self.attention_pooling(enhanced_features, attention_maps) + B, M, N = textural_feature_matrix_p.size() + feature_matrix = self.dropout(textural_feature_matrix_p).view(B, -1) + feature_matrix = self.projection_local(feature_matrix) + + # Get Global Feature G + final = layer_output["final"] + attention_maps2 = attention_maps.sum(dim=1, keepdim=True) # [B, 1, H_A, W_A] + final = self.attention_pooling(final, attention_maps2, norm=1).squeeze(1) # [B, C_F] + final = self.projection_final(final) + final = F.hardswish(final) + + # Get the Prediction by Ensemble Classifier + feature_matrix = torch.cat((feature_matrix, final), dim=1) # [B, 2 * mid_dim] + pred = self.ensemble_classifier_fc(feature_matrix) # [B, 2] + + # Get probability + prob = torch.softmax(pred, dim=1)[:, 1] + + pred_dict = {"cls": pred, + "prob": prob, + "feat": layer_output["final"], + "attentions": attention_maps, + "feature_maps_d": feature_maps_d} + + return pred_dict + + +class AttentionMap(nn.Module): + def __init__(self, in_channels, num_attention): + super(AttentionMap, self).__init__() + self.register_buffer('mask', torch.zeros([1, 1, 24, 24])) + self.mask[0, 0, 2:-2, 2:-2] = 1 + self.num_attentions = num_attention + self.conv_extract = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) + self.bn1 = nn.BatchNorm2d(in_channels) + self.conv2 = nn.Conv2d(in_channels, num_attention, kernel_size=1, bias=False) + self.bn2 = nn.BatchNorm2d(num_attention) + + def forward(self, x): + """ + Convert deep feature to attention map + Args: + x: extracted features + Returns: + attention_maps: conventionally 4 attention maps + """ + if self.num_attentions == 0: + return torch.ones([x.shape[0], 1, 1, 1], device=x.device) + + x = self.conv_extract(x) + x = self.bn1(x) + x = F.relu(x, inplace=True) + x = self.conv2(x) + x = self.bn2(x) + x = F.elu(x) + 1 + mask = F.interpolate(self.mask, (x.shape[2], x.shape[3]), mode='nearest') + + return x * mask + + +class AttentionPooling(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, features, attentions, norm=2): + """ + Bilinear Attention Pooing, when used for + Args: + features: [Tensor in [B, C_F, H_F, W_F]] extracted feature maps, either shallow ones or deep ones ??? + attentions: [Tensor in [B, M, H, W]] attention maps, conventionally 4 attention maps (M = 4) + norm: [int, default=2] 1 for deep features, 2 for shallow features + Returns: + feature_matrix: [Tensor in [B, M, C_F] or [B, M, 1]] P (shallow feature) or G (deep feature) ??? + """ + feature_size = features.size()[-2:] + attention_size = attentions.size()[-2:] + if feature_size != attention_size: + attentions = F.interpolate(attentions, size=feature_size, mode='bilinear', align_corners=True) + + if len(features.shape) == 4: + # In TextureEnhanceV1, in accordance with paper + feature_matrix = torch.einsum('imjk,injk->imn', attentions, features) # [B, M, C_F] + else: + # In TextureEnhanceV2 + feature_matrix = torch.einsum('imjk,imnjk->imn', attentions, features) + + if norm == 1: # Used for deep feature BAP + w = torch.sum(attentions + 1e-8, dim=(2, 3)).unsqueeze(-1) + feature_matrix /= w + elif norm == 2: # Used for shallow feature BAP + feature_matrix = F.normalize(feature_matrix, p=2, dim=-1) + + return feature_matrix + + +class TextureEnhanceV1(nn.Module): + def __init__(self, num_features, num_attentions): + super().__init__() + # self.output_features=num_features + self.output_features = num_features * 4 + self.output_features_d = num_features + self.conv0 = nn.Conv2d(num_features, num_features, 1) + self.conv1 = nn.Conv2d(num_features, num_features, 3, padding=1) + self.bn1 = nn.BatchNorm2d(num_features) + self.conv2 = nn.Conv2d(num_features * 2, num_features, 3, padding=1) + self.bn2 = nn.BatchNorm2d(2 * num_features) + self.conv3 = nn.Conv2d(num_features * 3, num_features, 3, padding=1) + self.bn3 = nn.BatchNorm2d(3 * num_features) + self.conv_last = nn.Conv2d(num_features * 4, num_features * 4, 1) + self.bn4 = nn.BatchNorm2d(4 * num_features) + self.bn_last = nn.BatchNorm2d(num_features * 4) + + def forward(self, feature_maps, attention_maps=(1, 1)): + """ + Texture Enhancement Block V1, in accordance with description in paper + 1. Local average pooling. + 2. Residual local features. + 3. Dense Net + Args: + feature_maps: [Tensor in [B, C', H', W']] extracted shallow features + attention_maps: [Tensor in [B, M, H_A, W_A]] calculated attention maps, or + [Tuple with two float elements] local average grid scale, + used for conduct local average pooling, local patch size is decided by attention map size. + Returns: + feature_maps: [Tensor in [B, C_1, H_1, W_1]] enhanced feature maps + feature_maps_d: [Tensor in [B, C', H_A, W_A]] textural information + + """ + B, N, H, W = feature_maps.shape + if type(attention_maps) == tuple: + attention_size = (int(H * attention_maps[0]), int(W * attention_maps[1])) + else: + attention_size = (attention_maps.shape[2], attention_maps.shape[3]) + feature_maps_d = F.adaptive_avg_pool2d(feature_maps, attention_size) + feature_maps = feature_maps - F.interpolate(feature_maps_d, (feature_maps.shape[2], feature_maps.shape[3]), + mode='nearest') + feature_maps0 = self.conv0(feature_maps) + feature_maps1 = self.conv1(F.relu(self.bn1(feature_maps0), inplace=True)) + feature_maps1_ = torch.cat([feature_maps0, feature_maps1], dim=1) + feature_maps2 = self.conv2(F.relu(self.bn2(feature_maps1_), inplace=True)) + feature_maps2_ = torch.cat([feature_maps1_, feature_maps2], dim=1) + feature_maps3 = self.conv3(F.relu(self.bn3(feature_maps2_), inplace=True)) + feature_maps3_ = torch.cat([feature_maps2_, feature_maps3], dim=1) + feature_maps = self.bn_last(self.conv_last(F.relu(self.bn4(feature_maps3_), inplace=True))) + return feature_maps, feature_maps_d + + +class TextureEnhanceV2(nn.Module): + def __init__(self, num_features, num_attentions): + super().__init__() + self.output_features = num_features + self.output_features_d = num_features + self.conv_extract = nn.Conv2d(num_features, num_features, 3, padding=1) + self.conv0 = nn.Conv2d(num_features * num_attentions, num_features * num_attentions, 5, padding=2, + groups=num_attentions) + self.conv1 = nn.Conv2d(num_features * num_attentions, num_features * num_attentions, 3, padding=1, + groups=num_attentions) + self.bn1 = nn.BatchNorm2d(num_features * num_attentions) + self.conv2 = nn.Conv2d(num_features * 2 * num_attentions, num_features * num_attentions, 3, padding=1, + groups=num_attentions) + self.bn2 = nn.BatchNorm2d(2 * num_features * num_attentions) + self.conv3 = nn.Conv2d(num_features * 3 * num_attentions, num_features * num_attentions, 3, padding=1, + groups=num_attentions) + self.bn3 = nn.BatchNorm2d(3 * num_features * num_attentions) + self.conv_last = nn.Conv2d(num_features * 4 * num_attentions, num_features * num_attentions, 1, + groups=num_attentions) + self.bn4 = nn.BatchNorm2d(4 * num_features * num_attentions) + self.bn_last = nn.BatchNorm2d(num_features * num_attentions) + + self.M = num_attentions + + def cat(self, a, b): + B, C, H, W = a.shape + c = torch.cat([a.reshape(B, self.M, -1, H, W), b.reshape(B, self.M, -1, H, W)], dim=2).reshape(B, -1, H, W) + return c + + def forward(self, feature_maps, attention_maps=(1, 1)): + """ + Args: + feature_maps: [Tensor in [B, N, H, W]] extracted feature maps from shallow layer + attention_maps: [Tensor in [B, M, H_A, W_A] or float of (H_ratio, W_ratio)] either extracted attention maps + or average pooling down-sampling ratio + Returns: + feature_maps, feature_maps_d: [Tensor in [B, M, N, H, W], Tensor in [B, N, H, W]] feature maps after dense + network and non-textural feature map D + """ + B, N, H, W = feature_maps.shape + if type(attention_maps) == tuple: + attention_size = (int(H * attention_maps[0]), int(W * attention_maps[1])) + else: + attention_size = (attention_maps.shape[2], attention_maps.shape[3]) + feature_maps = self.conv_extract(feature_maps) + feature_maps_d = F.adaptive_avg_pool2d(feature_maps, attention_size) + if feature_maps.size(2) > feature_maps_d.size(2): + feature_maps = feature_maps - F.interpolate(feature_maps_d, (feature_maps.shape[2], feature_maps.shape[3]), + mode='nearest') + attention_maps = ( + torch.tanh(F.interpolate(attention_maps.detach(), (H, W), mode='bilinear', align_corners=True))).unsqueeze( + 2) if type(attention_maps) != tuple else 1 + feature_maps = feature_maps.unsqueeze(1) + feature_maps = (feature_maps * attention_maps).reshape(B, -1, H, W) + feature_maps0 = self.conv0(feature_maps) + feature_maps1 = self.conv1(F.relu(self.bn1(feature_maps0), inplace=True)) + feature_maps1_ = self.cat(feature_maps0, feature_maps1) + feature_maps2 = self.conv2(F.relu(self.bn2(feature_maps1_), inplace=True)) + feature_maps2_ = self.cat(feature_maps1_, feature_maps2) + feature_maps3 = self.conv3(F.relu(self.bn3(feature_maps2_), inplace=True)) + feature_maps3_ = self.cat(feature_maps2_, feature_maps3) + feature_maps = F.relu(self.bn_last(self.conv_last(F.relu(self.bn4(feature_maps3_), inplace=True))), + inplace=True) + feature_maps = feature_maps.reshape(B, -1, N, H, W) + return feature_maps, feature_maps_d + + +class AGDA(nn.Module): + def __init__(self, kernel_size, dilation, sigma, threshold, zoom, scale_factor, noise_rate): + super().__init__() + self.kernel_size = kernel_size + self.dilation = dilation + self.sigma = sigma + self.noise_rate = noise_rate + self.scale_factor = scale_factor + self.threshold = threshold + self.zoom = zoom + self.filter = kornia.filters.GaussianBlur2d((self.kernel_size, self.kernel_size), (self.sigma, self.sigma)) + + def mod_func(self, x): + threshold = random.uniform(*self.threshold) if type(self.threshold) == list else self.threshold + zoom = random.uniform(*self.zoom) if type(self.zoom) == list else self.zoom + bottom = torch.sigmoid((torch.tensor(0.) - threshold) * zoom) + + return (torch.sigmoid((x - threshold) * zoom) - bottom) / (1 - bottom) + + def soft_drop2(self, x, attention_map): + with torch.no_grad(): + attention_map = self.mod_func(attention_map) + B, C, H, W = x.size() + xs = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=True) + xs = self.filter(xs) + xs += torch.randn_like(xs) * self.noise_rate + xs = F.interpolate(xs, (H, W), mode='bilinear', align_corners=True) + x = x * (1 - attention_map) + xs * attention_map + return x + + def agda(self, X, attention_map): + with torch.no_grad(): + attention_weight = torch.sum(attention_map, dim=(2, 3)) + attention_map = F.interpolate(attention_map, (X.size(2), X.size(3)), mode="bilinear", align_corners=True) + attention_weight = torch.sqrt(attention_weight + 1) + index = torch.distributions.categorical.Categorical(attention_weight).sample() + index1 = index.view(-1, 1, 1, 1).repeat(1, 1, X.size(2), X.size(3)) + attention_map = torch.gather(attention_map, 1, index1) + atten_max = torch.max(attention_map.view(attention_map.shape[0], 1, -1), 2)[0] + 1e-8 + attention_map = attention_map / atten_max.view(attention_map.shape[0], 1, 1, 1) + + return self.soft_drop2(X, attention_map), index diff --git a/training/detectors/npr_detector.py b/training/detectors/npr_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..4eede93b0d0126a3949897213dcc71fbac9494b8 --- /dev/null +++ b/training/detectors/npr_detector.py @@ -0,0 +1,105 @@ +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='npr') +class NPR(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.loss_func = self.build_loss(config) + + # NPR switch and parameters + self.use_npr = config.get('use_npr', False) + self.npr_factor = config.get('npr_factor', 0.5) # downsampling factor + self.npr_scale = config.get('npr_scale', 2.0/3.0) # scaling factor for NPR * 2/3 + + def build_backbone(self, config): + backbone_class = BACKBONE[config['backbone_name']] + model_config = config['backbone_config'] + backbone = backbone_class(model_config) + return backbone + + def build_loss(self, config): + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + # === NPR preprocessing function === + def npr_preprocess(self, x: torch.Tensor) -> torch.Tensor: + """ + x: [N, C, H, W] input image (already normalized) + NPR(x) = x - upsample(downsample(x)) + Both downsampling and upsampling use Nearest, with `scale_factor = self.npr_factor` + Then multiply by `self.npr_scale`, corresponding to `*2/3` in the paper code. + """ + factor = self.npr_factor + # Downsample first and then upsample using nearest-neighbor interpolation + x_down = F.interpolate( + x, scale_factor=factor, mode='nearest', recompute_scale_factor=True + ) + x_recon = F.interpolate( + x_down, scale_factor=1.0/factor, mode='nearest', recompute_scale_factor=True + ) + npr = x - x_recon + return npr * self.npr_scale + + def features(self, data_dict: dict) -> torch.Tensor: + img = data_dict['image'] + # Apply NPR preprocessing before sending data into the backbone + if self.use_npr: + img = self.npr_preprocess(img) + return self.backbone.features(img) + + def classifier(self, features: torch.Tensor) -> torch.Tensor: + return self.backbone.classifier(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + acc, mAP = calculate_acc_for_train( + label.detach(), pred.detach(), self.config['backbone_config']['num_classes'] + ) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + features = self.features(data_dict) + pred = self.classifier(features) + prob = torch.softmax(pred, dim=1) + pred_dict = { + 'cls': pred, + 'prob': prob, + 'feat': torch.mean(features, dim=[2, 3]) + } + return pred_dict + diff --git a/training/detectors/ooc_detector.py b/training/detectors/ooc_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..1da47343a3276650ff12b6b42a9caa3c0b10f77c --- /dev/null +++ b/training/detectors/ooc_detector.py @@ -0,0 +1,343 @@ +# author: Your Name +# date: 2025-xx-xx +# description: CoOp / OOC Detector wrapped into AbstractDetector + +import abc +import types +import torch +import torch.nn as nn + + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from loss import LOSSFUNC + +from clip import clip +from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer + +_tokenizer = _Tokenizer() + + +######################################## +# 1. TextEncoder (your implementation can be placed separately; here it is simplified and kept in the same file) +######################################## + +class TextEncoder(nn.Module): + def __init__(self, clip_model): + super().__init__() + self.transformer = clip_model.transformer + self.positional_embedding = clip_model.positional_embedding + self.ln_final = clip_model.ln_final + self.text_projection = clip_model.text_projection + self.dtype = clip_model.dtype + + # prompts: [num_prompts, n_ctx_total, d_model] + # tokenized_prompts: [num_prompts, 77],used only to locate the EOT token position + def forward(self, prompts, tokenized_prompts): + x = prompts.to(self.positional_embedding.device) + \ + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + # Take the feature at the EOT token of each prompt and project it + x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection + + return x + + +######################################## +# 2. PromptLearner_CoOp (following your implementation) +######################################## + +class PromptLearner_CoOp(nn.Module): + def __init__(self, cfg, classnames, clip_model): + """ + `cfg` is expected to contain at least these fields: + - cfg.n_ctx : number of context tokens + - cfg.backbone : 'RN50' or 'ViT' + """ + super().__init__() + n_cls = len(classnames) + n_ctx = cfg.n_ctx + ctx_init = False + dtype = clip_model.dtype + ctx_dim = clip_model.ln_final.weight.shape[0] + clip_imsize = clip_model.visual.input_resolution + cfg_imsize = 224 + assert cfg_imsize == clip_imsize, \ + f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})" + + if ctx_init: + ctx_init = ctx_init.replace("_", " ") + n_ctx = len(ctx_init.split(" ")) + prompt = clip.tokenize(ctx_init) + with torch.no_grad(): + embedding = clip_model.token_embedding(prompt).type(dtype) + ctx_vectors = embedding[0, 1: 1 + n_ctx, :] + prompt_prefix = ctx_init + else: + # generic random initialization for context + print("Initializing a generic context") + ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype) + nn.init.normal_(ctx_vectors, std=0.02) + prompt_prefix = " ".join(["X"] * n_ctx) + + print(f'Initial context: "{prompt_prefix}"') + print(f"Number of context words (tokens): {n_ctx}") + + self.ctx = nn.Parameter(ctx_vectors) # learnable parameters + + classnames = [name.replace("_", " ") for name in classnames] + name_lens = [len(_tokenizer.encode(name)) for name in classnames] + prompts = [prompt_prefix + " " + name + "." for name in classnames] + + tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) + with torch.no_grad(): + embedding = clip_model.token_embedding(tokenized_prompts).type(dtype) + + print(f"Prompts are: {prompts}") + + self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS + self.register_buffer("token_suffix", embedding[:, 1 + n_ctx:, :]) # CLS, EOS + + self.n_cls = n_cls + self.n_ctx = n_ctx + self.tokenized_prompts = tokenized_prompts # torch.Tensor + self.name_lens = name_lens + self.class_token_position = 'end' + + def forward(self): + ctx = self.ctx + if ctx.dim() == 2: + ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1) + + prefix = self.token_prefix + suffix = self.token_suffix + + if self.class_token_position == "end": + prompts = torch.cat( + [ + prefix, # (n_cls, 1, dim) + ctx, # (n_cls, n_ctx, dim) + suffix, # (n_cls, *, dim) + ], + dim=1, + ) + + elif self.class_token_position == "middle": + half_n_ctx = self.n_ctx // 2 + prompts = [] + for i in range(self.n_cls): + name_len = self.name_lens[i] + prefix_i = prefix[i: i + 1, :, :] + class_i = suffix[i: i + 1, :name_len, :] + suffix_i = suffix[i: i + 1, name_len:, :] + ctx_i_half1 = ctx[i: i + 1, :half_n_ctx, :] + ctx_i_half2 = ctx[i: i + 1, half_n_ctx:, :] + prompt = torch.cat( + [ + prefix_i, + ctx_i_half1, + class_i, + ctx_i_half2, + suffix_i, + ], + dim=1, + ) + prompts.append(prompt) + prompts = torch.cat(prompts, dim=0) + + elif self.class_token_position == "front": + prompts = [] + for i in range(self.n_cls): + name_len = self.name_lens[i] + prefix_i = prefix[i: i + 1, :, :] + class_i = suffix[i: i + 1, :name_len, :] + suffix_i = suffix[i: i + 1, name_len:, :] + ctx_i = ctx[i: i + 1, :, :] + prompt = torch.cat( + [ + prefix_i, + class_i, + ctx_i, + suffix_i, + ], + dim=1, + ) + prompts.append(prompt) + prompts = torch.cat(prompts, dim=0) + + else: + raise ValueError + + return prompts + + +######################################## +# 3. Build the CoOp model (similar to your CoOp function, but returning only the model) +######################################## + +def build_clip_model_with_coop(cfg): + """ + cfg Required fields: + - cfg.backbone: 'RN50' or 'ViT' + - cfg.n_ctx + - cfg.class_names: list[str],length = num_classes + """ + if cfg.backbone == 'RN50': + backbone_name = 'RN50' + elif cfg.backbone == 'ViT': + backbone_name = "ViT-B/16" + else: + raise ValueError(f"Unsupported backbone: {cfg.backbone}") + + clip_model, preprocess = clip.load(backbone_name, device="cpu", jit=False) + + print("Building custom CLIP with CoOp (multi-class)") + class_names = cfg.class_names + assert len(class_names) >= 2, "class_names must contain at least 2 classes" + + model = CustomCLIP_CoOp(cfg, class_names, clip_model) + return model + + +class CustomCLIP_CoOp(nn.Module): + def __init__(self, cfg, classnames, clip_model): + super().__init__() + self.prompt_learner = PromptLearner_CoOp(cfg, classnames, clip_model) + self.tokenized_prompts = self.prompt_learner.tokenized_prompts + self.image_encoder = clip_model.visual + self.text_encoder = TextEncoder(clip_model) + self.logit_scale = clip_model.logit_scale + self.dtype = clip_model.dtype + self.cfg = cfg + + def encode_image(self, image): + # image: [B, 3, 224, 224] + if self.cfg.backbone == "ViT": + image_features = self.image_encoder(image.type(self.dtype))[:, 0, :] + else: + image_features = self.image_encoder(image.type(self.dtype)) + return image_features + + def encode_text(self): + prompts = self.prompt_learner() + tokenized_prompts = self.tokenized_prompts + text_features = self.text_encoder(prompts, tokenized_prompts) + return text_features + + def forward(self, image): + # Return logits, image_features, and text_features, consistent with the original implementation + image_features = self.encode_image(image) + text_features = self.encode_text() + + image_features = image_features / image_features.norm(dim=-1, keepdim=True) + text_features = text_features / text_features.norm(dim=-1, keepdim=True) + + logit_scale = self.logit_scale.exp() + logits = logit_scale * image_features @ text_features.t() + + return logits, image_features, text_features + + +######################################## +# 4. Detector wrapper actually used in DeepFakeBench +######################################## + +# Assume you have a `DETECTOR` registry (following the style of `ResnetDetector`) + + + +@DETECTOR.register_module(module_name='coop_ooc') +class OOCDetector(AbstractDetector): + def __init__(self, config): + super().__init__(config) + self.config = config + self.backbone = self.build_backbone(config) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + """ + config['backbone_config'] needs to include: + - backbone: 'RN50' or 'ViT' + - n_ctx: int + - class_names: list[str],class names for multi-class classification + - num_classes: int, consistent with `len(class_names)` (can be validated) + """ + cfg_dict = config['backbone_config'] + + class_names = cfg_dict.get('class_names', None) + if class_names is None: + raise ValueError("backbone_config must provide a `class_names` list in") + + num_classes = cfg_dict.get('num_classes', None) + if num_classes is not None and num_classes != len(class_names): + raise ValueError( + f"num_classes ({num_classes}) and the length of `class_names` ({len(class_names)}) do not match" + ) + + # Construct a simple config object with attributes for CoOp + import types + cfg = types.SimpleNamespace( + backbone=cfg_dict.get('backbone', 'RN50'), + n_ctx=cfg_dict.get('n_ctx', 8), + class_names=class_names + ) + + model = build_clip_model_with_coop(cfg) + return model + + def build_loss(self, config): + loss_class = LOSSFUNC[config['loss_func']] + return loss_class() + + def features(self, data_dict: dict) -> torch.Tensor: + image = data_dict['image'] # [B, 3, H, W] + _, image_features, _ = self.backbone(image) + return image_features + + def classifier(self, features: torch.Tensor) -> torch.Tensor: + # Re-encode the text for multi-class classification + text_features = self.backbone.encode_text() + + image_features = features / features.norm(dim=-1, keepdim=True) + text_features = text_features / text_features.norm(dim=-1, keepdim=True) + + logit_scale = self.backbone.logit_scale.exp() + logits = logit_scale * image_features @ text_features.t() # [B, num_classes] + return logits + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] # [B],value range [0, num_classes-1] + pred = pred_dict['cls'] # [B, num_classes] + loss = self.loss_func(pred, label) + return {'overall': loss} + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + + num_classes = self.config['backbone_config']['num_classes'] + acc, mAP = calculate_acc_for_train( + label.detach(), pred.detach(), num_classes + ) + return {'acc': acc, 'mAP': mAP} + + def forward(self, data_dict: dict, inference=False) -> dict: + features = self.features(data_dict) + pred = self.classifier(features) + prob = torch.softmax(pred, dim=1) + + pred_dict = { + 'cls': pred, + 'prob': prob, + 'feat': features + } + return pred_dict + + diff --git a/training/detectors/pcl_xception_detector.py b/training/detectors/pcl_xception_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..7257fe7835dcb3d9ffb1ee330726b8f2241ac865 --- /dev/null +++ b/training/detectors/pcl_xception_detector.py @@ -0,0 +1,282 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the PCLDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{zhao2021learning, + title={Learning self-consistency for deepfake detection}, + author={Zhao, Tianchen and Xu, Xiang and Xu, Mingze and Ding, Hui and Xiong, Yuanjun and Xia, Wei}, + booktitle={Proceedings of the IEEE/CVF international conference on computer vision}, + pages={15023--15033}, + year={2021} +} +''' + + +import os +import datetime +import logging +import random + +import numpy as np +import yaml +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from dataset.I2G_dataset import I2GDataset +from metrics.base_metrics_class import calculate_metrics_for_train + +from detectors.base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +import math +from torchvision import transforms + +logger = logging.getLogger(__name__) + + +class Masks4D(object): + def __call__(self, masks): + + first_w = True + first_h = True + first_c = True + + for k, mask in enumerate(masks): + mask=mask.squeeze(0) + h, w = mask.shape + real_mask = torch.unsqueeze(torch.unsqueeze(torch.unsqueeze(mask, 0), 0), 0) + # fake_mask = torch.unsqueeze(torch.unsqueeze(1 - mask, 0), 0) + for i, mask_h in enumerate(mask): + for j, mask_w in enumerate(mask_h): + curr_mask = 1 - torch.abs(mask_w - real_mask) + if first_w: + total_mask_w = real_mask + first_w = False + else: + total_mask_w = torch.cat((total_mask_w, curr_mask), dim=2) + if first_h: + total_mask_h = total_mask_w + first_h = False + else: + total_mask_h = torch.cat((total_mask_h, total_mask_w), dim = 1) + first_w = True + if first_c: + total_mask_c = total_mask_h + first_c = False + else: + total_mask_c = torch.cat((total_mask_c, total_mask_h), dim = 0) + first_h = True + return total_mask_c + + +class NLBlockND(nn.Module): + def __init__(self, in_channels=256): + """Implementation of Non-Local Block with 4 different pairwise functions but doesn't include subsampling trick + args: + in_channels: original channel size (1024 in the paper) + inter_channels: channel size inside the block if not specifed reduced to half (512 in the paper) + mode: supports Gaussian, Embedded Gaussian, Dot Product, and Concatenation + dimension: can be 1 (temporal), 2 (spatial), 3 (spatiotemporal) + bn_layer: whether to add batch norm + """ + super(NLBlockND, self).__init__() + + self.in_channels = in_channels + + # assign appropriate convolutional, max pool, and batch norm layers for different dimensions + + # add BatchNorm layer after the last conv layer + self.sig = nn.Sigmoid() + + # define theta and phi for all operations except Gaussian; why are there two? + self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1) + self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1) + + def forward(self, x, return_nl_map=False): + """ + args + x: (N, C, T, H, W) for dimension=3; (N, C, H, W) for dimension 2; (N, C, T) for dimension 1 + """ + + batch_size = x.size(0) + + # (N, C, THW) + # this reshaping and permutation is from the spacetime_nonlocal function in the original Caffe2 implementation + + theta_x = self.theta(x).view(batch_size, self.in_channels, -1) #flatten operation + #phi_x = self.phi(x).view(batch_size, self.in_channels, -1) + phi_x = self.theta(x).view(batch_size, self.in_channels, -1) + theta_x = theta_x.permute(0, 2, 1) + #channel as vector + f = torch.matmul(theta_x, phi_x) + f_div_C = f / math.sqrt(self.in_channels) + + # contiguous here just allocates contiguous chunk of memory + y = f_div_C.permute(0, 2, 1).contiguous() + + sig_y = self.sig(y) + final_y = sig_y.view(batch_size, *x.size()[2:], *x.size()[2:]) + + if return_nl_map: + return final_y, sig_y + else: + return final_y + + +@DETECTOR.register_module(module_name='pcl_xception') +class PCLXceptionDetector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.loss_func = self.build_loss(config) + self.PCL = NLBlockND(in_channels=728) + self.Msk_PCL = transforms.Compose([Masks4D()]) + self.mask_down_sampling = nn.UpsamplingBilinear2d( + scale_factor=16 / 256) + self.criterionBCE = nn.BCELoss() + + def build_backbone(self, config): + backbone_class = BACKBONE[config['backbone_name']] + model_config = config['backbone_config'] + backbone = backbone_class(model_config) + if config['pretrained'] != 'None': + # if donot load the pretrained weights, fail to get good results + state_dict = torch.load(config['pretrained']) + for name, weights in state_dict.items(): + if 'pointwise' in name: + state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1) + state_dict = {k:v for k, v in state_dict.items() if 'fc' not in k} + # backbone.classifier=classifier + backbone.load_state_dict(state_dict, False) + logger.info('Load pretrained model successfully!') + else: + logger.info('No pretrained model.') + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + if config['loss_func']=='center_loss': + loss_func = loss_class(num_classes=2, feat_dim=2048) + else: + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + return self.backbone.features(data_dict['image']) #32,3,256,256 + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.backbone.classifier(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + if pred_dict['pcl_map'] is not None: + pcl_loss = self.criterionBCE(pred_dict['pcl_map'],pred_dict['pcl_gt_map']) + else: + pcl_loss = 0 + det_loss = self.loss_func(pred, label) + loss = det_loss+ self.config['pcl_loss_weight'] * pcl_loss + loss_dict = {'overall': loss,'pcl_loss': pcl_loss, 'det_loss':det_loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + # compute metrics for batch data + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + #print(data_dict['image'].device) + # get the features by backbone + features,x3 = self.features(data_dict) + if not inference: + pcl_map=self.PCL(x3) + pcl_gt_map=self.Msk_PCL(self.mask_down_sampling(data_dict['mask'])) + else: + pcl_map,pcl_gt_map = None, None + # get the prediction by classifier + pred,x = self.classifier(features) + # get the probability of the pred + prob = torch.softmax(pred, dim=1)[:, 1] + + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features, 'pcl_map':pcl_map, 'pcl_gt_map': pcl_gt_map} + return pred_dict + + + +if __name__ == '__main__': + with open(r'H:\code\DeepfakeBench\training\config\detector\pcl_xception.yaml', 'r') as f: + config = yaml.safe_load(f) + with open('./training/config/train_config.yaml', 'r') as f: + config2 = yaml.safe_load(f) + config.update(config2) + if config['manualSeed'] is None: + config['manualSeed'] = random.randint(1, 10000) + random.seed(config['manualSeed']) + torch.manual_seed(config['manualSeed']) + if config['cuda']: + torch.cuda.manual_seed_all(config['manualSeed']) + detector=PCLXceptionDetector(config=config).cuda() + config['data_manner'] = 'lmdb' + config['dataset_json_folder'] = 'preprocessing/dataset_json_v3' + config['sample_size']=256 + config['with_mask']=True + config['with_landmark']=True + config['use_data_augmentation']=True + train_set = I2GDataset(config=config, mode='train') + train_data_loader = \ + torch.utils.data.DataLoader( + dataset=train_set, + batch_size=8, + shuffle=True, + num_workers=0, + collate_fn=train_set.collate_fn, + ) + optimizer = optim.Adam( + params=detector.parameters(), + lr=config['optimizer']['adam']['lr'], + weight_decay=config['optimizer']['adam']['weight_decay'], + betas=(config['optimizer']['adam']['beta1'], config['optimizer']['adam']['beta2']), + eps=config['optimizer']['adam']['eps'], + amsgrad=config['optimizer']['adam']['amsgrad'], + ) + from tqdm import tqdm + for iteration, batch in enumerate(tqdm(train_data_loader)): + print(iteration) + batch['image'],batch['label'],batch['mask']=batch['image'].cuda(),batch['label'].cuda(),batch['mask'].cuda() + predictions=detector(batch) + losses = detector.get_losses(batch, predictions) + optimizer.zero_grad() + losses['overall'].backward() + optimizer.step() + + if iteration > 10: + break \ No newline at end of file diff --git a/training/detectors/pose_detector.py b/training/detectors/pose_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..6057acd2cbc939ed2abbb98e416ccf6b37651d89 --- /dev/null +++ b/training/detectors/pose_detector.py @@ -0,0 +1,144 @@ + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +import loralib as lora +import copy + +logger = logging.getLogger(__name__) + + +class vgg_layer(nn.Module): + def __init__(self, nin, nout): + super(vgg_layer, self).__init__() + self.main = nn.Sequential( + nn.Conv2d(nin, nout, 3, 1, 1), + nn.BatchNorm2d(nout), + nn.LeakyReLU(0.2) + ) + + def forward(self, input): + return self.main(input) + +class dcgan_conv(nn.Module): + def __init__(self, nin, nout): + super(dcgan_conv, self).__init__() + self.main = nn.Sequential( + nn.Conv2d(nin, nout, 4, 2, 1), + nn.BatchNorm2d(nout), + nn.LeakyReLU(0.2), + ) + + def forward(self, input): + return self.main(input) + +class Simple_CNN(nn.Module): + def __init__(self): + super(Simple_CNN, self).__init__() + nf = 64 + nc = 3 + self.main = nn.Sequential( + dcgan_conv(nc, nf), + vgg_layer(nf, nf), + + dcgan_conv(nf, nf * 2), + vgg_layer(nf * 2, nf * 2), + + dcgan_conv(nf * 2, nf * 4), + vgg_layer(nf * 4, nf * 4), + + dcgan_conv(nf * 4, nf * 8), + vgg_layer(nf * 8, nf * 8), + ) + + self.fc = nn.Linear(nf * 8 * 14 * 14, nf * 8, bias=True) + + # self.classification_head = nn.Sequential( + # nn.Dropout(p=0.2, inplace=True), + # nn.Linear(nf * 8, class_num, bias=True) + # ) + + def forward(self, input): + embedding = self.main(input) # [32, 512, 14, 14] + feature = embedding.view(embedding.shape[0], -1) + feature = self.fc(feature) + + return feature + + +@DETECTOR.register_module(module_name='pose') +class POSE_Detector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.head = nn.Linear(512, config['backbone_config']['num_classes']) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + # prepare the backbone + backbone = Simple_CNN() + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + feat = self.backbone(data_dict['image']) + return feat + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + + # compute metrics for batch data + # auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + # metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.config['backbone_config']['num_classes']) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + # prob = torch.softmax(pred, dim=1)[:, 1] + prob = torch.softmax(pred, dim=1) + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + return pred_dict diff --git a/training/detectors/recce_detector.py b/training/detectors/recce_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..514218e49f492fc2cae916f82f295a62926ba4da --- /dev/null +++ b/training/detectors/recce_detector.py @@ -0,0 +1,359 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the RECCEDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{cao2022end, + title={End-to-end reconstruction-classification learning for face forgery detection}, + author={Cao, Junyi and Ma, Chao and Yao, Taiping and Chen, Shen and Ding, Shouhong and Yang, Xiaokang}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + pages={4113--4122}, + year={2022} +} +''' + +import os +import datetime +from typing import Union +from sklearn import metrics +from collections import defaultdict +from functools import partial +from timm.models import xception +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter +import numpy as np +import argparse +from metrics.base_metrics_class import calculate_metrics_for_train + +from networks.xception import SeparableConv2d, Block +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +import logging + +logger = logging.getLogger(__name__) + +encoder_params = { + "xception": { + "features": 2048, + "init_op": partial(xception, pretrained=True) + } +} + +@DETECTOR.register_module(module_name='recce') +class RecceDetector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) # FIXME: do not use the self.backbone in recce + self.loss_func = self.build_loss(config) + self.model = Recce(num_classes=2) + + # FIXME: the above function should be comment or something else + def build_backbone(self, config): + # prepare the backbone + backbone_class = BACKBONE[config['backbone_name']] + model_config = config['backbone_config'] + backbone = backbone_class(model_config) + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + return self.model.features(data_dict['image']) + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.model.classifier(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + # compute metrics for batch data + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + + # get the probability of the pred + prob = torch.softmax(pred, dim=1)[:, 1] + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + return pred_dict + + +class Recce(nn.Module): + """ End-to-End Reconstruction-Classification Learning for Face Forgery Detection """ + + def __init__(self, num_classes, drop_rate=0.2): + super(Recce, self).__init__() + self.name = "xception" + self.loss_inputs = dict() + self.encoder = encoder_params[self.name]["init_op"]() + self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) + self.dropout = nn.Dropout(drop_rate) + self.fc = nn.Linear(encoder_params[self.name]["features"], num_classes) + + self.attention = GuidedAttention(depth=728, drop_rate=drop_rate) + self.reasoning = GraphReasoning(728, 256, 256, 256, 128, 256, [2, 4], drop_rate) + + self.decoder1 = nn.Sequential( + nn.UpsamplingNearest2d(scale_factor=2), + SeparableConv2d(728, 256, 3, 1, 1, bias=False), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True) + ) + self.decoder2 = Block(256, 256, 3, 1) + self.decoder3 = nn.Sequential( + nn.UpsamplingNearest2d(scale_factor=2), + SeparableConv2d(256, 128, 3, 1, 1, bias=False), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True) + ) + self.decoder4 = Block(128, 128, 3, 1) + self.decoder5 = nn.Sequential( + nn.UpsamplingNearest2d(scale_factor=2), + SeparableConv2d(128, 64, 3, 1, 1, bias=False), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True) + ) + self.decoder6 = nn.Sequential( + nn.Conv2d(64, 3, 1, 1, bias=False), + nn.Tanh() + ) + + def norm_n_corr(self, x): + norm_embed = F.normalize(self.global_pool(x), p=2, dim=1) + corr = (torch.matmul(norm_embed.squeeze(), norm_embed.squeeze().T) + 1.) / 2. + return norm_embed, corr + + @staticmethod + def add_white_noise(tensor, mean=0., std=1e-6): + rand = torch.rand([tensor.shape[0], 1, 1, 1]) + rand = torch.where(rand > 0.5, 1., 0.).to(tensor.device) + white_noise = torch.normal(mean, std, size=tensor.shape, device=tensor.device) + noise_t = tensor + white_noise * rand + noise_t = torch.clip(noise_t, -1., 1.) + return noise_t + + def features(self, x): + # clear the loss inputs + self.loss_inputs = dict(recons=[], contra=[]) + noise_x = self.add_white_noise(x) if self.training else x + out = self.encoder.conv1(noise_x) + out = self.encoder.bn1(out) + out = self.encoder.act1(out) + out = self.encoder.conv2(out) + out = self.encoder.bn2(out) + out = self.encoder.act2(out) + out = self.encoder.block1(out) + out = self.encoder.block2(out) + out = self.encoder.block3(out) + embedding = self.encoder.block4(out) + + norm_embed, corr = self.norm_n_corr(embedding) + self.loss_inputs['contra'].append(corr) + + out = self.dropout(embedding) + out = self.decoder1(out) + out_d2 = self.decoder2(out) + + norm_embed, corr = self.norm_n_corr(out_d2) + self.loss_inputs['contra'].append(corr) + + out = self.decoder3(out_d2) + out_d4 = self.decoder4(out) + + norm_embed, corr = self.norm_n_corr(out_d4) + self.loss_inputs['contra'].append(corr) + + out = self.decoder5(out_d4) + pred = self.decoder6(out) + + recons_x = F.interpolate(pred, size=x.shape[-2:], mode='bilinear', align_corners=True) + self.loss_inputs['recons'].append(recons_x) + + embedding = self.encoder.block5(embedding) + embedding = self.encoder.block6(embedding) + embedding = self.encoder.block7(embedding) + + fusion = self.reasoning(embedding, out_d2, out_d4) + embedding + + embedding = self.encoder.block8(fusion) + img_att = self.attention(x, recons_x, embedding) + + embedding = self.encoder.block9(img_att) + embedding = self.encoder.block10(embedding) + embedding = self.encoder.block11(embedding) + embedding = self.encoder.block12(embedding) + + embedding = self.encoder.conv3(embedding) + embedding = self.encoder.bn3(embedding) + embedding = self.encoder.act3(embedding) + embedding = self.encoder.conv4(embedding) + embedding = self.encoder.bn4(embedding) + embedding = self.encoder.act4(embedding) + + embedding = self.global_pool(embedding).squeeze(2).squeeze(2) + embedding = self.dropout(embedding) + + return embedding + + def classifier(self, embedding): + return self.fc(embedding) + + def forward(self, x): + embedding = self.features(x) + return self.classifier(embedding) + +class GraphReasoning(nn.Module): + """ Graph Reasoning Module for information aggregation. """ + + def __init__(self, va_in, va_out, vb_in, vb_out, vc_in, vc_out, spatial_ratio, drop_rate): + super(GraphReasoning, self).__init__() + self.ratio = spatial_ratio + self.va_embedding = nn.Sequential( + nn.Conv2d(va_in, va_out, 1, bias=False), + nn.ReLU(True), + nn.Conv2d(va_out, va_out, 1, bias=False), + ) + self.va_gated_b = nn.Sequential( + nn.Conv2d(va_in, va_out, 1, bias=False), + nn.Sigmoid() + ) + self.va_gated_c = nn.Sequential( + nn.Conv2d(va_in, va_out, 1, bias=False), + nn.Sigmoid() + ) + self.vb_embedding = nn.Sequential( + nn.Linear(vb_in, vb_out, bias=False), + nn.ReLU(True), + nn.Linear(vb_out, vb_out, bias=False), + ) + self.vc_embedding = nn.Sequential( + nn.Linear(vc_in, vc_out, bias=False), + nn.ReLU(True), + nn.Linear(vc_out, vc_out, bias=False), + ) + self.unfold_b = nn.Unfold(kernel_size=spatial_ratio[0], stride=spatial_ratio[0]) + self.unfold_c = nn.Unfold(kernel_size=spatial_ratio[1], stride=spatial_ratio[1]) + self.reweight_ab = nn.Sequential( + nn.Linear(va_out + vb_out, 1, bias=False), + nn.ReLU(True), + nn.Softmax(dim=1) + ) + self.reweight_ac = nn.Sequential( + nn.Linear(va_out + vc_out, 1, bias=False), + nn.ReLU(True), + nn.Softmax(dim=1) + ) + self.reproject = nn.Sequential( + nn.Conv2d(va_out + vb_out + vc_out, va_in, kernel_size=1, bias=False), + nn.ReLU(True), + nn.Conv2d(va_in, va_in, kernel_size=1, bias=False), + nn.Dropout(drop_rate) if drop_rate is not None else nn.Identity(), + ) + + def forward(self, vert_a, vert_b, vert_c): + emb_vert_a = self.va_embedding(vert_a) + emb_vert_a = emb_vert_a.reshape([emb_vert_a.shape[0], emb_vert_a.shape[1], -1]) + + gate_vert_b = 1 - self.va_gated_b(vert_a) + gate_vert_b = gate_vert_b.reshape(*emb_vert_a.shape) + gate_vert_c = 1 - self.va_gated_c(vert_a) + gate_vert_c = gate_vert_c.reshape(*emb_vert_a.shape) + + vert_b = self.unfold_b(vert_b).reshape( + [vert_b.shape[0], vert_b.shape[1], self.ratio[0] * self.ratio[0], -1]) + vert_b = vert_b.permute([0, 2, 3, 1]) + emb_vert_b = self.vb_embedding(vert_b) + + vert_c = self.unfold_c(vert_c).reshape( + [vert_c.shape[0], vert_c.shape[1], self.ratio[1] * self.ratio[1], -1]) + vert_c = vert_c.permute([0, 2, 3, 1]) + emb_vert_c = self.vc_embedding(vert_c) + + agg_vb = list() + agg_vc = list() + for j in range(emb_vert_a.shape[-1]): + # ab propagating + emb_v_a = torch.stack([emb_vert_a[:, :, j]] * (self.ratio[0] ** 2), dim=1) + emb_v_b = emb_vert_b[:, :, j, :] + emb_v_ab = torch.cat([emb_v_a, emb_v_b], dim=-1) + w = self.reweight_ab(emb_v_ab) + agg_vb.append(torch.bmm(emb_v_b.transpose(1, 2), w).squeeze() * gate_vert_b[:, :, j]) + + # ac propagating + emb_v_a = torch.stack([emb_vert_a[:, :, j]] * (self.ratio[1] ** 2), dim=1) + emb_v_c = emb_vert_c[:, :, j, :] + emb_v_ac = torch.cat([emb_v_a, emb_v_c], dim=-1) + w = self.reweight_ac(emb_v_ac) + agg_vc.append(torch.bmm(emb_v_c.transpose(1, 2), w).squeeze() * gate_vert_c[:, :, j]) + + agg_vert_b = torch.stack(agg_vb, dim=-1) + agg_vert_c = torch.stack(agg_vc, dim=-1) + agg_vert_bc = torch.cat([agg_vert_b, agg_vert_c], dim=1) + agg_vert_abc = torch.cat([agg_vert_bc, emb_vert_a], dim=1) + agg_vert_abc = torch.sigmoid(agg_vert_abc) + agg_vert_abc = agg_vert_abc.reshape(vert_a.shape[0], -1, vert_a.shape[2], vert_a.shape[3]) + return self.reproject(agg_vert_abc) + + +class GuidedAttention(nn.Module): + """ Reconstruction Guided Attention. """ + + def __init__(self, depth=728, drop_rate=0.2): + super(GuidedAttention, self).__init__() + self.depth = depth + self.gated = nn.Sequential( + nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1, bias=False), + nn.ReLU(True), + nn.Conv2d(3, 1, 1, bias=False), + nn.Sigmoid() + ) + self.h = nn.Sequential( + nn.Conv2d(depth, depth, 1, 1, bias=False), + nn.BatchNorm2d(depth), + nn.ReLU(True), + ) + self.dropout = nn.Dropout(drop_rate) + + def forward(self, x, pred_x, embedding): + residual_full = torch.abs(x - pred_x) + residual_x = F.interpolate(residual_full, size=embedding.shape[-2:], + mode='bilinear', align_corners=True) + res_map = self.gated(residual_x) + return res_map * self.h(embedding) + self.dropout(embedding) diff --git a/training/detectors/repmix_detector.py b/training/detectors/repmix_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..900967e7ca5b25506fde5e1c057f2b1e6727e498 --- /dev/null +++ b/training/detectors/repmix_detector.py @@ -0,0 +1,144 @@ + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +from transformers import AutoProcessor, CLIPModel, ViTModel, ViTConfig +import loralib as lora +import copy +import torchvision + + +logger = logging.getLogger(__name__) + + +class MixupLayer(nn.Module): + def __init__(self, nmix=2, beta=0.4): + super().__init__() + self.nmix = nmix + self.beta = beta + self.dist = torch.distributions.Beta(beta, beta) + + def forward(self, x, ratio=None): + # Simplified Mixup logic (in practice this should be extended based on `nmix`; this is only an example) + if ratio is None: + ratio = self.dist.sample((x.size(0), 1, 1, 1)).to(x.device) + # Randomly shuffle the batch to mix samples + idx = torch.randperm(x.size(0)).to(x.device) + x_mix = ratio * x + (1 - ratio) * x[idx] + return x_mix + +class ResnetMixup(nn.Module): + def __init__(self, d_embed, mix_level=0, nmix=2, beta=0.4): + # Change 1: remove the `inference` argument (no longer passed manually) + super().__init__() + + # 1. Load pretrained ResNet50 and replace the fc layer + model = torchvision.models.resnet50(pretrained=True, progress=False) + model.fc = nn.Linear(model.fc.in_features, d_embed) # original 2048 dimensions -> d_embed dimensions + + # 2. Split and reorganize ResNet50 modules to make Mixup insertion easier + model = list(model.children()) + model = [nn.Sequential(*model[:4])] + model[4:-2] + [nn.Sequential(model[-2], nn.Flatten(1), model[-1])] + + # 3. Validate that the Mixup insertion position is legal + assert mix_level <= len(model), f"mix_level({mix_level})exceeds the module list length({len(model)})" + + # 4. Create the Mixup layer and insert it at the specified position + mx_layer = MixupLayer(nmix, beta) + model.insert(mix_level, mx_layer) + + # 5. Save parameters and model structure + self.mix_level = mix_level + self.model = nn.ModuleList(model) + + # Print model information for debugging + print(f'ResnetMixup initialization completed, default training mode(self.training={self.training}):\n', self.model) + + def forward(self, x, ratio=None): + for i, layer in enumerate(self.model): + if i == self.mix_level: + # Change 2: use `self.training` to determine the mode (automatically synchronized with train/eval state) + if self.training: # Training mode (`model.train()`): enable Mixup + x = layer(x, ratio) + # Inference mode (`model.eval()`): automatically skip Mixup and execute the next layer directly + else: + x = layer(x) + return x + + +@DETECTOR.register_module(module_name='repmix') +class RepMix_Detector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.head = nn.Linear(1024, config['backbone_config']['num_classes']) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + # prepare the backbone + backbone = ResnetMixup(1024) + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + feat = self.backbone(data_dict['image']) + return feat + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + + # compute metrics for batch data + # auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + # metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.config['backbone_config']['num_classes']) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + # prob = torch.softmax(pred, dim=1)[:, 1] + prob = torch.softmax(pred, dim=1) + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + return pred_dict diff --git a/training/detectors/resnet34_detector.py b/training/detectors/resnet34_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..ab54f555ef5ad915c1bc006f8148eedb5a81b969 --- /dev/null +++ b/training/detectors/resnet34_detector.py @@ -0,0 +1,121 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the ResnetDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{wang2020cnn, + title={CNN-generated images are surprisingly easy to spot... for now}, + author={Wang, Sheng-Yu and Wang, Oliver and Zhang, Richard and Owens, Andrew and Efros, Alexei A}, + booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition}, + pages={8695--8704}, + year={2020} +} + +Notes: +We chose to use ResNet-34 as the backbone instead of ResNet-50 because the number of parameters in ResNet-34 is relatively similar to that of Xception. This similarity allows us to make a more meaningful and fair comparison between different architectures. +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC + +logger = logging.getLogger(__name__) + +@DETECTOR.register_module(module_name='resnet34') +class ResnetDetector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + # prepare the backbone + backbone_class = BACKBONE[config['backbone_name']] + model_config = config['backbone_config'] + backbone = backbone_class(model_config) + #FIXME: current load pretrained weights only from the backbone, not here + # # if donot load the pretrained weights, fail to get good results + # state_dict = torch.load(config['pretrained']) + # state_dict = {'resnet.'+k:v for k, v in state_dict.items() if 'fc' not in k} + # backbone.load_state_dict(state_dict, False) + # logger.info('Load pretrained model successfully!') + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + return self.backbone.features(data_dict['image']) + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.backbone.classifier(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + + # compute metrics for batch data + # auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + # metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.config['backbone_config']['num_classes']) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + + # prob = torch.softmax(pred, dim=1)[:, 1] + prob = torch.softmax(pred, dim=1) + + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': torch.mean(features, dim=[2, 3])} + return pred_dict + diff --git a/training/detectors/resnet34_distill_detector.py b/training/detectors/resnet34_distill_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..fe2590eb4f4e8126d7e44fc04693b0747f3968d2 --- /dev/null +++ b/training/detectors/resnet34_distill_detector.py @@ -0,0 +1,93 @@ + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC + +logger = logging.getLogger(__name__) + +@DETECTOR.register_module(module_name='resnet34_distill') +class DetectorDistill(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + # prepare the backbone + backbone_class = BACKBONE[config['backbone_name']] + model_config = config['backbone_config'] + backbone = backbone_class(model_config) + + #FIXME: current load pretrained weights only from the backbone, not here + # # if donot load the pretrained weights, fail to get good results + state_dict = torch.load(config['pretrained']) + state_dict = {'resnet.'+k:v for k, v in state_dict.items() if 'fc' not in k} + backbone.load_state_dict(state_dict, False) + logger.info('Load pretrained model successfully!') + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + return self.backbone.features(data_dict['image']) + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.backbone.classifier(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + + # compute metrics for batch data + # auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + # metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.config['backbone_config']['num_classes']) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + + # prob = torch.softmax(pred, dim=1)[:, 1] + prob = torch.softmax(pred, dim=1) + + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': torch.mean(features, dim=[2, 3])} + return pred_dict + diff --git a/training/detectors/rfm_detector.py b/training/detectors/rfm_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..3a7411ba330ebd3c9d6f0d3cc925e3928bc0b488 --- /dev/null +++ b/training/detectors/rfm_detector.py @@ -0,0 +1,174 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the RFMDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{wang2021representative, + title={Representative forgery mining for fake face detection}, + author={Wang, Chengrui and Deng, Weihong}, + booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition}, + pages={14923--14932}, + year={2021} +} +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict +import random + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC + + +logger = logging.getLogger(__name__) + +@DETECTOR.register_module(module_name='rfm') +class RFMDetector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + # prepare the backbone + backbone_class = BACKBONE[config['backbone_name']] + model_config = config['backbone_config'] + backbone = backbone_class(model_config) + # if donot load the pretrained weights, fail to get good results + state_dict = torch.load(config['pretrained']) + for name, weights in state_dict.items(): + if 'pointwise' in name: + state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1) + state_dict = {k:v for k, v in state_dict.items() if 'fc' not in k} + backbone.load_state_dict(state_dict, False) + logger.info('Load pretrained model successfully!') + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + return self.backbone.features(data_dict['image']) #32,3,256,256 + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.backbone.classifier(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + # compute metrics for batch data + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + return metric_batch_dict + + def cal_fam(self, inputs): + self.backbone.zero_grad() + inputs = inputs.detach().clone() + inputs.requires_grad_() + _, output = self.backbone(inputs) + target = output[:, 1]-output[:, 0] + target.backward(torch.ones(target.shape).cuda()) + fam = torch.abs(inputs.grad) + fam = torch.max(fam, dim=1, keepdim=True)[0] + return fam + + def apply_rfm_augmentation(self, data): + device = data.device + self.backbone.eval() + + # Call `self.cal_fam` directly instead of `cal_fam` + mask = self.cal_fam(data) + imgmask = torch.ones_like(mask) + imgh, imgw = 256, 256 + + # Apply the mask based on FAM + for i in range(len(mask)): + maxind = np.argsort(mask[i].cpu().numpy().flatten())[::-1] + pointcnt = 0 + for pointind in maxind: + pointx = pointind // imgw + pointy = pointind % imgw + + if imgmask[i][0][pointx][pointy] == 1: + eH, eW = 120, 120 + maskh = random.randint(1, eH) + maskw = random.randint(1, eW) + + sh = random.randint(1, maskh) + sw = random.randint(1, maskw) + + top = max(pointx - sh, 0) + bot = min(pointx + (maskh - sh), imgh) + lef = max(pointy - sw, 0) + rig = min(pointy + (maskw - sw), imgw) + + imgmask[i][:, top:bot, lef:rig] = torch.zeros_like(imgmask[i][:, top:bot, lef:rig]) + + pointcnt += 1 + if pointcnt >= 3: + break + + # Apply the masked data + data = imgmask * data + (1 - imgmask) * (torch.rand_like(data) * 2 - 1) + + self.backbone.train() + + return data + + + def forward(self, data_dict: dict, inference=False) -> dict: + if not inference: + # Apply RFM augmentation during non-inference stages + data_dict['image'] = self.apply_rfm_augmentation(data_dict['image']) + + + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + prob = torch.softmax(pred, dim=1)[:, 1] + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + + return pred_dict diff --git a/training/detectors/sbi_detector.py b/training/detectors/sbi_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..0c9ef6b29a15679afbc24d009846e0d96606e1de --- /dev/null +++ b/training/detectors/sbi_detector.py @@ -0,0 +1,118 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the SBIDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{shiohara2022detecting, + title={Detecting deepfakes with self-blended images}, + author={Shiohara, Kaede and Yamasaki, Toshihiko}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + pages={18720--18729}, + year={2022} +} +''' + +import os +import logging +import datetime +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC + + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='sbi') +class SBIDetector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.loss_func = self.build_loss(config) + self.prob, self.label = [], [] + self.video_names = [] + self.correct, self.total = 0, 0 + + def build_backbone(self, config): + # prepare the backbone + backbone_class = BACKBONE[config['backbone_name']] # efficientnetb4 + model_config = config['backbone_config'] + backbone = backbone_class(model_config) + # if donot load the pretrained weights, fail to get good results + state_dict = torch.load(config['pretrained']) # xception-b5690688.pth + for name, weights in state_dict.items(): + if 'pointwise' in name: + state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1) + state_dict = {k:v for k, v in state_dict.items() if 'fc' not in k} + backbone.load_state_dict(state_dict, False) + logger.info('Load pretrained model successfully!') + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + return self.backbone.features(data_dict['image']) + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.backbone.classifier(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + # compute metrics for batch data + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + # we dont compute the video-level metrics for training + self.video_names = [] + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + prob = torch.softmax(pred, dim=1)[:, 1] + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + return pred_dict diff --git a/training/detectors/sia_detector.py b/training/detectors/sia_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..75fe2ac051ad7242d9e5eb830b2f6dab5f6322c8 --- /dev/null +++ b/training/detectors/sia_detector.py @@ -0,0 +1,290 @@ +""" +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the SIADetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{sun2022information, + title={An information theoretic approach for attention-driven face forgery detection}, + author={Sun, Ke and Liu, Hong and Yao, Taiping and Sun, Xiaoshuai and Chen, Shen and Ding, Shouhong and Ji, Rongrong}, + booktitle={European Conference on Computer Vision}, + pages={111--127}, + year={2022}, + organization={Springer} +} +""" + +import logging + +import torch +import torch.nn as nn +import torch.nn.functional as F +from detectors import DETECTOR +from loss import LOSSFUNC +from metrics.base_metrics_class import calculate_metrics_for_train +from networks import BACKBONE + +from .base_detector import AbstractDetector + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='sia') +class SIADetector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.loss_func = self.build_loss(config) + + self.att0conv = SAIA_conv(24, kernel_size=3, isspace=True, ischannel=True) + self.att1conv = SAIA_conv(32, kernel_size=3, isspace=True, ischannel=True) + self.att2conv = SAIA_conv(56, kernel_size=3, isspace=True, ischannel=True) + self.att3conv = SAIA_conv(112, kernel_size=3, isspace=True, ischannel=True) + self.att4conv = SAIA_conv(160, kernel_size=3, isspace=True, ischannel=True) + self.att5conv = SAIA_conv(272, kernel_size=3, isspace=False, ischannel=True) + self.att6conv = SAIA_conv(448, kernel_size=3, isspace=False, ischannel=True) + + self.avgpool1 = nn.AdaptiveMaxPool2d((32, 32)) + # self.avgpool1 = nn.AdaptiveAvgPool2d((20,20))#[160] + + self.avgpool2 = nn.AdaptiveMaxPool2d((16, 16)) + + self.conv1 = nn.Sequential( + nn.Conv2d(32, 56, 1, 1, 0), + nn.BatchNorm2d(56), + nn.ReLU(inplace=True), + + ) + + self.conv2 = nn.Sequential( + nn.Conv2d(32, 160, 1, 1, 0), + nn.BatchNorm2d(160), + nn.ReLU(inplace=True), + + ) + + self.conv3 = nn.Sequential( + nn.Conv2d(56, 160, 1, 1, 0), + nn.BatchNorm2d(160), + nn.ReLU(inplace=True), + + ) + + num_ftrs = 1792 + num_classes = 1 + + self.linear = nn.Linear(num_ftrs, num_classes) + + def build_backbone(self, config): + # prepare the backbone + backbone_class = BACKBONE[config['backbone_name']] + model_config = config['backbone_config'] + model_config['pretrained'] = self.config.get('pretrained', None) + backbone = backbone_class(model_config) + + # FIXME: current load pretrained weights only from the backbone, not here + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + # Extract features from the EfficientNet-B4 model + x = data_dict['image'] + x = self.extract_features(x) + # if self.mode == 'adjust_channel': + # x = self.adjust_channel(x) + return x + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.backbone.classifier(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + # compute metrics for batch data + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + prob = torch.softmax(pred, dim=1)[:, 1] + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + return pred_dict + + def extract_features(self, inputs): + """use convolution layer to extract feature . + + Args: + inputs (tensor): Input tensor. + + Returns: + Output of the final convolution + layer in the efficientnet model. + """ + # Stem + x = self.backbone.efficientnet._conv_stem(inputs) + x = self.backbone.efficientnet._bn0(x) + x = self.backbone.efficientnet._swish(x) + # x = self._swish(self._bn0(self._conv_stem(inputs))) + + x = self.backbone.efficientnet._blocks[0](x) + x = self.backbone.efficientnet._blocks[1](x) + # print("Output shape after block 1:", x.shape) + + x = self.backbone.efficientnet._blocks[2](x) + x = self.backbone.efficientnet._blocks[3](x) + x = self.backbone.efficientnet._blocks[4](x) + x = self.backbone.efficientnet._blocks[5](x) + # print("Output shape after block 5:", x.shape) + + x, att1 = self.att1conv(x) + res12 = self.avgpool1(self.conv1(att1)) + res14 = self.avgpool2(self.conv2(att1)) + + x = self.backbone.efficientnet._blocks[6](x) + x = self.backbone.efficientnet._blocks[7](x) + x = self.backbone.efficientnet._blocks[8](x) + x = self.backbone.efficientnet._blocks[9](x) + # print("Output shape after block 9:", x.shape) + + x, att2 = self.att2conv(x + res12) + res24 = self.avgpool2(self.conv3(att2)) + + x = self.backbone.efficientnet._blocks[10](x) + x = self.backbone.efficientnet._blocks[11](x) + x = self.backbone.efficientnet._blocks[12](x) + x = self.backbone.efficientnet._blocks[13](x) + x = self.backbone.efficientnet._blocks[14](x) + x = self.backbone.efficientnet._blocks[15](x) + # print("Output shape after block 15:", x.shape) + + x = self.backbone.efficientnet._blocks[16](x) + x = self.backbone.efficientnet._blocks[17](x) + x = self.backbone.efficientnet._blocks[18](x) + x = self.backbone.efficientnet._blocks[19](x) + x = self.backbone.efficientnet._blocks[20](x) + x = self.backbone.efficientnet._blocks[21](x) + # print("Output shape after block 21:", x.shape) + + x, att4 = self.att4conv(x + res24 + res14) + + x = self.backbone.efficientnet._blocks[22](x) + x = self.backbone.efficientnet._blocks[23](x) + x = self.backbone.efficientnet._blocks[24](x) + x = self.backbone.efficientnet._blocks[25](x) + x = self.backbone.efficientnet._blocks[26](x) + x = self.backbone.efficientnet._blocks[27](x) + x = self.backbone.efficientnet._blocks[28](x) + x = self.backbone.efficientnet._blocks[29](x) + # print("Output shape after block 29:", x.shape) + + x = self.backbone.efficientnet._blocks[30](x) + x = self.backbone.efficientnet._blocks[31](x) + # print("Output shape after block 31:", x.shape) + + # for idx, block in enumerate(self.backbone.efficientnet._blocks): + # drop_connect_rate = self.backbone.efficientnet._global_params.drop_connect_rate + # if drop_connect_rate: + # drop_connect_rate *= float(idx) / len(self.backbone.efficientnet._blocks) # scale drop connect_rate + # x = block(x, drop_connect_rate=drop_connect_rate) + # print(idx) + + # Head + x = self.backbone.efficientnet._swish(self.backbone.efficientnet._bn1(self.backbone.efficientnet._conv_head(x))) + + return x + + +class SAIA_conv(nn.Module): + def __init__(self, outdim, kernel_size=3, padding=1, isspace=True, ischannel=True): + super(SAIA_conv, self).__init__() + + self.drop_rate = 0.3 + self.temperature = 0.03 + self.band_width = 1.0 + + self.isspace = isspace + self.ischannel = ischannel + self.outdim = outdim + + kernel = torch.ones((outdim, 1, kernel_size, kernel_size)) + self.weight = nn.Parameter(data=kernel, requires_grad=False) + kernel2 = torch.ones((outdim, 1, 1, 1)) * (kernel_size * kernel_size) + self.weight2 = nn.Parameter(data=kernel2, requires_grad=False) + self.pad = padding + self.channel_range = 5 + + def forward(self, x): + with torch.no_grad(): + batch_size = x.shape[0] + num_channel = x.shape[1] + # intra-feature + + x1 = F.conv2d(x, self.weight, padding=self.pad, groups=self.outdim) + x2 = F.conv2d(x, self.weight2, padding=0, groups=self.outdim) + intra_distance = torch.abs(x2 - x1) + + # inter-feature + pad_x = torch.cat([x, x[:, :self.channel_range + 1, :, :]], dim=1) + distances = [] + for i in range(1, self.channel_range + 1): + tmp = (x[:, :, :, :] - pad_x[:, i:num_channel + i, :, :]) + distances.append(tmp.clone()) + + distance = torch.cat(distances, dim=1) + batch_size, _, h_dis, w_dis = distance.shape + distance = distance.view(batch_size, -1, self.channel_range, h_dis, w_dis).sum(dim=2) + inter_distance = torch.abs(distance.view(batch_size, -1, h_dis, w_dis)) + att = intra_distance + 0.5 * inter_distance + + if self.ischannel: + distance_channel = att[:] + distance_channel = torch.exp( + -distance_channel / distance_channel.mean() / 2 / self.band_width ** 2) # using mean of distance to normalize + distance_channel = -torch.log(distance_channel + 0.1) + channel_attention = torch.mean(distance_channel.view(batch_size, self.outdim, -1), dim=2) + channel_attention = channel_attention.view(batch_size, -1, 1, 1) + 1 + + if self.isspace: + distance_space = att + distance_space = distance_space / distance_space.mean() / 2 / self.band_width ** 2 + space_attention = distance_space + batch_size, channels, h, w = x.shape + attention_image = (nn.Sigmoid()(space_attention) + 1) * x + + if self.isspace and self.ischannel: + return attention_image * (channel_attention.expand_as(x)), space_attention + elif self.isspace: + return attention_image, x + elif self.ischannel: + return x * (channel_attention.expand_as(x)), x diff --git a/training/detectors/sladd_detector.py b/training/detectors/sladd_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..87c9919c20c96b29f009ff647264d780eb33cb1a --- /dev/null +++ b/training/detectors/sladd_detector.py @@ -0,0 +1,285 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the SLADDDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{chen2022self, + title={Self-supervised learning of adversarial example: Towards good generalizations for deepfake detection}, + author={Chen, Liang and Zhang, Yong and Song, Yibing and Liu, Lingqiao and Wang, Jue}, + booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition}, + pages={18710--18719}, + year={2022} +} +''' + +import os +import datetime +import logging +import random + +import numpy as np +import yaml +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from dataset.pair_dataset import pairDataset +from metrics.base_metrics_class import calculate_metrics_for_train + + +from detectors.base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +from .utils.sladd_api import synthesizer + +logger = logging.getLogger(__name__) +device = "cuda" if torch.cuda.is_available() else "cpu" + +@DETECTOR.register_module(module_name='sladd') +class SLADDXceptionDetector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.loss_func = self.build_loss(config) + self.synthesizer = synthesizer(config=config) + params_synthesizer = ([p for p in self.synthesizer.parameters()]) + # train + self.optimizer_synthesizer = optim.Adam(params_synthesizer, lr=config['optimizer']['adam']['lr']/4, betas=(config['optimizer']['adam']['beta1']/4, 0.999), + weight_decay=config['optimizer']['adam']['weight_decay']) + + # synthesizer should be optimized solely ---> according to the official code. + def parameters(self, recurse=True): + for name, param in self.named_parameters(recurse=recurse): + if 'synthesizer' not in name: + yield param + + def get_test_metrics(self): + pass + + def build_backbone(self, config): + # prepare the backbone + backbone_class = BACKBONE[config['backbone_name']] + model_config = config['backbone_config'] + backbone = backbone_class(model_config) + if config['pretrained'] != 'None': + # if donot load the pretrained weights, fail to get good results + state_dict = torch.load(config['pretrained']) + for name, weights in state_dict.items(): + if 'pointwise' in name: + state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1) + state_dict = {k:v for k, v in state_dict.items() if 'fc' not in k} + backbone.load_state_dict(state_dict, False) + logger.info('Load pretrained model successfully!') + else: + logger.info('No pretrained model.') + return backbone + + def build_loss(self, config): + # prepare the loss function + self.l1loss = nn.MSELoss() + self.cls_criterion = LOSSFUNC[config['typeloss_func']]() + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, imgs) -> torch.tensor: + return self.backbone.features(imgs) #32,3,256,256 + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.backbone.classifier(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = pred_dict['label'] + pred = pred_dict['cls'] + if 'map' in pred_dict: + map, type, mag, type_label, mag_mask, mag_label, alt_mask\ + = pred_dict['map'],pred_dict['type'],pred_dict['mag'],pred_dict['type_label'],\ + pred_dict['mag_mask'],pred_dict['mag_label'],pred_dict['alt_mask'] + loss_type = self.cls_criterion(type, type_label).mean() + loss_mag = self.l1loss(mag*mag_mask, mag_label*mag_mask).mean() + loss_maps = self.l1loss(map, alt_mask) + + else: + loss_type,loss_mag,loss_maps=0,0,0 + + loss = self.loss_func(pred, label) + overall = loss+0.1*loss_maps + 0.05*loss_type + 0.1*loss_mag + if 'map' in pred_dict: + synthesizer_loss,entropy_penalty=self.get_syn_loss(overall,pred_dict) + else: + synthesizer_loss, entropy_penalty = 0,0 + loss_dict = { + 'overall': overall,'synthesizer_loss':synthesizer_loss,'loss_type':loss_type, + 'loss_mag':loss_mag,'loss_maps':loss_maps,'entropy_penalty':entropy_penalty, + } + return loss_dict + + def get_syn_loss(self, loss,pred_dict): + entropy = pred_dict['entropy'] + log_prob = pred_dict['log_prob'] + normlized_lm=loss.detach() + if log_prob is not None: + self.optimizer_synthesizer.zero_grad() + score_loss = torch.mean(-log_prob * normlized_lm) + entropy_penalty = torch.mean(entropy) + synthesizer_loss = score_loss - (1e-5) * entropy_penalty + if synthesizer_loss.requires_grad: + synthesizer_loss.backward() + self.optimizer_synthesizer.step() + else: + synthesizer_loss=0 + entropy_penalty=0 + return synthesizer_loss,entropy_penalty + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + + # compute metrics for batch data + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + return metric_batch_dict + + def syn_preprocess(self,batch): + imgs,lmks,msks,lbs=batch['image'].to(device),batch['landmark'].to(device),batch['mask'].to(device),batch['label'].to(device) + half = len(imgs) // 2 + + # imgs, lmks, msks, lbs = imgs[new_idx], lmks[new_idx], msks[new_idx], lbs[new_idx] + + img, fake_img, real_lmk, fake_lmk, real_mask, fake_mask,real_lb,fake_lb = \ + imgs[:half],imgs[half:],lmks[:half],lmks[half:],msks[:half],msks[half:],lbs[:half],lbs[half:] + + # conduct intentional real-fake switching to fit the stupid setting of original code. + # TODO: too little number of 0. considering replacing it by taking real at simple aug. But many issues may raise + switch_mask = torch.randint(0, 2, (img.shape[0],)).bool() + img[switch_mask], fake_img[switch_mask], real_lmk[switch_mask], fake_lmk[switch_mask], real_mask[switch_mask], fake_mask[switch_mask], real_lb[switch_mask], fake_lb[switch_mask] = \ + fake_img[switch_mask],img[switch_mask],fake_lmk[switch_mask],real_lmk[switch_mask],fake_mask[switch_mask],real_mask[switch_mask],fake_lb[switch_mask],real_lb[switch_mask] + + log_prob, entropy, new_img, alt_mask, label, type_label, mag_label, mag_mask = \ + self.synthesizer(img, fake_img, real_lmk, fake_lmk, real_mask, fake_mask,label=lbs[:half]) + + new_img = new_img.to(device) + label = label.to(device) + type_label = type_label.to(device) + mag_label = mag_label.to(device) + mag_mask = mag_mask.to(device) + alt_mask = alt_mask.to(device) + + ################ simple augmentation seems to be useless + img_flip = torch.flip(new_img, (3,)).detach().clone() + mask_flip = torch.flip(alt_mask, (3,)).detach().clone() + new_img = torch.cat((new_img, img_flip)) + alt_mask = torch.cat((alt_mask, mask_flip)) + label = torch.cat((label, label)) + type_label = torch.cat((type_label, type_label)) + mag_label = torch.cat((mag_label, mag_label)) + mag_mask = torch.cat((mag_mask, mag_mask)) + return new_img,alt_mask,label,type_label,mag_label,mag_mask,log_prob, entropy + + + def forward(self, data_dict: dict, inference=False) -> dict: + if inference: + new_img=data_dict['image'] + label=data_dict['label'] + features,map_fea = self.features(new_img) + # get the prediction by classifier + out,x = self.classifier(features) + pred = out + # get the probability of the pred + prob = torch.softmax(pred, dim=1)[:, 1] + pred_dict = { + 'cls': pred, 'prob': prob, 'feat': features,'label':label, + } + else: + #print(data_dict['image'].device) + new_img,alt_mask,label,type_label,mag_label,mag_mask,log_prob, entropy=self.syn_preprocess(data_dict) + # get the features by backbone + features,map_fea = self.features(new_img) + # get the prediction by classifier + out,x = self.classifier(features) + map = self.backbone.estimateMap(map_fea) + type=self.backbone.type_fc(x) + mag=self.backbone.mag_fc(x) + pred = out + # get the probability of the pred + prob = torch.softmax(pred, dim=1)[:, 1] + + # build the prediction dict for each output + pred_dict = { + 'cls': pred, 'prob': prob, 'feat': features,'map':map,'type':type,'mag':mag, 'log_prob':log_prob,'label':label, + 'entropy':entropy,'alt_mask': alt_mask,'type_label':type_label,'mag_label':mag_label,'mag_mask':mag_mask + } + return pred_dict + +if __name__ == '__main__': + with open(r'H:\code\DeepfakeBench\training\config\detector\sladd_xception.yaml', 'r') as f: + config = yaml.safe_load(f) + if config['manualSeed'] is None: + config['manualSeed'] = random.randint(1, 10000) + random.seed(config['manualSeed']) + torch.manual_seed(config['manualSeed']) + if config['cuda']: + torch.cuda.manual_seed_all(config['manualSeed']) + detector=SLADDXceptionDetector(config=config).to(device) + config['data_manner'] = 'lmdb' + config['dataset_json_folder'] = 'preprocessing/dataset_json_v3' + config['sample_size']=256 + config['with_mask']=True + config['with_landmark']=True + config['use_data_augmentation']=True + train_set = pairDataset(config=config, mode='train') + train_data_loader = \ + torch.utils.data.DataLoader( + dataset=train_set, + batch_size=32, + shuffle=True, + num_workers=0, + collate_fn=train_set.collate_fn, + ) + optimizer = optim.Adam( + params=detector.parameters(), + lr=config['optimizer']['adam']['lr'], + weight_decay=config['optimizer']['adam']['weight_decay'], + betas=(config['optimizer']['adam']['beta1'], config['optimizer']['adam']['beta2']), + eps=config['optimizer']['adam']['eps'], + amsgrad=config['optimizer']['adam']['amsgrad'], + ) + from tqdm import tqdm + for iteration, batch in enumerate(tqdm(train_data_loader)): + continue + imgs,lmks,msks=batch['image'].to(device),batch['landmark'].to(device),batch['mask'].to(device) + batch['image'],batch['landmark'],batch['mask'], batch['label'] = \ + batch['image'].to(device), batch['landmark'].to(device), batch['mask'].to(device),batch['label'].to(device) + half = len(imgs) // 2 + img, fake_img, real_lmk, fake_lmk, real_mask, fake_mask = imgs[:half],imgs[half:],lmks[:half],lmks[half:],msks[:half],msks[half:] + + predictions=detector(batch) + losses = detector.get_losses(batch, predictions) + optimizer.zero_grad() + losses['overall'].backward() + optimizer.step() + + if iteration > 10: + break diff --git a/training/detectors/spsl_detector.py b/training/detectors/spsl_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..a8e2cc7c0bfc6085ba73056aaa33dce4e15c5728 --- /dev/null +++ b/training/detectors/spsl_detector.py @@ -0,0 +1,148 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the SPSLDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{liu2021spatial, + title={Spatial-phase shallow learning: rethinking face forgery detection in frequency domain}, + author={Liu, Honggu and Li, Xiaodan and Zhou, Wenbo and Chen, Yuefeng and He, Yuan and Xue, Hui and Zhang, Weiming and Yu, Nenghai}, + booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition}, + pages={772--781}, + year={2021} +} + +Notes: +To ensure consistency in the comparison with other detectors, we have opted not to utilize the shallow Xception architecture. Instead, we are employing the original Xception model. +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +import random + +logger = logging.getLogger(__name__) + +@DETECTOR.register_module(module_name='spsl') +class SpslDetector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + # prepare the backbone + backbone_class = BACKBONE[config['backbone_name']] + model_config = config['backbone_config'] + backbone = backbone_class(model_config) + + # To get a good performance, use the ImageNet-pretrained Xception model + state_dict = torch.load(config['pretrained']) + for name, weights in state_dict.items(): + if 'pointwise' in name: + state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1) + state_dict = {k:v for k, v in state_dict.items() if 'fc' not in k} + + # remove conv1 from state_dict + conv1_data = state_dict.pop('conv1.weight') + + backbone.load_state_dict(state_dict, False) + logger.info('Load pretrained model from {}'.format(config['pretrained'])) + + # copy on conv1 + # let new conv1 use old param to balance the network + backbone.conv1 = nn.Conv2d(4, 32, 3, 2, 0, bias=False) + avg_conv1_data = conv1_data.mean(dim=1, keepdim=True) # average across the RGB channels + backbone.conv1.weight.data = avg_conv1_data.repeat(1, 4, 1, 1) # repeat the averaged weights across the 4 new channels + logger.info('Copy conv1 from pretrained model') + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict, phase_fea) -> torch.tensor: + features = torch.cat((data_dict['image'], phase_fea), dim=1) + return self.backbone.features(features) + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.backbone.classifier(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + # compute metrics for batch data + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + # we dont compute the video-level metrics for training + self.video_names = [] + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the phase features + phase_fea = self.phase_without_amplitude(data_dict['image']) + # bp + features = self.features(data_dict, phase_fea) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + prob = torch.softmax(pred, dim=1)[:, 1] + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + + return pred_dict + + def phase_without_amplitude(self, img): + # Convert to grayscale + gray_img = torch.mean(img, dim=1, keepdim=True) # shape: (batch_size, 1, 256, 256) + # Compute the DFT of the input signal + X = torch.fft.fftn(gray_img,dim=(-1,-2)) + #X = torch.fft.fftn(img) + # Extract the phase information from the DFT + phase_spectrum = torch.angle(X) + # Create a new complex spectrum with the phase information and zero magnitude + reconstructed_X = torch.exp(1j * phase_spectrum) + # Use the IDFT to obtain the reconstructed signal + reconstructed_x = torch.real(torch.fft.ifftn(reconstructed_X,dim=(-1,-2))) + # reconstructed_x = torch.real(torch.fft.ifftn(reconstructed_X)) + return reconstructed_x diff --git a/training/detectors/srm_detector.py b/training/detectors/srm_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..95664e8418a7d5c39c7f8876b5af81f5e8fd53c2 --- /dev/null +++ b/training/detectors/srm_detector.py @@ -0,0 +1,653 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the SRMDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{luo2021generalizing, + title={Generalizing face forgery detection with high-frequency features}, + author={Luo, Yuchen and Zhang, Yong and Yan, Junchi and Liu, Wei}, + booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition}, + pages={16317--16326}, + year={2021} +} + +Notes: +Other implementation modules are provided by the authors. +''' + +import os +import datetime +import numbers +import math +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC +import random + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='srm') +class SRMDetector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + # prepare the backbone for rgb and srm branch + self.backbone_rgb = self.build_backbone(config) + self.backbone_srm = self.build_backbone(config) + + # srm specific layers and modules + self.noise = GaussianNoise(clip=1) + self.blur = GaussianSmoothing(channels=3, kernel_size=7, sigma=0.8) + self.srm_conv0 = SRMConv2d_simple(inc=3) + self.srm_conv1 = SRMConv2d_Separate(32, 32) + self.srm_conv2 = SRMConv2d_Separate(64, 64) + self.relu = nn.ReLU(inplace=True) + self.att_map = None + self.srm_sa = SRMPixelAttention(3) + self.srm_sa_post = nn.Sequential( + nn.BatchNorm2d(64), + nn.ReLU(inplace=True) + ) + self.dual_cma0 = DualCrossModalAttention(in_dim=728) + self.dual_cma1 = DualCrossModalAttention(in_dim=728) + self.fusion = FeatureFusionModule() + + # prepare the loss function + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + assert config['backbone_name'] == 'xception', "SRM only supports the xception backbone" + # prepare the backbone + backbone_class = BACKBONE[config['backbone_name']] + model_config = config['backbone_config'] + backbone = backbone_class(model_config) + # To get a good performance, use the ImageNet-pretrained Xception model + state_dict = torch.load(config['pretrained']) + for name, weights in state_dict.items(): + if 'pointwise' in name: + state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1) + state_dict = {k:v for k, v in state_dict.items() if 'fc' not in k} + backbone.load_state_dict(state_dict, False) + logger.info('Load pretrained model from {}'.format(config['pretrained'])) + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class(gamma=0., m=0.45, s=30, t=1.) # use am-softmax for srm, params are specified by the author + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + x = data_dict['image'] # get the image as input for srm + srm = self.srm_conv0(x) + + x = self.backbone_rgb.fea_part1_0(x) + y = self.backbone_srm.fea_part1_0(srm) \ + + self.srm_conv1(x) + y = self.relu(y) + + x = self.backbone_rgb.fea_part1_1(x) + y = self.backbone_srm.fea_part1_1(y) \ + + self.srm_conv2(x) + y = self.relu(y) + + # srm guided spatial attention + self.att_map = self.srm_sa(srm) + x = x * self.att_map + x # use the residual + x = self.srm_sa_post(x) + + x = self.backbone_rgb.fea_part2(x) + y = self.backbone_srm.fea_part2(y) + + x, y = self.dual_cma0(x, y) + + x = self.backbone_rgb.fea_part3(x) + y = self.backbone_srm.fea_part3(y) + + x, y = self.dual_cma1(x, y) + + x = self.backbone_rgb.fea_part4(x) + y = self.backbone_srm.fea_part4(y) + + x = self.backbone_rgb.fea_part5(x) + y = self.backbone_srm.fea_part5(y) + + fea = self.fusion(x, y) + return fea + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.backbone_rgb.classifier(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + # compute metrics for batch data + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + # we dont compute the video-level metrics for training + self.video_names = [] + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + prob = torch.softmax(pred, dim=1)[:, 1] + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + return pred_dict + + +# ===================================== other modules for SRM # ===================================== + + +class SRMConv2d(nn.Module): + + def __init__(self, learnable=False): + super(SRMConv2d, self).__init__() + self.weight = nn.Parameter(torch.Tensor(30, 3, 5, 5), + requires_grad=learnable) + self.bias = nn.Parameter(torch.Tensor(30), \ + requires_grad=learnable) + self.reset_parameters() + + def reset_parameters(self): + SRM_npy = np.load('lib/component/SRM_Kernels.npy') + # print(SRM_npy.shape) + SRM_npy = np.repeat(SRM_npy, 3, axis=1) + # print(SRM_npy.shape) + self.weight.data.numpy()[:] = SRM_npy + self.bias.data.zero_() + + def forward(self, input): + return F.conv2d(input, self.weight, stride=1, padding=2) + + +class SRMConv2d_simple(nn.Module): + + def __init__(self, inc=3, learnable=False): + super(SRMConv2d_simple, self).__init__() + self.truc = nn.Hardtanh(-3, 3) + kernel = self._build_kernel(inc) # (3,3,5,5) + self.kernel = nn.Parameter(data=kernel, requires_grad=learnable) + # self.hor_kernel = self._build_kernel().transpose(0,1,3,2) + + def forward(self, x): + ''' + x: imgs (Batch, H, W, 3) + ''' + out = F.conv2d(x, self.kernel, stride=1, padding=2) + out = self.truc(out) + + return out + + def _build_kernel(self, inc): + # filter1: KB + filter1 = [[0, 0, 0, 0, 0], + [0, -1, 2, -1, 0], + [0, 2, -4, 2, 0], + [0, -1, 2, -1, 0], + [0, 0, 0, 0, 0]] + # filter2:KV + filter2 = [[-1, 2, -2, 2, -1], + [2, -6, 8, -6, 2], + [-2, 8, -12, 8, -2], + [2, -6, 8, -6, 2], + [-1, 2, -2, 2, -1]] + # # filter3:hor 2rd + filter3 = [[0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 1, -2, 1, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]] + # filter3:hor 2rd + # filter3 = [[0, 0, 0, 0, 0], + # [0, 0, 1, 0, 0], + # [0, 1, -4, 1, 0], + # [0, 0, 1, 0, 0], + # [0, 0, 0, 0, 0]] + + filter1 = np.asarray(filter1, dtype=float) / 4. + filter2 = np.asarray(filter2, dtype=float) / 12. + filter3 = np.asarray(filter3, dtype=float) / 2. + # statck the filters + filters = [[filter1],#, filter1, filter1], + [filter2],#, filter2, filter2], + [filter3]]#, filter3, filter3]] # (3,3,5,5) + filters = np.array(filters) + filters = np.repeat(filters, inc, axis=1) + filters = torch.FloatTensor(filters) # (3,3,5,5) + return filters + + +class SRMConv2d_Separate(nn.Module): + + def __init__(self, inc, outc, learnable=False): + super(SRMConv2d_Separate, self).__init__() + self.inc = inc + self.truc = nn.Hardtanh(-3, 3) + kernel = self._build_kernel(inc) # (3,3,5,5) + self.kernel = nn.Parameter(data=kernel, requires_grad=learnable) + # self.hor_kernel = self._build_kernel().transpose(0,1,3,2) + self.out_conv = nn.Sequential( + nn.Conv2d(3*inc, outc, 1, 1, 0, 1, 1, bias=False), + nn.BatchNorm2d(outc), + nn.ReLU(inplace=True) + ) + + for ly in self.out_conv.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + + def forward(self, x): + ''' + x: imgs (Batch,inc, H, W) + kernel: (outc,inc,kH,kW) + ''' + out = F.conv2d(x, self.kernel, stride=1, padding=2, groups=self.inc) + out = self.truc(out) + out = self.out_conv(out) + + return out + + def _build_kernel(self, inc): + # filter1: KB + filter1 = [[0, 0, 0, 0, 0], + [0, -1, 2, -1, 0], + [0, 2, -4, 2, 0], + [0, -1, 2, -1, 0], + [0, 0, 0, 0, 0]] + # filter2:KV + filter2 = [[-1, 2, -2, 2, -1], + [2, -6, 8, -6, 2], + [-2, 8, -12, 8, -2], + [2, -6, 8, -6, 2], + [-1, 2, -2, 2, -1]] + # # filter3:hor 2rd + filter3 = [[0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 1, -2, 1, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]] + # filter3:hor 2rd + # filter3 = [[0, 0, 0, 0, 0], + # [0, 0, 1, 0, 0], + # [0, 1, -4, 1, 0], + # [0, 0, 1, 0, 0], + # [0, 0, 0, 0, 0]] + + filter1 = np.asarray(filter1, dtype=float) / 4. + filter2 = np.asarray(filter2, dtype=float) / 12. + filter3 = np.asarray(filter3, dtype=float) / 2. + # statck the filters + filters = [[filter1],#, filter1, filter1], + [filter2],#, filter2, filter2], + [filter3]]#, filter3, filter3]] # (3,3,5,5) => (3,1,5,5) + filters = np.array(filters) + # filters = np.repeat(filters, inc, axis=1) + filters = np.repeat(filters, inc, axis=0) + filters = torch.FloatTensor(filters) # (3*inc,1,5,5) + # print(filters.size()) + return filters + + +class GaussianSmoothing(nn.Module): + """ + Apply gaussian smoothing on a + 1d, 2d or 3d tensor. Filtering is performed seperately for each channel + in the input using a depthwise convolution. + Arguments: + channels (int, sequence): Number of channels of the input tensors. Output will + have this number of channels as well. + kernel_size (int, sequence): Size of the gaussian kernel. + sigma (float, sequence): Standard deviation of the gaussian kernel. + dim (int, optional): The number of dimensions of the data. + Default value is 2 (spatial). + """ + + def __init__(self, channels, kernel_size, sigma=0.1, dim=2): + super(GaussianSmoothing, self).__init__() + self.kernel_size = kernel_size + if isinstance(kernel_size, numbers.Number): + kernel_size = [kernel_size] * dim + if isinstance(sigma, numbers.Number): + sigma = [sigma] * dim + + # The gaussian kernel is the product of the + # gaussian function of each dimension. + kernel = 1 + meshgrids = torch.meshgrid( + [ + torch.arange(size, dtype=torch.float32) + for size in kernel_size + ] + ) + for size, std, mgrid in zip(kernel_size, sigma, meshgrids): + mean = (size - 1) / 2 + kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \ + torch.exp(-((mgrid - mean) / std) ** 2 / 2) + + # Make sure sum of values in gaussian kernel equals 1. + kernel = kernel / torch.sum(kernel) + + # Reshape to depthwise convolutional weight + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) + + self.register_buffer('weight', kernel) + self.groups = channels + + if dim == 1: + self.conv = F.conv1d + elif dim == 2: + self.conv = F.conv2d + elif dim == 3: + self.conv = F.conv3d + else: + raise RuntimeError( + 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format( + dim) + ) + + def forward(self, input): + """ + Apply gaussian filter to input. + Arguments: + input (torch.Tensor): Input to apply gaussian filter on. + Returns: + filtered (torch.Tensor): Filtered output. + """ + if self.training: + return self.conv(input, weight=self.weight, groups=self.groups, padding=self.kernel_size//2) + else: + return input + + +class GaussianNoise(nn.Module): + def __init__(self, mean=0, std=0.1, clip=1): + super(GaussianNoise, self).__init__() + self.mean = mean + self.std = std + self.clip = clip + + def forward(self, x): + if self.training: + noise = x.data.new(x.size()).normal_(self.mean, self.std) + return torch.clamp(x + noise, -self.clip, self.clip) + else: + return x + + +class ChannelAttention(nn.Module): + def __init__(self, in_planes, ratio=8): + super(ChannelAttention, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.max_pool = nn.AdaptiveMaxPool2d(1) + + self.sharedMLP = nn.Sequential( + nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False), + nn.ReLU(), + nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)) + self.sigmoid = nn.Sigmoid() + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.xavier_normal_(m.weight.data, gain=0.02) + + def forward(self, x): + avgout = self.sharedMLP(self.avg_pool(x)) + maxout = self.sharedMLP(self.max_pool(x)) + return self.sigmoid(avgout + maxout) + + +class SpatialAttention(nn.Module): + def __init__(self, kernel_size=7): + super(SpatialAttention, self).__init__() + assert kernel_size in (3, 7), "kernel size must be 3 or 7" + padding = 3 if kernel_size == 7 else 1 + + self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) + self.sigmoid = nn.Sigmoid() + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.xavier_normal_(m.weight.data, gain=0.02) + + def forward(self, x): + avgout = torch.mean(x, dim=1, keepdim=True) + maxout, _ = torch.max(x, dim=1, keepdim=True) + x = torch.cat([avgout, maxout], dim=1) + x = self.conv(x) + return self.sigmoid(x) + + +class CrossModalAttention(nn.Module): + """ CMA attention Layer""" + + def __init__(self, in_dim, activation=None, ratio=8, cross_value=True): + super(CrossModalAttention, self).__init__() + self.chanel_in = in_dim + self.activation = activation + self.cross_value = cross_value + + self.query_conv = nn.Conv2d( + in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1) + self.key_conv = nn.Conv2d( + in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1) + self.value_conv = nn.Conv2d( + in_channels=in_dim, out_channels=in_dim, kernel_size=1) + self.gamma = nn.Parameter(torch.zeros(1)) + + self.softmax = nn.Softmax(dim=-1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.xavier_normal_(m.weight.data, gain=0.02) + + def forward(self, x, y): + """ + inputs : + x : input feature maps( B X C X W X H) + returns : + out : self attention value + input feature + attention: B X N X N (N is Width*Height) + """ + B, C, H, W = x.size() + + proj_query = self.query_conv(x).view( + B, -1, H*W).permute(0, 2, 1) # B , HW, C + proj_key = self.key_conv(y).view( + B, -1, H*W) # B X C x (*W*H) + energy = torch.bmm(proj_query, proj_key) # B, HW, HW + attention = self.softmax(energy) # BX (N) X (N) + if self.cross_value: + proj_value = self.value_conv(y).view( + B, -1, H*W) # B , C , HW + else: + proj_value = self.value_conv(x).view( + B, -1, H*W) # B , C , HW + + out = torch.bmm(proj_value, attention.permute(0, 2, 1)) + out = out.view(B, C, H, W) + + out = self.gamma*out + x + + if self.activation is not None: + out = self.activation(out) + + return out # , attention + + +class DualCrossModalAttention(nn.Module): + """ Dual CMA attention Layer""" + + def __init__(self, in_dim, activation=None, size=16, ratio=8, ret_att=False): + super(DualCrossModalAttention, self).__init__() + self.chanel_in = in_dim + self.activation = activation + self.ret_att = ret_att + + # query conv + self.key_conv1 = nn.Conv2d( + in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1) + self.key_conv2 = nn.Conv2d( + in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1) + self.key_conv_share = nn.Conv2d( + in_channels=in_dim//ratio, out_channels=in_dim//ratio, kernel_size=1) + + self.linear1 = nn.Linear(size*size, size*size) + self.linear2 = nn.Linear(size*size, size*size) + + # separated value conv + self.value_conv1 = nn.Conv2d( + in_channels=in_dim, out_channels=in_dim, kernel_size=1) + self.gamma1 = nn.Parameter(torch.zeros(1)) + + self.value_conv2 = nn.Conv2d( + in_channels=in_dim, out_channels=in_dim, kernel_size=1) + self.gamma2 = nn.Parameter(torch.zeros(1)) + + self.softmax = nn.Softmax(dim=-1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.xavier_normal_(m.weight.data, gain=0.02) + if isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight.data, gain=0.02) + + def forward(self, x, y): + """ + inputs : + x : input feature maps( B X C X W X H) + returns : + out : self attention value + input feature + attention: B X N X N (N is Width*Height) + """ + B, C, H, W = x.size() + + def _get_att(a, b): + proj_key1 = self.key_conv_share(self.key_conv1(a)).view( + B, -1, H*W).permute(0, 2, 1) # B , HW, C + proj_key2 = self.key_conv_share(self.key_conv2(b)).view( + B, -1, H*W) # B X C x (*W*H) + #print('proj_key1:', proj_key1[0][0][:5].cpu().detach().numpy()) + #print('proj_key2:', proj_key2[0][:5][0:5].cpu().detach().numpy()) + energy = torch.bmm(proj_key1, proj_key2) # B, HW, HW + #print('energy:', energy[0][0][:5].cpu().detach().numpy()) + attention1 = self.softmax(self.linear1(energy)) + attention2 = self.softmax(self.linear2(energy.permute(0,2,1))) # BX (N) X (N) + #print('1:', attention1[0]==attention1[1]) + #print('2:', attention2[0]==attention2[1]) + + return attention1, attention2 + + att_y_on_x, att_x_on_y = _get_att(x, y) + #print('att_y_on_x:', att_y_on_x[0][0][:5].cpu().detach().numpy()) + proj_value_y_on_x = self.value_conv2(y).view( + B, -1, H*W) # B , C , HW + out_y_on_x = torch.bmm(proj_value_y_on_x, att_y_on_x.permute(0, 2, 1)) + out_y_on_x = out_y_on_x.view(B, C, H, W) + out_x = self.gamma1*out_y_on_x + x + + proj_value_x_on_y = self.value_conv1(x).view( + B, -1, H*W) # B , C , HW + out_x_on_y = torch.bmm(proj_value_x_on_y, att_x_on_y.permute(0, 2, 1)) + out_x_on_y = out_x_on_y.view(B, C, H, W) + out_y = self.gamma2*out_x_on_y + y + + if self.ret_att: + return out_x, out_y, att_y_on_x, att_x_on_y + + return out_x, out_y # , attention + + +class SRMPixelAttention(nn.Module): + def __init__(self, in_channels): + super(SRMPixelAttention, self).__init__() + self.srm = SRMConv2d_simple() + self.conv = nn.Sequential( + nn.Conv2d(in_channels, 32, 3, 2, 0, bias=False), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + nn.Conv2d(32, 64, 3, bias=False), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + ) + self.pa = SpatialAttention() + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, a=1) + if not m.bias is None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x_srm = self.srm(x) + fea = self.conv(x_srm) + # fea += fea * self.ca(fea) + att_map = self.pa(fea) + # return x * y + return att_map + + +class FeatureFusionModule(nn.Module): + def __init__(self, in_chan=2048*2, out_chan=2048, *args, **kwargs): + super(FeatureFusionModule, self).__init__() + self.convblk = nn.Sequential( + nn.Conv2d(in_chan, out_chan, 1, 1, 0, bias=False), + nn.BatchNorm2d(out_chan), + nn.ReLU() + ) + self.ca = ChannelAttention(out_chan, ratio=16) + self.init_weight() + + def forward(self, x, y): + fuse_fea = self.convblk(torch.cat((x, y), dim=1)) + #fuse_fea = fuse_fea + fuse_fea * self.ca(fuse_fea) # Is it correct? F *(1+a) or F * a? + fuse_fea = fuse_fea * self.ca(fuse_fea) # changed by yong + return fuse_fea + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: + nn.init.constant_(ly.bias, 0) diff --git a/training/detectors/sta_detector.py b/training/detectors/sta_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..bb8bbee15bbefd552057720a97b440a5c0736602 --- /dev/null +++ b/training/detectors/sta_detector.py @@ -0,0 +1,597 @@ +""" +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the XceptionDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{rossler2019faceforensics++, + title={Faceforensics++: Learning to detect manipulated facial images}, + author={Rossler, Andreas and Cozzolino, Davide and Verdoliva, Luisa and Riess, Christian and Thies, Justus and Nie{\ss}ner, Matthias}, + booktitle={Proceedings of the IEEE/CVF international conference on computer vision}, + pages={1--11}, + year={2019} +} +""" + +import logging +from collections import OrderedDict + +import clip +import math +import numpy as np +import torch +import torch.nn as nn +from detectors import DETECTOR +from einops import rearrange +from loss import LOSSFUNC +from metrics.base_metrics_class import calculate_metrics_for_train +from sklearn import metrics + +from .base_detector import AbstractDetector + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='sta_clip') +class StACLIPDetector(AbstractDetector): + def __init__(self, config, demo=None): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.head = I3DHead( + num_classes=2, + in_channels=1024, + spatial_type='avg', + dropout_ratio=0.5 + ) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + assert self.config['resolution'] == 224, 'The resolution of the input image should be 224x224' + # assert self.config['clip_size'] == 8, 'The number of frames should be 8' + # prepare the backbone + backbone = ViT_CLIP( + input_resolution=224, + num_frames=self.config['clip_size'], + patch_size=14, + width=1024, + layers=14, + heads=16, + drop_path_rate=0.1, + num_tadapter=1, + adapter_scale=0.5, + pretrained=True + ) + + # ## freeze some parameters + # for name, param in backbone.named_parameters(): + # if 'temporal_embedding' not in name and 'ln_post' not in name and 'cls_head' not in name and 'Adapter' not in name: + # param.requires_grad = False + + # for name, param in backbone.named_parameters(): + # print('{}: {}'.format(name, param.requires_grad)) + # num_param = sum(p.numel() for p in backbone.parameters() if p.requires_grad) + # num_total_param = sum(p.numel() for p in backbone.parameters()) + # print('Number of total parameters: {}, tunable parameters: {}'.format(num_total_param, num_param)) + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + feat = self.backbone(data_dict['image']) + return feat + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label.long()) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + # compute metrics for batch data + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + return metric_batch_dict + + def get_test_metrics(self): + y_pred, y_true = self.video_calculation(self.video_names, self.prob, self.label) + # auc + fpr, tpr, thresholds = metrics.roc_curve(y_true, y_pred, pos_label=1) + auc = metrics.auc(fpr, tpr) + # eer + fnr = 1 - tpr + eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))] + # ap + ap = metrics.average_precision_score(y_true, y_pred) + # acc + acc = self.correct / self.total + # reset the prob and label + + return {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap, 'pred': y_pred, 'label': y_true} + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + prob = torch.softmax(pred, dim=1)[:, 1] + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + + return pred_dict + + +class Adapter(nn.Module): + def __init__(self, D_features, mlp_ratio=0.25, act_layer=nn.GELU, skip_connect=True): + super().__init__() + self.skip_connect = skip_connect + D_hidden_features = int(D_features * mlp_ratio) + self.act = act_layer() + self.D_fc1 = nn.Linear(D_features, D_hidden_features) + self.D_fc2 = nn.Linear(D_hidden_features, D_features) + + def forward(self, x): + # x is (BT, HW+1, D) + xs = self.D_fc1(x) + xs = self.act(xs) + xs = self.D_fc2(xs) + if self.skip_connect: + x = x + xs + else: + x = xs + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class LayerNormProxy(nn.Module): + def __init__(self, dim): + super().__init__() + self.norm = nn.LayerNorm(dim) + + def forward(self, x): + x = rearrange(x, 'b c t h w -> b t h w c') + x = self.norm(x) + x = rearrange(x, 'b t h w c -> b c t h w') + return x + + +class DepthwiseConv3D(nn.Module): + def __init__(self, in_channels, kernel_size): + super().__init__() + self.conv1 = nn.Conv3d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, + groups=in_channels, + padding=(kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[2] // 2)) + self.bn1 = nn.BatchNorm3d(num_features=in_channels) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv3d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, + groups=in_channels, + padding=(kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[2] // 2)) + self.bn2 = nn.BatchNorm3d(num_features=in_channels) + self.relu2 = nn.ReLU(inplace=True) + + self.conv3 = nn.Conv3d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, + groups=in_channels, + padding=(kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[2] // 2)) + self.bn3 = nn.BatchNorm3d(num_features=in_channels) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu1(x) + + x = self.conv2(x) + x = self.bn2(x) + x = self.relu2(x) + + x = self.conv3(x) + x = self.bn3(x) + return x + + +# class ViT_Adapter(nn.Module): +# def __init__(self, num_frames=8, in_channels=1024, out_channels=1024): +# super().__init__() +# self.num_frames=num_frames +# self.in_channels = in_channels +# self.out_channels = out_channels +# self.adapter_channels = int(1024 * 0.5) + +# self.down = nn.Linear(in_features=self.in_channels, out_features=self.adapter_channels) +# self.gelu1 = nn.GELU() + +# self.s_conv = DepthwiseConv3D(in_channels=self.adapter_channels, kernel_size=(1, 3, 3)) +# self.t_conv = DepthwiseConv3D(in_channels=self.adapter_channels, kernel_size=(3, 1, 1)) + +# self.gelu = nn.GELU() + +# self.up = nn.Linear(in_features=self.adapter_channels, out_features=self.out_channels) +# self.gelu2 = nn.GELU() + +# def forward(self, x): +# # hw+1 bt c +# n, bt, c = x.shape +# H = round(math.sqrt(n - 1)) +# x_in = x + +# x = self.down(x) +# x = self.gelu1(x) + +# cls = x[0, :, :].unsqueeze(0) +# x = x[1:, :, :] + +# x = rearrange(x, '(h w) (b t) c -> b c t h w', t=self.num_frames, h=H) + +# # Apply depthwise 3D convolutions +# xs = self.s_conv(x) +# xt = self.t_conv(x) + +# # Fusion of xs and xt +# x = (xs + xt) / 2 +# x = self.gelu(x) +# x = rearrange(x, 'b c t h w -> (h w) (b t) c') + +# x = torch.cat([cls, x], dim=0) + +# x = self.up(x) +# x = self.gelu2(x) + +# # residual +# x += x_in +# return x + + +class ViT_Adapter(nn.Module): + def __init__(self, num_frames=8, in_channels=1024, out_channels=1024): + super().__init__() + self.num_frames = num_frames + self.in_channels = in_channels + self.out_channels = out_channels + self.adapter_channels = int(1024 * 0.5) + + self.down = nn.Linear(in_features=self.in_channels, out_features=self.adapter_channels) + self.gelu1 = nn.GELU() + + self.s_conv1 = DepthwiseConv3D(in_channels=self.adapter_channels, kernel_size=(1, 3, 3)) + self.s_conv2 = DepthwiseConv3D(in_channels=self.adapter_channels, kernel_size=(1, 5, 5)) + self.s_conv3 = DepthwiseConv3D(in_channels=self.adapter_channels, kernel_size=(1, 7, 7)) + + self.t_conv1 = DepthwiseConv3D(in_channels=self.adapter_channels, kernel_size=(3, 1, 1)) + self.t_conv2 = DepthwiseConv3D(in_channels=self.adapter_channels, kernel_size=(5, 1, 1)) + self.t_conv3 = DepthwiseConv3D(in_channels=self.adapter_channels, kernel_size=(7, 1, 1)) + + self.gelu = nn.GELU() + + self.up = nn.Linear(in_features=self.adapter_channels, out_features=self.out_channels) + self.gelu2 = nn.GELU() + + self.cross_attention = CrossAttention(embed_dim=self.adapter_channels, num_heads=2) + + def forward(self, x): + # hw+1 bt c + n, bt, c = x.shape + H = round(math.sqrt(n - 1)) + x_in = x + + x = self.down(x) + x = self.gelu1(x) + + cls = x[0, :, :].unsqueeze(0) + x = x[1:, :, :] + + x = rearrange(x, '(h w) (b t) c -> b c t h w', t=self.num_frames, h=H) + + # Apply depthwise 3D convolutions + xs1 = self.s_conv1(x) + xs2 = self.s_conv2(x) + xs3 = self.s_conv3(x) + + xt1 = self.t_conv1(x) + xt2 = self.t_conv2(x) + xt3 = self.t_conv3(x) + + # Fusion of xs and xt with residual connections + xs = (xs1 + xs2 + xs3) / 3 + x + xt = (xt1 + xt2 + xt3) / 3 + x + + # cross attention + xs, xt = self.cross_attention(xs, xt) + + x = (xs + xt) / 2 + x = self.gelu(x) + x = rearrange(x, 'b c t h w -> (h w) (b t) c') + + x = torch.cat([cls, x], dim=0) + + x = self.up(x) + x = self.gelu2(x) + + # residual + x += x_in + return x + + +class CrossAttention(nn.Module): + def __init__(self, embed_dim, num_heads): + super().__init__() + self.spatial_to_temporal = nn.MultiheadAttention(embed_dim, num_heads) + self.temporal_to_spatial = nn.MultiheadAttention(embed_dim, num_heads) + + def forward(self, spatial, temporal): + # B, C, T, H, W + B, C, T, H, W = spatial.shape + + # Flatten the spatial and temporal dimensions + spatial = spatial.view(B, C, T, H * W) # [B, C, T, H*W] + temporal = temporal.view(B, C, T, H * W) # [B, C, T, H*W] + + # Permute to [T*H*W, B, C] for MultiheadAttention + spatial = spatial.permute(2, 0, 3, 1).reshape(T * H * W, B, C) # [T*H*W, B, C] + temporal = temporal.permute(2, 0, 3, 1).reshape(T * H * W, B, C) # [T*H*W, B, C] + + # Apply cross attention + s2t, _ = self.spatial_to_temporal(temporal, spatial, spatial) + t2s, _ = self.temporal_to_spatial(spatial, temporal, temporal) + + # Reshape back to original dimensions + s2t = s2t.view(T, H * W, B, C).permute(2, 3, 0, 1).reshape(B, C, T, H, W) + t2s = t2s.view(T, H * W, B, C).permute(2, 3, 0, 1).reshape(B, C, T, H, W) + + return s2t, t2s + + +class ResidualAttentionBlock(nn.Module): + def __init__(self): + super().__init__() + d_model = 1024 + n_head = 16 + self.ln_1 = LayerNorm(d_model) + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_2 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.Adapter = ViT_Adapter() + + def attention(self, x): + return self.attn(x, x, x, need_weights=False)[0] + + def forward(self, x): + # x shape [HW+1, BT, C] + x = x + self.attention(self.ln_1(x)) + # x = self.Adapter(x) + x = x + self.mlp(self.ln_2(x)) + x = self.Adapter(x) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, num_tadapter=1, scale=1., + drop_path=0.1): + super().__init__() + self.width = width + self.layers = layers + dpr = [x.item() for x in torch.linspace(0, drop_path, self.layers)] + self.resblocks = nn.Sequential(*[ResidualAttentionBlock() for i in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class ViT_CLIP(nn.Module): + ## ViT definition in CLIP image encoder + def __init__(self, input_resolution: int, num_frames: int, patch_size: int, width: int, layers: int, heads: int, + drop_path_rate, num_tadapter=1, adapter_scale=0.5, pretrained=None): + super().__init__() + self.input_resolution = input_resolution + self.pretrained = pretrained + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width ** -0.5 + self.layers = layers + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.num_frames = num_frames + self.temporal_embedding = nn.Parameter(torch.zeros(1, num_frames, width)) + + self.transformer = Transformer(width, layers, heads, num_tadapter=num_tadapter, scale=adapter_scale, + drop_path=drop_path_rate) + + self.ln_post = LayerNorm(width) + + # self.init_weights() + + def init_weights(self): + logger.info(f'load model from: {self.pretrained}') + # Load OpenAI CLIP pretrained weights + clip_model, preprocess = clip.load("ViT-L/14", device="cpu") + pretrain_dict = clip_model.visual.state_dict() + del clip_model + del pretrain_dict['proj'] + msg = self.load_state_dict(pretrain_dict, strict=False) + logger.info('Missing keys: {}'.format(msg.missing_keys)) + logger.info('Unexpected keys: {}'.format(msg.unexpected_keys)) + logger.info(f"=> loaded successfully '{self.pretrained}'") + torch.cuda.empty_cache() + # zero-initialize Adapters + for n1, m1 in self.named_modules(): + if 'Adapter' in n1: + for n2, m2 in m1.named_modules(): + if 'up' in n2: + logger.info('init: {}.{}'.format(n1, n2)) + nn.init.constant_(m2.weight, 0) + nn.init.constant_(m2.bias, 0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed', 'temporal_embedding'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table', 'temporal_position_bias_table'} + + def extract_class_indices(self, labels, which_class): + class_mask = torch.eq(labels, which_class) + class_mask_indices = torch.nonzero(class_mask, as_tuple=False) + return torch.reshape(class_mask_indices, (-1,)) + + def get_feat(self, x): + x = rearrange(x, 'b t c h w -> (b t) c h w') + x = self.conv1(x) + x = x.reshape(x.shape[0], x.shape[1], -1) + x = x.permute(0, 2, 1) + x = torch.cat( + [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), + x], dim=1) + x = x + self.positional_embedding.to(x.dtype) + # n = h*w+1 + n = x.shape[1] + + x = rearrange(x, '(b t) n c -> (b n) t c', t=self.num_frames) + x = x + self.temporal_embedding + x = rearrange(x, '(b n) t c -> (b t) n c', n=n) + + x = self.ln_pre(x) + x = x.permute(1, 0, 2) + x = self.transformer(x) + x = x.permute(1, 0, 2) + x = self.ln_post(x) + return x + + def forward(self, x: torch.Tensor): + B, T, C, H, W = x.shape + x = self.get_feat(x) + + x = x[:, 0] + x = rearrange(x, '(b t) d -> b d t', b=B, t=T) + + x = x.unsqueeze(-1).unsqueeze(-1) # BDTHW for I3D head + + return x + + +class I3DHead(nn.Module): + """Classification head for I3D. + + Args: + num_classes (int): Number of classes to be classified. + in_channels (int): Number of channels in input feature. + Default: dict(type='CrossEntropyLoss') + spatial_type (str): Pooling type in spatial dimension. Default: 'avg'. + dropout_ratio (float): Probability of dropout layer. Default: 0.5. + kwargs (dict, optional): Any keyword argument to be used to initialize + the head. + """ + + def __init__(self, + num_classes, + in_channels, + spatial_type='avg', + dropout_ratio=0.5, + **kwargs): + super().__init__() + self.num_classes = num_classes + self.in_channels = in_channels + self.spatial_type = spatial_type + self.dropout_ratio = dropout_ratio + if self.dropout_ratio != 0: + self.dropout = nn.Dropout(p=self.dropout_ratio) + else: + self.dropout = None + self.fc_cls = nn.Linear(self.in_channels, self.num_classes) + + if self.spatial_type == 'avg': + # use `nn.AdaptiveAvgPool3d` to adaptively match the in_channels. + self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) + else: + self.avg_pool = None + + def forward(self, x): + """Defines the computation performed at every call. + + Args: + x (torch.Tensor): The input data. + + Returns: + torch.Tensor: The classification scores for input samples. + """ + # [N, in_channels, 4, 7, 7] + if self.avg_pool is not None: + x = self.avg_pool(x) + # [N, in_channels, 1, 1, 1] + if self.dropout is not None: + x = self.dropout(x) + # [N, in_channels, 1, 1, 1] + x = x.view(x.shape[0], -1) + # [N, in_channels] + cls_score = self.fc_cls(x) + # [N, num_classes] + return cls_score + + +if __name__ == '__main__': + vit_model = ViT_CLIP( + input_resolution=224, + num_frames=8, + patch_size=16, + width=768, + layers=12, + heads=12, + drop_path_rate=0.1, + num_tadapter=1, + adapter_scale=0.5, + pretrained=True + ) + + i3d_head = I3DHead( + num_classes=2, + in_channels=768, + spatial_type='avg', + dropout_ratio=0.5 + ) + + rand_input = torch.rand(2, 8, 3, 224, 224) + feat = vit_model(rand_input) + print(feat.shape) + output = i3d_head(feat) + print(output.shape) diff --git a/training/detectors/stil_detector.py b/training/detectors/stil_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..bba2dcabc0f23c2f551574018920900d57e3e860 --- /dev/null +++ b/training/detectors/stil_detector.py @@ -0,0 +1,747 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the STILDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{gu2021spatiotemporal, + title={Spatiotemporal inconsistency learning for deepfake video detection}, + author={Gu, Zhihao and Chen, Yang and Yao, Taiping and Ding, Shouhong and Li, Jilin and Huang, Feiyue and Ma, Lizhuang}, + booktitle={Proceedings of the 29th ACM international conference on multimedia}, + pages={3473--3481}, + year={2021} +} +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torch.utils.model_zoo as model_zoo +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC + + +logger = logging.getLogger(__name__) + +@DETECTOR.register_module(module_name='stil') +class STILDetector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.model = self.build_backbone(config) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + backbone = STIL_Model(num_class=2, num_segment=config['clip_size'], add_softmax=False) + pretrained_path = config['pretrained'] + if pretrained_path: + state_dict = torch.load(pretrained_path) + state_dict = {k.replace("base_", "").replace("model.", ""): v for k, v in state_dict.items()} + state_dict = {"base_model." + k: v for k, v in state_dict.items()} + msg = backbone.load_state_dict(state_dict, False) + print('Missing keys: {}'.format(msg.missing_keys)) + print('Unexpected keys: {}'.format(msg.unexpected_keys)) + print(f"=> loaded successfully '{pretrained_path}'") + torch.cuda.empty_cache() + return backbone + + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + # STIL requires the input with the shape of (n, t*c, h, w), where n is the batch_size, t is num_segment + bs, t, c, h, w = data_dict['image'].shape + inputs = data_dict['image'].view(bs, t*c, h, w) + pred = self.model(inputs) + return pred + + def classifier(self, features: torch.tensor): + pass + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'].long() + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + # compute metrics for batch data + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + # we dont compute the video-level metrics for training + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the prediction by backbone + pred = self.features(data_dict) + # get the probability of the pred + prob = torch.softmax(pred, dim=1)[:, 1] + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': prob} + + return pred_dict + + + +class STIL_Model(nn.Module): + def __init__(self, + num_class=2, + num_segment=8, + add_softmax=False, + **kwargs): + """ Model Builder for STIL model. + STIL: Spatiotemporal Inconsistency Learning for DeepFake Video Detection (https://arxiv.org/abs/2109.01860) + + Args: + num_class (int, optional): Number of classes. Defaults to 2. + num_segment (int, optional): Number of segments (frames) fed to the model. Defaults to 8. + add_softmax (bool, optional): Whether to add softmax layer at the end. Defaults to False. + """ + super().__init__() + + self.num_class = num_class + self.num_segment = num_segment + + self.add_softmax = add_softmax + + self.build_model() + + + def build_model(self): + """ + Construct the model. + """ + self.base_model = scnet50_v1d(self.num_segment, pretrained=True) + + fc_feature_dim = self.base_model.fc.in_features + self.base_model.fc = nn.Linear(fc_feature_dim, self.num_class) + + if self.add_softmax: + self.softmax_layer = nn.Softmax(dim=1) + + + def forward(self, x): + """Forward pass of the model. + + Args: + x (torch.tensor): input tensor of shape (n, t*c, h, w). n is the batch_size, t is num_segment + """ + # img channel default to 3 + img_channel = 3 + + # x: [n, tc, h, w] -> [nt, c, h, w] + # out: [nt, num_class] + out = self.base_model( + x.view((-1, img_channel) + x.size()[2:]) + ) + + out = out.view(-1, self.num_segment, self.num_class) # [n, t, num_class] + out = out.mean(1, keepdim=False) # [n, num_class] + + if self.add_softmax: + out = self.softmax_layer(out) + + return out + + + def set_segment(self, num_segment): + """Change num_segment of the model. + Useful when the train and test want to feed different number of frames. + + Args: + num_segment (int): New number of segments. + """ + self.num_segment = num_segment + + + +model_urls = { + 'scnet50_v1d': 'https://backseason.oss-cn-beijing.aliyuncs.com/scnet/scnet50_v1d-4109d1e1.pth', +} + + +class ISM_Module(nn.Module): + def __init__(self, k_size=3): + """The Information Supplement Module (ISM). + + Args: + k_size (int, optional): Conv1d kernel_size . Defaults to 3. + """ + super(ISM_Module, self).__init__() + + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size-1)//2, bias=False) + self.sigmoid = nn.Sigmoid() + + + def forward(self, x): + """ + Args: + x (torch.tensor): Input tensor of shape (nt, c, h, w) + """ + y = self.avg_pool(x) + y = self.conv(y.squeeze(-1).transpose(-1,-2)).transpose(-1,-2).unsqueeze(-1) + y = self.sigmoid(y) + return x * y.expand_as(x) + + +class TIM_Module(nn.Module): + def __init__(self, in_channels, reduction=16, n_segment=8, return_attn=False): + """The Temporal Inconsistency Module (TIM). + + Args: + in_channels (int): Input channel number. + reduction (int, optional): Channel compression ratio r in the split operation.. Defaults to 16. + n_segment (int, optional): Number of input frames.. Defaults to 8. + return_attn (bool, optional): Whether to return the attention part. Defaults to False. + + """ + super(TIM_Module, self).__init__() + self.in_channels = in_channels + self.reduction = reduction + self.n_segment = n_segment + self.return_attn = return_attn + + self.reduced_channels = self.in_channels // self.reduction + + # first conv to shrink input channels + self.conv1 = nn.Conv2d(self.in_channels, self.reduced_channels, kernel_size=1, padding=0, bias=False) + self.bn1 = nn.BatchNorm2d(self.reduced_channels) + + self.conv_ht = nn.Conv2d(self.reduced_channels, self.reduced_channels, + kernel_size=(3, 1), padding=(1, 0), groups=self.reduced_channels, bias=False) + self.conv_tw = nn.Conv2d(self.reduced_channels, self.reduced_channels, + kernel_size=(1, 3), padding=(0, 1), groups=self.reduced_channels, bias=False) + + self.avg_pool_ht = nn.AvgPool2d((2, 1), (2, 1)) + self.avg_pool_tw = nn.AvgPool2d((1, 2), (1, 2)) + + # HTIE in two directions + self.htie_conv1 = nn.Sequential( + nn.Conv2d(self.reduced_channels, self.reduced_channels, kernel_size=(3, 1), padding=(1, 0), bias=False), + nn.BatchNorm2d(self.reduced_channels), + ) + self.vtie_conv1 = nn.Sequential( + nn.Conv2d(self.reduced_channels, self.reduced_channels, kernel_size=(1, 3), padding=(0, 1), bias=False), + nn.BatchNorm2d(self.reduced_channels), + ) + self.htie_conv2 = nn.Sequential( + nn.Conv2d(self.reduced_channels, self.reduced_channels, kernel_size=(3, 1), padding=(1, 0), bias=False), + nn.BatchNorm2d(self.reduced_channels), + ) + self.vtie_conv2 = nn.Sequential( + nn.Conv2d(self.reduced_channels, self.reduced_channels, kernel_size=(1, 3), padding=(0, 1), bias=False), + nn.BatchNorm2d(self.reduced_channels), + ) + self.ht_up_conv = nn.Sequential( + nn.Conv2d(self.reduced_channels, self.in_channels, kernel_size=1, bias=False), + nn.BatchNorm2d(self.in_channels) + ) + self.tw_up_conv = nn.Sequential( + nn.Conv2d(self.reduced_channels, self.in_channels, kernel_size=1, bias=False), + nn.BatchNorm2d(self.in_channels) + ) + + self.sigmoid = nn.Sigmoid() + + + def feat_ht(self, feat): + """The H-T branch in the TIM module. + + Args: + feat (torch.tensor): Input feature with shape [n, t, c, h, w] (c is in_channels // reduction) + + """ + n, t, c, h, w = feat.size() + # [n, t, c, h, w] -> [n, w, c, h, t] -> [nw, c, h, t] + feat_h = feat.permute(0, 4, 2, 3, 1).contiguous().view(-1, c, h, t) + + # [nw, c, h, t-1] + feat_h_fwd, _ = feat_h.split([self.n_segment-1, 1], dim=3) + feat_h_conv = self.conv_ht(feat_h) + _, feat_h_conv_fwd = feat_h_conv.split([1, self.n_segment-1], dim=3) + + diff_feat_fwd = feat_h_conv_fwd - feat_h_fwd + diff_feat_fwd = F.pad(diff_feat_fwd, [0, 1], value=0) # [nw, c, h, t] + + # HTIE, down_up branch + diff_feat_fwd1 = self.avg_pool_ht(diff_feat_fwd) # [nw, c, h//2, t] + diff_feat_fwd1 = self.htie_conv1(diff_feat_fwd1) # [nw, c, h//2, t] + diff_feat_fwd1 = F.interpolate(diff_feat_fwd1, diff_feat_fwd.size()[2:]) # [nw, c, h, t] + # HTIE, direct conv branch + diff_feat_fwd2 = self.htie_conv2(diff_feat_fwd) # [nw, c, h, t] + + # [nw, C, h, t] + feat_ht_out = self.ht_up_conv(1/3. * diff_feat_fwd + 1/3. * diff_feat_fwd1 + 1/3. * diff_feat_fwd2) + feat_ht_out = self.sigmoid(feat_ht_out) - 0.5 + # [nw, C, h, t] -> [n, w, C, h, t] -> [n, t, C, h, w] + feat_ht_out = feat_ht_out.view(n, w, self.in_channels, h, t).permute(0, 4, 2, 3, 1).contiguous() + # [n, t, C, h, w] -> [nt, C, h, w] + feat_ht_out = feat_ht_out.view(-1, self.in_channels, h, w) + + return feat_ht_out + + + def feat_tw(self, feat): + """The T-W branch in the TIM module. + + Args: + feat (torch.tensor): Input feature with shape [n, t, c, h, w] (c is in_channels // reduction) + """ + n, t, c, h, w = feat.size() + # [n, t, c, h, w] -> [n, h, c, t, w] -> [nh, c, t, w] + feat_w = feat.permute(0, 3, 2, 1, 4).contiguous().view(-1, c, t, w) + + # [nh, c, t-1, w] + feat_w_fwd, _ = feat_w.split([self.n_segment-1, 1], dim=2) + feat_w_conv = self.conv_tw(feat_w) + _, feat_w_conv_fwd = feat_w_conv.split([1, self.n_segment-1], dim=2) + + diff_feat_fwd = feat_w_conv_fwd - feat_w_fwd + diff_feat_fwd = F.pad(diff_feat_fwd, [0, 0, 0, 1], value=0) # [nh, c, t, w] + + # VTIE, down_up branch + diff_feat_fwd1 = self.avg_pool_tw(diff_feat_fwd) # [nh, c, t, w//2] + diff_feat_fwd1 = self.vtie_conv1(diff_feat_fwd1) # [nh, c, t, w//2] + diff_feat_fwd1 = F.interpolate(diff_feat_fwd1, diff_feat_fwd.size()[2:]) # [nh, c, t, w] + # VTIE, direct conv branch + diff_feat_fwd2 = self.vtie_conv2(diff_feat_fwd) # [nh, c, t, w] + + # [nh, C, t, w] + feat_tw_out = self.tw_up_conv(1/3. * diff_feat_fwd + 1/3. * diff_feat_fwd1 + 1/3. * diff_feat_fwd2) + feat_tw_out = self.sigmoid(feat_tw_out) - 0.5 + # [nh, C, t, w] -> [n, h, C, t, w] -> [n, t, C, h, W] + feat_tw_out = feat_tw_out.view(n, h, self.in_channels, t, w).permute(0, 3, 2, 1, 4).contiguous() + # [n, t, C, h, w] -> [nt, C, h, w] + feat_tw_out = feat_tw_out.view(-1, self.in_channels, h, w) + + return feat_tw_out + + + def forward(self, x): + """ + Args: + x (torch.tensor): Input with shape [nt, c, h, w] + """ + # [nt, c, h, w] -> [nt, c//r, h, w] + bottleneck = self.conv1(x) + bottleneck = self.bn1(bottleneck) + # [nt, c//r, h, w] -> [n, t, c//r, h, w] + bottleneck = bottleneck.view((-1, self.n_segment) + bottleneck.size()[1:]) + + F_h = self.feat_ht(bottleneck) # [nt, c, h, w] + F_w = self.feat_tw(bottleneck) # [nt, c, h, w] + + att = 0.5 * (F_h + F_w) + + if self.return_attn: + return att + + y2 = x + x * att + + return y2 + + +class ShiftModule(nn.Module): + def __init__(self, input_channels, n_segment=8, n_div=8, mode='shift'): + """A depth-wise conv on the segment level. + + Args: + input_channels (int): Input channel number. + n_segment (int, optional): Number of input frames.. Defaults to 8. + n_div (int, optional): How many channels to group as a fold.. Defaults to 8. + mode (str, optional): One of "shift", "fixed", "norm". Defaults to 'shift'. + """ + super(ShiftModule, self).__init__() + self.input_channels = input_channels + self.n_segment = n_segment + self.fold_div = n_div + self.fold = self.input_channels // self.fold_div + self.conv = nn.Conv1d(self.fold_div*self.fold, self.fold_div*self.fold, + kernel_size=3, padding=1, groups=self.fold_div*self.fold, + bias=False) + + if mode == 'shift': + self.conv.weight.requires_grad = True + self.conv.weight.data.zero_() + # shift left + self.conv.weight.data[:self.fold, 0, 2] = 1 + # shift right + self.conv.weight.data[self.fold: 2 * self.fold, 0, 0] = 1 + if 2*self.fold < self.input_channels: + self.conv.weight.data[2 * self.fold:, 0, 1] = 1 # fixed + elif mode == 'fixed': + self.conv.weight.requires_grad = True + self.conv.weight.data.zero_() + self.conv.weight.data[:, 0, 1] = 1 # fixed + elif mode == 'norm': + self.conv.weight.requires_grad = True + + + def forward(self, x): + """ + Args: + x (torch.tensor): Input with shape [nt, c, h, w] + """ + nt, c, h, w = x.size() + n_batch = nt // self.n_segment + x = x.view(n_batch, self.n_segment, c, h, w) + # (n, h, w, c, t) + x = x.permute(0, 3, 4, 2, 1) + x = x.contiguous().view(n_batch*h*w, c, self.n_segment) + # (n*h*w, c, t) + x = self.conv(x) + x = x.view(n_batch, h, w, c, self.n_segment) + # (n, t, c, h, w) + x = x.permute(0, 4, 3, 1, 2) + x = x.contiguous().view(nt, c, h, w) + return x + + +class SCConv(nn.Module): + """ + The spatial conv in SIM. Used in SCBottleneck + """ + def __init__(self, inplanes, planes, stride, padding, dilation, groups, pooling_r, norm_layer): + super(SCConv, self).__init__() + self.f_w = nn.Sequential( + nn.AvgPool2d(kernel_size=pooling_r, stride=pooling_r), + nn.Conv2d(inplanes, planes, kernel_size=(1,3), stride=1, + padding=(0,padding), dilation=(1,dilation), + groups=groups, bias=False), + norm_layer(planes), nn.ReLU(inplace=True)) + self.f_h = nn.Sequential( + # nn.AvgPool2d(kernel_size=(pooling_r,1), stride=(pooling_r,1)), + nn.Conv2d(inplanes, planes, kernel_size=(3,1), stride=1, + padding=(padding,0), dilation=(dilation,1), + groups=groups, bias=False), + norm_layer(planes), + ) + self.k3 = nn.Sequential( + nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, + padding=padding, dilation=dilation, + groups=groups, bias=False), + norm_layer(planes), + ) + self.k4 = nn.Sequential( + nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, + padding=padding, dilation=dilation, + groups=groups, bias=False), + norm_layer(planes), + ) + + + def forward(self, x): + identity = x + + # sigmoid(identity + k2) + out = torch.sigmoid( + torch.add( + identity, + F.interpolate(self.f_h(self.f_w(x)), identity.size()[2:]) + ) + ) + out = torch.mul(self.k3(x), out) # k3 * sigmoid(identity + k2) + s2t_info = out + out = self.k4(out) # k4 + + return out, s2t_info + + +class SCBottleneck(nn.Module): + """ + SCNet SCBottleneck. Variant for ResNet Bottlenect. + """ + expansion = 4 + pooling_r = 4 # down-sampling rate of the avg pooling layer in the K3 path of SC-Conv. + + def __init__(self, num_segments, inplanes, planes, stride=1, downsample=None, + cardinality=1, bottleneck_width=32, + avd=False, dilation=1, is_first=False, + norm_layer=None): + super(SCBottleneck, self).__init__() + group_width = int(planes * (bottleneck_width / 64.)) * cardinality + self.conv1_a = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False) + self.bn1_a = norm_layer(group_width) + self.conv1_b = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False) + self.bn1_b = norm_layer(group_width) + self.avd = avd and (stride > 1 or is_first) + self.tim = TIM_Module(group_width, n_segment=num_segments) + self.shift = ShiftModule(group_width, n_segment=num_segments, n_div=8, mode='shift') + self.inplanes = inplanes + self.planes = planes + self.ism = ISM_Module() + self.shift = ShiftModule(group_width, n_segment=num_segments, n_div=8, mode='shift') + + if self.avd: + self.avd_layer = nn.AvgPool2d(3, stride, padding=1) + stride = 1 + + self.k1 = nn.Sequential( + nn.Conv2d( + group_width, group_width, kernel_size=3, stride=stride, + padding=dilation, dilation=dilation, + groups=cardinality, bias=False), + norm_layer(group_width), + ) + + self.scconv = SCConv( + group_width, group_width, stride=stride, + padding=dilation, dilation=dilation, + groups=cardinality, pooling_r=self.pooling_r, norm_layer=norm_layer) + + self.conv3 = nn.Conv2d( + group_width * 2, planes * 4, kernel_size=1, bias=False) + self.bn3 = norm_layer(planes*4) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.dilation = dilation + + + def forward(self, x): + """Forward func which splits the input into two branchs a and b. + a: trace features + b: spatial features + """ + residual = x + + out_a = self.relu(self.bn1_a(self.conv1_a(x))) + out_b = self.relu(self.bn1_b(self.conv1_b(x))) + + # spatial representations + out_b, s2t_info = self.scconv(out_b) + out_b = self.relu(out_b) + + # trace features + out_a = self.tim(out_a) + out_a = self.shift(out_a + self.ism(s2t_info)) + out_a = self.relu(self.k1(out_a)) + + if self.avd: + out_a = self.avd_layer(out_a) + out_b = self.avd_layer(out_b) + + out = self.conv3(torch.cat([out_a, out_b], dim=1)) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class SCNet(nn.Module): + def __init__(self, num_segments, block, layers, groups=1, bottleneck_width=32, + num_classes=1000, dilated=False, dilation=1, + deep_stem=False, stem_width=64, avg_down=False, + avd=False, norm_layer=nn.BatchNorm2d): + """SCNet, a variant based on ResNet. + + Args: + num_segments (int): + Number of input frames. + block (class): + Class for the residual block. + layers (list): + Number of layers in each block. + num_classes (int, optional): + Number of classification class.. Defaults to 1000. + dilated (bool, optional): + Whether to apply dilation conv. Defaults to False. + dilation (int, optional): + The dilation parameter in dilation conv. Defaults to 1. + deep_stem (bool, optional): + Whether to replace 7x7 conv in input stem with 3 3x3 conv. Defaults to False. + stem_width (int, optional): + Stem width in conv1 stem. Defaults to 64. + avg_down (bool, optional): + Whether to use AvgPool instead of stride conv when downsampling in the bottleneck. Defaults to False. + avd (bool, optional): + The avd parameter for the block Defaults to False. + norm_layer (class, optional): + Normalization layer. Defaults to nn.BatchNorm2d. + """ + self.cardinality = groups + self.bottleneck_width = bottleneck_width + # ResNet-D params + self.inplanes = stem_width*2 if deep_stem else 64 + self.avg_down = avg_down + self.avd = avd + self.num_segments = num_segments + + super(SCNet, self).__init__() + conv_layer = nn.Conv2d + if deep_stem: + self.conv1 = nn.Sequential( + conv_layer(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False), + norm_layer(stem_width), + nn.ReLU(inplace=True), + conv_layer(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False), + norm_layer(stem_width), + nn.ReLU(inplace=True), + conv_layer(stem_width, stem_width*2, kernel_size=3, stride=1, padding=1, bias=False), + ) + else: + self.conv1 = conv_layer(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer, is_first=False) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer) + if dilated or dilation == 4: + self.layer3 = self._make_layer(block, 256, layers[2], stride=1, + dilation=2, norm_layer=norm_layer) + self.layer4 = self._make_layer(block, 512, layers[3], stride=1, + dilation=4, norm_layer=norm_layer) + elif dilation==2: + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilation=1, norm_layer=norm_layer) + self.layer4 = self._make_layer(block, 512, layers[3], stride=1, + dilation=2, norm_layer=norm_layer) + else: + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + norm_layer=norm_layer) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + norm_layer=norm_layer) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, norm_layer): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + + def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=None, + is_first=True): + """ + Core function to build layers. + """ + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + down_layers = [] + if self.avg_down: + if dilation == 1: + down_layers.append(nn.AvgPool2d(kernel_size=stride, stride=stride, + ceil_mode=True, count_include_pad=False)) + else: + down_layers.append(nn.AvgPool2d(kernel_size=1, stride=1, + ceil_mode=True, count_include_pad=False)) + down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=1, bias=False)) + else: + down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False)) + down_layers.append(norm_layer(planes * block.expansion)) + downsample = nn.Sequential(*down_layers) + + layers = [] + if dilation == 1 or dilation == 2: + layers.append(block(self.num_segments, self.inplanes, planes, stride, downsample=downsample, + cardinality=self.cardinality, + bottleneck_width=self.bottleneck_width, + avd=self.avd, dilation=1, is_first=is_first, + norm_layer=norm_layer)) + elif dilation == 4: + layers.append(block(self.num_segments, self.inplanes, planes, stride, downsample=downsample, + cardinality=self.cardinality, + bottleneck_width=self.bottleneck_width, + avd=self.avd, dilation=2, is_first=is_first, + norm_layer=norm_layer)) + else: + raise RuntimeError("=> unknown dilation size: {}".format(dilation)) + + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.num_segments, self.inplanes, planes, + cardinality=self.cardinality, + bottleneck_width=self.bottleneck_width, + avd=self.avd, dilation=dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + + def features(self, input): + x = self.conv1(input) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + return x + + + def logits(self, features): + x = self.avgpool(features) + x = x.view(x.size(0), -1) + x = self.fc(x) + return x + + + def forward(self, input): + x = self.features(input) + x = self.logits(x) + return x + + +def scnet50_v1d(num_segments, pretrained=False, **kwargs): + """ + SCNet backbone, which is based on ResNet-50 + Args: + num_segments (int): + Number of input frames. + pretrained (bool, optional): + Whether to load pretrained weights. + """ + model = SCNet(num_segments, SCBottleneck, [3, 4, 6, 3], + deep_stem=True, stem_width=32, avg_down=True, + avd=True, **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['scnet50_v1d']), strict=False) + + return model diff --git a/training/detectors/tall_detector.py b/training/detectors/tall_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..6b8ea89468e4c490039007158048d1bcabab1154 --- /dev/null +++ b/training/detectors/tall_detector.py @@ -0,0 +1,884 @@ +""" +# author: Kangran Zhao +# email: kangranzhao@link.cuhk.edu.cn +# date: 2023-0822 +# description: Class for the TALLDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{xu2023tall, + title={TALL: Thumbnail Layout for Deepfake Video Detection}, + author={Xu, Yuting and Liang, Jian and Jia, Gengyun and Yang, Ziming and Zhang, Yanhao and He, Ran}, + booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, + pages={22658--22668}, + year={2023} +} +""" + +import logging + +import math +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint +from detectors import DETECTOR +from einops import rearrange +from loss import LOSSFUNC +from metrics.base_metrics_class import calculate_metrics_for_train +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from torch.hub import load_state_dict_from_url + +from .base_detector import AbstractDetector + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='tall') +class TALLDetector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.model = self.build_backbone(config) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + model_kwargs = dict(num_classes=config['num_classes'], embed_dim=config['embed_dim'], + mlp_ratio=config['mlp_ratio'], patch_size=config['patch_size'], + window_size=config['window_size'], depths=config['depths'], + num_heads=config['num_heads'], ape=config['ape'], + thumbnail_rows=config['thumbnail_rows'], drop_rate=config['drop_rate'], + drop_path_rate=config['drop_path_rate'], use_checkpoint=False, bottleneck=False, + duration=config['clip_size']) + default_cfg = { + 'url': config['pretrained'], + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', } + backbone = SwinTransformer(img_size=config['resolution'], **model_kwargs) + backbone.default_cfg = default_cfg + load_pretrained(backbone, num_classes=config['num_classes'], in_chans=model_kwargs.get('in_chans', 3), + filter_fn=_conv_filter, img_size=config['resolution'], pretrained_window_size=7, + pretrained_model='') + + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + bs, t, c, h, w = data_dict['image'].shape + inputs = data_dict['image'].view(bs, t * c, h, w) + pred = self.model(inputs) + + return pred + + def classifier(self, features: torch.tensor): + pass + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'].long() + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + pred = self.features(data_dict) + prob = torch.softmax(pred, dim=1)[:, 1] + pred_dict = {'cls': pred, 'prob': prob, 'feat': prob} + + return pred_dict + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, bottleneck=False, use_checkpoint=False): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + self.use_checkpoint = use_checkpoint + + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward_attn(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + return x + + def forward_mlp(self, x): + return self.drop_path(self.mlp(self.norm2(x))) + + def forward(self, x): + shortcut = x + if self.use_checkpoint: + x = checkpoint.checkpoint(self.forward_attn, x) + else: + x = self.forward_attn(x) + x = shortcut + self.drop_path(x) + + if self.use_checkpoint: + x = x + checkpoint.checkpoint(self.forward_mlp, x) + else: + x = x + self.forward_mlp(x) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + bottleneck=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + bottleneck=bottleneck if i == depth - 1 else False, + use_checkpoint=use_checkpoint) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=(224, 224), patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + # img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + + +class SwinTransformer(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, duration=8, img_size=224, patch_size=4, in_chans=3, num_classes=1000, + embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, thumbnail_rows=1, bottleneck=False, **kwargs): + super().__init__() + + self.duration = duration # 4 + self.num_classes = num_classes # 2 + self.num_layers = len(depths) # [2, 2, 18, 2] + self.embed_dim = embed_dim # 128 + self.ape = ape # True + self.patch_norm = patch_norm # False + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.mlp_ratio = mlp_ratio # 4 = default + self.thumbnail_rows = thumbnail_rows # 2 + + self.img_size = img_size # 224 + self.window_size = [window_size for _ in depths] if not isinstance(window_size, list) else window_size + # self.image_mode = True # [14, 14, 14, 7] + + self.frame_padding = self.duration % thumbnail_rows # 0 + if self.frame_padding != 0: + self.frame_padding = self.thumbnail_rows - self.frame_padding + self.duration += self.frame_padding + + # split image into non-overlapping patches + thumbnail_dim = (thumbnail_rows, self.duration // thumbnail_rows) # (2, 2) + thumbnail_size = (img_size * thumbnail_dim[0], img_size * thumbnail_dim[1]) + + self.patch_embed = PatchEmbed( + img_size=(img_size, img_size), patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches # 16 + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution # [56, 56] + + # absolute position embedding + if self.ape: # True + self.frame_pos_embed = nn.Parameter(torch.zeros(1, self.duration, embed_dim)) + trunc_normal_(self.frame_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), + input_resolution=(patches_resolution[0] // (2 ** i_layer), + patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=self.window_size[i_layer], + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + bottleneck=bottleneck) + self.layers.append(layer) + + self.norm = norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool1d(1) + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed', 'frame_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def create_thumbnail(self, x): + # import pdb;pdb.set_trace() + input_size = x.shape[-2:] + if input_size != to_2tuple(self.img_size): + x = nn.functional.interpolate(x, size=self.img_size, mode='bilinear') + x = rearrange(x, 'b (th tw c) h w -> b c (th h) (tw w)', th=self.thumbnail_rows, c=3) + return x + + def pad_frames(self, x): + frame_num = self.duration - self.frame_padding + x = x.view((-1, 3 * frame_num) + x.size()[2:]) + x_padding = torch.zeros((x.shape[0], 3 * self.frame_padding) + x.size()[2:]).cuda() + x = torch.cat((x, x_padding), dim=1) + assert x.shape[1] == 3 * self.duration, 'frame number %d not the same as adjusted input size %d' % ( + x.shape[1], 3 * self.duration) + + return x + + # need to find a better way to do this, maybe torch.fold? + def create_image_pos_embed(self): + img_rows, img_cols = self.patches_resolution # (56, 56) + _, _, T = self.frame_pos_embed.shape # (1, 4, embed) + rows = img_rows // self.thumbnail_rows # 28 + cols = img_cols // (self.duration // self.thumbnail_rows) # 28 + img_pos_embed = torch.zeros(img_rows, img_cols, T).cuda() # [56, 56, embed] + for i in range(self.duration): + r_indx = (i // self.thumbnail_rows) * rows + c_indx = (i % self.thumbnail_rows) * cols + img_pos_embed[r_indx:r_indx + rows, c_indx:c_indx + cols] = self.frame_pos_embed[0, i] + + return img_pos_embed.reshape(-1, T) # [56*56, embed] + + def forward_features(self, x): + if self.frame_padding > 0: + x = self.pad_frames(x) + else: + x = x.view((-1, 3 * self.duration) + x.size()[2:]) + + x = self.create_thumbnail(x) + x = nn.functional.interpolate(x, size=self.img_size, mode='bilinear') # [B, 3, 224, 224] + + x = self.patch_embed(x) # [B, 56*56, embed] + if self.ape: + img_pos_embed = self.create_image_pos_embed() + x = x + img_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x) + + x = self.norm(x) # B L C + x = self.avgpool(x.transpose(1, 2)) # B C 1 + x = torch.flatten(x, 1) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + def flops(self): + flops = 0 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) + flops += self.num_features * self.num_classes + return flops + + +def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, img_size=224, num_patches=196, + pretrained_window_size=7, pretrained_model="", strict=True): + if cfg is None: + cfg = getattr(model, 'default_cfg') + if cfg is None or 'url' not in cfg or not cfg['url']: + _logger.warning("Pretrained model URL is invalid, using random initialization.") + return + + if len(pretrained_model) == 0: + # state_dict = model_zoo.load_url(cfg['url'], progress=False, map_location='cpu') + state_dict = load_state_dict_from_url(cfg['url'], map_location='cpu') + else: + try: + state_dict = load_state_dict(pretrained_model)['model'] + except: + state_dict = load_state_dict(pretrained_model) + + if filter_fn is not None: + state_dict = filter_fn(state_dict) + + if in_chans == 1: + conv1_name = cfg['first_conv'] + _logger.info('Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name) + conv1_weight = state_dict[conv1_name + '.weight'] + conv1_type = conv1_weight.dtype + conv1_weight = conv1_weight.float() + O, I, J, K = conv1_weight.shape + if I > 3: + assert conv1_weight.shape[1] % 3 == 0 + # For models with space2depth stems + conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K) + conv1_weight = conv1_weight.sum(dim=2, keepdim=False) + else: + conv1_weight = conv1_weight.sum(dim=1, keepdim=True) + conv1_weight = conv1_weight.to(conv1_type) + state_dict[conv1_name + '.weight'] = conv1_weight + elif in_chans != 3: + conv1_name = cfg['first_conv'] + conv1_weight = state_dict[conv1_name + '.weight'] + conv1_type = conv1_weight.dtype + conv1_weight = conv1_weight.float() + O, I, J, K = conv1_weight.shape + if I != 3: + _logger.warning('Deleting first conv (%s) from pretrained weights.' % conv1_name) + del state_dict[conv1_name + '.weight'] + strict = False + else: + _logger.info('Repeating first conv (%s) weights in channel dim.' % conv1_name) + repeat = int(math.ceil(in_chans / 3)) + conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] + conv1_weight *= (3 / float(in_chans)) + conv1_weight = conv1_weight.to(conv1_type) + state_dict[conv1_name + '.weight'] = conv1_weight + + classifier_name = cfg['classifier'] + if num_classes == 1000 and cfg['num_classes'] == 1001: + # special case for imagenet trained models with extra background class in pretrained weights + classifier_weight = state_dict[classifier_name + '.weight'] + state_dict[classifier_name + '.weight'] = classifier_weight[1:] + classifier_bias = state_dict[classifier_name + '.bias'] + state_dict[classifier_name + '.bias'] = classifier_bias[1:] + elif num_classes != cfg['num_classes']: # and len(pretrained_model) == 0: + # completely discard fully connected for all other differences between pretrained and created model + del state_dict['model'][classifier_name + '.weight'] + del state_dict['model'][classifier_name + '.bias'] + strict = False + ''' + ## Resizing the positional embeddings in case they don't match + if img_size != cfg['input_size'][1]: + pos_embed = state_dict['pos_embed'] + cls_pos_embed = pos_embed[0, 0, :].unsqueeze(0).unsqueeze(1) + other_pos_embed = pos_embed[0, 1:, :].unsqueeze(0).transpose(1, 2) + new_pos_embed = F.interpolate(other_pos_embed, size=(num_patches), mode='nearest') + new_pos_embed = new_pos_embed.transpose(1, 2) + new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1) + state_dict['pos_embed'] = new_pos_embed + ''' + + # remove window_size related parameters + window_size = (model.window_size)[0] + print(pretrained_window_size, window_size) + + new_state_dict = state_dict['model'].copy() + for key in state_dict['model']: + if 'attn_mask' in key: + del new_state_dict[key] + + if 'relative_position_index' in key: + del new_state_dict[key] + + # resize it + if 'relative_position_bias_table' in key: + pretrained_table = state_dict['model'][key] + pretrained_table_size = int(math.sqrt(pretrained_table.shape[0])) + table_size = int(math.sqrt(model.state_dict()[key].shape[0])) + if pretrained_table_size != table_size: + table = pretrained_table.permute(1, 0).view(1, -1, pretrained_table_size, pretrained_table_size) + table = nn.functional.interpolate(table, size=table_size, mode='bilinear') + table = table.view(-1, table_size * table_size).permute(1, 0) + new_state_dict[key] = table + + for key in model.state_dict(): + if 'bottleneck_norm' in key: + attn_key = key.replace('bottleneck_norm', 'norm1') + # print (key, attn_key) + new_state_dict[key] = new_state_dict[attn_key] + + print('loading weights....') + ## Loading the weights + model.load_state_dict(new_state_dict, strict=False) + + +def _conv_filter(state_dict, patch_size=4): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + out_dict = {} + for k, v in state_dict.items(): + if 'patch_embed.proj.weight' in k: + if v.shape[-1] != patch_size: + patch_size = v.shape[-1] + v = v.reshape((v.shape[0], 3, patch_size, patch_size)) + out_dict[k] = v + return out_dict diff --git a/training/detectors/temp.py b/training/detectors/temp.py new file mode 100644 index 0000000000000000000000000000000000000000..1a683b0afd6421d87f4e8a3a7fbcd7d7d4fbe0ef --- /dev/null +++ b/training/detectors/temp.py @@ -0,0 +1,142 @@ + +# #### +import torch +import torch.nn as nn +import torch.nn.functional as F +from vit_pytorch import ViT +import clip # Import the CLIP library + +class ViT_VAE_Detector(nn.Module): + def __init__(self, image_size=224, patch_size=16, num_classes=3, clip_model_name="ViT-B/32"): + super().__init__() + # 1. ViT feature extractor (1024 dimensions) + + # 2. Load the pretrained CLIP model as the semantic anchor + + # 3. VAE disentanglement encoder (semantics + fingerprint) + self.feat_dim = 1024 + self.semantic_dim = 512 # semantic feature dimension + self.fingerprint_dim = 512 # fingerprint feature dimension + + # Semantic encoder (outputs 512 dimensions and aligns with CLIP features directly or through a projection layer) + self.semantic_encoder = nn.Sequential( + nn.Linear(self.feat_dim, 768), + nn.ReLU(), + nn.Linear(768, self.semantic_dim) + ) + # Projection layer for semantic features (used for alignment when semantic and CLIP dimensions differ) + if self.semantic_dim != self.clip_feat_dim: + self.sem_proj_to_clip = nn.Linear(self.semantic_dim, self.clip_feat_dim) + else: + self.sem_proj_to_clip = nn.Identity() # No projection is needed when dimensions already match + + # Fingerprint encoder + self.fingerprint_encoder = nn.Sequential( + nn.Linear(self.feat_dim, 768), + nn.ReLU(), + nn.Linear(768, self.fingerprint_dim) + ) + + # 4. Decoder and classifier head + self.decoder = nn.Sequential( # Simplified decoder that keeps the reconstruction in 1024 dimensions + nn.Linear(self.feat_dim, 768), + nn.ReLU(), + nn.Linear(768, self.feat_dim) + ) + self.classifier = nn.Linear(self.fingerprint_dim, num_classes) + + # VAE mean and variance layers (simplified to direct outputs without intermediate layers) + self.semantic_mu = nn.Linear(self.semantic_dim, self.semantic_dim) + self.semantic_logvar = nn.Linear(self.semantic_dim, self.semantic_dim) + self.fingerprint_mu = nn.Linear(self.fingerprint_dim, self.fingerprint_dim) + self.fingerprint_logvar = nn.Linear(self.fingerprint_dim, self.fingerprint_dim) + + def reparameterize(self, mu, logvar): + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return mu + eps * std + + def forward(self, x): + # Step 1: extract ViT and CLIP features as the semantic anchor + vit_feat = self.vit(x) # (batch_size, 1024) + with torch.no_grad(): # do not update CLIP + clip_feat = self.clip_model.encode_image(x) # (batch_size, clip_feat_dim) + clip_feat = F.normalize(clip_feat, dim=-1) # normalize CLIP features + + # Step 2: disentangle semantic and fingerprint features + z_semantic_raw = self.semantic_encoder(vit_feat) # (batch_size, 512) + sem_mu = self.semantic_mu(z_semantic_raw) + sem_logvar = self.semantic_logvar(z_semantic_raw) + z_semantic = self.reparameterize(sem_mu, sem_logvar) # semantic feature + + z_fingerprint_raw = self.fingerprint_encoder(vit_feat) # (batch_size, 512) + fing_mu = self.fingerprint_mu(z_fingerprint_raw) + fing_logvar = self.fingerprint_logvar(z_fingerprint_raw) + z_fingerprint = self.reparameterize(fing_mu, fing_logvar) # fingerprint feature + + # Step 3: project semantic features into the CLIP space for regularization + z_semantic_clip = self.sem_proj_to_clip(z_semantic) # align CLIP dimensions + z_semantic_clip = F.normalize(z_semantic_clip, dim=-1) # normalize to ensure a consistent scale + + # Step 4: reconstruction and classification + recon_feat = self.decoder(torch.cat([z_semantic, z_fingerprint], dim=-1)) # reconstruct from concatenation + logits = self.classifier(z_fingerprint) + + return { + "logits": logits, + "z_semantic": z_semantic, + "z_fingerprint": z_fingerprint, + "z_semantic_clip": z_semantic_clip, # semantic features projected into CLIP space + "clip_feat": clip_feat, # CLIP features (semantic anchor) + "recon_feat": recon_feat, + + "sem_mu": sem_mu, + "sem_logvar": sem_logvar, + "fing_mu": fing_mu, + "fing_logvar": fing_logvar, + "vit_feat": vit_feat + } + +# Loss function: add CLIP semantic regularization +class DeepfakeVAELoss(nn.Module): + def __init__(self, kl_weight=0.001, recon_weight=1.0, clip_sem_weight=0.5): + super().__init__() + self.kl_weight = kl_weight + self.recon_weight = recon_weight + self.clip_sem_weight = clip_sem_weight # weight of the CLIP semantic regularization + self.ce_loss = nn.CrossEntropyLoss() + self.mse_loss = nn.MSELoss() + + def kl_divergence(self, mu, logvar): + return -0.5 * torch.sum(1 + logvar - mu**2 - logvar.exp(), dim=1).mean() + + def forward(self, model_outputs, targets): + # 1. Original loss + loss_ce = self.ce_loss(model_outputs["logits"], targets) + loss_recon = self.mse_loss(model_outputs["recon_feat"], model_outputs["vit_feat"]) + loss_kl_sem = self.kl_divergence(model_outputs["sem_mu"], model_outputs["sem_logvar"]) + loss_kl_fing = self.kl_divergence(model_outputs["fing_mu"], model_outputs["fing_logvar"]) + loss_kl = loss_kl_sem + loss_kl_fing + + # 2. Added: CLIP semantic regularization (MSE loss) + # Make the projected semantic features as close as possible to CLIP features + loss_clip_sem = self.mse_loss( + model_outputs["z_semantic_clip"], # disentangled semantic features projected into CLIP space + model_outputs["clip_feat"] # semantic features extracted by CLIP (anchor) + ) + + # 3. Total loss + total_loss = ( + loss_ce + + self.recon_weight * loss_recon + + self.kl_weight * loss_kl + + self.clip_sem_weight * loss_clip_sem # add CLIP regularization + ) + return { + "total_loss": total_loss, + "loss_ce": loss_ce, + "loss_recon": loss_recon, + "loss_kl": loss_kl, + "loss_clip_sem": loss_clip_sem # monitor the effect of semantic regularization + } + diff --git a/training/detectors/timesformer_detector.py b/training/detectors/timesformer_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..83441083f023b912a6d79f185a80592ece14c992 --- /dev/null +++ b/training/detectors/timesformer_detector.py @@ -0,0 +1,111 @@ +""" +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the TimesformerDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{bertasius2021space, + title={Is space-time attention all you need for video understanding?}, + author={Bertasius, Gedas and Wang, Heng and Torresani, Lorenzo}, + booktitle={ICML}, + volume={2}, + number={3}, + pages={4}, + year={2021} +} +""" + +import logging +import torch +import torch.nn as nn +from detectors import DETECTOR +from loss import LOSSFUNC +from metrics.base_metrics_class import calculate_metrics_for_train + +from .base_detector import AbstractDetector + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='timesformer') +class TimeSformerDetector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + # self.fc_norm = nn.LayerNorm(768) + # self.temporal_module = self.build_temporal_module(config) + self.head = nn.Linear(768, 2) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + from transformers import TimesformerModel + backbone = TimesformerModel.from_pretrained(config['pretrained']) + # for name, param in backbone.named_parameters(): + # print('{}: {}'.format(name, param.requires_grad)) + # num_param = sum(p.numel() for p in backbone.parameters() if p.requires_grad) + # num_total_param = sum(p.numel() for p in backbone.parameters()) + # print('Number of total parameters: {}, tunable parameters: {}'.format(num_total_param, num_param)) + return backbone + + def build_temporal_module(self, config): + return nn.LSTM(input_size=2048, hidden_size=512, num_layers=3, batch_first=True) + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + # b, t, c, h, w = data_dict['image'].shape + # frame_input = data_dict['image'].reshape(-1, c, h, w) + # # get frame-level features + # frame_level_features = self.backbone.features(frame_input) + # frame_level_features = F.adaptive_avg_pool2d(frame_level_features, (1, 1)).reshape(b, t, -1) + # # get video-level features + # video_level_features = self.temporal_module(frame_level_features)[0][:, -1, :] + + outputs = self.backbone(data_dict['image'], output_hidden_states=True) + video_level_features = outputs[0][:, 0] + return video_level_features + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + features = self.features(data_dict) + pred = self.classifier(features) + prob = torch.softmax(pred, dim=1)[:, 1] + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + + return pred_dict diff --git a/training/detectors/ucf_detector.py b/training/detectors/ucf_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..c45ddcf07522f5b62329e54eb04f90ec90251f00 --- /dev/null +++ b/training/detectors/ucf_detector.py @@ -0,0 +1,466 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the UCFDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@article{yan2023ucf, + title={UCF: Uncovering Common Features for Generalizable Deepfake Detection}, + author={Yan, Zhiyuan and Zhang, Yong and Fan, Yanbo and Wu, Baoyuan}, + journal={arXiv preprint arXiv:2304.13949}, + year={2023} +} +''' + +import os +import datetime +import logging +import random +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC + +logger = logging.getLogger(__name__) + +@DETECTOR.register_module(module_name='ucf') +class UCFDetector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.num_classes = config['backbone_config']['num_classes'] + self.encoder_feat_dim = config['encoder_feat_dim'] + self.half_fingerprint_dim = self.encoder_feat_dim//2 + + self.encoder_f = self.build_backbone(config) + self.encoder_c = self.build_backbone(config) + + self.loss_func = self.build_loss(config) + self.prob, self.label = [], [] + self.correct, self.total = 0, 0 + + # basic function + self.lr = nn.LeakyReLU(inplace=True) + self.do = nn.Dropout(0.2) + self.pool = nn.AdaptiveAvgPool2d(1) + + # conditional gan + self.con_gan = Conditional_UNet() + + # head + specific_task_number = len(config['train_dataset']) + 1 # default: 5 in FF++ + self.head_spe = Head( + in_f=self.half_fingerprint_dim, + hidden_dim=self.encoder_feat_dim, + out_f=specific_task_number + ) + self.head_sha = Head( + in_f=self.half_fingerprint_dim, + hidden_dim=self.encoder_feat_dim, + out_f=self.num_classes + ) + self.block_spe = Conv2d1x1( + in_f=self.encoder_feat_dim, + hidden_dim=self.half_fingerprint_dim, + out_f=self.half_fingerprint_dim + ) + self.block_sha = Conv2d1x1( + in_f=self.encoder_feat_dim, + hidden_dim=self.half_fingerprint_dim, + out_f=self.half_fingerprint_dim + ) + + def build_backbone(self, config): + # prepare the backbone + backbone_class = BACKBONE[config['backbone_name']] + model_config = config['backbone_config'] + backbone = backbone_class(model_config) + # if donot load the pretrained weights, fail to get good results + state_dict = torch.load(config['pretrained']) + for name, weights in state_dict.items(): + if 'pointwise' in name: + state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1) + state_dict = {k:v for k, v in state_dict.items() if 'fc' not in k} + backbone.load_state_dict(state_dict, False) + logger.info('Load pretrained model successfully!') + return backbone + + def build_loss(self, config): + cls_loss_class = LOSSFUNC[config['loss_func']['cls_loss']] + spe_loss_class = LOSSFUNC[config['loss_func']['spe_loss']] + con_loss_class = LOSSFUNC[config['loss_func']['con_loss']] + rec_loss_class = LOSSFUNC[config['loss_func']['rec_loss']] + cls_loss_func = cls_loss_class() + spe_loss_func = spe_loss_class() + con_loss_func = con_loss_class(margin=3.0) + rec_loss_func = rec_loss_class() + loss_func = { + 'cls': cls_loss_func, + 'spe': spe_loss_func, + 'con': con_loss_func, + 'rec': rec_loss_func, + } + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + cat_data = data_dict['image'] + # encoder + f_all = self.encoder_f.features(cat_data) + c_all = self.encoder_c.features(cat_data) + feat_dict = {'forgery': f_all, 'content': c_all} + return feat_dict + + def classifier(self, features: torch.tensor) -> torch.tensor: + # classification, multi-task + # split the features into the specific and common forgery + f_spe = self.block_spe(features) + f_share = self.block_sha(features) + return f_spe, f_share + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + if 'label_spe' in data_dict and 'recontruction_imgs' in pred_dict: + return self.get_train_losses(data_dict, pred_dict) + else: # test mode + return self.get_test_losses(data_dict, pred_dict) + + def get_train_losses(self, data_dict: dict, pred_dict: dict) -> dict: + # get combined, real, fake imgs + cat_data = data_dict['image'] + real_img, fake_img = cat_data.chunk(2, dim=0) + # get the reconstruction imgs + reconstruction_image_1, \ + reconstruction_image_2, \ + self_reconstruction_image_1, \ + self_reconstruction_image_2 \ + = pred_dict['recontruction_imgs'] + # get label + label = data_dict['label'] + label_spe = data_dict['label_spe'] + # get pred + pred = pred_dict['cls'] + pred_spe = pred_dict['cls_spe'] + + # 1. classification loss for common features + loss_sha = self.loss_func['cls'](pred, label) + + # 2. classification loss for specific features + loss_spe = self.loss_func['spe'](pred_spe, label_spe) + + # 3. reconstruction loss + self_loss_reconstruction_1 = self.loss_func['rec'](fake_img, self_reconstruction_image_1) + self_loss_reconstruction_2 = self.loss_func['rec'](real_img, self_reconstruction_image_2) + cross_loss_reconstruction_1 = self.loss_func['rec'](fake_img, reconstruction_image_2) + cross_loss_reconstruction_2 = self.loss_func['rec'](real_img, reconstruction_image_1) + loss_reconstruction = \ + self_loss_reconstruction_1 + self_loss_reconstruction_2 + \ + cross_loss_reconstruction_1 + cross_loss_reconstruction_2 + + # 4. constrative loss + common_features = pred_dict['feat'] + specific_features = pred_dict['feat_spe'] + loss_con = self.loss_func['con'](common_features, specific_features, label_spe) + + # 5. total loss + loss = loss_sha + 0.1*loss_spe + 0.3*loss_reconstruction + 0.05*loss_con + loss_dict = { + 'overall': loss, + 'common': loss_sha, + 'specific': loss_spe, + 'reconstruction': loss_reconstruction, + 'contrastive': loss_con, + } + return loss_dict + + def get_test_losses(self, data_dict: dict, pred_dict: dict) -> dict: + # get label + label = data_dict['label'] + # get pred + pred = pred_dict['cls'] + # for test mode, only classification loss for common features + loss = self.loss_func['cls'](pred, label) + loss_dict = {'common': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + def get_accracy(label, output): + _, prediction = torch.max(output, 1) # argmax + correct = (prediction == label).sum().item() + accuracy = correct / prediction.size(0) + return accuracy + + # get pred and label + label = data_dict['label'] + pred = pred_dict['cls'] + label_spe = data_dict['label_spe'] + pred_spe = pred_dict['cls_spe'] + + # compute metrics for batch data + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + acc_spe = get_accracy(label_spe.detach(), pred_spe.detach()) + metric_batch_dict = {'acc': acc, 'acc_spe': acc_spe, 'auc': auc, 'eer': eer, 'ap': ap} + # we dont compute the video-level metrics for training + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # split the features into the content and forgery + features = self.features(data_dict) + forgery_features, content_features = features['forgery'], features['content'] + # get the prediction by classifier (split the common and specific forgery) + f_spe, f_share = self.classifier(forgery_features) + + if inference: + # inference only consider share loss + out_sha, sha_feat = self.head_sha(f_share) + out_spe, spe_feat = self.head_spe(f_spe) + prob_sha = torch.softmax(out_sha, dim=1)[:, 1] + self.prob.append( + prob_sha + .detach() + .squeeze() + .cpu() + .numpy() + ) + self.label.append( + data_dict['label'] + .detach() + .squeeze() + .cpu() + .numpy() + ) + # deal with acc + _, prediction_class = torch.max(out_sha, 1) + common_label = (data_dict['label'] >= 1) + correct = (prediction_class == common_label).sum().item() + self.correct += correct + self.total += data_dict['label'].size(0) + + pred_dict = {'cls': out_sha, 'feat': sha_feat} + return pred_dict + + bs = f_share.size(0) + # using idx aug in the training mode + aug_idx = random.random() + if aug_idx < 0.7: + # real + idx_list = list(range(0, bs//2)) + random.shuffle(idx_list) + f_share[0: bs//2] = f_share[idx_list] + # fake + idx_list = list(range(bs//2, bs)) + random.shuffle(idx_list) + f_share[bs//2: bs] = f_share[idx_list] + + # concat spe and share to obtain new_f_all + f_all = torch.cat((f_spe, f_share), dim=1) + + # reconstruction loss + f2, f1 = f_all.chunk(2, dim=0) + c2, c1 = content_features.chunk(2, dim=0) + + # ==== self reconstruction ==== # + # f1 + c1 -> f11, f11 + c1 -> near~I1 + self_reconstruction_image_1 = self.con_gan(f1, c1) + + # f2 + c2 -> f2, f2 + c2 -> near~I2 + self_reconstruction_image_2 = self.con_gan(f2, c2) + + # ==== cross combine ==== # + reconstruction_image_1 = self.con_gan(f1, c2) + reconstruction_image_2 = self.con_gan(f2, c1) + + # head for spe and sha + out_spe, spe_feat = self.head_spe(f_spe) + out_sha, sha_feat = self.head_sha(f_share) + + # get the probability of the pred + prob_sha = torch.softmax(out_sha, dim=1)[:, 1] + prob_spe = torch.softmax(out_spe, dim=1)[:, 1] + + # build the prediction dict for each output + pred_dict = { + 'cls': out_sha, + 'prob': prob_sha, + 'feat': sha_feat, + 'cls_spe': out_spe, + 'prob_spe': prob_spe, + 'feat_spe': spe_feat, + 'feat_content': content_features, + 'recontruction_imgs': ( + reconstruction_image_1, + reconstruction_image_2, + self_reconstruction_image_1, + self_reconstruction_image_2 + ) + } + return pred_dict + +def sn_double_conv(in_channels, out_channels): + return nn.Sequential( + nn.utils.spectral_norm( + nn.Conv2d(in_channels, in_channels, 3, padding=1)), + nn.utils.spectral_norm( + nn.Conv2d(in_channels, out_channels, 3, padding=1, stride=2)), + nn.LeakyReLU(0.2, inplace=True) + ) + +def r_double_conv(in_channels, out_channels): + return nn.Sequential( + nn.Conv2d(in_channels, out_channels, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, 3, padding=1), + nn.ReLU(inplace=True) + ) + +class AdaIN(nn.Module): + def __init__(self, eps=1e-5): + super().__init__() + self.eps = eps + # self.l1 = nn.Linear(num_classes, in_channel*4, bias=True) #bias is good :) + + def c_norm(self, x, bs, ch, eps=1e-7): + # assert isinstance(x, torch.cuda.FloatTensor) + x_var = x.var(dim=-1) + eps + x_std = x_var.sqrt().view(bs, ch, 1, 1) + x_mean = x.mean(dim=-1).view(bs, ch, 1, 1) + return x_std, x_mean + + def forward(self, x, y): + assert x.size(0)==y.size(0) + size = x.size() + bs, ch = size[:2] + x_ = x.view(bs, ch, -1) + y_ = y.reshape(bs, ch, -1) + x_std, x_mean = self.c_norm(x_, bs, ch, eps=self.eps) + y_std, y_mean = self.c_norm(y_, bs, ch, eps=self.eps) + out = ((x - x_mean.expand(size)) / x_std.expand(size)) \ + * y_std.expand(size) + y_mean.expand(size) + return out + +class Conditional_UNet(nn.Module): + + def init_weight(self, std=0.2): + for m in self.modules(): + cn = m.__class__.__name__ + if cn.find('Conv') != -1: + m.weight.data.normal_(0., std) + elif cn.find('Linear') != -1: + m.weight.data.normal_(1., std) + m.bias.data.fill_(0) + + def __init__(self): + super(Conditional_UNet, self).__init__() + + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + self.maxpool = nn.MaxPool2d(2) + self.dropout = nn.Dropout(p=0.3) + #self.dropout_half = HalfDropout(p=0.3) + + self.adain3 = AdaIN() + self.adain2 = AdaIN() + self.adain1 = AdaIN() + + self.dconv_up3 = r_double_conv(512, 256) + self.dconv_up2 = r_double_conv(256, 128) + self.dconv_up1 = r_double_conv(128, 64) + + self.conv_last = nn.Conv2d(64, 3, 1) + self.up_last = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) + self.activation = nn.Tanh() + #self.init_weight() + + def forward(self, c, x): # c is the style and x is the content + x = self.adain3(x, c) + x = self.upsample(x) + x = self.dropout(x) + x = self.dconv_up3(x) + c = self.upsample(c) + c = self.dropout(c) + c = self.dconv_up3(c) + + x = self.adain2(x, c) + x = self.upsample(x) + x = self.dropout(x) + x = self.dconv_up2(x) + c = self.upsample(c) + c = self.dropout(c) + c = self.dconv_up2(c) + + x = self.adain1(x, c) + x = self.upsample(x) + x = self.dropout(x) + x = self.dconv_up1(x) + + x = self.conv_last(x) + out = self.up_last(x) + + return self.activation(out) + +class MLP(nn.Module): + def __init__(self, in_f, hidden_dim, out_f): + super(MLP, self).__init__() + self.pool = nn.AdaptiveAvgPool2d(1) + self.mlp = nn.Sequential(nn.Linear(in_f, hidden_dim), + nn.LeakyReLU(inplace=True), + nn.Linear(hidden_dim, hidden_dim), + nn.LeakyReLU(inplace=True), + nn.Linear(hidden_dim, out_f),) + + def forward(self, x): + x = self.pool(x) + x = self.mlp(x) + return x + +class Conv2d1x1(nn.Module): + def __init__(self, in_f, hidden_dim, out_f): + super(Conv2d1x1, self).__init__() + self.conv2d = nn.Sequential(nn.Conv2d(in_f, hidden_dim, 1, 1), + nn.LeakyReLU(inplace=True), + nn.Conv2d(hidden_dim, hidden_dim, 1, 1), + nn.LeakyReLU(inplace=True), + nn.Conv2d(hidden_dim, out_f, 1, 1),) + + def forward(self, x): + x = self.conv2d(x) + return x + +class Head(nn.Module): + def __init__(self, in_f, hidden_dim, out_f): + super(Head, self).__init__() + self.do = nn.Dropout(0.2) + self.pool = nn.AdaptiveAvgPool2d(1) + self.mlp = nn.Sequential(nn.Linear(in_f, hidden_dim), + nn.LeakyReLU(inplace=True), + nn.Linear(hidden_dim, out_f),) + + def forward(self, x): + bs = x.size()[0] + x_feat = self.pool(x).view(bs, -1) + x = self.mlp(x_feat) + x = self.do(x) + return x, x_feat diff --git a/training/detectors/uia_vit_detector.py b/training/detectors/uia_vit_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..e0d2beaba6b489fe70c2be81aaa9c7c3c91f9f9c --- /dev/null +++ b/training/detectors/uia_vit_detector.py @@ -0,0 +1,418 @@ +""" +# author: Kangran ZHAO +# email: kangranzhao@link.cuhk.edu.cn +# date: 2024-0410 +# description: Class for the UIA-ViT Detector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{zhuang2020UIA, + title={UIA-ViT: Unsupervised Inconsistency-Aware Method based on Vision Transformer for Face Forgery Detection}, + author={Zhuang, Wanyi and Chu, Qi and Tan, Zhentao and Liu, Qiankun and Yuan, Haojie and Miao, Changtao and Luo, Zixiang and Yu, Nenghai}, + booktitle={European Conference on Computer Vision (ECCV)}, + year={2022}, +} + +Codes are modified based on GitHub repo https://github.com/wany0824/UIA-ViT +""" +from functools import partial + +import torch +import torch.nn as nn +from detectors import DETECTOR +from loss import LOSSFUNC +from metrics.base_metrics_class import calculate_metrics_for_train +from sklearn.covariance import LedoitWolf +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + +from .base_detector import AbstractDetector + + +@DETECTOR.register_module(module_name='uia_vit') +class UIAViTDetector(AbstractDetector): + def __init__(self, config): + super().__init__() + + self.config = config + self.batch_per_epoch = config["batch_per_epoch"] + self.num_epoch = config["nEpochs"] + + self.batch_cnt = 0 + self.real_feature_list, self.fake_feature_list = [], [] + self.real_inv_covariance, self.fake_inv_covariance = None, None + self.real_feature_mean, self.fake_feature_mean = None, None + + self.model = self.build_backbone(config) + self.loss_func = self.build_loss(config) + self.loss_weight = config["loss_func"]["weights"] + + def build_backbone(self, config): + model = VisionTransformer(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), num_classes=2) + state_dict = torch.hub.load_state_dict_from_url(config["pretrained"]) + del state_dict["head.bias"], state_dict["head.weight"] + model.load_state_dict(state_dict, strict=False) + + return model + + def build_loss(self, config): + cls_loss_class = LOSSFUNC[config["loss_func"]["cls_loss"]] + pcl_loss_class = LOSSFUNC[config["loss_func"]["pcl_loss"]] + cls_loss_func = cls_loss_class() + pcl_loss_func = pcl_loss_class(c_real=self.model.c_real, c_fake=self.model.c_fake, c_cross=self.model.c_cross) + + return {"cls": cls_loss_func, "pcl": pcl_loss_func} + + def features(self, data_dict: dict) -> torch.tensor: + pass + + def classifier(self, features: torch.tensor) -> torch.tensor: + pass # do not overwrite this, since classifier structure has been written in self.forward() + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict["label"] + pred = pred_dict["cls"] + ce_loss = self.loss_func["cls"](pred, label) + if self.batch_cnt > self.batch_per_epoch and self.model.training: + pcl_loss = self.loss_func["pcl"](pred_dict["attention_map_real"], + pred_dict["attention_map_fake"], + pred_dict["feat"], + self.real_feature_mean, + self.real_inv_covariance, + self.fake_feature_mean, + self.fake_inv_covariance, + data_dict["label"]) + overall_loss = ce_loss + \ + self.loss_weight[0] * pcl_loss + \ + self.loss_weight[1] * (1 / torch.abs(self.model.c_real) + 1 / torch.abs(self.model.c_fake)) + \ + self.loss_weight[2] * torch.abs(self.model.c_cross) + + return {"overall": overall_loss, "ce_loss": ce_loss, "pcl_loss": pcl_loss, + "c1": self.model.c_real, "c2": self.model.c_fake, "c3": self.model.c_cross} + else: + return {"overall": ce_loss} + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # compute MVG + if self.model.training and self.batch_cnt != 0 and self.batch_cnt % (self.config["batch_per_epoch"] // 2) == 0: + real_feature_tensor = torch.cat(self.real_feature_list, dim=0).cuda() + self.real_inv_covariance = fit_inv_covariance(real_feature_tensor).cpu() + self.real_feature_mean = real_feature_tensor.mean(dim=0).cpu() + self.real_feature_list = [] + + fake_feature_tensor = torch.cat(self.fake_feature_list, dim=0).cuda() + self.fake_inv_covariance = fit_inv_covariance(fake_feature_tensor).cpu() + self.fake_feature_mean = fake_feature_tensor.mean(dim=0).cpu() + self.fake_feature_list = [] + + step = self.batch_cnt / (self.batch_per_epoch * self.num_epoch) if self.model.training else 1 + pred, feature_patch, attention_map = self.model(data_dict["image"], step=step) + + # collect features of real patches and inner fake patches + real_indices = torch.where(data_dict["label"] == 0.0)[0] + feature_patch_real = feature_patch[real_indices[:4]] + B, H, W, C = feature_patch_real.size() + self.real_feature_list.append(feature_patch_real.reshape(B * H * W, C).cpu().detach()) + + fake_indices = torch.where(data_dict["label"] == 1.0)[0] + feature_patch_fake = feature_patch[fake_indices[:4], 3:11, 3:11, :] # hard coding, extend config to modify if needed + B, H, W, C = feature_patch_fake.size() + self.fake_feature_list.append(feature_patch_fake.reshape(B * H * W, C).cpu().detach()) + + attention_map_real = torch.sigmoid(torch.mean(attention_map[real_indices, :, 1:, 1:], dim=1)) + attention_map_fake = torch.sigmoid(torch.mean(attention_map[fake_indices, :, 1:, 1:], dim=1)) + + prob = torch.softmax(pred, dim=1)[:, 1] + pred_dict = {"cls": pred, + "prob": prob, + "feat": feature_patch} + + del attention_map, feature_patch + + pred_dict["attention_map_real"] = attention_map_real + pred_dict["attention_map_fake"] = attention_map_fake + self.batch_cnt += 1 + + return pred_dict + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + # [B, 196, 768] -> [B, 196, 768*3] -> [B, 196, 3, 8, 96] -> [3, B, 8, 196, 96] + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn_qk = (q @ k.transpose(-2, -1)) * self.scale + attn_s = attn_qk.softmax(dim=-1) + attn = self.attn_drop(attn_s) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x, attn_qk + + +class Block(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x_attn, attn = self.attn(self.norm1(x)) + x = x + self.drop_path(x_attn) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x, attn + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + # assert H == self.img_size[0] and W == self.img_size[1], \ + # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # [B, H*W, C] + return x + + +class HybridEmbed(nn.Module): + """ CNN Feature Map Embedding + Extract feature map from CNN, flatten, project to embedding dim. + """ + + def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): + super().__init__() + assert isinstance(backbone, nn.Module) + img_size = to_2tuple(img_size) + self.img_size = img_size + self.backbone = backbone + if feature_size is None: + with torch.no_grad(): + # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature + # map for all networks, the feature metadata has reliable channel and stride info, but using + # stride to calc feature dim requires info about padding of each stage that isn't captured. + training = backbone.training + if training: + backbone.eval() + o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + backbone.train(training) + else: + feature_size = to_2tuple(feature_size) + feature_dim = self.backbone.feature_info.channels()[-1] + self.num_patches = feature_size[0] * feature_size[1] + self.proj = nn.Linear(feature_dim, embed_dim) + + def forward(self, x): + x = self.backbone(x)[-1] + x = x.flatten(2).transpose(1, 2) + x = self.proj(x) + return x + + +class VisionTransformer(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm): + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + if hybrid_backbone is not None: + self.patch_embed = HybridEmbed( + hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) + else: + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + self.c_real = nn.Parameter(torch.tensor(0.6)) + self.c_fake = nn.Parameter(torch.tensor(0.6)) + self.c_cross = nn.Parameter(torch.tensor(0.2)) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + self.norm_middle = norm_layer(embed_dim) + + # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here + # self.repr = nn.Linear(embed_dim, representation_size) + # self.repr_act = nn.Tanh() + + # Classifier head + self.head = nn.Linear(embed_dim * 2, num_classes) if num_classes > 0 else nn.Identity() + + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x, attn_blk, feat_blk=False): + if feat_blk == False: + feat_blk = attn_blk - 1 + B = x.shape[0] + x = self.patch_embed(x) + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embed + x = self.pos_drop(x) + if isinstance(attn_blk, int): + for i, blk in enumerate(self.blocks): + if i == feat_blk: + x_block = self.norm_middle(x) + if i == attn_blk: + attn_block = attn + x, attn = blk(x) + x = self.norm(x) # for vit_base_patch16_224: x.size() = [B, 14**2+1 (197) , 768] + if i == attn_blk - 1: + attn_block = attn + if i == feat_blk - 1: + x_block = x + elif isinstance(attn_blk, list): + attn_list = [] + for i, blk in enumerate(self.blocks): + if i == feat_blk: + x_block = self.norm_middle(x) + if i in attn_blk: + attn_list.append(attn) + x, attn = blk(x) + x = self.norm(x) # for vit_base_patch16_224: x.size() = [B, 14**2+1 (197) , 768] + if (i + 1) in attn_blk: + attn_list.append(attn) + if i == feat_blk - 1: + x_block = x + attn_block = torch.cat(attn_list, dim=1) + + x_block = x_block[:, 1:].reshape( + (x_block.size(0), int(x_block.size(1) ** 0.5), int(x_block.size(1) ** 0.5), x_block.size(2))) + return x, x_block, attn_block + + def forward(self, x, step=1, attn_blk=[8, 9, 10, 11, 12], feat_blk=6, k=12, thr=0.7, is_progressive=1): + x, feat_block, attn_block = self.forward_features(x, attn_blk, feat_blk) + + x_cls, x_patch = x[:, 0], x[:, 1:] + B, PP, C = x_patch.shape + localization_map = torch.sigmoid(torch.mean(attn_block[:, :, 0, 1:], dim=1)) + + if is_progressive: + if step < 1 / 8.: + localization_map = (torch.ones(B, 1, PP) / PP).to(x_patch.device) + else: + w = torch.sigmoid(torch.tensor(-k * (step - thr))).to(x_patch.device) + localization_map = (w * torch.ones(B, 1, PP).to(x_patch.device) + (1 - w) * localization_map.reshape(B, + 1, + PP).to( + x_patch.device)) / PP + else: + localization_map = localization_map.reshape(B, 1, PP).to(x_patch.device) / PP + x = torch.cat([x_cls, torch.bmm(localization_map, x_patch).squeeze(1)], -1) + x = self.head(x) + return x, feat_block, attn_block + + +def fit_inv_covariance(samples): + return torch.Tensor(LedoitWolf().fit(samples.cpu()).precision_).to( + samples.device + ) diff --git a/training/detectors/universal.py b/training/detectors/universal.py new file mode 100644 index 0000000000000000000000000000000000000000..7bd4e84adf1ba81426bec249d2459adf73a5d040 --- /dev/null +++ b/training/detectors/universal.py @@ -0,0 +1,141 @@ +# detectors/universal_detector.py + +import logging +import torch +import torch.nn as nn +import torch.nn.functional as F +import timm +from typing import Union + +from metrics.base_metrics_class import calculate_acc_for_train +from .base_detector import AbstractDetector +from . import DETECTOR + +logger = logging.getLogger(__name__) + +@DETECTOR.register_module(module_name='universal') +class UniversalDetector(AbstractDetector): + """ + UniversalDetector: use a timm backbone (such as ViT/CLIP) with a linear classification head. + The backbone is frozen here, and only the linear classifier head is trained. + """ + def __init__(self, config, load_param: Union[bool, str] = False): + super().__init__(config=config, load_param=load_param) + self.config = config + self.backbone = self.build_backbone(config) + self.loss_func = self.build_loss(config) + self.num_classes = config['backbone_config']['num_classes'] + with torch.no_grad(): + inc = config.get('backbone_config', {}).get('inc', 3) + img_size = config.get('resolution', 224) + dummy = torch.randn(1, inc, img_size, img_size) + backbone_feat = self.backbone_forward_features(dummy) # (1, D_backbone) + backbone_feat_dim = backbone_feat.shape[1] + + # ===== Added: feature projection layer to map backbone features to 1024 dimensions ===== + self.feat_dim = 1024 + self.feature_proj = nn.Linear(backbone_feat_dim, self.feat_dim) + + # Linear classifier head (the only component that needs training) + self.classifier_head = nn.Linear(self.feat_dim, self.num_classes) + + logger.info( + f"UniversalDetector initialized with backbone={config['backbone_name']}, " + f"feat_dim={self.feat_dim}, num_classes={self.num_classes}" + ) + + def build_backbone(self, config): + """ + Create the backbone with timm and freeze all of its parameters. + For example: vit_base_patch16_clip_224.openai + """ + backbone_name = config['backbone_name'] + logger.info(f"Building timm backbone (frozen): {backbone_name}") + + backbone = timm.create_model( + backbone_name, + pretrained=True, + num_classes=0, # Drop the final classification head and keep only the features + global_pool='' # Keep the original features (ViT returns CLS in `forward_features`) + ) + + # Freeze backbone parameters + for p in backbone.parameters(): + p.requires_grad = False + + backbone.eval() # Keep the backbone in eval mode during training to avoid changes from BN/Dropout + return backbone + + def build_loss(self, config): + if config.get('loss_func', 'cross_entropy') == 'cross_entropy': + logger.info("Using nn.CrossEntropyLoss as loss function.") + return nn.CrossEntropyLoss() + else: + from . import LOSSFUNC + loss_name = config['loss_func'] + logger.info(f"Using custom loss function: {loss_name}") + loss_class = LOSSFUNC[loss_name] + return loss_class() + + def backbone_forward_features(self, x: torch.Tensor) -> torch.Tensor: + """ + Run forward on the timm model to obtain features of shape B x D. + For `vit_base_patch16_clip_224.openai`, this is equivalent to taking the CLS feature from the final layer, + which stays fully consistent with the intermediate-layer selection in the original code. + """ + with torch.no_grad(): # Make sure no gradients are computed for the backbone + if hasattr(self.backbone, 'forward_features'): + feats = self.backbone.forward_features(x) + else: + feats = self.backbone(x) + + if feats.ndim > 2: + feats = torch.flatten(feats, start_dim=1) # B x D + return feats + + def features(self, data_dict: dict) -> torch.Tensor: + images = data_dict['image'] + feats = self.backbone_forward_features(images) + feats = self.feature_proj(feats) + return feats + + def classifier(self, features: torch.Tensor) -> torch.Tensor: + """ + Linear classifier head: B x D -> B x num_classes. + This is the only part that participates in backpropagation. + """ + logits = self.classifier_head(features) + return logits + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + return {'overall': loss} + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'].detach() + pred_logits = pred_dict['cls'].detach() + + acc, mAP = calculate_acc_for_train(label, pred_logits, self.num_classes) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + """ + Returns: + - 'cls': logits + - 'prob': softmax probability + - 'feat': feature vector from the frozen ViT output (intermediate layer selection matches the original code) + """ + feat = self.features(data_dict) # B x D(from the frozen backbone) + logits = self.classifier(feat) # B x num_classes(trainable) + prob = F.softmax(logits, dim=1) + + pred_dict = { + 'cls': logits, + 'prob': prob, + 'feat': feat + } + return pred_dict + diff --git a/training/detectors/utils/iid_api.py b/training/detectors/utils/iid_api.py new file mode 100644 index 0000000000000000000000000000000000000000..d86c0e5ea9f6a360e704b8ec38c14d01d5c0a17b --- /dev/null +++ b/training/detectors/utils/iid_api.py @@ -0,0 +1,268 @@ +from __future__ import print_function +from __future__ import division +import torch +import torch.nn as nn +from torch.nn import Parameter +import torch.distributed as dist +import math + + +def l2_norm(input, axis=1): + norm = torch.norm(input, p=2, dim=axis, keepdim=True) + output = torch.div(input, norm) + return output + + +def calc_logits(embeddings, kernel): + """ calculate original logits + """ + embeddings = l2_norm(embeddings, axis=1) + kernel_norm = l2_norm(kernel, axis=0) + cos_theta = torch.mm(embeddings, kernel_norm) + cos_theta = cos_theta.clamp(-1, 1) # for numerical stability + with torch.no_grad(): + origin_cos = cos_theta.clone() + return cos_theta, origin_cos + + +@torch.no_grad() +def all_gather_tensor(input_tensor): + """ allgather tensor (difference size in 0-dim) from all workers + """ + world_size = dist.get_world_size() + + tensor_size = torch.tensor([input_tensor.shape[0]], dtype=torch.int64).cuda() + tensor_size_list = [torch.ones_like(tensor_size) for _ in range(world_size)] + dist.all_gather(tensor_list=tensor_size_list, tensor=tensor_size, async_op=False) + max_size = torch.cat(tensor_size_list, dim=0).max() + + padded = torch.empty(max_size.item(), *input_tensor.shape[1:], dtype=input_tensor.dtype).cuda() + padded[:input_tensor.shape[0]] = input_tensor + padded_list = [torch.ones_like(padded) for _ in range(world_size)] + dist.all_gather(tensor_list=padded_list, tensor=padded, async_op=False) + + slices = [] + for ts, t in zip(tensor_size_list, padded_list): + slices.append(t[:ts.item()]) + return torch.cat(slices, dim=0) + + +def calc_top1_acc(original_logits, label,ddp=False): + """ + Compute the top1 accuracy during training + :param original_logits: logits w/o margin, [bs, C] + :param label: labels [bs] + :return: acc in all gpus + """ + assert (original_logits.size()[0] == label.size()[0]) + + with torch.no_grad(): + _, max_index = torch.max(original_logits, dim=1, keepdim=False) # local max logit + count = (max_index == label).sum() + if ddp: + dist.all_reduce(count, dist.ReduceOp.SUM) + + return count.item() / (original_logits.size()[0] * dist.get_world_size()) + else: + return count.item() / (original_logits.size()[0]) + +def l2_norm(input, axis=1): + norm = torch.norm(input, p=2, dim=axis, keepdim=True) + output = torch.div(input, norm) + return output + + +class FC_ddp2(nn.Module): + """ + Implement of (CVPR2021 Consistent Instance False Positive Improves Fairness in Face Recognition) + No model parallel is used + """ + + def __init__(self, + in_features, + out_features, + scale=64.0, + margin=0.4, + mode='cosface', + use_cifp=False, + reduction='mean', + ddp=False): + """ Args: + in_features: size of each input features + out_features: size of each output features + scale: norm of input feature + margin: margin + """ + super(FC_ddp2, self).__init__() + self.in_features = in_features + self.out_features = out_features # num of classes + self.scale = scale + self.margin = margin + self.mode = mode + self.use_cifp = use_cifp + self.kernel = Parameter(torch.Tensor(in_features, out_features)) + self.ddp = ddp + nn.init.normal_(self.kernel, std=0.01) + + self.criteria = torch.nn.CrossEntropyLoss(reduction=reduction) + + def apply_margin(self, target_cos_theta): + assert self.mode in ['cosface', 'arcface'], 'Please check the mode' + if self.mode == 'arcface': + cos_m = math.cos(self.margin) + sin_m = math.sin(self.margin) + theta = math.cos(math.pi - self.margin) + sinmm = math.sin(math.pi - self.margin) * self.margin + sin_theta = torch.sqrt(1.0 - torch.pow(target_cos_theta, 2)) + cos_theta_m = target_cos_theta * cos_m - sin_theta * sin_m + target_cos_theta_m = torch.where( + target_cos_theta > theta, cos_theta_m, target_cos_theta - sinmm) + elif self.mode == 'cosface': + target_cos_theta_m = target_cos_theta - self.margin + + return target_cos_theta_m + + def forward(self, embeddings, label, return_logits=False): + """ + + :param embeddings: local gpu [bs, 512] + :param label: local labels [bs] + :param return_logits: bool + :return: + loss: computed local loss, w/wo CIFP + acc: local accuracy in one gpu + output: local logits with margins, with gradients, scaled, [bs, C]. + """ + sample_num = embeddings.size(0) + + if not self.use_cifp: + cos_theta, origin_cos = calc_logits(embeddings, self.kernel) + target_cos_theta = cos_theta[torch.arange(0, sample_num), label].view(-1, 1) + target_cos_theta_m = self.apply_margin(target_cos_theta) + cos_theta.scatter_(1, label.view(-1, 1).long(), target_cos_theta_m) + else: + cos_theta, origin_cos = calc_logits(embeddings, self.kernel) + cos_theta_, _ = calc_logits(embeddings, self.kernel.detach()) + + mask = torch.zeros_like(cos_theta) # [bs,C] + mask.scatter_(1, label.view(-1, 1).long(), 1.0) # one-hot label / gt mask + + tmp_cos_theta = cos_theta - 2 * mask + tmp_cos_theta_ = cos_theta_ - 2 * mask + + target_cos_theta = cos_theta[torch.arange(0, sample_num), label].view(-1, 1) + target_cos_theta_ = cos_theta_[torch.arange(0, sample_num), label].view(-1, 1) + + target_cos_theta_m = self.apply_margin(target_cos_theta) + + far = 1 / (self.out_features - 1) # ru+ value + # far = 1e-5 + + topk_mask = torch.greater(tmp_cos_theta, target_cos_theta) + topk_sum = torch.sum(topk_mask.to(torch.int32)) + if self.ddp: + dist.all_reduce(topk_sum) + far_rank = math.ceil(far * (sample_num * (self.out_features - 1) * dist.get_world_size() - topk_sum)) + cos_theta_neg_topk = torch.topk((tmp_cos_theta - 2 * topk_mask.to(torch.float32)).flatten(), + k=far_rank)[0] # [far_rank] + cos_theta_neg_topk = all_gather_tensor(cos_theta_neg_topk.contiguous()) # top k across all gpus + cos_theta_neg_th = torch.topk(cos_theta_neg_topk, k=far_rank)[0][-1] + + cond = torch.mul(torch.bitwise_not(topk_mask), torch.greater(tmp_cos_theta, cos_theta_neg_th)) + cos_theta_neg_topk = torch.mul(cond.to(torch.float32), tmp_cos_theta) + cos_theta_neg_topk_ = torch.mul(cond.to(torch.float32), tmp_cos_theta_) + cond = torch.greater(target_cos_theta_m, cos_theta_neg_topk) + + cos_theta_neg_topk = torch.where(cond, cos_theta_neg_topk, cos_theta_neg_topk_) + cos_theta_neg_topk = torch.pow(cos_theta_neg_topk, 2) # F = z^p = cos^2 + times = torch.sum(torch.greater(cos_theta_neg_topk, 0).to(torch.float32), dim=1, keepdim=True) + times = torch.where(torch.greater(times, 0), times, torch.ones_like(times)) + cos_theta_neg_topk = torch.sum(cos_theta_neg_topk, dim=1, keepdim=True) / times # ri+/ru+ + + target_cos_theta_m = target_cos_theta_m - (1 + target_cos_theta_) * cos_theta_neg_topk + cos_theta.scatter_(1, label.view(-1, 1).long(), target_cos_theta_m) + + output = cos_theta * self.scale + loss = self.criteria(output, label) + acc = calc_top1_acc(origin_cos * self.scale, label,self.ddp) + + if return_logits: + return loss, acc, output + + return loss, acc + + +class FC_ddp(nn.Module): + """ + Implement of (CVPR2021 Consistent Instance False Positive Improves Fairness in Face Recognition) + No model parallel is used + """ + + def __init__(self, + in_features, + out_features, + scale=8.0, + margin=0.2, + mode='cosface', + use_cifp=False, + reduction='mean'): + """ Args: + in_features: size of each input features + out_features: size of each output features + scale: norm of input feature + margin: margin + """ + super(FC_ddp, self).__init__() + self.in_features = in_features + self.out_features = out_features # num of classes + self.scale = scale + self.margin = margin + self.mode = mode + self.use_cifp = use_cifp + # self.kernel = Parameter(torch.Tensor(in_features, out_features)) + # nn.init.normal_(self.kernel, std=0.01) + + self.criteria = torch.nn.CrossEntropyLoss(reduction=reduction) + self.sig = torch.nn.Sigmoid() + + def apply_margin(self, target_cos_theta): + assert self.mode in ['cosface', 'arcface'], 'Please check the mode' + if self.mode == 'arcface': + cos_m = math.cos(self.margin) + sin_m = math.sin(self.margin) + theta = math.cos(math.pi - self.margin) + sinmm = math.sin(math.pi - self.margin) * self.margin + sin_theta = torch.sqrt(1.0 - torch.pow(target_cos_theta, 2)) + cos_theta_m = target_cos_theta * cos_m - sin_theta * sin_m + target_cos_theta_m = torch.where( + target_cos_theta > theta, cos_theta_m, target_cos_theta - sinmm) + elif self.mode == 'cosface': + target_cos_theta_m = target_cos_theta - self.margin + + return target_cos_theta_m + + def forward(self, embeddings, label, return_logits=False): + """ + + :param embeddings: local gpu [bs, 512] + :param label: local labels [bs] + :param return_logits: bool + :return: + loss: computed local loss, w/wo CIFP + acc: local accuracy in one gpu + output: local logits with margins, with gradients, scaled, [bs, C]. + """ + sample_num = embeddings.size(0) + cos_theta = self.sig(embeddings) + target_cos_theta = cos_theta[torch.arange(0, sample_num), label].view(-1, 1) + # target_cos_theta_m = target_cos_theta - self.margin + target_cos_theta = target_cos_theta - self.margin + # cos_theta.scatter_(1, label.view(-1, 1).long(), target_cos_theta_m) + out = cos_theta.clone() + out.scatter_(1, label.view(-1, 1).long(), target_cos_theta) + + out = out * self.scale + + loss = self.criteria(out, label) + + return loss diff --git a/training/detectors/utils/lsad_api.py b/training/detectors/utils/lsad_api.py new file mode 100644 index 0000000000000000000000000000000000000000..445698614ba6d2aabc47d8625cc181cec2e1fd93 --- /dev/null +++ b/training/detectors/utils/lsad_api.py @@ -0,0 +1,83 @@ +import random + +import torch + + +def augment_domains(self, groups_feature_maps): + # Helper Functions + def hard_example_interpolation(z_i, hard_example, lambda_1): + return z_i + lambda_1 * (hard_example - z_i) + + def hard_example_extrapolation(z_i, mean_latent, lambda_2): + return z_i + lambda_2 * (z_i - mean_latent) + + def add_gaussian_noise(z_i, sigma, lambda_3): + epsilon = torch.randn_like(z_i) * sigma + return z_i + lambda_3 * epsilon + + def difference_transform(z_i, z_j, z_k, lambda_4): + return z_i + lambda_4 * (z_j - z_k) + + def distance(z_i, z_j): + return torch.norm(z_i - z_j) + + domain_number = len(groups_feature_maps[0]) + + # Calculate the mean latent vector for each domain across all groups + domain_means = [] + for domain_idx in range(domain_number): + all_samples_in_domain = torch.cat([group[domain_idx] for group in groups_feature_maps], dim=0) + domain_mean = torch.mean(all_samples_in_domain, dim=0) + domain_means.append(domain_mean) + + # Identify the hard example for each domain across all groups + hard_examples = [] + for domain_idx in range(domain_number): + all_samples_in_domain = torch.cat([group[domain_idx] for group in groups_feature_maps], dim=0) + distances = torch.tensor([distance(z, domain_means[domain_idx]) for z in all_samples_in_domain]) + hard_example = all_samples_in_domain[torch.argmax(distances)] + hard_examples.append(hard_example) + + augmented_groups = [] + + for group_feature_maps in groups_feature_maps: + augmented_domains = [] + + for domain_idx, domain_feature_maps in enumerate(group_feature_maps): + # Choose a random augmentation + augmentations = [ + lambda z: hard_example_interpolation(z, hard_examples[domain_idx], random.random()), + lambda z: hard_example_extrapolation(z, domain_means[domain_idx], random.random()), + lambda z: add_gaussian_noise(z, random.random(), random.random()), + lambda z: difference_transform(z, domain_feature_maps[0], domain_feature_maps[1], random.random()) + ] + chosen_aug = random.choice(augmentations) + augmented = torch.stack([chosen_aug(z) for z in domain_feature_maps]) + augmented_domains.append(augmented) + + augmented_domains = torch.stack(augmented_domains) + augmented_groups.append(augmented_domains) + + return torch.stack(augmented_groups) + + +def mixup_in_latent_space(self, data): + # data shape: [batchsize, num_domains, 3, 256, 256] + bs, num_domains, _, _, _ = data.shape + + # Initialize an empty tensor for mixed data + mixed_data = torch.zeros_like(data) + + # For each sample in the batch + for i in range(bs): + # Step 1: Generate a shuffled index list for the domains + shuffled_idxs = torch.randperm(num_domains) + + # Step 2: Choose random alpha between 0.5 and 2, then sample lambda from beta distribution + alpha = torch.rand(1) * 1.5 + 0.5 # random alpha between 0.5 and 2 + lambda_ = torch.distributions.beta.Beta(alpha, alpha).sample().to(data.device) + + # Step 3: Perform mixup using the shuffled indices + mixed_data[i] = lambda_ * data[i] + (1 - lambda_) * data[i, shuffled_idxs] + + return mixed_data \ No newline at end of file diff --git a/training/detectors/utils/sladd_api.py b/training/detectors/utils/sladd_api.py new file mode 100644 index 0000000000000000000000000000000000000000..9978dd64b15bf06db870605790c29831da6cb8fc --- /dev/null +++ b/training/detectors/utils/sladd_api.py @@ -0,0 +1,668 @@ +import torch +import math +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import yaml +from PIL import Image +import cv2 +from torchvision import transforms as T +from skimage import measure +from skimage.transform import PiecewiseAffineTransform, warp +from torch.autograd import Variable +from scipy.ndimage import binary_erosion, binary_dilation + +from dataset.pair_dataset import pairDataset +from dataset.utils.color_transfer import color_transfer +from dataset.utils.faceswap_utils_sladd import blendImages as alpha_blend_fea +from dataset.utils import faceswap + + + +class Block(nn.Module): + def __init__(self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True): + super(Block, self).__init__() + + if out_filters != in_filters or strides != 1: + self.skip = nn.Conv2d(in_filters, out_filters, + 1, stride=strides, bias=False) + self.skipbn = nn.BatchNorm2d(out_filters) + else: + self.skip = None + + self.relu = nn.ReLU(inplace=True) + rep = [] + + filters = in_filters + if grow_first: # whether the number of filters grows first + rep.append(self.relu) + rep.append(SeparableConv2d(in_filters, out_filters, + 3, stride=1, padding=1, bias=False)) + rep.append(nn.BatchNorm2d(out_filters)) + filters = out_filters + + for i in range(reps - 1): + rep.append(self.relu) + rep.append(SeparableConv2d(filters, filters, + 3, stride=1, padding=1, bias=False)) + rep.append(nn.BatchNorm2d(filters)) + + if not grow_first: + rep.append(self.relu) + rep.append(SeparableConv2d(in_filters, out_filters, + 3, stride=1, padding=1, bias=False)) + rep.append(nn.BatchNorm2d(out_filters)) + + if not start_with_relu: + rep = rep[1:] + else: + rep[0] = nn.ReLU(inplace=False) + + if strides != 1: + rep.append(nn.MaxPool2d(3, strides, 1)) + self.rep = nn.Sequential(*rep) + + def forward(self, inp): + x = self.rep(inp) + + if self.skip is not None: + skip = self.skip(inp) + skip = self.skipbn(skip) + else: + skip = inp + + x += skip + return x + +class SeparableConv2d(nn.Module): + def __init__(self, c_in, c_out, ks, stride=1, padding=0, dilation=1, bias=False): + super(SeparableConv2d, self).__init__() + self.c = nn.Conv2d(c_in, c_in, ks, stride, padding, dilation, groups=c_in, bias=bias) + self.pointwise = nn.Conv2d(c_in, c_out, 1, 1, 0, 1, 1, bias=bias) + + def forward(self, x): + x = self.c(x) + x = self.pointwise(x) + return x + +class Xception_SLADDSyn(nn.Module): + """ + Xception optimized for the ImageNet dataset, as specified in + https://arxiv.org/pdf/1610.02357.pdf + """ + + def __init__(self, num_classes=2, num_region=7, num_type=2, num_mag=1, inc=6): + """ Constructor + Args: + num_classes: number of classes + """ + super(Xception_SLADDSyn, self).__init__() + self.num_region = num_region + self.num_type = num_type + self.num_mag = num_mag + dropout = 0.5 + + # Entry flow + self.iniconv = nn.Conv2d(inc, 32, 3, 2, 0, bias=False) + # self.conv1 = nn.Conv2d(inc, 32, 3, 2, 0, bias=False) + self.bn1 = nn.BatchNorm2d(32) + self.relu = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(32, 64, 3, bias=False) + self.bn2 = nn.BatchNorm2d(64) + # do relu here + + self.block1 = Block( + 64, 128, 2, 2, start_with_relu=False, grow_first=True) + self.block2 = Block( + 128, 256, 2, 2, start_with_relu=True, grow_first=True) + self.block3 = Block( + 256, 728, 2, 2, start_with_relu=True, grow_first=True) + + # middle flow + self.block4 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True) + self.block5 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True) + self.block6 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True) + self.block7 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True) + + self.block8 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True) + self.block9 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True) + self.block10 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True) + self.block11 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True) + + # Exit flow + self.block12 = Block( + 728, 1024, 2, 2, start_with_relu=True, grow_first=False) + + self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1) + self.bn3 = nn.BatchNorm2d(1536) + + # do relu here + self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1) + self.bn4 = nn.BatchNorm2d(2048) + self.fc_region = nn.Sequential(nn.Dropout(p=dropout), nn.Linear(2048, num_region)) + self.fc_type = nn.Sequential(nn.Dropout(p=dropout), nn.Linear(2048, num_type)) + self.fc_mag = nn.Sequential(nn.Dropout(p=dropout), nn.Linear(2048, num_mag)) + + def fea_part1_0(self, x): + x = self.iniconv(x) + x = self.bn1(x) + x = self.relu(x) + + return x + + def fea_part1_1(self, x): + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + + return x + + def fea_part1(self, x): + x = self.iniconv(x) + x = self.bn1(x) + x = self.relu(x) + + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + + return x + + def fea_part2(self, x): + x = self.block1(x) + x = self.block2(x) + x = self.block3(x) + + return x + + def fea_part3(self, x): + x = self.block4(x) + x = self.block5(x) + x = self.block6(x) + x = self.block7(x) + + return x + + def fea_part4(self, x): + x = self.block8(x) + x = self.block9(x) + x = self.block10(x) + x = self.block11(x) + x = self.block12(x) + + return x + + def fea_part5(self, x): + x = self.conv3(x) + x = self.bn3(x) + x = self.relu(x) + + x = self.conv4(x) + x = self.bn4(x) + + return x + + def features(self, input): + x = self.fea_part1(input) + + x = self.fea_part2(x) + x = self.fea_part3(x) + x = self.fea_part4(x) + + x = self.fea_part5(x) + return x + + def classifier(self, features): + x = self.relu(features) + + x = F.adaptive_avg_pool2d(x, (1, 1)) + x = x.view(x.size(0), -1) + out = self.last_linear(x) + return out, x + + def forward(self, input): + x = self.features(input) + x = self.relu(x) + x = F.adaptive_avg_pool2d(x, (1, 1)) + x = x.view(x.size(0), -1) + + region_num = self.fc_region(x) + type_num = self.fc_type(x) + mag = self.fc_mag(x) + + return region_num, type_num, mag + + +def mask_postprocess(mask): + def blur_mask(mask): + blur_k = 2 * np.random.randint(1, 10) - 1 + + # kernel = np.ones((blur_k+1, blur_k+1), np.uint8) + # mask = cv2.erode(mask, kernel) + + mask = cv2.GaussianBlur(mask, (blur_k, blur_k), 0) + + return mask + + # random erode/dilate + prob = np.random.rand() + if prob < 0.3: + erode_k = 2 * np.random.randint(1, 10) + 1 + kernel = np.ones((erode_k, erode_k), np.uint8) + mask = cv2.erode(mask, kernel) + elif prob < 0.6: + erode_k = 2 * np.random.randint(1, 10) + 1 + kernel = np.ones((erode_k, erode_k), np.uint8) + mask = cv2.dilate(mask, kernel) + + # random blur + if np.random.rand() < 0.9: + mask = blur_mask(mask) + + return mask + +def xception(num_region=7, num_type=2, num_mag=1, pretrained='imagenet', inc=6): + model = Xception_SLADDSyn(num_region=num_region, num_type=num_type, num_mag=num_mag, inc=inc) + return model + + + +class TransferModel(nn.Module): + """ + Simple transfer learning model that takes an imagenet pretrained model with + a fc layer as base model and retrains a new fc layer for num_out_classes + """ + + def __init__(self, config, num_region=7, num_type=2, num_mag=1, return_fea=False, inc=6): + super(TransferModel, self).__init__() + self.return_fea = return_fea + def return_pytorch04_xception(pretrained=True): + # Raises warning "src not broadcastable to dst" but thats fine + model = xception(num_region=num_region, num_type=num_type, num_mag=num_mag, inc=inc, pretrained=False) + if pretrained: + # Load model in torch 0.4+ + # model.fc = model.last_linear + # del model.last_linear + state_dict = torch.load(config['pretrained']) + print('Loaded pretrained model (ImageNet)....') + for name, weights in state_dict.items(): + if 'pointwise' in name: + state_dict[name] = weights.unsqueeze( + -1).unsqueeze(-1) + model.load_state_dict(state_dict, strict=False) + # model.last_linear = model.fc + # del model.fc + return model + + self.model = return_pytorch04_xception() + # Replace fc + + if inc != 3: + self.model.iniconv = nn.Conv2d(inc, 32, 3, 2, 0, bias=False) + nn.init.xavier_normal(self.model.iniconv.weight.data, gain=0.02) + + def set_trainable_up_to(self, boolean=False, layername="Conv2d_4a_3x3"): + """ + Freezes all layers below a specific layer and sets the following layers + to true if boolean else only the fully connected final layer + :param boolean: + :param layername: depends on lib, for inception e.g. Conv2d_4a_3x3 + :return: + """ + # Stage-1: freeze all the layers + if layername is None: + for i, param in self.model.named_parameters(): + param.requires_grad = True + return + else: + for i, param in self.model.named_parameters(): + param.requires_grad = False + if boolean: + # Make all layers following the layername layer trainable + ct = [] + found = False + for name, child in self.model.named_children(): + if layername in ct: + found = True + for params in child.parameters(): + params.requires_grad = True + ct.append(name) + if not found: + raise NotImplementedError('Layer not found, cant finetune!'.format( + layername)) + else: + # Make fc trainable + for param in self.model.last_linear.parameters(): + param.requires_grad = True + + def forward(self, x): + region_num, type_num, mag = self.model(x) + return region_num, type_num, mag + + def features(self, x): + x = self.model.features(x) + return x + + def classifier(self, x): + out, x = self.model.classifier(x) + return out, x + + + +def dist(p1, p2): + return math.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2) + + +def generate_random_mask(mask, res=256): + randwl = np.random.randint(10, 60) + randwr = np.random.randint(10, 60) + randhu = np.random.randint(10, 60) + randhd = np.random.randint(10, 60) + newmask = np.zeros(mask.shape) + mask = np.where(mask > 0.1, 1, 0) + props = measure.regionprops(mask) + if len(props) == 0: + return newmask + center_x, center_y = props[0].centroid + center_x = int(round(center_x)) + center_y = int(round(center_y)) + newmask[max(center_x - randwl, 0):min(center_x + randwr, res - 1), + max(center_y - randhu, 0):min(center_x + randhd, res - 1)] = 1 + newmask *= mask + return newmask + + +def random_deform(mask, nrows, ncols, mean=0, std=10): + h, w = mask.shape[:2] + rows = np.linspace(0, h - 1, nrows).astype(np.int32) + cols = np.linspace(0, w - 1, ncols).astype(np.int32) + rows += np.random.normal(mean, std, size=rows.shape).astype(np.int32) + rows += np.random.normal(mean, std, size=cols.shape).astype(np.int32) + rows, cols = np.meshgrid(rows, cols) + anchors = np.vstack([rows.flat, cols.flat]).T + assert anchors.shape[1] == 2 and anchors.shape[0] == ncols * nrows + deformed = anchors + np.random.normal(mean, std, size=anchors.shape) + np.clip(deformed[:, 0], 0, h - 1, deformed[:, 0]) + np.clip(deformed[:, 1], 0, w - 1, deformed[:, 1]) + + trans = PiecewiseAffineTransform() + trans.estimate(anchors, deformed.astype(np.int32)) + warped = warp(mask, trans) + warped *= mask + blured = cv2.GaussianBlur(warped.astype(float), (5, 5), 3) + return blured + + +def get_five_key(landmarks_68): + # get the five key points by using the landmarks + leye_center = (landmarks_68[36] + landmarks_68[39]) * 0.5 + reye_center = (landmarks_68[42] + landmarks_68[45]) * 0.5 + nose = landmarks_68[33] + lmouth = landmarks_68[48] + rmouth = landmarks_68[54] + leye_left = landmarks_68[36] + leye_right = landmarks_68[39] + reye_left = landmarks_68[42] + reye_right = landmarks_68[45] + out = [tuple(x.astype('int32')) for x in [ + leye_center, reye_center, nose, lmouth, rmouth, leye_left, leye_right, reye_left, reye_right + ]] + return out + + +def remove_eyes(image, landmarks, opt): + ##l: left eye; r: right eye, b: both eye + if opt == 'l': + (x1, y1), (x2, y2) = landmarks[5:7] + elif opt == 'r': + (x1, y1), (x2, y2) = landmarks[7:9] + elif opt == 'b': + (x1, y1), (x2, y2) = landmarks[:2] + else: + print('wrong region') + mask = np.zeros_like(image[..., 0]) + line = cv2.line(mask, (x1, y1), (x2, y2), color=(1), thickness=2) + w = dist((x1, y1), (x2, y2)) + dilation = int(w // 4) + if opt != 'b': + dilation *= 4 + line = binary_dilation(line, iterations=dilation) + return line + + +def remove_nose(image, landmarks): + (x1, y1), (x2, y2) = landmarks[:2] + x3, y3 = landmarks[2] + mask = np.zeros_like(image[..., 0]) + x4 = int((x1 + x2) / 2) + y4 = int((y1 + y2) / 2) + line = cv2.line(mask, (x3, y3), (x4, y4), color=(1), thickness=2) + w = dist((x1, y1), (x2, y2)) + dilation = int(w // 4) + line = binary_dilation(line, iterations=dilation) + return line + + +def remove_mouth(image, landmarks): + (x1, y1), (x2, y2) = landmarks[3:5] + mask = np.zeros_like(image[..., 0]) + line = cv2.line(mask, (x1, y1), (x2, y2), color=(1), thickness=2) + w = dist((x1, y1), (x2, y2)) + dilation = int(w // 3) + line = binary_dilation(line, iterations=dilation) + return line + + +def blend_fake_to_real(realimg, real_lmk, fakeimg, fakemask, fake_lmk, deformed_fakemask, type, mag): + # source: fake image + # target: real image + realimg = ((realimg + 1) / 2 * 255).astype(np.uint8) + fakeimg = ((fakeimg + 1) / 2 * 255).astype(np.uint8) + H, W, C = realimg.shape + #Since alignment has already been applied, it can be used directly here. The original code also performed alignment, and this src corresponds to the fake sample. + aligned_src = fakeimg + src_mask = deformed_fakemask + src_mask = src_mask > 0 # (H, W) + + tgt_mask = np.asarray(src_mask, dtype=np.uint8) + tgt_mask = mask_postprocess(tgt_mask) + + ct_modes = ['rct-m', 'rct-fs', 'avg-align', 'faceswap'] + mode_idx = np.random.randint(len(ct_modes)) + mode = ct_modes[mode_idx] + + if mode != 'faceswap': + c_mask = tgt_mask / 255. + c_mask[c_mask > 0] = 1 + if len(c_mask.shape) < 3: + c_mask = np.expand_dims(c_mask, 2) + src_crop = color_transfer(mode, aligned_src, realimg, c_mask) + else: + c_mask = tgt_mask.copy() + c_mask[c_mask > 0] = 255 + masked_tgt = faceswap.apply_mask(realimg, c_mask) + masked_src = faceswap.apply_mask(aligned_src, c_mask) + src_crop = faceswap.correct_colours(masked_tgt, masked_src, np.array(real_lmk)) + + if tgt_mask.mean() < 0.005 or src_crop.max() == 0: + out_blend = realimg + else: + if type == 0: + out_blend, a_mask = alpha_blend_fea(src_crop, realimg, tgt_mask, + featherAmount=0.2 * np.random.rand()) + elif type == 1: + b_mask = (tgt_mask * 255).astype(np.uint8) + l, t, w, h = cv2.boundingRect(b_mask) + center = (int(l + w / 2), int(t + h / 2)) + out_blend = cv2.seamlessClone(src_crop, realimg, b_mask, center, cv2.NORMAL_CLONE) + else: + out_blend = copy_fake_to_real(realimg, src_crop, tgt_mask, mag) + + return out_blend, tgt_mask + + +def copy_fake_to_real(realimg, fakeimg, mask, mag): + mask = np.expand_dims(mask, 2) + newimg = fakeimg * mask * mag + realimg * (1 - mask) + realimg * mask * (1 - mag) + return newimg + + +class synthesizer(nn.Module): + def __init__(self,config): + super(synthesizer, self).__init__() + self.netG = TransferModel(config=config,num_region=10, num_type=4, num_mag=1, inc=6) + normalize = T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + self.transforms = T.Compose([T.ToTensor(), normalize]) + + def parse(self, img, reg, real_lmk, fakemask): + five_key = get_five_key(real_lmk) + if reg == 0: + mask = remove_eyes(img, five_key, 'l') + elif reg == 1: + mask = remove_eyes(img, five_key, 'r') + elif reg == 2: + mask = remove_eyes(img, five_key, 'b') + elif reg == 3: + mask = remove_nose(img, five_key) + elif reg == 4: + mask = remove_mouth(img, five_key) + elif reg == 5: + mask = remove_nose(img, five_key) + remove_eyes(img, five_key, 'l') + elif reg == 6: + mask = remove_nose(img, five_key) + remove_eyes(img, five_key, 'r') + elif reg == 7: + mask = remove_nose(img, five_key) + remove_eyes(img, five_key, 'b') + elif reg == 8: + mask = remove_nose(img, five_key) + remove_mouth(img, five_key) + elif reg == 9: + mask = remove_eyes(img, five_key, 'b') + remove_nose(img, five_key) + remove_mouth(img, five_key) + else: + mask = generate_random_mask(fakemask) + mask = random_deform(mask, 5, 5) + return mask * 1.0 + + def get_variable(self, inputs, cuda=False, **kwargs): + if type(inputs) in [list, np.ndarray]: + inputs = torch.Tensor(inputs) + if cuda: + out = Variable(inputs.cuda(), **kwargs) + else: + out = Variable(inputs, **kwargs) + return out + + def calculate(self, logits): + if logits.shape[1] != 1: + probs = F.softmax(logits, dim=-1) + log_prob = F.log_softmax(logits, dim=-1) + entropy = -(log_prob * probs).sum(1, keepdim=False) + action = probs.multinomial(num_samples=1).data + selected_log_prob = log_prob.gather(1, self.get_variable(action, requires_grad=False)) + else: + probs = torch.sigmoid(logits) + log_prob = torch.log(torch.sigmoid(logits)) + entropy = -(log_prob * probs).sum(1, keepdim=False) + action = probs + selected_log_prob = log_prob + return entropy, selected_log_prob[:, 0], action[:, 0] + + def forward(self, img, fake_img, real_lmk, fake_lmk, real_mask, fake_mask, label=None): + # based on pair_dataset, here, img always is real, fake_img always is fake + region_num, type_num, mag = self.netG(torch.cat((img, fake_img), 1)) + reg_etp, reg_log_prob, reg = self.calculate(region_num) + type_etp, type_log_prob, type = self.calculate(type_num) + mag_etp, mag_log_prob, mag = self.calculate(mag) + entropy = reg_etp + type_etp + mag_etp + log_prob = reg_log_prob + type_log_prob + mag_log_prob + newlabel = [] + typelabel = [] + maglabel = [] + magmask = [] + ##################### + alt_img = torch.ones(img.shape) + alt_mask = np.zeros((img.shape[0], 16, 16)) + if label is None: + label=np.zeros(img.shape[0]) + for i in range(img.shape[0]): + imgcp = np.transpose(img[i].cpu().numpy(), (1, 2, 0)).copy() + fake_imgcp = np.transpose(fake_img[i].cpu().numpy(), (1, 2, 0)).copy() + ##only work for real imgs and not do-nothing choice + if label[i] == 0 and type[i] != 3: + mask = self.parse(fake_imgcp, reg[i], fake_lmk[i].cpu().numpy(), + fake_mask[i].cpu().numpy()) + newimg, newmask = blend_fake_to_real(imgcp, real_lmk[i].cpu().numpy(), + fake_imgcp, fake_mask.cpu().numpy(), + fake_lmk[i].cpu().numpy(), mask, type[i], + mag[i].detach().cpu().numpy()) + newimg = self.transforms(Image.fromarray(np.array(newimg, dtype=np.uint8))) + newlabel.append(int(1)) + typelabel.append(int(type[i].cpu().numpy())) + if type[i] == 2: + magmask.append(int(1)) + else: + magmask.append(int(0)) + else: + newimg = self.transforms(Image.fromarray(np.array((imgcp + 1) / 2 * 255, dtype=np.uint8))) + newmask =real_mask[i].squeeze(2)[:,:,0].cpu().numpy() + newlabel.append(int(label[i])) + if label[i] == 0: + typelabel.append(int(3)) + else: + typelabel.append(int(4)) + magmask.append(int(0)) + if newmask is None: + newmask = np.zeros((16, 16)) + newmask = cv2.resize(newmask, (16, 16), interpolation=cv2.INTER_CUBIC) + alt_img[i] = newimg + alt_mask[i] = newmask + + alt_mask = torch.from_numpy(alt_mask.astype(np.float32)).unsqueeze(1) + newlabel = torch.tensor(newlabel) + typelabel = torch.tensor(typelabel) + maglabel = mag + magmask = torch.tensor(magmask) + return log_prob, entropy, alt_img.detach(), alt_mask.detach(), \ + newlabel.detach(), typelabel.detach(), maglabel.detach(), magmask.detach() + + +if __name__ == '__main__': + + with open(r'H:\code\DeepfakeBench\training\config\detector\sladd_xception.yaml', 'r') as f: + config = yaml.safe_load(f) + syn=synthesizer(config=config).cuda() + config['data_manner'] = 'lmdb' + config['dataset_json_folder'] = 'preprocessing/dataset_json_v3' + config['sample_size']=256 + config['with_mask']=True + config['with_landmark']=True + config['use_data_augmentation']=True + config['data_aug']['rotate_prob']=1 + train_set = pairDataset(config=config, mode='train') + train_data_loader = \ + torch.utils.data.DataLoader( + dataset=train_set, + batch_size=config['train_batchSize'], + shuffle=True, + num_workers=0, + collate_fn=train_set.collate_fn, + ) + from tqdm import tqdm + for iteration, batch in enumerate(tqdm(train_data_loader)): + print(iteration) + imgs,lmks,msks=batch['image'].cuda(),batch['landmark'].cuda(),batch['mask'].cuda() + half = len(imgs) // 2 + img, fake_img, real_lmk, fake_lmk, real_mask, fake_mask = imgs[:half],imgs[half:],lmks[:half],lmks[half:],msks[:half],msks[half:] + log_prob, entropy, new_img, alt_mask, label, type_label, mag_label, mag_mask = \ + syn(img, fake_img, real_lmk, fake_lmk, real_mask, fake_mask) + + if iteration > 10: + break + ... \ No newline at end of file diff --git a/training/detectors/utils/slowfast/__init__.py b/training/detectors/utils/slowfast/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e05c811313e73cc666f973abdad760a01ff2244e --- /dev/null +++ b/training/detectors/utils/slowfast/__init__.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import os +import sys +current_file_path = os.path.abspath(__file__) +parent_dir = os.path.dirname(os.path.dirname(current_file_path)) +project_root_dir = os.path.dirname(parent_dir) +sys.path.append(parent_dir) +sys.path.append(project_root_dir) + +from slowfast.utils.env import setup_environment + +setup_environment() diff --git a/training/detectors/utils/slowfast/config/__init__.py b/training/detectors/utils/slowfast/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8dbe96a785072a24a9bcc4841a1934024f2b06a1 --- /dev/null +++ b/training/detectors/utils/slowfast/config/__init__.py @@ -0,0 +1,2 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. diff --git a/training/detectors/utils/slowfast/config/custom_config.py b/training/detectors/utils/slowfast/config/custom_config.py new file mode 100644 index 0000000000000000000000000000000000000000..8131da2951d8cb629f664b39da4675d0ae5adee5 --- /dev/null +++ b/training/detectors/utils/slowfast/config/custom_config.py @@ -0,0 +1,9 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +"""Add custom configs and default values""" + + +def add_custom_config(_C): + # Add your own customized configs. + pass diff --git a/training/detectors/utils/slowfast/config/defaults(1).py b/training/detectors/utils/slowfast/config/defaults(1).py new file mode 100644 index 0000000000000000000000000000000000000000..083a3c1f31c822143f7a0a68b374ae8b83430b3c --- /dev/null +++ b/training/detectors/utils/slowfast/config/defaults(1).py @@ -0,0 +1,816 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +"""Configs.""" +import yaml +from fvcore.common.config import CfgNode as CfgNodeOri + +from . import custom_config +def load_yaml_with_base(text: str, allow_unsafe: bool = False): + """ + Just like `yaml.load(open(filename))`, but inherit attributes from its + `_BASE_`. + Args: + text (str): the file name of the current config. Will be used to + find the base config file. + allow_unsafe (bool): whether to allow loading the config file with + `yaml.unsafe_load`. + Returns: + (dict): the loaded yaml + """ + cfg = yaml.load(text, Loader=yaml.FullLoader) + return cfg +class CfgNode(CfgNodeOri): + def merge_from_str(self, text, allow_unsafe=False): + loaded_cfg = load_yaml_with_base(text, allow_unsafe=allow_unsafe) + loaded_cfg = type(self)(loaded_cfg) + self.merge_from_other_cfg(loaded_cfg) + +# ----------------------------------------------------------------------------- +# Config definition +# ----------------------------------------------------------------------------- +_C = CfgNode() + + +# ---------------------------------------------------------------------------- # +# Batch norm options +# ---------------------------------------------------------------------------- # +_C.BN = CfgNode() + +# Precise BN stats. +_C.BN.USE_PRECISE_STATS = False + +# Number of samples use to compute precise bn. +_C.BN.NUM_BATCHES_PRECISE = 200 + +# Weight decay value that applies on BN. +_C.BN.WEIGHT_DECAY = 0.0 + +# Norm type, options include `batchnorm`, `sub_batchnorm`, `sync_batchnorm` +_C.BN.NORM_TYPE = "batchnorm" + +# Parameter for SubBatchNorm, where it splits the batch dimension into +# NUM_SPLITS splits, and run BN on each of them separately independently. +_C.BN.NUM_SPLITS = 1 + +# Parameter for NaiveSyncBatchNorm3d, where the stats across `NUM_SYNC_DEVICES` +# devices will be synchronized. +_C.BN.NUM_SYNC_DEVICES = 1 + + +# ---------------------------------------------------------------------------- # +# Training options. +# ---------------------------------------------------------------------------- # +_C.TRAIN = CfgNode() + +# If True Train the model, else skip training. +_C.TRAIN.ENABLE = True + +# Dataset. +_C.TRAIN.DATASET = "kinetics" + +# Total mini-batch size. +_C.TRAIN.BATCH_SIZE = 64 + +_C.TRAIN.SPLIT = "train_subset2.pth" +# Evaluate model on test data every eval period epochs. +_C.TRAIN.EVAL_PERIOD = 1 + +# Save model checkpoint every checkpoint period epochs. +_C.TRAIN.CHECKPOINT_PERIOD = 1 + +# Save model checkpoint every checkpoint period iters. +_C.TRAIN.CHECKPOINT_PERIOD_BY_ITER = 500 + + +# Resume training from the latest checkpoint in the output directory. +_C.TRAIN.AUTO_RESUME = True + +# Path to the checkpoint to load the initial weight. +_C.TRAIN.CHECKPOINT_FILE_PATH = "" + +# Checkpoint types include `caffe2` or `pytorch`. +_C.TRAIN.CHECKPOINT_TYPE = "pytorch" + +# If True, perform inflation when loading checkpoint. +_C.TRAIN.CHECKPOINT_INFLATE = False + + +# ---------------------------------------------------------------------------- # +# Testing options +# ---------------------------------------------------------------------------- # +_C.TEST = CfgNode() + +# If True test the model, else skip the testing. +_C.TEST.ENABLE = True + +# Dataset for testing. +_C.TEST.DATASET = "kinetics" + +_C.TEST.SPLIT = "test_subset2.pth" +# Total mini-batch size +_C.TEST.BATCH_SIZE = 8 + +# Path to the checkpoint to load the initial weight. +_C.TEST.CHECKPOINT_FILE_PATH = "" + +# Number of clips to sample from a video uniformly for aggregating the +# prediction results. +_C.TEST.NUM_ENSEMBLE_VIEWS = 10 + +# Number of crops to sample from a frame spatially for aggregating the +# prediction results. +_C.TEST.NUM_SPATIAL_CROPS = 3 + +# Checkpoint types include `caffe2` or `pytorch`. +_C.TEST.CHECKPOINT_TYPE = "pytorch" +# Path to saving prediction results file. +_C.TEST.SAVE_RESULTS_PATH = "" +# ----------------------------------------------------------------------------- +# ResNet options +# ----------------------------------------------------------------------------- +_C.RESNET = CfgNode() + +# Transformation function. +_C.RESNET.TRANS_FUNC = "bottleneck_transform" + +# Number of groups. 1 for ResNet, and larger than 1 for ResNeXt). +_C.RESNET.NUM_GROUPS = 1 + +# Width of each group (64 -> ResNet; 4 -> ResNeXt). +_C.RESNET.WIDTH_PER_GROUP = 64 + +# Apply relu in a inplace manner. +_C.RESNET.INPLACE_RELU = True + +# Apply stride to 1x1 conv. +_C.RESNET.STRIDE_1X1 = False + +# If true, initialize the gamma of the final BN of each block to zero. +_C.RESNET.ZERO_INIT_FINAL_BN = False + +# Number of weight layers. +_C.RESNET.DEPTH = 50 + + +# label of branchs +_C.RESNET.LABELS = ["continus","discontinus"] + +# If the current block has more than NUM_BLOCK_TEMP_KERNEL blocks, use temporal +# kernel of 1 for the rest of the blocks. +_C.RESNET.NUM_BLOCK_TEMP_KERNEL = [[3], [4], [6], [3]] + +# Size of stride on different res stages. +_C.RESNET.SPATIAL_STRIDES = [[1], [2], [2], [2]] + +# Size of dilation on different res stages. +_C.RESNET.SPATIAL_DILATIONS = [[1], [1], [1], [1]] + + +# ----------------------------------------------------------------------------- +# Nonlocal options +# ----------------------------------------------------------------------------- +_C.NONLOCAL = CfgNode() + +# Index of each stage and block to add nonlocal layers. +_C.NONLOCAL.LOCATION = [[[]], [[]], [[]], [[]]] + +# Number of group for nonlocal for each stage. +_C.NONLOCAL.GROUP = [[1], [1], [1], [1]] + +# Instatiation to use for non-local layer. +_C.NONLOCAL.INSTANTIATION = "dot_product" + + +# Size of pooling layers used in Non-Local. +_C.NONLOCAL.POOL = [ + # Res2 + [[1, 2, 2], [1, 2, 2]], + # Res3 + [[1, 2, 2], [1, 2, 2]], + # Res4 + [[1, 2, 2], [1, 2, 2]], + # Res5 + [[1, 2, 2], [1, 2, 2]], +] + +# ----------------------------------------------------------------------------- +# Model options +# ----------------------------------------------------------------------------- +_C.MODEL = CfgNode() + +# Model architecture. +_C.MODEL.ARCH = "slowfast" + +# Model name +_C.MODEL.MODEL_NAME = "SlowFast" + +# The number of classes to predict for the model. +_C.MODEL.NUM_CLASSES = 400 + +# Loss function. +_C.MODEL.LOSS_FUNC = "cross_entropy" + +_C.MODEL.MASK_WEIGHT = 100 + +_C.MODEL.CLASS_WEIGHT = 1 + +# Model architectures that has one single pathway. +_C.MODEL.SINGLE_PATHWAY_ARCH = ["c2d", "i3d", "slow"] + +# Model architectures that has multiple pathways. +_C.MODEL.MULTI_PATHWAY_ARCH = ["slowfast"] + +# Dropout rate before final projection in the backbone. +_C.MODEL.DROPOUT_RATE = 0.5 + +# The std to initialize the fc layer(s). +_C.MODEL.FC_INIT_STD = 0.01 + +# Activation layer for the output head. +_C.MODEL.HEAD_ACT = "softmax" + + +# ----------------------------------------------------------------------------- +# SlowFast options +# ----------------------------------------------------------------------------- +_C.SLOWFAST = CfgNode() + +# Corresponds to the inverse of the channel reduction ratio, $\beta$ between +# the Slow and Fast pathways. +_C.SLOWFAST.BETA_INV = 8 + +# Corresponds to the frame rate reduction ratio, $\alpha$ between the Slow and +# Fast pathways. +_C.SLOWFAST.ALPHA = 8 + +# Ratio of channel dimensions between the Slow and Fast pathways. +_C.SLOWFAST.FUSION_CONV_CHANNEL_RATIO = 2 + +# Kernel dimension used for fusing information from Fast pathway to Slow +# pathway. +_C.SLOWFAST.FUSION_KERNEL_SZ = 5 + + +# ----------------------------------------------------------------------------- +# Data options +# ----------------------------------------------------------------------------- +_C.DATA = CfgNode() + +# The path to the data directory. +_C.DATA.PATH_TO_DATA_DIR = "" + +_C.DATA.DATASET = "faceforensics" + +_C.DATA.MODE = "" + +_C.DATA.ADAPTIVE = False + +_C.DATA.SCALE = 1.0 +# The separator used between path and label. +_C.DATA.PATH_LABEL_SEPARATOR = " " + +# Video path prefix if any. +_C.DATA.PATH_PREFIX = "" + +# The spatial crop size of the input clip. +_C.DATA.CROP_SIZE = 224 + +# The number of frames of the input clip. +_C.DATA.NUM_FRAMES = 8 + +_C.DATA.NUM_FRAMES_RANGE = [1,2,3,4,5,6,7,8] + +# The video sampling rate of the input clip. +_C.DATA.SAMPLING_RATE = 8 + +# The mean value of the video raw pixels across the R G B channels. +_C.DATA.MEAN = [0.45, 0.45, 0.45] +# List of input frame channel dimensions. + +_C.DATA.INPUT_CHANNEL_NUM = [3, 3] + +# The std value of the video raw pixels across the R G B channels. +_C.DATA.STD = [0.225, 0.225, 0.225] + +# The spatial augmentation jitter scales for training. +_C.DATA.TRAIN_JITTER_SCALES = [256, 320] + +# The spatial crop size for training. +_C.DATA.TRAIN_CROP_SIZE = 224 + +# The spatial crop size for testing. +_C.DATA.TEST_CROP_SIZE = 256 + +# Input videos may has different fps, convert it to the target video fps before +# frame sampling. +_C.DATA.TARGET_FPS = 30 + +# Decoding backend, options include `pyav` or `torchvision` +_C.DATA.DECODING_BACKEND = "pyav" + +# if True, sample uniformly in [1 / max_scale, 1 / min_scale] and take a +# reciprocal to get the scale. If False, take a uniform sample from +# [min_scale, max_scale]. +_C.DATA.INV_UNIFORM_SAMPLE = False + +# If True, perform random horizontal flip on the video frames during training. +_C.DATA.RANDOM_FLIP = True + +# If True, calculdate the map as metric. +_C.DATA.MULTI_LABEL = False + +# Method to perform the ensemble, options include "sum" and "max". +_C.DATA.ENSEMBLE_METHOD = "sum" + +# If True, revert the default input channel (RBG <-> BGR). +_C.DATA.REVERSE_INPUT_CHANNEL = False + + +# ---------------------------------------------------------------------------- # +# Optimizer options +# ---------------------------------------------------------------------------- # +_C.SOLVER = CfgNode() + +# Base learning rate. +_C.SOLVER.BASE_LR = 0.1 + +# Learning rate policy (see utils/lr_policy.py for options and examples). +_C.SOLVER.LR_POLICY = "cosine" + +# Exponential decay factor. +_C.SOLVER.GAMMA = 0.1 + +# Step size for 'exp' and 'cos' policies (in epochs). +_C.SOLVER.STEP_SIZE = 1 + +# Steps for 'steps_' policies (in epochs). +_C.SOLVER.STEPS = [] + +# Learning rates for 'steps_' policies. +_C.SOLVER.LRS = [] + +# Maximal number of epochs. +_C.SOLVER.MAX_EPOCH = 300 + +# Momentum. +_C.SOLVER.MOMENTUM = 0.9 + +# Momentum dampening. +_C.SOLVER.DAMPENING = 0.0 + +# Nesterov momentum. +_C.SOLVER.NESTEROV = True + +# L2 regularization. +_C.SOLVER.WEIGHT_DECAY = 1e-4 + +# Start the warm up from SOLVER.BASE_LR * SOLVER.WARMUP_FACTOR. +_C.SOLVER.WARMUP_FACTOR = 0.1 + +# Gradually warm up the SOLVER.BASE_LR over this number of epochs. +_C.SOLVER.WARMUP_EPOCHS = 0.0 + +# The start learning rate of the warm up. +_C.SOLVER.WARMUP_START_LR = 0.01 + +# Optimization method. +_C.SOLVER.OPTIMIZING_METHOD = "sgd" + +_C.SOLVER.LR_STEP = 50000 + +_C.SOLVER.TOTAL_STEP = 200000 + +_C.SOLVER.FREEZE_STEP = 10000 + + +# ---------------------------------------------------------------------------- # +# Misc options +# ---------------------------------------------------------------------------- # + +# Number of GPUs to use (applies to both training and testing). +_C.NUM_GPUS = 1 + +# Number of machine to use for the job. +_C.NUM_SHARDS = 1 + +# The index of the current machine. +_C.SHARD_ID = 0 + +# Output basedir. +_C.OUTPUT_DIR = "./tmp" + +# train module +_C.TRAIN_MODULE= "train_unet_by_iter" + +# Note that non-determinism may still be present due to non-deterministic +# operator implementations in GPU operator libraries. +_C.RNG_SEED = 1 + +# Log period in iters. +_C.LOG_PERIOD = 10 + +# If True, log the model info. +_C.LOG_MODEL_INFO = True + +# Distributed backend. +_C.DIST_BACKEND = "nccl" + +# ---------------------------------------------------------------------------- # +# Benchmark options +# ---------------------------------------------------------------------------- # +_C.BENCHMARK = CfgNode() + +# Number of epochs for data loading benchmark. +_C.BENCHMARK.NUM_EPOCHS = 5 + +# Log period in iters for data loading benchmark. +_C.BENCHMARK.LOG_PERIOD = 100 + +# If True, shuffle dataloader for epoch during benchmark. +_C.BENCHMARK.SHUFFLE = True + + +# ---------------------------------------------------------------------------- # +# Common train/test data loader options +# ---------------------------------------------------------------------------- # +_C.DATA_LOADER = CfgNode() + +# Number of data loader workers per training process. +_C.DATA_LOADER.NUM_WORKERS = 8 + +# Load data to pinned host memory. +_C.DATA_LOADER.PIN_MEMORY = True + +# Enable multi thread decoding. +_C.DATA_LOADER.ENABLE_MULTI_THREAD_DECODE = False + + +# ---------------------------------------------------------------------------- # +# Detection options. +# ---------------------------------------------------------------------------- # +_C.DETECTION = CfgNode() + +# Whether enable video detection. +_C.DETECTION.ENABLE = False + +# Aligned version of RoI. More details can be found at slowfast/models/head_helper.py +_C.DETECTION.ALIGNED = True + +# Spatial scale factor. +_C.DETECTION.SPATIAL_SCALE_FACTOR = 16 + +# RoI tranformation resolution. +_C.DETECTION.ROI_XFORM_RESOLUTION = 7 + + +# ----------------------------------------------------------------------------- +# AVA Dataset options +# ----------------------------------------------------------------------------- +_C.AVA = CfgNode() + +# Directory path of frames. +_C.AVA.FRAME_DIR = "/mnt/fair-flash3-east/ava_trainval_frames.img/" + +# Directory path for files of frame lists. +_C.AVA.FRAME_LIST_DIR = ( + "/mnt/vol/gfsai-flash3-east/ai-group/users/haoqifan/ava/frame_list/" +) + +# Directory path for annotation files. +_C.AVA.ANNOTATION_DIR = ( + "/mnt/vol/gfsai-flash3-east/ai-group/users/haoqifan/ava/frame_list/" +) + +# Filenames of training samples list files. +_C.AVA.TRAIN_LISTS = ["train.csv"] + +# Filenames of test samples list files. +_C.AVA.TEST_LISTS = ["val.csv"] + +# Filenames of box list files for training. Note that we assume files which +# contains predicted boxes will have a suffix "predicted_boxes" in the +# filename. +_C.AVA.TRAIN_GT_BOX_LISTS = ["ava_train_v2.2.csv"] +_C.AVA.TRAIN_PREDICT_BOX_LISTS = [] + +# Filenames of box list files for test. +_C.AVA.TEST_PREDICT_BOX_LISTS = ["ava_val_predicted_boxes.csv"] + +# This option controls the score threshold for the predicted boxes to use. +_C.AVA.DETECTION_SCORE_THRESH = 0.9 + +# If use BGR as the format of input frames. +_C.AVA.BGR = False + +# Training augmentation parameters +# Whether to use color augmentation method. +_C.AVA.TRAIN_USE_COLOR_AUGMENTATION = False + +# Whether to only use PCA jitter augmentation when using color augmentation +# method (otherwise combine with color jitter method). +_C.AVA.TRAIN_PCA_JITTER_ONLY = True + +# Eigenvalues for PCA jittering. Note PCA is RGB based. +_C.AVA.TRAIN_PCA_EIGVAL = [0.225, 0.224, 0.229] + +# Eigenvectors for PCA jittering. +_C.AVA.TRAIN_PCA_EIGVEC = [ + [-0.5675, 0.7192, 0.4009], + [-0.5808, -0.0045, -0.8140], + [-0.5836, -0.6948, 0.4203], +] + +# Whether to do horizontal flipping during test. +_C.AVA.TEST_FORCE_FLIP = False + +# Whether to use full test set for validation split. +_C.AVA.FULL_TEST_ON_VAL = False + +# The name of the file to the ava label map. +_C.AVA.LABEL_MAP_FILE = "ava_action_list_v2.2_for_activitynet_2019.pbtxt" + +# The name of the file to the ava exclusion. +_C.AVA.EXCLUSION_FILE = "ava_val_excluded_timestamps_v2.2.csv" + +# The name of the file to the ava groundtruth. +_C.AVA.GROUNDTRUTH_FILE = "ava_val_v2.2.csv" + +# Backend to process image, includes `pytorch` and `cv2`. +_C.AVA.IMG_PROC_BACKEND = "cv2" + +# ---------------------------------------------------------------------------- # +# Multigrid training options +# See https://arxiv.org/abs/1912.00998 for details about multigrid training. +# ---------------------------------------------------------------------------- # +_C.MULTIGRID = CfgNode() + +# Multigrid training allows us to train for more epochs with fewer iterations. +# This hyperparameter specifies how many times more epochs to train. +# The default setting in paper trains for 1.5x more epochs than baseline. +_C.MULTIGRID.EPOCH_FACTOR = 1.5 + +# Enable short cycles. +_C.MULTIGRID.SHORT_CYCLE = False +# Short cycle additional spatial dimensions relative to the default crop size. +_C.MULTIGRID.SHORT_CYCLE_FACTORS = [0.5, 0.5 ** 0.5] + +_C.MULTIGRID.LONG_CYCLE = False +# (Temporal, Spatial) dimensions relative to the default shape. +_C.MULTIGRID.LONG_CYCLE_FACTORS = [ + (0.25, 0.5 ** 0.5), + (0.5, 0.5 ** 0.5), + (0.5, 1), + (1, 1), +] + +# While a standard BN computes stats across all examples in a GPU, +# for multigrid training we fix the number of clips to compute BN stats on. +# See https://arxiv.org/abs/1912.00998 for details. +_C.MULTIGRID.BN_BASE_SIZE = 8 + +# Multigrid training epochs are not proportional to actual training time or +# computations, so _C.TRAIN.EVAL_PERIOD leads to too frequent or rare +# evaluation. We use a multigrid-specific rule to determine when to evaluate: +# This hyperparameter defines how many times to evaluate a model per long +# cycle shape. +_C.MULTIGRID.EVAL_FREQ = 3 + +# No need to specify; Set automatically and used as global variables. +_C.MULTIGRID.LONG_CYCLE_SAMPLING_RATE = 0 +_C.MULTIGRID.DEFAULT_B = 0 +_C.MULTIGRID.DEFAULT_T = 0 +_C.MULTIGRID.DEFAULT_S = 0 + +# ----------------------------------------------------------------------------- +# Tensorboard Visualization Options +# ----------------------------------------------------------------------------- +_C.TENSORBOARD = CfgNode() + +# Log to summary writer, this will automatically. +# log loss, lr and metrics during train/eval. +_C.TENSORBOARD.ENABLE = False +# Provide path to prediction results for visualization. +# This is a pickle file of [prediction_tensor, label_tensor] +_C.TENSORBOARD.PREDICTIONS_PATH = "" +# Path to directory for tensorboard logs. +# Default to to cfg.OUTPUT_DIR/runs-{cfg.TRAIN.DATASET}. +_C.TENSORBOARD.LOG_DIR = "" +# Path to a json file providing class_name - id mapping +# in the format {"class_name1": id1, "class_name2": id2, ...}. +# This file must be provided to enable plotting confusion matrix +# by a subset or parent categories. +_C.TENSORBOARD.CLASS_NAMES_PATH = "" + +# Path to a json file for categories -> classes mapping +# in the format {"parent_class": ["child_class1", "child_class2",...], ...}. +_C.TENSORBOARD.CATEGORIES_PATH = "" + +# Config for confusion matrices visualization. +_C.TENSORBOARD.CONFUSION_MATRIX = CfgNode() +# Visualize confusion matrix. +_C.TENSORBOARD.CONFUSION_MATRIX.ENABLE = False +# Figure size of the confusion matrices plotted. +_C.TENSORBOARD.CONFUSION_MATRIX.FIGSIZE = [8, 8] +# Path to a subset of categories to visualize. +# File contains class names separated by newline characters. +_C.TENSORBOARD.CONFUSION_MATRIX.SUBSET_PATH = "" + +# Config for histogram visualization. +_C.TENSORBOARD.HISTOGRAM = CfgNode() +# Visualize histograms. +_C.TENSORBOARD.HISTOGRAM.ENABLE = False +# Path to a subset of classes to plot histograms. +# Class names must be separated by newline characters. +_C.TENSORBOARD.HISTOGRAM.SUBSET_PATH = "" +# Visualize top-k most predicted classes on histograms for each +# chosen true label. +_C.TENSORBOARD.HISTOGRAM.TOPK = 10 +# Figure size of the histograms plotted. +_C.TENSORBOARD.HISTOGRAM.FIGSIZE = [8, 8] + +# Config for layers' weights and activations visualization. +# _C.TENSORBOARD.ENABLE must be True. +_C.TENSORBOARD.MODEL_VIS = CfgNode() + +# If False, skip model visualization. +_C.TENSORBOARD.MODEL_VIS.ENABLE = False + +# If False, skip visualizing model weights. +_C.TENSORBOARD.MODEL_VIS.MODEL_WEIGHTS = False + +# If False, skip visualizing model activations. +_C.TENSORBOARD.MODEL_VIS.ACTIVATIONS = False + +# If False, skip visualizing input videos. +_C.TENSORBOARD.MODEL_VIS.INPUT_VIDEO = False + + +# List of strings containing data about layer names and their indexing to +# visualize weights and activations for. The indexing is meant for +# choosing a subset of activations outputed by a layer for visualization. +# If indexing is not specified, visualize all activations outputed by the layer. +# For each string, layer name and indexing is separated by whitespaces. +# e.g.: [layer1 1,2;1,2, layer2, layer3 150,151;3,4]; this means for each array `arr` +# along the batch dimension in `layer1`, we take arr[[1, 2], [1, 2]] +_C.TENSORBOARD.MODEL_VIS.LAYER_LIST = [] +# Top-k predictions to plot on videos +_C.TENSORBOARD.MODEL_VIS.TOPK_PREDS = 1 +# Colormap to for text boxes and bounding boxes colors +_C.TENSORBOARD.MODEL_VIS.COLORMAP = "Pastel2" +# Config for visualization video inputs with Grad-CAM. +# _C.TENSORBOARD.ENABLE must be True. +_C.TENSORBOARD.MODEL_VIS.GRAD_CAM = CfgNode() +# Whether to run visualization using Grad-CAM technique. +_C.TENSORBOARD.MODEL_VIS.GRAD_CAM.ENABLE = True +# CNN layers to use for Grad-CAM. The number of layers must be equal to +# number of pathway(s). +_C.TENSORBOARD.MODEL_VIS.GRAD_CAM.LAYER_LIST = [] +# If True, visualize Grad-CAM using true labels for each instances. +# If False, use the highest predicted class. +_C.TENSORBOARD.MODEL_VIS.GRAD_CAM.USE_TRUE_LABEL = False +# Colormap to for text boxes and bounding boxes colors +_C.TENSORBOARD.MODEL_VIS.GRAD_CAM.COLORMAP = "viridis" + +# Config for visualization for wrong prediction visualization. +# _C.TENSORBOARD.ENABLE must be True. +_C.TENSORBOARD.WRONG_PRED_VIS = CfgNode() +_C.TENSORBOARD.WRONG_PRED_VIS.ENABLE = False +# Folder tag to origanize model eval videos under. +_C.TENSORBOARD.WRONG_PRED_VIS.TAG = "Incorrectly classified videos." +# Subset of labels to visualize. Only wrong predictions with true labels +# within this subset is visualized. +_C.TENSORBOARD.WRONG_PRED_VIS.SUBSET_PATH = "" + + + +############### +_C.JITTER = CfgNode() + +_C.JITTER.ENABLE = False + +_C.JITTER.CONTINUS_METHODS=["blend_diff_person","blend_downsampled","blend_same_person"] +_C.JITTER.DISCONTINUS_METHODS=["light", "rotate", "skip"] + +_C.JITTER.STRONG_INNER_CLIP_MASK_JITTER= False + +# ---------------------------------------------------------------------------- # +# Demo options +# ---------------------------------------------------------------------------- # +_C.DEMO = CfgNode() + +# Run model in DEMO mode. +_C.DEMO.ENABLE = False + +# Path to a json file providing class_name - id mapping +# in the format {"class_name1": id1, "class_name2": id2, ...}. +_C.DEMO.LABEL_FILE_PATH = "" + +# Specify a camera device as input. This will be prioritized +# over input video if set. +# If -1, use input video instead. +_C.DEMO.WEBCAM = -1 + +# Path to input video for demo. +_C.DEMO.INPUT_VIDEO = "" +# Custom width for reading input video data. +_C.DEMO.DISPLAY_WIDTH = 0 +# Custom height for reading input video data. +_C.DEMO.DISPLAY_HEIGHT = 0 +# Path to Detectron2 object detection model configuration, +# only used for detection tasks. +_C.DEMO.DETECTRON2_CFG = "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml" +# Path to Detectron2 object detection model pre-trained weights. +_C.DEMO.DETECTRON2_WEIGHTS = "detectron2://COCO-Detection/faster_rcnn_R_50_FPN_3x/137849458/model_final_280758.pkl" +# Threshold for choosing predicted bounding boxes by Detectron2. +_C.DEMO.DETECTRON2_THRESH = 0.9 +# Number of overlapping frames between 2 consecutive clips. +# Increase this number for more frequent action predictions. +# The number of overlapping frames cannot be larger than +# half of the sequence length `cfg.DATA.NUM_FRAMES * cfg.DATA.SAMPLING_RATE` +_C.DEMO.BUFFER_SIZE = 0 +# If specified, the visualized outputs will be written this a video file of +# this path. Otherwise, the visualized outputs will be displayed in a window. +_C.DEMO.OUTPUT_FILE = "" +# Frames per second rate for writing to output video file. +# If not set (-1), use fps rate from input file. +_C.DEMO.OUTPUT_FPS = -1 +# Input format from demo video reader ("RGB" or "BGR"). +_C.DEMO.INPUT_FORMAT = "BGR" +# Draw visualization frames in [keyframe_idx - CLIP_VIS_SIZE, keyframe_idx + CLIP_VIS_SIZE] inclusively. +_C.DEMO.CLIP_VIS_SIZE = 10 +# Number of processes to run video visualizer. +_C.DEMO.NUM_VIS_INSTANCES = 2 + +# Path to pre-computed predicted boxes +_C.DEMO.PREDS_BOXES = "" +# Whether to run in with multi-threaded video reader. +_C.DEMO.THREAD_ENABLE = False +# Take one clip for every `DEMO.NUM_CLIPS_SKIP` + 1 for prediction and visualization. +# This is used for fast demo speed by reducing the prediction/visualiztion frequency. +# If -1, take the most recent read clip for visualization. This mode is only supported +# if `DEMO.THREAD_ENABLE` is set to True. +_C.DEMO.NUM_CLIPS_SKIP = 0 +# Path to ground-truth boxes and labels (optional) +_C.DEMO.GT_BOXES = "" +# The starting second of the video w.r.t bounding boxes file. +_C.DEMO.STARTING_SECOND = 900 +# Frames per second of the input video/folder of images. +_C.DEMO.FPS = 30 +# Visualize with top-k predictions or predictions above certain threshold(s). +# Option: {"thres", "top-k"} +_C.DEMO.VIS_MODE = "thres" +# Threshold for common class names. +_C.DEMO.COMMON_CLASS_THRES = 0.7 +# Theshold for uncommon class names. This will not be +# used if `_C.DEMO.COMMON_CLASS_NAMES` is empty. +_C.DEMO.UNCOMMON_CLASS_THRES = 0.3 +# This is chosen based on distribution of examples in +# each classes in AVA dataset. +_C.DEMO.COMMON_CLASS_NAMES = [ + "watch (a person)", + "talk to (e.g., self, a person, a group)", + "listen to (a person)", + "touch (an object)", + "carry/hold (an object)", + "walk", + "sit", + "lie/sleep", + "bend/bow (at the waist)", +] +# Slow-motion rate for the visualization. The visualized portions of the +# video will be played `_C.DEMO.SLOWMO` times slower than usual speed. +_C.DEMO.SLOWMO = 1 + +# Add custom config with default values. +custom_config.add_custom_config(_C) + + +def _assert_and_infer_cfg(cfg): + # BN assertions. + if cfg.BN.USE_PRECISE_STATS: + assert cfg.BN.NUM_BATCHES_PRECISE >= 0 + # TRAIN assertions. + assert cfg.TRAIN.CHECKPOINT_TYPE in ["pytorch", "caffe2"] + assert cfg.TRAIN.BATCH_SIZE % cfg.NUM_GPUS == 0 + + # TEST assertions. + assert cfg.TEST.CHECKPOINT_TYPE in ["pytorch", "caffe2"] + assert cfg.TEST.BATCH_SIZE % cfg.NUM_GPUS == 0 + assert cfg.TEST.NUM_SPATIAL_CROPS == 3 + + # RESNET assertions. + assert cfg.RESNET.NUM_GROUPS > 0 + assert cfg.RESNET.WIDTH_PER_GROUP > 0 + assert cfg.RESNET.WIDTH_PER_GROUP % cfg.RESNET.NUM_GROUPS == 0 + + # General assertions. + assert cfg.SHARD_ID < cfg.NUM_SHARDS + return cfg + + +def get_cfg(): + """ + Get a copy of the default config. + """ + return _assert_and_infer_cfg(_C.clone()) diff --git a/training/detectors/utils/slowfast/config/defaults.py b/training/detectors/utils/slowfast/config/defaults.py new file mode 100644 index 0000000000000000000000000000000000000000..083a3c1f31c822143f7a0a68b374ae8b83430b3c --- /dev/null +++ b/training/detectors/utils/slowfast/config/defaults.py @@ -0,0 +1,816 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +"""Configs.""" +import yaml +from fvcore.common.config import CfgNode as CfgNodeOri + +from . import custom_config +def load_yaml_with_base(text: str, allow_unsafe: bool = False): + """ + Just like `yaml.load(open(filename))`, but inherit attributes from its + `_BASE_`. + Args: + text (str): the file name of the current config. Will be used to + find the base config file. + allow_unsafe (bool): whether to allow loading the config file with + `yaml.unsafe_load`. + Returns: + (dict): the loaded yaml + """ + cfg = yaml.load(text, Loader=yaml.FullLoader) + return cfg +class CfgNode(CfgNodeOri): + def merge_from_str(self, text, allow_unsafe=False): + loaded_cfg = load_yaml_with_base(text, allow_unsafe=allow_unsafe) + loaded_cfg = type(self)(loaded_cfg) + self.merge_from_other_cfg(loaded_cfg) + +# ----------------------------------------------------------------------------- +# Config definition +# ----------------------------------------------------------------------------- +_C = CfgNode() + + +# ---------------------------------------------------------------------------- # +# Batch norm options +# ---------------------------------------------------------------------------- # +_C.BN = CfgNode() + +# Precise BN stats. +_C.BN.USE_PRECISE_STATS = False + +# Number of samples use to compute precise bn. +_C.BN.NUM_BATCHES_PRECISE = 200 + +# Weight decay value that applies on BN. +_C.BN.WEIGHT_DECAY = 0.0 + +# Norm type, options include `batchnorm`, `sub_batchnorm`, `sync_batchnorm` +_C.BN.NORM_TYPE = "batchnorm" + +# Parameter for SubBatchNorm, where it splits the batch dimension into +# NUM_SPLITS splits, and run BN on each of them separately independently. +_C.BN.NUM_SPLITS = 1 + +# Parameter for NaiveSyncBatchNorm3d, where the stats across `NUM_SYNC_DEVICES` +# devices will be synchronized. +_C.BN.NUM_SYNC_DEVICES = 1 + + +# ---------------------------------------------------------------------------- # +# Training options. +# ---------------------------------------------------------------------------- # +_C.TRAIN = CfgNode() + +# If True Train the model, else skip training. +_C.TRAIN.ENABLE = True + +# Dataset. +_C.TRAIN.DATASET = "kinetics" + +# Total mini-batch size. +_C.TRAIN.BATCH_SIZE = 64 + +_C.TRAIN.SPLIT = "train_subset2.pth" +# Evaluate model on test data every eval period epochs. +_C.TRAIN.EVAL_PERIOD = 1 + +# Save model checkpoint every checkpoint period epochs. +_C.TRAIN.CHECKPOINT_PERIOD = 1 + +# Save model checkpoint every checkpoint period iters. +_C.TRAIN.CHECKPOINT_PERIOD_BY_ITER = 500 + + +# Resume training from the latest checkpoint in the output directory. +_C.TRAIN.AUTO_RESUME = True + +# Path to the checkpoint to load the initial weight. +_C.TRAIN.CHECKPOINT_FILE_PATH = "" + +# Checkpoint types include `caffe2` or `pytorch`. +_C.TRAIN.CHECKPOINT_TYPE = "pytorch" + +# If True, perform inflation when loading checkpoint. +_C.TRAIN.CHECKPOINT_INFLATE = False + + +# ---------------------------------------------------------------------------- # +# Testing options +# ---------------------------------------------------------------------------- # +_C.TEST = CfgNode() + +# If True test the model, else skip the testing. +_C.TEST.ENABLE = True + +# Dataset for testing. +_C.TEST.DATASET = "kinetics" + +_C.TEST.SPLIT = "test_subset2.pth" +# Total mini-batch size +_C.TEST.BATCH_SIZE = 8 + +# Path to the checkpoint to load the initial weight. +_C.TEST.CHECKPOINT_FILE_PATH = "" + +# Number of clips to sample from a video uniformly for aggregating the +# prediction results. +_C.TEST.NUM_ENSEMBLE_VIEWS = 10 + +# Number of crops to sample from a frame spatially for aggregating the +# prediction results. +_C.TEST.NUM_SPATIAL_CROPS = 3 + +# Checkpoint types include `caffe2` or `pytorch`. +_C.TEST.CHECKPOINT_TYPE = "pytorch" +# Path to saving prediction results file. +_C.TEST.SAVE_RESULTS_PATH = "" +# ----------------------------------------------------------------------------- +# ResNet options +# ----------------------------------------------------------------------------- +_C.RESNET = CfgNode() + +# Transformation function. +_C.RESNET.TRANS_FUNC = "bottleneck_transform" + +# Number of groups. 1 for ResNet, and larger than 1 for ResNeXt). +_C.RESNET.NUM_GROUPS = 1 + +# Width of each group (64 -> ResNet; 4 -> ResNeXt). +_C.RESNET.WIDTH_PER_GROUP = 64 + +# Apply relu in a inplace manner. +_C.RESNET.INPLACE_RELU = True + +# Apply stride to 1x1 conv. +_C.RESNET.STRIDE_1X1 = False + +# If true, initialize the gamma of the final BN of each block to zero. +_C.RESNET.ZERO_INIT_FINAL_BN = False + +# Number of weight layers. +_C.RESNET.DEPTH = 50 + + +# label of branchs +_C.RESNET.LABELS = ["continus","discontinus"] + +# If the current block has more than NUM_BLOCK_TEMP_KERNEL blocks, use temporal +# kernel of 1 for the rest of the blocks. +_C.RESNET.NUM_BLOCK_TEMP_KERNEL = [[3], [4], [6], [3]] + +# Size of stride on different res stages. +_C.RESNET.SPATIAL_STRIDES = [[1], [2], [2], [2]] + +# Size of dilation on different res stages. +_C.RESNET.SPATIAL_DILATIONS = [[1], [1], [1], [1]] + + +# ----------------------------------------------------------------------------- +# Nonlocal options +# ----------------------------------------------------------------------------- +_C.NONLOCAL = CfgNode() + +# Index of each stage and block to add nonlocal layers. +_C.NONLOCAL.LOCATION = [[[]], [[]], [[]], [[]]] + +# Number of group for nonlocal for each stage. +_C.NONLOCAL.GROUP = [[1], [1], [1], [1]] + +# Instatiation to use for non-local layer. +_C.NONLOCAL.INSTANTIATION = "dot_product" + + +# Size of pooling layers used in Non-Local. +_C.NONLOCAL.POOL = [ + # Res2 + [[1, 2, 2], [1, 2, 2]], + # Res3 + [[1, 2, 2], [1, 2, 2]], + # Res4 + [[1, 2, 2], [1, 2, 2]], + # Res5 + [[1, 2, 2], [1, 2, 2]], +] + +# ----------------------------------------------------------------------------- +# Model options +# ----------------------------------------------------------------------------- +_C.MODEL = CfgNode() + +# Model architecture. +_C.MODEL.ARCH = "slowfast" + +# Model name +_C.MODEL.MODEL_NAME = "SlowFast" + +# The number of classes to predict for the model. +_C.MODEL.NUM_CLASSES = 400 + +# Loss function. +_C.MODEL.LOSS_FUNC = "cross_entropy" + +_C.MODEL.MASK_WEIGHT = 100 + +_C.MODEL.CLASS_WEIGHT = 1 + +# Model architectures that has one single pathway. +_C.MODEL.SINGLE_PATHWAY_ARCH = ["c2d", "i3d", "slow"] + +# Model architectures that has multiple pathways. +_C.MODEL.MULTI_PATHWAY_ARCH = ["slowfast"] + +# Dropout rate before final projection in the backbone. +_C.MODEL.DROPOUT_RATE = 0.5 + +# The std to initialize the fc layer(s). +_C.MODEL.FC_INIT_STD = 0.01 + +# Activation layer for the output head. +_C.MODEL.HEAD_ACT = "softmax" + + +# ----------------------------------------------------------------------------- +# SlowFast options +# ----------------------------------------------------------------------------- +_C.SLOWFAST = CfgNode() + +# Corresponds to the inverse of the channel reduction ratio, $\beta$ between +# the Slow and Fast pathways. +_C.SLOWFAST.BETA_INV = 8 + +# Corresponds to the frame rate reduction ratio, $\alpha$ between the Slow and +# Fast pathways. +_C.SLOWFAST.ALPHA = 8 + +# Ratio of channel dimensions between the Slow and Fast pathways. +_C.SLOWFAST.FUSION_CONV_CHANNEL_RATIO = 2 + +# Kernel dimension used for fusing information from Fast pathway to Slow +# pathway. +_C.SLOWFAST.FUSION_KERNEL_SZ = 5 + + +# ----------------------------------------------------------------------------- +# Data options +# ----------------------------------------------------------------------------- +_C.DATA = CfgNode() + +# The path to the data directory. +_C.DATA.PATH_TO_DATA_DIR = "" + +_C.DATA.DATASET = "faceforensics" + +_C.DATA.MODE = "" + +_C.DATA.ADAPTIVE = False + +_C.DATA.SCALE = 1.0 +# The separator used between path and label. +_C.DATA.PATH_LABEL_SEPARATOR = " " + +# Video path prefix if any. +_C.DATA.PATH_PREFIX = "" + +# The spatial crop size of the input clip. +_C.DATA.CROP_SIZE = 224 + +# The number of frames of the input clip. +_C.DATA.NUM_FRAMES = 8 + +_C.DATA.NUM_FRAMES_RANGE = [1,2,3,4,5,6,7,8] + +# The video sampling rate of the input clip. +_C.DATA.SAMPLING_RATE = 8 + +# The mean value of the video raw pixels across the R G B channels. +_C.DATA.MEAN = [0.45, 0.45, 0.45] +# List of input frame channel dimensions. + +_C.DATA.INPUT_CHANNEL_NUM = [3, 3] + +# The std value of the video raw pixels across the R G B channels. +_C.DATA.STD = [0.225, 0.225, 0.225] + +# The spatial augmentation jitter scales for training. +_C.DATA.TRAIN_JITTER_SCALES = [256, 320] + +# The spatial crop size for training. +_C.DATA.TRAIN_CROP_SIZE = 224 + +# The spatial crop size for testing. +_C.DATA.TEST_CROP_SIZE = 256 + +# Input videos may has different fps, convert it to the target video fps before +# frame sampling. +_C.DATA.TARGET_FPS = 30 + +# Decoding backend, options include `pyav` or `torchvision` +_C.DATA.DECODING_BACKEND = "pyav" + +# if True, sample uniformly in [1 / max_scale, 1 / min_scale] and take a +# reciprocal to get the scale. If False, take a uniform sample from +# [min_scale, max_scale]. +_C.DATA.INV_UNIFORM_SAMPLE = False + +# If True, perform random horizontal flip on the video frames during training. +_C.DATA.RANDOM_FLIP = True + +# If True, calculdate the map as metric. +_C.DATA.MULTI_LABEL = False + +# Method to perform the ensemble, options include "sum" and "max". +_C.DATA.ENSEMBLE_METHOD = "sum" + +# If True, revert the default input channel (RBG <-> BGR). +_C.DATA.REVERSE_INPUT_CHANNEL = False + + +# ---------------------------------------------------------------------------- # +# Optimizer options +# ---------------------------------------------------------------------------- # +_C.SOLVER = CfgNode() + +# Base learning rate. +_C.SOLVER.BASE_LR = 0.1 + +# Learning rate policy (see utils/lr_policy.py for options and examples). +_C.SOLVER.LR_POLICY = "cosine" + +# Exponential decay factor. +_C.SOLVER.GAMMA = 0.1 + +# Step size for 'exp' and 'cos' policies (in epochs). +_C.SOLVER.STEP_SIZE = 1 + +# Steps for 'steps_' policies (in epochs). +_C.SOLVER.STEPS = [] + +# Learning rates for 'steps_' policies. +_C.SOLVER.LRS = [] + +# Maximal number of epochs. +_C.SOLVER.MAX_EPOCH = 300 + +# Momentum. +_C.SOLVER.MOMENTUM = 0.9 + +# Momentum dampening. +_C.SOLVER.DAMPENING = 0.0 + +# Nesterov momentum. +_C.SOLVER.NESTEROV = True + +# L2 regularization. +_C.SOLVER.WEIGHT_DECAY = 1e-4 + +# Start the warm up from SOLVER.BASE_LR * SOLVER.WARMUP_FACTOR. +_C.SOLVER.WARMUP_FACTOR = 0.1 + +# Gradually warm up the SOLVER.BASE_LR over this number of epochs. +_C.SOLVER.WARMUP_EPOCHS = 0.0 + +# The start learning rate of the warm up. +_C.SOLVER.WARMUP_START_LR = 0.01 + +# Optimization method. +_C.SOLVER.OPTIMIZING_METHOD = "sgd" + +_C.SOLVER.LR_STEP = 50000 + +_C.SOLVER.TOTAL_STEP = 200000 + +_C.SOLVER.FREEZE_STEP = 10000 + + +# ---------------------------------------------------------------------------- # +# Misc options +# ---------------------------------------------------------------------------- # + +# Number of GPUs to use (applies to both training and testing). +_C.NUM_GPUS = 1 + +# Number of machine to use for the job. +_C.NUM_SHARDS = 1 + +# The index of the current machine. +_C.SHARD_ID = 0 + +# Output basedir. +_C.OUTPUT_DIR = "./tmp" + +# train module +_C.TRAIN_MODULE= "train_unet_by_iter" + +# Note that non-determinism may still be present due to non-deterministic +# operator implementations in GPU operator libraries. +_C.RNG_SEED = 1 + +# Log period in iters. +_C.LOG_PERIOD = 10 + +# If True, log the model info. +_C.LOG_MODEL_INFO = True + +# Distributed backend. +_C.DIST_BACKEND = "nccl" + +# ---------------------------------------------------------------------------- # +# Benchmark options +# ---------------------------------------------------------------------------- # +_C.BENCHMARK = CfgNode() + +# Number of epochs for data loading benchmark. +_C.BENCHMARK.NUM_EPOCHS = 5 + +# Log period in iters for data loading benchmark. +_C.BENCHMARK.LOG_PERIOD = 100 + +# If True, shuffle dataloader for epoch during benchmark. +_C.BENCHMARK.SHUFFLE = True + + +# ---------------------------------------------------------------------------- # +# Common train/test data loader options +# ---------------------------------------------------------------------------- # +_C.DATA_LOADER = CfgNode() + +# Number of data loader workers per training process. +_C.DATA_LOADER.NUM_WORKERS = 8 + +# Load data to pinned host memory. +_C.DATA_LOADER.PIN_MEMORY = True + +# Enable multi thread decoding. +_C.DATA_LOADER.ENABLE_MULTI_THREAD_DECODE = False + + +# ---------------------------------------------------------------------------- # +# Detection options. +# ---------------------------------------------------------------------------- # +_C.DETECTION = CfgNode() + +# Whether enable video detection. +_C.DETECTION.ENABLE = False + +# Aligned version of RoI. More details can be found at slowfast/models/head_helper.py +_C.DETECTION.ALIGNED = True + +# Spatial scale factor. +_C.DETECTION.SPATIAL_SCALE_FACTOR = 16 + +# RoI tranformation resolution. +_C.DETECTION.ROI_XFORM_RESOLUTION = 7 + + +# ----------------------------------------------------------------------------- +# AVA Dataset options +# ----------------------------------------------------------------------------- +_C.AVA = CfgNode() + +# Directory path of frames. +_C.AVA.FRAME_DIR = "/mnt/fair-flash3-east/ava_trainval_frames.img/" + +# Directory path for files of frame lists. +_C.AVA.FRAME_LIST_DIR = ( + "/mnt/vol/gfsai-flash3-east/ai-group/users/haoqifan/ava/frame_list/" +) + +# Directory path for annotation files. +_C.AVA.ANNOTATION_DIR = ( + "/mnt/vol/gfsai-flash3-east/ai-group/users/haoqifan/ava/frame_list/" +) + +# Filenames of training samples list files. +_C.AVA.TRAIN_LISTS = ["train.csv"] + +# Filenames of test samples list files. +_C.AVA.TEST_LISTS = ["val.csv"] + +# Filenames of box list files for training. Note that we assume files which +# contains predicted boxes will have a suffix "predicted_boxes" in the +# filename. +_C.AVA.TRAIN_GT_BOX_LISTS = ["ava_train_v2.2.csv"] +_C.AVA.TRAIN_PREDICT_BOX_LISTS = [] + +# Filenames of box list files for test. +_C.AVA.TEST_PREDICT_BOX_LISTS = ["ava_val_predicted_boxes.csv"] + +# This option controls the score threshold for the predicted boxes to use. +_C.AVA.DETECTION_SCORE_THRESH = 0.9 + +# If use BGR as the format of input frames. +_C.AVA.BGR = False + +# Training augmentation parameters +# Whether to use color augmentation method. +_C.AVA.TRAIN_USE_COLOR_AUGMENTATION = False + +# Whether to only use PCA jitter augmentation when using color augmentation +# method (otherwise combine with color jitter method). +_C.AVA.TRAIN_PCA_JITTER_ONLY = True + +# Eigenvalues for PCA jittering. Note PCA is RGB based. +_C.AVA.TRAIN_PCA_EIGVAL = [0.225, 0.224, 0.229] + +# Eigenvectors for PCA jittering. +_C.AVA.TRAIN_PCA_EIGVEC = [ + [-0.5675, 0.7192, 0.4009], + [-0.5808, -0.0045, -0.8140], + [-0.5836, -0.6948, 0.4203], +] + +# Whether to do horizontal flipping during test. +_C.AVA.TEST_FORCE_FLIP = False + +# Whether to use full test set for validation split. +_C.AVA.FULL_TEST_ON_VAL = False + +# The name of the file to the ava label map. +_C.AVA.LABEL_MAP_FILE = "ava_action_list_v2.2_for_activitynet_2019.pbtxt" + +# The name of the file to the ava exclusion. +_C.AVA.EXCLUSION_FILE = "ava_val_excluded_timestamps_v2.2.csv" + +# The name of the file to the ava groundtruth. +_C.AVA.GROUNDTRUTH_FILE = "ava_val_v2.2.csv" + +# Backend to process image, includes `pytorch` and `cv2`. +_C.AVA.IMG_PROC_BACKEND = "cv2" + +# ---------------------------------------------------------------------------- # +# Multigrid training options +# See https://arxiv.org/abs/1912.00998 for details about multigrid training. +# ---------------------------------------------------------------------------- # +_C.MULTIGRID = CfgNode() + +# Multigrid training allows us to train for more epochs with fewer iterations. +# This hyperparameter specifies how many times more epochs to train. +# The default setting in paper trains for 1.5x more epochs than baseline. +_C.MULTIGRID.EPOCH_FACTOR = 1.5 + +# Enable short cycles. +_C.MULTIGRID.SHORT_CYCLE = False +# Short cycle additional spatial dimensions relative to the default crop size. +_C.MULTIGRID.SHORT_CYCLE_FACTORS = [0.5, 0.5 ** 0.5] + +_C.MULTIGRID.LONG_CYCLE = False +# (Temporal, Spatial) dimensions relative to the default shape. +_C.MULTIGRID.LONG_CYCLE_FACTORS = [ + (0.25, 0.5 ** 0.5), + (0.5, 0.5 ** 0.5), + (0.5, 1), + (1, 1), +] + +# While a standard BN computes stats across all examples in a GPU, +# for multigrid training we fix the number of clips to compute BN stats on. +# See https://arxiv.org/abs/1912.00998 for details. +_C.MULTIGRID.BN_BASE_SIZE = 8 + +# Multigrid training epochs are not proportional to actual training time or +# computations, so _C.TRAIN.EVAL_PERIOD leads to too frequent or rare +# evaluation. We use a multigrid-specific rule to determine when to evaluate: +# This hyperparameter defines how many times to evaluate a model per long +# cycle shape. +_C.MULTIGRID.EVAL_FREQ = 3 + +# No need to specify; Set automatically and used as global variables. +_C.MULTIGRID.LONG_CYCLE_SAMPLING_RATE = 0 +_C.MULTIGRID.DEFAULT_B = 0 +_C.MULTIGRID.DEFAULT_T = 0 +_C.MULTIGRID.DEFAULT_S = 0 + +# ----------------------------------------------------------------------------- +# Tensorboard Visualization Options +# ----------------------------------------------------------------------------- +_C.TENSORBOARD = CfgNode() + +# Log to summary writer, this will automatically. +# log loss, lr and metrics during train/eval. +_C.TENSORBOARD.ENABLE = False +# Provide path to prediction results for visualization. +# This is a pickle file of [prediction_tensor, label_tensor] +_C.TENSORBOARD.PREDICTIONS_PATH = "" +# Path to directory for tensorboard logs. +# Default to to cfg.OUTPUT_DIR/runs-{cfg.TRAIN.DATASET}. +_C.TENSORBOARD.LOG_DIR = "" +# Path to a json file providing class_name - id mapping +# in the format {"class_name1": id1, "class_name2": id2, ...}. +# This file must be provided to enable plotting confusion matrix +# by a subset or parent categories. +_C.TENSORBOARD.CLASS_NAMES_PATH = "" + +# Path to a json file for categories -> classes mapping +# in the format {"parent_class": ["child_class1", "child_class2",...], ...}. +_C.TENSORBOARD.CATEGORIES_PATH = "" + +# Config for confusion matrices visualization. +_C.TENSORBOARD.CONFUSION_MATRIX = CfgNode() +# Visualize confusion matrix. +_C.TENSORBOARD.CONFUSION_MATRIX.ENABLE = False +# Figure size of the confusion matrices plotted. +_C.TENSORBOARD.CONFUSION_MATRIX.FIGSIZE = [8, 8] +# Path to a subset of categories to visualize. +# File contains class names separated by newline characters. +_C.TENSORBOARD.CONFUSION_MATRIX.SUBSET_PATH = "" + +# Config for histogram visualization. +_C.TENSORBOARD.HISTOGRAM = CfgNode() +# Visualize histograms. +_C.TENSORBOARD.HISTOGRAM.ENABLE = False +# Path to a subset of classes to plot histograms. +# Class names must be separated by newline characters. +_C.TENSORBOARD.HISTOGRAM.SUBSET_PATH = "" +# Visualize top-k most predicted classes on histograms for each +# chosen true label. +_C.TENSORBOARD.HISTOGRAM.TOPK = 10 +# Figure size of the histograms plotted. +_C.TENSORBOARD.HISTOGRAM.FIGSIZE = [8, 8] + +# Config for layers' weights and activations visualization. +# _C.TENSORBOARD.ENABLE must be True. +_C.TENSORBOARD.MODEL_VIS = CfgNode() + +# If False, skip model visualization. +_C.TENSORBOARD.MODEL_VIS.ENABLE = False + +# If False, skip visualizing model weights. +_C.TENSORBOARD.MODEL_VIS.MODEL_WEIGHTS = False + +# If False, skip visualizing model activations. +_C.TENSORBOARD.MODEL_VIS.ACTIVATIONS = False + +# If False, skip visualizing input videos. +_C.TENSORBOARD.MODEL_VIS.INPUT_VIDEO = False + + +# List of strings containing data about layer names and their indexing to +# visualize weights and activations for. The indexing is meant for +# choosing a subset of activations outputed by a layer for visualization. +# If indexing is not specified, visualize all activations outputed by the layer. +# For each string, layer name and indexing is separated by whitespaces. +# e.g.: [layer1 1,2;1,2, layer2, layer3 150,151;3,4]; this means for each array `arr` +# along the batch dimension in `layer1`, we take arr[[1, 2], [1, 2]] +_C.TENSORBOARD.MODEL_VIS.LAYER_LIST = [] +# Top-k predictions to plot on videos +_C.TENSORBOARD.MODEL_VIS.TOPK_PREDS = 1 +# Colormap to for text boxes and bounding boxes colors +_C.TENSORBOARD.MODEL_VIS.COLORMAP = "Pastel2" +# Config for visualization video inputs with Grad-CAM. +# _C.TENSORBOARD.ENABLE must be True. +_C.TENSORBOARD.MODEL_VIS.GRAD_CAM = CfgNode() +# Whether to run visualization using Grad-CAM technique. +_C.TENSORBOARD.MODEL_VIS.GRAD_CAM.ENABLE = True +# CNN layers to use for Grad-CAM. The number of layers must be equal to +# number of pathway(s). +_C.TENSORBOARD.MODEL_VIS.GRAD_CAM.LAYER_LIST = [] +# If True, visualize Grad-CAM using true labels for each instances. +# If False, use the highest predicted class. +_C.TENSORBOARD.MODEL_VIS.GRAD_CAM.USE_TRUE_LABEL = False +# Colormap to for text boxes and bounding boxes colors +_C.TENSORBOARD.MODEL_VIS.GRAD_CAM.COLORMAP = "viridis" + +# Config for visualization for wrong prediction visualization. +# _C.TENSORBOARD.ENABLE must be True. +_C.TENSORBOARD.WRONG_PRED_VIS = CfgNode() +_C.TENSORBOARD.WRONG_PRED_VIS.ENABLE = False +# Folder tag to origanize model eval videos under. +_C.TENSORBOARD.WRONG_PRED_VIS.TAG = "Incorrectly classified videos." +# Subset of labels to visualize. Only wrong predictions with true labels +# within this subset is visualized. +_C.TENSORBOARD.WRONG_PRED_VIS.SUBSET_PATH = "" + + + +############### +_C.JITTER = CfgNode() + +_C.JITTER.ENABLE = False + +_C.JITTER.CONTINUS_METHODS=["blend_diff_person","blend_downsampled","blend_same_person"] +_C.JITTER.DISCONTINUS_METHODS=["light", "rotate", "skip"] + +_C.JITTER.STRONG_INNER_CLIP_MASK_JITTER= False + +# ---------------------------------------------------------------------------- # +# Demo options +# ---------------------------------------------------------------------------- # +_C.DEMO = CfgNode() + +# Run model in DEMO mode. +_C.DEMO.ENABLE = False + +# Path to a json file providing class_name - id mapping +# in the format {"class_name1": id1, "class_name2": id2, ...}. +_C.DEMO.LABEL_FILE_PATH = "" + +# Specify a camera device as input. This will be prioritized +# over input video if set. +# If -1, use input video instead. +_C.DEMO.WEBCAM = -1 + +# Path to input video for demo. +_C.DEMO.INPUT_VIDEO = "" +# Custom width for reading input video data. +_C.DEMO.DISPLAY_WIDTH = 0 +# Custom height for reading input video data. +_C.DEMO.DISPLAY_HEIGHT = 0 +# Path to Detectron2 object detection model configuration, +# only used for detection tasks. +_C.DEMO.DETECTRON2_CFG = "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml" +# Path to Detectron2 object detection model pre-trained weights. +_C.DEMO.DETECTRON2_WEIGHTS = "detectron2://COCO-Detection/faster_rcnn_R_50_FPN_3x/137849458/model_final_280758.pkl" +# Threshold for choosing predicted bounding boxes by Detectron2. +_C.DEMO.DETECTRON2_THRESH = 0.9 +# Number of overlapping frames between 2 consecutive clips. +# Increase this number for more frequent action predictions. +# The number of overlapping frames cannot be larger than +# half of the sequence length `cfg.DATA.NUM_FRAMES * cfg.DATA.SAMPLING_RATE` +_C.DEMO.BUFFER_SIZE = 0 +# If specified, the visualized outputs will be written this a video file of +# this path. Otherwise, the visualized outputs will be displayed in a window. +_C.DEMO.OUTPUT_FILE = "" +# Frames per second rate for writing to output video file. +# If not set (-1), use fps rate from input file. +_C.DEMO.OUTPUT_FPS = -1 +# Input format from demo video reader ("RGB" or "BGR"). +_C.DEMO.INPUT_FORMAT = "BGR" +# Draw visualization frames in [keyframe_idx - CLIP_VIS_SIZE, keyframe_idx + CLIP_VIS_SIZE] inclusively. +_C.DEMO.CLIP_VIS_SIZE = 10 +# Number of processes to run video visualizer. +_C.DEMO.NUM_VIS_INSTANCES = 2 + +# Path to pre-computed predicted boxes +_C.DEMO.PREDS_BOXES = "" +# Whether to run in with multi-threaded video reader. +_C.DEMO.THREAD_ENABLE = False +# Take one clip for every `DEMO.NUM_CLIPS_SKIP` + 1 for prediction and visualization. +# This is used for fast demo speed by reducing the prediction/visualiztion frequency. +# If -1, take the most recent read clip for visualization. This mode is only supported +# if `DEMO.THREAD_ENABLE` is set to True. +_C.DEMO.NUM_CLIPS_SKIP = 0 +# Path to ground-truth boxes and labels (optional) +_C.DEMO.GT_BOXES = "" +# The starting second of the video w.r.t bounding boxes file. +_C.DEMO.STARTING_SECOND = 900 +# Frames per second of the input video/folder of images. +_C.DEMO.FPS = 30 +# Visualize with top-k predictions or predictions above certain threshold(s). +# Option: {"thres", "top-k"} +_C.DEMO.VIS_MODE = "thres" +# Threshold for common class names. +_C.DEMO.COMMON_CLASS_THRES = 0.7 +# Theshold for uncommon class names. This will not be +# used if `_C.DEMO.COMMON_CLASS_NAMES` is empty. +_C.DEMO.UNCOMMON_CLASS_THRES = 0.3 +# This is chosen based on distribution of examples in +# each classes in AVA dataset. +_C.DEMO.COMMON_CLASS_NAMES = [ + "watch (a person)", + "talk to (e.g., self, a person, a group)", + "listen to (a person)", + "touch (an object)", + "carry/hold (an object)", + "walk", + "sit", + "lie/sleep", + "bend/bow (at the waist)", +] +# Slow-motion rate for the visualization. The visualized portions of the +# video will be played `_C.DEMO.SLOWMO` times slower than usual speed. +_C.DEMO.SLOWMO = 1 + +# Add custom config with default values. +custom_config.add_custom_config(_C) + + +def _assert_and_infer_cfg(cfg): + # BN assertions. + if cfg.BN.USE_PRECISE_STATS: + assert cfg.BN.NUM_BATCHES_PRECISE >= 0 + # TRAIN assertions. + assert cfg.TRAIN.CHECKPOINT_TYPE in ["pytorch", "caffe2"] + assert cfg.TRAIN.BATCH_SIZE % cfg.NUM_GPUS == 0 + + # TEST assertions. + assert cfg.TEST.CHECKPOINT_TYPE in ["pytorch", "caffe2"] + assert cfg.TEST.BATCH_SIZE % cfg.NUM_GPUS == 0 + assert cfg.TEST.NUM_SPATIAL_CROPS == 3 + + # RESNET assertions. + assert cfg.RESNET.NUM_GROUPS > 0 + assert cfg.RESNET.WIDTH_PER_GROUP > 0 + assert cfg.RESNET.WIDTH_PER_GROUP % cfg.RESNET.NUM_GROUPS == 0 + + # General assertions. + assert cfg.SHARD_ID < cfg.NUM_SHARDS + return cfg + + +def get_cfg(): + """ + Get a copy of the default config. + """ + return _assert_and_infer_cfg(_C.clone()) diff --git a/training/detectors/utils/slowfast/models/__init__.py b/training/detectors/utils/slowfast/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f82b3dc1e3af0dbabf4a3a6153e48b977eb1059e --- /dev/null +++ b/training/detectors/utils/slowfast/models/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +from .build import MODEL_REGISTRY, build_model # noqa +from .custom_video_model_builder import * # noqa +from .video_model_builder import ResNet, SlowFast # noqa \ No newline at end of file diff --git a/training/detectors/utils/slowfast/models/batchnorm_helper.py b/training/detectors/utils/slowfast/models/batchnorm_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..4e52d50497d9c0a58e5ace0a2fde94b5418ef563 --- /dev/null +++ b/training/detectors/utils/slowfast/models/batchnorm_helper.py @@ -0,0 +1,218 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +"""BatchNorm (BN) utility functions and custom batch-size BN implementations""" + +from functools import partial +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.autograd.function import Function + +import slowfast.utils.distributed as du + + +def get_norm(cfg): + """ + Args: + cfg (CfgNode): model building configs, details are in the comments of + the config file. + Returns: + nn.Module: the normalization layer. + """ + if cfg.BN.NORM_TYPE == "batchnorm": + return nn.BatchNorm3d + elif cfg.BN.NORM_TYPE == "sub_batchnorm": + return partial(SubBatchNorm3d, num_splits=cfg.BN.NUM_SPLITS) + elif cfg.BN.NORM_TYPE == "sync_batchnorm": + return partial( + NaiveSyncBatchNorm3d, num_sync_devices=cfg.BN.NUM_SYNC_DEVICES + ) + else: + raise NotImplementedError( + "Norm type {} is not supported".format(cfg.BN.NORM_TYPE) + ) + + +class SubBatchNorm3d(nn.Module): + """ + The standard BN layer computes stats across all examples in a GPU. In some + cases it is desirable to compute stats across only a subset of examples + (e.g., in multigrid training https://arxiv.org/abs/1912.00998). + SubBatchNorm3d splits the batch dimension into N splits, and run BN on + each of them separately (so that the stats are computed on each subset of + examples (1/N of batch) independently. During evaluation, it aggregates + the stats from all splits into one BN. + """ + + def __init__(self, num_splits, **args): + """ + Args: + num_splits (int): number of splits. + args (list): other arguments. + """ + super(SubBatchNorm3d, self).__init__() + self.num_splits = num_splits + num_features = args["num_features"] + # Keep only one set of weight and bias. + if args.get("affine", True): + self.affine = True + args["affine"] = False + self.weight = torch.nn.Parameter(torch.ones(num_features)) + self.bias = torch.nn.Parameter(torch.zeros(num_features)) + else: + self.affine = False + self.bn = nn.BatchNorm3d(**args) + args["num_features"] = num_features * num_splits + self.split_bn = nn.BatchNorm3d(**args) + + def _get_aggregated_mean_std(self, means, stds, n): + """ + Calculate the aggregated mean and stds. + Args: + means (tensor): mean values. + stds (tensor): standard deviations. + n (int): number of sets of means and stds. + """ + mean = means.view(n, -1).sum(0) / n + std = ( + stds.view(n, -1).sum(0) / n + + ((means.view(n, -1) - mean) ** 2).view(n, -1).sum(0) / n + ) + return mean.detach(), std.detach() + + def aggregate_stats(self): + """ + Synchronize running_mean, and running_var. Call this before eval. + """ + if self.split_bn.track_running_stats: + ( + self.bn.running_mean.data, + self.bn.running_var.data, + ) = self._get_aggregated_mean_std( + self.split_bn.running_mean, + self.split_bn.running_var, + self.num_splits, + ) + + def forward(self, x): + if self.training: + n, c, t, h, w = x.shape + x = x.view(n // self.num_splits, c * self.num_splits, t, h, w) + x = self.split_bn(x) + x = x.view(n, c, t, h, w) + else: + x = self.bn(x) + if self.affine: + x = x * self.weight.view((-1, 1, 1, 1)) + x = x + self.bias.view((-1, 1, 1, 1)) + return x + + +class GroupGather(Function): + """ + GroupGather performs all gather on each of the local process/ GPU groups. + """ + + @staticmethod + def forward(ctx, input, num_sync_devices, num_groups): + """ + Perform forwarding, gathering the stats across different process/ GPU + group. + """ + ctx.num_sync_devices = num_sync_devices + ctx.num_groups = num_groups + + input_list = [ + torch.zeros_like(input) for k in range(du.get_local_size()) + ] + dist.all_gather( + input_list, input, async_op=False, group=du._LOCAL_PROCESS_GROUP + ) + + inputs = torch.stack(input_list, dim=0) + if num_groups > 1: + rank = du.get_local_rank() + group_idx = rank // num_sync_devices + inputs = inputs[ + group_idx + * num_sync_devices : (group_idx + 1) + * num_sync_devices + ] + inputs = torch.sum(inputs, dim=0) + return inputs + + @staticmethod + def backward(ctx, grad_output): + """ + Perform backwarding, gathering the gradients across different process/ GPU + group. + """ + grad_output_list = [ + torch.zeros_like(grad_output) for k in range(du.get_local_size()) + ] + dist.all_gather( + grad_output_list, + grad_output, + async_op=False, + group=du._LOCAL_PROCESS_GROUP, + ) + + grads = torch.stack(grad_output_list, dim=0) + if ctx.num_groups > 1: + rank = du.get_local_rank() + group_idx = rank // ctx.num_sync_devices + grads = grads[ + group_idx + * ctx.num_sync_devices : (group_idx + 1) + * ctx.num_sync_devices + ] + grads = torch.sum(grads, dim=0) + return grads, None, None + + +class NaiveSyncBatchNorm3d(nn.BatchNorm3d): + def __init__(self, num_sync_devices, **args): + """ + Naive version of Synchronized 3D BatchNorm. + Args: + num_sync_devices (int): number of device to sync. + args (list): other arguments. + """ + self.num_sync_devices = num_sync_devices + if self.num_sync_devices > 0: + assert du.get_local_size() % self.num_sync_devices == 0, ( + du.get_local_size(), + self.num_sync_devices, + ) + self.num_groups = du.get_local_size() // self.num_sync_devices + else: + self.num_sync_devices = du.get_local_size() + self.num_groups = 1 + super(NaiveSyncBatchNorm3d, self).__init__(**args) + + def forward(self, input): + if du.get_local_size() == 1 or not self.training: + return super().forward(input) + + assert input.shape[0] > 0, "SyncBatchNorm does not support empty inputs" + C = input.shape[1] + mean = torch.mean(input, dim=[0, 2, 3, 4]) + meansqr = torch.mean(input * input, dim=[0, 2, 3, 4]) + + vec = torch.cat([mean, meansqr], dim=0) + vec = GroupGather.apply(vec, self.num_sync_devices, self.num_groups) * ( + 1.0 / self.num_sync_devices + ) + + mean, meansqr = torch.split(vec, C) + var = meansqr - mean * mean + self.running_mean += self.momentum * (mean.detach() - self.running_mean) + self.running_var += self.momentum * (var.detach() - self.running_var) + + invstd = torch.rsqrt(var + self.eps) + scale = self.weight * invstd + bias = self.bias - mean * scale + scale = scale.reshape(1, -1, 1, 1, 1) + bias = bias.reshape(1, -1, 1, 1, 1) + return input * scale + bias diff --git a/training/detectors/utils/slowfast/models/build.py b/training/detectors/utils/slowfast/models/build.py new file mode 100644 index 0000000000000000000000000000000000000000..8dd9cca224b3e77bb8c3cd899489358737d4cd35 --- /dev/null +++ b/training/detectors/utils/slowfast/models/build.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +"""Model construction functions.""" + +import torch +from fvcore.common.registry import Registry + +MODEL_REGISTRY = Registry("MODEL") +MODEL_REGISTRY.__doc__ = """ +Registry for video model. + +The registered object will be called with `obj(cfg)`. +The call should return a `torch.nn.Module` object. +""" + + +def build_model(cfg, gpu_id=None): + """ + Builds the video model. + Args: + cfg (configs): configs that contains the hyper-parameters to build the + backbone. Details can be seen in slowfast/config/defaults.py. + gpu_id (Optional[int]): specify the gpu index to build model. + """ + if torch.cuda.is_available(): + assert ( + cfg.NUM_GPUS <= torch.cuda.device_count() + ), "Cannot use more GPU devices than available" + else: + assert ( + cfg.NUM_GPUS == 0 + ), "Cuda is not available. Please set `NUM_GPUS: 0 for running on CPUs." + + # Construct the model + name = cfg.MODEL.MODEL_NAME + model = MODEL_REGISTRY.get(name)(cfg) + + if cfg.NUM_GPUS: + if gpu_id is None: + # Determine the GPU used by the current process + cur_device = torch.cuda.current_device() + else: + cur_device = gpu_id + # Transfer the model to the current GPU device + model = model.cuda(device=cur_device) + # Use multi-process data parallel model in the multi-gpu setting + if cfg.NUM_GPUS > 1: + # Make model replica operate on the current device + model = torch.nn.parallel.DistributedDataParallel( + module=model, device_ids=[cur_device], output_device=cur_device,find_unused_parameters=True + ) + return model diff --git a/training/detectors/utils/slowfast/models/custom_video_model_builder.py b/training/detectors/utils/slowfast/models/custom_video_model_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..f261f67b95616b8582b10998a290611ee108b2a9 --- /dev/null +++ b/training/detectors/utils/slowfast/models/custom_video_model_builder.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + + +"""A More Flexible Video models.""" diff --git a/training/detectors/utils/slowfast/models/head_helper.py b/training/detectors/utils/slowfast/models/head_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..df04b010430b6000005676d52174243383873d05 --- /dev/null +++ b/training/detectors/utils/slowfast/models/head_helper.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +"""ResNe(X)t Head helper.""" + +import torch +import torch.nn as nn + +class ResNetBasicHead(nn.Module): + """ + ResNe(X)t 3D head. + This layer performs a fully-connected projection during training, when the + input size is 1x1x1. It performs a convolutional projection during testing + when the input size is larger than 1x1x1. If the inputs are from multiple + different pathways, the inputs will be concatenated after pooling. + """ + + def __init__( + self, + dim_in, + num_classes, + pool_size, + dropout_rate=0.0, + act_func="softmax", + ): + """ + The `__init__` method of any subclass should also contain these + arguments. + ResNetBasicHead takes p pathways as input where p in [1, infty]. + + Args: + dim_in (list): the list of channel dimensions of the p inputs to the + ResNetHead. + num_classes (int): the channel dimensions of the p outputs to the + ResNetHead. + pool_size (list): the list of kernel sizes of p spatial temporal + poolings, temporal pool kernel size, spatial pool kernel size, + spatial pool kernel size in order. + dropout_rate (float): dropout rate. If equal to 0.0, perform no + dropout. + act_func (string): activation function to use. 'softmax': applies + softmax on the output. 'sigmoid': applies sigmoid on the output. + """ + super(ResNetBasicHead, self).__init__() + assert ( + len({len(pool_size), len(dim_in)}) == 1 + ), "pathway dimensions are not consistent." + self.num_pathways = len(pool_size) + + for pathway in range(self.num_pathways): + if pool_size[pathway] is None: + avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) + else: + avg_pool = nn.AvgPool3d(pool_size[pathway], stride=1) + self.add_module("pathway{}_avgpool".format(pathway), avg_pool) + + if dropout_rate > 0.0: + self.dropout = nn.Dropout(dropout_rate) + # Perform FC in a fully convolutional manner. The FC layer will be + # initialized with a different std comparing to convolutional layers. + self.projection = nn.Linear(sum(dim_in), num_classes, bias=True) + + # Softmax for evaluation and testing. + if act_func == "softmax": + self.act = nn.Softmax(dim=4) + elif act_func == "sigmoid": + self.act = nn.Sigmoid() + else: + raise NotImplementedError( + "{} is not supported as an activation" + "function.".format(act_func) + ) + + def forward(self, inputs): + assert ( + len(inputs) == self.num_pathways + ), "Input tensor does not contain {} pathway".format(self.num_pathways) + pool_out = [] + for pathway in range(self.num_pathways): + m = getattr(self, "pathway{}_avgpool".format(pathway)) + pool_out.append(m(inputs[pathway])) + x = torch.cat(pool_out, 1) + # (N, C, T, H, W) -> (N, T, H, W, C). + x = x.permute((0, 2, 3, 4, 1)) + # Perform dropout. + if hasattr(self, "dropout"): + x = self.dropout(x) + x = self.projection(x) + + # Performs fully convlutional inference. + # if not self.training: + # x = x.mean([1, 2, 3]) + x = self.act(x) + x = x.view(x.shape[0], -1) + return x diff --git a/training/detectors/utils/slowfast/models/losses.py b/training/detectors/utils/slowfast/models/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..7dda4eb19b2cf76275ba1778dc5f8730058a6c31 --- /dev/null +++ b/training/detectors/utils/slowfast/models/losses.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +"""Loss functions.""" + +import torch.nn as nn + +_LOSSES = { + "cross_entropy": nn.CrossEntropyLoss, + "bce": nn.BCELoss, + "bce_logit": nn.BCEWithLogitsLoss, +} + + +def get_loss_func(loss_name): + """ + Retrieve the loss given the loss name. + Args (int): + loss_name: the name of the loss to use. + """ + if loss_name not in _LOSSES.keys(): + raise NotImplementedError("Loss {} is not supported".format(loss_name)) + return _LOSSES[loss_name] diff --git a/training/detectors/utils/slowfast/models/nonlocal_helper.py b/training/detectors/utils/slowfast/models/nonlocal_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..6e68d05817256a66d0b6ecf0f96292446cf41270 --- /dev/null +++ b/training/detectors/utils/slowfast/models/nonlocal_helper.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +"""Non-local helper""" + +import torch +import torch.nn as nn + + +class Nonlocal(nn.Module): + """ + Builds Non-local Neural Networks as a generic family of building + blocks for capturing long-range dependencies. Non-local Network + computes the response at a position as a weighted sum of the + features at all positions. This building block can be plugged into + many computer vision architectures. + More details in the paper: https://arxiv.org/pdf/1711.07971.pdf + """ + + def __init__( + self, + dim, + dim_inner, + pool_size=None, + instantiation="softmax", + zero_init_final_conv=False, + zero_init_final_norm=True, + norm_eps=1e-5, + norm_momentum=0.1, + norm_module=nn.BatchNorm3d, + ): + """ + Args: + dim (int): number of dimension for the input. + dim_inner (int): number of dimension inside of the Non-local block. + pool_size (list): the kernel size of spatial temporal pooling, + temporal pool kernel size, spatial pool kernel size, spatial + pool kernel size in order. By default pool_size is None, + then there would be no pooling used. + instantiation (string): supports two different instantiation method: + "dot_product": normalizing correlation matrix with L2. + "softmax": normalizing correlation matrix with Softmax. + zero_init_final_conv (bool): If true, zero initializing the final + convolution of the Non-local block. + zero_init_final_norm (bool): + If true, zero initializing the final batch norm of the Non-local + block. + norm_module (nn.Module): nn.Module for the normalization layer. The + default is nn.BatchNorm3d. + """ + super(Nonlocal, self).__init__() + self.dim = dim + self.dim_inner = dim_inner + self.pool_size = pool_size + self.instantiation = instantiation + self.use_pool = ( + False + if pool_size is None + else any((size > 1 for size in pool_size)) + ) + self.norm_eps = norm_eps + self.norm_momentum = norm_momentum + self._construct_nonlocal( + zero_init_final_conv, zero_init_final_norm, norm_module + ) + + def _construct_nonlocal( + self, zero_init_final_conv, zero_init_final_norm, norm_module + ): + # Three convolution heads: theta, phi, and g. + self.conv_theta = nn.Conv3d( + self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0 + ) + self.conv_phi = nn.Conv3d( + self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0 + ) + self.conv_g = nn.Conv3d( + self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0 + ) + + # Final convolution output. + self.conv_out = nn.Conv3d( + self.dim_inner, self.dim, kernel_size=1, stride=1, padding=0 + ) + # Zero initializing the final convolution output. + self.conv_out.zero_init = zero_init_final_conv + + # TODO: change the name to `norm` + self.bn = norm_module( + num_features=self.dim, + eps=self.norm_eps, + momentum=self.norm_momentum, + ) + # Zero initializing the final bn. + self.bn.transform_final_bn = zero_init_final_norm + + # Optional to add the spatial-temporal pooling. + if self.use_pool: + self.pool = nn.MaxPool3d( + kernel_size=self.pool_size, + stride=self.pool_size, + padding=[0, 0, 0], + ) + + def forward(self, x): + x_identity = x + N, C, T, H, W = x.size() + + theta = self.conv_theta(x) + + # Perform temporal-spatial pooling to reduce the computation. + if self.use_pool: + x = self.pool(x) + + phi = self.conv_phi(x) + g = self.conv_g(x) + + theta = theta.view(N, self.dim_inner, -1) + phi = phi.view(N, self.dim_inner, -1) + g = g.view(N, self.dim_inner, -1) + + # (N, C, TxHxW) * (N, C, TxHxW) => (N, TxHxW, TxHxW). + theta_phi = torch.einsum("nct,ncp->ntp", (theta, phi)) + # For original Non-local paper, there are two main ways to normalize + # the affinity tensor: + # 1) Softmax normalization (norm on exp). + # 2) dot_product normalization. + if self.instantiation == "softmax": + # Normalizing the affinity tensor theta_phi before softmax. + theta_phi = theta_phi * (self.dim_inner ** -0.5) + theta_phi = nn.functional.softmax(theta_phi, dim=2) + elif self.instantiation == "dot_product": + spatial_temporal_dim = theta_phi.shape[2] + theta_phi = theta_phi / spatial_temporal_dim + else: + raise NotImplementedError( + "Unknown norm type {}".format(self.instantiation) + ) + + # (N, TxHxW, TxHxW) * (N, C, TxHxW) => (N, C, TxHxW). + theta_phi_g = torch.einsum("ntg,ncg->nct", (theta_phi, g)) + + # (N, C, TxHxW) => (N, C, T, H, W). + theta_phi_g = theta_phi_g.view(N, self.dim_inner, T, H, W) + + p = self.conv_out(theta_phi_g) + p = self.bn(p) + return x_identity + p diff --git a/training/detectors/utils/slowfast/models/optimizer.py b/training/detectors/utils/slowfast/models/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..130f2cebf994741bc45a6519f07c5f0740c106ac --- /dev/null +++ b/training/detectors/utils/slowfast/models/optimizer.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +"""Optimizer.""" + +import torch + +import slowfast.utils.lr_policy as lr_policy + + +def construct_optimizer(model, cfg): + """ + Construct a stochastic gradient descent or ADAM optimizer with momentum. + Details can be found in: + Herbert Robbins, and Sutton Monro. "A stochastic approximation method." + and + Diederik P.Kingma, and Jimmy Ba. + "Adam: A Method for Stochastic Optimization." + + Args: + model (model): model to perform stochastic gradient descent + optimization or ADAM optimization. + cfg (config): configs of hyper-parameters of SGD or ADAM, includes base + learning rate, momentum, weight_decay, dampening, and etc. + """ + # Batchnorm parameters. + bn_params = [] + # Non-batchnorm parameters. + non_bn_parameters = [] + for name, p in model.named_parameters(): + if "bn" in name: + bn_params.append(p) + else: + non_bn_parameters.append(p) + # Apply different weight decay to Batchnorm and non-batchnorm parameters. + # In Caffe2 classification codebase the weight decay for batchnorm is 0.0. + # Having a different weight decay on batchnorm might cause a performance + # drop. + optim_params = [ + {"params": bn_params, "weight_decay": cfg.BN.WEIGHT_DECAY}, + {"params": non_bn_parameters, "weight_decay": cfg.SOLVER.WEIGHT_DECAY}, + ] + # Check all parameters will be passed into optimizer. + assert len(list(model.parameters())) == len(non_bn_parameters) + len( + bn_params + ), "parameter size does not match: {} + {} != {}".format( + len(non_bn_parameters), len(bn_params), len(list(model.parameters())) + ) + + if cfg.SOLVER.OPTIMIZING_METHOD == "sgd": + return torch.optim.SGD( + optim_params, + lr=cfg.SOLVER.BASE_LR, + momentum=cfg.SOLVER.MOMENTUM, + weight_decay=cfg.SOLVER.WEIGHT_DECAY, + dampening=cfg.SOLVER.DAMPENING, + nesterov=cfg.SOLVER.NESTEROV, + ) + elif cfg.SOLVER.OPTIMIZING_METHOD == "adam": + return torch.optim.Adam( + optim_params, + lr=cfg.SOLVER.BASE_LR, + betas=(0.9, 0.999), + weight_decay=cfg.SOLVER.WEIGHT_DECAY, + ) + else: + raise NotImplementedError( + "Does not support {} optimizer".format(cfg.SOLVER.OPTIMIZING_METHOD) + ) + + +def get_epoch_lr(cur_epoch, cfg): + """ + Retrieves the lr for the given epoch (as specified by the lr policy). + Args: + cfg (config): configs of hyper-parameters of ADAM, includes base + learning rate, betas, and weight decays. + cur_epoch (float): the number of epoch of the current training stage. + """ + return lr_policy.get_lr_at_epoch(cfg, cur_epoch) + +def get_iter_lr(cur_iter, cfg): + """ + Retrieves the lr for the given iter (as specified by the lr policy). + Args: + cfg (config): configs of hyper-parameters of ADAM, includes base + learning rate, betas, and weight decays. + cur_epoch (float): the number of epoch of the current training stage. + """ + lr=lr_policy.get_lr_at_iter(cfg, cur_iter) + + return lr + + +def set_lr(optimizer, new_lr): + """ + Sets the optimizer lr to the specified value. + Args: + optimizer (optim): the optimizer using to optimize the current network. + new_lr (float): the new learning rate to set. + """ + for param_group in optimizer.param_groups: + param_group["lr"] = new_lr diff --git a/training/detectors/utils/slowfast/models/resnet_helper.py b/training/detectors/utils/slowfast/models/resnet_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..0005409707586ceaac71a50fdf07fadd6c9ffb6b --- /dev/null +++ b/training/detectors/utils/slowfast/models/resnet_helper.py @@ -0,0 +1,647 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +"""Video models.""" + +import torch.nn as nn + +from slowfast.models.nonlocal_helper import Nonlocal + + +def get_trans_func(name): + """ + Retrieves the transformation module by name. + """ + trans_funcs = { + "bottleneck_transform": BottleneckTransform, + "basic_transform": BasicTransform, + "temporal_transform":TemporalTransform + } + assert ( + name in trans_funcs.keys() + ), "Transformation function '{}' not supported".format(name) + return trans_funcs[name] + + +class BasicTransform(nn.Module): + """ + Basic transformation: Tx3x3, 1x3x3, where T is the size of temporal kernel. + """ + + def __init__( + self, + dim_in, + dim_out, + temp_kernel_size, + stride, + dim_inner=None, + num_groups=1, + stride_1x1=None, + inplace_relu=True, + eps=1e-5, + bn_mmt=0.1, + norm_module=nn.BatchNorm3d, + ): + """ + Args: + dim_in (int): the channel dimensions of the input. + dim_out (int): the channel dimension of the output. + temp_kernel_size (int): the temporal kernel sizes of the first + convolution in the basic block. + stride (int): the stride of the bottleneck. + dim_inner (None): the inner dimension would not be used in + BasicTransform. + num_groups (int): number of groups for the convolution. Number of + group is always 1 for BasicTransform. + stride_1x1 (None): stride_1x1 will not be used in BasicTransform. + inplace_relu (bool): if True, calculate the relu on the original + input without allocating new memory. + eps (float): epsilon for batch norm. + bn_mmt (float): momentum for batch norm. Noted that BN momentum in + PyTorch = 1 - BN momentum in Caffe2. + norm_module (nn.Module): nn.Module for the normalization layer. The + default is nn.BatchNorm3d. + """ + super(BasicTransform, self).__init__() + self.temp_kernel_size = temp_kernel_size + self._inplace_relu = inplace_relu + self._eps = eps + self._bn_mmt = bn_mmt + self._construct(dim_in, dim_out, stride, norm_module) + + def _construct(self, dim_in, dim_out, stride, norm_module): + # Tx3x3, BN, ReLU. + self.a = nn.Conv3d( + dim_in, + dim_out, + kernel_size=[self.temp_kernel_size, 3, 3], + stride=[1, stride, stride], + padding=[int(self.temp_kernel_size // 2), 1, 1], + bias=False, + ) + self.a_bn = norm_module( + num_features=dim_out, eps=self._eps, momentum=self._bn_mmt + ) + self.a_relu = nn.ReLU(inplace=self._inplace_relu) + # 1x3x3, BN. + self.b = nn.Conv3d( + dim_out, + dim_out, + kernel_size=[1, 3, 3], + stride=[1, 1, 1], + padding=[0, 1, 1], + bias=False, + ) + self.b_bn = norm_module( + num_features=dim_out, eps=self._eps, momentum=self._bn_mmt + ) + + self.b_bn.transform_final_bn = True + + def forward(self, x): + x = self.a(x) + x = self.a_bn(x) + x = self.a_relu(x) + + x = self.b(x) + x = self.b_bn(x) + return x + +class TemporalTransform(nn.Module): + """ + Basic transformation: Tx3x3, 1x3x3, where T is the size of temporal kernel. + """ + + def __init__( + self, + dim_in, + dim_out, + temp_kernel_size, + stride, + dim_inner=None, + num_groups=1, + stride_1x1=None, + inplace_relu=True, + eps=1e-5, + bn_mmt=0.1, + norm_module=nn.BatchNorm3d, + dilation=1 + ): + """ + Args: + dim_in (int): the channel dimensions of the input. + dim_out (int): the channel dimension of the output. + temp_kernel_size (int): the temporal kernel sizes of the first + convolution in the basic block. + stride (int): the stride of the bottleneck. + dim_inner (None): the inner dimension would not be used in + BasicTransform. + num_groups (int): number of groups for the convolution. Number of + group is always 1 for BasicTransform. + stride_1x1 (None): stride_1x1 will not be used in BasicTransform. + inplace_relu (bool): if True, calculate the relu on the original + input without allocating new memory. + eps (float): epsilon for batch norm. + bn_mmt (float): momentum for batch norm. Noted that BN momentum in + PyTorch = 1 - BN momentum in Caffe2. + norm_module (nn.Module): nn.Module for the normalization layer. The + default is nn.BatchNorm3d. + """ + super(TemporalTransform, self).__init__() + self.temp_kernel_size = temp_kernel_size + self._inplace_relu = inplace_relu + self._eps = eps + self._bn_mmt = bn_mmt + self._construct(dim_in, dim_out, stride, norm_module) + + def _construct(self, dim_in, dim_out, stride, norm_module): + # Tx3x3, BN, ReLU. + self.a = nn.Conv3d( + dim_in, + dim_out, + kernel_size=[self.temp_kernel_size, 3, 3], + stride=[1, stride, stride], + padding=[int(self.temp_kernel_size // 2), 1, 1], + bias=False, + ) + self.a_bn = norm_module( + num_features=dim_out, eps=self._eps, momentum=self._bn_mmt + ) + self.a_relu = nn.ReLU(inplace=self._inplace_relu) + # 1x3x3, BN. + self.b = nn.Conv3d( + dim_out, + dim_out, + kernel_size=[1, 3, 3], + stride=[1, 1, 1], + padding=[0, 1, 1], + bias=False, + ) + self.b_bn = norm_module( + num_features=dim_out, eps=self._eps, momentum=self._bn_mmt + ) + + self.b_bn.transform_final_bn = True + + def forward(self, x): + x = self.a(x) + x = self.a_bn(x) + x = self.a_relu(x) + + x = self.b(x) + x = self.b_bn(x) + return x + + +class BottleneckTransform(nn.Module): + """ + Bottleneck transformation: Tx1x1, 1x3x3, 1x1x1, where T is the size of + temporal kernel. + """ + + def __init__( + self, + dim_in, + dim_out, + temp_kernel_size, + stride, + dim_inner, + num_groups, + stride_1x1=False, + inplace_relu=True, + eps=1e-5, + bn_mmt=0.1, + dilation=1, + norm_module=nn.BatchNorm3d, + ): + """ + Args: + dim_in (int): the channel dimensions of the input. + dim_out (int): the channel dimension of the output. + temp_kernel_size (int): the temporal kernel sizes of the first + convolution in the bottleneck. + stride (int): the stride of the bottleneck. + dim_inner (int): the inner dimension of the block. + num_groups (int): number of groups for the convolution. num_groups=1 + is for standard ResNet like networks, and num_groups>1 is for + ResNeXt like networks. + stride_1x1 (bool): if True, apply stride to 1x1 conv, otherwise + apply stride to the 3x3 conv. + inplace_relu (bool): if True, calculate the relu on the original + input without allocating new memory. + eps (float): epsilon for batch norm. + bn_mmt (float): momentum for batch norm. Noted that BN momentum in + PyTorch = 1 - BN momentum in Caffe2. + dilation (int): size of dilation. + norm_module (nn.Module): nn.Module for the normalization layer. The + default is nn.BatchNorm3d. + """ + super(BottleneckTransform, self).__init__() + self.temp_kernel_size = temp_kernel_size + self._inplace_relu = inplace_relu + self._eps = eps + self._bn_mmt = bn_mmt + self._stride_1x1 = stride_1x1 + self._construct( + dim_in, + dim_out, + stride, + dim_inner, + num_groups, + dilation, + norm_module, + ) + + def _construct( + self, + dim_in, + dim_out, + stride, + dim_inner, + num_groups, + dilation, + norm_module, + ): + (str1x1, str3x3) = (stride, 1) if self._stride_1x1 else (1, stride) + + # Tx1x1, BN, ReLU. + self.a = nn.Conv3d( + dim_in, + dim_inner, + kernel_size=[self.temp_kernel_size, 1, 1], + stride=[1, str1x1, str1x1], + padding=[int(self.temp_kernel_size // 2), 0, 0], + bias=False, + ) + self.a_bn = norm_module( + num_features=dim_inner, eps=self._eps, momentum=self._bn_mmt + ) + self.a_relu = nn.ReLU(inplace=self._inplace_relu) + + # 1x3x3, BN, ReLU. + self.b = nn.Conv3d( + dim_inner, + dim_inner, + [1, 3, 3], + stride=[1, str3x3, str3x3], + padding=[0, dilation, dilation], + groups=num_groups, + bias=False, + dilation=[1, dilation, dilation], + ) + self.b_bn = norm_module( + num_features=dim_inner, eps=self._eps, momentum=self._bn_mmt + ) + self.b_relu = nn.ReLU(inplace=self._inplace_relu) + + # 1x1x1, BN. + self.c = nn.Conv3d( + dim_inner, + dim_out, + kernel_size=[1, 1, 1], + stride=[1, 1, 1], + padding=[0, 0, 0], + bias=False, + ) + self.c_bn = norm_module( + num_features=dim_out, eps=self._eps, momentum=self._bn_mmt + ) + self.c_bn.transform_final_bn = True + + def forward(self, x): + # Explicitly forward every layer. + # Branch2a. + x = self.a(x) + x = self.a_bn(x) + x = self.a_relu(x) + + # Branch2b. + x = self.b(x) + x = self.b_bn(x) + x = self.b_relu(x) + + # Branch2c + x = self.c(x) + x = self.c_bn(x) + return x + + +class ResBlock(nn.Module): + """ + Residual block. + """ + + def __init__( + self, + dim_in, + dim_out, + temp_kernel_size, + stride, + trans_func, + dim_inner, + num_groups=1, + stride_1x1=False, + inplace_relu=True, + eps=1e-5, + bn_mmt=0.1, + dilation=1, + norm_module=nn.BatchNorm3d, + ): + """ + ResBlock class constructs redisual blocks. More details can be found in: + Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. + "Deep residual learning for image recognition." + https://arxiv.org/abs/1512.03385 + Args: + dim_in (int): the channel dimensions of the input. + dim_out (int): the channel dimension of the output. + temp_kernel_size (int): the temporal kernel sizes of the middle + convolution in the bottleneck. + stride (int): the stride of the bottleneck. + trans_func (string): transform function to be used to construct the + bottleneck. + dim_inner (int): the inner dimension of the block. + num_groups (int): number of groups for the convolution. num_groups=1 + is for standard ResNet like networks, and num_groups>1 is for + ResNeXt like networks. + stride_1x1 (bool): if True, apply stride to 1x1 conv, otherwise + apply stride to the 3x3 conv. + inplace_relu (bool): calculate the relu on the original input + without allocating new memory. + eps (float): epsilon for batch norm. + bn_mmt (float): momentum for batch norm. Noted that BN momentum in + PyTorch = 1 - BN momentum in Caffe2. + dilation (int): size of dilation. + norm_module (nn.Module): nn.Module for the normalization layer. The + default is nn.BatchNorm3d. + """ + super(ResBlock, self).__init__() + self._inplace_relu = inplace_relu + self._eps = eps + self._bn_mmt = bn_mmt + self._construct( + dim_in, + dim_out, + temp_kernel_size, + stride, + trans_func, + dim_inner, + num_groups, + stride_1x1, + inplace_relu, + dilation, + norm_module, + ) + + def _construct( + self, + dim_in, + dim_out, + temp_kernel_size, + stride, + trans_func, + dim_inner, + num_groups, + stride_1x1, + inplace_relu, + dilation, + norm_module, + ): + # Use skip connection with projection if dim or res change. + if (dim_in != dim_out) or (stride != 1): + self.branch1 = nn.Conv3d( + dim_in, + dim_out, + kernel_size=1, + stride=[1, stride, stride], + padding=0, + bias=False, + dilation=1, + ) + self.branch1_bn = norm_module( + num_features=dim_out, eps=self._eps, momentum=self._bn_mmt + ) + self.branch2 = trans_func( + dim_in, + dim_out, + temp_kernel_size, + stride, + dim_inner, + num_groups, + stride_1x1=stride_1x1, + inplace_relu=inplace_relu, + dilation=dilation, + norm_module=norm_module, + ) + self.relu = nn.ReLU(self._inplace_relu) + + def forward(self, x): + if hasattr(self, "branch1"): + x = self.branch1_bn(self.branch1(x)) + self.branch2(x) + else: + x = x + self.branch2(x) + x = self.relu(x) + return x + + +class ResStage(nn.Module): + """ + Stage of 3D ResNet. It expects to have one or more tensors as input for + single pathway (C2D, I3D, Slow), and multi-pathway (SlowFast) cases. + More details can be found here: + + Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He. + "SlowFast networks for video recognition." + https://arxiv.org/pdf/1812.03982.pdf + """ + + def __init__( + self, + dim_in, + dim_out, + stride, + temp_kernel_sizes, + num_blocks, + dim_inner, + num_groups, + num_block_temp_kernel, + nonlocal_inds, + nonlocal_group, + nonlocal_pool, + dilation, + instantiation="softmax", + trans_func_name="bottleneck_transform", + stride_1x1=False, + inplace_relu=True, + norm_module=nn.BatchNorm3d, + ): + """ + The `__init__` method of any subclass should also contain these arguments. + ResStage builds p streams, where p can be greater or equal to one. + Args: + dim_in (list): list of p the channel dimensions of the input. + Different channel dimensions control the input dimension of + different pathways. + dim_out (list): list of p the channel dimensions of the output. + Different channel dimensions control the input dimension of + different pathways. + temp_kernel_sizes (list): list of the p temporal kernel sizes of the + convolution in the bottleneck. Different temp_kernel_sizes + control different pathway. + stride (list): list of the p strides of the bottleneck. Different + stride control different pathway. + num_blocks (list): list of p numbers of blocks for each of the + pathway. + dim_inner (list): list of the p inner channel dimensions of the + input. Different channel dimensions control the input dimension + of different pathways. + num_groups (list): list of number of p groups for the convolution. + num_groups=1 is for standard ResNet like networks, and + num_groups>1 is for ResNeXt like networks. + num_block_temp_kernel (list): extent the temp_kernel_sizes to + num_block_temp_kernel blocks, then fill temporal kernel size + of 1 for the rest of the layers. + nonlocal_inds (list): If the tuple is empty, no nonlocal layer will + be added. If the tuple is not empty, add nonlocal layers after + the index-th block. + dilation (list): size of dilation for each pathway. + nonlocal_group (list): list of number of p nonlocal groups. Each + number controls how to fold temporal dimension to batch + dimension before applying nonlocal transformation. + https://github.com/facebookresearch/video-nonlocal-net. + instantiation (string): different instantiation for nonlocal layer. + Supports two different instantiation method: + "dot_product": normalizing correlation matrix with L2. + "softmax": normalizing correlation matrix with Softmax. + trans_func_name (string): name of the the transformation function apply + on the network. + norm_module (nn.Module): nn.Module for the normalization layer. The + default is nn.BatchNorm3d. + """ + super(ResStage, self).__init__() + assert all( + ( + num_block_temp_kernel[i] <= num_blocks[i] + for i in range(len(temp_kernel_sizes)) + ) + ) + self.num_blocks = num_blocks + self.nonlocal_group = nonlocal_group + self.temp_kernel_sizes = [ + (temp_kernel_sizes[i] * num_blocks[i])[: num_block_temp_kernel[i]] + + [1] * (num_blocks[i] - num_block_temp_kernel[i]) + for i in range(len(temp_kernel_sizes)) + ] + assert ( + len( + { + len(dim_in), + len(dim_out), + len(temp_kernel_sizes), + len(stride), + len(num_blocks), + len(dim_inner), + len(num_groups), + len(num_block_temp_kernel), + len(nonlocal_inds), + len(nonlocal_group), + } + ) + == 1 + ) + self.num_pathways = len(self.num_blocks) + self._construct( + dim_in, + dim_out, + stride, + dim_inner, + num_groups, + trans_func_name, + stride_1x1, + inplace_relu, + nonlocal_inds, + nonlocal_pool, + instantiation, + dilation, + norm_module, + ) + + def _construct( + self, + dim_in, + dim_out, + stride, + dim_inner, + num_groups, + trans_func_name, + stride_1x1, + inplace_relu, + nonlocal_inds, + nonlocal_pool, + instantiation, + dilation, + norm_module, + ): + for pathway in range(self.num_pathways): + for i in range(self.num_blocks[pathway]): + # Retrieve the transformation function. + trans_func = get_trans_func(trans_func_name) + # Construct the block. + res_block = ResBlock( + dim_in[pathway] if i == 0 else dim_out[pathway], + dim_out[pathway], + self.temp_kernel_sizes[pathway][i], + stride[pathway] if i == 0 else 1, + trans_func, + dim_inner[pathway], + num_groups[pathway], + stride_1x1=stride_1x1, + inplace_relu=inplace_relu, + dilation=dilation[pathway], + norm_module=norm_module, + ) + self.add_module("pathway{}_res{}".format(pathway, i), res_block) + if i in nonlocal_inds[pathway]: + nln = Nonlocal( + dim_out[pathway], + dim_out[pathway] // 2, + nonlocal_pool[pathway], + instantiation=instantiation, + norm_module=norm_module, + ) + self.add_module( + "pathway{}_nonlocal{}".format(pathway, i), nln + ) + + def forward(self, inputs): + output = [] + for pathway in range(self.num_pathways): + x = inputs[pathway] + for i in range(self.num_blocks[pathway]): + m = getattr(self, "pathway{}_res{}".format(pathway, i)) + x = m(x) + if hasattr(self, "pathway{}_nonlocal{}".format(pathway, i)): + nln = getattr( + self, "pathway{}_nonlocal{}".format(pathway, i) + ) + b, c, t, h, w = x.shape + if self.nonlocal_group[pathway] > 1: + # Fold temporal dimension into batch dimension. + x = x.permute(0, 2, 1, 3, 4) + x = x.reshape( + b * self.nonlocal_group[pathway], + t // self.nonlocal_group[pathway], + c, + h, + w, + ) + x = x.permute(0, 2, 1, 3, 4) + x = nln(x) + if self.nonlocal_group[pathway] > 1: + # Fold back to temporal dimension. + x = x.permute(0, 2, 1, 3, 4) + x = x.reshape(b, t, c, h, w) + x = x.permute(0, 2, 1, 3, 4) + output.append(x) + + return output diff --git a/training/detectors/utils/slowfast/models/stem_helper.py b/training/detectors/utils/slowfast/models/stem_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..481977b15a13edf54bfdb17fd3627b6657d56262 --- /dev/null +++ b/training/detectors/utils/slowfast/models/stem_helper.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +"""ResNe(X)t 3D stem helper.""" + +import torch.nn as nn + + +class VideoModelStem(nn.Module): + """ + Video 3D stem module. Provides stem operations of Conv, BN, ReLU, MaxPool + on input data tensor for one or multiple pathways. + """ + + def __init__( + self, + dim_in, + dim_out, + kernel, + stride, + padding, + inplace_relu=True, + eps=1e-5, + bn_mmt=0.1, + norm_module=nn.BatchNorm3d, + ): + """ + The `__init__` method of any subclass should also contain these + arguments. List size of 1 for single pathway models (C2D, I3D, Slow + and etc), list size of 2 for two pathway models (SlowFast). + + Args: + dim_in (list): the list of channel dimensions of the inputs. + dim_out (list): the output dimension of the convolution in the stem + layer. + kernel (list): the kernels' size of the convolutions in the stem + layers. Temporal kernel size, height kernel size, width kernel + size in order. + stride (list): the stride sizes of the convolutions in the stem + layer. Temporal kernel stride, height kernel size, width kernel + size in order. + padding (list): the paddings' sizes of the convolutions in the stem + layer. Temporal padding size, height padding size, width padding + size in order. + inplace_relu (bool): calculate the relu on the original input + without allocating new memory. + eps (float): epsilon for batch norm. + bn_mmt (float): momentum for batch norm. Noted that BN momentum in + PyTorch = 1 - BN momentum in Caffe2. + norm_module (nn.Module): nn.Module for the normalization layer. The + default is nn.BatchNorm3d. + """ + super(VideoModelStem, self).__init__() + + assert ( + len( + { + len(dim_in), + len(dim_out), + len(kernel), + len(stride), + len(padding), + } + ) + == 1 + ), "Input pathway dimensions are not consistent." + self.num_pathways = len(dim_in) + self.kernel = kernel + self.stride = stride + self.padding = padding + self.inplace_relu = inplace_relu + self.eps = eps + self.bn_mmt = bn_mmt + # Construct the stem layer. + self._construct_stem(dim_in, dim_out, norm_module) + + def _construct_stem(self, dim_in, dim_out, norm_module): + for pathway in range(len(dim_in)): + stem = ResNetBasicStem( + dim_in[pathway], + dim_out[pathway], + self.kernel[pathway], + self.stride[pathway], + self.padding[pathway], + self.inplace_relu, + self.eps, + self.bn_mmt, + norm_module, + ) + self.add_module("pathway{}_stem".format(pathway), stem) + + def forward(self, x): + assert ( + len(x) == self.num_pathways + ), "Input tensor does not contain {} pathway".format(self.num_pathways) + for pathway in range(len(x)): + m = getattr(self, "pathway{}_stem".format(pathway)) + x[pathway] = m(x[pathway]) + return x + + +class ResNetBasicStem(nn.Module): + """ + ResNe(X)t 3D stem module. + Performs spatiotemporal Convolution, BN, and Relu following by a + spatiotemporal pooling. + """ + + def __init__( + self, + dim_in, + dim_out, + kernel, + stride, + padding, + inplace_relu=True, + eps=1e-5, + bn_mmt=0.1, + norm_module=nn.BatchNorm3d, + ): + """ + The `__init__` method of any subclass should also contain these arguments. + + Args: + dim_in (int): the channel dimension of the input. Normally 3 is used + for rgb input, and 2 or 3 is used for optical flow input. + dim_out (int): the output dimension of the convolution in the stem + layer. + kernel (list): the kernel size of the convolution in the stem layer. + temporal kernel size, height kernel size, width kernel size in + order. + stride (list): the stride size of the convolution in the stem layer. + temporal kernel stride, height kernel size, width kernel size in + order. + padding (int): the padding size of the convolution in the stem + layer, temporal padding size, height padding size, width + padding size in order. + inplace_relu (bool): calculate the relu on the original input + without allocating new memory. + eps (float): epsilon for batch norm. + bn_mmt (float): momentum for batch norm. Noted that BN momentum in + PyTorch = 1 - BN momentum in Caffe2. + norm_module (nn.Module): nn.Module for the normalization layer. The + default is nn.BatchNorm3d. + """ + super(ResNetBasicStem, self).__init__() + self.kernel = kernel + self.stride = stride + self.padding = padding + self.inplace_relu = inplace_relu + self.eps = eps + self.bn_mmt = bn_mmt + # Construct the stem layer. + self._construct_stem(dim_in, dim_out, norm_module) + + def _construct_stem(self, dim_in, dim_out, norm_module): + self.conv = nn.Conv3d( + dim_in, + dim_out, + self.kernel, + stride=self.stride, + padding=self.padding, + bias=False, + ) + self.bn = norm_module( + num_features=dim_out, eps=self.eps, momentum=self.bn_mmt + ) + self.relu = nn.ReLU(self.inplace_relu) + self.pool_layer = nn.MaxPool3d( + kernel_size=[1, 3, 3], stride=[1, 2, 2], padding=[0, 1, 1] + ) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + x = self.pool_layer(x) + return x diff --git a/training/detectors/utils/slowfast/models/unet_helper.py b/training/detectors/utils/slowfast/models/unet_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..36b7202cd1936a433b193017f6c363e5dd317b4f --- /dev/null +++ b/training/detectors/utils/slowfast/models/unet_helper.py @@ -0,0 +1,157 @@ +InPlaceABN = None +from torch import nn +import torch.nn.functional as F + + +class Conv3dReLU(nn.Sequential): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + padding=0, + stride=1, + use_batchnorm=True, + ): + + if use_batchnorm == "inplace" and InPlaceABN is None: + raise RuntimeError( + "In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. " + + "To install see: https://github.com/mapillary/inplace_abn" + ) + + conv = nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=not (use_batchnorm), + ) + relu = nn.ReLU(inplace=True) + + if use_batchnorm == "inplace": + bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0) + relu = nn.Identity() + + elif use_batchnorm and use_batchnorm != "inplace": + bn = nn.BatchNorm3d(out_channels) + + else: + bn = nn.Identity() + + super(Conv3dReLU, self).__init__(conv, bn, relu) + + +class DecoderBlock(nn.Module): + def __init__( + self, in_channels, skip_channels, out_channels, use_batchnorm=True, + ): + super().__init__() + self.conv1 = Conv3dReLU( + in_channels + skip_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + + self.conv2 = Conv3dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + + +class LightDecoderBlock(nn.Module): + def __init__( + self, in_channels, skip_channels, out_channels, use_batchnorm=True, + ): + super().__init__() + self.conv1 = Conv3dReLU( + in_channels + skip_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + + def forward(self, x): + x = self.conv1(x) + return x + + +def freeze_net(model: nn.Module, freeze_prefixs): + flag = False + for name, param in model.named_parameters(): + items = name.split(".") + if items[0] == "module": + prefix = items[1] + else: + prefix = items[0] + if prefix in freeze_prefixs: + if param.requires_grad is True: + param.requires_grad = False + flag = True + # print("freeze",name) + + assert flag + + +def unfreeze_net(model: nn.Module): + for name, param in model.named_parameters(): + param.requires_grad = True + + +from .resnet_helper import ResBlock, get_trans_func + + +class ResDecoderBlock(nn.Module): + def __init__( + self, in_channels, skip_channels, out_channels, use_batchnorm=True, + ): + super().__init__() + trans_func = get_trans_func("bottleneck_transform") + self.conv1 = ResBlock( + in_channels + skip_channels, + out_channels, + 3, + 1, + trans_func, + out_channels//2, + num_groups=1, + stride_1x1=False, + inplace_relu=True, + eps=1e-5, + bn_mmt=0.1, + dilation=1, + norm_module=nn.BatchNorm3d, + ) + + self.conv2 = ResBlock( + out_channels, + out_channels, + 3, + 1, + trans_func, + out_channels//2, + num_groups=1, + stride_1x1=False, + inplace_relu=True, + eps=1e-5, + bn_mmt=0.1, + dilation=1, + norm_module=nn.BatchNorm3d, + ) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x diff --git a/training/detectors/utils/slowfast/models/video_model_builder.py b/training/detectors/utils/slowfast/models/video_model_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..3c89e8ebb8f51fe79b82c01542f0668345212726 --- /dev/null +++ b/training/detectors/utils/slowfast/models/video_model_builder.py @@ -0,0 +1,2739 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +"""Video models.""" + +import torch +import torch.nn as nn +import copy + +import slowfast.utils.weight_init_helper as init_helper +from slowfast.models.batchnorm_helper import get_norm + +from . import head_helper, resnet_helper, stem_helper +from .build import MODEL_REGISTRY + +# Number of blocks for different stages given the model depth. +_MODEL_STAGE_DEPTH = {18:(2,2,2,2),50: (3, 4, 6, 3), 101: (3, 4, 23, 3)} + +# Basis of temporal kernel sizes for each of the stage. +_TEMPORAL_KERNEL_BASIS = { + "c2d": [ + [[1]], # conv1 temporal kernel. + [[1]], # res2 temporal kernel. + [[1]], # res3 temporal kernel. + [[1]], # res4 temporal kernel. + [[1]], # res5 temporal kernel. + ], + "c2d_nopool": [ + [[1]], # conv1 temporal kernel. + [[1]], # res2 temporal kernel. + [[1]], # res3 temporal kernel. + [[1]], # res4 temporal kernel. + [[1]], # res5 temporal kernel. + ], + "i3d": [ + [[5]], # conv1 temporal kernel. + [[3]], # res2 temporal kernel. + [[3, 1]], # res3 temporal kernel. + [[3, 1]], # res4 temporal kernel. + [[1, 3]], # res5 temporal kernel. + ], + "r3d_18": [ + [[3]], # conv1 temporal kernel. + [[3]], # res2 temporal kernel. + [[3, 1]], # res3 temporal kernel. + [[3, 1]], # res4 temporal kernel. + [[1, 3]], # res5 temporal kernel. + ], + "i3d_nopool": [ + [[5]], # conv1 temporal kernel. + [[3]], # res2 temporal kernel. + [[3, 1]], # res3 temporal kernel. + [[3, 1]], # res4 temporal kernel. + [[1, 3]], # res5 temporal kernel. + ], + "slow": [ + [[1]], # conv1 temporal kernel. + [[1]], # res2 temporal kernel. + [[1]], # res3 temporal kernel. + [[3]], # res4 temporal kernel. + [[3]], # res5 temporal kernel. + ], + "slowfast": [ + [[1], [5]], # conv1 temporal kernel for slow and fast pathway. + [[1], [3]], # res2 temporal kernel for slow and fast pathway. + [[1], [3]], # res3 temporal kernel for slow and fast pathway. + [[3], [3]], # res4 temporal kernel for slow and fast pathway. + [[3], [3]], # res5 temporal kernel for slow and fast pathway. + ], +} + +_POOL1 = { + "c2d": [[2, 1, 1]], + "c2d_nopool": [[1, 1, 1]], + "i3d": [[2, 1, 1]], + "r3d_18": [[2, 1, 1]], + "i3d_nopool": [[1, 1, 1]], + "slow": [[1, 1, 1]], + "slowfast": [[1, 1, 1], [1, 1, 1]], +} + + + + +class FuseFastToSlow(nn.Module): + """ + Fuses the information from the Fast pathway to the Slow pathway. Given the + tensors from Slow pathway and Fast pathway, fuse information from Fast to + Slow, then return the fused tensors from Slow and Fast pathway in order. + """ + + def __init__( + self, + dim_in, + fusion_conv_channel_ratio, + fusion_kernel, + alpha, + eps=1e-5, + bn_mmt=0.1, + inplace_relu=True, + norm_module=nn.BatchNorm3d, + ): + """ + Args: + dim_in (int): the channel dimension of the input. + fusion_conv_channel_ratio (int): channel ratio for the convolution + used to fuse from Fast pathway to Slow pathway. + fusion_kernel (int): kernel size of the convolution used to fuse + from Fast pathway to Slow pathway. + alpha (int): the frame rate ratio between the Fast and Slow pathway. + eps (float): epsilon for batch norm. + bn_mmt (float): momentum for batch norm. Noted that BN momentum in + PyTorch = 1 - BN momentum in Caffe2. + inplace_relu (bool): if True, calculate the relu on the original + input without allocating new memory. + norm_module (nn.Module): nn.Module for the normalization layer. The + default is nn.BatchNorm3d. + """ + super(FuseFastToSlow, self).__init__() + self.conv_f2s = nn.Conv3d( + dim_in, + dim_in * fusion_conv_channel_ratio, + kernel_size=[fusion_kernel, 1, 1], + stride=[alpha, 1, 1], + padding=[fusion_kernel // 2, 0, 0], + bias=False, + ) + self.bn = norm_module( + num_features=dim_in * fusion_conv_channel_ratio, + eps=eps, + momentum=bn_mmt, + ) + self.relu = nn.ReLU(inplace_relu) + + def forward(self, x): + x_s = x[0] + x_f = x[1] + fuse = self.conv_f2s(x_f) + fuse = self.bn(fuse) + fuse = self.relu(fuse) + x_s_fuse = torch.cat([x_s, fuse], 1) + return [x_s_fuse, x_f] + + +@MODEL_REGISTRY.register() +class SlowFast(nn.Module): + """ + SlowFast model builder for SlowFast network. + + Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He. + "SlowFast networks for video recognition." + https://arxiv.org/pdf/1812.03982.pdf + """ + + def __init__(self, cfg): + """ + The `__init__` method of any subclass should also contain these + arguments. + Args: + cfg (CfgNode): model building configs, details are in the + comments of the config file. + """ + super(SlowFast, self).__init__() + self.norm_module = get_norm(cfg) + self.enable_detection = cfg.DETECTION.ENABLE + self.num_pathways = 2 + self._construct_network(cfg) + init_helper.init_weights( + self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN + ) + + def _construct_network(self, cfg): + """ + Builds a SlowFast model. The first pathway is the Slow pathway and the + second pathway is the Fast pathway. + Args: + cfg (CfgNode): model building configs, details are in the + comments of the config file. + """ + assert cfg.MODEL.ARCH in _POOL1.keys() + pool_size = _POOL1[cfg.MODEL.ARCH] + assert len({len(pool_size), self.num_pathways}) == 1 + assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys() + + (d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH] + + num_groups = cfg.RESNET.NUM_GROUPS + width_per_group = cfg.RESNET.WIDTH_PER_GROUP + dim_inner = num_groups * width_per_group + out_dim_ratio = ( + cfg.SLOWFAST.BETA_INV // cfg.SLOWFAST.FUSION_CONV_CHANNEL_RATIO + ) + + temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH] + + self.s1 = stem_helper.VideoModelStem( + dim_in=cfg.DATA.INPUT_CHANNEL_NUM, + dim_out=[width_per_group, width_per_group // cfg.SLOWFAST.BETA_INV], + kernel=[temp_kernel[0][0] + [7, 7], temp_kernel[0][1] + [7, 7]], + stride=[[1, 2, 2]] * 2, + padding=[ + [temp_kernel[0][0][0] // 2, 3, 3], + [temp_kernel[0][1][0] // 2, 3, 3], + ], + norm_module=self.norm_module, + ) + self.s1_fuse = FuseFastToSlow( + width_per_group // cfg.SLOWFAST.BETA_INV, + cfg.SLOWFAST.FUSION_CONV_CHANNEL_RATIO, + cfg.SLOWFAST.FUSION_KERNEL_SZ, + cfg.SLOWFAST.ALPHA, + norm_module=self.norm_module, + ) + + self.s2 = resnet_helper.ResStage( + dim_in=[ + width_per_group + width_per_group // out_dim_ratio, + width_per_group // cfg.SLOWFAST.BETA_INV, + ], + dim_out=[ + width_per_group * 4, + width_per_group * 4 // cfg.SLOWFAST.BETA_INV, + ], + dim_inner=[dim_inner, dim_inner // cfg.SLOWFAST.BETA_INV], + temp_kernel_sizes=temp_kernel[1], + stride=cfg.RESNET.SPATIAL_STRIDES[0], + num_blocks=[d2] * 2, + num_groups=[num_groups] * 2, + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0], + nonlocal_inds=cfg.NONLOCAL.LOCATION[0], + nonlocal_group=cfg.NONLOCAL.GROUP[0], + nonlocal_pool=cfg.NONLOCAL.POOL[0], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + dilation=cfg.RESNET.SPATIAL_DILATIONS[0], + norm_module=self.norm_module, + ) + self.s2_fuse = FuseFastToSlow( + width_per_group * 4 // cfg.SLOWFAST.BETA_INV, + cfg.SLOWFAST.FUSION_CONV_CHANNEL_RATIO, + cfg.SLOWFAST.FUSION_KERNEL_SZ, + cfg.SLOWFAST.ALPHA, + norm_module=self.norm_module, + ) + + for pathway in range(self.num_pathways): + pool = nn.MaxPool3d( + kernel_size=pool_size[pathway], + stride=pool_size[pathway], + padding=[0, 0, 0], + ) + self.add_module("pathway{}_pool".format(pathway), pool) + + self.s3 = resnet_helper.ResStage( + dim_in=[ + width_per_group * 4 + width_per_group * 4 // out_dim_ratio, + width_per_group * 4 // cfg.SLOWFAST.BETA_INV, + ], + dim_out=[ + width_per_group * 8, + width_per_group * 8 // cfg.SLOWFAST.BETA_INV, + ], + dim_inner=[dim_inner * 2, dim_inner * 2 // cfg.SLOWFAST.BETA_INV], + temp_kernel_sizes=temp_kernel[2], + stride=cfg.RESNET.SPATIAL_STRIDES[1], + num_blocks=[d3] * 2, + num_groups=[num_groups] * 2, + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1], + nonlocal_inds=cfg.NONLOCAL.LOCATION[1], + nonlocal_group=cfg.NONLOCAL.GROUP[1], + nonlocal_pool=cfg.NONLOCAL.POOL[1], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + dilation=cfg.RESNET.SPATIAL_DILATIONS[1], + norm_module=self.norm_module, + ) + self.s3_fuse = FuseFastToSlow( + width_per_group * 8 // cfg.SLOWFAST.BETA_INV, + cfg.SLOWFAST.FUSION_CONV_CHANNEL_RATIO, + cfg.SLOWFAST.FUSION_KERNEL_SZ, + cfg.SLOWFAST.ALPHA, + norm_module=self.norm_module, + ) + + self.s4 = resnet_helper.ResStage( + dim_in=[ + width_per_group * 8 + width_per_group * 8 // out_dim_ratio, + width_per_group * 8 // cfg.SLOWFAST.BETA_INV, + ], + dim_out=[ + width_per_group * 16, + width_per_group * 16 // cfg.SLOWFAST.BETA_INV, + ], + dim_inner=[dim_inner * 4, dim_inner * 4 // cfg.SLOWFAST.BETA_INV], + temp_kernel_sizes=temp_kernel[3], + stride=cfg.RESNET.SPATIAL_STRIDES[2], + num_blocks=[d4] * 2, + num_groups=[num_groups] * 2, + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2], + nonlocal_inds=cfg.NONLOCAL.LOCATION[2], + nonlocal_group=cfg.NONLOCAL.GROUP[2], + nonlocal_pool=cfg.NONLOCAL.POOL[2], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + dilation=cfg.RESNET.SPATIAL_DILATIONS[2], + norm_module=self.norm_module, + ) + self.s4_fuse = FuseFastToSlow( + width_per_group * 16 // cfg.SLOWFAST.BETA_INV, + cfg.SLOWFAST.FUSION_CONV_CHANNEL_RATIO, + cfg.SLOWFAST.FUSION_KERNEL_SZ, + cfg.SLOWFAST.ALPHA, + norm_module=self.norm_module, + ) + + self.s5 = resnet_helper.ResStage( + dim_in=[ + width_per_group * 16 + width_per_group * 16 // out_dim_ratio, + width_per_group * 16 // cfg.SLOWFAST.BETA_INV, + ], + dim_out=[ + width_per_group * 32, + width_per_group * 32 // cfg.SLOWFAST.BETA_INV, + ], + dim_inner=[dim_inner * 8, dim_inner * 8 // cfg.SLOWFAST.BETA_INV], + temp_kernel_sizes=temp_kernel[4], + stride=cfg.RESNET.SPATIAL_STRIDES[3], + num_blocks=[d5] * 2, + num_groups=[num_groups] * 2, + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[3], + nonlocal_inds=cfg.NONLOCAL.LOCATION[3], + nonlocal_group=cfg.NONLOCAL.GROUP[3], + nonlocal_pool=cfg.NONLOCAL.POOL[3], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + dilation=cfg.RESNET.SPATIAL_DILATIONS[3], + norm_module=self.norm_module, + ) + + if cfg.DETECTION.ENABLE: + raise NotImplementedError + else: + self.head = head_helper.ResNetBasicHead( + dim_in=[ + width_per_group * 32, + width_per_group * 32 // cfg.SLOWFAST.BETA_INV, + ], + num_classes=cfg.MODEL.NUM_CLASSES, + pool_size=[None, None] + if cfg.MULTIGRID.SHORT_CYCLE + else [ + [ + cfg.DATA.NUM_FRAMES + // cfg.SLOWFAST.ALPHA + // pool_size[0][0], + cfg.DATA.CROP_SIZE // 32 // pool_size[0][1], + cfg.DATA.CROP_SIZE // 32 // pool_size[0][2], + ], + [ + cfg.DATA.NUM_FRAMES // pool_size[1][0], + cfg.DATA.CROP_SIZE // 32 // pool_size[1][1], + cfg.DATA.CROP_SIZE // 32 // pool_size[1][2], + ], + ], # None for AdaptiveAvgPool3d((1, 1, 1)) + dropout_rate=cfg.MODEL.DROPOUT_RATE, + act_func=cfg.MODEL.HEAD_ACT, + ) + + def forward(self, x, bboxes=None): + x = self.s1(x) + x = self.s1_fuse(x) + x = self.s2(x) + x = self.s2_fuse(x) + for pathway in range(self.num_pathways): + pool = getattr(self, "pathway{}_pool".format(pathway)) + x[pathway] = pool(x[pathway]) + x = self.s3(x) + x = self.s3_fuse(x) + x = self.s4(x) + x = self.s4_fuse(x) + x = self.s5(x) + if self.enable_detection: + x = self.head(x, bboxes) + else: + x = self.head(x) + return x + + +@MODEL_REGISTRY.register() +class ResNet(nn.Module): + """ + ResNet model builder. It builds a ResNet like network backbone without + lateral connection (C2D, I3D, Slow). + + Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He. + "SlowFast networks for video recognition." + https://arxiv.org/pdf/1812.03982.pdf + + Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He. + "Non-local neural networks." + https://arxiv.org/pdf/1711.07971.pdf + """ + + def __init__(self, cfg): + """ + The `__init__` method of any subclass should also contain these + arguments. + + Args: + cfg (CfgNode): model building configs, details are in the + comments of the config file. + """ + super(ResNet, self).__init__() + self.norm_module = get_norm(cfg) + self.enable_detection = cfg.DETECTION.ENABLE + self.num_pathways = 1 + self._construct_network(cfg) + init_helper.init_weights( + self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN + ) + + def _construct_network(self, cfg): + """ + Builds a single pathway ResNet model. + + Args: + cfg (CfgNode): model building configs, details are in the + comments of the config file. + """ + assert cfg.MODEL.ARCH in _POOL1.keys() + pool_size = _POOL1[cfg.MODEL.ARCH] + assert len({len(pool_size), self.num_pathways}) == 1 + assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys() + + (d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH] + + num_groups = cfg.RESNET.NUM_GROUPS + width_per_group = cfg.RESNET.WIDTH_PER_GROUP + dim_inner = num_groups * width_per_group + print(dim_inner) + + temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH] + + self.s1 = stem_helper.VideoModelStem( + dim_in=cfg.DATA.INPUT_CHANNEL_NUM, + dim_out=[width_per_group], + kernel=[temp_kernel[0][0] + [7, 7]], + stride=[[1, 2, 2]], + padding=[[temp_kernel[0][0][0] // 2, 3, 3]], + norm_module=self.norm_module, + ) + + self.s2 = resnet_helper.ResStage( + dim_in=[width_per_group], + dim_out=[width_per_group * 4], + dim_inner=[dim_inner], + temp_kernel_sizes=temp_kernel[1], + stride=cfg.RESNET.SPATIAL_STRIDES[0], + num_blocks=[d2], + num_groups=[num_groups], + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0], + nonlocal_inds=cfg.NONLOCAL.LOCATION[0], + nonlocal_group=cfg.NONLOCAL.GROUP[0], + nonlocal_pool=cfg.NONLOCAL.POOL[0], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + stride_1x1=cfg.RESNET.STRIDE_1X1, + inplace_relu=cfg.RESNET.INPLACE_RELU, + dilation=cfg.RESNET.SPATIAL_DILATIONS[0], + norm_module=self.norm_module, + ) + + for pathway in range(self.num_pathways): + pool = nn.MaxPool3d( + kernel_size=pool_size[pathway], + stride=pool_size[pathway], + padding=[0, 0, 0], + ) + self.add_module("pathway{}_pool".format(pathway), pool) + + self.s3 = resnet_helper.ResStage( + dim_in=[width_per_group * 4], + dim_out=[width_per_group * 8], + dim_inner=[dim_inner * 2], + temp_kernel_sizes=temp_kernel[2], + stride=cfg.RESNET.SPATIAL_STRIDES[1], + num_blocks=[d3], + num_groups=[num_groups], + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1], + nonlocal_inds=cfg.NONLOCAL.LOCATION[1], + nonlocal_group=cfg.NONLOCAL.GROUP[1], + nonlocal_pool=cfg.NONLOCAL.POOL[1], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + stride_1x1=cfg.RESNET.STRIDE_1X1, + inplace_relu=cfg.RESNET.INPLACE_RELU, + dilation=cfg.RESNET.SPATIAL_DILATIONS[1], + norm_module=self.norm_module, + ) + + self.s4 = resnet_helper.ResStage( + dim_in=[width_per_group * 8], + dim_out=[width_per_group * 16], + dim_inner=[dim_inner * 4], + temp_kernel_sizes=temp_kernel[3], + stride=cfg.RESNET.SPATIAL_STRIDES[2], + num_blocks=[d4], + num_groups=[num_groups], + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2], + nonlocal_inds=cfg.NONLOCAL.LOCATION[2], + nonlocal_group=cfg.NONLOCAL.GROUP[2], + nonlocal_pool=cfg.NONLOCAL.POOL[2], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + stride_1x1=cfg.RESNET.STRIDE_1X1, + inplace_relu=cfg.RESNET.INPLACE_RELU, + dilation=cfg.RESNET.SPATIAL_DILATIONS[2], + norm_module=self.norm_module, + ) + + self.s5 = resnet_helper.ResStage( + dim_in=[width_per_group * 16], + dim_out=[width_per_group * 32], + dim_inner=[dim_inner * 8], + temp_kernel_sizes=temp_kernel[4], + stride=cfg.RESNET.SPATIAL_STRIDES[3], + num_blocks=[d5], + num_groups=[num_groups], + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[3], + nonlocal_inds=cfg.NONLOCAL.LOCATION[3], + nonlocal_group=cfg.NONLOCAL.GROUP[3], + nonlocal_pool=cfg.NONLOCAL.POOL[3], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + stride_1x1=cfg.RESNET.STRIDE_1X1, + inplace_relu=cfg.RESNET.INPLACE_RELU, + dilation=cfg.RESNET.SPATIAL_DILATIONS[3], + norm_module=self.norm_module, + ) + + if self.enable_detection: + raise NotImplementedError + else: + self.head = head_helper.ResNetBasicHead( + dim_in=[width_per_group * 32], + num_classes=cfg.MODEL.NUM_CLASSES, + pool_size=[None, None] + if cfg.MULTIGRID.SHORT_CYCLE + else [ + [ + cfg.DATA.NUM_FRAMES // pool_size[0][0], + cfg.DATA.CROP_SIZE // 32 // pool_size[0][1], + cfg.DATA.CROP_SIZE // 32 // pool_size[0][2], + ] + ], # None for AdaptiveAvgPool3d((1, 1, 1)) + dropout_rate=cfg.MODEL.DROPOUT_RATE, + act_func=cfg.MODEL.HEAD_ACT, + ) + + def forward(self, x, return_feat=False, bboxes=None): + x = self.s1(x) + x = self.s2(x) + for pathway in range(self.num_pathways): + pool = getattr(self, "pathway{}_pool".format(pathway)) + x[pathway] = pool(x[pathway]) + x = self.s3(x) + x = self.s4(x) + feat = self.s5(x) + if return_feat: + return feat + if self.enable_detection: + x = self.head(feat, bboxes) + else: + x = self.head(feat) + return x + +@MODEL_REGISTRY.register() +class ResNetVar(nn.Module): + """ + ResNet model builder. It builds a ResNet like network backbone without + lateral connection (C2D, I3D, Slow). + + Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He. + "SlowFast networks for video recognition." + https://arxiv.org/pdf/1812.03982.pdf + + Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He. + "Non-local neural networks." + https://arxiv.org/pdf/1711.07971.pdf + """ + + def __init__(self, cfg): + """ + The `__init__` method of any subclass should also contain these + arguments. + + Args: + cfg (CfgNode): model building configs, details are in the + comments of the config file. + """ + super(ResNetVar, self).__init__() + self.norm_module = get_norm(cfg) + self.enable_detection = cfg.DETECTION.ENABLE + self.num_pathways = 1 + self._construct_network(cfg) + init_helper.init_weights( + self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN + ) + + def _construct_network(self, cfg): + """ + Builds a single pathway ResNet model. + + Args: + cfg (CfgNode): model building configs, details are in the + comments of the config file. + """ + assert cfg.MODEL.ARCH in _POOL1.keys() + pool_size = _POOL1[cfg.MODEL.ARCH] + assert len({len(pool_size), self.num_pathways}) == 1 + assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys() + + (d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH] + + num_groups = cfg.RESNET.NUM_GROUPS + width_per_group = cfg.RESNET.WIDTH_PER_GROUP + dim_inner = num_groups * width_per_group + + temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH] + + self.s1 = stem_helper.VideoModelStem( + dim_in=cfg.DATA.INPUT_CHANNEL_NUM, + dim_out=[width_per_group], + kernel=[temp_kernel[0][0] + [7, 7]], + stride=[[1, 2, 2]], + padding=[[temp_kernel[0][0][0] // 2, 3, 3]], + norm_module=self.norm_module, + ) + + self.s2 = resnet_helper.ResStage( + dim_in=[width_per_group], + dim_out=[width_per_group * 4], + dim_inner=[dim_inner], + temp_kernel_sizes=temp_kernel[1], + stride=cfg.RESNET.SPATIAL_STRIDES[0], + num_blocks=[d2], + num_groups=[num_groups], + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0], + nonlocal_inds=cfg.NONLOCAL.LOCATION[0], + nonlocal_group=cfg.NONLOCAL.GROUP[0], + nonlocal_pool=cfg.NONLOCAL.POOL[0], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + stride_1x1=cfg.RESNET.STRIDE_1X1, + inplace_relu=cfg.RESNET.INPLACE_RELU, + dilation=cfg.RESNET.SPATIAL_DILATIONS[0], + norm_module=self.norm_module, + ) + + for pathway in range(self.num_pathways): + pool = nn.MaxPool3d( + kernel_size=pool_size[pathway], + stride=pool_size[pathway], + padding=[0, 0, 0], + ) + self.add_module("pathway{}_pool".format(pathway), pool) + + self.s3 = resnet_helper.ResStage( + dim_in=[width_per_group * 4], + dim_out=[width_per_group * 8], + dim_inner=[dim_inner * 2], + temp_kernel_sizes=temp_kernel[2], + stride=cfg.RESNET.SPATIAL_STRIDES[1], + num_blocks=[d3], + num_groups=[num_groups], + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1], + nonlocal_inds=cfg.NONLOCAL.LOCATION[1], + nonlocal_group=cfg.NONLOCAL.GROUP[1], + nonlocal_pool=cfg.NONLOCAL.POOL[1], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + stride_1x1=cfg.RESNET.STRIDE_1X1, + inplace_relu=cfg.RESNET.INPLACE_RELU, + dilation=cfg.RESNET.SPATIAL_DILATIONS[1], + norm_module=self.norm_module, + ) + + self.s4 = resnet_helper.ResStage( + dim_in=[width_per_group * 8], + dim_out=[width_per_group * 16], + dim_inner=[dim_inner * 4], + temp_kernel_sizes=temp_kernel[3], + stride=cfg.RESNET.SPATIAL_STRIDES[2], + num_blocks=[d4], + num_groups=[num_groups], + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2], + nonlocal_inds=cfg.NONLOCAL.LOCATION[2], + nonlocal_group=cfg.NONLOCAL.GROUP[2], + nonlocal_pool=cfg.NONLOCAL.POOL[2], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + stride_1x1=cfg.RESNET.STRIDE_1X1, + inplace_relu=cfg.RESNET.INPLACE_RELU, + dilation=cfg.RESNET.SPATIAL_DILATIONS[2], + norm_module=self.norm_module, + ) + + self.s5 = resnet_helper.ResStage( + dim_in=[width_per_group * 16], + dim_out=[width_per_group * 32], + dim_inner=[dim_inner * 8], + temp_kernel_sizes=temp_kernel[4], + stride=cfg.RESNET.SPATIAL_STRIDES[3], + num_blocks=[d5], + num_groups=[num_groups], + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[3], + nonlocal_inds=cfg.NONLOCAL.LOCATION[3], + nonlocal_group=cfg.NONLOCAL.GROUP[3], + nonlocal_pool=cfg.NONLOCAL.POOL[3], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + stride_1x1=cfg.RESNET.STRIDE_1X1, + inplace_relu=cfg.RESNET.INPLACE_RELU, + dilation=cfg.RESNET.SPATIAL_DILATIONS[3], + norm_module=self.norm_module, + ) + + if self.enable_detection: + raise NotImplementedError + else: + self.head = head_helper.ResNetBasicHead( + dim_in=[width_per_group * 32], + num_classes=cfg.MODEL.NUM_CLASSES, + pool_size=[None], + dropout_rate=cfg.MODEL.DROPOUT_RATE, + act_func=cfg.MODEL.HEAD_ACT, + ) + + def forward(self, x, bboxes=None): + x = self.s1(x) + x = self.s2(x) + for pathway in range(self.num_pathways): + pool = getattr(self, "pathway{}_pool".format(pathway)) + x[pathway] = pool(x[pathway]) + x = self.s3(x) + x = self.s4(x) + x = self.s5(x) + if self.enable_detection: + x = self.head(x, bboxes) + else: + x = self.head(x) + return x + +@MODEL_REGISTRY.register() +class ResNetBase(nn.Module): + """ + ResNet model builder. It builds a ResNet like network backbone without + lateral connection (C2D, I3D, Slow). + + Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He. + "SlowFast networks for video recognition." + https://arxiv.org/pdf/1812.03982.pdf + + Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He. + "Non-local neural networks." + https://arxiv.org/pdf/1711.07971.pdf + """ + + def __init__(self, cfg): + """ + The `__init__` method of any subclass should also contain these + arguments. + + Args: + cfg (CfgNode): model building configs, details are in the + comments of the config file. + """ + super(ResNetBase, self).__init__() + self.norm_module = get_norm(cfg) + self.enable_detection = cfg.DETECTION.ENABLE + self.num_pathways = 1 + self._construct_network(cfg) + init_helper.init_weights( + self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN + ) + + def _construct_network(self, cfg): + """ + Builds a single pathway ResNet model. + + Args: + cfg (CfgNode): model building configs, details are in the + comments of the config file. + """ + assert cfg.MODEL.ARCH in _POOL1.keys() + pool_size = _POOL1[cfg.MODEL.ARCH] + assert len({len(pool_size), self.num_pathways}) == 1 + assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys() + + (d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH] + + num_groups = cfg.RESNET.NUM_GROUPS + width_per_group = cfg.RESNET.WIDTH_PER_GROUP + dim_inner = num_groups * width_per_group + + temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH] + + self.s1 = stem_helper.VideoModelStem( + dim_in=cfg.DATA.INPUT_CHANNEL_NUM, + dim_out=[width_per_group], + kernel=[temp_kernel[0][0] + [7, 7]], + stride=[[1, 2, 2]], + padding=[[temp_kernel[0][0][0] // 2, 3, 3]], + norm_module=self.norm_module, + ) + + self.s2 = resnet_helper.ResStage( + dim_in=[width_per_group], + dim_out=[width_per_group * 4], + dim_inner=[dim_inner], + temp_kernel_sizes=temp_kernel[1], + stride=cfg.RESNET.SPATIAL_STRIDES[0], + num_blocks=[d2], + num_groups=[num_groups], + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0], + nonlocal_inds=cfg.NONLOCAL.LOCATION[0], + nonlocal_group=cfg.NONLOCAL.GROUP[0], + nonlocal_pool=cfg.NONLOCAL.POOL[0], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + stride_1x1=cfg.RESNET.STRIDE_1X1, + inplace_relu=cfg.RESNET.INPLACE_RELU, + dilation=cfg.RESNET.SPATIAL_DILATIONS[0], + norm_module=self.norm_module, + ) + + for pathway in range(self.num_pathways): + pool = nn.MaxPool3d( + kernel_size=pool_size[pathway], + stride=pool_size[pathway], + padding=[0, 0, 0], + ) + self.add_module("pathway{}_pool".format(pathway), pool) + + self.s3 = resnet_helper.ResStage( + dim_in=[width_per_group * 4], + dim_out=[width_per_group * 8], + dim_inner=[dim_inner * 2], + temp_kernel_sizes=temp_kernel[2], + stride=cfg.RESNET.SPATIAL_STRIDES[1], + num_blocks=[d3], + num_groups=[num_groups], + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1], + nonlocal_inds=cfg.NONLOCAL.LOCATION[1], + nonlocal_group=cfg.NONLOCAL.GROUP[1], + nonlocal_pool=cfg.NONLOCAL.POOL[1], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + stride_1x1=cfg.RESNET.STRIDE_1X1, + inplace_relu=cfg.RESNET.INPLACE_RELU, + dilation=cfg.RESNET.SPATIAL_DILATIONS[1], + norm_module=self.norm_module, + ) + + self.s4 = resnet_helper.ResStage( + dim_in=[width_per_group * 8], + dim_out=[width_per_group * 16], + dim_inner=[dim_inner * 4], + temp_kernel_sizes=temp_kernel[3], + stride=cfg.RESNET.SPATIAL_STRIDES[2], + num_blocks=[d4], + num_groups=[num_groups], + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2], + nonlocal_inds=cfg.NONLOCAL.LOCATION[2], + nonlocal_group=cfg.NONLOCAL.GROUP[2], + nonlocal_pool=cfg.NONLOCAL.POOL[2], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + stride_1x1=cfg.RESNET.STRIDE_1X1, + inplace_relu=cfg.RESNET.INPLACE_RELU, + dilation=cfg.RESNET.SPATIAL_DILATIONS[2], + norm_module=self.norm_module, + ) + + self.s5 = resnet_helper.ResStage( + dim_in=[width_per_group * 16], + dim_out=[width_per_group * 32], + dim_inner=[dim_inner * 8], + temp_kernel_sizes=temp_kernel[4], + stride=cfg.RESNET.SPATIAL_STRIDES[3], + num_blocks=[d5], + num_groups=[num_groups], + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[3], + nonlocal_inds=cfg.NONLOCAL.LOCATION[3], + nonlocal_group=cfg.NONLOCAL.GROUP[3], + nonlocal_pool=cfg.NONLOCAL.POOL[3], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + stride_1x1=cfg.RESNET.STRIDE_1X1, + inplace_relu=cfg.RESNET.INPLACE_RELU, + dilation=cfg.RESNET.SPATIAL_DILATIONS[3], + norm_module=self.norm_module, + ) + + if self.enable_detection: + raise NotImplementedError + else: + self.head = head_helper.ResNetBasicHead( + dim_in=[width_per_group * 32], + num_classes=cfg.MODEL.NUM_CLASSES, + pool_size=[None, None] + if cfg.MULTIGRID.SHORT_CYCLE + else [ + None + ], # None for AdaptiveAvgPool3d((1, 1, 1)) + dropout_rate=cfg.MODEL.DROPOUT_RATE, + act_func=cfg.MODEL.HEAD_ACT, + ) + + def forward(self, x, bboxes=None): + x = self.s1(x) + x = self.s2(x) + for pathway in range(self.num_pathways): + pool = getattr(self, "pathway{}_pool".format(pathway)) + x[pathway] = pool(x[pathway]) + x = self.s3(x) + x = self.s4(x) + x = self.s5(x) + if self.enable_detection: + x = self.head(x, bboxes) + else: + x = self.head(x) + return x + + +@MODEL_REGISTRY.register() +class ResNetFreeze(nn.Module): + """ + ResNet model builder. It builds a ResNet like network backbone without + lateral connection (C2D, I3D, Slow). + + Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He. + "SlowFast networks for video recognition." + https://arxiv.org/pdf/1812.03982.pdf + + Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He. + "Non-local neural networks." + https://arxiv.org/pdf/1711.07971.pdf + """ + + def __init__(self, cfg): + """ + The `__init__` method of any subclass should also contain these + arguments. + + Args: + cfg (CfgNode): model building configs, details are in the + comments of the config file. + """ + super(ResNetFreeze, self).__init__() + self.norm_module = get_norm(cfg) + self.enable_detection = cfg.DETECTION.ENABLE + self.num_pathways = 1 + self._construct_network(cfg) + init_helper.init_weights( + self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN + ) + + def _construct_network(self, cfg): + """ + Builds a single pathway ResNet model. + + Args: + cfg (CfgNode): model building configs, details are in the + comments of the config file. + """ + assert cfg.MODEL.ARCH in _POOL1.keys() + pool_size = _POOL1[cfg.MODEL.ARCH] + assert len({len(pool_size), self.num_pathways}) == 1 + assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys() + + (d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH] + + num_groups = cfg.RESNET.NUM_GROUPS + width_per_group = cfg.RESNET.WIDTH_PER_GROUP + dim_inner = num_groups * width_per_group + + temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH] + + self.s1 = stem_helper.VideoModelStem( + dim_in=cfg.DATA.INPUT_CHANNEL_NUM, + dim_out=[width_per_group], + kernel=[temp_kernel[0][0] + [7, 7]], + stride=[[1, 2, 2]], + padding=[[temp_kernel[0][0][0] // 2, 3, 3]], + norm_module=self.norm_module, + ) + + self.s2 = resnet_helper.ResStage( + dim_in=[width_per_group], + dim_out=[width_per_group * 4], + dim_inner=[dim_inner], + temp_kernel_sizes=temp_kernel[1], + stride=cfg.RESNET.SPATIAL_STRIDES[0], + num_blocks=[d2], + num_groups=[num_groups], + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0], + nonlocal_inds=cfg.NONLOCAL.LOCATION[0], + nonlocal_group=cfg.NONLOCAL.GROUP[0], + nonlocal_pool=cfg.NONLOCAL.POOL[0], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + stride_1x1=cfg.RESNET.STRIDE_1X1, + inplace_relu=cfg.RESNET.INPLACE_RELU, + dilation=cfg.RESNET.SPATIAL_DILATIONS[0], + norm_module=self.norm_module, + ) + + for pathway in range(self.num_pathways): + pool = nn.MaxPool3d( + kernel_size=pool_size[pathway], + stride=pool_size[pathway], + padding=[0, 0, 0], + ) + self.add_module("pathway{}_pool".format(pathway), pool) + + self.s3 = resnet_helper.ResStage( + dim_in=[width_per_group * 4], + dim_out=[width_per_group * 8], + dim_inner=[dim_inner * 2], + temp_kernel_sizes=temp_kernel[2], + stride=cfg.RESNET.SPATIAL_STRIDES[1], + num_blocks=[d3], + num_groups=[num_groups], + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1], + nonlocal_inds=cfg.NONLOCAL.LOCATION[1], + nonlocal_group=cfg.NONLOCAL.GROUP[1], + nonlocal_pool=cfg.NONLOCAL.POOL[1], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + stride_1x1=cfg.RESNET.STRIDE_1X1, + inplace_relu=cfg.RESNET.INPLACE_RELU, + dilation=cfg.RESNET.SPATIAL_DILATIONS[1], + norm_module=self.norm_module, + ) + + self.s4 = resnet_helper.ResStage( + dim_in=[width_per_group * 8], + dim_out=[width_per_group * 16], + dim_inner=[dim_inner * 4], + temp_kernel_sizes=temp_kernel[3], + stride=cfg.RESNET.SPATIAL_STRIDES[2], + num_blocks=[d4], + num_groups=[num_groups], + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2], + nonlocal_inds=cfg.NONLOCAL.LOCATION[2], + nonlocal_group=cfg.NONLOCAL.GROUP[2], + nonlocal_pool=cfg.NONLOCAL.POOL[2], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + stride_1x1=cfg.RESNET.STRIDE_1X1, + inplace_relu=cfg.RESNET.INPLACE_RELU, + dilation=cfg.RESNET.SPATIAL_DILATIONS[2], + norm_module=self.norm_module, + ) + + self.s5 = resnet_helper.ResStage( + dim_in=[width_per_group * 16], + dim_out=[width_per_group * 32], + dim_inner=[dim_inner * 8], + temp_kernel_sizes=temp_kernel[4], + stride=cfg.RESNET.SPATIAL_STRIDES[3], + num_blocks=[d5], + num_groups=[num_groups], + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[3], + nonlocal_inds=cfg.NONLOCAL.LOCATION[3], + nonlocal_group=cfg.NONLOCAL.GROUP[3], + nonlocal_pool=cfg.NONLOCAL.POOL[3], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + stride_1x1=cfg.RESNET.STRIDE_1X1, + inplace_relu=cfg.RESNET.INPLACE_RELU, + dilation=cfg.RESNET.SPATIAL_DILATIONS[3], + norm_module=self.norm_module, + ) + + if self.enable_detection: + raise NotImplementedError + else: + self.head = head_helper.ResNetBasicHead( + dim_in=[width_per_group * 32], + num_classes=cfg.MODEL.NUM_CLASSES, + pool_size=[None,None] + if cfg.MULTIGRID.SHORT_CYCLE + else [ + None + ], # None for AdaptiveAvgPool3d((1, 1, 1)) + dropout_rate=cfg.MODEL.DROPOUT_RATE, + act_func=cfg.MODEL.HEAD_ACT, + ) + + def forward(self, x, freeze_backbone=False): + assert isinstance(freeze_backbone,bool) + x = self.s1(x) + x = self.s2(x) + # for pathway in range(self.num_pathways): + # pool = getattr(self, "pathway{}_pool".format(pathway)) + # x[pathway] = pool(x[pathway]) + x = self.s3(x) + x = self.s4(x) + x = self.s5(x) + if freeze_backbone: + x=[item.detach() for item in x] + + x = self.head(x) + return x + + + +import torch.nn.functional as F +from .unet_helper import DecoderBlock,LightDecoderBlock,ResDecoderBlock + + +@MODEL_REGISTRY.register() +class ResUNet(nn.Module): + """ + ResNet model builder. It builds a ResNet like network backbone without + lateral connection (C2D, I3D, Slow). + + Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He. + "SlowFast networks for video recognition." + https://arxiv.org/pdf/1812.03982.pdf + + Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He. + "Non-local neural networks." + https://arxiv.org/pdf/1711.07971.pdf + """ + + def __init__(self, cfg): + """ + The `__init__` method of any subclass should also contain these + arguments. + + Args: + cfg (CfgNode): model building configs, details are in the + comments of the config file. + """ + super(ResUNet, self).__init__() + self.norm_module = get_norm(cfg) + self.enable_detection = cfg.DETECTION.ENABLE + self.enable_jitter = cfg.JITTER.ENABLE + self.num_pathways = 1 + assert cfg.DATA.TRAIN_CROP_SIZE == cfg.DATA.TEST_CROP_SIZE + self.image_size = cfg.DATA.TRAIN_CROP_SIZE + self.clip_size = cfg.DATA.NUM_FRAMES + self._construct_network(cfg) + init_helper.init_weights( + self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN + ) + + def _construct_network(self, cfg): + """ + Builds a single pathway ResNet model. + + Args: + cfg (CfgNode): model building configs, details are in the + comments of the config file. + """ + assert cfg.MODEL.ARCH in _POOL1.keys() + pool_size = _POOL1[cfg.MODEL.ARCH] + self.cfg = cfg + assert len({len(pool_size), self.num_pathways}) == 1 + assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys() + + (d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH] + + num_groups = cfg.RESNET.NUM_GROUPS + width_per_group = cfg.RESNET.WIDTH_PER_GROUP + dim_inner = num_groups * width_per_group + + temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH] + + self.s1 = stem_helper.VideoModelStem( + dim_in=cfg.DATA.INPUT_CHANNEL_NUM, + dim_out=[width_per_group], + kernel=[temp_kernel[0][0] + [7, 7]], + stride=[[1, 2, 2]], + padding=[[temp_kernel[0][0][0] // 2, 3, 3]], + norm_module=self.norm_module, + ) + + self.s2 = resnet_helper.ResStage( + dim_in=[width_per_group], + dim_out=[width_per_group * 4], + dim_inner=[dim_inner], + temp_kernel_sizes=temp_kernel[1], + stride=cfg.RESNET.SPATIAL_STRIDES[0], + num_blocks=[d2], + num_groups=[num_groups], + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0], + nonlocal_inds=cfg.NONLOCAL.LOCATION[0], + nonlocal_group=cfg.NONLOCAL.GROUP[0], + nonlocal_pool=cfg.NONLOCAL.POOL[0], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + stride_1x1=cfg.RESNET.STRIDE_1X1, + inplace_relu=cfg.RESNET.INPLACE_RELU, + dilation=cfg.RESNET.SPATIAL_DILATIONS[0], + norm_module=self.norm_module, + ) + + for pathway in range(self.num_pathways): + pool = nn.MaxPool3d( + kernel_size=pool_size[pathway], + stride=pool_size[pathway], + padding=[0, 0, 0], + ) + self.add_module("pathway{}_pool".format(pathway), pool) + + self.s3 = resnet_helper.ResStage( + dim_in=[width_per_group * 4], + dim_out=[width_per_group * 8], + dim_inner=[dim_inner * 2], + temp_kernel_sizes=temp_kernel[2], + stride=cfg.RESNET.SPATIAL_STRIDES[1], + num_blocks=[d3], + num_groups=[num_groups], + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1], + nonlocal_inds=cfg.NONLOCAL.LOCATION[1], + nonlocal_group=cfg.NONLOCAL.GROUP[1], + nonlocal_pool=cfg.NONLOCAL.POOL[1], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + stride_1x1=cfg.RESNET.STRIDE_1X1, + inplace_relu=cfg.RESNET.INPLACE_RELU, + dilation=cfg.RESNET.SPATIAL_DILATIONS[1], + norm_module=self.norm_module, + ) + + self.s4 = resnet_helper.ResStage( + dim_in=[width_per_group * 8], + dim_out=[width_per_group * 16], + dim_inner=[dim_inner * 4], + temp_kernel_sizes=temp_kernel[3], + stride=cfg.RESNET.SPATIAL_STRIDES[2], + num_blocks=[d4], + num_groups=[num_groups], + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2], + nonlocal_inds=cfg.NONLOCAL.LOCATION[2], + nonlocal_group=cfg.NONLOCAL.GROUP[2], + nonlocal_pool=cfg.NONLOCAL.POOL[2], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + stride_1x1=cfg.RESNET.STRIDE_1X1, + inplace_relu=cfg.RESNET.INPLACE_RELU, + dilation=cfg.RESNET.SPATIAL_DILATIONS[2], + norm_module=self.norm_module, + ) + + # self.s5 = resnet_helper.ResStage( + # dim_in=[width_per_group * 16], + # dim_out=[width_per_group * 32], + # dim_inner=[dim_inner * 8], + # temp_kernel_sizes=temp_kernel[4], + # stride=cfg.RESNET.SPATIAL_STRIDES[3], + # num_blocks=[d5], + # num_groups=[num_groups], + # num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[3], + # nonlocal_inds=cfg.NONLOCAL.LOCATION[3], + # nonlocal_group=cfg.NONLOCAL.GROUP[3], + # nonlocal_pool=cfg.NONLOCAL.POOL[3], + # instantiation=cfg.NONLOCAL.INSTANTIATION, + # trans_func_name=cfg.RESNET.TRANS_FUNC, + # stride_1x1=cfg.RESNET.STRIDE_1X1, + # inplace_relu=cfg.RESNET.INPLACE_RELU, + # dilation=cfg.RESNET.SPATIAL_DILATIONS[3], + # norm_module=self.norm_module, + # ) + self.labels=["rotate","light"] + self.dual_define("t4",self.labels,DecoderBlock(width_per_group * 16,width_per_group * 8,width_per_group * 8)) + self.dual_define("t3",self.labels,DecoderBlock(width_per_group * 8,width_per_group * 4, 256)) + self.dual_define("conv1x1",self.labels,nn.Sequential( + nn.Conv3d(width_per_group*4+width_per_group, 1, kernel_size=(1, 1, 1), stride=1, padding=0), nn.Sigmoid() + )) + + self.linear = nn.Sequential(nn.Linear(1, 1), nn.Sigmoid()) + + def forward_plus(self, x, y, net): + return [net(x)[0] + y[0]] + + + def dual_define(self,name,labels,net): + for label in labels: + self.add_module(f"{name}_{label}",copy.deepcopy(net)) + + + + def upsample(self, x, dims=["space"]): + ori_size = x[0].shape[2:5] + t, h, w = ori_size + if "space" in dims: + h = 2 * h + w = 2 * w + if "time" in dims: + t = 2 * t + size = (t, h, w) + return [F.interpolate(x[0], size)] + + def concat(self,x,y): + return [torch.cat([x[0],y[0]],1)] + + + + # @torchsnooper.snoop() + def forward(self, x, bboxes=None): + x1 = self.s1(x) # 1,64,8,56,56 + x2 = self.s2(x1) # 1,256,8,56,56 + x3 = self.s3(x2) # 1,512,8,28, 28 + x = self.s4(x3) # 1,1024,8,14,14 + x = self.upsample(x) # 1,1024, 8, 28, 28 + x = self.concat(x3,x)# 1,1024+512, 8, 28, 28 + x=[self.forward_branch(x,x1,x2,label) for label in self.labels] + x=torch.cat(x,1) + out = x.mean([3, 4]).view(-1, 1)*100 + out = self.linear(out) + out = out.view(x.size(0), -1) + return x,out + + + + def forward_branch(self,x,x1,x2,label): + t4=getattr(self,f"t4_{label}") + x = t4(x[0])# 1,512, 8, 28, 28 + x = self.upsample([x]) # 1,512, 8, 56, 56 + x = self.concat(x2,x)# 1,256+512, 8, 56, 56 + t3= getattr(self,f"t3_{label}") + x = t3(x[0]) # 1,256, 8, 56, 56 + x = self.concat(x1,[x]) # 1,320, 8, 56, 56 + conv1x1=getattr(self,f"conv1x1_{label}") + x = conv1x1(x[0]) # 1,2,8,56,56 + return x + + + +@MODEL_REGISTRY.register() +class ResUNetLight(nn.Module): + """ + ResNet model builder. It builds a ResNet like network backbone without + lateral connection (C2D, I3D, Slow). + + Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He. + "SlowFast networks for video recognition." + https://arxiv.org/pdf/1812.03982.pdf + + Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He. + "Non-local neural networks." + https://arxiv.org/pdf/1711.07971.pdf + """ + + def __init__(self, cfg): + """ + The `__init__` method of any subclass should also contain these + arguments. + + Args: + cfg (CfgNode): model building configs, details are in the + comments of the config file. + """ + super(ResUNetLight, self).__init__() + self.norm_module = get_norm(cfg) + self.enable_detection = cfg.DETECTION.ENABLE + self.enable_jitter = cfg.JITTER.ENABLE + self.num_pathways = 1 + assert cfg.DATA.TRAIN_CROP_SIZE == cfg.DATA.TEST_CROP_SIZE + self.image_size = cfg.DATA.TRAIN_CROP_SIZE + self.clip_size = cfg.DATA.NUM_FRAMES + self._construct_network(cfg) + init_helper.init_weights( + self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN + ) + + def _construct_network(self, cfg): + """ + Builds a single pathway ResNet model. + + Args: + cfg (CfgNode): model building configs, details are in the + comments of the config file. + """ + assert cfg.MODEL.ARCH in _POOL1.keys() + pool_size = _POOL1[cfg.MODEL.ARCH] + self.cfg = cfg + assert len({len(pool_size), self.num_pathways}) == 1 + assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys() + + (d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH] + + num_groups = cfg.RESNET.NUM_GROUPS + width_per_group = cfg.RESNET.WIDTH_PER_GROUP + dim_inner = num_groups * width_per_group + + temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH] + + self.s1 = stem_helper.VideoModelStem( + dim_in=cfg.DATA.INPUT_CHANNEL_NUM, + dim_out=[width_per_group], + kernel=[temp_kernel[0][0] + [7, 7]], + stride=[[1, 2, 2]], + padding=[[temp_kernel[0][0][0] // 2, 3, 3]], + norm_module=self.norm_module, + ) + + self.s2 = resnet_helper.ResStage( + dim_in=[width_per_group], + dim_out=[width_per_group * 4], + dim_inner=[dim_inner], + temp_kernel_sizes=temp_kernel[1], + stride=cfg.RESNET.SPATIAL_STRIDES[0], + num_blocks=[d2], + num_groups=[num_groups], + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0], + nonlocal_inds=cfg.NONLOCAL.LOCATION[0], + nonlocal_group=cfg.NONLOCAL.GROUP[0], + nonlocal_pool=cfg.NONLOCAL.POOL[0], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + stride_1x1=cfg.RESNET.STRIDE_1X1, + inplace_relu=cfg.RESNET.INPLACE_RELU, + dilation=cfg.RESNET.SPATIAL_DILATIONS[0], + norm_module=self.norm_module, + ) + + for pathway in range(self.num_pathways): + pool = nn.MaxPool3d( + kernel_size=pool_size[pathway], + stride=pool_size[pathway], + padding=[0, 0, 0], + ) + self.add_module("pathway{}_pool".format(pathway), pool) + + self.s3 = resnet_helper.ResStage( + dim_in=[width_per_group * 4], + dim_out=[width_per_group * 8], + dim_inner=[dim_inner * 2], + temp_kernel_sizes=temp_kernel[2], + stride=cfg.RESNET.SPATIAL_STRIDES[1], + num_blocks=[d3], + num_groups=[num_groups], + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1], + nonlocal_inds=cfg.NONLOCAL.LOCATION[1], + nonlocal_group=cfg.NONLOCAL.GROUP[1], + nonlocal_pool=cfg.NONLOCAL.POOL[1], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + stride_1x1=cfg.RESNET.STRIDE_1X1, + inplace_relu=cfg.RESNET.INPLACE_RELU, + dilation=cfg.RESNET.SPATIAL_DILATIONS[1], + norm_module=self.norm_module, + ) + + self.s4 = resnet_helper.ResStage( + dim_in=[width_per_group * 8], + dim_out=[width_per_group * 16], + dim_inner=[dim_inner * 4], + temp_kernel_sizes=temp_kernel[3], + stride=cfg.RESNET.SPATIAL_STRIDES[2], + num_blocks=[d4], + num_groups=[num_groups], + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2], + nonlocal_inds=cfg.NONLOCAL.LOCATION[2], + nonlocal_group=cfg.NONLOCAL.GROUP[2], + nonlocal_pool=cfg.NONLOCAL.POOL[2], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + stride_1x1=cfg.RESNET.STRIDE_1X1, + inplace_relu=cfg.RESNET.INPLACE_RELU, + dilation=cfg.RESNET.SPATIAL_DILATIONS[2], + norm_module=self.norm_module, + ) + + # self.s5 = resnet_helper.ResStage( + # dim_in=[width_per_group * 16], + # dim_out=[width_per_group * 32], + # dim_inner=[dim_inner * 8], + # temp_kernel_sizes=temp_kernel[4], + # stride=cfg.RESNET.SPATIAL_STRIDES[3], + # num_blocks=[d5], + # num_groups=[num_groups], + # num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[3], + # nonlocal_inds=cfg.NONLOCAL.LOCATION[3], + # nonlocal_group=cfg.NONLOCAL.GROUP[3], + # nonlocal_pool=cfg.NONLOCAL.POOL[3], + # instantiation=cfg.NONLOCAL.INSTANTIATION, + # trans_func_name=cfg.RESNET.TRANS_FUNC, + # stride_1x1=cfg.RESNET.STRIDE_1X1, + # inplace_relu=cfg.RESNET.INPLACE_RELU, + # dilation=cfg.RESNET.SPATIAL_DILATIONS[3], + # norm_module=self.norm_module, + # ) + self.labels=["rotate","light"] + self.dual_define("t4",self.labels,LightDecoderBlock(width_per_group * 16,width_per_group * 8,width_per_group * 4)) + self.dual_define("t3",self.labels,LightDecoderBlock(width_per_group * 4,width_per_group * 4, 128)) + self.dual_define("conv1x1",self.labels,nn.Sequential( + nn.Conv3d(128+width_per_group, 1, kernel_size=(1, 1, 1), stride=1, padding=0), nn.Sigmoid() + )) + + self.linear = nn.Sequential(nn.Linear(1, 1), nn.Sigmoid()) + + def forward_plus(self, x, y, net): + return [net(x)[0] + y[0]] + + + def dual_define(self,name,labels,net): + for label in labels: + self.add_module(f"{name}_{label}",copy.deepcopy(net)) + + + + def upsample(self, x, dims=["space"]): + ori_size = x[0].shape[2:5] + t, h, w = ori_size + if "space" in dims: + h = 2 * h + w = 2 * w + if "time" in dims: + t = 2 * t + size = (t, h, w) + return [F.interpolate(x[0], size)] + + def concat(self,x,y): + return [torch.cat([x[0],y[0]],1)] + + def get_detach_var(self,x): + return [t.detach() for t in x] + + # @torchsnooper.snoop() + def forward(self, x, freeze_backbone=False): + x1 = self.s1(x) # 1,64,8,56,56 + x2 = self.s2(x1) # 1,256,8,56,56 + x3 = self.s3(x2) # 1,512,8,28, 28 + x = self.s4(x3) # 1,1024,8,14,14 + assert isinstance(freeze_backbone,bool) + if freeze_backbone: + x=self.get_detach_var(x) + x1=self.get_detach_var(x1) + x2=self.get_detach_var(x2) + x3=self.get_detach_var(x3) + + x = self.upsample(x) # 1,1024, 8, 28, 28 + x = self.concat(x3,x)# 1,1024+512, 8, 28, 28 + x=[self.forward_branch(x,x1,x2,label) for label in self.labels] + x=torch.cat(x,1) + out = x.mean([3, 4]).view(-1, 1)*100 # 1,2,8,56,56 + out = self.linear(out) + out = out.view(x.size(0), -1) + return x,out + + + + def forward_branch(self,x,x1,x2,label): + t4=getattr(self,f"t4_{label}") + x = t4(x[0])# 1,256, 8, 28, 28 + x = self.upsample([x]) # 1,256, 8, 56, 56 + x = self.concat(x2,x)# 1,256+256, 8, 56, 56 + t3= getattr(self,f"t3_{label}") + x = t3(x[0]) # 1,128, 8, 56, 56 + x = self.concat(x1,[x]) # 1,192, 8, 56, 56 + conv1x1=getattr(self,f"conv1x1_{label}") + x = conv1x1(x[0]) # 1,2,8,56,56 + return x + + + +@MODEL_REGISTRY.register() +class ResUNetLightFix(nn.Module): + """ + ResNet model builder. It builds a ResNet like network backbone without + lateral connection (C2D, I3D, Slow). + + Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He. + "SlowFast networks for video recognition." + https://arxiv.org/pdf/1812.03982.pdf + + Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He. + "Non-local neural networks." + https://arxiv.org/pdf/1711.07971.pdf + """ + + def __init__(self, cfg): + """ + The `__init__` method of any subclass should also contain these + arguments. + + Args: + cfg (CfgNode): model building configs, details are in the + comments of the config file. + """ + super(ResUNetLightFix, self).__init__() + self.norm_module = get_norm(cfg) + self.enable_detection = cfg.DETECTION.ENABLE + self.enable_jitter = cfg.JITTER.ENABLE + self.num_pathways = 1 + assert cfg.DATA.TRAIN_CROP_SIZE == cfg.DATA.TEST_CROP_SIZE + self.image_size = cfg.DATA.TRAIN_CROP_SIZE + self.clip_size = cfg.DATA.NUM_FRAMES + self._construct_network(cfg) + init_helper.init_weights( + self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN + ) + + def _construct_network(self, cfg): + """ + Builds a single pathway ResNet model. + + Args: + cfg (CfgNode): model building configs, details are in the + comments of the config file. + """ + assert cfg.MODEL.ARCH in _POOL1.keys() + pool_size = _POOL1[cfg.MODEL.ARCH] + self.cfg = cfg + assert len({len(pool_size), self.num_pathways}) == 1 + assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys() + + (d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH] + + num_groups = cfg.RESNET.NUM_GROUPS + width_per_group = cfg.RESNET.WIDTH_PER_GROUP + dim_inner = num_groups * width_per_group + + temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH] + + self.s1 = stem_helper.VideoModelStem( + dim_in=cfg.DATA.INPUT_CHANNEL_NUM, + dim_out=[width_per_group], + kernel=[temp_kernel[0][0] + [7, 7]], + stride=[[1, 2, 2]], + padding=[[temp_kernel[0][0][0] // 2, 3, 3]], + norm_module=self.norm_module, + ) + + self.s2 = resnet_helper.ResStage( + dim_in=[width_per_group], + dim_out=[width_per_group * 4], + dim_inner=[dim_inner], + temp_kernel_sizes=temp_kernel[1], + stride=cfg.RESNET.SPATIAL_STRIDES[0], + num_blocks=[d2], + num_groups=[num_groups], + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0], + nonlocal_inds=cfg.NONLOCAL.LOCATION[0], + nonlocal_group=cfg.NONLOCAL.GROUP[0], + nonlocal_pool=cfg.NONLOCAL.POOL[0], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + stride_1x1=cfg.RESNET.STRIDE_1X1, + inplace_relu=cfg.RESNET.INPLACE_RELU, + dilation=cfg.RESNET.SPATIAL_DILATIONS[0], + norm_module=self.norm_module, + ) + + for pathway in range(self.num_pathways): + pool = nn.MaxPool3d( + kernel_size=pool_size[pathway], + stride=pool_size[pathway], + padding=[0, 0, 0], + ) + self.add_module("pathway{}_pool".format(pathway), pool) + + self.s3 = resnet_helper.ResStage( + dim_in=[width_per_group * 4], + dim_out=[width_per_group * 8], + dim_inner=[dim_inner * 2], + temp_kernel_sizes=temp_kernel[2], + stride=cfg.RESNET.SPATIAL_STRIDES[1], + num_blocks=[d3], + num_groups=[num_groups], + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1], + nonlocal_inds=cfg.NONLOCAL.LOCATION[1], + nonlocal_group=cfg.NONLOCAL.GROUP[1], + nonlocal_pool=cfg.NONLOCAL.POOL[1], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + stride_1x1=cfg.RESNET.STRIDE_1X1, + inplace_relu=cfg.RESNET.INPLACE_RELU, + dilation=cfg.RESNET.SPATIAL_DILATIONS[1], + norm_module=self.norm_module, + ) + + self.s4 = resnet_helper.ResStage( + dim_in=[width_per_group * 8], + dim_out=[width_per_group * 16], + dim_inner=[dim_inner * 4], + temp_kernel_sizes=temp_kernel[3], + stride=cfg.RESNET.SPATIAL_STRIDES[2], + num_blocks=[d4], + num_groups=[num_groups], + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2], + nonlocal_inds=cfg.NONLOCAL.LOCATION[2], + nonlocal_group=cfg.NONLOCAL.GROUP[2], + nonlocal_pool=cfg.NONLOCAL.POOL[2], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + stride_1x1=cfg.RESNET.STRIDE_1X1, + inplace_relu=cfg.RESNET.INPLACE_RELU, + dilation=cfg.RESNET.SPATIAL_DILATIONS[2], + norm_module=self.norm_module, + ) + + # self.s5 = resnet_helper.ResStage( + # dim_in=[width_per_group * 16], + # dim_out=[width_per_group * 32], + # dim_inner=[dim_inner * 8], + # temp_kernel_sizes=temp_kernel[4], + # stride=cfg.RESNET.SPATIAL_STRIDES[3], + # num_blocks=[d5], + # num_groups=[num_groups], + # num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[3], + # nonlocal_inds=cfg.NONLOCAL.LOCATION[3], + # nonlocal_group=cfg.NONLOCAL.GROUP[3], + # nonlocal_pool=cfg.NONLOCAL.POOL[3], + # instantiation=cfg.NONLOCAL.INSTANTIATION, + # trans_func_name=cfg.RESNET.TRANS_FUNC, + # stride_1x1=cfg.RESNET.STRIDE_1X1, + # inplace_relu=cfg.RESNET.INPLACE_RELU, + # dilation=cfg.RESNET.SPATIAL_DILATIONS[3], + # norm_module=self.norm_module, + # ) + self.labels=["rotate","light","skip"] + self.dual_define("t4",self.labels,LightDecoderBlock(width_per_group * 16,width_per_group * 8,width_per_group * 4)) + self.dual_define("t3",self.labels,LightDecoderBlock(width_per_group * 4,width_per_group * 4, 128)) + self.dual_define("conv1x1",self.labels,nn.Sequential( + nn.Conv3d(128+width_per_group, 64, kernel_size=(1, 1, 1), stride=1, padding=0), + nn.BatchNorm3d(64), + nn.ReLU(), + nn.Conv3d(64, 1, kernel_size=(1, 1, 1), stride=1, padding=0), + )) + + self.linear = nn.Sequential(nn.Linear(1, 1)) + + def forward_plus(self, x, y, net): + return [net(x)[0] + y[0]] + + + def dual_define(self,name,labels,net): + for label in labels: + self.add_module(f"{name}_{label}",copy.deepcopy(net)) + + + + def upsample(self, x, dims=["space"]): + ori_size = x[0].shape[2:5] + t, h, w = ori_size + if "space" in dims: + h = 2 * h + w = 2 * w + if "time" in dims: + t = 2 * t + size = (t, h, w) + return [F.interpolate(x[0], size)] + + def concat(self,x,y): + return [torch.cat([x[0],y[0]],1)] + + def get_detach_var(self,x): + return [t.detach() for t in x] + + # @torchsnooper.snoop() + def forward(self, x, freeze_backbone=False): + x1 = self.s1(x) # 1,64,8,56,56 + x2 = self.s2(x1) # 1,256,8,56,56 + x3 = self.s3(x2) # 1,512,8,28, 28 + x = self.s4(x3) # 1,1024,8,14,14 + assert isinstance(freeze_backbone,bool) + if freeze_backbone: + x=self.get_detach_var(x) + x1=self.get_detach_var(x1) + x2=self.get_detach_var(x2) + x3=self.get_detach_var(x3) + + x = self.upsample(x) # 1,1024, 8, 28, 28 + x = self.concat(x3,x)# 1,1024+512, 8, 28, 28 + x=[self.forward_branch(x,x1,x2,label) for label in self.labels] + x=torch.cat(x,1) + x=torch.sigmoid(x) + out = x.mean([3, 4]).view(-1, 1)*100 # 1,2,8,56,56 + out = self.linear(out) + out = out.view(x.size(0), -1) + out = torch.sigmoid(out) + return x,out + + + + def forward_branch(self,x,x1,x2,label): + t4=getattr(self,f"t4_{label}") + x = t4(x[0])# 1,256, 8, 28, 28 + x = self.upsample([x]) # 1,256, 8, 56, 56 + x = self.concat(x2,x)# 1,256+256, 8, 56, 56 + t3= getattr(self,f"t3_{label}") + x = t3(x[0]) # 1,128, 8, 56, 56 + x = self.concat(x1,[x]) # 1,192, 8, 56, 56 + conv1x1=getattr(self,f"conv1x1_{label}") + x = conv1x1(x[0]) # 1,2,8,56,56 + return x + + + +@MODEL_REGISTRY.register() +class ResUNetContinus(nn.Module): + """ + ResNet model builder. It builds a ResNet like network backbone without + lateral connection (C2D, I3D, Slow). + + Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He. + "SlowFast networks for video recognition." + https://arxiv.org/pdf/1812.03982.pdf + + Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He. + "Non-local neural networks." + https://arxiv.org/pdf/1711.07971.pdf + """ + + def __init__(self, cfg): + """ + The `__init__` method of any subclass should also contain these + arguments. + + Args: + cfg (CfgNode): model building configs, details are in the + comments of the config file. + """ + super(ResUNetContinus, self).__init__() + self.norm_module = get_norm(cfg) + self.enable_detection = cfg.DETECTION.ENABLE + self.enable_jitter = cfg.JITTER.ENABLE + self.num_pathways = 1 + assert cfg.DATA.TRAIN_CROP_SIZE == cfg.DATA.TEST_CROP_SIZE + self.image_size = cfg.DATA.TRAIN_CROP_SIZE + self.clip_size = cfg.DATA.NUM_FRAMES + self._construct_network(cfg) + init_helper.init_weights( + self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN + ) + + def _construct_network(self, cfg): + """ + Builds a single pathway ResNet model. + + Args: + cfg (CfgNode): model building configs, details are in the + comments of the config file. + """ + assert cfg.MODEL.ARCH in _POOL1.keys() + pool_size = _POOL1[cfg.MODEL.ARCH] + self.cfg = cfg + assert len({len(pool_size), self.num_pathways}) == 1 + assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys() + + (d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH] + + num_groups = cfg.RESNET.NUM_GROUPS + width_per_group = cfg.RESNET.WIDTH_PER_GROUP + dim_inner = num_groups * width_per_group + + temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH] + + self.s1 = stem_helper.VideoModelStem( + dim_in=cfg.DATA.INPUT_CHANNEL_NUM, + dim_out=[width_per_group], + kernel=[temp_kernel[0][0] + [7, 7]], + stride=[[1, 2, 2]], + padding=[[temp_kernel[0][0][0] // 2, 3, 3]], + norm_module=self.norm_module, + ) + + self.s2 = resnet_helper.ResStage( + dim_in=[width_per_group], + dim_out=[width_per_group * 4], + dim_inner=[dim_inner], + temp_kernel_sizes=temp_kernel[1], + stride=cfg.RESNET.SPATIAL_STRIDES[0], + num_blocks=[d2], + num_groups=[num_groups], + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0], + nonlocal_inds=cfg.NONLOCAL.LOCATION[0], + nonlocal_group=cfg.NONLOCAL.GROUP[0], + nonlocal_pool=cfg.NONLOCAL.POOL[0], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + stride_1x1=cfg.RESNET.STRIDE_1X1, + inplace_relu=cfg.RESNET.INPLACE_RELU, + dilation=cfg.RESNET.SPATIAL_DILATIONS[0], + norm_module=self.norm_module, + ) + + for pathway in range(self.num_pathways): + pool = nn.MaxPool3d( + kernel_size=pool_size[pathway], + stride=pool_size[pathway], + padding=[0, 0, 0], + ) + self.add_module("pathway{}_pool".format(pathway), pool) + + self.s3 = resnet_helper.ResStage( + dim_in=[width_per_group * 4], + dim_out=[width_per_group * 8], + dim_inner=[dim_inner * 2], + temp_kernel_sizes=temp_kernel[2], + stride=cfg.RESNET.SPATIAL_STRIDES[1], + num_blocks=[d3], + num_groups=[num_groups], + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1], + nonlocal_inds=cfg.NONLOCAL.LOCATION[1], + nonlocal_group=cfg.NONLOCAL.GROUP[1], + nonlocal_pool=cfg.NONLOCAL.POOL[1], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + stride_1x1=cfg.RESNET.STRIDE_1X1, + inplace_relu=cfg.RESNET.INPLACE_RELU, + dilation=cfg.RESNET.SPATIAL_DILATIONS[1], + norm_module=self.norm_module, + ) + + self.s4 = resnet_helper.ResStage( + dim_in=[width_per_group * 8], + dim_out=[width_per_group * 16], + dim_inner=[dim_inner * 4], + temp_kernel_sizes=temp_kernel[3], + stride=cfg.RESNET.SPATIAL_STRIDES[2], + num_blocks=[d4], + num_groups=[num_groups], + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2], + nonlocal_inds=cfg.NONLOCAL.LOCATION[2], + nonlocal_group=cfg.NONLOCAL.GROUP[2], + nonlocal_pool=cfg.NONLOCAL.POOL[2], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + stride_1x1=cfg.RESNET.STRIDE_1X1, + inplace_relu=cfg.RESNET.INPLACE_RELU, + dilation=cfg.RESNET.SPATIAL_DILATIONS[2], + norm_module=self.norm_module, + ) + + # self.s5 = resnet_helper.ResStage( + # dim_in=[width_per_group * 16], + # dim_out=[width_per_group * 32], + # dim_inner=[dim_inner * 8], + # temp_kernel_sizes=temp_kernel[4], + # stride=cfg.RESNET.SPATIAL_STRIDES[3], + # num_blocks=[d5], + # num_groups=[num_groups], + # num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[3], + # nonlocal_inds=cfg.NONLOCAL.LOCATION[3], + # nonlocal_group=cfg.NONLOCAL.GROUP[3], + # nonlocal_pool=cfg.NONLOCAL.POOL[3], + # instantiation=cfg.NONLOCAL.INSTANTIATION, + # trans_func_name=cfg.RESNET.TRANS_FUNC, + # stride_1x1=cfg.RESNET.STRIDE_1X1, + # inplace_relu=cfg.RESNET.INPLACE_RELU, + # dilation=cfg.RESNET.SPATIAL_DILATIONS[3], + # norm_module=self.norm_module, + # ) + self.labels=["all"] + self.dual_define("t4",self.labels,LightDecoderBlock(width_per_group * 16,width_per_group * 8,width_per_group * 4)) + self.dual_define("t3",self.labels,LightDecoderBlock(width_per_group * 4,width_per_group * 4, 128)) + self.dual_define("conv1x1",self.labels,nn.Sequential( + nn.Conv3d(128+width_per_group, 64, kernel_size=(1, 1, 1), stride=1, padding=0), + nn.BatchNorm3d(64), + nn.ReLU(), + nn.Conv3d(64, 1, kernel_size=(1, 1, 1), stride=1, padding=0), + )) + + self.linear = nn.Sequential(nn.Linear(1, 1)) + + def forward_plus(self, x, y, net): + return [net(x)[0] + y[0]] + + + def dual_define(self,name,labels,net): + for label in labels: + self.add_module(f"{name}_{label}",copy.deepcopy(net)) + + + + def upsample(self, x, dims=["space"]): + ori_size = x[0].shape[2:5] + t, h, w = ori_size + if "space" in dims: + h = 2 * h + w = 2 * w + if "time" in dims: + t = 2 * t + size = (t, h, w) + return [F.interpolate(x[0], size)] + + def concat(self,x,y): + return [torch.cat([x[0],y[0]],1)] + + def get_detach_var(self,x): + return [t.detach() for t in x] + + # @torchsnooper.snoop() + def forward(self, x, freeze_backbone=False): + x1 = self.s1(x) # 1,64,8,56,56 + x2 = self.s2(x1) # 1,256,8,56,56 + x3 = self.s3(x2) # 1,512,8,28, 28 + x = self.s4(x3) # 1,1024,8,14,14 + assert isinstance(freeze_backbone,bool) + if freeze_backbone: + x=self.get_detach_var(x) + x1=self.get_detach_var(x1) + x2=self.get_detach_var(x2) + x3=self.get_detach_var(x3) + + x = self.upsample(x) # 1,1024, 8, 28, 28 + x = self.concat(x3,x)# 1,1024+512, 8, 28, 28 + x=[self.forward_branch(x,x1,x2,label) for label in self.labels] + x=torch.cat(x,1) + x=torch.sigmoid(x) + out = x.mean([3, 4]).view(-1, 1)*100 # 1,2,8,56,56 + out = self.linear(out) + out = out.view(x.size(0), -1) + out = torch.sigmoid(out) + return x,out + + + def forward_branch(self,x,x1,x2,label): + t4= getattr(self,f"t4_{label}") + x = t4(x[0])# 1,256, 8, 28, 28 + x = self.upsample([x]) # 1,256, 8, 56, 56 + x = self.concat(x2,x)# 1,256+256, 8, 56, 56 + t3= getattr(self,f"t3_{label}") + x = t3(x[0]) # 1,128, 8, 56, 56 + x = self.concat(x1,[x]) # 1,192, 8, 56, 56 + conv1x1=getattr(self,f"conv1x1_{label}") + x = conv1x1(x[0]) # 1,2,8,56,56 + return x + + + + +@MODEL_REGISTRY.register() +class ResUNetCommon(nn.Module): + """ + ResNet model builder. It builds a ResNet like network backbone without + lateral connection (C2D, I3D, Slow). + + Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He. + "SlowFast networks for video recognition." + https://arxiv.org/pdf/1812.03982.pdf + + Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He. + "Non-local neural networks." + https://arxiv.org/pdf/1711.07971.pdf + """ + + def __init__(self, cfg): + """ + The `__init__` method of any subclass should also contain these + arguments. + + Args: + cfg (CfgNode): model building configs, details are in the + comments of the config file. + """ + super(ResUNetCommon, self).__init__() + self.norm_module = get_norm(cfg) + self.enable_detection = cfg.DETECTION.ENABLE + self.enable_jitter = cfg.JITTER.ENABLE + self.num_pathways = 1 + assert cfg.DATA.TRAIN_CROP_SIZE == cfg.DATA.TEST_CROP_SIZE + self.image_size = cfg.DATA.TRAIN_CROP_SIZE + self.clip_size = cfg.DATA.NUM_FRAMES + self._construct_network(cfg) + init_helper.init_weights( + self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN + ) + + def _construct_network(self, cfg): + """ + Builds a single pathway ResNet model. + + Args: + cfg (CfgNode): model building configs, details are in the + comments of the config file. + """ + assert cfg.MODEL.ARCH in _POOL1.keys() + pool_size = _POOL1[cfg.MODEL.ARCH] + self.cfg = cfg + assert len({len(pool_size), self.num_pathways}) == 1 + assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys() + + (d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH] + + num_groups = cfg.RESNET.NUM_GROUPS + width_per_group = cfg.RESNET.WIDTH_PER_GROUP + dim_inner = num_groups * width_per_group + + temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH] + + self.s1 = stem_helper.VideoModelStem( + dim_in=cfg.DATA.INPUT_CHANNEL_NUM, + dim_out=[width_per_group], + kernel=[temp_kernel[0][0] + [7, 7]], + stride=[[1, 2, 2]], + padding=[[temp_kernel[0][0][0] // 2, 3, 3]], + norm_module=self.norm_module, + ) + + self.s2 = resnet_helper.ResStage( + dim_in=[width_per_group], + dim_out=[width_per_group * 4], + dim_inner=[dim_inner], + temp_kernel_sizes=temp_kernel[1], + stride=cfg.RESNET.SPATIAL_STRIDES[0], + num_blocks=[d2], + num_groups=[num_groups], + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0], + nonlocal_inds=cfg.NONLOCAL.LOCATION[0], + nonlocal_group=cfg.NONLOCAL.GROUP[0], + nonlocal_pool=cfg.NONLOCAL.POOL[0], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + stride_1x1=cfg.RESNET.STRIDE_1X1, + inplace_relu=cfg.RESNET.INPLACE_RELU, + dilation=cfg.RESNET.SPATIAL_DILATIONS[0], + norm_module=self.norm_module, + ) + + for pathway in range(self.num_pathways): + pool = nn.MaxPool3d( + kernel_size=pool_size[pathway], + stride=pool_size[pathway], + padding=[0, 0, 0], + ) + self.add_module("pathway{}_pool".format(pathway), pool) + + self.s3 = resnet_helper.ResStage( + dim_in=[width_per_group * 4], + dim_out=[width_per_group * 8], + dim_inner=[dim_inner * 2], + temp_kernel_sizes=temp_kernel[2], + stride=cfg.RESNET.SPATIAL_STRIDES[1], + num_blocks=[d3], + num_groups=[num_groups], + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1], + nonlocal_inds=cfg.NONLOCAL.LOCATION[1], + nonlocal_group=cfg.NONLOCAL.GROUP[1], + nonlocal_pool=cfg.NONLOCAL.POOL[1], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + stride_1x1=cfg.RESNET.STRIDE_1X1, + inplace_relu=cfg.RESNET.INPLACE_RELU, + dilation=cfg.RESNET.SPATIAL_DILATIONS[1], + norm_module=self.norm_module, + ) + + self.s4 = resnet_helper.ResStage( + dim_in=[width_per_group * 8], + dim_out=[width_per_group * 16], + dim_inner=[dim_inner * 4], + temp_kernel_sizes=temp_kernel[3], + stride=cfg.RESNET.SPATIAL_STRIDES[2], + num_blocks=[d4], + num_groups=[num_groups], + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2], + nonlocal_inds=cfg.NONLOCAL.LOCATION[2], + nonlocal_group=cfg.NONLOCAL.GROUP[2], + nonlocal_pool=cfg.NONLOCAL.POOL[2], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + stride_1x1=cfg.RESNET.STRIDE_1X1, + inplace_relu=cfg.RESNET.INPLACE_RELU, + dilation=cfg.RESNET.SPATIAL_DILATIONS[2], + norm_module=self.norm_module, + ) + + # self.s5 = resnet_helper.ResStage( + # dim_in=[width_per_group * 16], + # dim_out=[width_per_group * 32], + # dim_inner=[dim_inner * 8], + # temp_kernel_sizes=temp_kernel[4], + # stride=cfg.RESNET.SPATIAL_STRIDES[3], + # num_blocks=[d5], + # num_groups=[num_groups], + # num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[3], + # nonlocal_inds=cfg.NONLOCAL.LOCATION[3], + # nonlocal_group=cfg.NONLOCAL.GROUP[3], + # nonlocal_pool=cfg.NONLOCAL.POOL[3], + # instantiation=cfg.NONLOCAL.INSTANTIATION, + # trans_func_name=cfg.RESNET.TRANS_FUNC, + # stride_1x1=cfg.RESNET.STRIDE_1X1, + # inplace_relu=cfg.RESNET.INPLACE_RELU, + # dilation=cfg.RESNET.SPATIAL_DILATIONS[3], + # norm_module=self.norm_module, + # ) + self.labels=cfg.RESNET.LABELS + self.dual_define("t4",self.labels,LightDecoderBlock(width_per_group * 16,width_per_group * 8,width_per_group * 4)) + self.dual_define("t3",self.labels,LightDecoderBlock(width_per_group * 4,width_per_group * 4, 128)) + self.dual_define("conv1x1",self.labels,nn.Sequential( + nn.Conv3d(128+width_per_group, 64, kernel_size=(1, 1, 1), stride=1, padding=0), + nn.BatchNorm3d(64), + nn.ReLU(), + nn.Conv3d(64, 1, kernel_size=(1, 1, 1), stride=1, padding=0), + )) + + self.linear = nn.Linear(1, 2) + + def forward_plus(self, x, y, net): + return [net(x)[0] + y[0]] + + + def dual_define(self,name,labels,net): + for label in labels: + self.add_module(f"{name}_{label}",copy.deepcopy(net)) + + + def upsample(self, x, dims=["space"]): + ori_size = x[0].shape[2:5] + t, h, w = ori_size + if "space" in dims: + h = 2 * h + w = 2 * w + if "time" in dims: + t = 2 * t + size = (t, h, w) + return [F.interpolate(x[0], size)] + + def concat(self,x,y): + return [torch.cat([x[0],y[0]],1)] + + def get_detach_var(self,x): + return [t.detach() for t in x] + + # @torchsnooper.snoop() + def forward(self, x, freeze_backbone=False): + x = self.get_detach_var(x) + x1 = self.s1(x) # 1,64,8,56,56 + x2 = self.s2(x1) # 1,256,8,56,56 + x3 = self.s3(x2) # 1,512,8,28, 28 + feat= self.s4(x3) # 1,1024,8,14,14 + assert isinstance(freeze_backbone,bool) + if freeze_backbone: + feat=self.get_detach_var(feat) + x1=self.get_detach_var(x1) + x2=self.get_detach_var(x2) + x3=self.get_detach_var(x3) + + feat = self.upsample(feat) # 1,1024, 8, 28, 28 + feat = self.concat(x3,feat)# 1,1024+512, 8, 28, 28 + reg_out=[self.forward_branch(feat,x1,x2,label) for label in self.labels] + reg_out=torch.cat(reg_out,1) + reg_out=torch.sigmoid(reg_out) + class_out = reg_out.mean([3, 4]).view(-1, 1)*100 # 1,2,8,56,56 + class_out = self.linear(class_out) + class_out = class_out.view(reg_out.size(0),len(self.labels),-1) + class_out = class_out + return reg_out,class_out + + + def forward_branch(self,feat,x1,x2,label): + t4= getattr(self,f"t4_{label}") + feat = t4(feat[0])# 1,256, 8, 28, 28 + feat = self.upsample([feat]) # 1,256, 8, 56, 56 + feat = self.concat(x2,feat)# 1,256+256, 8, 56, 56 + t3= getattr(self,f"t3_{label}") + feat = t3(feat[0]) # 1,128, 8, 56, 56 + feat = self.concat(x1,[feat]) # 1,192, 8, 56, 56 + conv1x1=getattr(self,f"conv1x1_{label}") + feat = conv1x1(feat[0]) # 1,2,8,56,56 + return feat + + + + +@MODEL_REGISTRY.register() +class ResUNetCommon2(nn.Module): + """ + ResNet model builder. It builds a ResNet like network backbone without + lateral connection (C2D, I3D, Slow). + + Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He. + "SlowFast networks for video recognition." + https://arxiv.org/pdf/1812.03982.pdf + + Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He. + "Non-local neural networks." + https://arxiv.org/pdf/1711.07971.pdf + """ + + def __init__(self, cfg): + """ + The `__init__` method of any subclass should also contain these + arguments. + + Args: + cfg (CfgNode): model building configs, details are in the + comments of the config file. + """ + super(ResUNetCommon2, self).__init__() + self.norm_module = get_norm(cfg) + self.enable_detection = cfg.DETECTION.ENABLE + self.enable_jitter = cfg.JITTER.ENABLE + self.num_pathways = 1 + assert cfg.DATA.TRAIN_CROP_SIZE == cfg.DATA.TEST_CROP_SIZE + self.image_size = cfg.DATA.TRAIN_CROP_SIZE + self.clip_size = cfg.DATA.NUM_FRAMES + self._construct_network(cfg) + init_helper.init_weights( + self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN + ) + + def _construct_network(self, cfg): + """ + Builds a single pathway ResNet model. + + Args: + cfg (CfgNode): model building configs, details are in the + comments of the config file. + """ + assert cfg.MODEL.ARCH in _POOL1.keys() + pool_size = _POOL1[cfg.MODEL.ARCH] + self.cfg = cfg + assert len({len(pool_size), self.num_pathways}) == 1 + assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys() + + (d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH] + + num_groups = cfg.RESNET.NUM_GROUPS + width_per_group = cfg.RESNET.WIDTH_PER_GROUP + dim_inner = num_groups * width_per_group + + temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH] + + self.s1 = stem_helper.VideoModelStem( + dim_in=cfg.DATA.INPUT_CHANNEL_NUM, + dim_out=[width_per_group], + kernel=[temp_kernel[0][0] + [7, 7]], + stride=[[1, 2, 2]], + padding=[[temp_kernel[0][0][0] // 2, 3, 3]], + norm_module=self.norm_module, + ) + + self.s2 = resnet_helper.ResStage( + dim_in=[width_per_group], + dim_out=[width_per_group * 4], + dim_inner=[dim_inner], + temp_kernel_sizes=temp_kernel[1], + stride=cfg.RESNET.SPATIAL_STRIDES[0], + num_blocks=[d2], + num_groups=[num_groups], + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0], + nonlocal_inds=cfg.NONLOCAL.LOCATION[0], + nonlocal_group=cfg.NONLOCAL.GROUP[0], + nonlocal_pool=cfg.NONLOCAL.POOL[0], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + stride_1x1=cfg.RESNET.STRIDE_1X1, + inplace_relu=cfg.RESNET.INPLACE_RELU, + dilation=cfg.RESNET.SPATIAL_DILATIONS[0], + norm_module=self.norm_module, + ) + + for pathway in range(self.num_pathways): + pool = nn.MaxPool3d( + kernel_size=pool_size[pathway], + stride=pool_size[pathway], + padding=[0, 0, 0], + ) + self.add_module("pathway{}_pool".format(pathway), pool) + + self.s3 = resnet_helper.ResStage( + dim_in=[width_per_group * 4], + dim_out=[width_per_group * 8], + dim_inner=[dim_inner * 2], + temp_kernel_sizes=temp_kernel[2], + stride=cfg.RESNET.SPATIAL_STRIDES[1], + num_blocks=[d3], + num_groups=[num_groups], + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1], + nonlocal_inds=cfg.NONLOCAL.LOCATION[1], + nonlocal_group=cfg.NONLOCAL.GROUP[1], + nonlocal_pool=cfg.NONLOCAL.POOL[1], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + stride_1x1=cfg.RESNET.STRIDE_1X1, + inplace_relu=cfg.RESNET.INPLACE_RELU, + dilation=cfg.RESNET.SPATIAL_DILATIONS[1], + norm_module=self.norm_module, + ) + + self.s4 = resnet_helper.ResStage( + dim_in=[width_per_group * 8], + dim_out=[width_per_group * 16], + dim_inner=[dim_inner * 4], + temp_kernel_sizes=temp_kernel[3], + stride=cfg.RESNET.SPATIAL_STRIDES[2], + num_blocks=[d4], + num_groups=[num_groups], + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2], + nonlocal_inds=cfg.NONLOCAL.LOCATION[2], + nonlocal_group=cfg.NONLOCAL.GROUP[2], + nonlocal_pool=cfg.NONLOCAL.POOL[2], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + stride_1x1=cfg.RESNET.STRIDE_1X1, + inplace_relu=cfg.RESNET.INPLACE_RELU, + dilation=cfg.RESNET.SPATIAL_DILATIONS[2], + norm_module=self.norm_module, + ) + + # self.s5 = resnet_helper.ResStage( + # dim_in=[width_per_group * 16], + # dim_out=[width_per_group * 32], + # dim_inner=[dim_inner * 8], + # temp_kernel_sizes=temp_kernel[4], + # stride=cfg.RESNET.SPATIAL_STRIDES[3], + # num_blocks=[d5], + # num_groups=[num_groups], + # num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[3], + # nonlocal_inds=cfg.NONLOCAL.LOCATION[3], + # nonlocal_group=cfg.NONLOCAL.GROUP[3], + # nonlocal_pool=cfg.NONLOCAL.POOL[3], + # instantiation=cfg.NONLOCAL.INSTANTIATION, + # trans_func_name=cfg.RESNET.TRANS_FUNC, + # stride_1x1=cfg.RESNET.STRIDE_1X1, + # inplace_relu=cfg.RESNET.INPLACE_RELU, + # dilation=cfg.RESNET.SPATIAL_DILATIONS[3], + # norm_module=self.norm_module, + # ) + self.labels=cfg.RESNET.LABELS + self.dual_define("t4",self.labels,LightDecoderBlock(width_per_group * 16,width_per_group * 8,width_per_group * 4)) + self.dual_define("t3",self.labels,LightDecoderBlock(width_per_group * 4,width_per_group * 4, 128)) + self.dual_define("conv1x1",self.labels,nn.Sequential( + nn.Conv3d(128+width_per_group, 64, kernel_size=(1, 1, 1), stride=1, padding=0), + nn.BatchNorm3d(64), + nn.ReLU(), + nn.Conv3d(64, 1, kernel_size=(1, 1, 1), stride=1, padding=0), + )) + + self.linear = nn.Linear(1, 1) + + def forward_plus(self, x, y, net): + return [net(x)[0] + y[0]] + + + def dual_define(self,name,labels,net): + for label in labels: + self.add_module(f"{name}_{label}",copy.deepcopy(net)) + + + def upsample(self, x, dims=["space"]): + ori_size = x[0].shape[2:5] + t, h, w = ori_size + if "space" in dims: + h = 2 * h + w = 2 * w + if "time" in dims: + t = 2 * t + size = (t, h, w) + return [F.interpolate(x[0], size)] + + def concat(self,x,y): + return [torch.cat([x[0],y[0]],1)] + + def get_detach_var(self,x): + return [t.detach() for t in x] + + # @torchsnooper.snoop() + def forward(self, x, freeze_backbone=False): + x = self.get_detach_var(x) + x1 = self.s1(x) # 1,64,8,56,56 + x2 = self.s2(x1) # 1,256,8,56,56 + x3 = self.s3(x2) # 1,512,8,28, 28 + feat= self.s4(x3) # 1,1024,8,14,14 + assert isinstance(freeze_backbone,bool) + if freeze_backbone: + feat=self.get_detach_var(feat) + x1=self.get_detach_var(x1) + x2=self.get_detach_var(x2) + x3=self.get_detach_var(x3) + + feat = self.upsample(feat) # 1,1024, 8, 28, 28 + feat = self.concat(x3,feat)# 1,1024+512, 8, 28, 28 + reg_out=[self.forward_branch(feat,x1,x2,label) for label in self.labels] + reg_out=torch.cat(reg_out,1) + reg_out=torch.sigmoid(reg_out) + class_out = reg_out.mean([3, 4]).view(-1, 1)*100 # 1,2,8,56,56 + class_out = self.linear(class_out) + class_out = class_out.view(reg_out.size(0),len(self.labels),-1) + class_out = torch.sigmoid(class_out) + return reg_out,class_out + + + def forward_branch(self,feat,x1,x2,label): + t4= getattr(self,f"t4_{label}") + feat = t4(feat[0])# 1,256, 8, 28, 28 + feat = self.upsample([feat]) # 1,256, 8, 56, 56 + feat = self.concat(x2,feat)# 1,256+256, 8, 56, 56 + t3= getattr(self,f"t3_{label}") + feat = t3(feat[0]) # 1,128, 8, 56, 56 + feat = self.concat(x1,[feat]) # 1,192, 8, 56, 56 + conv1x1=getattr(self,f"conv1x1_{label}") + feat = conv1x1(feat[0]) # 1,2,8,56,56 + return feat + + + +@MODEL_REGISTRY.register() +class ResUNetStrong(nn.Module): + """ + ResNet model builder. It builds a ResNet like network backbone without + lateral connection (C2D, I3D, Slow). + + Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He. + "SlowFast networks for video recognition." + https://arxiv.org/pdf/1812.03982.pdf + + Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He. + "Non-local neural networks." + https://arxiv.org/pdf/1711.07971.pdf + """ + + def __init__(self, cfg): + """ + The `__init__` method of any subclass should also contain these + arguments. + + Args: + cfg (CfgNode): model building configs, details are in the + comments of the config file. + """ + super(ResUNetStrong, self).__init__() + self.norm_module = get_norm(cfg) + self.enable_detection = cfg.DETECTION.ENABLE + self.enable_jitter = cfg.JITTER.ENABLE + self.num_pathways = 1 + assert cfg.DATA.TRAIN_CROP_SIZE == cfg.DATA.TEST_CROP_SIZE + self.image_size = cfg.DATA.TRAIN_CROP_SIZE + self.clip_size = cfg.DATA.NUM_FRAMES + self._construct_network(cfg) + init_helper.init_weights( + self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN + ) + + def _construct_network(self, cfg): + """ + Builds a single pathway ResNet model. + + Args: + cfg (CfgNode): model building configs, details are in the + comments of the config file. + """ + assert cfg.MODEL.ARCH in _POOL1.keys() + pool_size = _POOL1[cfg.MODEL.ARCH] + self.cfg = cfg + assert len({len(pool_size), self.num_pathways}) == 1 + assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys() + + (d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH] + + num_groups = cfg.RESNET.NUM_GROUPS + width_per_group = cfg.RESNET.WIDTH_PER_GROUP + dim_inner = num_groups * width_per_group + + temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH] + + self.s1 = stem_helper.VideoModelStem( + dim_in=cfg.DATA.INPUT_CHANNEL_NUM, + dim_out=[width_per_group], + kernel=[temp_kernel[0][0] + [7, 7]], + stride=[[1, 2, 2]], + padding=[[temp_kernel[0][0][0] // 2, 3, 3]], + norm_module=self.norm_module, + ) + + self.s2 = resnet_helper.ResStage( + dim_in=[width_per_group], + dim_out=[width_per_group * 4], + dim_inner=[dim_inner], + temp_kernel_sizes=temp_kernel[1], + stride=cfg.RESNET.SPATIAL_STRIDES[0], + num_blocks=[d2], + num_groups=[num_groups], + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0], + nonlocal_inds=cfg.NONLOCAL.LOCATION[0], + nonlocal_group=cfg.NONLOCAL.GROUP[0], + nonlocal_pool=cfg.NONLOCAL.POOL[0], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + stride_1x1=cfg.RESNET.STRIDE_1X1, + inplace_relu=cfg.RESNET.INPLACE_RELU, + dilation=cfg.RESNET.SPATIAL_DILATIONS[0], + norm_module=self.norm_module, + ) + + for pathway in range(self.num_pathways): + pool = nn.MaxPool3d( + kernel_size=pool_size[pathway], + stride=pool_size[pathway], + padding=[0, 0, 0], + ) + self.add_module("pathway{}_pool".format(pathway), pool) + + self.s3 = resnet_helper.ResStage( + dim_in=[width_per_group * 4], + dim_out=[width_per_group * 8], + dim_inner=[dim_inner * 2], + temp_kernel_sizes=temp_kernel[2], + stride=cfg.RESNET.SPATIAL_STRIDES[1], + num_blocks=[d3], + num_groups=[num_groups], + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1], + nonlocal_inds=cfg.NONLOCAL.LOCATION[1], + nonlocal_group=cfg.NONLOCAL.GROUP[1], + nonlocal_pool=cfg.NONLOCAL.POOL[1], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + stride_1x1=cfg.RESNET.STRIDE_1X1, + inplace_relu=cfg.RESNET.INPLACE_RELU, + dilation=cfg.RESNET.SPATIAL_DILATIONS[1], + norm_module=self.norm_module, + ) + + self.s4 = resnet_helper.ResStage( + dim_in=[width_per_group * 8], + dim_out=[width_per_group * 16], + dim_inner=[dim_inner * 4], + temp_kernel_sizes=temp_kernel[3], + stride=cfg.RESNET.SPATIAL_STRIDES[2], + num_blocks=[d4], + num_groups=[num_groups], + num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2], + nonlocal_inds=cfg.NONLOCAL.LOCATION[2], + nonlocal_group=cfg.NONLOCAL.GROUP[2], + nonlocal_pool=cfg.NONLOCAL.POOL[2], + instantiation=cfg.NONLOCAL.INSTANTIATION, + trans_func_name=cfg.RESNET.TRANS_FUNC, + stride_1x1=cfg.RESNET.STRIDE_1X1, + inplace_relu=cfg.RESNET.INPLACE_RELU, + dilation=cfg.RESNET.SPATIAL_DILATIONS[2], + norm_module=self.norm_module, + ) + + # self.s5 = resnet_helper.ResStage( + # dim_in=[width_per_group * 16], + # dim_out=[width_per_group * 32], + # dim_inner=[dim_inner * 8], + # temp_kernel_sizes=temp_kernel[4], + # stride=cfg.RESNET.SPATIAL_STRIDES[3], + # num_blocks=[d5], + # num_groups=[num_groups], + # num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[3], + # nonlocal_inds=cfg.NONLOCAL.LOCATION[3], + # nonlocal_group=cfg.NONLOCAL.GROUP[3], + # nonlocal_pool=cfg.NONLOCAL.POOL[3], + # instantiation=cfg.NONLOCAL.INSTANTIATION, + # trans_func_name=cfg.RESNET.TRANS_FUNC, + # stride_1x1=cfg.RESNET.STRIDE_1X1, + # inplace_relu=cfg.RESNET.INPLACE_RELU, + # dilation=cfg.RESNET.SPATIAL_DILATIONS[3], + # norm_module=self.norm_module, + # ) + + self.labels=cfg.RESNET.LABELS + self.dual_define("t4",self.labels,ResDecoderBlock(width_per_group * 16,width_per_group * 8,width_per_group * 8)) + self.dual_define("t3",self.labels,ResDecoderBlock(width_per_group * 8,width_per_group * 4, 256)) + self.dual_define("conv1x1",self.labels,nn.Sequential( + nn.Conv3d(width_per_group*4+width_per_group, 128, kernel_size=(1, 1, 1), stride=1, padding=0), + nn.BatchNorm3d(128), + nn.ReLU(), + nn.Conv3d(128, 1, kernel_size=(1, 1, 1), stride=1, padding=0), + )) + + self.linear = nn.Linear(1, 1) + + def forward_plus(self, x, y, net): + return [net(x)[0] + y[0]] + + + def dual_define(self,name,labels,net): + for label in labels: + self.add_module(f"{name}_{label}",copy.deepcopy(net)) + + + def upsample(self, x, dims=["space"]): + ori_size = x[0].shape[2:5] + t, h, w = ori_size + if "space" in dims: + h = 2 * h + w = 2 * w + if "time" in dims: + t = 2 * t + size = (t, h, w) + return [F.interpolate(x[0], size)] + + def concat(self,x,y): + return [torch.cat([x[0],y[0]],1)] + + def get_detach_var(self,x): + return [t.detach() for t in x] + + # @torchsnooper.snoop() + def forward(self, x, freeze_backbone=False): + x = self.get_detach_var(x) + x1 = self.s1(x) # 1,64,8,56,56 + x2 = self.s2(x1) # 1,256,8,56,56 + x3 = self.s3(x2) # 1,512,8,28, 28 + feat= self.s4(x3) # 1,1024,8,14,14 + assert isinstance(freeze_backbone,bool) + if freeze_backbone: + feat=self.get_detach_var(feat) + x1=self.get_detach_var(x1) + x2=self.get_detach_var(x2) + x3=self.get_detach_var(x3) + + feat = self.upsample(feat) # 1,1024, 8, 28, 28 + feat = self.concat(x3,feat)# 1,1024+512, 8, 28, 28 + reg_out=[self.forward_branch(feat,x1,x2,label) for label in self.labels] + reg_out=torch.cat(reg_out,1) + reg_out=torch.sigmoid(reg_out) + class_out = reg_out.mean([3, 4]).view(-1, 1)*100 # 1,2,8,56,56 + class_out = self.linear(class_out) + class_out = class_out.view(reg_out.size(0),len(self.labels),-1) + class_out = torch.sigmoid(class_out) + return reg_out,class_out + + + def forward_branch(self,feat,x1,x2,label): + t4= getattr(self,f"t4_{label}") + feat = t4(feat[0])# 1,256, 8, 28, 28 + feat = self.upsample([feat]) # 1,256, 8, 56, 56 + feat = self.concat(x2,feat)# 1,256+256, 8, 56, 56 + t3= getattr(self,f"t3_{label}") + feat = t3(feat[0]) # 1,128, 8, 56, 56 + feat = self.concat(x1,[feat]) # 1,192, 8, 56, 56 + conv1x1=getattr(self,f"conv1x1_{label}") + feat = conv1x1(feat[0]) # 1,2,8,56,56 + return feat diff --git a/training/detectors/utils/slowfast/utils/__init__.py b/training/detectors/utils/slowfast/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8dbe96a785072a24a9bcc4841a1934024f2b06a1 --- /dev/null +++ b/training/detectors/utils/slowfast/utils/__init__.py @@ -0,0 +1,2 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. diff --git a/training/detectors/utils/slowfast/utils/ava_eval_helper.py b/training/detectors/utils/slowfast/utils/ava_eval_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..9e8ba5468077053a4dcf1920f25256a1a241f0b9 --- /dev/null +++ b/training/detectors/utils/slowfast/utils/ava_eval_helper.py @@ -0,0 +1,302 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## +# +# Based on: +# -------------------------------------------------------- +# ActivityNet +# Copyright (c) 2015 ActivityNet +# Licensed under The MIT License +# [see https://github.com/activitynet/ActivityNet/blob/master/LICENSE for details] +# -------------------------------------------------------- + +"""Helper functions for AVA evaluation.""" + +from __future__ import ( + absolute_import, + division, + print_function, + unicode_literals, +) +import csv +import logging +import numpy as np +import pprint +import time +from collections import defaultdict +from fvcore.common.file_io import PathManager + +from slowfast.utils.ava_evaluation import ( + object_detection_evaluation, + standard_fields, +) + +logger = logging.getLogger(__name__) + + +def make_image_key(video_id, timestamp): + """Returns a unique identifier for a video id & timestamp.""" + return "%s,%04d" % (video_id, int(timestamp)) + + +def read_csv(csv_file, class_whitelist=None, load_score=False): + """Loads boxes and class labels from a CSV file in the AVA format. + CSV file format described at https://research.google.com/ava/download.html. + Args: + csv_file: A file object. + class_whitelist: If provided, boxes corresponding to (integer) class labels + not in this set are skipped. + Returns: + boxes: A dictionary mapping each unique image key (string) to a list of + boxes, given as coordinates [y1, x1, y2, x2]. + labels: A dictionary mapping each unique image key (string) to a list of + integer class lables, matching the corresponding box in `boxes`. + scores: A dictionary mapping each unique image key (string) to a list of + score values lables, matching the corresponding label in `labels`. If + scores are not provided in the csv, then they will default to 1.0. + """ + boxes = defaultdict(list) + labels = defaultdict(list) + scores = defaultdict(list) + with PathManager.open(csv_file, "r") as f: + reader = csv.reader(f) + for row in reader: + assert len(row) in [7, 8], "Wrong number of columns: " + row + image_key = make_image_key(row[0], row[1]) + x1, y1, x2, y2 = [float(n) for n in row[2:6]] + action_id = int(row[6]) + if class_whitelist and action_id not in class_whitelist: + continue + score = 1.0 + if load_score: + score = float(row[7]) + boxes[image_key].append([y1, x1, y2, x2]) + labels[image_key].append(action_id) + scores[image_key].append(score) + return boxes, labels, scores + + +def read_exclusions(exclusions_file): + """Reads a CSV file of excluded timestamps. + Args: + exclusions_file: A file object containing a csv of video-id,timestamp. + Returns: + A set of strings containing excluded image keys, e.g. "aaaaaaaaaaa,0904", + or an empty set if exclusions file is None. + """ + excluded = set() + if exclusions_file: + with PathManager.open(exclusions_file, "r") as f: + reader = csv.reader(f) + for row in reader: + assert len(row) == 2, "Expected only 2 columns, got: " + row + excluded.add(make_image_key(row[0], row[1])) + return excluded + + +def read_labelmap(labelmap_file): + """Read label map and class ids.""" + + labelmap = [] + class_ids = set() + name = "" + class_id = "" + with PathManager.open(labelmap_file, "r") as f: + for line in f: + if line.startswith(" name:"): + name = line.split('"')[1] + elif line.startswith(" id:") or line.startswith(" label_id:"): + class_id = int(line.strip().split(" ")[-1]) + labelmap.append({"id": class_id, "name": name}) + class_ids.add(class_id) + return labelmap, class_ids + + +def evaluate_ava_from_files(labelmap, groundtruth, detections, exclusions): + """Run AVA evaluation given annotation/prediction files.""" + + categories, class_whitelist = read_labelmap(labelmap) + excluded_keys = read_exclusions(exclusions) + groundtruth = read_csv(groundtruth, class_whitelist, load_score=False) + detections = read_csv(detections, class_whitelist, load_score=True) + run_evaluation(categories, groundtruth, detections, excluded_keys) + + +def evaluate_ava( + preds, + original_boxes, + metadata, + excluded_keys, + class_whitelist, + categories, + groundtruth=None, + video_idx_to_name=None, + name="latest", +): + """Run AVA evaluation given numpy arrays.""" + + eval_start = time.time() + + detections = get_ava_eval_data( + preds, + original_boxes, + metadata, + class_whitelist, + video_idx_to_name=video_idx_to_name, + ) + + logger.info("Evaluating with %d unique GT frames." % len(groundtruth[0])) + logger.info( + "Evaluating with %d unique detection frames" % len(detections[0]) + ) + + write_results(detections, "detections_%s.csv" % name) + write_results(groundtruth, "groundtruth_%s.csv" % name) + + results = run_evaluation(categories, groundtruth, detections, excluded_keys) + + logger.info("AVA eval done in %f seconds." % (time.time() - eval_start)) + return results["PascalBoxes_Precision/mAP@0.5IOU"] + + +def run_evaluation( + categories, groundtruth, detections, excluded_keys, verbose=True +): + """AVA evaluation main logic.""" + + pascal_evaluator = object_detection_evaluation.PascalDetectionEvaluator( + categories + ) + + boxes, labels, _ = groundtruth + + gt_keys = [] + pred_keys = [] + + for image_key in boxes: + if image_key in excluded_keys: + logging.info( + ( + "Found excluded timestamp in ground truth: %s. " + "It will be ignored." + ), + image_key, + ) + continue + pascal_evaluator.add_single_ground_truth_image_info( + image_key, + { + standard_fields.InputDataFields.groundtruth_boxes: np.array( + boxes[image_key], dtype=float + ), + standard_fields.InputDataFields.groundtruth_classes: np.array( + labels[image_key], dtype=int + ), + standard_fields.InputDataFields.groundtruth_difficult: np.zeros( + len(boxes[image_key]), dtype=bool + ), + }, + ) + + gt_keys.append(image_key) + + boxes, labels, scores = detections + + for image_key in boxes: + if image_key in excluded_keys: + logging.info( + ( + "Found excluded timestamp in detections: %s. " + "It will be ignored." + ), + image_key, + ) + continue + pascal_evaluator.add_single_detected_image_info( + image_key, + { + standard_fields.DetectionResultFields.detection_boxes: np.array( + boxes[image_key], dtype=float + ), + standard_fields.DetectionResultFields.detection_classes: np.array( + labels[image_key], dtype=int + ), + standard_fields.DetectionResultFields.detection_scores: np.array( + scores[image_key], dtype=float + ), + }, + ) + + pred_keys.append(image_key) + + metrics = pascal_evaluator.evaluate() + + pprint.pprint(metrics, indent=2) + return metrics + + +def get_ava_eval_data( + scores, + boxes, + metadata, + class_whitelist, + verbose=False, + video_idx_to_name=None, +): + """ + Convert our data format into the data format used in official AVA + evaluation. + """ + + out_scores = defaultdict(list) + out_labels = defaultdict(list) + out_boxes = defaultdict(list) + count = 0 + for i in range(scores.shape[0]): + video_idx = int(np.round(metadata[i][0])) + sec = int(np.round(metadata[i][1])) + + video = video_idx_to_name[video_idx] + + key = video + "," + "%04d" % (sec) + batch_box = boxes[i].tolist() + # The first is batch idx. + batch_box = [batch_box[j] for j in [0, 2, 1, 4, 3]] + + one_scores = scores[i].tolist() + for cls_idx, score in enumerate(one_scores): + if cls_idx + 1 in class_whitelist: + out_scores[key].append(score) + out_labels[key].append(cls_idx + 1) + out_boxes[key].append(batch_box[1:]) + count += 1 + + return out_boxes, out_labels, out_scores + + +def write_results(detections, filename): + """Write prediction results into official formats.""" + start = time.time() + + boxes, labels, scores = detections + with PathManager.open(filename, "w") as f: + for key in boxes.keys(): + for box, label, score in zip(boxes[key], labels[key], scores[key]): + f.write( + "%s,%.03f,%.03f,%.03f,%.03f,%d,%.04f\n" + % (key, box[1], box[0], box[3], box[2], label, score) + ) + + logger.info("AVA results wrote to %s" % filename) + logger.info("\ttook %d seconds." % (time.time() - start)) diff --git a/training/detectors/utils/slowfast/utils/benchmark.py b/training/detectors/utils/slowfast/utils/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..33e5fe9073ad61ecec737d6b4a6a2880eec15cb9 --- /dev/null +++ b/training/detectors/utils/slowfast/utils/benchmark.py @@ -0,0 +1,103 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Functions for benchmarks. +""" + +import numpy as np +import pprint +import torch +import tqdm +from fvcore.common.timer import Timer + +import slowfast.utils.logging as logging +import slowfast.utils.misc as misc +from slowfast.datasets import loader +from slowfast.utils.env import setup_environment + +logger = logging.get_logger(__name__) + + +def benchmark_data_loading(cfg): + """ + Benchmark the speed of data loading in PySlowFast. + Args: + + cfg (CfgNode): configs. Details can be found in + slowfast/config/defaults.py + """ + # Set up environment. + setup_environment() + # Set random seed from configs. + np.random.seed(cfg.RNG_SEED) + torch.manual_seed(cfg.RNG_SEED) + + # Setup logging format. + logging.setup_logging(cfg.OUTPUT_DIR) + + # Print config. + logger.info("Benchmark data loading with config:") + logger.info(pprint.pformat(cfg)) + + timer = Timer() + dataloader = loader.construct_loader(cfg, "train") + logger.info( + "Initialize loader using {:.2f} seconds.".format(timer.seconds()) + ) + # Total batch size across different machines. + batch_size = cfg.TRAIN.BATCH_SIZE * cfg.NUM_SHARDS + log_period = cfg.BENCHMARK.LOG_PERIOD + epoch_times = [] + # Test for a few epochs. + for cur_epoch in range(cfg.BENCHMARK.NUM_EPOCHS): + timer = Timer() + timer_epoch = Timer() + iter_times = [] + if cfg.BENCHMARK.SHUFFLE: + loader.shuffle_dataset(dataloader, cur_epoch) + for cur_iter, _ in enumerate(tqdm.tqdm(dataloader)): + if cur_iter > 0 and cur_iter % log_period == 0: + iter_times.append(timer.seconds()) + ram_usage, ram_total = misc.cpu_mem_usage() + logger.info( + "Epoch {}: {} iters ({} videos) in {:.2f} seconds. " + "RAM Usage: {:.2f}/{:.2f} GB.".format( + cur_epoch, + log_period, + log_period * batch_size, + iter_times[-1], + ram_usage, + ram_total, + ) + ) + timer.reset() + epoch_times.append(timer_epoch.seconds()) + ram_usage, ram_total = misc.cpu_mem_usage() + logger.info( + "Epoch {}: in total {} iters ({} videos) in {:.2f} seconds. " + "RAM Usage: {:.2f}/{:.2f} GB.".format( + cur_epoch, + len(dataloader), + len(dataloader) * batch_size, + epoch_times[-1], + ram_usage, + ram_total, + ) + ) + logger.info( + "Epoch {}: on average every {} iters ({} videos) take {:.2f}/{:.2f} " + "(avg/std) seconds.".format( + cur_epoch, + log_period, + log_period * batch_size, + np.mean(iter_times), + np.std(iter_times), + ) + ) + logger.info( + "On average every epoch ({} videos) takes {:.2f}/{:.2f} " + "(avg/std) seconds.".format( + len(dataloader) * batch_size, + np.mean(epoch_times), + np.std(epoch_times), + ) + ) diff --git a/training/detectors/utils/slowfast/utils/bn_helper.py b/training/detectors/utils/slowfast/utils/bn_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..b18d8c76c10d7598db61ba8ca192314140f9ba79 --- /dev/null +++ b/training/detectors/utils/slowfast/utils/bn_helper.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +"""bn helper.""" + +import itertools +import torch + + +@torch.no_grad() +def compute_and_update_bn_stats(model, data_loader, num_batches=200): + """ + Compute and update the batch norm stats to make it more precise. During + training both bn stats and the weight are changing after every iteration, + so the bn can not precisely reflect the latest stats of the current model. + Here the bn stats is recomputed without change of weights, to make the + running mean and running var more precise. + Args: + model (model): the model using to compute and update the bn stats. + data_loader (dataloader): dataloader using to provide inputs. + num_batches (int): running iterations using to compute the stats. + """ + + # Prepares all the bn layers. + bn_layers = [ + m + for m in model.modules() + if any( + ( + isinstance(m, bn_type) + for bn_type in ( + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + ) + ) + ) + ] + + # In order to make the running stats only reflect the current batch, the + # momentum is disabled. + # bn.running_mean = (1 - momentum) * bn.running_mean + momentum * batch_mean + # Setting the momentum to 1.0 to compute the stats without momentum. + momentum_actual = [bn.momentum for bn in bn_layers] + for bn in bn_layers: + bn.momentum = 1.0 + + # Calculates the running iterations for precise stats computation. + running_mean = [torch.zeros_like(bn.running_mean) for bn in bn_layers] + running_square_mean = [torch.zeros_like(bn.running_var) for bn in bn_layers] + + for ind, (inputs, _, _) in enumerate( + itertools.islice(data_loader, num_batches) + ): + # Forwards the model to update the bn stats. + if isinstance(inputs, (list,)): + for i in range(len(inputs)): + inputs[i] = inputs[i].float().cuda(non_blocking=True) + else: + inputs = inputs.cuda(non_blocking=True) + model(inputs) + + for i, bn in enumerate(bn_layers): + # Accumulates the bn stats. + running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1) + # $E(x^2) = Var(x) + E(x)^2$. + cur_square_mean = bn.running_var + bn.running_mean ** 2 + running_square_mean[i] += ( + cur_square_mean - running_square_mean[i] + ) / (ind + 1) + + for i, bn in enumerate(bn_layers): + bn.running_mean = running_mean[i] + # Var(x) = $E(x^2) - E(x)^2$. + bn.running_var = running_square_mean[i] - bn.running_mean ** 2 + # Sets the precise bn stats. + bn.momentum = momentum_actual[i] diff --git a/training/detectors/utils/slowfast/utils/c2_model_loading.py b/training/detectors/utils/slowfast/utils/c2_model_loading.py new file mode 100644 index 0000000000000000000000000000000000000000..4bcc0759c484fd321917c55e9967835632c2ac54 --- /dev/null +++ b/training/detectors/utils/slowfast/utils/c2_model_loading.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +"""Caffe2 to PyTorch checkpoint name converting utility.""" + +import re + + +def get_name_convert_func(): + """ + Get the function to convert Caffe2 layer names to PyTorch layer names. + Returns: + (func): function to convert parameter name from Caffe2 format to PyTorch + format. + """ + pairs = [ + # ------------------------------------------------------------ + # 'nonlocal_conv3_1_theta_w' -> 's3.pathway0_nonlocal3.conv_g.weight' + [ + r"^nonlocal_conv([0-9]+)_([0-9]+)_(.*)", + r"s\1.pathway0_nonlocal\2_\3", + ], + # 'theta' -> 'conv_theta' + [r"^(.*)_nonlocal([0-9]+)_(theta)(.*)", r"\1_nonlocal\2.conv_\3\4"], + # 'g' -> 'conv_g' + [r"^(.*)_nonlocal([0-9]+)_(g)(.*)", r"\1_nonlocal\2.conv_\3\4"], + # 'phi' -> 'conv_phi' + [r"^(.*)_nonlocal([0-9]+)_(phi)(.*)", r"\1_nonlocal\2.conv_\3\4"], + # 'out' -> 'conv_out' + [r"^(.*)_nonlocal([0-9]+)_(out)(.*)", r"\1_nonlocal\2.conv_\3\4"], + # 'nonlocal_conv4_5_bn_s' -> 's4.pathway0_nonlocal3.bn.weight' + [r"^(.*)_nonlocal([0-9]+)_(bn)_(.*)", r"\1_nonlocal\2.\3.\4"], + # ------------------------------------------------------------ + # 't_pool1_subsample_bn' -> 's1_fuse.conv_f2s.bn.running_mean' + [r"^t_pool1_subsample_bn_(.*)", r"s1_fuse.bn.\1"], + # 't_pool1_subsample' -> 's1_fuse.conv_f2s' + [r"^t_pool1_subsample_(.*)", r"s1_fuse.conv_f2s.\1"], + # 't_res4_5_branch2c_bn_subsample_bn_rm' -> 's4_fuse.conv_f2s.bias' + [ + r"^t_res([0-9]+)_([0-9]+)_branch2c_bn_subsample_bn_(.*)", + r"s\1_fuse.bn.\3", + ], + # 't_pool1_subsample' -> 's1_fuse.conv_f2s' + [ + r"^t_res([0-9]+)_([0-9]+)_branch2c_bn_subsample_(.*)", + r"s\1_fuse.conv_f2s.\3", + ], + # ------------------------------------------------------------ + # 'res4_4_branch_2c_bn_b' -> 's4.pathway0_res4.branch2.c_bn_b' + [ + r"^res([0-9]+)_([0-9]+)_branch([0-9]+)([a-z])_(.*)", + r"s\1.pathway0_res\2.branch\3.\4_\5", + ], + # 'res_conv1_bn_' -> 's1.pathway0_stem.bn.' + [r"^res_conv1_bn_(.*)", r"s1.pathway0_stem.bn.\1"], + # 'conv1_w_momentum' -> 's1.pathway0_stem.conv.' + [r"^conv1_(.*)", r"s1.pathway0_stem.conv.\1"], + # 'res4_0_branch1_w' -> 'S4.pathway0_res0.branch1.weight' + [ + r"^res([0-9]+)_([0-9]+)_branch([0-9]+)_(.*)", + r"s\1.pathway0_res\2.branch\3_\4", + ], + # 'res_conv1_' -> 's1.pathway0_stem.conv.' + [r"^res_conv1_(.*)", r"s1.pathway0_stem.conv.\1"], + # ------------------------------------------------------------ + # 'res4_4_branch_2c_bn_b' -> 's4.pathway0_res4.branch2.c_bn_b' + [ + r"^t_res([0-9]+)_([0-9]+)_branch([0-9]+)([a-z])_(.*)", + r"s\1.pathway1_res\2.branch\3.\4_\5", + ], + # 'res_conv1_bn_' -> 's1.pathway0_stem.bn.' + [r"^t_res_conv1_bn_(.*)", r"s1.pathway1_stem.bn.\1"], + # 'conv1_w_momentum' -> 's1.pathway0_stem.conv.' + [r"^t_conv1_(.*)", r"s1.pathway1_stem.conv.\1"], + # 'res4_0_branch1_w' -> 'S4.pathway0_res0.branch1.weight' + [ + r"^t_res([0-9]+)_([0-9]+)_branch([0-9]+)_(.*)", + r"s\1.pathway1_res\2.branch\3_\4", + ], + # 'res_conv1_' -> 's1.pathway0_stem.conv.' + [r"^t_res_conv1_(.*)", r"s1.pathway1_stem.conv.\1"], + # ------------------------------------------------------------ + # pred_ -> head.projection. + [r"pred_(.*)", r"head.projection.\1"], + # '.bn_b' -> '.weight' + [r"(.*)bn.b\Z", r"\1bn.bias"], + # '.bn_s' -> '.weight' + [r"(.*)bn.s\Z", r"\1bn.weight"], + # '_bn_rm' -> '.running_mean' + [r"(.*)bn.rm\Z", r"\1bn.running_mean"], + # '_bn_riv' -> '.running_var' + [r"(.*)bn.riv\Z", r"\1bn.running_var"], + # '_b' -> '.bias' + [r"(.*)[\._]b\Z", r"\1.bias"], + # '_w' -> '.weight' + [r"(.*)[\._]w\Z", r"\1.weight"], + ] + + def convert_caffe2_name_to_pytorch(caffe2_layer_name): + """ + Convert the caffe2_layer_name to pytorch format by apply the list of + regular expressions. + Args: + caffe2_layer_name (str): caffe2 layer name. + Returns: + (str): pytorch layer name. + """ + for source, dest in pairs: + caffe2_layer_name = re.sub(source, dest, caffe2_layer_name) + return caffe2_layer_name + + return convert_caffe2_name_to_pytorch diff --git a/training/detectors/utils/slowfast/utils/checkpoint.py b/training/detectors/utils/slowfast/utils/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..05d5ac4624feecbbc1f868682f8ed921ac5fef2c --- /dev/null +++ b/training/detectors/utils/slowfast/utils/checkpoint.py @@ -0,0 +1,530 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +"""Functions that handle saving and loading of checkpoints.""" + +import copy +import numpy as np +import os +import pickle +from collections import OrderedDict +import torch +from fvcore.common.file_io import PathManager + +import slowfast.utils.distributed as du +import slowfast.utils.logging as logging +from slowfast.utils.c2_model_loading import get_name_convert_func + +logger = logging.get_logger(__name__) + + +def make_checkpoint_dir(path_to_job): + """ + Creates the checkpoint directory (if not present already). + Args: + path_to_job (string): the path to the folder of the current job. + """ + checkpoint_dir = os.path.join(path_to_job, "checkpoints") + # Create the checkpoint dir from the master process + if du.is_master_proc() and not PathManager.exists(checkpoint_dir): + try: + PathManager.mkdirs(checkpoint_dir) + except Exception: + pass + return checkpoint_dir + + +def get_checkpoint_dir(path_to_job): + """ + Get path for storing checkpoints. + Args: + path_to_job (string): the path to the folder of the current job. + """ + return os.path.join(path_to_job, "checkpoints") + + +def get_path_to_checkpoint(path_to_job, epoch): + """ + Get the full path to a checkpoint file. + Args: + path_to_job (string): the path to the folder of the current job. + epoch (int): the number of epoch for the checkpoint. + """ + name = "checkpoint_epoch_{:07d}.pyth".format(epoch) + return os.path.join(get_checkpoint_dir(path_to_job), name) + + +def get_last_checkpoint(path_to_job): + """ + Get the last checkpoint from the checkpointing folder. + Args: + path_to_job (string): the path to the folder of the current job. + """ + + d = get_checkpoint_dir(path_to_job) + names = PathManager.ls(d) if PathManager.exists(d) else [] + names = [f for f in names if "checkpoint" in f] + assert len(names), "No checkpoints found in '{}'.".format(d) + # Sort the checkpoints by epoch. + name = sorted(names)[-1] + return os.path.join(d, name) + + +def has_checkpoint(path_to_job): + """ + Determines if the given directory contains a checkpoint. + Args: + path_to_job (string): the path to the folder of the current job. + """ + d = get_checkpoint_dir(path_to_job) + files = PathManager.ls(d) if PathManager.exists(d) else [] + return any("checkpoint" in f for f in files) + + +def is_checkpoint_epoch(cfg, cur_epoch, multigrid_schedule=None): + """ + Determine if a checkpoint should be saved on current epoch. + Args: + cfg (CfgNode): configs to save. + cur_epoch (int): current number of epoch of the model. + multigrid_schedule (List): schedule for multigrid training. + """ + if cur_epoch + 1 == cfg.SOLVER.MAX_EPOCH: + return True + if multigrid_schedule is not None: + prev_epoch = 0 + for s in multigrid_schedule: + if cur_epoch < s[-1]: + period = max( + (s[-1] - prev_epoch) // cfg.MULTIGRID.EVAL_FREQ + 1, 1 + ) + return (s[-1] - 1 - cur_epoch) % period == 0 + prev_epoch = s[-1] + + return (cur_epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0 + + +def is_checkpoint_iter(cfg, cur_iter): + """ + Determine if a checkpoint should be saved on current iter. + Args: + cfg (CfgNode): configs to save. + cur_epoch (int): current number of epoch of the model. + multigrid_schedule (List): schedule for multigrid training. + """ + + return (cur_iter+1) % cfg.TRAIN.CHECKPOINT_PERIOD_BY_ITER == 0 + + + +def save_checkpoint_by_iter(path_to_job, model, optimizer, epoch,global_step,cfg): + """ + Save a checkpoint. + Args: + model (model): model to save the weight to the checkpoint. + optimizer (optim): optimizer to save the historical state. + epoch (int): current number of epoch of the model. + cfg (CfgNode): configs to save. + """ + # Save checkpoints only from the master process. + if not du.is_master_proc(cfg.NUM_GPUS * cfg.NUM_SHARDS): + return + # Ensure that the checkpoint dir exists. + PathManager.mkdirs(get_checkpoint_dir(path_to_job)) + # Omit the DDP wrapper in the multi-gpu setting. + sd = model.module.state_dict() if cfg.NUM_GPUS > 1 else model.state_dict() + normalized_sd = sub_to_normal_bn(sd) + + # Record the state. + checkpoint = { + "epoch": epoch, + "model_state": normalized_sd, + "optimizer_state": optimizer.state_dict(), + "global_step": global_step, + "cfg": cfg.dump(), + } + # Write the checkpoint. + path_to_checkpoint = get_path_to_checkpoint(path_to_job,global_step+1) + with PathManager.open(path_to_checkpoint, "wb") as f: + torch.save(checkpoint, f) + return path_to_checkpoint + +def save_checkpoint(path_to_job, model, optimizer, epoch, cfg): + """ + Save a checkpoint. + Args: + model (model): model to save the weight to the checkpoint. + optimizer (optim): optimizer to save the historical state. + epoch (int): current number of epoch of the model. + cfg (CfgNode): configs to save. + """ + # Save checkpoints only from the master process. + if not du.is_master_proc(cfg.NUM_GPUS * cfg.NUM_SHARDS): + return + # Ensure that the checkpoint dir exists. + PathManager.mkdirs(get_checkpoint_dir(path_to_job)) + # Omit the DDP wrapper in the multi-gpu setting. + sd = model.module.state_dict() if cfg.NUM_GPUS > 1 else model.state_dict() + normalized_sd = sub_to_normal_bn(sd) + + # Record the state. + checkpoint = { + "epoch": epoch, + "model_state": normalized_sd, + "optimizer_state": optimizer.state_dict(), + "cfg": cfg.dump(), + } + # Write the checkpoint. + path_to_checkpoint = get_path_to_checkpoint(path_to_job, epoch + 1) + with PathManager.open(path_to_checkpoint, "wb") as f: + torch.save(checkpoint, f) + return path_to_checkpoint + + +def inflate_weight(state_dict_2d, state_dict_3d): + """ + Inflate 2D model weights in state_dict_2d to the 3D model weights in + state_dict_3d. The details can be found in: + Joao Carreira, and Andrew Zisserman. + "Quo vadis, action recognition? a new model and the kinetics dataset." + Args: + state_dict_2d (OrderedDict): a dict of parameters from a 2D model. + state_dict_3d (OrderedDict): a dict of parameters from a 3D model. + Returns: + state_dict_inflated (OrderedDict): a dict of inflated parameters. + """ + state_dict_inflated = OrderedDict() + for k, v2d in state_dict_2d.items(): + assert k in state_dict_3d.keys() + v3d = state_dict_3d[k] + # Inflate the weight of 2D conv to 3D conv. + if len(v2d.shape) == 4 and len(v3d.shape) == 5: + logger.info( + "Inflate {}: {} -> {}: {}".format(k, v2d.shape, k, v3d.shape) + ) + # Dimension need to be match. + assert v2d.shape[-2:] == v3d.shape[-2:] + assert v2d.shape[:2] == v3d.shape[:2] + v3d = ( + v2d.unsqueeze(2).repeat(1, 1, v3d.shape[2], 1, 1) / v3d.shape[2] + ) + elif v2d.shape == v3d.shape: + v3d = v2d + else: + logger.info( + "Unexpected {}: {} -|> {}: {}".format( + k, v2d.shape, k, v3d.shape + ) + ) + state_dict_inflated[k] = v3d.clone() + return state_dict_inflated + + +def load_checkpoint( + path_to_checkpoint, + model, + data_parallel=True, + optimizer=None, + inflation=False, + convert_from_caffe2=False, +): + """ + Load the checkpoint from the given file. If inflation is True, inflate the + 2D Conv weights from the checkpoint to 3D Conv. + Args: + path_to_checkpoint (string): path to the checkpoint to load. + model (model): model to load the weights from the checkpoint. + data_parallel (bool): if true, model is wrapped by + torch.nn.parallel.DistributedDataParallel. + optimizer (optim): optimizer to load the historical state. + inflation (bool): if True, inflate the weights from the checkpoint. + convert_from_caffe2 (bool): if True, load the model from caffe2 and + convert it to pytorch. + Returns: + (int): the number of training epoch of the checkpoint. + """ + assert PathManager.exists( + path_to_checkpoint + ), "Checkpoint '{}' not found".format(path_to_checkpoint) + # Account for the DDP wrapper in the multi-gpu setting. + ms = model.module if data_parallel else model + if convert_from_caffe2: + with PathManager.open(path_to_checkpoint, "rb") as f: + caffe2_checkpoint = pickle.load(f, encoding="latin1") + state_dict = OrderedDict() + name_convert_func = get_name_convert_func() + for key in caffe2_checkpoint["blobs"].keys(): + converted_key = name_convert_func(key) + converted_key = c2_normal_to_sub_bn(converted_key, ms.state_dict()) + if converted_key in ms.state_dict(): + c2_blob_shape = caffe2_checkpoint["blobs"][key].shape + model_blob_shape = ms.state_dict()[converted_key].shape + # Load BN stats to Sub-BN. + if ( + len(model_blob_shape) == 1 + and len(c2_blob_shape) == 1 + and model_blob_shape[0] > c2_blob_shape[0] + and model_blob_shape[0] % c2_blob_shape[0] == 0 + ): + caffe2_checkpoint["blobs"][key] = np.concatenate( + [caffe2_checkpoint["blobs"][key]] + * (model_blob_shape[0] // c2_blob_shape[0]) + ) + c2_blob_shape = caffe2_checkpoint["blobs"][key].shape + + if c2_blob_shape == tuple(model_blob_shape): + state_dict[converted_key] = torch.tensor( + caffe2_checkpoint["blobs"][key] + ).clone() + logger.info( + "{}: {} => {}: {}".format( + key, + c2_blob_shape, + converted_key, + tuple(model_blob_shape), + ) + ) + else: + logger.warn( + "!! {}: {} does not match {}: {}".format( + key, + c2_blob_shape, + converted_key, + tuple(model_blob_shape), + ) + ) + else: + if not any( + prefix in key for prefix in ["momentum", "lr", "model_iter"] + ): + logger.warn( + "!! {}: can not be converted, got {}".format( + key, converted_key + ) + ) + ms.load_state_dict(state_dict, strict=False) + epoch = -1 + global_step=-1 + else: + # Load the checkpoint on CPU to avoid GPU mem spike. + with PathManager.open(path_to_checkpoint, "rb") as f: + checkpoint = torch.load(f, map_location="cpu") + model_state_dict_3d = ( + model.module.state_dict() if data_parallel else model.state_dict() + ) + checkpoint["model_state"] = normal_to_sub_bn( + checkpoint["model_state"], model_state_dict_3d + ) + if inflation: + # Try to inflate the model. + inflated_model_dict = inflate_weight( + checkpoint["model_state"], model_state_dict_3d + ) + ms.load_state_dict(inflated_model_dict, strict=False) + else: + ms.load_state_dict(checkpoint["model_state"]) + # Load the optimizer state (commonly not done when fine-tuning) + if optimizer: + optimizer.load_state_dict(checkpoint["optimizer_state"]) + if "epoch" in checkpoint.keys(): + epoch = checkpoint["epoch"] + else: + epoch = -1 + if "global_step" in checkpoint.keys(): + global_step=checkpoint["global_step"] + else: + global_step=-1 + return epoch,global_step + + +def sub_to_normal_bn(sd): + """ + Convert the Sub-BN paprameters to normal BN parameters in a state dict. + There are two copies of BN layers in a Sub-BN implementation: `bn.bn` and + `bn.split_bn`. `bn.split_bn` is used during training and + "compute_precise_bn". Before saving or evaluation, its stats are copied to + `bn.bn`. We rename `bn.bn` to `bn` and store it to be consistent with normal + BN layers. + Args: + sd (OrderedDict): a dict of parameters whitch might contain Sub-BN + parameters. + Returns: + new_sd (OrderedDict): a dict with Sub-BN parameters reshaped to + normal parameters. + """ + new_sd = copy.deepcopy(sd) + modifications = [ + ("bn.bn.running_mean", "bn.running_mean"), + ("bn.bn.running_var", "bn.running_var"), + ("bn.split_bn.num_batches_tracked", "bn.num_batches_tracked"), + ] + to_remove = ["bn.bn.", ".split_bn."] + for key in sd: + for before, after in modifications: + if key.endswith(before): + new_key = key.split(before)[0] + after + new_sd[new_key] = new_sd.pop(key) + + for rm in to_remove: + if rm in key and key in new_sd: + del new_sd[key] + + for key in new_sd: + if key.endswith("bn.weight") or key.endswith("bn.bias"): + if len(new_sd[key].size()) == 4: + assert all(d == 1 for d in new_sd[key].size()[1:]) + new_sd[key] = new_sd[key][:, 0, 0, 0] + + return new_sd + + +def c2_normal_to_sub_bn(key, model_keys): + """ + Convert BN parameters to Sub-BN parameters if model contains Sub-BNs. + Args: + key (OrderedDict): source dict of parameters. + mdoel_key (OrderedDict): target dict of parameters. + Returns: + new_sd (OrderedDict): converted dict of parameters. + """ + if "bn.running_" in key: + if key in model_keys: + return key + + new_key = key.replace("bn.running_", "bn.split_bn.running_") + if new_key in model_keys: + return new_key + else: + return key + + +def normal_to_sub_bn(checkpoint_sd, model_sd): + """ + Convert BN parameters to Sub-BN parameters if model contains Sub-BNs. + Args: + checkpoint_sd (OrderedDict): source dict of parameters. + model_sd (OrderedDict): target dict of parameters. + Returns: + new_sd (OrderedDict): converted dict of parameters. + """ + for key in model_sd: + if key not in checkpoint_sd: + if "bn.split_bn." in key: + load_key = key.replace("bn.split_bn.", "bn.") + bn_key = key.replace("bn.split_bn.", "bn.bn.") + checkpoint_sd[key] = checkpoint_sd.pop(load_key) + checkpoint_sd[bn_key] = checkpoint_sd[key] + + for key in model_sd: + if key in checkpoint_sd: + model_blob_shape = model_sd[key].shape + c2_blob_shape = checkpoint_sd[key].shape + + if ( + len(model_blob_shape) == 1 + and len(c2_blob_shape) == 1 + and model_blob_shape[0] > c2_blob_shape[0] + and model_blob_shape[0] % c2_blob_shape[0] == 0 + ): + before_shape = checkpoint_sd[key].shape + checkpoint_sd[key] = torch.cat( + [checkpoint_sd[key]] + * (model_blob_shape[0] // c2_blob_shape[0]) + ) + logger.info( + "{} {} -> {}".format( + key, before_shape, checkpoint_sd[key].shape + ) + ) + return checkpoint_sd + + +def load_test_checkpoint(cfg, model): + """ + Loading checkpoint logic for testing. + """ + # Load a checkpoint to test if applicable. + if cfg.TEST.CHECKPOINT_FILE_PATH != "": + # If no checkpoint found in MODEL_VIS.CHECKPOINT_FILE_PATH or in the current + # checkpoint folder, try to load checkpoint from + # TEST.CHECKPOINT_FILE_PATH and test it. + load_checkpoint( + cfg.TEST.CHECKPOINT_FILE_PATH, + model, + cfg.NUM_GPUS > 1, + None, + inflation=False, + convert_from_caffe2=cfg.TEST.CHECKPOINT_TYPE == "caffe2", + ) + elif has_checkpoint(cfg.OUTPUT_DIR): + last_checkpoint = get_last_checkpoint(cfg.OUTPUT_DIR) + load_checkpoint(last_checkpoint, model, cfg.NUM_GPUS > 1) + elif cfg.TRAIN.CHECKPOINT_FILE_PATH != "": + # If no checkpoint found in TEST.CHECKPOINT_FILE_PATH or in the current + # checkpoint folder, try to load checkpoint from + # TRAIN.CHECKPOINT_FILE_PATH and test it. + load_checkpoint( + cfg.TRAIN.CHECKPOINT_FILE_PATH, + model, + cfg.NUM_GPUS > 1, + None, + inflation=False, + convert_from_caffe2=cfg.TRAIN.CHECKPOINT_TYPE == "caffe2", + ) + else: + logger.info( + "Unknown way of loading checkpoint. Using with random initialization, only for debugging." + ) + + +def load_train_checkpoint(cfg, model, optimizer): + """ + Loading checkpoint logic for training. + """ + if cfg.TRAIN.AUTO_RESUME and has_checkpoint(cfg.OUTPUT_DIR): + last_checkpoint = get_last_checkpoint(cfg.OUTPUT_DIR) + logger.info("Load from last checkpoint, {}.".format(last_checkpoint)) + checkpoint_epoch,global_step = load_checkpoint( + last_checkpoint, model, cfg.NUM_GPUS > 1, optimizer + ) + start_epoch = checkpoint_epoch + 1 + global_step = global_step + 1 + elif cfg.TRAIN.CHECKPOINT_FILE_PATH != "": + if cfg.TRAIN.CHECKPOINT_TYPE=="backbone": + logger.info("Load backbone from given checkpoint file.") + load_backbone(model,cfg.TRAIN.CHECKPOINT_FILE_PATH) + start_epoch = 0 + global_step = 0 + else: + logger.info("Load from given checkpoint file.") + checkpoint_epoch, global_step = load_checkpoint( + cfg.TRAIN.CHECKPOINT_FILE_PATH, + model, + cfg.NUM_GPUS > 1, + optimizer, + inflation=cfg.TRAIN.CHECKPOINT_INFLATE, + convert_from_caffe2=cfg.TRAIN.CHECKPOINT_TYPE == "caffe2", + ) + start_epoch = checkpoint_epoch + 1 + global_step = global_step + 1 + else: + start_epoch = 0 + global_step = 0 + + return start_epoch, global_step + + + +def load_backbone(model,file): + current_state=model.state_dict() + checkpoint=torch.load(file) + + for key in checkpoint: + if key in current_state: + assert current_state[key].shape==checkpoint[key].shape + current_state[key]=checkpoint[key] + model.load_state_dict(current_state) + + return model + + diff --git a/training/detectors/utils/slowfast/utils/distributed.py b/training/detectors/utils/slowfast/utils/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..bfbed8e8a4af5fc4b38c1558616fd3f640b587c0 --- /dev/null +++ b/training/detectors/utils/slowfast/utils/distributed.py @@ -0,0 +1,299 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +"""Distributed helpers.""" + +import functools +import logging +import pickle +import torch +import torch.distributed as dist + +_LOCAL_PROCESS_GROUP = None + + +def all_gather(tensors): + """ + All gathers the provided tensors from all processes across machines. + Args: + tensors (list): tensors to perform all gather across all processes in + all machines. + """ + + gather_list = [] + output_tensor = [] + world_size = dist.get_world_size() + for tensor in tensors: + tensor_placeholder = [ + torch.ones_like(tensor) for _ in range(world_size) + ] + dist.all_gather(tensor_placeholder, tensor, async_op=False) + gather_list.append(tensor_placeholder) + for gathered_tensor in gather_list: + output_tensor.append(torch.cat(gathered_tensor, dim=0)) + return output_tensor + + +def all_reduce(tensors, average=True): + """ + All reduce the provided tensors from all processes across machines. + Args: + tensors (list): tensors to perform all reduce across all processes in + all machines. + average (bool): scales the reduced tensor by the number of overall + processes across all machines. + """ + + for tensor in tensors: + dist.all_reduce(tensor, async_op=False) + if average: + world_size = dist.get_world_size() + for tensor in tensors: + tensor.mul_(1.0 / world_size) + return tensors + + +def init_process_group( + local_rank, + local_world_size, + shard_id, + num_shards, + init_method, + dist_backend="nccl", +): + """ + Initializes the default process group. + Args: + local_rank (int): the rank on the current local machine. + local_world_size (int): the world size (number of processes running) on + the current local machine. + shard_id (int): the shard index (machine rank) of the current machine. + num_shards (int): number of shards for distributed training. + init_method (string): supporting three different methods for + initializing process groups: + "file": use shared file system to initialize the groups across + different processes. + "tcp": use tcp address to initialize the groups across different + dist_backend (string): backend to use for distributed training. Options + includes gloo, mpi and nccl, the details can be found here: + https://pytorch.org/docs/stable/distributed.html + """ + # Sets the GPU to use. + torch.cuda.set_device(local_rank) + # Initialize the process group. + proc_rank = local_rank + shard_id * local_world_size + world_size = local_world_size * num_shards + dist.init_process_group( + backend=dist_backend, + init_method=init_method, + world_size=world_size, + rank=proc_rank, + ) + + +def is_master_proc(num_gpus=8): + """ + Determines if the current process is the master process. + """ + if torch.distributed.is_initialized(): + return dist.get_rank() % num_gpus == 0 + else: + return True + + +def get_world_size(): + """ + Get the size of the world. + """ + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + """ + Get the rank of the current process. + """ + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + return dist.get_rank() + + +def synchronize(): + """ + Helper function to synchronize (barrier) among all processes when + using distributed training + """ + if not dist.is_available(): + return + if not dist.is_initialized(): + return + world_size = dist.get_world_size() + if world_size == 1: + return + dist.barrier() + + +@functools.lru_cache() +def _get_global_gloo_group(): + """ + Return a process group based on gloo backend, containing all the ranks + The result is cached. + Returns: + (group): pytorch dist group. + """ + if dist.get_backend() == "nccl": + return dist.new_group(backend="gloo") + else: + return dist.group.WORLD + + +def _serialize_to_tensor(data, group): + """ + Seriialize the tensor to ByteTensor. Note that only `gloo` and `nccl` + backend is supported. + Args: + data (data): data to be serialized. + group (group): pytorch dist group. + Returns: + tensor (ByteTensor): tensor that serialized. + """ + + backend = dist.get_backend(group) + assert backend in ["gloo", "nccl"] + device = torch.device("cpu" if backend == "gloo" else "cuda") + + buffer = pickle.dumps(data) + if len(buffer) > 1024 ** 3: + logger = logging.getLogger(__name__) + logger.warning( + "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( + get_rank(), len(buffer) / (1024 ** 3), device + ) + ) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to(device=device) + return tensor + + +def _pad_to_largest_tensor(tensor, group): + """ + Padding all the tensors from different GPUs to the largest ones. + Args: + tensor (tensor): tensor to pad. + group (group): pytorch dist group. + Returns: + list[int]: size of the tensor, on each rank + Tensor: padded tensor that has the max size + """ + world_size = dist.get_world_size(group=group) + assert ( + world_size >= 1 + ), "comm.gather/all_gather must be called from ranks within the given group!" + local_size = torch.tensor( + [tensor.numel()], dtype=torch.int64, device=tensor.device + ) + size_list = [ + torch.zeros([1], dtype=torch.int64, device=tensor.device) + for _ in range(world_size) + ] + dist.all_gather(size_list, local_size, group=group) + size_list = [int(size.item()) for size in size_list] + + max_size = max(size_list) + + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + if local_size != max_size: + padding = torch.zeros( + (max_size - local_size,), dtype=torch.uint8, device=tensor.device + ) + tensor = torch.cat((tensor, padding), dim=0) + return size_list, tensor + + +def all_gather_unaligned(data, group=None): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors). + + Args: + data: any picklable object + group: a torch process group. By default, will use a group which + contains all ranks on gloo backend. + + Returns: + list[data]: list of data gathered from each rank + """ + if get_world_size() == 1: + return [data] + if group is None: + group = _get_global_gloo_group() + if dist.get_world_size(group) == 1: + return [data] + + tensor = _serialize_to_tensor(data, group) + + size_list, tensor = _pad_to_largest_tensor(tensor, group) + max_size = max(size_list) + + # receiving Tensor from all ranks + tensor_list = [ + torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) + for _ in size_list + ] + dist.all_gather(tensor_list, tensor, group=group) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def init_distributed_training(cfg): + """ + Initialize variables needed for distributed training. + """ + if cfg.NUM_GPUS <= 1: + return + num_gpus_per_machine = cfg.NUM_GPUS + num_machines = dist.get_world_size() // num_gpus_per_machine + for i in range(num_machines): + ranks_on_i = list( + range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine) + ) + pg = dist.new_group(ranks_on_i) + if i == cfg.SHARD_ID: + global _LOCAL_PROCESS_GROUP + _LOCAL_PROCESS_GROUP = pg + + +def get_local_size() -> int: + """ + Returns: + The size of the per-machine process group, + i.e. the number of processes per machine. + """ + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) + + +def get_local_rank() -> int: + """ + Returns: + The rank of the current process within the local (per-machine) process group. + """ + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + assert _LOCAL_PROCESS_GROUP is not None + return dist.get_rank(group=_LOCAL_PROCESS_GROUP) diff --git a/training/detectors/utils/slowfast/utils/env.py b/training/detectors/utils/slowfast/utils/env.py new file mode 100644 index 0000000000000000000000000000000000000000..2554915089a6e20c9ce58bba7fa59136ed65c887 --- /dev/null +++ b/training/detectors/utils/slowfast/utils/env.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +"""Set up Environment.""" + +import slowfast.utils.logging as logging + +_ENV_SETUP_DONE = False + + +def setup_environment(): + global _ENV_SETUP_DONE + if _ENV_SETUP_DONE: + return + _ENV_SETUP_DONE = True diff --git a/training/detectors/utils/slowfast/utils/logging.py b/training/detectors/utils/slowfast/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..f2763b3000c9be4a5499b310a3bc052fd5472d50 --- /dev/null +++ b/training/detectors/utils/slowfast/utils/logging.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +"""Logging.""" + +import builtins +import decimal +import functools +import logging +import os +import sys +import simplejson +from fvcore.common.file_io import PathManager + +import slowfast.utils.distributed as du + + +def _suppress_print(): + """ + Suppresses printing from the current process. + """ + + def print_pass(*objects, sep=" ", end="\n", file=sys.stdout, flush=False): + pass + + builtins.print = print_pass + + +@functools.lru_cache(maxsize=None) +def _cached_log_stream(filename): + return PathManager.open(filename, "a") + + +def setup_logging(output_dir=None): + """ + Sets up the logging for multiple processes. Only enable the logging for the + master process, and suppress logging for the non-master processes. + """ + # Set up logging format. + _FORMAT = "[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s" + + if du.is_master_proc(): + # Enable logging for the master process. + logging.root.handlers = [] + else: + # Suppress logging for non-master processes. + _suppress_print() + + logger = logging.getLogger() + logger.setLevel(logging.DEBUG) + logger.propagate = False + plain_formatter = logging.Formatter( + "[%(asctime)s][%(levelname)s] %(name)s: %(lineno)4d: %(message)s", + datefmt="%m/%d %H:%M:%S", + ) + + if du.is_master_proc(): + ch = logging.StreamHandler(stream=sys.stdout) + ch.setLevel(logging.DEBUG) + ch.setFormatter(plain_formatter) + logger.addHandler(ch) + + if output_dir is not None and du.is_master_proc(du.get_world_size()): + filename = os.path.join(output_dir, "stdout.log") + fh = logging.StreamHandler(_cached_log_stream(filename)) + fh.setLevel(logging.DEBUG) + fh.setFormatter(plain_formatter) + logger.addHandler(fh) + + +def get_logger(name): + """ + Retrieve the logger with the specified name or, if name is None, return a + logger which is the root logger of the hierarchy. + Args: + name (string): name of the logger. + """ + return logging.getLogger(name) + + +def log_json_stats(stats): + """ + Logs json stats. + Args: + stats (dict): a dictionary of statistical information to log. + """ + stats = { + k: decimal.Decimal("{:.6f}".format(v)) if isinstance(v, float) else v + for k, v in stats.items() + } + json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True) + logger = get_logger(__name__) + logger.info("json_stats: {:s}".format(json_stats)) diff --git a/training/detectors/utils/slowfast/utils/lr_policy.py b/training/detectors/utils/slowfast/utils/lr_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..4c67f8e5d9d6d576928986521dbe2eac8b1ca7e5 --- /dev/null +++ b/training/detectors/utils/slowfast/utils/lr_policy.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +"""Learning rate policy.""" + +import math + + +def get_lr_at_epoch(cfg, cur_epoch): + """ + Retrieve the learning rate of the current epoch with the option to perform + warm up in the beginning of the training stage. + Args: + cfg (CfgNode): configs. Details can be found in + slowfast/config/defaults.py + cur_epoch (float): the number of epoch of the current training stage. + """ + lr = get_lr_func(cfg.SOLVER.LR_POLICY)(cfg, cur_epoch) + # Perform warm up. + if cur_epoch < cfg.SOLVER.WARMUP_EPOCHS: + lr_start = cfg.SOLVER.WARMUP_START_LR + lr_end = get_lr_func(cfg.SOLVER.LR_POLICY)( + cfg, cfg.SOLVER.WARMUP_EPOCHS + ) + alpha = (lr_end - lr_start) / cfg.SOLVER.WARMUP_EPOCHS + lr = cur_epoch * alpha + lr_start + return lr + +def get_lr_at_iter(cfg,cur_iter): + """LR schedule that should yield 76% converged accuracy with batch size 256""" + start_step = cfg.SOLVER.TOTAL_STEP- cfg.SOLVER.LR_STEP + duration_step = cfg.SOLVER.LR_STEP + base_lr=float(cfg.SOLVER.BASE_LR) + if cur_iter <= start_step: + return base_lr + else: + this_step = cur_iter - start_step + lr = base_lr * ((this_step / duration_step) ** 2.0) + return lr + + +def lr_func_cosine(cfg, cur_epoch): + """ + Retrieve the learning rate to specified values at specified epoch with the + cosine learning rate schedule. Details can be found in: + Ilya Loshchilov, and Frank Hutter + SGDR: Stochastic Gradient Descent With Warm Restarts. + Args: + cfg (CfgNode): configs. Details can be found in + slowfast/config/defaults.py + cur_epoch (float): the number of epoch of the current training stage. + """ + return ( + cfg.SOLVER.BASE_LR + * (math.cos(math.pi * cur_epoch / cfg.SOLVER.MAX_EPOCH) + 1.0) + * 0.5 + ) + + +def lr_func_steps_with_relative_lrs(cfg, cur_epoch): + """ + Retrieve the learning rate to specified values at specified epoch with the + steps with relative learning rate schedule. + Args: + cfg (CfgNode): configs. Details can be found in + slowfast/config/defaults.py + cur_epoch (float): the number of epoch of the current training stage. + """ + ind = get_step_index(cfg, cur_epoch) + return cfg.SOLVER.LRS[ind] * cfg.SOLVER.BASE_LR + + +def get_step_index(cfg, cur_epoch): + """ + Retrieves the lr step index for the given epoch. + Args: + cfg (CfgNode): configs. Details can be found in + slowfast/config/defaults.py + cur_epoch (float): the number of epoch of the current training stage. + """ + steps = cfg.SOLVER.STEPS + [cfg.SOLVER.MAX_EPOCH] + for ind, step in enumerate(steps): # NoQA + if cur_epoch < step: + break + return ind - 1 + + +def get_lr_func(lr_policy): + """ + Given the configs, retrieve the specified lr policy function. + Args: + lr_policy (string): the learning rate policy to use for the job. + """ + policy = "lr_func_" + lr_policy + if policy not in globals(): + raise NotImplementedError("Unknown LR policy: {}".format(lr_policy)) + else: + return globals()[policy] diff --git a/training/detectors/utils/slowfast/utils/meters.py b/training/detectors/utils/slowfast/utils/meters.py new file mode 100644 index 0000000000000000000000000000000000000000..2c4e9582a2bb5f6685987ce1f0ce391ac3950419 --- /dev/null +++ b/training/detectors/utils/slowfast/utils/meters.py @@ -0,0 +1,841 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +"""Meters.""" + +import datetime +import numpy as np +import os +from collections import defaultdict, deque +import torch +from fvcore.common.timer import Timer +from sklearn.metrics import average_precision_score + +import slowfast.datasets.ava_helper as ava_helper +import slowfast.utils.logging as logging +import slowfast.utils.metrics as metrics +import slowfast.utils.misc as misc +from slowfast.utils.ava_eval_helper import ( + evaluate_ava, + read_csv, + read_exclusions, + read_labelmap, +) + +logger = logging.get_logger(__name__) + + +def get_ava_mini_groundtruth(full_groundtruth): + """ + Get the groundtruth annotations corresponding the "subset" of AVA val set. + We define the subset to be the frames such that (second % 4 == 0). + We optionally use subset for faster evaluation during training + (in order to track training progress). + Args: + full_groundtruth(dict): list of groundtruth. + """ + ret = [defaultdict(list), defaultdict(list), defaultdict(list)] + + for i in range(3): + for key in full_groundtruth[i].keys(): + if int(key.split(",")[1]) % 4 == 0: + ret[i][key] = full_groundtruth[i][key] + return ret + + +class AVAMeter(object): + """ + Measure the AVA train, val, and test stats. + """ + + def __init__(self, overall_iters, cfg, mode): + """ + overall_iters (int): the overall number of iterations of one epoch. + cfg (CfgNode): configs. + mode (str): `train`, `val`, or `test` mode. + """ + self.cfg = cfg + self.lr = None + self.loss = ScalarMeter(cfg.LOG_PERIOD) + self.full_ava_test = cfg.AVA.FULL_TEST_ON_VAL + self.mode = mode + self.iter_timer = Timer() + self.all_preds = [] + self.all_ori_boxes = [] + self.all_metadata = [] + self.overall_iters = overall_iters + self.excluded_keys = read_exclusions( + os.path.join(cfg.AVA.ANNOTATION_DIR, cfg.AVA.EXCLUSION_FILE) + ) + self.categories, self.class_whitelist = read_labelmap( + os.path.join(cfg.AVA.ANNOTATION_DIR, cfg.AVA.LABEL_MAP_FILE) + ) + gt_filename = os.path.join( + cfg.AVA.ANNOTATION_DIR, cfg.AVA.GROUNDTRUTH_FILE + ) + self.full_groundtruth = read_csv(gt_filename, self.class_whitelist) + self.mini_groundtruth = get_ava_mini_groundtruth(self.full_groundtruth) + + _, self.video_idx_to_name = ava_helper.load_image_lists( + cfg, mode == "train" + ) + + def log_iter_stats(self, cur_epoch, cur_iter): + """ + Log the stats. + Args: + cur_epoch (int): the current epoch. + cur_iter (int): the current iteration. + """ + + if (cur_iter + 1) % self.cfg.LOG_PERIOD != 0: + return + + eta_sec = self.iter_timer.seconds() * (self.overall_iters - cur_iter) + eta = str(datetime.timedelta(seconds=int(eta_sec))) + if self.mode == "train": + stats = { + "_type": "{}_iter".format(self.mode), + "cur_epoch": "{}".format(cur_epoch + 1), + "cur_iter": "{}".format(cur_iter + 1), + "eta": eta, + "time_diff": self.iter_timer.seconds(), + "mode": self.mode, + "loss": self.loss.get_win_median(), + "lr": self.lr, + } + elif self.mode == "val": + stats = { + "_type": "{}_iter".format(self.mode), + "cur_epoch": "{}".format(cur_epoch + 1), + "cur_iter": "{}".format(cur_iter + 1), + "eta": eta, + "time_diff": self.iter_timer.seconds(), + "mode": self.mode, + } + elif self.mode == "test": + stats = { + "_type": "{}_iter".format(self.mode), + "cur_iter": "{}".format(cur_iter + 1), + "eta": eta, + "time_diff": self.iter_timer.seconds(), + "mode": self.mode, + } + else: + raise NotImplementedError("Unknown mode: {}".format(self.mode)) + + logging.log_json_stats(stats) + + def iter_tic(self): + """ + Start to record time. + """ + self.iter_timer.reset() + + def iter_toc(self): + """ + Stop to record time. + """ + self.iter_timer.pause() + + def reset(self): + """ + Reset the Meter. + """ + self.loss.reset() + + self.all_preds = [] + self.all_ori_boxes = [] + self.all_metadata = [] + + def update_stats(self, preds, ori_boxes, metadata, loss=None, lr=None): + """ + Update the current stats. + Args: + preds (tensor): prediction embedding. + ori_boxes (tensor): original boxes (x1, y1, x2, y2). + metadata (tensor): metadata of the AVA data. + loss (float): loss value. + lr (float): learning rate. + """ + if self.mode in ["val", "test"]: + self.all_preds.append(preds) + self.all_ori_boxes.append(ori_boxes) + self.all_metadata.append(metadata) + if loss is not None: + self.loss.add_value(loss) + if lr is not None: + self.lr = lr + + def finalize_metrics(self, log=True): + """ + Calculate and log the final AVA metrics. + """ + all_preds = torch.cat(self.all_preds, dim=0) + all_ori_boxes = torch.cat(self.all_ori_boxes, dim=0) + all_metadata = torch.cat(self.all_metadata, dim=0) + + if self.mode == "test" or (self.full_ava_test and self.mode == "val"): + groundtruth = self.full_groundtruth + else: + groundtruth = self.mini_groundtruth + + self.full_map = evaluate_ava( + all_preds, + all_ori_boxes, + all_metadata.tolist(), + self.excluded_keys, + self.class_whitelist, + self.categories, + groundtruth=groundtruth, + video_idx_to_name=self.video_idx_to_name, + ) + if log: + stats = {"mode": self.mode, "map": self.full_map} + logging.log_json_stats(stats) + + def log_epoch_stats(self, cur_epoch): + """ + Log the stats of the current epoch. + Args: + cur_epoch (int): the number of current epoch. + """ + if self.mode in ["val", "test"]: + self.finalize_metrics(log=False) + stats = { + "_type": "{}_epoch".format(self.mode), + "cur_epoch": "{}".format(cur_epoch + 1), + "mode": self.mode, + "map": self.full_map, + "gpu_mem": "{:.2f} GB".format(misc.gpu_mem_usage()), + "RAM": "{:.2f}/{:.2f} GB".format(*misc.cpu_mem_usage()), + } + logging.log_json_stats(stats) + + +class TestMeter(object): + """ + Perform the multi-view ensemble for testing: each video with an unique index + will be sampled with multiple clips, and the predictions of the clips will + be aggregated to produce the final prediction for the video. + The accuracy is calculated with the given ground truth labels. + """ + + def __init__( + self, + num_videos, + num_clips, + num_cls, + overall_iters, + multi_label=False, + ensemble_method="sum", + ): + """ + Construct tensors to store the predictions and labels. Expect to get + num_clips predictions from each video, and calculate the metrics on + num_videos videos. + Args: + num_videos (int): number of videos to test. + num_clips (int): number of clips sampled from each video for + aggregating the final prediction for the video. + num_cls (int): number of classes for each prediction. + overall_iters (int): overall iterations for testing. + multi_label (bool): if True, use map as the metric. + ensemble_method (str): method to perform the ensemble, options + include "sum", and "max". + """ + + self.iter_timer = Timer() + self.num_clips = num_clips + self.overall_iters = overall_iters + self.multi_label = multi_label + self.ensemble_method = ensemble_method + # Initialize tensors. + self.video_preds = torch.zeros((num_videos, num_cls)) + if multi_label: + self.video_preds -= 1e10 + + self.video_labels = ( + torch.zeros((num_videos, num_cls)) + if multi_label + else torch.zeros((num_videos)).long() + ) + self.clip_count = torch.zeros((num_videos)).long() + # Reset metric. + self.reset() + + def reset(self): + """ + Reset the metric. + """ + self.clip_count.zero_() + self.video_preds.zero_() + if self.multi_label: + self.video_preds -= 1e10 + self.video_labels.zero_() + + def update_stats(self, preds, labels, clip_ids): + """ + Collect the predictions from the current batch and perform on-the-flight + summation as ensemble. + Args: + preds (tensor): predictions from the current batch. Dimension is + N x C where N is the batch size and C is the channel size + (num_cls). + labels (tensor): the corresponding labels of the current batch. + Dimension is N. + clip_ids (tensor): clip indexes of the current batch, dimension is + N. + """ + for ind in range(preds.shape[0]): + vid_id = int(clip_ids[ind]) // self.num_clips + if self.video_labels[vid_id].sum() > 0: + assert torch.equal( + self.video_labels[vid_id].type(torch.FloatTensor), + labels[ind].type(torch.FloatTensor), + ) + self.video_labels[vid_id] = labels[ind] + if self.ensemble_method == "sum": + self.video_preds[vid_id] += preds[ind] + elif self.ensemble_method == "max": + self.video_preds[vid_id] = torch.max( + self.video_preds[vid_id], preds[ind] + ) + else: + raise NotImplementedError( + "Ensemble Method {} is not supported".format( + self.ensemble_method + ) + ) + self.clip_count[vid_id] += 1 + + def log_iter_stats(self, cur_iter): + """ + Log the stats. + Args: + cur_iter (int): the current iteration of testing. + """ + eta_sec = self.iter_timer.seconds() * (self.overall_iters - cur_iter) + eta = str(datetime.timedelta(seconds=int(eta_sec))) + stats = { + "split": "test_iter", + "cur_iter": "{}".format(cur_iter + 1), + "eta": eta, + "time_diff": self.iter_timer.seconds(), + } + logging.log_json_stats(stats) + + def iter_tic(self): + self.iter_timer.reset() + + def iter_toc(self): + self.iter_timer.pause() + + def finalize_metrics(self, ks=(1, 5)): + """ + Calculate and log the final ensembled metrics. + ks (tuple): list of top-k values for topk_accuracies. For example, + ks = (1, 5) correspods to top-1 and top-5 accuracy. + """ + if not all(self.clip_count == self.num_clips): + logger.warning( + "clip count {} ~= num clips {}".format( + ", ".join( + [ + "{}: {}".format(i, k) + for i, k in enumerate(self.clip_count.tolist()) + ] + ), + self.num_clips, + ) + ) + + stats = {"split": "test_final"} + if self.multi_label: + map = get_map( + self.video_preds.cpu().numpy(), self.video_labels.cpu().numpy() + ) + stats["map"] = map + else: + num_topks_correct = metrics.topks_correct( + self.video_preds, self.video_labels, ks + ) + topks = [ + (x / self.video_preds.size(0)) * 100.0 + for x in num_topks_correct + ] + assert len({len(ks), len(topks)}) == 1 + for k, topk in zip(ks, topks): + stats["top{}_acc".format(k)] = "{:.{prec}f}".format( + topk, prec=2 + ) + logging.log_json_stats(stats) + + +class ScalarMeter(object): + """ + A scalar meter uses a deque to track a series of scaler values with a given + window size. It supports calculating the median and average values of the + window, and also supports calculating the global average. + """ + + def __init__(self, window_size): + """ + Args: + window_size (int): size of the max length of the deque. + """ + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + + def reset(self): + """ + Reset the deque. + """ + self.deque.clear() + self.total = 0.0 + self.count = 0 + + def add_value(self, value): + """ + Add a new scalar value to the deque. + """ + self.deque.append(value) + self.count += 1 + self.total += value + + def get_win_median(self): + """ + Calculate the current median value of the deque. + """ + return np.median(self.deque) + + def get_win_avg(self): + """ + Calculate the current average value of the deque. + """ + return np.mean(self.deque) + + def get_global_avg(self): + """ + Calculate the global mean value. + """ + return self.total / self.count + + +class TrainMeter(object): + """ + Measure training stats. + """ + + def __init__(self, epoch_iters, cfg): + """ + Args: + epoch_iters (int): the overall number of iterations of one epoch. + cfg (CfgNode): configs. + """ + self._cfg = cfg + self.epoch_iters = epoch_iters + self.MAX_EPOCH = cfg.SOLVER.MAX_EPOCH * epoch_iters + self.iter_timer = Timer() + self.loss = ScalarMeter(cfg.LOG_PERIOD) + self.loss_total = 0.0 + self.lr = None + # Current minibatch errors (smoothed over a window). + self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD) + self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD) + # Number of misclassified examples. + self.num_top1_mis = 0 + self.num_top5_mis = 0 + self.num_samples = 0 + + def reset(self): + """ + Reset the Meter. + """ + self.loss.reset() + self.loss_total = 0.0 + self.lr = None + self.mb_top1_err.reset() + self.mb_top5_err.reset() + self.num_top1_mis = 0 + self.num_top5_mis = 0 + self.num_samples = 0 + + def iter_tic(self): + """ + Start to record time. + """ + self.iter_timer.reset() + + def iter_toc(self): + """ + Stop to record time. + """ + self.iter_timer.pause() + + def update_stats(self, top1_err, top5_err, loss, lr, mb_size): + """ + Update the current stats. + Args: + top1_err (float): top1 error rate. + top5_err (float): top5 error rate. + loss (float): loss value. + lr (float): learning rate. + mb_size (int): mini batch size. + """ + self.loss.add_value(loss) + self.lr = lr + self.loss_total += loss * mb_size + self.num_samples += mb_size + + if not self._cfg.DATA.MULTI_LABEL: + # Current minibatch stats + self.mb_top1_err.add_value(top1_err) + self.mb_top5_err.add_value(top5_err) + # Aggregate stats + self.num_top1_mis += top1_err * mb_size + self.num_top5_mis += top5_err * mb_size + + def log_iter_stats(self, cur_epoch, cur_iter): + """ + log the stats of the current iteration. + Args: + cur_epoch (int): the number of current epoch. + cur_iter (int): the number of current iteration. + """ + if (cur_iter + 1) % self._cfg.LOG_PERIOD != 0: + return + eta_sec = self.iter_timer.seconds() * ( + self.MAX_EPOCH - (cur_epoch * self.epoch_iters + cur_iter + 1) + ) + eta = str(datetime.timedelta(seconds=int(eta_sec))) + stats = { + "_type": "train_iter", + "epoch": "{}/{}".format(cur_epoch + 1, self._cfg.SOLVER.MAX_EPOCH), + "iter": "{}/{}".format(cur_iter + 1, self.epoch_iters), + "time_diff": self.iter_timer.seconds(), + "eta": eta, + "loss": self.loss.get_win_median(), + "lr": self.lr, + "gpu_mem": "{:.2f} GB".format(misc.gpu_mem_usage()), + } + if not self._cfg.DATA.MULTI_LABEL: + stats["top1_err"] = self.mb_top1_err.get_win_median() + stats["top5_err"] = self.mb_top5_err.get_win_median() + logging.log_json_stats(stats) + + def log_epoch_stats(self, cur_epoch): + """ + Log the stats of the current epoch. + Args: + cur_epoch (int): the number of current epoch. + """ + eta_sec = self.iter_timer.seconds() * ( + self.MAX_EPOCH - (cur_epoch + 1) * self.epoch_iters + ) + eta = str(datetime.timedelta(seconds=int(eta_sec))) + stats = { + "_type": "train_epoch", + "epoch": "{}/{}".format(cur_epoch + 1, self._cfg.SOLVER.MAX_EPOCH), + "time_diff": self.iter_timer.seconds(), + "eta": eta, + "lr": self.lr, + "gpu_mem": "{:.2f} GB".format(misc.gpu_mem_usage()), + "RAM": "{:.2f}/{:.2f} GB".format(*misc.cpu_mem_usage()), + } + if not self._cfg.DATA.MULTI_LABEL: + top1_err = self.num_top1_mis / self.num_samples + top5_err = self.num_top5_mis / self.num_samples + avg_loss = self.loss_total / self.num_samples + stats["top1_err"] = top1_err + stats["top5_err"] = top5_err + stats["loss"] = avg_loss + logging.log_json_stats(stats) + + +class TrainIterMeter(object): + """ + Measure training stats. + """ + + def __init__(self, epoch_iters, cfg,extra=[]): + """ + Args: + epoch_iters (int): the overall number of iterations of one epoch. + cfg (CfgNode): configs. + """ + self._cfg = cfg + self.epoch_iters = epoch_iters + self.MAX_EPOCH = cfg.SOLVER.MAX_EPOCH * epoch_iters + self.iter_timer = Timer() + self.loss = ScalarMeter(cfg.LOG_PERIOD) + self.loss_total = 0.0 + self.lr = None + + # Number of misclassified examples. + self.num_samples = 0 + + self.meters={key:ScalarMeter(cfg.LOG_PERIOD) for key in extra} + + def reset(self): + """ + Reset the Meter. + """ + self.loss.reset() + self.loss_total = 0.0 + self.lr = None + + + self.num_samples = 0 + + for meter in self.meters.values(): + meter.reset() + + def iter_tic(self): + """ + Start to record time. + """ + self.iter_timer.reset() + + def iter_toc(self): + """ + Stop to record time. + """ + self.iter_timer.pause() + + def update_stats(self, loss, lr, mb_size,extra={}): + """ + Update the current stats. + Args: + top1_err (float): top1 error rate. + top5_err (float): top5 error rate. + loss (float): loss value. + lr (float): learning rate. + mb_size (int): mini batch size. + """ + self.loss.add_value(loss) + self.lr = lr + self.loss_total += loss * mb_size + self.num_samples += mb_size + + + for key,val in extra.items(): + self.meters[key].add_value(val) + + def log_iter_stats(self, cur_epoch, cur_iter,extra={}): + """ + log the stats of the current iteration. + Args: + cur_epoch (int): the number of current epoch. + cur_iter (int): the number of current iteration. + """ + if (cur_iter + 1) % self._cfg.LOG_PERIOD != 0: + return + eta_sec = self.iter_timer.seconds() * ( + self.MAX_EPOCH - (cur_epoch * self.epoch_iters + cur_iter + 1) + ) + eta = str(datetime.timedelta(seconds=int(eta_sec))) + stats = { + "_type": "train_iter", + "epoch": "{}/{}".format(cur_epoch + 1, self._cfg.SOLVER.MAX_EPOCH), + "iter": "{}/{}".format(cur_iter + 1, self.epoch_iters), + "time_diff": self.iter_timer.seconds(), + "eta": eta, + "loss": self.loss.get_win_median(), + "lr": self.lr, + "gpu_mem": "{:.2f} GB".format(misc.gpu_mem_usage()), + } + + for key,meter in self.meters.items(): + stats[key]=meter.get_win_median() + for key,val in extra.items(): + stats[key]=val + + logging.log_json_stats(stats) + + def log_epoch_stats(self, cur_epoch): + """ + Log the stats of the current epoch. + Args: + cur_epoch (int): the number of current epoch. + """ + eta_sec = self.iter_timer.seconds() * ( + self.MAX_EPOCH - (cur_epoch + 1) * self.epoch_iters + ) + eta = str(datetime.timedelta(seconds=int(eta_sec))) + stats = { + "_type": "train_epoch", + "epoch": "{}/{}".format(cur_epoch + 1, self._cfg.SOLVER.MAX_EPOCH), + "time_diff": self.iter_timer.seconds(), + "eta": eta, + "lr": self.lr, + "gpu_mem": "{:.2f} GB".format(misc.gpu_mem_usage()), + "RAM": "{:.2f}/{:.2f} GB".format(*misc.cpu_mem_usage()), + } + if not self._cfg.DATA.MULTI_LABEL: + avg_loss = self.loss_total / self.num_samples + stats["loss"] = avg_loss + logging.log_json_stats(stats) + + + + +class ValMeter(object): + """ + Measures validation stats. + """ + + def __init__(self, max_iter, cfg): + """ + Args: + max_iter (int): the max number of iteration of the current epoch. + cfg (CfgNode): configs. + """ + self._cfg = cfg + self.max_iter = max_iter + self.iter_timer = Timer() + # Current minibatch errors (smoothed over a window). + self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD) + self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD) + # Min errors (over the full val set). + self.min_top1_err = 100.0 + self.min_top5_err = 100.0 + # Number of misclassified examples. + self.num_top1_mis = 0 + self.num_top5_mis = 0 + self.num_samples = 0 + self.all_preds = [] + self.all_labels = [] + + def reset(self): + """ + Reset the Meter. + """ + self.iter_timer.reset() + self.mb_top1_err.reset() + self.mb_top5_err.reset() + self.num_top1_mis = 0 + self.num_top5_mis = 0 + self.num_samples = 0 + self.all_preds = [] + self.all_labels = [] + + def iter_tic(self): + """ + Start to record time. + """ + self.iter_timer.reset() + + def iter_toc(self): + """ + Stop to record time. + """ + self.iter_timer.pause() + + def update_stats(self, top1_err, top5_err, mb_size): + """ + Update the current stats. + Args: + top1_err (float): top1 error rate. + top5_err (float): top5 error rate. + mb_size (int): mini batch size. + """ + self.mb_top1_err.add_value(top1_err) + self.mb_top5_err.add_value(top5_err) + self.num_top1_mis += top1_err * mb_size + self.num_top5_mis += top5_err * mb_size + self.num_samples += mb_size + + def update_predictions(self, preds, labels): + """ + Update predictions and labels. + Args: + preds (tensor): model output predictions. + labels (tensor): labels. + """ + # TODO: merge update_prediction with update_stats. + self.all_preds.append(preds) + self.all_labels.append(labels) + + def log_iter_stats(self, cur_epoch, cur_iter): + """ + log the stats of the current iteration. + Args: + cur_epoch (int): the number of current epoch. + cur_iter (int): the number of current iteration. + """ + if (cur_iter + 1) % self._cfg.LOG_PERIOD != 0: + return + eta_sec = self.iter_timer.seconds() * (self.max_iter - cur_iter - 1) + eta = str(datetime.timedelta(seconds=int(eta_sec))) + stats = { + "_type": "val_iter", + "epoch": "{}/{}".format(cur_epoch + 1, self._cfg.SOLVER.MAX_EPOCH), + "iter": "{}/{}".format(cur_iter + 1, self.max_iter), + "time_diff": self.iter_timer.seconds(), + "eta": eta, + "gpu_mem": "{:.2f} GB".format(misc.gpu_mem_usage()), + } + if not self._cfg.DATA.MULTI_LABEL: + stats["top1_err"] = self.mb_top1_err.get_win_median() + stats["top5_err"] = self.mb_top5_err.get_win_median() + logging.log_json_stats(stats) + + def log_epoch_stats(self, cur_epoch): + """ + Log the stats of the current epoch. + Args: + cur_epoch (int): the number of current epoch. + """ + stats = { + "_type": "val_epoch", + "epoch": "{}/{}".format(cur_epoch + 1, self._cfg.SOLVER.MAX_EPOCH), + "time_diff": self.iter_timer.seconds(), + "gpu_mem": "{:.2f} GB".format(misc.gpu_mem_usage()), + "RAM": "{:.2f}/{:.2f} GB".format(*misc.cpu_mem_usage()), + } + if self._cfg.DATA.MULTI_LABEL: + stats["map"] = get_map( + torch.cat(self.all_preds).cpu().numpy(), + torch.cat(self.all_labels).cpu().numpy(), + ) + else: + top1_err = self.num_top1_mis / self.num_samples + top5_err = self.num_top5_mis / self.num_samples + self.min_top1_err = min(self.min_top1_err, top1_err) + self.min_top5_err = min(self.min_top5_err, top5_err) + + stats["top1_err"] = top1_err + stats["top5_err"] = top5_err + stats["min_top1_err"] = self.min_top1_err + stats["min_top5_err"] = self.min_top5_err + + logging.log_json_stats(stats) + + +def get_map(preds, labels): + """ + Compute mAP for multi-label case. + Args: + preds (numpy tensor): num_examples x num_classes. + labels (numpy tensor): num_examples x num_classes. + Returns: + mean_ap (int): final mAP score. + """ + + logger.info("Getting mAP for {} examples".format(preds.shape[0])) + + preds = preds[:, ~(np.all(labels == 0, axis=0))] + labels = labels[:, ~(np.all(labels == 0, axis=0))] + aps = [0] + try: + aps = average_precision_score(labels, preds, average=None) + except ValueError: + print( + "Average precision requires a sufficient number of samples \ + in a batch which are missing in this sample." + ) + + mean_ap = np.mean(aps) + return mean_ap diff --git a/training/detectors/utils/slowfast/utils/metrics.py b/training/detectors/utils/slowfast/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..0ef01b174aa5c3d54da77923f515f244327c4e80 --- /dev/null +++ b/training/detectors/utils/slowfast/utils/metrics.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +"""Functions for computing metrics.""" + +import torch + + +def topks_correct(preds, labels, ks): + """ + Given the predictions, labels, and a list of top-k values, compute the + number of correct predictions for each top-k value. + + Args: + preds (array): array of predictions. Dimension is batchsize + N x ClassNum. + labels (array): array of labels. Dimension is batchsize N. + ks (list): list of top-k values. For example, ks = [1, 5] correspods + to top-1 and top-5. + + Returns: + topks_correct (list): list of numbers, where the `i`-th entry + corresponds to the number of top-`ks[i]` correct predictions. + """ + assert preds.size(0) == labels.size( + 0 + ), "Batch dim of predictions and labels must match" + # Find the top max_k predictions for each sample + _top_max_k_vals, top_max_k_inds = torch.topk( + preds, max(ks), dim=1, largest=True, sorted=True + ) + # (batch_size, max_k) -> (max_k, batch_size). + top_max_k_inds = top_max_k_inds.t() + # (batch_size, ) -> (max_k, batch_size). + rep_max_k_labels = labels.view(1, -1).expand_as(top_max_k_inds) + # (i, j) = 1 if top i-th prediction for the j-th sample is correct. + top_max_k_correct = top_max_k_inds.eq(rep_max_k_labels) + # Compute the number of topk correct predictions for each k. + topks_correct = [ + top_max_k_correct[:k, :].view(-1).float().sum() for k in ks + ] + return topks_correct + + +def topk_errors(preds, labels, ks): + """ + Computes the top-k error for each k. + Args: + preds (array): array of predictions. Dimension is N. + labels (array): array of labels. Dimension is N. + ks (list): list of ks to calculate the top accuracies. + """ + num_topks_correct = topks_correct(preds, labels, ks) + return [(1.0 - x / preds.size(0)) * 100.0 for x in num_topks_correct] + + +def topk_accuracies(preds, labels, ks): + """ + Computes the top-k accuracy for each k. + Args: + preds (array): array of predictions. Dimension is N. + labels (array): array of labels. Dimension is N. + ks (list): list of ks to calculate the top accuracies. + """ + num_topks_correct = topks_correct(preds, labels, ks) + return [(x / preds.size(0)) * 100.0 for x in num_topks_correct] diff --git a/training/detectors/utils/slowfast/utils/misc.py b/training/detectors/utils/slowfast/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..13fea609bca56845d61819707a66cbd807b956d8 --- /dev/null +++ b/training/detectors/utils/slowfast/utils/misc.py @@ -0,0 +1,359 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +import json +import logging +import math +import numpy as np +import os +from datetime import datetime +import psutil +import torch +from fvcore.common.file_io import PathManager +from fvcore.nn.activation_count import activation_count +from fvcore.nn.flop_count import flop_count +from matplotlib import pyplot as plt +from torch import nn + +import slowfast.utils.logging as logging +import slowfast.utils.multiprocessing as mpu +from slowfast.datasets.utils import pack_pathway_output +from slowfast.models.batchnorm_helper import SubBatchNorm3d + +logger = logging.get_logger(__name__) + + +def check_nan_losses(loss): + """ + Determine whether the loss is NaN (not a number). + Args: + loss (loss): loss to check whether is NaN. + """ + if math.isnan(loss): + raise RuntimeError("ERROR: Got NaN losses {}".format(datetime.now())) + + +def params_count(model): + """ + Compute the number of parameters. + Args: + model (model): model to count the number of parameters. + """ + return np.sum([p.numel() for p in model.parameters()]).item() + + +def gpu_mem_usage(): + """ + Compute the GPU memory usage for the current device (GB). + """ + if torch.cuda.is_available(): + mem_usage_bytes = torch.cuda.max_memory_allocated() + else: + mem_usage_bytes = 0 + return mem_usage_bytes / 1024 ** 3 + + +def cpu_mem_usage(): + """ + Compute the system memory (RAM) usage for the current device (GB). + Returns: + usage (float): used memory (GB). + total (float): total memory (GB). + """ + vram = psutil.virtual_memory() + usage = (vram.total - vram.available) / 1024 ** 3 + total = vram.total / 1024 ** 3 + + return usage, total + + +def _get_model_analysis_input(cfg, use_train_input): + """ + Return a dummy input for model analysis with batch size 1. The input is + used for analyzing the model (counting flops and activations etc.). + Args: + cfg (CfgNode): configs. Details can be found in + slowfast/config/defaults.py + use_train_input (bool): if True, return the input for training. Otherwise, + return the input for testing. + + Returns: + inputs: the input for model analysis. + """ + rgb_dimension = 3 + if use_train_input: + input_tensors = torch.rand( + rgb_dimension, + cfg.DATA.NUM_FRAMES, + cfg.DATA.TRAIN_CROP_SIZE, + cfg.DATA.TRAIN_CROP_SIZE, + ) + else: + input_tensors = torch.rand( + rgb_dimension, + cfg.DATA.NUM_FRAMES, + cfg.DATA.TEST_CROP_SIZE, + cfg.DATA.TEST_CROP_SIZE, + ) + model_inputs = pack_pathway_output(cfg, input_tensors) + for i in range(len(model_inputs)): + model_inputs[i] = model_inputs[i].unsqueeze(0) + if cfg.NUM_GPUS: + model_inputs[i] = model_inputs[i].cuda(non_blocking=True) + + # If detection is enabled, count flops for one proposal. + if cfg.DETECTION.ENABLE: + bbox = torch.tensor([[0, 0, 1.0, 0, 1.0]]) + if cfg.NUM_GPUS: + bbox = bbox.cuda() + inputs = (model_inputs, bbox) + else: + inputs = (model_inputs,) + return inputs + + +def get_model_stats(model, cfg, mode, use_train_input): + """ + Compute statistics for the current model given the config. + Args: + model (model): model to perform analysis. + cfg (CfgNode): configs. Details can be found in + slowfast/config/defaults.py + mode (str): Options include `flop` or `activation`. Compute either flop + (gflops) or activation count (mega). + use_train_input (bool): if True, compute statistics for training. Otherwise, + compute statistics for testing. + + Returns: + float: the total number of count of the given model. + """ + assert mode in [ + "flop", + "activation", + ], "'{}' not supported for model analysis".format(mode) + if mode == "flop": + model_stats_fun = flop_count + elif mode == "activation": + model_stats_fun = activation_count + + # Set model to evaluation mode for analysis. + # Evaluation mode can avoid getting stuck with sync batchnorm. + model_mode = model.training + model.eval() + inputs = _get_model_analysis_input(cfg, use_train_input) + count_dict, _ = model_stats_fun(model, inputs) + count = sum(count_dict.values()) + model.train(model_mode) + return count + + +def log_model_info(model, cfg, use_train_input=True): + """ + Log info, includes number of parameters, gpu usage, gflops and activation count. + The model info is computed when the model is in validation mode. + Args: + model (model): model to log the info. + cfg (CfgNode): configs. Details can be found in + slowfast/config/defaults.py + use_train_input (bool): if True, log info for training. Otherwise, + log info for testing. + """ + print("Model:\n{}".format(model)) + print("Params: {:,}".format(params_count(model))) + print("Mem: {:,} MB".format(gpu_mem_usage())) + print( + "Flops: {:,} G".format( + get_model_stats(model, cfg, "flop", use_train_input) + ) + ) + print( + "Activations: {:,} M".format( + get_model_stats(model, cfg, "activation", use_train_input) + ) + ) + logger.info("nvidia-smi") + os.system("nvidia-smi") + +def is_eval_epoch(cfg, cur_epoch, multigrid_schedule): + """ + Determine if the model should be evaluated at the current epoch. + Args: + cfg (CfgNode): configs. Details can be found in + slowfast/config/defaults.py + cur_epoch (int): current epoch. + multigrid_schedule (List): schedule for multigrid training. + """ + if cur_epoch + 1 == cfg.SOLVER.MAX_EPOCH: + return True + if multigrid_schedule is not None: + prev_epoch = 0 + for s in multigrid_schedule: + if cur_epoch < s[-1]: + period = max( + (s[-1] - prev_epoch) // cfg.MULTIGRID.EVAL_FREQ + 1, 1 + ) + return (s[-1] - 1 - cur_epoch) % period == 0 + prev_epoch = s[-1] + + return (cur_epoch + 1) % cfg.TRAIN.EVAL_PERIOD == 0 + + +def plot_input(tensor, bboxes=(), texts=(), path="./tmp_vis.png"): + """ + Plot the input tensor with the optional bounding box and save it to disk. + Args: + tensor (tensor): a tensor with shape of `NxCxHxW`. + bboxes (tuple): bounding boxes with format of [[x, y, h, w]]. + texts (tuple): a tuple of string to plot. + path (str): path to the image to save to. + """ + tensor = tensor - tensor.min() + tensor = tensor / tensor.max() + f, ax = plt.subplots(nrows=1, ncols=tensor.shape[0], figsize=(50, 20)) + for i in range(tensor.shape[0]): + ax[i].axis("off") + ax[i].imshow(tensor[i].permute(1, 2, 0)) + # ax[1][0].axis('off') + if bboxes is not None and len(bboxes) > i: + for box in bboxes[i]: + x1, y1, x2, y2 = box + ax[i].vlines(x1, y1, y2, colors="g", linestyles="solid") + ax[i].vlines(x2, y1, y2, colors="g", linestyles="solid") + ax[i].hlines(y1, x1, x2, colors="g", linestyles="solid") + ax[i].hlines(y2, x1, x2, colors="g", linestyles="solid") + + if texts is not None and len(texts) > i: + ax[i].text(0, 0, texts[i]) + f.savefig(path) + + +def frozen_bn_stats(model): + """ + Set all the bn layers to eval mode. + Args: + model (model): model to set bn layers to eval mode. + """ + for m in model.modules(): + if isinstance(m, nn.BatchNorm3d): + m.eval() + + +def aggregate_sub_bn_stats(module): + """ + Recursively find all SubBN modules and aggregate sub-BN stats. + Args: + module (nn.Module) + Returns: + count (int): number of SubBN module found. + """ + count = 0 + for child in module.children(): + if isinstance(child, SubBatchNorm3d): + child.aggregate_stats() + count += 1 + else: + count += aggregate_sub_bn_stats(child) + return count + + +def launch_job(cfg, init_method, func, daemon=False): + """ + Run 'func' on one or more GPUs, specified in cfg + Args: + cfg (CfgNode): configs. Details can be found in + slowfast/config/defaults.py + init_method (str): initialization method to launch the job with multiple + devices. + func (function): job to run on GPU(s) + daemon (bool): The spawned processes’ daemon flag. If set to True, + daemonic processes will be created + """ + if cfg.NUM_GPUS > 1: + torch.multiprocessing.spawn( + mpu.run, + nprocs=cfg.NUM_GPUS, + args=( + cfg.NUM_GPUS, + func, + init_method, + cfg.SHARD_ID, + cfg.NUM_SHARDS, + cfg.DIST_BACKEND, + cfg, + ), + daemon=daemon, + ) + else: + func(cfg=cfg) + + +def get_class_names(path, parent_path=None, subset_path=None): + """ + Read json file with entries {classname: index} and return + an array of class names in order. + If parent_path is provided, load and map all children to their ids. + Args: + path (str): path to class ids json file. + File must be in the format {"class1": id1, "class2": id2, ...} + parent_path (Optional[str]): path to parent-child json file. + File must be in the format {"parent1": ["child1", "child2", ...], ...} + subset_path (Optional[str]): path to text file containing a subset + of class names, separated by newline characters. + Returns: + class_names (list of strs): list of class names. + class_parents (dict): a dictionary where key is the name of the parent class + and value is a list of ids of the children classes. + subset_ids (list of ints): list of ids of the classes provided in the + subset file. + """ + try: + with PathManager.open(path, "r") as f: + class2idx = json.load(f) + except Exception as err: + print("Fail to load file from {} with error {}".format(path, err)) + return + + max_key = max(class2idx.values()) + class_names = [None] * (max_key + 1) + + for k, i in class2idx.items(): + class_names[i] = k + + class_parent = None + if parent_path is not None and parent_path != "": + try: + with PathManager.open(parent_path, "r") as f: + d_parent = json.load(f) + except EnvironmentError as err: + print( + "Fail to load file from {} with error {}".format( + parent_path, err + ) + ) + return + class_parent = {} + for parent, children in d_parent.items(): + indices = [ + class2idx[c] for c in children if class2idx.get(c) is not None + ] + class_parent[parent] = indices + + subset_ids = None + if subset_path is not None and subset_path != "": + try: + with PathManager.open(subset_path, "r") as f: + subset = f.read().split("\n") + subset_ids = [ + class2idx[name] + for name in subset + if class2idx.get(name) is not None + ] + except EnvironmentError as err: + print( + "Fail to load file from {} with error {}".format( + subset_path, err + ) + ) + return + + return class_names, class_parent, subset_ids diff --git a/training/detectors/utils/slowfast/utils/multigrid.py b/training/detectors/utils/slowfast/utils/multigrid.py new file mode 100644 index 0000000000000000000000000000000000000000..4aed24bb4889d30960cec5ca94ab20a73b40b9e1 --- /dev/null +++ b/training/detectors/utils/slowfast/utils/multigrid.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +"""Helper functions for multigrid training.""" + +import numpy as np + +import slowfast.utils.logging as logging + +logger = logging.get_logger(__name__) + + +class MultigridSchedule(object): + """ + This class defines multigrid training schedule and update cfg accordingly. + """ + + def init_multigrid(self, cfg): + """ + Update cfg based on multigrid settings. + Args: + cfg (configs): configs that contains training and multigrid specific + hyperparameters. Details can be seen in + slowfast/config/defaults.py. + Returns: + cfg (configs): the updated cfg. + """ + self.schedule = None + # We may modify cfg.TRAIN.BATCH_SIZE, cfg.DATA.NUM_FRAMES, and + # cfg.DATA.TRAIN_CROP_SIZE during training, so we store their original + # value in cfg and use them as global variables. + cfg.MULTIGRID.DEFAULT_B = cfg.TRAIN.BATCH_SIZE + cfg.MULTIGRID.DEFAULT_T = cfg.DATA.NUM_FRAMES + cfg.MULTIGRID.DEFAULT_S = cfg.DATA.TRAIN_CROP_SIZE + + if cfg.MULTIGRID.LONG_CYCLE: + self.schedule = self.get_long_cycle_schedule(cfg) + cfg.SOLVER.STEPS = [0] + [s[-1] for s in self.schedule] + # Fine-tuning phase. + cfg.SOLVER.STEPS[-1] = ( + cfg.SOLVER.STEPS[-2] + cfg.SOLVER.STEPS[-1] + ) // 2 + cfg.SOLVER.LRS = [ + cfg.SOLVER.GAMMA ** s[0] * s[1][0] for s in self.schedule + ] + # Fine-tuning phase. + cfg.SOLVER.LRS = cfg.SOLVER.LRS[:-1] + [ + cfg.SOLVER.LRS[-2], + cfg.SOLVER.LRS[-1], + ] + + cfg.SOLVER.MAX_EPOCH = self.schedule[-1][-1] + + elif cfg.MULTIGRID.SHORT_CYCLE: + cfg.SOLVER.STEPS = [ + int(s * cfg.MULTIGRID.EPOCH_FACTOR) for s in cfg.SOLVER.STEPS + ] + cfg.SOLVER.MAX_EPOCH = int( + cfg.SOLVER.MAX_EPOCH * cfg.MULTIGRID.EPOCH_FACTOR + ) + return cfg + + def update_long_cycle(self, cfg, cur_epoch): + """ + Before every epoch, check if long cycle shape should change. If it + should, update cfg accordingly. + Args: + cfg (configs): configs that contains training and multigrid specific + hyperparameters. Details can be seen in + slowfast/config/defaults.py. + cur_epoch (int): current epoch index. + Returns: + cfg (configs): the updated cfg. + changed (bool): do we change long cycle shape at this epoch? + """ + base_b, base_t, base_s = get_current_long_cycle_shape( + self.schedule, cur_epoch + ) + if base_s != cfg.DATA.TRAIN_CROP_SIZE or base_t != cfg.DATA.NUM_FRAMES: + + cfg.DATA.NUM_FRAMES = base_t + cfg.DATA.TRAIN_CROP_SIZE = base_s + cfg.TRAIN.BATCH_SIZE = base_b * cfg.MULTIGRID.DEFAULT_B + + bs_factor = ( + float(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS) + / cfg.MULTIGRID.BN_BASE_SIZE + ) + + if bs_factor < 1: + cfg.BN.NORM_TYPE = "sync_batchnorm" + cfg.BN.NUM_SYNC_DEVICES = int(1.0 / bs_factor) + elif bs_factor > 1: + cfg.BN.NORM_TYPE = "sub_batchnorm" + cfg.BN.NUM_SPLITS = int(bs_factor) + else: + cfg.BN.NORM_TYPE = "batchnorm" + + cfg.MULTIGRID.LONG_CYCLE_SAMPLING_RATE = cfg.DATA.SAMPLING_RATE * ( + cfg.MULTIGRID.DEFAULT_T // cfg.DATA.NUM_FRAMES + ) + logger.info("Long cycle updates:") + logger.info("\tBN.NORM_TYPE: {}".format(cfg.BN.NORM_TYPE)) + if cfg.BN.NORM_TYPE == "sync_batchnorm": + logger.info( + "\tBN.NUM_SYNC_DEVICES: {}".format(cfg.BN.NUM_SYNC_DEVICES) + ) + elif cfg.BN.NORM_TYPE == "sub_batchnorm": + logger.info("\tBN.NUM_SPLITS: {}".format(cfg.BN.NUM_SPLITS)) + logger.info("\tTRAIN.BATCH_SIZE: {}".format(cfg.TRAIN.BATCH_SIZE)) + logger.info( + "\tDATA.NUM_FRAMES x LONG_CYCLE_SAMPLING_RATE: {}x{}".format( + cfg.DATA.NUM_FRAMES, cfg.MULTIGRID.LONG_CYCLE_SAMPLING_RATE + ) + ) + logger.info( + "\tDATA.TRAIN_CROP_SIZE: {}".format(cfg.DATA.TRAIN_CROP_SIZE) + ) + return cfg, True + else: + return cfg, False + + def get_long_cycle_schedule(self, cfg): + """ + Based on multigrid hyperparameters, define the schedule of a long cycle. + Args: + cfg (configs): configs that contains training and multigrid specific + hyperparameters. Details can be seen in + slowfast/config/defaults.py. + Returns: + schedule (list): Specifies a list long cycle base shapes and their + corresponding training epochs. + """ + + steps = cfg.SOLVER.STEPS + + default_size = float( + cfg.DATA.NUM_FRAMES * cfg.DATA.TRAIN_CROP_SIZE ** 2 + ) + default_iters = steps[-1] + + # Get shapes and average batch size for each long cycle shape. + avg_bs = [] + all_shapes = [] + for t_factor, s_factor in cfg.MULTIGRID.LONG_CYCLE_FACTORS: + base_t = int(round(cfg.DATA.NUM_FRAMES * t_factor)) + base_s = int(round(cfg.DATA.TRAIN_CROP_SIZE * s_factor)) + if cfg.MULTIGRID.SHORT_CYCLE: + shapes = [ + [ + base_t, + cfg.MULTIGRID.DEFAULT_S + * cfg.MULTIGRID.SHORT_CYCLE_FACTORS[0], + ], + [ + base_t, + cfg.MULTIGRID.DEFAULT_S + * cfg.MULTIGRID.SHORT_CYCLE_FACTORS[1], + ], + [base_t, base_s], + ] + else: + shapes = [[base_t, base_s]] + + # (T, S) -> (B, T, S) + shapes = [ + [int(round(default_size / (s[0] * s[1] * s[1]))), s[0], s[1]] + for s in shapes + ] + avg_bs.append(np.mean([s[0] for s in shapes])) + all_shapes.append(shapes) + + # Get schedule regardless of cfg.MULTIGRID.EPOCH_FACTOR. + total_iters = 0 + schedule = [] + for step_index in range(len(steps) - 1): + step_epochs = steps[step_index + 1] - steps[step_index] + + for long_cycle_index, shapes in enumerate(all_shapes): + cur_epochs = ( + step_epochs * avg_bs[long_cycle_index] / sum(avg_bs) + ) + + cur_iters = cur_epochs / avg_bs[long_cycle_index] + total_iters += cur_iters + schedule.append((step_index, shapes[-1], cur_epochs)) + + iter_saving = default_iters / total_iters + + final_step_epochs = cfg.SOLVER.MAX_EPOCH - steps[-1] + + # We define the fine-tuning phase to have the same amount of iteration + # saving as the rest of the training. + ft_epochs = final_step_epochs / iter_saving * avg_bs[-1] + + schedule.append((step_index + 1, all_shapes[-1][2], ft_epochs)) + + # Obtrain final schedule given desired cfg.MULTIGRID.EPOCH_FACTOR. + x = ( + cfg.SOLVER.MAX_EPOCH + * cfg.MULTIGRID.EPOCH_FACTOR + / sum(s[-1] for s in schedule) + ) + + final_schedule = [] + total_epochs = 0 + for s in schedule: + epochs = s[2] * x + total_epochs += epochs + final_schedule.append((s[0], s[1], int(round(total_epochs)))) + print_schedule(final_schedule) + return final_schedule + + +def print_schedule(schedule): + """ + Log schedule. + """ + logger.info("Long cycle index\tBase shape\tEpochs") + for s in schedule: + logger.info("{}\t{}\t{}".format(s[0], s[1], s[2])) + + +def get_current_long_cycle_shape(schedule, epoch): + """ + Given a schedule and epoch index, return the long cycle base shape. + Args: + schedule (configs): configs that contains training and multigrid specific + hyperparameters. Details can be seen in + slowfast/config/defaults.py. + cur_epoch (int): current epoch index. + Returns: + shapes (list): A list describing the base shape in a long cycle: + [batch size relative to default, + number of frames, spatial dimension]. + """ + for s in schedule: + if epoch < s[-1]: + return s[1] + return schedule[-1][1] diff --git a/training/detectors/utils/slowfast/utils/multiprocessing.py b/training/detectors/utils/slowfast/utils/multiprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..a56aa603697dd3a6a37871ddd4407523aeebbb8c --- /dev/null +++ b/training/detectors/utils/slowfast/utils/multiprocessing.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +"""Multiprocessing helpers.""" + +import torch + + +def run( + local_rank, num_proc, func, init_method, shard_id, num_shards, backend, cfg +): + """ + Runs a function from a child process. + Args: + local_rank (int): rank of the current process on the current machine. + num_proc (int): number of processes per machine. + func (function): function to execute on each of the process. + init_method (string): method to initialize the distributed training. + TCP initialization: equiring a network address reachable from all + processes followed by the port. + Shared file-system initialization: makes use of a file system that + is shared and visible from all machines. The URL should start with + file:// and contain a path to a non-existent file on a shared file + system. + shard_id (int): the rank of the current machine. + num_shards (int): number of overall machines for the distributed + training job. + backend (string): three distributed backends ('nccl', 'gloo', 'mpi') are + supports, each with different capabilities. Details can be found + here: + https://pytorch.org/docs/stable/distributed.html + cfg (CfgNode): configs. Details can be found in + slowfast/config/defaults.py + """ + # Initialize the process group. + world_size = num_proc * num_shards + rank = shard_id * num_proc + local_rank + + try: + torch.distributed.init_process_group( + backend=backend, + init_method=init_method, + world_size=world_size, + rank=rank, + ) + except Exception as e: + raise e + + torch.cuda.set_device(local_rank) + func(cfg) diff --git a/training/detectors/utils/slowfast/utils/parser.py b/training/detectors/utils/slowfast/utils/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..06b4373e3b3736ceb310eed465fbc75e3bae1eb5 --- /dev/null +++ b/training/detectors/utils/slowfast/utils/parser.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +"""Argument parser functions.""" + +import argparse +import sys + +import slowfast.utils.checkpoint as cu +from slowfast.config.defaults import get_cfg + + +def parse_args(): + """ + Parse the following arguments for a default parser for PySlowFast users. + Args: + shard_id (int): shard id for the current machine. Starts from 0 to + num_shards - 1. If single machine is used, then set shard id to 0. + num_shards (int): number of shards using by the job. + init_method (str): initialization method to launch the job with multiple + devices. Options includes TCP or shared file-system for + initialization. details can be find in + https://pytorch.org/docs/stable/distributed.html#tcp-initialization + cfg (str): path to the config file. + opts (argument): provide addtional options from the command line, it + overwrites the config loaded from file. + """ + parser = argparse.ArgumentParser( + description="Provide SlowFast video training and testing pipeline." + ) + parser.add_argument( + "--shard_id", + help="The shard id of current node, Starts from 0 to num_shards - 1", + default=0, + type=int, + ) + parser.add_argument( + "--num_shards", + help="Number of shards using by the job", + default=1, + type=int, + ) + parser.add_argument( + "--init_method", + help="Initialization method, includes TCP or shared file-system", + default="tcp://localhost:9999", + type=str, + ) + parser.add_argument( + "--cfg", + dest="cfg_file", + help="Path to the config file", + default="configs/Kinetics/SLOWFAST_4x16_R50.yaml", + type=str, + ) + parser.add_argument( + "opts", + help="See slowfast/config/defaults.py for all options", + default=None, + nargs=argparse.REMAINDER, + ) + if len(sys.argv) == 1: + parser.print_help() + return parser.parse_args() + + +def load_config(args): + """ + Given the arguemnts, load and initialize the configs. + Args: + args (argument): arguments includes `shard_id`, `num_shards`, + `init_method`, `cfg_file`, and `opts`. + """ + # Setup cfg. + cfg = get_cfg() + # Load config from cfg. + if args.cfg_file is not None: + cfg.merge_from_file(args.cfg_file) + # Load config from command line, overwrite config from opts. + if args.opts is not None: + cfg.merge_from_list(args.opts) + + # Inherit parameters from args. + if hasattr(args, "num_shards") and hasattr(args, "shard_id"): + cfg.NUM_SHARDS = args.num_shards + cfg.SHARD_ID = args.shard_id + if hasattr(args, "rng_seed"): + cfg.RNG_SEED = args.rng_seed + if hasattr(args, "output_dir"): + cfg.OUTPUT_DIR = args.output_dir + + # Create the checkpoint dir. + cu.make_checkpoint_dir(cfg.OUTPUT_DIR) + return cfg diff --git a/training/detectors/utils/slowfast/utils/weight_init_helper.py b/training/detectors/utils/slowfast/utils/weight_init_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..0b5544a70529f5dd1b06ba05a6aca4c7f508bdf3 --- /dev/null +++ b/training/detectors/utils/slowfast/utils/weight_init_helper.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +"""Utility function for weight initialization""" + +import torch.nn as nn +from fvcore.nn.weight_init import c2_msra_fill + + +def init_weights(model, fc_init_std=0.01, zero_init_final_bn=True): + """ + Performs ResNet style weight initialization. + Args: + fc_init_std (float): the expected standard deviation for fc layer. + zero_init_final_bn (bool): if True, zero initialize the final bn for + every bottleneck. + """ + for m in model.modules(): + if isinstance(m, nn.Conv3d): + """ + Follow the initialization method proposed in: + {He, Kaiming, et al. + "Delving deep into rectifiers: Surpassing human-level + performance on imagenet classification." + arXiv preprint arXiv:1502.01852 (2015)} + """ + c2_msra_fill(m) + elif isinstance(m, nn.BatchNorm3d): + if ( + hasattr(m, "transform_final_bn") + and m.transform_final_bn + and zero_init_final_bn + ): + batchnorm_weight = 0.0 + else: + batchnorm_weight = 1.0 + if m.weight is not None: + m.weight.data.fill_(batchnorm_weight) + if m.bias is not None: + m.bias.data.zero_() + if isinstance(m, nn.Linear): + m.weight.data.normal_(mean=0.0, std=fc_init_std) + m.bias.data.zero_() diff --git a/training/detectors/videomae_detector.py b/training/detectors/videomae_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..ca470d911c7035c90aa4b3bc9579cf667db6fc02 --- /dev/null +++ b/training/detectors/videomae_detector.py @@ -0,0 +1,124 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the XceptionDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{rossler2019faceforensics++, + title={Faceforensics++: Learning to detect manipulated facial images}, + author={Rossler, Andreas and Cozzolino, Davide and Verdoliva, Luisa and Riess, Christian and Thies, Justus and Nie{\ss}ner, Matthias}, + booktitle={Proceedings of the IEEE/CVF international conference on computer vision}, + pages={1--11}, + year={2019} +} +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from loss import LOSSFUNC + +import loralib as lora + +logger = logging.getLogger(__name__) + +@DETECTOR.register_module(module_name='videomae') +class VideoMAEDetector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.fc_norm = nn.LayerNorm(768) + self.head = nn.Linear(768, 2) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + from transformers import VideoMAEModel, VideoMAEConfig + configuration = VideoMAEConfig( + num_frames=self.config['clip_size'], + image_size=self.config['resolution'], + ) + backbone = VideoMAEModel.from_pretrained( + config['pretrained'], + config=configuration + ) + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + # b, t, c, h, w = data_dict['image'].shape + # frame_input = data_dict['image'].reshape(-1, c, h, w) + # # get frame-level features + # frame_level_features = self.backbone.features(frame_input) + # frame_level_features = F.adaptive_avg_pool2d(frame_level_features, (1, 1)).reshape(b, t, -1) + # # get video-level features + # video_level_features = self.temporal_module(frame_level_features)[0][:, -1, :] + outputs = self.backbone(data_dict['image'], output_hidden_states=True) # torch.Size([8, 16, 3, 224, 224]) + sequence_output = outputs[0] # torch.Size([8, 1568, 768]) + video_level_features = self.fc_norm(sequence_output.mean(1)) # torch.Size([8, 768]) + return video_level_features + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + # compute metrics for batch data + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + # we dont compute the video-level metrics for training + self.video_names = [] + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + prob = torch.softmax(pred, dim=1)[:, 1] + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + + return pred_dict diff --git a/training/detectors/videomae_large_detector.py b/training/detectors/videomae_large_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..5c71cda4487b86342811f83c41463be7d0a2cab7 --- /dev/null +++ b/training/detectors/videomae_large_detector.py @@ -0,0 +1,117 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the XceptionDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{rossler2019faceforensics++, + title={Faceforensics++: Learning to detect manipulated facial images}, + author={Rossler, Andreas and Cozzolino, Davide and Verdoliva, Luisa and Riess, Christian and Thies, Justus and Nie{\ss}ner, Matthias}, + booktitle={Proceedings of the IEEE/CVF international conference on computer vision}, + pages={1--11}, + year={2019} +} +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from loss import LOSSFUNC + +import loralib as lora + +logger = logging.getLogger(__name__) + +@DETECTOR.register_module(module_name='videomae_large') +class VideoMAELargeDetector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.fc_norm = nn.LayerNorm(1024) + self.head = nn.Linear(1024, 2) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + from transformers import VideoMAEModel + backbone = VideoMAEModel.from_pretrained(config['pretrained']) + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + # b, t, c, h, w = data_dict['image'].shape + # frame_input = data_dict['image'].reshape(-1, c, h, w) + # # get frame-level features + # frame_level_features = self.backbone.features(frame_input) + # frame_level_features = F.adaptive_avg_pool2d(frame_level_features, (1, 1)).reshape(b, t, -1) + # # get video-level features + # video_level_features = self.temporal_module(frame_level_features)[0][:, -1, :] + outputs = self.backbone(data_dict['image'], output_hidden_states=True) # torch.Size([8, 16, 3, 224, 224]) + sequence_output = outputs[0] # torch.Size([8, 1568, 1024]) + video_level_features = self.fc_norm(sequence_output.mean(1)) # torch.Size([8, 1024]) + return video_level_features + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + # compute metrics for batch data + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + # we dont compute the video-level metrics for training + self.video_names = [] + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + prob = torch.softmax(pred, dim=1)[:, 1] + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + + return pred_dict diff --git a/training/detectors/videomae_lora_detector.py b/training/detectors/videomae_lora_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..0b96ee768edf5749c34a9f78174176138c21d1a8 --- /dev/null +++ b/training/detectors/videomae_lora_detector.py @@ -0,0 +1,175 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the XceptionDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{rossler2019faceforensics++, + title={Faceforensics++: Learning to detect manipulated facial images}, + author={Rossler, Andreas and Cozzolino, Davide and Verdoliva, Luisa and Riess, Christian and Thies, Justus and Nie{\ss}ner, Matthias}, + booktitle={Proceedings of the IEEE/CVF international conference on computer vision}, + pages={1--11}, + year={2019} +} +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from loss import LOSSFUNC + +import loralib as lora + +logger = logging.getLogger(__name__) + +@DETECTOR.register_module(module_name='videomae_lora') +class VideoMAELoRADetector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.fc_norm = nn.LayerNorm(768) + self.head = nn.Linear(768, 2) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + from transformers import VideoMAEModel, VideoMAEConfig + configuration = VideoMAEConfig( + num_frames=self.config['clip_size'], + image_size=self.config['resolution'], + ) + backbone = VideoMAEModel.from_pretrained( + config['pretrained'], + config=configuration + ) + backbone = to_lora(backbone, r=16) + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + # b, t, c, h, w = data_dict['image'].shape + # frame_input = data_dict['image'].reshape(-1, c, h, w) + # # get frame-level features + # frame_level_features = self.backbone.features(frame_input) + # frame_level_features = F.adaptive_avg_pool2d(frame_level_features, (1, 1)).reshape(b, t, -1) + # # get video-level features + # video_level_features = self.temporal_module(frame_level_features)[0][:, -1, :] + outputs = self.backbone(data_dict['image'], output_hidden_states=True) # torch.Size([8, 16, 3, 224, 224]) + sequence_output = outputs[0] # torch.Size([8, 1568, 768]) + video_level_features = self.fc_norm(sequence_output.mean(1)) # torch.Size([8, 768]) + return video_level_features + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + # compute metrics for batch data + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + # we dont compute the video-level metrics for training + self.video_names = [] + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + prob = torch.softmax(pred, dim=1)[:, 1] + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + + return pred_dict + + +def to_lora(model, target=[nn.Linear, nn.Conv2d, nn.Embedding], r=16, f_class=None, layers=None, names=None): + + for n, m in model.named_modules(): + if f_class is not None and not isinstance(m, f_class): + continue + if isinstance(m, nn.Sequential) or isinstance(m, nn.ModuleList): + + for name, mod in m.named_children(): + + # print(name, mod) + if isinstance(mod, nn.Linear) and not isinstance(mod, lora.Linear): + mod = change_mod(mod, r=r) + m._modules[name] = mod + else: + if layers is None or any(['layers.' + str(i) in n for i in layers]): + for name, mod in m.named_children(): + # if 'self_attn' in f_name: + # print(name, mod) + if isinstance(mod, nn.Linear) and not isinstance(mod, lora.Linear): + if names is None or any(na in name for na in names): + mod = change_mod(mod, r=r) + setattr(m, name, mod) + + lora.mark_only_lora_as_trainable(model) + return model + + +def change_mod(m, targets=[nn.Linear, nn.Conv2d, nn.Embedding], r=16): + st_dict = m.state_dict() + + if nn.Linear in targets and isinstance(m, nn.Linear): + dtype = m.weight.dtype + new_m = lora.Linear(m.in_features, m.out_features, bias=m.bias is not None, r=r, dtype=dtype) + new_m.load_state_dict(st_dict, strict=False) + # print(new_m) + m = new_m + elif nn.Conv2d in targets and isinstance(m, nn.Conv2d): + new_m = lora.Conv2d(m.in_channels, m.out_channels, m.kernel_size, stride=m.stride, padding=m.padding, \ + dilation=m.dilation, transposed=m.transposed, output_padding=m.output_padding, groups=m.groups, bias=m.bias, r=r) + new_m.load_state_dict(st_dict, strict=False) + m = new_m + elif nn.Embedding in targets and isinstance(m, nn.Embedding): + new_m = lora.Embedding(m.num_embeddings, m.embedding_dim, padding_idx=m.padding_idx, max_norm=m.max_norm, norm_type=m.norm_type, \ + scale_grad_by_freq=m.scale_grad_by_freq, freeze=m.freeze, sparse=m.sparse, r=r) + new_m.load_state_dict(st_dict, strict=False) + m = new_m + + return m \ No newline at end of file diff --git a/training/detectors/vit_detector.py b/training/detectors/vit_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..0ce02b0f1e2abb2389b23fc5490c636b801659be --- /dev/null +++ b/training/detectors/vit_detector.py @@ -0,0 +1,112 @@ + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC + +from transformers import AutoImageProcessor, ViTModel, ViTConfig +import loralib as lora +import copy + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='vit_large_fft') +class ViT_Large_FFT_Detector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + + self.head = nn.Linear(1024, config['backbone_config']['num_classes']) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + # prepare the backbone + + _, backbone = get_standard_vit(model_name=config['pretrained']) + return backbone + + def build_loss(self, config): + + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + + feat = self.backbone(data_dict['image'])['pooler_output'] + return feat + + def classifier(self, features: torch.tensor) -> torch.tensor: + + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + + label = data_dict['label'] + pred = pred_dict['cls'] + + acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.config['backbone_config']['num_classes']) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + + # get the features by backbone + features = self.features(data_dict) + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + prob = torch.softmax(pred, dim=1) + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + return pred_dict + + +# Key change 7: add `get_standard_vit` to replace the original `get_clip_visual` +def get_standard_vit(model_name = "google/vit-large-patch16-224"): + """ + Load the standard Google ViT model as a replacement for the CLIP visual branch + Args: + model_name: ViT model name, supported options include: + - google/vit-base-patch16-224 + - google/vit-large-patch16-224 + - google/vit-huge-patch14-224 + Returns: + image_processor: Image preprocessor + vit_model: ViT backbone (visual component only, without the classification head) + """ + + image_processor = AutoImageProcessor.from_pretrained(model_name) + + vit_model = ViTModel.from_pretrained( + model_name, + ) + return image_processor, vit_model \ No newline at end of file diff --git a/training/detectors/xception_detector.py b/training/detectors/xception_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..0d3854b0c9975125d4f73789090aa0d8691aec9e --- /dev/null +++ b/training/detectors/xception_detector.py @@ -0,0 +1,123 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the XceptionDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{rossler2019faceforensics++, + title={Faceforensics++: Learning to detect manipulated facial images}, + author={Rossler, Andreas and Cozzolino, Davide and Verdoliva, Luisa and Riess, Christian and Thies, Justus and Nie{\ss}ner, Matthias}, + booktitle={Proceedings of the IEEE/CVF international conference on computer vision}, + pages={1--11}, + year={2019} +} +''' + +import os +import datetime +import logging +import numpy as np +from sklearn import metrics +from typing import Union +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter + +from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train + +from .base_detector import AbstractDetector +from detectors import DETECTOR +from networks import BACKBONE +from loss import LOSSFUNC + +logger = logging.getLogger(__name__) + +@DETECTOR.register_module(module_name='xception') +class XceptionDetector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.loss_func = self.build_loss(config) + self.prob, self.label = [], [] + self.video_names = [] + self.correct, self.total = 0, 0 + + def build_backbone(self, config): + # prepare the backbone + backbone_class = BACKBONE[config['backbone_name']] + model_config = config['backbone_config'] + backbone = backbone_class(model_config) + # if donot load the pretrained weights, fail to get good results + state_dict = torch.load(config['pretrained']) + for name, weights in state_dict.items(): + if 'pointwise' in name: + state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1) + state_dict = {k:v for k, v in state_dict.items() if 'fc' not in k} + backbone.load_state_dict(state_dict, False) + logger.info('Load pretrained model successfully!') + return backbone + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + return self.backbone.features(data_dict['image']) #32,3,256,256 + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.backbone.classifier(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + overall_loss = loss + loss_dict = {'overall': overall_loss, 'cls': loss,} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + + # compute metrics for batch data + # auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + # metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.config['backbone_config']['num_classes']) + metric_batch_dict = {'acc': acc, 'mAP': mAP} + + # we dont compute the video-level metrics for training + self.video_names = [] + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + # get the features by backbone + features = self.features(data_dict) # features [64, 2048, 8, 8] + # get the prediction by classifier + pred = self.classifier(features) + # get the probability of the pred + # prob = torch.softmax(pred, dim=1)[:, 1] + prob = torch.softmax(pred, dim=1) + # build the prediction dict for each output + pred_dict = {'cls': pred, 'prob': prob, 'feat': torch.mean(features, dim=[2, 3])} + return pred_dict diff --git a/training/detectors/xclip_detector.py b/training/detectors/xclip_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..91fc29695ad091b7e33bb5cd639a289377732872 --- /dev/null +++ b/training/detectors/xclip_detector.py @@ -0,0 +1,111 @@ +""" +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 +# description: Class for the XCLIPDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{ma2022x, + title={X-clip: End-to-end multi-grained contrastive learning for video-text retrieval}, + author={Ma, Yiwei and Xu, Guohai and Sun, Xiaoshuai and Yan, Ming and Zhang, Ji and Ji, Rongrong}, + booktitle={Proceedings of the 30th ACM International Conference on Multimedia}, + pages={638--647}, + year={2022} +} +""" + +import logging + +import torch +import torch.nn as nn +from detectors import DETECTOR +from loss import LOSSFUNC +from metrics.base_metrics_class import calculate_metrics_for_train + +from .base_detector import AbstractDetector + +logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='xclip') +class XCLIPDetector(AbstractDetector): + def __init__(self, config): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.fc_norm = nn.LayerNorm(768) + # self.temporal_module = self.build_temporal_module(config) + self.head = nn.Linear(768, 2) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + from transformers import XCLIPVisionModel + backbone = XCLIPVisionModel.from_pretrained(config['pretrained']) + + return backbone + + def build_temporal_module(self, config): + return nn.LSTM(input_size=2048, hidden_size=512, num_layers=3, batch_first=True) + + def build_loss(self, config): + # prepare the loss function + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + # b, t, c, h, w = data_dict['image'].shape + # frame_input = data_dict['image'].reshape(-1, c, h, w) + # # get frame-level features + # frame_level_features = self.backbone.features(frame_input) + # frame_level_features = F.adaptive_avg_pool2d(frame_level_features, (1, 1)).reshape(b, t, -1) + # # get video-level features + # video_level_features = self.temporal_module(frame_level_features)[0][:, -1, :] + + batch_size, num_frames, num_channels, height, width = data_dict['image'].shape + pixel_values = data_dict['image'].reshape(-1, num_channels, height, width) + outputs = self.backbone(pixel_values, output_hidden_states=True) + sequence_output = outputs['pooler_output'].reshape(batch_size, num_frames, -1) + video_level_features = self.fc_norm(sequence_output.mean(1)) + + return video_level_features + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + features = self.features(data_dict) + pred = self.classifier(features) + prob = torch.softmax(pred, dim=1)[:, 1] + pred_dict = {'cls': pred, 'prob': prob, 'feat': features} + + + return pred_dict diff --git a/training/lib/component/MCT/template0.png b/training/lib/component/MCT/template0.png new file mode 100644 index 0000000000000000000000000000000000000000..bc2b450c2d80980c30aebac10201bda91e9d0e02 Binary files /dev/null and b/training/lib/component/MCT/template0.png differ diff --git a/training/lib/component/MCT/template1.png b/training/lib/component/MCT/template1.png new file mode 100644 index 0000000000000000000000000000000000000000..ce86a7f4b5b045f3222d43414c8a3a2c0d189a11 Binary files /dev/null and b/training/lib/component/MCT/template1.png differ diff --git a/training/lib/component/MCT/template2.png b/training/lib/component/MCT/template2.png new file mode 100644 index 0000000000000000000000000000000000000000..4917a2ea0b39dec9ac5f1233b0c7c58d04b681ca Binary files /dev/null and b/training/lib/component/MCT/template2.png differ diff --git a/training/lib/component/MCT/template3.png b/training/lib/component/MCT/template3.png new file mode 100644 index 0000000000000000000000000000000000000000..03d3fd0c66ee60474397592feea9835ffe51270a Binary files /dev/null and b/training/lib/component/MCT/template3.png differ diff --git a/training/lib/component/MCT/template4.png b/training/lib/component/MCT/template4.png new file mode 100644 index 0000000000000000000000000000000000000000..44701c0a389cf0fa24c3ad770fdecac87317ce96 Binary files /dev/null and b/training/lib/component/MCT/template4.png differ diff --git a/training/lib/component/MCT/template5.png b/training/lib/component/MCT/template5.png new file mode 100644 index 0000000000000000000000000000000000000000..13e3ee65ad94c455dcadc1f0c75414ee8da9cd6c Binary files /dev/null and b/training/lib/component/MCT/template5.png differ diff --git a/training/lib/component/MCT/template6.png b/training/lib/component/MCT/template6.png new file mode 100644 index 0000000000000000000000000000000000000000..dd239d79c99e43be0cdbe0c58b75b03079ee47e8 Binary files /dev/null and b/training/lib/component/MCT/template6.png differ diff --git a/training/lib/component/MCT/template7.png b/training/lib/component/MCT/template7.png new file mode 100644 index 0000000000000000000000000000000000000000..ee4d7f7d8007319ed756b540af15247d801042c6 Binary files /dev/null and b/training/lib/component/MCT/template7.png differ diff --git a/training/lib/component/MCT/template8.png b/training/lib/component/MCT/template8.png new file mode 100644 index 0000000000000000000000000000000000000000..0421aed95af01c9dd6d4cd6c2ff58ac70dba9db1 Binary files /dev/null and b/training/lib/component/MCT/template8.png differ diff --git a/training/lib/component/MCT/template9.png b/training/lib/component/MCT/template9.png new file mode 100644 index 0000000000000000000000000000000000000000..92554ac8cc1eb0bd6be5ea6a85f29ab43febe5d5 Binary files /dev/null and b/training/lib/component/MCT/template9.png differ diff --git a/training/lib/component/SRM_Kernels.npy b/training/lib/component/SRM_Kernels.npy new file mode 100644 index 0000000000000000000000000000000000000000..29d7cde8456f9c357c73292c0efa3f1b9f5f877c Binary files /dev/null and b/training/lib/component/SRM_Kernels.npy differ diff --git a/training/lib/component/__init__.py b/training/lib/component/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..676145d777810e4a51bdaf59fdec4f5358aae349 --- /dev/null +++ b/training/lib/component/__init__.py @@ -0,0 +1,7 @@ +import os +import sys +current_file_path = os.path.abspath(__file__) +parent_dir = os.path.dirname(os.path.dirname(current_file_path)) +project_root_dir = os.path.dirname(parent_dir) +sys.path.append(parent_dir) +sys.path.append(project_root_dir) diff --git a/training/lib/component/attention.py b/training/lib/component/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..60f3c39c3137dd209b4743710c6e0674467752fc --- /dev/null +++ b/training/lib/component/attention.py @@ -0,0 +1,350 @@ + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ChannelAttention(nn.Module): + def __init__(self, in_planes, ratio=8): + super(ChannelAttention, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.max_pool = nn.AdaptiveMaxPool2d(1) + + self.sharedMLP = nn.Sequential( + nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False), + nn.ReLU(), + nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)) + self.sigmoid = nn.Sigmoid() + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.xavier_normal_(m.weight.data, gain=0.02) + + def forward(self, x): + avgout = self.sharedMLP(self.avg_pool(x)) + maxout = self.sharedMLP(self.max_pool(x)) + return self.sigmoid(avgout + maxout) + + +class SpatialAttention(nn.Module): + def __init__(self, kernel_size=7): + super(SpatialAttention, self).__init__() + assert kernel_size in (3, 7), "kernel size must be 3 or 7" + padding = 3 if kernel_size == 7 else 1 + + self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) + self.sigmoid = nn.Sigmoid() + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.xavier_normal_(m.weight.data, gain=0.02) + + def forward(self, x): + avgout = torch.mean(x, dim=1, keepdim=True) + maxout, _ = torch.max(x, dim=1, keepdim=True) + x = torch.cat([avgout, maxout], dim=1) + x = self.conv(x) + return self.sigmoid(x) + + +class Self_Attn(nn.Module): + """ Self attention Layer""" + + def __init__(self, in_dim, out_dim=None, add=False, ratio=8): + super(Self_Attn, self).__init__() + self.chanel_in = in_dim + self.add = add + if out_dim is None: + out_dim = in_dim + self.out_dim = out_dim + # self.activation = activation + + self.query_conv = nn.Conv2d( + in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1) + self.key_conv = nn.Conv2d( + in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1) + self.value_conv = nn.Conv2d( + in_channels=in_dim, out_channels=out_dim, kernel_size=1) + self.gamma = nn.Parameter(torch.zeros(1)) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x): + """ + inputs : + x : input feature maps( B X C X W X H) + returns : + out : self attention value + input feature + attention: B X N X N (N is Width*Height) + """ + m_batchsize, C, width, height = x.size() + proj_query = self.query_conv(x).view( + m_batchsize, -1, width*height).permute(0, 2, 1) # B X C X(N) + proj_key = self.key_conv(x).view( + m_batchsize, -1, width*height) # B X C x (*W*H) + energy = torch.bmm(proj_query, proj_key) # transpose check + attention = self.softmax(energy) # BX (N) X (N) + proj_value = self.value_conv(x).view( + m_batchsize, -1, width*height) # B X C X N + + out = torch.bmm(proj_value, attention.permute(0, 2, 1)) + out = out.view(m_batchsize, self.out_dim, width, height) + + if self.add: + out = self.gamma*out + x + else: + out = self.gamma*out + return out # , attention + + +class CrossModalAttention(nn.Module): + """ CMA attention Layer""" + + def __init__(self, in_dim, activation=None, ratio=8, cross_value=True): + super(CrossModalAttention, self).__init__() + self.chanel_in = in_dim + self.activation = activation + self.cross_value = cross_value + + self.query_conv = nn.Conv2d( + in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1) + self.key_conv = nn.Conv2d( + in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1) + self.value_conv = nn.Conv2d( + in_channels=in_dim, out_channels=in_dim, kernel_size=1) + self.gamma = nn.Parameter(torch.zeros(1)) + + self.softmax = nn.Softmax(dim=-1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.xavier_normal_(m.weight.data, gain=0.02) + + def forward(self, x, y): + """ + inputs : + x : input feature maps( B X C X W X H) + returns : + out : self attention value + input feature + attention: B X N X N (N is Width*Height) + """ + B, C, H, W = x.size() + + proj_query = self.query_conv(x).view( + B, -1, H*W).permute(0, 2, 1) # B , HW, C + proj_key = self.key_conv(y).view( + B, -1, H*W) # B X C x (*W*H) + energy = torch.bmm(proj_query, proj_key) # B, HW, HW + attention = self.softmax(energy) # BX (N) X (N) + if self.cross_value: + proj_value = self.value_conv(y).view( + B, -1, H*W) # B , C , HW + else: + proj_value = self.value_conv(x).view( + B, -1, H*W) # B , C , HW + + out = torch.bmm(proj_value, attention.permute(0, 2, 1)) + out = out.view(B, C, H, W) + + out = self.gamma*out + x + + if self.activation is not None: + out = self.activation(out) + + return out # , attention + +class DualCrossModalAttention(nn.Module): + """ Dual CMA attention Layer""" + + def __init__(self, in_dim, activation=None, size=16, ratio=8, ret_att=False): + super(DualCrossModalAttention, self).__init__() + self.chanel_in = in_dim + self.activation = activation + self.ret_att = ret_att + + # query conv + self.key_conv1 = nn.Conv2d( + in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1) + self.key_conv2 = nn.Conv2d( + in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1) + self.key_conv_share = nn.Conv2d( + in_channels=in_dim//ratio, out_channels=in_dim//ratio, kernel_size=1) + + self.linear1 = nn.Linear(size*size, size*size) + self.linear2 = nn.Linear(size*size, size*size) + + # separated value conv + self.value_conv1 = nn.Conv2d( + in_channels=in_dim, out_channels=in_dim, kernel_size=1) + self.gamma1 = nn.Parameter(torch.zeros(1)) + + self.value_conv2 = nn.Conv2d( + in_channels=in_dim, out_channels=in_dim, kernel_size=1) + self.gamma2 = nn.Parameter(torch.zeros(1)) + + self.softmax = nn.Softmax(dim=-1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.xavier_normal_(m.weight.data, gain=0.02) + if isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight.data, gain=0.02) + + def forward(self, x, y): + """ + inputs : + x : input feature maps( B X C X W X H) + returns : + out : self attention value + input feature + attention: B X N X N (N is Width*Height) + """ + B, C, H, W = x.size() + + def _get_att(a, b): + proj_key1 = self.key_conv_share(self.key_conv1(a)).view( + B, -1, H*W).permute(0, 2, 1) # B , HW, C + proj_key2 = self.key_conv_share(self.key_conv2(b)).view( + B, -1, H*W) # B X C x (*W*H) + #print('proj_key1:', proj_key1[0][0][:5].cpu().detach().numpy()) + #print('proj_key2:', proj_key2[0][:5][0:5].cpu().detach().numpy()) + energy = torch.bmm(proj_key1, proj_key2) # B, HW, HW + #print('energy:', energy[0][0][:5].cpu().detach().numpy()) + attention1 = self.softmax(self.linear1(energy)) + attention2 = self.softmax(self.linear2(energy.permute(0,2,1))) # BX (N) X (N) + #print('1:', attention1[0]==attention1[1]) + #print('2:', attention2[0]==attention2[1]) + + return attention1, attention2 + + att_y_on_x, att_x_on_y = _get_att(x, y) + #print('att_y_on_x:', att_y_on_x[0][0][:5].cpu().detach().numpy()) + proj_value_y_on_x = self.value_conv2(y).view( + B, -1, H*W) # B , C , HW + out_y_on_x = torch.bmm(proj_value_y_on_x, att_y_on_x.permute(0, 2, 1)) + out_y_on_x = out_y_on_x.view(B, C, H, W) + out_x = self.gamma1*out_y_on_x + x + + proj_value_x_on_y = self.value_conv1(x).view( + B, -1, H*W) # B , C , HW + out_x_on_y = torch.bmm(proj_value_x_on_y, att_x_on_y.permute(0, 2, 1)) + out_x_on_y = out_x_on_y.view(B, C, H, W) + out_y = self.gamma2*out_x_on_y + y + + if self.ret_att: + return out_x, out_y, att_y_on_x, att_x_on_y + + return out_x, out_y # , attention + +class DualCrossModalAttention_old(nn.Module): + """ Dual CMA attention Layer""" + + def __init__(self, in_dim, activation=None, ratio=8, ret_att=False): + super(DualCrossModalAttention_old, self).__init__() + self.chanel_in = in_dim + self.activation = activation + self.ret_att = ret_att + + # shared query & key conv + self.query_conv = nn.Conv2d( + in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1) + self.key_conv = nn.Conv2d( + in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1) + + # separated value conv + self.value_conv1 = nn.Conv2d( + in_channels=in_dim, out_channels=in_dim, kernel_size=1) + self.gamma1 = nn.Parameter(torch.zeros(1)) + + self.value_conv2 = nn.Conv2d( + in_channels=in_dim, out_channels=in_dim, kernel_size=1) + self.gamma2 = nn.Parameter(torch.zeros(1)) + + self.softmax = nn.Softmax(dim=-1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.xavier_normal_(m.weight.data, gain=0.02) + + def forward(self, x, y): + """ + inputs : + x : input feature maps( B X C X W X H) + returns : + out : self attention value + input feature + attention: B X N X N (N is Width*Height) + """ + B, C, H, W = x.size() + + def _get_att(q, k): + proj_query = self.query_conv(q).view( + B, -1, H*W).permute(0, 2, 1) # B , HW, C + proj_key = self.key_conv(k).view( + B, -1, H*W) # B X C x (*W*H) + #print('proj_key:', proj_key[0][0][:5].cpu().detach().numpy()) + energy = torch.bmm(proj_query, proj_key) # B, HW, HW + #print('energy:', energy[0][0][:5].cpu().detach().numpy()) + attention = self.softmax(energy) # BX (N) X (N) + + return attention + + att_y_on_x = _get_att(x, y) + #print('att_y_on_x:', att_y_on_x[0][0][:5].cpu().detach().numpy()) + proj_value_y_on_x = self.value_conv2(y).view( + B, -1, H*W) # B , C , HW + out_y_on_x = torch.bmm(proj_value_y_on_x, att_y_on_x.permute(0, 2, 1)) + out_y_on_x = out_y_on_x.view(B, C, H, W) + out_x = self.gamma1*out_y_on_x + x + + att_x_on_y = _get_att(y, x) + proj_value_x_on_y = self.value_conv1(x).view( + B, -1, H*W) # B , C , HW + out_x_on_y = torch.bmm(proj_value_x_on_y, att_x_on_y.permute(0, 2, 1)) + out_x_on_y = out_x_on_y.view(B, C, H, W) + out_y = self.gamma2*out_x_on_y + y + + if self.ret_att: + return out_x, out_y, att_y_on_x, att_x_on_y + + return out_x, out_y # , attention + + + + +''' +class BasicBlock(nn.Module): + expansion = 1 + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.ca = ChannelAttention(planes) + self.sa = SpatialAttention() + self.downsample = downsample + self.stride = stride + def forward(self, x): + residual = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + out = self.ca(out) * out # broadcasting mechanism + out = self.sa(out) * out # broadcasting mechanism + if self.downsample is not None: + residual = self.downsample(x) + out += residual + out = self.relu(out) + return out +''' + +if __name__ == "__main__": + x = torch.rand(10, 768, 16, 16) + y = torch.rand(10, 768, 16, 16) + dcma = DualCrossModalAttention(768, ret_att=True) + out_x, out_y, att_y_on_x, att_x_on_y = dcma(x, y) + print(out_y.size()) + print(att_x_on_y.size()) diff --git a/training/lib/component/gaussian_ops.py b/training/lib/component/gaussian_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..57a7dda967810210c2f2ffc2c9ae1b54dca82bcd --- /dev/null +++ b/training/lib/component/gaussian_ops.py @@ -0,0 +1,117 @@ +import cv2 +import numpy as np +import math +import numbers +import torch +from torch import nn +from torch.nn import functional as F + + +class GaussianSmoothing(nn.Module): + """ + Apply gaussian smoothing on a + 1d, 2d or 3d tensor. Filtering is performed seperately for each channel + in the input using a depthwise convolution. + Arguments: + channels (int, sequence): Number of channels of the input tensors. Output will + have this number of channels as well. + kernel_size (int, sequence): Size of the gaussian kernel. + sigma (float, sequence): Standard deviation of the gaussian kernel. + dim (int, optional): The number of dimensions of the data. + Default value is 2 (spatial). + """ + + def __init__(self, channels, kernel_size, sigma=0.1, dim=2): + super(GaussianSmoothing, self).__init__() + self.kernel_size = kernel_size + if isinstance(kernel_size, numbers.Number): + kernel_size = [kernel_size] * dim + if isinstance(sigma, numbers.Number): + sigma = [sigma] * dim + + # The gaussian kernel is the product of the + # gaussian function of each dimension. + kernel = 1 + meshgrids = torch.meshgrid( + [ + torch.arange(size, dtype=torch.float32) + for size in kernel_size + ] + ) + for size, std, mgrid in zip(kernel_size, sigma, meshgrids): + mean = (size - 1) / 2 + kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \ + torch.exp(-((mgrid - mean) / std) ** 2 / 2) + + # Make sure sum of values in gaussian kernel equals 1. + kernel = kernel / torch.sum(kernel) + + # Reshape to depthwise convolutional weight + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) + + self.register_buffer('weight', kernel) + self.groups = channels + + if dim == 1: + self.conv = F.conv1d + elif dim == 2: + self.conv = F.conv2d + elif dim == 3: + self.conv = F.conv3d + else: + raise RuntimeError( + 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format( + dim) + ) + + def forward(self, input): + """ + Apply gaussian filter to input. + Arguments: + input (torch.Tensor): Input to apply gaussian filter on. + Returns: + filtered (torch.Tensor): Filtered output. + """ + if self.training: + return self.conv(input, weight=self.weight, groups=self.groups, padding=self.kernel_size//2) + else: + return input + +class GaussianNoise(nn.Module): + def __init__(self, mean=0, std=0.1, clip=1): + super(GaussianNoise, self).__init__() + self.mean = mean + self.std = std + self.clip = clip + + def forward(self, x): + if self.training: + noise = x.data.new(x.size()).normal_(self.mean, self.std) + return torch.clamp(x + noise, -self.clip, self.clip) + else: + return x + + +if __name__ == "__main__": + im = cv2.imread('E:\SRM\component\FF-F2F_0.png') + im_ten = im/255*2-1 + im_ten = torch.from_numpy(im_ten).unsqueeze(0).permute(0, 3, 1, 2).float() + blur = GaussianSmoothing(channels=3, kernel_size=7, sigma=0.8) + noise = GaussianNoise() + + noise_im = torch.clamp(noise(im_ten), -1, 1) + blur_im = blur(im_ten) + print(blur_im.size()) + + def t2im(t): + + t = (t+1)/2*255 + im = t.squeeze().cpu().numpy().transpose(1, 2, 0).astype(np.uint8) + return im + + cv2.imshow('ori', im) + cv2.imshow('blur', t2im(blur_im)) + cv2.imshow('noise', t2im(noise_im)) + + cv2.waitKey() diff --git a/training/lib/component/srm_conv.py b/training/lib/component/srm_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..2d60a41fd63fca643fae525ec2f6d3ed001b75d2 --- /dev/null +++ b/training/lib/component/srm_conv.py @@ -0,0 +1,194 @@ +# -------------------------------------------------------- +# Two Stream Faster R-CNN +# Licensed under The MIT License [see LICENSE for details] +# Written by Hangyan Jiang +# -------------------------------------------------------- + +# Testing part +import torch +import torch.nn as nn +import torch.nn.functional as F +import cv2 +from PIL import Image +import numpy as np +import matplotlib.pyplot as plt + +import argparse + + +class SRMConv2d(nn.Module): + + def __init__(self, learnable=False): + super(SRMConv2d, self).__init__() + self.weight = nn.Parameter(torch.Tensor(30, 3, 5, 5), + requires_grad=learnable) + self.bias = nn.Parameter(torch.Tensor(30), \ + requires_grad=learnable) + self.reset_parameters() + + def reset_parameters(self): + SRM_npy = np.load('lib/component/SRM_Kernels.npy') + # print(SRM_npy.shape) + SRM_npy = np.repeat(SRM_npy, 3, axis=1) + # print(SRM_npy.shape) + self.weight.data.numpy()[:] = SRM_npy + self.bias.data.zero_() + + def forward(self, input): + return F.conv2d(input, self.weight, stride=1, padding=2) + + + +class SRMConv2d_simple(nn.Module): + + def __init__(self, inc=3, learnable=False): + super(SRMConv2d_simple, self).__init__() + self.truc = nn.Hardtanh(-3, 3) + kernel = self._build_kernel(inc) # (3,3,5,5) + self.kernel = nn.Parameter(data=kernel, requires_grad=learnable) + # self.hor_kernel = self._build_kernel().transpose(0,1,3,2) + + def forward(self, x): + ''' + x: imgs (Batch, H, W, 3) + ''' + out = F.conv2d(x, self.kernel, stride=1, padding=2) + out = self.truc(out) + + return out + + def _build_kernel(self, inc): + # filter1: KB + filter1 = [[0, 0, 0, 0, 0], + [0, -1, 2, -1, 0], + [0, 2, -4, 2, 0], + [0, -1, 2, -1, 0], + [0, 0, 0, 0, 0]] + # filter2:KV + filter2 = [[-1, 2, -2, 2, -1], + [2, -6, 8, -6, 2], + [-2, 8, -12, 8, -2], + [2, -6, 8, -6, 2], + [-1, 2, -2, 2, -1]] + # # filter3:hor 2rd + filter3 = [[0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 1, -2, 1, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]] + # filter3:hor 2rd + # filter3 = [[0, 0, 0, 0, 0], + # [0, 0, 1, 0, 0], + # [0, 1, -4, 1, 0], + # [0, 0, 1, 0, 0], + # [0, 0, 0, 0, 0]] + + filter1 = np.asarray(filter1, dtype=float) / 4. + filter2 = np.asarray(filter2, dtype=float) / 12. + filter3 = np.asarray(filter3, dtype=float) / 2. + # statck the filters + filters = [[filter1],#, filter1, filter1], + [filter2],#, filter2, filter2], + [filter3]]#, filter3, filter3]] # (3,3,5,5) + filters = np.array(filters) + filters = np.repeat(filters, inc, axis=1) + filters = torch.FloatTensor(filters) # (3,3,5,5) + return filters + +class SRMConv2d_Separate(nn.Module): + + def __init__(self, inc, outc, learnable=False): + super(SRMConv2d_Separate, self).__init__() + self.inc = inc + self.truc = nn.Hardtanh(-3, 3) + kernel = self._build_kernel(inc) # (3,3,5,5) + self.kernel = nn.Parameter(data=kernel, requires_grad=learnable) + # self.hor_kernel = self._build_kernel().transpose(0,1,3,2) + self.out_conv = nn.Sequential( + nn.Conv2d(3*inc, outc, 1, 1, 0, 1, 1, bias=False), + nn.BatchNorm2d(outc), + nn.ReLU(inplace=True) + ) + + for ly in self.out_conv.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + + def forward(self, x): + ''' + x: imgs (Batch,inc, H, W) + kernel: (outc,inc,kH,kW) + ''' + out = F.conv2d(x, self.kernel, stride=1, padding=2, groups=self.inc) + out = self.truc(out) + out = self.out_conv(out) + + return out + + def _build_kernel(self, inc): + # filter1: KB + filter1 = [[0, 0, 0, 0, 0], + [0, -1, 2, -1, 0], + [0, 2, -4, 2, 0], + [0, -1, 2, -1, 0], + [0, 0, 0, 0, 0]] + # filter2:KV + filter2 = [[-1, 2, -2, 2, -1], + [2, -6, 8, -6, 2], + [-2, 8, -12, 8, -2], + [2, -6, 8, -6, 2], + [-1, 2, -2, 2, -1]] + # # filter3:hor 2rd + filter3 = [[0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 1, -2, 1, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]] + # filter3:hor 2rd + # filter3 = [[0, 0, 0, 0, 0], + # [0, 0, 1, 0, 0], + # [0, 1, -4, 1, 0], + # [0, 0, 1, 0, 0], + # [0, 0, 0, 0, 0]] + + filter1 = np.asarray(filter1, dtype=float) / 4. + filter2 = np.asarray(filter2, dtype=float) / 12. + filter3 = np.asarray(filter3, dtype=float) / 2. + # statck the filters + filters = [[filter1],#, filter1, filter1], + [filter2],#, filter2, filter2], + [filter3]]#, filter3, filter3]] # (3,3,5,5) => (3,1,5,5) + filters = np.array(filters) + # filters = np.repeat(filters, inc, axis=1) + filters = np.repeat(filters, inc, axis=0) + filters = torch.FloatTensor(filters) # (3*inc,1,5,5) + # print(filters.size()) + return filters + + +if __name__ == "__main__": + im = cv2.imread('E:\SRM\component\FF-F2F_0.png') + im_ten = im/255*2-1 + im_ten = torch.from_numpy(im_ten).unsqueeze(0).permute(0, 3, 1, 2).float() + # im_ten = torch.cat((im_ten, im_ten), dim=1) + srm_conv = SRMConv2d_simple(inc=3) + srm_conv1 = SRMConv2d_Separate(inc=3, outc=3) + + srm = srm_conv(im_ten) + print(srm.size()) + + def t2im(t): + + # t = (t+1)/2*255 + t = t*255 + im = t.squeeze().detach().cpu().numpy().transpose(1, 2, 0).astype(np.uint8) + return im + + cv2.imshow('ori', im) + cv2.imshow('srm', t2im(srm)) + cv2.imshow('srm1', t2im(srm_conv1(im_ten))) + # cv2.imshow('srm2', t2im(srm_conv(srm))) + + cv2.waitKey() + + diff --git a/training/logger.py b/training/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..af39158b34ff5026964520ce6c34475d262ba251 --- /dev/null +++ b/training/logger.py @@ -0,0 +1,48 @@ +import os +import logging +import torch.distributed as dist + +class Rank0Filter(logging.Filter): + """Only allow the rank-0 process to output logs""" + def filter(self, record): + # Check the current process rank and only allow rank 0 to pass + return dist.get_rank() == 0 + +def create_logger(log_path): + # Ensure the distributed environment has been initialized to avoid get_rank() errors + if not dist.is_initialized(): + raise RuntimeError("torch.distributed must be initialized before creating the logger!") + + # Current process rank + rank = dist.get_rank() + + # Create the logger + logger = logging.getLogger() + logger.setLevel(logging.INFO) + # Clear existing handlers to avoid duplicate outputs + logger.handlers = [] + + logger.log_path = log_path + + # Only rank 0 creates the log file and console output + if rank == 0: + # Create the log directory + log_dir = os.path.dirname(log_path) + if log_dir and not os.path.exists(log_dir): + os.makedirs(log_dir, exist_ok=True) + + # File handler (writes to the log file) + fh = logging.FileHandler(log_path) + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + fh.setFormatter(formatter) + logger.addHandler(fh) + + # Console handler (prints to the terminal) + sh = logging.StreamHandler() + sh.setFormatter(formatter) + logger.addHandler(sh) + else: + # Non-rank-0 processes: add a null handler and do not output any logs + logger.addHandler(logging.NullHandler()) + + return logger \ No newline at end of file diff --git a/training/loss/__init__.py b/training/loss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b01ca1903d5bd7ab4a7b4c076037669baffa7f14 --- /dev/null +++ b/training/loss/__init__.py @@ -0,0 +1,32 @@ +import os +import sys +current_file_path = os.path.abspath(__file__) +parent_dir = os.path.dirname(os.path.dirname(current_file_path)) +project_root_dir = os.path.dirname(parent_dir) +sys.path.append(parent_dir) +sys.path.append(project_root_dir) + +from metrics.registry import LOSSFUNC + +from .cross_entropy_loss import CrossEntropyLoss +from .consistency_loss import ConsistencyCos +from .capsule_loss import CapsuleLoss +from .bce_loss import BCELoss +from .am_softmax import AMSoftmaxLoss +from .am_softmax import AMSoftmax_OHEM +from .contrastive_regularization import ContrastiveLoss +from .l1_loss import L1Loss +from .id_loss import IDLoss +from .vgg_loss import VGGLoss +from .js_loss import JS_Loss +from .patch_consistency_loss import PatchConsistencyLoss +from .region_independent_loss import RegionIndependentLoss +from .supercontrast_loss import SupConLoss +from .supercontrast_cls_loss import SupConClsLoss +from .cross_entropy_orth_loss import CrossEntropyOrthLoss + +from .cross_entropy_orth1_loss import CrossEntropyOrth1Loss # CE + Orth + MSE +from .cross_entropy_orth2_loss import CrossEntropyOrth2Loss # CE + Orth + CosineSim +from .cross_entropy_orth3_loss import CrossEntropyOrth3Loss # CE + + MSE +from .cross_entropy_orth4_loss import CrossEntropyOrth4Loss # CE + Orth + RelaxedCosineSim(th=0.9) + diff --git a/training/loss/abstract_loss_func.py b/training/loss/abstract_loss_func.py new file mode 100644 index 0000000000000000000000000000000000000000..45d3324ed53be4310867b326e9eaabd265634138 --- /dev/null +++ b/training/loss/abstract_loss_func.py @@ -0,0 +1,17 @@ +import torch.nn as nn + +class AbstractLossClass(nn.Module): + """Abstract class for loss functions.""" + def __init__(self): + super(AbstractLossClass, self).__init__() + + def forward(self, pred, label): + """ + Args: + pred: prediction of the model + label: ground truth label + + Return: + loss: loss value + """ + raise NotImplementedError('Each subclass should implement the forward method.') diff --git a/training/loss/am_softmax.py b/training/loss/am_softmax.py new file mode 100644 index 0000000000000000000000000000000000000000..88b1df5236d19ecfddea8b1bc377733ff3aa6195 --- /dev/null +++ b/training/loss/am_softmax.py @@ -0,0 +1,145 @@ +""" + Copyright (c) 2018 Intel Corporation + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Parameter +import torch as th + +from .abstract_loss_func import AbstractLossClass +from metrics.registry import LOSSFUNC + + +#------------ AMSoftmax Loss ---------------------- + +def focal_loss(input_values, gamma): + """Computes the focal loss""" + p = torch.exp(-input_values) + loss = (1 - p) ** gamma * input_values + return loss.mean() + + +@LOSSFUNC.register_module(module_name="am_softmax") +class AMSoftmaxLoss(AbstractLossClass): + """Computes the AM-Softmax loss with cos or arc margin""" + margin_types = ['cos', 'arc'] + + def __init__(self, margin_type='cos', gamma=0., m=0.5, s=30, t=1.): + super().__init__() + assert margin_type in AMSoftmaxLoss.margin_types + self.margin_type = margin_type + assert gamma >= 0 + self.gamma = gamma + assert m > 0 + self.m = m + assert s > 0 + self.s = s + self.cos_m = math.cos(m) + self.sin_m = math.sin(m) + self.th = math.cos(math.pi - m) + assert t >= 1 + self.t = t + + def forward(self, cos_theta, target): + if self.margin_type == 'cos': + phi_theta = cos_theta - self.m + else: + sine = torch.sqrt(1.0 - torch.pow(cos_theta, 2)) + phi_theta = cos_theta * self.cos_m - sine * self.sin_m #cos(theta+m) + phi_theta = torch.where(cos_theta > self.th, phi_theta, cos_theta - self.sin_m * self.m) + + index = torch.zeros_like(cos_theta, dtype=torch.uint8) + index.scatter_(1, target.data.view(-1, 1), 1) + output = torch.where(index, phi_theta, cos_theta) + + if self.gamma == 0 and self.t == 1.: + return F.cross_entropy(self.s*output, target) + + if self.t > 1: + h_theta = self.t - 1 + self.t*cos_theta + support_vecs_mask = (1 - index) * \ + torch.lt(torch.masked_select(phi_theta, index).view(-1, 1).repeat(1, h_theta.shape[1]) - cos_theta, 0) + output = torch.where(support_vecs_mask, h_theta, output) + return F.cross_entropy(self.s*output, target) + + return focal_loss(F.cross_entropy(self.s*output, target, reduction='none'), self.gamma) + + +@LOSSFUNC.register_module(module_name="am_softmax_ohem") +class AMSoftmax_OHEM(AbstractLossClass): + """Computes the AM-Softmax loss with cos or arc margin""" + margin_types = ['cos', 'arc'] + + def __init__(self, margin_type='cos', gamma=0., m=0.5, s=30, t=1., ratio=1.): + super(self).__init__() + assert margin_type in AMSoftmaxLoss.margin_types + self.margin_type = margin_type + assert gamma >= 0 + self.gamma = gamma + assert m > 0 + self.m = m + assert s > 0 + self.s = s + self.cos_m = math.cos(m) + self.sin_m = math.sin(m) + self.th = math.cos(math.pi - m) + assert t >= 1 + self.t = t + self.ratio = ratio + + + # ------- online hard example mining -------------------- + def get_subidx(self,x,y,ratio): + num_inst = x.size(0) + num_hns = int(ratio * num_inst) + x_ = x.clone() + inst_losses = th.autograd.Variable(th.zeros(num_inst)).cuda() + + for idx, label in enumerate(y.data): + inst_losses[idx] = -x_.data[idx, label] + + _, idxs = inst_losses.topk(num_hns) + return idxs + + + def forward(self, cos_theta, target): + if self.margin_type == 'cos': + phi_theta = cos_theta - self.m + else: + sine = torch.sqrt(1.0 - torch.pow(cos_theta, 2)) + phi_theta = cos_theta * self.cos_m - sine * self.sin_m #cos(theta+m) + phi_theta = torch.where(cos_theta > self.th, phi_theta, cos_theta - self.sin_m * self.m) + + index = torch.zeros_like(cos_theta, dtype=torch.uint8) + index.scatter_(1, target.data.view(-1, 1), 1) + output = torch.where(index, phi_theta, cos_theta) + + out = F.log_softmax(output,dim=1) + idxs = self.get_subidx(out,target,self.ratio) # select hard examples + + output2 = output.index_select(0, idxs) + target2 = target.index_select(0, idxs) + + if self.gamma == 0 and self.t == 1.: + return F.cross_entropy(self.s*output2, target2) + + if self.t > 1: + h_theta = self.t - 1 + self.t*cos_theta + support_vecs_mask = (1 - index) * \ + torch.lt(torch.masked_select(phi_theta, index).view(-1, 1).repeat(1, h_theta.shape[1]) - cos_theta, 0) + output2 = torch.where(support_vecs_mask, h_theta, output2) + return F.cross_entropy(self.s*output2, target2) + + return focal_loss(F.cross_entropy(self.s*output2, target2, reduction='none'), self.gamma) \ No newline at end of file diff --git a/training/loss/bce_loss.py b/training/loss/bce_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..3641878fbe109fb4247f60d683401ce76a797bf7 --- /dev/null +++ b/training/loss/bce_loss.py @@ -0,0 +1,26 @@ +import torch.nn as nn +from .abstract_loss_func import AbstractLossClass +from metrics.registry import LOSSFUNC + + +@LOSSFUNC.register_module(module_name="bce") +class BCELoss(AbstractLossClass): + def __init__(self): + super().__init__() + self.loss_fn = nn.BCELoss() + + def forward(self, inputs, targets): + """ + Computes the bce loss. + + Args: + inputs: A PyTorch tensor of size (batch_size, num_classes) containing the predicted scores. + targets: A PyTorch tensor of size (batch_size) containing the ground-truth class indices. + + Returns: + A scalar tensor representing the bce loss. + """ + # Compute the bce loss + loss = self.loss_fn(inputs, targets.float()) + + return loss \ No newline at end of file diff --git a/training/loss/capsule_loss.py b/training/loss/capsule_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..df13bb3b768ec56ae4117fe58199461f295618b5 --- /dev/null +++ b/training/loss/capsule_loss.py @@ -0,0 +1,28 @@ +import torch.nn as nn +from .abstract_loss_func import AbstractLossClass +from metrics.registry import LOSSFUNC + + +@LOSSFUNC.register_module(module_name="capsule_loss") +class CapsuleLoss(AbstractLossClass): + def __init__(self): + super().__init__() + self.cross_entropy_loss = nn.CrossEntropyLoss() + + def forward(self, inputs, targets): + """ + Computes the capsule loss. + + Args: + inputs: A PyTorch tensor of size (batch_size, num_classes) containing the predicted scores. + targets: A PyTorch tensor of size (batch_size) containing the ground-truth class indices. + + Returns: + A scalar tensor representing the capsule loss. + """ + # Compute the capsule loss + loss_t = self.cross_entropy_loss(inputs[:,0,:], targets) + + for i in range(inputs.size(1) - 1): + loss_t = loss_t + self.cross_entropy_loss(inputs[:,i+1,:], targets) + return loss_t diff --git a/training/loss/consistency_loss.py b/training/loss/consistency_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..89a7ce69ef5075f6260d5371cf280577a55c2691 --- /dev/null +++ b/training/loss/consistency_loss.py @@ -0,0 +1,54 @@ +import torch.nn as nn +import torch +from .abstract_loss_func import AbstractLossClass +from metrics.registry import LOSSFUNC + + +@LOSSFUNC.register_module(module_name="consistency_loss") +class ConsistencyCos(nn.Module): + def __init__(self): + super(ConsistencyCos, self).__init__() + # # CrossEntropy Loss + # weight=torch.Tensor([4.0, 1.0]) + # if torch.cuda.is_available(): + # weight = weight.cuda() + # self.loss_fn = nn.CrossEntropyLoss(weight) + self.loss_fn = nn.CrossEntropyLoss() + self.mse_fn = nn.MSELoss() + + def forward(self, feat, inputs, targets): + feat = nn.functional.normalize(feat, dim=1) + feat_0 = feat[:int(feat.size(0)/2),:] + feat_1 = feat[int(feat.size(0)/2): 2*int(feat.size(0)/2),:] + + cos = torch.einsum('nc,nc->n', [feat_0, feat_1]).unsqueeze(-1) + labels = torch.ones((cos.shape[0],1), dtype=torch.float, requires_grad=False) + if torch.cuda.is_available(): + labels = labels.cuda() + self.consistency_rate = 1.0 + loss = self.consistency_rate * self.mse_fn(cos, labels) + self.loss_fn(inputs, targets) + return loss + +# +##FIXME to be implemented +class ConsistencyL2(nn.Module): + def __init__(self): + super(ConsistencyL2, self).__init__() + self.mse_fn = nn.MSELoss() + + def forward(self, feat): + feat_0 = feat[:int(feat.size(0)/2),:] + feat_1 = feat[int(feat.size(0)/2):,:] + loss = self.mse_fn(feat_0, feat_1) + return loss + +class ConsistencyL1(nn.Module): + def __init__(self): + super(ConsistencyL1, self).__init__() + self.L1_fn = nn.L1Loss() + + def forward(self, feat): + feat_0 = feat[:int(feat.size(0)/2),:] + feat_1 = feat[int(feat.size(0)/2):,:] + loss = self.L1_fn(feat_0, feat_1) + return loss \ No newline at end of file diff --git a/training/loss/contrastive_regularization.py b/training/loss/contrastive_regularization.py new file mode 100644 index 0000000000000000000000000000000000000000..8e5bb7c3fee3dc66f6c2028ea0a1dbffbe25476c --- /dev/null +++ b/training/loss/contrastive_regularization.py @@ -0,0 +1,78 @@ +import random +from collections import defaultdict +import torch +import torch.nn as nn +import torch.nn.functional as F +from .abstract_loss_func import AbstractLossClass +from metrics.registry import LOSSFUNC + + +def swap_spe_features(type_list, value_list): + type_list = type_list.cpu().numpy().tolist() + # get index + index_list = list(range(len(type_list))) + + # init a dict, where its key is the type and value is the index + spe_dict = defaultdict(list) + + # do for-loop to get spe dict + for i, one_type in enumerate(type_list): + spe_dict[one_type].append(index_list[i]) + + # shuffle the value list of each key + for keys in spe_dict.keys(): + random.shuffle(spe_dict[keys]) + + # generate a new index list for the value list + new_index_list = [] + for one_type in type_list: + value = spe_dict[one_type].pop() + new_index_list.append(value) + + # swap the value_list by new_index_list + value_list_new = value_list[new_index_list] + + return value_list_new + + +@LOSSFUNC.register_module(module_name="contrastive_regularization") +class ContrastiveLoss(AbstractLossClass): + def __init__(self, margin=1.0): + super().__init__() + self.margin = margin + + def contrastive_loss(self, anchor, positive, negative): + dist_pos = F.pairwise_distance(anchor, positive) + dist_neg = F.pairwise_distance(anchor, negative) + # Compute loss as the distance between anchor and negative minus the distance between anchor and positive + loss = torch.mean(torch.clamp(dist_pos - dist_neg + self.margin, min=0.0)) + return loss + + def forward(self, common, specific, spe_label): + # prepare + bs = common.shape[0] + real_common, fake_common = common.chunk(2) + ### common real + idx_list = list(range(0, bs//2)) + random.shuffle(idx_list) + real_common_anchor = common[idx_list] + ### common fake + idx_list = list(range(bs//2, bs)) + random.shuffle(idx_list) + fake_common_anchor = common[idx_list] + ### specific + specific_anchor = swap_spe_features(spe_label, specific) + real_specific_anchor, fake_specific_anchor = specific_anchor.chunk(2) + real_specific, fake_specific = specific.chunk(2) + + # Compute the contrastive loss of common between real and fake + loss_realcommon = self.contrastive_loss(real_common, real_common_anchor, fake_common_anchor) + loss_fakecommon = self.contrastive_loss(fake_common, fake_common_anchor, real_common_anchor) + + # Comupte the constrastive loss of specific between real and fake + loss_realspecific = self.contrastive_loss(real_specific, real_specific_anchor, fake_specific_anchor) + loss_fakespecific = self.contrastive_loss(fake_specific, fake_specific_anchor, real_specific_anchor) + + # Compute the final loss as the sum of all contrastive losses + loss = loss_realcommon + loss_fakecommon + loss_fakespecific + loss_realspecific + return loss \ No newline at end of file diff --git a/training/loss/cross_entropy_loss.py b/training/loss/cross_entropy_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..efa7123ed0ee0516743fa41d43b53e063c21a460 --- /dev/null +++ b/training/loss/cross_entropy_loss.py @@ -0,0 +1,26 @@ +import torch.nn as nn +from .abstract_loss_func import AbstractLossClass +from metrics.registry import LOSSFUNC + + +@LOSSFUNC.register_module(module_name="cross_entropy") +class CrossEntropyLoss(AbstractLossClass): + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, inputs, targets): + """ + Computes the cross-entropy loss. + + Args: + inputs: A PyTorch tensor of size (batch_size, num_classes) containing the predicted scores. + targets: A PyTorch tensor of size (batch_size) containing the ground-truth class indices. + + Returns: + A scalar tensor representing the cross-entropy loss. + """ + # Compute the cross-entropy loss + loss = self.loss_fn(inputs, targets) + + return loss \ No newline at end of file diff --git a/training/loss/cross_entropy_orth1_loss.py b/training/loss/cross_entropy_orth1_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..eee4c0b28171fb0db0bed74a6d058e6db300d462 --- /dev/null +++ b/training/loss/cross_entropy_orth1_loss.py @@ -0,0 +1,48 @@ +import torch.nn as nn +from .abstract_loss_func import AbstractLossClass +from metrics.registry import LOSSFUNC +import torch + +class OrthogonalLoss(nn.Module): + """Orthogonal loss: minimize the squared normalized dot product of two features""" + def __init__(self, eps=1e-8): + super().__init__() + self.eps = eps + + def forward(self, feat1, feat2): + assert feat1.shape == feat2.shape, "Feature shapes must match" + dot_product = torch.sum(feat1 * feat2, dim=1, keepdim=True) # dot product + norm1 = torch.norm(feat1, dim=1, keepdim=True) + self.eps # norm of feature 1 + norm2 = torch.norm(feat2, dim=1, keepdim=True) + self.eps # norm of feature 2 + normalized_dot = dot_product / (norm1 * norm2) # normalized dot product + return torch.mean(normalized_dot **2) # minimize the squared value (target is 0) + + +@LOSSFUNC.register_module(module_name="cross_entropy_orth1") +class CrossEntropyOrth1Loss(AbstractLossClass): + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + self.loss_mse = nn.MSELoss() + self.loss_orth = OrthogonalLoss() + self.eps = 1e-8 + + def forward(self, inputs, targets, feat_clip, feat_s, feat_f): + """ + Computes the cross-entropy loss. + + Args: + inputs: A PyTorch tensor of size (batch_size, num_classes) containing the predicted scores. + targets: A PyTorch tensor of size (batch_size) containing the ground-truth class indices. + + Returns: + A scalar tensor representing the cross-entropy loss. + """ + # Compute the cross-entropy loss + loss_ce = self.loss_fn(inputs, targets) + # + loss_orth = self.loss_orth(feat_s, feat_f) + # + loss_mse = self.loss_mse(feat_clip, feat_s) * 0.1 + + return loss_ce, loss_orth, loss_mse diff --git a/training/loss/cross_entropy_orth2_loss.py b/training/loss/cross_entropy_orth2_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..ab934c829d66ce05faa87a762ac604a371c9d35c --- /dev/null +++ b/training/loss/cross_entropy_orth2_loss.py @@ -0,0 +1,77 @@ +import torch.nn as nn +from .abstract_loss_func import AbstractLossClass +from metrics.registry import LOSSFUNC +import torch + +class OrthogonalLoss(nn.Module): + """Orthogonal loss: minimize the squared normalized dot product of two features""" + def __init__(self, eps=1e-8): + super().__init__() + self.eps = eps + + def forward(self, feat1, feat2): + assert feat1.shape == feat2.shape, "Feature shapes must match" + dot_product = torch.sum(feat1 * feat2, dim=1, keepdim=True) # dot product + norm1 = torch.norm(feat1, dim=1, keepdim=True) + self.eps # norm of feature 1 + norm2 = torch.norm(feat2, dim=1, keepdim=True) + self.eps # norm of feature 2 + normalized_dot = dot_product / (norm1 * norm2) # normalized dot product + return torch.mean(normalized_dot **2) # minimize the squared value (target is 0) + + +# Custom cosine-similarity loss (constrains direction only) +class CosineSimilarityLoss(nn.Module): + def __init__(self, eps=1e-8): + super().__init__() + self.eps = eps # prevent division by zero + + def forward(self, a, b): + """ + Compute directional loss between two feature vectors (1 - cosine similarity) + Args: + a: feature vectors of shape (batch_size, feat_dim) + b: feature vectors of shape (batch_size, feat_dim) + Returns: + Directional loss (scalar); smaller values indicate more consistent directions + """ + # Compute vector dot products + dot_product = torch.sum(a * b, dim=-1) # (batch_size,) + + # Compute vector norms + norm_a = torch.norm(a, dim=-1) # (batch_size,) + norm_b = torch.norm(b, dim=-1) # (batch_size,) + + # Compute cosine similarity (add eps to prevent division by zero) + cos_sim = dot_product / (norm_a * norm_b + self.eps) # (batch_size,) + + # Loss = 1 - mean cosine similarity (target: cos_sim → 1, loss → 0) + return 1 - torch.mean(cos_sim) + + +@LOSSFUNC.register_module(module_name="cross_entropy_orth2") +class CrossEntropyOrth2Loss(AbstractLossClass): + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + self.loss_dir = CosineSimilarityLoss() + self.loss_orth = OrthogonalLoss() + self.eps = 1e-8 + + def forward(self, inputs, targets, feat_clip, feat_s, feat_f): + """ + Computes the cross-entropy loss. + + Args: + inputs: A PyTorch tensor of size (batch_size, num_classes) containing the predicted scores. + targets: A PyTorch tensor of size (batch_size) containing the ground-truth class indices. + + Returns: + A scalar tensor representing the cross-entropy loss. + """ + # Compute the cross-entropy loss + loss_ce = self.loss_fn(inputs, targets) + # + loss_orth = self.loss_orth(feat_s, feat_f) + # + loss_dir = self.loss_dir(feat_clip, feat_s) * 0.1 + + return loss_ce, loss_orth, loss_dir diff --git a/training/loss/cross_entropy_orth3_loss.py b/training/loss/cross_entropy_orth3_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..a045093688a9f2b975309cb6d34c9507f6c6724f --- /dev/null +++ b/training/loss/cross_entropy_orth3_loss.py @@ -0,0 +1,77 @@ +import torch.nn as nn +from .abstract_loss_func import AbstractLossClass +from metrics.registry import LOSSFUNC +import torch + +class OrthogonalLoss(nn.Module): + """Orthogonal loss: minimize the squared normalized dot product of two features""" + def __init__(self, eps=1e-8): + super().__init__() + self.eps = eps + + def forward(self, feat1, feat2): + assert feat1.shape == feat2.shape, "Feature shapes must match" + dot_product = torch.sum(feat1 * feat2, dim=1, keepdim=True) # dot product + norm1 = torch.norm(feat1, dim=1, keepdim=True) + self.eps # norm of feature 1 + norm2 = torch.norm(feat2, dim=1, keepdim=True) + self.eps # norm of feature 2 + normalized_dot = dot_product / (norm1 * norm2) # normalized dot product + return torch.mean(normalized_dot **2) # minimize the squared value (target is 0) + + +# Custom cosine-similarity loss (constrains direction only) +class CosineSimilarityLoss(nn.Module): + def __init__(self, eps=1e-8): + super().__init__() + self.eps = eps # prevent division by zero + + def forward(self, a, b): + """ + Compute directional loss between two feature vectors (1 - cosine similarity) + Args: + a: feature vectors of shape (batch_size, feat_dim) + b: feature vectors of shape (batch_size, feat_dim) + Returns: + Directional loss (scalar); smaller values indicate more consistent directions + """ + # Compute vector dot products + dot_product = torch.sum(a * b, dim=-1) # (batch_size,) + + # Compute vector norms + norm_a = torch.norm(a, dim=-1) # (batch_size,) + norm_b = torch.norm(b, dim=-1) # (batch_size,) + + # Compute cosine similarity (add eps to prevent division by zero) + cos_sim = dot_product / (norm_a * norm_b + self.eps) # (batch_size,) + + # Loss = 1 - mean cosine similarity (target: cos_sim → 1, loss → 0) + return 1 - torch.mean(cos_sim) + + +@LOSSFUNC.register_module(module_name="cross_entropy_orth3") +class CrossEntropyOrth3Loss(AbstractLossClass): + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + self.loss_dir = CosineSimilarityLoss() + self.loss_orth = OrthogonalLoss() + self.eps = 1e-8 + + def forward(self, inputs, targets, feat_clip, feat_s, feat_f): + """ + Computes the cross-entropy loss. + + Args: + inputs: A PyTorch tensor of size (batch_size, num_classes) containing the predicted scores. + targets: A PyTorch tensor of size (batch_size) containing the ground-truth class indices. + + Returns: + A scalar tensor representing the cross-entropy loss. + """ + # Compute the cross-entropy loss + loss_ce = self.loss_fn(inputs, targets) + # + loss_orth = 0 # self.loss_orth(feat_s, feat_f) + # + loss_dir = self.loss_dir(feat_clip, feat_s) + + return loss_ce, loss_orth, loss_dir diff --git a/training/loss/cross_entropy_orth4_loss.py b/training/loss/cross_entropy_orth4_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..8bb8decf6b684bd234689660f8ffea9913cd522a --- /dev/null +++ b/training/loss/cross_entropy_orth4_loss.py @@ -0,0 +1,79 @@ +import torch.nn as nn +from .abstract_loss_func import AbstractLossClass +from metrics.registry import LOSSFUNC +import torch + +class OrthogonalLoss(nn.Module): + """Orthogonal loss: minimize the squared normalized dot product of two features""" + def __init__(self, eps=1e-8): + super().__init__() + self.eps = eps + + def forward(self, feat1, feat2): + assert feat1.shape == feat2.shape, "Feature shapes must match" + dot_product = torch.sum(feat1 * feat2, dim=1, keepdim=True) # dot product + norm1 = torch.norm(feat1, dim=1, keepdim=True) + self.eps # norm of feature 1 + norm2 = torch.norm(feat2, dim=1, keepdim=True) + self.eps # norm of feature 2 + normalized_dot = dot_product / (norm1 * norm2) # normalized dot product + return torch.mean(normalized_dot **2) # minimize the squared value (target is 0) + + +# Custom cosine-similarity loss (constrains direction only) +class RelaxedCosineSimilarityLoss(nn.Module): + def __init__(self, tolerance=0.9, eps=1e-8): + super().__init__() + self.tolerance = tolerance # tolerance threshold (between 0 and 1; larger values make the constraint looser) + self.eps = eps + + def forward(self, a, b): + """ + Relaxed directional loss: no penalty when cosine similarity ≥ tolerance; otherwise compute the loss + Args: + a: feature vector (batch_size, feat_dim) + b: feature vector (batch_size, feat_dim) + Returns: + Relaxed directional loss (scalar) + """ + # Compute cosine similarity (consistent with the original implementation) + dot_product = torch.sum(a * b, dim=-1) + norm_a = torch.norm(a, dim=-1) + norm_b = torch.norm(b, dim=-1) + cos_sim = dot_product / (norm_a * norm_b + self.eps) # (batch_size,) + + # Only compute loss when the similarity is below the threshold: max(0, tolerance - cos_sim) + # When cos_sim ≥ tolerance, the loss is 0 (the difference is accepted) + # When cos_sim < tolerance, the loss increases as the difference grows + loss_per_sample = torch.max(torch.zeros_like(cos_sim), self.tolerance - cos_sim) + + # Return the average loss + return torch.mean(loss_per_sample) + + +@LOSSFUNC.register_module(module_name="cross_entropy_orth4") +class CrossEntropyOrth4Loss(AbstractLossClass): + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + self.loss_dir = RelaxedCosineSimilarityLoss(tolerance=0.9) + self.loss_orth = OrthogonalLoss() + self.eps = 1e-8 + + def forward(self, inputs, targets, feat_clip, feat_s, feat_f): + """ + Computes the cross-entropy loss. + + Args: + inputs: A PyTorch tensor of size (batch_size, num_classes) containing the predicted scores. + targets: A PyTorch tensor of size (batch_size) containing the ground-truth class indices. + + Returns: + A scalar tensor representing the cross-entropy loss. + """ + # Compute the cross-entropy loss + loss_ce = self.loss_fn(inputs, targets) + # + loss_orth = self.loss_orth(feat_s, feat_f) + # + loss_dir = self.loss_dir(feat_clip, feat_s) + + return loss_ce, loss_orth, loss_dir diff --git a/training/loss/cross_entropy_orth_loss.py b/training/loss/cross_entropy_orth_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..b040131dcf6aa44f2a73db6402cc511ade1272b5 --- /dev/null +++ b/training/loss/cross_entropy_orth_loss.py @@ -0,0 +1,39 @@ +import torch.nn as nn +from .abstract_loss_func import AbstractLossClass +from metrics.registry import LOSSFUNC +import torch + +@LOSSFUNC.register_module(module_name="cross_entropy_orth") +class CrossEntropyOrthLoss(AbstractLossClass): + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + self.eps = 1e-8 + + def forward(self, inputs, targets, feat_s, feat_f): + """ + Computes the cross-entropy loss. + + Args: + inputs: A PyTorch tensor of size (batch_size, num_classes) containing the predicted scores. + targets: A PyTorch tensor of size (batch_size) containing the ground-truth class indices. + + Returns: + A scalar tensor representing the cross-entropy loss. + """ + # Compute the cross-entropy loss + loss_ce = self.loss_fn(inputs, targets) + + # Orthogonal loss + # Compute the feature dot product for each sample (batch_size, 1) + dot_product = torch.sum(feat_s * feat_f, dim=1, keepdim=True) + + # For stable training, divide by the product of feature norms (normalized dot product) + norm1 = torch.norm(feat_s, dim=1, keepdim=True) + self.eps # (batch_size, 1) + norm2 = torch.norm(feat_f, dim=1, keepdim=True) + self.eps # (batch_size, 1) + normalized_dot = dot_product / (norm1 * norm2) # normalized range [-1, 1] + + # Minimize the squared normalized dot product (target is 0) + loss_orth = torch.mean(normalized_dot **2) + + return loss_ce, loss_orth \ No newline at end of file diff --git a/training/loss/id_loss.py b/training/loss/id_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..43cc0a12ed3c7b5aad53bb7caec82c01d4fa29aa --- /dev/null +++ b/training/loss/id_loss.py @@ -0,0 +1,16 @@ +import torch +import torch.nn as nn +from .abstract_loss_func import AbstractLossClass +from metrics.registry import LOSSFUNC + +@LOSSFUNC.register_module(module_name="id_loss") +class IDLoss(AbstractLossClass): + def __init__(self, margin=0.5): + super().__init__() + self.cosine_similarity = nn.CosineSimilarity(dim=1, eps=1e-6) + self.margin = margin + + def forward(self, x1, x2): + cosine_similarity = self.cosine_similarity(x1, x2) + theta = torch.acos(cosine_similarity) + return 1 - torch.cos(theta + self.margin) \ No newline at end of file diff --git a/training/loss/js_loss.py b/training/loss/js_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..ccdce81ef8eb4e8362f13e305ccb2255fae42bc4 --- /dev/null +++ b/training/loss/js_loss.py @@ -0,0 +1,32 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from .abstract_loss_func import AbstractLossClass +from metrics.registry import LOSSFUNC + + +@LOSSFUNC.register_module(module_name="jsloss") +class JS_Loss(AbstractLossClass): + def __init__(self): + super().__init__() + + def forward(self, inputs, targets): + """ + Computes the Jensen-Shannon divergence loss. + """ + # Compute the probability distributions + inputs_prob = F.softmax(inputs, dim=1) + targets_prob = F.softmax(targets, dim=1) + + # Compute the average probability distribution + avg_prob = (inputs_prob + targets_prob) / 2 + + # Compute the KL divergence component for each distribution + kl_div_loss = nn.KLDivLoss(reduction='batchmean') + kl_inputs = kl_div_loss(inputs_prob.log(), avg_prob) + kl_targets = kl_div_loss(targets_prob.log(), avg_prob) + + # Compute the Jensen-Shannon divergence + loss = 0.5 * (kl_inputs + kl_targets) + + return loss \ No newline at end of file diff --git a/training/loss/l1_loss.py b/training/loss/l1_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..f2bfdedb628c802ef8ef8ebe977ad5c9a3ce1a37 --- /dev/null +++ b/training/loss/l1_loss.py @@ -0,0 +1,19 @@ +import torch.nn as nn +from .abstract_loss_func import AbstractLossClass +from metrics.registry import LOSSFUNC + + +@LOSSFUNC.register_module(module_name="l1loss") +class L1Loss(AbstractLossClass): + def __init__(self): + super().__init__() + self.loss_fn = nn.L1Loss() + + def forward(self, inputs, targets): + """ + Computes the l1 loss. + """ + # Compute the l1 loss + loss = self.loss_fn(inputs, targets) + + return loss \ No newline at end of file diff --git a/training/loss/loss_supcon.py b/training/loss/loss_supcon.py new file mode 100644 index 0000000000000000000000000000000000000000..41be243e2a7a37d5872fc3727d872877ac885506 --- /dev/null +++ b/training/loss/loss_supcon.py @@ -0,0 +1,89 @@ + +class SupConLoss(AbstractLossClass): + """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. + It also supports the unsupervised contrastive loss in SimCLR""" + def __init__(self, temperature=0.07, contrast_mode='all', + base_temperature=0.07): + super().__init__() + self.temperature = temperature + self.contrast_mode = contrast_mode + self.base_temperature = base_temperature + + def forward(self, features, labels=None, mask=None): + """Compute loss for model. If both `labels` and `mask` are None, + it degenerates to SimCLR unsupervised loss: + https://arxiv.org/pdf/2002.05709.pdf + + Args: + features: hidden vector of shape [bsz, n_views, ...]. + labels: ground truth of shape [bsz]. + mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j + has the same class as sample i. Can be asymmetric. + Returns: + A loss scalar. + """ + device = (torch.device("cuda") if features.is_cuda else torch.device("cpu")) + + if len(features.shape) == 2: + features = torch.unsqueeze(features, dim=1) + # raise ValueError('`features` needs to be [bsz, n_views, ...],' + # 'at least 3 dimensions are required') + if len(features.shape) > 3: + features = features.view(features.shape[0], features.shape[1], -1) + + batch_size = features.shape[0] + if labels is not None and mask is not None: + raise ValueError('Cannot define both `labels` and `mask`') + elif labels is None and mask is None: + mask = torch.eye(batch_size, dtype=torch.float32).to(device) + elif labels is not None: + labels = labels.contiguous().view(-1, 1) + if labels.shape[0] != batch_size: + raise ValueError('Num of labels does not match num of features') + mask = torch.eq(labels, labels.T).to(device) + else: + mask = mask.to(device) + + contrast_count = features.shape[1] + contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) + if self.contrast_mode == 'one': + anchor_feature = features[:, 0] + anchor_count = 1 + elif self.contrast_mode == 'all': + anchor_feature = contrast_feature + anchor_count = contrast_count + else: + raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) + + # compute logits + anchor_dot_contrast = torch.div(torch.matmul(anchor_feature, contrast_feature.T), self.temperature) # [B,D] * [D,B] / 0.07 ==> [B,B] + # for numerical stability + logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) + logits = anchor_dot_contrast - logits_max.detach() + + # tile mask + mask = mask.repeat(anchor_count, contrast_count) + # mask-out self-contrast cases + logits_mask = torch.scatter(torch.ones_like(mask), 1, torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 0) + mask = mask * logits_mask + + # compute log_prob + exp_logits = torch.exp(logits) * logits_mask + log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) + + # compute mean of log-likelihood over positive + # modified to handle edge cases when there is no positive pair + # for an anchor point. + # Edge case e.g.:- + # features of shape: [4,1,...] + # labels: [0,1,1,2] + # loss before mean: [nan, ..., ..., nan] + mask_pos_pairs = mask.sum(1) + mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs) + mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs + + # loss + loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos + loss = loss.view(anchor_count, batch_size).mean() + + return loss \ No newline at end of file diff --git a/training/loss/patch_consistency_loss.py b/training/loss/patch_consistency_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..5b18016e5de81ddfc034b003db5e60b29470739b --- /dev/null +++ b/training/loss/patch_consistency_loss.py @@ -0,0 +1,76 @@ +import torch +from metrics.registry import LOSSFUNC +from .abstract_loss_func import AbstractLossClass + + +def mahalanobis_distance(values: torch.Tensor, mean: torch.Tensor, inv_covariance: torch.Tensor) -> torch.Tensor: + """Compute the batched mahalanobis distance. + + values is a batch of feature vectors. + mean is either the mean of the distribution to compare, or a second + batch of feature vectors. + inv_covariance is the inverse covariance of the target distribution. + """ + assert values.dim() == 2 + assert 1 <= mean.dim() <= 2 + assert inv_covariance.dim() == 2 + assert values.shape[1] == mean.shape[-1] + assert mean.shape[-1] == inv_covariance.shape[0] + assert inv_covariance.shape[0] == inv_covariance.shape[1] + + if mean.dim() == 1: # Distribution mean. + mean = mean.unsqueeze(0) + x_mu = values - mean # batch x features + # Same as dist = x_mu.t() * inv_covariance * x_mu batch wise + dist = torch.einsum("im,mn,in->i", x_mu, inv_covariance, x_mu) + + return dist.sqrt() + + +@LOSSFUNC.register_module(module_name="patch_consistency_loss") +class PatchConsistencyLoss(AbstractLossClass): + def __init__(self, c_real, c_fake, c_cross): + super().__init__() + self.c_real = c_real + self.c_fake = c_fake + self.c_cross = c_cross + + def forward(self, attention_map_real, attention_map_fake, feature_patch, real_feature_mean, real_inv_covariance, + fake_feature_mean, fake_inv_covariance, labels): + # calculate mahalanobis distance + B, H, W, C = feature_patch.size() + dist_real = mahalanobis_distance(feature_patch.reshape(B * H * W, C), real_feature_mean.cuda(), + real_inv_covariance.cuda()) + dist_fake = mahalanobis_distance(feature_patch.reshape(B * H * W, C), fake_feature_mean.cuda(), + fake_inv_covariance.cuda()) + fake_indices = torch.where(labels == 1.0)[0] + index_map = torch.relu(dist_real - dist_fake).reshape((B, H, W))[fake_indices, :] + + # loss for real samples + if attention_map_real.shape[0] == 0: + loss_real = 0 + else: + B, PP, PP = attention_map_real.shape + c_matrix = (1 - self.c_real) * torch.eye(PP).cuda() + self.c_real * torch.ones(PP).cuda() + c_matrix = c_matrix.expand(B, -1, -1) + loss_real = torch.sum(torch.abs(attention_map_real - c_matrix)) / (B * (PP * PP - PP)) + + if attention_map_fake.shape[0] == 0: + loss_fake = 0 + else: + B, PP, PP = attention_map_fake.shape + c_matrix = [] + for b in range(B): + fake_indices = torch.where(index_map[b].reshape(-1) > 0)[0] + real_indices = torch.where(index_map[b].reshape(-1) <= 0)[0] + tmp = torch.zeros((PP, PP)).cuda() + self.c_cross + for i in fake_indices: + tmp[i, fake_indices] = self.c_fake + for i in real_indices: + tmp[i, real_indices] = self.c_real + c_matrix.append(tmp) + + c_matrix = torch.stack(c_matrix).cuda() + loss_fake = torch.sum(torch.abs(attention_map_fake - c_matrix)) / (B * (PP * PP - PP)) + + return loss_real + loss_fake diff --git a/training/loss/region_independent_loss.py b/training/loss/region_independent_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..71905d9aa493b9e7829eda287aed4a7e0bf52d5d --- /dev/null +++ b/training/loss/region_independent_loss.py @@ -0,0 +1,56 @@ +import torch +import torch.nn.functional as F +from detectors.multi_attention_detector import AttentionPooling +from .abstract_loss_func import AbstractLossClass +from metrics.registry import LOSSFUNC + + +@LOSSFUNC.register_module(module_name="region_independent_loss") +class RegionIndependentLoss(AbstractLossClass): + def __init__(self, M, N, alpha, alpha_decay, decay_batch, inter_margin, intra_margin): + super().__init__() + feature_centers = torch.zeros(M, N) + self.register_buffer("feature_centers", + feature_centers.cuda() if torch.cuda.is_available() else feature_centers) + self.alpha = alpha + self.alpha_decay = alpha_decay + self.decay_batch = decay_batch + self.batch_cnt = 0 + self.inter_margin = inter_margin + intra_margin = torch.Tensor(intra_margin) + self.register_buffer("intra_margin", intra_margin.cuda() if torch.cuda.is_available() else intra_margin) + self.atp = AttentionPooling() + + def forward(self, feature_maps_d, attention_maps, labels): + B, N, H, W = feature_maps_d.size() + B, M, AH, AW = attention_maps.size() + if AH != H or AW != W: + attention_maps = F.interpolate(attention_maps, (H, W), mode='bilinear', align_corners=True) + feature_matrix = self.atp(feature_maps_d, attention_maps) + + # Calculate new feature centers. P.s., I don't know why to use no_grad() and detach() for so many times. + feature_centers = self.feature_centers.detach() + new_feature_centers = feature_centers + self.alpha * torch.mean(feature_matrix - feature_centers, dim=0) + new_feature_centers = new_feature_centers.detach() + with torch.no_grad(): + self.feature_centers = new_feature_centers + + # Calculate intra-class loss + intra_margins = torch.gather(self.intra_margin.repeat(B, 1), dim=1, index=labels.unsqueeze(1)) + intra_class_loss = torch.mean(F.relu(torch.norm(feature_matrix - new_feature_centers, dim=-1) - intra_margins)) + + # Calculate inter-class loss + inter_class_loss = 0 + for i in range(M): + for j in range(i + 1, M): + inter_class_loss += F.relu( + self.inter_margin - torch.dist(new_feature_centers[i], new_feature_centers[j]), inplace=False) + inter_class_loss = inter_class_loss / M / self.alpha + + # Count batch, this is used to simulate epoch, since alpha cannot be modified based on epoch due to code + # structure. self.alpha should be modified every N batch. + self.batch_cnt += 1 + if self.batch_cnt % self.decay_batch == 0: + self.alpha *= self.alpha_decay + + return inter_class_loss + intra_class_loss diff --git a/training/loss/supercontrast_cls_loss.py b/training/loss/supercontrast_cls_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..60c42476032055c2d1f1b90a406a6672fd051ffb --- /dev/null +++ b/training/loss/supercontrast_cls_loss.py @@ -0,0 +1,108 @@ +""" +Author: Yonglong Tian (yonglong@mit.edu) +Date: May 07, 2020 +""" +from __future__ import print_function + +import torch +import torch.nn as nn +import torch.nn.functional as F +from .abstract_loss_func import AbstractLossClass +from metrics.registry import LOSSFUNC + + +@LOSSFUNC.register_module(module_name="supcon_cls") +class SupConClsLoss(AbstractLossClass): + """ + Refer to https://github.com/HobbitLong/SupContrast + Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. + It also supports the unsupervised contrastive loss in SimCLR + """ + def __init__(self, temperature=0.07, contrast_mode="all", base_temperature=0.07): + super(SupConClsLoss, self).__init__() + self.temperature = temperature + self.contrast_mode = contrast_mode + self.base_temperature = base_temperature + + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, features, pred, labels=None): + """ + Compute loss for model. If `labels` is None, + it degenerates to SimCLR unsupervised loss: + https://arxiv.org/pdf/2002.05709.pdf + + Args: + features: hidden vector of shape [bsz, n_views, ...]. + labels: ground truth of shape [bsz]. + Returns: + A loss scalar. + """ + device = (torch.device("cuda") if features.is_cuda else torch.device("cpu")) + + # cls loss + loss_cls = self.loss_fn(pred, labels) + + if len(features.shape) == 2: + features = torch.unsqueeze(features, dim=1) + # raise ValueError('`features` needs to be [bsz, n_views, ...],' + # 'at least 3 dimensions are required') + if len(features.shape) > 3: + features = features.view(features.shape[0], features.shape[1], -1)# [B, N_view, D] + + features = F.normalize(features, dim=-1) # Normalize each sample individually so that each sample has unit norm + + batch_size = features.shape[0] + if labels is not None: + labels = labels.contiguous().view(-1, 1) # [B, 1] + if labels.shape[0] != batch_size: + print("labels:", labels.shape) + print("batch_size:", batch_size) + raise ValueError("Num of labels does not match num of features") + mask = torch.eq(labels, labels.T).float().to(device) # Construct a B×B mask matrix and mark positions with the same label as 1 + else: + mask = torch.eye(batch_size, dtype=torch.float32).to(device) + + contrast_count = features.shape[1] # 1 + contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) # [B, D] + if self.contrast_mode == "one": + anchor_feature = features[:, 0] + anchor_count = 1 + elif self.contrast_mode == "all": + anchor_feature = contrast_feature # [B, D] + anchor_count = contrast_count # 1 + else: + raise ValueError(f"Unknown mode: {self.contrast_mode}") + + # compute logits + anchor_dot_contrast = torch.div(torch.matmul(anchor_feature, contrast_feature.T), self.temperature) # [B,D] * [D,B] / 0.07 ==> [B,B] + # for numerical stability + logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) + logits = anchor_dot_contrast - logits_max.detach() + + # tile mask + mask = mask.repeat(anchor_count, contrast_count) + # mask-out self-contrast cases + logits_mask = torch.scatter(torch.ones_like(mask), 1, torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 0) # Mask out diagonal elements + mask = mask * logits_mask + + # compute log_prob + exp_logits = torch.exp(logits) * logits_mask + log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) + + # compute mean of log-likelihood over positive + # modified to handle edge cases when there is no positive pair + # for an anchor point. + # Edge case e.g.:- + # features of shape: [4,1,...] + # labels: [0,1,1,2] + # loss before mean: [nan, ..., ..., nan] + mask_pos_pairs = mask.sum(1) + mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs) + mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs + + # loss + loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos + loss = loss.view(anchor_count, batch_size).mean() + + return 0.01 * loss + loss_cls \ No newline at end of file diff --git a/training/loss/supercontrast_loss.py b/training/loss/supercontrast_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..fb719a025a5077f3b032a31cc307626ca84c220d --- /dev/null +++ b/training/loss/supercontrast_loss.py @@ -0,0 +1,101 @@ +""" +Author: Yonglong Tian (yonglong@mit.edu) +Date: May 07, 2020 +""" +from __future__ import print_function + +import torch +import torch.nn as nn +import torch.nn.functional as F +from .abstract_loss_func import AbstractLossClass +from metrics.registry import LOSSFUNC + + +@LOSSFUNC.register_module(module_name="supcon") +class SupConLoss(AbstractLossClass): + """ + Refer to https://github.com/HobbitLong/SupContrast + Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. + It also supports the unsupervised contrastive loss in SimCLR + """ + def __init__(self, temperature=0.07, contrast_mode="all", base_temperature=0.07): + super(SupConLoss, self).__init__() + self.temperature = temperature + self.contrast_mode = contrast_mode + self.base_temperature = base_temperature + + def forward(self, features, labels=None): + """ + Compute loss for model. If `labels` is None, + it degenerates to SimCLR unsupervised loss: + https://arxiv.org/pdf/2002.05709.pdf + + Args: + features: hidden vector of shape [bsz, n_views, ...]. + labels: ground truth of shape [bsz]. + Returns: + A loss scalar. + """ + device = (torch.device("cuda") if features.is_cuda else torch.device("cpu")) + + if len(features.shape) == 2: + features = torch.unsqueeze(features, dim=1) + # raise ValueError('`features` needs to be [bsz, n_views, ...],' + # 'at least 3 dimensions are required') + if len(features.shape) > 3: + features = features.view(features.shape[0], features.shape[1], -1)# [B, N_view, D] + + features = F.normalize(features, dim=-1) # Normalize each sample individually so that each sample has unit norm + + batch_size = features.shape[0] + if labels is not None: + labels = labels.contiguous().view(-1, 1) # [B, 1] + if labels.shape[0] != batch_size: + raise ValueError("Num of labels does not match num of features") + mask = torch.eq(labels, labels.T).float().to(device) # Construct a B×B mask matrix and mark positions with the same label as 1 + else: + mask = torch.eye(batch_size, dtype=torch.float32).to(device) + + contrast_count = features.shape[1] # 1 + contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) # [B, D] + if self.contrast_mode == "one": + anchor_feature = features[:, 0] + anchor_count = 1 + elif self.contrast_mode == "all": + anchor_feature = contrast_feature # [B, D] + anchor_count = contrast_count # 1 + else: + raise ValueError(f"Unknown mode: {self.contrast_mode}") + + # compute logits + anchor_dot_contrast = torch.div(torch.matmul(anchor_feature, contrast_feature.T), self.temperature) # [B,D] * [D,B] / 0.07 ==> [B,B] + # for numerical stability + logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) + logits = anchor_dot_contrast - logits_max.detach() + + # tile mask + mask = mask.repeat(anchor_count, contrast_count) + # mask-out self-contrast cases + logits_mask = torch.scatter(torch.ones_like(mask), 1, torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 0) # Mask out diagonal elements + mask = mask * logits_mask + + # compute log_prob + exp_logits = torch.exp(logits) * logits_mask + log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) + + # compute mean of log-likelihood over positive + # modified to handle edge cases when there is no positive pair + # for an anchor point. + # Edge case e.g.:- + # features of shape: [4,1,...] + # labels: [0,1,1,2] + # loss before mean: [nan, ..., ..., nan] + mask_pos_pairs = mask.sum(1) + mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs) + mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs + + # loss + loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos + loss = loss.view(anchor_count, batch_size).mean() + + return loss \ No newline at end of file diff --git a/training/loss/vgg_loss.py b/training/loss/vgg_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..8ce09b72e11bb9474fe7e466547a1486045925b8 --- /dev/null +++ b/training/loss/vgg_loss.py @@ -0,0 +1,152 @@ +"""A VGG-based perceptual loss function for PyTorch.""" + +import torch +from torch import nn +from torch.nn import functional as F +from torchvision import models, transforms +import torch +import torch.nn as nn +from .abstract_loss_func import AbstractLossClass +from metrics.registry import LOSSFUNC + + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + +class Lambda(nn.Module): + """Wraps a callable in an :class:`nn.Module` without registering it.""" + + def __init__(self, func): + super().__init__() + object.__setattr__(self, 'forward', func) + + def extra_repr(self): + return getattr(self.forward, '__name__', type(self.forward).__name__) + '()' + + +class WeightedLoss(nn.ModuleList): + """A weighted combination of multiple loss functions.""" + + def __init__(self, losses, weights, verbose=False): + super().__init__() + for loss in losses: + self.append(loss if isinstance(loss, nn.Module) else Lambda(loss)) + self.weights = weights + self.verbose = verbose + + def _print_losses(self, losses): + for i, loss in enumerate(losses): + print(f'({i}) {type(self[i]).__name__}: {loss.item()}') + + def forward(self, *args, **kwargs): + losses = [] + for loss, weight in zip(self, self.weights): + losses.append(loss(*args, **kwargs) * weight) + if self.verbose: + self._print_losses(losses) + return sum(losses) + + +class TVLoss(nn.Module): + """Total variation loss (Lp penalty on image gradient magnitude). + The input must be 4D. If a target (second parameter) is passed in, it is + ignored. + ``p=1`` yields the vectorial total variation norm. It is a generalization + of the originally proposed (isotropic) 2D total variation norm (see + (see https://en.wikipedia.org/wiki/Total_variation_denoising) for color + images. On images with a single channel it is equal to the 2D TV norm. + ``p=2`` yields a variant that is often used for smoothing out noise in + reconstructions of images from neural network feature maps (see Mahendran + and Vevaldi, "Understanding Deep Image Representations by Inverting + Them", https://arxiv.org/abs/1412.0035) + :attr:`reduction` can be set to ``'mean'``, ``'sum'``, or ``'none'`` + similarly to the loss functions in :mod:`torch.nn`. The default is + ``'mean'``. + """ + + def __init__(self, p, reduction='mean', eps=1e-8): + super().__init__() + if p not in {1, 2}: + raise ValueError('p must be 1 or 2') + if reduction not in {'mean', 'sum', 'none'}: + raise ValueError("reduction must be 'mean', 'sum', or 'none'") + self.p = p + self.reduction = reduction + self.eps = eps + + def forward(self, input, target=None): + input = F.pad(input, (0, 1, 0, 1), 'replicate') + x_diff = input[..., :-1, :-1] - input[..., :-1, 1:] + y_diff = input[..., :-1, :-1] - input[..., 1:, :-1] + diff = x_diff**2 + y_diff**2 + if self.p == 1: + diff = (diff + self.eps).mean(dim=1, keepdims=True).sqrt() + if self.reduction == 'mean': + return diff.mean() + if self.reduction == 'sum': + return diff.sum() + return diff + + +@LOSSFUNC.register_module(module_name="vgg_loss") +class VGGLoss(AbstractLossClass): + """Computes the VGG perceptual loss between two batches of images. + The input and target must be 4D tensors with three channels + ``(B, 3, H, W)`` and must have equivalent shapes. Pixel values should be + normalized to the range 0–1. + The VGG perceptual loss is the mean squared difference between the features + computed for the input and target at layer :attr:`layer` (default 8, or + ``relu2_2``) of the pretrained model specified by :attr:`model` (either + ``'vgg16'`` (default) or ``'vgg19'``). + If :attr:`shift` is nonzero, a random shift of at most :attr:`shift` + pixels in both height and width will be applied to all images in the input + and target. The shift will only be applied when the loss function is in + training mode, and will not be applied if a precomputed feature map is + supplied as the target. + :attr:`reduction` can be set to ``'mean'``, ``'sum'``, or ``'none'`` + similarly to the loss functions in :mod:`torch.nn`. The default is + ``'mean'``. + :meth:`get_features()` may be used to precompute the features for the + target, to speed up the case where inputs are compared against the same + target over and over. To use the precomputed features, pass them in as + :attr:`target` and set :attr:`target_is_features` to :code:`True`. + Instances of :class:`VGGLoss` must be manually converted to the same + device and dtype as their inputs. + """ + + models = {'vgg16': models.vgg16, 'vgg19': models.vgg19} + + def __init__(self, model='vgg16', layer=8, shift=0, reduction='mean'): + super().__init__() + self.instancenorm = nn.InstanceNorm2d(512, affine=False) + self.shift = shift + self.reduction = reduction + self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + self.model = self.models[model](pretrained=True).features[:layer+1] + self.model.eval() + self.model.requires_grad_(False) + self.model.to(device) + + def get_features(self, input): + return self.model(self.normalize(input)) + + def train(self, mode=True): + self.training = mode + + def forward(self, input, target, target_is_features=False): + if target_is_features: + input_feats = self.get_features(input) + target_feats = target + else: + sep = input.shape[0] + batch = torch.cat([input, target]) + if self.shift and self.training: + padded = F.pad(batch, [self.shift] * 4, mode='replicate') + batch = transforms.RandomCrop(batch.shape[2:])(padded) + feats = self.get_features(batch) + input_feats, target_feats = feats[:sep], feats[sep:] + # input_feats, target_feats = \ + # self.instancenorm(input_feats), \ + # self.instancenorm(target_feats) + return F.mse_loss(input_feats, target_feats, reduction=self.reduction) \ No newline at end of file diff --git a/training/metrics/__init__.py b/training/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..676145d777810e4a51bdaf59fdec4f5358aae349 --- /dev/null +++ b/training/metrics/__init__.py @@ -0,0 +1,7 @@ +import os +import sys +current_file_path = os.path.abspath(__file__) +parent_dir = os.path.dirname(os.path.dirname(current_file_path)) +project_root_dir = os.path.dirname(parent_dir) +sys.path.append(parent_dir) +sys.path.append(project_root_dir) diff --git a/training/metrics/base_metrics_class.py b/training/metrics/base_metrics_class.py new file mode 100644 index 0000000000000000000000000000000000000000..4ae66c86483962b3106a51466bbb17a4341605f0 --- /dev/null +++ b/training/metrics/base_metrics_class.py @@ -0,0 +1,360 @@ +import numpy as np +from sklearn import metrics +from collections import defaultdict +import torch +import torch.nn as nn +from sklearn.metrics import average_precision_score + +def get_accracy(output, label): + _, prediction = torch.max(output, 1) # argmax + correct = (prediction == label).sum().item() + accuracy = correct / prediction.size(0) + return accuracy + + +def get_prediction(output, label): + prob = nn.functional.softmax(output, dim=1)[:, 1] + prob = prob.view(prob.size(0), 1) + label = label.view(label.size(0), 1) + #print(prob.size(), label.size()) + datas = torch.cat((prob, label.float()), dim=1) + return datas + + +def calculate_acc_for_train(label, output, num_classes): + """ + Compute Accuracy and mAP for a multi-class classification task. + + Args: + label: Ground-truth labels with shape [batch_size], where each element is a class index from 0 to num_classes - 1. + output: Model outputs with shape [batch_size, num_classes], usually logits. + num_classes: Total number of classes + + Returns: + accuracy: Accuracy score + map_score: Mean Average Precision (mAP) + """ + # Compute accuracy score + _, prediction = torch.max(output, 1) # Use the class with the highest probability as the prediction + correct = (prediction == label).sum().item() + accuracy = correct / prediction.size(0) + + # Compute mAP + # Convert outputs to a probability distribution + probs = torch.softmax(output, dim=1) # Apply softmax across the class dimension for multi-class classification + + # Convert to NumPy arrays for sklearn utilities + probs_np = probs.cpu().detach().numpy() + labels_np = label.cpu().detach().numpy() + + # Compute AP for each class + aps = [] + for class_idx in range(num_classes): + # Build binary labels: current class as 1, all others as 0 + binary_labels = (labels_np == class_idx).astype(int) + # Predicted probability for the current class + class_probs = probs_np[:, class_idx] + + # Compute AP for this class + try: + ap = average_precision_score(binary_labels, class_probs) + aps.append(ap) + except ValueError: + print("Error") + aps.append(0.0) + + map_score = np.mean(aps) + + return accuracy, map_score + + +def to_numpy(x): + if isinstance(x, torch.Tensor): + # If the input is a Tensor, detach it first to avoid gradient tracking, then move it to CPU and convert it to NumPy. + return x.detach().cpu().numpy() if x.is_cuda else x.numpy() + elif isinstance(x, np.ndarray): + # If it is already a NumPy array, return it directly. + return x + else: + raise TypeError(f"Unsupported data type: {type(x)},Only torch.Tensor and numpy.ndarray are supported") + + + + +def calculate_acc_for_test(label, output, num_classes): + """ + Compute Accuracy and mAP for a multi-class classification task. + Note: this version assumes `output` is already a probability distribution + (i.e. softmax has already been applied). + + Args: + label: Ground-truth labels with shape [batch_size], where each element is + a class index from 0 to num_classes - 1. + output: Model outputs with shape [batch_size, num_classes], already in + probability form. + num_classes: Total number of classes. + + Returns: + accuracy: Accuracy score. + map_score: Mean Average Precision (mAP). + """ + + # -------------------------- + # 1. Compute accuracy + # -------------------------- + label = to_numpy(label) + output = to_numpy(output) + + prediction = np.argmax(output, axis=1) # Take the index of the class with the highest probability + + # for i,j in zip(label, prediction): + # print(i, j) + + correct = np.sum(prediction == label) + accuracy = correct / len(label) # len(label) is the total number of samples + + # -------------------------- + # 2. Compute mAP + # -------------------------- + aps = [] + for class_idx in range(num_classes): + # Build binary labels: current class as 1, all others as 0 + binary_labels = (label == class_idx).astype(int) + # Predicted probability for the current class + class_probs = output[:, class_idx] + + # Check whether this class has both positive and negative samples to avoid meaningless computation + has_positive = np.any(binary_labels == 1) + has_negative = np.any(binary_labels == 0) + + if not (has_positive and has_negative): + # If only positive or only negative samples exist, AP cannot be computed, so skip this class + if(has_positive): + print(f"Warning: class {class_idx} is missing negative samples, skipping AP computation") + else: + print(f"Warning: class {class_idx} is missing positive samples, skipping AP computation") + continue # Skip directly and exclude it from mAP computation + + # Compute AP while handling possible numerical issues + try: + # Clamp the probability range to improve stability and avoid extreme values + class_probs_clamped = np.clip(class_probs, 1e-8, 1 - 1e-8) + ap = average_precision_score(binary_labels, class_probs_clamped) + aps.append(ap) + except Exception as e: + print(f"Class {class_idx} failed to compute AP: {e}") + continue # Skip if the computation fails + + # Compute mAP; if all classes are skipped, set mAP to 0 + if len(aps) == 0: + map_score = 0.0 + print("Warning: AP cannot be computed for any class, setting mAP to 0") + else: + map_score = np.mean(aps) + + # Compute binary accuracy and mAP + bin_pridiction=np.asarray(prediction, dtype=bool) + bin_lable=np.asarray(label, dtype=bool) + correct=np.sum(bin_pridiction==bin_lable) + bin_acc=correct/len(label) + + bin_class_probs_true=output[:,0] + bin_class_probs_false=1-bin_class_probs_true + has_positive = np.any(bin_lable == True) + has_negative = np.any(bin_lable == False) + if not (has_positive and has_negative): + bin_mAP=0.0 + else: + true_clamped = np.clip(bin_class_probs_true, 1e-8, 1 - 1e-8) + false_clamped=np.clip(bin_class_probs_false, 1e-8, 1 - 1e-8) + true_ap=average_precision_score(~bin_lable, true_clamped) + false_ap=average_precision_score(bin_lable, false_clamped) + #print('True:{t}, False:{f}'.format(t=true_ap,f=false_ap)) + bin_mAP=(true_ap+false_ap)/2 + + return {'acc': accuracy, 'mAP': map_score, 'pred': output, 'label': label, 'bin_acc':bin_acc,'bin_mAP':bin_mAP} + + + +def calculate_metrics_for_train(label, output): + if output.size(1) != 1: + prob = torch.softmax(output, dim=1)[:, 1] + else: + prob = output + + # Accuracy + _, prediction = torch.max(output, 1) + correct = (prediction == label).sum().item() + accuracy = correct / prediction.size(0) + + # Average Precision + y_true = label.cpu().detach().numpy() + y_pred = prob.cpu().detach().numpy() + ap = metrics.average_precision_score(y_true, y_pred) + + # AUC and EER + try: + fpr, tpr, thresholds = metrics.roc_curve(label.squeeze().cpu().numpy(), + prob.squeeze().cpu().numpy(), + pos_label=1) + except: + # for the case when we only have one sample + return None, None, accuracy, ap + + if np.isnan(fpr[0]) or np.isnan(tpr[0]): + # for the case when all the samples within a batch is fake/real + auc, eer = None, None + else: + auc = metrics.auc(fpr, tpr) + fnr = 1 - tpr + eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))] + + return auc, eer, accuracy, ap + + +# ------------ compute average metrics of batches--------------------- +class Metrics_batch(): + def __init__(self): + self.tprs = [] + self.mean_fpr = np.linspace(0, 1, 100) + self.aucs = [] + self.eers = [] + self.aps = [] + + self.correct = 0 + self.total = 0 + self.losses = [] + + def update(self, label, output): + acc = self._update_acc(label, output) + if output.size(1) == 2: + prob = torch.softmax(output, dim=1)[:, 1] + else: + prob = output + #label = 1-label + #prob = torch.softmax(output, dim=1)[:, 1] + auc, eer = self._update_auc(label, prob) + ap = self._update_ap(label, prob) + + return acc, auc, eer, ap + + def _update_auc(self, lab, prob): + fpr, tpr, thresholds = metrics.roc_curve(lab.squeeze().cpu().numpy(), + prob.squeeze().cpu().numpy(), + pos_label=1) + if np.isnan(fpr[0]) or np.isnan(tpr[0]): + return -1, -1 + + auc = metrics.auc(fpr, tpr) + interp_tpr = np.interp(self.mean_fpr, fpr, tpr) + interp_tpr[0] = 0.0 + self.tprs.append(interp_tpr) + self.aucs.append(auc) + + # return auc + + # EER + fnr = 1 - tpr + eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))] + self.eers.append(eer) + + return auc, eer + + def _update_acc(self, lab, output): + _, prediction = torch.max(output, 1) # argmax + correct = (prediction == lab).sum().item() + accuracy = correct / prediction.size(0) + # self.accs.append(accuracy) + self.correct = self.correct+correct + self.total = self.total+lab.size(0) + return accuracy + + def _update_ap(self, label, prob): + y_true = label.cpu().detach().numpy() + y_pred = prob.cpu().detach().numpy() + ap = metrics.average_precision_score(y_true,y_pred) + self.aps.append(ap) + + return np.mean(ap) + + def get_mean_metrics(self): + mean_acc, std_acc = self.correct/self.total, 0 + mean_auc, std_auc = self._mean_auc() + mean_err, std_err = np.mean(self.eers), np.std(self.eers) + mean_ap, std_ap = np.mean(self.aps), np.std(self.aps) + + return {'acc':mean_acc, 'auc':mean_auc, 'eer':mean_err, 'ap':mean_ap} + + def _mean_auc(self): + mean_tpr = np.mean(self.tprs, axis=0) + mean_tpr[-1] = 1.0 + mean_auc = metrics.auc(self.mean_fpr, mean_tpr) + std_auc = np.std(self.aucs) + return mean_auc, std_auc + + def clear(self): + self.tprs.clear() + self.aucs.clear() + # self.accs.clear() + self.correct=0 + self.total=0 + self.eers.clear() + self.aps.clear() + self.losses.clear() + + +# ------------ compute average metrics of all data --------------------- +class Metrics_all(): + def __init__(self): + self.probs = [] + self.labels = [] + self.correct = 0 + self.total = 0 + + def store(self, label, output): + prob = torch.softmax(output, dim=1)[:, 1] + _, prediction = torch.max(output, 1) # argmax + correct = (prediction == label).sum().item() + self.correct += correct + self.total += label.size(0) + self.labels.append(label.squeeze().cpu().numpy()) + self.probs.append(prob.squeeze().cpu().numpy()) + + def get_metrics(self): + y_pred = np.concatenate(self.probs) + y_true = np.concatenate(self.labels) + # auc + fpr, tpr, thresholds = metrics.roc_curve(y_true,y_pred,pos_label=1) + auc = metrics.auc(fpr, tpr) + # eer + fnr = 1 - tpr + eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))] + # ap + ap = metrics.average_precision_score(y_true,y_pred) + # acc + acc = self.correct / self.total + return {'acc':acc, 'auc':auc, 'eer':eer, 'ap':ap} + + def clear(self): + self.probs.clear() + self.labels.clear() + self.correct = 0 + self.total = 0 + + +# only used to record a series of scalar value +class Recorder: + def __init__(self): + self.sum = 0 + self.num = 0 + def update(self, item, num=1): + if item is not None: + self.sum += item * num + self.num += num + def average(self): + if self.num == 0: + return None + return self.sum/self.num + def clear(self): + self.sum = 0 + self.num = 0 diff --git a/training/metrics/registry.py b/training/metrics/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..86e256c18d0ad522de79149a154f676fd0bdb414 --- /dev/null +++ b/training/metrics/registry.py @@ -0,0 +1,20 @@ +class Registry(object): + def __init__(self): + self.data = {} + + def register_module(self, module_name=None): + def _register(cls): + name = module_name + if module_name is None: + name = cls.__name__ + self.data[name] = cls + return cls + return _register + + def __getitem__(self, key): + return self.data[key] + +BACKBONE = Registry() +DETECTOR = Registry() +TRAINER = Registry() +LOSSFUNC = Registry() diff --git a/training/metrics/utils.py b/training/metrics/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d9c5ea3559e0371d213d229ee07f214bf63e4dde --- /dev/null +++ b/training/metrics/utils.py @@ -0,0 +1,95 @@ +from sklearn import metrics +import numpy as np + + +def parse_metric_for_print(metric_dict): + if metric_dict is None: + return "\n" + str = "\n" + str += "================================ Each dataset best metric ================================ \n" + for key, value in metric_dict.items(): + if key != 'avg': + str= str+ f"| {key}: " + for k,v in value.items(): + str = str + f" {k}={v} " + str= str+ "| \n" + else: + str += "============================================================================================= \n" + str += "================================== Average best metric ====================================== \n" + avg_dict = value + for avg_key, avg_value in avg_dict.items(): + if avg_key == 'dataset_dict': + for key,value in avg_value.items(): + str = str + f"| {key}: {value} | \n" + else: + str = str + f"| avg {avg_key}: {avg_value} | \n" + str += "=============================================================================================" + return str + + +def get_test_metrics(y_pred, y_true, img_names): + + def get_video_metrics(image, pred, label): + result_dict = {} + new_label = [] + new_pred = [] + + # print(len(image)) + # print(pred.shape) + # print(label.shape) + + for item in np.transpose(np.stack((image, pred, label)), (1, 0)): + + s = item[0] + if '\\' in s: + parts = s.split('\\') + else: + parts = s.split('/') + vid_name = parts[-2] # video name + b = parts[-1] + + if vid_name not in result_dict: + result_dict[vid_name] = [] + + result_dict[vid_name].append(item) + image_arr = list(result_dict.values()) + + for video in image_arr: + pred_sum = 0 + label_sum = 0 + leng = 0 + for frame in video: + pred_sum += float(frame[1]) + label_sum += int(frame[2]) + leng += 1 + new_pred.append(pred_sum / leng) + new_label.append(int(label_sum / leng)) + fpr, tpr, thresholds = metrics.roc_curve(new_label, new_pred) + v_auc = metrics.auc(fpr, tpr) + fnr = 1 - tpr + v_eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))] + return v_auc, v_eer + + y_pred = y_pred.squeeze() + # For UCF, where labels for different manipulations are not consistent. + y_true[y_true >= 1] = 1 + # auc + fpr, tpr, thresholds = metrics.roc_curve(y_true, y_pred, pos_label=1) + auc = metrics.auc(fpr, tpr) + # eer + fnr = 1 - tpr + eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))] + # ap + ap = metrics.average_precision_score(y_true, y_pred) + # acc + prediction_class = (y_pred > 0.5).astype(int) + correct = (prediction_class == np.clip(y_true, a_min=0, a_max=1)).sum().item() + acc = correct / len(prediction_class) + if type(img_names[0]) is not list: + # calculate video-level auc for the frame-level methods. + v_auc, _ = get_video_metrics(img_names, y_pred, y_true) + else: + # video-level methods + v_auc=auc + + return {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap, 'pred': y_pred, 'video_auc': v_auc, 'label': y_true} diff --git a/training/metrics_retrieval/get_metric.py b/training/metrics_retrieval/get_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..68761b0284fcfb553e04c2c9f2b69cd4ff923a6e --- /dev/null +++ b/training/metrics_retrieval/get_metric.py @@ -0,0 +1,211 @@ +# Get retrieval mAP metrics + +import pickle +import os +import numpy as np +from .utils import * + +def load_pkl_results(pkl_path): + """ + Load a pkl file and return its contents + + Args: + pkl_path: Path to the pkl file + + Returns: + data: A dictionary containing all_predictions, all_labels, all_feats, and metrics + """ + # Check whether the file exists + if not os.path.exists(pkl_path): + raise FileNotFoundError(f"Cannot find pkl file: {pkl_path}") + + # Load the pkl file + try: + with open(pkl_path, 'rb') as f: + data = pickle.load(f) + return data + except Exception as e: + raise RuntimeError(f"Failed to load the pkl file: {str(e)}") + +def print_pkl_info(data): + """Print information about the data stored in the pkl file""" + # Check the required fields in the data + required_keys = ['all_predictions', 'all_labels', 'all_feats', 'metrics'] + for key in required_keys: + if key not in data: + print(f"Warning: missing field in pkl file '{key}'") + + # Print prediction information + if 'all_predictions' in data: + preds = data['all_predictions'] + print(f"\nPredictions (all_predictions):") + print(f" Type: {type(preds)}") + print(f" Shape: {preds.shape}") + print(f" Data type: {preds.dtype}") + print(f" First 3 samples: \n{preds[:3]}") + + # Print label information + if 'all_labels' in data: + labels = data['all_labels'] + print(f"\nLabels (all_labels):") + print(f" Type: {type(labels)}") + print(f" Shape: {labels.shape}") + print(f" Data type: {labels.dtype}") + print(f" First 3 samples: \n{labels[:3]}") + + # Print feature information + if 'all_feats' in data: + feats = data['all_feats'] + print(f"\nFeatures (all_feats):") + print(f" Type: {type(feats)}") + print(f" Shape: {feats.shape}") + print(f" Data type: {feats.dtype}") + print(f" First 5 feature values of the first sample: \n{feats[0, :5]}") + + # Print metric information + if 'metrics' in data: + metrics = data['metrics'] + print(f"\nEvaluation metrics (metrics):") + for k, v in metrics.items(): + print(f" {k}: {v:.6f}" if isinstance(v, float) else f" {k}: {v}") + +def run_retrieval_evaluation(pkl_file_path, query_mode='10_sample_avg', rank_max=1000000, random_seed=42): + """ + Main retrieval evaluation function (supports switching between two query modes) + + Args: + pkl_file_path (str): Path to the pkl result file + query_mode (str): Query mode. Options: '10_sample_avg' (average of 10 random samples per class) or 'all_sample_mean' (average of all samples per class) + rank_max (int): Maximum rank used for CMC computation (defaults to large gallery settings) + random_seed (int): Random seed (keeps results reproducible in 10-sample mode) + """ + # -------------------------- 1. Load data (shared by both code paths)-------------------------- + results_data = load_pkl_results(pkl_file_path) + predictions = results_data['all_predictions'] # (N, M):predicted probabilities (currently unused) + labels = results_data['all_labels'] # (N,):ground-truth labels for all samples + features = results_data['all_feats'] # (N, D):features for all samples (N is the total number of samples and D is the feature dimension) + features=features.reshape(features.shape[0],-1) + metrics = results_data['metrics'] # original evaluation metrics (currently unused) + unique_labels = np.unique(labels) # all unique classes (shared by both modes) + num_classes = len(unique_labels) # Total number of classes + D = features.shape[-1] # feature dimension + query_features = np.zeros((num_classes, D)) # Initialize query features (stores one query vector per class) + query_labels = [] # Initialize query labels (one-to-one with query features) + + + # -------------------------- 2. Generate query features by mode and prepare the gallery -------------------------- + if query_mode == '10_sample_avg': + # Mode 1: randomly select 10 samples from each class, average them as the query, and remove the selected samples from the gallery + exclude_indices = [] # Record sample indices that should be removed from the gallery to avoid self-retrieval + np.random.seed(random_seed) # Fix the random seed for reproducible results + + for i, label in enumerate(unique_labels): + # Find all sample indices for the current class + label_indices = np.where(labels == label)[0] + # If the current class has fewer than 10 samples, use all of them; otherwise use 10 + sample_count = min(10, len(label_indices)) + # Randomly select samples without replacement to avoid duplicates + selected_indices = np.random.choice(label_indices, size=sample_count, replace=False) + + # Record the sample indices to be removed + exclude_indices.extend(selected_indices) + # Average the selected sample features and use them as the query for the current class + selected_feats = features[selected_indices] + query_features[i] = np.mean(selected_feats, axis=0) + # Record the query label for the current class + query_labels.append(label) + + # Prepare the gallery by removing samples used to build the query + mask = np.ones(len(labels), dtype=bool) # Initialize as all True (keep) + mask[exclude_indices] = False # Set selected samples to False (remove) + gallery_features = features[mask] # final gallery features + gallery_labels = labels[mask] # final gallery labels + + print(f"=== Mode: {query_mode} ===") + # print(f"Original sample count: {len(labels)}, gallery size after removing query-related samples: {len(gallery_features)}") + + + elif query_mode == 'all_sample_mean': + # Mode 2: average all samples in each class as the query, and keep all samples in the gallery + # Optional: sort by label to reproduce the ordering logic of the original second code path; this only standardizes order and does not affect results + sorted_indices = np.argsort(labels) + sorted_labels = labels[sorted_indices] + sorted_features = features[sorted_indices] + + for i, label in enumerate(unique_labels): + # Find all indices of the current class in the sorted features + label_indices = np.where(sorted_labels == label)[0] + # Compute the mean feature of all samples in the current class and use it as the query + query_features[i] = np.mean(sorted_features[label_indices], axis=0) + # Record the query label for the current class + query_labels.append(label) + + # Use all sorted samples directly as the gallery without removal to reproduce the original logic + gallery_features = sorted_features + gallery_labels = sorted_labels + + print(f"=== Mode: {query_mode} ===") + # print(f"Total number of classes: {num_classes}; the gallery uses all samples ({len(gallery_features)} total)") + + + else: + raise ValueError(f"query_mode only supports '10_sample_avg' or 'all_sample_mean'. Current input: {query_mode}") + + + # -------------------------- 3. Run retrieval and compute metrics in a unified way (shared by both code paths)-------------------------- + # Convert query labels to a NumPy array for downstream functions + query_labels = np.array(query_labels) + + # Compute the cosine-similarity matrix between queries and gallery samples ((num_classes, gallery_size)) + dist_mat = optimized_cosine_matrix(query_features, gallery_features) + print(f"Similarity matrix shape: {dist_mat.shape} (num_queries: {dist_mat.shape[0]}, gallery_size: {dist_mat.shape[1]})") + + # Compute the CMC curve and AP values + cmc_all, ap_all = retrieval_cmc_ap( + dist_mat=dist_mat, + labels_query=query_labels, + labels_gallery=gallery_labels, + dist_type="cosine", # Always use cosine distance (both code paths use cosine) + rank_max=rank_max + ) + + real_ret_res_h = retrieval_real(dist_mat, query_labels, gallery_labels, dist_type="cosine", reject_real_ratio=0.01, paths=None) + real_ret_res_k = retrieval_real(dist_mat, query_labels, gallery_labels, dist_type="cosine", reject_real_ratio=0.001, paths=None) + + # Print final results + print(f"=== Retrieval evaluation results ===") + # print(f"Query labels list: {query_labels}") + # print(f"AP values for each query: {ap_all}") + print(f"1. Retrieval mAP: {np.mean(ap_all):.4f}") + print(f"2. Recall at a fixed 0.01 false-rejection rate Real:{real_ret_res_h[0]:.6f} Fake:{np.mean(real_ret_res_h[1:]):.6f}") + print(f"3. Recall at a fixed 0.001 false-rejection rate Real:{real_ret_res_k[0]:.6f} Fake:{np.mean(real_ret_res_k[1:]):.6f}") + acc = metrics['acc'] + mAP = metrics['mAP'] + bin_acc=metrics['bin_acc'] + bin_mAP=metrics['bin_mAP'] + ap_all_num = np.mean(ap_all) + print(real_ret_res_h) + ret_1 = np.mean(real_ret_res_h[1:]) + ret_2 = np.mean(real_ret_res_k[1:]) + all_avg = (acc + mAP + ap_all_num + ret_1 + ret_2) / 5 + #all_avg = (acc + mAP + ap_all_num + ret_1) / 4 + # print(f"\033[32m{acc:.4f} {mAP:.4f} {ap_all_num:.4f} {ret_1:.4f} {ret_2:.4f} {all_avg:.4f}\033[0m") # + print(f"\033[32m bin_acc:{bin_acc:.4f} acc:{acc:.4f} bin_maP:{bin_mAP:.4f} mAp:{mAP:.4f} Ret-mAP:{ap_all_num:.4f} 0.01 false-rejection recall:{ret_1:.4f} 0.001 false-rejection recall:{ret_2:.4f} Average:{all_avg:.4f}\033[0m") # + print(f"{bin_acc:.4f} {acc:.4f} {bin_mAP:.4f} {mAP:.4f} {ap_all_num:.4f} {ret_1:.4f} {ret_2:.4f} {all_avg:.4f}") + print() + +# -------------------------- Main function call (switch modes as needed)-------------------------- +if __name__ == "__main__": + RANK_MAX = 10 # Adjust as needed (the original second code path used 10, while the first used 1e6) + + + for PKL_FILE_PATH in [ + + "/your/protocol_2_test.pkl", + "/your/protocol_3_test.pkl", + + ]: + seed = 42 + run_retrieval_evaluation(pkl_file_path=PKL_FILE_PATH, query_mode='10_sample_avg', rank_max=RANK_MAX,random_seed=seed) # all_sample_mean + + diff --git a/training/metrics_retrieval/get_metric_pro4.py b/training/metrics_retrieval/get_metric_pro4.py new file mode 100644 index 0000000000000000000000000000000000000000..41ac520dcdaff7ab217254f72ddd9375aceba562 --- /dev/null +++ b/training/metrics_retrieval/get_metric_pro4.py @@ -0,0 +1,291 @@ +# Get retrieval mAP metrics + +import pickle +import os +import numpy as np +from .utils import * +import yaml +from typing import List, Tuple +import torch +def generate_cls2real(yaml_path: str) -> torch.Tensor: + """ + Minimal version: generate cls2real following the order of label_dict in the YAML file + Index = label index (0,1,2...), value = the corresponding group value in the YAML file (0/1/2/3) + """ + # 1. Read the YAML file and extract label_dict values in order + with open(yaml_path, 'r', encoding='utf-8') as f: + data = yaml.safe_load(f) + label_dict = data.get("label2real_dict", {}) + # 2. Extract group values in dictionary insertion order (Python 3.7+ preserves insertion order by default) + group_values = list(label_dict.values()) + # 3. Convert to a cls2real tensor + cls2real = torch.tensor(group_values, dtype=torch.int64) + return cls2real + +def generate_fake_info(yaml_path: str) -> List[Tuple[str, int]]: + """ + Extract fake-class information from label2real_dict in the YAML file and build fake_info + Return format: [(fake class name, corresponding real label), ...] + """ + # 1. Read the YAML file + with open(yaml_path, 'r', encoding='utf-8') as f: + data = yaml.safe_load(f) + label2real_dict = data.get("label2real_dict", {}) + + # 2. Filter fake classes and build fake_info + fake_info = [] + for name, real_label in label2real_dict.items(): + if "Fake" in name: # Filter fake classes + fake_info.append((name, real_label)) + + return fake_info + +def trans_lable(lables): + tmp_lables=lables.copy() + # 0-3 ->0 4-68->1-65 + mask_0_3 = (lables >= 0) & (lables <= 3) + tmp_lables[mask_0_3] = 0 + + # 3. Rule 2: labels greater than 4 are reduced by 3 (4→1, 5→2, ..., 68→65) + mask_gt4 = lables >= 4 + tmp_lables[mask_gt4] = lables[mask_gt4] - 3 + return tmp_lables + +def load_pkl_results(pkl_path): + """ + Load a pkl file and return its contents + + Args: + pkl_path: Path to the pkl file + + Returns: + data: A dictionary containing all_predictions, all_labels, all_feats, and metrics + """ + # Check whether the file exists + if not os.path.exists(pkl_path): + print(pkl_path) + raise FileNotFoundError(f"Cannot find pkl file: {pkl_path}") + + # Load the pkl file + try: + with open(pkl_path, 'rb') as f: + data = pickle.load(f) + return data + except Exception as e: + raise RuntimeError(f"Failed to load the pkl file: {str(e)}") + +def print_pkl_info(data): + """Print information about the data stored in the pkl file""" + # Check the required fields in the data + required_keys = ['all_predictions', 'all_labels', 'all_feats', 'metrics'] + for key in required_keys: + if key not in data: + print(f"Warning: missing field in pkl file '{key}'") + + # Print prediction information + if 'all_predictions' in data: + preds = data['all_predictions'] + print(f"\nPredictions (all_predictions):") + print(f" Type: {type(preds)}") + print(f" Shape: {preds.shape}") + print(f" Data type: {preds.dtype}") + print(f" First 3 samples: \n{preds[:3]}") + + # Print label information + if 'all_labels' in data: + labels = data['all_labels'] + print(f"\nLabels (all_labels):") + print(f" Type: {type(labels)}") + print(f" Shape: {labels.shape}") + print(f" Data type: {labels.dtype}") + print(f" First 3 samples: \n{labels[:3]}") + + # Print feature information + if 'all_feats' in data: + feats = data['all_feats'] + print(f"\nFeatures (all_feats):") + print(f" Type: {type(feats)}") + print(f" Shape: {feats.shape}") + print(f" Data type: {feats.dtype}") + print(f" First 5 feature values of the first sample: \n{feats[0, :5]}") + + # Print metric information + if 'metrics' in data: + metrics = data['metrics'] + print(f"\nEvaluation metrics (metrics):") + for k, v in metrics.items(): + print(f" {k}: {v:.6f}" if isinstance(v, float) else f" {k}: {v}") + +def run_retrieval_evaluation_p4(pkl_file_path, query_mode='10_sample_avg', rank_max=1000000, random_seed=42,yaml_path=None): + """ + Main retrieval evaluation function (supports switching between two query modes) + + Args: + pkl_file_path (str): Path to the pkl result file + query_mode (str): Query mode. Options: '10_sample_avg' (average of 10 random samples per class) or 'all_sample_mean' (average of all samples per class) + rank_max (int): Maximum rank used for CMC computation (defaults to large gallery settings) + random_seed (int): Random seed (keeps results reproducible in 10-sample mode) + """ + # -------------------------- 1. Load data (shared by both code paths)-------------------------- + results_data = load_pkl_results(pkl_file_path) + predictions = results_data['all_predictions'] # (N, M):predicted probabilities (currently unused) + labels = results_data['all_labels'] # (N,):ground-truth labels for all samples + + + features = results_data['all_feats'] # (N, D):features for all samples (N is the total number of samples and D is the feature dimension) + features=features.reshape(features.shape[0],-1) + metrics = results_data['metrics'] # original evaluation metrics (currently unused) + + print_pkl_info(results_data) + print(labels) + + unique_labels = np.unique(labels) # all unique classes (shared by both modes) + num_classes = len(unique_labels) # Total number of classes + D = features.shape[-1] # feature dimension + query_features = np.zeros((num_classes, D)) # Initialize query features (stores one query vector per class) + query_labels = [] # Initialize query labels (one-to-one with query features) + + + # -------------------------- 2. Generate query features by mode and prepare the gallery -------------------------- + if query_mode == '10_sample_avg': + # Mode 1: randomly select 10 samples from each class, average them as the query, and remove the selected samples from the gallery + exclude_indices = [] # Record sample indices that should be removed from the gallery to avoid self-retrieval + np.random.seed(random_seed) # Fix the random seed for reproducible results + + for i, label in enumerate(unique_labels): + # Find all sample indices for the current class + label_indices = np.where(labels == label)[0] + # If the current class has fewer than 10 samples, use all of them; otherwise use 10 + sample_count = min(10, len(label_indices)) + # Randomly select samples without replacement to avoid duplicates + selected_indices = np.random.choice(label_indices, size=sample_count, replace=False) + + # Record the sample indices to be removed + exclude_indices.extend(selected_indices) + # Average the selected sample features and use them as the query for the current class + selected_feats = features[selected_indices] + query_features[i] = np.mean(selected_feats, axis=0) + # Record the query label for the current class + query_labels.append(label) + + # Prepare the gallery by removing samples used to build the query + mask = np.ones(len(labels), dtype=bool) # Initialize as all True (keep) + mask[exclude_indices] = False # Set selected samples to False (remove) + gallery_features = features[mask] # final gallery features + gallery_labels = labels[mask] # final gallery labels + + print(f"=== Mode: {query_mode} ===") + # print(f"Original sample count: {len(labels)}, gallery size after removing query-related samples: {len(gallery_features)}") + + + elif query_mode == 'all_sample_mean': + # Mode 2: average all samples in each class as the query, and keep all samples in the gallery + # Optional: sort by label to reproduce the ordering logic of the original second code path; this only standardizes order and does not affect results + sorted_indices = np.argsort(labels) + sorted_labels = labels[sorted_indices] + sorted_features = features[sorted_indices] + + for i, label in enumerate(unique_labels): + # Find all indices of the current class in the sorted features + label_indices = np.where(sorted_labels == label)[0] + # Compute the mean feature of all samples in the current class and use it as the query + query_features[i] = np.mean(sorted_features[label_indices], axis=0) + # Record the query label for the current class + query_labels.append(label) + + # Use all sorted samples directly as the gallery without removal to reproduce the original logic + gallery_features = sorted_features + gallery_labels = sorted_labels + + print(f"=== Mode: {query_mode} ===") + # print(f"Total number of classes: {num_classes}; the gallery uses all samples ({len(gallery_features)} total)") + + + else: + raise ValueError(f"query_mode only supports '10_sample_avg' or 'all_sample_mean'. Current input: {query_mode}") + + + # -------------------------- 3. Run retrieval and compute metrics in a unified way (shared by both code paths)-------------------------- + # Convert query labels to a NumPy array for downstream functions + query_labels = np.array(query_labels) + #Process query_features and gallery_labels so that real labels 0, 1, 2, and 3 are merged into one class + D = query_features.shape[1] + # 2. Initialize new_query_features (66×D) + new_query_features = np.zeros((66, D), dtype=query_features.dtype) + # 3. Item 0: mean of the original items 0/1/2/3 + new_query_features[0] = np.mean(query_features[0:4], axis=0) # Take the first four items (0,1,2,3) and average them + # 4. Original items 4–68 (65 items in total) → items 1–65 in the new query + # Original indices: 4,5,...,68 → new indices: 1,2,...,65 + new_query_features[1:] = query_features[4:] # + #Convert labels + new_query_labels=trans_lable(query_labels) + new_query_labels= np.unique(new_query_labels) + new_gallery_labels=trans_lable(gallery_labels) + + # Compute the cosine-similarity matrix between queries and gallery samples ((num_classes, gallery_size)) + dist_mat = optimized_cosine_matrix(new_query_features, gallery_features) + print(f"Similarity matrix shape: {dist_mat.shape} (num_queries: {dist_mat.shape[0]}, gallery_size: {dist_mat.shape[1]})") + + + # Compute the CMC curve and AP values + cmc_all, ap_all = retrieval_cmc_ap( + dist_mat=dist_mat, + labels_query=new_query_labels, + labels_gallery=new_gallery_labels, + dist_type="cosine", # Always use cosine distance (both code paths use cosine) + rank_max=rank_max + ) + dist_mat = optimized_cosine_matrix(query_features, gallery_features) + cls2real=generate_cls2real(yaml_path) + real_ret_res_h = retrieval_real_p4(dist_mat, query_labels, gallery_labels, dist_type="cosine", reject_real_ratio=0.01, paths=None,cls2real=cls2real) + real_ret_res_k = retrieval_real_p4(dist_mat, query_labels, gallery_labels, dist_type="cosine", reject_real_ratio=0.001, paths=None,cls2real=cls2real) + + # Print final results + print(f"=== Retrieval evaluation results ===") + # print(f"Query labels list: {query_labels}") + # print(f"AP values for each query: {ap_all}") + print(f"1. Retrieval mAP: {np.mean(ap_all):.4f}") + print(f"2. Recall at a fixed 0.01 false-rejection rate Real:{real_ret_res_h[0]:.6f} Fake:{np.mean(real_ret_res_h[1:]):.6f}") + print(f"3. Recall at a fixed 0.001 false-rejection rate Real:{real_ret_res_k[0]:.6f} Fake:{np.mean(real_ret_res_k[1:]):.6f}") + acc = metrics['acc'] + mAP = metrics['mAP'] + bin_acc=metrics['bin_acc'] + bin_mAP=metrics['bin_mAP'] + ap_all_num = np.mean(ap_all) + ret_1 = np.mean(real_ret_res_h[4:]) + ret_2 = np.mean(real_ret_res_k[4:]) + all_avg = (acc + mAP + ap_all_num + ret_1 + ret_2) / 5 + #all_avg = (acc + mAP + ap_all_num + ret_1) / 4 + # print(f"\033[32m{acc:.4f} {mAP:.4f} {ap_all_num:.4f} {ret_1:.4f} {ret_2:.4f} {all_avg:.4f}\033[0m") # + print(f"\033[32m bin_acc:{bin_acc:.4f} acc:{acc:.4f} bin_maP:{bin_mAP:.4f} mAp:{mAP:.4f} Ret-mAP:{ap_all_num:.4f} 0.01 false-rejection recall:{ret_1:.4f} 0.001 false-rejection recall:{ret_2:.4f} Average:{all_avg:.4f}\033[0m") # + print(f"{bin_acc:.4f} {acc:.4f} {bin_mAP:.4f} {mAP:.4f} {ap_all_num:.4f} {ret_1:.4f} {ret_2:.4f} {all_avg:.4f}") + print() + fake_info =generate_fake_info(yaml_path) + df = pd.DataFrame({ + "Fake class name": [name for name, _ in fake_info], + "Corresponding real label": [real_label for _, real_label in fake_info], + "ret-map":ap_all[1:], + "ret1 metric": real_ret_res_h[4:], + "ret2 metric": real_ret_res_k[4:] + }) + + # 3.2 Print the formatted table to the console + print("=== Aligned metric table for 65 fake classes ===") + pd.set_option('display.max_rows', None) # show all rows + pd.set_option('display.max_columns', None) # show all columns + pd.set_option('display.width', 1000) # table width + pd.set_option('display.unicode.ambiguous_as_wide', True) + pd.set_option('display.unicode.east_asian_width', True) + print(df) +# -------------------------- Main function call (switch modes as needed)-------------------------- +if __name__ == "__main__": + RANK_MAX = 10 # Adjust as needed (the original second code path used 10, while the first used 1e6) + yaml_path="your.yaml" + for PKL_FILE_PATH in [ + "your.pkl" + ]: + seed = 42 + run_retrieval_evaluation_p4(pkl_file_path=PKL_FILE_PATH, query_mode='10_sample_avg', rank_max=RANK_MAX,random_seed=seed,yaml_path=yaml_path) # all_sample_mean + + + diff --git a/training/metrics_retrieval/utils.py b/training/metrics_retrieval/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5bfe97a280e661dbf2c1b4511880a3b9dd17f279 --- /dev/null +++ b/training/metrics_retrieval/utils.py @@ -0,0 +1,652 @@ +import os, pickle +import torch +import numpy as np +import pandas as pd +from tqdm import tqdm +from copy import deepcopy +import matplotlib.pyplot as plt +np.random.seed(1) + +def cosine_similarity(a, b): + dot_product = np.dot(a, b.T) + norm_a = np.linalg.norm(a) + norm_b = np.linalg.norm(b, axis=1) + return dot_product / (norm_a * norm_b) + +def optimized_cosine_matrix(query_features, gallery_features, chunk_size=50000, device='cuda'): + """ Revised cosine similarity matrix computation """ + # Ensure the inputs use floating point types (important!) + query = torch.as_tensor(query_features, dtype=torch.float32) + gallery = torch.as_tensor(gallery_features, dtype=torch.float32) + + # Move tensors to the target device synchronously to avoid asynchronous errors + query = query.to(device) + gallery = gallery.to(device) + + # Safely normalize inputs to avoid zero vectors + def safe_normalize(x): + norm = torch.norm(x, p=2, dim=1, keepdim=True) + return x / torch.where(norm == 0, torch.ones_like(norm), norm) + + query_norm = safe_normalize(query) + gallery_norm = safe_normalize(gallery) + + # Compute in chunks to optimize GPU memory usage + dist_mat = [] + with torch.no_grad(), torch.amp.autocast('cuda'): + for i in range(0, gallery_norm.size(0), chunk_size): + chunk = gallery_norm[i:i+chunk_size] + sim = torch.mm(query_norm, chunk.T) # (n_query, chunk_size) + dist_mat.append(sim.cpu().to(torch.float32)) # Preserve precision + + return torch.cat(dist_mat, dim=1) + +def retrieval_cmc_ap(dist_mat, labels_query, labels_gallery, dist_type="cosine", rank_max=10): + """ + Optimized CMC and AP computation for very large-scale settings (supports 14 × 800k matrices) + + Args: + dist_mat (Tensor): Distance matrix (num_query, num_gallery) + labels_query (Tensor): Query labels (num_query,) + labels_gallery (Tensor): Gallery labels (num_gallery,) + dist_type (str): Distance type ["cosine"|"l2"] + rank_max (int): Maximum ranking depth + + Returns: + cmc (Tensor): CMC curve (rank_max,) + ap_all (np.ndarray): AP values for each query (num_valid_query,) + """ + labels_query = torch.tensor(labels_query) + labels_gallery = torch.tensor(labels_gallery) + + # Basic parameter validation + num_query, num_gallery = dist_mat.shape + rank_max = min(rank_max, num_gallery) + device = dist_mat.device + + if dist_type == "l2": + sorted_indices = torch.argsort(dist_mat, dim=1, descending=False) + else: # cosine + sorted_indices = torch.argsort(dist_mat, dim=1, descending=True) # for cosine distance, larger means closer, so sort in descending order + + # Build the sorted label matrix in batch (num_query, num_gallery) + sorted_labels = labels_gallery[sorted_indices] # Get retrieval result labels ordered by distance + + # Build the match matrix (num_query, num_gallery) + matches = (sorted_labels == labels_query.view(-1, 1)).long() # Set positions with correct labels in the gallery to 1 + + # Filter out invalid queries with no matches + valid_mask = matches.sum(dim=1) > 0 + valid_matches = matches[valid_mask] + if valid_matches.size(0) == 0: + exit() + + # Vectorized CMC computation + cmc = valid_matches.cumsum(dim=1) + cmc = (cmc > 0).float() # Binarize the result + cmc_final = cmc[:, :rank_max].mean(dim=0) # (rank_max,) + + # Vectorized AP computation without Python loops + cum_correct = valid_matches.cumsum(dim=1) # Cumulative number of matches [q_num, g_num] + positions = torch.arange(1, num_gallery+1, device=device).view(1, -1) # [1, g_num] + precisions = cum_correct / positions # Compute precision at every position + # First collect precision values at all matched positions, then sum them and divide by the total number of matches + ap_values = (precisions * valid_matches).sum(dim=1) / valid_matches.sum(dim=1).clamp(min=1e-6) + + return cmc_final.cpu(), ap_values.cpu().numpy() + +def retrieval_real(dist_mat, labels_query, labels_gallery, dist_type="cosine", reject_real_ratio=0.001, paths=None): + + labels_query = torch.tensor(labels_query) + labels_gallery = torch.tensor(labels_gallery) + + # Basic parameter validation + num_query, num_gallery = dist_mat.shape + device = dist_mat.device + # Build the ordered index list + indexs = torch.tensor([i for i in range(len(labels_gallery))]) + + if dist_type == "l2": + sorted_indices = torch.argsort(dist_mat, dim=1, descending=False) + else: # cosine + sorted_indices = torch.argsort(dist_mat, dim=1, descending=True) # for cosine distance, larger means closer, so sort in descending order + + # Build the sorted label matrix in batch (num_query, num_gallery) + sorted_labels = labels_gallery[sorted_indices] # Get retrieval result labels ordered by similarity + sorted_indexs = indexs[sorted_indices] # Get index list ordered by similarity + + # Build the match matrix (num_query, num_gallery), Find indices of all FakeX and Real entries in each row + matches_fake = sorted_labels == labels_query.view(-1, 1) # Set positions with correct labels in the gallery to 1 + matches_real = sorted_labels == torch.zeros(len(labels_query)).view(-1, 1) + + # Plot similarity curves for each algorithm class + # for i in range(num_query): + # if i == 0: continue # Skip real retrieval results + # tensor_fake = dist_mat[i][labels_gallery == i] + # tensor_real = dist_mat[i][labels_gallery == 0] + # print(len(tensor_fake), len(tensor_real)) + # plt.figure(figsize=(12, 4)) + # plt.plot(torch.concat([tensor_fake, tensor_real], axis=0)) + # # tensor_sort_fake, _ = torch.sort(tensor_fake, descending=True) + # # tensor_sort_real, _ = torch.sort(tensor_real, descending=True) + # # plt.plot(torch.concat([tensor_sort_fake, tensor_sort_real], axis=0)) + # plt.grid(True) + # plt.savefig(f'{i}_unsort.png') + # plt.close() + + # Find the position where the 0.001 false rejection threshold is reached + real_num = torch.sum(torch.eq(labels_gallery, 0)) # Number of real samples in each row + cum_real = matches_real.cumsum(dim=1) + reject_k = cum_real == (real_num * reject_real_ratio).view(-1, 1).int() # Apply the 0.001 false rejection constraint + first_occurrence = torch.argmax(reject_k.long(), dim=1) # Position where false rejection is tolerated + + # Find sequence positions where false rejection occurs (handled per fake class) + reject_real_union = set() + for i in range(len(labels_query)): # skip real retrieval + cut_pos = first_occurrence[i]+1 # Include the final false rejection position + cut_labels = sorted_labels[i][:cut_pos] # All labels before the false rejection cutoff + cut_indexs = sorted_indexs[i][:cut_pos] # All indices before the false rejection cutoff + reject_real_union.update(cut_indexs[cut_labels == 0]) # Update the set + # print(i, len(reject_real_union), cut_indexs[cut_labels == 0]) + # print("Class label", i.numpy(), "descending similarity under false rejection:", dist_mat[i][cut_indexs[cut_labels == 0]]) + # print("cut_pos", i, cut_pos) + # print(f"Total false rejection after union across all fake classes at {reject_real_ratio}:", len(reject_real_union) / real_num) + + # Compute recall based on this position + cum_fake = matches_fake.cumsum(dim=1) + recall_nums = cum_fake[torch.arange(cum_fake.size(0)), first_occurrence] + + # print("recall_nums", recall_nums) + # print("cum_fake", cum_fake[:, -1]) + + return (recall_nums / cum_fake[:, -1]).numpy() + +def merge_pkls(pkl_algo): + features_all, paths_all = [], [] + # filter_flag = False + for pkl in pkl_algo: + # if '20240412_weixinkaiping_norm_lmdb.pkl' in pkl: + # filter_flag = True + with open(pkl, 'rb') as f: + data = pickle.load(f) # 'features' 'paths' + features_all.append(data['features']) + paths_all.append(data['paths']) + + # Apply padding if some algorithms have fewer than 10 samples + if sum([len(i) for i in features_all]) < 10: + print("Padding:", pkl_algo) + features_all = features_all + features_all + paths_all = paths_all + paths_all + + features_all = np.concatenate(features_all, axis=0) + paths_all = [element for sublist in paths_all for element in sublist] + + # Apply blacklist filtering to LMDB information for norm data + # if filter_flag: + # mask = np.array([vid not in all_drop_vids for vid in paths_all]) + # features_all = features_all[mask] + # paths_all = [img for img, keep in zip(paths_all, mask) if keep] + + return features_all, paths_all + +def group_average(arr, x, discard_remainder=True): + """ + Group the array into chunks of x elements and compute the average of each group. + + Args: + arr (np.ndarray): Input NumPy array + x (int): Number of elements in each group + discard_remainder (bool): Whether to discard the remaining elements when fewer than x + + Returns: + np.ndarray: Array of group averages + """ + if not isinstance(arr, np.ndarray): + arr = np.array(arr) + + if arr.size == 0: + return np.array([]) + + if x <= 0: + raise ValueError("Group size x must be a positive integer") + + if x > arr.size: + if discard_remainder: + return np.array([]) + else: + return np.array([arr.mean()]) + + # Compute the number of complete groups + n_groups = arr.size // x + full_groups = arr[:n_groups * x].reshape(-1, x) + averages = full_groups.mean(axis=1) + + # Handle the remaining elements + if not discard_remainder and arr.size % x != 0: + remainder = arr[n_groups * x:] + avg_remainder = remainder.mean() + averages = np.append(averages, avg_remainder) + + return averages + +def calculated_final_result(all_pkl): + ''' + Directly compute the final metrics from the pkl dictionary: global metrics + local metrics under 0.0001 false rejection + ''' + result_array = [] + + eval_modes = [2, 3] + for eval_mode in eval_modes: # range(1, 4): # 1:'10Q' 2:'10Q-Avg' 3:'All-Mean' + #### Settings + NUM_PER_CLASS = 10 + + features_lists, paths_lists = [], [] # Store features from each pkl + query_lists, labels_query = [], [] # Store the initial video features + labels_gallery = [] + + query_lenth = [] + #### S1.Load all pkl files and build the query set + for idx, pkl_key in tqdm(enumerate(all_pkl.keys())): + features_all, paths_all = merge_pkls(all_pkl[pkl_key]) + query_lenth.append(len(features_all)) + + if eval_mode == 1: + # Method 1: randomly select X samples as queries + sel_idxs = np.random.choice(len(paths_all), NUM_PER_CLASS, replace=False) + query_lists.extend(deepcopy(features_all[sel_idxs])) # Use deep copy here + # labels_query.extend([0 for _ in range(NUM_PER_CLASS)]) # Set to 0 to perform reverse recall for Real + labels_query.extend([idx for _ in range(NUM_PER_CLASS)]) + features_all = np.delete(features_all, sel_idxs, axis=0) + paths_all = [path for i, path in enumerate(paths_all) if i not in sel_idxs] + + elif eval_mode == 2: + # Method 2: use the center of X queries within a class as the query + sel_idxs = np.random.choice(len(paths_all), NUM_PER_CLASS, replace=False) + query_lists.append(np.mean(deepcopy(features_all[sel_idxs]), axis=0)) # Use deep copy here + labels_query.append(idx) + features_all = np.delete(features_all, sel_idxs, axis=0) + paths_all = [path for i, path in enumerate(paths_all) if i not in sel_idxs] + + elif eval_mode == 3: + # Method 3: use the class centroid as the query (cheating) + labels_query.append(idx) + query_lists.append(np.mean(np.array(features_all), axis=0)) + + # Generate ground-truth labels for the gallery set + labels_gallery.extend([idx for _ in range(len(paths_all))]) + # Save + features_lists.append(features_all) + paths_lists.append(paths_all) + + if eval_mode in [2, 3]: + NUM_PER_CLASS = 1 + + # Merge into one large feature list + features = np.concatenate(features_lists, axis=0) + paths = [element for sublist in paths_lists for element in sublist] + print('Len(gallery):', len(features)) + + #### S2.Generate the distance matrix [Q_num, G_num] + dist_mat = optimized_cosine_matrix(np.array(query_lists), features) + print('Shape(matrix):', dist_mat.shape) + + #### S3.[Global metric] compute CMC and AP + cmc_all, ap_all = retrieval_cmc_ap(dist_mat, labels_query, labels_gallery, dist_type="cosine", rank_max=10) + print("Global metric AP:") + result_array.append(group_average(ap_all, NUM_PER_CLASS)) + + #### S4.[Local metric] compute the relationship between a single fake class and all real samples (recall under 0.001 false rejection) + # recall = retrieval_real(dist_mat, labels_query, labels_gallery, reject_real_ratio=0.001) + # print("Local metric: recall under 0.001 false rejection:") + # result_array.append(recall) + + ### S5.[Local metric] compute the relationship between a single fake class and all real samples (recall under 0.0001 false rejection) + recall = retrieval_real(dist_mat, labels_query, labels_gallery, reject_real_ratio=0.0001) + print("Local metric: recall under 0.0001 false rejection:") + result_array.append(group_average(recall, NUM_PER_CLASS)) + print() + + # Reorder as [global metric][local metric 0.001][local metric 0.0001] for 10Q, 10Q-Avg, and All-Mean + if len(eval_modes) == 2: + rearranged = np.array(result_array)[[i + j * 2 for i in range(2) for j in range(len(result_array) // 2)]] + elif len(eval_modes) == 3: + rearranged = np.array(result_array)[[0,2,4,1,3,5]] + + result_array = np.transpose(rearranged) + + df = pd.DataFrame({ + 'Algorithm': list(all_pkl.keys()), + **{f'Value_{i+1}': result_array[:, i] for i in range(result_array.shape[1])} + }) + + return df, query_lenth + +# Support multiple query strategies +def calculated_final_result_multi_query_(all_pkl, upper_len=10): + ''' + Directly compute the final metrics from the pkl dictionary: global metrics + local metrics under 0.0001 false rejection + ''' + result_array = [] + + for num_query in range(1, upper_len + 1): # Evaluate every query count once + eval_modes = [2] + for eval_mode in eval_modes: # range(1, 4): # 1:'10Q' 2:'10Q-Avg' 3:'All-Mean' + #### Settings + NUM_PER_CLASS = num_query + + features_lists, paths_lists = [], [] # Store features from each pkl + query_lists, labels_query = [], [] # Store the initial video features + labels_gallery = [] + + #### S1.Load all pkl files and build the query set + for idx, pkl_key in tqdm(enumerate(all_pkl.keys())): + features_all, paths_all = merge_pkls(all_pkl[pkl_key]) + + if eval_mode == 1: + # Method 1: randomly select X samples as queries + sel_idxs = np.random.choice(len(paths_all), NUM_PER_CLASS, replace=False) + query_lists.extend(deepcopy(features_all[sel_idxs])) # Use deep copy here + # labels_query.extend([0 for _ in range(NUM_PER_CLASS)]) # Set to 0 to perform reverse recall for Real + labels_query.extend([idx for _ in range(NUM_PER_CLASS)]) + features_all = np.delete(features_all, sel_idxs, axis=0) + paths_all = [path for i, path in enumerate(paths_all) if i not in sel_idxs] + + elif eval_mode == 2: + # Method 2: use the center of X queries within a class as the query + sel_idxs = np.random.choice(len(paths_all), NUM_PER_CLASS, replace=False) + query_lists.append(np.mean(deepcopy(features_all[sel_idxs]), axis=0)) # Use deep copy here + labels_query.append(idx) + features_all = np.delete(features_all, sel_idxs, axis=0) + print("Selected videos:", [paths_all[i] for i in sel_idxs]) + paths_all = [path for i, path in enumerate(paths_all) if i not in sel_idxs] + + elif eval_mode == 3: + # Method 3: use the class centroid as the query (cheating) + labels_query.append(idx) + query_lists.append(np.mean(np.array(features_all), axis=0)) + + # Generate ground-truth labels for the gallery set + labels_gallery.extend([idx for _ in range(len(paths_all))]) + # Save + features_lists.append(features_all) + paths_lists.append(paths_all) + + if eval_mode in [2, 3]: + NUM_PER_CLASS = 1 + + # Merge into one large feature list + features = np.concatenate(features_lists, axis=0) + paths = [element for sublist in paths_lists for element in sublist] + print('Len(gallery):', len(features)) + + #### S2.Generate the distance matrix [Q_num, G_num] + dist_mat = optimized_cosine_matrix(np.array(query_lists), features) + print('Shape(matrix):', dist_mat.shape) + + #### S3.[Global metric] compute CMC and AP + cmc_all, ap_all = retrieval_cmc_ap(dist_mat, labels_query, labels_gallery, dist_type="cosine", rank_max=10) + print("Global metric AP:") + result_array.append(group_average(ap_all, NUM_PER_CLASS)) + + #### S4.[Local metric] compute the relationship between a single fake class and all real samples (recall under 0.001 false rejection) + # recall = retrieval_real(dist_mat, labels_query, labels_gallery, reject_real_ratio=0.001) + # print("Local metric: recall under 0.001 false rejection:") + # result_array.append(recall) + + ### S5.[Local metric] compute the relationship between a single fake class and all real samples (recall under 0.0001 false rejection) + recall = retrieval_real(dist_mat, labels_query, labels_gallery, reject_real_ratio=0.0001) + print("Local metric: recall under 0.0001 false rejection:") + result_array.append(group_average(recall, NUM_PER_CLASS)) + print() + + # Reorder as [global metric][local metric 0.001][local metric 0.0001] for 10Q, 10Q-Avg, and All-Mean + rearranged = np.array(result_array)[1::2] # np.concatenate([np.array(result_array)[::2], np.array(result_array)[1::2]]) + result_array = np.transpose(rearranged) + + df = pd.DataFrame({ + 'Algorithm': list(all_pkl.keys()), + **{f'Value_{i+1}': result_array[:, i] for i in range(result_array.shape[1])} + }) + + return df + +# Visual analysis of bad cases +def calculated_final_result_multi_query(all_pkl, upper_len=10): + ''' + Directly compute the final metrics from the pkl dictionary: global metrics + local metrics under 0.0001 false rejection + ''' + result_array = [] + + for num_query in range(1, upper_len + 1): # Evaluate every query count once + eval_modes = [2] + for eval_mode in eval_modes: # range(1, 4): # 1:'10Q' 2:'10Q-Avg' 3:'All-Mean' + #### Settings + NUM_PER_CLASS = num_query + + features_lists, paths_lists = [], [] # Store features from each pkl + query_lists, labels_query = [], [] # Store the initial video features + labels_gallery = [] + + #### S1.Load all pkl files and build the query set + for idx, pkl_key in tqdm(enumerate(all_pkl.keys())): + features_all, paths_all = merge_pkls(all_pkl[pkl_key]) + + if eval_mode == 1: + # Method 1: randomly select X samples as queries + sel_idxs = np.random.choice(len(paths_all), NUM_PER_CLASS, replace=False) + query_lists.extend(deepcopy(features_all[sel_idxs])) # Use deep copy here + # labels_query.extend([0 for _ in range(NUM_PER_CLASS)]) # Set to 0 to perform reverse recall for Real + labels_query.extend([idx for _ in range(NUM_PER_CLASS)]) + features_all = np.delete(features_all, sel_idxs, axis=0) + paths_all = [path for i, path in enumerate(paths_all) if i not in sel_idxs] + + elif eval_mode == 2: + # Method 2: use the center of X queries within a class as the query + sel_idxs = np.random.choice(len(paths_all), NUM_PER_CLASS, replace=False) + query_lists.append(np.mean(deepcopy(features_all[sel_idxs]), axis=0)) # Use deep copy here + labels_query.append(idx) + features_all = np.delete(features_all, sel_idxs, axis=0) + print("Selected videos:", [paths_all[i] for i in sel_idxs]) + paths_all = [path for i, path in enumerate(paths_all) if i not in sel_idxs] + + elif eval_mode == 3: + # Method 3: use the class centroid as the query (cheating) + labels_query.append(idx) + query_lists.append(np.mean(np.array(features_all), axis=0)) + + # Generate ground-truth labels for the gallery set + labels_gallery.extend([idx for _ in range(len(paths_all))]) + # Save + features_lists.append(features_all) + paths_lists.append(paths_all) + + if eval_mode in [2, 3]: + NUM_PER_CLASS = 1 + + # Merge into one large feature list + features = np.concatenate(features_lists, axis=0) + paths = [element for sublist in paths_lists for element in sublist] + print('Len(gallery):', len(features)) + + #### S2.Generate the distance matrix [Q_num, G_num] + dist_mat = optimized_cosine_matrix(np.array(query_lists), features) + print('Shape(matrix):', dist_mat.shape) + + #### S3.[Global metric] compute CMC and AP + cmc_all, ap_all = retrieval_cmc_ap(dist_mat, labels_query, labels_gallery, dist_type="cosine", rank_max=10) + print("Global metric AP:") + result_array.append(group_average(ap_all, NUM_PER_CLASS)) + + #### S4.[Local metric] compute the relationship between a single fake class and all real samples (recall under 0.001 false rejection) + # recall = retrieval_real(dist_mat, labels_query, labels_gallery, reject_real_ratio=0.001) + # print("Local metric: recall under 0.001 false rejection:") + # result_array.append(recall) + + ### S5.[Local metric] compute the relationship between a single fake class and all real samples (recall under 0.0001 false rejection) + recall = retrieval_real(dist_mat, labels_query, labels_gallery, reject_real_ratio=0.00001) # , paths=paths + print("Local metric: recall under 0.0001 false rejection:") + result_array.append(group_average(recall, NUM_PER_CLASS)) + print() + + # Reorder as [global metric][local metric 0.001][local metric 0.0001] for 10Q, 10Q-Avg, and All-Mean + rearranged = np.array(result_array)[1::2] # np.concatenate([np.array(result_array)[::2], np.array(result_array)[1::2]]) + result_array = np.transpose(rearranged) + + df = pd.DataFrame({ + 'Algorithm': list(all_pkl.keys()), + **{f'Value_{i+1}': result_array[:, i] for i in range(result_array.shape[1])} + }) + + return df +""" +def retrieval_real_p4(dist_mat, labels_query, labels_gallery, dist_type="cosine", reject_real_ratio=0.001, paths=None,cls2real=None): + #Exclude real classes from this computation + labels_query = torch.tensor(labels_query) + labels_gallery = torch.tensor(labels_gallery) + + # Basic parameter validation + num_query, num_gallery = dist_mat.shape + device = dist_mat.device + #Build real labels + if cls2real is None: + raise ValueError("The `cls2real` mapping must be provided (length = num_classes, with each class mapped to its corresponding real label).") + cls2real = cls2real.to(device) + query_real_labels = cls2real[labels_query] + #Filter out real classes + is_fake_query = (labels_query != query_real_labels) + fake_query_mask = is_fake_query + fake_labels_query = labels_query[fake_query_mask] + # n,fake_classes + fake_query_real_labels = query_real_labels[fake_query_mask] + dist_mat = dist_mat[fake_query_mask] + # Build the ordered index list + indexs = torch.tensor([i for i in range(len(labels_gallery))]) + # dist_mat : num_classes,n + if dist_type == "l2": + sorted_indices = torch.argsort(dist_mat, dim=1, descending=False) + else: # cosine + sorted_indices = torch.argsort(dist_mat, dim=1, descending=True) # for cosine distance, larger means closer, so sort in descending order num,n indices in ascending order + + # Build the sorted label matrix in batch (num_query, num_gallery) + sorted_labels = labels_gallery[sorted_indices] # Get retrieval result labels ordered by similarity + sorted_indexs = indexs[sorted_indices] # Get index list ordered by similarity + + # Build the match matrix (num_query, num_gallery), Find indices of all FakeX and Real entries in each row + matches_fake = sorted_labels == fake_labels_query.view(-1, 1) # Set positions with correct labels in the gallery to 1 + matches_real = sorted_labels == fake_query_real_labels.view(-1, 1) #here being equal to 0 means + + # Plot similarity curves for each algorithm class + # for i in range(num_query): + # if i == 0: continue # Skip real retrieval results + # tensor_fake = dist_mat[i][labels_gallery == i] + # tensor_real = dist_mat[i][labels_gallery == 0] + # print(len(tensor_fake), len(tensor_real)) + # plt.figure(figsize=(12, 4)) + # plt.plot(torch.concat([tensor_fake, tensor_real], axis=0)) + # # tensor_sort_fake, _ = torch.sort(tensor_fake, descending=True) + # # tensor_sort_real, _ = torch.sort(tensor_real, descending=True) + # # plt.plot(torch.concat([tensor_sort_fake, tensor_sort_real], axis=0)) + # plt.grid(True) + # plt.savefig(f'{i}_unsort.png') + # plt.close() + + # Find the position where the 0.001 false rejection threshold is reached + real_num = (labels_gallery.view(1, -1) == fake_query_real_labels.view(-1, 1)).sum(dim=1) # Number of real samples in each row + cum_real = matches_real.cumsum(dim=1) + reject_k = cum_real == (real_num * reject_real_ratio).view(-1, 1).int() # Apply the 0.001 false rejection constraint + first_occurrence = torch.argmax(reject_k.long(), dim=1) # Position where false rejection is tolerated + + # Find sequence positions where false rejection occurs (handled per fake class) + reject_real_union = set() + for i in range(len(labels_query)-4): # skip real retrieval + cut_pos = first_occurrence[i]+1 # Include the final false rejection position + cut_labels = sorted_labels[i][:cut_pos] # All labels before the false rejection cutoff + cut_indexs = sorted_indexs[i][:cut_pos] # All indices before the false rejection cutoff + reject_real_union.update(cut_indexs[cut_labels == query_real_labels[i]]) # Update the set + # print(i, len(reject_real_union), cut_indexs[cut_labels == 0]) + # print("Class label", i.numpy(), "descending similarity under false rejection:", dist_mat[i][cut_indexs[cut_labels == 0]]) + # print("cut_pos", i, cut_pos) + # print(f"Total false rejection after union across all fake classes at {reject_real_ratio}:", len(reject_real_union) / real_num) + + # Compute recall based on this position + cum_fake = matches_fake.cumsum(dim=1) + recall_nums = cum_fake[torch.arange(cum_fake.size(0)), first_occurrence] + + # print("recall_nums", recall_nums) + # print("cum_fake", cum_fake[:, -1]) + + return (recall_nums / cum_fake[:, -1]).numpy() + + +""" +def retrieval_real_p4(dist_mat, labels_query, labels_gallery, dist_type="cosine", reject_real_ratio=0.001, paths=None,cls2real=None): + + labels_query = torch.tensor(labels_query) + labels_gallery = torch.tensor(labels_gallery) + + # Basic parameter validation + num_query, num_gallery = dist_mat.shape + device = dist_mat.device + #Build real labels + if cls2real is None: + raise ValueError("The `cls2real` mapping must be provided") + cls2real = cls2real.to(device) + query_real_labels = cls2real[labels_query] + # Build the ordered index list + indexs = torch.tensor([i for i in range(len(labels_gallery))]) + # dist_mat : num_classes,n + if dist_type == "l2": + sorted_indices = torch.argsort(dist_mat, dim=1, descending=False) + else: # cosine + sorted_indices = torch.argsort(dist_mat, dim=1, descending=True) # for cosine distance, larger means closer, so sort in descending order num,n indices in ascending order + + # Build the sorted label matrix in batch (num_query, num_gallery) + sorted_labels = labels_gallery[sorted_indices] # Get retrieval result labels ordered by similarity + sorted_indexs = indexs[sorted_indices] # Get index list ordered by similarity + + # Build the match matrix (num_query, num_gallery), Find indices of all FakeX and Real entries in each row + matches_fake = sorted_labels == labels_query.view(-1, 1) # Set positions with correct labels in the gallery to 1 + matches_real = sorted_labels == query_real_labels.view(-1, 1) #here being equal to 0 means + + # Plot similarity curves for each algorithm class + # for i in range(num_query): + # if i == 0: continue # Skip real retrieval results + # tensor_fake = dist_mat[i][labels_gallery == i] + # tensor_real = dist_mat[i][labels_gallery == 0] + # print(len(tensor_fake), len(tensor_real)) + # plt.figure(figsize=(12, 4)) + # plt.plot(torch.concat([tensor_fake, tensor_real], axis=0)) + # # tensor_sort_fake, _ = torch.sort(tensor_fake, descending=True) + # # tensor_sort_real, _ = torch.sort(tensor_real, descending=True) + # # plt.plot(torch.concat([tensor_sort_fake, tensor_sort_real], axis=0)) + # plt.grid(True) + # plt.savefig(f'{i}_unsort.png') + # plt.close() + + # Find the position where the 0.001 false rejection threshold is reached + real_num = (labels_gallery.view(1, -1) == query_real_labels.view(-1, 1)).sum(dim=1) # Number of real samples in each row + cum_real = matches_real.cumsum(dim=1) + reject_k = cum_real == (torch.clamp_min((real_num * reject_real_ratio).int(), 1)).view(-1, 1).int() # Apply the 0.001 false rejection constraint + first_occurrence = torch.argmax(reject_k.long(), dim=1) # Position where false rejection is tolerated + + # Find sequence positions where false rejection occurs (handled per fake class) + reject_real_union = set() + for i in range(len(labels_query)): # skip real retrieval + cut_pos = first_occurrence[i]+1 # Include the final false rejection position + cut_labels = sorted_labels[i][:cut_pos] # All labels before the false rejection cutoff + cut_indexs = sorted_indexs[i][:cut_pos] # All indices before the false rejection cutoff + reject_real_union.update(cut_indexs[cut_labels == query_real_labels[i]]) # Update the set + # print(i, len(reject_real_union), cut_indexs[cut_labels == 0]) + # print("Class label", i.numpy(), "descending similarity under false rejection:", dist_mat[i][cut_indexs[cut_labels == 0]]) + # print("cut_pos", i, cut_pos) + # print(f"Total false rejection after union across all fake classes at {reject_real_ratio}:", len(reject_real_union) / real_num) + + # Compute recall based on this position + cum_fake = matches_fake.cumsum(dim=1) + recall_nums = cum_fake[torch.arange(cum_fake.size(0)), first_occurrence] + + # print("recall_nums", recall_nums) + # print("cum_fake", cum_fake[:, -1]) + + return (recall_nums / cum_fake[:, -1]).numpy() + diff --git a/training/networks/__init__.py b/training/networks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0fa30adf6859738c4cbf0be601107f8664446f5e --- /dev/null +++ b/training/networks/__init__.py @@ -0,0 +1,16 @@ +import os +import sys +current_file_path = os.path.abspath(__file__) +parent_dir = os.path.dirname(os.path.dirname(current_file_path)) +project_root_dir = os.path.dirname(parent_dir) +sys.path.append(parent_dir) +sys.path.append(project_root_dir) + +from metrics.registry import BACKBONE + +from .xception import Xception +from .mesonet import Meso4, MesoInception4 +from .resnet50 import ResNet50 +from .resnet34 import ResNet34 +from .efficientnetb4 import EfficientNetB4 +from .xception_sladd import Xception_SLADD diff --git a/training/networks/adaface.py b/training/networks/adaface.py new file mode 100644 index 0000000000000000000000000000000000000000..21730fdac4bed733ba2cde876884f5b46d8ea230 --- /dev/null +++ b/training/networks/adaface.py @@ -0,0 +1,414 @@ +from collections import namedtuple +import torch +import torch.nn as nn +from torch.nn import Dropout +from torch.nn import MaxPool2d +from torch.nn import Sequential +from torch.nn import Conv2d, Linear +from torch.nn import BatchNorm1d, BatchNorm2d +from torch.nn import ReLU, Sigmoid +from torch.nn import Module +from torch.nn import PReLU +import os + +def build_model(model_name='ir_50'): + if model_name == 'ir_101': + return IR_101(input_size=(112,112)) + elif model_name == 'ir_50': + return IR_50(input_size=(112,112)) + elif model_name == 'ir_se_50': + return IR_SE_50(input_size=(112,112)) + elif model_name == 'ir_34': + return IR_34(input_size=(112,112)) + elif model_name == 'ir_18': + return IR_18(input_size=(112,112)) + else: + raise ValueError('not a correct model name', model_name) + +def initialize_weights(modules): + """ Weight initilize, conv2d and linear is initialized with kaiming_normal + """ + for m in modules: + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, + mode='fan_out', + nonlinearity='relu') + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, + mode='fan_out', + nonlinearity='relu') + if m.bias is not None: + m.bias.data.zero_() + + +class Flatten(Module): + """ Flat tensor + """ + def forward(self, input): + return input.view(input.size(0), -1) + + +class LinearBlock(Module): + """ Convolution block without no-linear activation layer + """ + def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): + super(LinearBlock, self).__init__() + self.conv = Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False) + self.bn = BatchNorm2d(out_c) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + +class GNAP(Module): + """ Global Norm-Aware Pooling block + """ + def __init__(self, in_c): + super(GNAP, self).__init__() + self.bn1 = BatchNorm2d(in_c, affine=False) + self.pool = nn.AdaptiveAvgPool2d((1, 1)) + self.bn2 = BatchNorm1d(in_c, affine=False) + + def forward(self, x): + x = self.bn1(x) + x_norm = torch.norm(x, 2, 1, True) + x_norm_mean = torch.mean(x_norm) + weight = x_norm_mean / x_norm + x = x * weight + x = self.pool(x) + x = x.view(x.shape[0], -1) + feature = self.bn2(x) + return feature + + +class GDC(Module): + """ Global Depthwise Convolution block + """ + def __init__(self, in_c, embedding_size): + super(GDC, self).__init__() + self.conv_6_dw = LinearBlock(in_c, in_c, + groups=in_c, + kernel=(7, 7), + stride=(1, 1), + padding=(0, 0)) + self.conv_6_flatten = Flatten() + self.linear = Linear(in_c, embedding_size, bias=False) + self.bn = BatchNorm1d(embedding_size, affine=False) + + def forward(self, x): + x = self.conv_6_dw(x) + x = self.conv_6_flatten(x) + x = self.linear(x) + x = self.bn(x) + return x + + +class SEModule(Module): + """ SE block + """ + def __init__(self, channels, reduction): + super(SEModule, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc1 = Conv2d(channels, channels // reduction, + kernel_size=1, padding=0, bias=False) + + nn.init.xavier_uniform_(self.fc1.weight.data) + + self.relu = ReLU(inplace=True) + self.fc2 = Conv2d(channels // reduction, channels, + kernel_size=1, padding=0, bias=False) + + self.sigmoid = Sigmoid() + + def forward(self, x): + module_input = x + x = self.avg_pool(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + + return module_input * x + + + +class BasicBlockIR(Module): + """ BasicBlock for IRNet + """ + def __init__(self, in_channel, depth, stride): + super(BasicBlockIR, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth)) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), + BatchNorm2d(depth), + PReLU(depth), + Conv2d(depth, depth, (3, 3), stride, 1, bias=False), + BatchNorm2d(depth)) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + + return res + shortcut + + +class BottleneckIR(Module): + """ BasicBlock with bottleneck for IRNet + """ + def __init__(self, in_channel, depth, stride): + super(BottleneckIR, self).__init__() + reduction_channel = depth // 4 + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth)) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, reduction_channel, (1, 1), (1, 1), 0, bias=False), + BatchNorm2d(reduction_channel), + PReLU(reduction_channel), + Conv2d(reduction_channel, reduction_channel, (3, 3), (1, 1), 1, bias=False), + BatchNorm2d(reduction_channel), + PReLU(reduction_channel), + Conv2d(reduction_channel, depth, (1, 1), stride, 0, bias=False), + BatchNorm2d(depth)) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + + return res + shortcut + + +class BasicBlockIRSE(BasicBlockIR): + def __init__(self, in_channel, depth, stride): + super(BasicBlockIRSE, self).__init__(in_channel, depth, stride) + self.res_layer.add_module("se_block", SEModule(depth, 16)) + + +class BottleneckIRSE(BottleneckIR): + def __init__(self, in_channel, depth, stride): + super(BottleneckIRSE, self).__init__(in_channel, depth, stride) + self.res_layer.add_module("se_block", SEModule(depth, 16)) + + +class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): + '''A named tuple describing a ResNet block.''' + + +def get_block(in_channel, depth, num_units, stride=2): + + return [Bottleneck(in_channel, depth, stride)] +\ + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] + + +def get_blocks(num_layers): + if num_layers == 18: + blocks = [ + get_block(in_channel=64, depth=64, num_units=2), + get_block(in_channel=64, depth=128, num_units=2), + get_block(in_channel=128, depth=256, num_units=2), + get_block(in_channel=256, depth=512, num_units=2) + ] + elif num_layers == 34: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=4), + get_block(in_channel=128, depth=256, num_units=6), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 50: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=4), + get_block(in_channel=128, depth=256, num_units=14), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 100: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=13), + get_block(in_channel=128, depth=256, num_units=30), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 152: + blocks = [ + get_block(in_channel=64, depth=256, num_units=3), + get_block(in_channel=256, depth=512, num_units=8), + get_block(in_channel=512, depth=1024, num_units=36), + get_block(in_channel=1024, depth=2048, num_units=3) + ] + elif num_layers == 200: + blocks = [ + get_block(in_channel=64, depth=256, num_units=3), + get_block(in_channel=256, depth=512, num_units=24), + get_block(in_channel=512, depth=1024, num_units=36), + get_block(in_channel=1024, depth=2048, num_units=3) + ] + + return blocks + + +class Backbone(Module): + def __init__(self, input_size, num_layers, mode='ir'): + """ Args: + input_size: input_size of backbone + num_layers: num_layers of backbone + mode: support ir or irse + """ + super(Backbone, self).__init__() + assert input_size[0] in [112, 224], \ + "input_size should be [112, 112] or [224, 224]" + assert num_layers in [18, 34, 50, 100, 152, 200], \ + "num_layers should be 18, 34, 50, 100 or 152" + assert mode in ['ir', 'ir_se'], \ + "mode should be ir or ir_se" + self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), + BatchNorm2d(64), PReLU(64)) + blocks = get_blocks(num_layers) + if num_layers <= 100: + if mode == 'ir': + unit_module = BasicBlockIR + elif mode == 'ir_se': + unit_module = BasicBlockIRSE + output_channel = 512 + else: + if mode == 'ir': + unit_module = BottleneckIR + elif mode == 'ir_se': + unit_module = BottleneckIRSE + output_channel = 2048 + + if input_size[0] == 112: + self.output_layer = Sequential(BatchNorm2d(output_channel), + Dropout(0.4), Flatten(), + Linear(output_channel * 7 * 7, 512), + BatchNorm1d(512, affine=False)) + else: + self.output_layer = Sequential( + BatchNorm2d(output_channel), Dropout(0.4), Flatten(), + Linear(output_channel * 14 * 14, 512), + BatchNorm1d(512, affine=False)) + + modules = [] + for block in blocks: + for bottleneck in block: + modules.append( + unit_module(bottleneck.in_channel, bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + initialize_weights(self.modules()) + + + def forward(self, x): + + # current code only supports one extra image + # it comes with a extra dimension for number of extra image. We will just squeeze it out for now + x = self.input_layer(x) + + for idx, module in enumerate(self.body): + x = module(x) + + x = self.output_layer(x) + norm = torch.norm(x, 2, 1, True) + output = torch.div(x, norm) + + return output, norm + + + +def IR_18(input_size): + """ Constructs a ir-18 model. + """ + model = Backbone(input_size, 18, 'ir') + + return model + + +def IR_34(input_size): + """ Constructs a ir-34 model. + """ + model = Backbone(input_size, 34, 'ir') + + return model + + +def IR_50(input_size): + """ Constructs a ir-50 model. + """ + model = Backbone(input_size, 50, 'ir') + + return model + + +def IR_101(input_size): + """ Constructs a ir-101 model. + """ + model = Backbone(input_size, 100, 'ir') + + return model + + +def IR_152(input_size): + """ Constructs a ir-152 model. + """ + model = Backbone(input_size, 152, 'ir') + + return model + + +def IR_200(input_size): + """ Constructs a ir-200 model. + """ + model = Backbone(input_size, 200, 'ir') + + return model + + +def IR_SE_50(input_size): + """ Constructs a ir_se-50 model. + """ + model = Backbone(input_size, 50, 'ir_se') + + return model + + +def IR_SE_101(input_size): + """ Constructs a ir_se-101 model. + """ + model = Backbone(input_size, 100, 'ir_se') + + return model + + +def IR_SE_152(input_size): + """ Constructs a ir_se-152 model. + """ + model = Backbone(input_size, 152, 'ir_se') + + return model + + +def IR_SE_200(input_size): + """ Constructs a ir_se-200 model. + """ + model = Backbone(input_size, 200, 'ir_se') + + return model + diff --git a/training/networks/base_backbone.py b/training/networks/base_backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..8cbb14439c4c8da93d12edd0f0af442d8f460698 --- /dev/null +++ b/training/networks/base_backbone.py @@ -0,0 +1,32 @@ +import abc +import torch +from typing import Union + +class AbstractBackbone(abc.ABC): + """ + All backbones for detectors should subclass this class. + """ + def __init__(self, config, load_param: Union[bool, str] = False): + """ + config: (dict) + configurations for the model + load_param: (False | True | Path(str)) + False Do not read; True Read the default path; Path Read the required path + """ + pass + + @abc.abstractmethod + def features(self, data_dict: dict) -> torch.tensor: + """ + """ + + @abc.abstractmethod + def classifier(self, features: torch.tensor) -> torch.tensor: + """ + """ + + def init_weights(self, pretrained_path: Union[bool, str]): + """ + This method can be optionally implemented by subclasses. + """ + pass \ No newline at end of file diff --git a/training/networks/cls_hrnet.py b/training/networks/cls_hrnet.py new file mode 100644 index 0000000000000000000000000000000000000000..2c04a0e05817e7a3d91dd2af895aae7c5dc67403 --- /dev/null +++ b/training/networks/cls_hrnet.py @@ -0,0 +1,569 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 + +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (Bin.Xiao@microsoft.com) +# Modified by Ke Sun (sunk@mail.ustc.edu.cn) +# ------------------------------------------------------------------------------ + +The code is mainly modified from the below link: +https://github.com/HRNet/HRNet-Image-Classification/tree/master +''' + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import logging +import functools + +import numpy as np +from typing import Union + +import torch +import torch.nn as nn +import torch._utils +import torch.nn.functional as F + +BN_MOMENTUM = 0.1 +logger = logging.getLogger(__name__) + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, + bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion, + momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class HighResolutionModule(nn.Module): + def __init__(self, num_branches, blocks, num_blocks, num_inchannels, + num_channels, fuse_method, multi_scale_output=True): + super(HighResolutionModule, self).__init__() + self._check_branches( + num_branches, blocks, num_blocks, num_inchannels, num_channels) + + self.num_inchannels = num_inchannels + self.fuse_method = fuse_method + self.num_branches = num_branches + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches( + num_branches, blocks, num_blocks, num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(False) + + def _check_branches(self, num_branches, blocks, num_blocks, + num_inchannels, num_channels): + if num_branches != len(num_blocks): + error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( + num_branches, len(num_blocks)) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( + num_branches, len(num_channels)) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_inchannels): + error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( + num_branches, len(num_inchannels)) + logger.error(error_msg) + raise ValueError(error_msg) + + def _make_one_branch(self, branch_index, block, num_blocks, num_channels, + stride=1): + downsample = None + if stride != 1 or \ + self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.num_inchannels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(num_channels[branch_index] * block.expansion, + momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index], stride, downsample)) + self.num_inchannels[branch_index] = \ + num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index])) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + branches.append( + self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_inchannels[i], + 1, + 1, + 0, + bias=False), + nn.BatchNorm2d(num_inchannels[i], + momentum=BN_MOMENTUM), + nn.Upsample(scale_factor=2**(j-i), mode='nearest'))) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i-j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + 3, 2, 1, bias=False), + nn.BatchNorm2d(num_outchannels_conv3x3, + momentum=BN_MOMENTUM))) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + 3, 2, 1, bias=False), + nn.BatchNorm2d(num_outchannels_conv3x3, + momentum=BN_MOMENTUM), + nn.ReLU(False))) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + + def forward(self, x): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + for i in range(len(self.fuse_layers)): + y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + else: + y = y + self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + + return x_fuse + + +blocks_dict = { + 'BASIC': BasicBlock, + 'BOTTLENECK': Bottleneck +} + + +class HighResolutionNet(nn.Module): + + def __init__(self, cfg): + super(HighResolutionNet, self).__init__() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, + bias=False) + self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + + self.stage1_cfg = cfg['MODEL']['EXTRA']['STAGE1'] + num_channels = self.stage1_cfg['NUM_CHANNELS'][0] + block = blocks_dict[self.stage1_cfg['BLOCK']] + num_blocks = self.stage1_cfg['NUM_BLOCKS'][0] + self.layer1 = self._make_layer(block, 64, num_channels, num_blocks) + stage1_out_channel = block.expansion*num_channels + + self.stage2_cfg = cfg['MODEL']['EXTRA']['STAGE2'] + num_channels = self.stage2_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage2_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition1 = self._make_transition_layer( + [stage1_out_channel], num_channels) + self.stage2, pre_stage_channels = self._make_stage( + self.stage2_cfg, num_channels) + + self.stage3_cfg = cfg['MODEL']['EXTRA']['STAGE3'] + num_channels = self.stage3_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage3_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition2 = self._make_transition_layer( + pre_stage_channels, num_channels) + self.stage3, pre_stage_channels = self._make_stage( + self.stage3_cfg, num_channels) + + self.stage4_cfg = cfg['MODEL']['EXTRA']['STAGE4'] + num_channels = self.stage4_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage4_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition3 = self._make_transition_layer( + pre_stage_channels, num_channels) + self.stage4, pre_stage_channels = self._make_stage( + self.stage4_cfg, num_channels, multi_scale_output=True) + + # Classification Head + self.incre_modules, self.downsamp_modules, \ + self.final_layer = self._make_head(pre_stage_channels) + + self.fc = nn.Linear(2048, 1000) + + + def _make_head(self, pre_stage_channels): + head_block = Bottleneck + head_channels = [32, 64, 128, 256] + + # Increasing the #channels on each resolution + # from C, 2C, 4C, 8C to 128, 256, 512, 1024 + incre_modules = [] + for i, channels in enumerate(pre_stage_channels): + incre_module = self._make_layer(head_block, + channels, + head_channels[i], + 1, + stride=1) + incre_modules.append(incre_module) + incre_modules = nn.ModuleList(incre_modules) + + # downsampling modules + downsamp_modules = [] + for i in range(len(pre_stage_channels)-1): + in_channels = head_channels[i] * head_block.expansion + out_channels = head_channels[i+1] * head_block.expansion + + downsamp_module = nn.Sequential( + nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=2, + padding=1), + nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM), + nn.ReLU(inplace=True) + ) + + downsamp_modules.append(downsamp_module) + downsamp_modules = nn.ModuleList(downsamp_modules) + + final_layer = nn.Sequential( + nn.Conv2d( + in_channels=head_channels[3] * head_block.expansion, + out_channels=2048, + kernel_size=1, + stride=1, + padding=0 + ), + nn.BatchNorm2d(2048, momentum=BN_MOMENTUM), + nn.ReLU(inplace=True) + ) + + return incre_modules, downsamp_modules, final_layer + + def _make_transition_layer( + self, num_channels_pre_layer, num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append(nn.Sequential( + nn.Conv2d(num_channels_pre_layer[i], + num_channels_cur_layer[i], + 3, + 1, + 1, + bias=False), + nn.BatchNorm2d( + num_channels_cur_layer[i], momentum=BN_MOMENTUM), + nn.ReLU(inplace=True))) + else: + transition_layers.append(None) + else: + conv3x3s = [] + for j in range(i+1-num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = num_channels_cur_layer[i] \ + if j == i-num_branches_pre else inchannels + conv3x3s.append(nn.Sequential( + nn.Conv2d( + inchannels, outchannels, 3, 2, 1, bias=False), + nn.BatchNorm2d(outchannels, momentum=BN_MOMENTUM), + nn.ReLU(inplace=True))) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(inplanes, planes, stride, downsample)) + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_stage(self, layer_config, num_inchannels, + multi_scale_output=True): + num_modules = layer_config['NUM_MODULES'] + num_branches = layer_config['NUM_BRANCHES'] + num_blocks = layer_config['NUM_BLOCKS'] + num_channels = layer_config['NUM_CHANNELS'] + block = blocks_dict[layer_config['BLOCK']] + fuse_method = layer_config['FUSE_METHOD'] + + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + if not multi_scale_output and i == num_modules - 1: + reset_multi_scale_output = False + else: + reset_multi_scale_output = True + + modules.append( + HighResolutionModule(num_branches, + block, + num_blocks, + num_inchannels, + num_channels, + fuse_method, + reset_multi_scale_output) + ) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_cfg['NUM_BRANCHES']): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_cfg['NUM_BRANCHES']): + if self.transition2[i] is not None: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_cfg['NUM_BRANCHES']): + if self.transition3[i] is not None: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage4(x_list) + + # Classification Head + y = self.incre_modules[0](y_list[0]) + for i in range(len(self.downsamp_modules)): + y = self.incre_modules[i+1](y_list[i+1]) + \ + self.downsamp_modules[i](y) + + y = self.final_layer(y) + + if torch._C._get_tracing_state(): + y = y.flatten(start_dim=2).mean(dim=2) + else: + y = F.avg_pool2d(y, kernel_size=y.size() + [2:]).view(y.size(0), -1) + + y = self.fc(y) + + return y + + def features(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_cfg['NUM_BRANCHES']): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_cfg['NUM_BRANCHES']): + if self.transition2[i] is not None: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_cfg['NUM_BRANCHES']): + if self.transition3[i] is not None: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage4(x_list) + + # Upsampling + x0, x1, x2, x3 = y_list + x0_h, x0_w = x0.size(2), x0.size(3) + x1 = F.upsample(x1, size=(x0_h, x0_w), mode='bilinear') + x2 = F.upsample(x2, size=(x0_h, x0_w), mode='bilinear') + x3 = F.upsample(x3, size=(x0_h, x0_w), mode='bilinear') + + x_out = torch.cat([x0, x1, x2, x3], 1) + + #print(x_out.size()) + + return x_out + + def classifier(self, x): + # Classification Head + y = self.incre_modules[0](x[0]) + for i in range(len(self.downsamp_modules)): + y = self.incre_modules[i+1](x[i+1]) + \ + self.downsamp_modules[i](y) + + y = self.final_layer(y) + + if torch._C._get_tracing_state(): + y = y.flatten(start_dim=2).mean(dim=2) + else: + y = F.avg_pool2d(y, kernel_size=y.size() + [2:]).view(y.size(0), -1) + + y = self.fc(y) + +def get_cls_net(config, **kwargs): + model = HighResolutionNet(config, **kwargs) + return model diff --git a/training/networks/efficientnetb4.py b/training/networks/efficientnetb4.py new file mode 100644 index 0000000000000000000000000000000000000000..e2c2ea71ffa9d6c1aa259742d0f10345aff470b6 --- /dev/null +++ b/training/networks/efficientnetb4.py @@ -0,0 +1,112 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 + +The code is for EfficientNetB4 backbone. +''' + +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Union +from efficientnet_pytorch import EfficientNet +from metrics.registry import BACKBONE +import os + +@BACKBONE.register_module(module_name="efficientnetb4") +class EfficientNetB4(nn.Module): + def __init__(self, efficientnetb4_config): + super(EfficientNetB4, self).__init__() + """ Constructor + Args: + efficientnetb4_config: configuration file with the dict format + """ + self.num_classes = efficientnetb4_config["num_classes"] + inc = efficientnetb4_config["inc"] + self.dropout = efficientnetb4_config["dropout"] + self.mode = efficientnetb4_config["mode"] + + # Load the EfficientNet-B4 model without pre-trained weights + if efficientnetb4_config['pretrained']: + self.efficientnet = EfficientNet.from_pretrained('efficientnet-b4',weights_path=efficientnetb4_config['pretrained']) # FIXME: load the pretrained weights from online + # self.efficientnet = EfficientNet.from_name('efficientnet-b4') + else: + self.efficientnet = EfficientNet.from_name('efficientnet-b4') + # Modify the first convolutional layer to accept input tensors with 'inc' channels + self.efficientnet._conv_stem = nn.Conv2d(inc, 48, kernel_size=3, stride=2, bias=False) + + # Remove the last layer (the classifier) from the EfficientNet-B4 model + self.efficientnet._fc = nn.Identity() + + if self.dropout: + # Add dropout layer if specified + self.dropout_layer = nn.Dropout(p=self.dropout) + + # Initialize the last_layer layer + self.last_layer = nn.Linear(1792, self.num_classes) + + if self.mode == 'adjust_channel': + self.adjust_channel = nn.Sequential( + nn.Conv2d(1792, 512, 1, 1), + nn.BatchNorm2d(512), + nn.ReLU(inplace=True), + ) + + def block_part1(self,x): + x = self.efficientnet._swish(self.efficientnet._bn0(self.efficientnet._conv_stem(x))) + # x = self.efficientnet._blocks[0:10](x) + for idx, block in enumerate(self.efficientnet._blocks[:10]): + drop_connect_rate = self.efficientnet._global_params.drop_connect_rate + if drop_connect_rate: + drop_connect_rate *= float(idx+0) / len(self.efficientnet._blocks) # scale drop connect_rate + x = block(x, drop_connect_rate=drop_connect_rate) + return x + + def block_part2(self,x): + for idx, block in enumerate(self.efficientnet._blocks[10:22]): + drop_connect_rate = self.efficientnet._global_params.drop_connect_rate + if drop_connect_rate: + drop_connect_rate *= float(idx+10) / len(self.efficientnet._blocks) # scale drop connect_rate + x = block(x, drop_connect_rate=drop_connect_rate) + return x + + def block_part3(self,x): + for idx, block in enumerate(self.efficientnet._blocks[22:]): + drop_connect_rate = self.efficientnet._global_params.drop_connect_rate + if drop_connect_rate: + drop_connect_rate *= float(idx+22) / len(self.efficientnet._blocks) # scale drop connect_rate + x = block(x, drop_connect_rate=drop_connect_rate) + x = self.efficientnet._swish(self.efficientnet._bn1(self.efficientnet._conv_head(x))) + return x + + + def features(self, x): + # Extract features from the EfficientNet-B4 model + x = self.efficientnet.extract_features(x) + if self.mode == 'adjust_channel': + x = self.adjust_channel(x) + return x + def end_points(self,x): + return self.efficientnet.extract_endpoints(x) + def classifier(self, x): + x = F.adaptive_avg_pool2d(x, (1, 1)) + x = x.view(x.size(0), -1) + + # Apply dropout if specified + if self.dropout: + x = self.dropout_layer(x) + + # Apply last_layer layer + self.last_emb = x + y = self.last_layer(x) + return y + + def forward(self, x): + # Extract features and apply classifier layer + x = self.features(x) + # if False: + # x = F.adaptive_avg_pool2d(x, (1, 1)) + # x = x.view(x.size(0), -1) + x = self.classifier(x) + return x diff --git a/training/networks/iresnet.py b/training/networks/iresnet.py new file mode 100644 index 0000000000000000000000000000000000000000..12a454cb9acc227b8968806cf3c47c4145e6304b --- /dev/null +++ b/training/networks/iresnet.py @@ -0,0 +1,191 @@ +import torch +from torch import nn +import torch.nn.functional as F + +__all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200'] + +def set_requires_grad(model, val): + for p in model.parameters(): + p.requires_grad = val + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, + out_planes, + kernel_size=1, + stride=stride, + bias=False) + + +class IBasicBlock(nn.Module): + expansion = 1 + def __init__(self, inplanes, planes, stride=1, downsample=None, + groups=1, base_width=64, dilation=1): + super(IBasicBlock, self).__init__() + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,) + self.conv1 = conv3x3(inplanes, planes) + self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,) + self.prelu = nn.PReLU(planes) + self.conv2 = conv3x3(planes, planes, stride) + self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + out = self.bn1(x) + out = self.conv1(out) + out = self.bn2(out) + out = self.prelu(out) + out = self.conv2(out) + out = self.bn3(out) + if self.downsample is not None: + identity = self.downsample(x) + out += identity + return out + + +class IResNet(nn.Module): + def __init__(self, + block, layers, dropout=0, num_features=512, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False, fc_scale=7*7): + super(IResNet, self).__init__() + self.fp16 = fp16 + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) + self.prelu = nn.PReLU(self.inplanes) + self.layer1 = self._make_layer(block, 64, layers[0], stride=2) + self.layer2 = self._make_layer(block, + 128, + layers[1], + stride=2, + dilate=replace_stride_with_dilation[0]) + + self.layer3 = self._make_layer(block, + 256, + layers[2], + stride=2, + dilate=replace_stride_with_dilation[1]) + set_requires_grad(self.layer1, False) + set_requires_grad(self.layer2, False) + set_requires_grad(self.layer3, False) + self.layer4 = self._make_layer(block, + 512, + layers[3], + stride=2, + dilate=replace_stride_with_dilation[2]) + self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,) + self.dropout = nn.Dropout(p=dropout, inplace=True) + self.fc = nn.Linear(512 * block.expansion * fc_scale, num_features) + self.features = nn.BatchNorm1d(num_features, eps=1e-05) + nn.init.constant_(self.features.weight, 1.0) + self.features.weight.requires_grad = False + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight, 0, 0.1) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + if zero_init_residual: + for m in self.modules(): + if isinstance(m, IBasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ), + ) + layers = [] + layers.append( + block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block(self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation)) + + return nn.Sequential(*layers) + + def forward(self, x): + with torch.cuda.amp.autocast(self.fp16): + x = self.conv1(x) + x = self.bn1(x) + x = self.prelu(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.bn2(x) + x = self.dropout(x) + x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=0) + return x + + +def _iresnet(arch, block, layers, pretrained, progress, **kwargs): + model = IResNet(block, layers, **kwargs) + if pretrained: + raise ValueError() + return model + + +def iresnet18(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained, + progress, **kwargs) + + +def iresnet34(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained, + progress, **kwargs) + + +def iresnet50(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained, + progress, **kwargs) + + +def iresnet100(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained, + progress, **kwargs) + + +def iresnet200(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained, + progress, **kwargs) diff --git a/training/networks/iresnet_iid.py b/training/networks/iresnet_iid.py new file mode 100644 index 0000000000000000000000000000000000000000..5776ad2f64a6c19c376adde5e5a685d8241c5011 --- /dev/null +++ b/training/networks/iresnet_iid.py @@ -0,0 +1,196 @@ +import torch +from torch import nn +from torch.utils.checkpoint import checkpoint + + +__all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200'] +using_ckpt = False + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, + out_planes, + kernel_size=1, + stride=stride, + bias=False) + + +class IBasicBlock(nn.Module): + expansion = 1 + def __init__(self, inplanes, planes, stride=1, downsample=None, + groups=1, base_width=64, dilation=1): + super(IBasicBlock, self).__init__() + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,) + self.conv1 = conv3x3(inplanes, planes) + self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,) + self.prelu = nn.PReLU(planes) + self.conv2 = conv3x3(planes, planes, stride) + self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,) + self.downsample = downsample + self.stride = stride + + def forward_impl(self, x): + identity = x + out = self.bn1(x) + out = self.conv1(out) + out = self.bn2(out) + out = self.prelu(out) + out = self.conv2(out) + out = self.bn3(out) + if self.downsample is not None: + identity = self.downsample(x) + out += identity + return out + + def forward(self, x): + if self.training and using_ckpt: + return checkpoint(self.forward_impl, x) + else: + return self.forward_impl(x) + + +class IResNet(nn.Module): + fc_scale = 7 * 7 + def __init__(self, + block, layers, dropout=0, num_features=512, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False): + super(IResNet, self).__init__() + self.extra_gflops = 0.0 + self.fp16 = fp16 + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) + self.prelu = nn.PReLU(self.inplanes) + self.layer1 = self._make_layer(block, 64, layers[0], stride=2) + self.layer2 = self._make_layer(block, + 128, + layers[1], + stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, + 256, + layers[2], + stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, + 512, + layers[3], + stride=2, + dilate=replace_stride_with_dilation[2]) + self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,) + self.dropout = nn.Dropout(p=dropout, inplace=True) + self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) + self.features = nn.BatchNorm1d(num_features, eps=1e-05) + nn.init.constant_(self.features.weight, 1.0) + self.features.weight.requires_grad = False + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight, 0, 0.1) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + if zero_init_residual: + for m in self.modules(): + if isinstance(m, IBasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ), + ) + layers = [] + layers.append( + block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block(self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation)) + + return nn.Sequential(*layers) + + def forward(self, x): + with torch.cuda.amp.autocast(self.fp16): + x = self.conv1(x) + x = self.bn1(x) + x = self.prelu(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.bn2(x) + x = torch.flatten(x, 1) + x = self.dropout(x) + x = self.fc(x.float() if self.fp16 else x) + if x.size(0)>1: + x = self.features(x) + return x + + +def _iresnet(arch, block, layers, pretrained, progress, **kwargs): + model = IResNet(block, layers, **kwargs) + if pretrained: + raise ValueError() + return model + + +def iresnet18(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained, + progress, **kwargs) + + +def iresnet34(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained, + progress, **kwargs) + + +def iresnet50(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained, + progress, **kwargs) + + +def iresnet100(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained, + progress, **kwargs) + + +def iresnet200(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained, + progress, **kwargs) \ No newline at end of file diff --git a/training/networks/mesonet.py b/training/networks/mesonet.py new file mode 100644 index 0000000000000000000000000000000000000000..07429325acfaad62cc11825dfcc998b52f538b60 --- /dev/null +++ b/training/networks/mesonet.py @@ -0,0 +1,189 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 + +The code is mainly modified from the below link: +https://github.com/HongguLiu/MesoNet-Pytorch +''' + +import os +import argparse +import logging + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +import torch.utils.model_zoo as model_zoo +from torch.nn import init +from typing import Union +from metrics.registry import BACKBONE + +logger = logging.getLogger(__name__) + +@BACKBONE.register_module(module_name="meso4") +class Meso4(nn.Module): + def __init__(self, meso4_config): + super(Meso4, self).__init__() + self.num_classes = meso4_config["num_classes"] + inc = meso4_config["inc"] + self.conv1 = nn.Conv2d(inc, 8, 3, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(8) + self.relu = nn.ReLU(inplace=True) + self.leakyrelu = nn.LeakyReLU(0.1) + + self.conv2 = nn.Conv2d(8, 8, 5, padding=2, bias=False) + self.bn2 = nn.BatchNorm2d(16) + self.conv3 = nn.Conv2d(8, 16, 5, padding=2, bias=False) + self.conv4 = nn.Conv2d(16, 16, 5, padding=2, bias=False) + self.maxpooling1 = nn.MaxPool2d(kernel_size=(2, 2)) + self.maxpooling2 = nn.MaxPool2d(kernel_size=(4, 4)) + #flatten: x = x.view(x.size(0), -1) + self.dropout = nn.Dropout2d(0.5) + self.fc1 = nn.Linear(16*8*8, 16) + self.fc2 = nn.Linear(16, self.num_classes) + + + def features(self, input): + x = self.conv1(input) #(8, 256, 256) + x = self.relu(x) + x = self.bn1(x) + x = self.maxpooling1(x) #(8, 128, 128) + + x = self.conv2(x) #(8, 128, 128) + x = self.relu(x) + x = self.bn1(x) + x = self.maxpooling1(x) #(8, 64, 64) + + x = self.conv3(x) #(16, 64, 64) + x = self.relu(x) + x = self.bn2(x) + x = self.maxpooling1(x) #(16, 32, 32) + + x = self.conv4(x) #(16, 32, 32) + x = self.relu(x) + x = self.bn2(x) + x = self.maxpooling2(x) #(16, 8, 8) + x = x.view(x.size(0), -1) #(Batch, 16*8*8) + + return x + + def classifier(self, feature): + out = self.dropout(feature) + out = self.fc1(out) #(Batch, 16) + out = self.leakyrelu(out) + out = self.dropout(out) + out = self.fc2(out) + return out + + def forward(self, input): + x = self.features(input) + out = self.classifier(x) + return out, x + + +@BACKBONE.register_module(module_name="meso4Inception") +class MesoInception4(nn.Module): + def __init__(self, mesoInception4_config): + super(MesoInception4, self).__init__() + self.num_classes = mesoInception4_config["num_classes"] + inc = mesoInception4_config["inc"] + #InceptionLayer1 + self.Incption1_conv1 = nn.Conv2d(3, 1, 1, padding=0, bias=False) + self.Incption1_conv2_1 = nn.Conv2d(3, 4, 1, padding=0, bias=False) + self.Incption1_conv2_2 = nn.Conv2d(4, 4, 3, padding=1, bias=False) + self.Incption1_conv3_1 = nn.Conv2d(3, 4, 1, padding=0, bias=False) + self.Incption1_conv3_2 = nn.Conv2d(4, 4, 3, padding=2, dilation=2, bias=False) + self.Incption1_conv4_1 = nn.Conv2d(3, 2, 1, padding=0, bias=False) + self.Incption1_conv4_2 = nn.Conv2d(2, 2, 3, padding=3, dilation=3, bias=False) + self.Incption1_bn = nn.BatchNorm2d(11) + + + #InceptionLayer2 + self.Incption2_conv1 = nn.Conv2d(11, 2, 1, padding=0, bias=False) + self.Incption2_conv2_1 = nn.Conv2d(11, 4, 1, padding=0, bias=False) + self.Incption2_conv2_2 = nn.Conv2d(4, 4, 3, padding=1, bias=False) + self.Incption2_conv3_1 = nn.Conv2d(11, 4, 1, padding=0, bias=False) + self.Incption2_conv3_2 = nn.Conv2d(4, 4, 3, padding=2, dilation=2, bias=False) + self.Incption2_conv4_1 = nn.Conv2d(11, 2, 1, padding=0, bias=False) + self.Incption2_conv4_2 = nn.Conv2d(2, 2, 3, padding=3, dilation=3, bias=False) + self.Incption2_bn = nn.BatchNorm2d(12) + + #Normal Layer + self.conv1 = nn.Conv2d(12, 16, 5, padding=2, bias=False) + self.relu = nn.ReLU(inplace=True) + self.leakyrelu = nn.LeakyReLU(0.1) + self.bn1 = nn.BatchNorm2d(16) + self.maxpooling1 = nn.MaxPool2d(kernel_size=(2, 2)) + + self.conv2 = nn.Conv2d(16, 16, 5, padding=2, bias=False) + self.maxpooling2 = nn.MaxPool2d(kernel_size=(4, 4)) + + self.dropout = nn.Dropout2d(0.5) + self.fc1 = nn.Linear(16*8*8, 16) + self.fc2 = nn.Linear(16, self.num_classes) + + + #InceptionLayer + def InceptionLayer1(self, input): + x1 = self.Incption1_conv1(input) + x2 = self.Incption1_conv2_1(input) + x2 = self.Incption1_conv2_2(x2) + x3 = self.Incption1_conv3_1(input) + x3 = self.Incption1_conv3_2(x3) + x4 = self.Incption1_conv4_1(input) + x4 = self.Incption1_conv4_2(x4) + y = torch.cat((x1, x2, x3, x4), 1) + y = self.Incption1_bn(y) + y = self.maxpooling1(y) + + return y + + def InceptionLayer2(self, input): + x1 = self.Incption2_conv1(input) + x2 = self.Incption2_conv2_1(input) + x2 = self.Incption2_conv2_2(x2) + x3 = self.Incption2_conv3_1(input) + x3 = self.Incption2_conv3_2(x3) + x4 = self.Incption2_conv4_1(input) + x4 = self.Incption2_conv4_2(x4) + y = torch.cat((x1, x2, x3, x4), 1) + y = self.Incption2_bn(y) + y = self.maxpooling1(y) + + return y + + + def features(self, input): + x = self.InceptionLayer1(input) #(Batch, 11, 128, 128) + x = self.InceptionLayer2(x) #(Batch, 12, 64, 64) + + x = self.conv1(x) #(Batch, 16, 64 ,64) + x = self.relu(x) + x = self.bn1(x) + x = self.maxpooling1(x) #(Batch, 16, 32, 32) + + x = self.conv2(x) #(Batch, 16, 32, 32) + x = self.relu(x) + x = self.bn1(x) + x = self.maxpooling2(x) #(Batch, 16, 8, 8) + + x = x.view(x.size(0), -1) #(Batch, 16*8*8) + + return x + + def classifier(self, feature): + + out = self.dropout(feature) + out = self.fc1(out) #(Batch, 16) + out = self.leakyrelu(out) + out = self.dropout(out) + out = self.fc2(out) + return out + + def forward(self, input): + x = self.features(input) + out = self.classifier(x) + return out, x diff --git a/training/networks/resnet.py b/training/networks/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..191c6426b2899cd2680890cc531d4f825ca193f0 --- /dev/null +++ b/training/networks/resnet.py @@ -0,0 +1,501 @@ +# -*- coding: utf-8 -*- +""" +Created on 18-5-21 5:26 PM + +@author: ronghuaiyang +""" +import torch +import torch.nn as nn +import math +import torch.utils.model_zoo as model_zoo +import torch.nn.utils.weight_norm as weight_norm +import torch.nn.functional as F + + +# __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', +# 'resnet152'] + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class AdaIN(nn.Module): + def __init__(self, eps=1e-5): + super().__init__() + self.eps = eps + # self.l1 = nn.Linear(num_classes, in_channel*4, bias=True) #bias is good :) + + def c_norm(self, x, bs, ch, eps=1e-7): + # assert isinstance(x, torch.cuda.FloatTensor) + x_var = x.var(dim=-1) + eps + x_std = x_var.sqrt().view(bs, ch, 1, 1) + x_mean = x.mean(dim=-1).view(bs, ch, 1, 1) + return x_std, x_mean + + def forward(self, x, y): + assert x.size(0)==y.size(0) + size = x.size() + bs, ch = size[:2] + x_ = x.view(bs, ch, -1) + y_ = y.reshape(bs, ch, -1) + x_std, x_mean = self.c_norm(x_, bs, ch, eps=self.eps) + y_std, y_mean = self.c_norm(y_, bs, ch, eps=self.eps) + out = ((x - x_mean.expand(size)) / x_std.expand(size)) \ + * y_std.expand(size) + y_mean.expand(size) + return out + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class BasicBlock_adain(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock_adain, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.adain1 = AdaIN() + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.adain2 = AdaIN() + self.downsample = downsample + self.stride = stride + + def forward(self, feat): # x is content, c is style + x, c = feat + residual = x + + x = self.conv1(x) + out = self.adain1(x, c) + out = self.relu(out) + + out = self.conv2(out) + out = self.adain2(out, c) + + if self.downsample is not None: + residual = self.downsample(residual) + + out += residual + out = self.relu(out) + + return (out, c) + + +class IRBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True): + super(IRBlock, self).__init__() + self.bn0 = nn.BatchNorm2d(inplanes) + self.conv1 = conv3x3(inplanes, inplanes) + self.bn1 = nn.BatchNorm2d(inplanes) + self.prelu = nn.PReLU() + self.conv2 = conv3x3(inplanes, planes, stride) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + self.use_se = use_se + if self.use_se: + self.se = SEBlock(planes) + + def forward(self, x): + residual = x + out = self.bn0(x) + out = self.conv1(out) + out = self.bn1(out) + out = self.prelu(out) + + out = self.conv2(out) + out = self.bn2(out) + if self.use_se: + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.prelu(out) + + return out + + +class IRBlock_3conv(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True): + super(IRBlock_3conv, self).__init__() + self.bn0 = nn.BatchNorm2d(inplanes) + self.conv1 = conv3x3(inplanes, inplanes) + self.bn1 = nn.BatchNorm2d(inplanes) + self.prelu1 = nn.PReLU() + self.conv2 = conv3x3(inplanes, planes, stride) + self.bn2 = nn.BatchNorm2d(planes) + self.prelu2 = nn.PReLU() + self.conv3 = conv3x3(planes, planes) + self.bn3 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + self.use_se = use_se + if self.use_se: + self.se = SEBlock(planes) + self.prelu = nn.PReLU() + + def forward(self, x): + residual = x + out = self.bn0(x) + out = self.conv1(out) + out = self.bn1(out) + out = self.prelu1(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.prelu2(out) + + out = self.conv3(out) + out = self.bn3(out) + if self.use_se: + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.prelu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d( + planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class SEBlock(nn.Module): + def __init__(self, channel, reduction=16): + super(SEBlock, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction), + nn.PReLU(), + nn.Linear(channel // reduction, channel), + nn.Sigmoid() + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y + + +class ResNetFace(nn.Module): + def __init__(self, block, layers, use_se=True, inc=3): + self.inplanes = 64 + self.use_se = use_se + super(ResNetFace, self).__init__() + self.conv1 = nn.Conv2d(inc, 64, kernel_size=3, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.prelu = nn.PReLU() + self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.bn4 = nn.BatchNorm2d(512) + #self.dropout = nn.Dropout() + self.fc5 = nn.Linear(512 * 8 * 8, 512) + #self.bn5 = nn.BatchNorm1d(512) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.xavier_normal_(m.weight) + elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + layers = [] + layers.append(block(self.inplanes, planes, stride, + downsample, use_se=self.use_se)) + self.inplanes = planes + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, use_se=self.use_se)) + + return nn.Sequential(*layers) + + def features(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.prelu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.bn4(x) + + return x + + def classifier(self, x): + x = x.view(x.size(0), -1) + x = self.fc5(x) + + return x + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.prelu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.bn4(x) + #x = self.dropout(x) + x = x.view(x.size(0), -1) + x = self.fc5(x) + #x = self.bn5(x) + + return x + + +class ResNet(nn.Module): + + def __init__(self, block, layers, basedim=32, inc=1): + self.inplanes = basedim + super(ResNet, self).__init__() + # self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + # bias=False) + self.conv1 = nn.Conv2d(inc, self.inplanes, kernel_size=3, stride=1, padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(self.inplanes) + self.relu = nn.ReLU(inplace=True) + # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, basedim, layers[0], stride=2) + self.layer2 = self._make_layer(block, 2*basedim, layers[1], stride=2) + self.layer3 = self._make_layer(block, 4*basedim, layers[2], stride=2) + self.layer4 = self._make_layer(block, 8*basedim, layers[3], stride=2) + # self.avgpool = nn.AvgPool2d(8, stride=1) + # self.fc = nn.Linear(512 * block.expansion, num_classes) + self.fc5 = nn.Linear(512 * 8 * 8, 512) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def features(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + # x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + return x + + def classifier(self, x): + x = x.view(x.size(0), -1) + x = self.fc5(x) + + return x + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + # x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + # x = nn.AvgPool2d(kernel_size=x.size()[2:])(x) + # x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc5(x) + + return x + + +def resnet18(pretrained=False, **kwargs): + """Constructs a ResNet-18 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) + return model + + +def resnet34(pretrained=False, **kwargs): + """Constructs a ResNet-34 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) + return model + + +def resnet50(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) + return model + + +def resnet101(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) + return model + + +def resnet152(pretrained=False, **kwargs): + """Constructs a ResNet-152 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) + return model + + +def resnet_face18(use_se=True, **kwargs): + model = ResNetFace(IRBlock, [2, 2, 2, 2], use_se=use_se, **kwargs) + return model + + +def resnet_face62(use_se=True, **kwargs): + model = ResNetFace(IRBlock_3conv, [3, 4, 10, 3], use_se=use_se, **kwargs) + return model + +if __name__ == "__main__": + net = HR_resnet() + dummy = torch.rand(10,3,256,256) + x = net(dummy) + print('output:', x.size()) diff --git a/training/networks/resnet34.py b/training/networks/resnet34.py new file mode 100644 index 0000000000000000000000000000000000000000..0b6d1d00b49d6c4fbabe832f9d683e71e8bf898e --- /dev/null +++ b/training/networks/resnet34.py @@ -0,0 +1,60 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 + +The code is for ResNet34 backbone. +''' + +import os +import logging +from typing import Union +import torch +import torchvision +import torch.nn as nn +import torch.nn.functional as F +from metrics.registry import BACKBONE + +logger = logging.getLogger(__name__) + +@BACKBONE.register_module(module_name="resnet34") +class ResNet34(nn.Module): + def __init__(self, resnet_config): + super(ResNet34, self).__init__() + """ Constructor + Args: + resnet_config: configuration file with the dict format + """ + self.num_classes = resnet_config["num_classes"] + inc = resnet_config["inc"] + self.mode = resnet_config["mode"] + + # Define layers of the backbone + resnet = torchvision.models.resnet34(pretrained=True) # FIXME: download the pretrained weights from online + # resnet.conv1 = nn.Conv2d(inc, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.resnet = torch.nn.Sequential(*list(resnet.children())[:-2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512, self.num_classes) + + if self.mode == 'adjust_channel': + self.adjust_channel = nn.Sequential( + nn.Conv2d(512, 512, 1, 1), + nn.BatchNorm2d(512), + nn.ReLU(inplace=True), + ) + + + def features(self, inp): + x = self.resnet(inp) + return x + + def classifier(self, features): + x = self.avgpool(features) + x = x.view(x.size(0), -1) + x = self.fc(x) + return x + + def forward(self, inp): + x = self.features(inp) + out = self.classifier(x) + return out diff --git a/training/networks/resnet50.py b/training/networks/resnet50.py new file mode 100644 index 0000000000000000000000000000000000000000..732bf8ad6835bde08151891b14f66e075d6c6643 --- /dev/null +++ b/training/networks/resnet50.py @@ -0,0 +1,68 @@ +# author: Xinyuzhou +# email: Xinyuzhou@sjtu.edu.cn +# date: 2025-12-24 +# +# The code is for ResNet50 backbone. + +import os +import logging +from typing import Union +import torch +import torchvision +import torch.nn as nn +import torch.nn.functional as F +from metrics.registry import BACKBONE + +logger = logging.getLogger(__name__) + +@BACKBONE.register_module(module_name="resnet50") +class ResNet50(nn.Module): + def __init__(self, resnet_config): + super(ResNet50, self).__init__() + """ Constructor + Args: + resnet_config: configuration file with the dict format + """ + self.num_classes = resnet_config["num_classes"] + inc = resnet_config["inc"] + self.mode = resnet_config["mode"] + + # Load the pretrained ResNet50 weights from torchvision + # New API for torchvision >= 0.13: + # resnet = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1) + # The legacy API is still available: + resnet = torchvision.models.resnet50(pretrained=True) + + # If you want to support inc != 3, uncomment this and adjust accordingly + # resnet.conv1 = nn.Conv2d(inc, 64, kernel_size=7, stride=2, padding=3, bias=False) + + # Remove the final avgpool and fc layers, keeping feature maps up to layer4 + self.resnet = torch.nn.Sequential(*list(resnet.children())[:-2]) + + # The final feature dimension of ResNet50 is 2048 + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(2048, self.num_classes) + + if self.mode == 'adjust_channel': + self.adjust_channel = nn.Sequential( + nn.Conv2d(2048, 2048, 1, 1), + nn.BatchNorm2d(2048), + nn.ReLU(inplace=True), + ) + + def features(self, inp): + x = self.resnet(inp) + if self.mode == 'adjust_channel': + x = self.adjust_channel(x) + return x + + def classifier(self, features): + x = self.avgpool(features) + x = x.view(x.size(0), -1) + x = self.fc(x) + return x + + def forward(self, inp): + x = self.features(inp) + out = self.classifier(x) + return out \ No newline at end of file diff --git a/training/networks/time_transformer.py b/training/networks/time_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..df641b91c65febd9b3400847574faed974c92851 --- /dev/null +++ b/training/networks/time_transformer.py @@ -0,0 +1,252 @@ +import torch +from torch import nn, einsum +import torch.nn.functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +random_select = True + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(x, **kwargs) + x + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x, mask = None): + b, n, _, h = *x.shape, self.heads + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) + + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + mask_value = -torch.finfo(dots.dtype).max + + if mask is not None: + mask = F.pad(mask.flatten(1), (1, 0), value = True) + assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' + mask = rearrange(mask, 'b i -> b () i ()') * rearrange(mask, 'b j -> b () () j') + dots.masked_fill_(~mask, mask_value) + del mask + + attn = dots.softmax(dim=-1) + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + out = self.to_out(out) + return out + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))), + Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))) + ])) + def forward(self, x, mask = None): + for attn, ff in self.layers: + x = attn(x, mask = mask) + x = ff(x) + return x + +class ViT(nn.Module): + def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): + super().__init__() + assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' + num_patches = (image_size // patch_size) ** 2 + patch_dim = channels * patch_size ** 2 + assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' + + self.to_patch_embedding = nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size), + nn.Linear(patch_dim, dim), + ) + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) + self.dropout = nn.Dropout(emb_dropout) + + self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) + + self.pool = pool + self.to_latent = nn.Identity() + + self.mlp_head = nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, num_classes) + ) + + def forward(self, img, mask = None): + x = self.to_patch_embedding(img) + b, n, _ = x.shape #batch,num_patches,channels # + + cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) + x = torch.cat((cls_tokens, x), dim=1) + x += self.pos_embedding[:, :(n + 1)] + x = self.dropout(x) + + x = self.transformer(x, mask) + + x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] + + x = self.to_latent(x) + return self.mlp_head(x) + + + +def valid_idx(idx, h): + i = idx // h + j = idx % h + pad = h // 7 + if j < pad or i >= h - pad or j >= h - pad: + return False + else: + return True + +import random +from math import sqrt +class RandomSelect(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + # batch,7x7 + size=x.shape[1] + h=int(sqrt(size)) + candidates = list(range(size)) + candidates = [idx for idx in candidates if valid_idx(idx, h)] + max_k = len(candidates) + if self.training and random_select: + k = 8 + if k==-1: + k=max_k + else: + k = max_k + candidates = random.sample(candidates, k) + x = x[:,candidates] + return x + +class VideoiT(nn.Module): + def __init__(self, *, image_size, patch_size, num_patches, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): + super().__init__() + assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' + patch_dim = channels * patch_size ** 2 + assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' + + self.to_patch = Rearrange('b c t (h p1) (w p2) -> b (h w) t (p1 p2 c)', p1 = patch_size, p2 = patch_size) + self.patch_to_embedding=nn.Linear(patch_dim, dim) + self.num_patches=num_patches + + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) + self.dropout = nn.Dropout(emb_dropout) + + self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) + + self.pool = pool + self.random_select=RandomSelect() + self.to_latent = nn.Identity() + + self.mlp_head = nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, num_classes) + ) + + def forward(self, img, mask = None): + real_b=img.shape[0] + x = self.to_patch(img) + x = self.random_select(x) + n=x.shape[1] + x=x.reshape(real_b*n,self.num_patches,-1) + x = self.patch_to_embedding(x) + b, n, _ = x.shape #batch,num_patches,channels # + + cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) + x = torch.cat((cls_tokens, x), dim=1) + x += self.pos_embedding[:, :(n + 1)] + x = self.dropout(x) + + x = self.transformer(x, mask) + + x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] + + x = self.to_latent(x) + x = self.mlp_head(x) + x = x.reshape(real_b,-1) + return x + + +class TimeTransformer(nn.Module): + def __init__(self,num_patches, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', dim_head = 64, dropout = 0., emb_dropout = 0.): + super().__init__() + assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' + + self.num_patches=num_patches + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) + self.dropout = nn.Dropout(emb_dropout) + + self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) + + self.pool = pool + self.to_latent = nn.Identity() + + self.mlp_head = nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, num_classes) + ) + + def forward(self, x): + b, n, _ = x.shape #batch,num_patches,channels # + + cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) + x = torch.cat((cls_tokens, x), dim=1) + x += self.pos_embedding[:, :(n + 1)] + x = self.dropout(x) + + x = self.transformer(x, mask=None) + + x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] + + x = self.to_latent(x) + return self.mlp_head(x) \ No newline at end of file diff --git a/training/networks/vgg.py b/training/networks/vgg.py new file mode 100644 index 0000000000000000000000000000000000000000..da74f9c6a1ca1a93470cefeb68781ddee05d3460 --- /dev/null +++ b/training/networks/vgg.py @@ -0,0 +1,143 @@ +"""A VGG-based perceptual loss function for PyTorch.""" + +import torch +from torch import nn +from torch.nn import functional as F +from torchvision import models, transforms + + +class Lambda(nn.Module): + """Wraps a callable in an :class:`nn.Module` without registering it.""" + + def __init__(self, func): + super().__init__() + object.__setattr__(self, 'forward', func) + + def extra_repr(self): + return getattr(self.forward, '__name__', type(self.forward).__name__) + '()' + + +class WeightedLoss(nn.ModuleList): + """A weighted combination of multiple loss functions.""" + + def __init__(self, losses, weights, verbose=False): + super().__init__() + for loss in losses: + self.append(loss if isinstance(loss, nn.Module) else Lambda(loss)) + self.weights = weights + self.verbose = verbose + + def _print_losses(self, losses): + for i, loss in enumerate(losses): + print(f'({i}) {type(self[i]).__name__}: {loss.item()}') + + def forward(self, *args, **kwargs): + losses = [] + for loss, weight in zip(self, self.weights): + losses.append(loss(*args, **kwargs) * weight) + if self.verbose: + self._print_losses(losses) + return sum(losses) + + +class TVLoss(nn.Module): + """Total variation loss (Lp penalty on image gradient magnitude). + The input must be 4D. If a target (second parameter) is passed in, it is + ignored. + ``p=1`` yields the vectorial total variation norm. It is a generalization + of the originally proposed (isotropic) 2D total variation norm (see + (see https://en.wikipedia.org/wiki/Total_variation_denoising) for color + images. On images with a single channel it is equal to the 2D TV norm. + ``p=2`` yields a variant that is often used for smoothing out noise in + reconstructions of images from neural network feature maps (see Mahendran + and Vevaldi, "Understanding Deep Image Representations by Inverting + Them", https://arxiv.org/abs/1412.0035) + :attr:`reduction` can be set to ``'mean'``, ``'sum'``, or ``'none'`` + similarly to the loss functions in :mod:`torch.nn`. The default is + ``'mean'``. + """ + + def __init__(self, p, reduction='mean', eps=1e-8): + super().__init__() + if p not in {1, 2}: + raise ValueError('p must be 1 or 2') + if reduction not in {'mean', 'sum', 'none'}: + raise ValueError("reduction must be 'mean', 'sum', or 'none'") + self.p = p + self.reduction = reduction + self.eps = eps + + def forward(self, input, target=None): + input = F.pad(input, (0, 1, 0, 1), 'replicate') + x_diff = input[..., :-1, :-1] - input[..., :-1, 1:] + y_diff = input[..., :-1, :-1] - input[..., 1:, :-1] + diff = x_diff**2 + y_diff**2 + if self.p == 1: + diff = (diff + self.eps).mean(dim=1, keepdims=True).sqrt() + if self.reduction == 'mean': + return diff.mean() + if self.reduction == 'sum': + return diff.sum() + return diff + + +class VGGLoss(nn.Module): + """Computes the VGG perceptual loss between two batches of images. + The input and target must be 4D tensors with three channels + ``(B, 3, H, W)`` and must have equivalent shapes. Pixel values should be + normalized to the range 0–1. + The VGG perceptual loss is the mean squared difference between the features + computed for the input and target at layer :attr:`layer` (default 8, or + ``relu2_2``) of the pretrained model specified by :attr:`model` (either + ``'vgg16'`` (default) or ``'vgg19'``). + If :attr:`shift` is nonzero, a random shift of at most :attr:`shift` + pixels in both height and width will be applied to all images in the input + and target. The shift will only be applied when the loss function is in + training mode, and will not be applied if a precomputed feature map is + supplied as the target. + :attr:`reduction` can be set to ``'mean'``, ``'sum'``, or ``'none'`` + similarly to the loss functions in :mod:`torch.nn`. The default is + ``'mean'``. + :meth:`get_features()` may be used to precompute the features for the + target, to speed up the case where inputs are compared against the same + target over and over. To use the precomputed features, pass them in as + :attr:`target` and set :attr:`target_is_features` to :code:`True`. + Instances of :class:`VGGLoss` must be manually converted to the same + device and dtype as their inputs. + """ + + models = {'vgg16': models.vgg16, 'vgg19': models.vgg19} + + def __init__(self, model='vgg16', layer=8, shift=0, reduction='mean'): + super().__init__() + self.instancenorm = nn.InstanceNorm2d(512, affine=False) + self.shift = shift + self.reduction = reduction + self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + self.model = self.models[model](pretrained=True).features[:layer+1] + self.model.eval() + self.model.requires_grad_(False) + + def get_features(self, input): + return self.model(self.normalize(input)) + + def train(self, mode=True): + self.training = mode + + def forward(self, input, target, target_is_features=False): + if target_is_features: + input_feats = self.get_features(input) + target_feats = target + else: + sep = input.shape[0] + batch = torch.cat([input, target]) + if self.shift and self.training: + padded = F.pad(batch, [self.shift] * 4, mode='replicate') + batch = transforms.RandomCrop(batch.shape[2:])(padded) + feats = self.get_features(batch) + input_feats, target_feats = feats[:sep], feats[sep:] + # input_feats, target_feats = \ + # self.instancenorm(input_feats), \ + # self.instancenorm(target_feats) + return F.mse_loss(input_feats, target_feats, reduction=self.reduction) \ No newline at end of file diff --git a/training/networks/xception.py b/training/networks/xception.py new file mode 100644 index 0000000000000000000000000000000000000000..410345c5e15af8aee77a7ff4e3910967bf2d4fce --- /dev/null +++ b/training/networks/xception.py @@ -0,0 +1,285 @@ +''' +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-0706 + +The code is mainly modified from GitHub link below: +https://github.com/ondyari/FaceForensics/blob/master/classification/network/xception.py +''' + +import os +import argparse +import logging + +import math +import torch +# import pretrainedmodels +import torch.nn as nn +import torch.nn.functional as F + +import torch.utils.model_zoo as model_zoo +from torch.nn import init +from typing import Union +from metrics.registry import BACKBONE + +logger = logging.getLogger(__name__) + + + +class SeparableConv2d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False): + super(SeparableConv2d, self).__init__() + + self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, + stride, padding, dilation, groups=in_channels, bias=bias) + self.pointwise = nn.Conv2d( + in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias) + + def forward(self, x): + x = self.conv1(x) + x = self.pointwise(x) + return x + + +class Block(nn.Module): + def __init__(self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True): + super(Block, self).__init__() + + if out_filters != in_filters or strides != 1: + self.skip = nn.Conv2d(in_filters, out_filters, + 1, stride=strides, bias=False) + self.skipbn = nn.BatchNorm2d(out_filters) + else: + self.skip = None + + self.relu = nn.ReLU(inplace=True) + rep = [] + + filters = in_filters + if grow_first: # whether the number of filters grows first + rep.append(self.relu) + rep.append(SeparableConv2d(in_filters, out_filters, + 3, stride=1, padding=1, bias=False)) + rep.append(nn.BatchNorm2d(out_filters)) + filters = out_filters + + for i in range(reps-1): + rep.append(self.relu) + rep.append(SeparableConv2d(filters, filters, + 3, stride=1, padding=1, bias=False)) + rep.append(nn.BatchNorm2d(filters)) + + if not grow_first: + rep.append(self.relu) + rep.append(SeparableConv2d(in_filters, out_filters, + 3, stride=1, padding=1, bias=False)) + rep.append(nn.BatchNorm2d(out_filters)) + + if not start_with_relu: + rep = rep[1:] + else: + rep[0] = nn.ReLU(inplace=False) + + if strides != 1: + rep.append(nn.MaxPool2d(3, strides, 1)) + self.rep = nn.Sequential(*rep) + + def forward(self, inp): + x = self.rep(inp) + + if self.skip is not None: + skip = self.skip(inp) + skip = self.skipbn(skip) + else: + skip = inp + + x += skip + return x + +def add_gaussian_noise(ins, mean=0, stddev=0.2): + noise = ins.data.new(ins.size()).normal_(mean, stddev) + return ins + noise + + +@BACKBONE.register_module(module_name="xception") +class Xception(nn.Module): + """ + Xception optimized for the ImageNet dataset, as specified in + https://arxiv.org/pdf/1610.02357.pdf + """ + + def __init__(self, xception_config): + """ Constructor + Args: + xception_config: configuration file with the dict format + """ + super(Xception, self).__init__() + self.num_classes = xception_config["num_classes"] + self.mode = xception_config["mode"] + inc = xception_config["inc"] + dropout = xception_config["dropout"] + + # Entry flow + self.conv1 = nn.Conv2d(inc, 32, 3, 2, 0, bias=False) + + self.bn1 = nn.BatchNorm2d(32) + self.relu = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(32, 64, 3, bias=False) + self.bn2 = nn.BatchNorm2d(64) + # do relu here + + self.block1 = Block( + 64, 128, 2, 2, start_with_relu=False, grow_first=True) + self.block2 = Block( + 128, 256, 2, 2, start_with_relu=True, grow_first=True) + self.block3 = Block( + 256, 728, 2, 2, start_with_relu=True, grow_first=True) + + # middle flow + self.block4 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True) + self.block5 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True) + self.block6 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True) + self.block7 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True) + + self.block8 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True) + self.block9 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True) + self.block10 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True) + self.block11 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True) + + # Exit flow + self.block12 = Block( + 728, 1024, 2, 2, start_with_relu=True, grow_first=False) + + self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1) + self.bn3 = nn.BatchNorm2d(1536) + + # do relu here + self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1) + self.bn4 = nn.BatchNorm2d(2048) + # used for iid + final_channel = 2048 + if self.mode == 'adjust_channel_iid': + final_channel = 512 + self.mode = 'adjust_channel' + self.last_linear = nn.Linear(final_channel, self.num_classes) + if dropout: + self.last_linear = nn.Sequential( + nn.Dropout(p=dropout), + nn.Linear(final_channel, self.num_classes) + ) + + self.adjust_channel = nn.Sequential( + nn.Conv2d(2048, 512, 1, 1), + nn.BatchNorm2d(512), + nn.ReLU(inplace=False), + ) + + def fea_part1_0(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + + return x + + def fea_part1_1(self, x): + + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + + return x + + def fea_part1(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + + return x + + def fea_part2(self, x): + x = self.block1(x) + x = self.block2(x) + x = self.block3(x) + + return x + + def fea_part3(self, x): + if self.mode == "shallow_xception": + return x + else: + x = self.block4(x) + x = self.block5(x) + x = self.block6(x) + x = self.block7(x) + return x + + def fea_part4(self, x): + if self.mode == "shallow_xception": + x = self.block12(x) + else: + x = self.block8(x) + x = self.block9(x) + x = self.block10(x) + x = self.block11(x) + x = self.block12(x) + return x + + def fea_part5(self, x): + x = self.conv3(x) + x = self.bn3(x) + x = self.relu(x) + + x = self.conv4(x) + x = self.bn4(x) + + return x + + def features(self, input): + x = self.fea_part1(input) + + x = self.fea_part2(x) + x = self.fea_part3(x) + x = self.fea_part4(x) + + x = self.fea_part5(x) + + if self.mode == 'adjust_channel': + x = self.adjust_channel(x) + + return x + + def classifier(self, features,id_feat=None): + # for iid + if self.mode == 'adjust_channel': + x = features + else: + x = self.relu(features) + + if len(x.shape) == 4: + x = F.adaptive_avg_pool2d(x, (1, 1)) + x = x.view(x.size(0), -1) + self.last_emb = x + # for iid + if id_feat!=None: + out = self.last_linear(x-id_feat) + else: + out = self.last_linear(x) + return out + + def forward(self, input): + x = self.features(input) + out = self.classifier(x) + return out, x diff --git a/training/networks/xception_ffd.py b/training/networks/xception_ffd.py new file mode 100644 index 0000000000000000000000000000000000000000..5f23ddae50da43390081167dd199d31a1118c889 --- /dev/null +++ b/training/networks/xception_ffd.py @@ -0,0 +1,267 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import os +import sys + +class SeparableConv2d(nn.Module): + def __init__(self, c_in, c_out, ks, stride=1, padding=0, dilation=1, bias=False): + super(SeparableConv2d, self).__init__() + self.c = nn.Conv2d(c_in, c_in, ks, stride, padding, dilation, groups=c_in, bias=bias) + self.pointwise = nn.Conv2d(c_in, c_out, 1, 1, 0, 1, 1, bias=bias) + + def forward(self, x): + x = self.c(x) + x = self.pointwise(x) + return x + +class Block(nn.Module): + def __init__(self, c_in, c_out, reps, stride=1, start_with_relu=True, grow_first=True): + super(Block, self).__init__() + + self.skip = None + self.skip_bn = None + if c_out != c_in or stride!= 1: + self.skip = nn.Conv2d(c_in, c_out, 1, stride=stride, bias=False) + self.skip_bn = nn.BatchNorm2d(c_out) + + self.relu = nn.ReLU(inplace=True) + + rep = [] + c = c_in + if grow_first: + rep.append(self.relu) + rep.append(SeparableConv2d(c_in, c_out, 3, stride=1, padding=1, bias=False)) + rep.append(nn.BatchNorm2d(c_out)) + c = c_out + + for i in range(reps - 1): + rep.append(self.relu) + rep.append(SeparableConv2d(c, c, 3, stride=1, padding=1, bias=False)) + rep.append(nn.BatchNorm2d(c)) + + if not grow_first: + rep.append(self.relu) + rep.append(SeparableConv2d(c_in, c_out, 3, stride=1, padding=1, bias=False)) + rep.append(nn.BatchNorm2d(c_out)) + + if not start_with_relu: + rep = rep[1:] + else: + rep[0] = nn.ReLU(inplace=False) + + if stride != 1: + rep.append(nn.MaxPool2d(3, stride, 1)) + self.rep = nn.Sequential(*rep) + + def forward(self, inp): + x = self.rep(inp) + + if self.skip is not None: + y = self.skip(inp) + y = self.skip_bn(y) + else: + y = inp + + x += y + return x + +class RegressionMap(nn.Module): + def __init__(self, c_in): + super(RegressionMap, self).__init__() + self.c = SeparableConv2d(c_in, 1, 3, stride=1, padding=1, bias=False) + self.s = nn.Sigmoid() + + def forward(self, x): + mask = self.c(x) + mask = self.s(mask) + return mask, None + +class TemplateMap(nn.Module): + def __init__(self, c_in, templates): + super(TemplateMap, self).__init__() + self.c = Block(c_in, 364, 2, 2, start_with_relu=True, grow_first=False) + self.l = nn.Linear(364, 10) + self.relu = nn.ReLU(inplace=True) + + self.templates = templates + + def forward(self, x): + v = self.c(x) + v = self.relu(v) + v = F.adaptive_avg_pool2d(v, (1,1)) + v = v.view(v.size(0), -1) + v = self.l(v) + mask = torch.mm(v, self.templates.reshape(10,361)) + mask = mask.reshape(x.shape[0], 1, 19, 19) + + return mask, v + +class PCATemplateMap(nn.Module): + def __init__(self, templates): + super(PCATemplateMap, self).__init__() + self.templates = templates + + def forward(self, x): + fe = x.view(x.shape[0], x.shape[1], x.shape[2]*x.shape[3]) + fe = torch.transpose(fe, 1, 2) + mu = torch.mean(fe, 2, keepdim=True) + fea_diff = fe - mu + + cov_fea = torch.bmm(fea_diff, torch.transpose(fea_diff, 1, 2)) + B = self.templates.reshape(1, 10, 361).repeat(x.shape[0], 1, 1) + D = torch.bmm(torch.bmm(B, cov_fea), torch.transpose(B, 1, 2)) + eigen_value, eigen_vector = D.symeig(eigenvectors=True) + index = torch.tensor([9]).cuda() + eigen = torch.index_select(eigen_vector, 2, index) + + v = eigen.squeeze(-1) + mask = torch.mm(v, self.templates.reshape(10, 361)) + mask = mask.reshape(x.shape[0], 1, 19, 19) + return mask, v + +class Xception(nn.Module): + """ + Xception optimized for the ImageNet dataset, as specified in + https://arxiv.org/pdf/1610.02357.pdf + """ + def __init__(self, maptype, templates, num_classes=1000): + super(Xception, self).__init__() + self.num_classes = num_classes + + self.conv1 = nn.Conv2d(3, 32, 3,2, 0, bias=False) + self.bn1 = nn.BatchNorm2d(32) + self.relu = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(32,64,3,bias=False) + self.bn2 = nn.BatchNorm2d(64) + + self.block1=Block(64,128,2,2,start_with_relu=False,grow_first=True) + self.block2=Block(128,256,2,2,start_with_relu=True,grow_first=True) + self.block3=Block(256,728,2,2,start_with_relu=True,grow_first=True) + self.block4=Block(728,728,3,1,start_with_relu=True,grow_first=True) + self.block5=Block(728,728,3,1,start_with_relu=True,grow_first=True) + self.block6=Block(728,728,3,1,start_with_relu=True,grow_first=True) + self.block7=Block(728,728,3,1,start_with_relu=True,grow_first=True) + self.block8=Block(728,728,3,1,start_with_relu=True,grow_first=True) + self.block9=Block(728,728,3,1,start_with_relu=True,grow_first=True) + self.block10=Block(728,728,3,1,start_with_relu=True,grow_first=True) + self.block11=Block(728,728,3,1,start_with_relu=True,grow_first=True) + self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False) + + self.conv3 = SeparableConv2d(1024,1536,3,1,1) + self.bn3 = nn.BatchNorm2d(1536) + + self.conv4 = SeparableConv2d(1536,2048,3,1,1) + self.bn4 = nn.BatchNorm2d(2048) + + self.last_linear = nn.Linear(2048, num_classes) + + if maptype == 'none': + self.map = [1, None] + elif maptype == 'reg': + self.map = RegressionMap(728) + elif maptype == 'tmp': + self.map = TemplateMap(728, templates) + elif maptype == 'pca_tmp': + self.map = PCATemplateMap(728) + else: + print('Unknown map type: `{0}`'.format(maptype)) + sys.exit() + + def features(self, input): + x = self.conv1(input) + x = self.bn1(x) + x = self.relu(x) + + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + + x = self.block1(x) + x = self.block2(x) + x = self.block3(x) + x = self.block4(x) + x = self.block5(x) + x = self.block6(x) + x = self.block7(x) + mask, vec = self.map(x) + x = x * mask + x = self.block8(x) + x = self.block9(x) + x = self.block10(x) + x = self.block11(x) + x = self.block12(x) + x = self.conv3(x) + x = self.bn3(x) + x = self.relu(x) + + x = self.conv4(x) + x = self.bn4(x) + return x, mask, vec + + def logits(self, features): + x = self.relu(features) + x = F.adaptive_avg_pool2d(x, (1, 1)) + x = x.view(x.size(0), -1) + x = self.last_linear(x) + return x + + def forward(self, input): + x, mask, vec = self.features(input) + x = self.logits(x) + return x, mask, vec + +def init_weights(m): + classname = m.__class__.__name__ + if classname.find('SeparableConv2d') != -1: + m.c.weight.data.normal_(0.0, 0.01) + if m.c.bias is not None: + m.c.bias.data.fill_(0) + m.pointwise.weight.data.normal_(0.0, 0.01) + if m.pointwise.bias is not None: + m.pointwise.bias.data.fill_(0) + elif classname.find('Conv') != -1 or classname.find('Linear') != -1: + m.weight.data.normal_(0.0, 0.01) + if m.bias is not None: + m.bias.data.fill_(0) + elif classname.find('BatchNorm') != -1: + m.weight.data.normal_(1.0, 0.01) + m.bias.data.fill_(0) + elif classname.find('LSTM') != -1: + for i in m._parameters: + if i.__class__.__name__.find('weight') != -1: + i.data.normal_(0.0, 0.01) + elif i.__class__.__name__.find('bias') != -1: + i.bias.data.fill_(0) + +class Model: + def __init__(self, maptype='None', templates=None, num_classes=2, load_pretrain=True): + model = Xception(maptype, templates, num_classes=num_classes) + if load_pretrain: + state_dict = torch.load('./xception-b5690688.pth') + for name, weights in state_dict: + if 'pointwise' in name: + state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1) + del state_dict['fc.weight'] + del state_dict['fc.bias'] + model.load_state_dict(state_dict, False) + else: + model.apply(init_weights) + self.model = model + + def save(self, epoch, optim, model_dir): + state = {'net': self.model.state_dict(), 'optim': optim.state_dict()} + torch.save(state, '{0}/{1:06d}.tar'.format(model_dir, epoch)) + print('Saved model `{0}`'.format(epoch)) + + def load(self, epoch, model_dir): + filename = '{0}{1:06d}.tar'.format(model_dir, epoch) + print('Loading model from {0}'.format(filename)) + if os.path.exists(filename): + state = torch.load(filename) + self.model.load_state_dict(state['net']) + else: + print('Failed to load model from {0}'.format(filename)) + diff --git a/training/networks/xception_sladd.py b/training/networks/xception_sladd.py new file mode 100644 index 0000000000000000000000000000000000000000..b5b949d7918fab93983d915857ea6f60a22c9927 --- /dev/null +++ b/training/networks/xception_sladd.py @@ -0,0 +1,272 @@ +""" + +Author: Andreas Rössler +""" +import torchvision +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.model_zoo as model_zoo + +from metrics.registry import BACKBONE + +pretrained_settings = { + 'xception': { + 'imagenet': { + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth', + 'input_space': 'RGB', + 'input_size': [3, 299, 299], + 'input_range': [0, 1], + 'mean': [0.5, 0.5, 0.5], + 'std': [0.5, 0.5, 0.5], + 'num_classes': 1000, + 'scale': 0.8975 + # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 + } + } +} + + +class SeparableConv2d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False): + super(SeparableConv2d, self).__init__() + + self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, + stride, padding, dilation, groups=in_channels, bias=bias) + self.pointwise = nn.Conv2d( + in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias) + + def forward(self, x): + x = self.conv1(x) + x = self.pointwise(x) + return x + + +class RegressionMap(nn.Module): + def __init__(self, c_in): + super(RegressionMap, self).__init__() + self.c = SeparableConv2d(c_in, 1, 3, stride=1, padding=1, bias=False) + self.s = nn.Sigmoid() + + def forward(self, x): + mask = self.c(x) + mask = self.s(mask) + return mask + + +class Block(nn.Module): + def __init__(self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True): + super(Block, self).__init__() + + if out_filters != in_filters or strides != 1: + self.skip = nn.Conv2d(in_filters, out_filters, + 1, stride=strides, bias=False) + self.skipbn = nn.BatchNorm2d(out_filters) + else: + self.skip = None + + self.relu = nn.ReLU(inplace=False) + rep = [] + + filters = in_filters + if grow_first: # whether the number of filters grows first + rep.append(self.relu) + rep.append(SeparableConv2d(in_filters, out_filters, + 3, stride=1, padding=1, bias=False)) + rep.append(nn.BatchNorm2d(out_filters)) + filters = out_filters + + for i in range(reps - 1): + rep.append(self.relu) + rep.append(SeparableConv2d(filters, filters, + 3, stride=1, padding=1, bias=False)) + rep.append(nn.BatchNorm2d(filters)) + + if not grow_first: + rep.append(self.relu) + rep.append(SeparableConv2d(in_filters, out_filters, + 3, stride=1, padding=1, bias=False)) + rep.append(nn.BatchNorm2d(out_filters)) + + if not start_with_relu: + rep = rep[1:] + else: + rep[0] = nn.ReLU(inplace=False) + + if strides != 1: + rep.append(nn.MaxPool2d(3, strides, 1)) + self.rep = nn.Sequential(*rep) + + def forward(self, inp): + x = self.rep(inp) + + if self.skip is not None: + skip = self.skip(inp) + skip = self.skipbn(skip) + else: + skip = inp + + x += skip + return x + +@BACKBONE.register_module(module_name="xception_sladd") +class Xception_SLADD(nn.Module): + """ + Xception optimized for the ImageNet dataset, as specified in + https://arxiv.org/pdf/1610.02357.pdf + """ + + def __init__(self, config): + """ Constructor + Args: + num_classes: number of classes + """ + super(Xception_SLADD, self).__init__() + num_classes = config["num_classes"] + inc = config["inc"] + dropout = config["dropout"] + + # Entry flow + self.conv1 = nn.Conv2d(inc, 32, 3, 2, 0, bias=False) + self.bn1 = nn.BatchNorm2d(32) + self.relu = nn.ReLU(inplace=False) + + self.conv2 = nn.Conv2d(32, 64, 3, bias=False) + self.bn2 = nn.BatchNorm2d(64) + # do relu here + + self.block1 = Block( + 64, 128, 2, 2, start_with_relu=False, grow_first=True) + self.block2 = Block( + 128, 256, 2, 2, start_with_relu=True, grow_first=True) + self.block3 = Block( + 256, 728, 2, 2, start_with_relu=True, grow_first=True) + + # middle flow + self.block4 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True) + self.block5 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True) + self.block6 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True) + self.block7 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True) + + self.block8 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True) + self.block9 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True) + self.block10 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True) + self.block11 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True) + + # Exit flow + self.block12 = Block( + 728, 1024, 2, 2, start_with_relu=True, grow_first=False) + + self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1) + self.bn3 = nn.BatchNorm2d(1536) + + # do relu here + self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1) + self.bn4 = nn.BatchNorm2d(2048) + final_channel = 2048 + self.last_linear = nn.Linear(final_channel, num_classes) + if dropout: + self.last_linear = nn.Sequential( + nn.Dropout(p=dropout), + nn.Linear(final_channel, num_classes) + ) + self.type_fc = nn.Linear(2048, 5) + self.mag_fc = nn.Linear(2048, 1) + self.map = RegressionMap(728) + self.pecent = 1.0 / 1.5 + + def fea_part1_0(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + + return x + + def fea_part1_1(self, x): + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + + return x + + def fea_part1(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + + return x + + def fea_part2(self, x): + x = self.block1(x) + x = self.block2(x) + x = self.block3(x) + + return x + + def fea_part3(self, x): + x = self.block4(x) + x = self.block5(x) + x = self.block6(x) + x = self.block7(x) + + return x + + def fea_part4(self, x): + x = self.block8(x) + x = self.block9(x) + x = self.block10(x) + x = self.block11(x) + x = self.block12(x) + + return x + + def fea_part5(self, x): + x = self.conv3(x) + x = self.bn3(x) + x = self.relu(x) + + x = self.conv4(x) + x = self.bn4(x) + + return x + + def features(self, input): + x = self.fea_part1(input) + + x = self.fea_part2(x) + x3 = self.fea_part3(x) + x = self.fea_part4(x3) + + x = self.fea_part5(x) + return x,x3 + + # def classifier(self, features): + def classifier(self, x): + x = self.relu(x) + x = F.adaptive_avg_pool2d(x, (1, 1)) + x = x.view(x.size(0), -1) + out = self.last_linear(x) + return out, x + + def estimateMap(self, x): + map = self.map(x) + return map + + # def forward(self, input): + def forward(self, x): + x,x3=self.features(x) + out, fea, type, mag = self.classifier(x) + map = self.estimateMap(x3) + return out, fea, map, type, mag diff --git a/training/optimizor/LinearLR.py b/training/optimizor/LinearLR.py new file mode 100644 index 0000000000000000000000000000000000000000..80bc70dbae46bb9f76aa65afe6f4a1b95dd25619 --- /dev/null +++ b/training/optimizor/LinearLR.py @@ -0,0 +1,20 @@ +import torch +from torch.optim import SGD +from torch.optim.lr_scheduler import _LRScheduler + +class LinearDecayLR(_LRScheduler): + def __init__(self, optimizer, n_epoch, start_decay, last_epoch=-1): + self.start_decay=start_decay + self.n_epoch=n_epoch + super(LinearDecayLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + last_epoch = self.last_epoch + n_epoch=self.n_epoch + b_lr=self.base_lrs[0] + start_decay=self.start_decay + if last_epoch>start_decay: + lr=b_lr-b_lr/(n_epoch-start_decay)*(last_epoch-start_decay) + else: + lr=b_lr + return [lr] \ No newline at end of file diff --git a/training/optimizor/SAM.py b/training/optimizor/SAM.py new file mode 100644 index 0000000000000000000000000000000000000000..7b8d1dc52726ffea22553ce96a6e0d37a902fbff --- /dev/null +++ b/training/optimizor/SAM.py @@ -0,0 +1,77 @@ +# borrowed from + +import torch + +import torch +import torch.nn as nn + +def disable_running_stats(model): + def _disable(module): + if isinstance(module, nn.BatchNorm2d): + module.backup_momentum = module.momentum + module.momentum = 0 + + model.apply(_disable) + +def enable_running_stats(model): + def _enable(module): + if isinstance(module, nn.BatchNorm2d) and hasattr(module, "backup_momentum"): + module.momentum = module.backup_momentum + + model.apply(_enable) + +class SAM(torch.optim.Optimizer): + def __init__(self, params, base_optimizer, rho=0.05, **kwargs): + assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" + + defaults = dict(rho=rho, **kwargs) + super(SAM, self).__init__(params, defaults) + + self.base_optimizer = base_optimizer(self.param_groups, **kwargs) + self.param_groups = self.base_optimizer.param_groups + + @torch.no_grad() + def first_step(self, zero_grad=False): + grad_norm = self._grad_norm() + for group in self.param_groups: + scale = group["rho"] / (grad_norm + 1e-12) + + for p in group["params"]: + if p.grad is None: continue + e_w = p.grad * scale.to(p) + p.add_(e_w) # climb to the local maximum "w + e(w)" + self.state[p]["e_w"] = e_w + + if zero_grad: self.zero_grad() + + @torch.no_grad() + def second_step(self, zero_grad=False): + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: continue + p.sub_(self.state[p]["e_w"]) # get back to "w" from "w + e(w)" + + self.base_optimizer.step() # do the actual "sharpness-aware" update + + if zero_grad: self.zero_grad() + + @torch.no_grad() + def step(self, closure=None): + assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided" + closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass + + self.first_step(zero_grad=True) + closure() + self.second_step() + + def _grad_norm(self): + shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism + norm = torch.norm( + torch.stack([ + p.grad.norm(p=2).to(shared_device) + for group in self.param_groups for p in group["params"] + if p.grad is not None + ]), + p=2 + ) + return norm \ No newline at end of file diff --git a/training/optimizor/__init__.py b/training/optimizor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..676145d777810e4a51bdaf59fdec4f5358aae349 --- /dev/null +++ b/training/optimizor/__init__.py @@ -0,0 +1,7 @@ +import os +import sys +current_file_path = os.path.abspath(__file__) +parent_dir = os.path.dirname(os.path.dirname(current_file_path)) +project_root_dir = os.path.dirname(parent_dir) +sys.path.append(parent_dir) +sys.path.append(project_root_dir) diff --git a/training/test_pall.py b/training/test_pall.py new file mode 100644 index 0000000000000000000000000000000000000000..4826ee5cb194e03db6a63ead7cdd3572e3f8c9dc --- /dev/null +++ b/training/test_pall.py @@ -0,0 +1,407 @@ +""" +eval pretained model with multi-GPU support. +""" +import os +import numpy as np +from os.path import join +import cv2 +import random +import datetime +import time +import yaml +import pickle +from tqdm import tqdm +from copy import deepcopy +from PIL import Image as pil_image +from metrics.utils import get_test_metrics +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.nn.functional as F +import torch.utils.data +import torch.optim as optim +import torch.distributed as dist +from torch.utils.data.distributed import DistributedSampler + +from dataset.abstract_dataset import DeepfakeAbstractBaseDataset +from dataset.ff_blend import FFBlendDataset +from dataset.fwa_blend import FWABlendDataset +from dataset.pair_dataset import pairDataset + +from trainer.trainer import Trainer +from detectors import DETECTOR +from metrics.base_metrics_class import Recorder, calculate_acc_for_test +import metrics_retrieval.utils +from metrics_retrieval.get_metric_pro4 import * +from metrics_retrieval.get_metric import * + +from collections import defaultdict + +import argparse +from logger import create_logger + +parser = argparse.ArgumentParser(description='Process some paths.') +parser.add_argument('--detector_path', type=str, default='/PATH/TO/resnet34.yaml', help='path to detector YAML file') +parser.add_argument("--test_dataset", nargs="+") +parser.add_argument('--weights_path', type=str, default='') +parser.add_argument('--ddp', action='store_true', help='Use DistributedDataParallel') +parser.add_argument('--use_latest', action='store_true', help='Use Latest Ckpt') +parser.add_argument('--local_rank', '--local-rank', type=int, default=-1, help='Local rank for DDP') +parser.add_argument('--test_config', type=str, default='test_config_p2.yaml', help='test_config_p2.yaml / test_config_p4.yaml') +args = parser.parse_args() + + +def init_seed(config, seed=None): + if seed is None: + if config['manualSeed'] is None: + config['manualSeed'] = random.randint(1, 10000) + seed = config['manualSeed'] + random.seed(seed) + torch.manual_seed(seed) + if config['cuda']: + torch.cuda.manual_seed_all(seed) + return seed + + +def prepare_testing_data(config, ddp=False): + def get_test_data_loader(config, test_name): + # update the config dictionary with the specific testing dataset + config = config.copy() # create a copy of config to avoid altering the original one + config['test_dataset'] = test_name # specify the current test dataset + test_set = DeepfakeAbstractBaseDataset( + config=config, + mode='test', + ) + + # Use DistributedSampler to distribute the data + sampler = DistributedSampler(test_set, shuffle=False) if ddp else None + + test_data_loader = \ + torch.utils.data.DataLoader( + dataset=test_set, + batch_size=config['test_batchSize'], + shuffle=(sampler is None), + num_workers=int(config['workers']), + collate_fn=test_set.collate_fn, + drop_last=False, + pin_memory=True, + sampler=sampler # add sampler + ) + return test_data_loader, test_set.data_dict + + test_data_loaders = {} + test_data_dicts = {} + for one_test_name in config['test_dataset']: + loader, data_dict = get_test_data_loader(config, one_test_name) + test_data_loaders[one_test_name] = loader + test_data_dicts[one_test_name] = data_dict + return test_data_loaders, test_data_dicts + + +def choose_metric(config): + metric_scoring = config['metric_scoring'] + if metric_scoring not in ['eer', 'auc', 'acc', 'ap']: + raise NotImplementedError('metric {} is not implemented'.format(metric_scoring)) + return metric_scoring + + +def test_one_dataset(model, data_loader, device, local_rank): + # Initialize empty lists to store tensors + prediction_lists = [] + feature_lists = [] + label_lists = [] + img_name_lists = [] + + # Only the main process shows the progress bar + pbar = tqdm(enumerate(data_loader), total=len(data_loader), disable=(local_rank != 0)) + + for i, data_dict in pbar: + # get data + data, label, mask, landmark = data_dict['image'], data_dict['label'], data_dict['mask'], data_dict['landmark'] + img_names =[] # data_dict['image'] # Image names are still strings and are stored separately + + # Move data to GPU (keeping the original logic) + data_dict['image'], data_dict['label'] = data.to(device), label.to(device) + if mask is not None: + data_dict['mask'] = mask.to(device) + if landmark is not None: + data_dict['landmark'] = landmark.to(device) + + # Model forward pass (no gradients, original logic unchanged) + predictions = inference(model, data_dict) + + # Use append instead of extend, and concatenate later with torch.cat + label_lists.append(data_dict['label']) # label is a tensor, so append it directly + prediction_lists.append(predictions['prob']) # prob is the tensor output by the model + feature_lists.append(predictions['feat']) # the same applies to feat + img_name_lists.extend(img_names) # String lists still use extend + + # If the current process has no data (an extreme case), return empty tensors to avoid errors + predictions_tensor = torch.cat(prediction_lists, dim=0) if prediction_lists else torch.tensor([], device=device) + labels_tensor = torch.cat(label_lists, dim=0) if label_lists else torch.tensor([], device=device) + feats_tensor = torch.cat(feature_lists, dim=0) if feature_lists else torch.tensor([], device=device) + + print("feats_tensor", feats_tensor.shape) + + # Return results in tensor form (image names remain a list) + return predictions_tensor, labels_tensor, feats_tensor, img_name_lists + + +def test_epoch(model, test_data_loaders, test_data_dicts, device, local_rank, ddp, config, logger): + # set model to eval mode + model.eval() + + # define test recorder + metrics_all_datasets = {} + + # testing for all test data + keys = test_data_loaders.keys() + for key in keys: + + # 1.Dataset Name + print("Run Dataset:", key) + # if args.local_rank == 0: + logger.info(f"--------------- Run Dataset: {key} ---------------") + logger.info(f"--------------- Run Dataset: {logger.log_path} ---------------") + + data_loader = test_data_loaders[key] + data_dict = test_data_dicts[key] + + # Set the sampler epoch in DDP mode + if ddp and hasattr(data_loader.sampler, 'set_epoch'): + data_loader.sampler.set_epoch(0) + + # Each process computes its own portion (the return values are tensors at this point) + predictions_tensor, labels_tensor, feats_tensor, img_names = test_one_dataset( + model, data_loader, device, local_rank) + + # Gather results from all processes (only the main process needs the full results) + if ddp: + world_size = dist.get_world_size() + + # 1. Gather predictions + all_predictions = [torch.zeros_like(predictions_tensor) for _ in range(world_size)] + dist.all_gather(all_predictions, predictions_tensor) + + # 2. Gather labels + all_labels = [torch.zeros_like(labels_tensor) for _ in range(world_size)] + dist.all_gather(all_labels, labels_tensor) + + # 3. Gather features (optional) + all_feats = [torch.zeros_like(feats_tensor) for _ in range(world_size)] + dist.all_gather(all_feats, feats_tensor) + + all_predictions = torch.cat(all_predictions, dim=0) + all_labels = torch.cat(all_labels, dim=0) + all_feats = torch.cat(all_feats, dim=0) + else: + # In non-DDP mode, convert directly to NumPy (only once) + all_predictions = predictions_tensor.cpu().numpy() + all_labels = labels_tensor.cpu().numpy() + all_feats = feats_tensor.cpu().numpy() + #all_img_names = img_names + + # Only the main process computes metrics and outputs results + if local_rank == 0: + # compute metric for each dataset + metric_one_dataset = calculate_acc_for_test(all_labels, all_predictions, config['backbone_config']['num_classes']) + metrics_all_datasets[key] = metric_one_dataset + + # Information for each dataset + tqdm.write(f"dataset: {key}") + for k, v in metric_one_dataset.items(): + tqdm.write(f"{k}: {v}") + logger.info(f"{k}: {v}") + + # save info + pkl_save_path = os.path.join(os.path.dirname(logger.log_path), f"{key}.pkl") + save_data = { + "all_predictions": all_predictions.cpu().numpy(), + "all_labels": all_labels.cpu().numpy(), + "all_feats": all_feats.cpu().numpy(), + "metrics": metric_one_dataset, # Additionally save metrics for the current dataset to facilitate later analysis + "all_names": img_names + } + with open(pkl_save_path, "wb") as f: + pickle.dump(save_data, f, protocol=pickle.HIGHEST_PROTOCOL) # Using the highest protocol is more efficient + + return metrics_all_datasets if local_rank == 0 else None + + +@torch.no_grad() +def inference(model, data_dict): + from torch.cuda.amp import autocast + with autocast(dtype=torch.float16): + predictions = model(data_dict, inference=True) + return predictions + + +def main(): + # Initialize DDP + ddp = args.ddp + local_rank = args.local_rank + + if ddp: + # Initialize the process group + torch.cuda.set_device(local_rank) + dist.init_process_group( + backend='nccl', + init_method='env://', # Read rendezvous information from environment variables (set automatically by torchrun) + world_size=int(os.environ.get("WORLD_SIZE", 1)), # total number of GPUs + rank=int(os.environ.get("RANK", 0)) + ) + device = torch.device("cuda", local_rank) + else: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # parse options and load config + # Model-specific configuration + with open(args.detector_path, 'r') as f: + config = yaml.safe_load(f) + # Unified base configuration + with open(f'./training/config/{args.test_config}', 'r') as f: + config_base = yaml.safe_load(f) + + # Label dictionary shared by all datasets + if 'label_dict' in config: + config_base['label_dict']=config['label_dict'] # The base configuration has the highest priority + config.update(config_base) + + weights_path = None + # If arguments are provided, they will overwrite the yaml settings + if args.test_dataset: + config['test_dataset'] = args.test_dataset + if args.weights_path: + config['weights_path'] = args.weights_path + weights_path = args.weights_path + + # Set the same seed for DDP + seed = init_seed(config) + if ddp: + # Use a different seed offset for each process to ensure data augmentation diversity + seed += dist.get_rank() + init_seed(config, seed) + + # set cudnn benchmark if needed + if config['cudnn']: + cudnn.benchmark = True + + # Log information + logs_test_dir = weights_path.replace("logs", "logs_test") + # if local_rank == 0: + # creat log + os.makedirs(logs_test_dir, exist_ok=True) + logger = create_logger(os.path.join(logs_test_dir, 'testing.log')) + logger.info('Save log to {}'.format(logs_test_dir)) + # print configuration + logger.info("--------------- Configuration ---------------") + params_string = "Parameters: \n" + for key, value in config.items(): + params_string += "{}: {}".format(key, value) + "\n" + logger.info(params_string) + + # prepare the testing data loader + test_data_loaders, test_data_dicts = prepare_testing_data(config, ddp) + + # prepare the model (detector) + model_class = DETECTOR[config['model_name']] + model = model_class(config).to(device) + epoch = 0 + + # Only print model parameter information on the main process + if local_rank == 0: + for name, param in model.named_parameters(): + print(f"{name}: {param.shape}") + + if weights_path: + # For models containing LoRA, switch to eval mode first to avoid repeatedly stacking weights + if 'lora' in config['model_name'].lower() or "pmoe" in config['model_name'].lower(): + model.eval() + + if weights_path: + try: + epoch = int(weights_path.split('/')[-1].split('.')[0].split('_')[2]) + except: + epoch = 0 + + # Automatically find the best checkpoint + if args.use_latest: + ckpt_path = os.path.join(weights_path, "test/protocol_2_test/ckpt_latest.pth") + else: + if weights_path[-3:] == "pth": + ckpt_path = weights_path + else: + ckpt_path = os.path.join(weights_path, "test/protocol_2_test/ckpt_best.pth") + # ckpt_path = os.path.join(weights_path, "test/protocol_2_test/ckpt_best.pth") + ckpt = torch.load(ckpt_path, map_location=f"cuda:{local_rank}") + logger.info(f"Load ckpt: {ckpt_path}") + + # Remove the "module." prefix from the weights (if DDP was used during training) + new_state_dict = {k.replace('module.', ''): v for k, v in ckpt.items()} + + model.load_state_dict(new_state_dict, strict=False) + + if local_rank == 0: + print('===> Load checkpoint done!') + else: + if local_rank == 0: + print('Fail to load the pre-trained weights') + + if ddp: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank) + + # start testing + best_metric = test_epoch(model, test_data_loaders, test_data_dicts, device, local_rank, ddp, config, logger) + + if local_rank == 0: + print('===> Test Done!') + + # Clean up the DDP process group + if ddp: + dist.barrier() + dist.destroy_process_group() + + if local_rank==0: + #Metric test + for prot in config['test_dataset']: + prefix=config["weights_path"].replace("logs/", "logs_test/") + prefix=os.path.join("/Youtu_Pangu_Security_Public_cq11/shunliwang/DeepFakeBench_DFG",prefix) + if "protocol_2" in prot: + pkl_file="protocol_2_test.pkl" + RANK_MAX = 10 + seed = 42 + PKL_FILE_PATH = os.path.join(prefix, pkl_file) + run_retrieval_evaluation(pkl_file_path=PKL_FILE_PATH, query_mode='10_sample_avg', rank_max=RANK_MAX,random_seed=seed) + elif "protocol_3" in prot: + pkl_file="protocol_3_test.pkl" + RANK_MAX = 10 + seed = 42 + PKL_FILE_PATH = os.path.join(prefix, pkl_file) + run_retrieval_evaluation(pkl_file_path=PKL_FILE_PATH, query_mode='10_sample_avg', rank_max=RANK_MAX,random_seed=seed) + elif "protocol_4" in prot: + pkl_file="protocol_4_test.pkl" + RANK_MAX = 10 + seed = 42 + PKL_FILE_PATH = os.path.join(prefix, pkl_file) + yaml_path="config/test_config_p4.yaml" + run_retrieval_evaluation_p4(pkl_file_path=PKL_FILE_PATH, query_mode='10_sample_avg', rank_max=RANK_MAX,random_seed=seed,yaml_path=yaml_path) + + + + + + +if __name__ == '__main__': + main() + + +# 1.Useful information in the log + +# 2.Create the log_test directory + +# 3.Create logger text output + +# 4.Save features and labels + + diff --git a/training/train.py b/training/train.py new file mode 100644 index 0000000000000000000000000000000000000000..6de759e4bf4ed04228c8097294fd7b83f3cc71ef --- /dev/null +++ b/training/train.py @@ -0,0 +1,342 @@ +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-03-30 +# description: training code. + +import os +import argparse +from os.path import join +import cv2 +import random +import datetime +import time +import yaml +from tqdm import tqdm +import numpy as np +from datetime import timedelta +from copy import deepcopy +from PIL import Image as pil_image + +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.utils.data +import torch.optim as optim +from torch.utils.data.distributed import DistributedSampler +import torch.distributed as dist + +from optimizor.SAM import SAM +from optimizor.LinearLR import LinearDecayLR + +from trainer.trainer import Trainer +from detectors import DETECTOR +from dataset import * +from metrics.utils import parse_metric_for_print +from logger import create_logger + +# torch.hub.set_dir("training/pretrained") + +parser = argparse.ArgumentParser(description='Process some paths.') +parser.add_argument('--detector_path', type=str, + default='/data/home/zhiyuanyan/DeepfakeBenchv2/training/config/detector/sbi.yaml', + help='path to detector YAML file') +parser.add_argument("--train_dataset", nargs="+") +parser.add_argument("--test_dataset", nargs="+") +parser.add_argument('--no-save_ckpt', dest='save_ckpt', action='store_false', default=True) +parser.add_argument('--no-save_feat', dest='save_feat', action='store_false', default=True) +parser.add_argument("--ddp", action='store_true', default=False) +parser.add_argument('--local_rank', '--local-rank', type=int, default=0) +parser.add_argument('--task_target', type=str, default="", help='specify the target of current training task') +args = parser.parse_args() +torch.cuda.set_device(args.local_rank) + + +def init_seed(config): + if config['manualSeed'] is None: + config['manualSeed'] = random.randint(1, 10000) + random.seed(config['manualSeed']) + if config['cuda']: + torch.manual_seed(config['manualSeed']) + torch.cuda.manual_seed_all(config['manualSeed']) + + +def prepare_training_data(config): + #### Prepare Dataset + # Only use the blending dataset class in training + if 'dataset_type' in config and config['dataset_type'] == 'blend': + if config['model_name'] == 'facexray': + train_set = FFBlendDataset(config) + elif config['model_name'] == 'fwa': + train_set = FWABlendDataset(config) + elif config['model_name'] == 'sbi': + train_set = SBIDataset(config, mode='train') + elif config['model_name'] == 'lsda': + train_set = LSDADataset(config, mode='train') + else: + raise NotImplementedError('Only facexray, fwa, sbi, and lsda are currently supported for blending dataset') + elif 'dataset_type' in config and config['dataset_type'] == 'pair': + train_set = pairDataset(config, mode='train') # Only use the pair dataset class in training + elif 'dataset_type' in config and config['dataset_type'] == 'iid': + train_set = IIDDataset(config, mode='train') + elif 'dataset_type' in config and config['dataset_type'] == 'I2G': + train_set = I2GDataset(config, mode='train') + elif 'dataset_type' in config and config['dataset_type'] == 'lrl': + train_set = LRLDataset(config, mode='train') + else: + train_set = DeepfakeAbstractBaseDataset(config=config, mode='train') + + #### Prepare DataLoader + # Use a customized `CustomSampler` when the model is LSDA + if config['model_name'] == 'lsda': + from dataset.lsda_dataset import CustomSampler + custom_sampler = CustomSampler(num_groups=2*360, n_frame_per_vid=config['frame_num']['train'], batch_size=config['train_batchSize'], videos_per_group=5) + train_data_loader = \ + torch.utils.data.DataLoader( + dataset=train_set, + batch_size=config['train_batchSize'], + num_workers=int(config['workers']), + sampler=custom_sampler, + collate_fn=train_set.collate_fn, + pin_memory=True + ) + # Configure a distributed sampler when DDP is enabled + elif config['ddp']: + sampler = DistributedSampler(train_set) + train_data_loader = \ + torch.utils.data.DataLoader( + dataset=train_set, + batch_size=config['train_batchSize'], + num_workers=int(config['workers']), + collate_fn=train_set.collate_fn, + sampler=sampler, + pin_memory=True + ) + # Otherwise use the standard sampler + else: + + train_data_loader = \ + torch.utils.data.DataLoader( + dataset=train_set, + batch_size=config['train_batchSize'], + shuffle=True, + num_workers=int(config['workers']), + collate_fn=train_set.collate_fn, + pin_memory=True + ) + + return train_data_loader + + +def prepare_testing_data(config): + def get_test_data_loader(config, test_name): + # update the config dictionary with the specific testing dataset + config = config.copy() # create a copy of config to avoid altering the original one + config['test_dataset'] = test_name # specify the current test dataset + if not config.get('dataset_type', None) == 'lrl': + test_set = DeepfakeAbstractBaseDataset( + config=config, + mode='test', + ) + else: + test_set = LRLDataset( + config=config, + mode='test', + ) + + test_data_loader = \ + torch.utils.data.DataLoader( + dataset=test_set, + batch_size=config['test_batchSize'], + shuffle=False, + num_workers=int(config['workers']), + collate_fn=test_set.collate_fn, + drop_last=False, + pin_memory=True + ) + + return test_data_loader + + test_data_loaders = {} + for one_test_name in config['test_dataset']: + test_data_loaders[one_test_name] = get_test_data_loader(config, one_test_name) + return test_data_loaders + + +def choose_optimizer(model, config): + opt_name = config['optimizer']['type'] + if opt_name == 'sgd': + optimizer = optim.SGD( + params=model.parameters(), + lr=config['optimizer'][opt_name]['lr'], + momentum=config['optimizer'][opt_name]['momentum'], + weight_decay=config['optimizer'][opt_name]['weight_decay'] + ) + return optimizer + elif opt_name == 'adam': + optimizer = optim.Adam( + params=model.parameters(), + lr=config['optimizer'][opt_name]['lr'], + weight_decay=config['optimizer'][opt_name]['weight_decay'], + betas=(config['optimizer'][opt_name]['beta1'], config['optimizer'][opt_name]['beta2']), + eps=config['optimizer'][opt_name]['eps'], + amsgrad=config['optimizer'][opt_name]['amsgrad'], + ) + return optimizer + elif opt_name == 'sam': + optimizer = SAM( + model.parameters(), + optim.SGD, + lr=config['optimizer'][opt_name]['lr'], + momentum=config['optimizer'][opt_name]['momentum'], + ) + else: + raise NotImplementedError('Optimizer {} is not implemented'.format(config['optimizer'])) + return optimizer + + +def choose_scheduler(config, optimizer): + if config['lr_scheduler'] is None: + return None + elif config['lr_scheduler'] == 'step': + scheduler = optim.lr_scheduler.StepLR( + optimizer, + step_size=config['lr_step'], + gamma=config['lr_gamma'], + ) + return scheduler + elif config['lr_scheduler'] == 'cosine': + scheduler = optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=config['lr_T_max'], + eta_min=config['lr_eta_min'], + ) + return scheduler + elif config['lr_scheduler'] == 'linear': + scheduler = LinearDecayLR( + optimizer, + config['nEpochs'], + int(config['nEpochs']/4), + ) + else: + raise NotImplementedError('Scheduler {} is not implemented'.format(config['lr_scheduler'])) + + +def choose_metric(config): + metric_scoring = config['metric_scoring'] + if metric_scoring not in ['eer', 'auc', 'acc', 'ap']: + raise NotImplementedError('metric {} is not implemented'.format(metric_scoring)) + return metric_scoring + + +def main(): + # parse options and load config + + # Model-specific configuration + with open(args.detector_path, 'r') as f: + config = yaml.safe_load(f) + # Unified base configuration + with open('./training/config/train_config_p2.yaml', 'r') as f: + config_base = yaml.safe_load(f) + + # Label dictionary shared by all datasets + if 'label_dict' in config: + config_base['label_dict']=config['label_dict'] # The base configuration has the highest priority + + config.update(config_base) + + config['local_rank']=args.local_rank + if config['dry_run']: + config['nEpochs'] = 0 + config['save_feat']=False + + # If arguments are provided, they will overwrite the yaml settings + if args.train_dataset: + config['train_dataset'] = args.train_dataset + if args.test_dataset: + config['test_dataset'] = args.test_dataset + config['save_ckpt'] = args.save_ckpt + config['save_feat'] = args.save_feat + if config['lmdb']: + config['dataset_json_folder'] = 'preprocessing/dataset_json' # dataset_json_v3 + + # init seed + init_seed(config) + + # set cudnn benchmark if needed + if config['cudnn']: + cudnn.benchmark = True + config['ddp']= args.ddp + if config['ddp']: + # dist.init_process_group(backend='gloo') + dist.init_process_group(backend='nccl', timeout=timedelta(minutes=30)) + + # create logger + timenow=datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') + task_str = f"_{config['task_target']}" if config.get('task_target', None) is not None else "" + logger_path = os.path.join( + config['log_dir'], + config['model_name'] + task_str + '_' + timenow + ) + os.makedirs(logger_path, exist_ok=True) + logger = create_logger(os.path.join(logger_path, 'training.log')) + logger.info('Save log to {}'.format(logger_path)) + + # print configuration + logger.info("--------------- Configuration ---------------") + params_string = "Parameters: \n" + for key, value in config.items(): + params_string += "{}: {}".format(key, value) + "\n" + logger.info(params_string) + + # prepare the training data loader + train_data_loader = prepare_training_data(config) + + # prepare the testing data loader + test_data_loaders = prepare_testing_data(config) + + # prepare the model (detector) + model_class = DETECTOR[config['model_name']] + model = model_class(config) + + print(model) + + # prepare the optimizer + optimizer = choose_optimizer(model, config) + + # prepare the scheduler + scheduler = choose_scheduler(config, optimizer) + + # prepare the metric + metric_scoring = choose_metric(config) + + # prepare the trainer + trainer = Trainer(config, model, optimizer, scheduler, logger, metric_scoring, time_now=timenow) + + # start training + for epoch in range(config['start_epoch'], config['nEpochs'] + 1): + trainer.model.epoch = epoch + if config['ddp']: + train_data_loader.sampler.set_epoch(epoch) + best_metric = trainer.train_epoch( + epoch=epoch, + train_data_loader=train_data_loader, + test_data_loaders=test_data_loaders, + ) + if best_metric is not None: + logger.info(f"===> Epoch[{epoch}] end with testing {metric_scoring}: {parse_metric_for_print(best_metric)}!") + logger.info("Stop Training on best Testing metric {}".format(parse_metric_for_print(best_metric))) + + # update + if 'svdd' in config['model_name']: + model.update_R(epoch) + if scheduler is not None: + scheduler.step() + + # close the tensorboard writers + for writer in trainer.writers.values(): + writer.close() + + +if __name__ == '__main__': + main() diff --git a/training/trainer/__init__.py b/training/trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..53267828bbd6fcf8377bd4881398ea8550c5168e --- /dev/null +++ b/training/trainer/__init__.py @@ -0,0 +1,9 @@ +import os +import sys +current_file_path = os.path.abspath(__file__) +parent_dir = os.path.dirname(os.path.dirname(current_file_path)) +project_root_dir = os.path.dirname(parent_dir) +sys.path.append(parent_dir) +sys.path.append(project_root_dir) + +from metrics.registry import TRAINER \ No newline at end of file diff --git a/training/trainer/base_trainer.py b/training/trainer/base_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e1402752b994f82702349e5c7c8d98b1fc7d1519 --- /dev/null +++ b/training/trainer/base_trainer.py @@ -0,0 +1,50 @@ +import datetime +from copy import deepcopy +from abc import ABC, abstractmethod + + +class BaseTrainer(ABC): + """ + """ + + def __init__( + self, + config, + model, + optimizer, + scheduler, + writer, + ): + # check if all the necessary components are implemented + if config is None or model is None or optimizer is None or scheduler is None or writer is None: + raise NotImplementedError("config, model, optimizier, scheduler, and tensorboard writer must be implemented") + + self.config = config + self.model = model + self.optimizer = optimizer + self.scheduler = scheduler + self.writer = writer + + @abstractmethod + def speed_up(self): + pass + + @abstractmethod + def setTrain(self): + pass + + @abstractmethod + def setEval(self): + pass + + @abstractmethod + def load_ckpt(self, model_path): + pass + + @abstractmethod + def save_ckpt(self, dataset, epoch, iters, best=False): + pass + + @abstractmethod + def inference(self, data_dict): + pass diff --git a/training/trainer/trainer.py b/training/trainer/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..3faf5f755c7c2a85b9c50d76aa5243f773998de6 --- /dev/null +++ b/training/trainer/trainer.py @@ -0,0 +1,501 @@ +# author: Zhiyuan Yan +# email: zhiyuanyan@link.cuhk.edu.cn +# date: 2023-03-30 +# description: trainer +import os +import sys +current_file_path = os.path.abspath(__file__) +parent_dir = os.path.dirname(os.path.dirname(current_file_path)) +project_root_dir = os.path.dirname(parent_dir) +sys.path.append(parent_dir) +sys.path.append(project_root_dir) + +import pickle +import datetime +import logging +import numpy as np +from copy import deepcopy +from collections import defaultdict +from tqdm import tqdm +import time +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.nn import DataParallel +from torch.utils.tensorboard import SummaryWriter +from metrics.base_metrics_class import Recorder +from torch.optim.swa_utils import AveragedModel, SWALR +from torch import distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from sklearn import metrics +from metrics.utils import get_test_metrics +from metrics.base_metrics_class import calculate_acc_for_test + +FFpp_pool=['FaceForensics++','FF-DF','FF-F2F','FF-FS','FF-NT']# +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class Trainer(object): + def __init__( + self, + config, + model, + optimizer, + scheduler, + logger, + metric_scoring='auc', + time_now = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S'), + swa_model=None + ): + # check if all the necessary components are implemented + if config is None or model is None or optimizer is None or logger is None: + raise ValueError("config, model, optimizier, logger, and tensorboard writer must be implemented") + + self.config = config + self.model = model + self.optimizer = optimizer + self.scheduler = scheduler + self.swa_model = swa_model + self.writers = {} # dict to maintain different tensorboard writers for each dataset and metric + self.logger = logger + self.metric_scoring = metric_scoring + # maintain the best metric of all epochs + self.best_metrics_all_time = defaultdict( + lambda: defaultdict(lambda: float('-inf') + if self.metric_scoring != 'eer' else float('inf')) + ) + self.speed_up() # move model to GPU + + # get current time + self.timenow = time_now + # create directory path + if 'task_target' not in config: + self.log_dir = os.path.join( + self.config['log_dir'], + self.config['model_name'] + '_' + self.timenow + ) + else: + task_str = f"_{config['task_target']}" if config['task_target'] is not None else "" + self.log_dir = os.path.join( + self.config['log_dir'], + self.config['model_name'] + task_str + '_' + self.timenow + ) + os.makedirs(self.log_dir, exist_ok=True) + + def get_writer(self, phase, dataset_key, metric_key): + writer_key = f"{phase}-{dataset_key}-{metric_key}" + if writer_key not in self.writers: + # update directory path + writer_path = os.path.join( + self.log_dir, + phase, + dataset_key, + metric_key, + "metric_board" + ) + os.makedirs(writer_path, exist_ok=True) + # update writers dictionary + self.writers[writer_key] = SummaryWriter(writer_path) + return self.writers[writer_key] + + + def speed_up(self): + self.model.to(device) + self.model.device = device + if self.config['ddp'] == True: + num_gpus = torch.cuda.device_count() + print(f'avai gpus: {num_gpus}') + # local_rank=[i for i in range(0,num_gpus)] + self.model = DDP(self.model, device_ids=[self.config['local_rank']],find_unused_parameters=True, output_device=self.config['local_rank']) + #self.optimizer = nn.DataParallel(self.optimizer, device_ids=[int(os.environ['LOCAL_RANK'])]) + + def setTrain(self): + self.model.train() + self.train = True + + def setEval(self): + self.model.eval() + self.train = False + + def load_ckpt(self, model_path): + if os.path.isfile(model_path): + saved = torch.load(model_path, map_location='cpu') + suffix = model_path.split('.')[-1] + if suffix == 'p': + self.model.load_state_dict(saved.state_dict()) + else: + self.model.load_state_dict(saved) + self.logger.info('Model found in {}'.format(model_path)) + else: + raise NotImplementedError( + "=> no model found at '{}'".format(model_path)) + + def save_ckpt(self, phase, dataset_key,ckpt_info=None,ckpt_name="ckpt_best.pth"): + save_dir = os.path.join(self.log_dir, phase, dataset_key) + os.makedirs(save_dir, exist_ok=True) + # ckpt_name = f"ckpt_best.pth" + save_path = os.path.join(save_dir, ckpt_name) + if self.config['ddp'] == True: + torch.save(self.model.state_dict(), save_path) + else: + if 'svdd' in self.config['model_name']: + torch.save({'R': self.model.R, + 'c': self.model.c, + 'state_dict': self.model.state_dict(),}, save_path) + else: + torch.save(self.model.state_dict(), save_path) + self.logger.info(f"Checkpoint saved to {save_path}, current ckpt is {ckpt_info}") + + def save_swa_ckpt(self): + save_dir = self.log_dir + os.makedirs(save_dir, exist_ok=True) + ckpt_name = f"swa.pth" + save_path = os.path.join(save_dir, ckpt_name) + torch.save(self.swa_model.state_dict(), save_path) + self.logger.info(f"SWA Checkpoint saved to {save_path}") + + + def save_feat(self, phase, fea, dataset_key): + save_dir = os.path.join(self.log_dir, phase, dataset_key) + os.makedirs(save_dir, exist_ok=True) + features = fea + feat_name = f"feat_best.npy" + save_path = os.path.join(save_dir, feat_name) + np.save(save_path, features) + self.logger.info(f"Feature saved to {save_path}") + + def save_data_dict(self, phase, data_dict, dataset_key): + if self.config['local_rank'] != 0: + return + save_dir = os.path.join(self.log_dir, phase, dataset_key) + os.makedirs(save_dir, exist_ok=True) + file_path = os.path.join(save_dir, f'data_dict_{phase}.pickle') + with open(file_path, 'wb') as file: + pickle.dump(data_dict, file) + self.logger.info(f"data_dict saved to {file_path}") + + def save_metrics(self, phase, metric_one_dataset, dataset_key): + save_dir = os.path.join(self.log_dir, phase, dataset_key) + os.makedirs(save_dir, exist_ok=True) + file_path = os.path.join(save_dir, 'metric_dict_best.pickle') + with open(file_path, 'wb') as file: + pickle.dump(metric_one_dataset, file) + self.logger.info(f"Metrics saved to {file_path}") + + def train_step(self,data_dict): + if self.config['optimizer']['type']=='sam': + for i in range(2): + predictions = self.model(data_dict) + losses = self.model.get_losses(data_dict, predictions) + if i == 0: + pred_first = predictions + losses_first = losses + self.optimizer.zero_grad() + losses['overall'].backward() + if i == 0: + self.optimizer.first_step(zero_grad=True) + else: + self.optimizer.second_step(zero_grad=True) + return losses_first, pred_first + else: + + predictions = self.model(data_dict) + if type(self.model) is DDP: + losses = self.model.module.get_losses(data_dict, predictions) + else: + losses = self.model.get_losses(data_dict, predictions) + self.optimizer.zero_grad() + losses['overall'].backward() + self.optimizer.step() + + + return losses,predictions + + + def train_epoch(self, epoch, train_data_loader, test_data_loaders=None): + self.logger.info("===> Epoch[{}] start!".format(epoch)) + + # test 1 time per epoch + times_per_epoch = 1 + test_step = len(train_data_loader) // times_per_epoch # test 10 times per epoch + step_cnt = epoch * len(train_data_loader) + + # save the training data_dict + data_dict = train_data_loader.dataset.data_dict + self.save_data_dict('train', data_dict, ','.join(self.config['train_dataset'])) + + # define training recorder + train_recorder_loss = defaultdict(Recorder) + train_recorder_metric = defaultdict(Recorder) + + self.logger.info("===> Start For Loop!".format(epoch)) + for iteration, data_dict in tqdm(enumerate(train_data_loader),total=len(train_data_loader)): + self.setTrain() + + # more elegant and more scalable way of moving data to GPU + for key in data_dict.keys(): + if data_dict[key]!=None and key!='name': + data_dict[key]=data_dict[key].cuda() + + losses,predictions=self.train_step(data_dict) + + # update learning rate + if 'SWA' in self.config and self.config['SWA'] and epoch>self.config['swa_start']: + self.swa_model.update_parameters(self.model) + + # compute training metric for each batch data + if type(self.model) is DDP: + batch_metrics = self.model.module.get_train_metrics(data_dict, predictions) + else: + batch_metrics = self.model.get_train_metrics(data_dict, predictions) + + # store data by recorder + ## store metric + for name, value in batch_metrics.items(): + train_recorder_metric[name].update(value) + ## store loss + for name, value in losses.items(): + train_recorder_loss[name].update(value) + + # run tensorboard to visualize the training process + if iteration % 100 == 0 and self.config['local_rank']==0: + if self.config['SWA'] and (epoch>self.config['swa_start'] or self.config['dry_run']): + self.scheduler.step() + # info for loss + loss_str = f"Iter: {step_cnt} " + for k, v in train_recorder_loss.items(): + v_avg = v.average() + if v_avg == None: + loss_str += f"training-loss, {k}: not calculated" + continue + loss_str += f"training-loss, {k}: {v_avg:.6f} " + # tensorboard-1. loss + writer = self.get_writer('train', ','.join(self.config['train_dataset']), k) + writer.add_scalar(f'train_loss/{k}', v_avg, global_step=step_cnt) + self.logger.info(loss_str) + # info for metric + metric_str = f"Iter: {step_cnt} " + for k, v in train_recorder_metric.items(): + v_avg = v.average() + if v_avg == None: + metric_str += f"training-metric, {k}: not calculated " + continue + metric_str += f"training-metric, {k}: {v_avg:.6f} " + # tensorboard-2. metric + writer = self.get_writer('train', ','.join(self.config['train_dataset']), k) + writer.add_scalar(f'train_metric/{k}', v_avg, global_step=step_cnt) + self.logger.info(metric_str) + + # clear recorder. + # Note we only consider the current 300 samples for computing batch-level loss/metric + for name, recorder in train_recorder_loss.items(): # clear loss recorder + recorder.clear() + for name, recorder in train_recorder_metric.items(): # clear metric recorder + recorder.clear() + + # run test + test_best_metric = None + if (step_cnt+1) % test_step == 0: + if test_data_loaders is not None and (not self.config['ddp'] ): + self.logger.info("===> Test start!") + test_best_metric = self.test_epoch(epoch, iteration, test_data_loaders, step_cnt,) + elif test_data_loaders is not None and (self.config['ddp'] and dist.get_rank() == 0): + self.logger.info("===> Test start!") + test_best_metric = self.test_epoch(epoch, iteration, test_data_loaders, step_cnt,) + else: + test_best_metric = None + + # total_end_time = time.time() + # total_elapsed_time = total_end_time - total_start_time + # print("total cost time: {:.2f} seconds".format(total_elapsed_time)) + step_cnt += 1 + + torch.cuda.empty_cache() + + return test_best_metric + + def get_respect_acc_bin(self,prob,label): + pred = np.where(prob > 0.5, 1, 0) + judge = (pred == label) + zero_num = len(label) - np.count_nonzero(label) + acc_fake = np.count_nonzero(judge[zero_num:]) / len(judge[zero_num:]) + acc_real = np.count_nonzero(judge[:zero_num]) / len(judge[:zero_num]) + return acc_real,acc_fake + + def get_respect_acc(self, pred_probs, labels): + pred_probs = torch.tensor(pred_probs) + labels = torch.tensor(labels) + + _, preds = torch.max(pred_probs, dim=1) # shape[N] + + classes = torch.unique(labels) + num_classes = len(classes) + + class_correct = torch.zeros(num_classes, dtype=torch.int64) + class_total = torch.zeros(num_classes, dtype=torch.int64) + + for label, pred in zip(labels, preds): + class_idx = label.item() + class_total[class_idx] += 1 + if label == pred: + class_correct[class_idx] += 1 + + + class_acc = {} + for i, cls in enumerate(classes): + if class_total[i] == 0: + class_acc[cls.item()] = 0.0 + else: + class_acc[cls.item()] = (class_correct[i] / class_total[i]).item() + + return class_acc + + def test_one_dataset(self, data_loader): + # define test recorder + test_recorder_loss = defaultdict(Recorder) + prediction_lists, feature_lists, label_lists = [], [], [] + for i, data_dict in tqdm(enumerate(data_loader),total=len(data_loader)): + if 'label_spe' in data_dict: + data_dict.pop('label_spe') # remove the specific label + # data_dict['label'] = torch.where(data_dict['label']!=0, 1, 0) # fix the label to 0 and 1 only + # move data to GPU elegantly + + for key in data_dict.keys(): + if data_dict[key] != None: + data_dict[key] = data_dict[key].cuda() + + # model forward without considering gradient computation + predictions = self.inference(data_dict) + + label_lists += list(data_dict['label'].cpu().detach().numpy()) + prediction_lists += list(predictions['prob'].cpu().detach().numpy()) + feature_lists += list(predictions['feat'].cpu().detach().numpy()) + + if type(self.model) is not AveragedModel: + # compute all losses for each batch data + if type(self.model) is DDP: + losses = self.model.module.get_losses(data_dict, predictions) + else: + losses = self.model.get_losses(data_dict, predictions) + + # store data by recorder + for name, value in losses.items(): + test_recorder_loss[name].update(value) + + return test_recorder_loss, np.array(prediction_lists), np.array(label_lists), np.array(feature_lists) + + def save_best(self,epoch,iteration,step,losses_one_dataset_recorder,key,metric_one_dataset): + best_metric = self.best_metrics_all_time[key].get(self.metric_scoring, + float('-inf') if self.metric_scoring != 'eer' else float( + 'inf')) + # Check if the current score is an improvement + improved = (metric_one_dataset[self.metric_scoring] > best_metric) if self.metric_scoring != 'eer' else ( + metric_one_dataset[self.metric_scoring] < best_metric) + if improved: + # Update the best metric + self.best_metrics_all_time[key][self.metric_scoring] = metric_one_dataset[self.metric_scoring] + if key == 'avg': + self.best_metrics_all_time[key]['dataset_dict'] = metric_one_dataset['dataset_dict'] + # Save checkpoint, feature, and metrics if specified in config + if self.config['save_ckpt'] and key not in FFpp_pool: + self.save_ckpt('test', key, f"{epoch}+{iteration}") + self.save_metrics('test', metric_one_dataset, key) + + # Save the latested ckpt + if self.config['save_latest_ckpt']: + self.save_ckpt('test', key, f"{epoch}+{iteration}", ckpt_name="ckpt_latest.pth") + + # loss + if losses_one_dataset_recorder is not None: + # Information for each dataset + loss_str = f"dataset: {key} step: {step} " + for k, v in losses_one_dataset_recorder.items(): + writer = self.get_writer('test', key, k) + v_avg = v.average() + if v_avg == None: + print(f'{k} is not calculated') + continue + # tensorboard-1. loss + writer.add_scalar(f'test_losses/{k}', v_avg, global_step=step) + loss_str += f"testing-loss, {k}: {v_avg:.6f} " + self.logger.info(loss_str) + + # metric + metric_str = f"dataset: {key} step: {step} " + for k, v in metric_one_dataset.items(): + if k == 'pred' or k == 'label' or k=='dataset_dict': + continue + metric_str += f"testing-metric, {k}: {v:.6f} " + # tensorboard-2. metric + writer = self.get_writer('test', key, k) + writer.add_scalar(f'test_metrics/{k}', v, global_step=step) + self.logger.info(metric_str) + + if self.config['local_rank'] == 0: + if 'pred' in metric_one_dataset: + # get acc for each class + self.logger.info("Start get_respect_acc()") + get_respect_acc = self.get_respect_acc(metric_one_dataset['pred'], metric_one_dataset['label']) + self.logger.info("End get_respect_acc()") + for cls_name, cls_val in get_respect_acc.items(): + metric_str = f"testing-metric, {cls_name}-acc:{cls_val:.4f}" + self.logger.info(metric_str) + writer.add_scalar(f'test_metrics/acc_{cls_name}', cls_val, global_step=step) + + def test_epoch(self, epoch, iteration, test_data_loaders, step): + self.setEval() + + losses_all_datasets, metrics_all_datasets = {}, {} + best_metrics_per_dataset = defaultdict(dict) # best metric for each dataset, for each metric + avg_metric = {'acc': 0, 'mAP': 0, 'dataset_dict':{}} + + # All Datasets + keys = test_data_loaders.keys() + for key in keys: + # save the testing data_dict + data_dict = test_data_loaders[key].dataset.data_dict + self.save_data_dict('test', data_dict, key) + + # compute loss for each dataset + losses_one_dataset_recorder, predictions_nps, label_nps, feature_nps = self.test_one_dataset(test_data_loaders[key]) + print(f'stack len:{predictions_nps.shape};{label_nps.shape};{len(data_dict["image"])}') + + losses_all_datasets[key] = losses_one_dataset_recorder + + # metric_one_dataset = get_test_metrics(y_pred=predictions_nps, y_true=label_nps, img_names=data_dict['image']) + metric_one_dataset = calculate_acc_for_test(label_nps, predictions_nps, self.config['backbone_config']['num_classes']) + + for metric_name, value in metric_one_dataset.items(): + if metric_name in avg_metric: + avg_metric[metric_name]+=value + avg_metric['dataset_dict'][key] = metric_one_dataset[self.metric_scoring] + if type(self.model) is AveragedModel: + metric_str = f"Iter Final for SWA: " + for k, v in metric_one_dataset.items(): + metric_str += f"testing-metric, {k}: {v:.6f} " + self.logger.info(metric_str) + continue + self.save_best(epoch,iteration,step,losses_one_dataset_recorder,key,metric_one_dataset) + + if len(keys)>0 and self.config.get('save_avg',False): + # calculate avg value + for key in avg_metric: + if key != 'dataset_dict': + avg_metric[key] /= len(keys) + self.save_best(epoch, iteration, step, None, 'avg', avg_metric) + + self.logger.info('===> Test Done!') + return self.best_metrics_all_time # return all types of mean metrics for determining the best ckpt + + @torch.no_grad() + def inference(self, data_dict): + predictions = self.model(data_dict, inference=True) + return predictions + + # @torch.no_grad() + # def inference(self, data_dict): + # from torch.cuda.amp import autocast + # with autocast(dtype=torch.float16): + # predictions = self.model(data_dict, inference=True) + # return predictions