OrchAId / app.py
alrichardbollans's picture
Update app.py
39bb4b0 verified
raw
history blame
7.36 kB
# import detectron2
# import torch
#
# # Check this in logs
# try:
# print(f"Is CUDA available: {torch.cuda.is_available()}")
# # True
# print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
# except:
# print('Couldnt find CUDA device')
import base64
import tempfile
import cv2
from io import BytesIO
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from shiny import App, ui, render, reactive, Session, module
from detectron2.utils.visualizer import Visualizer, ColorMode
from detectron2.data import Metadata
from python_utils import load_model, apply_nms
# Load data and compute static values
from shared import app_dir
# Load the prediction model
predictor = load_model()
app_ui = ui.page_fluid(
ui.include_css("styles.css"),
ui.div(
ui.row(
ui.column(6,
ui.panel_title(ui.div("Orchid TZ Viability Analyzer", class_="navbar-title"))
),
),
class_="nav-bar"
),
ui.div(
ui.layout_sidebar(
ui.sidebar(
ui.input_file("upload", "Upload Images",
multiple=True,
accept=[".png", ".jpg", ".jpeg"]),
ui.input_slider("threshold", "Threshold", 0, 1.0, 0.8),
ui.input_action_button("analyze", "Analyze", class_="btn-success"),
# ui.input_switch("mask", "Mask", False),
ui.column(4,ui.download_button("download", "Download Results", class_="btn-primary")),
width =300
),
ui.output_ui("results_container"),
border=False,
border_radius=False
)
),
# ui.div(
# ui.layout_column_wrap(
# ui.output_image("rbg_kew"),
# ui.output_image("bloomberg"),
# ui.output_image("brin"),
# ui.output_image("abg"),
# )
# class_="footer"
# )
)
@module.ui
def plot_ui():
opacity_slider = ui.input_slider("opacity_slider", "Opacity", 0, 1.0, 0.5)
return ui.row(
ui.output_plot("plot_prediction"),
opacity_slider
)
@module.server
def plot_server(input, output, session, r):
@render.plot
def plot_prediction():
plt.ioff()
fig, ax = plt.subplots()
# ax = plt.Axes(fig, [0., 0., 1., 1.])
ax.set_axis_off()
# fig.add_axes(ax)
v = Visualizer(r["image"][:, :, ::-1],
scale=1.2, instance_mode=ColorMode.SEGMENTATION, font_size_scale=1)
colours = []
for cls in r["instances"].pred_classes:
if cls == 0:
colours.append([1,0,0])
elif cls == 1:
colours.append([1,1,0])
elif cls == 2:
colours.append([0,0,0])
out = v.overlay_instances(masks = r["instances"].pred_masks.to("cpu"),
assigned_colors = colours,
alpha = input.opacity_slider())
ax.imshow(cv2.cvtColor(out.get_image()[:, :, ::-1], cv2.COLOR_BGR2RGB))
fig.canvas.draw()
fig.canvas.flush_events()
def server(input, output, session: Session):
# @render.image
# def rbg_kew():
# img = {"src": "logos/rbg_kew.png", "height": "100px"}
# return img
# @render.image
# def bloomberg():
# img = {"src": "logos/bloomberg.png", "height": "100px"}
# return img
# @render.image
# def brin():
# img = {"src": "logos/brin.png", "height": "100px"}
# return img
# @render.image
# def abg():
# img = {"src": "logos/abg.png", "height": "100px"}
# return img
analysis_results = reactive.Value([])
@reactive.Effect
@reactive.event(input.analyze)
async def process_images():
files = input.upload()
if not files:
return
results = []
with tempfile.TemporaryDirectory() as temp_dir:
for idx, file in enumerate(files):
# Read image using OpenCV
im = cv2.imread(file["datapath"])
# Convert BGR to RGB for display
im_rgb = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
pil_img = Image.fromarray(im_rgb)
# Convert to base64 for HTML display
buffered = BytesIO()
pil_img.save(buffered, format="PNG")
img_base64 = base64.b64encode(buffered.getvalue()).decode()
# Run prediction with original BGR image
prediction = predictor(im)
prediction = apply_nms(prediction, True, input.threshold())
classes = prediction["instances"].pred_classes.tolist()
results.append({
"filename": file["name"],
"image_base64": img_base64,
"image": im,
**prediction,
"viable": classes.count(0),
"non-viable": classes.count(1),
"empty": classes.count(2),
"total": len(classes)
})
# Update reactive value
analysis_results.set(results)
@output
@render.ui
def results_container():
results = analysis_results.get()
if not results:
return ui.div("No results yet. Upload images and click 'Analyze'.",
class_="text-muted")
ui_output = []
for idx, r in enumerate(results):
plot_server(f"plot_{idx}", r=r)
ui_output.append(
ui.div(
ui.h5(r['filename'], style="margin-top: 15px;"),
ui.div(
ui.span(f"Viable = {r.get('viable', '? ')}", style="margin: 0 15px;"),
ui.span(f"Non-Viable = {r.get('non-viable', '? ')}", style="margin: 0 15px;"),
ui.span(f"Empty = {r.get('empty', '? ')}", style="margin: 0 15px;"),
ui.span(f"Total = {r.get('total', '? ')}", style="margin: 0 15px;"),
class_="results-text"
),
ui.row(
# ui.column(4, ui.img(src=f"data:image/png;base64,{r['image_base64']}")),
plot_ui(f"plot_{idx}"),
),
class_="card p-3"
)
)
return ui.div(ui_output)
@session.download()
def download():
results = analysis_results.get()
df = pd.DataFrame([{
"Filename": r["filename"],
"Viable": r.get("viable", ""),
"Non-Viable": r.get("non-viable", ""),
"Empty": r.get("empty", ""),
"Total": r.get("total", "")
} for r in results])
# Create in-memory CSV file
with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as tmp:
df.to_csv(tmp.name, index=False)
return tmp.name
app = App(app_ui, server)
# --------------------------------------------------------
# Reactive calculations and effects
# --------------------------------------------------------