yermandy commited on
Commit
c29babb
·
0 Parent(s):
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. .gitignore +14 -0
  3. .project-root +1 -0
  4. LICENSE +21 -0
  5. README.md +157 -0
  6. config/datasets/CDFv2/test/Celeb-real.txt +4 -0
  7. config/datasets/CDFv2/test/Celeb-synthesis.txt +4 -0
  8. config/datasets/CDFv2/test/YouTube-real.txt +4 -0
  9. config/datasets/FF/test/DF.txt +4 -0
  10. config/datasets/FF/test/F2F.txt +4 -0
  11. config/datasets/FF/test/FS.txt +4 -0
  12. config/datasets/FF/test/NT.txt +4 -0
  13. config/datasets/FF/test/real.txt +4 -0
  14. datasets/CDFv2/Celeb-real/id0_0000/000.png +3 -0
  15. datasets/CDFv2/Celeb-real/id0_0000/015.png +3 -0
  16. datasets/CDFv2/Celeb-real/id0_0000/030.png +3 -0
  17. datasets/CDFv2/Celeb-real/id0_0000/045.png +3 -0
  18. datasets/CDFv2/Celeb-synthesis/id0_id1_0000/000.png +3 -0
  19. datasets/CDFv2/Celeb-synthesis/id0_id1_0000/015.png +3 -0
  20. datasets/CDFv2/Celeb-synthesis/id0_id1_0000/030.png +3 -0
  21. datasets/CDFv2/Celeb-synthesis/id0_id1_0000/045.png +3 -0
  22. datasets/CDFv2/YouTube-real/00000/000.png +3 -0
  23. datasets/CDFv2/YouTube-real/00000/014.png +3 -0
  24. datasets/CDFv2/YouTube-real/00000/028.png +3 -0
  25. datasets/CDFv2/YouTube-real/00000/043.png +3 -0
  26. datasets/FF/DF/000_003/000.png +3 -0
  27. datasets/FF/DF/000_003/012.png +3 -0
  28. datasets/FF/DF/000_003/025.png +3 -0
  29. datasets/FF/DF/000_003/038.png +3 -0
  30. datasets/FF/F2F/000_003/000.png +3 -0
  31. datasets/FF/F2F/000_003/009.png +3 -0
  32. datasets/FF/F2F/000_003/019.png +3 -0
  33. datasets/FF/F2F/000_003/029.png +3 -0
  34. datasets/FF/FS/000_003/000.png +3 -0
  35. datasets/FF/FS/000_003/009.png +3 -0
  36. datasets/FF/FS/000_003/019.png +3 -0
  37. datasets/FF/FS/000_003/029.png +3 -0
  38. datasets/FF/NT/000_003/000.png +3 -0
  39. datasets/FF/NT/000_003/009.png +3 -0
  40. datasets/FF/NT/000_003/019.png +3 -0
  41. datasets/FF/NT/000_003/029.png +3 -0
  42. datasets/FF/real/000/000.png +3 -0
  43. datasets/FF/real/000/012.png +3 -0
  44. datasets/FF/real/000/025.png +3 -0
  45. datasets/FF/real/000/038.png +3 -0
  46. detector.py +701 -0
  47. pyproject.toml +37 -0
  48. requirements.txt +34 -0
  49. run.py +174 -0
  50. run_exp.py +209 -0
