run training with easy setup
Browse files- README.md +5 -25
- requirements.txt +0 -3
- setup_venv.sh +19 -0
- train.py +5 -4
- train.sh +18 -0
README.md
CHANGED
|
@@ -1,36 +1,16 @@
|
|
| 1 |
-
#
|
| 2 |
|
| 3 |
A Transformer-VAE made using flax.
|
| 4 |
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
Builds on T5, using an autoencoder to convert it into a VAE.
|
| 8 |
-
|
| 9 |
-
[See training logs.](https://wandb.ai/fraser/flax-vae)
|
| 10 |
-
|
| 11 |
-
## ToDo
|
| 12 |
-
|
| 13 |
-
- [ ] Basic training script working. (Fraser + Theo)
|
| 14 |
-
- [ ] Add MMD loss (Theo)
|
| 15 |
|
| 16 |
-
|
| 17 |
-
- [ ] Make a tokenizer using the OPTIMUS tokenized dataset.
|
| 18 |
-
- [ ] Train on the OPTIMUS wikipedia sentences dataset.
|
| 19 |
-
|
| 20 |
-
- [ ] Make Huggingface widget interpolating sentences! (???) https://github.com/huggingface/transformers/tree/master/examples/research_projects/jax-projects#how-to-build-a-demo
|
| 21 |
-
|
| 22 |
-
Optional ToDos:
|
| 23 |
-
|
| 24 |
-
- [ ] Add Funnel transformer encoder to FLAX (don't need weights).
|
| 25 |
-
- [ ] Train a Funnel-encoder + T5-decoder transformer VAE.
|
| 26 |
|
| 27 |
-
|
| 28 |
-
- [ ] Poetry (https://www.gwern.net/GPT-2#data-the-project-gutenberg-poetry-corpus)
|
| 29 |
-
- [ ] 8-bit music (https://github.com/chrisdonahue/LakhNES)
|
| 30 |
|
| 31 |
## Setup
|
| 32 |
|
| 33 |
-
Follow all steps to install dependencies from https://
|
| 34 |
|
| 35 |
- [ ] Find dataset storage site.
|
| 36 |
- [ ] Ask JAX team for dataset storage.
|
|
|
|
| 1 |
+
# T5-VAE-Python (flax) (WIP)
|
| 2 |
|
| 3 |
A Transformer-VAE made using flax.
|
| 4 |
|
| 5 |
+
It has been trained to interpolate on lines of Python code form the [python-lines dataset](https://huggingface.co/datasets/Fraser/python-lines).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
+
Done as part of Huggingface community training ([see forum post](https://discuss.huggingface.co/t/train-a-vae-to-interpolate-on-english-sentences/7548)).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
Builds on T5, using an autoencoder to convert it into an MMD-VAE.
|
|
|
|
|
|
|
| 10 |
|
| 11 |
## Setup
|
| 12 |
|
| 13 |
+
Follow all steps to install dependencies from https://github.com/huggingface/transformers/blob/master/examples/research_projects/jax-projects/README.md#tpu-vm
|
| 14 |
|
| 15 |
- [ ] Find dataset storage site.
|
| 16 |
- [ ] Ask JAX team for dataset storage.
|
requirements.txt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
jax
|
| 2 |
-
jaxlib
|
| 3 |
-
-r requirements-tpu.txt
|
|
|
|
|
|
|
|
|
|
|
|
setup_venv.sh
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# setup training on a TPU VM
|
| 2 |
+
rm -fr venv
|
| 3 |
+
python3 -m venv venv
|
| 4 |
+
source venv/bin/activate
|
| 5 |
+
pip install -U pip
|
| 6 |
+
pip install -U wheel
|
| 7 |
+
pip install requests
|
| 8 |
+
pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
| 9 |
+
|
| 10 |
+
cd ..
|
| 11 |
+
git clone https://github.com/huggingface/transformers.git
|
| 12 |
+
cd transformers
|
| 13 |
+
pip install -e ".[flax]"
|
| 14 |
+
cd ..
|
| 15 |
+
|
| 16 |
+
git clone https://github.com/huggingface/datasets.git
|
| 17 |
+
cd datasets
|
| 18 |
+
pip install -e ".[streaming]"
|
| 19 |
+
cd ..
|
train.py
CHANGED
|
@@ -2,8 +2,6 @@
|
|
| 2 |
Pre-training/Fine-tuning seq2seq models on autoencoding a dataset.
|
| 3 |
|
| 4 |
TODO:
|
| 5 |
-
- [x] Get this running.
|
| 6 |
-
- [x] Don't make decoder input ids.
|
| 7 |
- [ ] Add reg loss
|
| 8 |
- [x] calculate MMD loss
|
| 9 |
- [ ] schedule MMD loss weight
|
|
@@ -87,6 +85,10 @@ class ModelArguments:
|
|
| 87 |
"help": "Number of dimensions to use for each latent token."
|
| 88 |
},
|
| 89 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
config_path: Optional[str] = field(
|
| 91 |
default=None, metadata={"help": "Pretrained config path"}
|
| 92 |
)
|
|
@@ -361,8 +363,7 @@ def main():
|
|
| 361 |
model = FlaxT5VaeForAutoencoding.from_pretrained(
|
| 362 |
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
| 363 |
)
|
| 364 |
-
|
| 365 |
-
assert(model.params['t5']['shared'].shape[0] == len(tokenizer), "T5 Tokenizer doesn't match T5Vae embedding size.")
|
| 366 |
else:
|
| 367 |
vocab_size = len(tokenizer)
|
| 368 |
config.t5.vocab_size = vocab_size
|
|
|
|
| 2 |
Pre-training/Fine-tuning seq2seq models on autoencoding a dataset.
|
| 3 |
|
| 4 |
TODO:
|
|
|
|
|
|
|
| 5 |
- [ ] Add reg loss
|
| 6 |
- [x] calculate MMD loss
|
| 7 |
- [ ] schedule MMD loss weight
|
|
|
|
| 85 |
"help": "Number of dimensions to use for each latent token."
|
| 86 |
},
|
| 87 |
)
|
| 88 |
+
add_special_tokens: bool = field(
|
| 89 |
+
default=False,
|
| 90 |
+
metadata={"help": "Add these special tokens to the tokenizer: {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}"},
|
| 91 |
+
)
|
| 92 |
config_path: Optional[str] = field(
|
| 93 |
default=None, metadata={"help": "Pretrained config path"}
|
| 94 |
)
|
|
|
|
| 363 |
model = FlaxT5VaeForAutoencoding.from_pretrained(
|
| 364 |
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
| 365 |
)
|
| 366 |
+
assert model.params['t5']['shared'].shape[0] == len(tokenizer), "T5 Tokenizer doesn't match T5Vae embedding size."
|
|
|
|
| 367 |
else:
|
| 368 |
vocab_size = len(tokenizer)
|
| 369 |
config.t5.vocab_size = vocab_size
|
train.sh
CHANGED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export RUN_NAME=single_latent
|
| 2 |
+
|
| 3 |
+
./venv/bin/python train.py \
|
| 4 |
+
--t5_model_name_or_path="t5-base" \
|
| 5 |
+
--output_dir="output/${RUN_NAME}" \
|
| 6 |
+
--overwrite_output_dir \
|
| 7 |
+
--dataset_name="Fraser/python-lines" \
|
| 8 |
+
--do_train --do_eval \
|
| 9 |
+
--n_latent_tokens 1 \
|
| 10 |
+
--latent_token_size 32 \
|
| 11 |
+
--save_steps="2500" \
|
| 12 |
+
--eval_steps="2500" \
|
| 13 |
+
--block_size="32" \
|
| 14 |
+
--per_device_train_batch_size="10" \
|
| 15 |
+
--per_device_eval_batch_size="10" \
|
| 16 |
+
--overwrite_output_dir \
|
| 17 |
+
--num_train_epochs="1" \
|
| 18 |
+
--push_to_hub \
|