JulioContrerasH commited on
Commit
4b10cbe
·
verified ·
1 Parent(s): 9ef8333

Upload: load.py

Browse files
Files changed (1) hide show
  1. single/load.py +11 -6
single/load.py CHANGED
@@ -14,8 +14,8 @@ class EnsembleModel(torch.nn.Module):
14
  super(EnsembleModel, self).__init__()
15
  self.models = torch.nn.ModuleList(models)
16
  self.mode = mode
17
- if mode not in ["min", "mean", "max", "none"]:
18
- raise ValueError("Mode must be 'none', 'min', 'mean', or 'max'.")
19
 
20
  def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]:
21
  """
@@ -38,18 +38,20 @@ class EnsembleModel(torch.nn.Module):
38
  stacked_outputs = torch.stack(outputs, dim=1) # (B, N, 1, H, W)
39
  stacked_outputs = stacked_outputs.squeeze(2) # (B, N, H, W)
40
 
41
- # Calculate mean (probabilities)
42
  if self.mode == "max":
43
  output_probs = torch.max(stacked_outputs, dim=1, keepdim=True)[0]
44
  elif self.mode == "mean":
45
  output_probs = torch.mean(stacked_outputs, dim=1, keepdim=True)
 
 
46
  elif self.mode == "min":
47
  output_probs = torch.min(stacked_outputs, dim=1, keepdim=True)[0]
48
  elif self.mode == "none":
49
  # Return all predictions without aggregation
50
  return stacked_outputs, None
51
  else:
52
- raise ValueError("Mode must be 'min', 'mean', or 'max'.")
53
 
54
  # Calculate uncertainty (normalized standard deviation)
55
  N = len(outputs)
@@ -69,16 +71,19 @@ class EnsembleModel(torch.nn.Module):
69
  uncertainty = torch.zeros_like(output_probs)
70
 
71
  return output_probs, uncertainty # Both (B, 1, H, W)
72
-
73
  def compiled_model(
74
  path: pathlib.Path,
75
  stac_item: pystac.Item,
76
- mode: Literal["min", "mean", "max"] = "max",
77
  *args, **kwargs
78
  ) -> EnsembleModel:
79
  """
80
  Loads the ensemble dynamically using the 'stac_item' as the source of truth.
81
 
 
 
 
82
  Returns an EnsembleModel that outputs (probabilities, uncertainty) tuple.
83
  """
84
  model_paths = []
 
14
  super(EnsembleModel, self).__init__()
15
  self.models = torch.nn.ModuleList(models)
16
  self.mode = mode
17
+ if mode not in ["min", "mean", "median", "max", "none"]:
18
+ raise ValueError("Mode must be 'none', 'min', 'mean', 'median', or 'max'.")
19
 
20
  def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]:
21
  """
 
38
  stacked_outputs = torch.stack(outputs, dim=1) # (B, N, 1, H, W)
39
  stacked_outputs = stacked_outputs.squeeze(2) # (B, N, H, W)
40
 
41
+ # Calculate aggregated probabilities
42
  if self.mode == "max":
43
  output_probs = torch.max(stacked_outputs, dim=1, keepdim=True)[0]
44
  elif self.mode == "mean":
45
  output_probs = torch.mean(stacked_outputs, dim=1, keepdim=True)
46
+ elif self.mode == "median":
47
+ output_probs = torch.median(stacked_outputs, dim=1, keepdim=True)[0]
48
  elif self.mode == "min":
49
  output_probs = torch.min(stacked_outputs, dim=1, keepdim=True)[0]
50
  elif self.mode == "none":
51
  # Return all predictions without aggregation
52
  return stacked_outputs, None
53
  else:
54
+ raise ValueError("Mode must be 'min', 'mean', 'median', or 'max'.")
55
 
56
  # Calculate uncertainty (normalized standard deviation)
57
  N = len(outputs)
 
71
  uncertainty = torch.zeros_like(output_probs)
72
 
73
  return output_probs, uncertainty # Both (B, 1, H, W)
74
+
75
  def compiled_model(
76
  path: pathlib.Path,
77
  stac_item: pystac.Item,
78
+ mode: Literal["min", "mean", "median", "max"] = "max",
79
  *args, **kwargs
80
  ) -> EnsembleModel:
81
  """
82
  Loads the ensemble dynamically using the 'stac_item' as the source of truth.
83
 
84
+ Args:
85
+ mode: Aggregation mode - 'min', 'mean', 'median', or 'max' (default: 'max')
86
+
87
  Returns an EnsembleModel that outputs (probabilities, uncertainty) tuple.
88
  """
89
  model_paths = []