Lexa
Converted .pt files to safetensors, then (dirtily) patched fairseq to enable loading of safetensor files
b5a0bec
| # Patch for fairseq2.utils.file.load_tensors | |
| # | |
| # This patch allows for loading safetensors files | |
| # | |
| # It is used in the two_tower_diffusion_lcm model loader: | |
| # ./lcm/models/two_tower_diffusion_lcm/loader.py | |
| from __future__ import annotations | |
| import warnings | |
| from pathlib import Path | |
| from typing import Any, Callable, Dict, Mapping, Optional, Protocol, Union | |
| from warnings import catch_warnings | |
| import torch | |
| from torch import Tensor | |
| from typing_extensions import TypeAlias | |
| from fairseq2.typing import Device | |
| from safetensors.torch import load_file | |
| MapLocation: TypeAlias = Optional[ | |
| Union[Callable[[Tensor, str], Tensor], Device, str, Dict[str, str]] | |
| ] | |
| class TensorLoader(Protocol): | |
| """Loads tensors from files.""" | |
| def __call__( | |
| self, | |
| path: Path, | |
| *, | |
| map_location: MapLocation = None, | |
| restrict: bool = False, | |
| ) -> Dict[str, Any]: | |
| """ | |
| :param path: | |
| The path to the file. | |
| :param map_location: | |
| Same as the ``map_location`` parametload_two_tower_diffusion_lcm_model = StandardModelLoader( # type: ignore # FIXME | |
| config_loader=load_two_tower_diffusion_lcm_config, | |
| factory=create_two_tower_diffusion_lcm_model, | |
| checkpoint_converter=convert_lcm_checkpoint, | |
| restrict_checkpoints=False, | |
| ) | |
| """ | |
| class TensorDumper(Protocol): | |
| """Dumps tensors to files.""" | |
| def __call__(self, data: Mapping[str, Any], path: Path) -> None: | |
| """ | |
| :param data: | |
| The dictionary containing tensors and other auxiliary data. | |
| :param path: | |
| The path to the file. | |
| """ | |
| def load_tensors( | |
| path: Path, | |
| *, | |
| map_location=None, | |
| restrict: bool = False, | |
| ) -> Dict[str, Any]: | |
| """Load a checkpoint in .pt or .safetensors format.""" | |
| if str(path).endswith(".safetensors"): | |
| tensors = load_file(str(path), device=str(map_location) if map_location else "cpu") | |
| return {"model": tensors} # ✅ Wrap it like a .pt file | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore") | |
| return torch.load( | |
| str(path), map_location, weights_only=restrict # type: ignore[arg-type] | |
| ) | |
| def dump_tensors(data: Mapping[str, Any], path: Path) -> None: | |
| """Dump ``data`` to a PyTorch tensor file under ``path``.""" | |
| with catch_warnings(): | |
| warnings.simplefilter("ignore") # Suppress noisy FSDP warnings. | |
| torch.save(data, path) | |