Spaces:
Sleeping
Sleeping
| import warnings | |
| import logging | |
| import os | |
| import shutil | |
| import tempfile | |
| from pathlib import Path | |
| import numpy as np | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| from shiny import App, Inputs, Outputs, Session, reactive, render, ui | |
| import prediction | |
| # Suppress specific warnings | |
| warnings.filterwarnings("ignore", category=UserWarning, message="No writable cache directories") | |
| warnings.filterwarnings("ignore", category=FutureWarning, message="`resume_download` is deprecated") | |
| # Set logging level to WARNING to ignore INFO logs | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # The environment variables are set in the Dockerfile, no need to set them here | |
| WWW_DIR = Path(__file__).parent.resolve() / "www" | |
| # Define global variables | |
| current_image_path = None | |
| current_mask = None | |
| def create_ui(): | |
| """Construct the UI layout for the application.""" | |
| return ui.page_fillable( | |
| ui.tags.div( | |
| ui.panel_title("Segment Sidewalks"), | |
| ui.input_dark_mode(mode="dark"), | |
| class_="d-flex justify-content-between align-items-center", | |
| ), | |
| ui.layout_sidebar( | |
| ui.sidebar( | |
| ui.input_file("input_image", "Upload .png Image", accept=[".png"], multiple=False), | |
| ui.output_ui("side_menu_controls"), | |
| ), | |
| ui.card( | |
| ui.card_header( | |
| "See the sidewalks", | |
| ui.output_ui("overlay"), | |
| class_="d-flex justify-content-between align-items-center", | |
| ), | |
| ui.output_plot("plot_image_and_mask", fill=True), | |
| full_screen=True, | |
| ), | |
| ui.output_ui("compute"), | |
| ), | |
| ) | |
| def server(input: Inputs, output: Outputs, session: Session): | |
| global current_image_path | |
| global current_mask | |
| def update_image(): | |
| global current_image_path | |
| global current_mask | |
| if input.input_image() is not None: | |
| image_file = input.input_image()[0] | |
| temp_image_path = image_file["datapath"] | |
| image = Image.open(temp_image_path).convert("RGB") | |
| # Create a temporary file to save the uploaded image | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file: | |
| shutil.copy(temp_image_path, temp_file.name) | |
| current_image_path = temp_file.name | |
| logger.info(f"Image uploaded: {current_image_path}") | |
| model, processor, device = prediction.load_model_and_processor( | |
| WWW_DIR / "sidewalkSAM.pth", "facebook/sam-vit-base" | |
| ) | |
| logger.info("Model and processor loaded successfully.") | |
| current_mask = prediction.get_sidewalk_prediction(image, model, processor, device) | |
| logger.info("Inference completed.") | |
| # Trigger the plot update | |
| output.plot_image_and_mask.invalidate() | |
| def plot_image_and_mask(): | |
| global current_image_path | |
| global current_mask | |
| logger.info(f"Plotting images with current_image_path: {current_image_path} and current_mask: {current_mask is not None}") | |
| if current_image_path is None or current_mask is None: | |
| logger.warning("Image or mask is None. Skipping plot.") | |
| return | |
| fig, axes = plt.subplots(1, 2, figsize=(15, 5)) | |
| # Plot the first image on the left | |
| image = Image.open(current_image_path).convert("RGB") | |
| axes[0].imshow(image) | |
| axes[0].set_title("Original Image") | |
| # Plot the second image on the right | |
| axes[1].imshow(current_mask) # Assuming the second image is grayscale | |
| axes[1].set_title("Prediction") | |
| # Hide axis ticks and labels | |
| for ax in axes: | |
| ax.set_xticks([]) | |
| ax.set_yticks([]) | |
| ax.set_xticklabels([]) | |
| ax.set_yticklabels([]) | |
| # Display the images side by side | |
| plt.show() | |
| app = App(create_ui(), server, static_assets=str(WWW_DIR)) | |