alrichardbollans commited on
Commit
154ee04
·
1 Parent(s): 3fcc580

Add functionality for downloading segmented images

Browse files
Files changed (2) hide show
  1. app.py +68 -29
  2. styles.css +4 -0
app.py CHANGED
@@ -1,5 +1,7 @@
1
  import base64
 
2
  import tempfile
 
3
  from pathlib import Path
4
 
5
  import cv2
@@ -35,7 +37,7 @@ main_app = ui.page_fluid(
35
  multiple=True,
36
  accept=[".png", ".jpg", ".jpeg"]),
37
 
38
- ui.input_slider("threshold", f"Threshold for Discarding Overlapping Segmentations (Default: {OPTIMAL_NMS_THRESHOLD})",
39
  0, 1.0,
40
  OPTIMAL_NMS_THRESHOLD),
41
 
@@ -88,8 +90,9 @@ main_app = ui.page_fluid(
88
  """ # Style need adding here for slider for some reason
89
  ),
90
  ui.input_action_button("analyse", "Analyse", class_="btn-success"),
 
91
  # ui.input_switch("mask", "Mask", False),
92
- ui.output_ui("download_ui"),
93
  width=300
94
 
95
  ),
@@ -105,7 +108,7 @@ app_ui = ui.page_fluid(
105
  ui.div(
106
  ui.row(
107
  ui.column(5,
108
- ui.panel_title(ui.div("OrchAid", ui.output_image("logo_image", inline=True, width='100px'), class_="navbar-title"))
109
  )
110
  ),
111
  class_="nav-bar"
@@ -153,8 +156,8 @@ app_ui = ui.page_fluid(
153
  " Full details of the model, training process and evaluation can be found on the project ",
154
  ui.a("GitHub repository", href=github_repo_url, target="_blank"),
155
  ". You can find a project overview ", ui.a("here",
156
- href='https://www.kew.org/science/our-science/projects/machine-learning-to-improve-orchid-viability-testing',
157
- target="_blank"), '.'),
158
  class_="body-bar"))
159
  , id='tab'
160
  ),
@@ -189,6 +192,28 @@ def plot_ui():
189
  )
190
 
191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  @module.server
193
  def plot_server(input, output, session, r):
194
  @render.plot
@@ -200,21 +225,7 @@ def plot_server(input, output, session, r):
200
  ax.set_axis_off()
201
  # fig.add_axes(ax)
202
 
203
- v = Visualizer(r["image"][:, :, ::-1],
204
- scale=1.2, instance_mode=ColorMode.SEGMENTATION, font_size_scale=1)
205
-
206
- colours = []
207
- for cls in r["instances"].pred_classes:
208
- if cls == 0:
209
- colours.append([1, 0, 0])
210
- elif cls == 1:
211
- colours.append([1, 1, 0])
212
- elif cls == 2:
213
- colours.append([0, 0, 0])
214
-
215
- out = v.overlay_instances(masks=r["instances"].pred_masks.to("cpu"),
216
- assigned_colors=colours,
217
- alpha=input.opacity_slider())
218
 
219
  ax.imshow(cv2.cvtColor(out.get_image()[:, :, ::-1], cv2.COLOR_BGR2RGB))
220
  fig.canvas.draw()
@@ -277,14 +288,14 @@ def server(input, output, session: Session):
277
  # Run prediction with original BGR image
278
  prediction = predictor(im)
279
  print(f"Analyzing image {idx + 1} of {len(files)}")
280
- print(f"NMS threshold: {input.threshold()}")
281
  print(f'Number of instances: {len(prediction["instances"])}')
282
- prediction = apply_nms(prediction, mask=True, cls_agnostic_nms=input.threshold())
283
  print(f'Number of instances after NMS: {len(prediction["instances"])}')
284
 
285
  classes = prediction["instances"].pred_classes.tolist()
286
 
287
- results.append({
288
  "filename": file["name"],
289
  "image_base64": img_base64,
290
  "image": im,
@@ -293,8 +304,10 @@ def server(input, output, session: Session):
293
  "non-viable": classes.count(1),
294
  "empty": classes.count(2),
295
  "total": len(classes),
296
- 'NMS threshold': input.threshold()
297
- })
 
 
298
 
299
  # Update reactive value
300
  analysis_results.set(results)
@@ -335,12 +348,14 @@ def server(input, output, session: Session):
335
  return ui.div(ui_output)
336
 
337
  @render.ui
338
- def download_ui():
339
  if analysis_results.get() and not is_analyzing.get():
340
- return ui.column(4, ui.download_button("download", "Download Results", class_="btn-success"))
 
 
341
 
342
  @render.download()
343
- def download():
344
  results = analysis_results.get()
345
  # if not results:
346
  # None
@@ -355,10 +370,34 @@ def server(input, output, session: Session):
355
  } for r in results])
356
 
357
  # Create in-memory CSV file
358
- with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as tmp:
 
359
  df.to_csv(tmp.name, index=False)
360
  return tmp.name
361
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
 
363
  app = App(app_ui, server)
364
 
 
1
  import base64
2
+ import os
3
  import tempfile
4
+ import zipfile
5
  from pathlib import Path
6
 
7
  import cv2
 
37
  multiple=True,
38
  accept=[".png", ".jpg", ".jpeg"]),
39
 
40
+ ui.input_slider("nms_threshold", f"Threshold for Discarding Overlapping Segmentations (Default: {OPTIMAL_NMS_THRESHOLD})",
41
  0, 1.0,
42
  OPTIMAL_NMS_THRESHOLD),
43
 
 
90
  """ # Style need adding here for slider for some reason
91
  ),
92
  ui.input_action_button("analyse", "Analyse", class_="btn-success"),
