Spaces:
Configuration error
Configuration error
Commit ·
c29babb
0
Parent(s):
init
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +3 -0
- .gitignore +14 -0
- .project-root +1 -0
- LICENSE +21 -0
- README.md +157 -0
- config/datasets/CDFv2/test/Celeb-real.txt +4 -0
- config/datasets/CDFv2/test/Celeb-synthesis.txt +4 -0
- config/datasets/CDFv2/test/YouTube-real.txt +4 -0
- config/datasets/FF/test/DF.txt +4 -0
- config/datasets/FF/test/F2F.txt +4 -0
- config/datasets/FF/test/FS.txt +4 -0
- config/datasets/FF/test/NT.txt +4 -0
- config/datasets/FF/test/real.txt +4 -0
- datasets/CDFv2/Celeb-real/id0_0000/000.png +3 -0
- datasets/CDFv2/Celeb-real/id0_0000/015.png +3 -0
- datasets/CDFv2/Celeb-real/id0_0000/030.png +3 -0
- datasets/CDFv2/Celeb-real/id0_0000/045.png +3 -0
- datasets/CDFv2/Celeb-synthesis/id0_id1_0000/000.png +3 -0
- datasets/CDFv2/Celeb-synthesis/id0_id1_0000/015.png +3 -0
- datasets/CDFv2/Celeb-synthesis/id0_id1_0000/030.png +3 -0
- datasets/CDFv2/Celeb-synthesis/id0_id1_0000/045.png +3 -0
- datasets/CDFv2/YouTube-real/00000/000.png +3 -0
- datasets/CDFv2/YouTube-real/00000/014.png +3 -0
- datasets/CDFv2/YouTube-real/00000/028.png +3 -0
- datasets/CDFv2/YouTube-real/00000/043.png +3 -0
- datasets/FF/DF/000_003/000.png +3 -0
- datasets/FF/DF/000_003/012.png +3 -0
- datasets/FF/DF/000_003/025.png +3 -0
- datasets/FF/DF/000_003/038.png +3 -0
- datasets/FF/F2F/000_003/000.png +3 -0
- datasets/FF/F2F/000_003/009.png +3 -0
- datasets/FF/F2F/000_003/019.png +3 -0
- datasets/FF/F2F/000_003/029.png +3 -0
- datasets/FF/FS/000_003/000.png +3 -0
- datasets/FF/FS/000_003/009.png +3 -0
- datasets/FF/FS/000_003/019.png +3 -0
- datasets/FF/FS/000_003/029.png +3 -0
- datasets/FF/NT/000_003/000.png +3 -0
- datasets/FF/NT/000_003/009.png +3 -0
- datasets/FF/NT/000_003/019.png +3 -0
- datasets/FF/NT/000_003/029.png +3 -0
- datasets/FF/real/000/000.png +3 -0
- datasets/FF/real/000/012.png +3 -0
- datasets/FF/real/000/025.png +3 -0
- datasets/FF/real/000/038.png +3 -0
- detector.py +701 -0
- pyproject.toml +37 -0
- requirements.txt +34 -0
- run.py +174 -0
- run_exp.py +209 -0
.gitattributes
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
|
| 3 |
+
/.vscode
|
| 4 |
+
/config
|
| 5 |
+
/datasets
|
| 6 |
+
/outputs
|
| 7 |
+
/runs
|
| 8 |
+
/weights
|
| 9 |
+
/logs
|
| 10 |
+
/tmp
|
| 11 |
+
|
| 12 |
+
x.py
|
| 13 |
+
y.py
|
| 14 |
+
z.py
|
.project-root
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Do not remove, this file is used by the project to determine the root of the project
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Andy
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Deepfake Detection that Generalizes Across Benchmarks (WACV 2026)
|
| 2 |
+
|
| 3 |
+
[](https://arxiv.org/abs/2508.06248)
|
| 4 |
+
[](https://huggingface.co/collections/yermandy/gend)
|
| 5 |
+
|
| 6 |
+
This is the official repository for the paper:
|
| 7 |
+
|
| 8 |
+
**[Deepfake Detection that Generalizes Across Benchmarks](https://arxiv.org/abs/2508.06248)**.
|
| 9 |
+
|
| 10 |
+
### Abstract
|
| 11 |
+
|
| 12 |
+
> The generalization of deepfake detectors to unseen manipulation techniques remains a challenge for practical deployment. Although many approaches adapt foundation models by introducing significant architectural complexity, this work demonstrates that robust generalization is achievable through a parameter-efficient adaptation of one of the foundational pre-trained vision encoders. The proposed method, GenD, fine-tunes only the Layer Normalization parameters (0.03% of the total) and enhances generalization by enforcing a hyperspherical feature manifold using L2 normalization and metric learning on it.
|
| 13 |
+
>
|
| 14 |
+
> We conducted an extensive evaluation on 14 benchmark datasets spanning from 2019 to 2025. The proposed method achieves state-of-the-art performance, outperforming more complex, recent approaches in average cross-dataset AUROC. Our analysis yields two primary findings for the field: 1) training on paired real-fake data from the same source video is essential for mitigating shortcut learning and improving generalization, and 2) detection difficulty on academic datasets has not strictly increased over time, with models trained on older, diverse datasets showing strong generalization capabilities.
|
| 15 |
+
>
|
| 16 |
+
> This work delivers a computationally efficient and reproducible method, proving that state-of-the-art generalization is attainable by making targeted, minimal changes to a pre-trained foundational image encoder model.
|
| 17 |
+
|
| 18 |
+
## Inference using Hugging Face transformers
|
| 19 |
+
|
| 20 |
+
This example shows how to run inference with the pretrained GenD model from Hugging Face without other dependencies except `torch` and `transformers`. It expects that input images are already preprocessed by detector.
|
| 21 |
+
|
| 22 |
+
### Minimal dependencies
|
| 23 |
+
|
| 24 |
+
``` bash
|
| 25 |
+
conda create --name GenD python=3.12 uv -y
|
| 26 |
+
conda activate GenD
|
| 27 |
+
uv pip install torch==2.8.0
|
| 28 |
+
uv pip install torchvision==0.23.0
|
| 29 |
+
uv pip install transformers==4.56.2
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
### Inference with transformers
|
| 33 |
+
|
| 34 |
+
``` python
|
| 35 |
+
import requests
|
| 36 |
+
import torch
|
| 37 |
+
from PIL import Image
|
| 38 |
+
|
| 39 |
+
from src.hf.modeling_gend import GenD
|
| 40 |
+
|
| 41 |
+
# Other models can be found in https://huggingface.co/collections/yermandy/gend:
|
| 42 |
+
# -**** yermandy/GenD_CLIP_L_14
|
| 43 |
+
# - yermandy/GenD_PE_L
|
| 44 |
+
# - yermandy/GenD_DINOv3_L
|
| 45 |
+
model = GenD.from_pretrained("yermandy/GenD_CLIP_L_14")
|
| 46 |
+
|
| 47 |
+
urls = [
|
| 48 |
+
"https://github.com/yermandy/deepfake-detection/blob/main/datasets/FF/DF/000_003/000.png?raw=true",
|
| 49 |
+
"https://github.com/yermandy/deepfake-detection/blob/main/datasets/FF/real/000/000.png?raw=true",
|
| 50 |
+
]
|
| 51 |
+
images = [Image.open(requests.get(url, stream=True).raw) for url in urls]
|
| 52 |
+
tensors = torch.stack([model.feature_extractor.preprocess(img) for img in images])
|
| 53 |
+
logits = model(tensors)
|
| 54 |
+
probs = logits.softmax(dim=-1)
|
| 55 |
+
|
| 56 |
+
print(probs)
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
## Training
|
| 60 |
+
|
| 61 |
+
### Set up environment
|
| 62 |
+
|
| 63 |
+
``` bash
|
| 64 |
+
conda create --name GenD python=3.12 uv -y
|
| 65 |
+
conda activate GenD
|
| 66 |
+
uv pip install -r requirements.txt
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
### Minimal example without external data
|
| 70 |
+
|
| 71 |
+
#### Training example
|
| 72 |
+
|
| 73 |
+
Examine `src/exp/examples.py`, each experiment name is defined as a key, a value overrides default configuration of `Config` object from `src/config.py`. For example, try to run `example-training` experiment:
|
| 74 |
+
|
| 75 |
+
``` bash
|
| 76 |
+
python run_exp.py example-training
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
#### Test example after the model is trained
|
| 80 |
+
|
| 81 |
+
``` bash
|
| 82 |
+
python run_exp.py example-test --from_exp example-training --test
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
Alternatively, you can try inference using one of our released models from Hugging Face:
|
| 86 |
+
|
| 87 |
+
``` bash
|
| 88 |
+
python run_exp.py GenD_CLIP--CDFv2-example --test
|
| 89 |
+
python run_exp.py GenD_PE--CDFv2-example --test
|
| 90 |
+
python run_exp.py GenD_DINO--CDFv2-example --test
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
### Full training
|
| 94 |
+
|
| 95 |
+
To fully train the model, you need to download datasets, preprocess them, and create files with paths to the images.
|
| 96 |
+
|
| 97 |
+
The training entry will be similar to the minimal example above.
|
| 98 |
+
|
| 99 |
+
All experiments (configs) from the paper are stored in the `src/exp` folder.
|
| 100 |
+
|
| 101 |
+
#### Prepare the dataset
|
| 102 |
+
|
| 103 |
+
Take for example [FaceForensics++](https://github.com/ondyari/FaceForensics) dataset, follow these steps:
|
| 104 |
+
|
| 105 |
+
1. Download the dataset first from the [official source](https://github.com/ondyari/FaceForensics). The root of this dataset is `./FaceForensics`
|
| 106 |
+
|
| 107 |
+
2. Preprocess the dataset using `detector.py` script:
|
| 108 |
+
|
| 109 |
+
``` bash
|
| 110 |
+
python detector.py -i FaceForensics/manipulated_sequences/Deepfakes/c23/videos/ --mask_folder FaceForensics/masks/manipulated_sequences/Deepfakes/masks/videos/ -m at_least -n 32 -o datasets/FF/DF/ --det_thres 0.1 -s 1.3 --target_size none
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
Repeat the process for other manipulation methods and real videos. After processing everything, you will get a similar structure:
|
| 114 |
+
|
| 115 |
+
``` bash
|
| 116 |
+
datasets
|
| 117 |
+
└── FF
|
| 118 |
+
├── DF
|
| 119 |
+
│ └── 000_003
|
| 120 |
+
│ ├── 025.png
|
| 121 |
+
│ └── 038.png
|
| 122 |
+
├── F2F
|
| 123 |
+
│ └── 000_003
|
| 124 |
+
│ ├── 019.png
|
| 125 |
+
│ └── 029.png
|
| 126 |
+
├── FS
|
| 127 |
+
│ └── 000_003
|
| 128 |
+
│ ├── 019.png
|
| 129 |
+
│ └── 029.png
|
| 130 |
+
├── NT
|
| 131 |
+
│ └── 000_003
|
| 132 |
+
│ ├── 019.png
|
| 133 |
+
│ └── 029.png
|
| 134 |
+
└── real
|
| 135 |
+
└── 000
|
| 136 |
+
├── 025.png
|
| 137 |
+
└── 038.png
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
3. Create files with paths to images similar to the ones in `config/datasets` directory. It can be done using:
|
| 141 |
+
|
| 142 |
+
``` bash
|
| 143 |
+
find datasets/FF/DF/* -type f | sort > config/datasets/FF/DF.txt
|
| 144 |
+
```
|
| 145 |
+
|
| 146 |
+
We manage links to files using `src/utils/files.py`.
|
| 147 |
+
|
| 148 |
+
### Cite
|
| 149 |
+
|
| 150 |
+
``` bibtex
|
| 151 |
+
@article{yermakov2025deepfake,
|
| 152 |
+
title={Deepfake Detection that Generalizes Across Benchmarks},
|
| 153 |
+
author={Yermakov, Andrii and Cech, Jan and Matas, Jiri and Fritz, Mario},
|
| 154 |
+
journal={arXiv preprint arXiv:2508.06248},
|
| 155 |
+
year={2025}
|
| 156 |
+
}
|
| 157 |
+
```
|
config/datasets/CDFv2/test/Celeb-real.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
datasets/CDFv2/Celeb-real/id0_0000/045.png
|
| 2 |
+
datasets/CDFv2/Celeb-real/id0_0000/030.png
|
| 3 |
+
datasets/CDFv2/Celeb-real/id0_0000/015.png
|
| 4 |
+
datasets/CDFv2/Celeb-real/id0_0000/000.png
|
config/datasets/CDFv2/test/Celeb-synthesis.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
datasets/CDFv2/Celeb-synthesis/id0_id1_0000/000.png
|
| 2 |
+
datasets/CDFv2/Celeb-synthesis/id0_id1_0000/045.png
|
| 3 |
+
datasets/CDFv2/Celeb-synthesis/id0_id1_0000/030.png
|
| 4 |
+
datasets/CDFv2/Celeb-synthesis/id0_id1_0000/015.png
|
config/datasets/CDFv2/test/YouTube-real.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
datasets/CDFv2/YouTube-real/00000/000.png
|
| 2 |
+
datasets/CDFv2/YouTube-real/00000/014.png
|
| 3 |
+
datasets/CDFv2/YouTube-real/00000/028.png
|
| 4 |
+
datasets/CDFv2/YouTube-real/00000/043.png
|
config/datasets/FF/test/DF.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
datasets/FF/DF/000_003/000.png
|
| 2 |
+
datasets/FF/DF/000_003/012.png
|
| 3 |
+
datasets/FF/DF/000_003/025.png
|
| 4 |
+
datasets/FF/DF/000_003/038.png
|
config/datasets/FF/test/F2F.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
datasets/FF/F2F/000_003/000.png
|
| 2 |
+
datasets/FF/F2F/000_003/009.png
|
| 3 |
+
datasets/FF/F2F/000_003/019.png
|
| 4 |
+
datasets/FF/F2F/000_003/029.png
|
config/datasets/FF/test/FS.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
datasets/FF/FS/000_003/000.png
|
| 2 |
+
datasets/FF/FS/000_003/009.png
|
| 3 |
+
datasets/FF/FS/000_003/019.png
|
| 4 |
+
datasets/FF/FS/000_003/029.png
|
config/datasets/FF/test/NT.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
datasets/FF/NT/000_003/000.png
|
| 2 |
+
datasets/FF/NT/000_003/009.png
|
| 3 |
+
datasets/FF/NT/000_003/019.png
|
| 4 |
+
datasets/FF/NT/000_003/029.png
|
config/datasets/FF/test/real.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
datasets/FF/real/000/000.png
|
| 2 |
+
datasets/FF/real/000/012.png
|
| 3 |
+
datasets/FF/real/000/025.png
|
| 4 |
+
datasets/FF/real/000/038.png
|
datasets/CDFv2/Celeb-real/id0_0000/000.png
ADDED
|
Git LFS Details
|
datasets/CDFv2/Celeb-real/id0_0000/015.png
ADDED
|
Git LFS Details
|
datasets/CDFv2/Celeb-real/id0_0000/030.png
ADDED
|
Git LFS Details
|
datasets/CDFv2/Celeb-real/id0_0000/045.png
ADDED
|
Git LFS Details
|
datasets/CDFv2/Celeb-synthesis/id0_id1_0000/000.png
ADDED
|
Git LFS Details
|
datasets/CDFv2/Celeb-synthesis/id0_id1_0000/015.png
ADDED
|
Git LFS Details
|
datasets/CDFv2/Celeb-synthesis/id0_id1_0000/030.png
ADDED
|
Git LFS Details
|
datasets/CDFv2/Celeb-synthesis/id0_id1_0000/045.png
ADDED
|
Git LFS Details
|
datasets/CDFv2/YouTube-real/00000/000.png
ADDED
|
Git LFS Details
|
datasets/CDFv2/YouTube-real/00000/014.png
ADDED
|
Git LFS Details
|
datasets/CDFv2/YouTube-real/00000/028.png
ADDED
|
Git LFS Details
|
datasets/CDFv2/YouTube-real/00000/043.png
ADDED
|
Git LFS Details
|
datasets/FF/DF/000_003/000.png
ADDED
|
Git LFS Details
|
datasets/FF/DF/000_003/012.png
ADDED
|
Git LFS Details
|
datasets/FF/DF/000_003/025.png
ADDED
|
Git LFS Details
|
datasets/FF/DF/000_003/038.png
ADDED
|
Git LFS Details
|
datasets/FF/F2F/000_003/000.png
ADDED
|
Git LFS Details
|
datasets/FF/F2F/000_003/009.png
ADDED
|
Git LFS Details
|
datasets/FF/F2F/000_003/019.png
ADDED
|
Git LFS Details
|
datasets/FF/F2F/000_003/029.png
ADDED
|
Git LFS Details
|
datasets/FF/FS/000_003/000.png
ADDED
|
Git LFS Details
|
datasets/FF/FS/000_003/009.png
ADDED
|
Git LFS Details
|
datasets/FF/FS/000_003/019.png
ADDED
|
Git LFS Details
|
datasets/FF/FS/000_003/029.png
ADDED
|
Git LFS Details
|
datasets/FF/NT/000_003/000.png
ADDED
|
Git LFS Details
|
datasets/FF/NT/000_003/009.png
ADDED
|
Git LFS Details
|
datasets/FF/NT/000_003/019.png
ADDED
|
Git LFS Details
|
datasets/FF/NT/000_003/029.png
ADDED
|
Git LFS Details
|
datasets/FF/real/000/000.png
ADDED
|
Git LFS Details
|
datasets/FF/real/000/012.png
ADDED
|
Git LFS Details
|
datasets/FF/real/000/025.png
ADDED
|
Git LFS Details
|
datasets/FF/real/000/038.png
ADDED
|
Git LFS Details
|
detector.py
ADDED
|
@@ -0,0 +1,701 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import heapq
|
| 3 |
+
import os
|
| 4 |
+
import subprocess
|
| 5 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 6 |
+
from glob import glob
|
| 7 |
+
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
from src.retinaface import RetinaFace, prepare_model
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def max_spread_permutation_pq(N, start=0):
|
| 16 |
+
"""
|
| 17 |
+
Generate a permutation of 0..N-1 such that at each step
|
| 18 |
+
the next element is the one whose minimum distance to
|
| 19 |
+
all previously chosen elements is maximized, using a
|
| 20 |
+
priority queue to speed up selection.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
N (int): Length of the permutation.
|
| 24 |
+
start (int): The first element in the permutation (default 0).
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
List[int]: A list representing the permutation.
|
| 28 |
+
"""
|
| 29 |
+
if not (0 <= start < N):
|
| 30 |
+
raise ValueError("`start` must be in the range [0, N-1]")
|
| 31 |
+
|
| 32 |
+
# Initialize chosen list and distance map
|
| 33 |
+
chosen = [start]
|
| 34 |
+
dist = {i: abs(i - start) for i in range(N) if i != start}
|
| 35 |
+
|
| 36 |
+
# Build a max-heap (use negative distances for heapq)
|
| 37 |
+
heap = [(-d, i) for i, d in dist.items()]
|
| 38 |
+
heapq.heapify(heap)
|
| 39 |
+
|
| 40 |
+
# Greedily pick elements
|
| 41 |
+
while heap:
|
| 42 |
+
# Pop until we find a valid (up-to-date) entry
|
| 43 |
+
while True:
|
| 44 |
+
neg_d, candidate = heapq.heappop(heap)
|
| 45 |
+
current = -neg_d
|
| 46 |
+
# Only accept if it matches the latest dist
|
| 47 |
+
if dist.get(candidate, -1) == current:
|
| 48 |
+
break
|
| 49 |
+
|
| 50 |
+
# Add the selected candidate
|
| 51 |
+
chosen.append(candidate)
|
| 52 |
+
# Remove it from dist-map
|
| 53 |
+
del dist[candidate]
|
| 54 |
+
|
| 55 |
+
# Update distances for remaining elements
|
| 56 |
+
for other in list(dist.keys()):
|
| 57 |
+
new_d = abs(other - candidate)
|
| 58 |
+
if new_d < dist[other]:
|
| 59 |
+
dist[other] = new_d
|
| 60 |
+
heapq.heappush(heap, (-new_d, other))
|
| 61 |
+
|
| 62 |
+
return chosen
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def get_video_frames_generator(
|
| 66 |
+
source_path: str,
|
| 67 |
+
mask_path: str,
|
| 68 |
+
stride: int = 1,
|
| 69 |
+
num_frames=32,
|
| 70 |
+
mode="at_least",
|
| 71 |
+
):
|
| 72 |
+
video = cv2.VideoCapture(source_path)
|
| 73 |
+
if not video.isOpened():
|
| 74 |
+
print(f"Warning: Video {source_path} cannot be opened!")
|
| 75 |
+
return
|
| 76 |
+
|
| 77 |
+
video_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 78 |
+
|
| 79 |
+
if mask_path is not None:
|
| 80 |
+
mask_video = cv2.VideoCapture(mask_path)
|
| 81 |
+
|
| 82 |
+
if not mask_video.isOpened():
|
| 83 |
+
print(f"Warning: Mask video {mask_path} cannot be opened!")
|
| 84 |
+
return
|
| 85 |
+
|
| 86 |
+
mask_frames = int(mask_video.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 87 |
+
|
| 88 |
+
if video_frames != mask_frames:
|
| 89 |
+
print(
|
| 90 |
+
f"Warning: {source_path} and {mask_path} have different number of frames {video_frames} vs {mask_frames}!"
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
total_frames = min(video_frames, mask_frames)
|
| 94 |
+
else:
|
| 95 |
+
mask_video = None
|
| 96 |
+
total_frames = video_frames
|
| 97 |
+
|
| 98 |
+
if not video.isOpened():
|
| 99 |
+
raise Exception(f"Could not open video at {source_path}")
|
| 100 |
+
|
| 101 |
+
# Get the mode
|
| 102 |
+
if mode == "fixed_num_frames":
|
| 103 |
+
# Get the frame rate of the video by dividing the number of frames by the duration (same interval between frames)
|
| 104 |
+
frame_ids = np.linspace(0, total_frames - 1, num_frames, endpoint=True, dtype=int)
|
| 105 |
+
elif mode == "fixed_stride":
|
| 106 |
+
# Get the frame rate of the video by dividing the number of frames by the duration (same interval between frames)
|
| 107 |
+
frame_ids = np.arange(0, total_frames, stride, dtype=int)
|
| 108 |
+
elif mode == "at_least":
|
| 109 |
+
frame_ids = max_spread_permutation_pq(total_frames, start=total_frames // 2)
|
| 110 |
+
else:
|
| 111 |
+
raise ValueError(f"Invalid mode: {mode}. Choose 'fixed_num_frames', 'fixed_stride', or 'at_least'.")
|
| 112 |
+
|
| 113 |
+
# Iterate through the selected frame IDs
|
| 114 |
+
for frame_id in frame_ids:
|
| 115 |
+
# Set the video capture position to the desired frame
|
| 116 |
+
video.set(cv2.CAP_PROP_POS_FRAMES, frame_id)
|
| 117 |
+
success, frame = video.read()
|
| 118 |
+
|
| 119 |
+
if mask_video is not None:
|
| 120 |
+
mask_video.set(cv2.CAP_PROP_POS_FRAMES, frame_id)
|
| 121 |
+
success_mask, mask = mask_video.read()
|
| 122 |
+
if not success_mask:
|
| 123 |
+
print(f"Warning: Failed to read mask frame {frame_id} of {mask_path}. Skipping.")
|
| 124 |
+
continue
|
| 125 |
+
|
| 126 |
+
yield frame, frame_id, mask
|
| 127 |
+
|
| 128 |
+
else:
|
| 129 |
+
# Check if the frame was successfully read
|
| 130 |
+
if not success:
|
| 131 |
+
print(f"Warning: Failed to read frame {frame_id} of {source_path}. Skipping.")
|
| 132 |
+
continue
|
| 133 |
+
|
| 134 |
+
yield frame, frame_id, None
|
| 135 |
+
|
| 136 |
+
# Release the video capture object
|
| 137 |
+
video.release()
|
| 138 |
+
|
| 139 |
+
if mask_video is not None:
|
| 140 |
+
mask_video.release()
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def align_face(
|
| 144 |
+
img: np.ndarray,
|
| 145 |
+
landmarks: np.ndarray,
|
| 146 |
+
target_size: None | tuple = None,
|
| 147 |
+
scale: float = 1.3,
|
| 148 |
+
mask: np.ndarray = None,
|
| 149 |
+
):
|
| 150 |
+
"""
|
| 151 |
+
Aligns a face based on 5-point facial landmarks (eyes, nose, mouth corners).
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
img: Input image containing the face
|
| 155 |
+
landmarks: 5-point facial landmarks array with shape (5, 2)
|
| 156 |
+
target_size: Desired output size as (width, height) tuple
|
| 157 |
+
scale: Scaling factor to control how much context around the face to include
|
| 158 |
+
stabilize_features: Whether to use standard reference points for consistent alignment
|
| 159 |
+
return_transform: Whether to return the transformation matrix
|
| 160 |
+
mask: Resize mask the same way as img
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
Aligned face image with specified target_size
|
| 164 |
+
Optionally returns the transformation matrix if return_transform=True
|
| 165 |
+
"""
|
| 166 |
+
dst = np.array(
|
| 167 |
+
[
|
| 168 |
+
[0.34, 0.46],
|
| 169 |
+
[0.66, 0.46],
|
| 170 |
+
[0.5, 0.64],
|
| 171 |
+
[0.37, 0.82],
|
| 172 |
+
[0.63, 0.82],
|
| 173 |
+
],
|
| 174 |
+
dtype=np.float32,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
if target_size is None:
|
| 178 |
+
# Compute desired distances between all pairs
|
| 179 |
+
desired_dists = np.linalg.norm(landmarks[:, None, :] - landmarks[None, :, :], axis=-1)
|
| 180 |
+
|
| 181 |
+
# Destination distances between all pairs
|
| 182 |
+
dst_dists = np.linalg.norm(dst[:, None, :] - dst[None, :, :], axis=-1)
|
| 183 |
+
|
| 184 |
+
# Take upper triangle of the distance matrix
|
| 185 |
+
upper_triangle_indices = np.triu_indices(len(dst), k=1)
|
| 186 |
+
dst_dists = dst_dists[upper_triangle_indices]
|
| 187 |
+
desired_dists = desired_dists[upper_triangle_indices]
|
| 188 |
+
|
| 189 |
+
# Approximate target size
|
| 190 |
+
approx_size = np.round(np.mean(desired_dists / dst_dists) * scale).astype(int)
|
| 191 |
+
target_size = (approx_size, approx_size)
|
| 192 |
+
|
| 193 |
+
dst[:, 0] = dst[:, 0] * target_size[0]
|
| 194 |
+
dst[:, 1] = dst[:, 1] * target_size[1]
|
| 195 |
+
|
| 196 |
+
margin_rate = scale - 1
|
| 197 |
+
x_margin = target_size[0] * margin_rate / 2.0
|
| 198 |
+
y_margin = target_size[1] * margin_rate / 2.0
|
| 199 |
+
|
| 200 |
+
# move
|
| 201 |
+
dst[:, 0] += x_margin
|
| 202 |
+
dst[:, 1] += y_margin
|
| 203 |
+
|
| 204 |
+
# resize
|
| 205 |
+
dst[:, 0] *= target_size[0] / (target_size[0] + 2 * x_margin)
|
| 206 |
+
dst[:, 1] *= target_size[1] / (target_size[1] + 2 * y_margin)
|
| 207 |
+
|
| 208 |
+
src = landmarks.astype(np.float32)
|
| 209 |
+
|
| 210 |
+
M = cv2.estimateAffinePartial2D(src, dst, method=cv2.LMEDS)[0]
|
| 211 |
+
|
| 212 |
+
img = cv2.warpAffine(img, M, target_size, flags=cv2.INTER_LINEAR)
|
| 213 |
+
|
| 214 |
+
# Warp landmarks, show
|
| 215 |
+
# landmarks = cv2.transform(np.expand_dims(landmarks, axis=0), M)[0]
|
| 216 |
+
# for point in landmarks:
|
| 217 |
+
# cv2.circle(img, tuple(point.astype(int)), 2, (0, 255, 0), -1)
|
| 218 |
+
|
| 219 |
+
if mask is not None:
|
| 220 |
+
mask = cv2.warpAffine(mask, M, target_size, flags=cv2.INTER_NEAREST)
|
| 221 |
+
|
| 222 |
+
return img, mask
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def process_video(
|
| 226 |
+
source_path,
|
| 227 |
+
target_path,
|
| 228 |
+
mask_path,
|
| 229 |
+
model: RetinaFace,
|
| 230 |
+
scale=1.3,
|
| 231 |
+
target_size=(256, 256),
|
| 232 |
+
stride=1,
|
| 233 |
+
num_frames=32,
|
| 234 |
+
mode="at_least",
|
| 235 |
+
skip_processed_videos=False,
|
| 236 |
+
skip_processed_frames=False,
|
| 237 |
+
):
|
| 238 |
+
frame_save_path = target_path.replace(".mp4", "/frames")
|
| 239 |
+
|
| 240 |
+
# Skip if frame_save_path exists
|
| 241 |
+
if skip_processed_videos and os.path.exists(frame_save_path):
|
| 242 |
+
print(f"Frames for {source_path} already processed.")
|
| 243 |
+
return
|
| 244 |
+
else:
|
| 245 |
+
print(f"Processing {source_path}")
|
| 246 |
+
|
| 247 |
+
# Create a frame generator from video path for iteration of frames
|
| 248 |
+
frame_generator = get_video_frames_generator(
|
| 249 |
+
source_path,
|
| 250 |
+
mask_path,
|
| 251 |
+
stride=stride,
|
| 252 |
+
num_frames=num_frames,
|
| 253 |
+
mode=mode,
|
| 254 |
+
)
|
| 255 |
+
# desc = f"Processing {os.path.basename(source_path)}"
|
| 256 |
+
|
| 257 |
+
num_saved = 0
|
| 258 |
+
for frame, frame_id, mask in frame_generator:
|
| 259 |
+
frame_filename = os.path.join(frame_save_path, f"frame_{frame_id:04d}.png")
|
| 260 |
+
|
| 261 |
+
if skip_processed_frames and os.path.exists(frame_filename):
|
| 262 |
+
print(f"Frame {frame_id} of {source_path} already processed.")
|
| 263 |
+
num_saved += 1
|
| 264 |
+
if mode in ["fixed_stride", "at_least"] and num_saved >= num_frames and num_frames != -1:
|
| 265 |
+
break
|
| 266 |
+
continue
|
| 267 |
+
|
| 268 |
+
try:
|
| 269 |
+
preds = model.detect(frame)
|
| 270 |
+
except Exception as e:
|
| 271 |
+
print(f"Error during detection: {e}")
|
| 272 |
+
continue
|
| 273 |
+
|
| 274 |
+
xyxy, landmarks = preds
|
| 275 |
+
|
| 276 |
+
if len(xyxy) == 0:
|
| 277 |
+
print(f"No faces detected in frame {frame_id} of {source_path}")
|
| 278 |
+
continue
|
| 279 |
+
|
| 280 |
+
selected_landmarks = None
|
| 281 |
+
|
| 282 |
+
if mask is not None:
|
| 283 |
+
# It is possible that the mask is empty -> skip this frame
|
| 284 |
+
if mask.sum() == 0:
|
| 285 |
+
print(f"Warning: Mask is empty for frame {frame_id} of {source_path}. Skipping.")
|
| 286 |
+
continue
|
| 287 |
+
|
| 288 |
+
# Convert mask to grayscale if it's not already
|
| 289 |
+
mask_img = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) if len(mask.shape) == 3 else mask
|
| 290 |
+
|
| 291 |
+
# Threshold the mask to create a binary mask
|
| 292 |
+
mask_img = cv2.threshold(mask_img, 1, 255, cv2.THRESH_BINARY)[1]
|
| 293 |
+
|
| 294 |
+
# Find the face that intersects the most with the mask
|
| 295 |
+
best_landmarks = None
|
| 296 |
+
max_intersection = 0
|
| 297 |
+
for i in range(len(xyxy)):
|
| 298 |
+
# Get the bounding box coordinates
|
| 299 |
+
x1, y1, x2, y2 = xyxy[i, :4].astype(int)
|
| 300 |
+
|
| 301 |
+
# Create a mask for the face
|
| 302 |
+
face_mask = np.zeros_like(mask_img)
|
| 303 |
+
face_mask[y1:y2, x1:x2] = 255
|
| 304 |
+
|
| 305 |
+
# Calculate the intersection between the face mask and the provided mask
|
| 306 |
+
intersection = np.sum(np.logical_and(face_mask, mask_img))
|
| 307 |
+
|
| 308 |
+
# Update the best face if the intersection is greater than the current maximum
|
| 309 |
+
if intersection > max_intersection:
|
| 310 |
+
max_intersection = intersection
|
| 311 |
+
best_landmarks = landmarks[i]
|
| 312 |
+
|
| 313 |
+
# If a face was found, use it; otherwise, skip this frame
|
| 314 |
+
if best_landmarks is not None:
|
| 315 |
+
selected_landmarks = best_landmarks
|
| 316 |
+
else:
|
| 317 |
+
print(f"No suitable face found in frame {frame_id} of {source_path} with the provided mask.")
|
| 318 |
+
continue
|
| 319 |
+
|
| 320 |
+
# """
|
| 321 |
+
# Select landmarks of the largest face if not using mask
|
| 322 |
+
if selected_landmarks is None:
|
| 323 |
+
areas = (xyxy[:, 2] - xyxy[:, 0]) * (xyxy[:, 3] - xyxy[:, 1])
|
| 324 |
+
idx = np.argmax(areas)
|
| 325 |
+
selected_landmarks = landmarks[idx]
|
| 326 |
+
|
| 327 |
+
# Show all landmarks
|
| 328 |
+
# for L, B in zip(landmarks, xyxy):
|
| 329 |
+
# for point in L:
|
| 330 |
+
# cv2.circle(frame, tuple(point.astype(int)), 2, (0, 255, 0), -1)
|
| 331 |
+
# cv2.rectangle(frame, tuple(B[0:2].astype(int)), tuple(B[2:4].astype(int)), (0, 255, 0), 2)
|
| 332 |
+
|
| 333 |
+
# Align the face
|
| 334 |
+
aligned_face, _ = align_face(frame, selected_landmarks, target_size=target_size, scale=scale)
|
| 335 |
+
|
| 336 |
+
# Save the aligned face
|
| 337 |
+
os.makedirs(frame_save_path, exist_ok=True)
|
| 338 |
+
cv2.imwrite(frame_filename, aligned_face)
|
| 339 |
+
# """
|
| 340 |
+
|
| 341 |
+
num_saved += 1
|
| 342 |
+
|
| 343 |
+
if mode in ["fixed_stride", "at_least"] and num_saved >= num_frames and num_frames != -1:
|
| 344 |
+
break
|
| 345 |
+
|
| 346 |
+
if num_saved == 0:
|
| 347 |
+
print(f"No faces were saved from {source_path}. Check the detection threshold or input video.")
|
| 348 |
+
|
| 349 |
+
return frame_save_path
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def process_image(
|
| 353 |
+
source_path,
|
| 354 |
+
target_path,
|
| 355 |
+
model: RetinaFace,
|
| 356 |
+
scale=1.3,
|
| 357 |
+
target_size=(256, 256),
|
| 358 |
+
skip_processed_frames=False,
|
| 359 |
+
):
|
| 360 |
+
"""Processes a single image file."""
|
| 361 |
+
if skip_processed_frames and os.path.exists(target_path):
|
| 362 |
+
print(f"Image {source_path} already processed.")
|
| 363 |
+
return target_path
|
| 364 |
+
else:
|
| 365 |
+
print(f"Processing {source_path}")
|
| 366 |
+
|
| 367 |
+
img = cv2.imread(source_path)
|
| 368 |
+
if img is None:
|
| 369 |
+
print(f"Failed to read image {source_path}")
|
| 370 |
+
return None
|
| 371 |
+
|
| 372 |
+
try:
|
| 373 |
+
preds = model.detect(img)
|
| 374 |
+
except Exception as e:
|
| 375 |
+
print(f"Error during detection: {e}")
|
| 376 |
+
return None
|
| 377 |
+
|
| 378 |
+
xyxy, landmarks = preds
|
| 379 |
+
|
| 380 |
+
if len(xyxy) == 0:
|
| 381 |
+
print(f"No faces detected in {source_path}")
|
| 382 |
+
return None
|
| 383 |
+
|
| 384 |
+
# Select landmarks of the largest face
|
| 385 |
+
areas = (xyxy[:, 2] - xyxy[:, 0]) * (xyxy[:, 3] - xyxy[:, 1])
|
| 386 |
+
idx = np.argmax(areas)
|
| 387 |
+
landmarks = landmarks[idx]
|
| 388 |
+
|
| 389 |
+
# Align the face
|
| 390 |
+
aligned_face, _ = align_face(img, landmarks, target_size=target_size, scale=scale)
|
| 391 |
+
|
| 392 |
+
# Save the aligned face
|
| 393 |
+
os.makedirs(os.path.dirname(target_path), exist_ok=True)
|
| 394 |
+
cv2.imwrite(target_path, aligned_face)
|
| 395 |
+
return target_path
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def get_output_path(source_path, input_folder, output_folder):
|
| 399 |
+
# Example: source_path = input_folder + new_source_path``
|
| 400 |
+
new_source_path = source_path.replace(input_folder, os.path.basename(input_folder))
|
| 401 |
+
# Create directory for each video
|
| 402 |
+
new_source_path = new_source_path.replace(".mp4", "")
|
| 403 |
+
# Place it in the output folder
|
| 404 |
+
output_path = os.path.join(output_folder, new_source_path)
|
| 405 |
+
return output_path
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
def get_mask_path(input_folder, input_mask_folder, source_path):
|
| 409 |
+
if input_mask_folder is not None:
|
| 410 |
+
# Change the input folder to the mask folder
|
| 411 |
+
source_path = source_path.replace(input_folder, input_mask_folder)
|
| 412 |
+
|
| 413 |
+
#! FF++ has masks named the same way as original videos
|
| 414 |
+
if "FaceForensics" in source_path or "FF++" in source_path:
|
| 415 |
+
return source_path
|
| 416 |
+
|
| 417 |
+
#! Else assume masks are named with _mask suffix
|
| 418 |
+
source_path = source_path.replace(".mp4", "_mask.mp4")
|
| 419 |
+
return source_path
|
| 420 |
+
return None
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def process_mixed_types(
|
| 424 |
+
input_folder_or_file: str | list[str],
|
| 425 |
+
input_mask_folder: None | str,
|
| 426 |
+
model: RetinaFace,
|
| 427 |
+
num_workers=1,
|
| 428 |
+
scale=1.3,
|
| 429 |
+
target_size=(256, 256),
|
| 430 |
+
stride=1,
|
| 431 |
+
num_frames=32,
|
| 432 |
+
mode: str = "fixed_num_frames",
|
| 433 |
+
output_folder: str = "outputs",
|
| 434 |
+
possible_extensions: tuple[str] = ("mp4", "jpg", "png", "jpeg"),
|
| 435 |
+
skip_processed_videos: bool = False,
|
| 436 |
+
skip_processed_frames: bool = False,
|
| 437 |
+
):
|
| 438 |
+
if os.path.isfile(input_folder_or_file):
|
| 439 |
+
# If input is a file
|
| 440 |
+
if input_folder_or_file.endswith(possible_extensions):
|
| 441 |
+
# If input is a media file
|
| 442 |
+
files = [input_folder_or_file]
|
| 443 |
+
elif input_folder_or_file.endswith("txt"):
|
| 444 |
+
# If input is a txt file
|
| 445 |
+
with open(input_folder_or_file, "r") as f:
|
| 446 |
+
files = f.read().splitlines()
|
| 447 |
+
|
| 448 |
+
else:
|
| 449 |
+
# If input is a folder
|
| 450 |
+
files = find_files(input_folder_or_file, possible_extensions)
|
| 451 |
+
|
| 452 |
+
if not files:
|
| 453 |
+
print(f"No files found in {input_folder_or_file}")
|
| 454 |
+
return
|
| 455 |
+
|
| 456 |
+
def process(source_path):
|
| 457 |
+
output_path = get_output_path(source_path, input_folder_or_file, output_folder)
|
| 458 |
+
|
| 459 |
+
if source_path.endswith(".mp4"):
|
| 460 |
+
mask_path = get_mask_path(input_folder_or_file, input_mask_folder, source_path)
|
| 461 |
+
try:
|
| 462 |
+
return process_video(
|
| 463 |
+
source_path,
|
| 464 |
+
output_path,
|
| 465 |
+
mask_path,
|
| 466 |
+
model,
|
| 467 |
+
scale=scale,
|
| 468 |
+
target_size=target_size,
|
| 469 |
+
stride=stride,
|
| 470 |
+
num_frames=num_frames,
|
| 471 |
+
mode=mode,
|
| 472 |
+
skip_processed_videos=skip_processed_videos,
|
| 473 |
+
skip_processed_frames=skip_processed_frames,
|
| 474 |
+
)
|
| 475 |
+
except Exception as e:
|
| 476 |
+
print(f"Error processing video {source_path}: {e}")
|
| 477 |
+
else:
|
| 478 |
+
try:
|
| 479 |
+
return process_image(
|
| 480 |
+
source_path,
|
| 481 |
+
output_path,
|
| 482 |
+
model,
|
| 483 |
+
scale=scale,
|
| 484 |
+
target_size=target_size,
|
| 485 |
+
skip_processed_frames=skip_processed_frames,
|
| 486 |
+
)
|
| 487 |
+
except Exception as e:
|
| 488 |
+
print(f"Error processing image {source_path}: {e}")
|
| 489 |
+
|
| 490 |
+
files = sorted(files) # Sort files for consistent processing
|
| 491 |
+
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
| 492 |
+
futures = [executor.submit(process, file) for file in files]
|
| 493 |
+
for future in tqdm(futures, desc=f"Processing videos in {input_folder_or_file}", leave=True):
|
| 494 |
+
future.result()
|
| 495 |
+
|
| 496 |
+
print("Processing complete.")
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
def find_files_fd(start_dir, extensions):
|
| 500 |
+
"""
|
| 501 |
+
Finds files with given extensions recursively using the 'fd' command-line tool.
|
| 502 |
+
|
| 503 |
+
Args:
|
| 504 |
+
start_dir (str): The directory to start searching from.
|
| 505 |
+
extensions (list): A list of file extensions without the leading dot (e.g., ['png', 'jpg']).
|
| 506 |
+
|
| 507 |
+
Returns:
|
| 508 |
+
list: A list of full path strings for each found file. Returns empty list if fd fails.
|
| 509 |
+
|
| 510 |
+
Raises:
|
| 511 |
+
FileNotFoundError: If the 'fd' command is not found in the system's PATH.
|
| 512 |
+
"""
|
| 513 |
+
if not os.path.isdir(start_dir):
|
| 514 |
+
print(f"Error: Start directory not found: {start_dir}")
|
| 515 |
+
return []
|
| 516 |
+
|
| 517 |
+
try:
|
| 518 |
+
# Build the command. Use -e for each extension.
|
| 519 |
+
command = ["fd", "--type", "f", "--type", "l"] # Find only files or links to files
|
| 520 |
+
for ext in extensions:
|
| 521 |
+
# fd expects extensions without the dot
|
| 522 |
+
command.extend(["--extension", ext])
|
| 523 |
+
# Add the pattern ('.' matches everything, filtering is done by extension)
|
| 524 |
+
# and the directory to search
|
| 525 |
+
command.extend([".", start_dir])
|
| 526 |
+
|
| 527 |
+
# Run the command
|
| 528 |
+
result = subprocess.run(
|
| 529 |
+
command,
|
| 530 |
+
capture_output=True, # Capture stdout and stderr
|
| 531 |
+
text=True, # Decode output as text (UTF-8 by default)
|
| 532 |
+
check=False, # Do not raise exception on non-zero exit code automatically
|
| 533 |
+
encoding="utf-8", # Be explicit about encoding
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
# Check if fd ran successfully
|
| 537 |
+
if result.returncode != 0:
|
| 538 |
+
# fd returns specific exit codes, e.g., 1 for errors, 2 if pattern not found (but we use '.')
|
| 539 |
+
# We mainly care if the command executed but maybe found nothing or had an issue.
|
| 540 |
+
# Check stderr for actual errors.
|
| 541 |
+
if result.stderr:
|
| 542 |
+
print(f"Error running fd (code {result.returncode}): {result.stderr.strip()}")
|
| 543 |
+
# If stderr is empty but code isn't 0, it might just mean no files found, which is okay.
|
| 544 |
+
# We return an empty list in case of errors or no files found.
|
| 545 |
+
return [] # Return empty list on error or if no files found
|
| 546 |
+
|
| 547 |
+
# fd outputs one path per line. Split the output.
|
| 548 |
+
# .strip() removes potential leading/trailing whitespace/newlines
|
| 549 |
+
file_list = result.stdout.strip().splitlines()
|
| 550 |
+
return file_list
|
| 551 |
+
|
| 552 |
+
except FileNotFoundError:
|
| 553 |
+
raise # Re-raise the exception so the caller knows fd is missing
|
| 554 |
+
|
| 555 |
+
except Exception as e:
|
| 556 |
+
print(f"An unexpected error occurred while running fd: {e}")
|
| 557 |
+
return [] # Return empty list on other unexpected errors
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
def find_files_glob(start_dir, extensions):
|
| 561 |
+
"""
|
| 562 |
+
Finds files with given extensions recursively using glob.
|
| 563 |
+
|
| 564 |
+
Args:
|
| 565 |
+
start_dir (str): The directory to start searching from.
|
| 566 |
+
extensions (list): A list of file extensions without the leading dot (e.g., ['png', 'jpg']).
|
| 567 |
+
|
| 568 |
+
Returns:
|
| 569 |
+
list: A list of full path strings for each found file.
|
| 570 |
+
"""
|
| 571 |
+
files = []
|
| 572 |
+
for ext in extensions:
|
| 573 |
+
files.extend(glob(f"{start_dir}/**/*{ext}", recursive=True))
|
| 574 |
+
return sorted(f for f in files if os.path.isfile(f))
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
def find_files(start_dir, extensions):
|
| 578 |
+
try:
|
| 579 |
+
return find_files_fd(start_dir, extensions)
|
| 580 |
+
except Exception:
|
| 581 |
+
return find_files_glob(start_dir, extensions)
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
def get_args():
|
| 585 |
+
parser = argparse.ArgumentParser()
|
| 586 |
+
parser.add_argument(
|
| 587 |
+
"-i",
|
| 588 |
+
"--input_folder_or_file",
|
| 589 |
+
type=str,
|
| 590 |
+
required=True,
|
| 591 |
+
help="Path to the input folder containing videos or images.",
|
| 592 |
+
)
|
| 593 |
+
parser.add_argument(
|
| 594 |
+
"--mask_folder",
|
| 595 |
+
type=str,
|
| 596 |
+
default=None,
|
| 597 |
+
help="Path to the input folder containing masks (optional).",
|
| 598 |
+
)
|
| 599 |
+
parser.add_argument(
|
| 600 |
+
"--num_workers",
|
| 601 |
+
type=int,
|
| 602 |
+
default=8,
|
| 603 |
+
help="Number of worker threads.",
|
| 604 |
+
)
|
| 605 |
+
parser.add_argument(
|
| 606 |
+
"-s",
|
| 607 |
+
"--scale",
|
| 608 |
+
type=float,
|
| 609 |
+
default=1.3,
|
| 610 |
+
help="Scale factor for face alignment.",
|
| 611 |
+
)
|
| 612 |
+
parser.add_argument(
|
| 613 |
+
"--target_size",
|
| 614 |
+
type=str,
|
| 615 |
+
default="256,256",
|
| 616 |
+
help="Target size for aligned faces as width, height (e.g., 256,256) or 'none'.",
|
| 617 |
+
)
|
| 618 |
+
parser.add_argument(
|
| 619 |
+
"--det_thres",
|
| 620 |
+
type=float,
|
| 621 |
+
default=0.4,
|
| 622 |
+
help="Detection threshold for RetinaFace.",
|
| 623 |
+
)
|
| 624 |
+
parser.add_argument(
|
| 625 |
+
"-m",
|
| 626 |
+
"--mode",
|
| 627 |
+
type=str,
|
| 628 |
+
default="at_least",
|
| 629 |
+
choices=["fixed_num_frames", "fixed_stride", "at_least"],
|
| 630 |
+
help="Mode for frame extraction from videos ('fixed_num_frames', 'fixed_stride', or 'at_least').",
|
| 631 |
+
)
|
| 632 |
+
parser.add_argument(
|
| 633 |
+
"--stride",
|
| 634 |
+
type=int,
|
| 635 |
+
default=1,
|
| 636 |
+
help="Stride for frame extraction from videos (only used in 'fixed_stride' mode).",
|
| 637 |
+
)
|
| 638 |
+
parser.add_argument(
|
| 639 |
+
"-n",
|
| 640 |
+
"--num_frames",
|
| 641 |
+
type=int,
|
| 642 |
+
default=32,
|
| 643 |
+
help="Maximum number of frames to extract from each video, -1 for all frames.",
|
| 644 |
+
)
|
| 645 |
+
parser.add_argument(
|
| 646 |
+
"-o",
|
| 647 |
+
"--output_folder",
|
| 648 |
+
type=str,
|
| 649 |
+
default="outputs",
|
| 650 |
+
help="Output folder for the preprocessed images.",
|
| 651 |
+
)
|
| 652 |
+
parser.add_argument(
|
| 653 |
+
"--skip_processed_videos",
|
| 654 |
+
action="store_true",
|
| 655 |
+
help="Skip videos that have already been processed.",
|
| 656 |
+
)
|
| 657 |
+
parser.add_argument(
|
| 658 |
+
"--skip_processed_frames",
|
| 659 |
+
action="store_true",
|
| 660 |
+
help="Skip frames that have already been processed.",
|
| 661 |
+
)
|
| 662 |
+
args = parser.parse_args()
|
| 663 |
+
args.target_size = parse_target_size(args.target_size)
|
| 664 |
+
return args
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
def parse_target_size(target_size_str):
|
| 668 |
+
try:
|
| 669 |
+
width, height = map(int, target_size_str.split(","))
|
| 670 |
+
return (width, height)
|
| 671 |
+
except ValueError:
|
| 672 |
+
if "none" in target_size_str.lower():
|
| 673 |
+
return None
|
| 674 |
+
raise ValueError("Invalid target_size format. Use 'width,height' or 'none'.")
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
def main():
|
| 678 |
+
args = get_args()
|
| 679 |
+
|
| 680 |
+
model = prepare_model(args.det_thres)
|
| 681 |
+
|
| 682 |
+
process_mixed_types(
|
| 683 |
+
input_folder_or_file=args.input_folder_or_file,
|
| 684 |
+
input_mask_folder=args.mask_folder,
|
| 685 |
+
model=model,
|
| 686 |
+
num_workers=args.num_workers,
|
| 687 |
+
scale=args.scale,
|
| 688 |
+
target_size=args.target_size,
|
| 689 |
+
stride=args.stride,
|
| 690 |
+
num_frames=args.num_frames,
|
| 691 |
+
mode=args.mode,
|
| 692 |
+
output_folder=args.output_folder,
|
| 693 |
+
skip_processed_videos=args.skip_processed_videos,
|
| 694 |
+
skip_processed_frames=args.skip_processed_frames,
|
| 695 |
+
)
|
| 696 |
+
|
| 697 |
+
exit(0)
|
| 698 |
+
|
| 699 |
+
|
| 700 |
+
if __name__ == "__main__":
|
| 701 |
+
main()
|
pyproject.toml
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[tool.ruff]
|
| 2 |
+
line-length = 120
|
| 3 |
+
|
| 4 |
+
[tool.ruff.lint]
|
| 5 |
+
ignore = [
|
| 6 |
+
"C901", # complex condition
|
| 7 |
+
"E501", # line too long
|
| 8 |
+
"F401", # imported but unused
|
| 9 |
+
"F403", # from module import * used; unable to detect undefined names
|
| 10 |
+
"F405", # name may be undefined, or defined from star imports: module
|
| 11 |
+
"E741", # ambiguous variable name
|
| 12 |
+
]
|
| 13 |
+
|
| 14 |
+
select = [
|
| 15 |
+
"C", # flake8-comprehensions
|
| 16 |
+
"E", "W", # pycodestyle
|
| 17 |
+
"F", # pyflakes
|
| 18 |
+
"I", # isort
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
[tool.ruff.lint.isort]
|
| 22 |
+
force-to-top = ["autoroot", "autorootcwd"]
|
| 23 |
+
|
| 24 |
+
[tool.ruff.lint.per-file-ignores]
|
| 25 |
+
"**/__init__.py" = ["E402"]
|
| 26 |
+
|
| 27 |
+
[tool.pyright]
|
| 28 |
+
exclude = [
|
| 29 |
+
"**/__pycache__",
|
| 30 |
+
"wandb",
|
| 31 |
+
"datasets",
|
| 32 |
+
"outputs",
|
| 33 |
+
"runs",
|
| 34 |
+
"tmp",
|
| 35 |
+
"logs",
|
| 36 |
+
]
|
| 37 |
+
typeCheckingMode = "off"
|
requirements.txt
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.8.0
|
| 2 |
+
torchaudio==2.8.0
|
| 3 |
+
torchvision==0.23.0
|
| 4 |
+
lightning==2.5.5
|
| 5 |
+
transformers==4.56.2
|
| 6 |
+
tqdm==4.67.1 # progress bar
|
| 7 |
+
timm==1.0.20 # torch models
|
| 8 |
+
matplotlib==3.10.6 # visualization
|
| 9 |
+
seaborn==0.13.2 # visualization
|
| 10 |
+
scikit-learn==1.6.1 # metrics
|
| 11 |
+
rich==14.1.0 # logging
|
| 12 |
+
wandb==0.22.0 # logging
|
| 13 |
+
pydantic==2.11.9 # config
|
| 14 |
+
# albumentations==1.4.17 # augmentations
|
| 15 |
+
ruff==0.13.2 # formatting
|
| 16 |
+
fire==0.7.0 # CLI
|
| 17 |
+
pytorch-metric-learning==2.8.1 # losses
|
| 18 |
+
peft==0.15.2 # parameter-efficient fine-tuning
|
| 19 |
+
ipykernel==6.30.1 # jupyter
|
| 20 |
+
autoroot==1.0.1 # root utils
|
| 21 |
+
autorootcwd==1.0.1 # root utils
|
| 22 |
+
xformers==0.0.32.post2 # RADIOv2.5/3
|
| 23 |
+
einops==0.8.1 # RADIOv2.5/3
|
| 24 |
+
open-clip-torch==2.32.0 # RADIOv2.5/3
|
| 25 |
+
grad-cam==1.5.5 # for Grad-CAM visualization
|
| 26 |
+
mediapipe==0.10.21 # Face landmark detection
|
| 27 |
+
|
| 28 |
+
# --- for detector.py ---
|
| 29 |
+
# opencv-python==4.11.0.86 # mainly only for detector.py
|
| 30 |
+
opencv-python-headless==4.12.0.88 # mainly for `detector.py`
|
| 31 |
+
onnxruntime-gpu==1.21.0 # for ONNX model inference
|
| 32 |
+
|
| 33 |
+
# --- for app/run.py ---
|
| 34 |
+
gradio==5.49.1
|
run.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import traceback
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from lightning import Trainer
|
| 6 |
+
from lightning.pytorch import callbacks as pl_callbacks
|
| 7 |
+
from lightning.pytorch import loggers as pl_loggers
|
| 8 |
+
from rich import traceback as rich_traceback
|
| 9 |
+
|
| 10 |
+
from src import dataset as datasets
|
| 11 |
+
from src.config import Config
|
| 12 |
+
from src.model.base import BaseDeepakeDetectionModel
|
| 13 |
+
from src.utils import logger
|
| 14 |
+
from src.utils.checks import checks
|
| 15 |
+
from src.utils.model_checkpoint import ModelCheckpointParallel
|
| 16 |
+
|
| 17 |
+
rich_traceback.install()
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def load_third_party_model(config: Config) -> BaseDeepakeDetectionModel:
|
| 21 |
+
if "weights/Effort" in config.checkpoint:
|
| 22 |
+
# Download: https://drive.google.com/drive/folders/19kQwGDjF18uk78EnnypxxOLaG4Aa4v1h
|
| 23 |
+
from src.model.Effort import Effort
|
| 24 |
+
|
| 25 |
+
return Effort(config)
|
| 26 |
+
|
| 27 |
+
if "weights/ForAda" in config.checkpoint:
|
| 28 |
+
# Download: https://drive.usercontent.google.com/download?id=1UlaAUTtsX87ofIibf38TtfAKIsnA7WVm&export=download&authuser=0
|
| 29 |
+
from src.model.ForAda import ForAda
|
| 30 |
+
|
| 31 |
+
return ForAda(config)
|
| 32 |
+
|
| 33 |
+
if "weights/FS-VFM/" in config.checkpoint:
|
| 34 |
+
from src.model.FSFM import FSFM
|
| 35 |
+
|
| 36 |
+
return FSFM(config)
|
| 37 |
+
|
| 38 |
+
if "yermandy/" in config.checkpoint:
|
| 39 |
+
# https://huggingface.co/yermandy/models
|
| 40 |
+
from src.model.GenDHF import GenDHF
|
| 41 |
+
|
| 42 |
+
return GenDHF(config)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
raise ValueError(f"Unknown third party model in checkpoint path: {config.checkpoint}")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def load_model(config: Config) -> BaseDeepakeDetectionModel:
|
| 49 |
+
# If no checkpoint is provided, use GenD as default
|
| 50 |
+
if config.checkpoint is None or config.checkpoint == "":
|
| 51 |
+
from src.model.GenD import GenD
|
| 52 |
+
|
| 53 |
+
return GenD(config, verbose=True)
|
| 54 |
+
|
| 55 |
+
# Try to load third party model
|
| 56 |
+
try:
|
| 57 |
+
return load_third_party_model(config)
|
| 58 |
+
except ValueError:
|
| 59 |
+
# If not a third party model, use GenD as default
|
| 60 |
+
from src.model.GenD import GenD
|
| 61 |
+
|
| 62 |
+
return GenD(config, verbose=True)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def init_loggers(config: Config) -> list:
|
| 66 |
+
save_dir = f"{config.run_dir}/{config.run_name}"
|
| 67 |
+
|
| 68 |
+
loggers: list = [pl_loggers.CSVLogger(config.run_dir, name=config.run_name, version="")]
|
| 69 |
+
|
| 70 |
+
if config.wandb:
|
| 71 |
+
wandb_logger = pl_loggers.WandbLogger(
|
| 72 |
+
project="deepfake",
|
| 73 |
+
name=config.run_name,
|
| 74 |
+
save_dir=save_dir,
|
| 75 |
+
tags=set(config.wandb_tags),
|
| 76 |
+
group=config.wandb_group,
|
| 77 |
+
)
|
| 78 |
+
loggers.append(wandb_logger)
|
| 79 |
+
|
| 80 |
+
return loggers
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def init_callbacks(config: Config) -> list:
|
| 84 |
+
callbacks = [
|
| 85 |
+
pl_callbacks.RichProgressBar(leave=True),
|
| 86 |
+
ModelCheckpointParallel(
|
| 87 |
+
filename=config.checkpoint_name, monitor=config.monitor_metric, mode=config.monitor_metric_mode
|
| 88 |
+
),
|
| 89 |
+
]
|
| 90 |
+
# pl_callbacks.LearningRateFinder(1e-5, 1e-2),
|
| 91 |
+
|
| 92 |
+
if config.early_stopping_patience > 0:
|
| 93 |
+
callbacks.append(
|
| 94 |
+
pl_callbacks.EarlyStopping(
|
| 95 |
+
monitor=config.monitor_metric,
|
| 96 |
+
patience=config.early_stopping_patience,
|
| 97 |
+
mode=config.monitor_metric_mode,
|
| 98 |
+
verbose=True,
|
| 99 |
+
)
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
return callbacks
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def finish_wandb_run(trainer, config: Config):
|
| 106 |
+
if config.wandb:
|
| 107 |
+
if any(isinstance(l, pl_loggers.WandbLogger) for l in trainer.loggers):
|
| 108 |
+
wandb_logger = [l for l in trainer.loggers if isinstance(l, pl_loggers.WandbLogger)][0]
|
| 109 |
+
wandb_logger.finalize("success")
|
| 110 |
+
wandb_logger.experiment.finish()
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def main(config: Config, train: bool):
|
| 114 |
+
# Performs initial checks
|
| 115 |
+
checks(config)
|
| 116 |
+
|
| 117 |
+
# Set the precision for matmul operations
|
| 118 |
+
torch.set_float32_matmul_precision("high")
|
| 119 |
+
|
| 120 |
+
# Instantiates the model
|
| 121 |
+
model = load_model(config)
|
| 122 |
+
|
| 123 |
+
# Loads the checkpoint if provided
|
| 124 |
+
model.load_checkpoint(config.checkpoint)
|
| 125 |
+
|
| 126 |
+
data_module = datasets.DeepfakeDataModule(config, model.get_preprocessing())
|
| 127 |
+
|
| 128 |
+
save_dir = f"{config.run_dir}/{config.run_name}"
|
| 129 |
+
|
| 130 |
+
trainer = Trainer(
|
| 131 |
+
devices=config.devices,
|
| 132 |
+
max_epochs=config.max_epochs,
|
| 133 |
+
precision=config.precision,
|
| 134 |
+
accumulate_grad_batches=config.batch_size // config.mini_batch_size,
|
| 135 |
+
fast_dev_run=config.fast_dev_run,
|
| 136 |
+
log_every_n_steps=100,
|
| 137 |
+
overfit_batches=config.overfit_batches,
|
| 138 |
+
limit_train_batches=config.limit_train_batches,
|
| 139 |
+
limit_val_batches=config.limit_val_batches,
|
| 140 |
+
limit_test_batches=config.limit_test_batches,
|
| 141 |
+
deterministic=config.deterministic,
|
| 142 |
+
detect_anomaly=config.detect_anomaly,
|
| 143 |
+
logger=init_loggers(config),
|
| 144 |
+
callbacks=init_callbacks(config),
|
| 145 |
+
default_root_dir=config.run_dir,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
if train:
|
| 149 |
+
try:
|
| 150 |
+
trainer.fit(model, data_module)
|
| 151 |
+
except KeyboardInterrupt:
|
| 152 |
+
logger.print_warning("Training interrupted")
|
| 153 |
+
except Exception as e:
|
| 154 |
+
traceback.print_exc() # Print complete exception traceback
|
| 155 |
+
logger.print_error(f"Training failed: {e}")
|
| 156 |
+
# Save the exception traceback to a file
|
| 157 |
+
with open(f"{save_dir}/failed.log", "a") as f:
|
| 158 |
+
f.write(f"Training failed: {e}\n")
|
| 159 |
+
f.write(traceback.format_exc())
|
| 160 |
+
finally:
|
| 161 |
+
logger.print_info("Training finished. Starting testing")
|
| 162 |
+
ckpt_path = f"{save_dir}/checkpoints/{config.checkpoint_name}.ckpt"
|
| 163 |
+
if not os.path.exists(ckpt_path):
|
| 164 |
+
logger.print_error(f"Checkpoint {ckpt_path} does not exist. Cannot proceed with testing.")
|
| 165 |
+
else:
|
| 166 |
+
model.load_checkpoint(ckpt_path)
|
| 167 |
+
trainer.test(model, data_module)
|
| 168 |
+
|
| 169 |
+
else:
|
| 170 |
+
assert config.checkpoint is not None, "Checkpoint is required for testing"
|
| 171 |
+
trainer.test(model, data_module)
|
| 172 |
+
|
| 173 |
+
# Finish wandb run
|
| 174 |
+
finish_wandb_run(trainer, config)
|
run_exp.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import traceback
|
| 2 |
+
from copy import deepcopy
|
| 3 |
+
|
| 4 |
+
import fire
|
| 5 |
+
|
| 6 |
+
from run import main
|
| 7 |
+
from src import config as C
|
| 8 |
+
from src.config import Config
|
| 9 |
+
from src.exp import experiments
|
| 10 |
+
from src.utils import files, logger
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_val_files():
|
| 14 |
+
return [
|
| 15 |
+
*files.DeepSpeak_v2.my_val,
|
| 16 |
+
*files.DeepSpeak_v1_1.my_val,
|
| 17 |
+
*files.CDFv2.val,
|
| 18 |
+
*files.FFIW.val,
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_test_files():
|
| 23 |
+
return {
|
| 24 |
+
"FF": files.FF.test,
|
| 25 |
+
"FF-DF": files.FF.DF.test,
|
| 26 |
+
"FF-F2F": files.FF.F2F.test,
|
| 27 |
+
"FF-FS": files.FF.FS.test,
|
| 28 |
+
"FF-NT": files.FF.NT.test,
|
| 29 |
+
"CDF": files.CDFv2.test,
|
| 30 |
+
"FaceFusion": files.FaceFusion.CDF.test,
|
| 31 |
+
"DFD": files.DFD.test,
|
| 32 |
+
"DFDC": files.DFDC.test,
|
| 33 |
+
"FSh": files.FSh.test,
|
| 34 |
+
"UADFD": files.UADFV.test,
|
| 35 |
+
"DFDM": files.DFDM.test,
|
| 36 |
+
"FFIW": files.FFIW.test,
|
| 37 |
+
"DeepSpeak-1.1": files.DeepSpeak_v1_1.test,
|
| 38 |
+
"DeepSpeak-2.0": files.DeepSpeak_v2.test,
|
| 39 |
+
"KoDF": files.KoDF.test,
|
| 40 |
+
"KoDF-adv": files.KoDF.adversarial,
|
| 41 |
+
"FakeAVCeleb": files.FakeAVCeleb.test,
|
| 42 |
+
"FAVC-FV-RA-WL": files.FakeAVCeleb.FV_RA_WL.test,
|
| 43 |
+
"FAVC-FV-FA-FS": files.FakeAVCeleb.FV_FA_FS.test,
|
| 44 |
+
"FAVC-FV-FA-GAN": files.FakeAVCeleb.FV_FA_GAN.test,
|
| 45 |
+
"FAVC-FV-FA-WL": files.FakeAVCeleb.FV_FA_WL.test,
|
| 46 |
+
"PolyGlotFake": files.PolyGlotFake.test,
|
| 47 |
+
"IDForge-v1": files.IDForge_v1.test,
|
| 48 |
+
} | {
|
| 49 |
+
k: v.map(lambda x: x.replace("/CDFv3/", "/CDFv3-x1.3-th0.5-all/subset/uniform-32-frames/"))
|
| 50 |
+
for k, v in files.CDFv3.get_test_dict().items()
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def get_default_train_config() -> Config:
|
| 55 |
+
config = Config()
|
| 56 |
+
|
| 57 |
+
config.run_dir = "runs/rebuttal"
|
| 58 |
+
config.wandb = True
|
| 59 |
+
config.wandb_tags.append("rebuttal")
|
| 60 |
+
config.throw_exception_if_run_exists = True
|
| 61 |
+
|
| 62 |
+
config.num_workers = 12
|
| 63 |
+
config.devices = "auto"
|
| 64 |
+
|
| 65 |
+
config.backbone = C.Backbone.CLIP_L_14
|
| 66 |
+
config.freeze_feature_extractor = True
|
| 67 |
+
config.num_classes = 2
|
| 68 |
+
|
| 69 |
+
config.batch_size = config.mini_batch_size = 128
|
| 70 |
+
config.lr_scheduler = "cosine"
|
| 71 |
+
config.lr = 3e-4
|
| 72 |
+
config.min_lr = 1e-5
|
| 73 |
+
config.weight_decay = 0
|
| 74 |
+
config.max_epochs = 1 + 50
|
| 75 |
+
config.warmup_epochs = 1
|
| 76 |
+
|
| 77 |
+
config.trn_files = files.FF.train
|
| 78 |
+
config.val_files = get_val_files()
|
| 79 |
+
config.tst_files = get_test_files()
|
| 80 |
+
|
| 81 |
+
return config
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def get_default_test_config(orig_run_name, new_run_name) -> Config:
|
| 85 |
+
orig_run_dir = files.find_run_dir(orig_run_name)
|
| 86 |
+
orig_config_path = f"{orig_run_dir}/hparams.yaml"
|
| 87 |
+
checkpoint = "best_mAP.ckpt" # Default checkpoint name
|
| 88 |
+
|
| 89 |
+
# Load run specific config
|
| 90 |
+
config = C.load_config(orig_config_path)
|
| 91 |
+
|
| 92 |
+
config.run_name = new_run_name # Rename the run
|
| 93 |
+
config.run_dir = "runs/test" # Set default test dir
|
| 94 |
+
config.checkpoint = f"{orig_run_dir}/checkpoints/{checkpoint}"
|
| 95 |
+
|
| 96 |
+
config.wandb = True
|
| 97 |
+
config.wandb_tags.extend(["test"])
|
| 98 |
+
|
| 99 |
+
config.num_workers = 12
|
| 100 |
+
config.batch_size = config.mini_batch_size = 1024
|
| 101 |
+
config.devices = "auto"
|
| 102 |
+
|
| 103 |
+
config.tst_files = get_test_files()
|
| 104 |
+
|
| 105 |
+
return config
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def get_debug_config(config: Config) -> Config:
|
| 109 |
+
#! Debug
|
| 110 |
+
|
| 111 |
+
config.run_dir = "runs/tmp"
|
| 112 |
+
config.run_name = "tmp"
|
| 113 |
+
# config.num_workers = 0
|
| 114 |
+
config.max_epochs = 1
|
| 115 |
+
config.limit_train_batches = 12
|
| 116 |
+
config.limit_val_batches = 12
|
| 117 |
+
config.limit_test_batches = 12
|
| 118 |
+
# config.batch_size = config.mini_batch_size = 2
|
| 119 |
+
# config.deterministic = True
|
| 120 |
+
# config.detect_anomaly = True
|
| 121 |
+
|
| 122 |
+
config.trn_files = files.FF.train
|
| 123 |
+
config.val_files = files.FF.val
|
| 124 |
+
config.tst_files = files.FF.val
|
| 125 |
+
|
| 126 |
+
return config
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
experiments = {
|
| 130 |
+
**experiments, # Include all experiments defined in src.exp
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def entry(
|
| 135 |
+
exp_names: str | list[str],
|
| 136 |
+
debug: bool = False,
|
| 137 |
+
test: bool = False,
|
| 138 |
+
from_exp: str | None = None,
|
| 139 |
+
**kwargs,
|
| 140 |
+
):
|
| 141 |
+
if test:
|
| 142 |
+
if from_exp is not None:
|
| 143 |
+
if isinstance(exp_names, list):
|
| 144 |
+
if len(exp_names) != 1:
|
| 145 |
+
raise Exception("When running in test mode, you can provide only one experiment name.")
|
| 146 |
+
config = get_default_test_config(from_exp, exp_names[0])
|
| 147 |
+
else:
|
| 148 |
+
logger.print_warning("Running in test mode, but 'from_exp' is not provided. Using default test config.")
|
| 149 |
+
config = C.Config()
|
| 150 |
+
else:
|
| 151 |
+
config = get_default_train_config()
|
| 152 |
+
|
| 153 |
+
# parse name to list
|
| 154 |
+
if isinstance(exp_names, str):
|
| 155 |
+
exp_names = [exp_names]
|
| 156 |
+
|
| 157 |
+
for exp_name in exp_names:
|
| 158 |
+
exp_name = exp_name.strip()
|
| 159 |
+
|
| 160 |
+
if exp_name not in experiments:
|
| 161 |
+
logger.print_error(f"Experiment '{exp_name}' is not defined in 'src/exp/__init__.py:1'")
|
| 162 |
+
logger.print(f"Available experiments: {list(experiments.keys())}")
|
| 163 |
+
continue
|
| 164 |
+
|
| 165 |
+
modifiers = experiments[exp_name]
|
| 166 |
+
config_exp = deepcopy(config)
|
| 167 |
+
|
| 168 |
+
config_exp.run_name = exp_name
|
| 169 |
+
for modify in modifiers:
|
| 170 |
+
if isinstance(modify, Config):
|
| 171 |
+
# If the modifier is a Config object, change only different values
|
| 172 |
+
difference = modify.model_dump(exclude_unset=True)
|
| 173 |
+
# TODO: maybe set_values_from_dict(difference)?
|
| 174 |
+
config_exp = Config(**config_exp.model_copy(update=difference).model_dump())
|
| 175 |
+
# config_exp = config_exp.model_copy(update=difference)
|
| 176 |
+
else:
|
| 177 |
+
config_exp = modify(config_exp)
|
| 178 |
+
|
| 179 |
+
config_exp = Config(**config_exp.model_dump()) # Parse and validate config
|
| 180 |
+
|
| 181 |
+
if debug:
|
| 182 |
+
config_exp = config_exp.model_copy(update=get_debug_config(config_exp).model_dump())
|
| 183 |
+
|
| 184 |
+
# Update config with kwargs
|
| 185 |
+
config_exp.set_values_from_dict(kwargs)
|
| 186 |
+
|
| 187 |
+
# Revalidate the config - checks if user provided valid values
|
| 188 |
+
config_exp = Config(**config_exp.model_dump())
|
| 189 |
+
|
| 190 |
+
# logger.print(config_exp)
|
| 191 |
+
# exit()
|
| 192 |
+
|
| 193 |
+
try:
|
| 194 |
+
main(config_exp, not test)
|
| 195 |
+
|
| 196 |
+
except Exception as e:
|
| 197 |
+
traceback.print_exc() # Print complete exception traceback
|
| 198 |
+
logger.print_error(f"Error occurred while running experiment '{exp_name}':")
|
| 199 |
+
logger.print(e)
|
| 200 |
+
|
| 201 |
+
save_dir = f"{config_exp.run_dir}/{config_exp.run_name}"
|
| 202 |
+
# Save the exception traceback to a file
|
| 203 |
+
with open(f"{save_dir}/failed.log", "a") as f:
|
| 204 |
+
f.write(f"\nTraining failed: {e}\n")
|
| 205 |
+
f.write(traceback.format_exc())
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
if __name__ == "__main__":
|
| 209 |
+
fire.Fire(entry)
|