Fahimeh Orvati Nia commited on
Commit
60e6efb
·
1 Parent(s): 1bcb567

update the morphology, remove yolo, and correct the display

Browse files
app.py CHANGED
@@ -40,14 +40,41 @@ def process(file_path, preset_choice):
40
  # Fallback: save via PIL if direct copy fails
41
  Image.open(src).save(img_path)
42
 
43
- # Show input image immediately
44
  try:
45
- preview_im = Image.open(img_path)
46
- input_preview = preview_im.copy()
47
- preview_im.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  except Exception:
49
- input_preview = None
 
 
 
 
 
50
 
 
51
  yield (
52
  input_preview, # input image shown immediately
53
  None, # composite
@@ -61,9 +88,7 @@ def process(file_path, preset_choice):
61
  "", # stats
62
  )
63
 
64
- # Run the pipeline
65
- outputs = run_pipeline_on_image(str(img_path), tmpdir, save_artifacts=True)
66
-
67
  def load_pil(path_str):
68
  try:
69
  if not path_str:
@@ -75,39 +100,42 @@ def process(file_path, preset_choice):
75
  except Exception:
76
  return None
77
 
78
- composite = load_pil(outputs.get('Composite'))
79
- overlay = load_pil(outputs.get('Overlay'))
80
- mask = load_pil(outputs.get('Mask'))
81
- input_img = load_pil(outputs.get('InputImage')) or input_preview
82
- size_img = load_pil(str(Path(tmpdir) / 'results/size.size_analysis.png'))
83
-
84
- # Texture images (green band)
85
- lbp_path = Path(tmpdir) / 'texture_output/lbp_green.png'
86
- hog_path = Path(tmpdir) / 'texture_output/hog_green.png'
87
- lac1_path = Path(tmpdir) / 'texture_output/lac1_green.png'
88
- texture_img = load_pil(str(lbp_path)) if lbp_path.exists() else None
89
- hog_img = load_pil(str(hog_path)) if hog_path.exists() else None
90
- lac1_img = load_pil(str(lac1_path)) if lac1_path.exists() else None
91
-
92
- # Vegetation indices
93
- order = ['NDVI', 'GNDVI', 'SAVI']
94
- gallery_items = [load_pil(outputs[k]) for k in order if k in outputs]
95
-
96
- stats_text = outputs.get('StatsText', '')
97
-
98
- # Final yield with all results (10 outputs, no YOLO)
99
- yield (
100
- input_img,
101
- composite,
102
- mask,
103
- overlay,
104
- texture_img,
105
- hog_img,
106
- lac1_img,
107
- gallery_items,
108
- size_img,
109
- stats_text,
110
- )
 
 
 
111
 
112
 
113
  with gr.Blocks() as demo:
 
40
  # Fallback: save via PIL if direct copy fails
41
  Image.open(src).save(img_path)
42
 
43
+ # Show input image immediately (read exactly like pipeline for correctness)
44
  try:
45
+ import imghdr
46
+ import tifffile # type: ignore
47
+ import cv2 # type: ignore
48
+
49
+ kind = imghdr.what(str(img_path))
50
+ suffix = img_path.suffix.lower()
51
+
52
+ arr = None
53
+ if kind == "tiff" or suffix in [".tif", ".tiff"]:
54
+ try:
55
+ arr = tifffile.imread(str(img_path))
56
+ except Exception:
57
+ arr = cv2.imread(str(img_path), cv2.IMREAD_UNCHANGED)
58
+ else:
59
+ arr = cv2.imread(str(img_path), cv2.IMREAD_UNCHANGED)
60
+
61
+ if arr is None:
62
+ raise ValueError("Could not read image for preview")
63
+ if arr.ndim > 3:
64
+ arr = arr[..., 0]
65
+ if arr.ndim == 3 and arr.shape[-1] == 1:
66
+ arr = arr[..., 0]
67
+
68
+ input_preview = Image.fromarray(arr)
69
  except Exception:
