Spaces:
Sleeping
Sleeping
update
Browse files
app.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
"""
|
| 2 |
Simple Demo for Pest and Disease Classification
|
| 3 |
For Hugging Face Space Deployment
|
|
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
import torch
|
|
@@ -8,6 +9,8 @@ from PIL import Image
|
|
| 8 |
import json
|
| 9 |
import gradio as gr
|
| 10 |
from torchvision import transforms
|
|
|
|
|
|
|
| 11 |
|
| 12 |
from model import create_model
|
| 13 |
|
|
@@ -62,8 +65,94 @@ class PestDiseasePredictor:
|
|
| 62 |
return dict(sorted(results.items(), key=lambda x: x[1], reverse=True))
|
| 63 |
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
# ========== For Hugging Face Space ==========
|
| 66 |
-
checkpoint_path = "checkpoints/
|
| 67 |
label_mapping_path = "label_mapping.json"
|
| 68 |
backbone = 'efficientnet_b3'
|
| 69 |
device = "cuda"
|
|
|
|
| 1 |
"""
|
| 2 |
Simple Demo for Pest and Disease Classification
|
| 3 |
For Hugging Face Space Deployment
|
| 4 |
+
Supports both single model and ensemble prediction
|
| 5 |
"""
|
| 6 |
|
| 7 |
import torch
|
|
|
|
| 9 |
import json
|
| 10 |
import gradio as gr
|
| 11 |
from torchvision import transforms
|
| 12 |
+
import numpy as np
|
| 13 |
+
from pathlib import Path
|
| 14 |
|
| 15 |
from model import create_model
|
| 16 |
|
|
|
|
| 65 |
return dict(sorted(results.items(), key=lambda x: x[1], reverse=True))
|
| 66 |
|
| 67 |
|
| 68 |
+
class EnsemblePredictor:
|
| 69 |
+
"""Ensemble predictor using weighted soft voting"""
|
| 70 |
+
|
| 71 |
+
def __init__(self, checkpoint_paths, weights, label_mapping_path, backbone='efficientnet_b3', device='cuda'):
|
| 72 |
+
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
|
| 73 |
+
|
| 74 |
+
# Normalize weights to sum to 1
|
| 75 |
+
weights = np.array(weights)
|
| 76 |
+
self.weights = weights / weights.sum()
|
| 77 |
+
|
| 78 |
+
# Load label mapping
|
| 79 |
+
with open(label_mapping_path, 'r', encoding='utf-8') as f:
|
| 80 |
+
mapping = json.load(f)
|
| 81 |
+
self.id_to_label = {int(k): v for k, v in mapping['id_to_label'].items()}
|
| 82 |
+
self.num_classes = mapping['num_classes']
|
| 83 |
+
|
| 84 |
+
# Load all models
|
| 85 |
+
self.models = []
|
| 86 |
+
print(f"\n{'='*80}")
|
| 87 |
+
print("Loading Ensemble Models")
|
| 88 |
+
print(f"{'='*80}")
|
| 89 |
+
|
| 90 |
+
for i, checkpoint_path in enumerate(checkpoint_paths):
|
| 91 |
+
print(f"\nModel {i+1}/{len(checkpoint_paths)}")
|
| 92 |
+
print(f" Checkpoint: {checkpoint_path}")
|
| 93 |
+
print(f" Weight: {self.weights[i]:.4f}")
|
| 94 |
+
|
| 95 |
+
# Create model
|
| 96 |
+
model = create_model(
|
| 97 |
+
num_classes=self.num_classes,
|
| 98 |
+
backbone=backbone,
|
| 99 |
+
pretrained=False
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# Load checkpoint
|
| 103 |
+
if Path(checkpoint_path).exists():
|
| 104 |
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
| 105 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 106 |
+
model = model.to(self.device)
|
| 107 |
+
model.eval()
|
| 108 |
+
self.models.append(model)
|
| 109 |
+
print(f" β
Loaded successfully")
|
| 110 |
+
else:
|
| 111 |
+
print(f" β Checkpoint not found: {checkpoint_path}")
|
| 112 |
+
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
|
| 113 |
+
|
| 114 |
+
print(f"\n{'='*80}")
|
| 115 |
+
print(f"β
Ensemble loaded: {len(self.models)} models")
|
| 116 |
+
print(f"π» Device: {self.device}")
|
| 117 |
+
print(f"π Classes: {self.num_classes}")
|
| 118 |
+
print(f"{'='*80}\n")
|
| 119 |
+
|
| 120 |
+
# Image transforms
|
| 121 |
+
self.transform = transforms.Compose([
|
| 122 |
+
transforms.Resize((224, 224)),
|
| 123 |
+
transforms.ToTensor(),
|
| 124 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 125 |
+
std=[0.229, 0.224, 0.225])
|
| 126 |
+
])
|
| 127 |
+
|
| 128 |
+
def predict(self, image):
|
| 129 |
+
"""Predict using weighted soft voting"""
|
| 130 |
+
if image.mode != 'RGB':
|
| 131 |
+
image = image.convert('RGB')
|
| 132 |
+
|
| 133 |
+
img_tensor = self.transform(image).unsqueeze(0).to(self.device)
|
| 134 |
+
|
| 135 |
+
# Get predictions from all models
|
| 136 |
+
ensemble_probs = np.zeros(self.num_classes)
|
| 137 |
+
|
| 138 |
+
with torch.no_grad():
|
| 139 |
+
for model, weight in zip(self.models, self.weights):
|
| 140 |
+
outputs = model(img_tensor)
|
| 141 |
+
probabilities = torch.nn.functional.softmax(outputs, dim=1)
|
| 142 |
+
probs = probabilities[0].cpu().numpy()
|
| 143 |
+
ensemble_probs += weight * probs
|
| 144 |
+
|
| 145 |
+
# Create results dictionary
|
| 146 |
+
results = {}
|
| 147 |
+
for idx, prob in enumerate(ensemble_probs):
|
| 148 |
+
class_name = self.id_to_label[idx]
|
| 149 |
+
results[class_name] = float(prob)
|
| 150 |
+
|
| 151 |
+
return dict(sorted(results.items(), key=lambda x: x[1], reverse=True))
|
| 152 |
+
|
| 153 |
+
|
| 154 |
# ========== For Hugging Face Space ==========
|
| 155 |
+
checkpoint_path = "checkpoints/best_model_fold1.pth"
|
| 156 |
label_mapping_path = "label_mapping.json"
|
| 157 |
backbone = 'efficientnet_b3'
|
| 158 |
device = "cuda"
|
checkpoints/{best_model_993.pth β best_model_fold1.pth}
RENAMED
|
File without changes
|
checkpoints/best_model_fold2.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:caeb0e81ac149d12a107acfa87268d90ac5f7711f9a8cb531c50f3ea04748ae3
|
| 3 |
+
size 138717293
|
checkpoints/best_model_fold3.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3ac3889dde1936a9c9c3ea18b41802a352adf8849df1b5b568d694e8684c2586
|
| 3 |
+
size 138717293
|
checkpoints/best_model_fold4.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4de4227b21d8470070432370a27bb99551813dd77d729b2a44f11c64ef209e43
|
| 3 |
+
size 138717293
|
checkpoints/best_model_fold5.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cf4e72ccb78e6d2849f67ab552abc28d7ff07fab87957cb34d290fe00b8831fc
|
| 3 |
+
size 138717293
|