| """Train Sage-T2I locally on STL-10 (5000 real photos).""" |
| import os, sys, math, time, gc, ssl |
| ssl._create_default_https_context = ssl._create_unverified_context |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
| from torchvision import transforms, datasets |
| from PIL import Image |
| import warnings |
| warnings.filterwarnings("ignore") |
|
|
| BASE = os.path.dirname(os.path.abspath(__file__)) |
| sys.path.insert(0, BASE) |
| from model.config import DiTConfig |
| from model.dit import DiT |
|
|
| torch.set_num_threads(3) |
| torch.set_num_interop_threads(1) |
|
|
| OUTPUT_DIR = os.path.join(BASE, "checkpoints") |
| os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
| |
| config = DiTConfig(hidden_size=384, num_layers=12, num_heads=6, image_size=128) |
| BATCH = 1; GACC = 4; LR = 1e-4; EPOCHS = 100; NUM_TIMESTEPS = 100 |
| SAVE_EVERY = 10; SAMPLE_EVERY = 10; LOG_EVERY = 20 |
| MAX_IMAGES = 500 |
|
|
| from diffusers import AutoencoderKL |
| from transformers import CLIPTextModel, CLIPTokenizer |
|
|
| print("Loading VAE + CLIP...") |
| vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float32) |
| vae.eval() |
| for p in vae.parameters(): p.requires_grad_(False) |
|
|
| text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float32) |
| tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") |
| text_encoder.eval() |
| for p in text_encoder.parameters(): p.requires_grad_(False) |
| print(f"Loaded (CPU mode, {MAX_IMAGES} images)") |
|
|
| |
| print(f"Loading STL-10 (using up to {MAX_IMAGES} of 5000 photos)...") |
| stl = datasets.STL10(root=os.path.join(BASE, "training_data"), split="train", download=True) |
| print(f"STL-10: {len(stl)} images, classes={stl.classes}") |
|
|
| transform = transforms.Compose([ |
| transforms.Resize(128, interpolation=transforms.InterpolationMode.LANCZOS), |
| transforms.CenterCrop(128), |
| transforms.ToTensor(), |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), |
| ]) |
|
|
| |
| cache_path = os.path.join(OUTPUT_DIR, "encoded.pt") |
| if os.path.exists(cache_path): |
| print("Loading cached encodings...") |
| data = torch.load(cache_path, weights_only=True) |
| all_latents, all_embeds = data["latents"], data["embeds"] |
| else: |
| print(f"Encoding {MAX_IMAGES} images...") |
| enc_start = time.time() |
| all_latents, all_embeds = [], [] |
| for i in range(min(len(stl), MAX_IMAGES)): |
| img, label = stl[i] |
| img = transforms.ToPILImage()(img) if not isinstance(img, Image.Image) else img |
| img_t = transform(img).unsqueeze(0) |
| with torch.no_grad(): |
| latent = vae.encode(img_t).latent_dist.sample() * 0.18215 |
| all_latents.append(latent.squeeze(0)) |
| caption = f"a photo of a {stl.classes[label]}" |
| tokens = tokenizer(caption, padding="max_length", max_length=77, |
| truncation=True, return_tensors="pt") |
| with torch.no_grad(): |
| emb = text_encoder(tokens.input_ids)[0] |
| all_embeds.append(emb.squeeze(0)) |
| if (i+1) % 50 == 0: |
| print(f" {i+1}/{MAX_IMAGES} ({time.time()-enc_start:.0f}s)", flush=True) |
| all_latents = torch.stack(all_latents) |
| all_embeds = torch.stack(all_embeds) |
| torch.save({"latents": all_latents, "embeds": all_embeds}, cache_path) |
| print(f"Encoded {len(all_latents)} samples in {time.time()-enc_start:.0f}s") |
|
|
| model = DiT(config) |
| total = sum(p.numel() for p in model.parameters()) |
| print(f"DiT: {total:,} ({total/1e6:.1f}M)") |
|
|
| class ImgDS(Dataset): |
| def __init__(self, l, e): self.l = l; self.e = e |
| def __len__(self): return len(self.l) |
| def __getitem__(self, i): return self.l[i], self.e[i] |
|
|
| loader = DataLoader(ImgDS(all_latents, all_embeds), shuffle=True, drop_last=True) |
| opt = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9) |
|
|
| ts = torch.linspace(0, 1, NUM_TIMESTEPS + 1) |
| s = 0.008 |
| alpha_bar = torch.cos((ts + s) / (1 + s) * math.pi / 2) ** 2 |
| alpha_bar = alpha_bar / alpha_bar[0] |
|
|
| print(f"Training: {len(loader)} batches/epoch, {EPOCHS} epochs") |
| best_loss = float('inf'); step = 0; train_start = time.time() |
|
|
| for ep in range(EPOCHS): |
| model.train(); total_loss = 0.0; nb = 0 |
| opt.zero_grad(); ep_start = time.time() |
|
|
| for bidx, (latent, embed) in enumerate(loader): |
| t = torch.randint(0, NUM_TIMESTEPS, (1,)).item() |
| t_float = torch.tensor([t / NUM_TIMESTEPS]) |
| noise = torch.randn_like(latent) |
| sqrt_ab = math.sqrt(max(alpha_bar[t].item(), 1e-8)) |
| sqrt_1ab = math.sqrt(max(1 - alpha_bar[t].item(), 1e-8)) |
| noisy = sqrt_ab * latent + sqrt_1ab * noise |
| pred = model(noisy, t_float, embed) |
| loss = F.mse_loss(pred, noise) / GACC |
| loss.backward() |
| if (bidx + 1) % GACC == 0: |
| nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| opt.step(); opt.zero_grad(); step += 1 |
| total_loss += loss.item() * GACC; nb += 1 |
| if (bidx + 1) % LOG_EVERY == 0: |
| print(f" E{ep+1} {bidx+1}/{len(loader)} | Loss: {total_loss/nb:.6f} | {time.time()-ep_start:.0f}s", flush=True) |
|
|
| avg = total_loss / nb |
| et = time.time() - ep_start |
| print(f"E{ep+1}/{EPOCHS} | Loss: {avg:.6f} | {et:.0f}s | {time.time()-train_start:.0f}s total", flush=True) |
|
|
| if avg < best_loss: |
| best_loss = avg |
| torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, "dit_best.pt")) |
| if (ep + 1) % SAVE_EVERY == 0: |
| torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, f"dit_ep{ep+1}.pt")) |
| if (ep + 1) % SAMPLE_EVERY == 0: |
| model.eval() |
| |
| for sample_size, label in [(128, "native"), (256, "256"), (512, "512")]: |
| sample_latent_size = sample_size // 8 |
| old_latent_size = config.image_size // 8 |
| if sample_latent_size != old_latent_size: |
| old_pe = model.pos_embed.data.clone() |
| pe = old_pe.float().reshape(1, old_latent_size // model.patch_size, |
| old_latent_size // model.patch_size, -1) |
| pe = pe.permute(0, 3, 1, 2) |
| new_ps = sample_latent_size // model.patch_size |
| pe = F.interpolate(pe, size=(new_ps, new_ps), mode="bicubic", align_corners=False) |
| pe = pe.permute(0, 2, 3, 1).reshape(1, -1, model.hidden_size) |
| model.pos_embed.data.copy_(pe.to(model.pos_embed.dtype)) |
| with torch.no_grad(): |
| latent = torch.randn(1, 4, sample_latent_size, sample_latent_size) |
| for ti in range(NUM_TIMESTEPS - 1, -1, -1): |
| t_float = torch.tensor([ti / NUM_TIMESTEPS]) |
| pred = model(latent, t_float, null_emb) |
| beta = max(1 - alpha_bar[ti].item() / max(alpha_bar[max(ti-1,0)].item(), 1e-8), 0.001) if ti > 0 else 0.01 |
| latent = (latent - beta * pred) / math.sqrt(max(1 - beta, 1e-8)) |
| latent = latent / 0.18215 |
| img = vae.decode(latent).sample |
| img = (img.clamp(-1, 1) + 1) / 2 |
| img = img.squeeze(0).permute(1, 2, 0).numpy() |
| img = (img * 255).clip(0, 255).astype("uint8") |
| Image.fromarray(img).save(os.path.join(OUTPUT_DIR, f"sample_ep{ep+1}_{label}.png")) |
| if sample_latent_size != old_latent_size: |
| model.pos_embed.data.copy_(old_pe) |
| model.train() |
|
|
| torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, "dit_final.pt")) |
| print(f"\nDone! {time.time()-train_start:.0f}s total | Best loss: {best_loss:.6f}", flush=True) |
|
|