70
+ try:
71
+ preview_im = Image.open(img_path)
72
+ input_preview = preview_im.copy()
73
+ preview_im.close()
74
+ except Exception:
75
+ input_preview = None
76
 
77
+ # Initial yield showing input preview
78
  yield (
79
  input_preview, # input image shown immediately
80
  None, # composite
 
88
  "", # stats
89
  )
90
 
91
+ # Helper to load PIL images
 
 
92
  def load_pil(path_str):
93
  try:
94
  if not path_str:
 
100
  except Exception:
101
  return None
102
 
103
+ # Run the pipeline progressively (generator)
104
+ for outputs in run_pipeline_on_image(str(img_path), tmpdir, save_artifacts=True):
105
+ # Load all available outputs progressively
106
+ composite = load_pil(outputs.get('Composite'))
107
+ overlay = load_pil(outputs.get('Overlay'))
108
+ mask = load_pil(outputs.get('Mask'))
109
+ input_img = load_pil(outputs.get('InputImage')) or input_preview
110
+ size_img = load_pil(str(Path(tmpdir) / 'results/size.size_analysis.png'))
111
+
112
+ # Texture images (green band)
113
+ lbp_path = Path(tmpdir) / 'texture_output/lbp_green.png'
114
+ hog_path = Path(tmpdir) / 'texture_output/hog_green.png'
115
+ lac1_path = Path(tmpdir) / 'texture_output/lac1_green.png'
116
+ texture_img = load_pil(str(lbp_path)) if lbp_path.exists() else None
117
+ hog_img = load_pil(str(hog_path)) if hog_path.exists() else None
118
+ lac1_img = load_pil(str(lac1_path)) if lac1_path.exists() else None
119
+
120
+ # Vegetation indices
121
+ order = ['NDVI', 'GNDVI', 'SAVI']
122
+ gallery_items = [load_pil(outputs[k]) for k in order if k in outputs]
123
+
124
+ stats_text = outputs.get('StatsText', '')
125
+
126
+ # Yield intermediate/final results as they become available
127
+ yield (
128
+ input_img,
129
+ composite,
130
+ mask,
131
+ overlay,
132
+ texture_img,
133
+ hog_img,
134
+ lac1_img,
135
+ gallery_items,
136
+ size_img,
137
+ stats_text,
138
+ )
139
 
140
 
141
  with gr.Blocks() as demo:
sorghum_pipeline/features/morphology.py CHANGED
@@ -19,16 +19,14 @@ logger = logging.getLogger(__name__)
19
 
20
 
21
  class MorphologyExtractor:
22
- """Morphology extraction: size analysis image + simple traits + YOLO tips overlay."""
23
 
24
  def __init__(self, pixel_to_cm: float = 0.1099609375, prune_sizes: List[int] = None,
25
- yolo_weights_path: str = "/home/grads/f/fahimehorvatinia/plant-analysis-demo/SSL_greenhouse_tip_detection.pt",
26
  min_component_area_for_size: int = 3000):
27
  """Initialize."""
28
  self.pixel_to_cm = pixel_to_cm
29
  self.prune_sizes = prune_sizes or [200, 100, 50, 30, 10]
30
- self.yolo_weights_path = yolo_weights_path
31
- # Used only for the Morphology Size visualization (not for height or YOLO)
32
  self.min_component_area_for_size = int(min_component_area_for_size)
33
 
34
  if PLANT_CV_AVAILABLE:
@@ -39,7 +37,7 @@ class MorphologyExtractor:
39
  pcv.params.dpi = 100
40
 
41
  def extract_morphology_features(self, image: np.ndarray, mask: np.ndarray) -> Dict[str, Any]:
42
- """Fast size visualization (YOLO disabled for speed). Simplified for performance."""
43
  features: Dict[str, Any] = {'traits': {}, 'images': {}, 'success': False}
44
 
45
  try:
