File size: 10,117 Bytes
c668e80 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 |
from onmt.utils.logging import logger
from onmt.transforms import register_transform
from .transform import Transform, ObservableStats
class FilterTooLongStats(ObservableStats):
"""Runing statistics for FilterTooLongTransform."""
__slots__ = ["filtered"]
def __init__(self):
self.filtered = 1
def update(self, other: "FilterTooLongStats"):
self.filtered += other.filtered
@register_transform(name="filtertoolong")
class FilterTooLongTransform(Transform):
"""Filter out sentence that are too long."""
def __init__(self, opts):
super().__init__(opts)
@classmethod
def add_options(cls, parser):
"""
Available options relate to this Transform.
For performance it is better to use multiple of 8
On target side, since we'll add BOS/EOS, we filter with minus 2
"""
group = parser.add_argument_group("Transform/Filter")
group.add(
"--src_seq_length",
"-src_seq_length",
type=int,
default=192,
help="Maximum source sequence length.",
)
group.add(
"--tgt_seq_length",
"-tgt_seq_length",
type=int,
default=192,
help="Maximum target sequence length.",
)
def _parse_opts(self):
self.src_seq_length = self.opts.src_seq_length
self.tgt_seq_length = self.opts.tgt_seq_length
def apply(self, example, is_train=False, stats=None, **kwargs):
"""Return None if too long else return as is."""
if (
len(example["src"]) > self.src_seq_length
or len(example["tgt"]) > self.tgt_seq_length - 2
):
if stats is not None:
stats.update(FilterTooLongStats())
return None
else:
return example
def _repr_args(self):
"""Return str represent key arguments for class."""
return "{}={}, {}={}".format(
"src_seq_length", self.src_seq_length, "tgt_seq_length", self.tgt_seq_length
)
@register_transform(name="prefix")
class PrefixTransform(Transform):
"""Add Prefix to src (& tgt) sentence."""
def __init__(self, opts):
super().__init__(opts)
@classmethod
def add_options(cls, parser):
"""Avalailable options relate to this Transform."""
group = parser.add_argument_group("Transform/Prefix")
group.add(
"--src_prefix",
"-src_prefix",
type=str,
default="",
help="String to prepend to all source example.",
)
group.add(
"--tgt_prefix",
"-tgt_prefix",
type=str,
default="",
help="String to prepend to all target example.",
)
@staticmethod
def _get_prefix(corpus):
"""Get prefix string of a `corpus`."""
if "prefix" in corpus["transforms"]:
src_prefix = corpus.get("src_prefix", "")
tgt_prefix = corpus.get("tgt_prefix", "")
prefix = {"src": src_prefix, "tgt": tgt_prefix}
else:
prefix = None
return prefix
@classmethod
def get_prefix_dict(cls, opts):
"""Get all needed prefix correspond to corpus in `opts`."""
prefix_dict = {}
# prefix src/tgt for each dataset
if hasattr(opts, "data"):
for c_name, corpus in opts.data.items():
prefix = cls._get_prefix(corpus)
if prefix is not None:
logger.debug(f"Get prefix for {c_name}: {prefix}")
prefix_dict[c_name] = prefix
# prefix as general option for inference
if hasattr(opts, "src_prefix"):
if "infer" not in prefix_dict.keys():
prefix_dict["infer"] = {}
prefix_dict["infer"]["src"] = opts.src_prefix
logger.debug(f"Get prefix for src infer: {opts.src_prefix}")
if hasattr(opts, "tgt_prefix"):
if "infer" not in prefix_dict.keys():
prefix_dict["infer"] = {}
prefix_dict["infer"]["tgt"] = opts.tgt_prefix
logger.debug(f"Get prefix for tgt infer: {opts.tgt_prefix}")
return prefix_dict
@classmethod
def get_specials(cls, opts):
"""Get special vocabs added by prefix transform."""
prefix_dict = cls.get_prefix_dict(opts)
src_specials, tgt_specials = set(), set()
for _, prefix in prefix_dict.items():
src_specials.update(prefix["src"].split())
tgt_specials.update(prefix["tgt"].split())
return (src_specials, tgt_specials)
def warm_up(self, vocabs=None):
"""Warm up to get prefix dictionary."""
super().warm_up(None)
self.prefix_dict = self.get_prefix_dict(self.opts)
def _prepend(self, example, prefix):
"""Prepend `prefix` to `tokens`."""
for side, side_prefix in prefix.items():
if example.get(side) is not None:
example[side] = side_prefix.split() + example[side]
elif len(side_prefix) > 0:
example[side] = side_prefix.split()
return example
def apply(self, example, is_train=False, stats=None, **kwargs):
"""Apply prefix prepend to example.
Should provide `corpus_name` to get correspond prefix.
"""
corpus_name = kwargs.get("corpus_name", None)
if corpus_name is None:
raise ValueError("corpus_name is required.")
corpus_prefix = self.prefix_dict.get(corpus_name, None)
if corpus_prefix is None:
raise ValueError(f"prefix for {corpus_name} does not exist.")
return self._prepend(example, corpus_prefix)
def apply_reverse(self, translated):
def _removeprefix(s, prefix):
if s.startswith(prefix) and len(prefix) > 0:
return s[len(prefix) + 1 :]
else:
return s
corpus_prefix = self.prefix_dict.get("infer", None)
return _removeprefix(translated, corpus_prefix["tgt"])
def _repr_args(self):
"""Return str represent key arguments for class."""
return "{}={}".format("prefix_dict", self.prefix_dict)
@register_transform(name="suffix")
class SuffixTransform(Transform):
"""Add Suffix to src (& tgt) sentence."""
def __init__(self, opts):
super().__init__(opts)
@classmethod
def add_options(cls, parser):
"""Avalailable options relate to this Transform."""
group = parser.add_argument_group("Transform/Suffix")
group.add(
"--src_suffix",
"-src_suffix",
type=str,
default="",
help="String to append to all source example.",
)
group.add(
"--tgt_suffix",
"-tgt_suffix",
type=str,
default="",
help="String to append to all target example.",
)
@staticmethod
def _get_suffix(corpus):
"""Get suffix string of a `corpus`."""
if "suffix" in corpus["transforms"]:
src_suffix = corpus.get("src_suffix", "")
tgt_suffix = corpus.get("tgt_suffix", "")
suffix = {"src": src_suffix, "tgt": tgt_suffix}
else:
suffix = None
return suffix
@classmethod
def get_suffix_dict(cls, opts):
"""Get all needed suffix correspond to corpus in `opts`."""
suffix_dict = {}
# suffix src/tgt for each dataset
if hasattr(opts, "data"):
for c_name, corpus in opts.data.items():
suffix = cls._get_suffix(corpus)
if suffix is not None:
logger.debug(f"Get suffix for {c_name}: {suffix}")
suffix_dict[c_name] = suffix
# suffix as general option for inference
if hasattr(opts, "src_suffix"):
if "infer" not in suffix_dict.keys():
suffix_dict["infer"] = {}
suffix_dict["infer"]["src"] = opts.src_suffix
logger.debug(f"Get suffix for src infer: {opts.src_suffix}")
if hasattr(opts, "tgt_suffix"):
if "infer" not in suffix_dict.keys():
suffix_dict["infer"] = {}
suffix_dict["infer"]["tgt"] = opts.tgt_suffix
logger.debug(f"Get suffix for tgt infer: {opts.tgt_suffix}")
return suffix_dict
@classmethod
def get_specials(cls, opts):
"""Get special vocabs added by suffix transform."""
suffix_dict = cls.get_suffix_dict(opts)
src_specials, tgt_specials = set(), set()
for _, suffix in suffix_dict.items():
src_specials.update(suffix["src"].split())
tgt_specials.update(suffix["tgt"].split())
return (src_specials, tgt_specials)
def warm_up(self, vocabs=None):
"""Warm up to get suffix dictionary."""
super().warm_up(None)
self.suffix_dict = self.get_suffix_dict(self.opts)
def _append(self, example, suffix):
"""Prepend `suffix` to `tokens`."""
for side, side_suffix in suffix.items():
if example.get(side) is not None:
example[side] = example[side] + side_suffix.split()
elif len(side_suffix) > 0:
example[side] = side_suffix.split()
return example
def apply(self, example, is_train=False, stats=None, **kwargs):
"""Apply suffix append to example.
Should provide `corpus_name` to get correspond suffix.
"""
corpus_name = kwargs.get("corpus_name", None)
if corpus_name is None:
raise ValueError("corpus_name is required.")
corpus_suffix = self.suffix_dict.get(corpus_name, None)
if corpus_suffix is None:
raise ValueError(f"suffix for {corpus_name} does not exist.")
return self._append(example, corpus_suffix)
def _repr_args(self):
"""Return str represent key arguments for class."""
return "{}={}".format("suffix_dict", self.suffix_dict)
|