Lexa
Converted .pt files to safetensors, then (dirtily) patched fairseq to enable loading of safetensor files
b5a0bec
| # Copyright (c) Meta Platforms, Inc. and affiliates | |
| # All rights reserved. | |
| # | |
| # | |
| from dataclasses import dataclass, field | |
| from typing import Union | |
| from lcm.models.two_tower_diffusion_lcm.builder import TwoTowerDiffusionLCModelConfig | |
| from lcm.models.two_tower_diffusion_lcm.loader import ( | |
| load_two_tower_diffusion_lcm_model, | |
| ) | |
| from lcm.train.lcm.trainer import LCMTrainer, LCMTrainerBuilder, LCMTrainingConfig | |
| from lcm.train.two_tower_diffusion_lcm.criterion import ( | |
| TowerDiffusionLCMCriterionConfig, | |
| ) | |
| class TwoTowerDiffusionLCMTrainingConfig(LCMTrainingConfig): | |
| model_config_or_name: Union[TwoTowerDiffusionLCModelConfig, str, None] = None | |
| """The model configuration or name to train.""" | |
| criterion: TowerDiffusionLCMCriterionConfig = field( # type: ignore | |
| default_factory=lambda: TowerDiffusionLCMCriterionConfig() | |
| ) | |
| class DiffusionLCMTrainerBuilder(LCMTrainerBuilder): | |
| config: TwoTowerDiffusionLCMTrainingConfig | |
| def __init__(self, config: TwoTowerDiffusionLCMTrainingConfig): | |
| super().__init__(config) | |
| def model_loader(self): | |
| """A fairseq2 ModelLoader""" | |
| return load_two_tower_diffusion_lcm_model | |
| def prepare_two_tower_diffusion_lcm_trainer( | |
| config: TwoTowerDiffusionLCMTrainingConfig, | |
| ) -> LCMTrainer: | |
| """Create an LCM Trainer. | |
| :param config: The training configuration. | |
| """ | |
| return DiffusionLCMTrainerBuilder(config).build_trainer() | |