| | import tqdm |
| | import argparse |
| | import math |
| | |
| | import sys |
| | import os |
| | import time |
| | import logging |
| | from datetime import datetime |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.optim as optim |
| |
|
| | import torchvision |
| | from torch.utils.data import DataLoader |
| | from torchvision import transforms |
| | from torchvision.models import resnet50 |
| |
|
| | import yaml |
| | from pytorch_msssim import ms_ssim |
| | from DISTS_pytorch import DISTS |
| | from util.lpips import LPIPS |
| | from torch.nn import functional as F |
| | from torchvision import utils as vutils |
| | import matplotlib.pyplot as plt |
| | import numpy as np |
| | import glob |
| |
|
| | import util.misc as misc |
| | import util.lr_sched as lr_sched |
| | from torch.utils.tensorboard import SummaryWriter |
| | import models_mage_codec |
| | import models_mage_codec_high_resolu |
| | import timm.optim.optim_factory as optim_factory |
| | from util.misc import NativeScalerWithGradNormCount as NativeScaler |
| | import json |
| | import PIL.Image as Image |
| | import torch.backends.cudnn as cudnn |
| | from pathlib import Path |
| | import random |
| | import torch.distributed as dist |
| | from util.dataloader import MSCOCO, Kodak, prepadding, crop_to_original_shape |
| |
|
| | class CalMetrics(nn.Module): |
| | """Calculate BPP, PSNR, MS-SSIM, LPIPS and DISTS for the reconstructed image.""" |
| |
|
| | def __init__(self): |
| | super().__init__() |
| | self.mse = nn.MSELoss() |
| |
|
| | def bpp_loss(self, ori, out_net): |
| | b, _, h, w = ori.shape |
| | num_pixels = b * h * w |
| | |
| | |
| | |
| | |
| | bpp = torch.log(out_net["likelihoods"]).sum() / (-math.log(2) * num_pixels) |
| | bs_mask_token = out_net['bs_mask_token'] |
| | bytes_length = len(bs_mask_token) |
| | |
| | total_bits = bytes_length * 8 |
| | |
| | bpp_mask = total_bits / num_pixels |
| | return bpp, bpp_mask |
| |
|
| | def psnr(self, rec, ori): |
| | mse = torch.mean((rec - ori) ** 2) |
| | if(mse == 0): |
| | return 100 |
| | max_pixel = 1. |
| | psnr = 10 * torch.log10(max_pixel / mse) |
| | return torch.mean(psnr) |
| |
|
| | def lpips(self, rec, ori): |
| | lpips_func = LPIPS().eval().to(device=rec.device) |
| | lipis_value = lpips_func(rec, ori) |
| | return lipis_value.mean() |
| | |
| | def dists(self, rec, ori): |
| | D = DISTS().cuda() |
| | dists_value = D(rec, ori) |
| | return dists_value.mean() |
| | |
| | def cal_total_loss(self, lpips, bpp, out_net): |
| | |
| | task_loss = out_net['task_loss'] |
| | total_loss = bpp + out_net['lambda'] * task_loss |
| | return total_loss |
| |
|
| | def forward(self, ori, out_net, rec=None): |
| | out = {} |
| | out["bpp"], out["bpp_mask"] = self.bpp_loss(ori, out_net) |
| | out["bpp_loss"] = out["bpp"] + out["bpp_mask"] |
| | |
| | if rec is not None: |
| | out["psnr"] = self.psnr(torch.clamp(rec, 0, 1), ori) |
| | out["msssim"] = ms_ssim(torch.clamp(rec, 0, 1), ori, data_range=1, size_average=True) |
| | out["lpips"] = self.lpips(torch.clamp(rec, 0, 1), ori) |
| | out["dists"] = self.dists(torch.clamp(rec, 0, 1), ori) |
| | out["total_loss"] = self.cal_total_loss(out["lpips"], out["bpp_loss"], out_net) |
| | return out |
| |
|
| | def save_patches(blocks, save_path='/home/t2vg-a100-G4-10/project/qyp/patches'): |
| | """ Save each patch as an image file. """ |
| | os.makedirs(save_path, exist_ok=True) |
| | for i, block in enumerate(blocks): |
| | |
| | block_image = block.permute(1, 2, 0).cpu().numpy().squeeze() |
| | block_image = (block_image * 255).astype(np.uint8) |
| | |
| | |
| | img = Image.fromarray(block_image) |
| | img.save(f"{save_path}/patch_{i}.png") |
| | class AverageMeter: |
| | """Compute running average.""" |
| |
|
| | def __init__(self): |
| | self.val = 0 |
| | self.avg = 0 |
| | self.sum = 0 |
| | self.count = 0 |
| |
|
| | def update(self, val, n=1): |
| | self.val = val |
| | self.sum += val * n |
| | self.count += n |
| | self.avg = self.sum / self.count |
| |
|
| | class CustomDataParallel(nn.DataParallel): |
| | """Custom DataParallel to access the module methods.""" |
| |
|
| | def __getattr__(self, key): |
| | try: |
| | return super().__getattr__(key) |
| | except AttributeError: |
| | return getattr(self.module, key) |
| |
|
| |
|
| | def init(args): |
| | base_dir = f'{args.root}/{args.exp_name}/' |
| | os.makedirs(base_dir, exist_ok=True) |
| | return base_dir |
| |
|
| | def setup_logger(log_dir): |
| | log_formatter = logging.Formatter("%(asctime)s [%(levelname)-5.5s] %(message)s") |
| | root_logger = logging.getLogger() |
| | root_logger.setLevel(logging.INFO) |
| |
|
| | log_file_handler = logging.FileHandler(log_dir, encoding='utf-8') |
| | log_file_handler.setFormatter(log_formatter) |
| | root_logger.addHandler(log_file_handler) |
| |
|
| | log_stream_handler = logging.StreamHandler(sys.stdout) |
| | log_stream_handler.setFormatter(log_formatter) |
| | root_logger.addHandler(log_stream_handler) |
| |
|
| | logging.info('Logging file is %s' % log_dir) |
| |
|
| | def save_img(img: torch.Tensor, vis_path, input_p, mask=False): |
| | img = img.clone().detach() |
| | img = img.to(torch.device('cpu')) |
| | if os.path.isdir(vis_path) is not True: |
| | os.makedirs(vis_path) |
| | end = '/' |
| | if mask: |
| | img_name = vis_path + 'mask_' + str(input_p[input_p.rfind(end):]) |
| | else: |
| | img_name = vis_path + str(input_p[input_p.rfind(end):]) |
| | vutils.save_image(img, os.path.join(vis_path, img_name), nrow=8) |
| |
|
| | def train_one_epoch(model, data_loader, metrics_criterion, device, |
| | optimizer, epoch, loss_scaler, log_writer, args, val_dataloader=None, stage='train'): |
| | |
| | model.train(True) |
| | metric_logger = misc.MetricLogger(delimiter=" ") |
| | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) |
| | header = 'Epoch: [{}]'.format(epoch) |
| | print_freq = 20 |
| | accum_iter = args.accum_iter |
| | optimizer.zero_grad() |
| | if log_writer is not None: |
| | print('log_dir: {}'.format(log_writer.log_dir)) |
| |
|
| | vis_path = os.path.join("./MIM_test_new/", stage) |
| | os.makedirs(vis_path, exist_ok=True) |
| |
|
| | |
| | for data_iter_step, samples in enumerate(metric_logger.log_every(data_loader, print_freq, header)): |
| | samples = samples.to(device, non_blocking=True) |
| |
|
| | samples, h_ori, w_ori = prepadding(samples) |
| | |
| | if data_iter_step % accum_iter == 0: |
| | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) |
| |
|
| | with torch.cuda.amp.autocast(): |
| | out_net = model(samples, is_training=True, manual_mask_rate=None) |
| | |
| | rec = model.module.gen_img(out_net['logits'], out_net['token_all_mask'], out_net['token_indices'], out_net['z_H'], out_net['z_W']) |
| | |
| | rec = rec.to(device) |
| | samples = samples[:, :, :h_ori, :w_ori] |
| | rec = rec[:, :, :h_ori, :w_ori] |
| | out_criterion = metrics_criterion(samples, out_net, rec) |
| | loss_value = out_criterion['total_loss'].item() |
| |
|
| | if not math.isfinite(loss_value): |
| | print("Loss is {}, stopping training".format(loss_value)) |
| | sys.exit(1) |
| |
|
| | out_criterion['total_loss'] /= accum_iter |
| | loss_scaler(out_criterion['total_loss'], optimizer, clip_grad=args.grad_clip, parameters=model.parameters(), |
| | update_grad=(data_iter_step + 1) % accum_iter == 0) |
| | if (data_iter_step + 1) % accum_iter == 0: |
| | optimizer.zero_grad() |
| |
|
| | torch.cuda.synchronize() |
| |
|
| | metric_logger.update(loss=loss_value) |
| |
|
| | lr = optimizer.param_groups[0]["lr"] |
| | metric_logger.update(lr=lr) |
| | metric_logger.update(bpp=out_criterion['bpp_loss']) |
| | metric_logger.update(bpp_mask=out_criterion['bpp_mask']) |
| | metric_logger.update(task_loss=out_net['task_loss'].item()) |
| | metric_logger.update(lmbda=out_net['lambda']) |
| | metric_logger.update(mask_ratio=out_net['mask_ratio']) |
| | metric_logger.update(lpips=out_criterion['lpips'].item()) |
| | metric_logger.update(dists=out_criterion['dists'].item()) |
| |
|
| | loss_value_reduce = misc.all_reduce_mean(loss_value) |
| | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: |
| | """ We use epoch_1000x as the x-axis in tensorboard. |
| | This calibrates different curves when batch size changes. |
| | """ |
| | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) |
| | log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x) |
| | log_writer.add_scalar('lr', lr, epoch_1000x) |
| |
|
| | |
| | if data_iter_step % 1000 == 0: |
| | with torch.no_grad(): |
| | real_fake_images = torch.cat((samples, rec), dim=0) |
| | vutils.save_image(real_fake_images, os.path.join(vis_path, f"{epoch}_{data_iter_step}.jpg"), nrow=8) |
| | |
| | |
| | vutils.save_image(out_net['mask_vis'], os.path.join(vis_path, f"{epoch}_{data_iter_step}_mask.jpg"), nrow=8) |
| |
|
| |
|
| | |
| | metric_logger.synchronize_between_processes() |
| | print("Averaged stats:", metric_logger) |
| | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
| |
|
| | def inference(epoch, test_loader, model, metrics_criterion, device, manual_mask_ratio, args, stage='test'): |
| | model.eval() |
| | bpp_loss = AverageMeter() |
| | bpp_mask = AverageMeter() |
| | psnr = AverageMeter() |
| | msssim = AverageMeter() |
| | lpips = AverageMeter() |
| | dists = AverageMeter() |
| | test_loss = AverageMeter() |
| |
|
| | vis_path = os.path.join("./MIM_vbr_kodak/", stage) |
| | vis_path = os.path.join(vis_path, str(manual_mask_ratio)) |
| | os.makedirs(vis_path, exist_ok=True) |
| | if stage == 'test': |
| | test_vis_path = os.path.join("/home/v-ruoyufeng/v-ruoyufeng/qyp/rec_fid", manual_mask_ratio) |
| | os.makedirs(test_vis_path, exist_ok=True) |
| |
|
| | with torch.no_grad(): |
| | tqdm_meter = tqdm.tqdm(enumerate(test_loader), leave=False, total=len(test_loader)) |
| | for i, d in tqdm_meter: |
| | d = d.to(device) |
| | d, h_ori, w_ori = prepadding(d, factor=256) |
| | out_net = model(d, is_training=False, manual_mask_rate=manual_mask_ratio) |
| | |
| | rec = model.module.gen_img(out_net['logits'], out_net['token_all_mask'], out_net['token_indices'], out_net['ori_shape'], out_net['patch_sizes'], out_net['num_blocks_h'], out_net['num_blocks_w']) |
| | rec = rec.to(device) |
| |
|
| | d = crop_to_original_shape(d, h_ori, w_ori) |
| | rec = crop_to_original_shape(rec, h_ori, w_ori) |
| | |
| | out_criterion = metrics_criterion(d, out_net, rec) |
| |
|
| | bpp_loss.update(out_criterion["bpp_loss"]) |
| | bpp_mask.update(out_criterion["bpp_mask"]) |
| | psnr.update(out_criterion['psnr']) |
| | msssim.update(out_criterion['msssim']) |
| | lpips.update(out_criterion['lpips']) |
| | dists.update(out_criterion['dists']) |
| | test_loss.update(out_criterion['total_loss']) |
| | |
| | |
| | if stage == 'val': |
| | with torch.no_grad(): |
| | |
| | |
| | vutils.save_image(rec, os.path.join(vis_path, f"{epoch}_{i}.jpg")) |
| | vutils.save_image(out_net['mask_vis'], os.path.join(vis_path, f"{epoch}_{i}_mask.jpg")) |
| | |
| | |
| | |
| | |
| |
|
| | model.train() |
| |
|
| | |
| | if torch.distributed.is_initialized(): |
| | rank = dist.get_rank() |
| | else: |
| | rank = 0 |
| |
|
| | if rank == 0: |
| | log_txt = f"{epoch}|bpp:{bpp_loss.avg.item():.5f}|mask:{bpp_mask.avg:.5f}|mask_ratio:{manual_mask_ratio}|psnr:{psnr.avg.item():.5f}|msssim:{msssim.avg.item():.5f}|lpips:{lpips.avg.item():.5f}|dists:{dists.avg.item():.5f}|Test loss:{test_loss.avg.item():.5f}" |
| | logging.info(log_txt) |
| | return test_loss.avg |
| |
|
| |
|
| | def save_checkpoint(state, is_best, base_dir, filename="checkpoint.pth.tar"): |
| | torch.save(state, base_dir+filename) |
| | if is_best: |
| | torch.save(state, base_dir+"checkpoint_best.pth.tar") |
| |
|
| | def parse_args(argv): |
| | parser = argparse.ArgumentParser(description="Example training script.") |
| | parser.add_argument( |
| | "-c", |
| | "--config", |
| | default="config/vpt_default.yaml", |
| | help="Path to config file", |
| | ) |
| | parser.add_argument( |
| | '--name', |
| | default=datetime.now().strftime('%Y-%m-%d_%H_%M_%S'), |
| | type=str, |
| | help='Result dir name', |
| | ) |
| | parser.add_argument('--lr', type=float, default=None, metavar='LR', |
| | help='learning rate (absolute lr)') |
| | given_configs, remaining = parser.parse_known_args(argv) |
| | |
| | parser.add_argument('--world_size', default=1, type=int, |
| | help='number of distributed processes') |
| | parser.add_argument('--local-rank', default=-1, type=int) |
| | parser.add_argument('--dist_on_itp', action='store_true') |
| | parser.add_argument('--dist_url', default='env://', |
| | help='url used to set up distributed training') |
| | with open(given_configs.config) as file: |
| | yaml_data= yaml.safe_load(file) |
| | parser.set_defaults(**yaml_data) |
| | |
| | parser.add_argument( |
| | "-T", |
| | "--TEST", |
| | |
| | default=False, |
| | help='Testing' |
| | ) |
| | args = parser.parse_args(remaining) |
| | return args |
| |
|
| |
|
| | def main(argv): |
| | args = parse_args(argv) |
| | base_dir = init(args) |
| |
|
| | if args.output_dir: |
| | Path(args.output_dir).mkdir(parents=True, exist_ok=True) |
| | args.log_dir = args.output_dir |
| |
|
| | misc.init_distributed_mode(args) |
| |
|
| | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) |
| | print("{}".format(args).replace(', ', ',\n')) |
| |
|
| | device = torch.device(args.device) |
| | |
| | seed = args.seed + misc.get_rank() |
| | torch.manual_seed(seed) |
| | torch.cuda.manual_seed(seed) |
| | np.random.seed(seed) |
| | random.seed(seed) |
| |
|
| | cudnn.benchmark = True |
| | |
| | setup_logger(base_dir + '/' + time.strftime('%Y%m%d_%H%M%S') + '.log') |
| | msg = f'======================= {args.name} =======================' |
| | logging.info(msg) |
| | for k in args.__dict__: |
| | logging.info(k + ':' + str(args.__dict__[k])) |
| | logging.info('=' * len(msg)) |
| |
|
| | |
| | transform_det = transforms.Compose([ |
| | transforms.RandomHorizontalFlip(), |
| | transforms.ToTensor()]) |
| | transform_val = transforms.Compose([ |
| | |
| | |
| | transforms.ToTensor() |
| | ]) |
| |
|
| |
|
| | if args.dataset=='coco': |
| | train_dataset = MSCOCO(args.dataset_path + "/train2017/", |
| | transform_det, |
| | "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/util/img_list.txt") |
| | val_dataset = Kodak(args.kodak_path, transform_val) |
| | |
| |
|
| | device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu" |
| |
|
| | |
| | if True: |
| | num_tasks = misc.get_world_size() |
| | global_rank = misc.get_rank() |
| | sampler_val = torch.utils.data.DistributedSampler( |
| | val_dataset, num_replicas=num_tasks, rank=global_rank, shuffle=True |
| | ) |
| | else: |
| | sampler_train = torch.utils.data.RandomSampler(train_dataset) |
| |
|
| | if global_rank == 0 and args.log_dir is not None: |
| | os.makedirs(args.log_dir, exist_ok=True) |
| | log_writer = SummaryWriter(log_dir=args.log_dir) |
| | else: |
| | log_writer = None |
| |
|
| | val_dataloader = DataLoader(val_dataset, sampler=sampler_val, batch_size=1, |
| | num_workers=args.num_workers, shuffle=False, pin_memory=args.pin_mem, drop_last=True) |
| |
|
| | |
| | vqgan_ckpt_path = '/home/t2vg-a100-G4-10/project/qyp/mage/vqgan_jax_strongaug.ckpt' |
| | model = models_mage_codec_high_resolu.__dict__[args.model](mask_ratio_mu=args.mask_ratio_mu, mask_ratio_std=args.mask_ratio_std, |
| | mask_ratio_min=args.mask_ratio_min, mask_ratio_max=args.mask_ratio_max, |
| | vqgan_ckpt_path=vqgan_ckpt_path) |
| | |
| | |
| | |
| | model.to(device) |
| | model_without_ddp = model |
| | print("Model = %s" % str(model_without_ddp)) |
| | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() |
| | if args.lr is None: |
| | args.lr = args.blr * eff_batch_size / 256 |
| | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) |
| | print("actual lr: %.2e" % args.lr) |
| |
|
| | print("accumulate grad iterations: %d" % args.accum_iter) |
| | print("effective batch size: %d" % eff_batch_size) |
| |
|
| | if args.distributed: |
| | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) |
| | model_without_ddp = model.module |
| | |
| | |
| | param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay) |
| | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) |
| | print(optimizer) |
| | loss_scaler = NativeScaler() |
| |
|
| | |
| | misc.load_model(args=args, model_without_ddp=model_without_ddp, |
| | optimizer=optimizer, loss_scaler=loss_scaler, strict=False) |
| | |
| | metrics_criterion = CalMetrics() |
| | |
| | |
| | last_epoch = args.start_epoch |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | print("############## pre validation ##############") |
| | best_loss = float("inf") |
| | tqrange = tqdm.trange(last_epoch, args.epochs) |
| | for val_mask_ratio in [0.8, 0.75, 0.7, 0.65, 0.6, 0.55, 0.5, 0.45, 0.4, 0.35, 0.3, 0.25, 0.2, 0.15, 0.1, 0.05, 0.01]: |
| | test_loss = inference(-1, val_dataloader, model, metrics_criterion, device, val_mask_ratio, args, 'val') |
| |
|
| | |
| | print(f"############## Start training for {args.epochs} epochs ##############") |
| | start_time = time.time() |
| | for epoch in tqrange: |
| | sampler_train = torch.utils.data.DistributedSampler(train_dataset, shuffle=True) |
| | data_loader_train = DataLoader( |
| | train_dataset, sampler=sampler_train, |
| | batch_size=args.batch_size, |
| | num_workers=args.num_workers, |
| | pin_memory=args.pin_mem, |
| | drop_last=True, |
| | ) |
| | if args.distributed: |
| | data_loader_train.sampler.set_epoch(epoch) |
| | train_stats = train_one_epoch(model, data_loader_train, metrics_criterion, device, |
| | optimizer, epoch, loss_scaler, log_writer=log_writer, args=args, val_dataloader=val_dataloader, stage='train') |
| |
|
| | test_loss = inference(epoch, val_dataloader, model, metrics_criterion, device, val_mask_ratio, args, 'val') |
| |
|
| | is_best = test_loss < best_loss |
| | best_loss = min(test_loss, best_loss) |
| |
|
| | if args.output_dir and (epoch % 10 == 0 or epoch + 1 == args.epochs): |
| | misc.save_model( |
| | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, |
| | loss_scaler=loss_scaler, epoch=epoch) |
| | if is_best: |
| | misc.save_model_last( |
| | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, |
| | loss_scaler=loss_scaler, epoch=epoch, is_best=is_best) |
| |
|
| | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, |
| | 'epoch': epoch,} |
| | if args.output_dir and misc.is_main_process(): |
| | if log_writer is not None: |
| | log_writer.flush() |
| | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: |
| | f.write(json.dumps(log_stats) + "\n") |
| | |
| | total_time = time.time() - start_time |
| | total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
| | print('Training time {}'.format(total_time_str)) |
| |
|
| |
|
| |
|
| | if __name__ == "__main__": |
| | main(sys.argv[1:]) |
| |
|