Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import argparse | |
| import os | |
| import sys | |
| import subprocess | |
| from sklearn.metrics import confusion_matrix, classification_report | |
| import seaborn as sns | |
| from src.DataLoader.dataloader import create_dataloader | |
| from src.models.resnet18_finetune import make_resnet18 | |
| from src.models.cnn_model import PlantCNN | |
| from src.utils.config import load_config | |
| from datasets import load_from_disk | |
| from clearml import Task, InputModel | |
| def load_model_from_clearml(task_id: str, model_type: str, num_classes: int, device: str): | |
| print(f"[INFO] Loading 'best_model' artifact from Task ID '{task_id}'...") | |
| try: | |
| source_task = Task.get_task(task_id=task_id) | |
| model_path = source_task.artifacts['best_model'].get_local_copy() | |
| print(f"[INFO] Model downloaded to: {model_path}") | |
| except Exception as e: | |
| print(f"[FATAL] Could not retrieve artifact from Task {task_id}. Error: {e}") | |
| exit(1) | |
| if model_type.lower() == 'resnet18': | |
| model = make_resnet18(num_classes=num_classes) | |
| elif model_type.lower() == 'cnn': | |
| model = PlantCNN(num_classes=num_classes) | |
| else: | |
| raise ValueError(f"Unknown model type: {model_type}") | |
| state_dict = torch.load(model_path, map_location=device) | |
| if 'state_dict' in state_dict: | |
| state_dict = state_dict['state_dict'] | |
| model.load_state_dict(state_dict) | |
| model.to(device) | |
| model.eval() | |
| print("[SUCCESS] Model loaded and ready.") | |
| return model | |
| def evaluate_model(model, loader, device): | |
| """ | |
| Runs inference on the entire dataloader and returns predictions and labels. | |
| """ | |
| all_preds, all_labels = [], [] | |
| with torch.no_grad(): | |
| for inputs, labels in loader: | |
| inputs = inputs.to(device) | |
| if labels.ndim > 1: | |
| labels = labels.argmax(dim=1) | |
| outputs = model(inputs) | |
| preds = outputs.argmax(dim=1).cpu().numpy() | |
| all_preds.extend(preds) | |
| all_labels.extend(labels.cpu().numpy()) | |
| return np.array(all_labels), np.array(all_preds) | |
| def main(): | |
| task = Task.init(project_name="PlantDisease", task_name="model_evaluation", task_type=Task.TaskTypes.testing) | |
| task.set_packages("./requirements.txt") | |
| task.execute_remotely(queue_name="default") | |
| parser = argparse.ArgumentParser(description="Evaluate a trained model from ClearML.") | |
| parser.add_argument('--task_id', type=str, required=True, help="ClearML Task ID that produced the model.") | |
| parser.add_argument('--model_type', type=str, required=True, choices=['resnet18', 'cnn']) | |
| args = parser.parse_args() | |
| task.connect(args) | |
| logger = task.get_logger() | |
| print(f"--- Evaluating Model from Task ID: {args.task_id} ({args.model_type.upper()}) ---") | |
| cfg = load_config() | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| data_path = cfg['data_path'] | |
| if not os.path.exists(data_path): | |
| print(f"[WARN] Data path '{data_path}' not found. Running processing script...") | |
| subprocess.check_call([sys.executable, "process_dataset.py"]) | |
| ds_dict = load_from_disk(data_path) | |
| test_loader = create_dataloader(ds_dict['test'], batch_size=32, samples_per_epoch=len(ds_dict['test']), is_training_set=False) | |
| class_names = ds_dict['test'].features['label'].names | |
| num_classes = len(class_names) | |
| model = load_model_from_clearml(args.task_id, args.model_type, num_classes, device) | |
| y_true, y_pred = evaluate_model(model, test_loader, device) | |
| print("\n--- Generating Reports and Plots ---") | |
| report_dict = classification_report(y_true, y_pred, target_names=class_names, zero_division=0, output_dict=True) | |
| report_text = classification_report(y_true, y_pred, target_names=class_names, zero_division=0) | |
| print(report_text) | |
| task.upload_artifact(name="classification_report", artifact_object=report_dict) | |
| cm = confusion_matrix(y_true, y_pred) | |
| plt.figure(figsize=(22, 22)) | |
| sns.heatmap(cm, annot=False, cmap='Blues', xticklabels=class_names, yticklabels=class_names) | |
| plt.ylabel('True Label', fontsize=14) | |
| plt.xlabel('Predicted Label', fontsize=14) | |
| plt.title(f'Confusion Matrix - {args.model_type.upper()}', fontsize=16) | |
| plt.tight_layout() | |
| logger.report_matplotlib_figure(title="Confusion Matrix", series=args.model_type, figure=plt, report_image=True) | |
| print("[SUCCESS] Evaluation complete. Artifacts logged to ClearML.") | |
| task.close() | |
| if __name__ == "__main__": | |
| main() | |