Spaces:
Sleeping
Sleeping
| """ | |
| SAM 3 Segmentation Module for CropDoctor-Semantic | |
| ================================================== | |
| This module provides the core segmentation functionality using Meta's SAM 3 | |
| for concept-based plant disease detection. | |
| SAM 3 enables zero-shot segmentation using natural language prompts, | |
| allowing detection of disease symptoms without task-specific training. | |
| """ | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from pathlib import Path | |
| from typing import List, Dict, Tuple, Optional, Union | |
| from dataclasses import dataclass | |
| import yaml | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class SegmentationResult: | |
| """Container for segmentation results.""" | |
| masks: np.ndarray # Shape: (N, H, W) boolean masks | |
| boxes: np.ndarray # Shape: (N, 4) bounding boxes [x1, y1, x2, y2] | |
| scores: np.ndarray # Shape: (N,) confidence scores | |
| prompts: List[str] # Prompts used for each detection | |
| prompt_indices: np.ndarray # Which prompt each mask corresponds to | |
| class SAM3Segmenter: | |
| """ | |
| SAM 3 based segmentation for plant disease detection. | |
| Uses text prompts to detect and segment disease symptoms in plant images. | |
| SAM 3's Promptable Concept Segmentation (PCS) enables open-vocabulary | |
| detection without fine-tuning. | |
| Example: | |
| >>> segmenter = SAM3Segmenter("models/sam3/sam3.pt") | |
| >>> result = segmenter.segment_with_concepts( | |
| ... "leaf_image.jpg", | |
| ... ["leaf with brown spots", "healthy leaf"] | |
| ... ) | |
| >>> print(f"Found {len(result.masks)} regions") | |
| """ | |
| def __init__( | |
| self, | |
| checkpoint_path: str = "models/sam3/sam3.pt", | |
| config_path: str = "configs/sam3_config.yaml", | |
| device: Optional[str] = None | |
| ): | |
| """ | |
| Initialize SAM 3 segmenter. | |
| Args: | |
| checkpoint_path: Path to SAM 3 checkpoint | |
| config_path: Path to configuration file | |
| device: Device to use (cuda, cpu, mps). Auto-detected if None. | |
| """ | |
| self.checkpoint_path = Path(checkpoint_path) | |
| self.config = self._load_config(config_path) | |
| # Set device | |
| if device is None: | |
| if torch.cuda.is_available(): | |
| self.device = "cuda" | |
| elif torch.backends.mps.is_available(): | |
| self.device = "mps" | |
| else: | |
| self.device = "cpu" | |
| else: | |
| self.device = device | |
| logger.info(f"Using device: {self.device}") | |
| # Model will be loaded lazily | |
| self.model = None | |
| self.processor = None | |
| def _load_config(self, config_path: str) -> dict: | |
| """Load configuration from YAML file.""" | |
| config_path = Path(config_path) | |
| if config_path.exists(): | |
| with open(config_path, 'r') as f: | |
| return yaml.safe_load(f) | |
| else: | |
| logger.warning(f"Config not found at {config_path}, using defaults") | |
| return self._default_config() | |
| def _default_config(self) -> dict: | |
| """Return default configuration.""" | |
| return { | |
| "inference": { | |
| "confidence_threshold": 0.25, | |
| "presence_threshold": 0.5, | |
| "max_objects_per_prompt": 50, | |
| "min_mask_area": 100 | |
| }, | |
| "visualization": { | |
| "mask_alpha": 0.5, | |
| "show_confidence": True | |
| } | |
| } | |
| def load_model(self): | |
| """Load SAM 3 model and processor.""" | |
| if self.model is not None: | |
| return | |
| logger.info("Loading SAM 3 model...") | |
| try: | |
| # Import SAM 3 modules | |
| from sam3.model_builder import build_sam3_image_model | |
| from sam3.model.sam3_image_processor import Sam3Processor | |
| # Build model | |
| self.model = build_sam3_image_model(checkpoint=str(self.checkpoint_path)) | |
| self.model.to(self.device) | |
| if self.config.get("model", {}).get("half_precision", True) and self.device == "cuda": | |
| self.model = self.model.half() | |
| self.model.eval() | |
| # Create processor | |
| self.processor = Sam3Processor(self.model) | |
| logger.info("SAM 3 model loaded successfully") | |
| except ImportError: | |
| logger.error("SAM 3 not installed. Please install from: https://github.com/facebookresearch/sam3") | |
| raise | |
| except FileNotFoundError: | |
| logger.error(f"Checkpoint not found at {self.checkpoint_path}") | |
| raise | |
| def segment_with_concepts( | |
| self, | |
| image: Union[str, Path, Image.Image, np.ndarray], | |
| text_prompts: List[str], | |
| confidence_threshold: Optional[float] = None | |
| ) -> SegmentationResult: | |
| """ | |
| Segment image using text prompts. | |
| Args: | |
| image: Input image (path, PIL Image, or numpy array) | |
| text_prompts: List of text prompts describing concepts to detect | |
| confidence_threshold: Override default confidence threshold | |
| Returns: | |
| SegmentationResult containing masks, boxes, scores, and prompt info | |
| """ | |
| # Ensure model is loaded | |
| self.load_model() | |
| # Load image | |
| if isinstance(image, (str, Path)): | |
| image = Image.open(image).convert("RGB") | |
| elif isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| # Get threshold | |
| threshold = confidence_threshold or self.config["inference"]["confidence_threshold"] | |
| # Set image in processor | |
| inference_state = self.processor.set_image(image) | |
| # Collect results from all prompts | |
| all_masks = [] | |
| all_boxes = [] | |
| all_scores = [] | |
| all_prompt_indices = [] | |
| for prompt_idx, prompt in enumerate(text_prompts): | |
| logger.debug(f"Processing prompt: {prompt}") | |
| # Get segmentation for this prompt | |
| output = self.processor.set_text_prompt( | |
| state=inference_state, | |
| prompt=prompt | |
| ) | |
| masks = output["masks"] | |
| boxes = output["boxes"] | |
| scores = output["scores"] | |
| if masks is not None and len(masks) > 0: | |
| # Convert to numpy | |
| masks_np = masks.cpu().numpy() if torch.is_tensor(masks) else masks | |
| boxes_np = boxes.cpu().numpy() if torch.is_tensor(boxes) else boxes | |
| scores_np = scores.cpu().numpy() if torch.is_tensor(scores) else scores | |
| # Filter by confidence | |
| mask = scores_np >= threshold | |
| if mask.any(): | |
| all_masks.append(masks_np[mask]) | |
| all_boxes.append(boxes_np[mask]) | |
| all_scores.append(scores_np[mask]) | |
| all_prompt_indices.append( | |
| np.full(mask.sum(), prompt_idx, dtype=np.int32) | |
| ) | |
| # Combine results | |
| if all_masks: | |
| combined_masks = np.concatenate(all_masks, axis=0) | |
| combined_boxes = np.concatenate(all_boxes, axis=0) | |
| combined_scores = np.concatenate(all_scores, axis=0) | |
| combined_indices = np.concatenate(all_prompt_indices, axis=0) | |
| else: | |
| # Return empty results | |
| h, w = np.array(image).shape[:2] | |
| combined_masks = np.zeros((0, h, w), dtype=bool) | |
| combined_boxes = np.zeros((0, 4), dtype=np.float32) | |
| combined_scores = np.zeros((0,), dtype=np.float32) | |
| combined_indices = np.zeros((0,), dtype=np.int32) | |
| return SegmentationResult( | |
| masks=combined_masks, | |
| boxes=combined_boxes, | |
| scores=combined_scores, | |
| prompts=text_prompts, | |
| prompt_indices=combined_indices | |
| ) | |
| def segment_disease_regions( | |
| self, | |
| image: Union[str, Path, Image.Image, np.ndarray], | |
| profile: str = "standard" | |
| ) -> SegmentationResult: | |
| """ | |
| Segment disease regions using predefined prompt profiles. | |
| Args: | |
| image: Input image | |
| profile: Analysis profile ("quick_scan", "standard", "comprehensive", "pest_focused") | |
| Returns: | |
| SegmentationResult for the specified analysis profile | |
| """ | |
| profiles = self.config.get("analysis_profiles", {}) | |
| if profile not in profiles: | |
| available = list(profiles.keys()) | |
| raise ValueError(f"Profile '{profile}' not found. Available: {available}") | |
| prompts = profiles[profile]["prompts"] | |
| logger.info(f"Using profile '{profile}' with {len(prompts)} prompts") | |
| return self.segment_with_concepts(image, prompts) | |
| def calculate_affected_area( | |
| self, | |
| result: SegmentationResult, | |
| healthy_prompt_idx: Optional[int] = None | |
| ) -> Dict[str, float]: | |
| """ | |
| Calculate the percentage of affected area. | |
| Args: | |
| result: Segmentation result | |
| healthy_prompt_idx: Index of the "healthy" prompt for comparison | |
| Returns: | |
| Dictionary with area statistics | |
| """ | |
| if len(result.masks) == 0: | |
| return {"total_affected_percent": 0.0, "per_symptom": {}} | |
| # Total image area | |
| h, w = result.masks[0].shape | |
| total_area = h * w | |
| # Calculate areas per prompt | |
| per_symptom = {} | |
| total_diseased_area = 0 | |
| healthy_area = 0 | |
| for prompt_idx, prompt in enumerate(result.prompts): | |
| mask_indices = result.prompt_indices == prompt_idx | |
| if mask_indices.any(): | |
| combined_mask = result.masks[mask_indices].any(axis=0) | |
| area = combined_mask.sum() | |
| percent = (area / total_area) * 100 | |
| per_symptom[prompt] = percent | |
| if healthy_prompt_idx is not None and prompt_idx == healthy_prompt_idx: | |
| healthy_area = area | |
| else: | |
| total_diseased_area += area | |
| # Calculate total affected (excluding overlaps approximation) | |
| all_diseased_mask = np.zeros((h, w), dtype=bool) | |
| for prompt_idx, prompt in enumerate(result.prompts): | |
| if healthy_prompt_idx is None or prompt_idx != healthy_prompt_idx: | |
| mask_indices = result.prompt_indices == prompt_idx | |
| if mask_indices.any(): | |
| all_diseased_mask |= result.masks[mask_indices].any(axis=0) | |
| affected_percent = (all_diseased_mask.sum() / total_area) * 100 | |
| return { | |
| "total_affected_percent": affected_percent, | |
| "per_symptom": per_symptom, | |
| "healthy_percent": (healthy_area / total_area) * 100 if healthy_prompt_idx else None | |
| } | |
| def get_disease_prompts(self, category: str = "all") -> List[str]: | |
| """ | |
| Get predefined disease detection prompts. | |
| Args: | |
| category: Prompt category ("general", "fungal", "bacterial", | |
| "viral", "nutrient", "pest", or "all") | |
| Returns: | |
| List of prompts for the specified category | |
| """ | |
| prompts_config = self.config.get("prompts", {}) | |
| if category == "all": | |
| all_prompts = [] | |
| for cat_prompts in prompts_config.values(): | |
| all_prompts.extend(cat_prompts) | |
| return all_prompts | |
| elif category in prompts_config: | |
| return prompts_config[category] | |
| else: | |
| available = list(prompts_config.keys()) + ["all"] | |
| raise ValueError(f"Category '{category}' not found. Available: {available}") | |
| class MockSAM3Segmenter(SAM3Segmenter): | |
| """ | |
| Color-based segmentation for plant disease detection. | |
| Analyzes actual image colors to detect disease symptoms: | |
| - Green regions = healthy tissue | |
| - Brown/yellow/spotted regions = potential disease | |
| Uses scipy.ndimage for blob detection on non-green regions. | |
| """ | |
| def load_model(self): | |
| """Skip model loading for color-based analysis.""" | |
| logger.info("Using MockSAM3Segmenter (color-based analysis)") | |
| self.model = "color_analysis" | |
| self.processor = "color_analysis" | |
| def _compute_hsv(self, img_array: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | |
| """Convert RGB image to HSV channels.""" | |
| r, g, b = img_array[:,:,0], img_array[:,:,1], img_array[:,:,2] | |
| rgb_max = np.maximum(np.maximum(r, g), b) | |
| rgb_min = np.minimum(np.minimum(r, g), b) | |
| delta = (rgb_max - rgb_min).astype(np.float32) + 1e-10 | |
| # Value | |
| v = rgb_max | |
| # Saturation | |
| s = np.where(rgb_max > 0, (delta / (rgb_max.astype(np.float32) + 1e-10)) * 255, 0).astype(np.uint8) | |
| # Hue | |
| h_channel = np.zeros_like(r, dtype=np.float32) | |
| mask_r = (rgb_max == r) | |
| h_channel[mask_r] = 60 * (((g[mask_r].astype(np.float32) - b[mask_r]) / delta[mask_r]) % 6) | |
| mask_g = (rgb_max == g) & ~mask_r | |
| h_channel[mask_g] = 60 * (((b[mask_g].astype(np.float32) - r[mask_g]) / delta[mask_g]) + 2) | |
| mask_b = (rgb_max == b) & ~mask_r & ~mask_g | |
| h_channel[mask_b] = 60 * (((r[mask_b].astype(np.float32) - g[mask_b]) / delta[mask_b]) + 4) | |
| h_channel = (h_channel / 2).astype(np.uint8) # 0-180 range | |
| return h_channel, s, v | |
| def _segment_leaf(self, img_array: np.ndarray, h: np.ndarray, s: np.ndarray, v: np.ndarray) -> np.ndarray: | |
| """ | |
| Segment the leaf/plant tissue from the background. | |
| Uses color analysis to find plant material: | |
| - High saturation (plants are colorful, backgrounds are often gray/neutral) | |
| - Green to yellow-brown hue range (plant tissue colors) | |
| - Reasonable brightness | |
| Returns the largest connected region as the leaf mask. | |
| """ | |
| from scipy import ndimage | |
| img_h, img_w = img_array.shape[:2] | |
| # Plant tissue typically has: | |
| # 1. Good saturation (colorful, not gray) | |
| # 2. Hue in plant range: green (35-85) OR yellow/brown diseased (10-45) | |
| # 3. Reasonable brightness | |
| # Broad plant color range (green to brown/yellow) | |
| plant_hue_mask = ( | |
| ((h >= 15) & (h <= 90)) | # Green to yellow-green | |
| ((h >= 5) & (h <= 30)) # Brown/orange (diseased tissue) | |
| ) | |
| # Plant tissue has good saturation and brightness | |
| plant_saturation_mask = (s >= 25) # Saturated (not gray) | |
| plant_brightness_mask = (v >= 30) & (v <= 250) # Not too dark, not blown out | |
| # Combine criteria | |
| potential_leaf = plant_hue_mask & plant_saturation_mask & plant_brightness_mask | |
| # Also include high saturation areas regardless of hue (catches more plant tissue) | |
| high_saturation = (s >= 50) & plant_brightness_mask | |
| potential_leaf = potential_leaf | high_saturation | |
| # Clean up with morphological operations | |
| potential_leaf = ndimage.binary_closing(potential_leaf, iterations=3) | |
| potential_leaf = ndimage.binary_opening(potential_leaf, iterations=2) | |
| potential_leaf = ndimage.binary_fill_holes(potential_leaf) | |
| # Find the largest connected component (main leaf) | |
| labeled, num_features = ndimage.label(potential_leaf) | |
| if num_features == 0: | |
| # No leaf found - return full image as fallback | |
| logger.warning("No leaf detected, using full image") | |
| return np.ones((img_h, img_w), dtype=bool) | |
| # Find largest component | |
| component_sizes = ndimage.sum(potential_leaf, labeled, range(1, num_features + 1)) | |
| largest_idx = np.argmax(component_sizes) + 1 | |
| leaf_mask = (labeled == largest_idx) | |
| # Leaf should cover at least 10% of image to be valid | |
| leaf_coverage = leaf_mask.sum() / (img_h * img_w) | |
| if leaf_coverage < 0.10: | |
| logger.warning(f"Leaf too small ({leaf_coverage:.1%}), using full image") | |
| return np.ones((img_h, img_w), dtype=bool) | |
| logger.debug(f"Leaf segmented: {leaf_coverage:.1%} of image") | |
| return leaf_mask | |
| def _detect_disease_regions(self, image: Image.Image) -> Tuple[np.ndarray, List[Dict]]: | |
| """ | |
| Detect disease regions based on color analysis. | |
| First segments the leaf from background, then analyzes only | |
| the leaf area for disease symptoms. | |
| Returns: | |
| Tuple of (binary mask of all abnormal regions, list of blob info dicts) | |
| """ | |
| from scipy import ndimage | |
| img_array = np.array(image) | |
| img_h, img_w = img_array.shape[:2] | |
| # Compute HSV | |
| h_channel, s, v = self._compute_hsv(img_array) | |
| # Step 1: Segment the leaf from background | |
| leaf_mask = self._segment_leaf(img_array, h_channel, s, v) | |
| leaf_area = leaf_mask.sum() | |
| if leaf_area == 0: | |
| return np.zeros((img_h, img_w), dtype=bool), [] | |
| # Step 2: Within the leaf, find healthy green regions | |
| green_mask = ( | |
| (h_channel >= 35) & (h_channel <= 85) & # Green hue | |
| (s >= 30) & # Saturated | |
| (v >= 30) & # Not too dark | |
| leaf_mask # Only within leaf | |
| ) | |
| green_area = green_mask.sum() | |
| green_ratio = green_area / leaf_area | |
| logger.debug(f"Within leaf - Green: {green_ratio:.1%}, Leaf area: {leaf_area}px") | |
| # Step 3: Define disease colors (only within leaf!) | |
| # Brown spots: low hue, moderate saturation | |
| brown_mask = ( | |
| (h_channel >= 5) & (h_channel <= 25) & | |
| (s >= 30) & | |
| (v >= 40) & (v <= 200) & | |
| leaf_mask | |
| ) | |
| # Yellow/chlorosis: yellow hue, high saturation | |
| yellow_mask = ( | |
| (h_channel >= 20) & (h_channel <= 40) & | |
| (s >= 40) & | |
| (v >= 80) & | |
| leaf_mask | |
| ) | |
| # Necrotic dark spots (within leaf only) | |
| dark_spots = ( | |
| (v <= 60) & | |
| (s >= 15) & # Some color, not pure black | |
| leaf_mask | |
| ) | |
| # White spots (powdery mildew) - within leaf | |
| white_spots = ( | |
| (v >= 200) & | |
| (s <= 40) & | |
| leaf_mask | |
| ) | |
| # Combine abnormal regions | |
| abnormal_mask = (brown_mask | yellow_mask | dark_spots | white_spots) | |
| abnormal_area = abnormal_mask.sum() | |
| logger.debug(f"Abnormal pixels within leaf: {abnormal_area} ({abnormal_area/leaf_area:.1%} of leaf)") | |
| # If mostly green (>80% of leaf is green), consider healthy | |
| if green_ratio > 0.80 and abnormal_area < leaf_area * 0.05: | |
| logger.info(f"Leaf appears healthy ({green_ratio:.0%} green)") | |
| return np.zeros((img_h, img_w), dtype=bool), [] | |
| # If very little abnormal tissue, also healthy | |
| if abnormal_area < leaf_area * 0.02: | |
| logger.info("Minimal abnormal tissue detected - healthy") | |
| return np.zeros((img_h, img_w), dtype=bool), [] | |
| # Clean up the abnormal mask | |
| abnormal_mask = ndimage.binary_opening(abnormal_mask, iterations=1) | |
| abnormal_mask = ndimage.binary_closing(abnormal_mask, iterations=2) | |
| # Label connected components | |
| labeled_array, num_features = ndimage.label(abnormal_mask) | |
| # Filter blobs by size (relative to leaf, not image) | |
| min_blob_area = max(50, leaf_area * 0.005) # At least 0.5% of leaf | |
| max_blob_area = leaf_area * 0.6 # At most 60% of leaf | |
| blobs = [] | |
| for label_idx in range(1, num_features + 1): | |
| blob_mask = (labeled_array == label_idx) | |
| blob_area = blob_mask.sum() | |
| if min_blob_area <= blob_area <= max_blob_area: | |
| # Get bounding box | |
| rows = np.any(blob_mask, axis=1) | |
| cols = np.any(blob_mask, axis=0) | |
| y_min, y_max = np.where(rows)[0][[0, -1]] | |
| x_min, x_max = np.where(cols)[0][[0, -1]] | |
| # Calculate confidence based on color | |
| blob_region = img_array[blob_mask] | |
| avg_color = blob_region.mean(axis=0) | |
| r_ratio = avg_color[0] / 255 | |
| g_ratio = avg_color[1] / 255 | |
| b_ratio = avg_color[2] / 255 | |
| # Score: more brown/yellow = higher confidence | |
| color_score = r_ratio - 0.5 * g_ratio + 0.3 * (1 - b_ratio) | |
| color_score = np.clip(color_score, 0, 1) | |
| # Area score relative to leaf | |
| area_ratio = blob_area / leaf_area | |
| area_score = min(1.0, area_ratio * 10) | |
| confidence = 0.4 + 0.4 * color_score + 0.2 * area_score | |
| confidence = np.clip(confidence, 0.3, 0.95) | |
| blobs.append({ | |
| 'mask': blob_mask, | |
| 'bbox': [x_min, y_min, x_max, y_max], | |
| 'area': blob_area, | |
| 'confidence': float(confidence) | |
| }) | |
| return abnormal_mask, blobs | |
| def segment_with_concepts( | |
| self, | |
| image: Union[str, Path, Image.Image, np.ndarray], | |
| text_prompts: List[str], | |
| confidence_threshold: Optional[float] = None | |
| ) -> SegmentationResult: | |
| """ | |
| Segment disease regions based on color analysis. | |
| Analyzes the image colors to detect abnormal (non-green) regions | |
| that may indicate disease. Returns empty results for healthy images. | |
| """ | |
| # Load image | |
| if isinstance(image, (str, Path)): | |
| image = Image.open(image).convert("RGB") | |
| elif isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| w, h = image.size | |
| threshold = confidence_threshold or self.config["inference"]["confidence_threshold"] | |
| # Detect disease regions based on color | |
| abnormal_mask, blobs = self._detect_disease_regions(image) | |
| # Filter by confidence threshold | |
| blobs = [b for b in blobs if b['confidence'] >= threshold] | |
| if not blobs: | |
| logger.info("No disease regions detected (healthy image)") | |
| return SegmentationResult( | |
| masks=np.zeros((0, h, w), dtype=bool), | |
| boxes=np.zeros((0, 4), dtype=np.float32), | |
| scores=np.zeros((0,), dtype=np.float32), | |
| prompts=text_prompts, | |
| prompt_indices=np.zeros((0,), dtype=np.int32) | |
| ) | |
| # Convert blobs to arrays | |
| num_detections = len(blobs) | |
| masks = np.zeros((num_detections, h, w), dtype=bool) | |
| boxes = np.zeros((num_detections, 4), dtype=np.float32) | |
| scores = np.zeros(num_detections, dtype=np.float32) | |
| # Assign detections to first disease-related prompt (skip "healthy" prompts) | |
| disease_prompt_idx = 0 | |
| for idx, prompt in enumerate(text_prompts): | |
| if "healthy" not in prompt.lower(): | |
| disease_prompt_idx = idx | |
| break | |
| prompt_indices = np.full(num_detections, disease_prompt_idx, dtype=np.int32) | |
| for i, blob in enumerate(blobs): | |
| masks[i] = blob['mask'] | |
| boxes[i] = blob['bbox'] | |
| scores[i] = blob['confidence'] | |
| logger.info(f"Detected {num_detections} disease region(s)") | |
| return SegmentationResult( | |
| masks=masks, | |
| boxes=boxes, | |
| scores=scores, | |
| prompts=text_prompts, | |
| prompt_indices=prompt_indices | |
| ) | |
| class RFDETRSegmenter(SAM3Segmenter): | |
| """ | |
| RF-DETR based object detection for plant disease detection. | |
| Uses a trained RF-DETR model (DETR-based detector) instead of SAM 3. | |
| RF-DETR is trained on annotated plant disease datasets with bounding boxes. | |
| Example: | |
| >>> segmenter = RFDETRSegmenter("models/rfdetr/best.pt") | |
| >>> result = segmenter.segment_with_concepts(image, ["disease"]) | |
| >>> print(f"Found {len(result.masks)} disease regions") | |
| """ | |
| def __init__( | |
| self, | |
| checkpoint_path: str = "models/rfdetr/best.pt", | |
| config_path: str = "configs/sam3_config.yaml", | |
| device: Optional[str] = None, | |
| model_size: str = "medium" | |
| ): | |
| """ | |
| Initialize RF-DETR segmenter. | |
| Args: | |
| checkpoint_path: Path to trained RF-DETR checkpoint | |
| config_path: Path to configuration file | |
| device: Device to use (auto-detected if None) | |
| model_size: RF-DETR model size (nano, small, medium, base) | |
| """ | |
| super().__init__(checkpoint_path, config_path, device) | |
| self.model_size = model_size | |
| self.class_names = ["Pestalotiopsis"] # Default class, updated after loading | |
| def load_model(self): | |
| """Load RF-DETR model.""" | |
| if self.model is not None: | |
| return | |
| logger.info(f"Loading RF-DETR {self.model_size} model...") | |
| try: | |
| # Import RF-DETR | |
| if self.model_size == "nano": | |
| from rfdetr import RFDETRNano as RFDETRModel | |
| elif self.model_size == "small": | |
| from rfdetr import RFDETRSmall as RFDETRModel | |
| elif self.model_size == "medium": | |
| from rfdetr import RFDETRMedium as RFDETRModel | |
| else: | |
| from rfdetr import RFDETRBase as RFDETRModel | |
| # Load model with custom weights if available | |
| checkpoint = Path(self.checkpoint_path) | |
| if checkpoint.exists(): | |
| logger.info(f"Loading custom weights from {checkpoint}") | |
| self.model = RFDETRModel(pretrain_weights=str(checkpoint)) | |
| else: | |
| logger.warning(f"Checkpoint not found at {checkpoint}, using pretrained weights") | |
| self.model = RFDETRModel() | |
| logger.info("RF-DETR model loaded successfully") | |
| except ImportError as e: | |
| logger.error(f"RF-DETR not installed: {e}") | |
| logger.error("Install with: pip install rfdetr") | |
| raise | |
| def segment_with_concepts( | |
| self, | |
| image: Union[str, Path, Image.Image, np.ndarray], | |
| text_prompts: List[str], | |
| confidence_threshold: Optional[float] = None | |
| ) -> SegmentationResult: | |
| """ | |
| Detect disease regions using RF-DETR. | |
| Note: RF-DETR is class-based (not prompt-based), so text_prompts | |
| are ignored. The model detects all trained disease classes. | |
| Args: | |
| image: Input image | |
| text_prompts: Ignored (RF-DETR uses trained classes) | |
| confidence_threshold: Detection confidence threshold | |
| Returns: | |
| SegmentationResult with detected disease regions | |
| """ | |
| self.load_model() | |
| # Load image | |
| if isinstance(image, (str, Path)): | |
| pil_image = Image.open(image).convert("RGB") | |
| elif isinstance(image, np.ndarray): | |
| pil_image = Image.fromarray(image) | |
| else: | |
| pil_image = image | |
| w, h = pil_image.size | |
| threshold = confidence_threshold or self.config["inference"]["confidence_threshold"] | |
| # Run RF-DETR detection | |
| try: | |
| detections = self.model.predict(pil_image, threshold=threshold) | |
| except Exception as e: | |
| logger.error(f"RF-DETR prediction failed: {e}") | |
| return SegmentationResult( | |
| masks=np.zeros((0, h, w), dtype=bool), | |
| boxes=np.zeros((0, 4), dtype=np.float32), | |
| scores=np.zeros((0,), dtype=np.float32), | |
| prompts=text_prompts, | |
| prompt_indices=np.zeros((0,), dtype=np.int32) | |
| ) | |
| # Extract detections from supervision Detections object | |
| num_detections = len(detections) | |
| if num_detections == 0: | |
| logger.info("No disease regions detected") | |
| return SegmentationResult( | |
| masks=np.zeros((0, h, w), dtype=bool), | |
| boxes=np.zeros((0, 4), dtype=np.float32), | |
| scores=np.zeros((0,), dtype=np.float32), | |
| prompts=text_prompts, | |
| prompt_indices=np.zeros((0,), dtype=np.int32) | |
| ) | |
| # Get bounding boxes and scores | |
| boxes = detections.xyxy.astype(np.float32) # [x1, y1, x2, y2] | |
| scores = detections.confidence.astype(np.float32) | |
| # Create masks from bounding boxes (RF-DETR gives boxes, not masks) | |
| masks = np.zeros((num_detections, h, w), dtype=bool) | |
| for i, box in enumerate(boxes): | |
| x1, y1, x2, y2 = box.astype(int) | |
| x1, y1 = max(0, x1), max(0, y1) | |
| x2, y2 = min(w, x2), min(h, y2) | |
| masks[i, y1:y2, x1:x2] = True | |
| # Assign to first disease prompt | |
| disease_prompt_idx = 0 | |
| for idx, prompt in enumerate(text_prompts): | |
| if "healthy" not in prompt.lower(): | |
| disease_prompt_idx = idx | |
| break | |
| prompt_indices = np.full(num_detections, disease_prompt_idx, dtype=np.int32) | |
| logger.info(f"RF-DETR detected {num_detections} disease region(s)") | |
| return SegmentationResult( | |
| masks=masks, | |
| boxes=boxes, | |
| scores=scores, | |
| prompts=text_prompts, | |
| prompt_indices=prompt_indices | |
| ) | |
| def create_segmenter( | |
| checkpoint_path: str = "models/sam3/sam3.pt", | |
| config_path: str = "configs/sam3_config.yaml", | |
| use_mock: bool = False, | |
| use_rfdetr: bool = False, | |
| rfdetr_checkpoint: str = "models/rfdetr/best.pt", | |
| rfdetr_model_size: str = "medium" | |
| ) -> SAM3Segmenter: | |
| """ | |
| Factory function to create appropriate segmenter. | |
| Args: | |
| checkpoint_path: Path to SAM 3 checkpoint | |
| config_path: Path to configuration | |
| use_mock: If True, use color-based mock segmenter | |
| use_rfdetr: If True, use RF-DETR detector | |
| rfdetr_checkpoint: Path to RF-DETR checkpoint | |
| rfdetr_model_size: RF-DETR model size (nano, small, medium, base) | |
| Returns: | |
| SAM3Segmenter, MockSAM3Segmenter, or RFDETRSegmenter instance | |
| """ | |
| if use_rfdetr: | |
| return RFDETRSegmenter( | |
| checkpoint_path=rfdetr_checkpoint, | |
| config_path=config_path, | |
| model_size=rfdetr_model_size | |
| ) | |
| elif use_mock: | |
| return MockSAM3Segmenter(checkpoint_path, config_path) | |
| else: | |
| return SAM3Segmenter(checkpoint_path, config_path) | |
| if __name__ == "__main__": | |
| # Quick test with mock | |
| segmenter = create_segmenter(use_mock=True) | |
| # Create a test image | |
| test_image = Image.new("RGB", (640, 480), color=(34, 139, 34)) # Forest green | |
| prompts = ["diseased leaf", "brown spots", "healthy tissue"] | |
| result = segmenter.segment_with_concepts(test_image, prompts) | |
| print(f"Found {len(result.masks)} regions") | |
| print(f"Scores: {result.scores}") | |
| print(f"Prompts used: {[result.prompts[i] for i in result.prompt_indices]}") | |
| areas = segmenter.calculate_affected_area(result) | |
| print(f"Affected area: {areas['total_affected_percent']:.1f}%") | |