File size: 19,868 Bytes
7d3d1f4
 
 
 
 
 
 
 
 
 
 
 
61505b4
7d3d1f4
 
 
 
c99d892
 
b82867f
c99d892
 
 
b82867f
c99d892
 
b82867f
 
 
7d3d1f4
c99d892
b543dd0
c99d892
 
 
61505b4
7d3d1f4
 
61505b4
 
 
 
 
 
 
 
 
7d3d1f4
c99d892
 
 
 
 
 
 
 
7d3d1f4
 
 
 
c99d892
 
 
 
 
b82867f
 
 
61505b4
 
 
 
 
 
 
b82867f
 
 
 
 
 
 
61505b4
 
 
 
 
 
c99d892
61505b4
 
 
c99d892
 
 
 
 
 
 
b82867f
 
 
 
 
 
 
c99d892
 
 
 
 
 
 
61505b4
 
 
 
 
 
 
 
f9fef0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c99d892
61505b4
 
f9fef0f
61505b4
 
 
 
 
 
c99d892
f9fef0f
 
 
 
 
 
 
61505b4
 
 
 
 
 
 
 
 
 
 
c99d892
 
61505b4
 
 
 
 
c99d892
 
 
 
 
b82867f
 
 
 
 
 
 
61505b4
 
 
c99d892
 
 
 
61505b4
c99d892
61505b4
c99d892
 
 
7d3d1f4
c99d892
 
 
 
 
 
 
7d3d1f4
c99d892
 
 
 
 
 
 
 
 
 
 
 
 
 
7d3d1f4
c99d892
 
 
7d3d1f4
c99d892
 
 
 
 
 
 
7d3d1f4
c99d892
 
 
7d3d1f4
c99d892
 
 
 
 
 
 
eab1518
 
 
 
 
 
 
 
 
61505b4
eab1518
 
61505b4
 
eab1518
 
 
 
 
 
 
 
 
 
 
 
 
7d3d1f4
 
 
 
 
 
 
 
 
 
 
 
 
 
b82867f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d3d1f4
 
 
 
c99d892
 
 
 
 
 
 
 
 
 
 
 
 
b82867f
c99d892
 
 
 
b82867f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c99d892
7d3d1f4
c99d892
 
7d3d1f4
 
c99d892
 
 
7d3d1f4
 
 
 
 
c99d892
 
7d3d1f4
 
 
b82867f
 
 
 
7d3d1f4
 
 
 
 
 
 
 
 
aa0bcc3
7d3d1f4
 
aa0bcc3
7d3d1f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abd02d5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
"""
Gradio App for Bird Classification - Hugging Face Deployment
Enhanced model with architecture auto-detection and error handling.
"""
import gradio as gr
import torch
import torch.nn.functional as F
from PIL import Image
import json
import numpy as np
from torchvision import transforms
import os
import logging

# Import our model architecture
from models import create_model

# Optional: Hugging Face imports (used only when evaluating HF-format checkpoints)
try:
    import transformers
    from transformers import AutoConfig, AutoModelForImageClassification
    HF_AVAILABLE = True
except Exception:
    transformers = None
    HF_AVAILABLE = False

# HF image processor (AutoImageProcessor or AutoFeatureExtractor) will be stored here when available
hf_processor = None

# Configuration
# Default to the moved fine-tuned checkpoint if present
MODEL_PATH = os.environ.get('MODEL_PATH', os.path.join('best_model_finetuned.pth'))
# Optional: if your HF model id is known (e.g. Emiel/cub-200-bird-classifier-swin), set HF_MODEL_ID env var
HF_MODEL_ID = os.environ.get('HF_MODEL_ID', None)
CLASS_NAMES_PATH = os.environ.get('CLASS_NAMES_PATH', 'class_names.json')
FORCE_HF_LOAD = os.environ.get('FORCE_HF_LOAD', '0').lower() in ('1', 'true', 'yes')
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Default HF model id to try when checkpoint looks HF-like and HF_MODEL_ID not set
DEFAULT_HF_ID = 'Emiel/cub-200-bird-classifier-swin'

