allennlp-coref-onnx-mMiniLMv2-L12-H384-distilled-from-XLMR-Large
Browse files- allennlp-coref-onnx-mMiniLMv2-L12-H384-distilled-from-XLMR-Large/.gitattributes +35 -0
- allennlp-coref-onnx-mMiniLMv2-L12-H384-distilled-from-XLMR-Large/README.md +89 -0
- allennlp-coref-onnx-mMiniLMv2-L12-H384-distilled-from-XLMR-Large/config.json +19 -0
- allennlp-coref-onnx-mMiniLMv2-L12-H384-distilled-from-XLMR-Large/export/allennlp/nn/util.py +0 -0
- allennlp-coref-onnx-mMiniLMv2-L12-H384-distilled-from-XLMR-Large/export/allennlp_models/coref/models/coref.py +943 -0
- allennlp-coref-onnx-mMiniLMv2-L12-H384-distilled-from-XLMR-Large/export/example.py +16 -0
- allennlp-coref-onnx-mMiniLMv2-L12-H384-distilled-from-XLMR-Large/export/export_onnx.py +166 -0
- allennlp-coref-onnx-mMiniLMv2-L12-H384-distilled-from-XLMR-Large/model.onnx +3 -0
- allennlp-coref-onnx-mMiniLMv2-L12-H384-distilled-from-XLMR-Large/source.txt +1 -0
- allennlp-coref-onnx-mMiniLMv2-L12-H384-distilled-from-XLMR-Large/tokenizer.json +0 -0
allennlp-coref-onnx-mMiniLMv2-L12-H384-distilled-from-XLMR-Large/.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
allennlp-coref-onnx-mMiniLMv2-L12-H384-distilled-from-XLMR-Large/README.md
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
base_model:
|
| 4 |
+
- nreimers/mMiniLMv2-L6-H384-distilled-from-XLMR-Large
|
| 5 |
+
pipeline_tag: token-classification
|
| 6 |
+
tags:
|
| 7 |
+
- coreference-resolution
|
| 8 |
+
- multilingual
|
| 9 |
+
- onnx
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
## Usage
|
| 13 |
+
|
| 14 |
+
```sh
|
| 15 |
+
$ pip install coref-onnx
|
| 16 |
+
```
|
| 17 |
+
|
| 18 |
+
```python
|
| 19 |
+
from coref_onnx import CoreferenceResolver, decode_clusters
|
| 20 |
+
|
| 21 |
+
resolver = CoreferenceResolver.from_pretrained("talmago/allennlp-coref-onnx-mMiniLMv2-L12-H384-distilled-from-XLMR-Large")
|
| 22 |
+
|
| 23 |
+
sentences = [
|
| 24 |
+
["Barack", "Obama", "was", "the", "44th", "President", "of", "the", "United", "States", "."],
|
| 25 |
+
["He", "was", "born", "in", "Hawaii", "."]
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
pred = resolver(sentences)
|
| 29 |
+
|
| 30 |
+
print("Clusters:", pred["clusters"][0])
|
| 31 |
+
print("Decoded clusters:", decode_clusters(sentences, pred["clusters"][0]))
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
Output is:
|
| 35 |
+
|
| 36 |
+
```
|
| 37 |
+
Clusters: [[[(0, 1), (11, 11)]]]
|
| 38 |
+
Decoded clusters: [['Barack Obama', 'He']]
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
## ONNX
|
| 42 |
+
|
| 43 |
+
Download MiniLM model archive
|
| 44 |
+
|
| 45 |
+
```sh
|
| 46 |
+
$ mkdir -p models/minillm
|
| 47 |
+
$ wget -P models/minillm https://storage.googleapis.com/pandora-intelligence/models/crosslingual-coreference/minilm/model.tar.gz
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
Run docker container:
|
| 51 |
+
|
| 52 |
+
```sh
|
| 53 |
+
$ docker run -it --platform linux/amd64 --entrypoint /bin/bash -v $(pwd)/models/minillm:/models/minillm allennlp/allennlp:latest
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
Install `allennlp_models`
|
| 57 |
+
|
| 58 |
+
```sh
|
| 59 |
+
$ pip install allennlp_models
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
Use another tab copy source code and scripts to the container
|
| 63 |
+
|
| 64 |
+
```sh
|
| 65 |
+
$ docker cp allennlp_models/coref/models/coref.py <CONTAINER_ID>:/opt/conda/lib/python3.8/site-packages/allennlp_models/coref/models/coref.py
|
| 66 |
+
$ docker cp allennlp/nn/util.py <CONTAINER_ID>:/stage/allennlp/allennlp/nn/util.py
|
| 67 |
+
$ docker cp export_onnx.py <CONTAINER_ID>:/app/export_onnx.py
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
In the container run:
|
| 71 |
+
|
| 72 |
+
```sh
|
| 73 |
+
$ mkdir nreimers
|
| 74 |
+
$ git clone https://huggingface.co/nreimers/mMiniLMv2-L12-H384-distilled-from-XLMR-Large nreimers
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
And then run the export script:
|
| 78 |
+
|
| 79 |
+
```sh
|
| 80 |
+
$ python export_onnx.py
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
## Model Optimization
|
| 84 |
+
|
| 85 |
+
Run `onnxsim`
|
| 86 |
+
|
| 87 |
+
```sh
|
| 88 |
+
$ python -m onnxsim models/minillm/model.onnx optimized_model.onnx
|
| 89 |
+
```
|
allennlp-coref-onnx-mMiniLMv2-L12-H384-distilled-from-XLMR-Large/config.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"max_span_width": 5,
|
| 3 |
+
"max_spans": 512,
|
| 4 |
+
"spans_per_word": 0.4,
|
| 5 |
+
"max_antecedents": 50,
|
| 6 |
+
"inference_order": 2,
|
| 7 |
+
"feature_size": 20,
|
| 8 |
+
"coarse_to_fine": true,
|
| 9 |
+
"model_hidden_dims": 1500,
|
| 10 |
+
"model_dropout": 0.3,
|
| 11 |
+
"mention_input_dim": 1172,
|
| 12 |
+
"antecedent_input_dim": 3536,
|
| 13 |
+
"transformer": {
|
| 14 |
+
"model_name": "nreimers/mMiniLMv2-L12-H384-distilled-from-XLMR-Large",
|
| 15 |
+
"hidden_size": 384,
|
| 16 |
+
"max_position_embeddings": 512
|
| 17 |
+
},
|
| 18 |
+
"max_sentences": 120
|
| 19 |
+
}
|
allennlp-coref-onnx-mMiniLMv2-L12-H384-distilled-from-XLMR-Large/export/allennlp/nn/util.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
allennlp-coref-onnx-mMiniLMv2-L12-H384-distilled-from-XLMR-Large/export/allennlp_models/coref/models/coref.py
ADDED
|
@@ -0,0 +1,943 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import math
|
| 3 |
+
from typing import Any, Dict, List, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
from allennlp.data import TextFieldTensors, Vocabulary
|
| 10 |
+
from allennlp.models.model import Model
|
| 11 |
+
from allennlp.modules.token_embedders import Embedding
|
| 12 |
+
from allennlp.modules import FeedForward, GatedSum
|
| 13 |
+
from allennlp.modules import Seq2SeqEncoder, TimeDistributed, TextFieldEmbedder
|
| 14 |
+
from allennlp.modules.span_extractors import (
|
| 15 |
+
SelfAttentiveSpanExtractor,
|
| 16 |
+
EndpointSpanExtractor,
|
| 17 |
+
)
|
| 18 |
+
from allennlp.nn import util, InitializerApplicator
|
| 19 |
+
|
| 20 |
+
from allennlp_models.coref.metrics.conll_coref_scores import ConllCorefScores
|
| 21 |
+
from allennlp_models.coref.metrics.mention_recall import MentionRecall
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@Model.register("coref")
|
| 27 |
+
class CoreferenceResolver(Model):
|
| 28 |
+
"""
|
| 29 |
+
This `Model` implements the coreference resolution model described in
|
| 30 |
+
[Higher-order Coreference Resolution with Coarse-to-fine Inference](https://arxiv.org/pdf/1804.05392.pdf)
|
| 31 |
+
by Lee et al., 2018.
|
| 32 |
+
The basic outline of this model is to get an embedded representation of each span in the
|
| 33 |
+
document. These span representations are scored and used to prune away spans that are unlikely
|
| 34 |
+
to occur in a coreference cluster. For the remaining spans, the model decides which antecedent
|
| 35 |
+
span (if any) they are coreferent with. The resulting coreference links, after applying
|
| 36 |
+
transitivity, imply a clustering of the spans in the document.
|
| 37 |
+
|
| 38 |
+
# Parameters
|
| 39 |
+
|
| 40 |
+
vocab : `Vocabulary`
|
| 41 |
+
text_field_embedder : `TextFieldEmbedder`
|
| 42 |
+
Used to embed the `text` `TextField` we get as input to the model.
|
| 43 |
+
context_layer : `Seq2SeqEncoder`
|
| 44 |
+
This layer incorporates contextual information for each word in the document.
|
| 45 |
+
mention_feedforward : `FeedForward`
|
| 46 |
+
This feedforward network is applied to the span representations which is then scored
|
| 47 |
+
by a linear layer.
|
| 48 |
+
antecedent_feedforward : `FeedForward`
|
| 49 |
+
This feedforward network is applied to pairs of span representation, along with any
|
| 50 |
+
pairwise features, which is then scored by a linear layer.
|
| 51 |
+
feature_size : `int`
|
| 52 |
+
The embedding size for all the embedded features, such as distances or span widths.
|
| 53 |
+
max_span_width : `int`
|
| 54 |
+
The maximum width of candidate spans.
|
| 55 |
+
spans_per_word: `float`, required.
|
| 56 |
+
A multiplier between zero and one which controls what percentage of candidate mention
|
| 57 |
+
spans we retain with respect to the number of words in the document.
|
| 58 |
+
max_antecedents: `int`, required.
|
| 59 |
+
For each mention which survives the pruning stage, we consider this many antecedents.
|
| 60 |
+
coarse_to_fine: `bool`, optional (default = `False`)
|
| 61 |
+
Whether or not to apply the coarse-to-fine filtering.
|
| 62 |
+
inference_order: `int`, optional (default = `1`)
|
| 63 |
+
The number of inference orders. When greater than 1, the span representations are
|
| 64 |
+
updated and coreference scores re-computed.
|
| 65 |
+
lexical_dropout : `int`
|
| 66 |
+
The probability of dropping out dimensions of the embedded text.
|
| 67 |
+
initializer : `InitializerApplicator`, optional (default=`InitializerApplicator()`)
|
| 68 |
+
Used to initialize the model parameters.
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
def __init__(
|
| 72 |
+
self,
|
| 73 |
+
vocab: Vocabulary,
|
| 74 |
+
text_field_embedder: TextFieldEmbedder,
|
| 75 |
+
context_layer: Seq2SeqEncoder,
|
| 76 |
+
mention_feedforward: FeedForward,
|
| 77 |
+
antecedent_feedforward: FeedForward,
|
| 78 |
+
feature_size: int,
|
| 79 |
+
max_span_width: int,
|
| 80 |
+
spans_per_word: float,
|
| 81 |
+
max_antecedents: int,
|
| 82 |
+
coarse_to_fine: bool = False,
|
| 83 |
+
inference_order: int = 1,
|
| 84 |
+
lexical_dropout: float = 0.2,
|
| 85 |
+
initializer: InitializerApplicator = InitializerApplicator(),
|
| 86 |
+
**kwargs,
|
| 87 |
+
) -> None:
|
| 88 |
+
super().__init__(vocab, **kwargs)
|
| 89 |
+
|
| 90 |
+
self._text_field_embedder = text_field_embedder
|
| 91 |
+
self._context_layer = context_layer
|
| 92 |
+
self._mention_feedforward = TimeDistributed(mention_feedforward)
|
| 93 |
+
self._mention_scorer = TimeDistributed(
|
| 94 |
+
torch.nn.Linear(mention_feedforward.get_output_dim(), 1)
|
| 95 |
+
)
|
| 96 |
+
self._antecedent_feedforward = TimeDistributed(antecedent_feedforward)
|
| 97 |
+
self._antecedent_scorer = TimeDistributed(
|
| 98 |
+
torch.nn.Linear(antecedent_feedforward.get_output_dim(), 1)
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
self._endpoint_span_extractor = EndpointSpanExtractor(
|
| 102 |
+
context_layer.get_output_dim(),
|
| 103 |
+
combination="x,y",
|
| 104 |
+
num_width_embeddings=max_span_width,
|
| 105 |
+
span_width_embedding_dim=feature_size,
|
| 106 |
+
bucket_widths=False,
|
| 107 |
+
)
|
| 108 |
+
self._attentive_span_extractor = SelfAttentiveSpanExtractor(
|
| 109 |
+
input_dim=text_field_embedder.get_output_dim()
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# 10 possible distance buckets.
|
| 113 |
+
self._num_distance_buckets = 10
|
| 114 |
+
self._distance_embedding = Embedding(
|
| 115 |
+
embedding_dim=feature_size, num_embeddings=self._num_distance_buckets
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
self._max_span_width = max_span_width
|
| 119 |
+
self._spans_per_word = spans_per_word
|
| 120 |
+
self._max_antecedents = max_antecedents
|
| 121 |
+
|
| 122 |
+
self._coarse_to_fine = coarse_to_fine
|
| 123 |
+
if self._coarse_to_fine:
|
| 124 |
+
self._coarse2fine_scorer = torch.nn.Linear(
|
| 125 |
+
mention_feedforward.get_input_dim(), mention_feedforward.get_input_dim()
|
| 126 |
+
)
|
| 127 |
+
self._inference_order = inference_order
|
| 128 |
+
if self._inference_order > 1:
|
| 129 |
+
self._span_updating_gated_sum = GatedSum(
|
| 130 |
+
mention_feedforward.get_input_dim()
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
self._mention_recall = MentionRecall()
|
| 134 |
+
self._conll_coref_scores = ConllCorefScores()
|
| 135 |
+
if lexical_dropout > 0:
|
| 136 |
+
self._lexical_dropout = torch.nn.Dropout(p=lexical_dropout)
|
| 137 |
+
else:
|
| 138 |
+
self._lexical_dropout = lambda x: x
|
| 139 |
+
initializer(self)
|
| 140 |
+
|
| 141 |
+
def forward(
|
| 142 |
+
self, # type: ignore
|
| 143 |
+
text: TextFieldTensors,
|
| 144 |
+
spans: torch.IntTensor,
|
| 145 |
+
span_labels: torch.IntTensor = None,
|
| 146 |
+
metadata: List[Dict[str, Any]] = None,
|
| 147 |
+
) -> Dict[str, torch.Tensor]:
|
| 148 |
+
"""
|
| 149 |
+
# Parameters
|
| 150 |
+
|
| 151 |
+
text : `TextFieldTensors`, required.
|
| 152 |
+
The output of a `TextField` representing the text of
|
| 153 |
+
the document.
|
| 154 |
+
spans : `torch.IntTensor`, required.
|
| 155 |
+
A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end
|
| 156 |
+
indices of candidate spans for mentions. Comes from a `ListField[SpanField]` of
|
| 157 |
+
indices into the text of the document.
|
| 158 |
+
span_labels : `torch.IntTensor`, optional (default = `None`).
|
| 159 |
+
A tensor of shape (batch_size, num_spans), representing the cluster ids
|
| 160 |
+
of each span, or -1 for those which do not appear in any clusters.
|
| 161 |
+
metadata : `List[Dict[str, Any]]`, optional (default = `None`).
|
| 162 |
+
A metadata dictionary for each instance in the batch. We use the "original_text" and "clusters" keys
|
| 163 |
+
from this dictionary, which respectively have the original text and the annotated gold coreference
|
| 164 |
+
clusters for that instance.
|
| 165 |
+
|
| 166 |
+
# Returns
|
| 167 |
+
|
| 168 |
+
An output dictionary consisting of:
|
| 169 |
+
|
| 170 |
+
top_spans : `torch.IntTensor`
|
| 171 |
+
A tensor of shape `(batch_size, num_spans_to_keep, 2)` representing
|
| 172 |
+
the start and end word indices of the top spans that survived the pruning stage.
|
| 173 |
+
antecedent_indices : `torch.IntTensor`
|
| 174 |
+
A tensor of shape `(num_spans_to_keep, max_antecedents)` representing for each top span
|
| 175 |
+
the index (with respect to top_spans) of the possible antecedents the model considered.
|
| 176 |
+
predicted_antecedents : `torch.IntTensor`
|
| 177 |
+
A tensor of shape `(batch_size, num_spans_to_keep)` representing, for each top span, the
|
| 178 |
+
index (with respect to antecedent_indices) of the most likely antecedent. -1 means there
|
| 179 |
+
was no predicted link.
|
| 180 |
+
loss : `torch.FloatTensor`, optional
|
| 181 |
+
A scalar loss to be optimised.
|
| 182 |
+
"""
|
| 183 |
+
# Shape: (batch_size, document_length, embedding_size)
|
| 184 |
+
text_embeddings = self._lexical_dropout(self._text_field_embedder(text))
|
| 185 |
+
|
| 186 |
+
batch_size = spans.shape[0]
|
| 187 |
+
document_length = text_embeddings.shape[1]
|
| 188 |
+
num_spans = spans.shape[1]
|
| 189 |
+
|
| 190 |
+
# Shape: (batch_size, document_length)
|
| 191 |
+
text_mask = util.get_text_field_mask(text)
|
| 192 |
+
|
| 193 |
+
# Shape: (batch_size, num_spans)
|
| 194 |
+
# span_mask = (spans[:, :, 0] >= 0).squeeze(-1)
|
| 195 |
+
span_mask = spans[:, :, 0] >= 0
|
| 196 |
+
|
| 197 |
+
# SpanFields return -1 when they are used as padding. As we do
|
| 198 |
+
# some comparisons based on span widths when we attend over the
|
| 199 |
+
# span representations that we generate from these indices, we
|
| 200 |
+
# need them to be <= 0. This is only relevant in edge cases where
|
| 201 |
+
# the number of spans we consider after the pruning stage is >= the
|
| 202 |
+
# total number of spans, because in this case, it is possible we might
|
| 203 |
+
# consider a masked span.
|
| 204 |
+
# Shape: (batch_size, num_spans, 2)
|
| 205 |
+
spans = F.relu(spans.float()).long()
|
| 206 |
+
|
| 207 |
+
# Shape: (batch_size, document_length, encoding_dim)
|
| 208 |
+
contextualized_embeddings = self._context_layer(text_embeddings, text_mask)
|
| 209 |
+
# Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
|
| 210 |
+
endpoint_span_embeddings = self._endpoint_span_extractor(
|
| 211 |
+
contextualized_embeddings, spans
|
| 212 |
+
)
|
| 213 |
+
# Shape: (batch_size, num_spans, emebedding_size)
|
| 214 |
+
attended_span_embeddings = self._attentive_span_extractor(
|
| 215 |
+
text_embeddings, spans
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size)
|
| 219 |
+
span_embeddings = torch.cat(
|
| 220 |
+
[endpoint_span_embeddings, attended_span_embeddings], -1
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
# Prune based on mention scores.
|
| 224 |
+
num_spans_to_keep = int(math.floor(self._spans_per_word * document_length))
|
| 225 |
+
num_spans_to_keep = min(num_spans_to_keep, num_spans)
|
| 226 |
+
# num_spans_to_keep = num_spans
|
| 227 |
+
|
| 228 |
+
# Shape: (batch_size, num_spans)
|
| 229 |
+
span_mention_scores = self._mention_scorer(
|
| 230 |
+
self._mention_feedforward(span_embeddings)
|
| 231 |
+
).squeeze(-1)
|
| 232 |
+
k = torch.full(
|
| 233 |
+
(batch_size,), num_spans_to_keep, dtype=torch.long, device=spans.device
|
| 234 |
+
)
|
| 235 |
+
# Shape: (batch_size, num_spans) for all 3 tensors
|
| 236 |
+
top_span_mention_scores, top_span_mask, top_span_indices = util.masked_topk(
|
| 237 |
+
span_mention_scores, span_mask, k, dim=1
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
# Shape: (batch_size * num_spans_to_keep)
|
| 241 |
+
# torch.index_select only accepts 1D indices, but here
|
| 242 |
+
# we need to select spans for each element in the batch.
|
| 243 |
+
# This reformats the indices to take into account their
|
| 244 |
+
# index into the batch. We precompute this here to make
|
| 245 |
+
# the multiple calls to util.batched_index_select below more efficient.
|
| 246 |
+
flat_top_span_indices = util.flatten_and_batch_shift_indices(
|
| 247 |
+
top_span_indices, num_spans
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
# Compute final predictions for which spans to consider as mentions.
|
| 251 |
+
# Shape: (batch_size, num_spans_to_keep, 2)
|
| 252 |
+
top_spans = util.batched_index_select(
|
| 253 |
+
spans, top_span_indices, flat_top_span_indices
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
# Shape: (batch_size, num_spans_to_keep, embedding_size)
|
| 257 |
+
top_span_embeddings = util.batched_index_select(
|
| 258 |
+
span_embeddings, top_span_indices, flat_top_span_indices
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
# Compute indices for antecedent spans to consider.
|
| 262 |
+
max_antecedents = min(self._max_antecedents, num_spans_to_keep)
|
| 263 |
+
|
| 264 |
+
# Now that we have our variables in terms of num_spans_to_keep, we need to
|
| 265 |
+
# compare span pairs to decide each span's antecedent. Each span can only
|
| 266 |
+
# have prior spans as antecedents, and we only consider up to max_antecedents
|
| 267 |
+
# prior spans. So the first thing we do is construct a matrix mapping a span's
|
| 268 |
+
# index to the indices of its allowed antecedents.
|
| 269 |
+
|
| 270 |
+
# Once we have this matrix, we reformat our variables again to get embeddings
|
| 271 |
+
# for all valid antecedents for each span. This gives us variables with shapes
|
| 272 |
+
# like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which
|
| 273 |
+
# we can use to make coreference decisions between valid span pairs.
|
| 274 |
+
|
| 275 |
+
if self._coarse_to_fine:
|
| 276 |
+
pruned_antecedents = self._coarse_to_fine_pruning(
|
| 277 |
+
top_span_embeddings,
|
| 278 |
+
top_span_mention_scores,
|
| 279 |
+
top_span_mask,
|
| 280 |
+
max_antecedents,
|
| 281 |
+
)
|
| 282 |
+
else:
|
| 283 |
+
pruned_antecedents = self._distance_pruning(
|
| 284 |
+
top_span_embeddings, top_span_mention_scores, max_antecedents
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
# Shape: (batch_size, num_spans_to_keep, max_antecedents) for all 4 tensors
|
| 288 |
+
(
|
| 289 |
+
top_partial_coreference_scores,
|
| 290 |
+
top_antecedent_mask,
|
| 291 |
+
top_antecedent_offsets,
|
| 292 |
+
top_antecedent_indices,
|
| 293 |
+
) = pruned_antecedents
|
| 294 |
+
|
| 295 |
+
flat_top_antecedent_indices = util.flatten_and_batch_shift_indices(
|
| 296 |
+
top_antecedent_indices, num_spans_to_keep
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
# Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
|
| 300 |
+
top_antecedent_embeddings = util.batched_index_select(
|
| 301 |
+
top_span_embeddings, top_antecedent_indices, flat_top_antecedent_indices
|
| 302 |
+
)
|
| 303 |
+
# Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
|
| 304 |
+
coreference_scores = self._compute_coreference_scores(
|
| 305 |
+
top_span_embeddings,
|
| 306 |
+
top_antecedent_embeddings,
|
| 307 |
+
top_partial_coreference_scores,
|
| 308 |
+
top_antecedent_mask,
|
| 309 |
+
top_antecedent_offsets,
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
for _ in range(self._inference_order - 1):
|
| 313 |
+
dummy_mask = top_antecedent_mask.new_ones(batch_size, num_spans_to_keep, 1)
|
| 314 |
+
# Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents,)
|
| 315 |
+
top_antecedent_with_dummy_mask = torch.cat(
|
| 316 |
+
[dummy_mask, top_antecedent_mask], -1
|
| 317 |
+
)
|
| 318 |
+
# Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
|
| 319 |
+
attention_weight = util.masked_softmax(
|
| 320 |
+
coreference_scores,
|
| 321 |
+
top_antecedent_with_dummy_mask,
|
| 322 |
+
memory_efficient=True,
|
| 323 |
+
)
|
| 324 |
+
# Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents, embedding_size)
|
| 325 |
+
top_antecedent_with_dummy_embeddings = torch.cat(
|
| 326 |
+
[top_span_embeddings.unsqueeze(2), top_antecedent_embeddings], 2
|
| 327 |
+
)
|
| 328 |
+
# Shape: (batch_size, num_spans_to_keep, embedding_size)
|
| 329 |
+
attended_embeddings = util.weighted_sum(
|
| 330 |
+
top_antecedent_with_dummy_embeddings, attention_weight
|
| 331 |
+
)
|
| 332 |
+
# Shape: (batch_size, num_spans_to_keep, embedding_size)
|
| 333 |
+
top_span_embeddings = self._span_updating_gated_sum(
|
| 334 |
+
top_span_embeddings, attended_embeddings
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
# Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
|
| 338 |
+
top_antecedent_embeddings = util.batched_index_select(
|
| 339 |
+
top_span_embeddings, top_antecedent_indices, flat_top_antecedent_indices
|
| 340 |
+
)
|
| 341 |
+
# Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
|
| 342 |
+
coreference_scores = self._compute_coreference_scores(
|
| 343 |
+
top_span_embeddings,
|
| 344 |
+
top_antecedent_embeddings,
|
| 345 |
+
top_partial_coreference_scores,
|
| 346 |
+
top_antecedent_mask,
|
| 347 |
+
top_antecedent_offsets,
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
# We now have, for each span which survived the pruning stage,
|
| 351 |
+
# a predicted antecedent. This implies a clustering if we group
|
| 352 |
+
# mentions which refer to each other in a chain.
|
| 353 |
+
# Shape: (batch_size, num_spans_to_keep)
|
| 354 |
+
_, predicted_antecedents = coreference_scores.max(2)
|
| 355 |
+
# Subtract one here because index 0 is the "no antecedent" class,
|
| 356 |
+
# so this makes the indices line up with actual spans if the prediction
|
| 357 |
+
# is greater than -1.
|
| 358 |
+
predicted_antecedents -= 1
|
| 359 |
+
|
| 360 |
+
output_dict = {
|
| 361 |
+
"top_spans": top_spans,
|
| 362 |
+
"antecedent_indices": top_antecedent_indices,
|
| 363 |
+
"predicted_antecedents": predicted_antecedents,
|
| 364 |
+
}
|
| 365 |
+
if span_labels is not None:
|
| 366 |
+
# Find the gold labels for the spans which we kept.
|
| 367 |
+
# Shape: (batch_size, num_spans_to_keep, 1)
|
| 368 |
+
pruned_gold_labels = util.batched_index_select(
|
| 369 |
+
span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
# Shape: (batch_size, num_spans_to_keep, max_antecedents)
|
| 373 |
+
antecedent_labels = util.batched_index_select(
|
| 374 |
+
pruned_gold_labels, top_antecedent_indices, flat_top_antecedent_indices
|
| 375 |
+
).squeeze(-1)
|
| 376 |
+
antecedent_labels = util.replace_masked_values(
|
| 377 |
+
antecedent_labels, top_antecedent_mask, -100
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
# Compute labels.
|
| 381 |
+
# Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
|
| 382 |
+
gold_antecedent_labels = self._compute_antecedent_gold_labels(
|
| 383 |
+
pruned_gold_labels, antecedent_labels
|
| 384 |
+
)
|
| 385 |
+
# Now, compute the loss using the negative marginal log-likelihood.
|
| 386 |
+
# This is equal to the log of the sum of the probabilities of all antecedent predictions
|
| 387 |
+
# that would be consistent with the data, in the sense that we are minimising, for a
|
| 388 |
+
# given span, the negative marginal log likelihood of all antecedents which are in the
|
| 389 |
+
# same gold cluster as the span we are currently considering. Each span i predicts a
|
| 390 |
+
# single antecedent j, but there might be several prior mentions k in the same
|
| 391 |
+
# coreference cluster that would be valid antecedents. Our loss is the sum of the
|
| 392 |
+
# probability assigned to all valid antecedents. This is a valid objective for
|
| 393 |
+
# clustering as we don't mind which antecedent is predicted, so long as they are in
|
| 394 |
+
# the same coreference cluster.
|
| 395 |
+
coreference_log_probs = util.masked_log_softmax(
|
| 396 |
+
coreference_scores, top_span_mask.unsqueeze(-1)
|
| 397 |
+
)
|
| 398 |
+
correct_antecedent_log_probs = (
|
| 399 |
+
coreference_log_probs + gold_antecedent_labels.log()
|
| 400 |
+
)
|
| 401 |
+
negative_marginal_log_likelihood = -util.logsumexp(
|
| 402 |
+
correct_antecedent_log_probs
|
| 403 |
+
).sum()
|
| 404 |
+
|
| 405 |
+
self._mention_recall(top_spans, metadata)
|
| 406 |
+
self._conll_coref_scores(
|
| 407 |
+
top_spans, top_antecedent_indices, predicted_antecedents, metadata
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
output_dict["loss"] = negative_marginal_log_likelihood
|
| 411 |
+
|
| 412 |
+
if metadata is not None:
|
| 413 |
+
output_dict["document"] = [x["original_text"] for x in metadata]
|
| 414 |
+
return output_dict
|
| 415 |
+
|
| 416 |
+
def make_output_human_readable(self, output_dict: Dict[str, torch.Tensor]):
|
| 417 |
+
"""
|
| 418 |
+
Converts the list of spans and predicted antecedent indices into clusters
|
| 419 |
+
of spans for each element in the batch.
|
| 420 |
+
|
| 421 |
+
# Parameters
|
| 422 |
+
|
| 423 |
+
output_dict : `Dict[str, torch.Tensor]`, required.
|
| 424 |
+
The result of calling :func:`forward` on an instance or batch of instances.
|
| 425 |
+
|
| 426 |
+
# Returns
|
| 427 |
+
|
| 428 |
+
The same output dictionary, but with an additional `clusters` key:
|
| 429 |
+
|
| 430 |
+
clusters : `List[List[List[Tuple[int, int]]]]`
|
| 431 |
+
A nested list, representing, for each instance in the batch, the list of clusters,
|
| 432 |
+
which are in turn comprised of a list of (start, end) inclusive spans into the
|
| 433 |
+
original document.
|
| 434 |
+
"""
|
| 435 |
+
|
| 436 |
+
# A tensor of shape (batch_size, num_spans_to_keep, 2), representing
|
| 437 |
+
# the start and end indices of each span.
|
| 438 |
+
batch_top_spans = output_dict["top_spans"].detach().cpu()
|
| 439 |
+
|
| 440 |
+
# A tensor of shape (batch_size, num_spans_to_keep) representing, for each span,
|
| 441 |
+
# the index into `antecedent_indices` which specifies the antecedent span. Additionally,
|
| 442 |
+
# the index can be -1, specifying that the span has no predicted antecedent.
|
| 443 |
+
batch_predicted_antecedents = (
|
| 444 |
+
output_dict["predicted_antecedents"].detach().cpu()
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
# A tensor of shape (num_spans_to_keep, max_antecedents), representing the indices
|
| 448 |
+
# of the predicted antecedents with respect to the 2nd dimension of `batch_top_spans`
|
| 449 |
+
# for each antecedent we considered.
|
| 450 |
+
batch_antecedent_indices = output_dict["antecedent_indices"].detach().cpu()
|
| 451 |
+
batch_clusters: List[List[List[Tuple[int, int]]]] = []
|
| 452 |
+
|
| 453 |
+
# Calling zip() on two tensors results in an iterator over their
|
| 454 |
+
# first dimension. This is iterating over instances in the batch.
|
| 455 |
+
for top_spans, predicted_antecedents, antecedent_indices in zip(
|
| 456 |
+
batch_top_spans, batch_predicted_antecedents, batch_antecedent_indices
|
| 457 |
+
):
|
| 458 |
+
spans_to_cluster_ids: Dict[Tuple[int, int], int] = {}
|
| 459 |
+
clusters: List[List[Tuple[int, int]]] = []
|
| 460 |
+
|
| 461 |
+
for i, (span, predicted_antecedent) in enumerate(
|
| 462 |
+
zip(top_spans, predicted_antecedents)
|
| 463 |
+
):
|
| 464 |
+
if predicted_antecedent < 0:
|
| 465 |
+
# We don't care about spans which are
|
| 466 |
+
# not co-referent with anything.
|
| 467 |
+
continue
|
| 468 |
+
|
| 469 |
+
# Find the right cluster to update with this span.
|
| 470 |
+
# To do this, we find the row in `antecedent_indices`
|
| 471 |
+
# corresponding to this span we are considering.
|
| 472 |
+
# The predicted antecedent is then an index into this list
|
| 473 |
+
# of indices, denoting the span from `top_spans` which is the
|
| 474 |
+
# most likely antecedent.
|
| 475 |
+
predicted_index = antecedent_indices[i, predicted_antecedent]
|
| 476 |
+
|
| 477 |
+
antecedent_span = (
|
| 478 |
+
top_spans[predicted_index, 0].item(),
|
| 479 |
+
top_spans[predicted_index, 1].item(),
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
# Check if we've seen the span before.
|
| 483 |
+
if antecedent_span in spans_to_cluster_ids:
|
| 484 |
+
predicted_cluster_id: int = spans_to_cluster_ids[antecedent_span]
|
| 485 |
+
else:
|
| 486 |
+
# We start a new cluster.
|
| 487 |
+
predicted_cluster_id = len(clusters)
|
| 488 |
+
# Append a new cluster containing only this span.
|
| 489 |
+
clusters.append([antecedent_span])
|
| 490 |
+
# Record the new id of this span.
|
| 491 |
+
spans_to_cluster_ids[antecedent_span] = predicted_cluster_id
|
| 492 |
+
|
| 493 |
+
# Now add the span we are currently considering.
|
| 494 |
+
span_start, span_end = span[0].item(), span[1].item()
|
| 495 |
+
clusters[predicted_cluster_id].append((span_start, span_end))
|
| 496 |
+
spans_to_cluster_ids[(span_start, span_end)] = predicted_cluster_id
|
| 497 |
+
batch_clusters.append(clusters)
|
| 498 |
+
|
| 499 |
+
output_dict["clusters"] = batch_clusters
|
| 500 |
+
return output_dict
|
| 501 |
+
|
| 502 |
+
def get_metrics(self, reset: bool = False) -> Dict[str, float]:
|
| 503 |
+
mention_recall = self._mention_recall.get_metric(reset)
|
| 504 |
+
coref_precision, coref_recall, coref_f1 = self._conll_coref_scores.get_metric(
|
| 505 |
+
reset
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
return {
|
| 509 |
+
"coref_precision": coref_precision,
|
| 510 |
+
"coref_recall": coref_recall,
|
| 511 |
+
"coref_f1": coref_f1,
|
| 512 |
+
"mention_recall": mention_recall,
|
| 513 |
+
}
|
| 514 |
+
|
| 515 |
+
@staticmethod
|
| 516 |
+
def _generate_valid_antecedents(
|
| 517 |
+
num_spans_to_keep: int, max_antecedents: int, device: int
|
| 518 |
+
) -> Tuple[torch.IntTensor, torch.IntTensor, torch.BoolTensor]:
|
| 519 |
+
"""
|
| 520 |
+
This method generates possible antecedents per span which survived the pruning
|
| 521 |
+
stage. This procedure is `generic across the batch`. The reason this is the case is
|
| 522 |
+
that each span in a batch can be coreferent with any previous span, but here we
|
| 523 |
+
are computing the possible `indices` of these spans. So, regardless of the batch,
|
| 524 |
+
the 1st span _cannot_ have any antecedents, because there are none to select from.
|
| 525 |
+
Similarly, each element can only predict previous spans, so this returns a matrix
|
| 526 |
+
of shape (num_spans_to_keep, max_antecedents), where the (i,j)-th index is equal to
|
| 527 |
+
(i - 1) - j if j <= i, or zero otherwise.
|
| 528 |
+
|
| 529 |
+
# Parameters
|
| 530 |
+
|
| 531 |
+
num_spans_to_keep : `int`, required.
|
| 532 |
+
The number of spans that were kept while pruning.
|
| 533 |
+
max_antecedents : `int`, required.
|
| 534 |
+
The maximum number of antecedent spans to consider for every span.
|
| 535 |
+
device : `int`, required.
|
| 536 |
+
The CUDA device to use.
|
| 537 |
+
|
| 538 |
+
# Returns
|
| 539 |
+
|
| 540 |
+
valid_antecedent_indices : `torch.LongTensor`
|
| 541 |
+
The indices of every antecedent to consider with respect to the top k spans.
|
| 542 |
+
Has shape `(num_spans_to_keep, max_antecedents)`.
|
| 543 |
+
valid_antecedent_offsets : `torch.LongTensor`
|
| 544 |
+
The distance between the span and each of its antecedents in terms of the number
|
| 545 |
+
of considered spans (i.e not the word distance between the spans).
|
| 546 |
+
Has shape `(1, max_antecedents)`.
|
| 547 |
+
valid_antecedent_mask : `torch.BoolTensor`
|
| 548 |
+
The mask representing whether each antecedent span is valid. Required since
|
| 549 |
+
different spans have different numbers of valid antecedents. For example, the first
|
| 550 |
+
span in the document should have no valid antecedents.
|
| 551 |
+
Has shape `(1, num_spans_to_keep, max_antecedents)`.
|
| 552 |
+
"""
|
| 553 |
+
# Shape: (num_spans_to_keep, 1)
|
| 554 |
+
target_indices = util.get_range_vector(num_spans_to_keep, device).unsqueeze(1)
|
| 555 |
+
|
| 556 |
+
# Shape: (1, max_antecedents)
|
| 557 |
+
valid_antecedent_offsets = (
|
| 558 |
+
util.get_range_vector(max_antecedents, device) + 1
|
| 559 |
+
).unsqueeze(0)
|
| 560 |
+
|
| 561 |
+
# This is a broadcasted subtraction.
|
| 562 |
+
# Shape: (num_spans_to_keep, max_antecedents)
|
| 563 |
+
raw_antecedent_indices = target_indices - valid_antecedent_offsets
|
| 564 |
+
|
| 565 |
+
# In our matrix of indices, the upper triangular part will be negative
|
| 566 |
+
# because the offsets will be > the target indices. We want to mask these,
|
| 567 |
+
# because these are exactly the indices which we don't want to predict, per span.
|
| 568 |
+
# Shape: (1, num_spans_to_keep, max_antecedents)
|
| 569 |
+
valid_antecedent_mask = (raw_antecedent_indices >= 0).unsqueeze(0)
|
| 570 |
+
|
| 571 |
+
# Shape: (num_spans_to_keep, max_antecedents)
|
| 572 |
+
valid_antecedent_indices = F.relu(raw_antecedent_indices.float()).long()
|
| 573 |
+
return valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_mask
|
| 574 |
+
|
| 575 |
+
def _distance_pruning(
|
| 576 |
+
self,
|
| 577 |
+
top_span_embeddings: torch.FloatTensor,
|
| 578 |
+
top_span_mention_scores: torch.FloatTensor,
|
| 579 |
+
max_antecedents: int,
|
| 580 |
+
) -> Tuple[torch.FloatTensor, torch.BoolTensor, torch.LongTensor, torch.LongTensor]:
|
| 581 |
+
"""
|
| 582 |
+
Generates antecedents for each span and prunes down to `max_antecedents`. This method
|
| 583 |
+
prunes antecedents only based on distance (i.e. number of intervening spans). The closest
|
| 584 |
+
antecedents are kept.
|
| 585 |
+
|
| 586 |
+
# Parameters
|
| 587 |
+
|
| 588 |
+
top_span_embeddings: `torch.FloatTensor`, required.
|
| 589 |
+
The embeddings of the top spans.
|
| 590 |
+
(batch_size, num_spans_to_keep, embedding_size).
|
| 591 |
+
top_span_mention_scores: `torch.FloatTensor`, required.
|
| 592 |
+
The mention scores of the top spans.
|
| 593 |
+
(batch_size, num_spans_to_keep).
|
| 594 |
+
max_antecedents: `int`, required.
|
| 595 |
+
The maximum number of antecedents to keep for each span.
|
| 596 |
+
|
| 597 |
+
# Returns
|
| 598 |
+
|
| 599 |
+
top_partial_coreference_scores: `torch.FloatTensor`
|
| 600 |
+
The partial antecedent scores for each span-antecedent pair. Computed by summing
|
| 601 |
+
the span mentions scores of the span and the antecedent. This score is partial because
|
| 602 |
+
compared to the full coreference scores, it lacks the interaction term
|
| 603 |
+
w * FFNN([g_i, g_j, g_i * g_j, features]).
|
| 604 |
+
(batch_size, num_spans_to_keep, max_antecedents)
|
| 605 |
+
top_antecedent_mask: `torch.BoolTensor`
|
| 606 |
+
The mask representing whether each antecedent span is valid. Required since
|
| 607 |
+
different spans have different numbers of valid antecedents. For example, the first
|
| 608 |
+
span in the document should have no valid antecedents.
|
| 609 |
+
(batch_size, num_spans_to_keep, max_antecedents)
|
| 610 |
+
top_antecedent_offsets: `torch.LongTensor`
|
| 611 |
+
The distance between the span and each of its antecedents in terms of the number
|
| 612 |
+
of considered spans (i.e not the word distance between the spans).
|
| 613 |
+
(batch_size, num_spans_to_keep, max_antecedents)
|
| 614 |
+
top_antecedent_indices: `torch.LongTensor`
|
| 615 |
+
The indices of every antecedent to consider with respect to the top k spans.
|
| 616 |
+
(batch_size, num_spans_to_keep, max_antecedents)
|
| 617 |
+
"""
|
| 618 |
+
# These antecedent matrices are independent of the batch dimension - they're just a function
|
| 619 |
+
# of the span's position in top_spans.
|
| 620 |
+
# The spans are in document order, so we can just use the relative
|
| 621 |
+
# index of the spans to know which other spans are allowed antecedents.
|
| 622 |
+
|
| 623 |
+
num_spans_to_keep = top_span_embeddings.size(1)
|
| 624 |
+
device = util.get_device_of(top_span_embeddings)
|
| 625 |
+
|
| 626 |
+
# Shapes:
|
| 627 |
+
# (num_spans_to_keep, max_antecedents),
|
| 628 |
+
# (1, max_antecedents),
|
| 629 |
+
# (1, num_spans_to_keep, max_antecedents)
|
| 630 |
+
(
|
| 631 |
+
top_antecedent_indices,
|
| 632 |
+
top_antecedent_offsets,
|
| 633 |
+
top_antecedent_mask,
|
| 634 |
+
) = self._generate_valid_antecedents( # noqa
|
| 635 |
+
num_spans_to_keep, max_antecedents, device
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
# Shape: (batch_size, num_spans_to_keep, max_antecedents)
|
| 639 |
+
top_antecedent_mention_scores = util.flattened_index_select(
|
| 640 |
+
top_span_mention_scores.unsqueeze(-1), top_antecedent_indices
|
| 641 |
+
).squeeze(-1)
|
| 642 |
+
|
| 643 |
+
# Shape: (batch_size, num_spans_to_keep, max_antecedents) * 4
|
| 644 |
+
top_partial_coreference_scores = (
|
| 645 |
+
top_span_mention_scores.unsqueeze(-1) + top_antecedent_mention_scores
|
| 646 |
+
)
|
| 647 |
+
top_antecedent_indices = top_antecedent_indices.unsqueeze(0).expand_as(
|
| 648 |
+
top_partial_coreference_scores
|
| 649 |
+
)
|
| 650 |
+
top_antecedent_offsets = top_antecedent_offsets.unsqueeze(0).expand_as(
|
| 651 |
+
top_partial_coreference_scores
|
| 652 |
+
)
|
| 653 |
+
top_antecedent_mask = top_antecedent_mask.expand_as(
|
| 654 |
+
top_partial_coreference_scores
|
| 655 |
+
)
|
| 656 |
+
|
| 657 |
+
return (
|
| 658 |
+
top_partial_coreference_scores,
|
| 659 |
+
top_antecedent_mask,
|
| 660 |
+
top_antecedent_offsets,
|
| 661 |
+
top_antecedent_indices,
|
| 662 |
+
)
|
| 663 |
+
|
| 664 |
+
def _coarse_to_fine_pruning(
|
| 665 |
+
self,
|
| 666 |
+
top_span_embeddings: torch.FloatTensor,
|
| 667 |
+
top_span_mention_scores: torch.FloatTensor,
|
| 668 |
+
top_span_mask: torch.BoolTensor,
|
| 669 |
+
max_antecedents: int,
|
| 670 |
+
) -> Tuple[torch.FloatTensor, torch.BoolTensor, torch.LongTensor, torch.LongTensor]:
|
| 671 |
+
"""
|
| 672 |
+
Generates antecedents for each span and prunes down to `max_antecedents`. This method
|
| 673 |
+
prunes antecedents using a fast bilinar interaction score between a span and a candidate
|
| 674 |
+
antecedent, and the highest-scoring antecedents are kept.
|
| 675 |
+
|
| 676 |
+
# Parameters
|
| 677 |
+
|
| 678 |
+
top_span_embeddings: `torch.FloatTensor`, required.
|
| 679 |
+
The embeddings of the top spans.
|
| 680 |
+
(batch_size, num_spans_to_keep, embedding_size).
|
| 681 |
+
top_span_mention_scores: `torch.FloatTensor`, required.
|
| 682 |
+
The mention scores of the top spans.
|
| 683 |
+
(batch_size, num_spans_to_keep).
|
| 684 |
+
top_span_mask: `torch.BoolTensor`, required.
|
| 685 |
+
The mask for the top spans.
|
| 686 |
+
(batch_size, num_spans_to_keep).
|
| 687 |
+
max_antecedents: `int`, required.
|
| 688 |
+
The maximum number of antecedents to keep for each span.
|
| 689 |
+
|
| 690 |
+
# Returns
|
| 691 |
+
|
| 692 |
+
top_partial_coreference_scores: `torch.FloatTensor`
|
| 693 |
+
The partial antecedent scores for each span-antecedent pair. Computed by summing
|
| 694 |
+
the span mentions scores of the span and the antecedent as well as a bilinear
|
| 695 |
+
interaction term. This score is partial because compared to the full coreference scores,
|
| 696 |
+
it lacks the interaction term
|
| 697 |
+
`w * FFNN([g_i, g_j, g_i * g_j, features])`.
|
| 698 |
+
`(batch_size, num_spans_to_keep, max_antecedents)`
|
| 699 |
+
top_antecedent_mask: `torch.BoolTensor`
|
| 700 |
+
The mask representing whether each antecedent span is valid. Required since
|
| 701 |
+
different spans have different numbers of valid antecedents. For example, the first
|
| 702 |
+
span in the document should have no valid antecedents.
|
| 703 |
+
`(batch_size, num_spans_to_keep, max_antecedents)`
|
| 704 |
+
top_antecedent_offsets: `torch.LongTensor`
|
| 705 |
+
The distance between the span and each of its antecedents in terms of the number
|
| 706 |
+
of considered spans (i.e not the word distance between the spans).
|
| 707 |
+
`(batch_size, num_spans_to_keep, max_antecedents)`
|
| 708 |
+
top_antecedent_indices: `torch.LongTensor`
|
| 709 |
+
The indices of every antecedent to consider with respect to the top k spans.
|
| 710 |
+
`(batch_size, num_spans_to_keep, max_antecedents)`
|
| 711 |
+
"""
|
| 712 |
+
# batch_size, num_spans_to_keep = top_span_embeddings.size()[:2]
|
| 713 |
+
batch_size, num_spans_to_keep = top_span_embeddings.shape[:2]
|
| 714 |
+
device = util.get_device_of(top_span_embeddings)
|
| 715 |
+
|
| 716 |
+
# Shape: (1, num_spans_to_keep, num_spans_to_keep)
|
| 717 |
+
_, _, valid_antecedent_mask = self._generate_valid_antecedents(
|
| 718 |
+
num_spans_to_keep, num_spans_to_keep, device
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
mention_one_score = top_span_mention_scores.unsqueeze(1)
|
| 722 |
+
mention_two_score = top_span_mention_scores.unsqueeze(2)
|
| 723 |
+
bilinear_weights = self._coarse2fine_scorer(top_span_embeddings).transpose(1, 2)
|
| 724 |
+
bilinear_score = torch.matmul(top_span_embeddings, bilinear_weights)
|
| 725 |
+
# Shape: (batch_size, num_spans_to_keep, num_spans_to_keep); broadcast op
|
| 726 |
+
partial_antecedent_scores = (
|
| 727 |
+
mention_one_score + mention_two_score + bilinear_score
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
# Shape: (batch_size, num_spans_to_keep, num_spans_to_keep); broadcast op
|
| 731 |
+
span_pair_mask = top_span_mask.unsqueeze(-1) & valid_antecedent_mask
|
| 732 |
+
|
| 733 |
+
# Shape:
|
| 734 |
+
# (batch_size, num_spans_to_keep, max_antecedents) * 3
|
| 735 |
+
# k_tensor = torch.full((batch_size,), max_antecedents, dtype=torch.long, device=top_span_embeddings.device)
|
| 736 |
+
k_tensor = torch.full(
|
| 737 |
+
(batch_size, num_spans_to_keep),
|
| 738 |
+
max_antecedents,
|
| 739 |
+
dtype=torch.long,
|
| 740 |
+
device=top_span_embeddings.device,
|
| 741 |
+
)
|
| 742 |
+
|
| 743 |
+
(
|
| 744 |
+
top_partial_coreference_scores,
|
| 745 |
+
top_antecedent_mask,
|
| 746 |
+
top_antecedent_indices,
|
| 747 |
+
) = util.masked_topk(
|
| 748 |
+
partial_antecedent_scores, span_pair_mask, k_tensor, dim=-1
|
| 749 |
+
)
|
| 750 |
+
# (
|
| 751 |
+
# top_partial_coreference_scores,
|
| 752 |
+
# top_antecedent_mask,
|
| 753 |
+
# top_antecedent_indices,
|
| 754 |
+
# ) = util.masked_topk(partial_antecedent_scores, span_pair_mask, max_antecedents)
|
| 755 |
+
|
| 756 |
+
top_span_range = util.get_range_vector(num_spans_to_keep, device)
|
| 757 |
+
# Shape: (num_spans_to_keep, num_spans_to_keep); broadcast op
|
| 758 |
+
valid_antecedent_offsets = top_span_range.unsqueeze(
|
| 759 |
+
-1
|
| 760 |
+
) - top_span_range.unsqueeze(0)
|
| 761 |
+
|
| 762 |
+
# TODO: we need to make `batched_index_select` more general to make this less awkward.
|
| 763 |
+
top_antecedent_offsets = util.batched_index_select(
|
| 764 |
+
valid_antecedent_offsets.unsqueeze(0)
|
| 765 |
+
.expand(batch_size, num_spans_to_keep, num_spans_to_keep)
|
| 766 |
+
.reshape(batch_size * num_spans_to_keep, num_spans_to_keep, 1),
|
| 767 |
+
top_antecedent_indices.view(-1, max_antecedents),
|
| 768 |
+
).reshape(batch_size, num_spans_to_keep, max_antecedents)
|
| 769 |
+
|
| 770 |
+
return (
|
| 771 |
+
top_partial_coreference_scores,
|
| 772 |
+
top_antecedent_mask,
|
| 773 |
+
top_antecedent_offsets,
|
| 774 |
+
top_antecedent_indices,
|
| 775 |
+
)
|
| 776 |
+
|
| 777 |
+
def _compute_span_pair_embeddings(
|
| 778 |
+
self,
|
| 779 |
+
top_span_embeddings: torch.FloatTensor,
|
| 780 |
+
antecedent_embeddings: torch.FloatTensor,
|
| 781 |
+
antecedent_offsets: torch.FloatTensor,
|
| 782 |
+
):
|
| 783 |
+
"""
|
| 784 |
+
Computes an embedding representation of pairs of spans for the pairwise scoring function
|
| 785 |
+
to consider. This includes both the original span representations, the element-wise
|
| 786 |
+
similarity of the span representations, and an embedding representation of the distance
|
| 787 |
+
between the two spans.
|
| 788 |
+
|
| 789 |
+
# Parameters
|
| 790 |
+
|
| 791 |
+
top_span_embeddings : `torch.FloatTensor`, required.
|
| 792 |
+
Embedding representations of the top spans. Has shape
|
| 793 |
+
(batch_size, num_spans_to_keep, embedding_size).
|
| 794 |
+
antecedent_embeddings : `torch.FloatTensor`, required.
|
| 795 |
+
Embedding representations of the antecedent spans we are considering
|
| 796 |
+
for each top span. Has shape
|
| 797 |
+
(batch_size, num_spans_to_keep, max_antecedents, embedding_size).
|
| 798 |
+
antecedent_offsets : `torch.IntTensor`, required.
|
| 799 |
+
The offsets between each top span and its antecedent spans in terms
|
| 800 |
+
of spans we are considering. Has shape (batch_size, num_spans_to_keep, max_antecedents).
|
| 801 |
+
|
| 802 |
+
# Returns
|
| 803 |
+
|
| 804 |
+
span_pair_embeddings : `torch.FloatTensor`
|
| 805 |
+
Embedding representation of the pair of spans to consider. Has shape
|
| 806 |
+
(batch_size, num_spans_to_keep, max_antecedents, embedding_size)
|
| 807 |
+
"""
|
| 808 |
+
# Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
|
| 809 |
+
target_embeddings = top_span_embeddings.unsqueeze(2).expand_as(
|
| 810 |
+
antecedent_embeddings
|
| 811 |
+
)
|
| 812 |
+
|
| 813 |
+
# Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
|
| 814 |
+
antecedent_distance_embeddings = self._distance_embedding(
|
| 815 |
+
util.bucket_values(
|
| 816 |
+
antecedent_offsets, num_total_buckets=self._num_distance_buckets
|
| 817 |
+
)
|
| 818 |
+
)
|
| 819 |
+
|
| 820 |
+
# Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
|
| 821 |
+
span_pair_embeddings = torch.cat(
|
| 822 |
+
[
|
| 823 |
+
target_embeddings,
|
| 824 |
+
antecedent_embeddings,
|
| 825 |
+
antecedent_embeddings * target_embeddings,
|
| 826 |
+
antecedent_distance_embeddings,
|
| 827 |
+
],
|
| 828 |
+
-1,
|
| 829 |
+
)
|
| 830 |
+
return span_pair_embeddings
|
| 831 |
+
|
| 832 |
+
@staticmethod
|
| 833 |
+
def _compute_antecedent_gold_labels(
|
| 834 |
+
top_span_labels: torch.IntTensor, antecedent_labels: torch.IntTensor
|
| 835 |
+
):
|
| 836 |
+
"""
|
| 837 |
+
Generates a binary indicator for every pair of spans. This label is one if and
|
| 838 |
+
only if the pair of spans belong to the same cluster. The labels are augmented
|
| 839 |
+
with a dummy antecedent at the zeroth position, which represents the prediction
|
| 840 |
+
that a span does not have any antecedent.
|
| 841 |
+
|
| 842 |
+
# Parameters
|
| 843 |
+
|
| 844 |
+
top_span_labels : `torch.IntTensor`, required.
|
| 845 |
+
The cluster id label for every span. The id is arbitrary,
|
| 846 |
+
as we just care about the clustering. Has shape (batch_size, num_spans_to_keep).
|
| 847 |
+
antecedent_labels : `torch.IntTensor`, required.
|
| 848 |
+
The cluster id label for every antecedent span. The id is arbitrary,
|
| 849 |
+
as we just care about the clustering. Has shape
|
| 850 |
+
(batch_size, num_spans_to_keep, max_antecedents).
|
| 851 |
+
|
| 852 |
+
# Returns
|
| 853 |
+
|
| 854 |
+
pairwise_labels_with_dummy_label : `torch.FloatTensor`
|
| 855 |
+
A binary tensor representing whether a given pair of spans belong to
|
| 856 |
+
the same cluster in the gold clustering.
|
| 857 |
+
Has shape (batch_size, num_spans_to_keep, max_antecedents + 1).
|
| 858 |
+
|
| 859 |
+
"""
|
| 860 |
+
# Shape: (batch_size, num_spans_to_keep, max_antecedents)
|
| 861 |
+
target_labels = top_span_labels.expand_as(antecedent_labels)
|
| 862 |
+
same_cluster_indicator = (target_labels == antecedent_labels).float()
|
| 863 |
+
non_dummy_indicator = (target_labels >= 0).float()
|
| 864 |
+
pairwise_labels = same_cluster_indicator * non_dummy_indicator
|
| 865 |
+
|
| 866 |
+
# Shape: (batch_size, num_spans_to_keep, 1)
|
| 867 |
+
dummy_labels = (1 - pairwise_labels).prod(-1, keepdim=True)
|
| 868 |
+
|
| 869 |
+
# Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
|
| 870 |
+
pairwise_labels_with_dummy_label = torch.cat(
|
| 871 |
+
[dummy_labels, pairwise_labels], -1
|
| 872 |
+
)
|
| 873 |
+
return pairwise_labels_with_dummy_label
|
| 874 |
+
|
| 875 |
+
def _compute_coreference_scores(
|
| 876 |
+
self,
|
| 877 |
+
top_span_embeddings: torch.FloatTensor,
|
| 878 |
+
top_antecedent_embeddings: torch.FloatTensor,
|
| 879 |
+
top_partial_coreference_scores: torch.FloatTensor,
|
| 880 |
+
top_antecedent_mask: torch.BoolTensor,
|
| 881 |
+
top_antecedent_offsets: torch.FloatTensor,
|
| 882 |
+
) -> torch.FloatTensor:
|
| 883 |
+
"""
|
| 884 |
+
Computes scores for every pair of spans. Additionally, a dummy label is included,
|
| 885 |
+
representing the decision that the span is not coreferent with anything. For the dummy
|
| 886 |
+
label, the score is always zero. For the true antecedent spans, the score consists of
|
| 887 |
+
the pairwise antecedent score and the unary mention scores for the span and its
|
| 888 |
+
antecedent. The factoring allows the model to blame many of the absent links on bad
|
| 889 |
+
spans, enabling the pruning strategy used in the forward pass.
|
| 890 |
+
|
| 891 |
+
# Parameters
|
| 892 |
+
|
| 893 |
+
top_span_embeddings : `torch.FloatTensor`, required.
|
| 894 |
+
Embedding representations of the kept spans. Has shape
|
| 895 |
+
(batch_size, num_spans_to_keep, embedding_size)
|
| 896 |
+
top_antecedent_embeddings: `torch.FloatTensor`, required.
|
| 897 |
+
The embeddings of antecedents for each span candidate. Has shape
|
| 898 |
+
(batch_size, num_spans_to_keep, max_antecedents, embedding_size)
|
| 899 |
+
top_partial_coreference_scores : `torch.FloatTensor`, required.
|
| 900 |
+
Sum of span mention score and antecedent mention score. The coarse to fine settings
|
| 901 |
+
has an additional term which is the coarse bilinear score.
|
| 902 |
+
(batch_size, num_spans_to_keep, max_antecedents).
|
| 903 |
+
top_antecedent_mask : `torch.BoolTensor`, required.
|
| 904 |
+
The mask for valid antecedents.
|
| 905 |
+
(batch_size, num_spans_to_keep, max_antecedents).
|
| 906 |
+
top_antecedent_offsets : `torch.FloatTensor`, required.
|
| 907 |
+
The distance between the span and each of its antecedents in terms of the number
|
| 908 |
+
of considered spans (i.e not the word distance between the spans).
|
| 909 |
+
(batch_size, num_spans_to_keep, max_antecedents).
|
| 910 |
+
|
| 911 |
+
# Returns
|
| 912 |
+
|
| 913 |
+
coreference_scores : `torch.FloatTensor`
|
| 914 |
+
A tensor of shape (batch_size, num_spans_to_keep, max_antecedents + 1),
|
| 915 |
+
representing the unormalised score for each (span, antecedent) pair
|
| 916 |
+
we considered.
|
| 917 |
+
|
| 918 |
+
"""
|
| 919 |
+
# Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
|
| 920 |
+
span_pair_embeddings = self._compute_span_pair_embeddings(
|
| 921 |
+
top_span_embeddings, top_antecedent_embeddings, top_antecedent_offsets
|
| 922 |
+
)
|
| 923 |
+
|
| 924 |
+
# Shape: (batch_size, num_spans_to_keep, max_antecedents)
|
| 925 |
+
antecedent_scores = self._antecedent_scorer(
|
| 926 |
+
self._antecedent_feedforward(span_pair_embeddings)
|
| 927 |
+
).squeeze(-1)
|
| 928 |
+
antecedent_scores += top_partial_coreference_scores
|
| 929 |
+
antecedent_scores = util.replace_masked_values(
|
| 930 |
+
antecedent_scores,
|
| 931 |
+
top_antecedent_mask,
|
| 932 |
+
util.min_value_of_dtype(antecedent_scores.dtype),
|
| 933 |
+
)
|
| 934 |
+
|
| 935 |
+
# Shape: (batch_size, num_spans_to_keep, 1)
|
| 936 |
+
shape = [antecedent_scores.size(0), antecedent_scores.size(1), 1]
|
| 937 |
+
dummy_scores = antecedent_scores.new_zeros(*shape)
|
| 938 |
+
|
| 939 |
+
# Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
|
| 940 |
+
coreference_scores = torch.cat([dummy_scores, antecedent_scores], -1)
|
| 941 |
+
return coreference_scores
|
| 942 |
+
|
| 943 |
+
default_predictor = "coreference_resolution"
|
allennlp-coref-onnx-mMiniLMv2-L12-H384-distilled-from-XLMR-Large/export/example.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import allennlp_models.coref
|
| 2 |
+
|
| 3 |
+
from allennlp.models.archival import load_archive
|
| 4 |
+
from allennlp.predictors.predictor import Predictor
|
| 5 |
+
|
| 6 |
+
archive_path = "/models/minillm/model.tar.gz"
|
| 7 |
+
|
| 8 |
+
archive = load_archive(archive_path)
|
| 9 |
+
predictor = Predictor.from_archive(archive, predictor_name="coreference_resolution")
|
| 10 |
+
|
| 11 |
+
text = (
|
| 12 |
+
"Barack Obama was the 44th President of the United States. He was born in Hawaii."
|
| 13 |
+
)
|
| 14 |
+
result = predictor.predict(document=text)
|
| 15 |
+
|
| 16 |
+
print(result["clusters"])
|
allennlp-coref-onnx-mMiniLMv2-L12-H384-distilled-from-XLMR-Large/export/export_onnx.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from allennlp.models.archival import load_archive
|
| 5 |
+
from allennlp.predictors import Predictor
|
| 6 |
+
from allennlp.data import Batch
|
| 7 |
+
|
| 8 |
+
import allennlp_models.coref # Registers coref models and readers
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class CorefONNXWrapper(nn.Module):
|
| 12 |
+
def __init__(self, model):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.model = model
|
| 15 |
+
|
| 16 |
+
def forward(
|
| 17 |
+
self,
|
| 18 |
+
token_ids,
|
| 19 |
+
mask,
|
| 20 |
+
type_ids,
|
| 21 |
+
wordpiece_mask,
|
| 22 |
+
segment_concat_mask,
|
| 23 |
+
offsets,
|
| 24 |
+
spans,
|
| 25 |
+
):
|
| 26 |
+
text = {
|
| 27 |
+
"tokens": {
|
| 28 |
+
"token_ids": token_ids,
|
| 29 |
+
"mask": mask,
|
| 30 |
+
"type_ids": type_ids,
|
| 31 |
+
"wordpiece_mask": wordpiece_mask,
|
| 32 |
+
"segment_concat_mask": segment_concat_mask,
|
| 33 |
+
"offsets": offsets,
|
| 34 |
+
}
|
| 35 |
+
}
|
| 36 |
+
output = self.model(text=text, spans=spans)
|
| 37 |
+
return (
|
| 38 |
+
output["top_spans"],
|
| 39 |
+
output["antecedent_indices"],
|
| 40 |
+
output["predicted_antecedents"],
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def pad_spans(spans: torch.Tensor, max_len: int) -> torch.Tensor:
|
| 45 |
+
"""
|
| 46 |
+
Pads a tensor of spans to max_len along dimension 0.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
spans: Tensor of shape (num_spans, 2)
|
| 50 |
+
max_len: Desired number of spans (along dim 0)
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
Tensor of shape (max_len, 2)
|
| 54 |
+
"""
|
| 55 |
+
num_spans = spans.size(0)
|
| 56 |
+
|
| 57 |
+
if num_spans >= max_len:
|
| 58 |
+
return spans[:max_len]
|
| 59 |
+
else:
|
| 60 |
+
padding = torch.zeros(
|
| 61 |
+
(max_len - num_spans, 2), dtype=spans.dtype, device=spans.device
|
| 62 |
+
)
|
| 63 |
+
return torch.cat([spans, padding], dim=0)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def export_model_to_onnx(archive_path: str, onnx_path: str):
|
| 67 |
+
# Load archive and predictor
|
| 68 |
+
archive = load_archive(archive_path)
|
| 69 |
+
predictor = Predictor.from_archive(archive, predictor_name="coreference_resolution")
|
| 70 |
+
model = predictor._model
|
| 71 |
+
dataset_reader = predictor._dataset_reader
|
| 72 |
+
|
| 73 |
+
# Example input
|
| 74 |
+
input_json = {
|
| 75 |
+
"document": ["My", "sister", "has", "a", "dog", ".", "She", "loves", "him", "."]
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
# Convert input to Instance and batch
|
| 79 |
+
instance = dataset_reader.text_to_instance(input_json["document"])
|
| 80 |
+
instances = [instance]
|
| 81 |
+
dataset_reader.apply_token_indexers(instances)
|
| 82 |
+
|
| 83 |
+
batch = Batch(instances)
|
| 84 |
+
batch.index_instances(model.vocab)
|
| 85 |
+
tensor_dict = batch.as_tensor_dict()
|
| 86 |
+
|
| 87 |
+
# Filter only required args for forward()
|
| 88 |
+
model_input = {
|
| 89 |
+
"text": tensor_dict["text"], # Nested dict of token tensors
|
| 90 |
+
"spans": tensor_dict["spans"], # Tensor of span indices
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
for k, v in model_input["text"]["tokens"].items():
|
| 94 |
+
print(k, v.shape)
|
| 95 |
+
|
| 96 |
+
print("spans", model_input["spans"].shape)
|
| 97 |
+
print(model_input["spans"])
|
| 98 |
+
|
| 99 |
+
# Move to CPU and eval mode
|
| 100 |
+
device = torch.device("cpu")
|
| 101 |
+
model = model.to(device).eval()
|
| 102 |
+
|
| 103 |
+
for k, v in model_input.items():
|
| 104 |
+
if isinstance(v, torch.Tensor):
|
| 105 |
+
model_input[k] = v.to(device)
|
| 106 |
+
elif isinstance(v, dict):
|
| 107 |
+
model_input[k] = {
|
| 108 |
+
kk: vv.to(device) if isinstance(vv, torch.Tensor) else vv
|
| 109 |
+
for kk, vv in v.items()
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
# Wrap and prepare export
|
| 113 |
+
wrapper = CorefONNXWrapper(model)
|
| 114 |
+
max_num_spans = 300 # <-- or any upper bound you want
|
| 115 |
+
padded_spans = pad_spans(model_input["spans"].squeeze(0), max_num_spans).unsqueeze(
|
| 116 |
+
0
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
example_inputs = (
|
| 120 |
+
model_input["text"]["tokens"]["token_ids"],
|
| 121 |
+
model_input["text"]["tokens"]["mask"],
|
| 122 |
+
model_input["text"]["tokens"]["type_ids"],
|
| 123 |
+
model_input["text"]["tokens"]["wordpiece_mask"],
|
| 124 |
+
model_input["text"]["tokens"]["segment_concat_mask"],
|
| 125 |
+
model_input["text"]["tokens"]["offsets"],
|
| 126 |
+
padded_spans,
|
| 127 |
+
)
|
| 128 |
+
torch.onnx.export(
|
| 129 |
+
wrapper,
|
| 130 |
+
args=example_inputs,
|
| 131 |
+
f=onnx_path,
|
| 132 |
+
input_names=[
|
| 133 |
+
"token_ids",
|
| 134 |
+
"mask",
|
| 135 |
+
"type_ids",
|
| 136 |
+
"wordpiece_mask",
|
| 137 |
+
"segment_concat_mask",
|
| 138 |
+
"offsets",
|
| 139 |
+
"spans",
|
| 140 |
+
],
|
| 141 |
+
output_names=["top_spans", "antecedent_indices", "predicted_antecedents"],
|
| 142 |
+
dynamic_axes={
|
| 143 |
+
"token_ids": {0: "batch_size", 1: "seq_len"},
|
| 144 |
+
"mask": {0: "batch_size", 1: "orig_seq_len"},
|
| 145 |
+
"type_ids": {0: "batch_size", 1: "seq_len"},
|
| 146 |
+
"wordpiece_mask": {0: "batch_size", 1: "seq_len"},
|
| 147 |
+
"segment_concat_mask": {0: "batch_size", 1: "seq_len"},
|
| 148 |
+
"offsets": {0: "batch_size", 1: "orig_seq_len"},
|
| 149 |
+
"spans": {0: "batch_size", 1: "num_spans"},
|
| 150 |
+
"top_spans": {0: "batch_size", 1: "num_spans_to_keep"},
|
| 151 |
+
"antecedent_indices": {
|
| 152 |
+
0: "batch_size",
|
| 153 |
+
1: "num_spans_to_keep",
|
| 154 |
+
2: "max_antecedents",
|
| 155 |
+
},
|
| 156 |
+
"predicted_antecedents": {0: "batch_size", 1: "num_spans_to_keep"},
|
| 157 |
+
},
|
| 158 |
+
opset_version=15,
|
| 159 |
+
do_constant_folding=True,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
if __name__ == "__main__":
|
| 164 |
+
archive_path = "/models/minillm/model.tar.gz"
|
| 165 |
+
onnx_path = "/models/minillm/model.onnx"
|
| 166 |
+
export_model_to_onnx(archive_path, onnx_path)
|
allennlp-coref-onnx-mMiniLMv2-L12-H384-distilled-from-XLMR-Large/model.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:be2ad55e2b36c7e4e007aadc01d151e3d1f563c8f62a349736a17cb5ea27abe4
|
| 3 |
+
size 522012569
|
allennlp-coref-onnx-mMiniLMv2-L12-H384-distilled-from-XLMR-Large/source.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
https://huggingface.co/talmago/allennlp-coref-onnx-mMiniLMv2-L12-H384-distilled-from-XLMR-Large
|
allennlp-coref-onnx-mMiniLMv2-L12-H384-distilled-from-XLMR-Large/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|