Spaces:
Runtime error
Runtime error
Commit
·
6098a23
0
Parent(s):
Duplicate from IkIzma/Transformer_Homework
Browse files- .gitattributes +34 -0
- FineTune.ipynb +877 -0
- Generate_text.py +28 -0
- README.md +17 -0
- app.py +30 -0
- essay_dataset/train.txt +0 -0
- essay_dataset/valid.txt +0 -0
- models/essays/README.md +55 -0
- models/essays/added_tokens.json +3 -0
- models/essays/all_results.json +15 -0
- models/essays/config.json +34 -0
- models/essays/eval_results.json +10 -0
- models/essays/generation_config.json +6 -0
- models/essays/merges.txt +0 -0
- models/essays/pytorch_model.bin +3 -0
- models/essays/special_tokens_map.json +23 -0
- models/essays/tokenizer.json +0 -0
- models/essays/tokenizer_config.json +33 -0
- models/essays/train_results.json +8 -0
- models/essays/trainer_state.json +25 -0
- models/essays/training_args.bin +3 -0
- models/essays/vocab.json +0 -0
- requirements.txt +2 -0
- run_clm.py +635 -0
.gitattributes
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
FineTune.ipynb
ADDED
|
@@ -0,0 +1,877 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {
|
| 6 |
+
"id": "BO7MEGbb6mtB"
|
| 7 |
+
},
|
| 8 |
+
"source": [
|
| 9 |
+
"# Finetune \n",
|
| 10 |
+
"Finetuning RuGPTs model with huggingface.\n",
|
| 11 |
+
"\n",
|
| 12 |
+
"## Install env"
|
| 13 |
+
]
|
| 14 |
+
},
|
| 15 |
+
{
|
| 16 |
+
"cell_type": "code",
|
| 17 |
+
"execution_count": 1,
|
| 18 |
+
"metadata": {
|
| 19 |
+
"collapsed": true,
|
| 20 |
+
"id": "Xyhc5yrzR75j"
|
| 21 |
+
},
|
| 22 |
+
"outputs": [
|
| 23 |
+
{
|
| 24 |
+
"name": "stderr",
|
| 25 |
+
"output_type": "stream",
|
| 26 |
+
"text": [
|
| 27 |
+
"Cloning into 'transformers'...\n"
|
| 28 |
+
]
|
| 29 |
+
},
|
| 30 |
+
{
|
| 31 |
+
"name": "stdout",
|
| 32 |
+
"output_type": "stream",
|
| 33 |
+
"text": [
|
| 34 |
+
"Defaulting to user installation because normal site-packages is not writeable\n",
|
| 35 |
+
"Processing /home/kamil/Documents/SHAD/ML/Part 2/Seminar 7/transformers\n",
|
| 36 |
+
" Installing build dependencies: started\n",
|
| 37 |
+
" Installing build dependencies: finished with status 'done'\n",
|
| 38 |
+
" Getting requirements to build wheel: started\n",
|
| 39 |
+
" Getting requirements to build wheel: finished with status 'done'\n",
|
| 40 |
+
" Preparing metadata (pyproject.toml): started\n",
|
| 41 |
+
" Preparing metadata (pyproject.toml): finished with status 'done'\n",
|
| 42 |
+
"Requirement already satisfied: tqdm>=4.27 in /home/kamil/.local/lib/python3.10/site-packages (from transformers==4.29.0.dev0) (4.65.0)\n",
|
| 43 |
+
"Requirement already satisfied: regex!=2019.12.17 in /home/kamil/.local/lib/python3.10/site-packages (from transformers==4.29.0.dev0) (2023.3.23)\n",
|
| 44 |
+
"Requirement already satisfied: requests in /home/kamil/.local/lib/python3.10/site-packages (from transformers==4.29.0.dev0) (2.28.2)\n",
|
| 45 |
+
"Requirement already satisfied: pyyaml>=5.1 in /usr/lib/python3/dist-packages (from transformers==4.29.0.dev0) (5.4.1)\n",
|
| 46 |
+
"Requirement already satisfied: filelock in /home/kamil/.local/lib/python3.10/site-packages (from transformers==4.29.0.dev0) (3.10.6)\n",
|
| 47 |
+
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /home/kamil/.local/lib/python3.10/site-packages (from transformers==4.29.0.dev0) (0.13.3)\n",
|
| 48 |
+
"Requirement already satisfied: packaging>=20.0 in /usr/lib/python3/dist-packages (from transformers==4.29.0.dev0) (21.3)\n",
|
| 49 |
+
"Requirement already satisfied: numpy>=1.17 in /usr/lib/python3/dist-packages (from transformers==4.29.0.dev0) (1.21.5)\n",
|
| 50 |
+
"Requirement already satisfied: huggingface-hub<1.0,>=0.11.0 in /home/kamil/.local/lib/python3.10/site-packages (from transformers==4.29.0.dev0) (0.13.4)\n",
|
| 51 |
+
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/kamil/.local/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.11.0->transformers==4.29.0.dev0) (4.5.0)\n",
|
| 52 |
+
"Requirement already satisfied: certifi>=2017.4.17 in /usr/lib/python3/dist-packages (from requests->transformers==4.29.0.dev0) (2020.6.20)\n",
|
| 53 |
+
"Requirement already satisfied: charset-normalizer<4,>=2 in /home/kamil/.local/lib/python3.10/site-packages (from requests->transformers==4.29.0.dev0) (3.0.1)\n",
|
| 54 |
+
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/lib/python3/dist-packages (from requests->transformers==4.29.0.dev0) (1.26.5)\n",
|
| 55 |
+
"Requirement already satisfied: idna<4,>=2.5 in /usr/lib/python3/dist-packages (from requests->transformers==4.29.0.dev0) (3.3)\n",
|
| 56 |
+
"Building wheels for collected packages: transformers\n",
|
| 57 |
+
" Building wheel for transformers (pyproject.toml): started\n",
|
| 58 |
+
" Building wheel for transformers (pyproject.toml): finished with status 'done'\n",
|
| 59 |
+
" Created wheel for transformers: filename=transformers-4.29.0.dev0-py3-none-any.whl size=6929166 sha256=280057264eb46bc68355d5c5a1a4d2caff1da9951d55bacbaa62463cbf73296c\n",
|
| 60 |
+
" Stored in directory: /tmp/pip-ephem-wheel-cache-xt8a8mve/wheels/a5/d3/d1/e281e4412399bfd2f44bb86274ac4204a7d53b596a501f2ad1\n",
|
| 61 |
+
"Successfully built transformers\n",
|
| 62 |
+
"Installing collected packages: transformers\n",
|
| 63 |
+
" Attempting uninstall: transformers\n",
|
| 64 |
+
" Found existing installation: transformers 4.27.4\n",
|
| 65 |
+
" Uninstalling transformers-4.27.4:\n",
|
| 66 |
+
" Successfully uninstalled transformers-4.27.4\n",
|
| 67 |
+
"Successfully installed transformers-4.29.0.dev0\n"
|
| 68 |
+
]
|
| 69 |
+
}
|
| 70 |
+
],
|
| 71 |
+
"source": [
|
| 72 |
+
"%%bash\n",
|
| 73 |
+
"git clone https://github.com/huggingface/transformers\n",
|
| 74 |
+
"cd transformers\n",
|
| 75 |
+
"pip install ."
|
| 76 |
+
]
|
| 77 |
+
},
|
| 78 |
+
{
|
| 79 |
+
"cell_type": "code",
|
| 80 |
+
"execution_count": 2,
|
| 81 |
+
"metadata": {
|
| 82 |
+
"collapsed": true,
|
| 83 |
+
"id": "Os4vOL5LTOmk"
|
| 84 |
+
},
|
| 85 |
+
"outputs": [
|
| 86 |
+
{
|
| 87 |
+
"name": "stdout",
|
| 88 |
+
"output_type": "stream",
|
| 89 |
+
"text": [
|
| 90 |
+
"Defaulting to user installation because normal site-packages is not writeable\n",
|
| 91 |
+
"Collecting datasets\n",
|
| 92 |
+
" Downloading datasets-2.11.0-py3-none-any.whl (468 kB)\n",
|
| 93 |
+
"\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m468.7/468.7 kB\u001b[0m \u001b[31m3.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m[36m0:00:01\u001b[0mm eta \u001b[36m0:00:01\u001b[0m\n",
|
| 94 |
+
"\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/lib/python3/dist-packages (from datasets) (1.21.5)\n",
|
| 95 |
+
"Requirement already satisfied: pyarrow>=8.0.0 in /home/kamil/.local/lib/python3.10/site-packages (from datasets) (11.0.0)\n",
|
| 96 |
+
"Requirement already satisfied: tqdm>=4.62.1 in /home/kamil/.local/lib/python3.10/site-packages (from datasets) (4.65.0)\n",
|
| 97 |
+
"Requirement already satisfied: packaging in /usr/lib/python3/dist-packages (from datasets) (21.3)\n",
|
| 98 |
+
"Collecting dill<0.3.7,>=0.3.0\n",
|
| 99 |
+
" Downloading dill-0.3.6-py3-none-any.whl (110 kB)\n",
|
| 100 |
+
"\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m110.5/110.5 kB\u001b[0m \u001b[31m2.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\n",
|
| 101 |
+
"\u001b[?25hRequirement already satisfied: pandas in /home/kamil/.local/lib/python3.10/site-packages (from datasets) (1.5.3)\n",
|
| 102 |
+
"Requirement already satisfied: huggingface-hub<1.0.0,>=0.11.0 in /home/kamil/.local/lib/python3.10/site-packages (from datasets) (0.13.4)\n",
|
| 103 |
+
"Collecting aiohttp\n",
|
| 104 |
+
" Downloading aiohttp-3.8.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB)\n",
|
| 105 |
+
"\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m9.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m0m eta \u001b[36m0:00:01\u001b[0m0:01\u001b[0m\n",
|
| 106 |
+
"\u001b[?25hRequirement already satisfied: requests>=2.19.0 in /home/kamil/.local/lib/python3.10/site-packages (from datasets) (2.28.2)\n",
|
| 107 |
+
"Collecting fsspec[http]>=2021.11.1\n",
|
| 108 |
+
" Downloading fsspec-2023.4.0-py3-none-any.whl (153 kB)\n",
|
| 109 |
+
"\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m154.0/154.0 kB\u001b[0m \u001b[31m3.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m\n",
|
| 110 |
+
"\u001b[?25hRequirement already satisfied: pyyaml>=5.1 in /usr/lib/python3/dist-packages (from datasets) (5.4.1)\n",
|
| 111 |
+
"Collecting xxhash\n",
|
| 112 |
+
" Downloading xxhash-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)\n",
|
| 113 |
+
"\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m212.5/212.5 kB\u001b[0m \u001b[31m4.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m31m10.9 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\n",
|
| 114 |
+
"\u001b[?25hCollecting responses<0.19\n",
|
| 115 |
+
" Downloading responses-0.18.0-py3-none-any.whl (38 kB)\n",
|
| 116 |
+
"Collecting multiprocess\n",
|
| 117 |
+
" Downloading multiprocess-0.70.14-py310-none-any.whl (134 kB)\n",
|
| 118 |
+
"\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.3/134.3 kB\u001b[0m \u001b[31m3.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m\n",
|
| 119 |
+
"\u001b[?25hRequirement already satisfied: charset-normalizer<4.0,>=2.0 in /home/kamil/.local/lib/python3.10/site-packages (from aiohttp->datasets) (3.0.1)\n",
|
| 120 |
+
"Collecting multidict<7.0,>=4.5\n",
|
| 121 |
+
" Downloading multidict-6.0.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (114 kB)\n",
|
| 122 |
+
"\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m114.5/114.5 kB\u001b[0m \u001b[31m2.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m[36m0:00:01\u001b[0m\n",
|
| 123 |
+
"\u001b[?25hRequirement already satisfied: attrs>=17.3.0 in /usr/lib/python3/dist-packages (from aiohttp->datasets) (21.2.0)\n",
|
| 124 |
+
"Collecting async-timeout<5.0,>=4.0.0a3\n",
|
| 125 |
+
" Downloading async_timeout-4.0.2-py3-none-any.whl (5.8 kB)\n",
|
| 126 |
+
"Collecting frozenlist>=1.1.1\n",
|
| 127 |
+
" Downloading frozenlist-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (149 kB)\n",
|
| 128 |
+
"\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m149.6/149.6 kB\u001b[0m \u001b[31m3.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\n",
|
| 129 |
+
"\u001b[?25hCollecting yarl<2.0,>=1.0\n",
|
| 130 |
+
" Downloading yarl-1.8.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (264 kB)\n",
|
| 131 |
+
"\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m264.0/264.0 kB\u001b[0m \u001b[31m5.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m31m15.4 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\n",
|
| 132 |
+
"\u001b[?25hCollecting aiosignal>=1.1.2\n",
|
| 133 |
+
" Downloading aiosignal-1.3.1-py3-none-any.whl (7.6 kB)\n",
|
| 134 |
+
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/kamil/.local/lib/python3.10/site-packages (from huggingface-hub<1.0.0,>=0.11.0->datasets) (4.5.0)\n",
|
| 135 |
+
"Requirement already satisfied: filelock in /home/kamil/.local/lib/python3.10/site-packages (from huggingface-hub<1.0.0,>=0.11.0->datasets) (3.10.6)\n",
|
| 136 |
+
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/lib/python3/dist-packages (from requests>=2.19.0->datasets) (1.26.5)\n",
|
| 137 |
+
"Requirement already satisfied: idna<4,>=2.5 in /usr/lib/python3/dist-packages (from requests>=2.19.0->datasets) (3.3)\n",
|
| 138 |
+
"Requirement already satisfied: certifi>=2017.4.17 in /usr/lib/python3/dist-packages (from requests>=2.19.0->datasets) (2020.6.20)\n",
|
| 139 |
+
"Requirement already satisfied: pytz>=2020.1 in /usr/lib/python3/dist-packages (from pandas->datasets) (2022.1)\n",
|
| 140 |
+
"Requirement already satisfied: python-dateutil>=2.8.1 in /home/kamil/.local/lib/python3.10/site-packages (from pandas->datasets) (2.8.2)\n",
|
| 141 |
+
"Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)\n",
|
| 142 |
+
"Installing collected packages: xxhash, multidict, fsspec, frozenlist, dill, async-timeout, yarl, responses, multiprocess, aiosignal, aiohttp, datasets\n",
|
| 143 |
+
"Successfully installed aiohttp-3.8.4 aiosignal-1.3.1 async-timeout-4.0.2 datasets-2.11.0 dill-0.3.6 frozenlist-1.3.3 fsspec-2023.4.0 multidict-6.0.4 multiprocess-0.70.14 responses-0.18.0 xxhash-3.2.0 yarl-1.8.2\n"
|
| 144 |
+
]
|
| 145 |
+
}
|
| 146 |
+
],
|
| 147 |
+
"source": [
|
| 148 |
+
"!pip install datasets"
|
| 149 |
+
]
|
| 150 |
+
},
|
| 151 |
+
{
|
| 152 |
+
"cell_type": "code",
|
| 153 |
+
"execution_count": 3,
|
| 154 |
+
"metadata": {
|
| 155 |
+
"collapsed": true,
|
| 156 |
+
"id": "m1P6WSIeTdV5"
|
| 157 |
+
},
|
| 158 |
+
"outputs": [
|
| 159 |
+
{
|
| 160 |
+
"name": "stdout",
|
| 161 |
+
"output_type": "stream",
|
| 162 |
+
"text": [
|
| 163 |
+
"Defaulting to user installation because normal site-packages is not writeable\n",
|
| 164 |
+
"Collecting evaluate\n",
|
| 165 |
+
" Downloading evaluate-0.4.0-py3-none-any.whl (81 kB)\n",
|
| 166 |
+
"\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m81.4/81.4 kB\u001b[0m \u001b[31m1.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\n",
|
| 167 |
+
"\u001b[?25hRequirement already satisfied: xxhash in /home/kamil/.local/lib/python3.10/site-packages (from evaluate) (3.2.0)\n",
|
| 168 |
+
"Requirement already satisfied: fsspec[http]>=2021.05.0 in /home/kamil/.local/lib/python3.10/site-packages (from evaluate) (2023.4.0)\n",
|
| 169 |
+
"Requirement already satisfied: responses<0.19 in /home/kamil/.local/lib/python3.10/site-packages (from evaluate) (0.18.0)\n",
|
| 170 |
+
"Requirement already satisfied: packaging in /usr/lib/python3/dist-packages (from evaluate) (21.3)\n",
|
| 171 |
+
"Requirement already satisfied: tqdm>=4.62.1 in /home/kamil/.local/lib/python3.10/site-packages (from evaluate) (4.65.0)\n",
|
| 172 |
+
"Requirement already satisfied: huggingface-hub>=0.7.0 in /home/kamil/.local/lib/python3.10/site-packages (from evaluate) (0.13.4)\n",
|
| 173 |
+
"Requirement already satisfied: numpy>=1.17 in /usr/lib/python3/dist-packages (from evaluate) (1.21.5)\n",
|
| 174 |
+
"Requirement already satisfied: dill in /home/kamil/.local/lib/python3.10/site-packages (from evaluate) (0.3.6)\n",
|
| 175 |
+
"Requirement already satisfied: requests>=2.19.0 in /home/kamil/.local/lib/python3.10/site-packages (from evaluate) (2.28.2)\n",
|
| 176 |
+
"Requirement already satisfied: datasets>=2.0.0 in /home/kamil/.local/lib/python3.10/site-packages (from evaluate) (2.11.0)\n",
|
| 177 |
+
"Requirement already satisfied: multiprocess in /home/kamil/.local/lib/python3.10/site-packages (from evaluate) (0.70.14)\n",
|
| 178 |
+
"Requirement already satisfied: pandas in /home/kamil/.local/lib/python3.10/site-packages (from evaluate) (1.5.3)\n",
|
| 179 |
+
"Requirement already satisfied: aiohttp in /home/kamil/.local/lib/python3.10/site-packages (from datasets>=2.0.0->evaluate) (3.8.4)\n",
|
| 180 |
+
"Requirement already satisfied: pyyaml>=5.1 in /usr/lib/python3/dist-packages (from datasets>=2.0.0->evaluate) (5.4.1)\n",
|
| 181 |
+
"Requirement already satisfied: pyarrow>=8.0.0 in /home/kamil/.local/lib/python3.10/site-packages (from datasets>=2.0.0->evaluate) (11.0.0)\n",
|
| 182 |
+
"Requirement already satisfied: filelock in /home/kamil/.local/lib/python3.10/site-packages (from huggingface-hub>=0.7.0->evaluate) (3.10.6)\n",
|
| 183 |
+
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/kamil/.local/lib/python3.10/site-packages (from huggingface-hub>=0.7.0->evaluate) (4.5.0)\n",
|
| 184 |
+
"Requirement already satisfied: charset-normalizer<4,>=2 in /home/kamil/.local/lib/python3.10/site-packages (from requests>=2.19.0->evaluate) (3.0.1)\n",
|
| 185 |
+
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/lib/python3/dist-packages (from requests>=2.19.0->evaluate) (1.26.5)\n",
|
| 186 |
+
"Requirement already satisfied: certifi>=2017.4.17 in /usr/lib/python3/dist-packages (from requests>=2.19.0->evaluate) (2020.6.20)\n",
|
| 187 |
+
"Requirement already satisfied: idna<4,>=2.5 in /usr/lib/python3/dist-packages (from requests>=2.19.0->evaluate) (3.3)\n",
|
| 188 |
+
"Requirement already satisfied: python-dateutil>=2.8.1 in /home/kamil/.local/lib/python3.10/site-packages (from pandas->evaluate) (2.8.2)\n",
|
| 189 |
+
"Requirement already satisfied: pytz>=2020.1 in /usr/lib/python3/dist-packages (from pandas->evaluate) (2022.1)\n",
|
| 190 |
+
"Requirement already satisfied: aiosignal>=1.1.2 in /home/kamil/.local/lib/python3.10/site-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.3.1)\n",
|
| 191 |
+
"Requirement already satisfied: frozenlist>=1.1.1 in /home/kamil/.local/lib/python3.10/site-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.3.3)\n",
|
| 192 |
+
"Requirement already satisfied: multidict<7.0,>=4.5 in /home/kamil/.local/lib/python3.10/site-packages (from aiohttp->datasets>=2.0.0->evaluate) (6.0.4)\n",
|
| 193 |
+
"Requirement already satisfied: attrs>=17.3.0 in /usr/lib/python3/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (21.2.0)\n",
|
| 194 |
+
"Requirement already satisfied: yarl<2.0,>=1.0 in /home/kamil/.local/lib/python3.10/site-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.8.2)\n",
|
| 195 |
+
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /home/kamil/.local/lib/python3.10/site-packages (from aiohttp->datasets>=2.0.0->evaluate) (4.0.2)\n",
|
| 196 |
+
"Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.1->pandas->evaluate) (1.16.0)\n",
|
| 197 |
+
"Installing collected packages: evaluate\n",
|
| 198 |
+
"Successfully installed evaluate-0.4.0\n"
|
| 199 |
+
]
|
| 200 |
+
}
|
| 201 |
+
],
|
| 202 |
+
"source": [
|
| 203 |
+
"!pip install evaluate"
|
| 204 |
+
]
|
| 205 |
+
},
|
| 206 |
+
{
|
| 207 |
+
"cell_type": "code",
|
| 208 |
+
"execution_count": 3,
|
| 209 |
+
"metadata": {
|
| 210 |
+
"id": "WJZtWu8u6nwL"
|
| 211 |
+
},
|
| 212 |
+
"outputs": [],
|
| 213 |
+
"source": [
|
| 214 |
+
"!mkdir models/"
|
| 215 |
+
]
|
| 216 |
+
},
|
| 217 |
+
{
|
| 218 |
+
"cell_type": "markdown",
|
| 219 |
+
"metadata": {
|
| 220 |
+
"id": "WqwZiumW8WbZ"
|
| 221 |
+
},
|
| 222 |
+
"source": [
|
| 223 |
+
"## Download files"
|
| 224 |
+
]
|
| 225 |
+
},
|
| 226 |
+
{
|
| 227 |
+
"cell_type": "code",
|
| 228 |
+
"execution_count": 20,
|
| 229 |
+
"metadata": {
|
| 230 |
+
"collapsed": true,
|
| 231 |
+
"id": "j51bKtQW6nyY"
|
| 232 |
+
},
|
| 233 |
+
"outputs": [
|
| 234 |
+
{
|
| 235 |
+
"name": "stdout",
|
| 236 |
+
"output_type": "stream",
|
| 237 |
+
"text": [
|
| 238 |
+
"--2023-04-16 19:47:12-- https://www.dropbox.com/s/oa3v9c7g9bp40xw/train.txt?dl=0\n",
|
| 239 |
+
"Resolving www.dropbox.com (www.dropbox.com)... 162.125.70.18, 2620:100:6027:18::a27d:4812\n",
|
| 240 |
+
"Connecting to www.dropbox.com (www.dropbox.com)|162.125.70.18|:443... connected.\n",
|
| 241 |
+
"HTTP request sent, awaiting response... 302 Found\n",
|
| 242 |
+
"Location: /s/raw/oa3v9c7g9bp40xw/train.txt [following]\n",
|
| 243 |
+
"--2023-04-16 19:47:13-- https://www.dropbox.com/s/raw/oa3v9c7g9bp40xw/train.txt\n",
|
| 244 |
+
"Reusing existing connection to www.dropbox.com:443.\n",
|
| 245 |
+
"HTTP request sent, awaiting response... 302 Found\n",
|
| 246 |
+
"Location: https://uc5788429f15c026c306ed6aa7c0.dl.dropboxusercontent.com/cd/0/inline/B6QRy9JQtzcR-y7uMF3TBS26D_9WsPQhmzXoWmGuHLgFMVq5YeUy4XIvymTf-coW8njd463mquV6DZB7LKdlznygflsCZHNIJ0A8Hf_yyRl2y5rb63wSIyvyBbANSc5DBKvhD4HSmZ-G8GDlRmEf3CXz-PP4jpoQFXwvDZCbIGlStw/file# [following]\n",
|
| 247 |
+
"--2023-04-16 19:47:13-- https://uc5788429f15c026c306ed6aa7c0.dl.dropboxusercontent.com/cd/0/inline/B6QRy9JQtzcR-y7uMF3TBS26D_9WsPQhmzXoWmGuHLgFMVq5YeUy4XIvymTf-coW8njd463mquV6DZB7LKdlznygflsCZHNIJ0A8Hf_yyRl2y5rb63wSIyvyBbANSc5DBKvhD4HSmZ-G8GDlRmEf3CXz-PP4jpoQFXwvDZCbIGlStw/file\n",
|
| 248 |
+
"Resolving uc5788429f15c026c306ed6aa7c0.dl.dropboxusercontent.com (uc5788429f15c026c306ed6aa7c0.dl.dropboxusercontent.com)... 162.125.70.15, 2620:100:6028:15::a27d:470f\n",
|
| 249 |
+
"Connecting to uc5788429f15c026c306ed6aa7c0.dl.dropboxusercontent.com (uc5788429f15c026c306ed6aa7c0.dl.dropboxusercontent.com)|162.125.70.15|:443... connected.\n",
|
| 250 |
+
"HTTP request sent, awaiting response... 200 OK\n",
|
| 251 |
+
"Length: 1654900 (1,6M) [text/plain]\n",
|
| 252 |
+
"Saving to: ‘train.txt’\n",
|
| 253 |
+
"\n",
|
| 254 |
+
"train.txt 100%[===================>] 1,58M 8,43MB/s in 0,2s \n",
|
| 255 |
+
"\n",
|
| 256 |
+
"2023-04-16 19:47:14 (8,43 MB/s) - ‘train.txt’ saved [1654900/1654900]\n",
|
| 257 |
+
"\n",
|
| 258 |
+
"--2023-04-16 19:47:14-- https://www.dropbox.com/s/mworl3ld6r3bg62/valid.txt?dl=0\n",
|
| 259 |
+
"Resolving www.dropbox.com (www.dropbox.com)... 162.125.70.18, 2620:100:6027:18::a27d:4812\n",
|
| 260 |
+
"Connecting to www.dropbox.com (www.dropbox.com)|162.125.70.18|:443... connected.\n",
|
| 261 |
+
"HTTP request sent, awaiting response... 302 Found\n",
|
| 262 |
+
"Location: /s/raw/mworl3ld6r3bg62/valid.txt [following]\n",
|
| 263 |
+
"--2023-04-16 19:47:14-- https://www.dropbox.com/s/raw/mworl3ld6r3bg62/valid.txt\n",
|
| 264 |
+
"Reusing existing connection to www.dropbox.com:443.\n",
|
| 265 |
+
"HTTP request sent, awaiting response... 302 Found\n",
|
| 266 |
+
"Location: https://uc5ee48fa1d36195fd1fe094947e.dl.dropboxusercontent.com/cd/0/inline/B6QZm3htPxEoOiKlbNIGQz27I0gnkhm3CfT9DoU9qR3VUmFjo8_GWcsquYc01t4LT6WYRj4t70Sw9Z9DhdBPq4ZFpgiGfN4TyCf4Hav48iIButfo1Aaa31uqnVavn3dRVXKM2CZ5ewiMDDEGDexFnB-ZPHZyomgPCjDRtkdkMvfP7g/file# [following]\n",
|
| 267 |
+
"--2023-04-16 19:47:15-- https://uc5ee48fa1d36195fd1fe094947e.dl.dropboxusercontent.com/cd/0/inline/B6QZm3htPxEoOiKlbNIGQz27I0gnkhm3CfT9DoU9qR3VUmFjo8_GWcsquYc01t4LT6WYRj4t70Sw9Z9DhdBPq4ZFpgiGfN4TyCf4Hav48iIButfo1Aaa31uqnVavn3dRVXKM2CZ5ewiMDDEGDexFnB-ZPHZyomgPCjDRtkdkMvfP7g/file\n",
|
| 268 |
+
"Resolving uc5ee48fa1d36195fd1fe094947e.dl.dropboxusercontent.com (uc5ee48fa1d36195fd1fe094947e.dl.dropboxusercontent.com)... 162.125.70.15, 2620:100:6026:15::a27d:460f\n",
|
| 269 |
+
"Connecting to uc5ee48fa1d36195fd1fe094947e.dl.dropboxusercontent.com (uc5ee48fa1d36195fd1fe094947e.dl.dropboxusercontent.com)|162.125.70.15|:443... connected.\n",
|
| 270 |
+
"HTTP request sent, awaiting response... 200 OK\n",
|
| 271 |
+
"Length: 167021 (163K) [text/plain]\n",
|
| 272 |
+
"Saving to: ‘valid.txt’\n",
|
| 273 |
+
"\n",
|
| 274 |
+
"valid.txt 100%[===================>] 163,11K --.-KB/s in 0,08s \n",
|
| 275 |
+
"\n",
|
| 276 |
+
"2023-04-16 19:47:15 (2,02 MB/s) - ‘valid.txt’ saved [167021/167021]\n",
|
| 277 |
+
"\n"
|
| 278 |
+
]
|
| 279 |
+
}
|
| 280 |
+
],
|
| 281 |
+
"source": [
|
| 282 |
+
"!wget -O train.txt https://www.dropbox.com/s/oa3v9c7g9bp40xw/train.txt?dl=0\n",
|
| 283 |
+
"!wget -O valid.txt https://www.dropbox.com/s/mworl3ld6r3bg62/valid.txt?dl=0"
|
| 284 |
+
]
|
| 285 |
+
},
|
| 286 |
+
{
|
| 287 |
+
"cell_type": "markdown",
|
| 288 |
+
"metadata": {
|
| 289 |
+
"id": "zoyX62qN_38l"
|
| 290 |
+
},
|
| 291 |
+
"source": [
|
| 292 |
+
"## Train \n",
|
| 293 |
+
"The following code download model and tokenizer from huggingface and finetune model for generating essays."
|
| 294 |
+
]
|
| 295 |
+
},
|
| 296 |
+
{
|
| 297 |
+
"cell_type": "code",
|
| 298 |
+
"execution_count": 24,
|
| 299 |
+
"metadata": {
|
| 300 |
+
"collapsed": true,
|
| 301 |
+
"id": "OCIERP8AS1Dl"
|
| 302 |
+
},
|
| 303 |
+
"outputs": [
|
| 304 |
+
{
|
| 305 |
+
"name": "stdout",
|
| 306 |
+
"output_type": "stream",
|
| 307 |
+
"text": [
|
| 308 |
+
"04/16/2023 19:47:40 - WARNING - __main__ - Process rank: -1, device: cuda:0, n_gpu: 1distributed training: False, 16-bits training: False\n",
|
| 309 |
+
"04/16/2023 19:47:40 - INFO - __main__ - Training/evaluation parameters TrainingArguments(\n",
|
| 310 |
+
"_n_gpu=1,\n",
|
| 311 |
+
"adafactor=False,\n",
|
| 312 |
+
"adam_beta1=0.9,\n",
|
| 313 |
+
"adam_beta2=0.999,\n",
|
| 314 |
+
"adam_epsilon=1e-08,\n",
|
| 315 |
+
"auto_find_batch_size=False,\n",
|
| 316 |
+
"bf16=False,\n",
|
| 317 |
+
"bf16_full_eval=False,\n",
|
| 318 |
+
"data_seed=None,\n",
|
| 319 |
+
"dataloader_drop_last=False,\n",
|
| 320 |
+
"dataloader_num_workers=0,\n",
|
| 321 |
+
"dataloader_pin_memory=True,\n",
|
| 322 |
+
"ddp_bucket_cap_mb=None,\n",
|
| 323 |
+
"ddp_find_unused_parameters=None,\n",
|
| 324 |
+
"ddp_timeout=1800,\n",
|
| 325 |
+
"debug=[],\n",
|
| 326 |
+
"deepspeed=None,\n",
|
| 327 |
+
"disable_tqdm=False,\n",
|
| 328 |
+
"do_eval=True,\n",
|
| 329 |
+
"do_predict=False,\n",
|
| 330 |
+
"do_train=True,\n",
|
| 331 |
+
"eval_accumulation_steps=None,\n",
|
| 332 |
+
"eval_delay=0,\n",
|
| 333 |
+
"eval_steps=None,\n",
|
| 334 |
+
"evaluation_strategy=no,\n",
|
| 335 |
+
"fp16=False,\n",
|
| 336 |
+
"fp16_backend=auto,\n",
|
| 337 |
+
"fp16_full_eval=False,\n",
|
| 338 |
+
"fp16_opt_level=O1,\n",
|
| 339 |
+
"fsdp=[],\n",
|
| 340 |
+
"fsdp_config={'fsdp_min_num_params': 0, 'xla': False, 'xla_fsdp_grad_ckpt': False},\n",
|
| 341 |
+
"fsdp_min_num_params=0,\n",
|
| 342 |
+
"fsdp_transformer_layer_cls_to_wrap=None,\n",
|
| 343 |
+
"full_determinism=False,\n",
|
| 344 |
+
"gradient_accumulation_steps=1,\n",
|
| 345 |
+
"gradient_checkpointing=False,\n",
|
| 346 |
+
"greater_is_better=None,\n",
|
| 347 |
+
"group_by_length=False,\n",
|
| 348 |
+
"half_precision_backend=auto,\n",
|
| 349 |
+
"hub_model_id=None,\n",
|
| 350 |
+
"hub_private_repo=False,\n",
|
| 351 |
+
"hub_strategy=every_save,\n",
|
| 352 |
+
"hub_token=<HUB_TOKEN>,\n",
|
| 353 |
+
"ignore_data_skip=False,\n",
|
| 354 |
+
"include_inputs_for_metrics=False,\n",
|
| 355 |
+
"jit_mode_eval=False,\n",
|
| 356 |
+
"label_names=None,\n",
|
| 357 |
+
"label_smoothing_factor=0.0,\n",
|
| 358 |
+
"learning_rate=5e-05,\n",
|
| 359 |
+
"length_column_name=length,\n",
|
| 360 |
+
"load_best_model_at_end=False,\n",
|
| 361 |
+
"local_rank=-1,\n",
|
| 362 |
+
"log_level=passive,\n",
|
| 363 |
+
"log_level_replica=warning,\n",
|
| 364 |
+
"log_on_each_node=True,\n",
|
| 365 |
+
"logging_dir=models/essays2/runs/Apr16_19-47-40_kamil-desktop,\n",
|
| 366 |
+
"logging_first_step=False,\n",
|
| 367 |
+
"logging_nan_inf_filter=True,\n",
|
| 368 |
+
"logging_steps=500,\n",
|
| 369 |
+
"logging_strategy=steps,\n",
|
| 370 |
+
"lr_scheduler_type=linear,\n",
|
| 371 |
+
"max_grad_norm=1.0,\n",
|
| 372 |
+
"max_steps=-1,\n",
|
| 373 |
+
"metric_for_best_model=None,\n",
|
| 374 |
+
"mp_parameters=,\n",
|
| 375 |
+
"no_cuda=False,\n",
|
| 376 |
+
"num_train_epochs=3.0,\n",
|
| 377 |
+
"optim=adamw_hf,\n",
|
| 378 |
+
"optim_args=None,\n",
|
| 379 |
+
"output_dir=models/essays2,\n",
|
| 380 |
+
"overwrite_output_dir=False,\n",
|
| 381 |
+
"past_index=-1,\n",
|
| 382 |
+
"per_device_eval_batch_size=1,\n",
|
| 383 |
+
"per_device_train_batch_size=1,\n",
|
| 384 |
+
"prediction_loss_only=False,\n",
|
| 385 |
+
"push_to_hub=False,\n",
|
| 386 |
+
"push_to_hub_model_id=None,\n",
|
| 387 |
+
"push_to_hub_organization=None,\n",
|
| 388 |
+
"push_to_hub_token=<PUSH_TO_HUB_TOKEN>,\n",
|
| 389 |
+
"ray_scope=last,\n",
|
| 390 |
+
"remove_unused_columns=True,\n",
|
| 391 |
+
"report_to=[],\n",
|
| 392 |
+
"resume_from_checkpoint=None,\n",
|
| 393 |
+
"run_name=models/essays2,\n",
|
| 394 |
+
"save_on_each_node=False,\n",
|
| 395 |
+
"save_safetensors=False,\n",
|
| 396 |
+
"save_steps=500,\n",
|
| 397 |
+
"save_strategy=steps,\n",
|
| 398 |
+
"save_total_limit=None,\n",
|
| 399 |
+
"seed=42,\n",
|
| 400 |
+
"sharded_ddp=[],\n",
|
| 401 |
+
"skip_memory_metrics=True,\n",
|
| 402 |
+
"tf32=None,\n",
|
| 403 |
+
"torch_compile=False,\n",
|
| 404 |
+
"torch_compile_backend=None,\n",
|
| 405 |
+
"torch_compile_mode=None,\n",
|
| 406 |
+
"torchdynamo=None,\n",
|
| 407 |
+
"tpu_metrics_debug=False,\n",
|
| 408 |
+
"tpu_num_cores=None,\n",
|
| 409 |
+
"use_ipex=False,\n",
|
| 410 |
+
"use_legacy_prediction_loop=False,\n",
|
| 411 |
+
"use_mps_device=False,\n",
|
| 412 |
+
"warmup_ratio=0.0,\n",
|
| 413 |
+
"warmup_steps=0,\n",
|
| 414 |
+
"weight_decay=0.0,\n",
|
| 415 |
+
"xpu_backend=None,\n",
|
| 416 |
+
")\n",
|
| 417 |
+
"04/16/2023 19:47:40 - INFO - datasets.builder - Using custom data configuration default-94a5e2bc6bcfdc2e\n",
|
| 418 |
+
"04/16/2023 19:47:40 - INFO - datasets.info - Loading Dataset Infos from /home/kamil/.local/lib/python3.10/site-packages/datasets/packaged_modules/text\n",
|
| 419 |
+
"04/16/2023 19:47:40 - INFO - datasets.builder - Generating dataset text (/home/kamil/.cache/huggingface/datasets/text/default-94a5e2bc6bcfdc2e/0.0.0/cb1e9bd71a82ad27976be3b12b407850fe2837d80c22c5e03a28949843a8ace2)\n",
|
| 420 |
+
"Downloading and preparing dataset text/default to /home/kamil/.cache/huggingface/datasets/text/default-94a5e2bc6bcfdc2e/0.0.0/cb1e9bd71a82ad27976be3b12b407850fe2837d80c22c5e03a28949843a8ace2...\n",
|
| 421 |
+
"Downloading data files: 100%|██████████████████| 2/2 [00:00<00:00, 18517.90it/s]\n",
|
| 422 |
+
"04/16/2023 19:47:40 - INFO - datasets.download.download_manager - Downloading took 0.0 min\n",
|
| 423 |
+
"04/16/2023 19:47:40 - INFO - datasets.download.download_manager - Checksum Computation took 0.0 min\n",
|
| 424 |
+
"Extracting data files: 100%|█████████████████████| 2/2 [00:00<00:00, 228.71it/s]\n",
|
| 425 |
+
"04/16/2023 19:47:40 - INFO - datasets.builder - Generating train split\n",
|
| 426 |
+
"04/16/2023 19:47:40 - INFO - datasets.builder - Generating validation split\n",
|
| 427 |
+
"04/16/2023 19:47:40 - INFO - datasets.utils.info_utils - Unable to verify splits sizes.\n",
|
| 428 |
+
"Dataset text downloaded and prepared to /home/kamil/.cache/huggingface/datasets/text/default-94a5e2bc6bcfdc2e/0.0.0/cb1e9bd71a82ad27976be3b12b407850fe2837d80c22c5e03a28949843a8ace2. Subsequent calls will reuse this data.\n",
|
| 429 |
+
"100%|███████████████████████████████████████████| 2/2 [00:00<00:00, 1228.20it/s]\n",
|
| 430 |
+
"Downloading (…)lve/main/config.json: 100%|██████| 608/608 [00:00<00:00, 832kB/s]\n",
|
| 431 |
+
"[INFO|configuration_utils.py:668] 2023-04-16 19:47:41,750 >> loading configuration file config.json from cache at /home/kamil/.cache/huggingface/hub/models--sberbank-ai--rugpt3small_based_on_gpt2/snapshots/d64244b316057f71e745cc92be1dcfe7853d9d18/config.json\n",
|
| 432 |
+
"[INFO|configuration_utils.py:720] 2023-04-16 19:47:41,751 >> Model config GPT2Config {\n",
|
| 433 |
+
" \"_name_or_path\": \"sberbank-ai/rugpt3small_based_on_gpt2\",\n",
|
| 434 |
+
" \"activation_function\": \"gelu_new\",\n",
|
| 435 |
+
" \"architectures\": [\n",
|
| 436 |
+
" \"GPT2LMHeadModel\"\n",
|
| 437 |
+
" ],\n",
|
| 438 |
+
" \"attn_pdrop\": 0.1,\n",
|
| 439 |
+
" \"bos_token_id\": 50256,\n",
|
| 440 |
+
" \"embd_pdrop\": 0.1,\n",
|
| 441 |
+
" \"eos_token_id\": 50256,\n",
|
| 442 |
+
" \"gradient_checkpointing\": false,\n",
|
| 443 |
+
" \"initializer_range\": 0.02,\n",
|
| 444 |
+
" \"layer_norm_epsilon\": 1e-05,\n",
|
| 445 |
+
" \"model_type\": \"gpt2\",\n",
|
| 446 |
+
" \"n_ctx\": 2048,\n",
|
| 447 |
+
" \"n_embd\": 768,\n",
|
| 448 |
+
" \"n_head\": 12,\n",
|
| 449 |
+
" \"n_inner\": null,\n",
|
| 450 |
+
" \"n_layer\": 12,\n",
|
| 451 |
+
" \"n_positions\": 2048,\n",
|
| 452 |
+
" \"reorder_and_upcast_attn\": false,\n",
|
| 453 |
+
" \"resid_pdrop\": 0.1,\n",
|
| 454 |
+
" \"scale_attn_by_inverse_layer_idx\": false,\n",
|
| 455 |
+
" \"scale_attn_weights\": true,\n",
|
| 456 |
+
" \"summary_activation\": null,\n",
|
| 457 |
+
" \"summary_first_dropout\": 0.1,\n",
|
| 458 |
+
" \"summary_proj_to_labels\": true,\n",
|
| 459 |
+
" \"summary_type\": \"cls_index\",\n",
|
| 460 |
+
" \"summary_use_proj\": true,\n",
|
| 461 |
+
" \"transformers_version\": \"4.29.0.dev0\",\n",
|
| 462 |
+
" \"use_cache\": true,\n",
|
| 463 |
+
" \"vocab_size\": 50264\n",
|
| 464 |
+
"}\n",
|
| 465 |
+
"\n",
|
| 466 |
+
"[INFO|tokenization_auto.py:502] 2023-04-16 19:47:42,302 >> Could not locate the tokenizer configuration file, will try to use the model config instead.\n",
|
| 467 |
+
"[INFO|configuration_utils.py:668] 2023-04-16 19:47:42,851 >> loading configuration file config.json from cache at /home/kamil/.cache/huggingface/hub/models--sberbank-ai--rugpt3small_based_on_gpt2/snapshots/d64244b316057f71e745cc92be1dcfe7853d9d18/config.json\n",
|
| 468 |
+
"[INFO|configuration_utils.py:720] 2023-04-16 19:47:42,852 >> Model config GPT2Config {\n",
|
| 469 |
+
" \"_name_or_path\": \"sberbank-ai/rugpt3small_based_on_gpt2\",\n",
|
| 470 |
+
" \"activation_function\": \"gelu_new\",\n",
|
| 471 |
+
" \"architectures\": [\n",
|
| 472 |
+
" \"GPT2LMHeadModel\"\n",
|
| 473 |
+
" ],\n",
|
| 474 |
+
" \"attn_pdrop\": 0.1,\n",
|
| 475 |
+
" \"bos_token_id\": 50256,\n",
|
| 476 |
+
" \"embd_pdrop\": 0.1,\n",
|
| 477 |
+
" \"eos_token_id\": 50256,\n",
|
| 478 |
+
" \"gradient_checkpointing\": false,\n",
|
| 479 |
+
" \"initializer_range\": 0.02,\n",
|
| 480 |
+
" \"layer_norm_epsilon\": 1e-05,\n",
|
| 481 |
+
" \"model_type\": \"gpt2\",\n",
|
| 482 |
+
" \"n_ctx\": 2048,\n",
|
| 483 |
+
" \"n_embd\": 768,\n",
|
| 484 |
+
" \"n_head\": 12,\n",
|
| 485 |
+
" \"n_inner\": null,\n",
|
| 486 |
+
" \"n_layer\": 12,\n",
|
| 487 |
+
" \"n_positions\": 2048,\n",
|
| 488 |
+
" \"reorder_and_upcast_attn\": false,\n",
|
| 489 |
+
" \"resid_pdrop\": 0.1,\n",
|
| 490 |
+
" \"scale_attn_by_inverse_layer_idx\": false,\n",
|
| 491 |
+
" \"scale_attn_weights\": true,\n",
|
| 492 |
+
" \"summary_activation\": null,\n",
|
| 493 |
+
" \"summary_first_dropout\": 0.1,\n",
|
| 494 |
+
" \"summary_proj_to_labels\": true,\n",
|
| 495 |
+
" \"summary_type\": \"cls_index\",\n",
|
| 496 |
+
" \"summary_use_proj\": true,\n",
|
| 497 |
+
" \"transformers_version\": \"4.29.0.dev0\",\n",
|
| 498 |
+
" \"use_cache\": true,\n",
|
| 499 |
+
" \"vocab_size\": 50264\n",
|
| 500 |
+
"}\n",
|
| 501 |
+
"\n",
|
| 502 |
+
"Downloading (…)olve/main/vocab.json: 100%|█| 1.71M/1.71M [00:00<00:00, 3.73MB/s]\n",
|
| 503 |
+
"Downloading (…)olve/main/merges.txt: 100%|█| 1.27M/1.27M [00:00<00:00, 5.74MB/s]\n",
|
| 504 |
+
"[INFO|tokenization_utils_base.py:1809] 2023-04-16 19:47:47,652 >> loading file vocab.json from cache at /home/kamil/.cache/huggingface/hub/models--sberbank-ai--rugpt3small_based_on_gpt2/snapshots/d64244b316057f71e745cc92be1dcfe7853d9d18/vocab.json\n",
|
| 505 |
+
"[INFO|tokenization_utils_base.py:1809] 2023-04-16 19:47:47,652 >> loading file merges.txt from cache at /home/kamil/.cache/huggingface/hub/models--sberbank-ai--rugpt3small_based_on_gpt2/snapshots/d64244b316057f71e745cc92be1dcfe7853d9d18/merges.txt\n",
|
| 506 |
+
"[INFO|tokenization_utils_base.py:1809] 2023-04-16 19:47:47,652 >> loading file tokenizer.json from cache at None\n",
|
| 507 |
+
"[INFO|tokenization_utils_base.py:1809] 2023-04-16 19:47:47,652 >> loading file added_tokens.json from cache at None\n",
|
| 508 |
+
"[INFO|tokenization_utils_base.py:1809] 2023-04-16 19:47:47,652 >> loading file special_tokens_map.json from cache at None\n",
|
| 509 |
+
"[INFO|tokenization_utils_base.py:1809] 2023-04-16 19:47:47,652 >> loading file tokenizer_config.json from cache at None\n",
|
| 510 |
+
"[INFO|configuration_utils.py:668] 2023-04-16 19:47:47,652 >> loading configuration file config.json from cache at /home/kamil/.cache/huggingface/hub/models--sberbank-ai--rugpt3small_based_on_gpt2/snapshots/d64244b316057f71e745cc92be1dcfe7853d9d18/config.json\n",
|
| 511 |
+
"[INFO|configuration_utils.py:720] 2023-04-16 19:47:47,653 >> Model config GPT2Config {\n",
|
| 512 |
+
" \"_name_or_path\": \"sberbank-ai/rugpt3small_based_on_gpt2\",\n",
|
| 513 |
+
" \"activation_function\": \"gelu_new\",\n",
|
| 514 |
+
" \"architectures\": [\n",
|
| 515 |
+
" \"GPT2LMHeadModel\"\n",
|
| 516 |
+
" ],\n",
|
| 517 |
+
" \"attn_pdrop\": 0.1,\n",
|
| 518 |
+
" \"bos_token_id\": 50256,\n",
|
| 519 |
+
" \"embd_pdrop\": 0.1,\n",
|
| 520 |
+
" \"eos_token_id\": 50256,\n",
|
| 521 |
+
" \"gradient_checkpointing\": false,\n",
|
| 522 |
+
" \"initializer_range\": 0.02,\n",
|
| 523 |
+
" \"layer_norm_epsilon\": 1e-05,\n",
|
| 524 |
+
" \"model_type\": \"gpt2\",\n",
|
| 525 |
+
" \"n_ctx\": 2048,\n",
|
| 526 |
+
" \"n_embd\": 768,\n",
|
| 527 |
+
" \"n_head\": 12,\n",
|
| 528 |
+
" \"n_inner\": null,\n",
|
| 529 |
+
" \"n_layer\": 12,\n",
|
| 530 |
+
" \"n_positions\": 2048,\n",
|
| 531 |
+
" \"reorder_and_upcast_attn\": false,\n",
|
| 532 |
+
" \"resid_pdrop\": 0.1,\n",
|
| 533 |
+
" \"scale_attn_by_inverse_layer_idx\": false,\n",
|
| 534 |
+
" \"scale_attn_weights\": true,\n",
|
| 535 |
+
" \"summary_activation\": null,\n",
|
| 536 |
+
" \"summary_first_dropout\": 0.1,\n",
|
| 537 |
+
" \"summary_proj_to_labels\": true,\n",
|
| 538 |
+
" \"summary_type\": \"cls_index\",\n",
|
| 539 |
+
" \"summary_use_proj\": true,\n",
|
| 540 |
+
" \"transformers_version\": \"4.29.0.dev0\",\n",
|
| 541 |
+
" \"use_cache\": true,\n",
|
| 542 |
+
" \"vocab_size\": 50264\n",
|
| 543 |
+
"}\n",
|
| 544 |
+
"\n"
|
| 545 |
+
]
|
| 546 |
+
},
|
| 547 |
+
{
|
| 548 |
+
"name": "stdout",
|
| 549 |
+
"output_type": "stream",
|
| 550 |
+
"text": [
|
| 551 |
+
"[INFO|configuration_utils.py:668] 2023-04-16 19:47:47,725 >> loading configuration file config.json from cache at /home/kamil/.cache/huggingface/hub/models--sberbank-ai--rugpt3small_based_on_gpt2/snapshots/d64244b316057f71e745cc92be1dcfe7853d9d18/config.json\n",
|
| 552 |
+
"[INFO|configuration_utils.py:720] 2023-04-16 19:47:47,725 >> Model config GPT2Config {\n",
|
| 553 |
+
" \"_name_or_path\": \"sberbank-ai/rugpt3small_based_on_gpt2\",\n",
|
| 554 |
+
" \"activation_function\": \"gelu_new\",\n",
|
| 555 |
+
" \"architectures\": [\n",
|
| 556 |
+
" \"GPT2LMHeadModel\"\n",
|
| 557 |
+
" ],\n",
|
| 558 |
+
" \"attn_pdrop\": 0.1,\n",
|
| 559 |
+
" \"bos_token_id\": 50256,\n",
|
| 560 |
+
" \"embd_pdrop\": 0.1,\n",
|
| 561 |
+
" \"eos_token_id\": 50256,\n",
|
| 562 |
+
" \"gradient_checkpointing\": false,\n",
|
| 563 |
+
" \"initializer_range\": 0.02,\n",
|
| 564 |
+
" \"layer_norm_epsilon\": 1e-05,\n",
|
| 565 |
+
" \"model_type\": \"gpt2\",\n",
|
| 566 |
+
" \"n_ctx\": 2048,\n",
|
| 567 |
+
" \"n_embd\": 768,\n",
|
| 568 |
+
" \"n_head\": 12,\n",
|
| 569 |
+
" \"n_inner\": null,\n",
|
| 570 |
+
" \"n_layer\": 12,\n",
|
| 571 |
+
" \"n_positions\": 2048,\n",
|
| 572 |
+
" \"reorder_and_upcast_attn\": false,\n",
|
| 573 |
+
" \"resid_pdrop\": 0.1,\n",
|
| 574 |
+
" \"scale_attn_by_inverse_layer_idx\": false,\n",
|
| 575 |
+
" \"scale_attn_weights\": true,\n",
|
| 576 |
+
" \"summary_activation\": null,\n",
|
| 577 |
+
" \"summary_first_dropout\": 0.1,\n",
|
| 578 |
+
" \"summary_proj_to_labels\": true,\n",
|
| 579 |
+
" \"summary_type\": \"cls_index\",\n",
|
| 580 |
+
" \"summary_use_proj\": true,\n",
|
| 581 |
+
" \"transformers_version\": \"4.29.0.dev0\",\n",
|
| 582 |
+
" \"use_cache\": true,\n",
|
| 583 |
+
" \"vocab_size\": 50264\n",
|
| 584 |
+
"}\n",
|
| 585 |
+
"\n",
|
| 586 |
+
"[WARNING|logging.py:280] 2023-04-16 19:47:47,765 >> Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
|
| 587 |
+
"Downloading pytorch_model.bin: 100%|█████████| 551M/551M [00:36<00:00, 15.2MB/s]\n",
|
| 588 |
+
"[INFO|modeling_utils.py:2534] 2023-04-16 19:48:24,907 >> loading weights file pytorch_model.bin from cache at /home/kamil/.cache/huggingface/hub/models--sberbank-ai--rugpt3small_based_on_gpt2/snapshots/d64244b316057f71e745cc92be1dcfe7853d9d18/pytorch_model.bin\n",
|
| 589 |
+
"[INFO|configuration_utils.py:575] 2023-04-16 19:48:25,102 >> Generate config GenerationConfig {\n",
|
| 590 |
+
" \"_from_model_config\": true,\n",
|
| 591 |
+
" \"bos_token_id\": 50256,\n",
|
| 592 |
+
" \"eos_token_id\": 50256,\n",
|
| 593 |
+
" \"transformers_version\": \"4.29.0.dev0\"\n",
|
| 594 |
+
"}\n",
|
| 595 |
+
"\n",
|
| 596 |
+
"[INFO|modeling_utils.py:3190] 2023-04-16 19:48:26,046 >> All model checkpoint weights were used when initializing GPT2LMHeadModel.\n",
|
| 597 |
+
"\n",
|
| 598 |
+
"[INFO|modeling_utils.py:3198] 2023-04-16 19:48:26,046 >> All the weights of GPT2LMHeadModel were initialized from the model checkpoint at sberbank-ai/rugpt3small_based_on_gpt2.\n",
|
| 599 |
+
"If your task is similar to the task the model of the checkpoint was trained on, you can already use GPT2LMHeadModel for predictions without further training.\n",
|
| 600 |
+
"[INFO|modeling_utils.py:2839] 2023-04-16 19:48:26,570 >> Generation config file not found, using a generation config created from the model config.\n",
|
| 601 |
+
"Running tokenizer on dataset: 0%| | 0/720 [00:00<?, ? examples/s]04/16/2023 19:48:26 - INFO - datasets.arrow_dataset - Caching processed dataset at /home/kamil/.cache/huggingface/datasets/text/default-94a5e2bc6bcfdc2e/0.0.0/cb1e9bd71a82ad27976be3b12b407850fe2837d80c22c5e03a28949843a8ace2/cache-02f845f719a14b5e.arrow\n",
|
| 602 |
+
"Running tokenizer on dataset: 0%| | 0/80 [00:00<?, ? examples/s]04/16/2023 19:48:26 - INFO - datasets.arrow_dataset - Caching processed dataset at /home/kamil/.cache/huggingface/datasets/text/default-94a5e2bc6bcfdc2e/0.0.0/cb1e9bd71a82ad27976be3b12b407850fe2837d80c22c5e03a28949843a8ace2/cache-b826039bbf723e44.arrow\n",
|
| 603 |
+
"Grouping texts in chunks of 2048: 0%| | 0/720 [00:00<?, ? examples/s]04/16/2023 19:48:26 - INFO - datasets.arrow_dataset - Caching processed dataset at /home/kamil/.cache/huggingface/datasets/text/default-94a5e2bc6bcfdc2e/0.0.0/cb1e9bd71a82ad27976be3b12b407850fe2837d80c22c5e03a28949843a8ace2/cache-218f03f685e5499a.arrow\n",
|
| 604 |
+
"Grouping texts in chunks of 2048: 0%| | 0/80 [00:00<?, ? examples/s]04/16/2023 19:48:27 - INFO - datasets.arrow_dataset - Caching processed dataset at /home/kamil/.cache/huggingface/datasets/text/default-94a5e2bc6bcfdc2e/0.0.0/cb1e9bd71a82ad27976be3b12b407850fe2837d80c22c5e03a28949843a8ace2/cache-ad1660cb69988af0.arrow\n",
|
| 605 |
+
"Downloading builder script: 100%|██████████| 4.20k/4.20k [00:00<00:00, 4.48MB/s]\n",
|
| 606 |
+
"/home/kamil/.local/lib/python3.10/site-packages/transformers/optimization.py:391: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
|
| 607 |
+
" warnings.warn(\n",
|
| 608 |
+
"[INFO|trainer.py:1769] 2023-04-16 19:48:28,570 >> ***** Running training *****\n",
|
| 609 |
+
"[INFO|trainer.py:1770] 2023-04-16 19:48:28,570 >> Num examples = 92\n",
|
| 610 |
+
"[INFO|trainer.py:1771] 2023-04-16 19:48:28,570 >> Num Epochs = 3\n",
|
| 611 |
+
"[INFO|trainer.py:1772] 2023-04-16 19:48:28,570 >> Instantaneous batch size per device = 1\n",
|
| 612 |
+
"[INFO|trainer.py:1773] 2023-04-16 19:48:28,570 >> Total train batch size (w. parallel, distributed & accumulation) = 1\n",
|
| 613 |
+
"[INFO|trainer.py:1774] 2023-04-16 19:48:28,570 >> Gradient Accumulation steps = 1\n",
|
| 614 |
+
"[INFO|trainer.py:1775] 2023-04-16 19:48:28,570 >> Total optimization steps = 276\n",
|
| 615 |
+
"[INFO|trainer.py:1776] 2023-04-16 19:48:28,570 >> Number of trainable parameters = 125,231,616\n",
|
| 616 |
+
" 0%| | 0/276 [00:00<?, ?it/s]Traceback (most recent call last):\n",
|
| 617 |
+
" File \"/home/kamil/Documents/SHAD/ML/Part 2/Seminar 7/run_clm.py\", line 635, in <module>\n",
|
| 618 |
+
" main()\n",
|
| 619 |
+
" File \"/home/kamil/Documents/SHAD/ML/Part 2/Seminar 7/run_clm.py\", line 583, in main\n",
|
| 620 |
+
" train_result = trainer.train(resume_from_checkpoint=checkpoint)\n",
|
| 621 |
+
" File \"/home/kamil/.local/lib/python3.10/site-packages/transformers/trainer.py\", line 1662, in train\n",
|
| 622 |
+
" return inner_training_loop(\n",
|
| 623 |
+
" File \"/home/kamil/.local/lib/python3.10/site-packages/transformers/trainer.py\", line 1929, in _inner_training_loop\n",
|
| 624 |
+
" tr_loss_step = self.training_step(model, inputs)\n",
|
| 625 |
+
" File \"/home/kamil/.local/lib/python3.10/site-packages/transformers/trainer.py\", line 2699, in training_step\n",
|
| 626 |
+
" loss = self.compute_loss(model, inputs)\n",
|
| 627 |
+
" File \"/home/kamil/.local/lib/python3.10/site-packages/transformers/trainer.py\", line 2731, in compute_loss\n",
|
| 628 |
+
" outputs = model(**inputs)\n",
|
| 629 |
+
" File \"/home/kamil/.local/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1501, in _call_impl\n",
|
| 630 |
+
" return forward_call(*args, **kwargs)\n",
|
| 631 |
+
" File \"/home/kamil/.local/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py\", line 1075, in forward\n",
|
| 632 |
+
" transformer_outputs = self.transformer(\n",
|
| 633 |
+
" File \"/home/kamil/.local/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1501, in _call_impl\n",
|
| 634 |
+
" return forward_call(*args, **kwargs)\n",
|
| 635 |
+
" File \"/home/kamil/.local/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py\", line 899, in forward\n",
|
| 636 |
+
" outputs = block(\n",
|
| 637 |
+
" File \"/home/kamil/.local/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1501, in _call_impl\n",
|
| 638 |
+
" return forward_call(*args, **kwargs)\n",
|
| 639 |
+
" File \"/home/kamil/.local/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py\", line 389, in forward\n",
|
| 640 |
+
" attn_outputs = self.attn(\n",
|
| 641 |
+
" File \"/home/kamil/.local/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1501, in _call_impl\n",
|
| 642 |
+
" return forward_call(*args, **kwargs)\n",
|
| 643 |
+
" File \"/home/kamil/.local/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py\", line 330, in forward\n",
|
| 644 |
+
" attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)\n",
|
| 645 |
+
" File \"/home/kamil/.local/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py\", line 185, in _attn\n",
|
| 646 |
+
" attn_weights = attn_weights / torch.full(\n",
|
| 647 |
+
"torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 192.00 MiB (GPU 0; 7.79 GiB total capacity; 6.07 GiB already allocated; 171.81 MiB free; 6.09 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF\n",
|
| 648 |
+
" 0%| | 0/276 [00:01<?, ?it/s]\n"
|
| 649 |
+
]
|
| 650 |
+
}
|
| 651 |
+
],
|
| 652 |
+
"source": [
|
| 653 |
+
"!python3 run_clm.py \\\n",
|
| 654 |
+
" --model_name_or_path sberbank-ai/rugpt3small_based_on_gpt2 \\\n",
|
| 655 |
+
" --train_file train.txt \\\n",
|
| 656 |
+
" --validation_file valid.txt \\\n",
|
| 657 |
+
" --per_device_train_batch_size 1 \\\n",
|
| 658 |
+
" --per_device_eval_batch_size 1 \\\n",
|
| 659 |
+
" --block_size 2048 \\\n",
|
| 660 |
+
" --dataset_config_name plain_text \\\n",
|
| 661 |
+
" --do_train \\\n",
|
| 662 |
+
" --do_eval \\\n",
|
| 663 |
+
" --output_dir models/essays2"
|
| 664 |
+
]
|
| 665 |
+
},
|
| 666 |
+
{
|
| 667 |
+
"cell_type": "markdown",
|
| 668 |
+
"metadata": {
|
| 669 |
+
"id": "QvgntLymArg3"
|
| 670 |
+
},
|
| 671 |
+
"source": [
|
| 672 |
+
"## Evaluate model"
|
| 673 |
+
]
|
| 674 |
+
},
|
| 675 |
+
{
|
| 676 |
+
"cell_type": "code",
|
| 677 |
+
"execution_count": 4,
|
| 678 |
+
"metadata": {
|
| 679 |
+
"id": "csHcDJXFDdaW"
|
| 680 |
+
},
|
| 681 |
+
"outputs": [],
|
| 682 |
+
"source": [
|
| 683 |
+
"import numpy as np\n",
|
| 684 |
+
"import torch"
|
| 685 |
+
]
|
| 686 |
+
},
|
| 687 |
+
{
|
| 688 |
+
"cell_type": "code",
|
| 689 |
+
"execution_count": 5,
|
| 690 |
+
"metadata": {
|
| 691 |
+
"id": "TJxPg-cJDhAB"
|
| 692 |
+
},
|
| 693 |
+
"outputs": [
|
| 694 |
+
{
|
| 695 |
+
"data": {
|
| 696 |
+
"text/plain": [
|
| 697 |
+
"<torch._C.Generator at 0x7fe5314a4c50>"
|
| 698 |
+
]
|
| 699 |
+
},
|
| 700 |
+
"execution_count": 5,
|
| 701 |
+
"metadata": {},
|
| 702 |
+
"output_type": "execute_result"
|
| 703 |
+
}
|
| 704 |
+
],
|
| 705 |
+
"source": [
|
| 706 |
+
"np.random.seed(42)\n",
|
| 707 |
+
"torch.manual_seed(42)"
|
| 708 |
+
]
|
| 709 |
+
},
|
| 710 |
+
{
|
| 711 |
+
"cell_type": "code",
|
| 712 |
+
"execution_count": 6,
|
| 713 |
+
"metadata": {
|
| 714 |
+
"id": "AkUrzKsy_16F"
|
| 715 |
+
},
|
| 716 |
+
"outputs": [],
|
| 717 |
+
"source": [
|
| 718 |
+
"from transformers import GPT2LMHeadModel, GPT2Tokenizer"
|
| 719 |
+
]
|
| 720 |
+
},
|
| 721 |
+
{
|
| 722 |
+
"cell_type": "code",
|
| 723 |
+
"execution_count": 25,
|
| 724 |
+
"metadata": {
|
| 725 |
+
"id": "x_EMbgO0BTvb"
|
| 726 |
+
},
|
| 727 |
+
"outputs": [],
|
| 728 |
+
"source": [
|
| 729 |
+
"tok = GPT2Tokenizer.from_pretrained(\"models/essays\")"
|
| 730 |
+
]
|
| 731 |
+
},
|
| 732 |
+
{
|
| 733 |
+
"cell_type": "code",
|
| 734 |
+
"execution_count": 26,
|
| 735 |
+
"metadata": {
|
| 736 |
+
"id": "Fjy0GAuQBYpA"
|
| 737 |
+
},
|
| 738 |
+
"outputs": [],
|
| 739 |
+
"source": [
|
| 740 |
+
"model = GPT2LMHeadModel.from_pretrained(\"models/essays\")"
|
| 741 |
+
]
|
| 742 |
+
},
|
| 743 |
+
{
|
| 744 |
+
"cell_type": "code",
|
| 745 |
+
"execution_count": 27,
|
| 746 |
+
"metadata": {
|
| 747 |
+
"collapsed": true,
|
| 748 |
+
"id": "irh4H-HDBb6V"
|
| 749 |
+
},
|
| 750 |
+
"outputs": [
|
| 751 |
+
{
|
| 752 |
+
"data": {
|
| 753 |
+
"text/plain": [
|
| 754 |
+
"GPT2LMHeadModel(\n",
|
| 755 |
+
" (transformer): GPT2Model(\n",
|
| 756 |
+
" (wte): Embedding(50264, 768)\n",
|
| 757 |
+
" (wpe): Embedding(2048, 768)\n",
|
| 758 |
+
" (drop): Dropout(p=0.1, inplace=False)\n",
|
| 759 |
+
" (h): ModuleList(\n",
|
| 760 |
+
" (0-11): 12 x GPT2Block(\n",
|
| 761 |
+
" (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
|
| 762 |
+
" (attn): GPT2Attention(\n",
|
| 763 |
+
" (c_attn): Conv1D()\n",
|
| 764 |
+
" (c_proj): Conv1D()\n",
|
| 765 |
+
" (attn_dropout): Dropout(p=0.1, inplace=False)\n",
|
| 766 |
+
" (resid_dropout): Dropout(p=0.1, inplace=False)\n",
|
| 767 |
+
" )\n",
|
| 768 |
+
" (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
|
| 769 |
+
" (mlp): GPT2MLP(\n",
|
| 770 |
+
" (c_fc): Conv1D()\n",
|
| 771 |
+
" (c_proj): Conv1D()\n",
|
| 772 |
+
" (act): NewGELUActivation()\n",
|
| 773 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
| 774 |
+
" )\n",
|
| 775 |
+
" )\n",
|
| 776 |
+
" )\n",
|
| 777 |
+
" (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
|
| 778 |
+
" )\n",
|
| 779 |
+
" (lm_head): Linear(in_features=768, out_features=50264, bias=False)\n",
|
| 780 |
+
")"
|
| 781 |
+
]
|
| 782 |
+
},
|
| 783 |
+
"execution_count": 27,
|
| 784 |
+
"metadata": {},
|
| 785 |
+
"output_type": "execute_result"
|
| 786 |
+
}
|
| 787 |
+
],
|
| 788 |
+
"source": [
|
| 789 |
+
"model.cuda()"
|
| 790 |
+
]
|
| 791 |
+
},
|
| 792 |
+
{
|
| 793 |
+
"cell_type": "code",
|
| 794 |
+
"execution_count": 31,
|
| 795 |
+
"metadata": {
|
| 796 |
+
"id": "hQY6A5q7Bd4O"
|
| 797 |
+
},
|
| 798 |
+
"outputs": [],
|
| 799 |
+
"source": [
|
| 800 |
+
"text = \"<s>Тема: «В чем смысл жизни?»\\nСочинение: \"\n",
|
| 801 |
+
"inpt = tok.encode(text, return_tensors=\"pt\")"
|
| 802 |
+
]
|
| 803 |
+
},
|
| 804 |
+
{
|
| 805 |
+
"cell_type": "code",
|
| 806 |
+
"execution_count": 32,
|
| 807 |
+
"metadata": {
|
| 808 |
+
"id": "1gfJFmeOBj_t"
|
| 809 |
+
},
|
| 810 |
+
"outputs": [
|
| 811 |
+
{
|
| 812 |
+
"name": "stderr",
|
| 813 |
+
"output_type": "stream",
|
| 814 |
+
"text": [
|
| 815 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
| 816 |
+
"Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
|
| 817 |
+
]
|
| 818 |
+
}
|
| 819 |
+
],
|
| 820 |
+
"source": [
|
| 821 |
+
"out = model.generate(inpt.cuda(), max_length=200, repetition_penalty=5.0, do_sample=True, top_k=5, top_p=0.95, temperature=1)"
|
| 822 |
+
]
|
| 823 |
+
},
|
| 824 |
+
{
|
| 825 |
+
"cell_type": "code",
|
| 826 |
+
"execution_count": 33,
|
| 827 |
+
"metadata": {
|
| 828 |
+
"colab": {
|
| 829 |
+
"base_uri": "https://localhost:8080/",
|
| 830 |
+
"height": 123
|
| 831 |
+
},
|
| 832 |
+
"id": "gWZ9SUCxB2Ki",
|
| 833 |
+
"outputId": "31d8e1a3-376f-4f27-bd11-ba59a44983eb"
|
| 834 |
+
},
|
| 835 |
+
"outputs": [
|
| 836 |
+
{
|
| 837 |
+
"name": "stdout",
|
| 838 |
+
"output_type": "stream",
|
| 839 |
+
"text": [
|
| 840 |
+
"<s>Тема: «В чем смысл жизни?»\n",
|
| 841 |
+
"Сочинение: 📹Как часто в наше время мы слышим фразу \"жить надо так, чтобы было хорошо всем\". Однако не все могут себе позволить жить по-другому. В современном мире многие люди хотят изменить свою жизнь к лучшему и сделать ее комфортной для всех без исключения граждан нашей страны.</span] (по пьесе Мольера) \n",
|
| 842 |
+
" Существование системы образования является одним из важнейших условий становления цивилизованного общества на земле – формирования личности гражданина как носителя социально значимых ценностей.Начнем с определения понятия образование - наука о человеке или его способности организовывать свои действия во внешней среде посредством усвоения знаний об окружающем нас обществе; рассмотрим основные функции обучения, которые выполняет школьное учреждение : формирование у детей культуры общения со взрослыми людьми через обучение умению общаться при помощи словаря иностранных слов — это важнейший фактор социализации человека от рождения до смерти || • воспитание нравственности между детьми дошкольного возраста /под ред Н\n"
|
| 843 |
+
]
|
| 844 |
+
}
|
| 845 |
+
],
|
| 846 |
+
"source": [
|
| 847 |
+
"print(tok.decode(out[0]))"
|
| 848 |
+
]
|
| 849 |
+
}
|
| 850 |
+
],
|
| 851 |
+
"metadata": {
|
| 852 |
+
"accelerator": "GPU",
|
| 853 |
+
"colab": {
|
| 854 |
+
"name": "RuGPT3FinetuneHF.ipynb",
|
| 855 |
+
"provenance": []
|
| 856 |
+
},
|
| 857 |
+
"kernelspec": {
|
| 858 |
+
"display_name": "Python 3 (ipykernel)",
|
| 859 |
+
"language": "python",
|
| 860 |
+
"name": "python3"
|
| 861 |
+
},
|
| 862 |
+
"language_info": {
|
| 863 |
+
"codemirror_mode": {
|
| 864 |
+
"name": "ipython",
|
| 865 |
+
"version": 3
|
| 866 |
+
},
|
| 867 |
+
"file_extension": ".py",
|
| 868 |
+
"mimetype": "text/x-python",
|
| 869 |
+
"name": "python",
|
| 870 |
+
"nbconvert_exporter": "python",
|
| 871 |
+
"pygments_lexer": "ipython3",
|
| 872 |
+
"version": "3.10.6"
|
| 873 |
+
}
|
| 874 |
+
},
|
| 875 |
+
"nbformat": 4,
|
| 876 |
+
"nbformat_minor": 1
|
| 877 |
+
}
|
Generate_text.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
np.random.seed(17)
|
| 5 |
+
torch.manual_seed(17)
|
| 6 |
+
|
| 7 |
+
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
| 8 |
+
|
| 9 |
+
def load_tokenizer_and_model(model_name_or_path, device):
|
| 10 |
+
return GPT2Tokenizer.from_pretrained(model_name_or_path), GPT2LMHeadModel.from_pretrained(model_name_or_path).to(device)
|
| 11 |
+
|
| 12 |
+
def generate(
|
| 13 |
+
model, tok, text, device,
|
| 14 |
+
do_sample=True, max_length=200, repetition_penalty=5.0,
|
| 15 |
+
top_k=5, top_p=0.95, temperature=1,
|
| 16 |
+
num_beams=None,
|
| 17 |
+
no_repeat_ngram_size=3
|
| 18 |
+
):
|
| 19 |
+
input_ids = tok.encode(text, return_tensors="pt").to(device)
|
| 20 |
+
out = model.generate(
|
| 21 |
+
input_ids.to(device),
|
| 22 |
+
max_length=max_length,
|
| 23 |
+
repetition_penalty=repetition_penalty,
|
| 24 |
+
do_sample=do_sample,
|
| 25 |
+
top_k=top_k, top_p=top_p, temperature=temperature,
|
| 26 |
+
num_beams=num_beams, no_repeat_ngram_size=no_repeat_ngram_size
|
| 27 |
+
)
|
| 28 |
+
return list(map(tok.decode, out))
|
README.md
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Transformer Homework
|
| 3 |
+
emoji: 😻
|
| 4 |
+
colorFrom: gray
|
| 5 |
+
colorTo: yellow
|
| 6 |
+
sdk: streamlit
|
| 7 |
+
sdk_version: 1.17.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: openrail
|
| 11 |
+
duplicated_from: IkIzma/Transformer_Homework
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
Эта модель https://huggingface.co/ai-forever/rugpt3small_based_on_gpt2, которая была дообучена на датасете сочинений https://www.dropbox.com/s/oa3v9c7g9bp40xw. Весь код а также модель с датасетом представлены в репозитории.
|
| 15 |
+
Основной скрипт для finetune это run_clm.py, который запускается внутри ноутбука FineTune.ipynb
|
| 16 |
+
|
| 17 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
from Generate_text import generate, load_tokenizer_and_model
|
| 3 |
+
|
| 4 |
+
device = "cpu"
|
| 5 |
+
tok, model = load_tokenizer_and_model("models/essays", device)
|
| 6 |
+
|
| 7 |
+
st.markdown("## Генератор сочинений")
|
| 8 |
+
st.markdown("Это приложение представляет из себя примитивный генератор сочинений. В качестве основной модели был взят трансформер RuGPT от Сбера https://huggingface.co/ai-forever/rugpt3small_based_on_gpt2. После этого на том же HF был скачан датасет с сочинениями на русском языке, который использовался для finetuning'а (см. ReadMe)")
|
| 9 |
+
|
| 10 |
+
st.markdown("Для работы с приложением укажите максимальный размер сочинения, который хотите получить. Далее напишите тему. Для генерации сочинения длиной в 200 токенов потребуется подождать 6-7 минут")
|
| 11 |
+
st.markdown("Модель rugpt3small_based_on_gpt2 маленькая и датасет сочинений, на котором она дообучалась, тоже маленький, поэтому не стоит рассчитывать, что получившееся сочинение будет хорошо отражать выбранную тему). Но, по крайней мере, текст получается осмысленным и его стиль вполне соответствует стилю написания сочинений.")
|
| 12 |
+
st.markdown("Чем меньше макисмальная длина, тем хуже получается итоговый результат. Относительно неплохие сочинения получаются при длине 500+, но по времни такая генерация занимает десятки минут")
|
| 13 |
+
with st.columns(3)[1]:
|
| 14 |
+
st.markdown("<img width=200px src='https://ps-static.cdn-tinkoff.ru/static/ai-pushkin/portrait-2021-12-10-12-24-56.png'>", unsafe_allow_html=True)
|
| 15 |
+
# ^-- можно показывать пользователю текст, картинки, ограниченное подмножество html - всё как в jupyter
|
| 16 |
+
|
| 17 |
+
max_len = st.text_area("Максимальная длина")
|
| 18 |
+
text = st.text_area("Тема сочинения")
|
| 19 |
+
# ^-- показать текстовое поле. В поле text лежит строка, которая находится там в данный момент
|
| 20 |
+
thesis = "<s>Тема:" + "«" + text + "»." + "\nСочинение: "
|
| 21 |
+
|
| 22 |
+
if max_len != "":
|
| 23 |
+
begin_index = len(thesis)
|
| 24 |
+
max_len = int(max_len)
|
| 25 |
+
generated = generate(model, tok, thesis, device, max_length=max_len, num_beams=10)
|
| 26 |
+
end_index = (generated[0]).find("</s>")
|
| 27 |
+
|
| 28 |
+
st.markdown(f"Сочинение:\n")
|
| 29 |
+
st.markdown(generated[0][begin_index:end_index])
|
| 30 |
+
# выводим результаты модели в текстовое поле, на потеху пользователю
|
essay_dataset/train.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
essay_dataset/valid.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/essays/README.md
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
tags:
|
| 3 |
+
- generated_from_trainer
|
| 4 |
+
metrics:
|
| 5 |
+
- accuracy
|
| 6 |
+
model-index:
|
| 7 |
+
- name: essays
|
| 8 |
+
results: []
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
<!-- This model card has been generated automatically according to the information the Trainer had access to. You
|
| 12 |
+
should probably proofread and complete it, then remove this comment. -->
|
| 13 |
+
|
| 14 |
+
# essays
|
| 15 |
+
|
| 16 |
+
This model is a fine-tuned version of [sberbank-ai/rugpt3small_based_on_gpt2](https://huggingface.co/sberbank-ai/rugpt3small_based_on_gpt2) on an unknown dataset.
|
| 17 |
+
It achieves the following results on the evaluation set:
|
| 18 |
+
- Loss: 2.7676
|
| 19 |
+
- Accuracy: 0.4758
|
| 20 |
+
|
| 21 |
+
## Model description
|
| 22 |
+
|
| 23 |
+
More information needed
|
| 24 |
+
|
| 25 |
+
## Intended uses & limitations
|
| 26 |
+
|
| 27 |
+
More information needed
|
| 28 |
+
|
| 29 |
+
## Training and evaluation data
|
| 30 |
+
|
| 31 |
+
More information needed
|
| 32 |
+
|
| 33 |
+
## Training procedure
|
| 34 |
+
|
| 35 |
+
### Training hyperparameters
|
| 36 |
+
|
| 37 |
+
The following hyperparameters were used during training:
|
| 38 |
+
- learning_rate: 5e-05
|
| 39 |
+
- train_batch_size: 1
|
| 40 |
+
- eval_batch_size: 1
|
| 41 |
+
- seed: 42
|
| 42 |
+
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
|
| 43 |
+
- lr_scheduler_type: linear
|
| 44 |
+
- num_epochs: 3.0
|
| 45 |
+
|
| 46 |
+
### Training results
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
### Framework versions
|
| 51 |
+
|
| 52 |
+
- Transformers 4.29.0.dev0
|
| 53 |
+
- Pytorch 1.9.1+cu111
|
| 54 |
+
- Datasets 2.11.0
|
| 55 |
+
- Tokenizers 0.13.3
|
models/essays/added_tokens.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"<|endoftext|>": 50257
|
| 3 |
+
}
|
models/essays/all_results.json
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"epoch": 3.0,
|
| 3 |
+
"eval_accuracy": 0.4757639906638441,
|
| 4 |
+
"eval_loss": 2.767648458480835,
|
| 5 |
+
"eval_runtime": 1.0081,
|
| 6 |
+
"eval_samples": 9,
|
| 7 |
+
"eval_samples_per_second": 8.928,
|
| 8 |
+
"eval_steps_per_second": 8.928,
|
| 9 |
+
"perplexity": 15.921150708373386,
|
| 10 |
+
"train_loss": 2.7605772156646284,
|
| 11 |
+
"train_runtime": 87.21,
|
| 12 |
+
"train_samples": 92,
|
| 13 |
+
"train_samples_per_second": 3.165,
|
| 14 |
+
"train_steps_per_second": 3.165
|
| 15 |
+
}
|
models/essays/config.json
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "sberbank-ai/rugpt3small_based_on_gpt2",
|
| 3 |
+
"activation_function": "gelu_new",
|
| 4 |
+
"architectures": [
|
| 5 |
+
"GPT2LMHeadModel"
|
| 6 |
+
],
|
| 7 |
+
"attn_pdrop": 0.1,
|
| 8 |
+
"bos_token_id": 50256,
|
| 9 |
+
"embd_pdrop": 0.1,
|
| 10 |
+
"eos_token_id": 50256,
|
| 11 |
+
"gradient_checkpointing": false,
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"layer_norm_epsilon": 1e-05,
|
| 14 |
+
"model_type": "gpt2",
|
| 15 |
+
"n_ctx": 2048,
|
| 16 |
+
"n_embd": 768,
|
| 17 |
+
"n_head": 12,
|
| 18 |
+
"n_inner": null,
|
| 19 |
+
"n_layer": 12,
|
| 20 |
+
"n_positions": 2048,
|
| 21 |
+
"reorder_and_upcast_attn": false,
|
| 22 |
+
"resid_pdrop": 0.1,
|
| 23 |
+
"scale_attn_by_inverse_layer_idx": false,
|
| 24 |
+
"scale_attn_weights": true,
|
| 25 |
+
"summary_activation": null,
|
| 26 |
+
"summary_first_dropout": 0.1,
|
| 27 |
+
"summary_proj_to_labels": true,
|
| 28 |
+
"summary_type": "cls_index",
|
| 29 |
+
"summary_use_proj": true,
|
| 30 |
+
"torch_dtype": "float32",
|
| 31 |
+
"transformers_version": "4.29.0.dev0",
|
| 32 |
+
"use_cache": true,
|
| 33 |
+
"vocab_size": 50264
|
| 34 |
+
}
|
models/essays/eval_results.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"epoch": 3.0,
|
| 3 |
+
"eval_accuracy": 0.4757639906638441,
|
| 4 |
+
"eval_loss": 2.767648458480835,
|
| 5 |
+
"eval_runtime": 1.0081,
|
| 6 |
+
"eval_samples": 9,
|
| 7 |
+
"eval_samples_per_second": 8.928,
|
| 8 |
+
"eval_steps_per_second": 8.928,
|
| 9 |
+
"perplexity": 15.921150708373386
|
| 10 |
+
}
|
models/essays/generation_config.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"bos_token_id": 50256,
|
| 4 |
+
"eos_token_id": 50256,
|
| 5 |
+
"transformers_version": "4.29.0.dev0"
|
| 6 |
+
}
|
models/essays/merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/essays/pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c8f145dc6ca97655187b357bc9b9484ecde90f445362bae691bc45c2faf0d587
|
| 3 |
+
size 551312489
|
models/essays/special_tokens_map.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "<|endoftext|>",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": true,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"eos_token": {
|
| 10 |
+
"content": "<|endoftext|>",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": true,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"unk_token": {
|
| 17 |
+
"content": "<|endoftext|>",
|
| 18 |
+
"lstrip": false,
|
| 19 |
+
"normalized": true,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"single_word": false
|
| 22 |
+
}
|
| 23 |
+
}
|
models/essays/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/essays/tokenizer_config.json
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": false,
|
| 3 |
+
"add_prefix_space": false,
|
| 4 |
+
"bos_token": {
|
| 5 |
+
"__type": "AddedToken",
|
| 6 |
+
"content": "<|endoftext|>",
|
| 7 |
+
"lstrip": false,
|
| 8 |
+
"normalized": true,
|
| 9 |
+
"rstrip": false,
|
| 10 |
+
"single_word": false
|
| 11 |
+
},
|
| 12 |
+
"clean_up_tokenization_spaces": true,
|
| 13 |
+
"eos_token": {
|
| 14 |
+
"__type": "AddedToken",
|
| 15 |
+
"content": "<|endoftext|>",
|
| 16 |
+
"lstrip": false,
|
| 17 |
+
"normalized": true,
|
| 18 |
+
"rstrip": false,
|
| 19 |
+
"single_word": false
|
| 20 |
+
},
|
| 21 |
+
"errors": "replace",
|
| 22 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 23 |
+
"pad_token": null,
|
| 24 |
+
"tokenizer_class": "GPT2Tokenizer",
|
| 25 |
+
"unk_token": {
|
| 26 |
+
"__type": "AddedToken",
|
| 27 |
+
"content": "<|endoftext|>",
|
| 28 |
+
"lstrip": false,
|
| 29 |
+
"normalized": true,
|
| 30 |
+
"rstrip": false,
|
| 31 |
+
"single_word": false
|
| 32 |
+
}
|
| 33 |
+
}
|
models/essays/train_results.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"epoch": 3.0,
|
| 3 |
+
"train_loss": 2.7605772156646284,
|
| 4 |
+
"train_runtime": 87.21,
|
| 5 |
+
"train_samples": 92,
|
| 6 |
+
"train_samples_per_second": 3.165,
|
| 7 |
+
"train_steps_per_second": 3.165
|
| 8 |
+
}
|
models/essays/trainer_state.json
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"best_metric": null,
|
| 3 |
+
"best_model_checkpoint": null,
|
| 4 |
+
"epoch": 3.0,
|
| 5 |
+
"global_step": 276,
|
| 6 |
+
"is_hyper_param_search": false,
|
| 7 |
+
"is_local_process_zero": true,
|
| 8 |
+
"is_world_process_zero": true,
|
| 9 |
+
"log_history": [
|
| 10 |
+
{
|
| 11 |
+
"epoch": 3.0,
|
| 12 |
+
"step": 276,
|
| 13 |
+
"total_flos": 288466403328000.0,
|
| 14 |
+
"train_loss": 2.7605772156646284,
|
| 15 |
+
"train_runtime": 87.21,
|
| 16 |
+
"train_samples_per_second": 3.165,
|
| 17 |
+
"train_steps_per_second": 3.165
|
| 18 |
+
}
|
| 19 |
+
],
|
| 20 |
+
"max_steps": 276,
|
| 21 |
+
"num_train_epochs": 3,
|
| 22 |
+
"total_flos": 288466403328000.0,
|
| 23 |
+
"trial_name": null,
|
| 24 |
+
"trial_params": null
|
| 25 |
+
}
|
models/essays/training_args.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ffc90d13a3c1e167dedc3e673a03634beb1998f6a5a2653d6c61b9729498f4d7
|
| 3 |
+
size 3567
|
models/essays/vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
transformers
|
| 2 |
+
torch
|
run_clm.py
ADDED
|
@@ -0,0 +1,635 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
# Copyright 2020 The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""
|
| 17 |
+
Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
|
| 18 |
+
|
| 19 |
+
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
|
| 20 |
+
https://huggingface.co/models?filter=text-generation
|
| 21 |
+
"""
|
| 22 |
+
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
|
| 23 |
+
|
| 24 |
+
import logging
|
| 25 |
+
import math
|
| 26 |
+
import os
|
| 27 |
+
import sys
|
| 28 |
+
from dataclasses import dataclass, field
|
| 29 |
+
from itertools import chain
|
| 30 |
+
from typing import Optional
|
| 31 |
+
|
| 32 |
+
import datasets
|
| 33 |
+
import evaluate
|
| 34 |
+
import torch
|
| 35 |
+
from datasets import load_dataset
|
| 36 |
+
|
| 37 |
+
import transformers
|
| 38 |
+
from transformers import (
|
| 39 |
+
CONFIG_MAPPING,
|
| 40 |
+
MODEL_FOR_CAUSAL_LM_MAPPING,
|
| 41 |
+
AutoConfig,
|
| 42 |
+
AutoModelForCausalLM,
|
| 43 |
+
AutoTokenizer,
|
| 44 |
+
HfArgumentParser,
|
| 45 |
+
Trainer,
|
| 46 |
+
TrainingArguments,
|
| 47 |
+
default_data_collator,
|
| 48 |
+
is_torch_tpu_available,
|
| 49 |
+
set_seed,
|
| 50 |
+
)
|
| 51 |
+
from transformers.testing_utils import CaptureLogger
|
| 52 |
+
from transformers.trainer_utils import get_last_checkpoint
|
| 53 |
+
from transformers.utils import check_min_version, send_example_telemetry
|
| 54 |
+
from transformers.utils.versions import require_version
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
| 58 |
+
check_min_version("4.29.0.dev0")
|
| 59 |
+
|
| 60 |
+
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
| 61 |
+
|
| 62 |
+
logger = logging.getLogger(__name__)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
|
| 66 |
+
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@dataclass
|
| 70 |
+
class ModelArguments:
|
| 71 |
+
"""
|
| 72 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
model_name_or_path: Optional[str] = field(
|
| 76 |
+
default=None,
|
| 77 |
+
metadata={
|
| 78 |
+
"help": (
|
| 79 |
+
"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
|
| 80 |
+
)
|
| 81 |
+
},
|
| 82 |
+
)
|
| 83 |
+
model_type: Optional[str] = field(
|
| 84 |
+
default=None,
|
| 85 |
+
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
|
| 86 |
+
)
|
| 87 |
+
config_overrides: Optional[str] = field(
|
| 88 |
+
default=None,
|
| 89 |
+
metadata={
|
| 90 |
+
"help": (
|
| 91 |
+
"Override some existing default config settings when a model is trained from scratch. Example: "
|
| 92 |
+
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
|
| 93 |
+
)
|
| 94 |
+
},
|
| 95 |
+
)
|
| 96 |
+
config_name: Optional[str] = field(
|
| 97 |
+
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
| 98 |
+
)
|
| 99 |
+
tokenizer_name: Optional[str] = field(
|
| 100 |
+
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
| 101 |
+
)
|
| 102 |
+
cache_dir: Optional[str] = field(
|
| 103 |
+
default=None,
|
| 104 |
+
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
|
| 105 |
+
)
|
| 106 |
+
use_fast_tokenizer: bool = field(
|
| 107 |
+
default=True,
|
| 108 |
+
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
| 109 |
+
)
|
| 110 |
+
model_revision: str = field(
|
| 111 |
+
default="main",
|
| 112 |
+
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
| 113 |
+
)
|
| 114 |
+
use_auth_token: bool = field(
|
| 115 |
+
default=False,
|
| 116 |
+
metadata={
|
| 117 |
+
"help": (
|
| 118 |
+
"Will use the token generated when running `huggingface-cli login` (necessary to use this script "
|
| 119 |
+
"with private models)."
|
| 120 |
+
)
|
| 121 |
+
},
|
| 122 |
+
)
|
| 123 |
+
torch_dtype: Optional[str] = field(
|
| 124 |
+
default=None,
|
| 125 |
+
metadata={
|
| 126 |
+
"help": (
|
| 127 |
+
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
|
| 128 |
+
"dtype will be automatically derived from the model's weights."
|
| 129 |
+
),
|
| 130 |
+
"choices": ["auto", "bfloat16", "float16", "float32"],
|
| 131 |
+
},
|
| 132 |
+
)
|
| 133 |
+
low_cpu_mem_usage: bool = field(
|
| 134 |
+
default=False,
|
| 135 |
+
metadata={
|
| 136 |
+
"help": (
|
| 137 |
+
"It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded."
|
| 138 |
+
"set True will benefit LLM loading time and RAM consumption."
|
| 139 |
+
)
|
| 140 |
+
},
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
def __post_init__(self):
|
| 144 |
+
if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
|
| 145 |
+
raise ValueError(
|
| 146 |
+
"--config_overrides can't be used in combination with --config_name or --model_name_or_path"
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
@dataclass
|
| 151 |
+
class DataTrainingArguments:
|
| 152 |
+
"""
|
| 153 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
| 154 |
+
"""
|
| 155 |
+
|
| 156 |
+
dataset_name: Optional[str] = field(
|
| 157 |
+
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
| 158 |
+
)
|
| 159 |
+
dataset_config_name: Optional[str] = field(
|
| 160 |
+
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
| 161 |
+
)
|
| 162 |
+
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
|
| 163 |
+
validation_file: Optional[str] = field(
|
| 164 |
+
default=None,
|
| 165 |
+
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
|
| 166 |
+
)
|
| 167 |
+
max_train_samples: Optional[int] = field(
|
| 168 |
+
default=None,
|
| 169 |
+
metadata={
|
| 170 |
+
"help": (
|
| 171 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
| 172 |
+
"value if set."
|
| 173 |
+
)
|
| 174 |
+
},
|
| 175 |
+
)
|
| 176 |
+
max_eval_samples: Optional[int] = field(
|
| 177 |
+
default=None,
|
| 178 |
+
metadata={
|
| 179 |
+
"help": (
|
| 180 |
+
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
| 181 |
+
"value if set."
|
| 182 |
+
)
|
| 183 |
+
},
|
| 184 |
+
)
|
| 185 |
+
streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"})
|
| 186 |
+
block_size: Optional[int] = field(
|
| 187 |
+
default=None,
|
| 188 |
+
metadata={
|
| 189 |
+
"help": (
|
| 190 |
+
"Optional input sequence length after tokenization. "
|
| 191 |
+
"The training dataset will be truncated in block of this size for training. "
|
| 192 |
+
"Default to the model max input length for single sentence inputs (take into account special tokens)."
|
| 193 |
+
)
|
| 194 |
+
},
|
| 195 |
+
)
|
| 196 |
+
overwrite_cache: bool = field(
|
| 197 |
+
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
| 198 |
+
)
|
| 199 |
+
validation_split_percentage: Optional[int] = field(
|
| 200 |
+
default=5,
|
| 201 |
+
metadata={
|
| 202 |
+
"help": "The percentage of the train set used as validation set in case there's no validation split"
|
| 203 |
+
},
|
| 204 |
+
)
|
| 205 |
+
preprocessing_num_workers: Optional[int] = field(
|
| 206 |
+
default=None,
|
| 207 |
+
metadata={"help": "The number of processes to use for the preprocessing."},
|
| 208 |
+
)
|
| 209 |
+
keep_linebreaks: bool = field(
|
| 210 |
+
default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."}
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
def __post_init__(self):
|
| 214 |
+
if self.streaming:
|
| 215 |
+
require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`")
|
| 216 |
+
|
| 217 |
+
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
| 218 |
+
raise ValueError("Need either a dataset name or a training/validation file.")
|
| 219 |
+
else:
|
| 220 |
+
if self.train_file is not None:
|
| 221 |
+
extension = self.train_file.split(".")[-1]
|
| 222 |
+
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
|
| 223 |
+
if self.validation_file is not None:
|
| 224 |
+
extension = self.validation_file.split(".")[-1]
|
| 225 |
+
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def main():
|
| 229 |
+
# See all possible arguments in src/transformers/training_args.py
|
| 230 |
+
# or by passing the --help flag to this script.
|
| 231 |
+
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
| 232 |
+
|
| 233 |
+
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
| 234 |
+
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
| 235 |
+
# If we pass only one argument to the script and it's the path to a json file,
|
| 236 |
+
# let's parse it to get our arguments.
|
| 237 |
+
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
| 238 |
+
else:
|
| 239 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
| 240 |
+
|
| 241 |
+
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
|
| 242 |
+
# information sent is the one passed as arguments along with your Python/PyTorch versions.
|
| 243 |
+
send_example_telemetry("run_clm", model_args, data_args)
|
| 244 |
+
|
| 245 |
+
# Setup logging
|
| 246 |
+
logging.basicConfig(
|
| 247 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 248 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 249 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
if training_args.should_log:
|
| 253 |
+
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
|
| 254 |
+
transformers.utils.logging.set_verbosity_info()
|
| 255 |
+
|
| 256 |
+
log_level = training_args.get_process_log_level()
|
| 257 |
+
logger.setLevel(log_level)
|
| 258 |
+
datasets.utils.logging.set_verbosity(log_level)
|
| 259 |
+
transformers.utils.logging.set_verbosity(log_level)
|
| 260 |
+
transformers.utils.logging.enable_default_handler()
|
| 261 |
+
transformers.utils.logging.enable_explicit_format()
|
| 262 |
+
|
| 263 |
+
# Log on each process the small summary:
|
| 264 |
+
logger.warning(
|
| 265 |
+
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
| 266 |
+
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
| 267 |
+
)
|
| 268 |
+
logger.info(f"Training/evaluation parameters {training_args}")
|
| 269 |
+
|
| 270 |
+
# Detecting last checkpoint.
|
| 271 |
+
last_checkpoint = None
|
| 272 |
+
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
|
| 273 |
+
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
| 274 |
+
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
| 275 |
+
raise ValueError(
|
| 276 |
+
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
| 277 |
+
"Use --overwrite_output_dir to overcome."
|
| 278 |
+
)
|
| 279 |
+
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
| 280 |
+
logger.info(
|
| 281 |
+
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
| 282 |
+
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
# Set seed before initializing model.
|
| 286 |
+
set_seed(training_args.seed)
|
| 287 |
+
|
| 288 |
+
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
| 289 |
+
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
| 290 |
+
# (the dataset will be downloaded automatically from the datasets Hub).
|
| 291 |
+
#
|
| 292 |
+
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
|
| 293 |
+
# 'text' is found. You can easily tweak this behavior (see below).
|
| 294 |
+
#
|
| 295 |
+
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
|
| 296 |
+
# download the dataset.
|
| 297 |
+
if data_args.dataset_name is not None:
|
| 298 |
+
# Downloading and loading a dataset from the hub.
|
| 299 |
+
raw_datasets = load_dataset(
|
| 300 |
+
data_args.dataset_name,
|
| 301 |
+
data_args.dataset_config_name,
|
| 302 |
+
cache_dir=model_args.cache_dir,
|
| 303 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
| 304 |
+
streaming=data_args.streaming,
|
| 305 |
+
)
|
| 306 |
+
if "validation" not in raw_datasets.keys():
|
| 307 |
+
raw_datasets["validation"] = load_dataset(
|
| 308 |
+
data_args.dataset_name,
|
| 309 |
+
data_args.dataset_config_name,
|
| 310 |
+
split=f"train[:{data_args.validation_split_percentage}%]",
|
| 311 |
+
cache_dir=model_args.cache_dir,
|
| 312 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
| 313 |
+
streaming=data_args.streaming,
|
| 314 |
+
)
|
| 315 |
+
raw_datasets["train"] = load_dataset(
|
| 316 |
+
data_args.dataset_name,
|
| 317 |
+
data_args.dataset_config_name,
|
| 318 |
+
split=f"train[{data_args.validation_split_percentage}%:]",
|
| 319 |
+
cache_dir=model_args.cache_dir,
|
| 320 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
| 321 |
+
streaming=data_args.streaming,
|
| 322 |
+
)
|
| 323 |
+
else:
|
| 324 |
+
data_files = {}
|
| 325 |
+
dataset_args = {}
|
| 326 |
+
if data_args.train_file is not None:
|
| 327 |
+
data_files["train"] = data_args.train_file
|
| 328 |
+
if data_args.validation_file is not None:
|
| 329 |
+
data_files["validation"] = data_args.validation_file
|
| 330 |
+
extension = (
|
| 331 |
+
data_args.train_file.split(".")[-1]
|
| 332 |
+
if data_args.train_file is not None
|
| 333 |
+
else data_args.validation_file.split(".")[-1]
|
| 334 |
+
)
|
| 335 |
+
if extension == "txt":
|
| 336 |
+
extension = "text"
|
| 337 |
+
dataset_args["keep_linebreaks"] = data_args.keep_linebreaks
|
| 338 |
+
raw_datasets = load_dataset(
|
| 339 |
+
extension,
|
| 340 |
+
data_files=data_files,
|
| 341 |
+
cache_dir=model_args.cache_dir,
|
| 342 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
| 343 |
+
**dataset_args,
|
| 344 |
+
)
|
| 345 |
+
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
|
| 346 |
+
if "validation" not in raw_datasets.keys():
|
| 347 |
+
raw_datasets["validation"] = load_dataset(
|
| 348 |
+
extension,
|
| 349 |
+
data_files=data_files,
|
| 350 |
+
split=f"train[:{data_args.validation_split_percentage}%]",
|
| 351 |
+
cache_dir=model_args.cache_dir,
|
| 352 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
| 353 |
+
**dataset_args,
|
| 354 |
+
)
|
| 355 |
+
raw_datasets["train"] = load_dataset(
|
| 356 |
+
extension,
|
| 357 |
+
data_files=data_files,
|
| 358 |
+
split=f"train[{data_args.validation_split_percentage}%:]",
|
| 359 |
+
cache_dir=model_args.cache_dir,
|
| 360 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
| 361 |
+
**dataset_args,
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
| 365 |
+
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
| 366 |
+
|
| 367 |
+
# Load pretrained model and tokenizer
|
| 368 |
+
#
|
| 369 |
+
# Distributed training:
|
| 370 |
+
# The .from_pretrained methods guarantee that only one local process can concurrently
|
| 371 |
+
# download model & vocab.
|
| 372 |
+
|
| 373 |
+
config_kwargs = {
|
| 374 |
+
"cache_dir": model_args.cache_dir,
|
| 375 |
+
"revision": model_args.model_revision,
|
| 376 |
+
"use_auth_token": True if model_args.use_auth_token else None,
|
| 377 |
+
}
|
| 378 |
+
if model_args.config_name:
|
| 379 |
+
config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
|
| 380 |
+
elif model_args.model_name_or_path:
|
| 381 |
+
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
| 382 |
+
else:
|
| 383 |
+
config = CONFIG_MAPPING[model_args.model_type]()
|
| 384 |
+
logger.warning("You are instantiating a new config instance from scratch.")
|
| 385 |
+
if model_args.config_overrides is not None:
|
| 386 |
+
logger.info(f"Overriding config: {model_args.config_overrides}")
|
| 387 |
+
config.update_from_string(model_args.config_overrides)
|
| 388 |
+
logger.info(f"New config: {config}")
|
| 389 |
+
|
| 390 |
+
tokenizer_kwargs = {
|
| 391 |
+
"cache_dir": model_args.cache_dir,
|
| 392 |
+
"use_fast": model_args.use_fast_tokenizer,
|
| 393 |
+
"revision": model_args.model_revision,
|
| 394 |
+
"use_auth_token": True if model_args.use_auth_token else None,
|
| 395 |
+
}
|
| 396 |
+
if model_args.tokenizer_name:
|
| 397 |
+
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
|
| 398 |
+
elif model_args.model_name_or_path:
|
| 399 |
+
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs)
|
| 400 |
+
else:
|
| 401 |
+
raise ValueError(
|
| 402 |
+
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
|
| 403 |
+
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
if model_args.model_name_or_path:
|
| 407 |
+
torch_dtype = (
|
| 408 |
+
model_args.torch_dtype
|
| 409 |
+
if model_args.torch_dtype in ["auto", None]
|
| 410 |
+
else getattr(torch, model_args.torch_dtype)
|
| 411 |
+
)
|
| 412 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 413 |
+
model_args.model_name_or_path,
|
| 414 |
+
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
| 415 |
+
config=config,
|
| 416 |
+
cache_dir=model_args.cache_dir,
|
| 417 |
+
revision=model_args.model_revision,
|
| 418 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
| 419 |
+
torch_dtype=torch_dtype,
|
| 420 |
+
low_cpu_mem_usage=model_args.low_cpu_mem_usage,
|
| 421 |
+
)
|
| 422 |
+
else:
|
| 423 |
+
model = AutoModelForCausalLM.from_config(config)
|
| 424 |
+
n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values())
|
| 425 |
+
logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params")
|
| 426 |
+
|
| 427 |
+
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
|
| 428 |
+
# on a small vocab and want a smaller embedding size, remove this test.
|
| 429 |
+
embedding_size = model.get_input_embeddings().weight.shape[0]
|
| 430 |
+
if len(tokenizer) > embedding_size:
|
| 431 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 432 |
+
|
| 433 |
+
# Preprocessing the datasets.
|
| 434 |
+
# First we tokenize all the texts.
|
| 435 |
+
if training_args.do_train:
|
| 436 |
+
column_names = list(raw_datasets["train"].features)
|
| 437 |
+
else:
|
| 438 |
+
column_names = list(raw_datasets["validation"].features)
|
| 439 |
+
text_column_name = "text" if "text" in column_names else column_names[0]
|
| 440 |
+
|
| 441 |
+
# since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
|
| 442 |
+
tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")
|
| 443 |
+
|
| 444 |
+
def tokenize_function(examples):
|
| 445 |
+
with CaptureLogger(tok_logger) as cl:
|
| 446 |
+
output = tokenizer(examples[text_column_name])
|
| 447 |
+
# clm input could be much much longer than block_size
|
| 448 |
+
if "Token indices sequence length is longer than the" in cl.out:
|
| 449 |
+
tok_logger.warning(
|
| 450 |
+
"^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
|
| 451 |
+
" before being passed to the model."
|
| 452 |
+
)
|
| 453 |
+
return output
|
| 454 |
+
|
| 455 |
+
with training_args.main_process_first(desc="dataset map tokenization"):
|
| 456 |
+
if not data_args.streaming:
|
| 457 |
+
tokenized_datasets = raw_datasets.map(
|
| 458 |
+
tokenize_function,
|
| 459 |
+
batched=True,
|
| 460 |
+
num_proc=data_args.preprocessing_num_workers,
|
| 461 |
+
remove_columns=column_names,
|
| 462 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
| 463 |
+
desc="Running tokenizer on dataset",
|
| 464 |
+
)
|
| 465 |
+
else:
|
| 466 |
+
tokenized_datasets = raw_datasets.map(
|
| 467 |
+
tokenize_function,
|
| 468 |
+
batched=True,
|
| 469 |
+
remove_columns=column_names,
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
if data_args.block_size is None:
|
| 473 |
+
block_size = tokenizer.model_max_length
|
| 474 |
+
if block_size > 1024:
|
| 475 |
+
logger.warning(
|
| 476 |
+
"The chosen tokenizer supports a `model_max_length` that is longer than the default `block_size` value"
|
| 477 |
+
" of 1024. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can"
|
| 478 |
+
" override this default with `--block_size xxx`."
|
| 479 |
+
)
|
| 480 |
+
block_size = 1024
|
| 481 |
+
else:
|
| 482 |
+
if data_args.block_size > tokenizer.model_max_length:
|
| 483 |
+
logger.warning(
|
| 484 |
+
f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model"
|
| 485 |
+
f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
|
| 486 |
+
)
|
| 487 |
+
block_size = min(data_args.block_size, tokenizer.model_max_length)
|
| 488 |
+
|
| 489 |
+
# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
|
| 490 |
+
def group_texts(examples):
|
| 491 |
+
# Concatenate all texts.
|
| 492 |
+
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
|
| 493 |
+
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
| 494 |
+
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
| 495 |
+
# customize this part to your needs.
|
| 496 |
+
if total_length >= block_size:
|
| 497 |
+
total_length = (total_length // block_size) * block_size
|
| 498 |
+
# Split by chunks of max_len.
|
| 499 |
+
result = {
|
| 500 |
+
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
| 501 |
+
for k, t in concatenated_examples.items()
|
| 502 |
+
}
|
| 503 |
+
result["labels"] = result["input_ids"].copy()
|
| 504 |
+
return result
|
| 505 |
+
|
| 506 |
+
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
|
| 507 |
+
# for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
|
| 508 |
+
# to preprocess.
|
| 509 |
+
#
|
| 510 |
+
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
|
| 511 |
+
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
|
| 512 |
+
|
| 513 |
+
with training_args.main_process_first(desc="grouping texts together"):
|
| 514 |
+
if not data_args.streaming:
|
| 515 |
+
lm_datasets = tokenized_datasets.map(
|
| 516 |
+
group_texts,
|
| 517 |
+
batched=True,
|
| 518 |
+
num_proc=data_args.preprocessing_num_workers,
|
| 519 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
| 520 |
+
desc=f"Grouping texts in chunks of {block_size}",
|
| 521 |
+
)
|
| 522 |
+
else:
|
| 523 |
+
lm_datasets = tokenized_datasets.map(
|
| 524 |
+
group_texts,
|
| 525 |
+
batched=True,
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
if training_args.do_train:
|
| 529 |
+
if "train" not in tokenized_datasets:
|
| 530 |
+
raise ValueError("--do_train requires a train dataset")
|
| 531 |
+
train_dataset = lm_datasets["train"]
|
| 532 |
+
if data_args.max_train_samples is not None:
|
| 533 |
+
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
|
| 534 |
+
train_dataset = train_dataset.select(range(max_train_samples))
|
| 535 |
+
|
| 536 |
+
if training_args.do_eval:
|
| 537 |
+
if "validation" not in tokenized_datasets:
|
| 538 |
+
raise ValueError("--do_eval requires a validation dataset")
|
| 539 |
+
eval_dataset = lm_datasets["validation"]
|
| 540 |
+
if data_args.max_eval_samples is not None:
|
| 541 |
+
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
|
| 542 |
+
eval_dataset = eval_dataset.select(range(max_eval_samples))
|
| 543 |
+
|
| 544 |
+
def preprocess_logits_for_metrics(logits, labels):
|
| 545 |
+
if isinstance(logits, tuple):
|
| 546 |
+
# Depending on the model and config, logits may contain extra tensors,
|
| 547 |
+
# like past_key_values, but logits always come first
|
| 548 |
+
logits = logits[0]
|
| 549 |
+
return logits.argmax(dim=-1)
|
| 550 |
+
|
| 551 |
+
metric = evaluate.load("accuracy")
|
| 552 |
+
|
| 553 |
+
def compute_metrics(eval_preds):
|
| 554 |
+
preds, labels = eval_preds
|
| 555 |
+
# preds have the same shape as the labels, after the argmax(-1) has been calculated
|
| 556 |
+
# by preprocess_logits_for_metrics but we need to shift the labels
|
| 557 |
+
labels = labels[:, 1:].reshape(-1)
|
| 558 |
+
preds = preds[:, :-1].reshape(-1)
|
| 559 |
+
return metric.compute(predictions=preds, references=labels)
|
| 560 |
+
|
| 561 |
+
# Initialize our Trainer
|
| 562 |
+
trainer = Trainer(
|
| 563 |
+
model=model,
|
| 564 |
+
args=training_args,
|
| 565 |
+
train_dataset=train_dataset if training_args.do_train else None,
|
| 566 |
+
eval_dataset=eval_dataset if training_args.do_eval else None,
|
| 567 |
+
tokenizer=tokenizer,
|
| 568 |
+
# Data collator will default to DataCollatorWithPadding, so we change it.
|
| 569 |
+
data_collator=default_data_collator,
|
| 570 |
+
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
|
| 571 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics
|
| 572 |
+
if training_args.do_eval and not is_torch_tpu_available()
|
| 573 |
+
else None,
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
# Training
|
| 577 |
+
if training_args.do_train:
|
| 578 |
+
checkpoint = None
|
| 579 |
+
if training_args.resume_from_checkpoint is not None:
|
| 580 |
+
checkpoint = training_args.resume_from_checkpoint
|
| 581 |
+
elif last_checkpoint is not None:
|
| 582 |
+
checkpoint = last_checkpoint
|
| 583 |
+
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
| 584 |
+
trainer.save_model() # Saves the tokenizer too for easy upload
|
| 585 |
+
|
| 586 |
+
metrics = train_result.metrics
|
| 587 |
+
|
| 588 |
+
max_train_samples = (
|
| 589 |
+
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
|
| 590 |
+
)
|
| 591 |
+
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
|
| 592 |
+
|
| 593 |
+
trainer.log_metrics("train", metrics)
|
| 594 |
+
trainer.save_metrics("train", metrics)
|
| 595 |
+
trainer.save_state()
|
| 596 |
+
|
| 597 |
+
# Evaluation
|
| 598 |
+
if training_args.do_eval:
|
| 599 |
+
logger.info("*** Evaluate ***")
|
| 600 |
+
|
| 601 |
+
metrics = trainer.evaluate()
|
| 602 |
+
|
| 603 |
+
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
| 604 |
+
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
| 605 |
+
try:
|
| 606 |
+
perplexity = math.exp(metrics["eval_loss"])
|
| 607 |
+
except OverflowError:
|
| 608 |
+
perplexity = float("inf")
|
| 609 |
+
metrics["perplexity"] = perplexity
|
| 610 |
+
|
| 611 |
+
trainer.log_metrics("eval", metrics)
|
| 612 |
+
trainer.save_metrics("eval", metrics)
|
| 613 |
+
|
| 614 |
+
kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"}
|
| 615 |
+
if data_args.dataset_name is not None:
|
| 616 |
+
kwargs["dataset_tags"] = data_args.dataset_name
|
| 617 |
+
if data_args.dataset_config_name is not None:
|
| 618 |
+
kwargs["dataset_args"] = data_args.dataset_config_name
|
| 619 |
+
kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
|
| 620 |
+
else:
|
| 621 |
+
kwargs["dataset"] = data_args.dataset_name
|
| 622 |
+
|
| 623 |
+
if training_args.push_to_hub:
|
| 624 |
+
trainer.push_to_hub(**kwargs)
|
| 625 |
+
else:
|
| 626 |
+
trainer.create_model_card(**kwargs)
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
def _mp_fn(index):
|
| 630 |
+
# For xla_spawn (TPUs)
|
| 631 |
+
main()
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
if __name__ == "__main__":
|
| 635 |
+
main()
|