File size: 20,622 Bytes
5107c4c
154ee04
aef7d7f
154ee04
bf149e6
 
5107c4c
 
aef7d7f
 
782be65
5107c4c
3091cbb
bacb7aa
aef7d7f
ae2089c
ff1a4ee
5107c4c
bf149e6
 
5001cc3
a8b7713
 
 
 
022a96f
e492da7
 
022a96f
 
5107c4c
 
5001cc3
bf149e6
d223897
 
 
 
ff1a4ee
 
4527a8b
154ee04
a8b7713
ae2089c
4527a8b
 
6ee66b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4527a8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ee66b7
4527a8b
5b8db50
 
 
 
 
 
 
 
 
 
 
 
 
 
154ee04
5f37d9f
154ee04
ff1a4ee
bf149e6
d223897
 
 
 
6ee66b7
5001cc3
 
 
 
5b8db50
 
 
 
 
 
 
5001cc3
e492da7
4527a8b
 
e492da7
 
 
022a96f
e492da7
 
4527a8b
 
 
a8b7713
 
 
 
9fdb0f8
a8b7713
 
022a96f
a8b7713
 
 
5b8db50
e952d33
a8b7713
bbb29cb
a8b7713
 
 
 
 
 
 
 
 
 
 
 
af51f51
 
 
022a96f
a8b7713
5b8db50
a8b7713
022a96f
62eee9a
a8b7713
 
 
 
 
 
 
 
 
5b8db50
3fcc580
154ee04
5b8db50
a8b7713
 
4527a8b
50baaf9
11dbff2
50baaf9
 
a8b7713
 
 
 
50baaf9
 
364efdb
 
ff1a4ee
 
 
 
 
364efdb
aef7d7f
 
5001cc3
6b0b5e8
02ec4d8
14d6e80
9b18ace
887fb44
 
54d659f
887fb44
aef7d7f
ff1a4ee
154ee04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02ec4d8
490d05b
43fa076
9b18ace
8903ad8
02ec4d8
 
773350b
 
 
 
154ee04
bf149e6
773350b
2f0eb6c
 
5107c4c
ff1a4ee
6b0b5e8
364efdb
 
af51f51
364efdb
bf149e6
364efdb
 
af51f51
364efdb
 
 
 
af51f51
364efdb
 
 
 
af51f51
364efdb
bf149e6
4527a8b
 
af51f51
4527a8b
 
aef7d7f
50baaf9
aef7d7f
 
4527a8b
aef7d7f
50baaf9
aef7d7f
 
50baaf9
aef7d7f
 
5107c4c
 
 
 
 
 
 
d377551
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154ee04
5107c4c
 
 
50baaf9
573f3ce
aef7d7f
 
 
 
4527a8b
aef7d7f
50baaf9
 
 
aef7d7f
7307f26
a84661d
8903ad8
bf149e6
fc8a7fb
 
 
 
bbb29cb
 
 
 
 
 
 
fc8a7fb
 
 
d6c63d6
a5e1eec
d6c63d6
 
fc8a7fb
 
 
bf149e6
 
76c1ce8
50baaf9
154ee04
56df201
50baaf9
56df201
 
 
 
 
 
022a96f
154ee04
 
50baaf9
0e9299b
154ee04
aef7d7f
0e9299b
 
 
 
 
 
 
11dbff2
 
ae2089c
0e9299b
 
 
2374391
d377551
0e9299b
 
aef7d7f
154ee04
 
 
 
 
 
 
d377551
154ee04
 
 
d377551
 
 
 
 
 
154ee04
2374391
154ee04
 
 
 
 
 
 
aef7d7f
 
a1f301e
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
import base64
import os
import tempfile
import zipfile
from pathlib import Path

import cv2
from io import BytesIO

import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from shiny import App, ui, render, reactive, Session, module
from detectron2.utils.visualizer import Visualizer, ColorMode

from python_utils import load_model, apply_nms, OPTIMAL_NMS_THRESHOLD, MODEL_VERSION, discussion_url, model_page, github_repo_url

# Load data and compute static values
app_dir = Path(__file__).parent

