Create new_approach/spa_ensemble.py

#1
by FrAnKu34t23 - opened
Files changed (1) hide show
  1. new_approach/spa_ensemble.py +323 -0
new_approach/spa_ensemble.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import cv2
6
+ from PIL import Image
7
+ import os
8
+ from pathlib import Path
9
+ from scipy import stats
10
+ from scipy.fftpack import dct
11
+ from sklearn.preprocessing import StandardScaler
12
+ import torchvision.transforms as transforms
13
+ import open_clip
14
+
15
+ # --- CONFIGURATION ---
16
+ CONFIDENCE_THRESHOLD = 0.99
17
+
18
+ # PATH UPDATES:
19
+ # 1. Models are now INSIDE this folder structure: new_approach/ensemble_models/
20
+ # Using relative path from where app.py is executed (root)
21
+ MODELS_DIR = Path("new_approach/ensemble_models")
22
+
23
+ # 2. The species list is in the 'list' folder in the root
24
+ LIST_DIR = Path("list")
25
+
26
+ # ==============================================================================
27
+ # 1. FEATURE EXTRACTOR
28
+ # ==============================================================================
29
+ class FeatureExtractor:
30
+ @staticmethod
31
+ def extract_color_features(img):
32
+ img_np = np.array(img); features = {}
33
+ for i, channel in enumerate(['R', 'G', 'B']):
34
+ ch = img_np[:, :, i].flatten()
35
+ if len(ch) > 0:
36
+ features.update({f'color_{channel}_mean': float(np.mean(ch)), f'color_{channel}_std': float(np.std(ch)), f'color_{channel}_skew': float(stats.skew(ch)), f'color_{channel}_min': float(np.min(ch)), f'color_{channel}_max': float(np.max(ch))})
37
+ else:
38
+ features.update({f'color_{channel}_mean': 0.0, f'color_{channel}_std': 0.0, f'color_{channel}_skew': 0.0, f'color_{channel}_min': 0.0, f'color_{channel}_max': 0.0})
39
+ hist, _ = np.histogram(ch, bins=3, range=(0, 256)); hist = hist / (hist.sum() + 1e-8);
40
+ for j, v in enumerate(hist): features[f'color_{channel}_hist_bin{j}'] = float(v)
41
+ try:
42
+ hsv = cv2.cvtColor(img_np, cv2.COLOR_RGB2HSV)
43
+ features.update({'color_hue_mean': float(np.mean(hsv[:, :, 0])), 'color_saturation_mean': float(np.mean(hsv[:, :, 1])), 'color_value_mean': float(np.mean(hsv[:, :, 2]))})
44
+ except:
45
+ features.update({'color_hue_mean': 0.0, 'color_saturation_mean': 0.0, 'color_value_mean': 0.0})
46
+ return features
47
+
48
+ @staticmethod
49
+ def extract_texture_features(img):
50
+ img_np = np.array(img); gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY); features = {}
51
+ # Optimization: Canny/Sobel on resized image (handled in extract_all_features)
52
+ edges = cv2.Canny(gray, 50, 150)
53
+ gx, gy = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3), cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
54
+ features.update({
55
+ 'texture_edge_density': float(np.sum(edges > 0) / edges.size) if edges.size > 0 else 0.0,
56
+ 'texture_gradient_mean': float(np.mean(np.sqrt(gx**2 + gy**2))),
57
+ 'texture_gradient_std': float(np.std(np.sqrt(gx**2 + gy**2))),
58
+ 'texture_laplacian_var': float(np.var(cv2.Laplacian(gray, cv2.CV_64F)))
59
+ })
60
+ return features
61
+
62
+ @staticmethod
63
+ def extract_shape_features(img):
64
+ w, h = img.size; features = {}; features.update({'shape_height': h, 'shape_width': w, 'shape_aspect_ratio': w / h if h > 0 else 0.0, 'shape_area': w * h}); return features
65
+
66
+ @staticmethod
67
+ def extract_brightness_features(img):
68
+ img_np = np.array(img); gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY); features = {}; features.update({'brightness_mean': float(np.mean(gray)), 'brightness_std': float(np.std(gray))}); return features
69
+
70
+ @staticmethod
71
+ def extract_frequency_features(img):
72
+ img_np = np.array(img); gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY); gray_small = cv2.resize(gray, (64, 64)); dct_coeffs = dct(dct(gray_small.T, norm='ortho').T, norm='ortho'); features = {}
73
+ for i, v in enumerate(dct_coeffs.flatten()[:10]): features[f'freq_dct_{i}'] = float(v); return features
74
+
75
+ @staticmethod
76
+ def extract_statistical_features(img):
77
+ img_np = np.array(img); gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY); hist, _ = np.histogram(gray.flatten(), bins=256, range=(0, 256)); hist = hist / (hist.sum() + 1e-8)
78
+ hist_nonzero = hist[hist > 0]; entropy = -np.sum(hist_nonzero * np.log2(hist_nonzero)) if hist_nonzero.size > 0 else 0.0; features = {}; features.update({'stat_entropy': entropy, 'stat_uniformity': float(np.sum(hist**2))}); return features
79
+
80
+ @staticmethod
81
+ def extract_all_features(img):
82
+ img = img.convert('RGB')
83
+ # OPTIMIZATION: Resize for Handcrafted Features to speed up Canny/Sobel
84
+ max_size = 1024
85
+ if max(img.size) > max_size:
86
+ img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
87
+
88
+ features = {}
89
+ features.update(FeatureExtractor.extract_color_features(img))
90
+ features.update(FeatureExtractor.extract_texture_features(img))
91
+ features.update(FeatureExtractor.extract_shape_features(img))
92
+ features.update(FeatureExtractor.extract_brightness_features(img))
93
+ features.update(FeatureExtractor.extract_frequency_features(img))
94
+ features.update(FeatureExtractor.extract_statistical_features(img))
95
+ return features
96
+
97
+ # ==============================================================================
98
+ # 2. MODEL ARCHITECTURE
99
+ # ==============================================================================
100
+ class BioCLIP2ZeroShot:
101
+ def __init__(self, device, class_to_idx, id_to_name):
102
+ self.device = device; self.num_classes = len(class_to_idx); self.idx_to_class = {v: k for k, v in class_to_idx.items()}; self.id_to_name = id_to_name
103
+ print("Loading BioCLIP-2 model...")
104
+ try:
105
+ self.model, _, self.preprocess = open_clip.create_model_and_transforms('hf-hub:imageomics/bioclip-2')
106
+ self.tokenizer = open_clip.get_tokenizer('hf-hub:imageomics/bioclip-2')
107
+ except:
108
+ print("Warning: BioCLIP-2 load failed, trying base BioCLIP...")
109
+ self.model, _, self.preprocess = open_clip.create_model_and_transforms('hf-hub:imageomics/bioclip')
110
+ self.tokenizer = open_clip.get_tokenizer('hf-hub:imageomics/bioclip')
111
+ self.model.to(self.device).eval()
112
+ self.text_features_prototypes = self._precompute_text_features()
113
+
114
+ def _precompute_text_features(self):
115
+ templates = [ "a photo of {}", "a herbarium specimen of {}", "a botanical photograph of {}", "{} plant species", "leaves and flowers of {}" ]
116
+ class_ids = [self.idx_to_class[i] for i in range(self.num_classes)]
117
+ class_names = [self.id_to_name.get(str(cid), str(cid)) for cid in class_ids]
118
+ all_emb = []; bs = 64
119
+ text_inputs = [t.format(name) for name in class_names for t in templates]
120
+ with torch.no_grad():
121
+ for i in range(0, len(text_inputs), bs):
122
+ tokens = self.tokenizer(text_inputs[i:i+bs]).to(self.device)
123
+ emb = self.model.encode_text(tokens)
124
+ all_emb.append(emb)
125
+ all_text_embs = torch.cat(all_emb, dim=0).cpu().numpy()
126
+ prototypes = np.zeros((self.num_classes, all_text_embs.shape[1]), dtype=np.float32)
127
+ for idx in range(self.num_classes):
128
+ start = idx * len(templates)
129
+ avg = np.mean(all_text_embs[start:start + len(templates)], axis=0)
130
+ norm = np.linalg.norm(avg)
131
+ prototypes[idx] = avg / norm if norm > 0 else avg
132
+ return torch.from_numpy(prototypes).to(self.device)
133
+
134
+ def predict_zero_shot_logits(self, img):
135
+ processed = self.preprocess(img).unsqueeze(0).to(self.device)
136
+ with torch.no_grad():
137
+ image_features = self.model.encode_image(processed)
138
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
139
+ prototypes = self.text_features_prototypes
140
+ try: logit_scale = self.model.logit_scale.exp()
141
+ except: logit_scale = torch.tensor(100.0).to(self.device)
142
+ logits = (logit_scale * image_features @ prototypes.T).cpu().numpy().squeeze()
143
+ return logits
144
+
145
+ class EnsembleClassifier(nn.Module):
146
+ def __init__(self, num_handcrafted_features=49, dinov2_dim=1024, bioclip2_dim=100,
147
+ num_classes=100, hidden_dim=512, dropout_rate=0.3, prototype_dim=512):
148
+ super().__init__()
149
+ self.dinov2_proj = nn.Sequential(nn.Linear(dinov2_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout_rate))
150
+ self.handcraft_branch = nn.Sequential(
151
+ nn.Linear(num_handcrafted_features, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(dropout_rate),
152
+ nn.Linear(128, hidden_dim // 2), nn.BatchNorm1d(hidden_dim // 2), nn.ReLU(), nn.Dropout(dropout_rate),
153
+ nn.Linear(hidden_dim // 2, hidden_dim // 2), nn.BatchNorm1d(hidden_dim // 2), nn.ReLU(), nn.Dropout(dropout_rate))
154
+ self.bioclip2_branch = nn.Sequential(
155
+ nn.Linear(bioclip2_dim, hidden_dim // 4), nn.BatchNorm1d(hidden_dim // 4), nn.ReLU(), nn.Dropout(dropout_rate * 0.5))
156
+ fusion_input_dim = hidden_dim + hidden_dim // 2 + hidden_dim // 4
157
+ self.fusion = nn.Sequential(
158
+ nn.Linear(fusion_input_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(), nn.Dropout(dropout_rate))
159
+ self.classifier = nn.Linear(hidden_dim, num_classes)
160
+ self.prototype_proj = nn.Linear(hidden_dim, prototype_dim)
161
+
162
+ def forward(self, handcrafted_features, dinov2_features, bioclip2_logits):
163
+ dinov2_out = self.dinov2_proj(dinov2_features)
164
+ handcraft_out = self.handcraft_branch(handcrafted_features)
165
+ bioclip2_out = self.bioclip2_branch(bioclip2_logits)
166
+ shared_features = self.fusion(torch.cat([dinov2_out, handcraft_out, bioclip2_out], dim=1))
167
+ class_output = self.classifier(shared_features)
168
+ projected_feature = self.prototype_proj(shared_features)
169
+ return class_output, projected_feature
170
+
171
+ # ==============================================================================
172
+ # 3. MANAGER CLASS & EXPORTED FUNCTION
173
+ # ==============================================================================
174
+ class ModelManager:
175
+ def __init__(self):
176
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
177
+ print(f"Initializing SPA Ensemble on {self.device}...")
178
+
179
+ self.class_to_idx, self.idx_to_class, self.id_to_name = self.load_class_info()
180
+ self.num_classes = len(self.class_to_idx)
181
+ print(f"SPA Ensemble: Loaded {self.num_classes} classes.")
182
+
183
+ print("SPA Ensemble: Loading DINOv2...")
184
+ self.dinov2 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')
185
+ self.dinov2.to(self.device).eval()
186
+ self.dinov2_transform = transforms.Compose([
187
+ transforms.Resize(256), transforms.CenterCrop(224),
188
+ transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
189
+ ])
190
+
191
+ self.bioclip = BioCLIP2ZeroShot(self.device, self.class_to_idx, self.id_to_name)
192
+
193
+ self.models = []
194
+ # Dummy scaler. If you have 'scaler.joblib' from training, load it here!
195
+ self.scaler = StandardScaler()
196
+ self.scaler.fit(np.zeros((1, 49)))
197
+
198
+ print("SPA Ensemble: Loading Weights from 'new_approach/ensemble_models'...")
199
+ hidden_dims = [384, 448, 512, 576, 640]
200
+ dropout_rates = [0.2, 0.25, 0.3, 0.35, 0.4]
201
+
202
+ for i in range(5):
203
+ model_path = MODELS_DIR / f"ensemble_model_{i}.pth"
204
+ if model_path.exists():
205
+ model = EnsembleClassifier(
206
+ num_handcrafted_features=49, dinov2_dim=1024, bioclip2_dim=self.num_classes,
207
+ num_classes=self.num_classes, hidden_dim=hidden_dims[i], dropout_rate=dropout_rates[i]
208
+ )
209
+ try:
210
+ state_dict = torch.load(model_path, map_location=self.device)
211
+ model.load_state_dict(state_dict)
212
+ model.to(self.device).eval()
213
+ self.models.append(model)
214
+ except Exception as e:
215
+ print(f"Failed to load SPA model {i}: {e}")
216
+ else:
217
+ print(f"SPA model file {model_path} not found.")
218
+
219
+ def load_class_info(self):
220
+ class_to_idx = {}
221
+ id_to_name = {}
222
+
223
+ # Checking LIST_DIR which is set to 'list' folder in root
224
+ species_path = LIST_DIR / "species_list.txt"
225
+ train_path = LIST_DIR / "train.txt"
226
+
227
+ classes_set = set()
228
+
229
+ if train_path.exists():
230
+ with open(train_path, 'r') as f:
231
+ for line in f:
232
+ parts = line.strip().split()
233
+ if len(parts) >= 2: classes_set.add(parts[1])
234
+ elif species_path.exists():
235
+ with open(species_path, 'r') as f:
236
+ for line in f:
237
+ parts = line.strip().split(";", 1)
238
+ classes_set.add(parts[0].strip())
239
+ else:
240
+ # Fallback
241
+ classes_set = {str(i) for i in range(100)}
242
+
243
+ sorted_classes = sorted(list(classes_set))
244
+ class_to_idx = {cls: idx for idx, cls in enumerate(sorted_classes)}
245
+ idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}
246
+
247
+ if species_path.exists():
248
+ with open(species_path, 'r') as f:
249
+ for line in f:
250
+ if ";" in line:
251
+ parts = line.strip().split(";", 1)
252
+ id_to_name[parts[0].strip()] = parts[1].strip()
253
+ return class_to_idx, idx_to_class, id_to_name
254
+
255
+ def predict(self, image):
256
+ if image is None: return {}
257
+ img_pil = image.convert("RGB")
258
+
259
+ # 1. Handcrafted Features
260
+ hc_feats = FeatureExtractor.extract_all_features(img_pil)
261
+ hc_vector = np.array([hc_feats[k] for k in sorted(hc_feats.keys())]).reshape(1, -1)
262
+ hc_vector = self.scaler.transform(hc_vector)
263
+ hc_tensor = torch.FloatTensor(hc_vector).to(self.device)
264
+
265
+ # 2. DINOv2 Features
266
+ dino_input = self.dinov2_transform(img_pil).unsqueeze(0).to(self.device)
267
+ with torch.no_grad():
268
+ dino_feats = self.dinov2(dino_input)
269
+ dino_feats = dino_feats / (dino_feats.norm(dim=-1, keepdim=True) + 1e-8)
270
+
271
+ # 3. BioCLIP Features
272
+ bioclip_logits = self.bioclip.predict_zero_shot_logits(img_pil)
273
+ bioclip_tensor = torch.FloatTensor(bioclip_logits).unsqueeze(0).to(self.device)
274
+
275
+ # 4. Ensemble Prediction
276
+ all_probs = []
277
+ if not self.models: return {"Error": "SPA Models not loaded"}
278
+
279
+ for model in self.models:
280
+ with torch.no_grad():
281
+ probs, _ = model(hc_tensor, dino_feats, bioclip_tensor)
282
+ probs = F.softmax(probs, dim=1).cpu().numpy()[0]
283
+ all_probs.append(probs)
284
+
285
+ final_ens_probs = np.mean(all_probs, axis=0)
286
+
287
+ # 5. Hybrid Routing (Ensemble + BioCLIP fallback)
288
+ exp_logits = np.exp(bioclip_logits)
289
+ bioclip_probs = exp_logits / np.sum(exp_logits)
290
+
291
+ ens_pred_idx = np.argmax(final_ens_probs)
292
+ ens_conf = final_ens_probs[ens_pred_idx]
293
+
294
+ if ens_conf < CONFIDENCE_THRESHOLD:
295
+ # Soft fusion if low confidence
296
+ final_probs = (final_ens_probs + bioclip_probs) / 2
297
+ else:
298
+ final_probs = final_ens_probs
299
+
300
+ # 6. Formatting
301
+ top_k = 5
302
+ top_indices = np.argsort(final_probs)[::-1][:top_k]
303
+ results = {}
304
+ for idx in top_indices:
305
+ class_id = self.idx_to_class[idx]
306
+ name = self.id_to_name.get(class_id, class_id)
307
+ score = float(final_probs[idx])
308
+ results[f"{name} ({class_id})"] = score
309
+
310
+ return results
311
+
312
+ # Initialize Singleton
313
+ try:
314
+ spa_manager = ModelManager()
315
+ except Exception as e:
316
+ print(f"CRITICAL ERROR initializing SPA: {e}")
317
+ spa_manager = None
318
+
319
+ # Exported Function
320
+ def predict_spa(image):
321
+ if spa_manager is None:
322
+ return {"Error": "SPA System failed to initialize."}
323
+ return spa_manager.predict(image)