Spaces:
Build error
Build error
| import streamlit as st | |
| import torch | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| from train5 import deeplabv3_encoder_decoder | |
| import numpy as np | |
| # Function to load the model | |
| def load_model(model_path): | |
| model = deeplabv3_encoder_decoder() | |
| try: | |
| model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
| model.eval() | |
| return model | |
| except Exception as e: | |
| st.error(f"Error loading model: {e}") | |
| return None | |
| # Path to the model | |
| model_path = 'model.pth' | |
| # Load the trained model | |
| model = load_model(model_path) | |
| if model: | |
| # Create a Streamlit app | |
| st.title('Aerial Image Segmentation') | |
| # Add a file uploader to the app | |
| uploaded_file = st.file_uploader("Choose an image...", type="jpg") | |
| if uploaded_file is not None: | |
| image = Image.open(uploaded_file) | |
| # Display the original image | |
| st.image(image, caption='Uploaded Image.', use_column_width=True) | |
| # Preprocess the image | |
| data_transform = transforms.Compose([ | |
| transforms.Resize((512, 512)), | |
| transforms.ToTensor()] | |
| ) | |
| image = data_transform(image) | |
| image = image.unsqueeze(0) # add a batch dimension | |
| # Pass the image through the model | |
| with torch.no_grad(): | |
| output = model(image) | |
| # Define the color map and class labels | |
| color_map = { | |
| 0: np.array([255, 34, 133]), # Unlabeled | |
| 1: np.array([0, 252, 199]), # Early Blight | |
| 2: np.array([86, 0, 254]), # Late Blight | |
| 3: np.array([0, 0, 0]) # Leaf Minor | |
| } | |
| class_labels = { | |
| 0: 'Unlabeled', | |
| 1: 'Early Blight', | |
| 2: 'Late Blight', | |
| 3: 'Leaf Minor' | |
| } | |
| for k, v in class_labels.items(): | |
| st.sidebar.markdown(f'<div style="color:rgb{tuple(color_map[k])};">{v}</div>', unsafe_allow_html=True) | |
| output = torch.argmax(output.squeeze(), dim=0).detach().cpu().numpy() | |
| output_rgb = np.zeros((output.shape[0], output.shape[1], 3), dtype=np.uint8) | |
| for k, v in color_map.items(): | |
| output_rgb[output == k] = v | |
| st.image(output_rgb, caption='Segmented Image.', use_column_width=True) | |