# Setup file logger for traceability in Spaces
LOG_FILE = os.environ.get('APP_LOG_PATH', 'app.log')
logging.basicConfig(level=logging.INFO, filename=LOG_FILE, filemode='a',
                    format='%(asctime)s %(levelname)s: %(message)s')
logger = logging.getLogger(__name__)

# Load class names
if os.path.exists(CLASS_NAMES_PATH):
    try:
        with open(CLASS_NAMES_PATH, 'r') as f:
            class_names = json.load(f)
    except Exception:
        class_names = []
else:
    class_names = []

NUM_CLASSES = len(class_names)


def load_checkpoint_model(model_path, device):
    """Attempt to load a checkpoint. Supports local create_model-based checkpoints and
    heuristic handling for Hugging Face (Swin) checkpoints when HF_MODEL_ID is set.
    Returns (model, actual_num_classes) or (None, None) on failure.
    """
    # Allow writing to module-level hf_processor
    global hf_processor

    # If user wants to force HF loading from hub, try that first (useful in Spaces)
    if FORCE_HF_LOAD and HF_MODEL_ID and HF_AVAILABLE:
        try:
            print(f"FORCE_HF_LOAD enabled: loading HF model from hub: {HF_MODEL_ID}")
            hf_model = AutoModelForImageClassification.from_pretrained(HF_MODEL_ID)
            hf_model.to(device)
            hf_model.eval()
            # Try to load a matching image processor from the hub for preprocessing
            try:
                if transformers is not None:
                    hf_processor = transformers.AutoImageProcessor.from_pretrained(HF_MODEL_ID)
                    print("Loaded HF image processor from hub (force load)")
            except Exception:
                print("Warning: failed to load HF image processor for forced hub model")
            num_labels = getattr(hf_model.config, 'num_labels', NUM_CLASSES)
            print(f"Loaded HF model from hub with {num_labels} labels (force)")
            return hf_model, num_labels
        except Exception as e:
            print("Forced HF hub load failed:", e)

    if not os.path.exists(model_path):
        msg = f"Model file not found at {model_path}"
        print(msg)
        logger.info(msg)
        # If HF_MODEL_ID is set and transformers are available, try to load from hub
        if HF_MODEL_ID and HF_AVAILABLE:
            try:
                print(f"Attempting to load model from Hugging Face Hub: {HF_MODEL_ID}")
                hf_model = AutoModelForImageClassification.from_pretrained(HF_MODEL_ID)
                hf_model.to(device)
                hf_model.eval()
                # Try to load image processor for preprocessing
                try:
                    if transformers is not None:
                        hf_processor = transformers.AutoImageProcessor.from_pretrained(HF_MODEL_ID)
                        print("Loaded HF image processor from hub")
                except Exception:
                    print("Warning: failed to load HF image processor from hub")
                num_labels = getattr(hf_model.config, 'num_labels', NUM_CLASSES)
                print(f"Loaded HF model from hub with {num_labels} labels")
                return hf_model, num_labels
            except Exception as e:
                print("Failed to load HF model from hub:", e)
        return None, None

    print(f"Loading checkpoint from: {model_path}")
    logger.info(f"Loading checkpoint from: {model_path}")
    try:
        ckpt = torch.load(model_path, map_location='cpu')
    except Exception as e:
        print("Failed to load checkpoint file:", e)
        logger.exception("Failed to load checkpoint file:")
        ckpt = {}
    # unwrap common dict wrapper (support both 'model_state_dict' and 'state_dict')
    state_dict = {}
    if isinstance(ckpt, dict):
        if 'model_state_dict' in ckpt and isinstance(ckpt['model_state_dict'], dict):
            state_dict = ckpt['model_state_dict']
        elif 'state_dict' in ckpt and isinstance(ckpt['state_dict'], dict):
            state_dict = ckpt['state_dict']
        else:
            # fallback: ckpt may already be a state dict
            state_dict = ckpt

    # If the state_dict is a single-key wrapper (e.g., {'state_dict': {...}} or {'model': {...}}), unwrap one more level
    if isinstance(state_dict, dict) and len(state_dict) == 1:
        sole_val = next(iter(state_dict.values()))
        if isinstance(sole_val, dict):
            # adopt inner dict as state_dict if it looks like parameters
            inner_keys = list(sole_val.keys())[:8]
            # Heuristic: keys with '.' and numeric shapes indicate a param dict
            if any('.' in k for k in inner_keys):
                logger.info(f"Unwrapping single-key checkpoint wrapper, inner keys sample: {inner_keys}")
                state_dict = sole_val

    # Diagnostic: print a few state_dict keys so we can tell checkpoint format
    try:
        sample_keys = list(state_dict.keys())[:16]
        print("Checkpoint sample keys:", sample_keys)
        logger.info(f"Checkpoint sample keys: {sample_keys}")
    except Exception:
        print("No state_dict keys to sample")
        logger.info("No state_dict keys to sample")

    # Heuristic: detect HF-style Swin checkpoint by looking for keys that start with 'swin.'
    # Detect HF-like keys; strip common 'module.' prefix before checking
    def key_is_hf_like(k: str) -> bool:
        kk = k.replace('module.', '')
        kk = kk.lower()
        return kk.startswith('swin.') or 'swin.embeddings' in kk or 'swin.patch_embeddings' in kk

    hf_like = any(key_is_hf_like(k) for k in state_dict.keys()) if state_dict else False
    hf_msg = f"hf_like_checkpoint_detected={hf_like} HF_AVAILABLE={HF_AVAILABLE} HF_MODEL_ID={'set' if HF_MODEL_ID else 'not-set'}"
    print(hf_msg)
    logger.info(hf_msg)

    if hf_like and HF_AVAILABLE:
        # choose which HF id to use: env var or default
        hf_id_to_use = HF_MODEL_ID or DEFAULT_HF_ID
        if HF_MODEL_ID is None:
            info_msg = f"HF_MODEL_ID not set; using DEFAULT_HF_ID='{DEFAULT_HF_ID}' to attempt hub load"
            print(info_msg)
            logger.info(info_msg)

        try:
            msg = f"Attempting to load Hugging Face model '{hf_id_to_use}' and apply checkpoint weights..."
            print(msg)
            logger.info(msg)
            # prefer using the hub config to instantiate exact architecture
            config = AutoConfig.from_pretrained(hf_id_to_use)
            hf_model = AutoModelForImageClassification.from_config(config)
            # load weights non-strictly: match shapes
            missing, unexpected = hf_model.load_state_dict(state_dict, strict=False)
            hf_model.to(device)
            hf_model.eval()
            # Try to fetch image processor for this hf id so we can preprocess in predict
            try:
                if transformers is not None:
                    hf_processor = transformers.AutoImageProcessor.from_pretrained(hf_id_to_use)
                    print(f"Loaded HF image processor for {hf_id_to_use}")
            except Exception:
                print(f"Warning: failed to load HF image processor for {hf_id_to_use}")
            ok_msg = f"Loaded HF model with non-strict state_dict (missing {len(missing)} keys, unexpected {len(unexpected)} keys)"
            print(ok_msg)
            logger.info(ok_msg)
            num_labels = getattr(hf_model.config, 'num_labels', NUM_CLASSES)
            return hf_model, num_labels
        except Exception as e:
            print("HF load failed:", e)
            logger.exception("HF load failed")
            print("Falling back to local model loader...")
            logger.info("Falling back to local model loader")

    # Fallback: try to detect EfficientNet-like shapes and create local model
    # Determine actual num classes by inspecting a likely classifier weight key
    actual_classes = NUM_CLASSES
    for k, v in state_dict.items():
        if k.endswith('classifier.9.weight') or k.endswith('classifier.weight'):
            try:
                actual_classes = v.shape[0]
                break
            except Exception:
                pass

    # Heuristic to choose an EfficientNet variant based on conv head size
    model_type = 'efficientnet_b2'
    if state_dict:
        if 'backbone._conv_head.weight' in state_dict:
            try:
                conv_head_shape = state_dict['backbone._conv_head.weight'].shape
                if conv_head_shape[0] == 1536:
                    model_type = 'efficientnet_b3'
                elif conv_head_shape[0] == 1408:
                    model_type = 'efficientnet_b2'
                elif conv_head_shape[0] == 1280:
                    model_type = 'efficientnet_b1'
            except Exception:
                pass

    print(f"Creating local model {model_type} with {actual_classes} classes (fallback)")
    model = create_model(num_classes=actual_classes, model_type=model_type, pretrained=False, dropout_rate=0.3)
    # Try to load state dict
    try:
        # if ckpt was a dict without model_state_dict, attempt to load directly
        to_load = state_dict if state_dict else ckpt
        model.load_state_dict(to_load, strict=False)
        model.to(device)
        model.eval()
        print("✅ Local model loaded (non-strict).")
        return model, actual_classes
    except Exception as e:
        print("Failed to load local model:", e)
        return None, None


