rearange files
Browse files- ag_news_clm.sh → _ag_news_clm.sh +0 -0
- ag_news_load.sh → _ag_news_load.sh +0 -0
- ag_news.sh +2 -2
- datasets/dataset.py +81 -0
ag_news_clm.sh → _ag_news_clm.sh
RENAMED
|
File without changes
|
ag_news_load.sh → _ag_news_load.sh
RENAMED
|
File without changes
|
ag_news.sh
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
export RUN_NAME=
|
| 2 |
|
| 3 |
./venv/bin/python train.py \
|
| 4 |
--t5_model_name_or_path="t5-base" \
|
|
@@ -16,5 +16,5 @@ export RUN_NAME=test
|
|
| 16 |
--learning_rate="5e-3" --warmup_steps="1000" \
|
| 17 |
--adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
|
| 18 |
--overwrite_output_dir \
|
| 19 |
-
--num_train_epochs="
|
| 20 |
--push_to_hub \
|
|
|
|
| 1 |
+
export RUN_NAME=ag_news
|
| 2 |
|
| 3 |
./venv/bin/python train.py \
|
| 4 |
--t5_model_name_or_path="t5-base" \
|
|
|
|
| 16 |
--learning_rate="5e-3" --warmup_steps="1000" \
|
| 17 |
--adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
|
| 18 |
--overwrite_output_dir \
|
| 19 |
+
--num_train_epochs="3" \
|
| 20 |
--push_to_hub \
|
datasets/dataset.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# unused
|
| 2 |
+
"""Wikipedia Sentences"""
|
| 3 |
+
|
| 4 |
+
from __future__ import absolute_import, division, print_function
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
|
| 9 |
+
import datasets
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
_DESCRIPTION = """\
|
| 13 |
+
Dataset of sentences from Wikipedia (from the [Optimus paper](https://arxiv.org/abs/2004.04092)).
|
| 14 |
+
Each is of mex 64 words & <=256 GPT2 tokens.
|
| 15 |
+
Each row is a tokenised sentence.
|
| 16 |
+
{'token_ids': '{gpt2 token ids}'}
|
| 17 |
+
This is to test the semantics of a Transformer-VAEs latent space by interpolating on sentences.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
NUM_SEGMENTS = 5
|
| 21 |
+
_TRAIN_DOWNLOAD_URL = r"https://storage.googleapis.com/t-vae/wikipedia_json_64_filtered_segment_{0}.zip"
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class WikiSentencesConfig(datasets.BuilderConfig):
|
| 25 |
+
"""BuilderConfig for WikiSentences."""
|
| 26 |
+
|
| 27 |
+
def __init__(self, segment=None, max_num_samples=None, **kwargs):
|
| 28 |
+
"""BuilderConfig for WikiSentences.
|
| 29 |
+
Args:
|
| 30 |
+
segment_num: keyword argument to specify the segment of the dataset to load
|
| 31 |
+
**kwargs: keyword arguments forwarded to super.
|
| 32 |
+
"""
|
| 33 |
+
self.segment = segment
|
| 34 |
+
self.max_num_samples = max_num_samples
|
| 35 |
+
super(WikiSentencesConfig, self).__init__(**kwargs)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class WikiSentences(datasets.GeneratorBasedBuilder):
|
| 39 |
+
"""Sentences from Wikipedia."""
|
| 40 |
+
|
| 41 |
+
BUILDER_CONFIGS = [
|
| 42 |
+
WikiSentencesConfig(
|
| 43 |
+
name=f"segment_{i}",
|
| 44 |
+
description=f"Segment {i+1}/{NUM_SEGMENTS} of WikiSentences Dataset for interpolating on natural language.",
|
| 45 |
+
segment=i
|
| 46 |
+
) for i in range(NUM_SEGMENTS)
|
| 47 |
+
] + [
|
| 48 |
+
WikiSentencesConfig(
|
| 49 |
+
name=f"1M_segment_{i}",
|
| 50 |
+
description=f"Segment {i+1}/{NUM_SEGMENTS} of WikiSentences Dataset for interpolating on natural language.",
|
| 51 |
+
segment=i, max_num_samples=1_000_000
|
| 52 |
+
) for i in range(NUM_SEGMENTS)
|
| 53 |
+
]
|
| 54 |
+
|
| 55 |
+
def _info(self):
|
| 56 |
+
return datasets.DatasetInfo(
|
| 57 |
+
description=_DESCRIPTION,
|
| 58 |
+
features=datasets.Features(
|
| 59 |
+
{
|
| 60 |
+
'token_ids': [datasets.Value("int32")],
|
| 61 |
+
}
|
| 62 |
+
),
|
| 63 |
+
homepage="https://github.com/Fraser-Greenlee/transformer-vae",
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
def _split_generators(self, dl_manager):
|
| 67 |
+
assert(self.config.segment < NUM_SEGMENTS), f'Segment does not exist, requested segment {self.config.segment}, but max segment num is {NUM_SEGMENTS - 1}'
|
| 68 |
+
folder_path = dl_manager.download_and_extract(_TRAIN_DOWNLOAD_URL.format(self.config.segment))
|
| 69 |
+
return [
|
| 70 |
+
datasets.SplitGenerator(
|
| 71 |
+
name=datasets.Split.TRAIN, gen_kwargs={"filepath": os.path.join(folder_path, 'segment_output.jsonl')}
|
| 72 |
+
),
|
| 73 |
+
]
|
| 74 |
+
|
| 75 |
+
def _generate_examples(self, filepath):
|
| 76 |
+
"""Generate examples."""
|
| 77 |
+
with open(filepath, encoding="utf-8") as json_lines_file:
|
| 78 |
+
for id_, line in enumerate(json_lines_file):
|
| 79 |
+
yield id_, json.loads(line)
|
| 80 |
+
if id_ >= self.config.max_num_samples:
|
| 81 |
+
break
|