|
|
import argparse |
|
|
import math |
|
|
import sys |
|
|
import os |
|
|
import time |
|
|
import logging |
|
|
from datetime import datetime |
|
|
from model_vq import Model_VQ |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from omegaconf import OmegaConf |
|
|
|
|
|
import yaml |
|
|
from pytorch_msssim import ms_ssim |
|
|
from DISTS_pytorch import DISTS |
|
|
import lpips |
|
|
from torch.nn import functional as F |
|
|
from torchvision import utils as vutils |
|
|
import numpy as np |
|
|
import glob |
|
|
|
|
|
import util.misc as misc |
|
|
import PIL.Image as Image |
|
|
import torch.backends.cudnn as cudnn |
|
|
from pathlib import Path |
|
|
import os |
|
|
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '3' |
|
|
|
|
|
|
|
|
class CalMetrics(nn.Module): |
|
|
"""Calculate BPP, PSNR, MS-SSIM, LPIPS and DISTS for the reconstructed image.""" |
|
|
|
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.mse = nn.MSELoss() |
|
|
|
|
|
def psnr(self, rec, ori): |
|
|
mse = torch.mean((rec - ori) ** 2) |
|
|
if(mse == 0): |
|
|
return 100 |
|
|
max_pixel = 1. |
|
|
psnr = 10 * torch.log10(max_pixel / mse) |
|
|
return torch.mean(psnr) |
|
|
|
|
|
def lpips_vgg(self, rec, ori): |
|
|
loss_fn_vgg = lpips.LPIPS(net='vgg').cuda() |
|
|
lipis_vgg = loss_fn_vgg(rec, ori) |
|
|
return lipis_vgg |
|
|
|
|
|
def lpips_alex(self, rec, ori): |
|
|
loss_fn_alex = lpips.LPIPS(net='alex').cuda() |
|
|
lipis_alex = loss_fn_alex(rec, ori) |
|
|
return lipis_alex |
|
|
|
|
|
def dists(self, rec, ori): |
|
|
D = DISTS().cuda() |
|
|
dists_value = D(rec, ori) |
|
|
return dists_value |
|
|
|
|
|
def forward(self, ori, rec): |
|
|
out = {} |
|
|
if rec is not None: |
|
|
out["psnr"] = self.psnr(torch.clamp(rec, 0, 1), ori) |
|
|
out["lpips_vgg"] = self.lpips_vgg(torch.clamp(rec, 0, 1), ori) |
|
|
out["lpips_alex"] = self.lpips_alex(torch.clamp(rec, 0, 1), ori) |
|
|
out["dists"] = self.dists(torch.clamp(rec, 0, 1), ori) |
|
|
return out |
|
|
|
|
|
|
|
|
class AverageMeter: |
|
|
"""Compute running average.""" |
|
|
|
|
|
def __init__(self): |
|
|
self.val = 0 |
|
|
self.avg = 0 |
|
|
self.sum = 0 |
|
|
self.count = 0 |
|
|
|
|
|
def update(self, val, n=1): |
|
|
self.val = val |
|
|
self.sum += val * n |
|
|
self.count += n |
|
|
self.avg = self.sum / self.count |
|
|
|
|
|
|
|
|
class CustomDataParallel(nn.DataParallel): |
|
|
"""Custom DataParallel to access the module methods.""" |
|
|
|
|
|
def __getattr__(self, key): |
|
|
try: |
|
|
return super().__getattr__(key) |
|
|
except AttributeError: |
|
|
return getattr(self.module, key) |
|
|
|
|
|
|
|
|
def init(args): |
|
|
base_dir = f'{args.root}/{args.exp_name}/' |
|
|
os.makedirs(base_dir, exist_ok=True) |
|
|
return base_dir |
|
|
|
|
|
def setup_logger(log_dir): |
|
|
log_formatter = logging.Formatter("%(asctime)s [%(levelname)-5.5s] %(message)s") |
|
|
root_logger = logging.getLogger() |
|
|
root_logger.setLevel(logging.INFO) |
|
|
|
|
|
log_file_handler = logging.FileHandler(log_dir, encoding='utf-8') |
|
|
log_file_handler.setFormatter(log_formatter) |
|
|
root_logger.addHandler(log_file_handler) |
|
|
|
|
|
log_stream_handler = logging.StreamHandler(sys.stdout) |
|
|
log_stream_handler.setFormatter(log_formatter) |
|
|
root_logger.addHandler(log_stream_handler) |
|
|
|
|
|
logging.info('Logging file is %s' % log_dir) |
|
|
|
|
|
|
|
|
def load_img(p, padding=True, factor=64): |
|
|
x = Image.open(p) |
|
|
x = torch.from_numpy(np.asarray(x)) |
|
|
if len(x.shape) == 2: |
|
|
x = x.unsqueeze(-1).repeat(1, 1, 3) |
|
|
x = x.permute(2, 0, 1).unsqueeze(0).float().div(255) |
|
|
h, w = x.shape[2:4] |
|
|
|
|
|
if padding: |
|
|
dh = factor * math.ceil(h / factor) - h |
|
|
dw = factor * math.ceil(w / factor) - w |
|
|
|
|
|
dh_half = dh // 2 |
|
|
dw_half = dw // 2 |
|
|
dh_extra = dh % 2 |
|
|
dw_extra = dw % 2 |
|
|
x = F.pad(x, (dw_half, dw_half + dw_extra, dh_half, dh_half + dh_extra)) |
|
|
return x, h, w |
|
|
|
|
|
|
|
|
|
|
|
def save_img(img: torch.Tensor, vis_path, input_p, rec=False): |
|
|
img = img.clone().detach() |
|
|
img = img.to(torch.device('cpu')) |
|
|
if os.path.isdir(vis_path) is not True: |
|
|
os.makedirs(vis_path) |
|
|
end = '/' |
|
|
if rec: |
|
|
vis_path = vis_path + '/rec' |
|
|
if os.path.isdir(vis_path) is not True: |
|
|
os.makedirs(vis_path) |
|
|
img_name = vis_path + str(input_p[input_p.rfind(end):]) |
|
|
else: |
|
|
img_name = vis_path + str(input_p[input_p.rfind(end):]) |
|
|
vutils.save_image(img, os.path.join(img_name), nrow=8) |
|
|
|
|
|
|
|
|
def inference(epoch, eval_path, model, metrics_criterion, device, stage='test'): |
|
|
model.eval() |
|
|
psnr = AverageMeter() |
|
|
lpips_vgg = AverageMeter() |
|
|
lpips_alex = AverageMeter() |
|
|
dists = AverageMeter() |
|
|
|
|
|
vis_path = os.path.join("./VQGAN/", stage) |
|
|
os.makedirs(vis_path, exist_ok=True) |
|
|
|
|
|
with torch.no_grad(): |
|
|
for input_p in eval_path: |
|
|
x, hx, wx = load_img(input_p, padding=True, factor=64) |
|
|
x = x.to(device) |
|
|
rec = model(x) |
|
|
|
|
|
x = x[:, :, :hx, :wx] |
|
|
rec = rec[:, :, :hx, :wx] |
|
|
rec = rec.to(device) |
|
|
out_criterion = metrics_criterion(x, rec) |
|
|
|
|
|
psnr.update(out_criterion['psnr']) |
|
|
lpips_vgg.update(out_criterion['lpips_vgg']) |
|
|
lpips_alex.update(out_criterion['lpips_alex']) |
|
|
dists.update(out_criterion['dists']) |
|
|
|
|
|
|
|
|
|
|
|
save_img(rec, vis_path, input_p, rec=True) |
|
|
|
|
|
model.train() |
|
|
log_txt = f"{epoch}|psnr:{psnr.avg:.5f}|lpips_vgg:{lpips_vgg.avg.mean().item():.5f}|lpips_alex:{lpips_alex.avg.mean().item():.5f}|dists:{dists.avg.mean().item():.5f}" |
|
|
logging.info(log_txt) |
|
|
return psnr |
|
|
|
|
|
def parse_args(argv): |
|
|
parser = argparse.ArgumentParser(description="Example training script.") |
|
|
parser.add_argument( |
|
|
"-c", |
|
|
"--config", |
|
|
default="/home/t2vg-a100-G4-10/project/qyp/mimc_rope/config/cal_upper_bound.yaml", |
|
|
help="Path to config file", |
|
|
) |
|
|
parser.add_argument( |
|
|
'--name', |
|
|
default=datetime.now().strftime('%Y-%m-%d_%H_%M_%S'), |
|
|
type=str, |
|
|
help='Result dir name', |
|
|
) |
|
|
parser.add_argument( |
|
|
'--eval_path', |
|
|
default='/home/t2vg-a100-G4-10/project/qyp/datasets/COCO/val2017', |
|
|
type=str, |
|
|
help='path to the evaluation dataset', |
|
|
) |
|
|
parser.add_argument('--lr', type=float, default=None, metavar='LR', |
|
|
help='learning rate (absolute lr)') |
|
|
given_configs, remaining = parser.parse_known_args(argv) |
|
|
|
|
|
parser.add_argument('--world_size', default=1, type=int, |
|
|
help='number of distributed processes') |
|
|
parser.add_argument('--local-rank', default=-1, type=int) |
|
|
parser.add_argument('--dist_on_itp', action='store_true') |
|
|
parser.add_argument('--dist_url', default='env://', |
|
|
help='url used to set up distributed training') |
|
|
with open(given_configs.config) as file: |
|
|
yaml_data= yaml.safe_load(file) |
|
|
parser.set_defaults(**yaml_data) |
|
|
|
|
|
parser.add_argument( |
|
|
"-T", |
|
|
"--TEST", |
|
|
action='store_true', |
|
|
help='Testing' |
|
|
) |
|
|
args = parser.parse_args(remaining) |
|
|
return args |
|
|
|
|
|
def load_eval_ps(eval_path): |
|
|
eval_ps = sorted(glob.glob(os.path.join(eval_path, '*.jpg'))) |
|
|
return eval_ps |
|
|
|
|
|
def main(argv): |
|
|
args = parse_args(argv) |
|
|
base_dir = init(args) |
|
|
|
|
|
if args.output_dir: |
|
|
Path(args.output_dir).mkdir(parents=True, exist_ok=True) |
|
|
args.log_dir = args.output_dir |
|
|
|
|
|
misc.init_distributed_mode(args) |
|
|
|
|
|
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) |
|
|
print("{}".format(args).replace(', ', ',\n')) |
|
|
|
|
|
device = torch.device(args.device) |
|
|
|
|
|
seed = args.seed + misc.get_rank() |
|
|
torch.manual_seed(seed) |
|
|
np.random.seed(seed) |
|
|
|
|
|
cudnn.benchmark = True |
|
|
|
|
|
setup_logger(base_dir + '/' + time.strftime('%Y%m%d_%H%M%S') + '.log') |
|
|
msg = f'======================= {args.name} =======================' |
|
|
logging.info(msg) |
|
|
for k in args.__dict__: |
|
|
logging.info(k + ':' + str(args.__dict__[k])) |
|
|
logging.info('=' * len(msg)) |
|
|
|
|
|
|
|
|
eval_path = sorted(glob.glob(os.path.join(args.eval_path, '*.jpg'))) |
|
|
|
|
|
device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
vqgan_ckpt_path = '/home/t2vg-a100-G4-10/project/qyp/mage/vqgan_jax_strongaug.ckpt' |
|
|
config = OmegaConf.load('config/vqgan.yaml').model |
|
|
model = Model_VQ(ddconfig=config.params.ddconfig, |
|
|
n_embed=config.params.n_embed, |
|
|
embed_dim=config.params.embed_dim, |
|
|
ckpt_path=vqgan_ckpt_path) |
|
|
|
|
|
model.to(device) |
|
|
|
|
|
metrics_criterion = CalMetrics() |
|
|
|
|
|
test_loss = inference(-1, eval_path, model, metrics_criterion, device, 'val') |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main(sys.argv[1:]) |