from pathlib import Path import gradio as gr import numpy as np import onnxruntime as rt from PIL import Image MODEL_PATH = "model.onnx" EXAMPLES_DIR = Path("examples") IMAGE_SIZE = (128, 128) example_images = sorted(EXAMPLES_DIR.glob("*.jpg")) if EXAMPLES_DIR.exists() else [] if not example_images: example_images = [] try: sess_options = rt.SessionOptions() sess_options.intra_op_num_threads = 2 sess_options.inter_op_num_threads = 2 session = rt.InferenceSession( MODEL_PATH, sess_options=sess_options, providers=["CPUExecutionProvider"] ) input_name = session.get_inputs()[0].name output_names = [output.name for output in session.get_outputs()] except Exception as e: raise RuntimeError(f"Failed to load ONNX model: {e}") def normalize_mask(mask: np.ndarray) -> np.ndarray: """Normalizes mask values to [0, 1] range.""" min_val = mask.min() max_val = mask.max() if max_val > min_val: return (mask - min_val) / (max_val - min_val) return np.zeros_like(mask) def apply_mask(base_pil, prob_mask, threshold, color, binary): """Applies a probability mask over a base image with specified color and alpha.""" mask_arr = np.zeros((IMAGE_SIZE[0], IMAGE_SIZE[1], 4), dtype=np.uint8) active_mask = prob_mask > threshold mask_arr[..., 0] = color[0] mask_arr[..., 1] = color[1] mask_arr[..., 2] = color[2] if binary: mask_arr[..., 3] = np.where(active_mask, 150, 0).astype(np.uint8) else: alpha = (prob_mask * 200).astype(np.uint8) mask_arr[..., 3] = np.where(active_mask, alpha, 0).astype(np.uint8) mask_layer = Image.fromarray(mask_arr) return Image.alpha_composite(base_pil, mask_layer) def get_processed_data(image): """Runs inference and returns masks plus a pre-resized RGBA image for caching.""" if image is None: return None # Preprocess once img_resized = image.resize(IMAGE_SIZE, resample=Image.Resampling.BICUBIC) img_rgba = img_resized.convert("RGBA") img_array = np.array(img_resized).astype("float32") / 255.0 input_tensor = np.expand_dims(img_array, axis=0) onnx_pred = session.run(output_names, {input_name: input_tensor}) masks = onnx_pred[0][0] # Shape: (128, 128, 2) # Post-process probabilities spiral_prob = normalize_mask(masks[..., 0]) bar_prob = normalize_mask(masks[..., 1]) return {"masks": (spiral_prob, bar_prob), "img_rgba": img_rgba} def update_display( data, spiral_threshold, bar_threshold, binary_mask, show_image, show_spiral, show_bar, ): """Composites layers using cached data.""" if data is None: return None spiral_prob, bar_prob = data["masks"] img_rgba = data["img_rgba"] if show_image: base_pil = img_rgba else: base_pil = Image.new("RGBA", IMAGE_SIZE, (0, 0, 0, 255)) comp = base_pil if show_spiral: comp = apply_mask( comp, spiral_prob, spiral_threshold, (0, 255, 255), binary_mask ) if show_bar: comp = apply_mask(comp, bar_prob, bar_threshold, (218, 165, 32), binary_mask) return comp.resize((512, 512), resample=Image.Resampling.NEAREST) # --- Gradio Interface --- with gr.Blocks(title="Galaxy Segmentor", delete_cache=(7200, 7200)) as demo: cached_data = gr.State(None) gr.Markdown("# Galaxy Segmentor") gr.Markdown( "Upload a galaxy image to automatically segment into spiral arms and bars. Adjust thresholds to filter masks. " + "Trained with data from [Galaxy Zoo 3D](https://www.zooniverse.org/projects/klmasters/galaxy-zoo-3d/about/results). " + "Used in [this paper](https://arxiv.org/abs/2309.02380)." ) with gr.Row(): with gr.Column(): input_image = gr.Image( type="pil", label="Input Galaxy", sources=["upload", "clipboard"], ) with gr.Accordion("Minimum Thresholds", open=True): spiral_thresh = gr.Slider( 0.0, 1.0, value=0.5, label="Spiral Probability" ) bar_thresh = gr.Slider(0.0, 1.0, value=0.5, label="Bar Probability") if example_images: example_gallery = gr.Gallery( value=[str(p) for p in example_images], label="Example Galaxies", columns=5, height=128, allow_preview=False, interactive=False, object_fit="contain", ) def handle_select(evt: gr.SelectData): idx = evt.index return Image.open(example_images[idx]).convert("RGB") example_gallery.select( fn=handle_select, outputs=input_image, show_progress="hidden", ) with gr.Column(): output_image = gr.Image(label="Output") with gr.Accordion("Output Settings", open=True): with gr.Row(): show_img_check = gr.Checkbox(label="Show Image", value=True) show_spiral_check = gr.Checkbox(label="Show Spiral", value=True) show_bar_check = gr.Checkbox(label="Show Bar", value=True) binary_check = gr.Checkbox(label="Binarize Masks", value=False) # Define update logic display_inputs = [ cached_data, spiral_thresh, bar_thresh, binary_check, show_img_check, show_spiral_check, show_bar_check, ] # Event: Image changes input_image.change( get_processed_data, inputs=input_image, outputs=cached_data, show_progress="minimal", ).then( update_display, inputs=display_inputs, outputs=output_image, show_progress="hidden", ) # Event: Settings change settings_components = [ spiral_thresh, bar_thresh, binary_check, show_img_check, show_spiral_check, show_bar_check, ] gr.on( triggers=[c.change for c in settings_components], fn=update_display, inputs=display_inputs, outputs=output_image, show_progress="hidden", trigger_mode="always_last", ) if __name__ == "__main__": demo.queue() demo.launch( width=1280, max_file_size="10mb", theme=gr.themes.Base(primary_hue="blue"), )