File size: 2,160 Bytes
d86a963
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import os
from pathlib import Path

def optimize_checkpoint(path, key_mapping):
    print(f"Processing {path}...")
    try:
        # Load
        file_path = Path(path)
        if not file_path.exists():
            print(f"❌ File not found: {path}")
            return
            
        original_size = file_path.stat().st_size / (1024*1024)
        print(f"  Original size: {original_size:.2f} MB")
        
        # Load on CPU
        ckpt = torch.load(file_path, map_location='cpu')
        
        # Create new dict
        new_ckpt = {}
        
        # Copy only essential keys
        for keep_key in key_mapping['keep']:
            if keep_key in ckpt:
                new_ckpt[keep_key] = ckpt[keep_key]
        
        # Handle special cases / remapping if needed
        if 'rename' in key_mapping:
            for old_k, new_k in key_mapping['rename'].items():
                if old_k in ckpt:
                    new_ckpt[new_k] = ckpt[old_k]
                    
        # Save to new path
        new_path = file_path.parent / "model_optimized.pt"
        torch.save(new_ckpt, new_path)
        
        new_size = new_path.stat().st_size / (1024*1024)
        print(f"  ✅ Saved to {new_path}")
        print(f"  New size: {new_size:.2f} MB")
        print(f"  Reduction: {100 * (1 - new_size/original_size):.1f}%")
        
    except Exception as e:
        print(f"❌ Error processing {path}: {e}")

# Defnitions of what to keep for each model type
configs = [
    {
        'path': 'encoder_transformer/sentiment/checkpoints/best.pt',
        'key': {
            'keep': ['model_state_dict', 'config', 'best_acc', 'tokenizer_path']
        }
    },
    {
        'path': 'decoder_transformer/checkpoints/best.pt',
        'key': {
            'keep': ['model_state_dict', 'meta', 'ema_state_dict'] 
        }
    },
    {
        'path': 'machine_translation/checkpoints/best.pt',
        'key': {
            'keep': ['config', 'model'] # 'optimizer' is dropped
        }
    }
]

print("Starting optimization...")
for config in configs:
    optimize_checkpoint(config['path'], config['key'])
print("Done!")