|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import DataLoader |
|
|
from datasets import load_dataset |
|
|
from transformers import T5EncoderModel, T5Tokenizer, CLIPTextModel, CLIPTokenizer |
|
|
from huggingface_hub import HfApi, hf_hub_download |
|
|
from safetensors.torch import save_file, load_file |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
from tqdm.auto import tqdm |
|
|
import numpy as np |
|
|
import math |
|
|
import os |
|
|
import json |
|
|
from datetime import datetime |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BATCH_SIZE = 4 |
|
|
GRAD_ACCUM = 2 |
|
|
LR = 1e-4 |
|
|
EPOCHS = 10 |
|
|
MAX_SEQ = 128 |
|
|
MIN_SNR = 5.0 |
|
|
SHIFT = 3.0 |
|
|
DEVICE = "cuda" |
|
|
DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 |
|
|
|
|
|
|
|
|
HF_REPO = "AbstractPhil/tiny-flux" |
|
|
SAVE_EVERY = 1000 |
|
|
UPLOAD_EVERY = 1000 |
|
|
SAMPLE_EVERY = 500 |
|
|
LOG_EVERY = 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LOAD_TARGET = "latest" |
|
|
|
|
|
|
|
|
|
|
|
RESUME_STEP = None |
|
|
|
|
|
|
|
|
CHECKPOINT_DIR = "./tiny_flux_checkpoints" |
|
|
LOG_DIR = "./tiny_flux_logs" |
|
|
SAMPLE_DIR = "./tiny_flux_samples" |
|
|
|
|
|
os.makedirs(CHECKPOINT_DIR, exist_ok=True) |
|
|
os.makedirs(LOG_DIR, exist_ok=True) |
|
|
os.makedirs(SAMPLE_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Setting up HuggingFace Hub...") |
|
|
api = HfApi() |
|
|
|
|
|
try: |
|
|
api.create_repo(repo_id=HF_REPO, exist_ok=True, repo_type="model") |
|
|
print(f"✓ Repo ready: {HF_REPO}") |
|
|
except Exception as e: |
|
|
print(f"Note: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
run_name = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
writer = SummaryWriter(log_dir=os.path.join(LOG_DIR, run_name)) |
|
|
print(f"✓ Tensorboard: {LOG_DIR}/{run_name}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\nLoading dataset...") |
|
|
ds = load_dataset("AbstractPhil/flux-schnell-teacher-latents", split="train") |
|
|
print(f"Samples: {len(ds)}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\nLoading flan-t5-base (768 dim)...") |
|
|
t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base") |
|
|
t5_enc = T5EncoderModel.from_pretrained("google/flan-t5-base", torch_dtype=DTYPE).to(DEVICE).eval() |
|
|
|
|
|
print("Loading CLIP-L...") |
|
|
clip_tok = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") |
|
|
clip_enc = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=DTYPE).to(DEVICE).eval() |
|
|
|
|
|
for p in t5_enc.parameters(): p.requires_grad = False |
|
|
for p in clip_enc.parameters(): p.requires_grad = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Loading Flux VAE for samples...") |
|
|
from diffusers import AutoencoderKL |
|
|
vae = AutoencoderKL.from_pretrained( |
|
|
"black-forest-labs/FLUX.1-schnell", |
|
|
subfolder="vae", |
|
|
torch_dtype=DTYPE |
|
|
).to(DEVICE).eval() |
|
|
for p in vae.parameters(): p.requires_grad = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def encode_prompt(prompt): |
|
|
t5_in = t5_tok(prompt, max_length=MAX_SEQ, padding="max_length", truncation=True, return_tensors="pt").to(DEVICE) |
|
|
t5_out = t5_enc(input_ids=t5_in.input_ids, attention_mask=t5_in.attention_mask).last_hidden_state |
|
|
|
|
|
clip_in = clip_tok(prompt, max_length=77, padding="max_length", truncation=True, return_tensors="pt").to(DEVICE) |
|
|
clip_out = clip_enc(input_ids=clip_in.input_ids, attention_mask=clip_in.attention_mask) |
|
|
return t5_out, clip_out.pooler_output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def flux_shift(t, s=SHIFT): |
|
|
"""Flux timestep shift for training distribution. |
|
|
|
|
|
Shifts timesteps towards higher values (closer to data), |
|
|
making training focus more on refining details. |
|
|
|
|
|
s=3.0 (default): flux_shift(0.5) ≈ 0.75 |
|
|
""" |
|
|
return s * t / (1 + (s - 1) * t) |
|
|
|
|
|
def flux_shift_inverse(t_shifted, s=SHIFT): |
|
|
"""Inverse of flux_shift.""" |
|
|
return t_shifted / (s - (s - 1) * t_shifted) |
|
|
|
|
|
def min_snr_weight(t, gamma=MIN_SNR): |
|
|
"""Min-SNR weighting to balance loss across timesteps. |
|
|
|
|
|
Downweights very easy timesteps (near t=0 or t=1). |
|
|
gamma=5.0 is typical. |
|
|
""" |
|
|
snr = (t / (1 - t).clamp(min=1e-5)).pow(2) |
|
|
return torch.clamp(snr, max=gamma) / snr.clamp(min=1e-5) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def generate_samples(model, prompts, num_steps=20, guidance_scale=3.5, H=64, W=64): |
|
|
"""Generate sample images using Euler sampling. |
|
|
|
|
|
Flow matching: x_t = (1-t)*noise + t*data, v = data - noise |
|
|
At t=0: pure noise. At t=1: pure data. |
|
|
We integrate from t=0 to t=1. |
|
|
""" |
|
|
model.eval() |
|
|
B = len(prompts) |
|
|
C = 16 |
|
|
|
|
|
|
|
|
t5_embeds, clip_pooleds = [], [] |
|
|
for p in prompts: |
|
|
t5_out, clip_pooled = encode_prompt(p) |
|
|
t5_embeds.append(t5_out.squeeze(0)) |
|
|
clip_pooleds.append(clip_pooled.squeeze(0)) |
|
|
t5_embeds = torch.stack(t5_embeds) |
|
|
clip_pooleds = torch.stack(clip_pooleds) |
|
|
|
|
|
|
|
|
x = torch.randn(B, H * W, C, device=DEVICE, dtype=DTYPE) |
|
|
|
|
|
|
|
|
img_ids = TinyFlux.create_img_ids(B, H, W, DEVICE) |
|
|
|
|
|
|
|
|
timesteps = torch.linspace(0, 1, num_steps + 1, device=DEVICE, dtype=DTYPE) |
|
|
|
|
|
for i in range(num_steps): |
|
|
t_curr = timesteps[i] |
|
|
t_next = timesteps[i + 1] |
|
|
dt = t_next - t_curr |
|
|
|
|
|
t_batch = t_curr.expand(B) |
|
|
|
|
|
|
|
|
guidance = torch.full((B,), guidance_scale, device=DEVICE, dtype=DTYPE) |
|
|
v_cond = model( |
|
|
hidden_states=x, |
|
|
encoder_hidden_states=t5_embeds, |
|
|
pooled_projections=clip_pooleds, |
|
|
timestep=t_batch, |
|
|
img_ids=img_ids, |
|
|
guidance=guidance, |
|
|
) |
|
|
|
|
|
|
|
|
x = x + v_cond * dt |
|
|
|
|
|
|
|
|
latents = x.reshape(B, H, W, C).permute(0, 3, 1, 2) |
|
|
|
|
|
|
|
|
latents = latents / vae.config.scaling_factor |
|
|
images = vae.decode(latents.to(vae.dtype)).sample |
|
|
images = (images / 2 + 0.5).clamp(0, 1) |
|
|
|
|
|
model.train() |
|
|
return images |
|
|
|
|
|
def save_samples(images, prompts, step, save_dir): |
|
|
"""Save sample images and log to tensorboard.""" |
|
|
from torchvision.utils import make_grid, save_image |
|
|
|
|
|
|
|
|
for i, (img, prompt) in enumerate(zip(images, prompts)): |
|
|
safe_prompt = prompt[:50].replace(" ", "_").replace("/", "-") |
|
|
path = os.path.join(save_dir, f"step{step}_{i}_{safe_prompt}.png") |
|
|
save_image(img, path) |
|
|
|
|
|
|
|
|
grid = make_grid(images, nrow=2, normalize=False) |
|
|
writer.add_image("samples", grid, step) |
|
|
|
|
|
|
|
|
writer.add_text("sample_prompts", "\n".join(prompts), step) |
|
|
|
|
|
print(f" ✓ Saved {len(images)} samples") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def collate(batch): |
|
|
latents, t5_embeds, clip_embeds, prompts = [], [], [], [] |
|
|
for b in batch: |
|
|
latents.append(torch.tensor(np.array(b["latent"]), dtype=DTYPE)) |
|
|
t5_out, clip_pooled = encode_prompt(b["prompt"]) |
|
|
t5_embeds.append(t5_out.squeeze(0)) |
|
|
clip_embeds.append(clip_pooled.squeeze(0)) |
|
|
prompts.append(b["prompt"]) |
|
|
return { |
|
|
"latents": torch.stack(latents).to(DEVICE), |
|
|
"t5_embeds": torch.stack(t5_embeds), |
|
|
"clip_pooled": torch.stack(clip_embeds), |
|
|
"prompts": prompts, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_weights(path): |
|
|
"""Load weights from .safetensors or .pt file.""" |
|
|
if path.endswith(".safetensors"): |
|
|
return load_file(path) |
|
|
elif path.endswith(".pt"): |
|
|
ckpt = torch.load(path, map_location=DEVICE, weights_only=False) |
|
|
if isinstance(ckpt, dict): |
|
|
if "model" in ckpt: |
|
|
return ckpt["model"] |
|
|
elif "state_dict" in ckpt: |
|
|
return ckpt["state_dict"] |
|
|
else: |
|
|
|
|
|
first_val = next(iter(ckpt.values()), None) |
|
|
if isinstance(first_val, torch.Tensor): |
|
|
return ckpt |
|
|
|
|
|
return ckpt |
|
|
return ckpt |
|
|
else: |
|
|
|
|
|
try: |
|
|
return load_file(path) |
|
|
except: |
|
|
return torch.load(path, map_location=DEVICE, weights_only=False) |
|
|
|
|
|
def save_checkpoint(model, optimizer, scheduler, step, epoch, loss, path): |
|
|
"""Save checkpoint locally.""" |
|
|
os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True) |
|
|
|
|
|
weights_path = path.replace(".pt", ".safetensors") |
|
|
save_file(model.state_dict(), weights_path) |
|
|
|
|
|
state = { |
|
|
"step": step, |
|
|
"epoch": epoch, |
|
|
"loss": loss, |
|
|
"optimizer": optimizer.state_dict(), |
|
|
"scheduler": scheduler.state_dict(), |
|
|
} |
|
|
torch.save(state, path) |
|
|
print(f" ✓ Saved checkpoint: step {step}") |
|
|
return weights_path |
|
|
|
|
|
def upload_checkpoint(weights_path, step, config, include_logs=True): |
|
|
"""Upload checkpoint to HuggingFace Hub.""" |
|
|
try: |
|
|
|
|
|
api.upload_file( |
|
|
path_or_fileobj=weights_path, |
|
|
path_in_repo=f"checkpoints/step_{step}.safetensors", |
|
|
repo_id=HF_REPO, |
|
|
commit_message=f"Checkpoint step {step}", |
|
|
) |
|
|
|
|
|
|
|
|
config_path = os.path.join(CHECKPOINT_DIR, "config.json") |
|
|
with open(config_path, "w") as f: |
|
|
json.dump(config.__dict__, f, indent=2) |
|
|
api.upload_file( |
|
|
path_or_fileobj=config_path, |
|
|
path_in_repo="config.json", |
|
|
repo_id=HF_REPO, |
|
|
) |
|
|
|
|
|
|
|
|
if include_logs and os.path.exists(LOG_DIR): |
|
|
api.upload_folder( |
|
|
folder_path=LOG_DIR, |
|
|
path_in_repo="logs", |
|
|
repo_id=HF_REPO, |
|
|
commit_message=f"Logs at step {step}", |
|
|
) |
|
|
|
|
|
|
|
|
if os.path.exists(SAMPLE_DIR) and os.listdir(SAMPLE_DIR): |
|
|
api.upload_folder( |
|
|
folder_path=SAMPLE_DIR, |
|
|
path_in_repo="samples", |
|
|
repo_id=HF_REPO, |
|
|
commit_message=f"Samples at step {step}", |
|
|
) |
|
|
|
|
|
print(f" ✓ Uploaded to {HF_REPO}") |
|
|
except Exception as e: |
|
|
print(f" ⚠ Upload failed: {e}") |
|
|
|
|
|
def load_checkpoint(model, optimizer, scheduler, target): |
|
|
""" |
|
|
Load checkpoint based on target specification. |
|
|
|
|
|
Args: |
|
|
target: |
|
|
None, "latest" - most recent checkpoint |
|
|
"best" - best model |
|
|
int (1500) - specific step |
|
|
"hub:step_1000" - specific hub checkpoint |
|
|
"local:/path/to/file.safetensors" or "local:/path/to/file.pt" - specific local file |
|
|
"none" - skip loading, start fresh |
|
|
""" |
|
|
if target == "none": |
|
|
print("Starting fresh (no checkpoint loading)") |
|
|
return 0, 0 |
|
|
|
|
|
start_step, start_epoch = 0, 0 |
|
|
|
|
|
|
|
|
if target is None or target == "latest": |
|
|
load_mode = "latest" |
|
|
load_path = None |
|
|
elif target == "best": |
|
|
load_mode = "best" |
|
|
load_path = None |
|
|
elif isinstance(target, int): |
|
|
load_mode = "step" |
|
|
load_path = target |
|
|
elif target.startswith("hub:"): |
|
|
load_mode = "hub" |
|
|
load_path = target[4:] |
|
|
elif target.startswith("local:"): |
|
|
load_mode = "local" |
|
|
load_path = target[6:] |
|
|
else: |
|
|
print(f"Unknown target format: {target}, trying as step number") |
|
|
try: |
|
|
load_mode = "step" |
|
|
load_path = int(target) |
|
|
except: |
|
|
load_mode = "latest" |
|
|
load_path = None |
|
|
|
|
|
|
|
|
if load_mode == "local": |
|
|
|
|
|
if os.path.exists(load_path): |
|
|
weights = load_weights(load_path) |
|
|
model.load_state_dict(weights) |
|
|
|
|
|
|
|
|
if load_path.endswith(".safetensors"): |
|
|
state_path = load_path.replace(".safetensors", ".pt") |
|
|
elif load_path.endswith(".pt"): |
|
|
|
|
|
ckpt = torch.load(load_path, map_location=DEVICE, weights_only=False) |
|
|
if isinstance(ckpt, dict): |
|
|
|
|
|
non_tensor_keys = [k for k in ckpt.keys() if not isinstance(ckpt.get(k), torch.Tensor)] |
|
|
if non_tensor_keys: |
|
|
print(f" Checkpoint keys: {non_tensor_keys}") |
|
|
|
|
|
|
|
|
start_step = ckpt.get("step", ckpt.get("global_step", ckpt.get("iteration", 0))) |
|
|
start_epoch = ckpt.get("epoch", 0) |
|
|
|
|
|
|
|
|
if "state" in ckpt and isinstance(ckpt["state"], dict): |
|
|
start_step = ckpt["state"].get("step", start_step) |
|
|
start_epoch = ckpt["state"].get("epoch", start_epoch) |
|
|
|
|
|
|
|
|
if "optimizer" in ckpt: |
|
|
try: |
|
|
optimizer.load_state_dict(ckpt["optimizer"]) |
|
|
if "scheduler" in ckpt: |
|
|
scheduler.load_state_dict(ckpt["scheduler"]) |
|
|
except Exception as e: |
|
|
print(f" Note: Could not load optimizer state: {e}") |
|
|
state_path = None |
|
|
else: |
|
|
state_path = load_path + ".pt" |
|
|
|
|
|
if state_path and os.path.exists(state_path): |
|
|
state = torch.load(state_path, map_location=DEVICE, weights_only=False) |
|
|
try: |
|
|
start_step = state.get("step", start_step) |
|
|
start_epoch = state.get("epoch", start_epoch) |
|
|
if "optimizer" in state: |
|
|
optimizer.load_state_dict(state["optimizer"]) |
|
|
if "scheduler" in state: |
|
|
scheduler.load_state_dict(state["scheduler"]) |
|
|
except Exception as e: |
|
|
print(f" Note: Could not load optimizer state: {e}") |
|
|
|
|
|
print(f"✓ Loaded local: {load_path} (step {start_step})") |
|
|
return start_step, start_epoch |
|
|
else: |
|
|
print(f"⚠ Local file not found: {load_path}") |
|
|
|
|
|
elif load_mode == "hub": |
|
|
|
|
|
for ext in [".safetensors", ".pt", ""]: |
|
|
try: |
|
|
if load_path.endswith((".safetensors", ".pt")): |
|
|
filename = load_path if "/" in load_path else f"checkpoints/{load_path}" |
|
|
else: |
|
|
filename = f"checkpoints/{load_path}{ext}" |
|
|
local_path = hf_hub_download(repo_id=HF_REPO, filename=filename) |
|
|
weights = load_weights(local_path) |
|
|
model.load_state_dict(weights) |
|
|
|
|
|
if "step_" in load_path: |
|
|
start_step = int(load_path.split("step_")[-1].replace(".safetensors", "").replace(".pt", "")) |
|
|
print(f"✓ Loaded from Hub: {filename} (step {start_step})") |
|
|
return start_step, start_epoch |
|
|
except Exception as e: |
|
|
continue |
|
|
print(f"⚠ Could not load from hub: {load_path}") |
|
|
|
|
|
elif load_mode == "best": |
|
|
|
|
|
for ext in [".safetensors", ".pt"]: |
|
|
try: |
|
|
filename = f"model{ext}" if ext else "model.safetensors" |
|
|
local_path = hf_hub_download(repo_id=HF_REPO, filename=filename) |
|
|
weights = load_weights(local_path) |
|
|
model.load_state_dict(weights) |
|
|
print(f"✓ Loaded best model from Hub") |
|
|
return start_step, start_epoch |
|
|
except: |
|
|
continue |
|
|
|
|
|
|
|
|
for ext in [".safetensors", ".pt"]: |
|
|
best_path = os.path.join(CHECKPOINT_DIR, f"best{ext}") |
|
|
if os.path.exists(best_path): |
|
|
weights = load_weights(best_path) |
|
|
model.load_state_dict(weights) |
|
|
|
|
|
state_path = best_path.replace(ext, ".pt") if ext == ".safetensors" else best_path |
|
|
if os.path.exists(state_path): |
|
|
state = torch.load(state_path, map_location=DEVICE, weights_only=False) |
|
|
if isinstance(state, dict) and "step" in state: |
|
|
start_step = state.get("step", 0) |
|
|
start_epoch = state.get("epoch", 0) |
|
|
print(f"✓ Loaded local best (step {start_step})") |
|
|
return start_step, start_epoch |
|
|
|
|
|
elif load_mode == "step": |
|
|
|
|
|
step_num = load_path |
|
|
|
|
|
for ext in [".safetensors", ".pt"]: |
|
|
try: |
|
|
filename = f"checkpoints/step_{step_num}{ext}" |
|
|
local_path = hf_hub_download(repo_id=HF_REPO, filename=filename) |
|
|
weights = load_weights(local_path) |
|
|
model.load_state_dict(weights) |
|
|
start_step = step_num |
|
|
print(f"✓ Loaded step {step_num} from Hub") |
|
|
return start_step, start_epoch |
|
|
except: |
|
|
continue |
|
|
|
|
|
|
|
|
for ext in [".safetensors", ".pt"]: |
|
|
local_path = os.path.join(CHECKPOINT_DIR, f"step_{step_num}{ext}") |
|
|
if os.path.exists(local_path): |
|
|
weights = load_weights(local_path) |
|
|
model.load_state_dict(weights) |
|
|
state_path = local_path.replace(".safetensors", ".pt") if ext == ".safetensors" else local_path |
|
|
if os.path.exists(state_path): |
|
|
state = torch.load(state_path, map_location=DEVICE, weights_only=False) |
|
|
if isinstance(state, dict): |
|
|
try: |
|
|
if "optimizer" in state: |
|
|
optimizer.load_state_dict(state["optimizer"]) |
|
|
if "scheduler" in state: |
|
|
scheduler.load_state_dict(state["scheduler"]) |
|
|
start_epoch = state.get("epoch", 0) |
|
|
except: |
|
|
pass |
|
|
start_step = step_num |
|
|
print(f"✓ Loaded local step {step_num}") |
|
|
return start_step, start_epoch |
|
|
print(f"⚠ Step {step_num} not found") |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
files = api.list_repo_files(repo_id=HF_REPO) |
|
|
checkpoints = [f for f in files if f.startswith("checkpoints/step_") and (f.endswith(".safetensors") or f.endswith(".pt"))] |
|
|
if checkpoints: |
|
|
|
|
|
def get_step(f): |
|
|
return int(f.split("step_")[-1].replace(".safetensors", "").replace(".pt", "")) |
|
|
checkpoints.sort(key=get_step) |
|
|
latest = checkpoints[-1] |
|
|
step = get_step(latest) |
|
|
local_path = hf_hub_download(repo_id=HF_REPO, filename=latest) |
|
|
weights = load_weights(local_path) |
|
|
model.load_state_dict(weights) |
|
|
start_step = step |
|
|
print(f"✓ Loaded latest from Hub: step {step}") |
|
|
return start_step, start_epoch |
|
|
except Exception as e: |
|
|
print(f"Hub check: {e}") |
|
|
|
|
|
|
|
|
if os.path.exists(CHECKPOINT_DIR): |
|
|
local_ckpts = [f for f in os.listdir(CHECKPOINT_DIR) if f.startswith("step_") and (f.endswith(".safetensors") or f.endswith(".pt"))] |
|
|
|
|
|
local_ckpts = [f for f in local_ckpts if not (f.endswith(".pt") and f.replace(".pt", ".safetensors") in local_ckpts)] |
|
|
if local_ckpts: |
|
|
def get_step(f): |
|
|
return int(f.split("step_")[-1].replace(".safetensors", "").replace(".pt", "")) |
|
|
local_ckpts.sort(key=get_step) |
|
|
latest = local_ckpts[-1] |
|
|
step = get_step(latest) |
|
|
weights_path = os.path.join(CHECKPOINT_DIR, latest) |
|
|
weights = load_weights(weights_path) |
|
|
model.load_state_dict(weights) |
|
|
|
|
|
state_path = weights_path.replace(".safetensors", ".pt") if weights_path.endswith(".safetensors") else weights_path |
|
|
if os.path.exists(state_path): |
|
|
state = torch.load(state_path, map_location=DEVICE, weights_only=False) |
|
|
if isinstance(state, dict): |
|
|
try: |
|
|
if "optimizer" in state: |
|
|
optimizer.load_state_dict(state["optimizer"]) |
|
|
if "scheduler" in state: |
|
|
scheduler.load_state_dict(state["scheduler"]) |
|
|
start_epoch = state.get("epoch", 0) |
|
|
except: |
|
|
pass |
|
|
start_step = step |
|
|
print(f"✓ Loaded latest local: step {step}") |
|
|
return start_step, start_epoch |
|
|
|
|
|
print("No checkpoint found, starting fresh") |
|
|
return 0, 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate, num_workers=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config = TinyFluxConfig() |
|
|
model = TinyFlux(config).to(DEVICE).to(DTYPE) |
|
|
print(f"\nParams: {sum(p.numel() for p in model.parameters()):,}") |
|
|
model = torch.compile(model, mode="default") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
opt = torch.optim.AdamW(model.parameters(), lr=LR, betas=(0.9, 0.99), weight_decay=0.01) |
|
|
total_steps = len(loader) * EPOCHS // GRAD_ACCUM |
|
|
warmup = min(500, total_steps // 10) |
|
|
|
|
|
def lr_fn(step): |
|
|
if step < warmup: return step / warmup |
|
|
return 0.5 * (1 + math.cos(math.pi * (step - warmup) / (total_steps - warmup))) |
|
|
|
|
|
sched = torch.optim.lr_scheduler.LambdaLR(opt, lr_fn) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\nLoad target: {LOAD_TARGET}") |
|
|
start_step, start_epoch = load_checkpoint(model, opt, sched, LOAD_TARGET) |
|
|
|
|
|
|
|
|
if RESUME_STEP is not None: |
|
|
print(f"Overriding start_step: {start_step} -> {RESUME_STEP}") |
|
|
start_step = RESUME_STEP |
|
|
|
|
|
|
|
|
writer.add_text("config", json.dumps(config.__dict__, indent=2), 0) |
|
|
writer.add_text("training_config", json.dumps({ |
|
|
"batch_size": BATCH_SIZE, |
|
|
"grad_accum": GRAD_ACCUM, |
|
|
"lr": LR, |
|
|
"epochs": EPOCHS, |
|
|
"min_snr": MIN_SNR, |
|
|
"shift": SHIFT, |
|
|
}, indent=2), 0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SAMPLE_PROMPTS = [ |
|
|
"a photo of a cat sitting on a windowsill", |
|
|
"a beautiful sunset over mountains", |
|
|
"a portrait of a woman with red hair", |
|
|
"a futuristic cityscape at night", |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\nTraining {EPOCHS} epochs, {total_steps} total steps") |
|
|
print(f"Resuming from step {start_step}, epoch {start_epoch}") |
|
|
print(f"Save: {SAVE_EVERY}, Upload: {UPLOAD_EVERY}, Sample: {SAMPLE_EVERY}, Log: {LOG_EVERY}") |
|
|
|
|
|
model.train() |
|
|
step = start_step |
|
|
best = float("inf") |
|
|
|
|
|
for ep in range(start_epoch, EPOCHS): |
|
|
ep_loss = 0 |
|
|
ep_batches = 0 |
|
|
pbar = tqdm(loader, desc=f"E{ep+1}") |
|
|
|
|
|
for i, batch in enumerate(pbar): |
|
|
latents = batch["latents"] |
|
|
t5 = batch["t5_embeds"] |
|
|
clip = batch["clip_pooled"] |
|
|
|
|
|
B, C, H, W = latents.shape |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data = latents.permute(0, 2, 3, 1).reshape(B, H*W, C) |
|
|
noise = torch.randn_like(data) |
|
|
|
|
|
|
|
|
|
|
|
t = torch.sigmoid(torch.randn(B, device=DEVICE)) |
|
|
t = flux_shift(t, s=SHIFT).to(DTYPE).clamp(1e-4, 1-1e-4) |
|
|
|
|
|
|
|
|
t_expanded = t.view(B, 1, 1) |
|
|
x_t = (1 - t_expanded) * noise + t_expanded * data |
|
|
|
|
|
|
|
|
v_target = data - noise |
|
|
|
|
|
|
|
|
img_ids = TinyFlux.create_img_ids(B, H, W, DEVICE) |
|
|
|
|
|
|
|
|
guidance = torch.rand(B, device=DEVICE, dtype=DTYPE) * 4 + 1 |
|
|
|
|
|
|
|
|
with torch.autocast("cuda", dtype=DTYPE): |
|
|
v_pred = model( |
|
|
hidden_states=x_t, |
|
|
encoder_hidden_states=t5, |
|
|
pooled_projections=clip, |
|
|
timestep=t, |
|
|
img_ids=img_ids, |
|
|
guidance=guidance, |
|
|
) |
|
|
|
|
|
|
|
|
loss_raw = F.mse_loss(v_pred, v_target, reduction="none").mean(dim=[1, 2]) |
|
|
|
|
|
|
|
|
snr_weights = min_snr_weight(t) |
|
|
loss = (loss_raw * snr_weights).mean() / GRAD_ACCUM |
|
|
loss.backward() |
|
|
|
|
|
if (i + 1) % GRAD_ACCUM == 0: |
|
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
|
|
opt.step() |
|
|
sched.step() |
|
|
opt.zero_grad() |
|
|
step += 1 |
|
|
|
|
|
|
|
|
if step % LOG_EVERY == 0: |
|
|
writer.add_scalar("train/loss", loss.item() * GRAD_ACCUM, step) |
|
|
writer.add_scalar("train/lr", sched.get_last_lr()[0], step) |
|
|
writer.add_scalar("train/grad_norm", grad_norm.item(), step) |
|
|
writer.add_scalar("train/t_mean", t.mean().item(), step) |
|
|
writer.add_scalar("train/snr_weight_mean", snr_weights.mean().item(), step) |
|
|
|
|
|
|
|
|
if step % SAMPLE_EVERY == 0: |
|
|
print(f"\n Generating samples at step {step}...") |
|
|
images = generate_samples(model, SAMPLE_PROMPTS, num_steps=20) |
|
|
save_samples(images, SAMPLE_PROMPTS, step, SAMPLE_DIR) |
|
|
|
|
|
|
|
|
if step % SAVE_EVERY == 0: |
|
|
ckpt_path = os.path.join(CHECKPOINT_DIR, f"step_{step}.pt") |
|
|
weights_path = save_checkpoint(model, opt, sched, step, ep, loss.item(), ckpt_path) |
|
|
|
|
|
|
|
|
if step % UPLOAD_EVERY == 0: |
|
|
upload_checkpoint(weights_path, step, config, include_logs=True) |
|
|
|
|
|
ep_loss += loss.item() * GRAD_ACCUM |
|
|
ep_batches += 1 |
|
|
pbar.set_postfix(loss=f"{loss.item()*GRAD_ACCUM:.4f}", lr=f"{sched.get_last_lr()[0]:.1e}", step=step) |
|
|
|
|
|
avg = ep_loss / max(ep_batches, 1) |
|
|
print(f"Epoch {ep+1} loss: {avg:.4f}") |
|
|
writer.add_scalar("train/epoch_loss", avg, ep + 1) |
|
|
|
|
|
if avg < best: |
|
|
best = avg |
|
|
best_path = os.path.join(CHECKPOINT_DIR, "best.pt") |
|
|
weights_path = save_checkpoint(model, opt, sched, step, ep, avg, best_path) |
|
|
|
|
|
try: |
|
|
api.upload_file( |
|
|
path_or_fileobj=weights_path, |
|
|
path_in_repo="model.safetensors", |
|
|
repo_id=HF_REPO, |
|
|
commit_message=f"Best model (epoch {ep+1}, loss {avg:.4f})", |
|
|
) |
|
|
print(f" ✓ Uploaded best to {HF_REPO}") |
|
|
except Exception as e: |
|
|
print(f" ⚠ Upload failed: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\nSaving final model...") |
|
|
final_path = os.path.join(CHECKPOINT_DIR, "final.pt") |
|
|
weights_path = save_checkpoint(model, opt, sched, step, EPOCHS, best, final_path) |
|
|
|
|
|
|
|
|
print("Generating final samples...") |
|
|
images = generate_samples(model, SAMPLE_PROMPTS, num_steps=20) |
|
|
save_samples(images, SAMPLE_PROMPTS, step, SAMPLE_DIR) |
|
|
|
|
|
|
|
|
try: |
|
|
api.upload_file(path_or_fileobj=weights_path, path_in_repo="model.safetensors", repo_id=HF_REPO) |
|
|
config_path = os.path.join(CHECKPOINT_DIR, "config.json") |
|
|
with open(config_path, "w") as f: |
|
|
json.dump(config.__dict__, f, indent=2) |
|
|
api.upload_file(path_or_fileobj=config_path, path_in_repo="config.json", repo_id=HF_REPO) |
|
|
api.upload_folder(folder_path=LOG_DIR, path_in_repo="logs", repo_id=HF_REPO) |
|
|
api.upload_folder(folder_path=SAMPLE_DIR, path_in_repo="samples", repo_id=HF_REPO) |
|
|
print(f"\n✓ Training complete! https://huggingface.co/{HF_REPO}") |
|
|
except Exception as e: |
|
|
print(f"\n⚠ Final upload failed: {e}") |
|
|
|
|
|
writer.close() |
|
|
print(f"Best loss: {best:.4f}") |