Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from models.cnn_model import CatBreedCNN | |
| from utils.data_loader import get_dataloaders | |
| from utils.evaluate import evaluate_model | |
| # Load data | |
| train_loader, val_loader, classes = get_dataloaders("data/cat_breed_dataset") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Initialize model | |
| model = CatBreedCNN(len(classes)).to(device) | |
| # Loss and optimizer | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = optim.Adam(model.parameters(), lr=0.001) | |
| # Training loop | |
| for epoch in range(20): | |
| model.train() | |
| for x, y in train_loader: | |
| x, y = x.to(device), y.to(device) | |
| optimizer.zero_grad() | |
| outputs = model(x) | |
| loss = criterion(outputs, y) | |
| loss.backward() | |
| optimizer.step() | |
| print(f"Epoch {epoch+1} complete. Evaluating...") # ✅ Corrected f-string | |
| # Evaluate | |
| report, _ = evaluate_model(model, val_loader, device) | |
| print(report) | |
| # Save model | |
| torch.save(model.state_dict(), "models/cat_cnn.pth") |