{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "id": "o_xNUk10GCIa" }, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets, transforms\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "CEfUc-G5GmJm" }, "outputs": [], "source": [ "CONFIG = {\n", " \"batch_size\": 64,\n", " \"epochs\": 50,\n", " \"lr\": 0.003,\n", " \"weight_decay\": 0.0001,\n", " \"label_smoothing\": 0.1,\n", " \"num_workers\": 2,\n", " \"device\": \"cuda\" if torch.cuda.is_available() else \"cpu\",\n", " \"seed\": 23,\n", "}" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "SaIhfZfCG0Wn", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "b59d824d-81be-4463-8fd7-0910b78acee2" }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "100%|██████████| 9.91M/9.91M [00:00<00:00, 20.3MB/s]\n", "100%|██████████| 28.9k/28.9k [00:00<00:00, 508kB/s]\n", "100%|██████████| 1.65M/1.65M [00:00<00:00, 4.58MB/s]\n", "100%|██████████| 4.54k/4.54k [00:00<00:00, 9.38MB/s]\n" ] } ], "source": [ "train_transform = transforms.Compose([\n", " transforms.RandomRotation(10),\n", " transforms.RandomAffine(\n", " degrees=0,\n", " translate=(0.1, 0.1),\n", " scale=(0.9, 1.1),\n", " shear=5\n", " ),\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.1307,), (0.3081,)),\n", "])\n", "\n", "test_transform = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.1307,), (0.3081,)),\n", "])\n", "\n", "train_dataset = datasets.MNIST(root=\"./data\", train=True, download=True, transform=train_transform)\n", "test_dataset = datasets.MNIST(root=\"./data\", train=False, download=True, transform=test_transform)\n", "\n", "train_loader = DataLoader(train_dataset, batch_size=CONFIG[\"batch_size\"], shuffle=True, num_workers=CONFIG[\"num_workers\"])\n", "test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=CONFIG[\"num_workers\"])" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "3SHyrmaMHCIJ" }, "outputs": [], "source": [ "class Model(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", "\n", " self.conv_layers = nn.Sequential(\n", " # Block 1: 1 -> 32 channels, 28x28 -> 14x14\n", " nn.Conv2d(1, 32, kernel_size=3, padding=1),\n", " nn.BatchNorm2d(32),\n", " nn.ReLU(),\n", " nn.MaxPool2d(2),\n", " nn.Dropout2d(0.25),\n", "\n", " # Block 2: 32 -> 64 channels, 14x14 -> 7x7\n", " nn.Conv2d(32, 64, kernel_size=3, padding=1),\n", " nn.BatchNorm2d(64),\n", " nn.ReLU(),\n", " nn.MaxPool2d(2),\n", " nn.Dropout2d(0.25),\n", "\n", " # Block 3: 64 -> 128 channels, 7x7 -> 3x3\n", " nn.Conv2d(64, 128, kernel_size=3, padding=1),\n", " nn.BatchNorm2d(128),\n", " nn.ReLU(),\n", " nn.MaxPool2d(2),\n", " nn.Dropout2d(0.25),\n", "\n", " # Block 3: 128 -> 256 channels, 3x3 -> 1x1\n", " nn.Conv2d(128, 256, kernel_size=1),\n", " nn.BatchNorm2d(256),\n", " nn.ReLU(),\n", " nn.MaxPool2d(2),\n", " nn.Dropout2d(0.25),\n", " )\n", "\n", " self.fc_layers = nn.Sequential(\n", " nn.Flatten(), # 256 * 1 * 1 = 256\n", " nn.Linear(256 * 1 * 1, 128),\n", " nn.ReLU(),\n", " nn.Dropout(0.25),\n", " nn.Linear(128, 10)\n", " )\n", "\n", " def forward(self, x):\n", " x = self.conv_layers(x)\n", " x = self.fc_layers(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "rEp8D2U8Ke6d", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "92157bae-28fb-4aa0-cb5a-75dc935adbe9" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Model parameters: 160,842\n" ] } ], "source": [ "model = Model().to(CONFIG[\"device\"])\n", "total_params = sum(p.numel() for p in model.parameters())\n", "print(f\"Model parameters: {total_params:,}\")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "lL1TZN8MJoun" }, "outputs": [], "source": [ "optimizer = optim.AdamW(\n", " model.parameters(),\n", " lr=CONFIG[\"lr\"],\n", " weight_decay=CONFIG[\"weight_decay\"],\n", ")\n", "\n", "# Warmup for 5 epochs, then cosine decay\n", "scheduler = optim.lr_scheduler.OneCycleLR(\n", " optimizer,\n", " max_lr=CONFIG[\"lr\"],\n", " steps_per_epoch=len(train_loader),\n", " epochs=CONFIG[\"epochs\"],\n", " pct_start=0.1, # 10% warmup\n", " anneal_strategy=\"cos\",\n", ")\n", "\n", "criterion = nn.CrossEntropyLoss(label_smoothing=CONFIG[\"label_smoothing\"])" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "k_RjpCkLLXGj" }, "outputs": [], "source": [ "def train_epoch(model, loader, optimizer, scheduler, criterion, device):\n", " model.train()\n", " total_loss, correct, total = 0.0, 0, 0\n", "\n", " for images, labels in loader:\n", " images, labels = images.to(device), labels.to(device)\n", "\n", " optimizer.zero_grad()\n", " outputs = model(images)\n", " loss = criterion(outputs, labels)\n", " loss.backward()\n", " nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n", " optimizer.step()\n", " scheduler.step()\n", "\n", " total_loss += loss.item() * images.size(0)\n", " correct += (outputs.argmax(1) == labels).sum().item()\n", " total += images.size(0)\n", "\n", " return total_loss / total, correct / total\n", "\n", "\n", "def evaluate(model, loader, device, tta=False):\n", " \"\"\"Evaluate with optional Test-Time Augmentation.\"\"\"\n", " model.eval()\n", " correct, total = 0, 0\n", "\n", " tta_transforms = [\n", " transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]),\n", " transforms.Compose([transforms.RandomRotation(5), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]),\n", " transforms.Compose([transforms.RandomAffine(0, translate=(0.05, 0.05)), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]),\n", " ]\n", "\n", " with torch.no_grad():\n", " for images, labels in loader:\n", " images, labels = images.to(device), labels.to(device)\n", " outputs = model(images)\n", " correct += (outputs.argmax(1) == labels).sum().item()\n", " total += images.size(0)\n", "\n", " return correct / total" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "aC2r9yTUO6l9", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "7c804914-86c9-4863-aa04-f758d7b879c3" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\n", "============================================================\n", " Epoch Train Loss Train Acc Test Acc LR\n", "============================================================\n", " 1 1.5776 52.87% 95.11% 0.000395 ✓ BEST\n", " 2 0.8644 87.52% 97.11% 0.001115 ✓ BEST\n", " 3 0.7545 92.05% 98.13% 0.002006 ✓ BEST\n", " 4 0.7104 93.52% 98.53% 0.002725 ✓ BEST\n", " 5 0.6858 94.29% 98.75% 0.003000 ✓ BEST\n", " 6 0.6660 95.03% 98.78% 0.002996 ✓ BEST\n", " 7 0.6530 95.43% 98.84% 0.002985 ✓ BEST\n", " 8 0.6437 95.54% 98.91% 0.002967 ✓ BEST\n", " 9 0.6410 95.56% 99.14% 0.002942 ✓ BEST\n", " 10 0.6323 95.84% 99.07% 0.002910\n", " 11 0.6307 95.84% 99.10% 0.002870\n", " 12 0.6261 95.96% 98.97% 0.002824\n", " 13 0.6232 96.08% 99.05% 0.002772\n", " 14 0.6203 96.12% 99.06% 0.002713\n", " 15 0.6145 96.34% 99.02% 0.002649\n", " 16 0.6124 96.50% 99.22% 0.002579 ✓ BEST\n", " 17 0.6103 96.47% 99.09% 0.002504\n", " 18 0.6075 96.63% 99.12% 0.002423\n", " 19 0.6043 96.70% 99.09% 0.002339\n", " 20 0.6038 96.70% 99.17% 0.002250\n", " 21 0.6021 96.78% 99.27% 0.002157 ✓ BEST\n", " 22 0.6010 96.78% 99.20% 0.002062\n", " 23 0.5994 96.89% 99.24% 0.001963\n", " 24 0.5955 97.02% 99.35% 0.001863 ✓ BEST\n", " 25 0.5961 96.92% 99.21% 0.001760\n", " 26 0.5919 97.17% 99.31% 0.001657\n", " 27 0.5901 97.25% 99.30% 0.001552\n", " 28 0.5897 97.18% 99.27% 0.001448\n", " 29 0.5879 97.22% 99.27% 0.001343\n", " 30 0.5891 97.17% 99.26% 0.001239\n", " 31 0.5841 97.32% 99.25% 0.001137\n", " 32 0.5839 97.36% 99.29% 0.001036\n", " 33 0.5829 97.32% 99.37% 0.000938 ✓ BEST\n", " 34 0.5800 97.46% 99.36% 0.000842\n", " 35 0.5815 97.42% 99.38% 0.000750 ✓ BEST\n", " 36 0.5778 97.52% 99.40% 0.000661 ✓ BEST\n", " 37 0.5776 97.51% 99.37% 0.000576\n", " 38 0.5770 97.60% 99.41% 0.000496 ✓ BEST\n", " 39 0.5765 97.57% 99.38% 0.000421\n", " 40 0.5758 97.57% 99.43% 0.000351 ✓ BEST\n", " 41 0.5741 97.67% 99.41% 0.000286\n", " 42 0.5741 97.61% 99.38% 0.000228\n", " 43 0.5728 97.69% 99.40% 0.000176\n", " 44 0.5731 97.71% 99.39% 0.000130\n", " 45 0.5710 97.75% 99.38% 0.000090\n", " 46 0.5700 97.79% 99.40% 0.000058\n", " 47 0.5718 97.70% 99.38% 0.000033\n", " 48 0.5712 97.77% 99.38% 0.000015\n", " 49 0.5699 97.77% 99.38% 0.000004\n", " 50 0.5717 97.70% 99.39% 0.000000\n", "============================================================\n", "\n", "Best test accuracy: 99.43%\n" ] } ], "source": [ "best_acc = 0.0\n", "history = {\"train_loss\": [], \"train_acc\": [], \"test_acc\": []}\n", "\n", "print(\"\\n\" + \"=\"*60)\n", "print(f\"{'Epoch':>6} {'Train Loss':>10} {'Train Acc':>10} {'Test Acc':>10} {'LR':>10}\")\n", "print(\"=\"*60)\n", "\n", "for epoch in range(1, CONFIG[\"epochs\"] + 1):\n", " train_loss, train_acc = train_epoch(\n", " model, train_loader, optimizer, scheduler, criterion, CONFIG[\"device\"]\n", " )\n", " test_acc = evaluate(model, test_loader, CONFIG[\"device\"])\n", "\n", " history[\"train_loss\"].append(train_loss)\n", " history[\"train_acc\"].append(train_acc)\n", " history[\"test_acc\"].append(test_acc)\n", "\n", " current_lr = scheduler.get_last_lr()[0]\n", "\n", " if test_acc > best_acc:\n", " best_acc = test_acc\n", " torch.save(model.state_dict(), \"mnist_best.pth\")\n", " marker = \" ✓ BEST\"\n", " else:\n", " marker = \"\"\n", "\n", " print(f\"{epoch:>6} {train_loss:>10.4f} {train_acc*100:>9.2f}% {test_acc*100:>9.2f}% {current_lr:>10.6f}{marker}\")\n", "\n", "print(\"=\"*60)\n", "print(f\"\\nBest test accuracy: {best_acc*100:.2f}%\")" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "5QwlbG2YQ8Q4", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "ae219648-1b25-4062-97f2-6dec61a96b17" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\n", "Loading best model for final evaluation...\n", "Final test accuracy: 99.43%\n" ] } ], "source": [ "print(\"\\nLoading best model for final evaluation...\")\n", "model.load_state_dict(torch.load(\"mnist_best.pth\", map_location=CONFIG[\"device\"]))\n", "final_acc = evaluate(model, test_loader, CONFIG[\"device\"])\n", "print(f\"Final test accuracy: {final_acc*100:.2f}%\")" ] }, { "cell_type": "code", "source": [ "def confusion_matrix(model, loader, device, num_classes=10):\n", " model.eval()\n", " matrix = np.zeros((num_classes, num_classes), dtype=int)\n", " with torch.no_grad():\n", " for images, labels in loader:\n", " images = images.to(device)\n", " preds = model(images).argmax(1).cpu().numpy()\n", " for true, pred in zip(labels.numpy(), preds):\n", " matrix[true][pred] += 1\n", " return matrix\n", "\n", "cm = confusion_matrix(model, test_loader, CONFIG[\"device\"])\n", "print(\"\\nConfusion Matrix (rows=true, cols=predicted):\")\n", "print(\" \" + \" \".join(f\"{i:4}\" for i in range(10)))\n", "for i, row in enumerate(cm):\n", " errors = sum(row) - row[i]\n", " print(f\"{i}: \" + \" \".join(f\"{v:4}\" for v in row) + f\" [{errors} errors]\")\n", "\n", "per_class_acc = cm.diagonal() / cm.sum(axis=1)\n", "print(\"\\nPer-class accuracy:\")\n", "for i, acc in enumerate(per_class_acc):\n", " print(f\" Digit {i}: {acc*100:.1f}%\")" ], "metadata": { "id": "tv567D7c8tT8", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "9e585bd1-8862-428a-aa26-8886e087541b" }, "execution_count": 10, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\n", "Confusion Matrix (rows=true, cols=predicted):\n", " 0 1 2 3 4 5 6 7 8 9\n", "0: 980 0 0 0 0 0 0 0 0 0 [0 errors]\n", "1: 0 1132 0 1 0 1 0 1 0 0 [3 errors]\n", "2: 1 0 1025 2 0 0 1 3 0 0 [7 errors]\n", "3: 0 0 0 1008 0 1 0 0 1 0 [2 errors]\n", "4: 0 0 0 0 976 0 2 0 0 4 [6 errors]\n", "5: 1 0 0 3 0 885 2 1 0 0 [7 errors]\n", "6: 2 1 0 0 2 3 949 0 1 0 [9 errors]\n", "7: 0 4 2 0 0 1 0 1020 0 1 [8 errors]\n", "8: 0 0 2 1 0 1 0 0 968 2 [6 errors]\n", "9: 0 0 0 0 4 1 0 3 1 1000 [9 errors]\n", "\n", "Per-class accuracy:\n", " Digit 0: 100.0%\n", " Digit 1: 99.7%\n", " Digit 2: 99.3%\n", " Digit 3: 99.8%\n", " Digit 4: 99.4%\n", " Digit 5: 99.2%\n", " Digit 6: 99.1%\n", " Digit 7: 99.2%\n", " Digit 8: 99.4%\n", " Digit 9: 99.1%\n" ] } ] } ], "metadata": { "colab": { "provenance": [], "gpuType": "T4" }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "nbformat": 4, "nbformat_minor": 0 }