File size: 5,230 Bytes
a2731d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
948f69d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import gradio as gr
import torch
import pandas as pd
import os
from torchvision import transforms
from PIL import Image
from transformers import ConvNextV2ForImageClassification

# --- Configuration ---
# Paths are relative to the app's root directory in the Hugging Face Space
DATA_DIR = '.'
LIST_DIR = os.path.join(DATA_DIR, 'list')
MODEL_PATH_HERBARIUM = os.path.join(DATA_DIR, 'herbarium_convnext_v2_base.pth')
SPECIES_LIST_TXT = os.path.join(LIST_DIR, 'species_list.txt')

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# --- Load Species Information ---
try:
    species_df = pd.read_csv(SPECIES_LIST_TXT, sep=';', header=None, names=['class_id', 'species_name'])
    class_names = list(species_df['species_name'])
    num_labels = len(class_names)
except FileNotFoundError:
    # Fallback if the species list is not found
    class_names = [f"Class {i}" for i in range(100)] # Assuming 100 classes as a fallback
    num_labels = 100
    print(f"Warning: '{SPECIES_LIST_TXT}' not found. Using generic class names.")


# --- Image Transformations ---
data_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# --- Model Loading ---
def load_herbarium_model():
    """Loads the Herbarium ConvNextV2 model."""
    model = ConvNextV2ForImageClassification.from_pretrained(
        "facebook/convnextv2-base-22k-224",
        num_labels=num_labels,
        ignore_mismatched_sizes=True
    )
    try:
        # Load the state dictionary
        model.load_state_dict(torch.load(MODEL_PATH_HERBARIUM, map_location=DEVICE))
    except FileNotFoundError:
        print(f"Warning: Model weights not found at '{MODEL_PATH_HERBARIUM}'. The model is using pre-trained weights, not fine-tuned ones.")
    except Exception as e:
        print(f"Error loading model weights: {e}. The model is using pre-trained weights.")
        
    model = model.to(DEVICE)
    model.eval()
    return model

# Load the primary model
herbarium_model = load_herbarium_model()

# --- Prediction Functions ---
def predict_herbarium(image):
    """Runs inference on the herbarium model."""
    if image is None:
        return "Please upload an image."
        
    # Preprocess the image
    image = data_transforms(image).unsqueeze(0)
    image = image.to(DEVICE)

    # Get model predictions
    with torch.no_grad():
        outputs = herbarium_model(image).logits
        probabilities = torch.nn.functional.softmax(outputs, dim=1)[0]
        
        # Get top 5 predictions
        top5_prob, top5_indices = torch.topk(probabilities, 5)
        
    # Format results
    results = {class_names[i]: f"{p:.3f}" for i, p in zip(top5_indices, top5_prob)}
    return results

def predict_placeholder_1(image):
    """Placeholder function for the second model."""
    if image is None:
        return "Please upload an image."
    return "Model 2 is not available yet. Please check back later."

def predict_placeholder_2(image):
    """Placeholder function for the third model."""
    if image is None:
        return "Please upload an image."
    return "Model 3 is not available yet. Please check back later."

# --- Main Prediction Logic ---
def predict(model_choice, image):
    """Routes the prediction to the chosen model."""
    if model_choice == "Herbarium Species Classifier":
        return predict_herbarium(image)
    elif model_choice == "Future Model 1 (Placeholder)":
        return predict_placeholder_1(image)
    elif model_choice == "Future Model 2 (Placeholder)":
        return predict_placeholder_2(image)
    else:
        return "Invalid model selected."

# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # 🌿 Plant Species Classification
        ## AML Group Project - PsychicFireSong
        Upload an image of a plant to classify it. Select a model from the dropdown below.
        """
    )
    
    with gr.Row():
        with gr.Column(scale=1):
            model_selector = gr.Dropdown(
                label="Select Model",
                choices=[
                    "Herbarium Species Classifier", 
                    "Future Model 1 (Placeholder)", 
                    "Future Model 2 (Placeholder)"
                ],
                value="Herbarium Species Classifier"
            )
            image_input = gr.Image(type="pil", label="Upload Plant Image")
            submit_button = gr.Button("Classify", variant="primary")

        with gr.Column(scale=1):
            output_label = gr.Label(label="Top 5 Predictions", num_top_classes=5)

    submit_button.click(
        fn=predict,
        inputs=[model_selector, image_input],
        outputs=output_label
    )
    
    gr.Examples(
        examples=[
            # Add paths to example images if you have any in your project
            # e.g., os.path.join("examples", "example1.jpg")
        ],
        inputs=image_input,
        outputs=output_label,
        fn=lambda img: predict("Herbarium Species Classifier", img),
        cache_examples=False
    )

if __name__ == "__main__":
    demo.launch()
demo.launch()