File size: 2,773 Bytes
15dcafa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""
Extract the language model (text-only) weights from Gemma 4 multimodal safetensors.

- Filters keys containing 'language_model'
- Renames: model.language_model.X -> model.X
- Saves as sharded safetensors (10GB per shard)
- Generates model.safetensors.index.json
"""

import glob
import json
import os

import torch
from safetensors import safe_open
from safetensors.torch import save_file

SRC_DIR = "/workspace/llm/gemma-4-31B-it"
DST_DIR = "/workspace/llm/gemma-4-31B-Text"
MAX_SHARD_SIZE = 10 * 1024 * 1024 * 1024  # 10GB


def main():
    os.makedirs(DST_DIR, exist_ok=True)

    src_files = sorted(glob.glob(os.path.join(SRC_DIR, "*.safetensors")))
    print(f"Source files: {len(src_files)}")

    # Step 1: Collect all language_model tensors with renamed keys
    all_tensors = {}
    for path in src_files:
        print(f"Reading {os.path.basename(path)}...")
        with safe_open(path, framework="pt", device="cpu") as f:
            for key in f.keys():
                if "language_model" in key:
                    new_key = key.replace("model.language_model.", "model.")
                    all_tensors[new_key] = f.get_tensor(key)

    print(f"Extracted {len(all_tensors)} tensors")

    # Step 2: Split into shards by size
    shards = []
    current_shard = {}
    current_size = 0

    for key in sorted(all_tensors.keys()):
        tensor = all_tensors[key]
        tensor_size = tensor.nelement() * tensor.element_size()

        if current_shard and current_size + tensor_size > MAX_SHARD_SIZE:
            shards.append(current_shard)
            current_shard = {}
            current_size = 0

        current_shard[key] = tensor
        current_size += tensor_size

    if current_shard:
        shards.append(current_shard)

    print(f"Splitting into {len(shards)} shards")

    # Step 3: Save each shard and build weight_map
    total_shards = len(shards)
    weight_map = {}

    for i, shard in enumerate(shards):
        filename = f"model-{i+1:05d}-of-{total_shards:05d}.safetensors"
        filepath = os.path.join(DST_DIR, filename)
        shard_size = sum(t.nelement() * t.element_size() for t in shard.values())
        print(f"Saving {filename} ({shard_size / 1e9:.2f} GB, {len(shard)} tensors)...")
        save_file(shard, filepath)

        for key in shard:
            weight_map[key] = filename

    # Step 4: Write index file
    index = {
        "metadata": {"total_size": sum(t.nelement() * t.element_size() for t in all_tensors.values())},
        "weight_map": weight_map,
    }
    index_path = os.path.join(DST_DIR, "model.safetensors.index.json")
    with open(index_path, "w") as f:
        json.dump(index, f, indent=2)

    print(f"Done! Index written to {index_path}")


if __name__ == "__main__":
    main()