Lexa
Converted .pt files to safetensors, then (dirtily) patched fairseq to enable loading of safetensor files
b5a0bec
# Copyright (c) Meta Platforms, Inc. and affiliates
# All rights reserved.
#
#
from dataclasses import dataclass, field
from typing import Union
from lcm.models.two_tower_diffusion_lcm.builder import TwoTowerDiffusionLCModelConfig
from lcm.models.two_tower_diffusion_lcm.loader import (
load_two_tower_diffusion_lcm_model,
)
from lcm.train.lcm.trainer import LCMTrainer, LCMTrainerBuilder, LCMTrainingConfig
from lcm.train.two_tower_diffusion_lcm.criterion import (
TowerDiffusionLCMCriterionConfig,
)
@dataclass
class TwoTowerDiffusionLCMTrainingConfig(LCMTrainingConfig):
model_config_or_name: Union[TwoTowerDiffusionLCModelConfig, str, None] = None
"""The model configuration or name to train."""
criterion: TowerDiffusionLCMCriterionConfig = field( # type: ignore
default_factory=lambda: TowerDiffusionLCMCriterionConfig()
)
class DiffusionLCMTrainerBuilder(LCMTrainerBuilder):
config: TwoTowerDiffusionLCMTrainingConfig
def __init__(self, config: TwoTowerDiffusionLCMTrainingConfig):
super().__init__(config)
@property
def model_loader(self):
"""A fairseq2 ModelLoader"""
return load_two_tower_diffusion_lcm_model
def prepare_two_tower_diffusion_lcm_trainer(
config: TwoTowerDiffusionLCMTrainingConfig,
) -> LCMTrainer:
"""Create an LCM Trainer.
:param config: The training configuration.
"""
return DiffusionLCMTrainerBuilder(config).build_trainer()