triangle-40k-og / convert_checkpoints.py
algo2217's picture
Upload folder using huggingface_hub
793b152 verified
""" Open every file in the checkpoints directory and change the keys. keys are currently a nested dict. keep only keys that are in checkpoint['model'] and rename those keys such that, e.g. 'model.pos_embed.W_pos' becomes 'pos_embed.W_pos'.
"""
import os
import json
import torch
checkpoints_dir = "checkpoints/"
for file in os.listdir(checkpoints_dir):
if file.endswith(".pt"):
file_path = os.path.join(checkpoints_dir, file)
print(f"Processing {file}...")
# Load the checkpoint
checkpoint = torch.load(file_path, map_location='cpu')
# Extract model keys and rename them
if 'model' in checkpoint:
model_state_dict = checkpoint['model']
converted_state_dict = {}
for key, value in model_state_dict.items():
# Remove 'model.' prefix if it exists
if key.startswith('model.'):
new_key = key[6:] # Remove 'model.' prefix
else:
new_key = key
converted_state_dict[new_key] = value
# Save the converted checkpoint as a flat dictionary
output_path = os.path.join(checkpoints_dir, f"{file}")
torch.save(converted_state_dict, output_path)
print(f"Saved converted checkpoint to {output_path}")
else:
print(f"Warning: No 'model' key found in {file}")