Spaces:
Sleeping
Sleeping
| import io | |
| import os | |
| import time | |
| import pathlib | |
| import numpy as np | |
| import onnxruntime as ort | |
| import streamlit as st | |
| from huggingface_hub import hf_hub_download | |
| from PIL import Image | |
| from streamlit_drawable_canvas import st_canvas | |
| SPACE_REPO_ID = os.environ.get("HF_SPACE_REPO_ID", "hbyecoding/iU-RWKV-demo") | |
| MODEL_REPO_ID = os.environ.get("HF_MODEL_REPO_ID", "hbyecoding/iU-RWKV") | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| DISPLAY_SIZE = (256, 256) | |
| MODEL_SIZE = (192, 192) | |
| ASSETS_ROOT = pathlib.Path("hf_demo_assets") | |
| MODEL_DIR = pathlib.Path("models") | |
| MODELS = { | |
| "BUSI": { | |
| "assets_subdir": "BUSI", | |
| "onnx_filename": "iu_rwkv_busi_192.onnx", | |
| }, | |
| "POLY": { | |
| "assets_subdir": "POLY", | |
| "onnx_filename": "iu_rwkv_poly_192.onnx", | |
| }, | |
| "ISIC18": { | |
| "assets_subdir": "ISIC18", | |
| "onnx_filename": "iu_rwkv_isic18_192.onnx", | |
| }, | |
| } | |
| def _resize_image_rgb(pil_img, size): | |
| return pil_img.convert("RGB").resize(size, Image.Resampling.BILINEAR) | |
| def _resize_mask(pil_img, size): | |
| return pil_img.convert("L").resize(size, Image.Resampling.NEAREST) | |
| def _to_gray01(pil_img): | |
| arr = np.asarray(pil_img.convert("L"), dtype=np.float32) / 255.0 | |
| return arr | |
| def _bbox_channel(box, shape_hw): | |
| h, w = shape_hw | |
| ch = np.zeros((h, w), dtype=np.float32) | |
| if box is None: | |
| return ch | |
| x0, y0, x1, y1 = box | |
| x0 = int(np.clip(x0, 0, w)) | |
| x1 = int(np.clip(x1, 0, w)) | |
| y0 = int(np.clip(y0, 0, h)) | |
| y1 = int(np.clip(y1, 0, h)) | |
| if x1 > x0 and y1 > y0: | |
| ch[y0:y1, x0:x1] = 1.0 | |
| return ch | |
| def _click_channels(clicks, shape_hw): | |
| h, w = shape_hw | |
| pos = np.zeros((h, w), dtype=np.float32) | |
| neg = np.zeros((h, w), dtype=np.float32) | |
| if not clicks: | |
| return pos, neg | |
| for x, y, label in clicks: | |
| x = int(np.clip(x, 0, w - 1)) | |
| y = int(np.clip(y, 0, h - 1)) | |
| if int(label) == 1: | |
| pos[y, x] = 1.0 | |
| else: | |
| neg[y, x] = 1.0 | |
| return pos, neg | |
| def _build_model_input(pil_img_resized, box_xyxy, clicks_xy, model_size_hw): | |
| h, w = model_size_hw | |
| gray = _to_gray01(pil_img_resized) | |
| if gray.shape != (h, w): | |
| gray = np.asarray(_resize_mask(pil_img_resized, (w, h)), dtype=np.float32) / 255.0 | |
| img_ch = gray[None, :, :] | |
| box_ch = _bbox_channel(box_xyxy, (h, w))[None, :, :] | |
| pos_ch, neg_ch = _click_channels(clicks_xy, (h, w)) | |
| click_ch = np.stack([pos_ch, neg_ch], axis=0) | |
| mask_input_ch = np.zeros((1, h, w), dtype=np.float32) | |
| x = np.concatenate([img_ch, box_ch, click_ch, mask_input_ch], axis=0).astype(np.float32) | |
| x = x[None, :, :, :] | |
| return x | |
| def _scale_xyxy(box, src_size, dst_size): | |
| if box is None: | |
| return None | |
| sx = dst_size[0] / src_size[0] | |
| sy = dst_size[1] / src_size[1] | |
| x0, y0, x1, y1 = box | |
| return [int(round(x0 * sx)), int(round(y0 * sy)), int(round(x1 * sx)), int(round(y1 * sy))] | |
| def _scale_clicks(clicks, src_size, dst_size): | |
| if not clicks: | |
| return [] | |
| sx = dst_size[0] / src_size[0] | |
| sy = dst_size[1] / src_size[1] | |
| out = [] | |
| for x, y, label in clicks: | |
| out.append((int(round(x * sx)), int(round(y * sy)), int(label))) | |
| return out | |
| def _list_demo_images(assets_subdir): | |
| img_dir = ASSETS_ROOT / assets_subdir / "images" | |
| if not img_dir.exists(): | |
| return [] | |
| files = [] | |
| for p in img_dir.iterdir(): | |
| if p.suffix.lower() in [".png", ".jpg", ".jpeg", ".bmp"]: | |
| files.append(p) | |
| return sorted(files, key=lambda x: x.name) | |
| def _find_demo_mask(assets_subdir, stem): | |
| mask_dir = ASSETS_ROOT / assets_subdir / "masks" | |
| if not mask_dir.exists(): | |
| return None | |
| for p in mask_dir.iterdir(): | |
| if p.stem == stem: | |
| return p | |
| return None | |
| def get_ort_session(onnx_filename, num_threads): | |
| local_path = MODEL_DIR / onnx_filename | |
| if local_path.exists(): | |
| model_path = str(local_path) | |
| source = "local" | |
| else: | |
| try: | |
| model_path = hf_hub_download( | |
| repo_id=SPACE_REPO_ID, | |
| repo_type="space", | |
| filename=str(local_path.as_posix()), | |
| token=HF_TOKEN, | |
| ) | |
| source = "space" | |
| except Exception: | |
| model_path = hf_hub_download( | |
| repo_id=MODEL_REPO_ID, | |
| repo_type="model", | |
| filename=str(local_path.as_posix()), | |
| token=HF_TOKEN, | |
| ) | |
| source = "model" | |
| sess_opts = ort.SessionOptions() | |
| sess_opts.intra_op_num_threads = int(num_threads) | |
| sess_opts.inter_op_num_threads = 1 | |
| sess_opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL | |
| session = ort.InferenceSession(model_path, sess_options=sess_opts, providers=["CPUExecutionProvider"]) | |
| input_name = session.get_inputs()[0].name | |
| return session, input_name, source, model_path | |
| def run_onnx(session, input_name, x): | |
| y = session.run(None, {input_name: x})[0] | |
| return y | |
| def sigmoid(x): | |
| return 1.0 / (1.0 + np.exp(-x)) | |
| def dice(pred01, gt01, eps=1e-7): | |
| pred = pred01.astype(np.float32) | |
| gt = gt01.astype(np.float32) | |
| inter = np.sum(pred * gt) | |
| denom = np.sum(pred) + np.sum(gt) | |
| return float((2.0 * inter + eps) / (denom + eps)) | |
| def constraint_metrics(pred01, box_xyxy, clicks_xy, shape_hw): | |
| h, w = shape_hw | |
| pos = [(x, y) for (x, y, lab) in clicks_xy if int(lab) == 1] | |
| neg = [(x, y) for (x, y, lab) in clicks_xy if int(lab) == 0] | |
| pos_hit = None | |
| if len(pos) > 0: | |
| hits = [int(pred01[int(np.clip(y, 0, h - 1)), int(np.clip(x, 0, w - 1))] == 1) for x, y in pos] | |
| pos_hit = float(np.mean(hits)) | |
| neg_ok = None | |
| if len(neg) > 0: | |
| oks = [int(pred01[int(np.clip(y, 0, h - 1)), int(np.clip(x, 0, w - 1))] == 0) for x, y in neg] | |
| neg_ok = float(np.mean(oks)) | |
| outside_ratio = None | |
| if box_xyxy is not None: | |
| x0, y0, x1, y1 = box_xyxy | |
| x0 = int(np.clip(x0, 0, w)) | |
| x1 = int(np.clip(x1, 0, w)) | |
| y0 = int(np.clip(y0, 0, h)) | |
| y1 = int(np.clip(y1, 0, h)) | |
| bbox_mask = np.zeros((h, w), dtype=np.uint8) | |
| if x1 > x0 and y1 > y0: | |
| bbox_mask[y0:y1, x0:x1] = 1 | |
| pred_sum = float(np.sum(pred01)) | |
| if pred_sum > 0: | |
| outside_ratio = float(np.sum(pred01 * (1 - bbox_mask)) / pred_sum) | |
| else: | |
| outside_ratio = 0.0 | |
| pred_area_ratio = float(np.sum(pred01)) / float(h * w) | |
| return { | |
| "pos_hit_rate": pos_hit, | |
| "neg_ok_rate": neg_ok, | |
| "bbox_outside_ratio": outside_ratio, | |
| "pred_area_ratio": pred_area_ratio, | |
| } | |
| st.set_page_config(page_title="iU-RWKV Interactive Segmentation (ONNX)", layout="wide") | |
| st.title("iU-RWKV Interactive Segmentation Demo (Hugging Face Spaces)") | |
| st.markdown( | |
| "This Space runs iU-RWKV as an **ONNX Runtime** model on CPU. " | |
| "We report **per-click iteration latency** (prompt update + ONNX forward) and **interaction-consistency metrics** " | |
| "(how well the predicted mask satisfies your clicks/box constraints) to match clinical interaction experience." | |
| ) | |
| with st.sidebar: | |
| st.header("Settings") | |
| model_key = st.selectbox("Dataset / Model", list(MODELS.keys())) | |
| num_threads = st.slider("CPU threads (intra-op)", 1, 16, 8) | |
| max_clicks = st.slider("Max clicks to replay (K)", 1, 10, 5) | |
| show_intermediate = st.checkbox("Show per-iter masks", value=False) | |
| image_source = st.radio("Image source", ["Demo assets", "Upload"], index=0) | |
| assets_subdir = MODELS[model_key]["assets_subdir"] | |
| onnx_filename = MODELS[model_key]["onnx_filename"] | |
| session, input_name, model_source, model_path = get_ort_session(onnx_filename, num_threads) | |
| with st.sidebar: | |
| st.caption(f"Model file: {onnx_filename}") | |
| st.caption(f"Loaded from: {model_source}") | |
| demo_images = _list_demo_images(assets_subdir) if image_source == "Demo assets" else [] | |
| if image_source == "Demo assets" and not demo_images: | |
| with st.sidebar: | |
| st.warning(f"No demo assets found for {model_key}. Please upload an image instead.") | |
| image_source = "Upload" | |
| gt_mask_model = None | |
| img_display = None | |
| img_model = None | |
| if image_source == "Demo assets": | |
| if not demo_images: | |
| st.error("No demo images available. Switch to 'Upload' in the sidebar.") | |
| st.stop() | |
| selected = st.sidebar.selectbox("Select demo image", demo_images, format_func=lambda p: p.name) | |
| pil_img = Image.open(selected) | |
| img_display = _resize_image_rgb(pil_img, DISPLAY_SIZE) | |
| img_model = _resize_image_rgb(pil_img, MODEL_SIZE) | |
| mask_path = _find_demo_mask(assets_subdir, selected.stem) | |
| if mask_path is not None: | |
| gt_mask_display = _resize_mask(Image.open(mask_path), DISPLAY_SIZE) | |
| gt_mask_model = _resize_mask(Image.open(mask_path), MODEL_SIZE) | |
| else: | |
| uploaded = st.sidebar.file_uploader("Upload an image", type=["png", "jpg", "jpeg", "bmp"]) | |
| if uploaded is None: | |
| st.info("Upload an image to start.") | |
| st.stop() | |
| pil_img = Image.open(uploaded) | |
| img_display = _resize_image_rgb(pil_img, DISPLAY_SIZE) | |
| img_model = _resize_image_rgb(pil_img, MODEL_SIZE) | |
| st.subheader("Interactive workspace") | |
| st.write("Draw **one box** (blue) and/or add multiple **points** (green=positive, red=negative), then run inference.") | |
| col_tools, col_canvas = st.columns([1, 3]) | |
| with col_tools: | |
| interaction_mode = st.radio("Tool", ["Box", "Positive Click", "Negative Click"]) | |
| drawing_mode = "rect" if interaction_mode == "Box" else "point" | |
| stroke_color = "green" if interaction_mode == "Positive Click" else "red" | |
| if interaction_mode == "Box": | |
| stroke_color = "blue" | |
| st.caption("Tip: use 1 box to localize, then refine with clicks.") | |
| with col_canvas: | |
| canvas = st_canvas( | |
| fill_color="rgba(255, 165, 0, 0.2)", | |
| stroke_width=3, | |
| stroke_color=stroke_color, | |
| background_image=img_display, | |
| update_streamlit=True, | |
| height=DISPLAY_SIZE[1], | |
| width=DISPLAY_SIZE[0], | |
| drawing_mode=drawing_mode, | |
| key="canvas", | |
| ) | |
| def parse_canvas(canvas_json): | |
| bbox = None | |
| clicks = [] | |
| if canvas_json is None: | |
| return bbox, clicks | |
| objs = canvas_json.get("objects", []) | |
| for obj in objs: | |
| if obj.get("type") == "rect": | |
| x_min = int(obj["left"]) | |
| y_min = int(obj["top"]) | |
| x_max = int(obj["left"] + obj["width"]) | |
| y_max = int(obj["top"] + obj["height"]) | |
| bbox = [x_min, y_min, x_max, y_max] | |
| elif obj.get("type") in ["circle", "point"]: | |
| x = int(obj["left"] + obj["width"] / 2) | |
| y = int(obj["top"] + obj["height"] / 2) | |
| label = 1 if obj.get("stroke") == "green" else 0 | |
| clicks.append((x, y, label)) | |
| return bbox, clicks | |
| if st.button("Run inference", type="primary"): | |
| bbox_display, clicks_display = parse_canvas(canvas.json_data) | |
| bbox_model = _scale_xyxy(bbox_display, DISPLAY_SIZE, MODEL_SIZE) | |
| clicks_model = _scale_clicks(clicks_display, DISPLAY_SIZE, MODEL_SIZE) | |
| if len(clicks_model) == 0 and bbox_model is None: | |
| st.warning("Please draw a box or add clicks before running.") | |
| st.stop() | |
| k = min(int(max_clicks), max(1, len(clicks_model)) if clicks_model else 1) | |
| if clicks_model: | |
| click_prefixes = [clicks_model[:i] for i in range(1, k + 1)] | |
| else: | |
| click_prefixes = [[] for _ in range(k)] | |
| records = [] | |
| masks_display = [] | |
| final_mask_display = None | |
| for it, clicks_it in enumerate(click_prefixes, start=1): | |
| t_prompt0 = time.perf_counter() | |
| x = _build_model_input(img_model, bbox_model, clicks_it, model_size_hw=(MODEL_SIZE[1], MODEL_SIZE[0])) | |
| t_prompt1 = time.perf_counter() | |
| t_fwd0 = time.perf_counter() | |
| logits = run_onnx(session, input_name, x) | |
| t_fwd1 = time.perf_counter() | |
| prob = sigmoid(logits[0, 0]) | |
| pred01 = (prob > 0.5).astype(np.uint8) | |
| pred_pil_model = Image.fromarray((pred01 * 255).astype(np.uint8)) | |
| pred_display = np.asarray(_resize_mask(pred_pil_model, DISPLAY_SIZE), dtype=np.uint8) | |
| pred_display01 = (pred_display > 127).astype(np.uint8) | |
| masks_display.append(pred_display01) | |
| final_mask_display = pred_display01 | |
| dsc = None | |
| if gt_mask_model is not None: | |
| gt01 = (np.asarray(gt_mask_model, dtype=np.uint8) > 127).astype(np.uint8) | |
| dsc = dice(pred01, gt01) | |
| cm = constraint_metrics( | |
| pred01, | |
| bbox_model, | |
| clicks_it, | |
| shape_hw=(MODEL_SIZE[1], MODEL_SIZE[0]), | |
| ) | |
| records.append( | |
| { | |
| "iter": it, | |
| "n_clicks_used": len(clicks_it), | |
| "prompt_ms": (t_prompt1 - t_prompt0) * 1000.0, | |
| "onnx_forward_ms": (t_fwd1 - t_fwd0) * 1000.0, | |
| "total_ms": (t_fwd1 - t_prompt0) * 1000.0, | |
| "dice": dsc, | |
| "pos_hit_rate": cm["pos_hit_rate"], | |
| "neg_ok_rate": cm["neg_ok_rate"], | |
| "bbox_outside_ratio": cm["bbox_outside_ratio"], | |
| "pred_area_ratio": cm["pred_area_ratio"], | |
| } | |
| ) | |
| st.divider() | |
| st.subheader("Results") | |
| left, right = st.columns([2, 1]) | |
| with left: | |
| cols = st.columns(3 if gt_mask_model is not None else 2) | |
| cols[0].image(img_display, caption="Input", use_column_width=True) | |
| cols[1].image(final_mask_display * 255, caption="Prediction (final)", clamp=True, use_column_width=True) | |
| if gt_mask_model is not None: | |
| cols[2].image(gt_mask_display, caption="Ground truth", clamp=True, use_column_width=True) | |
| with right: | |
| st.write("Per-click iteration metrics:") | |
| st.dataframe(records, use_container_width=True) | |
| csv_buf = io.StringIO() | |
| header = [ | |
| "iter", | |
| "n_clicks_used", | |
| "prompt_ms", | |
| "onnx_forward_ms", | |
| "total_ms", | |
| "dice", | |
| "pos_hit_rate", | |
| "neg_ok_rate", | |
| "bbox_outside_ratio", | |
| "pred_area_ratio", | |
| ] | |
| csv_buf.write(",".join(header) + "\n") | |
| for r in records: | |
| dice_str = "" if r["dice"] is None else f"{r['dice']:.4f}" | |
| pos_str = "" if r["pos_hit_rate"] is None else f"{r['pos_hit_rate']:.4f}" | |
| neg_str = "" if r["neg_ok_rate"] is None else f"{r['neg_ok_rate']:.4f}" | |
| bbox_str = "" if r["bbox_outside_ratio"] is None else f"{r['bbox_outside_ratio']:.6f}" | |
| area_str = f"{r['pred_area_ratio']:.6f}" | |
| csv_buf.write( | |
| f"{r['iter']},{r['n_clicks_used']},{r['prompt_ms']:.3f},{r['onnx_forward_ms']:.3f},{r['total_ms']:.3f},{dice_str},{pos_str},{neg_str},{bbox_str},{area_str}\n" | |
| ) | |
| st.download_button( | |
| label="Download per-iter CSV", | |
| data=csv_buf.getvalue().encode("utf-8"), | |
| file_name=f"per_iter_{model_key.lower()}_{MODEL_SIZE[0]}.csv", | |
| mime="text/csv", | |
| ) | |
| if show_intermediate: | |
| st.subheader("Intermediate masks (per iter)") | |
| cols = st.columns(len(masks_display)) | |
| for i, m in enumerate(masks_display): | |
| cols[i].image(m * 255, caption=f"Iter {i+1}", clamp=True, use_column_width=True) | |