Mini-Vision-V2 / train.py
LWWZH's picture
Upload Mini-Vision-V2
5b6d90c verified
import os
import torch
import sys
from torch import nn
import torchvision
from datasets import load_dataset
from torch.utils.data import DataLoader
from model import MiniVisionV2
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
save_path = "minivisionv2_model"
batchsize = 256
learningrate = 1e-2
epoch = 50
if not os.path.exists(save_path):
os.mkdir(save_path)
writer = SummaryWriter("minivisionv2_logs")
dataset = load_dataset("ylecun/mnist")
transform_train = torchvision.transforms.Compose([
torchvision.transforms.RandomCrop(28, 2),
torchvision.transforms.RandomRotation(10),
torchvision.transforms.ToTensor()
])
transform_test = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
])
def transforms_train(data):
data["tensor"] = [transform_train(img) for img in data["image"]]
return data
def transforms_test(data):
data["tensor"] = [transform_test(img) for img in data["image"]]
return data
train_dataset = dataset["train"].with_transform(transforms_train)
test_dataset = dataset["test"].with_transform(transforms_test)
def collate_fn(batch):
return {
"tensor": torch.stack([x["tensor"] for x in batch]),
"label": torch.tensor([x["label"] for x in batch])
}
train_loader = DataLoader(train_dataset, batchsize, True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batchsize, False, collate_fn=collate_fn)
minivisionv2 = MiniVisionV2()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(minivisionv2.parameters(), learningrate, 0.8)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, 0.5)
for i in range(epoch):
print(f"=============== Epoch {i} Start | LR: {optimizer.param_groups[0]["lr"]} ===============")
minivisionv2.train()
total_train_loss = 0
for data in tqdm(train_loader, file=sys.stdout):
optimizer.zero_grad()
imgs = data["tensor"]
labels = data["label"]
output = minivisionv2(imgs)
loss = loss_fn(output, labels)
loss.backward()
optimizer.step()
total_train_loss += loss.item()
total_avg_train_loss = total_train_loss / len(train_loader)
print(f"Train loss: {total_avg_train_loss}")
writer.add_scalar("Train Loss", total_avg_train_loss, i)
minivisionv2.eval()
with torch.no_grad():
total_accuracy = 0
total_test_loss = 0
for data in tqdm(test_loader, file=sys.stdout):
imgs = data["tensor"]
labels = data["label"]
output = minivisionv2(imgs)
loss = loss_fn(output, labels)
total_test_loss += loss
accuracy = (output.argmax(1) == labels).sum()
total_accuracy += accuracy.item()
total_avg_test_loss = total_test_loss / len(test_loader)
total_accuracy_percentage = round(float(total_accuracy / len(test_dataset) * 100), 2)
print(f"Test loss: {total_avg_test_loss}")
print(f"Test Accuracy Percentage: {total_accuracy_percentage}%")
writer.add_scalar("Test Loss", total_avg_test_loss, i)
writer.add_scalar("Test Accuracy Percentage", total_accuracy_percentage, i)
torch.save(minivisionv2, f"./{save_path}/Mini-Vision-V2-Epoch-{i}.pth")
print("Model Saved!")
scheduler.step()
writer.close()