ahadhassan commited on
Commit
cd20fc4
·
verified ·
1 Parent(s): d96f558

Update yolo_predictor.py

Browse files
Files changed (1) hide show
  1. yolo_predictor.py +136 -59
yolo_predictor.py CHANGED
@@ -44,6 +44,36 @@ def predict_yolo(yolo_model, image_path, conf=0.001):
44
  Returns:
45
  results: YOLO results object
46
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  # Run YOLO prediction
48
  results = yolo_model([image_path], conf=conf)
49
 
@@ -51,7 +81,7 @@ def predict_yolo(yolo_model, image_path, conf=0.001):
51
 
52
  def create_4channel_tiff(rgb_array, ndvi_array, output_path):
53
  """
54
- Create a 4-channel TIFF file from RGB and NDVI arrays
55
 
56
  Args:
57
  rgb_array: RGB image as numpy array (H, W, 3)
@@ -60,24 +90,33 @@ def create_4channel_tiff(rgb_array, ndvi_array, output_path):
60
  """
61
  height, width = rgb_array.shape[:2]
62
 
63
- # Stack RGB and NDVI to create 4-channel image
64
- four_channel = np.zeros((4, height, width), dtype=np.float32)
65
-
66
- # Convert RGB to proper format and range
67
- if rgb_array.dtype == np.uint8:
68
- rgb_normalized = rgb_array.astype(np.float32) / 255.0
69
  else:
70
- rgb_normalized = rgb_array.astype(np.float32)
71
 
72
- # Assign channels in (C, H, W) format for rasterio
73
- four_channel[0] = rgb_normalized[:, :, 0] # Red
74
- four_channel[1] = rgb_normalized[:, :, 1] # Green
75
- four_channel[2] = rgb_normalized[:, :, 2] # Blue
76
- four_channel[3] = ndvi_array.astype(np.float32) # NDVI
77
 
78
- # Use tifffile for better compatibility with YOLO
79
- import tifffile
80
- tifffile.imwrite(output_path, four_channel, photometric='rgb')
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  def load_4channel_tiff(image_path):
83
  """
@@ -91,45 +130,52 @@ def load_4channel_tiff(image_path):
91
  ndvi_array: NDVI channel as numpy array (H, W)
92
  """
93
  try:
94
- with rasterio.open(image_path) as src:
95
- # Read all 4 channels
96
- channels = src.read() # Shape: (4, H, W)
97
-
98
- # Extract RGB and NDVI
99
- rgb_array = np.transpose(channels[:3], (1, 2, 0)) # (H, W, 3)
100
- ndvi_array = channels[3] # (H, W)
101
-
102
- # If NDVI was scaled to uint8, convert back to [-1, 1] range
103
- if channels.dtype == np.uint8:
104
- ndvi_array = (ndvi_array.astype(np.float32) / 127.5) - 1
105
-
106
- return rgb_array, ndvi_array
107
- except Exception as e:
108
- # Try with tifffile as fallback
109
- import tifffile
110
  img_array = tifffile.imread(image_path)
111
 
112
- if len(img_array.shape) == 3 and img_array.shape[0] == 4:
113
- # Shape is (4, H, W)
114
- rgb_array = np.transpose(img_array[:3], (1, 2, 0)) # (H, W, 3)
115
- ndvi_array = img_array[3] # (H, W)
116
- elif len(img_array.shape) == 3 and img_array.shape[2] == 4:
117
- # Shape is (H, W, 4)
118
- rgb_array = img_array[:, :, :3] # (H, W, 3)
119
- ndvi_array = img_array[:, :, 3] # (H, W)
120
- else:
121
- raise ValueError(f"Unexpected image shape: {img_array.shape}")
122
 
123
- # Normalize NDVI if needed
124
  if img_array.dtype == np.uint8:
125
  ndvi_array = (ndvi_array.astype(np.float32) / 127.5) - 1
126
 
127
  return rgb_array, ndvi_array
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
  def predict_pipeline(ndvi_model, yolo_model, image_path, conf=0.001):
