File size: 15,710 Bytes
8bc3305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
"""
eval pretained model with multi-GPU support.
"""
import os
import numpy as np
from os.path import join
import cv2
import random
import datetime
import time
import yaml
import pickle
from tqdm import tqdm
from copy import deepcopy
from PIL import Image as pil_image
from metrics.utils import get_test_metrics
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.utils.data
import torch.optim as optim
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler

from dataset.abstract_dataset import DeepfakeAbstractBaseDataset
from dataset.ff_blend import FFBlendDataset
from dataset.fwa_blend import FWABlendDataset
from dataset.pair_dataset import pairDataset

from trainer.trainer import Trainer
from detectors import DETECTOR
from metrics.base_metrics_class import Recorder, calculate_acc_for_test
import metrics_retrieval.utils
from metrics_retrieval.get_metric_pro4 import *
from metrics_retrieval.get_metric import *

from collections import defaultdict

import argparse
from logger import create_logger

parser = argparse.ArgumentParser(description='Process some paths.')
parser.add_argument('--detector_path', type=str, default='/PATH/TO/resnet34.yaml', help='path to detector YAML file')
parser.add_argument("--test_dataset", nargs="+")
parser.add_argument('--weights_path', type=str, default='')
parser.add_argument('--ddp', action='store_true', help='Use DistributedDataParallel')
parser.add_argument('--use_latest', action='store_true', help='Use Latest Ckpt')
parser.add_argument('--local_rank', '--local-rank', type=int, default=-1, help='Local rank for DDP')
parser.add_argument('--test_config', type=str, default='test_config_p2.yaml', help='test_config_p2.yaml / test_config_p4.yaml')
args = parser.parse_args()


def init_seed(config, seed=None):
    if seed is None:
        if config['manualSeed'] is None:
            config['manualSeed'] = random.randint(1, 10000)
        seed = config['manualSeed']
    random.seed(seed)
    torch.manual_seed(seed)
    if config['cuda']:
        torch.cuda.manual_seed_all(seed)
    return seed


def prepare_testing_data(config, ddp=False):
    def get_test_data_loader(config, test_name):
        # update the config dictionary with the specific testing dataset
        config = config.copy()  # create a copy of config to avoid altering the original one
        config['test_dataset'] = test_name  # specify the current test dataset
        test_set = DeepfakeAbstractBaseDataset(
                config=config,
                mode='test',
            )

        # Use DistributedSampler to distribute the data
        sampler = DistributedSampler(test_set, shuffle=False) if ddp else None

        test_data_loader = \
            torch.utils.data.DataLoader(
                dataset=test_set,
                batch_size=config['test_batchSize'],
                shuffle=(sampler is None),
                num_workers=int(config['workers']),
                collate_fn=test_set.collate_fn,
                drop_last=False,
                pin_memory=True,
                sampler=sampler  # add sampler
            )
        return test_data_loader, test_set.data_dict

    test_data_loaders = {}
    test_data_dicts = {}
    for one_test_name in config['test_dataset']:
        loader, data_dict = get_test_data_loader(config, one_test_name)
        test_data_loaders[one_test_name] = loader
        test_data_dicts[one_test_name] = data_dict
    return test_data_loaders, test_data_dicts


def choose_metric(config):
    metric_scoring = config['metric_scoring']
    if metric_scoring not in ['eer', 'auc', 'acc', 'ap']:
        raise NotImplementedError('metric {} is not implemented'.format(metric_scoring))
    return metric_scoring


def test_one_dataset(model, data_loader, device, local_rank):
    # Initialize empty lists to store tensors
    prediction_lists = []
    feature_lists = []
    label_lists = []
    img_name_lists = []

    # Only the main process shows the progress bar
    pbar = tqdm(enumerate(data_loader), total=len(data_loader), disable=(local_rank != 0))

    for i, data_dict in pbar:
        # get data
        data, label, mask, landmark = data_dict['image'], data_dict['label'], data_dict['mask'], data_dict['landmark']
        img_names =[]  # data_dict['image']  # Image names are still strings and are stored separately

        # Move data to GPU (keeping the original logic)
        data_dict['image'], data_dict['label'] = data.to(device), label.to(device)
        if mask is not None:
            data_dict['mask'] = mask.to(device)
        if landmark is not None:
            data_dict['landmark'] = landmark.to(device)

        # Model forward pass (no gradients, original logic unchanged)
        predictions = inference(model, data_dict)

        # Use append instead of extend, and concatenate later with torch.cat
        label_lists.append(data_dict['label'])  # label is a tensor, so append it directly
        prediction_lists.append(predictions['prob'])  # prob is the tensor output by the model
        feature_lists.append(predictions['feat'])  # the same applies to feat
        img_name_lists.extend(img_names)  # String lists still use extend

    # If the current process has no data (an extreme case), return empty tensors to avoid errors
    predictions_tensor = torch.cat(prediction_lists, dim=0) if prediction_lists else torch.tensor([], device=device)
    labels_tensor = torch.cat(label_lists, dim=0) if label_lists else torch.tensor([], device=device)
    feats_tensor = torch.cat(feature_lists, dim=0) if feature_lists else torch.tensor([], device=device)

    print("feats_tensor", feats_tensor.shape)

    # Return results in tensor form (image names remain a list)
    return predictions_tensor, labels_tensor, feats_tensor, img_name_lists


