Gillie2004 commited on
Commit
8f0e25c
·
verified ·
1 Parent(s): a82dfe3

Upload 4 files

Browse files
script/__pycache__/predict.cpython-310.pyc ADDED
Binary file (1.19 kB). View file
 
script/__pycache__/train.cpython-310.pyc ADDED
Binary file (997 Bytes). View file
 
script/predict.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
4
+
5
+ import torch
6
+ from PIL import Image
7
+ from torchvision import transforms
8
+ from models.cnn_model import CatBreedCNN
9
+
10
+
11
+ classes = ['Bengal', 'Domestic_Shorthair', 'Maine_Coon','Ragdoll','Siamese',] # Update as needed
12
+
13
+ def predict(image_path):
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
16
+ model = CatBreedCNN(num_classes=len(classes))
17
+ model.load_state_dict(torch.load("models/cat_cnn.pth", map_location=device))
18
+ model.eval()
19
+
20
+ transform = transforms.Compose([
21
+ transforms.Resize((128, 128)),
22
+ transforms.ToTensor(),
23
+ transforms.Normalize([0.5]*3, [0.5]*3)
24
+ ])
25
+
26
+ image = Image.open(image_path).convert("RGB")
27
+ image = transform(image).unsqueeze(0).to(device)
28
+
29
+ with torch.no_grad():
30
+ output = model(image)
31
+ predicted_index = output.argmax(dim=1).item()
32
+
33
+ return classes[predicted_index]
script/train.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from models.cnn_model import CatBreedCNN
5
+ from utils.data_loader import get_dataloaders
6
+ from utils.evaluate import evaluate_model
7
+
8
+ # Load data
9
+ train_loader, val_loader, classes = get_dataloaders("data/cat_breed_dataset")
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
+ # Initialize model
13
+ model = CatBreedCNN(len(classes)).to(device)
14
+
15
+ # Loss and optimizer
16
+ criterion = nn.CrossEntropyLoss()
17
+ optimizer = optim.Adam(model.parameters(), lr=0.001)
18
+
19
+ # Training loop
20
+ for epoch in range(20):
21
+ model.train()
22
+ for x, y in train_loader:
23
+ x, y = x.to(device), y.to(device)
24
+ optimizer.zero_grad()
25
+ outputs = model(x)
26
+ loss = criterion(outputs, y)
27
+ loss.backward()
28
+ optimizer.step()
29
+
30
+ print(f"Epoch {epoch+1} complete. Evaluating...") # ✅ Corrected f-string
31
+
32
+ # Evaluate
33
+ report, _ = evaluate_model(model, val_loader, device)
34
+ print(report)
35
+
36
+ # Save model
37
+ torch.save(model.state_dict(), "models/cat_cnn.pth")