import torch import torch.nn as nn from pathlib import Path from data_prep import test_loader, device from models.model import PlantCNN from utils.config import load_config from clearml import Task import numpy as np from utils.vis import visualize_preds, plot_cfm from tqdm.auto import tqdm def evaluate_on_test(model, loader, loss_fn, device, num_imgs): model.eval() all_labels = [] all_preds = [] running_loss = 0.0 correct = 0 total = 0 imgs_to_display = [] lbls_to_display = [] prs_to_display = [] with torch.no_grad(): for batch_idx, batch in enumerate(tqdm(loader, desc="Val", leave=False)): images = batch["pixel_values"].to(device) labels = batch["labels"].to(device) output = model(images) loss = loss_fn(output, labels) running_loss += loss.item()*labels.size(0) _, preds = torch.max(output, dim=1) correct += (preds==labels).sum().item() total += labels.size(0) all_labels.extend(labels.cpu().numpy()) all_preds.extend(preds.cpu().numpy()) if len(imgs_to_display) < num_imgs: remaining = num_imgs - len(imgs_to_display) for img, lbl, pr in zip(images[:remaining], preds[:remaining], preds[:remaining]): imgs_to_display.append(img.cpu()) lbls_to_display.append(lbl.item()) prs_to_display.append(pr.item()) test_loss = running_loss / total test_acc = correct / total return test_loss, test_acc, all_labels, all_preds, imgs_to_display, lbls_to_display, prs_to_display def main(): config = load_config() num_classes = config["num_classes"] channels = config["channels"] dropout = config["dropout"] lr = config["lr"] project_name = "GAP_plant_disease_classification" model_name = "PlantCNN" mean_nm = config["normalize_mean"] std_nm = config["normalize_std"] task = Task.init(project_name=project_name, task_name=f"{model_name}_test") task.connect(config) task.add_tags([model_name, "test"]) logger = task.get_logger() dataset = test_loader.dataset class_names = dataset.features["label"].names model = PlantCNN(num_classes=num_classes, channels=channels, dropout=dropout).to(device) project_root = Path(__file__).resolve().parent model_path = project_root / "saved_models" / "plant_cnn.pt" state_dict = torch.load(model_path, map_location=device) model.load_state_dict(state_dict) loss_fn = nn.CrossEntropyLoss() test_loss, test_acc, all_labels, all_preds, display_images, display_labels, display_preds = evaluate_on_test(model, test_loader, loss_fn, device, num_imgs=24) print("\nTest results:") print(f"Test loss: {test_loss:.3f} | Test accuracy: {test_acc:.3f}") logger.report_scalar("loss", "test", test_loss, 0) logger.report_scalar("accuracy", "test", test_acc, 0) visualize_preds(display_images, display_labels, display_preds, logger, class_names, mean_nm, std_nm, num_images=24) plot_cfm(all_labels, all_preds, logger, class_names, num_classes) if __name__ == "__main__": main()