| """ Open every file in the checkpoints directory and change the keys. keys are currently a nested dict. keep only keys that are in checkpoint['model'] and rename those keys such that, e.g. 'model.pos_embed.W_pos' becomes 'pos_embed.W_pos'. | |
| """ | |
| import os | |
| import json | |
| import torch | |
| checkpoints_dir = "checkpoints/" | |
| for file in os.listdir(checkpoints_dir): | |
| if file.endswith(".pt"): | |
| file_path = os.path.join(checkpoints_dir, file) | |
| print(f"Processing {file}...") | |
| # Load the checkpoint | |
| checkpoint = torch.load(file_path, map_location='cpu') | |
| # Extract model keys and rename them | |
| if 'model' in checkpoint: | |
| model_state_dict = checkpoint['model'] | |
| converted_state_dict = {} | |
| for key, value in model_state_dict.items(): | |
| # Remove 'model.' prefix if it exists | |
| if key.startswith('model.'): | |
| new_key = key[6:] # Remove 'model.' prefix | |
| else: | |
| new_key = key | |
| converted_state_dict[new_key] = value | |
| # Save the converted checkpoint as a flat dictionary | |
| output_path = os.path.join(checkpoints_dir, f"{file}") | |
| torch.save(converted_state_dict, output_path) | |
| print(f"Saved converted checkpoint to {output_path}") | |
| else: | |
| print(f"Warning: No 'model' key found in {file}") |