Spaces:
Sleeping
Sleeping
| """ | |
| Train a card attribute classifier on the existing labeled images. | |
| Uses MobileNetV3-Small for iPhone compatibility. | |
| Multi-head output: predicts all 4 attributes simultaneously. | |
| """ | |
| import os | |
| import json | |
| from pathlib import Path | |
| from typing import Tuple, Dict, List | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import Dataset, DataLoader, random_split | |
| from torchvision import transforms, models | |
| from torchvision.io import read_image, ImageReadMode | |
| from PIL import Image | |
| import numpy as np | |
| from tqdm import tqdm | |
| # === Config === | |
| DATA_DIR = Path(__file__).parent.parent.parent / "training_images" | |
| SYNTHETIC_DATA_DIR = Path(__file__).parent.parent.parent / "training_images_synthetic" | |
| WEIGHTS_DIR = Path(__file__).parent.parent.parent / "weights" | |
| WEIGHTS_DIR.mkdir(exist_ok=True) | |
| # Attribute mappings (folder names → indices) | |
| NUMBER_MAP = {"one": 0, "two": 1, "three": 2} | |
| COLOR_MAP = {"red": 0, "green": 1, "blue": 2} # blue = purple in standard Set | |
| SHAPE_MAP = {"diamond": 0, "oval": 1, "squiggle": 2} | |
| FILL_MAP = {"empty": 0, "full": 1, "partial": 2} # partial = striped | |
| # Reverse mappings for inference | |
| NUMBER_NAMES = ["one", "two", "three"] | |
| COLOR_NAMES = ["red", "green", "blue"] | |
| SHAPE_NAMES = ["diamond", "oval", "squiggle"] | |
| FILL_NAMES = ["empty", "full", "partial"] | |
| # === Dataset === | |
| class SetCardDataset(Dataset): | |
| """Dataset of labeled Set card images.""" | |
| def __init__(self, data_dirs, transform=None): | |
| if isinstance(data_dirs, Path): | |
| data_dirs = [data_dirs] | |
| self.transform = transform | |
| self.samples: List[Tuple[Path, Dict[str, int]]] = [] | |
| # Walk the directory structure to find all images | |
| for data_dir in data_dirs: | |
| if not data_dir.exists(): | |
| continue | |
| count_before = len(self.samples) | |
| for number in NUMBER_MAP: | |
| for color in COLOR_MAP: | |
| for shape in SHAPE_MAP: | |
| for fill in FILL_MAP: | |
| folder = data_dir / number / color / shape / fill | |
| if folder.exists(): | |
| for img_path in folder.glob("*.png"): | |
| labels = { | |
| "number": NUMBER_MAP[number], | |
| "color": COLOR_MAP[color], | |
| "shape": SHAPE_MAP[shape], | |
| "fill": FILL_MAP[fill], | |
| } | |
| self.samples.append((img_path, labels)) | |
| print(f"Loaded {len(self.samples) - count_before} samples from {data_dir}") | |
| print(f"Total: {len(self.samples)} samples") | |
| def __len__(self): | |
| return len(self.samples) | |
| def __getitem__(self, idx): | |
| img_path, labels = self.samples[idx] | |
| # Load image | |
| image = Image.open(img_path).convert("RGB") | |
| if self.transform: | |
| image = self.transform(image) | |
| # Stack labels into tensor | |
| label_tensor = torch.tensor([ | |
| labels["number"], | |
| labels["color"], | |
| labels["shape"], | |
| labels["fill"], | |
| ], dtype=torch.long) | |
| return image, label_tensor | |
| def get_raw(self, idx): | |
| """Get raw PIL image and labels (no transform).""" | |
| img_path, labels = self.samples[idx] | |
| image = Image.open(img_path).convert("RGB") | |
| label_tensor = torch.tensor([ | |
| labels["number"], | |
| labels["color"], | |
| labels["shape"], | |
| labels["fill"], | |
| ], dtype=torch.long) | |
| return image, label_tensor | |
| # === Model === | |
| class SetCardClassifier(nn.Module): | |
| """ | |
| Multi-head classifier for Set card attributes. | |
| Uses MobileNetV3-Small backbone (good for mobile deployment). | |
| Four output heads, one per attribute. | |
| """ | |
| def __init__(self, pretrained: bool = True): | |
| super().__init__() | |
| # Load pretrained MobileNetV3-Small | |
| weights = models.MobileNet_V3_Small_Weights.DEFAULT if pretrained else None | |
| self.backbone = models.mobilenet_v3_small(weights=weights) | |
| # Get the feature dimension from the classifier | |
| in_features = self.backbone.classifier[0].in_features | |
| # Remove the original classifier | |
| self.backbone.classifier = nn.Identity() | |
| # Add our multi-head classifier | |
| self.heads = nn.ModuleDict({ | |
| "number": nn.Linear(in_features, 3), | |
| "color": nn.Linear(in_features, 3), | |
| "shape": nn.Linear(in_features, 3), | |
| "fill": nn.Linear(in_features, 3), | |
| }) | |
| def forward(self, x): | |
| features = self.backbone(x) | |
| return { | |
| "number": self.heads["number"](features), | |
| "color": self.heads["color"](features), | |
| "shape": self.heads["shape"](features), | |
| "fill": self.heads["fill"](features), | |
| } | |
| # === Training === | |
| def train_epoch(model, loader, optimizer, criterion, device): | |
| model.train() | |
| total_loss = 0 | |
| correct = {k: 0 for k in ["number", "color", "shape", "fill"]} | |
| total = 0 | |
| for images, labels in tqdm(loader, desc="Training", leave=False): | |
| images = images.to(device) | |
| labels = labels.to(device) | |
| optimizer.zero_grad() | |
| outputs = model(images) | |
| # Compute loss for each head (2x weight on fill to penalize fill mistakes) | |
| loss = 0 | |
| fill_weight = 2.0 | |
| for i, key in enumerate(["number", "color", "shape", "fill"]): | |
| head_loss = criterion(outputs[key], labels[:, i]) | |
| loss += fill_weight * head_loss if key == "fill" else head_loss | |
| preds = outputs[key].argmax(dim=1) | |
| correct[key] += (preds == labels[:, i]).sum().item() | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() | |
| total += labels.size(0) | |
| avg_loss = total_loss / len(loader) | |
| accuracies = {k: v / total for k, v in correct.items()} | |
| return avg_loss, accuracies | |
| def evaluate(model, loader, criterion, device): | |
| model.eval() | |
| total_loss = 0 | |
| correct = {k: 0 for k in ["number", "color", "shape", "fill"]} | |
| total = 0 | |
| with torch.no_grad(): | |
| for images, labels in tqdm(loader, desc="Evaluating", leave=False): | |
| images = images.to(device) | |
| labels = labels.to(device) | |
| outputs = model(images) | |
| loss = 0 | |
| for i, key in enumerate(["number", "color", "shape", "fill"]): | |
| loss += criterion(outputs[key], labels[:, i]) | |
| preds = outputs[key].argmax(dim=1) | |
| correct[key] += (preds == labels[:, i]).sum().item() | |
| total_loss += loss.item() | |
| total += labels.size(0) | |
| avg_loss = total_loss / len(loader) | |
| accuracies = {k: v / total for k, v in correct.items()} | |
| return avg_loss, accuracies | |
| def main(): | |
| # === Hyperparameters === | |
| BATCH_SIZE = 32 | |
| EPOCHS = 50 | |
| LR = 1e-3 | |
| VAL_SPLIT = 0.15 | |
| TEST_SPLIT = 0.10 | |
| IMG_SIZE = 224 | |
| device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| # === Data transforms === | |
| train_transform = transforms.Compose([ | |
| transforms.RandomResizedCrop(IMG_SIZE, scale=(0.7, 1.0)), # Simulate imperfect detector crops | |
| transforms.RandomHorizontalFlip(), | |
| transforms.RandomVerticalFlip(), | |
| transforms.RandomRotation(180), # Cards can be any orientation | |
| transforms.RandomPerspective(distortion_scale=0.15, p=0.5), # Perspective warp from detection | |
| transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.05), | |
| transforms.RandomGrayscale(p=0.05), # Force model to not rely solely on color for fill | |
| transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0)), # ~30% effective via random sigma | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| val_transform = transforms.Compose([ | |
| transforms.Resize((IMG_SIZE, IMG_SIZE)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| # === Load dataset (clean + synthetic crops) === | |
| data_dirs = [DATA_DIR] | |
| if SYNTHETIC_DATA_DIR.exists(): | |
| data_dirs.append(SYNTHETIC_DATA_DIR) | |
| full_dataset = SetCardDataset(data_dirs, transform=None) # No transform yet | |
| # Split into train/val/test | |
| total = len(full_dataset) | |
| test_size = int(total * TEST_SPLIT) | |
| val_size = int(total * VAL_SPLIT) | |
| train_size = total - val_size - test_size | |
| train_dataset, val_dataset, test_dataset = random_split( | |
| full_dataset, [train_size, val_size, test_size], | |
| generator=torch.Generator().manual_seed(42) | |
| ) | |
| print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}") | |
| # Wrap with transform (can't change transform on Subset, so we wrap) | |
| class TransformDataset(torch.utils.data.Dataset): | |
| def __init__(self, subset, transform): | |
| self.subset = subset | |
| self.transform = transform | |
| def __len__(self): | |
| return len(self.subset) | |
| def __getitem__(self, idx): | |
| image, label = self.subset[idx] | |
| if self.transform: | |
| image = self.transform(image) | |
| return image, label | |
| train_dataset = TransformDataset(train_dataset, train_transform) | |
| val_dataset = TransformDataset(val_dataset, val_transform) | |
| test_dataset = TransformDataset(test_dataset, val_transform) | |
| # Use num_workers=0 on macOS to avoid shared memory issues | |
| import platform | |
| num_workers = 0 if platform.system() == "Darwin" else 4 | |
| train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers) | |
| val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers) | |
| test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers) | |
| # === Model === | |
| model = SetCardClassifier(pretrained=True).to(device) | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=LR) | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS) | |
| # === Training loop === | |
| best_val_acc = 0 | |
| for epoch in range(EPOCHS): | |
| train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device) | |
| val_loss, val_acc = evaluate(model, val_loader, criterion, device) | |
| scheduler.step() | |
| # Average accuracy across all heads | |
| avg_train_acc = sum(train_acc.values()) / 4 | |
| avg_val_acc = sum(val_acc.values()) / 4 | |
| print(f"Epoch {epoch+1}/{EPOCHS}") | |
| print(f" Train Loss: {train_loss:.4f}, Acc: {avg_train_acc:.4f}") | |
| print(f" Val Loss: {val_loss:.4f}, Acc: {avg_val_acc:.4f}") | |
| print(f" Val per-head: num={val_acc['number']:.3f} col={val_acc['color']:.3f} " | |
| f"shp={val_acc['shape']:.3f} fil={val_acc['fill']:.3f}") | |
| # Save best model | |
| if avg_val_acc > best_val_acc: | |
| best_val_acc = avg_val_acc | |
| torch.save({ | |
| "epoch": epoch, | |
| "model_state_dict": model.state_dict(), | |
| "optimizer_state_dict": optimizer.state_dict(), | |
| "val_acc": val_acc, | |
| }, WEIGHTS_DIR / "classifier_best.pt") | |
| print(f" Saved new best model (val_acc={avg_val_acc:.4f})") | |
| # === Final evaluation on test set === | |
| print("\n" + "="*50) | |
| print("Final Test Evaluation") | |
| print("="*50) | |
| # Load best model | |
| checkpoint = torch.load(WEIGHTS_DIR / "classifier_best.pt") | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| test_loss, test_acc = evaluate(model, test_loader, criterion, device) | |
| avg_test_acc = sum(test_acc.values()) / 4 | |
| print(f"Test Loss: {test_loss:.4f}") | |
| print(f"Test Accuracy (avg): {avg_test_acc:.4f}") | |
| print(f" Number: {test_acc['number']:.4f}") | |
| print(f" Color: {test_acc['color']:.4f}") | |
| print(f" Shape: {test_acc['shape']:.4f}") | |
| print(f" Fill: {test_acc['fill']:.4f}") | |
| # Save final results | |
| results = { | |
| "test_loss": test_loss, | |
| "test_accuracy": test_acc, | |
| "avg_test_accuracy": avg_test_acc, | |
| "train_size": train_size, | |
| "val_size": val_size, | |
| "test_size": test_size, | |
| } | |
| with open(WEIGHTS_DIR / "training_results.json", "w") as f: | |
| json.dump(results, f, indent=2) | |
| print(f"\nModel saved to {WEIGHTS_DIR / 'classifier_best.pt'}") | |
| print(f"Results saved to {WEIGHTS_DIR / 'training_results.json'}") | |
| if __name__ == "__main__": | |
| main() | |