File size: 9,302 Bytes
5357119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
# Michael Peres ~ 09/01/2024
# Bert Based Transformer Model for Image Classification
# ----------------------------------------------------------------------------------------------------------------------
# Import Modules
# pip install transformers torchvision
from transformers import BertModel, BertTokenizer, BertConfig
from transformers import get_linear_schedule_with_warmup
from transformers import BertForSequenceClassification
from torchvision.utils import make_grid, save_image
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import MNIST, CIFAR10
from torchvision import datasets, transforms
from tqdm.notebook import tqdm, trange
from torch.optim import AdamW, Adam
import matplotlib.pyplot as plt
import torch.nn.functional as F
import math, os, torch
import torch.nn as nn

# ----------------------------------------------------------------------------------------------------------------------
# This is a simple implementation, where the first hidden state,
# which is the encoded class token is used as the input to a MLP Head for classification.
# The model is trained on CIFAR-10 dataset, which is a dataset of 60,000 32x32 color images in 10 classes,
# with 6,000 images per class.
# This model will only contain the encoder part of the BERT model, and the classification head.


# ----------------------------------------------------------------------------------------------------------------------
# Some understanding of the BERT model is required to understand this code, here are the dimensions and documentation.
# From documentation, https://huggingface.co/transformers/v3.0.2/model_doc/bert.html

# BERT Parameters include:

# - hidden size: 256
# - intermediate size: 1024
# - number of hidden_layers:  12
# - num of attention heads: 8
# - max position embeddings: 256
# - vocab size: 100
# - bos_token_id: 101
# - eod_token_id: 102
# - cls_token_id: 103

# But what do all of these mean in terms of the question.

# Hidden size, this represents the dimensionality of the input embeddings D.

# Intermediate size is the number of neurons in the hidden layer of the feedforward,
# the feed forward would have dims, Hidden Size D -> Intermediate Size -> Hidden Size D

# Num of hidden layers, means the number of hidden layers in the transformer encoder,
# layers refer to transformer blocks, so more transformer blocks in the model.

# Num of attention heads, refers to the number multihead attention modules within one hidden layer.abs

# Max position embeddings refers to the max size of an input the model can handle, this should be larger for models that handle larger inputs etc.abs

# vocab size refers to the set of tokens the model is trained on, which has a specific length,
# in our case it is 100, which is confusing, because we have pixel intensities between 0-255.

# bos token is the beginning of a sentence token, which is token id, good for understanding sentence boundaries for text generation tasks.abs

# eos token id is end of sentence token, which I dont see in the documentation for bert config.

# cls token id is token is inputted at the beginning of each input instances.

# output_hidden_states = True, means to output all the hidden states for us to view.


# ----------------------------------------------------------------------------------------------------------------------

# Preparing CIFAR10 Image Dataset, and DataLoaders for Training and Testing
dataset = CIFAR10(root='./data/', train=True, download=True, transform=
transforms.Compose([
	transforms.RandomHorizontalFlip(),
	transforms.RandomCrop(32, padding=4),
	transforms.ToTensor(),
	transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]))
# augmentations are super important for CNN trainings, or it will overfit very fast without achieving good generalization accuracy

val_dataset = CIFAR10(root='./data/', train=False, download=True, transform=transforms.Compose(
	[transforms.ToTensor(),
	 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]
))

# Model Configuration and Hyperparameters
config = BertConfig(hidden_size=256, intermediate_size=1024, num_hidden_layers=12, num_attention_heads=8, max_position_embeddings=256, vocab_size=100, bos_token_id=101, eos_token_id=102, cls_token_id=103, output_hidden_states=False)

model = BertModel(config).cuda()
patch_embed = nn.Conv2d(3, config.hidden_size, kernel_size=4, stride=4).cuda()
CLS_token = nn.Parameter(torch.randn(1, 1, config.hidden_size, device="cuda") / math.sqrt(config.hidden_size))
readout = nn.Sequential(nn.Linear(config.hidden_size, config.hidden_size),
                        nn.GELU(),
                        nn.Linear(config.hidden_size, 10)
                        ).cuda()

for module in [patch_embed, readout, model, CLS_token]:
	module.cuda()

optimizer = AdamW([*model.parameters(),
                   *patch_embed.parameters(),
                   *readout.parameters(),
                   CLS_token], lr=5e-4)

