Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| from viscy.translation.engine import VSUNet | |
| from huggingface_hub import hf_hub_download | |
| from numpy.typing import ArrayLike | |
| import numpy as np | |
| from skimage import exposure | |
| from skimage.transform import resize | |
| from skimage.util import invert | |
| import cmap | |
| class VSGradio: | |
| def __init__(self, model_config, model_ckpt_path): | |
| self.model_config = model_config | |
| self.model_ckpt_path = model_ckpt_path | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {self.device}") | |
| self.model = None | |
| self.load_model() | |
| def load_model(self): | |
| try: | |
| # Load the model checkpoint and move it to the correct device (GPU or CPU) | |
| print(f"Loading model from checkpoint: {self.model_ckpt_path}") | |
| self.model = VSUNet.load_from_checkpoint( | |
| self.model_ckpt_path, | |
| architecture="UNeXt2_2D", | |
| model_config=self.model_config, | |
| ) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| print("Model loaded successfully and set to evaluation mode") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| raise | |
| def normalize_fov(self, input: ArrayLike): | |
| "Normalizing the fov with zero mean and unit variance" | |
| mean = np.mean(input) | |
| std = np.std(input) | |
| return (input - mean) / std | |
| def preprocess_image_standard(self, input: ArrayLike): | |
| input = exposure.equalize_adapthist(input) | |
| return input | |
| def downscale_image(self, inp: ArrayLike, scale_factor: float): | |
| """Downscales the image by the given scaling factor""" | |
| height, width = inp.shape | |
| new_height = int(height * scale_factor) | |
| new_width = int(width * scale_factor) | |
| return resize(inp, (new_height, new_width), anti_aliasing=True) | |
| def predict(self, inp, scaling_factor: float): | |
| try: | |
| if inp is None: | |
| print("Error: Input image is None") | |
| return None, None | |
| # Normalize the input and convert to tensor | |
| inp = self.normalize_fov(inp) | |
| original_shape = inp.shape | |
| inp = apply_rescale_image(inp, scaling_factor) | |
| # Convert the input to a tensor | |
| inp = torch.from_numpy(np.array(inp).astype(np.float32)) | |
| test_dict = dict( | |
| index=None, | |
| source=inp.unsqueeze(0).unsqueeze(0).unsqueeze(0).to(self.device), | |
| ) | |
| with torch.inference_mode(): | |
| self.model.on_predict_start() # Necessary preprocessing for the model | |
| pred = ( | |
| self.model.predict_step(test_dict, 0, 0).cpu().numpy() | |
| ) # Move output back to CPU for post-processing | |
| # Post-process the model output and rescale intensity | |
| nuc_pred = pred[0, 0, 0] | |
| mem_pred = pred[0, 1, 0] | |
| # Resize predictions back to the original image size | |
| nuc_pred = resize(nuc_pred, original_shape, anti_aliasing=True) | |
| mem_pred = resize(mem_pred, original_shape, anti_aliasing=True) | |
| green_colormap = cmap.Colormap("green") | |
| magenta_colormap = cmap.Colormap("magenta") | |
| nuc_rgb = apply_colormap(nuc_pred, green_colormap) | |
| mem_rgb = apply_colormap(mem_pred, magenta_colormap) | |
| return nuc_rgb, mem_rgb | |
| except Exception as e: | |
| print(f"Error during prediction: {e}") | |
| empty_img = np.zeros((300, 300, 3), dtype=np.uint8) | |
| return empty_img, empty_img | |
| def apply_colormap(prediction, colormap: cmap.Colormap): | |
| """Apply a colormap to a single-channel prediction image.""" | |
| # Ensure the prediction is within the valid range [0, 1] | |
| prediction = exposure.rescale_intensity(prediction, out_range=(0, 1)) | |
| rgb_image = colormap(prediction) | |
| rgb_image_uint8 = (rgb_image * 255).astype(np.uint8) | |
| return rgb_image_uint8 | |
| def merge_images(nuc_rgb: ArrayLike, mem_rgb: ArrayLike) -> ArrayLike: | |
| """Merge nucleus and membrane images into a single RGB image.""" | |
| return np.maximum(nuc_rgb, mem_rgb) | |
| def apply_image_adjustments(image, invert_image: bool, gamma_factor: float): | |
| if invert_image: | |
| image = invert(image, signed_float=False) | |
| image = exposure.adjust_gamma(image, gamma_factor) | |
| return exposure.rescale_intensity(image, out_range=(0, 255)).astype(np.uint8) | |
| def apply_rescale_image(image, scaling_factor: float): | |
| scaling_factor = float(scaling_factor) | |
| return resize( | |
| image, | |
| (int(image.shape[0] * scaling_factor), int(image.shape[1] * scaling_factor)), | |
| anti_aliasing=True, | |
| ) | |
| def clear_outputs(image): | |
| return image, None, None | |
| def load_css(file_path): | |
| with open(file_path, "r") as file: | |
| return file.read() | |
| if __name__ == "__main__": | |
| try: | |
| print("Downloading model checkpoint...") | |
| model_ckpt_path = hf_hub_download( | |
| repo_id="compmicro-czb/VSCyto2D", filename="epoch=399-step=23200.ckpt" | |
| ) | |
| print(f"Model downloaded successfully to: {model_ckpt_path}") | |
| model_config = { | |
| "in_channels": 1, | |
| "out_channels": 2, | |
| "encoder_blocks": [3, 3, 9, 3], | |
| "dims": [96, 192, 384, 768], | |
| "decoder_conv_blocks": 2, | |
| "stem_kernel_size": [1, 2, 2], | |
| "in_stack_depth": 1, | |
| "pretraining": False, | |
| } | |
| print("Initializing VSGradio...") | |
| vsgradio = VSGradio(model_config, model_ckpt_path) | |
| print(f"VSGradio initialized successfully! Using device: {vsgradio.device}") | |
| # Initialize the Gradio app using Blocks | |
| with gr.Blocks(css=load_css("style.css")) as demo: | |
| # Title and description | |
| gr.HTML( | |
| """ | |
| <div style="display: flex; flex-direction: column; justify-content: center; align-items: center; text-align: center;"> | |
| <a href="https://www.czbiohub.org/sf/" target="_blank"> | |
| <div style="height: 60px; overflow: hidden; display: flex; align-items: center; justify-content: center;"> | |
| <img src="https://huggingface.co/spaces/compmicro-czb/VirtualStaining/resolve/main/misc/biohub_logo.png" style="width: 300px; height: auto; object-fit: contain;"> | |
| </div> | |
| </a> | |
| <div class='title-block'> Robust virtual staining of landmark organelles with Cytoland </div> | |
| </div> | |
| """ | |
| ) | |
| gr.HTML( | |
| """ | |
| <div class='description-block'> | |
| <p><b>Model:</b> VSCyto2D</p> | |
| <p><b>Input:</b> label-free image (e.g., QPI or phase contrast).</p> | |
| <p><b>Output:</b> Virtual staining of nucleus and membrane.</p> | |
| <p><b>Note:</b> The model works well with QPI, and sometimes generalizes to phase contrast and DIC.<br> | |
| It was trained primarily on HEK293T, BJ5, and A549 cells imaged at 20x. <br> | |
| We continue to diagnose and improve generalization<p> | |
| <p>Check out our paper: <a href='https://doi.org/10.1038/s42256-025-01046-2' target='_blank'><i>Liu et al., Robust virtual staining of landmark organelles with Cytoland</i></a></p> | |
| <p> For training your own model and analyzing large amounts of data, use our <a href='https://github.com/mehta-lab/VisCy/tree/main/examples/virtual_staining/dlmbl_exercise' target='_blank'>GitHub repository</a>.</p> | |
| </div> | |
| """ | |
| ) | |
| # Layout for input and output images | |
| with gr.Row(): | |
| input_image = gr.Image( | |
| type="numpy", image_mode="L", label="Upload Image" | |
| ) | |
| adjusted_image = gr.Image( | |
| type="numpy", | |
| image_mode="L", | |
| label="Adjusted Image (Preview)", | |
| interactive=False, | |
| ) | |
| with gr.Column(): | |
| output_nucleus = gr.Image( | |
| type="numpy", image_mode="RGB", label="VS Nucleus" | |
| ) | |
| output_membrane = gr.Image( | |
| type="numpy", image_mode="RGB", label="VS Membrane" | |
| ) | |
| merged_image = gr.Image( | |
| type="numpy", | |
| image_mode="RGB", | |
| label="Merged Image", | |
| visible=False, | |
| ) | |
| preprocess_invert = gr.Checkbox(label="Invert Image", value=False) | |
| gamma_factor = gr.Slider( | |
| label="Adjust Gamma", minimum=0.01, maximum=5.0, value=1.0, step=0.1 | |
| ) | |
| # Input field for the cell diameter in microns | |
| scaling_factor = gr.Textbox( | |
| label="Rescaling image factor", | |
| value="1.0", | |
| placeholder="Rescaling factor for the input image", | |
| ) | |
| # Checkbox for merging predictions | |
| merge_checkbox = gr.Checkbox( | |
| label="Merge Predictions into one image", value=True | |
| ) | |
| input_image.change( | |
| fn=apply_image_adjustments, | |
| inputs=[input_image, preprocess_invert, gamma_factor], | |
| outputs=adjusted_image, | |
| ) | |
| gamma_factor.change( | |
| fn=apply_image_adjustments, | |
| inputs=[input_image, preprocess_invert, gamma_factor], | |
| outputs=adjusted_image, | |
| ) | |
| cell_name = gr.Textbox( | |
| label="Cell Name", placeholder="Cell Type", visible=False | |
| ) | |
| imaging_modality = gr.Textbox( | |
| label="Imaging Modality", placeholder="Imaging Modality", visible=False | |
| ) | |
| references = gr.Textbox( | |
| label="References", placeholder="References", visible=False | |
| ) | |
| preprocess_invert.change( | |
| fn=apply_image_adjustments, | |
| inputs=[input_image, preprocess_invert, gamma_factor], | |
| outputs=adjusted_image, | |
| ) | |
| # Button to trigger prediction and update the output images | |
| submit_button = gr.Button( | |
| "Virtually Stain Image", elem_classes=["submit-button"] | |
| ) | |
| # Function to handle prediction and merging if needed | |
| def submit_and_merge(inp, scaling_factor, merge): | |
| nucleus, membrane = vsgradio.predict(inp, scaling_factor) | |
| if merge: | |
| merged = merge_images(nucleus, membrane) | |
| return ( | |
| merged, | |
| gr.update(visible=True), | |
| nucleus, | |
| gr.update(visible=False), | |
| membrane, | |
| gr.update(visible=False), | |
| ) | |
| else: | |
| return ( | |
| None, | |
| gr.update(visible=False), | |
| nucleus, | |
| gr.update(visible=True), | |
| membrane, | |
| gr.update(visible=True), | |
| ) | |
| submit_button.click( | |
| fn=submit_and_merge, | |
| inputs=[adjusted_image, scaling_factor, merge_checkbox], | |
| outputs=[ | |
| merged_image, | |
| merged_image, | |
| output_nucleus, | |
| output_nucleus, | |
| output_membrane, | |
| output_membrane, | |
| ], | |
| ) | |
| input_image.change( | |
| fn=clear_outputs, | |
| inputs=input_image, | |
| outputs=[adjusted_image, output_nucleus, output_membrane], | |
| ) | |
| def merge_predictions_fn(nucleus_image, membrane_image, merge): | |
| if merge: | |
| merged = merge_images(nucleus_image, membrane_image) | |
| return ( | |
| merged, | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| ) | |
| else: | |
| return ( | |
| None, | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| ) | |
| merge_checkbox.change( | |
| fn=merge_predictions_fn, | |
| inputs=[output_nucleus, output_membrane, merge_checkbox], | |
| outputs=[merged_image, merged_image, output_nucleus, output_membrane], | |
| ) | |
| # Example images and article | |
| examples_component = gr.Examples( | |
| examples=[ | |
| ["examples/a549.png", "A549", "QPI", 1.0, False, "1.0", "1"], | |
| ["examples/hek.png", "HEK293T", "QPI", 1.0, False, "1.0", "1"], | |
| ["examples/HEK_PhC.png", "HEK293T", "PhC", 1.2, True, "1.0", "1"], | |
| [ | |
| "examples/livecell_A172.png", | |
| "A172", | |
| "PhC", | |
| 1.0, | |
| True, | |
| "1.0", | |
| "2", | |
| ], | |
| ["examples/ctc_HeLa.png", "HeLa", "DIC", 0.7, False, "0.7", "3"], | |
| [ | |
| "examples/ctc_glioblastoma_astrocytoma_U373.png", | |
| "Glioblastoma", | |
| "PhC", | |
| 1.0, | |
| True, | |
| "2.0", | |
| "3", | |
| ], | |
| [ | |
| "examples/U2OS_BF.png", | |
| "U2OS", | |
| "Brightfield", | |
| 1.0, | |
| False, | |
| "0.3", | |
| "4", | |
| ], | |
| ["examples/U2OS_QPI.png", "U2OS", "QPI", 1.0, False, "0.3", "4"], | |
| [ | |
| "examples/neuromast2.png", | |
| "Zebrafish neuromast", | |
| "QPI", | |
| 0.6, | |
| False, | |
| "1.2", | |
| "1", | |
| ], | |
| [ | |
| "examples/mousekidney.png", | |
| "Mouse Kidney", | |
| "QPI", | |
| 0.8, | |
| False, | |
| "0.6", | |
| "4", | |
| ], | |
| ], | |
| inputs=[ | |
| input_image, | |
| cell_name, | |
| imaging_modality, | |
| gamma_factor, | |
| preprocess_invert, | |
| scaling_factor, | |
| references, | |
| ], | |
| ) | |
| # Article or footer information | |
| gr.HTML( | |
| """ | |
| <div class='article-block'> | |
| <li>1. <a href='https://doi.org/10.1038/s42256-025-01046-2' target='_blank'>Liu et al., Robust virtual staining of landmark organelles with Cytoland</a></li> | |
| <li>2. <a href='https://sartorius-research.github.io/LIVECell/' target='_blank'>Edlund et. al. LIVECEll-A large-scale dataset for label-free live cell segmentation</a></li> | |
| <li>3. <a href='https://celltrackingchallenge.net/' target='_blank'>Maska et. al.,The cell tracking challenge: 10 years of objective benchmarking </a></li> | |
| <li>4. <a href='https://elifesciences.org/articles/55502' target='_blank'>Guo et. al., Revealing architectural order with quantitative label-free imaging and deep learning</a></li> | |
| </div> | |
| """ | |
| ) | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=True) | |
| # Launch the Gradio app | |
| except Exception as e: | |
| print(f"Error initializing VSGradio: {e}") | |