ahadhassan commited on
Commit
503cb09
·
verified ·
1 Parent(s): cf53d6c

Update yolo_predictor.py

Browse files
Files changed (1) hide show
  1. yolo_predictor.py +57 -254
yolo_predictor.py CHANGED
@@ -1,297 +1,100 @@
1
  # yolo_predictor.py
2
  import os
3
- import numpy as np
4
  import rasterio
5
  from modified_ultralytics import YOLO
6
- from ndvi_predictor import normalize_rgb, predict_ndvi
7
- import tempfile
8
- from PIL import Image
9
  import tifffile
10
 
11
  def load_yolo_model(model_path):
12
  """Load YOLO model from .pt file"""
13
  return YOLO(model_path)
14
 
15
- def predict_ndvi_from_rgb(ndvi_model, rgb_array):
16
  """
17
- Predict NDVI channel from RGB array
18
 
19
  Args:
20
- ndvi_model: Loaded NDVI prediction model
21
- rgb_array: RGB image as numpy array (H, W, 3)
22
 
23
  Returns:
24
- ndvi_array: Predicted NDVI as numpy array (H, W)
 
 
 
25
  """
26
- # Normalize RGB input
27
- norm_rgb = normalize_rgb(rgb_array)
28
-
29
- # Predict NDVI
30
- ndvi_pred = predict_ndvi(ndvi_model, norm_rgb)
31
 
32
- return ndvi_pred
33
-
34
- def predict_yolo(yolo_model, image_path, conf=0.001):
35
- import tifffile
36
-
37
  try:
 
38
  img_array = tifffile.imread(image_path)
39
- print(f"[DEBUG] Loaded image shape: {img_array.shape}, dtype: {img_array.dtype}")
40
-
41
  if len(img_array.shape) == 3:
42
  if img_array.shape[0] == 4:
43
- img_array = np.transpose(img_array, (1, 2, 0))
44
- print(f"[DEBUG] Transposed image shape to (H,W,C): {img_array.shape}")
45
- elif img_array.shape[2] != 4:
46
- raise ValueError(f"[ERROR] Expected 4 channels, got {img_array.shape[2]}")
47
- else:
48
- raise ValueError(f"[ERROR] Unexpected image shape: {img_array.shape}")
49
-
50
- # Confirm channel count
51
- if img_array.shape[2] != 4:
52
- raise ValueError(f"[ERROR] After transpose, still not 4 channels: got {img_array.shape[2]}")
53
-
54
- print(f"[DEBUG] Image dtype before normalization: {img_array.dtype}")
55
-
56
- if img_array.dtype != np.uint8:
57
- print(f"[DEBUG] Converting image to uint8")
58
- rgb_array = img_array[:, :, :3]
59
- ndvi_array = img_array[:, :, 3]
60
-
61
- # Normalize RGB
62
- if rgb_array.max() > 1.0:
63
- rgb_array = np.clip(rgb_array / rgb_array.max() * 255, 0, 255).astype(np.uint8)
64
  else:
65
- rgb_array = np.clip(rgb_array * 255, 0, 255).astype(np.uint8)
66
-
67
- # Normalize NDVI
68
- ndvi_normalized = ((ndvi_array + 1) * 127.5).astype(np.uint8)
69
-
70
- img_array = np.zeros((img_array.shape[0], img_array.shape[1], 4), dtype=np.uint8)
71
- img_array[:, :, :3] = rgb_array
72
- img_array[:, :, 3] = ndvi_normalized
73
-
74
- print(f"[DEBUG] Image converted to uint8 with shape: {img_array.shape}")
75
-
76
- # Save normalized version to temp file
77
- with tempfile.NamedTemporaryFile(delete=False, suffix='.tiff') as tmp_file:
78
- temp_path = tmp_file.name
79
- tifffile.imwrite(temp_path, img_array)
80
- image_path = temp_path
81
-
82
- print(f"[DEBUG] Final image ready for YOLO, path: {image_path}")
83
-
84
- # Final safety check
85
- assert img_array.shape[2] == 4, "[FATAL] Final image does not have 4 channels."
86
-
87
- results = yolo_model([image_path], conf=conf)
88
-
89
- if 'temp_path' in locals() and os.path.exists(temp_path):
90
- os.unlink(temp_path)
91
-
92
- return results[0]
93
-
94
- except Exception as e:
95
- raise ValueError(f"Error processing image: {str(e)}")
96
-
97
- def create_4channel_tiff(rgb_array, ndvi_array, output_path):
98
- """
99
- Create a 4-channel TIFF file from RGB and NDVI arrays compatible with PIL and YOLO
100
-
101
- Args:
102
- rgb_array: RGB image as numpy array (H, W, 3)
103
- ndvi_array: NDVI image as numpy array (H, W)
104
- output_path: Path to save the 4-channel TIFF
105
- """
106
- height, width = rgb_array.shape[:2]
107
-
108
- # Ensure RGB is in uint8 format
109
- if rgb_array.dtype != np.uint8:
110
- if rgb_array.max() <= 1.0:
111
- rgb_normalized = (rgb_array * 255).astype(np.uint8)
112
  else:
113
- rgb_normalized = np.clip(rgb_array, 0, 255).astype(np.uint8)
114
- else:
115
- rgb_normalized = rgb_array
116
-
117
- # Convert NDVI from [-1, 1] to [0, 255] for uint8 storage
118
- ndvi_normalized = ((ndvi_array + 1) * 127.5).astype(np.uint8)
119
-
120
- # Create 4-channel array in (H, W, 4) format
121
- four_channel = np.zeros((height, width, 4), dtype=np.uint8)
122
- four_channel[:, :, 0] = rgb_normalized[:, :, 0] # Red
123
- four_channel[:, :, 1] = rgb_normalized[:, :, 1] # Green
124
- four_channel[:, :, 2] = rgb_normalized[:, :, 2] # Blue
125
- four_channel[:, :, 3] = ndvi_normalized # NDVI
126
-
127
- # Save using tifffile with explicit 32-bit compatibility and DPI=1
128
- tifffile.imwrite(
129
- output_path,
130
- four_channel,
131
- photometric='rgb',
132
- compress='lzw',
133
- metadata={'axes': 'YXC', 'resolution': (1, 1)}, # DPI=1
134
- bitspersample=8 # Explicitly set to 8-bit per channel
135
- )
136
-
137
- def load_4channel_tiff(image_path):
138
- """
139
- Load a 4-channel TIFF image
140
-
141
- Args:
142
- image_path: Path to 4-channel TIFF image
143
-
144
- Returns:
145
- rgb_array: RGB channels as numpy array (H, W, 3)
146
- ndvi_array: NDVI channel as numpy array (H, W)
147
- """
148
- try:
149
- # Try with tifffile first for better TIFF support
150
- img_array = tifffile.imread(image_path)
151
-
152
- if len(img_array.shape) == 3:
153
- if img_array.shape[0] == 4:
154
- # Shape is (4, H, W) - transpose to (H, W, 4)
155
- img_array = np.transpose(img_array, (1, 2, 0))
156
- elif img_array.shape[2] != 4:
157
- raise ValueError(f"Expected 4 channels, got {img_array.shape}")
158
-
159
- # Extract RGB and NDVI from (H, W, 4) format
160
- rgb_array = img_array[:, :, :3]
161
- ndvi_array = img_array[:, :, 3]
162
-
163
- # Convert NDVI back from [0, 255] to [-1, 1] if it was stored as uint8
164
- if img_array.dtype == np.uint8:
165
- ndvi_array = (ndvi_array.astype(np.float32) / 127.5) - 1
166
-
167
- return rgb_array, ndvi_array
168
 
169
  except Exception as e:
170
- # Fallback to rasterio
171
  try:
172
  with rasterio.open(image_path) as src:
173
  if src.count != 4:
174
- raise ValueError(f"Expected 4 channels, got {src.count}")
175
-
176
- channels = src.read() # Shape: (4, H, W)
177
-
178
- # Extract RGB and NDVI
179
- rgb_array = np.transpose(channels[:3], (1, 2, 0)) # (H, W, 3)
180
- ndvi_array = channels[3] # (H, W)
181
-
182
- # Convert NDVI if needed
183
- if channels.dtype == np.uint8:
184
- ndvi_array = (ndvi_array.astype(np.float32) / 127.5) - 1
185
-
186
- return rgb_array, ndvi_array
187
 
188
  except Exception as e2:
189
- raise ValueError(f"Could not load 4-channel TIFF. Errors: tifffile={e}, rasterio={e2}")
190
 
191
- def predict_pipeline(ndvi_model, yolo_model, image_path, conf=0.001):
192
  """
193
- Full pipeline: Load image -> Extract RGB -> Predict NDVI ->
194
- Create 4-channel TIFF -> Run YOLO prediction
195
 
196
  Args:
197
- ndvi_model: Loaded NDVI prediction model
198
  yolo_model: Loaded YOLO model
199
- image_path: Path to input image (can be RGB or 4-channel TIFF)
200
- conf: Confidence threshold for YOLO
201
 
202
  Returns:
203
  results: YOLO results object
204
  """
205
- rgb_array = None
 
206
 
207
- # Try multiple methods to load the image and extract RGB
208
- try:
209
- # Method 1: Try with tifffile first (best for complex TIFF files)
210
- img_array = tifffile.imread(image_path)
211
-
212
- if len(img_array.shape) == 3:
213
- if img_array.shape[0] == 4:
214
- # Shape is (4, H, W) - extract RGB
215
- rgb_array = np.transpose(img_array[:3], (1, 2, 0))
216
- elif img_array.shape[0] == 3:
217
- # Shape is (3, H, W) - transpose to RGB
218
- rgb_array = np.transpose(img_array, (1, 2, 0))
219
- elif img_array.shape[2] == 4:
220
- # Shape is (H, W, 4) - extract RGB
221
- rgb_array = img_array[:, :, :3]
222
- elif img_array.shape[2] == 3:
223
- # Shape is (H, W, 3) - already RGB
224
- rgb_array = img_array
225
- elif len(img_array.shape) == 2:
226
- # Grayscale - convert to RGB
227
- rgb_array = np.stack([img_array] * 3, axis=-1)
228
-
229
- except Exception as e1:
230
- try:
231
- # Method 2: Try with rasterio
232
- with rasterio.open(image_path) as src:
233
- channels = src.read()
234
- if src.count >= 3:
235
- rgb_array = np.transpose(channels[:3], (1, 2, 0))
236
- elif src.count == 1:
237
- # Single channel - convert to RGB
238
- single_channel = channels[0]
239
- rgb_array = np.stack([single_channel] * 3, axis=-1)
240
-
241
- except Exception as e2:
242
- try:
243
- # Method 3: Fall back to PIL for standard formats
244
- img = Image.PIL(image_path)
245
- if img.mode != 'RGB':
246
- img = img.convert('RGB')
247
- rgb_array = np.array(img)
248
-
249
- except Exception as e3:
250
- raise ValueError(f"Could not load image with any method. Errors: tifffile={e1}, rasterio={e2}, PIL={e3}")
251
-
252
- if rgb_array is None:
253
- raise ValueError("Failed to extract RGB data from image")
254
 
255
- # Ensure RGB is in correct format and range
256
- if rgb_array.dtype == np.uint8:
257
- rgb_float = rgb_array.astype(np.float32) / 255.0
258
- else:
259
- if rgb_array.max() > 1.0:
260
- rgb_float = rgb_array / rgb_array.max()
261
- else:
262
- rgb_float = rgb_array
263
- rgb_array = (rgb_float * 255).astype(np.uint8)
264
-
265
- # Predict NDVI from RGB
266
- ndvi_pred = predict_ndvi_from_rgb(ndvi_model, rgb_float)
267
 
268
- # Create temporary 4-channel TIFF file
269
- with tempfile.NamedTemporaryFile(delete=False, suffix='.tiff') as tmp_file:
270
- temp_4ch_path = tmp_file.name
 
 
271
 
272
- try:
273
- # Create 4-channel TIFF with predicted NDVI
274
- create_4channel_tiff(rgb_array, ndvi_pred, temp_4ch_path)
275
-
276
- # Verify the created file can be read
277
- try:
278
- test_array = tifffile.imread(temp_4ch_path)
279
- if len(test_array.shape) == 3 and (test_array.shape[0] == 4 or test_array.shape[2] == 4):
280
- channels = 4
281
- else:
282
- channels = test_array.shape[2] if len(test_array.shape) == 3 else 1
283
- if channels != 4:
284
- raise ValueError(f"Created TIFF has {channels} channels instead of 4")
285
-
286
- except Exception as e:
287
- raise ValueError(f"Created TIFF file is not readable: {str(e)}")
288
-
289
- # Run YOLO prediction on 4-channel image
290
- results = predict_yolo(yolo_model, temp_4ch_path, conf=conf)
291
-
292
- return results
293
-
294
- finally:
295
- # Clean up temporary file
296
- if os.path.exists(temp_4ch_path):
297
- os.unlink(temp_4ch_path)
 
