| from onmt.constants import DefaultTokens |
| from onmt.transforms import register_transform |
| from onmt.utils.logging import logger |
| from .transform import Transform |
|
|
|
|
| @register_transform(name="insert_mask_before_placeholder") |
| class InsertMaskBeforePlaceholdersTransform(Transform): |
| """Add the `DefaultTokens.MASK_BEFORE` placeholder between |
| the prompt and the response in an LM finetuning exemple. |
| This is necessary to enable the 'zero-out prompt loss' mechanism. |
| """ |
|
|
| def __init__(self, opts): |
| super().__init__(opts) |
|
|
| @classmethod |
| def add_options(cls, parser): |
| """Options for mask_before placeholders insertion""" |
|
|
| group = parser.add_argument_group( |
| "Transform/InsertMaskBeforePlaceholdersTransform" |
| ) |
| group.add( |
| "--response_pattern", |
| "-response_pattern", |
| type=str, |
| help="Response patten to locate the end of the prompt", |
| default="Response : ⦅newline⦆", |
| ) |
|
|
| def _parse_opts(self): |
| self.response_pattern = self.opts.response_pattern |
|
|
| def apply(self, example, is_train=False, stats=None, **kwargs): |
| _src = " ".join(example["src"]) |
| if len(_src.split(self.response_pattern)) != 2: |
| logger.info("The mask_before could not be inserted") |
| return example |
| prompt, response = _src.split(self.response_pattern) |
| response = DefaultTokens.MASK_BEFORE.join([self.response_pattern, response]) |
| _src = "".join([prompt, response]) |
| example["src"] = _src.split(" ") |
| example["tgt"] = _src.split(" ") |
| return example |
|
|