File size: 3,460 Bytes
38bfe91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""
train.py  –  optimize pixel values via gradient descent.

Dataset: dataset/ folder containing pairs:
  dataset/cat.png + dataset/cat.txt   (image + its prompt)
  dataset/sun.png + dataset/sun.txt
  ...

Each .txt holds one line: the prompt for that image.
Target images are resized to OUT_SIZE×OUT_SIZE automatically.

Usage:
  python train.py                     # train with defaults
  python train.py --epochs 500 --lr 0.05
"""

import argparse
import os
import sys
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image

from model import (
    MODEL_H, MODEL_W, OUT_SIZE,
    forward, load_model, save_model, init_model,
    pixels_to_weights, prompt_to_embedding,
)

DATASET_DIR  = "dataset"
MODEL_PATH   = "model.png"


def load_dataset():
    pairs = []
    for fname in sorted(os.listdir(DATASET_DIR)):
        if not fname.lower().endswith((".png", ".jpg", ".jpeg")):
            continue
        stem = os.path.splitext(fname)[0]
        txt_path = os.path.join(DATASET_DIR, stem + ".txt")
        img_path = os.path.join(DATASET_DIR, fname)
        if not os.path.exists(txt_path):
            print(f"  [skip] no .txt for {fname}")
            continue
        with open(txt_path) as f:
            prompt = f.read().strip()
        img = Image.open(img_path).convert("RGB").resize(
            (OUT_SIZE, OUT_SIZE), Image.BILINEAR
        )
        target = torch.tensor(
            np.array(img, dtype=np.float32) / 255.0
        )  # (32,32,3)
        pairs.append((prompt, target))
        print(f"  loaded: '{prompt}'  ← {fname}")
    return pairs


def train(epochs=300, lr=0.03, save_every=50):
    # init model if missing
    if not os.path.exists(MODEL_PATH):
        init_model(MODEL_PATH)

    raw = load_model(MODEL_PATH)               # (H, W, 3)

    # pixels are the parameters — wrap in a leaf tensor
    pixels = raw.clone().requires_grad_(True)
    optimizer = torch.optim.Adam([pixels], lr=lr)

    print(f"\nLoading dataset from '{DATASET_DIR}/'…")
    dataset = load_dataset()
    if not dataset:
        sys.exit("No dataset pairs found. Add image+txt pairs to dataset/")

    print(f"\nTraining  epochs={epochs}  lr={lr}  pairs={len(dataset)}\n")

    for epoch in range(1, epochs + 1):
        optimizer.zero_grad()
        total_loss = torch.tensor(0.0)

        for prompt, target in dataset:
            # pixels must stay in [0,1] for the sign/magnitude encoding
            p_clamped = pixels.clamp(0, 1)
            pred = forward(p_clamped, prompt)       # (32,32,3)
            loss = F.mse_loss(pred, target)
            total_loss = total_loss + loss

        total_loss.backward()
        optimizer.step()

        # clamp pixels back to valid range after update
        with torch.no_grad():
            pixels.clamp_(0, 1)

        if epoch % save_every == 0 or epoch == 1:
            save_model(pixels, MODEL_PATH)
            avg = total_loss.item() / len(dataset)
            print(f"  epoch {epoch:>5}/{epochs}  loss={avg:.5f}  → saved model.png")

    save_model(pixels, MODEL_PATH)
    print(f"\nDone. Final model saved to {MODEL_PATH}")


if __name__ == "__main__":
    p = argparse.ArgumentParser()
    p.add_argument("--epochs",     type=int,   default=300)
    p.add_argument("--lr",         type=float, default=0.03)
    p.add_argument("--save-every", type=int,   default=50)
    args = p.parse_args()
    train(args.epochs, args.lr, args.save_every)