alrichardbollans commited on
Commit
ae2089c
·
1 Parent(s): 5fe3fb1

Improve UI explanations, fix NMS slider default, and add model version to results

Browse files
Files changed (2) hide show
  1. app.py +31 -17
  2. python_utils/get_model.py +10 -6
app.py CHANGED
@@ -21,15 +21,13 @@ import matplotlib.pyplot as plt
21
  from PIL import Image
22
  from shiny import App, ui, render, reactive, Session, module
23
  from detectron2.utils.visualizer import Visualizer, ColorMode
24
- from detectron2.data import Metadata
25
 
26
- from python_utils import load_model, apply_nms
27
 
28
  # Load data and compute static values
29
  app_dir = Path(__file__).parent
30
 
31
  protocol_url = 'https://pgomba.github.io/orchid_protocol/'
32
- discussion_url = 'https://huggingface.co/spaces/TZProject/TZSeedApp/discussions'
33
  acknowledgement_text = "The OrchAId TZ viability dataset used to develop the model was created by the Royal Botanic Gardens, Kew, Silo National des Graines Forestieres, Madagascar, the Ministry of Agriculture, Lands, Housing and Environment, Monsterrat, Instituto de Investigação Agrária de Moçambique, Mozambique, Departmento de Recursos Naturales y Ambientales, Puerto Rico & the National Parks Trust of the Virgin Islands."
34
 
35
  # Load the prediction model
@@ -43,8 +41,8 @@ main_app = ui.page_fluid(
43
  multiple=True,
44
  accept=[".png", ".jpg", ".jpeg"]),
45
 
46
- ui.input_slider("threshold", "Threshold for Discarding Overlapping Segmentations"
47
- " (ADD a more descriptive label, and explain in text above).", 0, 1.0, 0.7),
48
 
