| |
| |
|
|
| |
|
|
| |
| |
| |
| |
|
|
| import os |
| import sys |
| import argparse |
| from os.path import join |
| import random |
| import datetime |
| import time |
| import yaml |
| from tqdm import tqdm |
| import numpy as np |
| from datetime import timedelta |
| from copy import deepcopy |
| from PIL import Image as pil_image |
| from pathlib import Path |
| import gc |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.parallel |
| import torch.backends.cudnn as cudnn |
| import torch.utils.data |
| import torch.optim as optim |
| from torch.utils.data.distributed import DistributedSampler |
| import torch.distributed as dist |
| from torch.utils.data import DataLoader |
|
|
| from optimizor.SAM import SAM |
| from optimizor.LinearLR import LinearDecayLR |
|
|
| from trainer.trainer import Trainer |
| from arena.detectors.UCF.detectors import DETECTOR |
| from metrics.utils import parse_metric_for_print |
| from logger import create_logger, RankFilter |
|
|
| from huggingface_hub import hf_hub_download |
|
|
| |
| from bitmind.dataset_processing.load_split_data import load_datasets, create_real_fake_datasets |
| from bitmind.image_transforms import base_transforms, random_aug_transforms |
| from bitmind.constants import DATASET_META, FACE_TRAINING_DATASET_META |
| from config.constants import ( |
| CONFIG_PATH, |
| WEIGHTS_DIR, |
| HF_REPO, |
| BACKBONE_CKPT |
| ) |
|
|
| parser = argparse.ArgumentParser(description='Process some paths.') |
| parser.add_argument('--detector_path', type=str, default=CONFIG_PATH, help='path to detector YAML file') |
| parser.add_argument('--faces_only', dest='faces_only', action='store_true', default=False) |
| parser.add_argument('--no-save_ckpt', dest='save_ckpt', action='store_false', default=True) |
| parser.add_argument('--no-save_feat', dest='save_feat', action='store_false', default=True) |
| parser.add_argument("--ddp", action='store_true', default=False) |
| parser.add_argument('--local_rank', type=int, default=0) |
| parser.add_argument('--workers', type=int, default=os.cpu_count() - 1, |
| help='number of workers for data loading') |
| parser.add_argument('--epochs', type=int, default=None, help='number of training epochs') |
|
|
| args = parser.parse_args() |
| torch.cuda.set_device(args.local_rank) |
| print(f"torch.cuda.device(0): {torch.cuda.device(0)}") |
| print(f"torch.cuda.get_device_name(0): {torch.cuda.get_device_name(0)}") |
|
|
| def ensure_backbone_is_available(logger, |
| weights_dir=WEIGHTS_DIR, |
| model_filename=BACKBONE_CKPT, |
| hugging_face_repo_name=HF_REPO): |
| |
| destination_path = Path(weights_dir) / Path(model_filename) |
| if not destination_path.parent.exists(): |
| destination_path.parent.mkdir(parents=True, exist_ok=True) |
| logger.info(f"Created directory {destination_path.parent}.") |
| if not destination_path.exists(): |
| model_path = hf_hub_download(hugging_face_repo_name, model_filename) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = torch.load(model_path, map_location=device) |
| torch.save(model, destination_path) |
| del model |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| logger.info(f"Downloaded backbone {model_filename} to {destination_path}.") |
| else: |
| logger.info(f"{model_filename} backbone already present at {destination_path}.") |
|
|
| def init_seed(config): |
| if config['manualSeed'] is None: |
| config['manualSeed'] = random.randint(1, 10000) |
| random.seed(config['manualSeed']) |
| if config['cuda']: |
| torch.manual_seed(config['manualSeed']) |
| torch.cuda.manual_seed_all(config['manualSeed']) |
|
|
| def custom_collate_fn(batch): |
| images, labels, source_labels = zip(*batch) |
| |
| images = torch.stack(images, dim=0) |
| labels = torch.LongTensor(labels) |
| source_labels = torch.LongTensor(source_labels) |
| |
| data_dict = { |
| 'image': images, |
| 'label': labels, |
| 'label_spe': source_labels, |
| 'landmark': None, |
| 'mask': None |
| } |
| return data_dict |
|
|
| def prepare_datasets(config, logger): |
| start_time = log_start_time(logger, "Loading and splitting individual datasets") |
| |
| real_datasets, fake_datasets = load_datasets(dataset_meta=config['dataset_meta'], |
| expert=config['faces_only'], |
| split_transforms=config['split_transforms']) |
|
|
| log_finish_time(logger, "Loading and splitting individual datasets", start_time) |
| |
| start_time = log_start_time(logger, "Creating real fake dataset splits") |
| train_dataset, val_dataset, test_dataset = \ |
| create_real_fake_datasets(real_datasets, |
| fake_datasets, |
| config['split_transforms']['train']['transform'], |
| config['split_transforms']['validation']['transform'], |
| config['split_transforms']['test']['transform'], |
| source_labels=True) |
|
|
| log_finish_time(logger, "Creating real fake dataset splits", start_time) |
| |
| train_loader = torch.utils.data.DataLoader(train_dataset, |
| batch_size=config['train_batchSize'], |
| shuffle=True, |
| num_workers=config['workers'], |
| drop_last=True, |
| collate_fn=custom_collate_fn) |
| val_loader = torch.utils.data.DataLoader(val_dataset, |
| batch_size=config['train_batchSize'], |
| shuffle=True, |
| num_workers=config['workers'], |
| drop_last=True, |
| collate_fn=custom_collate_fn) |
| test_loader = torch.utils.data.DataLoader(test_dataset, |
| batch_size=config['train_batchSize'], |
| shuffle=True, |
| num_workers=config['workers'], |
| drop_last=True, |
| collate_fn=custom_collate_fn) |
|
|
| print(f"Train size: {len(train_loader.dataset)}") |
| print(f"Validation size: {len(val_loader.dataset)}") |
| print(f"Test size: {len(test_loader.dataset)}") |
|
|
| return train_loader, val_loader, test_loader |
|
|
| def choose_optimizer(model, config): |
| opt_name = config['optimizer']['type'] |
| if opt_name == 'sgd': |
| optimizer = optim.SGD( |
| params=model.parameters(), |
| lr=config['optimizer'][opt_name]['lr'], |
| momentum=config['optimizer'][opt_name]['momentum'], |
| weight_decay=config['optimizer'][opt_name]['weight_decay'] |
| ) |
| return optimizer |
| elif opt_name == 'adam': |
| optimizer = optim.Adam( |
| params=model.parameters(), |
| lr=config['optimizer'][opt_name]['lr'], |
| weight_decay=config['optimizer'][opt_name]['weight_decay'], |
| betas=(config['optimizer'][opt_name]['beta1'], config['optimizer'][opt_name]['beta2']), |
| eps=config['optimizer'][opt_name]['eps'], |
| amsgrad=config['optimizer'][opt_name]['amsgrad'], |
| ) |
| return optimizer |
| elif opt_name == 'sam': |
| optimizer = SAM( |
| model.parameters(), |
| optim.SGD, |
| lr=config['optimizer'][opt_name]['lr'], |
| momentum=config['optimizer'][opt_name]['momentum'], |
| ) |
| else: |
| raise NotImplementedError('Optimizer {} is not implemented'.format(config['optimizer'])) |
| return optimizer |
|
|
|
|
| def choose_scheduler(config, optimizer): |
| if config['lr_scheduler'] is None: |
| return None |
| elif config['lr_scheduler'] == 'step': |
| scheduler = optim.lr_scheduler.StepLR( |
| optimizer, |
| step_size=config['lr_step'], |
| gamma=config['lr_gamma'], |
| ) |
| return scheduler |
| elif config['lr_scheduler'] == 'cosine': |
| scheduler = optim.lr_scheduler.CosineAnnealingLR( |
| optimizer, |
| T_max=config['lr_T_max'], |
| eta_min=config['lr_eta_min'], |
| ) |
| return scheduler |
| elif config['lr_scheduler'] == 'linear': |
| scheduler = LinearDecayLR( |
| optimizer, |
| config['nEpochs'], |
| int(config['nEpochs']/4), |
| ) |
| else: |
| raise NotImplementedError('Scheduler {} is not implemented'.format(config['lr_scheduler'])) |
|
|
| def choose_metric(config): |
| metric_scoring = config['metric_scoring'] |
| if metric_scoring not in ['eer', 'auc', 'acc', 'ap']: |
| raise NotImplementedError('metric {} is not implemented'.format(metric_scoring)) |
| return metric_scoring |
|
|
| def log_start_time(logger, process_name): |
| """Log the start time of a process.""" |
| start_time = time.time() |
| logger.info(f"{process_name} Start Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))}") |
| return start_time |
|
|
| def log_finish_time(logger, process_name, start_time): |
| """Log the finish time and elapsed time of a process.""" |
| finish_time = time.time() |
| elapsed_time = finish_time - start_time |
|
|
| |
| hours, rem = divmod(elapsed_time, 3600) |
| minutes, seconds = divmod(rem, 60) |
|
|
| |
| logger.info(f"{process_name} Finish Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(finish_time))}") |
| logger.info(f"{process_name} Elapsed Time: {int(hours)} hours, {int(minutes)} minutes, {seconds:.2f} seconds") |
|
|
| def save_config(config, outputs_dir): |
| """ |
| Saves a config dictionary as both a pickle file and a YAML file, ensuring only basic types are saved. |
| Also, lists like 'mean' and 'std' are saved in flow style (on a single line). |
| |
| Args: |
| config (dict): The configuration dictionary to save. |
| outputs_dir (str): The directory path where the files will be saved. |
| """ |
|
|
| def is_basic_type(value): |
| """ |
| Check if a value is a basic data type that can be saved in YAML. |
| Basic types include int, float, str, bool, list, and dict. |
| """ |
| return isinstance(value, (int, float, str, bool, list, dict, type(None))) |
|
|
| def filter_dict(data_dict): |
| """ |
| Recursively filter out any keys from the dictionary whose values contain non-basic types (e.g., objects). |
| """ |
| if not isinstance(data_dict, dict): |
| return data_dict |
| |
| filtered_dict = {} |
| for key, value in data_dict.items(): |
| if isinstance(value, dict): |
| |
| nested_dict = filter_dict(value) |
| if nested_dict: |
| filtered_dict[key] = nested_dict |
| elif is_basic_type(value): |
| |
| filtered_dict[key] = value |
| else: |
| |
| print(f"Skipping key '{key}' because its value is of type {type(value)}") |
| |
| return filtered_dict |
|
|
| def save_dict_to_yaml(data_dict, file_path): |
| """ |
| Saves a dictionary to a YAML file, excluding any keys where the value is an object or contains an object. |
| Additionally, ensures that specific lists (like 'mean' and 'std') are saved in flow style. |
| |
| Args: |
| data_dict (dict): The dictionary to save. |
| file_path (str): The local file path where the YAML file will be saved. |
| """ |
| |
| |
| class FlowStyleList(list): |
| pass |
| |
| def flow_style_list_representer(dumper, data): |
| return dumper.represent_sequence('tag:yaml.org,2002:seq', data, flow_style=True) |
| |
| yaml.add_representer(FlowStyleList, flow_style_list_representer) |
|
|
| |
| if 'mean' in data_dict: |
| data_dict['mean'] = FlowStyleList(data_dict['mean']) |
| if 'std' in data_dict: |
| data_dict['std'] = FlowStyleList(data_dict['std']) |
|
|
| try: |
| |
| filtered_dict = filter_dict(data_dict) |
| |
| |
| with open(file_path, 'w') as f: |
| yaml.dump(filtered_dict, f, default_flow_style=False) |
| print(f"Filtered dictionary successfully saved to {file_path}") |
| except Exception as e: |
| print(f"Error saving dictionary to YAML: {e}") |
|
|
| |
| save_dict_to_yaml(config, outputs_dir + '/config.yaml') |
|
|
| def main(): |
| torch.cuda.empty_cache() |
| gc.collect() |
| |
| with open(args.detector_path, 'r') as f: |
| config = yaml.safe_load(f) |
| with open(os.getcwd() + '/config/train_config.yaml', 'r') as f: |
| config2 = yaml.safe_load(f) |
| if 'label_dict' in config: |
| config2['label_dict']=config['label_dict'] |
| config.update(config2) |
|
|
| config['workers'] = args.workers |
| |
| config['local_rank']=args.local_rank |
| if config['dry_run']: |
| config['nEpochs'] = 0 |
| config['save_feat']=False |
|
|
| if args.epochs: config['nEpochs'] = args.epochs |
|
|
| config['split_transforms'] = {'train': {'name': 'base_transforms', |
| 'transform': base_transforms}, |
| 'validation': {'name': 'base_transforms', |
| 'transform': base_transforms}, |
| 'test': {'name': 'base_transforms', |
| 'transform': base_transforms}} |
| config['faces_only'] = args.faces_only |
| config['dataset_meta'] = FACE_TRAINING_DATASET_META if config['faces_only'] else DATASET_META |
| dataset_names = [item["path"] for datasets in config['dataset_meta'].values() for item in datasets] |
| config['train_dataset'] = dataset_names |
| config['save_ckpt'] = args.save_ckpt |
| config['save_feat'] = args.save_feat |
|
|
| config['specific_task_number'] = len(config['dataset_meta']["fake"]) + 1 |
| |
| if config['lmdb']: |
| config['dataset_json_folder'] = 'preprocessing/dataset_json_v3' |
| |
| |
| timenow=datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') |
| |
| outputs_dir = os.path.join( |
| config['log_dir'], |
| config['model_name'] + '_' + timenow |
| ) |
| |
| os.makedirs(outputs_dir, exist_ok=True) |
| logger = create_logger(os.path.join(outputs_dir, 'training.log')) |
| config['log_dir'] = outputs_dir |
| logger.info('Save log to {}'.format(outputs_dir)) |
| |
| config['ddp']= args.ddp |
|
|
| |
| init_seed(config) |
|
|
| |
| if config['cudnn']: |
| cudnn.benchmark = True |
| if config['ddp']: |
| |
| dist.init_process_group( |
| backend='nccl', |
| timeout=timedelta(minutes=30) |
| ) |
| logger.addFilter(RankFilter(0)) |
|
|
| ensure_backbone_is_available(logger=logger, |
| model_filename=config['pretrained'].split('/')[-1], |
| hugging_face_repo_name='bitmind/' + config['model_name']) |
| |
| |
| model_class = DETECTOR[config['model_name']] |
| model = model_class(config) |
| |
| |
| optimizer = choose_optimizer(model, config) |
|
|
| |
| scheduler = choose_scheduler(config, optimizer) |
|
|
| |
| metric_scoring = choose_metric(config) |
|
|
| |
| trainer = Trainer(config, model, optimizer, scheduler, logger, metric_scoring) |
|
|
| |
| train_loader, val_loader, test_loader = prepare_datasets(config, logger) |
|
|
| |
| logger.info("--------------- Configuration ---------------") |
| params_string = "Parameters: \n" |
| for key, value in config.items(): |
| params_string += "{}: {}".format(key, value) + "\n" |
| logger.info(params_string) |
| |
| |
| save_config(config, outputs_dir) |
| |
| |
| start_time = log_start_time(logger, "Training") |
| for epoch in range(config['start_epoch'], config['nEpochs'] + 1): |
| trainer.model.epoch = epoch |
| best_metric = trainer.train_epoch( |
| epoch, |
| train_data_loader=train_loader, |
| validation_data_loaders={'val':val_loader} |
| ) |
| if best_metric is not None: |
| logger.info(f"===> Epoch[{epoch}] end with validation {metric_scoring}: {parse_metric_for_print(best_metric)}!") |
| logger.info("Stop Training on best Validation metric {}".format(parse_metric_for_print(best_metric))) |
| log_finish_time(logger, "Training", start_time) |
| |
| |
| start_time = log_start_time(logger, "Test") |
| trainer.eval(eval_data_loaders={'test':test_loader}, eval_stage="test") |
| log_finish_time(logger, "Test", start_time) |
| |
| |
| if 'svdd' in config['model_name']: |
| model.update_R(epoch) |
| if scheduler is not None: |
| scheduler.step() |
|
|
| |
| for writer in trainer.writers.values(): |
| writer.close() |
|
|
| torch.cuda.empty_cache() |
| gc.collect() |
|
|
| if __name__ == '__main__': |
| main() |