Respair's picture
Upload folder using huggingface_hub
b386992 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
from contextlib import nullcontext
from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable, Optional
import lightning.pytorch as L
import numpy as np
import torch
from megatron.core import parallel_state
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding
from megatron.core.models.common.vision_module.vision_module import VisionModule
from megatron.core.optimizer import OptimizerConfig
from megatron.core.transformer.enums import ModelType
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import openai_gelu, sharded_state_dict_default
from safetensors.torch import load_file as load_safetensors
from safetensors.torch import save_file as save_safetensors
from torch import nn
from torch.nn import functional as F
from nemo.collections.diffusion.encoders.conditioner import FrozenCLIPEmbedder, FrozenT5Embedder
from nemo.collections.diffusion.models.dit.dit_layer_spec import (
AdaLNContinuous,
FluxSingleTransformerBlock,
MMDiTLayer,
get_flux_double_transformer_engine_spec,
get_flux_single_transformer_engine_spec,
)
from nemo.collections.diffusion.models.flux.layers import EmbedND, MLPEmbedder, TimeStepEmbedder
from nemo.collections.diffusion.sampler.flow_matching.flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from nemo.collections.diffusion.utils.flux_ckpt_converter import (
_import_qkv,
_import_qkv_bias,
flux_transformer_converter,
)
from nemo.collections.diffusion.vae.autoencoder import AutoEncoder, AutoEncoderConfig
from nemo.collections.llm import fn
from nemo.lightning import io, teardown
from nemo.lightning.megatron_parallel import MaskedTokenLossReduction
from nemo.lightning.pytorch.optim import MegatronOptimizerModule, OptimizerModule
from nemo.utils import logging
# pylint: disable=C0116
def flux_data_step(dataloader_iter):
batch = next(dataloader_iter)
if isinstance(batch, tuple) and len(batch) == 3:
_batch = batch[0]
else:
_batch = batch
_batch['loss_mask'] = torch.Tensor([1.0]).cuda(non_blocking=True)
return _batch
@dataclass
class FluxConfig(TransformerConfig, io.IOMixin):
"""
transformer related Flux Config
"""
num_layers: int = 1 # dummy setting
num_joint_layers: int = 19
num_single_layers: int = 38
hidden_size: int = 3072
num_attention_heads: int = 24
activation_func: Callable = openai_gelu
add_qkv_bias: bool = True
in_channels: int = 64
context_dim: int = 4096
model_channels: int = 256
patch_size: int = 1
guidance_embed: bool = False
vec_in_dim: int = 768
rotary_interleaved: bool = True
apply_rope_fusion: bool = False
layernorm_epsilon: float = 1e-06
hidden_dropout: float = 0
attention_dropout: float = 0
use_cpu_initialization: bool = True
gradient_accumulation_fusion: bool = False
enable_cuda_graph: bool = False
use_te_rng_tracker: bool = False
cuda_graph_warmup_steps: int = 2
guidance_scale: float = 3.5
data_step_fn: Callable = flux_data_step
ckpt_path: Optional[str] = None
load_dist_ckpt: bool = False
do_convert_from_hf: bool = False
save_converted_model_to = None
def configure_model(self):
model = Flux(config=self)
return model
@dataclass
class T5Config:
"""
T5 Config
"""
version: Optional[str] = field(default_factory=lambda: "google/t5-v1_1-xxl")
max_length: Optional[int] = field(default_factory=lambda: 512)
load_config_only: bool = False
@dataclass
class ClipConfig:
"""
Clip Config
"""
version: Optional[str] = field(default_factory=lambda: "openai/clip-vit-large-patch14")
max_length: Optional[int] = field(default_factory=lambda: 77)
always_return_pooled: Optional[bool] = field(default_factory=lambda: True)
@dataclass
class FluxModelParams:
"""
Flux Model Params
"""
flux_config: FluxConfig = field(default_factory=FluxConfig)
vae_config: AutoEncoderConfig = field(
default_factory=lambda: AutoEncoderConfig(ch_mult=[1, 2, 4, 4], attn_resolutions=[])
)
clip_params: ClipConfig = field(default_factory=ClipConfig)
t5_params: T5Config = field(default_factory=T5Config)
scheduler_steps: int = 1000
device: str = 'cuda'
# pylint: disable=C0116
class Flux(VisionModule):
"""
NeMo implementation of Flux model, with flux transformer and single flux transformer blocks implemented with
Megatron Core.
Args:
config (FluxConfig): Configuration object containing the necessary parameters for setting up the model,
such as the number of channels, hidden size, attention heads, and more.
Attributes:
out_channels (int): The number of output channels for the model.
hidden_size (int): The size of the hidden layers.
num_attention_heads (int): The number of attention heads for the transformer.
patch_size (int): The size of the image patches.
in_channels (int): The number of input channels for the image.
guidance_embed (bool): A flag to indicate if guidance embedding should be used.
pos_embed (EmbedND): Position embedding layer for the model.
img_embed (nn.Linear): Linear layer to embed image input into the hidden space.
txt_embed (nn.Linear): Linear layer to embed text input into the hidden space.
timestep_embedding (TimeStepEmbedder): Embedding layer for timesteps, used in generative models.
vector_embedding (MLPEmbedder): MLP embedding for vector inputs.
guidance_embedding (nn.Module or nn.Identity): Optional MLP embedding for guidance, or identity if not used.
double_blocks (nn.ModuleList): A list of transformer blocks for the double block layers.
single_blocks (nn.ModuleList): A list of transformer blocks for the single block layers.
norm_out (AdaLNContinuous): Normalization layer for the output.
proj_out (nn.Linear): Final linear layer for output projection.
Methods:
forward: Performs a forward pass through the network, processing images, text, timesteps, and guidance.
load_from_pretrained: Loads model weights from a pretrained checkpoint, with optional support for distribution
and conversion from Hugging Face format.
"""
def __init__(self, config: FluxConfig):
# pylint: disable=C0116
super().__init__(config)
self.out_channels = config.in_channels
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.patch_size = config.patch_size
self.in_channels = config.in_channels
self.guidance_embed = config.guidance_embed
self.pos_embed = EmbedND(dim=self.hidden_size, theta=10000, axes_dim=[16, 56, 56])
self.img_embed = nn.Linear(config.in_channels, self.hidden_size)
self.txt_embed = nn.Linear(config.context_dim, self.hidden_size)
self.timestep_embedding = TimeStepEmbedder(config.model_channels, self.hidden_size)
self.vector_embedding = MLPEmbedder(in_dim=config.vec_in_dim, hidden_dim=self.hidden_size)
if config.guidance_embed:
self.guidance_embedding = (
MLPEmbedder(in_dim=config.model_channels, hidden_dim=self.hidden_size)
if config.guidance_embed
else nn.Identity()
)
self.double_blocks = nn.ModuleList(
[
MMDiTLayer(
config=config,
submodules=get_flux_double_transformer_engine_spec().submodules,
layer_number=i,
context_pre_only=False,
)
for i in range(config.num_joint_layers)
]
)
self.single_blocks = nn.ModuleList(
[
FluxSingleTransformerBlock(
config=config,
submodules=get_flux_single_transformer_engine_spec().submodules,
layer_number=i,
)
for i in range(config.num_single_layers)
]
)
self.norm_out = AdaLNContinuous(config=config, conditioning_embedding_dim=self.hidden_size)
self.proj_out = nn.Linear(self.hidden_size, self.patch_size * self.patch_size * self.out_channels, bias=True)
if self.config.ckpt_path is not None:
self.load_from_pretrained(
self.config.ckpt_path,
do_convert_from_hf=self.config.do_convert_from_hf,
load_dist_ckpt=self.config.load_dist_ckpt,
save_converted_model_to=self.config.save_converted_model_to,
)
def get_fp8_context(self):
# This is first and last 2 for mamba
if not self.config.fp8:
fp8_context = nullcontext()
else:
import transformer_engine # To keep out TE dependency when not training in fp8
if self.config.fp8 == "e4m3":
fp8_format = transformer_engine.common.recipe.Format.E4M3
elif self.config.fp8 == "hybrid":
fp8_format = transformer_engine.common.recipe.Format.HYBRID
else:
raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.")
fp8_recipe = transformer_engine.common.recipe.DelayedScaling(
margin=self.config.fp8_margin,
interval=self.config.fp8_interval,
fp8_format=fp8_format,
amax_compute_algo=self.config.fp8_amax_compute_algo,
amax_history_len=self.config.fp8_amax_history_len,
override_linear_precision=(False, False, not self.config.fp8_wgrad),
)
fp8_group = None
if parallel_state.model_parallel_is_initialized():
fp8_group = parallel_state.get_amax_reduction_group(with_context_parallel=True)
fp8_context = transformer_engine.pytorch.fp8_autocast(
enabled=True, fp8_recipe=fp8_recipe, fp8_group=fp8_group
)
return fp8_context
def forward(
self,
img: torch.Tensor,
txt: torch.Tensor = None,
y: torch.Tensor = None,
timesteps: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
controlnet_double_block_samples: torch.Tensor = None,
controlnet_single_block_samples: torch.Tensor = None,
):
"""
Forward pass through the model, processing image, text, and additional inputs like guidance and timesteps.
Args:
img (torch.Tensor):
The image input tensor.
txt (torch.Tensor, optional):
The text input tensor (default is None).
y (torch.Tensor, optional):
The vector input for embedding (default is None).
timesteps (torch.LongTensor, optional):
The timestep input, typically used in generative models (default is None).
img_ids (torch.Tensor, optional):
Image IDs for positional encoding (default is None).
txt_ids (torch.Tensor, optional):
Text IDs for positional encoding (default is None).
guidance (torch.Tensor, optional):
Guidance input for conditioning (default is None).
controlnet_double_block_samples (torch.Tensor, optional):
Optional controlnet samples for double blocks (default is None).
controlnet_single_block_samples (torch.Tensor, optional):
Optional controlnet samples for single blocks (default is None).
Returns:
torch.Tensor: The final output tensor from the model after processing all inputs.
"""
hidden_states = self.img_embed(img)
encoder_hidden_states = self.txt_embed(txt)
timesteps = timesteps.to(img.dtype) * 1000
vec_emb = self.timestep_embedding(timesteps)
if guidance is not None:
vec_emb = vec_emb + self.guidance_embedding(self.timestep_embedding.time_proj(guidance * 1000))
vec_emb = vec_emb + self.vector_embedding(y)
ids = torch.cat((txt_ids, img_ids), dim=1)
rotary_pos_emb = self.pos_embed(ids)
for id_block, block in enumerate(self.double_blocks):
with self.get_fp8_context():
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
rotary_pos_emb=rotary_pos_emb,
emb=vec_emb,
)
if controlnet_double_block_samples is not None:
interval_control = len(self.double_blocks) / len(controlnet_double_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states = hidden_states + controlnet_double_block_samples[id_block // interval_control]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=0)
for id_block, block in enumerate(self.single_blocks):
with self.get_fp8_context():
hidden_states, _ = block(
hidden_states=hidden_states,
rotary_pos_emb=rotary_pos_emb,
emb=vec_emb,
)
if controlnet_single_block_samples is not None:
interval_control = len(self.single_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states = torch.cat(
[
hidden_states[: encoder_hidden_states.shape[0]],
hidden_states[encoder_hidden_states.shape[0] :]
+ controlnet_single_block_samples[id_block // interval_control],
]
)
hidden_states = hidden_states[encoder_hidden_states.shape[0] :, ...]
hidden_states = self.norm_out(hidden_states, vec_emb)
output = self.proj_out(hidden_states)
return output
def load_from_pretrained(
self, ckpt_path, do_convert_from_hf=False, save_converted_model_to=None, load_dist_ckpt=False
):
# pylint: disable=C0116
if load_dist_ckpt:
from megatron.core import dist_checkpointing
sharded_state_dict = dict(state_dict=self.sharded_state_dict(prefix="module."))
loaded_state_dict = dist_checkpointing.load(
sharded_state_dict=sharded_state_dict, checkpoint_dir=ckpt_path
)
ckpt = {k.removeprefix("module."): v for k, v in loaded_state_dict["state_dict"].items()}
else:
if do_convert_from_hf:
ckpt = flux_transformer_converter(ckpt_path, self.config)
if save_converted_model_to is not None:
os.makedirs(save_converted_model_to, exist_ok=True)
save_path = os.path.join(save_converted_model_to, 'nemo_flux_transformer.safetensors')
save_safetensors(ckpt, save_path)
logging.info(f'saving converted transformer checkpoint to {save_path}')
else:
ckpt = load_safetensors(ckpt_path)
missing, unexpected = self.load_state_dict(ckpt, strict=False)
missing = [k for k in missing if not k.endswith('_extra_state')]
# These keys are mcore specific and should not affect the model performance
if len(missing) > 0:
logging.info(
f"The following keys are missing during checkpoint loading, "
f"please check the ckpt provided or the image quality may be compromised.\n {missing}"
)
logging.info(f"Found unexepected keys: \n {unexpected}")
logging.info(f"Restored flux model weights from {ckpt_path}")
def sharded_state_dict(self, prefix='', sharded_offsets: tuple = (), metadata: dict = None) -> ShardedStateDict:
sharded_state_dict = {}
layer_prefix = f'{prefix}double_blocks.'
for layer in self.double_blocks:
offset = layer._get_layer_offset(self.config)
global_layer_offset = layer.layer_number
state_dict_prefix = f'{layer_prefix}{global_layer_offset - offset}.'
sharded_prefix = f'{layer_prefix}{global_layer_offset}.'
sharded_pp_offset = []
layer_sharded_state_dict = layer.sharded_state_dict(state_dict_prefix, sharded_pp_offset, metadata)
replace_prefix_for_sharding(layer_sharded_state_dict, state_dict_prefix, sharded_prefix)
sharded_state_dict.update(layer_sharded_state_dict)
layer_prefix = f'{prefix}single_blocks.'
for layer in self.single_blocks:
offset = layer._get_layer_offset(self.config)
global_layer_offset = layer.layer_number
state_dict_prefix = f'{layer_prefix}{global_layer_offset - offset}.'
sharded_prefix = f'{layer_prefix}{global_layer_offset}.'
sharded_pp_offset = []
layer_sharded_state_dict = layer.sharded_state_dict(state_dict_prefix, sharded_pp_offset, metadata)
replace_prefix_for_sharding(layer_sharded_state_dict, state_dict_prefix, sharded_prefix)
sharded_state_dict.update(layer_sharded_state_dict)
for name, module in self.named_children():
if not (module is self.single_blocks or module is self.double_blocks):
sharded_state_dict.update(
sharded_state_dict_default(module, f'{prefix}{name}.', sharded_offsets, metadata)
)
return sharded_state_dict
class MegatronFluxModel(L.LightningModule, io.IOMixin, io.ConnectorMixin, fn.FNMixin):
'''
Megatron wrapper for flux.
Args:
flux_params (FluxModelParams): Parameters to configure the Flux model.
'''
def __init__(
self,
flux_params: FluxModelParams,
optim: Optional[OptimizerModule] = None,
):
# pylint: disable=C0116
self.params = flux_params
self.config = flux_params.flux_config
super().__init__()
self._training_loss_reduction = None
self._validation_loss_reduction = None
self.vae_config = self.params.vae_config
self.clip_params = self.params.clip_params
self.t5_params = self.params.t5_params
self.optim = optim or MegatronOptimizerModule(config=OptimizerConfig(lr=1e-4, use_distributed_optimizer=False))
self.optim.connect(self)
self.model_type = ModelType.encoder_or_decoder
self.text_precached = self.t5_params is None or self.clip_params is None
self.image_precached = self.vae_config is None
def configure_model(self):
# pylint: disable=C0116
if not hasattr(self, "module"):
self.module = self.config.configure_model()
self.configure_vae(self.vae_config)
self.configure_scheduler()
self.configure_text_encoders(self.clip_params, self.t5_params)
for name, param in self.module.named_parameters():
if self.config.num_single_layers == 0:
if 'context' in name or 'added' in name:
param.requires_grad = False
# When getting rid of concat, the projection bias in attention and mlp bias are identical
# So this bias is skipped and not included in the computation graph
if 'single_blocks' in name and 'self_attention.linear_proj.bias' in name:
param.requires_grad = False
def configure_scheduler(self):
# pylint: disable=C0116
self.scheduler = FlowMatchEulerDiscreteScheduler(
num_train_timesteps=self.params.scheduler_steps,
)
def configure_vae(self, vae):
# pylint: disable=C0116
if isinstance(vae, nn.Module):
self.vae = vae.eval().cuda()
self.vae_scale_factor = 2 ** (len(self.vae.params.ch_mult))
for param in self.vae.parameters():
param.requires_grad = False
elif isinstance(vae, AutoEncoderConfig):
self.vae = AutoEncoder(vae).eval().cuda()
self.vae_scale_factor = 2 ** (len(vae.ch_mult))
for param in self.vae.parameters():
param.requires_grad = False
else:
logging.info("Vae not provided, assuming the image input is precached...")
self.vae = None
self.vae_scale_factor = 16
def configure_text_encoders(self, clip, t5):
# pylint: disable=C0116
if isinstance(clip, nn.Module):
self.clip = clip
elif isinstance(clip, ClipConfig):
self.clip = FrozenCLIPEmbedder(
version=self.clip_params.version,
max_length=self.clip_params.max_length,
always_return_pooled=self.clip_params.always_return_pooled,
device=torch.cuda.current_device(),
)
else:
logging.info("CLIP encoder not provided, assuming the text embeddings is precached...")
self.clip = None
if isinstance(t5, nn.Module):
self.t5 = t5
elif isinstance(t5, T5Config):
self.t5 = FrozenT5Embedder(
self.t5_params.version,
max_length=self.t5_params.max_length,
device=torch.cuda.current_device(),
load_config_only=self.t5_params.load_config_only,
)
else:
logging.info("T5 encoder not provided, assuming the text embeddings is precached...")
self.t5 = None
# pylint: disable=C0116
def data_step(self, dataloader_iter):
return self.config.data_step_fn(dataloader_iter)
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
def training_step(self, batch, batch_idx=None) -> torch.Tensor:
# In mcore the loss-function is part of the forward-pass (when labels are provided)
return self.forward_step(batch)
def validation_step(self, batch, batch_idx=None) -> torch.Tensor:
# In mcore the loss-function is part of the forward-pass (when labels are provided)
return self.forward_step(batch)
# pylint: disable=C0116
def forward_step(self, batch) -> torch.Tensor:
# pylint: disable=C0116
if self.optim.config.bf16:
self.autocast_dtype = torch.bfloat16
elif self.optim.config.fp16:
self.autocast_dtype = torch.float
else:
self.autocast_dtype = torch.float32
if self.image_precached:
latents = batch['latents'].cuda(non_blocking=True)
else:
img = batch['images'].cuda(non_blocking=True)
latents = self.vae.encode(img).to(dtype=self.autocast_dtype)
latents, noise, packed_noisy_model_input, latent_image_ids, guidance_vec, timesteps = (
self.prepare_image_latent(latents)
)
if self.text_precached:
prompt_embeds = batch['prompt_embeds'].cuda(non_blocking=True).transpose(0, 1)
pooled_prompt_embeds = batch['pooled_prompt_embeds'].cuda(non_blocking=True)
text_ids = batch['text_ids'].cuda(non_blocking=True)
else:
txt = batch['txt']
prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
txt, device=latents.device, dtype=latents.dtype
)
with torch.cuda.amp.autocast(
self.autocast_dtype in (torch.half, torch.bfloat16),
dtype=self.autocast_dtype,
):
noise_pred = self.forward(
img=packed_noisy_model_input,
txt=prompt_embeds,
y=pooled_prompt_embeds,
timesteps=timesteps / 1000,
img_ids=latent_image_ids,
txt_ids=text_ids,
guidance=guidance_vec,
)
noise_pred = self._unpack_latents(
noise_pred.transpose(0, 1),
int(latents.shape[2] * self.vae_scale_factor // 2),
int(latents.shape[3] * self.vae_scale_factor // 2),
vae_scale_factor=self.vae_scale_factor,
).transpose(0, 1)
target = noise - latents
loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
return loss
def encode_prompt(self, prompt, device='cuda', dtype=torch.float32):
# pylint: disable=C0116
prompt_embeds = self.t5(prompt).transpose(0, 1)
_, pooled_prompt_embeds = self.clip(prompt)
text_ids = torch.zeros(prompt_embeds.shape[1], prompt_embeds.shape[0], 3).to(device=device, dtype=dtype)
return prompt_embeds, pooled_prompt_embeds.to(dtype=dtype), text_ids
def compute_density_for_timestep_sampling(
self,
weighting_scheme: str,
batch_size: int,
logit_mean: float = 0.0,
logit_std: float = 1.0,
mode_scale: float = 1.29,
):
"""
Compute the density for sampling the timesteps when doing SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
u = torch.nn.functional.sigmoid(u)
elif weighting_scheme == "mode":
u = torch.rand(size=(batch_size,), device="cpu")
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
else:
u = torch.rand(size=(batch_size,), device="cpu")
return u
def prepare_image_latent(self, latents):
# pylint: disable=C0116
latent_image_ids = self._prepare_latent_image_ids(
latents.shape[0],
latents.shape[2],
latents.shape[3],
latents.device,
latents.dtype,
)
noise = torch.randn_like(latents, device=latents.device, dtype=latents.dtype)
batch_size = latents.shape[0]
u = self.compute_density_for_timestep_sampling(
"logit_normal",
batch_size,
)
indices = (u * self.scheduler.num_train_timesteps).long()
timesteps = self.scheduler.timesteps[indices].to(device=latents.device)
sigmas = self.scheduler.sigmas.to(device=latents.device, dtype=latents.dtype)
schduler_timesteps = self.scheduler.timesteps.to(device=latents.device)
step_indices = [(schduler_timesteps == t).nonzero().item() for t in timesteps]
timesteps = timesteps.to(dtype=latents.dtype)
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < latents.ndim:
sigma = sigma.unsqueeze(-1)
noisy_model_input = (1.0 - sigma) * latents + sigma * noise
packed_noisy_model_input = self._pack_latents(
noisy_model_input,
batch_size=latents.shape[0],
num_channels_latents=latents.shape[1],
height=latents.shape[2],
width=latents.shape[3],
)
if self.config.guidance_embed:
guidance_vec = torch.full(
(noisy_model_input.shape[0],),
self.config.guidance_scale,
device=latents.device,
dtype=latents.dtype,
)
else:
guidance_vec = None
return (
latents.transpose(0, 1),
noise.transpose(0, 1),
packed_noisy_model_input.transpose(0, 1),
latent_image_ids,
guidance_vec,
timesteps,
)
def _unpack_latents(self, latents, height, width, vae_scale_factor):
# pylint: disable=C0116
batch_size, num_patches, channels = latents.shape
height = height // vae_scale_factor
width = width // vae_scale_factor
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
return latents
def _prepare_latent_image_ids(
self, batch_size: int, height: int, width: int, device: torch.device, dtype: torch.dtype
):
# pylint: disable=C0116
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
latent_image_ids = latent_image_ids.reshape(
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
return latent_image_ids.to(device=device, dtype=dtype)
def _pack_latents(self, latents, batch_size, num_channels_latents, height, width):
# pylint: disable=C0116
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
def set_input_tensor(self, tensor):
# pylint: disable=C0116
pass
@property
def training_loss_reduction(self) -> MaskedTokenLossReduction:
# pylint: disable=C0116
if not self._training_loss_reduction:
self._training_loss_reduction = MaskedTokenLossReduction()
return self._training_loss_reduction
@property
def validation_loss_reduction(self) -> MaskedTokenLossReduction:
# pylint: disable=C0116
# pylint: disable=C0116
if not self._validation_loss_reduction:
self._validation_loss_reduction = MaskedTokenLossReduction(validation_step=True)
return self._validation_loss_reduction
@io.model_importer(MegatronFluxModel, "hf")
class HFFluxImporter(io.ModelConnector["FluxTransformer2DModel", MegatronFluxModel]):
'''
Convert a HF ckpt into NeMo dist-ckpt compatible format.
'''
# pylint: disable=C0116
def init(self) -> MegatronFluxModel:
return MegatronFluxModel(self.config)
def apply(self, output_path: Path) -> Path:
from diffusers import FluxTransformer2DModel
source = FluxTransformer2DModel.from_pretrained(str(self), subfolder="transformer")
target = self.init()
trainer = self.nemo_setup(target)
self.convert_state(source, target)
print(f"Converted flux transformer to Nemo, saving to {output_path}")
self.nemo_save(output_path, trainer)
print(f"Converted flux transformer saved to {output_path}")
teardown(trainer, target)
return output_path
@property
def config(self) -> FluxConfig:
from diffusers import FluxTransformer2DModel
source = FluxTransformer2DModel.from_pretrained(str(self), subfolder="transformer")
source_config = source.config
flux_config = FluxConfig(
num_layers=1, # dummy setting
num_joint_layers=source_config.num_layers,
num_single_layers=source_config.num_single_layers,
hidden_size=source_config.num_attention_heads * source_config.attention_head_dim,
num_attention_heads=source_config.num_attention_heads,
activation_func=openai_gelu,
add_qkv_bias=True,
in_channels=source_config.in_channels,
context_dim=source_config.joint_attention_dim,
model_channels=256,
patch_size=source_config.patch_size,
guidance_embed=source_config.guidance_embeds,
vec_in_dim=source_config.pooled_projection_dim,
rotary_interleaved=True,
layernorm_epsilon=1e-06,
hidden_dropout=0,
attention_dropout=0,
use_cpu_initialization=True,
)
output = FluxModelParams(
flux_config=flux_config,
vae_config=None,
clip_params=None,
t5_params=None,
scheduler_steps=1000,
device='cuda',
)
return output
def convert_state(self, source, target):
# pylint: disable=C0301
mapping = {
'transformer_blocks.*.norm1.linear.weight': 'double_blocks.*.adaln.adaLN_modulation.1.weight',
'transformer_blocks.*.norm1.linear.bias': 'double_blocks.*.adaln.adaLN_modulation.1.bias',
'transformer_blocks.*.norm1_context.linear.weight': 'double_blocks.*.adaln_context.adaLN_modulation.1.weight',
'transformer_blocks.*.norm1_context.linear.bias': 'double_blocks.*.adaln_context.adaLN_modulation.1.bias',
'transformer_blocks.*.attn.norm_q.weight': 'double_blocks.*.self_attention.q_layernorm.weight',
'transformer_blocks.*.attn.norm_k.weight': 'double_blocks.*.self_attention.k_layernorm.weight',
'transformer_blocks.*.attn.norm_added_q.weight': 'double_blocks.*.self_attention.added_q_layernorm.weight',
'transformer_blocks.*.attn.norm_added_k.weight': 'double_blocks.*.self_attention.added_k_layernorm.weight',
'transformer_blocks.*.attn.to_out.0.weight': 'double_blocks.*.self_attention.linear_proj.weight',
'transformer_blocks.*.attn.to_out.0.bias': 'double_blocks.*.self_attention.linear_proj.bias',
'transformer_blocks.*.attn.to_add_out.weight': 'double_blocks.*.self_attention.added_linear_proj.weight',
'transformer_blocks.*.attn.to_add_out.bias': 'double_blocks.*.self_attention.added_linear_proj.bias',
'transformer_blocks.*.ff.net.0.proj.weight': 'double_blocks.*.mlp.linear_fc1.weight',
'transformer_blocks.*.ff.net.0.proj.bias': 'double_blocks.*.mlp.linear_fc1.bias',
'transformer_blocks.*.ff.net.2.weight': 'double_blocks.*.mlp.linear_fc2.weight',
'transformer_blocks.*.ff.net.2.bias': 'double_blocks.*.mlp.linear_fc2.bias',
'transformer_blocks.*.ff_context.net.0.proj.weight': 'double_blocks.*.context_mlp.linear_fc1.weight',
'transformer_blocks.*.ff_context.net.0.proj.bias': 'double_blocks.*.context_mlp.linear_fc1.bias',
'transformer_blocks.*.ff_context.net.2.weight': 'double_blocks.*.context_mlp.linear_fc2.weight',
'transformer_blocks.*.ff_context.net.2.bias': 'double_blocks.*.context_mlp.linear_fc2.bias',
'single_transformer_blocks.*.norm.linear.weight': 'single_blocks.*.adaln.adaLN_modulation.1.weight',
'single_transformer_blocks.*.norm.linear.bias': 'single_blocks.*.adaln.adaLN_modulation.1.bias',
'single_transformer_blocks.*.proj_mlp.weight': 'single_blocks.*.mlp.linear_fc1.weight',
'single_transformer_blocks.*.proj_mlp.bias': 'single_blocks.*.mlp.linear_fc1.bias',
'single_transformer_blocks.*.attn.norm_q.weight': 'single_blocks.*.self_attention.q_layernorm.weight',
'single_transformer_blocks.*.attn.norm_k.weight': 'single_blocks.*.self_attention.k_layernorm.weight',
'single_transformer_blocks.*.proj_out.bias': 'single_blocks.*.mlp.linear_fc2.bias',
'norm_out.linear.bias': 'norm_out.adaLN_modulation.1.bias',
'norm_out.linear.weight': 'norm_out.adaLN_modulation.1.weight',
'proj_out.bias': 'proj_out.bias',
'proj_out.weight': 'proj_out.weight',
'time_text_embed.guidance_embedder.linear_1.bias': 'guidance_embedding.in_layer.bias',
'time_text_embed.guidance_embedder.linear_1.weight': 'guidance_embedding.in_layer.weight',
'time_text_embed.guidance_embedder.linear_2.bias': 'guidance_embedding.out_layer.bias',
'time_text_embed.guidance_embedder.linear_2.weight': 'guidance_embedding.out_layer.weight',
'x_embedder.bias': 'img_embed.bias',
'x_embedder.weight': 'img_embed.weight',
'time_text_embed.timestep_embedder.linear_1.bias': 'timestep_embedding.time_embedder.in_layer.bias',
'time_text_embed.timestep_embedder.linear_1.weight': 'timestep_embedding.time_embedder.in_layer.weight',
'time_text_embed.timestep_embedder.linear_2.bias': 'timestep_embedding.time_embedder.out_layer.bias',
'time_text_embed.timestep_embedder.linear_2.weight': 'timestep_embedding.time_embedder.out_layer.weight',
'context_embedder.bias': 'txt_embed.bias',
'context_embedder.weight': 'txt_embed.weight',
'time_text_embed.text_embedder.linear_1.bias': 'vector_embedding.in_layer.bias',
'time_text_embed.text_embedder.linear_1.weight': 'vector_embedding.in_layer.weight',
'time_text_embed.text_embedder.linear_2.bias': 'vector_embedding.out_layer.bias',
'time_text_embed.text_embedder.linear_2.weight': 'vector_embedding.out_layer.weight',
}
return io.apply_transforms(
source,
target,
mapping=mapping,
transforms=[
import_double_block_qkv,
import_double_block_qkv_bias,
import_added_qkv,
import_added_qkv_bias,
import_single_block_qkv,
import_single_block_qkv_bias,
transform_single_proj_out,
],
)
@io.state_transform(
source_key=(
"transformer_blocks.*.attn.to_q.weight",
"transformer_blocks.*.attn.to_k.weight",
"transformer_blocks.*.attn.to_v.weight",
),
target_key=("double_blocks.*.self_attention.linear_qkv.weight"),
)
def import_double_block_qkv(ctx: io.TransformCTX, q, k, v):
transformer_config = ctx.target.config
return _import_qkv(transformer_config, q, k, v)
@io.state_transform(
source_key=(
"transformer_blocks.*.attn.to_q.bias",
"transformer_blocks.*.attn.to_k.bias",
"transformer_blocks.*.attn.to_v.bias",
),
target_key=("double_blocks.*.self_attention.linear_qkv.bias"),
)
def import_double_block_qkv_bias(ctx: io.TransformCTX, qb, kb, vb):
transformer_config = ctx.target.config
return _import_qkv_bias(transformer_config, qb, kb, vb)
@io.state_transform(
source_key=(
"transformer_blocks.*.attn.add_q_proj.weight",
"transformer_blocks.*.attn.add_k_proj.weight",
"transformer_blocks.*.attn.add_v_proj.weight",
),
target_key=("double_blocks.*.self_attention.added_linear_qkv.weight"),
)
def import_added_qkv(ctx: io.TransformCTX, q, k, v):
transformer_config = ctx.target.config
return _import_qkv(transformer_config, q, k, v)
@io.state_transform(
source_key=(
"transformer_blocks.*.attn.add_q_proj.bias",
"transformer_blocks.*.attn.add_k_proj.bias",
"transformer_blocks.*.attn.add_v_proj.bias",
),
target_key=("double_blocks.*.self_attention.added_linear_qkv.bias"),
)
def import_added_qkv_bias(ctx: io.TransformCTX, qb, kb, vb):
transformer_config = ctx.target.config
return _import_qkv_bias(transformer_config, qb, kb, vb)
@io.state_transform(
source_key=(
"single_transformer_blocks.*.attn.to_q.weight",
"single_transformer_blocks.*.attn.to_k.weight",
"single_transformer_blocks.*.attn.to_v.weight",
),
target_key=("single_blocks.*.self_attention.linear_qkv.weight"),
)
def import_single_block_qkv(ctx: io.TransformCTX, q, k, v):
transformer_config = ctx.target.config
return _import_qkv(transformer_config, q, k, v)
@io.state_transform(
source_key=(
"single_transformer_blocks.*.attn.to_q.bias",
"single_transformer_blocks.*.attn.to_k.bias",
"single_transformer_blocks.*.attn.to_v.bias",
),
target_key=("single_blocks.*.self_attention.linear_qkv.bias"),
)
def import_single_block_qkv_bias(ctx: io.TransformCTX, qb, kb, vb):
transformer_config = ctx.target.config
return _import_qkv_bias(transformer_config, qb, kb, vb)
@io.state_transform(
source_key=('single_transformer_blocks.*.proj_out.weight'),
target_key=('single_blocks.*.mlp.linear_fc2.weight', 'single_blocks.*.self_attention.linear_proj.weight'),
)
def transform_single_proj_out(proj_weight):
linear_fc2 = proj_weight.detach()[:, 3072:].clone()
linear_proj = proj_weight.detach()[:, :3072].clone()
return linear_fc2, linear_proj
# pylint: disable=C0116