|
|
|
|
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
|
from builtins import *
|
|
|
import sys
|
|
|
import torch
|
|
|
import re
|
|
|
|
|
|
from importlib.util import spec_from_file_location, module_from_spec
|
|
|
from pathlib import Path
|
|
|
from typing import Union
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_output_ref = None
|
|
|
_replicas_ref = None
|
|
|
|
|
|
def data_parallel_workaround(model, *input):
|
|
|
global _output_ref
|
|
|
global _replicas_ref
|
|
|
device_ids = list(range(torch.cuda.device_count()))
|
|
|
output_device = device_ids[0]
|
|
|
replicas = torch.nn.parallel.replicate(model, device_ids)
|
|
|
|
|
|
inputs = torch.nn.parallel.scatter(input, device_ids)
|
|
|
|
|
|
replicas = replicas[:len(inputs)]
|
|
|
outputs = torch.nn.parallel.parallel_apply(replicas, inputs)
|
|
|
y_hat = torch.nn.parallel.gather(outputs, output_device)
|
|
|
_output_ref = outputs
|
|
|
_replicas_ref = replicas
|
|
|
return y_hat
|
|
|
|
|
|
|
|
|
|
|
|
class __HParams:
|
|
|
"""Manages the hyperparams pseudo-module"""
|
|
|
def __init__(self, path: Union[str, Path]=None):
|
|
|
"""Constructs the hyperparameters from a path to a python module. If
|
|
|
`path` is None, will raise an AttributeError whenever its attributes
|
|
|
are accessed. Otherwise, configures self based on `path`."""
|
|
|
if path is None:
|
|
|
print("path is none")
|
|
|
self._configured = False
|
|
|
else:
|
|
|
self.configure(path)
|
|
|
|
|
|
def __getattr__(self, item):
|
|
|
print("self config2222:",self.is_configured())
|
|
|
if not self.is_configured():
|
|
|
raise AttributeError("HParams not configured yet. Call self.configure()")
|
|
|
else:
|
|
|
return super().__getattr__(item)
|
|
|
|
|
|
def configure(self, path: Union[str, Path]):
|
|
|
"""Configures hparams by copying over atrributes from a module with the
|
|
|
given path. Raises an exception if already configured."""
|
|
|
if self.is_configured():
|
|
|
raise RuntimeError("Cannot reconfigure hparams!")
|
|
|
print("path=",path)
|
|
|
|
|
|
if not isinstance(path, Path):
|
|
|
path = Path(path).expanduser()
|
|
|
if not path.exists():
|
|
|
raise FileNotFoundError(f"Could not find hparams file {path}")
|
|
|
elif path.suffix != ".py":
|
|
|
raise ValueError("`path` must be a python file")
|
|
|
|
|
|
|
|
|
m = _import_from_file("hparams", path)
|
|
|
|
|
|
reg = re.compile(r"^__.+__$")
|
|
|
for name, value in m.__dict__.items():
|
|
|
if reg.match(name):
|
|
|
|
|
|
continue
|
|
|
if name in self.__dict__:
|
|
|
|
|
|
raise AttributeError(
|
|
|
f"module at `path` cannot contain attribute {name} as it "
|
|
|
"overwrites an attribute of the same name in utils.hparams")
|
|
|
|
|
|
self.__setattr__(name, value)
|
|
|
|
|
|
self._configured = True
|
|
|
print("self config1111:",self._configured)
|
|
|
|
|
|
def is_configured(self):
|
|
|
return self._configured
|
|
|
|
|
|
hparams = __HParams()
|
|
|
|
|
|
|
|
|
def _import_from_file(name, path: Path):
|
|
|
"""Programmatically returns a module object from a filepath"""
|
|
|
if not Path(path).exists():
|
|
|
raise FileNotFoundError('"%s" doesn\'t exist!' % path)
|
|
|
spec = spec_from_file_location(name, path)
|
|
|
if spec is None:
|
|
|
raise ValueError('could not load module from "%s"' % path)
|
|
|
m = module_from_spec(spec)
|
|
|
spec.loader.exec_module(m)
|
|
|
return m |