FrAnKu34t23 commited on
Commit
b82867f
·
verified ·
1 Parent(s): f9fef0f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -11
app.py CHANGED
@@ -17,11 +17,16 @@ from models import create_model
17
 
18
  # Optional: Hugging Face imports (used only when evaluating HF-format checkpoints)
19
  try:
 
20
  from transformers import AutoConfig, AutoModelForImageClassification
21
  HF_AVAILABLE = True
22
  except Exception:
 
23
  HF_AVAILABLE = False
24
 
 
 
 
25
  # Configuration
26
  # Default to the moved fine-tuned checkpoint if present
27
  MODEL_PATH = os.environ.get('MODEL_PATH', os.path.join('best_model_finetuned.pth'))
@@ -58,6 +63,9 @@ def load_checkpoint_model(model_path, device):
58
  heuristic handling for Hugging Face (Swin) checkpoints when HF_MODEL_ID is set.
59
  Returns (model, actual_num_classes) or (None, None) on failure.
60
  """
 
 
 
61
  # If user wants to force HF loading from hub, try that first (useful in Spaces)
62
  if FORCE_HF_LOAD and HF_MODEL_ID and HF_AVAILABLE:
63
  try:
@@ -65,6 +73,13 @@ def load_checkpoint_model(model_path, device):
65
  hf_model = AutoModelForImageClassification.from_pretrained(HF_MODEL_ID)
66
  hf_model.to(device)
67
  hf_model.eval()
 
 
 
 
 
 
 
68
  num_labels = getattr(hf_model.config, 'num_labels', NUM_CLASSES)
69
  print(f"Loaded HF model from hub with {num_labels} labels (force)")
70
  return hf_model, num_labels
@@ -82,6 +97,13 @@ def load_checkpoint_model(model_path, device):
82
  hf_model = AutoModelForImageClassification.from_pretrained(HF_MODEL_ID)
83
  hf_model.to(device)
84
  hf_model.eval()
 
 
 
 
 
 
 
85
  num_labels = getattr(hf_model.config, 'num_labels', NUM_CLASSES)
86
  print(f"Loaded HF model from hub with {num_labels} labels")
87
  return hf_model, num_labels
@@ -159,6 +181,13 @@ def load_checkpoint_model(model_path, device):
159
  missing, unexpected = hf_model.load_state_dict(state_dict, strict=False)
160
  hf_model.to(device)
161
  hf_model.eval()
 
 
 
 
 
 
 
162
  ok_msg = f"Loaded HF model with non-strict state_dict (missing {len(missing)} keys, unexpected {len(unexpected)} keys)"
163
  print(ok_msg)
164
  logger.info(ok_msg)
@@ -260,16 +289,37 @@ def predict_bird(image):
260
  if image.mode != 'RGB':
261
  image = image.convert('RGB')
262
 
263
- # Define preprocessing step by step to avoid namespace issues
264
- resize = transforms.Resize((320, 320))
265
- to_tensor = transforms.ToTensor()
266
- normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
267
-
268
- # Apply transformations step by step
269
- resized_image = resize(image)
270
- tensor_image = to_tensor(resized_image)
271
- normalized_tensor = normalize(tensor_image)
272
- input_tensor = normalized_tensor.unsqueeze(0).to(DEVICE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
  # Prediction
275
  with torch.no_grad():
@@ -287,10 +337,34 @@ def predict_bird(image):
287
  except Exception:
288
  logits = outputs
289
 
 
290
  # Ensure logits is a tensor
291
  if not isinstance(logits, torch.Tensor):
292
  logits = torch.tensor(np.asarray(logits)).to(DEVICE)
293
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  probabilities = F.softmax(logits, dim=1)
295
  # Get top 5 predictions
296
  top5_prob, top5_indices = torch.topk(probabilities, min(5, probabilities.shape[1]), dim=1)
@@ -310,7 +384,10 @@ def predict_bird(image):
310
  return results
311
 
312
  except Exception as e:
313
- return {"Error": "Prediction failed: " + str(e)}
 
 
 
314
 
315
  # Create Gradio interface
316
  title = "🐦 Bird Species Classifier"
 
17
 
18
  # Optional: Hugging Face imports (used only when evaluating HF-format checkpoints)
19
  try:
20
+ import transformers
21
  from transformers import AutoConfig, AutoModelForImageClassification
22
  HF_AVAILABLE = True
23
  except Exception:
24
+ transformers = None
25
  HF_AVAILABLE = False
26
 
27
+ # HF image processor (AutoImageProcessor or AutoFeatureExtractor) will be stored here when available
28
+ hf_processor = None
29
+
30
  # Configuration
31
  # Default to the moved fine-tuned checkpoint if present
32
  MODEL_PATH = os.environ.get('MODEL_PATH', os.path.join('best_model_finetuned.pth'))
 
63
  heuristic handling for Hugging Face (Swin) checkpoints when HF_MODEL_ID is set.
64
  Returns (model, actual_num_classes) or (None, None) on failure.
65
  """
66
+ # Allow writing to module-level hf_processor
67
+ global hf_processor
68
+
69
  # If user wants to force HF loading from hub, try that first (useful in Spaces)
70
  if FORCE_HF_LOAD and HF_MODEL_ID and HF_AVAILABLE:
71
  try:
 
73
  hf_model = AutoModelForImageClassification.from_pretrained(HF_MODEL_ID)
74
  hf_model.to(device)
75
  hf_model.eval()
