File size: 232,759 Bytes
5f58fbe | 1 2 | {"cells":[{"cell_type":"markdown","metadata":{"id":"pmi4V4VaYR6i"},"source":["# **MIT 6.5940 EfficientML.ai Fall 2023 Lab 2: Quantization**"]},{"cell_type":"markdown","metadata":{"id":"_PC9ifUhHwRT"},"source":["This colab notebook provides code and a framework for Lab 2 quantization. You can work out your solutions here."]},{"cell_type":"markdown","metadata":{"id":"A-antxp8SSyb"},"source":["Please fill out this [feedback form](https://forms.gle/ZeCH5anNPrkd5wpp7) when you finished this lab. We would love to hear your thoughts or feedback on how we can improve this lab!"]},{"cell_type":"markdown","metadata":{"id":"vhVOMmbAaHRd"},"source":["## Goals\n","\n","In this assignment, you will practice quantizing a classical neural network model to reduce both model size and latency. The goals of this assignment are as follows:\n","\n","- Understand the basic concept of **quantization**\n","- Implement and apply **k-means quantization**\n","- Implement and apply **quantization-aware training** for k-means quantization\n","- Implement and apply **linear quantization**\n","- Implement and apply **integer-only inference** for linear quantization\n","- Get a basic understanding of performance improvement (such as speedup) from quantization\n","- Understand the differences and tradeoffs between these quantization approaches"]},{"cell_type":"markdown","metadata":{"id":"W6HPdGZ7aHZo"},"source":["## Contents\n","\n","\n","There are 2 main sections: ***K-Means Quantization*** and ***Linear Quantization***.\n","\n","There are ***10*** questions in total:\n","- For *K-Means Quantization*, there are ***3*** questions (Question 1-3).\n","- For *Linear Quantization*, there are ***6*** questions (Question 4-9).\n","- Question 10 compares k-means quantization and linear quantization."]},{"cell_type":"markdown","metadata":{"id":"2tFjnZZVlIFL"},"source":["# Setup"]},{"cell_type":"markdown","metadata":{"id":"Bz16rxaSH_7f"},"source":["First, install the required packages and download the datasets and pretrained model. Here we use CIFAR10 dataset and VGG network which is the same as what we used in the Lab 0 tutorial.\n","\n"]},{"cell_type":"code","execution_count":1,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":14111,"status":"ok","timestamp":1696471915810,"user":{"displayName":"Ji Lin","userId":"00458298044061695840"},"user_tz":420},"id":"nyngBRTXQG2n","outputId":"75c65dd0-4d9e-40b9-cd5c-b691229c8c5c"},"outputs":[{"name":"stdout","output_type":"stream","text":["Installing torchprofile...\n","Installing fast-pytorch-kmeans...\n","All required packages have been successfully installed!\n"]}],"source":["print('Installing torchprofile...')\n","!pip install torchprofile 1>/dev/null\n","print('Installing fast-pytorch-kmeans...')\n","! pip install fast-pytorch-kmeans 1>/dev/null\n","print('All required packages have been successfully installed!')"]},{"cell_type":"code","execution_count":2,"metadata":{"executionInfo":{"elapsed":4631,"status":"ok","timestamp":1696471920437,"user":{"displayName":"Ji Lin","userId":"00458298044061695840"},"user_tz":420},"id":"zDWoVhv_wGmA"},"outputs":[],"source":["import copy\n","import math\n","import random\n","from collections import OrderedDict, defaultdict\n","\n","from matplotlib import pyplot as plt\n","from matplotlib.colors import ListedColormap\n","import numpy as np\n","from tqdm.auto import tqdm\n","\n","import torch\n","from torch import nn\n","from torch.optim import *\n","from torch.optim.lr_scheduler import *\n","from torch.utils.data import DataLoader\n","from torchprofile import profile_macs\n","from torchvision.datasets import *\n","from torchvision.transforms import *\n","\n","from torchprofile import profile_macs\n","\n","assert torch.cuda.is_available(), \\\n","\"The current runtime does not have CUDA support.\" \\\n","\"Please go to menu bar (Runtime - Change runtime type) and select GPU\""]},{"cell_type":"code","execution_count":3,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":7,"status":"ok","timestamp":1696471920437,"user":{"displayName":"Ji Lin","userId":"00458298044061695840"},"user_tz":420},"id":"nLcJUofTDKud","outputId":"0ba542e8-7599-492f-a61c-923f66328f17"},"outputs":[{"data":{"text/plain":["<torch._C.Generator at 0x7ef9740f31b0>"]},"execution_count":3,"metadata":{},"output_type":"execute_result"}],"source":["random.seed(0)\n","np.random.seed(0)\n","torch.manual_seed(0)"]},{"cell_type":"code","execution_count":4,"metadata":{"executionInfo":{"elapsed":6,"status":"ok","timestamp":1696471920437,"user":{"displayName":"Ji Lin","userId":"00458298044061695840"},"user_tz":420},"id":"6kpu78GyHGA-"},"outputs":[],"source":["def download_url(url, model_dir='.', overwrite=False):\n"," import os, sys\n"," from urllib.request import urlretrieve\n"," target_dir = url.split('/')[-1]\n"," model_dir = os.path.expanduser(model_dir)\n"," try:\n"," if not os.path.exists(model_dir):\n"," os.makedirs(model_dir)\n"," model_dir = os.path.join(model_dir, target_dir)\n"," cached_file = model_dir\n"," if not os.path.exists(cached_file) or overwrite:\n"," sys.stderr.write('Downloading: \"{}\" to {}\\n'.format(url, cached_file))\n"," urlretrieve(url, cached_file)\n"," return cached_file\n"," except Exception as e:\n"," # remove lock file so download can be executed next time.\n"," os.remove(os.path.join(model_dir, 'download.lock'))\n"," sys.stderr.write('Failed to download from url %s' % url + '\\n' + str(e) + '\\n')\n"," return None"]},{"cell_type":"code","execution_count":5,"metadata":{"executionInfo":{"elapsed":5,"status":"ok","timestamp":1696471920437,"user":{"displayName":"Ji Lin","userId":"00458298044061695840"},"user_tz":420},"id":"qqInscyoifYN"},"outputs":[],"source":["class VGG(nn.Module):\n"," ARCH = [64, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M']\n","\n"," def __init__(self) -> None:\n"," super().__init__()\n","\n"," layers = []\n"," counts = defaultdict(int)\n","\n"," def add(name: str, layer: nn.Module) -> None:\n"," layers.append((f\"{name}{counts[name]}\", layer))\n"," counts[name] += 1\n","\n"," in_channels = 3\n"," for x in self.ARCH:\n"," if x != 'M':\n"," # conv-bn-relu\n"," add(\"conv\", nn.Conv2d(in_channels, x, 3, padding=1, bias=False))\n"," add(\"bn\", nn.BatchNorm2d(x))\n"," add(\"relu\", nn.ReLU(True))\n"," in_channels = x\n"," else:\n"," # maxpool\n"," add(\"pool\", nn.MaxPool2d(2))\n"," add(\"avgpool\", nn.AvgPool2d(2))\n"," self.backbone = nn.Sequential(OrderedDict(layers))\n"," self.classifier = nn.Linear(512, 10)\n","\n"," def forward(self, x: torch.Tensor) -> torch.Tensor:\n"," # backbone: [N, 3, 32, 32] => [N, 512, 2, 2]\n"," x = self.backbone(x)\n","\n"," # avgpool: [N, 512, 2, 2] => [N, 512]\n"," # x = x.mean([2, 3])\n"," x = x.view(x.shape[0], -1)\n","\n"," # classifier: [N, 512] => [N, 10]\n"," x = self.classifier(x)\n"," return x"]},{"cell_type":"code","execution_count":6,"metadata":{"executionInfo":{"elapsed":5,"status":"ok","timestamp":1696471920437,"user":{"displayName":"Ji Lin","userId":"00458298044061695840"},"user_tz":420},"id":"WqnPt0LUEaWi"},"outputs":[],"source":["def train(\n"," model: nn.Module,\n"," dataloader: DataLoader,\n"," criterion: nn.Module,\n"," optimizer: Optimizer,\n"," scheduler: LambdaLR,\n"," callbacks = None\n",") -> None:\n"," model.train()\n","\n"," for inputs, targets in tqdm(dataloader, desc='train', leave=False):\n"," # Move the data from CPU to GPU\n"," inputs = inputs.cuda()\n"," targets = targets.cuda()\n","\n"," # Reset the gradients (from the last iteration)\n"," optimizer.zero_grad()\n","\n"," # Forward inference\n"," outputs = model(inputs)\n"," loss = criterion(outputs, targets)\n","\n"," # Backward propagation\n"," loss.backward()\n","\n"," # Update optimizer and LR scheduler\n"," optimizer.step()\n"," scheduler.step()\n","\n"," if callbacks is not None:\n"," for callback in callbacks:\n"," callback()"]},{"cell_type":"code","execution_count":7,"metadata":{"executionInfo":{"elapsed":5,"status":"ok","timestamp":1696471920437,"user":{"displayName":"Ji Lin","userId":"00458298044061695840"},"user_tz":420},"id":"wVA1_oeUEUf6"},"outputs":[],"source":["@torch.inference_mode()\n","def evaluate(\n"," model: nn.Module,\n"," dataloader: DataLoader,\n"," extra_preprocess = None\n",") -> float:\n"," model.eval()\n","\n"," num_samples = 0\n"," num_correct = 0\n","\n"," for inputs, targets in tqdm(dataloader, desc=\"eval\", leave=False):\n"," # Move the data from CPU to GPU\n"," inputs = inputs.cuda()\n"," if extra_preprocess is not None:\n"," for preprocess in extra_preprocess:\n"," inputs = preprocess(inputs)\n","\n"," targets = targets.cuda()\n","\n"," # Inference\n"," outputs = model(inputs)\n","\n"," # Convert logits to class indices\n"," outputs = outputs.argmax(dim=1)\n","\n"," # Update metrics\n"," num_samples += targets.size(0)\n"," num_correct += (outputs == targets).sum()\n","\n"," return (num_correct / num_samples * 100).item()"]},{"cell_type":"markdown","metadata":{"id":"QBBKNhKNlAwE"},"source":["Helpler Functions (Flops, Model Size calculation, etc.)"]},{"cell_type":"code","execution_count":8,"metadata":{"executionInfo":{"elapsed":4,"status":"ok","timestamp":1696471920437,"user":{"displayName":"Ji Lin","userId":"00458298044061695840"},"user_tz":420},"id":"mRdK_ThzlMxL"},"outputs":[],"source":["def get_model_flops(model, inputs):\n"," num_macs = profile_macs(model, inputs)\n"," return num_macs"]},{"cell_type":"code","execution_count":9,"metadata":{"executionInfo":{"elapsed":4,"status":"ok","timestamp":1696471920437,"user":{"displayName":"Ji Lin","userId":"00458298044061695840"},"user_tz":420},"id":"cepv4SUalU79"},"outputs":[],"source":["def get_model_size(model: nn.Module, data_width=32):\n"," \"\"\"\n"," calculate the model size in bits\n"," :param data_width: #bits per element\n"," \"\"\"\n"," num_elements = 0\n"," for param in model.parameters():\n"," num_elements += param.numel()\n"," return num_elements * data_width\n","\n","Byte = 8\n","KiB = 1024 * Byte\n","MiB = 1024 * KiB\n","GiB = 1024 * MiB"]},{"cell_type":"markdown","metadata":{"id":"CGhomDjsaDB5"},"source":["Define misc funcions for verification."]},{"cell_type":"code","execution_count":15,"metadata":{"executionInfo":{"elapsed":1,"status":"ok","timestamp":1696471978643,"user":{"displayName":"Ji Lin","userId":"00458298044061695840"},"user_tz":420},"id":"WOJBXVXQeCXF"},"outputs":[],"source":["def test_k_means_quantize(\n"," test_tensor=torch.tensor([\n"," [-0.3747, 0.0874, 0.3200, -0.4868, 0.4404],\n"," [-0.0402, 0.2322, -0.2024, -0.4986, 0.1814],\n"," [ 0.3102, -0.3942, -0.2030, 0.0883, -0.4741],\n"," [-0.1592, -0.0777, -0.3946, -0.2128, 0.2675],\n"," [ 0.0611, -0.1933, -0.4350, 0.2928, -0.1087]]),\n"," bitwidth=2):\n"," def plot_matrix(tensor, ax, title, cmap=ListedColormap(['white'])):\n"," ax.imshow(tensor.cpu().numpy(), vmin=-0.5, vmax=0.5, cmap=cmap)\n"," ax.set_title(title)\n"," ax.set_yticklabels([])\n"," ax.set_xticklabels([])\n"," for i in range(tensor.shape[1]):\n"," for j in range(tensor.shape[0]):\n"," text = ax.text(j, i, f'{tensor[i, j].item():.2f}',\n"," ha=\"center\", va=\"center\", color=\"k\")\n","\n"," fig, axes = plt.subplots(1,2, figsize=(8, 12))\n"," ax_left, ax_right = axes.ravel()\n","\n"," print(test_tensor)\n"," plot_matrix(test_tensor, ax_left, 'original tensor')\n","\n"," num_unique_values_before_quantization = test_tensor.unique().numel()\n"," k_means_quantize(test_tensor, bitwidth=bitwidth)\n"," num_unique_values_after_quantization = test_tensor.unique().numel()\n"," print('* Test k_means_quantize()')\n"," print(f' target bitwidth: {bitwidth} bits')\n"," print(f' num unique values before k-means quantization: {num_unique_values_before_quantization}')\n"," print(f' num unique values after k-means quantization: {num_unique_values_after_quantization}')\n"," assert num_unique_values_after_quantization == min((1 << bitwidth), num_unique_values_before_quantization)\n"," print('* Test passed.')\n","\n"," plot_matrix(test_tensor, ax_right, f'{bitwidth}-bit k-means quantized tensor', cmap='tab20c')\n"," fig.tight_layout()\n"," plt.show()"]},{"cell_type":"code","execution_count":11,"metadata":{"executionInfo":{"elapsed":4,"status":"ok","timestamp":1696471920850,"user":{"displayName":"Ji Lin","userId":"00458298044061695840"},"user_tz":420},"id":"jI92NMafaCk7"},"outputs":[],"source":["def test_linear_quantize(\n"," test_tensor=torch.tensor([\n"," [ 0.0523, 0.6364, -0.0968, -0.0020, 0.1940],\n"," [ 0.7500, 0.5507, 0.6188, -0.1734, 0.4677],\n"," [-0.0669, 0.3836, 0.4297, 0.6267, -0.0695],\n"," [ 0.1536, -0.0038, 0.6075, 0.6817, 0.0601],\n"," [ 0.6446, -0.2500, 0.5376, -0.2226, 0.2333]]),\n"," quantized_test_tensor=torch.tensor([\n"," [-1, 1, -1, -1, 0],\n"," [ 1, 1, 1, -2, 0],\n"," [-1, 0, 0, 1, -1],\n"," [-1, -1, 1, 1, -1],\n"," [ 1, -2, 1, -2, 0]], dtype=torch.int8),\n"," real_min=-0.25, real_max=0.75, bitwidth=2, scale=1/3, zero_point=-1):\n"," def plot_matrix(tensor, ax, title, vmin=0, vmax=1, cmap=ListedColormap(['white'])):\n"," ax.imshow(tensor.cpu().numpy(), vmin=vmin, vmax=vmax, cmap=cmap)\n"," ax.set_title(title)\n"," ax.set_yticklabels([])\n"," ax.set_xticklabels([])\n"," for i in range(tensor.shape[0]):\n"," for j in range(tensor.shape[1]):\n"," datum = tensor[i, j].item()\n"," if isinstance(datum, float):\n"," text = ax.text(j, i, f'{datum:.2f}',\n"," ha=\"center\", va=\"center\", color=\"k\")\n"," else:\n"," text = ax.text(j, i, f'{datum}',\n"," ha=\"center\", va=\"center\", color=\"k\")\n"," quantized_min, quantized_max = get_quantized_range(bitwidth)\n"," fig, axes = plt.subplots(1,3, figsize=(10, 32))\n"," plot_matrix(test_tensor, axes[0], 'original tensor', vmin=real_min, vmax=real_max)\n"," _quantized_test_tensor = linear_quantize(\n"," test_tensor, bitwidth=bitwidth, scale=scale, zero_point=zero_point)\n"," _reconstructed_test_tensor = scale * (_quantized_test_tensor.float() - zero_point)\n"," print('* Test linear_quantize()')\n"," print(f' target bitwidth: {bitwidth} bits')\n"," print(f' scale: {scale}')\n"," print(f' zero point: {zero_point}')\n"," assert _quantized_test_tensor.equal(quantized_test_tensor)\n"," print('* Test passed.')\n"," plot_matrix(_quantized_test_tensor, axes[1], f'2-bit linear quantized tensor',\n"," vmin=quantized_min, vmax=quantized_max, cmap='tab20c')\n"," plot_matrix(_reconstructed_test_tensor, axes[2], f'reconstructed tensor',\n"," vmin=real_min, vmax=real_max, cmap='tab20c')\n"," fig.tight_layout()\n"," plt.show()\n"]},{"cell_type":"code","execution_count":12,"metadata":{"executionInfo":{"elapsed":4,"status":"ok","timestamp":1696471920850,"user":{"displayName":"Ji Lin","userId":"00458298044061695840"},"user_tz":420},"id":"vK-56lHteduz"},"outputs":[],"source":["def test_quantized_fc(\n"," input=torch.tensor([\n"," [0.6118, 0.7288, 0.8511, 0.2849, 0.8427, 0.7435, 0.4014, 0.2794],\n"," [0.3676, 0.2426, 0.1612, 0.7684, 0.6038, 0.0400, 0.2240, 0.4237],\n"," [0.6565, 0.6878, 0.4670, 0.3470, 0.2281, 0.8074, 0.0178, 0.3999],\n"," [0.1863, 0.3567, 0.6104, 0.0497, 0.0577, 0.2990, 0.6687, 0.8626]]),\n"," weight=torch.tensor([\n"," [ 1.2626e-01, -1.4752e-01, 8.1910e-02, 2.4982e-01, -1.0495e-01,\n"," -1.9227e-01, -1.8550e-01, -1.5700e-01],\n"," [ 2.7624e-01, -4.3835e-01, 5.1010e-02, -1.2020e-01, -2.0344e-01,\n"," 1.0202e-01, -2.0799e-01, 2.4112e-01],\n"," [-3.8216e-01, -2.8047e-01, 8.5238e-02, -4.2504e-01, -2.0952e-01,\n"," 3.2018e-01, -3.3619e-01, 2.0219e-01],\n"," [ 8.9233e-02, -1.0124e-01, 1.1467e-01, 2.0091e-01, 1.1438e-01,\n"," -4.2427e-01, 1.0178e-01, -3.0941e-04],\n"," [-1.8837e-02, -2.1256e-01, -4.5285e-01, 2.0949e-01, -3.8684e-01,\n"," -1.7100e-01, -4.5331e-01, -2.0433e-01],\n"," [-2.0038e-01, -5.3757e-02, 1.8997e-01, -3.6866e-01, 5.5484e-02,\n"," 1.5643e-01, -2.3538e-01, 2.1103e-01],\n"," [-2.6875e-01, 2.4984e-01, -2.3514e-01, 2.5527e-01, 2.0322e-01,\n"," 3.7675e-01, 6.1563e-02, 1.7201e-01],\n"," [ 3.3541e-01, -3.3555e-01, -4.3349e-01, 4.3043e-01, -2.0498e-01,\n"," -1.8366e-01, -9.1553e-02, -4.1168e-01]]),\n"," bias=torch.tensor([ 0.1954, -0.2756, 0.3113, 0.1149, 0.4274, 0.2429, -0.1721, -0.2502]),\n"," quantized_bias=torch.tensor([ 3, -2, 3, 1, 3, 2, -2, -2], dtype=torch.int32),\n"," shifted_quantized_bias=torch.tensor([-1, 0, -3, -1, -3, 0, 2, -4], dtype=torch.int32),\n"," calc_quantized_output=torch.tensor([\n"," [ 0, -1, 0, -1, -1, 0, 1, -2],\n"," [ 0, 0, -1, 0, 0, 0, 0, -1],\n"," [ 0, 0, 0, -1, 0, 0, 0, -1],\n"," [ 0, 0, 0, 0, 0, 1, -1, -2]], dtype=torch.int8),\n"," bitwidth=2, batch_size=4, in_channels=8, out_channels=8):\n"," def plot_matrix(tensor, ax, title, vmin=0, vmax=1, cmap=ListedColormap(['white'])):\n"," ax.imshow(tensor.cpu().numpy(), vmin=vmin, vmax=vmax, cmap=cmap)\n"," ax.set_title(title)\n"," ax.set_yticklabels([])\n"," ax.set_xticklabels([])\n"," for i in range(tensor.shape[0]):\n"," for j in range(tensor.shape[1]):\n"," datum = tensor[i, j].item()\n"," if isinstance(datum, float):\n"," text = ax.text(j, i, f'{datum:.2f}',\n"," ha=\"center\", va=\"center\", color=\"k\")\n"," else:\n"," text = ax.text(j, i, f'{datum}',\n"," ha=\"center\", va=\"center\", color=\"k\")\n","\n"," output = torch.nn.functional.linear(input, weight, bias)\n","\n"," quantized_weight, weight_scale, weight_zero_point = \\\n"," linear_quantize_weight_per_channel(weight, bitwidth)\n"," quantized_input, input_scale, input_zero_point = \\\n"," linear_quantize_feature(input, bitwidth)\n"," _quantized_bias, bias_scale, bias_zero_point = \\\n"," linear_quantize_bias_per_output_channel(bias, weight_scale, input_scale)\n"," assert _quantized_bias.equal(_quantized_bias)\n"," _shifted_quantized_bias = \\\n"," shift_quantized_linear_bias(quantized_bias, quantized_weight, input_zero_point)\n"," assert _shifted_quantized_bias.equal(shifted_quantized_bias)\n"," quantized_output, output_scale, output_zero_point = \\\n"," linear_quantize_feature(output, bitwidth)\n","\n"," _calc_quantized_output = quantized_linear(\n"," quantized_input, quantized_weight, shifted_quantized_bias,\n"," bitwidth, bitwidth,\n"," input_zero_point, output_zero_point,\n"," input_scale, weight_scale, output_scale)\n"," assert _calc_quantized_output.equal(calc_quantized_output)\n","\n"," reconstructed_weight = weight_scale * (quantized_weight.float() - weight_zero_point)\n"," reconstructed_input = input_scale * (quantized_input.float() - input_zero_point)\n"," reconstructed_bias = bias_scale * (quantized_bias.float() - bias_zero_point)\n"," reconstructed_calc_output = output_scale * (calc_quantized_output.float() - output_zero_point)\n","\n"," fig, axes = plt.subplots(3,3, figsize=(15, 12))\n"," quantized_min, quantized_max = get_quantized_range(bitwidth)\n"," plot_matrix(weight, axes[0, 0], 'original weight', vmin=-0.5, vmax=0.5)\n"," plot_matrix(input.t(), axes[1, 0], 'original input', vmin=0, vmax=1)\n"," plot_matrix(output.t(), axes[2, 0], 'original output', vmin=-1.5, vmax=1.5)\n"," plot_matrix(quantized_weight, axes[0, 1], f'{bitwidth}-bit linear quantized weight',\n"," vmin=quantized_min, vmax=quantized_max, cmap='tab20c')\n"," plot_matrix(quantized_input.t(), axes[1, 1], f'{bitwidth}-bit linear quantized input',\n"," vmin=quantized_min, vmax=quantized_max, cmap='tab20c')\n"," plot_matrix(calc_quantized_output.t(), axes[2, 1], f'quantized output from quantized_linear()',\n"," vmin=quantized_min, vmax=quantized_max, cmap='tab20c')\n"," plot_matrix(reconstructed_weight, axes[0, 2], f'reconstructed weight',\n"," vmin=-0.5, vmax=0.5, cmap='tab20c')\n"," plot_matrix(reconstructed_input.t(), axes[1, 2], f'reconstructed input',\n"," vmin=0, vmax=1, cmap='tab20c')\n"," plot_matrix(reconstructed_calc_output.t(), axes[2, 2], f'reconstructed output',\n"," vmin=-1.5, vmax=1.5, cmap='tab20c')\n","\n"," print('* Test quantized_fc()')\n"," print(f' target bitwidth: {bitwidth} bits')\n"," print(f' batch size: {batch_size}')\n"," print(f' input channels: {in_channels}')\n"," print(f' output channels: {out_channels}')\n"," print('* Test passed.')\n"," fig.tight_layout()\n"," plt.show()"]},{"cell_type":"markdown","metadata":{"id":"mDDp0OWLh6vK"},"source":["Load Pretrained Model"]},{"cell_type":"code","execution_count":13,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":12631,"status":"ok","timestamp":1696471933477,"user":{"displayName":"Ji Lin","userId":"00458298044061695840"},"user_tz":420},"id":"oNEYAZ_TQf7d","outputId":"426d48fc-7d79-49ef-a564-2f8701603ce1"},"outputs":[{"name":"stderr","output_type":"stream","text":["Downloading: \"https://hanlab18.mit.edu/files/course/labs/vgg.cifar.pretrained.pth\" to ./vgg.cifar.pretrained.pth\n"]},{"name":"stdout","output_type":"stream","text":["=> loading checkpoint 'https://hanlab18.mit.edu/files/course/labs/vgg.cifar.pretrained.pth'\n"]}],"source":["checkpoint_url = \"https://hanlab18.mit.edu/files/course/labs/vgg.cifar.pretrained.pth\"\n","checkpoint = torch.load(download_url(checkpoint_url), map_location=\"cpu\")\n","model = VGG().cuda()\n","print(f\"=> loading checkpoint '{checkpoint_url}'\")\n","model.load_state_dict(checkpoint['state_dict'])\n","recover_model = lambda : model.load_state_dict(checkpoint['state_dict'])"]},{"cell_type":"code","execution_count":14,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":27440,"status":"ok","timestamp":1696471960915,"user":{"displayName":"Ji Lin","userId":"00458298044061695840"},"user_tz":420},"id":"jQb3_6zlQfib","outputId":"9f4a17ef-4d3d-487e-ec19-fde5201b20a4"},"outputs":[{"name":"stdout","output_type":"stream","text":["Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar10/cifar-10-python.tar.gz\n"]},{"name":"stderr","output_type":"stream","text":["100%|██████████| 170498071/170498071 [00:20<00:00, 8311603.99it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Extracting data/cifar10/cifar-10-python.tar.gz to data/cifar10\n","Files already downloaded and verified\n"]}],"source":["image_size = 32\n","transforms = {\n"," \"train\": Compose([\n"," RandomCrop(image_size, padding=4),\n"," RandomHorizontalFlip(),\n"," ToTensor(),\n"," ]),\n"," \"test\": ToTensor(),\n","}\n","dataset = {}\n","for split in [\"train\", \"test\"]:\n"," dataset[split] = CIFAR10(\n"," root=\"data/cifar10\",\n"," train=(split == \"train\"),\n"," download=True,\n"," transform=transforms[split],\n"," )\n","dataloader = {}\n","for split in ['train', 'test']:\n"," dataloader[split] = DataLoader(\n"," dataset[split],\n"," batch_size=512,\n"," shuffle=(split == 'train'),\n"," num_workers=0,\n"," pin_memory=True,\n"," )"]},{"cell_type":"markdown","metadata":{"id":"rjESsg5nkdBG"},"source":["# Let's First Evaluate the Accuracy and Model Size of the FP32 Model"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"DTiA8hxMkkbU"},"outputs":[],"source":["fp32_model_accuracy = evaluate(model, dataloader['test'])\n","fp32_model_size = get_model_size(model)\n","print(f\"fp32 model has accuracy={fp32_model_accuracy:.2f}%\")\n","print(f\"fp32 model has size={fp32_model_size/MiB:.2f} MiB\")"]},{"cell_type":"markdown","metadata":{"id":"6QdGiddu87p2"},"source":["# K-Means Quantization"]},{"cell_type":"markdown","metadata":{"id":"Gr0MvPoVTirU"},"source":["Network quantization compresses the network by reducing the bits per weight required to represent the deep network. The quantized network can have a faster inference speed with hardware support."]},{"cell_type":"markdown","metadata":{"id":"lzlKpiNvY3lc"},"source":["In this section, we will explore the K-means quantization for neural networks as in [Deep Compression: Compressing Deep Neural Networks With Pruning, Trained Quantization\n","And Huffman Coding](https://arxiv.org/pdf/1510.00149.pdf).\n"]},{"cell_type":"markdown","metadata":{"id":"KDoR8iXP0VDG"},"source":[""]},{"cell_type":"markdown","metadata":{"id":"mGkEk4CX0VIe"},"source":["A $n$-bit k-means quantization will divide synapses into $2^n$ clusters, and synapses in the same cluster will share the same weight value.\n","\n","Therefore, k-means quantization will create a codebook, inlcuding\n","* `centroids`: $2^n$ fp32 cluster centers.\n","* `labels`: a $n$-bit integer tensor with the same #elements of the original fp32 weights tensor. Each integer indicates which cluster it belongs to.\n","\n","During the inference, a fp32 tensor is generated based on the codebook for inference:\n","\n","> ***quantized_weight* = *codebook.centroids*\\[*codebook.labels*\\].view_as(weight)**"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"Z2x1DbpfT3SD"},"outputs":[],"source":["from collections import namedtuple\n","\n","Codebook = namedtuple('Codebook', ['centroids', 'labels'])"]},{"cell_type":"markdown","metadata":{"id":"f3BczeDEN0y6"},"source":["## Question 1 (10 pts)\n","\n","Please complete the following K-Means quantization function."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"bnZ2b4hgOnoC"},"outputs":[],"source":["from fast_pytorch_kmeans import KMeans\n","\n","def k_means_quantize(fp32_tensor: torch.Tensor, bitwidth=4, codebook=None):\n"," \"\"\"\n"," quantize tensor using k-means clustering\n"," :param fp32_tensor:\n"," :param bitwidth: [int] quantization bit width, default=4\n"," :param codebook: [Codebook] (the cluster centroids, the cluster label tensor)\n"," :return:\n"," [Codebook = (centroids, labels)]\n"," centroids: [torch.(cuda.)FloatTensor] the cluster centroids\n"," labels: [torch.(cuda.)LongTensor] cluster label tensor\n"," \"\"\"\n"," if codebook is None:\n"," ############### YOUR CODE STARTS HERE ###############\n"," # get number of clusters based on the quantization precision\n"," # hint: one line of code\n"," n_clusters = 0\n"," ############### YOUR CODE ENDS HERE #################\n"," # use k-means to get the quantization centroids\n"," kmeans = KMeans(n_clusters=n_clusters, mode='euclidean', verbose=0)\n"," labels = kmeans.fit_predict(fp32_tensor.view(-1, 1)).to(torch.long)\n"," centroids = kmeans.centroids.to(torch.float).view(-1)\n"," codebook = Codebook(centroids, labels)\n"," ############### YOUR CODE STARTS HERE ###############\n"," # decode the codebook into k-means quantized tensor for inference\n"," # hint: one line of code\n"," quantized_tensor = 0\n"," ############### YOUR CODE ENDS HERE #################\n"," fp32_tensor.set_(quantized_tensor.view_as(fp32_tensor))\n"," return codebook"]},{"cell_type":"markdown","metadata":{"id":"XlVfIFRYPbfc"},"source":["Let's verify the functionality of defined k-means quantization by applying the function above on a dummy tensor."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"WZd8z1vGPdG8"},"outputs":[],"source":["test_k_means_quantize()"]},{"cell_type":"markdown","metadata":{"id":"O4FIfKk8JQX5"},"source":["## Question 2 (10 pts)"]},{"cell_type":"markdown","metadata":{"id":"o5mT6vV_JS5w"},"source":["The last code cell performs 2-bit k-means quantization and plots the tensor before and after the quantization. Each cluster is rendered with a unique color. There are 4 unique colors rendered in the quantized tensor.\n","\n","Given this observation, please answer the following questions."]},{"cell_type":"markdown","metadata":{"id":"DwERfTxNKFqa"},"source":["### Question 2.1 (5 pts)\n","\n","If 4-bit k-means quantization is performed, how many unique colors will be rendered in the quantized tensor?"]},{"cell_type":"markdown","metadata":{"id":"0PGiy4rGKPbe"},"source":["**Your Answer:**"]},{"cell_type":"markdown","metadata":{"id":"Dt1ZGParKRWD"},"source":["### Question 2.2 (5 pts)\n","\n","If *n*-bit k-means quantization is performed, how many unique colors will be rendered in the quantized tensor?"]},{"cell_type":"markdown","metadata":{"id":"gK8v80tnKZHa"},"source":["**Your Answer:**"]},{"cell_type":"markdown","metadata":{"id":"Aa7KVrLrSlNb"},"source":["## K-Means Quantization on Whole Model\n","\n","Similar to what we did in lab 1, we now wrap the k-means quantization function into a class for quantizing the whole model. In class `KMeansQuantizer`, we have to keep a record of the codebooks (i.e., `centroids` and `labels`) so that we could apply or update the codebooks whenever the model weights change."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"eaMAP465Te-p"},"outputs":[],"source":["from torch.nn import parameter\n","class KMeansQuantizer:\n"," def __init__(self, model : nn.Module, bitwidth=4):\n"," self.codebook = KMeansQuantizer.quantize(model, bitwidth)\n","\n"," @torch.no_grad()\n"," def apply(self, model, update_centroids):\n"," for name, param in model.named_parameters():\n"," if name in self.codebook:\n"," if update_centroids:\n"," update_codebook(param, codebook=self.codebook[name])\n"," self.codebook[name] = k_means_quantize(\n"," param, codebook=self.codebook[name])\n","\n"," @staticmethod\n"," @torch.no_grad()\n"," def quantize(model: nn.Module, bitwidth=4):\n"," codebook = dict()\n"," if isinstance(bitwidth, dict):\n"," for name, param in model.named_parameters():\n"," if name in bitwidth:\n"," codebook[name] = k_means_quantize(param, bitwidth=bitwidth[name])\n"," else:\n"," for name, param in model.named_parameters():\n"," if param.dim() > 1:\n"," codebook[name] = k_means_quantize(param, bitwidth=bitwidth)\n"," return codebook"]},{"cell_type":"markdown","metadata":{"id":"od8iy9KjVhlT"},"source":["Now let's quantize model into 8 bits, 4 bits and 2 bits using K-Means Quantization. *Note that we ignore the storage for codebooks when calculating the model size.*"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"CCI9sCUmVre-"},"outputs":[],"source":["print('Note that the storage for codebooks is ignored when calculating the model size.')\n","quantizers = dict()\n","for bitwidth in [8, 4, 2]:\n"," recover_model()\n"," print(f'k-means quantizing model into {bitwidth} bits')\n"," quantizer = KMeansQuantizer(model, bitwidth)\n"," quantized_model_size = get_model_size(model, bitwidth)\n"," print(f\" {bitwidth}-bit k-means quantized model has size={quantized_model_size/MiB:.2f} MiB\")\n"," quantized_model_accuracy = evaluate(model, dataloader['test'])\n"," print(f\" {bitwidth}-bit k-means quantized model has accuracy={quantized_model_accuracy:.2f}%\")\n"," quantizers[bitwidth] = quantizer"]},{"cell_type":"markdown","metadata":{"id":"4jj3FnIIWVq0"},"source":["## Trained K-Means Quantization\n","\n","As we can see from the results of last cell, the accuracy significantly drops when quantizing the model into lower bits. Therefore, we have to perform quantization-aware training to recover the accuracy.\n","\n","During the k-means quantization-aware training, the centroids are also updated, which is proposed in [Deep Compression: Compressing Deep Neural Networks With Pruning, Trained Quantization\n","And Huffman Coding](https://arxiv.org/pdf/1510.00149.pdf).\n","\n","The gradient for the centroids is calculated as,\n","\n","> $\\frac{\\partial \\mathcal{L} }{\\partial C_k} = \\sum_{j} \\frac{\\partial \\mathcal{L} }{\\partial W_{j}} \\frac{\\partial W_{j} }{\\partial C_k} = \\sum_{j} \\frac{\\partial \\mathcal{L} }{\\partial W_{j}} \\mathbf{1}(I_{j}=k)$\n","\n","where $\\mathcal{L}$ is the loss, $C_k$ is *k*-th centroid, $I_{j}$ is the label for weight $W_{j}$. $\\mathbf{1}()$ is the indicator function, and $\\mathbf{1}(I_{j}=k)$ means $1\\;\\mathrm{if}\\;I_{j}=k\\;\\mathrm{else}\\;0$, *i.e.*, $I_{j}==k$.\n","\n","Here in the lab, **for simplicity**, we directly update the centroids according to the latest weights:\n","\n","> $C_k = \\frac{\\sum_{j}W_{j}\\mathbf{1}(I_{j}=k)}{\\sum_{j}\\mathbf{1}(I_{j}=k)}$"]},{"cell_type":"markdown","metadata":{"id":"Q_hfG16fLVda"},"source":["### Question 3 (10 pts)\n","\n","Please complete the following codebook update function.\n","\n","**Hint**:\n","\n","The above equation for updating centroids is indeed using the `mean` of weights in the same cluster to be the updated centroid value."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"t8_FfiUf_1ZM"},"outputs":[],"source":["def update_codebook(fp32_tensor: torch.Tensor, codebook: Codebook):\n"," \"\"\"\n"," update the centroids in the codebook using updated fp32_tensor\n"," :param fp32_tensor: [torch.(cuda.)Tensor]\n"," :param codebook: [Codebook] (the cluster centroids, the cluster label tensor)\n"," \"\"\"\n"," n_clusters = codebook.centroids.numel()\n"," fp32_tensor = fp32_tensor.view(-1)\n"," for k in range(n_clusters):\n"," ############### YOUR CODE STARTS HERE ###############\n"," # hint: one line of code\n"," codebook.centroids[k] = 0\n"," ############### YOUR CODE ENDS HERE #################"]},{"cell_type":"markdown","metadata":{"id":"Ho4RDW0ccw4f"},"source":["Now let's run the following code cell to finetune the k-means quantized model to recover the accuracy. We will stop finetuning if accuracy drop is less than 0.5."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"DR1Fu6Arfgnk"},"outputs":[],"source":["accuracy_drop_threshold = 0.5\n","quantizers_before_finetune = copy.deepcopy(quantizers)\n","quantizers_after_finetune = quantizers\n","\n","for bitwidth in [8, 4, 2]:\n"," recover_model()\n"," quantizer = quantizers[bitwidth]\n"," print(f'k-means quantizing model into {bitwidth} bits')\n"," quantizer.apply(model, update_centroids=False)\n"," quantized_model_size = get_model_size(model, bitwidth)\n"," print(f\" {bitwidth}-bit k-means quantized model has size={quantized_model_size/MiB:.2f} MiB\")\n"," quantized_model_accuracy = evaluate(model, dataloader['test'])\n"," print(f\" {bitwidth}-bit k-means quantized model has accuracy={quantized_model_accuracy:.2f}% before quantization-aware training \")\n"," accuracy_drop = fp32_model_accuracy - quantized_model_accuracy\n"," if accuracy_drop > accuracy_drop_threshold:\n"," print(f\" Quantization-aware training due to accuracy drop={accuracy_drop:.2f}% is larger than threshold={accuracy_drop_threshold:.2f}%\")\n"," num_finetune_epochs = 5\n"," optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)\n"," scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_finetune_epochs)\n"," criterion = nn.CrossEntropyLoss()\n"," best_accuracy = 0\n"," epoch = num_finetune_epochs\n"," while accuracy_drop > accuracy_drop_threshold and epoch > 0:\n"," train(model, dataloader['train'], criterion, optimizer, scheduler,\n"," callbacks=[lambda: quantizer.apply(model, update_centroids=True)])\n"," model_accuracy = evaluate(model, dataloader['test'])\n"," is_best = model_accuracy > best_accuracy\n"," best_accuracy = max(model_accuracy, best_accuracy)\n"," print(f' Epoch {num_finetune_epochs-epoch} Accuracy {model_accuracy:.2f}% / Best Accuracy: {best_accuracy:.2f}%')\n"," accuracy_drop = fp32_model_accuracy - best_accuracy\n"," epoch -= 1\n"," else:\n"," print(f\" No need for quantization-aware training since accuracy drop={accuracy_drop:.2f}% is smaller than threshold={accuracy_drop_threshold:.2f}%\")"]},{"cell_type":"markdown","metadata":{"id":"USisw8QcnPNP"},"source":["# Linear Quantization"]},{"cell_type":"markdown","metadata":{"id":"sEzwgl9mZ6F2"},"source":["In this section, we will implement and perform linear quantization.\n","\n","Linear quantization directly rounds the floating-point value into the nearest quantized integer after range truncation and scaling.\n","\n","[Linear quantization](https://arxiv.org/pdf/1712.05877.pdf) can be represented as\n","\n","$r = S(q-Z)$\n","\n","where $r$ is a floating point real number, $q$ is a *n*-bit integer, $Z$ is a *n*-bit integer, and $S$ is a floating point real number. $Z$ is quantization zero point and $S$ is quantization scaling factor. Both constant $Z$ and $S$ are quantization parameters."]},{"cell_type":"markdown","metadata":{"id":"c6vO1_VBdTzs"},"source":["## *n*-bit Integer\n","\n","A *n*-bit signed integer is usually represented in [two's complement](https://en.wikipedia.org/wiki/Two%27s_complement) notation.\n","\n","A *n*-bit signed integer can enode integers in the range $[-2^{n-1}, 2^{n-1}-1]$. For example, a 8-bit integer falls in the range [-128, 127]."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"uz51HcD8dTIb"},"outputs":[],"source":["def get_quantized_range(bitwidth):\n"," quantized_max = (1 << (bitwidth - 1)) - 1\n"," quantized_min = -(1 << (bitwidth - 1))\n"," return quantized_min, quantized_max"]},{"cell_type":"markdown","metadata":{"id":"FSRwyEWrnA7o"},"source":["## **Question 4** (15 pts)\n","\n","Please complete the following linear quantization function.\n","\n","**Hint**:\n","* From $r=S(q-Z)$, we have $q = r/S + Z$.\n","* Both $r$ and $S$ are floating numbers, and thus we cannot directly add integer $Z$ to $r/S$. Therefore $q = \\mathrm{int}(\\mathrm{round}(r/S)) + Z$.\n","* To convert [`torch.FloatTensor`](https://pytorch.org/docs/stable/tensors.html) to [`torch.IntTensor`](https://pytorch.org/docs/stable/tensors.html), we could use [`torch.round()`](https://pytorch.org/docs/stable/generated/torch.round.html#torch.round), [`torch.Tensor.round()`](https://pytorch.org/docs/stable/generated/torch.Tensor.round.html#torch.Tensor.round), [`torch.Tensor.round_()`](https://pytorch.org/docs/stable/generated/torch.Tensor.round_) to first convert all values to floating integer, and then use [`torch.Tensor.to(torch.int8)`](https://pytorch.org/docs/stable/generated/torch.Tensor.to.html#torch.Tensor.to) to convert the data type from [`torch.float`](https://pytorch.org/docs/stable/tensors.html) to [`torch.int8`](https://pytorch.org/docs/stable/tensors.html).\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"1vg4SfIJYFUq"},"outputs":[],"source":["def linear_quantize(fp_tensor, bitwidth, scale, zero_point, dtype=torch.int8) -> torch.Tensor:\n"," \"\"\"\n"," linear quantization for single fp_tensor\n"," from\n"," fp_tensor = (quantized_tensor - zero_point) * scale\n"," we have,\n"," quantized_tensor = int(round(fp_tensor / scale)) + zero_point\n"," :param tensor: [torch.(cuda.)FloatTensor] floating tensor to be quantized\n"," :param bitwidth: [int] quantization bit width\n"," :param scale: [torch.(cuda.)FloatTensor] scaling factor\n"," :param zero_point: [torch.(cuda.)IntTensor] the desired centroid of tensor values\n"," :return:\n"," [torch.(cuda.)FloatTensor] quantized tensor whose values are integers\n"," \"\"\"\n"," assert(fp_tensor.dtype == torch.float)\n"," assert(isinstance(scale, float) or\n"," (scale.dtype == torch.float and scale.dim() == fp_tensor.dim()))\n"," assert(isinstance(zero_point, int) or\n"," (zero_point.dtype == dtype and zero_point.dim() == fp_tensor.dim()))\n","\n"," ############### YOUR CODE STARTS HERE ###############\n"," # Step 1: scale the fp_tensor\n"," scaled_tensor = 0\n"," # Step 2: round the floating value to integer value\n"," rounded_tensor = 0\n"," ############### YOUR CODE ENDS HERE #################\n","\n"," rounded_tensor = rounded_tensor.to(dtype)\n","\n"," ############### YOUR CODE STARTS HERE ###############\n"," # Step 3: shift the rounded_tensor to make zero_point 0\n"," shifted_tensor = 0\n"," ############### YOUR CODE ENDS HERE #################\n","\n"," # Step 4: clamp the shifted_tensor to lie in bitwidth-bit range\n"," quantized_min, quantized_max = get_quantized_range(bitwidth)\n"," quantized_tensor = shifted_tensor.clamp_(quantized_min, quantized_max)\n"," return quantized_tensor"]},{"cell_type":"markdown","metadata":{"id":"Qn-OU5z9IAtr"},"source":["Let's verify the functionality of defined linear quantization by applying the function above on a dummy tensor."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"10sN7WRNAWCc"},"outputs":[],"source":["test_linear_quantize()"]},{"cell_type":"markdown","metadata":{"id":"y9WBvG1jnhEz"},"source":["## Question 5 (10 pts)\n","\n","Now we have to determine the scaling factor $S$ and zero point $Z$ for linear quantization.\n","\n","Recall that [linear quantization](https://arxiv.org/pdf/1712.05877.pdf) can be represented as\n","\n","$r = S(q-Z)$"]},{"cell_type":"markdown","metadata":{"id":"lLgyax_fgrS1"},"source":["### Scale\n","\n","Linear quantization projects the floating point range [*fp_min*, *fp_max*] to the quantized range [*quantized_min*, *quantized_max*]. That is to say,\n","\n","> $r_{\\mathrm{max}} = S(q_{\\mathrm{max}}-Z)$\n",">\n","> $r_{\\mathrm{min}} = S(q_{\\mathrm{min}}-Z)$\n","\n","Substracting these two equations, we have,\n"]},{"cell_type":"markdown","metadata":{"id":"PkSalSPzSwN2"},"source":["#### Question 5.1 (1 pts)\n","\n","Please select the correct answer and delete the wrong answers in the next text cell."]},{"cell_type":"markdown","metadata":{"id":"dctWTxHaTnUy"},"source":["> $S=r_{\\mathrm{max}} / q_{\\mathrm{max}}$\n","\n","> $S=(r_{\\mathrm{max}} + r_{\\mathrm{min}}) / (q_{\\mathrm{max}} + q_{\\mathrm{min}})$\n","\n","> $S=(r_{\\mathrm{max}} - r_{\\mathrm{min}}) / (q_{\\mathrm{max}} - q_{\\mathrm{min}})$\n","\n","> $S=r_{\\mathrm{max}} / q_{\\mathrm{max}} - r_{\\mathrm{min}} / q_{\\mathrm{min}}$"]},{"cell_type":"markdown","metadata":{"id":"WWfff4eISuES"},"source":["\n","There are different approaches to determine the $r_{\\mathrm{min}}$ and $r_{\\mathrm{max}}$ of a floating point tensor `fp_tensor`.\n","\n","* The most common method is directly using the minimum and maximum value of `fp_tensor`.\n","* Another widely used method is minimizing Kullback-Leibler-J divergence to determine the *fp_max*."]},{"cell_type":"markdown","metadata":{"id":"SSb42B8Gjlpa"},"source":["### zero point\n","\n","Once we determine the scaling factor $S$, we can directly use the relationship between $r_{\\mathrm{min}}$ and $q_{\\mathrm{min}}$ to calculate the zero point $Z$."]},{"cell_type":"markdown","metadata":{"id":"u13Wmw_8TdDA"},"source":["#### Question 5.2 (1 pts)\n","\n","Please select the correct answer and delete the wrong answers in the next text cell."]},{"cell_type":"markdown","metadata":{"id":"uD3Q_FfSTwIL"},"source":["> $Z = \\mathrm{int}(\\mathrm{round}(r_{\\mathrm{min}} / S - q_{\\mathrm{min}})$\n","\n","> $Z = \\mathrm{int}(\\mathrm{round}(q_{\\mathrm{min}} - r_{\\mathrm{min}} / S))$\n","\n","> $Z = q_{\\mathrm{min}} - r_{\\mathrm{min}} / S$\n","\n","> $Z = r_{\\mathrm{min}} / S - q_{\\mathrm{min}}$"]},{"cell_type":"markdown","metadata":{"id":"tk4aPTKdAikt"},"source":["### Question 5.3 (8 pts)\n","\n","Please complete the following function for calculating the scale $S$ and zero point $Z$ from floating point tensor $r$.\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"LfAjS4KhfwDY"},"outputs":[],"source":["def get_quantization_scale_and_zero_point(fp_tensor, bitwidth):\n"," \"\"\"\n"," get quantization scale for single tensor\n"," :param fp_tensor: [torch.(cuda.)Tensor] floating tensor to be quantized\n"," :param bitwidth: [int] quantization bit width\n"," :return:\n"," [float] scale\n"," [int] zero_point\n"," \"\"\"\n"," quantized_min, quantized_max = get_quantized_range(bitwidth)\n"," fp_max = fp_tensor.max().item()\n"," fp_min = fp_tensor.min().item()\n","\n"," ############### YOUR CODE STARTS HERE ###############\n"," # hint: one line of code for calculating scale\n"," scale = 0\n"," # hint: one line of code for calculating zero_point\n"," zero_point = 0\n"," ############### YOUR CODE ENDS HERE #################\n","\n"," # clip the zero_point to fall in [quantized_min, quantized_max]\n"," if zero_point < quantized_min:\n"," zero_point = quantized_min\n"," elif zero_point > quantized_max:\n"," zero_point = quantized_max\n"," else: # convert from float to int using round()\n"," zero_point = round(zero_point)\n"," return scale, int(zero_point)"]},{"cell_type":"markdown","metadata":{"id":"OW1AOZboSK0e"},"source":["We now wrap `linear_quantize()` in Question 4 and `get_quantization_scale_and_zero_point()` in Question 5 into one function."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"FWeOwwXgoPLc"},"outputs":[],"source":["def linear_quantize_feature(fp_tensor, bitwidth):\n"," \"\"\"\n"," linear quantization for feature tensor\n"," :param fp_tensor: [torch.(cuda.)Tensor] floating feature to be quantized\n"," :param bitwidth: [int] quantization bit width\n"," :return:\n"," [torch.(cuda.)Tensor] quantized tensor\n"," [float] scale tensor\n"," [int] zero point\n"," \"\"\"\n"," scale, zero_point = get_quantization_scale_and_zero_point(fp_tensor, bitwidth)\n"," quantized_tensor = linear_quantize(fp_tensor, bitwidth, scale, zero_point)\n"," return quantized_tensor, scale, zero_point"]},{"cell_type":"markdown","metadata":{"id":"a0YWp1YtfY-a"},"source":["## Special case: linear quantization on weight tensor\n","\n","Let's first see the distribution of weight values."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"-HdIc5kKk89j"},"outputs":[],"source":["def plot_weight_distribution(model, bitwidth=32):\n"," # bins = (1 << bitwidth) if bitwidth <= 8 else 256\n"," if bitwidth <= 8:\n"," qmin, qmax = get_quantized_range(bitwidth)\n"," bins = np.arange(qmin, qmax + 2)\n"," align = 'left'\n"," else:\n"," bins = 256\n"," align = 'mid'\n"," fig, axes = plt.subplots(3,3, figsize=(10, 6))\n"," axes = axes.ravel()\n"," plot_index = 0\n"," for name, param in model.named_parameters():\n"," if param.dim() > 1:\n"," ax = axes[plot_index]\n"," ax.hist(param.detach().view(-1).cpu(), bins=bins, density=True,\n"," align=align, color = 'blue', alpha = 0.5,\n"," edgecolor='black' if bitwidth <= 4 else None)\n"," if bitwidth <= 4:\n"," quantized_min, quantized_max = get_quantized_range(bitwidth)\n"," ax.set_xticks(np.arange(start=quantized_min, stop=quantized_max+1))\n"," ax.set_xlabel(name)\n"," ax.set_ylabel('density')\n"," plot_index += 1\n"," fig.suptitle(f'Histogram of Weights (bitwidth={bitwidth} bits)')\n"," fig.tight_layout()\n"," fig.subplots_adjust(top=0.925)\n"," plt.show()\n","\n","recover_model()\n","plot_weight_distribution(model)"]},{"cell_type":"markdown","metadata":{"id":"WpAmSr7IlDn9"},"source":["As we can see from the histograms above, the distribution of weight values are nearly symmetric about 0 (except for the classifier in this case). Therefore, we usually make zero point $Z=0$ when quantizating the weights.\n","\n","From $r = S(q-Z)$, we have\n","\n","> $r_{\\mathrm{max}} = S \\cdot q_{\\mathrm{max}}$\n","\n","and then\n","\n","> $S = r_{\\mathrm{max}} / q_{\\mathrm{max}}$\n","\n","We directly use the maximum magnitude of weight values as $r_{\\mathrm{max}}$."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"v3d-Y-G9enEA"},"outputs":[],"source":["def get_quantization_scale_for_weight(weight, bitwidth):\n"," \"\"\"\n"," get quantization scale for single tensor of weight\n"," :param weight: [torch.(cuda.)Tensor] floating weight to be quantized\n"," :param bitwidth: [integer] quantization bit width\n"," :return:\n"," [floating scalar] scale\n"," \"\"\"\n"," # we just assume values in weight are symmetric\n"," # we also always make zero_point 0 for weight\n"," fp_max = max(weight.abs().max().item(), 5e-7)\n"," _, quantized_max = get_quantized_range(bitwidth)\n"," return fp_max / quantized_max"]},{"cell_type":"markdown","metadata":{"id":"X71SQhfDnwS7"},"source":["### Per-channel Linear Quantization\n","\n","Recall that for 2D convolution, the weight tensor is a 4-D tensor in the shape of (num_output_channels, num_input_channels, kernel_height, kernel_width).\n","\n","Intensive experiments show that using the different scaling factors $S$ and zero points $Z$ for different output channels will perform better. Therefore, we have to determine scaling factor $S$ and zero point $Z$ for the subtensor of each output channel independently."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"tFkA5JlgoiLs"},"outputs":[],"source":["def linear_quantize_weight_per_channel(tensor, bitwidth):\n"," \"\"\"\n"," linear quantization for weight tensor\n"," using different scales and zero_points for different output channels\n"," :param tensor: [torch.(cuda.)Tensor] floating weight to be quantized\n"," :param bitwidth: [int] quantization bit width\n"," :return:\n"," [torch.(cuda.)Tensor] quantized tensor\n"," [torch.(cuda.)Tensor] scale tensor\n"," [int] zero point (which is always 0)\n"," \"\"\"\n"," dim_output_channels = 0\n"," num_output_channels = tensor.shape[dim_output_channels]\n"," scale = torch.zeros(num_output_channels, device=tensor.device)\n"," for oc in range(num_output_channels):\n"," _subtensor = tensor.select(dim_output_channels, oc)\n"," _scale = get_quantization_scale_for_weight(_subtensor, bitwidth)\n"," scale[oc] = _scale\n"," scale_shape = [1] * tensor.dim()\n"," scale_shape[dim_output_channels] = -1\n"," scale = scale.view(scale_shape)\n"," quantized_tensor = linear_quantize(tensor, bitwidth, scale, zero_point=0)\n"," return quantized_tensor, scale, 0"]},{"cell_type":"markdown","metadata":{"id":"bqKnjcUJpq2R"},"source":["### A Quick Peek at Linear Quantization on Weights\n","\n","Now let's have a peek on the weight distribution and model size when applying linear quantization on weights with different bitwidths."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"6PzZe2DGvWap"},"outputs":[],"source":["@torch.no_grad()\n","def peek_linear_quantization():\n"," for bitwidth in [4, 2]:\n"," for name, param in model.named_parameters():\n"," if param.dim() > 1:\n"," quantized_param, scale, zero_point = \\\n"," linear_quantize_weight_per_channel(param, bitwidth)\n"," param.copy_(quantized_param)\n"," plot_weight_distribution(model, bitwidth)\n"," recover_model()\n","\n","peek_linear_quantization()"]},{"cell_type":"markdown","metadata":{"id":"5Dw14QcxxItm"},"source":["## Quantized Inference"]},{"cell_type":"markdown","metadata":{"id":"zsHy-Bx-UfpL"},"source":["After quantization, the inference of convolution and fully-connected layers also change.\n","\n","Recall that $r = S(q-Z)$, and we have\n","\n","> $r_{\\mathrm{input}} = S_{\\mathrm{input}}(q_{\\mathrm{input}}-Z_{\\mathrm{input}})$\n",">\n","> $r_{\\mathrm{weight}} = S_{\\mathrm{weight}}(q_{\\mathrm{weight}}-Z_{\\mathrm{weight}})$\n",">\n","> $r_{\\mathrm{bias}} = S_{\\mathrm{bias}}(q_{\\mathrm{bias}}-Z_{\\mathrm{bias}})$\n","\n","Since $Z_{\\mathrm{weight}}=0$, $r_{\\mathrm{weight}} = S_{\\mathrm{weight}}q_{\\mathrm{weight}}$.\n","\n","The floating point convolution can be written as,\n","\n","> $r_{\\mathrm{output}} = \\mathrm{CONV}[r_{\\mathrm{input}}, r_{\\mathrm{weight}}] + r_{\\mathrm{bias}}\\\\\n","\\;\\;\\;\\;\\;\\;\\;\\;= \\mathrm{CONV}[S_{\\mathrm{input}}(q_{\\mathrm{input}}-Z_{\\mathrm{input}}), S_{\\mathrm{weight}}q_{\\mathrm{weight}}] + S_{\\mathrm{bias}}(q_{\\mathrm{bias}}-Z_{\\mathrm{bias}})\\\\\n","\\;\\;\\;\\;\\;\\;\\;\\;= \\mathrm{CONV}[q_{\\mathrm{input}}-Z_{\\mathrm{input}}, q_{\\mathrm{weight}}]\\cdot (S_{\\mathrm{input}} \\cdot S_{\\mathrm{weight}}) + S_{\\mathrm{bias}}(q_{\\mathrm{bias}}-Z_{\\mathrm{bias}})$\n","\n","To further simplify the computation, we could let\n","\n","> $Z_{\\mathrm{bias}} = 0$\n",">\n","> $S_{\\mathrm{bias}} = S_{\\mathrm{input}} \\cdot S_{\\mathrm{weight}}$\n","\n","so that\n","\n","> $r_{\\mathrm{output}} = (\\mathrm{CONV}[q_{\\mathrm{input}}-Z_{\\mathrm{input}}, q_{\\mathrm{weight}}] + q_{\\mathrm{bias}})\\cdot (S_{\\mathrm{input}} \\cdot S_{\\mathrm{weight}})$\n","> $\\;\\;\\;\\;\\;\\;\\;\\;= (\\mathrm{CONV}[q_{\\mathrm{input}}, q_{\\mathrm{weight}}] - \\mathrm{CONV}[Z_{\\mathrm{input}}, q_{\\mathrm{weight}}] + q_{\\mathrm{bias}})\\cdot (S_{\\mathrm{input}}S_{\\mathrm{weight}})$\n","\n","Since\n","> $r_{\\mathrm{output}} = S_{\\mathrm{output}}(q_{\\mathrm{output}}-Z_{\\mathrm{output}})$\n","\n","we have\n","> $S_{\\mathrm{output}}(q_{\\mathrm{output}}-Z_{\\mathrm{output}}) = (\\mathrm{CONV}[q_{\\mathrm{input}}, q_{\\mathrm{weight}}] - \\mathrm{CONV}[Z_{\\mathrm{input}}, q_{\\mathrm{weight}}] + q_{\\mathrm{bias}})\\cdot (S_{\\mathrm{input}} S_{\\mathrm{weight}})$\n","\n","and thus\n","> $q_{\\mathrm{output}} = (\\mathrm{CONV}[q_{\\mathrm{input}}, q_{\\mathrm{weight}}] - \\mathrm{CONV}[Z_{\\mathrm{input}}, q_{\\mathrm{weight}}] + q_{\\mathrm{bias}})\\cdot (S_{\\mathrm{input}}S_{\\mathrm{weight}} / S_{\\mathrm{output}}) + Z_{\\mathrm{output}}$\n","\n","Since $Z_{\\mathrm{input}}$, $q_{\\mathrm{weight}}$, $q_{\\mathrm{bias}}$ are determined before inference, let\n","\n","> $Q_{\\mathrm{bias}} = q_{\\mathrm{bias}} - \\mathrm{CONV}[Z_{\\mathrm{input}}, q_{\\mathrm{weight}}]$\n","\n","we have\n","\n","> $q_{\\mathrm{output}} = (\\mathrm{CONV}[q_{\\mathrm{input}}, q_{\\mathrm{weight}}] + Q_{\\mathrm{bias}}) \\cdot (S_{\\mathrm{input}}S_{\\mathrm{weight}} / S_{\\mathrm{output}}) + Z_{\\mathrm{output}}$\n","\n","Similarily, for fully-connected layer, we have\n","\n","> $q_{\\mathrm{output}} = (\\mathrm{Linear}[q_{\\mathrm{input}}, q_{\\mathrm{weight}}] + Q_{\\mathrm{bias}})\\cdot (S_{\\mathrm{input}} \\cdot S_{\\mathrm{weight}} / S_{\\mathrm{output}}) + Z_{\\mathrm{output}}$\n","\n","where\n","\n","> $Q_{\\mathrm{bias}} = q_{\\mathrm{bias}} - \\mathrm{Linear}[Z_{\\mathrm{input}}, q_{\\mathrm{weight}}]$"]},{"cell_type":"markdown","metadata":{"id":"XlH0zg7M2J_L"},"source":["### Question 6 (5 pts)\n","\n","Please complete the following function for linear quantizing the bias.\n","\n","**Hint**:\n","\n","From the above deduction, we know that\n","\n","> $Z_{\\mathrm{bias}} = 0$\n",">\n","> $S_{\\mathrm{bias}} = S_{\\mathrm{input}} \\cdot S_{\\mathrm{weight}}$"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"0JZiyAms2G1J"},"outputs":[],"source":["def linear_quantize_bias_per_output_channel(bias, weight_scale, input_scale):\n"," \"\"\"\n"," linear quantization for single bias tensor\n"," quantized_bias = fp_bias / bias_scale\n"," :param bias: [torch.FloatTensor] bias weight to be quantized\n"," :param weight_scale: [float or torch.FloatTensor] weight scale tensor\n"," :param input_scale: [float] input scale\n"," :return:\n"," [torch.IntTensor] quantized bias tensor\n"," \"\"\"\n"," assert(bias.dim() == 1)\n"," assert(bias.dtype == torch.float)\n"," assert(isinstance(input_scale, float))\n"," if isinstance(weight_scale, torch.Tensor):\n"," assert(weight_scale.dtype == torch.float)\n"," weight_scale = weight_scale.view(-1)\n"," assert(bias.numel() == weight_scale.numel())\n","\n"," ############### YOUR CODE STARTS HERE ###############\n"," # hint: one line of code\n"," bias_scale = 0\n"," ############### YOUR CODE ENDS HERE #################\n","\n"," quantized_bias = linear_quantize(bias, 32, bias_scale,\n"," zero_point=0, dtype=torch.int32)\n"," return quantized_bias, bias_scale, 0"]},{"cell_type":"markdown","metadata":{"id":"mMM7uYX25rFM"},"source":["### Quantized Fully-Connected Layer"]},{"cell_type":"markdown","metadata":{"id":"znsT4EWL5tA5"},"source":["For quantized fully-connected layer, we first precompute $Q_{\\mathrm{bias}}$. Recall that $Q_{\\mathrm{bias}} = q_{\\mathrm{bias}} - \\mathrm{Linear}[Z_{\\mathrm{input}}, q_{\\mathrm{weight}}]$."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"4rnNs4MN5tgF"},"outputs":[],"source":["def shift_quantized_linear_bias(quantized_bias, quantized_weight, input_zero_point):\n"," \"\"\"\n"," shift quantized bias to incorporate input_zero_point for nn.Linear\n"," shifted_quantized_bias = quantized_bias - Linear(input_zero_point, quantized_weight)\n"," :param quantized_bias: [torch.IntTensor] quantized bias (torch.int32)\n"," :param quantized_weight: [torch.CharTensor] quantized weight (torch.int8)\n"," :param input_zero_point: [int] input zero point\n"," :return:\n"," [torch.IntTensor] shifted quantized bias tensor\n"," \"\"\"\n"," assert(quantized_bias.dtype == torch.int32)\n"," assert(isinstance(input_zero_point, int))\n"," return quantized_bias - quantized_weight.sum(1).to(torch.int32) * input_zero_point"]},{"cell_type":"markdown","metadata":{"id":"DqNMxMzk3cDz"},"source":["#### Question 7 (15 pts)\n","\n","Please complete the following quantized fully-connected layer inference function.\n","\n","**Hint**:\n","\n","> $q_{\\mathrm{output}} = (\\mathrm{Linear}[q_{\\mathrm{input}}, q_{\\mathrm{weight}}] + Q_{\\mathrm{bias}})\\cdot (S_{\\mathrm{input}} S_{\\mathrm{weight}} / S_{\\mathrm{output}}) + Z_{\\mathrm{output}}$"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"3PVvI7jP3cMo"},"outputs":[],"source":["def quantized_linear(input, weight, bias, feature_bitwidth, weight_bitwidth,\n"," input_zero_point, output_zero_point,\n"," input_scale, weight_scale, output_scale):\n"," \"\"\"\n"," quantized fully-connected layer\n"," :param input: [torch.CharTensor] quantized input (torch.int8)\n"," :param weight: [torch.CharTensor] quantized weight (torch.int8)\n"," :param bias: [torch.IntTensor] shifted quantized bias or None (torch.int32)\n"," :param feature_bitwidth: [int] quantization bit width of input and output\n"," :param weight_bitwidth: [int] quantization bit width of weight\n"," :param input_zero_point: [int] input zero point\n"," :param output_zero_point: [int] output zero point\n"," :param input_scale: [float] input feature scale\n"," :param weight_scale: [torch.FloatTensor] weight per-channel scale\n"," :param output_scale: [float] output feature scale\n"," :return:\n"," [torch.CharIntTensor] quantized output feature (torch.int8)\n"," \"\"\"\n"," assert(input.dtype == torch.int8)\n"," assert(weight.dtype == input.dtype)\n"," assert(bias is None or bias.dtype == torch.int32)\n"," assert(isinstance(input_zero_point, int))\n"," assert(isinstance(output_zero_point, int))\n"," assert(isinstance(input_scale, float))\n"," assert(isinstance(output_scale, float))\n"," assert(weight_scale.dtype == torch.float)\n","\n"," # Step 1: integer-based fully-connected (8-bit multiplication with 32-bit accumulation)\n"," if 'cpu' in input.device.type:\n"," # use 32-b MAC for simplicity\n"," output = torch.nn.functional.linear(input.to(torch.int32), weight.to(torch.int32), bias)\n"," else:\n"," # current version pytorch does not yet support integer-based linear() on GPUs\n"," output = torch.nn.functional.linear(input.float(), weight.float(), bias.float())\n","\n"," ############### YOUR CODE STARTS HERE ###############\n"," # Step 2: scale the output\n"," # hint: 1. scales are floating numbers, we need to convert output to float as well\n"," # 2. the shape of weight scale is [oc, 1, 1, 1] while the shape of output is [batch_size, oc]\n"," output = 0\n","\n"," # Step 3: shift output by output_zero_point\n"," # hint: one line of code\n"," output = 0\n"," ############### YOUR CODE ENDS HERE #################\n","\n"," # Make sure all value lies in the bitwidth-bit range\n"," output = output.round().clamp(*get_quantized_range(feature_bitwidth)).to(torch.int8)\n"," return output"]},{"cell_type":"markdown","metadata":{"id":"115enVamIG_e"},"source":["Let's verify the functionality of defined quantized fully connected layer."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"HWmsLxgHH_B3"},"outputs":[],"source":["test_quantized_fc()"]},{"cell_type":"markdown","metadata":{"id":"ATooyRrH50ls"},"source":["### Quantized Convolution"]},{"cell_type":"markdown","metadata":{"id":"8mk1Ziof51JG"},"source":["For quantized convolution layer, we first precompute $Q_{\\mathrm{bias}}$. Recall that $Q_{\\mathrm{bias}} = q_{\\mathrm{bias}} - \\mathrm{CONV}[Z_{\\mathrm{input}}, q_{\\mathrm{weight}}]$."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"wEeANE_I53hz"},"outputs":[],"source":["def shift_quantized_conv2d_bias(quantized_bias, quantized_weight, input_zero_point):\n"," \"\"\"\n"," shift quantized bias to incorporate input_zero_point for nn.Conv2d\n"," shifted_quantized_bias = quantized_bias - Conv(input_zero_point, quantized_weight)\n"," :param quantized_bias: [torch.IntTensor] quantized bias (torch.int32)\n"," :param quantized_weight: [torch.CharTensor] quantized weight (torch.int8)\n"," :param input_zero_point: [int] input zero point\n"," :return:\n"," [torch.IntTensor] shifted quantized bias tensor\n"," \"\"\"\n"," assert(quantized_bias.dtype == torch.int32)\n"," assert(isinstance(input_zero_point, int))\n"," return quantized_bias - quantized_weight.sum((1,2,3)).to(torch.int32) * input_zero_point"]},{"cell_type":"markdown","metadata":{"id":"0x2SqxOp23cw"},"source":["#### Question 8 (15 pts)\n","\n","Please complete the following quantized convolution function.\n","\n","**Hint**:\n","> $q_{\\mathrm{output}} = (\\mathrm{CONV}[q_{\\mathrm{input}}, q_{\\mathrm{weight}}] + Q_{\\mathrm{bias}}) \\cdot (S_{\\mathrm{input}}S_{\\mathrm{weight}} / S_{\\mathrm{output}}) + Z_{\\mathrm{output}}$\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"LVRqhiUno65x"},"outputs":[],"source":["def quantized_conv2d(input, weight, bias, feature_bitwidth, weight_bitwidth,\n"," input_zero_point, output_zero_point,\n"," input_scale, weight_scale, output_scale,\n"," stride, padding, dilation, groups):\n"," \"\"\"\n"," quantized 2d convolution\n"," :param input: [torch.CharTensor] quantized input (torch.int8)\n"," :param weight: [torch.CharTensor] quantized weight (torch.int8)\n"," :param bias: [torch.IntTensor] shifted quantized bias or None (torch.int32)\n"," :param feature_bitwidth: [int] quantization bit width of input and output\n"," :param weight_bitwidth: [int] quantization bit width of weight\n"," :param input_zero_point: [int] input zero point\n"," :param output_zero_point: [int] output zero point\n"," :param input_scale: [float] input feature scale\n"," :param weight_scale: [torch.FloatTensor] weight per-channel scale\n"," :param output_scale: [float] output feature scale\n"," :return:\n"," [torch.(cuda.)CharTensor] quantized output feature\n"," \"\"\"\n"," assert(len(padding) == 4)\n"," assert(input.dtype == torch.int8)\n"," assert(weight.dtype == input.dtype)\n"," assert(bias is None or bias.dtype == torch.int32)\n"," assert(isinstance(input_zero_point, int))\n"," assert(isinstance(output_zero_point, int))\n"," assert(isinstance(input_scale, float))\n"," assert(isinstance(output_scale, float))\n"," assert(weight_scale.dtype == torch.float)\n","\n"," # Step 1: calculate integer-based 2d convolution (8-bit multiplication with 32-bit accumulation)\n"," input = torch.nn.functional.pad(input, padding, 'constant', input_zero_point)\n"," if 'cpu' in input.device.type:\n"," # use 32-b MAC for simplicity\n"," output = torch.nn.functional.conv2d(input.to(torch.int32), weight.to(torch.int32), None, stride, 0, dilation, groups)\n"," else:\n"," # current version pytorch does not yet support integer-based conv2d() on GPUs\n"," output = torch.nn.functional.conv2d(input.float(), weight.float(), None, stride, 0, dilation, groups)\n"," output = output.round().to(torch.int32)\n"," if bias is not None:\n"," output = output + bias.view(1, -1, 1, 1)\n","\n"," ############### YOUR CODE STARTS HERE ###############\n"," # hint: this code block should be the very similar to quantized_linear()\n","\n"," # Step 2: scale the output\n"," # hint: 1. scales are floating numbers, we need to convert output to float as well\n"," # 2. the shape of weight scale is [oc, 1, 1, 1] while the shape of output is [batch_size, oc, height, width]\n"," output = 0\n","\n"," # Step 3: shift output by output_zero_point\n"," # hint: one line of code\n"," output = 0\n"," ############### YOUR CODE ENDS HERE #################\n","\n"," # Make sure all value lies in the bitwidth-bit range\n"," output = output.round().clamp(*get_quantized_range(feature_bitwidth)).to(torch.int8)\n"," return output"]},{"cell_type":"markdown","metadata":{"id":"32vvQ4WJHVlA"},"source":["## Question 9 (10 pts)\n","\n","Finally, we are putting everything together and perform post-training `int8` quantization for the model. We will convert the convolutional and linear layers in the model to a quantized version one-by-one."]},{"cell_type":"markdown","metadata":{"id":"2E5EDR2sINVS"},"source":["1. Firstly, we will fuse a BatchNorm layer into its previous convolutional layer, which is a standard practice before quantization. Fusing batchnorm reduces the extra multiplication during inference.\n","\n","We will also verify that the fused model `model_fused` has the same accuracy as the original model (BN fusion is an equivalent transform that does not change network functionality)."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"V2K8KSl7IE4D"},"outputs":[],"source":["def fuse_conv_bn(conv, bn):\n"," # modified from https://mmcv.readthedocs.io/en/latest/_modules/mmcv/cnn/utils/fuse_conv_bn.html\n"," assert conv.bias is None\n","\n"," factor = bn.weight.data / torch.sqrt(bn.running_var.data + bn.eps)\n"," conv.weight.data = conv.weight.data * factor.reshape(-1, 1, 1, 1)\n"," conv.bias = nn.Parameter(- bn.running_mean.data * factor + bn.bias.data)\n","\n"," return conv\n","\n","print('Before conv-bn fusion: backbone length', len(model.backbone))\n","# fuse the batchnorm into conv layers\n","recover_model()\n","model_fused = copy.deepcopy(model)\n","fused_backbone = []\n","ptr = 0\n","while ptr < len(model_fused.backbone):\n"," if isinstance(model_fused.backbone[ptr], nn.Conv2d) and \\\n"," isinstance(model_fused.backbone[ptr + 1], nn.BatchNorm2d):\n"," fused_backbone.append(fuse_conv_bn(\n"," model_fused.backbone[ptr], model_fused.backbone[ptr+ 1]))\n"," ptr += 2\n"," else:\n"," fused_backbone.append(model_fused.backbone[ptr])\n"," ptr += 1\n","model_fused.backbone = nn.Sequential(*fused_backbone)\n","\n","print('After conv-bn fusion: backbone length', len(model_fused.backbone))\n","# sanity check, no BN anymore\n","for m in model_fused.modules():\n"," assert not isinstance(m, nn.BatchNorm2d)\n","\n","# the accuracy will remain the same after fusion\n","fused_acc = evaluate(model_fused, dataloader['test'])\n","print(f'Accuracy of the fused model={fused_acc:.2f}%')"]},{"cell_type":"markdown","metadata":{"id":"oCOXprerMY5t"},"source":["2. We will run the model with some sample data to get the range of each feature map, so that we can get the range of the feature maps and compute their corresponding scaling factors and zero points."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"CjKrC7L-UrxZ"},"outputs":[],"source":["# add hook to record the min max value of the activation\n","input_activation = {}\n","output_activation = {}\n","\n","def add_range_recoder_hook(model):\n"," import functools\n"," def _record_range(self, x, y, module_name):\n"," x = x[0]\n"," input_activation[module_name] = x.detach()\n"," output_activation[module_name] = y.detach()\n","\n"," all_hooks = []\n"," for name, m in model.named_modules():\n"," if isinstance(m, (nn.Conv2d, nn.Linear, nn.ReLU)):\n"," all_hooks.append(m.register_forward_hook(\n"," functools.partial(_record_range, module_name=name)))\n"," return all_hooks\n","\n","hooks = add_range_recoder_hook(model_fused)\n","sample_data = iter(dataloader['train']).__next__()[0]\n","model_fused(sample_data.cuda())\n","\n","# remove hooks\n","for h in hooks:\n"," h.remove()\n"]},{"cell_type":"markdown","metadata":{"id":"uq3fvyx_6bLW"},"source":["3. Finally, let's do model quantization. We will convert the model in the following mapping\n","```python\n","nn.Conv2d: QuantizedConv2d,\n","nn.Linear: QuantizedLinear,\n","# the following twos are just wrappers, as current\n","# torch modules do not support int8 data format;\n","# we will temporarily convert them to fp32 for computation\n","nn.MaxPool2d: QuantizedMaxPool2d,\n","nn.AvgPool2d: QuantizedAvgPool2d,\n","```"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"iFD5W0H1Zban"},"outputs":[],"source":["class QuantizedConv2d(nn.Module):\n"," def __init__(self, weight, bias,\n"," input_zero_point, output_zero_point,\n"," input_scale, weight_scale, output_scale,\n"," stride, padding, dilation, groups,\n"," feature_bitwidth=8, weight_bitwidth=8):\n"," super().__init__()\n"," # current version Pytorch does not support IntTensor as nn.Parameter\n"," self.register_buffer('weight', weight)\n"," self.register_buffer('bias', bias)\n","\n"," self.input_zero_point = input_zero_point\n"," self.output_zero_point = output_zero_point\n","\n"," self.input_scale = input_scale\n"," self.register_buffer('weight_scale', weight_scale)\n"," self.output_scale = output_scale\n","\n"," self.stride = stride\n"," self.padding = (padding[1], padding[1], padding[0], padding[0])\n"," self.dilation = dilation\n"," self.groups = groups\n","\n"," self.feature_bitwidth = feature_bitwidth\n"," self.weight_bitwidth = weight_bitwidth\n","\n","\n"," def forward(self, x):\n"," return quantized_conv2d(\n"," x, self.weight, self.bias,\n"," self.feature_bitwidth, self.weight_bitwidth,\n"," self.input_zero_point, self.output_zero_point,\n"," self.input_scale, self.weight_scale, self.output_scale,\n"," self.stride, self.padding, self.dilation, self.groups\n"," )\n","\n","class QuantizedLinear(nn.Module):\n"," def __init__(self, weight, bias,\n"," input_zero_point, output_zero_point,\n"," input_scale, weight_scale, output_scale,\n"," feature_bitwidth=8, weight_bitwidth=8):\n"," super().__init__()\n"," # current version Pytorch does not support IntTensor as nn.Parameter\n"," self.register_buffer('weight', weight)\n"," self.register_buffer('bias', bias)\n","\n"," self.input_zero_point = input_zero_point\n"," self.output_zero_point = output_zero_point\n","\n"," self.input_scale = input_scale\n"," self.register_buffer('weight_scale', weight_scale)\n"," self.output_scale = output_scale\n","\n"," self.feature_bitwidth = feature_bitwidth\n"," self.weight_bitwidth = weight_bitwidth\n","\n"," def forward(self, x):\n"," return quantized_linear(\n"," x, self.weight, self.bias,\n"," self.feature_bitwidth, self.weight_bitwidth,\n"," self.input_zero_point, self.output_zero_point,\n"," self.input_scale, self.weight_scale, self.output_scale\n"," )\n","\n","class QuantizedMaxPool2d(nn.MaxPool2d):\n"," def forward(self, x):\n"," # current version PyTorch does not support integer-based MaxPool\n"," return super().forward(x.float()).to(torch.int8)\n","\n","class QuantizedAvgPool2d(nn.AvgPool2d):\n"," def forward(self, x):\n"," # current version PyTorch does not support integer-based AvgPool\n"," return super().forward(x.float()).to(torch.int8)\n","\n","# we use int8 quantization, which is quite popular\n","feature_bitwidth = weight_bitwidth = 8\n","quantized_model = copy.deepcopy(model_fused)\n","quantized_backbone = []\n","ptr = 0\n","while ptr < len(quantized_model.backbone):\n"," if isinstance(quantized_model.backbone[ptr], nn.Conv2d) and \\\n"," isinstance(quantized_model.backbone[ptr + 1], nn.ReLU):\n"," conv = quantized_model.backbone[ptr]\n"," conv_name = f'backbone.{ptr}'\n"," relu = quantized_model.backbone[ptr + 1]\n"," relu_name = f'backbone.{ptr + 1}'\n","\n"," input_scale, input_zero_point = \\\n"," get_quantization_scale_and_zero_point(\n"," input_activation[conv_name], feature_bitwidth)\n","\n"," output_scale, output_zero_point = \\\n"," get_quantization_scale_and_zero_point(\n"," output_activation[relu_name], feature_bitwidth)\n","\n"," quantized_weight, weight_scale, weight_zero_point = \\\n"," linear_quantize_weight_per_channel(conv.weight.data, weight_bitwidth)\n"," quantized_bias, bias_scale, bias_zero_point = \\\n"," linear_quantize_bias_per_output_channel(\n"," conv.bias.data, weight_scale, input_scale)\n"," shifted_quantized_bias = \\\n"," shift_quantized_conv2d_bias(quantized_bias, quantized_weight,\n"," input_zero_point)\n","\n"," quantized_conv = QuantizedConv2d(\n"," quantized_weight, shifted_quantized_bias,\n"," input_zero_point, output_zero_point,\n"," input_scale, weight_scale, output_scale,\n"," conv.stride, conv.padding, conv.dilation, conv.groups,\n"," feature_bitwidth=feature_bitwidth, weight_bitwidth=weight_bitwidth\n"," )\n","\n"," quantized_backbone.append(quantized_conv)\n"," ptr += 2\n"," elif isinstance(quantized_model.backbone[ptr], nn.MaxPool2d):\n"," quantized_backbone.append(QuantizedMaxPool2d(\n"," kernel_size=quantized_model.backbone[ptr].kernel_size,\n"," stride=quantized_model.backbone[ptr].stride\n"," ))\n"," ptr += 1\n"," elif isinstance(quantized_model.backbone[ptr], nn.AvgPool2d):\n"," quantized_backbone.append(QuantizedAvgPool2d(\n"," kernel_size=quantized_model.backbone[ptr].kernel_size,\n"," stride=quantized_model.backbone[ptr].stride\n"," ))\n"," ptr += 1\n"," else:\n"," raise NotImplementedError(type(quantized_model.backbone[ptr])) # should not happen\n","quantized_model.backbone = nn.Sequential(*quantized_backbone)\n","\n","# finally, quantized the classifier\n","fc_name = 'classifier'\n","fc = model.classifier\n","input_scale, input_zero_point = \\\n"," get_quantization_scale_and_zero_point(\n"," input_activation[fc_name], feature_bitwidth)\n","\n","output_scale, output_zero_point = \\\n"," get_quantization_scale_and_zero_point(\n"," output_activation[fc_name], feature_bitwidth)\n","\n","quantized_weight, weight_scale, weight_zero_point = \\\n"," linear_quantize_weight_per_channel(fc.weight.data, weight_bitwidth)\n","quantized_bias, bias_scale, bias_zero_point = \\\n"," linear_quantize_bias_per_output_channel(\n"," fc.bias.data, weight_scale, input_scale)\n","shifted_quantized_bias = \\\n"," shift_quantized_linear_bias(quantized_bias, quantized_weight,\n"," input_zero_point)\n","\n","quantized_model.classifier = QuantizedLinear(\n"," quantized_weight, shifted_quantized_bias,\n"," input_zero_point, output_zero_point,\n"," input_scale, weight_scale, output_scale,\n"," feature_bitwidth=feature_bitwidth, weight_bitwidth=weight_bitwidth\n",")"]},{"cell_type":"markdown","metadata":{"id":"YIPgF84IygCy"},"source":["The quantization process is done! Let's print and visualize the model architecture and also verify the accuracy of the quantized model.\n"]},{"cell_type":"markdown","metadata":{"id":"vGkf1sf27uFj"},"source":["### Question 9.1 (5 pts)\n","\n","To run the quantized model, we need an extra preprocessing to map the input data from range (0, 1) into `int8` range of (-128, 127). Fill in the code below to finish the extra preprocessing.\n","\n","**Hint**: you should find that the quantized model has roughly the same accuracy as the `fp32` counterpart."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"PEZT47BE6dXe"},"outputs":[],"source":["print(quantized_model)\n","\n","def extra_preprocess(x):\n"," # hint: you need to convert the original fp32 input of range (0, 1)\n"," # into int8 format of range (-128, 127)\n"," ############### YOUR CODE STARTS HERE ###############\n"," return 0.clamp(-128, 127).to(torch.int8)\n"," ############### YOUR CODE ENDS HERE #################\n","\n","int8_model_accuracy = evaluate(quantized_model, dataloader['test'],\n"," extra_preprocess=[extra_preprocess])\n","print(f\"int8 model has accuracy={int8_model_accuracy:.2f}%\")"]},{"cell_type":"markdown","metadata":{"id":"61iwKRAA8A7c"},"source":["## Question 9.2 (Bonus Question; 5 pts)\n","\n","Explain why there is no ReLU layer in the linear quantized model."]},{"cell_type":"markdown","metadata":{"id":"8Ewbt01T8F1b"},"source":["**Your Answer:**"]},{"cell_type":"markdown","metadata":{"id":"HhSnPZpD7ye9"},"source":["# Question 10 (5 pts)\n","\n","Please compare the advantages and disadvantages of k-means-based quantization and linear quantization. You can discuss from the perspective of accuracy, latency, hardware support, etc."]},{"cell_type":"markdown","metadata":{"id":"mhlf55UM73yA"},"source":["**Your Answer:**"]},{"cell_type":"markdown","metadata":{"id":"xxOBqoXoSUfE"},"source":["# Feedback"]},{"cell_type":"markdown","metadata":{"id":"ajqZJes3SVc-"},"source":["Please fill out this [feedback form](https://forms.gle/ZeCH5anNPrkd5wpp7) when you finished this lab. We would love to hear your thoughts or feedback on how we can improve this lab!"]}],"metadata":{"accelerator":"GPU","colab":{"collapsed_sections":["2tFjnZZVlIFL"],"provenance":[],"toc_visible":true},"gpuClass":"standard","kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.9.13"}},"nbformat":4,"nbformat_minor":0}
|