File size: 6,810 Bytes
626b231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
"""
train.py β€” Train your mini-style-transfer model

Usage:
    python train.py --style starry_night.jpg --output starry_night.pth

What this script does:
    1. Loads your style image (the painting)
    2. Loops over MS-COCO images (content images β€” everyday photos)
    3. For each photo: runs it through StyleNet, compares result to style
    4. Updates model weights so outputs look more like the style painting
    5. Saves your trained model as a .pth file

Beginner tip: Think of training as teaching the model by example.
You show it thousands of photos and say "make them look like Van Gogh".
After enough examples, it learns to do it on its own.
"""

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
import argparse
from model import StyleNet


# ── Settings ──────────────────────────────────────────────────────────────────

IMAGE_SIZE   = 256          # train on 256x256 (faster); can run inference at any size
BATCH_SIZE   = 4
EPOCHS       = 2            # 2 epochs is enough for a recognisable style
LR           = 1e-3
CONTENT_W    = 1.0          # how much to preserve original content
STYLE_W      = 1e5          # how strongly to apply the style (very high is normal)
DEVICE       = "cuda" if torch.cuda.is_available() else "cpu"


# ── Dataset ───────────────────────────────────────────────────────────────────

class ImageFolderDataset(Dataset):
    """Loads all images from a folder. Use MS-COCO train2017 images."""
    def __init__(self, folder, transform):
        self.paths = [
            os.path.join(folder, f) for f in os.listdir(folder)
            if f.lower().endswith(('.jpg', '.jpeg', '.png'))
        ]
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        return self.transform(img)


# ── Perceptual Loss (VGG16) ───────────────────────────────────────────────────
# Instead of comparing pixels directly, we compare how images "feel"
# using a pretrained VGG network. This is what makes the style look good.

class VGGLoss(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = models.vgg16(weights=models.VGG16_Weights.DEFAULT).features
        # relu2_2 for content, relu1_2 + relu2_2 + relu3_3 for style
        self.slice1 = nn.Sequential(*list(vgg)[:4]).eval()    # relu1_2
        self.slice2 = nn.Sequential(*list(vgg)[4:9]).eval()   # relu2_2  ← content
        self.slice3 = nn.Sequential(*list(vgg)[9:16]).eval()  # relu3_3
        for p in self.parameters():
            p.requires_grad = False

    def forward(self, x):
        h1 = self.slice1(x)
        h2 = self.slice2(h1)
        h3 = self.slice3(h2)
        return h1, h2, h3

def gram_matrix(feat):
    """Style is captured as correlations between feature maps (Gram matrix)."""
    B, C, H, W = feat.shape
    feat = feat.view(B, C, H * W)
    return torch.bmm(feat, feat.transpose(1, 2)) / (C * H * W)


# ── Training loop ─────────────────────────────────────────────────────────────

def train(style_image_path, content_folder, output_path):
    print(f"Device: {DEVICE}")
    print(f"Style:  {style_image_path}")
    print(f"Output: {output_path}\n")

    transform = transforms.Compose([
        transforms.Resize(IMAGE_SIZE),
        transforms.CenterCrop(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    # Load style image and precompute its Gram matrices (done once)
    style_img = transform(Image.open(style_image_path).convert("RGB"))
    style_img = style_img.unsqueeze(0).to(DEVICE)

    dataset   = ImageFolderDataset(content_folder, transform)
    loader    = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

    model     = StyleNet().to(DEVICE)
    vgg       = VGGLoss().to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LR)
    mse       = nn.MSELoss()

    # Precompute style Gram matrices
    with torch.no_grad():
        s1, s2, s3 = vgg(style_img)
        style_grams = [gram_matrix(s1), gram_matrix(s2), gram_matrix(s3)]

    print(f"Training on {len(dataset)} images for {EPOCHS} epochs...")
    print("─" * 50)

    for epoch in range(EPOCHS):
        for i, content in enumerate(loader):
            content = content.to(DEVICE)
            optimizer.zero_grad()

            # Forward pass
            styled = model(content)

            # Content loss β€” styled image should still look like the photo
            _, c_feat, _  = vgg(content)
            _, s_feat, _  = vgg(styled)
            content_loss  = mse(s_feat, c_feat.detach())

            # Style loss β€” styled image should look like the painting
            o1, o2, o3    = vgg(styled)
            style_loss    = (
                mse(gram_matrix(o1), style_grams[0].expand(content.size(0), -1, -1)) +
                mse(gram_matrix(o2), style_grams[1].expand(content.size(0), -1, -1)) +
                mse(gram_matrix(o3), style_grams[2].expand(content.size(0), -1, -1))
            )

            loss = CONTENT_W * content_loss + STYLE_W * style_loss
            loss.backward()
            optimizer.step()

            if i % 100 == 0:
                print(f"Epoch {epoch+1}/{EPOCHS}  Batch {i:4d}/{len(loader)}"
                      f"  Loss: {loss.item():.2f}"
                      f"  (content {content_loss.item():.3f}"
                      f"  style {style_loss.item():.2f})")

    torch.save(model.state_dict(), output_path)
    print(f"\nDone! Model saved to: {output_path}")
    print(f"Upload to HuggingFace: huggingface-cli upload your-username/mini-style-transfer {output_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--style",   required=True,  help="Path to your style painting image")
    parser.add_argument("--content", default="coco/", help="Folder of training photos (MS-COCO)")
    parser.add_argument("--output",  default="style_model.pth", help="Output .pth file name")
    args = parser.parse_args()
    train(args.style, args.content, args.output)