| # Setting up a Google Cloud TPU VM for training a tokenizer | |
| ## TPU VM Configurations | |
| To start off follow the guide from the Flax/JAX community week 2021 [here](https://github.com/huggingface/transformers/tree/master/examples/research_projects/jax-projects#how-to-setup-tpu-vm), but **NOTE** modify all the `pip` commands to `pip3`. | |
| Some might encounter this error message: | |
| ``` | |
| Building wheel for jax (setup.py) ... error | |
| ERROR: Command errored out with exit status 1: | |
| command: /home/patrick/patrick/bin/python3 -u -c 'import sys, setuptools, tokenize; sys.argv[0] = '"'"'/tmp/pip-install-lwseckn1/jax/setup.py'"'"'; __file__='"'"'/tmp/pip-install-lwseckn1/jax/setup.py'"'"';f=getattr(tokenize, '"'"'open'"'"', open)(__file__);code=f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' bdist_wheel -d /tmp/pip-wheel-pydotzlo | |
| cwd: /tmp/pip-install-lwseckn1/jax/ | |
| Complete output (6 lines): | |
| usage: setup.py [global_opts] cmd1 [cmd1_opts] [cmd2 [cmd2_opts] ...] | |
| or: setup.py --help [cmd1 cmd2 ...] | |
| or: setup.py --help-commands | |
| or: setup.py cmd --help | |
| error: invalid command 'bdist_wheel' | |
| ---------------------------------------- | |
| ERROR: Failed building wheel for jax | |
| ``` | |
| If encountering the error message run the following commands: | |
| ``` | |
| pip3 install --upgrade clu | |
| pip3 install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html | |
| ``` | |
| Then give your user sudo rights: | |
| ```bash | |
| chmod a+rwx /tmp/* | |
| chmod a+rwx /tmp/tpu_logs/* # Just to be sure ;-) | |
| ``` | |
| Afterwards you can verify the installation by either running the following script: | |
| ```python | |
| from transformers import FlaxRobertaModel, RobertaTokenizerFast | |
| from datasets import load_dataset | |
| import jax | |
| dataset = load_dataset('oscar', "unshuffled_deduplicated_en", split='train', streaming=True) | |
| dummy_input = next(iter(dataset))["text"] | |
| tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base") | |
| input_ids = tokenizer(dummy_input, return_tensors="np").input_ids[:, :10] | |
| model = FlaxRobertaModel.from_pretrained("julien-c/dummy-unknown") | |
| # run a forward pass, should return an object `FlaxBaseModelOutputWithPooling` | |
| model(input_ids) | |
| ``` | |
| Or by running the following `python` commands: | |
| ```python | |
| import jax | |
| jax.devices() | |
| ``` | |
| ## Training the tokenizer | |
| To train the tokenizer run the `train_tokenizer.py` script: | |
| ```bash | |
| python3 train_tokenizer.py | |
| ``` | |
| ### Problems while developing the script: | |
| - Loading the '*mc4*' dataset using the `load_dataset()` from HugginFace's dataset package `datasets` was not able to load multiple language in one line of code, as otherwise specified [here](https://huggingface.co/datasets/mc4). It was thus chosen to load each language and concatenate them. | |
| - Furthermore, it seems that even though you predefine a subset-split using the `split` argument, the entire dataset still needs to be downloaded. | |
| - Some bug occured when downloading the danish dataset, and we then had to force a redownload to mitigate the bug, and make the VM download it. | |