import argparse import math import sys import os import time import logging from datetime import datetime from model_vq import Model_VQ import torch import torch.nn as nn from omegaconf import OmegaConf import yaml from pytorch_msssim import ms_ssim from DISTS_pytorch import DISTS import lpips from torch.nn import functional as F from torchvision import utils as vutils import numpy as np import glob import util.misc as misc import PIL.Image as Image import torch.backends.cudnn as cudnn from pathlib import Path import os os.environ['CUDA_VISIBLE_DEVICES'] = '3' class CalMetrics(nn.Module): """Calculate BPP, PSNR, MS-SSIM, LPIPS and DISTS for the reconstructed image.""" def __init__(self): super().__init__() self.mse = nn.MSELoss() def psnr(self, rec, ori): mse = torch.mean((rec - ori) ** 2) if(mse == 0): return 100 max_pixel = 1. psnr = 10 * torch.log10(max_pixel / mse) return torch.mean(psnr) def lpips_vgg(self, rec, ori): loss_fn_vgg = lpips.LPIPS(net='vgg').cuda() lipis_vgg = loss_fn_vgg(rec, ori) return lipis_vgg def lpips_alex(self, rec, ori): loss_fn_alex = lpips.LPIPS(net='alex').cuda() lipis_alex = loss_fn_alex(rec, ori) return lipis_alex def dists(self, rec, ori): D = DISTS().cuda() dists_value = D(rec, ori) return dists_value def forward(self, ori, rec): out = {} if rec is not None: out["psnr"] = self.psnr(torch.clamp(rec, 0, 1), ori) out["lpips_vgg"] = self.lpips_vgg(torch.clamp(rec, 0, 1), ori) out["lpips_alex"] = self.lpips_alex(torch.clamp(rec, 0, 1), ori) out["dists"] = self.dists(torch.clamp(rec, 0, 1), ori) return out class AverageMeter: """Compute running average.""" def __init__(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count class CustomDataParallel(nn.DataParallel): """Custom DataParallel to access the module methods.""" def __getattr__(self, key): try: return super().__getattr__(key) except AttributeError: return getattr(self.module, key) def init(args): base_dir = f'{args.root}/{args.exp_name}/' os.makedirs(base_dir, exist_ok=True) return base_dir def setup_logger(log_dir): log_formatter = logging.Formatter("%(asctime)s [%(levelname)-5.5s] %(message)s") root_logger = logging.getLogger() root_logger.setLevel(logging.INFO) log_file_handler = logging.FileHandler(log_dir, encoding='utf-8') log_file_handler.setFormatter(log_formatter) root_logger.addHandler(log_file_handler) log_stream_handler = logging.StreamHandler(sys.stdout) log_stream_handler.setFormatter(log_formatter) root_logger.addHandler(log_stream_handler) logging.info('Logging file is %s' % log_dir) def load_img(p, padding=True, factor=64): x = Image.open(p) x = torch.from_numpy(np.asarray(x)) if len(x.shape) == 2: x = x.unsqueeze(-1).repeat(1, 1, 3) # h,w -> h,w,3 x = x.permute(2, 0, 1).unsqueeze(0).float().div(255) h, w = x.shape[2:4] if padding: dh = factor * math.ceil(h / factor) - h dw = factor * math.ceil(w / factor) - w # 均匀添加padding dh_half = dh // 2 dw_half = dw // 2 dh_extra = dh % 2 dw_extra = dw % 2 x = F.pad(x, (dw_half, dw_half + dw_extra, dh_half, dh_half + dh_extra)) return x, h, w def save_img(img: torch.Tensor, vis_path, input_p, rec=False): img = img.clone().detach() img = img.to(torch.device('cpu')) if os.path.isdir(vis_path) is not True: os.makedirs(vis_path) end = '/' if rec: vis_path = vis_path + '/rec' if os.path.isdir(vis_path) is not True: os.makedirs(vis_path) img_name = vis_path + str(input_p[input_p.rfind(end):]) else: img_name = vis_path + str(input_p[input_p.rfind(end):]) vutils.save_image(img, os.path.join(img_name), nrow=8) def inference(epoch, eval_path, model, metrics_criterion, device, stage='test'): model.eval() psnr = AverageMeter() lpips_vgg = AverageMeter() lpips_alex = AverageMeter() dists = AverageMeter() vis_path = os.path.join("./VQGAN/", stage) os.makedirs(vis_path, exist_ok=True) with torch.no_grad(): for input_p in eval_path: x, hx, wx = load_img(input_p, padding=True, factor=64) x = x.to(device) rec = model(x) x = x[:, :, :hx, :wx] rec = rec[:, :, :hx, :wx] rec = rec.to(device) out_criterion = metrics_criterion(x, rec) psnr.update(out_criterion['psnr']) lpips_vgg.update(out_criterion['lpips_vgg']) lpips_alex.update(out_criterion['lpips_alex']) dists.update(out_criterion['dists']) ## ======================= update progress bar & visualization ======================= ## # save_img(x, vis_path, input_p) save_img(rec, vis_path, input_p, rec=True) model.train() log_txt = f"{epoch}|psnr:{psnr.avg:.5f}|lpips_vgg:{lpips_vgg.avg.mean().item():.5f}|lpips_alex:{lpips_alex.avg.mean().item():.5f}|dists:{dists.avg.mean().item():.5f}" logging.info(log_txt) return psnr def parse_args(argv): parser = argparse.ArgumentParser(description="Example training script.") parser.add_argument( "-c", "--config", default="/home/t2vg-a100-G4-10/project/qyp/mimc_rope/config/cal_upper_bound.yaml", help="Path to config file", ) parser.add_argument( '--name', default=datetime.now().strftime('%Y-%m-%d_%H_%M_%S'), type=str, help='Result dir name', ) parser.add_argument( '--eval_path', default='/home/t2vg-a100-G4-10/project/qyp/datasets/COCO/val2017', type=str, help='path to the evaluation dataset', ) parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate (absolute lr)') given_configs, remaining = parser.parse_known_args(argv) # distributed training parameters parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') parser.add_argument('--local-rank', default=-1, type=int) parser.add_argument('--dist_on_itp', action='store_true') parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') with open(given_configs.config) as file: yaml_data= yaml.safe_load(file) parser.set_defaults(**yaml_data) parser.add_argument( "-T", "--TEST", action='store_true', help='Testing' ) args = parser.parse_args(remaining) return args def load_eval_ps(eval_path): eval_ps = sorted(glob.glob(os.path.join(eval_path, '*.jpg'))) return eval_ps def main(argv): args = parse_args(argv) base_dir = init(args) # create the base dir for saving the results if args.output_dir: Path(args.output_dir).mkdir(parents=True, exist_ok=True) args.log_dir = args.output_dir misc.init_distributed_mode(args) print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) print("{}".format(args).replace(', ', ',\n')) device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + misc.get_rank() torch.manual_seed(seed) np.random.seed(seed) cudnn.benchmark = True setup_logger(base_dir + '/' + time.strftime('%Y%m%d_%H%M%S') + '.log') msg = f'======================= {args.name} =======================' logging.info(msg) for k in args.__dict__: logging.info(k + ':' + str(args.__dict__[k])) logging.info('=' * len(msg)) ## ======================= prepare dataset ======================= ## eval_path = sorted(glob.glob(os.path.join(args.eval_path, '*.jpg'))) device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu" ## ======================= prepare model ======================= ## vqgan_ckpt_path = '/home/t2vg-a100-G4-10/project/qyp/mage/vqgan_jax_strongaug.ckpt' config = OmegaConf.load('config/vqgan.yaml').model model = Model_VQ(ddconfig=config.params.ddconfig, n_embed=config.params.n_embed, # 1024 embed_dim=config.params.embed_dim, # 256 ckpt_path=vqgan_ckpt_path) model.to(device) metrics_criterion = CalMetrics() ## ======================= pre validation ======================= ## test_loss = inference(-1, eval_path, model, metrics_criterion, device, 'val') if __name__ == "__main__": main(sys.argv[1:])