sage-t2i / train_local.py
itriedcoding's picture
Upload folder using huggingface_hub
2d7087a verified
Raw
History Blame Contribute Delete
7.76 kB
"""Train Sage-T2I locally on STL-10 (5000 real photos)."""
import os, sys, math, time, gc, ssl
ssl._create_default_https_context = ssl._create_unverified_context
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
from PIL import Image
import warnings
warnings.filterwarnings("ignore")
BASE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, BASE)
from model.config import DiTConfig
from model.dit import DiT
torch.set_num_threads(3)
torch.set_num_interop_threads(1)
OUTPUT_DIR = os.path.join(BASE, "checkpoints")
os.makedirs(OUTPUT_DIR, exist_ok=True)
# CPU-optimized config: smaller image, fewer images, fewer timesteps
config = DiTConfig(hidden_size=384, num_layers=12, num_heads=6, image_size=128)
BATCH = 1; GACC = 4; LR = 1e-4; EPOCHS = 100; NUM_TIMESTEPS = 100
SAVE_EVERY = 10; SAMPLE_EVERY = 10; LOG_EVERY = 20
MAX_IMAGES = 500 # Use subset for CPU speed
from diffusers import AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer
print("Loading VAE + CLIP...")
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float32)
vae.eval()
for p in vae.parameters(): p.requires_grad_(False)
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float32)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder.eval()
for p in text_encoder.parameters(): p.requires_grad_(False)
print(f"Loaded (CPU mode, {MAX_IMAGES} images)")
# Download STL-10 (5000 real labeled images)
print(f"Loading STL-10 (using up to {MAX_IMAGES} of 5000 photos)...")
stl = datasets.STL10(root=os.path.join(BASE, "training_data"), split="train", download=True)
print(f"STL-10: {len(stl)} images, classes={stl.classes}")
transform = transforms.Compose([
transforms.Resize(128, interpolation=transforms.InterpolationMode.LANCZOS),
transforms.CenterCrop(128),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
# Pre-encode images + auto-caption with class names
cache_path = os.path.join(OUTPUT_DIR, "encoded.pt")
if os.path.exists(cache_path):
print("Loading cached encodings...")
data = torch.load(cache_path, weights_only=True)
all_latents, all_embeds = data["latents"], data["embeds"]
else:
print(f"Encoding {MAX_IMAGES} images...")
enc_start = time.time()
all_latents, all_embeds = [], []
for i in range(min(len(stl), MAX_IMAGES)):
img, label = stl[i]
img = transforms.ToPILImage()(img) if not isinstance(img, Image.Image) else img
img_t = transform(img).unsqueeze(0)
with torch.no_grad():
latent = vae.encode(img_t).latent_dist.sample() * 0.18215
all_latents.append(latent.squeeze(0))
caption = f"a photo of a {stl.classes[label]}"
tokens = tokenizer(caption, padding="max_length", max_length=77,
truncation=True, return_tensors="pt")
with torch.no_grad():
emb = text_encoder(tokens.input_ids)[0]
all_embeds.append(emb.squeeze(0))
if (i+1) % 50 == 0:
print(f" {i+1}/{MAX_IMAGES} ({time.time()-enc_start:.0f}s)", flush=True)
all_latents = torch.stack(all_latents)
all_embeds = torch.stack(all_embeds)
torch.save({"latents": all_latents, "embeds": all_embeds}, cache_path)
print(f"Encoded {len(all_latents)} samples in {time.time()-enc_start:.0f}s")
model = DiT(config)
total = sum(p.numel() for p in model.parameters())
print(f"DiT: {total:,} ({total/1e6:.1f}M)")
class ImgDS(Dataset):
def __init__(self, l, e): self.l = l; self.e = e
def __len__(self): return len(self.l)
def __getitem__(self, i): return self.l[i], self.e[i]
loader = DataLoader(ImgDS(all_latents, all_embeds), shuffle=True, drop_last=True)
opt = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9)
ts = torch.linspace(0, 1, NUM_TIMESTEPS + 1)
s = 0.008
alpha_bar = torch.cos((ts + s) / (1 + s) * math.pi / 2) ** 2
alpha_bar = alpha_bar / alpha_bar[0]
print(f"Training: {len(loader)} batches/epoch, {EPOCHS} epochs")
best_loss = float('inf'); step = 0; train_start = time.time()
for ep in range(EPOCHS):
model.train(); total_loss = 0.0; nb = 0
opt.zero_grad(); ep_start = time.time()
for bidx, (latent, embed) in enumerate(loader):
t = torch.randint(0, NUM_TIMESTEPS, (1,)).item()
t_float = torch.tensor([t / NUM_TIMESTEPS])
noise = torch.randn_like(latent)
sqrt_ab = math.sqrt(max(alpha_bar[t].item(), 1e-8))
sqrt_1ab = math.sqrt(max(1 - alpha_bar[t].item(), 1e-8))
noisy = sqrt_ab * latent + sqrt_1ab * noise
pred = model(noisy, t_float, embed)
loss = F.mse_loss(pred, noise) / GACC
loss.backward()
if (bidx + 1) % GACC == 0:
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step(); opt.zero_grad(); step += 1
total_loss += loss.item() * GACC; nb += 1
if (bidx + 1) % LOG_EVERY == 0:
print(f" E{ep+1} {bidx+1}/{len(loader)} | Loss: {total_loss/nb:.6f} | {time.time()-ep_start:.0f}s", flush=True)
avg = total_loss / nb
et = time.time() - ep_start
print(f"E{ep+1}/{EPOCHS} | Loss: {avg:.6f} | {et:.0f}s | {time.time()-train_start:.0f}s total", flush=True)
if avg < best_loss:
best_loss = avg
torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, "dit_best.pt"))
if (ep + 1) % SAVE_EVERY == 0:
torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, f"dit_ep{ep+1}.pt"))
if (ep + 1) % SAMPLE_EVERY == 0:
model.eval()
# Generate at multiple resolutions via pos_embed interpolation
for sample_size, label in [(128, "native"), (256, "256"), (512, "512")]:
sample_latent_size = sample_size // 8
old_latent_size = config.image_size // 8
if sample_latent_size != old_latent_size:
old_pe = model.pos_embed.data.clone()
pe = old_pe.float().reshape(1, old_latent_size // model.patch_size,
old_latent_size // model.patch_size, -1)
pe = pe.permute(0, 3, 1, 2)
new_ps = sample_latent_size // model.patch_size
pe = F.interpolate(pe, size=(new_ps, new_ps), mode="bicubic", align_corners=False)
pe = pe.permute(0, 2, 3, 1).reshape(1, -1, model.hidden_size)
model.pos_embed.data.copy_(pe.to(model.pos_embed.dtype))
with torch.no_grad():
latent = torch.randn(1, 4, sample_latent_size, sample_latent_size)
for ti in range(NUM_TIMESTEPS - 1, -1, -1):
t_float = torch.tensor([ti / NUM_TIMESTEPS])
pred = model(latent, t_float, null_emb)
beta = max(1 - alpha_bar[ti].item() / max(alpha_bar[max(ti-1,0)].item(), 1e-8), 0.001) if ti > 0 else 0.01
latent = (latent - beta * pred) / math.sqrt(max(1 - beta, 1e-8))
latent = latent / 0.18215
img = vae.decode(latent).sample
img = (img.clamp(-1, 1) + 1) / 2
img = img.squeeze(0).permute(1, 2, 0).numpy()
img = (img * 255).clip(0, 255).astype("uint8")
Image.fromarray(img).save(os.path.join(OUTPUT_DIR, f"sample_ep{ep+1}_{label}.png"))
if sample_latent_size != old_latent_size:
model.pos_embed.data.copy_(old_pe)
model.train()
torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, "dit_final.pt"))
print(f"\nDone! {time.time()-train_start:.0f}s total | Best loss: {best_loss:.6f}", flush=True)