File size: 1,460 Bytes
b5a0bec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# Copyright (c) Meta Platforms, Inc. and affiliates
# All rights reserved.
#
#

from dataclasses import dataclass, field
from typing import Union

from lcm.models.two_tower_diffusion_lcm.builder import TwoTowerDiffusionLCModelConfig
from lcm.models.two_tower_diffusion_lcm.loader import (
    load_two_tower_diffusion_lcm_model,
)
from lcm.train.lcm.trainer import LCMTrainer, LCMTrainerBuilder, LCMTrainingConfig
from lcm.train.two_tower_diffusion_lcm.criterion import (
    TowerDiffusionLCMCriterionConfig,
)


@dataclass
class TwoTowerDiffusionLCMTrainingConfig(LCMTrainingConfig):
    model_config_or_name: Union[TwoTowerDiffusionLCModelConfig, str, None] = None
    """The model configuration or name to train."""

    criterion: TowerDiffusionLCMCriterionConfig = field(  # type: ignore
        default_factory=lambda: TowerDiffusionLCMCriterionConfig()
    )


class DiffusionLCMTrainerBuilder(LCMTrainerBuilder):
    config: TwoTowerDiffusionLCMTrainingConfig

    def __init__(self, config: TwoTowerDiffusionLCMTrainingConfig):
        super().__init__(config)

    @property
    def model_loader(self):
        """A fairseq2 ModelLoader"""
        return load_two_tower_diffusion_lcm_model


def prepare_two_tower_diffusion_lcm_trainer(
    config: TwoTowerDiffusionLCMTrainingConfig,
) -> LCMTrainer:
    """Create an LCM Trainer.
    :param config: The training configuration.
    """
    return DiffusionLCMTrainerBuilder(config).build_trainer()