Upload save_jirack_as_safe_tensors.py
Browse files
save_jirack_as_safe_tensors.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 CMS Manhattan
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
# Author: Konstantin Vladimirovich Grabko
|
| 4 |
+
# Email: grabko@cmsmanhattan.com
|
| 5 |
+
# Phone: +1(516)777-0945
|
| 6 |
+
#
|
| 7 |
+
# This program is free software: you can redistribute it and/or modify
|
| 8 |
+
# it under the terms of the GNU General Public License as published by
|
| 9 |
+
# the Free Software Foundation, version 3 of the License.
|
| 10 |
+
#
|
| 11 |
+
# This program is distributed in the hope that it will be useful,
|
| 12 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 13 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 14 |
+
# GNU General Public License for more details.
|
| 15 |
+
#
|
| 16 |
+
# You should have received a copy of the GNU General Public License
|
| 17 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 18 |
+
#
|
| 19 |
+
# Additional terms:
|
| 20 |
+
# Any commercial use or distribution of this software or derivative works
|
| 21 |
+
# requires explicit written permission from the copyright holder.
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
from safetensors.torch import save_file
|
| 25 |
+
import os
|
| 26 |
+
|
| 27 |
+
def save_jirack_sharded(model, directory="jirack_weights"):
|
| 28 |
+
"""
|
| 29 |
+
Converts the current model weights into sharded safetensors format.
|
| 30 |
+
"""
|
| 31 |
+
if not os.path.exists(directory):
|
| 32 |
+
os.makedirs(directory)
|
| 33 |
+
|
| 34 |
+
# Extract state dictionary
|
| 35 |
+
state_dict = model.state_dict()
|
| 36 |
+
keys = sorted(list(state_dict.keys()))
|
| 37 |
+
|
| 38 |
+
# Split keys into two halves for sharding
|
| 39 |
+
mid = len(keys) // 2
|
| 40 |
+
shard1 = {k: state_dict[k] for k in keys[:mid]}
|
| 41 |
+
shard2 = {k: state_dict[k] for k in keys[mid:]}
|
| 42 |
+
|
| 43 |
+
# Define paths according to your standard format
|
| 44 |
+
path1 = os.path.join(directory, "model-00001-of-00002.safetensors")
|
| 45 |
+
path2 = os.path.join(directory, "model-00002-of-00002.safetensors")
|
| 46 |
+
|
| 47 |
+
print(f"Saving Shard 1 ({len(shard1)} keys) -> {path1}")
|
| 48 |
+
save_file(shard1, path1)
|
| 49 |
+
|
| 50 |
+
print(f"Saving Shard 2 ({len(shard2)} keys) -> {path2}")
|
| 51 |
+
save_file(shard2, path2)
|
| 52 |
+
|
| 53 |
+
# Use the authorship method from your class
|
| 54 |
+
print(f"Done. Model by {model.get_author_info()} is now sharded.")
|
| 55 |
+
|
| 56 |
+
# Run as a standalone script to convert an existing .pt checkpoint:
|
| 57 |
+
if __name__ == "__main__":
|
| 58 |
+
# Import your specific model class
|
| 59 |
+
from JiRackPyTorch_GPT5_class_1b import JiRackPyTorch
|
| 60 |
+
# from JiRackPyTorch_GPT5_class_3b import JiRackPyTorch
|
| 61 |
+
|
| 62 |
+
# 1. Initialize the new architecture (with RoPE, SWA, etc.)
|
| 63 |
+
model = JiRackPyTorch()
|
| 64 |
+
|
| 65 |
+
# 2. Load existing .pt weights if they exist (optional)
|
| 66 |
+
# Using strict=False allows loading weights even if RoPE parameters are missing in the old file
|
| 67 |
+
old_weights = "old_model_1b.pt"
|
| 68 |
+
if os.path.exists(old_weights):
|
| 69 |
+
print(f"Merging old weights from {old_weights}...")
|
| 70 |
+
model.load_state_dict(torch.load(old_weights, map_location="cpu"), strict=False)
|
| 71 |
+
|
| 72 |
+
# 3. Save to the new sharded safetensors format
|
| 73 |
+
save_jirack_sharded(model)
|