vit_mnist / train.py
xcx0902's picture
Upload folder using huggingface_hub
f4b1740 verified
import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from transformers import ViTModel, ViTConfig, ViTForImageClassification
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Hyperparameters
IMAGE_SIZE = 28 # MNIST image size
PATCH_SIZE = 7 # Patch size to divide 28x28 image
NUM_CLASSES = 10
BATCH_SIZE = 128
EPOCHS = 5
LR = 2e-4
# Resize and normalize
transform = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Load MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)
# Use a pre-configured ViT for image classification
configuration = ViTConfig(
image_size=IMAGE_SIZE,
patch_size=PATCH_SIZE,
num_labels=NUM_CLASSES,
hidden_size=128,
num_hidden_layers=4,
num_attention_heads=4,
intermediate_size=256,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
initializer_range=0.02
)
model = ViTForImageClassification(configuration).to(device)
# Alternatively, you can also load a pretrained ViT and fine-tune it:
# model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=10)
# Optimizer
optimizer = optim.AdamW(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()
# Training loop
def train():
model.train()
for epoch in range(EPOCHS):
total_loss = 0
correct = 0
total = 0
for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
images, labels = images.to(device), labels.to(device)
# Repeat grayscale channel to match expected input shape (ViT expects 3 channels)
images = images.repeat(1, 3, 1, 1)
outputs = model(images, labels=labels)
loss = outputs.loss
logits = outputs.logits
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
preds = torch.argmax(logits, dim=-1)
correct += (preds == labels).sum().item()
total += labels.size(0)
print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}, Accuracy: {correct/total:.4f}")
# Evaluation loop
def evaluate():
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
images = images.repeat(1, 3, 1, 1)
outputs = model(images)
logits = outputs.logits
preds = torch.argmax(logits, dim=-1)
correct += (preds == labels).sum().item()
total += labels.size(0)
print(f"Test Accuracy: {correct / total:.4f}")
# Run training and evaluation
if __name__ == "__main__":
train()
evaluate()
model.save_pretrained(".")
torch.save(model, "vit_mnist.pth")