dailydrop-omnitry / inference.py
Iliya Belous
fix: use OmniTry custom pipeline with dual PEFT LoRA loading
f7a3af8
"""OmniTry (NeurIPS 2025) inference pipeline.
Uses OmniTry's custom FluxTransformer2DModel + FluxFillPipeline with
dual PEFT LoRA adapters (vtryon_lora + garment_lora) loaded via
safetensors. Standard diffusers load_lora_weights() does NOT work
because OmniTry uses custom key prefixes and split-batch LoRA routing.
"""
import logging
import math
import time
from typing import Optional
import peft
import torch
import torchvision.transforms as T
from huggingface_hub import hf_hub_download
from peft import LoraConfig
from PIL import Image
from safetensors import safe_open
from omnitry.models.transformer_flux import FluxTransformer2DModel
from omnitry.pipelines.pipeline_flux_fill import FluxFillPipeline
logger = logging.getLogger("omnitry")
_pipe = None
_transformer = None
# Category -> prompt mapping (from OmniTry config)
CATEGORY_PROMPTS = {
"upper_body": "replacing the top cloth",
"top clothes": "replacing the top cloth",
"top": "replacing the top cloth",
"outerwear": "replacing the top cloth",
"lower_body": "replacing the bottom cloth",
"bottom clothes": "replacing the bottom cloth",
"bottom": "replacing the bottom cloth",
"dresses": "replacing the dress",
"dress": "replacing the dress",
}
LORA_TARGET_MODULES = [
"x_embedder",
"attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0",
"attn.add_k_proj", "attn.add_q_proj", "attn.add_v_proj", "attn.to_add_out",
"ff.net.0.proj", "ff.net.2", "ff_context.net.0.proj", "ff_context.net.2",
"norm1_context.linear", "norm1.linear", "norm.linear", "proj_mlp", "proj_out",
]
LORA_RANK = 16
LORA_ALPHA = 16
def _create_hacked_forward(module):
"""Hack LoRA forward to route different batch elements to different adapters.
Batch element 0 (person) -> vtryon_lora adapter.
Batch element 1 (garment) -> garment_lora adapter.
"""
def lora_forward(self, active_adapter, x, *args, **kwargs):
result = self.base_layer(x, *args, **kwargs)
if active_adapter is not None:
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
x = x.to(lora_A.weight.dtype)
result = result + lora_B(lora_A(dropout(x))) * scaling
return result
def hacked_lora_forward(self, x, *args, **kwargs):
return torch.cat((
lora_forward(self, "vtryon_lora", x[:1], *args, **kwargs),
lora_forward(self, "garment_lora", x[1:], *args, **kwargs),
), dim=0)
return hacked_lora_forward.__get__(module, type(module))
def load_pipeline(
hf_token: Optional[str] = None,
weight_name: str = "omnitry_v1_clothes.safetensors",
):
"""Load OmniTry pipeline: custom FLUX transformer + dual LoRA adapters."""
global _pipe, _transformer
device = torch.device("cuda:0")
dtype = torch.bfloat16
logger.info("Loading custom FluxTransformer2DModel...")
t0 = time.time()
# Step 1: Load OmniTry's custom transformer (NOT stock diffusers)
_transformer = (
FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-Fill-dev",
subfolder="transformer",
token=hf_token,
)
.requires_grad_(False)
.to(device, dtype=dtype)
)
# Step 2: Build pipeline with custom transformer
_pipe = FluxFillPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Fill-dev",
transformer=_transformer,
torch_dtype=dtype,
token=hf_token,
).to(device)
logger.info("Base pipeline loaded (%.1fs). Injecting LoRA adapters...", time.time() - t0)
# Step 3: Add dual PEFT LoRA adapters
lora_config = LoraConfig(
r=LORA_RANK,
lora_alpha=LORA_ALPHA,
init_lora_weights="gaussian",
target_modules=LORA_TARGET_MODULES,
)
_transformer.add_adapter(lora_config, adapter_name="vtryon_lora")
_transformer.add_adapter(lora_config, adapter_name="garment_lora")
# Step 4: Download and load safetensors weights
lora_path = hf_hub_download(
repo_id="Kunbyte/OmniTry",
filename=weight_name,
token=hf_token,
)
with safe_open(lora_path, framework="pt") as f:
lora_weights = {k: f.get_tensor(k) for k in f.keys()}
_transformer.load_state_dict(lora_weights, strict=False)
logger.info("LoRA weights loaded (%.1fs). Hacking forwards...", time.time() - t0)
# Step 5: Hack LoRA forward to split adapters across batch
for n, m in _transformer.named_modules():
if isinstance(m, peft.tuners.lora.layer.Linear):
m.forward = _create_hacked_forward(m)
logger.info("Pipeline ready on GPU (%.1fs total)", time.time() - t0)
return _pipe
def run_tryon(
pipeline,
human_img: Image.Image,
garment_img: Image.Image,
category: str = "upper_body",
steps: int = 20,
seed: int = 42,
guidance_scale: float = 30.0,
) -> Image.Image:
"""Run OmniTry virtual try-on."""
t0 = time.time()
device = torch.device("cuda:0")
dtype = torch.bfloat16
# Resize person image (preserve aspect, max 1024x1024 area)
max_area = 1024 * 1024
oW, oH = human_img.size
ratio = math.sqrt(max_area / (oW * oH))
ratio = min(1, ratio)
tW = int(oW * ratio) // 16 * 16
tH = int(oH * ratio) // 16 * 16
transform = T.Compose([T.Resize((tH, tW)), T.ToTensor()])
person_tensor = transform(human_img.convert("RGB"))
# Resize garment and center-pad to same size as person
gW, gH = garment_img.size
g_ratio = min(tW / gW, tH / gH)
g_transform = T.Compose([
T.Resize((int(gH * g_ratio), int(gW * g_ratio))),
T.ToTensor(),
])
garment_padded = torch.ones_like(person_tensor)
garment_tensor = g_transform(garment_img.convert("RGB"))
new_h, new_w = garment_tensor.shape[1], garment_tensor.shape[2]
min_x = (tW - new_w) // 2
min_y = (tH - new_h) // 2
garment_padded[:, min_y:min_y + new_h, min_x:min_x + new_w] = garment_tensor
# Prepare prompts and conditions
prompt_text = CATEGORY_PROMPTS.get(category, "replacing the top cloth")
prompts = [prompt_text] * 2
img_cond = torch.stack([person_tensor, garment_padded]).to(dtype=dtype, device=device)
mask = torch.zeros_like(img_cond).to(img_cond)
logger.info(
"Running OmniTry: category=%s prompt=%r size=%dx%d steps=%d",
category, prompt_text, tW, tH, steps,
)
with torch.no_grad():
result = pipeline(
prompt=prompts,
height=tH,
width=tW,
img_cond=img_cond,
mask=mask,
guidance_scale=guidance_scale,
num_inference_steps=steps,
generator=torch.Generator(device).manual_seed(seed),
).images[0]
logger.info("Inference done (%.1fs)", time.time() - t0)
return result