The following PiT_MNIST_V1.0.ipynb is a direct implementationi of the PiT pixel transformer described in the 2024 paper titled An Image is Worth More Than 16 x 16 Patches: Exploring Transformers on Individual Pixels at https://arxiv.org/html/2406.09415v1 which describes "directly treating each individual pixel as a token and achieve highly performant results" This script simply applies this PiT model architecture without any modifications to the standard NMIST numeral-images-classification dataset that is provided in Google Colab sample_data folder. The script was ran for 25 epochs and obtained 92.30 Accuracy on the Validation set ( Train Loss: 0.2800 | Val Loss: 0.2441 | Val Acc: 92.30%) by epoch 15. Loss fell and Accuracy increased (almost) monontonically per each epoch until Epoch 18. (one minor dip in accuracy between Epoch 13 and 14, and again at Epoch 18-19, and 23-24 while Train Loss always continued to drop) Final Test Accuracy: 95.01% (25 Epochs) Final Test Loss: 0.1662 Ran on A100 PiT_MNIST_V1.0.ipynb Current session GPU 0 minutes ago 2.78 GB 6.51 GB Python 3 Google Compute Engine backend (GPU) Showing resources from 7:40 PM to 8:01 PM System RAM 2.8 / 83.5 GB GPU RAM 6.5 / 40.0 GB Disk 37.7 / 112.6 GB # ============================================================================== # PiT_MNIST_V1.0.py [in colab: PiT_MNIST_V1.0.ipynb] # # ML-Engineer LLM Agent Implementation # # Description: # This script implements a Pixel Transformer (PiT) for MNIST classification, # based on the paper "An Image is Worth More Than 16x16 Patches" # (arXiv:2406.09415). It treats each pixel as an individual token, forgoing # the patch-based approach of traditional Vision Transformers. # # Designed for Google Colab using the sample_data/mnist_*.csv files. # ============================================================================== import torch import torch.nn as nn import pandas as pd from torch.utils.data import Dataset, DataLoader from sklearn.model_selection import train_test_split from tqdm import tqdm import math # --- 1. Configuration & Hyperparameters --- # These parameters are chosen to be reasonable for the MNIST task and # inspired by the "Tiny" or "Small" variants in the paper. CONFIG = { "train_file": "/content/sample_data/mnist_train_small.csv", "test_file": "/content/sample_data/mnist_test.csv", "image_size": 28, "num_classes": 10, "embed_dim": 128, # 'd' in the paper. Dimension for each pixel embedding. "num_layers": 6, # Number of Transformer Encoder layers. "num_heads": 8, # Number of heads in Multi-Head Self-Attention. Must be a divisor of embed_dim. "mlp_dim": 512, # Hidden dimension of the MLP block inside the Transformer. (4 * embed_dim is common) "dropout": 0.1, "batch_size": 128, "epochs": 25, # Increased epochs for better convergence on the small dataset. "learning_rate": 1e-4, "device": "cuda" if torch.cuda.is_available() else "cpu", } CONFIG["sequence_length"] = CONFIG["image_size"] * CONFIG["image_size"] # 784 for MNIST print("--- Configuration ---") for key, value in CONFIG.items(): print(f"{key}: {value}") print("---------------------\n") # --- 2. Data Loading and Preprocessing --- class MNIST_CSV_Dataset(Dataset): """Custom PyTorch Dataset for loading MNIST data from CSV files.""" def __init__(self, file_path): df = pd.read_csv(file_path) self.labels = torch.tensor(df.iloc[:, 0].values, dtype=torch.long) # Normalize pixel values to [0, 1] and keep as float self.pixels = torch.tensor(df.iloc[:, 1:].values, dtype=torch.float32) / 255.0 def __len__(self): return len(self.labels) def __getitem__(self, idx): # The PiT's projection layer expects input of shape (in_features), # so for each pixel, we need a tensor of shape (1). # We reshape the 784 pixels to (784, 1). return self.pixels[idx].unsqueeze(-1), self.labels[idx] # --- 3. Pixel Transformer (PiT) Model Architecture --- class PixelTransformer(nn.Module): """ Pixel Transformer (PiT) model. Treats each pixel as a token and uses a Transformer Encoder for classification. """ def __init__(self, seq_len, num_classes, embed_dim, num_layers, num_heads, mlp_dim, dropout): super().__init__() # 1. Pixel Projection: Each pixel (a single value) is projected to embed_dim. # This is the core "pixels-as-tokens" step. self.pixel_projection = nn.Linear(1, embed_dim) # 2. CLS Token: A learnable parameter that is prepended to the sequence of # pixel embeddings. Its output state is used for classification. self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim)) # 3. Position Embedding: Learnable embeddings to encode spatial information. # Size is (seq_len + 1) to account for the CLS token. # This removes the inductive bias of fixed positional encodings. self.position_embedding = nn.Parameter(torch.randn(1, seq_len + 1, embed_dim)) self.dropout = nn.Dropout(dropout) # 4. Transformer Encoder: The main workhorse of the model. encoder_layer = nn.TransformerEncoderLayer( d_model=embed_dim, nhead=num_heads, dim_feedforward=mlp_dim, dropout=dropout, activation="gelu", batch_first=True # Important for (batch, seq, feature) input format ) self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) # 5. Classification Head: A simple MLP head on top of the CLS token's output. self.mlp_head = nn.Sequential( nn.LayerNorm(embed_dim), nn.Linear(embed_dim, num_classes) ) def forward(self, x): # Input x shape: (batch_size, seq_len, 1) -> (B, 784, 1) # Project pixels to embedding dimension x = self.pixel_projection(x) # (B, 784, 1) -> (B, 784, embed_dim) # Prepend CLS token cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) # (B, 1, embed_dim) x = torch.cat((cls_tokens, x), dim=1) # (B, 785, embed_dim) # Add position embedding x = x + self.position_embedding # (B, 785, embed_dim) x = self.dropout(x) # Pass through Transformer Encoder x = self.transformer_encoder(x) # (B, 785, embed_dim) # Extract the CLS token's output (at position 0) cls_output = x[:, 0] # (B, embed_dim) # Pass through MLP head to get logits logits = self.mlp_head(cls_output) # (B, num_classes) return logits # --- 4. Training and Evaluation Functions --- def train_one_epoch(model, dataloader, criterion, optimizer, device): model.train() total_loss = 0 progress_bar = tqdm(dataloader, desc="Training", leave=False) for pixels, labels in progress_bar: pixels, labels = pixels.to(device), labels.to(device) # Forward pass logits = model(pixels) loss = criterion(logits, labels) # Backward and optimize optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() progress_bar.set_postfix(loss=loss.item()) return total_loss / len(dataloader) def evaluate(model, dataloader, criterion, device): model.eval() total_loss = 0 correct = 0 total = 0 with torch.no_grad(): progress_bar = tqdm(dataloader, desc="Evaluating", leave=False) for pixels, labels in progress_bar: pixels, labels = pixels.to(device), labels.to(device) logits = model(pixels) loss = criterion(logits, labels) total_loss += loss.item() _, predicted = torch.max(logits.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() progress_bar.set_postfix(acc=100. * correct / total) avg_loss = total_loss / len(dataloader) accuracy = 100. * correct / total return avg_loss, accuracy # --- 5. Main Execution Block --- if __name__ == "__main__": device = CONFIG["device"] # Load full training data and split into train/validation sets # This helps monitor overfitting, as mnist_train_small is quite small. full_train_dataset = MNIST_CSV_Dataset(CONFIG["train_file"]) train_indices, val_indices = train_test_split( range(len(full_train_dataset)), test_size=0.1, # 10% for validation random_state=42 ) train_dataset = torch.utils.data.Subset(full_train_dataset, train_indices) val_dataset = torch.utils.data.Subset(full_train_dataset, val_indices) test_dataset = MNIST_CSV_Dataset(CONFIG["test_file"]) train_loader = DataLoader(train_dataset, batch_size=CONFIG["batch_size"], shuffle=True) val_loader = DataLoader(val_dataset, batch_size=CONFIG["batch_size"], shuffle=False) test_loader = DataLoader(test_dataset, batch_size=CONFIG["batch_size"], shuffle=False) print(f"\nData loaded.") print(f" Training samples: {len(train_dataset)}") print(f" Validation samples: {len(val_dataset)}") print(f" Test samples: {len(test_dataset)}\n") # Initialize model, loss function, and optimizer model = PixelTransformer( seq_len=CONFIG["sequence_length"], num_classes=CONFIG["num_classes"], embed_dim=CONFIG["embed_dim"], num_layers=CONFIG["num_layers"], num_heads=CONFIG["num_heads"], mlp_dim=CONFIG["mlp_dim"], dropout=CONFIG["dropout"] ).to(device) total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Model initialized on {device}.") print(f"Total trainable parameters: {total_params:,}\n") criterion = nn.CrossEntropyLoss() # AdamW is often preferred for Transformers optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG["learning_rate"]) # Training loop best_val_acc = 0 print("--- Starting Training ---") for epoch in range(CONFIG["epochs"]): train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device) val_loss, val_acc = evaluate(model, val_loader, criterion, device) print( f"Epoch {epoch+1:02}/{CONFIG['epochs']} | " f"Train Loss: {train_loss:.4f} | " f"Val Loss: {val_loss:.4f} | " f"Val Acc: {val_acc:.2f}%" ) if val_acc > best_val_acc: best_val_acc = val_acc print(f" -> New best validation accuracy! Saving model state.") torch.save(model.state_dict(), "PiT_MNIST_best.pth") print("--- Training Finished ---\n") # Final evaluation on the test set using the best model print("--- Evaluating on Test Set ---") model.load_state_dict(torch.load("PiT_MNIST_best.pth")) test_loss, test_acc = evaluate(model, test_loader, criterion, device) print(f"Final Test Loss: {test_loss:.4f}") print(f"Final Test Accuracy: {test_acc:.2f}%") print("----------------------------\n") [The PiT_MNIST_V1.0.ipynb script ran out of memory in CPUR, but was able to run and train fast in A100 GPU mode] --- Configuration --- train_file: /content/sample_data/mnist_train_small.csv test_file: /content/sample_data/mnist_test.csv image_size: 28 num_classes: 10 embed_dim: 128 num_layers: 6 num_heads: 8 mlp_dim: 512 dropout: 0.1 batch_size: 128 epochs: 25 learning_rate: 0.0001 device: cuda sequence_length: 784 --------------------- Data loaded. Training samples: 17999 Validation samples: 2000 Test samples: 9999 Model initialized on cuda. Total trainable parameters: 1,292,042 --- Starting Training --- Epoch 01/25 | Train Loss: 2.2063 | Val Loss: 2.0610 | Val Acc: 22.75% -> New best validation accuracy! Saving model state. Epoch 02/25 | Train Loss: 1.9907 | Val Loss: 1.7945 | Val Acc: 32.00% -> New best validation accuracy! Saving model state. Epoch 03/25 | Train Loss: 1.5767 | Val Loss: 1.1938 | Val Acc: 58.35% -> New best validation accuracy! Saving model state. Epoch 04/25 | Train Loss: 1.0441 | Val Loss: 0.7131 | Val Acc: 77.10% -> New best validation accuracy! Saving model state. Epoch 05/25 | Train Loss: 0.7299 | Val Loss: 0.5490 | Val Acc: 82.95% -> New best validation accuracy! Saving model state. Epoch 06/25 | Train Loss: 0.5935 | Val Loss: 0.4821 | Val Acc: 84.60% -> New best validation accuracy! Saving model state. Epoch 07/25 | Train Loss: 0.5311 | Val Loss: 0.4021 | Val Acc: 86.95% -> New best validation accuracy! Saving model state. Epoch 08/25 | Train Loss: 0.4682 | Val Loss: 0.3680 | Val Acc: 88.05% -> New best validation accuracy! Saving model state. Epoch 09/25 | Train Loss: 0.4264 | Val Loss: 0.3446 | Val Acc: 89.20% -> New best validation accuracy! Saving model state. Epoch 10/25 | Train Loss: 0.4038 | Val Loss: 0.3163 | Val Acc: 89.95% -> New best validation accuracy! Saving model state. Epoch 11/25 | Train Loss: 0.3641 | Val Loss: 0.2941 | Val Acc: 90.80% -> New best validation accuracy! Saving model state. Epoch 12/25 | Train Loss: 0.3447 | Val Loss: 0.2759 | Val Acc: 91.45% -> New best validation accuracy! Saving model state. Epoch 13/25 | Train Loss: 0.3181 | Val Loss: 0.2603 | Val Acc: 92.05% -> New best validation accuracy! Saving model state. Epoch 14/25 | Train Loss: 0.3023 | Val Loss: 0.2695 | Val Acc: 91.90% Epoch 15/25 | Train Loss: 0.2800 | Val Loss: 0.2441 | Val Acc: 92.30% -> New best validation accuracy! Saving model state. Epoch 16/25 | Train Loss: 0.2677 | Val Loss: 0.2377 | Val Acc: 92.65% -> New best validation accuracy! Saving model state. Epoch 17/25 | Train Loss: 0.2535 | Val Loss: 0.2143 | Val Acc: 93.80% -> New best validation accuracy! Saving model state. Epoch 18/25 | Train Loss: 0.2395 | Val Loss: 0.2059 | Val Acc: 94.05% -> New best validation accuracy! Saving model state. Epoch 19/25 | Train Loss: 0.2276 | Val Loss: 0.2126 | Val Acc: 93.75% Epoch 20/25 | Train Loss: 0.2189 | Val Loss: 0.1907 | Val Acc: 94.40% -> New best validation accuracy! Saving model state. Epoch 21/25 | Train Loss: 0.2113 | Val Loss: 0.1892 | Val Acc: 94.35% Epoch 22/25 | Train Loss: 0.2004 | Val Loss: 0.1775 | Val Acc: 94.50% -> New best validation accuracy! Saving model state. Epoch 23/25 | Train Loss: 0.1927 | Val Loss: 0.1912 | Val Acc: 94.15% Epoch 24/25 | Train Loss: 0.1836 | Val Loss: 0.1746 | Val Acc: 94.75% -> New best validation accuracy! Saving model state. Epoch 25/25 | Train Loss: 0.1804 | Val Loss: 0.1642 | Val Acc: 94.75% --- Training Finished --- --- Evaluating on Test Set --- Final Test Loss: 0.1662 Final Test Accuracy: 95.01% ---------------------------- --- license: apache-2.0 ---