| """OmniTry (NeurIPS 2025) inference pipeline. |
| |
| Uses OmniTry's custom FluxTransformer2DModel + FluxFillPipeline with |
| dual PEFT LoRA adapters (vtryon_lora + garment_lora) loaded via |
| safetensors. Standard diffusers load_lora_weights() does NOT work |
| because OmniTry uses custom key prefixes and split-batch LoRA routing. |
| """ |
|
|
| import logging |
| import math |
| import time |
| from typing import Optional |
|
|
| import peft |
| import torch |
| import torchvision.transforms as T |
| from huggingface_hub import hf_hub_download |
| from peft import LoraConfig |
| from PIL import Image |
| from safetensors import safe_open |
|
|
| from omnitry.models.transformer_flux import FluxTransformer2DModel |
| from omnitry.pipelines.pipeline_flux_fill import FluxFillPipeline |
|
|
| logger = logging.getLogger("omnitry") |
|
|
| _pipe = None |
| _transformer = None |
|
|
| |
| CATEGORY_PROMPTS = { |
| "upper_body": "replacing the top cloth", |
| "top clothes": "replacing the top cloth", |
| "top": "replacing the top cloth", |
| "outerwear": "replacing the top cloth", |
| "lower_body": "replacing the bottom cloth", |
| "bottom clothes": "replacing the bottom cloth", |
| "bottom": "replacing the bottom cloth", |
| "dresses": "replacing the dress", |
| "dress": "replacing the dress", |
| } |
|
|
| LORA_TARGET_MODULES = [ |
| "x_embedder", |
| "attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0", |
| "attn.add_k_proj", "attn.add_q_proj", "attn.add_v_proj", "attn.to_add_out", |
| "ff.net.0.proj", "ff.net.2", "ff_context.net.0.proj", "ff_context.net.2", |
| "norm1_context.linear", "norm1.linear", "norm.linear", "proj_mlp", "proj_out", |
| ] |
|
|
| LORA_RANK = 16 |
| LORA_ALPHA = 16 |
|
|
|
|
| def _create_hacked_forward(module): |
| """Hack LoRA forward to route different batch elements to different adapters. |
| |
| Batch element 0 (person) -> vtryon_lora adapter. |
| Batch element 1 (garment) -> garment_lora adapter. |
| """ |
| def lora_forward(self, active_adapter, x, *args, **kwargs): |
| result = self.base_layer(x, *args, **kwargs) |
| if active_adapter is not None: |
| lora_A = self.lora_A[active_adapter] |
| lora_B = self.lora_B[active_adapter] |
| dropout = self.lora_dropout[active_adapter] |
| scaling = self.scaling[active_adapter] |
| x = x.to(lora_A.weight.dtype) |
| result = result + lora_B(lora_A(dropout(x))) * scaling |
| return result |
|
|
| def hacked_lora_forward(self, x, *args, **kwargs): |
| return torch.cat(( |
| lora_forward(self, "vtryon_lora", x[:1], *args, **kwargs), |
| lora_forward(self, "garment_lora", x[1:], *args, **kwargs), |
| ), dim=0) |
|
|
| return hacked_lora_forward.__get__(module, type(module)) |
|
|
|
|
| def load_pipeline( |
| hf_token: Optional[str] = None, |
| weight_name: str = "omnitry_v1_clothes.safetensors", |
| ): |
| """Load OmniTry pipeline: custom FLUX transformer + dual LoRA adapters.""" |
| global _pipe, _transformer |
|
|
| device = torch.device("cuda:0") |
| dtype = torch.bfloat16 |
|
|
| logger.info("Loading custom FluxTransformer2DModel...") |
| t0 = time.time() |
|
|
| |
| _transformer = ( |
| FluxTransformer2DModel.from_pretrained( |
| "black-forest-labs/FLUX.1-Fill-dev", |
| subfolder="transformer", |
| token=hf_token, |
| ) |
| .requires_grad_(False) |
| .to(device, dtype=dtype) |
| ) |
|
|
| |
| _pipe = FluxFillPipeline.from_pretrained( |
| "black-forest-labs/FLUX.1-Fill-dev", |
| transformer=_transformer, |
| torch_dtype=dtype, |
| token=hf_token, |
| ).to(device) |
|
|
| logger.info("Base pipeline loaded (%.1fs). Injecting LoRA adapters...", time.time() - t0) |
|
|
| |
| lora_config = LoraConfig( |
| r=LORA_RANK, |
| lora_alpha=LORA_ALPHA, |
| init_lora_weights="gaussian", |
| target_modules=LORA_TARGET_MODULES, |
| ) |
| _transformer.add_adapter(lora_config, adapter_name="vtryon_lora") |
| _transformer.add_adapter(lora_config, adapter_name="garment_lora") |
|
|
| |
| lora_path = hf_hub_download( |
| repo_id="Kunbyte/OmniTry", |
| filename=weight_name, |
| token=hf_token, |
| ) |
| with safe_open(lora_path, framework="pt") as f: |
| lora_weights = {k: f.get_tensor(k) for k in f.keys()} |
| _transformer.load_state_dict(lora_weights, strict=False) |
|
|
| logger.info("LoRA weights loaded (%.1fs). Hacking forwards...", time.time() - t0) |
|
|
| |
| for n, m in _transformer.named_modules(): |
| if isinstance(m, peft.tuners.lora.layer.Linear): |
| m.forward = _create_hacked_forward(m) |
|
|
| logger.info("Pipeline ready on GPU (%.1fs total)", time.time() - t0) |
| return _pipe |
|
|
|
|
| def run_tryon( |
| pipeline, |
| human_img: Image.Image, |
| garment_img: Image.Image, |
| category: str = "upper_body", |
| steps: int = 20, |
| seed: int = 42, |
| guidance_scale: float = 30.0, |
| ) -> Image.Image: |
| """Run OmniTry virtual try-on.""" |
| t0 = time.time() |
| device = torch.device("cuda:0") |
| dtype = torch.bfloat16 |
|
|
| |
| max_area = 1024 * 1024 |
| oW, oH = human_img.size |
| ratio = math.sqrt(max_area / (oW * oH)) |
| ratio = min(1, ratio) |
| tW = int(oW * ratio) // 16 * 16 |
| tH = int(oH * ratio) // 16 * 16 |
|
|
| transform = T.Compose([T.Resize((tH, tW)), T.ToTensor()]) |
| person_tensor = transform(human_img.convert("RGB")) |
|
|
| |
| gW, gH = garment_img.size |
| g_ratio = min(tW / gW, tH / gH) |
| g_transform = T.Compose([ |
| T.Resize((int(gH * g_ratio), int(gW * g_ratio))), |
| T.ToTensor(), |
| ]) |
| garment_padded = torch.ones_like(person_tensor) |
| garment_tensor = g_transform(garment_img.convert("RGB")) |
| new_h, new_w = garment_tensor.shape[1], garment_tensor.shape[2] |
| min_x = (tW - new_w) // 2 |
| min_y = (tH - new_h) // 2 |
| garment_padded[:, min_y:min_y + new_h, min_x:min_x + new_w] = garment_tensor |
|
|
| |
| prompt_text = CATEGORY_PROMPTS.get(category, "replacing the top cloth") |
| prompts = [prompt_text] * 2 |
| img_cond = torch.stack([person_tensor, garment_padded]).to(dtype=dtype, device=device) |
| mask = torch.zeros_like(img_cond).to(img_cond) |
|
|
| logger.info( |
| "Running OmniTry: category=%s prompt=%r size=%dx%d steps=%d", |
| category, prompt_text, tW, tH, steps, |
| ) |
|
|
| with torch.no_grad(): |
| result = pipeline( |
| prompt=prompts, |
| height=tH, |
| width=tW, |
| img_cond=img_cond, |
| mask=mask, |
| guidance_scale=guidance_scale, |
| num_inference_steps=steps, |
| generator=torch.Generator(device).manual_seed(seed), |
| ).images[0] |
|
|
| logger.info("Inference done (%.1fs)", time.time() - t0) |
| return result |
|
|