Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import timm | |
| import json | |
| from PIL import Image | |
| import numpy as np | |
| # Load class names | |
| with open('class_names.json', 'r') as f: | |
| class_names = json.load(f) | |
| # Initialize model | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Create model architecture (same as training) - THIS IS THE ARCHITECTURE NAME | |
| model = timm.create_model('swin_base_patch4_window7_224', | |
| pretrained=False, # Don't load ImageNet pretrained weights | |
| num_classes=len(class_names)) | |
| # Load your trained weights - THIS IS YOUR TRAINED MODEL FILE | |
| model.load_state_dict(torch.load('best_swin_birds.pth', map_location=device)) | |
| model = model.to(device) | |
| model.eval() | |
| # Get the same transforms used during training - NEEDS ARCHITECTURE NAME FOR CONFIG | |
| data_config = timm.data.resolve_model_data_config( | |
| timm.create_model('swin_base_patch4_window7_224', pretrained=False) | |
| ) | |
| transform = timm.data.create_transform(**data_config, is_training=False) | |
| def predict_bird(image): | |
| """ | |
| Predict the bird species from an input image | |
| """ | |
| try: | |
| # Ensure image is in RGB format | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Apply the same preprocessing as during training | |
| input_tensor = transform(image).unsqueeze(0).to(device) | |
| # Make prediction | |
| with torch.no_grad(): | |
| outputs = model(input_tensor) | |
| probabilities = torch.nn.functional.softmax(outputs, dim=1) | |
| # Get top 5 predictions | |
| top5_probs, top5_indices = torch.topk(probabilities, 5) | |
| # Format results | |
| results = {} | |
| for i in range(5): | |
| class_idx = top5_indices[0][i].item() | |
| probability = top5_probs[0][i].item() | |
| class_name = class_names[class_idx][4:] | |
| results[class_name] = float(probability) | |
| return results | |
| except Exception as e: | |
| return {"Error": f"Prediction failed: {str(e)}"} | |
| # Create Gradio interface | |
| title = "🐦 Bird Species Classifier" | |
| description = """ | |
| Upload an image of a bird and I'll identify the species! | |
| This model can classify 200 different bird species using a Swin Transformer. | |
| """ | |
| examples = [ | |
| ["image1153.jpg"], | |
| ["image1465.jpg"] | |
| ] # No examples needed - users can upload their own images | |
| # Create the interface | |
| iface = gr.Interface( | |
| fn=predict_bird, | |
| inputs=gr.Image(type="pil", label="Upload a bird image"), | |
| outputs=gr.Label(num_top_classes=5, label="Top 5 Predictions"), | |
| title=title, | |
| description=description, | |
| examples=examples, | |
| theme=gr.themes.Soft(), | |
| allow_flagging="never" | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| iface.launch() |