""" Gradio App for Bird Classification - Hugging Face Deployment Enhanced model with architecture auto-detection and error handling. """ import gradio as gr import torch import torch.nn.functional as F from PIL import Image import json import numpy as np from torchvision import transforms import os import logging # Import our model architecture from models import create_model # Optional: Hugging Face imports (used only when evaluating HF-format checkpoints) try: import transformers from transformers import AutoConfig, AutoModelForImageClassification HF_AVAILABLE = True except Exception: transformers = None HF_AVAILABLE = False # HF image processor (AutoImageProcessor or AutoFeatureExtractor) will be stored here when available hf_processor = None # Configuration # Default to the moved fine-tuned checkpoint if present MODEL_PATH = os.environ.get('MODEL_PATH', os.path.join('best_model_finetuned.pth')) # Optional: if your HF model id is known (e.g. Emiel/cub-200-bird-classifier-swin), set HF_MODEL_ID env var HF_MODEL_ID = os.environ.get('HF_MODEL_ID', None) CLASS_NAMES_PATH = os.environ.get('CLASS_NAMES_PATH', 'class_names.json') FORCE_HF_LOAD = os.environ.get('FORCE_HF_LOAD', '0').lower() in ('1', 'true', 'yes') DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Default HF model id to try when checkpoint looks HF-like and HF_MODEL_ID not set DEFAULT_HF_ID = 'Emiel/cub-200-bird-classifier-swin' # Setup file logger for traceability in Spaces LOG_FILE = os.environ.get('APP_LOG_PATH', 'app.log') logging.basicConfig(level=logging.INFO, filename=LOG_FILE, filemode='a', format='%(asctime)s %(levelname)s: %(message)s') logger = logging.getLogger(__name__) # Load class names if os.path.exists(CLASS_NAMES_PATH): try: with open(CLASS_NAMES_PATH, 'r') as f: class_names = json.load(f) except Exception: class_names = [] else: class_names = [] NUM_CLASSES = len(class_names) def load_checkpoint_model(model_path, device): """Attempt to load a checkpoint. Supports local create_model-based checkpoints and heuristic handling for Hugging Face (Swin) checkpoints when HF_MODEL_ID is set. Returns (model, actual_num_classes) or (None, None) on failure. """ # Allow writing to module-level hf_processor global hf_processor # If user wants to force HF loading from hub, try that first (useful in Spaces) if FORCE_HF_LOAD and HF_MODEL_ID and HF_AVAILABLE: try: print(f"FORCE_HF_LOAD enabled: loading HF model from hub: {HF_MODEL_ID}") hf_model = AutoModelForImageClassification.from_pretrained(HF_MODEL_ID) hf_model.to(device) hf_model.eval() # Try to load a matching image processor from the hub for preprocessing try: if transformers is not None: hf_processor = transformers.AutoImageProcessor.from_pretrained(HF_MODEL_ID) print("Loaded HF image processor from hub (force load)") except Exception: print("Warning: failed to load HF image processor for forced hub model") num_labels = getattr(hf_model.config, 'num_labels', NUM_CLASSES) print(f"Loaded HF model from hub with {num_labels} labels (force)") return hf_model, num_labels except Exception as e: print("Forced HF hub load failed:", e) if not os.path.exists(model_path): msg = f"Model file not found at {model_path}" print(msg) logger.info(msg) # If HF_MODEL_ID is set and transformers are available, try to load from hub if HF_MODEL_ID and HF_AVAILABLE: try: print(f"Attempting to load model from Hugging Face Hub: {HF_MODEL_ID}") hf_model = AutoModelForImageClassification.from_pretrained(HF_MODEL_ID) hf_model.to(device) hf_model.eval() # Try to load image processor for preprocessing try: if transformers is not None: hf_processor = transformers.AutoImageProcessor.from_pretrained(HF_MODEL_ID) print("Loaded HF image processor from hub") except Exception: print("Warning: failed to load HF image processor from hub") num_labels = getattr(hf_model.config, 'num_labels', NUM_CLASSES) print(f"Loaded HF model from hub with {num_labels} labels") return hf_model, num_labels except Exception as e: print("Failed to load HF model from hub:", e) return None, None print(f"Loading checkpoint from: {model_path}") logger.info(f"Loading checkpoint from: {model_path}") try: ckpt = torch.load(model_path, map_location='cpu') except Exception as e: print("Failed to load checkpoint file:", e) logger.exception("Failed to load checkpoint file:") ckpt = {} # unwrap common dict wrapper (support both 'model_state_dict' and 'state_dict') state_dict = {} if isinstance(ckpt, dict): if 'model_state_dict' in ckpt and isinstance(ckpt['model_state_dict'], dict): state_dict = ckpt['model_state_dict'] elif 'state_dict' in ckpt and isinstance(ckpt['state_dict'], dict): state_dict = ckpt['state_dict'] else: # fallback: ckpt may already be a state dict state_dict = ckpt # If the state_dict is a single-key wrapper (e.g., {'state_dict': {...}} or {'model': {...}}), unwrap one more level if isinstance(state_dict, dict) and len(state_dict) == 1: sole_val = next(iter(state_dict.values())) if isinstance(sole_val, dict): # adopt inner dict as state_dict if it looks like parameters inner_keys = list(sole_val.keys())[:8] # Heuristic: keys with '.' and numeric shapes indicate a param dict if any('.' in k for k in inner_keys): logger.info(f"Unwrapping single-key checkpoint wrapper, inner keys sample: {inner_keys}") state_dict = sole_val # Diagnostic: print a few state_dict keys so we can tell checkpoint format try: sample_keys = list(state_dict.keys())[:16] print("Checkpoint sample keys:", sample_keys) logger.info(f"Checkpoint sample keys: {sample_keys}") except Exception: print("No state_dict keys to sample") logger.info("No state_dict keys to sample") # Heuristic: detect HF-style Swin checkpoint by looking for keys that start with 'swin.' # Detect HF-like keys; strip common 'module.' prefix before checking def key_is_hf_like(k: str) -> bool: kk = k.replace('module.', '') kk = kk.lower() return kk.startswith('swin.') or 'swin.embeddings' in kk or 'swin.patch_embeddings' in kk hf_like = any(key_is_hf_like(k) for k in state_dict.keys()) if state_dict else False hf_msg = f"hf_like_checkpoint_detected={hf_like} HF_AVAILABLE={HF_AVAILABLE} HF_MODEL_ID={'set' if HF_MODEL_ID else 'not-set'}" print(hf_msg) logger.info(hf_msg) if hf_like and HF_AVAILABLE: # choose which HF id to use: env var or default hf_id_to_use = HF_MODEL_ID or DEFAULT_HF_ID if HF_MODEL_ID is None: info_msg = f"HF_MODEL_ID not set; using DEFAULT_HF_ID='{DEFAULT_HF_ID}' to attempt hub load" print(info_msg) logger.info(info_msg) try: msg = f"Attempting to load Hugging Face model '{hf_id_to_use}' and apply checkpoint weights..." print(msg) logger.info(msg) # prefer using the hub config to instantiate exact architecture config = AutoConfig.from_pretrained(hf_id_to_use) hf_model = AutoModelForImageClassification.from_config(config) # load weights non-strictly: match shapes missing, unexpected = hf_model.load_state_dict(state_dict, strict=False) hf_model.to(device) hf_model.eval() # Try to fetch image processor for this hf id so we can preprocess in predict try: if transformers is not None: hf_processor = transformers.AutoImageProcessor.from_pretrained(hf_id_to_use) print(f"Loaded HF image processor for {hf_id_to_use}") except Exception: print(f"Warning: failed to load HF image processor for {hf_id_to_use}") ok_msg = f"Loaded HF model with non-strict state_dict (missing {len(missing)} keys, unexpected {len(unexpected)} keys)" print(ok_msg) logger.info(ok_msg) num_labels = getattr(hf_model.config, 'num_labels', NUM_CLASSES) return hf_model, num_labels except Exception as e: print("HF load failed:", e) logger.exception("HF load failed") print("Falling back to local model loader...") logger.info("Falling back to local model loader") # Fallback: try to detect EfficientNet-like shapes and create local model # Determine actual num classes by inspecting a likely classifier weight key actual_classes = NUM_CLASSES for k, v in state_dict.items(): if k.endswith('classifier.9.weight') or k.endswith('classifier.weight'): try: actual_classes = v.shape[0] break except Exception: pass # Heuristic to choose an EfficientNet variant based on conv head size model_type = 'efficientnet_b2' if state_dict: if 'backbone._conv_head.weight' in state_dict: try: conv_head_shape = state_dict['backbone._conv_head.weight'].shape if conv_head_shape[0] == 1536: model_type = 'efficientnet_b3' elif conv_head_shape[0] == 1408: model_type = 'efficientnet_b2' elif conv_head_shape[0] == 1280: model_type = 'efficientnet_b1' except Exception: pass print(f"Creating local model {model_type} with {actual_classes} classes (fallback)") model = create_model(num_classes=actual_classes, model_type=model_type, pretrained=False, dropout_rate=0.3) # Try to load state dict try: # if ckpt was a dict without model_state_dict, attempt to load directly to_load = state_dict if state_dict else ckpt model.load_state_dict(to_load, strict=False) model.to(device) model.eval() print("✅ Local model loaded (non-strict).") return model, actual_classes except Exception as e: print("Failed to load local model:", e) return None, None # Load model print("Loading model...", MODEL_PATH) model, actual_classes = load_checkpoint_model(MODEL_PATH, DEVICE) if model is None: print("No model available. The app will still launch but predictions will fail.") else: print(f"Model ready. Classes={actual_classes}") # If this is a Hugging Face model with id2label, prefer that mapping try: hf_config = getattr(model, 'config', None) if hf_config is not None: id2label = getattr(hf_config, 'id2label', None) if id2label: # id2label keys may be strings or ints # Build ordered class_names list by index max_idx = max(int(k) for k in id2label.keys()) hf_class_names = [""] * (max_idx + 1) for k, v in id2label.items(): hf_class_names[int(k)] = v.replace(' ', '_') if isinstance(v, str) else str(v) # Filter out empty entries hf_class_names = [c for c in hf_class_names if c] if len(hf_class_names) > 0: class_names = hf_class_names NUM_CLASSES = len(class_names) print(f"Using Hugging Face id2label mapping with {NUM_CLASSES} classes") except Exception as e: print("Warning: failed to extract id2label from HF model config:", e) # Warn if class_names.json doesn't match model classes if class_names and actual_classes and len(class_names) != actual_classes: print(f"Warning: class_names.json has {len(class_names)} entries but model expects {actual_classes} classes.") # If HF labels exist and match expected size, prefer them if len(class_names) < actual_classes: print("Note: consider updating class_names.json to match the model's label order or set HF_MODEL_ID to use id2label mapping.") def predict_bird(image): """ Predict bird species from uploaded image. """ try: # Preprocess image if isinstance(image, np.ndarray): image = Image.fromarray(image.astype('uint8')) # Convert to RGB if needed if image.mode != 'RGB': image = image.convert('RGB') # If an HF image processor was loaded alongside a HF model, prefer it for preprocessing if hf_processor is not None: try: # AutoImageProcessor expects PIL images or numpy arrays; return_tensors='pt' gives PyTorch tensors proc = hf_processor(images=image, return_tensors='pt') # Some processors return 'pixel_values', others return 'pixel_values' key if 'pixel_values' in proc: input_tensor = proc['pixel_values'].to(DEVICE) else: # Fall back to first tensor-like value val = next(iter(proc.values())) input_tensor = val.to(DEVICE) except Exception: logger.exception('HF processor failed; falling back to torchvision preprocessing') hf_local_fallback = True else: hf_local_fallback = False else: hf_local_fallback = True if hf_local_fallback: # Define preprocessing step by step to avoid namespace issues resize = transforms.Resize((320, 320)) to_tensor = transforms.ToTensor() normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Apply transformations step by step resized_image = resize(image) tensor_image = to_tensor(resized_image) normalized_tensor = normalize(tensor_image) input_tensor = normalized_tensor.unsqueeze(0).to(DEVICE) # Prediction with torch.no_grad(): outputs = model(input_tensor) # Handle Hugging Face ModelOutput objects try: # HF ModelOutput may be dict-like with a 'logits' attribute if hasattr(outputs, 'logits'): logits = outputs.logits elif isinstance(outputs, (tuple, list)): logits = outputs[0] else: logits = outputs except Exception: logits = outputs # Ensure logits is a tensor if not isinstance(logits, torch.Tensor): logits = torch.tensor(np.asarray(logits)).to(DEVICE) # Handle unexpected tensor shapes: # - if logits has spatial dims (e.g., 4D), average them # - if logits is 1D, unsqueeze batch dim try: if logits.dim() > 2: # average over all dims after channel dim reduce_dims = tuple(range(2, logits.dim())) logits = logits.mean(dim=reduce_dims) if logits.dim() == 1: logits = logits.unsqueeze(0) except Exception: # if shape ops fail, log and return safe error logger.exception('Failed to normalize logits shape') return {"Error": 0.0} # If single-logit output, treat as sigmoid probability if logits.size(1) == 1: probs = torch.sigmoid(logits) # return single-label prob mapped to first class or generic prob = float(probs[0, 0].item()) label = class_names[0].replace('_', ' ') if class_names else 'Class_0' return {label: prob} probabilities = F.softmax(logits, dim=1) # Get top 5 predictions top5_prob, top5_indices = torch.topk(probabilities, min(5, probabilities.shape[1]), dim=1) # Format results results = {} for i in range(top5_indices.shape[1]): class_idx = int(top5_indices[0][i].item()) prob = float(top5_prob[0][i].item()) # Handle potential class index mismatch if class_idx < len(class_names): class_name = class_names[class_idx].replace('_', ' ') else: class_name = "Class_" + str(class_idx) results[class_name] = prob return results except Exception as e: # Log exception and return a numeric-friendly error response for Gradio logger.exception('Prediction failed') print('Prediction failed:', e) return {"Error": 0.0} # Create Gradio interface title = "🐦 Bird Species Classifier" description = """ ## Advanced Bird Classification Model This model can classify **199 different bird species** using advanced deep learning techniques: ### Model Details: - **Architecture**: Auto-detected EfficientNet (B4) with enhanced regularization & fine-tuned with Swin Transformer (Emiel/cub-200-bird-classifier-swin) - **Training Strategy**: Progressive training with advanced augmentation - **Performance**: Optimized for accuracy and reliability - **Dataset**: CUB-200-2011 (200 bird species) ### How to use: 1. Upload a clear image of a bird 2. The model will predict the top 5 most likely species 3. Confidence scores show the model's certainty ### Best Results Tips: - Use high-quality, well-lit images - Ensure the bird is clearly visible - Close-up shots work better than distant ones - Natural lighting produces better results **Note**: This model was trained on the CUB-200-2011 dataset and works best with North American bird species. """ article = """ ### Technical Implementation: - **Framework**: PyTorch with auto-detected EfficientNet backbone - **Training**: Progressive training with advanced augmentation strategies - **Regularization**: Optimized dropout rates and comprehensive validation - **Image Size**: 320x320 pixels for optimal detail capture ### About the Model: This bird classifier was developed using advanced machine learning techniques including: - Transfer learning from ImageNet-pretrained EfficientNet - Progressive training strategy across multiple stages - Advanced data augmentation for improved generalization - Comprehensive evaluation and optimization The model automatically detects the correct architecture (EfficientNet-B2 or B3) from the saved weights, ensuring compatibility and optimal performance. For more details about the training process and methodology, please refer to the repository documentation. """ # Create the interface iface = gr.Interface( fn=predict_bird, inputs=gr.Image(type="pil", label="Upload Bird Image"), outputs=gr.Label(num_top_classes=5, label="Predictions"), title=title, description=description, article=article, examples=[ # You can add example images here if you have them ], allow_flagging="never", theme=gr.themes.Soft() ) if __name__ == "__main__": iface.launch(debug=True)