| # Michael Peres ~ 09/01/2024 | |
| # Bert Based Transformer Model for Image Classification | |
| # ---------------------------------------------------------------------------------------------------------------------- | |
| # Import Modules | |
| # pip install transformers torchvision | |
| from transformers import BertModel, BertTokenizer, BertConfig | |
| from transformers import get_linear_schedule_with_warmup | |
| from transformers import BertForSequenceClassification | |
| from torchvision.utils import make_grid, save_image | |
| from torch.utils.data import Dataset, DataLoader | |
| from torchvision.datasets import MNIST, CIFAR10 | |
| from torchvision import datasets, transforms | |
| from tqdm.notebook import tqdm, trange | |
| from torch.optim import AdamW, Adam | |
| import matplotlib.pyplot as plt | |
| import torch.nn.functional as F | |
| import math, os, torch | |
| import torch.nn as nn | |
| # ---------------------------------------------------------------------------------------------------------------------- | |
| # This is a simple implementation, where the first hidden state, | |
| # which is the encoded class token is used as the input to a MLP Head for classification. | |
| # The model is trained on CIFAR-10 dataset, which is a dataset of 60,000 32x32 color images in 10 classes, | |
| # with 6,000 images per class. | |
| # This model will only contain the encoder part of the BERT model, and the classification head. | |
| # ---------------------------------------------------------------------------------------------------------------------- | |
| # Some understanding of the BERT model is required to understand this code, here are the dimensions and documentation. | |
| # From documentation, https://huggingface.co/transformers/v3.0.2/model_doc/bert.html | |
| # BERT Parameters include: | |
| # - hidden size: 256 | |
| # - intermediate size: 1024 | |
| # - number of hidden_layers: 12 | |
| # - num of attention heads: 8 | |
| # - max position embeddings: 256 | |
| # - vocab size: 100 | |
| # - bos_token_id: 101 | |
| # - eod_token_id: 102 | |
| # - cls_token_id: 103 | |
| # But what do all of these mean in terms of the question. | |
| # Hidden size, this represents the dimensionality of the input embeddings D. | |
| # Intermediate size is the number of neurons in the hidden layer of the feedforward, | |
| # the feed forward would have dims, Hidden Size D -> Intermediate Size -> Hidden Size D | |
| # Num of hidden layers, means the number of hidden layers in the transformer encoder, | |
| # layers refer to transformer blocks, so more transformer blocks in the model. | |
| # Num of attention heads, refers to the number multihead attention modules within one hidden layer.abs | |
| # Max position embeddings refers to the max size of an input the model can handle, this should be larger for models that handle larger inputs etc.abs | |
| # vocab size refers to the set of tokens the model is trained on, which has a specific length, | |
| # in our case it is 100, which is confusing, because we have pixel intensities between 0-255. | |
| # bos token is the beginning of a sentence token, which is token id, good for understanding sentence boundaries for text generation tasks.abs | |
| # eos token id is end of sentence token, which I dont see in the documentation for bert config. | |
| # cls token id is token is inputted at the beginning of each input instances. | |
| # output_hidden_states = True, means to output all the hidden states for us to view. | |
| # ---------------------------------------------------------------------------------------------------------------------- | |
| # Preparing CIFAR10 Image Dataset, and DataLoaders for Training and Testing | |
| dataset = CIFAR10(root='./data/', train=True, download=True, transform= | |
| transforms.Compose([ | |
| transforms.RandomHorizontalFlip(), | |
| transforms.RandomCrop(32, padding=4), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), | |
| ])) | |
| # augmentations are super important for CNN trainings, or it will overfit very fast without achieving good generalization accuracy | |
| val_dataset = CIFAR10(root='./data/', train=False, download=True, transform=transforms.Compose( | |
| [transforms.ToTensor(), | |
| transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ] | |
| )) | |
| # Model Configuration and Hyperparameters | |
| config = BertConfig(hidden_size=256, intermediate_size=1024, num_hidden_layers=12, num_attention_heads=8, max_position_embeddings=256, vocab_size=100, bos_token_id=101, eos_token_id=102, cls_token_id=103, output_hidden_states=False) | |
| model = BertModel(config).cuda() | |
| patch_embed = nn.Conv2d(3, config.hidden_size, kernel_size=4, stride=4).cuda() | |
| CLS_token = nn.Parameter(torch.randn(1, 1, config.hidden_size, device="cuda") / math.sqrt(config.hidden_size)) | |
| readout = nn.Sequential(nn.Linear(config.hidden_size, config.hidden_size), | |
| nn.GELU(), | |
| nn.Linear(config.hidden_size, 10) | |
| ).cuda() | |
| for module in [patch_embed, readout, model, CLS_token]: | |
| module.cuda() | |
| optimizer = AdamW([*model.parameters(), | |
| *patch_embed.parameters(), | |
| *readout.parameters(), | |
| CLS_token], lr=5e-4) | |
| # DataLoaders | |
| batch_size = 192 # 96 | |
| train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) | |
| val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) | |
| # ---------------------------------------------------------------------------------------------------------------------- | |
| # Understanding ClS Token: | |
| # print("CLASS TOKEN shape:") | |
| # print(CLS_token.shape) | |
| # | |
| # reshaped_cls = CLS_token.expand(192, 1, -1) | |
| # print("CLS Reshaped shape", reshaped_cls.shape) # 192, 1, 256 | |
| # # We are telling the CLS to have the same shape as patch embeddings. | |
| # | |
| # imgs, labels = next(iter(train_loader)) | |
| # patch_embs = patch_embed(imgs.cuda()).flatten(2).permute(0, 2, 1) | |
| # | |
| # input_embs = torch.cat([reshaped_cls, patch_embs], dim=1) | |
| # print("Patch Embeddings Shape", patch_embs.shape) | |
| # | |
| # print("Input Embedding Shape", input_embs.shape) | |
| # ---------------------------------------------------------------------------------------------------------------------- | |
| # Understanding Output of Model Transformer: | |
| # Hidden State state dimension: 192, 12, 65, 256 | |
| # Last Hidden state dimension: 192, 65 256 | |
| # Pooler Output: 192, 256 | |
| # in essence pool all the tokens outputs, so we have a one value per complete sample, | |
| # completely removing the information for each token. | |
| # | |
| # # We should understand output of a model, | |
| # representations = output.last_hidden_state[:, 0, :] | |
| # print(output.last_hidden_state.shape) # Out of memory. | |
| # print(representations.shape) | |
| # ---------------------------------------------------------------------------------------------------------------------- | |
| # Training Loop | |
| EPOCHS = 30 | |
| model.train() | |
| loss_list = [] | |
| acc_list = [] | |
| correct_cnt = 0 | |
| total_loss = 0 | |
| for epoch in trange(EPOCHS, leave=False): | |
| pbar = tqdm(train_loader, leave=False) | |
| for i, (imgs, labels) in enumerate(pbar): | |
| patch_embs = patch_embed(imgs.cuda()) # patch embeddings, | |
| # print("patch embs shape ", patch_embs.shape) # (192, 256, 8, 8) # 192 per batch, | |
| patch_embs = patch_embs.flatten(2).permute(0, 2, 1) # (batch_size, HW, hidden=256) | |
| # print(patch_embs.shape) | |
| input_embs = torch.cat([CLS_token.expand(imgs.shape[0], 1, -1), patch_embs], dim=1) | |
| # print(input_embs.shape) | |
| output = model(inputs_embeds=input_embs) | |
| # print(dir(output)) | |
| # print("output, hidden state shape", output.hidden_states) # out of memory error. | |
| # print("output hidden state shape", output.last_hidden_state.shape) # 192, 65, 256 | |
| # print("output pooler output shape", output.pooler_output.shape) | |
| logit = readout(output.last_hidden_state[:, 0, :]) | |
| loss = F.cross_entropy(logit, labels.cuda()) | |
| # print(loss) | |
| loss.backward() | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| pbar.set_description(f"loss: {loss.item():.4f}") | |
| total_loss += loss.item() * imgs.shape[0] | |
| correct_cnt += (logit.argmax(dim=1) == labels.cuda()).sum().item() | |
| loss_list.append(round(total_loss / len(dataset), 4)) | |
| acc_list.append(round(correct_cnt / len(dataset), 4)) | |
| # test on validation set | |
| model.eval() | |
| correct_cnt = 0 | |
| total_loss = 0 | |
| for i, (imgs, labels) in enumerate(val_loader): | |
| patch_embs = patch_embed(imgs.cuda()) | |
| patch_embs = patch_embs.flatten(2).permute(0, 2, 1) # (batch_size, HW, hidden) | |
| input_embs = torch.cat([CLS_token.expand(imgs.shape[0], 1, -1), patch_embs], dim=1) | |
| output = model(inputs_embeds=input_embs) | |
| logit = readout(output.last_hidden_state[:, 0, :]) | |
| loss = F.cross_entropy(logit, labels.cuda()) | |
| total_loss += loss.item() * imgs.shape[0] | |
| correct_cnt += (logit.argmax(dim=1) == labels.cuda()).sum().item() | |
| print(f"val loss: {total_loss / len(val_dataset):.4f}, val acc: {correct_cnt / len(val_dataset):.4f}") | |
| # Plotting Loss and Accuracy | |
| plt.figure() | |
| plt.plot(loss_list, label="loss") | |
| plt.plot(acc_list, label="accuracy") | |
| plt.legend() | |
| plt.show() | |
| # ---------------------------------------------------------------------------------------------------------------------- | |
| # Saving Model Parameters | |
| torch.save(model.state_dict(), "bert.pth") | |
| # ---------------------------------------------------------------------------------------------------------------------- | |
| # Reference: Tutorial for Harvard Medical School ML from Scratch Series: Transformer from Scratch | |
| # ---------------------------------------------------------------------------------------------------------------------- | |