fjurie's picture
Upload 2 files
3509b15 verified
import streamlit as st
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
import numpy as np
import os
# 1. Define class names
class_names = ['ants', 'bees']
# 2. Function to load the pre-trained model
@st.cache_resource
def load_model():
model = models.mobilenet_v3_small(weights='DEFAULT')
# Freeze all parameters in the feature extractor
for param in model.parameters():
param.requires_grad = False
# Replace the classifier head
num_ftrs = model.classifier[3].in_features
model.classifier[3] = nn.Linear(num_ftrs, len(class_names))
# Load the state dictionary
# model_save_path = 'mobilenetv3_hymenoptera.pth'
file_dir = os.path.dirname(os.path.abspath(__file__))
model_save_path = os.path.join(file_dir, 'mobilenetv3_hymenoptera.pth')
try:
model.load_state_dict(torch.load(model_save_path, map_location=torch.device('cpu')))
model.eval()
return model
except FileNotFoundError:
st.error(f"Error: Model file '{model_save_path}' not found.")
st.error("Please ensure 'mobilenetv3_hymenoptera.pth' is in the same directory as 'app.py' app.")
st.stop()
except Exception as e:
st.error(f"Error loading model: {e}")
st.stop()
# 3. Define image transformation pipeline for inference
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Load the model once
model = load_model()
# Streamlit app interface
st.title("Ant vs. Bee Classifier")
st.write("Upload an image to classify whether it's an ant or a bee.")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file).convert('RGB')
st.image(image, caption='Uploaded Image.', use_column_width=True)
st.write("")
st.write("Classifying...")
# Preprocess the image
input_tensor = preprocess(image)
input_batch = input_tensor.unsqueeze(0) # Create a mini-batch as expected by the model
with torch.no_grad():
output = model(input_batch)
# Apply softmax to get probabilities
probabilities = torch.nn.functional.softmax(output[0], dim=0)
# Get prediction
_, predicted_idx = torch.max(output, 1)
predicted_class = class_names[predicted_idx.item()]
confidence = probabilities[predicted_idx.item()].item()
st.write(f"Prediction: **{predicted_class}**")
st.write(f"Confidence: **{confidence:.2f}**")