#!/usr/bin/env python3 """ Flux LoRA evaluation script using kohya sd-scripts inference pipeline. Generates 5 images per prompt across multiple epoch checkpoints. """ import sys import os import math import torch import einops import numpy as np from pathlib import Path from PIL import Image from safetensors.torch import load_file from tqdm import tqdm from typing import Optional sys.path.insert(0, "/workspace/sd-scripts") from library import flux_utils, strategy_flux from networks import lora_flux # ── Config ──────────────────────────────────────────────────────────────────── CKPT_PATH = "/workspace/flux-project/models/flux/flux1-dev.safetensors" CLIP_L_PATH = "/workspace/flux-project/models/flux/text_encoder/model.safetensors" T5XXL_PATH = "/workspace/flux-project/models/clip/t5-xxl/t5xxl_fp16.safetensors" AE_PATH = "/workspace/flux-project/models/vae/ae.safetensors" OUTPUT_BASE = "/workspace/flux-project/eval" CHECKPOINTS = { "epoch_07": "/workspace/flux-project/output/identity_lora-000007.safetensors", "epoch_08": "/workspace/flux-project/output/identity_lora-000008.safetensors", "epoch_10": "/workspace/flux-project/output/identity_lora.safetensors", } PROMPTS = { "test_01_basic_portrait": "bluej, professional portrait photography, natural lighting, highly detailed", "test_02_different_clothing": "bluej wearing a navy business suit, corporate office, professional photography", "test_03_different_environment": "bluej standing in Times Square at night, cinematic photography", "test_04_different_expression": "bluej smiling broadly, candid photography", "test_05_different_angle": "bluej, side profile portrait, studio lighting", "test_06_different_age_styling": "bluej, mature professional appearance, magazine photography", "test_07_different_hairstyle": "bluej with short hair, professional portrait", "test_08_artistic_style": "bluej, oil painting, masterpiece", "test_09_fantasy_style": "bluej, fantasy warrior, epic cinematic", } SEEDS = [42, 43, 44, 45, 46] STEPS = 25 GUIDANCE = 3.5 WIDTH = 1024 HEIGHT = 1024 LORA_SCALE = 0.9 DTYPE = torch.bfloat16 DEVICE = torch.device("cuda") # ── Scheduler ───────────────────────────────────────────────────────────────── def time_shift(mu, sigma, t: torch.Tensor) -> torch.Tensor: # operate on the full tensor so 1/0 → inf → result 0 (handled by PyTorch) return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) def get_schedule(num_steps, image_seq_len, base_shift=0.5, max_shift=1.15, shift=True): timesteps = torch.linspace(1, 0, num_steps + 1) if shift: m = (max_shift - base_shift) / (4096 - 256) b = base_shift - m * 256 mu = m * image_seq_len + b timesteps = time_shift(mu, 1.0, timesteps) return timesteps.tolist() # ── Denoise (matches kohya's denoise exactly) ───────────────────────────────── def denoise(model, img, img_ids, txt, txt_ids, vec, timesteps, guidance): guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) mod_vectors = model.get_mod_vectors(timesteps=t_vec, guidance=guidance_vec, batch_size=img.shape[0]) pred = model( img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, timesteps=t_vec, guidance=guidance_vec, mod_vectors=mod_vectors, ) img = img + (t_prev - t_curr) * pred return img # ── Load base models (once) ─────────────────────────────────────────────────── print("Loading Flux transformer...") _, model = flux_utils.load_flow_model(CKPT_PATH, DTYPE, DEVICE) model.eval() print("Loading CLIP-L...") clip_l = flux_utils.load_clip_l(CLIP_L_PATH, DTYPE, DEVICE) clip_l.eval() print("Loading T5-XXL...") t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, DTYPE, DEVICE) t5xxl.eval() print("Loading VAE...") ae = flux_utils.load_ae(AE_PATH, DTYPE, DEVICE) ae.eval() tokenize_strategy = strategy_flux.FluxTokenizeStrategy(512) encoding_strategy = strategy_flux.FluxTextEncodingStrategy() # ── Generation function ─────────────────────────────────────────────────────── def generate(prompt, seed): # Ensure text encoders are fully on device (position_ids buffer may be on CPU after load) clip_l.to(DEVICE) t5xxl.to(DEVICE) tokens = tokenize_strategy.tokenize(prompt) with torch.no_grad(): # Encode CLIP-L pooled embedding with torch.autocast(device_type="cuda", dtype=DTYPE): l_pooled, _, _, _ = encoding_strategy.encode_tokens( tokenize_strategy, [clip_l, None], tokens ) # Encode T5-XXL sequence with torch.autocast(device_type="cuda", dtype=DTYPE): _, t5_out, txt_ids, _ = encoding_strategy.encode_tokens( tokenize_strategy, [clip_l, t5xxl], tokens ) packed_h = HEIGHT // 16 packed_w = WIDTH // 16 noise = torch.randn( 1, packed_h * packed_w, 16 * 4, device=DEVICE, dtype=DTYPE, generator=torch.Generator(DEVICE).manual_seed(seed), ) img_ids = flux_utils.prepare_img_ids(1, packed_h, packed_w).to(DEVICE, dtype=DTYPE) txt_ids = txt_ids.to(DEVICE, dtype=DTYPE) t5_out = t5_out.to(DEVICE, dtype=DTYPE) l_pooled = l_pooled.to(DEVICE, dtype=DTYPE) timesteps = get_schedule(STEPS, noise.shape[1]) with torch.autocast(device_type="cuda", dtype=DTYPE), torch.no_grad(): x = denoise(model, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps, GUIDANCE) x = x.float() x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_h, w=packed_w, ph=2, pw=2) with torch.no_grad(), torch.autocast(device_type="cuda", dtype=DTYPE): x = ae.decode(x) x = x.clamp(-1, 1).permute(0, 2, 3, 1) img = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0]) return img # ── Generation loop ─────────────────────────────────────────────────────────── total = len(CHECKPOINTS) * len(PROMPTS) * len(SEEDS) done = 0 lora_model = None for epoch_label, lora_path in CHECKPOINTS.items(): print(f"\n{'='*60}") print(f"Checkpoint: {epoch_label} ({lora_path})") print(f"{'='*60}") weights_sd = load_file(lora_path) if lora_model is None: lora_model, _ = lora_flux.create_network_from_weights( LORA_SCALE, None, ae, [clip_l, t5xxl], model, weights_sd, True ) lora_model.apply_to([clip_l, t5xxl], model) lora_model.load_state_dict(weights_sd, strict=True) lora_model.eval() lora_model.to(DEVICE, dtype=DTYPE) clip_l.to(DEVICE, dtype=DTYPE) t5xxl.to(DEVICE, dtype=DTYPE) model.to(DEVICE, dtype=DTYPE) for test_name, prompt in PROMPTS.items(): out_dir = Path(OUTPUT_BASE) / epoch_label / test_name out_dir.mkdir(parents=True, exist_ok=True) print(f"\n [{test_name}]") print(f" Prompt: {prompt}") for seed in SEEDS: img_path = out_dir / f"seed_{seed:04d}.png" if img_path.exists(): print(f" seed {seed} — already exists, skipping") done += 1 continue image = generate(prompt, seed) image.save(img_path) done += 1 print(f" seed {seed} — saved ({done}/{total})") print(f"\nDone. All images saved to {OUTPUT_BASE}/")