iU-RWKV-demo / src /streamlit_app.py
hbyecoding's picture
Add demo assets + fallback to upload when missing
07d6237
import io
import os
import time
import pathlib
import numpy as np
import onnxruntime as ort
import streamlit as st
from huggingface_hub import hf_hub_download
from PIL import Image
from streamlit_drawable_canvas import st_canvas
SPACE_REPO_ID = os.environ.get("HF_SPACE_REPO_ID", "hbyecoding/iU-RWKV-demo")
MODEL_REPO_ID = os.environ.get("HF_MODEL_REPO_ID", "hbyecoding/iU-RWKV")
HF_TOKEN = os.environ.get("HF_TOKEN")
DISPLAY_SIZE = (256, 256)
MODEL_SIZE = (192, 192)
ASSETS_ROOT = pathlib.Path("hf_demo_assets")
MODEL_DIR = pathlib.Path("models")
MODELS = {
"BUSI": {
"assets_subdir": "BUSI",
"onnx_filename": "iu_rwkv_busi_192.onnx",
},
"POLY": {
"assets_subdir": "POLY",
"onnx_filename": "iu_rwkv_poly_192.onnx",
},
"ISIC18": {
"assets_subdir": "ISIC18",
"onnx_filename": "iu_rwkv_isic18_192.onnx",
},
}
def _resize_image_rgb(pil_img, size):
return pil_img.convert("RGB").resize(size, Image.Resampling.BILINEAR)
def _resize_mask(pil_img, size):
return pil_img.convert("L").resize(size, Image.Resampling.NEAREST)
def _to_gray01(pil_img):
arr = np.asarray(pil_img.convert("L"), dtype=np.float32) / 255.0
return arr
def _bbox_channel(box, shape_hw):
h, w = shape_hw
ch = np.zeros((h, w), dtype=np.float32)
if box is None:
return ch
x0, y0, x1, y1 = box
x0 = int(np.clip(x0, 0, w))
x1 = int(np.clip(x1, 0, w))
y0 = int(np.clip(y0, 0, h))
y1 = int(np.clip(y1, 0, h))
if x1 > x0 and y1 > y0:
ch[y0:y1, x0:x1] = 1.0
return ch
def _click_channels(clicks, shape_hw):
h, w = shape_hw
pos = np.zeros((h, w), dtype=np.float32)
neg = np.zeros((h, w), dtype=np.float32)
if not clicks:
return pos, neg
for x, y, label in clicks:
x = int(np.clip(x, 0, w - 1))
y = int(np.clip(y, 0, h - 1))
if int(label) == 1:
pos[y, x] = 1.0
else:
neg[y, x] = 1.0
return pos, neg
def _build_model_input(pil_img_resized, box_xyxy, clicks_xy, model_size_hw):
h, w = model_size_hw
gray = _to_gray01(pil_img_resized)
if gray.shape != (h, w):
gray = np.asarray(_resize_mask(pil_img_resized, (w, h)), dtype=np.float32) / 255.0
img_ch = gray[None, :, :]
box_ch = _bbox_channel(box_xyxy, (h, w))[None, :, :]
pos_ch, neg_ch = _click_channels(clicks_xy, (h, w))
click_ch = np.stack([pos_ch, neg_ch], axis=0)
mask_input_ch = np.zeros((1, h, w), dtype=np.float32)
x = np.concatenate([img_ch, box_ch, click_ch, mask_input_ch], axis=0).astype(np.float32)
x = x[None, :, :, :]
return x
def _scale_xyxy(box, src_size, dst_size):
if box is None:
return None
sx = dst_size[0] / src_size[0]
sy = dst_size[1] / src_size[1]
x0, y0, x1, y1 = box
return [int(round(x0 * sx)), int(round(y0 * sy)), int(round(x1 * sx)), int(round(y1 * sy))]
def _scale_clicks(clicks, src_size, dst_size):
if not clicks:
return []
sx = dst_size[0] / src_size[0]
sy = dst_size[1] / src_size[1]
out = []
for x, y, label in clicks:
out.append((int(round(x * sx)), int(round(y * sy)), int(label)))
return out
def _list_demo_images(assets_subdir):
img_dir = ASSETS_ROOT / assets_subdir / "images"
if not img_dir.exists():
return []
files = []
for p in img_dir.iterdir():
if p.suffix.lower() in [".png", ".jpg", ".jpeg", ".bmp"]:
files.append(p)
return sorted(files, key=lambda x: x.name)
def _find_demo_mask(assets_subdir, stem):
mask_dir = ASSETS_ROOT / assets_subdir / "masks"
if not mask_dir.exists():
return None
for p in mask_dir.iterdir():
if p.stem == stem:
return p
return None
@st.cache_resource
def get_ort_session(onnx_filename, num_threads):
local_path = MODEL_DIR / onnx_filename
if local_path.exists():
model_path = str(local_path)
source = "local"
else:
try:
model_path = hf_hub_download(
repo_id=SPACE_REPO_ID,
repo_type="space",
filename=str(local_path.as_posix()),
token=HF_TOKEN,
)
source = "space"
except Exception:
model_path = hf_hub_download(
repo_id=MODEL_REPO_ID,
repo_type="model",
filename=str(local_path.as_posix()),
token=HF_TOKEN,
)
source = "model"
sess_opts = ort.SessionOptions()
sess_opts.intra_op_num_threads = int(num_threads)
sess_opts.inter_op_num_threads = 1
sess_opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
session = ort.InferenceSession(model_path, sess_options=sess_opts, providers=["CPUExecutionProvider"])
input_name = session.get_inputs()[0].name
return session, input_name, source, model_path
def run_onnx(session, input_name, x):
y = session.run(None, {input_name: x})[0]
return y
def sigmoid(x):
return 1.0 / (1.0 + np.exp(-x))
def dice(pred01, gt01, eps=1e-7):
pred = pred01.astype(np.float32)
gt = gt01.astype(np.float32)
inter = np.sum(pred * gt)
denom = np.sum(pred) + np.sum(gt)
return float((2.0 * inter + eps) / (denom + eps))
def constraint_metrics(pred01, box_xyxy, clicks_xy, shape_hw):
h, w = shape_hw
pos = [(x, y) for (x, y, lab) in clicks_xy if int(lab) == 1]
neg = [(x, y) for (x, y, lab) in clicks_xy if int(lab) == 0]
pos_hit = None
if len(pos) > 0:
hits = [int(pred01[int(np.clip(y, 0, h - 1)), int(np.clip(x, 0, w - 1))] == 1) for x, y in pos]
pos_hit = float(np.mean(hits))
neg_ok = None
if len(neg) > 0:
oks = [int(pred01[int(np.clip(y, 0, h - 1)), int(np.clip(x, 0, w - 1))] == 0) for x, y in neg]
neg_ok = float(np.mean(oks))
outside_ratio = None
if box_xyxy is not None:
x0, y0, x1, y1 = box_xyxy
x0 = int(np.clip(x0, 0, w))
x1 = int(np.clip(x1, 0, w))
y0 = int(np.clip(y0, 0, h))
y1 = int(np.clip(y1, 0, h))
bbox_mask = np.zeros((h, w), dtype=np.uint8)
if x1 > x0 and y1 > y0:
bbox_mask[y0:y1, x0:x1] = 1
pred_sum = float(np.sum(pred01))
if pred_sum > 0:
outside_ratio = float(np.sum(pred01 * (1 - bbox_mask)) / pred_sum)
else:
outside_ratio = 0.0
pred_area_ratio = float(np.sum(pred01)) / float(h * w)
return {
"pos_hit_rate": pos_hit,
"neg_ok_rate": neg_ok,
"bbox_outside_ratio": outside_ratio,
"pred_area_ratio": pred_area_ratio,
}
st.set_page_config(page_title="iU-RWKV Interactive Segmentation (ONNX)", layout="wide")
st.title("iU-RWKV Interactive Segmentation Demo (Hugging Face Spaces)")
st.markdown(
"This Space runs iU-RWKV as an **ONNX Runtime** model on CPU. "
"We report **per-click iteration latency** (prompt update + ONNX forward) and **interaction-consistency metrics** "
"(how well the predicted mask satisfies your clicks/box constraints) to match clinical interaction experience."
)
with st.sidebar:
st.header("Settings")
model_key = st.selectbox("Dataset / Model", list(MODELS.keys()))
num_threads = st.slider("CPU threads (intra-op)", 1, 16, 8)
max_clicks = st.slider("Max clicks to replay (K)", 1, 10, 5)
show_intermediate = st.checkbox("Show per-iter masks", value=False)
image_source = st.radio("Image source", ["Demo assets", "Upload"], index=0)
assets_subdir = MODELS[model_key]["assets_subdir"]
onnx_filename = MODELS[model_key]["onnx_filename"]
session, input_name, model_source, model_path = get_ort_session(onnx_filename, num_threads)
with st.sidebar:
st.caption(f"Model file: {onnx_filename}")
st.caption(f"Loaded from: {model_source}")
demo_images = _list_demo_images(assets_subdir) if image_source == "Demo assets" else []
if image_source == "Demo assets" and not demo_images:
with st.sidebar:
st.warning(f"No demo assets found for {model_key}. Please upload an image instead.")
image_source = "Upload"
gt_mask_model = None
img_display = None
img_model = None
if image_source == "Demo assets":
if not demo_images:
st.error("No demo images available. Switch to 'Upload' in the sidebar.")
st.stop()
selected = st.sidebar.selectbox("Select demo image", demo_images, format_func=lambda p: p.name)
pil_img = Image.open(selected)
img_display = _resize_image_rgb(pil_img, DISPLAY_SIZE)
img_model = _resize_image_rgb(pil_img, MODEL_SIZE)
mask_path = _find_demo_mask(assets_subdir, selected.stem)
if mask_path is not None:
gt_mask_display = _resize_mask(Image.open(mask_path), DISPLAY_SIZE)
gt_mask_model = _resize_mask(Image.open(mask_path), MODEL_SIZE)
else:
uploaded = st.sidebar.file_uploader("Upload an image", type=["png", "jpg", "jpeg", "bmp"])
if uploaded is None:
st.info("Upload an image to start.")
st.stop()
pil_img = Image.open(uploaded)
img_display = _resize_image_rgb(pil_img, DISPLAY_SIZE)
img_model = _resize_image_rgb(pil_img, MODEL_SIZE)
st.subheader("Interactive workspace")
st.write("Draw **one box** (blue) and/or add multiple **points** (green=positive, red=negative), then run inference.")
col_tools, col_canvas = st.columns([1, 3])
with col_tools:
interaction_mode = st.radio("Tool", ["Box", "Positive Click", "Negative Click"])
drawing_mode = "rect" if interaction_mode == "Box" else "point"
stroke_color = "green" if interaction_mode == "Positive Click" else "red"
if interaction_mode == "Box":
stroke_color = "blue"
st.caption("Tip: use 1 box to localize, then refine with clicks.")
with col_canvas:
canvas = st_canvas(
fill_color="rgba(255, 165, 0, 0.2)",
stroke_width=3,
stroke_color=stroke_color,
background_image=img_display,
update_streamlit=True,
height=DISPLAY_SIZE[1],
width=DISPLAY_SIZE[0],
drawing_mode=drawing_mode,
key="canvas",
)
def parse_canvas(canvas_json):
bbox = None
clicks = []
if canvas_json is None:
return bbox, clicks
objs = canvas_json.get("objects", [])
for obj in objs:
if obj.get("type") == "rect":
x_min = int(obj["left"])
y_min = int(obj["top"])
x_max = int(obj["left"] + obj["width"])
y_max = int(obj["top"] + obj["height"])
bbox = [x_min, y_min, x_max, y_max]
elif obj.get("type") in ["circle", "point"]:
x = int(obj["left"] + obj["width"] / 2)
y = int(obj["top"] + obj["height"] / 2)
label = 1 if obj.get("stroke") == "green" else 0
clicks.append((x, y, label))
return bbox, clicks
if st.button("Run inference", type="primary"):
bbox_display, clicks_display = parse_canvas(canvas.json_data)
bbox_model = _scale_xyxy(bbox_display, DISPLAY_SIZE, MODEL_SIZE)
clicks_model = _scale_clicks(clicks_display, DISPLAY_SIZE, MODEL_SIZE)
if len(clicks_model) == 0 and bbox_model is None:
st.warning("Please draw a box or add clicks before running.")
st.stop()
k = min(int(max_clicks), max(1, len(clicks_model)) if clicks_model else 1)
if clicks_model:
click_prefixes = [clicks_model[:i] for i in range(1, k + 1)]
else:
click_prefixes = [[] for _ in range(k)]
records = []
masks_display = []
final_mask_display = None
for it, clicks_it in enumerate(click_prefixes, start=1):
t_prompt0 = time.perf_counter()
x = _build_model_input(img_model, bbox_model, clicks_it, model_size_hw=(MODEL_SIZE[1], MODEL_SIZE[0]))
t_prompt1 = time.perf_counter()
t_fwd0 = time.perf_counter()
logits = run_onnx(session, input_name, x)
t_fwd1 = time.perf_counter()
prob = sigmoid(logits[0, 0])
pred01 = (prob > 0.5).astype(np.uint8)
pred_pil_model = Image.fromarray((pred01 * 255).astype(np.uint8))
pred_display = np.asarray(_resize_mask(pred_pil_model, DISPLAY_SIZE), dtype=np.uint8)
pred_display01 = (pred_display > 127).astype(np.uint8)
masks_display.append(pred_display01)
final_mask_display = pred_display01
dsc = None
if gt_mask_model is not None:
gt01 = (np.asarray(gt_mask_model, dtype=np.uint8) > 127).astype(np.uint8)
dsc = dice(pred01, gt01)
cm = constraint_metrics(
pred01,
bbox_model,
clicks_it,
shape_hw=(MODEL_SIZE[1], MODEL_SIZE[0]),
)
records.append(
{
"iter": it,
"n_clicks_used": len(clicks_it),
"prompt_ms": (t_prompt1 - t_prompt0) * 1000.0,
"onnx_forward_ms": (t_fwd1 - t_fwd0) * 1000.0,
"total_ms": (t_fwd1 - t_prompt0) * 1000.0,
"dice": dsc,
"pos_hit_rate": cm["pos_hit_rate"],
"neg_ok_rate": cm["neg_ok_rate"],
"bbox_outside_ratio": cm["bbox_outside_ratio"],
"pred_area_ratio": cm["pred_area_ratio"],
}
)
st.divider()
st.subheader("Results")
left, right = st.columns([2, 1])
with left:
cols = st.columns(3 if gt_mask_model is not None else 2)
cols[0].image(img_display, caption="Input", use_column_width=True)
cols[1].image(final_mask_display * 255, caption="Prediction (final)", clamp=True, use_column_width=True)
if gt_mask_model is not None:
cols[2].image(gt_mask_display, caption="Ground truth", clamp=True, use_column_width=True)
with right:
st.write("Per-click iteration metrics:")
st.dataframe(records, use_container_width=True)
csv_buf = io.StringIO()
header = [
"iter",
"n_clicks_used",
"prompt_ms",
"onnx_forward_ms",
"total_ms",
"dice",
"pos_hit_rate",
"neg_ok_rate",
"bbox_outside_ratio",
"pred_area_ratio",
]
csv_buf.write(",".join(header) + "\n")
for r in records:
dice_str = "" if r["dice"] is None else f"{r['dice']:.4f}"
pos_str = "" if r["pos_hit_rate"] is None else f"{r['pos_hit_rate']:.4f}"
neg_str = "" if r["neg_ok_rate"] is None else f"{r['neg_ok_rate']:.4f}"
bbox_str = "" if r["bbox_outside_ratio"] is None else f"{r['bbox_outside_ratio']:.6f}"
area_str = f"{r['pred_area_ratio']:.6f}"
csv_buf.write(
f"{r['iter']},{r['n_clicks_used']},{r['prompt_ms']:.3f},{r['onnx_forward_ms']:.3f},{r['total_ms']:.3f},{dice_str},{pos_str},{neg_str},{bbox_str},{area_str}\n"
)
st.download_button(
label="Download per-iter CSV",
data=csv_buf.getvalue().encode("utf-8"),
file_name=f"per_iter_{model_key.lower()}_{MODEL_SIZE[0]}.csv",
mime="text/csv",
)
if show_intermediate:
st.subheader("Intermediate masks (per iter)")
cols = st.columns(len(masks_display))
for i, m in enumerate(masks_display):
cols[i].image(m * 255, caption=f"Iter {i+1}", clamp=True, use_column_width=True)