| | import torch |
| |
|
| | def delete_keys_from_dict(dict_del, lst_keys): |
| | """ |
| | Delete the keys present in lst_keys from the dictionary. |
| | Loops recursively over nested dictionaries. |
| | """ |
| | dict_foo = dict_del.copy() |
| | for field in dict_foo.keys(): |
| | if field in lst_keys: |
| | del dict_del[field] |
| | if type(dict_foo[field]) == dict: |
| | delete_keys_from_dict(dict_del[field], lst_keys) |
| | return dict_del |
| |
|
| |
|
| | def optimizer_to(optim, device): |
| | for param in optim.state.values(): |
| | |
| | if isinstance(param, torch.Tensor): |
| | param.data = param.data.to(device) |
| | if param._grad is not None: |
| | param._grad.data = param._grad.data.to(device) |
| | elif isinstance(param, dict): |
| | for subparam in param.values(): |
| | if isinstance(subparam, torch.Tensor): |
| | subparam.data = subparam.data.to(device) |
| | if subparam._grad is not None: |
| | subparam._grad.data = subparam._grad.data.to(device) |
| |
|
| |
|
| | def load_partial_model(pretrained_dict, model): |
| | model_dict = model.state_dict() |
| | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} |
| | model_dict.update(pretrained_dict) |
| | model.load_state_dict(model_dict) |