File size: 1,446 Bytes
524eab3
60df24b
524eab3
60df24b
524eab3
 
 
60df24b
524eab3
f73efa3
524eab3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
793b152
524eab3
 
f73efa3
524eab3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
""" Open every file in the checkpoints directory and change the keys. keys are currently a nested dict. keep only keys that are in checkpoint['model'] and rename those keys such that, e.g. 'model.pos_embed.W_pos' becomes 'pos_embed.W_pos'.

"""

import os
import json
import torch

checkpoints_dir = "checkpoints/"

for file in os.listdir(checkpoints_dir):
    if file.endswith(".pt"):
        file_path = os.path.join(checkpoints_dir, file)
        print(f"Processing {file}...")
        
        # Load the checkpoint
        checkpoint = torch.load(file_path, map_location='cpu')
        
        # Extract model keys and rename them
        if 'model' in checkpoint:
            model_state_dict = checkpoint['model']
            converted_state_dict = {}
            
            for key, value in model_state_dict.items():
                # Remove 'model.' prefix if it exists
                if key.startswith('model.'):
                    new_key = key[6:]  # Remove 'model.' prefix
                else:
                    new_key = key
                converted_state_dict[new_key] = value
            
            # Save the converted checkpoint as a flat dictionary
            output_path = os.path.join(checkpoints_dir, f"{file}")
            torch.save(converted_state_dict, output_path)
            print(f"Saved converted checkpoint to {output_path}")
        else:
            print(f"Warning: No 'model' key found in {file}")