# ************************************************************************* # Grasp Any Region (GAR) - Gradio Demo # Region-level Multimodal Understanding for Vision-Language Models # ************************************************************************* # 🚨 CRITICAL: Import spaces FIRST before any CUDA-related packages import spaces # Now import CUDA-related packages import torch import numpy as np from PIL import Image import gradio as gr from transformers import ( AutoModel, AutoProcessor, GenerationConfig, SamModel, SamProcessor, ) import cv2 import sys import os # Add project root to path for imports sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) try: from evaluation.eval_dataset import SingleRegionCaptionDataset except ImportError: print("Warning: Could not import SingleRegionCaptionDataset. Using simplified version.") SingleRegionCaptionDataset = None # Initialize device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Global model variables (loaded once) gar_model = None gar_processor = None sam_model = None sam_processor = None def load_models(): """Load models once at startup""" global gar_model, gar_processor, sam_model, sam_processor if gar_model is None: print("Loading GAR model...") model_path = "HaochenWang/GAR-1B" gar_model = AutoModel.from_pretrained( model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto", ).eval() gar_processor = AutoProcessor.from_pretrained( model_path, trust_remote_code=True, ) print("GAR model loaded successfully!") if sam_model is None: print("Loading SAM model...") sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device) sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") print("SAM model loaded successfully!") @spaces.GPU(duration=120) def generate_mask_from_points(image, points_str): """Generate mask using SAM from point coordinates""" try: load_models() if not points_str or points_str.strip() == "": return None, "Please provide points in format: x1,y1;x2,y2" # Parse points points = [] labels = [] for point in points_str.split(';'): point = point.strip() if point: x, y = map(float, point.split(',')) points.append([x, y]) labels.append(1) # Foreground point if not points: return None, "No valid points provided" # Apply SAM inputs = sam_processor( image, input_points=[points], input_labels=[labels], return_tensors="pt", ).to(device) with torch.no_grad(): outputs = sam_model(**inputs) masks = sam_processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu(), )[0][0] scores = outputs.iou_scores[0, 0] mask_selection_index = scores.argmax() mask_np = masks[mask_selection_index].numpy() # Visualize mask mask_img = (mask_np * 255).astype(np.uint8) return Image.fromarray(mask_img), "Mask generated successfully!" except Exception as e: return None, f"Error generating mask: {str(e)}" @spaces.GPU(duration=120) def generate_mask_from_box(image, box_str): """Generate mask using SAM from bounding box""" try: load_models() if not box_str or box_str.strip() == "": return None, "Please provide box in format: x1,y1,x2,y2" # Parse box box = list(map(float, box_str.split(','))) if len(box) != 4: return None, "Box must have 4 coordinates: x1,y1,x2,y2" # Apply SAM inputs = sam_processor( image, input_boxes=[[box]], return_tensors="pt", ).to(device) with torch.no_grad(): outputs = sam_model(**inputs) masks = sam_processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu(), )[0][0] scores = outputs.iou_scores[0, 0] mask_selection_index = scores.argmax() mask_np = masks[mask_selection_index].numpy() # Visualize mask mask_img = (mask_np * 255).astype(np.uint8) return Image.fromarray(mask_img), "Mask generated successfully!" except Exception as e: return None, f"Error generating mask: {str(e)}" @spaces.GPU(duration=120) def describe_region(image, mask): """Generate description for a region defined by a mask""" try: load_models() if image is None: return "Please provide an image" if mask is None: return "Please provide a mask (upload or generate using SAM)" # Convert mask to numpy if isinstance(mask, Image.Image): mask_np = np.array(mask.convert("L")) else: mask_np = np.array(mask) # Ensure mask is binary mask_np = (mask_np > 127).astype(np.uint8) # Prepare data prompt_number = gar_model.config.prompt_numbers prompt_tokens = [f"" for i_p in range(prompt_number)] + [""] if SingleRegionCaptionDataset is not None: dataset = SingleRegionCaptionDataset( image=image, mask=mask_np, processor=gar_processor, prompt_number=prompt_number, visual_prompt_tokens=prompt_tokens, data_dtype=torch.bfloat16, ) data_sample = dataset[0] else: # Simplified processing if dataset class not available # This is a fallback - the actual implementation requires SingleRegionCaptionDataset return "Error: SingleRegionCaptionDataset not available. Please check installation." # Generate description with torch.no_grad(): generate_ids = gar_model.generate( **data_sample, generation_config=GenerationConfig( max_new_tokens=1024, do_sample=False, eos_token_id=gar_processor.tokenizer.eos_token_id, pad_token_id=gar_processor.tokenizer.pad_token_id, ), return_dict=True, ) output_caption = gar_processor.tokenizer.decode( generate_ids.sequences[0], skip_special_tokens=True ).strip() return output_caption except Exception as e: return f"Error generating description: {str(e)}" def create_visualization(image, mask, points_str=None, box_str=None): """Create visualization with mask overlay""" try: if image is None or mask is None: return None img_np = np.array(image).astype(float) / 255.0 if isinstance(mask, Image.Image): mask_np = np.array(mask.convert("L")) > 127 else: mask_np = np.array(mask) > 127 # Draw contour mask_uint8 = mask_np.astype(np.uint8) * 255 contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) img_vis = img_np.copy() cv2.drawContours(img_vis, contours, -1, (1.0, 1.0, 0.0), thickness=3) # Draw points if provided if points_str: for point in points_str.split(';'): point = point.strip() if point: x, y = map(float, point.split(',')) cv2.circle(img_vis, (int(x), int(y)), radius=8, color=(1.0, 0.0, 0.0), thickness=-1) cv2.circle(img_vis, (int(x), int(y)), radius=8, color=(1.0, 1.0, 1.0), thickness=2) # Draw box if provided if box_str: coords = list(map(float, box_str.split(','))) if len(coords) == 4: x1, y1, x2, y2 = map(int, coords) cv2.rectangle(img_vis, (x1, y1), (x2, y2), color=(1.0, 1.0, 1.0), thickness=3) cv2.rectangle(img_vis, (x1, y1), (x2, y2), color=(1.0, 0.0, 0.0), thickness=1) img_pil = Image.fromarray((img_vis * 255.0).astype(np.uint8)) return img_pil except Exception as e: print(f"Error creating visualization: {str(e)}") return None # Create Gradio interface with gr.Blocks(title="Grasp Any Region (GAR) Demo", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🎯 Grasp Any Region (GAR) **Region-level Multimodal Understanding for Vision-Language Models** This demo showcases GAR's ability to understand and describe specific regions in images: - 🎨 **Single Region Understanding**: Describe specific areas using points, boxes, or masks - 🔍 **SAM Integration**: Generate masks interactively using Segment Anything Model - 💡 **Detailed Descriptions**: Get comprehensive descriptions of any region Built on top of Perception-LM with RoI-aligned feature replay technique. 📄 [Paper](https://arxiv.org/abs/2510.18876) | 💻 [GitHub](https://github.com/Haochen-Wang409/Grasp-Any-Region) | 🤗 [Model](https://huggingface.co/HaochenWang/GAR-1B) """) with gr.Tabs(): # Tab 1: Points-based segmentation with gr.Tab("🎯 Points → Describe"): gr.Markdown("### Click points on the image or enter coordinates to segment and describe a region") with gr.Row(): with gr.Column(): img_points = gr.Image(label="Input Image", type="pil") points_input = gr.Textbox( label="Points (format: x1,y1;x2,y2;...)", placeholder="e.g., 1172,812;1572,800", value="1172,812;1572,800" ) with gr.Row(): gen_mask_points_btn = gr.Button("Generate Mask", variant="primary") describe_points_btn = gr.Button("Describe Region", variant="secondary") with gr.Column(): mask_points = gr.Image(label="Generated Mask", type="pil") vis_points = gr.Image(label="Visualization") desc_points = gr.Textbox(label="Region Description", lines=5) points_status = gr.Textbox(label="Status", visible=False) gen_mask_points_btn.click( fn=generate_mask_from_points, inputs=[img_points, points_input], outputs=[mask_points, points_status] ) describe_points_btn.click( fn=describe_region, inputs=[img_points, mask_points], outputs=desc_points ).then( fn=create_visualization, inputs=[img_points, mask_points, points_input, gr.Textbox(visible=False)], outputs=vis_points ) gr.Examples( examples=[ ["assets/demo_image_2.jpg", "1172,812;1572,800"], ], inputs=[img_points, points_input], label="Example Images" ) # Tab 2: Box-based segmentation with gr.Tab("📦 Box → Describe"): gr.Markdown("### Draw a bounding box or enter coordinates to segment and describe a region") with gr.Row(): with gr.Column(): img_box = gr.Image(label="Input Image", type="pil") box_input = gr.Textbox( label="Bounding Box (format: x1,y1,x2,y2)", placeholder="e.g., 800,500,1800,1000", value="800,500,1800,1000" ) with gr.Row(): gen_mask_box_btn = gr.Button("Generate Mask", variant="primary") describe_box_btn = gr.Button("Describe Region", variant="secondary") with gr.Column(): mask_box = gr.Image(label="Generated Mask", type="pil") vis_box = gr.Image(label="Visualization") desc_box = gr.Textbox(label="Region Description", lines=5) box_status = gr.Textbox(label="Status", visible=False) gen_mask_box_btn.click( fn=generate_mask_from_box, inputs=[img_box, box_input], outputs=[mask_box, box_status] ) describe_box_btn.click( fn=describe_region, inputs=[img_box, mask_box], outputs=desc_box ).then( fn=create_visualization, inputs=[img_box, mask_box, gr.Textbox(visible=False), box_input], outputs=vis_box ) gr.Examples( examples=[ ["assets/demo_image_2.jpg", "800,500,1800,1000"], ], inputs=[img_box, box_input], label="Example Images" ) # Tab 3: Direct mask upload with gr.Tab("🎭 Mask → Describe"): gr.Markdown("### Upload a pre-made mask to describe a region") with gr.Row(): with gr.Column(): img_mask = gr.Image(label="Input Image", type="pil") mask_upload = gr.Image(label="Upload Mask", type="pil") describe_mask_btn = gr.Button("Describe Region", variant="primary") with gr.Column(): vis_mask = gr.Image(label="Visualization") desc_mask = gr.Textbox(label="Region Description", lines=5) describe_mask_btn.click( fn=describe_region, inputs=[img_mask, mask_upload], outputs=desc_mask ).then( fn=create_visualization, inputs=[img_mask, mask_upload, gr.Textbox(visible=False), gr.Textbox(visible=False)], outputs=vis_mask ) gr.Examples( examples=[ ["assets/demo_image_1.png", "assets/demo_mask_1.png"], ], inputs=[img_mask, mask_upload], label="Example Images" ) gr.Markdown(""" --- ### 📖 How to Use: 1. **Points → Describe**: Click or enter point coordinates, generate mask, then describe 2. **Box → Describe**: Draw or enter a bounding box, generate mask, then describe 3. **Mask → Describe**: Upload a pre-made mask directly and describe ### 🔧 Technical Details: - **Model**: GAR-1B (1 billion parameters) - **Base**: Facebook Perception-LM with RoI-aligned feature replay - **Segmentation**: Segment Anything Model (SAM ViT-Huge) - **Hardware**: Powered by ZeroGPU (NVIDIA H200, 70GB VRAM) ### 📚 Citation: ```bibtex @article{wang2025grasp, title={Grasp Any Region: Prompting MLLM to Understand the Dense World}, author={Haochen Wang et al.}, journal={arXiv preprint arXiv:2510.18876}, year={2025} } ``` """) # Load models on startup try: load_models() except Exception as e: print(f"Warning: Could not pre-load models: {e}") print("Models will be loaded on first use.") if __name__ == "__main__": demo.launch()