import json import os import random import sys import time import numpy as np import torch import torch.nn.functional as F from PIL import Image, ImageOps from torch.utils.data import DataLoader, Dataset from tqdm import tqdm # Settings -------------------------------------------------------------------- DATA_DIR = "/home/nobus/Raid0/DataSet/Images1" HED_DIR = "/home/nobus/Raid0/DataSet/hed_maps_768" EMBEDDINGS_DIR = "/home/nobus/Raid0/DataSet/embeddings" # index.json + gemma_embs.npy + siglip_pools.npy OUT = "checkpoints/controlnet_scribble_ip_768.pt" RESUME = None IP_CKPT = "checkpoints/ip_adapter_v2.pt" DEVICE = "cuda:1" TRANSFORMER_DEVICE = "cuda:1" COPY_BLOCKS_NUM = 7 BATCH_SIZE = 1 STEPS = 10000 LR = 5e-5 WEIGHT_DECAY = 1e-4 IMG_SIZE = 768 FLOW_SHIFT = 4.0 SEED = 4243 TIMESTEP_STD = 1.0 TEXT_DROP = 0.10 IP_DROP = 0.05 CTRL_DROP = 0.05 GRAD_CLIP = 0.25 LOG_EVERY = 25 SAVE_EVERY = 500 NUM_WORKERS = 4 # ---------------------------------------------------------------------------- _D = os.path.dirname(os.path.abspath(__file__)) _ROOT = os.path.abspath(os.path.join(_D, "../..")) sys.path.insert(0, _D) from control_maps import control_to_tensor, hed_to_scribble from controlnet import ( PixelDiTControlNet, load_checkpoint, load_ip_adapter_checkpoint, load_pixeldit, save_checkpoint, ) _T = 1000 def build_schedule(shift, T=_T): betas = np.linspace(1.0, 0.001, T, dtype=np.float64) sigmas_raw = 1.0 - betas sigmas = shift * sigmas_raw / (1.0 + (shift - 1.0) * sigmas_raw) alphas = 1.0 - sigmas return torch.from_numpy(sigmas).float(), torch.from_numpy(alphas).float() def q_sample(x0, t_idx, noise, alphas, sigmas): a = alphas[t_idx].view(-1, 1, 1, 1) s = sigmas[t_idx].view(-1, 1, 1, 1) return a * x0 + s * noise class ScribbleIPDataset(Dataset): def __init__(self, data_dir, emb_dir, hed_dir, size, max_samples=None): exts = {".jpg", ".jpeg", ".png", ".webp"} all_paths = sorted( os.path.join(r, f) for r, _, files in os.walk(data_dir) for f in files if os.path.splitext(f)[1].lower() in exts ) with open(os.path.join(emb_dir, "index.json"), encoding="utf-8") as fh: index = json.load(fh) self.gemma_mm = np.load(os.path.join(emb_dir, "gemma_embs.npy"), mmap_mode="r") self.siglip_mm = np.load(os.path.join(emb_dir, "siglip_pools.npy"), mmap_mode="r") self.paths = [] self.rows = [] for p in all_paths: row = index.get(str(p)) if row is None: continue stem = os.path.splitext(os.path.basename(p))[0] hed_path = os.path.join(hed_dir, f"{stem}.jpg") if not os.path.exists(hed_path): continue self.paths.append(p) self.rows.append(row) if max_samples: self.paths = self.paths[:max_samples] self.rows = self.rows[:max_samples] if not self.paths: raise RuntimeError( f"No images matched embeddings in {emb_dir} and HED maps in {hed_dir}. " "Run experiments/SANAInSANE/precompute_hed.py first." ) self.size = size self.hed_dir = hed_dir print(f"Dataset: {len(self.paths)} images with Gemma, SigLIP, and 768 HED maps") def __len__(self): return len(self.paths) def __getitem__(self, idx): path = self.paths[idx] row = self.rows[idx] stem = os.path.splitext(os.path.basename(path))[0] image = ImageOps.fit( Image.open(path).convert("RGB"), (self.size, self.size), method=Image.LANCZOS, ) arr = np.asarray(image, dtype=np.uint8).copy() img = torch.from_numpy(arr).permute(2, 0, 1).float().div_(127.5).sub_(1.0) hed_path = os.path.join(self.hed_dir, f"{stem}.jpg") hed_img = ImageOps.fit( Image.open(hed_path).convert("L"), (self.size, self.size), method=Image.LANCZOS, ) hed_arr = np.asarray(hed_img, dtype=np.uint8).copy() scribble = hed_to_scribble(hed_arr) control = control_to_tensor(scribble) gemma = torch.from_numpy(self.gemma_mm[row].copy()) siglip = torch.from_numpy(self.siglip_mm[row].copy()) return img, control, gemma, siglip def collate(batch): imgs, controls, gemmas, siglips = zip(*batch) return torch.stack(imgs), torch.stack(controls), torch.stack(gemmas), torch.stack(siglips) def main(): if IMG_SIZE < 768: raise ValueError("Train at 768+ for PixelDiT; 512 teaches the adapter a bad base-prior regime.") torch.manual_seed(SEED) random.seed(SEED) np.random.seed(SEED) if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED) torch.backends.cuda.matmul.allow_tf32 = True torch.set_float32_matmul_precision("high") print("Loading PixelDiT transformer...") transformer = load_pixeldit().to(TRANSFORMER_DEVICE) print(f"Building ControlNet ({COPY_BLOCKS_NUM} copied blocks) + frozen IP adapter...") model = PixelDiTControlNet(transformer, copy_blocks_num=COPY_BLOCKS_NUM).to(DEVICE) model.transformer.to(TRANSFORMER_DEVICE) ip_path = IP_CKPT if os.path.isabs(IP_CKPT) else os.path.join(_D, IP_CKPT) print(f"Loading IP adapter: {ip_path}") ip_step = load_ip_adapter_checkpoint(model, ip_path) print(f" IP step {ip_step}") resume_path = RESUME if (RESUME and os.path.isabs(RESUME)) else (os.path.join(_D, RESUME) if RESUME else None) base_step = 0 if resume_path: print(f"Resuming full adapter checkpoint: {resume_path}") base_step = load_checkpoint(model, resume_path) print(f" step {base_step}") # Freeze everything except the HED ControlNet branch. The pretrained IP adapter # stays active as conditioning, but it is not updated by this trainer. for p in model.parameters(): p.requires_grad_(False) for p in model.controlnet_blocks.parameters(): p.requires_grad_(True) trainable = list(model.controlnet_blocks.parameters()) n_params = sum(p.numel() for p in trainable) / 1e6 print(f"Trainable ControlNet params: {n_params:.1f}M") dataset = ScribbleIPDataset(DATA_DIR, EMBEDDINGS_DIR, HED_DIR, IMG_SIZE) loader_kw = dict( batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, drop_last=True, pin_memory=True, collate_fn=collate, ) if NUM_WORKERS > 0: loader_kw.update(persistent_workers=True, prefetch_factor=2) loader = DataLoader(dataset, **loader_kw) opt = torch.optim.AdamW(trainable, lr=LR, weight_decay=WEIGHT_DECAY) # BF16 autocast does not require gradient scaling, and CUDA GradScaler # cannot unscale BF16 grads on this build. scaler = torch.amp.GradScaler("cuda", enabled=False) sigmas_s, alphas_s = build_schedule(FLOW_SHIFT) sigmas_s = sigmas_s.to(DEVICE) alphas_s = alphas_s.to(DEVICE) null_y = torch.zeros(1, 300, 2304, dtype=torch.bfloat16, device=DEVICE) out_path = OUT if os.path.isabs(OUT) else os.path.join(_D, OUT) model.train() model.transformer.eval() step = 0 total_loss = 0.0 t0 = time.time() print(f"\nTraining {STEPS} steps batch={BATCH_SIZE} lr={LR} img={IMG_SIZE} flow_shift={FLOW_SHIFT}") pbar = tqdm(total=STEPS, unit="step") while step < STEPS: for imgs, controls, gemma_embs, siglip_pools in loader: if step >= STEPS: break imgs = imgs.to(DEVICE, non_blocking=True) B = imgs.shape[0] y = gemma_embs.to(DEVICE, dtype=torch.bfloat16, non_blocking=True) if TEXT_DROP > 0 and random.random() < TEXT_DROP: y = null_y.expand(B, -1, -1) ref_x = controls.to(DEVICE, dtype=torch.bfloat16, non_blocking=True) if CTRL_DROP > 0 and random.random() < CTRL_DROP: ctrl_scale = 0.0 else: ctrl_scale = 1.0 siglip_pools = siglip_pools.to(DEVICE, dtype=torch.bfloat16, non_blocking=True) ip_features = model.encode_siglip(siglip_pools) if IP_DROP > 0 and random.random() < IP_DROP: ip_features = torch.zeros_like(ip_features) u = torch.sigmoid(torch.randn(B, device=DEVICE) * TIMESTEP_STD) t_idx = (u * _T).long().clamp(0, _T - 1) noise = torch.randn_like(imgs) x_t = q_sample(imgs, t_idx, noise, alphas_s, sigmas_s).to(torch.bfloat16) target = (noise - imgs).float() t_model = sigmas_s[t_idx] * 1000.0 opt.zero_grad(set_to_none=True) with torch.amp.autocast("cuda", dtype=torch.bfloat16): pred = model( x_t, t_model, y, ref_x, ctrl_scale=ctrl_scale, ip_features=ip_features, ip_strength=1.0, ).to(DEVICE) loss = F.mse_loss(pred.float(), target) scaler.scale(loss).backward() scaler.unscale_(opt) grad_norm = torch.nn.utils.clip_grad_norm_(trainable, GRAD_CLIP) scaler.step(opt) scaler.update() total_loss += loss.item() step += 1 global_step = base_step + step pbar.update(1) if step % LOG_EVERY == 0: avg = total_loss / LOG_EVERY elapsed = max(time.time() - t0, 1e-6) pbar.set_postfix(loss=f"{avg:.4f}", grad=f"{float(grad_norm):.3e}", sps=f"{step / elapsed:.2f}") total_loss = 0.0 if step % SAVE_EVERY == 0 or step == STEPS: save_checkpoint(model, out_path, global_step) tqdm.write(f" saved -> {out_path}") pbar.close() save_checkpoint(model, out_path, base_step + step) print(f"\nDone -> {out_path}") if __name__ == "__main__": main()