49
  ui.tags.style("""
50
  .irs.irs--shiny .irs-single { /* square with number */
@@ -119,23 +117,35 @@ app_ui = ui.page_fluid(
119
  ),
120
 
121
  ui.div(
122
- ui.h4("Using this App"),
123
  ui.p(
124
- "This app is built to use a computer vision model to analyse images of orchid TZ tests and count the number of viable, non-viable and empty seeds."
125
- " To use this app, upload images and click 'Analyse'."
126
- " Segmented images will be displayed in the right-hand panel and results can be downloaded."),
 
 
 
 
 
 
 
127
  ui.p(
128
- "This app is built for use with specific types of images -- the protocol for taking images compatible with this model is available on ",
 
129
  ui.a("GitHub", href=protocol_url, target="_blank"),
130
  ". The protocol is available in English, Indonesian, Thai, French, Spanish, Portuguese, Arabic, Mandarin, Malagasy and Japanese."),
131
  ui.p(
132
- "NMS set to 0.7 as this was found to be optimal for our data, but you can adjust this value in the slider and lower values may be useful for images"
133
- "with more overlapping seeds."),
134
- ui.p(" If you have any feedback on the app, please start a discussion on ",
135
- ui.a("the HuggingFace space", href=discussion_url, target="_blank")
 
 
 
 
 
136
  ),
137
- ui.h4("Model Details and Accuracy"),
138
- ui.p(),
139
  class_="body-bar"
140
  ),
141
  main_app,
@@ -260,8 +270,11 @@ def server(input, output, session: Session):
260
 
261
  # Run prediction with original BGR image
262
  prediction = predictor(im)
263
-
 
 
264
  prediction = apply_nms(prediction, mask=True, cls_agnostic_nms=input.threshold())
 
265
 
266
  classes = prediction["instances"].pred_classes.tolist()
267
 
@@ -332,6 +345,7 @@ def server(input, output, session: Session):
332
  "Empty": r.get("empty", ""),
333
  "Total": r.get("total", ""),
334
  'NMS Threshold': r.get('NMS threshold', ''),
 
335
  } for r in results])
336
 
337
  # Create in-memory CSV file
 
21
  from PIL import Image
22
  from shiny import App, ui, render, reactive, Session, module
23
  from detectron2.utils.visualizer import Visualizer, ColorMode
 
24
 
25
+ from python_utils import load_model, apply_nms, OPTIMAL_NMS_THRESHOLD, MODEL_VERSION, discussion_url, model_page, github_repo_url
26
 
27
  # Load data and compute static values
28
  app_dir = Path(__file__).parent
29
 
30
  protocol_url = 'https://pgomba.github.io/orchid_protocol/'
 
31
  acknowledgement_text = "The OrchAId TZ viability dataset used to develop the model was created by the Royal Botanic Gardens, Kew, Silo National des Graines Forestieres, Madagascar, the Ministry of Agriculture, Lands, Housing and Environment, Monsterrat, Instituto de Investigação Agrária de Moçambique, Mozambique, Departmento de Recursos Naturales y Ambientales, Puerto Rico & the National Parks Trust of the Virgin Islands."
32
 
33
  # Load the prediction model
 
41
  multiple=True,
42
  accept=[".png", ".jpg", ".jpeg"]),
43
 
44
+ ui.input_slider("threshold", f"Threshold for Discarding Overlapping Segmentations (Default: {OPTIMAL_NMS_THRESHOLD})", 0, 1.0,
45
+ OPTIMAL_NMS_THRESHOLD),
46
 
47
  ui.tags.style("""
48
  .irs.irs--shiny .irs-single { /* square with number */
 
117
  ),
118
 
119
  ui.div(
120
+ ui.h4("Model Overview"),
121
  ui.p(
122
+ "This app uses a computer vision model trained to analyse images of orchid tetrazolium chloride tests to count the number of viable, non-viable and empty seeds."
123
+ " Full details of the model, training process and evaluation can be found on the project ",
124
+ ui.a("GitHub repository", href=github_repo_url, target="_blank"),
125
+ "."),
126
+
127
+ ui.h5('Performance'),
128
+ ui.p('To briefly summarise model performance on our test dataset, '),
129
+ ui.p('Disclaimer: the evaluation of the model applies to our dataset and there are many factors that may influence performance of the model on new images.'
130
+ ' We recommend visually inspecting at least a few images to ensure the model is performing as expected on your batch of images.'),
131
+ ui.h4("Using this App"),
132
  ui.p(
133
+ "This app is built for use with ", ui.HTML("<b>specific types of images</b>"),
134
+ " -- the protocol for taking images compatible with this model is available on ",
135
  ui.a("GitHub", href=protocol_url, target="_blank"),
136
  ". The protocol is available in English, Indonesian, Thai, French, Spanish, Portuguese, Arabic, Mandarin, Malagasy and Japanese."),
137
  ui.p(
138
+ "To use this app, upload images and click 'Analyse'."
139
+ " Segmented images will be displayed in the right-hand panel, showing viable seeds in red, non-viable in yellow and empty in black. An opacity slider can be used to adjust the transparency of the segmentation masks."
140
+ " The counts will also be displayed as text and results can be downloaded using the 'Download Results' button."),
141
+
142
+ ui.p(
143
+ f"Before analysing images it is possible to change the threshold used to discard overlapping segmentations produced by the model. The default threshold is {OPTIMAL_NMS_THRESHOLD} as this was found to be optimal for our data, but you can adjust this value in the slider."
144
+ f" We recommend leaving this as the default, and only decreasing the value if you find that your images have many overlapping seeds and some of them are not being included in the output. Similarly, you can increase this value if your images have very few overlapping seeds and the output includes multiple segmentations of the same seed."),
145
+ ui.p(" If you have any feedback on the app, please start a discussion on the project ",
146
+ ui.a("HuggingFace space", href=discussion_url, target="_blank"), '.'
147
  ),
148
+
 
149
  class_="body-bar"
150
  ),
151
  main_app,
 
270
 
271
  # Run prediction with original BGR image
272
  prediction = predictor(im)
273
+ print(f"Analyzing image {idx + 1} of {len(files)}")
274
+ print(f"NMS threshold: {input.threshold()}")
275
+ print(f'Number of instances: {len(prediction["instances"])}')
276
  prediction = apply_nms(prediction, mask=True, cls_agnostic_nms=input.threshold())
277
+ print(f'Number of instances after NMS: {len(prediction["instances"])}')
278
 
279
  classes = prediction["instances"].pred_classes.tolist()
280
 
 
345
  "Empty": r.get("empty", ""),
346
  "Total": r.get("total", ""),
347
  'NMS Threshold': r.get('NMS threshold', ''),
348
+ 'Model Version': MODEL_VERSION
349
  } for r in results])
350
 
351
  # Create in-memory CSV file
python_utils/get_model.py CHANGED
@@ -4,7 +4,7 @@ import tempfile
4
  ## Urls and model variables that might change.
5
  OPTIMAL_NMS_THRESHOLD = 0.7
6
  model_page = "https://huggingface.co/TZProject/final_tz_segmentor"
7
- _model_config_url = model_page+"/resolve/main/final_model_config.yaml"
8
  MODEL_VERSION = "v1.0"
9
  discussion_url = 'https://huggingface.co/spaces/TZProject/TZSeedApp/discussions'
10
  github_repo_url = 'https://github.com/JATamura/TZSegmenting'
@@ -19,7 +19,8 @@ def get_set_up():
19
 
20
  # print("detectron2:", detectron2.__version__)
21
 
