import streamlit as st import torch import torch.nn as nn import torchvision.transforms as transforms import torchvision.models as models from PIL import Image import numpy as np import os # 1. Define class names class_names = ['ants', 'bees'] # 2. Function to load the pre-trained model @st.cache_resource def load_model(): model = models.mobilenet_v3_small(weights='DEFAULT') # Freeze all parameters in the feature extractor for param in model.parameters(): param.requires_grad = False # Replace the classifier head num_ftrs = model.classifier[3].in_features model.classifier[3] = nn.Linear(num_ftrs, len(class_names)) # Load the state dictionary # model_save_path = 'mobilenetv3_hymenoptera.pth' file_dir = os.path.dirname(os.path.abspath(__file__)) model_save_path = os.path.join(file_dir, 'mobilenetv3_hymenoptera.pth') try: model.load_state_dict(torch.load(model_save_path, map_location=torch.device('cpu'))) model.eval() return model except FileNotFoundError: st.error(f"Error: Model file '{model_save_path}' not found.") st.error("Please ensure 'mobilenetv3_hymenoptera.pth' is in the same directory as 'app.py' app.") st.stop() except Exception as e: st.error(f"Error loading model: {e}") st.stop() # 3. Define image transformation pipeline for inference preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Load the model once model = load_model() # Streamlit app interface st.title("Ant vs. Bee Classifier") st.write("Upload an image to classify whether it's an ant or a bee.") uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: image = Image.open(uploaded_file).convert('RGB') st.image(image, caption='Uploaded Image.', use_column_width=True) st.write("") st.write("Classifying...") # Preprocess the image input_tensor = preprocess(image) input_batch = input_tensor.unsqueeze(0) # Create a mini-batch as expected by the model with torch.no_grad(): output = model(input_batch) # Apply softmax to get probabilities probabilities = torch.nn.functional.softmax(output[0], dim=0) # Get prediction _, predicted_idx = torch.max(output, 1) predicted_class = class_names[predicted_idx.item()] confidence = probabilities[predicted_idx.item()].item() st.write(f"Prediction: **{predicted_class}**") st.write(f"Confidence: **{confidence:.2f}**")