JulioContrerasH commited on
Commit
f122baf
·
verified ·
1 Parent(s): f62c253

Update load.py

Browse files
Files changed (1) hide show
  1. load.py +13 -22
load.py CHANGED
@@ -125,29 +125,30 @@ def compiled_model(
125
  def predict_large(
126
  image: np.ndarray,
127
  model: nn.Module,
128
- chunk_size: int = None, # None = auto
129
  overlap: int = None,
130
  batch_size: int = 1,
131
  device: str = "cpu",
132
  merge_clouds: bool = False,
133
  apply_rules: bool = False,
134
- max_direct_size: int = 4096, # Max size for direct inference
135
  **kwargs
136
  ) -> np.ndarray:
137
  """
138
  Predict on images of any size.
139
- Automatically uses direct inference for small/medium images to avoid tiling artifacts.
 
 
 
140
 
141
  Args:
142
  image: Input image (C, H, W) in reflectance [0, 1]
143
  model: Loaded model from compiled_model()
144
- chunk_size: Size of inference tiles (None = auto-detect)
145
- overlap: Overlap between tiles (None = auto, chunk_size // 2)
146
  batch_size: Tiles per batch (default: 1)
147
  device: 'cpu' or 'cuda'
148
  merge_clouds: If True, merge thin+thick into single cloud class
149
  apply_rules: If True, apply physical rules for bright clouds
150
- max_direct_size: Maximum dimension for direct inference (default: 4096)
151
 
152
  Returns:
153
  Predicted class labels (H, W)
@@ -162,24 +163,15 @@ def predict_large(
162
  merge_clouds = model.merge_clouds
163
 
164
  C, H, W = image.shape
165
- max_dim = max(H, W)
166
-
167
- # === AUTO CHUNK SIZE ===
168
- if chunk_size is None:
169
- if max_dim <= 1024:
170
- chunk_size = max_dim # Process entire image
171
- elif max_dim <= 2048:
172
- chunk_size = 1024
173
- else:
174
- chunk_size = 512
175
 
176
  # Set default overlap
177
  if overlap is None:
178
  overlap = chunk_size // 2
179
 
180
- # === DIRECT INFERENCE (NO TILING) ===
181
- # Use direct inference if image fits in single chunk or is small enough
182
- if max_dim <= max_direct_size or (H <= chunk_size and W <= chunk_size):
 
183
  with torch.no_grad():
184
  img_tensor = torch.from_numpy(image).unsqueeze(0).float().to(device)
185
  logits = model(img_tensor)
@@ -199,8 +191,7 @@ def predict_large(
199
 
200
  return pred
201
 
202
- # === SLIDING WINDOW FOR VERY LARGE IMAGES ===
203
- print(f" Using sliding window: {H}x{W} -> chunks={chunk_size}, overlap={overlap}")
204
 
205
  step = chunk_size - overlap
206
 
@@ -230,7 +221,7 @@ def predict_large(
230
  # Create blending window
231
  window = get_spline_window(chunk_size, power=2)
232
 
233
- # Generate tile coordinates - ensure full coverage
234
  coords = []
235
  for r in range(0, H_pad - chunk_size + 1, step):
236
  for c in range(0, W_pad - chunk_size + 1, step):
 
125
  def predict_large(
126
  image: np.ndarray,
127
  model: nn.Module,
128
+ chunk_size: int = 512,
129
  overlap: int = None,
130
  batch_size: int = 1,
131
  device: str = "cpu",
132
  merge_clouds: bool = False,
133
  apply_rules: bool = False,
 
134
  **kwargs
135
  ) -> np.ndarray:
136
  """
137
  Predict on images of any size.
138
+
139
+ Strategy:
140
+ - Images ≤ 2048px in any dimension: direct inference (no tiling)
141
+ - Images > 2048px: sliding window with specified chunk_size
142
 
143
  Args:
144
  image: Input image (C, H, W) in reflectance [0, 1]
145
  model: Loaded model from compiled_model()
146
+ chunk_size: Size of inference tiles for large images (default: 512)
147
+ overlap: Overlap between tiles (default: chunk_size // 2)
148
  batch_size: Tiles per batch (default: 1)
149
  device: 'cpu' or 'cuda'
150
  merge_clouds: If True, merge thin+thick into single cloud class
151
  apply_rules: If True, apply physical rules for bright clouds
 
152
 
153
  Returns:
154
  Predicted class labels (H, W)
 
163
  merge_clouds = model.merge_clouds
164
 
165
  C, H, W = image.shape
 
 
 
 
 
 
 
 
 
 
166
 
167
  # Set default overlap
168
  if overlap is None:
169
  overlap = chunk_size // 2
170
 
171
+ # === STRATEGY: Use direct inference for images ≤ 2048px ===
172
+ # This avoids tiling artifacts on small/medium images
173
+ if max(H, W) <= 2048:
174
+ # Direct inference - no tiling
175
  with torch.no_grad():
176
  img_tensor = torch.from_numpy(image).unsqueeze(0).float().to(device)
177
  logits = model(img_tensor)
 
191
 
192
  return pred
193
 
194
+ # === SLIDING WINDOW FOR LARGE IMAGES (> 2048px) ===
 
195
 
196
  step = chunk_size - overlap
197
 
 
221
  # Create blending window
222
  window = get_spline_window(chunk_size, power=2)
223
 
224
+ # Generate tile coordinates
225
  coords = []
226
  for r in range(0, H_pad - chunk_size + 1, step):
227
  for c in range(0, W_pad - chunk_size + 1, step):