MobileNet-Classification-Model
/
SightLinks-Dev-main
/DevelopmentTools
/experimental_resources
/quantisedMobileNet.py
| # BETA VERSION - NEEDS FURTHER DEVELOPMENT | |
| # Read these to catch up on what is (trying to at least) being done here | |
| # https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html | |
| # https://pytorch.org/docs/stable/quantization.html#model-preparation-for-eager-mode-static-quantization | |
| # Torch implementation of these models - mine is heavily based on these with some minor adjustments | |
| # | |
| # I've added squeeze and excitation layers to the MobileNetV2, a feature of MobileNetV3, but I did not put in | |
| # NAS (unnecessary since we're not optimising for mobile) or hardswish (because I prefer ReLU/ think it is better) | |
| # https://github.com/pytorch/vision/blob/main/torchvision/models/mobilenetv3.py#L117 | |
| # https://github.com/pytorch/vision/blob/11bf27e37190b320216c349e39b085fb33aefed1/torchvision/models/mobilenetv3.py#L56 | |
| # This is an adapted version of MobileNet, somewhere between versions 2/3, as some features of 3 were not required. There are | |
| # also some additions for our particular use case from miscallaneous sources | |
| from torchvision import transforms | |
| import torch | |
| from torch import nn, Tensor | |
| from torch.nn import functional as F | |
| from torch.utils.data import DataLoader, Subset | |
| import ClassUtils | |
| from torch.ao.quantization import QuantStub, DeQuantStub | |
| from torchvision.models.mobilenetv2 import _make_divisible | |
| import time | |
| import random | |
| import os | |
| import matplotlib.pyplot as plt | |
| # Squeeze: summarising global context by pooling feature maps into a single value | |
| # Excitation: Learning attention weights for each channel to prioritise the most relevant ones | |
| class SqueezeExcitation(nn.Module): | |
| def __init__(self, input_channels:int, squeeze_factor: int = 4): | |
| super().__init__() | |
| # If channels are a multiple of 8, they're optimised by the hardware | |
| squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8) | |
| self.squeeze = nn.Conv2d(input_channels, squeeze_channels, 1) | |
| self.relu = nn.ReLU(inplace=True) | |
| self.unsqueeze = nn.Conv2d(squeeze_channels, input_channels, 1) | |
| self.quant = nn.quantized.FloatFunctional() | |
| # Scale returns the feature attention map, how much attention should be payed to each input layer, in range [0, 1] | |
| # Inplace is used to save memory on operations - it might not be necessary in our case since we aren't using edge devices | |
| def _scale(self, input: Tensor, inplace=bool) -> Tensor: | |
| # Squeeze | |
| scale = F.adaptive_avg_pool2d(input, 1) | |
| scale = self.squeeze(scale) | |
| # Excite | |
| scale = self.relu(scale) | |
| scale = self.unsqueeze(scale) | |
| return F.hardsigmoid(scale, inplace=inplace) | |
| def forward(self, input: Tensor) -> Tensor: | |
| # print(self._scale(input, True)) | |
| # print(input) | |
| return self.quant.mul(self._scale(input, True), input) | |
| # The basic building block of our convolutional neural network | |
| # - qconfig should automatically insert fakeQuantisation operations during training, so there is no need to manually place them now | |
| class ConvBNReLu(nn.Sequential): | |
| def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): | |
| padding = (kernel_size - 1) // 2 | |
| super().__init__( | |
| nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), | |
| # No point applying a bias (constant addative term) if the next layer is a batch normalisation layer | |
| nn.BatchNorm2d(out_planes, momentum=0.1), | |
| nn.ReLU(inplace=True) | |
| ) | |
| # Like typical residual blocks but uses inverse narrow->wide->narrow, with Depth-wise convolutions instead of normal, | |
| # to reduce the number of parameters required compared to the usual residual blocks | |
| class InvertedResidual(nn.Module): | |
| def __init__(self, inpt, oupt, stride, expnd_ratio, kernel_size=3, se_layer=None): | |
| super().__init__() | |
| self.stride = stride | |
| assert stride in [1, 2] | |
| intermediate_channels = int(round(inpt * expnd_ratio)) | |
| # If the stride != 1, downsampling occurs so cannot be true. | |
| self.use_residual = (stride==1) and (inpt==oupt) | |
| # Squeeze and excitation layer - applied after the dw and pw convolutions, but before the residual | |
| self.se_layer = se_layer if se_layer else None | |
| layers = [] | |
| if expnd_ratio != 1: | |
| # Pointwise convolution to increase the channels | |
| layers.append(ConvBNReLu(inpt, intermediate_channels, kernel_size=1)) | |
| layers.extend([ | |
| # Depthwise convolution - each channel is convoled on an independent basis | |
| ConvBNReLu(intermediate_channels, intermediate_channels, stride=stride, groups=intermediate_channels), | |
| # point-wise convolution - linear combination to reduce layers back to the expected number | |
| nn.Conv2d(intermediate_channels, oupt, 1, 1, 0, bias=False), | |
| nn.BatchNorm2d(oupt, momentum=0.25) | |
| ]) | |
| self.conv = nn.Sequential(*layers) | |
| def forward(self, x): | |
| outpt = self.conv(x) | |
| if self.se_layer is not None: | |
| outpt = self.se_layer(outpt) | |
| if self.use_residual: | |
| return x + outpt | |
| else: | |
| return outpt | |
| # Same as the inverted residual, but replaces addition with a quantizable friendly operation | |
| class QuantizableInvertedResidual(InvertedResidual): | |
| def __init__(self, inpt, outpt, stride, expnd_ratio, se_layer=None): | |
| super().__init__(inpt, outpt, stride, expnd_ratio, se_layer=se_layer) | |
| self.skip_add = nn.quantized.FloatFunctional() | |
| # Overwrites the forwarding to use a quantizable friendly version of the addition | |
| def forward(self, x): | |
| outpt = self.conv(x) | |
| if self.se_layer is not None: | |
| outpt = self.se_layer(outpt) | |
| if self.use_residual: | |
| return self.skip_add.add(x, outpt) | |
| else: | |
| return outpt | |
| # The MobileNetV2 Architecture + some features from V3 (squeeze and excitation) but I didn't add NAS since we aren't running this on mobile | |
| # And I prefer ReLU over hardswish | |
| class MobileNetV2_5(nn.Module): | |
| def __init__(self, class_num=2, width_mult=1.0, round_nearest=8): | |
| super().__init__() | |
| layers = [] | |
| input_channel = 32 | |
| last_channel = 1280 | |
| # Just straight up copying this from the torchvision implementation | |
| self.residual_params = [ | |
| # expnd_ratio, outpt_channels, num_blocks, stride | |
| [1, 16, 1, 1], | |
| [6, 24, 2, 2], | |
| [6, 32, 3, 2], | |
| [6, 64, 4, 2], | |
| [6, 96, 3, 1], | |
| [6, 160, 3, 2], | |
| [6, 320, 1, 1], | |
| ] | |
| first_conv_output_channels = _make_divisible(self.residual_params[0][1] *width_mult, round_nearest) | |
| layers.append( | |
| ConvBNReLu(3, | |
| first_conv_output_channels, | |
| kernel_size=3, | |
| stride=2, | |
| ) | |
| ) | |
| prev_input_channels = first_conv_output_channels | |
| # Main body of feature extraction | |
| for expnd, oupt_c, num_blocks, strd in self.residual_params: | |
| # output channels must be a multiple of 8 for hardware optimisation | |
| output_channel = _make_divisible(oupt_c * width_mult, round_nearest) | |
| for i in range(num_blocks): | |
| stride = strd if i == 0 else 1 | |
| se_layer = SqueezeExcitation(oupt_c) if i == 0 else None | |
| layers.append(QuantizableInvertedResidual(prev_input_channels, output_channel, stride, expnd_ratio=expnd, se_layer=se_layer)) | |
| prev_input_channels = output_channel | |
| self.last_channel = _make_divisible(last_channel * max(width_mult, 1.0), round_nearest) | |
| # We could put this in the classifier, but I want that to be lightweight so that we could do transfer learning only on the head and | |
| # the feature extraction part of the model. | |
| layers.append( | |
| ConvBNReLu(prev_input_channels, self.last_channel, kernel_size=1) | |
| ) | |
| self.feature_extraction = nn.Sequential(*layers) | |
| self.avg_pooling = nn.AdaptiveAvgPool2d(1) | |
| self.classifier = nn.Sequential( | |
| nn.Dropout(0.125), | |
| nn.Linear(last_channel, class_num) | |
| ) | |
| # This bit is also just straight up copied from torch's implementation - I'm not touching it in case it gets messed up | |
| # weight initialization | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| nn.init.kaiming_normal_(m.weight, mode='fan_out') | |
| if m.bias is not None: | |
| nn.init.zeros_(m.bias) | |
| elif isinstance(m, nn.BatchNorm2d): | |
| nn.init.ones_(m.weight) | |
| nn.init.zeros_(m.bias) | |
| elif isinstance(m, nn.Linear): | |
| nn.init.normal_(m.weight, 0, 0.01) | |
| nn.init.zeros_(m.bias) | |
| def forward(self, x: Tensor) -> Tensor: | |
| x = self.feature_extraction(x) | |
| x = self.avg_pooling(x) | |
| x = torch.flatten(x, 1) | |
| print("eyo") | |
| x = self.classifier(x) | |
| return x | |
| class QuantizableMobileNetV2_5(MobileNetV2_5): | |
| def __init__(self, class_num=2, width_mult=1.0, round_nearest=8): | |
| super().__init__(class_num=class_num, width_mult=width_mult, round_nearest=round_nearest) | |
| self.quant = QuantStub() | |
| self.dequant = DeQuantStub() | |
| def _forward_impl(self, x: Tensor) -> Tensor: | |
| x = self.feature_extraction(x) | |
| # This was for debugging errors in shape of feature maps as they pass through - not deleting incase useful later | |
| # for idx, layer in enumerate(self.feature_extraction): | |
| # x = layer(x) | |
| # print(f"Feature extraction layer {idx}, output shape: {x.shape}") | |
| x = self.avg_pooling(x) | |
| x = torch.flatten(x, 1) | |
| x = self.classifier(x) | |
| return x | |
| def forward(self, x: Tensor) -> Tensor: | |
| x = self.quant(x) | |
| x = self._forward_impl(x) | |
| x= self.dequant(x) | |
| return x | |
| def train_single_epoch(model, loss_fnc, optimiser, data_loader, device): | |
| model.train() | |
| running_loss = 0 | |
| running_time = 0.0 | |
| for images, labels in data_loader: | |
| start_time = time.time() | |
| print(".", end=" ") | |
| images, labels = images.to(device), labels.to(device) | |
| preds = model(images) | |
| loss = loss_fnc(preds, labels) | |
| loss.backward() | |
| optimiser.step() | |
| running_loss += loss.item() | |
| running_time += time.time() - start_time | |
| print(f"{(time.time() - start_time):.2f}, {(running_time):.2f}", end=" ") | |
| print(f"loss of {running_loss}") | |
| return | |
| def print_size_of_model(model): | |
| torch.save(model.state_dict(), "temp.p") | |
| print('Size (MB):', os.path.getsize("temp.p")/1e6) | |
| os.remove('temp.p') | |
| def adjust_quantisation_engine(): | |
| # Adjust according to what your device supports | |
| print(torch.backends.quantized.supported_engines) | |
| torch.backends.quantized.engine = 'qnnpack' | |
| def train_model(model, dataloader, loss_function, optimiser, epoch_number=25, const_save=False, save=True): | |
| for epoch in range(epoch_number): | |
| print("IT IS EPOCH", epoch) | |
| train_single_epoch(model, loss_function, optimiser, dataloader, torch.device('cpu')) | |
| # Gradually freezes the unrequired observer parameters for quantisation and batch normalisation after a few epochs | |
| if epoch > 3: | |
| # Freeze quantizer parameters | |
| model.apply(torch.ao.quantization.disable_observer) | |
| if epoch > 2: | |
| # Freeze batch norm mean and variance estimates | |
| model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) | |
| if const_save: | |
| quantized_model = torch.ao.quantization.convert(model.eval(), inplace=False) | |
| quantized_model.eval() | |
| # Saving each intermediary model since they're so small, and this lets load up any of them for performace difference examples later | |
| torch.save(quantized_model.state_dict(), "quantStateDict"+str(epoch+1)+".pth") | |
| print(f"the above was Epoch {epoch+1} of {epoch_number} \nThe model has a size of", end=" ") | |
| print_size_of_model(quantized_model) | |
| else: | |
| print(f"the above was Epoch {epoch} of {epoch_number}") | |
| if save: | |
| torch.save(quantized_model.state_dict(), "full_quantStateDict.pth") | |
| return model | |
| learning_rate = 1e-3 | |
| batch_size = 64 | |
| data_size = 2560 | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Update to whatever you call your model | |
| modelName = "quantStateDict8.pth" | |
| load = False | |
| model = QuantizableMobileNetV2_5() | |
| # Adjust according to what your device supports | |
| torch.backends.quantized.engine = 'qnnpack' | |
| model.qconfig = torch.ao.quantization.default_qconfig | |
| optimiser = torch.optim.SGD(model.parameters(), lr= learning_rate) | |
| torch.ao.quantization.prepare_qat(model, inplace=True) | |
| dataset = ClassUtils.CrosswalkDataset("zebra_annotations/classification_data") | |
| train_loader = DataLoader( | |
| Subset(dataset, random.sample(list(range(0, int(len(dataset) * 0.95))), data_size)), | |
| batch_size=batch_size, shuffle=True) | |
| test_loader = DataLoader( | |
| Subset(dataset, random.sample(list(range(int(len(dataset) * 0.95), len(dataset))), 256)), | |
| batch_size=batch_size, shuffle=False) | |
| loss_function = nn.BCEWithLogitsLoss() | |
| model_updated = train_model(model, train_loader, loss_function, optimiser, epoch_number=8, const_save=True) | |
| quantized_model = torch.ao.quantization.convert(model_updated.eval(), inplace=True) | |
| if load: | |
| model_loaded_state_dict = torch.load(modelName) | |
| quantized_model.load_state_dict(model_loaded_state_dict) | |
| for images, labels in test_loader: | |
| preds = torch.sigmoid(quantized_model(images)) | |
| for i in range(len(preds)): | |
| print(preds) | |
| # plt.imshow(torch.permute(images[i], (1, 2, 0)).detach().numpy()) | |
| # plt.title(f"Prediction: {preds[i]}, Actual: {labels[i][0] == 1}") | |
| # plt.axis("off") | |
| # plt.show() | |