Spaces:
Sleeping
Sleeping
File size: 4,170 Bytes
41b512d 9304821 16b06ec 8fa7ba8 3e3df2b f607f30 41b512d f607f30 41b512d 3e3df2b 6f634d4 3e3df2b c4b8156 c6bf71c 3e3df2b c6bf71c 3e3df2b c4b8156 3e3df2b 0222108 ca576c2 6f634d4 14e759a ace2d98 14e759a 57372b9 8fa7ba8 ca576c2 57372b9 14e759a 41b512d 0c103b7 41b512d 830888a d780c23 ca576c2 6f634d4 ace2d98 6f634d4 41b512d 6f634d4 ca576c2 830888a ca576c2 f607f30 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 | import warnings
import logging
import os
import shutil
import tempfile
from pathlib import Path
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from shiny import App, Inputs, Outputs, Session, reactive, render, ui
import prediction
# Suppress specific warnings
warnings.filterwarnings("ignore", category=UserWarning, message="No writable cache directories")
warnings.filterwarnings("ignore", category=FutureWarning, message="`resume_download` is deprecated")
# Set logging level to WARNING to ignore INFO logs
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# The environment variables are set in the Dockerfile, no need to set them here
WWW_DIR = Path(__file__).parent.resolve() / "www"
# Define global variables
current_image_path = None
current_mask = None
def create_ui():
"""Construct the UI layout for the application."""
return ui.page_fillable(
ui.tags.div(
ui.panel_title("Segment Sidewalks"),
ui.input_dark_mode(mode="dark"),
class_="d-flex justify-content-between align-items-center",
),
ui.layout_sidebar(
ui.sidebar(
ui.input_file("input_image", "Upload .png Image", accept=[".png"], multiple=False),
ui.output_ui("side_menu_controls"),
),
ui.card(
ui.card_header(
"See the sidewalks",
ui.output_ui("overlay"),
class_="d-flex justify-content-between align-items-center",
),
ui.output_plot("plot_image_and_mask", fill=True),
full_screen=True,
),
ui.output_ui("compute"),
),
)
def server(input: Inputs, output: Outputs, session: Session):
global current_image_path
global current_mask
@reactive.Effect
@reactive.event(input.input_image)
def update_image():
global current_image_path
global current_mask
if input.input_image() is not None:
image_file = input.input_image()[0]
temp_image_path = image_file["datapath"]
image = Image.open(temp_image_path).convert("RGB")
# Create a temporary file to save the uploaded image
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
shutil.copy(temp_image_path, temp_file.name)
current_image_path = temp_file.name
logger.info(f"Image uploaded: {current_image_path}")
model, processor, device = prediction.load_model_and_processor(
WWW_DIR / "sidewalkSAM.pth", "facebook/sam-vit-base"
)
logger.info("Model and processor loaded successfully.")
current_mask = prediction.get_sidewalk_prediction(image, model, processor, device)
logger.info("Inference completed.")
# Trigger the plot update
output.plot_image_and_mask.invalidate()
@output
@render.plot
def plot_image_and_mask():
global current_image_path
global current_mask
logger.info(f"Plotting images with current_image_path: {current_image_path} and current_mask: {current_mask is not None}")
if current_image_path is None or current_mask is None:
logger.warning("Image or mask is None. Skipping plot.")
return
fig, axes = plt.subplots(1, 2, figsize=(15, 5))
# Plot the first image on the left
image = Image.open(current_image_path).convert("RGB")
axes[0].imshow(image)
axes[0].set_title("Original Image")
# Plot the second image on the right
axes[1].imshow(current_mask) # Assuming the second image is grayscale
axes[1].set_title("Prediction")
# Hide axis ticks and labels
for ax in axes:
ax.set_xticks([])
ax.set_yticks([])
ax.set_xticklabels([])
ax.set_yticklabels([])
# Display the images side by side
plt.show()
app = App(create_ui(), server, static_assets=str(WWW_DIR))
|