File size: 4,490 Bytes
fd62084 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
"""
Gradio App for Bird Species Classification
Deployed on Hugging Face Spaces
"""
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import convnext_base
from PIL import Image
import json
# Load class names
with open('class_names.json', 'r') as f:
class_names = json.load(f)
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Create model architecture (same as training)
def create_model(num_classes=200):
"""Create ConvNeXt model with same architecture as training"""
model = convnext_base(weights=None)
# Same classifier architecture as training
num_ftrs = model.classifier[2].in_features
model.classifier = nn.Sequential(
nn.Flatten(1),
nn.LayerNorm((num_ftrs,)),
nn.Dropout(0.6),
nn.Linear(num_ftrs, 512),
nn.GELU(),
nn.Dropout(0.5),
nn.Linear(512, num_classes)
)
return model
# Load the trained model
print("Loading model...")
model = create_model(num_classes=200)
# Load weights
checkpoint = torch.load('models/final_model.pth', map_location=device)
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
if 'val_acc' in checkpoint:
val_acc = checkpoint['val_acc']
print(f"Model loaded! Validation accuracy: {val_acc:.2f}%")
else:
model.load_state_dict(checkpoint)
print("Model loaded!")
model = model.to(device)
model.eval()
# Image preprocessing (same as validation transforms)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def predict(image):
"""
Make prediction on uploaded image
Args:
image: PIL Image
Returns:
dict: Top 5 predictions with confidence scores
"""
# Preprocess image
img_tensor = transform(image).unsqueeze(0).to(device)
# Make prediction
with torch.no_grad():
outputs = model(img_tensor)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
# Get top 5 predictions
top5_prob, top5_idx = torch.topk(probabilities, 5)
# Format results
results = {}
for i in range(5):
class_id = top5_idx[0][i].item()
prob = top5_prob[0][i].item()
species_name = class_names.get(str(class_id), f"Class {class_id}")
results[species_name] = float(prob)
return results
# Create Gradio interface
title = "🐦 Bird Species Classification"
description = """
Upload an image of a bird and the model will predict the species!
**Model Details:**
- Architecture: ConvNeXt-Base (87M parameters)
- Dataset: CUB-200-2011 (200 bird species)
- Test Accuracy: 83.64%
- Average Per-Class Accuracy: 83.29%
**Training Strategy:**
- Transfer Learning with ImageNet pretrained weights
- Two-phase training: Frozen backbone (40 epochs) → Fine-tuning (20 epochs)
- Strong regularization: Dropout (0.6, 0.5), Label smoothing (0.2)
- Data augmentation: Rotation, flip, color jitter, random erasing
Upload a clear image of a bird to get started!
"""
article = """
### About This Model
This bird classifier was trained on the CUB-200-2011 dataset containing 200 North American bird species.
The model uses ConvNeXt-Base architecture with modern training techniques to achieve high accuracy while
preventing overfitting.
**Key Features:**
- ✅ 200 bird species classification
- ✅ State-of-the-art ConvNeXt architecture
- ✅ 83.64% test accuracy
- ✅ Real-time inference
**Best Results:** Upload high-quality images with the bird clearly visible and centered.
"""
examples = [
# You can add example images here if you have them
# ["examples/bird1.jpg"],
# ["examples/bird2.jpg"],
]
# Create interface
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload Bird Image"),
outputs=gr.Label(num_top_classes=5, label="Top 5 Predictions"),
title=title,
description=description,
article=article,
examples=examples if examples else None,
theme=gr.themes.Soft(),
allow_flagging="never",
)
# Launch the app
if __name__ == "__main__":
iface.launch()
|