| |
| """ |
| Flux LoRA evaluation script using kohya sd-scripts inference pipeline. |
| Generates 5 images per prompt across multiple epoch checkpoints. |
| """ |
|
|
| import sys |
| import os |
| import math |
| import torch |
| import einops |
| import numpy as np |
| from pathlib import Path |
| from PIL import Image |
| from safetensors.torch import load_file |
| from tqdm import tqdm |
| from typing import Optional |
|
|
| sys.path.insert(0, "/workspace/sd-scripts") |
|
|
| from library import flux_utils, strategy_flux |
| from networks import lora_flux |
|
|
| |
|
|
| CKPT_PATH = "/workspace/flux-project/models/flux/flux1-dev.safetensors" |
| CLIP_L_PATH = "/workspace/flux-project/models/flux/text_encoder/model.safetensors" |
| T5XXL_PATH = "/workspace/flux-project/models/clip/t5-xxl/t5xxl_fp16.safetensors" |
| AE_PATH = "/workspace/flux-project/models/vae/ae.safetensors" |
| OUTPUT_BASE = "/workspace/flux-project/eval" |
|
|
| CHECKPOINTS = { |
| "epoch_07": "/workspace/flux-project/output/identity_lora-000007.safetensors", |
| "epoch_08": "/workspace/flux-project/output/identity_lora-000008.safetensors", |
| "epoch_10": "/workspace/flux-project/output/identity_lora.safetensors", |
| } |
|
|
| PROMPTS = { |
| "test_01_basic_portrait": "bluej, professional portrait photography, natural lighting, highly detailed", |
| "test_02_different_clothing": "bluej wearing a navy business suit, corporate office, professional photography", |
| "test_03_different_environment": "bluej standing in Times Square at night, cinematic photography", |
| "test_04_different_expression": "bluej smiling broadly, candid photography", |
| "test_05_different_angle": "bluej, side profile portrait, studio lighting", |
| "test_06_different_age_styling": "bluej, mature professional appearance, magazine photography", |
| "test_07_different_hairstyle": "bluej with short hair, professional portrait", |
| "test_08_artistic_style": "bluej, oil painting, masterpiece", |
| "test_09_fantasy_style": "bluej, fantasy warrior, epic cinematic", |
| } |
|
|
| SEEDS = [42, 43, 44, 45, 46] |
| STEPS = 25 |
| GUIDANCE = 3.5 |
| WIDTH = 1024 |
| HEIGHT = 1024 |
| LORA_SCALE = 0.9 |
| DTYPE = torch.bfloat16 |
| DEVICE = torch.device("cuda") |
|
|
| |
|
|
| def time_shift(mu, sigma, t: torch.Tensor) -> torch.Tensor: |
| |
| return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) |
|
|
| def get_schedule(num_steps, image_seq_len, base_shift=0.5, max_shift=1.15, shift=True): |
| timesteps = torch.linspace(1, 0, num_steps + 1) |
| if shift: |
| m = (max_shift - base_shift) / (4096 - 256) |
| b = base_shift - m * 256 |
| mu = m * image_seq_len + b |
| timesteps = time_shift(mu, 1.0, timesteps) |
| return timesteps.tolist() |
|
|
| |
|
|
| def denoise(model, img, img_ids, txt, txt_ids, vec, timesteps, guidance): |
| guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) |
| for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): |
| t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) |
| mod_vectors = model.get_mod_vectors(timesteps=t_vec, guidance=guidance_vec, batch_size=img.shape[0]) |
| pred = model( |
| img=img, img_ids=img_ids, |
| txt=txt, txt_ids=txt_ids, |
| y=vec, timesteps=t_vec, |
| guidance=guidance_vec, |
| mod_vectors=mod_vectors, |
| ) |
| img = img + (t_prev - t_curr) * pred |
| return img |
|
|
| |
|
|
| print("Loading Flux transformer...") |
| _, model = flux_utils.load_flow_model(CKPT_PATH, DTYPE, DEVICE) |
| model.eval() |
|
|
| print("Loading CLIP-L...") |
| clip_l = flux_utils.load_clip_l(CLIP_L_PATH, DTYPE, DEVICE) |
| clip_l.eval() |
|
|
| print("Loading T5-XXL...") |
| t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, DTYPE, DEVICE) |
| t5xxl.eval() |
|
|
| print("Loading VAE...") |
| ae = flux_utils.load_ae(AE_PATH, DTYPE, DEVICE) |
| ae.eval() |
|
|
| tokenize_strategy = strategy_flux.FluxTokenizeStrategy(512) |
| encoding_strategy = strategy_flux.FluxTextEncodingStrategy() |
|
|
| |
|
|
| def generate(prompt, seed): |
| |
| clip_l.to(DEVICE) |
| t5xxl.to(DEVICE) |
|
|
| tokens = tokenize_strategy.tokenize(prompt) |
| with torch.no_grad(): |
| |
| with torch.autocast(device_type="cuda", dtype=DTYPE): |
| l_pooled, _, _, _ = encoding_strategy.encode_tokens( |
| tokenize_strategy, [clip_l, None], tokens |
| ) |
| |
| with torch.autocast(device_type="cuda", dtype=DTYPE): |
| _, t5_out, txt_ids, _ = encoding_strategy.encode_tokens( |
| tokenize_strategy, [clip_l, t5xxl], tokens |
| ) |
|
|
| packed_h = HEIGHT // 16 |
| packed_w = WIDTH // 16 |
| noise = torch.randn( |
| 1, packed_h * packed_w, 16 * 4, |
| device=DEVICE, dtype=DTYPE, |
| generator=torch.Generator(DEVICE).manual_seed(seed), |
| ) |
| img_ids = flux_utils.prepare_img_ids(1, packed_h, packed_w).to(DEVICE, dtype=DTYPE) |
| txt_ids = txt_ids.to(DEVICE, dtype=DTYPE) |
| t5_out = t5_out.to(DEVICE, dtype=DTYPE) |
| l_pooled = l_pooled.to(DEVICE, dtype=DTYPE) |
|
|
| timesteps = get_schedule(STEPS, noise.shape[1]) |
|
|
| with torch.autocast(device_type="cuda", dtype=DTYPE), torch.no_grad(): |
| x = denoise(model, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps, GUIDANCE) |
|
|
| x = x.float() |
| x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_h, w=packed_w, ph=2, pw=2) |
|
|
| with torch.no_grad(), torch.autocast(device_type="cuda", dtype=DTYPE): |
| x = ae.decode(x) |
|
|
| x = x.clamp(-1, 1).permute(0, 2, 3, 1) |
| img = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0]) |
| return img |
|
|
| |
|
|
| total = len(CHECKPOINTS) * len(PROMPTS) * len(SEEDS) |
| done = 0 |
| lora_model = None |
|
|
| for epoch_label, lora_path in CHECKPOINTS.items(): |
| print(f"\n{'='*60}") |
| print(f"Checkpoint: {epoch_label} ({lora_path})") |
| print(f"{'='*60}") |
|
|
| weights_sd = load_file(lora_path) |
|
|
| if lora_model is None: |
| lora_model, _ = lora_flux.create_network_from_weights( |
| LORA_SCALE, None, ae, [clip_l, t5xxl], model, weights_sd, True |
| ) |
| lora_model.apply_to([clip_l, t5xxl], model) |
|
|
| lora_model.load_state_dict(weights_sd, strict=True) |
| lora_model.eval() |
| lora_model.to(DEVICE, dtype=DTYPE) |
| clip_l.to(DEVICE, dtype=DTYPE) |
| t5xxl.to(DEVICE, dtype=DTYPE) |
| model.to(DEVICE, dtype=DTYPE) |
|
|
| for test_name, prompt in PROMPTS.items(): |
| out_dir = Path(OUTPUT_BASE) / epoch_label / test_name |
| out_dir.mkdir(parents=True, exist_ok=True) |
|
|
| print(f"\n [{test_name}]") |
| print(f" Prompt: {prompt}") |
|
|
| for seed in SEEDS: |
| img_path = out_dir / f"seed_{seed:04d}.png" |
| if img_path.exists(): |
| print(f" seed {seed} β already exists, skipping") |
| done += 1 |
| continue |
|
|
| image = generate(prompt, seed) |
| image.save(img_path) |
| done += 1 |
| print(f" seed {seed} β saved ({done}/{total})") |
|
|
| print(f"\nDone. All images saved to {OUTPUT_BASE}/") |
|
|