def test_epoch(model, test_data_loaders, test_data_dicts, device, local_rank, ddp, config, logger):
    # set model to eval mode
    model.eval()

    # define test recorder
    metrics_all_datasets = {}

    # testing for all test data
    keys = test_data_loaders.keys()
    for key in keys:

        # 1.Dataset Name
        print("Run Dataset:", key)
        # if args.local_rank == 0:
        logger.info(f"--------------- Run Dataset: {key} ---------------")
        logger.info(f"--------------- Run Dataset: {logger.log_path} ---------------")

        data_loader = test_data_loaders[key]
        data_dict = test_data_dicts[key]

        # Set the sampler epoch in DDP mode
        if ddp and hasattr(data_loader.sampler, 'set_epoch'):
            data_loader.sampler.set_epoch(0)

        # Each process computes its own portion (the return values are tensors at this point)
        predictions_tensor, labels_tensor, feats_tensor, img_names = test_one_dataset(
            model, data_loader, device, local_rank)

        # Gather results from all processes (only the main process needs the full results)
        if ddp:
            world_size = dist.get_world_size()

            # 1. Gather predictions
            all_predictions = [torch.zeros_like(predictions_tensor) for _ in range(world_size)]
            dist.all_gather(all_predictions, predictions_tensor)

            # 2. Gather labels
            all_labels = [torch.zeros_like(labels_tensor) for _ in range(world_size)]
            dist.all_gather(all_labels, labels_tensor)

            # 3. Gather features (optional)
            all_feats = [torch.zeros_like(feats_tensor) for _ in range(world_size)]
            dist.all_gather(all_feats, feats_tensor)

            all_predictions = torch.cat(all_predictions, dim=0)
            all_labels = torch.cat(all_labels, dim=0)
            all_feats = torch.cat(all_feats, dim=0)
        else:
            # In non-DDP mode, convert directly to NumPy (only once)
            all_predictions = predictions_tensor.cpu().numpy()
            all_labels = labels_tensor.cpu().numpy()
            all_feats = feats_tensor.cpu().numpy()
            #all_img_names = img_names

        # Only the main process computes metrics and outputs results
        if local_rank == 0:
            # compute metric for each dataset
            metric_one_dataset = calculate_acc_for_test(all_labels, all_predictions, config['backbone_config']['num_classes'])
            metrics_all_datasets[key] = metric_one_dataset

            # Information for each dataset
            tqdm.write(f"dataset: {key}")
            for k, v in metric_one_dataset.items():
                tqdm.write(f"{k}: {v}")
                logger.info(f"{k}: {v}")

            # save info
            pkl_save_path = os.path.join(os.path.dirname(logger.log_path), f"{key}.pkl")
            save_data = {
                "all_predictions": all_predictions.cpu().numpy(),
                "all_labels": all_labels.cpu().numpy(),
                "all_feats": all_feats.cpu().numpy(),
                "metrics": metric_one_dataset,  # Additionally save metrics for the current dataset to facilitate later analysis
                "all_names": img_names
            }
            with open(pkl_save_path, "wb") as f:
                pickle.dump(save_data, f, protocol=pickle.HIGHEST_PROTOCOL)  # Using the highest protocol is more efficient

    return metrics_all_datasets if local_rank == 0 else None


@torch.no_grad()
def inference(model, data_dict):
    from torch.cuda.amp import autocast
    with autocast(dtype=torch.float16):
        predictions = model(data_dict, inference=True)
    return predictions


