Spaces:
Sleeping
Sleeping
| def load_matched_state_dict(model, state_dict, print_stats=True): | |
| """ | |
| Only loads weights that matched in key and shape. Ignore other weights. | |
| """ | |
| num_matched, num_total = 0, 0 | |
| curr_state_dict = model.state_dict() | |
| for key in curr_state_dict.keys(): | |
| num_total += 1 | |
| if key in state_dict and curr_state_dict[key].shape == state_dict[key].shape: | |
| curr_state_dict[key] = state_dict[key] | |
| num_matched += 1 | |
| model.load_state_dict(curr_state_dict) | |
| if print_stats: | |
| print(f'Loaded state_dict: {num_matched}/{num_total} matched') |