File size: 1,460 Bytes
b5a0bec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
# Copyright (c) Meta Platforms, Inc. and affiliates
# All rights reserved.
#
#
from dataclasses import dataclass, field
from typing import Union
from lcm.models.two_tower_diffusion_lcm.builder import TwoTowerDiffusionLCModelConfig
from lcm.models.two_tower_diffusion_lcm.loader import (
load_two_tower_diffusion_lcm_model,
)
from lcm.train.lcm.trainer import LCMTrainer, LCMTrainerBuilder, LCMTrainingConfig
from lcm.train.two_tower_diffusion_lcm.criterion import (
TowerDiffusionLCMCriterionConfig,
)
@dataclass
class TwoTowerDiffusionLCMTrainingConfig(LCMTrainingConfig):
model_config_or_name: Union[TwoTowerDiffusionLCModelConfig, str, None] = None
"""The model configuration or name to train."""
criterion: TowerDiffusionLCMCriterionConfig = field( # type: ignore
default_factory=lambda: TowerDiffusionLCMCriterionConfig()
)
class DiffusionLCMTrainerBuilder(LCMTrainerBuilder):
config: TwoTowerDiffusionLCMTrainingConfig
def __init__(self, config: TwoTowerDiffusionLCMTrainingConfig):
super().__init__(config)
@property
def model_loader(self):
"""A fairseq2 ModelLoader"""
return load_two_tower_diffusion_lcm_model
def prepare_two_tower_diffusion_lcm_trainer(
config: TwoTowerDiffusionLCMTrainingConfig,
) -> LCMTrainer:
"""Create an LCM Trainer.
:param config: The training configuration.
"""
return DiffusionLCMTrainerBuilder(config).build_trainer()
|