iimmortall's picture
Deploy InstantRetouch BILA ZeroGPU Space
bc275c2 verified
from __future__ import annotations
import copy
import gc
import json
from pathlib import Path
from types import SimpleNamespace
from typing import Dict, List
import torch
import torch.nn as nn
import torch.nn.functional as F
from .bila_layers import Bilateral_Grid_Joint_Flux, Biliteral_Grid_Joint
LORA_TARGET_MODULES = (
"to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,"
"linear_in,linear_out,to_qkv_mlp_proj,"
"single_transformer_blocks.0.attn.to_out,"
"single_transformer_blocks.1.attn.to_out,"
"single_transformer_blocks.2.attn.to_out,"
"single_transformer_blocks.3.attn.to_out,"
"single_transformer_blocks.4.attn.to_out,"
"single_transformer_blocks.5.attn.to_out,"
"single_transformer_blocks.6.attn.to_out,"
"single_transformer_blocks.7.attn.to_out,"
"single_transformer_blocks.8.attn.to_out,"
"single_transformer_blocks.9.attn.to_out,"
"single_transformer_blocks.10.attn.to_out,"
"single_transformer_blocks.11.attn.to_out,"
"single_transformer_blocks.12.attn.to_out,"
"single_transformer_blocks.13.attn.to_out,"
"single_transformer_blocks.14.attn.to_out,"
"single_transformer_blocks.15.attn.to_out,"
"single_transformer_blocks.16.attn.to_out,"
"single_transformer_blocks.17.attn.to_out,"
"single_transformer_blocks.18.attn.to_out,"
"single_transformer_blocks.19.attn.to_out"
)
def _device() -> torch.device:
if not torch.cuda.is_available():
raise RuntimeError("This demo requires a CUDA GPU Space.")
return torch.device("cuda")
def _checkpoint_state(path: Path, required: List[str]) -> Dict:
state = torch.load(path, map_location="cpu")
if "state_dict" not in state:
raise ValueError(f"{path} is missing top-level state_dict")
state_dict = state["state_dict"]
missing = [key.split(".", 1)[1] for key in required if key.startswith("state_dict.") and key.split(".", 1)[1] not in state_dict]
if missing:
raise ValueError(f"{path} is missing checkpoint entries: {missing}")
return state_dict
class Ip2pBilaBackend(nn.Module):
def __init__(self, model_cfg: Dict, paths: Dict[str, Path]):
super().__init__()
self.model_cfg = model_cfg
self.paths = paths
self.config = model_cfg["config"]
self.device = _device()
self.weight_dtype = torch.float32
self.bila_feat = None
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer
base = paths["base"]
self.tokenizer = CLIPTokenizer.from_pretrained(base, subfolder="tokenizer")
self.text_encoder = CLIPTextModel.from_pretrained(base, subfolder="text_encoder").to(self.device)
self.vae = AutoencoderKL.from_pretrained(base, subfolder="vae").to(self.device)
self.unet = UNet2DConditionModel.from_pretrained(base, subfolder="unet").to(self.device)
scheduler_path = Path(__file__).with_name("ip2p_scheduler.json")
with scheduler_path.open("r", encoding="utf-8") as handle:
self.noise_scheduler = DDPMScheduler.from_config(json.load(handle))
self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(self.device)
self.timesteps = torch.tensor([999], device=self.device).long()
self.unet.up_blocks[3].register_forward_hook(self._forward_hook)
self.bila_grid = Biliteral_Grid_Joint(
grid_res=self.config["bila_grid_res"],
grid_bins=self.config["bila_grid_bins"],
).to(self.device)
state_dict = _checkpoint_state(paths["checkpoint"], model_cfg["expected_checkpoint_keys"])
self.unet.load_state_dict(state_dict["unet"])
self.bila_grid.load_state_dict(state_dict["bila"])
self.eval()
self.requires_grad_(False)
def _forward_hook(self, module, inputs, output):
del module, inputs
self.bila_feat = [output]
@torch.no_grad()
def _encode_prompt(self, prompt_batch):
tokens = self.tokenizer(
prompt_batch,
padding="max_length",
max_length=77,
truncation=True,
return_tensors="pt",
return_overflowing_tokens=False,
).input_ids.to(self.device)
return self.text_encoder(tokens).last_hidden_state.detach()
@torch.inference_mode()
def forward(self, input_imgs, input_prompts, input_fullres):
input_imgs = input_imgs.to(self.device, dtype=self.weight_dtype)
input_fullres = input_fullres.to(self.device, dtype=self.weight_dtype)
vae_input = input_imgs * 2 - 1
image_latents = self.vae.encode(vae_input).latent_dist.mode()
noisy_latents = torch.randn(
image_latents.shape,
device=image_latents.device,
dtype=image_latents.dtype,
)
noisy_latents = noisy_latents * self.noise_scheduler.init_noise_sigma
encoder_hidden_states = self._encode_prompt(input_prompts)
timesteps = torch.ones((image_latents.shape[0],), device=self.device).long() * 999
concatenated_noisy_latents = torch.cat([noisy_latents, image_latents], dim=1)
model_pred = self.unet(concatenated_noisy_latents, timesteps, encoder_hidden_states).sample
alpha_prod = self.noise_scheduler.alphas_cumprod.to(image_latents.device, dtype=model_pred.dtype)
beta_prod = 1 - alpha_prod
alpha_prod_t = alpha_prod[timesteps].view(-1, 1, 1, 1)
beta_prod_t = beta_prod[timesteps].view(-1, 1, 1, 1)
x_denoised = (noisy_latents - beta_prod_t.sqrt() * model_pred) / alpha_prod_t.sqrt()
pred_images = self.vae.decode((1 / 0.18215) * x_denoised.to(self.weight_dtype), return_dict=False)[0]
diff_out_img = (pred_images / 2 + 0.5).clamp(0, 1)
bila_feat = [feat.float() for feat in self.bila_feat]
bila_out_img, _ = self.bila_grid(bila_feat, input_fullres)
return {"diff": diff_out_img.detach().cpu(), "bila": bila_out_img.detach().cpu()}
def _patch_attention_for_gqa():
original = F.scaled_dot_product_attention
if getattr(original, "_bila_gqa_patched", False):
return
def patched_scaled_dot_product_attention(*args, **kwargs):
kwargs.pop("enable_gqa", None)
return original(*args, **kwargs)
patched_scaled_dot_product_attention._bila_gqa_patched = True
F.scaled_dot_product_attention = patched_scaled_dot_product_attention
def _load_task_lora_state_dict(task_lora_path):
from diffusers import Flux2KleinPipeline
from diffusers.utils import convert_unet_state_dict_to_peft
lora_state_dict = Flux2KleinPipeline.lora_state_dict(task_lora_path)
transformer_lora_sd = {
key.replace("transformer.", ""): value
for key, value in lora_state_dict.items()
if key.startswith("transformer.")
}
return convert_unet_state_dict_to_peft(transformer_lora_sd)
def _build_lora_config(rank, alpha, dropout=0.0):
from peft import LoraConfig
return LoraConfig(
r=rank,
lora_alpha=alpha,
lora_dropout=dropout,
init_lora_weights="gaussian",
target_modules=[module.strip() for module in LORA_TARGET_MODULES.split(",")],
)
def _load_flux_transformer(args):
from diffusers import Flux2Transformer2DModel
from peft import set_peft_model_state_dict
from peft.tuners.lora.layer import LoraLayer
transformer = Flux2Transformer2DModel.from_pretrained(args.pipeline_path, subfolder="transformer")
if args.task_lora_path:
task_lora_sd = _load_task_lora_state_dict(args.task_lora_path)
task_lora_config = _build_lora_config(args.task_lora_rank, args.task_lora_alpha)
transformer.add_adapter(task_lora_config, adapter_name="task")
set_peft_model_state_dict(transformer, task_lora_sd, adapter_name="task")
for module in transformer.modules():
if isinstance(module, LoraLayer):
module.merge(adapter_names=["task"])
transformer.delete_adapters("task")
distill_lora_config = _build_lora_config(
args.distill_lora_rank,
args.distill_lora_alpha,
args.distill_lora_dropout,
)
transformer.add_adapter(distill_lora_config, adapter_name="distill")
return transformer
class FluxBilaBackend(nn.Module):
def __init__(self, model_cfg: Dict, paths: Dict[str, Path]):
super().__init__()
_patch_attention_for_gqa()
self.model_cfg = model_cfg
self.paths = paths
self.config = model_cfg["config"]
self.device = _device()
self.args = SimpleNamespace(
pipeline_path=str(paths["base"]),
task_lora_path=str(paths["task_lora"]),
use_t2i=False,
cfg=False,
bila_use_flux_rgb=False,
fix_guide_map=False,
bila_grid_res=self.config["bila_grid_res"],
bila_grid_bins=self.config["bila_grid_bins"],
mixed_precision=self.config["mixed_precision"],
max_sequence_length=self.config["max_sequence_length"],
distill_strategy=self.config["distill_strategy"],
distill_lora_rank=self.config["distill_lora_rank"],
distill_lora_alpha=self.config["distill_lora_alpha"],
distill_lora_dropout=self.config["distill_lora_dropout"],
task_lora_rank=self.config["task_lora_rank"],
task_lora_alpha=self.config["task_lora_alpha"],
)
from diffusers import AutoencoderKLFlux2, FlowMatchEulerDiscreteScheduler, Flux2KleinPipeline
from peft import set_peft_model_state_dict
from transformers import Qwen2TokenizerFast, Qwen3ForCausalLM
self.Flux2KleinPipeline = Flux2KleinPipeline
self.tokenizer = Qwen2TokenizerFast.from_pretrained(paths["base"], subfolder="tokenizer")
self.text_encoder = Qwen3ForCausalLM.from_pretrained(paths["base"], subfolder="text_encoder").to(self.device)
self.vae = AutoencoderKLFlux2.from_pretrained(paths["base"], subfolder="vae").to(self.device)
self.transformer = _load_flux_transformer(self.args).to(self.device)
self.noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(paths["base"], subfolder="scheduler")
self.noise_scheduler_copy = copy.deepcopy(self.noise_scheduler)
self.latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(self.device)
self.latents_bn_std = torch.sqrt(
self.vae.bn.running_var.view(1, -1, 1, 1).to(self.device) + self.vae.config.batch_norm_eps
)
self.text_encoding_pipeline = Flux2KleinPipeline.from_pretrained(
paths["base"],
vae=None,
transformer=None,
tokenizer=self.tokenizer,
text_encoder=self.text_encoder,
scheduler=None,
)
self.one_step_sigma = 1.0
self.weight_dtype = torch.bfloat16
if self.args.mixed_precision == "fp16":
self.weight_dtype = torch.float16
elif self.args.mixed_precision == "no":
self.weight_dtype = torch.float32
self.vae.to(dtype=self.weight_dtype)
self.transformer.to(dtype=self.weight_dtype)
self.text_encoder.to(dtype=self.weight_dtype)
self.bila_feat = None
if hasattr(self.transformer, "single_transformer_blocks") and len(self.transformer.single_transformer_blocks) > 0:
self.transformer.single_transformer_blocks[-1].register_forward_hook(self._forward_hook)
else:
self.transformer.transformer_blocks[-1].register_forward_hook(self._forward_hook)
self.bila_grid = Bilateral_Grid_Joint_Flux(
grid_res=self.config["bila_grid_res"],
grid_bins=self.config["bila_grid_bins"],
).to(self.device)
state_dict = _checkpoint_state(paths["checkpoint"], model_cfg["expected_checkpoint_keys"])
adapter_name = state_dict.get("active_adapter_name", "distill")
set_peft_model_state_dict(self.transformer, state_dict["transformer_lora"], adapter_name=adapter_name)
if hasattr(self.transformer, "set_adapter"):
self.transformer.set_adapter(adapter_name)
self.bila_grid.load_state_dict(state_dict["bila"])
self.eval()
self.requires_grad_(False)
def _forward_hook(self, module, inputs, output):
del module, inputs
self.bila_feat = [output[1] if isinstance(output, tuple) else output]
@torch.no_grad()
def _encode_prompt(self, prompt_batch):
prompt_embeds, text_ids = self.text_encoding_pipeline.encode_prompt(
prompt=prompt_batch,
max_sequence_length=self.args.max_sequence_length,
)
return prompt_embeds.detach(), text_ids.detach()
def _prepare_latent_ids_and_cond_ids(self, model_input, cond_model_input):
model_input_ids = self.Flux2KleinPipeline._prepare_latent_ids(model_input).to(device=model_input.device)
cond_model_input_list = [cond_model_input[i].unsqueeze(0) for i in range(cond_model_input.shape[0])]
cond_model_input_ids = self.Flux2KleinPipeline._prepare_image_ids(cond_model_input_list).to(
device=cond_model_input.device
)
cond_model_input_ids = cond_model_input_ids.view(
cond_model_input.shape[0], -1, model_input_ids.shape[-1]
)
return model_input_ids, cond_model_input_ids
@torch.inference_mode()
def forward(self, input_imgs, input_prompts, input_fullres):
input_imgs = input_imgs.to(self.device, dtype=self.weight_dtype)
input_fullres = input_fullres.to(self.device, dtype=self.weight_dtype)
vae_input = input_imgs * 2 - 1
image_latents = self.vae.encode(vae_input).latent_dist.mode()
image_latents_patched = self.Flux2KleinPipeline._patchify_latents(image_latents)
cond_model_input = (image_latents_patched - self.latents_bn_mean) / self.latents_bn_std
noisy_latents = torch.randn_like(cond_model_input)
model_input_ids, cond_model_input_ids = self._prepare_latent_ids_and_cond_ids(
noisy_latents, cond_model_input
)
prompt_embeds, text_ids = self._encode_prompt(input_prompts)
bsz = noisy_latents.shape[0]
timestep_input = (torch.ones((bsz,), device=self.device) * 999.0) / 1000.0
packed_noisy = self.Flux2KleinPipeline._pack_latents(noisy_latents)
packed_cond = self.Flux2KleinPipeline._pack_latents(cond_model_input)
orig_shape = packed_noisy.shape
orig_ids_shape = model_input_ids.shape
packed_input = torch.cat([packed_noisy, packed_cond], dim=1)
ids_input = torch.cat([model_input_ids, cond_model_input_ids], dim=1)
model_pred = self.transformer(
hidden_states=packed_input,
timestep=timestep_input,
guidance=None,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=ids_input,
return_dict=False,
)[0]
model_pred = model_pred[:, : orig_shape[1], :]
model_pred = self.Flux2KleinPipeline._unpack_latents_with_ids(
model_pred,
model_input_ids[:, : orig_ids_shape[1], :],
)
x0_pred_normalized = noisy_latents - self.one_step_sigma * model_pred
x0_pred_patched = x0_pred_normalized * self.latents_bn_std + self.latents_bn_mean
x_denoised = self.Flux2KleinPipeline._unpatchify_latents(x0_pred_patched)
pred_images = self.vae.decode(x_denoised.to(self.weight_dtype), return_dict=False)[0]
diff_out_img = (pred_images / 2 + 0.5).clamp(0, 1)
num_txt_tokens = prompt_embeds.shape[1]
bila_feat = [feat.float() for feat in self.bila_feat]
bila_feat = [feat[:, num_txt_tokens:, ...] for feat in bila_feat]
bila_feat = [feat[:, : orig_shape[1], :] for feat in bila_feat]
bila_out_img, _ = self.bila_grid(
bila_feat,
input_fullres,
latent_h=cond_model_input.shape[-2],
latent_w=cond_model_input.shape[-1],
)
return {"diff": diff_out_img.detach().cpu(), "bila": bila_out_img.detach().cpu()}
def release_cuda():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()