import streamlit as st import os import pathlib import numpy as np import torch from PIL import Image import sys from streamlit_drawable_canvas import st_canvas from huggingface_hub import hf_hub_download # Add project root to path to import modules sys.path.append(os.getcwd()) try: from demo_inference import load_model, predict except ImportError: # Handle if running from different directory structure sys.path.append(os.path.join(os.getcwd(), "ScribblePrompt")) from demo_inference import load_model, predict # Configuration REPO_ID = "hbyecoding/iU-RWKV" # User needs to replace this MODELS_CONFIG = { "ISIC18": { "dataset": "ISIC18", "checkpoint_name": "isic18_max-val_id-dice_score.pt", "config": "configs/train_urwkv_isic18.yaml" }, "POLY": { "dataset": "POLY", "checkpoint_name": "poly_max-val_id-dice_score.pt", "config": "configs/train_urwkv_poly.yaml" }, "BUSI (Custom)": { "dataset": "BUSI", "checkpoint_name": "custom_max-val_id-dice_score.pt", "config": "configs/train_urwkv_custom.yaml" } } # Assets are now relative to the app directory ASSETS_ROOT = pathlib.Path("hf_demo_assets") @st.cache_resource def get_model(config_path, checkpoint_name): # Try local first (for testing), then HF Hub local_path = os.path.join("checkpoints", checkpoint_name) if os.path.exists(local_path): ckpt_path = local_path else: try: ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=checkpoint_name) except Exception as e: st.error(f"Failed to download model from Hugging Face: {e}") st.error("Please ensure you have uploaded the models and updated REPO_ID in the code.") st.stop() return load_model(config_path, ckpt_path, device="cuda" if torch.cuda.is_available() else "cpu") def calculate_dice(pred_mask, gt_mask): pred = (pred_mask > 0.5).astype(np.uint8) gt = (gt_mask > 0).astype(np.uint8) intersection = np.sum(pred * gt) union = np.sum(pred) + np.sum(gt) if union == 0: return 1.0 if np.sum(pred) == 0 else 0.0 return 2.0 * intersection / union def get_image_list(dataset_name): dataset_dir = ASSETS_ROOT / dataset_name / "images" if not dataset_dir.exists(): return [] valid_files = [] for file in dataset_dir.iterdir(): if file.suffix.lower() in [".png", ".jpg", ".jpeg", ".bmp"]: valid_files.append((str(file), file.stem, file.suffix)) return sorted(valid_files) def get_mask_path(dataset_name, name): mask_dir = ASSETS_ROOT / dataset_name / "masks" if not mask_dir.exists(): return None for file in mask_dir.iterdir(): if file.stem == name: return str(file) return None st.set_page_config(page_title="iU-RWKV Interactive Segmentation", layout="wide") st.title("iU-RWKV Interactive Segmentation Demo") st.markdown("This demo showcases the interactive segmentation capabilities of the **iU-RWKV** model.") st.sidebar.header("Settings") # Model Selection selected_model_name = st.sidebar.selectbox("Select Model & Dataset", list(MODELS_CONFIG.keys())) model_info = MODELS_CONFIG[selected_model_name] dataset_name = model_info["dataset"] # Load specific model exp = get_model(model_info["config"], model_info["checkpoint_name"]) image_list_info = get_image_list(dataset_name) if not image_list_info: st.error(f"No demo images found for {dataset_name}. Please check hf_demo_assets directory.") st.stop() # Image Selection selected_image_info = st.sidebar.selectbox("Select Image", image_list_info, format_func=lambda x: os.path.basename(x[0])) selected_image_path, img_name, img_ext = selected_image_info # Try to find GT Mask gt_mask_path = get_mask_path(dataset_name, img_name) gt_mask = None if gt_mask_path: gt_mask_img = Image.open(gt_mask_path).convert("L") # We will resize GT mask to 256x256 to match prediction gt_mask = np.array(gt_mask_img.resize((256, 256), Image.NEAREST)) # Load Image image = Image.open(selected_image_path).convert("RGB") DISPLAY_SIZE = (256, 256) image_resized = image.resize(DISPLAY_SIZE) st.header("Interactive Workspace") st.write("Draw a **Box (Rect)** or **Click (Point)** on the image to guide the segmentation.") col_tools, col_canvas = st.columns([1, 3]) with col_tools: # Interaction Mode interaction_mode = st.radio("Interaction Mode", ["Box", "Positive Click", "Negative Click"]) drawing_mode = "rect" if interaction_mode == "Box" else "point" stroke_color = "green" if interaction_mode == "Positive Click" else "red" if interaction_mode == "Box": stroke_color = "blue" st.info("Tip: Draw a box around the object, or click green points on the object and red points on the background.") with col_canvas: # Canvas canvas_result = st_canvas( fill_color="rgba(255, 165, 0, 0.3)", stroke_width=3, stroke_color=stroke_color, background_image=image_resized, update_streamlit=True, height=DISPLAY_SIZE[1], width=DISPLAY_SIZE[0], drawing_mode=drawing_mode, key="canvas", ) if st.button("Run Prediction", type="primary"): bbox = None clicks = [] if canvas_result.json_data is not None: objects = canvas_result.json_data["objects"] for obj in objects: if obj["type"] == "rect": x_min = int(obj["left"]) y_min = int(obj["top"]) x_max = int(obj["left"] + obj["width"]) y_max = int(obj["top"] + obj["height"]) bbox = [x_min, y_min, x_max, y_max] elif obj["type"] in ["circle", "point"]: x = int(obj["left"] + obj["width"] / 2) y = int(obj["top"] + obj["height"] / 2) label = 1 if obj["stroke"] == "green" else 0 clicks.append((x, y, label)) with st.spinner("Running iU-RWKV Inference..."): temp_img_path = "temp_input.png" image_resized.save(temp_img_path) pred_mask, logits = predict(exp, temp_img_path, clicks=clicks, bbox=bbox) pred_np = pred_mask.squeeze().cpu().numpy() dice_score = None if gt_mask is not None: dice_score = calculate_dice(pred_np, gt_mask) st.divider() st.subheader("Results") if dice_score is not None: st.success(f"**Dice Score:** {dice_score:.4f}") else: st.warning("Ground Truth mask not found, cannot calculate Dice score.") col1, col2, col3, col4 = st.columns(4) with col1: st.image(image_resized, caption="Input Image", use_column_width=True) with col2: st.image(pred_np, caption="Prediction Mask", clamp=True, use_column_width=True) with col3: if gt_mask is not None: st.image(gt_mask, caption="Ground Truth", clamp=True, use_column_width=True) else: st.write("No GT Mask") with col4: # Overlay mask_rgba = Image.new("RGBA", DISPLAY_SIZE, (0, 255, 0, 0)) # Green for prediction mask_data = np.array(mask_rgba) mask_data[:, :, 3] = (pred_np * 100).astype(np.uint8) mask_rgba = Image.fromarray(mask_data) final_overlay = Image.alpha_composite(image_resized.convert("RGBA"), mask_rgba) if gt_mask is not None: gt_rgba = Image.new("RGBA", DISPLAY_SIZE, (255, 0, 0, 0)) # Red for GT gt_data = np.array(gt_rgba) gt_data[:, :, 3] = ((gt_mask > 0) * 100).astype(np.uint8) gt_rgba = Image.fromarray(gt_data) final_overlay = Image.alpha_composite(final_overlay, gt_rgba) st.image(final_overlay, caption="Overlay (Green: Pred, Red: GT)", use_column_width=True) else: st.image(final_overlay, caption="Overlay (Green: Pred)", use_column_width=True)