File size: 2,663 Bytes
9f50beb
 
 
 
 
 
 
3509b15
9f50beb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09dc5fa
 
 
9f50beb
 
 
 
 
 
0696ad6
9f50beb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import streamlit as st
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
import numpy as np
import os

# 1. Define class names
class_names = ['ants', 'bees']

# 2. Function to load the pre-trained model
@st.cache_resource
def load_model():
    model = models.mobilenet_v3_small(weights='DEFAULT')

    # Freeze all parameters in the feature extractor
    for param in model.parameters():
        param.requires_grad = False

    # Replace the classifier head
    num_ftrs = model.classifier[3].in_features
    model.classifier[3] = nn.Linear(num_ftrs, len(class_names))

    # Load the state dictionary
    # model_save_path = 'mobilenetv3_hymenoptera.pth'
    file_dir = os.path.dirname(os.path.abspath(__file__))
    model_save_path = os.path.join(file_dir, 'mobilenetv3_hymenoptera.pth')
    try:
        model.load_state_dict(torch.load(model_save_path, map_location=torch.device('cpu')))
        model.eval()
        return model
    except FileNotFoundError:
        st.error(f"Error: Model file '{model_save_path}' not found.")
        st.error("Please ensure 'mobilenetv3_hymenoptera.pth' is in the same directory as 'app.py' app.")
        st.stop()
    except Exception as e:
        st.error(f"Error loading model: {e}")
        st.stop()

# 3. Define image transformation pipeline for inference
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load the model once
model = load_model()

# Streamlit app interface
st.title("Ant vs. Bee Classifier")
st.write("Upload an image to classify whether it's an ant or a bee.")

uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])

if uploaded_file is not None:
    image = Image.open(uploaded_file).convert('RGB')
    st.image(image, caption='Uploaded Image.', use_column_width=True)
    st.write("")
    st.write("Classifying...")

    # Preprocess the image
    input_tensor = preprocess(image)
    input_batch = input_tensor.unsqueeze(0) # Create a mini-batch as expected by the model

    with torch.no_grad():
        output = model(input_batch)

    # Apply softmax to get probabilities
    probabilities = torch.nn.functional.softmax(output[0], dim=0)

    # Get prediction
    _, predicted_idx = torch.max(output, 1)
    predicted_class = class_names[predicted_idx.item()]
    confidence = probabilities[predicted_idx.item()].item()

    st.write(f"Prediction: **{predicted_class}**")
    st.write(f"Confidence: **{confidence:.2f}**")