niobures commited on
Commit
0a2677f
·
verified ·
1 Parent(s): 8dacc12

allennlp-coref-onnx-mMiniLMv2-L12-H384-distilled-from-XLMR-Large

Browse files
allennlp-coref-onnx-mMiniLMv2-L12-H384-distilled-from-XLMR-Large/.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
allennlp-coref-onnx-mMiniLMv2-L12-H384-distilled-from-XLMR-Large/README.md ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ base_model:
4
+ - nreimers/mMiniLMv2-L6-H384-distilled-from-XLMR-Large
5
+ pipeline_tag: token-classification
6
+ tags:
7
+ - coreference-resolution
8
+ - multilingual
9
+ - onnx
10
+ ---
11
+
12
+ ## Usage
13
+
14
+ ```sh
15
+ $ pip install coref-onnx
16
+ ```
17
+
18
+ ```python
19
+ from coref_onnx import CoreferenceResolver, decode_clusters
20
+
21
+ resolver = CoreferenceResolver.from_pretrained("talmago/allennlp-coref-onnx-mMiniLMv2-L12-H384-distilled-from-XLMR-Large")
22
+
23
+ sentences = [
24
+ ["Barack", "Obama", "was", "the", "44th", "President", "of", "the", "United", "States", "."],
25
+ ["He", "was", "born", "in", "Hawaii", "."]
26
+ ]
27
+
28
+ pred = resolver(sentences)
29
+
30
+ print("Clusters:", pred["clusters"][0])
31
+ print("Decoded clusters:", decode_clusters(sentences, pred["clusters"][0]))
32
+ ```
33
+
34
+ Output is:
35
+
36
+ ```
37
+ Clusters: [[[(0, 1), (11, 11)]]]
38
+ Decoded clusters: [['Barack Obama', 'He']]
39
+ ```
40
+
41
+ ## ONNX
42
+
43
+ Download MiniLM model archive
44
+
45
+ ```sh
46
+ $ mkdir -p models/minillm
47
+ $ wget -P models/minillm https://storage.googleapis.com/pandora-intelligence/models/crosslingual-coreference/minilm/model.tar.gz
48
+ ```
49
+
50
+ Run docker container:
51
+
52
+ ```sh
53
+ $ docker run -it --platform linux/amd64 --entrypoint /bin/bash -v $(pwd)/models/minillm:/models/minillm allennlp/allennlp:latest
54
+ ```
55
+
56
+ Install `allennlp_models`
57
+
58
+ ```sh
59
+ $ pip install allennlp_models
60
+ ```
61
+
62
+ Use another tab copy source code and scripts to the container
63
+
64
+ ```sh
65
+ $ docker cp allennlp_models/coref/models/coref.py <CONTAINER_ID>:/opt/conda/lib/python3.8/site-packages/allennlp_models/coref/models/coref.py
66
+ $ docker cp allennlp/nn/util.py <CONTAINER_ID>:/stage/allennlp/allennlp/nn/util.py
67
+ $ docker cp export_onnx.py <CONTAINER_ID>:/app/export_onnx.py
68
+ ```
69
+
70
+ In the container run:
71
+
72
+ ```sh
73
+ $ mkdir nreimers
74
+ $ git clone https://huggingface.co/nreimers/mMiniLMv2-L12-H384-distilled-from-XLMR-Large nreimers
75
+ ```
76
+
77
+ And then run the export script:
78
+
79
+ ```sh
80
+ $ python export_onnx.py
81
+ ```
82
+
83
+ ## Model Optimization
84
+
85
+ Run `onnxsim`
86
+
87
+ ```sh
88
+ $ python -m onnxsim models/minillm/model.onnx optimized_model.onnx
89
+ ```
allennlp-coref-onnx-mMiniLMv2-L12-H384-distilled-from-XLMR-Large/config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "max_span_width": 5,
3
+ "max_spans": 512,
4
+ "spans_per_word": 0.4,
5
+ "max_antecedents": 50,
6
+ "inference_order": 2,
7
+ "feature_size": 20,
8
+ "coarse_to_fine": true,
9
+ "model_hidden_dims": 1500,
10
+ "model_dropout": 0.3,
11
+ "mention_input_dim": 1172,
12
+ "antecedent_input_dim": 3536,
13
+ "transformer": {
14
+ "model_name": "nreimers/mMiniLMv2-L12-H384-distilled-from-XLMR-Large",
15
+ "hidden_size": 384,
16
+ "max_position_embeddings": 512
17
+ },
18
+ "max_sentences": 120
19
+ }
allennlp-coref-onnx-mMiniLMv2-L12-H384-distilled-from-XLMR-Large/export/allennlp/nn/util.py ADDED
The diff for this file is too large to render. See raw diff
 
