Spaces:
Sleeping
Sleeping
File size: 2,663 Bytes
9f50beb 3509b15 9f50beb 09dc5fa 9f50beb 0696ad6 9f50beb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 | import streamlit as st
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
import numpy as np
import os
# 1. Define class names
class_names = ['ants', 'bees']
# 2. Function to load the pre-trained model
@st.cache_resource
def load_model():
model = models.mobilenet_v3_small(weights='DEFAULT')
# Freeze all parameters in the feature extractor
for param in model.parameters():
param.requires_grad = False
# Replace the classifier head
num_ftrs = model.classifier[3].in_features
model.classifier[3] = nn.Linear(num_ftrs, len(class_names))
# Load the state dictionary
# model_save_path = 'mobilenetv3_hymenoptera.pth'
file_dir = os.path.dirname(os.path.abspath(__file__))
model_save_path = os.path.join(file_dir, 'mobilenetv3_hymenoptera.pth')
try:
model.load_state_dict(torch.load(model_save_path, map_location=torch.device('cpu')))
model.eval()
return model
except FileNotFoundError:
st.error(f"Error: Model file '{model_save_path}' not found.")
st.error("Please ensure 'mobilenetv3_hymenoptera.pth' is in the same directory as 'app.py' app.")
st.stop()
except Exception as e:
st.error(f"Error loading model: {e}")
st.stop()
# 3. Define image transformation pipeline for inference
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Load the model once
model = load_model()
# Streamlit app interface
st.title("Ant vs. Bee Classifier")
st.write("Upload an image to classify whether it's an ant or a bee.")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file).convert('RGB')
st.image(image, caption='Uploaded Image.', use_column_width=True)
st.write("")
st.write("Classifying...")
# Preprocess the image
input_tensor = preprocess(image)
input_batch = input_tensor.unsqueeze(0) # Create a mini-batch as expected by the model
with torch.no_grad():
output = model(input_batch)
# Apply softmax to get probabilities
probabilities = torch.nn.functional.softmax(output[0], dim=0)
# Get prediction
_, predicted_idx = torch.max(output, 1)
predicted_class = class_names[predicted_idx.item()]
confidence = probabilities[predicted_idx.item()].item()
st.write(f"Prediction: **{predicted_class}**")
st.write(f"Confidence: **{confidence:.2f}**")
|