HIT / utils /model_utils.py
yssszzzzzzzzy's picture
Upload 21 files
4a55589 verified
import torch
import torch.nn as nn
import os
from collections import OrderedDict
def freeze(model):
for p in model.parameters():
p.requires_grad=False
def unfreeze(model):
for p in model.parameters():
p.requires_grad=True
def is_frozen(model):
x = [p.requires_grad for p in model.parameters()]
return not all(x)
def save_checkpoint(model_dir, state, session):
epoch = state['epoch']
model_out_path = os.path.join(model_dir,"model_epoch_{}_{}.pth".format(epoch,session))
torch.save(state, model_out_path)
def load_checkpoint(model, weights):
checkpoint = torch.load(weights)
try:
model.load_state_dict(checkpoint["state_dict"])
except:
state_dict = checkpoint["state_dict"]
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] if 'module.' in k else k
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
def load_checkpoint_multigpu(model, weights):
checkpoint = torch.load(weights)
state_dict = checkpoint["state_dict"]
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:]
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
def load_start_epoch(weights):
checkpoint = torch.load(weights)
epoch = checkpoint["epoch"]
return epoch
def load_optim(optimizer, weights):
checkpoint = torch.load(weights)
optimizer.load_state_dict(checkpoint['optimizer'])
for p in optimizer.param_groups: lr = p['lr']
return lr
def get_arch(opt):
from model import HIT
arch = opt.arch
print('You choose '+arch+'...')
if arch == 'HIT_T':
model_restoration = HIT(img_size=opt.train_ps,embed_dim=16,win_size=8,token_projection='linear',token_mlp='leff')
elif arch == 'HIT_S':
model_restoration = HIT(img_size=opt.train_ps,embed_dim=32,win_size=8,token_projection='linear',token_mlp='leff',
depths=[2, 2, 2, 2, 2, 2, 2, 2, 2],dd_in=opt.dd_in)
elif arch == 'HIT_B':
model_restoration = HIT(img_size=opt.train_ps,embed_dim=32,win_size=8,token_projection='linear',token_mlp='leff',
depths=[1, 2, 8, 8, 2, 8, 8, 2, 1],dd_in=opt.dd_in)
else:
raise Exception("Arch error!")
return model_restoration