# Load model
print("Loading model...", MODEL_PATH)
model, actual_classes = load_checkpoint_model(MODEL_PATH, DEVICE)
if model is None:
    print("No model available. The app will still launch but predictions will fail.")
else:
    print(f"Model ready. Classes={actual_classes}")
    # If this is a Hugging Face model with id2label, prefer that mapping
    try:
        hf_config = getattr(model, 'config', None)
        if hf_config is not None:
            id2label = getattr(hf_config, 'id2label', None)
            if id2label:
                # id2label keys may be strings or ints
                # Build ordered class_names list by index
                max_idx = max(int(k) for k in id2label.keys())
                hf_class_names = [""] * (max_idx + 1)
                for k, v in id2label.items():
                    hf_class_names[int(k)] = v.replace(' ', '_') if isinstance(v, str) else str(v)
                # Filter out empty entries
                hf_class_names = [c for c in hf_class_names if c]
                if len(hf_class_names) > 0:
                    class_names = hf_class_names
                    NUM_CLASSES = len(class_names)
                    print(f"Using Hugging Face id2label mapping with {NUM_CLASSES} classes")
    except Exception as e:
        print("Warning: failed to extract id2label from HF model config:", e)

    # Warn if class_names.json doesn't match model classes
    if class_names and actual_classes and len(class_names) != actual_classes:
        print(f"Warning: class_names.json has {len(class_names)} entries but model expects {actual_classes} classes.")
        # If HF labels exist and match expected size, prefer them
        if len(class_names) < actual_classes:
            print("Note: consider updating class_names.json to match the model's label order or set HF_MODEL_ID to use id2label mapping.")

