ahadhassan commited on
Commit
3bd81ab
·
verified ·
1 Parent(s): 59519c9

Update yolo_predictor.py

Browse files
Files changed (1) hide show
  1. yolo_predictor.py +62 -44
yolo_predictor.py CHANGED
@@ -5,7 +5,6 @@ import rasterio
5
  from ultralytics import YOLO
6
  from ndvi_predictor import normalize_rgb, predict_ndvi
7
  import tempfile
8
- from rasterio.transform import from_bounds
9
  from PIL import Image
10
  import tifffile
11
 
@@ -46,38 +45,60 @@ def predict_yolo(yolo_model, image_path, conf=0.001):
46
  """
47
  # Verify the image has 4 channels before prediction
48
  try:
49
- # Check image format and channels
50
- with Image.open(image_path) as img:
51
- if hasattr(img, 'n_frames'):
52
- # Multi-frame TIFF
53
- channels = img.n_frames
54
- else:
55
- # Regular image
56
- channels = len(img.getbands()) if hasattr(img, 'getbands') else 3
57
 
58
- # If not 4 channels, try with tifffile
59
- if channels != 4:
60
- img_array = tifffile.imread(image_path)
61
- if len(img_array.shape) == 3:
62
- if img_array.shape[0] == 4:
63
- channels = 4
64
- elif img_array.shape[2] == 4:
65
- channels = 4
66
- else:
67
- channels = img_array.shape[0] if img_array.shape[0] <= 4 else img_array.shape[2]
68
- else:
69
- channels = 1
70
 
71
- if channels != 4:
72
- raise ValueError(f"YOLO model expects 4-channel images, but got {channels} channels")
 
 
 
 
 
 
73
 
74
- except Exception as e:
75
- raise ValueError(f"Error reading image channels: {str(e)}")
76
-
77
- # Run YOLO prediction
78
- results = yolo_model([image_path], conf=conf)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- return results[0] # Return first result
 
81
 
82
  def create_4channel_tiff(rgb_array, ndvi_array, output_path):
83
  """
@@ -90,7 +111,7 @@ def create_4channel_tiff(rgb_array, ndvi_array, output_path):
90
  """
91
  height, width = rgb_array.shape[:2]
92
 
93
- # Ensure RGB is in uint8 format for better compatibility
94
  if rgb_array.dtype != np.uint8:
95
  if rgb_array.max() <= 1.0:
96
  rgb_normalized = (rgb_array * 255).astype(np.uint8)
@@ -109,13 +130,14 @@ def create_4channel_tiff(rgb_array, ndvi_array, output_path):
109
  four_channel[:, :, 2] = rgb_normalized[:, :, 2] # Blue
110
  four_channel[:, :, 3] = ndvi_normalized # NDVI
111
 
112
- # Save using tifffile with proper format for YOLO compatibility
113
  tifffile.imwrite(
114
- output_path,
115
- four_channel,
116
  photometric='rgb',
117
  compress='lzw',
118
- metadata={'axes': 'YXC'}
 
119
  )
120
 
121
  def load_4channel_tiff(image_path):
@@ -225,7 +247,7 @@ def predict_pipeline(ndvi_model, yolo_model, image_path, conf=0.001):
225
  except Exception as e2:
226
  try:
227
  # Method 3: Fall back to PIL for standard formats
228
- img = Image.open(image_path)
229
  if img.mode != 'RGB':
230
  img = img.convert('RGB')
231
  rgb_array = np.array(img)
@@ -238,12 +260,10 @@ def predict_pipeline(ndvi_model, yolo_model, image_path, conf=0.001):
238
 
239
  # Ensure RGB is in correct format and range
240
  if rgb_array.dtype == np.uint8:
241
- # Keep as uint8 but also create float version for NDVI prediction
242
  rgb_float = rgb_array.astype(np.float32) / 255.0
243
  else:
244
- # Already float, ensure range is [0, 1]
245
  if rgb_array.max() > 1.0:
246
- rgb_float = rgb_array / 255.0
247
  else:
248
  rgb_float = rgb_array
249
  rgb_array = (rgb_float * 255).astype(np.uint8)
@@ -261,13 +281,11 @@ def predict_pipeline(ndvi_model, yolo_model, image_path, conf=0.001):
261
 
262
  # Verify the created file can be read
263
  try:
264
- test_img = Image.open(temp_4ch_path)
265
- if hasattr(test_img, 'n_frames'):
266
- channels = test_img.n_frames
267
  else:
268
- channels = len(test_img.getbands())
269
- test_img.close()
270
-
271
  if channels != 4:
272
  raise ValueError(f"Created TIFF has {channels} channels instead of 4")
273
 
 
5
  from ultralytics import YOLO
6
  from ndvi_predictor import normalize_rgb, predict_ndvi
7
  import tempfile
 
8
  from PIL import Image
9
  import tifffile
10
 
 
45
  """
46
  # Verify the image has 4 channels before prediction
47
  try:
48
+ # Use tifffile for 32-bit TIFF support
49
+ img_array = tifffile.imread(image_path)
 
 
 
 
 
 
50
 
51
+ # Handle different array shapes
52
+ if len(img_array.shape) == 3:
53
+ if img_array.shape[0] == 4:
54
+ # Shape is (4, H, W) - transpose to (H, W, 4)
55
+ img_array = np.transpose(img_array, (1, 2, 0))
56
+ elif img_array.shape[2] != 4:
57
+ raise ValueError(f"YOLO model expects 4-channel images, but got {img_array.shape[2]} channels")
58
+ else:
59
+ raise ValueError(f"Unexpected image shape: {img_array.shape}")
 
 
 
60
 
61
+ # Convert 32-bit float to uint8 if necessary
62
+ if img_array.dtype != np.uint8:
63
+ # Normalize to [0, 255] for RGB channels (first 3)
64
+ rgb_array = img_array[:, :, :3]
65
+ if rgb_array.max() > 1.0:
66
+ rgb_array = np.clip(rgb_array / rgb_array.max() * 255, 0, 255).astype(np.uint8)
67
+ else:
68
+ rgb_array = np.clip(rgb_array * 255, 0, 255).astype(np.uint8)
69
 
70
+ # Normalize NDVI (4th channel) from [-1, 1] to [0, 255]
71
+ ndvi_array = img_array[:, :, 3]
72
+ ndvi_normalized = ((ndvi_array + 1) * 127.5).astype(np.uint8)
73
+
74
+ # Recombine into 4-channel uint8 array
75
+ img_array = np.zeros((img_array.shape[0], img_array.shape[1], 4), dtype=np.uint8)
76
+ img_array[:, :, :3] = rgb_array
77
+ img_array[:, :, 3] = ndvi_normalized
78
+
79
+ # Save normalized image to temporary file
80
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.tiff') as tmp_file:
81
+ temp_path = tmp_file.name
82
+ tifffile.imwrite(
83
+ temp_path,
84
+ img_array,
85
+ photometric='rgb',
86
+ compress='lzw',
87
+ metadata={'axes': 'YXC', 'resolution': (1, 1)} # DPI=1
88
+ )
89
+ image_path = temp_path
90
+
91
+ # Run YOLO prediction
92
+ results = yolo_model([image_path], conf=conf)
93
+
94
+ # Clean up temporary file if created
95
+ if 'temp_path' in locals() and os.path.exists(temp_path):
96
+ os.unlink(temp_path)
97
+
98
+ return results[0] # Return first result
99
 
100
+ except Exception as e:
101
+ raise ValueError(f"Error processing image: {str(e)}")
102
 
103
  def create_4channel_tiff(rgb_array, ndvi_array, output_path):
104
  """
 
111
  """
112
  height, width = rgb_array.shape[:2]
113
 
114
+ # Ensure RGB is in uint8 format
115
  if rgb_array.dtype != np.uint8:
116
  if rgb_array.max() <= 1.0:
117
  rgb_normalized = (rgb_array * 255).astype(np.uint8)
 
130
  four_channel[:, :, 2] = rgb_normalized[:, :, 2] # Blue
131
  four_channel[:, :, 3] = ndvi_normalized # NDVI
132
 
133
+ # Save using tifffile with explicit 32-bit compatibility and DPI=1
134
  tifffile.imwrite(
135
+ output_path,
136
+ four_channel,
137
  photometric='rgb',
138
  compress='lzw',
139
+ metadata={'axes': 'YXC', 'resolution': (1, 1)}, # DPI=1
140
+ bitspersample=8 # Explicitly set to 8-bit per channel
141
  )
142
 
143
  def load_4channel_tiff(image_path):
 
247
  except Exception as e2:
248
  try:
249
  # Method 3: Fall back to PIL for standard formats
250
+ img = Image.PIL(image_path)
251
  if img.mode != 'RGB':
252
  img = img.convert('RGB')
253
  rgb_array = np.array(img)
 
260
 
261
  # Ensure RGB is in correct format and range
262
  if rgb_array.dtype == np.uint8:
 
263
  rgb_float = rgb_array.astype(np.float32) / 255.0
264
  else:
 
265
  if rgb_array.max() > 1.0:
266
+ rgb_float = rgb_array / rgb_array.max()
267
  else:
268
  rgb_float = rgb_array
269
  rgb_array = (rgb_float * 255).astype(np.uint8)
 
281
 
282
  # Verify the created file can be read
283
  try:
284
+ test_array = tifffile.imread(temp_4ch_path)
285
+ if len(test_array.shape) == 3 and (test_array.shape[0] == 4 or test_array.shape[2] == 4):
286
+ channels = 4
287
  else:
288
+ channels = test_array.shape[2] if len(test_array.shape) == 3 else 1
 
 
289
  if channels != 4:
290
  raise ValueError(f"Created TIFF has {channels} channels instead of 4")
291