| |
| |
| |
| |
|
|
| import logging |
| from dataclasses import dataclass, field |
| from typing import Optional |
|
|
| import torch |
| from .dummy_dataset import DummyDataset |
| from fairseq.data import Dictionary |
| from fairseq.dataclass import FairseqDataclass |
| from fairseq.tasks import FairseqTask, register_task |
| from omegaconf import II |
|
|
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class DummyLMConfig(FairseqDataclass): |
| dict_size: int = 49996 |
| dataset_size: int = 100000 |
| tokens_per_sample: int = field( |
| default=512, metadata={"help": "max sequence length"} |
| ) |
| add_bos_token: bool = False |
| batch_size: Optional[int] = II("dataset.batch_size") |
| max_tokens: Optional[int] = II("dataset.max_tokens") |
| max_target_positions: int = II("task.tokens_per_sample") |
|
|
|
|
| @register_task("dummy_lm", dataclass=DummyLMConfig) |
| class DummyLMTask(FairseqTask): |
| def __init__(self, cfg: DummyLMConfig): |
| super().__init__(cfg) |
|
|
| |
| self.dictionary = Dictionary() |
| for i in range(cfg.dict_size): |
| self.dictionary.add_symbol("word{}".format(i)) |
| self.dictionary.pad_to_multiple_(8) |
| logger.info("dictionary: {} types".format(len(self.dictionary))) |
|
|
| seq = torch.arange(cfg.tokens_per_sample + 1) + self.dictionary.pad() + 1 |
|
|
| self.dummy_src = seq[:-1] |
| self.dummy_tgt = seq[1:] |
|
|
| def load_dataset(self, split, epoch=1, combine=False, **kwargs): |
| """Load a given dataset split. |
| Args: |
| split (str): name of the split (e.g., train, valid, test) |
| """ |
| if self.cfg.batch_size is not None: |
| bsz = self.cfg.batch_size |
| else: |
| bsz = max(1, self.cfg.max_tokens // self.cfg.tokens_per_sample) |
| self.datasets[split] = DummyDataset( |
| { |
| "id": 1, |
| "net_input": { |
| "src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]), |
| "src_lengths": torch.full( |
| (bsz,), self.cfg.tokens_per_sample, dtype=torch.long |
| ), |
| }, |
| "target": torch.stack([self.dummy_tgt for _ in range(bsz)]), |
| "nsentences": bsz, |
| "ntokens": bsz * self.cfg.tokens_per_sample, |
| }, |
| num_items=self.cfg.dataset_size, |
| item_size=self.cfg.tokens_per_sample, |
| ) |
|
|
| @property |
| def source_dictionary(self): |
| return self.dictionary |
|
|
| @property |
| def target_dictionary(self): |
| return self.dictionary |
|
|