Spaces:
Sleeping
Sleeping
File size: 4,221 Bytes
ed3e09d 6a0b93e 79d4472 dcb04c4 5a17bb3 6a0b93e 79d4472 6a0b93e 5a17bb3 6a0b93e 049f834 6a0b93e c00c6a4 6a0b93e 94c4671 dcb04c4 37570fa dcb04c4 37570fa 60fd570 37570fa 94c4671 2eadc64 79d4472 5a17bb3 ff83735 94c4671 79d4472 94c4671 049f834 94c4671 ff83735 2eadc64 79d4472 2eadc64 79d4472 94c4671 37570fa 94c4671 2eadc64 5a17bb3 94c4671 79d4472 94c4671 79d4472 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 | import os, torch
import gradio as gr
from PIL import Image
from utils.imageHandling import hfImageToTensor, preprocessing, postprocessing
from model.modelLoading import loadModel
## %% CONSTANTS
gta_image_dir = "./preloadedImages/GTAV"
city_image_dir = "./preloadedImages/cityScapes"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# %% prediction on an image
def predict(inputImage: torch.Tensor, model) -> torch.Tensor:
"""
Predict the segmentation mask for the input image using the provided model.
Args:
inputImage (torch.Tensor): The input image tensor.
model (BiSeNet): The BiSeNet model for segmentation.
Returns:
prediction (torch.Tensor): The predicted segmentation mask.
"""
with torch.no_grad():
output = model(preprocessing(inputImage.clone()).to(device))
output = output[0] if isinstance(output, (tuple, list)) else output
return output[0].argmax(dim=0, keepdim=True).to(device)
# %% Gradio interface
def run_prediction(image: gr.Image, selected_model: str)-> tuple[torch.Tensor]:
if image is None:
return (gr.update(value=None, visible=False), gr.update(value=f"❌ No image provided for prediction.", visible=True))
if selected_model is None:
return (gr.update(value=None, visible=False), gr.update(value=f"❌ No model selected for prediction.", visible=True))
try:
model = loadModel(selected_model, device)
image = hfImageToTensor(image, width=1024, height=512)
prediction = predict(image, model)
prediction = postprocessing(prediction)
except Exception as e:
return (gr.update(value=None, visible=False), gr.update(value=f"❌ {str(e)}.", visible=True))
return (gr.update(value=prediction, visible=True), gr.update(value="", visible=False))
# Gradio UI
with gr.Blocks(title="Semantic Segmentation Predictors") as demo:
gr.Markdown("# Semantic Segmentation with Real-Time Networks")
gr.Markdown('A small user interface created to run semantic segmentation on images using city scapes like predictions and real time segmentation networks.')
gr.Markdown("Upload an image and choose your preferred model for segmentation, or otherwise use one of the preloaded images.")
with gr.Row():
with gr.Column():
model_selector = gr.Radio(
choices=["BiSeNet", "BiSeNetV2"],
value="BiSeNet",
label="Select the real time segmentation model"
)
image_input = gr.Image(type="pil", label="Upload image")
submit_btn = gr.Button("Run prediction")
with gr.Column():
result_display = gr.Image(label="Model prediction", visible=True)
error_text = gr.Markdown("", visible=False)
with gr.Row():
gr.Markdown("## Preloaded GTA V images to be used for testing the model")
with gr.Row():
gta_gallery = gr.Gallery(
value=sorted([Image.open(os.path.join(gta_image_dir, f)).convert("RGB") for f in os.listdir(gta_image_dir) if f.endswith(".png")]),
label="GTA V Examples",
show_label=False,
columns=5,
rows=1,
height=200,
type="pil"
)
with gr.Row():
gr.Markdown("## Preloaded Cityscapes images to be used for testing the model")
with gr.Row():
city_gallery = gr.Gallery(value=sorted([Image.open(os.path.join(city_image_dir, f)).convert("RGB") for f in os.listdir(city_image_dir) if f.endswith(".png")]),
label="Cityscapes Examples", show_label=False, columns=5, rows=1,
height=256, type="pil"
)
submit_btn.click(
fn=run_prediction,
inputs=[image_input, model_selector],
outputs=[result_display, error_text],
)
gr.Markdown("Made by group 21 semantic segmentation project. ")
def load_example(example_img):
return gr.update(value=example_img)
# On click: update image_input with selected example
gta_gallery.select(fn=load_example, inputs=[gta_gallery], outputs=[image_input])
city_gallery.select(fn=load_example, inputs=[city_gallery], outputs=[image_input])
demo.launch()
|