File size: 8,305 Bytes
6d1cce5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
#!/usr/bin/env python3
"""
Flux LoRA evaluation script using kohya sd-scripts inference pipeline.
Generates 5 images per prompt across multiple epoch checkpoints.
"""

import sys
import os
import math
import torch
import einops
import numpy as np
from pathlib import Path
from PIL import Image
from safetensors.torch import load_file
from tqdm import tqdm
from typing import Optional

sys.path.insert(0, "/workspace/sd-scripts")

from library import flux_utils, strategy_flux
from networks import lora_flux

# ── Config ────────────────────────────────────────────────────────────────────

CKPT_PATH   = "/workspace/flux-project/models/flux/flux1-dev.safetensors"
CLIP_L_PATH = "/workspace/flux-project/models/flux/text_encoder/model.safetensors"
T5XXL_PATH  = "/workspace/flux-project/models/clip/t5-xxl/t5xxl_fp16.safetensors"
AE_PATH     = "/workspace/flux-project/models/vae/ae.safetensors"
OUTPUT_BASE = "/workspace/flux-project/eval"

CHECKPOINTS = {
    "epoch_07": "/workspace/flux-project/output/identity_lora-000007.safetensors",
    "epoch_08": "/workspace/flux-project/output/identity_lora-000008.safetensors",
    "epoch_10": "/workspace/flux-project/output/identity_lora.safetensors",
}

PROMPTS = {
    "test_01_basic_portrait":        "bluej, professional portrait photography, natural lighting, highly detailed",
    "test_02_different_clothing":    "bluej wearing a navy business suit, corporate office, professional photography",
    "test_03_different_environment": "bluej standing in Times Square at night, cinematic photography",
    "test_04_different_expression":  "bluej smiling broadly, candid photography",
    "test_05_different_angle":       "bluej, side profile portrait, studio lighting",
    "test_06_different_age_styling": "bluej, mature professional appearance, magazine photography",
    "test_07_different_hairstyle":   "bluej with short hair, professional portrait",
    "test_08_artistic_style":        "bluej, oil painting, masterpiece",
    "test_09_fantasy_style":         "bluej, fantasy warrior, epic cinematic",
}

SEEDS    = [42, 43, 44, 45, 46]
STEPS    = 25
GUIDANCE = 3.5
WIDTH    = 1024
HEIGHT   = 1024
LORA_SCALE = 0.9
DTYPE    = torch.bfloat16
DEVICE   = torch.device("cuda")

# ── Scheduler ─────────────────────────────────────────────────────────────────

def time_shift(mu, sigma, t: torch.Tensor) -> torch.Tensor:
    # operate on the full tensor so 1/0 β†’ inf β†’ result 0 (handled by PyTorch)
    return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)

def get_schedule(num_steps, image_seq_len, base_shift=0.5, max_shift=1.15, shift=True):
    timesteps = torch.linspace(1, 0, num_steps + 1)
    if shift:
        m = (max_shift - base_shift) / (4096 - 256)
        b = base_shift - m * 256
        mu = m * image_seq_len + b
        timesteps = time_shift(mu, 1.0, timesteps)
    return timesteps.tolist()

# ── Denoise (matches kohya's denoise exactly) ─────────────────────────────────

def denoise(model, img, img_ids, txt, txt_ids, vec, timesteps, guidance):
    guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
    for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
        t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
        mod_vectors = model.get_mod_vectors(timesteps=t_vec, guidance=guidance_vec, batch_size=img.shape[0])
        pred = model(
            img=img, img_ids=img_ids,
            txt=txt, txt_ids=txt_ids,
            y=vec, timesteps=t_vec,
            guidance=guidance_vec,
            mod_vectors=mod_vectors,
        )
        img = img + (t_prev - t_curr) * pred
    return img

# ── Load base models (once) ───────────────────────────────────────────────────

print("Loading Flux transformer...")
_, model = flux_utils.load_flow_model(CKPT_PATH, DTYPE, DEVICE)
model.eval()

print("Loading CLIP-L...")
clip_l = flux_utils.load_clip_l(CLIP_L_PATH, DTYPE, DEVICE)
clip_l.eval()

