| |
| """Train a cat classifier to distinguish Lucy from Madelaine.""" |
|
|
| import os |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.utils.data import DataLoader, WeightedRandomSampler |
| from torchvision import datasets, transforms, models |
| from sklearn.model_selection import train_test_split |
| from tqdm import tqdm |
| import shutil |
| from pathlib import Path |
|
|
| |
| DATA_DIR = "cats" |
| MODEL_PATH = "cat_classifier.pth" |
| BATCH_SIZE = 16 |
| EPOCHS = 20 |
| LEARNING_RATE = 1e-4 |
| IMAGE_SIZE = 224 |
| DEVICE = ( |
| "mps" if torch.backends.mps.is_available() |
| else "cuda" if torch.cuda.is_available() |
| else "cpu" |
| ) |
|
|
| def create_train_val_split(data_dir: str, val_ratio: float = 0.2): |
| """Split data into train/val sets while preserving class balance.""" |
| train_dir = Path("data/train") |
| val_dir = Path("data/val") |
|
|
| |
| for d in [train_dir, val_dir]: |
| if d.exists(): |
| shutil.rmtree(d) |
|
|
| |
| classes = [d.name for d in Path(data_dir).iterdir() if d.is_dir() and not d.name.startswith('.')] |
|
|
| for cls in classes: |
| cls_dir = Path(data_dir) / cls |
| images = list(cls_dir.glob("*.jpeg")) + list(cls_dir.glob("*.jpg")) + list(cls_dir.glob("*.png")) |
|
|
| train_imgs, val_imgs = train_test_split(images, test_size=val_ratio, random_state=42) |
|
|
| |
| for img in train_imgs: |
| dest = train_dir / cls / img.name |
| dest.parent.mkdir(parents=True, exist_ok=True) |
| shutil.copy2(img, dest) |
|
|
| for img in val_imgs: |
| dest = val_dir / cls / img.name |
| dest.parent.mkdir(parents=True, exist_ok=True) |
| shutil.copy2(img, dest) |
|
|
| print(f"{cls}: {len(train_imgs)} train, {len(val_imgs)} val") |
|
|
| return str(train_dir), str(val_dir) |
|
|
|
|
| def get_data_loaders(train_dir: str, val_dir: str): |
| """Create data loaders with augmentation and class balancing.""" |
|
|
| train_transform = transforms.Compose([ |
| transforms.Resize((IMAGE_SIZE + 32, IMAGE_SIZE + 32)), |
| transforms.RandomCrop(IMAGE_SIZE), |
| transforms.RandomHorizontalFlip(), |
| transforms.RandomRotation(15), |
| transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), |
| transforms.ToTensor(), |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
| ]) |
|
|
| val_transform = transforms.Compose([ |
| transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
| ]) |
|
|
| train_dataset = datasets.ImageFolder(train_dir, transform=train_transform) |
| val_dataset = datasets.ImageFolder(val_dir, transform=val_transform) |
|
|
| |
| class_counts = [0, 0] |
| for _, label in train_dataset.samples: |
| class_counts[label] += 1 |
|
|
| weights = [1.0 / class_counts[label] for _, label in train_dataset.samples] |
| sampler = WeightedRandomSampler(weights, len(weights)) |
|
|
| train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler, num_workers=0) |
| val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0) |
|
|
| print(f"\nClass mapping: {train_dataset.class_to_idx}") |
| print(f"Class counts: {dict(zip(train_dataset.classes, class_counts))}") |
|
|
| return train_loader, val_loader, train_dataset.class_to_idx |
|
|
|
|
| def create_model(num_classes: int = 2): |
| """Create EfficientNet-B0 model with custom classifier.""" |
| model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT) |
|
|
| |
| for param in model.features[:5].parameters(): |
| param.requires_grad = False |
|
|
| |
| num_features = model.classifier[1].in_features |
| model.classifier = nn.Sequential( |
| nn.Dropout(p=0.3), |
| nn.Linear(num_features, num_classes) |
| ) |
|
|
| return model.to(DEVICE) |
|
|
|
|
| def train_epoch(model, loader, criterion, optimizer): |
| """Train for one epoch.""" |
| model.train() |
| total_loss, correct, total = 0, 0, 0 |
|
|
| for images, labels in tqdm(loader, desc="Training", leave=False): |
| images, labels = images.to(DEVICE), labels.to(DEVICE) |
|
|
| optimizer.zero_grad() |
| outputs = model(images) |
| loss = criterion(outputs, labels) |
| loss.backward() |
| optimizer.step() |
|
|
| total_loss += loss.item() * images.size(0) |
| _, predicted = outputs.max(1) |
| correct += predicted.eq(labels).sum().item() |
| total += labels.size(0) |
|
|
| return total_loss / total, correct / total |
|
|
|
|
| def validate(model, loader, criterion): |
| """Validate the model.""" |
| model.eval() |
| total_loss, correct, total = 0, 0, 0 |
|
|
| with torch.no_grad(): |
| for images, labels in tqdm(loader, desc="Validating", leave=False): |
| images, labels = images.to(DEVICE), labels.to(DEVICE) |
| outputs = model(images) |
| loss = criterion(outputs, labels) |
|
|
| total_loss += loss.item() * images.size(0) |
| _, predicted = outputs.max(1) |
| correct += predicted.eq(labels).sum().item() |
| total += labels.size(0) |
|
|
| return total_loss / total, correct / total |
|
|
|
|
| def main(): |
| print(f"Using device: {DEVICE}") |
|
|
| |
| print("\nSplitting data into train/val sets...") |
| train_dir, val_dir = create_train_val_split(DATA_DIR) |
| train_loader, val_loader, class_to_idx = get_data_loaders(train_dir, val_dir) |
|
|
| |
| print("\nCreating model...") |
| model = create_model() |
| criterion = nn.CrossEntropyLoss() |
| optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=LEARNING_RATE) |
| scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5) |
|
|
| |
| best_val_acc = 0 |
| print(f"\nTraining for {EPOCHS} epochs...\n") |
|
|
| for epoch in range(EPOCHS): |
| train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer) |
| val_loss, val_acc = validate(model, val_loader, criterion) |
| scheduler.step(val_loss) |
|
|
| print(f"Epoch {epoch+1:2d}/{EPOCHS} | " |
| f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | " |
| f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}") |
|
|
| if val_acc > best_val_acc: |
| best_val_acc = val_acc |
| torch.save({ |
| 'model_state_dict': model.state_dict(), |
| 'class_to_idx': class_to_idx, |
| }, MODEL_PATH) |
| print(f" -> Saved best model (val_acc: {val_acc:.4f})") |
|
|
| print(f"\nTraining complete! Best validation accuracy: {best_val_acc:.4f}") |
| print(f"Model saved to: {MODEL_PATH}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|