Spaces:
Running
Running
| """ | |
| Leaf Segmentation using SAM2. | |
| This module provides leaf segmentation functionality to isolate leaves | |
| from backgrounds before disease detection. | |
| """ | |
| import numpy as np | |
| from PIL import Image | |
| from typing import Optional, Tuple, List | |
| import torch | |
| class SAM2LeafSegmenter: | |
| """ | |
| Segments leaves from images using SAM2 (Segment Anything Model 2). | |
| This is used as a preprocessing step to: | |
| 1. Isolate the leaf from the background | |
| 2. Create a white background image with just the leaf | |
| 3. Reduce false positives in disease detection | |
| """ | |
| def __init__( | |
| self, | |
| checkpoint_path: str = "models/sam2/sam2.1_hiera_small.pt", | |
| config_file: str = "configs/sam2.1/sam2.1_hiera_s.yaml", | |
| device: Optional[str] = None | |
| ): | |
| """ | |
| Initialize SAM2 leaf segmenter. | |
| Args: | |
| checkpoint_path: Path to SAM2 checkpoint | |
| config_file: SAM2 config file name | |
| device: Device to use ('cuda', 'mps', 'cpu'). Auto-detected if None. | |
| """ | |
| self.checkpoint_path = checkpoint_path | |
| self.config_file = config_file | |
| if device is None: | |
| if torch.cuda.is_available(): | |
| self.device = 'cuda' | |
| elif torch.backends.mps.is_available(): | |
| self.device = 'mps' | |
| else: | |
| self.device = 'cpu' | |
| else: | |
| self.device = device | |
| self.model = None | |
| self.predictor = None | |
| self.mask_generator = None | |
| def load_model(self): | |
| """Load SAM2 model.""" | |
| if self.model is not None: | |
| return | |
| from sam2.build_sam import build_sam2 | |
| from sam2.sam2_image_predictor import SAM2ImagePredictor | |
| print(f"Loading SAM2 model on {self.device}...") | |
| self.model = build_sam2( | |
| config_file=self.config_file, | |
| ckpt_path=self.checkpoint_path, | |
| device=self.device | |
| ) | |
| self.predictor = SAM2ImagePredictor(self.model) | |
| print("SAM2 model loaded.") | |
| def load_mask_generator(self): | |
| """Load SAM2 automatic mask generator for multi-object segmentation.""" | |
| self.load_model() | |
| if self.mask_generator is not None: | |
| return | |
| from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator | |
| print("Initializing SAM2 automatic mask generator...") | |
| self.mask_generator = SAM2AutomaticMaskGenerator( | |
| model=self.model, | |
| points_per_side=32, | |
| points_per_batch=64, | |
| pred_iou_thresh=0.7, | |
| stability_score_thresh=0.92, | |
| crop_n_layers=1, | |
| min_mask_region_area=500, | |
| ) | |
| print("SAM2 mask generator ready.") | |
| def segment_leaf( | |
| self, | |
| image: Image.Image, | |
| point: Optional[Tuple[int, int]] = None, | |
| return_mask: bool = False | |
| ) -> Image.Image | Tuple[Image.Image, np.ndarray]: | |
| """ | |
| Segment the leaf from the image. | |
| Args: | |
| image: PIL Image to segment | |
| point: (x, y) point to indicate the leaf. If None, uses image center. | |
| return_mask: If True, also returns the binary mask | |
| Returns: | |
| Image with leaf on white background (and mask if return_mask=True) | |
| """ | |
| self.load_model() | |
| # Convert to numpy array | |
| image_np = np.array(image.convert('RGB')) | |
| h, w = image_np.shape[:2] | |
| # Use center point if not specified | |
| if point is None: | |
| point = (w // 2, h // 2) | |
| # Set image for predictor | |
| self.predictor.set_image(image_np) | |
| # Predict mask using point prompt | |
| point_coords = np.array([[point[0], point[1]]]) | |
| point_labels = np.array([1]) # 1 = foreground | |
| masks, scores, _ = self.predictor.predict( | |
| point_coords=point_coords, | |
| point_labels=point_labels, | |
| multimask_output=True | |
| ) | |
| # Select best mask (highest score) | |
| best_idx = np.argmax(scores) | |
| mask = masks[best_idx].astype(bool) | |
| # Create white background image | |
| result = np.ones_like(image_np) * 255 # White background | |
| result[mask] = image_np[mask] # Copy leaf pixels | |
| result_image = Image.fromarray(result.astype(np.uint8)) | |
| if return_mask: | |
| return result_image, mask | |
| return result_image | |
| def segment_leaf_with_bbox( | |
| self, | |
| image: Image.Image, | |
| bbox: Optional[Tuple[int, int, int, int]] = None, | |
| return_mask: bool = False | |
| ) -> Image.Image | Tuple[Image.Image, np.ndarray]: | |
| """ | |
| Segment the leaf using a bounding box prompt. | |
| Args: | |
| image: PIL Image to segment | |
| bbox: (x1, y1, x2, y2) bounding box. If None, uses full image. | |
| return_mask: If True, also returns the binary mask | |
| Returns: | |
| Image with leaf on white background (and mask if return_mask=True) | |
| """ | |
| self.load_model() | |
| # Convert to numpy array | |
| image_np = np.array(image.convert('RGB')) | |
| h, w = image_np.shape[:2] | |
| # Use full image bbox if not specified | |
| if bbox is None: | |
| # Use slightly inset bbox to focus on leaf | |
| margin = min(w, h) // 20 | |
| bbox = (margin, margin, w - margin, h - margin) | |
| # Set image for predictor | |
| self.predictor.set_image(image_np) | |
| # Predict mask using box prompt | |
| box = np.array([bbox]) | |
| masks, scores, _ = self.predictor.predict( | |
| box=box, | |
| multimask_output=True | |
| ) | |
| # Select best mask (highest score) | |
| best_idx = np.argmax(scores) | |
| mask = masks[best_idx].astype(bool) | |
| # Create white background image | |
| result = np.ones_like(image_np) * 255 # White background | |
| result[mask] = image_np[mask] # Copy leaf pixels | |
| result_image = Image.fromarray(result.astype(np.uint8)) | |
| if return_mask: | |
| return result_image, mask | |
| return result_image | |
| def auto_segment_leaf( | |
| self, | |
| image: Image.Image, | |
| return_mask: bool = False | |
| ) -> Image.Image | Tuple[Image.Image, np.ndarray]: | |
| """ | |
| Automatically segment the main leaf/plant from the image. | |
| Uses multiple strategies to find the best segmentation: | |
| 1. Center point | |
| 2. Multiple points in a grid | |
| 3. Green color detection for better point selection | |
| 4. Selects the largest coherent mask | |
| Args: | |
| image: PIL Image to segment | |
| return_mask: If True, also returns the binary mask | |
| Returns: | |
| Image with leaf on white background (and mask if return_mask=True) | |
| """ | |
| self.load_model() | |
| # Convert to numpy array | |
| image_np = np.array(image.convert('RGB')) | |
| h, w = image_np.shape[:2] | |
| # Set image for predictor | |
| self.predictor.set_image(image_np) | |
| # Try to find a good point on the leaf using green color detection | |
| # Convert to HSV for better color detection | |
| from PIL import ImageFilter | |
| import colorsys | |
| # Simple green detection: look for pixels with green hue | |
| green_mask = self._detect_green_regions(image_np) | |
| # Find centroid of green regions, fallback to image center | |
| if green_mask.sum() > 100: # At least some green pixels | |
| y_coords, x_coords = np.where(green_mask) | |
| center_x = int(np.median(x_coords)) | |
| center_y = int(np.median(y_coords)) | |
| else: | |
| center_x, center_y = w // 2, h // 2 | |
| # Try multiple points for robustness | |
| points_to_try = [ | |
| (center_x, center_y), # Green centroid or center | |
| (w // 2, h // 2), # Image center | |
| (w // 3, h // 2), # Left third | |
| (2 * w // 3, h // 2), # Right third | |
| ] | |
| best_mask = None | |
| best_score = -1 | |
| for px, py in points_to_try: | |
| point = np.array([[px, py]]) | |
| label = np.array([1]) | |
| masks, scores, _ = self.predictor.predict( | |
| point_coords=point, | |
| point_labels=label, | |
| multimask_output=True | |
| ) | |
| for mask, score in zip(masks, scores): | |
| # Ensure mask is boolean for indexing | |
| mask = mask.astype(bool) | |
| # Calculate mask coverage | |
| coverage = mask.sum() / (h * w) | |
| # Prefer masks that cover 5-95% of image (more flexible range) | |
| if 0.05 < coverage < 0.95: | |
| # Check if mask contains green (likely a leaf) | |
| green_in_mask = green_mask[mask].sum() / max(mask.sum(), 1) | |
| # Bonus for being closer to 30-70% coverage | |
| coverage_score = 1 - abs(coverage - 0.5) | |
| # Combined score: SAM confidence + coverage + greenness | |
| combined_score = score * 0.5 + coverage_score * 0.2 + green_in_mask * 0.3 | |
| if combined_score > best_score: | |
| best_score = combined_score | |
| best_mask = mask | |
| # Fallback to highest score mask from center point | |
| if best_mask is None: | |
| center_point = np.array([[w // 2, h // 2]]) | |
| center_label = np.array([1]) | |
| masks, scores, _ = self.predictor.predict( | |
| point_coords=center_point, | |
| point_labels=center_label, | |
| multimask_output=True | |
| ) | |
| best_idx = np.argmax(scores) | |
| best_mask = masks[best_idx] | |
| # Ensure mask is boolean | |
| best_mask = best_mask.astype(bool) | |
| # Create white background image | |
| result = np.ones_like(image_np) * 255 # White background | |
| result[best_mask] = image_np[best_mask] # Copy leaf pixels | |
| result_image = Image.fromarray(result.astype(np.uint8)) | |
| if return_mask: | |
| return result_image, best_mask | |
| return result_image | |
| def _detect_green_regions(self, image_np: np.ndarray) -> np.ndarray: | |
| """Detect green regions in image (likely leaf areas).""" | |
| # Convert RGB to HSV for better green detection | |
| r, g, b = image_np[:,:,0], image_np[:,:,1], image_np[:,:,2] | |
| # Green typically has: g > r, g > b, and reasonable brightness | |
| green_mask = ( | |
| (g > r * 0.9) & # Green channel dominant over red | |
| (g > b * 0.9) & # Green channel dominant over blue | |
| (g > 40) & # Not too dark | |
| (g < 250) # Not too bright (white) | |
| ) | |
| # Also detect yellow-green (common in leaves) | |
| yellow_green = ( | |
| (g > 50) & | |
| (r > 50) & | |
| (b < r * 0.8) & # Blue much less than red | |
| (abs(g.astype(int) - r.astype(int)) < 80) # R and G similar | |
| ) | |
| return green_mask | yellow_green | |
| def refine_boxes_to_masks( | |
| self, | |
| image: Image.Image, | |
| boxes: np.ndarray, | |
| return_scores: bool = False | |
| ) -> np.ndarray | Tuple[np.ndarray, np.ndarray]: | |
| """ | |
| Refine bounding boxes into precise segmentation masks using SAM2. | |
| This is used to convert RF-DETR detection boxes into proper | |
| segmentation masks for disease regions. | |
| Args: | |
| image: PIL Image | |
| boxes: Array of bounding boxes [N, 4] in xyxy format | |
| return_scores: If True, also returns confidence scores | |
| Returns: | |
| Array of masks [N, H, W] (and scores if return_scores=True) | |
| """ | |
| self.load_model() | |
| # Convert to numpy array | |
| image_np = np.array(image.convert('RGB')) | |
| h, w = image_np.shape[:2] | |
| if len(boxes) == 0: | |
| empty_masks = np.zeros((0, h, w), dtype=bool) | |
| if return_scores: | |
| return empty_masks, np.zeros((0,), dtype=np.float32) | |
| return empty_masks | |
| # Set image for predictor | |
| self.predictor.set_image(image_np) | |
| masks_list = [] | |
| scores_list = [] | |
| for box in boxes: | |
| # Use box prompt for SAM2 | |
| box_np = np.array([box]) | |
| masks, scores, _ = self.predictor.predict( | |
| box=box_np, | |
| multimask_output=True | |
| ) | |
| # Select best mask (highest score) | |
| best_idx = np.argmax(scores) | |
| best_mask = masks[best_idx].astype(bool) | |
| best_score = scores[best_idx] | |
| masks_list.append(best_mask) | |
| scores_list.append(best_score) | |
| result_masks = np.stack(masks_list, axis=0) if masks_list else np.zeros((0, h, w), dtype=bool) | |
| result_scores = np.array(scores_list, dtype=np.float32) | |
| if return_scores: | |
| return result_masks, result_scores | |
| return result_masks | |
| # Convenience function | |
| def create_leaf_segmenter( | |
| checkpoint_path: str = "models/sam2/sam2.1_hiera_small.pt", | |
| device: Optional[str] = None | |
| ) -> SAM2LeafSegmenter: | |
| """Create a SAM2 leaf segmenter instance.""" | |
| return SAM2LeafSegmenter( | |
| checkpoint_path=checkpoint_path, | |
| device=device | |
| ) | |