Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from torchvision import transforms | |
| import timm | |
| from PIL import Image | |
| import os | |
| # 1. Load Labels | |
| with open('labels.txt', 'r') as f: | |
| labels = [line.strip() for line in f.readlines()] | |
| # 2. Model Definition | |
| def get_model(num_classes=200, model_path='models/final_model_best.pth'): | |
| """Initializes and loads the pre-trained ConvNeXt V2 Large model.""" | |
| model = timm.create_model( | |
| 'convnextv2_large.fcmae_ft_in22k_in1k', | |
| pretrained=False, | |
| num_classes=num_classes, | |
| drop_path_rate=0.2 | |
| ) | |
| if not os.path.exists(model_path): | |
| print(f"Error: Model file not found at {model_path}.") | |
| return None | |
| try: | |
| model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
| model.eval() | |
| print("Model loaded successfully.") | |
| return model | |
| except Exception as e: | |
| print(f"An error occurred while loading the model: {e}") | |
| return None | |
| model = get_model() | |
| # 3. Image Transformations | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| # 4. Prediction Function | |
| def predict(image): | |
| """Takes a PIL image and returns a dictionary of top 3 predictions.""" | |
| if model is None: | |
| return {"Error": "Model is not loaded. Please check the logs for errors."} | |
| image = transform(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| outputs = model(image) | |
| probabilities = torch.nn.functional.softmax(outputs, dim=1)[0] | |
| # Get top 3 predictions | |
| top3_prob, top3_indices = torch.topk(probabilities, 3) | |
| confidences = {labels[i]: float(p) for i, p in zip(top3_indices, top3_prob)} | |
| return confidences | |
| # 5. Gradio Interface | |
| title = "Bird Species Classifier" | |
| description = "Upload an image of a bird to classify it into one of 200 species. This model is a ConvNeXt V2 Large, fine-tuned on a dataset of 200 bird species." | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil", label="Upload Bird Image"), | |
| outputs=gr.Label(num_top_classes=3, label="Predictions"), | |
| title=title, | |
| description=description, | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() |