| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import logging |
| import os |
| import re |
|
|
| import yaml |
| import torch |
| from collections import OrderedDict |
|
|
| import datetime |
|
|
|
|
| def load_checkpoint(model: torch.nn.Module, path: str) -> dict: |
| rank = int(os.environ.get('RANK', 0)) |
| logging.info('[Rank {}] Checkpoint: loading from checkpoint {}'.format( |
| rank, path)) |
| checkpoint = torch.load(path, map_location='cpu', mmap=True) |
| missing_keys, unexpected_keys = model.load_state_dict(checkpoint, |
| strict=False) |
| if rank == 0: |
| for key in missing_keys: |
| logging.info("missing tensor: {}".format(key)) |
| for key in unexpected_keys: |
| logging.info("unexpected tensor: {}".format(key)) |
| info_path = re.sub('.pt$', '.yaml', path) |
| configs = {} |
| if os.path.exists(info_path): |
| with open(info_path, 'r') as fin: |
| configs = yaml.load(fin, Loader=yaml.FullLoader) |
| return configs |
|
|
|
|
| def save_state_dict_and_infos(state_dict, path: str, infos=None): |
| rank = int(os.environ.get('RANK', 0)) |
| logging.info('[Rank {}] Checkpoint: save to checkpoint {}'.format( |
| rank, path)) |
| torch.save(state_dict, path) |
| info_path = re.sub('.pt$', '.yaml', path) |
| if infos is None: |
| infos = {} |
| infos['save_time'] = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S') |
| with open(info_path, 'w') as fout: |
| data = yaml.dump(infos) |
| fout.write(data) |
|
|
|
|
| def save_checkpoint(model: torch.nn.Module, path: str, infos=None): |
| ''' |
| Args: |
| infos (dict or None): any info you want to save. |
| ''' |
| if isinstance(model, torch.nn.DataParallel): |
| state_dict = model.module.state_dict() |
| elif isinstance(model, torch.nn.parallel.DistributedDataParallel): |
| state_dict = model.module.state_dict() |
| else: |
| state_dict = model.state_dict() |
| save_state_dict_and_infos(state_dict, path, infos) |
|
|
|
|
| def filter_modules(model_state_dict, modules): |
| rank = int(os.environ.get('RANK', 0)) |
| new_mods = [] |
| incorrect_mods = [] |
| mods_model = model_state_dict.keys() |
| for mod in modules: |
| if any(key.startswith(mod) for key in mods_model): |
| new_mods += [mod] |
| else: |
| incorrect_mods += [mod] |
| if incorrect_mods and rank == 0: |
| logging.warning( |
| "module(s) %s don't match or (partially match) " |
| "available modules in model.", |
| incorrect_mods, |
| ) |
| logging.warning("for information, the existing modules in model are:") |
| logging.warning("%s", mods_model) |
|
|
| return new_mods |
|
|
|
|
| def load_trained_modules(model: torch.nn.Module, args: None): |
| |
| enc_model_path = args.enc_init |
| enc_modules = args.enc_init_mods |
| main_state_dict = model.state_dict() |
| logging.warning("model(s) found for pre-initialization") |
| if os.path.isfile(enc_model_path): |
| logging.info('Checkpoint: loading from checkpoint %s for CPU' % |
| enc_model_path) |
| model_state_dict = torch.load(enc_model_path, map_location='cpu') |
| modules = filter_modules(model_state_dict, enc_modules) |
| partial_state_dict = OrderedDict() |
| for key, value in model_state_dict.items(): |
| if any(key.startswith(m) for m in modules): |
| partial_state_dict[key] = value |
| main_state_dict.update(partial_state_dict) |
| else: |
| logging.warning("model was not found : %s", enc_model_path) |
|
|
| model.load_state_dict(main_state_dict) |
| configs = {} |
| return configs |
|
|