LexaLCM_Pre0 / Patches /Patch_TorchLoader.py
Lexa
Converted .pt files to safetensors, then (dirtily) patched fairseq to enable loading of safetensor files
b5a0bec
# Patch for fairseq2.utils.file.load_tensors
#
# This patch allows for loading safetensors files
#
# It is used in the two_tower_diffusion_lcm model loader:
# ./lcm/models/two_tower_diffusion_lcm/loader.py
from __future__ import annotations
import warnings
from pathlib import Path
from typing import Any, Callable, Dict, Mapping, Optional, Protocol, Union
from warnings import catch_warnings
import torch
from torch import Tensor
from typing_extensions import TypeAlias
from fairseq2.typing import Device
from safetensors.torch import load_file
MapLocation: TypeAlias = Optional[
Union[Callable[[Tensor, str], Tensor], Device, str, Dict[str, str]]
]
class TensorLoader(Protocol):
"""Loads tensors from files."""
def __call__(
self,
path: Path,
*,
map_location: MapLocation = None,
restrict: bool = False,
) -> Dict[str, Any]:
"""
:param path:
The path to the file.
:param map_location:
Same as the ``map_location`` parametload_two_tower_diffusion_lcm_model = StandardModelLoader( # type: ignore # FIXME
config_loader=load_two_tower_diffusion_lcm_config,
factory=create_two_tower_diffusion_lcm_model,
checkpoint_converter=convert_lcm_checkpoint,
restrict_checkpoints=False,
)
"""
class TensorDumper(Protocol):
"""Dumps tensors to files."""
def __call__(self, data: Mapping[str, Any], path: Path) -> None:
"""
:param data:
The dictionary containing tensors and other auxiliary data.
:param path:
The path to the file.
"""
def load_tensors(
path: Path,
*,
map_location=None,
restrict: bool = False,
) -> Dict[str, Any]:
"""Load a checkpoint in .pt or .safetensors format."""
if str(path).endswith(".safetensors"):
tensors = load_file(str(path), device=str(map_location) if map_location else "cpu")
return {"model": tensors} # ✅ Wrap it like a .pt file
with warnings.catch_warnings():
warnings.simplefilter("ignore")
return torch.load(
str(path), map_location, weights_only=restrict # type: ignore[arg-type]
)
def dump_tensors(data: Mapping[str, Any], path: Path) -> None:
"""Dump ``data`` to a PyTorch tensor file under ``path``."""
with catch_warnings():
warnings.simplefilter("ignore") # Suppress noisy FSDP warnings.
torch.save(data, path)