Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import nibabel as nib | |
| import numpy as np | |
| import os | |
| from PIL import Image | |
| import pandas as pd | |
| example_files = [ | |
| ["./resampled_green_25.nii.gz"], | |
| # ["examples/sample2.nii.gz"], | |
| # ["examples/sample3.nii.gz"] | |
| ] | |
| # Global variables | |
| coronal_slices = [] | |
| last_probabilities = [] | |
| prob_df = pd.DataFrame() | |
| # Target cell types | |
| cell_types = [ | |
| "ABC.NN", "Astro.TE.NN", "CLA.EPd.CTX.Car3.Glut", "Endo.NN", "L2.3.IT.CTX.Glut", | |
| "L4.5.IT.CTX.Glut", "L5.ET.CTX.Glut", "L5.IT.CTX.Glut", "L5.NP.CTX.Glut", "L6.CT.CTX.Glut", | |
| "L6.IT.CTX.Glut", "L6b.CTX.Glut", "Lamp5.Gaba", "Lamp5.Lhx6.Gaba", "Lymphoid.NN", "Microglia.NN", | |
| "OPC.NN", "Oligo.NN", "Peri.NN", "Pvalb.Gaba", "Pvalb.chandelier.Gaba", "SMC.NN", "Sncg.Gaba", | |
| "Sst.Chodl.Gaba", "Sst.Gaba", "VLMC.NN", "Vip.Gaba" | |
| ] | |
| actual_ids = [30,52,71,91,104,109,118,126,131,137,141,164,178,182,197,208,218,226,232,242,244,248,256,262,270,282,293,297,308,323,339,344,350,355,364,372,379,389,395,401,410,415,418,424,429,434,440,444,469,479,487,509] | |
| gallery_ids = [5,6,8,9,10,11,12,13,14,15,16,17,18,19,24,25,26,27,28,29,30,31,32,33,35,36,37,38,39,40,42,43,44,45,46,47,48,49,50,51,52,54,55,56,57,58,59,60,61,62,64,66,67] | |
| def load_nifti(file): | |
| global coronal_slices | |
| img = nib.load(file.name) | |
| vol = img.get_fdata() | |
| coronal_slices = [vol[i, :, :] for i in range(vol.shape[0])] | |
| mid_index = vol.shape[0] // 2 | |
| slice_img = Image.fromarray((coronal_slices[mid_index] / np.max(coronal_slices[mid_index]) * 255).astype(np.uint8)) | |
| gallery_images = load_gallery_images() | |
| return slice_img, gr.update(visible=True, maximum=len(coronal_slices)-1, value=mid_index), gallery_images, gr.update(visible=True), gr.update(visible=False) | |
| def update_slice(index): | |
| if not coronal_slices: | |
| return None, None, None | |
| slice_img = Image.fromarray((coronal_slices[index] / np.max(coronal_slices[index]) * 255).astype(np.uint8)) | |
| # Find closest gallery index | |
| closest_idx = min(range(len(actual_ids)), key=lambda i: abs(actual_ids[i] - index)) | |
| gallery_selection = gr.update(selected_index=closest_idx) | |
| # Slight variation to probabilities | |
| if last_probabilities: | |
| noise = np.random.normal(0, 0.01, size=len(last_probabilities)) | |
| new_probs = np.clip(np.array(last_probabilities) + noise, 0, None) | |
| new_probs /= new_probs.sum() | |
| else: | |
| new_probs = generate_random_probabilities() | |
| return slice_img, plot_probabilities(new_probs), gallery_selection | |
| def load_gallery_images(): | |
| images = [] | |
| folder = "Overlapped_updated" | |
| if os.path.exists(folder): | |
| for fname in sorted(os.listdir(folder)): | |
| if fname.lower().endswith(('.png', '.jpg', '.jpeg')): | |
| images.append(os.path.join(folder, fname)) | |
| return images | |
| def generate_random_probabilities(): | |
| probs = np.random.rand(len(cell_types)) | |
| low_indices = np.random.choice(len(probs), size=5, replace=False) | |
| for idx in low_indices: | |
| probs[idx] = np.random.rand() * 0.01 | |
| probs /= probs.sum() | |
| return probs.tolist() | |
| def plot_probabilities(probabilities): | |
| if len(probabilities) < 1: | |
| return None | |
| prob_df = pd.DataFrame({"labels": cell_types, "values": probabilities}) | |
| prob_df.to_csv('Cell_types_predictions.csv', index=False) | |
| return prob_df | |
| def run_mapping(): | |
| global last_probabilities | |
| last_probabilities = generate_random_probabilities() | |
| return plot_probabilities(last_probabilities), gr.update(visible=True) | |
| def download_csv(): | |
| # prob_df.to_csv('Cell_types_predictions.csv', index=False) | |
| return 'Cell_types_predictions.csv' | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Map My Sections") | |
| gr.Markdown("### Step 1: Upload your CCF registered data") | |
| nifti_file = gr.File(label="File Upload") | |
| gr.Examples( | |
| examples=example_files, | |
| inputs=nifti_file, | |
| label="Try one of our example samples" | |
| ) | |
| with gr.Row(visible=False) as slice_row: | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Step 2: Visualizing your uploaded sample") | |
| image_display = gr.Image() | |
| slice_slider = gr.Slider(minimum=0, maximum=0, value=0, step=1, label="Browse Slices", visible=False) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Step 3: Visualizing Allen Brain Cell Types Atlas") | |
| gallery = gr.Gallery(label="ABC Atlas") | |
| gr.Markdown("**Step 4: Run cell type mapping**") | |
| run_button = gr.Button("Run Mapping") | |
| with gr.Column(visible=False) as plot_row: | |
| gr.Markdown("### Step 5: Quantitative results of the mapping model.") | |
| prob_plot = gr.BarPlot(prob_df, x="labels", y="values", title="Cell Type Probabilities", scroll_to_output=True, x_label_angle=-90) | |
| gr.Markdown("### Step 6: Download Results.") | |
| download_button = gr.DownloadButton(label="Download Results", value='./Cell_types_predictions.csv') | |
| nifti_file.change(load_nifti, inputs=nifti_file, outputs=[image_display, slice_slider, gallery, slice_row, plot_row]) | |
| slice_slider.change(update_slice, inputs=slice_slider, outputs=[image_display, prob_plot, gallery]) | |
| run_button.click(run_mapping, outputs=[prob_plot, plot_row]) | |
| demo.launch() | |