@@ -52,21 +50,43 @@ class MorphologyExtractor:
52
  if rgb is None:
53
  return features
54
 
55
- # Simple size visualization without PlantCV for speed
56
- vis = self._simple_size_visual(rgb, ((clean_mask > 0).astype(np.uint8) * 255))
57
- features['images']['size_analysis'] = vis
58
- features['success'] = True
59
-
60
- # Compute simple plant height from mask (px → cm)
61
- rows = np.where(clean_mask > 0)[0]
62
- if rows.size:
63
- height_px = int(rows.max() - rows.min() + 1)
64
- features['traits']['plant_height_cm'] = float(height_px * self.pixel_to_cm)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  else:
66
  features['traits']['plant_height_cm'] = 0.0
67
 
68
- # YOLO disabled for speed
69
- features['traits']['num_yolo_tips'] = 0
 
 
70
 
71
  except Exception as e:
72
  logger.error(f"Morphology extraction failed: {e}")
@@ -140,103 +160,47 @@ class MorphologyExtractor:
140
  return arr
141
 
142
  def _simple_size_visual(self, rgb: np.ndarray, mask: np.ndarray) -> np.ndarray:
143
- """Draw contours, bbox, and area on RGB image."""
144
  vis = rgb.copy()
145
- contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
146
- cv2.drawContours(vis, contours, -1, (255, 0, 0), 2)
147
- if contours:
148
- largest = max(contours, key=cv2.contourArea)
149
- x, y, w, h = cv2.boundingRect(largest)
150
- cv2.rectangle(vis, (x, y), (x + w, y + h), (0, 255, 0), 2)
151
- area_px = int(cv2.countNonZero(mask))
152
- cv2.putText(vis, f"Area: {area_px} px", (10, 24), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2, cv2.LINE_AA)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  return vis
154
 
155
- def _create_white_background_overlay(self, rgb: np.ndarray, mask: np.ndarray) -> np.ndarray:
156
- """Return white background with plant pixels in original colors.
157
- The mask here is normalized to a single-channel binary mask so it matches the frontend mask behavior.
158
- """
159
- img_no_bg = rgb.copy()
160
- # Normalize mask to single-channel binary
161
- if mask is None:
162
- bin_mask = None
163
- else:
164
- m = mask
165
- if m.ndim == 3:
166
- m = cv2.cvtColor(m, cv2.COLOR_BGR2GRAY)
167
- m = m.astype(np.uint8)
168
- _, bin_mask = cv2.threshold(m, 0, 255, cv2.THRESH_BINARY)
169
- if bin_mask is None:
170
- return np.full_like(rgb, 255, dtype=np.uint8)
171
- img_no_bg[bin_mask == 0] = 0
172
- overlay = np.full_like(rgb, 255, dtype=np.uint8)
173
- overlay[bin_mask > 0] = img_no_bg[bin_mask > 0]
174
- return overlay
175
-
176
- def _detect_yolo_tips(self, rgb: np.ndarray, mask: np.ndarray):
177
- """Detect tips using a YOLO model if available. Returns (overlay_img, tips_list)."""
178
- try:
179
- from ultralytics import YOLO # type: ignore
180
- except Exception as e:
181
- logger.warning(f"Ultralytics not available: {e}")
182
- return None, []
183
-
184
- try:
185
- # Resolve weights path robustly
186
- weights_path = self.yolo_weights_path
187
- if not isinstance(weights_path, str) or not weights_path:
188
- weights_path = "SSL_greenhouse_tip_detection.pt"
189
- # Try absolute, then repo root, then cwd
190
- candidates = [
191
- weights_path,
192
- "/home/grads/f/fahimehorvatinia/plant-analysis-demo/SSL_greenhouse_tip_detection.pt",
193
- "./SSL_greenhouse_tip_detection.pt",
194
- ]
195
- chosen = None
196
- for p in candidates:
197
- try:
198
- import os
199
- if os.path.exists(p):
200
- chosen = p
201
- break
202
- except Exception:
203
- pass
204
- if chosen is None:
205
- logger.warning("YOLO weights not found; skipping YOLO tips")
206
- return None, []
207
- model = YOLO(chosen)
208
- except Exception as e:
209
- logger.warning(f"Failed to load YOLO model: {e}")
210
- return None, []
211
-
212
- try:
213
- overlay_img = self._create_white_background_overlay(rgb, mask)
214
- # Run inference; allow low conf to let thresholding below handle
215
- results = model(overlay_img, conf=0.01, imgsz=640)
216
- tips = []
217
- for r in results:
218
- if getattr(r, 'keypoints', None) is not None and getattr(r.keypoints, 'xy', None) is not None:
219
- kps_xy = r.keypoints.xy.cpu().numpy()
220
- kps_conf = None
221
- if getattr(r.keypoints, 'conf', None) is not None:
222
- kps_conf = r.keypoints.conf.cpu().numpy()
223
- for i, det_xy in enumerate(kps_xy):
224
- for j, pt in enumerate(det_xy):
225
- x, y = float(pt[0]), float(pt[1])
226
- if not np.isnan(x) and not np.isnan(y):
227
- conf = float(kps_conf[i][j]) if kps_conf is not None else 1.0
228
- # Slightly relax threshold to 0.4 to improve recall
229
- if conf >= 0.4:
230
- tips.append((int(x), int(y), conf))
231
-
232
- # Draw tips
233
- vis = overlay_img.copy()
234
- for (x, y, _c) in tips:
235
- cv2.circle(vis, (int(x), int(y)), 8, (255, 0, 0), -1)
236
- return vis, tips
237
- except Exception as e:
238
- logger.warning(f"YOLO detection failed: {e}")
239
- return None, []
240
 