def predict_bird(image):
    """
    Predict bird species from uploaded image.
    """
    try:
        # Preprocess image
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image.astype('uint8'))
        
        # Convert to RGB if needed
        if image.mode != 'RGB':
            image = image.convert('RGB')
        
        # If an HF image processor was loaded alongside a HF model, prefer it for preprocessing
        if hf_processor is not None:
            try:
                # AutoImageProcessor expects PIL images or numpy arrays; return_tensors='pt' gives PyTorch tensors
                proc = hf_processor(images=image, return_tensors='pt')
                # Some processors return 'pixel_values', others return 'pixel_values' key
                if 'pixel_values' in proc:
                    input_tensor = proc['pixel_values'].to(DEVICE)
                else:
                    # Fall back to first tensor-like value
                    val = next(iter(proc.values()))
                    input_tensor = val.to(DEVICE)
            except Exception:
                logger.exception('HF processor failed; falling back to torchvision preprocessing')
                hf_local_fallback = True
            else:
                hf_local_fallback = False
        else:
            hf_local_fallback = True

        if hf_local_fallback:
            # Define preprocessing step by step to avoid namespace issues
            resize = transforms.Resize((320, 320))
            to_tensor = transforms.ToTensor()
            normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            
            # Apply transformations step by step
            resized_image = resize(image)
            tensor_image = to_tensor(resized_image)
            normalized_tensor = normalize(tensor_image)
            input_tensor = normalized_tensor.unsqueeze(0).to(DEVICE)
        
        # Prediction
        with torch.no_grad():
            outputs = model(input_tensor)

            # Handle Hugging Face ModelOutput objects
            try:
                # HF ModelOutput may be dict-like with a 'logits' attribute
                if hasattr(outputs, 'logits'):
                    logits = outputs.logits
                elif isinstance(outputs, (tuple, list)):
                    logits = outputs[0]
                else:
                    logits = outputs
            except Exception:
                logits = outputs


            # Ensure logits is a tensor
            if not isinstance(logits, torch.Tensor):
                logits = torch.tensor(np.asarray(logits)).to(DEVICE)

            # Handle unexpected tensor shapes:
            # - if logits has spatial dims (e.g., 4D), average them
            # - if logits is 1D, unsqueeze batch dim
            try:
                if logits.dim() > 2:
                    # average over all dims after channel dim
                    reduce_dims = tuple(range(2, logits.dim()))
                    logits = logits.mean(dim=reduce_dims)
                if logits.dim() == 1:
                    logits = logits.unsqueeze(0)
            except Exception:
                # if shape ops fail, log and return safe error
                logger.exception('Failed to normalize logits shape')
                return {"Error": 0.0}

            # If single-logit output, treat as sigmoid probability
            if logits.size(1) == 1:
                probs = torch.sigmoid(logits)
                # return single-label prob mapped to first class or generic
                prob = float(probs[0, 0].item())
                label = class_names[0].replace('_', ' ') if class_names else 'Class_0'
                return {label: prob}

            probabilities = F.softmax(logits, dim=1)
            # Get top 5 predictions
            top5_prob, top5_indices = torch.topk(probabilities, min(5, probabilities.shape[1]), dim=1)

            # Format results
            results = {}
            for i in range(top5_indices.shape[1]):
                class_idx = int(top5_indices[0][i].item())
                prob = float(top5_prob[0][i].item())
                # Handle potential class index mismatch
                if class_idx < len(class_names):
                    class_name = class_names[class_idx].replace('_', ' ')
                else:
                    class_name = "Class_" + str(class_idx)
                results[class_name] = prob

        return results
        
    except Exception as e:
        # Log exception and return a numeric-friendly error response for Gradio
        logger.exception('Prediction failed')
        print('Prediction failed:', e)
        return {"Error": 0.0}

