iU-RWKV-demo / app.py
hbyecoding's picture
Upload 143 files
b2c5353 verified
import streamlit as st
import os
import pathlib
import numpy as np
import torch
from PIL import Image
import sys
from streamlit_drawable_canvas import st_canvas
from huggingface_hub import hf_hub_download
# Add project root to path to import modules
sys.path.append(os.getcwd())
try:
from demo_inference import load_model, predict
except ImportError:
# Handle if running from different directory structure
sys.path.append(os.path.join(os.getcwd(), "ScribblePrompt"))
from demo_inference import load_model, predict
# Configuration
REPO_ID = "hbyecoding/iU-RWKV" # User needs to replace this
MODELS_CONFIG = {
"ISIC18": {
"dataset": "ISIC18",
"checkpoint_name": "isic18_max-val_id-dice_score.pt",
"config": "configs/train_urwkv_isic18.yaml"
},
"POLY": {
"dataset": "POLY",
"checkpoint_name": "poly_max-val_id-dice_score.pt",
"config": "configs/train_urwkv_poly.yaml"
},
"BUSI (Custom)": {
"dataset": "BUSI",
"checkpoint_name": "custom_max-val_id-dice_score.pt",
"config": "configs/train_urwkv_custom.yaml"
}
}
# Assets are now relative to the app directory
ASSETS_ROOT = pathlib.Path("hf_demo_assets")
@st.cache_resource
def get_model(config_path, checkpoint_name):
# Try local first (for testing), then HF Hub
local_path = os.path.join("checkpoints", checkpoint_name)
if os.path.exists(local_path):
ckpt_path = local_path
else:
try:
ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=checkpoint_name)
except Exception as e:
st.error(f"Failed to download model from Hugging Face: {e}")
st.error("Please ensure you have uploaded the models and updated REPO_ID in the code.")
st.stop()
return load_model(config_path, ckpt_path, device="cuda" if torch.cuda.is_available() else "cpu")
def calculate_dice(pred_mask, gt_mask):
pred = (pred_mask > 0.5).astype(np.uint8)
gt = (gt_mask > 0).astype(np.uint8)
intersection = np.sum(pred * gt)
union = np.sum(pred) + np.sum(gt)
if union == 0:
return 1.0 if np.sum(pred) == 0 else 0.0
return 2.0 * intersection / union
def get_image_list(dataset_name):
dataset_dir = ASSETS_ROOT / dataset_name / "images"
if not dataset_dir.exists():
return []
valid_files = []
for file in dataset_dir.iterdir():
if file.suffix.lower() in [".png", ".jpg", ".jpeg", ".bmp"]:
valid_files.append((str(file), file.stem, file.suffix))
return sorted(valid_files)
def get_mask_path(dataset_name, name):
mask_dir = ASSETS_ROOT / dataset_name / "masks"
if not mask_dir.exists():
return None
for file in mask_dir.iterdir():
if file.stem == name:
return str(file)
return None
st.set_page_config(page_title="iU-RWKV Interactive Segmentation", layout="wide")
st.title("iU-RWKV Interactive Segmentation Demo")
st.markdown("This demo showcases the interactive segmentation capabilities of the **iU-RWKV** model.")
st.sidebar.header("Settings")
# Model Selection
selected_model_name = st.sidebar.selectbox("Select Model & Dataset", list(MODELS_CONFIG.keys()))
model_info = MODELS_CONFIG[selected_model_name]
dataset_name = model_info["dataset"]
# Load specific model
exp = get_model(model_info["config"], model_info["checkpoint_name"])
image_list_info = get_image_list(dataset_name)
if not image_list_info:
st.error(f"No demo images found for {dataset_name}. Please check hf_demo_assets directory.")
st.stop()
# Image Selection
selected_image_info = st.sidebar.selectbox("Select Image", image_list_info, format_func=lambda x: os.path.basename(x[0]))
selected_image_path, img_name, img_ext = selected_image_info
# Try to find GT Mask
gt_mask_path = get_mask_path(dataset_name, img_name)
gt_mask = None
if gt_mask_path:
gt_mask_img = Image.open(gt_mask_path).convert("L")
# We will resize GT mask to 256x256 to match prediction
gt_mask = np.array(gt_mask_img.resize((256, 256), Image.NEAREST))
# Load Image
image = Image.open(selected_image_path).convert("RGB")
DISPLAY_SIZE = (256, 256)
image_resized = image.resize(DISPLAY_SIZE)
st.header("Interactive Workspace")
st.write("Draw a **Box (Rect)** or **Click (Point)** on the image to guide the segmentation.")
col_tools, col_canvas = st.columns([1, 3])
with col_tools:
# Interaction Mode
interaction_mode = st.radio("Interaction Mode", ["Box", "Positive Click", "Negative Click"])
drawing_mode = "rect" if interaction_mode == "Box" else "point"
stroke_color = "green" if interaction_mode == "Positive Click" else "red"
if interaction_mode == "Box": stroke_color = "blue"
st.info("Tip: Draw a box around the object, or click green points on the object and red points on the background.")
with col_canvas:
# Canvas
canvas_result = st_canvas(
fill_color="rgba(255, 165, 0, 0.3)",
stroke_width=3,
stroke_color=stroke_color,
background_image=image_resized,
update_streamlit=True,
height=DISPLAY_SIZE[1],
width=DISPLAY_SIZE[0],
drawing_mode=drawing_mode,
key="canvas",
)
if st.button("Run Prediction", type="primary"):
bbox = None
clicks = []
if canvas_result.json_data is not None:
objects = canvas_result.json_data["objects"]
for obj in objects:
if obj["type"] == "rect":
x_min = int(obj["left"])
y_min = int(obj["top"])
x_max = int(obj["left"] + obj["width"])
y_max = int(obj["top"] + obj["height"])
bbox = [x_min, y_min, x_max, y_max]
elif obj["type"] in ["circle", "point"]:
x = int(obj["left"] + obj["width"] / 2)
y = int(obj["top"] + obj["height"] / 2)
label = 1 if obj["stroke"] == "green" else 0
clicks.append((x, y, label))
with st.spinner("Running iU-RWKV Inference..."):
temp_img_path = "temp_input.png"
image_resized.save(temp_img_path)
pred_mask, logits = predict(exp, temp_img_path, clicks=clicks, bbox=bbox)
pred_np = pred_mask.squeeze().cpu().numpy()
dice_score = None
if gt_mask is not None:
dice_score = calculate_dice(pred_np, gt_mask)
st.divider()
st.subheader("Results")
if dice_score is not None:
st.success(f"**Dice Score:** {dice_score:.4f}")
else:
st.warning("Ground Truth mask not found, cannot calculate Dice score.")
col1, col2, col3, col4 = st.columns(4)
with col1:
st.image(image_resized, caption="Input Image", use_column_width=True)
with col2:
st.image(pred_np, caption="Prediction Mask", clamp=True, use_column_width=True)
with col3:
if gt_mask is not None:
st.image(gt_mask, caption="Ground Truth", clamp=True, use_column_width=True)
else:
st.write("No GT Mask")
with col4:
# Overlay
mask_rgba = Image.new("RGBA", DISPLAY_SIZE, (0, 255, 0, 0)) # Green for prediction
mask_data = np.array(mask_rgba)
mask_data[:, :, 3] = (pred_np * 100).astype(np.uint8)
mask_rgba = Image.fromarray(mask_data)
final_overlay = Image.alpha_composite(image_resized.convert("RGBA"), mask_rgba)
if gt_mask is not None:
gt_rgba = Image.new("RGBA", DISPLAY_SIZE, (255, 0, 0, 0)) # Red for GT
gt_data = np.array(gt_rgba)
gt_data[:, :, 3] = ((gt_mask > 0) * 100).astype(np.uint8)
gt_rgba = Image.fromarray(gt_data)
final_overlay = Image.alpha_composite(final_overlay, gt_rgba)
st.image(final_overlay, caption="Overlay (Green: Pred, Red: GT)", use_column_width=True)
else:
st.image(final_overlay, caption="Overlay (Green: Pred)", use_column_width=True)