Spaces:
Sleeping
Sleeping
File size: 5,598 Bytes
0b86da8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 | import os
import sys
from pathlib import Path
# Add project root to sys.path
sys.path.append(str(Path(__file__).parent.parent))
import matplotlib.pyplot as plt # noqa: E402
import numpy as np # noqa: E402
import torch # noqa: E402
import yaml # noqa: E402
from torch.utils.data import DataLoader # noqa: E402
from torchvision import transforms # noqa: E402
from src.dataset import TrashDataset # noqa: E402
from src.evaluate import evaluate # noqa: E402
from src.model import DeepCNN, ResNet18Transfer, SimpleCNN # noqa: E402
def load_config(config_path="config.yaml"):
with open(config_path, "r") as f:
return yaml.safe_load(f)
def run_comparison():
config = load_config()
device = "cuda" if torch.cuda.is_available() else "cpu"
processed_dir = Path("data/processed")
save_dir = Path("models/comparison")
save_dir.mkdir(parents=True, exist_ok=True)
# Normalization MUST match training
test_transform = transforms.Compose(
[
transforms.ToPILImage(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
# Load test data with transforms
test_ds = TrashDataset(
processed_dir / "X_test.npy", processed_dir / "y_test.npy", transform=test_transform
)
test_loader = DataLoader(test_ds, batch_size=config["batch_size"], shuffle=False)
models = {
"SimpleCNN": SimpleCNN(num_classes=len(config["classes"])),
"DeepCNN": DeepCNN(num_classes=len(config["classes"])),
"ResNet18": ResNet18Transfer(num_classes=len(config["classes"]), pretrained=False),
}
# Map model names to their best saved checkpoints
checkpoint_map = {
"SimpleCNN": "models/simplecnn_best.pth",
"DeepCNN": "models/deepcnn_best.pth",
"ResNet18": "models/resnet18_best.pth",
}
results = {}
all_preds = {}
for name, model in models.items():
print(f"\nEvaluating {name}...")
model_path = checkpoint_map[name]
if os.path.exists(model_path):
model.load_state_dict(torch.load(model_path, map_location=device))
print(f"Loaded weights from {model_path}")
else:
print(f"Warning: {model_path} not found. Using untrained weights for {name}.")
loss, acc = evaluate(model, test_loader, device=device, save_dir=str(save_dir / name))
results[name] = acc
# Collect all predictions for failure analysis
model.to(device)
model.eval()
preds = []
with torch.no_grad():
for images, _ in test_loader:
outputs = model(images.to(device))
preds.extend(torch.max(outputs, 1)[1].cpu().numpy())
all_preds[name] = np.array(preds)
# 1. Accuracy Comparison Plot
fig, ax = plt.subplots(figsize=(8, 5))
bars = ax.bar(
results.keys(),
results.values(),
color=["#4C9BE8", "#5DBB63", "#E8714C"],
width=0.5,
edgecolor="white",
linewidth=1.2,
)
# Write value directly onto the bars
for bar, val in zip(bars, results.values()):
ax.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height() + 0.01,
f"{val*100:.1f}%",
ha="center",
va="bottom",
fontsize=12,
fontweight="bold",
)
ax.set_ylabel("Test Accuracy", fontsize=12)
ax.set_title("Model Accuracy Comparison — TrashNet", fontsize=14, fontweight="bold")
ax.set_ylim(0, 1.05)
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f"{y*100:.0f}%"))
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
plt.tight_layout()
plt.savefig(save_dir / "accuracy_comparison.png", dpi=150)
print(f"\nComparison plot saved to {save_dir}/accuracy_comparison.png")
# 2. Failure Analysis
y_test = np.load(processed_dir / "y_test.npy")
resnet_correct = all_preds["ResNet18"] == y_test
deep_wrong = all_preds["DeepCNN"] != y_test
simple_wrong = all_preds["SimpleCNN"] != y_test
interesting_indices = np.where(resnet_correct & deep_wrong & simple_wrong)[0]
if len(interesting_indices) > 0:
num_interesting = len(interesting_indices)
print(f"Failure Analysis: Found {num_interesting} samples.")
X_test = np.load(processed_dir / "X_test.npy")
num_show = min(5, len(interesting_indices))
fig, axes = plt.subplots(1, num_show, figsize=(4 * num_show, 4))
if num_show == 1:
axes = [axes]
for i in range(num_show):
idx = interesting_indices[i]
axes[i].imshow(X_test[idx])
true_label = config["classes"][y_test[idx]]
deep_pred = config["classes"][all_preds["DeepCNN"][idx]]
simple_pred = config["classes"][all_preds["SimpleCNN"][idx]]
axes[i].set_title(
f"True: {true_label}\nResNet: ✓\nDeep: {deep_pred}\nSimple: {simple_pred}",
fontsize=10,
pad=4,
)
axes[i].axis("off")
plt.subplots_adjust(top=0.85, wspace=0.1)
plt.savefig(save_dir / "failure_analysis_comparison.png", dpi=150, bbox_inches="tight")
print(f"Failure analysis plot saved to {save_dir}/failure_analysis_comparison.png")
if __name__ == "__main__":
run_comparison()
|