| |
|
|
| import glob |
| import json |
| import os |
|
|
| from safetensors import safe_open |
|
|
|
|
| def main(): |
| |
| shard_files = sorted(glob.glob("model-*-of-*.safetensors")) |
|
|
| |
| total_size = sum(os.path.getsize(sf) for sf in shard_files) |
|
|
| metadata = {"total_size": total_size} |
| weight_map = {} |
|
|
| |
| for shard_file in shard_files: |
| with safe_open(shard_file, framework="np") as f: |
| for tensor_name in f.keys(): |
| weight_map[tensor_name] = os.path.basename(shard_file) |
|
|
| output_dict = {"metadata": metadata, "weight_map": weight_map} |
|
|
| |
| with open("model.safetensors.index.json", "w", encoding="utf-8") as out_file: |
| json.dump(output_dict, out_file, indent=2) |
|
|
| print("Created model.safetensors.index.json with total size =", total_size, "bytes.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|