allennlp-coref-onnx-mMiniLMv2-L12-H384-distilled-from-XLMR-Large/export/allennlp_models/coref/models/coref.py ADDED
@@ -0,0 +1,943 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from typing import Any, Dict, List, Tuple
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+
9
+ from allennlp.data import TextFieldTensors, Vocabulary
10
+ from allennlp.models.model import Model
11
+ from allennlp.modules.token_embedders import Embedding
12
+ from allennlp.modules import FeedForward, GatedSum
13
+ from allennlp.modules import Seq2SeqEncoder, TimeDistributed, TextFieldEmbedder
14
+ from allennlp.modules.span_extractors import (
15
+ SelfAttentiveSpanExtractor,
16
+ EndpointSpanExtractor,
17
+ )
18
+ from allennlp.nn import util, InitializerApplicator
19
+
20
+ from allennlp_models.coref.metrics.conll_coref_scores import ConllCorefScores
21
+ from allennlp_models.coref.metrics.mention_recall import MentionRecall
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ @Model.register("coref")
27
+ class CoreferenceResolver(Model):
28
+ """
29
+ This `Model` implements the coreference resolution model described in
30
+ [Higher-order Coreference Resolution with Coarse-to-fine Inference](https://arxiv.org/pdf/1804.05392.pdf)
31
+ by Lee et al., 2018.
32
+ The basic outline of this model is to get an embedded representation of each span in the
33
+ document. These span representations are scored and used to prune away spans that are unlikely
34
+ to occur in a coreference cluster. For the remaining spans, the model decides which antecedent
35
+ span (if any) they are coreferent with. The resulting coreference links, after applying
36
+ transitivity, imply a clustering of the spans in the document.
37
+
38
+ # Parameters
39
+
40
+ vocab : `Vocabulary`
41
+ text_field_embedder : `TextFieldEmbedder`
42
+ Used to embed the `text` `TextField` we get as input to the model.
43
+ context_layer : `Seq2SeqEncoder`
44
+ This layer incorporates contextual information for each word in the document.
45
+ mention_feedforward : `FeedForward`
46
+ This feedforward network is applied to the span representations which is then scored
47
+ by a linear layer.
48
+ antecedent_feedforward : `FeedForward`
49
+ This feedforward network is applied to pairs of span representation, along with any
50
+ pairwise features, which is then scored by a linear layer.
51
+ feature_size : `int`
52
+ The embedding size for all the embedded features, such as distances or span widths.
53
+ max_span_width : `int`
54
+ The maximum width of candidate spans.
55
+ spans_per_word: `float`, required.
56
+ A multiplier between zero and one which controls what percentage of candidate mention
57
+ spans we retain with respect to the number of words in the document.
58
+ max_antecedents: `int`, required.
59
+ For each mention which survives the pruning stage, we consider this many antecedents.
60
+ coarse_to_fine: `bool`, optional (default = `False`)
61
+ Whether or not to apply the coarse-to-fine filtering.
62
+ inference_order: `int`, optional (default = `1`)
63
+ The number of inference orders. When greater than 1, the span representations are
64
+ updated and coreference scores re-computed.
65
+ lexical_dropout : `int`
66
+ The probability of dropping out dimensions of the embedded text.
67
+ initializer : `InitializerApplicator`, optional (default=`InitializerApplicator()`)
68
+ Used to initialize the model parameters.
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ vocab: Vocabulary,
74
+ text_field_embedder: TextFieldEmbedder,
75
+ context_layer: Seq2SeqEncoder,
76
+ mention_feedforward: FeedForward,
77
+ antecedent_feedforward: FeedForward,
78
+ feature_size: int,
79
+ max_span_width: int,
80
+ spans_per_word: float,
81
+ max_antecedents: int,
82
+ coarse_to_fine: bool = False,
83
+ inference_order: int = 1,
84
+ lexical_dropout: float = 0.2,
85
+ initializer: InitializerApplicator = InitializerApplicator(),
86
+ **kwargs,
87
+ ) -> None:
88
+ super().__init__(vocab, **kwargs)
89
+
90
+ self._text_field_embedder = text_field_embedder
91
+ self._context_layer = context_layer
92
+ self._mention_feedforward = TimeDistributed(mention_feedforward)
93
+ self._mention_scorer = TimeDistributed(
94
+ torch.nn.Linear(mention_feedforward.get_output_dim(), 1)
95
+ )
96
+ self._antecedent_feedforward = TimeDistributed(antecedent_feedforward)
97
+ self._antecedent_scorer = TimeDistributed(
98
+ torch.nn.Linear(antecedent_feedforward.get_output_dim(), 1)
99
+ )
100
+
101
+ self._endpoint_span_extractor = EndpointSpanExtractor(
102
+ context_layer.get_output_dim(),
103
+ combination="x,y",
104
+ num_width_embeddings=max_span_width,
105
+ span_width_embedding_dim=feature_size,
106
+ bucket_widths=False,
107
+ )
108
+ self._attentive_span_extractor = SelfAttentiveSpanExtractor(
109
+ input_dim=text_field_embedder.get_output_dim()
110
+ )
111
+
112
+ # 10 possible distance buckets.
113
+ self._num_distance_buckets = 10
114
+ self._distance_embedding = Embedding(
115
+ embedding_dim=feature_size, num_embeddings=self._num_distance_buckets
116
+ )
117
+
118
+ self._max_span_width = max_span_width
119
+ self._spans_per_word = spans_per_word
120
+ self._max_antecedents = max_antecedents
121
+
122
+ self._coarse_to_fine = coarse_to_fine
123
+ if self._coarse_to_fine:
124
+ self._coarse2fine_scorer = torch.nn.Linear(
125
+ mention_feedforward.get_input_dim(), mention_feedforward.get_input_dim()
126
+ )
127
+ self._inference_order = inference_order
128
+ if self._inference_order > 1:
129
+ self._span_updating_gated_sum = GatedSum(
130
+ mention_feedforward.get_input_dim()
131
+ )
132
+
133
+ self._mention_recall = MentionRecall()
134
+ self._conll_coref_scores = ConllCorefScores()
135
+ if lexical_dropout > 0:
136
+ self._lexical_dropout = torch.nn.Dropout(p=lexical_dropout)
137
+ else:
138
+ self._lexical_dropout = lambda x: x
139
+ initializer(self)
140
+
141
+ def forward(
142
+ self, # type: ignore
143
+ text: TextFieldTensors,
144
+ spans: torch.IntTensor,
145
+ span_labels: torch.IntTensor = None,
146
+ metadata: List[Dict[str, Any]] = None,
147
+ ) -> Dict[str, torch.Tensor]:
148
+ """
149
+ # Parameters
150
+
151
+ text : `TextFieldTensors`, required.
152
+ The output of a `TextField` representing the text of
153
+ the document.
154
+ spans : `torch.IntTensor`, required.
155
+ A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end
156
+ indices of candidate spans for mentions. Comes from a `ListField[SpanField]` of
157
+ indices into the text of the document.
158
+ span_labels : `torch.IntTensor`, optional (default = `None`).
159
+ A tensor of shape (batch_size, num_spans), representing the cluster ids
160
+ of each span, or -1 for those which do not appear in any clusters.
161
+ metadata : `List[Dict[str, Any]]`, optional (default = `None`).
162
+ A metadata dictionary for each instance in the batch. We use the "original_text" and "clusters" keys
163
+ from this dictionary, which respectively have the original text and the annotated gold coreference
164
+ clusters for that instance.
165
+
166
+ # Returns
167
+
168
+ An output dictionary consisting of:
169
+
170
+ top_spans : `torch.IntTensor`
171
+ A tensor of shape `(batch_size, num_spans_to_keep, 2)` representing
172
+ the start and end word indices of the top spans that survived the pruning stage.
173
+ antecedent_indices : `torch.IntTensor`
174
+ A tensor of shape `(num_spans_to_keep, max_antecedents)` representing for each top span
175
+ the index (with respect to top_spans) of the possible antecedents the model considered.
176
+ predicted_antecedents : `torch.IntTensor`
177
+ A tensor of shape `(batch_size, num_spans_to_keep)` representing, for each top span, the
178
+ index (with respect to antecedent_indices) of the most likely antecedent. -1 means there
179
+ was no predicted link.
180
+ loss : `torch.FloatTensor`, optional
181
+ A scalar loss to be optimised.
182
+ """
183
+ # Shape: (batch_size, document_length, embedding_size)
184
+ text_embeddings = self._lexical_dropout(self._text_field_embedder(text))
185
+
186
+ batch_size = spans.shape[0]
187
+ document_length = text_embeddings.shape[1]
188
+ num_spans = spans.shape[1]
189
+
190
+ # Shape: (batch_size, document_length)
191
+ text_mask = util.get_text_field_mask(text)
192
+
193
+ # Shape: (batch_size, num_spans)
194
+ # span_mask = (spans[:, :, 0] >= 0).squeeze(-1)
195
+ span_mask = spans[:, :, 0] >= 0
196
+
197
+ # SpanFields return -1 when they are used as padding. As we do
198
+ # some comparisons based on span widths when we attend over the
199
+ # span representations that we generate from these indices, we
200
+ # need them to be <= 0. This is only relevant in edge cases where
201
+ # the number of spans we consider after the pruning stage is >= the
202
+ # total number of spans, because in this case, it is possible we might
203
+ # consider a masked span.
204
+ # Shape: (batch_size, num_spans, 2)
205
+ spans = F.relu(spans.float()).long()
206
+
207
+ # Shape: (batch_size, document_length, encoding_dim)
208
+ contextualized_embeddings = self._context_layer(text_embeddings, text_mask)
209
+ # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
210
+ endpoint_span_embeddings = self._endpoint_span_extractor(
211
+ contextualized_embeddings, spans
212
+ )
213
+ # Shape: (batch_size, num_spans, emebedding_size)
214
+ attended_span_embeddings = self._attentive_span_extractor(
215
+ text_embeddings, spans
216
+ )
217
+
218
+ # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size)
219
+ span_embeddings = torch.cat(
220
+ [endpoint_span_embeddings, attended_span_embeddings], -1
221
+ )
222
+
223
+ # Prune based on mention scores.
224
+ num_spans_to_keep = int(math.floor(self._spans_per_word * document_length))
225
+ num_spans_to_keep = min(num_spans_to_keep, num_spans)
226
+ # num_spans_to_keep = num_spans
227
+
228
+ # Shape: (batch_size, num_spans)
229
+ span_mention_scores = self._mention_scorer(
230
+ self._mention_feedforward(span_embeddings)
231
+ ).squeeze(-1)
232
+ k = torch.full(
233
+ (batch_size,), num_spans_to_keep, dtype=torch.long, device=spans.device
234
+ )
235
+ # Shape: (batch_size, num_spans) for all 3 tensors
236
+ top_span_mention_scores, top_span_mask, top_span_indices = util.masked_topk(
237
+ span_mention_scores, span_mask, k, dim=1
238
+ )
239
+
240
+ # Shape: (batch_size * num_spans_to_keep)
241
+ # torch.index_select only accepts 1D indices, but here
242
+ # we need to select spans for each element in the batch.
243
+ # This reformats the indices to take into account their
244
+ # index into the batch. We precompute this here to make
245
+ # the multiple calls to util.batched_index_select below more efficient.
246
+ flat_top_span_indices = util.flatten_and_batch_shift_indices(
247
+ top_span_indices, num_spans
248
+ )
249
+
250
+ # Compute final predictions for which spans to consider as mentions.
251
+ # Shape: (batch_size, num_spans_to_keep, 2)
252
+ top_spans = util.batched_index_select(
253
+ spans, top_span_indices, flat_top_span_indices
254
+ )
255
+
256
+ # Shape: (batch_size, num_spans_to_keep, embedding_size)
257
+ top_span_embeddings = util.batched_index_select(
258
+ span_embeddings, top_span_indices, flat_top_span_indices
259
+ )
260
+
261
+ # Compute indices for antecedent spans to consider.
262
+ max_antecedents = min(self._max_antecedents, num_spans_to_keep)
263
+
264
+ # Now that we have our variables in terms of num_spans_to_keep, we need to
265
+ # compare span pairs to decide each span's antecedent. Each span can only
266
+ # have prior spans as antecedents, and we only consider up to max_antecedents
267
+ # prior spans. So the first thing we do is construct a matrix mapping a span's
268
+ # index to the indices of its allowed antecedents.
269
+
270
+ # Once we have this matrix, we reformat our variables again to get embeddings
271
+ # for all valid antecedents for each span. This gives us variables with shapes
272
+ # like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which
273
+ # we can use to make coreference decisions between valid span pairs.
274
+
275
+ if self._coarse_to_fine:
276
+ pruned_antecedents = self._coarse_to_fine_pruning(
277
+ top_span_embeddings,
278
+ top_span_mention_scores,
279
+ top_span_mask,
280
+ max_antecedents,
281
+ )
282
+ else:
283
+ pruned_antecedents = self._distance_pruning(
284
+ top_span_embeddings, top_span_mention_scores, max_antecedents
285
+ )
286
+
287
+ # Shape: (batch_size, num_spans_to_keep, max_antecedents) for all 4 tensors
288
+ (
289
+ top_partial_coreference_scores,
290
+ top_antecedent_mask,
291
+ top_antecedent_offsets,
292
+ top_antecedent_indices,
293
+ ) = pruned_antecedents
294
+
295
+ flat_top_antecedent_indices = util.flatten_and_batch_shift_indices(
296
+ top_antecedent_indices, num_spans_to_keep
297
+ )
298
+
299
+ # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
300
+ top_antecedent_embeddings = util.batched_index_select(
301
+ top_span_embeddings, top_antecedent_indices, flat_top_antecedent_indices
302
+ )
303
+ # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
304
+ coreference_scores = self._compute_coreference_scores(
305
+ top_span_embeddings,
306
+ top_antecedent_embeddings,
307
+ top_partial_coreference_scores,
308
+ top_antecedent_mask,
309
+ top_antecedent_offsets,
310
+ )
311
+
312
+ for _ in range(self._inference_order - 1):
313
+ dummy_mask = top_antecedent_mask.new_ones(batch_size, num_spans_to_keep, 1)
314
+ # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents,)
315
+ top_antecedent_with_dummy_mask = torch.cat(
316
+ [dummy_mask, top_antecedent_mask], -1
317
+ )
318
+ # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
319
+ attention_weight = util.masked_softmax(
320
+ coreference_scores,
321
+ top_antecedent_with_dummy_mask,
322
+ memory_efficient=True,
323
+ )
324
+ # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents, embedding_size)
325
+ top_antecedent_with_dummy_embeddings = torch.cat(
326
+ [top_span_embeddings.unsqueeze(2), top_antecedent_embeddings], 2
327
+ )
328
+ # Shape: (batch_size, num_spans_to_keep, embedding_size)
329
+ attended_embeddings = util.weighted_sum(
330
+ top_antecedent_with_dummy_embeddings, attention_weight
331
+ )
332
+ # Shape: (batch_size, num_spans_to_keep, embedding_size)
333
+ top_span_embeddings = self._span_updating_gated_sum(
334
+ top_span_embeddings, attended_embeddings
335
+ )
336
+
337
+ # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
338
+ top_antecedent_embeddings = util.batched_index_select(
339
+ top_span_embeddings, top_antecedent_indices, flat_top_antecedent_indices
340
+ )
341
+ # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
342
+ coreference_scores = self._compute_coreference_scores(
343
+ top_span_embeddings,
344
+ top_antecedent_embeddings,
345
+ top_partial_coreference_scores,
346
+ top_antecedent_mask,
347
+ top_antecedent_offsets,
348
+ )
349
+
350
+ # We now have, for each span which survived the pruning stage,
351
+ # a predicted antecedent. This implies a clustering if we group
352
+ # mentions which refer to each other in a chain.
353
+ # Shape: (batch_size, num_spans_to_keep)
354
+ _, predicted_antecedents = coreference_scores.max(2)
355
+ # Subtract one here because index 0 is the "no antecedent" class,
356
+ # so this makes the indices line up with actual spans if the prediction
357
+ # is greater than -1.
358
+ predicted_antecedents -= 1
359
+
360
+ output_dict = {
361
+ "top_spans": top_spans,
362
+ "antecedent_indices": top_antecedent_indices,
363
+ "predicted_antecedents": predicted_antecedents,
364
+ }
365
+ if span_labels is not None:
366
+ # Find the gold labels for the spans which we kept.
367
+ # Shape: (batch_size, num_spans_to_keep, 1)
368
+ pruned_gold_labels = util.batched_index_select(
369
+ span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices
370
+ )
371
+
372
+ # Shape: (batch_size, num_spans_to_keep, max_antecedents)
373
+ antecedent_labels = util.batched_index_select(
374
+ pruned_gold_labels, top_antecedent_indices, flat_top_antecedent_indices
375
+ ).squeeze(-1)
376
+ antecedent_labels = util.replace_masked_values(
377
+ antecedent_labels, top_antecedent_mask, -100
378
+ )
379
+
380
+ # Compute labels.
381
+ # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
382
+ gold_antecedent_labels = self._compute_antecedent_gold_labels(
383
+ pruned_gold_labels, antecedent_labels
384
+ )
385
+ # Now, compute the loss using the negative marginal log-likelihood.
386
+ # This is equal to the log of the sum of the probabilities of all antecedent predictions
387
+ # that would be consistent with the data, in the sense that we are minimising, for a
388
+ # given span, the negative marginal log likelihood of all antecedents which are in the
389
+ # same gold cluster as the span we are currently considering. Each span i predicts a
390
+ # single antecedent j, but there might be several prior mentions k in the same
391
+ # coreference cluster that would be valid antecedents. Our loss is the sum of the
392
+ # probability assigned to all valid antecedents. This is a valid objective for
393
+ # clustering as we don't mind which antecedent is predicted, so long as they are in
394
+ # the same coreference cluster.
395
+ coreference_log_probs = util.masked_log_softmax(
396
+ coreference_scores, top_span_mask.unsqueeze(-1)
397
+ )
398
+ correct_antecedent_log_probs = (
399
+ coreference_log_probs + gold_antecedent_labels.log()
400
+ )
401
+ negative_marginal_log_likelihood = -util.logsumexp(
402
+ correct_antecedent_log_probs
403
+ ).sum()
404
+
405
+ self._mention_recall(top_spans, metadata)
406
+ self._conll_coref_scores(
407
+ top_spans, top_antecedent_indices, predicted_antecedents, metadata
408
+ )
409
+
410
+ output_dict["loss"] = negative_marginal_log_likelihood
411
+
412
+ if metadata is not None:
413
+ output_dict["document"] = [x["original_text"] for x in metadata]
414
+ return output_dict
415
+
416
+ def make_output_human_readable(self, output_dict: Dict[str, torch.Tensor]):
417
+ """
418
+ Converts the list of spans and predicted antecedent indices into clusters
419
+ of spans for each element in the batch.
420
+
421
+ # Parameters
422
+
423
+ output_dict : `Dict[str, torch.Tensor]`, required.
424
+ The result of calling :func:`forward` on an instance or batch of instances.
425
+
426
+ # Returns
427
+
428
+ The same output dictionary, but with an additional `clusters` key:
429
+
430
+ clusters : `List[List[List[Tuple[int, int]]]]`
431
+ A nested list, representing, for each instance in the batch, the list of clusters,
432
+ which are in turn comprised of a list of (start, end) inclusive spans into the
433
+ original document.
434
+ """
435
+
436
+ # A tensor of shape (batch_size, num_spans_to_keep, 2), representing
437
+ # the start and end indices of each span.
438
+ batch_top_spans = output_dict["top_spans"].detach().cpu()
439
+
440
+ # A tensor of shape (batch_size, num_spans_to_keep) representing, for each span,
441
+ # the index into `antecedent_indices` which specifies the antecedent span. Additionally,
442
+ # the index can be -1, specifying that the span has no predicted antecedent.
443
+ batch_predicted_antecedents = (
444
+ output_dict["predicted_antecedents"].detach().cpu()
445
+ )
446
+
447
+ # A tensor of shape (num_spans_to_keep, max_antecedents), representing the indices
448
+ # of the predicted antecedents with respect to the 2nd dimension of `batch_top_spans`
449
+ # for each antecedent we considered.
450
+ batch_antecedent_indices = output_dict["antecedent_indices"].detach().cpu()
451
+ batch_clusters: List[List[List[Tuple[int, int]]]] = []
452
+
453
+ # Calling zip() on two tensors results in an iterator over their
454
+ # first dimension. This is iterating over instances in the batch.
455
+ for top_spans, predicted_antecedents, antecedent_indices in zip(
456
+ batch_top_spans, batch_predicted_antecedents, batch_antecedent_indices
457
+ ):
458
+ spans_to_cluster_ids: Dict[Tuple[int, int], int] = {}
459
+ clusters: List[List[Tuple[int, int]]] = []
460
+
461
+ for i, (span, predicted_antecedent) in enumerate(
462
+ zip(top_spans, predicted_antecedents)
463
+ ):
464
+ if predicted_antecedent < 0:
465
+ # We don't care about spans which are
466
+ # not co-referent with anything.
467
+ continue
468
+
469
+ # Find the right cluster to update with this span.
470
+ # To do this, we find the row in `antecedent_indices`
471
+ # corresponding to this span we are considering.
472
+ # The predicted antecedent is then an index into this list
473
+ # of indices, denoting the span from `top_spans` which is the
474
+ # most likely antecedent.
475
+ predicted_index = antecedent_indices[i, predicted_antecedent]
476
+
477
+ antecedent_span = (
478
+ top_spans[predicted_index, 0].item(),
479
+ top_spans[predicted_index, 1].item(),
480
+ )
481
+
482
+ # Check if we've seen the span before.
483
+ if antecedent_span in spans_to_cluster_ids:
484
+ predicted_cluster_id: int = spans_to_cluster_ids[antecedent_span]
485
+ else:
486
+ # We start a new cluster.
487
+ predicted_cluster_id = len(clusters)
488
+ # Append a new cluster containing only this span.
489
+ clusters.append([antecedent_span])
490
+ # Record the new id of this span.
491
+ spans_to_cluster_ids[antecedent_span] = predicted_cluster_id
492
+
493
+ # Now add the span we are currently considering.
494
+ span_start, span_end = span[0].item(), span[1].item()
495
+ clusters[predicted_cluster_id].append((span_start, span_end))
496
+ spans_to_cluster_ids[(span_start, span_end)] = predicted_cluster_id
497
+ batch_clusters.append(clusters)
498
+
499
+ output_dict["clusters"] = batch_clusters
500
+ return output_dict
501
+
502
+ def get_metrics(self, reset: bool = False) -> Dict[str, float]:
503
+ mention_recall = self._mention_recall.get_metric(reset)
504
+ coref_precision, coref_recall, coref_f1 = self._conll_coref_scores.get_metric(
505
+ reset
506
+ )
507
+
508
+ return {
509
+ "coref_precision": coref_precision,
510
+ "coref_recall": coref_recall,
511
+ "coref_f1": coref_f1,
512
+ "mention_recall": mention_recall,
513
+ }
514
+
515
+ @staticmethod
516
+ def _generate_valid_antecedents(
517
+ num_spans_to_keep: int, max_antecedents: int, device: int
518
+ ) -> Tuple[torch.IntTensor, torch.IntTensor, torch.BoolTensor]:
519
+ """
520
+ This method generates possible antecedents per span which survived the pruning
521
+ stage. This procedure is `generic across the batch`. The reason this is the case is
522
+ that each span in a batch can be coreferent with any previous span, but here we
523
+ are computing the possible `indices` of these spans. So, regardless of the batch,
524
+ the 1st span _cannot_ have any antecedents, because there are none to select from.
525
+ Similarly, each element can only predict previous spans, so this returns a matrix
526
+ of shape (num_spans_to_keep, max_antecedents), where the (i,j)-th index is equal to
527
+ (i - 1) - j if j <= i, or zero otherwise.
528
+
529
+ # Parameters
530
+
531
+ num_spans_to_keep : `int`, required.
532
+ The number of spans that were kept while pruning.
533
+ max_antecedents : `int`, required.
534
+ The maximum number of antecedent spans to consider for every span.
535
+ device : `int`, required.
536
+ The CUDA device to use.
537
+
538
+ # Returns
539
+
540
+ valid_antecedent_indices : `torch.LongTensor`
541
+ The indices of every antecedent to consider with respect to the top k spans.
542
+ Has shape `(num_spans_to_keep, max_antecedents)`.
543
+ valid_antecedent_offsets : `torch.LongTensor`
544
+ The distance between the span and each of its antecedents in terms of the number
545
+ of considered spans (i.e not the word distance between the spans).
546
+ Has shape `(1, max_antecedents)`.
547
+ valid_antecedent_mask : `torch.BoolTensor`
548
+ The mask representing whether each antecedent span is valid. Required since
549
+ different spans have different numbers of valid antecedents. For example, the first
550
+ span in the document should have no valid antecedents.
551
+ Has shape `(1, num_spans_to_keep, max_antecedents)`.
552
+ """
553
+ # Shape: (num_spans_to_keep, 1)
554
+ target_indices = util.get_range_vector(num_spans_to_keep, device).unsqueeze(1)
555
+
556
+ # Shape: (1, max_antecedents)
557
+ valid_antecedent_offsets = (
558
+ util.get_range_vector(max_antecedents, device) + 1
559
+ ).unsqueeze(0)
560
+
561
+ # This is a broadcasted subtraction.
562
+ # Shape: (num_spans_to_keep, max_antecedents)
563
+ raw_antecedent_indices = target_indices - valid_antecedent_offsets
564
+
565
+ # In our matrix of indices, the upper triangular part will be negative
566
+ # because the offsets will be > the target indices. We want to mask these,
567
+ # because these are exactly the indices which we don't want to predict, per span.
568
+ # Shape: (1, num_spans_to_keep, max_antecedents)
569
+ valid_antecedent_mask = (raw_antecedent_indices >= 0).unsqueeze(0)
570
+
571
+ # Shape: (num_spans_to_keep, max_antecedents)
572
+ valid_antecedent_indices = F.relu(raw_antecedent_indices.float()).long()
573
+ return valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_mask
574
+
575
+ def _distance_pruning(
576
+ self,
577
+ top_span_embeddings: torch.FloatTensor,
578
+ top_span_mention_scores: torch.FloatTensor,
579
+ max_antecedents: int,
580
+ ) -> Tuple[torch.FloatTensor, torch.BoolTensor, torch.LongTensor, torch.LongTensor]:
581
+ """
582
+ Generates antecedents for each span and prunes down to `max_antecedents`. This method
583
+ prunes antecedents only based on distance (i.e. number of intervening spans). The closest
584
+ antecedents are kept.
585
+
586
+ # Parameters
587
+
588
+ top_span_embeddings: `torch.FloatTensor`, required.
589
+ The embeddings of the top spans.
590
+ (batch_size, num_spans_to_keep, embedding_size).
591
+ top_span_mention_scores: `torch.FloatTensor`, required.
592
+ The mention scores of the top spans.
593
+ (batch_size, num_spans_to_keep).
594
+ max_antecedents: `int`, required.
595
+ The maximum number of antecedents to keep for each span.
596
+
597
+ # Returns
598
+
599
+ top_partial_coreference_scores: `torch.FloatTensor`
600
+ The partial antecedent scores for each span-antecedent pair. Computed by summing
601
+ the span mentions scores of the span and the antecedent. This score is partial because
602
+ compared to the full coreference scores, it lacks the interaction term
603
+ w * FFNN([g_i, g_j, g_i * g_j, features]).
604
+ (batch_size, num_spans_to_keep, max_antecedents)
605
+ top_antecedent_mask: `torch.BoolTensor`
606
+ The mask representing whether each antecedent span is valid. Required since
607
+ different spans have different numbers of valid antecedents. For example, the first
608
+ span in the document should have no valid antecedents.
609
+ (batch_size, num_spans_to_keep, max_antecedents)
610
+ top_antecedent_offsets: `torch.LongTensor`
611
+ The distance between the span and each of its antecedents in terms of the number
612
+ of considered spans (i.e not the word distance between the spans).
613
+ (batch_size, num_spans_to_keep, max_antecedents)
614
+ top_antecedent_indices: `torch.LongTensor`
615
+ The indices of every antecedent to consider with respect to the top k spans.
616
+ (batch_size, num_spans_to_keep, max_antecedents)
617
+ """
618
+ # These antecedent matrices are independent of the batch dimension - they're just a function
619
+ # of the span's position in top_spans.
620
+ # The spans are in document order, so we can just use the relative
621
+ # index of the spans to know which other spans are allowed antecedents.
622
+
623
+ num_spans_to_keep = top_span_embeddings.size(1)
624
+ device = util.get_device_of(top_span_embeddings)
625
+
626
+ # Shapes:
627
+ # (num_spans_to_keep, max_antecedents),
628
+ # (1, max_antecedents),
629
+ # (1, num_spans_to_keep, max_antecedents)
630
+ (
631
+ top_antecedent_indices,
632
+ top_antecedent_offsets,
633
+ top_antecedent_mask,
634
+ ) = self._generate_valid_antecedents( # noqa
635
+ num_spans_to_keep, max_antecedents, device
636
+ )
637
+
638
+ # Shape: (batch_size, num_spans_to_keep, max_antecedents)
639
+ top_antecedent_mention_scores = util.flattened_index_select(
640
+ top_span_mention_scores.unsqueeze(-1), top_antecedent_indices
641
+ ).squeeze(-1)
642
+
643
+ # Shape: (batch_size, num_spans_to_keep, max_antecedents) * 4
644
+ top_partial_coreference_scores = (
645
+ top_span_mention_scores.unsqueeze(-1) + top_antecedent_mention_scores
646
+ )
647
+ top_antecedent_indices = top_antecedent_indices.unsqueeze(0).expand_as(
648
+ top_partial_coreference_scores
649
+ )
650
+ top_antecedent_offsets = top_antecedent_offsets.unsqueeze(0).expand_as(
651
+ top_partial_coreference_scores
652
+ )
653
+ top_antecedent_mask = top_antecedent_mask.expand_as(
654
+ top_partial_coreference_scores
655
+ )
656
+
657
+ return (
658
+ top_partial_coreference_scores,
659
+ top_antecedent_mask,
660
+ top_antecedent_offsets,
661
+ top_antecedent_indices,
662
+ )
663
+
664
+ def _coarse_to_fine_pruning(
665
+ self,
666
+ top_span_embeddings: torch.FloatTensor,
667
+ top_span_mention_scores: torch.FloatTensor,
668
+ top_span_mask: torch.BoolTensor,
669
+ max_antecedents: int,
670
+ ) -> Tuple[torch.FloatTensor, torch.BoolTensor, torch.LongTensor, torch.LongTensor]:
671
+ """
672
+ Generates antecedents for each span and prunes down to `max_antecedents`. This method
673
+ prunes antecedents using a fast bilinar interaction score between a span and a candidate
674
+ antecedent, and the highest-scoring antecedents are kept.
675
+
676
+ # Parameters
677
+
678
+ top_span_embeddings: `torch.FloatTensor`, required.
679
+ The embeddings of the top spans.
680
+ (batch_size, num_spans_to_keep, embedding_size).
681
+ top_span_mention_scores: `torch.FloatTensor`, required.
682
+ The mention scores of the top spans.
683
+ (batch_size, num_spans_to_keep).
684
+ top_span_mask: `torch.BoolTensor`, required.
685
+ The mask for the top spans.
686
+ (batch_size, num_spans_to_keep).
687
+ max_antecedents: `int`, required.
688
+ The maximum number of antecedents to keep for each span.
689
+
690
+ # Returns
691
+
692
+ top_partial_coreference_scores: `torch.FloatTensor`
693
+ The partial antecedent scores for each span-antecedent pair. Computed by summing
694
+ the span mentions scores of the span and the antecedent as well as a bilinear
695
+ interaction term. This score is partial because compared to the full coreference scores,
696
+ it lacks the interaction term
697
+ `w * FFNN([g_i, g_j, g_i * g_j, features])`.
698
+ `(batch_size, num_spans_to_keep, max_antecedents)`
699
+ top_antecedent_mask: `torch.BoolTensor`
700
+ The mask representing whether each antecedent span is valid. Required since
701
+ different spans have different numbers of valid antecedents. For example, the first
702
+ span in the document should have no valid antecedents.
703
+ `(batch_size, num_spans_to_keep, max_antecedents)`
704
+ top_antecedent_offsets: `torch.LongTensor`
705
+ The distance between the span and each of its antecedents in terms of the number
706
+ of considered spans (i.e not the word distance between the spans).
707
+ `(batch_size, num_spans_to_keep, max_antecedents)`
708
+ top_antecedent_indices: `torch.LongTensor`
709
+ The indices of every antecedent to consider with respect to the top k spans.
710
+ `(batch_size, num_spans_to_keep, max_antecedents)`
711
+ """
712
+ # batch_size, num_spans_to_keep = top_span_embeddings.size()[:2]
713
+ batch_size, num_spans_to_keep = top_span_embeddings.shape[:2]
714
+ device = util.get_device_of(top_span_embeddings)
715
+
716
+ # Shape: (1, num_spans_to_keep, num_spans_to_keep)
717
+ _, _, valid_antecedent_mask = self._generate_valid_antecedents(
718
+ num_spans_to_keep, num_spans_to_keep, device
719
+ )
720
+
721
+ mention_one_score = top_span_mention_scores.unsqueeze(1)
722
+ mention_two_score = top_span_mention_scores.unsqueeze(2)
723
+ bilinear_weights = self._coarse2fine_scorer(top_span_embeddings).transpose(1, 2)
724
+ bilinear_score = torch.matmul(top_span_embeddings, bilinear_weights)
725
+ # Shape: (batch_size, num_spans_to_keep, num_spans_to_keep); broadcast op
726
+ partial_antecedent_scores = (
727
+ mention_one_score + mention_two_score + bilinear_score
728
+ )
729
+
730
+ # Shape: (batch_size, num_spans_to_keep, num_spans_to_keep); broadcast op
731
+ span_pair_mask = top_span_mask.unsqueeze(-1) & valid_antecedent_mask
732
+
733
+ # Shape:
734
+ # (batch_size, num_spans_to_keep, max_antecedents) * 3
735
+ # k_tensor = torch.full((batch_size,), max_antecedents, dtype=torch.long, device=top_span_embeddings.device)
736
+ k_tensor = torch.full(
737
+ (batch_size, num_spans_to_keep),
738
+ max_antecedents,
739
+ dtype=torch.long,
740
+ device=top_span_embeddings.device,
741
+ )
742
+
743
+ (
744
+ top_partial_coreference_scores,
745
+ top_antecedent_mask,
746
+ top_antecedent_indices,
747
+ ) = util.masked_topk(
748
+ partial_antecedent_scores, span_pair_mask, k_tensor, dim=-1
749
+ )
750
+ # (
751
+ # top_partial_coreference_scores,
752
+ # top_antecedent_mask,
753
+ # top_antecedent_indices,
754
+ # ) = util.masked_topk(partial_antecedent_scores, span_pair_mask, max_antecedents)
755
+
756
+ top_span_range = util.get_range_vector(num_spans_to_keep, device)
757
+ # Shape: (num_spans_to_keep, num_spans_to_keep); broadcast op
758
+ valid_antecedent_offsets = top_span_range.unsqueeze(
759
+ -1
760
+ ) - top_span_range.unsqueeze(0)
761
+
762
+ # TODO: we need to make `batched_index_select` more general to make this less awkward.
763
+ top_antecedent_offsets = util.batched_index_select(
764
+ valid_antecedent_offsets.unsqueeze(0)
765
+ .expand(batch_size, num_spans_to_keep, num_spans_to_keep)
766
+ .reshape(batch_size * num_spans_to_keep, num_spans_to_keep, 1),
767
+ top_antecedent_indices.view(-1, max_antecedents),
768
+ ).reshape(batch_size, num_spans_to_keep, max_antecedents)
769
+
770
+ return (
771
+ top_partial_coreference_scores,
772
+ top_antecedent_mask,
773
+ top_antecedent_offsets,
774
+ top_antecedent_indices,
775
+ )
776
+
777
+ def _compute_span_pair_embeddings(
778
+ self,
779
+ top_span_embeddings: torch.FloatTensor,
780
+ antecedent_embeddings: torch.FloatTensor,
781
+ antecedent_offsets: torch.FloatTensor,
782
+ ):
783
+ """
784
+ Computes an embedding representation of pairs of spans for the pairwise scoring function
785
+ to consider. This includes both the original span representations, the element-wise
786
+ similarity of the span representations, and an embedding representation of the distance
787
+ between the two spans.
788
+
789
+ # Parameters
790
+
791
+ top_span_embeddings : `torch.FloatTensor`, required.
792
+ Embedding representations of the top spans. Has shape
793
+ (batch_size, num_spans_to_keep, embedding_size).
794
+ antecedent_embeddings : `torch.FloatTensor`, required.
795
+ Embedding representations of the antecedent spans we are considering
796
+ for each top span. Has shape
797
+ (batch_size, num_spans_to_keep, max_antecedents, embedding_size).
798
+ antecedent_offsets : `torch.IntTensor`, required.
799
+ The offsets between each top span and its antecedent spans in terms
800
+ of spans we are considering. Has shape (batch_size, num_spans_to_keep, max_antecedents).
801
+
802
+ # Returns
803
+
804
+ span_pair_embeddings : `torch.FloatTensor`
805
+ Embedding representation of the pair of spans to consider. Has shape
806
+ (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
807
+ """
808
+ # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
809
+ target_embeddings = top_span_embeddings.unsqueeze(2).expand_as(
810
+ antecedent_embeddings
811
+ )
812
+
813
+ # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
814
+ antecedent_distance_embeddings = self._distance_embedding(
815
+ util.bucket_values(
816
+ antecedent_offsets, num_total_buckets=self._num_distance_buckets
817
+ )
818
+ )
819
+
820
+ # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
821
+ span_pair_embeddings = torch.cat(
822
+ [
823
+ target_embeddings,
824
+ antecedent_embeddings,
825
+ antecedent_embeddings * target_embeddings,
826
+ antecedent_distance_embeddings,
827
+ ],
828
+ -1,
829
+ )
830
+ return span_pair_embeddings
831
+
832
+ @staticmethod
833
+ def _compute_antecedent_gold_labels(
834
+ top_span_labels: torch.IntTensor, antecedent_labels: torch.IntTensor
835
+ ):
836
+ """
837
+ Generates a binary indicator for every pair of spans. This label is one if and
838
+ only if the pair of spans belong to the same cluster. The labels are augmented
839
+ with a dummy antecedent at the zeroth position, which represents the prediction
840
+ that a span does not have any antecedent.
841
+
842
+ # Parameters
843
+
844
+ top_span_labels : `torch.IntTensor`, required.
845
+ The cluster id label for every span. The id is arbitrary,
846
+ as we just care about the clustering. Has shape (batch_size, num_spans_to_keep).
847
+ antecedent_labels : `torch.IntTensor`, required.
848
+ The cluster id label for every antecedent span. The id is arbitrary,
849
+ as we just care about the clustering. Has shape
850
+ (batch_size, num_spans_to_keep, max_antecedents).
851
+
852
+ # Returns
853
+
854
+ pairwise_labels_with_dummy_label : `torch.FloatTensor`
855
+ A binary tensor representing whether a given pair of spans belong to
856
+ the same cluster in the gold clustering.
857
+ Has shape (batch_size, num_spans_to_keep, max_antecedents + 1).
858
+
859
+ """
860
+ # Shape: (batch_size, num_spans_to_keep, max_antecedents)
861
+ target_labels = top_span_labels.expand_as(antecedent_labels)
862
+ same_cluster_indicator = (target_labels == antecedent_labels).float()
863
+ non_dummy_indicator = (target_labels >= 0).float()
864
+ pairwise_labels = same_cluster_indicator * non_dummy_indicator
865
+
866
+ # Shape: (batch_size, num_spans_to_keep, 1)
867
+ dummy_labels = (1 - pairwise_labels).prod(-1, keepdim=True)
868
+
869
+ # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
870
+ pairwise_labels_with_dummy_label = torch.cat(
871
+ [dummy_labels, pairwise_labels], -1
872
+ )
873
+ return pairwise_labels_with_dummy_label
874
+
875
+ def _compute_coreference_scores(
876
+ self,
877
+ top_span_embeddings: torch.FloatTensor,
878
+ top_antecedent_embeddings: torch.FloatTensor,
879
+ top_partial_coreference_scores: torch.FloatTensor,
880
+ top_antecedent_mask: torch.BoolTensor,
881
+ top_antecedent_offsets: torch.FloatTensor,
882
+ ) -> torch.FloatTensor:
883
+ """
884
+ Computes scores for every pair of spans. Additionally, a dummy label is included,
885
+ representing the decision that the span is not coreferent with anything. For the dummy
886
+ label, the score is always zero. For the true antecedent spans, the score consists of
887
+ the pairwise antecedent score and the unary mention scores for the span and its
888
+ antecedent. The factoring allows the model to blame many of the absent links on bad
889
+ spans, enabling the pruning strategy used in the forward pass.
890
+
891
+ # Parameters
892
+
893
+ top_span_embeddings : `torch.FloatTensor`, required.
894
+ Embedding representations of the kept spans. Has shape
895
+ (batch_size, num_spans_to_keep, embedding_size)
896
+ top_antecedent_embeddings: `torch.FloatTensor`, required.
897
+ The embeddings of antecedents for each span candidate. Has shape
898
+ (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
899
+ top_partial_coreference_scores : `torch.FloatTensor`, required.
900
+ Sum of span mention score and antecedent mention score. The coarse to fine settings
901
+ has an additional term which is the coarse bilinear score.
902
+ (batch_size, num_spans_to_keep, max_antecedents).
903
+ top_antecedent_mask : `torch.BoolTensor`, required.
904
+ The mask for valid antecedents.
905
+ (batch_size, num_spans_to_keep, max_antecedents).
906
+ top_antecedent_offsets : `torch.FloatTensor`, required.
907
+ The distance between the span and each of its antecedents in terms of the number
908
+ of considered spans (i.e not the word distance between the spans).
909
+ (batch_size, num_spans_to_keep, max_antecedents).
910
+
911
+ # Returns
912
+
913
+ coreference_scores : `torch.FloatTensor`
914
+ A tensor of shape (batch_size, num_spans_to_keep, max_antecedents + 1),
915
+ representing the unormalised score for each (span, antecedent) pair
916
+ we considered.
917
+
918
+ """
919
+ # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
920
+ span_pair_embeddings = self._compute_span_pair_embeddings(
921
+ top_span_embeddings, top_antecedent_embeddings, top_antecedent_offsets
922
+ )
923
+
924
+ # Shape: (batch_size, num_spans_to_keep, max_antecedents)
925
+ antecedent_scores = self._antecedent_scorer(
926
+ self._antecedent_feedforward(span_pair_embeddings)
927
+ ).squeeze(-1)
928
+ antecedent_scores += top_partial_coreference_scores
929
+ antecedent_scores = util.replace_masked_values(
930
+ antecedent_scores,
931
+ top_antecedent_mask,
932
+ util.min_value_of_dtype(antecedent_scores.dtype),
933
+ )
934
+
935
+ # Shape: (batch_size, num_spans_to_keep, 1)
936
+ shape = [antecedent_scores.size(0), antecedent_scores.size(1), 1]
937
+ dummy_scores = antecedent_scores.new_zeros(*shape)
938
+
939
+ # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
940
+ coreference_scores = torch.cat([dummy_scores, antecedent_scores], -1)
941
+ return coreference_scores
942
+
943
+ default_predictor = "coreference_resolution"
allennlp-coref-onnx-mMiniLMv2-L12-H384-distilled-from-XLMR-Large/export/example.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import allennlp_models.coref
2
+
3
+ from allennlp.models.archival import load_archive
4
+ from allennlp.predictors.predictor import Predictor
5
+
6
+ archive_path = "/models/minillm/model.tar.gz"
7
+
8
+ archive = load_archive(archive_path)
9
+ predictor = Predictor.from_archive(archive, predictor_name="coreference_resolution")
10
+
11
+ text = (
12
+ "Barack Obama was the 44th President of the United States. He was born in Hawaii."
13
+ )
14
+ result = predictor.predict(document=text)
15
+
16
+ print(result["clusters"])
allennlp-coref-onnx-mMiniLMv2-L12-H384-distilled-from-XLMR-Large/export/export_onnx.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from allennlp.models.archival import load_archive
5
+ from allennlp.predictors import Predictor
6
+ from allennlp.data import Batch
7
+
8
+ import allennlp_models.coref # Registers coref models and readers
9
+
10
+
11
+ class CorefONNXWrapper(nn.Module):
12
+ def __init__(self, model):
13
+ super().__init__()
14
+ self.model = model
15
+
16
+ def forward(
17
+ self,
18
+ token_ids,
19
+ mask,
20
+ type_ids,
21
+ wordpiece_mask,
22
+ segment_concat_mask,
23
+ offsets,
24
+ spans,
25
+ ):
26
+ text = {
27
+ "tokens": {
28
+ "token_ids": token_ids,
29
+ "mask": mask,
30
+ "type_ids": type_ids,
31
+ "wordpiece_mask": wordpiece_mask,
32
+ "segment_concat_mask": segment_concat_mask,
33
+ "offsets": offsets,
34
+ }
35
+ }
36
+ output = self.model(text=text, spans=spans)
37
+ return (
38
+ output["top_spans"],
39
+ output["antecedent_indices"],
40
+ output["predicted_antecedents"],
41
+ )
42
+
43
+
44
+ def pad_spans(spans: torch.Tensor, max_len: int) -> torch.Tensor:
45
+ """
46
+ Pads a tensor of spans to max_len along dimension 0.
47
+
48
+ Args:
49
+ spans: Tensor of shape (num_spans, 2)
50
+ max_len: Desired number of spans (along dim 0)
51
+
52
+ Returns:
53
+ Tensor of shape (max_len, 2)
54
+ """
55
+ num_spans = spans.size(0)
56
+
57
+ if num_spans >= max_len:
58
+ return spans[:max_len]
59
+ else:
60
+ padding = torch.zeros(
61
+ (max_len - num_spans, 2), dtype=spans.dtype, device=spans.device
62
+ )
63
+ return torch.cat([spans, padding], dim=0)
64
+
65
+
66
+ def export_model_to_onnx(archive_path: str, onnx_path: str):
67
+ # Load archive and predictor
68
+ archive = load_archive(archive_path)
69
+ predictor = Predictor.from_archive(archive, predictor_name="coreference_resolution")
70
+ model = predictor._model
71
+ dataset_reader = predictor._dataset_reader
72
+
73
+ # Example input
74
+ input_json = {
75
+ "document": ["My", "sister", "has", "a", "dog", ".", "She", "loves", "him", "."]
76
+ }
77
+
78
+ # Convert input to Instance and batch
79
+ instance = dataset_reader.text_to_instance(input_json["document"])
80
+ instances = [instance]
81
+ dataset_reader.apply_token_indexers(instances)
82
+
83
+ batch = Batch(instances)
84
+ batch.index_instances(model.vocab)
85
+ tensor_dict = batch.as_tensor_dict()
86
+
87
+ # Filter only required args for forward()
88
+ model_input = {
89
+ "text": tensor_dict["text"], # Nested dict of token tensors
90
+ "spans": tensor_dict["spans"], # Tensor of span indices
91
+ }
92
+
93
+ for k, v in model_input["text"]["tokens"].items():
94
+ print(k, v.shape)
95
+
96
+ print("spans", model_input["spans"].shape)
97
+ print(model_input["spans"])
98
+
99
+ # Move to CPU and eval mode
100
+ device = torch.device("cpu")
101
+ model = model.to(device).eval()
102
+
103
+ for k, v in model_input.items():
104
+ if isinstance(v, torch.Tensor):
105
+ model_input[k] = v.to(device)
106
+ elif isinstance(v, dict):
107
+ model_input[k] = {
108
+ kk: vv.to(device) if isinstance(vv, torch.Tensor) else vv
109
+ for kk, vv in v.items()
110
+ }
111
+
112
+ # Wrap and prepare export
113
+ wrapper = CorefONNXWrapper(model)
114
+ max_num_spans = 300 # <-- or any upper bound you want
115
+ padded_spans = pad_spans(model_input["spans"].squeeze(0), max_num_spans).unsqueeze(
116
+ 0
117
+ )
118
+
119
+ example_inputs = (
120
+ model_input["text"]["tokens"]["token_ids"],
121
+ model_input["text"]["tokens"]["mask"],
122
+ model_input["text"]["tokens"]["type_ids"],
123
+ model_input["text"]["tokens"]["wordpiece_mask"],
124
+ model_input["text"]["tokens"]["segment_concat_mask"],
125
+ model_input["text"]["tokens"]["offsets"],
126
+ padded_spans,
127
+ )
128
+ torch.onnx.export(
129
+ wrapper,
130
+ args=example_inputs,
131
+ f=onnx_path,
132
+ input_names=[
133
+ "token_ids",
134
+ "mask",
135
+ "type_ids",
136
+ "wordpiece_mask",
137
+ "segment_concat_mask",
138
+ "offsets",
139
+ "spans",
140
+ ],
141
+ output_names=["top_spans", "antecedent_indices", "predicted_antecedents"],
142
+ dynamic_axes={
143
+ "token_ids": {0: "batch_size", 1: "seq_len"},
144
+ "mask": {0: "batch_size", 1: "orig_seq_len"},
145
+ "type_ids": {0: "batch_size", 1: "seq_len"},
146
+ "wordpiece_mask": {0: "batch_size", 1: "seq_len"},
147
+ "segment_concat_mask": {0: "batch_size", 1: "seq_len"},
148
+ "offsets": {0: "batch_size", 1: "orig_seq_len"},
149
+ "spans": {0: "batch_size", 1: "num_spans"},
150
+ "top_spans": {0: "batch_size", 1: "num_spans_to_keep"},
151
+ "antecedent_indices": {
152
+ 0: "batch_size",
153
+ 1: "num_spans_to_keep",
154
+ 2: "max_antecedents",
155
+ },
156
+ "predicted_antecedents": {0: "batch_size", 1: "num_spans_to_keep"},
157
+ },
158
+ opset_version=15,
159
+ do_constant_folding=True,
160
+ )
161
+
162
+
163
+ if __name__ == "__main__":
164
+ archive_path = "/models/minillm/model.tar.gz"
165
+ onnx_path = "/models/minillm/model.onnx"
166
+ export_model_to_onnx(archive_path, onnx_path)
allennlp-coref-onnx-mMiniLMv2-L12-H384-distilled-from-XLMR-Large/model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be2ad55e2b36c7e4e007aadc01d151e3d1f563c8f62a349736a17cb5ea27abe4
3
+ size 522012569
allennlp-coref-onnx-mMiniLMv2-L12-H384-distilled-from-XLMR-Large/source.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ https://huggingface.co/talmago/allennlp-coref-onnx-mMiniLMv2-L12-H384-distilled-from-XLMR-Large
allennlp-coref-onnx-mMiniLMv2-L12-H384-distilled-from-XLMR-Large/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff