| import torch | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--input', '-I', type=str, help='Input file to prune', required = True) | |
| args = parser.parse_args() | |
| file = args.input | |
| checkpoint = torch.load(file) | |
| new_sd = dict() | |
| for k in checkpoint.keys(): | |
| if k != 'optimizer_states': | |
| new_sd[k] = checkpoint[k] | |
| torch.save(new_sd, f'pruned-{file}') |