from __future__ import annotations import copy import gc import json from pathlib import Path from types import SimpleNamespace from typing import Dict, List import torch import torch.nn as nn import torch.nn.functional as F from .bila_layers import Bilateral_Grid_Joint_Flux, Biliteral_Grid_Joint LORA_TARGET_MODULES = ( "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out," "linear_in,linear_out,to_qkv_mlp_proj," "single_transformer_blocks.0.attn.to_out," "single_transformer_blocks.1.attn.to_out," "single_transformer_blocks.2.attn.to_out," "single_transformer_blocks.3.attn.to_out," "single_transformer_blocks.4.attn.to_out," "single_transformer_blocks.5.attn.to_out," "single_transformer_blocks.6.attn.to_out," "single_transformer_blocks.7.attn.to_out," "single_transformer_blocks.8.attn.to_out," "single_transformer_blocks.9.attn.to_out," "single_transformer_blocks.10.attn.to_out," "single_transformer_blocks.11.attn.to_out," "single_transformer_blocks.12.attn.to_out," "single_transformer_blocks.13.attn.to_out," "single_transformer_blocks.14.attn.to_out," "single_transformer_blocks.15.attn.to_out," "single_transformer_blocks.16.attn.to_out," "single_transformer_blocks.17.attn.to_out," "single_transformer_blocks.18.attn.to_out," "single_transformer_blocks.19.attn.to_out" ) def _device() -> torch.device: if not torch.cuda.is_available(): raise RuntimeError("This demo requires a CUDA GPU Space.") return torch.device("cuda") def _checkpoint_state(path: Path, required: List[str]) -> Dict: state = torch.load(path, map_location="cpu") if "state_dict" not in state: raise ValueError(f"{path} is missing top-level state_dict") state_dict = state["state_dict"] missing = [key.split(".", 1)[1] for key in required if key.startswith("state_dict.") and key.split(".", 1)[1] not in state_dict] if missing: raise ValueError(f"{path} is missing checkpoint entries: {missing}") return state_dict class Ip2pBilaBackend(nn.Module): def __init__(self, model_cfg: Dict, paths: Dict[str, Path]): super().__init__() self.model_cfg = model_cfg self.paths = paths self.config = model_cfg["config"] self.device = _device() self.weight_dtype = torch.float32 self.bila_feat = None from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel from transformers import CLIPTextModel, CLIPTokenizer base = paths["base"] self.tokenizer = CLIPTokenizer.from_pretrained(base, subfolder="tokenizer") self.text_encoder = CLIPTextModel.from_pretrained(base, subfolder="text_encoder").to(self.device) self.vae = AutoencoderKL.from_pretrained(base, subfolder="vae").to(self.device) self.unet = UNet2DConditionModel.from_pretrained(base, subfolder="unet").to(self.device) scheduler_path = Path(__file__).with_name("ip2p_scheduler.json") with scheduler_path.open("r", encoding="utf-8") as handle: self.noise_scheduler = DDPMScheduler.from_config(json.load(handle)) self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(self.device) self.timesteps = torch.tensor([999], device=self.device).long() self.unet.up_blocks[3].register_forward_hook(self._forward_hook) self.bila_grid = Biliteral_Grid_Joint( grid_res=self.config["bila_grid_res"], grid_bins=self.config["bila_grid_bins"], ).to(self.device) state_dict = _checkpoint_state(paths["checkpoint"], model_cfg["expected_checkpoint_keys"]) self.unet.load_state_dict(state_dict["unet"]) self.bila_grid.load_state_dict(state_dict["bila"]) self.eval() self.requires_grad_(False) def _forward_hook(self, module, inputs, output): del module, inputs self.bila_feat = [output] @torch.no_grad() def _encode_prompt(self, prompt_batch): tokens = self.tokenizer( prompt_batch, padding="max_length", max_length=77, truncation=True, return_tensors="pt", return_overflowing_tokens=False, ).input_ids.to(self.device) return self.text_encoder(tokens).last_hidden_state.detach() @torch.inference_mode() def forward(self, input_imgs, input_prompts, input_fullres): input_imgs = input_imgs.to(self.device, dtype=self.weight_dtype) input_fullres = input_fullres.to(self.device, dtype=self.weight_dtype) vae_input = input_imgs * 2 - 1 image_latents = self.vae.encode(vae_input).latent_dist.mode() noisy_latents = torch.randn( image_latents.shape, device=image_latents.device, dtype=image_latents.dtype, ) noisy_latents = noisy_latents * self.noise_scheduler.init_noise_sigma encoder_hidden_states = self._encode_prompt(input_prompts) timesteps = torch.ones((image_latents.shape[0],), device=self.device).long() * 999 concatenated_noisy_latents = torch.cat([noisy_latents, image_latents], dim=1) model_pred = self.unet(concatenated_noisy_latents, timesteps, encoder_hidden_states).sample alpha_prod = self.noise_scheduler.alphas_cumprod.to(image_latents.device, dtype=model_pred.dtype) beta_prod = 1 - alpha_prod alpha_prod_t = alpha_prod[timesteps].view(-1, 1, 1, 1) beta_prod_t = beta_prod[timesteps].view(-1, 1, 1, 1) x_denoised = (noisy_latents - beta_prod_t.sqrt() * model_pred) / alpha_prod_t.sqrt() pred_images = self.vae.decode((1 / 0.18215) * x_denoised.to(self.weight_dtype), return_dict=False)[0] diff_out_img = (pred_images / 2 + 0.5).clamp(0, 1) bila_feat = [feat.float() for feat in self.bila_feat] bila_out_img, _ = self.bila_grid(bila_feat, input_fullres) return {"diff": diff_out_img.detach().cpu(), "bila": bila_out_img.detach().cpu()} def _patch_attention_for_gqa(): original = F.scaled_dot_product_attention if getattr(original, "_bila_gqa_patched", False): return def patched_scaled_dot_product_attention(*args, **kwargs): kwargs.pop("enable_gqa", None) return original(*args, **kwargs) patched_scaled_dot_product_attention._bila_gqa_patched = True F.scaled_dot_product_attention = patched_scaled_dot_product_attention def _load_task_lora_state_dict(task_lora_path): from diffusers import Flux2KleinPipeline from diffusers.utils import convert_unet_state_dict_to_peft lora_state_dict = Flux2KleinPipeline.lora_state_dict(task_lora_path) transformer_lora_sd = { key.replace("transformer.", ""): value for key, value in lora_state_dict.items() if key.startswith("transformer.") } return convert_unet_state_dict_to_peft(transformer_lora_sd) def _build_lora_config(rank, alpha, dropout=0.0): from peft import LoraConfig return LoraConfig( r=rank, lora_alpha=alpha, lora_dropout=dropout, init_lora_weights="gaussian", target_modules=[module.strip() for module in LORA_TARGET_MODULES.split(",")], ) def _load_flux_transformer(args): from diffusers import Flux2Transformer2DModel from peft import set_peft_model_state_dict from peft.tuners.lora.layer import LoraLayer transformer = Flux2Transformer2DModel.from_pretrained(args.pipeline_path, subfolder="transformer") if args.task_lora_path: task_lora_sd = _load_task_lora_state_dict(args.task_lora_path) task_lora_config = _build_lora_config(args.task_lora_rank, args.task_lora_alpha) transformer.add_adapter(task_lora_config, adapter_name="task") set_peft_model_state_dict(transformer, task_lora_sd, adapter_name="task") for module in transformer.modules(): if isinstance(module, LoraLayer): module.merge(adapter_names=["task"]) transformer.delete_adapters("task") distill_lora_config = _build_lora_config( args.distill_lora_rank, args.distill_lora_alpha, args.distill_lora_dropout, ) transformer.add_adapter(distill_lora_config, adapter_name="distill") return transformer class FluxBilaBackend(nn.Module): def __init__(self, model_cfg: Dict, paths: Dict[str, Path]): super().__init__() _patch_attention_for_gqa() self.model_cfg = model_cfg self.paths = paths self.config = model_cfg["config"] self.device = _device() self.args = SimpleNamespace( pipeline_path=str(paths["base"]), task_lora_path=str(paths["task_lora"]), use_t2i=False, cfg=False, bila_use_flux_rgb=False, fix_guide_map=False, bila_grid_res=self.config["bila_grid_res"], bila_grid_bins=self.config["bila_grid_bins"], mixed_precision=self.config["mixed_precision"], max_sequence_length=self.config["max_sequence_length"], distill_strategy=self.config["distill_strategy"], distill_lora_rank=self.config["distill_lora_rank"], distill_lora_alpha=self.config["distill_lora_alpha"], distill_lora_dropout=self.config["distill_lora_dropout"], task_lora_rank=self.config["task_lora_rank"], task_lora_alpha=self.config["task_lora_alpha"], ) from diffusers import AutoencoderKLFlux2, FlowMatchEulerDiscreteScheduler, Flux2KleinPipeline from peft import set_peft_model_state_dict from transformers import Qwen2TokenizerFast, Qwen3ForCausalLM self.Flux2KleinPipeline = Flux2KleinPipeline self.tokenizer = Qwen2TokenizerFast.from_pretrained(paths["base"], subfolder="tokenizer") self.text_encoder = Qwen3ForCausalLM.from_pretrained(paths["base"], subfolder="text_encoder").to(self.device) self.vae = AutoencoderKLFlux2.from_pretrained(paths["base"], subfolder="vae").to(self.device) self.transformer = _load_flux_transformer(self.args).to(self.device) self.noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(paths["base"], subfolder="scheduler") self.noise_scheduler_copy = copy.deepcopy(self.noise_scheduler) self.latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(self.device) self.latents_bn_std = torch.sqrt( self.vae.bn.running_var.view(1, -1, 1, 1).to(self.device) + self.vae.config.batch_norm_eps ) self.text_encoding_pipeline = Flux2KleinPipeline.from_pretrained( paths["base"], vae=None, transformer=None, tokenizer=self.tokenizer, text_encoder=self.text_encoder, scheduler=None, ) self.one_step_sigma = 1.0 self.weight_dtype = torch.bfloat16 if self.args.mixed_precision == "fp16": self.weight_dtype = torch.float16 elif self.args.mixed_precision == "no": self.weight_dtype = torch.float32 self.vae.to(dtype=self.weight_dtype) self.transformer.to(dtype=self.weight_dtype) self.text_encoder.to(dtype=self.weight_dtype) self.bila_feat = None if hasattr(self.transformer, "single_transformer_blocks") and len(self.transformer.single_transformer_blocks) > 0: self.transformer.single_transformer_blocks[-1].register_forward_hook(self._forward_hook) else: self.transformer.transformer_blocks[-1].register_forward_hook(self._forward_hook) self.bila_grid = Bilateral_Grid_Joint_Flux( grid_res=self.config["bila_grid_res"], grid_bins=self.config["bila_grid_bins"], ).to(self.device) state_dict = _checkpoint_state(paths["checkpoint"], model_cfg["expected_checkpoint_keys"]) adapter_name = state_dict.get("active_adapter_name", "distill") set_peft_model_state_dict(self.transformer, state_dict["transformer_lora"], adapter_name=adapter_name) if hasattr(self.transformer, "set_adapter"): self.transformer.set_adapter(adapter_name) self.bila_grid.load_state_dict(state_dict["bila"]) self.eval() self.requires_grad_(False) def _forward_hook(self, module, inputs, output): del module, inputs self.bila_feat = [output[1] if isinstance(output, tuple) else output] @torch.no_grad() def _encode_prompt(self, prompt_batch): prompt_embeds, text_ids = self.text_encoding_pipeline.encode_prompt( prompt=prompt_batch, max_sequence_length=self.args.max_sequence_length, ) return prompt_embeds.detach(), text_ids.detach() def _prepare_latent_ids_and_cond_ids(self, model_input, cond_model_input): model_input_ids = self.Flux2KleinPipeline._prepare_latent_ids(model_input).to(device=model_input.device) cond_model_input_list = [cond_model_input[i].unsqueeze(0) for i in range(cond_model_input.shape[0])] cond_model_input_ids = self.Flux2KleinPipeline._prepare_image_ids(cond_model_input_list).to( device=cond_model_input.device ) cond_model_input_ids = cond_model_input_ids.view( cond_model_input.shape[0], -1, model_input_ids.shape[-1] ) return model_input_ids, cond_model_input_ids @torch.inference_mode() def forward(self, input_imgs, input_prompts, input_fullres): input_imgs = input_imgs.to(self.device, dtype=self.weight_dtype) input_fullres = input_fullres.to(self.device, dtype=self.weight_dtype) vae_input = input_imgs * 2 - 1 image_latents = self.vae.encode(vae_input).latent_dist.mode() image_latents_patched = self.Flux2KleinPipeline._patchify_latents(image_latents) cond_model_input = (image_latents_patched - self.latents_bn_mean) / self.latents_bn_std noisy_latents = torch.randn_like(cond_model_input) model_input_ids, cond_model_input_ids = self._prepare_latent_ids_and_cond_ids( noisy_latents, cond_model_input ) prompt_embeds, text_ids = self._encode_prompt(input_prompts) bsz = noisy_latents.shape[0] timestep_input = (torch.ones((bsz,), device=self.device) * 999.0) / 1000.0 packed_noisy = self.Flux2KleinPipeline._pack_latents(noisy_latents) packed_cond = self.Flux2KleinPipeline._pack_latents(cond_model_input) orig_shape = packed_noisy.shape orig_ids_shape = model_input_ids.shape packed_input = torch.cat([packed_noisy, packed_cond], dim=1) ids_input = torch.cat([model_input_ids, cond_model_input_ids], dim=1) model_pred = self.transformer( hidden_states=packed_input, timestep=timestep_input, guidance=None, encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=ids_input, return_dict=False, )[0] model_pred = model_pred[:, : orig_shape[1], :] model_pred = self.Flux2KleinPipeline._unpack_latents_with_ids( model_pred, model_input_ids[:, : orig_ids_shape[1], :], ) x0_pred_normalized = noisy_latents - self.one_step_sigma * model_pred x0_pred_patched = x0_pred_normalized * self.latents_bn_std + self.latents_bn_mean x_denoised = self.Flux2KleinPipeline._unpatchify_latents(x0_pred_patched) pred_images = self.vae.decode(x_denoised.to(self.weight_dtype), return_dict=False)[0] diff_out_img = (pred_images / 2 + 0.5).clamp(0, 1) num_txt_tokens = prompt_embeds.shape[1] bila_feat = [feat.float() for feat in self.bila_feat] bila_feat = [feat[:, num_txt_tokens:, ...] for feat in bila_feat] bila_feat = [feat[:, : orig_shape[1], :] for feat in bila_feat] bila_out_img, _ = self.bila_grid( bila_feat, input_fullres, latent_h=cond_model_input.shape[-2], latent_w=cond_model_input.shape[-1], ) return {"diff": diff_out_img.detach().cpu(), "bila": bila_out_img.detach().cpu()} def release_cuda(): gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache()