SavirD commited on
Commit
d9452da
·
verified ·
1 Parent(s): ab2a5b9

Upload folder using huggingface_hub

Browse files
server/meta_optimizer_environment.py CHANGED
@@ -33,6 +33,11 @@ BATCH_SIZE = 32
33
  DENSE_REWARD_SCALE = 0.2
34
 
35
 
 
 
 
 
 
36
  def _build_model(spec: TaskSpec) -> nn.Module:
37
  """Build a 2-layer MLP for the given task spec."""
38
  torch.manual_seed(spec.arch_seed)
@@ -71,7 +76,7 @@ def run_adam_baseline(
71
  raise ValueError("Provide exactly one of task_id or task_spec")
72
  if seed is not None:
73
  torch.manual_seed(seed)
74
- device = torch.device("cpu")
75
  spec = task_spec_from_dict(task_spec) if task_spec is not None else get_task(task_id)
76
  model = _build_model(spec).to(device)
77
  opt = torch.optim.Adam(model.parameters(), lr=lr)
@@ -121,7 +126,7 @@ def run_sgd_baseline(
121
  raise ValueError("Provide exactly one of task_id or task_spec")
122
  if seed is not None:
123
  torch.manual_seed(seed)
124
- device = torch.device("cpu")
125
  spec = task_spec_from_dict(task_spec) if task_spec is not None else get_task(task_id)
126
  model = _build_model(spec).to(device)
127
  opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)
@@ -218,7 +223,7 @@ class MetaOptimizerEnvironment(Environment[MetaOptimizerAction, MetaOptimizerObs
218
  super().__init__(**kwargs)
219
  self.loss_threshold = loss_threshold
220
  self.max_steps = max_steps
221
- self._device = torch.device("cpu")
222
 
223
  # Episode state (set in reset)
224
  self._task_spec: Optional[TaskSpec] = None
 
33
  DENSE_REWARD_SCALE = 0.2
34
 
35
 
36
+ def _default_device() -> torch.device:
37
+ """Use CUDA when available, otherwise CPU."""
38
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
+
40
+
41
  def _build_model(spec: TaskSpec) -> nn.Module:
42
  """Build a 2-layer MLP for the given task spec."""
43
  torch.manual_seed(spec.arch_seed)
 
76
  raise ValueError("Provide exactly one of task_id or task_spec")
77
  if seed is not None:
78
  torch.manual_seed(seed)
79
+ device = _default_device()
80
  spec = task_spec_from_dict(task_spec) if task_spec is not None else get_task(task_id)
81
  model = _build_model(spec).to(device)
82
  opt = torch.optim.Adam(model.parameters(), lr=lr)
 
126
  raise ValueError("Provide exactly one of task_id or task_spec")
127
  if seed is not None:
128
  torch.manual_seed(seed)
129
+ device = _default_device()
130
  spec = task_spec_from_dict(task_spec) if task_spec is not None else get_task(task_id)
131
  model = _build_model(spec).to(device)
132
  opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)
 
223
  super().__init__(**kwargs)
224
  self.loss_threshold = loss_threshold
225
  self.max_steps = max_steps
226
+ self._device = _default_device()
227
 
228
  # Episode state (set in reset)
229
  self._task_spec: Optional[TaskSpec] = None