File size: 7,756 Bytes
2d7087a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 | """Train Sage-T2I locally on STL-10 (5000 real photos)."""
import os, sys, math, time, gc, ssl
ssl._create_default_https_context = ssl._create_unverified_context
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
from PIL import Image
import warnings
warnings.filterwarnings("ignore")
BASE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, BASE)
from model.config import DiTConfig
from model.dit import DiT
torch.set_num_threads(3)
torch.set_num_interop_threads(1)
OUTPUT_DIR = os.path.join(BASE, "checkpoints")
os.makedirs(OUTPUT_DIR, exist_ok=True)
# CPU-optimized config: smaller image, fewer images, fewer timesteps
config = DiTConfig(hidden_size=384, num_layers=12, num_heads=6, image_size=128)
BATCH = 1; GACC = 4; LR = 1e-4; EPOCHS = 100; NUM_TIMESTEPS = 100
SAVE_EVERY = 10; SAMPLE_EVERY = 10; LOG_EVERY = 20
MAX_IMAGES = 500 # Use subset for CPU speed
from diffusers import AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer
print("Loading VAE + CLIP...")
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float32)
vae.eval()
for p in vae.parameters(): p.requires_grad_(False)
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float32)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder.eval()
for p in text_encoder.parameters(): p.requires_grad_(False)
print(f"Loaded (CPU mode, {MAX_IMAGES} images)")
# Download STL-10 (5000 real labeled images)
print(f"Loading STL-10 (using up to {MAX_IMAGES} of 5000 photos)...")
stl = datasets.STL10(root=os.path.join(BASE, "training_data"), split="train", download=True)
print(f"STL-10: {len(stl)} images, classes={stl.classes}")
transform = transforms.Compose([
transforms.Resize(128, interpolation=transforms.InterpolationMode.LANCZOS),
transforms.CenterCrop(128),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
# Pre-encode images + auto-caption with class names
cache_path = os.path.join(OUTPUT_DIR, "encoded.pt")
if os.path.exists(cache_path):
print("Loading cached encodings...")
data = torch.load(cache_path, weights_only=True)
all_latents, all_embeds = data["latents"], data["embeds"]
else:
print(f"Encoding {MAX_IMAGES} images...")
enc_start = time.time()
all_latents, all_embeds = [], []
for i in range(min(len(stl), MAX_IMAGES)):
img, label = stl[i]
img = transforms.ToPILImage()(img) if not isinstance(img, Image.Image) else img
img_t = transform(img).unsqueeze(0)
with torch.no_grad():
latent = vae.encode(img_t).latent_dist.sample() * 0.18215
all_latents.append(latent.squeeze(0))
caption = f"a photo of a {stl.classes[label]}"
tokens = tokenizer(caption, padding="max_length", max_length=77,
truncation=True, return_tensors="pt")
with torch.no_grad():
emb = text_encoder(tokens.input_ids)[0]
all_embeds.append(emb.squeeze(0))
if (i+1) % 50 == 0:
print(f" {i+1}/{MAX_IMAGES} ({time.time()-enc_start:.0f}s)", flush=True)
all_latents = torch.stack(all_latents)
all_embeds = torch.stack(all_embeds)
torch.save({"latents": all_latents, "embeds": all_embeds}, cache_path)
print(f"Encoded {len(all_latents)} samples in {time.time()-enc_start:.0f}s")
model = DiT(config)
total = sum(p.numel() for p in model.parameters())
print(f"DiT: {total:,} ({total/1e6:.1f}M)")
class ImgDS(Dataset):
def __init__(self, l, e): self.l = l; self.e = e
def __len__(self): return len(self.l)
def __getitem__(self, i): return self.l[i], self.e[i]
loader = DataLoader(ImgDS(all_latents, all_embeds), shuffle=True, drop_last=True)
opt = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9)
ts = torch.linspace(0, 1, NUM_TIMESTEPS + 1)
s = 0.008
alpha_bar = torch.cos((ts + s) / (1 + s) * math.pi / 2) ** 2
alpha_bar = alpha_bar / alpha_bar[0]
print(f"Training: {len(loader)} batches/epoch, {EPOCHS} epochs")
best_loss = float('inf'); step = 0; train_start = time.time()
for ep in range(EPOCHS):
model.train(); total_loss = 0.0; nb = 0
opt.zero_grad(); ep_start = time.time()
for bidx, (latent, embed) in enumerate(loader):
t = torch.randint(0, NUM_TIMESTEPS, (1,)).item()
t_float = torch.tensor([t / NUM_TIMESTEPS])
noise = torch.randn_like(latent)
sqrt_ab = math.sqrt(max(alpha_bar[t].item(), 1e-8))
sqrt_1ab = math.sqrt(max(1 - alpha_bar[t].item(), 1e-8))
noisy = sqrt_ab * latent + sqrt_1ab * noise
pred = model(noisy, t_float, embed)
loss = F.mse_loss(pred, noise) / GACC
loss.backward()
if (bidx + 1) % GACC == 0:
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step(); opt.zero_grad(); step += 1
total_loss += loss.item() * GACC; nb += 1
if (bidx + 1) % LOG_EVERY == 0:
print(f" E{ep+1} {bidx+1}/{len(loader)} | Loss: {total_loss/nb:.6f} | {time.time()-ep_start:.0f}s", flush=True)
avg = total_loss / nb
et = time.time() - ep_start
print(f"E{ep+1}/{EPOCHS} | Loss: {avg:.6f} | {et:.0f}s | {time.time()-train_start:.0f}s total", flush=True)
if avg < best_loss:
best_loss = avg
torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, "dit_best.pt"))
if (ep + 1) % SAVE_EVERY == 0:
torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, f"dit_ep{ep+1}.pt"))
if (ep + 1) % SAMPLE_EVERY == 0:
model.eval()
# Generate at multiple resolutions via pos_embed interpolation
for sample_size, label in [(128, "native"), (256, "256"), (512, "512")]:
sample_latent_size = sample_size // 8
old_latent_size = config.image_size // 8
if sample_latent_size != old_latent_size:
old_pe = model.pos_embed.data.clone()
pe = old_pe.float().reshape(1, old_latent_size // model.patch_size,
old_latent_size // model.patch_size, -1)
pe = pe.permute(0, 3, 1, 2)
new_ps = sample_latent_size // model.patch_size
pe = F.interpolate(pe, size=(new_ps, new_ps), mode="bicubic", align_corners=False)
pe = pe.permute(0, 2, 3, 1).reshape(1, -1, model.hidden_size)
model.pos_embed.data.copy_(pe.to(model.pos_embed.dtype))
with torch.no_grad():
latent = torch.randn(1, 4, sample_latent_size, sample_latent_size)
for ti in range(NUM_TIMESTEPS - 1, -1, -1):
t_float = torch.tensor([ti / NUM_TIMESTEPS])
pred = model(latent, t_float, null_emb)
beta = max(1 - alpha_bar[ti].item() / max(alpha_bar[max(ti-1,0)].item(), 1e-8), 0.001) if ti > 0 else 0.01
latent = (latent - beta * pred) / math.sqrt(max(1 - beta, 1e-8))
latent = latent / 0.18215
img = vae.decode(latent).sample
img = (img.clamp(-1, 1) + 1) / 2
img = img.squeeze(0).permute(1, 2, 0).numpy()
img = (img * 255).clip(0, 255).astype("uint8")
Image.fromarray(img).save(os.path.join(OUTPUT_DIR, f"sample_ep{ep+1}_{label}.png"))
if sample_latent_size != old_latent_size:
model.pos_embed.data.copy_(old_pe)
model.train()
torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, "dit_final.pt"))
print(f"\nDone! {time.time()-train_start:.0f}s total | Best loss: {best_loss:.6f}", flush=True)
|