File size: 4,221 Bytes
ed3e09d
6a0b93e
79d4472
 
dcb04c4
5a17bb3
6a0b93e
79d4472
 
 
 
 
 
 
 
6a0b93e
 
5a17bb3
6a0b93e
 
 
 
 
 
 
 
 
 
 
049f834
6a0b93e
c00c6a4
6a0b93e
 
94c4671
 
dcb04c4
37570fa
dcb04c4
 
37570fa
60fd570
 
 
 
 
 
 
 
37570fa
94c4671
 
2eadc64
79d4472
5a17bb3
ff83735
94c4671
 
 
 
 
 
79d4472
94c4671
 
049f834
94c4671
ff83735
2eadc64
79d4472
 
 
 
 
 
 
 
 
 
 
 
 
2eadc64
79d4472
 
 
 
 
 
 
 
94c4671
 
 
37570fa
94c4671
2eadc64
5a17bb3
94c4671
79d4472
 
 
 
 
 
 
94c4671
79d4472
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os, torch
import gradio as gr
from PIL import Image

from utils.imageHandling import hfImageToTensor, preprocessing, postprocessing
from model.modelLoading import loadModel


## %% CONSTANTS
gta_image_dir = "./preloadedImages/GTAV"
city_image_dir = "./preloadedImages/cityScapes"
device = 'cuda' if torch.cuda.is_available() else 'cpu'



# %% prediction on an image

def predict(inputImage: torch.Tensor, model) -> torch.Tensor:
    """
    Predict the segmentation mask for the input image using the provided model.
    
    Args:
        inputImage (torch.Tensor): The input image tensor.
        model (BiSeNet): The BiSeNet model for segmentation.
        
    Returns:
        prediction (torch.Tensor): The predicted segmentation mask.
    """
    with torch.no_grad():
        output = model(preprocessing(inputImage.clone()).to(device))
    output = output[0] if isinstance(output, (tuple, list)) else output
    return output[0].argmax(dim=0, keepdim=True).to(device)  


# %% Gradio interface
def run_prediction(image: gr.Image, selected_model: str)-> tuple[torch.Tensor]:
    if image is None:
        return (gr.update(value=None, visible=False), gr.update(value=f"❌ No image provided for prediction.", visible=True))
    
    if selected_model is None:
        return (gr.update(value=None, visible=False),  gr.update(value=f"❌ No model selected for prediction.", visible=True))
    
    try:
        model = loadModel(selected_model, device) 
        image = hfImageToTensor(image, width=1024, height=512)
        prediction = predict(image, model)
        prediction = postprocessing(prediction)
    except Exception as e:
        return (gr.update(value=None, visible=False),  gr.update(value=f"❌ {str(e)}.", visible=True))
    return (gr.update(value=prediction, visible=True), gr.update(value="", visible=False))

# Gradio UI
with gr.Blocks(title="Semantic Segmentation Predictors") as demo:
    gr.Markdown("# Semantic Segmentation with Real-Time Networks")
    gr.Markdown('A small user interface created to run semantic segmentation on images using city scapes like predictions and real time segmentation networks.')
    gr.Markdown("Upload an image and choose your preferred model for segmentation, or otherwise use one of the preloaded images.")

    with gr.Row():
        with gr.Column():
            model_selector = gr.Radio(
                choices=["BiSeNet", "BiSeNetV2"],
                value="BiSeNet",
                label="Select the real time segmentation model"
            )
            image_input = gr.Image(type="pil", label="Upload image")
            submit_btn = gr.Button("Run prediction")
        with gr.Column():
            result_display = gr.Image(label="Model prediction", visible=True)
            error_text = gr.Markdown("", visible=False)
    
    with gr.Row():
        gr.Markdown("## Preloaded GTA V images to be used for testing the model")
    with gr.Row():
        gta_gallery = gr.Gallery(
            value=sorted([Image.open(os.path.join(gta_image_dir, f)).convert("RGB") for f in os.listdir(gta_image_dir) if f.endswith(".png")]),
            label="GTA V Examples",
            show_label=False,
            columns=5,
            rows=1,
            height=200,
            type="pil"
        )

    with gr.Row():
        gr.Markdown("## Preloaded Cityscapes images to be used for testing the model")
    with gr.Row():
        city_gallery = gr.Gallery(value=sorted([Image.open(os.path.join(city_image_dir, f)).convert("RGB") for f in os.listdir(city_image_dir) if f.endswith(".png")]),
            label="Cityscapes Examples", show_label=False, columns=5, rows=1,
            height=256, type="pil"
    )
    
    submit_btn.click(
        fn=run_prediction,
        inputs=[image_input, model_selector],
        outputs=[result_display, error_text],
    )

    gr.Markdown("Made by group 21 semantic segmentation project. ")

def load_example(example_img):
    return gr.update(value=example_img)

# On click: update image_input with selected example
gta_gallery.select(fn=load_example, inputs=[gta_gallery], outputs=[image_input])
city_gallery.select(fn=load_example, inputs=[city_gallery], outputs=[image_input])

demo.launch()