Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from pathlib import Path | |
| from data_prep import test_loader, device | |
| from models.model import PlantCNN | |
| from utils.config import load_config | |
| from clearml import Task | |
| import numpy as np | |
| from utils.vis import visualize_preds, plot_cfm | |
| from tqdm.auto import tqdm | |
| def evaluate_on_test(model, loader, loss_fn, device, num_imgs): | |
| model.eval() | |
| all_labels = [] | |
| all_preds = [] | |
| running_loss = 0.0 | |
| correct = 0 | |
| total = 0 | |
| imgs_to_display = [] | |
| lbls_to_display = [] | |
| prs_to_display = [] | |
| with torch.no_grad(): | |
| for batch_idx, batch in enumerate(tqdm(loader, desc="Val", leave=False)): | |
| images = batch["pixel_values"].to(device) | |
| labels = batch["labels"].to(device) | |
| output = model(images) | |
| loss = loss_fn(output, labels) | |
| running_loss += loss.item()*labels.size(0) | |
| _, preds = torch.max(output, dim=1) | |
| correct += (preds==labels).sum().item() | |
| total += labels.size(0) | |
| all_labels.extend(labels.cpu().numpy()) | |
| all_preds.extend(preds.cpu().numpy()) | |
| if len(imgs_to_display) < num_imgs: | |
| remaining = num_imgs - len(imgs_to_display) | |
| for img, lbl, pr in zip(images[:remaining], preds[:remaining], preds[:remaining]): | |
| imgs_to_display.append(img.cpu()) | |
| lbls_to_display.append(lbl.item()) | |
| prs_to_display.append(pr.item()) | |
| test_loss = running_loss / total | |
| test_acc = correct / total | |
| return test_loss, test_acc, all_labels, all_preds, imgs_to_display, lbls_to_display, prs_to_display | |
| def main(): | |
| config = load_config() | |
| num_classes = config["num_classes"] | |
| channels = config["channels"] | |
| dropout = config["dropout"] | |
| lr = config["lr"] | |
| project_name = "GAP_plant_disease_classification" | |
| model_name = "PlantCNN" | |
| mean_nm = config["normalize_mean"] | |
| std_nm = config["normalize_std"] | |
| task = Task.init(project_name=project_name, task_name=f"{model_name}_test") | |
| task.connect(config) | |
| task.add_tags([model_name, "test"]) | |
| logger = task.get_logger() | |
| dataset = test_loader.dataset | |
| class_names = dataset.features["label"].names | |
| model = PlantCNN(num_classes=num_classes, channels=channels, dropout=dropout).to(device) | |
| project_root = Path(__file__).resolve().parent | |
| model_path = project_root / "saved_models" / "plant_cnn.pt" | |
| state_dict = torch.load(model_path, map_location=device) | |
| model.load_state_dict(state_dict) | |
| loss_fn = nn.CrossEntropyLoss() | |
| test_loss, test_acc, all_labels, all_preds, display_images, display_labels, display_preds = evaluate_on_test(model, test_loader, | |
| loss_fn, device, | |
| num_imgs=24) | |
| print("\nTest results:") | |
| print(f"Test loss: {test_loss:.3f} | Test accuracy: {test_acc:.3f}") | |
| logger.report_scalar("loss", "test", test_loss, 0) | |
| logger.report_scalar("accuracy", "test", test_acc, 0) | |
| visualize_preds(display_images, display_labels, display_preds, logger, class_names, mean_nm, std_nm, num_images=24) | |
| plot_cfm(all_labels, all_preds, logger, class_names, num_classes) | |
| if __name__ == "__main__": | |
| main() |