Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from copy import deepcopy | |
| import pandas as pd | |
| from PIL import Image | |
| def get_index_of_element_containing_word(lst, word): | |
| # Create a list of indices where the word is found in the element | |
| indices = [i for i, element in enumerate(lst) if word.lower() in element.lower()] | |
| # Return the first index found, or -1 if the word is not found in any element | |
| return indices[0] if indices else -1 | |
| pred_global = None | |
| alpha_global = 0.5 | |
| alpha_image = None | |
| stl_preds = np.load("stl_species.npy") | |
| df = pd.read_csv("unique_species.csv") | |
| obs = df["NameList"].tolist() | |
| del df | |
| stl_base = Image.open("stl_base.png").convert("RGB") | |
| def update_fn(val): | |
| if val=="Class": | |
| return gr.Dropdown(label="Name", choices=class_list, interactive=True) | |
| elif val=="Order": | |
| return gr.Dropdown(label="Name", choices=order_list, interactive=True) | |
| elif val=="Family": | |
| return gr.Dropdown(label="Name", choices=family_list, interactive=True) | |
| elif val=="Genus": | |
| return gr.Dropdown(label="Name", choices=genus_list, interactive=True) | |
| elif val=="Species": | |
| return gr.Dropdown(label="Name", choices=obs, interactive=True) | |
| def text_fn(taxon, name): | |
| global pred_global, alpha_global, alpha_image | |
| species_index = get_index_of_element_containing_word(obs, name) | |
| preds = np.flip(stl_preds[:, species_index].reshape(510, 510), 1) | |
| pred_global = preds | |
| alpha_image = preds | |
| cmap = plt.get_cmap('plasma') | |
| rgba_img = cmap(preds) | |
| rgb_img = np.delete(rgba_img, 3, 2) | |
| blend = Image.blend(stl_base, Image.fromarray((rgb_img * 255).astype(np.uint8)), alpha_global) | |
| rgb_img = np.array(blend) | |
| #return gr.Image(preds, label="Predicted Heatmap", visible=True) | |
| return rgb_img | |
| def thresh_fn(val): | |
| global pred_global, alpha_global, alpha_image | |
| preds = deepcopy(pred_global) | |
| preds[preds<val] = 0 | |
| preds[preds>=val] = 1 | |
| alpha_image = deepcopy(preds) | |
| cmap = plt.get_cmap('plasma') | |
| rgba_img = cmap(preds) | |
| rgb_img = np.delete(rgba_img, 3, 2) | |
| blend = Image.blend(stl_base, Image.fromarray((rgb_img * 255).astype(np.uint8)), alpha_global) | |
| rgb_img = np.array(blend) | |
| return rgb_img | |
| def alpha_fn(val): | |
| global pred_global, alpha_global, alpha_image | |
| alpha_global = val | |
| preds = deepcopy(alpha_image) | |
| cmap = plt.get_cmap('plasma') | |
| rgba_img = cmap(preds) | |
| rgb_img = np.delete(rgba_img, 3, 2) | |
| blend = Image.blend(stl_base, Image.fromarray((rgb_img * 255).astype(np.uint8)), alpha_global) | |
| rgb_img = np.array(blend) | |
| return rgb_img | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # St Louis Species Distribution Model! | |
| This model predicts the distribution of species based on geographic, and satellite image features. | |
| """) | |
| with gr.Row(): | |
| inp = gr.Dropdown(label="Taxonomic Hierarchy", choices=["Species"]) | |
| out = gr.Dropdown(label="Name", interactive=True) | |
| inp.change(update_fn, inp, out) | |
| with gr.Row(): | |
| check_button = gr.Button("Run Model") | |
| with gr.Row(): | |
| slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label="Confidence Threshold") | |
| with gr.Row(): | |
| alpha = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label="Image Transparency") | |
| with gr.Row(): | |
| pred = gr.Image(label="Predicted Heatmap", visible=True) | |
| check_button.click(text_fn, inputs=[inp, out], outputs=[pred]) | |
| slider.change(thresh_fn, slider, outputs=pred) | |
| alpha.change(alpha_fn, alpha, outputs=pred) | |
| demo.launch() |