print("Loading T5-XXL...")
t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, DTYPE, DEVICE)
t5xxl.eval()

print("Loading VAE...")
ae = flux_utils.load_ae(AE_PATH, DTYPE, DEVICE)
ae.eval()

tokenize_strategy = strategy_flux.FluxTokenizeStrategy(512)
encoding_strategy = strategy_flux.FluxTextEncodingStrategy()

# ── Generation function ───────────────────────────────────────────────────────

def generate(prompt, seed):
    # Ensure text encoders are fully on device (position_ids buffer may be on CPU after load)
    clip_l.to(DEVICE)
    t5xxl.to(DEVICE)

    tokens = tokenize_strategy.tokenize(prompt)
    with torch.no_grad():
        # Encode CLIP-L pooled embedding
        with torch.autocast(device_type="cuda", dtype=DTYPE):
            l_pooled, _, _, _ = encoding_strategy.encode_tokens(
                tokenize_strategy, [clip_l, None], tokens
            )
        # Encode T5-XXL sequence
        with torch.autocast(device_type="cuda", dtype=DTYPE):
            _, t5_out, txt_ids, _ = encoding_strategy.encode_tokens(
                tokenize_strategy, [clip_l, t5xxl], tokens
            )

    packed_h = HEIGHT // 16
    packed_w = WIDTH  // 16
    noise = torch.randn(
        1, packed_h * packed_w, 16 * 4,
        device=DEVICE, dtype=DTYPE,
        generator=torch.Generator(DEVICE).manual_seed(seed),
    )
    img_ids = flux_utils.prepare_img_ids(1, packed_h, packed_w).to(DEVICE, dtype=DTYPE)
    txt_ids = txt_ids.to(DEVICE, dtype=DTYPE)
    t5_out  = t5_out.to(DEVICE, dtype=DTYPE)
    l_pooled = l_pooled.to(DEVICE, dtype=DTYPE)

    timesteps = get_schedule(STEPS, noise.shape[1])

    with torch.autocast(device_type="cuda", dtype=DTYPE), torch.no_grad():
        x = denoise(model, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps, GUIDANCE)

    x = x.float()
    x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_h, w=packed_w, ph=2, pw=2)

    with torch.no_grad(), torch.autocast(device_type="cuda", dtype=DTYPE):
        x = ae.decode(x)

    x = x.clamp(-1, 1).permute(0, 2, 3, 1)
    img = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
    return img

# ── Generation loop ───────────────────────────────────────────────────────────

total     = len(CHECKPOINTS) * len(PROMPTS) * len(SEEDS)
done      = 0
lora_model = None

for epoch_label, lora_path in CHECKPOINTS.items():
    print(f"\n{'='*60}")
    print(f"Checkpoint: {epoch_label}  ({lora_path})")
    print(f"{'='*60}")

    weights_sd = load_file(lora_path)

    if lora_model is None:
        lora_model, _ = lora_flux.create_network_from_weights(
            LORA_SCALE, None, ae, [clip_l, t5xxl], model, weights_sd, True
        )
        lora_model.apply_to([clip_l, t5xxl], model)

    lora_model.load_state_dict(weights_sd, strict=True)
    lora_model.eval()
    lora_model.to(DEVICE, dtype=DTYPE)
    clip_l.to(DEVICE, dtype=DTYPE)
    t5xxl.to(DEVICE, dtype=DTYPE)
    model.to(DEVICE, dtype=DTYPE)

    for test_name, prompt in PROMPTS.items():
        out_dir = Path(OUTPUT_BASE) / epoch_label / test_name
        out_dir.mkdir(parents=True, exist_ok=True)

        print(f"\n  [{test_name}]")
        print(f"  Prompt: {prompt}")

        for seed in SEEDS:
            img_path = out_dir / f"seed_{seed:04d}.png"
            if img_path.exists():
                print(f"    seed {seed} β€” already exists, skipping")
                done += 1
                continue

            image = generate(prompt, seed)
            image.save(img_path)
            done += 1
            print(f"    seed {seed} β€” saved ({done}/{total})")

print(f"\nDone. All images saved to {OUTPUT_BASE}/")