File size: 2,470 Bytes
b5a0bec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
# Patch for fairseq2.utils.file.load_tensors
#
# This patch allows for loading safetensors files
#
# It is used in the two_tower_diffusion_lcm model loader:
# ./lcm/models/two_tower_diffusion_lcm/loader.py
from __future__ import annotations
import warnings
from pathlib import Path
from typing import Any, Callable, Dict, Mapping, Optional, Protocol, Union
from warnings import catch_warnings
import torch
from torch import Tensor
from typing_extensions import TypeAlias
from fairseq2.typing import Device
from safetensors.torch import load_file
MapLocation: TypeAlias = Optional[
Union[Callable[[Tensor, str], Tensor], Device, str, Dict[str, str]]
]
class TensorLoader(Protocol):
"""Loads tensors from files."""
def __call__(
self,
path: Path,
*,
map_location: MapLocation = None,
restrict: bool = False,
) -> Dict[str, Any]:
"""
:param path:
The path to the file.
:param map_location:
Same as the ``map_location`` parametload_two_tower_diffusion_lcm_model = StandardModelLoader( # type: ignore # FIXME
config_loader=load_two_tower_diffusion_lcm_config,
factory=create_two_tower_diffusion_lcm_model,
checkpoint_converter=convert_lcm_checkpoint,
restrict_checkpoints=False,
)
"""
class TensorDumper(Protocol):
"""Dumps tensors to files."""
def __call__(self, data: Mapping[str, Any], path: Path) -> None:
"""
:param data:
The dictionary containing tensors and other auxiliary data.
:param path:
The path to the file.
"""
def load_tensors(
path: Path,
*,
map_location=None,
restrict: bool = False,
) -> Dict[str, Any]:
"""Load a checkpoint in .pt or .safetensors format."""
if str(path).endswith(".safetensors"):
tensors = load_file(str(path), device=str(map_location) if map_location else "cpu")
return {"model": tensors} # ✅ Wrap it like a .pt file
with warnings.catch_warnings():
warnings.simplefilter("ignore")
return torch.load(
str(path), map_location, weights_only=restrict # type: ignore[arg-type]
)
def dump_tensors(data: Mapping[str, Any], path: Path) -> None:
"""Dump ``data`` to a PyTorch tensor file under ``path``."""
with catch_warnings():
warnings.simplefilter("ignore") # Suppress noisy FSDP warnings.
torch.save(data, path)
|