AshProg's picture
Upload app.py
fd62084 verified
raw
history blame
4.49 kB
"""
Gradio App for Bird Species Classification
Deployed on Hugging Face Spaces
"""
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import convnext_base
from PIL import Image
import json
# Load class names
with open('class_names.json', 'r') as f:
class_names = json.load(f)
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Create model architecture (same as training)
def create_model(num_classes=200):
"""Create ConvNeXt model with same architecture as training"""
model = convnext_base(weights=None)
# Same classifier architecture as training
num_ftrs = model.classifier[2].in_features
model.classifier = nn.Sequential(
nn.Flatten(1),
nn.LayerNorm((num_ftrs,)),
nn.Dropout(0.6),
nn.Linear(num_ftrs, 512),
nn.GELU(),
nn.Dropout(0.5),
nn.Linear(512, num_classes)
)
return model
# Load the trained model
print("Loading model...")
model = create_model(num_classes=200)
# Load weights
checkpoint = torch.load('models/final_model.pth', map_location=device)
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
if 'val_acc' in checkpoint:
val_acc = checkpoint['val_acc']
print(f"Model loaded! Validation accuracy: {val_acc:.2f}%")
else:
model.load_state_dict(checkpoint)
print("Model loaded!")
model = model.to(device)
model.eval()
# Image preprocessing (same as validation transforms)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def predict(image):
"""
Make prediction on uploaded image
Args:
image: PIL Image
Returns:
dict: Top 5 predictions with confidence scores
"""
# Preprocess image
img_tensor = transform(image).unsqueeze(0).to(device)
# Make prediction
with torch.no_grad():
outputs = model(img_tensor)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
# Get top 5 predictions
top5_prob, top5_idx = torch.topk(probabilities, 5)
# Format results
results = {}
for i in range(5):
class_id = top5_idx[0][i].item()
prob = top5_prob[0][i].item()
species_name = class_names.get(str(class_id), f"Class {class_id}")
results[species_name] = float(prob)
return results
# Create Gradio interface
title = "🐦 Bird Species Classification"
description = """
Upload an image of a bird and the model will predict the species!
**Model Details:**
- Architecture: ConvNeXt-Base (87M parameters)
- Dataset: CUB-200-2011 (200 bird species)
- Test Accuracy: 83.64%
- Average Per-Class Accuracy: 83.29%
**Training Strategy:**
- Transfer Learning with ImageNet pretrained weights
- Two-phase training: Frozen backbone (40 epochs) β†’ Fine-tuning (20 epochs)
- Strong regularization: Dropout (0.6, 0.5), Label smoothing (0.2)
- Data augmentation: Rotation, flip, color jitter, random erasing
Upload a clear image of a bird to get started!
"""
article = """
### About This Model
This bird classifier was trained on the CUB-200-2011 dataset containing 200 North American bird species.
The model uses ConvNeXt-Base architecture with modern training techniques to achieve high accuracy while
preventing overfitting.
**Key Features:**
- βœ… 200 bird species classification
- βœ… State-of-the-art ConvNeXt architecture
- βœ… 83.64% test accuracy
- βœ… Real-time inference
**Best Results:** Upload high-quality images with the bird clearly visible and centered.
"""
examples = [
# You can add example images here if you have them
# ["examples/bird1.jpg"],
# ["examples/bird2.jpg"],
]
# Create interface
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload Bird Image"),
outputs=gr.Label(num_top_classes=5, label="Top 5 Predictions"),
title=title,
description=description,
article=article,
examples=examples if examples else None,
theme=gr.themes.Soft(),
allow_flagging="never",
)
# Launch the app
if __name__ == "__main__":
iface.launch()