| import os | |
| import tempfile | |
| import torch | |
| from mergekit.io import TensorWriter | |
| class TestTensorWriter: | |
| def test_safetensors(self): | |
| with tempfile.TemporaryDirectory() as d: | |
| writer = TensorWriter(d, safe_serialization=True) | |
| writer.save_tensor("steve", torch.randn(4)) | |
| writer.finalize() | |
| assert os.path.exists(os.path.join(d, "model-00001-of-00001.safetensors")) | |
| assert os.path.exists(os.path.join(d, "model.safetensors.index.json")) | |
| def test_pickle(self): | |
| with tempfile.TemporaryDirectory() as d: | |
| writer = TensorWriter(d, safe_serialization=False) | |
| writer.save_tensor("timothan", torch.randn(4)) | |
| writer.finalize() | |
| assert os.path.exists(os.path.join(d, "pytorch_model-00001-of-00001.bin")) | |
| assert os.path.exists(os.path.join(d, "pytorch_model.bin.index.json")) | |
| def test_duplicate_tensor(self): | |
| with tempfile.TemporaryDirectory() as d: | |
| writer = TensorWriter(d, safe_serialization=True) | |
| jim = torch.randn(4) | |
| writer.save_tensor("jim", jim) | |
| writer.save_tensor("jimbo", jim) | |
| writer.finalize() | |
| assert os.path.exists(os.path.join(d, "model-00001-of-00001.safetensors")) | |
| assert os.path.exists(os.path.join(d, "model.safetensors.index.json")) | |