|
|
import tqdm |
|
|
import argparse |
|
|
import math |
|
|
|
|
|
import sys |
|
|
import os |
|
|
import time |
|
|
import logging |
|
|
from datetime import datetime |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
|
|
|
import torchvision |
|
|
from torch.utils.data import DataLoader |
|
|
from torchvision import transforms |
|
|
from torchvision.models import resnet50 |
|
|
|
|
|
import yaml |
|
|
from pytorch_msssim import ms_ssim |
|
|
from DISTS_pytorch import DISTS |
|
|
from util.lpips import LPIPS |
|
|
from torch.nn import functional as F |
|
|
from torchvision import utils as vutils |
|
|
import numpy as np |
|
|
|
|
|
import util.misc as misc |
|
|
import util.lr_sched as lr_sched |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
import models_mage_codec_high_resolu |
|
|
import timm.optim.optim_factory as optim_factory |
|
|
from util.misc import NativeScalerWithGradNormCount as NativeScaler |
|
|
from collections import OrderedDict |
|
|
import pickle |
|
|
import torch.backends.cudnn as cudnn |
|
|
from pathlib import Path |
|
|
import random |
|
|
import torch.distributed as dist |
|
|
from util.dataloader import MSCOCO, Kodak, prepadding |
|
|
from util.utils import adaptively_split_and_pad, crop_and_reconstruct |
|
|
from util.alignment import Alignment |
|
|
|
|
|
|
|
|
from detectron2.config import get_cfg |
|
|
from detectron2.layers import ShapeSpec |
|
|
from detectron2.modeling.backbone.fpn import build_resnet_fpn_backbone |
|
|
|
|
|
|
|
|
from detectron2.evaluation import COCOEvaluator |
|
|
from detectron2.data.datasets import register_coco_instances |
|
|
from detectron2.data import build_detection_test_loader |
|
|
from detectron2.data.detection_utils import read_image |
|
|
|
|
|
from contextlib import ExitStack, contextmanager |
|
|
|
|
|
|
|
|
|
|
|
@contextmanager |
|
|
def inference_context(model): |
|
|
training_mode = model.training |
|
|
model.eval() |
|
|
yield |
|
|
model.train(training_mode) |
|
|
|
|
|
class CalMetrics(nn.Module): |
|
|
"""Calculate BPP, PSNR, MS-SSIM, LPIPS and DISTS for the reconstructed image.""" |
|
|
|
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.mse = nn.MSELoss() |
|
|
|
|
|
def bpp_loss(self, ori, out_net): |
|
|
b, _, h, w = ori.shape |
|
|
num_pixels = b * h * w |
|
|
bpp = torch.log(out_net["likelihoods"]).sum() / (-math.log(2) * num_pixels) |
|
|
bs_mask_token = out_net['bs_mask_token'] |
|
|
bytes_length = len(bs_mask_token) |
|
|
|
|
|
total_bits = bytes_length * 8 |
|
|
|
|
|
bpp_mask = total_bits / num_pixels |
|
|
return bpp, bpp_mask |
|
|
|
|
|
def psnr(self, rec, ori): |
|
|
mse = torch.mean((rec - ori) ** 2) |
|
|
if(mse == 0): |
|
|
return 100 |
|
|
max_pixel = 1. |
|
|
psnr = 10 * torch.log10(max_pixel / mse) |
|
|
return torch.mean(psnr) |
|
|
|
|
|
def lpips(self, rec, ori): |
|
|
lpips_func = LPIPS().eval().to(device=rec.device) |
|
|
lipis_value = lpips_func(rec, ori) |
|
|
return lipis_value.mean() |
|
|
|
|
|
def dists(self, rec, ori): |
|
|
D = DISTS().cuda() |
|
|
dists_value = D(rec, ori) |
|
|
return dists_value.mean() |
|
|
|
|
|
def cal_total_loss(self, lpips, bpp, out_net): |
|
|
|
|
|
task_loss = out_net['task_loss'] |
|
|
total_loss = bpp + out_net['lambda'] * task_loss |
|
|
return total_loss |
|
|
|
|
|
def forward(self, ori, out_net, rec=None): |
|
|
out = {} |
|
|
out["bpp"], out["bpp_mask"] = self.bpp_loss(ori, out_net) |
|
|
out["bpp_loss"] = out["bpp"] + out["bpp_mask"] |
|
|
|
|
|
if rec is not None: |
|
|
out["psnr"] = self.psnr(torch.clamp(rec, 0, 1), ori) |
|
|
out["msssim"] = ms_ssim(torch.clamp(rec, 0, 1), ori, data_range=1, size_average=True) |
|
|
out["lpips"] = self.lpips(torch.clamp(rec, 0, 1), ori) |
|
|
out["dists"] = self.dists(torch.clamp(rec, 0, 1), ori) |
|
|
out["total_loss"] = self.cal_total_loss(out["lpips"], out["bpp_loss"], out_net) |
|
|
return out |
|
|
|
|
|
|
|
|
class TaskLoss(nn.Module): |
|
|
def __init__(self, cfg, device) -> None: |
|
|
super().__init__() |
|
|
self.ce = nn.CrossEntropyLoss() |
|
|
self.task_net = build_resnet_fpn_backbone(cfg, ShapeSpec(channels=3)) |
|
|
checkpoint = OrderedDict() |
|
|
with open(cfg.MODEL.WEIGHTS, 'rb') as f: |
|
|
FPN_ckpt = pickle.load(f) |
|
|
for k, v in FPN_ckpt['model'].items(): |
|
|
if 'backbone' in k: |
|
|
checkpoint['.'.join(k.split('.')[1:])] = torch.from_numpy(v) |
|
|
self.task_net.load_state_dict(checkpoint, strict=True) |
|
|
self.task_net = self.task_net.to(device) |
|
|
for k, p in self.task_net.named_parameters(): |
|
|
p.requires_grad = False |
|
|
self.task_net.eval() |
|
|
self.align = Alignment(divisor=32).to(device) |
|
|
self.pixel_mean = torch.Tensor([103.530, 116.280, 123.675]).view(-1, 1, 1).to(device) |
|
|
|
|
|
def forward(self, output, d, train_mode=False): |
|
|
with torch.no_grad(): |
|
|
|
|
|
d = d.flip(1).mul(255) |
|
|
d = d - self.pixel_mean |
|
|
if not train_mode: |
|
|
d = self.align.align(d) |
|
|
gt_out = self.task_net(d) |
|
|
|
|
|
x_hat = torch.clamp(output["x_hat"], 0, 1) |
|
|
x_hat = x_hat.flip(1).mul(255) |
|
|
x_hat = x_hat - self.pixel_mean |
|
|
if not train_mode: |
|
|
x_hat = self.align.align(x_hat) |
|
|
task_net_out = self.task_net(x_hat) |
|
|
|
|
|
distortion_p2 = nn.MSELoss(reduction='none')(gt_out["p2"], task_net_out["p2"]) |
|
|
distortion_p3 = nn.MSELoss(reduction='none')(gt_out["p3"], task_net_out["p3"]) |
|
|
distortion_p4 = nn.MSELoss(reduction='none')(gt_out["p4"], task_net_out["p4"]) |
|
|
distortion_p5 = nn.MSELoss(reduction='none')(gt_out["p5"], task_net_out["p5"]) |
|
|
distortion_p6 = nn.MSELoss(reduction='none')(gt_out["p6"], task_net_out["p6"]) |
|
|
|
|
|
return 0.2*(distortion_p2.mean()+distortion_p3.mean()+distortion_p4.mean()+distortion_p5.mean()+distortion_p6.mean()) |
|
|
|
|
|
|
|
|
class AverageMeter: |
|
|
"""Compute running average.""" |
|
|
|
|
|
def __init__(self): |
|
|
self.val = 0 |
|
|
self.avg = 0 |
|
|
self.sum = 0 |
|
|
self.count = 0 |
|
|
|
|
|
def update(self, val, n=1): |
|
|
self.val = val |
|
|
self.sum += val * n |
|
|
self.count += n |
|
|
self.avg = self.sum / self.count |
|
|
|
|
|
|
|
|
def init(args): |
|
|
base_dir = f'{args.root}/{args.exp_name}/' |
|
|
os.makedirs(base_dir, exist_ok=True) |
|
|
return base_dir |
|
|
|
|
|
def setup_logger(log_dir): |
|
|
log_formatter = logging.Formatter("%(asctime)s [%(levelname)-5.5s] %(message)s") |
|
|
root_logger = logging.getLogger() |
|
|
root_logger.setLevel(logging.INFO) |
|
|
|
|
|
log_file_handler = logging.FileHandler(log_dir, encoding='utf-8') |
|
|
log_file_handler.setFormatter(log_formatter) |
|
|
root_logger.addHandler(log_file_handler) |
|
|
|
|
|
log_stream_handler = logging.StreamHandler(sys.stdout) |
|
|
log_stream_handler.setFormatter(log_formatter) |
|
|
root_logger.addHandler(log_stream_handler) |
|
|
|
|
|
logging.info('Logging file is %s' % log_dir) |
|
|
|
|
|
def save_img(img: torch.Tensor, vis_path, input_p, mask=False): |
|
|
img = img.clone().detach() |
|
|
img = img.to(torch.device('cpu')) |
|
|
if os.path.isdir(vis_path) is not True: |
|
|
os.makedirs(vis_path) |
|
|
end = '/' |
|
|
if mask: |
|
|
img_name = vis_path + 'mask_' + str(input_p[input_p.rfind(end):]) |
|
|
else: |
|
|
img_name = vis_path + str(input_p[input_p.rfind(end):]) |
|
|
vutils.save_image(img, os.path.join(vis_path, img_name), nrow=8) |
|
|
|
|
|
def train_one_epoch(model, data_loader, metrics_criterion, device, |
|
|
optimizer, epoch, loss_scaler, log_writer, args, val_dataloader=None, stage='train'): |
|
|
|
|
|
model.train(True) |
|
|
metric_logger = misc.MetricLogger(delimiter=" ") |
|
|
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) |
|
|
header = 'Epoch: [{}]'.format(epoch) |
|
|
print_freq = 20 |
|
|
accum_iter = args.accum_iter |
|
|
optimizer.zero_grad() |
|
|
if log_writer is not None: |
|
|
print('log_dir: {}'.format(log_writer.log_dir)) |
|
|
|
|
|
vis_path = os.path.join("./MIM_vbr/", stage) |
|
|
os.makedirs(vis_path, exist_ok=True) |
|
|
|
|
|
|
|
|
for data_iter_step, (samples, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): |
|
|
samples = samples.to(device, non_blocking=True) |
|
|
|
|
|
|
|
|
if data_iter_step % accum_iter == 0: |
|
|
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) |
|
|
|
|
|
with torch.cuda.amp.autocast(): |
|
|
out_net = model(samples, is_training=True, manual_mask_rate=None) |
|
|
|
|
|
rec = model.module.gen_img(out_net['logits'], out_net['token_all_mask'], out_net['token_indices']) |
|
|
|
|
|
rec = rec.to(device) |
|
|
out_criterion = metrics_criterion(samples, out_net, rec) |
|
|
loss_value = out_criterion['total_loss'].item() |
|
|
|
|
|
if not math.isfinite(loss_value): |
|
|
print("Loss is {}, stopping training".format(loss_value)) |
|
|
sys.exit(1) |
|
|
|
|
|
out_criterion['total_loss'] /= accum_iter |
|
|
loss_scaler(out_criterion['total_loss'], optimizer, clip_grad=args.grad_clip, parameters=model.parameters(), |
|
|
update_grad=(data_iter_step + 1) % accum_iter == 0) |
|
|
if (data_iter_step + 1) % accum_iter == 0: |
|
|
optimizer.zero_grad() |
|
|
|
|
|
torch.cuda.synchronize() |
|
|
|
|
|
metric_logger.update(loss=loss_value) |
|
|
|
|
|
lr = optimizer.param_groups[0]["lr"] |
|
|
metric_logger.update(lr=lr) |
|
|
metric_logger.update(bpp=out_criterion['bpp_loss']) |
|
|
metric_logger.update(bpp_mask=out_criterion['bpp_mask']) |
|
|
metric_logger.update(task_loss=out_net['task_loss'].item()) |
|
|
metric_logger.update(lmbda=out_net['lambda']) |
|
|
metric_logger.update(mask_ratio=out_net['mask_ratio']) |
|
|
metric_logger.update(lpips=out_criterion['lpips'].item()) |
|
|
metric_logger.update(dists=out_criterion['dists'].item()) |
|
|
|
|
|
loss_value_reduce = misc.all_reduce_mean(loss_value) |
|
|
if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: |
|
|
""" We use epoch_1000x as the x-axis in tensorboard. |
|
|
This calibrates different curves when batch size changes. |
|
|
""" |
|
|
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) |
|
|
log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x) |
|
|
log_writer.add_scalar('lr', lr, epoch_1000x) |
|
|
|
|
|
|
|
|
if data_iter_step % 1000 == 0: |
|
|
with torch.no_grad(): |
|
|
real_fake_images = torch.cat((samples, rec), dim=0) |
|
|
vutils.save_image(real_fake_images, os.path.join(vis_path, f"{epoch}_{data_iter_step}.jpg"), nrow=8) |
|
|
|
|
|
|
|
|
vutils.save_image(out_net['mask_vis'], os.path.join(vis_path, f"{epoch}_{data_iter_step}_mask.jpg"), nrow=8) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
metric_logger.synchronize_between_processes() |
|
|
print("Averaged stats:", metric_logger) |
|
|
return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
|
|
|
|
|
def train_one_epoch(train_dataloader, optimizer, model, criterion_rd, criterion_task, lmbda): |
|
|
model.train() |
|
|
device = next(model.parameters()).device |
|
|
tqdm_emu = tqdm.tqdm(enumerate(train_dataloader), total=len(train_dataloader), leave=False) |
|
|
for i, d in tqdm_emu: |
|
|
d = d.to(device) |
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
out_net = model(d) |
|
|
|
|
|
out_criterion = criterion_rd(out_net, d) |
|
|
perc_loss = criterion_task(out_net, d) |
|
|
total_loss = perc_loss + lmbda * out_criterion['bpp_loss'] |
|
|
total_loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
update_txt=f'[{i*len(d)}/{len(train_dataloader.dataset)}] | Loss: {total_loss.item():.3f} | Distortion loss: {perc_loss.item():.5f} | Bpp loss: {out_criterion["bpp_loss"].item():.4f}' |
|
|
tqdm_emu.set_postfix_str(update_txt, refresh=True) |
|
|
|
|
|
|
|
|
def validation_epoch(epoch, val_dataloader, model, criterion_rd, criterion_task, lmbda): |
|
|
model.eval() |
|
|
device = next(model.parameters()).device |
|
|
|
|
|
bpp_loss = AverageMeter() |
|
|
mse_loss = AverageMeter() |
|
|
psnr = AverageMeter() |
|
|
percloss = AverageMeter() |
|
|
totalloss = AverageMeter() |
|
|
|
|
|
with torch.no_grad(): |
|
|
tqdm_meter = tqdm.tqdm(enumerate(val_dataloader),leave=False, total=len(val_dataloader)) |
|
|
for i, d in tqdm_meter: |
|
|
align = Alignment(divisor=256, mode='resize').to(device) |
|
|
|
|
|
d = d.to(device) |
|
|
align_d = align.align(d) |
|
|
|
|
|
out_net = model(align_d) |
|
|
out_net['x_hat'] = align.resume(out_net['x_hat']).clamp_(0, 1) |
|
|
out_criterion = criterion_rd(out_net, d) |
|
|
perc_loss = criterion_task(out_net, d) |
|
|
total_loss = perc_loss + lmbda * out_criterion['bpp_loss'] |
|
|
|
|
|
bpp_loss.update(out_criterion["bpp_loss"]) |
|
|
mse_loss.update(out_criterion["mse_loss"]) |
|
|
psnr.update(out_criterion['psnr']) |
|
|
percloss.update(perc_loss) |
|
|
totalloss.update(total_loss) |
|
|
|
|
|
txt = f"Loss: {totalloss.avg:.3f} | MSE loss: {mse_loss.avg:.5f} | Perception loss: {percloss.avg:.4f} | Bpp loss: {bpp_loss.avg:.4f}" |
|
|
tqdm_meter.set_postfix_str(txt) |
|
|
|
|
|
model.train() |
|
|
print(f"Epoch: {epoch} | bpp loss: {bpp_loss.avg:.5f} | psnr: {psnr.avg:.5f}") |
|
|
return totalloss.avg |
|
|
|
|
|
|
|
|
def test_epoch(test_dataloader, model, criterion_rd, predictor, evaluator): |
|
|
model.eval() |
|
|
device = next(model.parameters()).device |
|
|
pixel_mean = torch.Tensor([103.530, 116.280, 123.675]).view(-1, 1, 1).to(device) |
|
|
|
|
|
bpp_loss = AverageMeter() |
|
|
psnr = AverageMeter() |
|
|
|
|
|
with torch.no_grad(): |
|
|
tqdm_meter = tqdm.tqdm(enumerate(test_dataloader),leave=False, total=len(test_dataloader)) |
|
|
for i, batch in tqdm_meter: |
|
|
with ExitStack() as stack: |
|
|
|
|
|
if isinstance(predictor.model, nn.Module): |
|
|
stack.enter_context(inference_context(predictor.model)) |
|
|
stack.enter_context(torch.no_grad()) |
|
|
|
|
|
align = Alignment(divisor=256, mode='resize').to(device) |
|
|
rcnn_align = Alignment(divisor=32).to(device) |
|
|
|
|
|
img = read_image(batch[0]["file_name"], format="BGR") |
|
|
d = torch.stack([batch[0]['image'].float().div(255)]).flip(1).to(device) |
|
|
align_d = align.align(d) |
|
|
|
|
|
out_net = model(align_d) |
|
|
out_net['x_hat'] = align.resume(out_net['x_hat']).clamp_(0, 1) |
|
|
out_criterion = criterion_rd(out_net, d) |
|
|
|
|
|
trand_y_tilde = out_net['x_hat'].flip(1).mul(255) |
|
|
trand_y_tilde = rcnn_align.align(trand_y_tilde - pixel_mean) |
|
|
|
|
|
bpp_loss.update(out_criterion["bpp_loss"]) |
|
|
psnr.update(out_criterion['psnr']) |
|
|
|
|
|
|
|
|
predictions = predictor(img, trand_y_tilde) |
|
|
evaluator.process(batch, [predictions]) |
|
|
txt = f"Bpp loss: {bpp_loss.avg:.4f} | PSNR loss: {psnr.avg:.4f}" |
|
|
tqdm_meter.set_postfix_str(txt) |
|
|
|
|
|
results = evaluator.evaluate() |
|
|
model.train() |
|
|
print(f"bpp loss: {bpp_loss.avg:.5f} | psnr: {psnr.avg:.5f}") |
|
|
return |
|
|
|
|
|
|
|
|
def inference(epoch, test_loader, model, metrics_criterion, device, manual_mask_ratio, args, stage='test'): |
|
|
model.eval() |
|
|
bpp_loss = AverageMeter() |
|
|
bpp_mask = AverageMeter() |
|
|
psnr = AverageMeter() |
|
|
msssim = AverageMeter() |
|
|
lpips = AverageMeter() |
|
|
dists = AverageMeter() |
|
|
test_loss = AverageMeter() |
|
|
|
|
|
vis_path = os.path.join("./MIM_test_high_resolu/", stage) |
|
|
os.makedirs(vis_path, exist_ok=True) |
|
|
if stage == 'test': |
|
|
test_vis_path = os.path.join("/home/v-ruoyufeng/v-ruoyufeng/qyp/rec_fid", manual_mask_ratio) |
|
|
os.makedirs(test_vis_path, exist_ok=True) |
|
|
|
|
|
with torch.no_grad(): |
|
|
tqdm_meter = tqdm.tqdm(enumerate(test_loader), leave=False, total=len(test_loader)) |
|
|
for i, d in tqdm_meter: |
|
|
d = d.to(device) |
|
|
d0 = d |
|
|
b_ori, _, h_ori, w_ori = d.shape |
|
|
d, patch_sizes, num_blocks_h, num_blocks_w = adaptively_split_and_pad(d) |
|
|
|
|
|
out_net = model(d, is_training=False, manual_mask_rate=manual_mask_ratio) |
|
|
|
|
|
rec = model.module.gen_img(out_net['logits'], out_net['token_all_mask'], out_net['token_indices'], num_iter=20) |
|
|
rec = crop_and_reconstruct(rec, patch_sizes, num_blocks_h, num_blocks_w) |
|
|
rec = rec.unsqueeze(0) |
|
|
rec = rec.to(device) |
|
|
print('d0:', d0.shape) |
|
|
print('rec:', rec.shape) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out_criterion = metrics_criterion(d0, out_net, rec) |
|
|
|
|
|
bpp_loss.update(out_criterion["bpp_loss"]) |
|
|
bpp_mask.update(out_criterion["bpp_mask"]) |
|
|
psnr.update(out_criterion['psnr']) |
|
|
msssim.update(out_criterion['msssim']) |
|
|
lpips.update(out_criterion['lpips']) |
|
|
dists.update(out_criterion['dists']) |
|
|
test_loss.update(out_criterion['total_loss']) |
|
|
|
|
|
|
|
|
if stage == 'val': |
|
|
|
|
|
with torch.no_grad(): |
|
|
real_fake_images = torch.cat((d0, rec), dim=0) |
|
|
vutils.save_image(real_fake_images, os.path.join(vis_path, f"{epoch}_{i}.jpg")) |
|
|
vutils.save_image(out_net['mask_vis'], os.path.join(vis_path, f"{epoch}_{i}_mask.jpg")) |
|
|
if stage == 'test': |
|
|
with torch.no_grad(): |
|
|
vutils.save_image(rec, os.path.join(test_vis_path, f"{i}.jpg"), nrow=8) |
|
|
|
|
|
|
|
|
|
|
|
model.train() |
|
|
|
|
|
|
|
|
if torch.distributed.is_initialized(): |
|
|
rank = dist.get_rank() |
|
|
else: |
|
|
rank = 0 |
|
|
|
|
|
if rank == 0: |
|
|
log_txt = f"{epoch}|bpp:{bpp_loss.avg.item():.5f}|mask:{bpp_mask.avg:.5f}|mask_ratio:{manual_mask_ratio}|psnr:{psnr.avg.item():.5f}|msssim:{msssim.avg.item():.5f}|lpips:{lpips.avg.item():.5f}|dists:{dists.avg.item():.5f}|Test loss:{test_loss.avg.item():.5f}" |
|
|
logging.info(log_txt) |
|
|
return test_loss.avg |
|
|
|
|
|
|
|
|
def save_checkpoint(state, is_best, base_dir, filename="checkpoint.pth.tar"): |
|
|
torch.save(state, base_dir+filename) |
|
|
if is_best: |
|
|
torch.save(state, base_dir+"checkpoint_best.pth.tar") |
|
|
|
|
|
def parse_args(argv): |
|
|
parser = argparse.ArgumentParser(description="Example training script.") |
|
|
parser.add_argument( |
|
|
"-c", |
|
|
"--config", |
|
|
default="config/vpt_default.yaml", |
|
|
help="Path to config file", |
|
|
) |
|
|
parser.add_argument( |
|
|
'--name', |
|
|
default=datetime.now().strftime('%Y-%m-%d_%H_%M_%S'), |
|
|
type=str, |
|
|
help='Result dir name', |
|
|
) |
|
|
parser.add_argument('--lr', type=float, default=None, metavar='LR', |
|
|
help='learning rate (absolute lr)') |
|
|
given_configs, remaining = parser.parse_known_args(argv) |
|
|
|
|
|
parser.add_argument('--world_size', default=1, type=int, |
|
|
help='number of distributed processes') |
|
|
parser.add_argument('--local-rank', default=-1, type=int) |
|
|
parser.add_argument('--dist_on_itp', action='store_true') |
|
|
parser.add_argument('--dist_url', default='env://', |
|
|
help='url used to set up distributed training') |
|
|
with open(given_configs.config) as file: |
|
|
yaml_data= yaml.safe_load(file) |
|
|
parser.set_defaults(**yaml_data) |
|
|
|
|
|
parser.add_argument( |
|
|
"-T", |
|
|
"--TEST", |
|
|
|
|
|
default=False, |
|
|
help='Testing' |
|
|
) |
|
|
args = parser.parse_args(remaining) |
|
|
return args |
|
|
|
|
|
|
|
|
def main(argv): |
|
|
args = parse_args(argv) |
|
|
base_dir = init(args) |
|
|
|
|
|
if args.output_dir: |
|
|
Path(args.output_dir).mkdir(parents=True, exist_ok=True) |
|
|
args.log_dir = args.output_dir |
|
|
|
|
|
misc.init_distributed_mode(args) |
|
|
|
|
|
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) |
|
|
print("{}".format(args).replace(', ', ',\n')) |
|
|
|
|
|
device = torch.device(args.device) |
|
|
|
|
|
seed = args.seed + misc.get_rank() |
|
|
torch.manual_seed(seed) |
|
|
torch.cuda.manual_seed(seed) |
|
|
np.random.seed(seed) |
|
|
random.seed(seed) |
|
|
|
|
|
cudnn.benchmark = True |
|
|
|
|
|
setup_logger(base_dir + '/' + time.strftime('%Y%m%d_%H%M%S') + '.log') |
|
|
msg = f'======================= {args.name} =======================' |
|
|
logging.info(msg) |
|
|
for k in args.__dict__: |
|
|
logging.info(k + ':' + str(args.__dict__[k])) |
|
|
logging.info('=' * len(msg)) |
|
|
|
|
|
|
|
|
transform_det = transforms.Compose([ |
|
|
transforms.RandomHorizontalFlip(), |
|
|
transforms.ToTensor()]) |
|
|
transform_val = transforms.Compose([ |
|
|
|
|
|
|
|
|
transforms.ToTensor() |
|
|
]) |
|
|
|
|
|
|
|
|
if args.dataset=='coco': |
|
|
train_dataset = MSCOCO(args.dataset_path + "/train2017/", |
|
|
transform_det, |
|
|
"/home/t2vg-a100-G4-10/project/qyp/mimc_rope/util/img_list.txt") |
|
|
|
|
|
val_dataset = MSCOCO(args.kodak_path, transform_val) |
|
|
|
|
|
device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
if True: |
|
|
num_tasks = misc.get_world_size() |
|
|
global_rank = misc.get_rank() |
|
|
sampler_val = torch.utils.data.DistributedSampler( |
|
|
val_dataset, num_replicas=num_tasks, rank=global_rank, shuffle=True |
|
|
) |
|
|
else: |
|
|
sampler_train = torch.utils.data.RandomSampler(train_dataset) |
|
|
|
|
|
if global_rank == 0 and args.log_dir is not None: |
|
|
os.makedirs(args.log_dir, exist_ok=True) |
|
|
log_writer = SummaryWriter(log_dir=args.log_dir) |
|
|
else: |
|
|
log_writer = None |
|
|
|
|
|
val_dataloader = DataLoader(val_dataset, sampler=sampler_val, batch_size=1, |
|
|
num_workers=args.num_workers, shuffle=False, pin_memory=args.pin_mem, drop_last=True) |
|
|
|
|
|
|
|
|
vqgan_ckpt_path = '/home/t2vg-a100-G4-10/project/qyp/mage/vqgan_jax_strongaug.ckpt' |
|
|
model = models_mage_codec_high_resolu.__dict__[args.model](mask_ratio_mu=args.mask_ratio_mu, mask_ratio_std=args.mask_ratio_std, |
|
|
mask_ratio_min=args.mask_ratio_min, mask_ratio_max=args.mask_ratio_max, |
|
|
vqgan_ckpt_path=vqgan_ckpt_path) |
|
|
|
|
|
model.to(device) |
|
|
model_without_ddp = model |
|
|
print("Model = %s" % str(model_without_ddp)) |
|
|
eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() |
|
|
if args.lr is None: |
|
|
args.lr = args.blr * eff_batch_size / 256 |
|
|
print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) |
|
|
print("actual lr: %.2e" % args.lr) |
|
|
|
|
|
print("accumulate grad iterations: %d" % args.accum_iter) |
|
|
print("effective batch size: %d" % eff_batch_size) |
|
|
|
|
|
if args.distributed: |
|
|
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) |
|
|
model_without_ddp = model.module |
|
|
|
|
|
|
|
|
param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay) |
|
|
optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) |
|
|
print(optimizer) |
|
|
loss_scaler = NativeScaler() |
|
|
|
|
|
|
|
|
misc.load_model(args=args, model_without_ddp=model_without_ddp, |
|
|
optimizer=optimizer, loss_scaler=loss_scaler, strict=False) |
|
|
|
|
|
metrics_criterion = CalMetrics() |
|
|
|
|
|
|
|
|
last_epoch = args.start_epoch |
|
|
|
|
|
|
|
|
print("############## pre validation ##############") |
|
|
best_loss = float("inf") |
|
|
tqrange = tqdm.trange(last_epoch, args.epochs) |
|
|
val_mask_ratio = 0.5 |
|
|
test_loss = inference(-1, val_dataloader, model, metrics_criterion, device, val_mask_ratio, args, 'val') |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main(sys.argv[1:]) |