Spaces:
Running
Running
| """ | |
| Ensemble: Combines ResNet50, EfficientNet-B0, and ViT using soft voting. | |
| Downloads checkpoints from HuggingFace, runs all three on the test set, | |
| averages softmax probabilities, and reports the ensemble accuracy. | |
| Usage: | |
| python -m src.ensemble | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader | |
| from torchvision import models | |
| from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights | |
| from transformers import ViTForImageClassification | |
| from huggingface_hub import hf_hub_download | |
| from sklearn.metrics import accuracy_score, classification_report, confusion_matrix | |
| from tqdm import tqdm | |
| from src.dataset import DateFruitDataset, get_val_transforms | |
| from src.utils import load_config, get_device, seed_everything | |
| HF_REPO_ID = "Rashidbm/saudi-date-classifier" | |
| CHECKPOINTS = { | |
| "resnet": "arabic_dates_resnet50_best_V2.pth", | |
| "efficientnet": "efficientnet_best.pth", | |
| "vit": "vit_best_model.pth", | |
| } | |
| CLASS_NAMES = [ | |
| "Ajwa", "Galaxy", "Medjool", "Meneifi", "Nabtat Ali", | |
| "Rutab", "Shaishe", "Sokari", "Sugaey", | |
| ] | |
| def build_resnet50(num_classes=9, dropout=0.3): | |
| model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) | |
| model.fc = nn.Sequential( | |
| nn.Dropout(dropout), | |
| nn.Linear(model.fc.in_features, num_classes), | |
| ) | |
| return model | |
| def build_efficientnet(num_classes=9, dropout=0.3): | |
| model = efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT) | |
| in_features = model.classifier[1].in_features | |
| model.classifier = nn.Sequential( | |
| nn.Dropout(p=dropout), | |
| nn.Linear(in_features, num_classes), | |
| ) | |
| return model | |
| class PretrainedViTClassifier(nn.Module): | |
| def __init__(self, num_classes=9): | |
| super().__init__() | |
| self.backbone = ViTForImageClassification.from_pretrained( | |
| "google/vit-base-patch16-224-in21k", | |
| num_labels=num_classes, | |
| ignore_mismatched_sizes=True, | |
| ) | |
| def forward(self, x): | |
| return self.backbone(x).logits | |
| def load_checkpoint(model, path, device): | |
| """Load checkpoint, handle both 'model_state_dict' and raw state dicts.""" | |
| ckpt = torch.load(path, map_location=device, weights_only=False) | |
| if isinstance(ckpt, dict) and "model_state_dict" in ckpt: | |
| model.load_state_dict(ckpt["model_state_dict"]) | |
| else: | |
| model.load_state_dict(ckpt) | |
| model.to(device) | |
| model.eval() | |
| return model | |
| def load_all_models(device): | |
| """Download checkpoints from HuggingFace and load all three models.""" | |
| print("Downloading checkpoints from HuggingFace...") | |
| paths = { | |
| name: hf_hub_download(repo_id=HF_REPO_ID, filename=fname) | |
| for name, fname in CHECKPOINTS.items() | |
| } | |
| print("Loading models...") | |
| models_dict = {} | |
| models_dict["resnet"] = load_checkpoint( | |
| build_resnet50(num_classes=9), paths["resnet"], device | |
| ) | |
| models_dict["efficientnet"] = load_checkpoint( | |
| build_efficientnet(num_classes=9), paths["efficientnet"], device | |
| ) | |
| models_dict["vit"] = load_checkpoint( | |
| PretrainedViTClassifier(num_classes=9), paths["vit"], device | |
| ) | |
| print("All models loaded.") | |
| return models_dict | |
| def evaluate_single(model, loader, device, name): | |
| """Evaluate a single model, return (accuracy, all_probs, all_labels).""" | |
| all_probs = [] | |
| all_labels = [] | |
| for images, labels, _ in tqdm(loader, desc=f"Evaluating {name}"): | |
| images = images.to(device) | |
| logits = model(images) | |
| probs = F.softmax(logits, dim=1) | |
| all_probs.append(probs.cpu()) | |
| all_labels.append(labels) | |
| all_probs = torch.cat(all_probs) | |
| all_labels = torch.cat(all_labels) | |
| preds = all_probs.argmax(dim=1) | |
| acc = accuracy_score(all_labels.numpy(), preds.numpy()) * 100 | |
| return acc, all_probs, all_labels | |
| def main(): | |
| config = load_config() | |
| seed_everything(42) | |
| device = get_device() | |
| print(f"Device: {device}") | |
| # Load all models | |
| models_dict = load_all_models(device) | |
| # Load test set | |
| transform = get_val_transforms(config) | |
| test_dataset = DateFruitDataset("data/test.csv", transform=transform) | |
| test_loader = DataLoader( | |
| test_dataset, batch_size=16, shuffle=False, num_workers=0 | |
| ) | |
| print(f"\nTest set: {len(test_dataset)} images") | |
| # Evaluate each model | |
| results = {} | |
| for name, model in models_dict.items(): | |
| acc, probs, labels = evaluate_single(model, test_loader, device, name) | |
| results[name] = {"accuracy": acc, "probs": probs, "labels": labels} | |
| # Ensemble (soft voting - average of softmax probabilities) | |
| ensemble_probs = sum(r["probs"] for r in results.values()) / len(results) | |
| ensemble_preds = ensemble_probs.argmax(dim=1).numpy() | |
| true_labels = results["vit"]["labels"].numpy() | |
| ensemble_acc = accuracy_score(true_labels, ensemble_preds) * 100 | |
| print(f"\n{'='*50}") | |
| print(f"INDIVIDUAL vs ENSEMBLE") | |
| print(f"{'='*50}") | |
| for name, r in results.items(): | |
| print(f" {name.upper():<15} {r['accuracy']:>6.2f}%") | |
| print(f" {'ENSEMBLE':<15} {ensemble_acc:>6.2f}%") | |
| print(f"\nEnsemble Classification Report:") | |
| print(classification_report(true_labels, ensemble_preds, target_names=CLASS_NAMES)) | |
| print("Confusion Matrix:") | |
| print(confusion_matrix(true_labels, ensemble_preds)) | |
| if __name__ == "__main__": | |
| main() | |