File size: 2,240 Bytes
9835abb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#!/usr/bin/env python3
"""
PyTorch ์ฒดํฌํฌ์ธํŠธ์—์„œ ๊ฐ€์ค‘์น˜๋ฅผ ์ถ”์ถœํ•˜์—ฌ binary ํ˜•์‹์œผ๋กœ ์ €์žฅ
"""
import torch
import numpy as np
import struct
import sys
from pathlib import Path

def extract_weights(checkpoint_path, output_path):
    """์ฒดํฌํฌ์ธํŠธ์—์„œ ๊ฐ€์ค‘์น˜ ์ถ”์ถœ"""
    print(f"Loading checkpoint: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    
    # state_dict ์ถ”์ถœ
    if 'model_state_dict' in checkpoint:
        state_dict = checkpoint['model_state_dict']
    elif 'state_dict' in checkpoint:
        state_dict = checkpoint['state_dict']
    else:
        state_dict = checkpoint
    
    print(f"Found {len(state_dict)} parameters")
    
    # Binary ํŒŒ์ผ๋กœ ์ €์žฅ
    with open(output_path, 'wb') as f:
        # ๋งค์ง ๋„˜๋ฒ„์™€ ๋ฒ„์ „ ์ •๋ณด
        f.write(b'LCNN')  # Magic number
        f.write(struct.pack('I', 1))  # Version
        
        # ํŒŒ๋ผ๋ฏธํ„ฐ ๊ฐœ์ˆ˜
        f.write(struct.pack('I', len(state_dict)))
        
        for name, param in state_dict.items():
            print(f"  {name}: {param.shape}")
            
            # ํŒŒ๋ผ๋ฏธํ„ฐ ์ด๋ฆ„ (์ตœ๋Œ€ 256 bytes)
            name_bytes = name.encode('utf-8')[:256]
            f.write(struct.pack('I', len(name_bytes)))
            f.write(name_bytes)
            
            # ํ…์„œ ๋ฐ์ดํ„ฐ
            data = param.cpu().numpy().astype(np.float32)
            
            # Shape ์ •๋ณด
            ndim = len(data.shape)
            f.write(struct.pack('I', ndim))
            for dim in data.shape:
                f.write(struct.pack('I', dim))
            
            # ๋ฐ์ดํ„ฐ (C-contiguous order)
            data_flat = data.flatten('C')
            f.write(struct.pack(f'{len(data_flat)}f', *data_flat))
    
    print(f"\nWeights saved to: {output_path}")
    print(f"File size: {Path(output_path).stat().st_size / 1024 / 1024:.2f} MB")

if __name__ == '__main__':
    checkpoint_path = sys.argv[1] if len(sys.argv) > 1 else '~/mycnn/checkpoints/LiteCNNPro_best.pth'
    output_path = sys.argv[2] if len(sys.argv) > 2 else './model_weights.bin'
    
    checkpoint_path = Path(checkpoint_path).expanduser()
    extract_weights(checkpoint_path, output_path)