Edw00765's picture
Upload 21 files
b39a019 verified
# BETA VERSION - NEEDS FURTHER DEVELOPMENT
# Read these to catch up on what is (trying to at least) being done here
# https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html
# https://pytorch.org/docs/stable/quantization.html#model-preparation-for-eager-mode-static-quantization
# Torch implementation of these models - mine is heavily based on these with some minor adjustments
#
# I've added squeeze and excitation layers to the MobileNetV2, a feature of MobileNetV3, but I did not put in
# NAS (unnecessary since we're not optimising for mobile) or hardswish (because I prefer ReLU/ think it is better)
# https://github.com/pytorch/vision/blob/main/torchvision/models/mobilenetv3.py#L117
# https://github.com/pytorch/vision/blob/11bf27e37190b320216c349e39b085fb33aefed1/torchvision/models/mobilenetv3.py#L56
# This is an adapted version of MobileNet, somewhere between versions 2/3, as some features of 3 were not required. There are
# also some additions for our particular use case from miscallaneous sources
from torchvision import transforms
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from torch.utils.data import DataLoader, Subset
import ClassUtils
from torch.ao.quantization import QuantStub, DeQuantStub
from torchvision.models.mobilenetv2 import _make_divisible
import time
import random
import os
import matplotlib.pyplot as plt
# Squeeze: summarising global context by pooling feature maps into a single value
# Excitation: Learning attention weights for each channel to prioritise the most relevant ones
class SqueezeExcitation(nn.Module):
def __init__(self, input_channels:int, squeeze_factor: int = 4):
super().__init__()
# If channels are a multiple of 8, they're optimised by the hardware
squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8)
self.squeeze = nn.Conv2d(input_channels, squeeze_channels, 1)
self.relu = nn.ReLU(inplace=True)
self.unsqueeze = nn.Conv2d(squeeze_channels, input_channels, 1)
self.quant = nn.quantized.FloatFunctional()
# Scale returns the feature attention map, how much attention should be payed to each input layer, in range [0, 1]
# Inplace is used to save memory on operations - it might not be necessary in our case since we aren't using edge devices
def _scale(self, input: Tensor, inplace=bool) -> Tensor:
# Squeeze
scale = F.adaptive_avg_pool2d(input, 1)
scale = self.squeeze(scale)
# Excite
scale = self.relu(scale)
scale = self.unsqueeze(scale)
return F.hardsigmoid(scale, inplace=inplace)
def forward(self, input: Tensor) -> Tensor:
# print(self._scale(input, True))
# print(input)
return self.quant.mul(self._scale(input, True), input)
# The basic building block of our convolutional neural network
# - qconfig should automatically insert fakeQuantisation operations during training, so there is no need to manually place them now
class ConvBNReLu(nn.Sequential):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
padding = (kernel_size - 1) // 2
super().__init__(
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
# No point applying a bias (constant addative term) if the next layer is a batch normalisation layer
nn.BatchNorm2d(out_planes, momentum=0.1),
nn.ReLU(inplace=True)
)
# Like typical residual blocks but uses inverse narrow->wide->narrow, with Depth-wise convolutions instead of normal,
# to reduce the number of parameters required compared to the usual residual blocks
class InvertedResidual(nn.Module):
def __init__(self, inpt, oupt, stride, expnd_ratio, kernel_size=3, se_layer=None):
super().__init__()
self.stride = stride
assert stride in [1, 2]
intermediate_channels = int(round(inpt * expnd_ratio))
# If the stride != 1, downsampling occurs so cannot be true.
self.use_residual = (stride==1) and (inpt==oupt)
# Squeeze and excitation layer - applied after the dw and pw convolutions, but before the residual
self.se_layer = se_layer if se_layer else None
layers = []
if expnd_ratio != 1:
# Pointwise convolution to increase the channels
layers.append(ConvBNReLu(inpt, intermediate_channels, kernel_size=1))
layers.extend([
# Depthwise convolution - each channel is convoled on an independent basis
ConvBNReLu(intermediate_channels, intermediate_channels, stride=stride, groups=intermediate_channels),
# point-wise convolution - linear combination to reduce layers back to the expected number
nn.Conv2d(intermediate_channels, oupt, 1, 1, 0, bias=False),
nn.BatchNorm2d(oupt, momentum=0.25)
])
self.conv = nn.Sequential(*layers)
def forward(self, x):
outpt = self.conv(x)
if self.se_layer is not None:
outpt = self.se_layer(outpt)
if self.use_residual:
return x + outpt
else:
return outpt
# Same as the inverted residual, but replaces addition with a quantizable friendly operation
class QuantizableInvertedResidual(InvertedResidual):
def __init__(self, inpt, outpt, stride, expnd_ratio, se_layer=None):
super().__init__(inpt, outpt, stride, expnd_ratio, se_layer=se_layer)
self.skip_add = nn.quantized.FloatFunctional()
# Overwrites the forwarding to use a quantizable friendly version of the addition
def forward(self, x):
outpt = self.conv(x)
if self.se_layer is not None:
outpt = self.se_layer(outpt)
if self.use_residual:
return self.skip_add.add(x, outpt)
else:
return outpt
# The MobileNetV2 Architecture + some features from V3 (squeeze and excitation) but I didn't add NAS since we aren't running this on mobile
# And I prefer ReLU over hardswish
class MobileNetV2_5(nn.Module):
def __init__(self, class_num=2, width_mult=1.0, round_nearest=8):
super().__init__()
layers = []
input_channel = 32
last_channel = 1280
# Just straight up copying this from the torchvision implementation
self.residual_params = [
# expnd_ratio, outpt_channels, num_blocks, stride
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]
first_conv_output_channels = _make_divisible(self.residual_params[0][1] *width_mult, round_nearest)
layers.append(
ConvBNReLu(3,
first_conv_output_channels,
kernel_size=3,
stride=2,
)
)
prev_input_channels = first_conv_output_channels
# Main body of feature extraction
for expnd, oupt_c, num_blocks, strd in self.residual_params:
# output channels must be a multiple of 8 for hardware optimisation
output_channel = _make_divisible(oupt_c * width_mult, round_nearest)
for i in range(num_blocks):
stride = strd if i == 0 else 1
se_layer = SqueezeExcitation(oupt_c) if i == 0 else None
layers.append(QuantizableInvertedResidual(prev_input_channels, output_channel, stride, expnd_ratio=expnd, se_layer=se_layer))
prev_input_channels = output_channel
self.last_channel = _make_divisible(last_channel * max(width_mult, 1.0), round_nearest)
# We could put this in the classifier, but I want that to be lightweight so that we could do transfer learning only on the head and
# the feature extraction part of the model.
layers.append(
ConvBNReLu(prev_input_channels, self.last_channel, kernel_size=1)
)
self.feature_extraction = nn.Sequential(*layers)
self.avg_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Sequential(
nn.Dropout(0.125),
nn.Linear(last_channel, class_num)
)
# This bit is also just straight up copied from torch's implementation - I'm not touching it in case it gets messed up
# weight initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)
def forward(self, x: Tensor) -> Tensor:
x = self.feature_extraction(x)
x = self.avg_pooling(x)
x = torch.flatten(x, 1)
print("eyo")
x = self.classifier(x)
return x
class QuantizableMobileNetV2_5(MobileNetV2_5):
def __init__(self, class_num=2, width_mult=1.0, round_nearest=8):
super().__init__(class_num=class_num, width_mult=width_mult, round_nearest=round_nearest)
self.quant = QuantStub()
self.dequant = DeQuantStub()
def _forward_impl(self, x: Tensor) -> Tensor:
x = self.feature_extraction(x)
# This was for debugging errors in shape of feature maps as they pass through - not deleting incase useful later
# for idx, layer in enumerate(self.feature_extraction):
# x = layer(x)
# print(f"Feature extraction layer {idx}, output shape: {x.shape}")
x = self.avg_pooling(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
def forward(self, x: Tensor) -> Tensor:
x = self.quant(x)
x = self._forward_impl(x)
x= self.dequant(x)
return x
def train_single_epoch(model, loss_fnc, optimiser, data_loader, device):
model.train()
running_loss = 0
running_time = 0.0
for images, labels in data_loader:
start_time = time.time()
print(".", end=" ")
images, labels = images.to(device), labels.to(device)
preds = model(images)
loss = loss_fnc(preds, labels)
loss.backward()
optimiser.step()
running_loss += loss.item()
running_time += time.time() - start_time
print(f"{(time.time() - start_time):.2f}, {(running_time):.2f}", end=" ")
print(f"loss of {running_loss}")
return
def print_size_of_model(model):
torch.save(model.state_dict(), "temp.p")
print('Size (MB):', os.path.getsize("temp.p")/1e6)
os.remove('temp.p')
def adjust_quantisation_engine():
# Adjust according to what your device supports
print(torch.backends.quantized.supported_engines)
torch.backends.quantized.engine = 'qnnpack'
def train_model(model, dataloader, loss_function, optimiser, epoch_number=25, const_save=False, save=True):
for epoch in range(epoch_number):
print("IT IS EPOCH", epoch)
train_single_epoch(model, loss_function, optimiser, dataloader, torch.device('cpu'))
# Gradually freezes the unrequired observer parameters for quantisation and batch normalisation after a few epochs
if epoch > 3:
# Freeze quantizer parameters
model.apply(torch.ao.quantization.disable_observer)
if epoch > 2:
# Freeze batch norm mean and variance estimates
model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
if const_save:
quantized_model = torch.ao.quantization.convert(model.eval(), inplace=False)
quantized_model.eval()
# Saving each intermediary model since they're so small, and this lets load up any of them for performace difference examples later
torch.save(quantized_model.state_dict(), "quantStateDict"+str(epoch+1)+".pth")
print(f"the above was Epoch {epoch+1} of {epoch_number} \nThe model has a size of", end=" ")
print_size_of_model(quantized_model)
else:
print(f"the above was Epoch {epoch} of {epoch_number}")
if save:
torch.save(quantized_model.state_dict(), "full_quantStateDict.pth")
return model
learning_rate = 1e-3
batch_size = 64
data_size = 2560
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Update to whatever you call your model
modelName = "quantStateDict8.pth"
load = False
model = QuantizableMobileNetV2_5()
# Adjust according to what your device supports
torch.backends.quantized.engine = 'qnnpack'
model.qconfig = torch.ao.quantization.default_qconfig
optimiser = torch.optim.SGD(model.parameters(), lr= learning_rate)
torch.ao.quantization.prepare_qat(model, inplace=True)
dataset = ClassUtils.CrosswalkDataset("zebra_annotations/classification_data")
train_loader = DataLoader(
Subset(dataset, random.sample(list(range(0, int(len(dataset) * 0.95))), data_size)),
batch_size=batch_size, shuffle=True)
test_loader = DataLoader(
Subset(dataset, random.sample(list(range(int(len(dataset) * 0.95), len(dataset))), 256)),
batch_size=batch_size, shuffle=False)
loss_function = nn.BCEWithLogitsLoss()
model_updated = train_model(model, train_loader, loss_function, optimiser, epoch_number=8, const_save=True)
quantized_model = torch.ao.quantization.convert(model_updated.eval(), inplace=True)
if load:
model_loaded_state_dict = torch.load(modelName)
quantized_model.load_state_dict(model_loaded_state_dict)
for images, labels in test_loader:
preds = torch.sigmoid(quantized_model(images))
for i in range(len(preds)):
print(preds)
# plt.imshow(torch.permute(images[i], (1, 2, 0)).detach().numpy())
# plt.title(f"Prediction: {preds[i]}, Actual: {labels[i][0] == 1}")
# plt.axis("off")
# plt.show()