Spaces:
Running
Running
| from pathlib import Path | |
| import gradio as gr | |
| import numpy as np | |
| import onnxruntime as rt | |
| from PIL import Image | |
| MODEL_PATH = "model.onnx" | |
| EXAMPLES_DIR = Path("examples") | |
| IMAGE_SIZE = (128, 128) | |
| example_images = sorted(EXAMPLES_DIR.glob("*.jpg")) if EXAMPLES_DIR.exists() else [] | |
| if not example_images: | |
| example_images = [] | |
| try: | |
| sess_options = rt.SessionOptions() | |
| sess_options.intra_op_num_threads = 2 | |
| sess_options.inter_op_num_threads = 2 | |
| session = rt.InferenceSession( | |
| MODEL_PATH, sess_options=sess_options, providers=["CPUExecutionProvider"] | |
| ) | |
| input_name = session.get_inputs()[0].name | |
| output_names = [output.name for output in session.get_outputs()] | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load ONNX model: {e}") | |
| def normalize_mask(mask: np.ndarray) -> np.ndarray: | |
| """Normalizes mask values to [0, 1] range.""" | |
| min_val = mask.min() | |
| max_val = mask.max() | |
| if max_val > min_val: | |
| return (mask - min_val) / (max_val - min_val) | |
| return np.zeros_like(mask) | |
| def apply_mask(base_pil, prob_mask, threshold, color, binary): | |
| """Applies a probability mask over a base image with specified color and alpha.""" | |
| mask_arr = np.zeros((IMAGE_SIZE[0], IMAGE_SIZE[1], 4), dtype=np.uint8) | |
| active_mask = prob_mask > threshold | |
| mask_arr[..., 0] = color[0] | |
| mask_arr[..., 1] = color[1] | |
| mask_arr[..., 2] = color[2] | |
| if binary: | |
| mask_arr[..., 3] = np.where(active_mask, 150, 0).astype(np.uint8) | |
| else: | |
| alpha = (prob_mask * 200).astype(np.uint8) | |
| mask_arr[..., 3] = np.where(active_mask, alpha, 0).astype(np.uint8) | |
| mask_layer = Image.fromarray(mask_arr) | |
| return Image.alpha_composite(base_pil, mask_layer) | |
| def get_processed_data(image): | |
| """Runs inference and returns masks plus a pre-resized RGBA image for caching.""" | |
| if image is None: | |
| return None | |
| # Preprocess once | |
| img_resized = image.resize(IMAGE_SIZE, resample=Image.Resampling.BICUBIC) | |
| img_rgba = img_resized.convert("RGBA") | |
| img_array = np.array(img_resized).astype("float32") / 255.0 | |
| input_tensor = np.expand_dims(img_array, axis=0) | |
| onnx_pred = session.run(output_names, {input_name: input_tensor}) | |
| masks = onnx_pred[0][0] # Shape: (128, 128, 2) | |
| # Post-process probabilities | |
| spiral_prob = normalize_mask(masks[..., 0]) | |
| bar_prob = normalize_mask(masks[..., 1]) | |
| return {"masks": (spiral_prob, bar_prob), "img_rgba": img_rgba} | |
| def update_display( | |
| data, | |
| spiral_threshold, | |
| bar_threshold, | |
| binary_mask, | |
| show_image, | |
| show_spiral, | |
| show_bar, | |
| ): | |
| """Composites layers using cached data.""" | |
| if data is None: | |
| return None | |
| spiral_prob, bar_prob = data["masks"] | |
| img_rgba = data["img_rgba"] | |
| if show_image: | |
| base_pil = img_rgba | |
| else: | |
| base_pil = Image.new("RGBA", IMAGE_SIZE, (0, 0, 0, 255)) | |
| comp = base_pil | |
| if show_spiral: | |
| comp = apply_mask( | |
| comp, spiral_prob, spiral_threshold, (0, 255, 255), binary_mask | |
| ) | |
| if show_bar: | |
| comp = apply_mask(comp, bar_prob, bar_threshold, (218, 165, 32), binary_mask) | |
| return comp.resize((512, 512), resample=Image.Resampling.NEAREST) | |
| # --- Gradio Interface --- | |
| with gr.Blocks(title="Galaxy Segmentor", delete_cache=(7200, 7200)) as demo: | |
| cached_data = gr.State(None) | |
| gr.Markdown("# Galaxy Segmentor") | |
| gr.Markdown( | |
| "Upload a galaxy image to automatically segment into spiral arms and bars. Adjust thresholds to filter masks. " | |
| + "Trained with data from [Galaxy Zoo 3D](https://www.zooniverse.org/projects/klmasters/galaxy-zoo-3d/about/results). " | |
| + "Used in [this paper](https://arxiv.org/abs/2309.02380)." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image( | |
| type="pil", | |
| label="Input Galaxy", | |
| sources=["upload", "clipboard"], | |
| ) | |
| with gr.Accordion("Minimum Thresholds", open=True): | |
| spiral_thresh = gr.Slider( | |
| 0.0, 1.0, value=0.5, label="Spiral Probability" | |
| ) | |
| bar_thresh = gr.Slider(0.0, 1.0, value=0.5, label="Bar Probability") | |
| if example_images: | |
| example_gallery = gr.Gallery( | |
| value=[str(p) for p in example_images], | |
| label="Example Galaxies", | |
| columns=5, | |
| height=128, | |
| allow_preview=False, | |
| interactive=False, | |
| object_fit="contain", | |
| ) | |
| def handle_select(evt: gr.SelectData): | |
| idx = evt.index | |
| return Image.open(example_images[idx]).convert("RGB") | |
| example_gallery.select( | |
| fn=handle_select, | |
| outputs=input_image, | |
| show_progress="hidden", | |
| ) | |
| with gr.Column(): | |
| output_image = gr.Image(label="Output") | |
| with gr.Accordion("Output Settings", open=True): | |
| with gr.Row(): | |
| show_img_check = gr.Checkbox(label="Show Image", value=True) | |
| show_spiral_check = gr.Checkbox(label="Show Spiral", value=True) | |
| show_bar_check = gr.Checkbox(label="Show Bar", value=True) | |
| binary_check = gr.Checkbox(label="Binarize Masks", value=False) | |
| # Define update logic | |
| display_inputs = [ | |
| cached_data, | |
| spiral_thresh, | |
| bar_thresh, | |
| binary_check, | |
| show_img_check, | |
| show_spiral_check, | |
| show_bar_check, | |
| ] | |
| # Event: Image changes | |
| input_image.change( | |
| get_processed_data, | |
| inputs=input_image, | |
| outputs=cached_data, | |
| show_progress="minimal", | |
| ).then( | |
| update_display, | |
| inputs=display_inputs, | |
| outputs=output_image, | |
| show_progress="hidden", | |
| ) | |
| # Event: Settings change | |
| settings_components = [ | |
| spiral_thresh, | |
| bar_thresh, | |
| binary_check, | |
| show_img_check, | |
| show_spiral_check, | |
| show_bar_check, | |
| ] | |
| gr.on( | |
| triggers=[c.change for c in settings_components], | |
| fn=update_display, | |
| inputs=display_inputs, | |
| outputs=output_image, | |
| show_progress="hidden", | |
| trigger_mode="always_last", | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue() | |
| demo.launch( | |
| width=1280, | |
| max_file_size="10mb", | |
| theme=gr.themes.Base(primary_hue="blue"), | |
| ) | |