File size: 4,761 Bytes
97fcc90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
import numpy as np
import matplotlib.pyplot as plt
import argparse
import os
import sys
import subprocess
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

from src.DataLoader.dataloader import create_dataloader
from src.models.resnet18_finetune import make_resnet18
from src.models.cnn_model import PlantCNN
from src.utils.config import load_config
from datasets import load_from_disk
from clearml import Task, InputModel

def load_model_from_clearml(task_id: str, model_type: str, num_classes: int, device: str):
    print(f"[INFO] Loading 'best_model' artifact from Task ID '{task_id}'...")
    try:
        source_task = Task.get_task(task_id=task_id)
        model_path = source_task.artifacts['best_model'].get_local_copy()
        print(f"[INFO] Model downloaded to: {model_path}")
    except Exception as e:
        print(f"[FATAL] Could not retrieve artifact from Task {task_id}. Error: {e}")
        exit(1)
    
    if model_type.lower() == 'resnet18':
        model = make_resnet18(num_classes=num_classes)
    elif model_type.lower() == 'cnn':
        model = PlantCNN(num_classes=num_classes)
    else:
        raise ValueError(f"Unknown model type: {model_type}")
        
    state_dict = torch.load(model_path, map_location=device)
    if 'state_dict' in state_dict:
        state_dict = state_dict['state_dict']
        
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()
    print("[SUCCESS] Model loaded and ready.")
    return model

def evaluate_model(model, loader, device):
    """

    Runs inference on the entire dataloader and returns predictions and labels.

    """
    all_preds, all_labels = [], []
    with torch.no_grad():
        for inputs, labels in loader:
            inputs = inputs.to(device)
            if labels.ndim > 1:
                labels = labels.argmax(dim=1)
            
            outputs = model(inputs)
            preds = outputs.argmax(dim=1).cpu().numpy()
            
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
    return np.array(all_labels), np.array(all_preds)

def main():
    task = Task.init(project_name="PlantDisease", task_name="model_evaluation", task_type=Task.TaskTypes.testing)
    
    task.set_packages("./requirements.txt")
    
    task.execute_remotely(queue_name="default") 
    
    parser = argparse.ArgumentParser(description="Evaluate a trained model from ClearML.")
    parser.add_argument('--task_id', type=str, required=True, help="ClearML Task ID that produced the model.")
    parser.add_argument('--model_type', type=str, required=True, choices=['resnet18', 'cnn'])
    args = parser.parse_args()
    
    task.connect(args)
    logger = task.get_logger()
    
    print(f"--- Evaluating Model from Task ID: {args.task_id} ({args.model_type.upper()}) ---")
    
    cfg = load_config()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    data_path = cfg['data_path']
    if not os.path.exists(data_path):
        print(f"[WARN] Data path '{data_path}' not found. Running processing script...")
        subprocess.check_call([sys.executable, "process_dataset.py"])
    
    ds_dict = load_from_disk(data_path)
    test_loader = create_dataloader(ds_dict['test'], batch_size=32, samples_per_epoch=len(ds_dict['test']), is_training_set=False)
    
    class_names = ds_dict['test'].features['label'].names
    num_classes = len(class_names)
    
    model = load_model_from_clearml(args.task_id, args.model_type, num_classes, device)
    
    y_true, y_pred = evaluate_model(model, test_loader, device)
    
    print("\n--- Generating Reports and Plots ---")
    
    report_dict = classification_report(y_true, y_pred, target_names=class_names, zero_division=0, output_dict=True)
    report_text = classification_report(y_true, y_pred, target_names=class_names, zero_division=0)
    print(report_text)
    task.upload_artifact(name="classification_report", artifact_object=report_dict)
    
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(22, 22))
    sns.heatmap(cm, annot=False, cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.ylabel('True Label', fontsize=14)
    plt.xlabel('Predicted Label', fontsize=14)
    plt.title(f'Confusion Matrix - {args.model_type.upper()}', fontsize=16)
    plt.tight_layout()
    
    logger.report_matplotlib_figure(title="Confusion Matrix", series=args.model_type, figure=plt, report_image=True)
    
    print("[SUCCESS] Evaluation complete. Artifacts logged to ClearML.")
    
    task.close()

if __name__ == "__main__":
    main()