JAMM032's picture
Upload github repo files
97fcc90 verified
import torch
import numpy as np
import matplotlib.pyplot as plt
import argparse
import os
import sys
import subprocess
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
from src.DataLoader.dataloader import create_dataloader
from src.models.resnet18_finetune import make_resnet18
from src.models.cnn_model import PlantCNN
from src.utils.config import load_config
from datasets import load_from_disk
from clearml import Task, InputModel
def load_model_from_clearml(task_id: str, model_type: str, num_classes: int, device: str):
print(f"[INFO] Loading 'best_model' artifact from Task ID '{task_id}'...")
try:
source_task = Task.get_task(task_id=task_id)
model_path = source_task.artifacts['best_model'].get_local_copy()
print(f"[INFO] Model downloaded to: {model_path}")
except Exception as e:
print(f"[FATAL] Could not retrieve artifact from Task {task_id}. Error: {e}")
exit(1)
if model_type.lower() == 'resnet18':
model = make_resnet18(num_classes=num_classes)
elif model_type.lower() == 'cnn':
model = PlantCNN(num_classes=num_classes)
else:
raise ValueError(f"Unknown model type: {model_type}")
state_dict = torch.load(model_path, map_location=device)
if 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
model.load_state_dict(state_dict)
model.to(device)
model.eval()
print("[SUCCESS] Model loaded and ready.")
return model
def evaluate_model(model, loader, device):
"""
Runs inference on the entire dataloader and returns predictions and labels.
"""
all_preds, all_labels = [], []
with torch.no_grad():
for inputs, labels in loader:
inputs = inputs.to(device)
if labels.ndim > 1:
labels = labels.argmax(dim=1)
outputs = model(inputs)
preds = outputs.argmax(dim=1).cpu().numpy()
all_preds.extend(preds)
all_labels.extend(labels.cpu().numpy())
return np.array(all_labels), np.array(all_preds)
def main():
task = Task.init(project_name="PlantDisease", task_name="model_evaluation", task_type=Task.TaskTypes.testing)
task.set_packages("./requirements.txt")
task.execute_remotely(queue_name="default")
parser = argparse.ArgumentParser(description="Evaluate a trained model from ClearML.")
parser.add_argument('--task_id', type=str, required=True, help="ClearML Task ID that produced the model.")
parser.add_argument('--model_type', type=str, required=True, choices=['resnet18', 'cnn'])
args = parser.parse_args()
task.connect(args)
logger = task.get_logger()
print(f"--- Evaluating Model from Task ID: {args.task_id} ({args.model_type.upper()}) ---")
cfg = load_config()
device = "cuda" if torch.cuda.is_available() else "cpu"
data_path = cfg['data_path']
if not os.path.exists(data_path):
print(f"[WARN] Data path '{data_path}' not found. Running processing script...")
subprocess.check_call([sys.executable, "process_dataset.py"])
ds_dict = load_from_disk(data_path)
test_loader = create_dataloader(ds_dict['test'], batch_size=32, samples_per_epoch=len(ds_dict['test']), is_training_set=False)
class_names = ds_dict['test'].features['label'].names
num_classes = len(class_names)
model = load_model_from_clearml(args.task_id, args.model_type, num_classes, device)
y_true, y_pred = evaluate_model(model, test_loader, device)
print("\n--- Generating Reports and Plots ---")
report_dict = classification_report(y_true, y_pred, target_names=class_names, zero_division=0, output_dict=True)
report_text = classification_report(y_true, y_pred, target_names=class_names, zero_division=0)
print(report_text)
task.upload_artifact(name="classification_report", artifact_object=report_dict)
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(22, 22))
sns.heatmap(cm, annot=False, cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.ylabel('True Label', fontsize=14)
plt.xlabel('Predicted Label', fontsize=14)
plt.title(f'Confusion Matrix - {args.model_type.upper()}', fontsize=16)
plt.tight_layout()
logger.report_matplotlib_figure(title="Confusion Matrix", series=args.model_type, figure=plt, report_image=True)
print("[SUCCESS] Evaluation complete. Artifacts logged to ClearML.")
task.close()
if __name__ == "__main__":
main()