Spaces:
Running on Zero
Running on Zero
| from __future__ import annotations | |
| import copy | |
| import gc | |
| import json | |
| from pathlib import Path | |
| from types import SimpleNamespace | |
| from typing import Dict, List | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .bila_layers import Bilateral_Grid_Joint_Flux, Biliteral_Grid_Joint | |
| LORA_TARGET_MODULES = ( | |
| "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out," | |
| "linear_in,linear_out,to_qkv_mlp_proj," | |
| "single_transformer_blocks.0.attn.to_out," | |
| "single_transformer_blocks.1.attn.to_out," | |
| "single_transformer_blocks.2.attn.to_out," | |
| "single_transformer_blocks.3.attn.to_out," | |
| "single_transformer_blocks.4.attn.to_out," | |
| "single_transformer_blocks.5.attn.to_out," | |
| "single_transformer_blocks.6.attn.to_out," | |
| "single_transformer_blocks.7.attn.to_out," | |
| "single_transformer_blocks.8.attn.to_out," | |
| "single_transformer_blocks.9.attn.to_out," | |
| "single_transformer_blocks.10.attn.to_out," | |
| "single_transformer_blocks.11.attn.to_out," | |
| "single_transformer_blocks.12.attn.to_out," | |
| "single_transformer_blocks.13.attn.to_out," | |
| "single_transformer_blocks.14.attn.to_out," | |
| "single_transformer_blocks.15.attn.to_out," | |
| "single_transformer_blocks.16.attn.to_out," | |
| "single_transformer_blocks.17.attn.to_out," | |
| "single_transformer_blocks.18.attn.to_out," | |
| "single_transformer_blocks.19.attn.to_out" | |
| ) | |
| def _device() -> torch.device: | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError("This demo requires a CUDA GPU Space.") | |
| return torch.device("cuda") | |
| def _checkpoint_state(path: Path, required: List[str]) -> Dict: | |
| state = torch.load(path, map_location="cpu") | |
| if "state_dict" not in state: | |
| raise ValueError(f"{path} is missing top-level state_dict") | |
| state_dict = state["state_dict"] | |
| missing = [key.split(".", 1)[1] for key in required if key.startswith("state_dict.") and key.split(".", 1)[1] not in state_dict] | |
| if missing: | |
| raise ValueError(f"{path} is missing checkpoint entries: {missing}") | |
| return state_dict | |
| class Ip2pBilaBackend(nn.Module): | |
| def __init__(self, model_cfg: Dict, paths: Dict[str, Path]): | |
| super().__init__() | |
| self.model_cfg = model_cfg | |
| self.paths = paths | |
| self.config = model_cfg["config"] | |
| self.device = _device() | |
| self.weight_dtype = torch.float32 | |
| self.bila_feat = None | |
| from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| base = paths["base"] | |
| self.tokenizer = CLIPTokenizer.from_pretrained(base, subfolder="tokenizer") | |
| self.text_encoder = CLIPTextModel.from_pretrained(base, subfolder="text_encoder").to(self.device) | |
| self.vae = AutoencoderKL.from_pretrained(base, subfolder="vae").to(self.device) | |
| self.unet = UNet2DConditionModel.from_pretrained(base, subfolder="unet").to(self.device) | |
| scheduler_path = Path(__file__).with_name("ip2p_scheduler.json") | |
| with scheduler_path.open("r", encoding="utf-8") as handle: | |
| self.noise_scheduler = DDPMScheduler.from_config(json.load(handle)) | |
| self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(self.device) | |
| self.timesteps = torch.tensor([999], device=self.device).long() | |
| self.unet.up_blocks[3].register_forward_hook(self._forward_hook) | |
| self.bila_grid = Biliteral_Grid_Joint( | |
| grid_res=self.config["bila_grid_res"], | |
| grid_bins=self.config["bila_grid_bins"], | |
| ).to(self.device) | |
| state_dict = _checkpoint_state(paths["checkpoint"], model_cfg["expected_checkpoint_keys"]) | |
| self.unet.load_state_dict(state_dict["unet"]) | |
| self.bila_grid.load_state_dict(state_dict["bila"]) | |
| self.eval() | |
| self.requires_grad_(False) | |
| def _forward_hook(self, module, inputs, output): | |
| del module, inputs | |
| self.bila_feat = [output] | |
| def _encode_prompt(self, prompt_batch): | |
| tokens = self.tokenizer( | |
| prompt_batch, | |
| padding="max_length", | |
| max_length=77, | |
| truncation=True, | |
| return_tensors="pt", | |
| return_overflowing_tokens=False, | |
| ).input_ids.to(self.device) | |
| return self.text_encoder(tokens).last_hidden_state.detach() | |
| def forward(self, input_imgs, input_prompts, input_fullres): | |
| input_imgs = input_imgs.to(self.device, dtype=self.weight_dtype) | |
| input_fullres = input_fullres.to(self.device, dtype=self.weight_dtype) | |
| vae_input = input_imgs * 2 - 1 | |
| image_latents = self.vae.encode(vae_input).latent_dist.mode() | |
| noisy_latents = torch.randn( | |
| image_latents.shape, | |
| device=image_latents.device, | |
| dtype=image_latents.dtype, | |
| ) | |
| noisy_latents = noisy_latents * self.noise_scheduler.init_noise_sigma | |
| encoder_hidden_states = self._encode_prompt(input_prompts) | |
| timesteps = torch.ones((image_latents.shape[0],), device=self.device).long() * 999 | |
| concatenated_noisy_latents = torch.cat([noisy_latents, image_latents], dim=1) | |
| model_pred = self.unet(concatenated_noisy_latents, timesteps, encoder_hidden_states).sample | |
| alpha_prod = self.noise_scheduler.alphas_cumprod.to(image_latents.device, dtype=model_pred.dtype) | |
| beta_prod = 1 - alpha_prod | |
| alpha_prod_t = alpha_prod[timesteps].view(-1, 1, 1, 1) | |
| beta_prod_t = beta_prod[timesteps].view(-1, 1, 1, 1) | |
| x_denoised = (noisy_latents - beta_prod_t.sqrt() * model_pred) / alpha_prod_t.sqrt() | |
| pred_images = self.vae.decode((1 / 0.18215) * x_denoised.to(self.weight_dtype), return_dict=False)[0] | |
| diff_out_img = (pred_images / 2 + 0.5).clamp(0, 1) | |
| bila_feat = [feat.float() for feat in self.bila_feat] | |
| bila_out_img, _ = self.bila_grid(bila_feat, input_fullres) | |
| return {"diff": diff_out_img.detach().cpu(), "bila": bila_out_img.detach().cpu()} | |
| def _patch_attention_for_gqa(): | |
| original = F.scaled_dot_product_attention | |
| if getattr(original, "_bila_gqa_patched", False): | |
| return | |
| def patched_scaled_dot_product_attention(*args, **kwargs): | |
| kwargs.pop("enable_gqa", None) | |
| return original(*args, **kwargs) | |
| patched_scaled_dot_product_attention._bila_gqa_patched = True | |
| F.scaled_dot_product_attention = patched_scaled_dot_product_attention | |
| def _load_task_lora_state_dict(task_lora_path): | |
| from diffusers import Flux2KleinPipeline | |
| from diffusers.utils import convert_unet_state_dict_to_peft | |
| lora_state_dict = Flux2KleinPipeline.lora_state_dict(task_lora_path) | |
| transformer_lora_sd = { | |
| key.replace("transformer.", ""): value | |
| for key, value in lora_state_dict.items() | |
| if key.startswith("transformer.") | |
| } | |
| return convert_unet_state_dict_to_peft(transformer_lora_sd) | |
| def _build_lora_config(rank, alpha, dropout=0.0): | |
| from peft import LoraConfig | |
| return LoraConfig( | |
| r=rank, | |
| lora_alpha=alpha, | |
| lora_dropout=dropout, | |
| init_lora_weights="gaussian", | |
| target_modules=[module.strip() for module in LORA_TARGET_MODULES.split(",")], | |
| ) | |
| def _load_flux_transformer(args): | |
| from diffusers import Flux2Transformer2DModel | |
| from peft import set_peft_model_state_dict | |
| from peft.tuners.lora.layer import LoraLayer | |
| transformer = Flux2Transformer2DModel.from_pretrained(args.pipeline_path, subfolder="transformer") | |
| if args.task_lora_path: | |
| task_lora_sd = _load_task_lora_state_dict(args.task_lora_path) | |
| task_lora_config = _build_lora_config(args.task_lora_rank, args.task_lora_alpha) | |
| transformer.add_adapter(task_lora_config, adapter_name="task") | |
| set_peft_model_state_dict(transformer, task_lora_sd, adapter_name="task") | |
| for module in transformer.modules(): | |
| if isinstance(module, LoraLayer): | |
| module.merge(adapter_names=["task"]) | |
| transformer.delete_adapters("task") | |
| distill_lora_config = _build_lora_config( | |
| args.distill_lora_rank, | |
| args.distill_lora_alpha, | |
| args.distill_lora_dropout, | |
| ) | |
| transformer.add_adapter(distill_lora_config, adapter_name="distill") | |
| return transformer | |
| class FluxBilaBackend(nn.Module): | |
| def __init__(self, model_cfg: Dict, paths: Dict[str, Path]): | |
| super().__init__() | |
| _patch_attention_for_gqa() | |
| self.model_cfg = model_cfg | |
| self.paths = paths | |
| self.config = model_cfg["config"] | |
| self.device = _device() | |
| self.args = SimpleNamespace( | |
| pipeline_path=str(paths["base"]), | |
| task_lora_path=str(paths["task_lora"]), | |
| use_t2i=False, | |
| cfg=False, | |
| bila_use_flux_rgb=False, | |
| fix_guide_map=False, | |
| bila_grid_res=self.config["bila_grid_res"], | |
| bila_grid_bins=self.config["bila_grid_bins"], | |
| mixed_precision=self.config["mixed_precision"], | |
| max_sequence_length=self.config["max_sequence_length"], | |
| distill_strategy=self.config["distill_strategy"], | |
| distill_lora_rank=self.config["distill_lora_rank"], | |
| distill_lora_alpha=self.config["distill_lora_alpha"], | |
| distill_lora_dropout=self.config["distill_lora_dropout"], | |
| task_lora_rank=self.config["task_lora_rank"], | |
| task_lora_alpha=self.config["task_lora_alpha"], | |
| ) | |
| from diffusers import AutoencoderKLFlux2, FlowMatchEulerDiscreteScheduler, Flux2KleinPipeline | |
| from peft import set_peft_model_state_dict | |
| from transformers import Qwen2TokenizerFast, Qwen3ForCausalLM | |
| self.Flux2KleinPipeline = Flux2KleinPipeline | |
| self.tokenizer = Qwen2TokenizerFast.from_pretrained(paths["base"], subfolder="tokenizer") | |
| self.text_encoder = Qwen3ForCausalLM.from_pretrained(paths["base"], subfolder="text_encoder").to(self.device) | |
| self.vae = AutoencoderKLFlux2.from_pretrained(paths["base"], subfolder="vae").to(self.device) | |
| self.transformer = _load_flux_transformer(self.args).to(self.device) | |
| self.noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(paths["base"], subfolder="scheduler") | |
| self.noise_scheduler_copy = copy.deepcopy(self.noise_scheduler) | |
| self.latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(self.device) | |
| self.latents_bn_std = torch.sqrt( | |
| self.vae.bn.running_var.view(1, -1, 1, 1).to(self.device) + self.vae.config.batch_norm_eps | |
| ) | |
| self.text_encoding_pipeline = Flux2KleinPipeline.from_pretrained( | |
| paths["base"], | |
| vae=None, | |
| transformer=None, | |
| tokenizer=self.tokenizer, | |
| text_encoder=self.text_encoder, | |
| scheduler=None, | |
| ) | |
| self.one_step_sigma = 1.0 | |
| self.weight_dtype = torch.bfloat16 | |
| if self.args.mixed_precision == "fp16": | |
| self.weight_dtype = torch.float16 | |
| elif self.args.mixed_precision == "no": | |
| self.weight_dtype = torch.float32 | |
| self.vae.to(dtype=self.weight_dtype) | |
| self.transformer.to(dtype=self.weight_dtype) | |
| self.text_encoder.to(dtype=self.weight_dtype) | |
| self.bila_feat = None | |
| if hasattr(self.transformer, "single_transformer_blocks") and len(self.transformer.single_transformer_blocks) > 0: | |
| self.transformer.single_transformer_blocks[-1].register_forward_hook(self._forward_hook) | |
| else: | |
| self.transformer.transformer_blocks[-1].register_forward_hook(self._forward_hook) | |
| self.bila_grid = Bilateral_Grid_Joint_Flux( | |
| grid_res=self.config["bila_grid_res"], | |
| grid_bins=self.config["bila_grid_bins"], | |
| ).to(self.device) | |
| state_dict = _checkpoint_state(paths["checkpoint"], model_cfg["expected_checkpoint_keys"]) | |
| adapter_name = state_dict.get("active_adapter_name", "distill") | |
| set_peft_model_state_dict(self.transformer, state_dict["transformer_lora"], adapter_name=adapter_name) | |
| if hasattr(self.transformer, "set_adapter"): | |
| self.transformer.set_adapter(adapter_name) | |
| self.bila_grid.load_state_dict(state_dict["bila"]) | |
| self.eval() | |
| self.requires_grad_(False) | |
| def _forward_hook(self, module, inputs, output): | |
| del module, inputs | |
| self.bila_feat = [output[1] if isinstance(output, tuple) else output] | |
| def _encode_prompt(self, prompt_batch): | |
| prompt_embeds, text_ids = self.text_encoding_pipeline.encode_prompt( | |
| prompt=prompt_batch, | |
| max_sequence_length=self.args.max_sequence_length, | |
| ) | |
| return prompt_embeds.detach(), text_ids.detach() | |
| def _prepare_latent_ids_and_cond_ids(self, model_input, cond_model_input): | |
| model_input_ids = self.Flux2KleinPipeline._prepare_latent_ids(model_input).to(device=model_input.device) | |
| cond_model_input_list = [cond_model_input[i].unsqueeze(0) for i in range(cond_model_input.shape[0])] | |
| cond_model_input_ids = self.Flux2KleinPipeline._prepare_image_ids(cond_model_input_list).to( | |
| device=cond_model_input.device | |
| ) | |
| cond_model_input_ids = cond_model_input_ids.view( | |
| cond_model_input.shape[0], -1, model_input_ids.shape[-1] | |
| ) | |
| return model_input_ids, cond_model_input_ids | |
| def forward(self, input_imgs, input_prompts, input_fullres): | |
| input_imgs = input_imgs.to(self.device, dtype=self.weight_dtype) | |
| input_fullres = input_fullres.to(self.device, dtype=self.weight_dtype) | |
| vae_input = input_imgs * 2 - 1 | |
| image_latents = self.vae.encode(vae_input).latent_dist.mode() | |
| image_latents_patched = self.Flux2KleinPipeline._patchify_latents(image_latents) | |
| cond_model_input = (image_latents_patched - self.latents_bn_mean) / self.latents_bn_std | |
| noisy_latents = torch.randn_like(cond_model_input) | |
| model_input_ids, cond_model_input_ids = self._prepare_latent_ids_and_cond_ids( | |
| noisy_latents, cond_model_input | |
| ) | |
| prompt_embeds, text_ids = self._encode_prompt(input_prompts) | |
| bsz = noisy_latents.shape[0] | |
| timestep_input = (torch.ones((bsz,), device=self.device) * 999.0) / 1000.0 | |
| packed_noisy = self.Flux2KleinPipeline._pack_latents(noisy_latents) | |
| packed_cond = self.Flux2KleinPipeline._pack_latents(cond_model_input) | |
| orig_shape = packed_noisy.shape | |
| orig_ids_shape = model_input_ids.shape | |
| packed_input = torch.cat([packed_noisy, packed_cond], dim=1) | |
| ids_input = torch.cat([model_input_ids, cond_model_input_ids], dim=1) | |
| model_pred = self.transformer( | |
| hidden_states=packed_input, | |
| timestep=timestep_input, | |
| guidance=None, | |
| encoder_hidden_states=prompt_embeds, | |
| txt_ids=text_ids, | |
| img_ids=ids_input, | |
| return_dict=False, | |
| )[0] | |
| model_pred = model_pred[:, : orig_shape[1], :] | |
| model_pred = self.Flux2KleinPipeline._unpack_latents_with_ids( | |
| model_pred, | |
| model_input_ids[:, : orig_ids_shape[1], :], | |
| ) | |
| x0_pred_normalized = noisy_latents - self.one_step_sigma * model_pred | |
| x0_pred_patched = x0_pred_normalized * self.latents_bn_std + self.latents_bn_mean | |
| x_denoised = self.Flux2KleinPipeline._unpatchify_latents(x0_pred_patched) | |
| pred_images = self.vae.decode(x_denoised.to(self.weight_dtype), return_dict=False)[0] | |
| diff_out_img = (pred_images / 2 + 0.5).clamp(0, 1) | |
| num_txt_tokens = prompt_embeds.shape[1] | |
| bila_feat = [feat.float() for feat in self.bila_feat] | |
| bila_feat = [feat[:, num_txt_tokens:, ...] for feat in bila_feat] | |
| bila_feat = [feat[:, : orig_shape[1], :] for feat in bila_feat] | |
| bila_out_img, _ = self.bila_grid( | |
| bila_feat, | |
| input_fullres, | |
| latent_h=cond_model_input.shape[-2], | |
| latent_w=cond_model_input.shape[-1], | |
| ) | |
| return {"diff": diff_out_img.detach().cpu(), "bila": bila_out_img.detach().cpu()} | |
| def release_cuda(): | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |