Fahimeh Orvati Nia commited on
Commit
f8ac29e
·
1 Parent(s): 6926cc6

update sorghum for multiple plants

Browse files
app.py CHANGED
@@ -23,11 +23,14 @@ PRESET_IMAGES = {
23
 
24
  def process(file_path, preset_choice):
25
  """Process image and yield results progressively for immediate display."""
26
- # If a preset is chosen, override the uploaded file path
 
27
  if preset_choice:
28
  chosen = PRESET_IMAGES.get(preset_choice)
29
  if chosen:
30
  file_path = chosen
 
 
31
 
32
  if not file_path:
33
  # Return 10 outputs (removed YOLO tips)
@@ -62,7 +65,7 @@ def process(file_path, preset_choice):
62
  return None
63
 
64
  # Run the pipeline progressively (generator)
65
- for outputs in run_pipeline_on_image(str(img_path), tmpdir, save_artifacts=True):
66
  # Load all available outputs progressively
67
  composite = load_pil(outputs.get('Composite'))
68
  overlay = load_pil(outputs.get('Overlay'))
 
23
 
24
  def process(file_path, preset_choice):
25
  """Process image and yield results progressively for immediate display."""
26
+ # Determine dataset type (single-plant mode for Corn, multi-plant for others)
27
+ single_plant_mode = False
28
  if preset_choice:
29
  chosen = PRESET_IMAGES.get(preset_choice)
30
  if chosen:
31
  file_path = chosen
32
+ # Corn uses single-plant mode
33
+ single_plant_mode = (preset_choice == "Corn")
34
 
35
  if not file_path:
36
  # Return 10 outputs (removed YOLO tips)
 
65
  return None
66
 
67
  # Run the pipeline progressively (generator)
68
+ for outputs in run_pipeline_on_image(str(img_path), tmpdir, save_artifacts=True, single_plant_mode=single_plant_mode):
69
  # Load all available outputs progressively
70
  composite = load_pil(outputs.get('Composite'))
71
  overlay = load_pil(outputs.get('Overlay'))
sorghum_pipeline/features/morphology.py CHANGED
@@ -22,12 +22,13 @@ class MorphologyExtractor:
22
  """Morphology extraction: size analysis image + simple traits."""
23
 
24
  def __init__(self, pixel_to_cm: float = 0.1099609375, prune_sizes: List[int] = None,
25
- min_component_area_for_size: int = 3000):
26
  """Initialize."""
27
  self.pixel_to_cm = pixel_to_cm
28
  self.prune_sizes = prune_sizes or [200, 100, 50, 30, 10]
29
  # Used only for the Morphology Size visualization (not for height)
30
  self.min_component_area_for_size = int(min_component_area_for_size)
 
31
 
32
  if PLANT_CV_AVAILABLE:
33
  pcv.params.debug = None
@@ -72,19 +73,22 @@ class MorphologyExtractor:
72
  plant_heights[f'plant_{plant_idx}'] = height_cm
73
  plant_areas[f'plant_{plant_idx}'] = area
74
 
75
- # Keep only the largest plant (main plant) for single-plant datasets
76
- if len(plant_heights) > 1:
77
  # Find the largest plant by area
78
  largest_plant = max(plant_areas.items(), key=lambda x: x[1])[0]
79
  plant_heights = {largest_plant: plant_heights[largest_plant]}
80
 
81
  # Store individual plant heights
82
  features['traits']['plant_heights'] = plant_heights
83
- features['traits']['num_plants'] = 1 if len(plant_heights) > 0 else 0
84
 
85
- # Store single plant height
86
  if len(plant_heights) == 1:
87
  features['traits']['plant_height_cm'] = list(plant_heights.values())[0]
 
 
 
88
  else:
89
  features['traits']['plant_height_cm'] = 0.0
90
 
@@ -165,43 +169,54 @@ class MorphologyExtractor:
165
  return arr
166
 
167
  def _simple_size_visual(self, rgb: np.ndarray, mask: np.ndarray) -> np.ndarray:
168
- """Draw contours and bbox for the largest plant on RGB image."""
169
  vis = rgb.copy()
170
 
171
  # Find connected components to identify individual plants
172
  num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
173
 
174
- # Find the largest plant (skip background 0)
175
- largest_idx = -1
176
- largest_area = 0
177
- for plant_idx in range(1, num_labels):
178
- area = stats[plant_idx, cv2.CC_STAT_AREA]
179
- if area > largest_area and area >= 100: # Filter noise
180
- largest_area = area
181
- largest_idx = plant_idx
182
-
183
- # Draw only the largest plant
184
- if largest_idx > 0:
185
- # Get mask for largest plant
186
- plant_mask = ((labels == largest_idx).astype(np.uint8) * 255)
187
-
188
- # Find contours
189
- contours, _ = cv2.findContours(plant_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
190
-
191
- # Use blue color for main plant
192
- color = (255, 0, 0)
193
 
194
- # Draw contours
195
- cv2.drawContours(vis, contours, -1, color, 2)
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
- # Draw bounding box
198
- if contours:
199
- x, y, w, h = cv2.boundingRect(contours[0])
200
- cv2.rectangle(vis, (x, y), (x + w, y + h), (0, 255, 0), 2)
 
 
 
 
 
 
 
 
 
 
 
201
 
202
- # Add "Plant 1" label
203
- cv2.putText(vis, "Plant 1", (x, y - 5),
204
- cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2, cv2.LINE_AA)
205
 
206
  return vis
207
 
 
22
  """Morphology extraction: size analysis image + simple traits."""
23
 
24
  def __init__(self, pixel_to_cm: float = 0.1099609375, prune_sizes: List[int] = None,
25
+ min_component_area_for_size: int = 3000, single_plant_mode: bool = False):
26
  """Initialize."""
27
  self.pixel_to_cm = pixel_to_cm
28
  self.prune_sizes = prune_sizes or [200, 100, 50, 30, 10]
29
  # Used only for the Morphology Size visualization (not for height)
30
  self.min_component_area_for_size = int(min_component_area_for_size)
31
+ self.single_plant_mode = single_plant_mode
32
 
33
  if PLANT_CV_AVAILABLE:
34
  pcv.params.debug = None
 
73
  plant_heights[f'plant_{plant_idx}'] = height_cm
74
  plant_areas[f'plant_{plant_idx}'] = area
75
 
76
+ # Keep only the largest plant if in single-plant mode (e.g., corn)
77
+ if self.single_plant_mode and len(plant_heights) > 1:
78
  # Find the largest plant by area
79
  largest_plant = max(plant_areas.items(), key=lambda x: x[1])[0]
80
  plant_heights = {largest_plant: plant_heights[largest_plant]}
81
 
82
  # Store individual plant heights
83
  features['traits']['plant_heights'] = plant_heights
84
+ features['traits']['num_plants'] = len(plant_heights) if not self.single_plant_mode else (1 if len(plant_heights) > 0 else 0)
85
 
86
+ # For backward compatibility, store single height
87
  if len(plant_heights) == 1:
88
  features['traits']['plant_height_cm'] = list(plant_heights.values())[0]
89
+ elif len(plant_heights) > 1:
90
+ # Store max height as overall height for multi-plant
91
+ features['traits']['plant_height_cm'] = max(plant_heights.values())
92
  else:
93
  features['traits']['plant_height_cm'] = 0.0
94
 
 
169
  return arr
170
 
171
  def _simple_size_visual(self, rgb: np.ndarray, mask: np.ndarray) -> np.ndarray:
172
+ """Draw contours and bbox for plants on RGB image."""
173
  vis = rgb.copy()
174
 
175
  # Find connected components to identify individual plants
176
  num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
177
 
178
+ if self.single_plant_mode:
179
+ # Single-plant mode: draw only the largest plant
180
+ largest_idx = -1
181
+ largest_area = 0
182
+ for plant_idx in range(1, num_labels):
183
+ area = stats[plant_idx, cv2.CC_STAT_AREA]
184
+ if area > largest_area and area >= 100: # Filter noise
185
+ largest_area = area
186
+ largest_idx = plant_idx
 
 
 
 
 
 
 
 
 
 
187
 
188
+ if largest_idx > 0:
189
+ plant_mask = ((labels == largest_idx).astype(np.uint8) * 255)
190
+ contours, _ = cv2.findContours(plant_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
191
+ color = (255, 0, 0)
192
+ cv2.drawContours(vis, contours, -1, color, 2)
193
+ if contours:
194
+ x, y, w, h = cv2.boundingRect(contours[0])
195
+ cv2.rectangle(vis, (x, y), (x + w, y + h), (0, 255, 0), 2)
196
+ cv2.putText(vis, "Plant 1", (x, y - 5),
197
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2, cv2.LINE_AA)
198
+ else:
199
+ # Multi-plant mode: draw all plants with different colors
200
+ colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255), (0, 255, 255)]
201
+ plant_count = 0
202
 
203
+ for plant_idx in range(1, num_labels):
204
+ area = stats[plant_idx, cv2.CC_STAT_AREA]
205
+ if area < 100: # Filter noise
206
+ continue
207
+
208
+ plant_mask = ((labels == plant_idx).astype(np.uint8) * 255)
209
+ contours, _ = cv2.findContours(plant_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
210
+ color = colors[plant_count % len(colors)]
211
+ cv2.drawContours(vis, contours, -1, color, 2)
212
+
213
+ if contours:
214
+ x, y, w, h = cv2.boundingRect(contours[0])
215
+ cv2.rectangle(vis, (x, y), (x + w, y + h), color, 2)
216
+ cv2.putText(vis, f"P{plant_idx}", (x, y - 5),
217
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2, cv2.LINE_AA)
218
 
219
+ plant_count += 1
 
 
220
 
221
  return vis
222
 
sorghum_pipeline/pipeline.py CHANGED
@@ -20,18 +20,19 @@ logger = logging.getLogger(__name__)
20
  class SorghumPipeline:
21
  """Minimal pipeline for single-image processing."""
22
 
23
- def __init__(self, config: Config):
24
  """Initialize pipeline."""
25
  logging.basicConfig(level=logging.INFO, format='%(levelname)s - %(message)s')
26
  self.config = config
27
  self.config.validate()
 
28
 
29
  # Initialize components with defaults
30
  self.preprocessor = ImagePreprocessor()
31
  self.mask_handler = MaskHandler()
32
  self.texture_extractor = TextureExtractor()
33
  self.vegetation_extractor = VegetationIndexExtractor()
34
- self.morphology_extractor = MorphologyExtractor()
35
  self.segmentation_manager = SegmentationManager(
36
  model_name="briaai/RMBG-2.0",
37
  device=self.config.get_device(),
@@ -41,7 +42,7 @@ class SorghumPipeline:
41
  output_folder=self.config.paths.output_folder,
42
  settings=self.config.output
43
  )
44
- logger.info("Pipeline initialized")
45
 
46
  def run(self, single_image_path: str) -> Dict[str, Any]:
47
  """Run pipeline on single image."""
@@ -105,7 +106,7 @@ class SorghumPipeline:
105
 
106
  return {"plants": plants, "timing": elapsed}
107
 
108
- def run_with_progress(self, single_image_path: str, progress_callback: Optional[Callable[[str, Dict[str, Any]], None]] = None) -> Generator[Dict[str, Any], None, None]:
109
  """Run pipeline on single image, yielding intermediate results progressively."""
110
  logger.info("Processing single image with progress...")
111
 
 
20
  class SorghumPipeline:
21
  """Minimal pipeline for single-image processing."""
22
 
23
+ def __init__(self, config: Config, single_plant_mode: bool = False):
24
  """Initialize pipeline."""
25
  logging.basicConfig(level=logging.INFO, format='%(levelname)s - %(message)s')
26
  self.config = config
27
  self.config.validate()
28
+ self.single_plant_mode = single_plant_mode
29
 
30
  # Initialize components with defaults
31
  self.preprocessor = ImagePreprocessor()
32
  self.mask_handler = MaskHandler()
33
  self.texture_extractor = TextureExtractor()
34
  self.vegetation_extractor = VegetationIndexExtractor()
35
+ self.morphology_extractor = MorphologyExtractor(single_plant_mode=single_plant_mode)
36
  self.segmentation_manager = SegmentationManager(
37
  model_name="briaai/RMBG-2.0",
38
  device=self.config.get_device(),
 
42
  output_folder=self.config.paths.output_folder,
43
  settings=self.config.output
44
  )
45
+ logger.info(f"Pipeline initialized (single_plant_mode={single_plant_mode})")
46
 
47
  def run(self, single_image_path: str) -> Dict[str, Any]:
48
  """Run pipeline on single image."""
 
106
 
107
  return {"plants": plants, "timing": elapsed}
108
 
109
+ def run_with_progress(self, single_image_path: str, progress_callback: Optional[Callable[[str, Dict[str, Any]], None]] = None, single_plant_mode: bool = False) -> Generator[Dict[str, Any], None, None]:
110
  """Run pipeline on single image, yielding intermediate results progressively."""
111
  logger.info("Processing single image with progress...")
112
 
wrapper.py CHANGED
@@ -10,7 +10,8 @@ from sorghum_pipeline.config import Config, Paths
10
 
11
 
12
  def run_pipeline_on_image(input_image_path: str, work_dir: str, save_artifacts: bool = True,
13
- progress_callback: Optional[Callable[[str, Dict[str, Any]], None]] = None) -> Generator[Dict[str, str], None, None]:
 
14
  """
15
  Run sorghum pipeline on a single image (no instance segmentation).
16
  Yields dict[label -> image_path] progressively for gallery display.
@@ -42,10 +43,10 @@ def run_pipeline_on_image(input_image_path: str, work_dir: str, save_artifacts:
42
  output_folder=str(work),
43
  boundingbox_dir=str(work)
44
  )
45
- pipeline = SorghumPipeline(config=cfg)
46
 
47
  # Run the pipeline with progress callback (generator)
48
- for stage_result in pipeline.run_with_progress(single_image_path=str(input_path), progress_callback=progress_callback):
49
  # Yield intermediate outputs as they become available
50
  outputs = _collect_outputs(work, stage_result.get('plants', {}))
51
  yield outputs
@@ -114,11 +115,20 @@ def _collect_outputs(work: Path, plants: Dict[str, Any]) -> Dict[str, str]:
114
  plant_heights = traits.get('plant_heights', {})
115
  num_plants = traits.get('num_plants', 0)
116
 
117
- # Always show as single plant (largest component)
118
- if num_plants > 0 and isinstance(plant_heights, dict) and len(plant_heights) >= 1:
119
- height_cm = list(plant_heights.values())[0]
120
- stats_lines.append(f"Number of plants: 1")
121
- stats_lines.append(f"Plant height: {height_cm:.2f} cm")
 
 
 
 
 
 
 
 
 
122
  else:
123
  # Fallback to old single height field
124
  height_cm = traits.get('plant_height_cm')
 
10
 
11
 
12
  def run_pipeline_on_image(input_image_path: str, work_dir: str, save_artifacts: bool = True,
13
+ progress_callback: Optional[Callable[[str, Dict[str, Any]], None]] = None,
14
+ single_plant_mode: bool = False) -> Generator[Dict[str, str], None, None]:
15
  """
16
  Run sorghum pipeline on a single image (no instance segmentation).
17
  Yields dict[label -> image_path] progressively for gallery display.
 
43
  output_folder=str(work),
44
  boundingbox_dir=str(work)
45
  )
46
+ pipeline = SorghumPipeline(config=cfg, single_plant_mode=single_plant_mode)
47
 
48
  # Run the pipeline with progress callback (generator)
49
+ for stage_result in pipeline.run_with_progress(single_image_path=str(input_path), progress_callback=progress_callback, single_plant_mode=single_plant_mode):
50
  # Yield intermediate outputs as they become available
51
  outputs = _collect_outputs(work, stage_result.get('plants', {}))
52
  yield outputs
 
115
  plant_heights = traits.get('plant_heights', {})
116
  num_plants = traits.get('num_plants', 0)
117
 
118
+ # Display plant info based on mode
119
+ if num_plants > 0 and isinstance(plant_heights, dict):
120
+ if num_plants == 1 or len(plant_heights) == 1:
121
+ # Single plant display
122
+ height_cm = list(plant_heights.values())[0]
123
+ stats_lines.append(f"Number of plants: 1")
124
+ stats_lines.append(f"Plant height: {height_cm:.2f} cm")
125
+ else:
126
+ # Multiple plants display
127
+ stats_lines.append(f"Number of plants: {num_plants}")
128
+ sorted_plants = sorted(plant_heights.items(), key=lambda x: int(x[0].split('_')[1]))
129
+ for plant_name, height_cm in sorted_plants:
130
+ plant_num = plant_name.split('_')[1]
131
+ stats_lines.append(f" Plant {plant_num}: {height_cm:.2f} cm")
132
  else:
133
  # Fallback to old single height field
134
  height_cm = traits.get('plant_height_cm')