Oopstom's picture
Upload 313 files
c668e80 verified
from onmt.constants import DefaultTokens
from onmt.transforms import register_transform
from .transform import Transform
import copy
@register_transform(name="docify")
class DocifyTransform(Transform):
"""
Convert source and target examples to doc level segments.
It concatenates segments with a DefaultTokens.SEP
until it reaches --doc_length tokens
"""
def __init__(self, opts):
super().__init__(opts)
@classmethod
def add_options(cls, parser):
"""Add an option for the corpus ratio to apply this transform."""
group = parser.add_argument_group("Transform/Docify")
group.add(
"--doc_length",
"-doc_length",
type=int,
default=200,
help="Number of tokens per doc.",
)
group.add(
"--max_context",
"-max_context",
type=int,
default=1,
help="Max context segments.",
)
def _parse_opts(self):
if hasattr(self.opts, "num_workers") and hasattr(self.opts, "world_size"):
self.stride = self.opts.num_workers * self.opts.world_size
else:
self.stride = 1
self.doc_length = self.opts.doc_length
self.max_context = self.opts.max_context
@classmethod
def get_specials(cls, opts):
"""Add newline tag to src and tgt vocabs."""
src_specials, tgt_specials = [DefaultTokens.SEP], [DefaultTokens.SEP]
return (src_specials, tgt_specials)
def warm_up(self, vocabs=None):
super().warm_up(None)
if self.stride != 1:
assert (
self.stride % (self.max_context + 1) == 0
), "(max_context+1) must be a multiple \
of num_workers * world_size"
def batch_apply(self, batch, is_train=False, stats=None, **kwargs):
"""Convert source and target examples to doc level segments."""
if self.max_context == 0:
return batch
trf_batch = []
doc = {}
doc["src"] = []
doc["tgt"] = []
doc["indices"] = 0
for ex, _, cid in batch:
if ex["tgt"] is not None:
cur_len = max(len(doc["src"] + ex["src"]), len(doc["tgt"] + ex["tgt"]))
if len(ex["src"]) == 0 and len(ex["tgt"]) == 0:
# doc break we add it, restart new doc
trf_batch.append((doc, self, cid))
doc = {}
doc["src"] = []
doc["tgt"] = []
doc["indices"] = ex["indices"]
elif cur_len > self.doc_length:
if len(doc["src"]) == 0:
# case 1st ex is already longer
trf_batch.append((ex, self, cid))
else:
# adding cur ex is too long we add cur doc
# and reset doc to cur ex
trf_batch.append((doc, self, cid))
doc = copy.deepcopy(ex)
else:
if len(doc["src"]) == 0:
# we start the new doc with cur ex
doc = copy.deepcopy(ex)
else:
# we cumulate cur ex to cur doc
doc["src"] += [DefaultTokens.SEP] + ex["src"]
doc["src_original"] += [DefaultTokens.SEP] + ex["src_original"]
doc["tgt"] += [DefaultTokens.SEP] + ex["tgt"]
doc["tgt_original"] += [DefaultTokens.SEP] + ex["tgt_original"]
nb_ctx = doc["src"].count(DefaultTokens.SEP)
if nb_ctx >= self.max_context:
trf_batch.append((doc, self, cid))
doc = {}
doc["src"] = []
doc["tgt"] = []
doc["indices"] = ex["indices"]
else:
cur_len = len(doc["src"] + ex["src"])
doc["tgt"] = None
if len(ex["src"]) == 0:
trf_batch.append((doc, self, cid))
doc = {}
doc["src"] = []
doc["indices"] = ex["indices"]
elif cur_len > self.doc_length:
if len(doc["src"]) == 0:
trf_batch.append((ex, self, cid))
else:
trf_batch.append((doc, self, cid))
doc = copy.deepcopy(ex)
else:
if len(doc["src"]) == 0:
doc = copy.deepcopy(ex)
else:
doc["src"] += [DefaultTokens.SEP] + ex["src"]
doc["src_original"] += [DefaultTokens.SEP] + ex["src_original"]
nb_ctx = doc["src"].count(DefaultTokens.SEP)
if nb_ctx >= self.max_context:
trf_batch.append((doc, self, cid))
doc = {}
doc["src"] = []
doc["indices"] = ex["indices"]
if len(doc["src"]) > 0:
trf_batch.append((doc, self, cid))
return trf_batch
def apply_reverse(self, translated):
segments = translated.split(DefaultTokens.SEP)
segments = [segment.strip(" ") for segment in segments]
return segments