File size: 5,598 Bytes
0b86da8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import os
import sys
from pathlib import Path

# Add project root to sys.path
sys.path.append(str(Path(__file__).parent.parent))

import matplotlib.pyplot as plt  # noqa: E402
import numpy as np  # noqa: E402
import torch  # noqa: E402
import yaml  # noqa: E402
from torch.utils.data import DataLoader  # noqa: E402
from torchvision import transforms  # noqa: E402

from src.dataset import TrashDataset  # noqa: E402
from src.evaluate import evaluate  # noqa: E402
from src.model import DeepCNN, ResNet18Transfer, SimpleCNN  # noqa: E402


def load_config(config_path="config.yaml"):
    with open(config_path, "r") as f:
        return yaml.safe_load(f)


def run_comparison():
    config = load_config()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    processed_dir = Path("data/processed")
    save_dir = Path("models/comparison")
    save_dir.mkdir(parents=True, exist_ok=True)

    # Normalization MUST match training
    test_transform = transforms.Compose(
        [
            transforms.ToPILImage(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )

    # Load test data with transforms
    test_ds = TrashDataset(
        processed_dir / "X_test.npy", processed_dir / "y_test.npy", transform=test_transform
    )
    test_loader = DataLoader(test_ds, batch_size=config["batch_size"], shuffle=False)

    models = {
        "SimpleCNN": SimpleCNN(num_classes=len(config["classes"])),
        "DeepCNN": DeepCNN(num_classes=len(config["classes"])),
        "ResNet18": ResNet18Transfer(num_classes=len(config["classes"]), pretrained=False),
    }

    # Map model names to their best saved checkpoints
    checkpoint_map = {
        "SimpleCNN": "models/simplecnn_best.pth",
        "DeepCNN": "models/deepcnn_best.pth",
        "ResNet18": "models/resnet18_best.pth",
    }

    results = {}
    all_preds = {}

    for name, model in models.items():
        print(f"\nEvaluating {name}...")
        model_path = checkpoint_map[name]

        if os.path.exists(model_path):
            model.load_state_dict(torch.load(model_path, map_location=device))
            print(f"Loaded weights from {model_path}")
        else:
            print(f"Warning: {model_path} not found. Using untrained weights for {name}.")

        loss, acc = evaluate(model, test_loader, device=device, save_dir=str(save_dir / name))
        results[name] = acc

        # Collect all predictions for failure analysis
        model.to(device)
        model.eval()
        preds = []
        with torch.no_grad():
            for images, _ in test_loader:
                outputs = model(images.to(device))
                preds.extend(torch.max(outputs, 1)[1].cpu().numpy())
        all_preds[name] = np.array(preds)

    # 1. Accuracy Comparison Plot
    fig, ax = plt.subplots(figsize=(8, 5))
    bars = ax.bar(
        results.keys(),
        results.values(),
        color=["#4C9BE8", "#5DBB63", "#E8714C"],
        width=0.5,
        edgecolor="white",
        linewidth=1.2,
    )

    # Write value directly onto the bars
    for bar, val in zip(bars, results.values()):
        ax.text(
            bar.get_x() + bar.get_width() / 2,
            bar.get_height() + 0.01,
            f"{val*100:.1f}%",
            ha="center",
            va="bottom",
            fontsize=12,
            fontweight="bold",
        )

    ax.set_ylabel("Test Accuracy", fontsize=12)
    ax.set_title("Model Accuracy Comparison — TrashNet", fontsize=14, fontweight="bold")
    ax.set_ylim(0, 1.05)
    ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f"{y*100:.0f}%"))
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    plt.tight_layout()
    plt.savefig(save_dir / "accuracy_comparison.png", dpi=150)
    print(f"\nComparison plot saved to {save_dir}/accuracy_comparison.png")

    # 2. Failure Analysis
    y_test = np.load(processed_dir / "y_test.npy")
    resnet_correct = all_preds["ResNet18"] == y_test
    deep_wrong = all_preds["DeepCNN"] != y_test
    simple_wrong = all_preds["SimpleCNN"] != y_test

    interesting_indices = np.where(resnet_correct & deep_wrong & simple_wrong)[0]

    if len(interesting_indices) > 0:
        num_interesting = len(interesting_indices)
        print(f"Failure Analysis: Found {num_interesting} samples.")
        X_test = np.load(processed_dir / "X_test.npy")

        num_show = min(5, len(interesting_indices))
        fig, axes = plt.subplots(1, num_show, figsize=(4 * num_show, 4))
        if num_show == 1:
            axes = [axes]

        for i in range(num_show):
            idx = interesting_indices[i]
            axes[i].imshow(X_test[idx])
            true_label = config["classes"][y_test[idx]]
            deep_pred = config["classes"][all_preds["DeepCNN"][idx]]
            simple_pred = config["classes"][all_preds["SimpleCNN"][idx]]
            axes[i].set_title(
                f"True: {true_label}\nResNet: ✓\nDeep: {deep_pred}\nSimple: {simple_pred}",
                fontsize=10,
                pad=4,
            )
            axes[i].axis("off")

        plt.subplots_adjust(top=0.85, wspace=0.1)
        plt.savefig(save_dir / "failure_analysis_comparison.png", dpi=150, bbox_inches="tight")
        print(f"Failure analysis plot saved to {save_dir}/failure_analysis_comparison.png")


if __name__ == "__main__":
    run_comparison()