--- license: mit tags: - image-classification - pytorch - convnext - birds - computer-vision datasets: - CUB-200-2011 metrics: - accuracy library_name: pytorch --- # Bird Species Classification - ConvNeXt-Base ## Model Description This model classifies 200 bird species using ConvNeXt-Base architecture with transfer learning. ## Performance - **Test Accuracy**: 83.64% - **Average Per-Class Accuracy**: 83.29% - **Architecture**: ConvNeXt-Base (87M parameters) - **Dataset**: CUB-200-2011 (200 bird species) ## Training Details ### Model Architecture - **Base Model**: ConvNeXt-Base pretrained on ImageNet-1K - **Classifier**: Custom 2-layer classifier with dropout - **Input Size**: 224x224 RGB images ### Training Strategy - **Phase 1** (40 epochs): Frozen backbone, train classifier only - Learning Rate: 0.001 - Batch Size: 32 - **Phase 2** (20 epochs): Full fine-tuning - Learning Rate: 0.0001 - Batch Size: 32 ### Regularization - Dropout: 0.6, 0.5 - Label Smoothing: 0.2 - Weight Decay: 0.005 - Data Augmentation: rotation, flip, color jitter, random erasing ## Usage ```python import torch import torch.nn as nn from torchvision import models, transforms from PIL import Image # Load model model = models.convnext_base(weights=None) num_features = model.classifier[2].in_features model.classifier[2] = nn.Sequential( nn.Dropout(0.6), nn.Linear(num_features, 1024), nn.ReLU(), nn.Dropout(0.5), nn.Linear(1024, 200) ) # Load weights model.load_state_dict(torch.load('final_model.pth', map_location='cpu')) model.eval() # Preprocessing transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Predict image = Image.open('bird.jpg').convert('RGB') image_tensor = transform(image).unsqueeze(0) with torch.no_grad(): outputs = model(image_tensor) probabilities = torch.nn.functional.softmax(outputs[0], dim=0) top5_prob, top5_indices = torch.topk(probabilities, 5) print("Top 5 Predictions:") for prob, idx in zip(top5_prob, top5_indices): print(f"Class {idx}: {prob.item()*100:.2f}%") ``` ## Try it out! Try the live demo: [Bird Species Classifier](https://huggingface.co/spaces/AshProg/AppliedMachineLearning_BirdClassifierInterface) ## Model Files - `final_model.pth` (1.06 GB): Full model weights ## Citation Dataset: CUB-200-2011 ``` @techreport{WahCUB_200_2011, Title = {{The Caltech-UCSD Birds-200-2011 Dataset}}, Author = {Wah, C. and Branson, S. and Welinder, P. and Perona, P. and Belongie, S.}, Year = {2011}, Institution = {California Institute of Technology}, Number = {CNS-TR-2011-001} } ``` ## Contact For questions or issues, please open an issue on the Space repository.