Bird-Classifier / app.py
Akshit-77's picture
Update app.py
6725ad3 verified
import gradio as gr
import torch
import timm
import json
from PIL import Image
import numpy as np
# Load class names
with open('class_names.json', 'r') as f:
class_names = json.load(f)
# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Create model architecture (same as training) - THIS IS THE ARCHITECTURE NAME
model = timm.create_model('swin_base_patch4_window7_224',
pretrained=False, # Don't load ImageNet pretrained weights
num_classes=len(class_names))
# Load your trained weights - THIS IS YOUR TRAINED MODEL FILE
model.load_state_dict(torch.load('best_swin_birds.pth', map_location=device))
model = model.to(device)
model.eval()
# Get the same transforms used during training - NEEDS ARCHITECTURE NAME FOR CONFIG
data_config = timm.data.resolve_model_data_config(
timm.create_model('swin_base_patch4_window7_224', pretrained=False)
)
transform = timm.data.create_transform(**data_config, is_training=False)
def predict_bird(image):
"""
Predict the bird species from an input image
"""
try:
# Ensure image is in RGB format
if image.mode != 'RGB':
image = image.convert('RGB')
# Apply the same preprocessing as during training
input_tensor = transform(image).unsqueeze(0).to(device)
# Make prediction
with torch.no_grad():
outputs = model(input_tensor)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
# Get top 5 predictions
top5_probs, top5_indices = torch.topk(probabilities, 5)
# Format results
results = {}
for i in range(5):
class_idx = top5_indices[0][i].item()
probability = top5_probs[0][i].item()
class_name = class_names[class_idx][4:]
results[class_name] = float(probability)
return results
except Exception as e:
return {"Error": f"Prediction failed: {str(e)}"}
# Create Gradio interface
title = "🐦 Bird Species Classifier"
description = """
Upload an image of a bird and I'll identify the species!
This model can classify 200 different bird species using a Swin Transformer.
"""
examples = [
["image1153.jpg"],
["image1465.jpg"]
] # No examples needed - users can upload their own images
# Create the interface
iface = gr.Interface(
fn=predict_bird,
inputs=gr.Image(type="pil", label="Upload a bird image"),
outputs=gr.Label(num_top_classes=5, label="Top 5 Predictions"),
title=title,
description=description,
examples=examples,
theme=gr.themes.Soft(),
allow_flagging="never"
)
# Launch the app
if __name__ == "__main__":
iface.launch()