Spaces:
Sleeping
Sleeping
Nunzio commited on
Commit ·
2eadc64
1
Parent(s): 37570fa
added new features to the Gradio interface and fixed some bugs in the prediction function.
Browse files- app.py +5 -6
- utils/imageHandling.py +38 -1
- utils2.py +0 -37
app.py
CHANGED
|
@@ -39,7 +39,7 @@ def run_prediction(image: gr.Image, selected_model: str)-> tuple[torch.Tensor]:
|
|
| 39 |
return (gr.update(value=prediction, visible=True), gr.update(value="", visible=False))
|
| 40 |
|
| 41 |
# Gradio UI
|
| 42 |
-
with gr.Blocks(title="
|
| 43 |
gr.Markdown("## Semantic Segmentation with Real-Time Networks")
|
| 44 |
gr.Markdown('A small user interface created to run semantic segmentation on images using city scapes like predictions and real time segmentation networks.')
|
| 45 |
gr.Markdown("Upload an image and choose your preferred model for segmentation.")
|
|
@@ -54,16 +54,15 @@ with gr.Blocks(title="🔀 BiSeNet | BiSeNetV2 Predictor") as demo:
|
|
| 54 |
image_input = gr.Image(type="pil", label="Upload image")
|
| 55 |
submit_btn = gr.Button("Run prediction")
|
| 56 |
with gr.Column():
|
| 57 |
-
result_display = gr.Image(label="Model prediction")
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
submit_btn.click(
|
| 62 |
fn=run_prediction,
|
| 63 |
inputs=[image_input, model_selector],
|
| 64 |
outputs=[result_display, error_text],
|
| 65 |
)
|
| 66 |
-
|
| 67 |
gr.Markdown("Made by group 21 semantic segmentation project. ")
|
| 68 |
|
| 69 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
| 39 |
return (gr.update(value=prediction, visible=True), gr.update(value="", visible=False))
|
| 40 |
|
| 41 |
# Gradio UI
|
| 42 |
+
with gr.Blocks(title="Semantic Segmentation Predictors") as demo:
|
| 43 |
gr.Markdown("## Semantic Segmentation with Real-Time Networks")
|
| 44 |
gr.Markdown('A small user interface created to run semantic segmentation on images using city scapes like predictions and real time segmentation networks.')
|
| 45 |
gr.Markdown("Upload an image and choose your preferred model for segmentation.")
|
|
|
|
| 54 |
image_input = gr.Image(type="pil", label="Upload image")
|
| 55 |
submit_btn = gr.Button("Run prediction")
|
| 56 |
with gr.Column():
|
| 57 |
+
result_display = gr.Image(label="Model prediction", visible=False)
|
| 58 |
+
error_text = gr.Markdown("", visible=False)
|
| 59 |
+
|
|
|
|
| 60 |
submit_btn.click(
|
| 61 |
fn=run_prediction,
|
| 62 |
inputs=[image_input, model_selector],
|
| 63 |
outputs=[result_display, error_text],
|
| 64 |
)
|
| 65 |
+
|
| 66 |
gr.Markdown("Made by group 21 semantic segmentation project. ")
|
| 67 |
|
| 68 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
utils/imageHandling.py
CHANGED
|
@@ -31,6 +31,43 @@ def preprocessing(image_tensor: torch.Tensor) -> torch.Tensor:
|
|
| 31 |
return torchvision.transforms.functional.normalize(
|
| 32 |
image_tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
| 33 |
).unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
# %% postprocessing
|
| 36 |
def postprocessing(pred: torch.Tensor) -> torch.Tensor:
|
|
@@ -43,4 +80,4 @@ def postprocessing(pred: torch.Tensor) -> torch.Tensor:
|
|
| 43 |
Returns:
|
| 44 |
torch.Tensor: Processed tensor of shape (3, H, W) for visualization.
|
| 45 |
"""
|
| 46 |
-
return torchvision.transforms.functional.to_pil_image(pred.squeeze(0).cpu().to(torch.uint8))
|
|
|
|
| 31 |
return torchvision.transforms.functional.normalize(
|
| 32 |
image_tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
| 33 |
).unsqueeze(0)
|
| 34 |
+
|
| 35 |
+
# %% print mask on a sem seg style
|
| 36 |
+
def print_mask(mask:torch.Tensor, numClasses:int=19)->None:
|
| 37 |
+
"""
|
| 38 |
+
Visualizes the segmentation mask by mapping each class to a specific color.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
mask (torch.Tensor): The segmentation mask to visualize.
|
| 42 |
+
numClasses (int, optional): Number of classes in the segmentation mask. Defaults to 19.
|
| 43 |
+
"""
|
| 44 |
+
colors = [
|
| 45 |
+
(128, 64, 128), # 0: road
|
| 46 |
+
(244, 35, 232), # 1: sidewalk
|
| 47 |
+
(70, 70, 70), # 2: building
|
| 48 |
+
(102, 102, 156), # 3: wall
|
| 49 |
+
(190, 153, 153), # 4: fence
|
| 50 |
+
(153, 153, 153), # 5: pole
|
| 51 |
+
(250, 170, 30), # 6: traffic light
|
| 52 |
+
(220, 220, 0), # 7: traffic sign
|
| 53 |
+
(107, 142, 35), # 8: vegetation
|
| 54 |
+
(152, 251, 152), # 9: terrain
|
| 55 |
+
(70, 130, 180), # 10: sky
|
| 56 |
+
(220, 20, 60), # 11: person
|
| 57 |
+
(255, 0, 0), # 12: rider
|
| 58 |
+
(0, 0, 142), # 13: car
|
| 59 |
+
(0, 0, 70), # 14: truck
|
| 60 |
+
(0, 60, 100), # 15: bus
|
| 61 |
+
(0, 80, 100), # 16: train
|
| 62 |
+
(0, 0, 230), # 17: motorcycle
|
| 63 |
+
(119, 11, 32) # 18: bicycle
|
| 64 |
+
]
|
| 65 |
+
|
| 66 |
+
new_mask = torch.zeros((mask.shape[0], mask.shape[1], 3),dtype=torch.uint8)
|
| 67 |
+
new_mask[mask == 255] = (0,0,0)
|
| 68 |
+
for i in range (numClasses):
|
| 69 |
+
new_mask[mask == i] = colors[i][:3]
|
| 70 |
+
return new_mask.permute(2,0,1)
|
| 71 |
|
| 72 |
# %% postprocessing
|
| 73 |
def postprocessing(pred: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 80 |
Returns:
|
| 81 |
torch.Tensor: Processed tensor of shape (3, H, W) for visualization.
|
| 82 |
"""
|
| 83 |
+
return torchvision.transforms.functional.to_pil_image(print_mask(pred.squeeze(0).cpu().to(torch.uint8)))
|
utils2.py
DELETED
|
@@ -1,37 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
|
| 3 |
-
def print_mask(mask:torch.Tensor, numClasses:int=19)->None:
|
| 4 |
-
"""
|
| 5 |
-
Visualizes the segmentation mask by mapping each class to a specific color.
|
| 6 |
-
|
| 7 |
-
Args:
|
| 8 |
-
mask (torch.Tensor): The segmentation mask to visualize.
|
| 9 |
-
numClasses (int, optional): Number of classes in the segmentation mask. Defaults to 19.
|
| 10 |
-
"""
|
| 11 |
-
colors = [
|
| 12 |
-
(128, 64, 128), # 0: road
|
| 13 |
-
(244, 35, 232), # 1: sidewalk
|
| 14 |
-
(70, 70, 70), # 2: building
|
| 15 |
-
(102, 102, 156), # 3: wall
|
| 16 |
-
(190, 153, 153), # 4: fence
|
| 17 |
-
(153, 153, 153), # 5: pole
|
| 18 |
-
(250, 170, 30), # 6: traffic light
|
| 19 |
-
(220, 220, 0), # 7: traffic sign
|
| 20 |
-
(107, 142, 35), # 8: vegetation
|
| 21 |
-
(152, 251, 152), # 9: terrain
|
| 22 |
-
(70, 130, 180), # 10: sky
|
| 23 |
-
(220, 20, 60), # 11: person
|
| 24 |
-
(255, 0, 0), # 12: rider
|
| 25 |
-
(0, 0, 142), # 13: car
|
| 26 |
-
(0, 0, 70), # 14: truck
|
| 27 |
-
(0, 60, 100), # 15: bus
|
| 28 |
-
(0, 80, 100), # 16: train
|
| 29 |
-
(0, 0, 230), # 17: motorcycle
|
| 30 |
-
(119, 11, 32) # 18: bicycle
|
| 31 |
-
]
|
| 32 |
-
|
| 33 |
-
new_mask = torch.zeros((mask.shape[0], mask.shape[1], 3),dtype=torch.uint8)
|
| 34 |
-
new_mask[mask == 255] = (0,0,0)
|
| 35 |
-
for i in range (numClasses):
|
| 36 |
-
new_mask[mask == i] = colors[i][:3]
|
| 37 |
-
return new_mask.permute(2,0,1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|