fatimaxa's picture
Upload 46 files
83be575 verified
import torch
import torch.nn as nn
from pathlib import Path
from data_prep import test_loader, device
from models.model import PlantCNN
from utils.config import load_config
from clearml import Task
import numpy as np
from utils.vis import visualize_preds, plot_cfm
from tqdm.auto import tqdm
def evaluate_on_test(model, loader, loss_fn, device, num_imgs):
model.eval()
all_labels = []
all_preds = []
running_loss = 0.0
correct = 0
total = 0
imgs_to_display = []
lbls_to_display = []
prs_to_display = []
with torch.no_grad():
for batch_idx, batch in enumerate(tqdm(loader, desc="Val", leave=False)):
images = batch["pixel_values"].to(device)
labels = batch["labels"].to(device)
output = model(images)
loss = loss_fn(output, labels)
running_loss += loss.item()*labels.size(0)
_, preds = torch.max(output, dim=1)
correct += (preds==labels).sum().item()
total += labels.size(0)
all_labels.extend(labels.cpu().numpy())
all_preds.extend(preds.cpu().numpy())
if len(imgs_to_display) < num_imgs:
remaining = num_imgs - len(imgs_to_display)
for img, lbl, pr in zip(images[:remaining], preds[:remaining], preds[:remaining]):
imgs_to_display.append(img.cpu())
lbls_to_display.append(lbl.item())
prs_to_display.append(pr.item())
test_loss = running_loss / total
test_acc = correct / total
return test_loss, test_acc, all_labels, all_preds, imgs_to_display, lbls_to_display, prs_to_display
def main():
config = load_config()
num_classes = config["num_classes"]
channels = config["channels"]
dropout = config["dropout"]
lr = config["lr"]
project_name = "GAP_plant_disease_classification"
model_name = "PlantCNN"
mean_nm = config["normalize_mean"]
std_nm = config["normalize_std"]
task = Task.init(project_name=project_name, task_name=f"{model_name}_test")
task.connect(config)
task.add_tags([model_name, "test"])
logger = task.get_logger()
dataset = test_loader.dataset
class_names = dataset.features["label"].names
model = PlantCNN(num_classes=num_classes, channels=channels, dropout=dropout).to(device)
project_root = Path(__file__).resolve().parent
model_path = project_root / "saved_models" / "plant_cnn.pt"
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
loss_fn = nn.CrossEntropyLoss()
test_loss, test_acc, all_labels, all_preds, display_images, display_labels, display_preds = evaluate_on_test(model, test_loader,
loss_fn, device,
num_imgs=24)
print("\nTest results:")
print(f"Test loss: {test_loss:.3f} | Test accuracy: {test_acc:.3f}")
logger.report_scalar("loss", "test", test_loss, 0)
logger.report_scalar("accuracy", "test", test_acc, 0)
visualize_preds(display_images, display_labels, display_preds, logger, class_names, mean_nm, std_nm, num_images=24)
plot_cfm(all_labels, all_preds, logger, class_names, num_classes)
if __name__ == "__main__":
main()