| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import math |
| | import os |
| | from contextlib import nullcontext |
| | from dataclasses import dataclass, field |
| | from pathlib import Path |
| | from typing import Callable, Optional |
| |
|
| | import lightning.pytorch as L |
| | import numpy as np |
| | import torch |
| | from megatron.core import parallel_state |
| | from megatron.core.dist_checkpointing.mapping import ShardedStateDict |
| | from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding |
| | from megatron.core.models.common.vision_module.vision_module import VisionModule |
| | from megatron.core.optimizer import OptimizerConfig |
| | from megatron.core.transformer.enums import ModelType |
| | from megatron.core.transformer.transformer_config import TransformerConfig |
| | from megatron.core.transformer.utils import openai_gelu, sharded_state_dict_default |
| | from safetensors.torch import load_file as load_safetensors |
| | from safetensors.torch import save_file as save_safetensors |
| | from torch import nn |
| | from torch.nn import functional as F |
| |
|
| | from nemo.collections.diffusion.encoders.conditioner import FrozenCLIPEmbedder, FrozenT5Embedder |
| | from nemo.collections.diffusion.models.dit.dit_layer_spec import ( |
| | AdaLNContinuous, |
| | FluxSingleTransformerBlock, |
| | MMDiTLayer, |
| | get_flux_double_transformer_engine_spec, |
| | get_flux_single_transformer_engine_spec, |
| | ) |
| | from nemo.collections.diffusion.models.flux.layers import EmbedND, MLPEmbedder, TimeStepEmbedder |
| | from nemo.collections.diffusion.sampler.flow_matching.flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler |
| | from nemo.collections.diffusion.utils.flux_ckpt_converter import ( |
| | _import_qkv, |
| | _import_qkv_bias, |
| | flux_transformer_converter, |
| | ) |
| | from nemo.collections.diffusion.vae.autoencoder import AutoEncoder, AutoEncoderConfig |
| | from nemo.collections.llm import fn |
| | from nemo.lightning import io, teardown |
| | from nemo.lightning.megatron_parallel import MaskedTokenLossReduction |
| | from nemo.lightning.pytorch.optim import MegatronOptimizerModule, OptimizerModule |
| | from nemo.utils import logging |
| |
|
| |
|
| | |
| | def flux_data_step(dataloader_iter): |
| | batch = next(dataloader_iter) |
| | if isinstance(batch, tuple) and len(batch) == 3: |
| | _batch = batch[0] |
| | else: |
| | _batch = batch |
| |
|
| | _batch['loss_mask'] = torch.Tensor([1.0]).cuda(non_blocking=True) |
| | return _batch |
| |
|
| |
|
| | @dataclass |
| | class FluxConfig(TransformerConfig, io.IOMixin): |
| | """ |
| | transformer related Flux Config |
| | """ |
| |
|
| | num_layers: int = 1 |
| | num_joint_layers: int = 19 |
| | num_single_layers: int = 38 |
| | hidden_size: int = 3072 |
| | num_attention_heads: int = 24 |
| | activation_func: Callable = openai_gelu |
| | add_qkv_bias: bool = True |
| | in_channels: int = 64 |
| | context_dim: int = 4096 |
| | model_channels: int = 256 |
| | patch_size: int = 1 |
| | guidance_embed: bool = False |
| | vec_in_dim: int = 768 |
| | rotary_interleaved: bool = True |
| | apply_rope_fusion: bool = False |
| | layernorm_epsilon: float = 1e-06 |
| | hidden_dropout: float = 0 |
| | attention_dropout: float = 0 |
| | use_cpu_initialization: bool = True |
| | gradient_accumulation_fusion: bool = False |
| | enable_cuda_graph: bool = False |
| | use_te_rng_tracker: bool = False |
| | cuda_graph_warmup_steps: int = 2 |
| |
|
| | guidance_scale: float = 3.5 |
| | data_step_fn: Callable = flux_data_step |
| | ckpt_path: Optional[str] = None |
| | load_dist_ckpt: bool = False |
| | do_convert_from_hf: bool = False |
| | save_converted_model_to = None |
| |
|
| | def configure_model(self): |
| | model = Flux(config=self) |
| | return model |
| |
|
| |
|
| | @dataclass |
| | class T5Config: |
| | """ |
| | T5 Config |
| | """ |
| |
|
| | version: Optional[str] = field(default_factory=lambda: "google/t5-v1_1-xxl") |
| | max_length: Optional[int] = field(default_factory=lambda: 512) |
| | load_config_only: bool = False |
| |
|
| |
|
| | @dataclass |
| | class ClipConfig: |
| | """ |
| | Clip Config |
| | """ |
| |
|
| | version: Optional[str] = field(default_factory=lambda: "openai/clip-vit-large-patch14") |
| | max_length: Optional[int] = field(default_factory=lambda: 77) |
| | always_return_pooled: Optional[bool] = field(default_factory=lambda: True) |
| |
|
| |
|
| | @dataclass |
| | class FluxModelParams: |
| | """ |
| | Flux Model Params |
| | """ |
| |
|
| | flux_config: FluxConfig = field(default_factory=FluxConfig) |
| | vae_config: AutoEncoderConfig = field( |
| | default_factory=lambda: AutoEncoderConfig(ch_mult=[1, 2, 4, 4], attn_resolutions=[]) |
| | ) |
| | clip_params: ClipConfig = field(default_factory=ClipConfig) |
| | t5_params: T5Config = field(default_factory=T5Config) |
| |
|
| | scheduler_steps: int = 1000 |
| | device: str = 'cuda' |
| |
|
| |
|
| | |
| | class Flux(VisionModule): |
| | """ |
| | NeMo implementation of Flux model, with flux transformer and single flux transformer blocks implemented with |
| | Megatron Core. |
| | |
| | Args: |
| | config (FluxConfig): Configuration object containing the necessary parameters for setting up the model, |
| | such as the number of channels, hidden size, attention heads, and more. |
| | |
| | Attributes: |
| | out_channels (int): The number of output channels for the model. |
| | hidden_size (int): The size of the hidden layers. |
| | num_attention_heads (int): The number of attention heads for the transformer. |
| | patch_size (int): The size of the image patches. |
| | in_channels (int): The number of input channels for the image. |
| | guidance_embed (bool): A flag to indicate if guidance embedding should be used. |
| | pos_embed (EmbedND): Position embedding layer for the model. |
| | img_embed (nn.Linear): Linear layer to embed image input into the hidden space. |
| | txt_embed (nn.Linear): Linear layer to embed text input into the hidden space. |
| | timestep_embedding (TimeStepEmbedder): Embedding layer for timesteps, used in generative models. |
| | vector_embedding (MLPEmbedder): MLP embedding for vector inputs. |
| | guidance_embedding (nn.Module or nn.Identity): Optional MLP embedding for guidance, or identity if not used. |
| | double_blocks (nn.ModuleList): A list of transformer blocks for the double block layers. |
| | single_blocks (nn.ModuleList): A list of transformer blocks for the single block layers. |
| | norm_out (AdaLNContinuous): Normalization layer for the output. |
| | proj_out (nn.Linear): Final linear layer for output projection. |
| | |
| | Methods: |
| | forward: Performs a forward pass through the network, processing images, text, timesteps, and guidance. |
| | load_from_pretrained: Loads model weights from a pretrained checkpoint, with optional support for distribution |
| | and conversion from Hugging Face format. |
| | |
| | """ |
| |
|
| | def __init__(self, config: FluxConfig): |
| | |
| | super().__init__(config) |
| | self.out_channels = config.in_channels |
| | self.hidden_size = config.hidden_size |
| | self.num_attention_heads = config.num_attention_heads |
| | self.patch_size = config.patch_size |
| | self.in_channels = config.in_channels |
| | self.guidance_embed = config.guidance_embed |
| |
|
| | self.pos_embed = EmbedND(dim=self.hidden_size, theta=10000, axes_dim=[16, 56, 56]) |
| | self.img_embed = nn.Linear(config.in_channels, self.hidden_size) |
| | self.txt_embed = nn.Linear(config.context_dim, self.hidden_size) |
| | self.timestep_embedding = TimeStepEmbedder(config.model_channels, self.hidden_size) |
| | self.vector_embedding = MLPEmbedder(in_dim=config.vec_in_dim, hidden_dim=self.hidden_size) |
| | if config.guidance_embed: |
| | self.guidance_embedding = ( |
| | MLPEmbedder(in_dim=config.model_channels, hidden_dim=self.hidden_size) |
| | if config.guidance_embed |
| | else nn.Identity() |
| | ) |
| |
|
| | self.double_blocks = nn.ModuleList( |
| | [ |
| | MMDiTLayer( |
| | config=config, |
| | submodules=get_flux_double_transformer_engine_spec().submodules, |
| | layer_number=i, |
| | context_pre_only=False, |
| | ) |
| | for i in range(config.num_joint_layers) |
| | ] |
| | ) |
| |
|
| | self.single_blocks = nn.ModuleList( |
| | [ |
| | FluxSingleTransformerBlock( |
| | config=config, |
| | submodules=get_flux_single_transformer_engine_spec().submodules, |
| | layer_number=i, |
| | ) |
| | for i in range(config.num_single_layers) |
| | ] |
| | ) |
| |
|
| | self.norm_out = AdaLNContinuous(config=config, conditioning_embedding_dim=self.hidden_size) |
| | self.proj_out = nn.Linear(self.hidden_size, self.patch_size * self.patch_size * self.out_channels, bias=True) |
| | if self.config.ckpt_path is not None: |
| | self.load_from_pretrained( |
| | self.config.ckpt_path, |
| | do_convert_from_hf=self.config.do_convert_from_hf, |
| | load_dist_ckpt=self.config.load_dist_ckpt, |
| | save_converted_model_to=self.config.save_converted_model_to, |
| | ) |
| |
|
| | def get_fp8_context(self): |
| | |
| | if not self.config.fp8: |
| | fp8_context = nullcontext() |
| | else: |
| | import transformer_engine |
| |
|
| | if self.config.fp8 == "e4m3": |
| | fp8_format = transformer_engine.common.recipe.Format.E4M3 |
| | elif self.config.fp8 == "hybrid": |
| | fp8_format = transformer_engine.common.recipe.Format.HYBRID |
| | else: |
| | raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.") |
| |
|
| | fp8_recipe = transformer_engine.common.recipe.DelayedScaling( |
| | margin=self.config.fp8_margin, |
| | interval=self.config.fp8_interval, |
| | fp8_format=fp8_format, |
| | amax_compute_algo=self.config.fp8_amax_compute_algo, |
| | amax_history_len=self.config.fp8_amax_history_len, |
| | override_linear_precision=(False, False, not self.config.fp8_wgrad), |
| | ) |
| | fp8_group = None |
| | if parallel_state.model_parallel_is_initialized(): |
| | fp8_group = parallel_state.get_amax_reduction_group(with_context_parallel=True) |
| | fp8_context = transformer_engine.pytorch.fp8_autocast( |
| | enabled=True, fp8_recipe=fp8_recipe, fp8_group=fp8_group |
| | ) |
| | return fp8_context |
| |
|
| | def forward( |
| | self, |
| | img: torch.Tensor, |
| | txt: torch.Tensor = None, |
| | y: torch.Tensor = None, |
| | timesteps: torch.LongTensor = None, |
| | img_ids: torch.Tensor = None, |
| | txt_ids: torch.Tensor = None, |
| | guidance: torch.Tensor = None, |
| | controlnet_double_block_samples: torch.Tensor = None, |
| | controlnet_single_block_samples: torch.Tensor = None, |
| | ): |
| | """ |
| | Forward pass through the model, processing image, text, and additional inputs like guidance and timesteps. |
| | |
| | Args: |
| | img (torch.Tensor): |
| | The image input tensor. |
| | txt (torch.Tensor, optional): |
| | The text input tensor (default is None). |
| | y (torch.Tensor, optional): |
| | The vector input for embedding (default is None). |
| | timesteps (torch.LongTensor, optional): |
| | The timestep input, typically used in generative models (default is None). |
| | img_ids (torch.Tensor, optional): |
| | Image IDs for positional encoding (default is None). |
| | txt_ids (torch.Tensor, optional): |
| | Text IDs for positional encoding (default is None). |
| | guidance (torch.Tensor, optional): |
| | Guidance input for conditioning (default is None). |
| | controlnet_double_block_samples (torch.Tensor, optional): |
| | Optional controlnet samples for double blocks (default is None). |
| | controlnet_single_block_samples (torch.Tensor, optional): |
| | Optional controlnet samples for single blocks (default is None). |
| | |
| | Returns: |
| | torch.Tensor: The final output tensor from the model after processing all inputs. |
| | """ |
| | hidden_states = self.img_embed(img) |
| | encoder_hidden_states = self.txt_embed(txt) |
| |
|
| | timesteps = timesteps.to(img.dtype) * 1000 |
| | vec_emb = self.timestep_embedding(timesteps) |
| |
|
| | if guidance is not None: |
| | vec_emb = vec_emb + self.guidance_embedding(self.timestep_embedding.time_proj(guidance * 1000)) |
| | vec_emb = vec_emb + self.vector_embedding(y) |
| |
|
| | ids = torch.cat((txt_ids, img_ids), dim=1) |
| | rotary_pos_emb = self.pos_embed(ids) |
| | for id_block, block in enumerate(self.double_blocks): |
| | with self.get_fp8_context(): |
| | hidden_states, encoder_hidden_states = block( |
| | hidden_states=hidden_states, |
| | encoder_hidden_states=encoder_hidden_states, |
| | rotary_pos_emb=rotary_pos_emb, |
| | emb=vec_emb, |
| | ) |
| |
|
| | if controlnet_double_block_samples is not None: |
| | interval_control = len(self.double_blocks) / len(controlnet_double_block_samples) |
| | interval_control = int(np.ceil(interval_control)) |
| | hidden_states = hidden_states + controlnet_double_block_samples[id_block // interval_control] |
| |
|
| | hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=0) |
| |
|
| | for id_block, block in enumerate(self.single_blocks): |
| | with self.get_fp8_context(): |
| | hidden_states, _ = block( |
| | hidden_states=hidden_states, |
| | rotary_pos_emb=rotary_pos_emb, |
| | emb=vec_emb, |
| | ) |
| |
|
| | if controlnet_single_block_samples is not None: |
| | interval_control = len(self.single_blocks) / len(controlnet_single_block_samples) |
| | interval_control = int(np.ceil(interval_control)) |
| | hidden_states = torch.cat( |
| | [ |
| | hidden_states[: encoder_hidden_states.shape[0]], |
| | hidden_states[encoder_hidden_states.shape[0] :] |
| | + controlnet_single_block_samples[id_block // interval_control], |
| | ] |
| | ) |
| |
|
| | hidden_states = hidden_states[encoder_hidden_states.shape[0] :, ...] |
| |
|
| | hidden_states = self.norm_out(hidden_states, vec_emb) |
| | output = self.proj_out(hidden_states) |
| |
|
| | return output |
| |
|
| | def load_from_pretrained( |
| | self, ckpt_path, do_convert_from_hf=False, save_converted_model_to=None, load_dist_ckpt=False |
| | ): |
| | |
| | if load_dist_ckpt: |
| | from megatron.core import dist_checkpointing |
| |
|
| | sharded_state_dict = dict(state_dict=self.sharded_state_dict(prefix="module.")) |
| | loaded_state_dict = dist_checkpointing.load( |
| | sharded_state_dict=sharded_state_dict, checkpoint_dir=ckpt_path |
| | ) |
| | ckpt = {k.removeprefix("module."): v for k, v in loaded_state_dict["state_dict"].items()} |
| | else: |
| | if do_convert_from_hf: |
| | ckpt = flux_transformer_converter(ckpt_path, self.config) |
| | if save_converted_model_to is not None: |
| | os.makedirs(save_converted_model_to, exist_ok=True) |
| | save_path = os.path.join(save_converted_model_to, 'nemo_flux_transformer.safetensors') |
| | save_safetensors(ckpt, save_path) |
| | logging.info(f'saving converted transformer checkpoint to {save_path}') |
| | else: |
| | ckpt = load_safetensors(ckpt_path) |
| | missing, unexpected = self.load_state_dict(ckpt, strict=False) |
| | missing = [k for k in missing if not k.endswith('_extra_state')] |
| | |
| | if len(missing) > 0: |
| | logging.info( |
| | f"The following keys are missing during checkpoint loading, " |
| | f"please check the ckpt provided or the image quality may be compromised.\n {missing}" |
| | ) |
| | logging.info(f"Found unexepected keys: \n {unexpected}") |
| | logging.info(f"Restored flux model weights from {ckpt_path}") |
| |
|
| | def sharded_state_dict(self, prefix='', sharded_offsets: tuple = (), metadata: dict = None) -> ShardedStateDict: |
| | sharded_state_dict = {} |
| | layer_prefix = f'{prefix}double_blocks.' |
| | for layer in self.double_blocks: |
| | offset = layer._get_layer_offset(self.config) |
| |
|
| | global_layer_offset = layer.layer_number |
| | state_dict_prefix = f'{layer_prefix}{global_layer_offset - offset}.' |
| | sharded_prefix = f'{layer_prefix}{global_layer_offset}.' |
| | sharded_pp_offset = [] |
| |
|
| | layer_sharded_state_dict = layer.sharded_state_dict(state_dict_prefix, sharded_pp_offset, metadata) |
| | replace_prefix_for_sharding(layer_sharded_state_dict, state_dict_prefix, sharded_prefix) |
| |
|
| | sharded_state_dict.update(layer_sharded_state_dict) |
| |
|
| | layer_prefix = f'{prefix}single_blocks.' |
| | for layer in self.single_blocks: |
| | offset = layer._get_layer_offset(self.config) |
| |
|
| | global_layer_offset = layer.layer_number |
| | state_dict_prefix = f'{layer_prefix}{global_layer_offset - offset}.' |
| | sharded_prefix = f'{layer_prefix}{global_layer_offset}.' |
| | sharded_pp_offset = [] |
| |
|
| | layer_sharded_state_dict = layer.sharded_state_dict(state_dict_prefix, sharded_pp_offset, metadata) |
| | replace_prefix_for_sharding(layer_sharded_state_dict, state_dict_prefix, sharded_prefix) |
| |
|
| | sharded_state_dict.update(layer_sharded_state_dict) |
| |
|
| | for name, module in self.named_children(): |
| | if not (module is self.single_blocks or module is self.double_blocks): |
| | sharded_state_dict.update( |
| | sharded_state_dict_default(module, f'{prefix}{name}.', sharded_offsets, metadata) |
| | ) |
| | return sharded_state_dict |
| |
|
| |
|
| | class MegatronFluxModel(L.LightningModule, io.IOMixin, io.ConnectorMixin, fn.FNMixin): |
| | ''' |
| | Megatron wrapper for flux. |
| | |
| | Args: |
| | flux_params (FluxModelParams): Parameters to configure the Flux model. |
| | ''' |
| |
|
| | def __init__( |
| | self, |
| | flux_params: FluxModelParams, |
| | optim: Optional[OptimizerModule] = None, |
| | ): |
| | |
| | self.params = flux_params |
| | self.config = flux_params.flux_config |
| | super().__init__() |
| | self._training_loss_reduction = None |
| | self._validation_loss_reduction = None |
| |
|
| | self.vae_config = self.params.vae_config |
| | self.clip_params = self.params.clip_params |
| | self.t5_params = self.params.t5_params |
| | self.optim = optim or MegatronOptimizerModule(config=OptimizerConfig(lr=1e-4, use_distributed_optimizer=False)) |
| | self.optim.connect(self) |
| | self.model_type = ModelType.encoder_or_decoder |
| | self.text_precached = self.t5_params is None or self.clip_params is None |
| | self.image_precached = self.vae_config is None |
| |
|
| | def configure_model(self): |
| | |
| | if not hasattr(self, "module"): |
| | self.module = self.config.configure_model() |
| | self.configure_vae(self.vae_config) |
| | self.configure_scheduler() |
| | self.configure_text_encoders(self.clip_params, self.t5_params) |
| | for name, param in self.module.named_parameters(): |
| | if self.config.num_single_layers == 0: |
| | if 'context' in name or 'added' in name: |
| | param.requires_grad = False |
| | |
| | |
| | if 'single_blocks' in name and 'self_attention.linear_proj.bias' in name: |
| | param.requires_grad = False |
| |
|
| | def configure_scheduler(self): |
| | |
| | self.scheduler = FlowMatchEulerDiscreteScheduler( |
| | num_train_timesteps=self.params.scheduler_steps, |
| | ) |
| |
|
| | def configure_vae(self, vae): |
| | |
| | if isinstance(vae, nn.Module): |
| | self.vae = vae.eval().cuda() |
| | self.vae_scale_factor = 2 ** (len(self.vae.params.ch_mult)) |
| | for param in self.vae.parameters(): |
| | param.requires_grad = False |
| | elif isinstance(vae, AutoEncoderConfig): |
| | self.vae = AutoEncoder(vae).eval().cuda() |
| | self.vae_scale_factor = 2 ** (len(vae.ch_mult)) |
| | for param in self.vae.parameters(): |
| | param.requires_grad = False |
| | else: |
| | logging.info("Vae not provided, assuming the image input is precached...") |
| | self.vae = None |
| | self.vae_scale_factor = 16 |
| |
|
| | def configure_text_encoders(self, clip, t5): |
| | |
| | if isinstance(clip, nn.Module): |
| | self.clip = clip |
| | elif isinstance(clip, ClipConfig): |
| | self.clip = FrozenCLIPEmbedder( |
| | version=self.clip_params.version, |
| | max_length=self.clip_params.max_length, |
| | always_return_pooled=self.clip_params.always_return_pooled, |
| | device=torch.cuda.current_device(), |
| | ) |
| | else: |
| | logging.info("CLIP encoder not provided, assuming the text embeddings is precached...") |
| | self.clip = None |
| |
|
| | if isinstance(t5, nn.Module): |
| | self.t5 = t5 |
| | elif isinstance(t5, T5Config): |
| | self.t5 = FrozenT5Embedder( |
| | self.t5_params.version, |
| | max_length=self.t5_params.max_length, |
| | device=torch.cuda.current_device(), |
| | load_config_only=self.t5_params.load_config_only, |
| | ) |
| | else: |
| | logging.info("T5 encoder not provided, assuming the text embeddings is precached...") |
| | self.t5 = None |
| |
|
| | |
| | def data_step(self, dataloader_iter): |
| | return self.config.data_step_fn(dataloader_iter) |
| |
|
| | def forward(self, *args, **kwargs): |
| | return self.module(*args, **kwargs) |
| |
|
| | def training_step(self, batch, batch_idx=None) -> torch.Tensor: |
| | |
| | return self.forward_step(batch) |
| |
|
| | def validation_step(self, batch, batch_idx=None) -> torch.Tensor: |
| | |
| |
|
| | return self.forward_step(batch) |
| |
|
| | |
| | def forward_step(self, batch) -> torch.Tensor: |
| | |
| | if self.optim.config.bf16: |
| | self.autocast_dtype = torch.bfloat16 |
| | elif self.optim.config.fp16: |
| | self.autocast_dtype = torch.float |
| | else: |
| | self.autocast_dtype = torch.float32 |
| |
|
| | if self.image_precached: |
| | latents = batch['latents'].cuda(non_blocking=True) |
| | else: |
| | img = batch['images'].cuda(non_blocking=True) |
| | latents = self.vae.encode(img).to(dtype=self.autocast_dtype) |
| | latents, noise, packed_noisy_model_input, latent_image_ids, guidance_vec, timesteps = ( |
| | self.prepare_image_latent(latents) |
| | ) |
| | if self.text_precached: |
| | prompt_embeds = batch['prompt_embeds'].cuda(non_blocking=True).transpose(0, 1) |
| | pooled_prompt_embeds = batch['pooled_prompt_embeds'].cuda(non_blocking=True) |
| | text_ids = batch['text_ids'].cuda(non_blocking=True) |
| | else: |
| | txt = batch['txt'] |
| | prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt( |
| | txt, device=latents.device, dtype=latents.dtype |
| | ) |
| | with torch.cuda.amp.autocast( |
| | self.autocast_dtype in (torch.half, torch.bfloat16), |
| | dtype=self.autocast_dtype, |
| | ): |
| | noise_pred = self.forward( |
| | img=packed_noisy_model_input, |
| | txt=prompt_embeds, |
| | y=pooled_prompt_embeds, |
| | timesteps=timesteps / 1000, |
| | img_ids=latent_image_ids, |
| | txt_ids=text_ids, |
| | guidance=guidance_vec, |
| | ) |
| |
|
| | noise_pred = self._unpack_latents( |
| | noise_pred.transpose(0, 1), |
| | int(latents.shape[2] * self.vae_scale_factor // 2), |
| | int(latents.shape[3] * self.vae_scale_factor // 2), |
| | vae_scale_factor=self.vae_scale_factor, |
| | ).transpose(0, 1) |
| |
|
| | target = noise - latents |
| | loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean") |
| | return loss |
| |
|
| | def encode_prompt(self, prompt, device='cuda', dtype=torch.float32): |
| | |
| | prompt_embeds = self.t5(prompt).transpose(0, 1) |
| | _, pooled_prompt_embeds = self.clip(prompt) |
| | text_ids = torch.zeros(prompt_embeds.shape[1], prompt_embeds.shape[0], 3).to(device=device, dtype=dtype) |
| |
|
| | return prompt_embeds, pooled_prompt_embeds.to(dtype=dtype), text_ids |
| |
|
| | def compute_density_for_timestep_sampling( |
| | self, |
| | weighting_scheme: str, |
| | batch_size: int, |
| | logit_mean: float = 0.0, |
| | logit_std: float = 1.0, |
| | mode_scale: float = 1.29, |
| | ): |
| | """ |
| | Compute the density for sampling the timesteps when doing SD3 training. |
| | |
| | Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. |
| | |
| | SD3 paper reference: https://arxiv.org/abs/2403.03206v1. |
| | """ |
| | if weighting_scheme == "logit_normal": |
| | |
| | u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") |
| | u = torch.nn.functional.sigmoid(u) |
| | elif weighting_scheme == "mode": |
| | u = torch.rand(size=(batch_size,), device="cpu") |
| | u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) |
| | else: |
| | u = torch.rand(size=(batch_size,), device="cpu") |
| | return u |
| |
|
| | def prepare_image_latent(self, latents): |
| | |
| | latent_image_ids = self._prepare_latent_image_ids( |
| | latents.shape[0], |
| | latents.shape[2], |
| | latents.shape[3], |
| | latents.device, |
| | latents.dtype, |
| | ) |
| |
|
| | noise = torch.randn_like(latents, device=latents.device, dtype=latents.dtype) |
| | batch_size = latents.shape[0] |
| | u = self.compute_density_for_timestep_sampling( |
| | "logit_normal", |
| | batch_size, |
| | ) |
| | indices = (u * self.scheduler.num_train_timesteps).long() |
| | timesteps = self.scheduler.timesteps[indices].to(device=latents.device) |
| |
|
| | sigmas = self.scheduler.sigmas.to(device=latents.device, dtype=latents.dtype) |
| | schduler_timesteps = self.scheduler.timesteps.to(device=latents.device) |
| | step_indices = [(schduler_timesteps == t).nonzero().item() for t in timesteps] |
| | timesteps = timesteps.to(dtype=latents.dtype) |
| | sigma = sigmas[step_indices].flatten() |
| |
|
| | while len(sigma.shape) < latents.ndim: |
| | sigma = sigma.unsqueeze(-1) |
| |
|
| | noisy_model_input = (1.0 - sigma) * latents + sigma * noise |
| | packed_noisy_model_input = self._pack_latents( |
| | noisy_model_input, |
| | batch_size=latents.shape[0], |
| | num_channels_latents=latents.shape[1], |
| | height=latents.shape[2], |
| | width=latents.shape[3], |
| | ) |
| |
|
| | if self.config.guidance_embed: |
| | guidance_vec = torch.full( |
| | (noisy_model_input.shape[0],), |
| | self.config.guidance_scale, |
| | device=latents.device, |
| | dtype=latents.dtype, |
| | ) |
| | else: |
| | guidance_vec = None |
| |
|
| | return ( |
| | latents.transpose(0, 1), |
| | noise.transpose(0, 1), |
| | packed_noisy_model_input.transpose(0, 1), |
| | latent_image_ids, |
| | guidance_vec, |
| | timesteps, |
| | ) |
| |
|
| | def _unpack_latents(self, latents, height, width, vae_scale_factor): |
| | |
| | batch_size, num_patches, channels = latents.shape |
| |
|
| | height = height // vae_scale_factor |
| | width = width // vae_scale_factor |
| |
|
| | latents = latents.view(batch_size, height, width, channels // 4, 2, 2) |
| | latents = latents.permute(0, 3, 1, 4, 2, 5) |
| |
|
| | latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) |
| |
|
| | return latents |
| |
|
| | def _prepare_latent_image_ids( |
| | self, batch_size: int, height: int, width: int, device: torch.device, dtype: torch.dtype |
| | ): |
| | |
| | latent_image_ids = torch.zeros(height // 2, width // 2, 3) |
| | latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] |
| | latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] |
| |
|
| | latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape |
| |
|
| | latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) |
| | latent_image_ids = latent_image_ids.reshape( |
| | batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels |
| | ) |
| |
|
| | return latent_image_ids.to(device=device, dtype=dtype) |
| |
|
| | def _pack_latents(self, latents, batch_size, num_channels_latents, height, width): |
| | |
| | latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) |
| | latents = latents.permute(0, 2, 4, 1, 3, 5) |
| | latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) |
| |
|
| | return latents |
| |
|
| | def set_input_tensor(self, tensor): |
| | |
| | pass |
| |
|
| | @property |
| | def training_loss_reduction(self) -> MaskedTokenLossReduction: |
| | |
| | if not self._training_loss_reduction: |
| | self._training_loss_reduction = MaskedTokenLossReduction() |
| |
|
| | return self._training_loss_reduction |
| |
|
| | @property |
| | def validation_loss_reduction(self) -> MaskedTokenLossReduction: |
| | |
| | |
| | if not self._validation_loss_reduction: |
| | self._validation_loss_reduction = MaskedTokenLossReduction(validation_step=True) |
| |
|
| | return self._validation_loss_reduction |
| |
|
| |
|
| | @io.model_importer(MegatronFluxModel, "hf") |
| | class HFFluxImporter(io.ModelConnector["FluxTransformer2DModel", MegatronFluxModel]): |
| | ''' |
| | Convert a HF ckpt into NeMo dist-ckpt compatible format. |
| | ''' |
| |
|
| | |
| | def init(self) -> MegatronFluxModel: |
| | return MegatronFluxModel(self.config) |
| |
|
| | def apply(self, output_path: Path) -> Path: |
| | from diffusers import FluxTransformer2DModel |
| |
|
| | source = FluxTransformer2DModel.from_pretrained(str(self), subfolder="transformer") |
| | target = self.init() |
| | trainer = self.nemo_setup(target) |
| | self.convert_state(source, target) |
| | print(f"Converted flux transformer to Nemo, saving to {output_path}") |
| |
|
| | self.nemo_save(output_path, trainer) |
| |
|
| | print(f"Converted flux transformer saved to {output_path}") |
| |
|
| | teardown(trainer, target) |
| |
|
| | return output_path |
| |
|
| | @property |
| | def config(self) -> FluxConfig: |
| | from diffusers import FluxTransformer2DModel |
| |
|
| | source = FluxTransformer2DModel.from_pretrained(str(self), subfolder="transformer") |
| | source_config = source.config |
| | flux_config = FluxConfig( |
| | num_layers=1, |
| | num_joint_layers=source_config.num_layers, |
| | num_single_layers=source_config.num_single_layers, |
| | hidden_size=source_config.num_attention_heads * source_config.attention_head_dim, |
| | num_attention_heads=source_config.num_attention_heads, |
| | activation_func=openai_gelu, |
| | add_qkv_bias=True, |
| | in_channels=source_config.in_channels, |
| | context_dim=source_config.joint_attention_dim, |
| | model_channels=256, |
| | patch_size=source_config.patch_size, |
| | guidance_embed=source_config.guidance_embeds, |
| | vec_in_dim=source_config.pooled_projection_dim, |
| | rotary_interleaved=True, |
| | layernorm_epsilon=1e-06, |
| | hidden_dropout=0, |
| | attention_dropout=0, |
| | use_cpu_initialization=True, |
| | ) |
| |
|
| | output = FluxModelParams( |
| | flux_config=flux_config, |
| | vae_config=None, |
| | clip_params=None, |
| | t5_params=None, |
| | scheduler_steps=1000, |
| | device='cuda', |
| | ) |
| | return output |
| |
|
| | def convert_state(self, source, target): |
| | |
| | mapping = { |
| | 'transformer_blocks.*.norm1.linear.weight': 'double_blocks.*.adaln.adaLN_modulation.1.weight', |
| | 'transformer_blocks.*.norm1.linear.bias': 'double_blocks.*.adaln.adaLN_modulation.1.bias', |
| | 'transformer_blocks.*.norm1_context.linear.weight': 'double_blocks.*.adaln_context.adaLN_modulation.1.weight', |
| | 'transformer_blocks.*.norm1_context.linear.bias': 'double_blocks.*.adaln_context.adaLN_modulation.1.bias', |
| | 'transformer_blocks.*.attn.norm_q.weight': 'double_blocks.*.self_attention.q_layernorm.weight', |
| | 'transformer_blocks.*.attn.norm_k.weight': 'double_blocks.*.self_attention.k_layernorm.weight', |
| | 'transformer_blocks.*.attn.norm_added_q.weight': 'double_blocks.*.self_attention.added_q_layernorm.weight', |
| | 'transformer_blocks.*.attn.norm_added_k.weight': 'double_blocks.*.self_attention.added_k_layernorm.weight', |
| | 'transformer_blocks.*.attn.to_out.0.weight': 'double_blocks.*.self_attention.linear_proj.weight', |
| | 'transformer_blocks.*.attn.to_out.0.bias': 'double_blocks.*.self_attention.linear_proj.bias', |
| | 'transformer_blocks.*.attn.to_add_out.weight': 'double_blocks.*.self_attention.added_linear_proj.weight', |
| | 'transformer_blocks.*.attn.to_add_out.bias': 'double_blocks.*.self_attention.added_linear_proj.bias', |
| | 'transformer_blocks.*.ff.net.0.proj.weight': 'double_blocks.*.mlp.linear_fc1.weight', |
| | 'transformer_blocks.*.ff.net.0.proj.bias': 'double_blocks.*.mlp.linear_fc1.bias', |
| | 'transformer_blocks.*.ff.net.2.weight': 'double_blocks.*.mlp.linear_fc2.weight', |
| | 'transformer_blocks.*.ff.net.2.bias': 'double_blocks.*.mlp.linear_fc2.bias', |
| | 'transformer_blocks.*.ff_context.net.0.proj.weight': 'double_blocks.*.context_mlp.linear_fc1.weight', |
| | 'transformer_blocks.*.ff_context.net.0.proj.bias': 'double_blocks.*.context_mlp.linear_fc1.bias', |
| | 'transformer_blocks.*.ff_context.net.2.weight': 'double_blocks.*.context_mlp.linear_fc2.weight', |
| | 'transformer_blocks.*.ff_context.net.2.bias': 'double_blocks.*.context_mlp.linear_fc2.bias', |
| | 'single_transformer_blocks.*.norm.linear.weight': 'single_blocks.*.adaln.adaLN_modulation.1.weight', |
| | 'single_transformer_blocks.*.norm.linear.bias': 'single_blocks.*.adaln.adaLN_modulation.1.bias', |
| | 'single_transformer_blocks.*.proj_mlp.weight': 'single_blocks.*.mlp.linear_fc1.weight', |
| | 'single_transformer_blocks.*.proj_mlp.bias': 'single_blocks.*.mlp.linear_fc1.bias', |
| | 'single_transformer_blocks.*.attn.norm_q.weight': 'single_blocks.*.self_attention.q_layernorm.weight', |
| | 'single_transformer_blocks.*.attn.norm_k.weight': 'single_blocks.*.self_attention.k_layernorm.weight', |
| | 'single_transformer_blocks.*.proj_out.bias': 'single_blocks.*.mlp.linear_fc2.bias', |
| | 'norm_out.linear.bias': 'norm_out.adaLN_modulation.1.bias', |
| | 'norm_out.linear.weight': 'norm_out.adaLN_modulation.1.weight', |
| | 'proj_out.bias': 'proj_out.bias', |
| | 'proj_out.weight': 'proj_out.weight', |
| | 'time_text_embed.guidance_embedder.linear_1.bias': 'guidance_embedding.in_layer.bias', |
| | 'time_text_embed.guidance_embedder.linear_1.weight': 'guidance_embedding.in_layer.weight', |
| | 'time_text_embed.guidance_embedder.linear_2.bias': 'guidance_embedding.out_layer.bias', |
| | 'time_text_embed.guidance_embedder.linear_2.weight': 'guidance_embedding.out_layer.weight', |
| | 'x_embedder.bias': 'img_embed.bias', |
| | 'x_embedder.weight': 'img_embed.weight', |
| | 'time_text_embed.timestep_embedder.linear_1.bias': 'timestep_embedding.time_embedder.in_layer.bias', |
| | 'time_text_embed.timestep_embedder.linear_1.weight': 'timestep_embedding.time_embedder.in_layer.weight', |
| | 'time_text_embed.timestep_embedder.linear_2.bias': 'timestep_embedding.time_embedder.out_layer.bias', |
| | 'time_text_embed.timestep_embedder.linear_2.weight': 'timestep_embedding.time_embedder.out_layer.weight', |
| | 'context_embedder.bias': 'txt_embed.bias', |
| | 'context_embedder.weight': 'txt_embed.weight', |
| | 'time_text_embed.text_embedder.linear_1.bias': 'vector_embedding.in_layer.bias', |
| | 'time_text_embed.text_embedder.linear_1.weight': 'vector_embedding.in_layer.weight', |
| | 'time_text_embed.text_embedder.linear_2.bias': 'vector_embedding.out_layer.bias', |
| | 'time_text_embed.text_embedder.linear_2.weight': 'vector_embedding.out_layer.weight', |
| | } |
| | return io.apply_transforms( |
| | source, |
| | target, |
| | mapping=mapping, |
| | transforms=[ |
| | import_double_block_qkv, |
| | import_double_block_qkv_bias, |
| | import_added_qkv, |
| | import_added_qkv_bias, |
| | import_single_block_qkv, |
| | import_single_block_qkv_bias, |
| | transform_single_proj_out, |
| | ], |
| | ) |
| |
|
| |
|
| | @io.state_transform( |
| | source_key=( |
| | "transformer_blocks.*.attn.to_q.weight", |
| | "transformer_blocks.*.attn.to_k.weight", |
| | "transformer_blocks.*.attn.to_v.weight", |
| | ), |
| | target_key=("double_blocks.*.self_attention.linear_qkv.weight"), |
| | ) |
| | def import_double_block_qkv(ctx: io.TransformCTX, q, k, v): |
| | transformer_config = ctx.target.config |
| | return _import_qkv(transformer_config, q, k, v) |
| |
|
| |
|
| | @io.state_transform( |
| | source_key=( |
| | "transformer_blocks.*.attn.to_q.bias", |
| | "transformer_blocks.*.attn.to_k.bias", |
| | "transformer_blocks.*.attn.to_v.bias", |
| | ), |
| | target_key=("double_blocks.*.self_attention.linear_qkv.bias"), |
| | ) |
| | def import_double_block_qkv_bias(ctx: io.TransformCTX, qb, kb, vb): |
| | transformer_config = ctx.target.config |
| | return _import_qkv_bias(transformer_config, qb, kb, vb) |
| |
|
| |
|
| | @io.state_transform( |
| | source_key=( |
| | "transformer_blocks.*.attn.add_q_proj.weight", |
| | "transformer_blocks.*.attn.add_k_proj.weight", |
| | "transformer_blocks.*.attn.add_v_proj.weight", |
| | ), |
| | target_key=("double_blocks.*.self_attention.added_linear_qkv.weight"), |
| | ) |
| | def import_added_qkv(ctx: io.TransformCTX, q, k, v): |
| | transformer_config = ctx.target.config |
| | return _import_qkv(transformer_config, q, k, v) |
| |
|
| |
|
| | @io.state_transform( |
| | source_key=( |
| | "transformer_blocks.*.attn.add_q_proj.bias", |
| | "transformer_blocks.*.attn.add_k_proj.bias", |
| | "transformer_blocks.*.attn.add_v_proj.bias", |
| | ), |
| | target_key=("double_blocks.*.self_attention.added_linear_qkv.bias"), |
| | ) |
| | def import_added_qkv_bias(ctx: io.TransformCTX, qb, kb, vb): |
| | transformer_config = ctx.target.config |
| | return _import_qkv_bias(transformer_config, qb, kb, vb) |
| |
|
| |
|
| | @io.state_transform( |
| | source_key=( |
| | "single_transformer_blocks.*.attn.to_q.weight", |
| | "single_transformer_blocks.*.attn.to_k.weight", |
| | "single_transformer_blocks.*.attn.to_v.weight", |
| | ), |
| | target_key=("single_blocks.*.self_attention.linear_qkv.weight"), |
| | ) |
| | def import_single_block_qkv(ctx: io.TransformCTX, q, k, v): |
| | transformer_config = ctx.target.config |
| | return _import_qkv(transformer_config, q, k, v) |
| |
|
| |
|
| | @io.state_transform( |
| | source_key=( |
| | "single_transformer_blocks.*.attn.to_q.bias", |
| | "single_transformer_blocks.*.attn.to_k.bias", |
| | "single_transformer_blocks.*.attn.to_v.bias", |
| | ), |
| | target_key=("single_blocks.*.self_attention.linear_qkv.bias"), |
| | ) |
| | def import_single_block_qkv_bias(ctx: io.TransformCTX, qb, kb, vb): |
| | transformer_config = ctx.target.config |
| | return _import_qkv_bias(transformer_config, qb, kb, vb) |
| |
|
| |
|
| | @io.state_transform( |
| | source_key=('single_transformer_blocks.*.proj_out.weight'), |
| | target_key=('single_blocks.*.mlp.linear_fc2.weight', 'single_blocks.*.self_attention.linear_proj.weight'), |
| | ) |
| | def transform_single_proj_out(proj_weight): |
| | linear_fc2 = proj_weight.detach()[:, 3072:].clone() |
| | linear_proj = proj_weight.detach()[:, :3072].clone() |
| | return linear_fc2, linear_proj |
| |
|
| |
|
| | |
| |
|