|
|
"""
|
|
|
Gradio App for Bird Species Classification
|
|
|
Deployed on Hugging Face Spaces
|
|
|
"""
|
|
|
|
|
|
import gradio as gr
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from torchvision import transforms
|
|
|
from torchvision.models import convnext_base
|
|
|
from PIL import Image
|
|
|
import json
|
|
|
|
|
|
|
|
|
with open('class_names.json', 'r') as f:
|
|
|
class_names = json.load(f)
|
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
|
|
|
def create_model(num_classes=200):
|
|
|
"""Create ConvNeXt model with same architecture as training"""
|
|
|
model = convnext_base(weights=None)
|
|
|
|
|
|
|
|
|
num_ftrs = model.classifier[2].in_features
|
|
|
model.classifier = nn.Sequential(
|
|
|
nn.Flatten(1),
|
|
|
nn.LayerNorm((num_ftrs,)),
|
|
|
nn.Dropout(0.6),
|
|
|
nn.Linear(num_ftrs, 512),
|
|
|
nn.GELU(),
|
|
|
nn.Dropout(0.5),
|
|
|
nn.Linear(512, num_classes)
|
|
|
)
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
print("Loading model...")
|
|
|
model = create_model(num_classes=200)
|
|
|
|
|
|
|
|
|
checkpoint = torch.load('models/final_model.pth', map_location=device)
|
|
|
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
|
|
model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
if 'val_acc' in checkpoint:
|
|
|
val_acc = checkpoint['val_acc']
|
|
|
print(f"Model loaded! Validation accuracy: {val_acc:.2f}%")
|
|
|
else:
|
|
|
model.load_state_dict(checkpoint)
|
|
|
print("Model loaded!")
|
|
|
|
|
|
model = model.to(device)
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
transform = transforms.Compose([
|
|
|
transforms.Resize((224, 224)),
|
|
|
transforms.ToTensor(),
|
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
|
])
|
|
|
|
|
|
def predict(image):
|
|
|
"""
|
|
|
Make prediction on uploaded image
|
|
|
|
|
|
Args:
|
|
|
image: PIL Image
|
|
|
|
|
|
Returns:
|
|
|
dict: Top 5 predictions with confidence scores
|
|
|
"""
|
|
|
|
|
|
img_tensor = transform(image).unsqueeze(0).to(device)
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
outputs = model(img_tensor)
|
|
|
probabilities = torch.nn.functional.softmax(outputs, dim=1)
|
|
|
|
|
|
|
|
|
top5_prob, top5_idx = torch.topk(probabilities, 5)
|
|
|
|
|
|
|
|
|
results = {}
|
|
|
for i in range(5):
|
|
|
class_id = top5_idx[0][i].item()
|
|
|
prob = top5_prob[0][i].item()
|
|
|
species_name = class_names.get(str(class_id), f"Class {class_id}")
|
|
|
results[species_name] = float(prob)
|
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
title = "π¦ Bird Species Classification"
|
|
|
description = """
|
|
|
Upload an image of a bird and the model will predict the species!
|
|
|
|
|
|
**Model Details:**
|
|
|
- Architecture: ConvNeXt-Base (87M parameters)
|
|
|
- Dataset: CUB-200-2011 (200 bird species)
|
|
|
- Test Accuracy: 83.64%
|
|
|
- Average Per-Class Accuracy: 83.29%
|
|
|
|
|
|
**Training Strategy:**
|
|
|
- Transfer Learning with ImageNet pretrained weights
|
|
|
- Two-phase training: Frozen backbone (40 epochs) β Fine-tuning (20 epochs)
|
|
|
- Strong regularization: Dropout (0.6, 0.5), Label smoothing (0.2)
|
|
|
- Data augmentation: Rotation, flip, color jitter, random erasing
|
|
|
|
|
|
Upload a clear image of a bird to get started!
|
|
|
"""
|
|
|
|
|
|
article = """
|
|
|
### About This Model
|
|
|
|
|
|
This bird classifier was trained on the CUB-200-2011 dataset containing 200 North American bird species.
|
|
|
The model uses ConvNeXt-Base architecture with modern training techniques to achieve high accuracy while
|
|
|
preventing overfitting.
|
|
|
|
|
|
**Key Features:**
|
|
|
- β
200 bird species classification
|
|
|
- β
State-of-the-art ConvNeXt architecture
|
|
|
- β
83.64% test accuracy
|
|
|
- β
Real-time inference
|
|
|
|
|
|
**Best Results:** Upload high-quality images with the bird clearly visible and centered.
|
|
|
"""
|
|
|
|
|
|
examples = [
|
|
|
|
|
|
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
iface = gr.Interface(
|
|
|
fn=predict,
|
|
|
inputs=gr.Image(type="pil", label="Upload Bird Image"),
|
|
|
outputs=gr.Label(num_top_classes=5, label="Top 5 Predictions"),
|
|
|
title=title,
|
|
|
description=description,
|
|
|
article=article,
|
|
|
examples=examples if examples else None,
|
|
|
theme=gr.themes.Soft(),
|
|
|
allow_flagging="never",
|
|
|
)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
iface.launch()
|
|
|
|