|
|
import json |
|
|
from safetensors import safe_open |
|
|
|
|
|
def generate_safetensors_index(model_path="."): |
|
|
"""Generate model.safetensors.index.json from existing safetensors files""" |
|
|
|
|
|
|
|
|
with open(f"pytorch_model.bin.index.json", "r") as f: |
|
|
bin_index = json.load(f) |
|
|
|
|
|
|
|
|
safetensors_index = { |
|
|
"metadata": bin_index.get("metadata", {}), |
|
|
"weight_map": {} |
|
|
} |
|
|
|
|
|
|
|
|
safetensors_files = [ |
|
|
"pytorch_model-00001-of-00004.safetensors", |
|
|
"pytorch_model-00002-of-00004.safetensors", |
|
|
"pytorch_model-00003-of-00004.safetensors", |
|
|
"pytorch_model-00004-of-00004.safetensors" |
|
|
] |
|
|
|
|
|
for safetensor_file in safetensors_files: |
|
|
try: |
|
|
with safe_open(f"{safetensor_file}", framework="pt") as f: |
|
|
for tensor_name in f.keys(): |
|
|
safetensors_index["weight_map"][tensor_name] = safetensor_file |
|
|
print(f"✓ Processed {safetensor_file}") |
|
|
except Exception as e: |
|
|
print(f"✗ Error processing {safetensor_file}: {e}") |
|
|
|
|
|
|
|
|
with open(f"model.safetensors.index.json", "w") as f: |
|
|
json.dump(safetensors_index, f, indent=2) |
|
|
|
|
|
print(f"✓ Generated model.safetensors.index.json with {len(safetensors_index['weight_map'])} tensors") |
|
|
return safetensors_index |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
generate_safetensors_index("./Finance-Llama-8B") |