Ateshh's picture
Upload 5 files
626b231 verified
"""
train.py β€” Train your mini-style-transfer model
Usage:
python train.py --style starry_night.jpg --output starry_night.pth
What this script does:
1. Loads your style image (the painting)
2. Loops over MS-COCO images (content images β€” everyday photos)
3. For each photo: runs it through StyleNet, compares result to style
4. Updates model weights so outputs look more like the style painting
5. Saves your trained model as a .pth file
Beginner tip: Think of training as teaching the model by example.
You show it thousands of photos and say "make them look like Van Gogh".
After enough examples, it learns to do it on its own.
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
import argparse
from model import StyleNet
# ── Settings ──────────────────────────────────────────────────────────────────
IMAGE_SIZE = 256 # train on 256x256 (faster); can run inference at any size
BATCH_SIZE = 4
EPOCHS = 2 # 2 epochs is enough for a recognisable style
LR = 1e-3
CONTENT_W = 1.0 # how much to preserve original content
STYLE_W = 1e5 # how strongly to apply the style (very high is normal)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ── Dataset ───────────────────────────────────────────────────────────────────
class ImageFolderDataset(Dataset):
"""Loads all images from a folder. Use MS-COCO train2017 images."""
def __init__(self, folder, transform):
self.paths = [
os.path.join(folder, f) for f in os.listdir(folder)
if f.lower().endswith(('.jpg', '.jpeg', '.png'))
]
self.transform = transform
def __len__(self):
return len(self.paths)
def __getitem__(self, idx):
img = Image.open(self.paths[idx]).convert("RGB")
return self.transform(img)
# ── Perceptual Loss (VGG16) ───────────────────────────────────────────────────
# Instead of comparing pixels directly, we compare how images "feel"
# using a pretrained VGG network. This is what makes the style look good.
class VGGLoss(nn.Module):
def __init__(self):
super().__init__()
vgg = models.vgg16(weights=models.VGG16_Weights.DEFAULT).features
# relu2_2 for content, relu1_2 + relu2_2 + relu3_3 for style
self.slice1 = nn.Sequential(*list(vgg)[:4]).eval() # relu1_2
self.slice2 = nn.Sequential(*list(vgg)[4:9]).eval() # relu2_2 ← content
self.slice3 = nn.Sequential(*list(vgg)[9:16]).eval() # relu3_3
for p in self.parameters():
p.requires_grad = False
def forward(self, x):
h1 = self.slice1(x)
h2 = self.slice2(h1)
h3 = self.slice3(h2)
return h1, h2, h3
def gram_matrix(feat):
"""Style is captured as correlations between feature maps (Gram matrix)."""
B, C, H, W = feat.shape
feat = feat.view(B, C, H * W)
return torch.bmm(feat, feat.transpose(1, 2)) / (C * H * W)
# ── Training loop ─────────────────────────────────────────────────────────────
def train(style_image_path, content_folder, output_path):
print(f"Device: {DEVICE}")
print(f"Style: {style_image_path}")
print(f"Output: {output_path}\n")
transform = transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.CenterCrop(IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
# Load style image and precompute its Gram matrices (done once)
style_img = transform(Image.open(style_image_path).convert("RGB"))
style_img = style_img.unsqueeze(0).to(DEVICE)
dataset = ImageFolderDataset(content_folder, transform)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
model = StyleNet().to(DEVICE)
vgg = VGGLoss().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LR)
mse = nn.MSELoss()
# Precompute style Gram matrices
with torch.no_grad():
s1, s2, s3 = vgg(style_img)
style_grams = [gram_matrix(s1), gram_matrix(s2), gram_matrix(s3)]
print(f"Training on {len(dataset)} images for {EPOCHS} epochs...")
print("─" * 50)
for epoch in range(EPOCHS):
for i, content in enumerate(loader):
content = content.to(DEVICE)
optimizer.zero_grad()
# Forward pass
styled = model(content)
# Content loss β€” styled image should still look like the photo
_, c_feat, _ = vgg(content)
_, s_feat, _ = vgg(styled)
content_loss = mse(s_feat, c_feat.detach())
# Style loss β€” styled image should look like the painting
o1, o2, o3 = vgg(styled)
style_loss = (
mse(gram_matrix(o1), style_grams[0].expand(content.size(0), -1, -1)) +
mse(gram_matrix(o2), style_grams[1].expand(content.size(0), -1, -1)) +
mse(gram_matrix(o3), style_grams[2].expand(content.size(0), -1, -1))
)
loss = CONTENT_W * content_loss + STYLE_W * style_loss
loss.backward()
optimizer.step()
if i % 100 == 0:
print(f"Epoch {epoch+1}/{EPOCHS} Batch {i:4d}/{len(loader)}"
f" Loss: {loss.item():.2f}"
f" (content {content_loss.item():.3f}"
f" style {style_loss.item():.2f})")
torch.save(model.state_dict(), output_path)
print(f"\nDone! Model saved to: {output_path}")
print(f"Upload to HuggingFace: huggingface-cli upload your-username/mini-style-transfer {output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--style", required=True, help="Path to your style painting image")
parser.add_argument("--content", default="coco/", help="Folder of training photos (MS-COCO)")
parser.add_argument("--output", default="style_model.pth", help="Output .pth file name")
args = parser.parse_args()
train(args.style, args.content, args.output)