Spaces:
Runtime error
Runtime error
| import argparse | |
| import json | |
| import os | |
| from collections import defaultdict | |
| from sklearn.metrics import log_loss | |
| from torch import topk | |
| import sys | |
| print('@@@@@@@@@@@@@@@@@@') | |
| sys.path.append('..') | |
| from training import losses | |
| from training.datasets.classifier_dataset import DeepFakeClassifierDataset | |
| from training.losses import WeightedLosses | |
| from training.tools.config import load_config | |
| from training.tools.utils import create_optimizer, AverageMeter | |
| from training.transforms.albu import IsotropicResize | |
| from training.zoo import classifiers | |
| os.environ["MKL_NUM_THREADS"] = "1" | |
| os.environ["NUMEXPR_NUM_THREADS"] = "1" | |
| os.environ["OMP_NUM_THREADS"] = "1" | |
| import cv2 | |
| cv2.ocl.setUseOpenCL(False) | |
| cv2.setNumThreads(0) | |
| import numpy as np | |
| from albumentations import Compose, RandomBrightnessContrast, \ | |
| HorizontalFlip, FancyPCA, HueSaturationValue, OneOf, ToGray, \ | |
| ShiftScaleRotate, ImageCompression, PadIfNeeded, GaussNoise, GaussianBlur | |
| from apex.parallel import DistributedDataParallel, convert_syncbn_model | |
| from tensorboardX import SummaryWriter | |
| from apex import amp | |
| import torch | |
| from torch.backends import cudnn | |
| from torch.nn import DataParallel | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| import torch.distributed as dist | |
| torch.backends.cudnn.benchmark = True | |
| def create_train_transforms(size=300): | |
| return Compose([ | |
| ImageCompression(quality_lower=60, quality_upper=100, p=0.5), | |
| GaussNoise(p=0.1), | |
| GaussianBlur(blur_limit=3, p=0.05), | |
| HorizontalFlip(), | |
| OneOf([ | |
| IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC), | |
| IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_LINEAR), | |
| IsotropicResize(max_side=size, interpolation_down=cv2.INTER_LINEAR, interpolation_up=cv2.INTER_LINEAR), | |
| ], p=1), | |
| PadIfNeeded(min_height=size, min_width=size, border_mode=cv2.BORDER_CONSTANT), | |
| OneOf([RandomBrightnessContrast(), FancyPCA(), HueSaturationValue()], p=0.7), | |
| ToGray(p=0.2), | |
| ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=10, border_mode=cv2.BORDER_CONSTANT, p=0.5), | |
| ] | |
| ) | |
| def create_val_transforms(size=300): | |
| return Compose([ | |
| IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC), | |
| PadIfNeeded(min_height=size, min_width=size, border_mode=cv2.BORDER_CONSTANT), | |
| ]) | |
| def main(): | |
| parser = argparse.ArgumentParser("PyTorch Xview Pipeline") | |
| arg = parser.add_argument | |
| arg('--config', metavar='CONFIG_FILE', help='path to configuration file') | |
| arg('--workers', type=int, default=6, help='number of cpu threads to use') | |
| arg('--gpu', type=str, default='0', help='List of GPUs for parallel training, e.g. 0,1,2,3') | |
| arg('--output-dir', type=str, default='weights/') | |
| arg('--resume', type=str, default='') | |
| arg('--fold', type=int, default=0) | |
| arg('--prefix', type=str, default='classifier_') | |
| arg('--data-dir', type=str, default="/mnt/sota/datasets/deepfake") | |
| arg('--folds-csv', type=str, default='folds.csv') | |
| arg('--crops-dir', type=str, default='crops') | |
| arg('--label-smoothing', type=float, default=0.01) | |
| arg('--logdir', type=str, default='logs') | |
| arg('--zero-score', action='store_true', default=False) | |
| arg('--from-zero', action='store_true', default=False) | |
| arg('--distributed', action='store_true', default=False) | |
| arg('--freeze-epochs', type=int, default=0) | |
| arg("--local_rank", default=0, type=int) | |
| arg("--seed", default=777, type=int) | |
| arg("--padding-part", default=3, type=int) | |
| arg("--opt-level", default='O1', type=str) | |
| arg("--test_every", type=int, default=1) | |
| arg("--no-oversample", action="store_true") | |
| arg("--no-hardcore", action="store_true") | |
| arg("--only-changed-frames", action="store_true") | |
| args = parser.parse_args() | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| if args.distributed: | |
| torch.cuda.set_device(args.local_rank) | |
| torch.distributed.init_process_group(backend='nccl', init_method='env://') | |
| else: | |
| os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' | |
| os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu | |
| cudnn.benchmark = True | |
| conf = load_config(args.config) | |
| model = classifiers.__dict__[conf['network']](encoder=conf['encoder']) | |
| model = model.cuda() | |
| if args.distributed: | |
| model = convert_syncbn_model(model) | |
| ohem = conf.get("ohem_samples", None) | |
| reduction = "mean" | |
| if ohem: | |
| reduction = "none" | |
| loss_fn = [] | |
| weights = [] | |
| for loss_name, weight in conf["losses"].items(): | |
| loss_fn.append(losses.__dict__[loss_name](reduction=reduction).cuda()) | |
| weights.append(weight) | |
| loss = WeightedLosses(loss_fn, weights) | |
| loss_functions = {"classifier_loss": loss} | |
| optimizer, scheduler = create_optimizer(conf['optimizer'], model) | |
| bce_best = 100 | |
| start_epoch = 0 | |
| batch_size = conf['optimizer']['batch_size'] | |
| data_train = DeepFakeClassifierDataset(mode="train", | |
| oversample_real=not args.no_oversample, | |
| fold=args.fold, | |
| padding_part=args.padding_part, | |
| hardcore=not args.no_hardcore, | |
| crops_dir=args.crops_dir, | |
| data_path=args.data_dir, | |
| label_smoothing=args.label_smoothing, | |
| folds_csv=args.folds_csv, | |
| transforms=create_train_transforms(conf["size"]), | |
| normalize=conf.get("normalize", None)) | |
| data_val = DeepFakeClassifierDataset(mode="val", | |
| fold=args.fold, | |
| padding_part=args.padding_part, | |
| crops_dir=args.crops_dir, | |
| data_path=args.data_dir, | |
| folds_csv=args.folds_csv, | |
| transforms=create_val_transforms(conf["size"]), | |
| normalize=conf.get("normalize", None)) | |
| val_data_loader = DataLoader(data_val, batch_size=batch_size * 2, num_workers=args.workers, shuffle=False, | |
| pin_memory=False) | |
| os.makedirs(args.logdir, exist_ok=True) | |
| summary_writer = SummaryWriter(args.logdir + '/' + conf.get("prefix", args.prefix) + conf['encoder'] + "_" + str(args.fold)) | |
| if args.resume: | |
| if os.path.isfile(args.resume): | |
| print("=> loading checkpoint '{}'".format(args.resume)) | |
| checkpoint = torch.load(args.resume, map_location='cpu') | |
| state_dict = checkpoint['state_dict'] | |
| state_dict = {k[7:]: w for k, w in state_dict.items()} | |
| model.load_state_dict(state_dict, strict=False) | |
| if not args.from_zero: | |
| start_epoch = checkpoint['epoch'] | |
| if not args.zero_score: | |
| bce_best = checkpoint.get('bce_best', 0) | |
| print("=> loaded checkpoint '{}' (epoch {}, bce_best {})" | |
| .format(args.resume, checkpoint['epoch'], checkpoint['bce_best'])) | |
| else: | |
| print("=> no checkpoint found at '{}'".format(args.resume)) | |
| if args.from_zero: | |
| start_epoch = 0 | |
| current_epoch = start_epoch | |
| if conf['fp16']: | |
| model, optimizer = amp.initialize(model, optimizer, | |
| opt_level=args.opt_level, | |
| loss_scale='dynamic') | |
| snapshot_name = "{}{}_{}_{}".format(conf.get("prefix", args.prefix), conf['network'], conf['encoder'], args.fold) | |
| if args.distributed: | |
| model = DistributedDataParallel(model, delay_allreduce=True) | |
| else: | |
| model = DataParallel(model).cuda() | |
| data_val.reset(1, args.seed) | |
| max_epochs = conf['optimizer']['schedule']['epochs'] | |
| for epoch in range(start_epoch, max_epochs): | |
| data_train.reset(epoch, args.seed) | |
| train_sampler = None | |
| if args.distributed: | |
| train_sampler = torch.utils.data.distributed.DistributedSampler(data_train) | |
| train_sampler.set_epoch(epoch) | |
| if epoch < args.freeze_epochs: | |
| print("Freezing encoder!!!") | |
| model.module.encoder.eval() | |
| for p in model.module.encoder.parameters(): | |
| p.requires_grad = False | |
| else: | |
| model.module.encoder.train() | |
| for p in model.module.encoder.parameters(): | |
| p.requires_grad = True | |
| train_data_loader = DataLoader(data_train, batch_size=batch_size, num_workers=args.workers, | |
| shuffle=train_sampler is None, sampler=train_sampler, pin_memory=False, | |
| drop_last=True) | |
| train_epoch(current_epoch, loss_functions, model, optimizer, scheduler, train_data_loader, summary_writer, conf, | |
| args.local_rank, args.only_changed_frames) | |
| model = model.eval() | |
| if args.local_rank == 0: | |
| torch.save({ | |
| 'epoch': current_epoch + 1, | |
| 'state_dict': model.state_dict(), | |
| 'bce_best': bce_best, | |
| }, args.output_dir + '/' + snapshot_name + "_last") | |
| torch.save({ | |
| 'epoch': current_epoch + 1, | |
| 'state_dict': model.state_dict(), | |
| 'bce_best': bce_best, | |
| }, args.output_dir + snapshot_name + "_{}".format(current_epoch)) | |
| if (epoch + 1) % args.test_every == 0: | |
| bce_best = evaluate_val(args, val_data_loader, bce_best, model, | |
| snapshot_name=snapshot_name, | |
| current_epoch=current_epoch, | |
| summary_writer=summary_writer) | |
| current_epoch += 1 | |
| def evaluate_val(args, data_val, bce_best, model, snapshot_name, current_epoch, summary_writer): | |
| print("Test phase") | |
| model = model.eval() | |
| bce, probs, targets = validate(model, data_loader=data_val) | |
| if args.local_rank == 0: | |
| summary_writer.add_scalar('val/bce', float(bce), global_step=current_epoch) | |
| if bce < bce_best: | |
| print("Epoch {} improved from {} to {}".format(current_epoch, bce_best, bce)) | |
| if args.output_dir is not None: | |
| torch.save({ | |
| 'epoch': current_epoch + 1, | |
| 'state_dict': model.state_dict(), | |
| 'bce_best': bce, | |
| }, args.output_dir + snapshot_name + "_best_dice") | |
| bce_best = bce | |
| with open("predictions_{}.json".format(args.fold), "w") as f: | |
| json.dump({"probs": probs, "targets": targets}, f) | |
| torch.save({ | |
| 'epoch': current_epoch + 1, | |
| 'state_dict': model.state_dict(), | |
| 'bce_best': bce_best, | |
| }, args.output_dir + snapshot_name + "_last") | |
| print("Epoch: {} bce: {}, bce_best: {}".format(current_epoch, bce, bce_best)) | |
| return bce_best | |
| def validate(net, data_loader, prefix=""): | |
| probs = defaultdict(list) | |
| targets = defaultdict(list) | |
| with torch.no_grad(): | |
| for sample in tqdm(data_loader): | |
| imgs = sample["image"].cuda() | |
| img_names = sample["img_name"] | |
| labels = sample["labels"].cuda().float() | |
| out = net(imgs) | |
| labels = labels.cpu().numpy() | |
| preds = torch.sigmoid(out).cpu().numpy() | |
| for i in range(out.shape[0]): | |
| video, img_id = img_names[i].split("/") | |
| probs[video].append(preds[i].tolist()) | |
| targets[video].append(labels[i].tolist()) | |
| data_x = [] | |
| data_y = [] | |
| for vid, score in probs.items(): | |
| score = np.array(score) | |
| lbl = targets[vid] | |
| score = np.mean(score) | |
| lbl = np.mean(lbl) | |
| data_x.append(score) | |
| data_y.append(lbl) | |
| y = np.array(data_y) | |
| x = np.array(data_x) | |
| fake_idx = y > 0.1 | |
| real_idx = y < 0.1 | |
| fake_loss = log_loss(y[fake_idx], x[fake_idx], labels=[0, 1]) | |
| real_loss = log_loss(y[real_idx], x[real_idx], labels=[0, 1]) | |
| print("{}fake_loss".format(prefix), fake_loss) | |
| print("{}real_loss".format(prefix), real_loss) | |
| return (fake_loss + real_loss) / 2, probs, targets | |
| def train_epoch(current_epoch, loss_functions, model, optimizer, scheduler, train_data_loader, summary_writer, conf, | |
| local_rank, only_valid): | |
| losses = AverageMeter() | |
| fake_losses = AverageMeter() | |
| real_losses = AverageMeter() | |
| max_iters = conf["batches_per_epoch"] | |
| print("training epoch {}".format(current_epoch)) | |
| model.train() | |
| pbar = tqdm(enumerate(train_data_loader), total=max_iters, desc="Epoch {}".format(current_epoch), ncols=0) | |
| if conf["optimizer"]["schedule"]["mode"] == "epoch": | |
| scheduler.step(current_epoch) | |
| for i, sample in pbar: | |
| imgs = sample["image"].cuda() | |
| labels = sample["labels"].cuda().float() | |
| out_labels = model(imgs) | |
| if only_valid: | |
| valid_idx = sample["valid"].cuda().float() > 0 | |
| out_labels = out_labels[valid_idx] | |
| labels = labels[valid_idx] | |
| if labels.size(0) == 0: | |
| continue | |
| fake_loss = 0 | |
| real_loss = 0 | |
| fake_idx = labels > 0.5 | |
| real_idx = labels <= 0.5 | |
| ohem = conf.get("ohem_samples", None) | |
| if torch.sum(fake_idx * 1) > 0: | |
| fake_loss = loss_functions["classifier_loss"](out_labels[fake_idx], labels[fake_idx]) | |
| if torch.sum(real_idx * 1) > 0: | |
| real_loss = loss_functions["classifier_loss"](out_labels[real_idx], labels[real_idx]) | |
| if ohem: | |
| fake_loss = topk(fake_loss, k=min(ohem, fake_loss.size(0)), sorted=False)[0].mean() | |
| real_loss = topk(real_loss, k=min(ohem, real_loss.size(0)), sorted=False)[0].mean() | |
| loss = (fake_loss + real_loss) / 2 | |
| losses.update(loss.item(), imgs.size(0)) | |
| fake_losses.update(0 if fake_loss == 0 else fake_loss.item(), imgs.size(0)) | |
| real_losses.update(0 if real_loss == 0 else real_loss.item(), imgs.size(0)) | |
| optimizer.zero_grad() | |
| pbar.set_postfix({"lr": float(scheduler.get_lr()[-1]), "epoch": current_epoch, "loss": losses.avg, | |
| "fake_loss": fake_losses.avg, "real_loss": real_losses.avg}) | |
| if conf['fp16']: | |
| with amp.scale_loss(loss, optimizer) as scaled_loss: | |
| scaled_loss.backward() | |
| else: | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1) | |
| optimizer.step() | |
| torch.cuda.synchronize() | |
| if conf["optimizer"]["schedule"]["mode"] in ("step", "poly"): | |
| scheduler.step(i + current_epoch * max_iters) | |
| if i == max_iters - 1: | |
| break | |
| pbar.close() | |
| if local_rank == 0: | |
| for idx, param_group in enumerate(optimizer.param_groups): | |
| lr = param_group['lr'] | |
| summary_writer.add_scalar('group{}/lr'.format(idx), float(lr), global_step=current_epoch) | |
| summary_writer.add_scalar('train/loss', float(losses.avg), global_step=current_epoch) | |
| if __name__ == '__main__': | |
| main() | |