File size: 1,054 Bytes
63e9186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import os
import argparse

def save_only_ema_weights(checkpoint_file):
    """Extract and save only the EMA weights."""
    checkpoint = torch.load(checkpoint_file, map_location='cpu')

    weights = {}
    if 'ema' in checkpoint:
        weights['model'] = checkpoint['ema']['module']
    else:
        raise ValueError("The checkpoint does not contain 'ema'.")

    dir_name, base_name = os.path.split(checkpoint_file)
    name, ext = os.path.splitext(base_name)
    output_file = os.path.join(dir_name, f"{name}_converted{ext}")

    torch.save(weights, output_file)
    print(f"EMA weights saved to {output_file}")

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Extract and save only EMA weights.")
    parser.add_argument('checkpoint_dir', type=str, help="Path to the input checkpoint file.")

    args = parser.parse_args()
    for file in os.listdir(args.checkpoint_dir):
        if '.pth' in file and '_converted' not in file:
            save_only_ema_weights(os.path.join(args.checkpoint_dir, file))