protocol_url = 'https://pgomba.github.io/orchid_protocol/'
acknowledgement_text = ("The OrchAId TZ viability dataset used to develop the model was created by the Royal Botanic Gardens, Kew, Silo National des "
                        "Graines Forestieres, Madagascar, the Ministry of Agriculture, Lands, Housing and Environment, Monsterrat, "
                        "Instituto de Investigação Agrária de Moçambique, Mozambique, Departmento de Recursos Naturales y Ambientales, "
                        "Puerto Rico & the National Parks Trust of the Virgin Islands.")
disclaimer_text = (
    ui.HTML("<b>DISCLAIMER</b>"),
    ': the evaluation of the model applies to our dataset and there are many factors that may influence performance of the '
    'model on new images.'
    ' We recommend visually inspecting at least a few images to ensure the model is performing as expected on your batch of images.')
# Load the prediction model
predictor = load_model()
main_app = ui.page_fluid(

    ui.div(
        ui.layout_sidebar(
            ui.sidebar(
                ui.input_file("upload", "Upload Images",
                              multiple=True,
                              accept=[".png", ".jpg", ".jpeg"]),

                ui.input_slider("nms_threshold", f"Threshold for Discarding Overlapping Segmentations (Default: {OPTIMAL_NMS_THRESHOLD})",
                                0, 1.0,
                                OPTIMAL_NMS_THRESHOLD),

                ui.tags.style(""" 
                    .irs.irs--shiny .irs-single { /* square with number */
                        background-color: #357abd;
                        font-size: 1rem;
                    }
                    .irs.irs--shiny .irs-min { /* square with number */

                        font-size: 1rem;
                    }
                    .irs.irs--shiny .irs-max { /* square with number */
                        font-size: 1rem;
                    }
                    .irs-bar.irs-bar--single { /* line */
                        background-color: #357abd;
                    }
                    .irs-handle.single { /* circle */
                        background-color: #357abd;
                    }
                    .irs-handle.single:hover { /* circle */
                        background-color: #2c3e50;
                    }
                    
                    # .irs-handle.single:focus { /* circle */
                    #     outline: 5px solid #ffab00 !important;       /* Highly visible gold/orange outline */
                    #     outline-offset: 0px;
                    #     box-shadow: 0 0 0 6px rgba(255, 171, 0, 0.25);   /* Soft glow for extra contrast */
                    #     z-index: 2;
                    #     transition: outline-color 0.2s, box-shadow 0.2s;
                    # }
                    # 
                    # .irs-handle.single:focus-visible { /* circle */
                    #     outline: 5px solid #ffab00 !important;       /* Highly visible gold/orange outline */
                    #     outline-offset: 0px;
                    #     box-shadow: 0 0 0 6px rgba(255, 171, 0, 0.25);   /* Soft glow for extra contrast */
                    #     z-index: 2;
                    #     transition: outline-color 0.2s, box-shadow 0.2s;
                    # }
                    # 
                    # .irs-handle.single:focus-within { /* circle */
                    #     outline: 5px solid #ffab00 !important;       /* Highly visible gold/orange outline */
                    #     outline-offset: 0px;
                    #     box-shadow: 0 0 0 6px rgba(255, 171, 0, 0.25);   /* Soft glow for extra contrast */
                    #     z-index: 2;
                    #     transition: outline-color 0.2s, box-shadow 0.2s;
                    # }

                    """  # Style need adding here for slider for some reason
                              ),
                ui.input_action_button("analyse", "Analyse", class_="btn-success"),
                # Add script to set 'aria-label' on input, since direct attribute isn't supported
                ui.tags.script("""
                            setTimeout(function() {
                                var fileInput = document.querySelector('input[type=file][id^=upload]');
                                if (fileInput) fileInput.setAttribute('aria-label', 'Upload images');
                            }, 100);
                        """),
                ui.tags.script("""
                                setTimeout(function() {
                                    var analyseBtn = document.querySelector('button[id^="analyse"]');
                                    if (analyseBtn) analyseBtn.setAttribute('aria-label', 'Analyse uploaded images');
                                }, 100);
                            """),

                ui.row(class_="analysis-separator"),
                # ui.input_switch("mask", "Mask", False),
                ui.output_ui("download_results_ui"),
                width=300

            ),
            ui.output_ui("results_container"),
            border=False,
            border_radius=False
        ), class_="side-bar"
    )
)

app_ui = ui.page_fluid(
    # Set charset in head
    ui.tags.meta(charset="utf-8"),
    # Set lang attribute on <html>
    ui.tags.script("""
            document.documentElement.setAttribute('lang', 'en');
        """),

    ui.include_css("styles.css"),

    ui.div(
        ui.row(
            ui.column(2,
                      ui.panel_title(ui.div(ui.output_image("logo_image", inline=True), class_="navbar-logo")),
                      class_="navbar-col"
                      ),
            ui.column(4, ui.div('A tool to automate the analysis of epiphytic orchid viability tests with machine learning', class_="navbar-text"),
                      class_="navbar-col")
        ),
        class_="nav-bar"
    ),
    ui.navset_tab(
        ui.nav_panel('App', main_app
                     ),
        ui.nav_panel('Instructions', ui.div(
            # ui.h4("Using this App"),
            ui.p(
                "This app uses a computer vision model trained to analyse images of orchid tetrazolium chloride tests to count the number of "
                "viable, non-viable and empty orchid seeds. "),
            ui.p(
                "The app is built for use with ", ui.HTML("<b>specific types of images</b>"),
                " -- the protocol for taking images compatible with this model is available on ",
                ui.a("GitHub", href=protocol_url, target="_blank", **{'aria-label': 'Image taking protocols'}),
                ". The protocol will shortly be available in English, Indonesian, Thai, French, Spanish, Portuguese, Arabic, Mandarin, Malagasy and Japanese."),
            ui.p(
                "To use this app, upload images* and click 'Analyse'."
                " Segmented images will be displayed in the right-hand panel, showing viable seeds in red, non-viable in yellow and empty in black."
                " An opacity slider can be used to adjust the transparency of the segmentation masks."
                " The counts will also be displayed as text and results can be downloaded using the 'Download Results' button, providing a data "
                "table with the filename of each image and the counts of viable, non-viable and empty seeds."),

            ui.p(
                f"Before analysing images it is possible to change the threshold used to discard overlapping segmentations produced by the model. "
                f"The default threshold is {OPTIMAL_NMS_THRESHOLD} as this was found to be optimal for our data, but you can adjust this value in "
                f"the slider."
                f" We recommend leaving this as the default, and only decreasing the value if you find that your images have many overlapping seeds "
                f"and some of them are not being included in the output. Similarly, you can increase this value if your images have very few "
                f"overlapping seeds and the output includes multiple segmentations of the same seed."),
            ui.p("Note that ", ui.HTML("<b>the upper limit for the number of detected seeds in a single image is 800</b>"), ' and ',
                 ui.HTML("<b>the app has a maximum capacity of approx. 50 images</b>"),
                 '. When running on a CPU the app takes around 60 seconds to analyse a single image, compared to 2 seconds on a T4 GPU.'),

            ui.p(" If you have any feedback on the app, please start a discussion on the project ",
                 ui.a("HuggingFace Space", href=discussion_url, target="_blank", **{'aria-label': 'Project discussion space'}), '.'
                 ),
            ui.p(disclaimer_text),
            ui.p("* Images are stored temporarily on HuggingFace servers and deleted at the end of your session."),

            class_="body-bar"

        )),
        ui.nav_panel('Model Overview',
                     ui.div(
                         ui.p(

                             " Full details of the model, training process and evaluation can be found on the project ",
                             ui.a("GitHub repository", href=github_repo_url, target="_blank", **{'aria-label': 'GitHub repository'}),
                             ". You can find a project overview ", ui.a("here",
                                                                        href='https://www.kew.org/science/our-science/projects/machine-learning-to-improve-orchid-viability-testing',
                                                                        target="_blank", **{'aria-label': 'Project overview'}), '.'),
                         class_="body-bar"))
        , id='tab'
    ),
    ui.div(
        ui.h4("Acknowledgements"),
        ui.p(
            acknowledgement_text
        ), ui.p(
            "The developers acknowledge Research Computing at the James Hutton Institute for providing computational resources and technical "
            "support for the 'UK’s Crop Diversity Bioinformatics HPC' (BBSRC grants BB/S019669/1 and BB/X019683/1), use of which has contributed to "
            "the development of the model used in this app."),
        class_="acknowledgement-bar"
    ),
    ui.div(
        ui.layout_column_wrap(
            ui.output_image("rbg_kew", height='100%', fill=True),
            ui.output_image("bloomberg", height='100%', fill=True),
            ui.output_image("brin", height='100%', fill=True),
            ui.output_image("abg", height='100%', fill=True),
        ), class_="footer"
    )
)


