alfred / models /utils /helper_utils.py
ch-min-ys's picture
delete ff_planner
483f96c
import torch
def delete_keys_from_dict(dict_del, lst_keys):
"""
Delete the keys present in lst_keys from the dictionary.
Loops recursively over nested dictionaries.
"""
dict_foo = dict_del.copy() #Used as iterator to avoid the 'DictionaryHasChanged' error
for field in dict_foo.keys():
if field in lst_keys:
del dict_del[field]
if type(dict_foo[field]) == dict:
delete_keys_from_dict(dict_del[field], lst_keys)
return dict_del
def optimizer_to(optim, device):
for param in optim.state.values():
# Not sure there are any global tensors in the state dict
if isinstance(param, torch.Tensor):
param.data = param.data.to(device)
if param._grad is not None:
param._grad.data = param._grad.data.to(device)
elif isinstance(param, dict):
for subparam in param.values():
if isinstance(subparam, torch.Tensor):
subparam.data = subparam.data.to(device)
if subparam._grad is not None:
subparam._grad.data = subparam._grad.data.to(device)
def load_partial_model(pretrained_dict, model):
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)