def main():
    # Initialize DDP
    ddp = args.ddp
    local_rank = args.local_rank

    if ddp:
        # Initialize the process group
        torch.cuda.set_device(local_rank)
        dist.init_process_group(
            backend='nccl',
            init_method='env://',  # Read rendezvous information from environment variables (set automatically by torchrun)
            world_size=int(os.environ.get("WORLD_SIZE", 1)),  # total number of GPUs
            rank=int(os.environ.get("RANK", 0))
            )
        device = torch.device("cuda", local_rank)
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # parse options and load config
    # Model-specific configuration
    with open(args.detector_path, 'r') as f:
        config = yaml.safe_load(f)
    # Unified base configuration
    with open(f'./training/config/{args.test_config}', 'r') as f:
        config_base = yaml.safe_load(f)

    # Label dictionary shared by all datasets
    if 'label_dict' in config:
        config_base['label_dict']=config['label_dict'] # The base configuration has the highest priority
    config.update(config_base)

    weights_path = None
    # If arguments are provided, they will overwrite the yaml settings
    if args.test_dataset:
        config['test_dataset'] = args.test_dataset
    if args.weights_path:
        config['weights_path'] = args.weights_path
        weights_path = args.weights_path

    # Set the same seed for DDP
    seed = init_seed(config)
    if ddp:
        # Use a different seed offset for each process to ensure data augmentation diversity
        seed += dist.get_rank()
        init_seed(config, seed)

    # set cudnn benchmark if needed
    if config['cudnn']:
        cudnn.benchmark = True

    # Log information
    logs_test_dir = weights_path.replace("logs", "logs_test")
    # if local_rank == 0:
    # creat log
    os.makedirs(logs_test_dir, exist_ok=True)
    logger = create_logger(os.path.join(logs_test_dir, 'testing.log'))
    logger.info('Save log to {}'.format(logs_test_dir))
    # print configuration
    logger.info("--------------- Configuration ---------------")
    params_string = "Parameters: \n"
    for key, value in config.items():
        params_string += "{}: {}".format(key, value) + "\n"
    logger.info(params_string)

    # prepare the testing data loader
    test_data_loaders, test_data_dicts = prepare_testing_data(config, ddp)

    # prepare the model (detector)
    model_class = DETECTOR[config['model_name']]
    model = model_class(config).to(device)
    epoch = 0

    # Only print model parameter information on the main process
    if local_rank == 0:
        for name, param in model.named_parameters():
            print(f"{name}: {param.shape}")

    if weights_path:
        # For models containing LoRA, switch to eval mode first to avoid repeatedly stacking weights
        if 'lora' in config['model_name'].lower() or "pmoe" in config['model_name'].lower():
            model.eval()

    if weights_path:
        try:
            epoch = int(weights_path.split('/')[-1].split('.')[0].split('_')[2])
        except:
            epoch = 0

        # Automatically find the best checkpoint
        if args.use_latest:
            ckpt_path = os.path.join(weights_path, "test/protocol_2_test/ckpt_latest.pth")
        else:
            if weights_path[-3:] == "pth":
                ckpt_path = weights_path
            else:
                ckpt_path = os.path.join(weights_path, "test/protocol_2_test/ckpt_best.pth")
            # ckpt_path = os.path.join(weights_path, "test/protocol_2_test/ckpt_best.pth")
        ckpt = torch.load(ckpt_path, map_location=f"cuda:{local_rank}")
        logger.info(f"Load ckpt: {ckpt_path}")

        # Remove the "module." prefix from the weights (if DDP was used during training)
        new_state_dict = {k.replace('module.', ''): v for k, v in ckpt.items()}

        model.load_state_dict(new_state_dict, strict=False)

        if local_rank == 0:
            print('===> Load checkpoint done!')
    else:
        if local_rank == 0:
            print('Fail to load the pre-trained weights')

    if ddp:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)

    # start testing
    best_metric = test_epoch(model, test_data_loaders, test_data_dicts, device, local_rank, ddp, config, logger)

    if local_rank == 0:
        print('===> Test Done!')

    # Clean up the DDP process group
    if ddp:
        dist.barrier()
        dist.destroy_process_group()

    if local_rank==0:
        #Metric test
        for prot in config['test_dataset']:
            prefix=config["weights_path"].replace("logs/", "logs_test/")
            prefix=os.path.join("/Youtu_Pangu_Security_Public_cq11/shunliwang/DeepFakeBench_DFG",prefix)
            if "protocol_2" in prot:
                pkl_file="protocol_2_test.pkl"
                RANK_MAX = 10
                seed = 42
                PKL_FILE_PATH = os.path.join(prefix, pkl_file)
                run_retrieval_evaluation(pkl_file_path=PKL_FILE_PATH, query_mode='10_sample_avg', rank_max=RANK_MAX,random_seed=seed)
            elif "protocol_3" in prot:
                pkl_file="protocol_3_test.pkl"
                RANK_MAX = 10
                seed = 42
                PKL_FILE_PATH = os.path.join(prefix, pkl_file)
                run_retrieval_evaluation(pkl_file_path=PKL_FILE_PATH, query_mode='10_sample_avg', rank_max=RANK_MAX,random_seed=seed)
            elif "protocol_4" in prot:
                pkl_file="protocol_4_test.pkl"
                RANK_MAX = 10
                seed = 42
                PKL_FILE_PATH = os.path.join(prefix, pkl_file)
                yaml_path="config/test_config_p4.yaml"
                run_retrieval_evaluation_p4(pkl_file_path=PKL_FILE_PATH, query_mode='10_sample_avg', rank_max=RANK_MAX,random_seed=seed,yaml_path=yaml_path)






if __name__ == '__main__':
    main()


# 1.Useful information in the log

# 2.Create the log_test directory

# 3.Create logger text output

# 4.Save features and labels