241
  class _FilteredStream:
242
  """Filter PlantCV output."""
 
19
 
20
 
21
  class MorphologyExtractor:
22
+ """Morphology extraction: size analysis image + simple traits."""
23
 
24
  def __init__(self, pixel_to_cm: float = 0.1099609375, prune_sizes: List[int] = None,
 
25
  min_component_area_for_size: int = 3000):
26
  """Initialize."""
27
  self.pixel_to_cm = pixel_to_cm
28
  self.prune_sizes = prune_sizes or [200, 100, 50, 30, 10]
29
+ # Used only for the Morphology Size visualization (not for height)
 
30
  self.min_component_area_for_size = int(min_component_area_for_size)
31
 
32
  if PLANT_CV_AVAILABLE:
 
37
  pcv.params.dpi = 100
38
 
39
  def extract_morphology_features(self, image: np.ndarray, mask: np.ndarray) -> Dict[str, Any]:
40
+ """Fast size visualization with multi-plant support. Simplified for performance."""
41
  features: Dict[str, Any] = {'traits': {}, 'images': {}, 'success': False}
42
 
43
  try:
 
50
  if rgb is None:
51
  return features
52
 
53
+ # Detect multiple plants using connected components
54
+ binary_mask = ((clean_mask > 0).astype(np.uint8) * 255)
55
+ num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(binary_mask, connectivity=8)
56
+
57
+ # Calculate height for each plant (skip background label 0)
58
+ plant_heights = {}
59
+ for plant_idx in range(1, num_labels):
60
+ area = stats[plant_idx, cv2.CC_STAT_AREA]
61
+ # Filter out very small components (noise)
62
+ if area < 100:
63
+ continue
64
+
65
+ # Get rows for this plant
66
+ plant_mask = (labels == plant_idx)
67
+ rows = np.where(plant_mask)[0]
68
+ if rows.size:
69
+ height_px = int(rows.max() - rows.min() + 1)
70
+ height_cm = float(height_px * self.pixel_to_cm)
71
+ plant_heights[f'plant_{plant_idx}'] = height_cm
72
+
73
+ # Store individual plant heights
74
+ features['traits']['plant_heights'] = plant_heights
75
+ features['traits']['num_plants'] = len(plant_heights)
76
+
77
+ # For backward compatibility, store total height if single plant
78
+ if len(plant_heights) == 1:
79
+ features['traits']['plant_height_cm'] = list(plant_heights.values())[0]
80
+ elif len(plant_heights) > 1:
81
+ # Store max height as overall height
82
+ features['traits']['plant_height_cm'] = max(plant_heights.values())
83
  else:
