import torch from torch.utils.data import DataLoader from torchvision import transforms, datasets from transformers import ViTModel, ViTConfig, ViTForImageClassification import torch.nn as nn import torch.optim as optim from tqdm import tqdm # Set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Hyperparameters IMAGE_SIZE = 28 # MNIST image size PATCH_SIZE = 7 # Patch size to divide 28x28 image NUM_CLASSES = 10 BATCH_SIZE = 128 EPOCHS = 5 LR = 2e-4 # Resize and normalize transform = transforms.Compose([ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) # Load MNIST dataset train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform) train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE) # Use a pre-configured ViT for image classification configuration = ViTConfig( image_size=IMAGE_SIZE, patch_size=PATCH_SIZE, num_labels=NUM_CLASSES, hidden_size=128, num_hidden_layers=4, num_attention_heads=4, intermediate_size=256, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, initializer_range=0.02 ) model = ViTForImageClassification(configuration).to(device) # Alternatively, you can also load a pretrained ViT and fine-tune it: # model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=10) # Optimizer optimizer = optim.AdamW(model.parameters(), lr=LR) criterion = nn.CrossEntropyLoss() # Training loop def train(): model.train() for epoch in range(EPOCHS): total_loss = 0 correct = 0 total = 0 for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"): images, labels = images.to(device), labels.to(device) # Repeat grayscale channel to match expected input shape (ViT expects 3 channels) images = images.repeat(1, 3, 1, 1) outputs = model(images, labels=labels) loss = outputs.loss logits = outputs.logits optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() preds = torch.argmax(logits, dim=-1) correct += (preds == labels).sum().item() total += labels.size(0) print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}, Accuracy: {correct/total:.4f}") # Evaluation loop def evaluate(): model.eval() correct = 0 total = 0 with torch.no_grad(): for images, labels in test_loader: images, labels = images.to(device), labels.to(device) images = images.repeat(1, 3, 1, 1) outputs = model(images) logits = outputs.logits preds = torch.argmax(logits, dim=-1) correct += (preds == labels).sum().item() total += labels.size(0) print(f"Test Accuracy: {correct / total:.4f}") # Run training and evaluation if __name__ == "__main__": train() evaluate() model.save_pretrained(".") torch.save(model, "vit_mnist.pth")