File size: 4,170 Bytes
41b512d
 
9304821
16b06ec
8fa7ba8
3e3df2b
 
 
 
 
 
 
f607f30
41b512d
 
 
f607f30
41b512d
 
 
 
3e3df2b
 
6f634d4
 
 
 
3e3df2b
 
 
 
c4b8156
c6bf71c
3e3df2b
 
 
 
c6bf71c
3e3df2b
 
 
 
c4b8156
3e3df2b
 
 
 
 
 
 
 
0222108
 
ca576c2
6f634d4
 
 
14e759a
 
 
ace2d98
 
 
14e759a
 
57372b9
 
 
8fa7ba8
 
 
 
ca576c2
57372b9
14e759a
 
 
 
41b512d
 
0c103b7
41b512d
830888a
d780c23
 
 
ca576c2
 
 
6f634d4
 
 
ace2d98
 
6f634d4
41b512d
6f634d4
 
ca576c2
 
 
830888a
 
ca576c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f607f30
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import warnings
import logging
import os
import shutil
import tempfile
from pathlib import Path
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from shiny import App, Inputs, Outputs, Session, reactive, render, ui
import prediction

# Suppress specific warnings
warnings.filterwarnings("ignore", category=UserWarning, message="No writable cache directories")
warnings.filterwarnings("ignore", category=FutureWarning, message="`resume_download` is deprecated")

# Set logging level to WARNING to ignore INFO logs
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# The environment variables are set in the Dockerfile, no need to set them here
WWW_DIR = Path(__file__).parent.resolve() / "www"

# Define global variables
current_image_path = None
current_mask = None

def create_ui():
    """Construct the UI layout for the application."""
    return ui.page_fillable(
        ui.tags.div(
            ui.panel_title("Segment Sidewalks"),
            ui.input_dark_mode(mode="dark"),
            class_="d-flex justify-content-between align-items-center",
        ),
        ui.layout_sidebar(
            ui.sidebar(
                ui.input_file("input_image", "Upload .png Image", accept=[".png"], multiple=False),
                ui.output_ui("side_menu_controls"),
            ),
            ui.card(
                ui.card_header(
                    "See the sidewalks",
                    ui.output_ui("overlay"),
                    class_="d-flex justify-content-between align-items-center",
                ),
                ui.output_plot("plot_image_and_mask", fill=True),
                full_screen=True,
            ),
            ui.output_ui("compute"),
        ),
    )

def server(input: Inputs, output: Outputs, session: Session):
    global current_image_path
    global current_mask

    @reactive.Effect
    @reactive.event(input.input_image)
    def update_image():
        global current_image_path
        global current_mask
        
        if input.input_image() is not None:
            image_file = input.input_image()[0]
            temp_image_path = image_file["datapath"]
            image = Image.open(temp_image_path).convert("RGB")
            
            # Create a temporary file to save the uploaded image
            with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
                shutil.copy(temp_image_path, temp_file.name)
                current_image_path = temp_file.name

            logger.info(f"Image uploaded: {current_image_path}")

            model, processor, device = prediction.load_model_and_processor(
                WWW_DIR / "sidewalkSAM.pth", "facebook/sam-vit-base"
            )
            logger.info("Model and processor loaded successfully.")

            current_mask = prediction.get_sidewalk_prediction(image, model, processor, device)
            logger.info("Inference completed.")

            # Trigger the plot update
            output.plot_image_and_mask.invalidate()

    @output
    @render.plot
    def plot_image_and_mask():
        global current_image_path
        global current_mask

        logger.info(f"Plotting images with current_image_path: {current_image_path} and current_mask: {current_mask is not None}")

        if current_image_path is None or current_mask is None:
            logger.warning("Image or mask is None. Skipping plot.")
            return

        fig, axes = plt.subplots(1, 2, figsize=(15, 5))

        # Plot the first image on the left
        image = Image.open(current_image_path).convert("RGB")
        axes[0].imshow(image)
        axes[0].set_title("Original Image")
        
        # Plot the second image on the right
        axes[1].imshow(current_mask)  # Assuming the second image is grayscale
        axes[1].set_title("Prediction")

        # Hide axis ticks and labels
        for ax in axes:
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_xticklabels([])
            ax.set_yticklabels([])
        
        # Display the images side by side
        plt.show()

app = App(create_ui(), server, static_assets=str(WWW_DIR))