File size: 1,864 Bytes
c668e80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from onmt.transforms import register_transform
from .transform import Transform
from onmt.utils.alignment import subword_map_by_joiner, subword_map_by_spacer


@register_transform(name="inferfeats")
class InferFeatsTransform(Transform):
    """Infer features for subword tokenization."""

    def __init__(self, opts):
        super().__init__(opts)

    @classmethod
    def add_options(cls, parser):
        """Avalilable options related to this Transform."""
        group = parser.add_argument_group("Transform/InferFeats")
        group.add(
            "--reversible_tokenization",
            "-reversible_tokenization",
            default="joiner",
            choices=["joiner", "spacer"],
            help="Type of reversible tokenization " "applied on the tokenizer.",
        )

    def _parse_opts(self):
        super()._parse_opts()
        self.reversible_tokenization = self.opts.reversible_tokenization

    def apply(self, example, is_train=False, stats=None, **kwargs):
        if "src_feats" not in example:
            # Do nothing
            return example

        if self.reversible_tokenization == "joiner":
            original_src = example["src_original"]
            word_to_subword_mapping = subword_map_by_joiner(
                example["src"], original_subwords=original_src
            )
        else:  # Spacer
            word_to_subword_mapping = subword_map_by_spacer(example["src"])

        new_src_feats = [[] for _ in range(len(example["src_feats"]))]
        for subword, word_id in zip(example["src"], word_to_subword_mapping):
            for i, feat_values in enumerate(example["src_feats"]):
                inferred_feat = feat_values[word_id]
                new_src_feats[i].append(inferred_feat)
        example["src_feats"] = new_src_feats

        return example

    def _repr_args(self):
        return ""