Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -17,11 +17,16 @@ from models import create_model
|
|
| 17 |
|
| 18 |
# Optional: Hugging Face imports (used only when evaluating HF-format checkpoints)
|
| 19 |
try:
|
|
|
|
| 20 |
from transformers import AutoConfig, AutoModelForImageClassification
|
| 21 |
HF_AVAILABLE = True
|
| 22 |
except Exception:
|
|
|
|
| 23 |
HF_AVAILABLE = False
|
| 24 |
|
|
|
|
|
|
|
|
|
|
| 25 |
# Configuration
|
| 26 |
# Default to the moved fine-tuned checkpoint if present
|
| 27 |
MODEL_PATH = os.environ.get('MODEL_PATH', os.path.join('best_model_finetuned.pth'))
|
|
@@ -58,6 +63,9 @@ def load_checkpoint_model(model_path, device):
|
|
| 58 |
heuristic handling for Hugging Face (Swin) checkpoints when HF_MODEL_ID is set.
|
| 59 |
Returns (model, actual_num_classes) or (None, None) on failure.
|
| 60 |
"""
|
|
|
|
|
|
|
|
|
|
| 61 |
# If user wants to force HF loading from hub, try that first (useful in Spaces)
|
| 62 |
if FORCE_HF_LOAD and HF_MODEL_ID and HF_AVAILABLE:
|
| 63 |
try:
|
|
@@ -65,6 +73,13 @@ def load_checkpoint_model(model_path, device):
|
|
| 65 |
hf_model = AutoModelForImageClassification.from_pretrained(HF_MODEL_ID)
|
| 66 |
hf_model.to(device)
|
| 67 |
hf_model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
num_labels = getattr(hf_model.config, 'num_labels', NUM_CLASSES)
|
| 69 |
print(f"Loaded HF model from hub with {num_labels} labels (force)")
|
| 70 |
return hf_model, num_labels
|
|
@@ -82,6 +97,13 @@ def load_checkpoint_model(model_path, device):
|
|
| 82 |
hf_model = AutoModelForImageClassification.from_pretrained(HF_MODEL_ID)
|
| 83 |
hf_model.to(device)
|
| 84 |
hf_model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
num_labels = getattr(hf_model.config, 'num_labels', NUM_CLASSES)
|
| 86 |
print(f"Loaded HF model from hub with {num_labels} labels")
|
| 87 |
return hf_model, num_labels
|
|
@@ -159,6 +181,13 @@ def load_checkpoint_model(model_path, device):
|
|
| 159 |
missing, unexpected = hf_model.load_state_dict(state_dict, strict=False)
|
| 160 |
hf_model.to(device)
|
| 161 |
hf_model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
ok_msg = f"Loaded HF model with non-strict state_dict (missing {len(missing)} keys, unexpected {len(unexpected)} keys)"
|
| 163 |
print(ok_msg)
|
| 164 |
logger.info(ok_msg)
|
|
@@ -260,16 +289,37 @@ def predict_bird(image):
|
|
| 260 |
if image.mode != 'RGB':
|
| 261 |
image = image.convert('RGB')
|
| 262 |
|
| 263 |
-
#
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
|
| 274 |
# Prediction
|
| 275 |
with torch.no_grad():
|
|
@@ -287,10 +337,34 @@ def predict_bird(image):
|
|
| 287 |
except Exception:
|
| 288 |
logits = outputs
|
| 289 |
|
|
|
|
| 290 |
# Ensure logits is a tensor
|
| 291 |
if not isinstance(logits, torch.Tensor):
|
| 292 |
logits = torch.tensor(np.asarray(logits)).to(DEVICE)
|
| 293 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
probabilities = F.softmax(logits, dim=1)
|
| 295 |
# Get top 5 predictions
|
| 296 |
top5_prob, top5_indices = torch.topk(probabilities, min(5, probabilities.shape[1]), dim=1)
|
|
@@ -310,7 +384,10 @@ def predict_bird(image):
|
|
| 310 |
return results
|
| 311 |
|
| 312 |
except Exception as e:
|
| 313 |
-
return
|
|
|
|
|
|
|
|
|
|
| 314 |
|
| 315 |
# Create Gradio interface
|
| 316 |
title = "🐦 Bird Species Classifier"
|
|
|
|
| 17 |
|
| 18 |
# Optional: Hugging Face imports (used only when evaluating HF-format checkpoints)
|
| 19 |
try:
|
| 20 |
+
import transformers
|
| 21 |
from transformers import AutoConfig, AutoModelForImageClassification
|
| 22 |
HF_AVAILABLE = True
|
| 23 |
except Exception:
|
| 24 |
+
transformers = None
|
| 25 |
HF_AVAILABLE = False
|
| 26 |
|
| 27 |
+
# HF image processor (AutoImageProcessor or AutoFeatureExtractor) will be stored here when available
|
| 28 |
+
hf_processor = None
|
| 29 |
+
|
| 30 |
# Configuration
|
| 31 |
# Default to the moved fine-tuned checkpoint if present
|
| 32 |
MODEL_PATH = os.environ.get('MODEL_PATH', os.path.join('best_model_finetuned.pth'))
|
|
|
|
| 63 |
heuristic handling for Hugging Face (Swin) checkpoints when HF_MODEL_ID is set.
|
| 64 |
Returns (model, actual_num_classes) or (None, None) on failure.
|
| 65 |
"""
|
| 66 |
+
# Allow writing to module-level hf_processor
|
| 67 |
+
global hf_processor
|
| 68 |
+
|
| 69 |
# If user wants to force HF loading from hub, try that first (useful in Spaces)
|
| 70 |
if FORCE_HF_LOAD and HF_MODEL_ID and HF_AVAILABLE:
|
| 71 |
try:
|
|
|
|
| 73 |
hf_model = AutoModelForImageClassification.from_pretrained(HF_MODEL_ID)
|
| 74 |
hf_model.to(device)
|
| 75 |
hf_model.eval()
|
| 76 |
+
# Try to load a matching image processor from the hub for preprocessing
|
| 77 |
+
try:
|
| 78 |
+
if transformers is not None:
|
| 79 |
+
hf_processor = transformers.AutoImageProcessor.from_pretrained(HF_MODEL_ID)
|
| 80 |
+
print("Loaded HF image processor from hub (force load)")
|
| 81 |
+
except Exception:
|
| 82 |
+
print("Warning: failed to load HF image processor for forced hub model")
|
| 83 |
num_labels = getattr(hf_model.config, 'num_labels', NUM_CLASSES)
|
| 84 |
print(f"Loaded HF model from hub with {num_labels} labels (force)")
|
| 85 |
return hf_model, num_labels
|
|
|
|
| 97 |
hf_model = AutoModelForImageClassification.from_pretrained(HF_MODEL_ID)
|
| 98 |
hf_model.to(device)
|
| 99 |
hf_model.eval()
|
| 100 |
+
# Try to load image processor for preprocessing
|
| 101 |
+
try:
|
| 102 |
+
if transformers is not None:
|
| 103 |
+
hf_processor = transformers.AutoImageProcessor.from_pretrained(HF_MODEL_ID)
|
| 104 |
+
print("Loaded HF image processor from hub")
|
| 105 |
+
except Exception:
|
| 106 |
+
print("Warning: failed to load HF image processor from hub")
|
| 107 |
num_labels = getattr(hf_model.config, 'num_labels', NUM_CLASSES)
|
| 108 |
print(f"Loaded HF model from hub with {num_labels} labels")
|
| 109 |
return hf_model, num_labels
|
|
|
|
| 181 |
missing, unexpected = hf_model.load_state_dict(state_dict, strict=False)
|
| 182 |
hf_model.to(device)
|
| 183 |
hf_model.eval()
|
| 184 |
+
# Try to fetch image processor for this hf id so we can preprocess in predict
|
| 185 |
+
try:
|
| 186 |
+
if transformers is not None:
|
| 187 |
+
hf_processor = transformers.AutoImageProcessor.from_pretrained(hf_id_to_use)
|
| 188 |
+
print(f"Loaded HF image processor for {hf_id_to_use}")
|
| 189 |
+
except Exception:
|
| 190 |
+
print(f"Warning: failed to load HF image processor for {hf_id_to_use}")
|
| 191 |
ok_msg = f"Loaded HF model with non-strict state_dict (missing {len(missing)} keys, unexpected {len(unexpected)} keys)"
|
| 192 |
print(ok_msg)
|
| 193 |
logger.info(ok_msg)
|
|
|
|
| 289 |
if image.mode != 'RGB':
|
| 290 |
image = image.convert('RGB')
|
| 291 |
|
| 292 |
+
# If an HF image processor was loaded alongside a HF model, prefer it for preprocessing
|
| 293 |
+
if hf_processor is not None:
|
| 294 |
+
try:
|
| 295 |
+
# AutoImageProcessor expects PIL images or numpy arrays; return_tensors='pt' gives PyTorch tensors
|
| 296 |
+
proc = hf_processor(images=image, return_tensors='pt')
|
| 297 |
+
# Some processors return 'pixel_values', others return 'pixel_values' key
|
| 298 |
+
if 'pixel_values' in proc:
|
| 299 |
+
input_tensor = proc['pixel_values'].to(DEVICE)
|
| 300 |
+
else:
|
| 301 |
+
# Fall back to first tensor-like value
|
| 302 |
+
val = next(iter(proc.values()))
|
| 303 |
+
input_tensor = val.to(DEVICE)
|
| 304 |
+
except Exception:
|
| 305 |
+
logger.exception('HF processor failed; falling back to torchvision preprocessing')
|
| 306 |
+
hf_local_fallback = True
|
| 307 |
+
else:
|
| 308 |
+
hf_local_fallback = False
|
| 309 |
+
else:
|
| 310 |
+
hf_local_fallback = True
|
| 311 |
+
|
| 312 |
+
if hf_local_fallback:
|
| 313 |
+
# Define preprocessing step by step to avoid namespace issues
|
| 314 |
+
resize = transforms.Resize((320, 320))
|
| 315 |
+
to_tensor = transforms.ToTensor()
|
| 316 |
+
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 317 |
+
|
| 318 |
+
# Apply transformations step by step
|
| 319 |
+
resized_image = resize(image)
|
| 320 |
+
tensor_image = to_tensor(resized_image)
|
| 321 |
+
normalized_tensor = normalize(tensor_image)
|
| 322 |
+
input_tensor = normalized_tensor.unsqueeze(0).to(DEVICE)
|
| 323 |
|
| 324 |
# Prediction
|
| 325 |
with torch.no_grad():
|
|
|
|
| 337 |
except Exception:
|
| 338 |
logits = outputs
|
| 339 |
|
| 340 |
+
|
| 341 |
# Ensure logits is a tensor
|
| 342 |
if not isinstance(logits, torch.Tensor):
|
| 343 |
logits = torch.tensor(np.asarray(logits)).to(DEVICE)
|
| 344 |
|
| 345 |
+
# Handle unexpected tensor shapes:
|
| 346 |
+
# - if logits has spatial dims (e.g., 4D), average them
|
| 347 |
+
# - if logits is 1D, unsqueeze batch dim
|
| 348 |
+
try:
|
| 349 |
+
if logits.dim() > 2:
|
| 350 |
+
# average over all dims after channel dim
|
| 351 |
+
reduce_dims = tuple(range(2, logits.dim()))
|
| 352 |
+
logits = logits.mean(dim=reduce_dims)
|
| 353 |
+
if logits.dim() == 1:
|
| 354 |
+
logits = logits.unsqueeze(0)
|
| 355 |
+
except Exception:
|
| 356 |
+
# if shape ops fail, log and return safe error
|
| 357 |
+
logger.exception('Failed to normalize logits shape')
|
| 358 |
+
return {"Error": 0.0}
|
| 359 |
+
|
| 360 |
+
# If single-logit output, treat as sigmoid probability
|
| 361 |
+
if logits.size(1) == 1:
|
| 362 |
+
probs = torch.sigmoid(logits)
|
| 363 |
+
# return single-label prob mapped to first class or generic
|
| 364 |
+
prob = float(probs[0, 0].item())
|
| 365 |
+
label = class_names[0].replace('_', ' ') if class_names else 'Class_0'
|
| 366 |
+
return {label: prob}
|
| 367 |
+
|
| 368 |
probabilities = F.softmax(logits, dim=1)
|
| 369 |
# Get top 5 predictions
|
| 370 |
top5_prob, top5_indices = torch.topk(probabilities, min(5, probabilities.shape[1]), dim=1)
|
|
|
|
| 384 |
return results
|
| 385 |
|
| 386 |
except Exception as e:
|
| 387 |
+
# Log exception and return a numeric-friendly error response for Gradio
|
| 388 |
+
logger.exception('Prediction failed')
|
| 389 |
+
print('Prediction failed:', e)
|
| 390 |
+
return {"Error": 0.0}
|
| 391 |
|
| 392 |
# Create Gradio interface
|
| 393 |
title = "🐦 Bird Species Classifier"
|