Spaces:
Sleeping
Sleeping
| # import detectron2 | |
| # import torch | |
| # | |
| # # Check this in logs | |
| # try: | |
| # print(f"Is CUDA available: {torch.cuda.is_available()}") | |
| # # True | |
| # print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") | |
| # except: | |
| # print('Couldnt find CUDA device') | |
| import base64 | |
| import tempfile | |
| import cv2 | |
| from io import BytesIO | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| from shiny import App, ui, render, reactive, Session, module | |
| from detectron2.utils.visualizer import Visualizer, ColorMode | |
| from detectron2.data import Metadata | |
| from python_utils import load_model, apply_nms | |
| # Load data and compute static values | |
| from shared import app_dir | |
| # Load the prediction model | |
| predictor = load_model() | |
| app_ui = ui.page_fluid( | |
| ui.include_css("styles.css"), | |
| ui.div( | |
| ui.row( | |
| ui.column(6, | |
| ui.panel_title(ui.div("Orchid TZ Viability Analyzer", class_="navbar-title")) | |
| ), | |
| ), | |
| class_="nav-bar" | |
| ), | |
| ui.div( | |
| ui.layout_sidebar( | |
| ui.sidebar( | |
| ui.input_file("upload", "Upload Images", | |
| multiple=True, | |
| accept=[".png", ".jpg", ".jpeg"]), | |
| ui.input_slider("threshold", "Threshold", 0, 1.0, 0.8), | |
| ui.input_action_button("analyze", "Analyze", class_="btn-success"), | |
| # ui.input_switch("mask", "Mask", False), | |
| ui.column(4,ui.download_button("download", "Download Results", class_="btn-primary")), | |
| width =300 | |
| ), | |
| ui.output_ui("results_container"), | |
| border=False, | |
| border_radius=False | |
| ) | |
| ), | |
| # ui.div( | |
| # ui.layout_column_wrap( | |
| # ui.output_image("rbg_kew"), | |
| # ui.output_image("bloomberg"), | |
| # ui.output_image("brin"), | |
| # ui.output_image("abg"), | |
| # ) | |
| # class_="footer" | |
| # ) | |
| ) | |
| def plot_ui(): | |
| opacity_slider = ui.input_slider("opacity_slider", "Opacity", 0, 1.0, 0.5) | |
| return ui.row( | |
| ui.output_plot("plot_prediction"), | |
| opacity_slider | |
| ) | |
| def plot_server(input, output, session, r): | |
| def plot_prediction(): | |
| plt.ioff() | |
| fig, ax = plt.subplots() | |
| # ax = plt.Axes(fig, [0., 0., 1., 1.]) | |
| ax.set_axis_off() | |
| # fig.add_axes(ax) | |
| v = Visualizer(r["image"][:, :, ::-1], | |
| scale=1.2, instance_mode=ColorMode.SEGMENTATION, font_size_scale=1) | |
| colours = [] | |
| for cls in r["instances"].pred_classes: | |
| if cls == 0: | |
| colours.append([1,0,0]) | |
| elif cls == 1: | |
| colours.append([1,1,0]) | |
| elif cls == 2: | |
| colours.append([0,0,0]) | |
| out = v.overlay_instances(masks = r["instances"].pred_masks.to("cpu"), | |
| assigned_colors = colours, | |
| alpha = input.opacity_slider()) | |
| ax.imshow(cv2.cvtColor(out.get_image()[:, :, ::-1], cv2.COLOR_BGR2RGB)) | |
| fig.canvas.draw() | |
| fig.canvas.flush_events() | |
| def server(input, output, session: Session): | |
| # @render.image | |
| # def rbg_kew(): | |
| # img = {"src": "logos/rbg_kew.png", "height": "100px"} | |
| # return img | |
| # @render.image | |
| # def bloomberg(): | |
| # img = {"src": "logos/bloomberg.png", "height": "100px"} | |
| # return img | |
| # @render.image | |
| # def brin(): | |
| # img = {"src": "logos/brin.png", "height": "100px"} | |
| # return img | |
| # @render.image | |
| # def abg(): | |
| # img = {"src": "logos/abg.png", "height": "100px"} | |
| # return img | |
| analysis_results = reactive.Value([]) | |
| async def process_images(): | |
| files = input.upload() | |
| if not files: | |
| return | |
| results = [] | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| for idx, file in enumerate(files): | |
| # Read image using OpenCV | |
| im = cv2.imread(file["datapath"]) | |
| # Convert BGR to RGB for display | |
| im_rgb = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) | |
| pil_img = Image.fromarray(im_rgb) | |
| # Convert to base64 for HTML display | |
| buffered = BytesIO() | |
| pil_img.save(buffered, format="PNG") | |
| img_base64 = base64.b64encode(buffered.getvalue()).decode() | |
| # Run prediction with original BGR image | |
| prediction = predictor(im) | |
| prediction = apply_nms(prediction, True, input.threshold()) | |
| classes = prediction["instances"].pred_classes.tolist() | |
| results.append({ | |
| "filename": file["name"], | |
| "image_base64": img_base64, | |
| "image": im, | |
| **prediction, | |
| "viable": classes.count(0), | |
| "non-viable": classes.count(1), | |
| "empty": classes.count(2), | |
| "total": len(classes) | |
| }) | |
| # Update reactive value | |
| analysis_results.set(results) | |
| def results_container(): | |
| results = analysis_results.get() | |
| if not results: | |
| return ui.div("No results yet. Upload images and click 'Analyze'.", | |
| class_="text-muted") | |
| ui_output = [] | |
| for idx, r in enumerate(results): | |
| plot_server(f"plot_{idx}", r=r) | |
| ui_output.append( | |
| ui.div( | |
| ui.h5(r['filename'], style="margin-top: 15px;"), | |
| ui.div( | |
| ui.span(f"Viable = {r.get('viable', '? ')}", style="margin: 0 15px;"), | |
| ui.span(f"Non-Viable = {r.get('non-viable', '? ')}", style="margin: 0 15px;"), | |
| ui.span(f"Empty = {r.get('empty', '? ')}", style="margin: 0 15px;"), | |
| ui.span(f"Total = {r.get('total', '? ')}", style="margin: 0 15px;"), | |
| class_="results-text" | |
| ), | |
| ui.row( | |
| # ui.column(4, ui.img(src=f"data:image/png;base64,{r['image_base64']}")), | |
| plot_ui(f"plot_{idx}"), | |
| ), | |
| class_="card p-3" | |
| ) | |
| ) | |
| return ui.div(ui_output) | |
| def download(): | |
| results = analysis_results.get() | |
| df = pd.DataFrame([{ | |
| "Filename": r["filename"], | |
| "Viable": r.get("viable", ""), | |
| "Non-Viable": r.get("non-viable", ""), | |
| "Empty": r.get("empty", ""), | |
| "Total": r.get("total", "") | |
| } for r in results]) | |
| # Create in-memory CSV file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as tmp: | |
| df.to_csv(tmp.name, index=False) | |
| return tmp.name | |
| app = App(app_ui, server) | |
| # -------------------------------------------------------- | |
| # Reactive calculations and effects | |
| # -------------------------------------------------------- | |