import io import os import time import pathlib import numpy as np import onnxruntime as ort import streamlit as st from huggingface_hub import hf_hub_download from PIL import Image from streamlit_drawable_canvas import st_canvas SPACE_REPO_ID = os.environ.get("HF_SPACE_REPO_ID", "hbyecoding/iU-RWKV-demo") MODEL_REPO_ID = os.environ.get("HF_MODEL_REPO_ID", "hbyecoding/iU-RWKV") HF_TOKEN = os.environ.get("HF_TOKEN") DISPLAY_SIZE = (256, 256) MODEL_SIZE = (192, 192) ASSETS_ROOT = pathlib.Path("hf_demo_assets") MODEL_DIR = pathlib.Path("models") MODELS = { "BUSI": { "assets_subdir": "BUSI", "onnx_filename": "iu_rwkv_busi_192.onnx", }, "POLY": { "assets_subdir": "POLY", "onnx_filename": "iu_rwkv_poly_192.onnx", }, "ISIC18": { "assets_subdir": "ISIC18", "onnx_filename": "iu_rwkv_isic18_192.onnx", }, } def _resize_image_rgb(pil_img, size): return pil_img.convert("RGB").resize(size, Image.Resampling.BILINEAR) def _resize_mask(pil_img, size): return pil_img.convert("L").resize(size, Image.Resampling.NEAREST) def _to_gray01(pil_img): arr = np.asarray(pil_img.convert("L"), dtype=np.float32) / 255.0 return arr def _bbox_channel(box, shape_hw): h, w = shape_hw ch = np.zeros((h, w), dtype=np.float32) if box is None: return ch x0, y0, x1, y1 = box x0 = int(np.clip(x0, 0, w)) x1 = int(np.clip(x1, 0, w)) y0 = int(np.clip(y0, 0, h)) y1 = int(np.clip(y1, 0, h)) if x1 > x0 and y1 > y0: ch[y0:y1, x0:x1] = 1.0 return ch def _click_channels(clicks, shape_hw): h, w = shape_hw pos = np.zeros((h, w), dtype=np.float32) neg = np.zeros((h, w), dtype=np.float32) if not clicks: return pos, neg for x, y, label in clicks: x = int(np.clip(x, 0, w - 1)) y = int(np.clip(y, 0, h - 1)) if int(label) == 1: pos[y, x] = 1.0 else: neg[y, x] = 1.0 return pos, neg def _build_model_input(pil_img_resized, box_xyxy, clicks_xy, model_size_hw): h, w = model_size_hw gray = _to_gray01(pil_img_resized) if gray.shape != (h, w): gray = np.asarray(_resize_mask(pil_img_resized, (w, h)), dtype=np.float32) / 255.0 img_ch = gray[None, :, :] box_ch = _bbox_channel(box_xyxy, (h, w))[None, :, :] pos_ch, neg_ch = _click_channels(clicks_xy, (h, w)) click_ch = np.stack([pos_ch, neg_ch], axis=0) mask_input_ch = np.zeros((1, h, w), dtype=np.float32) x = np.concatenate([img_ch, box_ch, click_ch, mask_input_ch], axis=0).astype(np.float32) x = x[None, :, :, :] return x def _scale_xyxy(box, src_size, dst_size): if box is None: return None sx = dst_size[0] / src_size[0] sy = dst_size[1] / src_size[1] x0, y0, x1, y1 = box return [int(round(x0 * sx)), int(round(y0 * sy)), int(round(x1 * sx)), int(round(y1 * sy))] def _scale_clicks(clicks, src_size, dst_size): if not clicks: return [] sx = dst_size[0] / src_size[0] sy = dst_size[1] / src_size[1] out = [] for x, y, label in clicks: out.append((int(round(x * sx)), int(round(y * sy)), int(label))) return out def _list_demo_images(assets_subdir): img_dir = ASSETS_ROOT / assets_subdir / "images" if not img_dir.exists(): return [] files = [] for p in img_dir.iterdir(): if p.suffix.lower() in [".png", ".jpg", ".jpeg", ".bmp"]: files.append(p) return sorted(files, key=lambda x: x.name) def _find_demo_mask(assets_subdir, stem): mask_dir = ASSETS_ROOT / assets_subdir / "masks" if not mask_dir.exists(): return None for p in mask_dir.iterdir(): if p.stem == stem: return p return None @st.cache_resource def get_ort_session(onnx_filename, num_threads): local_path = MODEL_DIR / onnx_filename if local_path.exists(): model_path = str(local_path) source = "local" else: try: model_path = hf_hub_download( repo_id=SPACE_REPO_ID, repo_type="space", filename=str(local_path.as_posix()), token=HF_TOKEN, ) source = "space" except Exception: model_path = hf_hub_download( repo_id=MODEL_REPO_ID, repo_type="model", filename=str(local_path.as_posix()), token=HF_TOKEN, ) source = "model" sess_opts = ort.SessionOptions() sess_opts.intra_op_num_threads = int(num_threads) sess_opts.inter_op_num_threads = 1 sess_opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL session = ort.InferenceSession(model_path, sess_options=sess_opts, providers=["CPUExecutionProvider"]) input_name = session.get_inputs()[0].name return session, input_name, source, model_path def run_onnx(session, input_name, x): y = session.run(None, {input_name: x})[0] return y def sigmoid(x): return 1.0 / (1.0 + np.exp(-x)) def dice(pred01, gt01, eps=1e-7): pred = pred01.astype(np.float32) gt = gt01.astype(np.float32) inter = np.sum(pred * gt) denom = np.sum(pred) + np.sum(gt) return float((2.0 * inter + eps) / (denom + eps)) def constraint_metrics(pred01, box_xyxy, clicks_xy, shape_hw): h, w = shape_hw pos = [(x, y) for (x, y, lab) in clicks_xy if int(lab) == 1] neg = [(x, y) for (x, y, lab) in clicks_xy if int(lab) == 0] pos_hit = None if len(pos) > 0: hits = [int(pred01[int(np.clip(y, 0, h - 1)), int(np.clip(x, 0, w - 1))] == 1) for x, y in pos] pos_hit = float(np.mean(hits)) neg_ok = None if len(neg) > 0: oks = [int(pred01[int(np.clip(y, 0, h - 1)), int(np.clip(x, 0, w - 1))] == 0) for x, y in neg] neg_ok = float(np.mean(oks)) outside_ratio = None if box_xyxy is not None: x0, y0, x1, y1 = box_xyxy x0 = int(np.clip(x0, 0, w)) x1 = int(np.clip(x1, 0, w)) y0 = int(np.clip(y0, 0, h)) y1 = int(np.clip(y1, 0, h)) bbox_mask = np.zeros((h, w), dtype=np.uint8) if x1 > x0 and y1 > y0: bbox_mask[y0:y1, x0:x1] = 1 pred_sum = float(np.sum(pred01)) if pred_sum > 0: outside_ratio = float(np.sum(pred01 * (1 - bbox_mask)) / pred_sum) else: outside_ratio = 0.0 pred_area_ratio = float(np.sum(pred01)) / float(h * w) return { "pos_hit_rate": pos_hit, "neg_ok_rate": neg_ok, "bbox_outside_ratio": outside_ratio, "pred_area_ratio": pred_area_ratio, } st.set_page_config(page_title="iU-RWKV Interactive Segmentation (ONNX)", layout="wide") st.title("iU-RWKV Interactive Segmentation Demo (Hugging Face Spaces)") st.markdown( "This Space runs iU-RWKV as an **ONNX Runtime** model on CPU. " "We report **per-click iteration latency** (prompt update + ONNX forward) and **interaction-consistency metrics** " "(how well the predicted mask satisfies your clicks/box constraints) to match clinical interaction experience." ) with st.sidebar: st.header("Settings") model_key = st.selectbox("Dataset / Model", list(MODELS.keys())) num_threads = st.slider("CPU threads (intra-op)", 1, 16, 8) max_clicks = st.slider("Max clicks to replay (K)", 1, 10, 5) show_intermediate = st.checkbox("Show per-iter masks", value=False) image_source = st.radio("Image source", ["Demo assets", "Upload"], index=0) assets_subdir = MODELS[model_key]["assets_subdir"] onnx_filename = MODELS[model_key]["onnx_filename"] session, input_name, model_source, model_path = get_ort_session(onnx_filename, num_threads) with st.sidebar: st.caption(f"Model file: {onnx_filename}") st.caption(f"Loaded from: {model_source}") demo_images = _list_demo_images(assets_subdir) if image_source == "Demo assets" else [] if image_source == "Demo assets" and not demo_images: with st.sidebar: st.warning(f"No demo assets found for {model_key}. Please upload an image instead.") image_source = "Upload" gt_mask_model = None img_display = None img_model = None if image_source == "Demo assets": if not demo_images: st.error("No demo images available. Switch to 'Upload' in the sidebar.") st.stop() selected = st.sidebar.selectbox("Select demo image", demo_images, format_func=lambda p: p.name) pil_img = Image.open(selected) img_display = _resize_image_rgb(pil_img, DISPLAY_SIZE) img_model = _resize_image_rgb(pil_img, MODEL_SIZE) mask_path = _find_demo_mask(assets_subdir, selected.stem) if mask_path is not None: gt_mask_display = _resize_mask(Image.open(mask_path), DISPLAY_SIZE) gt_mask_model = _resize_mask(Image.open(mask_path), MODEL_SIZE) else: uploaded = st.sidebar.file_uploader("Upload an image", type=["png", "jpg", "jpeg", "bmp"]) if uploaded is None: st.info("Upload an image to start.") st.stop() pil_img = Image.open(uploaded) img_display = _resize_image_rgb(pil_img, DISPLAY_SIZE) img_model = _resize_image_rgb(pil_img, MODEL_SIZE) st.subheader("Interactive workspace") st.write("Draw **one box** (blue) and/or add multiple **points** (green=positive, red=negative), then run inference.") col_tools, col_canvas = st.columns([1, 3]) with col_tools: interaction_mode = st.radio("Tool", ["Box", "Positive Click", "Negative Click"]) drawing_mode = "rect" if interaction_mode == "Box" else "point" stroke_color = "green" if interaction_mode == "Positive Click" else "red" if interaction_mode == "Box": stroke_color = "blue" st.caption("Tip: use 1 box to localize, then refine with clicks.") with col_canvas: canvas = st_canvas( fill_color="rgba(255, 165, 0, 0.2)", stroke_width=3, stroke_color=stroke_color, background_image=img_display, update_streamlit=True, height=DISPLAY_SIZE[1], width=DISPLAY_SIZE[0], drawing_mode=drawing_mode, key="canvas", ) def parse_canvas(canvas_json): bbox = None clicks = [] if canvas_json is None: return bbox, clicks objs = canvas_json.get("objects", []) for obj in objs: if obj.get("type") == "rect": x_min = int(obj["left"]) y_min = int(obj["top"]) x_max = int(obj["left"] + obj["width"]) y_max = int(obj["top"] + obj["height"]) bbox = [x_min, y_min, x_max, y_max] elif obj.get("type") in ["circle", "point"]: x = int(obj["left"] + obj["width"] / 2) y = int(obj["top"] + obj["height"] / 2) label = 1 if obj.get("stroke") == "green" else 0 clicks.append((x, y, label)) return bbox, clicks if st.button("Run inference", type="primary"): bbox_display, clicks_display = parse_canvas(canvas.json_data) bbox_model = _scale_xyxy(bbox_display, DISPLAY_SIZE, MODEL_SIZE) clicks_model = _scale_clicks(clicks_display, DISPLAY_SIZE, MODEL_SIZE) if len(clicks_model) == 0 and bbox_model is None: st.warning("Please draw a box or add clicks before running.") st.stop() k = min(int(max_clicks), max(1, len(clicks_model)) if clicks_model else 1) if clicks_model: click_prefixes = [clicks_model[:i] for i in range(1, k + 1)] else: click_prefixes = [[] for _ in range(k)] records = [] masks_display = [] final_mask_display = None for it, clicks_it in enumerate(click_prefixes, start=1): t_prompt0 = time.perf_counter() x = _build_model_input(img_model, bbox_model, clicks_it, model_size_hw=(MODEL_SIZE[1], MODEL_SIZE[0])) t_prompt1 = time.perf_counter() t_fwd0 = time.perf_counter() logits = run_onnx(session, input_name, x) t_fwd1 = time.perf_counter() prob = sigmoid(logits[0, 0]) pred01 = (prob > 0.5).astype(np.uint8) pred_pil_model = Image.fromarray((pred01 * 255).astype(np.uint8)) pred_display = np.asarray(_resize_mask(pred_pil_model, DISPLAY_SIZE), dtype=np.uint8) pred_display01 = (pred_display > 127).astype(np.uint8) masks_display.append(pred_display01) final_mask_display = pred_display01 dsc = None if gt_mask_model is not None: gt01 = (np.asarray(gt_mask_model, dtype=np.uint8) > 127).astype(np.uint8) dsc = dice(pred01, gt01) cm = constraint_metrics( pred01, bbox_model, clicks_it, shape_hw=(MODEL_SIZE[1], MODEL_SIZE[0]), ) records.append( { "iter": it, "n_clicks_used": len(clicks_it), "prompt_ms": (t_prompt1 - t_prompt0) * 1000.0, "onnx_forward_ms": (t_fwd1 - t_fwd0) * 1000.0, "total_ms": (t_fwd1 - t_prompt0) * 1000.0, "dice": dsc, "pos_hit_rate": cm["pos_hit_rate"], "neg_ok_rate": cm["neg_ok_rate"], "bbox_outside_ratio": cm["bbox_outside_ratio"], "pred_area_ratio": cm["pred_area_ratio"], } ) st.divider() st.subheader("Results") left, right = st.columns([2, 1]) with left: cols = st.columns(3 if gt_mask_model is not None else 2) cols[0].image(img_display, caption="Input", use_column_width=True) cols[1].image(final_mask_display * 255, caption="Prediction (final)", clamp=True, use_column_width=True) if gt_mask_model is not None: cols[2].image(gt_mask_display, caption="Ground truth", clamp=True, use_column_width=True) with right: st.write("Per-click iteration metrics:") st.dataframe(records, use_container_width=True) csv_buf = io.StringIO() header = [ "iter", "n_clicks_used", "prompt_ms", "onnx_forward_ms", "total_ms", "dice", "pos_hit_rate", "neg_ok_rate", "bbox_outside_ratio", "pred_area_ratio", ] csv_buf.write(",".join(header) + "\n") for r in records: dice_str = "" if r["dice"] is None else f"{r['dice']:.4f}" pos_str = "" if r["pos_hit_rate"] is None else f"{r['pos_hit_rate']:.4f}" neg_str = "" if r["neg_ok_rate"] is None else f"{r['neg_ok_rate']:.4f}" bbox_str = "" if r["bbox_outside_ratio"] is None else f"{r['bbox_outside_ratio']:.6f}" area_str = f"{r['pred_area_ratio']:.6f}" csv_buf.write( f"{r['iter']},{r['n_clicks_used']},{r['prompt_ms']:.3f},{r['onnx_forward_ms']:.3f},{r['total_ms']:.3f},{dice_str},{pos_str},{neg_str},{bbox_str},{area_str}\n" ) st.download_button( label="Download per-iter CSV", data=csv_buf.getvalue().encode("utf-8"), file_name=f"per_iter_{model_key.lower()}_{MODEL_SIZE[0]}.csv", mime="text/csv", ) if show_intermediate: st.subheader("Intermediate masks (per iter)") cols = st.columns(len(masks_display)) for i, m in enumerate(masks_display): cols[i].image(m * 255, caption=f"Iter {i+1}", clamp=True, use_column_width=True)