Mitchell Cavanagh
Update readme and app description
d7d2254 unverified
from pathlib import Path
import gradio as gr
import numpy as np
import onnxruntime as rt
from PIL import Image
MODEL_PATH = "model.onnx"
EXAMPLES_DIR = Path("examples")
IMAGE_SIZE = (128, 128)
example_images = sorted(EXAMPLES_DIR.glob("*.jpg")) if EXAMPLES_DIR.exists() else []
if not example_images:
example_images = []
try:
sess_options = rt.SessionOptions()
sess_options.intra_op_num_threads = 2
sess_options.inter_op_num_threads = 2
session = rt.InferenceSession(
MODEL_PATH, sess_options=sess_options, providers=["CPUExecutionProvider"]
)
input_name = session.get_inputs()[0].name
output_names = [output.name for output in session.get_outputs()]
except Exception as e:
raise RuntimeError(f"Failed to load ONNX model: {e}")
def normalize_mask(mask: np.ndarray) -> np.ndarray:
"""Normalizes mask values to [0, 1] range."""
min_val = mask.min()
max_val = mask.max()
if max_val > min_val:
return (mask - min_val) / (max_val - min_val)
return np.zeros_like(mask)
def apply_mask(base_pil, prob_mask, threshold, color, binary):
"""Applies a probability mask over a base image with specified color and alpha."""
mask_arr = np.zeros((IMAGE_SIZE[0], IMAGE_SIZE[1], 4), dtype=np.uint8)
active_mask = prob_mask > threshold
mask_arr[..., 0] = color[0]
mask_arr[..., 1] = color[1]
mask_arr[..., 2] = color[2]
if binary:
mask_arr[..., 3] = np.where(active_mask, 150, 0).astype(np.uint8)
else:
alpha = (prob_mask * 200).astype(np.uint8)
mask_arr[..., 3] = np.where(active_mask, alpha, 0).astype(np.uint8)
mask_layer = Image.fromarray(mask_arr)
return Image.alpha_composite(base_pil, mask_layer)
def get_processed_data(image):
"""Runs inference and returns masks plus a pre-resized RGBA image for caching."""
if image is None:
return None
# Preprocess once
img_resized = image.resize(IMAGE_SIZE, resample=Image.Resampling.BICUBIC)
img_rgba = img_resized.convert("RGBA")
img_array = np.array(img_resized).astype("float32") / 255.0
input_tensor = np.expand_dims(img_array, axis=0)
onnx_pred = session.run(output_names, {input_name: input_tensor})
masks = onnx_pred[0][0] # Shape: (128, 128, 2)
# Post-process probabilities
spiral_prob = normalize_mask(masks[..., 0])
bar_prob = normalize_mask(masks[..., 1])
return {"masks": (spiral_prob, bar_prob), "img_rgba": img_rgba}
def update_display(
data,
spiral_threshold,
bar_threshold,
binary_mask,
show_image,
show_spiral,
show_bar,
):
"""Composites layers using cached data."""
if data is None:
return None
spiral_prob, bar_prob = data["masks"]
img_rgba = data["img_rgba"]
if show_image:
base_pil = img_rgba
else:
base_pil = Image.new("RGBA", IMAGE_SIZE, (0, 0, 0, 255))
comp = base_pil
if show_spiral:
comp = apply_mask(
comp, spiral_prob, spiral_threshold, (0, 255, 255), binary_mask
)
if show_bar:
comp = apply_mask(comp, bar_prob, bar_threshold, (218, 165, 32), binary_mask)
return comp.resize((512, 512), resample=Image.Resampling.NEAREST)
# --- Gradio Interface ---
with gr.Blocks(title="Galaxy Segmentor", delete_cache=(7200, 7200)) as demo:
cached_data = gr.State(None)
gr.Markdown("# Galaxy Segmentor")
gr.Markdown(
"Upload a galaxy image to automatically segment into spiral arms and bars. Adjust thresholds to filter masks. "
+ "Trained with data from [Galaxy Zoo 3D](https://www.zooniverse.org/projects/klmasters/galaxy-zoo-3d/about/results). "
+ "Used in [this paper](https://arxiv.org/abs/2309.02380)."
)
with gr.Row():
with gr.Column():
input_image = gr.Image(
type="pil",
label="Input Galaxy",
sources=["upload", "clipboard"],
)
with gr.Accordion("Minimum Thresholds", open=True):
spiral_thresh = gr.Slider(
0.0, 1.0, value=0.5, label="Spiral Probability"
)
bar_thresh = gr.Slider(0.0, 1.0, value=0.5, label="Bar Probability")
if example_images:
example_gallery = gr.Gallery(
value=[str(p) for p in example_images],
label="Example Galaxies",
columns=5,
height=128,
allow_preview=False,
interactive=False,
object_fit="contain",
)
def handle_select(evt: gr.SelectData):
idx = evt.index
return Image.open(example_images[idx]).convert("RGB")
example_gallery.select(
fn=handle_select,
outputs=input_image,
show_progress="hidden",
)
with gr.Column():
output_image = gr.Image(label="Output")
with gr.Accordion("Output Settings", open=True):
with gr.Row():
show_img_check = gr.Checkbox(label="Show Image", value=True)
show_spiral_check = gr.Checkbox(label="Show Spiral", value=True)
show_bar_check = gr.Checkbox(label="Show Bar", value=True)
binary_check = gr.Checkbox(label="Binarize Masks", value=False)
# Define update logic
display_inputs = [
cached_data,
spiral_thresh,
bar_thresh,
binary_check,
show_img_check,
show_spiral_check,
show_bar_check,
]
# Event: Image changes
input_image.change(
get_processed_data,
inputs=input_image,
outputs=cached_data,
show_progress="minimal",
).then(
update_display,
inputs=display_inputs,
outputs=output_image,
show_progress="hidden",
)
# Event: Settings change
settings_components = [
spiral_thresh,
bar_thresh,
binary_check,
show_img_check,
show_spiral_check,
show_bar_check,
]
gr.on(
triggers=[c.change for c in settings_components],
fn=update_display,
inputs=display_inputs,
outputs=output_image,
show_progress="hidden",
trigger_mode="always_last",
)
if __name__ == "__main__":
demo.queue()
demo.launch(
width=1280,
max_file_size="10mb",
theme=gr.themes.Base(primary_hue="blue"),
)