VisualSemSeg / app.py
Nunzio
added prefixed images
79d4472
raw
history blame
4.22 kB
import os, torch
import gradio as gr
from PIL import Image
from utils.imageHandling import hfImageToTensor, preprocessing, postprocessing
from model.modelLoading import loadModel
## %% CONSTANTS
gta_image_dir = "./preloadedImages/GTAV"
city_image_dir = "./preloadedImages/cityScapes"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# %% prediction on an image
def predict(inputImage: torch.Tensor, model) -> torch.Tensor:
"""
Predict the segmentation mask for the input image using the provided model.
Args:
inputImage (torch.Tensor): The input image tensor.
model (BiSeNet): The BiSeNet model for segmentation.
Returns:
prediction (torch.Tensor): The predicted segmentation mask.
"""
with torch.no_grad():
output = model(preprocessing(inputImage.clone()).to(device))
output = output[0] if isinstance(output, (tuple, list)) else output
return output[0].argmax(dim=0, keepdim=True).to(device)
# %% Gradio interface
def run_prediction(image: gr.Image, selected_model: str)-> tuple[torch.Tensor]:
if image is None:
return (gr.update(value=None, visible=False), gr.update(value=f"❌ No image provided for prediction.", visible=True))
if selected_model is None:
return (gr.update(value=None, visible=False), gr.update(value=f"❌ No model selected for prediction.", visible=True))
try:
model = loadModel(selected_model, device)
image = hfImageToTensor(image, width=1024, height=512)
prediction = predict(image, model)
prediction = postprocessing(prediction)
except Exception as e:
return (gr.update(value=None, visible=False), gr.update(value=f"❌ {str(e)}.", visible=True))
return (gr.update(value=prediction, visible=True), gr.update(value="", visible=False))
# Gradio UI
with gr.Blocks(title="Semantic Segmentation Predictors") as demo:
gr.Markdown("# Semantic Segmentation with Real-Time Networks")
gr.Markdown('A small user interface created to run semantic segmentation on images using city scapes like predictions and real time segmentation networks.')
gr.Markdown("Upload an image and choose your preferred model for segmentation, or otherwise use one of the preloaded images.")
with gr.Row():
with gr.Column():
model_selector = gr.Radio(
choices=["BiSeNet", "BiSeNetV2"],
value="BiSeNet",
label="Select the real time segmentation model"
)
image_input = gr.Image(type="pil", label="Upload image")
submit_btn = gr.Button("Run prediction")
with gr.Column():
result_display = gr.Image(label="Model prediction", visible=True)
error_text = gr.Markdown("", visible=False)
with gr.Row():
gr.Markdown("## Preloaded GTA V images to be used for testing the model")
with gr.Row():
gta_gallery = gr.Gallery(
value=sorted([Image.open(os.path.join(gta_image_dir, f)).convert("RGB") for f in os.listdir(gta_image_dir) if f.endswith(".png")]),
label="GTA V Examples",
show_label=False,
columns=5,
rows=1,
height=200,
type="pil"
)
with gr.Row():
gr.Markdown("## Preloaded Cityscapes images to be used for testing the model")
with gr.Row():
city_gallery = gr.Gallery(value=sorted([Image.open(os.path.join(city_image_dir, f)).convert("RGB") for f in os.listdir(city_image_dir) if f.endswith(".png")]),
label="Cityscapes Examples", show_label=False, columns=5, rows=1,
height=256, type="pil"
)
submit_btn.click(
fn=run_prediction,
inputs=[image_input, model_selector],
outputs=[result_display, error_text],
)
gr.Markdown("Made by group 21 semantic segmentation project. ")
def load_example(example_img):
return gr.update(value=example_img)
# On click: update image_input with selected example
gta_gallery.select(fn=load_example, inputs=[gta_gallery], outputs=[image_input])
city_gallery.select(fn=load_example, inputs=[city_gallery], outputs=[image_input])
demo.launch()