shinySidewalks / app.py
olatte's picture
Update app.py
d780c23 verified
import warnings
import logging
import os
import shutil
import tempfile
from pathlib import Path
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from shiny import App, Inputs, Outputs, Session, reactive, render, ui
import prediction
# Suppress specific warnings
warnings.filterwarnings("ignore", category=UserWarning, message="No writable cache directories")
warnings.filterwarnings("ignore", category=FutureWarning, message="`resume_download` is deprecated")
# Set logging level to WARNING to ignore INFO logs
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# The environment variables are set in the Dockerfile, no need to set them here
WWW_DIR = Path(__file__).parent.resolve() / "www"
# Define global variables
current_image_path = None
current_mask = None
def create_ui():
"""Construct the UI layout for the application."""
return ui.page_fillable(
ui.tags.div(
ui.panel_title("Segment Sidewalks"),
ui.input_dark_mode(mode="dark"),
class_="d-flex justify-content-between align-items-center",
),
ui.layout_sidebar(
ui.sidebar(
ui.input_file("input_image", "Upload .png Image", accept=[".png"], multiple=False),
ui.output_ui("side_menu_controls"),
),
ui.card(
ui.card_header(
"See the sidewalks",
ui.output_ui("overlay"),
class_="d-flex justify-content-between align-items-center",
),
ui.output_plot("plot_image_and_mask", fill=True),
full_screen=True,
),
ui.output_ui("compute"),
),
)
def server(input: Inputs, output: Outputs, session: Session):
global current_image_path
global current_mask
@reactive.Effect
@reactive.event(input.input_image)
def update_image():
global current_image_path
global current_mask
if input.input_image() is not None:
image_file = input.input_image()[0]
temp_image_path = image_file["datapath"]
image = Image.open(temp_image_path).convert("RGB")
# Create a temporary file to save the uploaded image
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
shutil.copy(temp_image_path, temp_file.name)
current_image_path = temp_file.name
logger.info(f"Image uploaded: {current_image_path}")
model, processor, device = prediction.load_model_and_processor(
WWW_DIR / "sidewalkSAM.pth", "facebook/sam-vit-base"
)
logger.info("Model and processor loaded successfully.")
current_mask = prediction.get_sidewalk_prediction(image, model, processor, device)
logger.info("Inference completed.")
# Trigger the plot update
output.plot_image_and_mask.invalidate()
@output
@render.plot
def plot_image_and_mask():
global current_image_path
global current_mask
logger.info(f"Plotting images with current_image_path: {current_image_path} and current_mask: {current_mask is not None}")
if current_image_path is None or current_mask is None:
logger.warning("Image or mask is None. Skipping plot.")
return
fig, axes = plt.subplots(1, 2, figsize=(15, 5))
# Plot the first image on the left
image = Image.open(current_image_path).convert("RGB")
axes[0].imshow(image)
axes[0].set_title("Original Image")
# Plot the second image on the right
axes[1].imshow(current_mask) # Assuming the second image is grayscale
axes[1].set_title("Prediction")
# Hide axis ticks and labels
for ax in axes:
ax.set_xticks([])
ax.set_yticks([])
ax.set_xticklabels([])
ax.set_yticklabels([])
# Display the images side by side
plt.show()
app = App(create_ui(), server, static_assets=str(WWW_DIR))