Add ensemble code; .py and .ipynb

#4
Files changed (2) hide show
  1. dates-ensemble.ipynb +1 -0
  2. ensemble.py +195 -0
dates-ensemble.ipynb ADDED
@@ -0,0 +1 @@
 
 
1
+ {"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.12.12","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"nvidiaTeslaT4","dataSources":[{"sourceType":"datasetVersion","sourceId":1432643,"datasetId":839239,"databundleVersionId":1466035}],"dockerImageVersionId":31329,"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"!pip install -q huggingface_hub transformers scikit-learn","metadata":{"_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","execution":{"iopub.status.busy":"2026-04-04T18:47:36.671112Z","iopub.execute_input":"2026-04-04T18:47:36.671853Z","iopub.status.idle":"2026-04-04T18:47:42.065147Z","shell.execute_reply.started":"2026-04-04T18:47:36.671820Z","shell.execute_reply":"2026-04-04T18:47:42.064388Z"},"trusted":true},"outputs":[],"execution_count":1},{"cell_type":"code","source":"import os\n\nprint(os.listdir(\"/kaggle/input\"))","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-04-04T18:47:42.066921Z","iopub.execute_input":"2026-04-04T18:47:42.067165Z","iopub.status.idle":"2026-04-04T18:47:42.071974Z","shell.execute_reply.started":"2026-04-04T18:47:42.067138Z","shell.execute_reply":"2026-04-04T18:47:42.071309Z"}},"outputs":[{"name":"stdout","text":"['datasets']\n","output_type":"stream"}],"execution_count":2},{"cell_type":"code","source":"DATASET_DIR = \"/kaggle/input/datasets/wadhasnalhamdan/date-fruit-image-dataset-in-controlled-environment\" # change if needed\nprint(os.listdir(DATASET_DIR)[:20])","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-04-04T18:47:42.072870Z","iopub.execute_input":"2026-04-04T18:47:42.073270Z","iopub.status.idle":"2026-04-04T18:47:42.109076Z","shell.execute_reply.started":"2026-04-04T18:47:42.073237Z","shell.execute_reply":"2026-04-04T18:47:42.108383Z"}},"outputs":[{"name":"stdout","text":"['Galaxy', 'Rutab', 'Sugaey', 'Medjool', 'Nabtat Ali', 'Ajwa', 'Sokari', 'Shaishe', 'Meneifi']\n","output_type":"stream"}],"execution_count":3},{"cell_type":"code","source":"import os\nimport random\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom torch.utils.data import DataLoader, Subset\nfrom torchvision import datasets, transforms, models\nfrom torchvision.models import efficientnet_b0, EfficientNet_B0_Weights\nfrom transformers import ViTForImageClassification\nfrom huggingface_hub import hf_hub_download\n\nfrom sklearn.metrics import accuracy_score, classification_report, confusion_matrix\n\nSEED = 42\nIMAGE_SIZE = 224\nBATCH_SIZE = 32\nNUM_WORKERS = 2\n\nHF_REPO_ID = \"Rashidbm/saudi-date-classifier\"\n\nRESNET_FILE = \"arabic_dates_resnet50_best_V2.pth\"\nEFFICIENTNET_FILE = \"efficientnet_best.pth\"\nVIT_FILE = \"vit_best_model.pth\"\n\nEXPECTED_CLASSES = [\n \"Ajwa\", \"Galaxy\", \"Medjool\", \"Meneifi\", \"Nabtat Ali\",\n \"Rutab\", \"Shaishe\", \"Sokari\", \"Sugaey\"\n]\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\nprint(\"Using device:\", device)\n\nrandom.seed(SEED)\nnp.random.seed(SEED)\ntorch.manual_seed(SEED)\ntorch.cuda.manual_seed_all(SEED)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-04-04T18:47:42.110340Z","iopub.execute_input":"2026-04-04T18:47:42.110560Z","iopub.status.idle":"2026-04-04T18:48:13.404912Z","shell.execute_reply.started":"2026-04-04T18:47:42.110540Z","shell.execute_reply":"2026-04-04T18:48:13.404307Z"}},"outputs":[{"name":"stdout","text":"Using device: cuda\n","output_type":"stream"}],"execution_count":4},{"cell_type":"code","source":"val_transforms = transforms.Compose([\n transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),\n transforms.ToTensor(),\n transforms.Normalize([0.485, 0.456, 0.406],\n [0.229, 0.224, 0.225]),\n])\n\nfull_dataset = datasets.ImageFolder(DATASET_DIR, transform=val_transforms)\n\nclass_names = full_dataset.classes\nprint(\"Classes found:\", class_names)\n\nif class_names != EXPECTED_CLASSES:\n raise ValueError(f\"Class order mismatch.\\nExpected: {EXPECTED_CLASSES}\\nFound: {class_names}\")\n\nn = len(full_dataset)\nn_train = int(0.70 * n)\nn_val = int(0.15 * n)\nn_test = n - n_train - n_val\n\ngenerator = torch.Generator().manual_seed(SEED)\nindices = torch.randperm(n, generator=generator).tolist()\n\ntrain_idx = indices[:n_train]\nval_idx = indices[n_train:n_train + n_val]\ntest_idx = indices[n_train + n_val:]\n\ntest_dataset = Subset(full_dataset, test_idx)\n\ntest_loader = DataLoader(\n test_dataset,\n batch_size=BATCH_SIZE,\n shuffle=False,\n num_workers=NUM_WORKERS,\n pin_memory=True\n)\n\nprint(\"Total images:\", n)\nprint(\"Test images:\", len(test_dataset))","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-04-04T18:48:28.042898Z","iopub.execute_input":"2026-04-04T18:48:28.043430Z","iopub.status.idle":"2026-04-04T18:48:34.030824Z","shell.execute_reply.started":"2026-04-04T18:48:28.043391Z","shell.execute_reply":"2026-04-04T18:48:34.030214Z"}},"outputs":[{"name":"stdout","text":"Classes found: ['Ajwa', 'Galaxy', 'Medjool', 'Meneifi', 'Nabtat Ali', 'Rutab', 'Shaishe', 'Sokari', 'Sugaey']\nTotal images: 1658\nTest images: 250\n","output_type":"stream"}],"execution_count":5},{"cell_type":"code","source":"# ResNet50\ndef build_resnet50(num_classes=9, dropout=0.3):\n model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)\n model.fc = nn.Sequential(\n nn.Dropout(dropout),\n nn.Linear(model.fc.in_features, num_classes)\n )\n return model\n\n\n# EfficientNet-B0\ndef build_efficientnet(num_classes=9, pretrained=True, dropout=0.3):\n weights = EfficientNet_B0_Weights.DEFAULT if pretrained else None\n model = efficientnet_b0(weights=weights)\n in_features = model.classifier[1].in_features\n model.classifier = nn.Sequential(\n nn.Dropout(p=dropout),\n nn.Linear(in_features, num_classes)\n )\n return model\n\n\n# ViT wrapper\nclass PretrainedViTClassifier(nn.Module):\n def __init__(self, model_name=\"google/vit-base-patch16-224-in21k\", num_classes=9, dropout=0.1):\n super().__init__()\n self.backbone = ViTForImageClassification.from_pretrained(\n model_name,\n num_labels=num_classes,\n ignore_mismatched_sizes=True,\n )\n\n def forward(self, x):\n outputs = self.backbone(x)\n return outputs.logits","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-04-04T18:48:34.032014Z","iopub.execute_input":"2026-04-04T18:48:34.032466Z","iopub.status.idle":"2026-04-04T18:48:34.038844Z","shell.execute_reply.started":"2026-04-04T18:48:34.032440Z","shell.execute_reply":"2026-04-04T18:48:34.038275Z"}},"outputs":[],"execution_count":6},{"cell_type":"code","source":"resnet_path = hf_hub_download(repo_id=HF_REPO_ID, filename=RESNET_FILE)\nefficientnet_path = hf_hub_download(repo_id=HF_REPO_ID, filename=EFFICIENTNET_FILE)\nvit_path = hf_hub_download(repo_id=HF_REPO_ID, filename=VIT_FILE)\n\nprint(\"Downloaded:\")\nprint(resnet_path)\nprint(efficientnet_path)\nprint(vit_path)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-04-04T18:48:38.083104Z","iopub.execute_input":"2026-04-04T18:48:38.083401Z","iopub.status.idle":"2026-04-04T18:48:48.450459Z","shell.execute_reply.started":"2026-04-04T18:48:38.083374Z","shell.execute_reply":"2026-04-04T18:48:48.449746Z"}},"outputs":[{"output_type":"display_data","data":{"text/plain":"arabic_dates_resnet50_best_V2.pth: 0%| | 0.00/94.4M [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"27ebeafe729a4edfb7144192c2a8d16b"}},"metadata":{}},{"name":"stderr","text":"Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.\n","output_type":"stream"},{"output_type":"display_data","data":{"text/plain":"efficientnet_best.pth: 0%| | 0.00/16.4M [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"dad4ce1b503640c0929067b313b39ff7"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"vit_best_model.pth: 0%| | 0.00/1.03G [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"63eb1ed959d84808b44d8a5474b1d3d0"}},"metadata":{}},{"name":"stdout","text":"Downloaded:\n/root/.cache/huggingface/hub/models--Rashidbm--saudi-date-classifier/snapshots/f2814ae770f1331d58759c309b2e13baf2440e43/arabic_dates_resnet50_best_V2.pth\n/root/.cache/huggingface/hub/models--Rashidbm--saudi-date-classifier/snapshots/f2814ae770f1331d58759c309b2e13baf2440e43/efficientnet_best.pth\n/root/.cache/huggingface/hub/models--Rashidbm--saudi-date-classifier/snapshots/f2814ae770f1331d58759c309b2e13baf2440e43/vit_best_model.pth\n","output_type":"stream"}],"execution_count":7},{"cell_type":"code","source":"def load_checkpoint_flex(model, path):\n ckpt = torch.load(path, map_location=device)\n\n if isinstance(ckpt, dict) and \"model_state_dict\" in ckpt:\n model.load_state_dict(ckpt[\"model_state_dict\"])\n else:\n model.load_state_dict(ckpt)\n\n model.to(device)\n model.eval()\n return model\n\n\nresnet_model = build_resnet50(num_classes=len(class_names), dropout=0.3)\nresnet_model = load_checkpoint_flex(resnet_model, resnet_path)\n\nefficientnet_model = build_efficientnet(num_classes=len(class_names), pretrained=True, dropout=0.3)\nefficientnet_model = load_checkpoint_flex(efficientnet_model, efficientnet_path)\n\nvit_model = PretrainedViTClassifier(\n model_name=\"google/vit-base-patch16-224-in21k\",\n num_classes=len(class_names),\n dropout=0.1,\n)\nvit_model = load_checkpoint_flex(vit_model, vit_path)\n\nprint(\"All models loaded successfully.\")","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-04-04T18:48:48.451812Z","iopub.execute_input":"2026-04-04T18:48:48.452116Z","iopub.status.idle":"2026-04-04T18:48:54.602232Z","shell.execute_reply.started":"2026-04-04T18:48:48.452091Z","shell.execute_reply":"2026-04-04T18:48:54.601453Z"}},"outputs":[{"name":"stdout","text":"Downloading: \"https://download.pytorch.org/models/resnet50-11ad3fa6.pth\" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth\n","output_type":"stream"},{"name":"stderr","text":"100%|██████████| 97.8M/97.8M [00:00<00:00, 191MB/s] \n","output_type":"stream"},{"name":"stdout","text":"Downloading: \"https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth\" to /root/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-7f5810bc.pth\n","output_type":"stream"},{"name":"stderr","text":"100%|██████████| 20.5M/20.5M [00:00<00:00, 138MB/s] \n","output_type":"stream"},{"output_type":"display_data","data":{"text/plain":"config.json: 0%| | 0.00/502 [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"f487a0690ff44081a2ea4b03c273bed2"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"model.safetensors: 0%| | 0.00/346M [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"aef8ff50a9ec438897c3a2347d16c6e6"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"Loading weights: 0%| | 0/198 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"21123476749c495daf62a2bbbfeef92a"}},"metadata":{}},{"name":"stderr","text":"ViTForImageClassification LOAD REPORT from: google/vit-base-patch16-224-in21k\nKey | Status | \n--------------------+------------+-\npooler.dense.bias | UNEXPECTED | \npooler.dense.weight | UNEXPECTED | \nclassifier.bias | MISSING | \nclassifier.weight | MISSING | \n\nNotes:\n- UNEXPECTED\t:can be ignored when loading from different task/architecture; not ok if you expect identical arch.\n- MISSING\t:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.\n","output_type":"stream"},{"name":"stdout","text":"All models loaded successfully.\n","output_type":"stream"}],"execution_count":8},{"cell_type":"code","source":"@torch.no_grad()\ndef evaluate_ensemble(resnet_model, efficientnet_model, vit_model, loader, class_names):\n all_preds = []\n all_labels = []\n\n for images, labels in loader:\n images = images.to(device)\n labels = labels.to(device)\n\n logits_resnet = resnet_model(images)\n logits_eff = efficientnet_model(images)\n logits_vit = vit_model(images)\n\n probs_resnet = F.softmax(logits_resnet, dim=1)\n probs_eff = F.softmax(logits_eff, dim=1)\n probs_vit = F.softmax(logits_vit, dim=1)\n\n ensemble_probs = (probs_resnet + probs_eff + probs_vit) / 3.0\n preds = ensemble_probs.argmax(dim=1)\n\n all_preds.extend(preds.cpu().tolist())\n all_labels.extend(labels.cpu().tolist())\n\n acc = accuracy_score(all_labels, all_preds)\n print(f\"\\nEnsemble Test Accuracy: {acc * 100:.2f}%\\n\")\n\n print(\"Classification Report:\")\n print(classification_report(all_labels, all_preds, target_names=class_names))\n\n print(\"Confusion Matrix:\")\n print(confusion_matrix(all_labels, all_preds))\n\n\nevaluate_ensemble(resnet_model, efficientnet_model, vit_model, test_loader, class_names)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-04-04T18:49:32.033994Z","iopub.execute_input":"2026-04-04T18:49:32.034745Z","iopub.status.idle":"2026-04-04T18:49:51.150293Z","shell.execute_reply.started":"2026-04-04T18:49:32.034716Z","shell.execute_reply":"2026-04-04T18:49:51.149462Z"}},"outputs":[{"name":"stdout","text":"\nEnsemble Test Accuracy: 97.60%\n\nClassification Report:\n precision recall f1-score support\n\n Ajwa 1.00 1.00 1.00 33\n Galaxy 1.00 0.88 0.93 24\n Medjool 1.00 0.91 0.95 23\n Meneifi 0.88 1.00 0.94 44\n Nabtat Ali 1.00 1.00 1.00 25\n Rutab 1.00 1.00 1.00 20\n Shaishe 1.00 1.00 1.00 16\n Sokari 1.00 1.00 1.00 42\n Sugaey 1.00 0.96 0.98 23\n\n accuracy 0.98 250\n macro avg 0.99 0.97 0.98 250\nweighted avg 0.98 0.98 0.98 250\n\nConfusion Matrix:\n[[33 0 0 0 0 0 0 0 0]\n [ 0 21 0 3 0 0 0 0 0]\n [ 0 0 21 2 0 0 0 0 0]\n [ 0 0 0 44 0 0 0 0 0]\n [ 0 0 0 0 25 0 0 0 0]\n [ 0 0 0 0 0 20 0 0 0]\n [ 0 0 0 0 0 0 16 0 0]\n [ 0 0 0 0 0 0 0 42 0]\n [ 0 0 0 1 0 0 0 0 22]]\n","output_type":"stream"}],"execution_count":9},{"cell_type":"code","source":"","metadata":{"trusted":true},"outputs":[],"execution_count":null}]}
ensemble.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ !pip install -q huggingface_hub transformers scikit-learn
3
+ import os
4
+
5
+ print(os.listdir("/kaggle/input"))
6
+
7
+ DATASET_DIR = "/kaggle/input/datasets/wadhasnalhamdan/date-fruit-image-dataset-in-controlled-environment" # change if needed
8
+ print(os.listdir(DATASET_DIR)[:20])
9
+
10
+ import os
11
+ import random
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+
17
+ from torch.utils.data import DataLoader, Subset
18
+ from torchvision import datasets, transforms, models
19
+ from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
20
+ from transformers import ViTForImageClassification
21
+ from huggingface_hub import hf_hub_download
22
+
23
+ from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
24
+
25
+ SEED = 42
26
+ IMAGE_SIZE = 224
27
+ BATCH_SIZE = 32
28
+ NUM_WORKERS = 2
29
+
30
+ HF_REPO_ID = "Rashidbm/saudi-date-classifier"
31
+
32
+ RESNET_FILE = "arabic_dates_resnet50_best_V2.pth"
33
+ EFFICIENTNET_FILE = "efficientnet_best.pth"
34
+ VIT_FILE = "vit_best_model.pth"
35
+
36
+ EXPECTED_CLASSES = [
37
+ "Ajwa", "Galaxy", "Medjool", "Meneifi", "Nabtat Ali",
38
+ "Rutab", "Shaishe", "Sokari", "Sugaey"
39
+ ]
40
+
41
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
+ print("Using device:", device)
43
+
44
+ random.seed(SEED)
45
+ np.random.seed(SEED)
46
+ torch.manual_seed(SEED)
47
+ torch.cuda.manual_seed_all(SEED)
48
+
49
+ val_transforms = transforms.Compose([
50
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
51
+ transforms.ToTensor(),
52
+ transforms.Normalize([0.485, 0.456, 0.406],
53
+ [0.229, 0.224, 0.225]),
54
+ ])
55
+
56
+ full_dataset = datasets.ImageFolder(DATASET_DIR, transform=val_transforms)
57
+
58
+ class_names = full_dataset.classes
59
+ print("Classes found:", class_names)
60
+
61
+ if class_names != EXPECTED_CLASSES:
62
+ raise ValueError(f"Class order mismatch.\nExpected: {EXPECTED_CLASSES}\nFound: {class_names}")
63
+
64
+ n = len(full_dataset)
65
+ n_train = int(0.70 * n)
66
+ n_val = int(0.15 * n)
67
+ n_test = n - n_train - n_val
68
+
69
+ generator = torch.Generator().manual_seed(SEED)
70
+ indices = torch.randperm(n, generator=generator).tolist()
71
+
72
+ train_idx = indices[:n_train]
73
+ val_idx = indices[n_train:n_train + n_val]
74
+ test_idx = indices[n_train + n_val:]
75
+
76
+ test_dataset = Subset(full_dataset, test_idx)
77
+
78
+ test_loader = DataLoader(
79
+ test_dataset,
80
+ batch_size=BATCH_SIZE,
81
+ shuffle=False,
82
+ num_workers=NUM_WORKERS,
83
+ pin_memory=True
84
+ )
85
+
86
+ print("Total images:", n)
87
+ print("Test images:", len(test_dataset))
88
+
89
+ # ResNet50
90
+ def build_resnet50(num_classes=9, dropout=0.3):
91
+ model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
92
+ model.fc = nn.Sequential(
93
+ nn.Dropout(dropout),
94
+ nn.Linear(model.fc.in_features, num_classes)
95
+ )
96
+ return model
97
+
98
+
99
+ # EfficientNet-B0
100
+ def build_efficientnet(num_classes=9, pretrained=True, dropout=0.3):
101
+ weights = EfficientNet_B0_Weights.DEFAULT if pretrained else None
102
+ model = efficientnet_b0(weights=weights)
103
+ in_features = model.classifier[1].in_features
104
+ model.classifier = nn.Sequential(
105
+ nn.Dropout(p=dropout),
106
+ nn.Linear(in_features, num_classes)
107
+ )
108
+ return model
109
+
110
+
111
+ # ViT wrapper
112
+ class PretrainedViTClassifier(nn.Module):
113
+ def __init__(self, model_name="google/vit-base-patch16-224-in21k", num_classes=9, dropout=0.1):
114
+ super().__init__()
115
+ self.backbone = ViTForImageClassification.from_pretrained(
116
+ model_name,
117
+ num_labels=num_classes,
118
+ ignore_mismatched_sizes=True,
119
+ )
120
+
121
+ def forward(self, x):
122
+ outputs = self.backbone(x)
123
+ return outputs.logits
124
+
125
+ resnet_path = hf_hub_download(repo_id=HF_REPO_ID, filename=RESNET_FILE)
126
+ efficientnet_path = hf_hub_download(repo_id=HF_REPO_ID, filename=EFFICIENTNET_FILE)
127
+ vit_path = hf_hub_download(repo_id=HF_REPO_ID, filename=VIT_FILE)
128
+
129
+ print("Downloaded:")
130
+ print(resnet_path)
131
+ print(efficientnet_path)
132
+ print(vit_path)
133
+
134
+ def load_checkpoint_flex(model, path):
135
+ ckpt = torch.load(path, map_location=device)
136
+
137
+ if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
138
+ model.load_state_dict(ckpt["model_state_dict"])
139
+ else:
140
+ model.load_state_dict(ckpt)
141
+
142
+ model.to(device)
143
+ model.eval()
144
+ return model
145
+
146
+
147
+ resnet_model = build_resnet50(num_classes=len(class_names), dropout=0.3)
148
+ resnet_model = load_checkpoint_flex(resnet_model, resnet_path)
149
+
150
+ efficientnet_model = build_efficientnet(num_classes=len(class_names), pretrained=True, dropout=0.3)
151
+ efficientnet_model = load_checkpoint_flex(efficientnet_model, efficientnet_path)
152
+
153
+ vit_model = PretrainedViTClassifier(
154
+ model_name="google/vit-base-patch16-224-in21k",
155
+ num_classes=len(class_names),
156
+ dropout=0.1,
157
+ )
158
+ vit_model = load_checkpoint_flex(vit_model, vit_path)
159
+
160
+ print("All models loaded successfully.")
161
+
162
+ @torch.no_grad()
163
+ def evaluate_ensemble(resnet_model, efficientnet_model, vit_model, loader, class_names):
164
+ all_preds = []
165
+ all_labels = []
166
+
167
+ for images, labels in loader:
168
+ images = images.to(device)
169
+ labels = labels.to(device)
170
+
171
+ logits_resnet = resnet_model(images)
172
+ logits_eff = efficientnet_model(images)
173
+ logits_vit = vit_model(images)
174
+
175
+ probs_resnet = F.softmax(logits_resnet, dim=1)
176
+ probs_eff = F.softmax(logits_eff, dim=1)
177
+ probs_vit = F.softmax(logits_vit, dim=1)
178
+
179
+ ensemble_probs = (probs_resnet + probs_eff + probs_vit) / 3.0
180
+ preds = ensemble_probs.argmax(dim=1)
181
+
182
+ all_preds.extend(preds.cpu().tolist())
183
+ all_labels.extend(labels.cpu().tolist())
184
+
185
+ acc = accuracy_score(all_labels, all_preds)
186
+ print(f"\nEnsemble Test Accuracy: {acc * 100:.2f}%\n")
187
+
188
+ print("Classification Report:")
189
+ print(classification_report(all_labels, all_preds, target_names=class_names))
190
+
191
+ print("Confusion Matrix:")
192
+ print(confusion_matrix(all_labels, all_preds))
193
+
194
+
195
+ evaluate_ensemble(resnet_model, efficientnet_model, vit_model, test_loader, class_names)