.gitattributes ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ *.png filter=lfs diff=lfs merge=lfs -text
2
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
3
+ *.gz filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+
3
+ /.vscode
4
+ /config
5
+ /datasets
6
+ /outputs
7
+ /runs
8
+ /weights
9
+ /logs
10
+ /tmp
11
+
12
+ x.py
13
+ y.py
14
+ z.py
.project-root ADDED
@@ -0,0 +1 @@
 
 
1
+ # Do not remove, this file is used by the project to determine the root of the project
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Andy
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Deepfake Detection that Generalizes Across Benchmarks (WACV 2026)
2
+
3
+ [![arXiv Badge](https://img.shields.io/badge/arXiv-B31B1B?logo=arxiv&logoColor=FFF)](https://arxiv.org/abs/2508.06248)
4
+ [![Hugging Face Badge](https://img.shields.io/badge/Hugging%20Face-FFD21E?logo=huggingface&logoColor=000)](https://huggingface.co/collections/yermandy/gend)
5
+
6
+ This is the official repository for the paper:
7
+
8
+ **[Deepfake Detection that Generalizes Across Benchmarks](https://arxiv.org/abs/2508.06248)**.
9
+
10
+ ### Abstract
11
+
12
+ > The generalization of deepfake detectors to unseen manipulation techniques remains a challenge for practical deployment. Although many approaches adapt foundation models by introducing significant architectural complexity, this work demonstrates that robust generalization is achievable through a parameter-efficient adaptation of one of the foundational pre-trained vision encoders. The proposed method, GenD, fine-tunes only the Layer Normalization parameters (0.03% of the total) and enhances generalization by enforcing a hyperspherical feature manifold using L2 normalization and metric learning on it.
13
+ >
14
+ > We conducted an extensive evaluation on 14 benchmark datasets spanning from 2019 to 2025. The proposed method achieves state-of-the-art performance, outperforming more complex, recent approaches in average cross-dataset AUROC. Our analysis yields two primary findings for the field: 1) training on paired real-fake data from the same source video is essential for mitigating shortcut learning and improving generalization, and 2) detection difficulty on academic datasets has not strictly increased over time, with models trained on older, diverse datasets showing strong generalization capabilities.
15
+ >
16
+ > This work delivers a computationally efficient and reproducible method, proving that state-of-the-art generalization is attainable by making targeted, minimal changes to a pre-trained foundational image encoder model.
17
+
18
+ ## Inference using Hugging Face transformers
19
+
20
+ This example shows how to run inference with the pretrained GenD model from Hugging Face without other dependencies except `torch` and `transformers`. It expects that input images are already preprocessed by detector.
21
+
22
+ ### Minimal dependencies
23
+
24
+ ``` bash
25
+ conda create --name GenD python=3.12 uv -y
26
+ conda activate GenD
27
+ uv pip install torch==2.8.0
28
+ uv pip install torchvision==0.23.0
29
+ uv pip install transformers==4.56.2
30
+ ```
31
+
32
+ ### Inference with transformers
33
+
34
+ ``` python
35
+ import requests
36
+ import torch
37
+ from PIL import Image
38
+
39
+ from src.hf.modeling_gend import GenD
40
+
41
+ # Other models can be found in https://huggingface.co/collections/yermandy/gend:
42
+ # -**** yermandy/GenD_CLIP_L_14
43
+ # - yermandy/GenD_PE_L
44
+ # - yermandy/GenD_DINOv3_L
45
+ model = GenD.from_pretrained("yermandy/GenD_CLIP_L_14")
46
+
47
+ urls = [
48
+ "https://github.com/yermandy/deepfake-detection/blob/main/datasets/FF/DF/000_003/000.png?raw=true",
49
+ "https://github.com/yermandy/deepfake-detection/blob/main/datasets/FF/real/000/000.png?raw=true",
50
+ ]
51
+ images = [Image.open(requests.get(url, stream=True).raw) for url in urls]
52
+ tensors = torch.stack([model.feature_extractor.preprocess(img) for img in images])
53
+ logits = model(tensors)
54
+ probs = logits.softmax(dim=-1)
55
+
56
+ print(probs)
57
+ ```
58
+
59
+ ## Training
60
+
61
+ ### Set up environment
62
+
63
+ ``` bash
64
+ conda create --name GenD python=3.12 uv -y
65
+ conda activate GenD
66
+ uv pip install -r requirements.txt
67
+ ```
68
+
69
+ ### Minimal example without external data
70
+
71
+ #### Training example
72
+
73
+ Examine `src/exp/examples.py`, each experiment name is defined as a key, a value overrides default configuration of `Config` object from `src/config.py`. For example, try to run `example-training` experiment:
74
+
75
+ ``` bash
76
+ python run_exp.py example-training
77
+ ```
78
+
79
+ #### Test example after the model is trained
80
+
81
+ ``` bash
82
+ python run_exp.py example-test --from_exp example-training --test
83
+ ```
84
+
85
+ Alternatively, you can try inference using one of our released models from Hugging Face:
86
+
87
+ ``` bash
88
+ python run_exp.py GenD_CLIP--CDFv2-example --test
89
+ python run_exp.py GenD_PE--CDFv2-example --test
90
+ python run_exp.py GenD_DINO--CDFv2-example --test
91
+ ```
92
+
93
+ ### Full training
94
+
95
+ To fully train the model, you need to download datasets, preprocess them, and create files with paths to the images.
96
+
97
+ The training entry will be similar to the minimal example above.
98
+
99
+ All experiments (configs) from the paper are stored in the `src/exp` folder.
100
+
101
+ #### Prepare the dataset
102
+
103
+ Take for example [FaceForensics++](https://github.com/ondyari/FaceForensics) dataset, follow these steps:
104
+
105
+ 1. Download the dataset first from the [official source](https://github.com/ondyari/FaceForensics). The root of this dataset is `./FaceForensics`
106
+
107
+ 2. Preprocess the dataset using `detector.py` script:
108
+
109
+ ``` bash
110
+ python detector.py -i FaceForensics/manipulated_sequences/Deepfakes/c23/videos/ --mask_folder FaceForensics/masks/manipulated_sequences/Deepfakes/masks/videos/ -m at_least -n 32 -o datasets/FF/DF/ --det_thres 0.1 -s 1.3 --target_size none
111
+ ```
112
+
113
+ Repeat the process for other manipulation methods and real videos. After processing everything, you will get a similar structure:
114
+
115
+ ``` bash
116
+ datasets
117
+ └── FF
118
+ ├── DF
119
+ │ └── 000_003
120
+ │ ├── 025.png
121
+ │ └── 038.png
122
+ ├── F2F
123
+ │ └── 000_003
124
+ │ ├── 019.png
125
+ │ └── 029.png
126
+ ├── FS
127
+ │ └── 000_003
128
+ │ ├── 019.png
129
+ │ └── 029.png
130
+ ├── NT
131
+ │ └── 000_003
132
+ │ ├── 019.png
133
+ │ └── 029.png
134
+ └── real
135
+ └── 000
136
+ ├── 025.png
137
+ └── 038.png
138
+ ```
139
+
140
+ 3. Create files with paths to images similar to the ones in `config/datasets` directory. It can be done using:
141
+
142
+ ``` bash
143
+ find datasets/FF/DF/* -type f | sort > config/datasets/FF/DF.txt
144
+ ```
145
+
146
+ We manage links to files using `src/utils/files.py`.
147
+
148
+ ### Cite
149
+
150
+ ``` bibtex
151
+ @article{yermakov2025deepfake,
152
+ title={Deepfake Detection that Generalizes Across Benchmarks},
153
+ author={Yermakov, Andrii and Cech, Jan and Matas, Jiri and Fritz, Mario},
154
+ journal={arXiv preprint arXiv:2508.06248},
155
+ year={2025}
156
+ }
157
+ ```
config/datasets/CDFv2/test/Celeb-real.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ datasets/CDFv2/Celeb-real/id0_0000/045.png
2
+ datasets/CDFv2/Celeb-real/id0_0000/030.png
3
+ datasets/CDFv2/Celeb-real/id0_0000/015.png
4
+ datasets/CDFv2/Celeb-real/id0_0000/000.png
config/datasets/CDFv2/test/Celeb-synthesis.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ datasets/CDFv2/Celeb-synthesis/id0_id1_0000/000.png
2
+ datasets/CDFv2/Celeb-synthesis/id0_id1_0000/045.png
3
+ datasets/CDFv2/Celeb-synthesis/id0_id1_0000/030.png
4
+ datasets/CDFv2/Celeb-synthesis/id0_id1_0000/015.png
config/datasets/CDFv2/test/YouTube-real.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ datasets/CDFv2/YouTube-real/00000/000.png
2
+ datasets/CDFv2/YouTube-real/00000/014.png
3
+ datasets/CDFv2/YouTube-real/00000/028.png
4
+ datasets/CDFv2/YouTube-real/00000/043.png
config/datasets/FF/test/DF.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ datasets/FF/DF/000_003/000.png
2
+ datasets/FF/DF/000_003/012.png
3
+ datasets/FF/DF/000_003/025.png
4
+ datasets/FF/DF/000_003/038.png
config/datasets/FF/test/F2F.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ datasets/FF/F2F/000_003/000.png
2
+ datasets/FF/F2F/000_003/009.png
3
+ datasets/FF/F2F/000_003/019.png
4
+ datasets/FF/F2F/000_003/029.png
config/datasets/FF/test/FS.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ datasets/FF/FS/000_003/000.png
2
+ datasets/FF/FS/000_003/009.png
3
+ datasets/FF/FS/000_003/019.png
4
+ datasets/FF/FS/000_003/029.png
config/datasets/FF/test/NT.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ datasets/FF/NT/000_003/000.png
2
+ datasets/FF/NT/000_003/009.png
3
+ datasets/FF/NT/000_003/019.png
4
+ datasets/FF/NT/000_003/029.png
config/datasets/FF/test/real.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ datasets/FF/real/000/000.png
2
+ datasets/FF/real/000/012.png
3
+ datasets/FF/real/000/025.png
4
+ datasets/FF/real/000/038.png
datasets/CDFv2/Celeb-real/id0_0000/000.png ADDED

Git LFS Details

  • SHA256: 33a652cb6ad545d41465a978ab4bb02137380db1747a1a460d193a3e0ecd4db6
  • Pointer size: 130 Bytes
  • Size of remote file: 51.8 kB
datasets/CDFv2/Celeb-real/id0_0000/015.png ADDED

Git LFS Details

  • SHA256: cca6e95d080ccafbd35709c0e6ce60a12b415c7f97cf40ac4c1edb7fba441e4f
  • Pointer size: 130 Bytes
  • Size of remote file: 55.2 kB
datasets/CDFv2/Celeb-real/id0_0000/030.png ADDED

Git LFS Details

  • SHA256: 20d6775e831ef3bab80cd928c930b0916f09b933814fe0d2b1d6e51b0106c77a
  • Pointer size: 130 Bytes
  • Size of remote file: 56.3 kB
datasets/CDFv2/Celeb-real/id0_0000/045.png ADDED

Git LFS Details

  • SHA256: 6ca1683cfe93a01ac7800de0efa5f82c38abb5b17582c5708671de567958e08e
  • Pointer size: 130 Bytes
  • Size of remote file: 57.8 kB
datasets/CDFv2/Celeb-synthesis/id0_id1_0000/000.png ADDED

Git LFS Details

  • SHA256: cae26283688b2e1855b75b922f75e0945dd29f1669e5c399b9e0f5bc75a4700c
  • Pointer size: 130 Bytes
  • Size of remote file: 51.5 kB
datasets/CDFv2/Celeb-synthesis/id0_id1_0000/015.png ADDED

Git LFS Details

  • SHA256: 708f825b0fa8403e5db2966d96e4be4b18f4d1f55ab7352acdc2bd2d7542ee5b
  • Pointer size: 130 Bytes
  • Size of remote file: 54.6 kB
datasets/CDFv2/Celeb-synthesis/id0_id1_0000/030.png ADDED

Git LFS Details

  • SHA256: 5693e8c8469630a8699f6cc4b8a51edd4297d31a956a14deba1a1b7796230ec4
  • Pointer size: 130 Bytes
  • Size of remote file: 54 kB
datasets/CDFv2/Celeb-synthesis/id0_id1_0000/045.png ADDED

Git LFS Details

  • SHA256: 90c2250095d7331e64caa42d04d25cac5881288bbfa7013e68302cbe758ade06
  • Pointer size: 130 Bytes
  • Size of remote file: 56.1 kB
datasets/CDFv2/YouTube-real/00000/000.png ADDED

Git LFS Details

  • SHA256: c959f785348832e39685a4918904b3a145ff475ce69eef55624708b423473218
  • Pointer size: 130 Bytes
  • Size of remote file: 52.1 kB
datasets/CDFv2/YouTube-real/00000/014.png ADDED

Git LFS Details

  • SHA256: 775bf1f48de7319a4557071844ad3fb587bc048f0ca1d8a52b696f4167476996
  • Pointer size: 130 Bytes
  • Size of remote file: 58.8 kB
datasets/CDFv2/YouTube-real/00000/028.png ADDED

Git LFS Details

  • SHA256: c7eb0e9af30cbe4e2cab7c6df54a434405470eca81b416e6f835e65fac8e1fc6
  • Pointer size: 130 Bytes
  • Size of remote file: 59 kB
datasets/CDFv2/YouTube-real/00000/043.png ADDED

Git LFS Details

  • SHA256: 37adad8dd2df6a8d82f5ea530ab9db464a81037125180e7130ee3fe1bc1ac567
  • Pointer size: 130 Bytes
  • Size of remote file: 59.2 kB
datasets/FF/DF/000_003/000.png ADDED

Git LFS Details

  • SHA256: b2c605732c6b2152320a986173c1a3dc7544938e948aa46444340831b8018060
  • Pointer size: 130 Bytes
  • Size of remote file: 82.6 kB
datasets/FF/DF/000_003/012.png ADDED

Git LFS Details

  • SHA256: 91b2bf24d9c3685c28559a5ac8c91b94bbe5b841a563b961f4a7679f40e8f4b8
  • Pointer size: 130 Bytes
  • Size of remote file: 83.8 kB
datasets/FF/DF/000_003/025.png ADDED

Git LFS Details

  • SHA256: ca387094072caa412a0f683d909084580896f2c40dd669baf41c04166efffa95
  • Pointer size: 130 Bytes
  • Size of remote file: 82 kB
datasets/FF/DF/000_003/038.png ADDED

Git LFS Details

  • SHA256: 781d4ab6041c8b37c75b7d831f02944cae11b41fb848d5e34c577a142cb3a1a0
  • Pointer size: 130 Bytes
  • Size of remote file: 82.2 kB
datasets/FF/F2F/000_003/000.png ADDED

Git LFS Details

  • SHA256: a8e8c25ddc42909f3b82aedf939ca5fcf3924d1aeb54ad9dbb507ed97d050692
  • Pointer size: 130 Bytes
  • Size of remote file: 82.7 kB
datasets/FF/F2F/000_003/009.png ADDED

Git LFS Details

  • SHA256: 8ddccb64c403bb4cd0245f6826699bd53b02c3bc9b5435d3134b4d9aba019d85
  • Pointer size: 130 Bytes
  • Size of remote file: 82.1 kB
datasets/FF/F2F/000_003/019.png ADDED

Git LFS Details

  • SHA256: e1a08d87e07a821d0972e893e5b7fd7777e901d7b0433d1a1acb21862f349f5f
  • Pointer size: 130 Bytes
  • Size of remote file: 82 kB
datasets/FF/F2F/000_003/029.png ADDED

Git LFS Details

  • SHA256: da6f3aa04c3e0d6155dacdb4003cfb507f5b600b087e57d8184387c1fbbc76eb
  • Pointer size: 130 Bytes
  • Size of remote file: 82 kB
datasets/FF/FS/000_003/000.png ADDED

Git LFS Details

  • SHA256: 57b282095be875d9b161360d43cbd4b04d0536c093bb967e6bc6661baf7db361
  • Pointer size: 130 Bytes
  • Size of remote file: 82.3 kB
datasets/FF/FS/000_003/009.png ADDED

Git LFS Details

  • SHA256: 8481e3a732f811a832598bd39ffb426cd966f4b16478e603c5ec9f5351b732ff
  • Pointer size: 130 Bytes
  • Size of remote file: 81.1 kB
datasets/FF/FS/000_003/019.png ADDED

Git LFS Details

  • SHA256: bc24297d7825d45f663960f9bc0b7e215e7b67312ea9b00f495d289df432b32b
  • Pointer size: 130 Bytes
  • Size of remote file: 81.3 kB
datasets/FF/FS/000_003/029.png ADDED

Git LFS Details

  • SHA256: caae8b99b9f0016e7e3c2bbe9500422687bbdcd4ae2305f4de8ff7a05c5bf587
  • Pointer size: 130 Bytes
  • Size of remote file: 80.5 kB
datasets/FF/NT/000_003/000.png ADDED

Git LFS Details

  • SHA256: ca2b6ce7ea30df50d86855adc4c2dd11b1ba0848153746303e73cb58d7035290
  • Pointer size: 130 Bytes
  • Size of remote file: 81.5 kB
datasets/FF/NT/000_003/009.png ADDED

Git LFS Details

  • SHA256: 616d27b48e463eeadd9f10a3b6294b20d28b71393cbca66c7e5d91c626644f01
  • Pointer size: 130 Bytes
  • Size of remote file: 80.5 kB
datasets/FF/NT/000_003/019.png ADDED

Git LFS Details

  • SHA256: ec4ac46bf3f06d5de47af5bb7e7af691fc3dd46199f8d5216cdbabe65f311faf
  • Pointer size: 130 Bytes
  • Size of remote file: 80.2 kB
datasets/FF/NT/000_003/029.png ADDED

Git LFS Details

  • SHA256: 37a3fd9de8ea981445ca036dfc4ba44d8d1ff42163e14972563e6c5485b57af2
  • Pointer size: 130 Bytes
  • Size of remote file: 80.1 kB
datasets/FF/real/000/000.png ADDED

Git LFS Details

  • SHA256: 33813fa2f7a716f20f27f11bf7e4126c53136108bec3db92dd6001e9453b185a
  • Pointer size: 130 Bytes
  • Size of remote file: 84 kB
datasets/FF/real/000/012.png ADDED

Git LFS Details

  • SHA256: d3af45388e764ae25e9a176a9ed148004d4e155146c9364199852a96815aafaf
  • Pointer size: 130 Bytes
  • Size of remote file: 84.8 kB
datasets/FF/real/000/025.png ADDED

Git LFS Details

  • SHA256: 2b027fe48cf3468ad596acf3f8f9db59374dce243f57d4abdb89ab557e6681b6
  • Pointer size: 130 Bytes
  • Size of remote file: 82.4 kB
datasets/FF/real/000/038.png ADDED

Git LFS Details

  • SHA256: fc521c3b01c76428d353d734fde85969613e450666135881029c1e4e5cd38ce1
  • Pointer size: 130 Bytes
  • Size of remote file: 83.7 kB
detector.py ADDED
@@ -0,0 +1,701 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import heapq
3
+ import os
4
+ import subprocess
5
+ from concurrent.futures import ThreadPoolExecutor
6
+ from glob import glob
7
+
8
+ import cv2
9
+ import numpy as np
10
+ from tqdm import tqdm
11
+
12
+ from src.retinaface import RetinaFace, prepare_model
13
+
14
+
15
+ def max_spread_permutation_pq(N, start=0):
16
+ """
17
+ Generate a permutation of 0..N-1 such that at each step
18
+ the next element is the one whose minimum distance to
19
+ all previously chosen elements is maximized, using a
20
+ priority queue to speed up selection.
21
+
22
+ Args:
23
+ N (int): Length of the permutation.
24
+ start (int): The first element in the permutation (default 0).
25
+
26
+ Returns:
27
+ List[int]: A list representing the permutation.
28
+ """
29
+ if not (0 <= start < N):
30
+ raise ValueError("`start` must be in the range [0, N-1]")
31
+
32
+ # Initialize chosen list and distance map
33
+ chosen = [start]
34
+ dist = {i: abs(i - start) for i in range(N) if i != start}
35
+
36
+ # Build a max-heap (use negative distances for heapq)
37
+ heap = [(-d, i) for i, d in dist.items()]
38
+ heapq.heapify(heap)
39
+
40
+ # Greedily pick elements
41
+ while heap:
42
+ # Pop until we find a valid (up-to-date) entry
43
+ while True:
44
+ neg_d, candidate = heapq.heappop(heap)
45
+ current = -neg_d
46
+ # Only accept if it matches the latest dist
47
+ if dist.get(candidate, -1) == current:
48
+ break
49
+
50
+ # Add the selected candidate
51
+ chosen.append(candidate)
52
+ # Remove it from dist-map
53
+ del dist[candidate]
54
+
55
+ # Update distances for remaining elements
56
+ for other in list(dist.keys()):
57
+ new_d = abs(other - candidate)
58
+ if new_d < dist[other]:
59
+ dist[other] = new_d
60
+ heapq.heappush(heap, (-new_d, other))
61
+
62
+ return chosen
63
+
64
+
65
+ def get_video_frames_generator(
66
+ source_path: str,
67
+ mask_path: str,
68
+ stride: int = 1,
69
+ num_frames=32,
70
+ mode="at_least",
71
+ ):
72
+ video = cv2.VideoCapture(source_path)
73
+ if not video.isOpened():
74
+ print(f"Warning: Video {source_path} cannot be opened!")
75
+ return
76
+
77
+ video_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
78
+
79
+ if mask_path is not None:
80
+ mask_video = cv2.VideoCapture(mask_path)
81
+
82
+ if not mask_video.isOpened():
83
+ print(f"Warning: Mask video {mask_path} cannot be opened!")
84
+ return
85
+
86
+ mask_frames = int(mask_video.get(cv2.CAP_PROP_FRAME_COUNT))
87
+
88
+ if video_frames != mask_frames:
89
+ print(
90
+ f"Warning: {source_path} and {mask_path} have different number of frames {video_frames} vs {mask_frames}!"
91
+ )
92
+
93
+ total_frames = min(video_frames, mask_frames)
94
+ else:
95
+ mask_video = None
96
+ total_frames = video_frames
97
+
98
+ if not video.isOpened():
99
+ raise Exception(f"Could not open video at {source_path}")
100
+
101
+ # Get the mode
102
+ if mode == "fixed_num_frames":
103
+ # Get the frame rate of the video by dividing the number of frames by the duration (same interval between frames)
104
+ frame_ids = np.linspace(0, total_frames - 1, num_frames, endpoint=True, dtype=int)
105
+ elif mode == "fixed_stride":
106
+ # Get the frame rate of the video by dividing the number of frames by the duration (same interval between frames)
107
+ frame_ids = np.arange(0, total_frames, stride, dtype=int)
108
+ elif mode == "at_least":
109
+ frame_ids = max_spread_permutation_pq(total_frames, start=total_frames // 2)
110
+ else:
111
+ raise ValueError(f"Invalid mode: {mode}. Choose 'fixed_num_frames', 'fixed_stride', or 'at_least'.")
112
+
113
+ # Iterate through the selected frame IDs
114
+ for frame_id in frame_ids:
115
+ # Set the video capture position to the desired frame
116
+ video.set(cv2.CAP_PROP_POS_FRAMES, frame_id)
117
+ success, frame = video.read()
118
+
119
+ if mask_video is not None:
120
+ mask_video.set(cv2.CAP_PROP_POS_FRAMES, frame_id)
121
+ success_mask, mask = mask_video.read()
122
+ if not success_mask:
123
+ print(f"Warning: Failed to read mask frame {frame_id} of {mask_path}. Skipping.")
124
+ continue
125
+
126
+ yield frame, frame_id, mask
127
+
128
+ else:
129
+ # Check if the frame was successfully read
130
+ if not success:
131
+ print(f"Warning: Failed to read frame {frame_id} of {source_path}. Skipping.")
132
+ continue
133
+
134
+ yield frame, frame_id, None
135
+
136
+ # Release the video capture object
137
+ video.release()
138
+
139
+ if mask_video is not None:
140
+ mask_video.release()
141
+
142
+
143
+ def align_face(
144
+ img: np.ndarray,
145
+ landmarks: np.ndarray,
146
+ target_size: None | tuple = None,
147
+ scale: float = 1.3,
148
+ mask: np.ndarray = None,
149
+ ):
150
+ """
151
+ Aligns a face based on 5-point facial landmarks (eyes, nose, mouth corners).
152
+
153
+ Args:
154
+ img: Input image containing the face
155
+ landmarks: 5-point facial landmarks array with shape (5, 2)
156
+ target_size: Desired output size as (width, height) tuple
157
+ scale: Scaling factor to control how much context around the face to include
158
+ stabilize_features: Whether to use standard reference points for consistent alignment
159
+ return_transform: Whether to return the transformation matrix
160
+ mask: Resize mask the same way as img
161
+
162
+ Returns:
163
+ Aligned face image with specified target_size
164
+ Optionally returns the transformation matrix if return_transform=True
165
+ """
166
+ dst = np.array(
167
+ [
168
+ [0.34, 0.46],
169
+ [0.66, 0.46],
170
+ [0.5, 0.64],
171
+ [0.37, 0.82],
172
+ [0.63, 0.82],
173
+ ],
174
+ dtype=np.float32,
175
+ )
176
+
177
+ if target_size is None:
178
+ # Compute desired distances between all pairs
179
+ desired_dists = np.linalg.norm(landmarks[:, None, :] - landmarks[None, :, :], axis=-1)
180
+
181
+ # Destination distances between all pairs
182
+ dst_dists = np.linalg.norm(dst[:, None, :] - dst[None, :, :], axis=-1)
183
+
184
+ # Take upper triangle of the distance matrix
185
+ upper_triangle_indices = np.triu_indices(len(dst), k=1)
186
+ dst_dists = dst_dists[upper_triangle_indices]
187
+ desired_dists = desired_dists[upper_triangle_indices]
188
+
189
+ # Approximate target size
190
+ approx_size = np.round(np.mean(desired_dists / dst_dists) * scale).astype(int)
191
+ target_size = (approx_size, approx_size)
192
+
193
+ dst[:, 0] = dst[:, 0] * target_size[0]
194
+ dst[:, 1] = dst[:, 1] * target_size[1]
195
+
196
+ margin_rate = scale - 1
197
+ x_margin = target_size[0] * margin_rate / 2.0
198
+ y_margin = target_size[1] * margin_rate / 2.0
199
+
200
+ # move
201
+ dst[:, 0] += x_margin
202
+ dst[:, 1] += y_margin
203
+
204
+ # resize
205
+ dst[:, 0] *= target_size[0] / (target_size[0] + 2 * x_margin)
206
+ dst[:, 1] *= target_size[1] / (target_size[1] + 2 * y_margin)
207
+
208
+ src = landmarks.astype(np.float32)
209
+
210
+ M = cv2.estimateAffinePartial2D(src, dst, method=cv2.LMEDS)[0]
211
+
212
+ img = cv2.warpAffine(img, M, target_size, flags=cv2.INTER_LINEAR)
213
+
214
+ # Warp landmarks, show
215
+ # landmarks = cv2.transform(np.expand_dims(landmarks, axis=0), M)[0]
216
+ # for point in landmarks:
217
+ # cv2.circle(img, tuple(point.astype(int)), 2, (0, 255, 0), -1)
218
+
219
+ if mask is not None:
220
+ mask = cv2.warpAffine(mask, M, target_size, flags=cv2.INTER_NEAREST)
221
+
222
+ return img, mask
223
+
224
+
225
+ def process_video(
226
+ source_path,
227
+ target_path,
228
+ mask_path,
229
+ model: RetinaFace,
230
+ scale=1.3,
231
+ target_size=(256, 256),
232
+ stride=1,
233
+ num_frames=32,
234
+ mode="at_least",
235
+ skip_processed_videos=False,
236
+ skip_processed_frames=False,
237
+ ):
238
+ frame_save_path = target_path.replace(".mp4", "/frames")
239
+
240
+ # Skip if frame_save_path exists
241
+ if skip_processed_videos and os.path.exists(frame_save_path):
242
+ print(f"Frames for {source_path} already processed.")
243
+ return
244
+ else:
245
+ print(f"Processing {source_path}")
246
+
247
+ # Create a frame generator from video path for iteration of frames
248
+ frame_generator = get_video_frames_generator(
249
+ source_path,
250
+ mask_path,
251
+ stride=stride,
252
+ num_frames=num_frames,
253
+ mode=mode,
254
+ )
255
+ # desc = f"Processing {os.path.basename(source_path)}"
256
+
257
+ num_saved = 0
258
+ for frame, frame_id, mask in frame_generator:
259
+ frame_filename = os.path.join(frame_save_path, f"frame_{frame_id:04d}.png")
260
+
261
+ if skip_processed_frames and os.path.exists(frame_filename):
262
+ print(f"Frame {frame_id} of {source_path} already processed.")
263
+ num_saved += 1
264
+ if mode in ["fixed_stride", "at_least"] and num_saved >= num_frames and num_frames != -1:
265
+ break
266
+ continue
267
+
268
+ try:
269
+ preds = model.detect(frame)
270
+ except Exception as e:
271
+ print(f"Error during detection: {e}")
272
+ continue
273
+
274
+ xyxy, landmarks = preds
275
+
276
+ if len(xyxy) == 0:
277
+ print(f"No faces detected in frame {frame_id} of {source_path}")
278
+ continue
279
+
280
+ selected_landmarks = None
281
+
282
+ if mask is not None:
283
+ # It is possible that the mask is empty -> skip this frame
284
+ if mask.sum() == 0:
285
+ print(f"Warning: Mask is empty for frame {frame_id} of {source_path}. Skipping.")
286
+ continue
287
+
288
+ # Convert mask to grayscale if it's not already
289
+ mask_img = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) if len(mask.shape) == 3 else mask
290
+
291
+ # Threshold the mask to create a binary mask
292
+ mask_img = cv2.threshold(mask_img, 1, 255, cv2.THRESH_BINARY)[1]
293
+
294
+ # Find the face that intersects the most with the mask
295
+ best_landmarks = None
296
+ max_intersection = 0
297
+ for i in range(len(xyxy)):
298
+ # Get the bounding box coordinates
299
+ x1, y1, x2, y2 = xyxy[i, :4].astype(int)
300
+
301
+ # Create a mask for the face
302
+ face_mask = np.zeros_like(mask_img)
303
+ face_mask[y1:y2, x1:x2] = 255
304
+
305
+ # Calculate the intersection between the face mask and the provided mask
306
+ intersection = np.sum(np.logical_and(face_mask, mask_img))
307
+
308
+ # Update the best face if the intersection is greater than the current maximum
309
+ if intersection > max_intersection:
310
+ max_intersection = intersection
311
+ best_landmarks = landmarks[i]
312
+
313
+ # If a face was found, use it; otherwise, skip this frame
314
+ if best_landmarks is not None:
315
+ selected_landmarks = best_landmarks
316
+ else:
317
+ print(f"No suitable face found in frame {frame_id} of {source_path} with the provided mask.")
318
+ continue
319
+
320
+ # """
321
+ # Select landmarks of the largest face if not using mask
322
+ if selected_landmarks is None:
323
+ areas = (xyxy[:, 2] - xyxy[:, 0]) * (xyxy[:, 3] - xyxy[:, 1])
324
+ idx = np.argmax(areas)
325
+ selected_landmarks = landmarks[idx]
326
+
327
+ # Show all landmarks
328
+ # for L, B in zip(landmarks, xyxy):
329
+ # for point in L:
330
+ # cv2.circle(frame, tuple(point.astype(int)), 2, (0, 255, 0), -1)
331
+ # cv2.rectangle(frame, tuple(B[0:2].astype(int)), tuple(B[2:4].astype(int)), (0, 255, 0), 2)
332
+
333
+ # Align the face
334
+ aligned_face, _ = align_face(frame, selected_landmarks, target_size=target_size, scale=scale)
335
+
336
+ # Save the aligned face
337
+ os.makedirs(frame_save_path, exist_ok=True)
338
+ cv2.imwrite(frame_filename, aligned_face)
339
+ # """
340
+
341
+ num_saved += 1
342
+
343
+ if mode in ["fixed_stride", "at_least"] and num_saved >= num_frames and num_frames != -1:
344
+ break
345
+
346
+ if num_saved == 0:
347
+ print(f"No faces were saved from {source_path}. Check the detection threshold or input video.")
348
+
349
+ return frame_save_path
350
+
351
+
352
+ def process_image(
353
+ source_path,
354
+ target_path,
355
+ model: RetinaFace,
356
+ scale=1.3,
357
+ target_size=(256, 256),
358
+ skip_processed_frames=False,
359
+ ):
360
+ """Processes a single image file."""
361
+ if skip_processed_frames and os.path.exists(target_path):
362
+ print(f"Image {source_path} already processed.")
363
+ return target_path
364
+ else:
365
+ print(f"Processing {source_path}")
366
+
367
+ img = cv2.imread(source_path)
368
+ if img is None:
369
+ print(f"Failed to read image {source_path}")
370
+ return None
371
+
372
+ try:
373
+ preds = model.detect(img)
374
+ except Exception as e:
375
+ print(f"Error during detection: {e}")
376
+ return None
377
+
378
+ xyxy, landmarks = preds
379
+
380
+ if len(xyxy) == 0:
381
+ print(f"No faces detected in {source_path}")
382
+ return None
383
+
384
+ # Select landmarks of the largest face
385
+ areas = (xyxy[:, 2] - xyxy[:, 0]) * (xyxy[:, 3] - xyxy[:, 1])
386
+ idx = np.argmax(areas)
387
+ landmarks = landmarks[idx]
388
+
389
+ # Align the face
390
+ aligned_face, _ = align_face(img, landmarks, target_size=target_size, scale=scale)
391
+
392
+ # Save the aligned face
393
+ os.makedirs(os.path.dirname(target_path), exist_ok=True)
394
+ cv2.imwrite(target_path, aligned_face)
395
+ return target_path
396
+
397
+
398
+ def get_output_path(source_path, input_folder, output_folder):
399
+ # Example: source_path = input_folder + new_source_path``
400
+ new_source_path = source_path.replace(input_folder, os.path.basename(input_folder))
401
+ # Create directory for each video
402
+ new_source_path = new_source_path.replace(".mp4", "")
403
+ # Place it in the output folder
404
+ output_path = os.path.join(output_folder, new_source_path)
405
+ return output_path
406
+
407
+
408
+ def get_mask_path(input_folder, input_mask_folder, source_path):
409
+ if input_mask_folder is not None:
410
+ # Change the input folder to the mask folder
411
+ source_path = source_path.replace(input_folder, input_mask_folder)
412
+
413
+ #! FF++ has masks named the same way as original videos
414
+ if "FaceForensics" in source_path or "FF++" in source_path:
415
+ return source_path
416
+
417
+ #! Else assume masks are named with _mask suffix
418
+ source_path = source_path.replace(".mp4", "_mask.mp4")
419
+ return source_path
420
+ return None
421
+
422
+
423
+ def process_mixed_types(
424
+ input_folder_or_file: str | list[str],
425
+ input_mask_folder: None | str,
426
+ model: RetinaFace,
427
+ num_workers=1,
428
+ scale=1.3,
429
+ target_size=(256, 256),
430
+ stride=1,
431
+ num_frames=32,
432
+ mode: str = "fixed_num_frames",
433
+ output_folder: str = "outputs",
434
+ possible_extensions: tuple[str] = ("mp4", "jpg", "png", "jpeg"),
435
+ skip_processed_videos: bool = False,
436
+ skip_processed_frames: bool = False,
437
+ ):
438
+ if os.path.isfile(input_folder_or_file):
439
+ # If input is a file
440
+ if input_folder_or_file.endswith(possible_extensions):
441
+ # If input is a media file
442
+ files = [input_folder_or_file]
443
+ elif input_folder_or_file.endswith("txt"):
444
+ # If input is a txt file
445
+ with open(input_folder_or_file, "r") as f:
446
+ files = f.read().splitlines()
447
+
448
+ else:
449
+ # If input is a folder
450
+ files = find_files(input_folder_or_file, possible_extensions)
451
+
452
+ if not files:
453
+ print(f"No files found in {input_folder_or_file}")
454
+ return
455
+
456
+ def process(source_path):
457
+ output_path = get_output_path(source_path, input_folder_or_file, output_folder)
458
+
459
+ if source_path.endswith(".mp4"):
460
+ mask_path = get_mask_path(input_folder_or_file, input_mask_folder, source_path)
461
+ try:
462
+ return process_video(
463
+ source_path,
464
+ output_path,
465
+ mask_path,
466
+ model,
467
+ scale=scale,
468
+ target_size=target_size,
469
+ stride=stride,
470
+ num_frames=num_frames,
471
+ mode=mode,
472
+ skip_processed_videos=skip_processed_videos,
473
+ skip_processed_frames=skip_processed_frames,
474
+ )
475
+ except Exception as e:
476
+ print(f"Error processing video {source_path}: {e}")
477
+ else:
478
+ try:
479
+ return process_image(
480
+ source_path,
481
+ output_path,
482
+ model,
483
+ scale=scale,
484
+ target_size=target_size,
485
+ skip_processed_frames=skip_processed_frames,
486
+ )
487
+ except Exception as e:
488
+ print(f"Error processing image {source_path}: {e}")
489
+
490
+ files = sorted(files) # Sort files for consistent processing
491
+ with ThreadPoolExecutor(max_workers=num_workers) as executor:
492
+ futures = [executor.submit(process, file) for file in files]
493
+ for future in tqdm(futures, desc=f"Processing videos in {input_folder_or_file}", leave=True):
494
+ future.result()
495
+
496
+ print("Processing complete.")
497
+
498
+
499
+ def find_files_fd(start_dir, extensions):
500
+ """
501
+ Finds files with given extensions recursively using the 'fd' command-line tool.
502
+
503
+ Args:
504
+ start_dir (str): The directory to start searching from.
505
+ extensions (list): A list of file extensions without the leading dot (e.g., ['png', 'jpg']).
506
+
507
+ Returns:
508
+ list: A list of full path strings for each found file. Returns empty list if fd fails.
509
+
510
+ Raises:
511
+ FileNotFoundError: If the 'fd' command is not found in the system's PATH.
512
+ """
513
+ if not os.path.isdir(start_dir):
514
+ print(f"Error: Start directory not found: {start_dir}")
515
+ return []
516
+
517
+ try:
518
+ # Build the command. Use -e for each extension.
519
+ command = ["fd", "--type", "f", "--type", "l"] # Find only files or links to files
520
+ for ext in extensions:
521
+ # fd expects extensions without the dot
522
+ command.extend(["--extension", ext])
523
+ # Add the pattern ('.' matches everything, filtering is done by extension)
524
+ # and the directory to search
525
+ command.extend([".", start_dir])
526
+
527
+ # Run the command
528
+ result = subprocess.run(
529
+ command,
530
+ capture_output=True, # Capture stdout and stderr
531
+ text=True, # Decode output as text (UTF-8 by default)
532
+ check=False, # Do not raise exception on non-zero exit code automatically
533
+ encoding="utf-8", # Be explicit about encoding
534
+ )
535
+
536
+ # Check if fd ran successfully
537
+ if result.returncode != 0:
538
+ # fd returns specific exit codes, e.g., 1 for errors, 2 if pattern not found (but we use '.')
539
+ # We mainly care if the command executed but maybe found nothing or had an issue.
540
+ # Check stderr for actual errors.
541
+ if result.stderr:
542
+ print(f"Error running fd (code {result.returncode}): {result.stderr.strip()}")
543
+ # If stderr is empty but code isn't 0, it might just mean no files found, which is okay.
544
+ # We return an empty list in case of errors or no files found.
545
+ return [] # Return empty list on error or if no files found
546
+
547
+ # fd outputs one path per line. Split the output.
548
+ # .strip() removes potential leading/trailing whitespace/newlines
549
+ file_list = result.stdout.strip().splitlines()
550
+ return file_list
551
+
552
+ except FileNotFoundError:
553
+ raise # Re-raise the exception so the caller knows fd is missing
554
+
555
+ except Exception as e:
556
+ print(f"An unexpected error occurred while running fd: {e}")
557
+ return [] # Return empty list on other unexpected errors
558
+
559
+
560
+ def find_files_glob(start_dir, extensions):
561
+ """
562
+ Finds files with given extensions recursively using glob.
563
+
564
+ Args:
565
+ start_dir (str): The directory to start searching from.
566
+ extensions (list): A list of file extensions without the leading dot (e.g., ['png', 'jpg']).
567
+
568
+ Returns:
569
+ list: A list of full path strings for each found file.
570
+ """
571
+ files = []
572
+ for ext in extensions:
573
+ files.extend(glob(f"{start_dir}/**/*{ext}", recursive=True))
574
+ return sorted(f for f in files if os.path.isfile(f))
575
+
576
+
577
+ def find_files(start_dir, extensions):
578
+ try:
579
+ return find_files_fd(start_dir, extensions)
580
+ except Exception:
581
+ return find_files_glob(start_dir, extensions)
582
+
583
+
584
+ def get_args():
585
+ parser = argparse.ArgumentParser()
586
+ parser.add_argument(
587
+ "-i",
588
+ "--input_folder_or_file",
589
+ type=str,
590
+ required=True,
591
+ help="Path to the input folder containing videos or images.",
592
+ )
593
+ parser.add_argument(
594
+ "--mask_folder",
595
+ type=str,
596
+ default=None,
597
+ help="Path to the input folder containing masks (optional).",
598
+ )
599
+ parser.add_argument(
600
+ "--num_workers",
601
+ type=int,
602
+ default=8,
603
+ help="Number of worker threads.",
604
+ )
605
+ parser.add_argument(
606
+ "-s",
607
+ "--scale",
608
+ type=float,
609
+ default=1.3,
610
+ help="Scale factor for face alignment.",
611
+ )
612
+ parser.add_argument(
613
+ "--target_size",
614
+ type=str,
615
+ default="256,256",
616
+ help="Target size for aligned faces as width, height (e.g., 256,256) or 'none'.",
617
+ )
618
+ parser.add_argument(
619
+ "--det_thres",
620
+ type=float,
621
+ default=0.4,
622
+ help="Detection threshold for RetinaFace.",
623
+ )
624
+ parser.add_argument(
625
+ "-m",
626
+ "--mode",
627
+ type=str,
628
+ default="at_least",
629
+ choices=["fixed_num_frames", "fixed_stride", "at_least"],
630
+ help="Mode for frame extraction from videos ('fixed_num_frames', 'fixed_stride', or 'at_least').",
631
+ )
632
+ parser.add_argument(
633
+ "--stride",
634
+ type=int,
635
+ default=1,
636
+ help="Stride for frame extraction from videos (only used in 'fixed_stride' mode).",
637
+ )
638
+ parser.add_argument(
639
+ "-n",
640
+ "--num_frames",
641
+ type=int,
642
+ default=32,
643
+ help="Maximum number of frames to extract from each video, -1 for all frames.",
644
+ )
645
+ parser.add_argument(
646
+ "-o",
647
+ "--output_folder",
648
+ type=str,
649
+ default="outputs",
650
+ help="Output folder for the preprocessed images.",
651
+ )
652
+ parser.add_argument(
653
+ "--skip_processed_videos",
654
+ action="store_true",
655
+ help="Skip videos that have already been processed.",
656
+ )
657
+ parser.add_argument(
658
+ "--skip_processed_frames",
659
+ action="store_true",
660
+ help="Skip frames that have already been processed.",
661
+ )
662
+ args = parser.parse_args()
663
+ args.target_size = parse_target_size(args.target_size)
664
+ return args
665
+
666
+
667
+ def parse_target_size(target_size_str):
668
+ try:
669
+ width, height = map(int, target_size_str.split(","))
670
+ return (width, height)
671
+ except ValueError:
672
+ if "none" in target_size_str.lower():
673
+ return None
674
+ raise ValueError("Invalid target_size format. Use 'width,height' or 'none'.")
675
+
676
+
677
+ def main():
678
+ args = get_args()
679
+
680
+ model = prepare_model(args.det_thres)
681
+
682
+ process_mixed_types(
683
+ input_folder_or_file=args.input_folder_or_file,
684
+ input_mask_folder=args.mask_folder,
685
+ model=model,
686
+ num_workers=args.num_workers,
687
+ scale=args.scale,
688
+ target_size=args.target_size,
689
+ stride=args.stride,
690
+ num_frames=args.num_frames,
691
+ mode=args.mode,
692
+ output_folder=args.output_folder,
693
+ skip_processed_videos=args.skip_processed_videos,
694
+ skip_processed_frames=args.skip_processed_frames,
695
+ )
696
+
697
+ exit(0)
698
+
699
+
700
+ if __name__ == "__main__":
701
+ main()
pyproject.toml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.ruff]
2
+ line-length = 120
3
+
4
+ [tool.ruff.lint]
5
+ ignore = [
6
+ "C901", # complex condition
7
+ "E501", # line too long
8
+ "F401", # imported but unused
9
+ "F403", # from module import * used; unable to detect undefined names
10
+ "F405", # name may be undefined, or defined from star imports: module
11
+ "E741", # ambiguous variable name
12
+ ]
13
+
14
+ select = [
15
+ "C", # flake8-comprehensions
16
+ "E", "W", # pycodestyle
17
+ "F", # pyflakes
18
+ "I", # isort
19
+ ]
20
+
21
+ [tool.ruff.lint.isort]
22
+ force-to-top = ["autoroot", "autorootcwd"]
23
+
24
+ [tool.ruff.lint.per-file-ignores]
25
+ "**/__init__.py" = ["E402"]
26
+
27
+ [tool.pyright]
28
+ exclude = [
29
+ "**/__pycache__",
30
+ "wandb",
31
+ "datasets",
32
+ "outputs",
33
+ "runs",
34
+ "tmp",
35
+ "logs",
36
+ ]
37
+ typeCheckingMode = "off"
requirements.txt ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.8.0
2
+ torchaudio==2.8.0
3
+ torchvision==0.23.0
4
+ lightning==2.5.5
5
+ transformers==4.56.2
6
+ tqdm==4.67.1 # progress bar
7
+ timm==1.0.20 # torch models
8
+ matplotlib==3.10.6 # visualization
9
+ seaborn==0.13.2 # visualization
10
+ scikit-learn==1.6.1 # metrics
11
+ rich==14.1.0 # logging
12
+ wandb==0.22.0 # logging
13
+ pydantic==2.11.9 # config
14
+ # albumentations==1.4.17 # augmentations
15
+ ruff==0.13.2 # formatting
16
+ fire==0.7.0 # CLI
17
+ pytorch-metric-learning==2.8.1 # losses
18
+ peft==0.15.2 # parameter-efficient fine-tuning
19
+ ipykernel==6.30.1 # jupyter
20
+ autoroot==1.0.1 # root utils
21
+ autorootcwd==1.0.1 # root utils
22
+ xformers==0.0.32.post2 # RADIOv2.5/3
23
+ einops==0.8.1 # RADIOv2.5/3
24
+ open-clip-torch==2.32.0 # RADIOv2.5/3
25
+ grad-cam==1.5.5 # for Grad-CAM visualization
26
+ mediapipe==0.10.21 # Face landmark detection
27
+
28
+ # --- for detector.py ---
29
+ # opencv-python==4.11.0.86 # mainly only for detector.py
30
+ opencv-python-headless==4.12.0.88 # mainly for `detector.py`
31
+ onnxruntime-gpu==1.21.0 # for ONNX model inference
32
+
33
+ # --- for app/run.py ---
34
+ gradio==5.49.1
run.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import traceback
3
+
4
+ import torch
5
+ from lightning import Trainer
6
+ from lightning.pytorch import callbacks as pl_callbacks
7
+ from lightning.pytorch import loggers as pl_loggers
8
+ from rich import traceback as rich_traceback
9
+
10
+ from src import dataset as datasets
11
+ from src.config import Config
12
+ from src.model.base import BaseDeepakeDetectionModel
13
+ from src.utils import logger
14
+ from src.utils.checks import checks
15
+ from src.utils.model_checkpoint import ModelCheckpointParallel
16
+
17
+ rich_traceback.install()
18
+
19
+
20
+ def load_third_party_model(config: Config) -> BaseDeepakeDetectionModel:
21
+ if "weights/Effort" in config.checkpoint:
22
+ # Download: https://drive.google.com/drive/folders/19kQwGDjF18uk78EnnypxxOLaG4Aa4v1h
23
+ from src.model.Effort import Effort
24
+
25
+ return Effort(config)
26
+
27
+ if "weights/ForAda" in config.checkpoint:
28
+ # Download: https://drive.usercontent.google.com/download?id=1UlaAUTtsX87ofIibf38TtfAKIsnA7WVm&export=download&authuser=0
29
+ from src.model.ForAda import ForAda
30
+
31
+ return ForAda(config)
32
+
33
+ if "weights/FS-VFM/" in config.checkpoint:
34
+ from src.model.FSFM import FSFM
35
+
36
+ return FSFM(config)
37
+
38
+ if "yermandy/" in config.checkpoint:
39
+ # https://huggingface.co/yermandy/models
40
+ from src.model.GenDHF import GenDHF
41
+
42
+ return GenDHF(config)
43
+
44
+
45
+ raise ValueError(f"Unknown third party model in checkpoint path: {config.checkpoint}")
46
+
47
+
48
+ def load_model(config: Config) -> BaseDeepakeDetectionModel:
49
+ # If no checkpoint is provided, use GenD as default
50
+ if config.checkpoint is None or config.checkpoint == "":
51
+ from src.model.GenD import GenD
52
+
53
+ return GenD(config, verbose=True)
54
+
55
+ # Try to load third party model
56
+ try:
57
+ return load_third_party_model(config)
58
+ except ValueError:
59
+ # If not a third party model, use GenD as default
60
+ from src.model.GenD import GenD
61
+
62
+ return GenD(config, verbose=True)
63
+
64
+
65
+ def init_loggers(config: Config) -> list:
66
+ save_dir = f"{config.run_dir}/{config.run_name}"
67
+
68
+ loggers: list = [pl_loggers.CSVLogger(config.run_dir, name=config.run_name, version="")]
69
+
70
+ if config.wandb:
71
+ wandb_logger = pl_loggers.WandbLogger(
72
+ project="deepfake",
73
+ name=config.run_name,
74
+ save_dir=save_dir,
75
+ tags=set(config.wandb_tags),
76
+ group=config.wandb_group,
77
+ )
78
+ loggers.append(wandb_logger)
79
+
80
+ return loggers
81
+
82
+
83
+ def init_callbacks(config: Config) -> list:
84
+ callbacks = [
85
+ pl_callbacks.RichProgressBar(leave=True),
86
+ ModelCheckpointParallel(
87
+ filename=config.checkpoint_name, monitor=config.monitor_metric, mode=config.monitor_metric_mode
88
+ ),
89
+ ]
90
+ # pl_callbacks.LearningRateFinder(1e-5, 1e-2),
91
+
92
+ if config.early_stopping_patience > 0:
93
+ callbacks.append(
94
+ pl_callbacks.EarlyStopping(
95
+ monitor=config.monitor_metric,
96
+ patience=config.early_stopping_patience,
97
+ mode=config.monitor_metric_mode,
98
+ verbose=True,
99
+ )
100
+ )
101
+
102
+ return callbacks
103
+
104
+
105
+ def finish_wandb_run(trainer, config: Config):
106
+ if config.wandb:
107
+ if any(isinstance(l, pl_loggers.WandbLogger) for l in trainer.loggers):
108
+ wandb_logger = [l for l in trainer.loggers if isinstance(l, pl_loggers.WandbLogger)][0]
109
+ wandb_logger.finalize("success")
110
+ wandb_logger.experiment.finish()
111
+
112
+
113
+ def main(config: Config, train: bool):
114
+ # Performs initial checks
115
+ checks(config)
116
+
117
+ # Set the precision for matmul operations
118
+ torch.set_float32_matmul_precision("high")
119
+
120
+ # Instantiates the model
121
+ model = load_model(config)
122
+
123
+ # Loads the checkpoint if provided
124
+ model.load_checkpoint(config.checkpoint)
125
+
126
+ data_module = datasets.DeepfakeDataModule(config, model.get_preprocessing())
127
+
128
+ save_dir = f"{config.run_dir}/{config.run_name}"
129
+
130
+ trainer = Trainer(
131
+ devices=config.devices,
132
+ max_epochs=config.max_epochs,
133
+ precision=config.precision,
134
+ accumulate_grad_batches=config.batch_size // config.mini_batch_size,
135
+ fast_dev_run=config.fast_dev_run,
136
+ log_every_n_steps=100,
137
+ overfit_batches=config.overfit_batches,
138
+ limit_train_batches=config.limit_train_batches,
139
+ limit_val_batches=config.limit_val_batches,
140
+ limit_test_batches=config.limit_test_batches,
141
+ deterministic=config.deterministic,
142
+ detect_anomaly=config.detect_anomaly,
143
+ logger=init_loggers(config),
144
+ callbacks=init_callbacks(config),
145
+ default_root_dir=config.run_dir,
146
+ )
147
+
148
+ if train:
149
+ try:
150
+ trainer.fit(model, data_module)
151
+ except KeyboardInterrupt:
152
+ logger.print_warning("Training interrupted")
153
+ except Exception as e:
154
+ traceback.print_exc() # Print complete exception traceback
155
+ logger.print_error(f"Training failed: {e}")
156
+ # Save the exception traceback to a file
157
+ with open(f"{save_dir}/failed.log", "a") as f:
158
+ f.write(f"Training failed: {e}\n")
159
+ f.write(traceback.format_exc())
160
+ finally:
161
+ logger.print_info("Training finished. Starting testing")
162
+ ckpt_path = f"{save_dir}/checkpoints/{config.checkpoint_name}.ckpt"
163
+ if not os.path.exists(ckpt_path):
164
+ logger.print_error(f"Checkpoint {ckpt_path} does not exist. Cannot proceed with testing.")
165
+ else:
166
+ model.load_checkpoint(ckpt_path)
167
+ trainer.test(model, data_module)
168
+
169
+ else:
170
+ assert config.checkpoint is not None, "Checkpoint is required for testing"
171
+ trainer.test(model, data_module)
172
+
173
+ # Finish wandb run
174
+ finish_wandb_run(trainer, config)
run_exp.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import traceback
2
+ from copy import deepcopy
3
+
4
+ import fire
5
+
6
+ from run import main
7
+ from src import config as C
8
+ from src.config import Config
9
+ from src.exp import experiments
10
+ from src.utils import files, logger
11
+
12
+
13
+ def get_val_files():
14
+ return [
15
+ *files.DeepSpeak_v2.my_val,
16
+ *files.DeepSpeak_v1_1.my_val,
17
+ *files.CDFv2.val,
18
+ *files.FFIW.val,
19
+ ]
20
+
21
+
22
+ def get_test_files():
23
+ return {
24
+ "FF": files.FF.test,
25
+ "FF-DF": files.FF.DF.test,
26
+ "FF-F2F": files.FF.F2F.test,
27
+ "FF-FS": files.FF.FS.test,
28
+ "FF-NT": files.FF.NT.test,
29
+ "CDF": files.CDFv2.test,
30
+ "FaceFusion": files.FaceFusion.CDF.test,
31
+ "DFD": files.DFD.test,
32
+ "DFDC": files.DFDC.test,
33
+ "FSh": files.FSh.test,
34
+ "UADFD": files.UADFV.test,
35
+ "DFDM": files.DFDM.test,
36
+ "FFIW": files.FFIW.test,
37
+ "DeepSpeak-1.1": files.DeepSpeak_v1_1.test,
38
+ "DeepSpeak-2.0": files.DeepSpeak_v2.test,
39
+ "KoDF": files.KoDF.test,
40
+ "KoDF-adv": files.KoDF.adversarial,
41
+ "FakeAVCeleb": files.FakeAVCeleb.test,
42
+ "FAVC-FV-RA-WL": files.FakeAVCeleb.FV_RA_WL.test,
43
+ "FAVC-FV-FA-FS": files.FakeAVCeleb.FV_FA_FS.test,
44
+ "FAVC-FV-FA-GAN": files.FakeAVCeleb.FV_FA_GAN.test,
45
+ "FAVC-FV-FA-WL": files.FakeAVCeleb.FV_FA_WL.test,
46
+ "PolyGlotFake": files.PolyGlotFake.test,
47
+ "IDForge-v1": files.IDForge_v1.test,
48
+ } | {
49
+ k: v.map(lambda x: x.replace("/CDFv3/", "/CDFv3-x1.3-th0.5-all/subset/uniform-32-frames/"))
50
+ for k, v in files.CDFv3.get_test_dict().items()
51
+ }
52
+
53
+
54
+ def get_default_train_config() -> Config:
55
+ config = Config()
56
+
57
+ config.run_dir = "runs/rebuttal"
58
+ config.wandb = True
59
+ config.wandb_tags.append("rebuttal")
60
+ config.throw_exception_if_run_exists = True
61
+
62
+ config.num_workers = 12
63
+ config.devices = "auto"
64
+
65
+ config.backbone = C.Backbone.CLIP_L_14
66
+ config.freeze_feature_extractor = True
67
+ config.num_classes = 2
68
+
69
+ config.batch_size = config.mini_batch_size = 128
70
+ config.lr_scheduler = "cosine"
71
+ config.lr = 3e-4
72
+ config.min_lr = 1e-5
73
+ config.weight_decay = 0
74
+ config.max_epochs = 1 + 50
75
+ config.warmup_epochs = 1
76
+
77
+ config.trn_files = files.FF.train
78
+ config.val_files = get_val_files()
79
+ config.tst_files = get_test_files()
80
+
81
+ return config
82
+
83
+
84
+ def get_default_test_config(orig_run_name, new_run_name) -> Config:
85
+ orig_run_dir = files.find_run_dir(orig_run_name)
86
+ orig_config_path = f"{orig_run_dir}/hparams.yaml"
87
+ checkpoint = "best_mAP.ckpt" # Default checkpoint name
88
+
89
+ # Load run specific config
90
+ config = C.load_config(orig_config_path)
91
+
92
+ config.run_name = new_run_name # Rename the run
93
+ config.run_dir = "runs/test" # Set default test dir
94
+ config.checkpoint = f"{orig_run_dir}/checkpoints/{checkpoint}"
95
+
96
+ config.wandb = True
97
+ config.wandb_tags.extend(["test"])
98
+
99
+ config.num_workers = 12
100
+ config.batch_size = config.mini_batch_size = 1024
101
+ config.devices = "auto"
102
+
103
+ config.tst_files = get_test_files()
104
+
105
+ return config
106
+
107
+
108
+ def get_debug_config(config: Config) -> Config:
109
+ #! Debug
110
+
111
+ config.run_dir = "runs/tmp"
112
+ config.run_name = "tmp"
113
+ # config.num_workers = 0
114
+ config.max_epochs = 1
115
+ config.limit_train_batches = 12
116
+ config.limit_val_batches = 12
117
+ config.limit_test_batches = 12
118
+ # config.batch_size = config.mini_batch_size = 2
119
+ # config.deterministic = True
120
+ # config.detect_anomaly = True
121
+
122
+ config.trn_files = files.FF.train
123
+ config.val_files = files.FF.val
124
+ config.tst_files = files.FF.val
125
+
126
+ return config
127
+
128
+
129
+ experiments = {
130
+ **experiments, # Include all experiments defined in src.exp
131
+ }
132
+
133
+
134
+ def entry(
135
+ exp_names: str | list[str],
136
+ debug: bool = False,
137
+ test: bool = False,
138
+ from_exp: str | None = None,
139
+ **kwargs,
140
+ ):
141
+ if test:
142
+ if from_exp is not None:
143
+ if isinstance(exp_names, list):
144
+ if len(exp_names) != 1:
145
+ raise Exception("When running in test mode, you can provide only one experiment name.")
146
+ config = get_default_test_config(from_exp, exp_names[0])
147
+ else:
148
+ logger.print_warning("Running in test mode, but 'from_exp' is not provided. Using default test config.")
149
+ config = C.Config()
150
+ else:
151
+ config = get_default_train_config()
152
+
153
+ # parse name to list
154
+ if isinstance(exp_names, str):
155
+ exp_names = [exp_names]
156
+
157
+ for exp_name in exp_names:
158
+ exp_name = exp_name.strip()
159
+
160
+ if exp_name not in experiments:
161
+ logger.print_error(f"Experiment '{exp_name}' is not defined in 'src/exp/__init__.py:1'")
162
+ logger.print(f"Available experiments: {list(experiments.keys())}")
163
+ continue
164
+
165
+ modifiers = experiments[exp_name]
166
+ config_exp = deepcopy(config)
167
+
168
+ config_exp.run_name = exp_name
169
+ for modify in modifiers:
170
+ if isinstance(modify, Config):
171
+ # If the modifier is a Config object, change only different values
172
+ difference = modify.model_dump(exclude_unset=True)
173
+ # TODO: maybe set_values_from_dict(difference)?
174
+ config_exp = Config(**config_exp.model_copy(update=difference).model_dump())
175
+ # config_exp = config_exp.model_copy(update=difference)
176
+ else:
177
+ config_exp = modify(config_exp)
178
+
179
+ config_exp = Config(**config_exp.model_dump()) # Parse and validate config
180
+
181
+ if debug:
182
+ config_exp = config_exp.model_copy(update=get_debug_config(config_exp).model_dump())
183
+
184
+ # Update config with kwargs
185
+ config_exp.set_values_from_dict(kwargs)
186
+
187
+ # Revalidate the config - checks if user provided valid values
188
+ config_exp = Config(**config_exp.model_dump())
189
+
190
+ # logger.print(config_exp)
191
+ # exit()
192
+
193
+ try:
194
+ main(config_exp, not test)
195
+
196
+ except Exception as e:
197
+ traceback.print_exc() # Print complete exception traceback
198
+ logger.print_error(f"Error occurred while running experiment '{exp_name}':")
199
+ logger.print(e)
200
+
201
+ save_dir = f"{config_exp.run_dir}/{config_exp.run_name}"
202
+ # Save the exception traceback to a file
203
+ with open(f"{save_dir}/failed.log", "a") as f:
204
+ f.write(f"\nTraining failed: {e}\n")
205
+ f.write(traceback.format_exc())
206
+
207
+
208
+ if __name__ == "__main__":
209
+ fire.Fire(entry)