File size: 5,527 Bytes
c668e80 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
from onmt.constants import DefaultTokens
from onmt.transforms import register_transform
from .transform import Transform
import copy
@register_transform(name="docify")
class DocifyTransform(Transform):
"""
Convert source and target examples to doc level segments.
It concatenates segments with a DefaultTokens.SEP
until it reaches --doc_length tokens
"""
def __init__(self, opts):
super().__init__(opts)
@classmethod
def add_options(cls, parser):
"""Add an option for the corpus ratio to apply this transform."""
group = parser.add_argument_group("Transform/Docify")
group.add(
"--doc_length",
"-doc_length",
type=int,
default=200,
help="Number of tokens per doc.",
)
group.add(
"--max_context",
"-max_context",
type=int,
default=1,
help="Max context segments.",
)
def _parse_opts(self):
if hasattr(self.opts, "num_workers") and hasattr(self.opts, "world_size"):
self.stride = self.opts.num_workers * self.opts.world_size
else:
self.stride = 1
self.doc_length = self.opts.doc_length
self.max_context = self.opts.max_context
@classmethod
def get_specials(cls, opts):
"""Add newline tag to src and tgt vocabs."""
src_specials, tgt_specials = [DefaultTokens.SEP], [DefaultTokens.SEP]
return (src_specials, tgt_specials)
def warm_up(self, vocabs=None):
super().warm_up(None)
if self.stride != 1:
assert (
self.stride % (self.max_context + 1) == 0
), "(max_context+1) must be a multiple \
of num_workers * world_size"
def batch_apply(self, batch, is_train=False, stats=None, **kwargs):
"""Convert source and target examples to doc level segments."""
if self.max_context == 0:
return batch
trf_batch = []
doc = {}
doc["src"] = []
doc["tgt"] = []
doc["indices"] = 0
for ex, _, cid in batch:
if ex["tgt"] is not None:
cur_len = max(len(doc["src"] + ex["src"]), len(doc["tgt"] + ex["tgt"]))
if len(ex["src"]) == 0 and len(ex["tgt"]) == 0:
# doc break we add it, restart new doc
trf_batch.append((doc, self, cid))
doc = {}
doc["src"] = []
doc["tgt"] = []
doc["indices"] = ex["indices"]
elif cur_len > self.doc_length:
if len(doc["src"]) == 0:
# case 1st ex is already longer
trf_batch.append((ex, self, cid))
else:
# adding cur ex is too long we add cur doc
# and reset doc to cur ex
trf_batch.append((doc, self, cid))
doc = copy.deepcopy(ex)
else:
if len(doc["src"]) == 0:
# we start the new doc with cur ex
doc = copy.deepcopy(ex)
else:
# we cumulate cur ex to cur doc
doc["src"] += [DefaultTokens.SEP] + ex["src"]
doc["src_original"] += [DefaultTokens.SEP] + ex["src_original"]
doc["tgt"] += [DefaultTokens.SEP] + ex["tgt"]
doc["tgt_original"] += [DefaultTokens.SEP] + ex["tgt_original"]
nb_ctx = doc["src"].count(DefaultTokens.SEP)
if nb_ctx >= self.max_context:
trf_batch.append((doc, self, cid))
doc = {}
doc["src"] = []
doc["tgt"] = []
doc["indices"] = ex["indices"]
else:
cur_len = len(doc["src"] + ex["src"])
doc["tgt"] = None
if len(ex["src"]) == 0:
trf_batch.append((doc, self, cid))
doc = {}
doc["src"] = []
doc["indices"] = ex["indices"]
elif cur_len > self.doc_length:
if len(doc["src"]) == 0:
trf_batch.append((ex, self, cid))
else:
trf_batch.append((doc, self, cid))
doc = copy.deepcopy(ex)
else:
if len(doc["src"]) == 0:
doc = copy.deepcopy(ex)
else:
doc["src"] += [DefaultTokens.SEP] + ex["src"]
doc["src_original"] += [DefaultTokens.SEP] + ex["src_original"]
nb_ctx = doc["src"].count(DefaultTokens.SEP)
if nb_ctx >= self.max_context:
trf_batch.append((doc, self, cid))
doc = {}
doc["src"] = []
doc["indices"] = ex["indices"]
if len(doc["src"]) > 0:
trf_batch.append((doc, self, cid))
return trf_batch
def apply_reverse(self, translated):
segments = translated.split(DefaultTokens.SEP)
segments = [segment.strip(" ") for segment in segments]
return segments
|