| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from abc import ABC, abstractmethod |
| | from typing import Dict, Optional, Sequence |
| |
|
| | import safetensors |
| | import torch |
| |
|
| | from mergekit.io.lazy_unpickle import DeferredLoad, TorchArchiveReader, torch_lazy_load |
| |
|
| |
|
| | class TensorLoader(ABC): |
| | """Base class for (potentially lazy) tensor loaders.""" |
| |
|
| | @abstractmethod |
| | def get_tensor(self, key: str) -> torch.Tensor: |
| | ... |
| |
|
| | @abstractmethod |
| | def keys(self) -> Sequence[str]: |
| | ... |
| |
|
| | @classmethod |
| | def get( |
| | cls, |
| | shard_path: str, |
| | use_lazy_unpickle: bool = False, |
| | device: Optional[str] = None, |
| | ) -> "TensorLoader": |
| | if shard_path.lower().endswith(".safetensors"): |
| | |
| | return safetensors.safe_open( |
| | shard_path, framework="pt", device=device or "cpu" |
| | ) |
| | elif use_lazy_unpickle: |
| | return LazyPickleLoader(shard_path, device=device) |
| | return DumbPytorchLoader(shard_path, device=device) |
| |
|
| |
|
| | class LazyPickleLoader(TensorLoader): |
| | """Loader for pytorch files using a custom unpickler and vigorous monkeypatching.""" |
| |
|
| | zip_reader: TorchArchiveReader |
| | index: Dict[str, DeferredLoad] |
| | device: Optional[str] = None |
| |
|
| | def __init__(self, path: str, device: Optional[str] = None): |
| | self.zip_reader = TorchArchiveReader(path) |
| | self.device = device |
| | with torch_lazy_load(): |
| | self.index = torch.load(path) |
| |
|
| | def get_tensor(self, key: str) -> torch.Tensor: |
| | if key not in self.index: |
| | raise KeyError(key) |
| |
|
| | return self.index[key].execute(self.zip_reader, map_location=self.device) |
| |
|
| | def keys(self) -> Sequence[str]: |
| | return self.index.keys() |
| |
|
| |
|
| | class DumbPytorchLoader(TensorLoader): |
| | """Naive `torch.load` shard loading.""" |
| |
|
| | tensors: Dict[str, torch.Tensor] |
| |
|
| | def __init__(self, path: str, device: str): |
| | self.tensors = torch.load(path, map_location=device, weights_only=True) |
| |
|
| | def get_tensor(self, key: str) -> torch.Tensor: |
| | return self.tensors[key] |
| |
|
| | def keys(self) -> Sequence[str]: |
| | return self.tensors.keys() |
| |
|