# Patch for fairseq2.utils.file.load_tensors # # This patch allows for loading safetensors files # # It is used in the two_tower_diffusion_lcm model loader: # ./lcm/models/two_tower_diffusion_lcm/loader.py from __future__ import annotations import warnings from pathlib import Path from typing import Any, Callable, Dict, Mapping, Optional, Protocol, Union from warnings import catch_warnings import torch from torch import Tensor from typing_extensions import TypeAlias from fairseq2.typing import Device from safetensors.torch import load_file MapLocation: TypeAlias = Optional[ Union[Callable[[Tensor, str], Tensor], Device, str, Dict[str, str]] ] class TensorLoader(Protocol): """Loads tensors from files.""" def __call__( self, path: Path, *, map_location: MapLocation = None, restrict: bool = False, ) -> Dict[str, Any]: """ :param path: The path to the file. :param map_location: Same as the ``map_location`` parametload_two_tower_diffusion_lcm_model = StandardModelLoader( # type: ignore # FIXME config_loader=load_two_tower_diffusion_lcm_config, factory=create_two_tower_diffusion_lcm_model, checkpoint_converter=convert_lcm_checkpoint, restrict_checkpoints=False, ) """ class TensorDumper(Protocol): """Dumps tensors to files.""" def __call__(self, data: Mapping[str, Any], path: Path) -> None: """ :param data: The dictionary containing tensors and other auxiliary data. :param path: The path to the file. """ def load_tensors( path: Path, *, map_location=None, restrict: bool = False, ) -> Dict[str, Any]: """Load a checkpoint in .pt or .safetensors format.""" if str(path).endswith(".safetensors"): tensors = load_file(str(path), device=str(map_location) if map_location else "cpu") return {"model": tensors} # ✅ Wrap it like a .pt file with warnings.catch_warnings(): warnings.simplefilter("ignore") return torch.load( str(path), map_location, weights_only=restrict # type: ignore[arg-type] ) def dump_tensors(data: Mapping[str, Any], path: Path) -> None: """Dump ``data`` to a PyTorch tensor file under ``path``.""" with catch_warnings(): warnings.simplefilter("ignore") # Suppress noisy FSDP warnings. torch.save(data, path)