diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..8a800070aa0dfe29177be4c4bc0a591c324ec2c9 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,3 @@ +*.png filter=lfs diff=lfs merge=lfs -text +*.mp4 filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..d6fbfdaeecf824cea79b62ba212f81e4fcbe9e8a --- /dev/null +++ b/.gitignore @@ -0,0 +1,14 @@ +__pycache__ + +/.vscode +/config +/datasets +/outputs +/runs +/weights +/logs +/tmp + +x.py +y.py +z.py diff --git a/.project-root b/.project-root new file mode 100644 index 0000000000000000000000000000000000000000..29575ae1d00627ae11eb501608ce64b5b16a510c --- /dev/null +++ b/.project-root @@ -0,0 +1 @@ +# Do not remove, this file is used by the project to determine the root of the project diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..2a7c37c99286ffa4a0fd09c7b823e33696496b68 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Andy + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d1721a17fcf45bcc120f7d1294be3f2265e3252e --- /dev/null +++ b/README.md @@ -0,0 +1,157 @@ +# Deepfake Detection that Generalizes Across Benchmarks (WACV 2026) + +[![arXiv Badge](https://img.shields.io/badge/arXiv-B31B1B?logo=arxiv&logoColor=FFF)](https://arxiv.org/abs/2508.06248) +[![Hugging Face Badge](https://img.shields.io/badge/Hugging%20Face-FFD21E?logo=huggingface&logoColor=000)](https://huggingface.co/collections/yermandy/gend) + +This is the official repository for the paper: + +**[Deepfake Detection that Generalizes Across Benchmarks](https://arxiv.org/abs/2508.06248)**. + +### Abstract + +> The generalization of deepfake detectors to unseen manipulation techniques remains a challenge for practical deployment. Although many approaches adapt foundation models by introducing significant architectural complexity, this work demonstrates that robust generalization is achievable through a parameter-efficient adaptation of one of the foundational pre-trained vision encoders. The proposed method, GenD, fine-tunes only the Layer Normalization parameters (0.03% of the total) and enhances generalization by enforcing a hyperspherical feature manifold using L2 normalization and metric learning on it. +> +> We conducted an extensive evaluation on 14 benchmark datasets spanning from 2019 to 2025. The proposed method achieves state-of-the-art performance, outperforming more complex, recent approaches in average cross-dataset AUROC. Our analysis yields two primary findings for the field: 1) training on paired real-fake data from the same source video is essential for mitigating shortcut learning and improving generalization, and 2) detection difficulty on academic datasets has not strictly increased over time, with models trained on older, diverse datasets showing strong generalization capabilities. +> +> This work delivers a computationally efficient and reproducible method, proving that state-of-the-art generalization is attainable by making targeted, minimal changes to a pre-trained foundational image encoder model. + +## Inference using Hugging Face transformers + +This example shows how to run inference with the pretrained GenD model from Hugging Face without other dependencies except `torch` and `transformers`. It expects that input images are already preprocessed by detector. + +### Minimal dependencies + +``` bash +conda create --name GenD python=3.12 uv -y +conda activate GenD +uv pip install torch==2.8.0 +uv pip install torchvision==0.23.0 +uv pip install transformers==4.56.2 +``` + +### Inference with transformers + +``` python +import requests +import torch +from PIL import Image + +from src.hf.modeling_gend import GenD + +# Other models can be found in https://huggingface.co/collections/yermandy/gend: +# -**** yermandy/GenD_CLIP_L_14 +# - yermandy/GenD_PE_L +# - yermandy/GenD_DINOv3_L +model = GenD.from_pretrained("yermandy/GenD_CLIP_L_14") + +urls = [ + "https://github.com/yermandy/deepfake-detection/blob/main/datasets/FF/DF/000_003/000.png?raw=true", + "https://github.com/yermandy/deepfake-detection/blob/main/datasets/FF/real/000/000.png?raw=true", +] +images = [Image.open(requests.get(url, stream=True).raw) for url in urls] +tensors = torch.stack([model.feature_extractor.preprocess(img) for img in images]) +logits = model(tensors) +probs = logits.softmax(dim=-1) + +print(probs) +``` + +## Training + +### Set up environment + +``` bash +conda create --name GenD python=3.12 uv -y +conda activate GenD +uv pip install -r requirements.txt +``` + +### Minimal example without external data + +#### Training example + +Examine `src/exp/examples.py`, each experiment name is defined as a key, a value overrides default configuration of `Config` object from `src/config.py`. For example, try to run `example-training` experiment: + +``` bash +python run_exp.py example-training +``` + +#### Test example after the model is trained + +``` bash +python run_exp.py example-test --from_exp example-training --test +``` + +Alternatively, you can try inference using one of our released models from Hugging Face: + +``` bash +python run_exp.py GenD_CLIP--CDFv2-example --test +python run_exp.py GenD_PE--CDFv2-example --test +python run_exp.py GenD_DINO--CDFv2-example --test +``` + +### Full training + +To fully train the model, you need to download datasets, preprocess them, and create files with paths to the images. + +The training entry will be similar to the minimal example above. + +All experiments (configs) from the paper are stored in the `src/exp` folder. + +#### Prepare the dataset + +Take for example [FaceForensics++](https://github.com/ondyari/FaceForensics) dataset, follow these steps: + +1. Download the dataset first from the [official source](https://github.com/ondyari/FaceForensics). The root of this dataset is `./FaceForensics` + +2. Preprocess the dataset using `detector.py` script: + +``` bash +python detector.py -i FaceForensics/manipulated_sequences/Deepfakes/c23/videos/ --mask_folder FaceForensics/masks/manipulated_sequences/Deepfakes/masks/videos/ -m at_least -n 32 -o datasets/FF/DF/ --det_thres 0.1 -s 1.3 --target_size none +``` + +Repeat the process for other manipulation methods and real videos. After processing everything, you will get a similar structure: + +``` bash +datasets +└── FF + ├── DF + │ └── 000_003 + │ ├── 025.png + │ └── 038.png + ├── F2F + │ └── 000_003 + │ ├── 019.png + │ └── 029.png + ├── FS + │ └── 000_003 + │ ├── 019.png + │ └── 029.png + ├── NT + │ └── 000_003 + │ ├── 019.png + │ └── 029.png + └── real + └── 000 + ├── 025.png + └── 038.png +``` + +3. Create files with paths to images similar to the ones in `config/datasets` directory. It can be done using: + +``` bash +find datasets/FF/DF/* -type f | sort > config/datasets/FF/DF.txt +``` + +We manage links to files using `src/utils/files.py`. + +### Cite + +``` bibtex +@article{yermakov2025deepfake, + title={Deepfake Detection that Generalizes Across Benchmarks}, + author={Yermakov, Andrii and Cech, Jan and Matas, Jiri and Fritz, Mario}, + journal={arXiv preprint arXiv:2508.06248}, + year={2025} +} +``` diff --git a/config/datasets/CDFv2/test/Celeb-real.txt b/config/datasets/CDFv2/test/Celeb-real.txt new file mode 100644 index 0000000000000000000000000000000000000000..40d120755cf15e4c9f1a45ead8da53aee95f23b0 --- /dev/null +++ b/config/datasets/CDFv2/test/Celeb-real.txt @@ -0,0 +1,4 @@ +datasets/CDFv2/Celeb-real/id0_0000/045.png +datasets/CDFv2/Celeb-real/id0_0000/030.png +datasets/CDFv2/Celeb-real/id0_0000/015.png +datasets/CDFv2/Celeb-real/id0_0000/000.png \ No newline at end of file diff --git a/config/datasets/CDFv2/test/Celeb-synthesis.txt b/config/datasets/CDFv2/test/Celeb-synthesis.txt new file mode 100644 index 0000000000000000000000000000000000000000..fdf5540f9cd82dc0747f2db5e72a9532a996a94f --- /dev/null +++ b/config/datasets/CDFv2/test/Celeb-synthesis.txt @@ -0,0 +1,4 @@ +datasets/CDFv2/Celeb-synthesis/id0_id1_0000/000.png +datasets/CDFv2/Celeb-synthesis/id0_id1_0000/045.png +datasets/CDFv2/Celeb-synthesis/id0_id1_0000/030.png +datasets/CDFv2/Celeb-synthesis/id0_id1_0000/015.png \ No newline at end of file diff --git a/config/datasets/CDFv2/test/YouTube-real.txt b/config/datasets/CDFv2/test/YouTube-real.txt new file mode 100644 index 0000000000000000000000000000000000000000..805c7b42a274fd0d3401a13fdd23d5ebc3afe863 --- /dev/null +++ b/config/datasets/CDFv2/test/YouTube-real.txt @@ -0,0 +1,4 @@ +datasets/CDFv2/YouTube-real/00000/000.png +datasets/CDFv2/YouTube-real/00000/014.png +datasets/CDFv2/YouTube-real/00000/028.png +datasets/CDFv2/YouTube-real/00000/043.png \ No newline at end of file diff --git a/config/datasets/FF/test/DF.txt b/config/datasets/FF/test/DF.txt new file mode 100644 index 0000000000000000000000000000000000000000..6681b096fdc337e82817400312e3cc9381f3705f --- /dev/null +++ b/config/datasets/FF/test/DF.txt @@ -0,0 +1,4 @@ +datasets/FF/DF/000_003/000.png +datasets/FF/DF/000_003/012.png +datasets/FF/DF/000_003/025.png +datasets/FF/DF/000_003/038.png diff --git a/config/datasets/FF/test/F2F.txt b/config/datasets/FF/test/F2F.txt new file mode 100644 index 0000000000000000000000000000000000000000..9f88877c3c1fa3cfa3e4103bbfbf7ba89735f288 --- /dev/null +++ b/config/datasets/FF/test/F2F.txt @@ -0,0 +1,4 @@ +datasets/FF/F2F/000_003/000.png +datasets/FF/F2F/000_003/009.png +datasets/FF/F2F/000_003/019.png +datasets/FF/F2F/000_003/029.png diff --git a/config/datasets/FF/test/FS.txt b/config/datasets/FF/test/FS.txt new file mode 100644 index 0000000000000000000000000000000000000000..992d26e9c95ad8c567d34909e8233f8427820cd5 --- /dev/null +++ b/config/datasets/FF/test/FS.txt @@ -0,0 +1,4 @@ +datasets/FF/FS/000_003/000.png +datasets/FF/FS/000_003/009.png +datasets/FF/FS/000_003/019.png +datasets/FF/FS/000_003/029.png diff --git a/config/datasets/FF/test/NT.txt b/config/datasets/FF/test/NT.txt new file mode 100644 index 0000000000000000000000000000000000000000..ca05e64910dbaee70ab6783d976fb6e80176588b --- /dev/null +++ b/config/datasets/FF/test/NT.txt @@ -0,0 +1,4 @@ +datasets/FF/NT/000_003/000.png +datasets/FF/NT/000_003/009.png +datasets/FF/NT/000_003/019.png +datasets/FF/NT/000_003/029.png diff --git a/config/datasets/FF/test/real.txt b/config/datasets/FF/test/real.txt new file mode 100644 index 0000000000000000000000000000000000000000..2e3edd961751760da1b365df339b626aa89c35dd --- /dev/null +++ b/config/datasets/FF/test/real.txt @@ -0,0 +1,4 @@ +datasets/FF/real/000/000.png +datasets/FF/real/000/012.png +datasets/FF/real/000/025.png +datasets/FF/real/000/038.png diff --git a/datasets/CDFv2/Celeb-real/id0_0000/000.png b/datasets/CDFv2/Celeb-real/id0_0000/000.png new file mode 100644 index 0000000000000000000000000000000000000000..14e0711a356cbbd35e0e6a892ca8c89357cede15 --- /dev/null +++ b/datasets/CDFv2/Celeb-real/id0_0000/000.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:33a652cb6ad545d41465a978ab4bb02137380db1747a1a460d193a3e0ecd4db6 +size 51807 diff --git a/datasets/CDFv2/Celeb-real/id0_0000/015.png b/datasets/CDFv2/Celeb-real/id0_0000/015.png new file mode 100644 index 0000000000000000000000000000000000000000..6395cefa38741d424a9ec50d55487f8370599758 --- /dev/null +++ b/datasets/CDFv2/Celeb-real/id0_0000/015.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cca6e95d080ccafbd35709c0e6ce60a12b415c7f97cf40ac4c1edb7fba441e4f +size 55150 diff --git a/datasets/CDFv2/Celeb-real/id0_0000/030.png b/datasets/CDFv2/Celeb-real/id0_0000/030.png new file mode 100644 index 0000000000000000000000000000000000000000..4cd77079722a7c94019e88e8c7e0d9f1ad4d00cd --- /dev/null +++ b/datasets/CDFv2/Celeb-real/id0_0000/030.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:20d6775e831ef3bab80cd928c930b0916f09b933814fe0d2b1d6e51b0106c77a +size 56317 diff --git a/datasets/CDFv2/Celeb-real/id0_0000/045.png b/datasets/CDFv2/Celeb-real/id0_0000/045.png new file mode 100644 index 0000000000000000000000000000000000000000..90460f2ad1ec54795b568d95607485b2dfda585f --- /dev/null +++ b/datasets/CDFv2/Celeb-real/id0_0000/045.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6ca1683cfe93a01ac7800de0efa5f82c38abb5b17582c5708671de567958e08e +size 57774 diff --git a/datasets/CDFv2/Celeb-synthesis/id0_id1_0000/000.png b/datasets/CDFv2/Celeb-synthesis/id0_id1_0000/000.png new file mode 100644 index 0000000000000000000000000000000000000000..f3abcd7156ddf0ab1cb3a241c0271d36888d7449 --- /dev/null +++ b/datasets/CDFv2/Celeb-synthesis/id0_id1_0000/000.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cae26283688b2e1855b75b922f75e0945dd29f1669e5c399b9e0f5bc75a4700c +size 51542 diff --git a/datasets/CDFv2/Celeb-synthesis/id0_id1_0000/015.png b/datasets/CDFv2/Celeb-synthesis/id0_id1_0000/015.png new file mode 100644 index 0000000000000000000000000000000000000000..b76b4146181b9924938e30819a1d592f56af8728 --- /dev/null +++ b/datasets/CDFv2/Celeb-synthesis/id0_id1_0000/015.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:708f825b0fa8403e5db2966d96e4be4b18f4d1f55ab7352acdc2bd2d7542ee5b +size 54590 diff --git a/datasets/CDFv2/Celeb-synthesis/id0_id1_0000/030.png b/datasets/CDFv2/Celeb-synthesis/id0_id1_0000/030.png new file mode 100644 index 0000000000000000000000000000000000000000..a6f8a81521e789f1460a3a4215db6f8d4f2dfe2f --- /dev/null +++ b/datasets/CDFv2/Celeb-synthesis/id0_id1_0000/030.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5693e8c8469630a8699f6cc4b8a51edd4297d31a956a14deba1a1b7796230ec4 +size 54016 diff --git a/datasets/CDFv2/Celeb-synthesis/id0_id1_0000/045.png b/datasets/CDFv2/Celeb-synthesis/id0_id1_0000/045.png new file mode 100644 index 0000000000000000000000000000000000000000..70a62837a79780d5f022026c16a57cf9df900ee0 --- /dev/null +++ b/datasets/CDFv2/Celeb-synthesis/id0_id1_0000/045.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:90c2250095d7331e64caa42d04d25cac5881288bbfa7013e68302cbe758ade06 +size 56115 diff --git a/datasets/CDFv2/YouTube-real/00000/000.png b/datasets/CDFv2/YouTube-real/00000/000.png new file mode 100644 index 0000000000000000000000000000000000000000..de7707be77a4a29b665c5896f9e9ae47b73e021b --- /dev/null +++ b/datasets/CDFv2/YouTube-real/00000/000.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c959f785348832e39685a4918904b3a145ff475ce69eef55624708b423473218 +size 52062 diff --git a/datasets/CDFv2/YouTube-real/00000/014.png b/datasets/CDFv2/YouTube-real/00000/014.png new file mode 100644 index 0000000000000000000000000000000000000000..5eaab1191aac0a36b1946b38f769b7fe5bcbdd5f --- /dev/null +++ b/datasets/CDFv2/YouTube-real/00000/014.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:775bf1f48de7319a4557071844ad3fb587bc048f0ca1d8a52b696f4167476996 +size 58779 diff --git a/datasets/CDFv2/YouTube-real/00000/028.png b/datasets/CDFv2/YouTube-real/00000/028.png new file mode 100644 index 0000000000000000000000000000000000000000..83a7c42aace0b9b52cfa894bd75df79bc47ae874 --- /dev/null +++ b/datasets/CDFv2/YouTube-real/00000/028.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c7eb0e9af30cbe4e2cab7c6df54a434405470eca81b416e6f835e65fac8e1fc6 +size 58978 diff --git a/datasets/CDFv2/YouTube-real/00000/043.png b/datasets/CDFv2/YouTube-real/00000/043.png new file mode 100644 index 0000000000000000000000000000000000000000..de3fff6de12ab4b4a392d150cdb91bf752284778 --- /dev/null +++ b/datasets/CDFv2/YouTube-real/00000/043.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:37adad8dd2df6a8d82f5ea530ab9db464a81037125180e7130ee3fe1bc1ac567 +size 59221 diff --git a/datasets/FF/DF/000_003/000.png b/datasets/FF/DF/000_003/000.png new file mode 100644 index 0000000000000000000000000000000000000000..a5b2fce9ead65a4075c469f4334267590e46a689 --- /dev/null +++ b/datasets/FF/DF/000_003/000.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b2c605732c6b2152320a986173c1a3dc7544938e948aa46444340831b8018060 +size 82576 diff --git a/datasets/FF/DF/000_003/012.png b/datasets/FF/DF/000_003/012.png new file mode 100644 index 0000000000000000000000000000000000000000..b2baddbadb252e78ece367171f55a477a8c7f16c --- /dev/null +++ b/datasets/FF/DF/000_003/012.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:91b2bf24d9c3685c28559a5ac8c91b94bbe5b841a563b961f4a7679f40e8f4b8 +size 83752 diff --git a/datasets/FF/DF/000_003/025.png b/datasets/FF/DF/000_003/025.png new file mode 100644 index 0000000000000000000000000000000000000000..fc62ff04e03217af96597d12fa695d1bfacf03ed --- /dev/null +++ b/datasets/FF/DF/000_003/025.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ca387094072caa412a0f683d909084580896f2c40dd669baf41c04166efffa95 +size 81980 diff --git a/datasets/FF/DF/000_003/038.png b/datasets/FF/DF/000_003/038.png new file mode 100644 index 0000000000000000000000000000000000000000..00189bc6edb88b7b2dbdd0d46d9083304cdc2aa4 --- /dev/null +++ b/datasets/FF/DF/000_003/038.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:781d4ab6041c8b37c75b7d831f02944cae11b41fb848d5e34c577a142cb3a1a0 +size 82167 diff --git a/datasets/FF/F2F/000_003/000.png b/datasets/FF/F2F/000_003/000.png new file mode 100644 index 0000000000000000000000000000000000000000..0b31206b72bb1b8360d50f898419942adb10506d --- /dev/null +++ b/datasets/FF/F2F/000_003/000.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a8e8c25ddc42909f3b82aedf939ca5fcf3924d1aeb54ad9dbb507ed97d050692 +size 82653 diff --git a/datasets/FF/F2F/000_003/009.png b/datasets/FF/F2F/000_003/009.png new file mode 100644 index 0000000000000000000000000000000000000000..1aca648b510437bb12c2c38301e4d3d64cee0e31 --- /dev/null +++ b/datasets/FF/F2F/000_003/009.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8ddccb64c403bb4cd0245f6826699bd53b02c3bc9b5435d3134b4d9aba019d85 +size 82061 diff --git a/datasets/FF/F2F/000_003/019.png b/datasets/FF/F2F/000_003/019.png new file mode 100644 index 0000000000000000000000000000000000000000..2b89e5246e689cbc003b25f0672b45150e9d09fb --- /dev/null +++ b/datasets/FF/F2F/000_003/019.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e1a08d87e07a821d0972e893e5b7fd7777e901d7b0433d1a1acb21862f349f5f +size 82045 diff --git a/datasets/FF/F2F/000_003/029.png b/datasets/FF/F2F/000_003/029.png new file mode 100644 index 0000000000000000000000000000000000000000..bae6a9922cace1a8a76510ca605e883577fe3e58 --- /dev/null +++ b/datasets/FF/F2F/000_003/029.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:da6f3aa04c3e0d6155dacdb4003cfb507f5b600b087e57d8184387c1fbbc76eb +size 81992 diff --git a/datasets/FF/FS/000_003/000.png b/datasets/FF/FS/000_003/000.png new file mode 100644 index 0000000000000000000000000000000000000000..b6b67e94d27d034aa5e51fdae22208464ce97e7a --- /dev/null +++ b/datasets/FF/FS/000_003/000.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:57b282095be875d9b161360d43cbd4b04d0536c093bb967e6bc6661baf7db361 +size 82267 diff --git a/datasets/FF/FS/000_003/009.png b/datasets/FF/FS/000_003/009.png new file mode 100644 index 0000000000000000000000000000000000000000..51c3e1d2e2a8bfbfe4d124948bd715ad41997828 --- /dev/null +++ b/datasets/FF/FS/000_003/009.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8481e3a732f811a832598bd39ffb426cd966f4b16478e603c5ec9f5351b732ff +size 81058 diff --git a/datasets/FF/FS/000_003/019.png b/datasets/FF/FS/000_003/019.png new file mode 100644 index 0000000000000000000000000000000000000000..6e4a7c5163ce62a67d0da8ae8abee0e9fd079dba --- /dev/null +++ b/datasets/FF/FS/000_003/019.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bc24297d7825d45f663960f9bc0b7e215e7b67312ea9b00f495d289df432b32b +size 81307 diff --git a/datasets/FF/FS/000_003/029.png b/datasets/FF/FS/000_003/029.png new file mode 100644 index 0000000000000000000000000000000000000000..67521039dbd541736ac7868cb910c8b064e6017a --- /dev/null +++ b/datasets/FF/FS/000_003/029.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:caae8b99b9f0016e7e3c2bbe9500422687bbdcd4ae2305f4de8ff7a05c5bf587 +size 80480 diff --git a/datasets/FF/NT/000_003/000.png b/datasets/FF/NT/000_003/000.png new file mode 100644 index 0000000000000000000000000000000000000000..49a41d7e057a258c4517e3e9d05b79b262677f1d --- /dev/null +++ b/datasets/FF/NT/000_003/000.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ca2b6ce7ea30df50d86855adc4c2dd11b1ba0848153746303e73cb58d7035290 +size 81520 diff --git a/datasets/FF/NT/000_003/009.png b/datasets/FF/NT/000_003/009.png new file mode 100644 index 0000000000000000000000000000000000000000..bd4ae1a9a87db5b52faee3262e60c165bc9893d8 --- /dev/null +++ b/datasets/FF/NT/000_003/009.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:616d27b48e463eeadd9f10a3b6294b20d28b71393cbca66c7e5d91c626644f01 +size 80488 diff --git a/datasets/FF/NT/000_003/019.png b/datasets/FF/NT/000_003/019.png new file mode 100644 index 0000000000000000000000000000000000000000..243355e0901f3828517601664553e5354153e6d9 --- /dev/null +++ b/datasets/FF/NT/000_003/019.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ec4ac46bf3f06d5de47af5bb7e7af691fc3dd46199f8d5216cdbabe65f311faf +size 80175 diff --git a/datasets/FF/NT/000_003/029.png b/datasets/FF/NT/000_003/029.png new file mode 100644 index 0000000000000000000000000000000000000000..a98fc5631d4502fadf802d2bce53b2896f501f52 --- /dev/null +++ b/datasets/FF/NT/000_003/029.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:37a3fd9de8ea981445ca036dfc4ba44d8d1ff42163e14972563e6c5485b57af2 +size 80145 diff --git a/datasets/FF/real/000/000.png b/datasets/FF/real/000/000.png new file mode 100644 index 0000000000000000000000000000000000000000..3ff02f673eb33fa649c24594dc2d2cd871bc188e --- /dev/null +++ b/datasets/FF/real/000/000.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:33813fa2f7a716f20f27f11bf7e4126c53136108bec3db92dd6001e9453b185a +size 84038 diff --git a/datasets/FF/real/000/012.png b/datasets/FF/real/000/012.png new file mode 100644 index 0000000000000000000000000000000000000000..510ae4cfb2f3fad7ab6a99134adbe0417a318c79 --- /dev/null +++ b/datasets/FF/real/000/012.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d3af45388e764ae25e9a176a9ed148004d4e155146c9364199852a96815aafaf +size 84754 diff --git a/datasets/FF/real/000/025.png b/datasets/FF/real/000/025.png new file mode 100644 index 0000000000000000000000000000000000000000..eb0c822f89b20f49436c85c82ee323137bee2d18 --- /dev/null +++ b/datasets/FF/real/000/025.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2b027fe48cf3468ad596acf3f8f9db59374dce243f57d4abdb89ab557e6681b6 +size 82401 diff --git a/datasets/FF/real/000/038.png b/datasets/FF/real/000/038.png new file mode 100644 index 0000000000000000000000000000000000000000..a599bf19c290dd669b7cf30c7d6f508a4c28820f --- /dev/null +++ b/datasets/FF/real/000/038.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fc521c3b01c76428d353d734fde85969613e450666135881029c1e4e5cd38ce1 +size 83664 diff --git a/detector.py b/detector.py new file mode 100644 index 0000000000000000000000000000000000000000..b46e0c5674084d1105e590adb72158c149032828 --- /dev/null +++ b/detector.py @@ -0,0 +1,701 @@ +import argparse +import heapq +import os +import subprocess +from concurrent.futures import ThreadPoolExecutor +from glob import glob + +import cv2 +import numpy as np +from tqdm import tqdm + +from src.retinaface import RetinaFace, prepare_model + + +def max_spread_permutation_pq(N, start=0): + """ + Generate a permutation of 0..N-1 such that at each step + the next element is the one whose minimum distance to + all previously chosen elements is maximized, using a + priority queue to speed up selection. + + Args: + N (int): Length of the permutation. + start (int): The first element in the permutation (default 0). + + Returns: + List[int]: A list representing the permutation. + """ + if not (0 <= start < N): + raise ValueError("`start` must be in the range [0, N-1]") + + # Initialize chosen list and distance map + chosen = [start] + dist = {i: abs(i - start) for i in range(N) if i != start} + + # Build a max-heap (use negative distances for heapq) + heap = [(-d, i) for i, d in dist.items()] + heapq.heapify(heap) + + # Greedily pick elements + while heap: + # Pop until we find a valid (up-to-date) entry + while True: + neg_d, candidate = heapq.heappop(heap) + current = -neg_d + # Only accept if it matches the latest dist + if dist.get(candidate, -1) == current: + break + + # Add the selected candidate + chosen.append(candidate) + # Remove it from dist-map + del dist[candidate] + + # Update distances for remaining elements + for other in list(dist.keys()): + new_d = abs(other - candidate) + if new_d < dist[other]: + dist[other] = new_d + heapq.heappush(heap, (-new_d, other)) + + return chosen + + +def get_video_frames_generator( + source_path: str, + mask_path: str, + stride: int = 1, + num_frames=32, + mode="at_least", +): + video = cv2.VideoCapture(source_path) + if not video.isOpened(): + print(f"Warning: Video {source_path} cannot be opened!") + return + + video_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) + + if mask_path is not None: + mask_video = cv2.VideoCapture(mask_path) + + if not mask_video.isOpened(): + print(f"Warning: Mask video {mask_path} cannot be opened!") + return + + mask_frames = int(mask_video.get(cv2.CAP_PROP_FRAME_COUNT)) + + if video_frames != mask_frames: + print( + f"Warning: {source_path} and {mask_path} have different number of frames {video_frames} vs {mask_frames}!" + ) + + total_frames = min(video_frames, mask_frames) + else: + mask_video = None + total_frames = video_frames + + if not video.isOpened(): + raise Exception(f"Could not open video at {source_path}") + + # Get the mode + if mode == "fixed_num_frames": + # Get the frame rate of the video by dividing the number of frames by the duration (same interval between frames) + frame_ids = np.linspace(0, total_frames - 1, num_frames, endpoint=True, dtype=int) + elif mode == "fixed_stride": + # Get the frame rate of the video by dividing the number of frames by the duration (same interval between frames) + frame_ids = np.arange(0, total_frames, stride, dtype=int) + elif mode == "at_least": + frame_ids = max_spread_permutation_pq(total_frames, start=total_frames // 2) + else: + raise ValueError(f"Invalid mode: {mode}. Choose 'fixed_num_frames', 'fixed_stride', or 'at_least'.") + + # Iterate through the selected frame IDs + for frame_id in frame_ids: + # Set the video capture position to the desired frame + video.set(cv2.CAP_PROP_POS_FRAMES, frame_id) + success, frame = video.read() + + if mask_video is not None: + mask_video.set(cv2.CAP_PROP_POS_FRAMES, frame_id) + success_mask, mask = mask_video.read() + if not success_mask: + print(f"Warning: Failed to read mask frame {frame_id} of {mask_path}. Skipping.") + continue + + yield frame, frame_id, mask + + else: + # Check if the frame was successfully read + if not success: + print(f"Warning: Failed to read frame {frame_id} of {source_path}. Skipping.") + continue + + yield frame, frame_id, None + + # Release the video capture object + video.release() + + if mask_video is not None: + mask_video.release() + + +def align_face( + img: np.ndarray, + landmarks: np.ndarray, + target_size: None | tuple = None, + scale: float = 1.3, + mask: np.ndarray = None, +): + """ + Aligns a face based on 5-point facial landmarks (eyes, nose, mouth corners). + + Args: + img: Input image containing the face + landmarks: 5-point facial landmarks array with shape (5, 2) + target_size: Desired output size as (width, height) tuple + scale: Scaling factor to control how much context around the face to include + stabilize_features: Whether to use standard reference points for consistent alignment + return_transform: Whether to return the transformation matrix + mask: Resize mask the same way as img + + Returns: + Aligned face image with specified target_size + Optionally returns the transformation matrix if return_transform=True + """ + dst = np.array( + [ + [0.34, 0.46], + [0.66, 0.46], + [0.5, 0.64], + [0.37, 0.82], + [0.63, 0.82], + ], + dtype=np.float32, + ) + + if target_size is None: + # Compute desired distances between all pairs + desired_dists = np.linalg.norm(landmarks[:, None, :] - landmarks[None, :, :], axis=-1) + + # Destination distances between all pairs + dst_dists = np.linalg.norm(dst[:, None, :] - dst[None, :, :], axis=-1) + + # Take upper triangle of the distance matrix + upper_triangle_indices = np.triu_indices(len(dst), k=1) + dst_dists = dst_dists[upper_triangle_indices] + desired_dists = desired_dists[upper_triangle_indices] + + # Approximate target size + approx_size = np.round(np.mean(desired_dists / dst_dists) * scale).astype(int) + target_size = (approx_size, approx_size) + + dst[:, 0] = dst[:, 0] * target_size[0] + dst[:, 1] = dst[:, 1] * target_size[1] + + margin_rate = scale - 1 + x_margin = target_size[0] * margin_rate / 2.0 + y_margin = target_size[1] * margin_rate / 2.0 + + # move + dst[:, 0] += x_margin + dst[:, 1] += y_margin + + # resize + dst[:, 0] *= target_size[0] / (target_size[0] + 2 * x_margin) + dst[:, 1] *= target_size[1] / (target_size[1] + 2 * y_margin) + + src = landmarks.astype(np.float32) + + M = cv2.estimateAffinePartial2D(src, dst, method=cv2.LMEDS)[0] + + img = cv2.warpAffine(img, M, target_size, flags=cv2.INTER_LINEAR) + + # Warp landmarks, show + # landmarks = cv2.transform(np.expand_dims(landmarks, axis=0), M)[0] + # for point in landmarks: + # cv2.circle(img, tuple(point.astype(int)), 2, (0, 255, 0), -1) + + if mask is not None: + mask = cv2.warpAffine(mask, M, target_size, flags=cv2.INTER_NEAREST) + + return img, mask + + +def process_video( + source_path, + target_path, + mask_path, + model: RetinaFace, + scale=1.3, + target_size=(256, 256), + stride=1, + num_frames=32, + mode="at_least", + skip_processed_videos=False, + skip_processed_frames=False, +): + frame_save_path = target_path.replace(".mp4", "/frames") + + # Skip if frame_save_path exists + if skip_processed_videos and os.path.exists(frame_save_path): + print(f"Frames for {source_path} already processed.") + return + else: + print(f"Processing {source_path}") + + # Create a frame generator from video path for iteration of frames + frame_generator = get_video_frames_generator( + source_path, + mask_path, + stride=stride, + num_frames=num_frames, + mode=mode, + ) + # desc = f"Processing {os.path.basename(source_path)}" + + num_saved = 0 + for frame, frame_id, mask in frame_generator: + frame_filename = os.path.join(frame_save_path, f"frame_{frame_id:04d}.png") + + if skip_processed_frames and os.path.exists(frame_filename): + print(f"Frame {frame_id} of {source_path} already processed.") + num_saved += 1 + if mode in ["fixed_stride", "at_least"] and num_saved >= num_frames and num_frames != -1: + break + continue + + try: + preds = model.detect(frame) + except Exception as e: + print(f"Error during detection: {e}") + continue + + xyxy, landmarks = preds + + if len(xyxy) == 0: + print(f"No faces detected in frame {frame_id} of {source_path}") + continue + + selected_landmarks = None + + if mask is not None: + # It is possible that the mask is empty -> skip this frame + if mask.sum() == 0: + print(f"Warning: Mask is empty for frame {frame_id} of {source_path}. Skipping.") + continue + + # Convert mask to grayscale if it's not already + mask_img = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) if len(mask.shape) == 3 else mask + + # Threshold the mask to create a binary mask + mask_img = cv2.threshold(mask_img, 1, 255, cv2.THRESH_BINARY)[1] + + # Find the face that intersects the most with the mask + best_landmarks = None + max_intersection = 0 + for i in range(len(xyxy)): + # Get the bounding box coordinates + x1, y1, x2, y2 = xyxy[i, :4].astype(int) + + # Create a mask for the face + face_mask = np.zeros_like(mask_img) + face_mask[y1:y2, x1:x2] = 255 + + # Calculate the intersection between the face mask and the provided mask + intersection = np.sum(np.logical_and(face_mask, mask_img)) + + # Update the best face if the intersection is greater than the current maximum + if intersection > max_intersection: + max_intersection = intersection + best_landmarks = landmarks[i] + + # If a face was found, use it; otherwise, skip this frame + if best_landmarks is not None: + selected_landmarks = best_landmarks + else: + print(f"No suitable face found in frame {frame_id} of {source_path} with the provided mask.") + continue + + # """ + # Select landmarks of the largest face if not using mask + if selected_landmarks is None: + areas = (xyxy[:, 2] - xyxy[:, 0]) * (xyxy[:, 3] - xyxy[:, 1]) + idx = np.argmax(areas) + selected_landmarks = landmarks[idx] + + # Show all landmarks + # for L, B in zip(landmarks, xyxy): + # for point in L: + # cv2.circle(frame, tuple(point.astype(int)), 2, (0, 255, 0), -1) + # cv2.rectangle(frame, tuple(B[0:2].astype(int)), tuple(B[2:4].astype(int)), (0, 255, 0), 2) + + # Align the face + aligned_face, _ = align_face(frame, selected_landmarks, target_size=target_size, scale=scale) + + # Save the aligned face + os.makedirs(frame_save_path, exist_ok=True) + cv2.imwrite(frame_filename, aligned_face) + # """ + + num_saved += 1 + + if mode in ["fixed_stride", "at_least"] and num_saved >= num_frames and num_frames != -1: + break + + if num_saved == 0: + print(f"No faces were saved from {source_path}. Check the detection threshold or input video.") + + return frame_save_path + + +def process_image( + source_path, + target_path, + model: RetinaFace, + scale=1.3, + target_size=(256, 256), + skip_processed_frames=False, +): + """Processes a single image file.""" + if skip_processed_frames and os.path.exists(target_path): + print(f"Image {source_path} already processed.") + return target_path + else: + print(f"Processing {source_path}") + + img = cv2.imread(source_path) + if img is None: + print(f"Failed to read image {source_path}") + return None + + try: + preds = model.detect(img) + except Exception as e: + print(f"Error during detection: {e}") + return None + + xyxy, landmarks = preds + + if len(xyxy) == 0: + print(f"No faces detected in {source_path}") + return None + + # Select landmarks of the largest face + areas = (xyxy[:, 2] - xyxy[:, 0]) * (xyxy[:, 3] - xyxy[:, 1]) + idx = np.argmax(areas) + landmarks = landmarks[idx] + + # Align the face + aligned_face, _ = align_face(img, landmarks, target_size=target_size, scale=scale) + + # Save the aligned face + os.makedirs(os.path.dirname(target_path), exist_ok=True) + cv2.imwrite(target_path, aligned_face) + return target_path + + +def get_output_path(source_path, input_folder, output_folder): + # Example: source_path = input_folder + new_source_path`` + new_source_path = source_path.replace(input_folder, os.path.basename(input_folder)) + # Create directory for each video + new_source_path = new_source_path.replace(".mp4", "") + # Place it in the output folder + output_path = os.path.join(output_folder, new_source_path) + return output_path + + +def get_mask_path(input_folder, input_mask_folder, source_path): + if input_mask_folder is not None: + # Change the input folder to the mask folder + source_path = source_path.replace(input_folder, input_mask_folder) + + #! FF++ has masks named the same way as original videos + if "FaceForensics" in source_path or "FF++" in source_path: + return source_path + + #! Else assume masks are named with _mask suffix + source_path = source_path.replace(".mp4", "_mask.mp4") + return source_path + return None + + +def process_mixed_types( + input_folder_or_file: str | list[str], + input_mask_folder: None | str, + model: RetinaFace, + num_workers=1, + scale=1.3, + target_size=(256, 256), + stride=1, + num_frames=32, + mode: str = "fixed_num_frames", + output_folder: str = "outputs", + possible_extensions: tuple[str] = ("mp4", "jpg", "png", "jpeg"), + skip_processed_videos: bool = False, + skip_processed_frames: bool = False, +): + if os.path.isfile(input_folder_or_file): + # If input is a file + if input_folder_or_file.endswith(possible_extensions): + # If input is a media file + files = [input_folder_or_file] + elif input_folder_or_file.endswith("txt"): + # If input is a txt file + with open(input_folder_or_file, "r") as f: + files = f.read().splitlines() + + else: + # If input is a folder + files = find_files(input_folder_or_file, possible_extensions) + + if not files: + print(f"No files found in {input_folder_or_file}") + return + + def process(source_path): + output_path = get_output_path(source_path, input_folder_or_file, output_folder) + + if source_path.endswith(".mp4"): + mask_path = get_mask_path(input_folder_or_file, input_mask_folder, source_path) + try: + return process_video( + source_path, + output_path, + mask_path, + model, + scale=scale, + target_size=target_size, + stride=stride, + num_frames=num_frames, + mode=mode, + skip_processed_videos=skip_processed_videos, + skip_processed_frames=skip_processed_frames, + ) + except Exception as e: + print(f"Error processing video {source_path}: {e}") + else: + try: + return process_image( + source_path, + output_path, + model, + scale=scale, + target_size=target_size, + skip_processed_frames=skip_processed_frames, + ) + except Exception as e: + print(f"Error processing image {source_path}: {e}") + + files = sorted(files) # Sort files for consistent processing + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [executor.submit(process, file) for file in files] + for future in tqdm(futures, desc=f"Processing videos in {input_folder_or_file}", leave=True): + future.result() + + print("Processing complete.") + + +def find_files_fd(start_dir, extensions): + """ + Finds files with given extensions recursively using the 'fd' command-line tool. + + Args: + start_dir (str): The directory to start searching from. + extensions (list): A list of file extensions without the leading dot (e.g., ['png', 'jpg']). + + Returns: + list: A list of full path strings for each found file. Returns empty list if fd fails. + + Raises: + FileNotFoundError: If the 'fd' command is not found in the system's PATH. + """ + if not os.path.isdir(start_dir): + print(f"Error: Start directory not found: {start_dir}") + return [] + + try: + # Build the command. Use -e for each extension. + command = ["fd", "--type", "f", "--type", "l"] # Find only files or links to files + for ext in extensions: + # fd expects extensions without the dot + command.extend(["--extension", ext]) + # Add the pattern ('.' matches everything, filtering is done by extension) + # and the directory to search + command.extend([".", start_dir]) + + # Run the command + result = subprocess.run( + command, + capture_output=True, # Capture stdout and stderr + text=True, # Decode output as text (UTF-8 by default) + check=False, # Do not raise exception on non-zero exit code automatically + encoding="utf-8", # Be explicit about encoding + ) + + # Check if fd ran successfully + if result.returncode != 0: + # fd returns specific exit codes, e.g., 1 for errors, 2 if pattern not found (but we use '.') + # We mainly care if the command executed but maybe found nothing or had an issue. + # Check stderr for actual errors. + if result.stderr: + print(f"Error running fd (code {result.returncode}): {result.stderr.strip()}") + # If stderr is empty but code isn't 0, it might just mean no files found, which is okay. + # We return an empty list in case of errors or no files found. + return [] # Return empty list on error or if no files found + + # fd outputs one path per line. Split the output. + # .strip() removes potential leading/trailing whitespace/newlines + file_list = result.stdout.strip().splitlines() + return file_list + + except FileNotFoundError: + raise # Re-raise the exception so the caller knows fd is missing + + except Exception as e: + print(f"An unexpected error occurred while running fd: {e}") + return [] # Return empty list on other unexpected errors + + +def find_files_glob(start_dir, extensions): + """ + Finds files with given extensions recursively using glob. + + Args: + start_dir (str): The directory to start searching from. + extensions (list): A list of file extensions without the leading dot (e.g., ['png', 'jpg']). + + Returns: + list: A list of full path strings for each found file. + """ + files = [] + for ext in extensions: + files.extend(glob(f"{start_dir}/**/*{ext}", recursive=True)) + return sorted(f for f in files if os.path.isfile(f)) + + +def find_files(start_dir, extensions): + try: + return find_files_fd(start_dir, extensions) + except Exception: + return find_files_glob(start_dir, extensions) + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-i", + "--input_folder_or_file", + type=str, + required=True, + help="Path to the input folder containing videos or images.", + ) + parser.add_argument( + "--mask_folder", + type=str, + default=None, + help="Path to the input folder containing masks (optional).", + ) + parser.add_argument( + "--num_workers", + type=int, + default=8, + help="Number of worker threads.", + ) + parser.add_argument( + "-s", + "--scale", + type=float, + default=1.3, + help="Scale factor for face alignment.", + ) + parser.add_argument( + "--target_size", + type=str, + default="256,256", + help="Target size for aligned faces as width, height (e.g., 256,256) or 'none'.", + ) + parser.add_argument( + "--det_thres", + type=float, + default=0.4, + help="Detection threshold for RetinaFace.", + ) + parser.add_argument( + "-m", + "--mode", + type=str, + default="at_least", + choices=["fixed_num_frames", "fixed_stride", "at_least"], + help="Mode for frame extraction from videos ('fixed_num_frames', 'fixed_stride', or 'at_least').", + ) + parser.add_argument( + "--stride", + type=int, + default=1, + help="Stride for frame extraction from videos (only used in 'fixed_stride' mode).", + ) + parser.add_argument( + "-n", + "--num_frames", + type=int, + default=32, + help="Maximum number of frames to extract from each video, -1 for all frames.", + ) + parser.add_argument( + "-o", + "--output_folder", + type=str, + default="outputs", + help="Output folder for the preprocessed images.", + ) + parser.add_argument( + "--skip_processed_videos", + action="store_true", + help="Skip videos that have already been processed.", + ) + parser.add_argument( + "--skip_processed_frames", + action="store_true", + help="Skip frames that have already been processed.", + ) + args = parser.parse_args() + args.target_size = parse_target_size(args.target_size) + return args + + +def parse_target_size(target_size_str): + try: + width, height = map(int, target_size_str.split(",")) + return (width, height) + except ValueError: + if "none" in target_size_str.lower(): + return None + raise ValueError("Invalid target_size format. Use 'width,height' or 'none'.") + + +def main(): + args = get_args() + + model = prepare_model(args.det_thres) + + process_mixed_types( + input_folder_or_file=args.input_folder_or_file, + input_mask_folder=args.mask_folder, + model=model, + num_workers=args.num_workers, + scale=args.scale, + target_size=args.target_size, + stride=args.stride, + num_frames=args.num_frames, + mode=args.mode, + output_folder=args.output_folder, + skip_processed_videos=args.skip_processed_videos, + skip_processed_frames=args.skip_processed_frames, + ) + + exit(0) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..f9c7bcac9c5bae8bd9fbe8aeb8a15e7208070c74 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,37 @@ +[tool.ruff] +line-length = 120 + +[tool.ruff.lint] +ignore = [ + "C901", # complex condition + "E501", # line too long + "F401", # imported but unused + "F403", # from module import * used; unable to detect undefined names + "F405", # name may be undefined, or defined from star imports: module + "E741", # ambiguous variable name +] + +select = [ + "C", # flake8-comprehensions + "E", "W", # pycodestyle + "F", # pyflakes + "I", # isort +] + +[tool.ruff.lint.isort] +force-to-top = ["autoroot", "autorootcwd"] + +[tool.ruff.lint.per-file-ignores] +"**/__init__.py" = ["E402"] + +[tool.pyright] +exclude = [ + "**/__pycache__", + "wandb", + "datasets", + "outputs", + "runs", + "tmp", + "logs", +] +typeCheckingMode = "off" \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..6231b9f9435dc50c254b3e102c0a6df36d43fddd --- /dev/null +++ b/requirements.txt @@ -0,0 +1,34 @@ +torch==2.8.0 +torchaudio==2.8.0 +torchvision==0.23.0 +lightning==2.5.5 +transformers==4.56.2 +tqdm==4.67.1 # progress bar +timm==1.0.20 # torch models +matplotlib==3.10.6 # visualization +seaborn==0.13.2 # visualization +scikit-learn==1.6.1 # metrics +rich==14.1.0 # logging +wandb==0.22.0 # logging +pydantic==2.11.9 # config +# albumentations==1.4.17 # augmentations +ruff==0.13.2 # formatting +fire==0.7.0 # CLI +pytorch-metric-learning==2.8.1 # losses +peft==0.15.2 # parameter-efficient fine-tuning +ipykernel==6.30.1 # jupyter +autoroot==1.0.1 # root utils +autorootcwd==1.0.1 # root utils +xformers==0.0.32.post2 # RADIOv2.5/3 +einops==0.8.1 # RADIOv2.5/3 +open-clip-torch==2.32.0 # RADIOv2.5/3 +grad-cam==1.5.5 # for Grad-CAM visualization +mediapipe==0.10.21 # Face landmark detection + +# --- for detector.py --- +# opencv-python==4.11.0.86 # mainly only for detector.py +opencv-python-headless==4.12.0.88 # mainly for `detector.py` +onnxruntime-gpu==1.21.0 # for ONNX model inference + +# --- for app/run.py --- +gradio==5.49.1 \ No newline at end of file diff --git a/run.py b/run.py new file mode 100644 index 0000000000000000000000000000000000000000..7cb520d9183eb0c24e8ddd8788d078c10f1b5c89 --- /dev/null +++ b/run.py @@ -0,0 +1,174 @@ +import os +import traceback + +import torch +from lightning import Trainer +from lightning.pytorch import callbacks as pl_callbacks +from lightning.pytorch import loggers as pl_loggers +from rich import traceback as rich_traceback + +from src import dataset as datasets +from src.config import Config +from src.model.base import BaseDeepakeDetectionModel +from src.utils import logger +from src.utils.checks import checks +from src.utils.model_checkpoint import ModelCheckpointParallel + +rich_traceback.install() + + +def load_third_party_model(config: Config) -> BaseDeepakeDetectionModel: + if "weights/Effort" in config.checkpoint: + # Download: https://drive.google.com/drive/folders/19kQwGDjF18uk78EnnypxxOLaG4Aa4v1h + from src.model.Effort import Effort + + return Effort(config) + + if "weights/ForAda" in config.checkpoint: + # Download: https://drive.usercontent.google.com/download?id=1UlaAUTtsX87ofIibf38TtfAKIsnA7WVm&export=download&authuser=0 + from src.model.ForAda import ForAda + + return ForAda(config) + + if "weights/FS-VFM/" in config.checkpoint: + from src.model.FSFM import FSFM + + return FSFM(config) + + if "yermandy/" in config.checkpoint: + # https://huggingface.co/yermandy/models + from src.model.GenDHF import GenDHF + + return GenDHF(config) + + + raise ValueError(f"Unknown third party model in checkpoint path: {config.checkpoint}") + + +def load_model(config: Config) -> BaseDeepakeDetectionModel: + # If no checkpoint is provided, use GenD as default + if config.checkpoint is None or config.checkpoint == "": + from src.model.GenD import GenD + + return GenD(config, verbose=True) + + # Try to load third party model + try: + return load_third_party_model(config) + except ValueError: + # If not a third party model, use GenD as default + from src.model.GenD import GenD + + return GenD(config, verbose=True) + + +def init_loggers(config: Config) -> list: + save_dir = f"{config.run_dir}/{config.run_name}" + + loggers: list = [pl_loggers.CSVLogger(config.run_dir, name=config.run_name, version="")] + + if config.wandb: + wandb_logger = pl_loggers.WandbLogger( + project="deepfake", + name=config.run_name, + save_dir=save_dir, + tags=set(config.wandb_tags), + group=config.wandb_group, + ) + loggers.append(wandb_logger) + + return loggers + + +def init_callbacks(config: Config) -> list: + callbacks = [ + pl_callbacks.RichProgressBar(leave=True), + ModelCheckpointParallel( + filename=config.checkpoint_name, monitor=config.monitor_metric, mode=config.monitor_metric_mode + ), + ] + # pl_callbacks.LearningRateFinder(1e-5, 1e-2), + + if config.early_stopping_patience > 0: + callbacks.append( + pl_callbacks.EarlyStopping( + monitor=config.monitor_metric, + patience=config.early_stopping_patience, + mode=config.monitor_metric_mode, + verbose=True, + ) + ) + + return callbacks + + +def finish_wandb_run(trainer, config: Config): + if config.wandb: + if any(isinstance(l, pl_loggers.WandbLogger) for l in trainer.loggers): + wandb_logger = [l for l in trainer.loggers if isinstance(l, pl_loggers.WandbLogger)][0] + wandb_logger.finalize("success") + wandb_logger.experiment.finish() + + +def main(config: Config, train: bool): + # Performs initial checks + checks(config) + + # Set the precision for matmul operations + torch.set_float32_matmul_precision("high") + + # Instantiates the model + model = load_model(config) + + # Loads the checkpoint if provided + model.load_checkpoint(config.checkpoint) + + data_module = datasets.DeepfakeDataModule(config, model.get_preprocessing()) + + save_dir = f"{config.run_dir}/{config.run_name}" + + trainer = Trainer( + devices=config.devices, + max_epochs=config.max_epochs, + precision=config.precision, + accumulate_grad_batches=config.batch_size // config.mini_batch_size, + fast_dev_run=config.fast_dev_run, + log_every_n_steps=100, + overfit_batches=config.overfit_batches, + limit_train_batches=config.limit_train_batches, + limit_val_batches=config.limit_val_batches, + limit_test_batches=config.limit_test_batches, + deterministic=config.deterministic, + detect_anomaly=config.detect_anomaly, + logger=init_loggers(config), + callbacks=init_callbacks(config), + default_root_dir=config.run_dir, + ) + + if train: + try: + trainer.fit(model, data_module) + except KeyboardInterrupt: + logger.print_warning("Training interrupted") + except Exception as e: + traceback.print_exc() # Print complete exception traceback + logger.print_error(f"Training failed: {e}") + # Save the exception traceback to a file + with open(f"{save_dir}/failed.log", "a") as f: + f.write(f"Training failed: {e}\n") + f.write(traceback.format_exc()) + finally: + logger.print_info("Training finished. Starting testing") + ckpt_path = f"{save_dir}/checkpoints/{config.checkpoint_name}.ckpt" + if not os.path.exists(ckpt_path): + logger.print_error(f"Checkpoint {ckpt_path} does not exist. Cannot proceed with testing.") + else: + model.load_checkpoint(ckpt_path) + trainer.test(model, data_module) + + else: + assert config.checkpoint is not None, "Checkpoint is required for testing" + trainer.test(model, data_module) + + # Finish wandb run + finish_wandb_run(trainer, config) diff --git a/run_exp.py b/run_exp.py new file mode 100644 index 0000000000000000000000000000000000000000..3f484adb653ce800d306d5fe4bc21288ffd97073 --- /dev/null +++ b/run_exp.py @@ -0,0 +1,209 @@ +import traceback +from copy import deepcopy + +import fire + +from run import main +from src import config as C +from src.config import Config +from src.exp import experiments +from src.utils import files, logger + + +def get_val_files(): + return [ + *files.DeepSpeak_v2.my_val, + *files.DeepSpeak_v1_1.my_val, + *files.CDFv2.val, + *files.FFIW.val, + ] + + +def get_test_files(): + return { + "FF": files.FF.test, + "FF-DF": files.FF.DF.test, + "FF-F2F": files.FF.F2F.test, + "FF-FS": files.FF.FS.test, + "FF-NT": files.FF.NT.test, + "CDF": files.CDFv2.test, + "FaceFusion": files.FaceFusion.CDF.test, + "DFD": files.DFD.test, + "DFDC": files.DFDC.test, + "FSh": files.FSh.test, + "UADFD": files.UADFV.test, + "DFDM": files.DFDM.test, + "FFIW": files.FFIW.test, + "DeepSpeak-1.1": files.DeepSpeak_v1_1.test, + "DeepSpeak-2.0": files.DeepSpeak_v2.test, + "KoDF": files.KoDF.test, + "KoDF-adv": files.KoDF.adversarial, + "FakeAVCeleb": files.FakeAVCeleb.test, + "FAVC-FV-RA-WL": files.FakeAVCeleb.FV_RA_WL.test, + "FAVC-FV-FA-FS": files.FakeAVCeleb.FV_FA_FS.test, + "FAVC-FV-FA-GAN": files.FakeAVCeleb.FV_FA_GAN.test, + "FAVC-FV-FA-WL": files.FakeAVCeleb.FV_FA_WL.test, + "PolyGlotFake": files.PolyGlotFake.test, + "IDForge-v1": files.IDForge_v1.test, + } | { + k: v.map(lambda x: x.replace("/CDFv3/", "/CDFv3-x1.3-th0.5-all/subset/uniform-32-frames/")) + for k, v in files.CDFv3.get_test_dict().items() + } + + +def get_default_train_config() -> Config: + config = Config() + + config.run_dir = "runs/rebuttal" + config.wandb = True + config.wandb_tags.append("rebuttal") + config.throw_exception_if_run_exists = True + + config.num_workers = 12 + config.devices = "auto" + + config.backbone = C.Backbone.CLIP_L_14 + config.freeze_feature_extractor = True + config.num_classes = 2 + + config.batch_size = config.mini_batch_size = 128 + config.lr_scheduler = "cosine" + config.lr = 3e-4 + config.min_lr = 1e-5 + config.weight_decay = 0 + config.max_epochs = 1 + 50 + config.warmup_epochs = 1 + + config.trn_files = files.FF.train + config.val_files = get_val_files() + config.tst_files = get_test_files() + + return config + + +def get_default_test_config(orig_run_name, new_run_name) -> Config: + orig_run_dir = files.find_run_dir(orig_run_name) + orig_config_path = f"{orig_run_dir}/hparams.yaml" + checkpoint = "best_mAP.ckpt" # Default checkpoint name + + # Load run specific config + config = C.load_config(orig_config_path) + + config.run_name = new_run_name # Rename the run + config.run_dir = "runs/test" # Set default test dir + config.checkpoint = f"{orig_run_dir}/checkpoints/{checkpoint}" + + config.wandb = True + config.wandb_tags.extend(["test"]) + + config.num_workers = 12 + config.batch_size = config.mini_batch_size = 1024 + config.devices = "auto" + + config.tst_files = get_test_files() + + return config + + +def get_debug_config(config: Config) -> Config: + #! Debug + + config.run_dir = "runs/tmp" + config.run_name = "tmp" + # config.num_workers = 0 + config.max_epochs = 1 + config.limit_train_batches = 12 + config.limit_val_batches = 12 + config.limit_test_batches = 12 + # config.batch_size = config.mini_batch_size = 2 + # config.deterministic = True + # config.detect_anomaly = True + + config.trn_files = files.FF.train + config.val_files = files.FF.val + config.tst_files = files.FF.val + + return config + + +experiments = { + **experiments, # Include all experiments defined in src.exp +} + + +def entry( + exp_names: str | list[str], + debug: bool = False, + test: bool = False, + from_exp: str | None = None, + **kwargs, +): + if test: + if from_exp is not None: + if isinstance(exp_names, list): + if len(exp_names) != 1: + raise Exception("When running in test mode, you can provide only one experiment name.") + config = get_default_test_config(from_exp, exp_names[0]) + else: + logger.print_warning("Running in test mode, but 'from_exp' is not provided. Using default test config.") + config = C.Config() + else: + config = get_default_train_config() + + # parse name to list + if isinstance(exp_names, str): + exp_names = [exp_names] + + for exp_name in exp_names: + exp_name = exp_name.strip() + + if exp_name not in experiments: + logger.print_error(f"Experiment '{exp_name}' is not defined in 'src/exp/__init__.py:1'") + logger.print(f"Available experiments: {list(experiments.keys())}") + continue + + modifiers = experiments[exp_name] + config_exp = deepcopy(config) + + config_exp.run_name = exp_name + for modify in modifiers: + if isinstance(modify, Config): + # If the modifier is a Config object, change only different values + difference = modify.model_dump(exclude_unset=True) + # TODO: maybe set_values_from_dict(difference)? + config_exp = Config(**config_exp.model_copy(update=difference).model_dump()) + # config_exp = config_exp.model_copy(update=difference) + else: + config_exp = modify(config_exp) + + config_exp = Config(**config_exp.model_dump()) # Parse and validate config + + if debug: + config_exp = config_exp.model_copy(update=get_debug_config(config_exp).model_dump()) + + # Update config with kwargs + config_exp.set_values_from_dict(kwargs) + + # Revalidate the config - checks if user provided valid values + config_exp = Config(**config_exp.model_dump()) + + # logger.print(config_exp) + # exit() + + try: + main(config_exp, not test) + + except Exception as e: + traceback.print_exc() # Print complete exception traceback + logger.print_error(f"Error occurred while running experiment '{exp_name}':") + logger.print(e) + + save_dir = f"{config_exp.run_dir}/{config_exp.run_name}" + # Save the exception traceback to a file + with open(f"{save_dir}/failed.log", "a") as f: + f.write(f"\nTraining failed: {e}\n") + f.write(traceback.format_exc()) + + +if __name__ == "__main__": + fire.Fire(entry) diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/config.py b/src/config.py new file mode 100644 index 0000000000000000000000000000000000000000..f70b97a6e494a4763be5fa5ea1058258c1ac27a4 --- /dev/null +++ b/src/config.py @@ -0,0 +1,266 @@ +from enum import Enum +from typing import Literal, Self + +from pydantic import BaseModel as Validation +from pydantic import field_validator + +Scheduler = Literal[ + "cosine", # CosineAnnealingLR + "cyclic", # CosineAnnealingWarmRestarts +] + +Precision = Literal[ + 16, + 32, + 64, + "16", + "16-true", + "16-mixed", + "bf16-true", + "bf16-mixed", + "32", + "32-true", + "64", + "64-true", +] + + +class ValidateEnum(str, Enum): + @classmethod + def get_all_values(cls) -> list[str]: + return [value.value for value in cls] + + @classmethod + def validate(cls, value: str) -> str: + values = cls.get_all_values() + if value not in values: + raise ValueError(f"\n\nInvalid value: '{value}'\n\nPossible values are: {values}\n\nSee {__file__}\n\n") + return value + + +class Optimizer(ValidateEnum): + AdamW = "AdamW" + SGD = "SGD" + + +class InferenceStrategy(ValidateEnum): + SOFTMAX = "softmax" + + +class Head(ValidateEnum): + Linear = "linear" + NLinear = "LinearNorm" + + +class Backbone(ValidateEnum): + # https://hf.co/docs/transformers/en/model_doc/clip + # https://hf.co/openai/models?search=clip + CLIP_B_16 = "openai/clip-vit-base-patch16" + CLIP_B_32 = "openai/clip-vit-base-patch32" + CLIP_L_14 = "openai/clip-vit-large-patch14" + CLIP_L_14_336 = "openai/clip-vit-large-patch14-336" + + # https://hf.co/collections/facebook/perception-encoder-67f977c9a65ca5895a7f6ba1 + PerceptionEncoder_B_p16_224 = "vit_pe_core_base_patch16_224" # (from timm) + PerceptionEncoder_L_p14_336 = "vit_pe_core_large_patch14_336" # (from timm) + PerceptionEncoder_G_p14_448 = "vit_pe_core_gigantic_patch14_448" # (from timm) + + # https://hf.co/models?search=facebook/dinov3 + DINOv3_ViT_B = "facebook/dinov3-vitb16-pretrain-lvd1689m" + DINOv3_ViT_L = "facebook/dinov3-vitl16-pretrain-lvd1689m" + + +class BackboneArgs(Validation, validate_assignment=True): + img_size: None | int = 224 # Image size for the backbone + merge_cls_token_with_patches: None | Literal["cat", "mean"] = None # Concatenate CLS token with patches + + +class Loss(Validation, validate_assignment=True): + # Cross-entropy loss (multi-class classification) + ce_labels: float = 0.0 # Loss weight + label_smoothing: float = 0.0 # Loss weight + # Uniformity and alignment loss + uniformity: float = 0.0 # Loss weight + alignment_labels: float = 0.0 # Loss weight + + +class LoRA(Validation, validate_assignment=True): + target_modules: list[str] | str = ["out_proj"] # Target modules + rank: int = 1 # Rank of the decomposition + alpha: int = 32 # Scaling factor + dropout: float = 0.05 # Dropout probability + bias: str = "none" # Bias configuration + use_rslora: bool = False # Use rsLoRA + use_dora: bool = False # Use DoRA + + +class PEFT(Validation, validate_assignment=True): + lora: None | LoRA = None # LORA configuration + + +class CustomPreprocessing(Validation, validate_assignment=True): + zoom_factor: float = 1.0 # Zoom factor for the input images + image_size: None | list[int] = None # Target image size (width, height) + flip_left_right: bool = False # Whether to flip the image left-right (mirror) + + +class Augmentations(Validation, validate_assignment=True): + random_horizontal_flip: float = 0.5 # Probability of random horizontal flip, 0 - no augmentations + random_affine_degrees: int = 10 # Random affine rotation degrees, 0 - no rotation + random_affine_translate: None | list[float] = [0.1, 0.1] # Random affine translation, None - no translation + random_affine_scale: None | list[float] = [0.9, 1.1] # Random affine scale, None - no scaling + gaussian_blur_prob: float = 0.1 # Probability of applying Gaussian blur, 0 - no blur + gaussian_blur_kernel_size: int | list[int] = 7 # Gaussian blur kernel size, 0 - no blur + gaussian_blur_sigma: float | list[float] = [0.1, 2.0] # Gaussian blur sigma + color_jitter_brightness: float = 0.1 # Brightness jitter factor, 0 - no brightness jitter + color_jitter_contrast: float = 0.1 # Contrast jitter factor, 0 - no contrast jitter + jpeg_quality: int | list[int] = [40, 100] # JPEG quality range, 100 - no JPEG compression + resize: None | int | list[int] = None # Resize to (width, height), None - no resizing + # 0:nearest, 1:lanczos, 2:bilinear, 3:bicubic, 4:box, 5:hamming + resize_interpolation: int = 2 # Interpolation method for resizing, see InterpolationMode or Pillow integer constant + gaussian_noise_sigma: float = 0.0 # Standard deviation of Gaussian noise to add, 0 - no noise + + @staticmethod + def get_empty() -> Self: + return Augmentations( + random_horizontal_flip=0.0, + random_affine_degrees=0, + random_affine_translate=None, + random_affine_scale=None, + gaussian_blur_prob=0.0, + gaussian_blur_kernel_size=0, + gaussian_blur_sigma=0.0, + color_jitter_brightness=0.0, + color_jitter_contrast=0.0, + jpeg_quality=100, + resize=None, + ) + + +class Config(Validation, validate_assignment=True): + # Run configuration + run_name: str = "exp-name-1" # Name of the run + run_dir: str = "runs/exp" # Directory to save the run + seed: int = 42 # Random seed for reproducibility + throw_exception_if_run_exists: bool = False # Throw an exception if the run directory exists + remove_if_run_exists: bool = False # Remove existing run directory if it exists + + # Model configuration + num_classes: int = 2 + num_sources: int = 5 + checkpoint: None | str = None # Path to a checkpoint to load + backbone: str = Backbone.CLIP_B_32 # Backbone model to use + backbone_args: None | BackboneArgs = None # Arguments for the backbone model + freeze_feature_extractor: bool = True # Freeze the feature extractor + unfreeze_layers: list[str] = [] # Layers to unfreeze + head: str = Head.Linear # Head model to use + inference_strategy: str = "softmax" # Inference strategy to use + + # PEFT configuration + peft_v2: None | PEFT = None + + # Data configuration + trn_files: list[str] | dict[str, list[str]] = [] # Files containing paths to training samples + val_files: list[str] | dict[str, list[str]] = [] # Files containing paths to validation samples + tst_files: list[str] | dict[str, list[str]] = [] # Files containing paths to test samples + limit_trn_files: None | int = None # Limit the number of training files + limit_val_files: None | int = None # Limit the number of validation files + limit_tst_files: None | int = None # Limit the number of test files + binary_labels: bool = True # Use binary labels + custom_preprocessing: None | CustomPreprocessing = None # Custom preprocessing pipeline + augmentations: None | Augmentations = Augmentations() # Training augmentations + test_augmentations: None | Augmentations = None # Test-time augmentations + load_pairs: bool = False # Whether to load csv files with paired videos + + # Optimization configuration + lr: float = 0.0003 # Learning rate (initial / base) + min_lr: float = 1e-6 # Minimum learning rate + lr_scheduler: None | Scheduler = "cosine" # Learning rate scheduler + warmup_epochs: float = 0 # Number of warmup epochs (can be a fraction) + num_epochs_in_cycle: float = 1 # Number of epochs in a cycle (for cyclic schedulers) + optimizer: str = "AdamW" # Optimizer to use + weight_decay: float = 0.0 # AdamW weight decay + betas: list[float] = [0.9, 0.999] # First and second moment coefficients for SGD and AdamW + loss: Loss = Loss() # Loss function to use + + # Training configuration (managed by Lightning Trainer) + max_epochs: int = 1 # Number of epochs to train + batch_size: int = 512 # Required batch size to perform one step + mini_batch_size: int = 512 # Mini batch size per device + num_workers: int = 12 # Number of workers for the DataLoader + devices: list[int] | str | int = "auto" # Devices to use for training + precision: Precision = "bf16-mixed" # Precision for the model + fast_dev_run: int | bool = False # Run a fast development run + overfit_batches: int | float = 0.0 # Overfit on a subset of the data + limit_train_batches: None | int | float = None # Limit the number of training batches + limit_test_batches: None | int | float = None # Limit the number of test batches + limit_val_batches: None | int | float = None # Limit the number of validation batches + deterministic: None | bool = None # Set random seed for reproducibility + detect_anomaly: bool = False # Detect anomalies in the model + early_stopping_patience: int = -1 # Early stopping patience, -1 to disable + checkpoint_name: str = "best_mAP" # Checkpoint to use for testing + monitor_metric: str = "val/mAP_video" # Metric to monitor for early stopping and checkpointing + monitor_metric_mode: str = "max" # Mode for monitoring metric ("max" or "min") + + # Logging + wandb: bool = False # Log metrics to Weights & Biases + wandb_tags: list[str] = [] # Tags to use for Weights & Biases + wandb_group: None | str = None # Group to use for Weights & Biases + + # Post-processing + make_binary_before_video_aggregation: bool = True # Make binary labels before video aggregation + reduce_video_predictions: Literal["mean", "median"] = "mean" # Reduce strategy for frame to video probs + + # Validation + @field_validator("head") + @classmethod + def validate_head(cls, head: str) -> str: + return Head.validate(head) + + @field_validator("backbone") + @classmethod + def validate_backbone(cls, backbone: str) -> str: + return Backbone.validate(backbone) + + @field_validator("inference_strategy") + @classmethod + def validate_inference_strategy(cls, inference_strategy: str) -> str: + return InferenceStrategy.validate(inference_strategy) + + @field_validator("optimizer") + @classmethod + def validate_optimizer(cls, optimizer: str) -> str: + return Optimizer.validate(optimizer) + + def set_values_from_dict(self, dict: dict) -> Self: + """ + Set values in the config from a dictionary. The keys of the dictionary can be + either the names of the attributes in the config or a dot-separated path to the + attribute. For example, if the config has an attribute `a.b.c`, you can set its + value by passing a dictionary with the key `a.b.c`. + """ + # Iterate over the dictionary and set the values in the config + for key, value in dict.items(): + # If key contains a dot, traverse the config to the last key + if "." in key: + keys = key.split(".") + # Traverse the config to the last key + last_dict = self + for next_key in keys[:-1]: + last_dict = getattr(last_dict, next_key) + setattr(last_dict, keys[-1], value) + else: + setattr(self, key, value) + return self + + +def load_config(path: str) -> Config: + import yaml + + # read yaml config + with open(path, "r") as f: + config = yaml.safe_load(f) + + # overwrite config + config = Config(**config) + return config diff --git a/src/dataset/__init__.py b/src/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..34e8ec5a3fcbde4eb6e4ef0fe0d735873aaab572 --- /dev/null +++ b/src/dataset/__init__.py @@ -0,0 +1,2 @@ +from .base import BaseDataModule +from .data_module import DeepfakeDataModule diff --git a/src/dataset/augmentations.py b/src/dataset/augmentations.py new file mode 100644 index 0000000000000000000000000000000000000000..85bbfb4c5955f5f1f89a004fdad0df3731253ec6 --- /dev/null +++ b/src/dataset/augmentations.py @@ -0,0 +1,69 @@ +from torchvision.transforms import v2 as T + +from src.config import Augmentations + + +def init_augmentations(augs: Augmentations): + # TODO: for each augmentation, add a probability parameter to the config + if augs is None: + return None + + composed_transforms = [] + + if augs.random_horizontal_flip != 0.0: + composed_transforms.append(T.RandomHorizontalFlip(p=augs.random_horizontal_flip)) + + if ( + augs.random_affine_degrees != 0 + or augs.random_affine_translate is not None + or augs.random_affine_scale is not None + ): + composed_transforms.append( + T.RandomAffine( + degrees=augs.random_affine_degrees, + translate=augs.random_affine_translate, + scale=augs.random_affine_scale, + ) + ) + + if augs.gaussian_blur_prob != 0.0: + ks = augs.gaussian_blur_kernel_size + if (isinstance(ks, int) and ks != 0) or (isinstance(ks, list) and sum(ks) != 0): + composed_transforms.append( + T.RandomApply( + [T.GaussianBlur(kernel_size=ks, sigma=augs.gaussian_blur_sigma)], + p=augs.gaussian_blur_prob, + ) + ) + + if augs.color_jitter_brightness != 0.0 or augs.color_jitter_contrast != 0.0: + composed_transforms.append( + T.ColorJitter( + brightness=augs.color_jitter_brightness, + contrast=augs.color_jitter_contrast, + ) + ) + + if (isinstance(augs.jpeg_quality, int) and augs.jpeg_quality != 100) or ( + isinstance(augs.jpeg_quality, list) and augs.jpeg_quality[0] != 100 + ): + composed_transforms.append(T.JPEG(augs.jpeg_quality)) + + if augs.resize is not None: + composed_transforms.append(T.Resize(augs.resize, augs.resize_interpolation)) + + if augs.gaussian_noise_sigma != 0.0: + composed_transforms.append( + T.Compose( + [ + T.ToTensor(), + T.GaussianNoise(0.0, augs.gaussian_noise_sigma), + T.ToPILImage(), + ] + ) + ) + + if len(composed_transforms) == 0: + return None + + return T.Compose(composed_transforms) diff --git a/src/dataset/base.py b/src/dataset/base.py new file mode 100644 index 0000000000000000000000000000000000000000..a64fe023157a6d6800c13177779d118207fd92d6 --- /dev/null +++ b/src/dataset/base.py @@ -0,0 +1,98 @@ +from abc import abstractmethod +from typing import Callable, Optional + +import lightning as pl +import numpy as np +from PIL import Image +from torch.utils.data import DataLoader, Dataset + +from src.config import Config +from src.utils.logger import print + + +class BaseDataset(Dataset): + def __init__( + self, + files: list[str], + labels: list[int], + preprocess: None | Callable = None, + augmentations: None | Callable = None, + shuffle: bool = False, # Shuffles the dataset once + dataset2files: Optional[dict[str, list[str]]] = None, + ): + self.files = files + self.labels = labels + + self.preprocess = preprocess + self.augmentations = augmentations + + self.dataset2files = dataset2files + + if shuffle: + self.shuffle() + + def shuffle(self): + # create fixed seed for reproducibility + idx = np.random.RandomState(42).permutation(len(self.files)) + self.files = [self.files[i] for i in idx] + self.labels = [self.labels[i] for i in idx] + + def __len__(self): + return len(self.files) + + def __getitem__(self, idx): + path = self.files[idx] + image = Image.open(path) + if self.augmentations is not None: + image = self.augmentations(image) + if self.preprocess is not None: + image = self.preprocess(image) + return { + "image": image, + "label": self.labels[idx], + "path": path, + } + + def print_statistics(self): + print(f"Number of samples: {len(self.files)}") + unique, counts = np.unique(self.labels, return_counts=True) + print("Class distribution") + names = self.get_class_names() + for u, c in zip(unique, counts): + print(f"Class {u} ({names[u]}): {c}") + + @abstractmethod + def get_class_names(self) -> dict[int, str]: + raise NotImplementedError + + +class BaseDataModule(pl.LightningDataModule): + def __init__(self, config: Config, preprocess: None | Callable = None): + super().__init__() + self.config = config + self.preprocess = preprocess + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.config.mini_batch_size, + num_workers=self.config.num_workers, + pin_memory=True, + shuffle=True, + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.config.mini_batch_size, + num_workers=self.config.num_workers, + pin_memory=True, + ) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.config.mini_batch_size, + num_workers=self.config.num_workers, + pin_memory=True, + ) diff --git a/src/dataset/data_module.py b/src/dataset/data_module.py new file mode 100644 index 0000000000000000000000000000000000000000..749606ae79ec47c48828c2de8e3f18a200ae498a --- /dev/null +++ b/src/dataset/data_module.py @@ -0,0 +1,76 @@ +from typing import Callable + +from torch.utils.data import DataLoader + +from src.config import Config +from src.utils import logger + +from .augmentations import init_augmentations +from .base import BaseDataModule +from .dataset import DeepfakeDataset + + +class DeepfakeDataModule(BaseDataModule): + def __init__(self, config: Config, preprocess: None | Callable = None): + super().__init__(config, preprocess) + + def setup(self, stage: str): + # Initialize datasets + if stage == "fit" or stage == "validate": + logger.print("\n[blue]Creating training dataset") + self.train_dataset = DeepfakeDataset( + self.config.trn_files, + self.preprocess, + augmentations=init_augmentations(self.config.augmentations), + binary=self.config.binary_labels, + limit_files=self.config.limit_trn_files, + load_pairs=self.config.load_pairs, + ) + self.train_dataset.print_statistics() + + logger.print("\n[blue]Creating validation dataset") + self.val_dataset = DeepfakeDataset( + self.config.val_files, + self.preprocess, + shuffle=True, + binary=self.config.binary_labels, + limit_files=self.config.limit_val_files, + ) + self.val_dataset.print_statistics() + + if stage == "test": + logger.print("\nCreating test dataset") + self.test_dataset = DeepfakeDataset( + self.config.tst_files, + self.preprocess, + augmentations=init_augmentations(self.config.test_augmentations), + binary=self.config.binary_labels, + limit_files=self.config.limit_tst_files, + ) + self.test_dataset.print_statistics() + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.config.mini_batch_size, + num_workers=self.config.num_workers, + pin_memory=True, + shuffle=True, + drop_last=True, + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.config.mini_batch_size, + num_workers=self.config.num_workers, + pin_memory=True, + ) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.config.mini_batch_size, + num_workers=self.config.num_workers, + pin_memory=True, + ) diff --git a/src/dataset/dataset.py b/src/dataset/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0b1986eb2534f27e5de630fbd2a07b69cebc241c --- /dev/null +++ b/src/dataset/dataset.py @@ -0,0 +1,253 @@ +import os +from concurrent.futures import ThreadPoolExecutor +from functools import partial +from typing import Callable + +import numpy as np +import pandas as pd +from PIL import Image, ImageFile +from tqdm import tqdm + +from src.utils import logger +from src.utils.logger import print + +from .base import BaseDataset + +ImageFile.LOAD_TRUNCATED_IMAGES = True + + +class DeepfakeDataset(BaseDataset): + """ + DeepfakeDataset is any dataset that follows this structure: + ... / / / / + + - Name of the dataset, e.g. FF, CDF, DFD, DFDC... + - Name of the source, e.g. real, fake, or any name of generator, e.g. FaceSwap, Face2Face... + - Name of the video, e.g. 000, 000_003, ... + - Any name of the frame, e.g. 000001.jpg, 000002.jpg, ... + + Labels are automatically created from such that: + - If has "real" substring, the label is 0 + - Otherwise, the label is 1 + + """ + + def __init__( + self, + files_with_paths: list[str] | dict[str, list[str]], + preprocess: None | Callable = None, + augmentations: None | Callable = None, + shuffle: bool = False, # Shuffles the dataset once + binary: bool = False, + limit_files: None | int = None, + load_pairs: bool = False, + ): + files = [] + labels = [] + logger.print_info("Loading files") + + if binary: + label2name = {0: "real", 1: "fake"} + else: + raise NotImplementedError("Only binary classification is supported now") + + source2label = {v: k for k, v in label2name.items()} + + self.label2name = label2name + + dataset2files = None + + if isinstance(files_with_paths, dict): + dataset2files_with_paths = files_with_paths.copy() + dataset2files = {dataset_name: [] for dataset_name in dataset2files_with_paths.keys()} + files_with_paths = [item for sublist in files_with_paths.values() for item in sublist] + + max_workers = min(64, os.cpu_count()) + + for file_with_paths in sorted(set(files_with_paths)): + with open(file_with_paths, "r") as f: + paths = f.readlines() + paths = [path.strip() for path in paths] + + # If files do not exist, append root of 'file' to each path + root = file_with_paths.rsplit("/", 1)[0] + + def process_path(root, path): + if not os.path.exists(path): + path = f"{root}/{path}" + assert os.path.exists(path), f"File not found: {path}" + return path + + with ThreadPoolExecutor(max_workers) as executor: + process_with_root = partial(process_path, root) + paths = list( + tqdm( + executor.map(process_with_root, paths), + total=len(paths), + desc=f"Processing paths in {file_with_paths}", + leave=True, + ) + ) + + files.extend(paths) + + if dataset2files is not None: + for dataset_name, files_with_paths in dataset2files_with_paths.items(): + if file_with_paths in files_with_paths: + dataset2files[dataset_name].extend(paths) + + # Remove duplicate paths + files = np.unique(files).tolist() + + # Limit the number of files + if limit_files is not None: + files = self.limit_files(files, limit_files) + + # Get labels from paths + for path in files: + source = self.get_source_from_file(path) + + if binary: + if "real" in source: + source = "real" + else: + source = "fake" + + label = source2label[source] + labels.append(label) + + logger.print_info("Files loaded") + + super().__init__(files, labels, preprocess, augmentations, shuffle, dataset2files) + + self.source2uid = self._source2uid() + self.video_path2uid = self._video_path2uid() + + self.file2index = {f: i for i, f in enumerate(self.files)} + + def limit_files(self, files: list[str], limit: int) -> list[str]: + """Limits number of files by considering unique videos""" + # Select unique videos + video_paths = [self.get_video_path(file) for file in files] + unique_videos = list(np.unique(video_paths)) + + # For each video, select files + video2files = {video: [] for video in unique_videos} + for file, video in zip(files, video_paths): + video2files[video].append(file) + + # Shuffle videos with fixed seed + np.random.RandomState(42).shuffle(unique_videos) + + # Select files from shuffled videos + selected_files = [] + for video in unique_videos: + selected_files.extend(video2files[video]) + + if len(selected_files) >= limit: + break + + return selected_files[:limit] + + def _source2uid(self) -> dict[str, int]: + sources = [self.get_source_from_file(file) for file in self.files] + sources = np.unique(sources) + + assert any("real" in g for g in sources), "No real source found" + sources = [str(g) for g in sources] + + # Map all real sources to 0 and fake sources to 1, 2, 3, ... + real_sources = [g for g in sources if "real" in g] + fake_sources = [g for g in sources if "real" not in g] + + source2uid = dict.fromkeys(real_sources, 0) + for i, s in enumerate(fake_sources, start=1): + source2uid[s] = i + + return source2uid + + def _video_path2uid(self) -> dict[str, int]: + video_paths = [self.get_video_path(file) for file in self.files] + unique_videos = list(np.unique(video_paths)) + return {video: i for i, video in enumerate(unique_videos)} + + @staticmethod + def get_frame_from_file(file_path: str) -> str: + # ... / / / / + # returns + return file_path.split("/")[-1] + + @staticmethod + def get_video_from_file(file_path: str) -> str: + # ... / / / / + # returns + return file_path.split("/")[-2] + + @staticmethod + def get_source_from_file(file_path: str) -> str: + # ... / / / / + # returns + return file_path.split("/")[-3] + + @staticmethod + def get_dataset_from_file(file_path: str) -> str: + # ... / / / / + # returns + return file_path.split("/")[-4] + + @staticmethod + def get_video_path(file_path: str) -> str: + # ... / / / / + # file_path[::-1].find("/") finds the last occurrence of "/" + # returns ...//// + return file_path[: -file_path[::-1].find("/")] + + def get_class_names(self) -> dict[int, str]: + return self.label2name + + def print_statistics(self): + super().print_statistics() + + video_paths = [self.get_video_path(file) for file in self.files] + + files_by_dataset = [self.get_dataset_from_file(file) for file in self.files] + + print(f"Total number of frames: {len(self.files)}") + print(f"Total number of videos: {len(set(video_paths))}") + + # For each dataset, print number of frames and videos + df = pd.DataFrame({"dataset": files_by_dataset, "video": video_paths}) + + for dataset in df["dataset"].unique(): + dataset_df = df[df["dataset"] == dataset] + videos_count = dataset_df["video"].nunique() + frames_count = len(dataset_df) + print(f"Dataset: {dataset}, videos: {videos_count}, frames: {frames_count}") + + def __getitem__(self, idx): + path = self.files[idx] + image = Image.open(path) + source = self.get_source_from_file(path) + video_path = self.get_video_path(path) + label = self.labels[idx] + + # Apply augmentations defined in from config.Augmentations + if self.augmentations is not None: + image = self.augmentations(image) + + # Apply preprocessing defined by the model input requirements + if self.preprocess is not None: + image = self.preprocess(image) + + output = { + "idx": idx, + "image": image, + "label": label, + "path": path, + "video": self.get_video_from_file(path), + "source_uid": self.source2uid[source], + "frame": self.get_frame_from_file(path), + "video_uid": self.video_path2uid[video_path], + } + + return output diff --git a/src/encoders/_common.py b/src/encoders/_common.py new file mode 100644 index 0000000000000000000000000000000000000000..10f40df13eda7622baaefdfde31540338e3c3c5b --- /dev/null +++ b/src/encoders/_common.py @@ -0,0 +1,21 @@ +import requests +import torch +from PIL import Image + + +def inference(model): + for name, param in model.named_parameters(): + print(name, param.shape) + print() + + print(model) + print() + + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw) + + preprocessed = [model.preprocess(image) for image in [image, image]] + preprocessed = torch.stack(preprocessed) + outputs = model(preprocessed) + + print(outputs.shape) diff --git a/src/encoders/clip_encoder.py b/src/encoders/clip_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..f371af8be0b711e2db44dbf8847ab88d2ca23c0e --- /dev/null +++ b/src/encoders/clip_encoder.py @@ -0,0 +1,54 @@ +import torch +import torch.nn as nn +from PIL import Image +from transformers import CLIPModel, CLIPProcessor + + +class CLIPEncoder(nn.Module): + def __init__(self, model_name="openai/clip-vit-large-patch14"): + """ + Models: + 1. openai/clip-vit-base-patch16 | 768 features + 2. openai/clip-vit-base-patch32 | 768 features + 3. openai/clip-vit-large-patch14 | 1024 features + + See more in src/config.py + """ + + super().__init__() + + try: + self._preprocess = CLIPProcessor.from_pretrained(model_name) + except Exception: + self._preprocess = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16") + + clip: CLIPModel = CLIPModel.from_pretrained(model_name) + + # take vision model from CLIP, maps image to vision_embed_dim + self.vision_model = clip.vision_model + + self.model_name = model_name + + self.features_dim = self.vision_model.config.hidden_size + + # take visual_projection, maps vision_embed_dim to projection_dim + self.visual_projection = clip.visual_projection + + def preprocess(self, image: Image) -> torch.Tensor: + return self._preprocess(images=image, return_tensors="pt")["pixel_values"][0] + + def forward(self, preprocessed_images: torch.Tensor) -> torch.Tensor: + return self.vision_model(preprocessed_images).pooler_output + + def get_features_dim(self): + return self.features_dim + + +if __name__ == "__main__": + import autorootcwd # noqa: F401 + + from src.config import Backbone + from src.encoders._common import inference + + model = CLIPEncoder(Backbone.CLIP_B_16.value) + inference(model) diff --git a/src/encoders/dino_encoder.py b/src/encoders/dino_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..4fa1c8a0a5b9f667bfa1dc60600922f76ba4ffd2 --- /dev/null +++ b/src/encoders/dino_encoder.py @@ -0,0 +1,59 @@ +import torch +import torch.nn as nn +from PIL import Image +from transformers import AutoImageProcessor, AutoModel, Dinov2Model, Dinov2WithRegistersModel + + +class DINOEncoder(nn.Module): + def __init__( + self, model_name="facebook/dinov2-with-registers-base", merge_cls_token_with_patches: None | str = None + ): + """ + See models in src/config.py + """ + + super().__init__() + + self._preprocess = AutoImageProcessor.from_pretrained(model_name) + self.backbone: Dinov2Model | Dinov2WithRegistersModel = AutoModel.from_pretrained(model_name) + self.merge_cls_token_with_patches = merge_cls_token_with_patches + + self.features_dim = self.backbone.config.hidden_size + if self.merge_cls_token_with_patches == "cat": + self.features_dim *= 2 + + self.merge_cls_token_with_patches = merge_cls_token_with_patches + + def preprocess(self, image: Image) -> torch.Tensor: + return self._preprocess(images=image, return_tensors="pt")["pixel_values"][0] + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + outputs = self.backbone(inputs) + + cls_token = outputs.last_hidden_state[:, 0] + + if self.merge_cls_token_with_patches is None: + embeddings = cls_token + elif self.merge_cls_token_with_patches == "cat": + patches = outputs.last_hidden_state[:, -16 * 16 :].mean(dim=1) + embeddings = torch.cat([cls_token, patches], dim=1) + elif self.merge_cls_token_with_patches == "mean": + patches = outputs.last_hidden_state[:, -16 * 16 :].mean(dim=1) + embeddings = (cls_token + patches) / 2 + else: + raise ValueError(f"Unknown merge_cls_token_with_patches strategy: {self.merge_cls_token_with_patches}") + + return embeddings + + def get_features_dim(self) -> int: + return self.features_dim + + +if __name__ == "__main__": + import autorootcwd # noqa: F401 + + from src.config import Backbone + from src.encoders._common import inference + + model = DINOEncoder(Backbone.DINOv3_ViT_B.value, None) + inference(model) diff --git a/src/encoders/perception_encoder.py b/src/encoders/perception_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..d19e91e27cf2cacb3e09d1cc6e4381c8b2c7f956 --- /dev/null +++ b/src/encoders/perception_encoder.py @@ -0,0 +1,55 @@ +import timm +import torch +import torch.nn as nn +from PIL import Image +from timm.models.eva import Eva + + +class PerceptionEncoder(nn.Module): + def __init__( + self, + model_name="vit_pe_core_large_patch14_336", + img_size: None | int = None, + ): + super().__init__() + + if img_size is not None: + dynamic_img_size = True + + self.backbone: Eva = timm.create_model( + model_name, + pretrained=True, + dynamic_img_size=dynamic_img_size, + ) + + # Get model specific transforms (normalization, resize) + data_config = timm.data.resolve_model_data_config(self.backbone) + + if img_size is not None: + data_config["input_size"] = (3, img_size, img_size) + + self._preprocess = timm.data.create_transform(**data_config, is_training=False) + + # Remove head + self.backbone.head = nn.Identity() + + self.features_dim = self.backbone.num_features + + def preprocess(self, image: Image.Image) -> torch.Tensor: + return self._preprocess(image) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + return self.backbone(inputs) + + def get_features_dim(self) -> int: + return self.features_dim + + +if __name__ == "__main__": + import autorootcwd # noqa: F401 + + from src.config import Backbone + from src.encoders._common import inference + + model = PerceptionEncoder(Backbone.PerceptionEncoder_B_p16_224.value, img_size=224) + inference(model) diff --git a/src/exp/__init__.py b/src/exp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..63a6cb652db7c07b6b26a1f1ee3b4ac21cc2cdb6 --- /dev/null +++ b/src/exp/__init__.py @@ -0,0 +1,15 @@ +from . import ( + examples, + third_party, + wacv_rebuttal, + wacv_rebuttal_aug_robustness, + wacv_rebuttal_paired_unpaired, +) + +experiments = { + **examples.experiments, + **third_party.experiments, + **wacv_rebuttal.experiments, + **wacv_rebuttal_paired_unpaired.experiments, + **wacv_rebuttal_aug_robustness.experiments, +} diff --git a/src/exp/examples.py b/src/exp/examples.py new file mode 100644 index 0000000000000000000000000000000000000000..7684deee5516d5e991d4473f142a5e0a54882cb5 --- /dev/null +++ b/src/exp/examples.py @@ -0,0 +1,94 @@ +from .. import config as C +from ..config import Config + +experiments = { + "example-training": [ + Config( + backbone=C.Backbone.CLIP_L_14, + head=C.Head.Linear, + unfreeze_layers=["pre_layrnorm", "layer_norm1", "layer_norm2", "post_layernorm"], + loss=C.Loss(ce_labels=1.0), + run_dir="runs/example", + trn_files=[ + "config/datasets/FF/test/DF.txt", + "config/datasets/FF/test/F2F.txt", + "config/datasets/FF/test/FS.txt", + "config/datasets/FF/test/NT.txt", + "config/datasets/FF/test/real.txt", + ], + val_files=[ + "config/datasets/FF/test/DF.txt", + "config/datasets/FF/test/F2F.txt", + "config/datasets/FF/test/FS.txt", + "config/datasets/FF/test/NT.txt", + "config/datasets/FF/test/real.txt", + ], + tst_files=[ + "config/datasets/CDFv2/test/Celeb-real.txt", + "config/datasets/CDFv2/test/Celeb-synthesis.txt", + "config/datasets/CDFv2/test/YouTube-real.txt", + ], + batch_size=2, + mini_batch_size=2, + max_epochs=1, + wandb=False, + devices=[0], + ) + ], + "example-test": [ + Config( + run_dir="runs/test", + tst_files=[ + "config/datasets/CDFv2/test/Celeb-real.txt", + "config/datasets/CDFv2/test/Celeb-synthesis.txt", + "config/datasets/CDFv2/test/YouTube-real.txt", + ], + batch_size=128, + mini_batch_size=128, + wandb=False, + devices=[0], + ) + ], + "GenD_CLIP--CDFv2-example": [ + Config( + run_dir="runs/test", + tst_files=[ + "config/datasets/CDFv2/test/Celeb-real.txt", + "config/datasets/CDFv2/test/Celeb-synthesis.txt", + "config/datasets/CDFv2/test/YouTube-real.txt", + ], + checkpoint="yermandy/GenD_CLIP_L_14", + max_epochs=1, + wandb=False, + devices=[0], + ) + ], + "GenD_PE--CDFv2-example": [ + Config( + run_dir="runs/test", + tst_files=[ + "config/datasets/CDFv2/test/Celeb-real.txt", + "config/datasets/CDFv2/test/Celeb-synthesis.txt", + "config/datasets/CDFv2/test/YouTube-real.txt", + ], + checkpoint="yermandy/GenD_PE_L", + max_epochs=1, + wandb=False, + devices=[0], + ) + ], + "GenD_DINO--CDFv2-example": [ + Config( + run_dir="runs/test", + tst_files=[ + "config/datasets/CDFv2/test/Celeb-real.txt", + "config/datasets/CDFv2/test/Celeb-synthesis.txt", + "config/datasets/CDFv2/test/YouTube-real.txt", + ], + checkpoint="yermandy/GenD_DINOv3_L", + max_epochs=1, + wandb=False, + devices=[0], + ) + ], +} diff --git a/src/exp/third_party.py b/src/exp/third_party.py new file mode 100644 index 0000000000000000000000000000000000000000..53e316be981204ac39a136a0dffc7698085b7f13 --- /dev/null +++ b/src/exp/third_party.py @@ -0,0 +1,80 @@ +from .. import config as C +from ..config import Config +from ..utils import files + +experiments = { + "Effort-tmp": [ + Config( + checkpoint="weights/effort/effort_clip_L14_trainOn_FaceForensic.pth", + ), + ], + "ForAda-tmp": [ + Config( + checkpoint="weights/forensics_adapter/ForensicsAdapter.pth", + ), + ], + **{ + f"FS-VFM-{zoom_factor}-bilinear": [ + Config( + checkpoint="weights/FS-VFM/FS-VFM-ViT-L.pth", + custom_preprocessing=C.CustomPreprocessing(zoom_factor=zoom_factor), + mini_batch_size=1024, + batch_size=1024, + ), + ] + for zoom_factor in [1.0, 1.1, 1.2, 1.3, 1.4, 1.5] + }, +} + + +def get_common(): + config = Config() + config.run_dir = "runs/test" + config.num_workers = 12 + config.wandb = True + config.wandb_tags = ["test"] + + config.tst_files = { + "FF": files.FF.test, + "FF-DF": files.FF.DF.test, + "FF-F2F": files.FF.F2F.test, + "FF-FS": files.FF.FS.test, + "FF-NT": files.FF.NT.test, + "CDF": files.CDFv2.test, + "FaceFusion": files.FaceFusion.CDF.test, + "DFD": files.DFD.test, + "DFDC": files.DFDC.test, + "FSh": files.FSh.test, + "UADFD": files.UADFV.test, + "DFDM": files.DFDM.test, + "FFIW": files.FFIW.test, + "DeepSpeak-1.1": files.DeepSpeak_v1_1.test, + "DeepSpeak-2.0": files.DeepSpeak_v2.test, + "KoDF": files.KoDF.test, + "KoDF-adv": files.KoDF.adversarial, + "FakeAVCeleb": files.FakeAVCeleb.test, + "FAVC-FV-RA-WL": files.FakeAVCeleb.FV_RA_WL.test, + "FAVC-FV-FA-FS": files.FakeAVCeleb.FV_FA_FS.test, + "FAVC-FV-FA-GAN": files.FakeAVCeleb.FV_FA_GAN.test, + "FAVC-FV-FA-WL": files.FakeAVCeleb.FV_FA_WL.test, + "PolyGlotFake": files.PolyGlotFake.test, + "IDForge-v1": files.IDForge_v1.test, + } | { + k: v.map(lambda x: x.replace("/CDFv3/", "/CDFv3-x1.3-th0.5-all/subset/uniform-32-frames/")) + for k, v in files.CDFv3.get_test_dict().items() + } + + return config + + +def set_common_settings(experiments): + for run_name, modifieres in experiments.items(): + experiments[run_name][0] = Config( + **{ + **get_common().model_dump(exclude_unset=True), # get default settings + **modifieres[0].model_dump(exclude_unset=True), # override with specific experiment settings + } + ) + + +set_common_settings(experiments) diff --git a/src/exp/wacv.sh b/src/exp/wacv.sh new file mode 100644 index 0000000000000000000000000000000000000000..58e9851956edadd9eac3e8e55135b18ad0c9e1aa --- /dev/null +++ b/src/exp/wacv.sh @@ -0,0 +1,166 @@ +P=amdgpudeadline,amdgpu +#! CLIP +# seeds=(0 1 2 3 4) + +# for seed in ${seeds[@]}; do +# sh run_job.sh -p $P wacv-Baseline-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True +# sh run_job.sh -p $P wacv-LN-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True +# sh run_job.sh -p $P wacv-LN+L2-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True +# sh run_job.sh -p $P wacv-LN+L2+UnAl-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True +# #! noaug +# sh run_job.sh -p $P wacv-LN-noaug-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True +# done + +#! CLIP+components without LN +# seeds=(0 1 2 3 4) + +# for seed in ${seeds[@]}; do +# sh run_job.sh -p $P wacv-NoLN-L2-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True +# sh run_job.sh -p $P wacv-NoLN-L2+UnAl-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True +# done + +#! PE +# seeds=(0 1 2 3 4) + +# for seed in ${seeds[@]}; do + # sh run_job.sh -p $P wacv-PE_L-Baseline-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True + # sh run_job.sh -p $P wacv-PE_L-LN-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True + # sh run_job.sh -p $P wacv-PE_L-LN+L2-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True + # sh run_job.sh -p $P wacv-PE_L-LN+L2+UnAl-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True + #! experimental + # sh run_job.sh -p $P wacv-PE_L-LN+L2+UA-U1.0-A0.5-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True +# sh run_job.sh -p $P wacv-PE_L-LN+L2+UA-U1.0-A0.1-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True +# sh run_job.sh -p $P wacv-PE_L-LN+L2+UA-U0.5-A0.0-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True +# sh run_job.sh -p $P wacv-PE_L-LN+L2+UA-U1.0-A0.0-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True +# #! noaug +# sh run_job.sh -p $P wacv-PE_L-LN-noaug-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True +# done + +#! DINOv3 +# seeds=(0 1 2 3 4) + +# for seed in ${seeds[@]}; do +# sh run_job.sh -p $P wacv-DINOv3L-Baseline-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True +# sh run_job.sh -p $P wacv-DINOv3L-LN-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True +# sh run_job.sh -p $P wacv-DINOv3L-LN+L2-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True +# sh run_job.sh -p $P wacv-DINOv3L-LN+L2+UnAl-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True + # sh run_job.sh -p $P wacv-DINOv3L-LN+L2+UA-U0.5-A0.5-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True +# sh run_job.sh -p $P wacv-DINOv3L-LN+L2+UA-U0.1-A0.5-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True + #! noaug +# sh run_job.sh -p $P wacv-DINOv3L-LN-noaug-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True +# done + +#! Paired / Unpaired +# Seeds from 00 to 19 +# seeds=(00 01 02 03 04 05 06 07 08 09 10 11 12 13 14 15 16 17 18 19) + +# for seed in ${seeds[@]}; do +# sh run_job.sh -p $P wacv-paired-LN-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True +# sh run_job.sh -p $P wacv-unpaired-LN-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True +# sh run_job.sh -p $P wacv-unpaired-PE_L-LN+L2-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True +# sh run_job.sh -p $P wacv-paired-PE_L-LN+L2-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True +# done + +# for seed in ${seeds[@]}; do +# sh run_job.sh -p $P wacv-paired-CDFv2-LN-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True +# sh run_job.sh -p $P wacv-unpaired-CDFv2-LN-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True +# done + +#? For now only seeds 0-5 +# seeds=(00 01 02 03 04 05) +# for seed in ${seeds[@]}; do +# sh run_job.sh -p $P wacv-paired-FAVC-LN-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True +# sh run_job.sh -p $P wacv-unpaired-FAVC-LN-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True +# done + +# seeds=(00 01 02 03 04 05) +# for seed in ${seeds[@]}; do +# sh run_job.sh -p $P wacv-paired-FAVC-PE_L-LN-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True +# sh run_job.sh -p $P wacv-unpaired-FAVC-PE_L-LN-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True +# done + +#! JPEG compression robustness +# jpeg_qualities=(100 80 60 40 20 10) +# jpeg_qualities=(100) + +# for q in ${jpeg_qualities[@]}; do +# sh run_test_effort.sh Effort-jpeg${q} --test_augmentations.jpeg_quality "[${q},${q}]" --throw_exception_if_run_exists False --remove_if_run_exists True +# sh run_test_forada.sh ForAda-jpeg${q} --test_augmentations.jpeg_quality "[${q},${q}]" --throw_exception_if_run_exists False --remove_if_run_exists True +# done + +# seeds=(0 1 2 3 4) +# jpeg_qualities=(100 80 60 40 20 10) + +# for seed in ${seeds[@]}; do +# for q in ${jpeg_qualities[@]}; do +# sh run_job.sh -p amdgpufast,amdgpudeadline wacv-LN+L2+UnAl-seed${seed}-jpeg${q} --test --from_exp wacv-LN+L2+UnAl-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True +# sh run_job.sh -p amdgpufast,amdgpudeadline wacv-PE_L-LN+L2-seed${seed}-jpeg${q} --test --from_exp wacv-PE_L-LN+L2-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True +# done +# done + +#! Gaussian blur tests +# blur_kernel_sizes=(0 5 7 11 13 19) +# blur_sigmas=(0.0 0.5 1.0 1.5 2.0 3.0) +# blur_kernel_sizes=(0) +# blur_sigmas=(0.0) + +# for i in ${!blur_kernel_sizes[@]}; do +# k=${blur_kernel_sizes[$i]} +# s=${blur_sigmas[$i]} +# sh run_test_effort.sh Effort-blur-${k}-${s} --test_augmentations.gaussian_blur_kernel_size ${k} --test_augmentations.gaussian_blur_sigma "[${s},${s}]" --test_augmentations.gaussian_blur_prob 1.0 --throw_exception_if_run_exists False --remove_if_run_exists True +# sh run_test_forada.sh ForAda-blur-${k}-${s} --test_augmentations.gaussian_blur_kernel_size ${k} --test_augmentations.gaussian_blur_sigma "[${s},${s}]" --test_augmentations.gaussian_blur_prob 1.0 --throw_exception_if_run_exists False --remove_if_run_exists True +# done + +# seeds=(0 1 2 3 4) +# for seed in ${seeds[@]}; do +# for i in ${!blur_kernel_sizes[@]}; do +# k=${blur_kernel_sizes[$i]} +# s=${blur_sigmas[$i]} +# sh run_job.sh -p amdgpufast,amdgpudeadline wacv-PE_L-LN+L2-seed${seed}-blur-${k}-${s} --test --from_exp wacv-PE_L-LN+L2-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True +# done +# done + +#! Resize tests +# resize=(224 112 64) +# interpolation=(0 1 2 3 4 5) # 0=NEAREST, 1=BILINEAR, 2=BICUBIC, 3=BOX, 4=HAMMING, 5=LANCZOS + +# for r in ${resize[@]}; do +# for i in ${interpolation[@]}; do +# sh run_test_effort.sh Effort-resize-${r}-${i} --test_augmentations.resize ${r} --test_augmentations.resize_interpolation ${i} --throw_exception_if_run_exists False --remove_if_run_exists True +# sh run_test_forada.sh ForAda-resize-${r}-${i} --test_augmentations.resize ${r} --test_augmentations.resize_interpolation ${i} --throw_exception_if_run_exists False --remove_if_run_exists True +# done +# done + +# seeds=(0 1 2 3 4) + +# for r in ${resize[@]}; do +# for i in ${interpolation[@]}; do +# for seed in ${seeds[@]}; do +# sh run_job.sh -p amdgpufast,amdgpudeadline wacv-PE_L-LN+L2-seed${seed}-resize-${r}-${i} --test --from_exp wacv-PE_L-LN+L2-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True +# done +# done +# done + + +#! DEBUG + +# sh run_test_effort.sh Effort-blur-0-0.0-tmp --test_augmentations.gaussian_blur_kernel_size 0 --test_augmentations.gaussian_blur_sigma "[0.0, 0.0]" --test_augmentations.gaussian_blur_prob 1.0 --throw_exception_if_run_exists False --remove_if_run_exists True +# sh run_test_forada.sh ForAda-blur-0-0.0-tmp --test_augmentations.gaussian_blur_kernel_size 0 --test_augmentations.gaussian_blur_sigma "[0.0, 0.0]" --test_augmentations.gaussian_blur_prob 1.0 --throw_exception_if_run_exists False --remove_if_run_exists True + +# sh run_test_effort.sh Effort-jpeg100-tmp --test_augmentations.jpeg_quality "[100,100]" --throw_exception_if_run_exists False --remove_if_run_exists True +# sh run_test_forada.sh ForAda-jpeg100-tmp --test_augmentations.jpeg_quality "[100,100]" --throw_exception_if_run_exists False --remove_if_run_exists True + + +#! Uniformity-alignment α, β hyperparameter sweep +# seeds=(0) +# alphas=(0.0 0.1 0.5 1.0 5.0) +# betas=(0.0 0.1 0.5 1.0 5.0) + +# for seed in ${seeds[@]}; do +# for alpha in ${alphas[@]}; do +# for beta in ${betas[@]}; do +# sh run_job.sh -p $P wacv-PE_L-LN+L2+UnAl-sweep-A${alpha}-U${beta}-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True +# sh run_job.sh -p $P wacv-DINOv3L-LN+L2+UnAl-sweep-A${alpha}-U${beta}-seed${seed} --throw_exception_if_run_exists False --remove_if_run_exists True +# done +# done +# done \ No newline at end of file diff --git a/src/exp/wacv_rebuttal.py b/src/exp/wacv_rebuttal.py new file mode 100644 index 0000000000000000000000000000000000000000..3d988f317f963b03e849ea5c551f1ea6e1614d6b --- /dev/null +++ b/src/exp/wacv_rebuttal.py @@ -0,0 +1,315 @@ +from copy import deepcopy + +from .. import config as C +from ..config import Config +from ..utils import files + +experiments = { + #! CLIP + "wacv-Baseline": [ + Config( + backbone=C.Backbone.CLIP_L_14, + head=C.Head.Linear, + loss=C.Loss(ce_labels=1.0), + ) + ], + "wacv-LN": [ + Config( + backbone=C.Backbone.CLIP_L_14, + head=C.Head.Linear, + unfreeze_layers=["pre_layrnorm", "layer_norm1", "layer_norm2", "post_layernorm"], + loss=C.Loss(ce_labels=1.0), + ), + ], + "wacv-LN-noaug": [ + Config( + backbone=C.Backbone.CLIP_L_14, + head=C.Head.Linear, + unfreeze_layers=["pre_layrnorm", "layer_norm1", "layer_norm2", "post_layernorm"], + loss=C.Loss(ce_labels=1.0), + augmentations=None, + ), + ], + "wacv-LN+L2": [ + Config( + backbone=C.Backbone.CLIP_L_14, + head=C.Head.NLinear, + unfreeze_layers=["pre_layrnorm", "layer_norm1", "layer_norm2", "post_layernorm"], + loss=C.Loss(ce_labels=1.0), + ), + ], + "wacv-LN+L2+UnAl": [ + Config( + backbone=C.Backbone.CLIP_L_14, + head=C.Head.NLinear, + unfreeze_layers=["pre_layrnorm", "layer_norm1", "layer_norm2", "post_layernorm"], + loss=C.Loss(ce_labels=1.0, uniformity=0.5, alignment_labels=0.1), + ), + ], + #! CLIP+components without LN + "wacv-NoLN-L2": [ + Config( + backbone=C.Backbone.CLIP_L_14, + head=C.Head.NLinear, + loss=C.Loss(ce_labels=1.0), + ), + ], + "wacv-NoLN-L2+UnAl": [ + Config( + backbone=C.Backbone.CLIP_L_14, + head=C.Head.NLinear, + loss=C.Loss(ce_labels=1.0, uniformity=0.5, alignment_labels=0.1), + ), + ], + #! PE + "wacv-PE_L-Baseline": [ + Config( + backbone=C.Backbone.PerceptionEncoder_L_p14_336, + backbone_args=C.BackboneArgs(image_size=224), + head=C.Head.Linear, + loss=C.Loss(ce_labels=1.0), + ) + ], + "wacv-PE_L-LN": [ + Config( + backbone=C.Backbone.PerceptionEncoder_L_p14_336, + backbone_args=C.BackboneArgs(img_size=224), + head=C.Head.Linear, + unfreeze_layers=["norm_pre", "norm1", "norm2", "norm"], + loss=C.Loss(ce_labels=1.0), + ), + ], + "wacv-PE_L-LN-noaug": [ + Config( + backbone=C.Backbone.PerceptionEncoder_L_p14_336, + backbone_args=C.BackboneArgs(img_size=224), + head=C.Head.Linear, + unfreeze_layers=["norm_pre", "norm1", "norm2", "norm"], + loss=C.Loss(ce_labels=1.0), + augmentations=None, + ), + ], + "wacv-PE_L-LN+L2": [ + Config( + backbone=C.Backbone.PerceptionEncoder_L_p14_336, + backbone_args=C.BackboneArgs(img_size=224), + head=C.Head.NLinear, + unfreeze_layers=["norm_pre", "norm1", "norm2", "norm"], + loss=C.Loss(ce_labels=1.0), + ), + ], + "wacv-PE_L-LN+L2+UnAl": [ + Config( + backbone=C.Backbone.PerceptionEncoder_L_p14_336, + backbone_args=C.BackboneArgs(img_size=224), + head=C.Head.NLinear, + unfreeze_layers=["norm_pre", "norm1", "norm2", "norm"], + loss=C.Loss(ce_labels=1.0, uniformity=0.5, alignment_labels=0.1), + ), + ], + "wacv-PE_L-LN+L2+UA-U1.0-A0.5": [ + Config( + backbone=C.Backbone.PerceptionEncoder_L_p14_336, + backbone_args=C.BackboneArgs(img_size=224), + head=C.Head.NLinear, + unfreeze_layers=["norm_pre", "norm1", "norm2", "norm"], + loss=C.Loss(ce_labels=1.0, uniformity=1.0, alignment_labels=0.5), + ), + ], + "wacv-PE_L-LN+L2+UA-U1.0-A0.1": [ + Config( + backbone=C.Backbone.PerceptionEncoder_L_p14_336, + backbone_args=C.BackboneArgs(img_size=224), + head=C.Head.NLinear, + unfreeze_layers=["norm_pre", "norm1", "norm2", "norm"], + loss=C.Loss(ce_labels=1.0, uniformity=1.0, alignment_labels=0.1), + ), + ], + "wacv-PE_L-LN+L2+UA-U0.5-A0.0": [ + Config( + backbone=C.Backbone.PerceptionEncoder_L_p14_336, + backbone_args=C.BackboneArgs(img_size=224), + head=C.Head.NLinear, + unfreeze_layers=["norm_pre", "norm1", "norm2", "norm"], + loss=C.Loss(ce_labels=1.0, uniformity=0.5, alignment_labels=0.0), + ), + ], + "wacv-PE_L-LN+L2+UA-U1.0-A0.0": [ + Config( + backbone=C.Backbone.PerceptionEncoder_L_p14_336, + backbone_args=C.BackboneArgs(img_size=224), + head=C.Head.NLinear, + unfreeze_layers=["norm_pre", "norm1", "norm2", "norm"], + loss=C.Loss(ce_labels=1.0, uniformity=1.0, alignment_labels=0.0), + ), + ], + #! DINOv3 + "wacv-DINOv3L-Baseline": [ + Config( + backbone=C.Backbone.DINOv3_ViT_L, + head=C.Head.Linear, + loss=C.Loss(ce_labels=1.0), + ) + ], + "wacv-DINOv3L-LN": [ + Config( + backbone=C.Backbone.DINOv3_ViT_L, + head=C.Head.Linear, + unfreeze_layers=["norm1", "norm2", "norm"], + loss=C.Loss(ce_labels=1.0), + ), + ], + "wacv-DINOv3L-LN-noaug": [ + Config( + backbone=C.Backbone.DINOv3_ViT_L, + head=C.Head.Linear, + unfreeze_layers=["norm1", "norm2", "norm"], + loss=C.Loss(ce_labels=1.0), + augmentations=None, + ), + ], + "wacv-DINOv3L-LN+L2": [ + Config( + backbone=C.Backbone.DINOv3_ViT_L, + head=C.Head.NLinear, + unfreeze_layers=["norm1", "norm2", "norm"], + loss=C.Loss(ce_labels=1.0), + ), + ], + "wacv-DINOv3L-LN+L2+UnAl": [ + Config( + backbone=C.Backbone.DINOv3_ViT_L, + head=C.Head.NLinear, + unfreeze_layers=["norm1", "norm2", "norm"], + loss=C.Loss(ce_labels=1.0, uniformity=0.5, alignment_labels=0.1), + ), + ], + "wacv-DINOv3L-LN+L2+UA-U0.5-A0.5": [ + Config( + backbone=C.Backbone.DINOv3_ViT_L, + head=C.Head.NLinear, + unfreeze_layers=["norm1", "norm2", "norm"], + loss=C.Loss(ce_labels=1.0, uniformity=0.5, alignment_labels=0.5), + ), + ], + "wacv-DINOv3L-LN+L2+UA-U0.1-A0.5": [ + Config( + backbone=C.Backbone.DINOv3_ViT_L, + head=C.Head.NLinear, + unfreeze_layers=["norm1", "norm2", "norm"], + loss=C.Loss(ce_labels=1.0, uniformity=0.1, alignment_labels=0.5), + ), + ], +} + + +def get_common(): + config = Config() + config.run_dir = "runs/rebuttal" + config.wandb_tags = ["rebuttal"] + config.lr_scheduler = "cyclic" + config.num_epochs_in_cycle = 10 + config.early_stopping_patience = 30 + config.max_epochs = 30 + + config.val_files = [ + *files.DeepSpeak_v2.my_val, + *files.DeepSpeak_v1_1.my_val, + *files.CDFv3.my_val.map(lambda x: x.replace("/CDFv3/", "/CDFv3-x1.3-th0.5-all/subset/uniform-32-frames/")), + *files.FFIW.val, + ] + + config.tst_files = { + "FF": files.FF.test, + "FF-DF": files.FF.DF.test, + "FF-F2F": files.FF.F2F.test, + "FF-FS": files.FF.FS.test, + "FF-NT": files.FF.NT.test, + "CDF": files.CDFv2.test, + "FaceFusion": files.FaceFusion.CDF.test, + "DFD": files.DFD.test, + "DFDC": files.DFDC.test, + "FSh": files.FSh.test, + "UADFD": files.UADFV.test, + "DFDM": files.DFDM.test, + "FFIW": files.FFIW.test, + "DeepSpeak-1.1": files.DeepSpeak_v1_1.test, + "DeepSpeak-2.0": files.DeepSpeak_v2.test, + "KoDF": files.KoDF.test, + "KoDF-adv": files.KoDF.adversarial, + "FakeAVCeleb": files.FakeAVCeleb.test, + "FAVC-FV-RA-WL": files.FakeAVCeleb.FV_RA_WL.test, + "FAVC-FV-FA-FS": files.FakeAVCeleb.FV_FA_FS.test, + "FAVC-FV-FA-GAN": files.FakeAVCeleb.FV_FA_GAN.test, + "FAVC-FV-FA-WL": files.FakeAVCeleb.FV_FA_WL.test, + "PolyGlotFake": files.PolyGlotFake.test, + "IDForge-v1": files.IDForge_v1.test, + } | { + k: v.map(lambda x: x.replace("/CDFv3/", "/CDFv3-x1.3-th0.5-all/subset/uniform-32-frames/")) + for k, v in files.CDFv3.get_test_dict().items() + } + + return config + + +def set_common_settings(experiments): + for run_name, modifieres in experiments.items(): + experiments[run_name][0] = Config( + **{ + **get_common().model_dump(exclude_unset=True), # get default settings + **modifieres[0].model_dump(exclude_unset=True), # override with specific experiment settings + } + ) + + +set_common_settings(experiments) + +registered_experiments = deepcopy(experiments) + +#! Add 5 splits with different seeds and different trn splits +for seed in range(5): + for run_name, modifieres in registered_experiments.items(): + config = modifieres[0] + config = deepcopy(config) + + config.trn_files = files.FF.train.map( + lambda x: x.replace("/FF/", f"/FF-x1.3-th0.5-all/subset/random-32-frames/split-{seed}/") + ) + + config.seed = seed + + config.wandb_group = f"{run_name}" + + run_name = f"{run_name}-seed{seed}" + + experiments[run_name] = [config] + + +#! Add Uniformity-Alignment α, β hyperparameter sweep +seeds = [0] +alphas = [0.0, 0.1, 0.5, 1.0, 5.0] +betas = [0.0, 0.1, 0.5, 1.0, 5.0] + +for seed in seeds: + for alpha in alphas: + for beta in betas: + for run_name, modifieres in registered_experiments.items(): + if "UnAl" not in run_name: + continue + + config = modifieres[0] + config = deepcopy(config) + + config.trn_files = files.FF.train.map( + lambda x: x.replace("/FF/", f"/FF-x1.3-th0.5-all/subset/random-32-frames/split-{seed}/") + ) + + config.seed = seed + config.loss.uniformity = alpha + config.loss.alignment_labels = beta + + config.wandb_group = f"{run_name}-sweep" + + run_name = f"{run_name}-sweep-A{alpha}-U{beta}-seed{seed}" + + experiments[run_name] = [config] diff --git a/src/exp/wacv_rebuttal_aug_robustness.py b/src/exp/wacv_rebuttal_aug_robustness.py new file mode 100644 index 0000000000000000000000000000000000000000..4bc0d838b4f29201a1f205afb9f78122727ada21 --- /dev/null +++ b/src/exp/wacv_rebuttal_aug_robustness.py @@ -0,0 +1,119 @@ +from copy import deepcopy + +from .. import config as C +from ..config import Config +from ..utils import files + +_experiments = { + # "wacv-Baseline": [Config(test_augmentations=get_empty_augmentations())], + # "wacv-LN": [Config(test_augmentations=get_empty_augmentations())], + # "wacv-LN+L2": [Config(test_augmentations=get_empty_augmentations())], + "wacv-LN+L2+UnAl": [Config(test_augmentations=C.Augmentations.get_empty())], #! Select this group of runs + # "wacv-PE_L-Baseline": [Config(test_augmentations=get_empty_augmentations())], + # "wacv-PE_L-LN": [Config(test_augmentations=get_empty_augmentations())], + "wacv-PE_L-LN+L2": [Config(test_augmentations=C.Augmentations.get_empty())], #! Select this group of runs + # "wacv-PE_L-LN+L2+UnAl": [Config(test_augmentations=get_empty_augmentations())], +} + + +# Add common settings +for run_name, modifieres in _experiments.items(): + config = modifieres[0] + config.run_dir = "runs/test-aug-robustness" + config.wandb_tags = ["rebuttal", "test"] + config.seed = 0 # We want augmentations to be the same for all runs + + config.tst_files = { + "FF": files.FF.test, + "FF-DF": files.FF.DF.test, + "FF-F2F": files.FF.F2F.test, + "FF-FS": files.FF.FS.test, + "FF-NT": files.FF.NT.test, + "CDF": files.CDFv2.test, + "FaceFusion": files.FaceFusion.CDF.test, + "DFD": files.DFD.test, + "DFDC": files.DFDC.test, + "FSh": files.FSh.test, + "UADFD": files.UADFV.test, + "DFDM": files.DFDM.test, + "FFIW": files.FFIW.test, + "DeepSpeak-1.1": files.DeepSpeak_v1_1.test, + "DeepSpeak-2.0": files.DeepSpeak_v2.test, + "KoDF": files.KoDF.test, + "KoDF-adv": files.KoDF.adversarial, + "FakeAVCeleb": files.FakeAVCeleb.test, + "FAVC-FV-RA-WL": files.FakeAVCeleb.FV_RA_WL.test, + "FAVC-FV-FA-FS": files.FakeAVCeleb.FV_FA_FS.test, + "FAVC-FV-FA-GAN": files.FakeAVCeleb.FV_FA_GAN.test, + "FAVC-FV-FA-WL": files.FakeAVCeleb.FV_FA_WL.test, + "PolyGlotFake": files.PolyGlotFake.test, + "IDForge-v1": files.IDForge_v1.test, + } | { + k: v.map(lambda x: x.replace("/CDFv3/", "/CDFv3-x1.3-th0.5-all/subset/uniform-32-frames/")) + for k, v in files.CDFv3.get_test_dict().items() + } + + +jpeg_quality_levels = [100, 80, 60, 40, 20, 10] + + +experiments = {} +# Test for all 5 seeds +for seed in range(5): + for q in jpeg_quality_levels: + for run_name, modifieres in _experiments.items(): + config = modifieres[0] + config = deepcopy(config) + + config.test_augmentations.jpeg_quality = [q, q] + + run_name = f"{run_name}-seed{seed}-jpeg{q}" + + experiments[run_name] = [config] + + +blur_levels = [(0, 0.0), (5, 0.5), (7, 1.0), (11, 1.5), (13, 2.0), (19, 3.0)] + +for seed in range(5): + for k, s in blur_levels: + for run_name, modifieres in _experiments.items(): + config = modifieres[0] + config = deepcopy(config) + + config.test_augmentations.gaussian_blur_kernel_size = k + config.test_augmentations.gaussian_blur_sigma = (s, s) + + run_name = f"{run_name}-seed{seed}-blur-{k}-{s}" + + experiments[run_name] = [config] + + +resize_levels = [224, 112, 64] +interpolations = [0, 1, 2, 3, 4, 5] + +for seed in range(5): + for resize in resize_levels: + for interp in interpolations: + for run_name, modifieres in _experiments.items(): + config = modifieres[0] + config = deepcopy(config) + + config.test_augmentations.resize = resize + config.test_augmentations.resize_interpolation = interp + + run_name = f"{run_name}-seed{seed}-resize-{resize}-{interp}" + + experiments[run_name] = [config] + + +gaussian_noise_levels = [0.0, 0.01, 0.03, 0.05, 0.1, 0.2] +for l in gaussian_noise_levels: + for run_name, modifieres in _experiments.items(): + config = modifieres[0] + config = deepcopy(config) + + config.test_augmentations.gaussian_noise_sigma = l + + run_name = f"{run_name}-seed0-gaussian_noise-{l}" + + experiments[run_name] = [config] diff --git a/src/exp/wacv_rebuttal_paired_unpaired.py b/src/exp/wacv_rebuttal_paired_unpaired.py new file mode 100644 index 0000000000000000000000000000000000000000..20786a2c21332339ef247dd200a1fc5a91279c6c --- /dev/null +++ b/src/exp/wacv_rebuttal_paired_unpaired.py @@ -0,0 +1,154 @@ +from copy import deepcopy + +from .. import config as C +from ..config import Config +from ..utils import files + +experiments = { + "wacv-LN": [ + Config( + backbone=C.Backbone.CLIP_L_14, + head=C.Head.Linear, + unfreeze_layers=["pre_layrnorm", "layer_norm1", "layer_norm2", "post_layernorm"], + loss=C.Loss(ce_labels=1.0), + ), + ], + "wacv-PE_L-LN": [ + Config( + backbone=C.Backbone.PerceptionEncoder_L_p14_336, + backbone_args=C.BackboneArgs(img_size=224), + head=C.Head.Linear, + unfreeze_layers=["norm_pre", "norm1", "norm2", "norm"], + loss=C.Loss(ce_labels=1.0), + ), + ], +} + + +# Add common settings +for run_name, modifieres in experiments.items(): + config = modifieres[0] + config.run_dir = "runs/rebuttal" + config.wandb_tags = ["rebuttal"] + config.lr_scheduler = "cyclic" + config.num_epochs_in_cycle = 10 + config.max_epochs = 30 + config.early_stopping_patience = 30 + + config.val_files = [ + *files.DeepSpeak_v2.my_val, + *files.DeepSpeak_v1_1.my_val, + *files.CDFv3.my_val.map(lambda x: x.replace("/CDFv3/", "/CDFv3-x1.3-th0.5-all/subset/uniform-32-frames/")), + *files.FFIW.val, + ] + + config.tst_files = { + "FF": files.FF.test, + "FF-DF": files.FF.DF.test, + "FF-F2F": files.FF.F2F.test, + "FF-FS": files.FF.FS.test, + "FF-NT": files.FF.NT.test, + "CDF": files.CDFv2.test, + "FaceFusion": files.FaceFusion.CDF.test, + "DFD": files.DFD.test, + "DFDC": files.DFDC.test, + "FSh": files.FSh.test, + "UADFD": files.UADFV.test, + "DFDM": files.DFDM.test, + "FFIW": files.FFIW.test, + "DeepSpeak-1.1": files.DeepSpeak_v1_1.test, + "DeepSpeak-2.0": files.DeepSpeak_v2.test, + "KoDF": files.KoDF.test, + "KoDF-adv": files.KoDF.adversarial, + "FakeAVCeleb": files.FakeAVCeleb.test, + "FAVC-FV-RA-WL": files.FakeAVCeleb.FV_RA_WL.test, + "FAVC-FV-FA-FS": files.FakeAVCeleb.FV_FA_FS.test, + "FAVC-FV-FA-GAN": files.FakeAVCeleb.FV_FA_GAN.test, + "FAVC-FV-FA-WL": files.FakeAVCeleb.FV_FA_WL.test, + "PolyGlotFake": files.PolyGlotFake.test, + "IDForge-v1": files.IDForge_v1.test, + } | { + k: v.map(lambda x: x.replace("/CDFv3/", "/CDFv3-x1.3-th0.5-all/subset/uniform-32-frames/")) + for k, v in files.CDFv3.get_test_dict().items() + } + + +registered_experiments = deepcopy(experiments) +experiments = {} + +# add 20 splits with different seeds and different trn splits +for paired_or_unpaired in ["paired", "unpaired"]: + for seed in range(20): + for run_name, modifieres in registered_experiments.items(): + config = modifieres[0] + config = deepcopy(config) + + config.trn_files = files.FF.train.map( + lambda x: x.replace( + "/FF/", f"/FF-x1.3-th0.5-all/subset/paired-unpaired/split-{seed}/{paired_or_unpaired}/" + ) + ) + + config.seed = seed + + run_name = run_name.replace("wacv-", f"wacv-{paired_or_unpaired}-") + run_name = f"{run_name}-seed{seed:02d}" + + experiments[run_name] = [config] + + +#! Add 20 splits training on CDFv2, remove CDFv2 from val_files, add FF++ validation set +for paired_or_unpaired in ["paired", "unpaired"]: + for seed in range(20): + for run_name, modifieres in registered_experiments.items(): + config = modifieres[0] + config = deepcopy(config) + + config.trn_files = [ + f"config/datasets/CDFv3-x1.3-th0.5-all/subset/paired-unpaired/split-{seed}/{paired_or_unpaired}/train/Celeb-DF-v2.txt", + f"config/datasets/CDFv3-x1.3-th0.5-all/subset/paired-unpaired/split-{seed}/{paired_or_unpaired}/train/Celeb-real.txt", + ] + + config.val_files = [ + *files.FF.val, + *files.DeepSpeak_v2.my_val, + *files.DeepSpeak_v1_1.my_val, + *files.FFIW.val, + ] + + config.seed = seed + + run_name = run_name.replace("wacv-", f"wacv-{paired_or_unpaired}-CDFv2-") + run_name = f"{run_name}-seed{seed:02d}" + + experiments[run_name] = [config] + + +#! Add 20 splits training on FAVC +for paired_or_unpaired in ["paired", "unpaired"]: + for seed in range(20): + for run_name, modifieres in registered_experiments.items(): + config = modifieres[0] + config = deepcopy(config) + + config.trn_files = [ + f"config/datasets/FakeAVCeleb/subset/paired-unpaired/split-{seed}/{paired_or_unpaired}/train/fake.txt", + f"config/datasets/FakeAVCeleb/subset/paired-unpaired/split-{seed}/{paired_or_unpaired}/train/real.txt", + ] + + config.seed = seed + + run_name = run_name.replace("wacv-", f"wacv-{paired_or_unpaired}-FAVC-") + + # Create a group + config.wandb_group = run_name + + # Add seed to name + run_name = f"{run_name}-seed{seed:02d}" + + experiments[run_name] = [config] + + +#! List all experiments +# for run_name in sorted(experiments.keys()): +# print(run_name) diff --git a/src/heads/head.py b/src/heads/head.py new file mode 100644 index 0000000000000000000000000000000000000000..e8294147fecd939a7ad10e1845bee0f8572d2eea --- /dev/null +++ b/src/heads/head.py @@ -0,0 +1,45 @@ +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +@dataclass +class HeadOutput: + logits_labels: None | torch.Tensor = None + l2_embeddings: torch.Tensor = None + + +class LinearProbe(nn.Module): + """ + x - input tensor of shape (B, D) + y - output tensor of shape (B, C), logits + z - output tensor of shape (B, D), embeddings + f - classifier that maps D -> C + + Pseudocode: + if normalized: + x = x / ||x|| # normalized inputs + + y = f(x) # logits + z = x / ||x|| # normalized embeddings + + return y, z + """ + + def __init__(self, input_dim, num_classes, normalize_inputs=False, detach_classifier_inputs=False): + super().__init__() + self.linear = nn.Linear(input_dim, num_classes) + self.normalize_inputs = normalize_inputs + self.detach_classifier_inputs = detach_classifier_inputs + + def forward(self, x: torch.Tensor, **kwargs) -> HeadOutput: + l2_embeddings = F.normalize(x, p=2, dim=1) + + if self.normalize_inputs: + x = l2_embeddings + + logits = self.linear(x if not self.detach_classifier_inputs else x.detach()) + + return HeadOutput(logits_labels=logits, l2_embeddings=l2_embeddings) diff --git a/src/hf/modeling_gend.py b/src/hf/modeling_gend.py new file mode 100644 index 0000000000000000000000000000000000000000..d08d5d0217d92d3d592c563f70ae1cb5b9f8ccdc --- /dev/null +++ b/src/hf/modeling_gend.py @@ -0,0 +1,163 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image +from transformers import PretrainedConfig, PreTrainedModel + + +class LinearProbe(nn.Module): + def __init__(self, input_dim, num_classes, normalize_inputs=False): + super().__init__() + self.linear = nn.Linear(input_dim, num_classes) + self.normalize_inputs = normalize_inputs + + def forward(self, x: torch.Tensor, **kwargs): + if self.normalize_inputs: + x = F.normalize(x, p=2, dim=1) + + return self.linear(x) + + +class CLIPEncoder(nn.Module): + def __init__(self, model_name="openai/clip-vit-large-patch14"): + super().__init__() + + from transformers import CLIPModel, CLIPProcessor + + try: + self._preprocess = CLIPProcessor.from_pretrained(model_name) + except Exception: + self._preprocess = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16") + + clip: CLIPModel = CLIPModel.from_pretrained(model_name) + + # take vision model from CLIP, maps image to vision_embed_dim + self.vision_model = clip.vision_model + + self.model_name = model_name + + self.features_dim = self.vision_model.config.hidden_size + + # take visual_projection, maps vision_embed_dim to projection_dim + self.visual_projection = clip.visual_projection + + def preprocess(self, image: Image) -> torch.Tensor: + return self._preprocess(images=image, return_tensors="pt")["pixel_values"][0] + + def forward(self, preprocessed_images: torch.Tensor) -> torch.Tensor: + return self.vision_model(preprocessed_images).pooler_output + + def get_features_dim(self): + return self.features_dim + + +class DINOEncoder(nn.Module): + def __init__(self, model_name="facebook/dinov2-with-registers-base"): + super().__init__() + + from transformers import AutoImageProcessor, AutoModel, Dinov2Model, Dinov2WithRegistersModel + + self._preprocess = AutoImageProcessor.from_pretrained(model_name) + self.backbone: Dinov2Model | Dinov2WithRegistersModel = AutoModel.from_pretrained(model_name) + + self.features_dim = self.backbone.config.hidden_size + + def preprocess(self, image: Image) -> torch.Tensor: + return self._preprocess(images=image, return_tensors="pt")["pixel_values"][0] + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + return self.backbone(inputs).last_hidden_state[:, 0] + + def get_features_dim(self) -> int: + return self.features_dim + + +class PerceptionEncoder(nn.Module): + def __init__(self, model_name="vit_pe_core_large_patch14_336"): + super().__init__() + + import timm + from timm.models.eva import Eva + + self.backbone: Eva = timm.create_model( + model_name, + pretrained=True, + dynamic_img_size=True, + ) + + # Get model specific transforms (normalization, resize) + data_config = timm.data.resolve_model_data_config(self.backbone) + data_config["input_size"] = (3, 224, 224) + + self._preprocess = timm.data.create_transform(**data_config, is_training=False) + + # Remove head + self.backbone.head = nn.Identity() + + self.features_dim = self.backbone.num_features + + def preprocess(self, image: Image.Image) -> torch.Tensor: + return self._preprocess(image) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + return self.backbone(inputs) + + def get_features_dim(self) -> int: + return self.features_dim + + +class GenDConfig(PretrainedConfig): + model_type = "GenD" + + def __init__(self, backbone: str = "openai/clip-vit-large-patch14", head: str = "linear", **kwargs): + super().__init__(**kwargs) + self.backbone = backbone + self.head = head + + +class GenD(PreTrainedModel): + config_class = GenDConfig + + def __init__(self, config): + super().__init__(config) + + self.head = config.head + self.backbone = config.backbone + self.config = config + + self._init_feature_extractor() + self._init_head() + + def _init_feature_extractor(self): + backbone = self.backbone + backbone_lowercase = backbone.lower() + + if "clip" in backbone_lowercase: + self.feature_extractor = CLIPEncoder(backbone) + + elif "vit_pe" in backbone_lowercase: + self.feature_extractor = PerceptionEncoder(backbone) + + elif "dino" in backbone_lowercase: + self.feature_extractor = DINOEncoder(backbone) + + else: + raise ValueError(f"Unknown backbone: {backbone}") + + def _init_head(self): + features_dim = self.feature_extractor.get_features_dim() + + match self.head: + case "linear": + self.model = LinearProbe(features_dim, 2) + + case "LinearNorm": + self.model = LinearProbe(features_dim, 2, True) + + case _: + raise ValueError(f"Unknown head: {self.head}") + + def forward(self, inputs: torch.Tensor): + features = self.feature_extractor(inputs) + outputs = self.model.forward(features) + return outputs diff --git a/src/loss.py b/src/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..6d91d3c3361e3c96285ce946c38d3165da324655 --- /dev/null +++ b/src/loss.py @@ -0,0 +1,83 @@ +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from src.losses.unifalign import alignment, uniformity +from src.utils import logger + +from .config import Loss as LossConfig + + +@dataclass +class LossInputs: + logits_labels: None | torch.Tensor = None + labels: None | torch.Tensor = None + l2_embeddings: None | torch.Tensor = None + + +@dataclass +class LossOutputs: + ce_labels: None | float = None + uniformity: None | float = None + alignment_labels: None | float = None + compactness: None | float = None + total: int | torch.Tensor = 0 + + +class Loss(nn.Module): + def __init__(self, config: LossConfig): + super().__init__() + self.config = config + + def forward( + self, + inputs: LossInputs, + ) -> LossOutputs: + loss_outputs = LossOutputs() + config = self.config + + if inputs.logits_labels is not None: + if config.ce_labels: + L = config.ce_labels * F.cross_entropy( + inputs.logits_labels, inputs.labels, label_smoothing=config.label_smoothing + ) + loss_outputs.ce_labels = L.item() + loss_outputs.total += L + + if inputs.l2_embeddings is not None: + # L2 normalize embeddings + # See 3.1 https://arxiv.org/pdf/2004.11362 + # embeddings = F.normalize(inputs.embeddings, p=2, dim=1) + l2_embeddings = inputs.l2_embeddings + + # check that embeddings are normalized + if not torch.allclose( + l2_embeddings.norm(p=2, dim=1), + torch.ones(l2_embeddings.size(0), device=l2_embeddings.device, dtype=l2_embeddings.dtype), + ): + logger.print_warning_once("[yellow]Embeddings are not normalized") + + if inputs.labels is not None: + if config.alignment_labels: + L = config.alignment_labels * alignment(l2_embeddings, inputs.labels) + loss_outputs.alignment_labels = L.item() + loss_outputs.total += L + + if config.uniformity: + L = config.uniformity * uniformity(l2_embeddings) + loss_outputs.uniformity = L.item() + loss_outputs.total += L + + if isinstance(loss_outputs.total, int): + logger.print_warning_once("[yellow]Total loss is 0. Check if loss coefficients are set correctly.") + + if isinstance(loss_outputs.total, torch.Tensor) and loss_outputs.total.isnan(): + logger.print_warning("[yellow]Total loss is nan") + loss_outputs.total = inputs.logits_labels.sum() * 0 + + return loss_outputs + + def __call__(self, inputs: LossInputs) -> LossOutputs: + return super().__call__(inputs) diff --git a/src/losses/unifalign.py b/src/losses/unifalign.py new file mode 100644 index 0000000000000000000000000000000000000000..a4006e50547da3e64e276a2de7a7b7e60e460b9f --- /dev/null +++ b/src/losses/unifalign.py @@ -0,0 +1,90 @@ +import torch + + +def alignment( + embeddings: torch.Tensor, + labels: torch.Tensor, + alpha: float = 2, +): + """ + https://arxiv.org/pdf/2005.10242 + + Label-aware Alignment loss. + + Calculates alignment for embeddings of samples with the SAME label + within a batch, assuming embeddings are already unit-normalized. + + Args: + embeddings: Tensor [N, D] - Batch of unit-normalized embeddings. + labels: Tensor [N] - Corresponding labels. + alpha: Power to raise squared distance (hyperparameter, default=2). + + Returns: + Tensor: Label-aware Alignment loss (scalar). Returns 0 if no positive pairs. + """ + assert embeddings.size(0) == labels.size(0), "Embeddings and labels must have the same size." + + n_samples = embeddings.size(0) + if n_samples < 2: + return torch.tensor(0.0, device=embeddings.device) + + # Create a pairwise label comparison matrix (N x N), exclude self-pairs + labels_equal_mask = (labels[:, None] == labels[None, :]).triu(diagonal=1) + + positive_indices = torch.nonzero(labels_equal_mask, as_tuple=False) + if positive_indices.numel() == 0: + return torch.tensor(0.0, device=embeddings.device) + + # Get embeddings of positive pairs + x = embeddings[positive_indices[:, 0]] + y = embeddings[positive_indices[:, 1]] + + # Calculate alignment loss + return (x - y).norm(p=2, dim=1).pow(alpha).mean() + + +def uniformity( + x: torch.Tensor, + t: float = 2, + clip_value: float = 1e-6, +): + """ + https://arxiv.org/pdf/2005.10242 + + Calculates the Uniformity loss. + + Args: + x: [N, D] - Batch of feature embeddings. + t: Temperature parameter (hyperparameter). + + Returns: + Tensor: Uniformity loss value (scalar). + """ + return torch.pdist(x, p=2).pow(2).mul(-t).exp().mean().clamp(min=clip_value).log() + + +if __name__ == "__main__": + embeddings = torch.tensor( + [ + [1.0, 0.0], + [1.0, 0.0], + [1.0, 1.0], + [0.0, 1.0], + [0.0, 1.0], + ], + ) + embeddings /= embeddings.norm(p=2, dim=1, keepdim=True) + + labels = torch.tensor([0, 0, 0, 1, 1]) + + print("Embeddings:") + print(embeddings.numpy()) + + print("\nLabels:") + print(labels.numpy()) + + alignment_loss = alignment(embeddings, labels, alpha=2) + print("\nAlignment loss:", alignment_loss.item()) + + uniformity_loss = uniformity(embeddings, t=2, clip_value=1e-6) + print("Uniformity loss:", uniformity_loss.item()) diff --git a/src/metrics.py b/src/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..6d75ac86fcb7129fd5f5aab1316864f99c15d4b5 --- /dev/null +++ b/src/metrics.py @@ -0,0 +1,160 @@ +import numpy as np +from scipy.interpolate import interp1d +from scipy.optimize import brentq +from scipy.stats import wasserstein_distance +from sklearn import metrics as M + + +def ovr_roc(labels: np.ndarray, probs: np.ndarray): + """ + Calculate the One-vs-Rest (OvR) Receiver Operating Characteristic (ROC) and Area Under the ROC Curve (AUROC) for each class. + + Parameters: + labels (np.ndarray): Array of true class labels. Shape should be (n_samples,). + probs (np.ndarray): Array of predicted probabilities for each class. Shape should be (n_samples, n_classes). + + Returns: + tuple: A tuple containing: + - aurocs (list): List of AUROC values for each class. + - fprs (list): List of false positive rates for each class. + - tprs (list): List of true positive rates for each class. + - ths (list): List of thresholds for each class. + - ovr_macro_auroc (float): Macro-averaged AUROC for the OvR setting. + """ + num_classes = probs.shape[1] + labels_one_hot = np.eye(num_classes)[labels] + fprs, tprs, ths = [], [], [] + + # Why OvR with macro avg: https://chatgpt.com/share/677e448d-5bc0-8006-b9b5-081427b02857 + ovr_macro_auroc = M.roc_auc_score(labels_one_hot, probs, multi_class="ovr", average="macro") + + # Calculate OvR ROC and AUROC for each class + for i in range(num_classes): + fpr_class, tpr_class, ths_class = M.roc_curve(labels_one_hot[:, i], probs[:, i]) + ths_class = np.nan_to_num(ths_class, posinf=1.0) # replace inf with max value + ths_class = np.concatenate(([1], ths_class, [0])) # add 0 and 1 thresholds + fpr_class = np.concatenate(([0], fpr_class, [1])) # add 0 and 1 fpr + tpr_class = np.concatenate(([0], tpr_class, [1])) # add 0 and 1 tpr + fprs.append(fpr_class) + tprs.append(tpr_class) + ths.append(ths_class) + + return fprs, tprs, ths, ovr_macro_auroc + + +def ovr_prc(labels: np.ndarray, probs: np.ndarray): + """ + Calculate the One-vs-Rest (OvR) Precision-Recall Curve (PRC) and the mean Average Precision (mAP) for a multi-class classification problem. + + Args: + labels (np.ndarray): Array of true class labels with shape (n_samples,). + probs (np.ndarray): Array of predicted probabilities with shape (n_samples, n_classes). + + Returns: + tuple: A tuple containing: + - precs (list of np.ndarray): List of precision values for each class. + - recs (list of np.ndarray): List of recall values for each class. + - ths (list of np.ndarray): List of threshold values for each class. + - ovr_macro_ap (float): The mean Average Precision (mAP) score. + """ + num_classes = probs.shape[1] + labels_one_hot = np.eye(num_classes)[labels] + precs, recs, ths = [], [], [] + + # The same as mAP (mean Average Precision) + ovr_macro_ap = M.average_precision_score(labels_one_hot, probs, average="macro") + + # Calculate OvR PRC for each class + for i in range(num_classes): + prec_class, rec_class, ths_class = M.precision_recall_curve(labels_one_hot[:, i], probs[:, i]) + ths_class = np.nan_to_num(ths_class, posinf=1.0) # replace inf with max value + ths_class = np.concatenate(([1], ths_class, [0])) # add 0 and 1 thresholds + prec_class = np.concatenate(([0], prec_class, [1])) # add 0 and 1 precision + rec_class = np.concatenate(([1], rec_class, [0])) # add 0 and 1 recall + precs.append(prec_class) + recs.append(rec_class) + ths.append(ths_class) + + return precs, recs, ths, ovr_macro_ap + + +def calculate_eer(y_true: np.ndarray, y_score: np.ndarray, return_threshold: bool = False): + """ + Returns the equal error rate (EER) and the threshold at which EER occurs + for a binary classifier output. + + Args: + y_true (np.ndarray): True binary labels. + y_score (np.ndarray): Target scores, can either be probability estimates of the positive class, + confidence values, or non-thresholded measure of decisions. + Assumes shape (n_samples, 2) where column 1 is the positive class score. + + Returns: + tuple: A tuple containing: + - eer (float): The Equal Error Rate. + - threshold (float): The threshold at which EER occurs. Returns NaN if EER calculation fails. + """ + fpr, tpr, thresholds = M.roc_curve(y_true, y_score[:, 1], pos_label=1) + try: + eer = brentq(lambda x: 1.0 - x - interp1d(fpr, tpr)(x), 0.0, 1.0) + except ValueError: + eer = np.nan + + if return_threshold: + return eer, float(interp1d(fpr, thresholds)(eer)) + + return eer + + +def calculate_tpr_at_fpr(y_true: np.ndarray, y_score: np.ndarray, fpr_targets: list = [0.01, 0.05]): + """ + Calculate True Positive Rate (TPR) at specified False Positive Rate (FPR) levels for binary classification. + + Args: + y_true (np.ndarray): True binary labels (0 or 1). + y_score (np.ndarray): Predicted probabilities or scores, shape (n_samples, 2), where column 1 is for positive class. + fpr_targets (list): List of FPR targets (e.g., [0.01, 0.05] for 1% and 5%). + + Returns: + list: List of TPR values corresponding to the specified FPR targets. If a target FPR is out of range, NaN is returned for that target. + """ + fpr, tpr, _ = M.roc_curve(y_true, y_score[:, 1], pos_label=1) + + results = [] + for target in fpr_targets: + if target < fpr.min() or target > fpr.max(): + results.append(np.nan) + else: + results.append(np.interp(target, fpr, tpr)) + + return results + + +def compute_wasserstein1_metrics(probs: np.ndarray, labels: np.ndarray): + is_real = labels == 0 + is_fake = labels == 1 + + if is_real.any() and is_fake.any(): + #! Compute Wasserstein-1 distance for inter-class separability + # These W1(u, v) reflect how well the model separates the two classes + # u ~ P(p(y=0|x) | y=0) + # v ~ P(p(y=0|x) | y=1) + W1_sep_real = wasserstein_distance(probs[is_real, 0], probs[is_fake, 0]) + + # u ~ P(p(y=1|x) | y=0) + # v ~ P(p(y=1|x) | y=1) + W1_sep_fake = wasserstein_distance(probs[is_real, 1], probs[is_fake, 1]) + + #! Compute Wasserstein-1 distance for intra-sample confidence margin + # These W1(u, v) reflect how confident the model is about its predictions + # u ∼ P(p(y=0∣x) ∣ y=0) + # v ∼ P(p(y=1∣x) ∣ y=0) + W1_conf_real = wasserstein_distance(probs[is_real, 0], probs[is_real, 1]) + + # u ∼ P(p(y=0∣x) ∣ y=1) + # v ∼ P(p(y=1∣x) ∣ y=1) + W1_conf_fake = wasserstein_distance(probs[is_fake, 0], probs[is_fake, 1]) + + return W1_sep_real, W1_sep_fake, W1_conf_real, W1_conf_fake + + return -1, -1, -1, -1 diff --git a/src/model/Effort.py b/src/model/Effort.py new file mode 100644 index 0000000000000000000000000000000000000000..4f4e095ddd004f68cec1f1da39b2a7fd85e2b9e9 --- /dev/null +++ b/src/model/Effort.py @@ -0,0 +1,78 @@ +from typing import override + +import torch +import torchvision.transforms as T +from PIL import Image + +from src.config import Config +from src.heads.head import HeadOutput +from src.model.base import BaseDeepakeDetectionModel, OutputsForMetrics +from src.model.effort.model import EffortModel +from src.utils import logger + +preprocessing_alternative = T.Compose( + [ + T.Resize((224, 224), interpolation=T.InterpolationMode.BILINEAR), + T.ToTensor(), + T.Normalize( + [0.48145466, 0.4578275, 0.40821073], + [0.26862954, 0.26130258, 0.27577711], + ), + ] +) + + +class Effort(BaseDeepakeDetectionModel): + def __init__(self, config: Config): + super().__init__(config, verbose=True) + self.detector = EffortModel() + self.test_step_outputs = OutputsForMetrics() + + self.detector.eval() + + @override + def forward(self, inputs: torch.Tensor) -> HeadOutput: + logits, l2_embeddings = self.detector(inputs) + return HeadOutput(logits_labels=logits, l2_embeddings=l2_embeddings) + + @override + def test_step(self, batch, batch_idx): + batch = self.get_batch(batch) + outputs = self.forward(batch.images) + probs = outputs.logits_labels.softmax(dim=1) + + # Save outputs for metrics calculation + self.test_step_outputs.labels.update(batch.labels) + self.test_step_outputs.probs.update(probs.detach()) + self.test_step_outputs.idx.update(batch.idx) + + @override + def load_checkpoint(self, checkpoint_path: str): + """Load the model checkpoint.""" + logger.print_info(f"Loading checkpoint from {checkpoint_path}") + checkpoint = torch.load(checkpoint_path, map_location="cpu") + state_dict = checkpoint.get("state_dict", checkpoint) + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + incompatible_keys = self.detector.load_state_dict(state_dict, strict=False) + self.print_checkpoint_keys(incompatible_keys) + + @override + def get_preprocessing(self): + def preprocess(image: Image) -> torch.Tensor: + return preprocessing_alternative(image) + + return preprocess + + +if __name__ == "__main__": + # Example usage + model = Effort() + print(model) + + model.load_checkpoint("weights/effort/effort_clip_L14_trainOn_FaceForensic.pth") + + image = Image.open("datasets/FF/real/000/000.png") + tensor = preprocessing_alternative(image).unsqueeze(0) # Add batch dimension + outputs = model({"image": tensor}) + + print(outputs) diff --git a/src/model/FSFM.py b/src/model/FSFM.py new file mode 100644 index 0000000000000000000000000000000000000000..ac8243824176adf2a6903bccf0dfc894388615bc --- /dev/null +++ b/src/model/FSFM.py @@ -0,0 +1,108 @@ +import os +from typing import override + +import torch +import torchvision.transforms as T +from PIL import Image + +from src.config import Config, CustomPreprocessing +from src.heads.head import HeadOutput +from src.model.base import BaseDeepakeDetectionModel +from src.model.fsfm import models_vit, models_vit_fs_adapter +from src.utils import logger + + +def download_model_if_needed(checkpoint_path: str, link: str): + if not os.path.exists(checkpoint_path): + logger.print_warning_once(f"Checkpoint '{checkpoint_path}' not found, downloading...") + os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True) + os.system(f"wget {link} -O {checkpoint_path}") + + +class FSFM(BaseDeepakeDetectionModel): + def __init__(self, config: Config): + super().__init__(config, verbose=True) + self.initialize_model(config.checkpoint) + self.model.eval() + + def initialize_model(self, checkpoint_path: str): + if checkpoint_path == "weights/FS-VFM/FS-VFM-ViT-L-Adapter.pth": + link = "https://hf.co/Wolowolo/fsfm-3c/resolve/main/finetuned_models/FS-VFM_extensions/finetune_fs-adapter/cross_dataset_DfD_and_DiFF/ViT-L_VF2_600e/FT_on_FF%2B%2B_c23_32frames/checkpoint-min_val_loss.pth?download=true" + download_model_if_needed(checkpoint_path, link) + self.model = models_vit_fs_adapter.vit_large_patch16(num_classes=2, drop_path_rate=0.1, global_pool=True) + + elif checkpoint_path == "weights/FS-VFM/FS-VFM-ViT-L.pth": + link = "https://hf.co/Wolowolo/fsfm-3c/resolve/main/finetuned_models/FS-VFM_extensions/cross_dataset_DFD_and_DiFF/ViT-L_VF2_600e/FT_on_FF%2B%2B_c23_32frames/checkpoint-min_val_loss.pth?download=true" + download_model_if_needed(checkpoint_path, link) + self.model = models_vit.vit_large_patch16( + num_classes=2, + drop_path_rate=0.1, + global_pool=True, + ) + + else: + raise ValueError(f"Unknown FS-VFM checkpoint path: {checkpoint_path}") + + @override + def forward(self, inputs: torch.Tensor) -> HeadOutput: + outputs = self.model(inputs) + outputs = outputs[..., [1, 0]] # Swap 0 and 1 rows to have [real, fake] + return HeadOutput(logits_labels=outputs) + + @override + def test_step(self, batch, batch_idx): + batch = self.get_batch(batch) + outputs = self.forward(batch.images) + probs = outputs.logits_labels.softmax(dim=1) + + # Save outputs for metrics calculation + self.test_step_outputs.labels.update(batch.labels) + self.test_step_outputs.probs.update(probs.detach()) + self.test_step_outputs.idx.update(batch.idx) + + @override + def load_checkpoint(self, checkpoint_path: str): + """Load the model checkpoint.""" + logger.print_info(f"Loading checkpoint from {checkpoint_path}") + checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) + incompatible_keys = self.model.load_state_dict(checkpoint["model"], strict=False) + self.print_checkpoint_keys(incompatible_keys) + + @override + def get_preprocessing(self): + if self.config.custom_preprocessing is None: + logger.print_warning_once("This model might expect a zoom in to the facial image. Make sure to tune it.") + + def preprocess(image: Image) -> torch.Tensor: + image = self.custom_preprocessing(image) + return transform(image) + + return preprocess + + +transform = T.Compose( + [ + T.Resize(224, interpolation=T.InterpolationMode.BILINEAR), + T.ToTensor(), + T.Normalize( + [0.5482207536697388, 0.42340534925460815, 0.3654651641845703], + [0.2789176106452942, 0.2438540756702423, 0.23493893444538116], + ), + ] +) + + +if __name__ == "__main__": + config = Config( + checkpoint="weights/FS-VFM/FS-VFM-ViT-L.pth", + custom_preprocessing=CustomPreprocessing(zoom_factor=1.3), + ) + model = FSFM(config) + model.load_checkpoint(config.checkpoint) + + image = Image.open("datasets/FF/DF/001_870/000.png") + # image = Image.open("datasets/FF/real/001/000.png") + preprocessed_image = model.get_preprocessing()(image) # Convert to tensor + batch = preprocessed_image.unsqueeze(0) # Add batch dimension + outputs = model.forward(batch) + print(outputs.logits_labels.softmax(dim=-1)) diff --git a/src/model/ForAda.py b/src/model/ForAda.py new file mode 100644 index 0000000000000000000000000000000000000000..4fe9e97ddd20b57ad5e2264fcad7b9899852f998 --- /dev/null +++ b/src/model/ForAda.py @@ -0,0 +1,121 @@ +from typing import override + +import cv2 +import numpy as np +import torch +import yaml +from PIL import Image +from torchvision import transforms as T + +from src.config import Config +from src.heads.head import HeadOutput +from src.model.base import BaseDeepakeDetectionModel, OutputsForMetrics +from src.model.forada.ds import DS +from src.utils import logger + + +class ForAda(BaseDeepakeDetectionModel): + def __init__(self, config: Config): + super().__init__(config, verbose=True) + + # load yaml file relative to the current file + config_path = __file__.replace("forensics_adapter.py", "forensics_adapter_model/config.yaml") + with open(config_path, "r") as f: + config = yaml.safe_load(f) + + self.model = DS( + clip_name=config["clip_model_name"], + adapter_vit_name=config["vit_name"], + num_quires=config["num_quires"], + fusion_map=config["fusion_map"], + mlp_dim=config["mlp_dim"], + mlp_out_dim=config["mlp_out_dim"], + head_num=config["head_num"], + ) + self.eval() + + @override + def forward(self, inputs: torch.Tensor) -> HeadOutput: + outputs = self.model({"image": inputs}, inference=True) + return HeadOutput(logits_labels=outputs["logits"]) + + @override + def on_test_epoch_start(self): + self.test_step_outputs = OutputsForMetrics() + # move model to the device + self.model.to(self.trainer.strategy.root_device) + + @override + def test_step(self, batch, batch_idx): + batch = self.get_batch(batch) + outputs = self.forward(batch.images) + probs = outputs.logits_labels.softmax(dim=1) + + # Save outputs for metrics calculation + self.test_step_outputs.labels.update(batch.labels) + self.test_step_outputs.probs.update(probs.detach()) + self.test_step_outputs.idx.update(batch.idx) + + @override + def load_checkpoint(self, checkpoint_path: str): + """Load the model checkpoint.""" + logger.print_info(f"Loading checkpoint from {checkpoint_path}") + state_dict = torch.load(checkpoint_path, map_location="cpu") + incompatible_keys = self.model.load_state_dict(state_dict, strict=False) + self.print_checkpoint_keys(incompatible_keys) + + @override + def get_preprocessing(self): + def preprocess(image: Image) -> torch.Tensor: + return preprocessing(image) + + return preprocess + + +_preprocess = T.Compose( + [ + T.ToTensor(), + T.Normalize( + mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711], + ), + ] +) + + +def preprocessing(image: Image) -> torch.Tensor: + image = np.array(image) + image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_CUBIC) + image = np.array(image, dtype=np.uint8) + image = _preprocess(image) + # image = F.interpolate( + # image.unsqueeze(0), + # size=(224, 224), + # mode="bilinear", + # align_corners=False, + # )[0] + return image + + +if __name__ == "__main__": + #! Run as module: + #! python -m src.model.forensics_adapter + + from PIL import Image + + from src.config import Config + from src.model.ForAda import ForAda + + config = Config() + model = ForAda(config) + + model.load_checkpoint("weights/forensics_adapter/ForensicsAdapter.pth") + + path = "datasets/FF/real/000/000.png" + image = Image.open(path) # Load image + preprocessed_image = model.get_preprocessing()(image) # Convert to tensor + batch = preprocessed_image.unsqueeze(0) # Add batch dimension + outputs = model(batch) + + print(outputs.logits_labels) # Print logits labels + print(outputs.logits_labels.softmax(dim=1)) # Print probabilities diff --git a/src/model/GenD.py b/src/model/GenD.py new file mode 100644 index 0000000000000000000000000000000000000000..02fd0c778fac7fa7c4453a7022197d4440225163 --- /dev/null +++ b/src/model/GenD.py @@ -0,0 +1,356 @@ +from typing import Callable + +import torch +from lightning import seed_everything +from PIL import Image +from torch import optim + +from src import config as C +from src.config import Config, Head +from src.heads import head +from src.loss import Loss, LossInputs, LossOutputs +from src.losses import unifalign +from src.model.base import BaseDeepakeDetectionModel, Batch +from src.utils import logger +from src.utils.decorators import TryExcept + + +class GenD(BaseDeepakeDetectionModel): + def __init__(self, config: Config, verbose: bool = False): + super().__init__(config, verbose) + self.config = config + self.save_hyperparameters(config.model_dump()) + self.is_debug_mode = "tmp" in config.run_name + + if verbose: + logger.print(config) + + seed_everything(self.config.seed, workers=True, verbose=verbose) + + self._init_specific_attributes(verbose) + + def _init_specific_attributes(self, verbose: bool = False): + self._init_feature_extractor() + self._init_head() + self._freeze_parameters() + self._init_peft() + self._init_loss() + + if verbose: + self.print_trainable_parameters() + + def print_trainable_parameters(self): + logger.print("\n🔥 [red bold]Trainable parameters:") + for name, param in self.named_parameters(): + if param.requires_grad: + logger.print(f"[red]- {name} shape = {tuple(param.shape)}") + + all_params = sum(p.numel() for p in self.parameters()) + trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) + logger.print( + f"Total parameters: {all_params}, trainable: {trainable_params}, %: {trainable_params / all_params * 100:.4f}" + ) + + def _init_feature_extractor(self): + logger.print("\n[blue]Initializing image encoder...") + + backbone = self.config.backbone + backbone_lowercase = backbone.lower() + + if "clip" in backbone_lowercase: + from src.encoders.clip_encoder import CLIPEncoder + + self.feature_extractor = CLIPEncoder(backbone) + + elif "vit_pe" in backbone_lowercase: + from src.encoders.perception_encoder import PerceptionEncoder + + self.feature_extractor = PerceptionEncoder(backbone, self.config.backbone_args.img_size) + + elif "dino" in backbone_lowercase: + from src.encoders.dino_encoder import DINOEncoder + + if self.config.backbone_args is not None: + merge_cls_token_with_patches = self.config.backbone_args.merge_cls_token_with_patches + else: + merge_cls_token_with_patches = None + + self.feature_extractor = DINOEncoder(backbone, merge_cls_token_with_patches) + + else: + raise ValueError(f"Unknown backbone: {backbone}") + + logger.print(self.feature_extractor) + + # self.feature_extractor.eval() + # self.feature_extractor.to(self.device) + + def _init_peft(self): + if self.config.peft_v2 is not None: + from peft import get_peft_model + + if self.config.peft_v2.lora is not None: + from peft import LoraConfig + + peft_config = LoraConfig( + target_modules=self.config.peft_v2.lora.target_modules, + r=self.config.peft_v2.lora.rank, + lora_alpha=self.config.peft_v2.lora.alpha, + lora_dropout=self.config.peft_v2.lora.dropout, + bias=self.config.peft_v2.lora.bias, + use_rslora=self.config.peft_v2.lora.use_rslora, + use_dora=self.config.peft_v2.lora.use_dora, + ) + + else: + raise ValueError("Unknown PEFT configuration") + + backbone = self.feature_extractor + training_parameters = {name for name, param in backbone.named_parameters() if param.requires_grad} + + self.feature_extractor = get_peft_model(self.feature_extractor, peft_config) + + for name, param in backbone.named_parameters(): + if name in training_parameters: + param.requires_grad = True + + def _init_head(self): + logger.print("\n[blue]Initializing head...") + + features_dim = self.feature_extractor.get_features_dim() + + match self.config.head: + case Head.Linear: + self.model = head.LinearProbe(features_dim, self.config.num_classes) + + case Head.NLinear: + self.model = head.LinearProbe(features_dim, self.config.num_classes, True) + + case _: + raise ValueError(f"Unknown head: {self.config.head}") + + # self.model.eval() + # self.model.to(self.device) + + logger.print(self.model) + + def _freeze_parameters(self): + # Freeze feature extractor + self.feature_extractor.requires_grad_(not self.config.freeze_feature_extractor) + + if len(self.config.unfreeze_layers) > 0: + for name, param in self.named_parameters(): + if any(layer in name for layer in self.config.unfreeze_layers): + param.requires_grad = True + + def _init_loss(self): + self.criterion = Loss(self.config.loss) + + def get_preprocessing(self) -> Callable[[Image.Image], torch.Tensor]: + def preprocessing(image: Image.Image) -> torch.Tensor: + image = self.custom_preprocessing(image) + image = self.feature_extractor.preprocess(image) + return image + + return preprocessing + + def forward(self, inputs: torch.Tensor) -> head.HeadOutput: + features = self.feature_extractor(inputs) + outputs = self.model.forward(features) + return outputs + + def log_loss(self, loss: LossOutputs, stage: str, batch_size: int): + common = {"prog_bar": self.is_debug_mode, "on_epoch": True, "on_step": False, "batch_size": batch_size} + if loss.total is not None: + self.log(f"{stage}/loss", loss.total, **common) + if loss.ce_labels is not None: + self.log(f"{stage}/loss_ce", loss.ce_labels, **common) + + def log_aliunif(self, outputs: head.HeadOutput, labels: torch.Tensor, stage: str, batch_size: int): + alignment = unifalign.alignment(outputs.l2_embeddings, labels) + uniformity = unifalign.uniformity(outputs.l2_embeddings) + common = {"prog_bar": self.is_debug_mode, "on_epoch": True, "on_step": False, "batch_size": batch_size} + self.log(f"{stage}/alignment", alignment, **common) + self.log(f"{stage}/uniformity", uniformity, **common) + + def get_probs(self, outputs: head.HeadOutput): + if self.config.inference_strategy == C.InferenceStrategy.SOFTMAX: + return outputs.logits_labels.softmax(1) + + raise NotImplementedError("Unknown inference strategy") + + def get_batch(self, batch: dict) -> Batch: + return Batch.from_dict(batch) + + def on_train_start(self): + logger.print(f"[blue]Logs: {self.logger.log_dir}") + self.log("num_train_files", len(self.trainer.datamodule.train_dataset)) + self.log("num_val_files", len(self.trainer.datamodule.val_dataset)) + + def on_train_epoch_start(self): + # Log learning rate for the current epoch + self.log("lr", self.trainer.optimizers[0].param_groups[0]["lr"]) + + def training_step(self, batch, batch_idx): + batch = self.get_batch(batch) + # outputs = self.forward(batch.images) + features = self.feature_extractor(batch.images) + outputs = self.model.forward(features) + + loss_inputs = LossInputs( + logits_labels=outputs.logits_labels, + labels=batch.labels, + l2_embeddings=outputs.l2_embeddings, + ) + loss = self.criterion(loss_inputs) + + probs = self.get_probs(outputs) # Get probabilities based on the inference strategy + + # Log metrics + self.log_loss(loss, "train", batch_size=len(batch.images)) + self.log_aliunif(outputs, batch.labels, "train", batch_size=len(batch.images)) + + # Save outputs for metrics calculation + self.train_step_outputs.labels.update(batch.labels) + self.train_step_outputs.probs.update(probs.detach()) + self.train_step_outputs.idx.update(batch.idx) + + return loss.total + + def on_train_epoch_end(self): + if self.logger.log_dir is None: + # TODO: figure out why logger.log_dir can be None + return + + # Log weights norms + with TryExcept(verbose=False): + self.log("model/linear-W-norm", self.model.linear.weight.norm().item()) + self.log("model/linear-b-norm", self.model.linear.bias.norm().item()) + + dataset = self.trainer.datamodule.train_dataset + self.log_all_metrics(self.train_step_outputs, "train", dataset) + + def validation_step(self, batch, batch_idx): + batch = self.get_batch(batch) + outputs = self.forward(batch.images) + loss_inputs = LossInputs( + logits_labels=outputs.logits_labels, + labels=batch.labels, + l2_embeddings=outputs.l2_embeddings, + ) + loss = self.criterion(loss_inputs) + probs = self.get_probs(outputs) + + self.log_loss(loss, "val", len(batch.images)) + self.log_aliunif(outputs, batch.labels, "val", len(batch.images)) + + # Save outputs for metrics calculation + self.val_step_outputs.labels.update(batch.labels) + self.val_step_outputs.probs.update(probs.detach()) + self.val_step_outputs.idx.update(batch.idx) + + def test_step(self, batch, batch_idx): + batch = self.get_batch(batch) + outputs = self.forward(batch.images) + loss_inputs = LossInputs( + logits_labels=outputs.logits_labels, + labels=batch.labels, + l2_embeddings=outputs.l2_embeddings, + ) + loss = self.criterion(loss_inputs) + probs = self.get_probs(outputs) + + self.log_loss(loss, "test", len(batch.images)) + self.log_aliunif(outputs, batch.labels, "test", len(batch.images)) + + # Save outputs for metrics calculation + self.test_step_outputs.labels.update(batch.labels) + self.test_step_outputs.probs.update(probs.detach()) + self.test_step_outputs.idx.update(batch.idx) + + def on_validation_epoch_end(self): + if self.logger.log_dir is None: + # TODO: figure out why logger.log_dir can be None + return + + dataset = self.trainer.datamodule.val_dataset + self.log_all_metrics(self.val_step_outputs, "val", dataset) + + def configure_optimizers(self): + self.trainer.fit_loop.setup_data() # because we need an access to the dataloader + config = self.config + + # Separate parameters for weight decay and no weight decay + decay_params = [] + no_decay_params = [] + for name, param in self.named_parameters(): + if not param.requires_grad: + continue + if "bias" in name or "norm" in name: + no_decay_params.append(param) + else: + decay_params.append(param) + + optimizer_grouped_parameters = [ + {"params": decay_params, "weight_decay": config.weight_decay}, + {"params": no_decay_params, "weight_decay": 0.0}, + ] + + # Configure optimizer + if config.optimizer == C.Optimizer.AdamW: + optimizer = optim.AdamW( + optimizer_grouped_parameters, + lr=config.lr, + weight_decay=config.weight_decay, + betas=config.betas, + ) + elif config.optimizer == C.Optimizer.SGD: + optimizer = optim.SGD( + optimizer_grouped_parameters, + lr=config.lr, + momentum=config.betas[0], + weight_decay=config.weight_decay, + ) + else: + raise ValueError(f"Unknown optimizer: {config.optimizer}") + + optimizers = {"optimizer": optimizer} + + scheduler = None + + # Configure LR scheduler + if config.lr_scheduler == "cosine": + #! be careful when running experiments with limit_train_batches + if config.limit_train_batches is not None: + logger.print_warning_once("lr scheduling and limit_train_batches are not compatible") + T_max = config.max_epochs * len(self.trainer.train_dataloader) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max, eta_min=config.min_lr) + + elif config.lr_scheduler == "cyclic": + cycle_length_in_epochs = int(config.num_epochs_in_cycle * len(self.trainer.train_dataloader)) + scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer, T_0=cycle_length_in_epochs, T_mult=1, eta_min=config.min_lr + ) + + # Configure warmup + if config.warmup_epochs > 0: + total_warmup_steps = int(config.warmup_epochs * len(self.trainer.train_dataloader)) + warmup = optim.lr_scheduler.LinearLR( + optimizer, start_factor=config.min_lr / config.lr, total_iters=total_warmup_steps + ) + + if scheduler is not None: + scheduler = optim.lr_scheduler.SequentialLR( + optimizer, [warmup, scheduler], milestones=[total_warmup_steps] + ) + else: + scheduler = warmup + + if scheduler is not None: + optimizers["lr_scheduler"] = { + "scheduler": scheduler, + "interval": "step", + "frequency": 1, + } + + return optimizers diff --git a/src/model/GenDHF.py b/src/model/GenDHF.py new file mode 100644 index 0000000000000000000000000000000000000000..eb8a2b0cbd8a7291839c7ca27cba4a89b1f566cb --- /dev/null +++ b/src/model/GenDHF.py @@ -0,0 +1,58 @@ +import os +from typing import override + +import torch +import torchvision.transforms as T +from PIL import Image + +from src.config import Config, CustomPreprocessing +from src.heads.head import HeadOutput +from src.hf.modeling_gend import GenD +from src.model.base import BaseDeepakeDetectionModel +from src.utils import logger + + +class GenDHF(BaseDeepakeDetectionModel): + def __init__(self, config: Config): + super().__init__(config, verbose=True) + self.model = GenD.from_pretrained(config.checkpoint) + self.model.eval() + + @override + def forward(self, inputs: torch.Tensor) -> HeadOutput: + return HeadOutput(logits_labels=self.model(inputs)) + + @override + def test_step(self, batch, batch_idx): + batch = self.get_batch(batch) + outputs = self.forward(batch.images) + probs = outputs.logits_labels.softmax(dim=1) + + # Save outputs for metrics calculation + self.test_step_outputs.labels.update(batch.labels) + self.test_step_outputs.probs.update(probs.detach()) + self.test_step_outputs.idx.update(batch.idx) + + @override + def load_checkpoint(self, checkpoint_path: str): + """Load the model checkpoint.""" + pass # Handled by from_pretrained + + @override + def get_preprocessing(self): + return self.model.feature_extractor.preprocess + + +if __name__ == "__main__": + config = Config( + checkpoint="yermandy/GenD_CLIP_L_14", + ) + model = GenDHF(config) + model.load_checkpoint(config.checkpoint) + + image = Image.open("datasets/FF/DF/001_870/000.png") + # image = Image.open("datasets/FF/real/001/000.png") + preprocessed_image = model.get_preprocessing()(image) # Convert to tensor + batch = preprocessed_image.unsqueeze(0) # Add batch dimension + outputs = model.forward(batch) + print(outputs.logits_labels.softmax(dim=-1)) diff --git a/src/model/__init__.py b/src/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/model/base.py b/src/model/base.py new file mode 100644 index 0000000000000000000000000000000000000000..2ef94973c0949b64b4dd1d87c285074d597285de --- /dev/null +++ b/src/model/base.py @@ -0,0 +1,446 @@ +from dataclasses import dataclass +from typing import Callable, Literal + +import lightning as pl +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import wandb +from lightning import seed_everything +from lightning.pytorch.loggers import WandbLogger +from PIL import Image +from sklearn import metrics as M +from torchmetrics import CatMetric + +from src import metrics, plots +from src.config import Config +from src.dataset.base import BaseDataset +from src.utils import logger +from src.utils.decorators import TryExcept + + +class OutputsForMetrics(nn.Module): + def __init__(self): + super().__init__() + self.probs = CatMetric() + self.labels = CatMetric() + self.idx = CatMetric() + + def reset(self): + self.probs.reset() + self.labels.reset() + self.idx.reset() + + +@dataclass +class Batch: + images: torch.Tensor + labels: None | torch.Tensor + identity: None | torch.Tensor + source_uids: None | torch.Tensor + idx: None | torch.Tensor + + def __getitem__(self, key): + # if batch["image"] is called, return batch.images + return getattr(self, key) + + @staticmethod + def from_dict(batch: dict): + assert "image" in batch, "Batch must contain 'image' key" + + return Batch( + images=batch.get("image"), + labels=batch.get("label"), + identity=batch.get("identity"), + source_uids=batch.get("source_uid"), + idx=batch.get("idx"), + ) + + +def compute_across_videos(files: list, probs: np.ndarray, labels: np.ndarray, reduce: Literal["mean", "median"]): + """ + Calculate mean probs for each video across all frames + """ + + # Get all before the last / + # For example: a/b/c/d -> a/b/c + videos = [f[: -f[::-1].find("/")] for f in files] + + # Group by video: video -> [indices] + video2idx = {v: [] for v in videos} + for i, v in enumerate(videos): + video2idx[v].append(i) + + # Calculate mean probs for each video across all frames + video2probs = {v: [] for v in videos} + video2labels = {v: [] for v in videos} + for v, idxs in video2idx.items(): + if reduce == "mean": + video2probs[v] = np.mean(probs[idxs], axis=0) + elif reduce == "median": + video2probs[v] = np.median(probs[idxs], axis=0) + else: + raise ValueError(f"Unknown reduce method: {reduce}") + video2labels[v] = labels[idxs[0]] # Assume all frames have the same label + + video_probs = np.array(list(video2probs.values())) + video_labels = np.array(list(video2labels.values())) + + return video_probs, video_labels + + +class BaseDeepakeDetectionModel(pl.LightningModule): + def __init__(self, config: Config, verbose: bool = False): + super().__init__() + self.config = config + self.save_hyperparameters(config.model_dump()) + self.is_debug_mode = "tmp" in config.run_name + + if verbose: + logger.print(config) + + seed_everything(self.config.seed, workers=True, verbose=verbose) + + self._init_metrics() + + def _init_metrics(self): + self.train_step_outputs = OutputsForMetrics() + self.val_step_outputs = OutputsForMetrics() + self.test_step_outputs = OutputsForMetrics() + + def get_preprocessing(self) -> Callable[[Image.Image], torch.Tensor]: + raise NotImplementedError("get_preprocessing must be implemented in the child class") + + def get_batch(self, batch: dict) -> Batch: + return Batch.from_dict(batch) + + def on_train_epoch_end(self): + if self.logger.log_dir is None: + # TODO: figure out why logger.log_dir can be None + return + + # Log weights norms + with TryExcept(verbose=False): + self.log("model/linear-W-norm", self.model.linear.weight.norm().item()) + self.log("model/linear-b-norm", self.model.linear.bias.norm().item()) + + # Log learned temperature + with TryExcept(verbose=False): + self.log("model/criterion/compactness_loss/temp", self.criterion.compactness_loss.temp.item()) + + with TryExcept(verbose=False): + self.log("model/criterion/dispersion_loss/temp", self.criterion.dispersion_loss.temp.item()) + + dataset = self.trainer.datamodule.train_dataset + self.log_all_metrics(self.train_step_outputs, "train", dataset) + + def log_metrics( + self, + probs: np.ndarray, + labels: np.ndarray, + stage: Literal["train", "test", "val"], + prefix: str, + level: Literal["frame", "video"], + dataset: BaseDataset, + ): + """ + Images are saved to + `log_dir / prefix / level_metrics / metric.png` + """ + + log_dir = self.logger.log_dir + + Stage = stage.capitalize() + + # Compute ROC and PR curves for every class + fprs, tprs, roc_ths, ovr_macro_auroc = metrics.ovr_roc(labels, probs) + precs, recs, pr_ths, ovr_macro_ap = metrics.ovr_prc(labels, probs) + + if self.config.num_classes == 2: + # Compute EER (Equal Error Rate) + eer, eer_th = metrics.calculate_eer(labels, probs, True) + self.log(f"{prefix}/eer_{level}", eer) + self.log(f"{prefix}/eer_th_{level}", eer_th) + + # Compute TPR at selected FPRs, e.g., 0.1%, 1%, 5% + selected_fprs = [0.001, 0.01, 0.05] + tpr_at_fprs = metrics.calculate_tpr_at_fpr(labels, probs, selected_fprs) + for target_fpr, tpr in zip(selected_fprs, tpr_at_fprs): + self.log(f"{prefix}/TPR@FPR={target_fpr}_{level}", tpr) + + plots.plot_fpr_fnr_curve( + fprs, + tprs, + roc_ths, + title=f"{Stage} FPR vs FNR ({level}-level)", + path=f"{log_dir}/{prefix}/{level}_metrics/{stage}_fpr_fnr_curve.png", + eer=eer, + ) + + W1_sep_real, W1_sep_fake, W1_conf_real, W1_conf_fake = metrics.compute_wasserstein1_metrics(probs, labels) + + if W1_sep_real is not None: + self.log(f"{prefix}/W1-sep-real_{level}", W1_sep_real) + self.log(f"{prefix}/W1-sep-fake_{level}", W1_sep_fake) + + # A mean of Wasserstein distances + self.log(f"{prefix}/W1-sep_{level}", (W1_sep_real + W1_sep_fake) / 2) + + self.log(f"{prefix}/W1-conf-real_{level}", W1_conf_real) + self.log(f"{prefix}/W1-conf-fake_{level}", W1_conf_fake) + + # A mean of Wasserstein distances + self.log(f"{prefix}/W1-conf_{level}", (W1_conf_real + W1_conf_fake) / 2) + + # Compute predictions by EER threshold + preds = np.where(probs[:, 1] > eer_th, 1, 0) + + else: + # Compute predictions by argmax rule + preds = probs.argmax(1) + + # Log metrics + self.log(f"{prefix}/auroc_{level}", ovr_macro_auroc) + self.log(f"{prefix}/acc_{level}", M.accuracy_score(labels, preds)) + self.log(f"{prefix}/balanced_acc_{level}", M.balanced_accuracy_score(labels, preds)) + self.log(f"{prefix}/f1_score_{level}", M.f1_score(labels, preds, average="macro")) + self.log(f"{prefix}/mAP_{level}", ovr_macro_ap) + + class_names = dataset.get_class_names() + + plots.plot_probs_distribution( + probs, + labels, + class_names, + f"{log_dir}/{prefix}/{level}_metrics/{stage}_probs_distribution.png", + ) + + plots.plot_roc_curve( + fprs, + tprs, + roc_ths, + f"{Stage} ROC ({level}-level)", + f"{log_dir}/{prefix}/{level}_metrics/{stage}_roc_{level}.png", + 0.01, + class_names, + ) + + plots.plot_prc_curve( + precs, + recs, + pr_ths, + f"{Stage} PR Curve ({level}-level)", + f"{log_dir}/{prefix}/{level}_metrics/{stage}_pr_curve.png", + 0.01, + class_names, + ) + + plots.plot_f1_curve( + precs, + recs, + pr_ths, + f"{Stage} F1 Curve ({level}-level)", + f"{log_dir}/{prefix}/{level}_metrics/{stage}_f1_curve.png", + 0.01, + class_names, + ) + + # Confusion matrix + conf = M.confusion_matrix(labels, preds) + plots.plot_confusion_matrix( + conf, + class_names, + f"{Stage} Confusion Matrix ({level}-level)", + f"{log_dir}/{prefix}/{level}_metrics/{stage}_confusion.png", + ) + plots.plot_confusion_matrix( + conf, + class_names, + f"{Stage} Confusion Matrix ({level}-level)", + f"{log_dir}/{prefix}/{level}_metrics/{stage}_confusion_norm.png", + True, + ) + + wandb_logger = self.get_wandb_logger() + if wandb_logger is not None: + wandb_logger.log_metrics( + { + f"confusion/{prefix}/{stage}_{level}": wandb.plot.confusion_matrix( + y_true=labels, + preds=preds, + class_names=["real", "fake"], + title=f"{Stage} Confusion Matrix {level.capitalize()}", + ) + } + ) + + def sources_probs_to_binary(self, probs: np.ndarray) -> np.ndarray: + # probs[:, 0] # is real probs + # probs[:, 1:] # is fake probs (for each generator) + return np.stack([probs[:, 0], probs[:, 1:].max(axis=1)], 1) + + def log_all_metrics( + self, + outputs_for_metrics: OutputsForMetrics, + stage: Literal["train", "test", "val"], + dataset: BaseDataset, + ): + # Merge all predictions and labels across processes + labels = outputs_for_metrics.labels.compute().cpu().int().numpy() + probs = outputs_for_metrics.probs.compute().cpu().numpy() + idx = outputs_for_metrics.idx.compute().cpu().int().numpy() + files = [dataset.files[i] for i in idx] # Get files in the same order as the rest + outputs_for_metrics.reset() + + if self.config.make_binary_before_video_aggregation: + if probs.shape[1] > 2: + probs = self.sources_probs_to_binary(probs) + + # Compute probs and labels for videos + video_probs, video_labels = compute_across_videos(files, probs, labels, self.config.reduce_video_predictions) + + # Convery to binary if sources are used + if not self.config.make_binary_before_video_aggregation: + if probs.shape[1] > 2: + probs = self.sources_probs_to_binary(probs) + video_probs = self.sources_probs_to_binary(video_probs) + + self.log_metrics(probs, labels, stage, stage, "frame", dataset) + self.log_metrics(video_probs, video_labels, stage, stage, "video", dataset) + + # if trn_files / val_files / tst_files is dict, separate metrics for each dataset + if dataset.dataset2files is not None: + if not self.config.make_binary_before_video_aggregation: + logger.print_warning( + "`make_binary_before_video_aggregation=False` is not supported when trn_files / val_files / tst_files is dict" + ) + + file2index = {f: i for i, f in enumerate(files)} + for dataset_name, dataset_files in dataset.dataset2files.items(): + # Get files only for current dataset + dataset_files = np.intersect1d(files, dataset_files) + file_indices = [file2index[f] for f in dataset_files] + dataset_probs = probs[file_indices] + dataset_labels = labels[file_indices] + dataset_files = [files[i] for i in file_indices] + + self.log_metrics( + dataset_probs, + dataset_labels, + stage, + f"{stage}/dataset/{dataset_name}", + "frame", + dataset, + ) + + dataset_video_probs, dataset_video_labels = compute_across_videos( + dataset_files, dataset_probs, dataset_labels, self.config.reduce_video_predictions + ) + + self.log_metrics( + dataset_video_probs, + dataset_video_labels, + stage, + f"{stage}/dataset/{dataset_name}", + "video", + dataset, + ) + + def custom_preprocessing(self, image: Image.Image) -> Image.Image: + if self.config.custom_preprocessing is None: + return image + + if self.config.custom_preprocessing.zoom_factor != 1.0: + zoom_factor = self.config.custom_preprocessing.zoom_factor + + width, height = image.size + # Calculate crop size (smaller portion of the image to simulate zoom-in) + crop_w = width // zoom_factor + crop_h = height // zoom_factor + + # Center crop coordinates + left = (width - crop_w) // 2 + top = (height - crop_h) // 2 + right = left + crop_w + bottom = top + crop_h + + # Crop and resize back to original size + cropped_img = image.crop((left, top, right, bottom)) + + if self.config.custom_preprocessing.image_size is not None: + image = cropped_img.resize(self.config.custom_preprocessing.image_size, Image.BILINEAR) + else: + # Use bilinear interpolation to preserve artifacts + image = cropped_img.resize((width, height), Image.BILINEAR) + + if self.config.custom_preprocessing.image_size is not None: + image = image.resize(self.config.custom_preprocessing.image_size, Image.BILINEAR) + + if self.config.custom_preprocessing.flip_left_right: + image = image.transpose(Image.FLIP_LEFT_RIGHT) + + return image + + def get_wandb_logger(self) -> WandbLogger | None: + """ + Get the WandbLogger instance from the current loggers. + Returns None if no WandbLogger is found. + """ + for l in self.loggers: + if isinstance(l, WandbLogger): + return l + return None + + def on_test_start(self): + logger.print(f"[blue]Logs: {self.logger.log_dir}") + self.log("num_test_files", len(self.trainer.datamodule.test_dataset)) + + def test_step(self, batch, batch_idx): + raise NotImplementedError("test_step must be implemented in the child class") + + def on_test_epoch_end(self): + if self.logger.log_dir is None: + # TODO: figure out why logger.log_dir can be None + return + + # Concatenate all predictions and labels + probs = self.test_step_outputs.probs.compute().cpu().numpy() + labels = self.test_step_outputs.labels.compute().cpu().int().numpy() + idx = self.test_step_outputs.idx.compute().cpu().int().numpy() + + dataset = self.trainer.datamodule.test_dataset + + files = [dataset.files[i] for i in idx] + + # preds is a 2D array of shape (num_samples, num_classes) + probs = {f"prob_class_{i}": np.round(probs[:, i], 4) for i in range(probs.shape[1])} + table = pd.DataFrame({"files": files, "labels": labels, **probs}) + + # Save to CSV + table.to_csv(f"{self.logger.log_dir}/test_predictions.csv", index=False, float_format="%.4f") + + self.log_all_metrics(self.test_step_outputs, "test", dataset) + + def load_checkpoint(self, checkpoint: str): + if checkpoint: + state_dict = torch.load(checkpoint, map_location="cpu", weights_only=True)["state_dict"] + incompatible_keys = self.load_state_dict(state_dict, strict=False) + self.print_checkpoint_keys(incompatible_keys) + + def print_checkpoint_keys(self, incompatible_keys): + missing_keys = set(incompatible_keys.missing_keys) + unexpected_keys = set(incompatible_keys.unexpected_keys) + + logger.print("\n[blue bold]Keys in checkpoint:") + logger.print("[red bold]- Missing") + logger.print("[yellow bold]? Unexpected") + logger.print("[green bold]+ Matched\n") + + for key in self.state_dict().keys(): + if key in missing_keys: + logger.print(f"[red]- {key}") + elif key in unexpected_keys: + logger.print(f"[orange]? {key}") + else: + logger.print(f"[green]+ {key}") diff --git a/src/model/effort/model.py b/src/model/effort/model.py new file mode 100644 index 0000000000000000000000000000000000000000..0b509971d9d0e50e3edc05ceee8f7bb148ca67b7 --- /dev/null +++ b/src/model/effort/model.py @@ -0,0 +1,345 @@ +import math + +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as T +from PIL import Image +from transformers import CLIPModel + + +class EffortModel(nn.Module): + def __init__(self, config=None): + super().__init__() + self.config = config + self.backbone = self.build_backbone(config) + self.head = nn.Linear(1024, 2) + self.loss_func = nn.CrossEntropyLoss() + self.prob, self.label = [], [] + self.correct, self.total = 0, 0 + + def build_backbone(self, config): + # ⚠⚠⚠ Download CLIP model using the below link + # https://drive.google.com/drive/folders/1fm3Jd8lFMiSP1qgdmsxfqlJZGpr_bXsx?usp=drive_link + + # mean: [0.48145466, 0.4578275, 0.40821073] + # std: [0.26862954, 0.26130258, 0.27577711] + + # ViT-L/14 224*224 + # the path of this folder in your disk (download from the above link) + clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") + + # Apply SVD to self_attn layers only + # ViT-L/14 224*224: 1024-1 + clip_model.vision_model = apply_svd_residual_to_self_attn(clip_model.vision_model, r=1024 - 1) + + # for name, param in clip_model.vision_model.named_parameters(): + # print('{}: {}'.format(name, param.requires_grad)) + # num_param = sum(p.numel() for p in clip_model.vision_model.parameters() if p.requires_grad) + # num_total_param = sum(p.numel() for p in clip_model.vision_model.parameters()) + # print('Number of total parameters: {}, tunable parameters: {}'.format(num_total_param, num_param)) + + return clip_model.vision_model + + def features(self, inputs: torch.Tensor) -> torch.tensor: + return self.backbone(inputs).pooler_output + + def classifier(self, features: torch.tensor) -> torch.tensor: + return self.head(features) + + # def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + # label = data_dict['label'] + # pred = pred_dict['cls'] + # loss = self.loss_func(pred, label) + # + # if self.training: + # # Regularization term + # lambda_reg = 1.0 + # reg_term = 0.0 + # num_reg = 0 + # for module in self.backbone.modules(): + # if isinstance(module, SVDResidualLinear): + # reg_term += module.compute_orthogonal_loss() + # reg_term += module.compute_keepsv_loss() + # num_reg += 1 + # + # loss += lambda_reg * reg_term / num_reg + # + # loss_dict = {'overall': loss} + # return loss_dict + + def compute_weight_loss(self): + weight_sum_dict = {} + num_weight_dict = {} + for name, module in self.backbone.named_modules(): + if isinstance(module, SVDResidualLinear): + weight_curr = module.compute_current_weight() + if str(weight_curr.size()) not in weight_sum_dict.keys(): + weight_sum_dict[str(weight_curr.size())] = weight_curr + num_weight_dict[str(weight_curr.size())] = 1 + else: + weight_sum_dict[str(weight_curr.size())] += weight_curr + num_weight_dict[str(weight_curr.size())] += 1 + + loss2 = 0.0 + for k in weight_sum_dict.keys(): + _, S_sum, _ = torch.linalg.svd(weight_sum_dict[k], full_matrices=False) + loss2 += -torch.mean(S_sum) + loss2 /= len(weight_sum_dict.keys()) + return loss2 + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict["label"] # Tensor of shape [batch_size] + pred = pred_dict["logits"] # Tensor of shape [batch_size, num_classes] + + # Compute overall loss using all samples + loss = self.loss_func(pred, label) + + # Create masks for real and fake classes + mask_real = label == 0 # Boolean tensor + mask_fake = label == 1 # Boolean tensor + + # Compute loss for real class + if mask_real.sum() > 0: + pred_real = pred[mask_real] + label_real = label[mask_real] + loss_real = self.loss_func(pred_real, label_real) + else: + # No real samples in batch + loss_real = torch.tensor(0.0, device=pred.device) + + # Compute loss for fake class + if mask_fake.sum() > 0: + pred_fake = pred[mask_fake] + label_fake = label[mask_fake] + loss_fake = self.loss_func(pred_fake, label_fake) + else: + # No fake samples in batch + loss_fake = torch.tensor(0.0, device=pred.device) + + # loss2 = self.compute_weight_loss() + # overall_loss = loss + loss2 + + # Return a dictionary with all losses + loss_dict = { + "overall": loss, + "real_loss": loss_real, + "fake_loss": loss_fake, + # 'erank_loss': loss2 + } + return loss_dict + + # def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + # label = data_dict["label"] + # pred = pred_dict["logits"] + # # compute metrics for batch data + # auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + # metric_batch_dict = {"acc": acc, "auc": auc, "eer": eer, "ap": ap} + # return metric_batch_dict + + def forward(self, inputs: torch.Tensor): + # Get features from the backbone + features = self.features(inputs) + + # Get logits from the classifier + logits = self.classifier(features) + + normalized_features = F.normalize(features, p=2, dim=1) + + return logits, normalized_features + + +# Custom module to represent the residual using SVD components +class SVDResidualLinear(nn.Module): + def __init__(self, in_features, out_features, r, bias=True, init_weight=None): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.r = r # Number of top singular values to exclude + + # Original weights (fixed) + self.weight_main = nn.Parameter(torch.Tensor(out_features, in_features), requires_grad=False) + if init_weight is not None: + self.weight_main.data.copy_(init_weight) + else: + nn.init.kaiming_uniform_(self.weight_main, a=math.sqrt(5)) + + # Bias + if bias: + self.bias = nn.Parameter(torch.Tensor(out_features)) + nn.init.zeros_(self.bias) + else: + self.register_parameter("bias", None) + + def compute_current_weight(self): + if self.S_residual is not None: + return self.weight_main + self.U_residual @ torch.diag(self.S_residual) @ self.V_residual + else: + return self.weight_main + + def forward(self, x): + if hasattr(self, "U_residual") and hasattr(self, "V_residual") and self.S_residual is not None: + # Reconstruct the residual weight + residual_weight = self.U_residual @ torch.diag(self.S_residual) @ self.V_residual + # Total weight is the fixed main weight plus the residual + weight = self.weight_main + residual_weight + else: + # If residual components are not set, use only the main weight + weight = self.weight_main + + return F.linear(x, weight, self.bias) + + def compute_orthogonal_loss(self): + if self.S_residual is not None: + # According to the properties of orthogonal matrices: A^TA = I + UUT = torch.cat((self.U_r, self.U_residual), dim=1) @ torch.cat((self.U_r, self.U_residual), dim=1).t() + VVT = torch.cat((self.V_r, self.V_residual), dim=0) @ torch.cat((self.V_r, self.V_residual), dim=0).t() + # print(self.U_r.size(), self.U_residual.size()) # torch.Size([1024, 1023]) torch.Size([1024, 1]) + # print(self.V_r.size(), self.V_residual.size()) # torch.Size([1023, 1024]) torch.Size([1, 1024]) + # UUT = self.U_residual @ self.U_residual.t() + # VVT = self.V_residual @ self.V_residual.t() + + # Construct an identity matrix + UUT_identity = torch.eye(UUT.size(0), device=UUT.device) + VVT_identity = torch.eye(VVT.size(0), device=VVT.device) + + # Using frobenius norm to compute loss + loss = 0.5 * torch.norm(UUT - UUT_identity, p="fro") + 0.5 * torch.norm(VVT - VVT_identity, p="fro") + else: + loss = 0.0 + + return loss + + def compute_keepsv_loss(self): + if (self.S_residual is not None) and (self.weight_original_fnorm is not None): + # Total current weight is the fixed main weight plus the residual + weight_current = self.weight_main + self.U_residual @ torch.diag(self.S_residual) @ self.V_residual + # Frobenius norm of current weight + weight_current_fnorm = torch.norm(weight_current, p="fro") + + loss = torch.abs(weight_current_fnorm**2 - self.weight_original_fnorm**2) + # loss = torch.abs(weight_current_fnorm ** 2 + 0.01 * self.weight_main_fnorm ** 2 - 1.01 * self.weight_original_fnorm ** 2) + else: + loss = 0.0 + + return loss + + def compute_fn_loss(self): + if self.S_residual is not None: + weight_current = self.weight_main + self.U_residual @ torch.diag(self.S_residual) @ self.V_residual + weight_current_fnorm = torch.norm(weight_current, p="fro") + + loss = weight_current_fnorm**2 + else: + loss = 0.0 + + return loss + + +# Function to replace nn.Linear modules within self_attn modules with SVDResidualLinear +def apply_svd_residual_to_self_attn(model, r): + for name, module in model.named_children(): + if "self_attn" in name: + # Replace nn.Linear layers in this module + for sub_name, sub_module in module.named_modules(): + if isinstance(sub_module, nn.Linear): + # Get parent module within self_attn + parent_module = module + sub_module_names = sub_name.split(".") + for module_name in sub_module_names[:-1]: + parent_module = getattr(parent_module, module_name) + # Replace the nn.Linear layer with SVDResidualLinear + setattr(parent_module, sub_module_names[-1], replace_with_svd_residual(sub_module, r)) + else: + # Recursively apply to child modules + apply_svd_residual_to_self_attn(module, r) + # After replacing, set requires_grad for residual components + for param_name, param in model.named_parameters(): + if any(x in param_name for x in ["S_residual", "U_residual", "V_residual"]): + param.requires_grad = True + else: + param.requires_grad = False + return model + + +# Function to replace a module with SVDResidualLinear +def replace_with_svd_residual(module, r): + if isinstance(module, nn.Linear): + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + + # Create SVDResidualLinear module + new_module = SVDResidualLinear(in_features, out_features, r, bias=bias, init_weight=module.weight.data.clone()) + + if bias and module.bias is not None: + new_module.bias.data.copy_(module.bias.data) + + new_module.weight_original_fnorm = torch.norm(module.weight.data, p="fro") + + # Perform SVD on the original weight + U, S, Vh = torch.linalg.svd(module.weight.data, full_matrices=False) + + # Determine r based on the rank of the weight matrix + r = min(r, len(S)) # Ensure r does not exceed the number of singular values + + # Keep top r singular components (main weight) + U_r = U[:, :r] # Shape: (out_features, r) + S_r = S[:r] # Shape: (r,) + Vh_r = Vh[:r, :] # Shape: (r, in_features) + + # Reconstruct the main weight (fixed) + weight_main = U_r @ torch.diag(S_r) @ Vh_r + + # Calculate the frobenius norm of main weight + new_module.weight_main_fnorm = torch.norm(weight_main.data, p="fro") + + # Set the main weight + new_module.weight_main.data.copy_(weight_main) + + # Residual components (trainable) + U_residual = U[:, r:] # Shape: (out_features, n - r) + S_residual = S[r:] # Shape: (n - r,) + Vh_residual = Vh[r:, :] # Shape: (n - r, in_features) + + if len(S_residual) > 0: + new_module.S_residual = nn.Parameter(S_residual.clone()) + new_module.U_residual = nn.Parameter(U_residual.clone()) + new_module.V_residual = nn.Parameter(Vh_residual.clone()) + + new_module.S_r = nn.Parameter(S_r.clone(), requires_grad=False) + new_module.U_r = nn.Parameter(U_r.clone(), requires_grad=False) + new_module.V_r = nn.Parameter(Vh_r.clone(), requires_grad=False) + else: + new_module.S_residual = None + new_module.U_residual = None + new_module.V_residual = None + + new_module.S_r = None + new_module.U_r = None + new_module.V_r = None + + return new_module + else: + return module + + +# This is the original preprocessing used in Effort paper +# Gives almost the same results as `preprocessing` +_preprocessing_original = T.Compose( + [ + T.ToTensor(), + T.Normalize( + [0.48145466, 0.4578275, 0.40821073], + [0.26862954, 0.26130258, 0.27577711], + ), + ] +) + + +def preprocessing_original(image: Image) -> torch.Tensor: + image = np.array(image) + image = cv2.resize(image, (224, 224), interpolation=cv2.INTER_LINEAR) + return _preprocessing_original(image) diff --git a/src/model/forada/__init__.py b/src/model/forada/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/model/forada/adapters/adapter.py b/src/model/forada/adapters/adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..8a97fda005abe7d4b2654e12168cbcd2d855f228 --- /dev/null +++ b/src/model/forada/adapters/adapter.py @@ -0,0 +1,139 @@ +from functools import partial + +import torch +from timm import create_model +from torch import nn +from torch.nn import functional as F + +from ..layer import MLP, VT_LN, Fusion, PatchEmbed + + +class Mask_Decoder(nn.Module): + def __init__(self, in_dim, mlp_dim=512, out_dim=256, mlp_num_layers=3, head_num=16): + super().__init__() + self.in_dim = in_dim + self.mlp_dim = mlp_dim + self.out_dim = out_dim + self.head_num = head_num + dense_affine_func = partial(nn.Conv2d, kernel_size=1) + self.query_mlp = MLP(in_dim, mlp_dim, out_dim, mlp_num_layers) # L R L R L + self.xray_mlp = MLP(in_dim, mlp_dim, out_dim, mlp_num_layers, affine_func=dense_affine_func) + self.attn_mlp = MLP(in_dim, mlp_dim, out_dim * self.head_num, mlp_num_layers, affine_func=dense_affine_func) + self.bias_scaling = nn.Linear(1, 1) + + def forward(self, query, x): + # query (N,QL,D) x (N D H W) + query = self.query_mlp(query) + xray = self.xray_mlp(x) + attn = self.attn_mlp(x) + patch_x = x.reshape(x.shape[0], x.shape[1], -1) # (N D L) + patch_x = patch_x.permute(0, 2, 1) # (N L D) + xray_pred = torch.einsum("NQD,NDhw->NQhw", query, xray) + n, d, h, w = xray.shape + attn = attn.reshape(n, self.head_num, d, h, w) # (N Head*D,h,w)->(N Head D h w) + attn_bias = torch.einsum("NQD,NHDhw->NHQhw", query, attn) + attn_bias = self.bias_scaling(attn_bias[..., None]).squeeze(-1) + return xray_pred, attn_bias + + +class Adapter(nn.Module): + def __init__(self, vit_name, num_quires, fusion_map, mlp_dim, mlp_out_dim, head_num): + super().__init__() + self.vit_model = create_model( + vit_name, + pretrained=False, + fc_norm=False, + num_classes=0, + embed_layer=PatchEmbed, + ) + + if self.vit_model.cls_token is not None: + self.vit_model.pos_embed = nn.Parameter(self.vit_model.pos_embed[:, 1:, ...]) # 去掉cls的位置 + del self.vit_model.cls_token + self.vit_model.cls_token = None + del self.vit_model.norm + self.vit_model.norm = nn.Identity() + self.num_quires = num_quires + self.num_features = self.vit_model.num_features + self.query_embed = nn.Parameter(torch.zeros(1, self.num_quires, self.num_features)) # (1,Q_L,D) + self.query_pos_embed = nn.Parameter(torch.zeros(1, self.num_quires, self.num_features)) + self.fusion_map = fusion_map + nn.init.normal_(self.query_embed, std=0.02) + nn.init.normal_(self.query_pos_embed, std=0.02) + self.mask_decoder = Mask_Decoder( + in_dim=self.num_features, mlp_dim=mlp_dim, out_dim=mlp_out_dim, mlp_num_layers=3, head_num=head_num + ) + self.ln_pre = VT_LN(self.num_features) + self.patch_conv = nn.Conv2d(in_channels=3, out_channels=self.num_features, kernel_size=16, stride=16, bias=True) + + def fuse(self, block_idx, x, clip_features, spatial_shape): + if block_idx in self.fusion_map.keys(): + clip_layer = self.fusion_map[block_idx] + # adapter_layer = block_idx + clip_dim = clip_features[clip_layer].shape[2] # clip features NLD + + fusion = Fusion(clip_dim, self.num_features).to(x.device) + L = spatial_shape[0] * spatial_shape[1] + x = torch.cat( + [ + x[:, :-L, ...], # query + # fuse vision token x(N,a_L,D) clip_f[i] (N,c_L,D) + fusion(x[:, -L:, ...], clip_features[clip_layer], spatial_shape), + ], + dim=1, + ) + return x + + def forward(self, data_dict, clip_features, inference): + image = data_dict["image"] + x = self.patch_conv(image) + x = x.reshape(x.shape[0], x.shape[1], -1) + x = x.permute(0, 2, 1) + + pos_embed = self.vit_model.pos_embed # (N L D) + pos_embed = pos_embed.permute(0, 2, 1) # (NDL) + pos_embed = ( + F.interpolate( + pos_embed.reshape(pos_embed.shape[0], pos_embed.shape[1], 14, 14), + size=(16, 16), + mode="bilinear", + align_corners=False, + ) + .reshape(pos_embed.shape[0], pos_embed.shape[1], 256) + .permute(0, 2, 1) + ) # NDL->NLD + pos_embed = torch.cat([self.query_pos_embed.expand(pos_embed.shape[0], -1, -1), pos_embed], dim=1) + v_L = x.shape[1] # vision token L 196 + (h, w) = 16, 16 # h w 16,16 + + x = torch.cat([self.query_embed.expand(x.shape[0], -1, -1), x], dim=1) # (N ,Q_L+L,D) + x = x + pos_embed + x = self.ln_pre(x) + outs = [] + out_layers = [8] + loss_intra = 0 + # self.fuse(0, x, clip_features, (h, w)) + for i, block in enumerate(self.vit_model.blocks, start=1): # total 1-12 ,only use 1-8 + x = block(x) # (N, Q_L+L, D) + self.fuse(i, x, clip_features, (h, w)) + + if i in out_layers: + n, _, d = x.shape + outs.append( + { + "query": x[:, :-v_L, ...], + "x": x[:, -v_L:, ...].permute(0, 2, 1).reshape(n, d, h, w), + } + ) + x = x + pos_embed + if i == max(out_layers): + break + xray_preds = [] + attn_biases = [] + + for feature in outs: + xray_pred, attn_bias = self.mask_decoder(feature["query"], feature["x"]) + xray_preds.append(xray_pred) + attn_biases.append(attn_bias) + + return attn_biases, xray_preds, loss_intra diff --git a/src/model/forada/attn.py b/src/model/forada/attn.py new file mode 100644 index 0000000000000000000000000000000000000000..bb979c5300a04b527d611fb0677ca70e4da495e5 --- /dev/null +++ b/src/model/forada/attn.py @@ -0,0 +1,100 @@ +import torch +from torch import nn +from torch.nn import functional as F + + +class ClipIntraBlock(nn.Module): + def __init__(self, num_features): + super().__init__() + self.num_features = num_features + self.conv_first = nn.Conv1d(in_channels=self.num_features, out_channels=192, kernel_size=1) + self.relu = nn.ReLU() + self.conv_second = nn.Conv1d(in_channels=192, out_channels=self.num_features, kernel_size=1) + + def forward(self, x, data_dict, clip_L, inference): + intra_x = x.permute(1, 2, 0) # LND -> NDL + intra_x = intra_x[:, :, -clip_L:].float() # N clipD 256/196 + + intra_x = self.conv_first(intra_x) # N D 256 + intra_x = self.relu(intra_x) + intra_x = intra_x.permute(0, 2, 1) # NDL-> NLD + loss_clip = 0 + intra_x = intra_x.permute(0, 2, 1) # NLD-> NDL + intra_x = self.conv_second(intra_x) # NDL + # intra_x = self.relu(intra_x) + + intra_x = intra_x.permute(2, 0, 1) # NDL- >LND + # x LND + # x[-clip_L:,...] = intra_x * 0.1 + x[-clip_L:, ...] * self.intra_scale 在B14 效果不错 + # x[-clip_L:, ...] = intra_x * self.intra_scale + x[-clip_L:, ...] * 0.9 + # x[-clip_L:, ...] = intra_x * 0.15 + x[-clip_L:, ...] * 0.95 + return intra_x, loss_clip + + +class RecAttnClip(nn.Module): + def __init__(self, vit, num_quires): + super().__init__() + self.vit = vit + self.resblocks = self.vit.transformer.resblocks + self.first_layer = 0 + self.clss_nums = num_quires + self.ln_post = self.vit.ln_post + self.proj = self.vit.proj + self.num_features = self.vit.width + self.intra_scale = nn.Parameter(torch.zeros(1)) + self.intra_map = {6: 0} + self.clip_intra_blocks = nn.ModuleList([ClipIntraBlock(self.num_features) for _ in range(1)]) + self._freeze() + + def build_attn_mask(self, attn_bias): + num_heads = self.resblocks[0].attn.num_heads + n, Head, q, h, w = attn_bias.shape + + assert Head == num_heads, f"num_head={Head} is not supported. Modify to {num_heads}" + attn_bias = attn_bias.reshape(n * Head, q, -1) + l = attn_bias.shape[-1] + attn_mask = attn_bias.new_zeros(q + 1 + l, q + 1 + l) + attn_mask[:, :q] = -100 + attn_mask[torch.arange(q), torch.arange(q)] = 0 + attn_mask[:q, q] = -100 + attn_mask = attn_mask[None, ...].expand(n * Head, -1, -1).clone() + attn_mask[:, :q, -l:] = attn_bias + # attn_mask (n*head,1+q+l,1+q+l) + attn_biases = [attn_mask for _ in self.resblocks.children()] + return attn_biases + + def _freeze(self): + for name, param in self.named_parameters(): + if "clip_intra_blocks" in name: + param.requires_grad = True + else: + param.requires_grad = False + + def forward(self, data_dict, clip_features, attn_bias, inference=False, normalize=False): + cls_token = clip_features[f"layer_{self.first_layer}_cls"].unsqueeze(1).permute(1, 0, 2).clone() # ND->N1D->1ND + vision_tokens = clip_features[self.first_layer].permute(1, 0, 2).clone() # NLD->LND + clss_token = cls_token.repeat(self.clss_nums, 1, 1) # 1ND -> clss_nums, N,D + + x = torch.cat([clss_token, cls_token, vision_tokens], dim=0) # (1+Q+L,N,D) + x.requires_grad = True + clip_L = vision_tokens.shape[0] # + + attn_biases = self.build_attn_mask(attn_bias) + + loss_clip = 0 + for i, blocks in enumerate(self.resblocks.children()): + x = blocks(x, attn_biases[i]) + if i == 6: + intra_x, loss_clip_tmp = self.clip_intra_blocks[self.intra_map[i]](x, data_dict, clip_L, inference) + loss_clip = loss_clip_tmp + loss_clip + x[-clip_L:, ...] = intra_x * 0.05 + x[-clip_L:, ...] + + x = x.permute(1, 0, 2) # LND -> NLD + clss_token = x[:, : self.clss_nums, :] + clss_token = self.ln_post(clss_token) + if self.proj is not None: + clss_token = clss_token @ self.proj + if normalize: + clss_token = F.normalize(clss_token, dim=-1) + + return clss_token, loss_clip diff --git a/src/model/forada/clip/__init__.py b/src/model/forada/clip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dcc5619538c0f7c782508bdbd9587259d805e0d9 --- /dev/null +++ b/src/model/forada/clip/__init__.py @@ -0,0 +1 @@ +from .clip import * diff --git a/src/model/forada/clip/bpe_simple_vocab_16e6.txt.gz b/src/model/forada/clip/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/src/model/forada/clip/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/src/model/forada/clip/clip.py b/src/model/forada/clip/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..ad433d03fddf5956b77ace605ed1a1e3494540e7 --- /dev/null +++ b/src/model/forada/clip/clip.py @@ -0,0 +1,250 @@ +import hashlib +import os +import urllib +import warnings +from typing import List, Union + +import torch +from PIL import Image +from pkg_resources import packaging +from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor +from tqdm import tqdm + +from .model import build_model +from .simple_tokenizer import SimpleTokenizer as _Tokenizer + +try: + from torchvision.transforms import InterpolationMode + + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + + +if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): + warnings.warn("PyTorch version 1.7.1 or higher is recommended") + + +__all__ = ["available_models", "load", "tokenize"] +_tokenizer = _Tokenizer() + +_MODELS = { + "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", + "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", + "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", + "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", + "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", + "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", +} + + +def _download(url: str, root: str): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + expected_sha256 = url.split("/")[-2] + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm( + total=int(source.info().get("Content-Length")), ncols=80, unit="iB", unit_scale=True, unit_divisor=1024 + ) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: + raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def _convert_image_to_rgb(image): + return image.convert("RGB") + + +def _transform(n_px): + return Compose( + [ + Resize(n_px, interpolation=BICUBIC), + CenterCrop(n_px), + _convert_image_to_rgb, + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ] + ) + + +def available_models() -> List[str]: + """Returns the names of available CLIP models""" + return list(_MODELS.keys()) + + +def load( + name: str, + device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", + jit: bool = False, + download_root: str = None, +): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + + device : Union[str, torch.device] + The device to put the loaded model + + jit : bool + Whether to load the optimized JIT model or more hackable non-JIT model (default). + + download_root: str + path to download the model files; by default, it uses "~/.cache/clip" + + Returns + ------- + model : torch.nn.Module + The CLIP model + + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if name in _MODELS: + model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {available_models()}") + + with open(model_path, "rb") as opened_file: + try: + # loading JIT archive + model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict 加载已训练好的state_dict权重到模型中 + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(opened_file, map_location="cpu") + + if not jit: + # build model + model = build_model(state_dict or model.state_dict()).to(device) + if str(device) == "cpu": + model.float() + return model, _transform(model.visual.input_resolution) + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + + model.float() + + return model, _transform(model.input_resolution.item()) + + +def tokenize( + texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False +) -> Union[torch.IntTensor, torch.LongTensor]: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + + context_length : int + The context length to use; all CLIP models use 77 as the context length + + truncate: bool + Whether to truncate the text in case its encoding is longer than the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. + We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder["<|startoftext|>"] + eot_token = _tokenizer.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + else: + result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + result[i, : len(tokens)] = torch.tensor(tokens) + + return result diff --git a/src/model/forada/clip/model.py b/src/model/forada/clip/model.py new file mode 100644 index 0000000000000000000000000000000000000000..ea7fccf9667b60d01efa4d0e620e2cd793552369 --- /dev/null +++ b/src/model/forada/clip/model.py @@ -0,0 +1,508 @@ +from collections import OrderedDict +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential( + OrderedDict( + [ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)), + ] + ) + ) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu1(self.bn1(self.conv1(x))) + out = self.relu2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x[:1], + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False, + ) + return x.squeeze(0) + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.relu3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + def stem(x): + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict( + [ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)), + ] + ) + ) + self.ln_2 = LayerNorm(d_model) + # self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor, attn_mask): + attn_mask = attn_mask.to(dtype=x.dtype, device=x.device) if attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] + + def forward(self, x: torch.Tensor, attn_mask=None): + x = x + self.attention(self.ln_1(x), attn_mask) + x = x + self.mlp(self.ln_2(x)) + return x + + +class VTransformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) + + print(f"width:{width} layers {layers} ") + + def forward(self, x: torch.Tensor, extract=None): + out = {} + features = {} + if extract is not None: + features["layer_0_cls"] = x[0] # ND + features[0] = x[1:].permute(1, 0, 2) # LND -> NLD + for idx, layer in enumerate(self.resblocks.children(), start=1): + x = layer(x) + out["layer" + str(idx)] = x[0] + if extract is not None and idx in extract: + features[f"layer_{idx}_cls"] = x[0] # ND + features[idx] = x[1:].permute(1, 0, 2) # NLD + if idx == max(extract): + return features + return out, x + + # return self.resblocks(x) # This is the original code + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) + + def forward(self, x: torch.Tensor): + # out = {} + # for idx, layer in enumerate(self.resblocks.children()): + # x = layer(x) + # out['layer' + str(idx)] = x[0] # shape:LND. choose cls token feature 选择 [CLS] 标记的特征 + # return out, x + + return self.resblocks(x) # This is the original code + + +class VisionTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = VTransformer(width, layers, heads) + self.width = width + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + self.extract = False + + def forward(self, x: torch.Tensor, extract=None): + x = self.conv1(x) + x = x.reshape(x.shape[0], x.shape[1], -1) + x = x.permute(0, 2, 1) + x = torch.cat( + [ + self.class_embedding.to(x.dtype) + + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), + x, + ], + dim=1, + ) # + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) # layer_norm + + x = x.permute(1, 0, 2) # NLD -> LND + if extract is not None: + feat = self.transformer(x, extract) + return feat # NLD + else: + out, x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_post(x[:, 0, :]) + out["before_projection"] = x + + if self.proj is not None: + x = x @ self.proj + out["after_projection"] = x + # Return both intermediate features and final clip feature + # return out + + # This only returns CLIP features + return x + + +class CLIP(nn.Module): + def __init__( + self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int, + ): + super().__init__() + + self.context_length = context_length + # Resnet + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width, + ) + # VIT + else: + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim, + ) + # text + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask(), + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features**-0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + proj_std = (self.transformer.width**-0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width**-0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + # get_image_embedding + def encode_image(self, image): + return self.visual(image.type(self.dtype)) + + def extract_features(self, image, extract): + return self.visual(image.type(self.dtype), extract) # extract_features,x,cls + + # get_text_embedding + def encode_text(self, text): + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) # LayerNorm + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + return x + + def forward(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm(dim=1, keepdim=True) + text_features = text_features / text_features.norm(dim=1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text + + +def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + if isinstance(l, nn.MultiheadAttention): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +def build_model(state_dict: dict): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len( + [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")] + ) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_resolution = vision_patch_size * grid_size + else: + counts: list = [ + len({k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}")}) for b in [1, 2, 3, 4] + ] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width**2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_resolution = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len({k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")}) + + model = CLIP( + embed_dim, + image_resolution, + vision_layers, + vision_width, + vision_patch_size, + context_length, + vocab_size, + transformer_width, + transformer_heads, + transformer_layers, + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + if key in state_dict: + del state_dict[key] + + convert_weights(model) + model.load_state_dict(state_dict) + return model.eval() diff --git a/src/model/forada/clip/simple_tokenizer.py b/src/model/forada/clip/simple_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..0b1a6b1470840ca66638524fbd541d5bdf67b4f8 --- /dev/null +++ b/src/model/forada/clip/simple_tokenizer.py @@ -0,0 +1,135 @@ +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") + merges = merges[1 : 49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + "" for v in vocab] + for merge in merges: + vocab.append("".join(merge)) + vocab.extend(["<|startoftext|>", "<|endoftext|>"]) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {"<|startoftext|>": "<|startoftext|>", "<|endoftext|>": "<|endoftext|>"} + self.pat = re.compile( + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE, + ) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + "",) + pairs = get_pairs(word) + + if not pairs: + return token + "" + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except Exception: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")) + return bpe_tokens + + def decode(self, tokens): + text = "".join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors="replace").replace("", " ") + return text diff --git a/src/model/forada/config.yaml b/src/model/forada/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2b86604b5a065b6fbcf280f386b496abd008f3b1 --- /dev/null +++ b/src/model/forada/config.yaml @@ -0,0 +1,137 @@ +# log dir +log_dir: /data/cuixinjie/FA/logs/new +#log_dir: /data/cuixinjie/DsClip_V2/log +lmdb: False +mode: train +dry_run: false +model_name: 'ds' +inter: 'none' +task_target: " " +save_avg: True +#queue_size: 2048 +#random_k: 256 +loss_rate: "10 * loss1 + 200 * loss_mse + 20* loss_intra + 10 *loss_inter" +vit_name: 'vit_tiny_patch16_224' +train_set: 'ori' +num_quires: 128 +fusion_map: {1: 1, 2: 8, 3: 15} +clip_model_name: "ViT-L/14" +#clip_model_name: "ViT-B/16" + +#clip_model_name: "data/cuixinjie/ViT-L-14.pt" + +device: 'cuda:0' +mlp_dim: 256 +mlp_out_dim: 128 +head_num: 16 # for clip_model_name: "ViT-L/14" +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FaceForensics++] +#test_dataset: [FaceForensics++] +test_dataset: [FaceForensics++] + +#test_dataset: [FaceForensics++] + + +#test_dataset: [FF-F2F] + +#dataset_json_folder: '/media/ouc/新加卷/DS_CLIP_V2/dataset/dataset_json' +dataset_json_folder: '/data/cuixinjie/dataset/dataset_json' + +compression: c23 # compression-level for videos +train_batchSize: 20 # training batch size +test_batchSize: 64 # test batch size +workers: 8 # number of data loading workers +frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing +resolution: 256 # resolution of output image to network +with_mask: true # whether to include mask information in the input +with_xray: true +with_patch_labels: true +with_landmark: false # whether to include facial landmark information in the input +# label settings +label_dict: + # DFD + DFD_fake: 1 + DFD_real: 0 + # FF++ + FaceShifter(FF-real+FF-FH) + FF-SH: 1 + FF-F2F: 1 + FF-DF: 1 + FF-FS: 1 + FF-NT: 1 + FF-FH: 1 + FF-real: 0 + # CelebDF + CelebDFv1_real: 0 + CelebDFv1_fake: 1 + CelebDFv2_real: 0 + CelebDFv2_fake: 1 + # DFDCP + DFDCP_Real: 0 + DFDCP_FakeA: 1 + DFDCP_FakeB: 1 + # DFDC + DFDC_Fake: 1 + DFDC_Real: 0 + # DeeperForensics-1.0 + DF_fake: 1 + DF_real: 0 + # UADFV + UADFV_Fake: 1 + UADFV_Real: 0 + + + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.48145466, 0.4578275, 0.40821073] +std: [0.26862954, 0.26130258, 0.27577711] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.0002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 60 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 2 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +#logdir: /media/ouc/新加卷/DS_CLIP_V2/log # folder to output images and logs + +manualSeed: 1020 # manual seed for random number generation +save_ckpt: true # whether to save checkpoiccnt +save_feat: false + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda +ngpu: 1 # number of GPUs to use +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations diff --git a/src/model/forada/ds.py b/src/model/forada/ds.py new file mode 100644 index 0000000000000000000000000000000000000000..52ec6076cd1a4dbe90a1f65f8501859360e49862 --- /dev/null +++ b/src/model/forada/ds.py @@ -0,0 +1,87 @@ +import torch.nn.functional as F +from torch import nn + +from .adapters.adapter import Adapter +from .attn import RecAttnClip +from .clip.clip import load +from .layer import MaskPostXrayProcess, PostClipProcess + + +class DS(nn.Module): + def __init__( + self, clip_name, adapter_vit_name, num_quires, fusion_map, mlp_dim, mlp_out_dim, head_num, mode="video" + ): + super().__init__() + self.clip_model, self.processor = load(clip_name, download_root="weights/forensics_adapter") + self.adapter = Adapter( + vit_name=adapter_vit_name, + num_quires=num_quires, + fusion_map=fusion_map, + mlp_dim=mlp_dim, + mlp_out_dim=mlp_out_dim, + head_num=head_num, + ) + self.rec_attn_clip = RecAttnClip(self.clip_model.visual, num_quires) # 全部参数被冻结 + self.masked_xray_post_process = MaskPostXrayProcess(in_c=num_quires) + self.clip_post_process = PostClipProcess(num_quires=num_quires, embed_dim=768) + + self.mode = mode + self._freeze() + + def _freeze(self): + for name, param in self.named_parameters(): + if "clip_model" in name: + param.requires_grad = False + + def get_losses(self, data_dict, pred_dict): + label = data_dict["label"] # N + xray = data_dict["xray"] + pred = pred_dict["cls"] # N2 + xray_pred = pred_dict["xray_pred"] + loss_intra = pred_dict["loss_intra"] + loss_clip = pred_dict["loss_clip"] + criterion = nn.CrossEntropyLoss() + loss1 = criterion(pred.float(), label) + if xray is not None: + loss_mse = F.mse_loss(xray_pred.squeeze().float(), xray.squeeze().float()) # (N 1 224 224)->(N 224 224) + + loss = 10 * loss1 + 200 * loss_mse + 20 * loss_intra + 10 * loss_clip + + loss_dict = {"cls": loss1, "xray": loss_mse, "intra": loss_intra, "loss_clip": loss_clip, "overall": loss} + return loss_dict + else: + loss_dict = {"cls": loss1, "overall": loss1} + return loss_dict + + def forward(self, data_dict, inference=False): + images = data_dict["image"] + clip_images = F.interpolate( + images, + size=(224, 224), + mode="bilinear", + align_corners=False, + ) + + clip_features = self.clip_model.extract_features(clip_images, self.adapter.fusion_map.values()) + + attn_biases, xray_preds, loss_adapter_intra = self.adapter(data_dict, clip_features, inference) + clip_output, loss_clip = self.rec_attn_clip( + data_dict, clip_features, attn_biases[-1], inference, normalize=True + ) + + # data_dict["if_boundary"] = data_dict["if_boundary"].to(self.device) + # xray_preds = [self.masked_xray_post_process(xray_pred, data_dict["if_boundary"]) for xray_pred in xray_preds] + + clip_cls_output = self.clip_post_process(clip_output.float()) # N2 + + # prob = torch.softmax(outputs["clip_cls_output"], dim=1)[:, 1] + pred_dict = { + "logits": clip_cls_output, + # "cls": outputs["clip_cls_output"], + # "prob": prob, + # "xray_pred": xray_preds[-1], # N 1 224 224 + # "loss_intra": loss_adapter_intra, + # "loss_clip": loss_clip, + } + + return pred_dict diff --git a/src/model/forada/layer.py b/src/model/forada/layer.py new file mode 100644 index 0000000000000000000000000000000000000000..68630dc1f82373c30fd5f653905870b53f165d7a --- /dev/null +++ b/src/model/forada/layer.py @@ -0,0 +1,171 @@ +import torch +from timm.layers import to_2tuple +from torch import nn +from torch.nn import functional as F + + +class LayerNorm(nn.Module): + """ + A LayerNorm variant, popularized by Transformers, that performs point-wise mean and + variance normalization over the channel dimension for inputs that have shape + (batch_size, channels, height, width). + https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950 + """ + + def __init__(self, normalized_shape, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.normalized_shape = (normalized_shape,) + + def forward(self, x: torch.Tensor): + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +class MLP(nn.Module): + """Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers, affine_func=nn.Linear): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(affine_func(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x: torch.Tensor): + for i, layer in enumerate(self.layers): + # L R L R L + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +class Fusion(nn.Module): + def __init__(self, clip_dim, adapter_dim): + super().__init__() + self.clip_dim = clip_dim + self.adapter_dim = adapter_dim + self.proj = nn.Sequential( + LayerNorm(clip_dim), + nn.Conv2d(clip_dim, adapter_dim, kernel_size=1), + ) + + def forward(self, x, clip_x, spatial_shape): + h, w = spatial_shape + n, l, d = clip_x.shape + + if l == h * w: + clip_x = clip_x.permute(0, 2, 1).view(n, d, h, w) # NLD->NDL->NDhw + else: + clip_x = clip_x.permute(0, 2, 1).view(n, d, 14, 14) # NLD->NDL->NDhw + clip_x = F.interpolate( + clip_x.contiguous(), + size=(16, 16), + mode="bilinear", + align_corners=False, + ) # ND 14 14 => N D 16 16 + clip_x = self.proj(clip_x).view(n, self.adapter_dim, h * w).permute(0, 2, 1) + x = x + clip_x # NLD + + return x + + +class MaskPostXrayProcess(nn.Module): + def __init__(self, in_c): + super().__init__() + + self.process = nn.Sequential( + nn.Conv2d( + in_channels=in_c, out_channels=in_c // 2, kernel_size=3, stride=1, padding=1 + ), # (N Q h,w)->(N 64 h,w)) + nn.BatchNorm2d(in_c // 2), + nn.ReLU(), + nn.Conv2d(in_channels=in_c // 2, out_channels=in_c // 4, kernel_size=3, stride=1, padding=1), # (N 32 h,w) + nn.BatchNorm2d(in_c // 4), + nn.ReLU(), + nn.Conv2d(in_channels=in_c // 4, out_channels=1, kernel_size=1, stride=1, padding=0), # (N 16 h,w) + nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=16, stride=16), # (N 16 h,w)->(N 1 256 256) + # nn.Upsample(size=(256, 256), mode='bilinear', align_corners=True) # (N 1 256 256) + ) + + def forward(self, x, if_boundaries): + x = x.reshape(x.shape[0], x.shape[1], -1) # (N Q 256) + x = x.permute(0, 2, 1) # (N L Q) + if_boundaries = if_boundaries.unsqueeze(-1) # (NL1) 不是boundry的patch块置为0 + + x = x * if_boundaries # (N L Q) * (N L 1) + x = x.permute(0, 2, 1) # (N Q L) + x = x.reshape(x.shape[0], x.shape[1], 16, 16) + + post_x = self.process(x) # (N 1 224 224) + return post_x + + +class PostClipProcess(nn.Module): + """ + NQD -> ND -> N2 + + """ + + def __init__(self, num_quires, embed_dim): + super().__init__() + + self.first_process = nn.Sequential( + nn.Conv1d( + in_channels=num_quires, out_channels=num_quires // 2, kernel_size=3, stride=1, padding=1 + ), # NQD->N1D + nn.BatchNorm1d(num_quires // 2), + nn.ReLU(), + nn.Conv1d(in_channels=num_quires // 2, out_channels=num_quires // 4, kernel_size=3, stride=1, padding=1), + nn.BatchNorm1d(num_quires // 4), + nn.ReLU(), + nn.Conv1d(in_channels=num_quires // 4, out_channels=1, kernel_size=3, stride=1, padding=1), + ) + # self.norm = VT_LN(embed_dim) + self.second_process = nn.Sequential( # ND->N2 + nn.Linear(in_features=embed_dim, out_features=embed_dim // 2), + nn.ReLU(), + nn.Linear(in_features=embed_dim // 2, out_features=embed_dim // 4), + nn.ReLU(), + # nn.Linear(in_features=embed_dim // 4, out_features=embed_dim // 8), + # nn.ReLU(), + nn.Linear(in_features=embed_dim // 4, out_features=2), + ) + + def forward(self, x): + x = self.first_process(x) # NQD->N1D + x = x.squeeze(1) # NQD->ND + x = self.second_process(x) + return x + + +class VT_LN(nn.LayerNorm): + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class PatchEmbed(nn.Module): + def __init__(self, img_size=256, patch_size=16, in_chans=3, embed_dim=192, norm_layer=None, bias=False, **kwargs): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) + + self.norm = VT_LN(embed_dim) + + def forward(self, x): + x = self.proj(x) + x = x.reshape(x.shape[0], x.shape[1], -1) # NDL + x = x.permute(0, 2, 1) # NDL->NLD + # x = self.norm(x) + return x diff --git a/src/model/fsfm/models_vit.py b/src/model/fsfm/models_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..fa720ddd7206d2ecd0749d75756210c836294e67 --- /dev/null +++ b/src/model/fsfm/models_vit.py @@ -0,0 +1,175 @@ +# -*- coding: utf-8 -*- +# Author: Gaojian Wang@ZJUICSR +# -------------------------------------------------------- +# This source code is licensed under the Attribution-NonCommercial 4.0 International License. +# You can find the license in the LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- + +from functools import partial + +import timm.models.vision_transformer +import torch +import torch.nn as nn +from timm.models.helpers import load_pretrained +from timm.models.vision_transformer import default_cfgs + + +class VisionTransformer(timm.models.vision_transformer.VisionTransformer): + """Vision Transformer with support for global average pooling""" + + def __init__(self, global_pool=False, **kwargs): + super(VisionTransformer, self).__init__(**kwargs) + + self.global_pool = global_pool + if self.global_pool: + norm_layer = kwargs["norm_layer"] + embed_dim = kwargs["embed_dim"] + self.fc_norm = norm_layer(embed_dim) + + del self.norm # remove the original norm + + def forward_features(self, x): + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + # if self.global_pool: + # x = x[:, 1:, :].mean(dim=1) # global pool without cls token + # outcome = self.fc_norm(x) + if self.global_pool: + x_gp = x[:, 1:, :].mean(dim=1) # global pool without cls token + outcome = self.fc_norm(x_gp) + + # x_new = torch.zeros_like(x) + # x_new[:, 0, :] = x_gp + # x_new[:, 1:, :] = x[:, 1:, :] + # outcome = x_new + else: + x = self.norm(x) + # outcome = x[:, 0] + outcome = x # for fas code + + return outcome + + # def reset_classifier(self, num_classes, global_pool=''): + # self.num_classes = num_classes + # self.head = nn.ModuleList([ + # nn.Linear(self.embed_dim, 512), + # nn.Linear(512, num_classes) if num_classes > 0 else nn.Identity() + # ]) + # + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def _conv_filter(state_dict, patch_size=16): + """convert patch embedding weight from manual patchify + linear proj to conv""" + out_dict = {} + for k, v in state_dict.items(): + if "patch_embed.proj.weight" in k: + v = v.reshape((v.shape[0], 3, patch_size, patch_size)) + out_dict[k] = v + return out_dict + + +def vit_small_patch16(pretrained=False, **kwargs): + model = VisionTransformer( + patch_size=16, + embed_dim=384, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) # ViT-small config in MOCO_V3 + # model = VisionTransformer( + # patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, qkv_bias=True, + # norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) # ViT-small config in timm + model.default_cfg = default_cfgs["vit_small_patch16_224"] + # if pretrained: + # load_pretrained( + # model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter) + # for timm version 0.6.12: + # pretrained_cfg = resolve_pretrained_cfg('vit_base_patch16_224', + # pretrained_cfg=kwargs.pop('pretrained_cfg', None)) + # load_pretrained( + # model, pretrained_cfg=pretrained_cfg, + # num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter) + return model + + +def vit_base_patch16(pretrained=False, **kwargs): + model = VisionTransformer( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + model.default_cfg = default_cfgs["vit_base_patch16_224"] + if pretrained: + load_pretrained( + model, num_classes=model.num_classes, in_chans=kwargs.get("in_chans", 3), filter_fn=_conv_filter + ) + # for timm version 0.6.12: + # pretrained_cfg = resolve_pretrained_cfg('vit_base_patch16_224', + # pretrained_cfg=kwargs.pop('pretrained_cfg', None)) + # load_pretrained( + # model, pretrained_cfg=pretrained_cfg, + # num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter) + return model + + +def vit_large_patch16(pretrained=False, **kwargs): + model = VisionTransformer( + patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + model.default_cfg = default_cfgs["vit_large_patch16_224"] + if pretrained: + load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get("in_chans", 3)) + # for timm version 0.6.12: + # pretrained_cfg = resolve_pretrained_cfg('vit_large_patch16_224', + # pretrained_cfg=kwargs.pop('pretrained_cfg', None)) + # load_pretrained( + # model, pretrained_cfg=pretrained_cfg, + # num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + return model + + +def vit_huge_patch14(pretrained=False, **kwargs): + model = VisionTransformer( + patch_size=14, + embed_dim=1280, + depth=32, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + return model diff --git a/src/model/fsfm/models_vit_fs_adapter.py b/src/model/fsfm/models_vit_fs_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..b226c49d0c87af189a2407529db11d45582ffee5 --- /dev/null +++ b/src/model/fsfm/models_vit_fs_adapter.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- +# Author: Gaojian Wang@ZJUICSR +# -------------------------------------------------------- +# This source code is licensed under the Attribution-NonCommercial 4.0 International License. +# You can find the license in the LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- + +from functools import partial + +import timm.models.vision_transformer +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Adapter(nn.Module): + def __init__(self, c_in, reduction=4): + super(Adapter, self).__init__() + self.linear1 = nn.Linear(c_in, c_in // reduction, bias=False) + self.activation1 = nn.ReLU(inplace=True) + self.linear2 = nn.Linear(c_in // reduction, c_in, bias=False) + self.activation2 = nn.ReLU(inplace=True) + + # self.fc = nn.Sequential( + # nn.Linear(c_in, c_in // reduction, bias=False), + # nn.ReLU(inplace=True), + # nn.Linear(c_in // reduction, c_in, bias=False), + # nn.ReLU(inplace=True) + # ) + + def forward(self, x): + x = self.linear1(x) + x_bottleneck = self.activation1(x) + + x = self.linear2(x_bottleneck) + x = self.activation2(x) + + return x, x_bottleneck + + +class VisionTransformer(timm.models.vision_transformer.VisionTransformer): + """Vision Transformer with support for global average pooling""" + + def __init__(self, global_pool=False, adapter_reduction=4, **kwargs): + super(VisionTransformer, self).__init__(**kwargs) + + self.global_pool = global_pool + if self.global_pool: + norm_layer = kwargs["norm_layer"] + embed_dim = kwargs["embed_dim"] + self.fc_norm = norm_layer(embed_dim) + + del self.norm # remove the original norm + + self.adapter = Adapter(kwargs["embed_dim"], adapter_reduction) + self.projector = nn.Linear(kwargs["embed_dim"] // adapter_reduction, (kwargs["embed_dim"] // adapter_reduction)) + + self.head = ( + nn.Linear(kwargs["embed_dim"] * 2, kwargs["num_classes"]) if kwargs["num_classes"] > 0 else nn.Identity() + ) + + def forward_features(self, x): + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + if self.global_pool: + x = x[:, 1:, :].mean(dim=1) # global pool without cls token + outcome = self.fc_norm(x) + else: + x = self.norm(x) + outcome = x[:, 0] + + return outcome + + def forward(self, x, training=False): + x = self.forward_features(x) + + x_adapter, x_bottleneck = self.adapter(x) + proj_features = self.projector(x_bottleneck) + + x = torch.concatenate([x, x_adapter], dim=-1) + x = self.head(x) + + if training: + # Return logits for cls loss, projected features for contrastive loss + return x, F.normalize(proj_features, dim=-1) + return x + + +def vit_small_patch16(**kwargs): + model = VisionTransformer( + patch_size=16, + embed_dim=384, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, # ViT-small config in MOCO_V3 + # patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, qkv_bias=True, # ViT-small config in timm + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + return model + + +def vit_base_patch16(**kwargs): + model = VisionTransformer( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + return model + + +def vit_large_patch16(**kwargs): + model = VisionTransformer( + patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + return model + + +def vit_huge_patch14(**kwargs): + model = VisionTransformer( + patch_size=14, + embed_dim=1280, + depth=32, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + return model diff --git a/src/plots.py b/src/plots.py new file mode 100644 index 0000000000000000000000000000000000000000..b6d59b2e8c6204636dabd87e508dc4db66a49f25 --- /dev/null +++ b/src/plots.py @@ -0,0 +1,355 @@ +import os + +import matplotlib.pyplot as plt +import numpy as np +import seaborn as sns +from numpy import ndarray +from scipy.interpolate import interp1d +from sklearn import metrics as M + +from src.utils import logger +from src.utils.decorators import TryExcept + + +@TryExcept("plot_curve") +def plot_curve( + xs: list[ndarray], + ys: list[ndarray], + auc_threshold: float = 0.01, + class_names: None | dict[int, str] = None, + ax_plot=None, + interpolate: int = 200, + mean: bool = True, + linestyles=["-", "--", "-.", ":"], + palette: None | list | dict = None, +): + # Use only one linestyle if up to 4 classes + if len(xs) <= 4: + linestyles = linestyles[:1] + + # Create figure with larger size and better aspect ratio + plt.figure(figsize=(10, 8), tight_layout=True) + + # Create two subplots - one for the plot, one for the legend + gs = plt.GridSpec(1, 2, width_ratios=[4, 1]) + ax_plot = plt.subplot(gs[0]) + ax_legend = plt.subplot(gs[1]) + + if palette is None: + palette = sns.husl_palette(len(xs)) + + if interpolate != -1: + x_new = np.linspace(0, 1, interpolate) + ys = [interp1d(x, y)(x_new) for x, y in zip(xs, ys)] + xs = [x_new] * len(xs) + + # Plot curves on the main axis + active_classes = [] + for c, (x, y) in enumerate(zip(xs, ys)): + auc = M.auc(x, y) + if auc >= auc_threshold: # Only plot and include in legend if AUC > threshold + class_name = f"{c}: {class_names[c]}" if class_names else c + label = f"{class_name} (AUC: {auc:.2f})" + linestyle = linestyles[c % len(linestyles)] + line = ax_plot.plot(x, y, label=label, linewidth=1.5, color=palette[c], linestyle=linestyle) + active_classes.append((line[0], label)) + + if mean and interpolate != -1: + ys_mean = np.mean(ys, axis=0) + xs_mean = np.mean(xs, axis=0) + + # Plot mean curve + auc = M.auc(xs_mean, ys_mean) + label = f"avg (AUC: {auc:.2f})" + ax_plot.plot(xs_mean, ys_mean, label="avg", linewidth=1.5, color="black", linestyle="-") + active_classes.append((ax_plot.lines[-1], label)) + + # Set square aspect ratio + ax_plot.set_aspect("equal") + + # Set limits explicitly to ensure square plot + ax_plot.set_xlim(-0.02, 1.02) # Slight padding for better visibility + ax_plot.set_ylim(-0.02, 1.02) + + # Customize the main plot + ax_plot.grid(True, linestyle="--", alpha=0.3) + + # Create legend in the second subplot + ax_legend.axis("off") # Hide the axis + if active_classes: + lines, labels = zip(*active_classes) + ax_legend.legend(lines, labels, loc="center left", fontsize=10, borderaxespad=0) + + return ax_plot + + +@TryExcept("plot_roc_curve") +def plot_roc_curve( + fprs: list[ndarray], + tprs: list[ndarray], + ths: list[ndarray], + title: str = "ROC", + path: str = "roc_curve.png", + auc_threshold: float = 0.01, + class_names: None | dict[int, str] = None, +): + """ + Plot ROC curve for multiple classes. + """ + ax_plot = plot_curve(fprs, tprs, auc_threshold, class_names) + + # Add the diagonal line + ax_plot.plot([0, 1], [0, 1], color="black", linestyle="--", alpha=0.5) + + ax_plot.set_title(title, fontsize=14) + ax_plot.set_xlabel("False Positive Rate (FPR)", fontsize=12) + ax_plot.set_ylabel("True Positive Rate (TPR)", fontsize=12) + + # Save with high quality + os.makedirs(os.path.dirname(path), exist_ok=True) + plt.savefig(path, dpi=300, bbox_inches="tight") + plt.close() + + +@TryExcept("plot_prc_curve") +def plot_prc_curve( + prcs: list[ndarray], + recs: list[ndarray], + ths: list[ndarray], + title: str = "PRC", + path: str = "pr_curve.png", + auc_threshold: float = 0.01, + class_names: None | dict[int, str] = None, + show_f1_lines: bool = True, +): + """ + Plot Precision-Recall curve for multiple classes. + """ + ax_plot = plot_curve(recs, prcs, auc_threshold, class_names) + + if show_f1_lines: + f_scores = np.linspace(0.1, 0.9, num=9) # F1 scores to plot + for f_score in f_scores: + r = np.linspace(0.001, 1, 100) # Recall + p = f_score * r / (2 * r - f_score) # Precision for given F1 score + mask = p > 0 + ax_plot.plot(r[mask], p[mask], color="gray", alpha=0.2, linestyle="--") + ax_plot.annotate("F1={0:0.1f}".format(f_score), xy=(0.95, p[-1] - 0.02), alpha=0.2) + + # Customize the main plot + ax_plot.set_title(title, fontsize=14) + ax_plot.set_xlabel("Recall", fontsize=12) + ax_plot.set_ylabel("Precision", fontsize=12) + + # Save with high quality + os.makedirs(os.path.dirname(path), exist_ok=True) + plt.savefig(path, dpi=300, bbox_inches="tight") + plt.close() + + +@TryExcept("plot_f1_curve") +def plot_f1_curve( + prcs: list[ndarray], + recs: list[ndarray], + ths: list[ndarray], + title: str = "F1", + path: str = "f1_curve.png", + auc_threshold: float = 0.01, + class_names: None | dict[int, str] = None, +): + """ + Plot F1 curve for multiple classes + """ + f1s = [] + for prc, rec in zip(prcs, recs): + with np.errstate(divide="ignore", invalid="ignore"): + f1 = np.where((prc + rec) == 0, 0, 2 * prc * rec / (prc + rec)) + f1 = f1[:-1] + f1s.append(f1) + + ax_plot = plot_curve(ths, f1s, auc_threshold, class_names) + + # Customize the main plot + ax_plot.set_title(title, fontsize=14) + ax_plot.set_xlabel("Threshold", fontsize=12) + ax_plot.set_ylabel("F1 Score", fontsize=12) + + # Save with high quality + os.makedirs(os.path.dirname(path), exist_ok=True) + plt.savefig(path, dpi=300, bbox_inches="tight") + plt.close() + + +@TryExcept("plot_fpr_fnr_curve") +def plot_fpr_fnr_curve( + fprs: list[ndarray], # 2 x ths + tprs: list[ndarray], # 2 x ths + ths: list[ndarray], # 2 x ths + title: str = "FPR vs FNR", + path: str = "fpr_fnr_curve.png", + auc_threshold: float = 0.01, + eer: None | float = None, +): + """ + Plot FPR vs FNR curve and EER for binary classification + """ + if len(fprs) != 2: + logger.print_warning_once("FPR vs FNR curve is only plotted for 2 classes") + return + + # Calculate FNR from TPR + fpr = fprs[1] + fnr = 1 - tprs[1] + + xs = [ths[1], ths[1]] + ys = [fpr, fnr] + + class_names = {0: "FPR", 1: "FNR"} + + ax_plot = plot_curve(xs, ys, auc_threshold, class_names, mean=False, linestyles=["-"]) + + if eer is not None: + ax_plot.axhline(y=eer, color="black", linestyle="--") + ax_plot.text(0, eer + 0.02, f"EER: {eer:.2f}", color="black", fontsize=10) + + ax_plot.set_title(title, fontsize=14) + ax_plot.set_xlabel("Threshold", fontsize=12) + ax_plot.set_ylabel("FPR vs FNR", fontsize=12) + + os.makedirs(os.path.dirname(path), exist_ok=True) + plt.savefig(path, dpi=300, bbox_inches="tight") + plt.close() + + +@TryExcept("plot_confusion_matrix") +def plot_confusion_matrix( + confusion_matrix: ndarray, + class_names: None | dict[int, str] = None, + title: str = "Confusion Matrix", + path: str = "confusion_matrix.png", + normalize: bool = False, +): + """ + Plot confusion matrix + """ + N = len(confusion_matrix) + size = max(10, N / 2) + plt.figure(figsize=(size, size), tight_layout=True) + fmt = "d" + if normalize: + confusion_matrix = confusion_matrix / confusion_matrix.sum(axis=1, keepdims=True) * 100 + confusion_matrix[np.isnan(confusion_matrix)] = 0 + fmt = ".2f" + + labels = [f"{k}: {v}" for k, v in class_names.items()] if class_names else None + sns.heatmap( + confusion_matrix, + annot=True, + fmt=fmt, + cmap="Blues", + xticklabels=labels, + yticklabels=labels, + annot_kws={"fontsize": 8}, + ) + plt.xlabel("Predicted", fontsize=12) + plt.ylabel("Actual", fontsize=12) + plt.title(title, fontsize=14, pad=20) + dirname = os.path.dirname(path) + if dirname: + os.makedirs(dirname, exist_ok=True) + plt.savefig(path, dpi=100, bbox_inches="tight") + plt.close() + + +@TryExcept("plot_features_2d") +def plot_features_2d( + features_2d: np.ndarray, # (N, 2) + set_ids: np.ndarray, # (N,) + id2label: dict[int, str], # dict {id: label} + output_path: str, +): + assert isinstance(features_2d, np.ndarray) + assert isinstance(set_ids, np.ndarray) + assert isinstance(id2label, dict) + + plt.figure(figsize=(25, 25)) + + palette = sns.husl_palette(len(id2label)) + id2color = {id: palette[i] for i, id in enumerate(id2label)} + + for id, label in id2label.items(): + mask = set_ids == id + + if not np.any(mask): + continue + + xs = features_2d[mask, 0] + ys = features_2d[mask, 1] + + if "real" in label: + marker = "." + else: + marker = "x" + + plt.scatter(xs, ys, c=[id2color[id]] * len(xs), marker=marker, label=label) + + for x, y, label in zip(xs, ys, set_ids[mask]): + plt.text(x, y, label, c=id2color[id], fontsize=9) + + plt.legend(loc="best", title="Models") + + plt.tight_layout() + + os.makedirs(os.path.dirname(output_path), exist_ok=True) + plt.savefig(output_path) + plt.savefig(output_path.replace(".png", ".svg")) + + +@TryExcept("plot_probs_distribution") +def plot_probs_distribution( + probs: np.ndarray, # (N, C) + labels: np.ndarray, # (N,) + class_names: dict[int, str], # dict {id: label} + output_path: str, +): + n_classes = len(class_names) + fig, axes = plt.subplots(n_classes, 1, figsize=(10, 4 * n_classes)) + palette = sns.husl_palette(n_classes) + + # Find global min and max for x-axis limits + x_min = probs.min() + x_max = probs.max() + x_min, x_max = -0.005, 1.005 + + for idx, (class_idx, class_name) in enumerate(class_names.items()): + ax = axes[idx] + + # Get probabilities for current class + class_mask = labels == class_idx + class_probs = probs[class_mask] + + # Plot probability distribution for each possible class prediction + for pred_idx, pred_name in class_names.items(): + pred_probs = class_probs[:, pred_idx] + sns.histplot( + data=pred_probs, + label=f"ŷ={pred_name}", + color=palette[pred_idx], + alpha=0.2, + bins=100, + stat="probability", + kde=True, + element="step", + ax=ax, + ) + + ax.set_xlabel("Scores") + ax.set_ylabel("Probability") + ax.set_title(f"Histogram p(ŷ|y={class_name}) y – true, ŷ – predicted class", color=palette[class_idx]) + ax.set_xlim(x_min, x_max) + ax.legend() + + plt.tight_layout() + os.makedirs(os.path.dirname(output_path), exist_ok=True) + plt.savefig(output_path, dpi=300, bbox_inches="tight") + plt.close() diff --git a/src/retinaface.py b/src/retinaface.py new file mode 100644 index 0000000000000000000000000000000000000000..543b8009885075504034d107ba7920f934df0bff --- /dev/null +++ b/src/retinaface.py @@ -0,0 +1,345 @@ +import os +import subprocess + +import cv2 +import numpy as np +import onnxruntime + + +def distance2bbox(points, distance, max_shape=None): + """Decode distance prediction to bounding box. + + Args: + points (Tensor): Shape (n, 2), [x, y]. + distance (Tensor): Distance from the given point to 4 + boundaries (left, top, right, bottom). + max_shape (tuple): Shape of the image. + + Returns: + Tensor: Decoded bboxes. + """ + x1 = points[:, 0] - distance[:, 0] + y1 = points[:, 1] - distance[:, 1] + x2 = points[:, 0] + distance[:, 2] + y2 = points[:, 1] + distance[:, 3] + if max_shape is not None: + x1 = x1.clamp(min=0, max=max_shape[1]) + y1 = y1.clamp(min=0, max=max_shape[0]) + x2 = x2.clamp(min=0, max=max_shape[1]) + y2 = y2.clamp(min=0, max=max_shape[0]) + return np.stack([x1, y1, x2, y2], axis=-1) + + +def distance2kps(points, distance, max_shape=None): + """Decode distance prediction to bounding box. + + Args: + points (Tensor): Shape (n, 2), [x, y]. + distance (Tensor): Distance from the given point to 4 + boundaries (left, top, right, bottom). + max_shape (tuple): Shape of the image. + + Returns: + Tensor: Decoded bboxes. + """ + preds = [] + for i in range(0, distance.shape[1], 2): + px = points[:, i % 2] + distance[:, i] + py = points[:, i % 2 + 1] + distance[:, i + 1] + if max_shape is not None: + px = px.clamp(min=0, max=max_shape[1]) + py = py.clamp(min=0, max=max_shape[0]) + preds.append(px) + preds.append(py) + return np.stack(preds, axis=-1) + + +class RetinaFace: + def __init__(self, model_file=None, session=None, providers=["CPUExecutionProvider"]): + self.model_file = model_file + self.session = session + self.taskname = "detection" + if self.session is None: + assert self.model_file is not None + assert os.path.exists(self.model_file) + sess_options = onnxruntime.SessionOptions() + sess_options.intra_op_num_threads = int(os.environ.get("OMP_NUM_THREADS", 1)) + sess_options.inter_op_num_threads = int(os.environ.get("OMP_NUM_THREADS", 1)) + self.session = onnxruntime.InferenceSession(self.model_file, sess_options) + self.session.set_providers(providers) + self.center_cache = {} + self.nms_thresh = 0.4 + self.det_thresh = 0.5 + self._init_vars() + + def _init_vars(self): + input_cfg = self.session.get_inputs()[0] + input_shape = input_cfg.shape + # print(input_shape) + if isinstance(input_shape[2], str): + self.input_size = None + else: + self.input_size = tuple(input_shape[2:4][::-1]) + # print('image_size:', self.image_size) + input_name = input_cfg.name + self.input_shape = input_shape + outputs = self.session.get_outputs() + output_names = [] + for o in outputs: + output_names.append(o.name) + self.input_name = input_name + self.output_names = output_names + self.input_mean = 127.5 + self.input_std = 128.0 + # print(self.output_names) + # assert len(outputs)==10 or len(outputs)==15 + self.use_kps = False + self._anchor_ratio = 1.0 + self._num_anchors = 1 + if len(outputs) == 6: + self.fmc = 3 + self._feat_stride_fpn = [8, 16, 32] + self._num_anchors = 2 + elif len(outputs) == 9: + self.fmc = 3 + self._feat_stride_fpn = [8, 16, 32] + self._num_anchors = 2 + self.use_kps = True + elif len(outputs) == 10: + self.fmc = 5 + self._feat_stride_fpn = [8, 16, 32, 64, 128] + self._num_anchors = 1 + elif len(outputs) == 15: + self.fmc = 5 + self._feat_stride_fpn = [8, 16, 32, 64, 128] + self._num_anchors = 1 + self.use_kps = True + + def prepare(self, ctx_id, **kwargs): + if ctx_id < 0: + self.session.set_providers(["CPUExecutionProvider"]) + nms_thresh = kwargs.get("nms_thresh", None) + if nms_thresh is not None: + self.nms_thresh = nms_thresh + det_thresh = kwargs.get("det_thresh", None) + if det_thresh is not None: + self.det_thresh = det_thresh + input_size = kwargs.get("input_size", None) + if input_size is not None: + if self.input_size is not None: + print("warning: det_size is already set in detection model, ignore") + else: + self.input_size = input_size + + def forward(self, img, threshold): + scores_list = [] + bboxes_list = [] + kpss_list = [] + input_size = tuple(img.shape[0:2][::-1]) + blob = cv2.dnn.blobFromImage( + img, 1.0 / self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True + ) + net_outs = self.session.run(self.output_names, {self.input_name: blob}) + + input_height = blob.shape[2] + input_width = blob.shape[3] + fmc = self.fmc + for idx, stride in enumerate(self._feat_stride_fpn): + scores = net_outs[idx] + bbox_preds = net_outs[idx + fmc] + bbox_preds = bbox_preds * stride + if self.use_kps: + kps_preds = net_outs[idx + fmc * 2] * stride + height = input_height // stride + width = input_width // stride + key = (height, width, stride) + if key in self.center_cache: + anchor_centers = self.center_cache[key] + else: + # solution-1, c style: + # anchor_centers = np.zeros( (height, width, 2), dtype=np.float32 ) + # for i in range(height): + # anchor_centers[i, :, 1] = i + # for i in range(width): + # anchor_centers[:, i, 0] = i + + # solution-2: + # ax = np.arange(width, dtype=np.float32) + # ay = np.arange(height, dtype=np.float32) + # xv, yv = np.meshgrid(np.arange(width), np.arange(height)) + # anchor_centers = np.stack([xv, yv], axis=-1).astype(np.float32) + + # solution-3: + anchor_centers = np.stack(np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32) + # print(anchor_centers.shape) + + anchor_centers = (anchor_centers * stride).reshape((-1, 2)) + if self._num_anchors > 1: + anchor_centers = np.stack([anchor_centers] * self._num_anchors, axis=1).reshape((-1, 2)) + if len(self.center_cache) < 100: + self.center_cache[key] = anchor_centers + + pos_inds = np.where(scores >= threshold)[0] + bboxes = distance2bbox(anchor_centers, bbox_preds) + pos_scores = scores[pos_inds] + pos_bboxes = bboxes[pos_inds] + scores_list.append(pos_scores) + bboxes_list.append(pos_bboxes) + if self.use_kps: + kpss = distance2kps(anchor_centers, kps_preds) + # kpss = kps_preds + kpss = kpss.reshape((kpss.shape[0], -1, 2)) + pos_kpss = kpss[pos_inds] + kpss_list.append(pos_kpss) + return scores_list, bboxes_list, kpss_list + + def detect(self, img, input_size=None, max_num=0, metric="default"): + assert input_size is not None or self.input_size is not None + input_size = self.input_size if input_size is None else input_size + + im_ratio = float(img.shape[0]) / img.shape[1] + model_ratio = float(input_size[1]) / input_size[0] + if im_ratio > model_ratio: + new_height = input_size[1] + new_width = int(new_height / im_ratio) + else: + new_width = input_size[0] + new_height = int(new_width * im_ratio) + det_scale = float(new_height) / img.shape[0] + resized_img = cv2.resize(img, (new_width, new_height)) + det_img = np.zeros((input_size[1], input_size[0], 3), dtype=np.uint8) + det_img[:new_height, :new_width, :] = resized_img + + scores_list, bboxes_list, kpss_list = self.forward(det_img, self.det_thresh) + + scores = np.vstack(scores_list) + scores_ravel = scores.ravel() + order = scores_ravel.argsort()[::-1] + bboxes = np.vstack(bboxes_list) / det_scale + if self.use_kps: + kpss = np.vstack(kpss_list) / det_scale + pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False) + pre_det = pre_det[order, :] + keep = self.nms(pre_det) + det = pre_det[keep, :] + if self.use_kps: + kpss = kpss[order, :, :] + kpss = kpss[keep, :, :] + else: + kpss = None + if max_num > 0 and det.shape[0] > max_num: + area = (det[:, 2] - det[:, 0]) * (det[:, 3] - det[:, 1]) + img_center = img.shape[0] // 2, img.shape[1] // 2 + offsets = np.vstack( + [(det[:, 0] + det[:, 2]) / 2 - img_center[1], (det[:, 1] + det[:, 3]) / 2 - img_center[0]] + ) + offset_dist_squared = np.sum(np.power(offsets, 2.0), 0) + if metric == "max": + values = area + else: + values = area - offset_dist_squared * 2.0 # some extra weight on the centering + bindex = np.argsort(values)[::-1] # some extra weight on the centering + bindex = bindex[0:max_num] + det = det[bindex, :] + if kpss is not None: + kpss = kpss[bindex, :] + return det, kpss + + def nms(self, dets): + thresh = self.nms_thresh + x1 = dets[:, 0] + y1 = dets[:, 1] + x2 = dets[:, 2] + y2 = dets[:, 3] + scores = dets[:, 4] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + + return keep + + +def cuda_available(): + try: + subprocess.check_output(["nvidia-smi"], stderr=subprocess.STDOUT) + return True + except (subprocess.CalledProcessError, FileNotFoundError): + return False + + +def prepare_model(det_thres=0.5, nms_thresh=0.4) -> RetinaFace: + model_file = "weights/models/buffalo_l/det_10g.onnx" + + if not os.path.exists(model_file): + os.makedirs("weights/models/buffalo_l", exist_ok=True) + os.system( + "wget https://huggingface.co/datasets/theanhntp/Liblib/resolve/ae4357741af379482690fe3e0f2fa6fd32ba33b4/insightface/models/buffalo_l/det_10g.onnx -O weights/models/buffalo_l/det_10g.onnx" + ) + + if cuda_available(): + providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] + else: + providers = ["CPUExecutionProvider"] + + model = RetinaFace(model_file, providers=providers) + + model.prepare( + ctx_id=0, + nms_thresh=nms_thresh, + input_size=(640, 640), + det_thresh=det_thres, + ) + return model + + +if __name__ == "__main__": + # Example of usage: + + # Prepare the model + model = prepare_model(det_thres=0.2, nms_thresh=0.4) + + # Download the image + url = "https://cdn.mos.cms.futurecdn.net/GA98TY8kmqu5WtSx4m9Ha7.jpg" + os.system("wget -O image.jpg " + url) + + # Read the image + img = cv2.imread("image.jpg") + + # Make predictions + bboxes, kpts = model.detect(img) + + # Show predictions + for bbox in bboxes: + x1, y1, x2, y2, score = bbox + + # Add bounding box to the image + cv2.rectangle(img, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2) + + # Add score to the image + cv2.putText(img, f"{score:.2f}", (int(x1), int(y1) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) + + # Add keypoints if available + for kpt in kpts: # kpt is shape (5, 2) for 5 keypoints (x, y) + for point in kpt: # point is shape (2,) -> [x, y] + x, y = point + cv2.circle(img, (int(x), int(y)), 2, (255, 0, 0), -1) + + # Save the image with predictions + cv2.imwrite("image_with_predictions.jpg", img) diff --git a/src/utils/checks.py b/src/utils/checks.py new file mode 100644 index 0000000000000000000000000000000000000000..94ccdd63e810ffe36d914cfbeb7c0a6e3a3ebc10 --- /dev/null +++ b/src/utils/checks.py @@ -0,0 +1,55 @@ +import os +import shutil + +from ..config import Config +from . import logger + + +def checks(config: Config): + save_dir = f"{config.run_dir}/{config.run_name}" + + if "tmp" in config.run_name: + logger.print_warning("Using 'tmp' in run name. Wandb will not be used.") + config.wandb = False + + if os.path.exists(save_dir) and "tmp" not in save_dir: + if config.throw_exception_if_run_exists: + raise FileExistsError(f"Folder {save_dir} exists, remove it or include 'tmp' in run name") + logger.print() + logger.print_warning(f"folder [magenta]{save_dir}[/] exists, remove it or include 'tmp' in run name") + if config.remove_if_run_exists: + logger.print_warning(f"Folder [magenta]{save_dir}[/] is removed") + shutil.rmtree(str(save_dir)) + else: + logger.print("Enter [green bold]R[/] to replace") + # Interactively ask + key = input() + if key not in ["R"]: + logger.print_error("Aborted") + exit() + if key == "R": + logger.print_warning(f"Folder [magenta]{save_dir}[/] is removed") + shutil.rmtree(str(save_dir)) + + if config.binary_labels and config.num_classes != 2: + raise ValueError("Binary labels is only supported for 2 classes") + + def get_files_from_dict_values(d: list[str] | dict[str, list[str]]): + if isinstance(d, list): + return d + return [f for sublist in d.values() for f in sublist] + + trn_files = get_files_from_dict_values(config.trn_files) + if not all(os.path.exists(f) for f in trn_files): + not_found = [f for f in trn_files if not os.path.exists(f)] + raise FileNotFoundError(f"Some train files are not found: {not_found}") + + val_files = get_files_from_dict_values(config.val_files) + if not all(os.path.exists(f) for f in val_files): + not_found = [f for f in val_files if not os.path.exists(f)] + raise FileNotFoundError(f"Some val files are not found: {not_found}") + + tst_files = get_files_from_dict_values(config.tst_files) + if not all(os.path.exists(f) for f in tst_files): + not_found = [f for f in tst_files if not os.path.exists(f)] + raise FileNotFoundError(f"Some test files are not found: {not_found}") diff --git a/src/utils/constants.py b/src/utils/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..b3b80703c2be98957a11f1711f80524a6b1afccb --- /dev/null +++ b/src/utils/constants.py @@ -0,0 +1,8 @@ +import os + +# See more: https://pytorch.org/docs/stable/elastic/run.html#environment-variables +RANK = int(os.getenv("LOCAL_RANK", 0)) +IS_GLOBAL_ZERO = RANK == 0 +NODE_RANK = int(os.getenv("NODE_RANK", 0)) +LOCAL_RANK = int(os.getenv("LOCAL_RANK", 0)) +WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1)) diff --git a/src/utils/decorators.py b/src/utils/decorators.py new file mode 100644 index 0000000000000000000000000000000000000000..9a7c7c56418255bbfbd493ee0985a65313dbb50b --- /dev/null +++ b/src/utils/decorators.py @@ -0,0 +1,39 @@ +import contextlib +import functools + +from .logger import print_error + + +class TryExcept(contextlib.ContextDecorator): + """Usage: @TryExcept() decorator or 'with TryExcept():' context manager.""" + + def __init__(self, msg: str = "", verbose: bool = True): + """Initialize TryExcept class with optional message and verbosity settings.""" + self.msg = msg + self.verbose = verbose + + def __call__(self, func): + """ + Allows the instance to be used as a decorator. + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + if self.verbose: + msg = f"{self.msg}{': ' if self.msg else ''}[red]{e}[/red]" + print_error(f"caught by [green]{func.__name__}[/green] decorator. {msg}") + + return wrapper + + def __enter__(self): + """Executes when entering TryExcept context, initializes instance.""" + return self + + def __exit__(self, exc_type, value, traceback): + """Defines behavior when exiting a 'with' block, prints error message if necessary.""" + if self.verbose and value: + print_error(f"{self.msg}{': ' if self.msg else ''}{value}") + return True diff --git a/src/utils/files.py b/src/utils/files.py new file mode 100644 index 0000000000000000000000000000000000000000..038332addbb21446d1cb3c428f74fbfe0722ef42 --- /dev/null +++ b/src/utils/files.py @@ -0,0 +1,805 @@ +from glob import glob + + +def find_run_dir(run_name: str) -> str: + runs = list(glob(f"runs/*/{run_name}", recursive=True)) + if len(runs) == 0: + raise FileNotFoundError(f"Directory for run '{run_name}' is not found") + if len(runs) > 1: + raise FileExistsError(f"Multiple directories found for run '{run_name}': {runs}") + return runs[0] + + +# Extend list definition with map function +class Files(list): + def __init__(self, *files): + # If a single non-string iterable is passed, use it directly; otherwise treat args as items + if len(files) == 1 and not isinstance(files[0], (str, bytes)): + super().__init__(files[0]) + else: + super().__init__(files) + + def map(self, func): + return Files(map(func, self)) + + def unique(self): + return Files(sorted(set(self))) + + def cat(self, other): + return Files(self + other) + + +class FF: + """https://arxiv.org/abs/1901.08971""" + + class DF: + test = Files( + "config/datasets/FF/test/DF.txt", + "config/datasets/FF/test/real.txt", + ) + + class F2F: + test = Files( + "config/datasets/FF/test/F2F.txt", + "config/datasets/FF/test/real.txt", + ) + + class FS: + test = Files( + "config/datasets/FF/test/FS.txt", + "config/datasets/FF/test/real.txt", + ) + + class NT: + test = Files( + "config/datasets/FF/test/NT.txt", + "config/datasets/FF/test/real.txt", + ) + + def to_train(f) -> str: + return f.replace("/test/", "/train/") + + def to_val(f) -> str: + return f.replace("/test/", "/val/") + + def to_x1_5(f) -> str: + return f.replace("/FF/", "/FF-x1.5/") + + def to_x2(f) -> str: + return f.replace("/FF/", "/FF-x2.0/") + + def to_rmbg_x1_5(f) -> str: + return f.replace("/FF/", "/FF-rmbg-x1.5/") + + test = Files(DF.test + F2F.test + FS.test + NT.test).unique() + train = test.map(to_train) + val = test.map(to_val) + + +class CDFv2: + """https://arxiv.org/abs/1909.12962""" + + test = Files( + "config/datasets/CDFv2/test/Celeb-synthesis.txt", + "config/datasets/CDFv2/test/YouTube-real.txt", + "config/datasets/CDFv2/test/Celeb-real.txt", + ) + + # It is not an official validation set but generated from {all}\{test} files + # using scripts/datasets/create_validation_set.py + val = Files( + "config/datasets/CDFv2/val/Celeb-synthesis.txt", + "config/datasets/CDFv2/val/YouTube-real.txt", + "config/datasets/CDFv2/val/Celeb-real.txt", + ) + + my_train = Files( + "config/datasets/CDFv2/my-train/Celeb-synthesis.txt", + "config/datasets/CDFv2/my-train/YouTube-real.txt", + "config/datasets/CDFv2/my-train/Celeb-real.txt", + ) + + +class CDFv3: + """https://arxiv.org/abs/2507.18015v1""" + + class FS: + """Face-swap""" + + class CDFv2: + test = Files( + "config/datasets/CDFv3/test/Celeb-DF-v2.txt", + "config/datasets/CDFv3/test/Celeb-real.txt", + "config/datasets/CDFv3/test/YouTube-real.txt", + ) + + class BlendFace: + test = Files( + "config/datasets/CDFv3/test/BlendFace.txt", + "config/datasets/CDFv3/test/Celeb-real.txt", + "config/datasets/CDFv3/test/YouTube-real.txt", + ) + + class GHOST: + test = Files( + "config/datasets/CDFv3/test/GHOST.txt", + "config/datasets/CDFv3/test/Celeb-real.txt", + "config/datasets/CDFv3/test/YouTube-real.txt", + ) + + class HifiFace: + test = Files( + "config/datasets/CDFv3/test/HifiFace.txt", + "config/datasets/CDFv3/test/Celeb-real.txt", + "config/datasets/CDFv3/test/YouTube-real.txt", + ) + + class InSwapper: + test = Files( + "config/datasets/CDFv3/test/InSwapper.txt", + "config/datasets/CDFv3/test/Celeb-real.txt", + "config/datasets/CDFv3/test/YouTube-real.txt", + ) + + class MobileFaceSwap: + test = Files( + "config/datasets/CDFv3/test/MobileFaceSwap.txt", + "config/datasets/CDFv3/test/Celeb-real.txt", + "config/datasets/CDFv3/test/YouTube-real.txt", + ) + + class SimSwap: + test = Files( + "config/datasets/CDFv3/test/SimSwap.txt", + "config/datasets/CDFv3/test/Celeb-real.txt", + "config/datasets/CDFv3/test/YouTube-real.txt", + ) + + class UniFace: + test = Files( + "config/datasets/CDFv3/test/UniFace.txt", + "config/datasets/CDFv3/test/Celeb-real.txt", + "config/datasets/CDFv3/test/YouTube-real.txt", + ) + + test = Files( + CDFv2.test + + BlendFace.test + + GHOST.test + + HifiFace.test + + InSwapper.test + + MobileFaceSwap.test + + SimSwap.test + + UniFace.test + ).unique() + + class FR: + """Face Reenectment""" + + class DaGAN: + test = Files( + "config/datasets/CDFv3/test/DaGAN.txt", + "config/datasets/CDFv3/test/Celeb-real.txt", + "config/datasets/CDFv3/test/YouTube-real.txt", + ) + + class FSRT: + test = Files( + "config/datasets/CDFv3/test/FSRT.txt", + "config/datasets/CDFv3/test/Celeb-real.txt", + "config/datasets/CDFv3/test/YouTube-real.txt", + ) + + class HyperReenact: + test = Files( + "config/datasets/CDFv3/test/HyperReenact.txt", + "config/datasets/CDFv3/test/Celeb-real.txt", + "config/datasets/CDFv3/test/YouTube-real.txt", + ) + + class LIA: + test = Files( + "config/datasets/CDFv3/test/LIA.txt", + "config/datasets/CDFv3/test/Celeb-real.txt", + "config/datasets/CDFv3/test/YouTube-real.txt", + ) + + class LivePortrait: + test = Files( + "config/datasets/CDFv3/test/LivePortrait.txt", + "config/datasets/CDFv3/test/Celeb-real.txt", + "config/datasets/CDFv3/test/YouTube-real.txt", + ) + + class MCNET: + test = Files( + "config/datasets/CDFv3/test/MCNET.txt", + "config/datasets/CDFv3/test/Celeb-real.txt", + "config/datasets/CDFv3/test/YouTube-real.txt", + ) + + class TPSMM: + test = Files( + "config/datasets/CDFv3/test/TPSMM.txt", + "config/datasets/CDFv3/test/Celeb-real.txt", + "config/datasets/CDFv3/test/YouTube-real.txt", + ) + + test = Files( + DaGAN.test + FSRT.test + HyperReenact.test + LIA.test + LivePortrait.test + MCNET.test + TPSMM.test + ).unique() + + class TF: + """Talking Face""" + + class AniTalker: + test = Files( + "config/datasets/CDFv3/test/AniTalker.txt", + "config/datasets/CDFv3/test/Celeb-real.txt", + "config/datasets/CDFv3/test/YouTube-real.txt", + ) + + class EchoMimic: + test = Files( + "config/datasets/CDFv3/test/EchoMimic.txt", + "config/datasets/CDFv3/test/Celeb-real.txt", + "config/datasets/CDFv3/test/YouTube-real.txt", + ) + + class EDTalk: + test = Files( + "config/datasets/CDFv3/test/EDTalk.txt", + "config/datasets/CDFv3/test/Celeb-real.txt", + "config/datasets/CDFv3/test/YouTube-real.txt", + ) + + class FLOAT: + test = Files( + "config/datasets/CDFv3/test/FLOAT.txt", + "config/datasets/CDFv3/test/Celeb-real.txt", + "config/datasets/CDFv3/test/YouTube-real.txt", + ) + + class IP_LAP: + test = Files( + "config/datasets/CDFv3/test/IP_LAP.txt", + "config/datasets/CDFv3/test/Celeb-real.txt", + "config/datasets/CDFv3/test/YouTube-real.txt", + ) + + class Real3DPortrait: + test = Files( + "config/datasets/CDFv3/test/Real3DPortrait.txt", + "config/datasets/CDFv3/test/Celeb-real.txt", + "config/datasets/CDFv3/test/YouTube-real.txt", + ) + + class SadTalker: + test = Files( + "config/datasets/CDFv3/test/SadTalker.txt", + "config/datasets/CDFv3/test/Celeb-real.txt", + "config/datasets/CDFv3/test/YouTube-real.txt", + ) + + test = Files( + AniTalker.test + + EchoMimic.test + + EDTalk.test + + FLOAT.test + + IP_LAP.test + + Real3DPortrait.test + + SadTalker.test + ).unique() + + def to_train(f) -> str: + return f.replace("/test/", "/train/") + + def to_my_val(f) -> str: + return f.replace("/test/", "/my-val/") + + def to_x1_5(f) -> str: + return f.replace("/CDFv3/", "/CDFv3-x1.5/") + + def to_x2(f) -> str: + return f.replace("/CDFv3/", "/CDFv3-x2.0/") + + def to_rmbg_x1_5(f) -> str: + return f.replace("/CDFv3/", "/CDFv3-rmbg-x1.5/") + + def to_x1_3_th0_5_all(f) -> str: + return f.replace("/CDFv3/", "/CDFv3-x1.3-th0.5-all/") + + test = Files(FS.test + FR.test + TF.test).unique() + train = test.map(to_train) + my_val = test.map(to_my_val) + + @classmethod + def get_test_dict(cls) -> dict[str, list[str]]: + return { + "CDFv3": cls.test, + "CDFv3-FS": cls.FS.test, + "CDFv3-FR": cls.FR.test, + "CDFv3-TF": cls.TF.test, + "CDFv3-FS-CDFv2": cls.FS.CDFv2.test, + "CDFv3-FS-BlendFace": cls.FS.BlendFace.test, + "CDFv3-FS-GHOST": cls.FS.GHOST.test, + "CDFv3-FS-HifiFace": cls.FS.HifiFace.test, + "CDFv3-FS-InSwapper": cls.FS.InSwapper.test, + "CDFv3-FS-MobileFaceSwap": cls.FS.MobileFaceSwap.test, + "CDFv3-FS-SimSwap": cls.FS.SimSwap.test, + "CDFv3-FS-UniFace": cls.FS.UniFace.test, + "CDFv3-FR-DaGAN": cls.FR.DaGAN.test, + "CDFv3-FR-FSRT": cls.FR.FSRT.test, + "CDFv3-FR-HyperReenact": cls.FR.HyperReenact.test, + "CDFv3-FR-LIA": cls.FR.LIA.test, + "CDFv3-FR-LivePortrait": cls.FR.LivePortrait.test, + "CDFv3-FR-MCNET": cls.FR.MCNET.test, + "CDFv3-FR-TPSMM": cls.FR.TPSMM.test, + "CDFv3-TF-AniTalker": cls.TF.AniTalker.test, + "CDFv3-TF-EchoMimic": cls.TF.EchoMimic.test, + "CDFv3-TF-EDTalk": cls.TF.EDTalk.test, + "CDFv3-TF-FLOAT": cls.TF.FLOAT.test, + "CDFv3-TF-IP_LAP": cls.TF.IP_LAP.test, + "CDFv3-TF-Real3DPortrait": cls.TF.Real3DPortrait.test, + "CDFv3-TF-SadTalker": cls.TF.SadTalker.test, + } + + +class DFD: + test = Files( + "config/datasets/DFD/fake.txt", + "config/datasets/DFD/real.txt", + ) + + +class DFDC: + test = Files( + "config/datasets/DFDC/test/fake.txt", + "config/datasets/DFDC/test/real.txt", + ) + + +class FSh: + """ + FSh: https://github.com/maum-ai/faceshifter + FF++: https://github.com/ondyari/FaceForensics + """ + + test = Files( + "config/datasets/FSh/test/fake.txt", + "config/datasets/FSh/test/real.txt", + ) + + +class UADFV: + """https://arxiv.org/abs/1806.02877""" + + test = Files( + "config/datasets/UADFD/fake.txt", + "config/datasets/UADFD/real.txt", + ) + + +class DFDM: + """https://arxiv.org/abs/2202.12951""" + + test = Files( + "config/datasets/DFDM/all/dfaker.txt", + "config/datasets/DFDM/all/dfl.txt", + "config/datasets/DFDM/all/iae.txt", + "config/datasets/DFDM/all/lightweight.txt", + "config/datasets/CDFv2/all/Celeb-real.txt", + ) + + +class FFIW: + """https://arxiv.org/abs/2103.16076""" + + test = Files( + "config/datasets/FFIW/test-fake.txt", + "config/datasets/FFIW/test-real.txt", + ) + + val = Files( + "config/datasets/FFIW/val-fake.txt", + "config/datasets/FFIW/val-real.txt", + ) + + train = Files( + "config/datasets/FFIW/train-fake.txt", + "config/datasets/FFIW/train-real.txt", + ) + + # My subsets of FFIW + train_subset_1024 = Files( + "config/datasets/FFIW/subsets/train-fake-subset-1024.txt", + "config/datasets/FFIW/subsets/train-real-subset-1024.txt", + ) + + # My subsets of FFIW created using scripts/datasets/FFIW/create_FFIW_subset.py + train_subset_2048 = Files( + "config/datasets/FFIW/subset-2048/train-fake.txt", + "config/datasets/FFIW/subset-2048/train-real.txt", + ) + + +class DeepSpeak_v1_1: + """https://arxiv.org/abs/2408.05366""" + + test = Files( + "config/datasets/DeepSpeak-1.1/test/test-facefusion_gan.txt", + "config/datasets/DeepSpeak-1.1/test/test-facefusion_live.txt", + "config/datasets/DeepSpeak-1.1/test/test-facefusion.txt", + "config/datasets/DeepSpeak-1.1/test/test-real.txt", + "config/datasets/DeepSpeak-1.1/test/test-retalking.txt", + "config/datasets/DeepSpeak-1.1/test/test-wav2lip.txt", + ) + + train = test.map(lambda x: x.replace("/test/test-", "/train/train-")) + + # DeepSpeak-1.1 has a train folder. my-val is sampled from train + my_val = test.map(lambda x: x.replace("/test/test-", "/my-val/val-")) + + # DeepSpeak-1.1 has a train folder. my-val is sampled from train, my-train is train \ my-val + my_train = test.map(lambda x: x.replace("/test/test-", "/my-train/train-")) + + +class DeepSpeak_v2: + """https://arxiv.org/abs/2408.05366""" + + test = Files( + "config/datasets/DeepSpeak-2.0/test/test-diff2lip.txt", + "config/datasets/DeepSpeak-2.0/test/test-facefusion.txt", + "config/datasets/DeepSpeak-2.0/test/test-hellomeme.txt", + "config/datasets/DeepSpeak-2.0/test/test-latentsync.txt", + "config/datasets/DeepSpeak-2.0/test/test-liveportrait.txt", + "config/datasets/DeepSpeak-2.0/test/test-memo.txt", + "config/datasets/DeepSpeak-2.0/test/test-real.txt", + ) + + train = test.map(lambda x: x.replace("/test/test-", "/train/train-")) + + # DeepSpeak-2.0 has a train folder. my-val is sampled from train + my_val = test.map(lambda x: x.replace("/test/test-", "/my-val/val-")) + + # DeepSpeak-2.0 has a train folder. my-val is sampled from train, my-train is train \ my-val + my_train = test.map(lambda x: x.replace("/test/test-", "/my-train/train-")) + + +class KoDF: + """https://arxiv.org/abs/2103.10094""" + + test = Files( + "config/datasets/KoDF/real.txt", + "config/datasets/KoDF/fake-audio-driven.txt", + "config/datasets/KoDF/fake-dffs.txt", + "config/datasets/KoDF/fake-dfl.txt", + "config/datasets/KoDF/fake-fo.txt", + "config/datasets/KoDF/fake-fsgan.txt", + ) + + adversarial = Files( + "config/datasets/KoDF/fake-adv.txt", + "config/datasets/KoDF/real-adv.txt", + ) + + +class FaceFusion: + """Dataset created by VRG group""" + + class FF: + train = Files( + "config/datasets/FaceFusion/train/ff_inswapper_128_fp16.txt", + ) + + class CDF: + test = Files( + "config/datasets/FaceFusion/test/cdf_hififace_unofficial_256.txt", + "config/datasets/FaceFusion/test/cdf_inswapper_128_fp16.txt", + "config/datasets/CDFv2/test/YouTube-real.txt", + "config/datasets/CDFv2/test/Celeb-real.txt", + ) + + +class VRG: + """Dataset created by VRG group""" + + class CSFD: + files = Files( + "config/datasets/CSFD/real.txt", + ) + + +class AVSpeech: + files = Files( + "config/datasets/AVSpeech/real.txt", + ) + + +class FakeAVCeleb: + """https://arxiv.org/abs/2108.05080""" + + test = Files( + "config/datasets/FakeAVCeleb/FV-FA-faceswap-wav2lip.txt", + "config/datasets/FakeAVCeleb/FV-FA-fsgan-wav2lip.txt", + "config/datasets/FakeAVCeleb/FV-FA-wav2lip.txt", + "config/datasets/FakeAVCeleb/FV-RA-faceswap.txt", + "config/datasets/FakeAVCeleb/FV-RA-fsgan.txt", + "config/datasets/FakeAVCeleb/FV-RA-wav2lip.txt", + "config/datasets/FakeAVCeleb/RV-RA-real.txt", + ) + + class FV_RA_WL: + test = Files( + "config/datasets/FakeAVCeleb/FV-RA-wav2lip.txt", + "config/datasets/FakeAVCeleb/RV-RA-real.txt", + ) + + class FV_FA_FS: + test = Files( + "config/datasets/FakeAVCeleb/FV-FA-faceswap-wav2lip.txt", + "config/datasets/FakeAVCeleb/RV-RA-real.txt", + ) + + class FV_FA_GAN: + test = Files( + "config/datasets/FakeAVCeleb/FV-FA-fsgan-wav2lip.txt", + "config/datasets/FakeAVCeleb/RV-RA-real.txt", + ) + + class FV_FA_WL: + test = Files( + "config/datasets/FakeAVCeleb/FV-FA-wav2lip.txt", + "config/datasets/FakeAVCeleb/RV-RA-real.txt", + ) + + +class PolyGlotFake: + """https://arxiv.org/abs/2405.08838""" + + test = Files( + "config/datasets/PolyGlotFake/real-ar.txt", + "config/datasets/PolyGlotFake/real-en.txt", + "config/datasets/PolyGlotFake/real-es.txt", + "config/datasets/PolyGlotFake/real-fr.txt", + "config/datasets/PolyGlotFake/real-ja.txt", + "config/datasets/PolyGlotFake/real-ru.txt", + "config/datasets/PolyGlotFake/real-zh.txt", + "config/datasets/PolyGlotFake/ar2en_Wav2Lip.txt", + "config/datasets/PolyGlotFake/ar2en_video_retalking.txt", + "config/datasets/PolyGlotFake/ar2es_Wav2Lip.txt", + "config/datasets/PolyGlotFake/ar2es_video_retalking.txt", + "config/datasets/PolyGlotFake/ar2fr_Wav2Lip.txt", + "config/datasets/PolyGlotFake/ar2fr_video_retalking.txt", + "config/datasets/PolyGlotFake/ar2ja_Wav2Lip.txt", + "config/datasets/PolyGlotFake/ar2ja_video_retalking.txt", + "config/datasets/PolyGlotFake/ar2ru_Wav2Lip.txt", + "config/datasets/PolyGlotFake/ar2ru_video_retalking.txt", + "config/datasets/PolyGlotFake/ar2zh_Wav2Lip.txt", + "config/datasets/PolyGlotFake/ar2zh_video_retalking.txt", + "config/datasets/PolyGlotFake/en2ar_Wav2Lip.txt", + "config/datasets/PolyGlotFake/en2ar_video_retalking.txt", + "config/datasets/PolyGlotFake/en2es_Wav2Lip.txt", + "config/datasets/PolyGlotFake/en2es_video_retalking.txt", + "config/datasets/PolyGlotFake/en2fr_Wav2Lip.txt", + "config/datasets/PolyGlotFake/en2fr_video_retalking.txt", + "config/datasets/PolyGlotFake/en2ja_Wav2Lip.txt", + "config/datasets/PolyGlotFake/en2ja_video_retalking.txt", + "config/datasets/PolyGlotFake/en2ru_Wav2Lip.txt", + "config/datasets/PolyGlotFake/en2ru_video_retalking.txt", + "config/datasets/PolyGlotFake/en2zh_Wav2Lip.txt", + "config/datasets/PolyGlotFake/en2zh_video_retalking.txt", + "config/datasets/PolyGlotFake/es2ar_Wav2Lip.txt", + "config/datasets/PolyGlotFake/es2ar_video_retalking.txt", + "config/datasets/PolyGlotFake/es2en_Wav2Lip.txt", + "config/datasets/PolyGlotFake/es2en_video_retalking.txt", + "config/datasets/PolyGlotFake/es2fr_Wav2Lip.txt", + "config/datasets/PolyGlotFake/es2fr_video_retalking.txt", + "config/datasets/PolyGlotFake/es2ja_Wav2Lip.txt", + "config/datasets/PolyGlotFake/es2ja_video_retalking.txt", + "config/datasets/PolyGlotFake/es2ru_Wav2Lip.txt", + "config/datasets/PolyGlotFake/es2ru_video_retalking.txt", + "config/datasets/PolyGlotFake/es2zh_Wav2Lip.txt", + "config/datasets/PolyGlotFake/es2zh_video_retalking.txt", + "config/datasets/PolyGlotFake/fr2ar_Wav2Lip.txt", + "config/datasets/PolyGlotFake/fr2ar_video_retalking.txt", + "config/datasets/PolyGlotFake/fr2en_Wav2Lip.txt", + "config/datasets/PolyGlotFake/fr2en_video_retalking.txt", + "config/datasets/PolyGlotFake/fr2es_Wav2Lip.txt", + "config/datasets/PolyGlotFake/fr2es_video_retalking.txt", + "config/datasets/PolyGlotFake/fr2ja_Wav2Lip.txt", + "config/datasets/PolyGlotFake/fr2ja_video_retalking.txt", + "config/datasets/PolyGlotFake/fr2ru_Wav2Lip.txt", + "config/datasets/PolyGlotFake/fr2ru_video_retalking.txt", + "config/datasets/PolyGlotFake/fr2zh_Wav2Lip.txt", + "config/datasets/PolyGlotFake/fr2zh_video_retalking.txt", + "config/datasets/PolyGlotFake/ja2ar_Wav2Lip.txt", + "config/datasets/PolyGlotFake/ja2ar_video_retalking.txt", + "config/datasets/PolyGlotFake/ja2en_Wav2Lip.txt", + "config/datasets/PolyGlotFake/ja2en_video_retalking.txt", + "config/datasets/PolyGlotFake/ja2es_Wav2Lip.txt", + "config/datasets/PolyGlotFake/ja2es_video_retalking.txt", + "config/datasets/PolyGlotFake/ja2fr_Wav2Lip.txt", + "config/datasets/PolyGlotFake/ja2fr_video_retalking.txt", + "config/datasets/PolyGlotFake/ja2ru_Wav2Lip.txt", + "config/datasets/PolyGlotFake/ja2ru_video_retalking.txt", + "config/datasets/PolyGlotFake/ja2zh_Wav2Lip.txt", + "config/datasets/PolyGlotFake/ja2zh_video_retalking.txt", + "config/datasets/PolyGlotFake/ru2ar_Wav2Lip.txt", + "config/datasets/PolyGlotFake/ru2ar_video_retalking.txt", + "config/datasets/PolyGlotFake/ru2en_Wav2Lip.txt", + "config/datasets/PolyGlotFake/ru2en_video_retalking.txt", + "config/datasets/PolyGlotFake/ru2es_Wav2Lip.txt", + "config/datasets/PolyGlotFake/ru2es_video_retalking.txt", + "config/datasets/PolyGlotFake/ru2fr_Wav2Lip.txt", + "config/datasets/PolyGlotFake/ru2fr_video_retalking.txt", + "config/datasets/PolyGlotFake/ru2ja_Wav2Lip.txt", + "config/datasets/PolyGlotFake/ru2zh_Wav2Lip.txt", + "config/datasets/PolyGlotFake/ru2zh_video_retalking.txt", + "config/datasets/PolyGlotFake/zh2ar_Wav2Lip.txt", + "config/datasets/PolyGlotFake/zh2ar_video_retalking.txt", + "config/datasets/PolyGlotFake/zh2en_Wav2Lip.txt", + "config/datasets/PolyGlotFake/zh2en_video_retalking.txt", + "config/datasets/PolyGlotFake/zh2es_Wav2Lip.txt", + "config/datasets/PolyGlotFake/zh2es_video_retalking.txt", + "config/datasets/PolyGlotFake/zh2fr_Wav2Lip.txt", + "config/datasets/PolyGlotFake/zh2fr_video_retalking.txt", + "config/datasets/PolyGlotFake/zh2ja_Wav2Lip.txt", + "config/datasets/PolyGlotFake/zh2ru_Wav2Lip.txt", + "config/datasets/PolyGlotFake/zh2ru_video_retalking.txt", + ) + + +class IDForge_v1: + """https://arxiv.org/abs/2401.11764""" + + train = Files( + "config/datasets/IDForge-v1/train/train-face_tts_infoswap.txt", + "config/datasets/IDForge-v1/train/train-face_tts_roop.txt", + "config/datasets/IDForge-v1/train/train-face_tts_simswap.txt", + "config/datasets/IDForge-v1/train/train-real.txt", + ) + + val = Files( + "config/datasets/IDForge-v1/val/val-face_tts_infoswap.txt", + "config/datasets/IDForge-v1/val/val-face_tts_roop.txt", + "config/datasets/IDForge-v1/val/val-face_tts_simswap.txt", + "config/datasets/IDForge-v1/val/val-real.txt", + ) + + test = Files( + "config/datasets/IDForge-v1/test/test-face_tts_infoswap.txt", + "config/datasets/IDForge-v1/test/test-face_tts_roop.txt", + "config/datasets/IDForge-v1/test/test-face_tts_simswap.txt", + "config/datasets/IDForge-v1/test/test-real.txt", + ) + + +class DF40: + """https://arxiv.org/abs/2406.13495""" + + class CDF: + class SadTalker: + test = Files( + "config/datasets/DF40/test/test_cdf_sadtalker.txt", + "config/datasets/CDFv2/test/YouTube-real.txt", + "config/datasets/CDFv2/test/Celeb-real.txt", + ) + + class FOMM: + test = Files( + "config/datasets/DF40/test/test_cdf_fomm.txt", + "config/datasets/CDFv2/test/YouTube-real.txt", + "config/datasets/CDFv2/test/Celeb-real.txt", + ) + + class FaceDancer: + test = Files( + "config/datasets/DF40/test/test_cdf_facedancer.txt", + "config/datasets/CDFv2/test/YouTube-real.txt", + "config/datasets/CDFv2/test/Celeb-real.txt", + ) + + class MobileSwap: + test = Files( + "config/datasets/DF40/test/test_cdf_mobileswap.txt", + "config/datasets/CDFv2/test/YouTube-real.txt", + "config/datasets/CDFv2/test/Celeb-real.txt", + ) + + class SimSwap: + test = Files( + "config/datasets/DF40/test/test_cdf_simswap.txt", + "config/datasets/CDFv2/test/YouTube-real.txt", + "config/datasets/CDFv2/test/Celeb-real.txt", + ) + + class InSwapper: + test = Files( + "config/datasets/DF40/test/test_cdf_inswap.txt", + "config/datasets/CDFv2/test/YouTube-real.txt", + "config/datasets/CDFv2/test/Celeb-real.txt", + ) + + class Uniface: + test = Files( + "config/datasets/DF40/test/test_cdf_uniface.txt", + "config/datasets/CDFv2/test/YouTube-real.txt", + "config/datasets/CDFv2/test/Celeb-real.txt", + ) + + +class FFv2: + """ + FaceFusion v2 dataset created by VRG group + """ + + class FF: + train = Files( + "config/datasets/FF/train/real.txt", + "config/datasets/FFv2/train/FF_blendswap_256.txt", + "config/datasets/FFv2/train/FF_ghost_1_256.txt", + # "config/datasets/FFv2/train/FF_ghost_2_256.txt", + # "config/datasets/FFv2/train/FF_ghost_3_256.txt", + "config/datasets/FFv2/train/FF_hififace_unofficial_256.txt", + "config/datasets/FFv2/train/FF_hyperswap_1a_256.txt", + # "config/datasets/FFv2/train/FF_hyperswap_1c_256.txt", + "config/datasets/FFv2/train/FF_inswapper_128_fp16.txt", + # "config/datasets/FFv2/train/FF_inswapper_128.txt", + "config/datasets/FFv2/train/FF_simswap_256.txt", + # "config/datasets/FFv2/train/FF_simswap_unofficial_512.txt", + "config/datasets/FFv2/train/FF_uniface_256.txt", + ) + + class SS: + train = Files( + "config/datasets/FF/train/real.txt", + "config/datasets/FFv2/train/SS_blendswap_256.txt", + "config/datasets/FFv2/train/SS_ghost_1_256.txt", + # "config/datasets/FFv2/train/SS_ghost_2_256.txt", + # "config/datasets/FFv2/train/SS_ghost_3_256.txt", + "config/datasets/FFv2/train/SS_hififace_unofficial_256.txt", + "config/datasets/FFv2/train/SS_hyperswap_1a_256.txt", + # "config/datasets/FFv2/train/SS_hyperswap_1c_256.txt", + "config/datasets/FFv2/train/SS_inswapper_128_fp16.txt", + # "config/datasets/FFv2/train/SS_inswapper_128.txt", + "config/datasets/FFv2/train/SS_simswap_256.txt", + # "config/datasets/FFv2/train/SS_simswap_unofficial_512.txt", + "config/datasets/FFv2/train/SS_uniface_256.txt", + ) + + +if __name__ == "__main__": + import pandas as pd + + def get_video(file_path: str) -> str: + return file_path.split("/")[-2] + + val_files = [ + # *CDFv3.test.map(CDFv3.to_train).map(CDFv3.to_x1_5), + # *FF.train.map(FF.to_x1_5), + # *FF.train.map(FF.to_rmbg_x1_5), + # *CDFv3.train.map(CDFv3.to_x1_3_th0_5_all), + # *DeepSpeak_v1_1.train + # *DeepSpeak_v2.train.cat(DeepSpeak_v1_1.train) + *FF.train.map(lambda x: x.replace("/FF/", "/FF-x1.3-th0.5-all/subset/1st-frame/")), + *DeepSpeak_v1_1.train.map(lambda x: x.replace("/DeepSpeak-1.1/", "/DeepSpeak-1.1/subset/1st-frame/")), + *DeepSpeak_v2.train.map(lambda x: x.replace("/DeepSpeak-2.0/", "/DeepSpeak-2.0/subset/1st-frame/")), + *FFIW.train.map(lambda x: x.replace("/FFIW/", "/FFIW/subset/1st-frame/")), + ] + + total_videos = 0 + for file in val_files: + # read with pandas + df = pd.read_csv(file, names=["files"]) + + df["video"] = df["files"].apply(lambda x: get_video(x)) + + # unique values + unique_videos = df["video"].unique() + + print(f"Unique videos in {file} : {len(unique_videos)}") + + total_videos += len(unique_videos) + + print(f"Total unique videos: {total_videos}") diff --git a/src/utils/logger.py b/src/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..89973b049a778a223ad3fec0f77c15f785db5f4f --- /dev/null +++ b/src/utils/logger.py @@ -0,0 +1,39 @@ +from rich import print as rprint + +from .constants import IS_GLOBAL_ZERO + +__all__ = ["print_error", "print_info", "print_warning", "print", "print_warning_once"] + +printed_warnings = set() + + +def print_error(text="", only_zero_rank=False): + if only_zero_rank and not IS_GLOBAL_ZERO: + return + rprint(f"[red bold]ERROR: [/red bold]{text}") + + +def print_warning(text="", only_zero_rank=False): + if only_zero_rank and not IS_GLOBAL_ZERO: + return + rprint(f"[yellow bold]WARNING: [/yellow bold]{text}") + + +def print_warning_once(text="", only_zero_rank=False): + global printed_warnings + if text in printed_warnings: + return + printed_warnings.add(text) + print_warning(text, only_zero_rank) + + +def print_info(text="", only_zero_rank=True): + if only_zero_rank and not IS_GLOBAL_ZERO: + return + rprint(f"[blue bold]INFO: [/blue bold]{text}") + + +def print(text="", only_zero_rank=True): + if only_zero_rank and not IS_GLOBAL_ZERO: + return + rprint(text) diff --git a/src/utils/model_checkpoint.py b/src/utils/model_checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..08001bb7b755359a3f6377999c1015738a51e6d0 --- /dev/null +++ b/src/utils/model_checkpoint.py @@ -0,0 +1,50 @@ +from concurrent.futures import ThreadPoolExecutor + +from lightning.pytorch import callbacks as pl_callbacks +from typing_extensions import override + +from src.utils import logger + + +class ModelCheckpointParallel(pl_callbacks.ModelCheckpoint): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.threads = [] + self.thread_pool = ThreadPoolExecutor(1, thread_name_prefix="ModelCheckpointParallel") + + @override + def on_train_batch_end(self, *args, **kwargs): + trainer = args[0] + if self._should_skip_saving_checkpoint(trainer): + return + self.threads.append(self.thread_pool.submit(super().on_train_batch_end, *args, **kwargs)) + + @override + def on_train_epoch_end(self, *args, **kwargs): + trainer = args[0] + if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer): + self.threads.append(self.thread_pool.submit(super().on_train_epoch_end, *args, **kwargs)) + + @override + def on_validation_end(self, *args, **kwargs): + trainer = args[0] + if not self._should_skip_saving_checkpoint(trainer) and not self._should_save_on_train_epoch_end(trainer): + self.threads.append(self.thread_pool.submit(super().on_validation_end, *args, **kwargs)) + + def wait(self): + for thread in self.threads: + try: + thread.result() + except Exception as e: + logger.print_error(f"Exception during checkpoint saving in thread: {e}") + self.thread_pool.shutdown(wait=True) + self.thread_pool = ThreadPoolExecutor(1, thread_name_prefix="ModelCheckpointParallel") + self.threads = [] + + @override + def on_train_end(self, *args, **kwargs): + self.wait() + + @override + def on_test_start(self, *args, **kwargs): + self.wait() diff --git a/src/utils/silencer.py b/src/utils/silencer.py new file mode 100644 index 0000000000000000000000000000000000000000..8b9383ccad0550a05d883270f0f6f16eeecffb93 --- /dev/null +++ b/src/utils/silencer.py @@ -0,0 +1,43 @@ +import os +import sys +from contextlib import contextmanager + + +@contextmanager +def silenced_output(): + """ + A context manager to suppress all stdout and stderr, + including from C-level libraries. + """ + # Open a null file descriptor + devnull_fd = os.open(os.devnull, os.O_RDWR) + + # Save the original stdout and stderr file descriptors + original_stdout_fd = sys.stdout.fileno() + original_stderr_fd = sys.stderr.fileno() + + # Duplicate the original file descriptors to save them + saved_stdout_fd = os.dup(original_stdout_fd) + saved_stderr_fd = os.dup(original_stderr_fd) + + try: + # Redirect Python's stdout/stderr file descriptors + # to the null device + sys.stdout.flush() + sys.stderr.flush() + os.dup2(devnull_fd, original_stdout_fd) + os.dup2(devnull_fd, original_stderr_fd) + + # Yield control back to the 'with' block + yield + finally: + # Restore the original stdout/stderr from the saved FDs + sys.stdout.flush() + sys.stderr.flush() + os.dup2(saved_stdout_fd, original_stdout_fd) + os.dup2(saved_stderr_fd, original_stderr_fd) + + # Close the temporary FDs + os.close(saved_stdout_fd) + os.close(saved_stderr_fd) + os.close(devnull_fd) diff --git a/src/utils/wb.py b/src/utils/wb.py new file mode 100644 index 0000000000000000000000000000000000000000..18f084fc7c8ec224c2b58961aa6c965909b6a0aa --- /dev/null +++ b/src/utils/wb.py @@ -0,0 +1,100 @@ +import numpy as np +import pandas as pd +import wandb + +from .decorators import TryExcept + + +@TryExcept() +def create_custom_wandb_metric( + xs: list, + ys: list, + classes: list, + title: str = "Precision Recall Curve", + x_axis_title: str = "Recall", + y_axis_title: str = "Precision", +): + """Creates a custom wandb metric similar to default wandb.plot.pr_curve + + Args: + xs: list of N values to plot on the x-axis + ys: list of N values to plot on the y-axis + classes: class labels for each point (list of N values) + title: plot title + + Returns: + wandb object to log + """ + df = pd.DataFrame( + { + "class": classes, + "y": ys, + "x": xs, + } + ).round(3) + + return wandb.plot_table( + "wandb/area-under-curve/v0", + wandb.Table(dataframe=df), + {"x": "x", "y": "y", "class": "class"}, + { + "title": title, + "x-axis-title": x_axis_title, + "y-axis-title": y_axis_title, + }, + ) + + +@TryExcept() +def plot_curve_wandb( + xs: np.ndarray, + ys: np.ndarray, + names: list = [], + id: str = "precision-recall", + title: str = "Precision Recall Curve", + x_axis_title: str = "Recall", + y_axis_title: str = "Precision", + num_xs: int = 100, + only_mean: bool = True, +): + """adds a metric curve to wandb + + Args: + xs: np.array of N values + ys: np.array of C by N values where C is the number of classes + names: dict of class names + id: log id in wandb + title: plot title in wandb + num_xs: number of points to interpolate to + only_mean: if True, only the mean curve is plotted + """ + # create new xs + xs_new = np.linspace(xs[0], xs[-1], num_xs) + + # create arrays for logging + xs_log = xs_new.tolist() + ys_log = np.interp(xs_new, xs, np.mean(ys, axis=0)).tolist() + classes = ["mean"] * len(xs_log) + + if not only_mean and len(names) == len(ys): + for i, y in enumerate(ys): + # add new xs + xs_log.extend(xs_new) + # interpolate y to new xs + ys_log.extend(np.interp(xs_new, xs, y)) + # add class names + classes.extend([names[i]] * len(xs_new)) + + wandb.log( + { + id: create_custom_wandb_metric( + xs_log, + ys_log, + classes, + title, + x_axis_title, + y_axis_title, + ) + }, + commit=False, + )