CCS229_ALA / src /train.py
Gillie2004's picture
Upload 4 files
a82dfe3 verified
import torch
import torch.nn as nn
import torch.optim as optim
from models.cnn_model import CatBreedCNN
from utils.data_loader import get_dataloaders
from utils.evaluate import evaluate_model
# Load data
train_loader, val_loader, classes = get_dataloaders("data/cat_breed_dataset")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize model
model = CatBreedCNN(len(classes)).to(device)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Training loop
for epoch in range(20):
model.train()
for x, y in train_loader:
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
outputs = model(x)
loss = criterion(outputs, y)
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1} complete. Evaluating...") # ✅ Corrected f-string
# Evaluate
report, _ = evaluate_model(model, val_loader, device)
print(report)
# Save model
torch.save(model.state_dict(), "models/cat_cnn.pth")