76
+ # Try to load a matching image processor from the hub for preprocessing
77
+ try:
78
+ if transformers is not None:
79
+ hf_processor = transformers.AutoImageProcessor.from_pretrained(HF_MODEL_ID)
80
+ print("Loaded HF image processor from hub (force load)")
81
+ except Exception:
82
+ print("Warning: failed to load HF image processor for forced hub model")
83
  num_labels = getattr(hf_model.config, 'num_labels', NUM_CLASSES)
84
  print(f"Loaded HF model from hub with {num_labels} labels (force)")
85
  return hf_model, num_labels
 
97
  hf_model = AutoModelForImageClassification.from_pretrained(HF_MODEL_ID)
98
  hf_model.to(device)
99
  hf_model.eval()
100
+ # Try to load image processor for preprocessing
101
+ try:
102
+ if transformers is not None:
103
+ hf_processor = transformers.AutoImageProcessor.from_pretrained(HF_MODEL_ID)
104
+ print("Loaded HF image processor from hub")
105
+ except Exception:
106
+ print("Warning: failed to load HF image processor from hub")
107
  num_labels = getattr(hf_model.config, 'num_labels', NUM_CLASSES)
108
  print(f"Loaded HF model from hub with {num_labels} labels")
109
  return hf_model, num_labels
 
181
  missing, unexpected = hf_model.load_state_dict(state_dict, strict=False)
182
  hf_model.to(device)
183
  hf_model.eval()
184
+ # Try to fetch image processor for this hf id so we can preprocess in predict
185
+ try:
186
+ if transformers is not None:
187
+ hf_processor = transformers.AutoImageProcessor.from_pretrained(hf_id_to_use)
188
+ print(f"Loaded HF image processor for {hf_id_to_use}")
189
+ except Exception:
190
+ print(f"Warning: failed to load HF image processor for {hf_id_to_use}")
191
  ok_msg = f"Loaded HF model with non-strict state_dict (missing {len(missing)} keys, unexpected {len(unexpected)} keys)"
192
  print(ok_msg)
193
  logger.info(ok_msg)
 
289
  if image.mode != 'RGB':
290
  image = image.convert('RGB')
291
 
292
+ # If an HF image processor was loaded alongside a HF model, prefer it for preprocessing
293
+ if hf_processor is not None:
294
+ try:
295
+ # AutoImageProcessor expects PIL images or numpy arrays; return_tensors='pt' gives PyTorch tensors
296
+ proc = hf_processor(images=image, return_tensors='pt')
297
+ # Some processors return 'pixel_values', others return 'pixel_values' key
298
+ if 'pixel_values' in proc:
299
+ input_tensor = proc['pixel_values'].to(DEVICE)
300
+ else:
301
+ # Fall back to first tensor-like value
302
+ val = next(iter(proc.values()))
303
+ input_tensor = val.to(DEVICE)
304
+ except Exception:
305
+ logger.exception('HF processor failed; falling back to torchvision preprocessing')
306
+ hf_local_fallback = True
307
+ else:
308
+ hf_local_fallback = False
309
+ else:
310
+ hf_local_fallback = True
311
+
312
+ if hf_local_fallback:
313
+ # Define preprocessing step by step to avoid namespace issues
314
+ resize = transforms.Resize((320, 320))
315
+ to_tensor = transforms.ToTensor()
316
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
317
+
318
+ # Apply transformations step by step
319
+ resized_image = resize(image)
320
+ tensor_image = to_tensor(resized_image)
321
+ normalized_tensor = normalize(tensor_image)
322
+ input_tensor = normalized_tensor.unsqueeze(0).to(DEVICE)
323
 
324
  # Prediction
325
  with torch.no_grad():
 
337
  except Exception:
338
  logits = outputs
339
 
340
+
341
  # Ensure logits is a tensor
342
  if not isinstance(logits, torch.Tensor):
343
  logits = torch.tensor(np.asarray(logits)).to(DEVICE)
344
 
345
+ # Handle unexpected tensor shapes:
346
+ # - if logits has spatial dims (e.g., 4D), average them
347
+ # - if logits is 1D, unsqueeze batch dim
348
+ try:
349
+ if logits.dim() > 2:
350
+ # average over all dims after channel dim
351
+ reduce_dims = tuple(range(2, logits.dim()))
352
+ logits = logits.mean(dim=reduce_dims)
353
+ if logits.dim() == 1:
354
+ logits = logits.unsqueeze(0)
355
+ except Exception:
356
+ # if shape ops fail, log and return safe error
357
+ logger.exception('Failed to normalize logits shape')
358
+ return {"Error": 0.0}
359
+
360
+ # If single-logit output, treat as sigmoid probability
361
+ if logits.size(1) == 1:
362
+ probs = torch.sigmoid(logits)
363
+ # return single-label prob mapped to first class or generic
364
+ prob = float(probs[0, 0].item())
365
+ label = class_names[0].replace('_', ' ') if class_names else 'Class_0'
366
+ return {label: prob}
367
+
368
  probabilities = F.softmax(logits, dim=1)
369
  # Get top 5 predictions
370
  top5_prob, top5_indices = torch.topk(probabilities, min(5, probabilities.shape[1]), dim=1)
 
384
  return results
385
 
386
  except Exception as e:
387
+ # Log exception and return a numeric-friendly error response for Gradio
388
+ logger.exception('Prediction failed')
389
+ print('Prediction failed:', e)
390
+ return {"Error": 0.0}
391
 
392
  # Create Gradio interface
393
  title = "🐦 Bird Species Classifier"