File size: 4,248 Bytes
2e62044
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# Make it explicit that we do it the Python 3 way
from __future__ import absolute_import, division, print_function, unicode_literals
from builtins import *
import sys
import torch
import re

from importlib.util import spec_from_file_location, module_from_spec
from pathlib import Path
from typing import Union

# Credit: Ryuichi Yamamoto (https://github.com/r9y9/wavenet_vocoder/blob/1717f145c8f8c0f3f85ccdf346b5209fa2e1c920/train.py#L599)
# Modified by: Ryan Butler (https://github.com/TheButlah)
# workaround for https://github.com/pytorch/pytorch/issues/15716
# the idea is to return outputs and replicas explicitly, so that making pytorch
# not to release the nodes (this is a pytorch bug though)

_output_ref = None
_replicas_ref = None

def data_parallel_workaround(model, *input):
    global _output_ref
    global _replicas_ref
    device_ids = list(range(torch.cuda.device_count()))
    output_device = device_ids[0]
    replicas = torch.nn.parallel.replicate(model, device_ids)
    # input.shape = (num_args, batch, ...)
    inputs = torch.nn.parallel.scatter(input, device_ids)
    # inputs.shape = (num_gpus, num_args, batch/num_gpus, ...)
    replicas = replicas[:len(inputs)]
    outputs = torch.nn.parallel.parallel_apply(replicas, inputs)
    y_hat = torch.nn.parallel.gather(outputs, output_device)
    _output_ref = outputs
    _replicas_ref = replicas
    return y_hat


###### Deal with hparams import that has to be configured at runtime ######
class __HParams:
    """Manages the hyperparams pseudo-module"""
    def __init__(self, path: Union[str, Path]=None):
        """Constructs the hyperparameters from a path to a python module. If

        `path` is None, will raise an AttributeError whenever its attributes

        are accessed. Otherwise, configures self based on `path`."""
        if path is None:
            print("path is none")
            self._configured = False
        else:
            self.configure(path)

    def __getattr__(self, item):
        print("self config2222:",self.is_configured())
        if not self.is_configured():
            raise AttributeError("HParams not configured yet. Call self.configure()")
        else:
            return super().__getattr__(item)

    def configure(self, path: Union[str, Path]):
        """Configures hparams by copying over atrributes from a module with the

        given path. Raises an exception if already configured."""
        if self.is_configured():
            raise RuntimeError("Cannot reconfigure hparams!")
        print("path=",path)
        ###### Check for proper path ######
        if not isinstance(path, Path):
            path = Path(path).expanduser()
        if not path.exists():
            raise FileNotFoundError(f"Could not find hparams file {path}")
        elif path.suffix != ".py":
            raise ValueError("`path` must be a python file")

        ###### Load in attributes from module ######
        m = _import_from_file("hparams", path)

        reg = re.compile(r"^__.+__$")  # Matches magic methods
        for name, value in m.__dict__.items():
            if reg.match(name):
                # Skip builtins
                continue
            if name in self.__dict__:
                # Cannot overwrite already existing attributes
                raise AttributeError(
                    f"module at `path` cannot contain attribute {name} as it "
                    "overwrites an attribute of the same name in utils.hparams")
            # Fair game to copy over the attribute
            self.__setattr__(name, value)

        self._configured = True
        print("self config1111:",self._configured)

    def is_configured(self):
        return self._configured

hparams = __HParams()


def _import_from_file(name, path: Path):
    """Programmatically returns a module object from a filepath"""
    if not Path(path).exists():
        raise FileNotFoundError('"%s" doesn\'t exist!' % path)
    spec = spec_from_file_location(name, path)
    if spec is None:
        raise ValueError('could not load module from "%s"' % path)
    m = module_from_spec(spec)
    spec.loader.exec_module(m)
    return m