set-solver / src /train /classifier.py
Tian Wang
Deploy Set Solver web app
8a34385
"""
Train a card attribute classifier on the existing labeled images.
Uses MobileNetV3-Small for iPhone compatibility.
Multi-head output: predicts all 4 attributes simultaneously.
"""
import os
import json
from pathlib import Path
from typing import Tuple, Dict, List
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models
from torchvision.io import read_image, ImageReadMode
from PIL import Image
import numpy as np
from tqdm import tqdm
# === Config ===
DATA_DIR = Path(__file__).parent.parent.parent / "training_images"
SYNTHETIC_DATA_DIR = Path(__file__).parent.parent.parent / "training_images_synthetic"
WEIGHTS_DIR = Path(__file__).parent.parent.parent / "weights"
WEIGHTS_DIR.mkdir(exist_ok=True)
# Attribute mappings (folder names → indices)
NUMBER_MAP = {"one": 0, "two": 1, "three": 2}
COLOR_MAP = {"red": 0, "green": 1, "blue": 2} # blue = purple in standard Set
SHAPE_MAP = {"diamond": 0, "oval": 1, "squiggle": 2}
FILL_MAP = {"empty": 0, "full": 1, "partial": 2} # partial = striped
# Reverse mappings for inference
NUMBER_NAMES = ["one", "two", "three"]
COLOR_NAMES = ["red", "green", "blue"]
SHAPE_NAMES = ["diamond", "oval", "squiggle"]
FILL_NAMES = ["empty", "full", "partial"]
# === Dataset ===
class SetCardDataset(Dataset):
"""Dataset of labeled Set card images."""
def __init__(self, data_dirs, transform=None):
if isinstance(data_dirs, Path):
data_dirs = [data_dirs]
self.transform = transform
self.samples: List[Tuple[Path, Dict[str, int]]] = []
# Walk the directory structure to find all images
for data_dir in data_dirs:
if not data_dir.exists():
continue
count_before = len(self.samples)
for number in NUMBER_MAP:
for color in COLOR_MAP:
for shape in SHAPE_MAP:
for fill in FILL_MAP:
folder = data_dir / number / color / shape / fill
if folder.exists():
for img_path in folder.glob("*.png"):
labels = {
"number": NUMBER_MAP[number],
"color": COLOR_MAP[color],
"shape": SHAPE_MAP[shape],
"fill": FILL_MAP[fill],
}
self.samples.append((img_path, labels))
print(f"Loaded {len(self.samples) - count_before} samples from {data_dir}")
print(f"Total: {len(self.samples)} samples")
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
img_path, labels = self.samples[idx]
# Load image
image = Image.open(img_path).convert("RGB")
if self.transform:
image = self.transform(image)
# Stack labels into tensor
label_tensor = torch.tensor([
labels["number"],
labels["color"],
labels["shape"],
labels["fill"],
], dtype=torch.long)
return image, label_tensor
def get_raw(self, idx):
"""Get raw PIL image and labels (no transform)."""
img_path, labels = self.samples[idx]
image = Image.open(img_path).convert("RGB")
label_tensor = torch.tensor([
labels["number"],
labels["color"],
labels["shape"],
labels["fill"],
], dtype=torch.long)
return image, label_tensor
# === Model ===
class SetCardClassifier(nn.Module):
"""
Multi-head classifier for Set card attributes.
Uses MobileNetV3-Small backbone (good for mobile deployment).
Four output heads, one per attribute.
"""
def __init__(self, pretrained: bool = True):
super().__init__()
# Load pretrained MobileNetV3-Small
weights = models.MobileNet_V3_Small_Weights.DEFAULT if pretrained else None
self.backbone = models.mobilenet_v3_small(weights=weights)
# Get the feature dimension from the classifier
in_features = self.backbone.classifier[0].in_features
# Remove the original classifier
self.backbone.classifier = nn.Identity()
# Add our multi-head classifier
self.heads = nn.ModuleDict({
"number": nn.Linear(in_features, 3),
"color": nn.Linear(in_features, 3),
"shape": nn.Linear(in_features, 3),
"fill": nn.Linear(in_features, 3),
})
def forward(self, x):
features = self.backbone(x)
return {
"number": self.heads["number"](features),
"color": self.heads["color"](features),
"shape": self.heads["shape"](features),
"fill": self.heads["fill"](features),
}
# === Training ===
def train_epoch(model, loader, optimizer, criterion, device):
model.train()
total_loss = 0
correct = {k: 0 for k in ["number", "color", "shape", "fill"]}
total = 0
for images, labels in tqdm(loader, desc="Training", leave=False):
images = images.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(images)
# Compute loss for each head (2x weight on fill to penalize fill mistakes)
loss = 0
fill_weight = 2.0
for i, key in enumerate(["number", "color", "shape", "fill"]):
head_loss = criterion(outputs[key], labels[:, i])
loss += fill_weight * head_loss if key == "fill" else head_loss
preds = outputs[key].argmax(dim=1)
correct[key] += (preds == labels[:, i]).sum().item()
loss.backward()
optimizer.step()
total_loss += loss.item()
total += labels.size(0)
avg_loss = total_loss / len(loader)
accuracies = {k: v / total for k, v in correct.items()}
return avg_loss, accuracies
def evaluate(model, loader, criterion, device):
model.eval()
total_loss = 0
correct = {k: 0 for k in ["number", "color", "shape", "fill"]}
total = 0
with torch.no_grad():
for images, labels in tqdm(loader, desc="Evaluating", leave=False):
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
loss = 0
for i, key in enumerate(["number", "color", "shape", "fill"]):
loss += criterion(outputs[key], labels[:, i])
preds = outputs[key].argmax(dim=1)
correct[key] += (preds == labels[:, i]).sum().item()
total_loss += loss.item()
total += labels.size(0)
avg_loss = total_loss / len(loader)
accuracies = {k: v / total for k, v in correct.items()}
return avg_loss, accuracies
def main():
# === Hyperparameters ===
BATCH_SIZE = 32
EPOCHS = 50
LR = 1e-3
VAL_SPLIT = 0.15
TEST_SPLIT = 0.10
IMG_SIZE = 224
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")
# === Data transforms ===
train_transform = transforms.Compose([
transforms.RandomResizedCrop(IMG_SIZE, scale=(0.7, 1.0)), # Simulate imperfect detector crops
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomRotation(180), # Cards can be any orientation
transforms.RandomPerspective(distortion_scale=0.15, p=0.5), # Perspective warp from detection
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.05),
transforms.RandomGrayscale(p=0.05), # Force model to not rely solely on color for fill
transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0)), # ~30% effective via random sigma
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
val_transform = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# === Load dataset (clean + synthetic crops) ===
data_dirs = [DATA_DIR]
if SYNTHETIC_DATA_DIR.exists():
data_dirs.append(SYNTHETIC_DATA_DIR)
full_dataset = SetCardDataset(data_dirs, transform=None) # No transform yet
# Split into train/val/test
total = len(full_dataset)
test_size = int(total * TEST_SPLIT)
val_size = int(total * VAL_SPLIT)
train_size = total - val_size - test_size
train_dataset, val_dataset, test_dataset = random_split(
full_dataset, [train_size, val_size, test_size],
generator=torch.Generator().manual_seed(42)
)
print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
# Wrap with transform (can't change transform on Subset, so we wrap)
class TransformDataset(torch.utils.data.Dataset):
def __init__(self, subset, transform):
self.subset = subset
self.transform = transform
def __len__(self):
return len(self.subset)
def __getitem__(self, idx):
image, label = self.subset[idx]
if self.transform:
image = self.transform(image)
return image, label
train_dataset = TransformDataset(train_dataset, train_transform)
val_dataset = TransformDataset(val_dataset, val_transform)
test_dataset = TransformDataset(test_dataset, val_transform)
# Use num_workers=0 on macOS to avoid shared memory issues
import platform
num_workers = 0 if platform.system() == "Darwin" else 4
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers)
# === Model ===
model = SetCardClassifier(pretrained=True).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
# === Training loop ===
best_val_acc = 0
for epoch in range(EPOCHS):
train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
val_loss, val_acc = evaluate(model, val_loader, criterion, device)
scheduler.step()
# Average accuracy across all heads
avg_train_acc = sum(train_acc.values()) / 4
avg_val_acc = sum(val_acc.values()) / 4
print(f"Epoch {epoch+1}/{EPOCHS}")
print(f" Train Loss: {train_loss:.4f}, Acc: {avg_train_acc:.4f}")
print(f" Val Loss: {val_loss:.4f}, Acc: {avg_val_acc:.4f}")
print(f" Val per-head: num={val_acc['number']:.3f} col={val_acc['color']:.3f} "
f"shp={val_acc['shape']:.3f} fil={val_acc['fill']:.3f}")
# Save best model
if avg_val_acc > best_val_acc:
best_val_acc = avg_val_acc
torch.save({
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"val_acc": val_acc,
}, WEIGHTS_DIR / "classifier_best.pt")
print(f" Saved new best model (val_acc={avg_val_acc:.4f})")
# === Final evaluation on test set ===
print("\n" + "="*50)
print("Final Test Evaluation")
print("="*50)
# Load best model
checkpoint = torch.load(WEIGHTS_DIR / "classifier_best.pt")
model.load_state_dict(checkpoint["model_state_dict"])
test_loss, test_acc = evaluate(model, test_loader, criterion, device)
avg_test_acc = sum(test_acc.values()) / 4
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy (avg): {avg_test_acc:.4f}")
print(f" Number: {test_acc['number']:.4f}")
print(f" Color: {test_acc['color']:.4f}")
print(f" Shape: {test_acc['shape']:.4f}")
print(f" Fill: {test_acc['fill']:.4f}")
# Save final results
results = {
"test_loss": test_loss,
"test_accuracy": test_acc,
"avg_test_accuracy": avg_test_acc,
"train_size": train_size,
"val_size": val_size,
"test_size": test_size,
}
with open(WEIGHTS_DIR / "training_results.json", "w") as f:
json.dump(results, f, indent=2)
print(f"\nModel saved to {WEIGHTS_DIR / 'classifier_best.pt'}")
print(f"Results saved to {WEIGHTS_DIR / 'training_results.json'}")
if __name__ == "__main__":
main()