""" train.py – optimize pixel values via gradient descent. Dataset: dataset/ folder containing pairs: dataset/cat.png + dataset/cat.txt (image + its prompt) dataset/sun.png + dataset/sun.txt ... Each .txt holds one line: the prompt for that image. Target images are resized to OUT_SIZE×OUT_SIZE automatically. Usage: python train.py # train with defaults python train.py --epochs 500 --lr 0.05 """ import argparse import os import sys import torch import torch.nn.functional as F import numpy as np from PIL import Image from model import ( MODEL_H, MODEL_W, OUT_SIZE, forward, load_model, save_model, init_model, pixels_to_weights, prompt_to_embedding, ) DATASET_DIR = "dataset" MODEL_PATH = "model.png" def load_dataset(): pairs = [] for fname in sorted(os.listdir(DATASET_DIR)): if not fname.lower().endswith((".png", ".jpg", ".jpeg")): continue stem = os.path.splitext(fname)[0] txt_path = os.path.join(DATASET_DIR, stem + ".txt") img_path = os.path.join(DATASET_DIR, fname) if not os.path.exists(txt_path): print(f" [skip] no .txt for {fname}") continue with open(txt_path) as f: prompt = f.read().strip() img = Image.open(img_path).convert("RGB").resize( (OUT_SIZE, OUT_SIZE), Image.BILINEAR ) target = torch.tensor( np.array(img, dtype=np.float32) / 255.0 ) # (32,32,3) pairs.append((prompt, target)) print(f" loaded: '{prompt}' ← {fname}") return pairs def train(epochs=300, lr=0.03, save_every=50): # init model if missing if not os.path.exists(MODEL_PATH): init_model(MODEL_PATH) raw = load_model(MODEL_PATH) # (H, W, 3) # pixels are the parameters — wrap in a leaf tensor pixels = raw.clone().requires_grad_(True) optimizer = torch.optim.Adam([pixels], lr=lr) print(f"\nLoading dataset from '{DATASET_DIR}/'…") dataset = load_dataset() if not dataset: sys.exit("No dataset pairs found. Add image+txt pairs to dataset/") print(f"\nTraining epochs={epochs} lr={lr} pairs={len(dataset)}\n") for epoch in range(1, epochs + 1): optimizer.zero_grad() total_loss = torch.tensor(0.0) for prompt, target in dataset: # pixels must stay in [0,1] for the sign/magnitude encoding p_clamped = pixels.clamp(0, 1) pred = forward(p_clamped, prompt) # (32,32,3) loss = F.mse_loss(pred, target) total_loss = total_loss + loss total_loss.backward() optimizer.step() # clamp pixels back to valid range after update with torch.no_grad(): pixels.clamp_(0, 1) if epoch % save_every == 0 or epoch == 1: save_model(pixels, MODEL_PATH) avg = total_loss.item() / len(dataset) print(f" epoch {epoch:>5}/{epochs} loss={avg:.5f} → saved model.png") save_model(pixels, MODEL_PATH) print(f"\nDone. Final model saved to {MODEL_PATH}") if __name__ == "__main__": p = argparse.ArgumentParser() p.add_argument("--epochs", type=int, default=300) p.add_argument("--lr", type=float, default=0.03) p.add_argument("--save-every", type=int, default=50) args = p.parse_args() train(args.epochs, args.lr, args.save_every)