Upload: load.py
Browse files- single/load.py +11 -6
single/load.py
CHANGED
|
@@ -14,8 +14,8 @@ class EnsembleModel(torch.nn.Module):
|
|
| 14 |
super(EnsembleModel, self).__init__()
|
| 15 |
self.models = torch.nn.ModuleList(models)
|
| 16 |
self.mode = mode
|
| 17 |
-
if mode not in ["min", "mean", "max", "none"]:
|
| 18 |
-
raise ValueError("Mode must be 'none', 'min', 'mean', or 'max'.")
|
| 19 |
|
| 20 |
def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 21 |
"""
|
|
@@ -38,18 +38,20 @@ class EnsembleModel(torch.nn.Module):
|
|
| 38 |
stacked_outputs = torch.stack(outputs, dim=1) # (B, N, 1, H, W)
|
| 39 |
stacked_outputs = stacked_outputs.squeeze(2) # (B, N, H, W)
|
| 40 |
|
| 41 |
-
# Calculate
|
| 42 |
if self.mode == "max":
|
| 43 |
output_probs = torch.max(stacked_outputs, dim=1, keepdim=True)[0]
|
| 44 |
elif self.mode == "mean":
|
| 45 |
output_probs = torch.mean(stacked_outputs, dim=1, keepdim=True)
|
|
|
|
|
|
|
| 46 |
elif self.mode == "min":
|
| 47 |
output_probs = torch.min(stacked_outputs, dim=1, keepdim=True)[0]
|
| 48 |
elif self.mode == "none":
|
| 49 |
# Return all predictions without aggregation
|
| 50 |
return stacked_outputs, None
|
| 51 |
else:
|
| 52 |
-
raise ValueError("Mode must be 'min', 'mean', or 'max'.")
|
| 53 |
|
| 54 |
# Calculate uncertainty (normalized standard deviation)
|
| 55 |
N = len(outputs)
|
|
@@ -69,16 +71,19 @@ class EnsembleModel(torch.nn.Module):
|
|
| 69 |
uncertainty = torch.zeros_like(output_probs)
|
| 70 |
|
| 71 |
return output_probs, uncertainty # Both (B, 1, H, W)
|
| 72 |
-
|
| 73 |
def compiled_model(
|
| 74 |
path: pathlib.Path,
|
| 75 |
stac_item: pystac.Item,
|
| 76 |
-
mode: Literal["min", "mean", "max"] = "max",
|
| 77 |
*args, **kwargs
|
| 78 |
) -> EnsembleModel:
|
| 79 |
"""
|
| 80 |
Loads the ensemble dynamically using the 'stac_item' as the source of truth.
|
| 81 |
|
|
|
|
|
|
|
|
|
|
| 82 |
Returns an EnsembleModel that outputs (probabilities, uncertainty) tuple.
|
| 83 |
"""
|
| 84 |
model_paths = []
|
|
|
|
| 14 |
super(EnsembleModel, self).__init__()
|
| 15 |
self.models = torch.nn.ModuleList(models)
|
| 16 |
self.mode = mode
|
| 17 |
+
if mode not in ["min", "mean", "median", "max", "none"]:
|
| 18 |
+
raise ValueError("Mode must be 'none', 'min', 'mean', 'median', or 'max'.")
|
| 19 |
|
| 20 |
def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 21 |
"""
|
|
|
|
| 38 |
stacked_outputs = torch.stack(outputs, dim=1) # (B, N, 1, H, W)
|
| 39 |
stacked_outputs = stacked_outputs.squeeze(2) # (B, N, H, W)
|
| 40 |
|
| 41 |
+
# Calculate aggregated probabilities
|
| 42 |
if self.mode == "max":
|
| 43 |
output_probs = torch.max(stacked_outputs, dim=1, keepdim=True)[0]
|
| 44 |
elif self.mode == "mean":
|
| 45 |
output_probs = torch.mean(stacked_outputs, dim=1, keepdim=True)
|
| 46 |
+
elif self.mode == "median":
|
| 47 |
+
output_probs = torch.median(stacked_outputs, dim=1, keepdim=True)[0]
|
| 48 |
elif self.mode == "min":
|
| 49 |
output_probs = torch.min(stacked_outputs, dim=1, keepdim=True)[0]
|
| 50 |
elif self.mode == "none":
|
| 51 |
# Return all predictions without aggregation
|
| 52 |
return stacked_outputs, None
|
| 53 |
else:
|
| 54 |
+
raise ValueError("Mode must be 'min', 'mean', 'median', or 'max'.")
|
| 55 |
|
| 56 |
# Calculate uncertainty (normalized standard deviation)
|
| 57 |
N = len(outputs)
|
|
|
|
| 71 |
uncertainty = torch.zeros_like(output_probs)
|
| 72 |
|
| 73 |
return output_probs, uncertainty # Both (B, 1, H, W)
|
| 74 |
+
|
| 75 |
def compiled_model(
|
| 76 |
path: pathlib.Path,
|
| 77 |
stac_item: pystac.Item,
|
| 78 |
+
mode: Literal["min", "mean", "median", "max"] = "max",
|
| 79 |
*args, **kwargs
|
| 80 |
) -> EnsembleModel:
|
| 81 |
"""
|
| 82 |
Loads the ensemble dynamically using the 'stac_item' as the source of truth.
|
| 83 |
|
| 84 |
+
Args:
|
| 85 |
+
mode: Aggregation mode - 'min', 'mean', 'median', or 'max' (default: 'max')
|
| 86 |
+
|
| 87 |
Returns an EnsembleModel that outputs (probabilities, uncertainty) tuple.
|
| 88 |
"""
|
| 89 |
model_paths = []
|