File size: 1,446 Bytes
524eab3 60df24b 524eab3 60df24b 524eab3 60df24b 524eab3 f73efa3 524eab3 793b152 524eab3 f73efa3 524eab3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
""" Open every file in the checkpoints directory and change the keys. keys are currently a nested dict. keep only keys that are in checkpoint['model'] and rename those keys such that, e.g. 'model.pos_embed.W_pos' becomes 'pos_embed.W_pos'.
"""
import os
import json
import torch
checkpoints_dir = "checkpoints/"
for file in os.listdir(checkpoints_dir):
if file.endswith(".pt"):
file_path = os.path.join(checkpoints_dir, file)
print(f"Processing {file}...")
# Load the checkpoint
checkpoint = torch.load(file_path, map_location='cpu')
# Extract model keys and rename them
if 'model' in checkpoint:
model_state_dict = checkpoint['model']
converted_state_dict = {}
for key, value in model_state_dict.items():
# Remove 'model.' prefix if it exists
if key.startswith('model.'):
new_key = key[6:] # Remove 'model.' prefix
else:
new_key = key
converted_state_dict[new_key] = value
# Save the converted checkpoint as a flat dictionary
output_path = os.path.join(checkpoints_dir, f"{file}")
torch.save(converted_state_dict, output_path)
print(f"Saved converted checkpoint to {output_path}")
else:
print(f"Warning: No 'model' key found in {file}") |