2c6829 commited on
Commit
9835abb
ยท
verified ยท
1 Parent(s): c7b0486

Upload extract_weights.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. extract_weights.py +64 -0
extract_weights.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ PyTorch ์ฒดํฌํฌ์ธํŠธ์—์„œ ๊ฐ€์ค‘์น˜๋ฅผ ์ถ”์ถœํ•˜์—ฌ binary ํ˜•์‹์œผ๋กœ ์ €์žฅ
4
+ """
5
+ import torch
6
+ import numpy as np
7
+ import struct
8
+ import sys
9
+ from pathlib import Path
10
+
11
+ def extract_weights(checkpoint_path, output_path):
12
+ """์ฒดํฌํฌ์ธํŠธ์—์„œ ๊ฐ€์ค‘์น˜ ์ถ”์ถœ"""
13
+ print(f"Loading checkpoint: {checkpoint_path}")
14
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
15
+
16
+ # state_dict ์ถ”์ถœ
17
+ if 'model_state_dict' in checkpoint:
18
+ state_dict = checkpoint['model_state_dict']
19
+ elif 'state_dict' in checkpoint:
20
+ state_dict = checkpoint['state_dict']
21
+ else:
22
+ state_dict = checkpoint
23
+
24
+ print(f"Found {len(state_dict)} parameters")
25
+
26
+ # Binary ํŒŒ์ผ๋กœ ์ €์žฅ
27
+ with open(output_path, 'wb') as f:
28
+ # ๋งค์ง ๋„˜๋ฒ„์™€ ๋ฒ„์ „ ์ •๋ณด
29
+ f.write(b'LCNN') # Magic number
30
+ f.write(struct.pack('I', 1)) # Version
31
+
32
+ # ํŒŒ๋ผ๋ฏธํ„ฐ ๊ฐœ์ˆ˜
33
+ f.write(struct.pack('I', len(state_dict)))
34
+
35
+ for name, param in state_dict.items():
36
+ print(f" {name}: {param.shape}")
37
+
38
+ # ํŒŒ๋ผ๋ฏธํ„ฐ ์ด๋ฆ„ (์ตœ๋Œ€ 256 bytes)
39
+ name_bytes = name.encode('utf-8')[:256]
40
+ f.write(struct.pack('I', len(name_bytes)))
41
+ f.write(name_bytes)
42
+
43
+ # ํ…์„œ ๋ฐ์ดํ„ฐ
44
+ data = param.cpu().numpy().astype(np.float32)
45
+
46
+ # Shape ์ •๋ณด
47
+ ndim = len(data.shape)
48
+ f.write(struct.pack('I', ndim))
49
+ for dim in data.shape:
50
+ f.write(struct.pack('I', dim))
51
+
52
+ # ๋ฐ์ดํ„ฐ (C-contiguous order)
53
+ data_flat = data.flatten('C')
54
+ f.write(struct.pack(f'{len(data_flat)}f', *data_flat))
55
+
56
+ print(f"\nWeights saved to: {output_path}")
57
+ print(f"File size: {Path(output_path).stat().st_size / 1024 / 1024:.2f} MB")
58
+
59
+ if __name__ == '__main__':
60
+ checkpoint_path = sys.argv[1] if len(sys.argv) > 1 else '~/mycnn/checkpoints/LiteCNNPro_best.pth'
61
+ output_path = sys.argv[2] if len(sys.argv) > 2 else './model_weights.bin'
62
+
63
+ checkpoint_path = Path(checkpoint_path).expanduser()
64
+ extract_weights(checkpoint_path, output_path)