niobures's picture
RNNoise (models)
2e62044 verified
# Make it explicit that we do it the Python 3 way
from __future__ import absolute_import, division, print_function, unicode_literals
from builtins import *
import sys
import torch
import re
from importlib.util import spec_from_file_location, module_from_spec
from pathlib import Path
from typing import Union
# Credit: Ryuichi Yamamoto (https://github.com/r9y9/wavenet_vocoder/blob/1717f145c8f8c0f3f85ccdf346b5209fa2e1c920/train.py#L599)
# Modified by: Ryan Butler (https://github.com/TheButlah)
# workaround for https://github.com/pytorch/pytorch/issues/15716
# the idea is to return outputs and replicas explicitly, so that making pytorch
# not to release the nodes (this is a pytorch bug though)
_output_ref = None
_replicas_ref = None
def data_parallel_workaround(model, *input):
global _output_ref
global _replicas_ref
device_ids = list(range(torch.cuda.device_count()))
output_device = device_ids[0]
replicas = torch.nn.parallel.replicate(model, device_ids)
# input.shape = (num_args, batch, ...)
inputs = torch.nn.parallel.scatter(input, device_ids)
# inputs.shape = (num_gpus, num_args, batch/num_gpus, ...)
replicas = replicas[:len(inputs)]
outputs = torch.nn.parallel.parallel_apply(replicas, inputs)
y_hat = torch.nn.parallel.gather(outputs, output_device)
_output_ref = outputs
_replicas_ref = replicas
return y_hat
###### Deal with hparams import that has to be configured at runtime ######
class __HParams:
"""Manages the hyperparams pseudo-module"""
def __init__(self, path: Union[str, Path]=None):
"""Constructs the hyperparameters from a path to a python module. If
`path` is None, will raise an AttributeError whenever its attributes
are accessed. Otherwise, configures self based on `path`."""
if path is None:
print("path is none")
self._configured = False
else:
self.configure(path)
def __getattr__(self, item):
print("self config2222:",self.is_configured())
if not self.is_configured():
raise AttributeError("HParams not configured yet. Call self.configure()")
else:
return super().__getattr__(item)
def configure(self, path: Union[str, Path]):
"""Configures hparams by copying over atrributes from a module with the
given path. Raises an exception if already configured."""
if self.is_configured():
raise RuntimeError("Cannot reconfigure hparams!")
print("path=",path)
###### Check for proper path ######
if not isinstance(path, Path):
path = Path(path).expanduser()
if not path.exists():
raise FileNotFoundError(f"Could not find hparams file {path}")
elif path.suffix != ".py":
raise ValueError("`path` must be a python file")
###### Load in attributes from module ######
m = _import_from_file("hparams", path)
reg = re.compile(r"^__.+__$") # Matches magic methods
for name, value in m.__dict__.items():
if reg.match(name):
# Skip builtins
continue
if name in self.__dict__:
# Cannot overwrite already existing attributes
raise AttributeError(
f"module at `path` cannot contain attribute {name} as it "
"overwrites an attribute of the same name in utils.hparams")
# Fair game to copy over the attribute
self.__setattr__(name, value)
self._configured = True
print("self config1111:",self._configured)
def is_configured(self):
return self._configured
hparams = __HParams()
def _import_from_file(name, path: Path):
"""Programmatically returns a module object from a filepath"""
if not Path(path).exists():
raise FileNotFoundError('"%s" doesn\'t exist!' % path)
spec = spec_from_file_location(name, path)
if spec is None:
raise ValueError('could not load module from "%s"' % path)
m = module_from_spec(spec)
spec.loader.exec_module(m)
return m