import base64
import os
import tempfile
import zipfile
from pathlib import Path
import cv2
from io import BytesIO
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from shiny import App, ui, render, reactive, Session, module
from detectron2.utils.visualizer import Visualizer, ColorMode
from python_utils import load_model, apply_nms, OPTIMAL_NMS_THRESHOLD, MODEL_VERSION, discussion_url, model_page, github_repo_url
# Load data and compute static values
app_dir = Path(__file__).parent
protocol_url = 'https://pgomba.github.io/orchid_protocol/'
acknowledgement_text = ("The OrchAId TZ viability dataset used to develop the model was created by the Royal Botanic Gardens, Kew, Silo National des "
"Graines Forestieres, Madagascar, the Ministry of Agriculture, Lands, Housing and Environment, Monsterrat, "
"Instituto de Investigação Agrária de Moçambique, Mozambique, Departmento de Recursos Naturales y Ambientales, "
"Puerto Rico & the National Parks Trust of the Virgin Islands.")
disclaimer_text = (
ui.HTML("DISCLAIMER"),
': the evaluation of the model applies to our dataset and there are many factors that may influence performance of the '
'model on new images.'
' We recommend visually inspecting at least a few images to ensure the model is performing as expected on your batch of images.')
# Load the prediction model
predictor = load_model()
main_app = ui.page_fluid(
ui.div(
ui.layout_sidebar(
ui.sidebar(
ui.input_file("upload", "Upload Images",
multiple=True,
accept=[".png", ".jpg", ".jpeg"]),
ui.input_slider("nms_threshold", f"Threshold for Discarding Overlapping Segmentations (Default: {OPTIMAL_NMS_THRESHOLD})",
0, 1.0,
OPTIMAL_NMS_THRESHOLD),
ui.tags.style("""
.irs.irs--shiny .irs-single { /* square with number */
background-color: #357abd;
font-size: 1rem;
}
.irs.irs--shiny .irs-min { /* square with number */
font-size: 1rem;
}
.irs.irs--shiny .irs-max { /* square with number */
font-size: 1rem;
}
.irs-bar.irs-bar--single { /* line */
background-color: #357abd;
}
.irs-handle.single { /* circle */
background-color: #357abd;
}
.irs-handle.single:hover { /* circle */
background-color: #2c3e50;
}
# .irs-handle.single:focus { /* circle */
# outline: 5px solid #ffab00 !important; /* Highly visible gold/orange outline */
# outline-offset: 0px;
# box-shadow: 0 0 0 6px rgba(255, 171, 0, 0.25); /* Soft glow for extra contrast */
# z-index: 2;
# transition: outline-color 0.2s, box-shadow 0.2s;
# }
#
# .irs-handle.single:focus-visible { /* circle */
# outline: 5px solid #ffab00 !important; /* Highly visible gold/orange outline */
# outline-offset: 0px;
# box-shadow: 0 0 0 6px rgba(255, 171, 0, 0.25); /* Soft glow for extra contrast */
# z-index: 2;
# transition: outline-color 0.2s, box-shadow 0.2s;
# }
#
# .irs-handle.single:focus-within { /* circle */
# outline: 5px solid #ffab00 !important; /* Highly visible gold/orange outline */
# outline-offset: 0px;
# box-shadow: 0 0 0 6px rgba(255, 171, 0, 0.25); /* Soft glow for extra contrast */
# z-index: 2;
# transition: outline-color 0.2s, box-shadow 0.2s;
# }
""" # Style need adding here for slider for some reason
),
ui.input_action_button("analyse", "Analyse", class_="btn-success"),
# Add script to set 'aria-label' on input, since direct attribute isn't supported
ui.tags.script("""
setTimeout(function() {
var fileInput = document.querySelector('input[type=file][id^=upload]');
if (fileInput) fileInput.setAttribute('aria-label', 'Upload images');
}, 100);
"""),
ui.tags.script("""
setTimeout(function() {
var analyseBtn = document.querySelector('button[id^="analyse"]');
if (analyseBtn) analyseBtn.setAttribute('aria-label', 'Analyse uploaded images');
}, 100);
"""),
ui.row(class_="analysis-separator"),
# ui.input_switch("mask", "Mask", False),
ui.output_ui("download_results_ui"),
width=300
),
ui.output_ui("results_container"),
border=False,
border_radius=False
), class_="side-bar"
)
)
app_ui = ui.page_fluid(
# Set charset in head
ui.tags.meta(charset="utf-8"),
# Set lang attribute on
ui.tags.script("""
document.documentElement.setAttribute('lang', 'en');
"""),
ui.include_css("styles.css"),
ui.div(
ui.row(
ui.column(2,
ui.panel_title(ui.div(ui.output_image("logo_image", inline=True), class_="navbar-logo")),
class_="navbar-col"
),
ui.column(4, ui.div('A tool to automate the analysis of epiphytic orchid viability tests with machine learning', class_="navbar-text"),
class_="navbar-col")
),
class_="nav-bar"
),
ui.navset_tab(
ui.nav_panel('App', main_app
),
ui.nav_panel('Instructions', ui.div(
# ui.h4("Using this App"),
ui.p(
"This app uses a computer vision model trained to analyse images of orchid tetrazolium chloride tests to count the number of "
"viable, non-viable and empty orchid seeds. "),
ui.p(
"The app is built for use with ", ui.HTML("specific types of images"),
" -- the protocol for taking images compatible with this model is available on ",
ui.a("GitHub", href=protocol_url, target="_blank", **{'aria-label': 'Image taking protocols'}),
". The protocol will shortly be available in English, Indonesian, Thai, French, Spanish, Portuguese, Arabic, Mandarin, Malagasy and Japanese."),
ui.p(
"To use this app, upload images* and click 'Analyse'."
" Segmented images will be displayed in the right-hand panel, showing viable seeds in red, non-viable in yellow and empty in black."
" An opacity slider can be used to adjust the transparency of the segmentation masks."
" The counts will also be displayed as text and results can be downloaded using the 'Download Results' button, providing a data "
"table with the filename of each image and the counts of viable, non-viable and empty seeds."),
ui.p(
f"Before analysing images it is possible to change the threshold used to discard overlapping segmentations produced by the model. "
f"The default threshold is {OPTIMAL_NMS_THRESHOLD} as this was found to be optimal for our data, but you can adjust this value in "
f"the slider."
f" We recommend leaving this as the default, and only decreasing the value if you find that your images have many overlapping seeds "
f"and some of them are not being included in the output. Similarly, you can increase this value if your images have very few "
f"overlapping seeds and the output includes multiple segmentations of the same seed."),
ui.p("Note that ", ui.HTML("the upper limit for the number of detected seeds in a single image is 800"), ' and ',
ui.HTML("the app has a maximum capacity of approx. 50 images"),
'. When running on a CPU the app takes around 60 seconds to analyse a single image, compared to 2 seconds on a T4 GPU.'),
ui.p(" If you have any feedback on the app, please start a discussion on the project ",
ui.a("HuggingFace Space", href=discussion_url, target="_blank", **{'aria-label': 'Project discussion space'}), '.'
),
ui.p(disclaimer_text),
ui.p("* Images are stored temporarily on HuggingFace servers and deleted at the end of your session."),
class_="body-bar"
)),
ui.nav_panel('Model Overview',
ui.div(
ui.p(
" Full details of the model, training process and evaluation can be found on the project ",
ui.a("GitHub repository", href=github_repo_url, target="_blank", **{'aria-label': 'GitHub repository'}),
". You can find a project overview ", ui.a("here",
href='https://www.kew.org/science/our-science/projects/machine-learning-to-improve-orchid-viability-testing',
target="_blank", **{'aria-label': 'Project overview'}), '.'),
class_="body-bar"))
, id='tab'
),
ui.div(
ui.h4("Acknowledgements"),
ui.p(
acknowledgement_text
), ui.p(
"The developers acknowledge Research Computing at the James Hutton Institute for providing computational resources and technical "
"support for the 'UK’s Crop Diversity Bioinformatics HPC' (BBSRC grants BB/S019669/1 and BB/X019683/1), use of which has contributed to "
"the development of the model used in this app."),
class_="acknowledgement-bar"
),
ui.div(
ui.layout_column_wrap(
ui.output_image("rbg_kew", height='100%', fill=True),
ui.output_image("bloomberg", height='100%', fill=True),
ui.output_image("brin", height='100%', fill=True),
ui.output_image("abg", height='100%', fill=True),
), class_="footer"
)
)
# This allows individual opacity sliders.
@module.ui
def plot_ui():
opacity_slider = ui.input_slider("opacity_slider", "Opacity", 0, 1.0, 0.5)
return ui.row(
ui.output_plot("plot_prediction"),
opacity_slider
)
def get_overlayed_image_from_single_result(r, opacity=0.5, palette=None):
'''
From the stored result, get the overlayed image.
:param r:
:param opacity:
:return:
'''
v = Visualizer(r["image"][:, :, ::-1],
scale=1.2, instance_mode=ColorMode.SEGMENTATION, font_size_scale=1)
if palette is None:
palette = [[1, 0, 0], [1, 1, 0], [0, 0, 0]]
colours = []
for cls in r["instances"].pred_classes:
colours.append(palette[cls])
out = v.overlay_instances(masks=r["instances"].pred_masks.to("cpu"),
assigned_colors=colours,
alpha=opacity)
return out
@module.server
def plot_server(input, output, session, r):
@render.plot
def plot_prediction():
plt.ioff()
fig, ax = plt.subplots()
# ax = plt.Axes(fig, [0., 0., 1., 1.])
ax.set_axis_off()
# fig.add_axes(ax)
out = get_overlayed_image_from_single_result(r, input.opacity_slider())
ax.imshow(cv2.cvtColor(out.get_image()[:, :, ::-1], cv2.COLOR_BGR2RGB))
fig.canvas.draw()
fig.canvas.flush_events()
def server(input, output, session: Session):
@render.image
def rbg_kew():
img = {"src": "assets/rbg_kew.jpg", 'aria-label': 'Royal Botanic Gardens, Kew logo'}
return img
@render.image
def bloomberg():
img = {"src": "assets/bloomberg_philanthropies.jpg", 'aria-label': 'Bloomberg Philanthropies logo'}
return img
@render.image
def brin():
img = {"src": "assets/brin.jpg", 'aria-label': 'Badan Riset Dan Inovasi Nasional logo'}
return img
@render.image
def abg():
img = {"src": "assets/abg.png", 'aria-label': 'Atlantic Botanical Garden logo'}
return img
@render.image
def logo_image():
img = {"src": "assets/logo3.png", "height": "100px", "width": "138px", 'alt': 'OrchAId', 'aria-label': 'OrchAId logo'}
return img
analysis_results = reactive.Value([])
is_analyzing = reactive.Value(False) # Track if analysis is in progress
@reactive.Effect
@reactive.event(input.analyse)
async def process_images():
is_analyzing.set(True) # Set analyzing flag to True
files = input.upload()
if not files:
is_analyzing.set(False) # Reset flag if no files
return
results = []
with tempfile.TemporaryDirectory() as temp_dir:
for idx, file in enumerate(files):
# Read image using OpenCV
im = cv2.imread(file["datapath"])
# Convert BGR to RGB for display
try:
im_rgb = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
pil_img = Image.fromarray(im_rgb)
# Convert to base64 for HTML display
buffered = BytesIO()
pil_img.save(buffered, format="PNG")
img_base64 = base64.b64encode(buffered.getvalue()).decode()
# Run prediction with original BGR image
prediction = predictor(im)
print(f"Analyzing image {idx + 1} of {len(files)}")
print(f"NMS threshold: {input.nms_threshold()}")
print(f'Number of instances: {len(prediction["instances"])}')
prediction = apply_nms(prediction, mask=True, cls_agnostic_nms=input.nms_threshold())
print(f'Number of instances after NMS: {len(prediction["instances"])}')
classes = prediction["instances"].pred_classes.tolist()
single_result = {
"filename": file["name"],
"image_base64": img_base64,
"image": im,
**prediction,
"viable": classes.count(0),
"non-viable": classes.count(1),
"empty": classes.count(2),
"total": len(classes),
'NMS threshold': input.nms_threshold()
}
results.append(single_result)
except cv2.error as e:
print(f"Error reading image {file['name']}: {e}")
single_result = {
"filename": file["name"]
}
results.append(single_result)
# print(f'Size of result: {sys.getsizeof(single_result)} bytes')
# Update reactive value
analysis_results.set(results)
is_analyzing.set(False) # Set analyzing flag to False when done
@render.ui
def results_container():
results = analysis_results.get()
if not results:
return ui.div("No results yet. Upload images and click 'Analyse'.",
class_="text-muted")
if is_analyzing.get():
return ui.div("Analyzing...",
class_="text-muted")
ui_output = []
for idx, r in enumerate(results):
plot_server(f"plot_{idx}", r=r)
ui_output.append(
ui.div(
ui.h5(r['filename'], style="margin-top: 15px;"),
ui.div(
ui.span(f"Viable ",
ui.HTML('(■)'),
f" = {r.get('viable', '? ')}", style="margin: 0 15px;"),
ui.span(f"Non-Viable ", ui.HTML('(■)'),
f" = {r.get('non-viable', '? ')}", style="margin: 0 15px;"),
ui.span(f"Empty ", ui.HTML('(■)'),
f" = {r.get('empty', '? ')}", style="margin: 0 15px;"),
ui.span(f"Total = {r.get('total', '? ')}", style="margin: 0 15px;"),
class_="results-text"
),
ui.row(
# ui.column(4, ui.img(src=f"data:image/png;base64,{r['image_base64']}")),
plot_ui(f"plot_{idx}"),
),
class_="card p-3"
)
)
return ui.div(ui_output)
@render.ui
def download_results_ui():
if analysis_results.get() and not is_analyzing.get():
# results = analysis_results.get()
# current_nms = input.nms_threshold()
# print(f'Current NMS threshold: {current_nms}')
# if results[0].get('NMS threshold') != current_nms:
# print('NMS changed')
# else:
return ui.download_button("download_results", "Download Results", class_="btn-success"), ui.download_button("download_segmented_images",
"Download Segmented Images",
class_="btn-success")
@render.download()
def download_results():
results = analysis_results.get()
# if not results:
# None
df = pd.DataFrame([{
"Filename": r["filename"],
"Viable": r.get("viable", ""),
"Non-Viable": r.get("non-viable", ""),
"Empty": r.get("empty", ""),
"Total": r.get("total", ""),
'NMS Threshold': r.get('NMS threshold', ''),
'Model Version': MODEL_VERSION
} for r in results])
# Create in-memory CSV file
with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as tmp:
# print(f'result tmp csv: {tmp.name}')
df.to_csv(tmp.name, index=False)
return tmp.name
@render.download()
def download_segmented_images():
results = analysis_results.get()
tmp_img_files = []
with tempfile.TemporaryDirectory() as temp_dir:
# print(os.listdir(os.path.dirname(temp_dir)))
for r in results:
# open your files here
named_file = os.path.join(temp_dir, r['filename'])
try:
img = get_overlayed_image_from_single_result(r)
img.save(named_file)
tmp_img_files.append(named_file)
except KeyError as e:
print(f"Error reading image {r['filename']}: {e}")
with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp:
with zipfile.ZipFile(tmp.name, 'w') as zipMe:
for file in tmp_img_files:
zipMe.write(file, compress_type=zipfile.ZIP_DEFLATED)
return tmp.name
app = App(app_ui, server)
# --------------------------------------------------------
# Reactive calculations and effects
# --------------------------------------------------------