madtune's picture
Add ControlNet + IP-Adapter weights, HED detector, training scripts
1d1bd5d
Raw
History Blame Contribute Delete
10.3 kB
import json
import os
import random
import sys
import time
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image, ImageOps
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
# Settings --------------------------------------------------------------------
DATA_DIR = "/home/nobus/Raid0/DataSet/Images1"
HED_DIR = "/home/nobus/Raid0/DataSet/hed_maps_768"
EMBEDDINGS_DIR = "/home/nobus/Raid0/DataSet/embeddings" # index.json + gemma_embs.npy + siglip_pools.npy
OUT = "checkpoints/controlnet_scribble_ip_768.pt"
RESUME = None
IP_CKPT = "checkpoints/ip_adapter_v2.pt"
DEVICE = "cuda:1"
TRANSFORMER_DEVICE = "cuda:1"
COPY_BLOCKS_NUM = 7
BATCH_SIZE = 1
STEPS = 10000
LR = 5e-5
WEIGHT_DECAY = 1e-4
IMG_SIZE = 768
FLOW_SHIFT = 4.0
SEED = 4243
TIMESTEP_STD = 1.0
TEXT_DROP = 0.10
IP_DROP = 0.05
CTRL_DROP = 0.05
GRAD_CLIP = 0.25
LOG_EVERY = 25
SAVE_EVERY = 500
NUM_WORKERS = 4
# ----------------------------------------------------------------------------
_D = os.path.dirname(os.path.abspath(__file__))
_ROOT = os.path.abspath(os.path.join(_D, "../.."))
sys.path.insert(0, _D)
from control_maps import control_to_tensor, hed_to_scribble
from controlnet import (
PixelDiTControlNet,
load_checkpoint,
load_ip_adapter_checkpoint,
load_pixeldit,
save_checkpoint,
)
_T = 1000
def build_schedule(shift, T=_T):
betas = np.linspace(1.0, 0.001, T, dtype=np.float64)
sigmas_raw = 1.0 - betas
sigmas = shift * sigmas_raw / (1.0 + (shift - 1.0) * sigmas_raw)
alphas = 1.0 - sigmas
return torch.from_numpy(sigmas).float(), torch.from_numpy(alphas).float()
def q_sample(x0, t_idx, noise, alphas, sigmas):
a = alphas[t_idx].view(-1, 1, 1, 1)
s = sigmas[t_idx].view(-1, 1, 1, 1)
return a * x0 + s * noise
class ScribbleIPDataset(Dataset):
def __init__(self, data_dir, emb_dir, hed_dir, size, max_samples=None):
exts = {".jpg", ".jpeg", ".png", ".webp"}
all_paths = sorted(
os.path.join(r, f)
for r, _, files in os.walk(data_dir)
for f in files
if os.path.splitext(f)[1].lower() in exts
)
with open(os.path.join(emb_dir, "index.json"), encoding="utf-8") as fh:
index = json.load(fh)
self.gemma_mm = np.load(os.path.join(emb_dir, "gemma_embs.npy"), mmap_mode="r")
self.siglip_mm = np.load(os.path.join(emb_dir, "siglip_pools.npy"), mmap_mode="r")
self.paths = []
self.rows = []
for p in all_paths:
row = index.get(str(p))
if row is None:
continue
stem = os.path.splitext(os.path.basename(p))[0]
hed_path = os.path.join(hed_dir, f"{stem}.jpg")
if not os.path.exists(hed_path):
continue
self.paths.append(p)
self.rows.append(row)
if max_samples:
self.paths = self.paths[:max_samples]
self.rows = self.rows[:max_samples]
if not self.paths:
raise RuntimeError(
f"No images matched embeddings in {emb_dir} and HED maps in {hed_dir}. "
"Run experiments/SANAInSANE/precompute_hed.py first."
)
self.size = size
self.hed_dir = hed_dir
print(f"Dataset: {len(self.paths)} images with Gemma, SigLIP, and 768 HED maps")
def __len__(self):
return len(self.paths)
def __getitem__(self, idx):
path = self.paths[idx]
row = self.rows[idx]
stem = os.path.splitext(os.path.basename(path))[0]
image = ImageOps.fit(
Image.open(path).convert("RGB"),
(self.size, self.size),
method=Image.LANCZOS,
)
arr = np.asarray(image, dtype=np.uint8).copy()
img = torch.from_numpy(arr).permute(2, 0, 1).float().div_(127.5).sub_(1.0)
hed_path = os.path.join(self.hed_dir, f"{stem}.jpg")
hed_img = ImageOps.fit(
Image.open(hed_path).convert("L"),
(self.size, self.size),
method=Image.LANCZOS,
)
hed_arr = np.asarray(hed_img, dtype=np.uint8).copy()
scribble = hed_to_scribble(hed_arr)
control = control_to_tensor(scribble)
gemma = torch.from_numpy(self.gemma_mm[row].copy())
siglip = torch.from_numpy(self.siglip_mm[row].copy())
return img, control, gemma, siglip
def collate(batch):
imgs, controls, gemmas, siglips = zip(*batch)
return torch.stack(imgs), torch.stack(controls), torch.stack(gemmas), torch.stack(siglips)
def main():
if IMG_SIZE < 768:
raise ValueError("Train at 768+ for PixelDiT; 512 teaches the adapter a bad base-prior regime.")
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(SEED)
torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision("high")
print("Loading PixelDiT transformer...")
transformer = load_pixeldit().to(TRANSFORMER_DEVICE)
print(f"Building ControlNet ({COPY_BLOCKS_NUM} copied blocks) + frozen IP adapter...")
model = PixelDiTControlNet(transformer, copy_blocks_num=COPY_BLOCKS_NUM).to(DEVICE)
model.transformer.to(TRANSFORMER_DEVICE)
ip_path = IP_CKPT if os.path.isabs(IP_CKPT) else os.path.join(_D, IP_CKPT)
print(f"Loading IP adapter: {ip_path}")
ip_step = load_ip_adapter_checkpoint(model, ip_path)
print(f" IP step {ip_step}")
resume_path = RESUME if (RESUME and os.path.isabs(RESUME)) else (os.path.join(_D, RESUME) if RESUME else None)
base_step = 0
if resume_path:
print(f"Resuming full adapter checkpoint: {resume_path}")
base_step = load_checkpoint(model, resume_path)
print(f" step {base_step}")
# Freeze everything except the HED ControlNet branch. The pretrained IP adapter
# stays active as conditioning, but it is not updated by this trainer.
for p in model.parameters():
p.requires_grad_(False)
for p in model.controlnet_blocks.parameters():
p.requires_grad_(True)
trainable = list(model.controlnet_blocks.parameters())
n_params = sum(p.numel() for p in trainable) / 1e6
print(f"Trainable ControlNet params: {n_params:.1f}M")
dataset = ScribbleIPDataset(DATA_DIR, EMBEDDINGS_DIR, HED_DIR, IMG_SIZE)
loader_kw = dict(
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=NUM_WORKERS,
drop_last=True,
pin_memory=True,
collate_fn=collate,
)
if NUM_WORKERS > 0:
loader_kw.update(persistent_workers=True, prefetch_factor=2)
loader = DataLoader(dataset, **loader_kw)
opt = torch.optim.AdamW(trainable, lr=LR, weight_decay=WEIGHT_DECAY)
# BF16 autocast does not require gradient scaling, and CUDA GradScaler
# cannot unscale BF16 grads on this build.
scaler = torch.amp.GradScaler("cuda", enabled=False)
sigmas_s, alphas_s = build_schedule(FLOW_SHIFT)
sigmas_s = sigmas_s.to(DEVICE)
alphas_s = alphas_s.to(DEVICE)
null_y = torch.zeros(1, 300, 2304, dtype=torch.bfloat16, device=DEVICE)
out_path = OUT if os.path.isabs(OUT) else os.path.join(_D, OUT)
model.train()
model.transformer.eval()
step = 0
total_loss = 0.0
t0 = time.time()
print(f"\nTraining {STEPS} steps batch={BATCH_SIZE} lr={LR} img={IMG_SIZE} flow_shift={FLOW_SHIFT}")
pbar = tqdm(total=STEPS, unit="step")
while step < STEPS:
for imgs, controls, gemma_embs, siglip_pools in loader:
if step >= STEPS:
break
imgs = imgs.to(DEVICE, non_blocking=True)
B = imgs.shape[0]
y = gemma_embs.to(DEVICE, dtype=torch.bfloat16, non_blocking=True)
if TEXT_DROP > 0 and random.random() < TEXT_DROP:
y = null_y.expand(B, -1, -1)
ref_x = controls.to(DEVICE, dtype=torch.bfloat16, non_blocking=True)
if CTRL_DROP > 0 and random.random() < CTRL_DROP:
ctrl_scale = 0.0
else:
ctrl_scale = 1.0
siglip_pools = siglip_pools.to(DEVICE, dtype=torch.bfloat16, non_blocking=True)
ip_features = model.encode_siglip(siglip_pools)
if IP_DROP > 0 and random.random() < IP_DROP:
ip_features = torch.zeros_like(ip_features)
u = torch.sigmoid(torch.randn(B, device=DEVICE) * TIMESTEP_STD)
t_idx = (u * _T).long().clamp(0, _T - 1)
noise = torch.randn_like(imgs)
x_t = q_sample(imgs, t_idx, noise, alphas_s, sigmas_s).to(torch.bfloat16)
target = (noise - imgs).float()
t_model = sigmas_s[t_idx] * 1000.0
opt.zero_grad(set_to_none=True)
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
pred = model(
x_t,
t_model,
y,
ref_x,
ctrl_scale=ctrl_scale,
ip_features=ip_features,
ip_strength=1.0,
).to(DEVICE)
loss = F.mse_loss(pred.float(), target)
scaler.scale(loss).backward()
scaler.unscale_(opt)
grad_norm = torch.nn.utils.clip_grad_norm_(trainable, GRAD_CLIP)
scaler.step(opt)
scaler.update()
total_loss += loss.item()
step += 1
global_step = base_step + step
pbar.update(1)
if step % LOG_EVERY == 0:
avg = total_loss / LOG_EVERY
elapsed = max(time.time() - t0, 1e-6)
pbar.set_postfix(loss=f"{avg:.4f}", grad=f"{float(grad_norm):.3e}", sps=f"{step / elapsed:.2f}")
total_loss = 0.0
if step % SAVE_EVERY == 0 or step == STEPS:
save_checkpoint(model, out_path, global_step)
tqdm.write(f" saved -> {out_path}")
pbar.close()
save_checkpoint(model, out_path, base_step + step)
print(f"\nDone -> {out_path}")
if __name__ == "__main__":
main()