1
  # yolo_predictor.py
2
  import os
 
3
  import rasterio
4
  from modified_ultralytics import YOLO
 
 
 
5
  import tifffile
6
 
7
  def load_yolo_model(model_path):
8
  """Load YOLO model from .pt file"""
9
  return YOLO(model_path)
10
 
11
+ def validate_4channel_tiff(image_path):
12
  """
13
+ Validate that the input TIFF file has 4 channels and is readable
14
 
15
  Args:
16
+ image_path: Path to input TIFF image
 
17
 
18
  Returns:
19
+ bool: True if valid 4-channel TIFF
20
+
21
+ Raises:
22
+ ValueError: If validation fails
23
  """
24
+ if not os.path.exists(image_path):
25
+ raise ValueError(f"Image file does not exist: {image_path}")
 
 
 
26
 
 
 
 
 
 
27
  try:
28
+ # Primary validation with tifffile
29
  img_array = tifffile.imread(image_path)
30
+
31
+ # Check array shape and channels
32
  if len(img_array.shape) == 3:
33
  if img_array.shape[0] == 4:
34
+ # Shape is (4, H, W)
35
+ channels = 4
36
+ height, width = img_array.shape[1], img_array.shape[2]
37
+ elif img_array.shape[2] == 4:
38
+ # Shape is (H, W, 4)
39
+ channels = 4
40
+ height, width = img_array.shape[0], img_array.shape[1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  else:
42
+ channels = min(img_array.shape[0], img_array.shape[2])
43
+ height, width = img_array.shape[0], img_array.shape[1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  else:
45
+ raise ValueError(f"Invalid image shape: {img_array.shape}. Expected 3D array with 4 channels.")
46
+
47
+ if channels != 4:
48
+ raise ValueError(f"YOLO model expects 4-channel images, but got {channels} channels")
49
+
50
+ print(f"Validation successful: {channels} channels, {height}x{width}, dtype: {img_array.dtype}")
51
+ return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  except Exception as e:
54
+ # Fallback validation with rasterio
55
  try:
56
  with rasterio.open(image_path) as src:
57
  if src.count != 4:
58
+ raise ValueError(f"YOLO model expects 4-channel images, but got {src.count} channels")
59
+
60
+ print(f"Validation successful (rasterio): {src.count} channels, {src.width}x{src.height}, dtype: {src.dtypes[0]}")
61
+ return True
 
 
 
 
 
 
 
 
 
62
 
63
  except Exception as e2:
64
+ raise ValueError(f"Could not validate TIFF file. Errors: tifffile={str(e)}, rasterio={str(e2)}")
65
 
66
+ def predict_yolo(yolo_model, image_path, conf=0.001):
67
  """
68
+ Predict using YOLO model on 4-channel TIFF image
 
69
 
70
  Args:
 
71
  yolo_model: Loaded YOLO model
72
+ image_path: Path to 4-channel TIFF image
73
+ conf: Confidence threshold
74
 
75
  Returns:
76
  results: YOLO results object
77
  """
78
+ # Validate input file
79
+ validate_4channel_tiff(image_path)
80
 
81
+ # Run YOLO prediction directly on the input file
82
+ results = yolo_model([image_path], conf=conf)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
+ return results[0] # Return first result
85
+
86
+ def predict_pipeline(ndvi_model, yolo_model, image_path, conf=0.001):
87
+ """
88
+ Simplified pipeline: Validate input -> Run YOLO prediction
 
 
 
 
 
 
 
89
 
90
+ Args:
91
+ ndvi_model: Not used (kept for API compatibility)
92
+ yolo_model: Loaded YOLO model
93
+ image_path: Path to input 4-channel TIFF image
94
+ conf: Confidence threshold for YOLO
95
 
96
+ Returns:
97
+ results: YOLO results object
98
+ """
99
+ # Simply validate and run prediction on the uploaded file
100
+ return predict_yolo(yolo_model, image_path, conf=conf)