File size: 1,044 Bytes
8f0e25c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
import torch.nn as nn
import torch.optim as optim
from models.cnn_model import CatBreedCNN
from utils.data_loader import get_dataloaders
from utils.evaluate import evaluate_model

# Load data
train_loader, val_loader, classes = get_dataloaders("data/cat_breed_dataset")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize model
model = CatBreedCNN(len(classes)).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
for epoch in range(20):
    model.train()
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        outputs = model(x)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()
    
    print(f"Epoch {epoch+1} complete. Evaluating...")  # ✅ Corrected f-string

    # Evaluate
    report, _ = evaluate_model(model, val_loader, device)
    print(report)

# Save model
torch.save(model.state_dict(), "models/cat_cnn.pth")