84
  features['traits']['plant_height_cm'] = 0.0
85
 
86
+ # Simple size visualization without PlantCV for speed
87
+ vis = self._simple_size_visual(rgb, binary_mask)
88
+ features['images']['size_analysis'] = vis
89
+ features['success'] = True
90
 
91
  except Exception as e:
92
  logger.error(f"Morphology extraction failed: {e}")
 
160
  return arr
161
 
162
  def _simple_size_visual(self, rgb: np.ndarray, mask: np.ndarray) -> np.ndarray:
163
+ """Draw contours and bbox for each plant on RGB image."""
164
  vis = rgb.copy()
165
+
166
+ # Find connected components to identify individual plants
167
+ num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
168
+
169
+ # Use different colors for different plants
170
+ colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255), (0, 255, 255)]
171
+
172
+ plant_count = 0
173
+ for plant_idx in range(1, num_labels): # Skip background (0)
174
+ area = stats[plant_idx, cv2.CC_STAT_AREA]
175
+ # Filter out very small components (noise)
176
+ if area < 100:
177
+ continue
178
+
179
+ # Get individual plant mask
180
+ plant_mask = ((labels == plant_idx).astype(np.uint8) * 255)
181
+
182
+ # Find contours for this plant
183
+ contours, _ = cv2.findContours(plant_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
184
+
185
+ # Pick color for this plant
186
+ color = colors[plant_count % len(colors)]
187
+
188
+ # Draw contours
189
+ cv2.drawContours(vis, contours, -1, color, 2)
190
+
191
+ # Draw bounding box
192
+ if contours:
193
+ x, y, w, h = cv2.boundingRect(contours[0])
194
+ cv2.rectangle(vis, (x, y), (x + w, y + h), color, 2)
195
+
196
+ # Add plant number label
197
+ cv2.putText(vis, f"P{plant_idx}", (x, y - 5),
198
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2, cv2.LINE_AA)
199
+
200
+ plant_count += 1
201
+
202
  return vis
203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
  class _FilteredStream:
206
  """Filter PlantCV output."""
sorghum_pipeline/pipeline.py CHANGED
@@ -4,7 +4,7 @@ Minimal single-image pipeline for Hugging Face demo.
4
 
5
  import logging
6
  from pathlib import Path
7
- from typing import Dict, Any
8
  import numpy as np
9
  import cv2
10
 
@@ -105,6 +105,90 @@ class SorghumPipeline:
105
 
106
  return {"plants": plants, "timing": elapsed}
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  def _segment(self, plants: Dict[str, Any]) -> Dict[str, Any]:
109
  """Segment using BRIA."""
110
  for key, pdata in plants.items():
 
4
 
5
  import logging
6
  from pathlib import Path
7
+ from typing import Dict, Any, Callable, Optional, Generator
8
  import numpy as np
9
  import cv2
10
 
 
105
 
106
  return {"plants": plants, "timing": elapsed}
107
 
108
+ def run_with_progress(self, single_image_path: str, progress_callback: Optional[Callable[[str, Dict[str, Any]], None]] = None) -> Generator[Dict[str, Any], None, None]:
109
+ """Run pipeline on single image, yielding intermediate results progressively."""
110
+ logger.info("Processing single image with progress...")
111
+
112
+ import time, imghdr, tifffile
113
+ from PIL import Image
114
+
115
+ start = time.perf_counter()
116
+
117
+ # --- Load image with TIFF preference ---
118
+ kind = imghdr.what(single_image_path)
119
+ suffix = Path(single_image_path).suffix.lower()
120
+
121
+ arr = None
122
+ if kind == "tiff" or suffix in [".tif", ".tiff"]:
123
+ try:
124
+ arr = tifffile.imread(single_image_path)
125
+ logger.info(f"Loaded TIFF: shape={arr.shape}, dtype={arr.dtype}")
126
+ except Exception as e:
127
+ logger.warning(f"tifffile failed ({e}), falling back to cv2")
128
+ arr = cv2.imread(single_image_path, cv2.IMREAD_UNCHANGED)
129
+ logger.info(f"Fallback read: shape={arr.shape}, dtype={arr.dtype}")
130
+ else:
131
+ arr = cv2.imread(single_image_path, cv2.IMREAD_UNCHANGED)
132
+ logger.info(f"Loaded non-TIFF: shape={arr.shape}, dtype={arr.dtype}")
133
+
134
+ # --- Normalize array shape ---
135
+ if arr is None:
136
+ raise ValueError(f"Could not read image: {single_image_path}")
137
+ if arr.ndim > 3:
138
+ arr = arr[..., 0] # drop extra dimension
139
+ if arr.ndim == 3 and arr.shape[-1] == 1:
140
+ arr = arr[..., 0] # squeeze singleton
141
+
142
+ logger.info(f"DEBUG normalized input: shape={arr.shape}, dtype={arr.dtype}")
143
+
144
+ # Wrap into PIL image for downstream pipeline
145
+ img = Image.fromarray(arr)
146
+
147
+ plants = {
148
+ "demo": {
149
+ "raw_image": (img, Path(single_image_path).name),
150
+ "plant_name": "demo",
151
+ "normalized_input": arr,
152
+ }
153
+ }
154
+
155
+ # Create output directories early
156
+ self.output_manager.create_output_directories()
157
+
158
+ # Stage 1: Create composite
159
+ logger.info("Stage 1: Creating composite...")
160
+ plants = self.preprocessor.create_composites(plants)
161
+ if progress_callback:
162
+ progress_callback("composite", plants)
163
+ # Save composite immediately for display
164
+ for key, pdata in plants.items():
165
+ self.output_manager.save_plant_results(key, pdata)
166
+ yield {"plants": plants, "stage": "composite"}
167
+
168
+ # Stage 2: Segmentation
169
+ logger.info("Stage 2: Segmentation...")
170
+ plants = self._segment(plants)
171
+ if progress_callback:
172
+ progress_callback("segmentation", plants)
173
+ # Save mask and overlay
174
+ for key, pdata in plants.items():
175
+ self.output_manager.save_plant_results(key, pdata)
176
+ yield {"plants": plants, "stage": "segmentation"}
177
+
178
+ # Stage 3: Extract features (texture, vegetation, morphology)
179
+ logger.info("Stage 3: Extracting features...")
180
+ plants = self._extract_features(plants)
181
+ if progress_callback:
182
+ progress_callback("features", plants)
183
+ # Save all final outputs
184
+ for key, pdata in plants.items():
185
+ self.output_manager.save_plant_results(key, pdata)
186
+
187
+ elapsed = time.perf_counter() - start
188
+ logger.info(f"Completed in {elapsed:.2f}s")
189
+
190
+ yield {"plants": plants, "timing": elapsed, "stage": "complete"}
191
+
192
  def _segment(self, plants: Dict[str, Any]) -> Dict[str, Any]:
193
  """Segment using BRIA."""
194
  for key, pdata in plants.items():
wrapper.py CHANGED
@@ -1,5 +1,5 @@
1
  from pathlib import Path
2
- from typing import Dict
3
  import shutil
4
  from PIL import Image
5
  import glob
@@ -9,10 +9,20 @@ from sorghum_pipeline.pipeline import SorghumPipeline
9
  from sorghum_pipeline.config import Config, Paths
10
 
11
 
12
- def run_pipeline_on_image(input_image_path: str, work_dir: str, save_artifacts: bool = True) -> Dict[str, str]:
 
13
  """
14
  Run sorghum pipeline on a single image (no instance segmentation).
15
- Returns dict[label -> image_path] for gallery display.
 
 
 
 
 
 
 
 
 
16
  """
17
 
18
  work = Path(work_dir)
@@ -34,11 +44,20 @@ def run_pipeline_on_image(input_image_path: str, work_dir: str, save_artifacts:
34
  )
35
  pipeline = SorghumPipeline(config=cfg)
36
 
37
- # Run the pipeline (single image minimal demo)
38
- results = pipeline.run(single_image_path=str(input_path))
 
 
 
 
 
 
39
 
40
- # Collect outputs
 
 
41
  outputs: Dict[str, str] = {}
 
42
  try:
43
  # Log immediate output directory contents for debugging
44
  for sub in ['results', 'Vegetation_indices_images', 'texture_output']:
@@ -66,7 +85,6 @@ def run_pipeline_on_image(input_image_path: str, work_dir: str, save_artifacts:
66
  overlay_path = work / 'results/overlay.png'
67
  mask_path = work / 'results/mask.png'
68
  composite_path = work / 'results/composite.png'
69
- yolo_tips_path = work / 'results/yolo_tips.png'
70
  input_img_path = work / 'results/input_image.png'
71
  if overlay_path.exists():
72
  outputs['Overlay'] = str(overlay_path)
@@ -74,14 +92,11 @@ def run_pipeline_on_image(input_image_path: str, work_dir: str, save_artifacts:
74
  outputs['Mask'] = str(mask_path)
75
  if composite_path.exists():
76
  outputs['Composite'] = str(composite_path)
77
- if yolo_tips_path.exists():
78
- outputs['YOLOTips'] = str(yolo_tips_path)
79
  if input_img_path.exists():
80
  outputs['InputImage'] = str(input_img_path)
81
 
82
  # Extract simple stats for display if present in pipeline results
83
  try:
84
- plants = results.get('plants', {}) if isinstance(results, dict) else {}
85
  if plants:
86
  _, pdata = next(iter(plants.items()))
87
  veg = pdata.get('vegetation_indices', {})
@@ -91,15 +106,31 @@ def run_pipeline_on_image(input_image_path: str, work_dir: str, save_artifacts:
91
  st = entry.get('statistics', {}) if isinstance(entry, dict) else {}
92
  if st:
93
  stats_lines.append(f"{name}: mean={st.get('mean', 0):.3f}, std={st.get('std', 0):.3f}")
94
- # Morphology stats (height, yolo tips)
95
  morph = pdata.get('morphology_features', {}) if isinstance(pdata, dict) else {}
96
  traits = morph.get('traits', {}) if isinstance(morph, dict) else {}
97
- height_cm = traits.get('plant_height_cm')
98
- if isinstance(height_cm, (int, float)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  stats_lines.append(f"Plant height: {height_cm:.2f} cm")
100
- num_tips = traits.get('num_yolo_tips')
101
- if isinstance(num_tips, (int, float)):
102
- stats_lines.append(f"YOLO tips: {int(num_tips)}")
 
 
103
  if stats_lines:
104
  outputs['StatsText'] = "\n".join(stats_lines)
105
  except Exception:
 
1
  from pathlib import Path
2
+ from typing import Dict, Callable, Optional, Generator, Any
3
  import shutil
4
  from PIL import Image
5
  import glob
 
9
  from sorghum_pipeline.config import Config, Paths
10
 
11
 
12
+ def run_pipeline_on_image(input_image_path: str, work_dir: str, save_artifacts: bool = True,
13
+ progress_callback: Optional[Callable[[str, Dict[str, Any]], None]] = None) -> Generator[Dict[str, str], None, None]:
14
  """
15
  Run sorghum pipeline on a single image (no instance segmentation).
16
+ Yields dict[label -> image_path] progressively for gallery display.
17
+
18
+ Args:
19
+ input_image_path: Path to input image
20
+ work_dir: Working directory for outputs
21
+ save_artifacts: Whether to save artifacts
22
+ progress_callback: Optional callback(stage_name, data) called after each pipeline stage
23
+
24
+ Yields:
25
+ Dictionary of output paths progressively as they become available
26
  """
27
 
28
  work = Path(work_dir)
 
44
  )
45
  pipeline = SorghumPipeline(config=cfg)
46
 
47
+ # Run the pipeline with progress callback (generator)
48
+ for stage_result in pipeline.run_with_progress(single_image_path=str(input_path), progress_callback=progress_callback):
49
+ # Yield intermediate outputs as they become available
50
+ outputs = _collect_outputs(work, stage_result.get('plants', {}))
51
+ yield outputs
52
+
53
+ # Final results
54
+ results = stage_result
55
 
56
+
57
+ def _collect_outputs(work: Path, plants: Dict[str, Any]) -> Dict[str, str]:
58
+ """Collect all available outputs from work directory and plants data."""
59
  outputs: Dict[str, str] = {}
60
+
61
  try:
62
  # Log immediate output directory contents for debugging
63
  for sub in ['results', 'Vegetation_indices_images', 'texture_output']:
 
85
  overlay_path = work / 'results/overlay.png'
86
  mask_path = work / 'results/mask.png'
87
  composite_path = work / 'results/composite.png'
 
88
  input_img_path = work / 'results/input_image.png'
89
  if overlay_path.exists():
90
  outputs['Overlay'] = str(overlay_path)
 
92
  outputs['Mask'] = str(mask_path)
93
  if composite_path.exists():
94
  outputs['Composite'] = str(composite_path)
 
 
95
  if input_img_path.exists():
96
  outputs['InputImage'] = str(input_img_path)
97
 
98
  # Extract simple stats for display if present in pipeline results
99
  try:
 
100
  if plants:
101
  _, pdata = next(iter(plants.items()))
102
  veg = pdata.get('vegetation_indices', {})
 
106
  st = entry.get('statistics', {}) if isinstance(entry, dict) else {}
107
  if st:
108
  stats_lines.append(f"{name}: mean={st.get('mean', 0):.3f}, std={st.get('std', 0):.3f}")
109
+ # Morphology stats (height for multiple plants)
110
  morph = pdata.get('morphology_features', {}) if isinstance(pdata, dict) else {}
111
  traits = morph.get('traits', {}) if isinstance(morph, dict) else {}
112
+
113
+ # Check if we have multiple plants
114
+ plant_heights = traits.get('plant_heights', {})
115
+ num_plants = traits.get('num_plants', 0)
116
+
117
+ if isinstance(plant_heights, dict) and len(plant_heights) > 1:
118
+ # Multiple plants detected
119
+ stats_lines.append(f"Number of plants: {num_plants}")
120
+ # Sort by plant index for consistent display
121
+ sorted_plants = sorted(plant_heights.items(), key=lambda x: int(x[0].split('_')[1]))
122
+ for plant_name, height_cm in sorted_plants:
123
+ plant_num = plant_name.split('_')[1]
124
+ stats_lines.append(f" Plant {plant_num}: {height_cm:.2f} cm")
125
+ elif isinstance(plant_heights, dict) and len(plant_heights) == 1:
126
+ # Single plant
127
+ height_cm = list(plant_heights.values())[0]
128
  stats_lines.append(f"Plant height: {height_cm:.2f} cm")
129
+ else:
130
+ # Fallback to old single height field
131
+ height_cm = traits.get('plant_height_cm')
132
+ if isinstance(height_cm, (int, float)) and height_cm > 0:
133
+ stats_lines.append(f"Plant height: {height_cm:.2f} cm")
134
  if stats_lines:
135
  outputs['StatsText'] = "\n".join(stats_lines)
136
  except Exception: