ahadhassan commited on
Commit
65137d3
·
verified ·
1 Parent(s): 458f0f7

Update yolo_predictor.py

Browse files
Files changed (1) hide show
  1. yolo_predictor.py +90 -50
yolo_predictor.py CHANGED
@@ -7,6 +7,7 @@ from ndvi_predictor import normalize_rgb, predict_ndvi
7
  import tempfile
8
  from rasterio.transform import from_bounds
9
  from PIL import Image
 
10
 
11
  def load_yolo_model(model_path):
12
  """Load YOLO model from .pt file"""
@@ -60,35 +61,23 @@ def create_4channel_tiff(rgb_array, ndvi_array, output_path):
60
  height, width = rgb_array.shape[:2]
61
 
62
  # Stack RGB and NDVI to create 4-channel image
63
- four_channel = np.zeros((height, width, 4), dtype=rgb_array.dtype)
64
- four_channel[:, :, :3] = rgb_array # RGB channels
65
 
66
- # Normalize NDVI to match RGB data type range
67
  if rgb_array.dtype == np.uint8:
68
- # Scale NDVI from [-1, 1] to [0, 255]
69
- ndvi_scaled = ((ndvi_array + 1) * 127.5).astype(np.uint8)
70
  else:
71
- # Keep NDVI in original range for float types
72
- ndvi_scaled = ndvi_array.astype(rgb_array.dtype)
73
-
74
- four_channel[:, :, 3] = ndvi_scaled # NDVI channel
75
-
76
- # Create transform (assuming no specific georeferencing needed)
77
- transform = from_bounds(0, 0, width, height, width, height)
78
-
79
- # Write 4-channel TIFF
80
- with rasterio.open(
81
- output_path,
82
- 'w',
83
- driver='GTiff',
84
- height=height,
85
- width=width,
86
- count=4,
87
- dtype=four_channel.dtype,
88
- transform=transform
89
- ) as dst:
90
- for i in range(4):
91
- dst.write(four_channel[:, :, i], i + 1)
92
 
93
  def load_4channel_tiff(image_path):
94
  """
@@ -101,16 +90,38 @@ def load_4channel_tiff(image_path):
101
  rgb_array: RGB channels as numpy array (H, W, 3)
102
  ndvi_array: NDVI channel as numpy array (H, W)
103
  """
104
- with rasterio.open(image_path) as src:
105
- # Read all 4 channels
106
- channels = src.read() # Shape: (4, H, W)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
- # Extract RGB and NDVI
109
- rgb_array = np.transpose(channels[:3], (1, 2, 0)) # (H, W, 3)
110
- ndvi_array = channels[3] # (H, W)
 
 
 
 
 
 
 
111
 
112
- # If NDVI was scaled to uint8, convert back to [-1, 1] range
113
- if channels.dtype == np.uint8:
114
  ndvi_array = (ndvi_array.astype(np.float32) / 127.5) - 1
115
 
116
  return rgb_array, ndvi_array
@@ -129,28 +140,57 @@ def predict_pipeline(ndvi_model, yolo_model, image_path, conf=0.001):
129
  Returns:
130
  results: YOLO results object
131
  """
132
- # Try to load as 4-channel TIFF first, fall back to RGB
 
 
133
  try:
134
- with rasterio.open(image_path) as src:
135
- if src.count == 4:
136
- # Load 4-channel TIFF
137
- rgb_array, _ = load_4channel_tiff(image_path)
138
- elif src.count == 3:
139
- # Load as RGB TIFF
140
- channels = src.read()
141
- rgb_array = np.transpose(channels, (1, 2, 0))
142
- else:
143
- raise ValueError(f"Unsupported number of channels: {src.count}")
144
- except:
145
- # Fall back to PIL for standard image formats
146
- img = Image.open(image_path).convert("RGB")
147
- rgb_array = np.array(img)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
  # Predict NDVI from RGB
150
  ndvi_pred = predict_ndvi_from_rgb(ndvi_model, rgb_array)
151
 
152
  # Create temporary 4-channel TIFF file
153
- with tempfile.NamedTemporaryFile(delete=False, suffix='.tif') as tmp_file:
154
  temp_4ch_path = tmp_file.name
155
 
156
  try:
 
7
  import tempfile
8
  from rasterio.transform import from_bounds
9
  from PIL import Image
10
+ import tifffile
11
 
12
  def load_yolo_model(model_path):
13
  """Load YOLO model from .pt file"""
 
61
  height, width = rgb_array.shape[:2]
62
 
63
  # Stack RGB and NDVI to create 4-channel image
64
+ four_channel = np.zeros((4, height, width), dtype=np.float32)
 
65
 
66
+ # Convert RGB to proper format and range
67
  if rgb_array.dtype == np.uint8:
68
+ rgb_normalized = rgb_array.astype(np.float32) / 255.0
 
69
  else:
70
+ rgb_normalized = rgb_array.astype(np.float32)
71
+
72
+ # Assign channels in (C, H, W) format for rasterio
73
+ four_channel[0] = rgb_normalized[:, :, 0] # Red
74
+ four_channel[1] = rgb_normalized[:, :, 1] # Green
75
+ four_channel[2] = rgb_normalized[:, :, 2] # Blue
76
+ four_channel[3] = ndvi_array.astype(np.float32) # NDVI
77
+
78
+ # Use tifffile for better compatibility with YOLO
79
+ import tifffile
80
+ tifffile.imwrite(output_path, four_channel, photometric='rgb')
 
 
 
 
 
 
 
 
 
 
81
 
82
  def load_4channel_tiff(image_path):
83
  """
 
