File size: 3,346 Bytes
f4b1740 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from transformers import ViTModel, ViTConfig, ViTForImageClassification
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Hyperparameters
IMAGE_SIZE = 28 # MNIST image size
PATCH_SIZE = 7 # Patch size to divide 28x28 image
NUM_CLASSES = 10
BATCH_SIZE = 128
EPOCHS = 5
LR = 2e-4
# Resize and normalize
transform = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Load MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)
# Use a pre-configured ViT for image classification
configuration = ViTConfig(
image_size=IMAGE_SIZE,
patch_size=PATCH_SIZE,
num_labels=NUM_CLASSES,
hidden_size=128,
num_hidden_layers=4,
num_attention_heads=4,
intermediate_size=256,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
initializer_range=0.02
)
model = ViTForImageClassification(configuration).to(device)
# Alternatively, you can also load a pretrained ViT and fine-tune it:
# model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=10)
# Optimizer
optimizer = optim.AdamW(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()
# Training loop
def train():
model.train()
for epoch in range(EPOCHS):
total_loss = 0
correct = 0
total = 0
for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
images, labels = images.to(device), labels.to(device)
# Repeat grayscale channel to match expected input shape (ViT expects 3 channels)
images = images.repeat(1, 3, 1, 1)
outputs = model(images, labels=labels)
loss = outputs.loss
logits = outputs.logits
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
preds = torch.argmax(logits, dim=-1)
correct += (preds == labels).sum().item()
total += labels.size(0)
print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}, Accuracy: {correct/total:.4f}")
# Evaluation loop
def evaluate():
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
images = images.repeat(1, 3, 1, 1)
outputs = model(images)
logits = outputs.logits
preds = torch.argmax(logits, dim=-1)
correct += (preds == labels).sum().item()
total += labels.size(0)
print(f"Test Accuracy: {correct / total:.4f}")
# Run training and evaluation
if __name__ == "__main__":
train()
evaluate()
model.save_pretrained(".")
torch.save(model, "vit_mnist.pth")
|