Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| import pathlib | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| import sys | |
| from streamlit_drawable_canvas import st_canvas | |
| from huggingface_hub import hf_hub_download | |
| # Add project root to path to import modules | |
| sys.path.append(os.getcwd()) | |
| try: | |
| from demo_inference import load_model, predict | |
| except ImportError: | |
| # Handle if running from different directory structure | |
| sys.path.append(os.path.join(os.getcwd(), "ScribblePrompt")) | |
| from demo_inference import load_model, predict | |
| # Configuration | |
| REPO_ID = "hbyecoding/iU-RWKV" # User needs to replace this | |
| MODELS_CONFIG = { | |
| "ISIC18": { | |
| "dataset": "ISIC18", | |
| "checkpoint_name": "isic18_max-val_id-dice_score.pt", | |
| "config": "configs/train_urwkv_isic18.yaml" | |
| }, | |
| "POLY": { | |
| "dataset": "POLY", | |
| "checkpoint_name": "poly_max-val_id-dice_score.pt", | |
| "config": "configs/train_urwkv_poly.yaml" | |
| }, | |
| "BUSI (Custom)": { | |
| "dataset": "BUSI", | |
| "checkpoint_name": "custom_max-val_id-dice_score.pt", | |
| "config": "configs/train_urwkv_custom.yaml" | |
| } | |
| } | |
| # Assets are now relative to the app directory | |
| ASSETS_ROOT = pathlib.Path("hf_demo_assets") | |
| def get_model(config_path, checkpoint_name): | |
| # Try local first (for testing), then HF Hub | |
| local_path = os.path.join("checkpoints", checkpoint_name) | |
| if os.path.exists(local_path): | |
| ckpt_path = local_path | |
| else: | |
| try: | |
| ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=checkpoint_name) | |
| except Exception as e: | |
| st.error(f"Failed to download model from Hugging Face: {e}") | |
| st.error("Please ensure you have uploaded the models and updated REPO_ID in the code.") | |
| st.stop() | |
| return load_model(config_path, ckpt_path, device="cuda" if torch.cuda.is_available() else "cpu") | |
| def calculate_dice(pred_mask, gt_mask): | |
| pred = (pred_mask > 0.5).astype(np.uint8) | |
| gt = (gt_mask > 0).astype(np.uint8) | |
| intersection = np.sum(pred * gt) | |
| union = np.sum(pred) + np.sum(gt) | |
| if union == 0: | |
| return 1.0 if np.sum(pred) == 0 else 0.0 | |
| return 2.0 * intersection / union | |
| def get_image_list(dataset_name): | |
| dataset_dir = ASSETS_ROOT / dataset_name / "images" | |
| if not dataset_dir.exists(): | |
| return [] | |
| valid_files = [] | |
| for file in dataset_dir.iterdir(): | |
| if file.suffix.lower() in [".png", ".jpg", ".jpeg", ".bmp"]: | |
| valid_files.append((str(file), file.stem, file.suffix)) | |
| return sorted(valid_files) | |
| def get_mask_path(dataset_name, name): | |
| mask_dir = ASSETS_ROOT / dataset_name / "masks" | |
| if not mask_dir.exists(): | |
| return None | |
| for file in mask_dir.iterdir(): | |
| if file.stem == name: | |
| return str(file) | |
| return None | |
| st.set_page_config(page_title="iU-RWKV Interactive Segmentation", layout="wide") | |
| st.title("iU-RWKV Interactive Segmentation Demo") | |
| st.markdown("This demo showcases the interactive segmentation capabilities of the **iU-RWKV** model.") | |
| st.sidebar.header("Settings") | |
| # Model Selection | |
| selected_model_name = st.sidebar.selectbox("Select Model & Dataset", list(MODELS_CONFIG.keys())) | |
| model_info = MODELS_CONFIG[selected_model_name] | |
| dataset_name = model_info["dataset"] | |
| # Load specific model | |
| exp = get_model(model_info["config"], model_info["checkpoint_name"]) | |
| image_list_info = get_image_list(dataset_name) | |
| if not image_list_info: | |
| st.error(f"No demo images found for {dataset_name}. Please check hf_demo_assets directory.") | |
| st.stop() | |
| # Image Selection | |
| selected_image_info = st.sidebar.selectbox("Select Image", image_list_info, format_func=lambda x: os.path.basename(x[0])) | |
| selected_image_path, img_name, img_ext = selected_image_info | |
| # Try to find GT Mask | |
| gt_mask_path = get_mask_path(dataset_name, img_name) | |
| gt_mask = None | |
| if gt_mask_path: | |
| gt_mask_img = Image.open(gt_mask_path).convert("L") | |
| # We will resize GT mask to 256x256 to match prediction | |
| gt_mask = np.array(gt_mask_img.resize((256, 256), Image.NEAREST)) | |
| # Load Image | |
| image = Image.open(selected_image_path).convert("RGB") | |
| DISPLAY_SIZE = (256, 256) | |
| image_resized = image.resize(DISPLAY_SIZE) | |
| st.header("Interactive Workspace") | |
| st.write("Draw a **Box (Rect)** or **Click (Point)** on the image to guide the segmentation.") | |
| col_tools, col_canvas = st.columns([1, 3]) | |
| with col_tools: | |
| # Interaction Mode | |
| interaction_mode = st.radio("Interaction Mode", ["Box", "Positive Click", "Negative Click"]) | |
| drawing_mode = "rect" if interaction_mode == "Box" else "point" | |
| stroke_color = "green" if interaction_mode == "Positive Click" else "red" | |
| if interaction_mode == "Box": stroke_color = "blue" | |
| st.info("Tip: Draw a box around the object, or click green points on the object and red points on the background.") | |
| with col_canvas: | |
| # Canvas | |
| canvas_result = st_canvas( | |
| fill_color="rgba(255, 165, 0, 0.3)", | |
| stroke_width=3, | |
| stroke_color=stroke_color, | |
| background_image=image_resized, | |
| update_streamlit=True, | |
| height=DISPLAY_SIZE[1], | |
| width=DISPLAY_SIZE[0], | |
| drawing_mode=drawing_mode, | |
| key="canvas", | |
| ) | |
| if st.button("Run Prediction", type="primary"): | |
| bbox = None | |
| clicks = [] | |
| if canvas_result.json_data is not None: | |
| objects = canvas_result.json_data["objects"] | |
| for obj in objects: | |
| if obj["type"] == "rect": | |
| x_min = int(obj["left"]) | |
| y_min = int(obj["top"]) | |
| x_max = int(obj["left"] + obj["width"]) | |
| y_max = int(obj["top"] + obj["height"]) | |
| bbox = [x_min, y_min, x_max, y_max] | |
| elif obj["type"] in ["circle", "point"]: | |
| x = int(obj["left"] + obj["width"] / 2) | |
| y = int(obj["top"] + obj["height"] / 2) | |
| label = 1 if obj["stroke"] == "green" else 0 | |
| clicks.append((x, y, label)) | |
| with st.spinner("Running iU-RWKV Inference..."): | |
| temp_img_path = "temp_input.png" | |
| image_resized.save(temp_img_path) | |
| pred_mask, logits = predict(exp, temp_img_path, clicks=clicks, bbox=bbox) | |
| pred_np = pred_mask.squeeze().cpu().numpy() | |
| dice_score = None | |
| if gt_mask is not None: | |
| dice_score = calculate_dice(pred_np, gt_mask) | |
| st.divider() | |
| st.subheader("Results") | |
| if dice_score is not None: | |
| st.success(f"**Dice Score:** {dice_score:.4f}") | |
| else: | |
| st.warning("Ground Truth mask not found, cannot calculate Dice score.") | |
| col1, col2, col3, col4 = st.columns(4) | |
| with col1: | |
| st.image(image_resized, caption="Input Image", use_column_width=True) | |
| with col2: | |
| st.image(pred_np, caption="Prediction Mask", clamp=True, use_column_width=True) | |
| with col3: | |
| if gt_mask is not None: | |
| st.image(gt_mask, caption="Ground Truth", clamp=True, use_column_width=True) | |
| else: | |
| st.write("No GT Mask") | |
| with col4: | |
| # Overlay | |
| mask_rgba = Image.new("RGBA", DISPLAY_SIZE, (0, 255, 0, 0)) # Green for prediction | |
| mask_data = np.array(mask_rgba) | |
| mask_data[:, :, 3] = (pred_np * 100).astype(np.uint8) | |
| mask_rgba = Image.fromarray(mask_data) | |
| final_overlay = Image.alpha_composite(image_resized.convert("RGBA"), mask_rgba) | |
| if gt_mask is not None: | |
| gt_rgba = Image.new("RGBA", DISPLAY_SIZE, (255, 0, 0, 0)) # Red for GT | |
| gt_data = np.array(gt_rgba) | |
| gt_data[:, :, 3] = ((gt_mask > 0) * 100).astype(np.uint8) | |
| gt_rgba = Image.fromarray(gt_data) | |
| final_overlay = Image.alpha_composite(final_overlay, gt_rgba) | |
| st.image(final_overlay, caption="Overlay (Green: Pred, Red: GT)", use_column_width=True) | |
| else: | |
| st.image(final_overlay, caption="Overlay (Green: Pred)", use_column_width=True) | |