File size: 1,792 Bytes
c668e80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""Module for dynamic data transfrom."""
import os
import importlib

from .transform import (
    make_transforms,
    get_specials,
    save_transforms,
    load_transforms,
    TransformPipe,
    Transform,
)


AVAILABLE_TRANSFORMS = {}


def get_transforms_cls(transform_names):
    """Return valid transform class indicated in `transform_names`."""
    transforms_cls = {}
    for name in transform_names:
        if name not in AVAILABLE_TRANSFORMS:
            raise ValueError("%s transform not supported!" % name)
        transforms_cls[name] = AVAILABLE_TRANSFORMS[name]
    return transforms_cls


__all__ = [
    "get_transforms_cls",
    "get_specials",
    "make_transforms",
    "load_transforms",
    "save_transforms",
    "TransformPipe",
    "prepare_transforms",
]


def register_transform(name):
    """Transform register that can be used to add new transform class."""

    def register_transfrom_cls(cls):
        if name in AVAILABLE_TRANSFORMS:
            raise ValueError("Cannot register duplicate transform ({})".format(name))
        if not issubclass(cls, Transform):
            raise ValueError(
                "transform ({}: {}) must extend Transform".format(name, cls.__name__)
            )
        AVAILABLE_TRANSFORMS[name] = cls
        return cls

    return register_transfrom_cls


# Auto import python files in this directory
transform_dir = os.path.dirname(__file__)
for file in os.listdir(transform_dir):
    path = os.path.join(transform_dir, file)
    if (
        not file.startswith("_")
        and not file.startswith(".")
        and (file.endswith(".py") or os.path.isdir(path))
    ):
        file_name = file[: file.find(".py")] if file.endswith(".py") else file
        module = importlib.import_module("onmt.transforms." + file_name)