File size: 2,284 Bytes
9fae6ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bbce66
9fae6ff
 
 
 
 
 
 
 
5bbce66
 
9fae6ff
5bbce66
9fae6ff
 
 
 
 
5bbce66
9fae6ff
 
 
 
5bbce66
9fae6ff
 
 
 
 
5bbce66
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import gradio as gr
import torch
from torchvision import transforms
import timm
from PIL import Image
import os

# 1. Load Labels
with open('labels.txt', 'r') as f:
    labels = [line.strip() for line in f.readlines()]

# 2. Model Definition
def get_model(num_classes=200, model_path='models/final_model_best.pth'):
    """Initializes and loads the pre-trained ConvNeXt V2 Large model."""
    model = timm.create_model(
        'convnextv2_large.fcmae_ft_in22k_in1k',
        pretrained=False,
        num_classes=num_classes,
        drop_path_rate=0.2
    )
    if not os.path.exists(model_path):
        print(f"Error: Model file not found at {model_path}.")
        return None
    try:
        model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
        model.eval()
        print("Model loaded successfully.")
        return model
    except Exception as e:
        print(f"An error occurred while loading the model: {e}")
        return None

model = get_model()

# 3. Image Transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# 4. Prediction Function
def predict(image):
    """Takes a PIL image and returns a dictionary of top 3 predictions."""
    if model is None:
        return {"Error": "Model is not loaded. Please check the logs for errors."}
        
    image = transform(image).unsqueeze(0)
    with torch.no_grad():
        outputs = model(image)
        probabilities = torch.nn.functional.softmax(outputs, dim=1)[0]
        
    # Get top 3 predictions
    top3_prob, top3_indices = torch.topk(probabilities, 3)
    
    confidences = {labels[i]: float(p) for i, p in zip(top3_indices, top3_prob)}
    
    return confidences

# 5. Gradio Interface
title = "Bird Species Classifier"
description = "Upload an image of a bird to classify it into one of 200 species. This model is a ConvNeXt V2 Large, fine-tuned on a dataset of 200 bird species."

iface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil", label="Upload Bird Image"),
    outputs=gr.Label(num_top_classes=3, label="Predictions"),
    title=title,
    description=description,
)

if __name__ == "__main__":
    iface.launch()