Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| from pathlib import Path | |
| # Add project root to sys.path | |
| sys.path.append(str(Path(__file__).parent.parent)) | |
| import matplotlib.pyplot as plt # noqa: E402 | |
| import numpy as np # noqa: E402 | |
| import torch # noqa: E402 | |
| import yaml # noqa: E402 | |
| from torch.utils.data import DataLoader # noqa: E402 | |
| from torchvision import transforms # noqa: E402 | |
| from src.dataset import TrashDataset # noqa: E402 | |
| from src.evaluate import evaluate # noqa: E402 | |
| from src.model import DeepCNN, ResNet18Transfer, SimpleCNN # noqa: E402 | |
| def load_config(config_path="config.yaml"): | |
| with open(config_path, "r") as f: | |
| return yaml.safe_load(f) | |
| def run_comparison(): | |
| config = load_config() | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| processed_dir = Path("data/processed") | |
| save_dir = Path("models/comparison") | |
| save_dir.mkdir(parents=True, exist_ok=True) | |
| # Normalization MUST match training | |
| test_transform = transforms.Compose( | |
| [ | |
| transforms.ToPILImage(), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ] | |
| ) | |
| # Load test data with transforms | |
| test_ds = TrashDataset( | |
| processed_dir / "X_test.npy", processed_dir / "y_test.npy", transform=test_transform | |
| ) | |
| test_loader = DataLoader(test_ds, batch_size=config["batch_size"], shuffle=False) | |
| models = { | |
| "SimpleCNN": SimpleCNN(num_classes=len(config["classes"])), | |
| "DeepCNN": DeepCNN(num_classes=len(config["classes"])), | |
| "ResNet18": ResNet18Transfer(num_classes=len(config["classes"]), pretrained=False), | |
| } | |
| # Map model names to their best saved checkpoints | |
| checkpoint_map = { | |
| "SimpleCNN": "models/simplecnn_best.pth", | |
| "DeepCNN": "models/deepcnn_best.pth", | |
| "ResNet18": "models/resnet18_best.pth", | |
| } | |
| results = {} | |
| all_preds = {} | |
| for name, model in models.items(): | |
| print(f"\nEvaluating {name}...") | |
| model_path = checkpoint_map[name] | |
| if os.path.exists(model_path): | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| print(f"Loaded weights from {model_path}") | |
| else: | |
| print(f"Warning: {model_path} not found. Using untrained weights for {name}.") | |
| loss, acc = evaluate(model, test_loader, device=device, save_dir=str(save_dir / name)) | |
| results[name] = acc | |
| # Collect all predictions for failure analysis | |
| model.to(device) | |
| model.eval() | |
| preds = [] | |
| with torch.no_grad(): | |
| for images, _ in test_loader: | |
| outputs = model(images.to(device)) | |
| preds.extend(torch.max(outputs, 1)[1].cpu().numpy()) | |
| all_preds[name] = np.array(preds) | |
| # 1. Accuracy Comparison Plot | |
| fig, ax = plt.subplots(figsize=(8, 5)) | |
| bars = ax.bar( | |
| results.keys(), | |
| results.values(), | |
| color=["#4C9BE8", "#5DBB63", "#E8714C"], | |
| width=0.5, | |
| edgecolor="white", | |
| linewidth=1.2, | |
| ) | |
| # Write value directly onto the bars | |
| for bar, val in zip(bars, results.values()): | |
| ax.text( | |
| bar.get_x() + bar.get_width() / 2, | |
| bar.get_height() + 0.01, | |
| f"{val*100:.1f}%", | |
| ha="center", | |
| va="bottom", | |
| fontsize=12, | |
| fontweight="bold", | |
| ) | |
| ax.set_ylabel("Test Accuracy", fontsize=12) | |
| ax.set_title("Model Accuracy Comparison — TrashNet", fontsize=14, fontweight="bold") | |
| ax.set_ylim(0, 1.05) | |
| ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f"{y*100:.0f}%")) | |
| ax.spines["top"].set_visible(False) | |
| ax.spines["right"].set_visible(False) | |
| plt.tight_layout() | |
| plt.savefig(save_dir / "accuracy_comparison.png", dpi=150) | |
| print(f"\nComparison plot saved to {save_dir}/accuracy_comparison.png") | |
| # 2. Failure Analysis | |
| y_test = np.load(processed_dir / "y_test.npy") | |
| resnet_correct = all_preds["ResNet18"] == y_test | |
| deep_wrong = all_preds["DeepCNN"] != y_test | |
| simple_wrong = all_preds["SimpleCNN"] != y_test | |
| interesting_indices = np.where(resnet_correct & deep_wrong & simple_wrong)[0] | |
| if len(interesting_indices) > 0: | |
| num_interesting = len(interesting_indices) | |
| print(f"Failure Analysis: Found {num_interesting} samples.") | |
| X_test = np.load(processed_dir / "X_test.npy") | |
| num_show = min(5, len(interesting_indices)) | |
| fig, axes = plt.subplots(1, num_show, figsize=(4 * num_show, 4)) | |
| if num_show == 1: | |
| axes = [axes] | |
| for i in range(num_show): | |
| idx = interesting_indices[i] | |
| axes[i].imshow(X_test[idx]) | |
| true_label = config["classes"][y_test[idx]] | |
| deep_pred = config["classes"][all_preds["DeepCNN"][idx]] | |
| simple_pred = config["classes"][all_preds["SimpleCNN"][idx]] | |
| axes[i].set_title( | |
| f"True: {true_label}\nResNet: ✓\nDeep: {deep_pred}\nSimple: {simple_pred}", | |
| fontsize=10, | |
| pad=4, | |
| ) | |
| axes[i].axis("off") | |
| plt.subplots_adjust(top=0.85, wspace=0.1) | |
| plt.savefig(save_dir / "failure_analysis_comparison.png", dpi=150, bbox_inches="tight") | |
| print(f"Failure analysis plot saved to {save_dir}/failure_analysis_comparison.png") | |
| if __name__ == "__main__": | |
| run_comparison() | |