90
  rgb_array: RGB channels as numpy array (H, W, 3)
91
  ndvi_array: NDVI channel as numpy array (H, W)
92
  """
93
+ try:
94
+ with rasterio.open(image_path) as src:
95
+ # Read all 4 channels
96
+ channels = src.read() # Shape: (4, H, W)
97
+
98
+ # Extract RGB and NDVI
99
+ rgb_array = np.transpose(channels[:3], (1, 2, 0)) # (H, W, 3)
100
+ ndvi_array = channels[3] # (H, W)
101
+
102
+ # If NDVI was scaled to uint8, convert back to [-1, 1] range
103
+ if channels.dtype == np.uint8:
104
+ ndvi_array = (ndvi_array.astype(np.float32) / 127.5) - 1
105
+
106
+ return rgb_array, ndvi_array
107
+ except Exception as e:
108
+ # Try with tifffile as fallback
109
+ import tifffile
110
+ img_array = tifffile.imread(image_path)
111
 
112
+ if len(img_array.shape) == 3 and img_array.shape[0] == 4:
113
+ # Shape is (4, H, W)
114
+ rgb_array = np.transpose(img_array[:3], (1, 2, 0)) # (H, W, 3)
115
+ ndvi_array = img_array[3] # (H, W)
116
+ elif len(img_array.shape) == 3 and img_array.shape[2] == 4:
117
+ # Shape is (H, W, 4)
118
+ rgb_array = img_array[:, :, :3] # (H, W, 3)
119
+ ndvi_array = img_array[:, :, 3] # (H, W)
120
+ else:
121
+ raise ValueError(f"Unexpected image shape: {img_array.shape}")
122
 
123
+ # Normalize NDVI if needed
124
+ if img_array.dtype == np.uint8:
125
  ndvi_array = (ndvi_array.astype(np.float32) / 127.5) - 1
126
 
127
  return rgb_array, ndvi_array
 
140
  Returns:
141
  results: YOLO results object
142
  """
143
+ rgb_array = None
144
+
145
+ # Try multiple methods to load the image
146
  try:
147
+ # Method 1: Try with tifffile first (best for complex TIFF files)
148
+ import tifffile
149
+ img_array = tifffile.imread(image_path)
150
+
151
+ if len(img_array.shape) == 3:
152
+ if img_array.shape[0] == 4:
153
+ # Shape is (4, H, W) - extract RGB
154
+ rgb_array = np.transpose(img_array[:3], (1, 2, 0))
155
+ elif img_array.shape[2] == 4:
156
+ # Shape is (H, W, 4) - extract RGB
157
+ rgb_array = img_array[:, :, :3]
158
+ elif img_array.shape[2] == 3:
159
+ # Shape is (H, W, 3) - already RGB
160
+ rgb_array = img_array
161
+ elif img_array.shape[0] == 3:
162
+ # Shape is (3, H, W) - transpose to RGB
163
+ rgb_array = np.transpose(img_array, (1, 2, 0))
164
+ except Exception as e1:
165
+ try:
166
+ # Method 2: Try with rasterio
167
+ with rasterio.open(image_path) as src:
168
+ if src.count >= 3:
169
+ channels = src.read()
170
+ if src.count == 4:
171
+ rgb_array = np.transpose(channels[:3], (1, 2, 0))
172
+ else:
173
+ rgb_array = np.transpose(channels, (1, 2, 0))
174
+ except Exception as e2:
175
+ try:
176
+ # Method 3: Fall back to PIL for standard formats
177
+ img = Image.open(image_path).convert("RGB")
178
+ rgb_array = np.array(img)
179
+ except Exception as e3:
180
+ raise ValueError(f"Could not load image with any method. Errors: tifffile={e1}, rasterio={e2}, PIL={e3}")
181
+
182
+ if rgb_array is None:
183
+ raise ValueError("Failed to extract RGB data from image")
184
+
185
+ # Ensure RGB is in correct format and range
186
+ if rgb_array.max() > 1:
187
+ rgb_array = rgb_array.astype(np.float32) / 255.0
188
 
189
  # Predict NDVI from RGB
190
  ndvi_pred = predict_ndvi_from_rgb(ndvi_model, rgb_array)
191
 
192
  # Create temporary 4-channel TIFF file
193
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.tiff') as tmp_file:
194
  temp_4ch_path = tmp_file.name
195
 
196
  try: