Spaces:
Sleeping
Sleeping
Sync from GitHub via hub-sync
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .python-version +1 -0
- LICENSE +6 -0
- README.md +181 -6
- app.py +81 -0
- detector_codes/AIDE-main/LICENSE +21 -0
- detector_codes/AIDE-main/data/__init__.py +0 -0
- detector_codes/AIDE-main/data/datasets.py +199 -0
- detector_codes/AIDE-main/data/dct.py +151 -0
- detector_codes/AIDE-main/engine_finetune.py +202 -0
- detector_codes/AIDE-main/main_finetune.py +657 -0
- detector_codes/AIDE-main/models/AIDE.py +298 -0
- detector_codes/AIDE-main/models/__init__.py +0 -0
- detector_codes/AIDE-main/models/clip/__init__.py +1 -0
- detector_codes/AIDE-main/models/clip/bpe_simple_vocab_16e6.txt.gz +3 -0
- detector_codes/AIDE-main/models/clip/clip.py +237 -0
- detector_codes/AIDE-main/models/clip/lora_clip.py +151 -0
- detector_codes/AIDE-main/models/clip/model.py +452 -0
- detector_codes/AIDE-main/models/clip/simple_tokenizer.py +132 -0
- detector_codes/AIDE-main/models/srm_filter_kernel.py +220 -0
- detector_codes/AIDE-main/models/utils.py +116 -0
- detector_codes/AIDE-main/optim_factory.py +222 -0
- detector_codes/AIDE-main/requirements.txt +36 -0
- detector_codes/AIDE-main/scripts/eval.sh +21 -0
- detector_codes/AIDE-main/scripts/train.sh +24 -0
- detector_codes/AIDE-main/utils.py +700 -0
- detector_codes/C2P-CLIP-DeepfakeDetection-main/Word_Frequency_Analysis.py +147 -0
- detector_codes/C2P-CLIP-DeepfakeDetection-main/data/__init__.py +43 -0
- detector_codes/C2P-CLIP-DeepfakeDetection-main/data/datasets.py +213 -0
- detector_codes/C2P-CLIP-DeepfakeDetection-main/decode_clipfeature_dataset.py +113 -0
- detector_codes/C2P-CLIP-DeepfakeDetection-main/draw_tsne_kmean.py +301 -0
- detector_codes/C2P-CLIP-DeepfakeDetection-main/inference.py +164 -0
- detector_codes/C2P-CLIP-DeepfakeDetection-main/networks/__init__.py +0 -0
- detector_codes/C2P-CLIP-DeepfakeDetection-main/networks/base_model.py +108 -0
- detector_codes/C2P-CLIP-DeepfakeDetection-main/networks/c2p_clip.py +59 -0
- detector_codes/C2P-CLIP-DeepfakeDetection-main/networks/decode_clipfeature_image.py +260 -0
- detector_codes/C2P-CLIP-DeepfakeDetection-main/networks/trainer.py +193 -0
- detector_codes/C2P-CLIP-DeepfakeDetection-main/options/__init__.py +0 -0
- detector_codes/C2P-CLIP-DeepfakeDetection-main/options/base_options.py +131 -0
- detector_codes/C2P-CLIP-DeepfakeDetection-main/options/test_options.py +17 -0
- detector_codes/C2P-CLIP-DeepfakeDetection-main/options/train_options.py +26 -0
- detector_codes/C2P-CLIP-DeepfakeDetection-main/requirements.txt +93 -0
- detector_codes/C2P-CLIP-DeepfakeDetection-main/train.py +150 -0
- detector_codes/C2P-CLIP-DeepfakeDetection-main/train_UniversalFakeDetect.sh +21 -0
- detector_codes/C2P-CLIP-DeepfakeDetection-main/train_aigibench.py +231 -0
- detector_codes/C2P-CLIP-DeepfakeDetection-main/train_aigibench.sh +24 -0
- detector_codes/C2P-CLIP-DeepfakeDetection-main/train_genimage.sh +19 -0
- detector_codes/C2P-CLIP-DeepfakeDetection-main/utils/logger.py +214 -0
- detector_codes/C2P-CLIP-DeepfakeDetection-main/utils/util.py +52 -0
- detector_codes/C2P-CLIP-DeepfakeDetection-main/validate.py +23 -0
- detector_codes/C2P-DINOv2-main/dataset.py +45 -0
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.12
|
LICENSE
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Licenses
|
| 2 |
+
|
| 3 |
+
Unless specifically labeled otherwise, these Datasets are provided to You under the terms of the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License (“CC BY-NC-SA 4.0”), with the additional terms included herein.
|
| 4 |
+
The CC BY-NC-SA 4.0 may be accessed at https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode. When You download or use the Datasets from the Website or elsewhere, You are agreeing to comply with the terms of CC BY-NC-SA 4.0, and also agreeing to the Dataset Terms.
|
| 5 |
+
Where these Dataset Terms conflict with the terms of CC BY-NC-SA 4.0, these Dataset Terms shall prevail. We reiterate once again that this dataset is used only for non-commercial purposes such as academic research, teaching, or scientific publications.
|
| 6 |
+
We prohibits You from using the dataset or any derivative works for commercial purposes, such as selling data or using it for commercial gain.
|
README.md
CHANGED
|
@@ -1,14 +1,189 @@
|
|
| 1 |
---
|
| 2 |
-
title: DeForge
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
colorTo: gray
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 6.14.0
|
| 8 |
-
python_version: '3.
|
| 9 |
app_file: app.py
|
| 10 |
pinned: false
|
| 11 |
-
license: cc-by-nc-sa-4.0
|
| 12 |
---
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: DeForge AI
|
| 3 |
+
emoji: 📉
|
| 4 |
+
colorFrom: yellow
|
| 5 |
colorTo: gray
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 6.14.0
|
| 8 |
+
python_version: '3.12'
|
| 9 |
app_file: app.py
|
| 10 |
pinned: false
|
|
|
|
| 11 |
---
|
| 12 |
|
| 13 |
+
<div align="center">
|
| 14 |
+
<br>
|
| 15 |
+
<h1>Is Artificial Intelligence Generated Image Detection a Solved Problem?</h1>
|
| 16 |
+
|
| 17 |
+
[Ziqiang Li](https://scholar.google.com/citations?user=mj5a8WgAAAAJ&hl=zh-CN)<sup>1</sup>, [Jiazhen Yan](https://scholar.google.com/citations?user=QkURh8EAAAAJ&hl=zh-CN)<sup>1</sup>, [Ziwen He](https://scholar.google.com/citations?user=PjkDK9cAAAAJ&hl=zh-CN)<sup>1</sup>, [Kai Zeng](https://scholar.google.com.hk/citations?user=TsI93SIAAAAJ&hl=zh-CN)<sup>2</sup>, [Weiwei Jiang](https://scholar.google.co.jp/citations?user=mbPN0hgAAAAJ&hl=zh-CN)<sup>1</sup>, [Lizhi Xiong](https://scholar.google.com/citations?user=-FzrEP4AAAAJ&hl=zh-CN)<sup>1</sup>, [Zhangjie Fu](https://scholar.google.com/citations?user=fO9NmagAAAAJ&hl=zh-CN)<sup>1‡</sup>
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
<div class="is-size-6 publication-authors">
|
| 21 |
+
<p class="footnote">
|
| 22 |
+
<span class="footnote-symbol"><sup>‡</sup></span>Corresponding author
|
| 23 |
+
</p>
|
| 24 |
+
</div>
|
| 25 |
+
|
| 26 |
+
<sup>1</sup>Nanjing University of Information Science and Technology <sup>2</sup>University of Siena
|
| 27 |
+
<p align="center">
|
| 28 |
+
<a href='https://github.com/HorizonTEL/AIGIBench'>
|
| 29 |
+
<img src='https://img.shields.io/badge/Project-Page-pink?style=flat&logo=Google%20chrome&logoColor=pink'>
|
| 30 |
+
</a>
|
| 31 |
+
<a href='https://arxiv.org/abs/2505.12335'>
|
| 32 |
+
<img src='https://img.shields.io/badge/Arxiv-2406.19435-A42C25?style=flat&logo=arXiv&logoColor=A42C25'>
|
| 33 |
+
</a>
|
| 34 |
+
<a href='https://arxiv.org/pdf/2505.12335'>
|
| 35 |
+
<img src='https://img.shields.io/badge/Paper-PDF-yellow?style=flat&logo=arXiv&logoColor=yellow'>
|
| 36 |
+
</a>
|
| 37 |
+
</p>
|
| 38 |
+
</div>
|
| 39 |
+
|
| 40 |
+
## 🔥 News
|
| 41 |
+
* [2025-09-19]🎉🎉🎉 AIGIBench is accepted by NeurIPS 2025 Datasets and Benchmarks.
|
| 42 |
+
|
| 43 |
+
##
|
| 44 |
+
|
| 45 |
+
**This repository is the official repository of the AIGIBench.**
|
| 46 |
+
|
| 47 |
+
> [!NOTE]
|
| 48 |
+
> This is a **modified version** of the original [AIGIBench](https://github.com/HorizonTEL/AIGIBench) repository. In addition to the original dataset and methods, it includes my custom detection solutions: **DeForge-AI** and **C2P-DINOv2** (intermediary solution).
|
| 49 |
+
|
| 50 |
+
**This repository contains the AIGIBench dataset and the evaluated methods.**
|
| 51 |
+
|
| 52 |
+
**AIGIBench** dataset contains two types of training and 25 test subsets. This dataset has the following advantages:
|
| 53 |
+
- Comprehensive generate types: including GAN-based Noise-to-Image Generation, Diffusion for Text-to-Image Generation, GANs for Deepfake, Diffusion for Personalized Generation, and Open-source Platforms.
|
| 54 |
+
- State-of-the-art Generators: MidjourneyV6, Stable Diffusion 3, Imagen, DALLE3, InstantID, FaceSwap, StyleGAN-XL and so on.
|
| 55 |
+
- Completely unknown generation method: Crawl pictures from communities and social media to build datasets CommunityAI & SocialRF, making detection more challenging.
|
| 56 |
+
|
| 57 |
+

|
| 58 |
+
|
| 59 |
+
If this project helps you, please fork, watch, and give a star to this repository.
|
| 60 |
+
|
| 61 |
+
## 📚Dataset
|
| 62 |
+
The training set and testing set used in the paper can be downloaded on [Huggingface](https://huggingface.co/datasets/HorizonTEL/AIGIBench)/[Baidu Netdisk](https://pan.baidu.com/s/1XTwfXlfqkGxAwYLxXuZbfA?pwd=sm6v).
|
| 63 |
+
|
| 64 |
+
Each folder contains compressed files. After unzip the file, files under the data root directory can be organized as follows.
|
| 65 |
+
### Train
|
| 66 |
+
AIGIBench introduces two training dataset settings: **(i) Setting-I:** Training on 144K images generated by ProGAN across four object categories—car, cat, chair, and horse. **(ii) Setting-II:** Training on 144K images generated by both SD-v1.4 and ProGAN, covering the same four object categories. The data of ProGAN comes from ForenSynths, and the data of sdv1.4 comes from GenImage. In order to maintain the fairness of the training data, we randomly select the sdv1.4 training images of GenImage to keep the same number as ProGAN, and then merge the data. The file directory is as follows:
|
| 67 |
+
|
| 68 |
+
```
|
| 69 |
+
├── train
|
| 70 |
+
│ ├── car
|
| 71 |
+
│ │ ├── 0_real
|
| 72 |
+
│ │ ├── 1_fake
|
| 73 |
+
│ ├── cat
|
| 74 |
+
│ │ ├── ...
|
| 75 |
+
│ ├── chair
|
| 76 |
+
│ │ ├── ...
|
| 77 |
+
│ ├── horse
|
| 78 |
+
│ │ ├── ...
|
| 79 |
+
│ ├── sdv1.4
|
| 80 |
+
│ │ ├── 0_real
|
| 81 |
+
│ │ ├── 1_fake
|
| 82 |
+
├── val
|
| 83 |
+
│ ├── ...
|
| 84 |
+
│ │ ├── 0_real
|
| 85 |
+
│ │ ├── 1_fake
|
| 86 |
+
│ │ ...
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
### Test
|
| 90 |
+
AIGIBench comprehensively tests the performance of the detector and builds a test dataset from five perspectives: GAN-based Noise-to-Image Generation, Diffusion for Text-to-Image Generation, GANs for Deepfake, Diffusion for Personalized Generation, and Open-source Platforms. The file directory is as follows:
|
| 91 |
+
```
|
| 92 |
+
├── test
|
| 93 |
+
│ ├── ProGAN
|
| 94 |
+
│ │ ├── 0_real
|
| 95 |
+
│ │ ├── 1_fake
|
| 96 |
+
│ ├── R3GAN
|
| 97 |
+
│ │ ├── ...
|
| 98 |
+
│ │ ...
|
| 99 |
+
│ ├── BlendFace
|
| 100 |
+
│ │ ├── 0_real
|
| 101 |
+
│ │ ├── 1_fake
|
| 102 |
+
│ ├── InSwap
|
| 103 |
+
│ │ ├── ...
|
| 104 |
+
│ │ ...
|
| 105 |
+
│ ├── FLUX1-dev
|
| 106 |
+
│ │ ├── 0_real
|
| 107 |
+
│ │ ├── 1_fake
|
| 108 |
+
│ ├── Midjourney-V6
|
| 109 |
+
│ │ ├── ...
|
| 110 |
+
│ │ ...
|
| 111 |
+
│ ├── BLIP
|
| 112 |
+
│ │ ├── 0_real
|
| 113 |
+
│ │ ├── 1_fake
|
| 114 |
+
│ ├── Infinite-ID
|
| 115 |
+
│ │ ├── ...
|
| 116 |
+
│ │ ...
|
| 117 |
+
│ ├── CommunityAI
|
| 118 |
+
│ │ ├── 0_real
|
| 119 |
+
│ │ ├── 1_fake
|
| 120 |
+
│ ├── SocialRF
|
| 121 |
+
│ │ ├── ...
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
*Note: The test set count in the paper contained some errors, which we are correcting here. Please note that the number of real images and generated images are consistent; only the number of generated images is listed below.*
|
| 125 |
+
| Generator | Number |
|
| 126 |
+
|:------: |:---------:|
|
| 127 |
+
| CommunityAI | 6000 |
|
| 128 |
+
| SocialRF | 3000 |
|
| 129 |
+
| FaceSwap | 4000 |
|
| 130 |
+
| ImSwap | 4000 |
|
| 131 |
+
| WFIR | 1000 |
|
| 132 |
+
|
| 133 |
+
## 🔍Detection Methods
|
| 134 |
+
We use the official code for all detection codes and make unified modifications to the input and output. The code we use for training in Setting-II is publicly available above, the corresponding pre-trained checkpoints are publicly available on [Huggingface](https://huggingface.co/HorizonTEL/AIGIBench). Of course, if you need the code from the original paper, the following is the corresponding detection code in the paper:
|
| 135 |
+
- [ResNet-50](https://github.com/huggingface/pytorch-image-models/tree/v0.6.12/timm): Deep Residual Learning for Image Recognition
|
| 136 |
+
- [CNNDetection](https://github.com/PeterWang512/CNNDetection): CNN-generated images are surprisingly easy to spot...for now
|
| 137 |
+
- [GramNet](https://github.com/liuzhengzhe/Global_Texture_Enhancement_for_Fake_Face_Detection_in_the-Wild): Global Texture Enhancement for Fake Face Detection in the Wild
|
| 138 |
+
- [LGrad](https://github.com/chuangchuangtan/LGrad): Learning on Gradients: Generalized Artifacts Representation for GAN-Generated Images Detection
|
| 139 |
+
- [CLIPDetection](https://github.com/WisconsinAIVision/UniversalFakeDetect): Towards Universal Fake Image Detectors that Generalize Across Generative Models
|
| 140 |
+
- [FreqNet](https://github.com/chuangchuangtan/FreqNet-DeepfakeDetection): FreqNet: A Frequency-domain Image Super-Resolution Network with Dicrete Cosine Transform
|
| 141 |
+
- [NPR](https://github.com/chuangchuangtan/NPR-DeepfakeDetection): Rethinking the Up-Sampling Operations in CNN-based Generative Network for Generalizable Deepfake Detection
|
| 142 |
+
- [DFFreq](https://github.com/HorizonTEL/DFFreq-main): Dual Frequency Branch Framework with Reconstructed Sliding Windows Attention for AI-Generated Image Detection
|
| 143 |
+
- [LaDeDa](https://github.com/barcavia/RealTime-DeepfakeDetection-in-the-RealWorld): Real-Time Deepfake Detection in the Real-World
|
| 144 |
+
- [AIDE](https://github.com/shilinyan99/AIDE): A Sanity Check for AI-generated Image Detection
|
| 145 |
+
- [SAFE](https://github.com/Ouxiang-Li/SAFE): Improving Synthetic Image Detection Towards Generalization: An Image Transformation Perspectives
|
| 146 |
+
- [Effort](https://github.com/YZY-stack/Effort-AIGI-Detection): Orthogonal Subspace Decomposition for Generalizable AI-Generated Image Detection
|
| 147 |
+
|
| 148 |
+
## ⏳Detection Results (Continuously updating)
|
| 149 |
+
**To ensure a fair comparison, we retrain all baseline methods on the Setting-II of AIGIBench.**
|
| 150 |
+
|
| 151 |
+
_If your retrained results differ significantly from those shown, please contact us._
|
| 152 |
+
| Method | Paper | Ref | R.Acc. | F.Acc. | Acc. | A.P. |
|
| 153 |
+
|:------: |:---------: |:---------:|:------:|:------:|:----:|:----:|
|
| 154 |
+
| CNNDetection | CNN-generated images are surprisingly easy to spot... for now | CVPR 2020 |**98.2**| 11.6 | 54.9 | 67.0 |
|
| 155 |
+
| Gram-Net | Global Texture Enhancement for Fake Face Detection In the Wild | CVPR 2020 | 90.5 | 26.6 | 58.6 | 62.4 |
|
| 156 |
+
| LGrad | Learning on Gradients: Generalized Artifacts Representation for GAN-Generated Images Detection | CVPR 2023 | 85.8 | 39.6 | 62.9 | 66.6 |
|
| 157 |
+
| UniFD | Towards Universal Fake Image Detectors that Generalize Across Generative Models | CVPR 2023 | 73.3 | 71.5 | 72.5 | 75.6 |
|
| 158 |
+
| FreqNet | Frequency-Aware Deepfake Detection: Improving Generalizability through Frequency Space Learning | AAAI 2024 | 65.9 | 66.4 | 66.2 | 70.1 |
|
| 159 |
+
| NPR | Rethinking the Up-Sampling Operations in CNN-based Generative Network for Generalizable Deepfake Detection | CVPR 2024 | 93.8 | 41.9 | 67.9 | 73.9 |
|
| 160 |
+
| Ladeda | Real-Time Deepfake Detection in the Real-World | Arxiv 2024| 91.7 | 54.9 | 73.4 | 79.3 |
|
| 161 |
+
| DFFreq | Dual Frequency Branch Framework with Reconstructed Sliding Windows Attention for AI-Generated Image Detection | TIFS 2026 | 91.8 | 58.0 | 75.1 | 82.2 |
|
| 162 |
+
| C2P-CLIP* | C2P-CLIP: Injecting Category Common Prompt in CLIP to Enhance Generalization in Deepfake Detection | AAAI 2025 | 93.8 | 49.8 | 71.8 | 82.2 |
|
| 163 |
+
| AIDE | A Sanity Check for AI-generated Image Detection | ICLR 2025 | 88.1 | 67.0 | 77.6 | 82.7 |
|
| 164 |
+
| SAFE | Improving Synthetic Image Detection Towards Generalization: An Image Transformation Perspectives | KDD 2025 | 89.0 | 66.6 | 78.1 | 83.6 |
|
| 165 |
+
| VIB-Net | Towards Universal AI-Generated Image Detection by Variational Information Bottleneck Network | CVPR 2025 | 60.6 |**78.1**| 69.3 | 70.9 |
|
| 166 |
+
| $D^3$ | $D^3$: Scaling Up Deepfake Detection by Learning from Discrepancy | CVPR 2025 | 81.0 | 46.4 | 63.7 | 68.9 |
|
| 167 |
+
| Effort | Orthogonal Subspace Decomposition for Generalizable AI-Generated Image Detection | ICML 2025 | 96.9 | 57.1 | 77.1 |**87.2**|
|
| 168 |
+
| FerretNet | FerretNet: Efficient Synthetic Image Detection via Local Pixel Dependencies | NIPS 2025 | 96.6 | 61.8 |**79.4**| 85.8 |
|
| 169 |
+
| LOTA | LOTA: Bit-Planes Guided AI-Generated Image Detection | ICCV 2025 | 89.3 | 65.1 | 77.4 | 83.1 |
|
| 170 |
+
| BSF | Beyond Semantic Features: Pixel-level Mapping for Generalized AI-Generated Image Detection | AAAI 2026 | 91.5 | 65.6 | 78.8 | 81.1 |
|
| 171 |
+
| LTD | Layer Consistency Matters: Elegant Latent Transition Discrepancy for Generalizable Synthetic Image Detection | CVPR 2026 | 82.0 | 67.7 | 74.9 | 77.6 |
|
| 172 |
+
|
| 173 |
+
**For specific reasons, in the following method, we directly utilize the official pre-trained weights for inference.**
|
| 174 |
+
| Method | Paper | Ref | R.Acc. | F.Acc. | Acc. | A.P. |
|
| 175 |
+
|:------: |:---------: |:---------:|:------:|:------:|:----:|:----:|
|
| 176 |
+
| DDA | Dual Data Alignment Makes AI-Generated Image Detector Easier Generalizable | NIPS 2025 | 93.9 | 69.3 | 81.6 | 90.2 |
|
| 177 |
+
|
| 178 |
+
## Citation
|
| 179 |
+
```
|
| 180 |
+
@inproceedings{li2025artificial,
|
| 181 |
+
title={Is Artificial Intelligence Generated Image Detection a Solved Problem?},
|
| 182 |
+
author={Li, Ziqiang and Yan, Jiazhen and He, Ziwen and Zeng, Kai and Jiang, Weiwei and Xiong, Lizhi and Fu, Zhangjie},
|
| 183 |
+
booktitle={Advances in Neural Information Processing Systems},
|
| 184 |
+
year={2025}
|
| 185 |
+
}
|
| 186 |
+
```
|
| 187 |
+
|
| 188 |
+
## Contact
|
| 189 |
+
If you have any question about this project, please feel free to contact 247918horizon@gmail.com
|
app.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
|
| 3 |
+
from detector_codes import DEVICE, detector_classes, weight_mapping
|
| 4 |
+
|
| 5 |
+
# Model cache to avoid reloading
|
| 6 |
+
model_cache = {'name': None, 'instance': None}
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def predict(model_name, input_image):
|
| 10 |
+
if input_image is None:
|
| 11 |
+
return 'Please upload an image.'
|
| 12 |
+
|
| 13 |
+
global model_cache
|
| 14 |
+
|
| 15 |
+
# Load model if not in cache or if model changed
|
| 16 |
+
if model_cache['name'] != model_name:
|
| 17 |
+
print(f'Loading model: {model_name}...')
|
| 18 |
+
try:
|
| 19 |
+
detector_class = detector_classes[model_name]
|
| 20 |
+
weights = weight_mapping[model_name]
|
| 21 |
+
model_cache['instance'] = detector_class(weights)
|
| 22 |
+
model_cache['name'] = model_name
|
| 23 |
+
except Exception as e:
|
| 24 |
+
return f'Error loading model {model_name}: {str(e)}'
|
| 25 |
+
|
| 26 |
+
detector = model_cache['instance']
|
| 27 |
+
|
| 28 |
+
# Preprocess image
|
| 29 |
+
try:
|
| 30 |
+
img_tensor = detector.transform(input_image).unsqueeze(0).to(DEVICE)
|
| 31 |
+
|
| 32 |
+
# Inference
|
| 33 |
+
p_fake = detector.detect(img_tensor).item()
|
| 34 |
+
p_real = 1.0 - p_fake
|
| 35 |
+
|
| 36 |
+
return {'Real Image': p_real, 'AI Generated': p_fake}
|
| 37 |
+
except Exception as e:
|
| 38 |
+
return f'Error during inference: {str(e)}'
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# Define the Gradio interface
|
| 42 |
+
with gr.Blocks(title='AIGI Detector Bench') as demo:
|
| 43 |
+
gr.Markdown('# AIGI Detector Benchmark')
|
| 44 |
+
gr.Markdown(
|
| 45 |
+
"Select a model and upload an image to check if it's AI-generated or real."
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
with gr.Row():
|
| 49 |
+
with gr.Column():
|
| 50 |
+
model_dropdown = gr.Dropdown(
|
| 51 |
+
choices=sorted(list(detector_classes.keys())),
|
| 52 |
+
value='DeForge-AI',
|
| 53 |
+
label='Select Detection Model',
|
| 54 |
+
)
|
| 55 |
+
input_img = gr.Image(type='pil', label='Upload Image')
|
| 56 |
+
btn = gr.Button('Detect', variant='primary')
|
| 57 |
+
|
| 58 |
+
with gr.Column():
|
| 59 |
+
output_label = gr.Label(num_top_classes=2, label='Prediction')
|
| 60 |
+
|
| 61 |
+
btn.click(fn=predict, inputs=[model_dropdown, input_img], outputs=output_label)
|
| 62 |
+
|
| 63 |
+
gr.Markdown("""
|
| 64 |
+
### About
|
| 65 |
+
This tool is a **modified version** of the official [AIGIBench](https://github.com/HorizonTEL/AIGIBench) repository, featuring state-of-the-art AI-Generated Image (AIGI) detectors.
|
| 66 |
+
|
| 67 |
+
In this version, I have integrated the original baselines along with my own proposed solutions: **DeForge-AI** and **C2P-DINOv2**.
|
| 68 |
+
|
| 69 |
+
- **Project Page**: [NeurIPS 2025] Is Artificial Intelligence Generated Image Detection a Solved Problem?
|
| 70 |
+
- **Original Repository**: [HorizonTEL/AIGIBench](https://github.com/HorizonTEL/AIGIBench)
|
| 71 |
+
|
| 72 |
+
#### Featured Models:
|
| 73 |
+
- **DeForge-AI**: My proposed multi-modal forensic detector (optimized for diverse generators).
|
| 74 |
+
- **C2P-DINOv2**: My solution leveraging DINOv2 features (intermediary solution).
|
| 75 |
+
- **RIGID, AIDE, SAFE, Effort, NPR, LaDeDa, etc.**: Original SOTA baselines.
|
| 76 |
+
|
| 77 |
+
Each model has different strengths. DeForge-AI generally provides the best performance across diverse generators.
|
| 78 |
+
""")
|
| 79 |
+
|
| 80 |
+
if __name__ == '__main__':
|
| 81 |
+
demo.launch()
|
detector_codes/AIDE-main/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2024 Shilin Yan
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
detector_codes/AIDE-main/data/__init__.py
ADDED
|
File without changes
|
detector_codes/AIDE-main/data/datasets.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
|
| 5 |
+
# This source code is licensed under the license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
from torchvision import transforms
|
| 11 |
+
from torch.utils.data import Dataset
|
| 12 |
+
|
| 13 |
+
from PIL import Image
|
| 14 |
+
import io
|
| 15 |
+
import torch
|
| 16 |
+
from .dct import DCT_base_Rec_Module
|
| 17 |
+
import random
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
from torchvision.transforms import InterpolationMode
|
| 21 |
+
BICUBIC = InterpolationMode.BICUBIC
|
| 22 |
+
except ImportError:
|
| 23 |
+
BICUBIC = Image.BICUBIC
|
| 24 |
+
|
| 25 |
+
from PIL import ImageFile
|
| 26 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 27 |
+
import kornia.augmentation as K
|
| 28 |
+
|
| 29 |
+
Perturbations = K.container.ImageSequential(
|
| 30 |
+
K.RandomGaussianBlur(kernel_size=(3, 3), sigma=(0.1, 3.0), p=0.1),
|
| 31 |
+
K.RandomJPEG(jpeg_quality=(30, 100), p=0.1)
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
transform_before = transforms.Compose([
|
| 35 |
+
transforms.ToTensor(),
|
| 36 |
+
transforms.Lambda(lambda x: Perturbations(x)[0])
|
| 37 |
+
]
|
| 38 |
+
)
|
| 39 |
+
transform_before_test = transforms.Compose([
|
| 40 |
+
transforms.ToTensor(),
|
| 41 |
+
]
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
transform_train = transforms.Compose([
|
| 45 |
+
transforms.Resize([256, 256]),
|
| 46 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
transform_test_normalize = transforms.Compose([
|
| 50 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class TrainDataset(Dataset):
|
| 55 |
+
def __init__(self, is_train, args):
|
| 56 |
+
|
| 57 |
+
root = args.data_path if is_train else args.eval_data_path
|
| 58 |
+
|
| 59 |
+
self.data_list = []
|
| 60 |
+
|
| 61 |
+
if'GenImage' in root and root.split('/')[-1] != 'train':
|
| 62 |
+
file_path = root
|
| 63 |
+
|
| 64 |
+
if '0_real' not in os.listdir(file_path):
|
| 65 |
+
for folder_name in os.listdir(file_path):
|
| 66 |
+
|
| 67 |
+
assert (os.listdir(os.path.join(file_path, folder_name)) == ['0_real', '1_fake']) or (os.listdir(os.path.join(file_path, folder_name)) == ['1_fake', '0_real'])
|
| 68 |
+
|
| 69 |
+
for image_path in os.listdir(os.path.join(file_path, folder_name, '0_real')):
|
| 70 |
+
self.data_list.append({"image_path": os.path.join(file_path, folder_name, '0_real', image_path), "label" : 0})
|
| 71 |
+
|
| 72 |
+
for image_path in os.listdir(os.path.join(file_path, folder_name, '1_fake')):
|
| 73 |
+
self.data_list.append({"image_path": os.path.join(file_path, folder_name, '1_fake', image_path), "label" : 1})
|
| 74 |
+
|
| 75 |
+
else:
|
| 76 |
+
for image_path in os.listdir(os.path.join(file_path, '0_real')):
|
| 77 |
+
self.data_list.append({"image_path": os.path.join(file_path, '0_real', image_path), "label" : 0})
|
| 78 |
+
for image_path in os.listdir(os.path.join(file_path, '1_fake')):
|
| 79 |
+
self.data_list.append({"image_path": os.path.join(file_path, '1_fake', image_path), "label" : 1})
|
| 80 |
+
else:
|
| 81 |
+
|
| 82 |
+
for filename in os.listdir(root):
|
| 83 |
+
|
| 84 |
+
file_path = os.path.join(root, filename)
|
| 85 |
+
|
| 86 |
+
if '0_real' not in os.listdir(file_path):
|
| 87 |
+
for folder_name in os.listdir(file_path):
|
| 88 |
+
|
| 89 |
+
assert (os.listdir(os.path.join(file_path, folder_name)) == ['0_real', '1_fake']) or (os.listdir(os.path.join(file_path, folder_name)) == ['1_fake', '0_real'])
|
| 90 |
+
|
| 91 |
+
for image_path in os.listdir(os.path.join(file_path, folder_name, '0_real')):
|
| 92 |
+
self.data_list.append({"image_path": os.path.join(file_path, folder_name, '0_real', image_path), "label" : 0})
|
| 93 |
+
|
| 94 |
+
for image_path in os.listdir(os.path.join(file_path, folder_name, '1_fake')):
|
| 95 |
+
self.data_list.append({"image_path": os.path.join(file_path, folder_name, '1_fake', image_path), "label" : 1})
|
| 96 |
+
|
| 97 |
+
else:
|
| 98 |
+
for image_path in os.listdir(os.path.join(file_path, '0_real')):
|
| 99 |
+
self.data_list.append({"image_path": os.path.join(file_path, '0_real', image_path), "label" : 0})
|
| 100 |
+
for image_path in os.listdir(os.path.join(file_path, '1_fake')):
|
| 101 |
+
self.data_list.append({"image_path": os.path.join(file_path, '1_fake', image_path), "label" : 1})
|
| 102 |
+
|
| 103 |
+
self.dct = DCT_base_Rec_Module()
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def __len__(self):
|
| 107 |
+
return len(self.data_list)
|
| 108 |
+
|
| 109 |
+
def __getitem__(self, index):
|
| 110 |
+
|
| 111 |
+
sample = self.data_list[index]
|
| 112 |
+
|
| 113 |
+
image_path, targets = sample['image_path'], sample['label']
|
| 114 |
+
|
| 115 |
+
try:
|
| 116 |
+
image = Image.open(image_path).convert('RGB')
|
| 117 |
+
except:
|
| 118 |
+
print(f'image error: {image_path}')
|
| 119 |
+
return self.__getitem__(random.randint(0, len(self.data_list) - 1))
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
image = transform_before(image)
|
| 123 |
+
|
| 124 |
+
try:
|
| 125 |
+
x_minmin, x_maxmax, x_minmin1, x_maxmax1 = self.dct(image)
|
| 126 |
+
except:
|
| 127 |
+
print(f'image error: {image_path}, c, h, w: {image.shape}')
|
| 128 |
+
return self.__getitem__(random.randint(0, len(self.data_list) - 1))
|
| 129 |
+
|
| 130 |
+
x_0 = transform_train(image)
|
| 131 |
+
x_minmin = transform_train(x_minmin)
|
| 132 |
+
x_maxmax = transform_train(x_maxmax)
|
| 133 |
+
|
| 134 |
+
x_minmin1 = transform_train(x_minmin1)
|
| 135 |
+
x_maxmax1 = transform_train(x_maxmax1)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
return torch.stack([x_minmin, x_maxmax, x_minmin1, x_maxmax1, x_0], dim=0), torch.tensor(int(targets))
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class TestDataset(Dataset):
|
| 144 |
+
def __init__(self, is_train, args):
|
| 145 |
+
|
| 146 |
+
root = args.data_path if is_train else args.eval_data_path
|
| 147 |
+
|
| 148 |
+
self.data_list = []
|
| 149 |
+
|
| 150 |
+
file_path = root
|
| 151 |
+
|
| 152 |
+
if '0_real' not in os.listdir(file_path):
|
| 153 |
+
for folder_name in os.listdir(file_path):
|
| 154 |
+
|
| 155 |
+
assert (os.listdir(os.path.join(file_path, folder_name)) == ['0_real', '1_fake']) or (os.listdir(os.path.join(file_path, folder_name)) == ['1_fake', '0_real'])
|
| 156 |
+
|
| 157 |
+
for image_path in os.listdir(os.path.join(file_path, folder_name, '0_real')):
|
| 158 |
+
self.data_list.append({"image_path": os.path.join(file_path, folder_name, '0_real', image_path), "label" : 0})
|
| 159 |
+
|
| 160 |
+
for image_path in os.listdir(os.path.join(file_path, folder_name, '1_fake')):
|
| 161 |
+
self.data_list.append({"image_path": os.path.join(file_path, folder_name, '1_fake', image_path), "label" : 1})
|
| 162 |
+
|
| 163 |
+
else:
|
| 164 |
+
for image_path in os.listdir(os.path.join(file_path, '0_real')):
|
| 165 |
+
self.data_list.append({"image_path": os.path.join(file_path, '0_real', image_path), "label" : 0})
|
| 166 |
+
for image_path in os.listdir(os.path.join(file_path, '1_fake')):
|
| 167 |
+
self.data_list.append({"image_path": os.path.join(file_path, '1_fake', image_path), "label" : 1})
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
self.dct = DCT_base_Rec_Module()
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def __len__(self):
|
| 174 |
+
return len(self.data_list)
|
| 175 |
+
|
| 176 |
+
def __getitem__(self, index):
|
| 177 |
+
|
| 178 |
+
sample = self.data_list[index]
|
| 179 |
+
|
| 180 |
+
image_path, targets = sample['image_path'], sample['label']
|
| 181 |
+
|
| 182 |
+
image = Image.open(image_path).convert('RGB')
|
| 183 |
+
|
| 184 |
+
image = transform_before_test(image)
|
| 185 |
+
|
| 186 |
+
# x_max, x_min, x_max_min, x_minmin = self.dct(image)
|
| 187 |
+
|
| 188 |
+
x_minmin, x_maxmax, x_minmin1, x_maxmax1 = self.dct(image)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
x_0 = transform_train(image) # 上采样到256*256
|
| 192 |
+
x_minmin = transform_train(x_minmin)
|
| 193 |
+
x_maxmax = transform_train(x_maxmax)
|
| 194 |
+
|
| 195 |
+
x_minmin1 = transform_train(x_minmin1)
|
| 196 |
+
x_maxmax1 = transform_train(x_maxmax1)
|
| 197 |
+
|
| 198 |
+
return torch.stack([x_minmin, x_maxmax, x_minmin1, x_maxmax1, x_0], dim=0), torch.tensor(int(targets))
|
| 199 |
+
|
detector_codes/AIDE-main/data/dct.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def DCT_mat(size):
|
| 9 |
+
m = [[ (np.sqrt(1./size) if i == 0 else np.sqrt(2./size)) * np.cos((j + 0.5) * np.pi * i / size) for j in range(size)] for i in range(size)]
|
| 10 |
+
return m
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def generate_filter(start, end, size):
|
| 14 |
+
return [[0. if i + j > end or i + j < start else 1. for j in range(size)] for i in range(size)]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def norm_sigma(x):
|
| 18 |
+
return 2. * torch.sigmoid(x) - 1.
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class Filter(nn.Module):
|
| 22 |
+
def __init__(self, size, band_start, band_end, use_learnable=False, norm=False):
|
| 23 |
+
super(Filter, self).__init__()
|
| 24 |
+
self.use_learnable = use_learnable
|
| 25 |
+
self.base = nn.Parameter(torch.tensor(generate_filter(band_start, band_end, size)), requires_grad=False)
|
| 26 |
+
if self.use_learnable:
|
| 27 |
+
self.learnable = nn.Parameter(torch.randn(size, size), requires_grad=True)
|
| 28 |
+
self.learnable.data.normal_(0., 0.1)
|
| 29 |
+
self.norm = norm
|
| 30 |
+
if norm:
|
| 31 |
+
self.ft_num = nn.Parameter(torch.sum(torch.tensor(generate_filter(band_start, band_end, size))), requires_grad=False)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
if self.use_learnable:
|
| 36 |
+
filt = self.base + norm_sigma(self.learnable)
|
| 37 |
+
else:
|
| 38 |
+
filt = self.base
|
| 39 |
+
|
| 40 |
+
if self.norm:
|
| 41 |
+
y = x * filt / self.ft_num
|
| 42 |
+
else:
|
| 43 |
+
y = x * filt
|
| 44 |
+
return y
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class DCT_base_Rec_Module(nn.Module):
|
| 48 |
+
"""_summary_
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
x: [C, H, W] -> [C*level, output, output]
|
| 52 |
+
"""
|
| 53 |
+
def __init__(self, window_size=32, stride=16, output=256, grade_N=6, level_fliter=[0]):
|
| 54 |
+
super().__init__()
|
| 55 |
+
|
| 56 |
+
assert output % window_size == 0
|
| 57 |
+
assert len(level_fliter) > 0
|
| 58 |
+
|
| 59 |
+
self.window_size = window_size
|
| 60 |
+
self.grade_N = grade_N
|
| 61 |
+
self.level_N = len(level_fliter)
|
| 62 |
+
self.N = (output // window_size) * (output // window_size)
|
| 63 |
+
|
| 64 |
+
self._DCT_patch = nn.Parameter(torch.tensor(DCT_mat(window_size)).float(), requires_grad=False)
|
| 65 |
+
self._DCT_patch_T = nn.Parameter(torch.transpose(torch.tensor(DCT_mat(window_size)).float(), 0, 1), requires_grad=False)
|
| 66 |
+
|
| 67 |
+
self.unfold = nn.Unfold(
|
| 68 |
+
kernel_size=(window_size, window_size), stride=stride
|
| 69 |
+
)
|
| 70 |
+
self.fold0 = nn.Fold(
|
| 71 |
+
output_size=(window_size, window_size),
|
| 72 |
+
kernel_size=(window_size, window_size),
|
| 73 |
+
stride=window_size
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
lm, mh = 2.82, 2
|
| 77 |
+
level_f = [
|
| 78 |
+
Filter(window_size, 0, window_size * 2)
|
| 79 |
+
]
|
| 80 |
+
|
| 81 |
+
self.level_filters = nn.ModuleList([level_f[i] for i in level_fliter])
|
| 82 |
+
self.grade_filters = nn.ModuleList([Filter(window_size, window_size * 2. / grade_N * i, window_size * 2. / grade_N * (i+1), norm=True) for i in range(grade_N)])
|
| 83 |
+
|
| 84 |
+
def forward(self, x):
|
| 85 |
+
|
| 86 |
+
N = self.N
|
| 87 |
+
grade_N = self.grade_N
|
| 88 |
+
level_N = self.level_N
|
| 89 |
+
window_size = self.window_size
|
| 90 |
+
C, W, H = x.shape
|
| 91 |
+
x_unfold = self.unfold(x.unsqueeze(0)).squeeze(0)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
_, L = x_unfold.shape
|
| 95 |
+
x_unfold = x_unfold.transpose(0, 1).reshape(L, C, window_size, window_size)
|
| 96 |
+
x_dct = self._DCT_patch @ x_unfold @ self._DCT_patch_T
|
| 97 |
+
|
| 98 |
+
y_list = []
|
| 99 |
+
for i in range(self.level_N):
|
| 100 |
+
x_pass = self.level_filters[i](x_dct)
|
| 101 |
+
y = self._DCT_patch_T @ x_pass @ self._DCT_patch
|
| 102 |
+
y_list.append(y)
|
| 103 |
+
level_x_unfold = torch.cat(y_list, dim=1)
|
| 104 |
+
|
| 105 |
+
grade = torch.zeros(L).to(x.device)
|
| 106 |
+
w, k = 1, 2
|
| 107 |
+
for _ in range(grade_N):
|
| 108 |
+
_x = torch.abs(x_dct)
|
| 109 |
+
_x = torch.log(_x + 1)
|
| 110 |
+
_x = self.grade_filters[_](_x)
|
| 111 |
+
_x = torch.sum(_x, dim=[1,2,3])
|
| 112 |
+
grade += w * _x
|
| 113 |
+
w *= k
|
| 114 |
+
|
| 115 |
+
_, idx = torch.sort(grade)
|
| 116 |
+
max_idx = torch.flip(idx, dims=[0])[:N]
|
| 117 |
+
maxmax_idx = max_idx[0]
|
| 118 |
+
if len(max_idx) == 1:
|
| 119 |
+
maxmax_idx1 = max_idx[0]
|
| 120 |
+
else:
|
| 121 |
+
maxmax_idx1 = max_idx[1]
|
| 122 |
+
|
| 123 |
+
min_idx = idx[:N]
|
| 124 |
+
minmin_idx = idx[0]
|
| 125 |
+
if len(min_idx) == 1:
|
| 126 |
+
minmin_idx1 = idx[0]
|
| 127 |
+
else:
|
| 128 |
+
minmin_idx1 = idx[1]
|
| 129 |
+
|
| 130 |
+
x_minmin = torch.index_select(level_x_unfold, 0, minmin_idx)
|
| 131 |
+
x_maxmax = torch.index_select(level_x_unfold, 0, maxmax_idx)
|
| 132 |
+
x_minmin1 = torch.index_select(level_x_unfold, 0, minmin_idx1)
|
| 133 |
+
x_maxmax1 = torch.index_select(level_x_unfold, 0, maxmax_idx1)
|
| 134 |
+
|
| 135 |
+
x_minmin = x_minmin.reshape(1, level_N*C*window_size* window_size).transpose(0, 1)
|
| 136 |
+
x_maxmax = x_maxmax.reshape(1, level_N*C*window_size* window_size).transpose(0, 1)
|
| 137 |
+
x_minmin1 = x_minmin1.reshape(1, level_N*C*window_size* window_size).transpose(0, 1)
|
| 138 |
+
x_maxmax1 = x_maxmax1.reshape(1, level_N*C*window_size* window_size).transpose(0, 1)
|
| 139 |
+
|
| 140 |
+
x_minmin = self.fold0(x_minmin)
|
| 141 |
+
x_maxmax = self.fold0(x_maxmax)
|
| 142 |
+
x_minmin1 = self.fold0(x_minmin1)
|
| 143 |
+
x_maxmax1 = self.fold0(x_maxmax1)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
return x_minmin, x_maxmax, x_minmin1, x_maxmax1
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
|
detector_codes/AIDE-main/engine_finetune.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Iterable, Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import utils
|
| 6 |
+
from scipy.special import softmax
|
| 7 |
+
from sklearn.metrics import accuracy_score, average_precision_score
|
| 8 |
+
from timm.data import Mixup
|
| 9 |
+
from timm.utils import ModelEma, accuracy
|
| 10 |
+
from utils import adjust_learning_rate
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def train_one_epoch(
|
| 14 |
+
model: torch.nn.Module,
|
| 15 |
+
criterion: torch.nn.Module,
|
| 16 |
+
data_loader: Iterable,
|
| 17 |
+
optimizer: torch.optim.Optimizer,
|
| 18 |
+
device: torch.device,
|
| 19 |
+
epoch: int,
|
| 20 |
+
loss_scaler,
|
| 21 |
+
max_norm: float = 0,
|
| 22 |
+
model_ema: Optional[ModelEma] = None,
|
| 23 |
+
mixup_fn: Optional[Mixup] = None,
|
| 24 |
+
log_writer=None,
|
| 25 |
+
args=None,
|
| 26 |
+
):
|
| 27 |
+
model.train(True)
|
| 28 |
+
metric_logger = utils.MetricLogger(delimiter=' ')
|
| 29 |
+
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
| 30 |
+
header = 'Epoch: [{}]'.format(epoch)
|
| 31 |
+
print_freq = 100
|
| 32 |
+
|
| 33 |
+
update_freq = args.update_freq
|
| 34 |
+
use_amp = args.use_amp
|
| 35 |
+
optimizer.zero_grad()
|
| 36 |
+
|
| 37 |
+
for data_iter_step, (samples, targets) in enumerate(
|
| 38 |
+
metric_logger.log_every(data_loader, print_freq, header)
|
| 39 |
+
):
|
| 40 |
+
# we use a per iteration (instead of per epoch) lr scheduler
|
| 41 |
+
if data_iter_step % update_freq == 0:
|
| 42 |
+
adjust_learning_rate(
|
| 43 |
+
optimizer, data_iter_step / len(data_loader) + epoch, args
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
samples = samples.to(device, non_blocking=True)
|
| 47 |
+
targets = targets.to(device, non_blocking=True)
|
| 48 |
+
|
| 49 |
+
if mixup_fn is not None:
|
| 50 |
+
samples, targets = mixup_fn(samples, targets)
|
| 51 |
+
|
| 52 |
+
if use_amp:
|
| 53 |
+
with torch.cuda.amp.autocast():
|
| 54 |
+
output = model(samples)
|
| 55 |
+
loss = criterion(output, targets)
|
| 56 |
+
else: # full precision
|
| 57 |
+
output = model(samples)
|
| 58 |
+
loss = criterion(output, targets)
|
| 59 |
+
|
| 60 |
+
loss_value = loss.item()
|
| 61 |
+
|
| 62 |
+
if not math.isfinite(loss_value):
|
| 63 |
+
print('Loss is {}, stopping training'.format(loss_value))
|
| 64 |
+
assert math.isfinite(loss_value)
|
| 65 |
+
|
| 66 |
+
if use_amp:
|
| 67 |
+
# this attribute is added by timm on one optimizer (adahessian)
|
| 68 |
+
is_second_order = (
|
| 69 |
+
hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
|
| 70 |
+
)
|
| 71 |
+
loss /= update_freq
|
| 72 |
+
grad_norm = loss_scaler(
|
| 73 |
+
loss,
|
| 74 |
+
optimizer,
|
| 75 |
+
clip_grad=max_norm,
|
| 76 |
+
parameters=model.parameters(),
|
| 77 |
+
create_graph=is_second_order,
|
| 78 |
+
update_grad=(data_iter_step + 1) % update_freq == 0,
|
| 79 |
+
)
|
| 80 |
+
if (data_iter_step + 1) % update_freq == 0:
|
| 81 |
+
optimizer.zero_grad()
|
| 82 |
+
if model_ema is not None:
|
| 83 |
+
model_ema.update(model)
|
| 84 |
+
else: # full precision
|
| 85 |
+
loss /= update_freq
|
| 86 |
+
loss.backward()
|
| 87 |
+
if (data_iter_step + 1) % update_freq == 0:
|
| 88 |
+
optimizer.step()
|
| 89 |
+
optimizer.zero_grad()
|
| 90 |
+
if model_ema is not None:
|
| 91 |
+
model_ema.update(model)
|
| 92 |
+
|
| 93 |
+
torch.cuda.synchronize()
|
| 94 |
+
|
| 95 |
+
if mixup_fn is None:
|
| 96 |
+
class_acc = (output.max(-1)[-1] == targets).float().mean()
|
| 97 |
+
else:
|
| 98 |
+
class_acc = None
|
| 99 |
+
|
| 100 |
+
metric_logger.update(loss=loss_value)
|
| 101 |
+
metric_logger.update(class_acc=class_acc)
|
| 102 |
+
min_lr = 10.0
|
| 103 |
+
max_lr = 0.0
|
| 104 |
+
for group in optimizer.param_groups:
|
| 105 |
+
min_lr = min(min_lr, group['lr'])
|
| 106 |
+
max_lr = max(max_lr, group['lr'])
|
| 107 |
+
|
| 108 |
+
metric_logger.update(lr=max_lr)
|
| 109 |
+
metric_logger.update(min_lr=min_lr)
|
| 110 |
+
weight_decay_value = None
|
| 111 |
+
for group in optimizer.param_groups:
|
| 112 |
+
if group['weight_decay'] > 0:
|
| 113 |
+
weight_decay_value = group['weight_decay']
|
| 114 |
+
metric_logger.update(weight_decay=weight_decay_value)
|
| 115 |
+
if use_amp:
|
| 116 |
+
metric_logger.update(grad_norm=grad_norm)
|
| 117 |
+
if log_writer is not None:
|
| 118 |
+
log_writer.update(loss=loss_value, head='loss')
|
| 119 |
+
log_writer.update(class_acc=class_acc, head='loss')
|
| 120 |
+
log_writer.update(lr=max_lr, head='opt')
|
| 121 |
+
log_writer.update(min_lr=min_lr, head='opt')
|
| 122 |
+
log_writer.update(weight_decay=weight_decay_value, head='opt')
|
| 123 |
+
if use_amp:
|
| 124 |
+
log_writer.update(grad_norm=grad_norm, head='opt')
|
| 125 |
+
log_writer.set_step()
|
| 126 |
+
|
| 127 |
+
print('Averaged stats:', metric_logger)
|
| 128 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@torch.no_grad()
|
| 132 |
+
def evaluate(data_loader, model, device, use_amp=False):
|
| 133 |
+
criterion = torch.nn.CrossEntropyLoss()
|
| 134 |
+
|
| 135 |
+
metric_logger = utils.MetricLogger(delimiter=' ')
|
| 136 |
+
header = 'Test:'
|
| 137 |
+
|
| 138 |
+
# switch to evaluation mode
|
| 139 |
+
model.eval()
|
| 140 |
+
|
| 141 |
+
predictions = []
|
| 142 |
+
labels = []
|
| 143 |
+
|
| 144 |
+
for index, batch in enumerate(metric_logger.log_every(data_loader, 10, header)):
|
| 145 |
+
images = batch[0]
|
| 146 |
+
target = batch[-1]
|
| 147 |
+
|
| 148 |
+
images = images.to(device, non_blocking=True)
|
| 149 |
+
target = target.to(device, non_blocking=True)
|
| 150 |
+
|
| 151 |
+
# compute output
|
| 152 |
+
if use_amp:
|
| 153 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 154 |
+
output = model(images)
|
| 155 |
+
if isinstance(output, dict):
|
| 156 |
+
output = output['logits']
|
| 157 |
+
loss = criterion(output, target)
|
| 158 |
+
else:
|
| 159 |
+
output = model(images) # [bs, num_cls]
|
| 160 |
+
if isinstance(output, dict):
|
| 161 |
+
output = output['logits']
|
| 162 |
+
|
| 163 |
+
loss = criterion(output, target)
|
| 164 |
+
|
| 165 |
+
predictions.append(output)
|
| 166 |
+
labels.append(target)
|
| 167 |
+
|
| 168 |
+
torch.cuda.synchronize()
|
| 169 |
+
|
| 170 |
+
acc1, acc5 = accuracy(output, target, topk=(1, 2))
|
| 171 |
+
|
| 172 |
+
batch_size = images.shape[0]
|
| 173 |
+
metric_logger.update(loss=loss.item())
|
| 174 |
+
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
|
| 175 |
+
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
|
| 176 |
+
|
| 177 |
+
print(
|
| 178 |
+
'* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'.format(
|
| 179 |
+
top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss
|
| 180 |
+
)
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
# Concatenate predictions and labels
|
| 184 |
+
predictions = torch.cat(predictions, dim=0)
|
| 185 |
+
labels = torch.cat(labels, dim=0)
|
| 186 |
+
|
| 187 |
+
y_pred = softmax(predictions.detach().cpu().numpy(), axis=1)[:, 1]
|
| 188 |
+
y_true = labels.detach().cpu().numpy()
|
| 189 |
+
y_true = y_true.astype(int)
|
| 190 |
+
|
| 191 |
+
acc = accuracy_score(y_true, y_pred > 0.5)
|
| 192 |
+
r_acc = accuracy_score(y_true[y_true == 0], y_pred[y_true == 0] > 0.5)
|
| 193 |
+
f_acc = accuracy_score(y_true[y_true == 1], y_pred[y_true == 1] > 0.5)
|
| 194 |
+
ap = average_precision_score(y_true, y_pred)
|
| 195 |
+
|
| 196 |
+
return (
|
| 197 |
+
{k: meter.global_avg for k, meter in metric_logger.meters.items()},
|
| 198 |
+
acc,
|
| 199 |
+
ap,
|
| 200 |
+
r_acc,
|
| 201 |
+
f_acc,
|
| 202 |
+
)
|
detector_codes/AIDE-main/main_finetune.py
ADDED
|
@@ -0,0 +1,657 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This source code is licensed under the license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
import argparse
|
| 9 |
+
import csv
|
| 10 |
+
import datetime
|
| 11 |
+
import json
|
| 12 |
+
import os
|
| 13 |
+
import time
|
| 14 |
+
import warnings
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
import models.AIDE as AIDE
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
import torch.backends.cudnn as cudnn
|
| 21 |
+
import utils
|
| 22 |
+
from data.datasets import TestDataset, TrainDataset
|
| 23 |
+
from engine_finetune import evaluate, train_one_epoch
|
| 24 |
+
from optim_factory import create_optimizer
|
| 25 |
+
from timm.data.mixup import Mixup
|
| 26 |
+
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
|
| 27 |
+
from timm.utils import ModelEma
|
| 28 |
+
from utils import NativeScalerWithGradNormCount as NativeScaler
|
| 29 |
+
from utils import str2bool
|
| 30 |
+
|
| 31 |
+
warnings.filterwarnings('ignore')
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_args_parser():
|
| 35 |
+
parser = argparse.ArgumentParser('Resnet fine-tuning', add_help=False)
|
| 36 |
+
parser.add_argument('--batch_size', default=32, type=int, help='Per GPU batch size')
|
| 37 |
+
parser.add_argument('--epochs', default=100, type=int)
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
'--update_freq', default=1, type=int, help='gradient accumulation steps'
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# Model parameters
|
| 43 |
+
parser.add_argument(
|
| 44 |
+
'--model',
|
| 45 |
+
default='AIDE',
|
| 46 |
+
type=str,
|
| 47 |
+
metavar='MODEL',
|
| 48 |
+
help='Name of model to train',
|
| 49 |
+
)
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
'--resnet_path',
|
| 52 |
+
default=None,
|
| 53 |
+
type=str,
|
| 54 |
+
metavar='MODEL',
|
| 55 |
+
help='Path of resnet model',
|
| 56 |
+
)
|
| 57 |
+
parser.add_argument(
|
| 58 |
+
'--convnext_path',
|
| 59 |
+
default=None,
|
| 60 |
+
type=str,
|
| 61 |
+
metavar='MODEL',
|
| 62 |
+
help='Path of ConvNeXt of model',
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# EMA related parameters
|
| 66 |
+
parser.add_argument('--model_ema', type=str2bool, default=False)
|
| 67 |
+
parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='')
|
| 68 |
+
parser.add_argument('--model_ema_force_cpu', type=str2bool, default=False, help='')
|
| 69 |
+
parser.add_argument(
|
| 70 |
+
'--model_ema_eval',
|
| 71 |
+
type=str2bool,
|
| 72 |
+
default=False,
|
| 73 |
+
help='Using ema to eval during training.',
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# Optimization parameters
|
| 77 |
+
parser.add_argument(
|
| 78 |
+
'--clip_grad',
|
| 79 |
+
type=float,
|
| 80 |
+
default=None,
|
| 81 |
+
metavar='NORM',
|
| 82 |
+
help='Clip gradient norm (default: None, no clipping)',
|
| 83 |
+
)
|
| 84 |
+
parser.add_argument(
|
| 85 |
+
'--weight_decay', type=float, default=0.0, help='weight decay (default: 0.05)'
|
| 86 |
+
)
|
| 87 |
+
parser.add_argument(
|
| 88 |
+
'--lr',
|
| 89 |
+
type=float,
|
| 90 |
+
default=None,
|
| 91 |
+
metavar='LR',
|
| 92 |
+
help='learning rate (absolute lr)',
|
| 93 |
+
)
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
'--blr',
|
| 96 |
+
type=float,
|
| 97 |
+
default=5e-4,
|
| 98 |
+
metavar='LR',
|
| 99 |
+
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256',
|
| 100 |
+
)
|
| 101 |
+
parser.add_argument('--layer_decay', type=float, default=1.0)
|
| 102 |
+
parser.add_argument(
|
| 103 |
+
'--min_lr',
|
| 104 |
+
type=float,
|
| 105 |
+
default=1e-6,
|
| 106 |
+
metavar='LR',
|
| 107 |
+
help='lower lr bound for cyclic schedulers that hit 0 (1e-6)',
|
| 108 |
+
)
|
| 109 |
+
parser.add_argument(
|
| 110 |
+
'--warmup_epochs',
|
| 111 |
+
type=int,
|
| 112 |
+
default=0,
|
| 113 |
+
metavar='N',
|
| 114 |
+
help='epochs to warmup LR, if scheduler supports',
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
parser.add_argument(
|
| 118 |
+
'--warmup_steps',
|
| 119 |
+
type=int,
|
| 120 |
+
default=-1,
|
| 121 |
+
metavar='N',
|
| 122 |
+
help='num of steps to warmup LR, will overload warmup_epochs if set > 0',
|
| 123 |
+
)
|
| 124 |
+
parser.add_argument(
|
| 125 |
+
'--opt',
|
| 126 |
+
default='adamw',
|
| 127 |
+
type=str,
|
| 128 |
+
metavar='OPTIMIZER',
|
| 129 |
+
help='Optimizer (default: "adamw"',
|
| 130 |
+
)
|
| 131 |
+
parser.add_argument(
|
| 132 |
+
'--opt_eps',
|
| 133 |
+
default=1e-8,
|
| 134 |
+
type=float,
|
| 135 |
+
metavar='EPSILON',
|
| 136 |
+
help='Optimizer Epsilon (default: 1e-8)',
|
| 137 |
+
)
|
| 138 |
+
parser.add_argument(
|
| 139 |
+
'--opt_betas',
|
| 140 |
+
default=None,
|
| 141 |
+
type=float,
|
| 142 |
+
nargs='+',
|
| 143 |
+
metavar='BETA',
|
| 144 |
+
help='Optimizer Betas (default: None, use opt default)',
|
| 145 |
+
)
|
| 146 |
+
parser.add_argument(
|
| 147 |
+
'--momentum',
|
| 148 |
+
type=float,
|
| 149 |
+
default=0.9,
|
| 150 |
+
metavar='M',
|
| 151 |
+
help='SGD momentum (default: 0.9)',
|
| 152 |
+
)
|
| 153 |
+
parser.add_argument(
|
| 154 |
+
'--weight_decay_end',
|
| 155 |
+
type=float,
|
| 156 |
+
default=None,
|
| 157 |
+
help="""Final value of the
|
| 158 |
+
weight decay. We use a cosine schedule for WD and using a larger decay by
|
| 159 |
+
the end of training improves performance for ViTs.""",
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
# Augmentation parameters
|
| 163 |
+
parser.add_argument(
|
| 164 |
+
'--color_jitter',
|
| 165 |
+
type=float,
|
| 166 |
+
default=None,
|
| 167 |
+
metavar='PCT',
|
| 168 |
+
help='Color jitter factor (enabled only when not using Auto/RandAug)',
|
| 169 |
+
)
|
| 170 |
+
parser.add_argument(
|
| 171 |
+
'--aa',
|
| 172 |
+
type=str,
|
| 173 |
+
default='rand-m9-mstd0.5-inc1',
|
| 174 |
+
metavar='NAME',
|
| 175 |
+
help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)',
|
| 176 |
+
)
|
| 177 |
+
parser.add_argument(
|
| 178 |
+
'--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)'
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
parser.add_argument(
|
| 182 |
+
'--train_interpolation',
|
| 183 |
+
type=str,
|
| 184 |
+
default='bicubic',
|
| 185 |
+
help='Training interpolation (random, bilinear, bicubic default: "bicubic")',
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# * Random Erase params
|
| 189 |
+
parser.add_argument(
|
| 190 |
+
'--reprob',
|
| 191 |
+
type=float,
|
| 192 |
+
default=0.25,
|
| 193 |
+
metavar='PCT',
|
| 194 |
+
help='Random erase prob (default: 0.25)',
|
| 195 |
+
)
|
| 196 |
+
parser.add_argument(
|
| 197 |
+
'--remode',
|
| 198 |
+
type=str,
|
| 199 |
+
default='pixel',
|
| 200 |
+
help='Random erase mode (default: "pixel")',
|
| 201 |
+
)
|
| 202 |
+
parser.add_argument(
|
| 203 |
+
'--recount', type=int, default=1, help='Random erase count (default: 1)'
|
| 204 |
+
)
|
| 205 |
+
parser.add_argument(
|
| 206 |
+
'--resplit',
|
| 207 |
+
type=str2bool,
|
| 208 |
+
default=False,
|
| 209 |
+
help='Do not random erase first (clean) augmentation split',
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
# * Mixup params
|
| 213 |
+
parser.add_argument(
|
| 214 |
+
'--mixup', type=float, default=0.0, help='mixup alpha, mixup enabled if > 0.'
|
| 215 |
+
)
|
| 216 |
+
parser.add_argument(
|
| 217 |
+
'--cutmix', type=float, default=0.0, help='cutmix alpha, cutmix enabled if > 0.'
|
| 218 |
+
)
|
| 219 |
+
parser.add_argument(
|
| 220 |
+
'--cutmix_minmax',
|
| 221 |
+
type=float,
|
| 222 |
+
nargs='+',
|
| 223 |
+
default=None,
|
| 224 |
+
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)',
|
| 225 |
+
)
|
| 226 |
+
parser.add_argument(
|
| 227 |
+
'--mixup_prob',
|
| 228 |
+
type=float,
|
| 229 |
+
default=1.0,
|
| 230 |
+
help='Probability of performing mixup or cutmix when either/both is enabled',
|
| 231 |
+
)
|
| 232 |
+
parser.add_argument(
|
| 233 |
+
'--mixup_switch_prob',
|
| 234 |
+
type=float,
|
| 235 |
+
default=0.5,
|
| 236 |
+
help='Probability of switching to cutmix when both mixup and cutmix enabled',
|
| 237 |
+
)
|
| 238 |
+
parser.add_argument(
|
| 239 |
+
'--mixup_mode',
|
| 240 |
+
type=str,
|
| 241 |
+
default='batch',
|
| 242 |
+
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"',
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
# * Finetuning params
|
| 246 |
+
parser.add_argument('--finetune', default='', help='finetune from checkpoint')
|
| 247 |
+
parser.add_argument(
|
| 248 |
+
'--head_init_scale',
|
| 249 |
+
default=0.001,
|
| 250 |
+
type=float,
|
| 251 |
+
help='classifier head initial scale, typically adjusted in fine-tuning',
|
| 252 |
+
)
|
| 253 |
+
parser.add_argument(
|
| 254 |
+
'--model_key',
|
| 255 |
+
default='model|module',
|
| 256 |
+
type=str,
|
| 257 |
+
help='which key to load from saved state dict, usually model or model_ema',
|
| 258 |
+
)
|
| 259 |
+
parser.add_argument('--model_prefix', default='', type=str)
|
| 260 |
+
|
| 261 |
+
# Dataset parameters
|
| 262 |
+
parser.add_argument(
|
| 263 |
+
'--data_path', default='path/dataset', type=str, help='dataset path'
|
| 264 |
+
)
|
| 265 |
+
parser.add_argument(
|
| 266 |
+
'--nb_classes', default=2, type=int, help='number of the classification types'
|
| 267 |
+
)
|
| 268 |
+
parser.add_argument(
|
| 269 |
+
'--output_dir',
|
| 270 |
+
default='./5class-checkpoints',
|
| 271 |
+
help='path where to save, empty for no saving',
|
| 272 |
+
)
|
| 273 |
+
parser.add_argument('--log_dir', default=None, help='path where to tensorboard log')
|
| 274 |
+
parser.add_argument(
|
| 275 |
+
'--device', default='cuda:0', help='device to use for training / testing'
|
| 276 |
+
)
|
| 277 |
+
parser.add_argument('--seed', default=0, type=int)
|
| 278 |
+
parser.add_argument('--resume', default='', help='resume from checkpoint')
|
| 279 |
+
|
| 280 |
+
parser.add_argument(
|
| 281 |
+
'--eval_data_path', default=None, type=str, help='dataset path for evaluation'
|
| 282 |
+
)
|
| 283 |
+
parser.add_argument('--imagenet_default_mean_and_std', type=str2bool, default=True)
|
| 284 |
+
parser.add_argument(
|
| 285 |
+
'--data_set',
|
| 286 |
+
default='IMNET',
|
| 287 |
+
choices=['CIFAR', 'IMNET', 'image_folder'],
|
| 288 |
+
type=str,
|
| 289 |
+
help='ImageNet dataset path',
|
| 290 |
+
)
|
| 291 |
+
parser.add_argument('--auto_resume', type=str2bool, default=True)
|
| 292 |
+
parser.add_argument('--save_ckpt', type=str2bool, default=True)
|
| 293 |
+
parser.add_argument('--save_ckpt_freq', default=5, type=int)
|
| 294 |
+
parser.add_argument('--save_ckpt_num', default=100, type=int)
|
| 295 |
+
|
| 296 |
+
parser.add_argument(
|
| 297 |
+
'--start_epoch', default=0, type=int, metavar='N', help='start epoch'
|
| 298 |
+
)
|
| 299 |
+
parser.add_argument(
|
| 300 |
+
'--eval', type=str2bool, default=False, help='Perform evaluation only'
|
| 301 |
+
)
|
| 302 |
+
parser.add_argument(
|
| 303 |
+
'--dist_eval',
|
| 304 |
+
type=str2bool,
|
| 305 |
+
default=True,
|
| 306 |
+
help='Enabling distributed evaluation',
|
| 307 |
+
)
|
| 308 |
+
parser.add_argument(
|
| 309 |
+
'--disable_eval',
|
| 310 |
+
type=str2bool,
|
| 311 |
+
default=False,
|
| 312 |
+
help='Disabling evaluation during training',
|
| 313 |
+
)
|
| 314 |
+
parser.add_argument('--num_workers', default=8, type=int)
|
| 315 |
+
parser.add_argument(
|
| 316 |
+
'--pin_mem',
|
| 317 |
+
type=str2bool,
|
| 318 |
+
default=True,
|
| 319 |
+
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.',
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
# Evaluation parameters
|
| 323 |
+
parser.add_argument('--crop_pct', type=float, default=None)
|
| 324 |
+
|
| 325 |
+
# GPU selection parameter
|
| 326 |
+
parser.add_argument(
|
| 327 |
+
'--gpu_ids',
|
| 328 |
+
type=str,
|
| 329 |
+
default='2',
|
| 330 |
+
help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU',
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
parser.add_argument(
|
| 334 |
+
'--use_amp',
|
| 335 |
+
type=str2bool,
|
| 336 |
+
default=False,
|
| 337 |
+
help='Use apex AMP (Automatic Mixed Precision) or not',
|
| 338 |
+
)
|
| 339 |
+
return parser
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def main(args):
|
| 343 |
+
print(args)
|
| 344 |
+
|
| 345 |
+
# Set device based on gpu_ids
|
| 346 |
+
if args.gpu_ids == '-1':
|
| 347 |
+
device = torch.device('cpu')
|
| 348 |
+
else:
|
| 349 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_ids
|
| 350 |
+
device = torch.device('cuda')
|
| 351 |
+
|
| 352 |
+
# Fix the seed for reproducibility
|
| 353 |
+
seed = args.seed
|
| 354 |
+
torch.manual_seed(seed)
|
| 355 |
+
np.random.seed(seed)
|
| 356 |
+
cudnn.benchmark = True
|
| 357 |
+
|
| 358 |
+
dataset_train = TrainDataset(is_train=True, args=args)
|
| 359 |
+
|
| 360 |
+
if args.disable_eval:
|
| 361 |
+
dataset_val = None
|
| 362 |
+
else:
|
| 363 |
+
dataset_val = TrainDataset(is_train=False, args=args)
|
| 364 |
+
|
| 365 |
+
sampler_train = torch.utils.data.RandomSampler(dataset_train)
|
| 366 |
+
|
| 367 |
+
if dataset_val is not None:
|
| 368 |
+
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
|
| 369 |
+
else:
|
| 370 |
+
sampler_val = None
|
| 371 |
+
|
| 372 |
+
if args.log_dir is not None:
|
| 373 |
+
os.makedirs(args.log_dir, exist_ok=True)
|
| 374 |
+
log_writer = utils.TensorboardLogger(log_dir=args.log_dir)
|
| 375 |
+
else:
|
| 376 |
+
log_writer = None
|
| 377 |
+
|
| 378 |
+
data_loader_train = torch.utils.data.DataLoader(
|
| 379 |
+
dataset_train,
|
| 380 |
+
sampler=sampler_train,
|
| 381 |
+
batch_size=args.batch_size,
|
| 382 |
+
num_workers=args.num_workers,
|
| 383 |
+
pin_memory=args.pin_mem,
|
| 384 |
+
drop_last=True,
|
| 385 |
+
)
|
| 386 |
+
if dataset_val is not None:
|
| 387 |
+
data_loader_val = torch.utils.data.DataLoader(
|
| 388 |
+
dataset_val,
|
| 389 |
+
sampler=sampler_val,
|
| 390 |
+
batch_size=args.batch_size,
|
| 391 |
+
num_workers=args.num_workers,
|
| 392 |
+
pin_memory=args.pin_mem,
|
| 393 |
+
drop_last=False,
|
| 394 |
+
)
|
| 395 |
+
else:
|
| 396 |
+
data_loader_val = None
|
| 397 |
+
|
| 398 |
+
mixup_fn = None
|
| 399 |
+
mixup_active = args.mixup > 0 or args.cutmix > 0.0 or args.cutmix_minmax is not None
|
| 400 |
+
if mixup_active:
|
| 401 |
+
print('Mixup is activated!')
|
| 402 |
+
mixup_fn = Mixup(
|
| 403 |
+
mixup_alpha=args.mixup,
|
| 404 |
+
cutmix_alpha=args.cutmix,
|
| 405 |
+
cutmix_minmax=args.cutmix_minmax,
|
| 406 |
+
prob=args.mixup_prob,
|
| 407 |
+
switch_prob=args.mixup_switch_prob,
|
| 408 |
+
mode=args.mixup_mode,
|
| 409 |
+
label_smoothing=args.smoothing,
|
| 410 |
+
num_classes=args.nb_classes,
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
model = AIDE.__dict__[args.model](
|
| 414 |
+
resnet_path=args.resnet_path, convnext_path=args.convnext_path
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
model.to(device)
|
| 418 |
+
|
| 419 |
+
model_ema = None
|
| 420 |
+
if args.model_ema:
|
| 421 |
+
model_ema = ModelEma(
|
| 422 |
+
model,
|
| 423 |
+
decay=args.model_ema_decay,
|
| 424 |
+
device='cpu' if args.model_ema_force_cpu else '',
|
| 425 |
+
resume='',
|
| 426 |
+
)
|
| 427 |
+
print('Using EMA with decay = %.8f' % args.model_ema_decay)
|
| 428 |
+
|
| 429 |
+
model_without_ddp = model
|
| 430 |
+
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 431 |
+
|
| 432 |
+
print('Model = %s' % str(model_without_ddp))
|
| 433 |
+
print('number of params:', n_parameters)
|
| 434 |
+
|
| 435 |
+
eff_batch_size = args.batch_size * args.update_freq
|
| 436 |
+
num_training_steps_per_epoch = len(dataset_train) // eff_batch_size
|
| 437 |
+
|
| 438 |
+
if args.lr is None:
|
| 439 |
+
args.lr = args.blr * eff_batch_size / 256
|
| 440 |
+
|
| 441 |
+
print('base lr: %.2e' % (args.lr * 256 / eff_batch_size))
|
| 442 |
+
print('actual lr: %.2e' % args.lr)
|
| 443 |
+
|
| 444 |
+
print('accumulate grad iterations: %d' % args.update_freq)
|
| 445 |
+
print('effective batch size: %d' % eff_batch_size)
|
| 446 |
+
|
| 447 |
+
assigner = None
|
| 448 |
+
|
| 449 |
+
optimizer = create_optimizer(
|
| 450 |
+
args,
|
| 451 |
+
model_without_ddp,
|
| 452 |
+
skip_list=None,
|
| 453 |
+
get_num_layer=assigner.get_layer_id if assigner is not None else None,
|
| 454 |
+
get_layer_scale=assigner.get_scale if assigner is not None else None,
|
| 455 |
+
)
|
| 456 |
+
loss_scaler = NativeScaler()
|
| 457 |
+
|
| 458 |
+
if mixup_fn is not None:
|
| 459 |
+
criterion = SoftTargetCrossEntropy()
|
| 460 |
+
elif args.smoothing > 0.0:
|
| 461 |
+
criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
|
| 462 |
+
else:
|
| 463 |
+
criterion = torch.nn.CrossEntropyLoss()
|
| 464 |
+
|
| 465 |
+
print('criterion = %s' % str(criterion))
|
| 466 |
+
|
| 467 |
+
utils.auto_load_model(
|
| 468 |
+
args=args,
|
| 469 |
+
model=model,
|
| 470 |
+
model_without_ddp=model_without_ddp,
|
| 471 |
+
optimizer=optimizer,
|
| 472 |
+
loss_scaler=loss_scaler,
|
| 473 |
+
model_ema=model_ema,
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
if args.eval:
|
| 477 |
+
print('Eval only mode')
|
| 478 |
+
|
| 479 |
+
vals = os.listdir(args.eval_data_path)
|
| 480 |
+
eval_data_path = args.eval_data_path
|
| 481 |
+
|
| 482 |
+
rows = [
|
| 483 |
+
['{} model testing on...'.format(args.resume)],
|
| 484 |
+
['testset', 'accuracy', 'avg precision', 'r_acc', 'f_acc'],
|
| 485 |
+
]
|
| 486 |
+
|
| 487 |
+
for v_id, val in enumerate(vals):
|
| 488 |
+
args.eval_data_path = os.path.join(args.eval_data_path, val)
|
| 489 |
+
dataset_val = TestDataset(is_train=False, args=args)
|
| 490 |
+
args.eval_data_path = eval_data_path
|
| 491 |
+
|
| 492 |
+
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
|
| 493 |
+
|
| 494 |
+
data_loader_val = torch.utils.data.DataLoader(
|
| 495 |
+
dataset_val,
|
| 496 |
+
sampler=sampler_val,
|
| 497 |
+
batch_size=args.batch_size,
|
| 498 |
+
num_workers=args.num_workers,
|
| 499 |
+
pin_memory=args.pin_mem,
|
| 500 |
+
drop_last=False,
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
test_stats, acc, ap, r_acc, f_acc = evaluate(data_loader_val, model, device)
|
| 504 |
+
print(
|
| 505 |
+
f'Accuracy of the network on {len(dataset_val)} test images: {test_stats["acc1"]:.5f}%'
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
print(
|
| 509 |
+
f'test dataset is {val} acc: {acc}, ap: {ap}, r_acc: {r_acc}, f_acc: {f_acc}'
|
| 510 |
+
)
|
| 511 |
+
print('***********************************')
|
| 512 |
+
|
| 513 |
+
rows.append([val, acc, ap, r_acc, f_acc])
|
| 514 |
+
|
| 515 |
+
def calculate_column_means(rows):
|
| 516 |
+
if not rows or len(rows[0]) < 2:
|
| 517 |
+
raise ValueError(
|
| 518 |
+
'The input rows list is empty or lacks numeric columns.'
|
| 519 |
+
)
|
| 520 |
+
num_columns = len(rows[0]) - 1
|
| 521 |
+
means = ['mean'] + [
|
| 522 |
+
sum(row[i] for row in rows) / len(rows)
|
| 523 |
+
for i in range(1, num_columns + 1)
|
| 524 |
+
]
|
| 525 |
+
return means
|
| 526 |
+
|
| 527 |
+
rows.append(calculate_column_means(rows[2:]))
|
| 528 |
+
test_dataset_name = args.eval_data_path.split('/')[-2]
|
| 529 |
+
|
| 530 |
+
csv_name = os.path.join(
|
| 531 |
+
args.output_dir, f'{os.path.basename(args.resume)}_{test_dataset_name}.csv'
|
| 532 |
+
)
|
| 533 |
+
with open(csv_name, 'w') as f:
|
| 534 |
+
csv_writer = csv.writer(f, delimiter=',')
|
| 535 |
+
csv_writer.writerows(rows)
|
| 536 |
+
return
|
| 537 |
+
|
| 538 |
+
max_accuracy = 0.0
|
| 539 |
+
if args.model_ema and args.model_ema_eval:
|
| 540 |
+
max_accuracy_ema = 0.0
|
| 541 |
+
|
| 542 |
+
print('Start training for %d epochs' % args.epochs)
|
| 543 |
+
start_time = time.time()
|
| 544 |
+
for epoch in range(args.start_epoch, args.epochs):
|
| 545 |
+
if log_writer is not None:
|
| 546 |
+
log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq)
|
| 547 |
+
train_stats = train_one_epoch(
|
| 548 |
+
model,
|
| 549 |
+
criterion,
|
| 550 |
+
data_loader_train,
|
| 551 |
+
optimizer,
|
| 552 |
+
device,
|
| 553 |
+
epoch,
|
| 554 |
+
loss_scaler,
|
| 555 |
+
args.clip_grad,
|
| 556 |
+
model_ema,
|
| 557 |
+
mixup_fn,
|
| 558 |
+
log_writer=log_writer,
|
| 559 |
+
args=args,
|
| 560 |
+
)
|
| 561 |
+
if args.output_dir and args.save_ckpt:
|
| 562 |
+
if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs:
|
| 563 |
+
utils.save_model(
|
| 564 |
+
args=args,
|
| 565 |
+
model=model,
|
| 566 |
+
model_without_ddp=model_without_ddp,
|
| 567 |
+
optimizer=optimizer,
|
| 568 |
+
loss_scaler=loss_scaler,
|
| 569 |
+
epoch=epoch,
|
| 570 |
+
model_ema=model_ema,
|
| 571 |
+
)
|
| 572 |
+
if data_loader_val is not None:
|
| 573 |
+
test_stats, acc, ap, r_acc, f_acc = evaluate(
|
| 574 |
+
data_loader_val, model, device, use_amp=args.use_amp
|
| 575 |
+
)
|
| 576 |
+
print(
|
| 577 |
+
f'Accuracy of the model on the {len(dataset_val)} test images: {test_stats["acc1"]:.1f}%, ap: {ap}.'
|
| 578 |
+
)
|
| 579 |
+
if max_accuracy < test_stats['acc1']:
|
| 580 |
+
max_accuracy = test_stats['acc1']
|
| 581 |
+
if args.output_dir and args.save_ckpt:
|
| 582 |
+
utils.save_model(
|
| 583 |
+
args=args,
|
| 584 |
+
model=model,
|
| 585 |
+
model_without_ddp=model_without_ddp,
|
| 586 |
+
optimizer=optimizer,
|
| 587 |
+
loss_scaler=loss_scaler,
|
| 588 |
+
epoch='best',
|
| 589 |
+
model_ema=model_ema,
|
| 590 |
+
)
|
| 591 |
+
print(f'Max accuracy: {max_accuracy:.2f}%')
|
| 592 |
+
|
| 593 |
+
if log_writer is not None:
|
| 594 |
+
log_writer.update(test_acc1=test_stats['acc1'], head='perf', step=epoch)
|
| 595 |
+
log_writer.update(test_acc5=test_stats['acc5'], head='perf', step=epoch)
|
| 596 |
+
log_writer.update(test_loss=test_stats['loss'], head='perf', step=epoch)
|
| 597 |
+
|
| 598 |
+
log_stats = {
|
| 599 |
+
**{f'train_{k}': v for k, v in train_stats.items()},
|
| 600 |
+
**{f'test_{k}': v for k, v in test_stats.items()},
|
| 601 |
+
'epoch': epoch,
|
| 602 |
+
'n_parameters': n_parameters,
|
| 603 |
+
}
|
| 604 |
+
|
| 605 |
+
if args.model_ema and args.model_ema_eval:
|
| 606 |
+
test_stats_ema, acc, ap, r_acc, f_acc = evaluate(
|
| 607 |
+
data_loader_val, model_ema.ema, device, use_amp=args.use_amp
|
| 608 |
+
)
|
| 609 |
+
print(
|
| 610 |
+
f'Accuracy of the model EMA on {len(dataset_val)} test images: {test_stats_ema["acc1"]:.1f}%, ap: {ap}'
|
| 611 |
+
)
|
| 612 |
+
if max_accuracy_ema < test_stats_ema['acc1']:
|
| 613 |
+
max_accuracy_ema = test_stats_ema['acc1']
|
| 614 |
+
if args.output_dir and args.save_ckpt:
|
| 615 |
+
utils.save_model(
|
| 616 |
+
args=args,
|
| 617 |
+
model=model,
|
| 618 |
+
model_without_ddp=model_without_ddp,
|
| 619 |
+
optimizer=optimizer,
|
| 620 |
+
loss_scaler=loss_scaler,
|
| 621 |
+
epoch='best-ema',
|
| 622 |
+
model_ema=model_ema,
|
| 623 |
+
)
|
| 624 |
+
print(f'Max EMA accuracy: {max_accuracy_ema:.2f}%')
|
| 625 |
+
if log_writer is not None:
|
| 626 |
+
log_writer.update(
|
| 627 |
+
test_acc1_ema=test_stats_ema['acc1'], head='perf', step=epoch
|
| 628 |
+
)
|
| 629 |
+
log_stats.update(
|
| 630 |
+
{**{f'test_{k}_ema': v for k, v in test_stats_ema.items()}}
|
| 631 |
+
)
|
| 632 |
+
else:
|
| 633 |
+
log_stats = {
|
| 634 |
+
**{f'train_{k}': v for k, v in train_stats.items()},
|
| 635 |
+
'epoch': epoch,
|
| 636 |
+
'n_parameters': n_parameters,
|
| 637 |
+
}
|
| 638 |
+
|
| 639 |
+
if args.output_dir:
|
| 640 |
+
if log_writer is not None:
|
| 641 |
+
log_writer.flush()
|
| 642 |
+
with open(
|
| 643 |
+
os.path.join(args.output_dir, 'log.txt'), mode='a', encoding='utf-8'
|
| 644 |
+
) as f:
|
| 645 |
+
f.write(json.dumps(log_stats) + '\n')
|
| 646 |
+
|
| 647 |
+
total_time = time.time() - start_time
|
| 648 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 649 |
+
print('Training time {}'.format(total_time_str))
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
if __name__ == '__main__':
|
| 653 |
+
parser = argparse.ArgumentParser('AIDE training', parents=[get_args_parser()])
|
| 654 |
+
args = parser.parse_args()
|
| 655 |
+
if args.output_dir:
|
| 656 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
| 657 |
+
main(args)
|
detector_codes/AIDE-main/models/AIDE.py
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch.utils.model_zoo as model_zoo
|
| 3 |
+
import torch
|
| 4 |
+
import clip
|
| 5 |
+
import open_clip
|
| 6 |
+
from .srm_filter_kernel import all_normalized_hpf_list
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
class HPF(nn.Module):
|
| 10 |
+
def __init__(self):
|
| 11 |
+
super(HPF, self).__init__()
|
| 12 |
+
|
| 13 |
+
#Load 30 SRM Filters
|
| 14 |
+
all_hpf_list_5x5 = []
|
| 15 |
+
|
| 16 |
+
for hpf_item in all_normalized_hpf_list:
|
| 17 |
+
if hpf_item.shape[0] == 3:
|
| 18 |
+
hpf_item = np.pad(hpf_item, pad_width=((1, 1), (1, 1)), mode='constant')
|
| 19 |
+
|
| 20 |
+
all_hpf_list_5x5.append(hpf_item)
|
| 21 |
+
|
| 22 |
+
hpf_weight = torch.Tensor(all_hpf_list_5x5).view(30, 1, 5, 5).contiguous()
|
| 23 |
+
hpf_weight = torch.nn.Parameter(hpf_weight.repeat(1, 3, 1, 1), requires_grad=False)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
self.hpf = nn.Conv2d(3, 30, kernel_size=5, padding=2, bias=False)
|
| 27 |
+
self.hpf.weight = hpf_weight
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def forward(self, input):
|
| 31 |
+
|
| 32 |
+
output = self.hpf(input)
|
| 33 |
+
|
| 34 |
+
return output
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
| 39 |
+
"""3x3 convolution with padding"""
|
| 40 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
| 41 |
+
padding=1, bias=False)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
| 45 |
+
"""1x1 convolution"""
|
| 46 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class BasicBlock(nn.Module):
|
| 50 |
+
expansion = 1
|
| 51 |
+
|
| 52 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 53 |
+
super(BasicBlock, self).__init__()
|
| 54 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
| 55 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 56 |
+
self.relu = nn.ReLU(inplace=True)
|
| 57 |
+
self.conv2 = conv3x3(planes, planes)
|
| 58 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 59 |
+
self.downsample = downsample
|
| 60 |
+
self.stride = stride
|
| 61 |
+
|
| 62 |
+
def forward(self, x):
|
| 63 |
+
identity = x
|
| 64 |
+
|
| 65 |
+
out = self.conv1(x)
|
| 66 |
+
out = self.bn1(out)
|
| 67 |
+
out = self.relu(out)
|
| 68 |
+
|
| 69 |
+
out = self.conv2(out)
|
| 70 |
+
out = self.bn2(out)
|
| 71 |
+
|
| 72 |
+
if self.downsample is not None:
|
| 73 |
+
identity = self.downsample(x)
|
| 74 |
+
|
| 75 |
+
out += identity
|
| 76 |
+
out = self.relu(out)
|
| 77 |
+
|
| 78 |
+
return out
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class Bottleneck(nn.Module):
|
| 82 |
+
expansion = 4
|
| 83 |
+
|
| 84 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 85 |
+
super(Bottleneck, self).__init__()
|
| 86 |
+
self.conv1 = conv1x1(inplanes, planes)
|
| 87 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 88 |
+
self.conv2 = conv3x3(planes, planes, stride)
|
| 89 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 90 |
+
self.conv3 = conv1x1(planes, planes * self.expansion)
|
| 91 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
| 92 |
+
self.relu = nn.ReLU(inplace=True)
|
| 93 |
+
self.downsample = downsample
|
| 94 |
+
self.stride = stride
|
| 95 |
+
|
| 96 |
+
def forward(self, x):
|
| 97 |
+
identity = x
|
| 98 |
+
|
| 99 |
+
out = self.conv1(x)
|
| 100 |
+
out = self.bn1(out)
|
| 101 |
+
out = self.relu(out)
|
| 102 |
+
|
| 103 |
+
out = self.conv2(out)
|
| 104 |
+
out = self.bn2(out)
|
| 105 |
+
out = self.relu(out)
|
| 106 |
+
|
| 107 |
+
out = self.conv3(out)
|
| 108 |
+
out = self.bn3(out)
|
| 109 |
+
|
| 110 |
+
if self.downsample is not None:
|
| 111 |
+
identity = self.downsample(x)
|
| 112 |
+
|
| 113 |
+
out += identity
|
| 114 |
+
out = self.relu(out)
|
| 115 |
+
|
| 116 |
+
return out
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class ResNet(nn.Module):
|
| 120 |
+
|
| 121 |
+
def __init__(self, block, layers, num_classes=1000, zero_init_residual=True):
|
| 122 |
+
super(ResNet, self).__init__()
|
| 123 |
+
|
| 124 |
+
self.inplanes = 64
|
| 125 |
+
self.conv1 = nn.Conv2d(30, 64, kernel_size=7, stride=2, padding=3,
|
| 126 |
+
bias=False)
|
| 127 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 128 |
+
self.relu = nn.ReLU(inplace=True)
|
| 129 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 130 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 131 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
| 132 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
| 133 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
| 134 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| 135 |
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
| 136 |
+
|
| 137 |
+
for m in self.modules():
|
| 138 |
+
if isinstance(m, nn.Conv2d):
|
| 139 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 140 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 141 |
+
nn.init.constant_(m.weight, 1)
|
| 142 |
+
nn.init.constant_(m.bias, 0)
|
| 143 |
+
|
| 144 |
+
# Zero-initialize the last BN in each residual branch,
|
| 145 |
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
| 146 |
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
| 147 |
+
if zero_init_residual:
|
| 148 |
+
for m in self.modules():
|
| 149 |
+
if isinstance(m, Bottleneck):
|
| 150 |
+
nn.init.constant_(m.bn3.weight, 0)
|
| 151 |
+
elif isinstance(m, BasicBlock):
|
| 152 |
+
nn.init.constant_(m.bn2.weight, 0)
|
| 153 |
+
|
| 154 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
| 155 |
+
downsample = None
|
| 156 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 157 |
+
downsample = nn.Sequential(
|
| 158 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
| 159 |
+
nn.BatchNorm2d(planes * block.expansion),
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
layers = []
|
| 163 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
| 164 |
+
self.inplanes = planes * block.expansion
|
| 165 |
+
for _ in range(1, blocks):
|
| 166 |
+
layers.append(block(self.inplanes, planes))
|
| 167 |
+
|
| 168 |
+
return nn.Sequential(*layers)
|
| 169 |
+
|
| 170 |
+
def forward(self, x):
|
| 171 |
+
|
| 172 |
+
x = self.conv1(x)
|
| 173 |
+
x = self.bn1(x)
|
| 174 |
+
x = self.relu(x)
|
| 175 |
+
x = self.maxpool(x)
|
| 176 |
+
|
| 177 |
+
x = self.layer1(x)
|
| 178 |
+
x = self.layer2(x)
|
| 179 |
+
x = self.layer3(x)
|
| 180 |
+
x = self.layer4(x)
|
| 181 |
+
|
| 182 |
+
x = self.avgpool(x)
|
| 183 |
+
x = x.view(x.size(0), -1)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
return x
|
| 187 |
+
|
| 188 |
+
class Mlp(nn.Module):
|
| 189 |
+
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU):
|
| 193 |
+
super().__init__()
|
| 194 |
+
out_features = out_features or in_features
|
| 195 |
+
hidden_features = hidden_features or in_features
|
| 196 |
+
|
| 197 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 198 |
+
self.act = act_layer()
|
| 199 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 200 |
+
|
| 201 |
+
def forward(self, x):
|
| 202 |
+
x = self.fc1(x)
|
| 203 |
+
x = self.act(x)
|
| 204 |
+
x = self.fc2(x)
|
| 205 |
+
return x
|
| 206 |
+
|
| 207 |
+
class AIDE_Model(nn.Module):
|
| 208 |
+
|
| 209 |
+
def __init__(self, resnet_path, convnext_path):
|
| 210 |
+
super(AIDE_Model, self).__init__()
|
| 211 |
+
self.hpf = HPF()
|
| 212 |
+
self.model_min = ResNet(Bottleneck, [3, 4, 6, 3])
|
| 213 |
+
self.model_max = ResNet(Bottleneck, [3, 4, 6, 3])
|
| 214 |
+
|
| 215 |
+
if resnet_path is not None:
|
| 216 |
+
pretrained_dict = torch.load(resnet_path, map_location='cpu')
|
| 217 |
+
|
| 218 |
+
model_min_dict = self.model_min.state_dict()
|
| 219 |
+
model_max_dict = self.model_max.state_dict()
|
| 220 |
+
|
| 221 |
+
for k in pretrained_dict.keys():
|
| 222 |
+
if k in model_min_dict and pretrained_dict[k].size() == model_min_dict[k].size():
|
| 223 |
+
model_min_dict[k] = pretrained_dict[k]
|
| 224 |
+
model_max_dict[k] = pretrained_dict[k]
|
| 225 |
+
else:
|
| 226 |
+
print(f"Skipping layer {k} because of size mismatch")
|
| 227 |
+
|
| 228 |
+
self.fc = Mlp(2048 + 256 , 1024, 2)
|
| 229 |
+
|
| 230 |
+
print("build model with convnext_xxl")
|
| 231 |
+
self.openclip_convnext_xxl, _, _ = open_clip.create_model_and_transforms(
|
| 232 |
+
"convnext_xxlarge", pretrained=convnext_path
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
self.openclip_convnext_xxl = self.openclip_convnext_xxl.visual.trunk
|
| 236 |
+
self.openclip_convnext_xxl.head.global_pool = nn.Identity()
|
| 237 |
+
self.openclip_convnext_xxl.head.flatten = nn.Identity()
|
| 238 |
+
|
| 239 |
+
self.openclip_convnext_xxl.eval()
|
| 240 |
+
|
| 241 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| 242 |
+
self.convnext_proj = nn.Sequential(
|
| 243 |
+
nn.Linear(3072, 256),
|
| 244 |
+
|
| 245 |
+
)
|
| 246 |
+
for param in self.openclip_convnext_xxl.parameters():
|
| 247 |
+
param.requires_grad = False
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def forward(self, x):
|
| 252 |
+
|
| 253 |
+
b, t, c, h, w = x.shape
|
| 254 |
+
|
| 255 |
+
x_minmin = x[:, 0] #[b, c, h, w]
|
| 256 |
+
x_maxmax = x[:, 1]
|
| 257 |
+
x_minmin1 = x[:, 2]
|
| 258 |
+
x_maxmax1 = x[:, 3]
|
| 259 |
+
tokens = x[:, 4]
|
| 260 |
+
|
| 261 |
+
x_minmin = self.hpf(x_minmin)
|
| 262 |
+
x_maxmax = self.hpf(x_maxmax)
|
| 263 |
+
x_minmin1 = self.hpf(x_minmin1)
|
| 264 |
+
x_maxmax1 = self.hpf(x_maxmax1)
|
| 265 |
+
|
| 266 |
+
with torch.no_grad():
|
| 267 |
+
|
| 268 |
+
clip_mean = torch.Tensor([0.48145466, 0.4578275, 0.40821073])
|
| 269 |
+
clip_mean = clip_mean.to(tokens, non_blocking=True).view(3, 1, 1)
|
| 270 |
+
clip_std = torch.Tensor([0.26862954, 0.26130258, 0.27577711])
|
| 271 |
+
clip_std = clip_std.to(tokens, non_blocking=True).view(3, 1, 1)
|
| 272 |
+
dinov2_mean = torch.Tensor([0.485, 0.456, 0.406]).to(tokens, non_blocking=True).view(3, 1, 1)
|
| 273 |
+
dinov2_std = torch.Tensor([0.229, 0.224, 0.225]).to(tokens, non_blocking=True).view(3, 1, 1)
|
| 274 |
+
|
| 275 |
+
local_convnext_image_feats = self.openclip_convnext_xxl(
|
| 276 |
+
tokens * (dinov2_std / clip_std) + (dinov2_mean - clip_mean) / clip_std
|
| 277 |
+
) #[b, 3072, 8, 8]
|
| 278 |
+
assert local_convnext_image_feats.size()[1:] == (3072, 8, 8)
|
| 279 |
+
local_convnext_image_feats = self.avgpool(local_convnext_image_feats).view(tokens.size(0), -1)
|
| 280 |
+
x_0 = self.convnext_proj(local_convnext_image_feats)
|
| 281 |
+
|
| 282 |
+
x_min = self.model_min(x_minmin)
|
| 283 |
+
x_max = self.model_max(x_maxmax)
|
| 284 |
+
x_min1 = self.model_min(x_minmin1)
|
| 285 |
+
x_max1 = self.model_max(x_maxmax1)
|
| 286 |
+
|
| 287 |
+
x_1 = (x_min + x_max + x_min1 + x_max1) / 4
|
| 288 |
+
|
| 289 |
+
x = torch.cat([x_0, x_1], dim=1)
|
| 290 |
+
|
| 291 |
+
x = self.fc(x)
|
| 292 |
+
|
| 293 |
+
return x
|
| 294 |
+
|
| 295 |
+
def AIDE(resnet_path, convnext_path):
|
| 296 |
+
model = AIDE_Model(resnet_path, convnext_path)
|
| 297 |
+
return model
|
| 298 |
+
|
detector_codes/AIDE-main/models/__init__.py
ADDED
|
File without changes
|
detector_codes/AIDE-main/models/clip/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .clip import *
|
detector_codes/AIDE-main/models/clip/bpe_simple_vocab_16e6.txt.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
| 3 |
+
size 1356917
|
detector_codes/AIDE-main/models/clip/clip.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hashlib
|
| 2 |
+
import os
|
| 3 |
+
import urllib
|
| 4 |
+
import warnings
|
| 5 |
+
from typing import Any, Union, List
|
| 6 |
+
from pkg_resources import packaging
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
from .model import build_model
|
| 14 |
+
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
from torchvision.transforms import InterpolationMode
|
| 18 |
+
BICUBIC = InterpolationMode.BICUBIC
|
| 19 |
+
except ImportError:
|
| 20 |
+
BICUBIC = Image.BICUBIC
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
|
| 24 |
+
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
__all__ = ["available_models", "load", "tokenize"]
|
| 28 |
+
_tokenizer = _Tokenizer()
|
| 29 |
+
|
| 30 |
+
_MODELS = {
|
| 31 |
+
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
| 32 |
+
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
|
| 33 |
+
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
|
| 34 |
+
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
|
| 35 |
+
"RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
|
| 36 |
+
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
| 37 |
+
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
|
| 38 |
+
"ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
|
| 39 |
+
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _download(url: str, root: str):
|
| 44 |
+
os.makedirs(root, exist_ok=True)
|
| 45 |
+
filename = os.path.basename(url)
|
| 46 |
+
|
| 47 |
+
expected_sha256 = url.split("/")[-2]
|
| 48 |
+
download_target = os.path.join(root, filename)
|
| 49 |
+
|
| 50 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
| 51 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
| 52 |
+
|
| 53 |
+
if os.path.isfile(download_target):
|
| 54 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
|
| 55 |
+
return download_target
|
| 56 |
+
else:
|
| 57 |
+
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
| 58 |
+
|
| 59 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
| 60 |
+
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
|
| 61 |
+
while True:
|
| 62 |
+
buffer = source.read(8192)
|
| 63 |
+
if not buffer:
|
| 64 |
+
break
|
| 65 |
+
|
| 66 |
+
output.write(buffer)
|
| 67 |
+
loop.update(len(buffer))
|
| 68 |
+
|
| 69 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
|
| 70 |
+
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
|
| 71 |
+
|
| 72 |
+
return download_target
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def _convert_image_to_rgb(image):
|
| 76 |
+
return image.convert("RGB")
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _transform(n_px):
|
| 80 |
+
return Compose([
|
| 81 |
+
Resize(n_px, interpolation=BICUBIC),
|
| 82 |
+
CenterCrop(n_px),
|
| 83 |
+
_convert_image_to_rgb,
|
| 84 |
+
ToTensor(),
|
| 85 |
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
| 86 |
+
])
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def available_models() -> List[str]:
|
| 90 |
+
"""Returns the names of available CLIP models"""
|
| 91 |
+
return list(_MODELS.keys())
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
|
| 95 |
+
"""Load a CLIP model
|
| 96 |
+
|
| 97 |
+
Parameters
|
| 98 |
+
----------
|
| 99 |
+
name : str
|
| 100 |
+
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
| 101 |
+
|
| 102 |
+
device : Union[str, torch.device]
|
| 103 |
+
The device to put the loaded model
|
| 104 |
+
|
| 105 |
+
jit : bool
|
| 106 |
+
Whether to load the optimized JIT model or more hackable non-JIT model (default).
|
| 107 |
+
|
| 108 |
+
download_root: str
|
| 109 |
+
path to download the model files; by default, it uses "~/.cache/clip"
|
| 110 |
+
|
| 111 |
+
Returns
|
| 112 |
+
-------
|
| 113 |
+
model : torch.nn.Module
|
| 114 |
+
The CLIP model
|
| 115 |
+
|
| 116 |
+
preprocess : Callable[[PIL.Image], torch.Tensor]
|
| 117 |
+
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
| 118 |
+
"""
|
| 119 |
+
if name in _MODELS:
|
| 120 |
+
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
|
| 121 |
+
elif os.path.isfile(name):
|
| 122 |
+
model_path = name
|
| 123 |
+
else:
|
| 124 |
+
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
| 125 |
+
|
| 126 |
+
with open(model_path, 'rb') as opened_file:
|
| 127 |
+
try:
|
| 128 |
+
# loading JIT archive
|
| 129 |
+
model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
|
| 130 |
+
state_dict = None
|
| 131 |
+
except RuntimeError:
|
| 132 |
+
# loading saved state dict
|
| 133 |
+
if jit:
|
| 134 |
+
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
| 135 |
+
jit = False
|
| 136 |
+
state_dict = torch.load(opened_file, map_location="cpu")
|
| 137 |
+
|
| 138 |
+
if not jit:
|
| 139 |
+
model = build_model(state_dict or model.state_dict()).to(device)
|
| 140 |
+
if str(device) == "cpu":
|
| 141 |
+
model.float()
|
| 142 |
+
return model, _transform(model.visual.input_resolution)
|
| 143 |
+
|
| 144 |
+
# patch the device names
|
| 145 |
+
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
| 146 |
+
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
| 147 |
+
|
| 148 |
+
def patch_device(module):
|
| 149 |
+
try:
|
| 150 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
| 151 |
+
except RuntimeError:
|
| 152 |
+
graphs = []
|
| 153 |
+
|
| 154 |
+
if hasattr(module, "forward1"):
|
| 155 |
+
graphs.append(module.forward1.graph)
|
| 156 |
+
|
| 157 |
+
for graph in graphs:
|
| 158 |
+
for node in graph.findAllNodes("prim::Constant"):
|
| 159 |
+
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
|
| 160 |
+
node.copyAttributes(device_node)
|
| 161 |
+
|
| 162 |
+
model.apply(patch_device)
|
| 163 |
+
patch_device(model.encode_image)
|
| 164 |
+
patch_device(model.encode_text)
|
| 165 |
+
|
| 166 |
+
# patch dtype to float32 on CPU
|
| 167 |
+
if str(device) == "cpu":
|
| 168 |
+
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
| 169 |
+
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
| 170 |
+
float_node = float_input.node()
|
| 171 |
+
|
| 172 |
+
def patch_float(module):
|
| 173 |
+
try:
|
| 174 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
| 175 |
+
except RuntimeError:
|
| 176 |
+
graphs = []
|
| 177 |
+
|
| 178 |
+
if hasattr(module, "forward1"):
|
| 179 |
+
graphs.append(module.forward1.graph)
|
| 180 |
+
|
| 181 |
+
for graph in graphs:
|
| 182 |
+
for node in graph.findAllNodes("aten::to"):
|
| 183 |
+
inputs = list(node.inputs())
|
| 184 |
+
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
| 185 |
+
if inputs[i].node()["value"] == 5:
|
| 186 |
+
inputs[i].node().copyAttributes(float_node)
|
| 187 |
+
|
| 188 |
+
model.apply(patch_float)
|
| 189 |
+
patch_float(model.encode_image)
|
| 190 |
+
patch_float(model.encode_text)
|
| 191 |
+
|
| 192 |
+
model.float()
|
| 193 |
+
|
| 194 |
+
return model, _transform(model.input_resolution.item())
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
|
| 198 |
+
"""
|
| 199 |
+
Returns the tokenized representation of given input string(s)
|
| 200 |
+
|
| 201 |
+
Parameters
|
| 202 |
+
----------
|
| 203 |
+
texts : Union[str, List[str]]
|
| 204 |
+
An input string or a list of input strings to tokenize
|
| 205 |
+
|
| 206 |
+
context_length : int
|
| 207 |
+
The context length to use; all CLIP models use 77 as the context length
|
| 208 |
+
|
| 209 |
+
truncate: bool
|
| 210 |
+
Whether to truncate the text in case its encoding is longer than the context length
|
| 211 |
+
|
| 212 |
+
Returns
|
| 213 |
+
-------
|
| 214 |
+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
|
| 215 |
+
We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
|
| 216 |
+
"""
|
| 217 |
+
if isinstance(texts, str):
|
| 218 |
+
texts = [texts]
|
| 219 |
+
|
| 220 |
+
sot_token = _tokenizer.encoder["<|startoftext|>"]
|
| 221 |
+
eot_token = _tokenizer.encoder["<|endoftext|>"]
|
| 222 |
+
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
| 223 |
+
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
|
| 224 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
| 225 |
+
else:
|
| 226 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
|
| 227 |
+
|
| 228 |
+
for i, tokens in enumerate(all_tokens):
|
| 229 |
+
if len(tokens) > context_length:
|
| 230 |
+
if truncate:
|
| 231 |
+
tokens = tokens[:context_length]
|
| 232 |
+
tokens[-1] = eot_token
|
| 233 |
+
else:
|
| 234 |
+
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
|
| 235 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
| 236 |
+
|
| 237 |
+
return result
|
detector_codes/AIDE-main/models/clip/lora_clip.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from models.clip import clip
|
| 5 |
+
import math
|
| 6 |
+
import copy
|
| 7 |
+
from collections import OrderedDict
|
| 8 |
+
import numpy as np
|
| 9 |
+
from torch.utils.data import Dataset, DataLoader
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import os
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class LoRALayer(nn.Module):
|
| 15 |
+
"""
|
| 16 |
+
LoRA层的实现:添加低秩矩阵以高效微调大型模型
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, in_dim, out_dim, rank=4, alpha=1.0):
|
| 20 |
+
super().__init__()
|
| 21 |
+
# 缩放系数,决定了LoRA部分的权重
|
| 22 |
+
self.scale = alpha / rank
|
| 23 |
+
|
| 24 |
+
# 创建LoRA的低秩矩阵A和B
|
| 25 |
+
self.lora_down = nn.Linear(in_dim, rank, bias=False)
|
| 26 |
+
self.lora_up = nn.Linear(rank, out_dim, bias=False)
|
| 27 |
+
|
| 28 |
+
# 初始化:下投影随机初始化,上投影初始化为0
|
| 29 |
+
nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
| 30 |
+
nn.init.zeros_(self.lora_up.weight)
|
| 31 |
+
|
| 32 |
+
def forward(self, x):
|
| 33 |
+
# 计算LoRA的贡献并返回
|
| 34 |
+
return self.lora_up(self.lora_down(x)) * self.scale
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# ---------------------- 2. CLIP的LoRA包装器 ----------------------
|
| 38 |
+
|
| 39 |
+
class CLIPWithLoRA(nn.Module):
|
| 40 |
+
"""
|
| 41 |
+
为CLIP的图像编码器添加LoRA适配层的完整实现
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(self, base_clip_model, lora_rank=8, lora_alpha=16):
|
| 45 |
+
super().__init__()
|
| 46 |
+
# 保存原始的CLIP模型
|
| 47 |
+
self.clip = base_clip_model
|
| 48 |
+
|
| 49 |
+
# 冻结原始CLIP的所有参数
|
| 50 |
+
for param in self.clip.parameters():
|
| 51 |
+
param.requires_grad = False
|
| 52 |
+
|
| 53 |
+
# 分析CLIP的图像编码器结构
|
| 54 |
+
# 对于ViT-L/14,每个Transformer块有三个投影矩阵
|
| 55 |
+
# 我们将为每个投影矩阵添加LoRA层
|
| 56 |
+
self.image_encoder = self.clip.visual
|
| 57 |
+
|
| 58 |
+
# 检查并确定这是一个Vision Transformer结构
|
| 59 |
+
assert hasattr(self.image_encoder, 'transformer'), "不支持的CLIP视觉编码器类型"
|
| 60 |
+
|
| 61 |
+
# 创建LoRA层字典
|
| 62 |
+
self.lora_layers = nn.ModuleDict()
|
| 63 |
+
|
| 64 |
+
# 获取隐藏层维度(对于ViT-L/14,通常是1024)
|
| 65 |
+
if hasattr(self.image_encoder, 'width'):
|
| 66 |
+
hidden_dim = self.image_encoder.width
|
| 67 |
+
else:
|
| 68 |
+
hidden_dim = self.image_encoder.transformer.width
|
| 69 |
+
|
| 70 |
+
# 为每个Transformer块的注意力层添加LoRA
|
| 71 |
+
for block_idx, block in enumerate(self.image_encoder.transformer.resblocks):
|
| 72 |
+
# 为QKV矩阵添加LoRA
|
| 73 |
+
# 注意:根据OpenAI的CLIP实现,通常QKV是在一个单一的投影中
|
| 74 |
+
attn = block.attn
|
| 75 |
+
|
| 76 |
+
# 添加LoRA到Query投影
|
| 77 |
+
self.lora_layers[f"block_{block_idx}_q"] = LoRALayer(
|
| 78 |
+
hidden_dim, hidden_dim,
|
| 79 |
+
rank=lora_rank, alpha=lora_alpha
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# 添加LoRA到Key投影
|
| 83 |
+
self.lora_layers[f"block_{block_idx}_k"] = LoRALayer(
|
| 84 |
+
hidden_dim, hidden_dim,
|
| 85 |
+
rank=lora_rank, alpha=lora_alpha
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# 添加LoRA到Value投影
|
| 89 |
+
self.lora_layers[f"block_{block_idx}_v"] = LoRALayer(
|
| 90 |
+
hidden_dim, hidden_dim,
|
| 91 |
+
rank=lora_rank, alpha=lora_alpha
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# 修改注意力的前向传播方法
|
| 95 |
+
self._patch_attention_forward(block.attn, block_idx)
|
| 96 |
+
|
| 97 |
+
def _patch_attention_forward(self, attn_module, block_idx):
|
| 98 |
+
"""
|
| 99 |
+
通过钩子技术修改注意力层的前向传播,插入LoRA计算
|
| 100 |
+
"""
|
| 101 |
+
# 保存原始的前向传播方法
|
| 102 |
+
original_forward = attn_module.forward
|
| 103 |
+
|
| 104 |
+
# 保存self引用以在闭包中使用
|
| 105 |
+
_self = self
|
| 106 |
+
|
| 107 |
+
# 定义新的前向传播方法
|
| 108 |
+
def lora_forward(self, query, key, value, need_weights=False, **kwargs):
|
| 109 |
+
# 获取模型尺寸
|
| 110 |
+
B, N, C = query.shape
|
| 111 |
+
|
| 112 |
+
# 原始的QKV投影
|
| 113 |
+
original_qkv = self.in_proj_weight
|
| 114 |
+
|
| 115 |
+
# 将输入分成三部分用于q、k、v计算
|
| 116 |
+
qkv = F.linear(query, original_qkv, self.in_proj_bias)
|
| 117 |
+
qkv = qkv.reshape(B, N, 3, C).permute(2, 0, 1, 3) # [3, B, N, C]
|
| 118 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 119 |
+
|
| 120 |
+
# 应用LoRA修改到q、k、v
|
| 121 |
+
q = q + _self.lora_layers[f"block_{block_idx}_q"](query)
|
| 122 |
+
k = k + _self.lora_layers[f"block_{block_idx}_k"](query)
|
| 123 |
+
v = v + _self.lora_layers[f"block_{block_idx}_v"](query)
|
| 124 |
+
|
| 125 |
+
# 继续原始注意力计算
|
| 126 |
+
self.scale = q.size(-1) ** -0.5
|
| 127 |
+
q = q * self.scale
|
| 128 |
+
attn = (q @ k.transpose(-2, -1))
|
| 129 |
+
attn = attn.softmax(dim=-1)
|
| 130 |
+
# attn = self.attn_dropout(attn)
|
| 131 |
+
|
| 132 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 133 |
+
x = self.out_proj(x)
|
| 134 |
+
# x = self.out_dropout(x)
|
| 135 |
+
|
| 136 |
+
if need_weights:
|
| 137 |
+
return x, attn
|
| 138 |
+
return x
|
| 139 |
+
|
| 140 |
+
# 替换原始的前向传播方法
|
| 141 |
+
attn_module.forward = lora_forward.__get__(attn_module, type(attn_module))
|
| 142 |
+
|
| 143 |
+
def forward(self, image):
|
| 144 |
+
"""
|
| 145 |
+
模型的前向传播,使用LoRA修改的CLIP进行图像编码
|
| 146 |
+
"""
|
| 147 |
+
# 直接使用修改后的CLIP模型
|
| 148 |
+
image_features = self.clip.encode_image(image)
|
| 149 |
+
|
| 150 |
+
# 返回归一化的图像特征
|
| 151 |
+
return image_features
|
detector_codes/AIDE-main/models/clip/model.py
ADDED
|
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
from typing import Tuple, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Bottleneck(nn.Module):
|
| 11 |
+
expansion = 4
|
| 12 |
+
|
| 13 |
+
def __init__(self, inplanes, planes, stride=1):
|
| 14 |
+
super().__init__()
|
| 15 |
+
|
| 16 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
| 17 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
| 18 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 19 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 20 |
+
|
| 21 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
| 22 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 23 |
+
self.relu2 = nn.ReLU(inplace=True)
|
| 24 |
+
|
| 25 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
| 26 |
+
|
| 27 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
| 28 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
| 29 |
+
self.relu3 = nn.ReLU(inplace=True)
|
| 30 |
+
|
| 31 |
+
self.downsample = None
|
| 32 |
+
self.stride = stride
|
| 33 |
+
|
| 34 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
| 35 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
| 36 |
+
self.downsample = nn.Sequential(OrderedDict([
|
| 37 |
+
("-1", nn.AvgPool2d(stride)),
|
| 38 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
| 39 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
| 40 |
+
]))
|
| 41 |
+
|
| 42 |
+
def forward(self, x: torch.Tensor):
|
| 43 |
+
identity = x
|
| 44 |
+
|
| 45 |
+
out = self.relu1(self.bn1(self.conv1(x)))
|
| 46 |
+
out = self.relu2(self.bn2(self.conv2(out)))
|
| 47 |
+
out = self.avgpool(out)
|
| 48 |
+
out = self.bn3(self.conv3(out))
|
| 49 |
+
|
| 50 |
+
if self.downsample is not None:
|
| 51 |
+
identity = self.downsample(x)
|
| 52 |
+
|
| 53 |
+
out += identity
|
| 54 |
+
out = self.relu3(out)
|
| 55 |
+
return out
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class AttentionPool2d(nn.Module):
|
| 59 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
| 62 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
| 63 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
| 64 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
| 65 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
| 66 |
+
self.num_heads = num_heads
|
| 67 |
+
|
| 68 |
+
def forward(self, x):
|
| 69 |
+
x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
|
| 70 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
| 71 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
| 72 |
+
x, _ = F.multi_head_attention_forward(
|
| 73 |
+
query=x[:1], key=x, value=x,
|
| 74 |
+
embed_dim_to_check=x.shape[-1],
|
| 75 |
+
num_heads=self.num_heads,
|
| 76 |
+
q_proj_weight=self.q_proj.weight,
|
| 77 |
+
k_proj_weight=self.k_proj.weight,
|
| 78 |
+
v_proj_weight=self.v_proj.weight,
|
| 79 |
+
in_proj_weight=None,
|
| 80 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
| 81 |
+
bias_k=None,
|
| 82 |
+
bias_v=None,
|
| 83 |
+
add_zero_attn=False,
|
| 84 |
+
dropout_p=0,
|
| 85 |
+
out_proj_weight=self.c_proj.weight,
|
| 86 |
+
out_proj_bias=self.c_proj.bias,
|
| 87 |
+
use_separate_proj_weight=True,
|
| 88 |
+
training=self.training,
|
| 89 |
+
need_weights=False
|
| 90 |
+
)
|
| 91 |
+
return x.squeeze(0)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class ModifiedResNet(nn.Module):
|
| 95 |
+
"""
|
| 96 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
| 97 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
| 98 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
| 99 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
| 103 |
+
super().__init__()
|
| 104 |
+
self.output_dim = output_dim
|
| 105 |
+
self.input_resolution = input_resolution
|
| 106 |
+
|
| 107 |
+
# the 3-layer stem
|
| 108 |
+
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
| 109 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
| 110 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 111 |
+
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
| 112 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
| 113 |
+
self.relu2 = nn.ReLU(inplace=True)
|
| 114 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
| 115 |
+
self.bn3 = nn.BatchNorm2d(width)
|
| 116 |
+
self.relu3 = nn.ReLU(inplace=True)
|
| 117 |
+
self.avgpool = nn.AvgPool2d(2)
|
| 118 |
+
|
| 119 |
+
# residual layers
|
| 120 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
| 121 |
+
self.layer1 = self._make_layer(width, layers[0])
|
| 122 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
| 123 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
| 124 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
| 125 |
+
|
| 126 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
| 127 |
+
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
| 128 |
+
|
| 129 |
+
def _make_layer(self, planes, blocks, stride=1):
|
| 130 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
| 131 |
+
|
| 132 |
+
self._inplanes = planes * Bottleneck.expansion
|
| 133 |
+
for _ in range(1, blocks):
|
| 134 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
| 135 |
+
|
| 136 |
+
return nn.Sequential(*layers)
|
| 137 |
+
|
| 138 |
+
def forward(self, x):
|
| 139 |
+
def stem(x):
|
| 140 |
+
x = self.relu1(self.bn1(self.conv1(x)))
|
| 141 |
+
x = self.relu2(self.bn2(self.conv2(x)))
|
| 142 |
+
x = self.relu3(self.bn3(self.conv3(x)))
|
| 143 |
+
x = self.avgpool(x)
|
| 144 |
+
return x
|
| 145 |
+
|
| 146 |
+
x = x.type(self.conv1.weight.dtype)
|
| 147 |
+
x = stem(x)
|
| 148 |
+
x = self.layer1(x)
|
| 149 |
+
x = self.layer2(x)
|
| 150 |
+
x = self.layer3(x)
|
| 151 |
+
x = self.layer4(x)
|
| 152 |
+
x = self.attnpool(x)
|
| 153 |
+
|
| 154 |
+
return x
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class LayerNorm(nn.LayerNorm):
|
| 158 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
| 159 |
+
|
| 160 |
+
def forward(self, x: torch.Tensor):
|
| 161 |
+
orig_type = x.dtype
|
| 162 |
+
ret = super().forward(x.type(torch.float32))
|
| 163 |
+
return ret.type(orig_type)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class QuickGELU(nn.Module):
|
| 167 |
+
def forward(self, x: torch.Tensor):
|
| 168 |
+
return x * torch.sigmoid(1.702 * x)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class ResidualAttentionBlock(nn.Module):
|
| 172 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
| 173 |
+
super().__init__()
|
| 174 |
+
|
| 175 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
| 176 |
+
self.ln_1 = LayerNorm(d_model)
|
| 177 |
+
self.mlp = nn.Sequential(OrderedDict([
|
| 178 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
| 179 |
+
("gelu", QuickGELU()),
|
| 180 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
| 181 |
+
]))
|
| 182 |
+
self.ln_2 = LayerNorm(d_model)
|
| 183 |
+
self.attn_mask = attn_mask
|
| 184 |
+
|
| 185 |
+
def attention(self, x: torch.Tensor):
|
| 186 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
| 187 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
| 188 |
+
|
| 189 |
+
def forward(self, x: torch.Tensor):
|
| 190 |
+
x = x + self.attention(self.ln_1(x))
|
| 191 |
+
x = x + self.mlp(self.ln_2(x))
|
| 192 |
+
return x
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class Transformer(nn.Module):
|
| 196 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
|
| 197 |
+
super().__init__()
|
| 198 |
+
self.width = width
|
| 199 |
+
self.layers = layers
|
| 200 |
+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
|
| 201 |
+
|
| 202 |
+
def forward(self, x: torch.Tensor):
|
| 203 |
+
out = {}
|
| 204 |
+
for idx, layer in enumerate(self.resblocks.children()):
|
| 205 |
+
x = layer(x)
|
| 206 |
+
out['layer'+str(idx)] = x[0] # shape:LND. choose cls token feature
|
| 207 |
+
return out, x
|
| 208 |
+
|
| 209 |
+
# return self.resblocks(x) # This is the original code
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class VisionTransformer(nn.Module):
|
| 213 |
+
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
| 214 |
+
super().__init__()
|
| 215 |
+
self.input_resolution = input_resolution
|
| 216 |
+
self.output_dim = output_dim
|
| 217 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
| 218 |
+
|
| 219 |
+
scale = width ** -0.5
|
| 220 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
| 221 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
| 222 |
+
self.ln_pre = LayerNorm(width)
|
| 223 |
+
|
| 224 |
+
self.transformer = Transformer(width, layers, heads)
|
| 225 |
+
|
| 226 |
+
self.ln_post = LayerNorm(width)
|
| 227 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def forward(self, x: torch.Tensor):
|
| 232 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
| 233 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
| 234 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
| 235 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
| 236 |
+
x = x + self.positional_embedding.to(x.dtype)
|
| 237 |
+
x = self.ln_pre(x)
|
| 238 |
+
|
| 239 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 240 |
+
out, x = self.transformer(x)
|
| 241 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 242 |
+
|
| 243 |
+
x = self.ln_post(x[:, 0, :])
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
out['before_projection'] = x
|
| 247 |
+
|
| 248 |
+
if self.proj is not None:
|
| 249 |
+
x = x @ self.proj
|
| 250 |
+
out['after_projection'] = x
|
| 251 |
+
|
| 252 |
+
# Return both intermediate features and final clip feature
|
| 253 |
+
# return out
|
| 254 |
+
|
| 255 |
+
# This only returns CLIP features
|
| 256 |
+
return x
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
class CLIP(nn.Module):
|
| 260 |
+
def __init__(self,
|
| 261 |
+
embed_dim: int,
|
| 262 |
+
# vision
|
| 263 |
+
image_resolution: int,
|
| 264 |
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
| 265 |
+
vision_width: int,
|
| 266 |
+
vision_patch_size: int,
|
| 267 |
+
# text
|
| 268 |
+
context_length: int,
|
| 269 |
+
vocab_size: int,
|
| 270 |
+
transformer_width: int,
|
| 271 |
+
transformer_heads: int,
|
| 272 |
+
transformer_layers: int
|
| 273 |
+
):
|
| 274 |
+
super().__init__()
|
| 275 |
+
|
| 276 |
+
self.context_length = context_length
|
| 277 |
+
|
| 278 |
+
if isinstance(vision_layers, (tuple, list)):
|
| 279 |
+
vision_heads = vision_width * 32 // 64
|
| 280 |
+
self.visual = ModifiedResNet(
|
| 281 |
+
layers=vision_layers,
|
| 282 |
+
output_dim=embed_dim,
|
| 283 |
+
heads=vision_heads,
|
| 284 |
+
input_resolution=image_resolution,
|
| 285 |
+
width=vision_width
|
| 286 |
+
)
|
| 287 |
+
else:
|
| 288 |
+
vision_heads = vision_width // 64
|
| 289 |
+
self.visual = VisionTransformer(
|
| 290 |
+
input_resolution=image_resolution,
|
| 291 |
+
patch_size=vision_patch_size,
|
| 292 |
+
width=vision_width,
|
| 293 |
+
layers=vision_layers,
|
| 294 |
+
heads=vision_heads,
|
| 295 |
+
output_dim=embed_dim
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
self.transformer = Transformer(
|
| 299 |
+
width=transformer_width,
|
| 300 |
+
layers=transformer_layers,
|
| 301 |
+
heads=transformer_heads,
|
| 302 |
+
attn_mask=self.build_attention_mask()
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
self.vocab_size = vocab_size
|
| 306 |
+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
| 307 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
| 308 |
+
self.ln_final = LayerNorm(transformer_width)
|
| 309 |
+
|
| 310 |
+
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
| 311 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 312 |
+
|
| 313 |
+
self.initialize_parameters()
|
| 314 |
+
|
| 315 |
+
def initialize_parameters(self):
|
| 316 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
| 317 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
| 318 |
+
|
| 319 |
+
if isinstance(self.visual, ModifiedResNet):
|
| 320 |
+
if self.visual.attnpool is not None:
|
| 321 |
+
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
| 322 |
+
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
| 323 |
+
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
| 324 |
+
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
| 325 |
+
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
| 326 |
+
|
| 327 |
+
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
| 328 |
+
for name, param in resnet_block.named_parameters():
|
| 329 |
+
if name.endswith("bn3.weight"):
|
| 330 |
+
nn.init.zeros_(param)
|
| 331 |
+
|
| 332 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
| 333 |
+
attn_std = self.transformer.width ** -0.5
|
| 334 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
| 335 |
+
for block in self.transformer.resblocks:
|
| 336 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
| 337 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
| 338 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
| 339 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
| 340 |
+
|
| 341 |
+
if self.text_projection is not None:
|
| 342 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
| 343 |
+
|
| 344 |
+
def build_attention_mask(self):
|
| 345 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
| 346 |
+
# pytorch uses additive attention mask; fill with -inf
|
| 347 |
+
mask = torch.empty(self.context_length, self.context_length)
|
| 348 |
+
mask.fill_(float("-inf"))
|
| 349 |
+
mask.triu_(1) # zero out the lower diagonal
|
| 350 |
+
return mask
|
| 351 |
+
|
| 352 |
+
@property
|
| 353 |
+
def dtype(self):
|
| 354 |
+
return self.visual.conv1.weight.dtype
|
| 355 |
+
|
| 356 |
+
def encode_image(self, image):
|
| 357 |
+
return self.visual(image.type(self.dtype))
|
| 358 |
+
|
| 359 |
+
def encode_text(self, text):
|
| 360 |
+
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
| 361 |
+
|
| 362 |
+
x = x + self.positional_embedding.type(self.dtype)
|
| 363 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 364 |
+
x = self.transformer(x)
|
| 365 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 366 |
+
x = self.ln_final(x).type(self.dtype)
|
| 367 |
+
|
| 368 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
| 369 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
| 370 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
| 371 |
+
|
| 372 |
+
return x
|
| 373 |
+
|
| 374 |
+
def forward(self, image, text):
|
| 375 |
+
image_features = self.encode_image(image)
|
| 376 |
+
text_features = self.encode_text(text)
|
| 377 |
+
|
| 378 |
+
# normalized features
|
| 379 |
+
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
| 380 |
+
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
| 381 |
+
|
| 382 |
+
# cosine similarity as logits
|
| 383 |
+
logit_scale = self.logit_scale.exp()
|
| 384 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
| 385 |
+
logits_per_text = logits_per_image.t()
|
| 386 |
+
|
| 387 |
+
# shape = [global_batch_size, global_batch_size]
|
| 388 |
+
return logits_per_image, logits_per_text
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def convert_weights(model: nn.Module):
|
| 392 |
+
"""Convert applicable model parameters to fp16"""
|
| 393 |
+
|
| 394 |
+
def _convert_weights_to_fp16(l):
|
| 395 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
| 396 |
+
l.weight.data = l.weight.data.half()
|
| 397 |
+
if l.bias is not None:
|
| 398 |
+
l.bias.data = l.bias.data.half()
|
| 399 |
+
|
| 400 |
+
if isinstance(l, nn.MultiheadAttention):
|
| 401 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
| 402 |
+
tensor = getattr(l, attr)
|
| 403 |
+
if tensor is not None:
|
| 404 |
+
tensor.data = tensor.data.half()
|
| 405 |
+
|
| 406 |
+
for name in ["text_projection", "proj"]:
|
| 407 |
+
if hasattr(l, name):
|
| 408 |
+
attr = getattr(l, name)
|
| 409 |
+
if attr is not None:
|
| 410 |
+
attr.data = attr.data.half()
|
| 411 |
+
|
| 412 |
+
model.apply(_convert_weights_to_fp16)
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def build_model(state_dict: dict):
|
| 416 |
+
vit = "visual.proj" in state_dict
|
| 417 |
+
|
| 418 |
+
if vit:
|
| 419 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
| 420 |
+
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
| 421 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
| 422 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
| 423 |
+
image_resolution = vision_patch_size * grid_size
|
| 424 |
+
else:
|
| 425 |
+
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
| 426 |
+
vision_layers = tuple(counts)
|
| 427 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
| 428 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
| 429 |
+
vision_patch_size = None
|
| 430 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
| 431 |
+
image_resolution = output_width * 32
|
| 432 |
+
|
| 433 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
| 434 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
| 435 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
| 436 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
| 437 |
+
transformer_heads = transformer_width // 64
|
| 438 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
|
| 439 |
+
|
| 440 |
+
model = CLIP(
|
| 441 |
+
embed_dim,
|
| 442 |
+
image_resolution, vision_layers, vision_width, vision_patch_size,
|
| 443 |
+
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
| 447 |
+
if key in state_dict:
|
| 448 |
+
del state_dict[key]
|
| 449 |
+
|
| 450 |
+
convert_weights(model)
|
| 451 |
+
model.load_state_dict(state_dict)
|
| 452 |
+
return model.eval()
|
detector_codes/AIDE-main/models/clip/simple_tokenizer.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gzip
|
| 2 |
+
import html
|
| 3 |
+
import os
|
| 4 |
+
from functools import lru_cache
|
| 5 |
+
|
| 6 |
+
import ftfy
|
| 7 |
+
import regex as re
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@lru_cache()
|
| 11 |
+
def default_bpe():
|
| 12 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@lru_cache()
|
| 16 |
+
def bytes_to_unicode():
|
| 17 |
+
"""
|
| 18 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
| 19 |
+
The reversible bpe codes work on unicode strings.
|
| 20 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
| 21 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
| 22 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
| 23 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
| 24 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
| 25 |
+
"""
|
| 26 |
+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
| 27 |
+
cs = bs[:]
|
| 28 |
+
n = 0
|
| 29 |
+
for b in range(2**8):
|
| 30 |
+
if b not in bs:
|
| 31 |
+
bs.append(b)
|
| 32 |
+
cs.append(2**8+n)
|
| 33 |
+
n += 1
|
| 34 |
+
cs = [chr(n) for n in cs]
|
| 35 |
+
return dict(zip(bs, cs))
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_pairs(word):
|
| 39 |
+
"""Return set of symbol pairs in a word.
|
| 40 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
| 41 |
+
"""
|
| 42 |
+
pairs = set()
|
| 43 |
+
prev_char = word[0]
|
| 44 |
+
for char in word[1:]:
|
| 45 |
+
pairs.add((prev_char, char))
|
| 46 |
+
prev_char = char
|
| 47 |
+
return pairs
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def basic_clean(text):
|
| 51 |
+
text = ftfy.fix_text(text)
|
| 52 |
+
text = html.unescape(html.unescape(text))
|
| 53 |
+
return text.strip()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def whitespace_clean(text):
|
| 57 |
+
text = re.sub(r'\s+', ' ', text)
|
| 58 |
+
text = text.strip()
|
| 59 |
+
return text
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class SimpleTokenizer(object):
|
| 63 |
+
def __init__(self, bpe_path: str = default_bpe()):
|
| 64 |
+
self.byte_encoder = bytes_to_unicode()
|
| 65 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
| 66 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
| 67 |
+
merges = merges[1:49152-256-2+1]
|
| 68 |
+
merges = [tuple(merge.split()) for merge in merges]
|
| 69 |
+
vocab = list(bytes_to_unicode().values())
|
| 70 |
+
vocab = vocab + [v+'</w>' for v in vocab]
|
| 71 |
+
for merge in merges:
|
| 72 |
+
vocab.append(''.join(merge))
|
| 73 |
+
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
| 74 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
| 75 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
| 76 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
| 77 |
+
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
| 78 |
+
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
| 79 |
+
|
| 80 |
+
def bpe(self, token):
|
| 81 |
+
if token in self.cache:
|
| 82 |
+
return self.cache[token]
|
| 83 |
+
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
| 84 |
+
pairs = get_pairs(word)
|
| 85 |
+
|
| 86 |
+
if not pairs:
|
| 87 |
+
return token+'</w>'
|
| 88 |
+
|
| 89 |
+
while True:
|
| 90 |
+
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
| 91 |
+
if bigram not in self.bpe_ranks:
|
| 92 |
+
break
|
| 93 |
+
first, second = bigram
|
| 94 |
+
new_word = []
|
| 95 |
+
i = 0
|
| 96 |
+
while i < len(word):
|
| 97 |
+
try:
|
| 98 |
+
j = word.index(first, i)
|
| 99 |
+
new_word.extend(word[i:j])
|
| 100 |
+
i = j
|
| 101 |
+
except:
|
| 102 |
+
new_word.extend(word[i:])
|
| 103 |
+
break
|
| 104 |
+
|
| 105 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
| 106 |
+
new_word.append(first+second)
|
| 107 |
+
i += 2
|
| 108 |
+
else:
|
| 109 |
+
new_word.append(word[i])
|
| 110 |
+
i += 1
|
| 111 |
+
new_word = tuple(new_word)
|
| 112 |
+
word = new_word
|
| 113 |
+
if len(word) == 1:
|
| 114 |
+
break
|
| 115 |
+
else:
|
| 116 |
+
pairs = get_pairs(word)
|
| 117 |
+
word = ' '.join(word)
|
| 118 |
+
self.cache[token] = word
|
| 119 |
+
return word
|
| 120 |
+
|
| 121 |
+
def encode(self, text):
|
| 122 |
+
bpe_tokens = []
|
| 123 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
| 124 |
+
for token in re.findall(self.pat, text):
|
| 125 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
| 126 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
| 127 |
+
return bpe_tokens
|
| 128 |
+
|
| 129 |
+
def decode(self, tokens):
|
| 130 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
| 131 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
| 132 |
+
return text
|
detector_codes/AIDE-main/models/srm_filter_kernel.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
filter_class_1 = [
|
| 5 |
+
np.array([
|
| 6 |
+
[1, 0, 0],
|
| 7 |
+
[0, -1, 0],
|
| 8 |
+
[0, 0, 0]
|
| 9 |
+
], dtype=np.float32),
|
| 10 |
+
np.array([
|
| 11 |
+
[0, 1, 0],
|
| 12 |
+
[0, -1, 0],
|
| 13 |
+
[0, 0, 0]
|
| 14 |
+
], dtype=np.float32),
|
| 15 |
+
np.array([
|
| 16 |
+
[0, 0, 1],
|
| 17 |
+
[0, -1, 0],
|
| 18 |
+
[0, 0, 0]
|
| 19 |
+
], dtype=np.float32),
|
| 20 |
+
np.array([
|
| 21 |
+
[0, 0, 0],
|
| 22 |
+
[1, -1, 0],
|
| 23 |
+
[0, 0, 0]
|
| 24 |
+
], dtype=np.float32),
|
| 25 |
+
np.array([
|
| 26 |
+
[0, 0, 0],
|
| 27 |
+
[0, -1, 1],
|
| 28 |
+
[0, 0, 0]
|
| 29 |
+
], dtype=np.float32),
|
| 30 |
+
np.array([
|
| 31 |
+
[0, 0, 0],
|
| 32 |
+
[0, -1, 0],
|
| 33 |
+
[1, 0, 0]
|
| 34 |
+
], dtype=np.float32),
|
| 35 |
+
np.array([
|
| 36 |
+
[0, 0, 0],
|
| 37 |
+
[0, -1, 0],
|
| 38 |
+
[0, 1, 0]
|
| 39 |
+
], dtype=np.float32),
|
| 40 |
+
np.array([
|
| 41 |
+
[0, 0, 0],
|
| 42 |
+
[0, -1, 0],
|
| 43 |
+
[0, 0, 1]
|
| 44 |
+
], dtype=np.float32)
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
filter_class_2 = [
|
| 49 |
+
np.array([
|
| 50 |
+
[1, 0, 0],
|
| 51 |
+
[0, -2, 0],
|
| 52 |
+
[0, 0, 1]
|
| 53 |
+
], dtype=np.float32),
|
| 54 |
+
np.array([
|
| 55 |
+
[0, 1, 0],
|
| 56 |
+
[0, -2, 0],
|
| 57 |
+
[0, 1, 0]
|
| 58 |
+
], dtype=np.float32),
|
| 59 |
+
np.array([
|
| 60 |
+
[0, 0, 1],
|
| 61 |
+
[0, -2, 0],
|
| 62 |
+
[1, 0, 0]
|
| 63 |
+
], dtype=np.float32),
|
| 64 |
+
np.array([
|
| 65 |
+
[0, 0, 0],
|
| 66 |
+
[1, -2, 1],
|
| 67 |
+
[0, 0, 0]
|
| 68 |
+
], dtype=np.float32),
|
| 69 |
+
]
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
filter_class_3 = [
|
| 73 |
+
np.array([
|
| 74 |
+
[-1, 0, 0, 0, 0],
|
| 75 |
+
[0, 3, 0, 0, 0],
|
| 76 |
+
[0, 0, -3, 0, 0],
|
| 77 |
+
[0, 0, 0, 1, 0],
|
| 78 |
+
[0, 0, 0, 0, 0]
|
| 79 |
+
], dtype=np.float32),
|
| 80 |
+
np.array([
|
| 81 |
+
[0, 0, -1, 0, 0],
|
| 82 |
+
[0, 0, 3, 0, 0],
|
| 83 |
+
[0, 0, -3, 0, 0],
|
| 84 |
+
[0, 0, 1, 0, 0],
|
| 85 |
+
[0, 0, 0, 0, 0]
|
| 86 |
+
], dtype=np.float32),
|
| 87 |
+
np.array([
|
| 88 |
+
[0, 0, 0, 0, -1],
|
| 89 |
+
[0, 0, 0, 3, 0],
|
| 90 |
+
[0, 0, -3, 0, 0],
|
| 91 |
+
[0, 1, 0, 0, 0],
|
| 92 |
+
[0, 0, 0, 0, 0]
|
| 93 |
+
], dtype=np.float32),
|
| 94 |
+
np.array([
|
| 95 |
+
[0, 0, 0, 0, 0],
|
| 96 |
+
[0, 0, 0, 0, 0],
|
| 97 |
+
[0, 1, -3, 3, -1],
|
| 98 |
+
[0, 0, 0, 0, 0],
|
| 99 |
+
[0, 0, 0, 0, 0]
|
| 100 |
+
], dtype=np.float32),
|
| 101 |
+
np.array([
|
| 102 |
+
[0, 0, 0, 0, 0],
|
| 103 |
+
[0, 1, 0, 0, 0],
|
| 104 |
+
[0, 0, -3, 0, 0],
|
| 105 |
+
[0, 0, 0, 3, 0],
|
| 106 |
+
[0, 0, 0, 0, -1]
|
| 107 |
+
], dtype=np.float32),
|
| 108 |
+
np.array([
|
| 109 |
+
[0, 0, 0, 0, 0],
|
| 110 |
+
[0, 0, 1, 0, 0],
|
| 111 |
+
[0, 0, -3, 0, 0],
|
| 112 |
+
[0, 0, 3, 0, 0],
|
| 113 |
+
[0, 0, -1, 0, 0]
|
| 114 |
+
], dtype=np.float32),
|
| 115 |
+
np.array([
|
| 116 |
+
[0, 0, 0, 0, 0],
|
| 117 |
+
[0, 0, 0, 1, 0],
|
| 118 |
+
[0, 0, -3, 0, 0],
|
| 119 |
+
[0, 3, 0, 0, 0],
|
| 120 |
+
[-1, 0, 0, 0, 0]
|
| 121 |
+
], dtype=np.float32),
|
| 122 |
+
np.array([
|
| 123 |
+
[0, 0, 0, 0, 0],
|
| 124 |
+
[0, 0, 0, 0, 0],
|
| 125 |
+
[-1, 3, -3, 1, 0],
|
| 126 |
+
[0, 0, 0, 0, 0],
|
| 127 |
+
[0, 0, 0, 0, 0]
|
| 128 |
+
], dtype=np.float32)
|
| 129 |
+
]
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
filter_edge_3x3 = [
|
| 133 |
+
np.array([
|
| 134 |
+
[-1, 2, -1],
|
| 135 |
+
[2, -4, 2],
|
| 136 |
+
[0, 0, 0]
|
| 137 |
+
], dtype=np.float32),
|
| 138 |
+
np.array([
|
| 139 |
+
[0, 2, -1],
|
| 140 |
+
[0, -4, 2],
|
| 141 |
+
[0, 2, -1]
|
| 142 |
+
], dtype=np.float32),
|
| 143 |
+
np.array([
|
| 144 |
+
[0, 0, 0],
|
| 145 |
+
[2, -4, 2],
|
| 146 |
+
[-1, 2, -1]
|
| 147 |
+
], dtype=np.float32),
|
| 148 |
+
np.array([
|
| 149 |
+
[-1, 2, 0],
|
| 150 |
+
[2, -4, 0],
|
| 151 |
+
[-1, 2, 0]
|
| 152 |
+
], dtype=np.float32),
|
| 153 |
+
]
|
| 154 |
+
|
| 155 |
+
filter_edge_5x5 = [
|
| 156 |
+
np.array([
|
| 157 |
+
[-1, 2, -2, 2, -1],
|
| 158 |
+
[2, -6, 8, -6, 2],
|
| 159 |
+
[-2, 8, -12, 8, -2],
|
| 160 |
+
[0, 0, 0, 0, 0],
|
| 161 |
+
[0, 0, 0, 0, 0]
|
| 162 |
+
], dtype=np.float32),
|
| 163 |
+
np.array([
|
| 164 |
+
[0, 0, -2, 2, -1],
|
| 165 |
+
[0, 0, 8, -6, 2],
|
| 166 |
+
[0, 0, -12, 8, -2],
|
| 167 |
+
[0, 0, 8, -6, 2],
|
| 168 |
+
[0, 0, -2, 2, -1]
|
| 169 |
+
], dtype=np.float32),
|
| 170 |
+
np.array([
|
| 171 |
+
[0, 0, 0, 0, 0],
|
| 172 |
+
[0, 0, 0, 0, 0],
|
| 173 |
+
[-2, 8, -12, 8, -2],
|
| 174 |
+
[2, -6, 8, -6, 2],
|
| 175 |
+
[-1, 2, -2, 2, -1]
|
| 176 |
+
], dtype=np.float32),
|
| 177 |
+
np.array([
|
| 178 |
+
[-1, 2, -2, 0, 0],
|
| 179 |
+
[2, -6, 8, 0, 0],
|
| 180 |
+
[-2, 8, -12, 0, 0],
|
| 181 |
+
[2, -6, 8, 0, 0],
|
| 182 |
+
[-1, 2, -2, 0, 0]
|
| 183 |
+
], dtype=np.float32),
|
| 184 |
+
]
|
| 185 |
+
|
| 186 |
+
square_3x3 = np.array([
|
| 187 |
+
[-1, 2, -1],
|
| 188 |
+
[2, -4, 2],
|
| 189 |
+
[-1, 2, -1]
|
| 190 |
+
], dtype=np.float32)
|
| 191 |
+
|
| 192 |
+
square_5x5 = np.array([
|
| 193 |
+
[-1, 2, -2, 2, -1],
|
| 194 |
+
[2, -6, 8, -6, 2],
|
| 195 |
+
[-2, 8, -12, 8, -2],
|
| 196 |
+
[2, -6, 8, -6, 2],
|
| 197 |
+
[-1, 2, -2, 2, -1]
|
| 198 |
+
], dtype=np.float32)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
all_hpf_list = filter_class_1 + filter_class_2 + filter_class_3 + filter_edge_3x3 + filter_edge_5x5 + [square_3x3, square_5x5]
|
| 202 |
+
|
| 203 |
+
hpf_3x3_list = filter_class_1 + filter_class_2 + filter_edge_3x3 + [square_3x3]
|
| 204 |
+
hpf_5x5_list = filter_class_3 + filter_edge_5x5 + [square_5x5]
|
| 205 |
+
|
| 206 |
+
normalized_filter_class_2 = [hpf / 2 for hpf in filter_class_2]
|
| 207 |
+
normalized_filter_class_3 = [hpf / 3 for hpf in filter_class_3]
|
| 208 |
+
normalized_filter_edge_3x3 = [hpf / 4 for hpf in filter_edge_3x3]
|
| 209 |
+
normalized_square_3x3 = square_3x3 / 4
|
| 210 |
+
normalized_filter_edge_5x5 = [hpf / 12 for hpf in filter_edge_5x5]
|
| 211 |
+
normalized_square_5x5 = square_5x5 / 12
|
| 212 |
+
|
| 213 |
+
all_normalized_hpf_list = filter_class_1 + normalized_filter_class_2 + normalized_filter_class_3 + \
|
| 214 |
+
normalized_filter_edge_3x3 + normalized_filter_edge_5x5 + [normalized_square_3x3, normalized_square_5x5]
|
| 215 |
+
|
| 216 |
+
normalized_hpf_3x3_list = filter_class_1 + normalized_filter_class_2 + normalized_filter_edge_3x3 + [normalized_square_3x3]
|
| 217 |
+
normalized_hpf_5x5_list = normalized_filter_class_3 + normalized_filter_edge_5x5 + [normalized_square_5x5]
|
| 218 |
+
|
| 219 |
+
normalized_3x3_list = normalized_filter_edge_3x3 + [normalized_square_3x3]
|
| 220 |
+
normalized_5x5_list = normalized_filter_edge_5x5 + [normalized_square_5x5]
|
detector_codes/AIDE-main/models/utils.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
|
| 5 |
+
# This source code is licensed under the license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
import numpy.random as random
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
# from MinkowskiEngine import SparseTensor
|
| 15 |
+
|
| 16 |
+
# class MinkowskiGRN(nn.Module):
|
| 17 |
+
# """ GRN layer for sparse tensors.
|
| 18 |
+
# """
|
| 19 |
+
# def __init__(self, dim):
|
| 20 |
+
# super().__init__()
|
| 21 |
+
# self.gamma = nn.Parameter(torch.zeros(1, dim))
|
| 22 |
+
# self.beta = nn.Parameter(torch.zeros(1, dim))
|
| 23 |
+
|
| 24 |
+
# def forward(self, x):
|
| 25 |
+
# cm = x.coordinate_manager
|
| 26 |
+
# in_key = x.coordinate_map_key
|
| 27 |
+
|
| 28 |
+
# Gx = torch.norm(x.F, p=2, dim=0, keepdim=True)
|
| 29 |
+
# Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
| 30 |
+
# return SparseTensor(
|
| 31 |
+
# self.gamma * (x.F * Nx) + self.beta + x.F,
|
| 32 |
+
# coordinate_map_key=in_key,
|
| 33 |
+
# coordinate_manager=cm)
|
| 34 |
+
|
| 35 |
+
# class MinkowskiDropPath(nn.Module):
|
| 36 |
+
# """ Drop Path for sparse tensors.
|
| 37 |
+
# """
|
| 38 |
+
|
| 39 |
+
# def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
|
| 40 |
+
# super(MinkowskiDropPath, self).__init__()
|
| 41 |
+
# self.drop_prob = drop_prob
|
| 42 |
+
# self.scale_by_keep = scale_by_keep
|
| 43 |
+
|
| 44 |
+
# def forward(self, x):
|
| 45 |
+
# if self.drop_prob == 0. or not self.training:
|
| 46 |
+
# return x
|
| 47 |
+
# cm = x.coordinate_manager
|
| 48 |
+
# in_key = x.coordinate_map_key
|
| 49 |
+
# keep_prob = 1 - self.drop_prob
|
| 50 |
+
# mask = torch.cat([
|
| 51 |
+
# torch.ones(len(_)) if random.uniform(0, 1) > self.drop_prob
|
| 52 |
+
# else torch.zeros(len(_)) for _ in x.decomposed_coordinates
|
| 53 |
+
# ]).view(-1, 1).to(x.device)
|
| 54 |
+
# if keep_prob > 0.0 and self.scale_by_keep:
|
| 55 |
+
# mask.div_(keep_prob)
|
| 56 |
+
# return SparseTensor(
|
| 57 |
+
# x.F * mask,
|
| 58 |
+
# coordinate_map_key=in_key,
|
| 59 |
+
# coordinate_manager=cm)
|
| 60 |
+
|
| 61 |
+
# class MinkowskiLayerNorm(nn.Module):
|
| 62 |
+
# """ Channel-wise layer normalization for sparse tensors.
|
| 63 |
+
# """
|
| 64 |
+
|
| 65 |
+
# def __init__(
|
| 66 |
+
# self,
|
| 67 |
+
# normalized_shape,
|
| 68 |
+
# eps=1e-6,
|
| 69 |
+
# ):
|
| 70 |
+
# super(MinkowskiLayerNorm, self).__init__()
|
| 71 |
+
# self.ln = nn.LayerNorm(normalized_shape, eps=eps)
|
| 72 |
+
# def forward(self, input):
|
| 73 |
+
# output = self.ln(input.F)
|
| 74 |
+
# return SparseTensor(
|
| 75 |
+
# output,
|
| 76 |
+
# coordinate_map_key=input.coordinate_map_key,
|
| 77 |
+
# coordinate_manager=input.coordinate_manager)
|
| 78 |
+
|
| 79 |
+
class LayerNorm(nn.Module):
|
| 80 |
+
""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
| 81 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
| 82 |
+
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
| 83 |
+
with shape (batch_size, channels, height, width).
|
| 84 |
+
"""
|
| 85 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 88 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
| 89 |
+
self.eps = eps
|
| 90 |
+
self.data_format = data_format
|
| 91 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
| 92 |
+
raise NotImplementedError
|
| 93 |
+
self.normalized_shape = (normalized_shape, )
|
| 94 |
+
|
| 95 |
+
def forward(self, x):
|
| 96 |
+
if self.data_format == "channels_last":
|
| 97 |
+
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
| 98 |
+
elif self.data_format == "channels_first":
|
| 99 |
+
u = x.mean(1, keepdim=True)
|
| 100 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
| 101 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
| 102 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
| 103 |
+
return x
|
| 104 |
+
|
| 105 |
+
class GRN(nn.Module):
|
| 106 |
+
""" GRN (Global Response Normalization) layer
|
| 107 |
+
"""
|
| 108 |
+
def __init__(self, dim):
|
| 109 |
+
super().__init__()
|
| 110 |
+
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
| 111 |
+
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
| 112 |
+
|
| 113 |
+
def forward(self, x):
|
| 114 |
+
Gx = torch.norm(x, p=2, dim=(1,2), keepdim=True)
|
| 115 |
+
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
| 116 |
+
return self.gamma * (x * Nx) + self.beta + x
|
detector_codes/AIDE-main/optim_factory.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
|
| 5 |
+
# This source code is licensed under the license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import optim as optim
|
| 11 |
+
|
| 12 |
+
from timm.optim.adafactor import Adafactor
|
| 13 |
+
from timm.optim.adahessian import Adahessian
|
| 14 |
+
from timm.optim.adamp import AdamP
|
| 15 |
+
from timm.optim.lookahead import Lookahead
|
| 16 |
+
# from timm.optim.nadam import Nadam
|
| 17 |
+
# from timm.optim.novograd import NovoGrad
|
| 18 |
+
# from timm.optim.nvnovograd import NvNovoGrad
|
| 19 |
+
# from timm.optim.radam import RAdam
|
| 20 |
+
from timm.optim.rmsprop_tf import RMSpropTF
|
| 21 |
+
from timm.optim.sgdp import SGDP
|
| 22 |
+
|
| 23 |
+
import json
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
|
| 27 |
+
has_apex = True
|
| 28 |
+
except ImportError:
|
| 29 |
+
has_apex = False
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_num_layer_for_convnext_single(var_name, depths):
|
| 33 |
+
"""
|
| 34 |
+
Each layer is assigned distinctive layer ids
|
| 35 |
+
"""
|
| 36 |
+
if var_name.startswith("downsample_layers"):
|
| 37 |
+
stage_id = int(var_name.split('.')[1])
|
| 38 |
+
layer_id = sum(depths[:stage_id]) + 1
|
| 39 |
+
return layer_id
|
| 40 |
+
|
| 41 |
+
elif var_name.startswith("stages"):
|
| 42 |
+
stage_id = int(var_name.split('.')[1])
|
| 43 |
+
block_id = int(var_name.split('.')[2])
|
| 44 |
+
layer_id = sum(depths[:stage_id]) + block_id + 1
|
| 45 |
+
return layer_id
|
| 46 |
+
|
| 47 |
+
else:
|
| 48 |
+
return sum(depths) + 1
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_num_layer_for_convnext(var_name):
|
| 52 |
+
"""
|
| 53 |
+
Divide [3, 3, 27, 3] layers into 12 groups; each group is three
|
| 54 |
+
consecutive blocks, including possible neighboring downsample layers;
|
| 55 |
+
adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py
|
| 56 |
+
"""
|
| 57 |
+
num_max_layer = 12
|
| 58 |
+
if var_name.startswith("downsample_layers"):
|
| 59 |
+
stage_id = int(var_name.split('.')[1])
|
| 60 |
+
if stage_id == 0:
|
| 61 |
+
layer_id = 0
|
| 62 |
+
elif stage_id == 1 or stage_id == 2:
|
| 63 |
+
layer_id = stage_id + 1
|
| 64 |
+
elif stage_id == 3:
|
| 65 |
+
layer_id = 12
|
| 66 |
+
return layer_id
|
| 67 |
+
|
| 68 |
+
elif var_name.startswith("stages"):
|
| 69 |
+
stage_id = int(var_name.split('.')[1])
|
| 70 |
+
block_id = int(var_name.split('.')[2])
|
| 71 |
+
if stage_id == 0 or stage_id == 1:
|
| 72 |
+
layer_id = stage_id + 1
|
| 73 |
+
elif stage_id == 2:
|
| 74 |
+
layer_id = 3 + block_id // 3
|
| 75 |
+
elif stage_id == 3:
|
| 76 |
+
layer_id = 12
|
| 77 |
+
return layer_id
|
| 78 |
+
else:
|
| 79 |
+
return num_max_layer + 1
|
| 80 |
+
|
| 81 |
+
class LayerDecayValueAssigner(object):
|
| 82 |
+
def __init__(self, values, depths=[3,3,27,3], layer_decay_type='single'):
|
| 83 |
+
self.values = values
|
| 84 |
+
self.depths = depths
|
| 85 |
+
self.layer_decay_type = layer_decay_type
|
| 86 |
+
|
| 87 |
+
def get_scale(self, layer_id):
|
| 88 |
+
return self.values[layer_id]
|
| 89 |
+
|
| 90 |
+
def get_layer_id(self, var_name):
|
| 91 |
+
if self.layer_decay_type == 'single':
|
| 92 |
+
return get_num_layer_for_convnext_single(var_name, self.depths)
|
| 93 |
+
else:
|
| 94 |
+
return get_num_layer_for_convnext(var_name)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None):
|
| 98 |
+
parameter_group_names = {}
|
| 99 |
+
parameter_group_vars = {}
|
| 100 |
+
|
| 101 |
+
for name, param in model.named_parameters():
|
| 102 |
+
if not param.requires_grad:
|
| 103 |
+
continue # frozen weights
|
| 104 |
+
if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list or \
|
| 105 |
+
name.endswith(".gamma") or name.endswith(".beta"):
|
| 106 |
+
group_name = "no_decay"
|
| 107 |
+
this_weight_decay = 0.
|
| 108 |
+
else:
|
| 109 |
+
group_name = "decay"
|
| 110 |
+
this_weight_decay = weight_decay
|
| 111 |
+
if get_num_layer is not None:
|
| 112 |
+
layer_id = get_num_layer(name)
|
| 113 |
+
group_name = "layer_%d_%s" % (layer_id, group_name)
|
| 114 |
+
else:
|
| 115 |
+
layer_id = None
|
| 116 |
+
|
| 117 |
+
if group_name not in parameter_group_names:
|
| 118 |
+
if get_layer_scale is not None:
|
| 119 |
+
scale = get_layer_scale(layer_id)
|
| 120 |
+
else:
|
| 121 |
+
scale = 1.
|
| 122 |
+
|
| 123 |
+
parameter_group_names[group_name] = {
|
| 124 |
+
"weight_decay": this_weight_decay,
|
| 125 |
+
"params": [],
|
| 126 |
+
"lr_scale": scale
|
| 127 |
+
}
|
| 128 |
+
parameter_group_vars[group_name] = {
|
| 129 |
+
"weight_decay": this_weight_decay,
|
| 130 |
+
"params": [],
|
| 131 |
+
"lr_scale": scale
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
parameter_group_vars[group_name]["params"].append(param)
|
| 135 |
+
parameter_group_names[group_name]["params"].append(name)
|
| 136 |
+
print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
|
| 137 |
+
return list(parameter_group_vars.values())
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None):
|
| 141 |
+
opt_lower = args.opt.lower()
|
| 142 |
+
weight_decay = args.weight_decay
|
| 143 |
+
# if weight_decay and filter_bias_and_bn:
|
| 144 |
+
if filter_bias_and_bn:
|
| 145 |
+
skip = {}
|
| 146 |
+
if skip_list is not None:
|
| 147 |
+
skip = skip_list
|
| 148 |
+
elif hasattr(model, 'no_weight_decay'):
|
| 149 |
+
skip = model.no_weight_decay()
|
| 150 |
+
parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale)
|
| 151 |
+
weight_decay = 0.
|
| 152 |
+
else:
|
| 153 |
+
parameters = model.parameters()
|
| 154 |
+
|
| 155 |
+
if 'fused' in opt_lower:
|
| 156 |
+
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
|
| 157 |
+
|
| 158 |
+
opt_args = dict(lr=args.lr, weight_decay=weight_decay)
|
| 159 |
+
if hasattr(args, 'opt_eps') and args.opt_eps is not None:
|
| 160 |
+
opt_args['eps'] = args.opt_eps
|
| 161 |
+
if hasattr(args, 'opt_betas') and args.opt_betas is not None:
|
| 162 |
+
opt_args['betas'] = args.opt_betas
|
| 163 |
+
|
| 164 |
+
opt_split = opt_lower.split('_')
|
| 165 |
+
opt_lower = opt_split[-1]
|
| 166 |
+
if opt_lower == 'sgd' or opt_lower == 'nesterov':
|
| 167 |
+
opt_args.pop('eps', None)
|
| 168 |
+
optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
|
| 169 |
+
elif opt_lower == 'momentum':
|
| 170 |
+
opt_args.pop('eps', None)
|
| 171 |
+
optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
|
| 172 |
+
elif opt_lower == 'adam':
|
| 173 |
+
optimizer = optim.Adam(parameters, **opt_args)
|
| 174 |
+
elif opt_lower == 'adamw':
|
| 175 |
+
optimizer = optim.AdamW(parameters, **opt_args)
|
| 176 |
+
# elif opt_lower == 'nadam':
|
| 177 |
+
# optimizer = Nadam(parameters, **opt_args)
|
| 178 |
+
# elif opt_lower == 'radam':
|
| 179 |
+
# optimizer = RAdam(parameters, **opt_args)
|
| 180 |
+
elif opt_lower == 'adamp':
|
| 181 |
+
optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
|
| 182 |
+
elif opt_lower == 'sgdp':
|
| 183 |
+
optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args)
|
| 184 |
+
elif opt_lower == 'adadelta':
|
| 185 |
+
optimizer = optim.Adadelta(parameters, **opt_args)
|
| 186 |
+
elif opt_lower == 'adafactor':
|
| 187 |
+
if not args.lr:
|
| 188 |
+
opt_args['lr'] = None
|
| 189 |
+
optimizer = Adafactor(parameters, **opt_args)
|
| 190 |
+
elif opt_lower == 'adahessian':
|
| 191 |
+
optimizer = Adahessian(parameters, **opt_args)
|
| 192 |
+
elif opt_lower == 'rmsprop':
|
| 193 |
+
optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
|
| 194 |
+
elif opt_lower == 'rmsproptf':
|
| 195 |
+
optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
|
| 196 |
+
elif opt_lower == 'novograd':
|
| 197 |
+
optimizer = NovoGrad(parameters, **opt_args)
|
| 198 |
+
elif opt_lower == 'nvnovograd':
|
| 199 |
+
optimizer = NvNovoGrad(parameters, **opt_args)
|
| 200 |
+
elif opt_lower == 'fusedsgd':
|
| 201 |
+
opt_args.pop('eps', None)
|
| 202 |
+
optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
|
| 203 |
+
elif opt_lower == 'fusedmomentum':
|
| 204 |
+
opt_args.pop('eps', None)
|
| 205 |
+
optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
|
| 206 |
+
elif opt_lower == 'fusedadam':
|
| 207 |
+
optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
|
| 208 |
+
elif opt_lower == 'fusedadamw':
|
| 209 |
+
optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
|
| 210 |
+
elif opt_lower == 'fusedlamb':
|
| 211 |
+
optimizer = FusedLAMB(parameters, **opt_args)
|
| 212 |
+
elif opt_lower == 'fusednovograd':
|
| 213 |
+
opt_args.setdefault('betas', (0.95, 0.98))
|
| 214 |
+
optimizer = FusedNovoGrad(parameters, **opt_args)
|
| 215 |
+
else:
|
| 216 |
+
assert False and "Invalid optimizer"
|
| 217 |
+
|
| 218 |
+
if len(opt_split) > 1:
|
| 219 |
+
if opt_split[0] == 'lookahead':
|
| 220 |
+
optimizer = Lookahead(optimizer)
|
| 221 |
+
|
| 222 |
+
return optimizer
|
detector_codes/AIDE-main/requirements.txt
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
einops==0.6.1
|
| 2 |
+
fairscale==0.4.13
|
| 3 |
+
filelock==3.13.1
|
| 4 |
+
ftfy==6.1.3
|
| 5 |
+
h5py==3.10.0
|
| 6 |
+
imgaug==0.2.6
|
| 7 |
+
keras==2.11.0
|
| 8 |
+
kornia==0.7.2
|
| 9 |
+
kornia_rs==0.1.2
|
| 10 |
+
lmdb==1.4.1
|
| 11 |
+
matplotlib==3.7.4
|
| 12 |
+
matplotlib-inline==0.1.6
|
| 13 |
+
numpy==1.24.3
|
| 14 |
+
omegaconf==2.3.0
|
| 15 |
+
open-clip-torch==2.24.0
|
| 16 |
+
openai-clip==1.0.1
|
| 17 |
+
openpyxl==3.1.2
|
| 18 |
+
pandas==2.0.3
|
| 19 |
+
Pillow==9.5.0
|
| 20 |
+
safetensors==0.4.1
|
| 21 |
+
scikit-image==0.20.0
|
| 22 |
+
scikit-learn==1.3.2
|
| 23 |
+
scipy==1.9.1
|
| 24 |
+
sentencepiece==0.2.0
|
| 25 |
+
streamlit==1.30.0
|
| 26 |
+
tenacity==8.2.3
|
| 27 |
+
tensorboard==2.11.2
|
| 28 |
+
tensorboard-data-server==0.6.1
|
| 29 |
+
tensorboard-plugin-wit==1.8.1
|
| 30 |
+
tensorboardX==2.6.2.2
|
| 31 |
+
torch==1.11.0
|
| 32 |
+
torch-fidelity==0.3.0
|
| 33 |
+
torchmetrics==0.6.0
|
| 34 |
+
torchsummary==1.5.1
|
| 35 |
+
torchvision==0.12.0
|
| 36 |
+
tqdm==4.66.1
|
detector_codes/AIDE-main/scripts/eval.sh
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MODEL="AIDE"
|
| 2 |
+
RESUME_PATH="./4class-checkpoints"
|
| 3 |
+
|
| 4 |
+
eval_datasets=(
|
| 5 |
+
# "/data/ziqiang/Benchmark" \
|
| 6 |
+
# "/data/ziqiang/jpeg" \
|
| 7 |
+
# "/data/ziqiang/noise" \
|
| 8 |
+
"/data/ziqiang/sample" \
|
| 9 |
+
)
|
| 10 |
+
for eval_dataset in "${eval_datasets[@]}"
|
| 11 |
+
do
|
| 12 |
+
python main_finetune.py \
|
| 13 |
+
--model $MODEL \
|
| 14 |
+
--data_path /data/ziqiang/yjz/dataset/ForenSynths/train \
|
| 15 |
+
--eval_data_path $eval_dataset \
|
| 16 |
+
--batch_size 64 \
|
| 17 |
+
--output_dir $RESUME_PATH \
|
| 18 |
+
--resume $RESUME_PATH/checkpoint-best.pth \
|
| 19 |
+
--eval True
|
| 20 |
+
done
|
| 21 |
+
|
detector_codes/AIDE-main/scripts/train.sh
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
GPU_NUM=8
|
| 3 |
+
WORLD_SIZE=1
|
| 4 |
+
RANK=0
|
| 5 |
+
MASTER_ADDR=localhost
|
| 6 |
+
MASTER_PORT=29512
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
DISTRIBUTED_ARGS="
|
| 10 |
+
--nproc_per_node $GPU_NUM \
|
| 11 |
+
--nnodes $WORLD_SIZE \
|
| 12 |
+
--node_rank $RANK \
|
| 13 |
+
--master_addr $MASTER_ADDR \
|
| 14 |
+
--master_port $MASTER_PORT
|
| 15 |
+
"
|
| 16 |
+
|
| 17 |
+
PY_ARGS=${@:1} # Any other arguments
|
| 18 |
+
|
| 19 |
+
python -m torch.distributed.launch $DISTRIBUTED_ARGS main_finetune.py \
|
| 20 |
+
--model AIDE \
|
| 21 |
+
--batch_size 32 \
|
| 22 |
+
--blr 1e-4 \
|
| 23 |
+
--epochs 20 \
|
| 24 |
+
${PY_ARGS}
|
detector_codes/AIDE-main/utils.py
ADDED
|
@@ -0,0 +1,700 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
|
| 5 |
+
# This source code is licensed under the license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
import datetime
|
| 10 |
+
import math
|
| 11 |
+
import os
|
| 12 |
+
import time
|
| 13 |
+
from collections import OrderedDict, defaultdict, deque
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
import torch.distributed as dist
|
| 19 |
+
from tensorboardX import SummaryWriter
|
| 20 |
+
from timm.utils import get_state_dict
|
| 21 |
+
from torch import inf
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def str2bool(v):
|
| 25 |
+
"""
|
| 26 |
+
Converts string to bool type; enables command line
|
| 27 |
+
arguments in the format of '--arg1 true --arg2 false'
|
| 28 |
+
"""
|
| 29 |
+
if isinstance(v, bool):
|
| 30 |
+
return v
|
| 31 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
| 32 |
+
return True
|
| 33 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
| 34 |
+
return False
|
| 35 |
+
else:
|
| 36 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class SmoothedValue(object):
|
| 40 |
+
"""Track a series of values and provide access to smoothed values over a
|
| 41 |
+
window or the global series average.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(self, window_size=20, fmt=None):
|
| 45 |
+
if fmt is None:
|
| 46 |
+
fmt = '{median:.4f} ({global_avg:.4f})'
|
| 47 |
+
self.deque = deque(maxlen=window_size)
|
| 48 |
+
self.total = 0.0
|
| 49 |
+
self.count = 0
|
| 50 |
+
self.fmt = fmt
|
| 51 |
+
|
| 52 |
+
def update(self, value, n=1):
|
| 53 |
+
self.deque.append(value)
|
| 54 |
+
self.count += n
|
| 55 |
+
self.total += value * n
|
| 56 |
+
|
| 57 |
+
def synchronize_between_processes(self):
|
| 58 |
+
"""
|
| 59 |
+
Warning: does not synchronize the deque!
|
| 60 |
+
"""
|
| 61 |
+
if not is_dist_avail_and_initialized():
|
| 62 |
+
return
|
| 63 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
| 64 |
+
dist.barrier()
|
| 65 |
+
dist.all_reduce(t)
|
| 66 |
+
t = t.tolist()
|
| 67 |
+
self.count = int(t[0])
|
| 68 |
+
self.total = t[1]
|
| 69 |
+
|
| 70 |
+
@property
|
| 71 |
+
def median(self):
|
| 72 |
+
d = torch.tensor(list(self.deque))
|
| 73 |
+
return d.median().item()
|
| 74 |
+
|
| 75 |
+
@property
|
| 76 |
+
def avg(self):
|
| 77 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
| 78 |
+
return d.mean().item()
|
| 79 |
+
|
| 80 |
+
@property
|
| 81 |
+
def global_avg(self):
|
| 82 |
+
return self.total / self.count
|
| 83 |
+
|
| 84 |
+
@property
|
| 85 |
+
def max(self):
|
| 86 |
+
return max(self.deque)
|
| 87 |
+
|
| 88 |
+
@property
|
| 89 |
+
def value(self):
|
| 90 |
+
return self.deque[-1]
|
| 91 |
+
|
| 92 |
+
def __str__(self):
|
| 93 |
+
return self.fmt.format(
|
| 94 |
+
median=self.median,
|
| 95 |
+
avg=self.avg,
|
| 96 |
+
global_avg=self.global_avg,
|
| 97 |
+
max=self.max,
|
| 98 |
+
value=self.value,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class MetricLogger(object):
|
| 103 |
+
def __init__(self, delimiter='\t'):
|
| 104 |
+
self.meters = defaultdict(SmoothedValue)
|
| 105 |
+
self.delimiter = delimiter
|
| 106 |
+
|
| 107 |
+
def update(self, **kwargs):
|
| 108 |
+
for k, v in kwargs.items():
|
| 109 |
+
if v is None:
|
| 110 |
+
continue
|
| 111 |
+
if isinstance(v, torch.Tensor):
|
| 112 |
+
v = v.item()
|
| 113 |
+
assert isinstance(v, (float, int))
|
| 114 |
+
self.meters[k].update(v)
|
| 115 |
+
|
| 116 |
+
def __getattr__(self, attr):
|
| 117 |
+
if attr in self.meters:
|
| 118 |
+
return self.meters[attr]
|
| 119 |
+
if attr in self.__dict__:
|
| 120 |
+
return self.__dict__[attr]
|
| 121 |
+
raise AttributeError(
|
| 122 |
+
"'{}' object has no attribute '{}'".format(type(self).__name__, attr)
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
def __str__(self):
|
| 126 |
+
loss_str = []
|
| 127 |
+
for name, meter in self.meters.items():
|
| 128 |
+
loss_str.append('{}: {}'.format(name, str(meter)))
|
| 129 |
+
return self.delimiter.join(loss_str)
|
| 130 |
+
|
| 131 |
+
def synchronize_between_processes(self):
|
| 132 |
+
for meter in self.meters.values():
|
| 133 |
+
meter.synchronize_between_processes()
|
| 134 |
+
|
| 135 |
+
def add_meter(self, name, meter):
|
| 136 |
+
self.meters[name] = meter
|
| 137 |
+
|
| 138 |
+
def log_every(self, iterable, print_freq, header=None):
|
| 139 |
+
i = 0
|
| 140 |
+
if not header:
|
| 141 |
+
header = ''
|
| 142 |
+
start_time = time.time()
|
| 143 |
+
end = time.time()
|
| 144 |
+
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
| 145 |
+
data_time = SmoothedValue(fmt='{avg:.4f}')
|
| 146 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
| 147 |
+
log_msg = [
|
| 148 |
+
header,
|
| 149 |
+
'[{0' + space_fmt + '}/{1}]',
|
| 150 |
+
'eta: {eta}',
|
| 151 |
+
'{meters}',
|
| 152 |
+
'time: {time}',
|
| 153 |
+
'data: {data}',
|
| 154 |
+
]
|
| 155 |
+
if torch.cuda.is_available():
|
| 156 |
+
log_msg.append('max mem: {memory:.0f}')
|
| 157 |
+
log_msg = self.delimiter.join(log_msg)
|
| 158 |
+
MB = 1024.0 * 1024.0
|
| 159 |
+
for obj in iterable:
|
| 160 |
+
data_time.update(time.time() - end)
|
| 161 |
+
yield obj
|
| 162 |
+
iter_time.update(time.time() - end)
|
| 163 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
| 164 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
| 165 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
| 166 |
+
if torch.cuda.is_available():
|
| 167 |
+
print(
|
| 168 |
+
log_msg.format(
|
| 169 |
+
i,
|
| 170 |
+
len(iterable),
|
| 171 |
+
eta=eta_string,
|
| 172 |
+
meters=str(self),
|
| 173 |
+
time=str(iter_time),
|
| 174 |
+
data=str(data_time),
|
| 175 |
+
memory=torch.cuda.max_memory_allocated() / MB,
|
| 176 |
+
)
|
| 177 |
+
)
|
| 178 |
+
else:
|
| 179 |
+
print(
|
| 180 |
+
log_msg.format(
|
| 181 |
+
i,
|
| 182 |
+
len(iterable),
|
| 183 |
+
eta=eta_string,
|
| 184 |
+
meters=str(self),
|
| 185 |
+
time=str(iter_time),
|
| 186 |
+
data=str(data_time),
|
| 187 |
+
)
|
| 188 |
+
)
|
| 189 |
+
i += 1
|
| 190 |
+
end = time.time()
|
| 191 |
+
total_time = time.time() - start_time
|
| 192 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 193 |
+
print(
|
| 194 |
+
'{} Total time: {} ({:.4f} s / it)'.format(
|
| 195 |
+
header, total_time_str, total_time / len(iterable)
|
| 196 |
+
)
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class TensorboardLogger(object):
|
| 201 |
+
def __init__(self, log_dir):
|
| 202 |
+
self.writer = SummaryWriter(logdir=log_dir)
|
| 203 |
+
self.step = 0
|
| 204 |
+
|
| 205 |
+
def set_step(self, step=None):
|
| 206 |
+
if step is not None:
|
| 207 |
+
self.step = step
|
| 208 |
+
else:
|
| 209 |
+
self.step += 1
|
| 210 |
+
|
| 211 |
+
def update(self, head='scalar', step=None, **kwargs):
|
| 212 |
+
for k, v in kwargs.items():
|
| 213 |
+
if v is None:
|
| 214 |
+
continue
|
| 215 |
+
if isinstance(v, torch.Tensor):
|
| 216 |
+
v = v.item()
|
| 217 |
+
assert isinstance(v, (float, int))
|
| 218 |
+
self.writer.add_scalar(
|
| 219 |
+
head + '/' + k, v, self.step if step is None else step
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
def flush(self):
|
| 223 |
+
self.writer.flush()
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class WandbLogger(object):
|
| 227 |
+
def __init__(self, args):
|
| 228 |
+
self.args = args
|
| 229 |
+
|
| 230 |
+
try:
|
| 231 |
+
import wandb
|
| 232 |
+
|
| 233 |
+
self._wandb = wandb
|
| 234 |
+
except ImportError:
|
| 235 |
+
raise ImportError(
|
| 236 |
+
'To use the Weights and Biases Logger please install wandb.'
|
| 237 |
+
'Run `pip install wandb` to install it.'
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
# Initialize a W&B run
|
| 241 |
+
if self._wandb.run is None:
|
| 242 |
+
self._wandb.init(project=args.project, config=args)
|
| 243 |
+
|
| 244 |
+
def log_epoch_metrics(self, metrics, commit=True):
|
| 245 |
+
"""
|
| 246 |
+
Log train/test metrics onto W&B.
|
| 247 |
+
"""
|
| 248 |
+
# Log number of model parameters as W&B summary
|
| 249 |
+
self._wandb.summary['n_parameters'] = metrics.get('n_parameters', None)
|
| 250 |
+
metrics.pop('n_parameters', None)
|
| 251 |
+
|
| 252 |
+
# Log current epoch
|
| 253 |
+
self._wandb.log({'epoch': metrics.get('epoch')}, commit=False)
|
| 254 |
+
metrics.pop('epoch')
|
| 255 |
+
|
| 256 |
+
for k, v in metrics.items():
|
| 257 |
+
if 'train' in k:
|
| 258 |
+
self._wandb.log({f'Global Train/{k}': v}, commit=False)
|
| 259 |
+
elif 'test' in k:
|
| 260 |
+
self._wandb.log({f'Global Test/{k}': v}, commit=False)
|
| 261 |
+
|
| 262 |
+
self._wandb.log({})
|
| 263 |
+
|
| 264 |
+
def log_checkpoints(self):
|
| 265 |
+
output_dir = self.args.output_dir
|
| 266 |
+
model_artifact = self._wandb.Artifact(
|
| 267 |
+
self._wandb.run.id + '_model', type='model'
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
model_artifact.add_dir(output_dir)
|
| 271 |
+
self._wandb.log_artifact(model_artifact, aliases=['latest', 'best'])
|
| 272 |
+
|
| 273 |
+
def set_steps(self):
|
| 274 |
+
# Set global training step
|
| 275 |
+
self._wandb.define_metric(
|
| 276 |
+
'Rank-0 Batch Wise/*', step_metric='Rank-0 Batch Wise/global_train_step'
|
| 277 |
+
)
|
| 278 |
+
# Set epoch-wise step
|
| 279 |
+
self._wandb.define_metric('Global Train/*', step_metric='epoch')
|
| 280 |
+
self._wandb.define_metric('Global Test/*', step_metric='epoch')
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def setup_for_distributed(is_master):
|
| 284 |
+
"""
|
| 285 |
+
This function disables printing when not in master process
|
| 286 |
+
"""
|
| 287 |
+
import builtins as __builtin__
|
| 288 |
+
|
| 289 |
+
builtin_print = __builtin__.print
|
| 290 |
+
|
| 291 |
+
def print(*args, **kwargs):
|
| 292 |
+
force = kwargs.pop('force', False)
|
| 293 |
+
if is_master or force:
|
| 294 |
+
builtin_print(*args, **kwargs)
|
| 295 |
+
|
| 296 |
+
__builtin__.print = print
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def is_dist_avail_and_initialized():
|
| 300 |
+
if not dist.is_available():
|
| 301 |
+
return False
|
| 302 |
+
if not dist.is_initialized():
|
| 303 |
+
return False
|
| 304 |
+
return True
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def get_world_size():
|
| 308 |
+
if not is_dist_avail_and_initialized():
|
| 309 |
+
return 1
|
| 310 |
+
return dist.get_world_size()
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def get_rank():
|
| 314 |
+
if not is_dist_avail_and_initialized():
|
| 315 |
+
return 0
|
| 316 |
+
return dist.get_rank()
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def is_main_process():
|
| 320 |
+
return get_rank() == 0
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def save_on_master(*args, **kwargs):
|
| 324 |
+
if is_main_process():
|
| 325 |
+
torch.save(*args, **kwargs)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def init_distributed_mode(args):
|
| 329 |
+
|
| 330 |
+
if args.dist_on_itp:
|
| 331 |
+
args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
|
| 332 |
+
args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
|
| 333 |
+
args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
|
| 334 |
+
args.dist_url = 'tcp://%s:%s' % (
|
| 335 |
+
os.environ['MASTER_ADDR'],
|
| 336 |
+
os.environ['MASTER_PORT'],
|
| 337 |
+
)
|
| 338 |
+
os.environ['LOCAL_RANK'] = str(args.gpu)
|
| 339 |
+
os.environ['RANK'] = str(args.rank)
|
| 340 |
+
os.environ['WORLD_SIZE'] = str(args.world_size)
|
| 341 |
+
# ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
|
| 342 |
+
elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
| 343 |
+
args.rank = int(os.environ['RANK'])
|
| 344 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
| 345 |
+
args.gpu = int(os.environ['LOCAL_RANK'])
|
| 346 |
+
elif 'SLURM_PROCID' in os.environ:
|
| 347 |
+
args.rank = int(os.environ['SLURM_PROCID'])
|
| 348 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
| 349 |
+
|
| 350 |
+
os.environ['RANK'] = str(args.rank)
|
| 351 |
+
os.environ['LOCAL_RANK'] = str(args.gpu)
|
| 352 |
+
os.environ['WORLD_SIZE'] = str(args.world_size)
|
| 353 |
+
else:
|
| 354 |
+
print('Not using distributed mode')
|
| 355 |
+
args.distributed = False
|
| 356 |
+
return
|
| 357 |
+
|
| 358 |
+
args.distributed = True
|
| 359 |
+
|
| 360 |
+
torch.cuda.set_device(args.gpu)
|
| 361 |
+
args.dist_backend = 'nccl'
|
| 362 |
+
print(
|
| 363 |
+
'| distributed init (rank {}): {}, gpu {}'.format(
|
| 364 |
+
args.rank, args.dist_url, args.gpu
|
| 365 |
+
),
|
| 366 |
+
flush=True,
|
| 367 |
+
)
|
| 368 |
+
torch.distributed.init_process_group(
|
| 369 |
+
backend=args.dist_backend,
|
| 370 |
+
init_method=args.dist_url,
|
| 371 |
+
world_size=args.world_size,
|
| 372 |
+
rank=args.rank,
|
| 373 |
+
)
|
| 374 |
+
torch.distributed.barrier()
|
| 375 |
+
setup_for_distributed(args.rank == 0)
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def all_reduce_mean(x):
|
| 379 |
+
world_size = get_world_size()
|
| 380 |
+
if world_size > 1:
|
| 381 |
+
x_reduce = torch.tensor(x).cuda()
|
| 382 |
+
dist.all_reduce(x_reduce)
|
| 383 |
+
x_reduce /= world_size
|
| 384 |
+
return x_reduce.item()
|
| 385 |
+
else:
|
| 386 |
+
return x
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def load_state_dict(
|
| 390 |
+
model, state_dict, prefix='', ignore_missing='relative_position_index'
|
| 391 |
+
):
|
| 392 |
+
missing_keys = []
|
| 393 |
+
unexpected_keys = []
|
| 394 |
+
error_msgs = []
|
| 395 |
+
# copy state_dict so _load_from_state_dict can modify it
|
| 396 |
+
metadata = getattr(state_dict, '_metadata', None)
|
| 397 |
+
state_dict = state_dict.copy()
|
| 398 |
+
if metadata is not None:
|
| 399 |
+
state_dict._metadata = metadata
|
| 400 |
+
|
| 401 |
+
def load(module, prefix=''):
|
| 402 |
+
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
| 403 |
+
module._load_from_state_dict(
|
| 404 |
+
state_dict,
|
| 405 |
+
prefix,
|
| 406 |
+
local_metadata,
|
| 407 |
+
True,
|
| 408 |
+
missing_keys,
|
| 409 |
+
unexpected_keys,
|
| 410 |
+
error_msgs,
|
| 411 |
+
)
|
| 412 |
+
for name, child in module._modules.items():
|
| 413 |
+
if child is not None:
|
| 414 |
+
load(child, prefix + name + '.')
|
| 415 |
+
|
| 416 |
+
load(model, prefix=prefix)
|
| 417 |
+
|
| 418 |
+
warn_missing_keys = []
|
| 419 |
+
ignore_missing_keys = []
|
| 420 |
+
for key in missing_keys:
|
| 421 |
+
keep_flag = True
|
| 422 |
+
for ignore_key in ignore_missing.split('|'):
|
| 423 |
+
if ignore_key in key:
|
| 424 |
+
keep_flag = False
|
| 425 |
+
break
|
| 426 |
+
if keep_flag:
|
| 427 |
+
warn_missing_keys.append(key)
|
| 428 |
+
else:
|
| 429 |
+
ignore_missing_keys.append(key)
|
| 430 |
+
|
| 431 |
+
missing_keys = warn_missing_keys
|
| 432 |
+
|
| 433 |
+
if len(missing_keys) > 0:
|
| 434 |
+
print(
|
| 435 |
+
'Weights of {} not initialized from pretrained model: {}'.format(
|
| 436 |
+
model.__class__.__name__, missing_keys
|
| 437 |
+
)
|
| 438 |
+
)
|
| 439 |
+
if len(unexpected_keys) > 0:
|
| 440 |
+
print(
|
| 441 |
+
'Weights from pretrained model not used in {}: {}'.format(
|
| 442 |
+
model.__class__.__name__, unexpected_keys
|
| 443 |
+
)
|
| 444 |
+
)
|
| 445 |
+
if len(ignore_missing_keys) > 0:
|
| 446 |
+
print(
|
| 447 |
+
'Ignored weights of {} not initialized from pretrained model: {}'.format(
|
| 448 |
+
model.__class__.__name__, ignore_missing_keys
|
| 449 |
+
)
|
| 450 |
+
)
|
| 451 |
+
if len(error_msgs) > 0:
|
| 452 |
+
print('\n'.join(error_msgs))
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
class NativeScalerWithGradNormCount:
|
| 456 |
+
state_dict_key = 'amp_scaler'
|
| 457 |
+
|
| 458 |
+
def __init__(self):
|
| 459 |
+
self._scaler = torch.cuda.amp.GradScaler()
|
| 460 |
+
|
| 461 |
+
def __call__(
|
| 462 |
+
self,
|
| 463 |
+
loss,
|
| 464 |
+
optimizer,
|
| 465 |
+
clip_grad=None,
|
| 466 |
+
parameters=None,
|
| 467 |
+
create_graph=False,
|
| 468 |
+
update_grad=True,
|
| 469 |
+
):
|
| 470 |
+
self._scaler.scale(loss).backward(create_graph=create_graph)
|
| 471 |
+
if update_grad:
|
| 472 |
+
if clip_grad is not None:
|
| 473 |
+
assert parameters is not None
|
| 474 |
+
self._scaler.unscale_(
|
| 475 |
+
optimizer
|
| 476 |
+
) # unscale the gradients of optimizer's assigned params in-place
|
| 477 |
+
norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
|
| 478 |
+
else:
|
| 479 |
+
self._scaler.unscale_(optimizer)
|
| 480 |
+
norm = get_grad_norm_(parameters)
|
| 481 |
+
self._scaler.step(optimizer)
|
| 482 |
+
self._scaler.update()
|
| 483 |
+
else:
|
| 484 |
+
norm = None
|
| 485 |
+
return norm
|
| 486 |
+
|
| 487 |
+
def state_dict(self):
|
| 488 |
+
return self._scaler.state_dict()
|
| 489 |
+
|
| 490 |
+
def load_state_dict(self, state_dict):
|
| 491 |
+
self._scaler.load_state_dict(state_dict)
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
|
| 495 |
+
if isinstance(parameters, torch.Tensor):
|
| 496 |
+
parameters = [parameters]
|
| 497 |
+
parameters = [p for p in parameters if p.grad is not None]
|
| 498 |
+
norm_type = float(norm_type)
|
| 499 |
+
if len(parameters) == 0:
|
| 500 |
+
return torch.tensor(0.0)
|
| 501 |
+
device = parameters[0].grad.device
|
| 502 |
+
if norm_type == inf:
|
| 503 |
+
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
|
| 504 |
+
else:
|
| 505 |
+
total_norm = torch.norm(
|
| 506 |
+
torch.stack(
|
| 507 |
+
[torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]
|
| 508 |
+
),
|
| 509 |
+
norm_type,
|
| 510 |
+
)
|
| 511 |
+
return total_norm
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
def save_model(
|
| 515 |
+
args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None
|
| 516 |
+
):
|
| 517 |
+
output_dir = Path(args.output_dir)
|
| 518 |
+
epoch_name = str(epoch)
|
| 519 |
+
checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
|
| 520 |
+
for checkpoint_path in checkpoint_paths:
|
| 521 |
+
to_save = {
|
| 522 |
+
'model': model_without_ddp.state_dict(),
|
| 523 |
+
'optimizer': optimizer.state_dict(),
|
| 524 |
+
'epoch': epoch,
|
| 525 |
+
'scaler': loss_scaler.state_dict(),
|
| 526 |
+
'args': args,
|
| 527 |
+
}
|
| 528 |
+
|
| 529 |
+
if model_ema is not None:
|
| 530 |
+
to_save['model_ema'] = get_state_dict(model_ema)
|
| 531 |
+
|
| 532 |
+
save_on_master(to_save, checkpoint_path)
|
| 533 |
+
|
| 534 |
+
if is_main_process() and isinstance(epoch, int):
|
| 535 |
+
to_del = epoch - args.save_ckpt_num * args.save_ckpt_freq
|
| 536 |
+
old_ckpt = output_dir / ('checkpoint-%s.pth' % to_del)
|
| 537 |
+
if os.path.exists(old_ckpt):
|
| 538 |
+
os.remove(old_ckpt)
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
def auto_load_model(
|
| 542 |
+
args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None
|
| 543 |
+
):
|
| 544 |
+
output_dir = Path(args.output_dir)
|
| 545 |
+
if args.auto_resume and len(args.resume) == 0:
|
| 546 |
+
import glob
|
| 547 |
+
|
| 548 |
+
all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth'))
|
| 549 |
+
latest_ckpt = -1
|
| 550 |
+
for ckpt in all_checkpoints:
|
| 551 |
+
t = ckpt.split('-')[-1].split('.')[0]
|
| 552 |
+
if t.isdigit():
|
| 553 |
+
latest_ckpt = max(int(t), latest_ckpt)
|
| 554 |
+
if latest_ckpt >= 0:
|
| 555 |
+
args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt)
|
| 556 |
+
print('Auto resume checkpoint: %s' % args.resume)
|
| 557 |
+
|
| 558 |
+
if args.resume:
|
| 559 |
+
if args.resume.startswith('https'):
|
| 560 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 561 |
+
args.resume, map_location='cpu', check_hash=True
|
| 562 |
+
)
|
| 563 |
+
else:
|
| 564 |
+
checkpoint = torch.load(args.resume, map_location='cpu')
|
| 565 |
+
|
| 566 |
+
model_without_ddp.load_state_dict(checkpoint['model'])
|
| 567 |
+
print('Resume checkpoint %s' % args.resume)
|
| 568 |
+
if 'optimizer' in checkpoint and 'epoch' in checkpoint:
|
| 569 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
| 570 |
+
if not isinstance(
|
| 571 |
+
checkpoint['epoch'], str
|
| 572 |
+
): # does not support resuming with 'best', 'best-ema'
|
| 573 |
+
args.start_epoch = checkpoint['epoch'] + 1
|
| 574 |
+
else:
|
| 575 |
+
assert args.eval, 'Does not support resuming with checkpoint-best'
|
| 576 |
+
if hasattr(args, 'model_ema') and args.model_ema:
|
| 577 |
+
if 'model_ema' in checkpoint.keys():
|
| 578 |
+
model_ema.ema.load_state_dict(checkpoint['model_ema'])
|
| 579 |
+
else:
|
| 580 |
+
model_ema.ema.load_state_dict(checkpoint['model'])
|
| 581 |
+
if 'scaler' in checkpoint:
|
| 582 |
+
loss_scaler.load_state_dict(checkpoint['scaler'])
|
| 583 |
+
print('With optim & sched!')
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
def cosine_scheduler(
|
| 587 |
+
base_value,
|
| 588 |
+
final_value,
|
| 589 |
+
epochs,
|
| 590 |
+
niter_per_ep,
|
| 591 |
+
warmup_epochs=0,
|
| 592 |
+
start_warmup_value=0,
|
| 593 |
+
warmup_steps=-1,
|
| 594 |
+
):
|
| 595 |
+
warmup_schedule = np.array([])
|
| 596 |
+
warmup_iters = warmup_epochs * niter_per_ep
|
| 597 |
+
if warmup_steps > 0:
|
| 598 |
+
warmup_iters = warmup_steps
|
| 599 |
+
print('Set warmup steps = %d' % warmup_iters)
|
| 600 |
+
if warmup_epochs > 0:
|
| 601 |
+
warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
|
| 602 |
+
|
| 603 |
+
iters = np.arange(epochs * niter_per_ep - warmup_iters)
|
| 604 |
+
schedule = np.array(
|
| 605 |
+
[
|
| 606 |
+
final_value
|
| 607 |
+
+ 0.5
|
| 608 |
+
* (base_value - final_value)
|
| 609 |
+
* (1 + math.cos(math.pi * i / (len(iters))))
|
| 610 |
+
for i in iters
|
| 611 |
+
]
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
+
schedule = np.concatenate((warmup_schedule, schedule))
|
| 615 |
+
|
| 616 |
+
assert len(schedule) == epochs * niter_per_ep
|
| 617 |
+
return schedule
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
def adjust_learning_rate(optimizer, epoch, args):
|
| 621 |
+
"""Decay the learning rate with half-cycle cosine after warmup"""
|
| 622 |
+
if epoch < args.warmup_epochs:
|
| 623 |
+
lr = args.lr * epoch / args.warmup_epochs
|
| 624 |
+
else:
|
| 625 |
+
lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * (
|
| 626 |
+
1.0
|
| 627 |
+
+ math.cos(
|
| 628 |
+
math.pi
|
| 629 |
+
* (epoch - args.warmup_epochs)
|
| 630 |
+
/ (args.epochs - args.warmup_epochs)
|
| 631 |
+
)
|
| 632 |
+
)
|
| 633 |
+
for param_group in optimizer.param_groups:
|
| 634 |
+
if 'lr_scale' in param_group:
|
| 635 |
+
param_group['lr'] = lr * param_group['lr_scale']
|
| 636 |
+
else:
|
| 637 |
+
param_group['lr'] = lr
|
| 638 |
+
return lr
|
| 639 |
+
|
| 640 |
+
|
| 641 |
+
def remap_checkpoint_keys(ckpt):
|
| 642 |
+
new_ckpt = OrderedDict()
|
| 643 |
+
for k, v in ckpt.items():
|
| 644 |
+
if k.startswith('encoder'):
|
| 645 |
+
k = '.'.join(k.split('.')[1:]) # remove encoder in the name
|
| 646 |
+
if k.endswith('kernel'):
|
| 647 |
+
k = '.'.join(k.split('.')[:-1]) # remove kernel in the name
|
| 648 |
+
new_k = k + '.weight'
|
| 649 |
+
if len(v.shape) == 3: # resahpe standard convolution
|
| 650 |
+
kv, in_dim, out_dim = v.shape
|
| 651 |
+
ks = int(math.sqrt(kv))
|
| 652 |
+
new_ckpt[new_k] = (
|
| 653 |
+
v.permute(2, 1, 0).reshape(out_dim, in_dim, ks, ks).transpose(3, 2)
|
| 654 |
+
)
|
| 655 |
+
elif len(v.shape) == 2: # reshape depthwise convolution
|
| 656 |
+
kv, dim = v.shape
|
| 657 |
+
ks = int(math.sqrt(kv))
|
| 658 |
+
new_ckpt[new_k] = (
|
| 659 |
+
v.permute(1, 0).reshape(dim, 1, ks, ks).transpose(3, 2)
|
| 660 |
+
)
|
| 661 |
+
continue
|
| 662 |
+
elif 'ln' in k or 'linear' in k:
|
| 663 |
+
k = k.split('.')
|
| 664 |
+
k.pop(-2) # remove ln and linear in the name
|
| 665 |
+
new_k = '.'.join(k)
|
| 666 |
+
else:
|
| 667 |
+
new_k = k
|
| 668 |
+
new_ckpt[new_k] = v
|
| 669 |
+
|
| 670 |
+
# reshape grn affine parameters and biases
|
| 671 |
+
for k, v in new_ckpt.items():
|
| 672 |
+
if k.endswith('bias') and len(v.shape) != 1:
|
| 673 |
+
new_ckpt[k] = v.reshape(-1)
|
| 674 |
+
# elif 'grn' in k:
|
| 675 |
+
# new_ckpt[k] = v.unsqueeze(0).unsqueeze(1)
|
| 676 |
+
return new_ckpt
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
class Logger(object):
|
| 680 |
+
"""Log stdout messages."""
|
| 681 |
+
|
| 682 |
+
def __init__(self, outfile):
|
| 683 |
+
self.terminal = sys.stdout
|
| 684 |
+
self.log = open(outfile, 'a')
|
| 685 |
+
sys.stdout = self
|
| 686 |
+
|
| 687 |
+
def write(self, message):
|
| 688 |
+
self.terminal.write(message)
|
| 689 |
+
self.log.write(message)
|
| 690 |
+
|
| 691 |
+
def flush(self):
|
| 692 |
+
self.terminal.flush()
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
def printSet(set_str):
|
| 696 |
+
set_str = str(set_str)
|
| 697 |
+
num = len(set_str)
|
| 698 |
+
print('=' * num * 3)
|
| 699 |
+
print(' ' * num + set_str)
|
| 700 |
+
print('=' * num * 3)
|
detector_codes/C2P-CLIP-DeepfakeDetection-main/Word_Frequency_Analysis.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
from collections import Counter, OrderedDict
|
| 5 |
+
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import nltk
|
| 8 |
+
import numpy as np
|
| 9 |
+
from nltk.corpus import stopwords
|
| 10 |
+
|
| 11 |
+
# os.environ['CURL_CA_BUNDLE'] = ''
|
| 12 |
+
# os.environ['HTTP_PROXY'] = "http://*:7890"
|
| 13 |
+
# os.environ['HTTPS_PROXY'] = "http://*:7890"
|
| 14 |
+
# os.environ['ALL_PROXY'] = "socks://*:7891"
|
| 15 |
+
nltk.download('stopwords')
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def parse_args():
|
| 19 |
+
parser = argparse.ArgumentParser(description='Word_Frequency_Analysis')
|
| 20 |
+
parser.add_argument('--root_path', type=str, help='')
|
| 21 |
+
parser.add_argument('--save_path', type=str, help='save_path', default='')
|
| 22 |
+
args = parser.parse_args()
|
| 23 |
+
|
| 24 |
+
def print_options(parser, args):
|
| 25 |
+
message = ''
|
| 26 |
+
message += '----------------- Options ---------------\n'
|
| 27 |
+
for k, v in sorted(vars(args).items()):
|
| 28 |
+
comment = ''
|
| 29 |
+
default = parser.get_default(k)
|
| 30 |
+
if v != default:
|
| 31 |
+
comment = '\t[default: %s]' % str(default)
|
| 32 |
+
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
|
| 33 |
+
message += '----------------- End -------------------'
|
| 34 |
+
print(message)
|
| 35 |
+
|
| 36 |
+
print_options(parser, args)
|
| 37 |
+
return args
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_list(folder_path):
|
| 41 |
+
image_paths_real = []
|
| 42 |
+
image_paths_fake = []
|
| 43 |
+
for root, dirs, files in os.walk(folder_path):
|
| 44 |
+
for file in files:
|
| 45 |
+
if file.lower().endswith(('.txt')):
|
| 46 |
+
abspath_tmp = os.path.abspath(os.path.join(root, file))
|
| 47 |
+
if '/0_real/' in abspath_tmp:
|
| 48 |
+
image_paths_real.append(abspath_tmp)
|
| 49 |
+
if '/1_fake/' in abspath_tmp:
|
| 50 |
+
image_paths_fake.append(abspath_tmp)
|
| 51 |
+
return image_paths_real, image_paths_fake
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def get_words_counts(image_paths):
|
| 55 |
+
all_text = []
|
| 56 |
+
for tpath in image_paths:
|
| 57 |
+
with open(tpath, 'r') as file:
|
| 58 |
+
all_text.append(file.read())
|
| 59 |
+
content = ' '.join(all_text)
|
| 60 |
+
words = re.findall(r'\b\w+\b', content.lower())
|
| 61 |
+
|
| 62 |
+
stop_words = set(stopwords.words('english'))
|
| 63 |
+
filtered_words = [word for word in words if word not in stop_words]
|
| 64 |
+
|
| 65 |
+
word_counts = Counter(filtered_words)
|
| 66 |
+
|
| 67 |
+
common_words = word_counts.most_common(20)
|
| 68 |
+
common_words_dict = {}
|
| 69 |
+
for common_word in common_words:
|
| 70 |
+
common_words_dict[common_word[0]] = common_word[1]
|
| 71 |
+
return common_words_dict, dict(word_counts)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
if __name__ == '__main__':
|
| 75 |
+
opt = parse_args()
|
| 76 |
+
os.makedirs(os.path.dirname(opt.save_path), mode=0o777, exist_ok=True)
|
| 77 |
+
if opt.root_path[-1] == '/':
|
| 78 |
+
opt.root_path = opt.root_path[:-1]
|
| 79 |
+
if isinstance(opt.root_path, list):
|
| 80 |
+
image_paths_real, image_paths_fake = [], []
|
| 81 |
+
for textp in opt.root_path:
|
| 82 |
+
tmp_real, tmp_fake = get_list(textp)
|
| 83 |
+
image_paths_real.extend(tmp_real)
|
| 84 |
+
image_paths_fake.extend(tmp_fake)
|
| 85 |
+
else:
|
| 86 |
+
image_paths_real, image_paths_fake = get_list(opt.root_path)
|
| 87 |
+
|
| 88 |
+
words_counts_real, words_counts_real_all = get_words_counts(image_paths_real)
|
| 89 |
+
words_counts_fake, words_counts_fake_all = get_words_counts(image_paths_fake)
|
| 90 |
+
|
| 91 |
+
all_words = set(list(words_counts_real.keys()) + list(words_counts_fake.keys()))
|
| 92 |
+
for word in all_words:
|
| 93 |
+
if word not in words_counts_real.keys():
|
| 94 |
+
words_counts_real[word] = (
|
| 95 |
+
words_counts_real_all[word]
|
| 96 |
+
if word in list(words_counts_real_all.keys())
|
| 97 |
+
else 0
|
| 98 |
+
)
|
| 99 |
+
if word not in words_counts_fake.keys():
|
| 100 |
+
words_counts_fake[word] = (
|
| 101 |
+
words_counts_fake_all[word]
|
| 102 |
+
if word in list(words_counts_fake_all.keys())
|
| 103 |
+
else 0
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
words_counts_fake_sorted = OrderedDict(
|
| 107 |
+
(key, words_counts_fake[key])
|
| 108 |
+
for key in words_counts_real
|
| 109 |
+
if key in words_counts_fake
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
words_counts_real = [(k, v) for k, v in words_counts_real.items()]
|
| 113 |
+
words_counts_fake_sorted = [(k, v) for k, v in words_counts_fake_sorted.items()]
|
| 114 |
+
|
| 115 |
+
words_real, counts_real = zip(*words_counts_real)
|
| 116 |
+
words_fake, counts_fake = zip(*words_counts_fake_sorted)
|
| 117 |
+
assert words_real == words_fake
|
| 118 |
+
|
| 119 |
+
print(f'words_real: {words_real}')
|
| 120 |
+
print(f'counts_real: {counts_real}')
|
| 121 |
+
|
| 122 |
+
print(f'words_fake: {words_fake}')
|
| 123 |
+
print(f'counts_fake: {counts_fake}')
|
| 124 |
+
# exit()
|
| 125 |
+
words_real, counts_real = words_real[:15], counts_real[:15]
|
| 126 |
+
words_fake, counts_fake = words_fake[:15], counts_fake[:15]
|
| 127 |
+
plt.figure(figsize=(16, 8))
|
| 128 |
+
width = 0.35
|
| 129 |
+
x = np.arange(len(words_real))
|
| 130 |
+
plt.bar(
|
| 131 |
+
x - width / 2,
|
| 132 |
+
counts_real,
|
| 133 |
+
width,
|
| 134 |
+
label=f'{" ".join(opt.root_path.split("/")[-2:])} real',
|
| 135 |
+
)
|
| 136 |
+
plt.bar(
|
| 137 |
+
x + width / 2,
|
| 138 |
+
counts_fake,
|
| 139 |
+
width,
|
| 140 |
+
label=f'{" ".join(opt.root_path.split("/")[-2:])} fake',
|
| 141 |
+
)
|
| 142 |
+
# plt.ylabel('Frequency')
|
| 143 |
+
# plt.title('Top 15 Most Common Words')
|
| 144 |
+
plt.xticks(x, labels=words_fake, rotation=90, fontsize=35)
|
| 145 |
+
plt.legend(prop={'size': 20})
|
| 146 |
+
plt.show()
|
| 147 |
+
plt.savefig(f'{opt.save_path}', bbox_inches='tight', pad_inches=0.1)
|
detector_codes/C2P-CLIP-DeepfakeDetection-main/data/__init__.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from torch.utils.data.sampler import WeightedRandomSampler
|
| 4 |
+
|
| 5 |
+
from .datasets import dataset_folder
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
def get_dataset(opt):
|
| 9 |
+
classes = os.listdir(opt.dataroot) if len(opt.classes) == 0 else opt.classes
|
| 10 |
+
if '0_real' not in classes or '1_fake' not in classes:
|
| 11 |
+
dset_lst = []
|
| 12 |
+
for cls in classes:
|
| 13 |
+
root = opt.dataroot + '/' + cls
|
| 14 |
+
dset = dataset_folder(opt, root)
|
| 15 |
+
dset_lst.append(dset)
|
| 16 |
+
return torch.utils.data.ConcatDataset(dset_lst)
|
| 17 |
+
return dataset_folder(opt, opt.dataroot)
|
| 18 |
+
|
| 19 |
+
def get_bal_sampler(dataset):
|
| 20 |
+
targets = []
|
| 21 |
+
for d in dataset.datasets:
|
| 22 |
+
targets.extend(d.targets)
|
| 23 |
+
|
| 24 |
+
ratio = np.bincount(targets)
|
| 25 |
+
w = 1. / torch.tensor(ratio, dtype=torch.float)
|
| 26 |
+
sample_weights = w[targets]
|
| 27 |
+
sampler = WeightedRandomSampler(weights=sample_weights,
|
| 28 |
+
num_samples=len(sample_weights))
|
| 29 |
+
return sampler
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def create_dataloader(opt):
|
| 33 |
+
shuffle = not opt.serial_batches if (opt.isTrain and not opt.class_bal) else False
|
| 34 |
+
dataset = get_dataset(opt)
|
| 35 |
+
sampler = get_bal_sampler(dataset) if opt.class_bal else None
|
| 36 |
+
|
| 37 |
+
data_loader = torch.utils.data.DataLoader(dataset,
|
| 38 |
+
batch_size=opt.batch_size,
|
| 39 |
+
shuffle=shuffle,
|
| 40 |
+
sampler=sampler,
|
| 41 |
+
drop_last=True if opt.isTrain else False,
|
| 42 |
+
num_workers=int(opt.num_threads))
|
| 43 |
+
return data_loader
|
detector_codes/C2P-CLIP-DeepfakeDetection-main/data/datasets.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torchvision.datasets as datasets
|
| 4 |
+
import torchvision.transforms as transforms
|
| 5 |
+
import torchvision.transforms.functional as TF
|
| 6 |
+
from random import random, choice
|
| 7 |
+
from io import BytesIO
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from PIL import ImageFile
|
| 10 |
+
from scipy.ndimage.filters import gaussian_filter
|
| 11 |
+
from torchvision.transforms import InterpolationMode
|
| 12 |
+
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
|
| 13 |
+
import os
|
| 14 |
+
from transformers import AutoTokenizer
|
| 15 |
+
|
| 16 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 17 |
+
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
|
| 18 |
+
|
| 19 |
+
def pil_loader(path: str) -> Image.Image:
|
| 20 |
+
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
|
| 21 |
+
with open(path, "rb") as f:
|
| 22 |
+
img = Image.open(f)
|
| 23 |
+
return img.convert("RGB")
|
| 24 |
+
|
| 25 |
+
class ImageFolder2(datasets.DatasetFolder):
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
root: str,
|
| 29 |
+
opt,
|
| 30 |
+
transform: Optional[Callable] = None,
|
| 31 |
+
):
|
| 32 |
+
super().__init__(
|
| 33 |
+
root,
|
| 34 |
+
transform=transform,
|
| 35 |
+
extensions=IMG_EXTENSIONS,
|
| 36 |
+
loader = pil_loader
|
| 37 |
+
)
|
| 38 |
+
self.opt = opt
|
| 39 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.opt.clip, model_max_length=77, padding_side="right", use_fast=False)
|
| 40 |
+
self.tokenizer.pad_token_id = 0
|
| 41 |
+
|
| 42 |
+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
| 43 |
+
"""
|
| 44 |
+
Args:
|
| 45 |
+
index (int): Index
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
tuple: (sample, target) where target is class_index of the target class.
|
| 49 |
+
"""
|
| 50 |
+
path, target = self.samples[index]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
textpath = path.replace(self.opt.imgroot, self.opt.textroot)
|
| 54 |
+
textpath = os.path.splitext(textpath)[0] + '.txt'
|
| 55 |
+
|
| 56 |
+
sample = self.loader(path)
|
| 57 |
+
try:
|
| 58 |
+
with open(textpath, 'r') as file:
|
| 59 |
+
text = file.read()
|
| 60 |
+
cates_len = len(self.opt.cates)//2
|
| 61 |
+
if target == 1: text = f'{" ".join(self.opt.cates[:cates_len])}. {text} {" ".join(self.opt.cates[:cates_len])}.'
|
| 62 |
+
if target == 0: text = f'{" ".join(self.opt.cates[cates_len:])}. {text} {" ".join(self.opt.cates[cates_len:])}.'
|
| 63 |
+
inputs = self.tokenizer([text], padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
|
| 64 |
+
input_ids=inputs['input_ids'][0]
|
| 65 |
+
attention_mask=inputs['attention_mask'][0]
|
| 66 |
+
except:
|
| 67 |
+
text, input_ids, attention_mask = ' ', ' ', ' '
|
| 68 |
+
|
| 69 |
+
if self.transform is not None:
|
| 70 |
+
sample = self.transform(sample)
|
| 71 |
+
if self.target_transform is not None:
|
| 72 |
+
target = self.target_transform(target)
|
| 73 |
+
|
| 74 |
+
return path, sample, text, input_ids, attention_mask, target
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def dataset_folder(opt, root):
|
| 78 |
+
if opt.mode == 'binary':
|
| 79 |
+
return binary_dataset(opt, root)
|
| 80 |
+
if opt.mode == 'filename':
|
| 81 |
+
return FileNameDataset(opt, root)
|
| 82 |
+
raise ValueError('opt.mode needs to be binary or filename.')
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def binary_dataset(opt, root):
|
| 86 |
+
if opt.isTrain:
|
| 87 |
+
crop_func = transforms.RandomCrop(opt.cropSize)
|
| 88 |
+
elif opt.no_crop:
|
| 89 |
+
crop_func = transforms.Lambda(lambda img: img)
|
| 90 |
+
else:
|
| 91 |
+
crop_func = transforms.CenterCrop(opt.cropSize)
|
| 92 |
+
|
| 93 |
+
if opt.isTrain and not opt.no_flip:
|
| 94 |
+
flip_func = transforms.RandomHorizontalFlip()
|
| 95 |
+
else:
|
| 96 |
+
flip_func = transforms.Lambda(lambda img: img)
|
| 97 |
+
|
| 98 |
+
if not opt.isTrain and opt.no_resize:
|
| 99 |
+
rz_func = transforms.Lambda(lambda img: img)
|
| 100 |
+
else:
|
| 101 |
+
rz_func = transforms.Lambda(lambda img: translate_duplicate(img, opt.cropSize))
|
| 102 |
+
|
| 103 |
+
dset = ImageFolder2(
|
| 104 |
+
root,
|
| 105 |
+
opt,
|
| 106 |
+
transforms.Compose([
|
| 107 |
+
rz_func,
|
| 108 |
+
crop_func,
|
| 109 |
+
flip_func,
|
| 110 |
+
transforms.ToTensor(),
|
| 111 |
+
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
|
| 112 |
+
]))
|
| 113 |
+
return dset
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class FileNameDataset(datasets.ImageFolder):
|
| 117 |
+
def name(self):
|
| 118 |
+
return 'FileNameDataset'
|
| 119 |
+
|
| 120 |
+
def __init__(self, opt, root):
|
| 121 |
+
self.opt = opt
|
| 122 |
+
super().__init__(root)
|
| 123 |
+
|
| 124 |
+
def __getitem__(self, index):
|
| 125 |
+
# Loading sample
|
| 126 |
+
path, target = self.samples[index]
|
| 127 |
+
return path
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
import math
|
| 131 |
+
def translate_duplicate(img, cropSize):
|
| 132 |
+
if min(img.size) < cropSize:
|
| 133 |
+
width, height = img.size
|
| 134 |
+
|
| 135 |
+
new_width = width * math.ceil(cropSize/width)
|
| 136 |
+
new_height = height * math.ceil(cropSize/height)
|
| 137 |
+
|
| 138 |
+
new_img = Image.new('RGB', (new_width, new_height))
|
| 139 |
+
for i in range(0, new_width, width):
|
| 140 |
+
for j in range(0, new_height, height):
|
| 141 |
+
new_img.paste(img, (i, j))
|
| 142 |
+
return new_img
|
| 143 |
+
else:
|
| 144 |
+
return img
|
| 145 |
+
|
| 146 |
+
def data_augment(img, opt):
|
| 147 |
+
img = np.array(img)
|
| 148 |
+
|
| 149 |
+
if random() < opt.blur_prob:
|
| 150 |
+
sig = sample_continuous(opt.blur_sig)
|
| 151 |
+
gaussian_blur(img, sig)
|
| 152 |
+
|
| 153 |
+
if random() < opt.jpg_prob:
|
| 154 |
+
method = sample_discrete(opt.jpg_method)
|
| 155 |
+
qual = sample_discrete(opt.jpg_qual)
|
| 156 |
+
img = jpeg_from_key(img, qual, method)
|
| 157 |
+
|
| 158 |
+
return Image.fromarray(img)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def sample_continuous(s):
|
| 162 |
+
if len(s) == 1:
|
| 163 |
+
return s[0]
|
| 164 |
+
if len(s) == 2:
|
| 165 |
+
rg = s[1] - s[0]
|
| 166 |
+
return random() * rg + s[0]
|
| 167 |
+
raise ValueError("Length of iterable s should be 1 or 2.")
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def sample_discrete(s):
|
| 171 |
+
if len(s) == 1:
|
| 172 |
+
return s[0]
|
| 173 |
+
return choice(s)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def gaussian_blur(img, sigma):
|
| 177 |
+
gaussian_filter(img[:,:,0], output=img[:,:,0], sigma=sigma)
|
| 178 |
+
gaussian_filter(img[:,:,1], output=img[:,:,1], sigma=sigma)
|
| 179 |
+
gaussian_filter(img[:,:,2], output=img[:,:,2], sigma=sigma)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def cv2_jpg(img, compress_val):
|
| 183 |
+
img_cv2 = img[:,:,::-1]
|
| 184 |
+
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), compress_val]
|
| 185 |
+
result, encimg = cv2.imencode('.jpg', img_cv2, encode_param)
|
| 186 |
+
decimg = cv2.imdecode(encimg, 1)
|
| 187 |
+
return decimg[:,:,::-1]
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def pil_jpg(img, compress_val):
|
| 191 |
+
out = BytesIO()
|
| 192 |
+
img = Image.fromarray(img)
|
| 193 |
+
img.save(out, format='jpeg', quality=compress_val)
|
| 194 |
+
img = Image.open(out)
|
| 195 |
+
# load from memory before ByteIO closes
|
| 196 |
+
img = np.array(img)
|
| 197 |
+
out.close()
|
| 198 |
+
return img
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
jpeg_dict = {'cv2': cv2_jpg, 'pil': pil_jpg}
|
| 202 |
+
def jpeg_from_key(img, compress_val, key):
|
| 203 |
+
method = jpeg_dict[key]
|
| 204 |
+
return method(img, compress_val)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
rz_dict = {'bilinear': InterpolationMode.BILINEAR,
|
| 208 |
+
'bicubic': InterpolationMode.BICUBIC,
|
| 209 |
+
'lanczos': InterpolationMode.LANCZOS,
|
| 210 |
+
'nearest': InterpolationMode.NEAREST}
|
| 211 |
+
def custom_resize(img, opt):
|
| 212 |
+
interp = sample_discrete(opt.rz_interp)
|
| 213 |
+
return TF.resize(img, opt.loadSize, interpolation=rz_dict[interp])
|
detector_codes/C2P-CLIP-DeepfakeDetection-main/decode_clipfeature_dataset.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 5 |
+
import argparse
|
| 6 |
+
import os
|
| 7 |
+
import warnings
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
warnings.filterwarnings('ignore')
|
| 13 |
+
from networks.decode_clipfeature_oneImage import (
|
| 14 |
+
get_clip_model,
|
| 15 |
+
get_clipcap_model,
|
| 16 |
+
get_image_features,
|
| 17 |
+
get_text,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def parse_args():
|
| 22 |
+
parser = argparse.ArgumentParser(description='decode detection feature to text')
|
| 23 |
+
parser.add_argument('--prefix_length', type=int, default=10)
|
| 24 |
+
parser.add_argument(
|
| 25 |
+
'--model_path', type=str, default='./ClipCaption_COCO.pt', help='model_path'
|
| 26 |
+
)
|
| 27 |
+
parser.add_argument(
|
| 28 |
+
'--images_root', type=str, default='', help='image_path', required=True
|
| 29 |
+
)
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
'--save_path', type=str, default='', help='image_path', required=True
|
| 32 |
+
)
|
| 33 |
+
parser.add_argument(
|
| 34 |
+
'--fc_path', type=str, default='./fc_parameters.pth', help='fc_path'
|
| 35 |
+
)
|
| 36 |
+
parser.add_argument(
|
| 37 |
+
'--cal_detection_feat', action='store_true', help='cal_detection_feat'
|
| 38 |
+
)
|
| 39 |
+
parser.add_argument('--device', type=str, default='cuda:0', help='cuda:n or cpu')
|
| 40 |
+
args = parser.parse_args()
|
| 41 |
+
|
| 42 |
+
def print_options(parser, args):
|
| 43 |
+
message = ''
|
| 44 |
+
message += '----------------- Options ---------------\n'
|
| 45 |
+
for k, v in sorted(vars(args).items()):
|
| 46 |
+
comment = ''
|
| 47 |
+
default = parser.get_default(k)
|
| 48 |
+
if v != default:
|
| 49 |
+
comment = '\t[default: %s]' % str(default)
|
| 50 |
+
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
|
| 51 |
+
message += '----------------- End -------------------'
|
| 52 |
+
print(message)
|
| 53 |
+
|
| 54 |
+
print_options(parser, args)
|
| 55 |
+
return args
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def get_image_files_in_directory(directory):
|
| 59 |
+
image_extensions = [
|
| 60 |
+
'.jpg',
|
| 61 |
+
'.jpeg',
|
| 62 |
+
'.png',
|
| 63 |
+
'.ppm',
|
| 64 |
+
'.bmp',
|
| 65 |
+
'.pgm',
|
| 66 |
+
'.tif',
|
| 67 |
+
'.tiff',
|
| 68 |
+
'.webp',
|
| 69 |
+
]
|
| 70 |
+
image_files = []
|
| 71 |
+
for root, dirs, files in os.walk(directory, followlinks=True):
|
| 72 |
+
for file in files:
|
| 73 |
+
if os.path.splitext(file)[1].lower() in image_extensions:
|
| 74 |
+
absolute_path = os.path.join(root, file)
|
| 75 |
+
image_files.append(absolute_path)
|
| 76 |
+
return image_files
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
if __name__ == '__main__':
|
| 80 |
+
opt = parse_args()
|
| 81 |
+
device = torch.device(opt.device)
|
| 82 |
+
assert os.path.exists(opt.images_root)
|
| 83 |
+
|
| 84 |
+
opt.images_root = os.path.abspath(opt.images_root)
|
| 85 |
+
opt.save_path = os.path.abspath(opt.save_path)
|
| 86 |
+
os.makedirs(opt.save_path, mode=0o777, exist_ok=True)
|
| 87 |
+
|
| 88 |
+
image_files = get_image_files_in_directory(opt.images_root)
|
| 89 |
+
print(f'len(image_files): {len(image_files)}')
|
| 90 |
+
|
| 91 |
+
clipmodel, processor = get_clip_model(
|
| 92 |
+
clip_name='openai/clip-vit-large-patch14', device=device
|
| 93 |
+
)
|
| 94 |
+
model, tokenizer = get_clipcap_model(opt.model_path, device=device)
|
| 95 |
+
|
| 96 |
+
for image_file in tqdm(image_files):
|
| 97 |
+
image_features = get_image_features(
|
| 98 |
+
image_file, clipmodel, processor, device=device
|
| 99 |
+
)
|
| 100 |
+
text = get_text(
|
| 101 |
+
image_features,
|
| 102 |
+
tokenizer,
|
| 103 |
+
model,
|
| 104 |
+
opt.fc_path,
|
| 105 |
+
opt.cal_detection_feat,
|
| 106 |
+
device=device,
|
| 107 |
+
)
|
| 108 |
+
text_save_path = os.path.splitext(
|
| 109 |
+
image_file.replace(opt.images_root, opt.save_path)
|
| 110 |
+
)[0]
|
| 111 |
+
os.makedirs(os.path.dirname(text_save_path), mode=0o777, exist_ok=True)
|
| 112 |
+
with open(f'{text_save_path}.txt', 'w') as file:
|
| 113 |
+
file.write(text)
|
detector_codes/C2P-CLIP-DeepfakeDetection-main/draw_tsne_kmean.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 5 |
+
import argparse
|
| 6 |
+
from random import shuffle
|
| 7 |
+
from typing import Any, Callable, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import PIL.Image
|
| 11 |
+
import skimage.io as io
|
| 12 |
+
import torch
|
| 13 |
+
import torchvision.datasets as datasets
|
| 14 |
+
from matplotlib import pyplot as plt
|
| 15 |
+
from MulticoreTSNE import MulticoreTSNE as TSNE
|
| 16 |
+
from PIL import ImageFile
|
| 17 |
+
from torch.utils.data import DataLoader
|
| 18 |
+
from transformers import CLIPProcessor
|
| 19 |
+
from utils.logger import Progbar
|
| 20 |
+
|
| 21 |
+
np.random.seed(123)
|
| 22 |
+
# from clipcap import gg_text, get_model
|
| 23 |
+
from decode_clipfeature_image import get_clip_model, get_clipcap_model, get_text
|
| 24 |
+
from kmeans_pytorch import kmeans
|
| 25 |
+
|
| 26 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 27 |
+
IMG_EXTENSIONS = (
|
| 28 |
+
'.jpg',
|
| 29 |
+
'.jpeg',
|
| 30 |
+
'.png',
|
| 31 |
+
'.ppm',
|
| 32 |
+
'.bmp',
|
| 33 |
+
'.pgm',
|
| 34 |
+
'.tif',
|
| 35 |
+
'.tiff',
|
| 36 |
+
'.webp',
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class dataset_folder(datasets.DatasetFolder):
|
| 41 |
+
def __init__(self, root: str, transform: Optional[Callable] = None):
|
| 42 |
+
super().__init__(root, transform=None, extensions=IMG_EXTENSIONS, loader=None)
|
| 43 |
+
self.processor = CLIPProcessor.from_pretrained('openai/clip-vit-large-patch14')
|
| 44 |
+
|
| 45 |
+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
| 46 |
+
path, target = self.samples[index]
|
| 47 |
+
image = PIL.Image.fromarray(io.imread(path)) # .to(device)
|
| 48 |
+
sample = self.processor(images=image, return_tensors='pt')['pixel_values'][0]
|
| 49 |
+
return sample, target, path
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def collate_fn(batch):
|
| 53 |
+
batch = list(filter(lambda x: x is not None, batch))
|
| 54 |
+
return torch.utils.data.dataloader.default_collate(batch)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def binary_dataset(root):
|
| 58 |
+
return dataset_folder(root)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def generate_colors(num_colors):
|
| 62 |
+
# cmap = plt.cm.get_cmap('tab20')
|
| 63 |
+
cmap = plt.colormaps.get_cmap('tab20')
|
| 64 |
+
colors = cmap(range(num_colors))
|
| 65 |
+
return colors
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def tsne_vis(features, labels, draw_dir, opt, device):
|
| 69 |
+
|
| 70 |
+
clipcap_model, tokenizer = get_clipcap_model(
|
| 71 |
+
model_path='https://www.now61.com/f/Xljmi0/coco_prefix_latest.pt', device=device
|
| 72 |
+
)
|
| 73 |
+
all_labels = list(set(labels.reshape([-1])))
|
| 74 |
+
num_clusters = 3
|
| 75 |
+
all_cluster_centers = []
|
| 76 |
+
all_text = {}
|
| 77 |
+
for tmp_label in all_labels:
|
| 78 |
+
feature_ = features[labels == tmp_label]
|
| 79 |
+
cluster_ids_x, cluster_centers = kmeans(
|
| 80 |
+
X=torch.from_numpy(feature_).cuda(),
|
| 81 |
+
num_clusters=num_clusters,
|
| 82 |
+
distance='euclidean',
|
| 83 |
+
device=device,
|
| 84 |
+
)
|
| 85 |
+
all_cluster_centers.append(cluster_centers)
|
| 86 |
+
print('=' * 100)
|
| 87 |
+
print(opt.legend[tmp_label])
|
| 88 |
+
tmp_text = []
|
| 89 |
+
for index in range(num_clusters):
|
| 90 |
+
tmp_text.append(
|
| 91 |
+
get_text(
|
| 92 |
+
cluster_centers[index].to(device),
|
| 93 |
+
tokenizer,
|
| 94 |
+
clipcap_model,
|
| 95 |
+
fc_path='https://www.now61.com/f/qwvoH5/fc_parameters.pth',
|
| 96 |
+
cal_detection_feat=False,
|
| 97 |
+
device=device,
|
| 98 |
+
)
|
| 99 |
+
)
|
| 100 |
+
print(f'\n{tmp_text[-1]}')
|
| 101 |
+
all_text[str(tmp_label + len(all_labels))] = tmp_text
|
| 102 |
+
print('=' * 100)
|
| 103 |
+
|
| 104 |
+
all_cluster_centers = torch.cat(all_cluster_centers, 0).cpu().numpy()
|
| 105 |
+
features = np.concatenate((features, all_cluster_centers), axis=0)
|
| 106 |
+
new_label = []
|
| 107 |
+
for index, all_label in enumerate(all_labels):
|
| 108 |
+
new_label.extend([len(all_labels) + index] * num_clusters)
|
| 109 |
+
labels = np.concatenate((labels, np.array(new_label)), axis=0)
|
| 110 |
+
|
| 111 |
+
embedding_path = os.path.join(draw_dir, '{}_embedding.npy'.format(opt.save_name))
|
| 112 |
+
|
| 113 |
+
if opt.do_fit or not os.path.exists(embedding_path):
|
| 114 |
+
print('>>> t-SNE fitting')
|
| 115 |
+
tsne_model = TSNE(
|
| 116 |
+
n_jobs=64, perplexity=opt.perplexity, random_state=1024, learning_rate=1000
|
| 117 |
+
)
|
| 118 |
+
embeddings = tsne_model.fit_transform(features)
|
| 119 |
+
print('<<< fitting over')
|
| 120 |
+
np.save(embedding_path, embeddings)
|
| 121 |
+
else:
|
| 122 |
+
embeddings = np.load(embedding_path)
|
| 123 |
+
|
| 124 |
+
index = [i for i in range(len(embeddings))]
|
| 125 |
+
shuffle(index)
|
| 126 |
+
embeddings = [embeddings[index[i]] for i in range(len(index))]
|
| 127 |
+
labels = [labels[index[i]] for i in range(len(index))]
|
| 128 |
+
embeddings = np.array(embeddings)
|
| 129 |
+
|
| 130 |
+
print('>>> draw image')
|
| 131 |
+
vis_x = embeddings[:, 0]
|
| 132 |
+
vis_y = embeddings[:, 1]
|
| 133 |
+
plt.figure(figsize=(35, 35))
|
| 134 |
+
plt.rcParams['figure.dpi'] = 1000
|
| 135 |
+
colors = generate_colors(20)
|
| 136 |
+
num_classes = len(set(labels))
|
| 137 |
+
|
| 138 |
+
for i in range(num_classes):
|
| 139 |
+
s = 1000 if i > 5 else 20
|
| 140 |
+
marker = '*' if i > 5 else 'o'
|
| 141 |
+
color = colors[i]
|
| 142 |
+
class_index = [j for j, v in enumerate(labels) if v == i]
|
| 143 |
+
plt.scatter(
|
| 144 |
+
vis_x[class_index], vis_y[class_index], s=s, color=color, marker=marker
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
if i > 5 and opt.draw_text:
|
| 148 |
+
texthighs = [0] * len(vis_y[class_index])
|
| 149 |
+
sorted_A = sorted(
|
| 150 |
+
enumerate(vis_y[class_index]), key=lambda x: x[1], reverse=False
|
| 151 |
+
)
|
| 152 |
+
for rank, (index, value) in enumerate(sorted_A):
|
| 153 |
+
texthighs[index] = (rank + 1) * 40
|
| 154 |
+
for tx, ty, ttext, texthigh in zip(
|
| 155 |
+
vis_x[class_index], vis_y[class_index], all_text[str(i)], texthighs
|
| 156 |
+
):
|
| 157 |
+
fc = colors[i - num_classes // 2]
|
| 158 |
+
plt.annotate(
|
| 159 |
+
f' {ttext} ',
|
| 160 |
+
xy=(tx, ty),
|
| 161 |
+
xycoords='data',
|
| 162 |
+
xytext=(100, texthigh),
|
| 163 |
+
fontsize=30,
|
| 164 |
+
textcoords='offset points',
|
| 165 |
+
color='white',
|
| 166 |
+
va='center',
|
| 167 |
+
ha='center',
|
| 168 |
+
weight='bold',
|
| 169 |
+
bbox=dict(boxstyle='round', fc=fc, ec='none'),
|
| 170 |
+
arrowprops=dict(
|
| 171 |
+
connectionstyle='angle3,angleA=0,angleB=90',
|
| 172 |
+
arrowstyle='wedge, tail_width=1.',
|
| 173 |
+
fc=fc,
|
| 174 |
+
ec='none',
|
| 175 |
+
patchA=None,
|
| 176 |
+
),
|
| 177 |
+
)
|
| 178 |
+
if opt.draw_text:
|
| 179 |
+
img_path = os.path.join(draw_dir, '{}_draw-text_tsne.png'.format(opt.save_name))
|
| 180 |
+
else:
|
| 181 |
+
img_path = os.path.join(draw_dir, '{}_tsne.png'.format(opt.save_name))
|
| 182 |
+
plt.xticks([])
|
| 183 |
+
plt.yticks([])
|
| 184 |
+
legend = plt.legend(opt.legend, prop={'size': 35})
|
| 185 |
+
for handle in legend.legend_handles:
|
| 186 |
+
handle.set_sizes([300])
|
| 187 |
+
plt.show()
|
| 188 |
+
plt.savefig(img_path, bbox_inches='tight', pad_inches=0.1)
|
| 189 |
+
print('<<<save image')
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def extract_feature(model, draw_loader, device):
|
| 193 |
+
|
| 194 |
+
fc_path = 'https://www.now61.com/f/qwvoH5/fc_parameters.pth'
|
| 195 |
+
mod = (
|
| 196 |
+
torch.hub.load_state_dict_from_url(fc_path, map_location='cpu', progress=True)
|
| 197 |
+
if fc_path.startswith('http')
|
| 198 |
+
else torch.load(fc_path, map_location='cpu')
|
| 199 |
+
)
|
| 200 |
+
weight, bias = mod['fc.weight'].to(device), mod['fc.bias'].to(device)
|
| 201 |
+
|
| 202 |
+
features = None
|
| 203 |
+
model.eval()
|
| 204 |
+
progbar = Progbar(len(draw_loader), stateful_metrics=['run-type'])
|
| 205 |
+
with torch.no_grad():
|
| 206 |
+
for _, batch in enumerate(draw_loader):
|
| 207 |
+
input_img_batch, label_batch, path_batch = batch
|
| 208 |
+
input_img = input_img_batch.to(device)
|
| 209 |
+
label = label_batch.reshape((-1)).to(device)
|
| 210 |
+
|
| 211 |
+
feature = model.get_image_features(input_img)
|
| 212 |
+
feature = torch.mul(feature, weight) + bias
|
| 213 |
+
feature /= feature.norm(2, dim=-1, keepdim=True)
|
| 214 |
+
if features is None:
|
| 215 |
+
features = feature.cpu().numpy()
|
| 216 |
+
gt_labels = label
|
| 217 |
+
else:
|
| 218 |
+
gt_labels = torch.cat([gt_labels, label])
|
| 219 |
+
features = np.vstack((features, feature.cpu().numpy()))
|
| 220 |
+
|
| 221 |
+
progbar.add(1, values=[('run-type', 'extract feature')])
|
| 222 |
+
|
| 223 |
+
gt_labels = gt_labels.cpu().numpy()
|
| 224 |
+
|
| 225 |
+
return features, gt_labels
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def parse_args():
|
| 229 |
+
parser = argparse.ArgumentParser(description='draw tsne')
|
| 230 |
+
parser.add_argument('--draw_data_path', type=str, required=True)
|
| 231 |
+
parser.add_argument('--image_path', type=str, help='image_path', default='')
|
| 232 |
+
parser.add_argument('--device', default='cuda:0', type=str, help='cuda:n or cpu')
|
| 233 |
+
parser.add_argument(
|
| 234 |
+
'--do_extract',
|
| 235 |
+
action='store_true',
|
| 236 |
+
default=False,
|
| 237 |
+
help='whether to extract features',
|
| 238 |
+
)
|
| 239 |
+
parser.add_argument(
|
| 240 |
+
'--do_fit', action='store_true', default=False, help='whether to fit tsne model'
|
| 241 |
+
)
|
| 242 |
+
parser.add_argument('--save_name', default='cross_all', type=str)
|
| 243 |
+
parser.add_argument('--legend', nargs='+', help='legend')
|
| 244 |
+
parser.add_argument('--draw_text', default=0, type=int)
|
| 245 |
+
parser.add_argument('--perplexity', default=20, type=int)
|
| 246 |
+
args = parser.parse_args()
|
| 247 |
+
|
| 248 |
+
def print_options(parser, args):
|
| 249 |
+
message = ''
|
| 250 |
+
message += '----------------- Options ---------------\n'
|
| 251 |
+
for k, v in sorted(vars(args).items()):
|
| 252 |
+
comment = ''
|
| 253 |
+
default = parser.get_default(k)
|
| 254 |
+
if v != default:
|
| 255 |
+
comment = '\t[default: %s]' % str(default)
|
| 256 |
+
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
|
| 257 |
+
message += '----------------- End -------------------'
|
| 258 |
+
print(message)
|
| 259 |
+
|
| 260 |
+
print_options(parser, args)
|
| 261 |
+
return args
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
if __name__ == '__main__':
|
| 265 |
+
opt = parse_args()
|
| 266 |
+
device = torch.device(opt.device)
|
| 267 |
+
|
| 268 |
+
draw_dir = os.path.join(
|
| 269 |
+
os.path.splitext(opt.draw_data_path)[0], 'tsne-' + opt.save_name
|
| 270 |
+
)
|
| 271 |
+
os.makedirs(draw_dir, exist_ok=True)
|
| 272 |
+
feature_path = os.path.join(draw_dir, '{}_features.npy'.format(opt.save_name))
|
| 273 |
+
label_path = os.path.join(draw_dir, '{}_labels.npy'.format(opt.save_name))
|
| 274 |
+
print('draw dir: %s' % draw_dir)
|
| 275 |
+
|
| 276 |
+
clipmodel, processor = get_clip_model(
|
| 277 |
+
clip_name='openai/clip-vit-large-patch14', device=device
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
draw_loader = DataLoader(
|
| 281 |
+
dataset=binary_dataset(opt.image_path),
|
| 282 |
+
num_workers=8,
|
| 283 |
+
batch_size=1,
|
| 284 |
+
pin_memory=True,
|
| 285 |
+
shuffle=False,
|
| 286 |
+
drop_last=False,
|
| 287 |
+
collate_fn=collate_fn,
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
if opt.do_extract or not os.path.exists(feature_path):
|
| 291 |
+
features, gt_labels = extract_feature(clipmodel, draw_loader, device)
|
| 292 |
+
np.save(feature_path, features)
|
| 293 |
+
np.save(label_path, gt_labels)
|
| 294 |
+
else:
|
| 295 |
+
features = np.load(feature_path)
|
| 296 |
+
gt_labels = np.load(label_path)
|
| 297 |
+
|
| 298 |
+
print('labels:', gt_labels.shape, 'features:', features.shape)
|
| 299 |
+
tsne_vis(features, gt_labels, draw_dir, opt, device)
|
| 300 |
+
|
| 301 |
+
# CUDA_VISIBLE_DEVICES=1 python draw_tsne_kmean.py --draw_data_path A_tsne_png_20240812 --image_path ../stylegan_tsne_data --save_name stylegan_test --legend stylegan-bedroom-real stylegan-bedroom-fake stylegan-car-real stylegan-car-fake stylegan-cat-real stylegan-cat-fake --do_extract --do_fit --draw_text 0
|
detector_codes/C2P-CLIP-DeepfakeDetection-main/inference.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import time
|
| 5 |
+
import warnings
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from data import create_dataloader
|
| 11 |
+
from sklearn.metrics import (
|
| 12 |
+
accuracy_score,
|
| 13 |
+
average_precision_score,
|
| 14 |
+
)
|
| 15 |
+
from transformers import CLIPModel
|
| 16 |
+
|
| 17 |
+
warnings.filterwarnings('ignore')
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def seed_torch(seed=1029):
|
| 21 |
+
random.seed(seed)
|
| 22 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
| 23 |
+
np.random.seed(seed)
|
| 24 |
+
torch.manual_seed(seed)
|
| 25 |
+
torch.cuda.manual_seed(seed)
|
| 26 |
+
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
|
| 27 |
+
torch.backends.cudnn.benchmark = False
|
| 28 |
+
torch.backends.cudnn.deterministic = True
|
| 29 |
+
torch.backends.cudnn.enabled = False
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
seed_torch(123)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class C2P_CLIP(nn.Module):
|
| 36 |
+
def __init__(self, name='openai/clip-vit-large-patch14', num_classes=1):
|
| 37 |
+
super(C2P_CLIP, self).__init__()
|
| 38 |
+
self.model = CLIPModel.from_pretrained(name)
|
| 39 |
+
del self.model.text_model
|
| 40 |
+
del self.model.text_projection
|
| 41 |
+
del self.model.logit_scale
|
| 42 |
+
|
| 43 |
+
self.model.vision_model.requires_grad_(False)
|
| 44 |
+
self.model.visual_projection.requires_grad_(False)
|
| 45 |
+
self.model.fc = nn.Linear(768, num_classes)
|
| 46 |
+
torch.nn.init.normal_(self.model.fc.weight.data, 0.0, 0.02)
|
| 47 |
+
|
| 48 |
+
def encode_image(self, img):
|
| 49 |
+
vision_outputs = self.model.vision_model(
|
| 50 |
+
pixel_values=img,
|
| 51 |
+
output_attentions=self.model.config.output_attentions,
|
| 52 |
+
output_hidden_states=self.model.config.output_hidden_states,
|
| 53 |
+
return_dict=self.model.config.return_dict,
|
| 54 |
+
)
|
| 55 |
+
pooled_output = vision_outputs[1] # pooled_output
|
| 56 |
+
image_features = self.model.visual_projection(pooled_output)
|
| 57 |
+
return image_features
|
| 58 |
+
|
| 59 |
+
def forward(self, img):
|
| 60 |
+
# tmp = x; print(f'x: {tmp.shape}, max: {tmp.max()}, min: {tmp.min()}, mean: {tmp.mean()}')
|
| 61 |
+
image_embeds = self.encode_image(img)
|
| 62 |
+
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
| 63 |
+
return self.model.fc(image_embeds)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def printSet(set_str):
|
| 67 |
+
set_str = str(set_str)
|
| 68 |
+
num = len(set_str)
|
| 69 |
+
print('=' * num * 3)
|
| 70 |
+
print(' ' * num + set_str)
|
| 71 |
+
print('=' * num * 3)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def parse_args():
|
| 75 |
+
parser = argparse.ArgumentParser(description='test C2P-CLIP')
|
| 76 |
+
parser.add_argument('--loadSize', type=int, default=224)
|
| 77 |
+
parser.add_argument('--cropSize', type=int, default=224)
|
| 78 |
+
parser.add_argument('--batch_size', type=int, default=64)
|
| 79 |
+
parser.add_argument(
|
| 80 |
+
'--dataroot',
|
| 81 |
+
type=str,
|
| 82 |
+
default='/opt/data/private/tcc/data/data/DeepfakeDetection/datasets/ForenSynths_train_val_19test/test/',
|
| 83 |
+
)
|
| 84 |
+
parser.add_argument(
|
| 85 |
+
'--model_path',
|
| 86 |
+
type=str,
|
| 87 |
+
default='https://www.now61.com/f/95OefW/C2P_CLIP_release_20240901.zip',
|
| 88 |
+
)
|
| 89 |
+
args = parser.parse_args()
|
| 90 |
+
|
| 91 |
+
def print_options(parser, args):
|
| 92 |
+
message = ''
|
| 93 |
+
message += '----------------- Options ---------------\n'
|
| 94 |
+
for k, v in sorted(vars(args).items()):
|
| 95 |
+
comment = ''
|
| 96 |
+
default = parser.get_default(k)
|
| 97 |
+
if v != default:
|
| 98 |
+
comment = '\t[default: %s]' % str(default)
|
| 99 |
+
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
|
| 100 |
+
message += '----------------- End -------------------'
|
| 101 |
+
print(message)
|
| 102 |
+
|
| 103 |
+
print_options(parser, args)
|
| 104 |
+
return args
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
if __name__ == '__main__':
|
| 108 |
+
opt = parse_args()
|
| 109 |
+
DetectionTests = {
|
| 110 |
+
'19Test': {
|
| 111 |
+
'dataroot': opt.dataroot,
|
| 112 |
+
'no_resize': False,
|
| 113 |
+
'no_crop': False,
|
| 114 |
+
},
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
# state_dict = torch.hub._legacy_zip_load( 'C2P_CLIP_release_20240901.zip', './', map_location = "cpu", weights_only= False)
|
| 118 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 119 |
+
opt.model_path, map_location='cpu', progress=True
|
| 120 |
+
)
|
| 121 |
+
model = C2P_CLIP(name='openai/clip-vit-large-patch14', num_classes=1)
|
| 122 |
+
model.load_state_dict(state_dict, strict=True)
|
| 123 |
+
model.cuda()
|
| 124 |
+
model.eval()
|
| 125 |
+
|
| 126 |
+
for testSet in DetectionTests.keys():
|
| 127 |
+
dataroot = DetectionTests[testSet]['dataroot']
|
| 128 |
+
printSet(testSet)
|
| 129 |
+
accs = []
|
| 130 |
+
aps = []
|
| 131 |
+
print(time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime()))
|
| 132 |
+
for v_id, val in enumerate(os.listdir(dataroot)):
|
| 133 |
+
opt.dataroot = '{}/{}'.format(dataroot, val)
|
| 134 |
+
opt.classes = '' # os.listdir(opt.dataroot) if multiclass[v_id] else ['']
|
| 135 |
+
opt.no_resize = DetectionTests[testSet]['no_resize']
|
| 136 |
+
opt.no_crop = DetectionTests[testSet]['no_crop']
|
| 137 |
+
data_loader = create_dataloader(opt)
|
| 138 |
+
with torch.no_grad():
|
| 139 |
+
y_true, y_pred = [], []
|
| 140 |
+
for img, label, path in data_loader:
|
| 141 |
+
y_pred.extend(model(img.cuda()).sigmoid().flatten().tolist())
|
| 142 |
+
y_true.extend(label.flatten().tolist())
|
| 143 |
+
y_true, y_pred = np.array(y_true), np.array(y_pred)
|
| 144 |
+
r_acc = accuracy_score(y_true[y_true == 0], y_pred[y_true == 0] > 0.5)
|
| 145 |
+
f_acc = accuracy_score(y_true[y_true == 1], y_pred[y_true == 1] > 0.5)
|
| 146 |
+
acc = accuracy_score(y_true, y_pred > 0.5)
|
| 147 |
+
ap = average_precision_score(y_true, y_pred)
|
| 148 |
+
|
| 149 |
+
accs.append(acc)
|
| 150 |
+
aps.append(ap)
|
| 151 |
+
print(
|
| 152 |
+
'({} {:12}) acc: {:.2f}; ap: {:.2f}'.format(
|
| 153 |
+
v_id, val, acc * 100, ap * 100
|
| 154 |
+
)
|
| 155 |
+
)
|
| 156 |
+
print(
|
| 157 |
+
'({} {:10}) acc: {:.2f}; ap: {:.2f}'.format(
|
| 158 |
+
v_id + 1,
|
| 159 |
+
'Mean',
|
| 160 |
+
np.array(accs).mean() * 100,
|
| 161 |
+
np.array(aps).mean() * 100,
|
| 162 |
+
)
|
| 163 |
+
)
|
| 164 |
+
print('*' * 25)
|
detector_codes/C2P-CLIP-DeepfakeDetection-main/networks/__init__.py
ADDED
|
File without changes
|
detector_codes/C2P-CLIP-DeepfakeDetection-main/networks/base_model.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from pix2pix
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from torch.nn import init
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class BaseModel(nn.Module):
|
| 10 |
+
def __init__(self, opt):
|
| 11 |
+
super(BaseModel, self).__init__()
|
| 12 |
+
self.opt = opt
|
| 13 |
+
self.total_steps = 0
|
| 14 |
+
self.isTrain = opt.isTrain
|
| 15 |
+
self.lr = opt.lr
|
| 16 |
+
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
| 17 |
+
self.device = (
|
| 18 |
+
torch.device('cuda:{}'.format(opt.gpu_ids[0]))
|
| 19 |
+
if opt.gpu_ids
|
| 20 |
+
else torch.device('cpu')
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
def save_networks(self, epoch):
|
| 24 |
+
save_filename = 'model_epoch_%s.pth' % epoch
|
| 25 |
+
save_path = os.path.join(self.save_dir, save_filename)
|
| 26 |
+
|
| 27 |
+
# serialize model and optimizer to dict
|
| 28 |
+
state_dict = {
|
| 29 |
+
'model': self.model.state_dict(),
|
| 30 |
+
# 'optimizer' : self.optimizer.state_dict(),
|
| 31 |
+
'total_steps': self.total_steps,
|
| 32 |
+
}
|
| 33 |
+
torch.save(state_dict, save_path)
|
| 34 |
+
try:
|
| 35 |
+
try:
|
| 36 |
+
savemodel = self.model.module
|
| 37 |
+
except:
|
| 38 |
+
savemodel = self.model
|
| 39 |
+
# savemodel.model.vision_model = savemodel.vision_tower_lora.merge_and_unload()
|
| 40 |
+
save_path2 = os.path.join(self.save_dir, f'model_epoch_{epoch}')
|
| 41 |
+
os.makedirs(save_path2, mode=0o777, exist_ok=True)
|
| 42 |
+
savemodel.model.save_pretrained(save_path2, safe_serialization=False)
|
| 43 |
+
except:
|
| 44 |
+
pass
|
| 45 |
+
|
| 46 |
+
# load models from the disk
|
| 47 |
+
def load_networks(self, epoch):
|
| 48 |
+
load_filename = 'model_epoch_%s.pth' % epoch
|
| 49 |
+
load_path = os.path.join(self.save_dir, load_filename)
|
| 50 |
+
|
| 51 |
+
print('loading the model from %s' % load_path)
|
| 52 |
+
# if you are using PyTorch newer than 0.4 (e.g., built from
|
| 53 |
+
# GitHub source), you can remove str() on self.device
|
| 54 |
+
state_dict = torch.load(load_path, map_location=self.device)
|
| 55 |
+
if hasattr(state_dict, '_metadata'):
|
| 56 |
+
del state_dict._metadata
|
| 57 |
+
|
| 58 |
+
self.model.load_state_dict(state_dict['model'])
|
| 59 |
+
self.total_steps = state_dict['total_steps']
|
| 60 |
+
|
| 61 |
+
if self.isTrain and not self.opt.new_optim:
|
| 62 |
+
self.optimizer.load_state_dict(state_dict['optimizer'])
|
| 63 |
+
### move optimizer state to GPU
|
| 64 |
+
for state in self.optimizer.state.values():
|
| 65 |
+
for k, v in state.items():
|
| 66 |
+
if torch.is_tensor(v):
|
| 67 |
+
state[k] = v.to(self.device)
|
| 68 |
+
|
| 69 |
+
for g in self.optimizer.param_groups:
|
| 70 |
+
g['lr'] = self.opt.lr
|
| 71 |
+
|
| 72 |
+
def eval(self):
|
| 73 |
+
self.model.eval()
|
| 74 |
+
|
| 75 |
+
def train(self):
|
| 76 |
+
self.model.train()
|
| 77 |
+
|
| 78 |
+
def test(self):
|
| 79 |
+
with torch.no_grad():
|
| 80 |
+
self.forward()
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def init_weights(net, init_type='normal', gain=0.02):
|
| 84 |
+
def init_func(m):
|
| 85 |
+
classname = m.__class__.__name__
|
| 86 |
+
if hasattr(m, 'weight') and (
|
| 87 |
+
classname.find('Conv') != -1 or classname.find('Linear') != -1
|
| 88 |
+
):
|
| 89 |
+
if init_type == 'normal':
|
| 90 |
+
init.normal_(m.weight.data, 0.0, gain)
|
| 91 |
+
elif init_type == 'xavier':
|
| 92 |
+
init.xavier_normal_(m.weight.data, gain=gain)
|
| 93 |
+
elif init_type == 'kaiming':
|
| 94 |
+
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
| 95 |
+
elif init_type == 'orthogonal':
|
| 96 |
+
init.orthogonal_(m.weight.data, gain=gain)
|
| 97 |
+
else:
|
| 98 |
+
raise NotImplementedError(
|
| 99 |
+
'initialization method [%s] is not implemented' % init_type
|
| 100 |
+
)
|
| 101 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
| 102 |
+
init.constant_(m.bias.data, 0.0)
|
| 103 |
+
elif classname.find('BatchNorm2d') != -1:
|
| 104 |
+
init.normal_(m.weight.data, 1.0, gain)
|
| 105 |
+
init.constant_(m.bias.data, 0.0)
|
| 106 |
+
|
| 107 |
+
print('initialize network with %s' % init_type)
|
| 108 |
+
net.apply(init_func)
|
detector_codes/C2P-CLIP-DeepfakeDetection-main/networks/c2p_clip.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from peft import LoraConfig, get_peft_model
|
| 4 |
+
from transformers import CLIPModel
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class C2P_CLIP_Model(nn.Module):
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
name='openai/clip-vit-large-patch14',
|
| 11 |
+
num_classes=1,
|
| 12 |
+
lora_r=16,
|
| 13 |
+
lora_alpha=32,
|
| 14 |
+
lora_dropout=0.05,
|
| 15 |
+
hf_token=None,
|
| 16 |
+
):
|
| 17 |
+
super(C2P_CLIP_Model, self).__init__()
|
| 18 |
+
|
| 19 |
+
self.model = CLIPModel.from_pretrained(name, token=hf_token)
|
| 20 |
+
del self.model.text_model
|
| 21 |
+
del self.model.text_projection
|
| 22 |
+
del self.model.logit_scale
|
| 23 |
+
|
| 24 |
+
self.vision_tower = self.model.vision_model
|
| 25 |
+
self.vision_tower.requires_grad_(False)
|
| 26 |
+
self.model.visual_projection.requires_grad_(False)
|
| 27 |
+
|
| 28 |
+
lora_config = LoraConfig(
|
| 29 |
+
r=lora_r,
|
| 30 |
+
lora_alpha=lora_alpha,
|
| 31 |
+
target_modules=['q_proj', 'k_proj', 'v_proj'],
|
| 32 |
+
lora_dropout=lora_dropout,
|
| 33 |
+
bias='none',
|
| 34 |
+
)
|
| 35 |
+
self.vision_tower_lora = get_peft_model(self.vision_tower, lora_config)
|
| 36 |
+
|
| 37 |
+
self.model.fc = nn.Linear(768, num_classes)
|
| 38 |
+
torch.nn.init.normal_(self.model.fc.weight.data, 0.0, 0.02)
|
| 39 |
+
|
| 40 |
+
def encode_image(self, img):
|
| 41 |
+
vision_outputs = self.vision_tower_lora(
|
| 42 |
+
pixel_values=img,
|
| 43 |
+
output_attentions=self.model.config.output_attentions,
|
| 44 |
+
output_hidden_states=self.model.config.output_hidden_states,
|
| 45 |
+
return_dict=self.model.config.return_dict,
|
| 46 |
+
)
|
| 47 |
+
pooled_output = vision_outputs[1] # pooled_output
|
| 48 |
+
image_features = self.model.visual_projection(pooled_output)
|
| 49 |
+
return image_features
|
| 50 |
+
|
| 51 |
+
def forward(self, img):
|
| 52 |
+
image_embeds = self.encode_image(img)
|
| 53 |
+
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
| 54 |
+
return self.model.fc(image_embeds)
|
| 55 |
+
|
| 56 |
+
def detect(self, img):
|
| 57 |
+
with torch.no_grad():
|
| 58 |
+
output = self.forward(img)
|
| 59 |
+
return torch.sigmoid(output).squeeze(1)
|
detector_codes/C2P-CLIP-DeepfakeDetection-main/networks/decode_clipfeature_image.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# conda -n create clip-text-decoder python=3.8.5
|
| 2 |
+
import argparse
|
| 3 |
+
import warnings
|
| 4 |
+
from typing import Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import PIL.Image
|
| 7 |
+
import skimage.io as io
|
| 8 |
+
import torch # 2.3.1+cu118 conda install pytorch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 pytorch-cuda=11.8 -c pytorch -c nvidia
|
| 9 |
+
import torch.nn.functional as nnf
|
| 10 |
+
from torch import nn
|
| 11 |
+
from transformers import ( # 4.25.0
|
| 12 |
+
CLIPModel,
|
| 13 |
+
CLIPProcessor,
|
| 14 |
+
GPT2LMHeadModel,
|
| 15 |
+
GPT2Tokenizer,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
warnings.filterwarnings('ignore')
|
| 19 |
+
|
| 20 |
+
N = type(None)
|
| 21 |
+
T = torch.Tensor
|
| 22 |
+
D = torch.device
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class MLP(nn.Module):
|
| 26 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 27 |
+
return self.model(x)
|
| 28 |
+
|
| 29 |
+
def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
|
| 30 |
+
super(MLP, self).__init__()
|
| 31 |
+
layers = []
|
| 32 |
+
for i in range(len(sizes) - 1):
|
| 33 |
+
layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
|
| 34 |
+
if i < len(sizes) - 2:
|
| 35 |
+
layers.append(act())
|
| 36 |
+
self.model = nn.Sequential(*layers)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class ClipCaptionModel(nn.Module):
|
| 40 |
+
# @functools.lru_cache #FIXME
|
| 41 |
+
def get_dummy_token(self, batch_size: int, device: D) -> T:
|
| 42 |
+
return torch.zeros(
|
| 43 |
+
batch_size, self.prefix_length, dtype=torch.int64, device=device
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
def forward(
|
| 47 |
+
self, tokens: T, prefix: T, mask: Optional[T] = None, labels: Optional[T] = None
|
| 48 |
+
):
|
| 49 |
+
embedding_text = self.gpt.transformer.wte(tokens)
|
| 50 |
+
prefix_projections = self.clip_project(prefix).view(
|
| 51 |
+
-1, self.prefix_length, self.gpt_embedding_size
|
| 52 |
+
)
|
| 53 |
+
embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
|
| 54 |
+
if labels is not None:
|
| 55 |
+
dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
|
| 56 |
+
labels = torch.cat((dummy_token, tokens), dim=1)
|
| 57 |
+
out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
|
| 58 |
+
return out
|
| 59 |
+
|
| 60 |
+
def __init__(self, prefix_length: int, prefix_size: int = 512):
|
| 61 |
+
super(ClipCaptionModel, self).__init__()
|
| 62 |
+
self.prefix_length = prefix_length
|
| 63 |
+
self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
|
| 64 |
+
self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
|
| 65 |
+
self.clip_project = MLP(
|
| 66 |
+
(
|
| 67 |
+
prefix_size,
|
| 68 |
+
(self.gpt_embedding_size * prefix_length) // 2,
|
| 69 |
+
self.gpt_embedding_size * prefix_length,
|
| 70 |
+
)
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def generate2(
|
| 75 |
+
model,
|
| 76 |
+
tokenizer,
|
| 77 |
+
tokens=None,
|
| 78 |
+
prompt=None,
|
| 79 |
+
embed=None,
|
| 80 |
+
entry_count=1,
|
| 81 |
+
entry_length=67,
|
| 82 |
+
top_p=0.8,
|
| 83 |
+
temperature=1.0,
|
| 84 |
+
stop_token: str = '.',
|
| 85 |
+
):
|
| 86 |
+
model.eval()
|
| 87 |
+
generated_num = 0
|
| 88 |
+
generated_list = []
|
| 89 |
+
stop_token_index = tokenizer.encode(stop_token)[0]
|
| 90 |
+
filter_value = -float('Inf')
|
| 91 |
+
device = next(model.parameters()).device
|
| 92 |
+
|
| 93 |
+
with torch.no_grad():
|
| 94 |
+
for entry_idx in range(entry_count):
|
| 95 |
+
if embed is not None:
|
| 96 |
+
generated = embed
|
| 97 |
+
else:
|
| 98 |
+
if tokens is None:
|
| 99 |
+
tokens = torch.tensor(tokenizer.encode(prompt))
|
| 100 |
+
tokens = tokens.unsqueeze(0).to(device)
|
| 101 |
+
generated = model.gpt.transformer.wte(tokens)
|
| 102 |
+
|
| 103 |
+
for i in range(entry_length):
|
| 104 |
+
outputs = model.gpt(inputs_embeds=generated)
|
| 105 |
+
logits = outputs.logits
|
| 106 |
+
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
|
| 107 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 108 |
+
cumulative_probs = torch.cumsum(
|
| 109 |
+
nnf.softmax(sorted_logits, dim=-1), dim=-1
|
| 110 |
+
)
|
| 111 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 112 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
|
| 113 |
+
..., :-1
|
| 114 |
+
].clone()
|
| 115 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 116 |
+
|
| 117 |
+
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
| 118 |
+
logits[:, indices_to_remove] = filter_value
|
| 119 |
+
next_token = torch.argmax(logits, -1).unsqueeze(0)
|
| 120 |
+
next_token_embed = model.gpt.transformer.wte(next_token)
|
| 121 |
+
if tokens is None:
|
| 122 |
+
tokens = next_token
|
| 123 |
+
else:
|
| 124 |
+
tokens = torch.cat((tokens, next_token), dim=1)
|
| 125 |
+
generated = torch.cat((generated, next_token_embed), dim=1)
|
| 126 |
+
if stop_token_index == next_token.item():
|
| 127 |
+
break
|
| 128 |
+
|
| 129 |
+
output_list = list(tokens.squeeze().cpu().numpy())
|
| 130 |
+
output_text = tokenizer.decode(output_list)
|
| 131 |
+
generated_list.append(output_text)
|
| 132 |
+
|
| 133 |
+
return generated_list[0]
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def parse_args():
|
| 137 |
+
parser = argparse.ArgumentParser(description='decode detection feature to text')
|
| 138 |
+
parser.add_argument('--prefix_length', type=int, default=10)
|
| 139 |
+
parser.add_argument(
|
| 140 |
+
'--model_path',
|
| 141 |
+
type=str,
|
| 142 |
+
default='https://www.now61.com/f/Xljmi0/coco_prefix_latest.pt',
|
| 143 |
+
help='model_path',
|
| 144 |
+
)
|
| 145 |
+
parser.add_argument(
|
| 146 |
+
'--image_path', type=str, default='', help='image_path', required=True
|
| 147 |
+
)
|
| 148 |
+
parser.add_argument(
|
| 149 |
+
'--fc_path',
|
| 150 |
+
type=str,
|
| 151 |
+
default='https://www.now61.com/f/qwvoH5/fc_parameters.pth',
|
| 152 |
+
help='fc_path',
|
| 153 |
+
)
|
| 154 |
+
parser.add_argument(
|
| 155 |
+
'--cal_detection_feat', action='store_true', help='cal_detection_feat'
|
| 156 |
+
)
|
| 157 |
+
parser.add_argument('--device', type=str, default='cuda:0', help='cuda:n or cpu')
|
| 158 |
+
args = parser.parse_args()
|
| 159 |
+
|
| 160 |
+
def print_options(parser, args):
|
| 161 |
+
message = ''
|
| 162 |
+
message += '----------------- Options ---------------\n'
|
| 163 |
+
for k, v in sorted(vars(args).items()):
|
| 164 |
+
comment = ''
|
| 165 |
+
default = parser.get_default(k)
|
| 166 |
+
if v != default:
|
| 167 |
+
comment = '\t[default: %s]' % str(default)
|
| 168 |
+
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
|
| 169 |
+
message += '----------------- End -------------------'
|
| 170 |
+
print(message)
|
| 171 |
+
|
| 172 |
+
print_options(parser, args)
|
| 173 |
+
return args
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def get_text(
|
| 177 |
+
image_features,
|
| 178 |
+
tokenizer,
|
| 179 |
+
model,
|
| 180 |
+
fc_path,
|
| 181 |
+
cal_detection_feat=True,
|
| 182 |
+
prefix_length=10,
|
| 183 |
+
device='cpu',
|
| 184 |
+
):
|
| 185 |
+
mod = (
|
| 186 |
+
torch.hub.load_state_dict_from_url(fc_path, map_location='cpu', progress=True)
|
| 187 |
+
if fc_path.startswith('http')
|
| 188 |
+
else torch.load(fc_path, map_location='cpu')
|
| 189 |
+
)
|
| 190 |
+
weight, bias = mod['fc.weight'].to(device), mod['fc.bias'].to(device)
|
| 191 |
+
with torch.no_grad():
|
| 192 |
+
prob = nnf.linear(image_features, weight, bias).sigmoid().cpu().numpy()[0][0]
|
| 193 |
+
dict_prob = {False: 'Fake', True: 'Real'}
|
| 194 |
+
# print( f'\nPredicted prob: {prob}, {dict_prob[prob<0.5]}' )
|
| 195 |
+
# tmp=image_features;print(f'image_features: {tmp.shape}, max: {tmp.max()}, min: {tmp.min()}, mean: {tmp.mean()}')
|
| 196 |
+
if cal_detection_feat:
|
| 197 |
+
image_features = torch.mul(image_features, weight) + bias
|
| 198 |
+
image_features /= image_features.norm(2, dim=-1, keepdim=True)
|
| 199 |
+
prefix_embed = model.clip_project(image_features).reshape(1, prefix_length, -1)
|
| 200 |
+
generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed)
|
| 201 |
+
return generated_text_prefix
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def get_clip_model(clip_name='openai/clip-vit-large-patch14', device='cpu'):
|
| 205 |
+
clipmodel = CLIPModel.from_pretrained(clip_name)
|
| 206 |
+
processor = CLIPProcessor.from_pretrained(clip_name)
|
| 207 |
+
del clipmodel.text_model
|
| 208 |
+
del clipmodel.text_projection
|
| 209 |
+
del clipmodel.logit_scale
|
| 210 |
+
clipmodel = clipmodel.to(device)
|
| 211 |
+
return clipmodel, processor
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def get_clipcap_model(model_path, prefix_length=10, device='cpu'):
|
| 215 |
+
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
| 216 |
+
model = ClipCaptionModel(prefix_length, prefix_size=768)
|
| 217 |
+
pretrained = (
|
| 218 |
+
torch.hub.load_state_dict_from_url(
|
| 219 |
+
model_path, map_location='cpu', progress=True
|
| 220 |
+
)
|
| 221 |
+
if model_path.startswith('http')
|
| 222 |
+
else torch.load(model_path, map_location='cpu')
|
| 223 |
+
)
|
| 224 |
+
model.load_state_dict(pretrained)
|
| 225 |
+
assert pretrained.keys() == model.state_dict().keys()
|
| 226 |
+
model = model.eval()
|
| 227 |
+
model = model.to(device)
|
| 228 |
+
return model, tokenizer
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def get_image_features(image_path, clipmodel, processor, device='cpu'):
|
| 232 |
+
with torch.no_grad():
|
| 233 |
+
image = PIL.Image.fromarray(io.imread(image_path))
|
| 234 |
+
inputs = processor(images=image, return_tensors='pt')
|
| 235 |
+
inputs['pixel_values'] = inputs['pixel_values'].to(device)
|
| 236 |
+
image_features = clipmodel.get_image_features(**inputs)
|
| 237 |
+
return image_features
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
if __name__ == '__main__':
|
| 241 |
+
opt = parse_args()
|
| 242 |
+
device = torch.device(opt.device)
|
| 243 |
+
|
| 244 |
+
clipmodel, processor = get_clip_model(
|
| 245 |
+
clip_name='openai/clip-vit-large-patch14', device=device
|
| 246 |
+
)
|
| 247 |
+
model, tokenizer = get_clipcap_model(opt.model_path, device=device)
|
| 248 |
+
|
| 249 |
+
image_features = get_image_features(
|
| 250 |
+
opt.image_path, clipmodel, processor, device=device
|
| 251 |
+
)
|
| 252 |
+
text = get_text(
|
| 253 |
+
image_features,
|
| 254 |
+
tokenizer,
|
| 255 |
+
model,
|
| 256 |
+
opt.fc_path,
|
| 257 |
+
opt.cal_detection_feat,
|
| 258 |
+
device=device,
|
| 259 |
+
)
|
| 260 |
+
print(text)
|
detector_codes/C2P-CLIP-DeepfakeDetection-main/networks/trainer.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from peft import LoraConfig, get_peft_model
|
| 4 |
+
from transformers import CLIPModel, CLIPProcessor
|
| 5 |
+
|
| 6 |
+
from networks.base_model import BaseModel
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class CLIPModel_lora(nn.Module):
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
name='openai/clip-vit-large-patch14-336',
|
| 13 |
+
num_classes=1,
|
| 14 |
+
lora_r=16,
|
| 15 |
+
lora_alpha=32,
|
| 16 |
+
lora_dropout=0.05,
|
| 17 |
+
):
|
| 18 |
+
super(CLIPModel_lora, self).__init__()
|
| 19 |
+
self.model = CLIPModel.from_pretrained(name)
|
| 20 |
+
self.processor = CLIPProcessor.from_pretrained(name)
|
| 21 |
+
self.vision_tower = self.model.vision_model
|
| 22 |
+
self.vision_tower.requires_grad_(False)
|
| 23 |
+
self.model.text_model.requires_grad_(False)
|
| 24 |
+
self.model.visual_projection.requires_grad_(False)
|
| 25 |
+
self.model.text_projection.requires_grad_(False)
|
| 26 |
+
self.contrastive_loss = nn.CrossEntropyLoss()
|
| 27 |
+
self.model.logit_scale.requires_grad_(False)
|
| 28 |
+
lora_config = LoraConfig(
|
| 29 |
+
r=lora_r,
|
| 30 |
+
lora_alpha=lora_alpha,
|
| 31 |
+
target_modules=['q_proj', 'k_proj', 'v_proj'],
|
| 32 |
+
lora_dropout=lora_dropout,
|
| 33 |
+
bias='none',
|
| 34 |
+
)
|
| 35 |
+
self.vision_tower_lora = get_peft_model(self.vision_tower, lora_config)
|
| 36 |
+
self.model.fc = nn.Linear(768, num_classes)
|
| 37 |
+
torch.nn.init.normal_(self.model.fc.weight.data, 0.0, 0.02)
|
| 38 |
+
|
| 39 |
+
def encode_text(self, input_ids, attention_mask):
|
| 40 |
+
text_outputs = self.model.text_model(
|
| 41 |
+
input_ids=input_ids,
|
| 42 |
+
attention_mask=attention_mask,
|
| 43 |
+
position_ids=None,
|
| 44 |
+
output_attentions=self.model.config.output_attentions,
|
| 45 |
+
output_hidden_states=self.model.config.output_hidden_states,
|
| 46 |
+
return_dict=self.model.config.return_dict,
|
| 47 |
+
)
|
| 48 |
+
text_embeds = text_outputs[1]
|
| 49 |
+
text_embeds = self.model.text_projection(text_embeds)
|
| 50 |
+
return text_embeds
|
| 51 |
+
|
| 52 |
+
def encode_image(self, img):
|
| 53 |
+
vision_outputs = self.vision_tower_lora(
|
| 54 |
+
pixel_values=img,
|
| 55 |
+
output_attentions=self.model.config.output_attentions,
|
| 56 |
+
output_hidden_states=self.model.config.output_hidden_states,
|
| 57 |
+
return_dict=self.model.config.return_dict,
|
| 58 |
+
)
|
| 59 |
+
pooled_output = vision_outputs[1] # pooled_output
|
| 60 |
+
image_features = self.model.visual_projection(pooled_output)
|
| 61 |
+
return image_features
|
| 62 |
+
|
| 63 |
+
def forward(self, img, input_ids, attention_mask, cla=False):
|
| 64 |
+
# tmp = x; print(f'x: {tmp.shape}, max: {tmp.max()}, min: {tmp.min()}, mean: {tmp.mean()}')
|
| 65 |
+
|
| 66 |
+
image_embeds = self.encode_image(img)
|
| 67 |
+
|
| 68 |
+
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
| 69 |
+
classhead = self.model.fc(image_embeds)
|
| 70 |
+
if cla:
|
| 71 |
+
return classhead
|
| 72 |
+
|
| 73 |
+
text_embeds = self.encode_text(input_ids, attention_mask)
|
| 74 |
+
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
|
| 75 |
+
|
| 76 |
+
logits_per_text = (
|
| 77 |
+
torch.matmul(text_embeds, image_embeds.t()) * self.model.logit_scale.exp()
|
| 78 |
+
)
|
| 79 |
+
logits_per_image = logits_per_text.t()
|
| 80 |
+
|
| 81 |
+
return logits_per_image, classhead.squeeze(1)
|
| 82 |
+
|
| 83 |
+
def forward_eval(self, img):
|
| 84 |
+
image_embeds = self.encode_image(img)
|
| 85 |
+
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
| 86 |
+
classhead = self.model.fc(image_embeds)
|
| 87 |
+
return classhead
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class Trainer(BaseModel):
|
| 91 |
+
def name(self):
|
| 92 |
+
return 'Trainer'
|
| 93 |
+
|
| 94 |
+
def __init__(self, opt):
|
| 95 |
+
super(Trainer, self).__init__(opt)
|
| 96 |
+
|
| 97 |
+
self.delr = opt.delr
|
| 98 |
+
self.claloss = opt.claloss
|
| 99 |
+
|
| 100 |
+
self.printOne = 1
|
| 101 |
+
self.model = CLIPModel_lora(
|
| 102 |
+
name=opt.clip,
|
| 103 |
+
lora_r=opt.lora_r,
|
| 104 |
+
lora_alpha=opt.lora_alpha,
|
| 105 |
+
lora_dropout=opt.lora_dropout,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
net_params = sum(map(lambda x: x.numel(), self.model.model.parameters()))
|
| 109 |
+
|
| 110 |
+
print(f'Model parameters {net_params:,d}')
|
| 111 |
+
|
| 112 |
+
if self.isTrain:
|
| 113 |
+
self.loss_fn = nn.BCEWithLogitsLoss()
|
| 114 |
+
if opt.optim == 'adam':
|
| 115 |
+
self.optimizer = torch.optim.Adam(
|
| 116 |
+
filter(lambda p: p.requires_grad, self.model.parameters()),
|
| 117 |
+
lr=opt.lr,
|
| 118 |
+
betas=(opt.beta1, 0.999),
|
| 119 |
+
)
|
| 120 |
+
elif opt.optim == 'sgd':
|
| 121 |
+
self.optimizer = torch.optim.SGD(
|
| 122 |
+
self.model.parameters(), lr=opt.lr, momentum=0.0, weight_decay=0
|
| 123 |
+
)
|
| 124 |
+
elif opt.optim == 'adamw':
|
| 125 |
+
self.optimizer = torch.optim.AdamW(
|
| 126 |
+
filter(lambda p: p.requires_grad, self.model.parameters()),
|
| 127 |
+
lr=opt.lr,
|
| 128 |
+
weight_decay=0.05,
|
| 129 |
+
betas=(opt.beta1, 0.999),
|
| 130 |
+
eps=1e-8,
|
| 131 |
+
)
|
| 132 |
+
else:
|
| 133 |
+
raise ValueError('optim should be [adam, sgd]')
|
| 134 |
+
|
| 135 |
+
if not self.isTrain or opt.continue_train:
|
| 136 |
+
self.load_networks(opt.epoch)
|
| 137 |
+
|
| 138 |
+
self.model = nn.DataParallel(self.model).cuda()
|
| 139 |
+
|
| 140 |
+
def adjust_learning_rate(self, min_lr=1e-6):
|
| 141 |
+
for param_group in self.optimizer.param_groups:
|
| 142 |
+
param_group['lr'] *= self.delr
|
| 143 |
+
if param_group['lr'] < min_lr:
|
| 144 |
+
return False
|
| 145 |
+
self.lr = param_group['lr']
|
| 146 |
+
print('*' * 25)
|
| 147 |
+
print(
|
| 148 |
+
f'Changing lr from {param_group["lr"] / self.delr} to {param_group["lr"]} with delr {self.delr}'
|
| 149 |
+
)
|
| 150 |
+
print('*' * 25)
|
| 151 |
+
return True
|
| 152 |
+
|
| 153 |
+
def set_input(self, input):
|
| 154 |
+
self.input = input[1].cuda()
|
| 155 |
+
self.text = input[2]
|
| 156 |
+
self.input_ids = input[3].cuda()
|
| 157 |
+
self.attention_mask = input[4].cuda()
|
| 158 |
+
self.label = input[5].cuda().float()
|
| 159 |
+
|
| 160 |
+
def forward(self):
|
| 161 |
+
self.output, self.classhead = self.model(
|
| 162 |
+
self.input, self.input_ids, self.attention_mask
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
def contrastive_loss(self, logits: torch.Tensor) -> torch.Tensor:
|
| 166 |
+
|
| 167 |
+
caption_loss = nn.functional.cross_entropy(
|
| 168 |
+
logits, torch.arange(len(logits), device=logits.device)
|
| 169 |
+
)
|
| 170 |
+
image_loss = nn.functional.cross_entropy(
|
| 171 |
+
logits.t(), torch.arange(len(logits), device=logits.device)
|
| 172 |
+
)
|
| 173 |
+
return (caption_loss + image_loss) / 2.0
|
| 174 |
+
|
| 175 |
+
def get_loss(self):
|
| 176 |
+
return self.model.clip_loss_input(self.input, self.text)
|
| 177 |
+
|
| 178 |
+
def optimize_parameters(self):
|
| 179 |
+
self.forward()
|
| 180 |
+
|
| 181 |
+
self.loss1 = sum(
|
| 182 |
+
[
|
| 183 |
+
self.contrastive_loss(output)
|
| 184 |
+
for output in torch.split(self.output, self.output.shape[1], dim=0)
|
| 185 |
+
]
|
| 186 |
+
)
|
| 187 |
+
self.loss2 = self.claloss * self.loss_fn(self.classhead, self.label)
|
| 188 |
+
self.loss = self.loss1 + self.loss2
|
| 189 |
+
# self.loss1, self.loss2 = 0.0, 0.0
|
| 190 |
+
# self.loss = self.loss_fn(self.classhead, self.label)
|
| 191 |
+
self.optimizer.zero_grad()
|
| 192 |
+
self.loss.backward()
|
| 193 |
+
self.optimizer.step()
|
detector_codes/C2P-CLIP-DeepfakeDetection-main/options/__init__.py
ADDED
|
File without changes
|
detector_codes/C2P-CLIP-DeepfakeDetection-main/options/base_options.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
import utils.util as util
|
| 5 |
+
import torch
|
| 6 |
+
#import models
|
| 7 |
+
#import data
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class BaseOptions():
|
| 11 |
+
def __init__(self):
|
| 12 |
+
self.initialized = False
|
| 13 |
+
|
| 14 |
+
def initialize(self, parser):
|
| 15 |
+
parser.add_argument('--mode', default='binary')
|
| 16 |
+
parser.add_argument('--arch', type=str, default='res50', help='architecture for binary classification')
|
| 17 |
+
|
| 18 |
+
# data augmentation
|
| 19 |
+
parser.add_argument('--rz_interp', default='bilinear')
|
| 20 |
+
parser.add_argument('--blur_prob', type=float, default=0)
|
| 21 |
+
parser.add_argument('--blur_sig', default='0.5')
|
| 22 |
+
parser.add_argument('--jpg_prob', type=float, default=0)
|
| 23 |
+
parser.add_argument('--jpg_method', default='cv2')
|
| 24 |
+
parser.add_argument('--jpg_qual', default='75')
|
| 25 |
+
|
| 26 |
+
parser.add_argument('--dataroot', default='./dataset/', help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
|
| 27 |
+
parser.add_argument('--textroot', default='./Genimage_CNNDetection_CLIP_prefix_caption/', help='path to texts')
|
| 28 |
+
|
| 29 |
+
parser.add_argument('--classes', default='', help='which classes to use, separated by comma. If empty, use all subfolders of dataroot')
|
| 30 |
+
parser.add_argument('--class_bal', action='store_true')
|
| 31 |
+
parser.add_argument('--batch_size', type=int, default=64, help='input batch size')
|
| 32 |
+
parser.add_argument('--loadSize', type=int, default=256, help='scale images to this size')
|
| 33 |
+
parser.add_argument('--cropSize', type=int, default=224, help='then crop to this size')
|
| 34 |
+
parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
|
| 35 |
+
parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
|
| 36 |
+
parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
|
| 37 |
+
parser.add_argument('--num_threads', type=int, default=8, help='# threads for loading data')
|
| 38 |
+
parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
|
| 39 |
+
parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
|
| 40 |
+
parser.add_argument('--resize_or_crop', type=str, default='scale_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop|none]')
|
| 41 |
+
parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
|
| 42 |
+
parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]')
|
| 43 |
+
parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
|
| 44 |
+
parser.add_argument('--suffix', type=str, default='', help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{loadSize}')
|
| 45 |
+
parser.add_argument('--delr_freq', type=int, default=20, help='frequency of change lr')
|
| 46 |
+
parser.add_argument('--delr', type=float, default=0.8, help='delr')
|
| 47 |
+
parser.add_argument('--seed', type=int, default=123, help='seed')
|
| 48 |
+
parser.add_argument('--clip', type=str, default='./clip-vit-large-patch14/', help='clip path')
|
| 49 |
+
parser.add_argument('--claloss', type=float, default=0.5, help='fixed num layer')
|
| 50 |
+
parser.add_argument('--cates', nargs='+', default=['Deepfake', 'Camera'])
|
| 51 |
+
parser.add_argument('--eval_freq', type=int, default=200, help='eval frequency')
|
| 52 |
+
parser.add_argument('--lora_r', type=int, default=16, help='eval frequency')
|
| 53 |
+
parser.add_argument('--lora_alpha', type=int, default=32, help='eval frequency')
|
| 54 |
+
parser.add_argument('--lora_dropout', type=float, default=0.1, help='eval frequency')
|
| 55 |
+
parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam')
|
| 56 |
+
|
| 57 |
+
self.initialized = True
|
| 58 |
+
return parser
|
| 59 |
+
|
| 60 |
+
def gather_options(self):
|
| 61 |
+
# initialize parser with basic options
|
| 62 |
+
if not self.initialized:
|
| 63 |
+
parser = argparse.ArgumentParser(
|
| 64 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 65 |
+
parser = self.initialize(parser)
|
| 66 |
+
|
| 67 |
+
# get the basic options
|
| 68 |
+
opt, _ = parser.parse_known_args()
|
| 69 |
+
self.parser = parser
|
| 70 |
+
|
| 71 |
+
return opt
|
| 72 |
+
# return parser.parse_args()
|
| 73 |
+
|
| 74 |
+
def print_options(self, opt):
|
| 75 |
+
message = ''
|
| 76 |
+
message += '----------------- Options ---------------\n'
|
| 77 |
+
for k, v in sorted(vars(opt).items()):
|
| 78 |
+
comment = ''
|
| 79 |
+
default = self.parser.get_default(k)
|
| 80 |
+
if v != default:
|
| 81 |
+
comment = '\t[default: %s]' % str(default)
|
| 82 |
+
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
|
| 83 |
+
message += '----------------- End -------------------'
|
| 84 |
+
print(message)
|
| 85 |
+
|
| 86 |
+
# save to the disk
|
| 87 |
+
|
| 88 |
+
expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
| 89 |
+
util.mkdirs(expr_dir)
|
| 90 |
+
file_name = os.path.join(expr_dir, 'opt.txt')
|
| 91 |
+
with open(file_name, 'wt') as opt_file:
|
| 92 |
+
opt_file.write(message)
|
| 93 |
+
opt_file.write('\n')
|
| 94 |
+
|
| 95 |
+
def parse(self, print_options=True):
|
| 96 |
+
|
| 97 |
+
opt = self.gather_options()
|
| 98 |
+
opt.isTrain = self.isTrain # train or test
|
| 99 |
+
opt.imgroot = opt.dataroot
|
| 100 |
+
opt.name = '__'.join([opt.name, time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()), 'Seed_'+str(opt.seed), 'cates_'+'-'.join(opt.cates), 'claloss_'+str(opt.claloss), 'lora_r_'+str(opt.lora_r), 'lora_alpha_'+str(opt.lora_alpha), 'lora_dropout_'+str(opt.lora_dropout), 'lr_'+str(opt.lr)])
|
| 101 |
+
|
| 102 |
+
if opt.suffix:
|
| 103 |
+
suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
|
| 104 |
+
opt.name = opt.name + suffix
|
| 105 |
+
|
| 106 |
+
if print_options:
|
| 107 |
+
self.print_options(opt)
|
| 108 |
+
|
| 109 |
+
# set gpu ids
|
| 110 |
+
str_ids = opt.gpu_ids.split(',')
|
| 111 |
+
opt.gpu_ids = []
|
| 112 |
+
for str_id in str_ids:
|
| 113 |
+
id = int(str_id)
|
| 114 |
+
if id >= 0:
|
| 115 |
+
opt.gpu_ids.append(id)
|
| 116 |
+
if len(opt.gpu_ids) > 0:
|
| 117 |
+
torch.cuda.set_device(opt.gpu_ids[0])
|
| 118 |
+
|
| 119 |
+
# additional
|
| 120 |
+
opt.classes = opt.classes.split(',')
|
| 121 |
+
opt.rz_interp = opt.rz_interp.split(',')
|
| 122 |
+
opt.blur_sig = [float(s) for s in opt.blur_sig.split(',')]
|
| 123 |
+
opt.jpg_method = opt.jpg_method.split(',')
|
| 124 |
+
opt.jpg_qual = [int(s) for s in opt.jpg_qual.split(',')]
|
| 125 |
+
if len(opt.jpg_qual) == 2:
|
| 126 |
+
opt.jpg_qual = list(range(opt.jpg_qual[0], opt.jpg_qual[1] + 1))
|
| 127 |
+
elif len(opt.jpg_qual) > 2:
|
| 128 |
+
raise ValueError("Shouldn't have more than 2 values for --jpg_qual.")
|
| 129 |
+
|
| 130 |
+
self.opt = opt
|
| 131 |
+
return self.opt
|
detector_codes/C2P-CLIP-DeepfakeDetection-main/options/test_options.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base_options import BaseOptions
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class TestOptions(BaseOptions):
|
| 5 |
+
def initialize(self, parser):
|
| 6 |
+
parser = BaseOptions.initialize(self, parser)
|
| 7 |
+
# parser.add_argument('--dataroot')
|
| 8 |
+
parser.add_argument('--model_path')
|
| 9 |
+
parser.add_argument('--no_resize', action='store_true')
|
| 10 |
+
parser.add_argument('--no_crop', action='store_true')
|
| 11 |
+
parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
|
| 12 |
+
parser.add_argument('--earlystop_epoch', type=int, default=15)
|
| 13 |
+
# parser.add_argument('--lr', type=float, default=0.00002, help='initial learning rate for adam')
|
| 14 |
+
parser.add_argument('--niter', type=int, default=0, help='# of iter at starting learning rate')
|
| 15 |
+
|
| 16 |
+
self.isTrain = False
|
| 17 |
+
return parser
|
detector_codes/C2P-CLIP-DeepfakeDetection-main/options/train_options.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base_options import BaseOptions
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class TrainOptions(BaseOptions):
|
| 5 |
+
def initialize(self, parser):
|
| 6 |
+
parser = BaseOptions.initialize(self, parser)
|
| 7 |
+
parser.add_argument('--earlystop_epoch', type=int, default=15)
|
| 8 |
+
parser.add_argument('--data_aug', action='store_true', help='if specified, perform additional data augmentation (photometric, blurring, jpegging)')
|
| 9 |
+
parser.add_argument('--optim', type=str, default='adam', help='optim to use [sgd, adam]')
|
| 10 |
+
parser.add_argument('--new_optim', action='store_true', help='new optimizer instead of loading the optim state')
|
| 11 |
+
parser.add_argument('--loss_freq', type=int, default=400, help='frequency of showing loss on tensorboard')
|
| 12 |
+
parser.add_argument('--save_latest_freq', type=int, default=2000, help='frequency of saving the latest results')
|
| 13 |
+
parser.add_argument('--save_epoch_freq', type=int, default=20, help='frequency of saving checkpoints at the end of epochs')
|
| 14 |
+
parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
|
| 15 |
+
parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')
|
| 16 |
+
parser.add_argument('--last_epoch', type=int, default=-1, help='starting epoch count for scheduler intialization')
|
| 17 |
+
parser.add_argument('--train_split', type=str, default='train', help='train, val, test, etc')
|
| 18 |
+
parser.add_argument('--val_split', type=str, default='val', help='train, val, test, etc')
|
| 19 |
+
parser.add_argument('--niter', type=int, default=1000, help='number of epochs')
|
| 20 |
+
parser.add_argument('--total_steps', type=int, default=1000, help='total_steps to train')
|
| 21 |
+
parser.add_argument('--beta1', type=float, default=0.9, help='momentum term of adam')
|
| 22 |
+
# parser.add_argument('--model_path')
|
| 23 |
+
# parser.add_argument('--no_resize', action='store_true')
|
| 24 |
+
# parser.add_argument('--no_crop', action='store_true')
|
| 25 |
+
self.isTrain = True
|
| 26 |
+
return parser
|
detector_codes/C2P-CLIP-DeepfakeDetection-main/requirements.txt
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--extra-index-url https://download.pytorch.org/whl/cu118
|
| 2 |
+
accelerate==0.32.1
|
| 3 |
+
aiosignal==1.3.1
|
| 4 |
+
annotated-types==0.7.0
|
| 5 |
+
anyio==4.8.0
|
| 6 |
+
attrs==23.2.0
|
| 7 |
+
certifi==2024.7.4
|
| 8 |
+
charset-normalizer==3.3.2
|
| 9 |
+
click==8.1.7
|
| 10 |
+
clip @ git+https://github.com/openai/CLIP.git@dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
|
| 11 |
+
contourpy==1.2.1
|
| 12 |
+
cycler==0.12.1
|
| 13 |
+
DingtalkChatbot==1.5.7
|
| 14 |
+
distro==1.9.0
|
| 15 |
+
dlib==19.24.6
|
| 16 |
+
einops==0.8.0
|
| 17 |
+
exceptiongroup==1.2.2
|
| 18 |
+
filelock==3.13.1
|
| 19 |
+
fonttools==4.53.1
|
| 20 |
+
frozenlist==1.4.1
|
| 21 |
+
fsspec==2024.2.0
|
| 22 |
+
ftfy==6.2.0
|
| 23 |
+
h11==0.14.0
|
| 24 |
+
httpcore==1.0.7
|
| 25 |
+
httpx==0.28.1
|
| 26 |
+
idna==3.7
|
| 27 |
+
imageio==2.35.1
|
| 28 |
+
Jinja2==3.1.3
|
| 29 |
+
jiter==0.8.2
|
| 30 |
+
joblib==1.4.2
|
| 31 |
+
jsonschema==4.23.0
|
| 32 |
+
jsonschema-specifications==2023.12.1
|
| 33 |
+
kiwisolver==1.4.5
|
| 34 |
+
lazy_loader==0.4
|
| 35 |
+
lora-pytorch==0.2.0
|
| 36 |
+
MarkupSafe==2.1.5
|
| 37 |
+
matplotlib==3.9.1
|
| 38 |
+
mpmath==1.3.0
|
| 39 |
+
msgpack==1.0.8
|
| 40 |
+
munch==4.0.0
|
| 41 |
+
networkx==3.2.1
|
| 42 |
+
numpy==1.26.3
|
| 43 |
+
nvidia-cublas-cu11==11.11.3.6
|
| 44 |
+
nvidia-cuda-cupti-cu11==11.8.87
|
| 45 |
+
nvidia-cuda-nvrtc-cu11==11.8.89
|
| 46 |
+
nvidia-cuda-runtime-cu11==11.8.89
|
| 47 |
+
nvidia-cudnn-cu11==8.7.0.84
|
| 48 |
+
nvidia-cufft-cu11==10.9.0.58
|
| 49 |
+
nvidia-curand-cu11==10.3.0.86
|
| 50 |
+
nvidia-cusolver-cu11==11.4.1.48
|
| 51 |
+
nvidia-cusparse-cu11==11.7.5.86
|
| 52 |
+
nvidia-nccl-cu11==2.20.5
|
| 53 |
+
nvidia-nvtx-cu11==11.8.86
|
| 54 |
+
openai==1.60.2
|
| 55 |
+
opencv-contrib-python==4.10.0.84
|
| 56 |
+
opencv-python==4.10.0.84
|
| 57 |
+
packaging==24.1
|
| 58 |
+
peft==0.11.1
|
| 59 |
+
pillow==10.2.0
|
| 60 |
+
protobuf==5.27.2
|
| 61 |
+
psutil==6.0.0
|
| 62 |
+
pydantic==2.10.6
|
| 63 |
+
pydantic_core==2.27.2
|
| 64 |
+
pyparsing==3.1.2
|
| 65 |
+
python-dateutil==2.9.0.post0
|
| 66 |
+
pytorch_wavelets @ file:///opt/data/private/tcc/GANS_BS1_project6_LLM/A_FatFormer/FatFormer/pytorch_wavelets
|
| 67 |
+
PyYAML==6.0.1
|
| 68 |
+
ray==2.32.0
|
| 69 |
+
referencing==0.35.1
|
| 70 |
+
regex==2024.5.15
|
| 71 |
+
requests==2.32.3
|
| 72 |
+
rpds-py==0.19.0
|
| 73 |
+
safetensors==0.4.3
|
| 74 |
+
scikit-image==0.24.0
|
| 75 |
+
scikit-learn==1.5.1
|
| 76 |
+
scipy==1.14.0
|
| 77 |
+
setproctitle==1.3.3
|
| 78 |
+
six==1.16.0
|
| 79 |
+
sniffio==1.3.1
|
| 80 |
+
sympy==1.12
|
| 81 |
+
tensorboardX==2.6.2.2
|
| 82 |
+
threadpoolctl==3.5.0
|
| 83 |
+
tifffile==2024.8.28
|
| 84 |
+
tokenizers==0.19.1
|
| 85 |
+
torch==2.3.0
|
| 86 |
+
torchaudio==2.3.0
|
| 87 |
+
torchvision==0.18.0
|
| 88 |
+
tqdm==4.66.4
|
| 89 |
+
transformers==4.42.3
|
| 90 |
+
triton==2.3.0
|
| 91 |
+
typing_extensions==4.12.2
|
| 92 |
+
urllib3==2.2.2
|
| 93 |
+
wcwidth==0.2.13
|
detector_codes/C2P-CLIP-DeepfakeDetection-main/train.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 5 |
+
import random
|
| 6 |
+
import time
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn
|
| 11 |
+
from data import create_dataloader
|
| 12 |
+
from networks.trainer import Trainer
|
| 13 |
+
from options.test_options import TestOptions
|
| 14 |
+
from options.train_options import TrainOptions
|
| 15 |
+
from utils.util import Logger
|
| 16 |
+
from validate import validate
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def seed_torch(seed=1029):
|
| 20 |
+
random.seed(seed)
|
| 21 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
| 22 |
+
np.random.seed(seed)
|
| 23 |
+
torch.manual_seed(seed)
|
| 24 |
+
torch.cuda.manual_seed(seed)
|
| 25 |
+
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
|
| 26 |
+
torch.backends.cudnn.benchmark = False
|
| 27 |
+
torch.backends.cudnn.deterministic = True
|
| 28 |
+
torch.backends.cudnn.enabled = False
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
"""Currently assumes jpg_prob, blur_prob 0 or 1"""
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_val_opt():
|
| 35 |
+
val_opt = TrainOptions().parse(print_options=False)
|
| 36 |
+
val_opt.dataroot = '{}/{}/'.format(val_opt.dataroot, val_opt.val_split)
|
| 37 |
+
val_opt.isTrain = False
|
| 38 |
+
val_opt.no_resize = False
|
| 39 |
+
val_opt.no_crop = False
|
| 40 |
+
val_opt.serial_batches = True
|
| 41 |
+
val_opt.jpg_method = ['pil']
|
| 42 |
+
if len(val_opt.blur_sig) == 2:
|
| 43 |
+
b_sig = val_opt.blur_sig
|
| 44 |
+
val_opt.blur_sig = [(b_sig[0] + b_sig[1]) / 2]
|
| 45 |
+
if len(val_opt.jpg_qual) != 1:
|
| 46 |
+
j_qual = val_opt.jpg_qual
|
| 47 |
+
val_opt.jpg_qual = [int((j_qual[0] + j_qual[-1]) / 2)]
|
| 48 |
+
|
| 49 |
+
return val_opt
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
if __name__ == '__main__':
|
| 53 |
+
opt = TrainOptions().parse()
|
| 54 |
+
seed_torch(opt.seed)
|
| 55 |
+
|
| 56 |
+
Test_dataroot = os.path.join(opt.dataroot, 'test')
|
| 57 |
+
opt.dataroot = '{}/{}/'.format(opt.dataroot, opt.train_split)
|
| 58 |
+
Logger(os.path.join(opt.checkpoints_dir, opt.name, 'log.log'))
|
| 59 |
+
Testopt = TestOptions().parse(print_options=False)
|
| 60 |
+
Test_vals = os.listdir(Test_dataroot)
|
| 61 |
+
data_loader = create_dataloader(opt)
|
| 62 |
+
model = Trainer(opt)
|
| 63 |
+
|
| 64 |
+
def testmodel(epoch=0):
|
| 65 |
+
print('*' * 25)
|
| 66 |
+
accs = []
|
| 67 |
+
aps = []
|
| 68 |
+
logs = [f'Testing end of {epoch}']
|
| 69 |
+
print(time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime()))
|
| 70 |
+
for v_id, val in enumerate(Test_vals):
|
| 71 |
+
Testopt.dataroot = f'{Test_dataroot}/{val}'
|
| 72 |
+
# Testopt.classes = os.listdir(Testopt.dataroot) if multiclass[v_id] else ['']
|
| 73 |
+
Testopt.loadSize = opt.cropSize
|
| 74 |
+
Testopt.cropSize = opt.cropSize
|
| 75 |
+
Testopt.no_resize = False
|
| 76 |
+
Testopt.no_crop = False
|
| 77 |
+
Testopt.classes = ''
|
| 78 |
+
acc, ap, _, _, _, _ = validate(model.model, Testopt)
|
| 79 |
+
accs.append(acc)
|
| 80 |
+
aps.append(ap)
|
| 81 |
+
logs.append(
|
| 82 |
+
'({} {:10}) acc: {:.1f}; ap: {:.1f}'.format(
|
| 83 |
+
v_id, val, acc * 100, ap * 100
|
| 84 |
+
)
|
| 85 |
+
)
|
| 86 |
+
print(logs[-1])
|
| 87 |
+
logs.append(
|
| 88 |
+
'({} {:10}) acc: {:.1f}; ap: {:.1f}'.format(
|
| 89 |
+
v_id + 1,
|
| 90 |
+
'Mean',
|
| 91 |
+
np.array(accs).mean() * 100,
|
| 92 |
+
np.array(aps).mean() * 100,
|
| 93 |
+
)
|
| 94 |
+
)
|
| 95 |
+
print(logs[-1])
|
| 96 |
+
print('*' * 25)
|
| 97 |
+
print(time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime()))
|
| 98 |
+
return round(np.array(accs).mean() * 100, 4)
|
| 99 |
+
|
| 100 |
+
model.train()
|
| 101 |
+
for epoch in range(opt.niter):
|
| 102 |
+
epoch_start_time = time.time()
|
| 103 |
+
iter_data_time = time.time()
|
| 104 |
+
epoch_iter = 0
|
| 105 |
+
|
| 106 |
+
for i, data in enumerate(data_loader):
|
| 107 |
+
model.total_steps += 1
|
| 108 |
+
if model.total_steps > opt.total_steps:
|
| 109 |
+
break
|
| 110 |
+
epoch_iter += opt.batch_size
|
| 111 |
+
model.set_input(data)
|
| 112 |
+
model.optimize_parameters()
|
| 113 |
+
if model.total_steps % opt.loss_freq == 0:
|
| 114 |
+
print(
|
| 115 |
+
time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime()),
|
| 116 |
+
'Train loss: {} loss1: {} loss2-cla: {} at step: {} lr {}'.format(
|
| 117 |
+
model.loss,
|
| 118 |
+
model.loss1,
|
| 119 |
+
model.loss2,
|
| 120 |
+
model.total_steps,
|
| 121 |
+
model.lr,
|
| 122 |
+
),
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# if model.total_steps % opt.eval_freq == 0:
|
| 126 |
+
# print(os.getcwd())
|
| 127 |
+
# print(f'==========total_steps {model.total_steps}=================')
|
| 128 |
+
# model.eval()
|
| 129 |
+
# testacc = testmodel(epoch)
|
| 130 |
+
# model.save_networks( f'{str(epoch)}_total_steps_{str(model.total_steps)}_testacc_{str(testacc)}')
|
| 131 |
+
# model.train()
|
| 132 |
+
|
| 133 |
+
if epoch % opt.delr_freq == 0 and epoch != 0:
|
| 134 |
+
print(
|
| 135 |
+
time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime()),
|
| 136 |
+
'changing lr at the end of epoch %d, iters %d'
|
| 137 |
+
% (epoch, model.total_steps),
|
| 138 |
+
)
|
| 139 |
+
model.adjust_learning_rate()
|
| 140 |
+
|
| 141 |
+
model.eval()
|
| 142 |
+
testacc = testmodel(epoch)
|
| 143 |
+
model.save_networks(
|
| 144 |
+
f'{str(epoch)}_total_steps_{str(model.total_steps)}_testacc_{str(testacc)}'
|
| 145 |
+
)
|
| 146 |
+
print(
|
| 147 |
+
'saving the latest model %s (epoch %d, model.total_steps %d)'
|
| 148 |
+
% (opt.name, epoch, model.total_steps)
|
| 149 |
+
)
|
| 150 |
+
model.train()
|
detector_codes/C2P-CLIP-DeepfakeDetection-main/train_UniversalFakeDetect.sh
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/train.py \
|
| 2 |
+
--name ForenSynths_train_val_19test_dataset \
|
| 3 |
+
--dataroot ./ForenSynths_train_val_19test/ \
|
| 4 |
+
--classes car,cat,chair,horse \
|
| 5 |
+
--batch_size 128 \
|
| 6 |
+
--delr_freq 10 \
|
| 7 |
+
--lr 0.0004 \
|
| 8 |
+
--niter 1 \
|
| 9 |
+
--total_steps 800 \
|
| 10 |
+
--delr 0.9 \
|
| 11 |
+
--loadSize 256 \
|
| 12 |
+
--cropSize 224 \
|
| 13 |
+
--seed 123 \
|
| 14 |
+
--clip ./clip-vit-large-patch14/ \
|
| 15 |
+
--claloss 8.0 \
|
| 16 |
+
--cates Deepfake Camera \
|
| 17 |
+
--eval_freq 800 \
|
| 18 |
+
--loss_freq 50 \
|
| 19 |
+
--lora_r 6 \
|
| 20 |
+
--lora_alpha 6 \
|
| 21 |
+
--lora_dropout 0.8
|
detector_codes/C2P-CLIP-DeepfakeDetection-main/train_aigibench.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import sys
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from datasets import load_dataset
|
| 10 |
+
from torch.utils.data import DataLoader, Dataset
|
| 11 |
+
from torchvision import transforms
|
| 12 |
+
from transformers import AutoTokenizer
|
| 13 |
+
|
| 14 |
+
# Add the current directory to sys.path so it can find networks, options, etc.
|
| 15 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 16 |
+
if current_dir not in sys.path:
|
| 17 |
+
sys.path.insert(0, current_dir)
|
| 18 |
+
|
| 19 |
+
from networks.trainer import Trainer
|
| 20 |
+
from options.train_options import TrainOptions
|
| 21 |
+
from utils.util import Logger
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def seed_torch(seed=1029):
|
| 25 |
+
random.seed(seed)
|
| 26 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
| 27 |
+
np.random.seed(seed)
|
| 28 |
+
torch.manual_seed(seed)
|
| 29 |
+
torch.cuda.manual_seed(seed)
|
| 30 |
+
torch.cuda.manual_seed_all(seed)
|
| 31 |
+
torch.backends.cudnn.benchmark = False
|
| 32 |
+
torch.backends.cudnn.deterministic = True
|
| 33 |
+
torch.backends.cudnn.enabled = False
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class AIGIBenchDataset(Dataset):
|
| 37 |
+
def __init__(self, hf_data, opt, transform=None):
|
| 38 |
+
self.hf_data = hf_data
|
| 39 |
+
self.opt = opt
|
| 40 |
+
self.transform = transform
|
| 41 |
+
|
| 42 |
+
# Use CLIP model path or HF name for tokenizer
|
| 43 |
+
tokenizer_name = (
|
| 44 |
+
opt.clip if os.path.exists(opt.clip) else 'openai/clip-vit-large-patch14'
|
| 45 |
+
)
|
| 46 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 47 |
+
tokenizer_name, model_max_length=77, padding_side='right', use_fast=False
|
| 48 |
+
)
|
| 49 |
+
self.tokenizer.pad_token_id = 0
|
| 50 |
+
|
| 51 |
+
def __len__(self):
|
| 52 |
+
return len(self.hf_data)
|
| 53 |
+
|
| 54 |
+
def __getitem__(self, idx):
|
| 55 |
+
item = self.hf_data[idx]
|
| 56 |
+
image = item['image'].convert('RGB')
|
| 57 |
+
label = item['label']
|
| 58 |
+
generator_id = item['generator']
|
| 59 |
+
|
| 60 |
+
# Generator mapping based on AIGIBench spec
|
| 61 |
+
gen_names = {0: 'Real', 1: 'ProGAN', 2: 'SD14'}
|
| 62 |
+
gen_name = gen_names.get(generator_id, 'Unknown')
|
| 63 |
+
|
| 64 |
+
cates = self.opt.cates
|
| 65 |
+
cates_len = len(cates) // 2
|
| 66 |
+
|
| 67 |
+
# Generic image description to guide the contrastive learning
|
| 68 |
+
text = f'An image produced by {gen_name}'
|
| 69 |
+
|
| 70 |
+
if label == 1: # Fake (AI-generated)
|
| 71 |
+
text = (
|
| 72 |
+
f'{" ".join(cates[:cates_len])}. {text} {" ".join(cates[:cates_len])}.'
|
| 73 |
+
)
|
| 74 |
+
else: # Real (Authentic)
|
| 75 |
+
text = (
|
| 76 |
+
f'{" ".join(cates[cates_len:])}. {text} {" ".join(cates[cates_len:])}.'
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
inputs = self.tokenizer(
|
| 80 |
+
[text],
|
| 81 |
+
padding='max_length',
|
| 82 |
+
max_length=self.tokenizer.model_max_length,
|
| 83 |
+
truncation=True,
|
| 84 |
+
return_tensors='pt',
|
| 85 |
+
)
|
| 86 |
+
input_ids = inputs['input_ids'][0]
|
| 87 |
+
attention_mask = inputs['attention_mask'][0]
|
| 88 |
+
|
| 89 |
+
if self.transform:
|
| 90 |
+
image = self.transform(image)
|
| 91 |
+
|
| 92 |
+
# Format: path, image, text, input_ids, attention_mask, label
|
| 93 |
+
return 'hf_dataset', image, text, input_ids, attention_mask, label
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def get_train_transforms(opt):
|
| 97 |
+
return transforms.Compose(
|
| 98 |
+
[
|
| 99 |
+
transforms.Resize(opt.loadSize),
|
| 100 |
+
transforms.RandomCrop(opt.cropSize),
|
| 101 |
+
transforms.RandomHorizontalFlip(),
|
| 102 |
+
transforms.ToTensor(),
|
| 103 |
+
transforms.Normalize(
|
| 104 |
+
mean=[0.48145466, 0.4578275, 0.40821073],
|
| 105 |
+
std=[0.26862954, 0.26130258, 0.27577711],
|
| 106 |
+
),
|
| 107 |
+
]
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def get_val_transforms(opt):
|
| 112 |
+
return transforms.Compose(
|
| 113 |
+
[
|
| 114 |
+
transforms.Resize(opt.loadSize),
|
| 115 |
+
transforms.CenterCrop(opt.cropSize),
|
| 116 |
+
transforms.ToTensor(),
|
| 117 |
+
transforms.Normalize(
|
| 118 |
+
mean=[0.48145466, 0.4578275, 0.40821073],
|
| 119 |
+
std=[0.26862954, 0.26130258, 0.27577711],
|
| 120 |
+
),
|
| 121 |
+
]
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
if __name__ == '__main__':
|
| 126 |
+
# Parse options
|
| 127 |
+
opt = TrainOptions().parse()
|
| 128 |
+
seed_torch(opt.seed)
|
| 129 |
+
|
| 130 |
+
# Handle CLIP path if folder doesn't exist locally
|
| 131 |
+
if not os.path.exists(opt.clip):
|
| 132 |
+
print(
|
| 133 |
+
f"CLIP path {opt.clip} not found locally. Using 'openai/clip-vit-large-patch14' from HuggingFace."
|
| 134 |
+
)
|
| 135 |
+
opt.clip = 'openai/clip-vit-large-patch14'
|
| 136 |
+
|
| 137 |
+
# Initialize logger
|
| 138 |
+
log_path = os.path.join(opt.checkpoints_dir, opt.name)
|
| 139 |
+
if not os.path.exists(log_path):
|
| 140 |
+
os.makedirs(log_path)
|
| 141 |
+
Logger(os.path.join(log_path, 'log.log'))
|
| 142 |
+
|
| 143 |
+
print('Loading AIGIBench dataset from HuggingFace...')
|
| 144 |
+
ds = load_dataset('TheKernel01/AIGIBench')
|
| 145 |
+
train_data = ds['train']
|
| 146 |
+
|
| 147 |
+
val_data = ds['validation']
|
| 148 |
+
|
| 149 |
+
train_dataset = AIGIBenchDataset(
|
| 150 |
+
train_data, opt, transform=get_train_transforms(opt)
|
| 151 |
+
)
|
| 152 |
+
val_dataset = AIGIBenchDataset(val_data, opt, transform=get_val_transforms(opt))
|
| 153 |
+
|
| 154 |
+
train_loader = DataLoader(
|
| 155 |
+
train_dataset,
|
| 156 |
+
batch_size=opt.batch_size,
|
| 157 |
+
shuffle=True,
|
| 158 |
+
num_workers=opt.num_threads,
|
| 159 |
+
)
|
| 160 |
+
val_loader = DataLoader(
|
| 161 |
+
val_dataset,
|
| 162 |
+
batch_size=opt.batch_size,
|
| 163 |
+
shuffle=False,
|
| 164 |
+
num_workers=opt.num_threads,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# Initialize Trainer
|
| 168 |
+
model = Trainer(opt)
|
| 169 |
+
|
| 170 |
+
print(f'Starting training on {len(train_data)} samples...')
|
| 171 |
+
for epoch in range(opt.niter):
|
| 172 |
+
model.train()
|
| 173 |
+
epoch_start_time = time.time()
|
| 174 |
+
|
| 175 |
+
for i, data in enumerate(train_loader):
|
| 176 |
+
model.total_steps += 1
|
| 177 |
+
|
| 178 |
+
model.set_input(data)
|
| 179 |
+
model.optimize_parameters()
|
| 180 |
+
|
| 181 |
+
if model.total_steps % opt.loss_freq == 0:
|
| 182 |
+
print(
|
| 183 |
+
f'{time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())} Train loss: {model.loss:.4f} loss1: {model.loss1:.4f} loss2-cla: {model.loss2:.4f} at step: {model.total_steps} lr {model.lr}'
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
if model.total_steps >= opt.total_steps:
|
| 187 |
+
break
|
| 188 |
+
|
| 189 |
+
# Validation at the end of epoch
|
| 190 |
+
print('Running validation...')
|
| 191 |
+
model.eval()
|
| 192 |
+
with torch.no_grad():
|
| 193 |
+
val_loss = 0
|
| 194 |
+
val_acc = 0
|
| 195 |
+
count = 0
|
| 196 |
+
# Limit validation samples for speed during training
|
| 197 |
+
max_val_samples = 2000
|
| 198 |
+
|
| 199 |
+
for i, data in enumerate(val_loader):
|
| 200 |
+
model.set_input(data)
|
| 201 |
+
# Trainer's model is DataParallel, access it directly for inference
|
| 202 |
+
_, classhead = model.model(
|
| 203 |
+
model.input, model.input_ids, model.attention_mask
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
loss = nn.functional.binary_cross_entropy_with_logits(
|
| 207 |
+
classhead, model.label
|
| 208 |
+
)
|
| 209 |
+
val_loss += loss.item() * len(model.label)
|
| 210 |
+
|
| 211 |
+
preds = (torch.sigmoid(classhead) > 0.5).float()
|
| 212 |
+
val_acc += (preds == model.label).sum().item()
|
| 213 |
+
count += len(model.label)
|
| 214 |
+
|
| 215 |
+
if count >= max_val_samples:
|
| 216 |
+
break
|
| 217 |
+
|
| 218 |
+
avg_val_loss = val_loss / count
|
| 219 |
+
avg_val_acc = val_acc / count
|
| 220 |
+
print(
|
| 221 |
+
f'Epoch {epoch} | Val Loss: {avg_val_loss:.4f} | Val Acc: {avg_val_acc:.4f}'
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
# Save checkpoint
|
| 225 |
+
model.save_networks(f'epoch_{epoch}_acc_{avg_val_acc:.4f}')
|
| 226 |
+
|
| 227 |
+
if model.total_steps >= opt.total_steps:
|
| 228 |
+
print('Reached total_steps limit. Ending training.')
|
| 229 |
+
break
|
| 230 |
+
|
| 231 |
+
print('Training complete.')
|
detector_codes/C2P-CLIP-DeepfakeDetection-main/train_aigibench.sh
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Training script for AIGIBench using C2P-CLIP
|
| 4 |
+
# This command reproduces the configuration found in the latest checkpoint.
|
| 5 |
+
|
| 6 |
+
python ./train_aigibench.py \
|
| 7 |
+
--name aigibench_full_train \
|
| 8 |
+
--arch res50 \
|
| 9 |
+
--batch_size 8 \
|
| 10 |
+
--lr 0.0001 \
|
| 11 |
+
--niter 1 \
|
| 12 |
+
--total_steps 5000 \
|
| 13 |
+
--loadSize 256 \
|
| 14 |
+
--cropSize 224 \
|
| 15 |
+
--seed 123 \
|
| 16 |
+
--clip openai/clip-vit-large-patch14 \
|
| 17 |
+
--claloss 0.5 \
|
| 18 |
+
--cates Deepfake Camera \
|
| 19 |
+
--eval_freq 200 \
|
| 20 |
+
--loss_freq 50 \
|
| 21 |
+
--lora_r 16 \
|
| 22 |
+
--lora_alpha 32 \
|
| 23 |
+
--lora_dropout 0.1 \
|
| 24 |
+
--num_threads 4
|
detector_codes/C2P-CLIP-DeepfakeDetection-main/train_genimage.sh
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/train.py \
|
| 2 |
+
--name genimage \
|
| 3 |
+
--dataroot ./GenImage_Dataset/ \
|
| 4 |
+
--classes sdv4 \
|
| 5 |
+
--batch_size 128 \
|
| 6 |
+
--lr 0.0002 \
|
| 7 |
+
--niter 1 \
|
| 8 |
+
--total_steps 100 \
|
| 9 |
+
--loadSize 256 \
|
| 10 |
+
--cropSize 224 \
|
| 11 |
+
--seed 123 \
|
| 12 |
+
--clip ./clip-vit-large-patch14/ \
|
| 13 |
+
--claloss 8 \
|
| 14 |
+
--cates Deepfake Camera \
|
| 15 |
+
--eval_freq 100 \
|
| 16 |
+
--loss_freq 10 \
|
| 17 |
+
--lora_r 6 \
|
| 18 |
+
--lora_alpha 6 \
|
| 19 |
+
--lora_dropout 0.5
|
detector_codes/C2P-CLIP-DeepfakeDetection-main/utils/logger.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def create_logger(log_dir, phase='train'):
|
| 11 |
+
time_str = time.strftime('%Y-%m-%d-%H-%M')
|
| 12 |
+
log_file = '{}_{}.log'.format(time_str, phase)
|
| 13 |
+
final_log_file = os.path.join(log_dir, log_file)
|
| 14 |
+
logging.basicConfig(
|
| 15 |
+
filename=str(final_log_file),
|
| 16 |
+
format='%(asctime)s %(levelname)s: %(message)s',
|
| 17 |
+
level=logging.INFO,
|
| 18 |
+
datefmt='%Y-%m-%d %H:%M:%S',
|
| 19 |
+
)
|
| 20 |
+
logger = logging.getLogger()
|
| 21 |
+
logger.setLevel(logging.INFO)
|
| 22 |
+
console = logging.StreamHandler()
|
| 23 |
+
logging.getLogger('').addHandler(console)
|
| 24 |
+
return logger
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Progbar(object):
|
| 28 |
+
"""Displays a progress bar.
|
| 29 |
+
# Arguments
|
| 30 |
+
target: Total number of steps expected, None if unknown.
|
| 31 |
+
width: Progress bar width on screen.
|
| 32 |
+
verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)
|
| 33 |
+
stateful_metrics: Iterable of string names of metrics that
|
| 34 |
+
should *not* be averaged over time. Metrics in this list
|
| 35 |
+
will be displayed as-is. All others will be averaged
|
| 36 |
+
by the progbar before display.
|
| 37 |
+
interval: Minimum visual progress update interval (in seconds).
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(
|
| 41 |
+
self, target, width=30, verbose=1, interval=0.05, stateful_metrics=None
|
| 42 |
+
):
|
| 43 |
+
self.target = target
|
| 44 |
+
self.width = width
|
| 45 |
+
self.verbose = verbose
|
| 46 |
+
self.interval = interval
|
| 47 |
+
if stateful_metrics:
|
| 48 |
+
self.stateful_metrics = set(stateful_metrics)
|
| 49 |
+
else:
|
| 50 |
+
self.stateful_metrics = set()
|
| 51 |
+
|
| 52 |
+
self._dynamic_display = (
|
| 53 |
+
hasattr(sys.stdout, 'isatty') and sys.stdout.isatty()
|
| 54 |
+
) or 'ipykernel' in sys.modules
|
| 55 |
+
self._total_width = 0
|
| 56 |
+
self._seen_so_far = 0
|
| 57 |
+
self._values = collections.OrderedDict()
|
| 58 |
+
self._start = time.time()
|
| 59 |
+
self._last_update = 0
|
| 60 |
+
|
| 61 |
+
def update(self, current, values=None):
|
| 62 |
+
"""Updates the progress bar.
|
| 63 |
+
# Arguments
|
| 64 |
+
current: Index of current step.
|
| 65 |
+
values: List of tuples:
|
| 66 |
+
`(name, value_for_last_step)`.
|
| 67 |
+
If `name` is in `stateful_metrics`,
|
| 68 |
+
`value_for_last_step` will be displayed as-is.
|
| 69 |
+
Else, an average of the metric over time will be displayed.
|
| 70 |
+
"""
|
| 71 |
+
values = values or []
|
| 72 |
+
for k, v in values:
|
| 73 |
+
if k not in self.stateful_metrics:
|
| 74 |
+
if k not in self._values:
|
| 75 |
+
self._values[k] = [
|
| 76 |
+
v * (current - self._seen_so_far),
|
| 77 |
+
current - self._seen_so_far,
|
| 78 |
+
]
|
| 79 |
+
else:
|
| 80 |
+
self._values[k][0] += v * (current - self._seen_so_far)
|
| 81 |
+
self._values[k][1] += current - self._seen_so_far
|
| 82 |
+
else:
|
| 83 |
+
self._values[k] = v
|
| 84 |
+
self._seen_so_far = current
|
| 85 |
+
|
| 86 |
+
now = time.time()
|
| 87 |
+
info = ' - %.0fs' % (now - self._start)
|
| 88 |
+
if self.verbose == 1:
|
| 89 |
+
if (
|
| 90 |
+
now - self._last_update < self.interval
|
| 91 |
+
and self.target is not None
|
| 92 |
+
and current < self.target
|
| 93 |
+
):
|
| 94 |
+
return
|
| 95 |
+
|
| 96 |
+
prev_total_width = self._total_width
|
| 97 |
+
if self._dynamic_display:
|
| 98 |
+
sys.stdout.write('\b' * prev_total_width)
|
| 99 |
+
sys.stdout.write('\r')
|
| 100 |
+
else:
|
| 101 |
+
sys.stdout.write('\n')
|
| 102 |
+
|
| 103 |
+
if self.target is not None:
|
| 104 |
+
numdigits = int(np.floor(np.log10(self.target))) + 1
|
| 105 |
+
barstr = '%%%dd/%d [' % (numdigits, self.target)
|
| 106 |
+
bar = barstr % current
|
| 107 |
+
prog = float(current) / self.target
|
| 108 |
+
prog_width = int(self.width * prog)
|
| 109 |
+
if prog_width > 0:
|
| 110 |
+
bar += '=' * (prog_width - 1)
|
| 111 |
+
if current < self.target:
|
| 112 |
+
bar += '>'
|
| 113 |
+
else:
|
| 114 |
+
bar += '='
|
| 115 |
+
bar += '.' * (self.width - prog_width)
|
| 116 |
+
bar += ']'
|
| 117 |
+
else:
|
| 118 |
+
bar = '%7d/Unknown' % current
|
| 119 |
+
|
| 120 |
+
self._total_width = len(bar)
|
| 121 |
+
sys.stdout.write(bar)
|
| 122 |
+
|
| 123 |
+
if current:
|
| 124 |
+
time_per_unit = (now - self._start) / current
|
| 125 |
+
else:
|
| 126 |
+
time_per_unit = 0
|
| 127 |
+
if self.target is not None and current < self.target:
|
| 128 |
+
eta = time_per_unit * (self.target - current)
|
| 129 |
+
if eta > 3600:
|
| 130 |
+
eta_format = '%d:%02d:%02d' % (
|
| 131 |
+
eta // 3600,
|
| 132 |
+
(eta % 3600) // 60,
|
| 133 |
+
eta % 60,
|
| 134 |
+
)
|
| 135 |
+
elif eta > 60:
|
| 136 |
+
eta_format = '%d:%02d' % (eta // 60, eta % 60)
|
| 137 |
+
else:
|
| 138 |
+
eta_format = '%ds' % eta
|
| 139 |
+
|
| 140 |
+
info = ' - ETA: %s' % eta_format
|
| 141 |
+
else:
|
| 142 |
+
if time_per_unit >= 1:
|
| 143 |
+
info += ' %.0fs/step' % time_per_unit
|
| 144 |
+
elif time_per_unit >= 1e-3:
|
| 145 |
+
info += ' %.0fms/step' % (time_per_unit * 1e3)
|
| 146 |
+
else:
|
| 147 |
+
info += ' %.0fus/step' % (time_per_unit * 1e6)
|
| 148 |
+
|
| 149 |
+
for k in self._values:
|
| 150 |
+
info += ' - %s:' % k
|
| 151 |
+
if isinstance(self._values[k], list):
|
| 152 |
+
avg = np.mean(self._values[k][0] / max(1, self._values[k][1]))
|
| 153 |
+
if abs(avg) > 1e-3:
|
| 154 |
+
info += ' %.4f' % avg
|
| 155 |
+
else:
|
| 156 |
+
info += ' %.4e' % avg
|
| 157 |
+
else:
|
| 158 |
+
info += ' %s' % self._values[k]
|
| 159 |
+
|
| 160 |
+
self._total_width += len(info)
|
| 161 |
+
if prev_total_width > self._total_width:
|
| 162 |
+
info += ' ' * (prev_total_width - self._total_width)
|
| 163 |
+
|
| 164 |
+
if self.target is not None and current >= self.target:
|
| 165 |
+
info += '\n'
|
| 166 |
+
|
| 167 |
+
sys.stdout.write(info)
|
| 168 |
+
sys.stdout.flush()
|
| 169 |
+
|
| 170 |
+
elif self.verbose == 2:
|
| 171 |
+
if self.target is None or current >= self.target:
|
| 172 |
+
for k in self._values:
|
| 173 |
+
info += ' - %s:' % k
|
| 174 |
+
avg = np.mean(self._values[k][0] / max(1, self._values[k][1]))
|
| 175 |
+
if avg > 1e-3:
|
| 176 |
+
info += ' %.4f' % avg
|
| 177 |
+
else:
|
| 178 |
+
info += ' %.4e' % avg
|
| 179 |
+
info += '\n'
|
| 180 |
+
|
| 181 |
+
sys.stdout.write(info)
|
| 182 |
+
sys.stdout.flush()
|
| 183 |
+
|
| 184 |
+
self._last_update = now
|
| 185 |
+
|
| 186 |
+
def add(self, n, values=None):
|
| 187 |
+
self.update(self._seen_so_far + n, values)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class AverageMeter(object):
|
| 191 |
+
"""Computes and stores the average and current value"""
|
| 192 |
+
|
| 193 |
+
def __init__(self):
|
| 194 |
+
self.reset()
|
| 195 |
+
|
| 196 |
+
def reset(self):
|
| 197 |
+
self.val = 0
|
| 198 |
+
self.avg = 0
|
| 199 |
+
self.sum = 0
|
| 200 |
+
self.count = 0
|
| 201 |
+
|
| 202 |
+
def update(self, val, n=1):
|
| 203 |
+
self.val = val
|
| 204 |
+
self.sum += val * n
|
| 205 |
+
self.count += n
|
| 206 |
+
self.avg = self.sum / (0.0001 + self.count)
|
| 207 |
+
|
| 208 |
+
def __str__(self):
|
| 209 |
+
"""String representation for logging"""
|
| 210 |
+
# for values that should be recorded exactly e.g. iteration number
|
| 211 |
+
if self.count == 0:
|
| 212 |
+
return str(self.val)
|
| 213 |
+
# for stats
|
| 214 |
+
return '%.4f (%.4f)' % (self.val, self.avg)
|
detector_codes/C2P-CLIP-DeepfakeDetection-main/utils/util.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def mkdirs(paths):
|
| 8 |
+
if isinstance(paths, list) and not isinstance(paths, str):
|
| 9 |
+
for path in paths:
|
| 10 |
+
mkdir(path)
|
| 11 |
+
else:
|
| 12 |
+
mkdir(paths)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def mkdir(path):
|
| 16 |
+
if not os.path.exists(path):
|
| 17 |
+
os.makedirs(path)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def unnormalize(tens, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
|
| 21 |
+
# assume tensor of shape NxCxHxW
|
| 22 |
+
return (
|
| 23 |
+
tens * torch.Tensor(std)[None, :, None, None]
|
| 24 |
+
+ torch.Tensor(mean)[None, :, None, None]
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class Logger(object):
|
| 29 |
+
"""Log stdout messages."""
|
| 30 |
+
|
| 31 |
+
def __init__(self, outfile):
|
| 32 |
+
self.terminal = sys.stdout
|
| 33 |
+
self.log = open(outfile, 'a')
|
| 34 |
+
sys.stdout = self
|
| 35 |
+
|
| 36 |
+
def write(self, message):
|
| 37 |
+
self.terminal.write(message)
|
| 38 |
+
self.log.write(message)
|
| 39 |
+
|
| 40 |
+
def flush(self):
|
| 41 |
+
self.terminal.flush()
|
| 42 |
+
|
| 43 |
+
def isatty(self):
|
| 44 |
+
return self.terminal.isatty()
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def printSet(set_str):
|
| 48 |
+
set_str = str(set_str)
|
| 49 |
+
num = len(set_str)
|
| 50 |
+
print('=' * num * 3)
|
| 51 |
+
print(' ' * num + set_str)
|
| 52 |
+
print('=' * num * 3)
|
detector_codes/C2P-CLIP-DeepfakeDetection-main/validate.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from data import create_dataloader
|
| 4 |
+
from sklearn.metrics import accuracy_score, average_precision_score
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def validate(model, opt):
|
| 8 |
+
data_loader = create_dataloader(opt)
|
| 9 |
+
|
| 10 |
+
with torch.no_grad():
|
| 11 |
+
y_true, y_pred = [], []
|
| 12 |
+
for path, img, text, input_ids, attention_mask, label in data_loader:
|
| 13 |
+
y_pred.extend(
|
| 14 |
+
model(img.cuda(), None, None, cla=True).sigmoid().flatten().tolist()
|
| 15 |
+
)
|
| 16 |
+
y_true.extend(label.flatten().tolist())
|
| 17 |
+
|
| 18 |
+
y_true, y_pred = np.array(y_true), np.array(y_pred)
|
| 19 |
+
r_acc = accuracy_score(y_true[y_true == 0], y_pred[y_true == 0] > 0.5)
|
| 20 |
+
f_acc = accuracy_score(y_true[y_true == 1], y_pred[y_true == 1] > 0.5)
|
| 21 |
+
acc = accuracy_score(y_true, y_pred > 0.5)
|
| 22 |
+
ap = average_precision_score(y_true, y_pred)
|
| 23 |
+
return acc, ap, r_acc, f_acc, y_true, y_pred
|
detector_codes/C2P-DINOv2-main/dataset.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import Dataset
|
| 3 |
+
from torchvision import transforms
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class AIGIBenchDataset(Dataset):
|
| 7 |
+
def __init__(self, hf_data, transform=None):
|
| 8 |
+
self.hf_data = hf_data
|
| 9 |
+
self.transform = transform
|
| 10 |
+
|
| 11 |
+
def __len__(self):
|
| 12 |
+
return len(self.hf_data)
|
| 13 |
+
|
| 14 |
+
def __getitem__(self, idx):
|
| 15 |
+
item = self.hf_data[idx]
|
| 16 |
+
image = item['image'].convert('RGB')
|
| 17 |
+
label = item['label'] # 0 for Real, 1 for Fake
|
| 18 |
+
|
| 19 |
+
if self.transform:
|
| 20 |
+
image = self.transform(image)
|
| 21 |
+
|
| 22 |
+
return image, torch.tensor(label, dtype=torch.float32)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_train_transforms(size=224):
|
| 26 |
+
return transforms.Compose(
|
| 27 |
+
[
|
| 28 |
+
transforms.Resize(size + 32),
|
| 29 |
+
transforms.RandomCrop(size),
|
| 30 |
+
transforms.RandomHorizontalFlip(),
|
| 31 |
+
transforms.ToTensor(),
|
| 32 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 33 |
+
]
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def get_val_transforms(size=224):
|
| 38 |
+
return transforms.Compose(
|
| 39 |
+
[
|
| 40 |
+
transforms.Resize(size + 32),
|
| 41 |
+
transforms.CenterCrop(size),
|
| 42 |
+
transforms.ToTensor(),
|
| 43 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 44 |
+
]
|
| 45 |
+
)
|