Upload: load.py
Browse files- single/load.py +19 -11
single/load.py
CHANGED
|
@@ -71,20 +71,24 @@ class EnsembleModel(torch.nn.Module):
|
|
| 71 |
uncertainty = torch.zeros_like(output_probs)
|
| 72 |
|
| 73 |
return output_probs, uncertainty # Both (B, 1, H, W)
|
| 74 |
-
|
| 75 |
def compiled_model(
|
| 76 |
path: pathlib.Path,
|
| 77 |
stac_item: pystac.Item,
|
| 78 |
mode: Literal["min", "mean", "median", "max"] = "max",
|
| 79 |
*args, **kwargs
|
| 80 |
-
)
|
| 81 |
"""
|
| 82 |
-
Loads
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
Args:
|
| 85 |
-
mode: Aggregation mode
|
| 86 |
|
| 87 |
-
Returns
|
|
|
|
| 88 |
"""
|
| 89 |
model_paths = []
|
| 90 |
for asset_key, asset in stac_item.assets.items():
|
|
@@ -93,10 +97,16 @@ def compiled_model(
|
|
| 93 |
|
| 94 |
if not model_paths:
|
| 95 |
raise ValueError("No .pt2 files found in STAC item assets.")
|
|
|
|
| 96 |
model_paths.sort()
|
| 97 |
-
models = [torch.export.load(p).module() for p in model_paths]
|
| 98 |
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
def define_iteration(dimension: tuple, chunk_size: int, overlap: int = 0):
|
| 102 |
"""
|
|
@@ -172,8 +182,7 @@ def predict_large(
|
|
| 172 |
except (NotImplementedError, AttributeError):
|
| 173 |
# Exported model (.pt2) or EnsembleModel
|
| 174 |
model = model.to(device)
|
| 175 |
-
|
| 176 |
-
# Test if model returns tuple (ensemble) or single output
|
| 177 |
test_input = torch.zeros(1, bands, chunk_size, chunk_size).to(device)
|
| 178 |
with torch.no_grad():
|
| 179 |
test_output = model(test_input)
|
|
@@ -300,8 +309,7 @@ def predict_large(
|
|
| 300 |
offset_y : offset_y + length_y,
|
| 301 |
offset_x : offset_x + length_x
|
| 302 |
] = to_write_uncertainty
|
| 303 |
-
|
| 304 |
-
# Return based on model type
|
| 305 |
if is_ensemble:
|
| 306 |
return output_probs, output_uncertainty
|
| 307 |
else:
|
|
|
|
| 71 |
uncertainty = torch.zeros_like(output_probs)
|
| 72 |
|
| 73 |
return output_probs, uncertainty # Both (B, 1, H, W)
|
| 74 |
+
|
| 75 |
def compiled_model(
|
| 76 |
path: pathlib.Path,
|
| 77 |
stac_item: pystac.Item,
|
| 78 |
mode: Literal["min", "mean", "median", "max"] = "max",
|
| 79 |
*args, **kwargs
|
| 80 |
+
):
|
| 81 |
"""
|
| 82 |
+
Loads model(s) dynamically based on STAC metadata.
|
| 83 |
+
|
| 84 |
+
- If single .pt2 → returns single model
|
| 85 |
+
- If multiple .pt2 → returns EnsembleModel
|
| 86 |
|
| 87 |
Args:
|
| 88 |
+
mode: Aggregation mode for ensembles (ignored for single models)
|
| 89 |
|
| 90 |
+
Returns:
|
| 91 |
+
Single model or EnsembleModel
|
| 92 |
"""
|
| 93 |
model_paths = []
|
| 94 |
for asset_key, asset in stac_item.assets.items():
|
|
|
|
| 97 |
|
| 98 |
if not model_paths:
|
| 99 |
raise ValueError("No .pt2 files found in STAC item assets.")
|
| 100 |
+
|
| 101 |
model_paths.sort()
|
|
|
|
| 102 |
|
| 103 |
+
if len(model_paths) == 1:
|
| 104 |
+
# Single model
|
| 105 |
+
return torch.export.load(model_paths[0]).module()
|
| 106 |
+
else:
|
| 107 |
+
# Ensemble model
|
| 108 |
+
models = [torch.export.load(p).module() for p in model_paths]
|
| 109 |
+
return EnsembleModel(*models, mode=mode)
|
| 110 |
|
| 111 |
def define_iteration(dimension: tuple, chunk_size: int, overlap: int = 0):
|
| 112 |
"""
|
|
|
|
| 182 |
except (NotImplementedError, AttributeError):
|
| 183 |
# Exported model (.pt2) or EnsembleModel
|
| 184 |
model = model.to(device)
|
| 185 |
+
|
|
|
|
| 186 |
test_input = torch.zeros(1, bands, chunk_size, chunk_size).to(device)
|
| 187 |
with torch.no_grad():
|
| 188 |
test_output = model(test_input)
|
|
|
|
| 309 |
offset_y : offset_y + length_y,
|
| 310 |
offset_x : offset_x + length_x
|
| 311 |
] = to_write_uncertainty
|
| 312 |
+
|
|
|
|
| 313 |
if is_ensemble:
|
| 314 |
return output_probs, output_uncertainty
|
| 315 |
else:
|