import base64 import os import tempfile import zipfile from pathlib import Path import cv2 from io import BytesIO import pandas as pd import matplotlib.pyplot as plt from PIL import Image from shiny import App, ui, render, reactive, Session, module from detectron2.utils.visualizer import Visualizer, ColorMode from python_utils import load_model, apply_nms, OPTIMAL_NMS_THRESHOLD, MODEL_VERSION, discussion_url, model_page, github_repo_url # Load data and compute static values app_dir = Path(__file__).parent protocol_url = 'https://pgomba.github.io/orchid_protocol/' acknowledgement_text = ("The OrchAId TZ viability dataset used to develop the model was created by the Royal Botanic Gardens, Kew, Silo National des " "Graines Forestieres, Madagascar, the Ministry of Agriculture, Lands, Housing and Environment, Monsterrat, " "Instituto de Investigação Agrária de Moçambique, Mozambique, Departmento de Recursos Naturales y Ambientales, " "Puerto Rico & the National Parks Trust of the Virgin Islands.") disclaimer_text = ( ui.HTML("DISCLAIMER"), ': the evaluation of the model applies to our dataset and there are many factors that may influence performance of the ' 'model on new images.' ' We recommend visually inspecting at least a few images to ensure the model is performing as expected on your batch of images.') # Load the prediction model predictor = load_model() main_app = ui.page_fluid( ui.div( ui.layout_sidebar( ui.sidebar( ui.input_file("upload", "Upload Images", multiple=True, accept=[".png", ".jpg", ".jpeg"]), ui.input_slider("nms_threshold", f"Threshold for Discarding Overlapping Segmentations (Default: {OPTIMAL_NMS_THRESHOLD})", 0, 1.0, OPTIMAL_NMS_THRESHOLD), ui.tags.style(""" .irs.irs--shiny .irs-single { /* square with number */ background-color: #357abd; font-size: 1rem; } .irs.irs--shiny .irs-min { /* square with number */ font-size: 1rem; } .irs.irs--shiny .irs-max { /* square with number */ font-size: 1rem; } .irs-bar.irs-bar--single { /* line */ background-color: #357abd; } .irs-handle.single { /* circle */ background-color: #357abd; } .irs-handle.single:hover { /* circle */ background-color: #2c3e50; } # .irs-handle.single:focus { /* circle */ # outline: 5px solid #ffab00 !important; /* Highly visible gold/orange outline */ # outline-offset: 0px; # box-shadow: 0 0 0 6px rgba(255, 171, 0, 0.25); /* Soft glow for extra contrast */ # z-index: 2; # transition: outline-color 0.2s, box-shadow 0.2s; # } # # .irs-handle.single:focus-visible { /* circle */ # outline: 5px solid #ffab00 !important; /* Highly visible gold/orange outline */ # outline-offset: 0px; # box-shadow: 0 0 0 6px rgba(255, 171, 0, 0.25); /* Soft glow for extra contrast */ # z-index: 2; # transition: outline-color 0.2s, box-shadow 0.2s; # } # # .irs-handle.single:focus-within { /* circle */ # outline: 5px solid #ffab00 !important; /* Highly visible gold/orange outline */ # outline-offset: 0px; # box-shadow: 0 0 0 6px rgba(255, 171, 0, 0.25); /* Soft glow for extra contrast */ # z-index: 2; # transition: outline-color 0.2s, box-shadow 0.2s; # } """ # Style need adding here for slider for some reason ), ui.input_action_button("analyse", "Analyse", class_="btn-success"), # Add script to set 'aria-label' on input, since direct attribute isn't supported ui.tags.script(""" setTimeout(function() { var fileInput = document.querySelector('input[type=file][id^=upload]'); if (fileInput) fileInput.setAttribute('aria-label', 'Upload images'); }, 100); """), ui.tags.script(""" setTimeout(function() { var analyseBtn = document.querySelector('button[id^="analyse"]'); if (analyseBtn) analyseBtn.setAttribute('aria-label', 'Analyse uploaded images'); }, 100); """), ui.row(class_="analysis-separator"), # ui.input_switch("mask", "Mask", False), ui.output_ui("download_results_ui"), width=300 ), ui.output_ui("results_container"), border=False, border_radius=False ), class_="side-bar" ) ) app_ui = ui.page_fluid( # Set charset in head ui.tags.meta(charset="utf-8"), # Set lang attribute on ui.tags.script(""" document.documentElement.setAttribute('lang', 'en'); """), ui.include_css("styles.css"), ui.div( ui.row( ui.column(2, ui.panel_title(ui.div(ui.output_image("logo_image", inline=True), class_="navbar-logo")), class_="navbar-col" ), ui.column(4, ui.div('A tool to automate the analysis of epiphytic orchid viability tests with machine learning', class_="navbar-text"), class_="navbar-col") ), class_="nav-bar" ), ui.navset_tab( ui.nav_panel('App', main_app ), ui.nav_panel('Instructions', ui.div( # ui.h4("Using this App"), ui.p( "This app uses a computer vision model trained to analyse images of orchid tetrazolium chloride tests to count the number of " "viable, non-viable and empty orchid seeds. "), ui.p( "The app is built for use with ", ui.HTML("specific types of images"), " -- the protocol for taking images compatible with this model is available on ", ui.a("GitHub", href=protocol_url, target="_blank", **{'aria-label': 'Image taking protocols'}), ". The protocol will shortly be available in English, Indonesian, Thai, French, Spanish, Portuguese, Arabic, Mandarin, Malagasy and Japanese."), ui.p( "To use this app, upload images* and click 'Analyse'." " Segmented images will be displayed in the right-hand panel, showing viable seeds in red, non-viable in yellow and empty in black." " An opacity slider can be used to adjust the transparency of the segmentation masks." " The counts will also be displayed as text and results can be downloaded using the 'Download Results' button, providing a data " "table with the filename of each image and the counts of viable, non-viable and empty seeds."), ui.p( f"Before analysing images it is possible to change the threshold used to discard overlapping segmentations produced by the model. " f"The default threshold is {OPTIMAL_NMS_THRESHOLD} as this was found to be optimal for our data, but you can adjust this value in " f"the slider." f" We recommend leaving this as the default, and only decreasing the value if you find that your images have many overlapping seeds " f"and some of them are not being included in the output. Similarly, you can increase this value if your images have very few " f"overlapping seeds and the output includes multiple segmentations of the same seed."), ui.p("Note that ", ui.HTML("the upper limit for the number of detected seeds in a single image is 800"), ' and ', ui.HTML("the app has a maximum capacity of approx. 50 images"), '. When running on a CPU the app takes around 60 seconds to analyse a single image, compared to 2 seconds on a T4 GPU.'), ui.p(" If you have any feedback on the app, please start a discussion on the project ", ui.a("HuggingFace Space", href=discussion_url, target="_blank", **{'aria-label': 'Project discussion space'}), '.' ), ui.p(disclaimer_text), ui.p("* Images are stored temporarily on HuggingFace servers and deleted at the end of your session."), class_="body-bar" )), ui.nav_panel('Model Overview', ui.div( ui.p( " Full details of the model, training process and evaluation can be found on the project ", ui.a("GitHub repository", href=github_repo_url, target="_blank", **{'aria-label': 'GitHub repository'}), ". You can find a project overview ", ui.a("here", href='https://www.kew.org/science/our-science/projects/machine-learning-to-improve-orchid-viability-testing', target="_blank", **{'aria-label': 'Project overview'}), '.'), class_="body-bar")) , id='tab' ), ui.div( ui.h4("Acknowledgements"), ui.p( acknowledgement_text ), ui.p( "The developers acknowledge Research Computing at the James Hutton Institute for providing computational resources and technical " "support for the 'UK’s Crop Diversity Bioinformatics HPC' (BBSRC grants BB/S019669/1 and BB/X019683/1), use of which has contributed to " "the development of the model used in this app."), class_="acknowledgement-bar" ), ui.div( ui.layout_column_wrap( ui.output_image("rbg_kew", height='100%', fill=True), ui.output_image("bloomberg", height='100%', fill=True), ui.output_image("brin", height='100%', fill=True), ui.output_image("abg", height='100%', fill=True), ), class_="footer" ) ) # This allows individual opacity sliders. @module.ui def plot_ui(): opacity_slider = ui.input_slider("opacity_slider", "Opacity", 0, 1.0, 0.5) return ui.row( ui.output_plot("plot_prediction"), opacity_slider ) def get_overlayed_image_from_single_result(r, opacity=0.5, palette=None): ''' From the stored result, get the overlayed image. :param r: :param opacity: :return: ''' v = Visualizer(r["image"][:, :, ::-1], scale=1.2, instance_mode=ColorMode.SEGMENTATION, font_size_scale=1) if palette is None: palette = [[1, 0, 0], [1, 1, 0], [0, 0, 0]] colours = [] for cls in r["instances"].pred_classes: colours.append(palette[cls]) out = v.overlay_instances(masks=r["instances"].pred_masks.to("cpu"), assigned_colors=colours, alpha=opacity) return out @module.server def plot_server(input, output, session, r): @render.plot def plot_prediction(): plt.ioff() fig, ax = plt.subplots() # ax = plt.Axes(fig, [0., 0., 1., 1.]) ax.set_axis_off() # fig.add_axes(ax) out = get_overlayed_image_from_single_result(r, input.opacity_slider()) ax.imshow(cv2.cvtColor(out.get_image()[:, :, ::-1], cv2.COLOR_BGR2RGB)) fig.canvas.draw() fig.canvas.flush_events() def server(input, output, session: Session): @render.image def rbg_kew(): img = {"src": "assets/rbg_kew.jpg", 'aria-label': 'Royal Botanic Gardens, Kew logo'} return img @render.image def bloomberg(): img = {"src": "assets/bloomberg_philanthropies.jpg", 'aria-label': 'Bloomberg Philanthropies logo'} return img @render.image def brin(): img = {"src": "assets/brin.jpg", 'aria-label': 'Badan Riset Dan Inovasi Nasional logo'} return img @render.image def abg(): img = {"src": "assets/abg.png", 'aria-label': 'Atlantic Botanical Garden logo'} return img @render.image def logo_image(): img = {"src": "assets/logo3.png", "height": "100px", "width": "138px", 'alt': 'OrchAId', 'aria-label': 'OrchAId logo'} return img analysis_results = reactive.Value([]) is_analyzing = reactive.Value(False) # Track if analysis is in progress @reactive.Effect @reactive.event(input.analyse) async def process_images(): is_analyzing.set(True) # Set analyzing flag to True files = input.upload() if not files: is_analyzing.set(False) # Reset flag if no files return results = [] with tempfile.TemporaryDirectory() as temp_dir: for idx, file in enumerate(files): # Read image using OpenCV im = cv2.imread(file["datapath"]) # Convert BGR to RGB for display try: im_rgb = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) pil_img = Image.fromarray(im_rgb) # Convert to base64 for HTML display buffered = BytesIO() pil_img.save(buffered, format="PNG") img_base64 = base64.b64encode(buffered.getvalue()).decode() # Run prediction with original BGR image prediction = predictor(im) print(f"Analyzing image {idx + 1} of {len(files)}") print(f"NMS threshold: {input.nms_threshold()}") print(f'Number of instances: {len(prediction["instances"])}') prediction = apply_nms(prediction, mask=True, cls_agnostic_nms=input.nms_threshold()) print(f'Number of instances after NMS: {len(prediction["instances"])}') classes = prediction["instances"].pred_classes.tolist() single_result = { "filename": file["name"], "image_base64": img_base64, "image": im, **prediction, "viable": classes.count(0), "non-viable": classes.count(1), "empty": classes.count(2), "total": len(classes), 'NMS threshold': input.nms_threshold() } results.append(single_result) except cv2.error as e: print(f"Error reading image {file['name']}: {e}") single_result = { "filename": file["name"] } results.append(single_result) # print(f'Size of result: {sys.getsizeof(single_result)} bytes') # Update reactive value analysis_results.set(results) is_analyzing.set(False) # Set analyzing flag to False when done @render.ui def results_container(): results = analysis_results.get() if not results: return ui.div("No results yet. Upload images and click 'Analyse'.", class_="text-muted") if is_analyzing.get(): return ui.div("Analyzing...", class_="text-muted") ui_output = [] for idx, r in enumerate(results): plot_server(f"plot_{idx}", r=r) ui_output.append( ui.div( ui.h5(r['filename'], style="margin-top: 15px;"), ui.div( ui.span(f"Viable ", ui.HTML('()'), f" = {r.get('viable', '? ')}", style="margin: 0 15px;"), ui.span(f"Non-Viable ", ui.HTML('()'), f" = {r.get('non-viable', '? ')}", style="margin: 0 15px;"), ui.span(f"Empty ", ui.HTML('()'), f" = {r.get('empty', '? ')}", style="margin: 0 15px;"), ui.span(f"Total = {r.get('total', '? ')}", style="margin: 0 15px;"), class_="results-text" ), ui.row( # ui.column(4, ui.img(src=f"data:image/png;base64,{r['image_base64']}")), plot_ui(f"plot_{idx}"), ), class_="card p-3" ) ) return ui.div(ui_output) @render.ui def download_results_ui(): if analysis_results.get() and not is_analyzing.get(): # results = analysis_results.get() # current_nms = input.nms_threshold() # print(f'Current NMS threshold: {current_nms}') # if results[0].get('NMS threshold') != current_nms: # print('NMS changed') # else: return ui.download_button("download_results", "Download Results", class_="btn-success"), ui.download_button("download_segmented_images", "Download Segmented Images", class_="btn-success") @render.download() def download_results(): results = analysis_results.get() # if not results: # None df = pd.DataFrame([{ "Filename": r["filename"], "Viable": r.get("viable", ""), "Non-Viable": r.get("non-viable", ""), "Empty": r.get("empty", ""), "Total": r.get("total", ""), 'NMS Threshold': r.get('NMS threshold', ''), 'Model Version': MODEL_VERSION } for r in results]) # Create in-memory CSV file with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as tmp: # print(f'result tmp csv: {tmp.name}') df.to_csv(tmp.name, index=False) return tmp.name @render.download() def download_segmented_images(): results = analysis_results.get() tmp_img_files = [] with tempfile.TemporaryDirectory() as temp_dir: # print(os.listdir(os.path.dirname(temp_dir))) for r in results: # open your files here named_file = os.path.join(temp_dir, r['filename']) try: img = get_overlayed_image_from_single_result(r) img.save(named_file) tmp_img_files.append(named_file) except KeyError as e: print(f"Error reading image {r['filename']}: {e}") with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp: with zipfile.ZipFile(tmp.name, 'w') as zipMe: for file in tmp_img_files: zipMe.write(file, compress_type=zipfile.ZIP_DEFLATED) return tmp.name app = App(app_ui, server) # -------------------------------------------------------- # Reactive calculations and effects # --------------------------------------------------------