130
  """
131
- Full pipeline: Load 4-channel image -> Extract RGB -> Predict NDVI ->
132
- Create new 4-channel with predicted NDVI -> Run YOLO prediction
133
 
134
  Args:
135
  ndvi_model: Loaded NDVI prediction model
@@ -142,40 +188,48 @@ def predict_pipeline(ndvi_model, yolo_model, image_path, conf=0.001):
142
  """
143
  rgb_array = None
144
 
145
- # Try multiple methods to load the image
146
  try:
147
  # Method 1: Try with tifffile first (best for complex TIFF files)
148
- import tifffile
149
  img_array = tifffile.imread(image_path)
150
 
151
  if len(img_array.shape) == 3:
152
  if img_array.shape[0] == 4:
153
  # Shape is (4, H, W) - extract RGB
154
  rgb_array = np.transpose(img_array[:3], (1, 2, 0))
 
 
 
155
  elif img_array.shape[2] == 4:
156
  # Shape is (H, W, 4) - extract RGB
157
  rgb_array = img_array[:, :, :3]
158
  elif img_array.shape[2] == 3:
159
  # Shape is (H, W, 3) - already RGB
160
  rgb_array = img_array
161
- elif img_array.shape[0] == 3:
162
- # Shape is (3, H, W) - transpose to RGB
163
- rgb_array = np.transpose(img_array, (1, 2, 0))
 
164
  except Exception as e1:
165
  try:
166
  # Method 2: Try with rasterio
167
  with rasterio.open(image_path) as src:
 
168
  if src.count >= 3:
169
- channels = src.read()
170
- if src.count == 4:
171
- rgb_array = np.transpose(channels[:3], (1, 2, 0))
172
- else:
173
- rgb_array = np.transpose(channels, (1, 2, 0))
 
174
  except Exception as e2:
175
  try:
176
  # Method 3: Fall back to PIL for standard formats
177
- img = Image.open(image_path).convert("RGB")
 
 
178
  rgb_array = np.array(img)
 
179
  except Exception as e3:
180
  raise ValueError(f"Could not load image with any method. Errors: tifffile={e1}, rasterio={e2}, PIL={e3}")
181
 
@@ -183,11 +237,19 @@ def predict_pipeline(ndvi_model, yolo_model, image_path, conf=0.001):
183
  raise ValueError("Failed to extract RGB data from image")
184
 
185
  # Ensure RGB is in correct format and range
186
- if rgb_array.max() > 1:
187
- rgb_array = rgb_array.astype(np.float32) / 255.0
 
 
 
 
 
 
 
 
188
 
189
  # Predict NDVI from RGB
190
- ndvi_pred = predict_ndvi_from_rgb(ndvi_model, rgb_array)
191
 
192
  # Create temporary 4-channel TIFF file
193
  with tempfile.NamedTemporaryFile(delete=False, suffix='.tiff') as tmp_file:
@@ -197,6 +259,21 @@ def predict_pipeline(ndvi_model, yolo_model, image_path, conf=0.001):
197
  # Create 4-channel TIFF with predicted NDVI
198
  create_4channel_tiff(rgb_array, ndvi_pred, temp_4ch_path)
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  # Run YOLO prediction on 4-channel image
201
  results = predict_yolo(yolo_model, temp_4ch_path, conf=conf)
202
 
 
44
  Returns:
45
  results: YOLO results object
46
  """
47
+ # Verify the image has 4 channels before prediction
48
+ try:
49
+ # Check image format and channels
50
+ with Image.open(image_path) as img:
51
+ if hasattr(img, 'n_frames'):
52
+ # Multi-frame TIFF
53
+ channels = img.n_frames
54
+ else:
55
+ # Regular image
56
+ channels = len(img.getbands()) if hasattr(img, 'getbands') else 3
57
+
58
+ # If not 4 channels, try with tifffile
59
+ if channels != 4:
60
+ img_array = tifffile.imread(image_path)
61
+ if len(img_array.shape) == 3:
62
+ if img_array.shape[0] == 4:
63
+ channels = 4
64
+ elif img_array.shape[2] == 4:
65
+ channels = 4
66
+ else:
67
+ channels = img_array.shape[0] if img_array.shape[0] <= 4 else img_array.shape[2]
68
+ else:
69
+ channels = 1
70
+
71
+ if channels != 4:
72
+ raise ValueError(f"YOLO model expects 4-channel images, but got {channels} channels")
73
+
74
+ except Exception as e:
75
+ raise ValueError(f"Error reading image channels: {str(e)}")
76
+
77
  # Run YOLO prediction
78
  results = yolo_model([image_path], conf=conf)
79
 
 
81
 
82
  def create_4channel_tiff(rgb_array, ndvi_array, output_path):
83
  """
