|
|
import torch |
|
|
from torch.utils.data import DataLoader |
|
|
from torchvision import transforms, datasets |
|
|
from transformers import ViTModel, ViTConfig, ViTForImageClassification |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
IMAGE_SIZE = 28 |
|
|
PATCH_SIZE = 7 |
|
|
NUM_CLASSES = 10 |
|
|
BATCH_SIZE = 128 |
|
|
EPOCHS = 5 |
|
|
LR = 2e-4 |
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize((0.5,), (0.5,)) |
|
|
]) |
|
|
|
|
|
|
|
|
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) |
|
|
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform) |
|
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) |
|
|
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE) |
|
|
|
|
|
|
|
|
configuration = ViTConfig( |
|
|
image_size=IMAGE_SIZE, |
|
|
patch_size=PATCH_SIZE, |
|
|
num_labels=NUM_CLASSES, |
|
|
hidden_size=128, |
|
|
num_hidden_layers=4, |
|
|
num_attention_heads=4, |
|
|
intermediate_size=256, |
|
|
hidden_act="gelu", |
|
|
hidden_dropout_prob=0.1, |
|
|
attention_probs_dropout_prob=0.1, |
|
|
initializer_range=0.02 |
|
|
) |
|
|
|
|
|
model = ViTForImageClassification(configuration).to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
optimizer = optim.AdamW(model.parameters(), lr=LR) |
|
|
criterion = nn.CrossEntropyLoss() |
|
|
|
|
|
|
|
|
def train(): |
|
|
model.train() |
|
|
for epoch in range(EPOCHS): |
|
|
total_loss = 0 |
|
|
correct = 0 |
|
|
total = 0 |
|
|
for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"): |
|
|
images, labels = images.to(device), labels.to(device) |
|
|
|
|
|
|
|
|
images = images.repeat(1, 3, 1, 1) |
|
|
|
|
|
outputs = model(images, labels=labels) |
|
|
loss = outputs.loss |
|
|
logits = outputs.logits |
|
|
|
|
|
optimizer.zero_grad() |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
total_loss += loss.item() |
|
|
preds = torch.argmax(logits, dim=-1) |
|
|
correct += (preds == labels).sum().item() |
|
|
total += labels.size(0) |
|
|
|
|
|
print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}, Accuracy: {correct/total:.4f}") |
|
|
|
|
|
|
|
|
def evaluate(): |
|
|
model.eval() |
|
|
correct = 0 |
|
|
total = 0 |
|
|
with torch.no_grad(): |
|
|
for images, labels in test_loader: |
|
|
images, labels = images.to(device), labels.to(device) |
|
|
images = images.repeat(1, 3, 1, 1) |
|
|
|
|
|
outputs = model(images) |
|
|
logits = outputs.logits |
|
|
|
|
|
preds = torch.argmax(logits, dim=-1) |
|
|
correct += (preds == labels).sum().item() |
|
|
total += labels.size(0) |
|
|
|
|
|
print(f"Test Accuracy: {correct / total:.4f}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
train() |
|
|
evaluate() |
|
|
model.save_pretrained(".") |
|
|
torch.save(model, "vit_mnist.pth") |
|
|
|