MYai / scripts /run_eval.py
montignyp's picture
Add epoch 8 & 10 LoRA weights, training config, eval script
6d1cce5
#!/usr/bin/env python3
"""
Flux LoRA evaluation script using kohya sd-scripts inference pipeline.
Generates 5 images per prompt across multiple epoch checkpoints.
"""
import sys
import os
import math
import torch
import einops
import numpy as np
from pathlib import Path
from PIL import Image
from safetensors.torch import load_file
from tqdm import tqdm
from typing import Optional
sys.path.insert(0, "/workspace/sd-scripts")
from library import flux_utils, strategy_flux
from networks import lora_flux
# ── Config ────────────────────────────────────────────────────────────────────
CKPT_PATH = "/workspace/flux-project/models/flux/flux1-dev.safetensors"
CLIP_L_PATH = "/workspace/flux-project/models/flux/text_encoder/model.safetensors"
T5XXL_PATH = "/workspace/flux-project/models/clip/t5-xxl/t5xxl_fp16.safetensors"
AE_PATH = "/workspace/flux-project/models/vae/ae.safetensors"
OUTPUT_BASE = "/workspace/flux-project/eval"
CHECKPOINTS = {
"epoch_07": "/workspace/flux-project/output/identity_lora-000007.safetensors",
"epoch_08": "/workspace/flux-project/output/identity_lora-000008.safetensors",
"epoch_10": "/workspace/flux-project/output/identity_lora.safetensors",
}
PROMPTS = {
"test_01_basic_portrait": "bluej, professional portrait photography, natural lighting, highly detailed",
"test_02_different_clothing": "bluej wearing a navy business suit, corporate office, professional photography",
"test_03_different_environment": "bluej standing in Times Square at night, cinematic photography",
"test_04_different_expression": "bluej smiling broadly, candid photography",
"test_05_different_angle": "bluej, side profile portrait, studio lighting",
"test_06_different_age_styling": "bluej, mature professional appearance, magazine photography",
"test_07_different_hairstyle": "bluej with short hair, professional portrait",
"test_08_artistic_style": "bluej, oil painting, masterpiece",
"test_09_fantasy_style": "bluej, fantasy warrior, epic cinematic",
}
SEEDS = [42, 43, 44, 45, 46]
STEPS = 25
GUIDANCE = 3.5
WIDTH = 1024
HEIGHT = 1024
LORA_SCALE = 0.9
DTYPE = torch.bfloat16
DEVICE = torch.device("cuda")
# ── Scheduler ─────────────────────────────────────────────────────────────────
def time_shift(mu, sigma, t: torch.Tensor) -> torch.Tensor:
# operate on the full tensor so 1/0 β†’ inf β†’ result 0 (handled by PyTorch)
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_schedule(num_steps, image_seq_len, base_shift=0.5, max_shift=1.15, shift=True):
timesteps = torch.linspace(1, 0, num_steps + 1)
if shift:
m = (max_shift - base_shift) / (4096 - 256)
b = base_shift - m * 256
mu = m * image_seq_len + b
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist()
# ── Denoise (matches kohya's denoise exactly) ─────────────────────────────────
def denoise(model, img, img_ids, txt, txt_ids, vec, timesteps, guidance):
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
mod_vectors = model.get_mod_vectors(timesteps=t_vec, guidance=guidance_vec, batch_size=img.shape[0])
pred = model(
img=img, img_ids=img_ids,
txt=txt, txt_ids=txt_ids,
y=vec, timesteps=t_vec,
guidance=guidance_vec,
mod_vectors=mod_vectors,
)
img = img + (t_prev - t_curr) * pred
return img
# ── Load base models (once) ───────────────────────────────────────────────────
print("Loading Flux transformer...")
_, model = flux_utils.load_flow_model(CKPT_PATH, DTYPE, DEVICE)
model.eval()
print("Loading CLIP-L...")
clip_l = flux_utils.load_clip_l(CLIP_L_PATH, DTYPE, DEVICE)
clip_l.eval()
print("Loading T5-XXL...")
t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, DTYPE, DEVICE)
t5xxl.eval()
print("Loading VAE...")
ae = flux_utils.load_ae(AE_PATH, DTYPE, DEVICE)
ae.eval()
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(512)
encoding_strategy = strategy_flux.FluxTextEncodingStrategy()
# ── Generation function ───────────────────────────────────────────────────────
def generate(prompt, seed):
# Ensure text encoders are fully on device (position_ids buffer may be on CPU after load)
clip_l.to(DEVICE)
t5xxl.to(DEVICE)
tokens = tokenize_strategy.tokenize(prompt)
with torch.no_grad():
# Encode CLIP-L pooled embedding
with torch.autocast(device_type="cuda", dtype=DTYPE):
l_pooled, _, _, _ = encoding_strategy.encode_tokens(
tokenize_strategy, [clip_l, None], tokens
)
# Encode T5-XXL sequence
with torch.autocast(device_type="cuda", dtype=DTYPE):
_, t5_out, txt_ids, _ = encoding_strategy.encode_tokens(
tokenize_strategy, [clip_l, t5xxl], tokens
)
packed_h = HEIGHT // 16
packed_w = WIDTH // 16
noise = torch.randn(
1, packed_h * packed_w, 16 * 4,
device=DEVICE, dtype=DTYPE,
generator=torch.Generator(DEVICE).manual_seed(seed),
)
img_ids = flux_utils.prepare_img_ids(1, packed_h, packed_w).to(DEVICE, dtype=DTYPE)
txt_ids = txt_ids.to(DEVICE, dtype=DTYPE)
t5_out = t5_out.to(DEVICE, dtype=DTYPE)
l_pooled = l_pooled.to(DEVICE, dtype=DTYPE)
timesteps = get_schedule(STEPS, noise.shape[1])
with torch.autocast(device_type="cuda", dtype=DTYPE), torch.no_grad():
x = denoise(model, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps, GUIDANCE)
x = x.float()
x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_h, w=packed_w, ph=2, pw=2)
with torch.no_grad(), torch.autocast(device_type="cuda", dtype=DTYPE):
x = ae.decode(x)
x = x.clamp(-1, 1).permute(0, 2, 3, 1)
img = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
return img
# ── Generation loop ───────────────────────────────────────────────────────────
total = len(CHECKPOINTS) * len(PROMPTS) * len(SEEDS)
done = 0
lora_model = None
for epoch_label, lora_path in CHECKPOINTS.items():
print(f"\n{'='*60}")
print(f"Checkpoint: {epoch_label} ({lora_path})")
print(f"{'='*60}")
weights_sd = load_file(lora_path)
if lora_model is None:
lora_model, _ = lora_flux.create_network_from_weights(
LORA_SCALE, None, ae, [clip_l, t5xxl], model, weights_sd, True
)
lora_model.apply_to([clip_l, t5xxl], model)
lora_model.load_state_dict(weights_sd, strict=True)
lora_model.eval()
lora_model.to(DEVICE, dtype=DTYPE)
clip_l.to(DEVICE, dtype=DTYPE)
t5xxl.to(DEVICE, dtype=DTYPE)
model.to(DEVICE, dtype=DTYPE)
for test_name, prompt in PROMPTS.items():
out_dir = Path(OUTPUT_BASE) / epoch_label / test_name
out_dir.mkdir(parents=True, exist_ok=True)
print(f"\n [{test_name}]")
print(f" Prompt: {prompt}")
for seed in SEEDS:
img_path = out_dir / f"seed_{seed:04d}.png"
if img_path.exists():
print(f" seed {seed} β€” already exists, skipping")
done += 1
continue
image = generate(prompt, seed)
image.save(img_path)
done += 1
print(f" seed {seed} β€” saved ({done}/{total})")
print(f"\nDone. All images saved to {OUTPUT_BASE}/")