Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as transforms | |
| import torchvision.models as models | |
| from PIL import Image | |
| import numpy as np | |
| import os | |
| # 1. Define class names | |
| class_names = ['ants', 'bees'] | |
| # 2. Function to load the pre-trained model | |
| def load_model(): | |
| model = models.mobilenet_v3_small(weights='DEFAULT') | |
| # Freeze all parameters in the feature extractor | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| # Replace the classifier head | |
| num_ftrs = model.classifier[3].in_features | |
| model.classifier[3] = nn.Linear(num_ftrs, len(class_names)) | |
| # Load the state dictionary | |
| # model_save_path = 'mobilenetv3_hymenoptera.pth' | |
| file_dir = os.path.dirname(os.path.abspath(__file__)) | |
| model_save_path = os.path.join(file_dir, 'mobilenetv3_hymenoptera.pth') | |
| try: | |
| model.load_state_dict(torch.load(model_save_path, map_location=torch.device('cpu'))) | |
| model.eval() | |
| return model | |
| except FileNotFoundError: | |
| st.error(f"Error: Model file '{model_save_path}' not found.") | |
| st.error("Please ensure 'mobilenetv3_hymenoptera.pth' is in the same directory as 'app.py' app.") | |
| st.stop() | |
| except Exception as e: | |
| st.error(f"Error loading model: {e}") | |
| st.stop() | |
| # 3. Define image transformation pipeline for inference | |
| preprocess = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| # Load the model once | |
| model = load_model() | |
| # Streamlit app interface | |
| st.title("Ant vs. Bee Classifier") | |
| st.write("Upload an image to classify whether it's an ant or a bee.") | |
| uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file is not None: | |
| image = Image.open(uploaded_file).convert('RGB') | |
| st.image(image, caption='Uploaded Image.', use_column_width=True) | |
| st.write("") | |
| st.write("Classifying...") | |
| # Preprocess the image | |
| input_tensor = preprocess(image) | |
| input_batch = input_tensor.unsqueeze(0) # Create a mini-batch as expected by the model | |
| with torch.no_grad(): | |
| output = model(input_batch) | |
| # Apply softmax to get probabilities | |
| probabilities = torch.nn.functional.softmax(output[0], dim=0) | |
| # Get prediction | |
| _, predicted_idx = torch.max(output, 1) | |
| predicted_class = class_names[predicted_idx.item()] | |
| confidence = probabilities[predicted_idx.item()].item() | |
| st.write(f"Prediction: **{predicted_class}**") | |
| st.write(f"Confidence: **{confidence:.2f}**") | |