| import lightning as L |
| from diffusers.pipelines import FluxPipeline |
| import torch |
| from peft import LoraConfig, get_peft_model_state_dict |
|
|
| import prodigyopt |
|
|
| from ..flux.transformer import tranformer_forward |
| from ..flux.condition import Condition |
| from ..flux.pipeline_tools import encode_images, prepare_text_input |
|
|
|
|
| class OminiModel(L.LightningModule): |
| def __init__( |
| self, |
| flux_pipe_id: str, |
| lora_path: str = None, |
| lora_config: dict = None, |
| device: str = "cuda", |
| dtype: torch.dtype = torch.bfloat16, |
| model_config: dict = {}, |
| optimizer_config: dict = None, |
| gradient_checkpointing: bool = False, |
| ): |
| |
| super().__init__() |
| self.model_config = model_config |
| self.optimizer_config = optimizer_config |
|
|
| |
| self.flux_pipe: FluxPipeline = ( |
| FluxPipeline.from_pretrained(flux_pipe_id).to(dtype=dtype).to(device) |
| ) |
| self.transformer = self.flux_pipe.transformer |
| self.transformer.gradient_checkpointing = gradient_checkpointing |
| self.transformer.train() |
|
|
| |
| self.flux_pipe.text_encoder.requires_grad_(False).eval() |
| self.flux_pipe.text_encoder_2.requires_grad_(False).eval() |
| self.flux_pipe.vae.requires_grad_(False).eval() |
|
|
| |
| self.lora_layers = self.init_lora(lora_path, lora_config) |
|
|
| self.to(device).to(dtype) |
|
|
| def init_lora(self, lora_path: str, lora_config: dict): |
| assert lora_path or lora_config |
| if lora_path: |
| |
| raise NotImplementedError |
| else: |
| self.transformer.add_adapter(LoraConfig(**lora_config)) |
| |
| lora_layers = filter( |
| lambda p: p.requires_grad, self.transformer.parameters() |
| ) |
| return list(lora_layers) |
|
|
| def save_lora(self, path: str): |
| FluxPipeline.save_lora_weights( |
| save_directory=path, |
| transformer_lora_layers=get_peft_model_state_dict(self.transformer), |
| safe_serialization=True, |
| ) |
|
|
| def configure_optimizers(self): |
| |
| self.transformer.requires_grad_(False) |
| opt_config = self.optimizer_config |
|
|
| |
| self.trainable_params = self.lora_layers |
|
|
| |
| for p in self.trainable_params: |
| p.requires_grad_(True) |
|
|
| |
| if opt_config["type"] == "AdamW": |
| optimizer = torch.optim.AdamW(self.trainable_params, **opt_config["params"]) |
| elif opt_config["type"] == "Prodigy": |
| optimizer = prodigyopt.Prodigy( |
| self.trainable_params, |
| **opt_config["params"], |
| ) |
| elif opt_config["type"] == "SGD": |
| optimizer = torch.optim.SGD(self.trainable_params, **opt_config["params"]) |
| else: |
| raise NotImplementedError |
|
|
| return optimizer |
|
|
| def training_step(self, batch, batch_idx): |
| step_loss = self.step(batch) |
| self.log_loss = ( |
| step_loss.item() |
| if not hasattr(self, "log_loss") |
| else self.log_loss * 0.95 + step_loss.item() * 0.05 |
| ) |
| return step_loss |
|
|
| def step(self, batch): |
| imgs = batch["image"] |
| conditions = batch["condition"] |
| condition_types = batch["condition_type"] |
| prompts = batch["description"] |
| position_delta = batch["position_delta"][0] |
| position_scale = float(batch.get("position_scale", [1.0])[0]) |
|
|
| |
| with torch.no_grad(): |
| |
| x_0, img_ids = encode_images(self.flux_pipe, imgs) |
|
|
| |
| prompt_embeds, pooled_prompt_embeds, text_ids = prepare_text_input( |
| self.flux_pipe, prompts |
| ) |
|
|
| |
| t = torch.sigmoid(torch.randn((imgs.shape[0],), device=self.device)) |
| x_1 = torch.randn_like(x_0).to(self.device) |
| t_ = t.unsqueeze(1).unsqueeze(1) |
| x_t = ((1 - t_) * x_0 + t_ * x_1).to(self.dtype) |
|
|
| |
| condition_latents, condition_ids = encode_images(self.flux_pipe, conditions) |
|
|
| |
| condition_ids[:, 1] += position_delta[0] |
| condition_ids[:, 2] += position_delta[1] |
|
|
| if position_scale != 1.0: |
| scale_bias = (position_scale - 1.0) / 2 |
| condition_ids[:, 1] *= position_scale |
| condition_ids[:, 2] *= position_scale |
| condition_ids[:, 1] += scale_bias |
| condition_ids[:, 2] += scale_bias |
|
|
| |
| condition_type_ids = torch.tensor( |
| [ |
| Condition.get_type_id(condition_type) |
| for condition_type in condition_types |
| ] |
| ).to(self.device) |
| condition_type_ids = ( |
| torch.ones_like(condition_ids[:, 0]) * condition_type_ids[0] |
| ).unsqueeze(1) |
|
|
| |
| guidance = ( |
| torch.ones_like(t).to(self.device) |
| if self.transformer.config.guidance_embeds |
| else None |
| ) |
|
|
| |
| transformer_out = tranformer_forward( |
| self.transformer, |
| |
| model_config=self.model_config, |
| |
| condition_latents=condition_latents, |
| condition_ids=condition_ids, |
| condition_type_ids=condition_type_ids, |
| |
| hidden_states=x_t, |
| timestep=t, |
| guidance=guidance, |
| pooled_projections=pooled_prompt_embeds, |
| encoder_hidden_states=prompt_embeds, |
| txt_ids=text_ids, |
| img_ids=img_ids, |
| joint_attention_kwargs=None, |
| return_dict=False, |
| ) |
| pred = transformer_out[0] |
|
|
| |
| loss = torch.nn.functional.mse_loss(pred, (x_1 - x_0), reduction="mean") |
| self.last_t = t.mean().item() |
| return loss |
|
|