import torch import numpy as np import matplotlib.pyplot as plt import argparse import os import sys import subprocess from sklearn.metrics import confusion_matrix, classification_report import seaborn as sns from src.DataLoader.dataloader import create_dataloader from src.models.resnet18_finetune import make_resnet18 from src.models.cnn_model import PlantCNN from src.utils.config import load_config from datasets import load_from_disk from clearml import Task, InputModel def load_model_from_clearml(task_id: str, model_type: str, num_classes: int, device: str): print(f"[INFO] Loading 'best_model' artifact from Task ID '{task_id}'...") try: source_task = Task.get_task(task_id=task_id) model_path = source_task.artifacts['best_model'].get_local_copy() print(f"[INFO] Model downloaded to: {model_path}") except Exception as e: print(f"[FATAL] Could not retrieve artifact from Task {task_id}. Error: {e}") exit(1) if model_type.lower() == 'resnet18': model = make_resnet18(num_classes=num_classes) elif model_type.lower() == 'cnn': model = PlantCNN(num_classes=num_classes) else: raise ValueError(f"Unknown model type: {model_type}") state_dict = torch.load(model_path, map_location=device) if 'state_dict' in state_dict: state_dict = state_dict['state_dict'] model.load_state_dict(state_dict) model.to(device) model.eval() print("[SUCCESS] Model loaded and ready.") return model def evaluate_model(model, loader, device): """ Runs inference on the entire dataloader and returns predictions and labels. """ all_preds, all_labels = [], [] with torch.no_grad(): for inputs, labels in loader: inputs = inputs.to(device) if labels.ndim > 1: labels = labels.argmax(dim=1) outputs = model(inputs) preds = outputs.argmax(dim=1).cpu().numpy() all_preds.extend(preds) all_labels.extend(labels.cpu().numpy()) return np.array(all_labels), np.array(all_preds) def main(): task = Task.init(project_name="PlantDisease", task_name="model_evaluation", task_type=Task.TaskTypes.testing) task.set_packages("./requirements.txt") task.execute_remotely(queue_name="default") parser = argparse.ArgumentParser(description="Evaluate a trained model from ClearML.") parser.add_argument('--task_id', type=str, required=True, help="ClearML Task ID that produced the model.") parser.add_argument('--model_type', type=str, required=True, choices=['resnet18', 'cnn']) args = parser.parse_args() task.connect(args) logger = task.get_logger() print(f"--- Evaluating Model from Task ID: {args.task_id} ({args.model_type.upper()}) ---") cfg = load_config() device = "cuda" if torch.cuda.is_available() else "cpu" data_path = cfg['data_path'] if not os.path.exists(data_path): print(f"[WARN] Data path '{data_path}' not found. Running processing script...") subprocess.check_call([sys.executable, "process_dataset.py"]) ds_dict = load_from_disk(data_path) test_loader = create_dataloader(ds_dict['test'], batch_size=32, samples_per_epoch=len(ds_dict['test']), is_training_set=False) class_names = ds_dict['test'].features['label'].names num_classes = len(class_names) model = load_model_from_clearml(args.task_id, args.model_type, num_classes, device) y_true, y_pred = evaluate_model(model, test_loader, device) print("\n--- Generating Reports and Plots ---") report_dict = classification_report(y_true, y_pred, target_names=class_names, zero_division=0, output_dict=True) report_text = classification_report(y_true, y_pred, target_names=class_names, zero_division=0) print(report_text) task.upload_artifact(name="classification_report", artifact_object=report_dict) cm = confusion_matrix(y_true, y_pred) plt.figure(figsize=(22, 22)) sns.heatmap(cm, annot=False, cmap='Blues', xticklabels=class_names, yticklabels=class_names) plt.ylabel('True Label', fontsize=14) plt.xlabel('Predicted Label', fontsize=14) plt.title(f'Confusion Matrix - {args.model_type.upper()}', fontsize=16) plt.tight_layout() logger.report_matplotlib_figure(title="Confusion Matrix", series=args.model_type, figure=plt, report_image=True) print("[SUCCESS] Evaluation complete. Artifacts logged to ClearML.") task.close() if __name__ == "__main__": main()