""" Open every file in the checkpoints directory and change the keys. keys are currently a nested dict. keep only keys that are in checkpoint['model'] and rename those keys such that, e.g. 'model.pos_embed.W_pos' becomes 'pos_embed.W_pos'. """ import os import json import torch checkpoints_dir = "checkpoints/" for file in os.listdir(checkpoints_dir): if file.endswith(".pt"): file_path = os.path.join(checkpoints_dir, file) print(f"Processing {file}...") # Load the checkpoint checkpoint = torch.load(file_path, map_location='cpu') # Extract model keys and rename them if 'model' in checkpoint: model_state_dict = checkpoint['model'] converted_state_dict = {} for key, value in model_state_dict.items(): # Remove 'model.' prefix if it exists if key.startswith('model.'): new_key = key[6:] # Remove 'model.' prefix else: new_key = key converted_state_dict[new_key] = value # Save the converted checkpoint as a flat dictionary output_path = os.path.join(checkpoints_dir, f"{file}") torch.save(converted_state_dict, output_path) print(f"Saved converted checkpoint to {output_path}") else: print(f"Warning: No 'model' key found in {file}")