22
- def load_model(using_final_model:bool=True):
 
23
  """
24
  Load and configure a Detectron2 model predictor. The method creates a configuration
25
  object, merges it with a specified configuration file fetched from a remote URL,
@@ -57,11 +58,12 @@ def load_model(using_final_model:bool=True):
57
  ## but these are used during training and not inference and shouldn't affect the model performance
58
  ## code below
59
  cfg.MODEL.ROI_BOX_HEAD.USE_FED_LOSS = False
60
-
61
  predictor = DefaultPredictor(cfg)
62
 
63
  return predictor
64
 
 
65
  def mask_nms(masks, scores, nms_threshold=OPTIMAL_NMS_THRESHOLD):
66
  """
67
  Runs class agnostic NMS on masks/segmentations instead of the bounding boxes.
@@ -72,7 +74,7 @@ def mask_nms(masks, scores, nms_threshold=OPTIMAL_NMS_THRESHOLD):
72
  """
73
  import supervision as sv
74
  from shapely.geometry.polygon import Polygon
75
-
76
  polygons = []
77
  for mask in masks:
78
  contour = sv.mask_to_polygons(mask)
@@ -96,6 +98,7 @@ def mask_nms(masks, scores, nms_threshold=OPTIMAL_NMS_THRESHOLD):
96
  order.remove(j)
97
  return masks_kept
98
 
 
99
  def apply_nms(prediction, mask=False, cls_agnostic_nms=OPTIMAL_NMS_THRESHOLD):
100
  """
101
  Applies Non-Maximum Suppression (NMS) to filter redundant bounding boxes
@@ -118,7 +121,7 @@ def apply_nms(prediction, mask=False, cls_agnostic_nms=OPTIMAL_NMS_THRESHOLD):
118
  Defaults to ``"""
119
  from torchvision.ops import nms
120
  from detectron2.structures import Instances
121
-
122
  if mask:
123
  nms_indices = mask_nms(prediction["instances"].pred_masks.numpy(),
124
  prediction["instances"]._fields["scores"], cls_agnostic_nms)
@@ -134,6 +137,7 @@ def apply_nms(prediction, mask=False, cls_agnostic_nms=OPTIMAL_NMS_THRESHOLD):
134
 
135
  return pred
136
 
 
137
  if __name__ == '__main__':
138
  # get_set_up()
139
- load_model()
 
4
  ## Urls and model variables that might change.
5
  OPTIMAL_NMS_THRESHOLD = 0.7
6
  model_page = "https://huggingface.co/TZProject/final_tz_segmentor"
7
+ _model_config_url = model_page + "/resolve/main/final_model_config.yaml"
8
  MODEL_VERSION = "v1.0"
9
  discussion_url = 'https://huggingface.co/spaces/TZProject/TZSeedApp/discussions'
10
  github_repo_url = 'https://github.com/JATamura/TZSegmenting'
 
19
 
20
  # print("detectron2:", detectron2.__version__)
21
 
22
+
23
+ def load_model(using_final_model: bool = True):
24
  """
25
  Load and configure a Detectron2 model predictor. The method creates a configuration
26
  object, merges it with a specified configuration file fetched from a remote URL,
 
58
  ## but these are used during training and not inference and shouldn't affect the model performance
59
  ## code below
60
  cfg.MODEL.ROI_BOX_HEAD.USE_FED_LOSS = False
61
+
62
  predictor = DefaultPredictor(cfg)
63
 
64
  return predictor
65
 
66
+
67
  def mask_nms(masks, scores, nms_threshold=OPTIMAL_NMS_THRESHOLD):
68
  """
69
  Runs class agnostic NMS on masks/segmentations instead of the bounding boxes.
 
74
  """
75
  import supervision as sv
76
  from shapely.geometry.polygon import Polygon
77
+
78
  polygons = []
79
  for mask in masks:
80
  contour = sv.mask_to_polygons(mask)
 
98
  order.remove(j)
99
  return masks_kept
100
 
101
+
102
  def apply_nms(prediction, mask=False, cls_agnostic_nms=OPTIMAL_NMS_THRESHOLD):
103
  """
104
  Applies Non-Maximum Suppression (NMS) to filter redundant bounding boxes
 
121
  Defaults to ``"""
122
  from torchvision.ops import nms
123
  from detectron2.structures import Instances
124
+ print(f'applying nms with threshold {cls_agnostic_nms} and mask {mask}... \n')
125
  if mask:
126
  nms_indices = mask_nms(prediction["instances"].pred_masks.numpy(),
127
  prediction["instances"]._fields["scores"], cls_agnostic_nms)
 
137
 
138
  return pred
139
 
140
+
141
  if __name__ == '__main__':
142
  # get_set_up()
143
+ load_model()