Spaces:
Sleeping
Sleeping
| import os | |
| import shutil | |
| import time | |
| import uuid | |
| from datetime import datetime | |
| from decimal import Decimal | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| from settings import DEMO | |
| plt.switch_backend("agg") # fix for "RuntimeError: main thread is not in main loop" | |
| import numpy as np | |
| import pandas as pd | |
| from PIL import Image | |
| from model import GranaAnalyser | |
| ga = GranaAnalyser( | |
| "weights/yolo/20240604_yolov8_segm_ABRCR1_all_train4_best.pt", | |
| "weights/AS_square_v16.ckpt", | |
| "weights/period_measurer_weights-1.298_real_full-fa12970.ckpt", | |
| ) | |
| def calc_ratio(pixels, nano): | |
| """ | |
| Calculates ratio of pixels to nanometers and returns as str to populate ratio_input | |
| :param pixels: | |
| :param nano: | |
| :return: | |
| """ | |
| if not (pixels and nano): | |
| pass | |
| else: | |
| res = pixels / nano | |
| return res | |
| # https://jakevdp.github.io/PythonDataScienceHandbook/05.13-kernel-density-estimation.html | |
| def KDE(dataset, h): | |
| # the Kernel function | |
| def K(x): | |
| return np.exp(-(x ** 2) / 2) / np.sqrt(2 * np.pi) | |
| n_samples = dataset.size | |
| x_range = dataset # x-value range for plotting KDEs | |
| total_sum = 0 | |
| # iterate over datapoints | |
| for i, xi in enumerate(dataset): | |
| total_sum += K((x_range - xi) / h) | |
| y_range = total_sum / (h * n_samples) | |
| return y_range | |
| def prepare_files_for_download( | |
| dir_name, | |
| grana_data, | |
| aggregated_data, | |
| detection_visualizations_dict, | |
| images_grana_dict, | |
| ): | |
| """ | |
| Save and zip files for download | |
| :param dir_name: | |
| :param grana_data: DataFrame containing all grana measurements | |
| :param aggregated_data: dict containing aggregated measurements | |
| :return: | |
| """ | |
| dir_to_zip = f"{dir_name}/to_zip" | |
| # raw data | |
| grana_data_csv_path = f"{dir_to_zip}/grana_raw_data.csv" | |
| grana_data.to_csv(grana_data_csv_path, index=False) | |
| # aggregated measurements | |
| aggregated_csv_path = f"{dir_to_zip}/grana_aggregated_data.csv" | |
| aggregated_data.to_csv(aggregated_csv_path) | |
| # annotated pictures | |
| masked_images_dir = f"{dir_to_zip}/annotated_images" | |
| os.makedirs(masked_images_dir) | |
| for img_name, img in detection_visualizations_dict.items(): | |
| filename_split = img_name.split(".") | |
| extension = filename_split[-1] | |
| filename = ".".join(filename_split[:-1]) | |
| filename = f"{filename}_annotated.{extension}" | |
| img.save(f"{masked_images_dir}/{filename}") | |
| # single_grana images | |
| grana_images_dir = f"{dir_to_zip}/single_grana_images" | |
| os.makedirs(grana_images_dir) | |
| org_images_dict = pd.Series( | |
| grana_data["source image"].values, index=grana_data["granum ID"] | |
| ).to_dict() | |
| for img_name, img in images_grana_dict.items(): | |
| org_filename = org_images_dict[img_name] | |
| org_filename_split = org_filename.split(".") | |
| org_filename_no_ext = ".".join(org_filename_split[:-1]) | |
| img_name_ext = f"{org_filename_no_ext}_granum_{str(img_name)}.png" | |
| img.save(f"{grana_images_dir}/{img_name_ext}") | |
| # zip all files | |
| date_str = datetime.today().strftime("%Y-%m-%d") | |
| zip_name = f"GRANA_results_{date_str}" | |
| zip_path = f"{dir_name}/{zip_name}" | |
| shutil.make_archive(zip_path, "zip", dir_to_zip) | |
| # delete to_zip dir | |
| zip_dir_path = os.path.join(os.getcwd(), dir_to_zip) | |
| shutil.rmtree(zip_dir_path) | |
| download_file_path = f"{zip_path}.zip" | |
| return download_file_path | |
| def show_info_on_submit(s): | |
| return ( | |
| gr.Button(interactive=False), | |
| gr.Button(interactive=False), | |
| gr.Row(visible=True), | |
| gr.Row(visible=False), | |
| ) | |
| def load_css(): | |
| with open("styles.css", "r") as f: | |
| css_content = f.read() | |
| return css_content | |
| primary_hue = gr.themes.Color( | |
| c50="#e1f8ee", | |
| c100="#b7efd5", | |
| c200="#8de6bd", | |
| c300="#63dda5", | |
| c400="#39d48d", | |
| c500="#27b373", | |
| c600="#1e8958", | |
| c700="#155f3d", | |
| c800="#0c3522", | |
| c900="#030b07", | |
| c950="#000", | |
| ) | |
| theme = gr.themes.Default( | |
| primary_hue=primary_hue, | |
| font=[gr.themes.GoogleFont("Ubuntu"), "ui-sans-serif", "system-ui", "sans-serif"], | |
| ) | |
| def draw_violin_plot(y, ylabel, title): | |
| # only generate plot for 3 or more values | |
| if y.count() < 3: | |
| return None | |
| # Colors | |
| RED_DARK = "#850e00" | |
| DARK_GREEN = "#0c3522" | |
| BRIGHT_GREEN = "#8de6bd" | |
| # Create jittered version of "x" (which is only 1) | |
| x_jittered = [] | |
| kde = KDE(y, (y.max() - y.min()) / y.size / 2) | |
| kde = kde / kde.max() * 0.2 | |
| for y_val in kde: | |
| x_jittered.append(1 + np.random.uniform(-y_val, y_val, 1)) | |
| fig = plt.figure() | |
| ax = fig.add_subplot(1, 1, 1) | |
| ax.scatter(x=x_jittered, y=y, s=20, alpha=0.4, c=DARK_GREEN) | |
| violins = ax.violinplot( | |
| y, | |
| widths=0.45, | |
| bw_method="silverman", | |
| showmeans=False, | |
| showmedians=False, | |
| showextrema=False, | |
| ) | |
| # change violin color | |
| for pc in violins["bodies"]: | |
| pc.set_facecolor(BRIGHT_GREEN) | |
| # add a boxplot to ax | |
| # but make the whiskers length equal to 1 SD, i.e. in the proportion of the IQ range, but this length should start from the mean but be visible from the box boundary | |
| lower = np.mean(y) - 1 * np.std(y) | |
| upper = np.mean(y) + 1 * np.std(y) | |
| medianprops = dict(linewidth=1, color="black", solid_capstyle="butt") | |
| boxplot_stats = [ | |
| { | |
| "med": np.median(y), | |
| "q1": np.percentile(y, 25), | |
| "q3": np.percentile(y, 75), | |
| "whislo": lower, | |
| "whishi": upper, | |
| } | |
| ] | |
| ax.bxp( | |
| boxplot_stats, # data for the boxplot | |
| showfliers=False, # do not show the outliers beyond the caps. | |
| showcaps=True, # show the caps | |
| medianprops=medianprops, | |
| ) | |
| # Add mean value point | |
| ax.scatter(1, y.mean(), s=30, color=RED_DARK, zorder=3) | |
| ax.set_xticks([]) | |
| ax.set_ylabel(ylabel) | |
| ax.set_title(title) | |
| fig.tight_layout() | |
| return fig | |
| def transform_aggregated_results_table(results_dict): | |
| MEASUREMENT_HEADER = "measurement [unit]" | |
| VALUE_HEADER = "value +-SD" | |
| def get_value_str(value, std): | |
| if np.isnan(value) or np.isnan(std): | |
| return "-" | |
| value_str = str(Decimal(str(value)).quantize(Decimal("0.01"))) | |
| std_str = str(Decimal(str(std)).quantize(Decimal("0.01"))) | |
| return f"{value_str} +-{std_str}" | |
| def append_to_dict(new_key, old_val_key, old_sd_key): | |
| aggregated_dict[MEASUREMENT_HEADER].append(new_key) | |
| value_str = get_value_str(results_dict[old_val_key], results_dict[old_sd_key]) | |
| aggregated_dict[VALUE_HEADER].append(value_str) | |
| aggregated_dict = {MEASUREMENT_HEADER: [], VALUE_HEADER: []} | |
| # area | |
| append_to_dict("area [nm^2]", "area nm^2", "area nm^2 std") | |
| # perimeter | |
| append_to_dict("perimeter [nm]", "perimeter nm", "perimeter nm std") | |
| # diameter | |
| append_to_dict("diameter [nm]", "diameter nm", "diameter nm std") | |
| # height | |
| append_to_dict("height [nm]", "height nm", "height nm std") | |
| # number of layers | |
| append_to_dict("number of thylakoids", "Number of layers", "Number of layers std") | |
| # SRD | |
| append_to_dict("SRD [nm]", "period nm", "period nm std") | |
| # GSI | |
| append_to_dict("GSI", "GSI", "GSI std") | |
| # N grana | |
| aggregated_dict[MEASUREMENT_HEADER].append("number of grana") | |
| aggregated_dict[VALUE_HEADER].append(str(int(results_dict["N grana"]))) | |
| return aggregated_dict | |
| def rename_columns_in_results_table(results_table): | |
| column_names = { | |
| "Granum ID": "granum ID", | |
| "File name": "source image", | |
| "area nm^2": "area [nm^2]", | |
| "perimeter nm": "perimeter [nm]", | |
| "diameter nm": "diameter [nm]", | |
| "height nm": "height [nm]", | |
| "Number of layers": "number of thylakoids", | |
| "period nm": "SRD [nm]", | |
| "period SD nm": "SRD SD [nm]", | |
| } | |
| results_table = results_table.rename(columns=column_names) | |
| return results_table | |
| with gr.Blocks(css=load_css(), theme=theme) as demo: | |
| svg = """ | |
| <svg id="Layer_1" data-name="Layer 1" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 30.73 33.38"> | |
| <defs> | |
| <style> | |
| .cls-1 { | |
| fill: #27b373; | |
| stroke-width: 0px; | |
| } | |
| </style> | |
| </defs> | |
| <path class="cls-1" d="M19.69,11.73h-3.22c-2.74,0-4.96,2.22-4.96,4.96h0c0,2.74,2.22,4.96,4.96,4.96h3.43c.56,0,1,.51.89,1.09-.08.43-.49.72-.92.72h-8.62c-.74,0-1.34-.6-1.34-1.34v-10.87c0-.74.6-1.34,1.34-1.34h13.44c2.73,0,4.95-2.22,4.95-4.95h0c0-2.75-2.22-4.97-4.96-4.97h-13.85C4.85,0,0,4.85,0,10.83v11.71c0,5.98,4.85,10.83,10.83,10.83h9.07c5.76,0,10.49-4.52,10.81-10.21.35-6.29-4.72-11.44-11.02-11.44ZM19.9,31.4h-9.07c-4.89,0-8.85-3.96-8.85-8.85v-11.71C1.98,5.95,5.95,1.98,10.83,1.98h13.81c1.64,0,2.97,1.33,2.97,2.97h0c0,1.65-1.33,2.97-2.96,2.97h-13.4c-1.83,0-3.32,1.49-3.32,3.32v10.87c0,1.83,1.49,3.32,3.32,3.32h8.56c1.51,0,2.83-1.12,2.97-2.62.16-1.72-1.2-3.16-2.88-3.16h-3.52c-1.64,0-2.97-1.33-2.97-2.97h0c0-1.64,1.33-2.97,2.97-2.97h3.34c4.83,0,8.9,3.81,9.01,8.64s-3.9,9.04-8.84,9.04Z"/> | |
| <path class="cls-1" d="M19.9,29.41h-9.07c-3.79,0-6.87-3.07-6.87-6.87v-11.71c0-3.79,3.07-6.87,6.87-6.87h13.81c.55,0,.99.44.99.99h0c0,.55-.44.99-.99.99h-13.81c-2.7,0-4.88,2.19-4.88,4.88v11.71c0,2.7,2.19,4.88,4.88,4.88h8.94c2.64,0,4.91-2.05,5-4.7s-2.12-5.05-4.87-5.05h-3.52c-.55,0-.99-.44-.99-.99h0c0-.55.44-.99.99-.99h3.36c3.74,0,6.9,2.92,7.01,6.66.11,3.87-3.01,7.06-6.85,7.06Z"/> | |
| </svg> | |
| """ | |
| gr.HTML( | |
| f'<div class="header"><div id="header-logo">{svg}</div><div id="header-text">GRANA<div></div>' | |
| ) | |
| with gr.Row(elem_classes="input-row"): # input | |
| with gr.Column(): | |
| gr.HTML( | |
| "<h1>1. Choose images to upload. All the images need to be of the same scale and experimental variant.</h1>" | |
| ) | |
| img_input = gr.File(file_count="multiple") | |
| gr.HTML("<h1>2. Set the scale of the images for the measurements.</h1>") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.HTML("Either provide pixel per nanometer ratio...") | |
| ratio_input = gr.Number( | |
| label="pixel per nm", precision=3, step=0.001 | |
| ) | |
| with gr.Column(): | |
| gr.HTML("...or length of the scale bar in pixels and nanometers.") | |
| pixels_input = gr.Number(label="Length in pixels") | |
| nano_input = gr.Number(label="Length in nanometers") | |
| pixels_input.change( | |
| calc_ratio, | |
| inputs=[pixels_input, nano_input], | |
| outputs=ratio_input, | |
| ) | |
| nano_input.change( | |
| calc_ratio, | |
| inputs=[pixels_input, nano_input], | |
| outputs=ratio_input, | |
| ) | |
| with gr.Row(): | |
| clear_btn = gr.ClearButton(img_input, "Clear") | |
| submit_btn = gr.Button("Submit", variant="primary") | |
| with gr.Row(visible=False) as loading_row: | |
| with gr.Column(): | |
| gr.HTML( | |
| "<div class='processed-info'>Images are being processed. This may take a while...</div>" | |
| ) | |
| with gr.Row(visible=False) as output_row: | |
| with gr.Column(): | |
| gr.HTML( | |
| '<div class="results-header">Results</div>' | |
| "<p>Full results are a zip file containing:<p>" | |
| "<ul>- grana_raw_data.csv: a table with full grana measurements,</ul>" | |
| "<ul>- grana_aggregated_data.csv: a table with aggregated measurements,</ul>" | |
| '<ul>- directory "annotated_images" with all submitted images with masks on detected grana,</ul>' | |
| '<ul>- directory "single_grana_images" with images of all detected grana.</ul>' | |
| "<p>Note that GRANA only stores the result files for 1 hour.</p>", | |
| elem_classes="input-row", | |
| ) | |
| with gr.Row(elem_classes="input-row"): | |
| download_file_out = gr.DownloadButton( | |
| label="Download results", | |
| variant="primary", | |
| elem_classes="margin-bottom", | |
| ) | |
| with gr.Row(): | |
| gr.HTML( | |
| '<h2 class="title">Annotated images</h2>' | |
| "Gallery of uploaded images with masks of recognized grana structures. " | |
| "Each granum mask is " | |
| "labeled with its number. Note that only fully visible grana in the image are masked." | |
| ) | |
| with gr.Row(elem_classes="margin-bottom"): | |
| gallery_out = gr.Gallery( | |
| columns=4, | |
| rows=2, | |
| object_fit="contain", | |
| label="Detection visualizations", | |
| show_download_button=False, | |
| ) | |
| with gr.Row(elem_classes="input-row"): | |
| gr.HTML( | |
| '<h2 class="title">Aggregated results for all uploaded images</h2>' | |
| ) | |
| with gr.Row(elem_classes=["input-row", "margin-bottom"]): | |
| table_out = gr.Dataframe(label="Aggregated data") | |
| with gr.Row(): | |
| gr.HTML( | |
| '<h2 class="title">Violin graphs</h2>' | |
| "These graphs present aggregated results for selected structural parameters. " | |
| "The graph for each parameter is only generated if three or more values are available. " | |
| "Each graph " | |
| "displays individual data points, a box plot indicating the first and third quartiles, whiskers " | |
| "marking the standard deviation (SD), the median value (horizontal line on the box plot), " | |
| "the mean value (red dot), and a density plot where the width represents the frequency." | |
| ) | |
| with gr.Row(): | |
| area_plot_out = gr.Plot(label="Area") | |
| perimeter_plot_out = gr.Plot(label="Perimeter") | |
| gsi_plot_out = gr.Plot(label="GSI") | |
| with gr.Row(elem_classes="margin-bottom"): | |
| diameter_plot_out = gr.Plot(label="Diameter") | |
| height_plot_out = gr.Plot(label="Height") | |
| srd_plot_out = gr.Plot(label="SRD") | |
| with gr.Row(): | |
| gr.HTML( | |
| '<h2 class="title">Recognized and rotated grana structures</h2>' | |
| ) | |
| with gr.Row(elem_classes="margin-bottom"): | |
| gallery_single_grana_out = gr.Gallery( | |
| columns=4, | |
| rows=2, | |
| object_fit="contain", | |
| label="Single grana images", | |
| show_download_button=False, | |
| ) | |
| with gr.Row(): | |
| gr.HTML( | |
| '<h2 class="title">Full results</h2>' | |
| "Note that structural parameters other than area and perimeter are only calculated for the grana " | |
| "whose direction and/or SRD could be estimated." | |
| ) | |
| with gr.Row(): | |
| table_full_out = gr.Dataframe(label="Full measurements data") | |
| submit_btn.click( | |
| show_info_on_submit, | |
| inputs=[submit_btn], | |
| outputs=[submit_btn, clear_btn, loading_row, output_row], | |
| ) | |
| def enable_submit(): | |
| return ( | |
| gr.Button(interactive=True), | |
| gr.Button(interactive=True), | |
| gr.Row(visible=False), | |
| ) | |
| def gradio_analize_image(images, scale): | |
| """ | |
| Model accepts following parameters: | |
| :param images: list of images to be processed, in either tiff or png format | |
| :param scale: float, nm to pixel ratio | |
| Model returns the following objects: | |
| - detection_visualizations: list of images with masks to be displayed as gallery and served to download | |
| as zip of images | |
| - grana_data: dataframe with measurements for each image to be served to download as a csv file | |
| - images_grana: list of images with single grana to be served to download as zip of images | |
| - aggregated_data: dataframe with aggregated measurements for all images to be displayed as table and served | |
| to download as csv | |
| """ | |
| # validate that at least one image has been uploaded | |
| if images is None or len(images) == 0: | |
| raise gr.Error("Please upload at least one image") | |
| # on demo instance, we limit the number of images to 5 | |
| if DEMO: | |
| if len(images) > 5: | |
| raise gr.Error("In demo version it is possible to analyze up to 5 images.") | |
| # validate that scale has been provided correctly | |
| if scale is None or scale == 0: | |
| raise gr.Error("Please provide scale. Use dot as decimal separator") | |
| # validate that all images are png or tiff | |
| for image in images: | |
| if not image.name.lower().endswith((".png", ".tif", ".jpg", ".jpeg")): | |
| raise gr.Error("Only png, tiff, jpg ang jpeg images are supported") | |
| # clean up previous results | |
| # find all directories in current working directory that start with "results_" | |
| # that were created more than 1 hour ago and delete them with all contents | |
| for directory_name in os.listdir(): | |
| if directory_name.startswith("results_"): | |
| dir_path = os.path.join(os.getcwd(), directory_name) | |
| if os.path.isdir(dir_path): | |
| if time.time() - os.path.getctime(dir_path) > 60 * 60: | |
| shutil.rmtree(dir_path) | |
| # create a directory for results | |
| results_dir_name = "results_{uuid}".format(uuid=uuid.uuid4().hex) | |
| os.makedirs(results_dir_name) | |
| zip_dir_name = f"{results_dir_name}/to_zip" | |
| os.makedirs(zip_dir_name) | |
| # model takes a dict of images, so we need to convert input to list of PIL.PngImagePlugin.PngImageFile or | |
| # PIL.TiffImagePlugin.TiffImageFile objects | |
| images_dict = { | |
| image.name.split("/")[-1]: Image.open(image.name) | |
| for i, image in enumerate(images) | |
| } | |
| # model works here | |
| ( | |
| detection_visualizations_dict, | |
| grana_data, | |
| images_grana_dict, | |
| aggregated_data, | |
| ) = ga.predict(images_dict, scale) | |
| detection_visualizations = list(detection_visualizations_dict.values()) | |
| images_grana = list(images_grana_dict.values()) | |
| # rearrange aggregated data to be displayed as table | |
| aggregated_dict = transform_aggregated_results_table(aggregated_data) | |
| aggregated_df_transposed = pd.DataFrame.from_dict(aggregated_dict) | |
| # rename columns in full results | |
| grana_data = rename_columns_in_results_table(grana_data) | |
| # save files returned by model to disk so they can be retrieved for downloading | |
| download_file_path = prepare_files_for_download( | |
| results_dir_name, | |
| grana_data, | |
| aggregated_df_transposed, | |
| detection_visualizations_dict, | |
| images_grana_dict, | |
| ) | |
| # generate plot | |
| area_fig = draw_violin_plot( | |
| grana_data["area [nm^2]"].dropna(), | |
| "Granum area [nm^2]", | |
| "Grana areas from all uploaded images", | |
| ) | |
| perimeter_fig = draw_violin_plot( | |
| grana_data["perimeter [nm]"].dropna(), | |
| "Granum perimeter [nm]", | |
| "Grana perimeters from all uploaded images", | |
| ) | |
| gsi_fig = draw_violin_plot( | |
| grana_data["GSI"].dropna(), | |
| "GSI", | |
| "GSI from all uploaded images", | |
| ) | |
| diameter_fig = draw_violin_plot( | |
| grana_data["diameter [nm]"].dropna(), | |
| "Granum diameter [nm]", | |
| "Grana diameters from all uploaded images", | |
| ) | |
| height_fig = draw_violin_plot( | |
| grana_data["height [nm]"].dropna(), | |
| "Granum height [nm]", | |
| "Grana heights from all uploaded images", | |
| ) | |
| srd_fig = draw_violin_plot( | |
| grana_data["SRD [nm]"].dropna(), "SRD [nm]", "SRD from all uploaded images" | |
| ) | |
| return [ | |
| gr.Row(visible=True), | |
| gr.Row(visible=True), | |
| download_file_path, | |
| detection_visualizations, | |
| aggregated_df_transposed, | |
| area_fig, | |
| perimeter_fig, | |
| gsi_fig, | |
| diameter_fig, | |
| height_fig, | |
| srd_fig, | |
| images_grana, | |
| grana_data, | |
| ] | |
| submit_btn.click( | |
| fn=gradio_analize_image, | |
| inputs=[ | |
| img_input, | |
| ratio_input, | |
| ], | |
| outputs=[ | |
| loading_row, | |
| output_row, | |
| # file_download_checkboxes, | |
| download_file_out, | |
| gallery_out, | |
| table_out, | |
| area_plot_out, | |
| perimeter_plot_out, | |
| gsi_plot_out, | |
| diameter_plot_out, | |
| height_plot_out, | |
| srd_plot_out, | |
| gallery_single_grana_out, | |
| table_full_out, | |
| ], | |
| ).then(fn=enable_submit, inputs=[], outputs=[submit_btn, clear_btn, loading_row]) | |
| demo.launch( | |
| share=False, debug=True, server_name="0.0.0.0", allowed_paths=["images/logo.svg"] | |
| ) | |