transformer-vae / datasets /wiki_sentences.py
Fraser's picture
add wiki sentences
4286a16
# unused
"""Wikipedia Sentences"""
from __future__ import absolute_import, division, print_function
import os
import json
import datasets
_DESCRIPTION = """\
Dataset of sentences from Wikipedia (from the [Optimus paper](https://arxiv.org/abs/2004.04092)).
Each is of mex 64 words & <=256 GPT2 tokens.
Each row is a tokenised sentence.
{'token_ids': '{gpt2 token ids}'}
This is to test the semantics of a Transformer-VAEs latent space by interpolating on sentences.
"""
NUM_SEGMENTS = 5
_TRAIN_DOWNLOAD_URL = r"https://storage.googleapis.com/t-vae/wikipedia_json_64_filtered_segment_{0}.zip"
class WikiSentencesConfig(datasets.BuilderConfig):
"""BuilderConfig for WikiSentences."""
def __init__(self, segment=None, max_num_samples=None, **kwargs):
"""BuilderConfig for WikiSentences.
Args:
segment_num: keyword argument to specify the segment of the dataset to load
**kwargs: keyword arguments forwarded to super.
"""
self.segment = segment
self.max_num_samples = max_num_samples
super(WikiSentencesConfig, self).__init__(**kwargs)
class WikiSentences(datasets.GeneratorBasedBuilder):
"""Sentences from Wikipedia."""
BUILDER_CONFIGS = [
WikiSentencesConfig(
name=f"segment_{i}",
description=f"Segment {i+1}/{NUM_SEGMENTS} of WikiSentences Dataset for interpolating on natural language.",
segment=i
) for i in range(NUM_SEGMENTS)
] + [
WikiSentencesConfig(
name=f"1M_segment_{i}",
description=f"Segment {i+1}/{NUM_SEGMENTS} of WikiSentences Dataset for interpolating on natural language.",
segment=i, max_num_samples=1_000_000
) for i in range(NUM_SEGMENTS)
]
def _info(self):
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=datasets.Features(
{
'token_ids': [datasets.Value("int32")],
}
),
homepage="https://github.com/Fraser-Greenlee/transformer-vae",
)
def _split_generators(self, dl_manager):
assert(self.config.segment < NUM_SEGMENTS), f'Segment does not exist, requested segment {self.config.segment}, but max segment num is {NUM_SEGMENTS - 1}'
folder_path = dl_manager.download_and_extract(_TRAIN_DOWNLOAD_URL.format(self.config.segment))
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN, gen_kwargs={"filepath": os.path.join(folder_path, 'segment_output.jsonl')}
),
]
def _generate_examples(self, filepath):
"""Generate examples."""
with open(filepath, encoding="utf-8") as json_lines_file:
for id_, line in enumerate(json_lines_file):
yield id_, json.loads(line)
if id_ >= self.config.max_num_samples:
break