MartialTerran commited on
Commit
fb188bc
·
verified ·
1 Parent(s): 7b5622f

Create PiT_MNIST_V1.0.ipynb

Browse files
Files changed (1) hide show
  1. PiT_MNIST_V1.0.ipynb +251 -0
PiT_MNIST_V1.0.ipynb ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==============================================================================
2
+ # PiT_MNIST_V1.0.py
3
+ #
4
+ # ML-Engineer LLM Agent Implementation
5
+ #
6
+ # Description:
7
+ # This script implements a Pixel Transformer (PiT) for MNIST classification,
8
+ # based on the paper "An Image is Worth More Than 16x16 Patches"
9
+ # (arXiv:2406.09415). It treats each pixel as an individual token, forgoing
10
+ # the patch-based approach of traditional Vision Transformers.
11
+ #
12
+ # Designed for Google Colab using the sample_data/mnist_*.csv files.
13
+ # ==============================================================================
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import pandas as pd
18
+ from torch.utils.data import Dataset, DataLoader
19
+ from sklearn.model_selection import train_test_split
20
+ from tqdm import tqdm
21
+ import math
22
+
23
+ # --- 1. Configuration & Hyperparameters ---
24
+ # These parameters are chosen to be reasonable for the MNIST task and
25
+ # inspired by the "Tiny" or "Small" variants in the paper.
26
+ CONFIG = {
27
+ "train_file": "/content/sample_data/mnist_train_small.csv",
28
+ "test_file": "/content/sample_data/mnist_test.csv",
29
+ "image_size": 28,
30
+ "num_classes": 10,
31
+ "embed_dim": 128, # 'd' in the paper. Dimension for each pixel embedding.
32
+ "num_layers": 6, # Number of Transformer Encoder layers.
33
+ "num_heads": 8, # Number of heads in Multi-Head Self-Attention. Must be a divisor of embed_dim.
34
+ "mlp_dim": 512, # Hidden dimension of the MLP block inside the Transformer. (4 * embed_dim is common)
35
+ "dropout": 0.1,
36
+ "batch_size": 128,
37
+ "epochs": 25, # Increased epochs for better convergence on the small dataset.
38
+ "learning_rate": 1e-4,
39
+ "device": "cuda" if torch.cuda.is_available() else "cpu",
40
+ }
41
+ CONFIG["sequence_length"] = CONFIG["image_size"] * CONFIG["image_size"] # 784 for MNIST
42
+
43
+ print("--- Configuration ---")
44
+ for key, value in CONFIG.items():
45
+ print(f"{key}: {value}")
46
+ print("---------------------\n")
47
+
48
+
49
+ # --- 2. Data Loading and Preprocessing ---
50
+ class MNIST_CSV_Dataset(Dataset):
51
+ """Custom PyTorch Dataset for loading MNIST data from CSV files."""
52
+ def __init__(self, file_path):
53
+ df = pd.read_csv(file_path)
54
+ self.labels = torch.tensor(df.iloc[:, 0].values, dtype=torch.long)
55
+ # Normalize pixel values to [0, 1] and keep as float
56
+ self.pixels = torch.tensor(df.iloc[:, 1:].values, dtype=torch.float32) / 255.0
57
+
58
+ def __len__(self):
59
+ return len(self.labels)
60
+
61
+ def __getitem__(self, idx):
62
+ # The PiT's projection layer expects input of shape (in_features),
63
+ # so for each pixel, we need a tensor of shape (1).
64
+ # We reshape the 784 pixels to (784, 1).
65
+ return self.pixels[idx].unsqueeze(-1), self.labels[idx]
66
+
67
+ # --- 3. Pixel Transformer (PiT) Model Architecture ---
68
+ class PixelTransformer(nn.Module):
69
+ """
70
+ Pixel Transformer (PiT) model.
71
+ Treats each pixel as a token and uses a Transformer Encoder for classification.
72
+ """
73
+ def __init__(self, seq_len, num_classes, embed_dim, num_layers, num_heads, mlp_dim, dropout):
74
+ super().__init__()
75
+
76
+ # 1. Pixel Projection: Each pixel (a single value) is projected to embed_dim.
77
+ # This is the core "pixels-as-tokens" step.
78
+ self.pixel_projection = nn.Linear(1, embed_dim)
79
+
80
+ # 2. CLS Token: A learnable parameter that is prepended to the sequence of
81
+ # pixel embeddings. Its output state is used for classification.
82
+ self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
83
+
84
+ # 3. Position Embedding: Learnable embeddings to encode spatial information.
85
+ # Size is (seq_len + 1) to account for the CLS token.
86
+ # This removes the inductive bias of fixed positional encodings.
87
+ self.position_embedding = nn.Parameter(torch.randn(1, seq_len + 1, embed_dim))
88
+
89
+ self.dropout = nn.Dropout(dropout)
90
+
91
+ # 4. Transformer Encoder: The main workhorse of the model.
92
+ encoder_layer = nn.TransformerEncoderLayer(
93
+ d_model=embed_dim,
94
+ nhead=num_heads,
95
+ dim_feedforward=mlp_dim,
96
+ dropout=dropout,
97
+ activation="gelu",
98
+ batch_first=True # Important for (batch, seq, feature) input format
99
+ )
100
+ self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
101
+
102
+ # 5. Classification Head: A simple MLP head on top of the CLS token's output.
103
+ self.mlp_head = nn.Sequential(
104
+ nn.LayerNorm(embed_dim),
105
+ nn.Linear(embed_dim, num_classes)
106
+ )
107
+
108
+ def forward(self, x):
109
+ # Input x shape: (batch_size, seq_len, 1) -> (B, 784, 1)
110
+
111
+ # Project pixels to embedding dimension
112
+ x = self.pixel_projection(x) # (B, 784, 1) -> (B, 784, embed_dim)
113
+
114
+ # Prepend CLS token
115
+ cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) # (B, 1, embed_dim)
116
+ x = torch.cat((cls_tokens, x), dim=1) # (B, 785, embed_dim)
117
+
118
+ # Add position embedding
119
+ x = x + self.position_embedding # (B, 785, embed_dim)
120
+ x = self.dropout(x)
121
+
122
+ # Pass through Transformer Encoder
123
+ x = self.transformer_encoder(x) # (B, 785, embed_dim)
124
+
125
+ # Extract the CLS token's output (at position 0)
126
+ cls_output = x[:, 0] # (B, embed_dim)
127
+
128
+ # Pass through MLP head to get logits
129
+ logits = self.mlp_head(cls_output) # (B, num_classes)
130
+
131
+ return logits
132
+
133
+
134
+ # --- 4. Training and Evaluation Functions ---
135
+ def train_one_epoch(model, dataloader, criterion, optimizer, device):
136
+ model.train()
137
+ total_loss = 0
138
+ progress_bar = tqdm(dataloader, desc="Training", leave=False)
139
+ for pixels, labels in progress_bar:
140
+ pixels, labels = pixels.to(device), labels.to(device)
141
+
142
+ # Forward pass
143
+ logits = model(pixels)
144
+ loss = criterion(logits, labels)
145
+
146
+ # Backward and optimize
147
+ optimizer.zero_grad()
148
+ loss.backward()
149
+ optimizer.step()
150
+
151
+ total_loss += loss.item()
152
+ progress_bar.set_postfix(loss=loss.item())
153
+
154
+ return total_loss / len(dataloader)
155
+
156
+ def evaluate(model, dataloader, criterion, device):
157
+ model.eval()
158
+ total_loss = 0
159
+ correct = 0
160
+ total = 0
161
+ with torch.no_grad():
162
+ progress_bar = tqdm(dataloader, desc="Evaluating", leave=False)
163
+ for pixels, labels in progress_bar:
164
+ pixels, labels = pixels.to(device), labels.to(device)
165
+
166
+ logits = model(pixels)
167
+ loss = criterion(logits, labels)
168
+
169
+ total_loss += loss.item()
170
+ _, predicted = torch.max(logits.data, 1)
171
+ total += labels.size(0)
172
+ correct += (predicted == labels).sum().item()
173
+ progress_bar.set_postfix(acc=100. * correct / total)
174
+
175
+ avg_loss = total_loss / len(dataloader)
176
+ accuracy = 100. * correct / total
177
+ return avg_loss, accuracy
178
+
179
+
180
+ # --- 5. Main Execution Block ---
181
+ if __name__ == "__main__":
182
+ device = CONFIG["device"]
183
+
184
+ # Load full training data and split into train/validation sets
185
+ # This helps monitor overfitting, as mnist_train_small is quite small.
186
+ full_train_dataset = MNIST_CSV_Dataset(CONFIG["train_file"])
187
+ train_indices, val_indices = train_test_split(
188
+ range(len(full_train_dataset)),
189
+ test_size=0.1, # 10% for validation
190
+ random_state=42
191
+ )
192
+ train_dataset = torch.utils.data.Subset(full_train_dataset, train_indices)
193
+ val_dataset = torch.utils.data.Subset(full_train_dataset, val_indices)
194
+ test_dataset = MNIST_CSV_Dataset(CONFIG["test_file"])
195
+
196
+ train_loader = DataLoader(train_dataset, batch_size=CONFIG["batch_size"], shuffle=True)
197
+ val_loader = DataLoader(val_dataset, batch_size=CONFIG["batch_size"], shuffle=False)
198
+ test_loader = DataLoader(test_dataset, batch_size=CONFIG["batch_size"], shuffle=False)
199
+
200
+ print(f"\nData loaded.")
201
+ print(f" Training samples: {len(train_dataset)}")
202
+ print(f" Validation samples: {len(val_dataset)}")
203
+ print(f" Test samples: {len(test_dataset)}\n")
204
+
205
+ # Initialize model, loss function, and optimizer
206
+ model = PixelTransformer(
207
+ seq_len=CONFIG["sequence_length"],
208
+ num_classes=CONFIG["num_classes"],
209
+ embed_dim=CONFIG["embed_dim"],
210
+ num_layers=CONFIG["num_layers"],
211
+ num_heads=CONFIG["num_heads"],
212
+ mlp_dim=CONFIG["mlp_dim"],
213
+ dropout=CONFIG["dropout"]
214
+ ).to(device)
215
+
216
+ total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
217
+ print(f"Model initialized on {device}.")
218
+ print(f"Total trainable parameters: {total_params:,}\n")
219
+
220
+ criterion = nn.CrossEntropyLoss()
221
+ # AdamW is often preferred for Transformers
222
+ optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG["learning_rate"])
223
+
224
+ # Training loop
225
+ best_val_acc = 0
226
+ print("--- Starting Training ---")
227
+ for epoch in range(CONFIG["epochs"]):
228
+ train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
229
+ val_loss, val_acc = evaluate(model, val_loader, criterion, device)
230
+
231
+ print(
232
+ f"Epoch {epoch+1:02}/{CONFIG['epochs']} | "
233
+ f"Train Loss: {train_loss:.4f} | "
234
+ f"Val Loss: {val_loss:.4f} | "
235
+ f"Val Acc: {val_acc:.2f}%"
236
+ )
237
+
238
+ if val_acc > best_val_acc:
239
+ best_val_acc = val_acc
240
+ print(f" -> New best validation accuracy! Saving model state.")
241
+ torch.save(model.state_dict(), "PiT_MNIST_best.pth")
242
+
243
+ print("--- Training Finished ---\n")
244
+
245
+ # Final evaluation on the test set using the best model
246
+ print("--- Evaluating on Test Set ---")
247
+ model.load_state_dict(torch.load("PiT_MNIST_best.pth"))
248
+ test_loss, test_acc = evaluate(model, test_loader, criterion, device)
249
+ print(f"Final Test Loss: {test_loss:.4f}")
250
+ print(f"Final Test Accuracy: {test_acc:.2f}%")
251
+ print("----------------------------\n")