added SentEval
Browse files- SentEval/.gitignore +16 -0
- SentEval/LICENSE +30 -0
- SentEval/README.md +249 -0
- SentEval/examples/bow.py +112 -0
- SentEval/examples/gensen.py +74 -0
- SentEval/examples/googleuse.py +67 -0
- SentEval/examples/infersent.py +76 -0
- SentEval/examples/models.py +265 -0
- SentEval/examples/skipthought.py +61 -0
- SentEval/senteval/__init__.py +10 -0
- SentEval/senteval/binary.py +92 -0
- SentEval/senteval/engine.py +129 -0
- SentEval/senteval/mrpc.py +104 -0
- SentEval/senteval/probing.py +171 -0
- SentEval/senteval/rank.py +108 -0
- SentEval/senteval/sick.py +216 -0
- SentEval/senteval/snli.py +113 -0
- SentEval/senteval/sst.py +96 -0
- SentEval/senteval/sts.py +231 -0
- SentEval/senteval/tools/__init__.py +0 -0
- SentEval/senteval/tools/classifier.py +202 -0
- SentEval/senteval/tools/ranking.py +359 -0
- SentEval/senteval/tools/relatedness.py +134 -0
- SentEval/senteval/tools/validation.py +246 -0
- SentEval/senteval/trec.py +89 -0
- SentEval/senteval/utils.py +95 -0
- SentEval/setup.py +21 -0
- data/._data_csv_default-6b8a73dfc1f26733_0.0.0_6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317.lock +0 -0
- data/csv/default-6b8a73dfc1f26733/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317.incomplete_info.lock +0 -0
- data/csv/default-6b8a73dfc1f26733/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317/cache-e43d857791056f6f.arrow +3 -0
- data/csv/default-6b8a73dfc1f26733/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317/csv-train.arrow +3 -0
- data/csv/default-6b8a73dfc1f26733/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317/dataset_info.json +1 -0
- data/csv/default-6b8a73dfc1f26733/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317_builder.lock +0 -0
- result/sup-simcse-nb-bert-base/config.json +31 -0
- result/sup-simcse-nb-bert-base/pytorch_model.bin +3 -0
- result/sup-simcse-nb-bert-base/special_tokens_map.json +1 -0
- result/sup-simcse-nb-bert-base/tokenizer_config.json +1 -0
- result/sup-simcse-nb-bert-base/train_results.txt +3 -0
- result/sup-simcse-nb-bert-base/trainer_state.json +22 -0
- result/sup-simcse-nb-bert-base/training_args.bin +3 -0
- result/sup-simcse-nb-bert-base/vocab.txt +3 -0
- runs/Oct21_13-13-50_t1v-n-d0240692-w-0/1666358047.7059593/events.out.tfevents.1666358047.t1v-n-d0240692-w-0.37317.1 +3 -0
- runs/Oct21_13-13-50_t1v-n-d0240692-w-0/events.out.tfevents.1666358047.t1v-n-d0240692-w-0.37317.0 +3 -0
- runs/Oct21_13-17-52_t1v-n-d0240692-w-0/1666358281.579476/events.out.tfevents.1666358281.t1v-n-d0240692-w-0.41386.1 +3 -0
- runs/Oct21_13-17-52_t1v-n-d0240692-w-0/events.out.tfevents.1666358281.t1v-n-d0240692-w-0.41386.0 +3 -0
SentEval/.gitignore
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SentEval data and .pyc files
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# python
|
| 6 |
+
__pycache__/
|
| 7 |
+
*.py[cod]
|
| 8 |
+
*$py.class
|
| 9 |
+
|
| 10 |
+
# log files
|
| 11 |
+
*.log
|
| 12 |
+
*.txt
|
| 13 |
+
|
| 14 |
+
# data files
|
| 15 |
+
data/senteval_data*
|
| 16 |
+
data/downstream/
|
SentEval/LICENSE
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
BSD License
|
| 2 |
+
|
| 3 |
+
For SentEval software
|
| 4 |
+
|
| 5 |
+
Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
|
| 6 |
+
|
| 7 |
+
Redistribution and use in source and binary forms, with or without modification,
|
| 8 |
+
are permitted provided that the following conditions are met:
|
| 9 |
+
|
| 10 |
+
* Redistributions of source code must retain the above copyright notice, this
|
| 11 |
+
list of conditions and the following disclaimer.
|
| 12 |
+
|
| 13 |
+
* Redistributions in binary form must reproduce the above copyright notice,
|
| 14 |
+
this list of conditions and the following disclaimer in the documentation
|
| 15 |
+
and/or other materials provided with the distribution.
|
| 16 |
+
|
| 17 |
+
* Neither the name Facebook nor the names of its contributors may be used to
|
| 18 |
+
endorse or promote products derived from this software without specific
|
| 19 |
+
prior written permission.
|
| 20 |
+
|
| 21 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
| 22 |
+
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
| 23 |
+
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 24 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
|
| 25 |
+
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
| 26 |
+
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
| 27 |
+
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
| 28 |
+
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
| 29 |
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
| 30 |
+
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
SentEval/README.md
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Our modification to SentEval:
|
| 2 |
+
|
| 3 |
+
1. Add the `all` setting to all STS tasks.
|
| 4 |
+
2. Change STS-B and SICK-R to not use an additional regressor.
|
| 5 |
+
|
| 6 |
+
# SentEval: evaluation toolkit for sentence embeddings
|
| 7 |
+
|
| 8 |
+
SentEval is a library for evaluating the quality of sentence embeddings. We assess their generalization power by using them as features on a broad and diverse set of "transfer" tasks. **SentEval currently includes 17 downstream tasks**. We also include a suite of **10 probing tasks** which evaluate what linguistic properties are encoded in sentence embeddings. Our goal is to ease the study and the development of general-purpose fixed-size sentence representations.
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
**(04/22) SentEval new tasks: Added probing tasks for evaluating what linguistic properties are encoded in sentence embeddings**
|
| 12 |
+
|
| 13 |
+
**(10/04) SentEval example scripts for three sentence encoders: [SkipThought-LN](https://github.com/ryankiros/layer-norm#skip-thoughts)/[GenSen](https://github.com/Maluuba/gensen)/[Google-USE](https://tfhub.dev/google/universal-sentence-encoder/1)**
|
| 14 |
+
|
| 15 |
+
## Dependencies
|
| 16 |
+
|
| 17 |
+
This code is written in python. The dependencies are:
|
| 18 |
+
|
| 19 |
+
* Python 2/3 with [NumPy](http://www.numpy.org/)/[SciPy](http://www.scipy.org/)
|
| 20 |
+
* [Pytorch](http://pytorch.org/)>=0.4
|
| 21 |
+
* [scikit-learn](http://scikit-learn.org/stable/index.html)>=0.18.0
|
| 22 |
+
|
| 23 |
+
## Transfer tasks
|
| 24 |
+
|
| 25 |
+
### Downstream tasks
|
| 26 |
+
SentEval allows you to evaluate your sentence embeddings as features for the following *downstream* tasks:
|
| 27 |
+
|
| 28 |
+
| Task | Type | #train | #test | needs_train | set_classifier |
|
| 29 |
+
|---------- |------------------------------ |-----------:|----------:|:-----------:|:----------:|
|
| 30 |
+
| [MR](https://nlp.stanford.edu/~sidaw/home/projects:nbsvm) | movie review | 11k | 11k | 1 | 1 |
|
| 31 |
+
| [CR](https://nlp.stanford.edu/~sidaw/home/projects:nbsvm) | product review | 4k | 4k | 1 | 1 |
|
| 32 |
+
| [SUBJ](https://nlp.stanford.edu/~sidaw/home/projects:nbsvm) | subjectivity status | 10k | 10k | 1 | 1 |
|
| 33 |
+
| [MPQA](https://nlp.stanford.edu/~sidaw/home/projects:nbsvm) | opinion-polarity | 11k | 11k | 1 | 1 |
|
| 34 |
+
| [SST](https://nlp.stanford.edu/sentiment/index.html) | binary sentiment analysis | 67k | 1.8k | 1 | 1 |
|
| 35 |
+
| **[SST](https://nlp.stanford.edu/sentiment/index.html)** | **fine-grained sentiment analysis** | 8.5k | 2.2k | 1 | 1 |
|
| 36 |
+
| [TREC](http://cogcomp.cs.illinois.edu/Data/QA/QC/) | question-type classification | 6k | 0.5k | 1 | 1 |
|
| 37 |
+
| [SICK-E](http://clic.cimec.unitn.it/composes/sick.html) | natural language inference | 4.5k | 4.9k | 1 | 1 |
|
| 38 |
+
| [SNLI](https://nlp.stanford.edu/projects/snli/) | natural language inference | 550k | 9.8k | 1 | 1 |
|
| 39 |
+
| [MRPC](https://aclweb.org/aclwiki/Paraphrase_Identification_(State_of_the_art)) | paraphrase detection | 4.1k | 1.7k | 1 | 1 |
|
| 40 |
+
| [STS 2012](https://www.cs.york.ac.uk/semeval-2012/task6/) | semantic textual similarity | N/A | 3.1k | 0 | 0 |
|
| 41 |
+
| [STS 2013](http://ixa2.si.ehu.es/sts/) | semantic textual similarity | N/A | 1.5k | 0 | 0 |
|
| 42 |
+
| [STS 2014](http://alt.qcri.org/semeval2014/task10/) | semantic textual similarity | N/A | 3.7k | 0 | 0 |
|
| 43 |
+
| [STS 2015](http://alt.qcri.org/semeval2015/task2/) | semantic textual similarity | N/A | 8.5k | 0 | 0 |
|
| 44 |
+
| [STS 2016](http://alt.qcri.org/semeval2016/task1/) | semantic textual similarity | N/A | 9.2k | 0 | 0 |
|
| 45 |
+
| [STS B](http://ixa2.si.ehu.es/stswiki/index.php/STSbenchmark#Results) | semantic textual similarity | 5.7k | 1.4k | 1 | 0 |
|
| 46 |
+
| [SICK-R](http://clic.cimec.unitn.it/composes/sick.html) | semantic textual similarity | 4.5k | 4.9k | 1 | 0 |
|
| 47 |
+
| [COCO](http://mscoco.org/) | image-caption retrieval | 567k | 5*1k | 1 | 0 |
|
| 48 |
+
|
| 49 |
+
where **needs_train** means a model with parameters is learned on top of the sentence embeddings, and **set_classifier** means you can define the parameters of the classifier in the case of a classification task (see below).
|
| 50 |
+
|
| 51 |
+
Note: COCO comes with ResNet-101 2048d image embeddings. [More details on the tasks.](https://arxiv.org/pdf/1705.02364.pdf)
|
| 52 |
+
|
| 53 |
+
### Probing tasks
|
| 54 |
+
SentEval also includes a series of [*probing* tasks](https://github.com/facebookresearch/SentEval/tree/master/data/probing) to evaluate what linguistic properties are encoded in your sentence embeddings:
|
| 55 |
+
|
| 56 |
+
| Task | Type | #train | #test | needs_train | set_classifier |
|
| 57 |
+
|---------- |------------------------------ |-----------:|----------:|:-----------:|:----------:|
|
| 58 |
+
| [SentLen](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Length prediction | 100k | 10k | 1 | 1 |
|
| 59 |
+
| [WC](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Word Content analysis | 100k | 10k | 1 | 1 |
|
| 60 |
+
| [TreeDepth](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Tree depth prediction | 100k | 10k | 1 | 1 |
|
| 61 |
+
| [TopConst](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Top Constituents prediction | 100k | 10k | 1 | 1 |
|
| 62 |
+
| [BShift](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Word order analysis | 100k | 10k | 1 | 1 |
|
| 63 |
+
| [Tense](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Verb tense prediction | 100k | 10k | 1 | 1 |
|
| 64 |
+
| [SubjNum](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Subject number prediction | 100k | 10k | 1 | 1 |
|
| 65 |
+
| [ObjNum](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Object number prediction | 100k | 10k | 1 | 1 |
|
| 66 |
+
| [SOMO](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Semantic odd man out | 100k | 10k | 1 | 1 |
|
| 67 |
+
| [CoordInv](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Coordination Inversion | 100k | 10k | 1 | 1 |
|
| 68 |
+
|
| 69 |
+
## Download datasets
|
| 70 |
+
To get all the transfer tasks datasets, run (in data/downstream/):
|
| 71 |
+
```bash
|
| 72 |
+
./get_transfer_data.bash
|
| 73 |
+
```
|
| 74 |
+
This will automatically download and preprocess the downstream datasets, and store them in data/downstream (warning: for MacOS users, you may have to use p7zip instead of unzip). The probing tasks are already in data/probing by default.
|
| 75 |
+
|
| 76 |
+
## How to use SentEval: examples
|
| 77 |
+
|
| 78 |
+
### examples/bow.py
|
| 79 |
+
|
| 80 |
+
In examples/bow.py, we evaluate the quality of the average of word embeddings.
|
| 81 |
+
|
| 82 |
+
To download state-of-the-art fastText embeddings:
|
| 83 |
+
|
| 84 |
+
```bash
|
| 85 |
+
curl -Lo glove.840B.300d.zip http://nlp.stanford.edu/data/glove.840B.300d.zip
|
| 86 |
+
curl -Lo crawl-300d-2M.vec.zip https://dl.fbaipublicfiles.com/fasttext/vectors-english/crawl-300d-2M.vec.zip
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
To reproduce the results for bag-of-vectors, run (in examples/):
|
| 90 |
+
```bash
|
| 91 |
+
python bow.py
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
As required by SentEval, this script implements two functions: **prepare** (optional) and **batcher** (required) that turn text sentences into sentence embeddings. Then SentEval takes care of the evaluation on the transfer tasks using the embeddings as features.
|
| 95 |
+
|
| 96 |
+
### examples/infersent.py
|
| 97 |
+
|
| 98 |
+
To get the **[InferSent](https://www.github.com/facebookresearch/InferSent)** model and reproduce our results, download our best models and run infersent.py (in examples/):
|
| 99 |
+
```bash
|
| 100 |
+
curl -Lo examples/infersent1.pkl https://dl.fbaipublicfiles.com/senteval/infersent/infersent1.pkl
|
| 101 |
+
curl -Lo examples/infersent2.pkl https://dl.fbaipublicfiles.com/senteval/infersent/infersent2.pkl
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
### examples/skipthought.py - examples/gensen.py - examples/googleuse.py
|
| 105 |
+
|
| 106 |
+
We also provide example scripts for three other encoders:
|
| 107 |
+
|
| 108 |
+
* [SkipThought with Layer-Normalization](https://github.com/ryankiros/layer-norm#skip-thoughts) in Theano
|
| 109 |
+
* [GenSen encoder](https://github.com/Maluuba/gensen) in Pytorch
|
| 110 |
+
* [Google encoder](https://tfhub.dev/google/universal-sentence-encoder/1) in TensorFlow
|
| 111 |
+
|
| 112 |
+
Note that for SkipThought and GenSen, following the steps of the associated githubs is necessary.
|
| 113 |
+
The Google encoder script should work as-is.
|
| 114 |
+
|
| 115 |
+
## How to use SentEval
|
| 116 |
+
|
| 117 |
+
To evaluate your sentence embeddings, SentEval requires that you implement two functions:
|
| 118 |
+
|
| 119 |
+
1. **prepare** (sees the whole dataset of each task and can thus construct the word vocabulary, the dictionary of word vectors etc)
|
| 120 |
+
2. **batcher** (transforms a batch of text sentences into sentence embeddings)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
### 1.) prepare(params, samples) (optional)
|
| 124 |
+
|
| 125 |
+
*batcher* only sees one batch at a time while the *samples* argument of *prepare* contains all the sentences of a task.
|
| 126 |
+
|
| 127 |
+
```
|
| 128 |
+
prepare(params, samples)
|
| 129 |
+
```
|
| 130 |
+
* *params*: senteval parameters.
|
| 131 |
+
* *samples*: list of all sentences from the tranfer task.
|
| 132 |
+
* *output*: No output. Arguments stored in "params" can further be used by *batcher*.
|
| 133 |
+
|
| 134 |
+
*Example*: in bow.py, prepare is is used to build the vocabulary of words and construct the "params.word_vect* dictionary of word vectors.
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
### 2.) batcher(params, batch)
|
| 138 |
+
```
|
| 139 |
+
batcher(params, batch)
|
| 140 |
+
```
|
| 141 |
+
* *params*: senteval parameters.
|
| 142 |
+
* *batch*: numpy array of text sentences (of size params.batch_size)
|
| 143 |
+
* *output*: numpy array of sentence embeddings (of size params.batch_size)
|
| 144 |
+
|
| 145 |
+
*Example*: in bow.py, batcher is used to compute the mean of the word vectors for each sentence in the batch using params.word_vec. Use your own encoder in that function to encode sentences.
|
| 146 |
+
|
| 147 |
+
### 3.) evaluation on transfer tasks
|
| 148 |
+
|
| 149 |
+
After having implemented the batch and prepare function for your own sentence encoder,
|
| 150 |
+
|
| 151 |
+
1) to perform the actual evaluation, first import senteval and set its parameters:
|
| 152 |
+
```python
|
| 153 |
+
import senteval
|
| 154 |
+
params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10}
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
2) (optional) set the parameters of the classifier (when applicable):
|
| 158 |
+
```python
|
| 159 |
+
params['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64,
|
| 160 |
+
'tenacity': 5, 'epoch_size': 4}
|
| 161 |
+
```
|
| 162 |
+
You can choose **nhid=0** (Logistic Regression) or **nhid>0** (MLP) and define the parameters for training.
|
| 163 |
+
|
| 164 |
+
3) Create an instance of the class SE:
|
| 165 |
+
```python
|
| 166 |
+
se = senteval.engine.SE(params, batcher, prepare)
|
| 167 |
+
```
|
| 168 |
+
|
| 169 |
+
4) define the set of transfer tasks and run the evaluation:
|
| 170 |
+
```python
|
| 171 |
+
transfer_tasks = ['MR', 'SICKEntailment', 'STS14', 'STSBenchmark']
|
| 172 |
+
results = se.eval(transfer_tasks)
|
| 173 |
+
```
|
| 174 |
+
The current list of available tasks is:
|
| 175 |
+
```python
|
| 176 |
+
['CR', 'MR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC', 'SNLI',
|
| 177 |
+
'SICKEntailment', 'SICKRelatedness', 'STSBenchmark', 'ImageCaptionRetrieval',
|
| 178 |
+
'STS12', 'STS13', 'STS14', 'STS15', 'STS16',
|
| 179 |
+
'Length', 'WordContent', 'Depth', 'TopConstituents','BigramShift', 'Tense',
|
| 180 |
+
'SubjNumber', 'ObjNumber', 'OddManOut', 'CoordinationInversion']
|
| 181 |
+
```
|
| 182 |
+
|
| 183 |
+
## SentEval parameters
|
| 184 |
+
Global parameters of SentEval:
|
| 185 |
+
```bash
|
| 186 |
+
# senteval parameters
|
| 187 |
+
task_path # path to SentEval datasets (required)
|
| 188 |
+
seed # seed
|
| 189 |
+
usepytorch # use cuda-pytorch (else scikit-learn) where possible
|
| 190 |
+
kfold # k-fold validation for MR/CR/SUB/MPQA.
|
| 191 |
+
```
|
| 192 |
+
|
| 193 |
+
Parameters of the classifier:
|
| 194 |
+
```bash
|
| 195 |
+
nhid: # number of hidden units (0: Logistic Regression, >0: MLP); Default nonlinearity: Tanh
|
| 196 |
+
optim: # optimizer ("sgd,lr=0.1", "adam", "rmsprop" ..)
|
| 197 |
+
tenacity: # how many times dev acc does not increase before training stops
|
| 198 |
+
epoch_size: # each epoch corresponds to epoch_size pass on the train set
|
| 199 |
+
max_epoch: # max number of epoches
|
| 200 |
+
dropout: # dropout for MLP
|
| 201 |
+
```
|
| 202 |
+
|
| 203 |
+
Note that to get a proxy of the results while **dramatically reducing computation time**,
|
| 204 |
+
we suggest the **prototyping config**:
|
| 205 |
+
```python
|
| 206 |
+
params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5}
|
| 207 |
+
params['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128,
|
| 208 |
+
'tenacity': 3, 'epoch_size': 2}
|
| 209 |
+
```
|
| 210 |
+
which will results in a 5 times speedup for classification tasks.
|
| 211 |
+
|
| 212 |
+
To produce results that are **comparable to the literature**, use the **default config**:
|
| 213 |
+
```python
|
| 214 |
+
params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10}
|
| 215 |
+
params['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64,
|
| 216 |
+
'tenacity': 5, 'epoch_size': 4}
|
| 217 |
+
```
|
| 218 |
+
which takes longer but will produce better and comparable results.
|
| 219 |
+
|
| 220 |
+
For probing tasks, we used an MLP with a Sigmoid nonlinearity and and tuned the nhid (in [50, 100, 200]) and dropout (in [0.0, 0.1, 0.2]) on the dev set.
|
| 221 |
+
|
| 222 |
+
## References
|
| 223 |
+
|
| 224 |
+
Please considering citing [[1]](https://arxiv.org/abs/1803.05449) if using this code for evaluating sentence embedding methods.
|
| 225 |
+
|
| 226 |
+
### SentEval: An Evaluation Toolkit for Universal Sentence Representations
|
| 227 |
+
|
| 228 |
+
[1] A. Conneau, D. Kiela, [*SentEval: An Evaluation Toolkit for Universal Sentence Representations*](https://arxiv.org/abs/1803.05449)
|
| 229 |
+
|
| 230 |
+
```
|
| 231 |
+
@article{conneau2018senteval,
|
| 232 |
+
title={SentEval: An Evaluation Toolkit for Universal Sentence Representations},
|
| 233 |
+
author={Conneau, Alexis and Kiela, Douwe},
|
| 234 |
+
journal={arXiv preprint arXiv:1803.05449},
|
| 235 |
+
year={2018}
|
| 236 |
+
}
|
| 237 |
+
```
|
| 238 |
+
|
| 239 |
+
Contact: [aconneau@fb.com](mailto:aconneau@fb.com), [dkiela@fb.com](mailto:dkiela@fb.com)
|
| 240 |
+
|
| 241 |
+
### Related work
|
| 242 |
+
* [J. R Kiros, Y. Zhu, R. Salakhutdinov, R. S. Zemel, A. Torralba, R. Urtasun, S. Fidler - SkipThought Vectors, NIPS 2015](https://arxiv.org/abs/1506.06726)
|
| 243 |
+
* [S. Arora, Y. Liang, T. Ma - A Simple but Tough-to-Beat Baseline for Sentence Embeddings, ICLR 2017](https://openreview.net/pdf?id=SyK00v5xx)
|
| 244 |
+
* [Y. Adi, E. Kermany, Y. Belinkov, O. Lavi, Y. Goldberg - Fine-grained analysis of sentence embeddings using auxiliary prediction tasks, ICLR 2017](https://arxiv.org/abs/1608.04207)
|
| 245 |
+
* [A. Conneau, D. Kiela, L. Barrault, H. Schwenk, A. Bordes - Supervised Learning of Universal Sentence Representations from Natural Language Inference Data, EMNLP 2017](https://arxiv.org/abs/1705.02364)
|
| 246 |
+
* [S. Subramanian, A. Trischler, Y. Bengio, C. J Pal - Learning General Purpose Distributed Sentence Representations via Large Scale Multi-task Learning, ICLR 2018](https://arxiv.org/abs/1804.00079)
|
| 247 |
+
* [A. Nie, E. D. Bennett, N. D. Goodman - DisSent: Sentence Representation Learning from Explicit Discourse Relations, 2018](https://arxiv.org/abs/1710.04334)
|
| 248 |
+
* [D. Cer, Y. Yang, S. Kong, N. Hua, N. Limtiaco, R. St. John, N. Constant, M. Guajardo-Cespedes, S. Yuan, C. Tar, Y. Sung, B. Strope, R. Kurzweil - Universal Sentence Encoder, 2018](https://arxiv.org/abs/1803.11175)
|
| 249 |
+
* [A. Conneau, G. Kruszewski, G. Lample, L. Barrault, M. Baroni - What you can cram into a single vector: Probing sentence embeddings for linguistic properties, ACL 2018](https://arxiv.org/abs/1805.01070)
|
SentEval/examples/bow.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2017-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
from __future__ import absolute_import, division, unicode_literals
|
| 9 |
+
|
| 10 |
+
import sys
|
| 11 |
+
import io
|
| 12 |
+
import numpy as np
|
| 13 |
+
import logging
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# Set PATHs
|
| 17 |
+
PATH_TO_SENTEVAL = '../'
|
| 18 |
+
PATH_TO_DATA = '../data'
|
| 19 |
+
# PATH_TO_VEC = 'glove/glove.840B.300d.txt'
|
| 20 |
+
PATH_TO_VEC = 'fasttext/crawl-300d-2M.vec'
|
| 21 |
+
|
| 22 |
+
# import SentEval
|
| 23 |
+
sys.path.insert(0, PATH_TO_SENTEVAL)
|
| 24 |
+
import senteval
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# Create dictionary
|
| 28 |
+
def create_dictionary(sentences, threshold=0):
|
| 29 |
+
words = {}
|
| 30 |
+
for s in sentences:
|
| 31 |
+
for word in s:
|
| 32 |
+
words[word] = words.get(word, 0) + 1
|
| 33 |
+
|
| 34 |
+
if threshold > 0:
|
| 35 |
+
newwords = {}
|
| 36 |
+
for word in words:
|
| 37 |
+
if words[word] >= threshold:
|
| 38 |
+
newwords[word] = words[word]
|
| 39 |
+
words = newwords
|
| 40 |
+
words['<s>'] = 1e9 + 4
|
| 41 |
+
words['</s>'] = 1e9 + 3
|
| 42 |
+
words['<p>'] = 1e9 + 2
|
| 43 |
+
|
| 44 |
+
sorted_words = sorted(words.items(), key=lambda x: -x[1]) # inverse sort
|
| 45 |
+
id2word = []
|
| 46 |
+
word2id = {}
|
| 47 |
+
for i, (w, _) in enumerate(sorted_words):
|
| 48 |
+
id2word.append(w)
|
| 49 |
+
word2id[w] = i
|
| 50 |
+
|
| 51 |
+
return id2word, word2id
|
| 52 |
+
|
| 53 |
+
# Get word vectors from vocabulary (glove, word2vec, fasttext ..)
|
| 54 |
+
def get_wordvec(path_to_vec, word2id):
|
| 55 |
+
word_vec = {}
|
| 56 |
+
|
| 57 |
+
with io.open(path_to_vec, 'r', encoding='utf-8') as f:
|
| 58 |
+
# if word2vec or fasttext file : skip first line "next(f)"
|
| 59 |
+
for line in f:
|
| 60 |
+
word, vec = line.split(' ', 1)
|
| 61 |
+
if word in word2id:
|
| 62 |
+
word_vec[word] = np.fromstring(vec, sep=' ')
|
| 63 |
+
|
| 64 |
+
logging.info('Found {0} words with word vectors, out of \
|
| 65 |
+
{1} words'.format(len(word_vec), len(word2id)))
|
| 66 |
+
return word_vec
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# SentEval prepare and batcher
|
| 70 |
+
def prepare(params, samples):
|
| 71 |
+
_, params.word2id = create_dictionary(samples)
|
| 72 |
+
params.word_vec = get_wordvec(PATH_TO_VEC, params.word2id)
|
| 73 |
+
params.wvec_dim = 300
|
| 74 |
+
return
|
| 75 |
+
|
| 76 |
+
def batcher(params, batch):
|
| 77 |
+
batch = [sent if sent != [] else ['.'] for sent in batch]
|
| 78 |
+
embeddings = []
|
| 79 |
+
|
| 80 |
+
for sent in batch:
|
| 81 |
+
sentvec = []
|
| 82 |
+
for word in sent:
|
| 83 |
+
if word in params.word_vec:
|
| 84 |
+
sentvec.append(params.word_vec[word])
|
| 85 |
+
if not sentvec:
|
| 86 |
+
vec = np.zeros(params.wvec_dim)
|
| 87 |
+
sentvec.append(vec)
|
| 88 |
+
sentvec = np.mean(sentvec, 0)
|
| 89 |
+
embeddings.append(sentvec)
|
| 90 |
+
|
| 91 |
+
embeddings = np.vstack(embeddings)
|
| 92 |
+
return embeddings
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# Set params for SentEval
|
| 96 |
+
params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5}
|
| 97 |
+
params_senteval['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128,
|
| 98 |
+
'tenacity': 3, 'epoch_size': 2}
|
| 99 |
+
|
| 100 |
+
# Set up logger
|
| 101 |
+
logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG)
|
| 102 |
+
|
| 103 |
+
if __name__ == "__main__":
|
| 104 |
+
se = senteval.engine.SE(params_senteval, batcher, prepare)
|
| 105 |
+
transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16',
|
| 106 |
+
'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC',
|
| 107 |
+
'SICKEntailment', 'SICKRelatedness', 'STSBenchmark',
|
| 108 |
+
'Length', 'WordContent', 'Depth', 'TopConstituents',
|
| 109 |
+
'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber',
|
| 110 |
+
'OddManOut', 'CoordinationInversion']
|
| 111 |
+
results = se.eval(transfer_tasks)
|
| 112 |
+
print(results)
|
SentEval/examples/gensen.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2017-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
"""
|
| 9 |
+
Clone GenSen repo here: https://github.com/Maluuba/gensen.git
|
| 10 |
+
And follow instructions for loading the model used in batcher
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import absolute_import, division, unicode_literals
|
| 14 |
+
|
| 15 |
+
import sys
|
| 16 |
+
import logging
|
| 17 |
+
# import GenSen package
|
| 18 |
+
from gensen import GenSen, GenSenSingle
|
| 19 |
+
|
| 20 |
+
# Set PATHs
|
| 21 |
+
PATH_TO_SENTEVAL = '../'
|
| 22 |
+
PATH_TO_DATA = '../data'
|
| 23 |
+
|
| 24 |
+
# import SentEval
|
| 25 |
+
sys.path.insert(0, PATH_TO_SENTEVAL)
|
| 26 |
+
import senteval
|
| 27 |
+
|
| 28 |
+
# SentEval prepare and batcher
|
| 29 |
+
def prepare(params, samples):
|
| 30 |
+
return
|
| 31 |
+
|
| 32 |
+
def batcher(params, batch):
|
| 33 |
+
batch = [' '.join(sent) if sent != [] else '.' for sent in batch]
|
| 34 |
+
_, reps_h_t = gensen.get_representation(
|
| 35 |
+
sentences, pool='last', return_numpy=True, tokenize=True
|
| 36 |
+
)
|
| 37 |
+
embeddings = reps_h_t
|
| 38 |
+
return embeddings
|
| 39 |
+
|
| 40 |
+
# Load GenSen model
|
| 41 |
+
gensen_1 = GenSenSingle(
|
| 42 |
+
model_folder='../data/models',
|
| 43 |
+
filename_prefix='nli_large_bothskip',
|
| 44 |
+
pretrained_emb='../data/embedding/glove.840B.300d.h5'
|
| 45 |
+
)
|
| 46 |
+
gensen_2 = GenSenSingle(
|
| 47 |
+
model_folder='../data/models',
|
| 48 |
+
filename_prefix='nli_large_bothskip_parse',
|
| 49 |
+
pretrained_emb='../data/embedding/glove.840B.300d.h5'
|
| 50 |
+
)
|
| 51 |
+
gensen_encoder = GenSen(gensen_1, gensen_2)
|
| 52 |
+
reps_h, reps_h_t = gensen.get_representation(
|
| 53 |
+
sentences, pool='last', return_numpy=True, tokenize=True
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# Set params for SentEval
|
| 57 |
+
params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5}
|
| 58 |
+
params_senteval['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128,
|
| 59 |
+
'tenacity': 3, 'epoch_size': 2}
|
| 60 |
+
params_senteval['gensen'] = gensen_encoder
|
| 61 |
+
|
| 62 |
+
# Set up logger
|
| 63 |
+
logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG)
|
| 64 |
+
|
| 65 |
+
if __name__ == "__main__":
|
| 66 |
+
se = senteval.engine.SE(params_senteval, batcher, prepare)
|
| 67 |
+
transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16',
|
| 68 |
+
'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC',
|
| 69 |
+
'SICKEntailment', 'SICKRelatedness', 'STSBenchmark',
|
| 70 |
+
'Length', 'WordContent', 'Depth', 'TopConstituents',
|
| 71 |
+
'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber',
|
| 72 |
+
'OddManOut', 'CoordinationInversion']
|
| 73 |
+
results = se.eval(transfer_tasks)
|
| 74 |
+
print(results)
|
SentEval/examples/googleuse.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2017-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
from __future__ import absolute_import, division
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import sys
|
| 12 |
+
import logging
|
| 13 |
+
import tensorflow as tf
|
| 14 |
+
import tensorflow_hub as hub
|
| 15 |
+
tf.logging.set_verbosity(0)
|
| 16 |
+
|
| 17 |
+
# Set PATHs
|
| 18 |
+
PATH_TO_SENTEVAL = '../'
|
| 19 |
+
PATH_TO_DATA = '../data'
|
| 20 |
+
|
| 21 |
+
# import SentEval
|
| 22 |
+
sys.path.insert(0, PATH_TO_SENTEVAL)
|
| 23 |
+
import senteval
|
| 24 |
+
|
| 25 |
+
# tensorflow session
|
| 26 |
+
session = tf.Session()
|
| 27 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
| 28 |
+
|
| 29 |
+
# SentEval prepare and batcher
|
| 30 |
+
def prepare(params, samples):
|
| 31 |
+
return
|
| 32 |
+
|
| 33 |
+
def batcher(params, batch):
|
| 34 |
+
batch = [' '.join(sent) if sent != [] else '.' for sent in batch]
|
| 35 |
+
embeddings = params['google_use'](batch)
|
| 36 |
+
return embeddings
|
| 37 |
+
|
| 38 |
+
def make_embed_fn(module):
|
| 39 |
+
with tf.Graph().as_default():
|
| 40 |
+
sentences = tf.placeholder(tf.string)
|
| 41 |
+
embed = hub.Module(module)
|
| 42 |
+
embeddings = embed(sentences)
|
| 43 |
+
session = tf.train.MonitoredSession()
|
| 44 |
+
return lambda x: session.run(embeddings, {sentences: x})
|
| 45 |
+
|
| 46 |
+
# Start TF session and load Google Universal Sentence Encoder
|
| 47 |
+
encoder = make_embed_fn("https://tfhub.dev/google/universal-sentence-encoder-large/2")
|
| 48 |
+
|
| 49 |
+
# Set params for SentEval
|
| 50 |
+
params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5}
|
| 51 |
+
params_senteval['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128,
|
| 52 |
+
'tenacity': 3, 'epoch_size': 2}
|
| 53 |
+
params_senteval['google_use'] = encoder
|
| 54 |
+
|
| 55 |
+
# Set up logger
|
| 56 |
+
logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG)
|
| 57 |
+
|
| 58 |
+
if __name__ == "__main__":
|
| 59 |
+
se = senteval.engine.SE(params_senteval, batcher, prepare)
|
| 60 |
+
transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16',
|
| 61 |
+
'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC',
|
| 62 |
+
'SICKEntailment', 'SICKRelatedness', 'STSBenchmark',
|
| 63 |
+
'Length', 'WordContent', 'Depth', 'TopConstituents',
|
| 64 |
+
'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber',
|
| 65 |
+
'OddManOut', 'CoordinationInversion']
|
| 66 |
+
results = se.eval(transfer_tasks)
|
| 67 |
+
print(results)
|
SentEval/examples/infersent.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2017-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
"""
|
| 9 |
+
InferSent models. See https://github.com/facebookresearch/InferSent.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import absolute_import, division, unicode_literals
|
| 13 |
+
|
| 14 |
+
import sys
|
| 15 |
+
import os
|
| 16 |
+
import torch
|
| 17 |
+
import logging
|
| 18 |
+
|
| 19 |
+
# get models.py from InferSent repo
|
| 20 |
+
from models import InferSent
|
| 21 |
+
|
| 22 |
+
# Set PATHs
|
| 23 |
+
PATH_SENTEVAL = '../'
|
| 24 |
+
PATH_TO_DATA = '../data'
|
| 25 |
+
PATH_TO_W2V = 'PATH/TO/glove.840B.300d.txt' # or crawl-300d-2M.vec for V2
|
| 26 |
+
MODEL_PATH = 'infersent1.pkl'
|
| 27 |
+
V = 1 # version of InferSent
|
| 28 |
+
|
| 29 |
+
assert os.path.isfile(MODEL_PATH) and os.path.isfile(PATH_TO_W2V), \
|
| 30 |
+
'Set MODEL and GloVe PATHs'
|
| 31 |
+
|
| 32 |
+
# import senteval
|
| 33 |
+
sys.path.insert(0, PATH_SENTEVAL)
|
| 34 |
+
import senteval
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def prepare(params, samples):
|
| 38 |
+
params.infersent.build_vocab([' '.join(s) for s in samples], tokenize=False)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def batcher(params, batch):
|
| 42 |
+
sentences = [' '.join(s) for s in batch]
|
| 43 |
+
embeddings = params.infersent.encode(sentences, bsize=params.batch_size, tokenize=False)
|
| 44 |
+
return embeddings
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
"""
|
| 48 |
+
Evaluation of trained model on Transfer Tasks (SentEval)
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
# define senteval params
|
| 52 |
+
params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5}
|
| 53 |
+
params_senteval['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128,
|
| 54 |
+
'tenacity': 3, 'epoch_size': 2}
|
| 55 |
+
# Set up logger
|
| 56 |
+
logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG)
|
| 57 |
+
|
| 58 |
+
if __name__ == "__main__":
|
| 59 |
+
# Load InferSent model
|
| 60 |
+
params_model = {'bsize': 64, 'word_emb_dim': 300, 'enc_lstm_dim': 2048,
|
| 61 |
+
'pool_type': 'max', 'dpout_model': 0.0, 'version': V}
|
| 62 |
+
model = InferSent(params_model)
|
| 63 |
+
model.load_state_dict(torch.load(MODEL_PATH))
|
| 64 |
+
model.set_w2v_path(PATH_TO_W2V)
|
| 65 |
+
|
| 66 |
+
params_senteval['infersent'] = model.cuda()
|
| 67 |
+
|
| 68 |
+
se = senteval.engine.SE(params_senteval, batcher, prepare)
|
| 69 |
+
transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16',
|
| 70 |
+
'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC',
|
| 71 |
+
'SICKEntailment', 'SICKRelatedness', 'STSBenchmark',
|
| 72 |
+
'Length', 'WordContent', 'Depth', 'TopConstituents',
|
| 73 |
+
'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber',
|
| 74 |
+
'OddManOut', 'CoordinationInversion']
|
| 75 |
+
results = se.eval(transfer_tasks)
|
| 76 |
+
print(results)
|
SentEval/examples/models.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2017-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
"""
|
| 9 |
+
This file contains the definition of encoders used in https://arxiv.org/pdf/1705.02364.pdf
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import time
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class InferSent(nn.Module):
|
| 20 |
+
|
| 21 |
+
def __init__(self, config):
|
| 22 |
+
super(InferSent, self).__init__()
|
| 23 |
+
self.bsize = config['bsize']
|
| 24 |
+
self.word_emb_dim = config['word_emb_dim']
|
| 25 |
+
self.enc_lstm_dim = config['enc_lstm_dim']
|
| 26 |
+
self.pool_type = config['pool_type']
|
| 27 |
+
self.dpout_model = config['dpout_model']
|
| 28 |
+
self.version = 1 if 'version' not in config else config['version']
|
| 29 |
+
|
| 30 |
+
self.enc_lstm = nn.LSTM(self.word_emb_dim, self.enc_lstm_dim, 1,
|
| 31 |
+
bidirectional=True, dropout=self.dpout_model)
|
| 32 |
+
|
| 33 |
+
assert self.version in [1, 2]
|
| 34 |
+
if self.version == 1:
|
| 35 |
+
self.bos = '<s>'
|
| 36 |
+
self.eos = '</s>'
|
| 37 |
+
self.max_pad = True
|
| 38 |
+
self.moses_tok = False
|
| 39 |
+
elif self.version == 2:
|
| 40 |
+
self.bos = '<p>'
|
| 41 |
+
self.eos = '</p>'
|
| 42 |
+
self.max_pad = False
|
| 43 |
+
self.moses_tok = True
|
| 44 |
+
|
| 45 |
+
def is_cuda(self):
|
| 46 |
+
# either all weights are on cpu or they are on gpu
|
| 47 |
+
return self.enc_lstm.bias_hh_l0.data.is_cuda
|
| 48 |
+
|
| 49 |
+
def forward(self, sent_tuple):
|
| 50 |
+
# sent_len: [max_len, ..., min_len] (bsize)
|
| 51 |
+
# sent: (seqlen x bsize x worddim)
|
| 52 |
+
sent, sent_len = sent_tuple
|
| 53 |
+
|
| 54 |
+
# Sort by length (keep idx)
|
| 55 |
+
sent_len_sorted, idx_sort = np.sort(sent_len)[::-1], np.argsort(-sent_len)
|
| 56 |
+
sent_len_sorted = sent_len_sorted.copy()
|
| 57 |
+
idx_unsort = np.argsort(idx_sort)
|
| 58 |
+
|
| 59 |
+
idx_sort = torch.from_numpy(idx_sort).cuda() if self.is_cuda() \
|
| 60 |
+
else torch.from_numpy(idx_sort)
|
| 61 |
+
sent = sent.index_select(1, idx_sort)
|
| 62 |
+
|
| 63 |
+
# Handling padding in Recurrent Networks
|
| 64 |
+
sent_packed = nn.utils.rnn.pack_padded_sequence(sent, sent_len_sorted)
|
| 65 |
+
sent_output = self.enc_lstm(sent_packed)[0] # seqlen x batch x 2*nhid
|
| 66 |
+
sent_output = nn.utils.rnn.pad_packed_sequence(sent_output)[0]
|
| 67 |
+
|
| 68 |
+
# Un-sort by length
|
| 69 |
+
idx_unsort = torch.from_numpy(idx_unsort).cuda() if self.is_cuda() \
|
| 70 |
+
else torch.from_numpy(idx_unsort)
|
| 71 |
+
sent_output = sent_output.index_select(1, idx_unsort)
|
| 72 |
+
|
| 73 |
+
# Pooling
|
| 74 |
+
if self.pool_type == "mean":
|
| 75 |
+
sent_len = torch.FloatTensor(sent_len.copy()).unsqueeze(1).cuda()
|
| 76 |
+
emb = torch.sum(sent_output, 0).squeeze(0)
|
| 77 |
+
emb = emb / sent_len.expand_as(emb)
|
| 78 |
+
elif self.pool_type == "max":
|
| 79 |
+
if not self.max_pad:
|
| 80 |
+
sent_output[sent_output == 0] = -1e9
|
| 81 |
+
emb = torch.max(sent_output, 0)[0]
|
| 82 |
+
if emb.ndimension() == 3:
|
| 83 |
+
emb = emb.squeeze(0)
|
| 84 |
+
assert emb.ndimension() == 2
|
| 85 |
+
|
| 86 |
+
return emb
|
| 87 |
+
|
| 88 |
+
def set_w2v_path(self, w2v_path):
|
| 89 |
+
self.w2v_path = w2v_path
|
| 90 |
+
|
| 91 |
+
def get_word_dict(self, sentences, tokenize=True):
|
| 92 |
+
# create vocab of words
|
| 93 |
+
word_dict = {}
|
| 94 |
+
sentences = [s.split() if not tokenize else self.tokenize(s) for s in sentences]
|
| 95 |
+
for sent in sentences:
|
| 96 |
+
for word in sent:
|
| 97 |
+
if word not in word_dict:
|
| 98 |
+
word_dict[word] = ''
|
| 99 |
+
word_dict[self.bos] = ''
|
| 100 |
+
word_dict[self.eos] = ''
|
| 101 |
+
return word_dict
|
| 102 |
+
|
| 103 |
+
def get_w2v(self, word_dict):
|
| 104 |
+
assert hasattr(self, 'w2v_path'), 'w2v path not set'
|
| 105 |
+
# create word_vec with w2v vectors
|
| 106 |
+
word_vec = {}
|
| 107 |
+
with open(self.w2v_path, encoding='utf-8') as f:
|
| 108 |
+
for line in f:
|
| 109 |
+
word, vec = line.split(' ', 1)
|
| 110 |
+
if word in word_dict:
|
| 111 |
+
word_vec[word] = np.fromstring(vec, sep=' ')
|
| 112 |
+
print('Found %s(/%s) words with w2v vectors' % (len(word_vec), len(word_dict)))
|
| 113 |
+
return word_vec
|
| 114 |
+
|
| 115 |
+
def get_w2v_k(self, K):
|
| 116 |
+
assert hasattr(self, 'w2v_path'), 'w2v path not set'
|
| 117 |
+
# create word_vec with k first w2v vectors
|
| 118 |
+
k = 0
|
| 119 |
+
word_vec = {}
|
| 120 |
+
with open(self.w2v_path, encoding='utf-8') as f:
|
| 121 |
+
for line in f:
|
| 122 |
+
word, vec = line.split(' ', 1)
|
| 123 |
+
if k <= K:
|
| 124 |
+
word_vec[word] = np.fromstring(vec, sep=' ')
|
| 125 |
+
k += 1
|
| 126 |
+
if k > K:
|
| 127 |
+
if word in [self.bos, self.eos]:
|
| 128 |
+
word_vec[word] = np.fromstring(vec, sep=' ')
|
| 129 |
+
|
| 130 |
+
if k > K and all([w in word_vec for w in [self.bos, self.eos]]):
|
| 131 |
+
break
|
| 132 |
+
return word_vec
|
| 133 |
+
|
| 134 |
+
def build_vocab(self, sentences, tokenize=True):
|
| 135 |
+
assert hasattr(self, 'w2v_path'), 'w2v path not set'
|
| 136 |
+
word_dict = self.get_word_dict(sentences, tokenize)
|
| 137 |
+
self.word_vec = self.get_w2v(word_dict)
|
| 138 |
+
print('Vocab size : %s' % (len(self.word_vec)))
|
| 139 |
+
|
| 140 |
+
# build w2v vocab with k most frequent words
|
| 141 |
+
def build_vocab_k_words(self, K):
|
| 142 |
+
assert hasattr(self, 'w2v_path'), 'w2v path not set'
|
| 143 |
+
self.word_vec = self.get_w2v_k(K)
|
| 144 |
+
print('Vocab size : %s' % (K))
|
| 145 |
+
|
| 146 |
+
def update_vocab(self, sentences, tokenize=True):
|
| 147 |
+
assert hasattr(self, 'w2v_path'), 'warning : w2v path not set'
|
| 148 |
+
assert hasattr(self, 'word_vec'), 'build_vocab before updating it'
|
| 149 |
+
word_dict = self.get_word_dict(sentences, tokenize)
|
| 150 |
+
|
| 151 |
+
# keep only new words
|
| 152 |
+
for word in self.word_vec:
|
| 153 |
+
if word in word_dict:
|
| 154 |
+
del word_dict[word]
|
| 155 |
+
|
| 156 |
+
# udpate vocabulary
|
| 157 |
+
if word_dict:
|
| 158 |
+
new_word_vec = self.get_w2v(word_dict)
|
| 159 |
+
self.word_vec.update(new_word_vec)
|
| 160 |
+
else:
|
| 161 |
+
new_word_vec = []
|
| 162 |
+
print('New vocab size : %s (added %s words)'% (len(self.word_vec), len(new_word_vec)))
|
| 163 |
+
|
| 164 |
+
def get_batch(self, batch):
|
| 165 |
+
# sent in batch in decreasing order of lengths
|
| 166 |
+
# batch: (bsize, max_len, word_dim)
|
| 167 |
+
embed = np.zeros((len(batch[0]), len(batch), self.word_emb_dim))
|
| 168 |
+
|
| 169 |
+
for i in range(len(batch)):
|
| 170 |
+
for j in range(len(batch[i])):
|
| 171 |
+
embed[j, i, :] = self.word_vec[batch[i][j]]
|
| 172 |
+
|
| 173 |
+
return torch.FloatTensor(embed)
|
| 174 |
+
|
| 175 |
+
def tokenize(self, s):
|
| 176 |
+
from nltk.tokenize import word_tokenize
|
| 177 |
+
if self.moses_tok:
|
| 178 |
+
s = ' '.join(word_tokenize(s))
|
| 179 |
+
s = s.replace(" n't ", "n 't ") # HACK to get ~MOSES tokenization
|
| 180 |
+
return s.split()
|
| 181 |
+
else:
|
| 182 |
+
return word_tokenize(s)
|
| 183 |
+
|
| 184 |
+
def prepare_samples(self, sentences, bsize, tokenize, verbose):
|
| 185 |
+
sentences = [[self.bos] + s.split() + [self.eos] if not tokenize else
|
| 186 |
+
[self.bos] + self.tokenize(s) + [self.eos] for s in sentences]
|
| 187 |
+
n_w = np.sum([len(x) for x in sentences])
|
| 188 |
+
|
| 189 |
+
# filters words without w2v vectors
|
| 190 |
+
for i in range(len(sentences)):
|
| 191 |
+
s_f = [word for word in sentences[i] if word in self.word_vec]
|
| 192 |
+
if not s_f:
|
| 193 |
+
import warnings
|
| 194 |
+
warnings.warn('No words in "%s" (idx=%s) have w2v vectors. \
|
| 195 |
+
Replacing by "</s>"..' % (sentences[i], i))
|
| 196 |
+
s_f = [self.eos]
|
| 197 |
+
sentences[i] = s_f
|
| 198 |
+
|
| 199 |
+
lengths = np.array([len(s) for s in sentences])
|
| 200 |
+
n_wk = np.sum(lengths)
|
| 201 |
+
if verbose:
|
| 202 |
+
print('Nb words kept : %s/%s (%.1f%s)' % (
|
| 203 |
+
n_wk, n_w, 100.0 * n_wk / n_w, '%'))
|
| 204 |
+
|
| 205 |
+
# sort by decreasing length
|
| 206 |
+
lengths, idx_sort = np.sort(lengths)[::-1], np.argsort(-lengths)
|
| 207 |
+
sentences = np.array(sentences)[idx_sort]
|
| 208 |
+
|
| 209 |
+
return sentences, lengths, idx_sort
|
| 210 |
+
|
| 211 |
+
def encode(self, sentences, bsize=64, tokenize=True, verbose=False):
|
| 212 |
+
tic = time.time()
|
| 213 |
+
sentences, lengths, idx_sort = self.prepare_samples(
|
| 214 |
+
sentences, bsize, tokenize, verbose)
|
| 215 |
+
|
| 216 |
+
embeddings = []
|
| 217 |
+
for stidx in range(0, len(sentences), bsize):
|
| 218 |
+
batch = self.get_batch(sentences[stidx:stidx + bsize])
|
| 219 |
+
if self.is_cuda():
|
| 220 |
+
batch = batch.cuda()
|
| 221 |
+
with torch.no_grad():
|
| 222 |
+
batch = self.forward((batch, lengths[stidx:stidx + bsize])).data.cpu().numpy()
|
| 223 |
+
embeddings.append(batch)
|
| 224 |
+
embeddings = np.vstack(embeddings)
|
| 225 |
+
|
| 226 |
+
# unsort
|
| 227 |
+
idx_unsort = np.argsort(idx_sort)
|
| 228 |
+
embeddings = embeddings[idx_unsort]
|
| 229 |
+
|
| 230 |
+
if verbose:
|
| 231 |
+
print('Speed : %.1f sentences/s (%s mode, bsize=%s)' % (
|
| 232 |
+
len(embeddings)/(time.time()-tic),
|
| 233 |
+
'gpu' if self.is_cuda() else 'cpu', bsize))
|
| 234 |
+
return embeddings
|
| 235 |
+
|
| 236 |
+
def visualize(self, sent, tokenize=True):
|
| 237 |
+
|
| 238 |
+
sent = sent.split() if not tokenize else self.tokenize(sent)
|
| 239 |
+
sent = [[self.bos] + [word for word in sent if word in self.word_vec] + [self.eos]]
|
| 240 |
+
|
| 241 |
+
if ' '.join(sent[0]) == '%s %s' % (self.bos, self.eos):
|
| 242 |
+
import warnings
|
| 243 |
+
warnings.warn('No words in "%s" have w2v vectors. Replacing \
|
| 244 |
+
by "%s %s"..' % (sent, self.bos, self.eos))
|
| 245 |
+
batch = self.get_batch(sent)
|
| 246 |
+
|
| 247 |
+
if self.is_cuda():
|
| 248 |
+
batch = batch.cuda()
|
| 249 |
+
output = self.enc_lstm(batch)[0]
|
| 250 |
+
output, idxs = torch.max(output, 0)
|
| 251 |
+
# output, idxs = output.squeeze(), idxs.squeeze()
|
| 252 |
+
idxs = idxs.data.cpu().numpy()
|
| 253 |
+
argmaxs = [np.sum((idxs == k)) for k in range(len(sent[0]))]
|
| 254 |
+
|
| 255 |
+
# visualize model
|
| 256 |
+
import matplotlib.pyplot as plt
|
| 257 |
+
x = range(len(sent[0]))
|
| 258 |
+
y = [100.0 * n / np.sum(argmaxs) for n in argmaxs]
|
| 259 |
+
plt.xticks(x, sent[0], rotation=45)
|
| 260 |
+
plt.bar(x, y)
|
| 261 |
+
plt.ylabel('%')
|
| 262 |
+
plt.title('Visualisation of words importance')
|
| 263 |
+
plt.show()
|
| 264 |
+
|
| 265 |
+
return output, idxs
|
SentEval/examples/skipthought.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2017-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
from __future__ import absolute_import, division, unicode_literals
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
Example of file for SkipThought in SentEval
|
| 12 |
+
"""
|
| 13 |
+
import logging
|
| 14 |
+
import sys
|
| 15 |
+
sys.setdefaultencoding('utf8')
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# Set PATHs
|
| 19 |
+
PATH_TO_SENTEVAL = '../'
|
| 20 |
+
PATH_TO_DATA = '../data/senteval_data/'
|
| 21 |
+
PATH_TO_SKIPTHOUGHT = ''
|
| 22 |
+
|
| 23 |
+
assert PATH_TO_SKIPTHOUGHT != '', 'Download skipthought and set correct PATH'
|
| 24 |
+
|
| 25 |
+
# import skipthought and Senteval
|
| 26 |
+
sys.path.insert(0, PATH_TO_SKIPTHOUGHT)
|
| 27 |
+
import skipthoughts
|
| 28 |
+
sys.path.insert(0, PATH_TO_SENTEVAL)
|
| 29 |
+
import senteval
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def prepare(params, samples):
|
| 33 |
+
return
|
| 34 |
+
|
| 35 |
+
def batcher(params, batch):
|
| 36 |
+
batch = [str(' '.join(sent), errors="ignore") if sent != [] else '.' for sent in batch]
|
| 37 |
+
embeddings = skipthoughts.encode(params['encoder'], batch,
|
| 38 |
+
verbose=False, use_eos=True)
|
| 39 |
+
return embeddings
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# Set params for SentEval
|
| 43 |
+
params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10, 'batch_size': 512}
|
| 44 |
+
params_senteval['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64,
|
| 45 |
+
'tenacity': 5, 'epoch_size': 4}
|
| 46 |
+
# Set up logger
|
| 47 |
+
logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG)
|
| 48 |
+
|
| 49 |
+
if __name__ == "__main__":
|
| 50 |
+
# Load SkipThought model
|
| 51 |
+
params_senteval['encoder'] = skipthoughts.load_model()
|
| 52 |
+
|
| 53 |
+
se = senteval.engine.SE(params_senteval, batcher, prepare)
|
| 54 |
+
transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16',
|
| 55 |
+
'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC',
|
| 56 |
+
'SICKEntailment', 'SICKRelatedness', 'STSBenchmark',
|
| 57 |
+
'Length', 'WordContent', 'Depth', 'TopConstituents',
|
| 58 |
+
'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber',
|
| 59 |
+
'OddManOut', 'CoordinationInversion']
|
| 60 |
+
results = se.eval(transfer_tasks)
|
| 61 |
+
print(results)
|
SentEval/senteval/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2017-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
from __future__ import absolute_import
|
| 9 |
+
|
| 10 |
+
from senteval.engine import SE
|
SentEval/senteval/binary.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2017-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
'''
|
| 9 |
+
Binary classifier and corresponding datasets : MR, CR, SUBJ, MPQA
|
| 10 |
+
'''
|
| 11 |
+
from __future__ import absolute_import, division, unicode_literals
|
| 12 |
+
|
| 13 |
+
import io
|
| 14 |
+
import os
|
| 15 |
+
import numpy as np
|
| 16 |
+
import logging
|
| 17 |
+
|
| 18 |
+
from senteval.tools.validation import InnerKFoldClassifier
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class BinaryClassifierEval(object):
|
| 22 |
+
def __init__(self, pos, neg, seed=1111):
|
| 23 |
+
self.seed = seed
|
| 24 |
+
self.samples, self.labels = pos + neg, [1] * len(pos) + [0] * len(neg)
|
| 25 |
+
self.n_samples = len(self.samples)
|
| 26 |
+
|
| 27 |
+
def do_prepare(self, params, prepare):
|
| 28 |
+
# prepare is given the whole text
|
| 29 |
+
return prepare(params, self.samples)
|
| 30 |
+
# prepare puts everything it outputs in "params" : params.word2id etc
|
| 31 |
+
# Those output will be further used by "batcher".
|
| 32 |
+
|
| 33 |
+
def loadFile(self, fpath):
|
| 34 |
+
with io.open(fpath, 'r', encoding='latin-1') as f:
|
| 35 |
+
return [line.split() for line in f.read().splitlines()]
|
| 36 |
+
|
| 37 |
+
def run(self, params, batcher):
|
| 38 |
+
enc_input = []
|
| 39 |
+
# Sort to reduce padding
|
| 40 |
+
sorted_corpus = sorted(zip(self.samples, self.labels),
|
| 41 |
+
key=lambda z: (len(z[0]), z[1]))
|
| 42 |
+
sorted_samples = [x for (x, y) in sorted_corpus]
|
| 43 |
+
sorted_labels = [y for (x, y) in sorted_corpus]
|
| 44 |
+
logging.info('Generating sentence embeddings')
|
| 45 |
+
for ii in range(0, self.n_samples, params.batch_size):
|
| 46 |
+
batch = sorted_samples[ii:ii + params.batch_size]
|
| 47 |
+
embeddings = batcher(params, batch)
|
| 48 |
+
enc_input.append(embeddings)
|
| 49 |
+
enc_input = np.vstack(enc_input)
|
| 50 |
+
logging.info('Generated sentence embeddings')
|
| 51 |
+
|
| 52 |
+
config = {'nclasses': 2, 'seed': self.seed,
|
| 53 |
+
'usepytorch': params.usepytorch,
|
| 54 |
+
'classifier': params.classifier,
|
| 55 |
+
'nhid': params.nhid, 'kfold': params.kfold}
|
| 56 |
+
clf = InnerKFoldClassifier(enc_input, np.array(sorted_labels), config)
|
| 57 |
+
devacc, testacc = clf.run()
|
| 58 |
+
logging.debug('Dev acc : {0} Test acc : {1}\n'.format(devacc, testacc))
|
| 59 |
+
return {'devacc': devacc, 'acc': testacc, 'ndev': self.n_samples,
|
| 60 |
+
'ntest': self.n_samples}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class CREval(BinaryClassifierEval):
|
| 64 |
+
def __init__(self, task_path, seed=1111):
|
| 65 |
+
logging.debug('***** Transfer task : CR *****\n\n')
|
| 66 |
+
pos = self.loadFile(os.path.join(task_path, 'custrev.pos'))
|
| 67 |
+
neg = self.loadFile(os.path.join(task_path, 'custrev.neg'))
|
| 68 |
+
super(self.__class__, self).__init__(pos, neg, seed)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class MREval(BinaryClassifierEval):
|
| 72 |
+
def __init__(self, task_path, seed=1111):
|
| 73 |
+
logging.debug('***** Transfer task : MR *****\n\n')
|
| 74 |
+
pos = self.loadFile(os.path.join(task_path, 'rt-polarity.pos'))
|
| 75 |
+
neg = self.loadFile(os.path.join(task_path, 'rt-polarity.neg'))
|
| 76 |
+
super(self.__class__, self).__init__(pos, neg, seed)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class SUBJEval(BinaryClassifierEval):
|
| 80 |
+
def __init__(self, task_path, seed=1111):
|
| 81 |
+
logging.debug('***** Transfer task : SUBJ *****\n\n')
|
| 82 |
+
obj = self.loadFile(os.path.join(task_path, 'subj.objective'))
|
| 83 |
+
subj = self.loadFile(os.path.join(task_path, 'subj.subjective'))
|
| 84 |
+
super(self.__class__, self).__init__(obj, subj, seed)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class MPQAEval(BinaryClassifierEval):
|
| 88 |
+
def __init__(self, task_path, seed=1111):
|
| 89 |
+
logging.debug('***** Transfer task : MPQA *****\n\n')
|
| 90 |
+
pos = self.loadFile(os.path.join(task_path, 'mpqa.pos'))
|
| 91 |
+
neg = self.loadFile(os.path.join(task_path, 'mpqa.neg'))
|
| 92 |
+
super(self.__class__, self).__init__(pos, neg, seed)
|
SentEval/senteval/engine.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2017-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
'''
|
| 9 |
+
|
| 10 |
+
Generic sentence evaluation scripts wrapper
|
| 11 |
+
|
| 12 |
+
'''
|
| 13 |
+
from __future__ import absolute_import, division, unicode_literals
|
| 14 |
+
|
| 15 |
+
from senteval import utils
|
| 16 |
+
from senteval.binary import CREval, MREval, MPQAEval, SUBJEval
|
| 17 |
+
from senteval.snli import SNLIEval
|
| 18 |
+
from senteval.trec import TRECEval
|
| 19 |
+
from senteval.sick import SICKEntailmentEval, SICKEval
|
| 20 |
+
from senteval.mrpc import MRPCEval
|
| 21 |
+
from senteval.sts import STS12Eval, STS13Eval, STS14Eval, STS15Eval, STS16Eval, STSBenchmarkEval, SICKRelatednessEval, STSBenchmarkFinetune
|
| 22 |
+
from senteval.sst import SSTEval
|
| 23 |
+
from senteval.rank import ImageCaptionRetrievalEval
|
| 24 |
+
from senteval.probing import *
|
| 25 |
+
|
| 26 |
+
class SE(object):
|
| 27 |
+
def __init__(self, params, batcher, prepare=None):
|
| 28 |
+
# parameters
|
| 29 |
+
params = utils.dotdict(params)
|
| 30 |
+
params.usepytorch = True if 'usepytorch' not in params else params.usepytorch
|
| 31 |
+
params.seed = 1111 if 'seed' not in params else params.seed
|
| 32 |
+
|
| 33 |
+
params.batch_size = 128 if 'batch_size' not in params else params.batch_size
|
| 34 |
+
params.nhid = 0 if 'nhid' not in params else params.nhid
|
| 35 |
+
params.kfold = 5 if 'kfold' not in params else params.kfold
|
| 36 |
+
|
| 37 |
+
if 'classifier' not in params or not params['classifier']:
|
| 38 |
+
params.classifier = {'nhid': 0}
|
| 39 |
+
|
| 40 |
+
assert 'nhid' in params.classifier, 'Set number of hidden units in classifier config!!'
|
| 41 |
+
|
| 42 |
+
self.params = params
|
| 43 |
+
|
| 44 |
+
# batcher and prepare
|
| 45 |
+
self.batcher = batcher
|
| 46 |
+
self.prepare = prepare if prepare else lambda x, y: None
|
| 47 |
+
|
| 48 |
+
self.list_tasks = ['CR', 'MR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC',
|
| 49 |
+
'SICKRelatedness', 'SICKEntailment', 'STSBenchmark',
|
| 50 |
+
'SNLI', 'ImageCaptionRetrieval', 'STS12', 'STS13',
|
| 51 |
+
'STS14', 'STS15', 'STS16',
|
| 52 |
+
'Length', 'WordContent', 'Depth', 'TopConstituents',
|
| 53 |
+
'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber',
|
| 54 |
+
'OddManOut', 'CoordinationInversion', 'SICKRelatedness-finetune', 'STSBenchmark-finetune', 'STSBenchmark-fix']
|
| 55 |
+
|
| 56 |
+
def eval(self, name):
|
| 57 |
+
# evaluate on evaluation [name], either takes string or list of strings
|
| 58 |
+
if (isinstance(name, list)):
|
| 59 |
+
self.results = {x: self.eval(x) for x in name}
|
| 60 |
+
return self.results
|
| 61 |
+
|
| 62 |
+
tpath = self.params.task_path
|
| 63 |
+
assert name in self.list_tasks, str(name) + ' not in ' + str(self.list_tasks)
|
| 64 |
+
|
| 65 |
+
# Original SentEval tasks
|
| 66 |
+
if name == 'CR':
|
| 67 |
+
self.evaluation = CREval(tpath + '/downstream/CR', seed=self.params.seed)
|
| 68 |
+
elif name == 'MR':
|
| 69 |
+
self.evaluation = MREval(tpath + '/downstream/MR', seed=self.params.seed)
|
| 70 |
+
elif name == 'MPQA':
|
| 71 |
+
self.evaluation = MPQAEval(tpath + '/downstream/MPQA', seed=self.params.seed)
|
| 72 |
+
elif name == 'SUBJ':
|
| 73 |
+
self.evaluation = SUBJEval(tpath + '/downstream/SUBJ', seed=self.params.seed)
|
| 74 |
+
elif name == 'SST2':
|
| 75 |
+
self.evaluation = SSTEval(tpath + '/downstream/SST/binary', nclasses=2, seed=self.params.seed)
|
| 76 |
+
elif name == 'SST5':
|
| 77 |
+
self.evaluation = SSTEval(tpath + '/downstream/SST/fine', nclasses=5, seed=self.params.seed)
|
| 78 |
+
elif name == 'TREC':
|
| 79 |
+
self.evaluation = TRECEval(tpath + '/downstream/TREC', seed=self.params.seed)
|
| 80 |
+
elif name == 'MRPC':
|
| 81 |
+
self.evaluation = MRPCEval(tpath + '/downstream/MRPC', seed=self.params.seed)
|
| 82 |
+
elif name == 'SICKRelatedness':
|
| 83 |
+
self.evaluation = SICKRelatednessEval(tpath + '/downstream/SICK', seed=self.params.seed)
|
| 84 |
+
elif name == 'STSBenchmark':
|
| 85 |
+
self.evaluation = STSBenchmarkEval(tpath + '/downstream/STS/STSBenchmark', seed=self.params.seed)
|
| 86 |
+
elif name == 'STSBenchmark-fix':
|
| 87 |
+
self.evaluation = STSBenchmarkEval(tpath + '/downstream/STS/STSBenchmark-fix', seed=self.params.seed)
|
| 88 |
+
elif name == 'STSBenchmark-finetune':
|
| 89 |
+
self.evaluation = STSBenchmarkFinetune(tpath + '/downstream/STS/STSBenchmark', seed=self.params.seed)
|
| 90 |
+
elif name == 'SICKRelatedness-finetune':
|
| 91 |
+
self.evaluation = SICKEval(tpath + '/downstream/SICK', seed=self.params.seed)
|
| 92 |
+
elif name == 'SICKEntailment':
|
| 93 |
+
self.evaluation = SICKEntailmentEval(tpath + '/downstream/SICK', seed=self.params.seed)
|
| 94 |
+
elif name == 'SNLI':
|
| 95 |
+
self.evaluation = SNLIEval(tpath + '/downstream/SNLI', seed=self.params.seed)
|
| 96 |
+
elif name in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16']:
|
| 97 |
+
fpath = name + '-en-test'
|
| 98 |
+
self.evaluation = eval(name + 'Eval')(tpath + '/downstream/STS/' + fpath, seed=self.params.seed)
|
| 99 |
+
elif name == 'ImageCaptionRetrieval':
|
| 100 |
+
self.evaluation = ImageCaptionRetrievalEval(tpath + '/downstream/COCO', seed=self.params.seed)
|
| 101 |
+
|
| 102 |
+
# Probing Tasks
|
| 103 |
+
elif name == 'Length':
|
| 104 |
+
self.evaluation = LengthEval(tpath + '/probing', seed=self.params.seed)
|
| 105 |
+
elif name == 'WordContent':
|
| 106 |
+
self.evaluation = WordContentEval(tpath + '/probing', seed=self.params.seed)
|
| 107 |
+
elif name == 'Depth':
|
| 108 |
+
self.evaluation = DepthEval(tpath + '/probing', seed=self.params.seed)
|
| 109 |
+
elif name == 'TopConstituents':
|
| 110 |
+
self.evaluation = TopConstituentsEval(tpath + '/probing', seed=self.params.seed)
|
| 111 |
+
elif name == 'BigramShift':
|
| 112 |
+
self.evaluation = BigramShiftEval(tpath + '/probing', seed=self.params.seed)
|
| 113 |
+
elif name == 'Tense':
|
| 114 |
+
self.evaluation = TenseEval(tpath + '/probing', seed=self.params.seed)
|
| 115 |
+
elif name == 'SubjNumber':
|
| 116 |
+
self.evaluation = SubjNumberEval(tpath + '/probing', seed=self.params.seed)
|
| 117 |
+
elif name == 'ObjNumber':
|
| 118 |
+
self.evaluation = ObjNumberEval(tpath + '/probing', seed=self.params.seed)
|
| 119 |
+
elif name == 'OddManOut':
|
| 120 |
+
self.evaluation = OddManOutEval(tpath + '/probing', seed=self.params.seed)
|
| 121 |
+
elif name == 'CoordinationInversion':
|
| 122 |
+
self.evaluation = CoordinationInversionEval(tpath + '/probing', seed=self.params.seed)
|
| 123 |
+
|
| 124 |
+
self.params.current_task = name
|
| 125 |
+
self.evaluation.do_prepare(self.params, self.prepare)
|
| 126 |
+
|
| 127 |
+
self.results = self.evaluation.run(self.params, self.batcher)
|
| 128 |
+
|
| 129 |
+
return self.results
|
SentEval/senteval/mrpc.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2017-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
'''
|
| 9 |
+
MRPC : Microsoft Research Paraphrase (detection) Corpus
|
| 10 |
+
'''
|
| 11 |
+
from __future__ import absolute_import, division, unicode_literals
|
| 12 |
+
|
| 13 |
+
import os
|
| 14 |
+
import logging
|
| 15 |
+
import numpy as np
|
| 16 |
+
import io
|
| 17 |
+
|
| 18 |
+
from senteval.tools.validation import KFoldClassifier
|
| 19 |
+
|
| 20 |
+
from sklearn.metrics import f1_score
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class MRPCEval(object):
|
| 24 |
+
def __init__(self, task_path, seed=1111):
|
| 25 |
+
logging.info('***** Transfer task : MRPC *****\n\n')
|
| 26 |
+
self.seed = seed
|
| 27 |
+
train = self.loadFile(os.path.join(task_path,
|
| 28 |
+
'msr_paraphrase_train.txt'))
|
| 29 |
+
test = self.loadFile(os.path.join(task_path,
|
| 30 |
+
'msr_paraphrase_test.txt'))
|
| 31 |
+
self.mrpc_data = {'train': train, 'test': test}
|
| 32 |
+
|
| 33 |
+
def do_prepare(self, params, prepare):
|
| 34 |
+
# TODO : Should we separate samples in "train, test"?
|
| 35 |
+
samples = self.mrpc_data['train']['X_A'] + \
|
| 36 |
+
self.mrpc_data['train']['X_B'] + \
|
| 37 |
+
self.mrpc_data['test']['X_A'] + self.mrpc_data['test']['X_B']
|
| 38 |
+
return prepare(params, samples)
|
| 39 |
+
|
| 40 |
+
def loadFile(self, fpath):
|
| 41 |
+
mrpc_data = {'X_A': [], 'X_B': [], 'y': []}
|
| 42 |
+
with io.open(fpath, 'r', encoding='utf-8') as f:
|
| 43 |
+
for line in f:
|
| 44 |
+
text = line.strip().split('\t')
|
| 45 |
+
mrpc_data['X_A'].append(text[3].split())
|
| 46 |
+
mrpc_data['X_B'].append(text[4].split())
|
| 47 |
+
mrpc_data['y'].append(text[0])
|
| 48 |
+
|
| 49 |
+
mrpc_data['X_A'] = mrpc_data['X_A'][1:]
|
| 50 |
+
mrpc_data['X_B'] = mrpc_data['X_B'][1:]
|
| 51 |
+
mrpc_data['y'] = [int(s) for s in mrpc_data['y'][1:]]
|
| 52 |
+
return mrpc_data
|
| 53 |
+
|
| 54 |
+
def run(self, params, batcher):
|
| 55 |
+
mrpc_embed = {'train': {}, 'test': {}}
|
| 56 |
+
|
| 57 |
+
for key in self.mrpc_data:
|
| 58 |
+
logging.info('Computing embedding for {0}'.format(key))
|
| 59 |
+
# Sort to reduce padding
|
| 60 |
+
text_data = {}
|
| 61 |
+
sorted_corpus = sorted(zip(self.mrpc_data[key]['X_A'],
|
| 62 |
+
self.mrpc_data[key]['X_B'],
|
| 63 |
+
self.mrpc_data[key]['y']),
|
| 64 |
+
key=lambda z: (len(z[0]), len(z[1]), z[2]))
|
| 65 |
+
|
| 66 |
+
text_data['A'] = [x for (x, y, z) in sorted_corpus]
|
| 67 |
+
text_data['B'] = [y for (x, y, z) in sorted_corpus]
|
| 68 |
+
text_data['y'] = [z for (x, y, z) in sorted_corpus]
|
| 69 |
+
|
| 70 |
+
for txt_type in ['A', 'B']:
|
| 71 |
+
mrpc_embed[key][txt_type] = []
|
| 72 |
+
for ii in range(0, len(text_data['y']), params.batch_size):
|
| 73 |
+
batch = text_data[txt_type][ii:ii + params.batch_size]
|
| 74 |
+
embeddings = batcher(params, batch)
|
| 75 |
+
mrpc_embed[key][txt_type].append(embeddings)
|
| 76 |
+
mrpc_embed[key][txt_type] = np.vstack(mrpc_embed[key][txt_type])
|
| 77 |
+
mrpc_embed[key]['y'] = np.array(text_data['y'])
|
| 78 |
+
logging.info('Computed {0} embeddings'.format(key))
|
| 79 |
+
|
| 80 |
+
# Train
|
| 81 |
+
trainA = mrpc_embed['train']['A']
|
| 82 |
+
trainB = mrpc_embed['train']['B']
|
| 83 |
+
trainF = np.c_[np.abs(trainA - trainB), trainA * trainB]
|
| 84 |
+
trainY = mrpc_embed['train']['y']
|
| 85 |
+
|
| 86 |
+
# Test
|
| 87 |
+
testA = mrpc_embed['test']['A']
|
| 88 |
+
testB = mrpc_embed['test']['B']
|
| 89 |
+
testF = np.c_[np.abs(testA - testB), testA * testB]
|
| 90 |
+
testY = mrpc_embed['test']['y']
|
| 91 |
+
|
| 92 |
+
config = {'nclasses': 2, 'seed': self.seed,
|
| 93 |
+
'usepytorch': params.usepytorch,
|
| 94 |
+
'classifier': params.classifier,
|
| 95 |
+
'nhid': params.nhid, 'kfold': params.kfold}
|
| 96 |
+
clf = KFoldClassifier(train={'X': trainF, 'y': trainY},
|
| 97 |
+
test={'X': testF, 'y': testY}, config=config)
|
| 98 |
+
|
| 99 |
+
devacc, testacc, yhat = clf.run()
|
| 100 |
+
testf1 = round(100*f1_score(testY, yhat), 2)
|
| 101 |
+
logging.debug('Dev acc : {0} Test acc {1}; Test F1 {2} for MRPC.\n'
|
| 102 |
+
.format(devacc, testacc, testf1))
|
| 103 |
+
return {'devacc': devacc, 'acc': testacc, 'f1': testf1,
|
| 104 |
+
'ndev': len(trainA), 'ntest': len(testA)}
|
SentEval/senteval/probing.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2017-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
'''
|
| 9 |
+
probing tasks
|
| 10 |
+
'''
|
| 11 |
+
|
| 12 |
+
from __future__ import absolute_import, division, unicode_literals
|
| 13 |
+
|
| 14 |
+
import os
|
| 15 |
+
import io
|
| 16 |
+
import copy
|
| 17 |
+
import logging
|
| 18 |
+
import numpy as np
|
| 19 |
+
|
| 20 |
+
from senteval.tools.validation import SplitClassifier
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class PROBINGEval(object):
|
| 24 |
+
def __init__(self, task, task_path, seed=1111):
|
| 25 |
+
self.seed = seed
|
| 26 |
+
self.task = task
|
| 27 |
+
logging.debug('***** (Probing) Transfer task : %s classification *****', self.task.upper())
|
| 28 |
+
self.task_data = {'train': {'X': [], 'y': []},
|
| 29 |
+
'dev': {'X': [], 'y': []},
|
| 30 |
+
'test': {'X': [], 'y': []}}
|
| 31 |
+
self.loadFile(task_path)
|
| 32 |
+
logging.info('Loaded %s train - %s dev - %s test for %s' %
|
| 33 |
+
(len(self.task_data['train']['y']), len(self.task_data['dev']['y']),
|
| 34 |
+
len(self.task_data['test']['y']), self.task))
|
| 35 |
+
|
| 36 |
+
def do_prepare(self, params, prepare):
|
| 37 |
+
samples = self.task_data['train']['X'] + self.task_data['dev']['X'] + \
|
| 38 |
+
self.task_data['test']['X']
|
| 39 |
+
return prepare(params, samples)
|
| 40 |
+
|
| 41 |
+
def loadFile(self, fpath):
|
| 42 |
+
self.tok2split = {'tr': 'train', 'va': 'dev', 'te': 'test'}
|
| 43 |
+
with io.open(fpath, 'r', encoding='utf-8') as f:
|
| 44 |
+
for line in f:
|
| 45 |
+
line = line.rstrip().split('\t')
|
| 46 |
+
self.task_data[self.tok2split[line[0]]]['X'].append(line[-1].split())
|
| 47 |
+
self.task_data[self.tok2split[line[0]]]['y'].append(line[1])
|
| 48 |
+
|
| 49 |
+
labels = sorted(np.unique(self.task_data['train']['y']))
|
| 50 |
+
self.tok2label = dict(zip(labels, range(len(labels))))
|
| 51 |
+
self.nclasses = len(self.tok2label)
|
| 52 |
+
|
| 53 |
+
for split in self.task_data:
|
| 54 |
+
for i, y in enumerate(self.task_data[split]['y']):
|
| 55 |
+
self.task_data[split]['y'][i] = self.tok2label[y]
|
| 56 |
+
|
| 57 |
+
def run(self, params, batcher):
|
| 58 |
+
task_embed = {'train': {}, 'dev': {}, 'test': {}}
|
| 59 |
+
bsize = params.batch_size
|
| 60 |
+
logging.info('Computing embeddings for train/dev/test')
|
| 61 |
+
for key in self.task_data:
|
| 62 |
+
# Sort to reduce padding
|
| 63 |
+
sorted_data = sorted(zip(self.task_data[key]['X'],
|
| 64 |
+
self.task_data[key]['y']),
|
| 65 |
+
key=lambda z: (len(z[0]), z[1]))
|
| 66 |
+
self.task_data[key]['X'], self.task_data[key]['y'] = map(list, zip(*sorted_data))
|
| 67 |
+
|
| 68 |
+
task_embed[key]['X'] = []
|
| 69 |
+
for ii in range(0, len(self.task_data[key]['y']), bsize):
|
| 70 |
+
batch = self.task_data[key]['X'][ii:ii + bsize]
|
| 71 |
+
embeddings = batcher(params, batch)
|
| 72 |
+
task_embed[key]['X'].append(embeddings)
|
| 73 |
+
task_embed[key]['X'] = np.vstack(task_embed[key]['X'])
|
| 74 |
+
task_embed[key]['y'] = np.array(self.task_data[key]['y'])
|
| 75 |
+
logging.info('Computed embeddings')
|
| 76 |
+
|
| 77 |
+
config_classifier = {'nclasses': self.nclasses, 'seed': self.seed,
|
| 78 |
+
'usepytorch': params.usepytorch,
|
| 79 |
+
'classifier': params.classifier}
|
| 80 |
+
|
| 81 |
+
if self.task == "WordContent" and params.classifier['nhid'] > 0:
|
| 82 |
+
config_classifier = copy.deepcopy(config_classifier)
|
| 83 |
+
config_classifier['classifier']['nhid'] = 0
|
| 84 |
+
print(params.classifier['nhid'])
|
| 85 |
+
|
| 86 |
+
clf = SplitClassifier(X={'train': task_embed['train']['X'],
|
| 87 |
+
'valid': task_embed['dev']['X'],
|
| 88 |
+
'test': task_embed['test']['X']},
|
| 89 |
+
y={'train': task_embed['train']['y'],
|
| 90 |
+
'valid': task_embed['dev']['y'],
|
| 91 |
+
'test': task_embed['test']['y']},
|
| 92 |
+
config=config_classifier)
|
| 93 |
+
|
| 94 |
+
devacc, testacc = clf.run()
|
| 95 |
+
logging.debug('\nDev acc : %.1f Test acc : %.1f for %s classification\n' % (devacc, testacc, self.task.upper()))
|
| 96 |
+
|
| 97 |
+
return {'devacc': devacc, 'acc': testacc,
|
| 98 |
+
'ndev': len(task_embed['dev']['X']),
|
| 99 |
+
'ntest': len(task_embed['test']['X'])}
|
| 100 |
+
|
| 101 |
+
"""
|
| 102 |
+
Surface Information
|
| 103 |
+
"""
|
| 104 |
+
class LengthEval(PROBINGEval):
|
| 105 |
+
def __init__(self, task_path, seed=1111):
|
| 106 |
+
task_path = os.path.join(task_path, 'sentence_length.txt')
|
| 107 |
+
# labels: bins
|
| 108 |
+
PROBINGEval.__init__(self, 'Length', task_path, seed)
|
| 109 |
+
|
| 110 |
+
class WordContentEval(PROBINGEval):
|
| 111 |
+
def __init__(self, task_path, seed=1111):
|
| 112 |
+
task_path = os.path.join(task_path, 'word_content.txt')
|
| 113 |
+
# labels: 200 target words
|
| 114 |
+
PROBINGEval.__init__(self, 'WordContent', task_path, seed)
|
| 115 |
+
|
| 116 |
+
"""
|
| 117 |
+
Latent Structural Information
|
| 118 |
+
"""
|
| 119 |
+
class DepthEval(PROBINGEval):
|
| 120 |
+
def __init__(self, task_path, seed=1111):
|
| 121 |
+
task_path = os.path.join(task_path, 'tree_depth.txt')
|
| 122 |
+
# labels: bins
|
| 123 |
+
PROBINGEval.__init__(self, 'Depth', task_path, seed)
|
| 124 |
+
|
| 125 |
+
class TopConstituentsEval(PROBINGEval):
|
| 126 |
+
def __init__(self, task_path, seed=1111):
|
| 127 |
+
task_path = os.path.join(task_path, 'top_constituents.txt')
|
| 128 |
+
# labels: 'PP_NP_VP_.' .. (20 classes)
|
| 129 |
+
PROBINGEval.__init__(self, 'TopConstituents', task_path, seed)
|
| 130 |
+
|
| 131 |
+
class BigramShiftEval(PROBINGEval):
|
| 132 |
+
def __init__(self, task_path, seed=1111):
|
| 133 |
+
task_path = os.path.join(task_path, 'bigram_shift.txt')
|
| 134 |
+
# labels: 0 or 1
|
| 135 |
+
PROBINGEval.__init__(self, 'BigramShift', task_path, seed)
|
| 136 |
+
|
| 137 |
+
# TODO: Voice?
|
| 138 |
+
|
| 139 |
+
"""
|
| 140 |
+
Latent Semantic Information
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
+
class TenseEval(PROBINGEval):
|
| 144 |
+
def __init__(self, task_path, seed=1111):
|
| 145 |
+
task_path = os.path.join(task_path, 'past_present.txt')
|
| 146 |
+
# labels: 'PRES', 'PAST'
|
| 147 |
+
PROBINGEval.__init__(self, 'Tense', task_path, seed)
|
| 148 |
+
|
| 149 |
+
class SubjNumberEval(PROBINGEval):
|
| 150 |
+
def __init__(self, task_path, seed=1111):
|
| 151 |
+
task_path = os.path.join(task_path, 'subj_number.txt')
|
| 152 |
+
# labels: 'NN', 'NNS'
|
| 153 |
+
PROBINGEval.__init__(self, 'SubjNumber', task_path, seed)
|
| 154 |
+
|
| 155 |
+
class ObjNumberEval(PROBINGEval):
|
| 156 |
+
def __init__(self, task_path, seed=1111):
|
| 157 |
+
task_path = os.path.join(task_path, 'obj_number.txt')
|
| 158 |
+
# labels: 'NN', 'NNS'
|
| 159 |
+
PROBINGEval.__init__(self, 'ObjNumber', task_path, seed)
|
| 160 |
+
|
| 161 |
+
class OddManOutEval(PROBINGEval):
|
| 162 |
+
def __init__(self, task_path, seed=1111):
|
| 163 |
+
task_path = os.path.join(task_path, 'odd_man_out.txt')
|
| 164 |
+
# labels: 'O', 'C'
|
| 165 |
+
PROBINGEval.__init__(self, 'OddManOut', task_path, seed)
|
| 166 |
+
|
| 167 |
+
class CoordinationInversionEval(PROBINGEval):
|
| 168 |
+
def __init__(self, task_path, seed=1111):
|
| 169 |
+
task_path = os.path.join(task_path, 'coordination_inversion.txt')
|
| 170 |
+
# labels: 'O', 'I'
|
| 171 |
+
PROBINGEval.__init__(self, 'CoordinationInversion', task_path, seed)
|
SentEval/senteval/rank.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2017-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
'''
|
| 9 |
+
Image-Caption Retrieval with COCO dataset
|
| 10 |
+
'''
|
| 11 |
+
from __future__ import absolute_import, division, unicode_literals
|
| 12 |
+
|
| 13 |
+
import os
|
| 14 |
+
import sys
|
| 15 |
+
import logging
|
| 16 |
+
import numpy as np
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
import cPickle as pickle
|
| 20 |
+
except ImportError:
|
| 21 |
+
import pickle
|
| 22 |
+
|
| 23 |
+
from senteval.tools.ranking import ImageSentenceRankingPytorch
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ImageCaptionRetrievalEval(object):
|
| 27 |
+
def __init__(self, task_path, seed=1111):
|
| 28 |
+
logging.debug('***** Transfer task: Image Caption Retrieval *****\n\n')
|
| 29 |
+
|
| 30 |
+
# Get captions and image features
|
| 31 |
+
self.seed = seed
|
| 32 |
+
train, dev, test = self.loadFile(task_path)
|
| 33 |
+
self.coco_data = {'train': train, 'dev': dev, 'test': test}
|
| 34 |
+
|
| 35 |
+
def do_prepare(self, params, prepare):
|
| 36 |
+
samples = self.coco_data['train']['sent'] + \
|
| 37 |
+
self.coco_data['dev']['sent'] + \
|
| 38 |
+
self.coco_data['test']['sent']
|
| 39 |
+
prepare(params, samples)
|
| 40 |
+
|
| 41 |
+
def loadFile(self, fpath):
|
| 42 |
+
coco = {}
|
| 43 |
+
|
| 44 |
+
for split in ['train', 'valid', 'test']:
|
| 45 |
+
list_sent = []
|
| 46 |
+
list_img_feat = []
|
| 47 |
+
if sys.version_info < (3, 0):
|
| 48 |
+
with open(os.path.join(fpath, split + '.pkl')) as f:
|
| 49 |
+
cocodata = pickle.load(f)
|
| 50 |
+
else:
|
| 51 |
+
with open(os.path.join(fpath, split + '.pkl'), 'rb') as f:
|
| 52 |
+
cocodata = pickle.load(f, encoding='latin1')
|
| 53 |
+
|
| 54 |
+
for imgkey in range(len(cocodata['features'])):
|
| 55 |
+
assert len(cocodata['image_to_caption_ids'][imgkey]) >= 5, \
|
| 56 |
+
cocodata['image_to_caption_ids'][imgkey]
|
| 57 |
+
for captkey in cocodata['image_to_caption_ids'][imgkey][0:5]:
|
| 58 |
+
sent = cocodata['captions'][captkey]['cleaned_caption']
|
| 59 |
+
sent += ' .' # add punctuation to end of sentence in COCO
|
| 60 |
+
list_sent.append(sent.encode('utf-8').split())
|
| 61 |
+
list_img_feat.append(cocodata['features'][imgkey])
|
| 62 |
+
assert len(list_sent) == len(list_img_feat) and \
|
| 63 |
+
len(list_sent) % 5 == 0
|
| 64 |
+
list_img_feat = np.array(list_img_feat).astype('float32')
|
| 65 |
+
coco[split] = {'sent': list_sent, 'imgfeat': list_img_feat}
|
| 66 |
+
return coco['train'], coco['valid'], coco['test']
|
| 67 |
+
|
| 68 |
+
def run(self, params, batcher):
|
| 69 |
+
coco_embed = {'train': {'sentfeat': [], 'imgfeat': []},
|
| 70 |
+
'dev': {'sentfeat': [], 'imgfeat': []},
|
| 71 |
+
'test': {'sentfeat': [], 'imgfeat': []}}
|
| 72 |
+
|
| 73 |
+
for key in self.coco_data:
|
| 74 |
+
logging.info('Computing embedding for {0}'.format(key))
|
| 75 |
+
# Sort to reduce padding
|
| 76 |
+
self.coco_data[key]['sent'] = np.array(self.coco_data[key]['sent'])
|
| 77 |
+
self.coco_data[key]['sent'], idx_sort = np.sort(self.coco_data[key]['sent']), np.argsort(self.coco_data[key]['sent'])
|
| 78 |
+
idx_unsort = np.argsort(idx_sort)
|
| 79 |
+
|
| 80 |
+
coco_embed[key]['X'] = []
|
| 81 |
+
nsent = len(self.coco_data[key]['sent'])
|
| 82 |
+
for ii in range(0, nsent, params.batch_size):
|
| 83 |
+
batch = self.coco_data[key]['sent'][ii:ii + params.batch_size]
|
| 84 |
+
embeddings = batcher(params, batch)
|
| 85 |
+
coco_embed[key]['sentfeat'].append(embeddings)
|
| 86 |
+
coco_embed[key]['sentfeat'] = np.vstack(coco_embed[key]['sentfeat'])[idx_unsort]
|
| 87 |
+
coco_embed[key]['imgfeat'] = np.array(self.coco_data[key]['imgfeat'])
|
| 88 |
+
logging.info('Computed {0} embeddings'.format(key))
|
| 89 |
+
|
| 90 |
+
config = {'seed': self.seed, 'projdim': 1000, 'margin': 0.2}
|
| 91 |
+
clf = ImageSentenceRankingPytorch(train=coco_embed['train'],
|
| 92 |
+
valid=coco_embed['dev'],
|
| 93 |
+
test=coco_embed['test'],
|
| 94 |
+
config=config)
|
| 95 |
+
|
| 96 |
+
bestdevscore, r1_i2t, r5_i2t, r10_i2t, medr_i2t, \
|
| 97 |
+
r1_t2i, r5_t2i, r10_t2i, medr_t2i = clf.run()
|
| 98 |
+
|
| 99 |
+
logging.debug("\nTest scores | Image to text: \
|
| 100 |
+
{0}, {1}, {2}, {3}".format(r1_i2t, r5_i2t, r10_i2t, medr_i2t))
|
| 101 |
+
logging.debug("Test scores | Text to image: \
|
| 102 |
+
{0}, {1}, {2}, {3}\n".format(r1_t2i, r5_t2i, r10_t2i, medr_t2i))
|
| 103 |
+
|
| 104 |
+
return {'devacc': bestdevscore,
|
| 105 |
+
'acc': [(r1_i2t, r5_i2t, r10_i2t, medr_i2t),
|
| 106 |
+
(r1_t2i, r5_t2i, r10_t2i, medr_t2i)],
|
| 107 |
+
'ndev': len(coco_embed['dev']['sentfeat']),
|
| 108 |
+
'ntest': len(coco_embed['test']['sentfeat'])}
|
SentEval/senteval/sick.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2017-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
'''
|
| 9 |
+
SICK Relatedness and Entailment
|
| 10 |
+
'''
|
| 11 |
+
from __future__ import absolute_import, division, unicode_literals
|
| 12 |
+
|
| 13 |
+
import os
|
| 14 |
+
import io
|
| 15 |
+
import logging
|
| 16 |
+
import numpy as np
|
| 17 |
+
|
| 18 |
+
from sklearn.metrics import mean_squared_error
|
| 19 |
+
from scipy.stats import pearsonr, spearmanr
|
| 20 |
+
|
| 21 |
+
from senteval.tools.relatedness import RelatednessPytorch
|
| 22 |
+
from senteval.tools.validation import SplitClassifier
|
| 23 |
+
|
| 24 |
+
class SICKEval(object):
|
| 25 |
+
def __init__(self, task_path, seed=1111):
|
| 26 |
+
logging.debug('***** Transfer task : SICK-Relatedness*****\n\n')
|
| 27 |
+
self.seed = seed
|
| 28 |
+
train = self.loadFile(os.path.join(task_path, 'SICK_train.txt'))
|
| 29 |
+
dev = self.loadFile(os.path.join(task_path, 'SICK_trial.txt'))
|
| 30 |
+
test = self.loadFile(os.path.join(task_path, 'SICK_test_annotated.txt'))
|
| 31 |
+
self.sick_data = {'train': train, 'dev': dev, 'test': test}
|
| 32 |
+
|
| 33 |
+
def do_prepare(self, params, prepare):
|
| 34 |
+
samples = self.sick_data['train']['X_A'] + \
|
| 35 |
+
self.sick_data['train']['X_B'] + \
|
| 36 |
+
self.sick_data['dev']['X_A'] + \
|
| 37 |
+
self.sick_data['dev']['X_B'] + \
|
| 38 |
+
self.sick_data['test']['X_A'] + self.sick_data['test']['X_B']
|
| 39 |
+
return prepare(params, samples)
|
| 40 |
+
|
| 41 |
+
def loadFile(self, fpath):
|
| 42 |
+
skipFirstLine = True
|
| 43 |
+
sick_data = {'X_A': [], 'X_B': [], 'y': []}
|
| 44 |
+
with io.open(fpath, 'r', encoding='utf-8') as f:
|
| 45 |
+
for line in f:
|
| 46 |
+
if skipFirstLine:
|
| 47 |
+
skipFirstLine = False
|
| 48 |
+
else:
|
| 49 |
+
text = line.strip().split('\t')
|
| 50 |
+
sick_data['X_A'].append(text[1].split())
|
| 51 |
+
sick_data['X_B'].append(text[2].split())
|
| 52 |
+
sick_data['y'].append(text[3])
|
| 53 |
+
|
| 54 |
+
sick_data['y'] = [float(s) for s in sick_data['y']]
|
| 55 |
+
return sick_data
|
| 56 |
+
|
| 57 |
+
def run(self, params, batcher):
|
| 58 |
+
sick_embed = {'train': {}, 'dev': {}, 'test': {}}
|
| 59 |
+
bsize = params.batch_size
|
| 60 |
+
|
| 61 |
+
for key in self.sick_data:
|
| 62 |
+
logging.info('Computing embedding for {0}'.format(key))
|
| 63 |
+
# Sort to reduce padding
|
| 64 |
+
sorted_corpus = sorted(zip(self.sick_data[key]['X_A'],
|
| 65 |
+
self.sick_data[key]['X_B'],
|
| 66 |
+
self.sick_data[key]['y']),
|
| 67 |
+
key=lambda z: (len(z[0]), len(z[1]), z[2]))
|
| 68 |
+
|
| 69 |
+
self.sick_data[key]['X_A'] = [x for (x, y, z) in sorted_corpus]
|
| 70 |
+
self.sick_data[key]['X_B'] = [y for (x, y, z) in sorted_corpus]
|
| 71 |
+
self.sick_data[key]['y'] = [z for (x, y, z) in sorted_corpus]
|
| 72 |
+
|
| 73 |
+
for txt_type in ['X_A', 'X_B']:
|
| 74 |
+
sick_embed[key][txt_type] = []
|
| 75 |
+
for ii in range(0, len(self.sick_data[key]['y']), bsize):
|
| 76 |
+
batch = self.sick_data[key][txt_type][ii:ii + bsize]
|
| 77 |
+
embeddings = batcher(params, batch)
|
| 78 |
+
sick_embed[key][txt_type].append(embeddings)
|
| 79 |
+
sick_embed[key][txt_type] = np.vstack(sick_embed[key][txt_type])
|
| 80 |
+
sick_embed[key]['y'] = np.array(self.sick_data[key]['y'])
|
| 81 |
+
logging.info('Computed {0} embeddings'.format(key))
|
| 82 |
+
|
| 83 |
+
# Train
|
| 84 |
+
trainA = sick_embed['train']['X_A']
|
| 85 |
+
trainB = sick_embed['train']['X_B']
|
| 86 |
+
trainF = np.c_[np.abs(trainA - trainB), trainA * trainB]
|
| 87 |
+
trainY = self.encode_labels(self.sick_data['train']['y'])
|
| 88 |
+
|
| 89 |
+
# Dev
|
| 90 |
+
devA = sick_embed['dev']['X_A']
|
| 91 |
+
devB = sick_embed['dev']['X_B']
|
| 92 |
+
devF = np.c_[np.abs(devA - devB), devA * devB]
|
| 93 |
+
devY = self.encode_labels(self.sick_data['dev']['y'])
|
| 94 |
+
|
| 95 |
+
# Test
|
| 96 |
+
testA = sick_embed['test']['X_A']
|
| 97 |
+
testB = sick_embed['test']['X_B']
|
| 98 |
+
testF = np.c_[np.abs(testA - testB), testA * testB]
|
| 99 |
+
testY = self.encode_labels(self.sick_data['test']['y'])
|
| 100 |
+
|
| 101 |
+
config = {'seed': self.seed, 'nclasses': 5}
|
| 102 |
+
clf = RelatednessPytorch(train={'X': trainF, 'y': trainY},
|
| 103 |
+
valid={'X': devF, 'y': devY},
|
| 104 |
+
test={'X': testF, 'y': testY},
|
| 105 |
+
devscores=self.sick_data['dev']['y'],
|
| 106 |
+
config=config)
|
| 107 |
+
|
| 108 |
+
devspr, yhat = clf.run()
|
| 109 |
+
|
| 110 |
+
pr = pearsonr(yhat, self.sick_data['test']['y'])[0]
|
| 111 |
+
sr = spearmanr(yhat, self.sick_data['test']['y'])[0]
|
| 112 |
+
pr = 0 if pr != pr else pr
|
| 113 |
+
sr = 0 if sr != sr else sr
|
| 114 |
+
se = mean_squared_error(yhat, self.sick_data['test']['y'])
|
| 115 |
+
logging.debug('Dev : Spearman {0}'.format(devspr))
|
| 116 |
+
logging.debug('Test : Pearson {0} Spearman {1} MSE {2} \
|
| 117 |
+
for SICK Relatedness\n'.format(pr, sr, se))
|
| 118 |
+
|
| 119 |
+
return {'devspearman': devspr, 'pearson': pr, 'spearman': sr, 'mse': se,
|
| 120 |
+
'yhat': yhat, 'ndev': len(devA), 'ntest': len(testA)}
|
| 121 |
+
|
| 122 |
+
def encode_labels(self, labels, nclass=5):
|
| 123 |
+
"""
|
| 124 |
+
Label encoding from Tree LSTM paper (Tai, Socher, Manning)
|
| 125 |
+
"""
|
| 126 |
+
Y = np.zeros((len(labels), nclass)).astype('float32')
|
| 127 |
+
for j, y in enumerate(labels):
|
| 128 |
+
for i in range(nclass):
|
| 129 |
+
if i+1 == np.floor(y) + 1:
|
| 130 |
+
Y[j, i] = y - np.floor(y)
|
| 131 |
+
if i+1 == np.floor(y):
|
| 132 |
+
Y[j, i] = np.floor(y) - y + 1
|
| 133 |
+
return Y
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class SICKEntailmentEval(SICKEval):
|
| 137 |
+
def __init__(self, task_path, seed=1111):
|
| 138 |
+
logging.debug('***** Transfer task : SICK-Entailment*****\n\n')
|
| 139 |
+
self.seed = seed
|
| 140 |
+
train = self.loadFile(os.path.join(task_path, 'SICK_train.txt'))
|
| 141 |
+
dev = self.loadFile(os.path.join(task_path, 'SICK_trial.txt'))
|
| 142 |
+
test = self.loadFile(os.path.join(task_path, 'SICK_test_annotated.txt'))
|
| 143 |
+
self.sick_data = {'train': train, 'dev': dev, 'test': test}
|
| 144 |
+
|
| 145 |
+
def loadFile(self, fpath):
|
| 146 |
+
label2id = {'CONTRADICTION': 0, 'NEUTRAL': 1, 'ENTAILMENT': 2}
|
| 147 |
+
skipFirstLine = True
|
| 148 |
+
sick_data = {'X_A': [], 'X_B': [], 'y': []}
|
| 149 |
+
with io.open(fpath, 'r', encoding='utf-8') as f:
|
| 150 |
+
for line in f:
|
| 151 |
+
if skipFirstLine:
|
| 152 |
+
skipFirstLine = False
|
| 153 |
+
else:
|
| 154 |
+
text = line.strip().split('\t')
|
| 155 |
+
sick_data['X_A'].append(text[1].split())
|
| 156 |
+
sick_data['X_B'].append(text[2].split())
|
| 157 |
+
sick_data['y'].append(text[4])
|
| 158 |
+
sick_data['y'] = [label2id[s] for s in sick_data['y']]
|
| 159 |
+
return sick_data
|
| 160 |
+
|
| 161 |
+
def run(self, params, batcher):
|
| 162 |
+
sick_embed = {'train': {}, 'dev': {}, 'test': {}}
|
| 163 |
+
bsize = params.batch_size
|
| 164 |
+
|
| 165 |
+
for key in self.sick_data:
|
| 166 |
+
logging.info('Computing embedding for {0}'.format(key))
|
| 167 |
+
# Sort to reduce padding
|
| 168 |
+
sorted_corpus = sorted(zip(self.sick_data[key]['X_A'],
|
| 169 |
+
self.sick_data[key]['X_B'],
|
| 170 |
+
self.sick_data[key]['y']),
|
| 171 |
+
key=lambda z: (len(z[0]), len(z[1]), z[2]))
|
| 172 |
+
|
| 173 |
+
self.sick_data[key]['X_A'] = [x for (x, y, z) in sorted_corpus]
|
| 174 |
+
self.sick_data[key]['X_B'] = [y for (x, y, z) in sorted_corpus]
|
| 175 |
+
self.sick_data[key]['y'] = [z for (x, y, z) in sorted_corpus]
|
| 176 |
+
|
| 177 |
+
for txt_type in ['X_A', 'X_B']:
|
| 178 |
+
sick_embed[key][txt_type] = []
|
| 179 |
+
for ii in range(0, len(self.sick_data[key]['y']), bsize):
|
| 180 |
+
batch = self.sick_data[key][txt_type][ii:ii + bsize]
|
| 181 |
+
embeddings = batcher(params, batch)
|
| 182 |
+
sick_embed[key][txt_type].append(embeddings)
|
| 183 |
+
sick_embed[key][txt_type] = np.vstack(sick_embed[key][txt_type])
|
| 184 |
+
logging.info('Computed {0} embeddings'.format(key))
|
| 185 |
+
|
| 186 |
+
# Train
|
| 187 |
+
trainA = sick_embed['train']['X_A']
|
| 188 |
+
trainB = sick_embed['train']['X_B']
|
| 189 |
+
trainF = np.c_[np.abs(trainA - trainB), trainA * trainB]
|
| 190 |
+
trainY = np.array(self.sick_data['train']['y'])
|
| 191 |
+
|
| 192 |
+
# Dev
|
| 193 |
+
devA = sick_embed['dev']['X_A']
|
| 194 |
+
devB = sick_embed['dev']['X_B']
|
| 195 |
+
devF = np.c_[np.abs(devA - devB), devA * devB]
|
| 196 |
+
devY = np.array(self.sick_data['dev']['y'])
|
| 197 |
+
|
| 198 |
+
# Test
|
| 199 |
+
testA = sick_embed['test']['X_A']
|
| 200 |
+
testB = sick_embed['test']['X_B']
|
| 201 |
+
testF = np.c_[np.abs(testA - testB), testA * testB]
|
| 202 |
+
testY = np.array(self.sick_data['test']['y'])
|
| 203 |
+
|
| 204 |
+
config = {'nclasses': 3, 'seed': self.seed,
|
| 205 |
+
'usepytorch': params.usepytorch,
|
| 206 |
+
'classifier': params.classifier,
|
| 207 |
+
'nhid': params.nhid}
|
| 208 |
+
clf = SplitClassifier(X={'train': trainF, 'valid': devF, 'test': testF},
|
| 209 |
+
y={'train': trainY, 'valid': devY, 'test': testY},
|
| 210 |
+
config=config)
|
| 211 |
+
|
| 212 |
+
devacc, testacc = clf.run()
|
| 213 |
+
logging.debug('\nDev acc : {0} Test acc : {1} for \
|
| 214 |
+
SICK entailment\n'.format(devacc, testacc))
|
| 215 |
+
return {'devacc': devacc, 'acc': testacc,
|
| 216 |
+
'ndev': len(devA), 'ntest': len(testA)}
|
SentEval/senteval/snli.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2017-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
'''
|
| 9 |
+
SNLI - Entailment
|
| 10 |
+
'''
|
| 11 |
+
from __future__ import absolute_import, division, unicode_literals
|
| 12 |
+
|
| 13 |
+
import codecs
|
| 14 |
+
import os
|
| 15 |
+
import io
|
| 16 |
+
import copy
|
| 17 |
+
import logging
|
| 18 |
+
import numpy as np
|
| 19 |
+
|
| 20 |
+
from senteval.tools.validation import SplitClassifier
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class SNLIEval(object):
|
| 24 |
+
def __init__(self, taskpath, seed=1111):
|
| 25 |
+
logging.debug('***** Transfer task : SNLI Entailment*****\n\n')
|
| 26 |
+
self.seed = seed
|
| 27 |
+
train1 = self.loadFile(os.path.join(taskpath, 's1.train'))
|
| 28 |
+
train2 = self.loadFile(os.path.join(taskpath, 's2.train'))
|
| 29 |
+
|
| 30 |
+
trainlabels = io.open(os.path.join(taskpath, 'labels.train'),
|
| 31 |
+
encoding='utf-8').read().splitlines()
|
| 32 |
+
|
| 33 |
+
valid1 = self.loadFile(os.path.join(taskpath, 's1.dev'))
|
| 34 |
+
valid2 = self.loadFile(os.path.join(taskpath, 's2.dev'))
|
| 35 |
+
validlabels = io.open(os.path.join(taskpath, 'labels.dev'),
|
| 36 |
+
encoding='utf-8').read().splitlines()
|
| 37 |
+
|
| 38 |
+
test1 = self.loadFile(os.path.join(taskpath, 's1.test'))
|
| 39 |
+
test2 = self.loadFile(os.path.join(taskpath, 's2.test'))
|
| 40 |
+
testlabels = io.open(os.path.join(taskpath, 'labels.test'),
|
| 41 |
+
encoding='utf-8').read().splitlines()
|
| 42 |
+
|
| 43 |
+
# sort data (by s2 first) to reduce padding
|
| 44 |
+
sorted_train = sorted(zip(train2, train1, trainlabels),
|
| 45 |
+
key=lambda z: (len(z[0]), len(z[1]), z[2]))
|
| 46 |
+
train2, train1, trainlabels = map(list, zip(*sorted_train))
|
| 47 |
+
|
| 48 |
+
sorted_valid = sorted(zip(valid2, valid1, validlabels),
|
| 49 |
+
key=lambda z: (len(z[0]), len(z[1]), z[2]))
|
| 50 |
+
valid2, valid1, validlabels = map(list, zip(*sorted_valid))
|
| 51 |
+
|
| 52 |
+
sorted_test = sorted(zip(test2, test1, testlabels),
|
| 53 |
+
key=lambda z: (len(z[0]), len(z[1]), z[2]))
|
| 54 |
+
test2, test1, testlabels = map(list, zip(*sorted_test))
|
| 55 |
+
|
| 56 |
+
self.samples = train1 + train2 + valid1 + valid2 + test1 + test2
|
| 57 |
+
self.data = {'train': (train1, train2, trainlabels),
|
| 58 |
+
'valid': (valid1, valid2, validlabels),
|
| 59 |
+
'test': (test1, test2, testlabels)
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
def do_prepare(self, params, prepare):
|
| 63 |
+
return prepare(params, self.samples)
|
| 64 |
+
|
| 65 |
+
def loadFile(self, fpath):
|
| 66 |
+
with codecs.open(fpath, 'rb', 'latin-1') as f:
|
| 67 |
+
return [line.split() for line in
|
| 68 |
+
f.read().splitlines()]
|
| 69 |
+
|
| 70 |
+
def run(self, params, batcher):
|
| 71 |
+
self.X, self.y = {}, {}
|
| 72 |
+
dico_label = {'entailment': 0, 'neutral': 1, 'contradiction': 2}
|
| 73 |
+
for key in self.data:
|
| 74 |
+
if key not in self.X:
|
| 75 |
+
self.X[key] = []
|
| 76 |
+
if key not in self.y:
|
| 77 |
+
self.y[key] = []
|
| 78 |
+
|
| 79 |
+
input1, input2, mylabels = self.data[key]
|
| 80 |
+
enc_input = []
|
| 81 |
+
n_labels = len(mylabels)
|
| 82 |
+
for ii in range(0, n_labels, params.batch_size):
|
| 83 |
+
batch1 = input1[ii:ii + params.batch_size]
|
| 84 |
+
batch2 = input2[ii:ii + params.batch_size]
|
| 85 |
+
|
| 86 |
+
if len(batch1) == len(batch2) and len(batch1) > 0:
|
| 87 |
+
enc1 = batcher(params, batch1)
|
| 88 |
+
enc2 = batcher(params, batch2)
|
| 89 |
+
enc_input.append(np.hstack((enc1, enc2, enc1 * enc2,
|
| 90 |
+
np.abs(enc1 - enc2))))
|
| 91 |
+
if (ii*params.batch_size) % (20000*params.batch_size) == 0:
|
| 92 |
+
logging.info("PROGRESS (encoding): %.2f%%" %
|
| 93 |
+
(100 * ii / n_labels))
|
| 94 |
+
self.X[key] = np.vstack(enc_input)
|
| 95 |
+
self.y[key] = [dico_label[y] for y in mylabels]
|
| 96 |
+
|
| 97 |
+
config = {'nclasses': 3, 'seed': self.seed,
|
| 98 |
+
'usepytorch': params.usepytorch,
|
| 99 |
+
'cudaEfficient': True,
|
| 100 |
+
'nhid': params.nhid, 'noreg': True}
|
| 101 |
+
|
| 102 |
+
config_classifier = copy.deepcopy(params.classifier)
|
| 103 |
+
config_classifier['max_epoch'] = 15
|
| 104 |
+
config_classifier['epoch_size'] = 1
|
| 105 |
+
config['classifier'] = config_classifier
|
| 106 |
+
|
| 107 |
+
clf = SplitClassifier(self.X, self.y, config)
|
| 108 |
+
devacc, testacc = clf.run()
|
| 109 |
+
logging.debug('Dev acc : {0} Test acc : {1} for SNLI\n'
|
| 110 |
+
.format(devacc, testacc))
|
| 111 |
+
return {'devacc': devacc, 'acc': testacc,
|
| 112 |
+
'ndev': len(self.data['valid'][0]),
|
| 113 |
+
'ntest': len(self.data['test'][0])}
|
SentEval/senteval/sst.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2017-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
'''
|
| 9 |
+
SST - binary classification
|
| 10 |
+
'''
|
| 11 |
+
|
| 12 |
+
from __future__ import absolute_import, division, unicode_literals
|
| 13 |
+
|
| 14 |
+
import os
|
| 15 |
+
import io
|
| 16 |
+
import logging
|
| 17 |
+
import numpy as np
|
| 18 |
+
|
| 19 |
+
from senteval.tools.validation import SplitClassifier
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class SSTEval(object):
|
| 23 |
+
def __init__(self, task_path, nclasses=2, seed=1111):
|
| 24 |
+
self.seed = seed
|
| 25 |
+
|
| 26 |
+
# binary of fine-grained
|
| 27 |
+
assert nclasses in [2, 5]
|
| 28 |
+
self.nclasses = nclasses
|
| 29 |
+
self.task_name = 'Binary' if self.nclasses == 2 else 'Fine-Grained'
|
| 30 |
+
logging.debug('***** Transfer task : SST %s classification *****\n\n', self.task_name)
|
| 31 |
+
|
| 32 |
+
train = self.loadFile(os.path.join(task_path, 'sentiment-train'))
|
| 33 |
+
dev = self.loadFile(os.path.join(task_path, 'sentiment-dev'))
|
| 34 |
+
test = self.loadFile(os.path.join(task_path, 'sentiment-test'))
|
| 35 |
+
self.sst_data = {'train': train, 'dev': dev, 'test': test}
|
| 36 |
+
|
| 37 |
+
def do_prepare(self, params, prepare):
|
| 38 |
+
samples = self.sst_data['train']['X'] + self.sst_data['dev']['X'] + \
|
| 39 |
+
self.sst_data['test']['X']
|
| 40 |
+
return prepare(params, samples)
|
| 41 |
+
|
| 42 |
+
def loadFile(self, fpath):
|
| 43 |
+
sst_data = {'X': [], 'y': []}
|
| 44 |
+
with io.open(fpath, 'r', encoding='utf-8') as f:
|
| 45 |
+
for line in f:
|
| 46 |
+
if self.nclasses == 2:
|
| 47 |
+
sample = line.strip().split('\t')
|
| 48 |
+
sst_data['y'].append(int(sample[1]))
|
| 49 |
+
sst_data['X'].append(sample[0].split())
|
| 50 |
+
elif self.nclasses == 5:
|
| 51 |
+
sample = line.strip().split(' ', 1)
|
| 52 |
+
sst_data['y'].append(int(sample[0]))
|
| 53 |
+
sst_data['X'].append(sample[1].split())
|
| 54 |
+
assert max(sst_data['y']) == self.nclasses - 1
|
| 55 |
+
return sst_data
|
| 56 |
+
|
| 57 |
+
def run(self, params, batcher):
|
| 58 |
+
sst_embed = {'train': {}, 'dev': {}, 'test': {}}
|
| 59 |
+
bsize = params.batch_size
|
| 60 |
+
|
| 61 |
+
for key in self.sst_data:
|
| 62 |
+
logging.info('Computing embedding for {0}'.format(key))
|
| 63 |
+
# Sort to reduce padding
|
| 64 |
+
sorted_data = sorted(zip(self.sst_data[key]['X'],
|
| 65 |
+
self.sst_data[key]['y']),
|
| 66 |
+
key=lambda z: (len(z[0]), z[1]))
|
| 67 |
+
self.sst_data[key]['X'], self.sst_data[key]['y'] = map(list, zip(*sorted_data))
|
| 68 |
+
|
| 69 |
+
sst_embed[key]['X'] = []
|
| 70 |
+
for ii in range(0, len(self.sst_data[key]['y']), bsize):
|
| 71 |
+
batch = self.sst_data[key]['X'][ii:ii + bsize]
|
| 72 |
+
embeddings = batcher(params, batch)
|
| 73 |
+
sst_embed[key]['X'].append(embeddings)
|
| 74 |
+
sst_embed[key]['X'] = np.vstack(sst_embed[key]['X'])
|
| 75 |
+
sst_embed[key]['y'] = np.array(self.sst_data[key]['y'])
|
| 76 |
+
logging.info('Computed {0} embeddings'.format(key))
|
| 77 |
+
|
| 78 |
+
config_classifier = {'nclasses': self.nclasses, 'seed': self.seed,
|
| 79 |
+
'usepytorch': params.usepytorch,
|
| 80 |
+
'classifier': params.classifier}
|
| 81 |
+
|
| 82 |
+
clf = SplitClassifier(X={'train': sst_embed['train']['X'],
|
| 83 |
+
'valid': sst_embed['dev']['X'],
|
| 84 |
+
'test': sst_embed['test']['X']},
|
| 85 |
+
y={'train': sst_embed['train']['y'],
|
| 86 |
+
'valid': sst_embed['dev']['y'],
|
| 87 |
+
'test': sst_embed['test']['y']},
|
| 88 |
+
config=config_classifier)
|
| 89 |
+
|
| 90 |
+
devacc, testacc = clf.run()
|
| 91 |
+
logging.debug('\nDev acc : {0} Test acc : {1} for \
|
| 92 |
+
SST {2} classification\n'.format(devacc, testacc, self.task_name))
|
| 93 |
+
|
| 94 |
+
return {'devacc': devacc, 'acc': testacc,
|
| 95 |
+
'ndev': len(sst_embed['dev']['X']),
|
| 96 |
+
'ntest': len(sst_embed['test']['X'])}
|
SentEval/senteval/sts.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2017-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
'''
|
| 9 |
+
STS-{2012,2013,2014,2015,2016} (unsupervised) and
|
| 10 |
+
STS-benchmark (supervised) tasks
|
| 11 |
+
'''
|
| 12 |
+
|
| 13 |
+
from __future__ import absolute_import, division, unicode_literals
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import io
|
| 17 |
+
import numpy as np
|
| 18 |
+
import logging
|
| 19 |
+
|
| 20 |
+
from scipy.stats import spearmanr, pearsonr
|
| 21 |
+
|
| 22 |
+
from senteval.utils import cosine
|
| 23 |
+
from senteval.sick import SICKEval
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class STSEval(object):
|
| 27 |
+
def loadFile(self, fpath):
|
| 28 |
+
self.data = {}
|
| 29 |
+
self.samples = []
|
| 30 |
+
|
| 31 |
+
for dataset in self.datasets:
|
| 32 |
+
sent1, sent2 = zip(*[l.split("\t") for l in
|
| 33 |
+
io.open(fpath + '/STS.input.%s.txt' % dataset,
|
| 34 |
+
encoding='utf8').read().splitlines()])
|
| 35 |
+
raw_scores = np.array([x for x in
|
| 36 |
+
io.open(fpath + '/STS.gs.%s.txt' % dataset,
|
| 37 |
+
encoding='utf8')
|
| 38 |
+
.read().splitlines()])
|
| 39 |
+
not_empty_idx = raw_scores != ''
|
| 40 |
+
|
| 41 |
+
gs_scores = [float(x) for x in raw_scores[not_empty_idx]]
|
| 42 |
+
sent1 = np.array([s.split() for s in sent1])[not_empty_idx]
|
| 43 |
+
sent2 = np.array([s.split() for s in sent2])[not_empty_idx]
|
| 44 |
+
# sort data by length to minimize padding in batcher
|
| 45 |
+
sorted_data = sorted(zip(sent1, sent2, gs_scores),
|
| 46 |
+
key=lambda z: (len(z[0]), len(z[1]), z[2]))
|
| 47 |
+
sent1, sent2, gs_scores = map(list, zip(*sorted_data))
|
| 48 |
+
|
| 49 |
+
self.data[dataset] = (sent1, sent2, gs_scores)
|
| 50 |
+
self.samples += sent1 + sent2
|
| 51 |
+
|
| 52 |
+
def do_prepare(self, params, prepare):
|
| 53 |
+
if 'similarity' in params:
|
| 54 |
+
self.similarity = params.similarity
|
| 55 |
+
else: # Default similarity is cosine
|
| 56 |
+
self.similarity = lambda s1, s2: np.nan_to_num(cosine(np.nan_to_num(s1), np.nan_to_num(s2)))
|
| 57 |
+
return prepare(params, self.samples)
|
| 58 |
+
|
| 59 |
+
def run(self, params, batcher):
|
| 60 |
+
results = {}
|
| 61 |
+
all_sys_scores = []
|
| 62 |
+
all_gs_scores = []
|
| 63 |
+
for dataset in self.datasets:
|
| 64 |
+
sys_scores = []
|
| 65 |
+
input1, input2, gs_scores = self.data[dataset]
|
| 66 |
+
for ii in range(0, len(gs_scores), params.batch_size):
|
| 67 |
+
batch1 = input1[ii:ii + params.batch_size]
|
| 68 |
+
batch2 = input2[ii:ii + params.batch_size]
|
| 69 |
+
|
| 70 |
+
# we assume get_batch already throws out the faulty ones
|
| 71 |
+
if len(batch1) == len(batch2) and len(batch1) > 0:
|
| 72 |
+
enc1 = batcher(params, batch1)
|
| 73 |
+
enc2 = batcher(params, batch2)
|
| 74 |
+
|
| 75 |
+
for kk in range(enc2.shape[0]):
|
| 76 |
+
sys_score = self.similarity(enc1[kk], enc2[kk])
|
| 77 |
+
sys_scores.append(sys_score)
|
| 78 |
+
all_sys_scores.extend(sys_scores)
|
| 79 |
+
all_gs_scores.extend(gs_scores)
|
| 80 |
+
results[dataset] = {'pearson': pearsonr(sys_scores, gs_scores),
|
| 81 |
+
'spearman': spearmanr(sys_scores, gs_scores),
|
| 82 |
+
'nsamples': len(sys_scores)}
|
| 83 |
+
logging.debug('%s : pearson = %.4f, spearman = %.4f' %
|
| 84 |
+
(dataset, results[dataset]['pearson'][0],
|
| 85 |
+
results[dataset]['spearman'][0]))
|
| 86 |
+
|
| 87 |
+
weights = [results[dset]['nsamples'] for dset in results.keys()]
|
| 88 |
+
list_prs = np.array([results[dset]['pearson'][0] for
|
| 89 |
+
dset in results.keys()])
|
| 90 |
+
list_spr = np.array([results[dset]['spearman'][0] for
|
| 91 |
+
dset in results.keys()])
|
| 92 |
+
|
| 93 |
+
avg_pearson = np.average(list_prs)
|
| 94 |
+
avg_spearman = np.average(list_spr)
|
| 95 |
+
wavg_pearson = np.average(list_prs, weights=weights)
|
| 96 |
+
wavg_spearman = np.average(list_spr, weights=weights)
|
| 97 |
+
all_pearson = pearsonr(all_sys_scores, all_gs_scores)
|
| 98 |
+
all_spearman = spearmanr(all_sys_scores, all_gs_scores)
|
| 99 |
+
results['all'] = {'pearson': {'all': all_pearson[0],
|
| 100 |
+
'mean': avg_pearson,
|
| 101 |
+
'wmean': wavg_pearson},
|
| 102 |
+
'spearman': {'all': all_spearman[0],
|
| 103 |
+
'mean': avg_spearman,
|
| 104 |
+
'wmean': wavg_spearman}}
|
| 105 |
+
logging.debug('ALL : Pearson = %.4f, \
|
| 106 |
+
Spearman = %.4f' % (all_pearson[0], all_spearman[0]))
|
| 107 |
+
logging.debug('ALL (weighted average) : Pearson = %.4f, \
|
| 108 |
+
Spearman = %.4f' % (wavg_pearson, wavg_spearman))
|
| 109 |
+
logging.debug('ALL (average) : Pearson = %.4f, \
|
| 110 |
+
Spearman = %.4f\n' % (avg_pearson, avg_spearman))
|
| 111 |
+
|
| 112 |
+
return results
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class STS12Eval(STSEval):
|
| 116 |
+
def __init__(self, taskpath, seed=1111):
|
| 117 |
+
logging.debug('***** Transfer task : STS12 *****\n\n')
|
| 118 |
+
self.seed = seed
|
| 119 |
+
self.datasets = ['MSRpar', 'MSRvid', 'SMTeuroparl',
|
| 120 |
+
'surprise.OnWN', 'surprise.SMTnews']
|
| 121 |
+
self.loadFile(taskpath)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class STS13Eval(STSEval):
|
| 125 |
+
# STS13 here does not contain the "SMT" subtask due to LICENSE issue
|
| 126 |
+
def __init__(self, taskpath, seed=1111):
|
| 127 |
+
logging.debug('***** Transfer task : STS13 (-SMT) *****\n\n')
|
| 128 |
+
self.seed = seed
|
| 129 |
+
self.datasets = ['FNWN', 'headlines', 'OnWN']
|
| 130 |
+
self.loadFile(taskpath)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class STS14Eval(STSEval):
|
| 134 |
+
def __init__(self, taskpath, seed=1111):
|
| 135 |
+
logging.debug('***** Transfer task : STS14 *****\n\n')
|
| 136 |
+
self.seed = seed
|
| 137 |
+
self.datasets = ['deft-forum', 'deft-news', 'headlines',
|
| 138 |
+
'images', 'OnWN', 'tweet-news']
|
| 139 |
+
self.loadFile(taskpath)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class STS15Eval(STSEval):
|
| 143 |
+
def __init__(self, taskpath, seed=1111):
|
| 144 |
+
logging.debug('***** Transfer task : STS15 *****\n\n')
|
| 145 |
+
self.seed = seed
|
| 146 |
+
self.datasets = ['answers-forums', 'answers-students',
|
| 147 |
+
'belief', 'headlines', 'images']
|
| 148 |
+
self.loadFile(taskpath)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class STS16Eval(STSEval):
|
| 152 |
+
def __init__(self, taskpath, seed=1111):
|
| 153 |
+
logging.debug('***** Transfer task : STS16 *****\n\n')
|
| 154 |
+
self.seed = seed
|
| 155 |
+
self.datasets = ['answer-answer', 'headlines', 'plagiarism',
|
| 156 |
+
'postediting', 'question-question']
|
| 157 |
+
self.loadFile(taskpath)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class STSBenchmarkEval(STSEval):
|
| 161 |
+
def __init__(self, task_path, seed=1111):
|
| 162 |
+
logging.debug('\n\n***** Transfer task : STSBenchmark*****\n\n')
|
| 163 |
+
self.seed = seed
|
| 164 |
+
self.samples = []
|
| 165 |
+
train = self.loadFile(os.path.join(task_path, 'sts-train.csv'))
|
| 166 |
+
dev = self.loadFile(os.path.join(task_path, 'sts-dev.csv'))
|
| 167 |
+
test = self.loadFile(os.path.join(task_path, 'sts-test.csv'))
|
| 168 |
+
self.datasets = ['train', 'dev', 'test']
|
| 169 |
+
self.data = {'train': train, 'dev': dev, 'test': test}
|
| 170 |
+
|
| 171 |
+
def loadFile(self, fpath):
|
| 172 |
+
sick_data = {'X_A': [], 'X_B': [], 'y': []}
|
| 173 |
+
with io.open(fpath, 'r', encoding='utf-8') as f:
|
| 174 |
+
for line in f:
|
| 175 |
+
text = line.strip().split('\t')
|
| 176 |
+
sick_data['X_A'].append(text[5].split())
|
| 177 |
+
sick_data['X_B'].append(text[6].split())
|
| 178 |
+
sick_data['y'].append(text[4])
|
| 179 |
+
|
| 180 |
+
sick_data['y'] = [float(s) for s in sick_data['y']]
|
| 181 |
+
self.samples += sick_data['X_A'] + sick_data["X_B"]
|
| 182 |
+
return (sick_data['X_A'], sick_data["X_B"], sick_data['y'])
|
| 183 |
+
|
| 184 |
+
class STSBenchmarkFinetune(SICKEval):
|
| 185 |
+
def __init__(self, task_path, seed=1111):
|
| 186 |
+
logging.debug('\n\n***** Transfer task : STSBenchmark*****\n\n')
|
| 187 |
+
self.seed = seed
|
| 188 |
+
train = self.loadFile(os.path.join(task_path, 'sts-train.csv'))
|
| 189 |
+
dev = self.loadFile(os.path.join(task_path, 'sts-dev.csv'))
|
| 190 |
+
test = self.loadFile(os.path.join(task_path, 'sts-test.csv'))
|
| 191 |
+
self.sick_data = {'train': train, 'dev': dev, 'test': test}
|
| 192 |
+
|
| 193 |
+
def loadFile(self, fpath):
|
| 194 |
+
sick_data = {'X_A': [], 'X_B': [], 'y': []}
|
| 195 |
+
with io.open(fpath, 'r', encoding='utf-8') as f:
|
| 196 |
+
for line in f:
|
| 197 |
+
text = line.strip().split('\t')
|
| 198 |
+
sick_data['X_A'].append(text[5].split())
|
| 199 |
+
sick_data['X_B'].append(text[6].split())
|
| 200 |
+
sick_data['y'].append(text[4])
|
| 201 |
+
|
| 202 |
+
sick_data['y'] = [float(s) for s in sick_data['y']]
|
| 203 |
+
return sick_data
|
| 204 |
+
|
| 205 |
+
class SICKRelatednessEval(STSEval):
|
| 206 |
+
def __init__(self, task_path, seed=1111):
|
| 207 |
+
logging.debug('\n\n***** Transfer task : SICKRelatedness*****\n\n')
|
| 208 |
+
self.seed = seed
|
| 209 |
+
self.samples = []
|
| 210 |
+
train = self.loadFile(os.path.join(task_path, 'SICK_train.txt'))
|
| 211 |
+
dev = self.loadFile(os.path.join(task_path, 'SICK_trial.txt'))
|
| 212 |
+
test = self.loadFile(os.path.join(task_path, 'SICK_test_annotated.txt'))
|
| 213 |
+
self.datasets = ['train', 'dev', 'test']
|
| 214 |
+
self.data = {'train': train, 'dev': dev, 'test': test}
|
| 215 |
+
|
| 216 |
+
def loadFile(self, fpath):
|
| 217 |
+
skipFirstLine = True
|
| 218 |
+
sick_data = {'X_A': [], 'X_B': [], 'y': []}
|
| 219 |
+
with io.open(fpath, 'r', encoding='utf-8') as f:
|
| 220 |
+
for line in f:
|
| 221 |
+
if skipFirstLine:
|
| 222 |
+
skipFirstLine = False
|
| 223 |
+
else:
|
| 224 |
+
text = line.strip().split('\t')
|
| 225 |
+
sick_data['X_A'].append(text[1].split())
|
| 226 |
+
sick_data['X_B'].append(text[2].split())
|
| 227 |
+
sick_data['y'].append(text[3])
|
| 228 |
+
|
| 229 |
+
sick_data['y'] = [float(s) for s in sick_data['y']]
|
| 230 |
+
self.samples += sick_data['X_A'] + sick_data["X_B"]
|
| 231 |
+
return (sick_data['X_A'], sick_data["X_B"], sick_data['y'])
|
SentEval/senteval/tools/__init__.py
ADDED
|
File without changes
|
SentEval/senteval/tools/classifier.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2017-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
"""
|
| 9 |
+
Pytorch Classifier class in the style of scikit-learn
|
| 10 |
+
Classifiers include Logistic Regression and MLP
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import absolute_import, division, unicode_literals
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import copy
|
| 17 |
+
from senteval import utils
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from torch import nn
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class PyTorchClassifier(object):
|
| 25 |
+
def __init__(self, inputdim, nclasses, l2reg=0., batch_size=64, seed=1111,
|
| 26 |
+
cudaEfficient=False):
|
| 27 |
+
# fix seed
|
| 28 |
+
np.random.seed(seed)
|
| 29 |
+
torch.manual_seed(seed)
|
| 30 |
+
torch.cuda.manual_seed(seed)
|
| 31 |
+
|
| 32 |
+
self.inputdim = inputdim
|
| 33 |
+
self.nclasses = nclasses
|
| 34 |
+
self.l2reg = l2reg
|
| 35 |
+
self.batch_size = batch_size
|
| 36 |
+
self.cudaEfficient = cudaEfficient
|
| 37 |
+
|
| 38 |
+
def prepare_split(self, X, y, validation_data=None, validation_split=None):
|
| 39 |
+
# Preparing validation data
|
| 40 |
+
assert validation_split or validation_data
|
| 41 |
+
if validation_data is not None:
|
| 42 |
+
trainX, trainy = X, y
|
| 43 |
+
devX, devy = validation_data
|
| 44 |
+
else:
|
| 45 |
+
permutation = np.random.permutation(len(X))
|
| 46 |
+
trainidx = permutation[int(validation_split * len(X)):]
|
| 47 |
+
devidx = permutation[0:int(validation_split * len(X))]
|
| 48 |
+
trainX, trainy = X[trainidx], y[trainidx]
|
| 49 |
+
devX, devy = X[devidx], y[devidx]
|
| 50 |
+
|
| 51 |
+
device = torch.device('cpu') if self.cudaEfficient else torch.device('cuda')
|
| 52 |
+
|
| 53 |
+
trainX = torch.from_numpy(trainX).to(device, dtype=torch.float32)
|
| 54 |
+
trainy = torch.from_numpy(trainy).to(device, dtype=torch.int64)
|
| 55 |
+
devX = torch.from_numpy(devX).to(device, dtype=torch.float32)
|
| 56 |
+
devy = torch.from_numpy(devy).to(device, dtype=torch.int64)
|
| 57 |
+
|
| 58 |
+
return trainX, trainy, devX, devy
|
| 59 |
+
|
| 60 |
+
def fit(self, X, y, validation_data=None, validation_split=None,
|
| 61 |
+
early_stop=True):
|
| 62 |
+
self.nepoch = 0
|
| 63 |
+
bestaccuracy = -1
|
| 64 |
+
stop_train = False
|
| 65 |
+
early_stop_count = 0
|
| 66 |
+
|
| 67 |
+
# Preparing validation data
|
| 68 |
+
trainX, trainy, devX, devy = self.prepare_split(X, y, validation_data,
|
| 69 |
+
validation_split)
|
| 70 |
+
|
| 71 |
+
# Training
|
| 72 |
+
while not stop_train and self.nepoch <= self.max_epoch:
|
| 73 |
+
self.trainepoch(trainX, trainy, epoch_size=self.epoch_size)
|
| 74 |
+
accuracy = self.score(devX, devy)
|
| 75 |
+
if accuracy > bestaccuracy:
|
| 76 |
+
bestaccuracy = accuracy
|
| 77 |
+
bestmodel = copy.deepcopy(self.model)
|
| 78 |
+
elif early_stop:
|
| 79 |
+
if early_stop_count >= self.tenacity:
|
| 80 |
+
stop_train = True
|
| 81 |
+
early_stop_count += 1
|
| 82 |
+
self.model = bestmodel
|
| 83 |
+
return bestaccuracy
|
| 84 |
+
|
| 85 |
+
def trainepoch(self, X, y, epoch_size=1):
|
| 86 |
+
self.model.train()
|
| 87 |
+
for _ in range(self.nepoch, self.nepoch + epoch_size):
|
| 88 |
+
permutation = np.random.permutation(len(X))
|
| 89 |
+
all_costs = []
|
| 90 |
+
for i in range(0, len(X), self.batch_size):
|
| 91 |
+
# forward
|
| 92 |
+
idx = torch.from_numpy(permutation[i:i + self.batch_size]).long().to(X.device)
|
| 93 |
+
|
| 94 |
+
Xbatch = X[idx]
|
| 95 |
+
ybatch = y[idx]
|
| 96 |
+
|
| 97 |
+
if self.cudaEfficient:
|
| 98 |
+
Xbatch = Xbatch.cuda()
|
| 99 |
+
ybatch = ybatch.cuda()
|
| 100 |
+
output = self.model(Xbatch)
|
| 101 |
+
# loss
|
| 102 |
+
loss = self.loss_fn(output, ybatch)
|
| 103 |
+
all_costs.append(loss.data.item())
|
| 104 |
+
# backward
|
| 105 |
+
self.optimizer.zero_grad()
|
| 106 |
+
loss.backward()
|
| 107 |
+
# Update parameters
|
| 108 |
+
self.optimizer.step()
|
| 109 |
+
self.nepoch += epoch_size
|
| 110 |
+
|
| 111 |
+
def score(self, devX, devy):
|
| 112 |
+
self.model.eval()
|
| 113 |
+
correct = 0
|
| 114 |
+
if not isinstance(devX, torch.cuda.FloatTensor) or self.cudaEfficient:
|
| 115 |
+
devX = torch.FloatTensor(devX).cuda()
|
| 116 |
+
devy = torch.LongTensor(devy).cuda()
|
| 117 |
+
with torch.no_grad():
|
| 118 |
+
for i in range(0, len(devX), self.batch_size):
|
| 119 |
+
Xbatch = devX[i:i + self.batch_size]
|
| 120 |
+
ybatch = devy[i:i + self.batch_size]
|
| 121 |
+
if self.cudaEfficient:
|
| 122 |
+
Xbatch = Xbatch.cuda()
|
| 123 |
+
ybatch = ybatch.cuda()
|
| 124 |
+
output = self.model(Xbatch)
|
| 125 |
+
pred = output.data.max(1)[1]
|
| 126 |
+
correct += pred.long().eq(ybatch.data.long()).sum().item()
|
| 127 |
+
accuracy = 1.0 * correct / len(devX)
|
| 128 |
+
return accuracy
|
| 129 |
+
|
| 130 |
+
def predict(self, devX):
|
| 131 |
+
self.model.eval()
|
| 132 |
+
if not isinstance(devX, torch.cuda.FloatTensor):
|
| 133 |
+
devX = torch.FloatTensor(devX).cuda()
|
| 134 |
+
yhat = np.array([])
|
| 135 |
+
with torch.no_grad():
|
| 136 |
+
for i in range(0, len(devX), self.batch_size):
|
| 137 |
+
Xbatch = devX[i:i + self.batch_size]
|
| 138 |
+
output = self.model(Xbatch)
|
| 139 |
+
yhat = np.append(yhat,
|
| 140 |
+
output.data.max(1)[1].cpu().numpy())
|
| 141 |
+
yhat = np.vstack(yhat)
|
| 142 |
+
return yhat
|
| 143 |
+
|
| 144 |
+
def predict_proba(self, devX):
|
| 145 |
+
self.model.eval()
|
| 146 |
+
probas = []
|
| 147 |
+
with torch.no_grad():
|
| 148 |
+
for i in range(0, len(devX), self.batch_size):
|
| 149 |
+
Xbatch = devX[i:i + self.batch_size]
|
| 150 |
+
vals = F.softmax(self.model(Xbatch).data.cpu().numpy())
|
| 151 |
+
if not probas:
|
| 152 |
+
probas = vals
|
| 153 |
+
else:
|
| 154 |
+
probas = np.concatenate(probas, vals, axis=0)
|
| 155 |
+
return probas
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
"""
|
| 159 |
+
MLP with Pytorch (nhid=0 --> Logistic Regression)
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
class MLP(PyTorchClassifier):
|
| 163 |
+
def __init__(self, params, inputdim, nclasses, l2reg=0., batch_size=64,
|
| 164 |
+
seed=1111, cudaEfficient=False):
|
| 165 |
+
super(self.__class__, self).__init__(inputdim, nclasses, l2reg,
|
| 166 |
+
batch_size, seed, cudaEfficient)
|
| 167 |
+
"""
|
| 168 |
+
PARAMETERS:
|
| 169 |
+
-nhid: number of hidden units (0: Logistic Regression)
|
| 170 |
+
-optim: optimizer ("sgd,lr=0.1", "adam", "rmsprop" ..)
|
| 171 |
+
-tenacity: how many times dev acc does not increase before stopping
|
| 172 |
+
-epoch_size: each epoch corresponds to epoch_size pass on the train set
|
| 173 |
+
-max_epoch: max number of epoches
|
| 174 |
+
-dropout: dropout for MLP
|
| 175 |
+
"""
|
| 176 |
+
|
| 177 |
+
self.nhid = 0 if "nhid" not in params else params["nhid"]
|
| 178 |
+
self.optim = "adam" if "optim" not in params else params["optim"]
|
| 179 |
+
self.tenacity = 5 if "tenacity" not in params else params["tenacity"]
|
| 180 |
+
self.epoch_size = 4 if "epoch_size" not in params else params["epoch_size"]
|
| 181 |
+
self.max_epoch = 200 if "max_epoch" not in params else params["max_epoch"]
|
| 182 |
+
self.dropout = 0. if "dropout" not in params else params["dropout"]
|
| 183 |
+
self.batch_size = 64 if "batch_size" not in params else params["batch_size"]
|
| 184 |
+
|
| 185 |
+
if params["nhid"] == 0:
|
| 186 |
+
self.model = nn.Sequential(
|
| 187 |
+
nn.Linear(self.inputdim, self.nclasses),
|
| 188 |
+
).cuda()
|
| 189 |
+
else:
|
| 190 |
+
self.model = nn.Sequential(
|
| 191 |
+
nn.Linear(self.inputdim, params["nhid"]),
|
| 192 |
+
nn.Dropout(p=self.dropout),
|
| 193 |
+
nn.Sigmoid(),
|
| 194 |
+
nn.Linear(params["nhid"], self.nclasses),
|
| 195 |
+
).cuda()
|
| 196 |
+
|
| 197 |
+
self.loss_fn = nn.CrossEntropyLoss().cuda()
|
| 198 |
+
self.loss_fn.size_average = False
|
| 199 |
+
|
| 200 |
+
optim_fn, optim_params = utils.get_optimizer(self.optim)
|
| 201 |
+
self.optimizer = optim_fn(self.model.parameters(), **optim_params)
|
| 202 |
+
self.optimizer.param_groups[0]['weight_decay'] = self.l2reg
|
SentEval/senteval/tools/ranking.py
ADDED
|
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2017-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
"""
|
| 9 |
+
Image Annotation/Search for COCO with Pytorch
|
| 10 |
+
"""
|
| 11 |
+
from __future__ import absolute_import, division, unicode_literals
|
| 12 |
+
|
| 13 |
+
import logging
|
| 14 |
+
import copy
|
| 15 |
+
import numpy as np
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
from torch import nn
|
| 19 |
+
from torch.autograd import Variable
|
| 20 |
+
import torch.optim as optim
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class COCOProjNet(nn.Module):
|
| 24 |
+
def __init__(self, config):
|
| 25 |
+
super(COCOProjNet, self).__init__()
|
| 26 |
+
self.imgdim = config['imgdim']
|
| 27 |
+
self.sentdim = config['sentdim']
|
| 28 |
+
self.projdim = config['projdim']
|
| 29 |
+
self.imgproj = nn.Sequential(
|
| 30 |
+
nn.Linear(self.imgdim, self.projdim),
|
| 31 |
+
)
|
| 32 |
+
self.sentproj = nn.Sequential(
|
| 33 |
+
nn.Linear(self.sentdim, self.projdim),
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
def forward(self, img, sent, imgc, sentc):
|
| 37 |
+
# imgc : (bsize, ncontrast, imgdim)
|
| 38 |
+
# sentc : (bsize, ncontrast, sentdim)
|
| 39 |
+
# img : (bsize, imgdim)
|
| 40 |
+
# sent : (bsize, sentdim)
|
| 41 |
+
img = img.unsqueeze(1).expand_as(imgc).contiguous()
|
| 42 |
+
img = img.view(-1, self.imgdim)
|
| 43 |
+
imgc = imgc.view(-1, self.imgdim)
|
| 44 |
+
sent = sent.unsqueeze(1).expand_as(sentc).contiguous()
|
| 45 |
+
sent = sent.view(-1, self.sentdim)
|
| 46 |
+
sentc = sentc.view(-1, self.sentdim)
|
| 47 |
+
|
| 48 |
+
imgproj = self.imgproj(img)
|
| 49 |
+
imgproj = imgproj / torch.sqrt(torch.pow(imgproj, 2).sum(1, keepdim=True)).expand_as(imgproj)
|
| 50 |
+
imgcproj = self.imgproj(imgc)
|
| 51 |
+
imgcproj = imgcproj / torch.sqrt(torch.pow(imgcproj, 2).sum(1, keepdim=True)).expand_as(imgcproj)
|
| 52 |
+
sentproj = self.sentproj(sent)
|
| 53 |
+
sentproj = sentproj / torch.sqrt(torch.pow(sentproj, 2).sum(1, keepdim=True)).expand_as(sentproj)
|
| 54 |
+
sentcproj = self.sentproj(sentc)
|
| 55 |
+
sentcproj = sentcproj / torch.sqrt(torch.pow(sentcproj, 2).sum(1, keepdim=True)).expand_as(sentcproj)
|
| 56 |
+
# (bsize*ncontrast, projdim)
|
| 57 |
+
|
| 58 |
+
anchor1 = torch.sum((imgproj*sentproj), 1)
|
| 59 |
+
anchor2 = torch.sum((sentproj*imgproj), 1)
|
| 60 |
+
img_sentc = torch.sum((imgproj*sentcproj), 1)
|
| 61 |
+
sent_imgc = torch.sum((sentproj*imgcproj), 1)
|
| 62 |
+
|
| 63 |
+
# (bsize*ncontrast)
|
| 64 |
+
return anchor1, anchor2, img_sentc, sent_imgc
|
| 65 |
+
|
| 66 |
+
def proj_sentence(self, sent):
|
| 67 |
+
output = self.sentproj(sent)
|
| 68 |
+
output = output / torch.sqrt(torch.pow(output, 2).sum(1, keepdim=True)).expand_as(output)
|
| 69 |
+
return output # (bsize, projdim)
|
| 70 |
+
|
| 71 |
+
def proj_image(self, img):
|
| 72 |
+
output = self.imgproj(img)
|
| 73 |
+
output = output / torch.sqrt(torch.pow(output, 2).sum(1, keepdim=True)).expand_as(output)
|
| 74 |
+
return output # (bsize, projdim)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class PairwiseRankingLoss(nn.Module):
|
| 78 |
+
"""
|
| 79 |
+
Pairwise ranking loss
|
| 80 |
+
"""
|
| 81 |
+
def __init__(self, margin):
|
| 82 |
+
super(PairwiseRankingLoss, self).__init__()
|
| 83 |
+
self.margin = margin
|
| 84 |
+
|
| 85 |
+
def forward(self, anchor1, anchor2, img_sentc, sent_imgc):
|
| 86 |
+
|
| 87 |
+
cost_sent = torch.clamp(self.margin - anchor1 + img_sentc,
|
| 88 |
+
min=0.0).sum()
|
| 89 |
+
cost_img = torch.clamp(self.margin - anchor2 + sent_imgc,
|
| 90 |
+
min=0.0).sum()
|
| 91 |
+
loss = cost_sent + cost_img
|
| 92 |
+
return loss
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class ImageSentenceRankingPytorch(object):
|
| 96 |
+
# Image Sentence Ranking on COCO with Pytorch
|
| 97 |
+
def __init__(self, train, valid, test, config):
|
| 98 |
+
# fix seed
|
| 99 |
+
self.seed = config['seed']
|
| 100 |
+
np.random.seed(self.seed)
|
| 101 |
+
torch.manual_seed(self.seed)
|
| 102 |
+
torch.cuda.manual_seed(self.seed)
|
| 103 |
+
|
| 104 |
+
self.train = train
|
| 105 |
+
self.valid = valid
|
| 106 |
+
self.test = test
|
| 107 |
+
|
| 108 |
+
self.imgdim = len(train['imgfeat'][0])
|
| 109 |
+
self.sentdim = len(train['sentfeat'][0])
|
| 110 |
+
self.projdim = config['projdim']
|
| 111 |
+
self.margin = config['margin']
|
| 112 |
+
|
| 113 |
+
self.batch_size = 128
|
| 114 |
+
self.ncontrast = 30
|
| 115 |
+
self.maxepoch = 20
|
| 116 |
+
self.early_stop = True
|
| 117 |
+
|
| 118 |
+
config_model = {'imgdim': self.imgdim,'sentdim': self.sentdim,
|
| 119 |
+
'projdim': self.projdim}
|
| 120 |
+
self.model = COCOProjNet(config_model).cuda()
|
| 121 |
+
|
| 122 |
+
self.loss_fn = PairwiseRankingLoss(margin=self.margin).cuda()
|
| 123 |
+
|
| 124 |
+
self.optimizer = optim.Adam(self.model.parameters())
|
| 125 |
+
|
| 126 |
+
def prepare_data(self, trainTxt, trainImg, devTxt, devImg,
|
| 127 |
+
testTxt, testImg):
|
| 128 |
+
trainTxt = torch.FloatTensor(trainTxt)
|
| 129 |
+
trainImg = torch.FloatTensor(trainImg)
|
| 130 |
+
devTxt = torch.FloatTensor(devTxt).cuda()
|
| 131 |
+
devImg = torch.FloatTensor(devImg).cuda()
|
| 132 |
+
testTxt = torch.FloatTensor(testTxt).cuda()
|
| 133 |
+
testImg = torch.FloatTensor(testImg).cuda()
|
| 134 |
+
|
| 135 |
+
return trainTxt, trainImg, devTxt, devImg, testTxt, testImg
|
| 136 |
+
|
| 137 |
+
def run(self):
|
| 138 |
+
self.nepoch = 0
|
| 139 |
+
bestdevscore = -1
|
| 140 |
+
early_stop_count = 0
|
| 141 |
+
stop_train = False
|
| 142 |
+
|
| 143 |
+
# Preparing data
|
| 144 |
+
logging.info('prepare data')
|
| 145 |
+
trainTxt, trainImg, devTxt, devImg, testTxt, testImg = \
|
| 146 |
+
self.prepare_data(self.train['sentfeat'], self.train['imgfeat'],
|
| 147 |
+
self.valid['sentfeat'], self.valid['imgfeat'],
|
| 148 |
+
self.test['sentfeat'], self.test['imgfeat'])
|
| 149 |
+
|
| 150 |
+
# Training
|
| 151 |
+
while not stop_train and self.nepoch <= self.maxepoch:
|
| 152 |
+
logging.info('start epoch')
|
| 153 |
+
self.trainepoch(trainTxt, trainImg, devTxt, devImg, nepoches=1)
|
| 154 |
+
logging.info('Epoch {0} finished'.format(self.nepoch))
|
| 155 |
+
|
| 156 |
+
results = {'i2t': {'r1': 0, 'r5': 0, 'r10': 0, 'medr': 0},
|
| 157 |
+
't2i': {'r1': 0, 'r5': 0, 'r10': 0, 'medr': 0},
|
| 158 |
+
'dev': bestdevscore}
|
| 159 |
+
score = 0
|
| 160 |
+
for i in range(5):
|
| 161 |
+
devTxt_i = devTxt[i*5000:(i+1)*5000]
|
| 162 |
+
devImg_i = devImg[i*5000:(i+1)*5000]
|
| 163 |
+
# Compute dev ranks img2txt
|
| 164 |
+
r1_i2t, r5_i2t, r10_i2t, medr_i2t = self.i2t(devImg_i,
|
| 165 |
+
devTxt_i)
|
| 166 |
+
results['i2t']['r1'] += r1_i2t / 5
|
| 167 |
+
results['i2t']['r5'] += r5_i2t / 5
|
| 168 |
+
results['i2t']['r10'] += r10_i2t / 5
|
| 169 |
+
results['i2t']['medr'] += medr_i2t / 5
|
| 170 |
+
logging.info("Image to text: {0}, {1}, {2}, {3}"
|
| 171 |
+
.format(r1_i2t, r5_i2t, r10_i2t, medr_i2t))
|
| 172 |
+
# Compute dev ranks txt2img
|
| 173 |
+
r1_t2i, r5_t2i, r10_t2i, medr_t2i = self.t2i(devImg_i,
|
| 174 |
+
devTxt_i)
|
| 175 |
+
results['t2i']['r1'] += r1_t2i / 5
|
| 176 |
+
results['t2i']['r5'] += r5_t2i / 5
|
| 177 |
+
results['t2i']['r10'] += r10_t2i / 5
|
| 178 |
+
results['t2i']['medr'] += medr_t2i / 5
|
| 179 |
+
logging.info("Text to Image: {0}, {1}, {2}, {3}"
|
| 180 |
+
.format(r1_t2i, r5_t2i, r10_t2i, medr_t2i))
|
| 181 |
+
score += (r1_i2t + r5_i2t + r10_i2t +
|
| 182 |
+
r1_t2i + r5_t2i + r10_t2i) / 5
|
| 183 |
+
|
| 184 |
+
logging.info("Dev mean Text to Image: {0}, {1}, {2}, {3}".format(
|
| 185 |
+
results['t2i']['r1'], results['t2i']['r5'],
|
| 186 |
+
results['t2i']['r10'], results['t2i']['medr']))
|
| 187 |
+
logging.info("Dev mean Image to text: {0}, {1}, {2}, {3}".format(
|
| 188 |
+
results['i2t']['r1'], results['i2t']['r5'],
|
| 189 |
+
results['i2t']['r10'], results['i2t']['medr']))
|
| 190 |
+
|
| 191 |
+
# early stop on Pearson
|
| 192 |
+
if score > bestdevscore:
|
| 193 |
+
bestdevscore = score
|
| 194 |
+
bestmodel = copy.deepcopy(self.model)
|
| 195 |
+
elif self.early_stop:
|
| 196 |
+
if early_stop_count >= 3:
|
| 197 |
+
stop_train = True
|
| 198 |
+
early_stop_count += 1
|
| 199 |
+
self.model = bestmodel
|
| 200 |
+
|
| 201 |
+
# Compute test for the 5 splits
|
| 202 |
+
results = {'i2t': {'r1': 0, 'r5': 0, 'r10': 0, 'medr': 0},
|
| 203 |
+
't2i': {'r1': 0, 'r5': 0, 'r10': 0, 'medr': 0},
|
| 204 |
+
'dev': bestdevscore}
|
| 205 |
+
for i in range(5):
|
| 206 |
+
testTxt_i = testTxt[i*5000:(i+1)*5000]
|
| 207 |
+
testImg_i = testImg[i*5000:(i+1)*5000]
|
| 208 |
+
# Compute test ranks img2txt
|
| 209 |
+
r1_i2t, r5_i2t, r10_i2t, medr_i2t = self.i2t(testImg_i, testTxt_i)
|
| 210 |
+
results['i2t']['r1'] += r1_i2t / 5
|
| 211 |
+
results['i2t']['r5'] += r5_i2t / 5
|
| 212 |
+
results['i2t']['r10'] += r10_i2t / 5
|
| 213 |
+
results['i2t']['medr'] += medr_i2t / 5
|
| 214 |
+
# Compute test ranks txt2img
|
| 215 |
+
r1_t2i, r5_t2i, r10_t2i, medr_t2i = self.t2i(testImg_i, testTxt_i)
|
| 216 |
+
results['t2i']['r1'] += r1_t2i / 5
|
| 217 |
+
results['t2i']['r5'] += r5_t2i / 5
|
| 218 |
+
results['t2i']['r10'] += r10_t2i / 5
|
| 219 |
+
results['t2i']['medr'] += medr_t2i / 5
|
| 220 |
+
|
| 221 |
+
return bestdevscore, results['i2t']['r1'], results['i2t']['r5'], \
|
| 222 |
+
results['i2t']['r10'], results['i2t']['medr'], \
|
| 223 |
+
results['t2i']['r1'], results['t2i']['r5'], \
|
| 224 |
+
results['t2i']['r10'], results['t2i']['medr']
|
| 225 |
+
|
| 226 |
+
def trainepoch(self, trainTxt, trainImg, devTxt, devImg, nepoches=1):
|
| 227 |
+
self.model.train()
|
| 228 |
+
for _ in range(self.nepoch, self.nepoch + nepoches):
|
| 229 |
+
permutation = list(np.random.permutation(len(trainTxt)))
|
| 230 |
+
all_costs = []
|
| 231 |
+
for i in range(0, len(trainTxt), self.batch_size):
|
| 232 |
+
# forward
|
| 233 |
+
if i % (self.batch_size*500) == 0 and i > 0:
|
| 234 |
+
logging.info('samples : {0}'.format(i))
|
| 235 |
+
r1_i2t, r5_i2t, r10_i2t, medr_i2t = self.i2t(devImg,
|
| 236 |
+
devTxt)
|
| 237 |
+
logging.info("Image to text: {0}, {1}, {2}, {3}".format(
|
| 238 |
+
r1_i2t, r5_i2t, r10_i2t, medr_i2t))
|
| 239 |
+
# Compute test ranks txt2img
|
| 240 |
+
r1_t2i, r5_t2i, r10_t2i, medr_t2i = self.t2i(devImg,
|
| 241 |
+
devTxt)
|
| 242 |
+
logging.info("Text to Image: {0}, {1}, {2}, {3}".format(
|
| 243 |
+
r1_t2i, r5_t2i, r10_t2i, medr_t2i))
|
| 244 |
+
idx = torch.LongTensor(permutation[i:i + self.batch_size])
|
| 245 |
+
imgbatch = Variable(trainImg.index_select(0, idx)).cuda()
|
| 246 |
+
sentbatch = Variable(trainTxt.index_select(0, idx)).cuda()
|
| 247 |
+
|
| 248 |
+
idximgc = np.random.choice(permutation[:i] +
|
| 249 |
+
permutation[i + self.batch_size:],
|
| 250 |
+
self.ncontrast*idx.size(0))
|
| 251 |
+
idxsentc = np.random.choice(permutation[:i] +
|
| 252 |
+
permutation[i + self.batch_size:],
|
| 253 |
+
self.ncontrast*idx.size(0))
|
| 254 |
+
idximgc = torch.LongTensor(idximgc)
|
| 255 |
+
idxsentc = torch.LongTensor(idxsentc)
|
| 256 |
+
# Get indexes for contrastive images and sentences
|
| 257 |
+
imgcbatch = Variable(trainImg.index_select(0, idximgc)).view(
|
| 258 |
+
-1, self.ncontrast, self.imgdim).cuda()
|
| 259 |
+
sentcbatch = Variable(trainTxt.index_select(0, idxsentc)).view(
|
| 260 |
+
-1, self.ncontrast, self.sentdim).cuda()
|
| 261 |
+
|
| 262 |
+
anchor1, anchor2, img_sentc, sent_imgc = self.model(
|
| 263 |
+
imgbatch, sentbatch, imgcbatch, sentcbatch)
|
| 264 |
+
# loss
|
| 265 |
+
loss = self.loss_fn(anchor1, anchor2, img_sentc, sent_imgc)
|
| 266 |
+
all_costs.append(loss.data.item())
|
| 267 |
+
# backward
|
| 268 |
+
self.optimizer.zero_grad()
|
| 269 |
+
loss.backward()
|
| 270 |
+
# Update parameters
|
| 271 |
+
self.optimizer.step()
|
| 272 |
+
self.nepoch += nepoches
|
| 273 |
+
|
| 274 |
+
def t2i(self, images, captions):
|
| 275 |
+
"""
|
| 276 |
+
Images: (5N, imgdim) matrix of images
|
| 277 |
+
Captions: (5N, sentdim) matrix of captions
|
| 278 |
+
"""
|
| 279 |
+
with torch.no_grad():
|
| 280 |
+
# Project images and captions
|
| 281 |
+
img_embed, sent_embed = [], []
|
| 282 |
+
for i in range(0, len(images), self.batch_size):
|
| 283 |
+
img_embed.append(self.model.proj_image(
|
| 284 |
+
Variable(images[i:i + self.batch_size])))
|
| 285 |
+
sent_embed.append(self.model.proj_sentence(
|
| 286 |
+
Variable(captions[i:i + self.batch_size])))
|
| 287 |
+
img_embed = torch.cat(img_embed, 0).data
|
| 288 |
+
sent_embed = torch.cat(sent_embed, 0).data
|
| 289 |
+
|
| 290 |
+
npts = int(img_embed.size(0) / 5)
|
| 291 |
+
idxs = torch.cuda.LongTensor(range(0, len(img_embed), 5))
|
| 292 |
+
ims = img_embed.index_select(0, idxs)
|
| 293 |
+
|
| 294 |
+
ranks = np.zeros(5 * npts)
|
| 295 |
+
for index in range(npts):
|
| 296 |
+
|
| 297 |
+
# Get query captions
|
| 298 |
+
queries = sent_embed[5*index: 5*index + 5]
|
| 299 |
+
|
| 300 |
+
# Compute scores
|
| 301 |
+
scores = torch.mm(queries, ims.transpose(0, 1)).cpu().numpy()
|
| 302 |
+
inds = np.zeros(scores.shape)
|
| 303 |
+
for i in range(len(inds)):
|
| 304 |
+
inds[i] = np.argsort(scores[i])[::-1]
|
| 305 |
+
ranks[5 * index + i] = np.where(inds[i] == index)[0][0]
|
| 306 |
+
|
| 307 |
+
# Compute metrics
|
| 308 |
+
r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
|
| 309 |
+
r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
|
| 310 |
+
r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
|
| 311 |
+
medr = np.floor(np.median(ranks)) + 1
|
| 312 |
+
return (r1, r5, r10, medr)
|
| 313 |
+
|
| 314 |
+
def i2t(self, images, captions):
|
| 315 |
+
"""
|
| 316 |
+
Images: (5N, imgdim) matrix of images
|
| 317 |
+
Captions: (5N, sentdim) matrix of captions
|
| 318 |
+
"""
|
| 319 |
+
with torch.no_grad():
|
| 320 |
+
# Project images and captions
|
| 321 |
+
img_embed, sent_embed = [], []
|
| 322 |
+
for i in range(0, len(images), self.batch_size):
|
| 323 |
+
img_embed.append(self.model.proj_image(
|
| 324 |
+
Variable(images[i:i + self.batch_size])))
|
| 325 |
+
sent_embed.append(self.model.proj_sentence(
|
| 326 |
+
Variable(captions[i:i + self.batch_size])))
|
| 327 |
+
img_embed = torch.cat(img_embed, 0).data
|
| 328 |
+
sent_embed = torch.cat(sent_embed, 0).data
|
| 329 |
+
|
| 330 |
+
npts = int(img_embed.size(0) / 5)
|
| 331 |
+
index_list = []
|
| 332 |
+
|
| 333 |
+
ranks = np.zeros(npts)
|
| 334 |
+
for index in range(npts):
|
| 335 |
+
|
| 336 |
+
# Get query image
|
| 337 |
+
query_img = img_embed[5 * index]
|
| 338 |
+
|
| 339 |
+
# Compute scores
|
| 340 |
+
scores = torch.mm(query_img.view(1, -1),
|
| 341 |
+
sent_embed.transpose(0, 1)).view(-1)
|
| 342 |
+
scores = scores.cpu().numpy()
|
| 343 |
+
inds = np.argsort(scores)[::-1]
|
| 344 |
+
index_list.append(inds[0])
|
| 345 |
+
|
| 346 |
+
# Score
|
| 347 |
+
rank = 1e20
|
| 348 |
+
for i in range(5*index, 5*index + 5, 1):
|
| 349 |
+
tmp = np.where(inds == i)[0][0]
|
| 350 |
+
if tmp < rank:
|
| 351 |
+
rank = tmp
|
| 352 |
+
ranks[index] = rank
|
| 353 |
+
|
| 354 |
+
# Compute metrics
|
| 355 |
+
r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
|
| 356 |
+
r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
|
| 357 |
+
r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
|
| 358 |
+
medr = np.floor(np.median(ranks)) + 1
|
| 359 |
+
return (r1, r5, r10, medr)
|
SentEval/senteval/tools/relatedness.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2017-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
"""
|
| 9 |
+
Semantic Relatedness (supervised) with Pytorch
|
| 10 |
+
"""
|
| 11 |
+
from __future__ import absolute_import, division, unicode_literals
|
| 12 |
+
|
| 13 |
+
import copy
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from torch import nn
|
| 18 |
+
import torch.optim as optim
|
| 19 |
+
|
| 20 |
+
from scipy.stats import pearsonr, spearmanr
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class RelatednessPytorch(object):
|
| 24 |
+
# Can be used for SICK-Relatedness, and STS14
|
| 25 |
+
def __init__(self, train, valid, test, devscores, config):
|
| 26 |
+
# fix seed
|
| 27 |
+
np.random.seed(config['seed'])
|
| 28 |
+
torch.manual_seed(config['seed'])
|
| 29 |
+
assert torch.cuda.is_available(), 'torch.cuda required for Relatedness'
|
| 30 |
+
torch.cuda.manual_seed(config['seed'])
|
| 31 |
+
|
| 32 |
+
self.train = train
|
| 33 |
+
self.valid = valid
|
| 34 |
+
self.test = test
|
| 35 |
+
self.devscores = devscores
|
| 36 |
+
|
| 37 |
+
self.inputdim = train['X'].shape[1]
|
| 38 |
+
self.nclasses = config['nclasses']
|
| 39 |
+
self.seed = config['seed']
|
| 40 |
+
self.l2reg = 0.
|
| 41 |
+
self.batch_size = 64
|
| 42 |
+
self.maxepoch = 1000
|
| 43 |
+
self.early_stop = True
|
| 44 |
+
|
| 45 |
+
self.model = nn.Sequential(
|
| 46 |
+
nn.Linear(self.inputdim, self.nclasses),
|
| 47 |
+
nn.Softmax(dim=-1),
|
| 48 |
+
)
|
| 49 |
+
self.loss_fn = nn.MSELoss()
|
| 50 |
+
|
| 51 |
+
if torch.cuda.is_available():
|
| 52 |
+
self.model = self.model.cuda()
|
| 53 |
+
self.loss_fn = self.loss_fn.cuda()
|
| 54 |
+
|
| 55 |
+
self.loss_fn.size_average = False
|
| 56 |
+
self.optimizer = optim.Adam(self.model.parameters(),
|
| 57 |
+
weight_decay=self.l2reg)
|
| 58 |
+
|
| 59 |
+
def prepare_data(self, trainX, trainy, devX, devy, testX, testy):
|
| 60 |
+
# Transform probs to log-probs for KL-divergence
|
| 61 |
+
trainX = torch.from_numpy(trainX).float().cuda()
|
| 62 |
+
trainy = torch.from_numpy(trainy).float().cuda()
|
| 63 |
+
devX = torch.from_numpy(devX).float().cuda()
|
| 64 |
+
devy = torch.from_numpy(devy).float().cuda()
|
| 65 |
+
testX = torch.from_numpy(testX).float().cuda()
|
| 66 |
+
testY = torch.from_numpy(testy).float().cuda()
|
| 67 |
+
|
| 68 |
+
return trainX, trainy, devX, devy, testX, testy
|
| 69 |
+
|
| 70 |
+
def run(self):
|
| 71 |
+
self.nepoch = 0
|
| 72 |
+
bestpr = -1
|
| 73 |
+
early_stop_count = 0
|
| 74 |
+
r = np.arange(1, 6)
|
| 75 |
+
stop_train = False
|
| 76 |
+
|
| 77 |
+
# Preparing data
|
| 78 |
+
trainX, trainy, devX, devy, testX, testy = self.prepare_data(
|
| 79 |
+
self.train['X'], self.train['y'],
|
| 80 |
+
self.valid['X'], self.valid['y'],
|
| 81 |
+
self.test['X'], self.test['y'])
|
| 82 |
+
|
| 83 |
+
# Training
|
| 84 |
+
while not stop_train and self.nepoch <= self.maxepoch:
|
| 85 |
+
self.trainepoch(trainX, trainy, nepoches=50)
|
| 86 |
+
yhat = np.dot(self.predict_proba(devX), r)
|
| 87 |
+
pr = spearmanr(yhat, self.devscores)[0]
|
| 88 |
+
pr = 0 if pr != pr else pr # if NaN bc std=0
|
| 89 |
+
# early stop on Pearson
|
| 90 |
+
if pr > bestpr:
|
| 91 |
+
bestpr = pr
|
| 92 |
+
bestmodel = copy.deepcopy(self.model)
|
| 93 |
+
elif self.early_stop:
|
| 94 |
+
if early_stop_count >= 3:
|
| 95 |
+
stop_train = True
|
| 96 |
+
early_stop_count += 1
|
| 97 |
+
self.model = bestmodel
|
| 98 |
+
|
| 99 |
+
yhat = np.dot(self.predict_proba(testX), r)
|
| 100 |
+
|
| 101 |
+
return bestpr, yhat
|
| 102 |
+
|
| 103 |
+
def trainepoch(self, X, y, nepoches=1):
|
| 104 |
+
self.model.train()
|
| 105 |
+
for _ in range(self.nepoch, self.nepoch + nepoches):
|
| 106 |
+
permutation = np.random.permutation(len(X))
|
| 107 |
+
all_costs = []
|
| 108 |
+
for i in range(0, len(X), self.batch_size):
|
| 109 |
+
# forward
|
| 110 |
+
idx = torch.from_numpy(permutation[i:i + self.batch_size]).long().cuda()
|
| 111 |
+
Xbatch = X[idx]
|
| 112 |
+
ybatch = y[idx]
|
| 113 |
+
output = self.model(Xbatch)
|
| 114 |
+
# loss
|
| 115 |
+
loss = self.loss_fn(output, ybatch)
|
| 116 |
+
all_costs.append(loss.item())
|
| 117 |
+
# backward
|
| 118 |
+
self.optimizer.zero_grad()
|
| 119 |
+
loss.backward()
|
| 120 |
+
# Update parameters
|
| 121 |
+
self.optimizer.step()
|
| 122 |
+
self.nepoch += nepoches
|
| 123 |
+
|
| 124 |
+
def predict_proba(self, devX):
|
| 125 |
+
self.model.eval()
|
| 126 |
+
probas = []
|
| 127 |
+
with torch.no_grad():
|
| 128 |
+
for i in range(0, len(devX), self.batch_size):
|
| 129 |
+
Xbatch = devX[i:i + self.batch_size]
|
| 130 |
+
if len(probas) == 0:
|
| 131 |
+
probas = self.model(Xbatch).data.cpu().numpy()
|
| 132 |
+
else:
|
| 133 |
+
probas = np.concatenate((probas, self.model(Xbatch).data.cpu().numpy()), axis=0)
|
| 134 |
+
return probas
|
SentEval/senteval/tools/validation.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2017-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
"""
|
| 9 |
+
Validation and classification
|
| 10 |
+
(train) : inner-kfold classifier
|
| 11 |
+
(train, test) : kfold classifier
|
| 12 |
+
(train, dev, test) : split classifier
|
| 13 |
+
|
| 14 |
+
"""
|
| 15 |
+
from __future__ import absolute_import, division, unicode_literals
|
| 16 |
+
|
| 17 |
+
import logging
|
| 18 |
+
import numpy as np
|
| 19 |
+
from senteval.tools.classifier import MLP
|
| 20 |
+
|
| 21 |
+
import sklearn
|
| 22 |
+
assert(sklearn.__version__ >= "0.18.0"), \
|
| 23 |
+
"need to update sklearn to version >= 0.18.0"
|
| 24 |
+
from sklearn.linear_model import LogisticRegression
|
| 25 |
+
from sklearn.model_selection import StratifiedKFold
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def get_classif_name(classifier_config, usepytorch):
|
| 29 |
+
if not usepytorch:
|
| 30 |
+
modelname = 'sklearn-LogReg'
|
| 31 |
+
else:
|
| 32 |
+
nhid = classifier_config['nhid']
|
| 33 |
+
optim = 'adam' if 'optim' not in classifier_config else classifier_config['optim']
|
| 34 |
+
bs = 64 if 'batch_size' not in classifier_config else classifier_config['batch_size']
|
| 35 |
+
modelname = 'pytorch-MLP-nhid%s-%s-bs%s' % (nhid, optim, bs)
|
| 36 |
+
return modelname
|
| 37 |
+
|
| 38 |
+
# Pytorch version
|
| 39 |
+
class InnerKFoldClassifier(object):
|
| 40 |
+
"""
|
| 41 |
+
(train) split classifier : InnerKfold.
|
| 42 |
+
"""
|
| 43 |
+
def __init__(self, X, y, config):
|
| 44 |
+
self.X = X
|
| 45 |
+
self.y = y
|
| 46 |
+
self.featdim = X.shape[1]
|
| 47 |
+
self.nclasses = config['nclasses']
|
| 48 |
+
self.seed = config['seed']
|
| 49 |
+
self.devresults = []
|
| 50 |
+
self.testresults = []
|
| 51 |
+
self.usepytorch = config['usepytorch']
|
| 52 |
+
self.classifier_config = config['classifier']
|
| 53 |
+
self.modelname = get_classif_name(self.classifier_config, self.usepytorch)
|
| 54 |
+
|
| 55 |
+
self.k = 5 if 'kfold' not in config else config['kfold']
|
| 56 |
+
|
| 57 |
+
def run(self):
|
| 58 |
+
logging.info('Training {0} with (inner) {1}-fold cross-validation'
|
| 59 |
+
.format(self.modelname, self.k))
|
| 60 |
+
|
| 61 |
+
regs = [10**t for t in range(-5, -1)] if self.usepytorch else \
|
| 62 |
+
[2**t for t in range(-2, 4, 1)]
|
| 63 |
+
skf = StratifiedKFold(n_splits=self.k, shuffle=True, random_state=1111)
|
| 64 |
+
innerskf = StratifiedKFold(n_splits=self.k, shuffle=True,
|
| 65 |
+
random_state=1111)
|
| 66 |
+
count = 0
|
| 67 |
+
for train_idx, test_idx in skf.split(self.X, self.y):
|
| 68 |
+
count += 1
|
| 69 |
+
X_train, X_test = self.X[train_idx], self.X[test_idx]
|
| 70 |
+
y_train, y_test = self.y[train_idx], self.y[test_idx]
|
| 71 |
+
scores = []
|
| 72 |
+
for reg in regs:
|
| 73 |
+
regscores = []
|
| 74 |
+
for inner_train_idx, inner_test_idx in innerskf.split(X_train, y_train):
|
| 75 |
+
X_in_train, X_in_test = X_train[inner_train_idx], X_train[inner_test_idx]
|
| 76 |
+
y_in_train, y_in_test = y_train[inner_train_idx], y_train[inner_test_idx]
|
| 77 |
+
if self.usepytorch:
|
| 78 |
+
clf = MLP(self.classifier_config, inputdim=self.featdim,
|
| 79 |
+
nclasses=self.nclasses, l2reg=reg,
|
| 80 |
+
seed=self.seed)
|
| 81 |
+
clf.fit(X_in_train, y_in_train,
|
| 82 |
+
validation_data=(X_in_test, y_in_test))
|
| 83 |
+
else:
|
| 84 |
+
clf = LogisticRegression(C=reg, random_state=self.seed)
|
| 85 |
+
clf.fit(X_in_train, y_in_train)
|
| 86 |
+
regscores.append(clf.score(X_in_test, y_in_test))
|
| 87 |
+
scores.append(round(100*np.mean(regscores), 2))
|
| 88 |
+
optreg = regs[np.argmax(scores)]
|
| 89 |
+
logging.info('Best param found at split {0}: l2reg = {1} \
|
| 90 |
+
with score {2}'.format(count, optreg, np.max(scores)))
|
| 91 |
+
self.devresults.append(np.max(scores))
|
| 92 |
+
|
| 93 |
+
if self.usepytorch:
|
| 94 |
+
clf = MLP(self.classifier_config, inputdim=self.featdim,
|
| 95 |
+
nclasses=self.nclasses, l2reg=optreg,
|
| 96 |
+
seed=self.seed)
|
| 97 |
+
|
| 98 |
+
clf.fit(X_train, y_train, validation_split=0.05)
|
| 99 |
+
else:
|
| 100 |
+
clf = LogisticRegression(C=optreg, random_state=self.seed)
|
| 101 |
+
clf.fit(X_train, y_train)
|
| 102 |
+
|
| 103 |
+
self.testresults.append(round(100*clf.score(X_test, y_test), 2))
|
| 104 |
+
|
| 105 |
+
devaccuracy = round(np.mean(self.devresults), 2)
|
| 106 |
+
testaccuracy = round(np.mean(self.testresults), 2)
|
| 107 |
+
return devaccuracy, testaccuracy
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class KFoldClassifier(object):
|
| 111 |
+
"""
|
| 112 |
+
(train, test) split classifier : cross-validation on train.
|
| 113 |
+
"""
|
| 114 |
+
def __init__(self, train, test, config):
|
| 115 |
+
self.train = train
|
| 116 |
+
self.test = test
|
| 117 |
+
self.featdim = self.train['X'].shape[1]
|
| 118 |
+
self.nclasses = config['nclasses']
|
| 119 |
+
self.seed = config['seed']
|
| 120 |
+
self.usepytorch = config['usepytorch']
|
| 121 |
+
self.classifier_config = config['classifier']
|
| 122 |
+
self.modelname = get_classif_name(self.classifier_config, self.usepytorch)
|
| 123 |
+
|
| 124 |
+
self.k = 5 if 'kfold' not in config else config['kfold']
|
| 125 |
+
|
| 126 |
+
def run(self):
|
| 127 |
+
# cross-validation
|
| 128 |
+
logging.info('Training {0} with {1}-fold cross-validation'
|
| 129 |
+
.format(self.modelname, self.k))
|
| 130 |
+
regs = [10**t for t in range(-5, -1)] if self.usepytorch else \
|
| 131 |
+
[2**t for t in range(-1, 6, 1)]
|
| 132 |
+
skf = StratifiedKFold(n_splits=self.k, shuffle=True,
|
| 133 |
+
random_state=self.seed)
|
| 134 |
+
scores = []
|
| 135 |
+
|
| 136 |
+
for reg in regs:
|
| 137 |
+
scanscores = []
|
| 138 |
+
for train_idx, test_idx in skf.split(self.train['X'],
|
| 139 |
+
self.train['y']):
|
| 140 |
+
# Split data
|
| 141 |
+
X_train, y_train = self.train['X'][train_idx], self.train['y'][train_idx]
|
| 142 |
+
|
| 143 |
+
X_test, y_test = self.train['X'][test_idx], self.train['y'][test_idx]
|
| 144 |
+
|
| 145 |
+
# Train classifier
|
| 146 |
+
if self.usepytorch:
|
| 147 |
+
clf = MLP(self.classifier_config, inputdim=self.featdim,
|
| 148 |
+
nclasses=self.nclasses, l2reg=reg,
|
| 149 |
+
seed=self.seed)
|
| 150 |
+
clf.fit(X_train, y_train, validation_data=(X_test, y_test))
|
| 151 |
+
else:
|
| 152 |
+
clf = LogisticRegression(C=reg, random_state=self.seed)
|
| 153 |
+
clf.fit(X_train, y_train)
|
| 154 |
+
score = clf.score(X_test, y_test)
|
| 155 |
+
scanscores.append(score)
|
| 156 |
+
# Append mean score
|
| 157 |
+
scores.append(round(100*np.mean(scanscores), 2))
|
| 158 |
+
|
| 159 |
+
# evaluation
|
| 160 |
+
logging.info([('reg:' + str(regs[idx]), scores[idx])
|
| 161 |
+
for idx in range(len(scores))])
|
| 162 |
+
optreg = regs[np.argmax(scores)]
|
| 163 |
+
devaccuracy = np.max(scores)
|
| 164 |
+
logging.info('Cross-validation : best param found is reg = {0} \
|
| 165 |
+
with score {1}'.format(optreg, devaccuracy))
|
| 166 |
+
|
| 167 |
+
logging.info('Evaluating...')
|
| 168 |
+
if self.usepytorch:
|
| 169 |
+
clf = MLP(self.classifier_config, inputdim=self.featdim,
|
| 170 |
+
nclasses=self.nclasses, l2reg=optreg,
|
| 171 |
+
seed=self.seed)
|
| 172 |
+
clf.fit(self.train['X'], self.train['y'], validation_split=0.05)
|
| 173 |
+
else:
|
| 174 |
+
clf = LogisticRegression(C=optreg, random_state=self.seed)
|
| 175 |
+
clf.fit(self.train['X'], self.train['y'])
|
| 176 |
+
yhat = clf.predict(self.test['X'])
|
| 177 |
+
|
| 178 |
+
testaccuracy = clf.score(self.test['X'], self.test['y'])
|
| 179 |
+
testaccuracy = round(100*testaccuracy, 2)
|
| 180 |
+
|
| 181 |
+
return devaccuracy, testaccuracy, yhat
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class SplitClassifier(object):
|
| 185 |
+
"""
|
| 186 |
+
(train, valid, test) split classifier.
|
| 187 |
+
"""
|
| 188 |
+
def __init__(self, X, y, config):
|
| 189 |
+
self.X = X
|
| 190 |
+
self.y = y
|
| 191 |
+
self.nclasses = config['nclasses']
|
| 192 |
+
self.featdim = self.X['train'].shape[1]
|
| 193 |
+
self.seed = config['seed']
|
| 194 |
+
self.usepytorch = config['usepytorch']
|
| 195 |
+
self.classifier_config = config['classifier']
|
| 196 |
+
self.cudaEfficient = False if 'cudaEfficient' not in config else \
|
| 197 |
+
config['cudaEfficient']
|
| 198 |
+
self.modelname = get_classif_name(self.classifier_config, self.usepytorch)
|
| 199 |
+
self.noreg = False if 'noreg' not in config else config['noreg']
|
| 200 |
+
self.config = config
|
| 201 |
+
|
| 202 |
+
def run(self):
|
| 203 |
+
logging.info('Training {0} with standard validation..'
|
| 204 |
+
.format(self.modelname))
|
| 205 |
+
regs = [10**t for t in range(-5, -1)] if self.usepytorch else \
|
| 206 |
+
[2**t for t in range(-2, 4, 1)]
|
| 207 |
+
if self.noreg:
|
| 208 |
+
regs = [1e-9 if self.usepytorch else 1e9]
|
| 209 |
+
scores = []
|
| 210 |
+
for reg in regs:
|
| 211 |
+
if self.usepytorch:
|
| 212 |
+
clf = MLP(self.classifier_config, inputdim=self.featdim,
|
| 213 |
+
nclasses=self.nclasses, l2reg=reg,
|
| 214 |
+
seed=self.seed, cudaEfficient=self.cudaEfficient)
|
| 215 |
+
|
| 216 |
+
# TODO: Find a hack for reducing nb epoches in SNLI
|
| 217 |
+
clf.fit(self.X['train'], self.y['train'],
|
| 218 |
+
validation_data=(self.X['valid'], self.y['valid']))
|
| 219 |
+
else:
|
| 220 |
+
clf = LogisticRegression(C=reg, random_state=self.seed)
|
| 221 |
+
clf.fit(self.X['train'], self.y['train'])
|
| 222 |
+
scores.append(round(100*clf.score(self.X['valid'],
|
| 223 |
+
self.y['valid']), 2))
|
| 224 |
+
logging.info([('reg:'+str(regs[idx]), scores[idx])
|
| 225 |
+
for idx in range(len(scores))])
|
| 226 |
+
optreg = regs[np.argmax(scores)]
|
| 227 |
+
devaccuracy = np.max(scores)
|
| 228 |
+
logging.info('Validation : best param found is reg = {0} with score \
|
| 229 |
+
{1}'.format(optreg, devaccuracy))
|
| 230 |
+
clf = LogisticRegression(C=optreg, random_state=self.seed)
|
| 231 |
+
logging.info('Evaluating...')
|
| 232 |
+
if self.usepytorch:
|
| 233 |
+
clf = MLP(self.classifier_config, inputdim=self.featdim,
|
| 234 |
+
nclasses=self.nclasses, l2reg=optreg,
|
| 235 |
+
seed=self.seed, cudaEfficient=self.cudaEfficient)
|
| 236 |
+
|
| 237 |
+
# TODO: Find a hack for reducing nb epoches in SNLI
|
| 238 |
+
clf.fit(self.X['train'], self.y['train'],
|
| 239 |
+
validation_data=(self.X['valid'], self.y['valid']))
|
| 240 |
+
else:
|
| 241 |
+
clf = LogisticRegression(C=optreg, random_state=self.seed)
|
| 242 |
+
clf.fit(self.X['train'], self.y['train'])
|
| 243 |
+
|
| 244 |
+
testaccuracy = clf.score(self.X['test'], self.y['test'])
|
| 245 |
+
testaccuracy = round(100*testaccuracy, 2)
|
| 246 |
+
return devaccuracy, testaccuracy
|
SentEval/senteval/trec.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2017-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
'''
|
| 9 |
+
TREC question-type classification
|
| 10 |
+
'''
|
| 11 |
+
|
| 12 |
+
from __future__ import absolute_import, division, unicode_literals
|
| 13 |
+
|
| 14 |
+
import os
|
| 15 |
+
import io
|
| 16 |
+
import logging
|
| 17 |
+
import numpy as np
|
| 18 |
+
|
| 19 |
+
from senteval.tools.validation import KFoldClassifier
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class TRECEval(object):
|
| 23 |
+
def __init__(self, task_path, seed=1111):
|
| 24 |
+
logging.info('***** Transfer task : TREC *****\n\n')
|
| 25 |
+
self.seed = seed
|
| 26 |
+
self.train = self.loadFile(os.path.join(task_path, 'train_5500.label'))
|
| 27 |
+
self.test = self.loadFile(os.path.join(task_path, 'TREC_10.label'))
|
| 28 |
+
|
| 29 |
+
def do_prepare(self, params, prepare):
|
| 30 |
+
samples = self.train['X'] + self.test['X']
|
| 31 |
+
return prepare(params, samples)
|
| 32 |
+
|
| 33 |
+
def loadFile(self, fpath):
|
| 34 |
+
trec_data = {'X': [], 'y': []}
|
| 35 |
+
tgt2idx = {'ABBR': 0, 'DESC': 1, 'ENTY': 2,
|
| 36 |
+
'HUM': 3, 'LOC': 4, 'NUM': 5}
|
| 37 |
+
with io.open(fpath, 'r', encoding='latin-1') as f:
|
| 38 |
+
for line in f:
|
| 39 |
+
target, sample = line.strip().split(':', 1)
|
| 40 |
+
sample = sample.split(' ', 1)[1].split()
|
| 41 |
+
assert target in tgt2idx, target
|
| 42 |
+
trec_data['X'].append(sample)
|
| 43 |
+
trec_data['y'].append(tgt2idx[target])
|
| 44 |
+
return trec_data
|
| 45 |
+
|
| 46 |
+
def run(self, params, batcher):
|
| 47 |
+
train_embeddings, test_embeddings = [], []
|
| 48 |
+
|
| 49 |
+
# Sort to reduce padding
|
| 50 |
+
sorted_corpus_train = sorted(zip(self.train['X'], self.train['y']),
|
| 51 |
+
key=lambda z: (len(z[0]), z[1]))
|
| 52 |
+
train_samples = [x for (x, y) in sorted_corpus_train]
|
| 53 |
+
train_labels = [y for (x, y) in sorted_corpus_train]
|
| 54 |
+
|
| 55 |
+
sorted_corpus_test = sorted(zip(self.test['X'], self.test['y']),
|
| 56 |
+
key=lambda z: (len(z[0]), z[1]))
|
| 57 |
+
test_samples = [x for (x, y) in sorted_corpus_test]
|
| 58 |
+
test_labels = [y for (x, y) in sorted_corpus_test]
|
| 59 |
+
|
| 60 |
+
# Get train embeddings
|
| 61 |
+
for ii in range(0, len(train_labels), params.batch_size):
|
| 62 |
+
batch = train_samples[ii:ii + params.batch_size]
|
| 63 |
+
embeddings = batcher(params, batch)
|
| 64 |
+
train_embeddings.append(embeddings)
|
| 65 |
+
train_embeddings = np.vstack(train_embeddings)
|
| 66 |
+
logging.info('Computed train embeddings')
|
| 67 |
+
|
| 68 |
+
# Get test embeddings
|
| 69 |
+
for ii in range(0, len(test_labels), params.batch_size):
|
| 70 |
+
batch = test_samples[ii:ii + params.batch_size]
|
| 71 |
+
embeddings = batcher(params, batch)
|
| 72 |
+
test_embeddings.append(embeddings)
|
| 73 |
+
test_embeddings = np.vstack(test_embeddings)
|
| 74 |
+
logging.info('Computed test embeddings')
|
| 75 |
+
|
| 76 |
+
config_classifier = {'nclasses': 6, 'seed': self.seed,
|
| 77 |
+
'usepytorch': params.usepytorch,
|
| 78 |
+
'classifier': params.classifier,
|
| 79 |
+
'kfold': params.kfold}
|
| 80 |
+
clf = KFoldClassifier({'X': train_embeddings,
|
| 81 |
+
'y': np.array(train_labels)},
|
| 82 |
+
{'X': test_embeddings,
|
| 83 |
+
'y': np.array(test_labels)},
|
| 84 |
+
config_classifier)
|
| 85 |
+
devacc, testacc, _ = clf.run()
|
| 86 |
+
logging.debug('\nDev acc : {0} Test acc : {1} \
|
| 87 |
+
for TREC\n'.format(devacc, testacc))
|
| 88 |
+
return {'devacc': devacc, 'acc': testacc,
|
| 89 |
+
'ndev': len(self.train['X']), 'ntest': len(self.test['X'])}
|
SentEval/senteval/utils.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2017-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
from __future__ import absolute_import, division, unicode_literals
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import re
|
| 12 |
+
import inspect
|
| 13 |
+
from torch import optim
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def create_dictionary(sentences):
|
| 17 |
+
words = {}
|
| 18 |
+
for s in sentences:
|
| 19 |
+
for word in s:
|
| 20 |
+
if word in words:
|
| 21 |
+
words[word] += 1
|
| 22 |
+
else:
|
| 23 |
+
words[word] = 1
|
| 24 |
+
words['<s>'] = 1e9 + 4
|
| 25 |
+
words['</s>'] = 1e9 + 3
|
| 26 |
+
words['<p>'] = 1e9 + 2
|
| 27 |
+
# words['<UNK>'] = 1e9 + 1
|
| 28 |
+
sorted_words = sorted(words.items(), key=lambda x: -x[1]) # inverse sort
|
| 29 |
+
id2word = []
|
| 30 |
+
word2id = {}
|
| 31 |
+
for i, (w, _) in enumerate(sorted_words):
|
| 32 |
+
id2word.append(w)
|
| 33 |
+
word2id[w] = i
|
| 34 |
+
|
| 35 |
+
return id2word, word2id
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def cosine(u, v):
|
| 39 |
+
return np.dot(u, v) / (np.linalg.norm(u) * np.linalg.norm(v))
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class dotdict(dict):
|
| 43 |
+
""" dot.notation access to dictionary attributes """
|
| 44 |
+
__getattr__ = dict.get
|
| 45 |
+
__setattr__ = dict.__setitem__
|
| 46 |
+
__delattr__ = dict.__delitem__
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def get_optimizer(s):
|
| 50 |
+
"""
|
| 51 |
+
Parse optimizer parameters.
|
| 52 |
+
Input should be of the form:
|
| 53 |
+
- "sgd,lr=0.01"
|
| 54 |
+
- "adagrad,lr=0.1,lr_decay=0.05"
|
| 55 |
+
"""
|
| 56 |
+
if "," in s:
|
| 57 |
+
method = s[:s.find(',')]
|
| 58 |
+
optim_params = {}
|
| 59 |
+
for x in s[s.find(',') + 1:].split(','):
|
| 60 |
+
split = x.split('=')
|
| 61 |
+
assert len(split) == 2
|
| 62 |
+
assert re.match("^[+-]?(\d+(\.\d*)?|\.\d+)$", split[1]) is not None
|
| 63 |
+
optim_params[split[0]] = float(split[1])
|
| 64 |
+
else:
|
| 65 |
+
method = s
|
| 66 |
+
optim_params = {}
|
| 67 |
+
|
| 68 |
+
if method == 'adadelta':
|
| 69 |
+
optim_fn = optim.Adadelta
|
| 70 |
+
elif method == 'adagrad':
|
| 71 |
+
optim_fn = optim.Adagrad
|
| 72 |
+
elif method == 'adam':
|
| 73 |
+
optim_fn = optim.Adam
|
| 74 |
+
elif method == 'adamax':
|
| 75 |
+
optim_fn = optim.Adamax
|
| 76 |
+
elif method == 'asgd':
|
| 77 |
+
optim_fn = optim.ASGD
|
| 78 |
+
elif method == 'rmsprop':
|
| 79 |
+
optim_fn = optim.RMSprop
|
| 80 |
+
elif method == 'rprop':
|
| 81 |
+
optim_fn = optim.Rprop
|
| 82 |
+
elif method == 'sgd':
|
| 83 |
+
optim_fn = optim.SGD
|
| 84 |
+
assert 'lr' in optim_params
|
| 85 |
+
else:
|
| 86 |
+
raise Exception('Unknown optimization method: "%s"' % method)
|
| 87 |
+
|
| 88 |
+
# check that we give good parameters to the optimizer
|
| 89 |
+
expected_args = inspect.getargspec(optim_fn.__init__)[0]
|
| 90 |
+
assert expected_args[:2] == ['self', 'params']
|
| 91 |
+
if not all(k in expected_args[2:] for k in optim_params.keys()):
|
| 92 |
+
raise Exception('Unexpected parameters: expected "%s", got "%s"' % (
|
| 93 |
+
str(expected_args[2:]), str(optim_params.keys())))
|
| 94 |
+
|
| 95 |
+
return optim_fn, optim_params
|
SentEval/setup.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2017-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
import io
|
| 9 |
+
from setuptools import setup, find_packages
|
| 10 |
+
|
| 11 |
+
with io.open('./README.md', encoding='utf-8') as f:
|
| 12 |
+
readme = f.read()
|
| 13 |
+
|
| 14 |
+
setup(
|
| 15 |
+
name='SentEval',
|
| 16 |
+
version='0.1.0',
|
| 17 |
+
url='https://github.com/facebookresearch/SentEval',
|
| 18 |
+
packages=find_packages(exclude=['examples']),
|
| 19 |
+
license='Attribution-NonCommercial 4.0 International',
|
| 20 |
+
long_description=readme,
|
| 21 |
+
)
|
data/._data_csv_default-6b8a73dfc1f26733_0.0.0_6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317.lock
ADDED
|
File without changes
|
data/csv/default-6b8a73dfc1f26733/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317.incomplete_info.lock
ADDED
|
File without changes
|
data/csv/default-6b8a73dfc1f26733/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317/cache-e43d857791056f6f.arrow
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c38cded0e7f6b3da19c50174db1a0260c4a306b848171685d77ae7fcf6358bb8
|
| 3 |
+
size 2136
|
data/csv/default-6b8a73dfc1f26733/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317/csv-train.arrow
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:571626bc0264e30c47b15fb4f1ed4ce55fc3aa078e88fe00562ab13f2cec5583
|
| 3 |
+
size 600
|
data/csv/default-6b8a73dfc1f26733/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317/dataset_info.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"description": "", "citation": "", "homepage": "", "license": "", "features": {"version https://git-lfs.github.com/spec/v1": {"dtype": "string", "_type": "Value"}}, "builder_name": "csv", "config_name": "default", "version": {"version_str": "0.0.0", "major": 0, "minor": 0, "patch": 0}, "splits": {"train": {"name": "train", "num_bytes": 96, "num_examples": 2, "dataset_name": "csv"}}, "download_checksums": {"/home/perk/models/SimCSE-test/data/mnli_no_for_simcse.csv": {"num_bytes": 133, "checksum": "e98d34ec65c4e9843be795896c6f82c6d0b8e7379c7d755bb600f534e336097a"}}, "download_size": 133, "dataset_size": 96, "size_in_bytes": 229}
|
data/csv/default-6b8a73dfc1f26733/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317_builder.lock
ADDED
|
File without changes
|
result/sup-simcse-nb-bert-base/config.json
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "NbAiLab/nb-bert-base",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"BertForCL"
|
| 5 |
+
],
|
| 6 |
+
"attention_probs_dropout_prob": 0.1,
|
| 7 |
+
"directionality": "bidi",
|
| 8 |
+
"gradient_checkpointing": false,
|
| 9 |
+
"hidden_act": "gelu",
|
| 10 |
+
"hidden_dropout_prob": 0.1,
|
| 11 |
+
"hidden_size": 768,
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"intermediate_size": 3072,
|
| 14 |
+
"layer_norm_eps": 1e-12,
|
| 15 |
+
"max_position_embeddings": 512,
|
| 16 |
+
"model_type": "bert",
|
| 17 |
+
"num_attention_heads": 12,
|
| 18 |
+
"num_hidden_layers": 12,
|
| 19 |
+
"pad_token_id": 0,
|
| 20 |
+
"pooler_fc_size": 768,
|
| 21 |
+
"pooler_num_attention_heads": 12,
|
| 22 |
+
"pooler_num_fc_layers": 3,
|
| 23 |
+
"pooler_size_per_head": 128,
|
| 24 |
+
"pooler_type": "first_token_transform",
|
| 25 |
+
"position_embedding_type": "absolute",
|
| 26 |
+
"transformers_version": "4.2.1",
|
| 27 |
+
"type_vocab_size": 2,
|
| 28 |
+
"use_cache": true,
|
| 29 |
+
"vocab_size": 119547,
|
| 30 |
+
"xla_device": true
|
| 31 |
+
}
|
result/sup-simcse-nb-bert-base/pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:96826d1d8fe607692b7bba8bace83fffbfd4ffe259633100b83a5701c9e05ea5
|
| 3 |
+
size 711481329
|
result/sup-simcse-nb-bert-base/special_tokens_map.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
|
result/sup-simcse-nb-bert-base/tokenizer_config.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"do_lower_case": false, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "special_tokens_map_file": null, "name_or_path": "NbAiLab/nb-bert-base", "do_basic_tokenize": true, "never_split": null}
|
result/sup-simcse-nb-bert-base/train_results.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7459efd343d212b0aacd508abd903c6fbec4d41b37e6ac1133f57db2e79965df
|
| 3 |
+
size 68
|
result/sup-simcse-nb-bert-base/trainer_state.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"best_metric": null,
|
| 3 |
+
"best_model_checkpoint": null,
|
| 4 |
+
"epoch": 3.0,
|
| 5 |
+
"global_step": 3,
|
| 6 |
+
"is_hyper_param_search": false,
|
| 7 |
+
"is_local_process_zero": true,
|
| 8 |
+
"is_world_process_zero": true,
|
| 9 |
+
"log_history": [
|
| 10 |
+
{
|
| 11 |
+
"epoch": 3.0,
|
| 12 |
+
"step": 3,
|
| 13 |
+
"train_runtime": 0.3583,
|
| 14 |
+
"train_samples_per_second": 8.373
|
| 15 |
+
}
|
| 16 |
+
],
|
| 17 |
+
"max_steps": 3,
|
| 18 |
+
"num_train_epochs": 3,
|
| 19 |
+
"total_flos": 409774325760,
|
| 20 |
+
"trial_name": null,
|
| 21 |
+
"trial_params": null
|
| 22 |
+
}
|
result/sup-simcse-nb-bert-base/training_args.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:240f693f0920769bbe8786c88d43cae1f7c134ac3692d3396d9df47cb9ae14e4
|
| 3 |
+
size 2095
|
result/sup-simcse-nb-bert-base/vocab.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fe0fda7c425b48c516fc8f160d594c8022a0808447475c1a7c6d6479763f310c
|
| 3 |
+
size 995526
|
runs/Oct21_13-13-50_t1v-n-d0240692-w-0/1666358047.7059593/events.out.tfevents.1666358047.t1v-n-d0240692-w-0.37317.1
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9d3b24404456d0b6365784d0b4d4104a121578d0149c86e8eccbb3efb7767b6f
|
| 3 |
+
size 3146
|
runs/Oct21_13-13-50_t1v-n-d0240692-w-0/events.out.tfevents.1666358047.t1v-n-d0240692-w-0.37317.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2e9cbc308f95994c70c4d65ff6bc29bbc13432ab236c5e648ece19d0d0a46197
|
| 3 |
+
size 2738
|
runs/Oct21_13-17-52_t1v-n-d0240692-w-0/1666358281.579476/events.out.tfevents.1666358281.t1v-n-d0240692-w-0.41386.1
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6cb2e2ceb22add60dc558d997a8b51dff053c36da73040f3de7f87e90f2b53b4
|
| 3 |
+
size 3146
|
runs/Oct21_13-17-52_t1v-n-d0240692-w-0/events.out.tfevents.1666358281.t1v-n-d0240692-w-0.41386.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7c3d121de020f204bf8afa1dd2e791ae646fd8c9a0fb5c6df975557f0f966ed1
|
| 3 |
+
size 2738
|