pixelmodel / train.py
wop's picture
Upload 23 files
38bfe91 verified
Raw
History Blame Contribute Delete
3.46 kB
"""
train.py – optimize pixel values via gradient descent.
Dataset: dataset/ folder containing pairs:
dataset/cat.png + dataset/cat.txt (image + its prompt)
dataset/sun.png + dataset/sun.txt
...
Each .txt holds one line: the prompt for that image.
Target images are resized to OUT_SIZE×OUT_SIZE automatically.
Usage:
python train.py # train with defaults
python train.py --epochs 500 --lr 0.05
"""
import argparse
import os
import sys
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
from model import (
MODEL_H, MODEL_W, OUT_SIZE,
forward, load_model, save_model, init_model,
pixels_to_weights, prompt_to_embedding,
)
DATASET_DIR = "dataset"
MODEL_PATH = "model.png"
def load_dataset():
pairs = []
for fname in sorted(os.listdir(DATASET_DIR)):
if not fname.lower().endswith((".png", ".jpg", ".jpeg")):
continue
stem = os.path.splitext(fname)[0]
txt_path = os.path.join(DATASET_DIR, stem + ".txt")
img_path = os.path.join(DATASET_DIR, fname)
if not os.path.exists(txt_path):
print(f" [skip] no .txt for {fname}")
continue
with open(txt_path) as f:
prompt = f.read().strip()
img = Image.open(img_path).convert("RGB").resize(
(OUT_SIZE, OUT_SIZE), Image.BILINEAR
)
target = torch.tensor(
np.array(img, dtype=np.float32) / 255.0
) # (32,32,3)
pairs.append((prompt, target))
print(f" loaded: '{prompt}' ← {fname}")
return pairs
def train(epochs=300, lr=0.03, save_every=50):
# init model if missing
if not os.path.exists(MODEL_PATH):
init_model(MODEL_PATH)
raw = load_model(MODEL_PATH) # (H, W, 3)
# pixels are the parameters — wrap in a leaf tensor
pixels = raw.clone().requires_grad_(True)
optimizer = torch.optim.Adam([pixels], lr=lr)
print(f"\nLoading dataset from '{DATASET_DIR}/'…")
dataset = load_dataset()
if not dataset:
sys.exit("No dataset pairs found. Add image+txt pairs to dataset/")
print(f"\nTraining epochs={epochs} lr={lr} pairs={len(dataset)}\n")
for epoch in range(1, epochs + 1):
optimizer.zero_grad()
total_loss = torch.tensor(0.0)
for prompt, target in dataset:
# pixels must stay in [0,1] for the sign/magnitude encoding
p_clamped = pixels.clamp(0, 1)
pred = forward(p_clamped, prompt) # (32,32,3)
loss = F.mse_loss(pred, target)
total_loss = total_loss + loss
total_loss.backward()
optimizer.step()
# clamp pixels back to valid range after update
with torch.no_grad():
pixels.clamp_(0, 1)
if epoch % save_every == 0 or epoch == 1:
save_model(pixels, MODEL_PATH)
avg = total_loss.item() / len(dataset)
print(f" epoch {epoch:>5}/{epochs} loss={avg:.5f} → saved model.png")
save_model(pixels, MODEL_PATH)
print(f"\nDone. Final model saved to {MODEL_PATH}")
if __name__ == "__main__":
p = argparse.ArgumentParser()
p.add_argument("--epochs", type=int, default=300)
p.add_argument("--lr", type=float, default=0.03)
p.add_argument("--save-every", type=int, default=50)
args = p.parse_args()
train(args.epochs, args.lr, args.save_every)