"""Train Sage-T2I locally on STL-10 (5000 real photos).""" import os, sys, math, time, gc, ssl ssl._create_default_https_context = ssl._create_unverified_context import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from torchvision import transforms, datasets from PIL import Image import warnings warnings.filterwarnings("ignore") BASE = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, BASE) from model.config import DiTConfig from model.dit import DiT torch.set_num_threads(3) torch.set_num_interop_threads(1) OUTPUT_DIR = os.path.join(BASE, "checkpoints") os.makedirs(OUTPUT_DIR, exist_ok=True) # CPU-optimized config: smaller image, fewer images, fewer timesteps config = DiTConfig(hidden_size=384, num_layers=12, num_heads=6, image_size=128) BATCH = 1; GACC = 4; LR = 1e-4; EPOCHS = 100; NUM_TIMESTEPS = 100 SAVE_EVERY = 10; SAMPLE_EVERY = 10; LOG_EVERY = 20 MAX_IMAGES = 500 # Use subset for CPU speed from diffusers import AutoencoderKL from transformers import CLIPTextModel, CLIPTokenizer print("Loading VAE + CLIP...") vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float32) vae.eval() for p in vae.parameters(): p.requires_grad_(False) text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float32) tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") text_encoder.eval() for p in text_encoder.parameters(): p.requires_grad_(False) print(f"Loaded (CPU mode, {MAX_IMAGES} images)") # Download STL-10 (5000 real labeled images) print(f"Loading STL-10 (using up to {MAX_IMAGES} of 5000 photos)...") stl = datasets.STL10(root=os.path.join(BASE, "training_data"), split="train", download=True) print(f"STL-10: {len(stl)} images, classes={stl.classes}") transform = transforms.Compose([ transforms.Resize(128, interpolation=transforms.InterpolationMode.LANCZOS), transforms.CenterCrop(128), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) # Pre-encode images + auto-caption with class names cache_path = os.path.join(OUTPUT_DIR, "encoded.pt") if os.path.exists(cache_path): print("Loading cached encodings...") data = torch.load(cache_path, weights_only=True) all_latents, all_embeds = data["latents"], data["embeds"] else: print(f"Encoding {MAX_IMAGES} images...") enc_start = time.time() all_latents, all_embeds = [], [] for i in range(min(len(stl), MAX_IMAGES)): img, label = stl[i] img = transforms.ToPILImage()(img) if not isinstance(img, Image.Image) else img img_t = transform(img).unsqueeze(0) with torch.no_grad(): latent = vae.encode(img_t).latent_dist.sample() * 0.18215 all_latents.append(latent.squeeze(0)) caption = f"a photo of a {stl.classes[label]}" tokens = tokenizer(caption, padding="max_length", max_length=77, truncation=True, return_tensors="pt") with torch.no_grad(): emb = text_encoder(tokens.input_ids)[0] all_embeds.append(emb.squeeze(0)) if (i+1) % 50 == 0: print(f" {i+1}/{MAX_IMAGES} ({time.time()-enc_start:.0f}s)", flush=True) all_latents = torch.stack(all_latents) all_embeds = torch.stack(all_embeds) torch.save({"latents": all_latents, "embeds": all_embeds}, cache_path) print(f"Encoded {len(all_latents)} samples in {time.time()-enc_start:.0f}s") model = DiT(config) total = sum(p.numel() for p in model.parameters()) print(f"DiT: {total:,} ({total/1e6:.1f}M)") class ImgDS(Dataset): def __init__(self, l, e): self.l = l; self.e = e def __len__(self): return len(self.l) def __getitem__(self, i): return self.l[i], self.e[i] loader = DataLoader(ImgDS(all_latents, all_embeds), shuffle=True, drop_last=True) opt = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9) ts = torch.linspace(0, 1, NUM_TIMESTEPS + 1) s = 0.008 alpha_bar = torch.cos((ts + s) / (1 + s) * math.pi / 2) ** 2 alpha_bar = alpha_bar / alpha_bar[0] print(f"Training: {len(loader)} batches/epoch, {EPOCHS} epochs") best_loss = float('inf'); step = 0; train_start = time.time() for ep in range(EPOCHS): model.train(); total_loss = 0.0; nb = 0 opt.zero_grad(); ep_start = time.time() for bidx, (latent, embed) in enumerate(loader): t = torch.randint(0, NUM_TIMESTEPS, (1,)).item() t_float = torch.tensor([t / NUM_TIMESTEPS]) noise = torch.randn_like(latent) sqrt_ab = math.sqrt(max(alpha_bar[t].item(), 1e-8)) sqrt_1ab = math.sqrt(max(1 - alpha_bar[t].item(), 1e-8)) noisy = sqrt_ab * latent + sqrt_1ab * noise pred = model(noisy, t_float, embed) loss = F.mse_loss(pred, noise) / GACC loss.backward() if (bidx + 1) % GACC == 0: nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step(); opt.zero_grad(); step += 1 total_loss += loss.item() * GACC; nb += 1 if (bidx + 1) % LOG_EVERY == 0: print(f" E{ep+1} {bidx+1}/{len(loader)} | Loss: {total_loss/nb:.6f} | {time.time()-ep_start:.0f}s", flush=True) avg = total_loss / nb et = time.time() - ep_start print(f"E{ep+1}/{EPOCHS} | Loss: {avg:.6f} | {et:.0f}s | {time.time()-train_start:.0f}s total", flush=True) if avg < best_loss: best_loss = avg torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, "dit_best.pt")) if (ep + 1) % SAVE_EVERY == 0: torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, f"dit_ep{ep+1}.pt")) if (ep + 1) % SAMPLE_EVERY == 0: model.eval() # Generate at multiple resolutions via pos_embed interpolation for sample_size, label in [(128, "native"), (256, "256"), (512, "512")]: sample_latent_size = sample_size // 8 old_latent_size = config.image_size // 8 if sample_latent_size != old_latent_size: old_pe = model.pos_embed.data.clone() pe = old_pe.float().reshape(1, old_latent_size // model.patch_size, old_latent_size // model.patch_size, -1) pe = pe.permute(0, 3, 1, 2) new_ps = sample_latent_size // model.patch_size pe = F.interpolate(pe, size=(new_ps, new_ps), mode="bicubic", align_corners=False) pe = pe.permute(0, 2, 3, 1).reshape(1, -1, model.hidden_size) model.pos_embed.data.copy_(pe.to(model.pos_embed.dtype)) with torch.no_grad(): latent = torch.randn(1, 4, sample_latent_size, sample_latent_size) for ti in range(NUM_TIMESTEPS - 1, -1, -1): t_float = torch.tensor([ti / NUM_TIMESTEPS]) pred = model(latent, t_float, null_emb) beta = max(1 - alpha_bar[ti].item() / max(alpha_bar[max(ti-1,0)].item(), 1e-8), 0.001) if ti > 0 else 0.01 latent = (latent - beta * pred) / math.sqrt(max(1 - beta, 1e-8)) latent = latent / 0.18215 img = vae.decode(latent).sample img = (img.clamp(-1, 1) + 1) / 2 img = img.squeeze(0).permute(1, 2, 0).numpy() img = (img * 255).clip(0, 255).astype("uint8") Image.fromarray(img).save(os.path.join(OUTPUT_DIR, f"sample_ep{ep+1}_{label}.png")) if sample_latent_size != old_latent_size: model.pos_embed.data.copy_(old_pe) model.train() torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, "dit_final.pt")) print(f"\nDone! {time.time()-train_start:.0f}s total | Best loss: {best_loss:.6f}", flush=True)