xxwyyds commited on
Commit
891e05c
·
verified ·
1 Parent(s): abffab7

Upload 86 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
.gitattributes CHANGED
@@ -49,3 +49,19 @@ loupe/ffhq/ffhq-0142.png filter=lfs diff=lfs merge=lfs -text
49
  loupe/ffhq/ffhq-0154.png filter=lfs diff=lfs merge=lfs -text
50
  loupe/ffhq/ffhq-0155.png filter=lfs diff=lfs merge=lfs -text
51
  loupe/ffhq/ffhq-0505.png filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  loupe/ffhq/ffhq-0154.png filter=lfs diff=lfs merge=lfs -text
50
  loupe/ffhq/ffhq-0155.png filter=lfs diff=lfs merge=lfs -text
51
  loupe/ffhq/ffhq-0505.png filter=lfs diff=lfs merge=lfs -text
52
+ ffhq/ffhq-0001.png filter=lfs diff=lfs merge=lfs -text
53
+ ffhq/ffhq-0009.png filter=lfs diff=lfs merge=lfs -text
54
+ ffhq/ffhq-0032.png filter=lfs diff=lfs merge=lfs -text
55
+ ffhq/ffhq-0041.png filter=lfs diff=lfs merge=lfs -text
56
+ ffhq/ffhq-0055.png filter=lfs diff=lfs merge=lfs -text
57
+ ffhq/ffhq-0062.png filter=lfs diff=lfs merge=lfs -text
58
+ ffhq/ffhq-0085.png filter=lfs diff=lfs merge=lfs -text
59
+ ffhq/ffhq-0096.png filter=lfs diff=lfs merge=lfs -text
60
+ ffhq/ffhq-0100.png filter=lfs diff=lfs merge=lfs -text
61
+ ffhq/ffhq-0112.png filter=lfs diff=lfs merge=lfs -text
62
+ ffhq/ffhq-0136.png filter=lfs diff=lfs merge=lfs -text
63
+ ffhq/ffhq-0138.png filter=lfs diff=lfs merge=lfs -text
64
+ ffhq/ffhq-0142.png filter=lfs diff=lfs merge=lfs -text
65
+ ffhq/ffhq-0154.png filter=lfs diff=lfs merge=lfs -text
66
+ ffhq/ffhq-0155.png filter=lfs diff=lfs merge=lfs -text
67
+ ffhq/ffhq-0505.png filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 kamichanw
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,14 +1,100 @@
1
- ---
2
- title: Loupe
3
- emoji: 🌖
4
- colorFrom: pink
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 5.36.2
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: deepfake image detection and localization.
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Loupe
2
+
3
+ The 1st place solution of IJCAI 2025 challenge track 1: Forgery Image Detection and Localization. The top 5 final leaderboard is as follows:
4
+ | User | Overall Score |
5
+ |:---:|:---:|
6
+ | Loupe (ours) | **0.846** |
7
+ | Rank 2 | 0.8161 |
8
+ | Rank 3 | 0.8151 |
9
+ | Rank 4 | 0.815 |
10
+ | Rank 5 | 0.815 |
11
+
12
+
13
+ ## Setup
14
+ ### 1. Create environment
15
+ ```bash
16
+ conda create -y -n loupe python=3.11
17
+ conda activate loupe
18
+ pip install -r requirements.txt
19
+ mkdir ./pretrained_weights/PE-Core-L14-336
20
+ ```
21
+
22
+ ### 2. Prepare pretrained weights
23
+ Download [Perception Encoder](https://github.com/facebookresearch/perception_models) following their original instructions, and place `PE-Core-L14-336.pt` at `./pretrained_weights/PE-Core-L14-336`. This can be done with `huggingface-cli`:
24
+ ```bash
25
+ export HF_ENDPOINT=https://hf-mirror.com
26
+ huggingface-cli download facebook/PE-Core-L14-336 PE-Core-L14-336.pt --local-dir ./pretrained_weights/PE-Core-L14-336
27
+ ```
28
+
29
+ ### 3. Prepare datasets
30
+ Download the dataset to any location of your choice. Then, use the [`dataset_preprocess.ipynb`](./dataset_preprocess.ipynb) notebook to preprocess the dataset. This process converts the dataset into a directly loadable `DatasetDict` and saves it in `parquet` format.
31
+
32
+ After preprocessing, you will obtain a dataset with three splits: `train`, `valid`, and `test`. Each item in these splits has the following structure:
33
+
34
+ ```python
35
+ {
36
+ "image": "path/to/image", # but will be loaded as an actual PIL.Image.Image object
37
+ "mask": "path/to/mask", # set to None for real images without masks
38
+ "name": "basename_of_image.png"
39
+ }
40
+ ```
41
+
42
+ > [!NOTE]
43
+ > You can also adapt to a new dataset. In this case, you need to modify [`dataset_preprocess.ipynb`](./dataset_preprocess.ipynb) for your own use.
44
+
45
+ After preparation, the last thing to do is specifying `path/to/your/dataset` in [dataset.yaml](configs/dataset.yaml).
46
+
47
+ ## How to train
48
+ Loupe employs a two or three stages training process. The first stage trains the classifier and can be executed with the following command:
49
+
50
+ ```bash
51
+ python src/train.py stage=cls
52
+ ```
53
+
54
+ During training, two directories will be automatically created:
55
+
56
+ * `./results/checkpoints` — contains the DeepSpeed-format checkpoint with the highest AUC on the validation set (when using the default training strategy, which can be configured in `./configs/base.yaml`).
57
+ * `./results/{stage.name}` — contains logs in TensorBoard format. You can monitor the training progress by running:
58
+
59
+ ```bash
60
+ tensorboard --logdir=./results/cls
61
+ # or `tensorboard --logdir=./results/seg`, etc.
62
+ ```
63
+
64
+ After training completes, the best checkpoint will be saved in the directory `./checkpoints/cls-auc=xxx.ckpt`. This directory contains several configs and checkpoint file `model.safetensors` which stores the best checkpoint in the safetensors format.
65
+
66
+ The second stage trains the segmentation head. To do so, simply replace the command line argument `stage=cls` with `stage=seg` in the stage 1 command.
67
+
68
+ The third stage is optional, which jointly trains the backbone, classifier head, and segmentation head. By default, a portion of the validation set is used as training data, while the remainder is reserved for validation. The reason why I use validation set as an extra training set is the test set used in the competition is slightly out-of-distribution (OOD). I found that
69
+ if continue training on the original training set will result in overfitting. However, if you prefer to train the whole network from scratch directly on the training set, you can do so by:
70
+ ```bash
71
+ python src/train.py stage=cls_seg \
72
+ ckpt.checkpoint_paths=[] \
73
+ model.freeze_backbone=true \
74
+ stage.train_on_trainset=true
75
+ ```
76
+
77
+ All training configurations can be adjusted within the `configs/` directory. Detailed comments are provided to facilitate quick and clear configuration.
78
+
79
+ ## How to test or predict
80
+ By default, testing is performed on the full validation set. This means it is not suitable for evaluating Loupe trained in the third stage, since the third stage trains Loupe on the validation set itself (see above). Alternatively, if you are willing to make a slight modification to [data loading process](./src/data_module.py) to have Loupe train on the training set instead, this limitation can be avoided.
81
+
82
+ To evaluate a trained model, you can run:
83
+ ```bash
84
+ python src/infer.py stage=test ckpt.checkpoint_paths=["checkpoints/cls/model.safetensors","checkpoints/seg/model.safetensors"]
85
+ ```
86
+
87
+ The `ckpt.checkpoint_paths` configuration is defined under `configs/ckpt`. It is a list that specifies the checkpoints to load sequentially during execution.
88
+
89
+ The prediction step is essentially the same as the test step. You only need to add an additional parameter to specify the output directory for predictions. For example:
90
+
91
+ ```bash
92
+ python src/infer.py stage=test \
93
+ ckpt.checkpoint_paths=["checkpoints/cls/model.safetensors","checkpoints/seg/model.safetensors"] \
94
+ stage.pred_output_dir=./pred_outputs
95
+ ```
96
+
97
+ The classification predictions will be saved in `./pred_outputs/predictions.txt`, and the mask outputs will be stored in `./pred_outputs/masks`. For more details on available parameters, please refer to `configs/stage/test.yaml`.
98
+
99
+ ## Code reading guides
100
+ Nobody cares this work, leave this section blank.
app.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import tempfile
4
+ import numpy as np
5
+ from PIL import Image
6
+ from src.predict import process_single_image
7
+ import sys
8
+ sys.path.insert(0, "./src")
9
+
10
+ def get_example_images(folder_path="/home/xxw/Loupe/ffhq"):
11
+ return [os.path.join(folder_path, f) for f in os.listdir(folder_path)
12
+ if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
13
+
14
+ def safe_extract_prob(cls_probs):
15
+ """安全地从cls_probs中提取概率值"""
16
+ try:
17
+ if cls_probs is None:
18
+ return 0.0
19
+ elif isinstance(cls_probs, (list, np.ndarray)) and len(cls_probs) > 0:
20
+ return float(cls_probs[0])
21
+ elif hasattr(cls_probs, '__getitem__'):
22
+ return float(cls_probs[0])
23
+ else:
24
+ return float(cls_probs)
25
+ except (TypeError, IndexError, ValueError) as e:
26
+ print(f"Error extracting probability: {e}")
27
+ return 0.0
28
+
29
+ with gr.Blocks(title="Loupe图像伪造检测系统", theme=gr.themes.Soft()) as demo:
30
+ gr.Markdown("""
31
+ # Loupe🕵️‍♂️ 图像伪造检测系统
32
+ ### 上传图像或从示例中选择,系统将检测图像中的伪造区域
33
+ """)
34
+
35
+ with gr.Row():
36
+ with gr.Column(scale=1):
37
+ with gr.Tab("上传图像"):
38
+ image_input = gr.Image(type="pil", label="原始图像")
39
+ upload_button = gr.Button("检测伪造", variant="primary")
40
+
41
+ with gr.Tab("选择示例"):
42
+ example_images = get_example_images()
43
+ example_dropdown = gr.Dropdown(
44
+ choices=example_images,
45
+ label="选择示例图像",
46
+ value=example_images[0] if example_images else None
47
+ )
48
+ example_button = gr.Button("检测示例", variant="secondary")
49
+
50
+ with gr.Accordion("高级选项", open=False):
51
+ threshold = gr.Slider(0, 1, value=0.5, label="检测阈值")
52
+
53
+ with gr.Column(scale=1):
54
+ gr.Markdown("### 检测结果")
55
+ with gr.Tabs():
56
+ with gr.Tab("处理后的图像"):
57
+ output_image = gr.Image(label="伪造检测结果", interactive=False)
58
+
59
+ with gr.Tab("对比视图"):
60
+ with gr.Row():
61
+ original_display = gr.Image(label="原始图像", interactive=False)
62
+ processed_display = gr.Image(label="处理后图像", interactive=False)
63
+
64
+ with gr.Group():
65
+ with gr.Row():
66
+ fake_prob = gr.Number(label="伪造概率", precision=4)
67
+ # Simplified to just show the probability as text
68
+ result_text = gr.Textbox(label="检测结果", interactive=False)
69
+
70
+ save_button = gr.Button("保存结果", variant="secondary")
71
+
72
+ gr.Markdown("""
73
+ ---
74
+ ### 关于
75
+ - **技术**: Forgery Image Detection and Localization.
76
+ - **版本**: 1.0.0
77
+ """)
78
+
79
+ def process_image(image, threshold_value):
80
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
81
+ image_path = tmp_file.name
82
+ image.save(image_path)
83
+
84
+ try:
85
+ processed_img, cls_probs = process_single_image(image_path)
86
+
87
+ # 安全提取概率值
88
+ prob = safe_extract_prob(cls_probs)
89
+ print(f"Classification probability: {prob:.4f}" if prob is not None else "No cls output")
90
+
91
+ return {
92
+ output_image: processed_img,
93
+ original_display: image,
94
+ processed_display: processed_img,
95
+ fake_prob: prob,
96
+ result_text: f"伪造概率: {prob:.4f}"
97
+ }
98
+ except Exception as e:
99
+ print(f"Error in processing: {e}")
100
+ return {
101
+ output_image: None,
102
+ original_display: image,
103
+ processed_display: None,
104
+ fake_prob: 0.0,
105
+ result_text: "处理错误"
106
+ }
107
+ finally:
108
+ if os.path.exists(image_path):
109
+ os.unlink(image_path)
110
+
111
+ def process_example(image_path, threshold_value):
112
+ try:
113
+ processed_img, cls_probs = process_single_image(image_path)
114
+
115
+ # 安全提取概率值
116
+ prob = safe_extract_prob(cls_probs)
117
+ print(f"Classification probability: {prob:.4f}" if prob is not None else "No cls output")
118
+
119
+ original_img = Image.open(image_path)
120
+
121
+ return {
122
+ image_input: original_img,
123
+ output_image: processed_img,
124
+ original_display: original_img,
125
+ processed_display: processed_img,
126
+ fake_prob: prob,
127
+ result_text: f"伪造概率: {prob:.4f}",
128
+ threshold: threshold_value
129
+ }
130
+ except Exception as e:
131
+ print(f"Error in processing example: {e}")
132
+ return {
133
+ image_input: None,
134
+ output_image: None,
135
+ original_display: None,
136
+ processed_display: None,
137
+ fake_prob: 0.0,
138
+ result_text: "处理错误",
139
+ threshold: threshold_value
140
+ }
141
+
142
+ upload_button.click(
143
+ process_image,
144
+ [image_input, threshold],
145
+ [output_image, original_display, processed_display, fake_prob, result_text]
146
+ )
147
+
148
+ example_button.click(
149
+ process_example,
150
+ [example_dropdown, threshold],
151
+ [image_input, output_image, original_display, processed_display, fake_prob, result_text, threshold]
152
+ )
153
+
154
+ save_button.click(
155
+ lambda img: img.save("result.jpg") if img else None,
156
+ [output_image],
157
+ None
158
+ )
159
+
160
+ if __name__ == "__main__":
161
+ demo.launch(server_name="0.0.0.0", server_port=7860)
162
+
163
+ # import gradio as gr
164
+ # import os
165
+ # import tempfile
166
+ # import numpy as np
167
+ # from PIL import Image
168
+ # from src.predict import process_single_image
169
+ # import sys
170
+ # sys.path.insert(0, "./src") # 确保src目录在路径中
171
+ # # 可以处理无mask的图像 也可以处理有mask的两张图像
172
+
173
+
174
+ # # 获取图像文件夹中的图片列表
175
+ # def get_example_images(folder_path="/home/xxw/Loupe/ffhq"):
176
+ # return [os.path.join(folder_path, f) for f in os.listdir(folder_path)
177
+ # if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
178
+
179
+
180
+ # # 创建Gradio界面
181
+ # with gr.Blocks(title="Loupe图像伪造检测系统", theme=gr.themes.Soft()) as demo:
182
+ # # 标题和描述
183
+ # gr.Markdown("""
184
+ # # Loupe🕵️‍♂️ 图像伪造检测系统
185
+ # ### 上传图像或从示例中选择,系统将检测图像中的伪造区域
186
+ # """)
187
+
188
+ # with gr.Row():
189
+ # # 左侧面板 - 原始图像
190
+ # with gr.Column(scale=1):
191
+ # with gr.Tab("上传图像"):
192
+ # image_input = gr.Image(type="pil", label="原始图像")
193
+ # upload_button = gr.Button("检测伪造", variant="primary")
194
+
195
+ # with gr.Tab("选择示例"):
196
+ # example_images = get_example_images()
197
+ # example_dropdown = gr.Dropdown(
198
+ # choices=example_images,
199
+ # label="选择示例图像",
200
+ # value=example_images[0] if example_images else None
201
+ # )
202
+ # example_button = gr.Button("检测示例", variant="secondary")
203
+
204
+ # with gr.Accordion("高级选项", open=False):
205
+ # threshold = gr.Slider(0, 1, value=0.5,
206
+ # label="检测阈值",
207
+ # info="调整伪造检测的敏感度")
208
+ # processing_mode = gr.Radio(
209
+ # ["快速模式", "精确模式"],
210
+ # value="快速模式",
211
+ # label="处理模式"
212
+ # )
213
+
214
+ # # 右侧输出面板
215
+ # with gr.Column(scale=1):
216
+ # gr.Markdown("### 检测结果")
217
+ # with gr.Tabs():
218
+ # with gr.Tab("处理后的图像"):
219
+ # output_image = gr.Image(label="伪造检测结果", interactive=False)
220
+
221
+ # with gr.Tab("对比视图"):
222
+ # with gr.Row():
223
+ # original_display = gr.Image(label="原始图像", interactive=False)
224
+ # processed_display = gr.Image(label="处理后图像", interactive=False)
225
+
226
+ # with gr.Group():
227
+ # with gr.Row():
228
+ # fake_prob = gr.Number(label="伪造概率", precision=2)
229
+ # result_label = gr.Label(label="检测结论")
230
+
231
+ # save_button = gr.Button("保存结果", variant="secondary")
232
+
233
+ # # 底部信息
234
+ # gr.Markdown("""
235
+ # ---
236
+ # ### 关于
237
+ # - **技术**: Forgery Image Detection and Localization.
238
+ # - **版本**: 1.0.0
239
+ # - **开发者**: xxw/teleai EVOL lab
240
+ # """)
241
+
242
+ # # 定义处理函数
243
+ # def process_image(image, threshold_value):
244
+ # # 创建一个临时文件保存上传的图像
245
+ # with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
246
+ # image_path = tmp_file.name
247
+ # image.save(image_path)
248
+
249
+ # try:
250
+ # # 调用你的处理函数
251
+ # processed_img, cls_probs = process_single_image(image_path)
252
+
253
+ # # 获取伪造概率(假设cls_probs是一个数组,取第一个值)
254
+ # prob = float(cls_probs[0]) if cls_probs else 0.0
255
+
256
+ # # 确定结果 - 修改为返回字典格式
257
+ # result = {
258
+ # "label": "伪造图像" if prob > threshold_value else "真实图像",
259
+ # "confidences": [
260
+ # {"label": "伪造", "confidence": prob},
261
+ # {"label": "真实", "confidence": 1 - prob}
262
+ # ]
263
+ # }
264
+
265
+ # return {
266
+ # output_image: processed_img,
267
+ # original_display: image,
268
+ # processed_display: processed_img,
269
+ # fake_prob: prob,
270
+ # result_label: result # 使用正确的字典格式
271
+ # }
272
+ # finally:
273
+ # # 清理临时文件
274
+ # if os.path.exists(image_path):
275
+ # os.unlink(image_path)
276
+
277
+ # def process_example(image_path, threshold_value):
278
+ # # 直接调用你的处理函数
279
+ # processed_img, cls_probs = process_single_image(image_path)
280
+
281
+ # # 获取伪造概率
282
+ # prob = float(cls_probs[0]) if cls_probs else 0.0
283
+
284
+ # # 确定结果 - 修改为返回字典格式
285
+ # result = {
286
+ # "label": "伪造图像" if prob > threshold_value else "真实图像",
287
+ # "confidences": [
288
+ # {"label": "伪造", "confidence": prob},
289
+ # {"label": "真实", "confidence": 1 - prob}
290
+ # ]
291
+ # }
292
+
293
+ # # 打开原始图像用于显示
294
+ # original_img = Image.open(image_path)
295
+
296
+ # return {
297
+ # image_input: original_img,
298
+ # output_image: processed_img,
299
+ # original_display: original_img,
300
+ # processed_display: processed_img,
301
+ # fake_prob: prob,
302
+ # result_label: result, # 使用正确的字典格式
303
+ # threshold: threshold_value
304
+ # }
305
+
306
+ # # 修改绑定事件
307
+ # upload_button.click(
308
+ # fn=process_image,
309
+ # inputs=[image_input, threshold],
310
+ # outputs=[output_image, original_display, processed_display, fake_prob, result_label]
311
+ # )
312
+
313
+ # def load_example_image(example_path):
314
+ # try:
315
+ # return Image.open(example_path)
316
+ # except:
317
+ # return None
318
+
319
+ # example_button.click(
320
+ # fn=process_example,
321
+ # inputs=[example_dropdown, threshold],
322
+ # outputs=[image_input, output_image, original_display, processed_display, fake_prob, result_label, threshold]
323
+ # )
324
+
325
+ # save_button.click(
326
+ # fn=lambda img: (img.save("result.jpg") if img else None) or "结果已保存!",
327
+ # inputs=[output_image],
328
+ # outputs=gr.Textbox(visible=True, label="保存状态"),
329
+ # api_name="save_result"
330
+ # )
331
+
332
+
333
+ # # def greet(name):
334
+ # # return "Hello " + name + "!!"
335
+
336
+ # # demo = gr.Interface(fn=greet, inputs="text", outputs="text")
337
+ # # demo.launch()
338
+
339
+ # # 启动应用
340
+ # if __name__ == "__main__":
341
+ # demo.launch(server_name="0.0.0.0", server_port=7860)
app3.py ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import tempfile
4
+ import numpy as np
5
+ from PIL import Image
6
+ from src.predict import process_single_image
7
+ import sys
8
+ sys.path.insert(0, "./src")
9
+
10
+ # 自定义主题 - 炫彩现代化
11
+ custom_theme = gr.themes.Default(
12
+ primary_hue="purple",
13
+ secondary_hue="pink",
14
+ neutral_hue="slate",
15
+ font=[gr.themes.GoogleFont("Poppins"), gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"]
16
+ ).set(
17
+ button_primary_background_fill="linear-gradient(45deg, #667eea 0%, #764ba2 100%)",
18
+ button_primary_background_fill_hover="linear-gradient(45deg, #764ba2 0%, #667eea 100%)",
19
+ button_primary_text_color="white",
20
+ button_secondary_background_fill="linear-gradient(45deg, #f093fb 0%, #f5576c 100%)",
21
+ button_secondary_background_fill_hover="linear-gradient(45deg, #f5576c 0%, #f093fb 100%)",
22
+ button_secondary_text_color="white"
23
+ )
24
+
25
+ def get_example_images(folder_path="ffhq"):
26
+ """获取示例图片列表"""
27
+ return sorted([os.path.join(folder_path, f) for f in os.listdir(folder_path)
28
+ if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
29
+
30
+ def safe_extract_prob(cls_probs):
31
+ """安全地从cls_probs中提取概率值"""
32
+ try:
33
+ if cls_probs is None:
34
+ return 0.0
35
+ elif isinstance(cls_probs, (list, np.ndarray)) and len(cls_probs) > 0:
36
+ return float(cls_probs[0])
37
+ elif hasattr(cls_probs, '__getitem__'):
38
+ return float(cls_probs[0])
39
+ else:
40
+ return float(cls_probs)
41
+ except (TypeError, IndexError, ValueError) as e:
42
+ print(f"Error extracting probability: {e}")
43
+ return 0.0
44
+
45
+ # 创建主界面
46
+ with gr.Blocks(
47
+ title="Loupe - AI图像伪造检测系统",
48
+ theme=custom_theme,
49
+ css="""
50
+ /* 全局样式 */
51
+ body {
52
+ background: linear-gradient(-45deg, #ee7752, #e73c7e, #23a6d5, #23d5ab);
53
+ background-size: 400% 400%;
54
+ animation: gradientBG 15s ease infinite;
55
+ min-height: 100vh;
56
+ }
57
+
58
+ @keyframes gradientBG {
59
+ 0% { background-position: 0% 50%; }
60
+ 50% { background-position: 100% 50%; }
61
+ 100% { background-position: 0% 50%; }
62
+ }
63
+
64
+ /* 主容器样式 */
65
+ .gradio-container {
66
+ background: rgba(255, 255, 255, 0.95);
67
+ backdrop-filter: blur(10px);
68
+ border-radius: 20px;
69
+ box-shadow: 0 20px 40px rgba(0, 0, 0, 0.1);
70
+ margin: 20px;
71
+ padding: 20px;
72
+ }
73
+
74
+ /* 标题样式 */
75
+ .title-box {
76
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
77
+ padding: 30px;
78
+ border-radius: 15px;
79
+ margin-bottom: 30px;
80
+ box-shadow: 0 15px 35px rgba(102, 126, 234, 0.3);
81
+ position: relative;
82
+ overflow: hidden;
83
+ }
84
+
85
+ .title-box::before {
86
+ content: '';
87
+ position: absolute;
88
+ top: -50%;
89
+ left: -50%;
90
+ width: 200%;
91
+ height: 200%;
92
+ background: linear-gradient(45deg, transparent, rgba(255, 255, 255, 0.1), transparent);
93
+ animation: shine 3s infinite;
94
+ }
95
+
96
+ @keyframes shine {
97
+ 0% { transform: translateX(-100%) translateY(-100%) rotate(45deg); }
98
+ 100% { transform: translateX(100%) translateY(100%) rotate(45deg); }
99
+ }
100
+
101
+ .title-text {
102
+ font-weight: 700;
103
+ font-size: 32px;
104
+ color: white;
105
+ margin-bottom: 8px;
106
+ text-shadow: 2px 2px 4px rgba(0, 0, 0, 0.3);
107
+ background: linear-gradient(45deg, #fff, #f0f8ff);
108
+ -webkit-background-clip: text;
109
+ -webkit-text-fill-color: transparent;
110
+ background-clip: text;
111
+ }
112
+
113
+ .subtitle-text {
114
+ color: rgba(255, 255, 255, 0.9);
115
+ font-size: 18px;
116
+ font-weight: 300;
117
+ text-shadow: 1px 1px 2px rgba(0, 0, 0, 0.2);
118
+ }
119
+
120
+ /* 输入和结果框样式 */
121
+ .input-box, .result-box {
122
+ background: linear-gradient(145deg, rgba(255, 255, 255, 0.9), rgba(248, 250, 252, 0.9));
123
+ padding: 25px;
124
+ border-radius: 15px;
125
+ margin-bottom: 20px;
126
+ border: 1px solid rgba(255, 255, 255, 0.3);
127
+ box-shadow: 0 10px 30px rgba(0, 0, 0, 0.1);
128
+ backdrop-filter: blur(10px);
129
+ transition: all 0.3s ease;
130
+ }
131
+
132
+ .input-box:hover, .result-box:hover {
133
+ transform: translateY(-5px);
134
+ box-shadow: 0 20px 40px rgba(0, 0, 0, 0.15);
135
+ }
136
+
137
+ .input-title, .result-title {
138
+ font-weight: 700;
139
+ background: linear-gradient(45deg, #667eea, #764ba2);
140
+ -webkit-background-clip: text;
141
+ -webkit-text-fill-color: transparent;
142
+ background-clip: text;
143
+ margin-bottom: 15px;
144
+ font-size: 20px;
145
+ text-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
146
+ }
147
+
148
+ /* 按钮样式 */
149
+ .btn-primary {
150
+ background: linear-gradient(45deg, #667eea 0%, #764ba2 100%);
151
+ border: none;
152
+ border-radius: 25px;
153
+ padding: 12px 30px;
154
+ font-weight: 600;
155
+ text-transform: uppercase;
156
+ letter-spacing: 1px;
157
+ box-shadow: 0 10px 20px rgba(102, 126, 234, 0.3);
158
+ transition: all 0.3s ease;
159
+ }
160
+
161
+ .btn-primary:hover {
162
+ transform: translateY(-3px);
163
+ box-shadow: 0 15px 30px rgba(102, 126, 234, 0.4);
164
+ background: linear-gradient(45deg, #764ba2 0%, #667eea 100%);
165
+ }
166
+
167
+ .btn-secondary {
168
+ background: linear-gradient(45deg, #f093fb 0%, #f5576c 100%);
169
+ border: none;
170
+ border-radius: 25px;
171
+ padding: 10px 25px;
172
+ font-weight: 600;
173
+ box-shadow: 0 8px 16px rgba(240, 147, 251, 0.3);
174
+ transition: all 0.3s ease;
175
+ }
176
+
177
+ .btn-secondary:hover {
178
+ transform: translateY(-2px);
179
+ box-shadow: 0 12px 24px rgba(240, 147, 251, 0.4);
180
+ }
181
+
182
+ /* 图片上传区域 */
183
+ #upload_image {
184
+ min-height: 350px;
185
+ border: 3px dashed rgba(102, 126, 234, 0.3);
186
+ border-radius: 15px;
187
+ background: linear-gradient(45deg, rgba(102, 126, 234, 0.05), rgba(118, 75, 162, 0.05));
188
+ transition: all 0.3s ease;
189
+ }
190
+
191
+ #upload_image:hover {
192
+ border-color: rgba(102, 126, 234, 0.6);
193
+ background: linear-gradient(45deg, rgba(102, 126, 234, 0.1), rgba(118, 75, 162, 0.1));
194
+ transform: scale(1.02);
195
+ }
196
+
197
+ /* 概率显示 */
198
+ #probability input {
199
+ font-weight: bold;
200
+ background: linear-gradient(45deg, #667eea, #764ba2);
201
+ -webkit-background-clip: text;
202
+ -webkit-text-fill-color: transparent;
203
+ background-clip: text;
204
+ font-size: 1.2em;
205
+ }
206
+
207
+ #result_text input {
208
+ font-size: 1.1em;
209
+ font-weight: 600;
210
+ background: linear-gradient(45deg, rgba(102, 126, 234, 0.1), rgba(118, 75, 162, 0.1));
211
+ border-radius: 10px;
212
+ border: 2px solid rgba(102, 126, 234, 0.2);
213
+ }
214
+
215
+ /* 画廊样式 */
216
+ .gallery-item {
217
+ border-radius: 12px !important;
218
+ transition: all 0.3s ease;
219
+ box-shadow: 0 5px 15px rgba(0, 0, 0, 0.1);
220
+ }
221
+
222
+ .gallery-item:hover {
223
+ transform: scale(1.05);
224
+ box-shadow: 0 10px 25px rgba(0, 0, 0, 0.2);
225
+ }
226
+
227
+ /* 示例按钮 */
228
+ .example-btn {
229
+ margin-top: 15px;
230
+ width: 100%;
231
+ background: linear-gradient(45deg, #23a6d5 0%, #23d5ab 100%);
232
+ border-radius: 20px;
233
+ font-weight: 600;
234
+ box-shadow: 0 8px 16px rgba(35, 166, 213, 0.3);
235
+ transition: all 0.3s ease;
236
+ }
237
+
238
+ .example-btn:hover {
239
+ transform: translateY(-2px);
240
+ box-shadow: 0 12px 24px rgba(35, 166, 213, 0.4);
241
+ }
242
+
243
+ /* Tab 样式 */
244
+ .tab-nav button {
245
+ border-radius: 15px 15px 0 0;
246
+ background: linear-gradient(45deg, rgba(102, 126, 234, 0.8), rgba(118, 75, 162, 0.8));
247
+ color: white;
248
+ font-weight: 600;
249
+ transition: all 0.3s ease;
250
+ }
251
+
252
+ .tab-nav button:hover {
253
+ background: linear-gradient(45deg, rgba(118, 75, 162, 0.9), rgba(102, 126, 234, 0.9));
254
+ transform: translateY(-2px);
255
+ }
256
+
257
+ /* 滑块样式 */
258
+ .gr-slider input[type="range"] {
259
+ background: linear-gradient(45deg, #667eea, #764ba2);
260
+ border-radius: 10px;
261
+ }
262
+
263
+ /* 手风琴样式 */
264
+ .gr-accordion {
265
+ background: linear-gradient(145deg, rgba(255, 255, 255, 0.8), rgba(248, 250, 252, 0.8));
266
+ border-radius: 15px;
267
+ border: 1px solid rgba(102, 126, 234, 0.2);
268
+ box-shadow: 0 5px 15px rgba(0, 0, 0, 0.1);
269
+ }
270
+
271
+ /* 炫彩加载动画 */
272
+ @keyframes rainbow {
273
+ 0% { background-position: 0% 50%; }
274
+ 50% { background-position: 100% 50%; }
275
+ 100% { background-position: 0% 50%; }
276
+ }
277
+
278
+ .processing {
279
+ background: linear-gradient(-45deg, #ee7752, #e73c7e, #23a6d5, #23d5ab);
280
+ background-size: 400% 400%;
281
+ animation: rainbow 2s ease infinite;
282
+ }
283
+
284
+ /* 响应式设计 */
285
+ @media (max-width: 768px) {
286
+ .title-text { font-size: 24px; }
287
+ .subtitle-text { font-size: 16px; }
288
+ .input-box, .result-box { padding: 20px; }
289
+ }
290
+ """
291
+ ) as demo:
292
+
293
+ # 标题部分 - 炫彩渐变设计
294
+ with gr.Column(elem_classes="title-box"):
295
+ gr.Markdown("""
296
+ <div class="title-text">🔍 Loupe 图像伪造检测系统</div>
297
+ <div class="subtitle-text">✨ 基于深度学习的图像伪造检测与定位技术</div>
298
+ """)
299
+
300
+ # 添加装饰性分割线
301
+ gr.HTML("""
302
+ <div style="height: 4px; background: linear-gradient(90deg, #667eea, #764ba2, #f093fb, #f5576c, #23a6d5, #23d5ab);
303
+ border-radius: 2px; margin: 20px 0; box-shadow: 0 2px 10px rgba(0,0,0,0.2);"></div>
304
+ """)
305
+
306
+ # 主界面组件
307
+ with gr.Row(equal_height=True):
308
+ with gr.Column(scale=1, min_width=300):
309
+ # 输入图像区域 - 炫彩设计
310
+ with gr.Column(elem_classes="input-box"):
311
+ gr.Markdown("""<div class="input-title">🎨 输入图像</div>""")
312
+ with gr.Tabs():
313
+ with gr.Tab("📤 上传图片", id="upload_tab"):
314
+ image_input = gr.Image(type="pil", label="", elem_id="upload_image")
315
+ upload_button = gr.Button("🚀 开始检测", variant="primary", size="lg", elem_classes="btn-primary")
316
+
317
+ with gr.Tab("🖼️ 示例图片", id="example_tab"):
318
+ example_images = get_example_images()
319
+ example_gallery = gr.Gallery(
320
+ value=example_images,
321
+ label="",
322
+ columns=4,
323
+ rows=None,
324
+ height="auto",
325
+ object_fit="contain",
326
+ allow_preview=True,
327
+ selected_index=None
328
+ )
329
+ # 添加炫彩检测按钮
330
+ example_button = gr.Button(
331
+ "✨ 检测选中的示例图片",
332
+ variant="primary",
333
+ elem_classes="example-btn"
334
+ )
335
+ # 隐藏组件用于存储选中索引
336
+ selected_index = gr.Number(visible=False)
337
+
338
+ with gr.Accordion("⚙️ 高级设置", open=False):
339
+ threshold = gr.Slider(0, 1, value=0.5, step=0.01, label="🎯 检测敏感度")
340
+ gr.HTML("""
341
+ <div style="background: linear-gradient(45deg, rgba(102,126,234,0.1), rgba(118,75,162,0.1));
342
+ padding: 10px; border-radius: 8px; margin-top: 10px;">
343
+ <small style="color: #667eea; font-weight: 500;">💡 调整数值可改变检测的严格程度</small>
344
+ </div>
345
+ """)
346
+
347
+ with gr.Column(scale=1.5, min_width=500):
348
+ # 检测结果区域 - 炫彩设计
349
+ with gr.Column(elem_classes="result-box"):
350
+ gr.Markdown("""<div class="result-title">🎯 检测结果</div>""")
351
+ with gr.Tabs():
352
+ with gr.Tab("🔍 检测效果", id="result_tab"):
353
+ output_image = gr.Image(label="伪造区域标记", interactive=False)
354
+
355
+ with gr.Tab("⚖️ 对比视图", id="compare_tab"):
356
+ with gr.Row():
357
+ original_display = gr.Image(label="原始图像", interactive=False)
358
+ processed_display = gr.Image(label="检测结果", interactive=False)
359
+
360
+ with gr.Group():
361
+ with gr.Row():
362
+ fake_prob = gr.Number(label="🎲 伪造概率", precision=2, elem_id="probability")
363
+ result_text = gr.Textbox(label="📝 检测结论", interactive=False, elem_id="result_text")
364
+
365
+ with gr.Row():
366
+ save_button = gr.Button("💾 保存结果", variant="secondary", elem_classes="btn-secondary")
367
+ clear_button = gr.Button("🧹 清除", variant="secondary", elem_classes="btn-secondary")
368
+
369
+ # 关于部分 - 炫彩设计
370
+ with gr.Accordion("🌟 关于系统", open=False):
371
+ gr.HTML("""
372
+ <div style="background: linear-gradient(135deg, rgba(102,126,234,0.1), rgba(118,75,162,0.1), rgba(240,147,251,0.1));
373
+ padding: 20px; border-radius: 15px; border: 1px solid rgba(102,126,234,0.2);">
374
+ <h3 style="background: linear-gradient(45deg, #667eea, #764ba2); -webkit-background-clip: text;
375
+ -webkit-text-fill-color: transparent; margin-bottom: 15px;">
376
+ ✨ Loupe 伪造图像检测系统
377
+ </h3>
378
+ <div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); gap: 15px;">
379
+ <div style="background: rgba(102,126,234,0.1); padding: 15px; border-radius: 10px;">
380
+ <strong style="color: #667eea;">🚀 技术</strong><br>
381
+ 基于深度学习的图像伪造检测与定位
382
+ </div>
383
+ <div style="background: rgba(118,75,162,0.1); padding: 15px; border-radius: 10px;">
384
+ <strong style="color: #764ba2;">⭐ 特点</strong><br>
385
+ 高精度、实时处理、可解释性强
386
+ </div>
387
+ <div style="background: rgba(240,147,251,0.1); padding: 15px; border-radius: 10px;">
388
+ <strong style="color: #f093fb;">📱 版本</strong><br>
389
+ v2.0.0 炫彩版
390
+ </div>
391
+ <div style="background: rgba(245,87,108,0.1); padding: 15px; border-radius: 10px;">
392
+ <strong style="color: #f5576c;">👥 开发者</strong><br>
393
+ EVOL Lab (jyc, xxw)
394
+ </div>
395
+ </div>
396
+ <div style="margin-top: 20px; padding: 15px; background: linear-gradient(45deg, rgba(35,166,213,0.1), rgba(35,213,171,0.1));
397
+ border-radius: 10px; border-left: 4px solid #23a6d5;">
398
+ <strong style="color: #23a6d5;">💡 系统介绍</strong><br>
399
+ 本系统可检测多种图像篡改痕迹,包括复制-移动、拼接、擦除等操作。采用最新的深度学习算法,提供高精度的检测结果和直观的可视化分析。
400
+ </div>
401
+ </div>
402
+ """)
403
+
404
+ # 页脚 - 炫彩设计
405
+ gr.HTML("""
406
+ <div style="margin-top: 40px; padding: 20px; text-align: center;
407
+ background: linear-gradient(135deg, rgba(102,126,234,0.1), rgba(118,75,162,0.1));
408
+ border-radius: 15px; border-top: 2px solid rgba(102,126,234,0.3);">
409
+ <div style="background: linear-gradient(45deg, #667eea, #764ba2); -webkit-background-clip: text;
410
+ -webkit-text-fill-color: transparent; font-weight: 600; margin-bottom: 10px;">
411
+ ✨ 感谢使用 Loupe 图像伪造检测系统 ✨
412
+ </div>
413
+ <div style="color: #64748b; font-size: 14px;">
414
+ © 2025 EVOL Lab | 让AI守护图像真实性 🛡️
415
+ </div>
416
+ <div style="margin-top: 10px;">
417
+ <span style="background: linear-gradient(45deg, #f093fb, #f5576c); -webkit-background-clip: text;
418
+ -webkit-text-fill-color: transparent; font-weight: 500;">
419
+ 🌟 科技点亮未来,智能守护真实 🌟
420
+ </span>
421
+ </div>
422
+ </div>
423
+ """)
424
+
425
+ def process_image(image, threshold_value):
426
+ """处理上传的图像"""
427
+ if image is None:
428
+ return {
429
+ output_image: None,
430
+ fake_prob: 0.0,
431
+ result_text: "❌ 请上传有效图像"
432
+ }
433
+
434
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
435
+ image_path = tmp_file.name
436
+ image.save(image_path)
437
+
438
+ try:
439
+ processed_img, cls_probs = process_single_image(image_path)
440
+ prob = safe_extract_prob(cls_probs)
441
+
442
+ # 根据概率生成炫彩结论
443
+ if prob > threshold_value + 0.2:
444
+ conclusion = "🚨 高度疑似伪造"
445
+ emoji = "🔴"
446
+ elif prob > threshold_value:
447
+ conclusion = "⚠️ 可能伪造"
448
+ emoji = "🟡"
449
+ else:
450
+ conclusion = "✅ 未检测到伪造"
451
+ emoji = "🟢"
452
+
453
+ return {
454
+ output_image: processed_img,
455
+ original_display: image,
456
+ processed_display: processed_img,
457
+ fake_prob: prob,
458
+ result_text: f"{emoji} {conclusion} (概率: {prob:.2f})"
459
+ }
460
+ except Exception as e:
461
+ print(f"Error in processing: {e}")
462
+ return {
463
+ output_image: None,
464
+ original_display: image,
465
+ processed_display: None,
466
+ fake_prob: 0.0,
467
+ result_text: f"❌ 处理错误: {str(e)}"
468
+ }
469
+ finally:
470
+ if os.path.exists(image_path):
471
+ os.unlink(image_path)
472
+
473
+ def process_example(example_data, selected_idx, threshold_value):
474
+ """处理示例图像"""
475
+ if not example_data or selected_idx is None:
476
+ return {
477
+ image_input: None,
478
+ output_image: None,
479
+ original_display: None,
480
+ processed_display: None,
481
+ fake_prob: 0.0,
482
+ result_text: "⚠️ 请先选择示例图片",
483
+ threshold: threshold_value
484
+ }
485
+
486
+ try:
487
+ selected_idx = int(selected_idx)
488
+ image_info = example_data[selected_idx]
489
+
490
+ # 处理不同的数据格式
491
+ if isinstance(image_info, (tuple, list)):
492
+ image_path = image_info[0] # (path, caption)格式
493
+ elif isinstance(image_info, dict):
494
+ image_path = image_info.get("name", image_info.get("path"))
495
+ else:
496
+ image_path = image_info
497
+
498
+ print(f"Processing selected image (index {selected_idx}): {image_path}") # 调试日志
499
+
500
+ # 处理图像
501
+ processed_img, cls_probs = process_single_image(image_path)
502
+ prob = safe_extract_prob(cls_probs)
503
+ original_img = Image.open(image_path)
504
+
505
+ # 根据概率生成炫彩结论
506
+ if prob > threshold_value + 0.2:
507
+ conclusion = "🚨 高度疑似伪造"
508
+ emoji = "🔴"
509
+ elif prob > threshold_value:
510
+ conclusion = "⚠️ 可能伪造"
511
+ emoji = "🟡"
512
+ else:
513
+ conclusion = "✅ 未检测到伪造"
514
+ emoji = "🟢"
515
+
516
+ return {
517
+ image_input: original_img,
518
+ output_image: processed_img,
519
+ original_display: original_img,
520
+ processed_display: processed_img,
521
+ fake_prob: prob,
522
+ result_text: f"{emoji} {conclusion} (概率: {prob:.2f})",
523
+ threshold: threshold_value
524
+ }
525
+ except Exception as e:
526
+ print(f"Error in processing example: {e}")
527
+ return {
528
+ image_input: None,
529
+ output_image: None,
530
+ original_display: None,
531
+ processed_display: None,
532
+ fake_prob: 0.0,
533
+ result_text: f"❌ 示例处理错误: {str(e)}",
534
+ threshold: threshold_value
535
+ }
536
+
537
+ def clear_all():
538
+ """清除所有输入输出"""
539
+ return {
540
+ image_input: None,
541
+ output_image: None,
542
+ original_display: None,
543
+ processed_display: None,
544
+ fake_prob: 0.0,
545
+ result_text: "🧹 已清除所有数据"
546
+ }
547
+
548
+ def update_selected_index(evt: gr.SelectData):
549
+ """更新选中的图片索引"""
550
+ return evt.index
551
+
552
+ # 交互逻辑
553
+ upload_button.click(
554
+ process_image,
555
+ [image_input, threshold],
556
+ [output_image, original_display, processed_display, fake_prob, result_text]
557
+ )
558
+
559
+ # 示例图片选择事件
560
+ example_gallery.select(
561
+ update_selected_index,
562
+ None,
563
+ selected_index
564
+ )
565
+
566
+ # 示例图片检测按钮点击事件
567
+ example_button.click(
568
+ process_example,
569
+ [example_gallery, selected_index, threshold],
570
+ [image_input, output_image, original_display, processed_display, fake_prob, result_text, threshold]
571
+ )
572
+
573
+ save_button.click(
574
+ lambda img: (img.save("result.jpg"), "💾 结果已保存为 result.jpg")[1] if img else "❌ 没有图像可保存",
575
+ [output_image],
576
+ None,
577
+ api_name="save_result"
578
+ )
579
+
580
+ clear_button.click(
581
+ clear_all,
582
+ [],
583
+ [image_input, output_image, original_display, processed_display, fake_prob, result_text]
584
+ )
585
+
586
+ if __name__ == "__main__":
587
+ demo.launch(
588
+ server_name="0.0.0.0",
589
+ server_port=7864,
590
+ favicon_path="./favicon.ico" if os.path.exists("./favicon.ico") else None
591
+ )
configs/base.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ seed: 42
2
+ trainer:
3
+ precision: 16-mixed
4
+ strategy: deepspeed_stage_2_offload # options: "deepspeed_stage_2_offload" / "ddp"
5
+ fast_dev_run: false # set to true for debugging
6
+ hydra.output_subdir: null
configs/ckpt/base.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # list of checkpoints to load, the latter will override the former
2
+ # must end with .safetensors or .pt/.pth
3
+ checkpoint_paths: []
4
+
5
+ saver:
6
+ _target_: "pytorch_lightning.callbacks.ModelCheckpoint"
7
+ dirpath: null # it will be set during runtime
8
+ filename: "loupe-{val_loss:.4f}"
9
+ monitor: val_loss
10
+ mode: min
11
+ save_top_k: 1
12
+ save_last: False
configs/ckpt/cls.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base
3
+ - _self_
4
+
5
+ saver:
6
+ filename: "cls-{auc:.4f}"
7
+ monitor: auc
8
+ mode: max
configs/ckpt/cls_seg.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base
3
+ - _self_
4
+
5
+ checkpoint_paths: ["checkpoints/cls/model.safetensors", "checkpoints/seg-f1=0.8804-iou=0.8866.ckpt/model.safetensors"]
6
+ saver:
7
+ filename: "cls_seg-{auc:.4f}-{f1:.4f}-{iou:.4f}"
8
+ monitor: overall # overall is mean of auc, f1, and iou, based on challenge requirements
9
+ mode: max
configs/ckpt/seg.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base
3
+ - _self_
4
+
5
+ saver:
6
+ filename: "seg-{f1:.4f}-{iou:.4f}"
7
+ monitor: iou
8
+ mode: max
configs/ckpt/test.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - cls_seg
3
+ - _self_
4
+
5
+ # list of checkpoints to load, the latter will override the former
6
+ # must end with .safetensors or .pt/.pth
7
+ checkpoint_paths: ["checkpoints/cls/model.safetensors", "checkpoints/seg/model.safetensors"]
configs/dataset/base.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ data_dir: null
2
+
3
+ num_workers: 8
4
+ valid_size: 0.1 # float between 0 and 1 or int representing the number of samples
configs/dataset/custom.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ defaults:
2
+ - base
3
+ - _self_
4
+
5
+ data_dir: "/gemini/space/jyc/casia2/"
configs/dataset/ddl.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base
3
+ - _self_
4
+
5
+ valid_size: 5000
6
+ data_dir: "/gemini/space/jyc/track1/"
configs/hparams/base.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ weight_decay: 1e-3
2
+ warmup_step: 0.1
3
+ decay_step: 0.1
4
+ grad_clip_val: 1.0
5
+ scheduler: "wsd" # Options: "cosine" / "wsd"
6
+ epoch: 1
configs/hparams/cls.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base
3
+ - _self_
4
+
5
+ cls_lr: null # lr for cls head
6
+ backbone_lr: 1e-5 # sometimes we want to finetune backbone with a smaller learning rate
7
+ lr: 5e-4 # default lr for other not specified params
8
+
9
+ batch_size: 48
10
+ accumulate_grad_batches: 8
configs/hparams/cls_seg.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ defaults:
2
+ - cls
3
+ - seg
4
+ - _self_
5
+
6
+ batch_size: 32
configs/hparams/seg.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base
3
+ - _self_
4
+
5
+ weight_decay: 5e-2
6
+
7
+ seg_lr: null # lr for seg head
8
+ backbone_lr: 1e-5 # sometimes we want to finetune backbone with a smaller learning rate
9
+ lr: 3e-4 # default lr for other not specified params
10
+
11
+ epoch: 1
12
+ batch_size: 40
13
+ accumulate_grad_batches: 3
configs/hparams/test.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base
3
+ - _self_
4
+
5
+ batch_size: 96
6
+ lr: 1e-4
7
+ accumulate_grad_batches: 1
configs/infer.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base
3
+ - dataset: ddl
4
+ - stage: test
5
+ - model: test
6
+ - hparams: test
7
+ - ckpt: test
8
+ - _self_
9
+
10
+ trainer:
11
+ enable_checkpointing: false
configs/model/base.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # basic configs
2
+ hidden_act: "gelu"
3
+ hidden_dropout_prob: 0.1
4
+ initializer_range: 0.02
5
+
6
+ # backbone configs
7
+ backbone_name: PE-Core-L14-336
8
+ backbone_path: ./pretrained_weights/pe/PE-Core-L14-336.pt
9
+ freeze_backbone: False
10
+
11
+ # backbone overrides, you can set attr to '-' to use default value
12
+ # visit https://github.com/facebookresearch/perception_models/blob/main/core/vision_encoder/config.py for available overrides
13
+ backbone_overrides:
14
+ output_dim: null # set to null to use our own proj mlp
15
+ # NOTE: pool_type of PE-Spatial-G14-448 is none
16
+ # but loupe requires a pool_type, specify it to "attn" / "tok"
17
+ # pool_type: "attn" / "tok"
18
+ pool_type: "-"
19
+ use_cls_token: "-"
configs/model/cls.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base
3
+ - _self_
4
+
5
+ # loupe configs
6
+ cls_mlp_ratio: 2 # 2 times of Perception Encoder output dim
7
+ cls_mlp_layers: 2
8
+ enable_patch_cls: True
9
+ enable_cls_fusion: True
10
+
11
+ freeze_cls: False
12
+ freeze_backbone: True
13
+
14
+ cls_forge_weight: 0.2
15
+ patch_forge_weight: 0.85
configs/model/cls_seg.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - cls
3
+ - seg
4
+ - _self_
5
+
6
+ freeze_backbone: False
7
+ freeze_cls: False
8
+ freeze_seg: False
9
+
10
+ cls_loss_weight: 2.0
11
+ seg_loss_weight: 1.0
12
+
13
+ # conditional_queries is typically used for test-time adaptation
14
+ # during segmentation training, conditional_queries will serve as an
15
+ # extra condition provided for pixel decoder.
16
+ enable_conditional_queries: True
configs/model/seg.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base
3
+ - _self_
4
+
5
+ fpn_scales: [0.5, 2, 4] # rescale the last hidden states of backbone. for PE-Core-L14-336, rescale to 12x12, 48x48, 96x96
6
+ freeze_backbone: True
7
+ freeze_seg: False
8
+ # tversky alpha and beta control the weight of false positive and false negative, respectively
9
+ # the tversky beta is set to 1 - alpha
10
+ tversky_alpha: 0.3
11
+ # weight for forged pixels, set between 0 and 1
12
+ pixel_forge_weight: 0.8
13
+ # epsilon for poly1 focal loss
14
+ pixel_poly_epsilon: 1.0
15
+
16
+ # conditional_queries is typically used for test-time adaptation
17
+ # during segmentation training, conditional_queries will serve as an
18
+ # extra condition provided for pixel decoder.
19
+ enable_conditional_queries: True
20
+
21
+ # mask2former overrides, you can set attr to '-' to use default value
22
+ # visit https://huggingface.co/docs/transformers/main/model_doc/mask2former#transformers.Mask2FormerConfig for available overrides
23
+ mask2former_overrides:
24
+ num_queries: 20
25
+ mask_weight: 5
26
+ class_weight: 2
27
+ dice_weight: 5
28
+ id2label:
29
+ 0: "forgery"
30
+ label2id:
31
+ forgery: 0
32
+
configs/model/test.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - cls_seg
3
+ - _self_
4
+
5
+ freeze_backbone: True
6
+ freeze_cls: True
7
+ freeze_seg: True
configs/stage/cls.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ name: cls
configs/stage/cls_seg.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ name: cls_seg
2
+
3
+ # whether to train on the training set.
4
+ # since the classifier and segmenter are sometimes trained separately,
5
+ # we default to training on most of the validation set to avoid overfitting.
6
+ # However, if you prefer to train from scratch directly on the training set,
7
+ # you can do so by setting this variable to true.
8
+ train_on_trainset: false
configs/stage/seg.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ name: seg
configs/stage/test.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ name: test
2
+
3
+ # if set to a specific path, model predictions will be saved to this path
4
+ # with the filename format:
5
+ # - predicted masks: {predict_output_dir}/{ckpt-name}/masks/{image-name}.png
6
+ # - predicted labels: {predict_output_dir}/{ckpt-name}/predictions.txt
7
+ # in predictions.txt, each line is {image-name}.png,{probs to be forged:.4f}
8
+ pred_output_dir: null
9
+
10
+ enable_tta: false
configs/train.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base
3
+ - dataset: ddl
4
+ - stage: null # options: "cls" / "seg" / "cls_seg"
5
+ - model: ${stage}
6
+ - hparams: ${stage}
7
+ - ckpt: ${stage}
8
+ - _self_
9
+
10
+ trainer:
11
+ enable_checkpointing: true
dataset_preprocess.ipynb ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "531bfd29",
6
+ "metadata": {},
7
+ "source": [
8
+ "In this notebook, we transform raw datasets to parquet format to enable faster loading speed during training and evaluation.\n",
9
+ "\n",
10
+ "The raw format of released datasets is as follows:\n",
11
+ "```python\n",
12
+ "# train set\n",
13
+ "/train/real/...\n",
14
+ "/train/fake/...\n",
15
+ "/train/masks/...\n",
16
+ "# valid set\n",
17
+ "/valid/real/...\n",
18
+ "/valid/fake/...\n",
19
+ "/valid/masks/...\n",
20
+ "```"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": 2,
26
+ "id": "8bd7e9d5",
27
+ "metadata": {},
28
+ "outputs": [],
29
+ "source": [
30
+ "import os\n",
31
+ "from datasets import Dataset, DatasetDict\n",
32
+ "from datasets import Features, Image, Value\n",
33
+ "from typing import List, Optional\n",
34
+ "\n",
35
+ "\n",
36
+ "def load_images_from_dir(directory: str) -> List[str]:\n",
37
+ " return [\n",
38
+ " os.path.join(directory, fname)\n",
39
+ " for fname in os.listdir(directory)\n",
40
+ " if fname.endswith((\"jpg\", \"jpeg\", \"png\", \"tif\"))\n",
41
+ " ]\n",
42
+ "\n",
43
+ "\n",
44
+ "def create_split(root_dir: str, split: str) -> Optional[Dataset]:\n",
45
+ " fake_dir = os.path.join(root_dir, split, \"fake\")\n",
46
+ " masks_dir = os.path.join(root_dir, split, \"masks\")\n",
47
+ " real_dir = os.path.join(root_dir, split, \"real\")\n",
48
+ "\n",
49
+ " if all(not os.path.isdir(p) for p in [fake_dir, masks_dir, real_dir]):\n",
50
+ " return None\n",
51
+ "\n",
52
+ " print(f\"Split: {split},\", end=\" \")\n",
53
+ " fake_images, real_images, mask_images = [], [], []\n",
54
+ " if os.path.isdir(fake_dir):\n",
55
+ " fake_images = load_images_from_dir(fake_dir)\n",
56
+ " print(f\"Fake images: {len(fake_images)}\", end=\"\")\n",
57
+ " if os.path.isdir(masks_dir):\n",
58
+ " mask_images = load_images_from_dir(masks_dir)\n",
59
+ " print(f\", Masks: {len(mask_images)}\", end=\"\")\n",
60
+ " assert len(fake_images) == len(mask_images)\n",
61
+ " if os.path.isdir(real_dir):\n",
62
+ " real_images = load_images_from_dir(real_dir)\n",
63
+ " print(f\", Real images: {len(real_images)}\", end=\"\")\n",
64
+ " print()\n",
65
+ "\n",
66
+ " return Dataset.from_dict(\n",
67
+ " {\n",
68
+ " \"path\": fake_images + real_images,\n",
69
+ " \"image\": fake_images + real_images,\n",
70
+ " \"mask\": mask_images + [None] * len(real_images),\n",
71
+ " },\n",
72
+ " features=Features(\n",
73
+ " {\"path\": Value(dtype=\"string\"), \"image\": Image(), \"mask\": Image()}\n",
74
+ " ),\n",
75
+ " )\n",
76
+ "\n",
77
+ "\n",
78
+ "def create_dataset(root_dir: str) -> DatasetDict:\n",
79
+ " return DatasetDict(\n",
80
+ " {\n",
81
+ " split: d\n",
82
+ " for split in [\"train\", \"valid\", \"test\"]\n",
83
+ " if (d := create_split(root_dir, split)) is not None\n",
84
+ " }\n",
85
+ " )\n",
86
+ "\n",
87
+ "\n",
88
+ "# replace with your own dataset path\n",
89
+ "root_dir = \"/gemini/space/lye/track1\"\n",
90
+ "save_dir = \"/gemini/space/jyc/track1\""
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "markdown",
95
+ "id": "a1d6f1c7",
96
+ "metadata": {},
97
+ "source": [
98
+ "We merge `real/` and `fake/` into `images` column for simplity. A image is real if there is no corresponding mask."
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "code",
103
+ "execution_count": 14,
104
+ "id": "07009f1e",
105
+ "metadata": {},
106
+ "outputs": [
107
+ {
108
+ "name": "stdout",
109
+ "output_type": "stream",
110
+ "text": [
111
+ "Split: train, Fake images: 798831, Masks: 798831, Real images: 156100\n",
112
+ "Split: valid, Fake images: 199708, Masks: 199708, Real images: 39025\n",
113
+ "Split: test, Images: 222847\n"
114
+ ]
115
+ },
116
+ {
117
+ "data": {
118
+ "text/plain": [
119
+ "DatasetDict({\n",
120
+ " train: Dataset({\n",
121
+ " features: ['path', 'image', 'mask'],\n",
122
+ " num_rows: 954931\n",
123
+ " })\n",
124
+ " valid: Dataset({\n",
125
+ " features: ['path', 'image', 'mask'],\n",
126
+ " num_rows: 238733\n",
127
+ " })\n",
128
+ " test: Dataset({\n",
129
+ " features: ['path', 'image'],\n",
130
+ " num_rows: 222847\n",
131
+ " })\n",
132
+ "})"
133
+ ]
134
+ },
135
+ "execution_count": 14,
136
+ "metadata": {},
137
+ "output_type": "execute_result"
138
+ }
139
+ ],
140
+ "source": [
141
+ "dataset = create_dataset(root_dir)\n",
142
+ "dataset"
143
+ ]
144
+ },
145
+ {
146
+ "cell_type": "markdown",
147
+ "id": "3aa7de84",
148
+ "metadata": {},
149
+ "source": [
150
+ "Then save processed datasets to parquet."
151
+ ]
152
+ },
153
+ {
154
+ "cell_type": "code",
155
+ "execution_count": null,
156
+ "id": "cd6b20bc",
157
+ "metadata": {},
158
+ "outputs": [],
159
+ "source": [
160
+ "os.makedirs(save_dir, exist_ok=True)\n",
161
+ "for split in dataset:\n",
162
+ " dataset[split].to_parquet(os.path.join(save_dir, f\"{split}.parquet\"))\n",
163
+ " print(f\"Saved {split} split to {save_dir}/{split}.parquet\")"
164
+ ]
165
+ },
166
+ {
167
+ "cell_type": "markdown",
168
+ "id": "f63933c8",
169
+ "metadata": {},
170
+ "source": [
171
+ "Load from processed datasets to do whatever you want."
172
+ ]
173
+ },
174
+ {
175
+ "cell_type": "code",
176
+ "execution_count": 10,
177
+ "id": "4af7f346",
178
+ "metadata": {},
179
+ "outputs": [
180
+ {
181
+ "data": {
182
+ "text/plain": [
183
+ "Dataset({\n",
184
+ " features: ['path', 'image', 'mask'],\n",
185
+ " num_rows: 954931\n",
186
+ "})"
187
+ ]
188
+ },
189
+ "execution_count": 10,
190
+ "metadata": {},
191
+ "output_type": "execute_result"
192
+ }
193
+ ],
194
+ "source": [
195
+ "import os\n",
196
+ "from datasets import load_dataset\n",
197
+ "\n",
198
+ "trainset = load_dataset(\"parquet\", data_dir=save_dir, split=\"train\")\n",
199
+ "trainset"
200
+ ]
201
+ },
202
+ {
203
+ "cell_type": "markdown",
204
+ "id": "b3c84f0a",
205
+ "metadata": {},
206
+ "source": [
207
+ "Since the forged components are usually smaller in proportion compared to the real ones, this leads to class imbalance.\n",
208
+ "For optimal training performance, hyper parameters such as `pixel_forge_weight` and `cls_forge_weight` in `src.loupe.configuration_loupe.LoupeConfig` must be appropriately configured. These parameters control the weights of forged pixels and forged images.\n",
209
+ "\n",
210
+ "Once suitable parameters are found using the following code snippet, you can set them in `configs/model/cls.yaml` or `configs/model/seg.yaml`.\n"
211
+ ]
212
+ },
213
+ {
214
+ "cell_type": "code",
215
+ "execution_count": null,
216
+ "id": "40a5ec91",
217
+ "metadata": {},
218
+ "outputs": [
219
+ {
220
+ "data": {
221
+ "application/vnd.jupyter.widget-view+json": {
222
+ "model_id": "19d416f59f20464692ee95bddefdaded",
223
+ "version_major": 2,
224
+ "version_minor": 0
225
+ },
226
+ "text/plain": [
227
+ "Computing mask stats (num_proc=8): 0%| | 0/5000 [00:00<?, ? examples/s]"
228
+ ]
229
+ },
230
+ "metadata": {},
231
+ "output_type": "display_data"
232
+ },
233
+ {
234
+ "name": "stdout",
235
+ "output_type": "stream",
236
+ "text": [
237
+ "cls_forge_weight: 0.16920000000000002\n",
238
+ "patch_forge_weight: 0.9294853830073696\n",
239
+ "pixel_forge_weight: 0.9160308902282281\n"
240
+ ]
241
+ }
242
+ ],
243
+ "source": [
244
+ "import numpy as np\n",
245
+ "from PIL import Image\n",
246
+ "from tqdm.notebook import tqdm\n",
247
+ "\n",
248
+ "cls_forge_weight: float # the ratio of forged images to total images.\n",
249
+ "# the ratio of forged patches to total patches across all images.\n",
250
+ "patch_forge_weight: float\n",
251
+ "# the ratio of forged pixels to total pixels across fake images.\n",
252
+ "pixel_forge_weight: float\n",
253
+ "\n",
254
+ "num_subset_samples = min(5000, len(trainset))\n",
255
+ "subset = trainset.shuffle().select(range(num_subset_samples))\n",
256
+ "image_size, patch_size = 336, 14\n",
257
+ "\n",
258
+ "\n",
259
+ "def compute_mask_stats(example):\n",
260
+ "\n",
261
+ " if example[\"mask\"] is None:\n",
262
+ " return {\n",
263
+ " \"is_forge\": 0,\n",
264
+ " \"forge_pixel_sum\": 0.0,\n",
265
+ " \"total_pixel_count\": 0,\n",
266
+ " \"forge_patch_sum\": 0.0,\n",
267
+ " }\n",
268
+ "\n",
269
+ " mask = example[\"mask\"].convert(\"L\").resize((image_size, image_size), Image.NEAREST)\n",
270
+ " mask_np = np.array(mask, dtype=np.float32)\n",
271
+ "\n",
272
+ " if mask_np.max() != mask_np.min():\n",
273
+ " mask_np = (mask_np - mask_np.min()) / (mask_np.max() - mask_np.min())\n",
274
+ " else:\n",
275
+ " mask_np[:] = 0.0\n",
276
+ "\n",
277
+ " forged_pixel_sum = mask_np.sum()\n",
278
+ " total_pixels = mask_np.size\n",
279
+ "\n",
280
+ " reshaped = mask_np.reshape(\n",
281
+ " image_size // patch_size, patch_size, image_size // patch_size, patch_size\n",
282
+ " )\n",
283
+ " patches = reshaped.transpose(0, 2, 1, 3)\n",
284
+ " forged_patch_sum = (patches != 0).sum(axis=(2, 3)) / (patch_size * patch_size)\n",
285
+ " forged_patch_sum = forged_patch_sum.sum()\n",
286
+ "\n",
287
+ " return {\n",
288
+ " \"is_forge\": 1,\n",
289
+ " \"forge_pixel_sum\": forged_pixel_sum,\n",
290
+ " \"total_pixel_count\": total_pixels,\n",
291
+ " \"forge_patch_sum\": forged_patch_sum,\n",
292
+ " }\n",
293
+ "\n",
294
+ "\n",
295
+ "processed = subset.map(compute_mask_stats, num_proc=8, desc=\"Computing mask stats\")\n",
296
+ "\n",
297
+ "num_forge_images = sum(processed[\"is_forge\"])\n",
298
+ "num_forge_pixels = sum(processed[\"forge_pixel_sum\"])\n",
299
+ "num_total_pixels = sum(processed[\"total_pixel_count\"])\n",
300
+ "num_forge_patches = sum(processed[\"forge_patch_sum\"])\n",
301
+ "num_total_patches = len(processed) * (image_size // patch_size) ** 2\n",
302
+ "\n",
303
+ "cls_forge_weight = 1 - num_forge_images / len(processed)\n",
304
+ "patch_forge_weight = 1 - num_forge_patches / num_total_patches\n",
305
+ "pixel_forge_weight = 1 - num_forge_pixels / num_total_pixels\n",
306
+ "\n",
307
+ "print(\"cls_forge_weight:\", cls_forge_weight)\n",
308
+ "print(\"patch_forge_weight:\", patch_forge_weight)\n",
309
+ "print(\"pixel_forge_weight:\", pixel_forge_weight)"
310
+ ]
311
+ }
312
+ ],
313
+ "metadata": {
314
+ "kernelspec": {
315
+ "display_name": "loupe2",
316
+ "language": "python",
317
+ "name": "python3"
318
+ },
319
+ "language_info": {
320
+ "codemirror_mode": {
321
+ "name": "ipython",
322
+ "version": 3
323
+ },
324
+ "file_extension": ".py",
325
+ "mimetype": "text/x-python",
326
+ "name": "python",
327
+ "nbconvert_exporter": "python",
328
+ "pygments_lexer": "ipython3",
329
+ "version": "3.12.9"
330
+ }
331
+ },
332
+ "nbformat": 4,
333
+ "nbformat_minor": 5
334
+ }
ffhq/ffhq-0001.png ADDED

Git LFS Details

  • SHA256: 33ea1e008ec7b310b254b7dbe2275586f22ba6424de05590093aa06f8a785c2d
  • Pointer size: 132 Bytes
  • Size of remote file: 1.26 MB
ffhq/ffhq-0009.png ADDED

Git LFS Details

  • SHA256: 4b9c680fb4726a47c2b31a1acce1173de2e52ded122bd4b5195c051d83b84f43
  • Pointer size: 132 Bytes
  • Size of remote file: 1.37 MB
ffhq/ffhq-0032.png ADDED

Git LFS Details

  • SHA256: 1358253a3d3e1f79d276a9f94e61eadb9c57422f0ed5ba8cc26aec111f74022e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.25 MB
ffhq/ffhq-0041.png ADDED

Git LFS Details

  • SHA256: 74f687e40e8dc0c0284aba29b96af877606c22ea023fde1e1fb131a199e905e3
  • Pointer size: 132 Bytes
  • Size of remote file: 1.38 MB
ffhq/ffhq-0055.png ADDED

Git LFS Details

  • SHA256: 392c9579f25a345ec4e0f9e3e930e2c3d3ccc1f5af555fa426c83dacbada619d
  • Pointer size: 132 Bytes
  • Size of remote file: 1.15 MB
ffhq/ffhq-0062.png ADDED

Git LFS Details

  • SHA256: 4b784d66a1fde7d381910751dbbf6b2a1d5842909dbd31242cbd4f2329a49e39
  • Pointer size: 132 Bytes
  • Size of remote file: 1.52 MB
ffhq/ffhq-0085.png ADDED

Git LFS Details

  • SHA256: 9cfd4523806cab324c6cd5cc71a8c7d734209e9fa463b9967548c9ad36123fc0
  • Pointer size: 132 Bytes
  • Size of remote file: 1.26 MB
ffhq/ffhq-0096.png ADDED

Git LFS Details

  • SHA256: 73fe06affd7fe171b23f39434f4b2addf6e8e8d5fb71764e7ca499f3e7b39028
  • Pointer size: 132 Bytes
  • Size of remote file: 1.3 MB
ffhq/ffhq-0100.png ADDED

Git LFS Details

  • SHA256: 3620d6c5a3280dea0cf9e98b91ce5b505a6c6fee94a515c9cbff250f9d13f9c5
  • Pointer size: 132 Bytes
  • Size of remote file: 1.34 MB
ffhq/ffhq-0112.png ADDED

Git LFS Details

  • SHA256: b53b33b183caa169e0e64d2051e336f9bd454eac0eca962503558e0f28212bf5
  • Pointer size: 132 Bytes
  • Size of remote file: 1.35 MB
ffhq/ffhq-0136.png ADDED

Git LFS Details

  • SHA256: 16f24d2ef0fa499cdc18775682537d2c8976835253962d333567b62cc58b299b
  • Pointer size: 132 Bytes
  • Size of remote file: 1.21 MB
ffhq/ffhq-0138.png ADDED

Git LFS Details

  • SHA256: 5e5633a6902d67fb156059ec20a0a534b8ca923e9c5ffc189d34229830f90b85
  • Pointer size: 132 Bytes
  • Size of remote file: 1.33 MB
ffhq/ffhq-0142.png ADDED

Git LFS Details

  • SHA256: 68cdd9f28187c41dce823804b7f88da999aa361b7196a80b2903af2b827ee3e4
  • Pointer size: 132 Bytes
  • Size of remote file: 1.38 MB
ffhq/ffhq-0154.png ADDED

Git LFS Details

  • SHA256: 7b98c4e8325ff7df1a4531224d66c9fb094a4984602faff341de7f64bf88a0eb
  • Pointer size: 132 Bytes
  • Size of remote file: 1.38 MB
ffhq/ffhq-0155.png ADDED

Git LFS Details

  • SHA256: 0da2b6c1ce5492f88e51fb5e9b560c3f3b9d0aeb925ede3454d81db5f0814b95
  • Pointer size: 132 Bytes
  • Size of remote file: 1.16 MB
ffhq/ffhq-0505.png ADDED

Git LFS Details

  • SHA256: 5b217a10af9855e4d4fa1dfabc39aa84eacb49a38a3ec6fd805e688f39303430
  • Pointer size: 132 Bytes
  • Size of remote file: 1.12 MB
predict_case.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ lightning
2
+ lightning[extra]
3
+ torch==2.5.1
4
+ hydra-core
5
+ deepspeed
6
+ scikit-learn
7
+ evaluate
8
+ timm
9
+ transformers
runtime.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python-3.11