JulioContrerasH commited on
Commit
ff46e70
·
verified ·
1 Parent(s): 4b10cbe

Upload: load.py

Browse files
Files changed (1) hide show
  1. single/load.py +19 -11
single/load.py CHANGED
@@ -71,20 +71,24 @@ class EnsembleModel(torch.nn.Module):
71
  uncertainty = torch.zeros_like(output_probs)
72
 
73
  return output_probs, uncertainty # Both (B, 1, H, W)
74
-
75
  def compiled_model(
76
  path: pathlib.Path,
77
  stac_item: pystac.Item,
78
  mode: Literal["min", "mean", "median", "max"] = "max",
79
  *args, **kwargs
80
- ) -> EnsembleModel:
81
  """
82
- Loads the ensemble dynamically using the 'stac_item' as the source of truth.
 
 
 
83
 
84
  Args:
85
- mode: Aggregation mode - 'min', 'mean', 'median', or 'max' (default: 'max')
86
 
87
- Returns an EnsembleModel that outputs (probabilities, uncertainty) tuple.
 
88
  """
89
  model_paths = []
90
  for asset_key, asset in stac_item.assets.items():
@@ -93,10 +97,16 @@ def compiled_model(
93
 
94
  if not model_paths:
95
  raise ValueError("No .pt2 files found in STAC item assets.")
 
96
  model_paths.sort()
97
- models = [torch.export.load(p).module() for p in model_paths]
98
 
99
- return EnsembleModel(*models, mode=mode)
 
 
 
 
 
 
100
 
101
  def define_iteration(dimension: tuple, chunk_size: int, overlap: int = 0):
102
  """
@@ -172,8 +182,7 @@ def predict_large(
172
  except (NotImplementedError, AttributeError):
173
  # Exported model (.pt2) or EnsembleModel
174
  model = model.to(device)
175
-
176
- # Test if model returns tuple (ensemble) or single output
177
  test_input = torch.zeros(1, bands, chunk_size, chunk_size).to(device)
178
  with torch.no_grad():
179
  test_output = model(test_input)
@@ -300,8 +309,7 @@ def predict_large(
300
  offset_y : offset_y + length_y,
301
  offset_x : offset_x + length_x
302
  ] = to_write_uncertainty
303
-
304
- # Return based on model type
305
  if is_ensemble:
306
  return output_probs, output_uncertainty
307
  else:
 
71
  uncertainty = torch.zeros_like(output_probs)
72
 
73
  return output_probs, uncertainty # Both (B, 1, H, W)
74
+
75
  def compiled_model(
76
  path: pathlib.Path,
77
  stac_item: pystac.Item,
78
  mode: Literal["min", "mean", "median", "max"] = "max",
79
  *args, **kwargs
80
+ ):
81
  """
82
+ Loads model(s) dynamically based on STAC metadata.
83
+
84
+ - If single .pt2 → returns single model
85
+ - If multiple .pt2 → returns EnsembleModel
86
 
87
  Args:
88
+ mode: Aggregation mode for ensembles (ignored for single models)
89
 
90
+ Returns:
91
+ Single model or EnsembleModel
92
  """
93
  model_paths = []
94
  for asset_key, asset in stac_item.assets.items():
 
97
 
98
  if not model_paths:
99
  raise ValueError("No .pt2 files found in STAC item assets.")
100
+
101
  model_paths.sort()
 
102
 
103
+ if len(model_paths) == 1:
104
+ # Single model
105
+ return torch.export.load(model_paths[0]).module()
106
+ else:
107
+ # Ensemble model
108
+ models = [torch.export.load(p).module() for p in model_paths]
109
+ return EnsembleModel(*models, mode=mode)
110
 
111
  def define_iteration(dimension: tuple, chunk_size: int, overlap: int = 0):
112
  """
 
182
  except (NotImplementedError, AttributeError):
183
  # Exported model (.pt2) or EnsembleModel
184
  model = model.to(device)
185
+
 
186
  test_input = torch.zeros(1, bands, chunk_size, chunk_size).to(device)
187
  with torch.no_grad():
188
  test_output = model(test_input)
 
309
  offset_y : offset_y + length_y,
310
  offset_x : offset_x + length_x
311
  ] = to_write_uncertainty
312
+
 
313
  if is_ensemble:
314
  return output_probs, output_uncertainty
315
  else: