trash-classifier / src /comparison.py
alshami-dev's picture
First Update to the App
0b86da8 verified
Raw
History Blame Contribute Delete
5.6 kB
import os
import sys
from pathlib import Path
# Add project root to sys.path
sys.path.append(str(Path(__file__).parent.parent))
import matplotlib.pyplot as plt # noqa: E402
import numpy as np # noqa: E402
import torch # noqa: E402
import yaml # noqa: E402
from torch.utils.data import DataLoader # noqa: E402
from torchvision import transforms # noqa: E402
from src.dataset import TrashDataset # noqa: E402
from src.evaluate import evaluate # noqa: E402
from src.model import DeepCNN, ResNet18Transfer, SimpleCNN # noqa: E402
def load_config(config_path="config.yaml"):
with open(config_path, "r") as f:
return yaml.safe_load(f)
def run_comparison():
config = load_config()
device = "cuda" if torch.cuda.is_available() else "cpu"
processed_dir = Path("data/processed")
save_dir = Path("models/comparison")
save_dir.mkdir(parents=True, exist_ok=True)
# Normalization MUST match training
test_transform = transforms.Compose(
[
transforms.ToPILImage(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
# Load test data with transforms
test_ds = TrashDataset(
processed_dir / "X_test.npy", processed_dir / "y_test.npy", transform=test_transform
)
test_loader = DataLoader(test_ds, batch_size=config["batch_size"], shuffle=False)
models = {
"SimpleCNN": SimpleCNN(num_classes=len(config["classes"])),
"DeepCNN": DeepCNN(num_classes=len(config["classes"])),
"ResNet18": ResNet18Transfer(num_classes=len(config["classes"]), pretrained=False),
}
# Map model names to their best saved checkpoints
checkpoint_map = {
"SimpleCNN": "models/simplecnn_best.pth",
"DeepCNN": "models/deepcnn_best.pth",
"ResNet18": "models/resnet18_best.pth",
}
results = {}
all_preds = {}
for name, model in models.items():
print(f"\nEvaluating {name}...")
model_path = checkpoint_map[name]
if os.path.exists(model_path):
model.load_state_dict(torch.load(model_path, map_location=device))
print(f"Loaded weights from {model_path}")
else:
print(f"Warning: {model_path} not found. Using untrained weights for {name}.")
loss, acc = evaluate(model, test_loader, device=device, save_dir=str(save_dir / name))
results[name] = acc
# Collect all predictions for failure analysis
model.to(device)
model.eval()
preds = []
with torch.no_grad():
for images, _ in test_loader:
outputs = model(images.to(device))
preds.extend(torch.max(outputs, 1)[1].cpu().numpy())
all_preds[name] = np.array(preds)
# 1. Accuracy Comparison Plot
fig, ax = plt.subplots(figsize=(8, 5))
bars = ax.bar(
results.keys(),
results.values(),
color=["#4C9BE8", "#5DBB63", "#E8714C"],
width=0.5,
edgecolor="white",
linewidth=1.2,
)
# Write value directly onto the bars
for bar, val in zip(bars, results.values()):
ax.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height() + 0.01,
f"{val*100:.1f}%",
ha="center",
va="bottom",
fontsize=12,
fontweight="bold",
)
ax.set_ylabel("Test Accuracy", fontsize=12)
ax.set_title("Model Accuracy Comparison — TrashNet", fontsize=14, fontweight="bold")
ax.set_ylim(0, 1.05)
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f"{y*100:.0f}%"))
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
plt.tight_layout()
plt.savefig(save_dir / "accuracy_comparison.png", dpi=150)
print(f"\nComparison plot saved to {save_dir}/accuracy_comparison.png")
# 2. Failure Analysis
y_test = np.load(processed_dir / "y_test.npy")
resnet_correct = all_preds["ResNet18"] == y_test
deep_wrong = all_preds["DeepCNN"] != y_test
simple_wrong = all_preds["SimpleCNN"] != y_test
interesting_indices = np.where(resnet_correct & deep_wrong & simple_wrong)[0]
if len(interesting_indices) > 0:
num_interesting = len(interesting_indices)
print(f"Failure Analysis: Found {num_interesting} samples.")
X_test = np.load(processed_dir / "X_test.npy")
num_show = min(5, len(interesting_indices))
fig, axes = plt.subplots(1, num_show, figsize=(4 * num_show, 4))
if num_show == 1:
axes = [axes]
for i in range(num_show):
idx = interesting_indices[i]
axes[i].imshow(X_test[idx])
true_label = config["classes"][y_test[idx]]
deep_pred = config["classes"][all_preds["DeepCNN"][idx]]
simple_pred = config["classes"][all_preds["SimpleCNN"][idx]]
axes[i].set_title(
f"True: {true_label}\nResNet: ✓\nDeep: {deep_pred}\nSimple: {simple_pred}",
fontsize=10,
pad=4,
)
axes[i].axis("off")
plt.subplots_adjust(top=0.85, wspace=0.1)
plt.savefig(save_dir / "failure_analysis_comparison.png", dpi=150, bbox_inches="tight")
print(f"Failure analysis plot saved to {save_dir}/failure_analysis_comparison.png")
if __name__ == "__main__":
run_comparison()