File size: 6,810 Bytes
626b231 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 | """
train.py β Train your mini-style-transfer model
Usage:
python train.py --style starry_night.jpg --output starry_night.pth
What this script does:
1. Loads your style image (the painting)
2. Loops over MS-COCO images (content images β everyday photos)
3. For each photo: runs it through StyleNet, compares result to style
4. Updates model weights so outputs look more like the style painting
5. Saves your trained model as a .pth file
Beginner tip: Think of training as teaching the model by example.
You show it thousands of photos and say "make them look like Van Gogh".
After enough examples, it learns to do it on its own.
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
import argparse
from model import StyleNet
# ββ Settings ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
IMAGE_SIZE = 256 # train on 256x256 (faster); can run inference at any size
BATCH_SIZE = 4
EPOCHS = 2 # 2 epochs is enough for a recognisable style
LR = 1e-3
CONTENT_W = 1.0 # how much to preserve original content
STYLE_W = 1e5 # how strongly to apply the style (very high is normal)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ββ Dataset βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class ImageFolderDataset(Dataset):
"""Loads all images from a folder. Use MS-COCO train2017 images."""
def __init__(self, folder, transform):
self.paths = [
os.path.join(folder, f) for f in os.listdir(folder)
if f.lower().endswith(('.jpg', '.jpeg', '.png'))
]
self.transform = transform
def __len__(self):
return len(self.paths)
def __getitem__(self, idx):
img = Image.open(self.paths[idx]).convert("RGB")
return self.transform(img)
# ββ Perceptual Loss (VGG16) βββββββββββββββββββββββββββββββββββββββββββββββββββ
# Instead of comparing pixels directly, we compare how images "feel"
# using a pretrained VGG network. This is what makes the style look good.
class VGGLoss(nn.Module):
def __init__(self):
super().__init__()
vgg = models.vgg16(weights=models.VGG16_Weights.DEFAULT).features
# relu2_2 for content, relu1_2 + relu2_2 + relu3_3 for style
self.slice1 = nn.Sequential(*list(vgg)[:4]).eval() # relu1_2
self.slice2 = nn.Sequential(*list(vgg)[4:9]).eval() # relu2_2 β content
self.slice3 = nn.Sequential(*list(vgg)[9:16]).eval() # relu3_3
for p in self.parameters():
p.requires_grad = False
def forward(self, x):
h1 = self.slice1(x)
h2 = self.slice2(h1)
h3 = self.slice3(h2)
return h1, h2, h3
def gram_matrix(feat):
"""Style is captured as correlations between feature maps (Gram matrix)."""
B, C, H, W = feat.shape
feat = feat.view(B, C, H * W)
return torch.bmm(feat, feat.transpose(1, 2)) / (C * H * W)
# ββ Training loop βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def train(style_image_path, content_folder, output_path):
print(f"Device: {DEVICE}")
print(f"Style: {style_image_path}")
print(f"Output: {output_path}\n")
transform = transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.CenterCrop(IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
# Load style image and precompute its Gram matrices (done once)
style_img = transform(Image.open(style_image_path).convert("RGB"))
style_img = style_img.unsqueeze(0).to(DEVICE)
dataset = ImageFolderDataset(content_folder, transform)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
model = StyleNet().to(DEVICE)
vgg = VGGLoss().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LR)
mse = nn.MSELoss()
# Precompute style Gram matrices
with torch.no_grad():
s1, s2, s3 = vgg(style_img)
style_grams = [gram_matrix(s1), gram_matrix(s2), gram_matrix(s3)]
print(f"Training on {len(dataset)} images for {EPOCHS} epochs...")
print("β" * 50)
for epoch in range(EPOCHS):
for i, content in enumerate(loader):
content = content.to(DEVICE)
optimizer.zero_grad()
# Forward pass
styled = model(content)
# Content loss β styled image should still look like the photo
_, c_feat, _ = vgg(content)
_, s_feat, _ = vgg(styled)
content_loss = mse(s_feat, c_feat.detach())
# Style loss β styled image should look like the painting
o1, o2, o3 = vgg(styled)
style_loss = (
mse(gram_matrix(o1), style_grams[0].expand(content.size(0), -1, -1)) +
mse(gram_matrix(o2), style_grams[1].expand(content.size(0), -1, -1)) +
mse(gram_matrix(o3), style_grams[2].expand(content.size(0), -1, -1))
)
loss = CONTENT_W * content_loss + STYLE_W * style_loss
loss.backward()
optimizer.step()
if i % 100 == 0:
print(f"Epoch {epoch+1}/{EPOCHS} Batch {i:4d}/{len(loader)}"
f" Loss: {loss.item():.2f}"
f" (content {content_loss.item():.3f}"
f" style {style_loss.item():.2f})")
torch.save(model.state_dict(), output_path)
print(f"\nDone! Model saved to: {output_path}")
print(f"Upload to HuggingFace: huggingface-cli upload your-username/mini-style-transfer {output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--style", required=True, help="Path to your style painting image")
parser.add_argument("--content", default="coco/", help="Folder of training photos (MS-COCO)")
parser.add_argument("--output", default="style_model.pth", help="Output .pth file name")
args = parser.parse_args()
train(args.style, args.content, args.output)
|