|
|
|
|
|
|
|
|
import argparse |
|
|
import datetime |
|
|
import numpy as np |
|
|
import time |
|
|
import torch |
|
|
import torch.backends.cudnn as cudnn |
|
|
import json |
|
|
|
|
|
from pathlib import Path |
|
|
|
|
|
from timm.models import create_model |
|
|
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy |
|
|
from timm.scheduler import create_scheduler |
|
|
from timm.optim import create_optimizer |
|
|
from timm.utils import NativeScaler, get_state_dict, ModelEma |
|
|
from augmentations import collate_data_and_cast_aug |
|
|
from datasets import build_dataset |
|
|
|
|
|
from losses_hint import DistillationLoss |
|
|
from samplers import RASampler |
|
|
from functools import partial |
|
|
|
|
|
import importlib |
|
|
import utils |
|
|
import random |
|
|
import math |
|
|
from multiprocessing import Value |
|
|
from abc import ABC |
|
|
|
|
|
import sys |
|
|
from typing import Iterable, Optional |
|
|
from timm.data import Mixup |
|
|
from timm.utils import accuracy, ModelEma |
|
|
import utils |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
from torchvision.utils import make_grid |
|
|
|
|
|
class MaskingGenerator(ABC): |
|
|
def __init__(self, input_size): |
|
|
if not isinstance(input_size, tuple): |
|
|
input_size = (input_size,) * 2 |
|
|
self.height, self.width = input_size |
|
|
self.num_patches = self.height * self.width |
|
|
|
|
|
def __repr__(self): |
|
|
raise NotImplementedError |
|
|
|
|
|
def get_shape(self): |
|
|
return self.height, self.width |
|
|
|
|
|
def _mask(self, mask, max_mask_patches): |
|
|
raise NotImplementedError |
|
|
|
|
|
def get_none_mask(self): |
|
|
return np.zeros(shape=self.get_shape(), dtype=bool) |
|
|
|
|
|
|
|
|
|
|
|
class RandomMaskingGenerator(MaskingGenerator): |
|
|
def __init__( |
|
|
self, |
|
|
input_size, |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
input_size: the size of the token map, e.g., 14x14 |
|
|
""" |
|
|
super().__init__(input_size) |
|
|
|
|
|
def __repr__(self): |
|
|
repr_str = f"Random Generator({self.height}, {self.width})" |
|
|
return repr_str |
|
|
|
|
|
def _mask(self, mask, max_mask_patches): |
|
|
return super()._mask(mask, max_mask_patches) |
|
|
|
|
|
def __call__(self, num_masking_patches=0): |
|
|
if num_masking_patches <= 0: |
|
|
return np.zeros(shape=self.get_shape(), dtype=bool) |
|
|
|
|
|
mask = np.hstack([np.ones(num_masking_patches, dtype=bool), |
|
|
np.zeros(self.num_patches - num_masking_patches, dtype=bool)]) |
|
|
np.random.shuffle(mask) |
|
|
mask = mask.reshape(self.get_shape()) |
|
|
return mask |
|
|
|
|
|
|
|
|
def get_args_parser(): |
|
|
parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False) |
|
|
parser.add_argument('--batch-size', default=64, type=int) |
|
|
parser.add_argument('--epochs', default=300, type=int) |
|
|
parser.add_argument('--bce-loss', action='store_true') |
|
|
parser.add_argument('--unscale-lr', action='store_true') |
|
|
|
|
|
|
|
|
parser.add_argument('--model', default='deit_base_patch16_224', type=str) |
|
|
parser.add_argument('--target_model', default='deit_base_patch16_224', type=str) |
|
|
parser.add_argument('--input-size', default=224, type=int, help='images input size') |
|
|
|
|
|
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', |
|
|
help='Dropout rate (default: 0.)') |
|
|
parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT', |
|
|
help='Drop path rate (default: 0.1)') |
|
|
|
|
|
parser.add_argument('--model-ema', action='store_true') |
|
|
parser.add_argument('--no-model-ema', action='store_false', dest='model_ema') |
|
|
parser.set_defaults(model_ema=True) |
|
|
parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='') |
|
|
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='') |
|
|
|
|
|
|
|
|
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', |
|
|
help='Optimizer (default: "adamw"') |
|
|
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', |
|
|
help='Optimizer Epsilon (default: 1e-8)') |
|
|
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', |
|
|
help='Optimizer Betas (default: None, use opt default)') |
|
|
parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', |
|
|
help='Clip gradient norm (default: None, no clipping)') |
|
|
parser.add_argument('--momentum', type=float, default=0.9, metavar='M', |
|
|
help='SGD momentum (default: 0.9)') |
|
|
parser.add_argument('--weight-decay', type=float, default=0.05, |
|
|
help='weight decay (default: 0.05)') |
|
|
|
|
|
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', |
|
|
help='LR scheduler (default: "cosine"') |
|
|
parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', |
|
|
help='learning rate (default: 5e-4)') |
|
|
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', |
|
|
help='learning rate noise on/off epoch percentages') |
|
|
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', |
|
|
help='learning rate noise limit percent (default: 0.67)') |
|
|
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', |
|
|
help='learning rate noise std-dev (default: 1.0)') |
|
|
parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', |
|
|
help='warmup learning rate (default: 1e-6)') |
|
|
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', |
|
|
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') |
|
|
|
|
|
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', |
|
|
help='epoch interval to decay LR') |
|
|
parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', |
|
|
help='epochs to warmup LR, if scheduler supports') |
|
|
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', |
|
|
help='epochs to cooldown LR at min_lr, after cyclic schedule ends') |
|
|
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', |
|
|
help='patience epochs for Plateau LR scheduler (default: 10') |
|
|
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', |
|
|
help='LR decay rate (default: 0.1)') |
|
|
|
|
|
|
|
|
parser.add_argument('--color-jitter', type=float, default=0.3, metavar='PCT', |
|
|
help='Color jitter factor (default: 0.3)') |
|
|
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', |
|
|
help='Use AutoAugment policy. "v0" or "original". " + \ |
|
|
"(default: rand-m9-mstd0.5-inc1)'), |
|
|
parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') |
|
|
parser.add_argument('--train-interpolation', type=str, default='bicubic', |
|
|
help='Training interpolation (random, bilinear, bicubic default: "bicubic")') |
|
|
|
|
|
parser.add_argument('--repeated-aug', action='store_true') |
|
|
parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug') |
|
|
parser.set_defaults(repeated_aug=True) |
|
|
|
|
|
parser.add_argument('--train-mode', action='store_true') |
|
|
parser.add_argument('--no-train-mode', action='store_false', dest='train_mode') |
|
|
parser.set_defaults(train_mode=True) |
|
|
|
|
|
parser.add_argument('--ThreeAugment', action='store_true') |
|
|
|
|
|
parser.add_argument('--src', action='store_true') |
|
|
|
|
|
|
|
|
parser.add_argument('--global_crops_size', '--img_size', default=224, type=int, |
|
|
help="this should be equal to image size") |
|
|
parser.add_argument('--patch_size', default=16, type=int, |
|
|
help="patch size for vit patch embedding") |
|
|
|
|
|
|
|
|
parser.add_argument('--mask_ratio', default=(0.1, 0.5), type=float, nargs='+', |
|
|
help="mask ratio can be either a value or a range") |
|
|
parser.add_argument('--mask_probability', default=0., type=float, |
|
|
help="how many samples with be applied with masking") |
|
|
parser.add_argument('--mask_first_n', action='store_true', |
|
|
help="mask the first n sample to avoid shuffling. Needed for MAE-style encoder") |
|
|
parser.add_argument('--clone_batch', default=1, type=int, |
|
|
help="how many times to clone the batch for masking (default: 1, not cloning)") |
|
|
|
|
|
|
|
|
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', |
|
|
help='Random erase prob (default: 0.25)') |
|
|
parser.add_argument('--remode', type=str, default='pixel', |
|
|
help='Random erase mode (default: "pixel")') |
|
|
parser.add_argument('--recount', type=int, default=1, |
|
|
help='Random erase count (default: 1)') |
|
|
parser.add_argument('--resplit', action='store_true', default=False, |
|
|
help='Do not random erase first (clean) augmentation split') |
|
|
|
|
|
|
|
|
parser.add_argument('--mixup', type=float, default=0.8, |
|
|
help='mixup alpha, mixup enabled if > 0. (default: 0.8)') |
|
|
parser.add_argument('--cutmix', type=float, default=1.0, |
|
|
help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') |
|
|
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, |
|
|
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') |
|
|
parser.add_argument('--mixup-prob', type=float, default=1.0, |
|
|
help='Probability of performing mixup or cutmix when either/both is enabled') |
|
|
parser.add_argument('--mixup-switch-prob', type=float, default=0.5, |
|
|
help='Probability of switching to cutmix when both mixup and cutmix enabled') |
|
|
parser.add_argument('--mixup-mode', type=str, default='batch', |
|
|
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') |
|
|
|
|
|
|
|
|
parser.add_argument('--teacher-model', default='base', type=str) |
|
|
parser.add_argument('--teacher-path', type=str, default='') |
|
|
parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="") |
|
|
parser.add_argument('--distillation-alpha', default=0.5, type=float, help="") |
|
|
parser.add_argument('--distillation-tau', default=1.0, type=float, help="") |
|
|
parser.add_argument('--lambda_token', type=float, default=1.0) |
|
|
parser.add_argument('--lambda_fea', type=float, default=1.0) |
|
|
parser.add_argument('--lambda_patch', type=float, default=1.0) |
|
|
|
|
|
|
|
|
parser.add_argument('--cosub', action='store_true') |
|
|
|
|
|
|
|
|
parser.add_argument('--finetune', default='', help='finetune from checkpoint') |
|
|
parser.add_argument('--attn-only', action='store_true') |
|
|
parser.add_argument('--weight_inherit', default='') |
|
|
|
|
|
|
|
|
parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str, |
|
|
help='dataset path') |
|
|
parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'IMNET_ibot', 'IMNET_ibot_aug', 'IMNET_ibot_fast_aug', 'INAT', 'INAT19', 'IMNET_L', 'IMNET_L_ibot'], |
|
|
type=str, help='Image Net dataset path') |
|
|
parser.add_argument('--inat-category', default='name', |
|
|
choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'], |
|
|
type=str, help='semantic granularity') |
|
|
|
|
|
parser.add_argument('--output_dir', default='', |
|
|
help='path where to save, empty for no saving') |
|
|
parser.add_argument('--device', default='cuda', |
|
|
help='device to use for training / testing') |
|
|
parser.add_argument('--seed', default=0, type=int) |
|
|
parser.add_argument('--resume', default='', help='resume from checkpoint') |
|
|
parser.add_argument('--start_epoch', default=0, type=int, metavar='N', |
|
|
help='start epoch') |
|
|
parser.add_argument('--eval', action='store_true', help='Perform evaluation only') |
|
|
parser.add_argument('--eval-crop-ratio', default=0.875, type=float, help="Crop ratio for evaluation") |
|
|
parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation') |
|
|
parser.add_argument('--num_workers', default=10, type=int) |
|
|
parser.add_argument('--pin-mem', action='store_true', |
|
|
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') |
|
|
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem', |
|
|
help='') |
|
|
parser.set_defaults(pin_mem=True) |
|
|
|
|
|
|
|
|
parser.add_argument('--distributed', action='store_true', default=False, help='Enabling distributed training') |
|
|
parser.add_argument('--world_size', default=1, type=int, |
|
|
help='number of distributed processes') |
|
|
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') |
|
|
return parser |
|
|
|
|
|
import torchvision |
|
|
import matplotlib.pyplot as plt |
|
|
def visualize_features(features, output_path='./feature_visualization_settings_all.png'): |
|
|
|
|
|
batch_size, num_features, height, width = features.shape |
|
|
|
|
|
|
|
|
vis = features.mean(dim=1, keepdim=True) |
|
|
vis = vis - vis.min() |
|
|
vis = vis / vis.max() |
|
|
|
|
|
|
|
|
vis = vis.squeeze(1).cpu().detach().numpy() |
|
|
|
|
|
|
|
|
vis_colored = np.zeros((batch_size, height, width, 3)) |
|
|
for i in range(batch_size): |
|
|
vis_colored[i] = plt.cm.viridis(vis[i])[:, :, :3] |
|
|
|
|
|
|
|
|
vis_colored = torch.tensor(vis_colored).permute(0, 3, 1, 2) |
|
|
|
|
|
|
|
|
torchvision.utils.save_image(vis_colored, output_path, normalize=True) |
|
|
|
|
|
import torch.nn as nn |
|
|
|
|
|
class ResBlock(nn.Module): |
|
|
def __init__(self, in_channels, out_channels, stride=1): |
|
|
super(ResBlock, self).__init__() |
|
|
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1) |
|
|
self.bn1 = nn.BatchNorm2d(out_channels) |
|
|
self.relu = nn.ReLU(inplace=True) |
|
|
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) |
|
|
self.bn2 = nn.BatchNorm2d(out_channels) |
|
|
|
|
|
if stride != 1 or in_channels != out_channels: |
|
|
self.shortcut = nn.Sequential( |
|
|
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride), |
|
|
nn.BatchNorm2d(out_channels) |
|
|
) |
|
|
else: |
|
|
self.shortcut = nn.Sequential() |
|
|
|
|
|
def forward(self, x): |
|
|
out = self.relu(self.bn1(self.conv1(x))) |
|
|
out = self.bn2(self.conv2(out)) |
|
|
out += self.shortcut(x) |
|
|
out = self.relu(out) |
|
|
return out |
|
|
|
|
|
|
|
|
class Decoder(nn.Module): |
|
|
def __init__(self, feature_dim, output_channels=3): |
|
|
super(Decoder, self).__init__() |
|
|
self.initial = nn.Sequential( |
|
|
nn.ConvTranspose2d(feature_dim, 512, kernel_size=4, stride=2, padding=1), |
|
|
nn.BatchNorm2d(512), |
|
|
nn.ReLU(inplace=True) |
|
|
) |
|
|
|
|
|
self.layer1 = ResBlock(512, 256, stride=1) |
|
|
self.up1 = nn.ConvTranspose2d(256, 256, kernel_size=4, stride=2, padding=1) |
|
|
self.layer2 = ResBlock(256, 128, stride=1) |
|
|
self.up2 = nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1) |
|
|
self.layer3 = ResBlock(128, 64, stride=1) |
|
|
|
|
|
self.up3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.final = nn.Sequential( |
|
|
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(64, output_channels, kernel_size=3, stride=1, padding=1), |
|
|
) |
|
|
self.pre_conv = nn.Conv2d(feature_dim, feature_dim, kernel_size=3, stride=1, padding=1) |
|
|
self.relu = nn.LeakyReLU(0.05, inplace=True) |
|
|
self.post_conv = nn.Conv2d(feature_dim, feature_dim, kernel_size=3, stride=1, padding=1) |
|
|
|
|
|
def forward(self, x, h, w): |
|
|
x = self.pre_conv(x) |
|
|
x = self.relu(x) |
|
|
x = torch.nn.functional.interpolate(x, size=(h//8, w//8), mode='bicubic', align_corners=False) |
|
|
x = self.post_conv(x) |
|
|
x = self.relu(x) |
|
|
|
|
|
x = self.initial(x) |
|
|
x = self.layer1(x) |
|
|
x = self.up1(x) |
|
|
x = self.layer2(x) |
|
|
x = self.up2(x) |
|
|
x = self.layer3(x) |
|
|
x = self.up3(x) |
|
|
x = self.final(x) |
|
|
return x |
|
|
|
|
|
def cal_psnr(output, target): |
|
|
mse = torch.mean((output - target) ** 2) |
|
|
if(mse == 0): |
|
|
return 100 |
|
|
max_pixel = 1. |
|
|
psnr = 10 * torch.log10(max_pixel / mse) |
|
|
return torch.mean(psnr) |
|
|
|
|
|
import glob |
|
|
from torch.utils.data import DataLoader, Dataset |
|
|
from PIL import Image |
|
|
class MSCOCO(Dataset): |
|
|
def __init__(self, root, transform, img_list=None): |
|
|
assert root[-1] == '/', "root to COCO dataset should end with \'/\', not {}.".format( |
|
|
root) |
|
|
|
|
|
if img_list: |
|
|
self.image_paths = [] |
|
|
with open(img_list, 'r') as r: |
|
|
lines = r.read().splitlines() |
|
|
for line in lines: |
|
|
self.image_paths.append(root + line) |
|
|
else: |
|
|
self.image_paths = sorted(glob.glob(root + "*.jpg")) |
|
|
self.transform = transform |
|
|
|
|
|
def __getitem__(self, index): |
|
|
""" |
|
|
Args: |
|
|
index (int): Index |
|
|
Returns: |
|
|
object: image. |
|
|
""" |
|
|
img_path = self.image_paths[index] |
|
|
|
|
|
img = Image.open(img_path).convert('RGB') |
|
|
|
|
|
if self.transform is not None: |
|
|
img = self.transform(img) |
|
|
|
|
|
return img |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.image_paths) |
|
|
|
|
|
def visualize_features(features, output_path='./feature_visualization_fast.png'): |
|
|
|
|
|
batch_size, num_features, height, width = features.shape |
|
|
|
|
|
|
|
|
vis = features.mean(dim=1, keepdim=True) |
|
|
vis = vis - vis.min() |
|
|
vis = vis / vis.max() |
|
|
|
|
|
|
|
|
vis = vis.squeeze(1).cpu().detach().numpy() |
|
|
|
|
|
|
|
|
vis_colored = np.zeros((batch_size, height, width, 3)) |
|
|
for i in range(batch_size): |
|
|
vis_colored[i] = plt.cm.viridis(vis[i])[:, :, :3] |
|
|
|
|
|
|
|
|
vis_colored = torch.tensor(vis_colored).permute(0, 3, 1, 2) |
|
|
|
|
|
|
|
|
torchvision.utils.save_image(vis_colored, output_path, normalize=True) |
|
|
|
|
|
|
|
|
def main(args): |
|
|
utils.init_distributed_mode(args) |
|
|
|
|
|
print(args) |
|
|
|
|
|
device = torch.device(args.device) |
|
|
|
|
|
|
|
|
seed = args.seed + utils.get_rank() |
|
|
torch.manual_seed(seed) |
|
|
np.random.seed(seed) |
|
|
|
|
|
|
|
|
cudnn.benchmark = True |
|
|
|
|
|
print(f"Creating model: {args.model}") |
|
|
meta_arch_module = importlib.import_module(args.model) |
|
|
MetaArch = meta_arch_module.MetaArch |
|
|
|
|
|
model = MetaArch(args) |
|
|
|
|
|
if args.finetune: |
|
|
checkpoint = torch.load(args.finetune, map_location='cpu') |
|
|
|
|
|
if 'state_dict' in checkpoint: |
|
|
pretrained_dict = checkpoint['state_dict'] |
|
|
elif 'model' in checkpoint: |
|
|
pretrained_dict = checkpoint['model'] |
|
|
else: |
|
|
pretrained_dict = checkpoint |
|
|
|
|
|
missing_keys, unexpected_keys = model.load_state_dict(pretrained_dict, False) |
|
|
print('missing_keys: ', missing_keys) |
|
|
print('unexpected_keys: ', unexpected_keys) |
|
|
|
|
|
if args.attn_only: |
|
|
for name_p,p in model.named_parameters(): |
|
|
if '.attn.' in name_p: |
|
|
p.requires_grad = True |
|
|
else: |
|
|
p.requires_grad = False |
|
|
try: |
|
|
model.head.weight.requires_grad = True |
|
|
model.head.bias.requires_grad = True |
|
|
except: |
|
|
model.fc.weight.requires_grad = True |
|
|
model.fc.bias.requires_grad = True |
|
|
try: |
|
|
model.pos_embed.requires_grad = True |
|
|
except: |
|
|
print('no position encoding') |
|
|
try: |
|
|
for p in model.patch_embed.parameters(): |
|
|
p.requires_grad = False |
|
|
except: |
|
|
print('no patch embed') |
|
|
|
|
|
model.to(device) |
|
|
|
|
|
model_ema = None |
|
|
if args.model_ema: |
|
|
|
|
|
model_ema = ModelEma( |
|
|
model.student.backbone, |
|
|
decay=args.model_ema_decay, |
|
|
device='cpu' if args.model_ema_force_cpu else '', |
|
|
resume='') |
|
|
|
|
|
model_without_ddp = model |
|
|
if args.distributed: |
|
|
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) |
|
|
model_without_ddp = model.module |
|
|
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
print('number of params:', n_parameters) |
|
|
|
|
|
if not args.unscale_lr: |
|
|
linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0 |
|
|
args.lr = linear_scaled_lr |
|
|
|
|
|
output_dir = Path(args.output_dir) |
|
|
|
|
|
|
|
|
from torchvision import transforms |
|
|
data_transforms = transforms.Compose([ |
|
|
|
|
|
transforms.RandomResizedCrop(560, scale=(0.8, 1.0), interpolation=3), |
|
|
transforms.RandomHorizontalFlip(), |
|
|
transforms.ToTensor() |
|
|
]) |
|
|
dataset_path = '/home/t2vg-a100-G4-1/projects/dataset/' |
|
|
dataset = MSCOCO(dataset_path+"/train2017/", |
|
|
data_transforms) |
|
|
print(f"Loaded dataset with {len(dataset)} samples") |
|
|
|
|
|
|
|
|
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4) |
|
|
print('Dataloader created') |
|
|
|
|
|
|
|
|
decoder = Decoder(feature_dim=768).to(device) |
|
|
import torch.optim as optim |
|
|
optimizer = optim.Adam(decoder.parameters(), lr=5e-5) |
|
|
criterion = nn.MSELoss() |
|
|
|
|
|
|
|
|
writer = SummaryWriter(log_dir='vis_logs', flush_secs=30) |
|
|
|
|
|
|
|
|
saveroot = "./vis_maefeats_decode_new" |
|
|
import shutil |
|
|
import os |
|
|
shutil.rmtree(saveroot, ignore_errors=True) |
|
|
os.makedirs(saveroot, exist_ok=True) |
|
|
model.eval() |
|
|
iteration = 0 |
|
|
save_iter = 50000 |
|
|
eval_iter = 100 |
|
|
epoch_size = len(dataloader) |
|
|
for epoch in range(100): |
|
|
decoder.train() |
|
|
for idx, images in enumerate(dataloader): |
|
|
images = images.to(device) |
|
|
_, _, h, w = images.shape |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
inputs = images |
|
|
features_dict = model.student.backbone(inputs, is_training=True) |
|
|
features = features_dict['x_norm_patchtokens'] |
|
|
|
|
|
features, _ = model.info_bottleneck(features, is_training=False) |
|
|
features = features.view(-1, 40, 40, features.shape[2]) |
|
|
features = features.permute(0, 3, 1, 2) |
|
|
features = (features - features.mean()) / features.std() |
|
|
features = torch.clamp(features, -5, 5) |
|
|
|
|
|
|
|
|
|
|
|
reconstructed_images = decoder(features, h, w) |
|
|
|
|
|
|
|
|
|
|
|
loss = criterion(reconstructed_images, images) |
|
|
psnr = cal_psnr(reconstructed_images, images) |
|
|
writer.add_scalar('Train_loss', loss, (epoch*epoch_size + iteration)) |
|
|
writer.add_scalar('Train_psnr', psnr, (epoch*epoch_size + iteration)) |
|
|
writer.add_image("input", make_grid(images, nrow=4), (epoch*epoch_size + iteration)) |
|
|
writer.add_image("rec", make_grid(reconstructed_images, nrow=4), (epoch*epoch_size + iteration)) |
|
|
|
|
|
|
|
|
visualize_features(features, output_path=f'{saveroot}/features_{iteration}.png') |
|
|
torchvision.utils.save_image(images, f"{saveroot}/reconstructed_images_{iteration}.png", normalize=True) |
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
|
|
|
|
|
|
print(f"EPOCH [{epoch + 1}/{args.epochs}], ITERATION [{idx + 1}/{len(dataloader)}], LOSS: {loss.item()}, PSNR: {psnr.item()}", end='\r') |
|
|
iteration += 1 |
|
|
|
|
|
|
|
|
if iteration % eval_iter == 0 or iteration == 1: |
|
|
decoder.eval() |
|
|
with torch.no_grad(): |
|
|
savedir_vis = os.path.join(saveroot, "vis") |
|
|
os.makedirs(savedir_vis, exist_ok=True) |
|
|
vis = torch.cat([images, reconstructed_images], dim=2) |
|
|
|
|
|
|
|
|
|
|
|
torchvision.utils.save_image(vis, f"{savedir_vis}/vis_{iteration}.png", normalize=True) |
|
|
decoder.train() |
|
|
|
|
|
|
|
|
if iteration % save_iter == 0: |
|
|
savedir_model = os.path.join(saveroot, "model") |
|
|
os.makedirs(savedir_model, exist_ok=True) |
|
|
|
|
|
torch.save(decoder.state_dict(), f"{savedir_model}/iteration_{iteration}.pth") |
|
|
|
|
|
writer.close() |
|
|
|
|
|
if __name__ == '__main__': |
|
|
parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()]) |
|
|
args = parser.parse_args() |
|
|
if args.output_dir: |
|
|
Path(args.output_dir).mkdir(parents=True, exist_ok=True) |
|
|
main(args) |
|
|
|