import os import sys from pathlib import Path # Add project root to sys.path sys.path.append(str(Path(__file__).parent.parent)) import matplotlib.pyplot as plt # noqa: E402 import numpy as np # noqa: E402 import torch # noqa: E402 import yaml # noqa: E402 from torch.utils.data import DataLoader # noqa: E402 from torchvision import transforms # noqa: E402 from src.dataset import TrashDataset # noqa: E402 from src.evaluate import evaluate # noqa: E402 from src.model import DeepCNN, ResNet18Transfer, SimpleCNN # noqa: E402 def load_config(config_path="config.yaml"): with open(config_path, "r") as f: return yaml.safe_load(f) def run_comparison(): config = load_config() device = "cuda" if torch.cuda.is_available() else "cpu" processed_dir = Path("data/processed") save_dir = Path("models/comparison") save_dir.mkdir(parents=True, exist_ok=True) # Normalization MUST match training test_transform = transforms.Compose( [ transforms.ToPILImage(), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) # Load test data with transforms test_ds = TrashDataset( processed_dir / "X_test.npy", processed_dir / "y_test.npy", transform=test_transform ) test_loader = DataLoader(test_ds, batch_size=config["batch_size"], shuffle=False) models = { "SimpleCNN": SimpleCNN(num_classes=len(config["classes"])), "DeepCNN": DeepCNN(num_classes=len(config["classes"])), "ResNet18": ResNet18Transfer(num_classes=len(config["classes"]), pretrained=False), } # Map model names to their best saved checkpoints checkpoint_map = { "SimpleCNN": "models/simplecnn_best.pth", "DeepCNN": "models/deepcnn_best.pth", "ResNet18": "models/resnet18_best.pth", } results = {} all_preds = {} for name, model in models.items(): print(f"\nEvaluating {name}...") model_path = checkpoint_map[name] if os.path.exists(model_path): model.load_state_dict(torch.load(model_path, map_location=device)) print(f"Loaded weights from {model_path}") else: print(f"Warning: {model_path} not found. Using untrained weights for {name}.") loss, acc = evaluate(model, test_loader, device=device, save_dir=str(save_dir / name)) results[name] = acc # Collect all predictions for failure analysis model.to(device) model.eval() preds = [] with torch.no_grad(): for images, _ in test_loader: outputs = model(images.to(device)) preds.extend(torch.max(outputs, 1)[1].cpu().numpy()) all_preds[name] = np.array(preds) # 1. Accuracy Comparison Plot fig, ax = plt.subplots(figsize=(8, 5)) bars = ax.bar( results.keys(), results.values(), color=["#4C9BE8", "#5DBB63", "#E8714C"], width=0.5, edgecolor="white", linewidth=1.2, ) # Write value directly onto the bars for bar, val in zip(bars, results.values()): ax.text( bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01, f"{val*100:.1f}%", ha="center", va="bottom", fontsize=12, fontweight="bold", ) ax.set_ylabel("Test Accuracy", fontsize=12) ax.set_title("Model Accuracy Comparison — TrashNet", fontsize=14, fontweight="bold") ax.set_ylim(0, 1.05) ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f"{y*100:.0f}%")) ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) plt.tight_layout() plt.savefig(save_dir / "accuracy_comparison.png", dpi=150) print(f"\nComparison plot saved to {save_dir}/accuracy_comparison.png") # 2. Failure Analysis y_test = np.load(processed_dir / "y_test.npy") resnet_correct = all_preds["ResNet18"] == y_test deep_wrong = all_preds["DeepCNN"] != y_test simple_wrong = all_preds["SimpleCNN"] != y_test interesting_indices = np.where(resnet_correct & deep_wrong & simple_wrong)[0] if len(interesting_indices) > 0: num_interesting = len(interesting_indices) print(f"Failure Analysis: Found {num_interesting} samples.") X_test = np.load(processed_dir / "X_test.npy") num_show = min(5, len(interesting_indices)) fig, axes = plt.subplots(1, num_show, figsize=(4 * num_show, 4)) if num_show == 1: axes = [axes] for i in range(num_show): idx = interesting_indices[i] axes[i].imshow(X_test[idx]) true_label = config["classes"][y_test[idx]] deep_pred = config["classes"][all_preds["DeepCNN"][idx]] simple_pred = config["classes"][all_preds["SimpleCNN"][idx]] axes[i].set_title( f"True: {true_label}\nResNet: ✓\nDeep: {deep_pred}\nSimple: {simple_pred}", fontsize=10, pad=4, ) axes[i].axis("off") plt.subplots_adjust(top=0.85, wspace=0.1) plt.savefig(save_dir / "failure_analysis_comparison.png", dpi=150, bbox_inches="tight") print(f"Failure analysis plot saved to {save_dir}/failure_analysis_comparison.png") if __name__ == "__main__": run_comparison()