csaybar commited on
Commit
a09c701
·
verified ·
1 Parent(s): 62a17d0

Update ensemble/load.py

Browse files
Files changed (1) hide show
  1. ensemble/load.py +25 -3
ensemble/load.py CHANGED
@@ -17,14 +17,15 @@ def load_model_module(model_path: pathlib.Path):
17
  return model
18
 
19
  class EnsembleModel(torch.nn.Module):
20
- def __init__(self, model1, model2, model3, model4, model5, mode="max"):
21
  super(EnsembleModel, self).__init__()
22
  self.model1 = model1
23
  self.model2 = model2
24
  self.model3 = model3
25
  self.model4 = model4
26
  self.model5 = model5
27
- self.models = [model1, model2, model3, model4, model5]
 
28
  self.mode = mode
29
  if mode not in ["min", "mean", "max"]:
30
  raise ValueError("Mode must be 'min', 'mean', or 'max'.")
@@ -42,6 +43,8 @@ class EnsembleModel(torch.nn.Module):
42
  output_probs = torch.mean(torch.cat(outputs, dim=1), dim=1)[0].squeeze()
43
  elif self.mode == "min":
44
  output_probs = torch.min(torch.cat(outputs, dim=1), dim=1)[0].squeeze()
 
 
45
  else:
46
  raise ValueError("Mode must be 'min', 'mean', or 'max'.")
47
 
@@ -68,6 +71,7 @@ def compiled_model(path, device: str = "cpu", mode: Literal["min", "mean", "max"
68
  model3_f = path / "1dpwunetpp.safetensor"
69
  model4_f = path / "unet.safetensor"
70
  model5_f = path / "unetpp.safetensor"
 
71
 
72
  # Load model parameters
73
  model1_weights = safetensors.torch.load_file(model1_f)
@@ -75,6 +79,7 @@ def compiled_model(path, device: str = "cpu", mode: Literal["min", "mean", "max"
75
  model3_weights = safetensors.torch.load_file(model3_f)
76
  model4_weights = safetensors.torch.load_file(model4_f)
77
  model5_weights = safetensors.torch.load_file(model5_f)
 
78
 
79
  # Load all models
80
 
@@ -128,12 +133,29 @@ def compiled_model(path, device: str = "cpu", mode: Literal["min", "mean", "max"
128
  for param in model5.parameters():
129
  param.requires_grad = False
130
  model5 = model5.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  # Create ensemble model
133
- cloud_model = EnsembleModel(model1, model2, model3, model4, model5, mode=mode)
134
 
135
  return cloud_model
136
 
 
 
137
  def display_results(path: pathlib.Path, device: str = "cpu", mode: Literal["min", "mean", "max"] ="max", *args, **kwargs):
138
  # Load model
139
  model = compiled_model(path, device, mode=mode)
 
17
  return model
18
 
19
  class EnsembleModel(torch.nn.Module):
20
+ def __init__(self, model1, model2, model3, model4, model5, model6, mode="max"):
21
  super(EnsembleModel, self).__init__()
22
  self.model1 = model1
23
  self.model2 = model2
24
  self.model3 = model3
25
  self.model4 = model4
26
  self.model5 = model5
27
+ self.model6 = model6
28
+ self.models = [model1, model2, model3, model4, model5, model6]
29
  self.mode = mode
30
  if mode not in ["min", "mean", "max"]:
31
  raise ValueError("Mode must be 'min', 'mean', or 'max'.")
 
43
  output_probs = torch.mean(torch.cat(outputs, dim=1), dim=1)[0].squeeze()
44
  elif self.mode == "min":
45
  output_probs = torch.min(torch.cat(outputs, dim=1), dim=1)[0].squeeze()
46
+ elif self.mode == "none":
47
+ return torch.cat(outputs, dim=1)
48
  else:
49
  raise ValueError("Mode must be 'min', 'mean', or 'max'.")
50
 
 
71
  model3_f = path / "1dpwunetpp.safetensor"
72
  model4_f = path / "unet.safetensor"
73
  model5_f = path / "unetpp.safetensor"
74
+ model6_f = path / "c2r1km.safetensor"
75
 
76
  # Load model parameters
77
  model1_weights = safetensors.torch.load_file(model1_f)
 
79
  model3_weights = safetensors.torch.load_file(model3_f)
80
  model4_weights = safetensors.torch.load_file(model4_f)
81
  model5_weights = safetensors.torch.load_file(model5_f)
82
+ model6_weights = safetensors.torch.load_file(model6_f)
83
 
84
  # Load all models
85
 
 
133
  for param in model5.parameters():
134
  param.requires_grad = False
135
  model5 = model5.eval()
136
+
137
+ # Model 6 (C2R1KM)
138
+ model6 = load_model_module(path / "c2r1km.py").CloudMaskOne(
139
+ hidden_layer_sizes=(21, 20),
140
+ activation='relu',
141
+ last_activation='sigmoid',
142
+ dropout_rate=0.1,
143
+ input_dim=40,
144
+ batch_norm=False
145
+ )
146
+ model6.load_state_dict(model6_weights)
147
+ model6 = model6.to(device)
148
+ for param in model6.parameters():
149
+ param.requires_grad = False
150
+ model6 = model6.eval()
151
 
152
  # Create ensemble model
153
+ cloud_model = EnsembleModel(model1, model2, model3, model4, model5, model6, mode=mode)
154
 
155
  return cloud_model
156
 
157
+
158
+
159
  def display_results(path: pathlib.Path, device: str = "cpu", mode: Literal["min", "mean", "max"] ="max", *args, **kwargs):
160
  # Load model
161
  model = compiled_model(path, device, mode=mode)