| import os | |
| import torch | |
| import json | |
| # load your large model | |
| model = SomeLargeModel('/mnt/e/ai_cache/output/wizardcoder_mmlu_2/merged') | |
| model.load_state_dict(torch.load('pytorch_model.bin')) | |
| # save each tensor to a separate file and record the mapping in the index | |
| state_dict = model.state_dict() | |
| index = {"metadata": {"total_size": 0}, "weight_map": {}} | |
| i = 1 | |
| total_files = len(state_dict.keys()) | |
| for key, tensor in state_dict.items(): | |
| chunk_file = f'pytorch_model-{str(i).zfill(5)}-of-{str(total_files).zfill(5)}.bin' | |
| torch.save({key: tensor}, chunk_file) | |
| index["weight_map"][key] = chunk_file | |
| index["metadata"]["total_size"] += tensor.nelement() * tensor.element_size() | |
| i += 1 | |
| # save the index | |
| with open('pytorch_model.bin.index', 'w') as f: | |
| json.dump(index, f) | |