Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import os | |
| from collections import OrderedDict | |
| def freeze(model): | |
| for p in model.parameters(): | |
| p.requires_grad=False | |
| def unfreeze(model): | |
| for p in model.parameters(): | |
| p.requires_grad=True | |
| def is_frozen(model): | |
| x = [p.requires_grad for p in model.parameters()] | |
| return not all(x) | |
| def save_checkpoint(model_dir, state, session): | |
| epoch = state['epoch'] | |
| model_out_path = os.path.join(model_dir,"model_epoch_{}_{}.pth".format(epoch,session)) | |
| torch.save(state, model_out_path) | |
| def load_checkpoint(model, weights): | |
| checkpoint = torch.load(weights) | |
| try: | |
| model.load_state_dict(checkpoint["state_dict"]) | |
| except: | |
| state_dict = checkpoint["state_dict"] | |
| new_state_dict = OrderedDict() | |
| for k, v in state_dict.items(): | |
| name = k[7:] if 'module.' in k else k | |
| new_state_dict[name] = v | |
| model.load_state_dict(new_state_dict) | |
| def load_checkpoint_multigpu(model, weights): | |
| checkpoint = torch.load(weights) | |
| state_dict = checkpoint["state_dict"] | |
| new_state_dict = OrderedDict() | |
| for k, v in state_dict.items(): | |
| name = k[7:] | |
| new_state_dict[name] = v | |
| model.load_state_dict(new_state_dict) | |
| def load_start_epoch(weights): | |
| checkpoint = torch.load(weights) | |
| epoch = checkpoint["epoch"] | |
| return epoch | |
| def load_optim(optimizer, weights): | |
| checkpoint = torch.load(weights) | |
| optimizer.load_state_dict(checkpoint['optimizer']) | |
| for p in optimizer.param_groups: lr = p['lr'] | |
| return lr | |
| def get_arch(opt): | |
| from model import HIT | |
| arch = opt.arch | |
| print('You choose '+arch+'...') | |
| if arch == 'HIT_T': | |
| model_restoration = HIT(img_size=opt.train_ps,embed_dim=16,win_size=8,token_projection='linear',token_mlp='leff') | |
| elif arch == 'HIT_S': | |
| model_restoration = HIT(img_size=opt.train_ps,embed_dim=32,win_size=8,token_projection='linear',token_mlp='leff', | |
| depths=[2, 2, 2, 2, 2, 2, 2, 2, 2],dd_in=opt.dd_in) | |
| elif arch == 'HIT_B': | |
| model_restoration = HIT(img_size=opt.train_ps,embed_dim=32,win_size=8,token_projection='linear',token_mlp='leff', | |
| depths=[1, 2, 8, 8, 2, 8, 8, 2, 1],dd_in=opt.dd_in) | |
| else: | |
| raise Exception("Arch error!") | |
| return model_restoration |