vermouthdky commited on
Commit
56ccb13
·
verified ·
1 Parent(s): 7c50ccf

Upload nets.py

Browse files
Files changed (1) hide show
  1. nets.py +30 -13
nets.py CHANGED
@@ -14,6 +14,8 @@
14
 
15
  """Deep networks."""
16
 
 
 
17
  import numpy as np
18
  import torch
19
  import torch.nn.functional as F
@@ -131,10 +133,14 @@ class EnsembleFC(nn.Module):
131
  return torch.add(wx, self.bias[:, None, None, :]) # w times x + b
132
 
133
 
134
- class EnsembleModel(nn.Module):
 
 
 
 
135
  def __init__(self, encoding_dim, num_ensemble, hidden_dim=128, activation="relu", dtype=torch.float32) -> None:
136
  # super().__init__(encoding_dim, hidden_dim, activation)
137
- super(EnsembleModel, self).__init__()
138
  self.num_ensemble = num_ensemble
139
  self.hidden_dim = hidden_dim
140
  self.output_dim = 1
@@ -152,23 +158,34 @@ class EnsembleModel(nn.Module):
152
  else:
153
  raise ValueError(f"Unknown activation {activation}")
154
 
155
- def get_params(self) -> torch.Tensor:
156
- params = []
157
- for pp in list(self.parameters()):
158
- params.append(pp.view(-1))
159
- return torch.cat(params)
160
-
161
  def forward(self, encoding: torch.Tensor) -> torch.Tensor:
162
  x = self.activation(self.nn1(encoding))
163
  x = self.activation(self.nn2(x))
164
  score = self.nn_out(x)
165
  return score
166
 
167
- def init(self):
168
- self.init_params = self.get_params().data.clone()
169
- if torch.cuda.is_available():
170
- self.init_params = self.init_params.cuda()
171
-
172
  def regularization(self):
173
  """Prior towards independent initialization."""
174
  return ((self.get_params() - self.init_params) ** 2).mean()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  """Deep networks."""
16
 
17
+ from copy import deepcopy
18
+
19
  import numpy as np
20
  import torch
21
  import torch.nn.functional as F
 
133
  return torch.add(wx, self.bias[:, None, None, :]) # w times x + b
134
 
135
 
136
+ def get_params(model):
137
+ return torch.cat([p.view(-1) for p in model.parameters()])
138
+
139
+
140
+ class _EnsembleModel(nn.Module):
141
  def __init__(self, encoding_dim, num_ensemble, hidden_dim=128, activation="relu", dtype=torch.float32) -> None:
142
  # super().__init__(encoding_dim, hidden_dim, activation)
143
+ super(_EnsembleModel, self).__init__()
144
  self.num_ensemble = num_ensemble
145
  self.hidden_dim = hidden_dim
146
  self.output_dim = 1
 
158
  else:
159
  raise ValueError(f"Unknown activation {activation}")
160
 
 
 
 
 
 
 
161
  def forward(self, encoding: torch.Tensor) -> torch.Tensor:
162
  x = self.activation(self.nn1(encoding))
163
  x = self.activation(self.nn2(x))
164
  score = self.nn_out(x)
165
  return score
166
 
 
 
 
 
 
167
  def regularization(self):
168
  """Prior towards independent initialization."""
169
  return ((self.get_params() - self.init_params) ** 2).mean()
170
+
171
+
172
+ class EnsembleModel(nn.Module):
173
+ def __init__(self, encoding_dim, num_ensemble, hidden_dim=128, activation="relu", dtype=torch.float32) -> None:
174
+ super(EnsembleModel, self).__init__()
175
+ self.encoding_dim = encoding_dim
176
+ self.num_ensemble = num_ensemble
177
+ self.hidden_dim = hidden_dim
178
+ self.model = _EnsembleModel(encoding_dim, num_ensemble, hidden_dim, activation, dtype)
179
+ self.reg_model = deepcopy(self.model) # only used for regularization
180
+ # freeze the reg model
181
+ for param in self.reg_model.parameters():
182
+ param.requires_grad = False
183
+
184
+ def forward(self, encoding: torch.Tensor) -> torch.Tensor:
185
+ return self.model(encoding)
186
+
187
+ def regularization(self):
188
+ """Prior towards independent initialization."""
189
+ model_params = get_params(self.model)
190
+ reg_params = get_params(self.reg_model).detach()
191
+ return ((model_params - reg_params) ** 2).mean()