# Create Gradio interface
title = "🐦 Bird Species Classifier"
description = """
## Advanced Bird Classification Model

This model can classify **199 different bird species** using advanced deep learning techniques:

### Model Details:
- **Architecture**: Auto-detected EfficientNet (B4) with enhanced regularization & fine-tuned with Swin Transformer (Emiel/cub-200-bird-classifier-swin)
- **Training Strategy**: Progressive training with advanced augmentation
- **Performance**: Optimized for accuracy and reliability
- **Dataset**: CUB-200-2011 (200 bird species)

### How to use:
1. Upload a clear image of a bird
2. The model will predict the top 5 most likely species
3. Confidence scores show the model's certainty

### Best Results Tips:
- Use high-quality, well-lit images
- Ensure the bird is clearly visible
- Close-up shots work better than distant ones
- Natural lighting produces better results

**Note**: This model was trained on the CUB-200-2011 dataset and works best with North American bird species.
"""

article = """
### Technical Implementation:
- **Framework**: PyTorch with auto-detected EfficientNet backbone
- **Training**: Progressive training with advanced augmentation strategies
- **Regularization**: Optimized dropout rates and comprehensive validation
- **Image Size**: 320x320 pixels for optimal detail capture

### About the Model:
This bird classifier was developed using advanced machine learning techniques including:
- Transfer learning from ImageNet-pretrained EfficientNet
- Progressive training strategy across multiple stages
- Advanced data augmentation for improved generalization
- Comprehensive evaluation and optimization

The model automatically detects the correct architecture (EfficientNet-B2 or B3) from the saved weights,
ensuring compatibility and optimal performance.

For more details about the training process and methodology, please refer to the repository documentation.
"""

# Create the interface
iface = gr.Interface(
    fn=predict_bird,
    inputs=gr.Image(type="pil", label="Upload Bird Image"),
    outputs=gr.Label(num_top_classes=5, label="Predictions"),
    title=title,
    description=description,
    article=article,
    examples=[
        # You can add example images here if you have them
    ],
    allow_flagging="never",
    theme=gr.themes.Soft()
)

if __name__ == "__main__":
    iface.launch(debug=True)