| """ |
| train.py – optimize pixel values via gradient descent. |
| |
| Dataset: dataset/ folder containing pairs: |
| dataset/cat.png + dataset/cat.txt (image + its prompt) |
| dataset/sun.png + dataset/sun.txt |
| ... |
| |
| Each .txt holds one line: the prompt for that image. |
| Target images are resized to OUT_SIZE×OUT_SIZE automatically. |
| |
| Usage: |
| python train.py # train with defaults |
| python train.py --epochs 500 --lr 0.05 |
| """ |
|
|
| import argparse |
| import os |
| import sys |
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| from PIL import Image |
|
|
| from model import ( |
| MODEL_H, MODEL_W, OUT_SIZE, |
| forward, load_model, save_model, init_model, |
| pixels_to_weights, prompt_to_embedding, |
| ) |
|
|
| DATASET_DIR = "dataset" |
| MODEL_PATH = "model.png" |
|
|
|
|
| def load_dataset(): |
| pairs = [] |
| for fname in sorted(os.listdir(DATASET_DIR)): |
| if not fname.lower().endswith((".png", ".jpg", ".jpeg")): |
| continue |
| stem = os.path.splitext(fname)[0] |
| txt_path = os.path.join(DATASET_DIR, stem + ".txt") |
| img_path = os.path.join(DATASET_DIR, fname) |
| if not os.path.exists(txt_path): |
| print(f" [skip] no .txt for {fname}") |
| continue |
| with open(txt_path) as f: |
| prompt = f.read().strip() |
| img = Image.open(img_path).convert("RGB").resize( |
| (OUT_SIZE, OUT_SIZE), Image.BILINEAR |
| ) |
| target = torch.tensor( |
| np.array(img, dtype=np.float32) / 255.0 |
| ) |
| pairs.append((prompt, target)) |
| print(f" loaded: '{prompt}' ← {fname}") |
| return pairs |
|
|
|
|
| def train(epochs=300, lr=0.03, save_every=50): |
| |
| if not os.path.exists(MODEL_PATH): |
| init_model(MODEL_PATH) |
|
|
| raw = load_model(MODEL_PATH) |
|
|
| |
| pixels = raw.clone().requires_grad_(True) |
| optimizer = torch.optim.Adam([pixels], lr=lr) |
|
|
| print(f"\nLoading dataset from '{DATASET_DIR}/'…") |
| dataset = load_dataset() |
| if not dataset: |
| sys.exit("No dataset pairs found. Add image+txt pairs to dataset/") |
|
|
| print(f"\nTraining epochs={epochs} lr={lr} pairs={len(dataset)}\n") |
|
|
| for epoch in range(1, epochs + 1): |
| optimizer.zero_grad() |
| total_loss = torch.tensor(0.0) |
|
|
| for prompt, target in dataset: |
| |
| p_clamped = pixels.clamp(0, 1) |
| pred = forward(p_clamped, prompt) |
| loss = F.mse_loss(pred, target) |
| total_loss = total_loss + loss |
|
|
| total_loss.backward() |
| optimizer.step() |
|
|
| |
| with torch.no_grad(): |
| pixels.clamp_(0, 1) |
|
|
| if epoch % save_every == 0 or epoch == 1: |
| save_model(pixels, MODEL_PATH) |
| avg = total_loss.item() / len(dataset) |
| print(f" epoch {epoch:>5}/{epochs} loss={avg:.5f} → saved model.png") |
|
|
| save_model(pixels, MODEL_PATH) |
| print(f"\nDone. Final model saved to {MODEL_PATH}") |
|
|
|
|
| if __name__ == "__main__": |
| p = argparse.ArgumentParser() |
| p.add_argument("--epochs", type=int, default=300) |
| p.add_argument("--lr", type=float, default=0.03) |
| p.add_argument("--save-every", type=int, default=50) |
| args = p.parse_args() |
| train(args.epochs, args.lr, args.save_every) |
|
|