AlexNet / train.py
DrujZ-cmd's picture
AI417 A5 AlexNet model
2dfdcd4
import time
import torch as tr
import torchvision as tv
import torchvision.transforms as transforms
from model import model, loss_fn, optimizer
def main():
batch_size = 128
epochs = 10
transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor()])
trn_dataset = tv.datasets.ImageFolder(
root='data/tiny-imagenet-200/train',
transform=transform)
evl_dataset = tv.datasets.ImageFolder(
root='data/tiny-imagenet-200/val',
transform=transform)
trn_loader = tr.utils.data.DataLoader(
trn_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=8,
pin_memory=True)
evl_loader = tr.utils.data.DataLoader(
evl_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=8,
pin_memory=True)
device = tr.device("cuda" if tr.cuda.is_available() else "cpu")
model.to(device)
print("Running on:", device)
start_time = time.time()
for epoch in range(1, epochs + 1):
model.train()
for images, labels in trn_loader:
images = images.to(device)
labels = labels.to(device)
optimizer.zero_grad()
output = model(images)
loss = loss_fn(tr.log(tr.clamp(output, min=1e-9)), labels)
loss.backward()
optimizer.step()
model.eval()
correct = 0
total = 0
with tr.no_grad():
for images, labels in evl_loader:
images = images.to(device)
labels = labels.to(device)
output = model(images)
preds = tr.argmax(output, dim=1)
correct += tr.sum(preds == labels).item()
total += labels.size(0)
accuracy = correct / total
print(f"Epoch {epoch}/{epochs} - Loss: {loss.item():.4f} - Accuracy: {accuracy*100:.2f}%")
end_time = time.time()
print("Training finished")
print(f"Final Loss: {loss.item():.4f}")
print(f"Final Accuracy: {accuracy*100:.2f}%")
print(f"Total Time: {end_time - start_time:.2f} seconds")
if __name__ == "__main__":
main()