Boyun7 commited on
Commit
aca7c2f
Β·
1 Parent(s): 1d24cbd
app.py CHANGED
@@ -1,6 +1,7 @@
1
  """
2
  Simple Demo for Pest and Disease Classification
3
  For Hugging Face Space Deployment
 
4
  """
5
 
6
  import torch
@@ -8,6 +9,8 @@ from PIL import Image
8
  import json
9
  import gradio as gr
10
  from torchvision import transforms
 
 
11
 
12
  from model import create_model
13
 
@@ -62,8 +65,94 @@ class PestDiseasePredictor:
62
  return dict(sorted(results.items(), key=lambda x: x[1], reverse=True))
63
 
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  # ========== For Hugging Face Space ==========
66
- checkpoint_path = "checkpoints/best_model_993.pth"
67
  label_mapping_path = "label_mapping.json"
68
  backbone = 'efficientnet_b3'
69
  device = "cuda"
 
1
  """
2
  Simple Demo for Pest and Disease Classification
3
  For Hugging Face Space Deployment
4
+ Supports both single model and ensemble prediction
5
  """
6
 
7
  import torch
 
9
  import json
10
  import gradio as gr
11
  from torchvision import transforms
12
+ import numpy as np
13
+ from pathlib import Path
14
 
15
  from model import create_model
16
 
 
65
  return dict(sorted(results.items(), key=lambda x: x[1], reverse=True))
66
 
67
 
68
+ class EnsemblePredictor:
69
+ """Ensemble predictor using weighted soft voting"""
70
+
71
+ def __init__(self, checkpoint_paths, weights, label_mapping_path, backbone='efficientnet_b3', device='cuda'):
72
+ self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
73
+
74
+ # Normalize weights to sum to 1
75
+ weights = np.array(weights)
76
+ self.weights = weights / weights.sum()
77
+
78
+ # Load label mapping
79
+ with open(label_mapping_path, 'r', encoding='utf-8') as f:
80
+ mapping = json.load(f)
81
+ self.id_to_label = {int(k): v for k, v in mapping['id_to_label'].items()}
82
+ self.num_classes = mapping['num_classes']
83
+
84
+ # Load all models
85
+ self.models = []
86
+ print(f"\n{'='*80}")
87
+ print("Loading Ensemble Models")
88
+ print(f"{'='*80}")
89
+
90
+ for i, checkpoint_path in enumerate(checkpoint_paths):
91
+ print(f"\nModel {i+1}/{len(checkpoint_paths)}")
92
+ print(f" Checkpoint: {checkpoint_path}")
93
+ print(f" Weight: {self.weights[i]:.4f}")
94
+
95
+ # Create model
96
+ model = create_model(
97
+ num_classes=self.num_classes,
98
+ backbone=backbone,
99
+ pretrained=False
100
+ )
101
+
102
+ # Load checkpoint
103
+ if Path(checkpoint_path).exists():
104
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
105
+ model.load_state_dict(checkpoint['model_state_dict'])
106
+ model = model.to(self.device)
107
+ model.eval()
108
+ self.models.append(model)
109
+ print(f" βœ… Loaded successfully")
110
+ else:
111
+ print(f" ❌ Checkpoint not found: {checkpoint_path}")
112
+ raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
113
+
114
+ print(f"\n{'='*80}")
115
+ print(f"βœ… Ensemble loaded: {len(self.models)} models")
116
+ print(f"πŸ’» Device: {self.device}")
117
+ print(f"πŸ“š Classes: {self.num_classes}")
118
+ print(f"{'='*80}\n")
119
+
120
+ # Image transforms
121
+ self.transform = transforms.Compose([
122
+ transforms.Resize((224, 224)),
123
+ transforms.ToTensor(),
124
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
125
+ std=[0.229, 0.224, 0.225])
126
+ ])
127
+
128
+ def predict(self, image):
129
+ """Predict using weighted soft voting"""
130
+ if image.mode != 'RGB':
131
+ image = image.convert('RGB')
132
+
133
+ img_tensor = self.transform(image).unsqueeze(0).to(self.device)
134
+
135
+ # Get predictions from all models
136
+ ensemble_probs = np.zeros(self.num_classes)
137
+
138
+ with torch.no_grad():
139
+ for model, weight in zip(self.models, self.weights):
140
+ outputs = model(img_tensor)
141
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)
142
+ probs = probabilities[0].cpu().numpy()
143
+ ensemble_probs += weight * probs
144
+
145
+ # Create results dictionary
146
+ results = {}
147
+ for idx, prob in enumerate(ensemble_probs):
148
+ class_name = self.id_to_label[idx]
149
+ results[class_name] = float(prob)
150
+
151
+ return dict(sorted(results.items(), key=lambda x: x[1], reverse=True))
152
+
153
+
154
  # ========== For Hugging Face Space ==========
155
+ checkpoint_path = "checkpoints/best_model_fold1.pth"
156
  label_mapping_path = "label_mapping.json"
157
  backbone = 'efficientnet_b3'
158
  device = "cuda"
checkpoints/{best_model_993.pth β†’ best_model_fold1.pth} RENAMED
File without changes
checkpoints/best_model_fold2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:caeb0e81ac149d12a107acfa87268d90ac5f7711f9a8cb531c50f3ea04748ae3
3
+ size 138717293
checkpoints/best_model_fold3.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ac3889dde1936a9c9c3ea18b41802a352adf8849df1b5b568d694e8684c2586
3
+ size 138717293
checkpoints/best_model_fold4.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4de4227b21d8470070432370a27bb99551813dd77d729b2a44f11c64ef209e43
3
+ size 138717293
checkpoints/best_model_fold5.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf4e72ccb78e6d2849f67ab552abc28d7ff07fab87957cb34d290fe00b8831fc
3
+ size 138717293