File size: 1,497 Bytes
3d79eb3 b5a0bec 3d79eb3 b5a0bec 3d79eb3 b5a0bec 3d79eb3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
#
from fairseq2.models.config_loader import StandardModelConfigLoader
from fairseq2.models.loader import StandardModelLoader, load_model
from Patches import Patch_TorchLoader
from lcm.models.base_lcm.loader import convert_lcm_checkpoint
from lcm.models.two_tower_diffusion_lcm.builder import (
TWO_TOWER_DIFFUSION_LCM_MODEL_TYPE,
TwoTowerDiffusionLCModelConfig,
create_two_tower_diffusion_lcm_model,
lcm_archs,
)
from lcm.utils.model_type_registry import ModelTypeConfig, lcm_model_type_registry
load_two_tower_diffusion_lcm_config = StandardModelConfigLoader(
family=TWO_TOWER_DIFFUSION_LCM_MODEL_TYPE,
config_kls=TwoTowerDiffusionLCModelConfig,
arch_configs=lcm_archs,
)
load_two_tower_diffusion_lcm_model = StandardModelLoader(
config_loader=load_two_tower_diffusion_lcm_config,
factory=create_two_tower_diffusion_lcm_model,
checkpoint_converter=convert_lcm_checkpoint,
restrict_checkpoints=False,
tensor_loader=Patch_TorchLoader.load_tensors, # 🔥 the key patch
)
load_model.register(
TWO_TOWER_DIFFUSION_LCM_MODEL_TYPE, load_two_tower_diffusion_lcm_model
)
lcm_model_type_registry.register(
ModelTypeConfig(
model_type=TWO_TOWER_DIFFUSION_LCM_MODEL_TYPE,
config_loader=load_two_tower_diffusion_lcm_config,
model_factory=create_two_tower_diffusion_lcm_model,
model_loader=load_two_tower_diffusion_lcm_model,
)
)
|