84
+ Create a 4-channel TIFF file from RGB and NDVI arrays compatible with PIL and YOLO
85
 
86
  Args:
87
  rgb_array: RGB image as numpy array (H, W, 3)
 
90
  """
91
  height, width = rgb_array.shape[:2]
92
 
93
+ # Ensure RGB is in uint8 format for better compatibility
94
+ if rgb_array.dtype != np.uint8:
95
+ if rgb_array.max() <= 1.0:
96
+ rgb_normalized = (rgb_array * 255).astype(np.uint8)
97
+ else:
98
+ rgb_normalized = np.clip(rgb_array, 0, 255).astype(np.uint8)
99
  else:
100
+ rgb_normalized = rgb_array
101
 
102
+ # Convert NDVI from [-1, 1] to [0, 255] for uint8 storage
103
+ ndvi_normalized = ((ndvi_array + 1) * 127.5).astype(np.uint8)
 
 
 
104
 
105
+ # Create 4-channel array in (H, W, 4) format
106
+ four_channel = np.zeros((height, width, 4), dtype=np.uint8)
107
+ four_channel[:, :, 0] = rgb_normalized[:, :, 0] # Red
108
+ four_channel[:, :, 1] = rgb_normalized[:, :, 1] # Green
109
+ four_channel[:, :, 2] = rgb_normalized[:, :, 2] # Blue
110
+ four_channel[:, :, 3] = ndvi_normalized # NDVI
111
+
112
+ # Save using tifffile with proper format for YOLO compatibility
113
+ tifffile.imwrite(
114
+ output_path,
115
+ four_channel,
116
+ photometric='rgb',
117
+ compress='lzw',
118
+ metadata={'axes': 'YXC'}
119
+ )
120
 
121
  def load_4channel_tiff(image_path):
122
  """
 
130
  ndvi_array: NDVI channel as numpy array (H, W)
131
  """
132
  try:
133
+ # Try with tifffile first for better TIFF support
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  img_array = tifffile.imread(image_path)
135
 
136
+ if len(img_array.shape) == 3:
137
+ if img_array.shape[0] == 4:
138
+ # Shape is (4, H, W) - transpose to (H, W, 4)
139
+ img_array = np.transpose(img_array, (1, 2, 0))
140
+ elif img_array.shape[2] != 4:
141
+ raise ValueError(f"Expected 4 channels, got {img_array.shape}")
142
+
143
+ # Extract RGB and NDVI from (H, W, 4) format
144
+ rgb_array = img_array[:, :, :3]
145
+ ndvi_array = img_array[:, :, 3]
146
 
147
+ # Convert NDVI back from [0, 255] to [-1, 1] if it was stored as uint8
148
  if img_array.dtype == np.uint8:
149
  ndvi_array = (ndvi_array.astype(np.float32) / 127.5) - 1
150
 
151
  return rgb_array, ndvi_array
152
+
153
+ except Exception as e:
154
+ # Fallback to rasterio
155
+ try:
156
+ with rasterio.open(image_path) as src:
157
+ if src.count != 4:
158
+ raise ValueError(f"Expected 4 channels, got {src.count}")
159
+
160
+ channels = src.read() # Shape: (4, H, W)
161
+
162
+ # Extract RGB and NDVI
163
+ rgb_array = np.transpose(channels[:3], (1, 2, 0)) # (H, W, 3)
164
+ ndvi_array = channels[3] # (H, W)
165
+
166
+ # Convert NDVI if needed
167
+ if channels.dtype == np.uint8:
168
+ ndvi_array = (ndvi_array.astype(np.float32) / 127.5) - 1
169
+
170
+ return rgb_array, ndvi_array
171
+
172
+ except Exception as e2:
173
+ raise ValueError(f"Could not load 4-channel TIFF. Errors: tifffile={e}, rasterio={e2}")
174
 
175
  def predict_pipeline(ndvi_model, yolo_model, image_path, conf=0.001):
176
  """
