Spaces:
Sleeping
Sleeping
| import os, torch | |
| import gradio as gr | |
| from PIL import Image | |
| from utils.imageHandling import hfImageToTensor, preprocessing, postprocessing | |
| from model.modelLoading import loadModel | |
| ## %% CONSTANTS | |
| gta_image_dir = "./preloadedImages/GTAV" | |
| city_image_dir = "./preloadedImages/cityScapes" | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # %% prediction on an image | |
| def predict(inputImage: torch.Tensor, model) -> torch.Tensor: | |
| """ | |
| Predict the segmentation mask for the input image using the provided model. | |
| Args: | |
| inputImage (torch.Tensor): The input image tensor. | |
| model (BiSeNet): The BiSeNet model for segmentation. | |
| Returns: | |
| prediction (torch.Tensor): The predicted segmentation mask. | |
| """ | |
| with torch.no_grad(): | |
| output = model(preprocessing(inputImage.clone()).to(device)) | |
| output = output[0] if isinstance(output, (tuple, list)) else output | |
| return output[0].argmax(dim=0, keepdim=True).to(device) | |
| # %% Gradio interface | |
| def run_prediction(image: gr.Image, selected_model: str)-> tuple[torch.Tensor]: | |
| if image is None: | |
| return (gr.update(value=None, visible=False), gr.update(value=f"❌ No image provided for prediction.", visible=True)) | |
| if selected_model is None: | |
| return (gr.update(value=None, visible=False), gr.update(value=f"❌ No model selected for prediction.", visible=True)) | |
| try: | |
| model = loadModel(selected_model, device) | |
| image = hfImageToTensor(image, width=1024, height=512) | |
| prediction = predict(image, model) | |
| prediction = postprocessing(prediction) | |
| except Exception as e: | |
| return (gr.update(value=None, visible=False), gr.update(value=f"❌ {str(e)}.", visible=True)) | |
| return (gr.update(value=prediction, visible=True), gr.update(value="", visible=False)) | |
| # Gradio UI | |
| with gr.Blocks(title="Semantic Segmentation Predictors") as demo: | |
| gr.Markdown("# Semantic Segmentation with Real-Time Networks") | |
| gr.Markdown('A small user interface created to run semantic segmentation on images using city scapes like predictions and real time segmentation networks.') | |
| gr.Markdown("Upload an image and choose your preferred model for segmentation, or otherwise use one of the preloaded images.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_selector = gr.Radio( | |
| choices=["BiSeNet", "BiSeNetV2"], | |
| value="BiSeNet", | |
| label="Select the real time segmentation model" | |
| ) | |
| image_input = gr.Image(type="pil", label="Upload image") | |
| submit_btn = gr.Button("Run prediction") | |
| with gr.Column(): | |
| result_display = gr.Image(label="Model prediction", visible=True) | |
| error_text = gr.Markdown("", visible=False) | |
| with gr.Row(): | |
| gr.Markdown("## Preloaded GTA V images to be used for testing the model") | |
| with gr.Row(): | |
| gta_gallery = gr.Gallery( | |
| value=sorted([Image.open(os.path.join(gta_image_dir, f)).convert("RGB") for f in os.listdir(gta_image_dir) if f.endswith(".png")]), | |
| label="GTA V Examples", | |
| show_label=False, | |
| columns=5, | |
| rows=1, | |
| height=200, | |
| type="pil" | |
| ) | |
| with gr.Row(): | |
| gr.Markdown("## Preloaded Cityscapes images to be used for testing the model") | |
| with gr.Row(): | |
| city_gallery = gr.Gallery(value=sorted([Image.open(os.path.join(city_image_dir, f)).convert("RGB") for f in os.listdir(city_image_dir) if f.endswith(".png")]), | |
| label="Cityscapes Examples", show_label=False, columns=5, rows=1, | |
| height=256, type="pil" | |
| ) | |
| submit_btn.click( | |
| fn=run_prediction, | |
| inputs=[image_input, model_selector], | |
| outputs=[result_display, error_text], | |
| ) | |
| gr.Markdown("Made by group 21 semantic segmentation project. ") | |
| def load_example(example_img): | |
| return gr.update(value=example_img) | |
| # On click: update image_input with selected example | |
| gta_gallery.select(fn=load_example, inputs=[gta_gallery], outputs=[image_input]) | |
| city_gallery.select(fn=load_example, inputs=[city_gallery], outputs=[image_input]) | |
| demo.launch() | |