Nunzio commited on
Commit
2eadc64
·
1 Parent(s): 37570fa

added new features to the Gradio interface and fixed some bugs in the prediction function.

Browse files
Files changed (3) hide show
  1. app.py +5 -6
  2. utils/imageHandling.py +38 -1
  3. utils2.py +0 -37
app.py CHANGED
@@ -39,7 +39,7 @@ def run_prediction(image: gr.Image, selected_model: str)-> tuple[torch.Tensor]:
39
  return (gr.update(value=prediction, visible=True), gr.update(value="", visible=False))
40
 
41
  # Gradio UI
42
- with gr.Blocks(title="🔀 BiSeNet | BiSeNetV2 Predictor") as demo:
43
  gr.Markdown("## Semantic Segmentation with Real-Time Networks")
44
  gr.Markdown('A small user interface created to run semantic segmentation on images using city scapes like predictions and real time segmentation networks.')
45
  gr.Markdown("Upload an image and choose your preferred model for segmentation.")
@@ -54,16 +54,15 @@ with gr.Blocks(title="🔀 BiSeNet | BiSeNetV2 Predictor") as demo:
54
  image_input = gr.Image(type="pil", label="Upload image")
55
  submit_btn = gr.Button("Run prediction")
56
  with gr.Column():
57
- result_display = gr.Image(label="Model prediction")
58
-
59
- error_text = gr.Textbox(label="Message", interactive=False, visible=False)
60
-
61
  submit_btn.click(
62
  fn=run_prediction,
63
  inputs=[image_input, model_selector],
64
  outputs=[result_display, error_text],
65
  )
66
-
67
  gr.Markdown("Made by group 21 semantic segmentation project. ")
68
 
69
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
39
  return (gr.update(value=prediction, visible=True), gr.update(value="", visible=False))
40
 
41
  # Gradio UI
42
+ with gr.Blocks(title="Semantic Segmentation Predictors") as demo:
43
  gr.Markdown("## Semantic Segmentation with Real-Time Networks")
44
  gr.Markdown('A small user interface created to run semantic segmentation on images using city scapes like predictions and real time segmentation networks.')
45
  gr.Markdown("Upload an image and choose your preferred model for segmentation.")
 
54
  image_input = gr.Image(type="pil", label="Upload image")
55
  submit_btn = gr.Button("Run prediction")
56
  with gr.Column():
57
+ result_display = gr.Image(label="Model prediction", visible=False)
58
+ error_text = gr.Markdown("", visible=False)
59
+
 
60
  submit_btn.click(
61
  fn=run_prediction,
62
  inputs=[image_input, model_selector],
63
  outputs=[result_display, error_text],
64
  )
65
+
66
  gr.Markdown("Made by group 21 semantic segmentation project. ")
67
 
68
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
utils/imageHandling.py CHANGED
@@ -31,6 +31,43 @@ def preprocessing(image_tensor: torch.Tensor) -> torch.Tensor:
31
  return torchvision.transforms.functional.normalize(
32
  image_tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
33
  ).unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  # %% postprocessing
36
  def postprocessing(pred: torch.Tensor) -> torch.Tensor:
@@ -43,4 +80,4 @@ def postprocessing(pred: torch.Tensor) -> torch.Tensor:
43
  Returns:
44
  torch.Tensor: Processed tensor of shape (3, H, W) for visualization.
45
  """
46
- return torchvision.transforms.functional.to_pil_image(pred.squeeze(0).cpu().to(torch.uint8))
 
31
  return torchvision.transforms.functional.normalize(
32
  image_tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
33
  ).unsqueeze(0)
34
+
35
+ # %% print mask on a sem seg style
36
+ def print_mask(mask:torch.Tensor, numClasses:int=19)->None:
37
+ """
38
+ Visualizes the segmentation mask by mapping each class to a specific color.
39
+
40
+ Args:
41
+ mask (torch.Tensor): The segmentation mask to visualize.
42
+ numClasses (int, optional): Number of classes in the segmentation mask. Defaults to 19.
43
+ """
44
+ colors = [
45
+ (128, 64, 128), # 0: road
46
+ (244, 35, 232), # 1: sidewalk
47
+ (70, 70, 70), # 2: building
48
+ (102, 102, 156), # 3: wall
49
+ (190, 153, 153), # 4: fence
50
+ (153, 153, 153), # 5: pole
51
+ (250, 170, 30), # 6: traffic light
52
+ (220, 220, 0), # 7: traffic sign
53
+ (107, 142, 35), # 8: vegetation
54
+ (152, 251, 152), # 9: terrain
55
+ (70, 130, 180), # 10: sky
56
+ (220, 20, 60), # 11: person
57
+ (255, 0, 0), # 12: rider
58
+ (0, 0, 142), # 13: car
59
+ (0, 0, 70), # 14: truck
60
+ (0, 60, 100), # 15: bus
61
+ (0, 80, 100), # 16: train
62
+ (0, 0, 230), # 17: motorcycle
63
+ (119, 11, 32) # 18: bicycle
64
+ ]
65
+
66
+ new_mask = torch.zeros((mask.shape[0], mask.shape[1], 3),dtype=torch.uint8)
67
+ new_mask[mask == 255] = (0,0,0)
68
+ for i in range (numClasses):
69
+ new_mask[mask == i] = colors[i][:3]
70
+ return new_mask.permute(2,0,1)
71
 
72
  # %% postprocessing
73
  def postprocessing(pred: torch.Tensor) -> torch.Tensor:
 
80
  Returns:
81
  torch.Tensor: Processed tensor of shape (3, H, W) for visualization.
82
  """
83
+ return torchvision.transforms.functional.to_pil_image(print_mask(pred.squeeze(0).cpu().to(torch.uint8)))
utils2.py DELETED
@@ -1,37 +0,0 @@
1
- import torch
2
-
3
- def print_mask(mask:torch.Tensor, numClasses:int=19)->None:
4
- """
5
- Visualizes the segmentation mask by mapping each class to a specific color.
6
-
7
- Args:
8
- mask (torch.Tensor): The segmentation mask to visualize.
9
- numClasses (int, optional): Number of classes in the segmentation mask. Defaults to 19.
10
- """
11
- colors = [
12
- (128, 64, 128), # 0: road
13
- (244, 35, 232), # 1: sidewalk
14
- (70, 70, 70), # 2: building
15
- (102, 102, 156), # 3: wall
16
- (190, 153, 153), # 4: fence
17
- (153, 153, 153), # 5: pole
18
- (250, 170, 30), # 6: traffic light
19
- (220, 220, 0), # 7: traffic sign
20
- (107, 142, 35), # 8: vegetation
21
- (152, 251, 152), # 9: terrain
22
- (70, 130, 180), # 10: sky
23
- (220, 20, 60), # 11: person
24
- (255, 0, 0), # 12: rider
25
- (0, 0, 142), # 13: car
26
- (0, 0, 70), # 14: truck
27
- (0, 60, 100), # 15: bus
28
- (0, 80, 100), # 16: train
29
- (0, 0, 230), # 17: motorcycle
30
- (119, 11, 32) # 18: bicycle
31
- ]
32
-
33
- new_mask = torch.zeros((mask.shape[0], mask.shape[1], 3),dtype=torch.uint8)
34
- new_mask[mask == 255] = (0,0,0)
35
- for i in range (numClasses):
36
- new_mask[mask == i] = colors[i][:3]
37
- return new_mask.permute(2,0,1)