Instructions to use madtune/pixeldit-controlnet with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use madtune/pixeldit-controlnet with Diffusers:
pip install -U diffusers transformers accelerate
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline controlnet = ControlNetModel.from_pretrained("madtune/pixeldit-controlnet") pipe = StableDiffusionControlNetPipeline.from_pretrained( "madtune/pixeldit-diffusers", controlnet=controlnet ) - Notebooks
- Google Colab
- Kaggle
| import json | |
| import os | |
| import random | |
| import sys | |
| import time | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image, ImageOps | |
| from torch.utils.data import DataLoader, Dataset | |
| from tqdm import tqdm | |
| # Settings -------------------------------------------------------------------- | |
| DATA_DIR = "/home/nobus/Raid0/DataSet/Images1" | |
| HED_DIR = "/home/nobus/Raid0/DataSet/hed_maps_768" | |
| EMBEDDINGS_DIR = "/home/nobus/Raid0/DataSet/embeddings" # index.json + gemma_embs.npy + siglip_pools.npy | |
| OUT = "checkpoints/controlnet_scribble_ip_768.pt" | |
| RESUME = None | |
| IP_CKPT = "checkpoints/ip_adapter_v2.pt" | |
| DEVICE = "cuda:1" | |
| TRANSFORMER_DEVICE = "cuda:1" | |
| COPY_BLOCKS_NUM = 7 | |
| BATCH_SIZE = 1 | |
| STEPS = 10000 | |
| LR = 5e-5 | |
| WEIGHT_DECAY = 1e-4 | |
| IMG_SIZE = 768 | |
| FLOW_SHIFT = 4.0 | |
| SEED = 4243 | |
| TIMESTEP_STD = 1.0 | |
| TEXT_DROP = 0.10 | |
| IP_DROP = 0.05 | |
| CTRL_DROP = 0.05 | |
| GRAD_CLIP = 0.25 | |
| LOG_EVERY = 25 | |
| SAVE_EVERY = 500 | |
| NUM_WORKERS = 4 | |
| # ---------------------------------------------------------------------------- | |
| _D = os.path.dirname(os.path.abspath(__file__)) | |
| _ROOT = os.path.abspath(os.path.join(_D, "../..")) | |
| sys.path.insert(0, _D) | |
| from control_maps import control_to_tensor, hed_to_scribble | |
| from controlnet import ( | |
| PixelDiTControlNet, | |
| load_checkpoint, | |
| load_ip_adapter_checkpoint, | |
| load_pixeldit, | |
| save_checkpoint, | |
| ) | |
| _T = 1000 | |
| def build_schedule(shift, T=_T): | |
| betas = np.linspace(1.0, 0.001, T, dtype=np.float64) | |
| sigmas_raw = 1.0 - betas | |
| sigmas = shift * sigmas_raw / (1.0 + (shift - 1.0) * sigmas_raw) | |
| alphas = 1.0 - sigmas | |
| return torch.from_numpy(sigmas).float(), torch.from_numpy(alphas).float() | |
| def q_sample(x0, t_idx, noise, alphas, sigmas): | |
| a = alphas[t_idx].view(-1, 1, 1, 1) | |
| s = sigmas[t_idx].view(-1, 1, 1, 1) | |
| return a * x0 + s * noise | |
| class ScribbleIPDataset(Dataset): | |
| def __init__(self, data_dir, emb_dir, hed_dir, size, max_samples=None): | |
| exts = {".jpg", ".jpeg", ".png", ".webp"} | |
| all_paths = sorted( | |
| os.path.join(r, f) | |
| for r, _, files in os.walk(data_dir) | |
| for f in files | |
| if os.path.splitext(f)[1].lower() in exts | |
| ) | |
| with open(os.path.join(emb_dir, "index.json"), encoding="utf-8") as fh: | |
| index = json.load(fh) | |
| self.gemma_mm = np.load(os.path.join(emb_dir, "gemma_embs.npy"), mmap_mode="r") | |
| self.siglip_mm = np.load(os.path.join(emb_dir, "siglip_pools.npy"), mmap_mode="r") | |
| self.paths = [] | |
| self.rows = [] | |
| for p in all_paths: | |
| row = index.get(str(p)) | |
| if row is None: | |
| continue | |
| stem = os.path.splitext(os.path.basename(p))[0] | |
| hed_path = os.path.join(hed_dir, f"{stem}.jpg") | |
| if not os.path.exists(hed_path): | |
| continue | |
| self.paths.append(p) | |
| self.rows.append(row) | |
| if max_samples: | |
| self.paths = self.paths[:max_samples] | |
| self.rows = self.rows[:max_samples] | |
| if not self.paths: | |
| raise RuntimeError( | |
| f"No images matched embeddings in {emb_dir} and HED maps in {hed_dir}. " | |
| "Run experiments/SANAInSANE/precompute_hed.py first." | |
| ) | |
| self.size = size | |
| self.hed_dir = hed_dir | |
| print(f"Dataset: {len(self.paths)} images with Gemma, SigLIP, and 768 HED maps") | |
| def __len__(self): | |
| return len(self.paths) | |
| def __getitem__(self, idx): | |
| path = self.paths[idx] | |
| row = self.rows[idx] | |
| stem = os.path.splitext(os.path.basename(path))[0] | |
| image = ImageOps.fit( | |
| Image.open(path).convert("RGB"), | |
| (self.size, self.size), | |
| method=Image.LANCZOS, | |
| ) | |
| arr = np.asarray(image, dtype=np.uint8).copy() | |
| img = torch.from_numpy(arr).permute(2, 0, 1).float().div_(127.5).sub_(1.0) | |
| hed_path = os.path.join(self.hed_dir, f"{stem}.jpg") | |
| hed_img = ImageOps.fit( | |
| Image.open(hed_path).convert("L"), | |
| (self.size, self.size), | |
| method=Image.LANCZOS, | |
| ) | |
| hed_arr = np.asarray(hed_img, dtype=np.uint8).copy() | |
| scribble = hed_to_scribble(hed_arr) | |
| control = control_to_tensor(scribble) | |
| gemma = torch.from_numpy(self.gemma_mm[row].copy()) | |
| siglip = torch.from_numpy(self.siglip_mm[row].copy()) | |
| return img, control, gemma, siglip | |
| def collate(batch): | |
| imgs, controls, gemmas, siglips = zip(*batch) | |
| return torch.stack(imgs), torch.stack(controls), torch.stack(gemmas), torch.stack(siglips) | |
| def main(): | |
| if IMG_SIZE < 768: | |
| raise ValueError("Train at 768+ for PixelDiT; 512 teaches the adapter a bad base-prior regime.") | |
| torch.manual_seed(SEED) | |
| random.seed(SEED) | |
| np.random.seed(SEED) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(SEED) | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.set_float32_matmul_precision("high") | |
| print("Loading PixelDiT transformer...") | |
| transformer = load_pixeldit().to(TRANSFORMER_DEVICE) | |
| print(f"Building ControlNet ({COPY_BLOCKS_NUM} copied blocks) + frozen IP adapter...") | |
| model = PixelDiTControlNet(transformer, copy_blocks_num=COPY_BLOCKS_NUM).to(DEVICE) | |
| model.transformer.to(TRANSFORMER_DEVICE) | |
| ip_path = IP_CKPT if os.path.isabs(IP_CKPT) else os.path.join(_D, IP_CKPT) | |
| print(f"Loading IP adapter: {ip_path}") | |
| ip_step = load_ip_adapter_checkpoint(model, ip_path) | |
| print(f" IP step {ip_step}") | |
| resume_path = RESUME if (RESUME and os.path.isabs(RESUME)) else (os.path.join(_D, RESUME) if RESUME else None) | |
| base_step = 0 | |
| if resume_path: | |
| print(f"Resuming full adapter checkpoint: {resume_path}") | |
| base_step = load_checkpoint(model, resume_path) | |
| print(f" step {base_step}") | |
| # Freeze everything except the HED ControlNet branch. The pretrained IP adapter | |
| # stays active as conditioning, but it is not updated by this trainer. | |
| for p in model.parameters(): | |
| p.requires_grad_(False) | |
| for p in model.controlnet_blocks.parameters(): | |
| p.requires_grad_(True) | |
| trainable = list(model.controlnet_blocks.parameters()) | |
| n_params = sum(p.numel() for p in trainable) / 1e6 | |
| print(f"Trainable ControlNet params: {n_params:.1f}M") | |
| dataset = ScribbleIPDataset(DATA_DIR, EMBEDDINGS_DIR, HED_DIR, IMG_SIZE) | |
| loader_kw = dict( | |
| batch_size=BATCH_SIZE, | |
| shuffle=True, | |
| num_workers=NUM_WORKERS, | |
| drop_last=True, | |
| pin_memory=True, | |
| collate_fn=collate, | |
| ) | |
| if NUM_WORKERS > 0: | |
| loader_kw.update(persistent_workers=True, prefetch_factor=2) | |
| loader = DataLoader(dataset, **loader_kw) | |
| opt = torch.optim.AdamW(trainable, lr=LR, weight_decay=WEIGHT_DECAY) | |
| # BF16 autocast does not require gradient scaling, and CUDA GradScaler | |
| # cannot unscale BF16 grads on this build. | |
| scaler = torch.amp.GradScaler("cuda", enabled=False) | |
| sigmas_s, alphas_s = build_schedule(FLOW_SHIFT) | |
| sigmas_s = sigmas_s.to(DEVICE) | |
| alphas_s = alphas_s.to(DEVICE) | |
| null_y = torch.zeros(1, 300, 2304, dtype=torch.bfloat16, device=DEVICE) | |
| out_path = OUT if os.path.isabs(OUT) else os.path.join(_D, OUT) | |
| model.train() | |
| model.transformer.eval() | |
| step = 0 | |
| total_loss = 0.0 | |
| t0 = time.time() | |
| print(f"\nTraining {STEPS} steps batch={BATCH_SIZE} lr={LR} img={IMG_SIZE} flow_shift={FLOW_SHIFT}") | |
| pbar = tqdm(total=STEPS, unit="step") | |
| while step < STEPS: | |
| for imgs, controls, gemma_embs, siglip_pools in loader: | |
| if step >= STEPS: | |
| break | |
| imgs = imgs.to(DEVICE, non_blocking=True) | |
| B = imgs.shape[0] | |
| y = gemma_embs.to(DEVICE, dtype=torch.bfloat16, non_blocking=True) | |
| if TEXT_DROP > 0 and random.random() < TEXT_DROP: | |
| y = null_y.expand(B, -1, -1) | |
| ref_x = controls.to(DEVICE, dtype=torch.bfloat16, non_blocking=True) | |
| if CTRL_DROP > 0 and random.random() < CTRL_DROP: | |
| ctrl_scale = 0.0 | |
| else: | |
| ctrl_scale = 1.0 | |
| siglip_pools = siglip_pools.to(DEVICE, dtype=torch.bfloat16, non_blocking=True) | |
| ip_features = model.encode_siglip(siglip_pools) | |
| if IP_DROP > 0 and random.random() < IP_DROP: | |
| ip_features = torch.zeros_like(ip_features) | |
| u = torch.sigmoid(torch.randn(B, device=DEVICE) * TIMESTEP_STD) | |
| t_idx = (u * _T).long().clamp(0, _T - 1) | |
| noise = torch.randn_like(imgs) | |
| x_t = q_sample(imgs, t_idx, noise, alphas_s, sigmas_s).to(torch.bfloat16) | |
| target = (noise - imgs).float() | |
| t_model = sigmas_s[t_idx] * 1000.0 | |
| opt.zero_grad(set_to_none=True) | |
| with torch.amp.autocast("cuda", dtype=torch.bfloat16): | |
| pred = model( | |
| x_t, | |
| t_model, | |
| y, | |
| ref_x, | |
| ctrl_scale=ctrl_scale, | |
| ip_features=ip_features, | |
| ip_strength=1.0, | |
| ).to(DEVICE) | |
| loss = F.mse_loss(pred.float(), target) | |
| scaler.scale(loss).backward() | |
| scaler.unscale_(opt) | |
| grad_norm = torch.nn.utils.clip_grad_norm_(trainable, GRAD_CLIP) | |
| scaler.step(opt) | |
| scaler.update() | |
| total_loss += loss.item() | |
| step += 1 | |
| global_step = base_step + step | |
| pbar.update(1) | |
| if step % LOG_EVERY == 0: | |
| avg = total_loss / LOG_EVERY | |
| elapsed = max(time.time() - t0, 1e-6) | |
| pbar.set_postfix(loss=f"{avg:.4f}", grad=f"{float(grad_norm):.3e}", sps=f"{step / elapsed:.2f}") | |
| total_loss = 0.0 | |
| if step % SAVE_EVERY == 0 or step == STEPS: | |
| save_checkpoint(model, out_path, global_step) | |
| tqdm.write(f" saved -> {out_path}") | |
| pbar.close() | |
| save_checkpoint(model, out_path, base_step + step) | |
| print(f"\nDone -> {out_path}") | |
| if __name__ == "__main__": | |
| main() | |