| """ |
| Training Utilities for Qwen-SDXL |
| Qwen-SDXL 训练工具 - 包含损失函数、训练步骤等 |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Dict, Any, Optional, Tuple |
| import math |
|
|
|
|
| class DiffusionLoss(nn.Module): |
| """ |
| Diffusion training loss for SDXL with Qwen embeddings |
| 使用 Qwen 嵌入的 SDXL 扩散训练损失 |
| """ |
| |
| def __init__( |
| self, |
| noise_scheduler, |
| loss_type: str = "mse", |
| snr_gamma: Optional[float] = None, |
| use_v_parameterization: bool = False |
| ): |
| super().__init__() |
| self.noise_scheduler = noise_scheduler |
| self.loss_type = loss_type |
| self.snr_gamma = snr_gamma |
| self.use_v_parameterization = use_v_parameterization |
| |
| if loss_type == "mse": |
| self.loss_fn = nn.MSELoss(reduction="none") |
| elif loss_type == "l1": |
| self.loss_fn = nn.L1Loss(reduction="none") |
| elif loss_type == "huber": |
| self.loss_fn = nn.HuberLoss(reduction="none", delta=0.1) |
| else: |
| raise ValueError(f"Unsupported loss type: {loss_type}") |
| |
| def compute_snr(self, timesteps): |
| """ |
| Compute signal-to-noise ratio for loss weighting |
| 计算信噪比用于损失加权 |
| """ |
| alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(timesteps.device)[timesteps] |
| sqrt_alphas_cumprod = alphas_cumprod**0.5 |
| sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 |
| |
| |
| snr = (sqrt_alphas_cumprod / sqrt_one_minus_alphas_cumprod) ** 2 |
| return snr |
| |
| def forward( |
| self, |
| model_pred: torch.Tensor, |
| target: torch.Tensor, |
| timesteps: torch.Tensor, |
| mask: Optional[torch.Tensor] = None |
| ) -> torch.Tensor: |
| """ |
| Compute diffusion loss |
| 计算扩散损失 |
| |
| Args: |
| model_pred: Model prediction (noise or v-parameterization) |
| target: Target (noise or v-parameterization) |
| timesteps: Diffusion timesteps |
| mask: Optional mask for selective loss computation |
| |
| Returns: |
| Loss tensor |
| """ |
| |
| loss = self.loss_fn(model_pred, target) |
| |
| |
| if mask is not None: |
| loss = loss * mask |
| |
| |
| loss = loss.mean(dim=list(range(1, len(loss.shape)))) |
| |
| |
| if self.snr_gamma is not None: |
| snr = self.compute_snr(timesteps) |
| |
| if self.snr_gamma >= 1.0: |
| |
| snr_weight = torch.stack([snr, torch.full_like(snr, self.snr_gamma)], dim=1).min(dim=1)[0] |
| else: |
| |
| snr_weight = snr ** self.snr_gamma |
| |
| loss = loss * snr_weight |
| |
| return loss.mean() |
|
|
|
|
| class AdapterTrainingStep: |
| """ |
| Training step for adapter-only training |
| 仅训练适配器的训练步骤 |
| """ |
| |
| def __init__( |
| self, |
| unet, |
| vae, |
| text_encoder, |
| adapter, |
| noise_scheduler, |
| loss_fn: DiffusionLoss, |
| device: str = "cuda", |
| dtype: torch.dtype = torch.bfloat16 |
| ): |
| self.unet = unet |
| self.vae = vae |
| self.text_encoder = text_encoder |
| self.adapter = adapter |
| self.noise_scheduler = noise_scheduler |
| self.loss_fn = loss_fn |
| self.device = device |
| self.dtype = dtype |
| |
| |
| self._freeze_components() |
| |
| def _freeze_components(self): |
| """Freeze all components except adapter""" |
| if self.unet is not None: |
| for param in self.unet.parameters(): |
| param.requires_grad = False |
| |
| if self.vae is not None: |
| for param in self.vae.parameters(): |
| param.requires_grad = False |
| |
| |
| |
| for param in self.adapter.parameters(): |
| param.requires_grad = True |
| |
| def prepare_inputs( |
| self, |
| images: torch.Tensor, |
| prompts: list, |
| negative_prompts: Optional[list] = None |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Prepare inputs for training step |
| 准备训练步骤的输入 |
| """ |
| batch_size = images.shape[0] |
| |
| |
| with torch.no_grad(): |
| latents = self.vae.encode(images.to(self.dtype)).latent_dist.sample() |
| latents = latents * self.vae.config.scaling_factor |
| |
| |
| noise = torch.randn_like(latents) |
| timesteps = torch.randint( |
| 0, self.noise_scheduler.config.num_train_timesteps, |
| (batch_size,), device=self.device |
| ).long() |
| |
| noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) |
| |
| |
| text_embeddings, pooled_embeddings = self.text_encoder.encode_prompts( |
| prompts, negative_prompts, do_classifier_free_guidance=False |
| ) |
| |
| |
| seq_len = 77 |
| text_embeddings_seq = text_embeddings.unsqueeze(1).expand(-1, seq_len, -1) |
| |
| encoder_hidden_states = self.adapter.forward_text_embeddings(text_embeddings_seq.to(self.dtype)) |
| pooled_prompt_embeds = self.adapter.forward_pooled_embeddings(pooled_embeddings.to(self.dtype)) |
| |
| |
| height, width = images.shape[-2:] |
| original_size = (height, width) |
| target_size = (height, width) |
| crops_coords_top_left = (0, 0) |
| |
| add_time_ids = list(original_size + crops_coords_top_left + target_size) |
| add_time_ids = torch.tensor([add_time_ids] * batch_size, dtype=self.dtype, device=self.device) |
| |
| return { |
| "noisy_latents": noisy_latents, |
| "timesteps": timesteps, |
| "encoder_hidden_states": encoder_hidden_states, |
| "pooled_prompt_embeds": pooled_prompt_embeds, |
| "add_time_ids": add_time_ids, |
| "noise": noise |
| } |
| |
| def training_step( |
| self, |
| images: torch.Tensor, |
| prompts: list, |
| negative_prompts: Optional[list] = None |
| ) -> Dict[str, Any]: |
| """ |
| Execute one training step |
| 执行一个训练步骤 |
| """ |
| |
| inputs = self.prepare_inputs(images, prompts, negative_prompts) |
| |
| |
| added_cond_kwargs = { |
| "text_embeds": inputs["pooled_prompt_embeds"], |
| "time_ids": inputs["add_time_ids"] |
| } |
| |
| model_pred = self.unet( |
| inputs["noisy_latents"], |
| inputs["timesteps"], |
| encoder_hidden_states=inputs["encoder_hidden_states"], |
| added_cond_kwargs=added_cond_kwargs, |
| return_dict=False, |
| )[0] |
| |
| |
| if self.loss_fn.use_v_parameterization: |
| |
| alphas_cumprod = self.noise_scheduler.alphas_cumprod[inputs["timesteps"]] |
| sqrt_alphas_cumprod = alphas_cumprod**0.5 |
| sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 |
| |
| target = sqrt_alphas_cumprod.view(-1, 1, 1, 1) * inputs["noise"] - \ |
| sqrt_one_minus_alphas_cumprod.view(-1, 1, 1, 1) * inputs["noisy_latents"] |
| else: |
| |
| target = inputs["noise"] |
| |
| loss = self.loss_fn(model_pred, target, inputs["timesteps"]) |
| |
| return { |
| "loss": loss, |
| "model_pred": model_pred.detach(), |
| "target": target.detach(), |
| "timesteps": inputs["timesteps"] |
| } |
|
|
|
|
| def get_cosine_schedule_with_warmup( |
| optimizer, |
| num_warmup_steps: int, |
| num_training_steps: int, |
| num_cycles: float = 0.5, |
| last_epoch: int = -1, |
| ): |
| """ |
| Cosine learning rate schedule with warmup |
| 带预热的余弦学习率调度 |
| """ |
| def lr_lambda(current_step): |
| if current_step < num_warmup_steps: |
| return float(current_step) / float(max(1, num_warmup_steps)) |
| |
| progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) |
| return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) |
| |
| from torch.optim.lr_scheduler import LambdaLR |
| return LambdaLR(optimizer, lr_lambda, last_epoch) |
|
|
|
|
| class EMAModel: |
| """ |
| Exponential Moving Average for model parameters |
| 模型参数的指数移动平均 |
| """ |
| |
| def __init__(self, model, decay: float = 0.9999): |
| self.model = model |
| self.decay = decay |
| self.shadow = {} |
| self.backup = {} |
| |
| |
| for name, param in model.named_parameters(): |
| if param.requires_grad: |
| self.shadow[name] = param.data.clone() |
| |
| def update(self): |
| """Update EMA parameters""" |
| for name, param in self.model.named_parameters(): |
| if param.requires_grad and name in self.shadow: |
| self.shadow[name] = self.decay * self.shadow[name] + (1 - self.decay) * param.data |
| |
| def apply_shadow(self): |
| """Apply EMA parameters to model""" |
| for name, param in self.model.named_parameters(): |
| if param.requires_grad and name in self.shadow: |
| self.backup[name] = param.data.clone() |
| param.data = self.shadow[name] |
| |
| def restore(self): |
| """Restore original parameters""" |
| for name, param in self.model.named_parameters(): |
| if param.requires_grad and name in self.backup: |
| param.data = self.backup[name] |
| self.backup = {} |
|
|