import os, torch import gradio as gr from PIL import Image from utils.imageHandling import hfImageToTensor, preprocessing, postprocessing from model.modelLoading import loadModel ## %% CONSTANTS gta_image_dir = "./preloadedImages/GTAV" city_image_dir = "./preloadedImages/cityScapes" device = 'cuda' if torch.cuda.is_available() else 'cpu' # %% prediction on an image def predict(inputImage: torch.Tensor, model) -> torch.Tensor: """ Predict the segmentation mask for the input image using the provided model. Args: inputImage (torch.Tensor): The input image tensor. model (BiSeNet): The BiSeNet model for segmentation. Returns: prediction (torch.Tensor): The predicted segmentation mask. """ with torch.no_grad(): output = model(preprocessing(inputImage.clone()).to(device)) output = output[0] if isinstance(output, (tuple, list)) else output return output[0].argmax(dim=0, keepdim=True).to(device) # %% Gradio interface def run_prediction(image: gr.Image, selected_model: str)-> tuple[torch.Tensor]: if image is None: return (gr.update(value=None, visible=False), gr.update(value=f"❌ No image provided for prediction.", visible=True)) if selected_model is None: return (gr.update(value=None, visible=False), gr.update(value=f"❌ No model selected for prediction.", visible=True)) try: model = loadModel(selected_model, device) image = hfImageToTensor(image, width=1024, height=512) prediction = predict(image, model) prediction = postprocessing(prediction) except Exception as e: return (gr.update(value=None, visible=False), gr.update(value=f"❌ {str(e)}.", visible=True)) return (gr.update(value=prediction, visible=True), gr.update(value="", visible=False)) # Gradio UI with gr.Blocks(title="Semantic Segmentation Predictors") as demo: gr.Markdown("# Semantic Segmentation with Real-Time Networks") gr.Markdown('A small user interface created to run semantic segmentation on images using city scapes like predictions and real time segmentation networks.') gr.Markdown("Upload an image and choose your preferred model for segmentation, or otherwise use one of the preloaded images.") with gr.Row(): with gr.Column(): model_selector = gr.Radio( choices=["BiSeNet", "BiSeNetV2"], value="BiSeNet", label="Select the real time segmentation model" ) image_input = gr.Image(type="pil", label="Upload image") submit_btn = gr.Button("Run prediction") with gr.Column(): result_display = gr.Image(label="Model prediction", visible=True) error_text = gr.Markdown("", visible=False) with gr.Row(): gr.Markdown("## Preloaded GTA V images to be used for testing the model") with gr.Row(): gta_gallery = gr.Gallery( value=sorted([Image.open(os.path.join(gta_image_dir, f)).convert("RGB") for f in os.listdir(gta_image_dir) if f.endswith(".png")]), label="GTA V Examples", show_label=False, columns=5, rows=1, height=200, type="pil" ) with gr.Row(): gr.Markdown("## Preloaded Cityscapes images to be used for testing the model") with gr.Row(): city_gallery = gr.Gallery(value=sorted([Image.open(os.path.join(city_image_dir, f)).convert("RGB") for f in os.listdir(city_image_dir) if f.endswith(".png")]), label="Cityscapes Examples", show_label=False, columns=5, rows=1, height=256, type="pil" ) submit_btn.click( fn=run_prediction, inputs=[image_input, model_selector], outputs=[result_display, error_text], ) gr.Markdown("Made by group 21 semantic segmentation project. ") def load_example(example_img): return gr.update(value=example_img) # On click: update image_input with selected example gta_gallery.select(fn=load_example, inputs=[gta_gallery], outputs=[image_input]) city_gallery.select(fn=load_example, inputs=[city_gallery], outputs=[image_input]) demo.launch()