Spaces:
Runtime error
Runtime error
| ''' | |
| This is to save and load the model. | |
| ''' | |
| def check_model_checkpoint_consistency(ckpt_state_dict, model_state_dict, special_strs=None): | |
| """ | |
| Maintain all checkpoint keys. Ignore keys with specific endings if absent. | |
| Raise exception for model keys not in checkpoint unless ignored. | |
| ckpt: The state dictionary of the checkpoint. | |
| model_state_dict: The state dictionary of the model. | |
| special_endings: A list of specific endings of strings to be ignored. | |
| """ | |
| filtered_ckpt = {} | |
| special_modules =[] | |
| for key in model_state_dict.keys(): | |
| if key in ckpt_state_dict: | |
| filtered_ckpt[key] = ckpt_state_dict[key] | |
| elif any(special_str in key for special_str in special_strs): | |
| special_modules.append(key) | |
| continue | |
| else: | |
| raise KeyError(f"Key '{key}' not found in checkpoint and does not match any special endings.") | |
| def remove_module_prefix(state_dict): | |
| new_state_dict = {} | |
| for key, value in state_dict.items(): | |
| if key.startswith('module.'): | |
| new_key = key[len('module.'):] | |
| new_state_dict[new_key] = value | |
| else: | |
| new_state_dict[key] = value | |
| return new_state_dict | |
| # This is for reducing impact at the beginning of training. | |
| def zero_module(module): | |
| """ | |
| Zero out the parameters of a module and return it. | |
| """ | |
| for p in module.parameters(): | |
| p.detach().zero_() | |
| return module | |
| def filter_model_checkpoint(ckpt_state_dict, model_state_dict, need_strs=None): | |
| filtered_ckpt = {} | |
| for key in model_state_dict.keys(): | |
| if key in ckpt_state_dict and any(need_str in key for need_str in need_strs): | |
| filtered_ckpt[key] = ckpt_state_dict[key] | |
| else: | |
| continue | |
| return filtered_ckpt | |