Fraser commited on
Commit
2df9eb0
·
1 Parent(s): 7633929

rearange files

Browse files
ag_news_clm.sh → _ag_news_clm.sh RENAMED
File without changes
ag_news_load.sh → _ag_news_load.sh RENAMED
File without changes
ag_news.sh CHANGED
@@ -1,4 +1,4 @@
1
- export RUN_NAME=test
2
 
3
  ./venv/bin/python train.py \
4
  --t5_model_name_or_path="t5-base" \
@@ -16,5 +16,5 @@ export RUN_NAME=test
16
  --learning_rate="5e-3" --warmup_steps="1000" \
17
  --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
18
  --overwrite_output_dir \
19
- --num_train_epochs="20" \
20
  --push_to_hub \
 
1
+ export RUN_NAME=ag_news
2
 
3
  ./venv/bin/python train.py \
4
  --t5_model_name_or_path="t5-base" \
 
16
  --learning_rate="5e-3" --warmup_steps="1000" \
17
  --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
18
  --overwrite_output_dir \
19
+ --num_train_epochs="3" \
20
  --push_to_hub \
datasets/dataset.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # unused
2
+ """Wikipedia Sentences"""
3
+
4
+ from __future__ import absolute_import, division, print_function
5
+
6
+ import os
7
+ import json
8
+
9
+ import datasets
10
+
11
+
12
+ _DESCRIPTION = """\
13
+ Dataset of sentences from Wikipedia (from the [Optimus paper](https://arxiv.org/abs/2004.04092)).
14
+ Each is of mex 64 words & <=256 GPT2 tokens.
15
+ Each row is a tokenised sentence.
16
+ {'token_ids': '{gpt2 token ids}'}
17
+ This is to test the semantics of a Transformer-VAEs latent space by interpolating on sentences.
18
+ """
19
+
20
+ NUM_SEGMENTS = 5
21
+ _TRAIN_DOWNLOAD_URL = r"https://storage.googleapis.com/t-vae/wikipedia_json_64_filtered_segment_{0}.zip"
22
+
23
+
24
+ class WikiSentencesConfig(datasets.BuilderConfig):
25
+ """BuilderConfig for WikiSentences."""
26
+
27
+ def __init__(self, segment=None, max_num_samples=None, **kwargs):
28
+ """BuilderConfig for WikiSentences.
29
+ Args:
30
+ segment_num: keyword argument to specify the segment of the dataset to load
31
+ **kwargs: keyword arguments forwarded to super.
32
+ """
33
+ self.segment = segment
34
+ self.max_num_samples = max_num_samples
35
+ super(WikiSentencesConfig, self).__init__(**kwargs)
36
+
37
+
38
+ class WikiSentences(datasets.GeneratorBasedBuilder):
39
+ """Sentences from Wikipedia."""
40
+
41
+ BUILDER_CONFIGS = [
42
+ WikiSentencesConfig(
43
+ name=f"segment_{i}",
44
+ description=f"Segment {i+1}/{NUM_SEGMENTS} of WikiSentences Dataset for interpolating on natural language.",
45
+ segment=i
46
+ ) for i in range(NUM_SEGMENTS)
47
+ ] + [
48
+ WikiSentencesConfig(
49
+ name=f"1M_segment_{i}",
50
+ description=f"Segment {i+1}/{NUM_SEGMENTS} of WikiSentences Dataset for interpolating on natural language.",
51
+ segment=i, max_num_samples=1_000_000
52
+ ) for i in range(NUM_SEGMENTS)
53
+ ]
54
+
55
+ def _info(self):
56
+ return datasets.DatasetInfo(
57
+ description=_DESCRIPTION,
58
+ features=datasets.Features(
59
+ {
60
+ 'token_ids': [datasets.Value("int32")],
61
+ }
62
+ ),
63
+ homepage="https://github.com/Fraser-Greenlee/transformer-vae",
64
+ )
65
+
66
+ def _split_generators(self, dl_manager):
67
+ assert(self.config.segment < NUM_SEGMENTS), f'Segment does not exist, requested segment {self.config.segment}, but max segment num is {NUM_SEGMENTS - 1}'
68
+ folder_path = dl_manager.download_and_extract(_TRAIN_DOWNLOAD_URL.format(self.config.segment))
69
+ return [
70
+ datasets.SplitGenerator(
71
+ name=datasets.Split.TRAIN, gen_kwargs={"filepath": os.path.join(folder_path, 'segment_output.jsonl')}
72
+ ),
73
+ ]
74
+
75
+ def _generate_examples(self, filepath):
76
+ """Generate examples."""
77
+ with open(filepath, encoding="utf-8") as json_lines_file:
78
+ for id_, line in enumerate(json_lines_file):
79
+ yield id_, json.loads(line)
80
+ if id_ >= self.config.max_num_samples:
81
+ break