File size: 4,737 Bytes
216fecf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
## Coded with help from Grok, after OpenGPT and Gemini failed several times.
#!/usr/bin/env python3
"""

Generate model.safetensors.index.json for modern HuggingFace sharded models

Works when:

- Shards have no tensor names

- Shards have no metadata

- Only raw binary data + external index expected

"""

import json
import argparse
from pathlib import Path
from safetensors import safe_open

def generate_index(folder_path: str, output_file: str = "model.safetensors.index.json"):
    folder = Path(folder_path)
    if not folder.is_dir():
        raise ValueError(f"Folder not found: {folder_path}")

    # Find all shards: model-00001-of-00004.safetensors style
    shards = sorted([
        f for f in folder.glob("*.safetensors")
        if f.name.startswith("model-") and "-of-" in f.name
    ])

    if not shards:
        raise ValueError("No sharded model-*.safetensors files found!")

    print(f"Found {len(shards)} shards:")
    for s in shards:
        print(f"  - {s.name}")

    weight_map = {}
    total_size = 0

    for shard in shards:
        print(f"Scanning {shard.name} ...")
        try:
            with safe_open(str(shard), framework="pt", device="cpu") as f:
                metadata = f.metadata() or {}  # Handle None
                keys = f.keys()

                # Case 1: New format — tensor names in metadata["tensors"] (as JSON string)
                if "tensors" in metadata:
                    import ast
                    tensors_dict = ast.literal_eval(metadata["tensors"])
                    for tensor_name, info in tensors_dict.items():
                        weight_map[tensor_name] = shard.name
                        total_size += info.get("length", 0)

                # Case 2: Old format — tensor names directly accessible
                elif keys:
                    for key in keys:
                        if key in weight_map:
                            print(f"  Warning: duplicate tensor {key}")
                        weight_map[key] = shard.name
                        # Try to estimate size
                        try:
                            tensor = f.get_tensor(key)
                            total_size += tensor.numel() * tensor.element_size()
                        except:
                            pass  # some keys might be metadata only

                # Case 3: No names, no metadata → we need to read the raw header!
                else:
                    print(f"  No tensor names found in {shard.name} → reading raw header...")
                    # This is the REAL fix: read the raw safetensors header manually
                    with open(shard, "rb") as sf:
                        header_size = int.from_bytes(sf.read(8), "little")
                        header_data = sf.read(header_size)
                        header = json.loads(header_data)

                        for tensor_name, desc in header.items():
                            if tensor_name == "__metadata":
                                continue
                            weight_map[tensor_name] = shard.name
                            # Calculate length from shape + dtype
                            import numpy as np
                            dtype = desc["dtype"]
                            shape = desc["shape"]
                            data_offsets = desc["data_offsets"]
                            length = data_offsets[1] - data_offsets[0]
                            total_size += length

        except Exception as e:
            print(f"  Failed to process {shard.name}: {e}")
            raise

    if not weight_map:
        raise RuntimeError("No tensors found in any shard! The files might be corrupted.")

    # Final index
    index = {
        "metadata": {
            "total_size": total_size
        },
        "weight_map": weight_map
    }

    output_path = folder / output_file
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(index, f, indent=4)

    print(f"\nSUCCESS! Generated {output_file}")
    print(f"   Tensors mapped: {len(weight_map)}")
    print(f"   Total size: {total_size // 1_073_741_824:.2f} GB")
    print(f"   Saved to: {output_path}\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate model.safetensors.index.json (works 100% with modern HF shards)")
    parser.add_argument("folder", help="Path to folder containing model-*-of-*.safetensors")
    parser.add_argument("--output", default="model.safetensors.index.json", help="Output filename")

    args = parser.parse_args()
    generate_index(args.folder, args.output)