Segmentation / app.py
riha55's picture
Update app.py
199937f verified
import streamlit as st
import torch
from PIL import Image
import torchvision.transforms as transforms
from train5 import deeplabv3_encoder_decoder
import numpy as np
# Function to load the model
def load_model(model_path):
model = deeplabv3_encoder_decoder()
try:
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
return model
except Exception as e:
st.error(f"Error loading model: {e}")
return None
# Path to the model
model_path = 'model.pth'
# Load the trained model
model = load_model(model_path)
if model:
# Create a Streamlit app
st.title('Aerial Image Segmentation')
# Add a file uploader to the app
uploaded_file = st.file_uploader("Choose an image...", type="jpg")
if uploaded_file is not None:
image = Image.open(uploaded_file)
# Display the original image
st.image(image, caption='Uploaded Image.', use_column_width=True)
# Preprocess the image
data_transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor()]
)
image = data_transform(image)
image = image.unsqueeze(0) # add a batch dimension
# Pass the image through the model
with torch.no_grad():
output = model(image)
# Define the color map and class labels
color_map = {
0: np.array([255, 34, 133]), # Unlabeled
1: np.array([0, 252, 199]), # Early Blight
2: np.array([86, 0, 254]), # Late Blight
3: np.array([0, 0, 0]) # Leaf Minor
}
class_labels = {
0: 'Unlabeled',
1: 'Early Blight',
2: 'Late Blight',
3: 'Leaf Minor'
}
for k, v in class_labels.items():
st.sidebar.markdown(f'<div style="color:rgb{tuple(color_map[k])};">{v}</div>', unsafe_allow_html=True)
output = torch.argmax(output.squeeze(), dim=0).detach().cpu().numpy()
output_rgb = np.zeros((output.shape[0], output.shape[1], 3), dtype=np.uint8)
for k, v in color_map.items():
output_rgb[output == k] = v
st.image(output_rgb, caption='Segmented Image.', use_column_width=True)