# This allows individual opacity sliders.
@module.ui
def plot_ui():
    opacity_slider = ui.input_slider("opacity_slider", "Opacity", 0, 1.0, 0.5)
    return ui.row(
        ui.output_plot("plot_prediction"),
        opacity_slider
    )


def get_overlayed_image_from_single_result(r, opacity=0.5, palette=None):
    '''
    From the stored result, get the overlayed image.
    :param r:
    :param opacity:
    :return:
    '''
    v = Visualizer(r["image"][:, :, ::-1],
                   scale=1.2, instance_mode=ColorMode.SEGMENTATION, font_size_scale=1)

    if palette is None:
        palette = [[1, 0, 0], [1, 1, 0], [0, 0, 0]]

    colours = []
    for cls in r["instances"].pred_classes:
        colours.append(palette[cls])
    out = v.overlay_instances(masks=r["instances"].pred_masks.to("cpu"),
                              assigned_colors=colours,
                              alpha=opacity)
    return out


@module.server
def plot_server(input, output, session, r):
    @render.plot
    def plot_prediction():
        plt.ioff()
        fig, ax = plt.subplots()

        # ax = plt.Axes(fig, [0., 0., 1., 1.])
        ax.set_axis_off()
        # fig.add_axes(ax)

        out = get_overlayed_image_from_single_result(r, input.opacity_slider())

        ax.imshow(cv2.cvtColor(out.get_image()[:, :, ::-1], cv2.COLOR_BGR2RGB))
        fig.canvas.draw()
        fig.canvas.flush_events()


