JulioContrerasH commited on
Commit
58ad7ff
·
verified ·
1 Parent(s): 919f17a

Update load.py

Browse files
Files changed (1) hide show
  1. load.py +30 -9
load.py CHANGED
@@ -121,8 +121,8 @@ def compiled_model(
121
  def predict_large(
122
  image: np.ndarray,
123
  model: nn.Module,
124
- chunk_size: int = 512,
125
- overlap: int = 256,
126
  batch_size: int = 1,
127
  device: str = "cpu",
128
  merge_clouds: bool = False,
@@ -136,7 +136,7 @@ def predict_large(
136
  image: Input image (C, H, W) in reflectance [0, 1]
137
  model: Loaded model from compiled_model()
138
  chunk_size: Size of inference tiles (default: 512)
139
- overlap: Overlap between tiles (default: 256)
140
  batch_size: Tiles per batch (default: 1)
141
  device: 'cpu' or 'cuda'
142
  merge_clouds: If True, merge thin+thick into single cloud class
@@ -150,6 +150,7 @@ def predict_large(
150
  model.eval()
151
  model.to(device)
152
 
 
153
  if not hasattr(model, 'merge_clouds'):
154
  model.merge_clouds = merge_clouds
155
  else:
@@ -157,17 +158,23 @@ def predict_large(
157
 
158
  C, H, W = image.shape
159
 
 
 
 
 
 
160
  if H <= chunk_size and W <= chunk_size:
161
  with torch.no_grad():
162
  img_tensor = torch.from_numpy(image).unsqueeze(0).float().to(device)
163
  logits = model(img_tensor)
164
 
165
  if merge_clouds:
 
166
  probs = torch.softmax(logits, dim=1)
167
  probs_merged = torch.zeros(1, 3, H, W, device=device)
168
- probs_merged[:, 0] = probs[:, 0]
169
- probs_merged[:, 1] = probs[:, 1] + probs[:, 2]
170
- probs_merged[:, 2] = probs[:, 3]
171
  pred = probs_merged.argmax(1).squeeze().cpu().numpy().astype(np.uint8)
172
  else:
173
  pred = logits.argmax(1).squeeze().cpu().numpy().astype(np.uint8)
@@ -177,9 +184,11 @@ def predict_large(
177
 
178
  return pred
179
 
 
180
  step = chunk_size - overlap
181
  half_tile = chunk_size // 2
182
 
 
183
  image_padded = np.pad(
184
  image,
185
  ((0, 0), (half_tile, half_tile + chunk_size), (half_tile, half_tile + chunk_size)),
@@ -188,49 +197,61 @@ def predict_large(
188
 
189
  _, H_pad, W_pad = image_padded.shape
190
 
 
191
  num_classes = 4
192
  probs_sum = np.zeros((num_classes, H_pad, W_pad), dtype=np.float32)
193
  weight_sum = np.zeros((H_pad, W_pad), dtype=np.float32)
194
 
 
195
  window = get_spline_window(chunk_size, power=2)
196
 
 
197
  coords = [
198
  (r, c)
199
  for r in range(0, H_pad - chunk_size + 1, step)
200
  for c in range(0, W_pad - chunk_size + 1, step)
201
  ]
202
 
 
203
  with torch.no_grad():
204
  for i in tqdm(range(0, len(coords), batch_size), desc=" Tiles", leave=False, disable=True):
205
  batch_coords = coords[i:i + batch_size]
206
 
 
207
  tiles = np.stack([
208
  image_padded[:, r:r + chunk_size, c:c + chunk_size]
209
  for r, c in batch_coords
210
  ])
211
 
 
212
  tiles_tensor = torch.from_numpy(tiles).float().to(device)
213
  logits = model(tiles_tensor)
214
  probs = torch.softmax(logits, dim=1).cpu().numpy()
215
 
 
216
  for j, (r, c) in enumerate(batch_coords):
217
  probs_sum[:, r:r + chunk_size, c:c + chunk_size] += probs[j] * window
218
  weight_sum[r:r + chunk_size, c:c + chunk_size] += window
219
 
 
220
  weight_sum = np.maximum(weight_sum, 1e-8)
221
  probs_final = probs_sum / weight_sum
222
 
 
223
  probs_final = probs_final[:, half_tile:half_tile + H, half_tile:half_tile + W]
224
 
 
225
  if merge_clouds:
 
226
  probs_merged = np.zeros((3, H, W), dtype=np.float32)
227
- probs_merged[0] = probs_final[0]
228
- probs_merged[1] = probs_final[1] + probs_final[2]
229
- probs_merged[2] = probs_final[3]
230
  pred = np.argmax(probs_merged, axis=0).astype(np.uint8)
231
  else:
232
  pred = np.argmax(probs_final, axis=0).astype(np.uint8)
233
 
 
234
  if apply_rules:
235
  pred = apply_physical_rules(pred, image, merge_clouds)
236
 
 
121
  def predict_large(
122
  image: np.ndarray,
123
  model: nn.Module,
124
+ chunk_size: int = 512,
125
+ overlap: int = None,
126
  batch_size: int = 1,
127
  device: str = "cpu",
128
  merge_clouds: bool = False,
 
136
  image: Input image (C, H, W) in reflectance [0, 1]
137
  model: Loaded model from compiled_model()
138
  chunk_size: Size of inference tiles (default: 512)
139
+ overlap: Overlap between tiles (default: chunk_size // 2)
140
  batch_size: Tiles per batch (default: 1)
141
  device: 'cpu' or 'cuda'
142
  merge_clouds: If True, merge thin+thick into single cloud class
 
150
  model.eval()
151
  model.to(device)
152
 
153
+ # Get merge_clouds from model if not specified
154
  if not hasattr(model, 'merge_clouds'):
155
  model.merge_clouds = merge_clouds
156
  else:
 
158
 
159
  C, H, W = image.shape
160
 
161
+ # Set default overlap if not specified
162
+ if overlap is None:
163
+ overlap = chunk_size // 2
164
+
165
+ # Direct inference if image fits within chunk_size
166
  if H <= chunk_size and W <= chunk_size:
167
  with torch.no_grad():
168
  img_tensor = torch.from_numpy(image).unsqueeze(0).float().to(device)
169
  logits = model(img_tensor)
170
 
171
  if merge_clouds:
172
+ # Merge thin+thick clouds into single cloud class
173
  probs = torch.softmax(logits, dim=1)
174
  probs_merged = torch.zeros(1, 3, H, W, device=device)
175
+ probs_merged[:, 0] = probs[:, 0] # Clear
176
+ probs_merged[:, 1] = probs[:, 1] + probs[:, 2] # Cloud (thin+thick)
177
+ probs_merged[:, 2] = probs[:, 3] # Shadow
178
  pred = probs_merged.argmax(1).squeeze().cpu().numpy().astype(np.uint8)
179
  else:
180
  pred = logits.argmax(1).squeeze().cpu().numpy().astype(np.uint8)
 
184
 
185
  return pred
186
 
187
+ # Sliding window inference for larger images
188
  step = chunk_size - overlap
189
  half_tile = chunk_size // 2
190
 
191
+ # Pad image to ensure tiles cover entire image
192
  image_padded = np.pad(
193
  image,
194
  ((0, 0), (half_tile, half_tile + chunk_size), (half_tile, half_tile + chunk_size)),
 
197
 
198
  _, H_pad, W_pad = image_padded.shape
199
 
200
+ # Initialize accumulation buffers
201
  num_classes = 4
202
  probs_sum = np.zeros((num_classes, H_pad, W_pad), dtype=np.float32)
203
  weight_sum = np.zeros((H_pad, W_pad), dtype=np.float32)
204
 
205
+ # Create blending window
206
  window = get_spline_window(chunk_size, power=2)
207
 
208
+ # Generate tile coordinates
209
  coords = [
210
  (r, c)
211
  for r in range(0, H_pad - chunk_size + 1, step)
212
  for c in range(0, W_pad - chunk_size + 1, step)
213
  ]
214
 
215
+ # Process tiles in batches
216
  with torch.no_grad():
217
  for i in tqdm(range(0, len(coords), batch_size), desc=" Tiles", leave=False, disable=True):
218
  batch_coords = coords[i:i + batch_size]
219
 
220
+ # Extract tiles
221
  tiles = np.stack([
222
  image_padded[:, r:r + chunk_size, c:c + chunk_size]
223
  for r, c in batch_coords
224
  ])
225
 
226
+ # Run inference
227
  tiles_tensor = torch.from_numpy(tiles).float().to(device)
228
  logits = model(tiles_tensor)
229
  probs = torch.softmax(logits, dim=1).cpu().numpy()
230
 
231
+ # Accumulate weighted predictions
232
  for j, (r, c) in enumerate(batch_coords):
233
  probs_sum[:, r:r + chunk_size, c:c + chunk_size] += probs[j] * window
234
  weight_sum[r:r + chunk_size, c:c + chunk_size] += window
235
 
236
+ # Normalize by accumulated weights
237
  weight_sum = np.maximum(weight_sum, 1e-8)
238
  probs_final = probs_sum / weight_sum
239
 
240
+ # Remove padding
241
  probs_final = probs_final[:, half_tile:half_tile + H, half_tile:half_tile + W]
242
 
243
+ # Get final prediction
244
  if merge_clouds:
245
+ # Merge thin+thick clouds into single cloud class
246
  probs_merged = np.zeros((3, H, W), dtype=np.float32)
247
+ probs_merged[0] = probs_final[0] # Clear
248
+ probs_merged[1] = probs_final[1] + probs_final[2] # Cloud (thin+thick)
249
+ probs_merged[2] = probs_final[3] # Shadow
250
  pred = np.argmax(probs_merged, axis=0).astype(np.uint8)
251
  else:
252
  pred = np.argmax(probs_final, axis=0).astype(np.uint8)
253
 
254
+ # Apply physical rules if requested
255
  if apply_rules:
256
  pred = apply_physical_rules(pred, image, merge_clouds)
257