Spaces:
Sleeping
Sleeping
| import base64 | |
| import os | |
| import tempfile | |
| import zipfile | |
| from pathlib import Path | |
| import cv2 | |
| from io import BytesIO | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| from shiny import App, ui, render, reactive, Session, module | |
| from detectron2.utils.visualizer import Visualizer, ColorMode | |
| from python_utils import load_model, apply_nms, OPTIMAL_NMS_THRESHOLD, MODEL_VERSION, discussion_url, model_page, github_repo_url | |
| # Load data and compute static values | |
| app_dir = Path(__file__).parent | |
| protocol_url = 'https://pgomba.github.io/orchid_protocol/' | |
| acknowledgement_text = ("The OrchAId TZ viability dataset used to develop the model was created by the Royal Botanic Gardens, Kew, Silo National des " | |
| "Graines Forestieres, Madagascar, the Ministry of Agriculture, Lands, Housing and Environment, Monsterrat, " | |
| "Instituto de Investigação Agrária de Moçambique, Mozambique, Departmento de Recursos Naturales y Ambientales, " | |
| "Puerto Rico & the National Parks Trust of the Virgin Islands.") | |
| disclaimer_text = ( | |
| ui.HTML("<b>DISCLAIMER</b>"), | |
| ': the evaluation of the model applies to our dataset and there are many factors that may influence performance of the ' | |
| 'model on new images.' | |
| ' We recommend visually inspecting at least a few images to ensure the model is performing as expected on your batch of images.') | |
| # Load the prediction model | |
| predictor = load_model() | |
| main_app = ui.page_fluid( | |
| ui.div( | |
| ui.layout_sidebar( | |
| ui.sidebar( | |
| ui.input_file("upload", "Upload Images", | |
| multiple=True, | |
| accept=[".png", ".jpg", ".jpeg"]), | |
| ui.input_slider("nms_threshold", f"Threshold for Discarding Overlapping Segmentations (Default: {OPTIMAL_NMS_THRESHOLD})", | |
| 0, 1.0, | |
| OPTIMAL_NMS_THRESHOLD), | |
| ui.tags.style(""" | |
| .irs.irs--shiny .irs-single { /* square with number */ | |
| background-color: #357abd; | |
| font-size: 1rem; | |
| } | |
| .irs.irs--shiny .irs-min { /* square with number */ | |
| font-size: 1rem; | |
| } | |
| .irs.irs--shiny .irs-max { /* square with number */ | |
| font-size: 1rem; | |
| } | |
| .irs-bar.irs-bar--single { /* line */ | |
| background-color: #357abd; | |
| } | |
| .irs-handle.single { /* circle */ | |
| background-color: #357abd; | |
| } | |
| .irs-handle.single:hover { /* circle */ | |
| background-color: #2c3e50; | |
| } | |
| # .irs-handle.single:focus { /* circle */ | |
| # outline: 5px solid #ffab00 !important; /* Highly visible gold/orange outline */ | |
| # outline-offset: 0px; | |
| # box-shadow: 0 0 0 6px rgba(255, 171, 0, 0.25); /* Soft glow for extra contrast */ | |
| # z-index: 2; | |
| # transition: outline-color 0.2s, box-shadow 0.2s; | |
| # } | |
| # | |
| # .irs-handle.single:focus-visible { /* circle */ | |
| # outline: 5px solid #ffab00 !important; /* Highly visible gold/orange outline */ | |
| # outline-offset: 0px; | |
| # box-shadow: 0 0 0 6px rgba(255, 171, 0, 0.25); /* Soft glow for extra contrast */ | |
| # z-index: 2; | |
| # transition: outline-color 0.2s, box-shadow 0.2s; | |
| # } | |
| # | |
| # .irs-handle.single:focus-within { /* circle */ | |
| # outline: 5px solid #ffab00 !important; /* Highly visible gold/orange outline */ | |
| # outline-offset: 0px; | |
| # box-shadow: 0 0 0 6px rgba(255, 171, 0, 0.25); /* Soft glow for extra contrast */ | |
| # z-index: 2; | |
| # transition: outline-color 0.2s, box-shadow 0.2s; | |
| # } | |
| """ # Style need adding here for slider for some reason | |
| ), | |
| ui.input_action_button("analyse", "Analyse", class_="btn-success"), | |
| # Add script to set 'aria-label' on input, since direct attribute isn't supported | |
| ui.tags.script(""" | |
| setTimeout(function() { | |
| var fileInput = document.querySelector('input[type=file][id^=upload]'); | |
| if (fileInput) fileInput.setAttribute('aria-label', 'Upload images'); | |
| }, 100); | |
| """), | |
| ui.tags.script(""" | |
| setTimeout(function() { | |
| var analyseBtn = document.querySelector('button[id^="analyse"]'); | |
| if (analyseBtn) analyseBtn.setAttribute('aria-label', 'Analyse uploaded images'); | |
| }, 100); | |
| """), | |
| ui.row(class_="analysis-separator"), | |
| # ui.input_switch("mask", "Mask", False), | |
| ui.output_ui("download_results_ui"), | |
| width=300 | |
| ), | |
| ui.output_ui("results_container"), | |
| border=False, | |
| border_radius=False | |
| ), class_="side-bar" | |
| ) | |
| ) | |
| app_ui = ui.page_fluid( | |
| # Set charset in head | |
| ui.tags.meta(charset="utf-8"), | |
| # Set lang attribute on <html> | |
| ui.tags.script(""" | |
| document.documentElement.setAttribute('lang', 'en'); | |
| """), | |
| ui.include_css("styles.css"), | |
| ui.div( | |
| ui.row( | |
| ui.column(2, | |
| ui.panel_title(ui.div(ui.output_image("logo_image", inline=True), class_="navbar-logo")), | |
| class_="navbar-col" | |
| ), | |
| ui.column(4, ui.div('A tool to automate the analysis of epiphytic orchid viability tests with machine learning', class_="navbar-text"), | |
| class_="navbar-col") | |
| ), | |
| class_="nav-bar" | |
| ), | |
| ui.navset_tab( | |
| ui.nav_panel('App', main_app | |
| ), | |
| ui.nav_panel('Instructions', ui.div( | |
| # ui.h4("Using this App"), | |
| ui.p( | |
| "This app uses a computer vision model trained to analyse images of orchid tetrazolium chloride tests to count the number of " | |
| "viable, non-viable and empty orchid seeds. "), | |
| ui.p( | |
| "The app is built for use with ", ui.HTML("<b>specific types of images</b>"), | |
| " -- the protocol for taking images compatible with this model is available on ", | |
| ui.a("GitHub", href=protocol_url, target="_blank", **{'aria-label': 'Image taking protocols'}), | |
| ". The protocol will shortly be available in English, Indonesian, Thai, French, Spanish, Portuguese, Arabic, Mandarin, Malagasy and Japanese."), | |
| ui.p( | |
| "To use this app, upload images* and click 'Analyse'." | |
| " Segmented images will be displayed in the right-hand panel, showing viable seeds in red, non-viable in yellow and empty in black." | |
| " An opacity slider can be used to adjust the transparency of the segmentation masks." | |
| " The counts will also be displayed as text and results can be downloaded using the 'Download Results' button, providing a data " | |
| "table with the filename of each image and the counts of viable, non-viable and empty seeds."), | |
| ui.p( | |
| f"Before analysing images it is possible to change the threshold used to discard overlapping segmentations produced by the model. " | |
| f"The default threshold is {OPTIMAL_NMS_THRESHOLD} as this was found to be optimal for our data, but you can adjust this value in " | |
| f"the slider." | |
| f" We recommend leaving this as the default, and only decreasing the value if you find that your images have many overlapping seeds " | |
| f"and some of them are not being included in the output. Similarly, you can increase this value if your images have very few " | |
| f"overlapping seeds and the output includes multiple segmentations of the same seed."), | |
| ui.p("Note that ", ui.HTML("<b>the upper limit for the number of detected seeds in a single image is 800</b>"), ' and ', | |
| ui.HTML("<b>the app has a maximum capacity of approx. 50 images</b>"), | |
| '. When running on a CPU the app takes around 60 seconds to analyse a single image, compared to 2 seconds on a T4 GPU.'), | |
| ui.p(" If you have any feedback on the app, please start a discussion on the project ", | |
| ui.a("HuggingFace Space", href=discussion_url, target="_blank", **{'aria-label': 'Project discussion space'}), '.' | |
| ), | |
| ui.p(disclaimer_text), | |
| ui.p("* Images are stored temporarily on HuggingFace servers and deleted at the end of your session."), | |
| class_="body-bar" | |
| )), | |
| ui.nav_panel('Model Overview', | |
| ui.div( | |
| ui.p( | |
| " Full details of the model, training process and evaluation can be found on the project ", | |
| ui.a("GitHub repository", href=github_repo_url, target="_blank", **{'aria-label': 'GitHub repository'}), | |
| ". You can find a project overview ", ui.a("here", | |
| href='https://www.kew.org/science/our-science/projects/machine-learning-to-improve-orchid-viability-testing', | |
| target="_blank", **{'aria-label': 'Project overview'}), '.'), | |
| class_="body-bar")) | |
| , id='tab' | |
| ), | |
| ui.div( | |
| ui.h4("Acknowledgements"), | |
| ui.p( | |
| acknowledgement_text | |
| ), ui.p( | |
| "The developers acknowledge Research Computing at the James Hutton Institute for providing computational resources and technical " | |
| "support for the 'UK’s Crop Diversity Bioinformatics HPC' (BBSRC grants BB/S019669/1 and BB/X019683/1), use of which has contributed to " | |
| "the development of the model used in this app."), | |
| class_="acknowledgement-bar" | |
| ), | |
| ui.div( | |
| ui.layout_column_wrap( | |
| ui.output_image("rbg_kew", height='100%', fill=True), | |
| ui.output_image("bloomberg", height='100%', fill=True), | |
| ui.output_image("brin", height='100%', fill=True), | |
| ui.output_image("abg", height='100%', fill=True), | |
| ), class_="footer" | |
| ) | |
| ) | |
| # This allows individual opacity sliders. | |
| def plot_ui(): | |
| opacity_slider = ui.input_slider("opacity_slider", "Opacity", 0, 1.0, 0.5) | |
| return ui.row( | |
| ui.output_plot("plot_prediction"), | |
| opacity_slider | |
| ) | |
| def get_overlayed_image_from_single_result(r, opacity=0.5, palette=None): | |
| ''' | |
| From the stored result, get the overlayed image. | |
| :param r: | |
| :param opacity: | |
| :return: | |
| ''' | |
| v = Visualizer(r["image"][:, :, ::-1], | |
| scale=1.2, instance_mode=ColorMode.SEGMENTATION, font_size_scale=1) | |
| if palette is None: | |
| palette = [[1, 0, 0], [1, 1, 0], [0, 0, 0]] | |
| colours = [] | |
| for cls in r["instances"].pred_classes: | |
| colours.append(palette[cls]) | |
| out = v.overlay_instances(masks=r["instances"].pred_masks.to("cpu"), | |
| assigned_colors=colours, | |
| alpha=opacity) | |
| return out | |
| def plot_server(input, output, session, r): | |
| def plot_prediction(): | |
| plt.ioff() | |
| fig, ax = plt.subplots() | |
| # ax = plt.Axes(fig, [0., 0., 1., 1.]) | |
| ax.set_axis_off() | |
| # fig.add_axes(ax) | |
| out = get_overlayed_image_from_single_result(r, input.opacity_slider()) | |
| ax.imshow(cv2.cvtColor(out.get_image()[:, :, ::-1], cv2.COLOR_BGR2RGB)) | |
| fig.canvas.draw() | |
| fig.canvas.flush_events() | |
| def server(input, output, session: Session): | |
| def rbg_kew(): | |
| img = {"src": "assets/rbg_kew.jpg", 'aria-label': 'Royal Botanic Gardens, Kew logo'} | |
| return img | |
| def bloomberg(): | |
| img = {"src": "assets/bloomberg_philanthropies.jpg", 'aria-label': 'Bloomberg Philanthropies logo'} | |
| return img | |
| def brin(): | |
| img = {"src": "assets/brin.jpg", 'aria-label': 'Badan Riset Dan Inovasi Nasional logo'} | |
| return img | |
| def abg(): | |
| img = {"src": "assets/abg.png", 'aria-label': 'Atlantic Botanical Garden logo'} | |
| return img | |
| def logo_image(): | |
| img = {"src": "assets/logo3.png", "height": "100px", "width": "138px", 'alt': 'OrchAId', 'aria-label': 'OrchAId logo'} | |
| return img | |
| analysis_results = reactive.Value([]) | |
| is_analyzing = reactive.Value(False) # Track if analysis is in progress | |
| async def process_images(): | |
| is_analyzing.set(True) # Set analyzing flag to True | |
| files = input.upload() | |
| if not files: | |
| is_analyzing.set(False) # Reset flag if no files | |
| return | |
| results = [] | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| for idx, file in enumerate(files): | |
| # Read image using OpenCV | |
| im = cv2.imread(file["datapath"]) | |
| # Convert BGR to RGB for display | |
| try: | |
| im_rgb = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) | |
| pil_img = Image.fromarray(im_rgb) | |
| # Convert to base64 for HTML display | |
| buffered = BytesIO() | |
| pil_img.save(buffered, format="PNG") | |
| img_base64 = base64.b64encode(buffered.getvalue()).decode() | |
| # Run prediction with original BGR image | |
| prediction = predictor(im) | |
| print(f"Analyzing image {idx + 1} of {len(files)}") | |
| print(f"NMS threshold: {input.nms_threshold()}") | |
| print(f'Number of instances: {len(prediction["instances"])}') | |
| prediction = apply_nms(prediction, mask=True, cls_agnostic_nms=input.nms_threshold()) | |
| print(f'Number of instances after NMS: {len(prediction["instances"])}') | |
| classes = prediction["instances"].pred_classes.tolist() | |
| single_result = { | |
| "filename": file["name"], | |
| "image_base64": img_base64, | |
| "image": im, | |
| **prediction, | |
| "viable": classes.count(0), | |
| "non-viable": classes.count(1), | |
| "empty": classes.count(2), | |
| "total": len(classes), | |
| 'NMS threshold': input.nms_threshold() | |
| } | |
| results.append(single_result) | |
| except cv2.error as e: | |
| print(f"Error reading image {file['name']}: {e}") | |
| single_result = { | |
| "filename": file["name"] | |
| } | |
| results.append(single_result) | |
| # print(f'Size of result: {sys.getsizeof(single_result)} bytes') | |
| # Update reactive value | |
| analysis_results.set(results) | |
| is_analyzing.set(False) # Set analyzing flag to False when done | |
| def results_container(): | |
| results = analysis_results.get() | |
| if not results: | |
| return ui.div("No results yet. Upload images and click 'Analyse'.", | |
| class_="text-muted") | |
| if is_analyzing.get(): | |
| return ui.div("Analyzing...", | |
| class_="text-muted") | |
| ui_output = [] | |
| for idx, r in enumerate(results): | |
| plot_server(f"plot_{idx}", r=r) | |
| ui_output.append( | |
| ui.div( | |
| ui.h5(r['filename'], style="margin-top: 15px;"), | |
| ui.div( | |
| ui.span(f"Viable ", | |
| ui.HTML('(<span style="color: rgba(255,0,0,1); font-weight:bold;">■</span>)'), | |
| f" = {r.get('viable', '? ')}", style="margin: 0 15px;"), | |
| ui.span(f"Non-Viable ", ui.HTML('(<span style="color: rgba(220,220,0,1); font-weight:bold">■</span>)'), | |
| f" = {r.get('non-viable', '? ')}", style="margin: 0 15px;"), | |
| ui.span(f"Empty ", ui.HTML('(<span style="color: rgba(0,0,0,0.5); font-weight:bold">■</span>)'), | |
| f" = {r.get('empty', '? ')}", style="margin: 0 15px;"), | |
| ui.span(f"Total = {r.get('total', '? ')}", style="margin: 0 15px;"), | |
| class_="results-text" | |
| ), | |
| ui.row( | |
| # ui.column(4, ui.img(src=f"data:image/png;base64,{r['image_base64']}")), | |
| plot_ui(f"plot_{idx}"), | |
| ), | |
| class_="card p-3" | |
| ) | |
| ) | |
| return ui.div(ui_output) | |
| def download_results_ui(): | |
| if analysis_results.get() and not is_analyzing.get(): | |
| # results = analysis_results.get() | |
| # current_nms = input.nms_threshold() | |
| # print(f'Current NMS threshold: {current_nms}') | |
| # if results[0].get('NMS threshold') != current_nms: | |
| # print('NMS changed') | |
| # else: | |
| return ui.download_button("download_results", "Download Results", class_="btn-success"), ui.download_button("download_segmented_images", | |
| "Download Segmented Images", | |
| class_="btn-success") | |
| def download_results(): | |
| results = analysis_results.get() | |
| # if not results: | |
| # None | |
| df = pd.DataFrame([{ | |
| "Filename": r["filename"], | |
| "Viable": r.get("viable", ""), | |
| "Non-Viable": r.get("non-viable", ""), | |
| "Empty": r.get("empty", ""), | |
| "Total": r.get("total", ""), | |
| 'NMS Threshold': r.get('NMS threshold', ''), | |
| 'Model Version': MODEL_VERSION | |
| } for r in results]) | |
| # Create in-memory CSV file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as tmp: | |
| # print(f'result tmp csv: {tmp.name}') | |
| df.to_csv(tmp.name, index=False) | |
| return tmp.name | |
| def download_segmented_images(): | |
| results = analysis_results.get() | |
| tmp_img_files = [] | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| # print(os.listdir(os.path.dirname(temp_dir))) | |
| for r in results: | |
| # open your files here | |
| named_file = os.path.join(temp_dir, r['filename']) | |
| try: | |
| img = get_overlayed_image_from_single_result(r) | |
| img.save(named_file) | |
| tmp_img_files.append(named_file) | |
| except KeyError as e: | |
| print(f"Error reading image {r['filename']}: {e}") | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp: | |
| with zipfile.ZipFile(tmp.name, 'w') as zipMe: | |
| for file in tmp_img_files: | |
| zipMe.write(file, compress_type=zipfile.ZIP_DEFLATED) | |
| return tmp.name | |
| app = App(app_ui, server) | |
| # -------------------------------------------------------- | |
| # Reactive calculations and effects | |
| # -------------------------------------------------------- | |