177
+ Full pipeline: Load image -> Extract RGB -> Predict NDVI ->
178
+ Create 4-channel TIFF -> Run YOLO prediction
179
 
180
  Args:
181
  ndvi_model: Loaded NDVI prediction model
 
188
  """
189
  rgb_array = None
190
 
191
+ # Try multiple methods to load the image and extract RGB
192
  try:
193
  # Method 1: Try with tifffile first (best for complex TIFF files)
 
194
  img_array = tifffile.imread(image_path)
195
 
196
  if len(img_array.shape) == 3:
197
  if img_array.shape[0] == 4:
198
  # Shape is (4, H, W) - extract RGB
199
  rgb_array = np.transpose(img_array[:3], (1, 2, 0))
200
+ elif img_array.shape[0] == 3:
201
+ # Shape is (3, H, W) - transpose to RGB
202
+ rgb_array = np.transpose(img_array, (1, 2, 0))
203
  elif img_array.shape[2] == 4:
204
  # Shape is (H, W, 4) - extract RGB
205
  rgb_array = img_array[:, :, :3]
206
  elif img_array.shape[2] == 3:
207
  # Shape is (H, W, 3) - already RGB
208
  rgb_array = img_array
209
+ elif len(img_array.shape) == 2:
210
+ # Grayscale - convert to RGB
211
+ rgb_array = np.stack([img_array] * 3, axis=-1)
212
+
213
  except Exception as e1:
214
  try:
215
  # Method 2: Try with rasterio
216
  with rasterio.open(image_path) as src:
217
+ channels = src.read()
218
  if src.count >= 3:
219
+ rgb_array = np.transpose(channels[:3], (1, 2, 0))
220
+ elif src.count == 1:
221
+ # Single channel - convert to RGB
222
+ single_channel = channels[0]
223
+ rgb_array = np.stack([single_channel] * 3, axis=-1)
224
+
225
  except Exception as e2:
226
  try:
227
  # Method 3: Fall back to PIL for standard formats
228
+ img = Image.open(image_path)
229
+ if img.mode != 'RGB':
230
+ img = img.convert('RGB')
231
  rgb_array = np.array(img)
232
+
233
  except Exception as e3:
234
  raise ValueError(f"Could not load image with any method. Errors: tifffile={e1}, rasterio={e2}, PIL={e3}")
235
 
 
237
  raise ValueError("Failed to extract RGB data from image")
238
 
239
  # Ensure RGB is in correct format and range
240
+ if rgb_array.dtype == np.uint8:
241
+ # Keep as uint8 but also create float version for NDVI prediction
242
+ rgb_float = rgb_array.astype(np.float32) / 255.0
243
+ else:
244
+ # Already float, ensure range is [0, 1]
245
+ if rgb_array.max() > 1.0:
246
+ rgb_float = rgb_array / 255.0
247
+ else:
248
+ rgb_float = rgb_array
249
+ rgb_array = (rgb_float * 255).astype(np.uint8)
250
 
251
  # Predict NDVI from RGB
252
+ ndvi_pred = predict_ndvi_from_rgb(ndvi_model, rgb_float)
253
 
254
  # Create temporary 4-channel TIFF file
255
  with tempfile.NamedTemporaryFile(delete=False, suffix='.tiff') as tmp_file:
 
259
  # Create 4-channel TIFF with predicted NDVI
260
  create_4channel_tiff(rgb_array, ndvi_pred, temp_4ch_path)
261
 
262
+ # Verify the created file can be read
263
+ try:
264
+ test_img = Image.open(temp_4ch_path)
265
+ if hasattr(test_img, 'n_frames'):
266
+ channels = test_img.n_frames
267
+ else:
268
+ channels = len(test_img.getbands())
269
+ test_img.close()
270
+
271
+ if channels != 4:
272
+ raise ValueError(f"Created TIFF has {channels} channels instead of 4")
273
+
274
+ except Exception as e:
275
+ raise ValueError(f"Created TIFF file is not readable: {str(e)}")
276
+
277
  # Run YOLO prediction on 4-channel image
278
  results = predict_yolo(yolo_model, temp_4ch_path, conf=conf)
279