MartialTerran commited on
Commit
5c88c7d
·
verified ·
1 Parent(s): ab02030

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +314 -3
README.md CHANGED
@@ -1,3 +1,314 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The following PiT_MNIST_V1.0.ipynb is a direct implementationi of the PiT pixel transformer described in the 2024 paper titled
2
+ An Image is Worth More Than 16 x 16 Patches: Exploring Transformers on Individual Pixels
3
+ at https://arxiv.org/html/2406.09415v1
4
+ Which describes "directly treating each individual pixel as a token and achieve highly performant results"
5
+ This script simply applies this PiT model architecture without any modifications to the standard NMIST numeral-images-classification dataset that is provided in Google Colab sample_data folder.
6
+ The script was ran for 25 epochs and obtained 92.30 Accuracy on the Validation set ( Train Loss: 0.2800 | Val Loss: 0.2441 | Val Acc: 92.30%) by epoch 15.
7
+ Loss fell and Accuracy increased monontonically per each epoch.
8
+
9
+ # ==============================================================================
10
+ # PiT_MNIST_V1.0.py [in colab: PiT_MNIST_V1.0.ipynb]
11
+ #
12
+ # ML-Engineer LLM Agent Implementation
13
+ #
14
+ # Description:
15
+ # This script implements a Pixel Transformer (PiT) for MNIST classification,
16
+ # based on the paper "An Image is Worth More Than 16x16 Patches"
17
+ # (arXiv:2406.09415). It treats each pixel as an individual token, forgoing
18
+ # the patch-based approach of traditional Vision Transformers.
19
+ #
20
+ # Designed for Google Colab using the sample_data/mnist_*.csv files.
21
+ # ==============================================================================
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ import pandas as pd
26
+ from torch.utils.data import Dataset, DataLoader
27
+ from sklearn.model_selection import train_test_split
28
+ from tqdm import tqdm
29
+ import math
30
+
31
+ # --- 1. Configuration & Hyperparameters ---
32
+ # These parameters are chosen to be reasonable for the MNIST task and
33
+ # inspired by the "Tiny" or "Small" variants in the paper.
34
+ CONFIG = {
35
+ "train_file": "/content/sample_data/mnist_train_small.csv",
36
+ "test_file": "/content/sample_data/mnist_test.csv",
37
+ "image_size": 28,
38
+ "num_classes": 10,
39
+ "embed_dim": 128, # 'd' in the paper. Dimension for each pixel embedding.
40
+ "num_layers": 6, # Number of Transformer Encoder layers.
41
+ "num_heads": 8, # Number of heads in Multi-Head Self-Attention. Must be a divisor of embed_dim.
42
+ "mlp_dim": 512, # Hidden dimension of the MLP block inside the Transformer. (4 * embed_dim is common)
43
+ "dropout": 0.1,
44
+ "batch_size": 128,
45
+ "epochs": 25, # Increased epochs for better convergence on the small dataset.
46
+ "learning_rate": 1e-4,
47
+ "device": "cuda" if torch.cuda.is_available() else "cpu",
48
+ }
49
+ CONFIG["sequence_length"] = CONFIG["image_size"] * CONFIG["image_size"] # 784 for MNIST
50
+
51
+ print("--- Configuration ---")
52
+ for key, value in CONFIG.items():
53
+ print(f"{key}: {value}")
54
+ print("---------------------\n")
55
+
56
+
57
+ # --- 2. Data Loading and Preprocessing ---
58
+ class MNIST_CSV_Dataset(Dataset):
59
+ """Custom PyTorch Dataset for loading MNIST data from CSV files."""
60
+ def __init__(self, file_path):
61
+ df = pd.read_csv(file_path)
62
+ self.labels = torch.tensor(df.iloc[:, 0].values, dtype=torch.long)
63
+ # Normalize pixel values to [0, 1] and keep as float
64
+ self.pixels = torch.tensor(df.iloc[:, 1:].values, dtype=torch.float32) / 255.0
65
+
66
+ def __len__(self):
67
+ return len(self.labels)
68
+
69
+ def __getitem__(self, idx):
70
+ # The PiT's projection layer expects input of shape (in_features),
71
+ # so for each pixel, we need a tensor of shape (1).
72
+ # We reshape the 784 pixels to (784, 1).
73
+ return self.pixels[idx].unsqueeze(-1), self.labels[idx]
74
+
75
+ # --- 3. Pixel Transformer (PiT) Model Architecture ---
76
+ class PixelTransformer(nn.Module):
77
+ """
78
+ Pixel Transformer (PiT) model.
79
+ Treats each pixel as a token and uses a Transformer Encoder for classification.
80
+ """
81
+ def __init__(self, seq_len, num_classes, embed_dim, num_layers, num_heads, mlp_dim, dropout):
82
+ super().__init__()
83
+
84
+ # 1. Pixel Projection: Each pixel (a single value) is projected to embed_dim.
85
+ # This is the core "pixels-as-tokens" step.
86
+ self.pixel_projection = nn.Linear(1, embed_dim)
87
+
88
+ # 2. CLS Token: A learnable parameter that is prepended to the sequence of
89
+ # pixel embeddings. Its output state is used for classification.
90
+ self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
91
+
92
+ # 3. Position Embedding: Learnable embeddings to encode spatial information.
93
+ # Size is (seq_len + 1) to account for the CLS token.
94
+ # This removes the inductive bias of fixed positional encodings.
95
+ self.position_embedding = nn.Parameter(torch.randn(1, seq_len + 1, embed_dim))
96
+
97
+ self.dropout = nn.Dropout(dropout)
98
+
99
+ # 4. Transformer Encoder: The main workhorse of the model.
100
+ encoder_layer = nn.TransformerEncoderLayer(
101
+ d_model=embed_dim,
102
+ nhead=num_heads,
103
+ dim_feedforward=mlp_dim,
104
+ dropout=dropout,
105
+ activation="gelu",
106
+ batch_first=True # Important for (batch, seq, feature) input format
107
+ )
108
+ self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
109
+
110
+ # 5. Classification Head: A simple MLP head on top of the CLS token's output.
111
+ self.mlp_head = nn.Sequential(
112
+ nn.LayerNorm(embed_dim),
113
+ nn.Linear(embed_dim, num_classes)
114
+ )
115
+
116
+ def forward(self, x):
117
+ # Input x shape: (batch_size, seq_len, 1) -> (B, 784, 1)
118
+
119
+ # Project pixels to embedding dimension
120
+ x = self.pixel_projection(x) # (B, 784, 1) -> (B, 784, embed_dim)
121
+
122
+ # Prepend CLS token
123
+ cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) # (B, 1, embed_dim)
124
+ x = torch.cat((cls_tokens, x), dim=1) # (B, 785, embed_dim)
125
+
126
+ # Add position embedding
127
+ x = x + self.position_embedding # (B, 785, embed_dim)
128
+ x = self.dropout(x)
129
+
130
+ # Pass through Transformer Encoder
131
+ x = self.transformer_encoder(x) # (B, 785, embed_dim)
132
+
133
+ # Extract the CLS token's output (at position 0)
134
+ cls_output = x[:, 0] # (B, embed_dim)
135
+
136
+ # Pass through MLP head to get logits
137
+ logits = self.mlp_head(cls_output) # (B, num_classes)
138
+
139
+ return logits
140
+
141
+
142
+ # --- 4. Training and Evaluation Functions ---
143
+ def train_one_epoch(model, dataloader, criterion, optimizer, device):
144
+ model.train()
145
+ total_loss = 0
146
+ progress_bar = tqdm(dataloader, desc="Training", leave=False)
147
+ for pixels, labels in progress_bar:
148
+ pixels, labels = pixels.to(device), labels.to(device)
149
+
150
+ # Forward pass
151
+ logits = model(pixels)
152
+ loss = criterion(logits, labels)
153
+
154
+ # Backward and optimize
155
+ optimizer.zero_grad()
156
+ loss.backward()
157
+ optimizer.step()
158
+
159
+ total_loss += loss.item()
160
+ progress_bar.set_postfix(loss=loss.item())
161
+
162
+ return total_loss / len(dataloader)
163
+
164
+ def evaluate(model, dataloader, criterion, device):
165
+ model.eval()
166
+ total_loss = 0
167
+ correct = 0
168
+ total = 0
169
+ with torch.no_grad():
170
+ progress_bar = tqdm(dataloader, desc="Evaluating", leave=False)
171
+ for pixels, labels in progress_bar:
172
+ pixels, labels = pixels.to(device), labels.to(device)
173
+
174
+ logits = model(pixels)
175
+ loss = criterion(logits, labels)
176
+
177
+ total_loss += loss.item()
178
+ _, predicted = torch.max(logits.data, 1)
179
+ total += labels.size(0)
180
+ correct += (predicted == labels).sum().item()
181
+ progress_bar.set_postfix(acc=100. * correct / total)
182
+
183
+ avg_loss = total_loss / len(dataloader)
184
+ accuracy = 100. * correct / total
185
+ return avg_loss, accuracy
186
+
187
+
188
+ # --- 5. Main Execution Block ---
189
+ if __name__ == "__main__":
190
+ device = CONFIG["device"]
191
+
192
+ # Load full training data and split into train/validation sets
193
+ # This helps monitor overfitting, as mnist_train_small is quite small.
194
+ full_train_dataset = MNIST_CSV_Dataset(CONFIG["train_file"])
195
+ train_indices, val_indices = train_test_split(
196
+ range(len(full_train_dataset)),
197
+ test_size=0.1, # 10% for validation
198
+ random_state=42
199
+ )
200
+ train_dataset = torch.utils.data.Subset(full_train_dataset, train_indices)
201
+ val_dataset = torch.utils.data.Subset(full_train_dataset, val_indices)
202
+ test_dataset = MNIST_CSV_Dataset(CONFIG["test_file"])
203
+
204
+ train_loader = DataLoader(train_dataset, batch_size=CONFIG["batch_size"], shuffle=True)
205
+ val_loader = DataLoader(val_dataset, batch_size=CONFIG["batch_size"], shuffle=False)
206
+ test_loader = DataLoader(test_dataset, batch_size=CONFIG["batch_size"], shuffle=False)
207
+
208
+ print(f"\nData loaded.")
209
+ print(f" Training samples: {len(train_dataset)}")
210
+ print(f" Validation samples: {len(val_dataset)}")
211
+ print(f" Test samples: {len(test_dataset)}\n")
212
+
213
+ # Initialize model, loss function, and optimizer
214
+ model = PixelTransformer(
215
+ seq_len=CONFIG["sequence_length"],
216
+ num_classes=CONFIG["num_classes"],
217
+ embed_dim=CONFIG["embed_dim"],
218
+ num_layers=CONFIG["num_layers"],
219
+ num_heads=CONFIG["num_heads"],
220
+ mlp_dim=CONFIG["mlp_dim"],
221
+ dropout=CONFIG["dropout"]
222
+ ).to(device)
223
+
224
+ total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
225
+ print(f"Model initialized on {device}.")
226
+ print(f"Total trainable parameters: {total_params:,}\n")
227
+
228
+ criterion = nn.CrossEntropyLoss()
229
+ # AdamW is often preferred for Transformers
230
+ optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG["learning_rate"])
231
+
232
+ # Training loop
233
+ best_val_acc = 0
234
+ print("--- Starting Training ---")
235
+ for epoch in range(CONFIG["epochs"]):
236
+ train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
237
+ val_loss, val_acc = evaluate(model, val_loader, criterion, device)
238
+
239
+ print(
240
+ f"Epoch {epoch+1:02}/{CONFIG['epochs']} | "
241
+ f"Train Loss: {train_loss:.4f} | "
242
+ f"Val Loss: {val_loss:.4f} | "
243
+ f"Val Acc: {val_acc:.2f}%"
244
+ )
245
+
246
+ if val_acc > best_val_acc:
247
+ best_val_acc = val_acc
248
+ print(f" -> New best validation accuracy! Saving model state.")
249
+ torch.save(model.state_dict(), "PiT_MNIST_best.pth")
250
+
251
+ print("--- Training Finished ---\n")
252
+
253
+ # Final evaluation on the test set using the best model
254
+ print("--- Evaluating on Test Set ---")
255
+ model.load_state_dict(torch.load("PiT_MNIST_best.pth"))
256
+ test_loss, test_acc = evaluate(model, test_loader, criterion, device)
257
+ print(f"Final Test Loss: {test_loss:.4f}")
258
+ print(f"Final Test Accuracy: {test_acc:.2f}%")
259
+ print("----------------------------\n")
260
+
261
+
262
+ [The PiT_MNIST_V1.0.ipynb script ran out of memory in CPUR, but was able to run and train fast in A100 GPU mode]
263
+ --- Configuration ---
264
+ train_file: /content/sample_data/mnist_train_small.csv
265
+ test_file: /content/sample_data/mnist_test.csv
266
+ image_size: 28
267
+ num_classes: 10
268
+ embed_dim: 128
269
+ num_layers: 6
270
+ num_heads: 8
271
+ mlp_dim: 512
272
+ dropout: 0.1
273
+ batch_size: 128
274
+ epochs: 25
275
+ learning_rate: 0.0001
276
+ device: cuda
277
+ sequence_length: 784
278
+ ---------------------
279
+
280
+
281
+ Data loaded.
282
+ Training samples: 17999
283
+ Validation samples: 2000
284
+ Test samples: 9999
285
+
286
+ Model initialized on cuda.
287
+ Total trainable parameters: 1,292,042
288
+
289
+ --- Starting Training ---
290
+ Epoch 01/25 | Train Loss: 2.2063 | Val Loss: 2.0610 | Val Acc: 22.75%
291
+ -> New best validation accuracy! Saving model state.
292
+ Epoch 02/25 | Train Loss: 1.9907 | Val Loss: 1.7945 | Val Acc: 32.00%
293
+ -> New best validation accuracy! Saving model state.
294
+ Epoch 03/25 | Train Loss: 1.5767 | Val Loss: 1.1938 | Val Acc: 58.35%
295
+ -> New best validation accuracy! Saving model state.
296
+ Epoch 04/25 | Train Loss: 1.0441 | Val Loss: 0.7131 | Val Acc: 77.10%
297
+ -> New best validation accuracy! Saving model state.
298
+ Epoch 05/25 | Train Loss: 0.7299 | Val Loss: 0.5490 | Val Acc: 82.95%
299
+ -> New best validation accuracy! Saving model state.
300
+ Epoch 06/25 | Train Loss: 0.5935 | Val Loss: 0.4821 | Val Acc: 84.60%
301
+ -> New best validation accuracy! Saving model state.
302
+ Epoch 07/25 | Train Loss: 0.5311 | Val Loss: 0.4021 | Val Acc: 86.95%
303
+ -> New best validation accuracy! Saving model state.
304
+ Epoch 08/25 | Train Loss: 0.4682 | Val Loss: 0.3680 | Val Acc: 88.05%
305
+ -> New best validation accuracy! Saving model state.
306
+ Epoch 09/25 | Train Loss: 0.4264 | Val Loss: 0.3446 | Val Acc: 89.20%
307
+ -> New best validation accuracy! Saving model state.
308
+
309
+
310
+
311
+
312
+ ---
313
+ license: apache-2.0
314
+ ---