kgrabko commited on
Commit
0ba2ceb
·
verified ·
1 Parent(s): eabe75d

Upload save_jirack_as_safe_tensors.py

Browse files
Files changed (1) hide show
  1. save_jirack_as_safe_tensors.py +73 -0
save_jirack_as_safe_tensors.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 CMS Manhattan
2
+ # All rights reserved.
3
+ # Author: Konstantin Vladimirovich Grabko
4
+ # Email: grabko@cmsmanhattan.com
5
+ # Phone: +1(516)777-0945
6
+ #
7
+ # This program is free software: you can redistribute it and/or modify
8
+ # it under the terms of the GNU General Public License as published by
9
+ # the Free Software Foundation, version 3 of the License.
10
+ #
11
+ # This program is distributed in the hope that it will be useful,
12
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
13
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
+ # GNU General Public License for more details.
15
+ #
16
+ # You should have received a copy of the GNU General Public License
17
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
18
+ #
19
+ # Additional terms:
20
+ # Any commercial use or distribution of this software or derivative works
21
+ # requires explicit written permission from the copyright holder.
22
+
23
+ import torch
24
+ from safetensors.torch import save_file
25
+ import os
26
+
27
+ def save_jirack_sharded(model, directory="jirack_weights"):
28
+ """
29
+ Converts the current model weights into sharded safetensors format.
30
+ """
31
+ if not os.path.exists(directory):
32
+ os.makedirs(directory)
33
+
34
+ # Extract state dictionary
35
+ state_dict = model.state_dict()
36
+ keys = sorted(list(state_dict.keys()))
37
+
38
+ # Split keys into two halves for sharding
39
+ mid = len(keys) // 2
40
+ shard1 = {k: state_dict[k] for k in keys[:mid]}
41
+ shard2 = {k: state_dict[k] for k in keys[mid:]}
42
+
43
+ # Define paths according to your standard format
44
+ path1 = os.path.join(directory, "model-00001-of-00002.safetensors")
45
+ path2 = os.path.join(directory, "model-00002-of-00002.safetensors")
46
+
47
+ print(f"Saving Shard 1 ({len(shard1)} keys) -> {path1}")
48
+ save_file(shard1, path1)
49
+
50
+ print(f"Saving Shard 2 ({len(shard2)} keys) -> {path2}")
51
+ save_file(shard2, path2)
52
+
53
+ # Use the authorship method from your class
54
+ print(f"Done. Model by {model.get_author_info()} is now sharded.")
55
+
56
+ # Run as a standalone script to convert an existing .pt checkpoint:
57
+ if __name__ == "__main__":
58
+ # Import your specific model class
59
+ from JiRackPyTorch_GPT5_class_1b import JiRackPyTorch
60
+ # from JiRackPyTorch_GPT5_class_3b import JiRackPyTorch
61
+
62
+ # 1. Initialize the new architecture (with RoPE, SWA, etc.)
63
+ model = JiRackPyTorch()
64
+
65
+ # 2. Load existing .pt weights if they exist (optional)
66
+ # Using strict=False allows loading weights even if RoPE parameters are missing in the old file
67
+ old_weights = "old_model_1b.pt"
68
+ if os.path.exists(old_weights):
69
+ print(f"Merging old weights from {old_weights}...")
70
+ model.load_state_dict(torch.load(old_weights, map_location="cpu"), strict=False)
71
+
72
+ # 3. Save to the new sharded safetensors format
73
+ save_jirack_sharded(model)