93
+ ui.row(class_="analysis-separator"),
94
  # ui.input_switch("mask", "Mask", False),
95
+ ui.output_ui("download_results_ui"),
96
  width=300
97
 
98
  ),
 
108
  ui.div(
109
  ui.row(
110
  ui.column(5,
111
+ ui.panel_title(ui.div("OrchAId", ui.output_image("logo_image", inline=True, width='100px'), class_="navbar-title"))
112
  )
113
  ),
114
  class_="nav-bar"
 
156
  " Full details of the model, training process and evaluation can be found on the project ",
157
  ui.a("GitHub repository", href=github_repo_url, target="_blank"),
158
  ". You can find a project overview ", ui.a("here",
159
+ href='https://www.kew.org/science/our-science/projects/machine-learning-to-improve-orchid-viability-testing',
160
+ target="_blank"), '.'),
161
  class_="body-bar"))
162
  , id='tab'
163
  ),
 
192
  )
193
 
194
 
195
+ def get_overlayed_image_from_single_result(r, opacity=0.5, palette=None):
196
+ '''
197
+ From the stored result, get the overlayed image.
198
+ :param r:
199
+ :param opacity:
200
+ :return:
201
+ '''
202
+ v = Visualizer(r["image"][:, :, ::-1],
203
+ scale=1.2, instance_mode=ColorMode.SEGMENTATION, font_size_scale=1)
204
+
205
+ if palette is None:
206
+ palette = [[1, 0, 0], [1, 1, 0], [0, 0, 0]]
207
+
208
+ colours = []
209
+ for cls in r["instances"].pred_classes:
210
+ colours.append(palette[cls])
211
+ out = v.overlay_instances(masks=r["instances"].pred_masks.to("cpu"),
212
+ assigned_colors=colours,
213
+ alpha=opacity)
214
+ return out
215
+
216
+
217
  @module.server
218
  def plot_server(input, output, session, r):
219
  @render.plot
 
225
  ax.set_axis_off()
226
  # fig.add_axes(ax)
227
 
228
+ out = get_overlayed_image_from_single_result(r, input.opacity_slider())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
  ax.imshow(cv2.cvtColor(out.get_image()[:, :, ::-1], cv2.COLOR_BGR2RGB))
231
  fig.canvas.draw()
 
288
  # Run prediction with original BGR image
289
  prediction = predictor(im)
290
  print(f"Analyzing image {idx + 1} of {len(files)}")
291
+ print(f"NMS threshold: {input.nms_threshold()}")
292
  print(f'Number of instances: {len(prediction["instances"])}')
293
+ prediction = apply_nms(prediction, mask=True, cls_agnostic_nms=input.nms_threshold())
294
  print(f'Number of instances after NMS: {len(prediction["instances"])}')
295
 
296
  classes = prediction["instances"].pred_classes.tolist()
297
 
298
+ single_result = {
299
  "filename": file["name"],
300
  "image_base64": img_base64,
301
  "image": im,
 
304
  "non-viable": classes.count(1),
305
  "empty": classes.count(2),
306
  "total": len(classes),
307
+ 'NMS threshold': input.nms_threshold()
308
+ }
309
+ results.append(single_result)
310
+ # print(f'Size of result: {sys.getsizeof(single_result)} bytes')
311
 
312
  # Update reactive value
313
  analysis_results.set(results)
 
348
  return ui.div(ui_output)
349
 
350
  @render.ui
351
+ def download_results_ui():
352
  if analysis_results.get() and not is_analyzing.get():
353
+ return ui.download_button("download_results", "Download Results", class_="btn-success"), ui.download_button("download_segmented_images",
354
+ "Download Segmented Images",
355
+ class_="btn-success")
356
 
357
  @render.download()
358
+ def download_results():
359
  results = analysis_results.get()
360
  # if not results:
361
  # None
 
370
  } for r in results])
371
 
372
  # Create in-memory CSV file
373
+ with tempfile.NamedTemporaryFile(delete=False, delete_on_close=True, suffix=".csv") as tmp:
374
+ print(f'result tmp csv: {tmp.name}')
375
  df.to_csv(tmp.name, index=False)
376
  return tmp.name
377
 
378
+ @render.download()
379
+ def download_segmented_images():
380
+ results = analysis_results.get()
381
+
382
+ tmp_img_files = []
383
+
384
+ with tempfile.TemporaryDirectory() as temp_dir:
385
+ print(os.listdir(os.path.dirname(temp_dir)))
386
+ for r in results:
387
+ # open your files here
388
+ named_file = os.path.join(temp_dir, r['filename'])
389
+ img = get_overlayed_image_from_single_result(r)
390
+ img.save(named_file)
391
+ tmp_img_files.append(named_file)
392
+
393
+ with tempfile.NamedTemporaryFile(delete=False, delete_on_close=True, suffix=".zip") as tmp:
394
+
395
+ with zipfile.ZipFile(tmp.name, 'w') as zipMe:
396
+ for file in tmp_img_files:
397
+ zipMe.write(file, compress_type=zipfile.ZIP_DEFLATED)
398
+
399
+ return tmp.name
400
+
401
 
402
  app = App(app_ui, server)
403
 
styles.css CHANGED
@@ -62,6 +62,10 @@ position: sticky;
62
  border-top: 2px solid #ddd;
63
  background: white;
64
  }
 
 
 
 
65
 
66
  /* Sidebar styling */
67
  .card.shiny-input-container {
 
62
  border-top: 2px solid #ddd;
63
  background: white;
64
  }
65
+ .analysis-separator {
66
+ border-top: 2px solid #ddd;
67
+ background: white;
68
+ }
69
 
70
  /* Sidebar styling */
71
  .card.shiny-input-container {