def server(input, output, session: Session):
    @render.image
    def rbg_kew():
        img = {"src": "assets/rbg_kew.jpg", 'aria-label': 'Royal Botanic Gardens, Kew logo'}
        return img

    @render.image
    def bloomberg():
        img = {"src": "assets/bloomberg_philanthropies.jpg", 'aria-label': 'Bloomberg Philanthropies logo'}
        return img

    @render.image
    def brin():
        img = {"src": "assets/brin.jpg", 'aria-label': 'Badan Riset Dan Inovasi Nasional logo'}
        return img

    @render.image
    def abg():
        img = {"src": "assets/abg.png", 'aria-label': 'Atlantic Botanical Garden logo'}
        return img

    @render.image
    def logo_image():
        img = {"src": "assets/logo3.png", "height": "100px", "width": "138px", 'alt': 'OrchAId', 'aria-label': 'OrchAId logo'}
        return img

    analysis_results = reactive.Value([])
    is_analyzing = reactive.Value(False)  # Track if analysis is in progress

    @reactive.Effect
    @reactive.event(input.analyse)
    async def process_images():
        is_analyzing.set(True)  # Set analyzing flag to True
        files = input.upload()
        if not files:
            is_analyzing.set(False)  # Reset flag if no files
            return

        results = []
        with tempfile.TemporaryDirectory() as temp_dir:
            for idx, file in enumerate(files):
                # Read image using OpenCV
                im = cv2.imread(file["datapath"])

                # Convert BGR to RGB for display
                try:
                    im_rgb = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
                    pil_img = Image.fromarray(im_rgb)

                    # Convert to base64 for HTML display
                    buffered = BytesIO()
                    pil_img.save(buffered, format="PNG")
                    img_base64 = base64.b64encode(buffered.getvalue()).decode()

                    # Run prediction with original BGR image
                    prediction = predictor(im)
                    print(f"Analyzing image {idx + 1} of {len(files)}")
                    print(f"NMS threshold: {input.nms_threshold()}")
                    print(f'Number of instances: {len(prediction["instances"])}')
                    prediction = apply_nms(prediction, mask=True, cls_agnostic_nms=input.nms_threshold())
                    print(f'Number of instances after NMS: {len(prediction["instances"])}')

                    classes = prediction["instances"].pred_classes.tolist()

                    single_result = {
                        "filename": file["name"],
                        "image_base64": img_base64,
                        "image": im,
                        **prediction,
                        "viable": classes.count(0),
                        "non-viable": classes.count(1),
                        "empty": classes.count(2),
                        "total": len(classes),
                        'NMS threshold': input.nms_threshold()
                    }
                    results.append(single_result)

                except cv2.error as e:
                    print(f"Error reading image {file['name']}: {e}")
                    single_result = {
                        "filename": file["name"]
                    }
                    results.append(single_result)

                # print(f'Size of result: {sys.getsizeof(single_result)} bytes')

        # Update reactive value
        analysis_results.set(results)
        is_analyzing.set(False)  # Set analyzing flag to False when done

    @render.ui
    def results_container():
        results = analysis_results.get()
        if not results:
            return ui.div("No results yet. Upload images and click 'Analyse'.",
                          class_="text-muted")
        if is_analyzing.get():
            return ui.div("Analyzing...",
                          class_="text-muted")

        ui_output = []
        for idx, r in enumerate(results):
            plot_server(f"plot_{idx}", r=r)

            ui_output.append(
                ui.div(
                    ui.h5(r['filename'], style="margin-top: 15px;"),
                    ui.div(
                        ui.span(f"Viable ",
                                ui.HTML('(<span style="color: rgba(255,0,0,1); font-weight:bold;">&#9632</span>)'),
                                f" = {r.get('viable', '? ')}", style="margin: 0 15px;"),
                        ui.span(f"Non-Viable ", ui.HTML('(<span style="color: rgba(220,220,0,1); font-weight:bold">&#9632</span>)'),
                                f" = {r.get('non-viable', '? ')}", style="margin: 0 15px;"),
                        ui.span(f"Empty ", ui.HTML('(<span style="color: rgba(0,0,0,0.5); font-weight:bold">&#9632</span>)'),
                                f" = {r.get('empty', '? ')}", style="margin: 0 15px;"),
                        ui.span(f"Total = {r.get('total', '? ')}", style="margin: 0 15px;"),
                        class_="results-text"
                    ),
                    ui.row(
                        # ui.column(4, ui.img(src=f"data:image/png;base64,{r['image_base64']}")),
                        plot_ui(f"plot_{idx}"),
                    ),
                    class_="card p-3"
                )
            )

        return ui.div(ui_output)

    @render.ui
    def download_results_ui():

        if analysis_results.get() and not is_analyzing.get():
            # results = analysis_results.get()
            # current_nms = input.nms_threshold()
            # print(f'Current NMS threshold: {current_nms}')
            # if results[0].get('NMS threshold') != current_nms:
            #     print('NMS changed')
            # else:
            return ui.download_button("download_results", "Download Results", class_="btn-success"), ui.download_button("download_segmented_images",
                                                                                                                        "Download Segmented Images",
                                                                                                                        class_="btn-success")

    @render.download()
    def download_results():
        results = analysis_results.get()
        # if not results:
        #     None
        df = pd.DataFrame([{
            "Filename": r["filename"],
            "Viable": r.get("viable", ""),
            "Non-Viable": r.get("non-viable", ""),
            "Empty": r.get("empty", ""),
            "Total": r.get("total", ""),
            'NMS Threshold': r.get('NMS threshold', ''),
            'Model Version': MODEL_VERSION
        } for r in results])

        # Create in-memory CSV file
        with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as tmp:
            # print(f'result tmp csv: {tmp.name}')
            df.to_csv(tmp.name, index=False)
            return tmp.name

    @render.download()
    def download_segmented_images():
        results = analysis_results.get()

        tmp_img_files = []

        with tempfile.TemporaryDirectory() as temp_dir:
            # print(os.listdir(os.path.dirname(temp_dir)))
            for r in results:
                # open your files here
                named_file = os.path.join(temp_dir, r['filename'])
                try:
                    img = get_overlayed_image_from_single_result(r)
                    img.save(named_file)
                    tmp_img_files.append(named_file)
                except KeyError as e:
                    print(f"Error reading image {r['filename']}: {e}")

            with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp:

                with zipfile.ZipFile(tmp.name, 'w') as zipMe:
                    for file in tmp_img_files:
                        zipMe.write(file, compress_type=zipfile.ZIP_DEFLATED)

                return tmp.name


app = App(app_ui, server)

# --------------------------------------------------------
# Reactive calculations and effects
# --------------------------------------------------------