#!/usr/bin/env python3 """Train a cat classifier to distinguish Lucy from Madelaine.""" import os import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, WeightedRandomSampler from torchvision import datasets, transforms, models from sklearn.model_selection import train_test_split from tqdm import tqdm import shutil from pathlib import Path # Config DATA_DIR = "cats" MODEL_PATH = "cat_classifier.pth" BATCH_SIZE = 16 EPOCHS = 20 LEARNING_RATE = 1e-4 IMAGE_SIZE = 224 DEVICE = ( "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu" ) def create_train_val_split(data_dir: str, val_ratio: float = 0.2): """Split data into train/val sets while preserving class balance.""" train_dir = Path("data/train") val_dir = Path("data/val") # Clean up old splits for d in [train_dir, val_dir]: if d.exists(): shutil.rmtree(d) # Get class names from directory structure classes = [d.name for d in Path(data_dir).iterdir() if d.is_dir() and not d.name.startswith('.')] for cls in classes: cls_dir = Path(data_dir) / cls images = list(cls_dir.glob("*.jpeg")) + list(cls_dir.glob("*.jpg")) + list(cls_dir.glob("*.png")) train_imgs, val_imgs = train_test_split(images, test_size=val_ratio, random_state=42) # Copy to train/val directories for img in train_imgs: dest = train_dir / cls / img.name dest.parent.mkdir(parents=True, exist_ok=True) shutil.copy2(img, dest) for img in val_imgs: dest = val_dir / cls / img.name dest.parent.mkdir(parents=True, exist_ok=True) shutil.copy2(img, dest) print(f"{cls}: {len(train_imgs)} train, {len(val_imgs)} val") return str(train_dir), str(val_dir) def get_data_loaders(train_dir: str, val_dir: str): """Create data loaders with augmentation and class balancing.""" train_transform = transforms.Compose([ transforms.Resize((IMAGE_SIZE + 32, IMAGE_SIZE + 32)), transforms.RandomCrop(IMAGE_SIZE), transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) val_transform = transforms.Compose([ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) train_dataset = datasets.ImageFolder(train_dir, transform=train_transform) val_dataset = datasets.ImageFolder(val_dir, transform=val_transform) # Compute class weights for balanced sampling class_counts = [0, 0] for _, label in train_dataset.samples: class_counts[label] += 1 weights = [1.0 / class_counts[label] for _, label in train_dataset.samples] sampler = WeightedRandomSampler(weights, len(weights)) train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler, num_workers=0) val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0) print(f"\nClass mapping: {train_dataset.class_to_idx}") print(f"Class counts: {dict(zip(train_dataset.classes, class_counts))}") return train_loader, val_loader, train_dataset.class_to_idx def create_model(num_classes: int = 2): """Create EfficientNet-B0 model with custom classifier.""" model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT) # Freeze early layers for param in model.features[:5].parameters(): param.requires_grad = False # Replace classifier num_features = model.classifier[1].in_features model.classifier = nn.Sequential( nn.Dropout(p=0.3), nn.Linear(num_features, num_classes) ) return model.to(DEVICE) def train_epoch(model, loader, criterion, optimizer): """Train for one epoch.""" model.train() total_loss, correct, total = 0, 0, 0 for images, labels in tqdm(loader, desc="Training", leave=False): images, labels = images.to(DEVICE), labels.to(DEVICE) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() total_loss += loss.item() * images.size(0) _, predicted = outputs.max(1) correct += predicted.eq(labels).sum().item() total += labels.size(0) return total_loss / total, correct / total def validate(model, loader, criterion): """Validate the model.""" model.eval() total_loss, correct, total = 0, 0, 0 with torch.no_grad(): for images, labels in tqdm(loader, desc="Validating", leave=False): images, labels = images.to(DEVICE), labels.to(DEVICE) outputs = model(images) loss = criterion(outputs, labels) total_loss += loss.item() * images.size(0) _, predicted = outputs.max(1) correct += predicted.eq(labels).sum().item() total += labels.size(0) return total_loss / total, correct / total def main(): print(f"Using device: {DEVICE}") # Prepare data print("\nSplitting data into train/val sets...") train_dir, val_dir = create_train_val_split(DATA_DIR) train_loader, val_loader, class_to_idx = get_data_loaders(train_dir, val_dir) # Create model print("\nCreating model...") model = create_model() criterion = nn.CrossEntropyLoss() optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=LEARNING_RATE) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5) # Training loop best_val_acc = 0 print(f"\nTraining for {EPOCHS} epochs...\n") for epoch in range(EPOCHS): train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer) val_loss, val_acc = validate(model, val_loader, criterion) scheduler.step(val_loss) print(f"Epoch {epoch+1:2d}/{EPOCHS} | " f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | " f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}") if val_acc > best_val_acc: best_val_acc = val_acc torch.save({ 'model_state_dict': model.state_dict(), 'class_to_idx': class_to_idx, }, MODEL_PATH) print(f" -> Saved best model (val_acc: {val_acc:.4f})") print(f"\nTraining complete! Best validation accuracy: {best_val_acc:.4f}") print(f"Model saved to: {MODEL_PATH}") if __name__ == "__main__": main()