Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import nibabel as nib | |
| import numpy as np | |
| import os | |
| from PIL import Image | |
| import pandas as pd | |
| import nrrd | |
| import ants | |
| from natsort import natsorted | |
| from scipy.ndimage import zoom, rotate | |
| import matplotlib.pyplot as plt | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.models as models | |
| import torchvision.transforms as transforms | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from skimage.transform import resize | |
| import cv2 | |
| def square_padd(original_data, square_size=(120,152, 184), order = 1): | |
| # e.g. square_size = 256 by default | |
| # takes a raw image as input | |
| # returns a square (padded) image as output | |
| # order = [int(x-1) for x in ss.rankdata(original_data.shape)] | |
| # # print(order) | |
| # data = original_data.transpose(order) | |
| data= original_data | |
| # print(original_data.shape) | |
| # print(data.shape) | |
| if data.shape[1]>data.shape[0] and data.shape[1]>data.shape[2]: # width>height | |
| scale_percent = (square_size[1]/data.shape[1])*100 | |
| # print("dim1") | |
| elif data.shape[2]>data.shape[0] and data.shape[2]>data.shape[1]: # width>height | |
| scale_percent = (square_size[2]/data.shape[2])*100 | |
| # print("dim2") | |
| else: # width<height | |
| scale_percent = (square_size[0]/data.shape[0])*100 | |
| scale_percent = int(scale_percent) | |
| # print(scale_percent) | |
| width = int(data.shape[0] * scale_percent / 100); height = int(data.shape[1] * scale_percent / 100); depth = int(data.shape[2] * scale_percent / 100); | |
| dim = (width, height, depth) | |
| # print(dim) | |
| zoomFactors = [square_size_axis/float(data_shape) for data_shape, square_size_axis in zip(data.shape, square_size)] | |
| sect_mask = zoom(data,zoom = zoomFactors, order = order, ) | |
| # sect_mask = zoom(data,(scale_percent/100, scale_percent/100, scale_percent/100), order = order, ) | |
| # sect_mask = cv2.resize(data, dim, interpolation = cv2.INTER_AREA) | |
| sect_padd = (np.ones(square_size))*data[0,0,0] | |
| sect_padd[int((square_size[0]-np.shape(sect_mask)[0])/2):int((square_size[0]-np.shape(sect_mask)[0])/2)+np.shape(sect_mask)[0], | |
| int((square_size[1]-np.shape(sect_mask)[1])/2):int((square_size[1]-np.shape(sect_mask)[1])/2)+np.shape(sect_mask)[1], | |
| int((square_size[2]-np.shape(sect_mask)[2])/2):int((square_size[2]-np.shape(sect_mask)[2])/2)+np.shape(sect_mask)[2]] = sect_mask | |
| return sect_padd | |
| def square_padding_RGB(single_RGB,square_size=256): | |
| # e.g. square_size = 256 by default | |
| # takes a raw image as input | |
| # returns a square (padded) image as output | |
| # input: 2D image | |
| # output: 2D resized padded image | |
| # example: BNI images, HMS data | |
| if single_RGB.shape[1]>single_RGB.shape[0]: # width>height | |
| scale_percent = (square_size/single_RGB.shape[1])*100 | |
| else: # width<height | |
| scale_percent = (square_size/single_RGB.shape[0])*100 | |
| width = int(single_RGB.shape[1] * scale_percent / 100); height = int(single_RGB.shape[0] * scale_percent / 100); dim = (width, height) | |
| sect_mask = cv2.resize(single_RGB, dim, interpolation = cv2.INTER_AREA) | |
| sect_padd = (np.ones((square_size,square_size,3)))*np.mean(single_RGB[:10,:10]) | |
| sect_padd[int((square_size-np.shape(sect_mask)[0])/2):int((square_size-np.shape(sect_mask)[0])/2)+np.shape(sect_mask)[0], | |
| int((square_size-np.shape(sect_mask)[1])/2):int((square_size-np.shape(sect_mask)[1])/2)+np.shape(sect_mask)[1],:] = sect_mask | |
| return sect_padd | |
| def square_padding(single_gray,square_size=256): | |
| # e.g. square_size = 256 by default | |
| # takes a raw image as input | |
| # returns a square (padded) image as output | |
| # input: 2D image | |
| # output: 2D resized padded image | |
| # example: BNI images, HMS data | |
| if len(np.shape(single_gray))>2: | |
| return square_padding_RGB(single_gray[:,:,:3]) | |
| else: | |
| # print("Single gray shape:", np.shape(single_gray)) | |
| if single_gray.shape[1]>single_gray.shape[0]: # width>height | |
| scale_percent = (square_size/single_gray.shape[1])*100 | |
| else: # width<height | |
| scale_percent = (square_size/single_gray.shape[0])*100 | |
| width = int(single_gray.shape[1] * scale_percent / 100); height = int(single_gray.shape[0] * scale_percent / 100); dim = (width, height) | |
| # print("Dim::", dim) | |
| sect_mask = cv2.resize(single_gray, dim, interpolation = cv2.INTER_AREA) | |
| sect_padd = (np.zeros((square_size,square_size)))*single_gray[-20,-20]#find a better solution for single_gray[100,-100] | |
| sect_padd[int((square_size-np.shape(sect_mask)[0])/2):int((square_size-np.shape(sect_mask)[0])/2)+np.shape(sect_mask)[0], | |
| int((square_size-np.shape(sect_mask)[1])/2):int((square_size-np.shape(sect_mask)[1])/2)+np.shape(sect_mask)[1]] = sect_mask | |
| return sect_padd | |
| def affine_reg(fixed_image,moving_image,gauss_param=100): | |
| # this function takes fixed and moving images as input and return affine transformation matrix | |
| # fixed/moving images can be 2D/3D | |
| # todo: add an option as flag to save the transformation matrix and displacement fields at the desired location to be able to apply the transforms later | |
| mytx = ants.registration(fixed=fixed_image, | |
| moving=moving_image, | |
| type_of_transform='Affine', | |
| reg_iterations = (gauss_param,gauss_param,gauss_param,gauss_param)) | |
| print('affine registration completed') | |
| return mytx | |
| def nonrigid_reg(fixed_image,mytx,type_of_transform='SyN',grad_step=0.25,reg_iterations=(50,50,50, ),flow_sigma=9,total_sigma=0.2): | |
| # this function takes fixed image and affined tx matrix as input and return non-rigid transformation matrix | |
| # fixed/moving images can be 2D/3D | |
| # type of transform selection: https://antspy.readthedocs.io/en/latest/registration.html | |
| # todo: scale the function to incorporate the extended parameters for type_of_transform | |
| # todo: scale the function to incorporate the affine+non-rigid simultaneously in case of SyNRA | |
| transform_type = {'SyN':{'grad_step':grad_step,'reg_iterations':reg_iterations,'flow_sigma':flow_sigma,'total_sigma':total_sigma}, | |
| 'SyNRA':{'grad_step':grad_step,'reg_iterations':reg_iterations,'flow_sigma':flow_sigma,'total_sigma':total_sigma}} | |
| mytx_non_rigid = ants.registration(fixed = fixed_image, | |
| moving=mytx['warpedmovout'], | |
| type_of_transform=type_of_transform, | |
| grad_step=transform_type[type_of_transform]['grad_step'], | |
| reg_iterations=transform_type[type_of_transform]['reg_iterations'], | |
| flow_sigma=transform_type[type_of_transform]['flow_sigma'], | |
| total_sigma=transform_type[type_of_transform]['total_sigma']) | |
| print('non-rigid registration completed') | |
| return mytx_non_rigid | |
| def affine_reg(fixed_image,moving_image,gauss_param=100): | |
| # this function takes fixed and moving images as input and return affine transformation matrix | |
| # fixed/moving images can be 2D/3D | |
| # todo: add an option as flag to save the transformation matrix and displacement fields at the desired location to be able to apply the transforms later | |
| mytx = ants.registration(fixed=fixed_image, | |
| moving=moving_image, | |
| type_of_transform='Affine', | |
| reg_iterations = (gauss_param,gauss_param,gauss_param,gauss_param)) | |
| print('affine registration completed') | |
| return mytx | |
| def nonrigid_reg(fixed_image,mytx,type_of_transform='SyN',grad_step=0.25,reg_iterations=(50,50,50, ),flow_sigma=9,total_sigma=0.2): | |
| # this function takes fixed image and affined tx matrix as input and return non-rigid transformation matrix | |
| # fixed/moving images can be 2D/3D | |
| # type of transform selection: https://antspy.readthedocs.io/en/latest/registration.html | |
| # todo: scale the function to incorporate the extended parameters for type_of_transform | |
| # todo: scale the function to incorporate the affine+non-rigid simultaneously in case of SyNRA | |
| transform_type = {'SyN':{'grad_step':grad_step,'reg_iterations':reg_iterations,'flow_sigma':flow_sigma,'total_sigma':total_sigma}, | |
| 'SyNRA':{'grad_step':grad_step,'reg_iterations':reg_iterations,'flow_sigma':flow_sigma,'total_sigma':total_sigma}} | |
| mytx_non_rigid = ants.registration(fixed = fixed_image, | |
| moving=mytx['warpedmovout'], | |
| type_of_transform=type_of_transform, | |
| grad_step=transform_type[type_of_transform]['grad_step'], | |
| reg_iterations=transform_type[type_of_transform]['reg_iterations'], | |
| flow_sigma=transform_type[type_of_transform]['flow_sigma'], | |
| total_sigma=transform_type[type_of_transform]['total_sigma']) | |
| print('non-rigid registration completed') | |
| return mytx_non_rigid | |
| def run_3D_registration(user_section, ): | |
| global allen_atlas_ccf, allen_template_ccf | |
| template_atlas = allen_atlas_ccf | |
| template_section = allen_template_ccf | |
| template_atlas = np.uint16(template_atlas*255) | |
| user_section = square_padd(user_section, (60, 76, 92)) | |
| template_atlas = square_padd(template_atlas, user_section.shape) | |
| template_section = square_padd(template_section, user_section.shape) | |
| fixed_image = ants.from_numpy(user_section) | |
| moving_atlas_ants = ants.from_numpy(template_atlas) | |
| moving_image = ants.from_numpy(template_section) | |
| mytx = affine_reg(fixed_image,moving_image) | |
| mytx_non_rigid = nonrigid_reg(fixed_image,mytx) | |
| affined_fixed_atlas = ants.apply_transforms(fixed=fixed_image, | |
| moving=moving_image, | |
| transformlist=mytx['fwdtransforms'], | |
| interpolator='nearestNeighbor') | |
| nonrigid_fixed_atlas = ants.apply_transforms(fixed=fixed_image, | |
| moving=affined_fixed_atlas, | |
| transformlist=mytx_non_rigid['fwdtransforms'], | |
| interpolator='nearestNeighbor') | |
| gallery_images = natsorted(load_gallery_images()) | |
| transformed_images = [] | |
| if not(os.path.exists("Overlaped_registered")): | |
| os.mkdir("Overlaped_registered") | |
| # registered = nonrigid_fixed_atlas.numpy()/255 | |
| # for id in list(range((registered.shape[0]//2)-15, (registered.shape[0]//2)+15, 2)): | |
| # print(id) | |
| # plt.imsave(f'Overlaped_registered/{id}.png',registered[id, :, :], cmap = 'gray' ) | |
| # transformed_images.append(f'Overlaped_registered/{id}.png') | |
| for i in range(len(gallery_images)-10): | |
| im = plt.imread(gallery_images[i]) | |
| fname = os.path.split(gallery_images[i])[-1] | |
| moving_image_slice = ants.from_numpy(square_padding(gray_scale(im))) | |
| affined_fixed_atlas = ants.apply_transforms(fixed=fixed_image, | |
| moving=moving_image, | |
| transformlist=mytx['fwdtransforms'], | |
| interpolator='nearestNeighbor') | |
| nonrigid_fixed_atlas = ants.apply_transforms(fixed=fixed_image, | |
| moving=affined_fixed_atlas, | |
| transformlist=mytx_non_rigid['fwdtransforms'], | |
| interpolator='nearestNeighbor') | |
| # print(im.shape, nonrigid_fixed_atlas.numpy().shape) | |
| reconverted_img = reconvert_to_rgb(im[:,:,:3], nonrigid_fixed_atlas.numpy()[i,:,:]) | |
| plt.imsave(f'Overlaped_registered/{fname}',(reconverted_img * 255).astype(np.uint8)) | |
| transformed_images.append(f'Overlaped_registered/{fname}') | |
| transformed_images = natsorted(load_gallery_images()) | |
| return transformed_images | |
| def run_2D_registration(user_section, slice_idx): | |
| global allen_atlas_ccf, allen_template_ccf, gallery_selected_data | |
| template_atlas = allen_atlas_ccf | |
| template_section = allen_template_ccf | |
| template_atlas = allen_atlas_ccf[slice_idx,:,:] | |
| template_section = allen_template_ccf[slice_idx,:,:] | |
| # colored_atlas = colored_atlas[slice_idx,:,:] | |
| print(np.shape(template_atlas), np.shape(template_section)) | |
| user_section = square_padding(user_section) | |
| template_atlas = np.uint16(template_atlas*255) | |
| template_atlas = square_padding(template_atlas) | |
| template_section = square_padding(template_section) | |
| fixed_image = ants.from_numpy(user_section) | |
| moving_atlas_ants = ants.from_numpy(template_atlas) | |
| moving_image = ants.from_numpy(template_section) | |
| mytx = affine_reg(fixed_image,moving_image) | |
| mytx_non_rigid = nonrigid_reg(fixed_image,mytx) | |
| gallery_imgs = natsorted(load_gallery_images()) | |
| im = plt.imread(gallery_imgs[gallery_selected_data]) | |
| print(im.shape) | |
| moving_gallery_img = ants.from_numpy(square_padding(gray_scale(im))) | |
| affined_fixed_atlas = ants.apply_transforms(fixed=fixed_image, | |
| moving=moving_gallery_img, | |
| transformlist=mytx['fwdtransforms'], | |
| interpolator='nearestNeighbor') | |
| nonrigid_fixed_atlas = ants.apply_transforms(fixed=fixed_image, | |
| moving=affined_fixed_atlas, | |
| transformlist=mytx_non_rigid['fwdtransforms'], | |
| interpolator='nearestNeighbor') | |
| gallery_images = load_gallery_images() | |
| transformed_images = [] | |
| if not(os.path.exists("Overlaped_registered")): | |
| os.mkdir("Overlaped_registered") | |
| print("Reconverting Image") | |
| reconverted_img = reconvert_to_rgb(im[:,:,:3], nonrigid_fixed_atlas.numpy()) | |
| plt.imsave(f'Overlaped_registered/registered_slice_reconverted_1.png',(reconverted_img * 255).astype(np.uint8)) | |
| return ['Overlaped_registered/registered_slice_reconverted_1.png'] | |
| def reconvert_to_rgb(img_rgb, img_gray_processed): | |
| # 3. Resize original RGB to match processed grayscale shape | |
| original_shape = img_gray_processed.shape | |
| img_rgb_resized = resize(img_rgb, (original_shape[0], original_shape[1]), preserve_range=True) | |
| # 4. Convert resized RGB to grayscale | |
| gray_resized = np.mean(img_rgb_resized, axis=2) + 1e-8 # avoid divide-by-zero | |
| # 5. Compute ratio of new_gray / old_gray, apply to RGB channels | |
| ratio = img_gray_processed / gray_resized | |
| img_recolored = img_rgb_resized * ratio[..., np.newaxis] | |
| # 6. Clip values to [0, 1] if image is in float format (common for imread) | |
| # img_recolored = np.clip(img_recolored, 0, 1) | |
| return img_recolored | |
| def embeddings_classifier(user_section, atlas_embeddings,atlas_labels): | |
| class SliceEncoder(nn.Module): | |
| def __init__(self): | |
| super(SliceEncoder, self).__init__() | |
| base = models.resnet18(pretrained=True) | |
| self.backbone = nn.Sequential(*list(base.children())[:-1]) # Remove final FC layer | |
| def forward(self, x): | |
| x = self.backbone(x) # Output shape: (B, 512, 1, 1) | |
| return x.view(x.size(0), -1) # Flatten to (B, 512) | |
| # Transform | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| ]) | |
| # Feature extraction utility | |
| def extract_embedding(img_array, encoder, transform): | |
| img = Image.fromarray(((img_array) * 255).astype(np.uint8)).convert('RGB') | |
| img_tensor = transform(img).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| embedding = encoder(img_tensor) | |
| return embedding.cpu().numpy().flatten() | |
| # Prepare device and model | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| encoder = SliceEncoder().to(device).eval() | |
| # Precompute atlas embeddings | |
| query_emb = extract_embedding(user_section, encoder, transform).reshape(1, -1) | |
| sims = cosine_similarity(query_emb, atlas_embeddings)[0] | |
| pred_idx = np.argmax(sims) | |
| pred_gt = atlas_labels[pred_idx] | |
| return int(pred_gt) | |
| def gray_scale(image): | |
| # input: a 2D RGB image (x,y,z) | |
| # output: a grayscale image (x,y) | |
| # todo: fix the depth issue of pixels | |
| if len(np.shape(image))>2: | |
| return cv2.cvtColor(image[:,:,:3], cv2.COLOR_RGB2GRAY) | |
| else: | |
| return image | |
| def atlas_slice_prediction(user_section, axis = 'coronal'): | |
| user_section = gray_scale(square_padding(gray_scale(user_section))) | |
| user_section = gray_scale(user_section) | |
| user_section = square_padding(user_section, 224) | |
| user_section = (user_section - np.min(user_section))/((np.max(user_section) - np.min(user_section))) | |
| print("Loading model") | |
| atlas_embeddings = np.load(f"registration/atlas_embeddings_{axis}.npy") | |
| atlas_labels = np.load(f"registration/atlas_labels_{axis}.npy") | |
| idx = embeddings_classifier(user_section, atlas_embeddings,atlas_labels) | |
| return idx | |
| example_files = [ | |
| ["./resampled_green_25.nii.gz", "CCF registered Sample", "3D"], | |
| ["./Brain_1.png", "Custom Sample", "2D"], | |
| # ["examples/sample3.nii.gz"] | |
| ] | |
| # Global variables | |
| coronal_slices = [] | |
| last_probabilities = [] | |
| prob_df = pd.DataFrame() | |
| vol = None | |
| slice_idx = None | |
| # Target cell types | |
| cell_types = [ | |
| "ABC.NN", "Astro.TE.NN", "CLA.EPd.CTX.Car3.Glut", "Endo.NN", "L2.3.IT.CTX.Glut", | |
| "L4.5.IT.CTX.Glut", "L5.ET.CTX.Glut", "L5.IT.CTX.Glut", "L5.NP.CTX.Glut", "L6.CT.CTX.Glut", | |
| "L6.IT.CTX.Glut", "L6b.CTX.Glut", "Lamp5.Gaba", "Lamp5.Lhx6.Gaba", "Lymphoid.NN", "Microglia.NN", | |
| "OPC.NN", "Oligo.NN", "Peri.NN", "Pvalb.Gaba", "Pvalb.chandelier.Gaba", "SMC.NN", "Sncg.Gaba", | |
| "Sst.Chodl.Gaba", "Sst.Gaba", "VLMC.NN", "Vip.Gaba" | |
| ] | |
| actual_ids = [30,52,71,91,104,109,118,126,131,137,141,164,178,182,197,208,218,226,232,242,244,248,256,262,270,282,293,297,308,323,339,344,350,355,364,372,379,389,395,401,410,415,418,424,429,434,440,444,469,479,487,509] | |
| gallery_ids = [5,6,8,9,10,11,12,13,14,15,16,17,18,19,24,25,26,27,28,29,30,31,32,33,35,36,37,38,39,40,42,43,44,45,46,47,48,49,50,51,52,54,55,56,57,58,59,60,61,62,64,66,67] | |
| # gallery_ids.reverse() | |
| allen_atlas_ccf, header = nrrd.read('./registration/annotation_25.nrrd') | |
| allen_template_ccf, _ = nrrd.read("./registration/average_template_25.nrrd") | |
| # colored_atlas,_ = nrrd.read('./registration/colored_atlas_turbo.nrrd') | |
| gallery_selected_data = None | |
| def load_nifti_or_png(file, sample_type, data_type): | |
| global coronal_slices, vol, slice_idx, gallery_selected_data | |
| if file.name.endswith(".nii") or file.name.endswith(".nii.gz"): | |
| img = nib.load(file.name) | |
| vol = img.get_fdata() | |
| coronal_slices = [vol[i, :, :] for i in range(vol.shape[0])] | |
| if data_type == "2D": | |
| mid_index = vol.shape[0] // 2 | |
| slice_img = Image.fromarray((coronal_slices[mid_index] / np.max(coronal_slices[mid_index]) * 255).astype(np.uint8)) | |
| gallery_images = load_gallery_images() | |
| return ( | |
| slice_img, | |
| gr.update(visible=False), | |
| gallery_images, | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| gr.update(visible=(sample_type == "Custom Sample")) | |
| ) | |
| else: # 3D with actual_ids only | |
| coronal_slices = [vol[i, :, :] for i in actual_ids] | |
| idx = len(actual_ids) // 2 # Mid of actual_ids | |
| slice_img = Image.fromarray((coronal_slices[idx] / np.max(coronal_slices[idx]) * 255).astype(np.uint8)) | |
| gallery_images = load_gallery_images() | |
| gallery_images = natsorted(gallery_images) | |
| return ( | |
| slice_img, | |
| gr.update(visible=True, minimum=0, maximum=len(coronal_slices)-1, value=idx), | |
| gallery_images, | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| gr.update(visible=(sample_type == "Custom Sample")) | |
| ) | |
| else: | |
| img = Image.open(file.name).convert("L") | |
| vol = np.array(img) | |
| coronal_slices = [np.array(img)] | |
| gallery_images = natsorted(load_gallery_images()) | |
| idx = atlas_slice_prediction(np.array(img)) | |
| slice_idx = idx | |
| closest_actual_idx = min(actual_ids, key=lambda x: abs(x - idx)) | |
| gallery_index = actual_ids.index(closest_actual_idx) | |
| print(gallery_index, len(actual_ids) -(gallery_index)) | |
| gallery_selected_data = len(actual_ids) -(gallery_index) | |
| return ( | |
| img, | |
| gr.update(visible=False), | |
| gr.update(selected_index=len(actual_ids) -(gallery_index) if gallery_index < len(gallery_ids) else 0, visible = True), | |
| # gr.update(value=gallery_images, selected_index=len(actual_ids) -(gallery_index)), # gallery | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| gr.update(visible=(sample_type == "Custom Sample")) | |
| ) | |
| def update_slice(index): | |
| if not coronal_slices: | |
| return None, None, None | |
| slice_img = Image.fromarray((coronal_slices[index] / np.max(coronal_slices[index]) * 255).astype(np.uint8)) | |
| gallery_selection = gr.update(selected_index=len(gallery_ids) - index if index < len(gallery_ids) else 0) | |
| if last_probabilities: | |
| noise = np.random.normal(0, 0.01, size=len(last_probabilities)) | |
| new_probs = np.clip(np.array(last_probabilities) + noise, 0, None) | |
| new_probs /= new_probs.sum() | |
| else: | |
| new_probs = [] | |
| return slice_img, plot_probabilities(new_probs), gallery_selection | |
| def load_gallery_images(): | |
| folder = "Overlapped_updated" | |
| images = [] | |
| if os.path.exists(folder): | |
| for fname in sorted(os.listdir(folder)): | |
| if fname.lower().endswith(('.png', '.jpg', '.jpeg')): | |
| images.append(os.path.join(folder, fname)) | |
| return images | |
| def generate_random_probabilities(): | |
| probs = np.random.rand(len(cell_types)) | |
| low_indices = np.random.choice(len(probs), size=5, replace=False) | |
| for idx in low_indices: | |
| probs[idx] = np.random.rand() * 0.01 | |
| probs /= probs.sum() | |
| return probs.tolist() | |
| def plot_probabilities(probabilities): | |
| if len(probabilities) < 1: | |
| return None | |
| prob_df = pd.DataFrame({"labels": cell_types, "values": probabilities}) | |
| os.makedirs("outputs", exist_ok=True) | |
| prob_df.to_csv('outputs/Cell_types_predictions.csv', index=False) | |
| return prob_df | |
| def run_mapping(): | |
| global last_probabilities | |
| last_probabilities = generate_random_probabilities() | |
| return plot_probabilities(last_probabilities), gr.update(visible=True), gr.update(value = 'outputs/Cell_types_predictions.csv', visible = True), gr.update(visible=True) | |
| def run_registration(data_type, selected_idx): | |
| global vol, slice_idx | |
| print("Running registration logic here..., Vol shape::", vol.shape) | |
| if data_type == "3D": | |
| gallery_images = run_3D_registration(vol) | |
| else: | |
| gallery_images = run_2D_registration(vol, slice_idx) | |
| return gallery_images | |
| return "Registration complete!" | |
| def download_csv(): | |
| return 'outputs/Cell_types_predictions.csv' | |
| def handle_data_type_change(dt): | |
| if dt == "2D": | |
| return gr.update(visible=False) | |
| else: | |
| return gr.update(visible=True, minimum=0, maximum=len(actual_ids)-1, value=len(actual_ids)//2) | |
| def on_select(evt: gr.SelectData): | |
| print("Selected index:", evt) | |
| print("Selected value:", evt.value) | |
| print("Selected coordinates:", evt.selected) | |
| gallery_selected_data = evt.index | |
| gallery_images = natsorted(load_gallery_images()) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Map My Sections\n### This GUI is part of the submission to the Allen Institute's Map My Sections tool by Tibbling Technologies.") | |
| with gr.Row(): | |
| gr.Markdown("### Step 1: Upload your sample, currently only .nii.gz (3D) and .png (2D) supported") | |
| gr.Markdown("### Step 2: Select your sample and data type.") | |
| with gr.Row(): | |
| nifti_file = gr.File(label="File Upload") | |
| with gr.Column(): | |
| sample_type = gr.Dropdown(choices=["CCF registered Sample", "Custom Sample"], value="CCF registered Sample", label="Sample Type") | |
| data_type = gr.Radio(choices=["2D", "3D"], value="3D", label="Data Type") | |
| gr.Examples(examples=example_files, inputs=[nifti_file, sample_type, data_type], label="Try one of our example samples") | |
| with gr.Row(visible=False) as slice_row: | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Step 3: Visualizing your uploaded sample") | |
| image_display = gr.Image(height=450) | |
| slice_slider = gr.Slider(minimum=0, maximum=0, value=0, step=1, label="Slices", visible=False) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Step 4: Visualizing Allen Brain Cell Types Atlas") | |
| gallery = gr.Gallery(label="ABC Atlas", value = gallery_images,columns = 5, height = 450) | |
| gr.Markdown("### Step 5: Run cell type mapping and/or registeration. ") | |
| with gr.Row(): | |
| run_button = gr.Button("Map My Sections") | |
| reg_button = gr.Button("Run Registration (Optional)", visible=False) | |
| with gr.Column(visible=False) as plot_row: | |
| gr.Markdown("### Step 6: Quantitative results of the mapping model.") | |
| prob_plot = gr.BarPlot(prob_df, x="labels", y="values", title="Cell Type Probabilities", scroll_to_output=True, x_label_angle=-90, height=400) | |
| download_step = gr.Markdown("### Step 7: Download Results.", visible = False) | |
| download_button = gr.DownloadButton(label="Download Results", visible = False) | |
| nifti_file.change( | |
| load_nifti_or_png, | |
| inputs=[nifti_file, sample_type, data_type], | |
| outputs=[image_display, slice_slider, gallery, slice_row, plot_row, reg_button] | |
| ) | |
| sample_type.change( | |
| lambda s: (gr.update(visible=True), gr.update(visible=(s == "Custom Sample"))), | |
| inputs=sample_type, | |
| outputs=[slice_row, reg_button] | |
| ) | |
| data_type.change( | |
| handle_data_type_change, | |
| inputs=data_type, | |
| outputs=slice_slider | |
| ) | |
| gallery.select(on_select, inputs=None, outputs=None) | |
| slice_slider.change(update_slice, inputs=slice_slider, outputs=[image_display, prob_plot, gallery]) | |
| run_button.click(run_mapping, outputs=[prob_plot, plot_row, download_button, download_step]) | |
| reg_button.click(run_registration,inputs = [data_type], outputs=[gallery]) | |
| demo.launch() |