# DataLoaders
batch_size = 192  # 96
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# ----------------------------------------------------------------------------------------------------------------------
# Understanding ClS Token:
# print("CLASS TOKEN shape:")
# print(CLS_token.shape)
#
# reshaped_cls = CLS_token.expand(192, 1, -1)
# print("CLS Reshaped shape", reshaped_cls.shape)  # 192, 1, 256
# # We are telling the CLS to have the same shape as patch embeddings.
#
# imgs, labels = next(iter(train_loader))
# patch_embs = patch_embed(imgs.cuda()).flatten(2).permute(0, 2, 1)
#
# input_embs = torch.cat([reshaped_cls, patch_embs], dim=1)
# print("Patch Embeddings Shape", patch_embs.shape)
#
# print("Input Embedding Shape", input_embs.shape)

# ----------------------------------------------------------------------------------------------------------------------
# Understanding Output of Model Transformer:

# Hidden State state dimension: 192, 12, 65, 256
# Last Hidden state dimension: 192, 65 256
# Pooler Output: 192, 256

# in essence pool all the tokens outputs, so we have a one value per complete sample,
# completely removing the information for each token.

#
# # We should understand output of a model,
# representations = output.last_hidden_state[:, 0, :]
# print(output.last_hidden_state.shape)  # Out of memory.
# print(representations.shape)

# ----------------------------------------------------------------------------------------------------------------------

# Training Loop
EPOCHS = 30

model.train()
loss_list = []
acc_list = []
correct_cnt = 0
total_loss = 0
for epoch in trange(EPOCHS, leave=False):
	pbar = tqdm(train_loader, leave=False)
	for i, (imgs, labels) in enumerate(pbar):
		patch_embs = patch_embed(imgs.cuda())  # patch embeddings,
		# print("patch embs shape ", patch_embs.shape)  #  (192, 256, 8, 8) # 192 per batch,
		patch_embs = patch_embs.flatten(2).permute(0, 2, 1)  # (batch_size, HW, hidden=256)
		# print(patch_embs.shape)
		input_embs = torch.cat([CLS_token.expand(imgs.shape[0], 1, -1), patch_embs], dim=1)
		# print(input_embs.shape)
		output = model(inputs_embeds=input_embs)
		# print(dir(output))
		# print("output, hidden state shape", output.hidden_states)  # out of memory error.
		# print("output hidden state shape", output.last_hidden_state.shape) # 192, 65, 256
		# print("output pooler output shape", output.pooler_output.shape)
		logit = readout(output.last_hidden_state[:, 0, :])
		loss = F.cross_entropy(logit, labels.cuda())
		# print(loss)
		loss.backward()
		optimizer.step()
		optimizer.zero_grad()
		pbar.set_description(f"loss: {loss.item():.4f}")
		total_loss += loss.item() * imgs.shape[0]
		correct_cnt += (logit.argmax(dim=1) == labels.cuda()).sum().item()

	loss_list.append(round(total_loss / len(dataset), 4))
	acc_list.append(round(correct_cnt / len(dataset), 4))
	# test on validation set
	model.eval()
	correct_cnt = 0
	total_loss = 0

	for i, (imgs, labels) in enumerate(val_loader):
		patch_embs = patch_embed(imgs.cuda())
		patch_embs = patch_embs.flatten(2).permute(0, 2, 1)  # (batch_size, HW, hidden)
		input_embs = torch.cat([CLS_token.expand(imgs.shape[0], 1, -1), patch_embs], dim=1)
		output = model(inputs_embeds=input_embs)
		logit = readout(output.last_hidden_state[:, 0, :])
		loss = F.cross_entropy(logit, labels.cuda())
		total_loss += loss.item() * imgs.shape[0]
		correct_cnt += (logit.argmax(dim=1) == labels.cuda()).sum().item()

	print(f"val loss: {total_loss / len(val_dataset):.4f}, val acc: {correct_cnt / len(val_dataset):.4f}")

# Plotting Loss and Accuracy
plt.figure()
plt.plot(loss_list, label="loss")
plt.plot(acc_list, label="accuracy")
plt.legend()
plt.show()
# ----------------------------------------------------------------------------------------------------------------------

# Saving Model Parameters
torch.save(model.state_dict(), "bert.pth")

# ----------------------------------------------------------------------------------------------------------------------
# Reference: Tutorial for Harvard Medical School ML from Scratch Series: Transformer from Scratch
# ----------------------------------------------------------------------------------------------------------------------