{ "cells": [ { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "ename": "ModuleNotFoundError", "evalue": "No module named 'torch'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", "\u001b[1;32m/Users/johnnydevriese/projects/jupyter/m1_gpu_pytorch.ipynb Cell 1'\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mtorch\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mtorch\u001b[39;00m \u001b[39mimport\u001b[39;00m nn\n\u001b[1;32m 3\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mtorch\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mutils\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mdata\u001b[39;00m \u001b[39mimport\u001b[39;00m DataLoader\n", "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torch'" ] } ], "source": [ "import torch\n", "from torch import nn\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets\n", "from torchvision.transforms import ToTensor" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/nightly/cpu\n", "Collecting torch\n", " Downloading https://download.pytorch.org/whl/nightly/cpu/torch-1.13.0.dev20220703-cp310-none-macosx_11_0_arm64.whl (50.1 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m50.1/50.1 MB\u001b[0m \u001b[31m30.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hCollecting torchvision\n", " Downloading https://download.pytorch.org/whl/nightly/cpu/torchvision-0.14.0.dev20220703-cp310-cp310-macosx_11_0_arm64.whl (1.4 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.4/1.4 MB\u001b[0m \u001b[31m21.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hCollecting torchaudio\n", " Downloading https://download.pytorch.org/whl/nightly/cpu/torchaudio-0.14.0.dev20220603-cp310-cp310-macosx_11_0_arm64.whl (2.7 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.7/2.7 MB\u001b[0m \u001b[31m9.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m0:01\u001b[0m\n", "\u001b[?25hCollecting typing-extensions\n", " Using cached typing_extensions-4.3.0-py3-none-any.whl (25 kB)\n", "Collecting pillow!=8.3.*,>=5.3.0\n", " Downloading Pillow-9.2.0-cp310-cp310-macosx_11_0_arm64.whl (2.8 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.8/2.8 MB\u001b[0m \u001b[31m14.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hCollecting requests\n", " Using cached requests-2.28.1-py3-none-any.whl (62 kB)\n", "Collecting numpy\n", " Downloading numpy-1.23.0-cp310-cp310-macosx_11_0_arm64.whl (13.3 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.3/13.3 MB\u001b[0m \u001b[31m44.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hCollecting certifi>=2017.4.17\n", " Downloading certifi-2022.6.15-py3-none-any.whl (160 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m160.2/160.2 kB\u001b[0m \u001b[31m6.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting idna<4,>=2.5\n", " Using cached idna-3.3-py3-none-any.whl (61 kB)\n", "Collecting charset-normalizer<3,>=2\n", " Using cached charset_normalizer-2.1.0-py3-none-any.whl (39 kB)\n", "Collecting urllib3<1.27,>=1.21.1\n", " Using cached urllib3-1.26.9-py2.py3-none-any.whl (138 kB)\n", "Installing collected packages: urllib3, typing-extensions, pillow, numpy, idna, charset-normalizer, certifi, torch, requests, torchvision, torchaudio\n", "Successfully installed certifi-2022.6.15 charset-normalizer-2.1.0 idna-3.3 numpy-1.23.0 pillow-9.2.0 requests-2.28.1 torch-1.13.0.dev20220703 torchaudio-0.14.0.dev20220603 torchvision-0.14.0.dev20220703 typing-extensions-4.3.0 urllib3-1.26.9\n" ] } ], "source": [ "! pip install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch 1.13.0.dev20220703\n", "device mps\n", "Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100.0%\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Extracting data/cifar-10-python.tar.gz to data\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Downloading: \"https://github.com/pytorch/vision/zipball/v0.11.0\" to /Users/johnnydevriese/.cache/torch/hub/v0.11.0.zip\n", "/Users/johnnydevriese/miniforge3/envs/pytorch-nightly/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and will be removed in 0.15, please use 'weights' instead.\n", " warnings.warn(\n", "/Users/johnnydevriese/miniforge3/envs/pytorch-nightly/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=None`.\n", " warnings.warn(msg)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 001/001 | Batch 0000/1406 | Loss: 2.5887\n", "Epoch: 001/001 | Batch 0100/1406 | Loss: 2.4339\n", "Epoch: 001/001 | Batch 0200/1406 | Loss: 2.0386\n", "Epoch: 001/001 | Batch 0300/1406 | Loss: 2.0561\n", "Epoch: 001/001 | Batch 0400/1406 | Loss: 2.1730\n", "Epoch: 001/001 | Batch 0500/1406 | Loss: 2.1067\n" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "\u001b[1;32m/Users/johnnydevriese/projects/jupyter/m1_gpu_pytorch.ipynb Cell 3'\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 257\u001b[0m model \u001b[39m=\u001b[39m model\u001b[39m.\u001b[39mto(DEVICE)\n\u001b[1;32m 259\u001b[0m optimizer \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39moptim\u001b[39m.\u001b[39mAdam(model\u001b[39m.\u001b[39mparameters(), lr\u001b[39m=\u001b[39m\u001b[39m0.0005\u001b[39m)\n\u001b[0;32m--> 261\u001b[0m minibatch_loss_list, train_acc_list, valid_acc_list \u001b[39m=\u001b[39m train_classifier_simple_v2(\n\u001b[1;32m 262\u001b[0m model\u001b[39m=\u001b[39;49mmodel,\n\u001b[1;32m 263\u001b[0m num_epochs\u001b[39m=\u001b[39;49mNUM_EPOCHS,\n\u001b[1;32m 264\u001b[0m train_loader\u001b[39m=\u001b[39;49mtrain_loader,\n\u001b[1;32m 265\u001b[0m valid_loader\u001b[39m=\u001b[39;49mvalid_loader,\n\u001b[1;32m 266\u001b[0m test_loader\u001b[39m=\u001b[39;49mtest_loader,\n\u001b[1;32m 267\u001b[0m optimizer\u001b[39m=\u001b[39;49moptimizer,\n\u001b[1;32m 268\u001b[0m best_model_save_path\u001b[39m=\u001b[39;49m\u001b[39mNone\u001b[39;49;00m,\n\u001b[1;32m 269\u001b[0m device\u001b[39m=\u001b[39;49mDEVICE,\n\u001b[1;32m 270\u001b[0m scheduler_on\u001b[39m=\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39mvalid_acc\u001b[39;49m\u001b[39m\"\u001b[39;49m,\n\u001b[1;32m 271\u001b[0m logging_interval\u001b[39m=\u001b[39;49m\u001b[39m100\u001b[39;49m,\n\u001b[1;32m 272\u001b[0m )\n", "\u001b[1;32m/Users/johnnydevriese/projects/jupyter/m1_gpu_pytorch.ipynb Cell 3'\u001b[0m in \u001b[0;36mtrain_classifier_simple_v2\u001b[0;34m(model, num_epochs, train_loader, valid_loader, test_loader, optimizer, device, logging_interval, best_model_save_path, scheduler, skip_train_acc, scheduler_on)\u001b[0m\n\u001b[1;32m 67\u001b[0m targets \u001b[39m=\u001b[39m targets\u001b[39m.\u001b[39mto(device)\n\u001b[1;32m 69\u001b[0m \u001b[39m# ## FORWARD AND BACK PROP\u001b[39;00m\n\u001b[0;32m---> 70\u001b[0m logits \u001b[39m=\u001b[39m model(features)\n\u001b[1;32m 71\u001b[0m loss \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mnn\u001b[39m.\u001b[39mfunctional\u001b[39m.\u001b[39mcross_entropy(logits, targets)\n\u001b[1;32m 72\u001b[0m optimizer\u001b[39m.\u001b[39mzero_grad()\n", "File \u001b[0;32m~/miniforge3/envs/pytorch-nightly/lib/python3.10/site-packages/torch/nn/modules/module.py:1186\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1182\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1183\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1184\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1185\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1186\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1187\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n", "File \u001b[0;32m~/miniforge3/envs/pytorch-nightly/lib/python3.10/site-packages/torchvision/models/vgg.py:66\u001b[0m, in \u001b[0;36mVGG.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 65\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, x: torch\u001b[39m.\u001b[39mTensor) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m torch\u001b[39m.\u001b[39mTensor:\n\u001b[0;32m---> 66\u001b[0m x \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mfeatures(x)\n\u001b[1;32m 67\u001b[0m x \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mavgpool(x)\n\u001b[1;32m 68\u001b[0m x \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mflatten(x, \u001b[39m1\u001b[39m)\n", "File \u001b[0;32m~/miniforge3/envs/pytorch-nightly/lib/python3.10/site-packages/torch/nn/modules/module.py:1186\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1182\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1183\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1184\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1185\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1186\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1187\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n", "File \u001b[0;32m~/miniforge3/envs/pytorch-nightly/lib/python3.10/site-packages/torch/nn/modules/container.py:141\u001b[0m, in \u001b[0;36mSequential.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 139\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39minput\u001b[39m):\n\u001b[1;32m 140\u001b[0m \u001b[39mfor\u001b[39;00m module \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m:\n\u001b[0;32m--> 141\u001b[0m \u001b[39minput\u001b[39m \u001b[39m=\u001b[39m module(\u001b[39minput\u001b[39;49m)\n\u001b[1;32m 142\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39minput\u001b[39m\n", "File \u001b[0;32m~/miniforge3/envs/pytorch-nightly/lib/python3.10/site-packages/torch/nn/modules/module.py:1186\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1182\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1183\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1184\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1185\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1186\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1187\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n", "File \u001b[0;32m~/miniforge3/envs/pytorch-nightly/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:150\u001b[0m, in \u001b[0;36m_BatchNorm.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtraining \u001b[39mand\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtrack_running_stats:\n\u001b[1;32m 148\u001b[0m \u001b[39m# TODO: if statement only here to tell the jit to skip emitting this when it is None\u001b[39;00m\n\u001b[1;32m 149\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mnum_batches_tracked \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m: \u001b[39m# type: ignore[has-type]\u001b[39;00m\n\u001b[0;32m--> 150\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mnum_batches_tracked\u001b[39m.\u001b[39;49madd_(\u001b[39m1\u001b[39;49m) \u001b[39m# type: ignore[has-type]\u001b[39;00m\n\u001b[1;32m 151\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmomentum \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m: \u001b[39m# use cumulative moving average\u001b[39;00m\n\u001b[1;32m 152\u001b[0m exponential_average_factor \u001b[39m=\u001b[39m \u001b[39m1.0\u001b[39m \u001b[39m/\u001b[39m \u001b[39mfloat\u001b[39m(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mnum_batches_tracked)\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "#!/usr/bin/env python\n", "# coding: utf-8\n", "\n", "import argparse\n", "import os\n", "import random\n", "import time\n", "\n", "import numpy as np\n", "import torch\n", "import torchvision\n", "from torch.utils.data import DataLoader\n", "from torch.utils.data import SubsetRandomSampler\n", "from torchvision import datasets, transforms\n", "\n", "\n", "def set_all_seeds(seed):\n", " os.environ[\"PL_GLOBAL_SEED\"] = str(seed)\n", " random.seed(seed)\n", " np.random.seed(seed)\n", " torch.manual_seed(seed)\n", " torch.cuda.manual_seed_all(seed)\n", "\n", "\n", "def compute_accuracy(model, data_loader, device):\n", " model.eval()\n", " with torch.no_grad():\n", " correct_pred, num_examples = 0, 0\n", " for i, (features, targets) in enumerate(data_loader):\n", "\n", " features = features.to(device)\n", " targets = targets.to(device)\n", "\n", " logits = model(features)\n", " _, predicted_labels = torch.max(logits, 1)\n", " num_examples += targets.size(0)\n", " correct_pred += (predicted_labels.cpu() == targets.cpu()).sum()\n", " return correct_pred.float() / num_examples * 100\n", "\n", "\n", "def train_classifier_simple_v2(\n", " model,\n", " num_epochs,\n", " train_loader,\n", " valid_loader,\n", " test_loader,\n", " optimizer,\n", " device,\n", " logging_interval=50,\n", " best_model_save_path=None,\n", " scheduler=None,\n", " skip_train_acc=False,\n", " scheduler_on=\"valid_acc\",\n", "):\n", "\n", " start_time = time.time()\n", " minibatch_loss_list, train_acc_list, valid_acc_list = [], [], []\n", " best_valid_acc, best_epoch = -float(\"inf\"), 0\n", "\n", " for epoch in range(num_epochs):\n", "\n", " epoch_start_time = time.time()\n", " model.train()\n", " for batch_idx, (features, targets) in enumerate(train_loader):\n", "\n", " features = features.to(device)\n", " targets = targets.to(device)\n", "\n", " # ## FORWARD AND BACK PROP\n", " logits = model(features)\n", " loss = torch.nn.functional.cross_entropy(logits, targets)\n", " optimizer.zero_grad()\n", "\n", " loss.backward()\n", "\n", " # ## UPDATE MODEL PARAMETERS\n", " optimizer.step()\n", "\n", " # ## LOGGING\n", " minibatch_loss_list.append(loss.item())\n", " if not batch_idx % logging_interval:\n", " print(\n", " f\"Epoch: {epoch+1:03d}/{num_epochs:03d} \"\n", " f\"| Batch {batch_idx:04d}/{len(train_loader):04d} \"\n", " f\"| Loss: {loss:.4f}\"\n", " )\n", "\n", " model.eval()\n", "\n", " elapsed = (time.time() - epoch_start_time) / 60\n", " print(f\"Time / epoch without evaluation: {elapsed:.2f} min\")\n", " with torch.no_grad(): # save memory during inference\n", " if not skip_train_acc:\n", " train_acc = compute_accuracy(model, train_loader, device=device).item()\n", " else:\n", " train_acc = float(\"nan\")\n", " valid_acc = compute_accuracy(model, valid_loader, device=device).item()\n", " train_acc_list.append(train_acc)\n", " valid_acc_list.append(valid_acc)\n", "\n", " if valid_acc > best_valid_acc:\n", " best_valid_acc, best_epoch = valid_acc, epoch + 1\n", " if best_model_save_path:\n", " torch.save(model.state_dict(), best_model_save_path)\n", "\n", " print(\n", " f\"Epoch: {epoch+1:03d}/{num_epochs:03d} \"\n", " f\"| Train: {train_acc :.2f}% \"\n", " f\"| Validation: {valid_acc :.2f}% \"\n", " f\"| Best Validation \"\n", " f\"(Ep. {best_epoch:03d}): {best_valid_acc :.2f}%\"\n", " )\n", "\n", " elapsed = (time.time() - start_time) / 60\n", " print(f\"Time elapsed: {elapsed:.2f} min\")\n", "\n", " if scheduler is not None:\n", "\n", " if scheduler_on == \"valid_acc\":\n", " scheduler.step(valid_acc_list[-1])\n", " elif scheduler_on == \"minibatch_loss\":\n", " scheduler.step(minibatch_loss_list[-1])\n", " else:\n", " raise ValueError(\"Invalid `scheduler_on` choice.\")\n", "\n", " elapsed = (time.time() - start_time) / 60\n", " print(f\"Total Training Time: {elapsed:.2f} min\")\n", "\n", " test_acc = compute_accuracy(model, test_loader, device=device)\n", " print(f\"Test accuracy {test_acc :.2f}%\")\n", "\n", " elapsed = (time.time() - start_time) / 60\n", " print(f\"Total Time: {elapsed:.2f} min\")\n", "\n", " return minibatch_loss_list, train_acc_list, valid_acc_list\n", "\n", "\n", "def get_dataloaders_cifar10(\n", " batch_size,\n", " num_workers=0,\n", " validation_fraction=None,\n", " train_transforms=None,\n", " test_transforms=None,\n", "):\n", "\n", " if train_transforms is None:\n", " train_transforms = transforms.ToTensor()\n", "\n", " if test_transforms is None:\n", " test_transforms = transforms.ToTensor()\n", "\n", " train_dataset = datasets.CIFAR10(\n", " root=\"data\", train=True, transform=train_transforms, download=True\n", " )\n", "\n", " valid_dataset = datasets.CIFAR10(root=\"data\", train=True, transform=test_transforms)\n", "\n", " test_dataset = datasets.CIFAR10(root=\"data\", train=False, transform=test_transforms)\n", "\n", " if validation_fraction is not None:\n", " num = int(validation_fraction * 50000)\n", " train_indices = torch.arange(0, 50000 - num)\n", " valid_indices = torch.arange(50000 - num, 50000)\n", "\n", " train_sampler = SubsetRandomSampler(train_indices)\n", " valid_sampler = SubsetRandomSampler(valid_indices)\n", "\n", " valid_loader = DataLoader(\n", " dataset=valid_dataset,\n", " batch_size=batch_size,\n", " num_workers=num_workers,\n", " sampler=valid_sampler,\n", " )\n", "\n", " train_loader = DataLoader(\n", " dataset=train_dataset,\n", " batch_size=batch_size,\n", " num_workers=num_workers,\n", " drop_last=True,\n", " sampler=train_sampler,\n", " )\n", "\n", " else:\n", " train_loader = DataLoader(\n", " dataset=train_dataset,\n", " batch_size=batch_size,\n", " num_workers=num_workers,\n", " drop_last=True,\n", " shuffle=True,\n", " )\n", "\n", " test_loader = DataLoader(\n", " dataset=test_dataset,\n", " batch_size=batch_size,\n", " num_workers=num_workers,\n", " shuffle=False,\n", " )\n", "\n", " if validation_fraction is None:\n", " return train_loader, test_loader\n", " else:\n", " return train_loader, valid_loader, test_loader\n", "\n", "\n", "if __name__ == \"__main__\":\n", "\n", " # parser = argparse.ArgumentParser()\n", " # parser.add_argument(\n", " # \"--device\", type=str, required=True, help=\"Which GPU device to use.\"\n", " # )\n", "\n", " # args = parser.parse_args()\n", "\n", "\n", " RANDOM_SEED = 123\n", " BATCH_SIZE = 32\n", " NUM_EPOCHS = 1\n", " # DEVICE = torch.device(args.device)\n", " # Apple’s Metal Performance Shaders (MPS)\n", " DEVICE = \"mps\"\n", "\n", " print('torch', torch.__version__)\n", " print('device', DEVICE)\n", "\n", " train_transforms = torchvision.transforms.Compose(\n", " [\n", " torchvision.transforms.Resize((256, 256)),\n", " torchvision.transforms.RandomCrop((224, 224)),\n", " torchvision.transforms.ToTensor(),\n", " torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n", " ]\n", " )\n", "\n", " test_transforms = torchvision.transforms.Compose(\n", " [\n", " torchvision.transforms.Resize((256, 256)),\n", " torchvision.transforms.CenterCrop((224, 224)),\n", " torchvision.transforms.ToTensor(),\n", " torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n", " ]\n", " )\n", "\n", " train_loader, valid_loader, test_loader = get_dataloaders_cifar10(\n", " batch_size=BATCH_SIZE,\n", " validation_fraction=0.1,\n", " train_transforms=train_transforms,\n", " test_transforms=test_transforms,\n", " num_workers=2,\n", " )\n", "\n", " model = torch.hub.load(\n", " \"pytorch/vision:v0.11.0\", \"vgg16_bn\", pretrained=False\n", " )\n", "\n", " model.classifier[-1] = torch.nn.Linear(\n", " in_features=4096, out_features=10 # as in original\n", " ) # number of class labels in Cifar-10)\n", "\n", " model = model.to(DEVICE)\n", "\n", " optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)\n", "\n", " minibatch_loss_list, train_acc_list, valid_acc_list = train_classifier_simple_v2(\n", " model=model,\n", " num_epochs=NUM_EPOCHS,\n", " train_loader=train_loader,\n", " valid_loader=valid_loader,\n", " test_loader=test_loader,\n", " optimizer=optimizer,\n", " best_model_save_path=None,\n", " device=DEVICE,\n", " scheduler_on=\"valid_acc\",\n", " logging_interval=100,\n", " )" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "\n", "torch.has_mps" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.10.5 ('pytorch-nightly')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.5" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "8a8bcccfb183d1298694efece6cf41240378bc61621e95c864629a40c5876542" } } }, "nbformat": 4, "nbformat_minor": 2 }