import gradio as gr import torch import timm import json from PIL import Image import numpy as np # Load class names with open('class_names.json', 'r') as f: class_names = json.load(f) # Initialize model device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Create model architecture (same as training) - THIS IS THE ARCHITECTURE NAME model = timm.create_model('swin_base_patch4_window7_224', pretrained=False, # Don't load ImageNet pretrained weights num_classes=len(class_names)) # Load your trained weights - THIS IS YOUR TRAINED MODEL FILE model.load_state_dict(torch.load('best_swin_birds.pth', map_location=device)) model = model.to(device) model.eval() # Get the same transforms used during training - NEEDS ARCHITECTURE NAME FOR CONFIG data_config = timm.data.resolve_model_data_config( timm.create_model('swin_base_patch4_window7_224', pretrained=False) ) transform = timm.data.create_transform(**data_config, is_training=False) def predict_bird(image): """ Predict the bird species from an input image """ try: # Ensure image is in RGB format if image.mode != 'RGB': image = image.convert('RGB') # Apply the same preprocessing as during training input_tensor = transform(image).unsqueeze(0).to(device) # Make prediction with torch.no_grad(): outputs = model(input_tensor) probabilities = torch.nn.functional.softmax(outputs, dim=1) # Get top 5 predictions top5_probs, top5_indices = torch.topk(probabilities, 5) # Format results results = {} for i in range(5): class_idx = top5_indices[0][i].item() probability = top5_probs[0][i].item() class_name = class_names[class_idx][4:] results[class_name] = float(probability) return results except Exception as e: return {"Error": f"Prediction failed: {str(e)}"} # Create Gradio interface title = "🐦 Bird Species Classifier" description = """ Upload an image of a bird and I'll identify the species! This model can classify 200 different bird species using a Swin Transformer. """ examples = [ ["image1153.jpg"], ["image1465.jpg"] ] # No examples needed - users can upload their own images # Create the interface iface = gr.Interface( fn=predict_bird, inputs=gr.Image(type="pil", label="Upload a bird image"), outputs=gr.Label(num_top_classes=5, label="Top 5 Predictions"), title=title, description=description, examples=examples, theme=gr.themes.Soft(), allow_flagging="never" ) # Launch the app if __name__ == "__main__": iface.launch()