File size: 8,305 Bytes
6d1cce5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 | #!/usr/bin/env python3
"""
Flux LoRA evaluation script using kohya sd-scripts inference pipeline.
Generates 5 images per prompt across multiple epoch checkpoints.
"""
import sys
import os
import math
import torch
import einops
import numpy as np
from pathlib import Path
from PIL import Image
from safetensors.torch import load_file
from tqdm import tqdm
from typing import Optional
sys.path.insert(0, "/workspace/sd-scripts")
from library import flux_utils, strategy_flux
from networks import lora_flux
# ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
CKPT_PATH = "/workspace/flux-project/models/flux/flux1-dev.safetensors"
CLIP_L_PATH = "/workspace/flux-project/models/flux/text_encoder/model.safetensors"
T5XXL_PATH = "/workspace/flux-project/models/clip/t5-xxl/t5xxl_fp16.safetensors"
AE_PATH = "/workspace/flux-project/models/vae/ae.safetensors"
OUTPUT_BASE = "/workspace/flux-project/eval"
CHECKPOINTS = {
"epoch_07": "/workspace/flux-project/output/identity_lora-000007.safetensors",
"epoch_08": "/workspace/flux-project/output/identity_lora-000008.safetensors",
"epoch_10": "/workspace/flux-project/output/identity_lora.safetensors",
}
PROMPTS = {
"test_01_basic_portrait": "bluej, professional portrait photography, natural lighting, highly detailed",
"test_02_different_clothing": "bluej wearing a navy business suit, corporate office, professional photography",
"test_03_different_environment": "bluej standing in Times Square at night, cinematic photography",
"test_04_different_expression": "bluej smiling broadly, candid photography",
"test_05_different_angle": "bluej, side profile portrait, studio lighting",
"test_06_different_age_styling": "bluej, mature professional appearance, magazine photography",
"test_07_different_hairstyle": "bluej with short hair, professional portrait",
"test_08_artistic_style": "bluej, oil painting, masterpiece",
"test_09_fantasy_style": "bluej, fantasy warrior, epic cinematic",
}
SEEDS = [42, 43, 44, 45, 46]
STEPS = 25
GUIDANCE = 3.5
WIDTH = 1024
HEIGHT = 1024
LORA_SCALE = 0.9
DTYPE = torch.bfloat16
DEVICE = torch.device("cuda")
# ββ Scheduler βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def time_shift(mu, sigma, t: torch.Tensor) -> torch.Tensor:
# operate on the full tensor so 1/0 β inf β result 0 (handled by PyTorch)
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_schedule(num_steps, image_seq_len, base_shift=0.5, max_shift=1.15, shift=True):
timesteps = torch.linspace(1, 0, num_steps + 1)
if shift:
m = (max_shift - base_shift) / (4096 - 256)
b = base_shift - m * 256
mu = m * image_seq_len + b
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist()
# ββ Denoise (matches kohya's denoise exactly) βββββββββββββββββββββββββββββββββ
def denoise(model, img, img_ids, txt, txt_ids, vec, timesteps, guidance):
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
mod_vectors = model.get_mod_vectors(timesteps=t_vec, guidance=guidance_vec, batch_size=img.shape[0])
pred = model(
img=img, img_ids=img_ids,
txt=txt, txt_ids=txt_ids,
y=vec, timesteps=t_vec,
guidance=guidance_vec,
mod_vectors=mod_vectors,
)
img = img + (t_prev - t_curr) * pred
return img
# ββ Load base models (once) βββββββββββββββββββββββββββββββββββββββββββββββββββ
print("Loading Flux transformer...")
_, model = flux_utils.load_flow_model(CKPT_PATH, DTYPE, DEVICE)
model.eval()
print("Loading CLIP-L...")
clip_l = flux_utils.load_clip_l(CLIP_L_PATH, DTYPE, DEVICE)
clip_l.eval()
print("Loading T5-XXL...")
t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, DTYPE, DEVICE)
t5xxl.eval()
print("Loading VAE...")
ae = flux_utils.load_ae(AE_PATH, DTYPE, DEVICE)
ae.eval()
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(512)
encoding_strategy = strategy_flux.FluxTextEncodingStrategy()
# ββ Generation function βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def generate(prompt, seed):
# Ensure text encoders are fully on device (position_ids buffer may be on CPU after load)
clip_l.to(DEVICE)
t5xxl.to(DEVICE)
tokens = tokenize_strategy.tokenize(prompt)
with torch.no_grad():
# Encode CLIP-L pooled embedding
with torch.autocast(device_type="cuda", dtype=DTYPE):
l_pooled, _, _, _ = encoding_strategy.encode_tokens(
tokenize_strategy, [clip_l, None], tokens
)
# Encode T5-XXL sequence
with torch.autocast(device_type="cuda", dtype=DTYPE):
_, t5_out, txt_ids, _ = encoding_strategy.encode_tokens(
tokenize_strategy, [clip_l, t5xxl], tokens
)
packed_h = HEIGHT // 16
packed_w = WIDTH // 16
noise = torch.randn(
1, packed_h * packed_w, 16 * 4,
device=DEVICE, dtype=DTYPE,
generator=torch.Generator(DEVICE).manual_seed(seed),
)
img_ids = flux_utils.prepare_img_ids(1, packed_h, packed_w).to(DEVICE, dtype=DTYPE)
txt_ids = txt_ids.to(DEVICE, dtype=DTYPE)
t5_out = t5_out.to(DEVICE, dtype=DTYPE)
l_pooled = l_pooled.to(DEVICE, dtype=DTYPE)
timesteps = get_schedule(STEPS, noise.shape[1])
with torch.autocast(device_type="cuda", dtype=DTYPE), torch.no_grad():
x = denoise(model, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps, GUIDANCE)
x = x.float()
x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_h, w=packed_w, ph=2, pw=2)
with torch.no_grad(), torch.autocast(device_type="cuda", dtype=DTYPE):
x = ae.decode(x)
x = x.clamp(-1, 1).permute(0, 2, 3, 1)
img = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
return img
# ββ Generation loop βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
total = len(CHECKPOINTS) * len(PROMPTS) * len(SEEDS)
done = 0
lora_model = None
for epoch_label, lora_path in CHECKPOINTS.items():
print(f"\n{'='*60}")
print(f"Checkpoint: {epoch_label} ({lora_path})")
print(f"{'='*60}")
weights_sd = load_file(lora_path)
if lora_model is None:
lora_model, _ = lora_flux.create_network_from_weights(
LORA_SCALE, None, ae, [clip_l, t5xxl], model, weights_sd, True
)
lora_model.apply_to([clip_l, t5xxl], model)
lora_model.load_state_dict(weights_sd, strict=True)
lora_model.eval()
lora_model.to(DEVICE, dtype=DTYPE)
clip_l.to(DEVICE, dtype=DTYPE)
t5xxl.to(DEVICE, dtype=DTYPE)
model.to(DEVICE, dtype=DTYPE)
for test_name, prompt in PROMPTS.items():
out_dir = Path(OUTPUT_BASE) / epoch_label / test_name
out_dir.mkdir(parents=True, exist_ok=True)
print(f"\n [{test_name}]")
print(f" Prompt: {prompt}")
for seed in SEEDS:
img_path = out_dir / f"seed_{seed:04d}.png"
if img_path.exists():
print(f" seed {seed} β already exists, skipping")
done += 1
continue
image = generate(prompt, seed)
image.save(img_path)
done += 1
print(f" seed {seed} β saved ({done}/{total})")
print(f"\nDone. All images saved to {OUTPUT_BASE}/")
|