Isabelle Mohr
commited on
Commit
·
86ec569
1
Parent(s):
e3681c2
feat: add span annotation and chunking pooling to encode
Browse files- modeling_xlm_roberta.py +46 -5
modeling_xlm_roberta.py
CHANGED
|
@@ -441,6 +441,23 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 441 |
|
| 442 |
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
| 443 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 444 |
|
| 445 |
@torch.inference_mode()
|
| 446 |
def encode(
|
|
@@ -454,6 +471,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 454 |
device: Optional[torch.device] = None,
|
| 455 |
normalize_embeddings: bool = False,
|
| 456 |
truncate_dim: Optional[int] = None,
|
|
|
|
| 457 |
**tokenizer_kwargs,
|
| 458 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
| 459 |
"""
|
|
@@ -485,6 +503,10 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 485 |
be used.
|
| 486 |
truncate_dim(`int`, *optional*, defaults to None):
|
| 487 |
The dimension to truncate sentence embeddings to. `None` does no truncation.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 488 |
tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
|
| 489 |
Keyword arguments for the tokenizer
|
| 490 |
Returns:
|
|
@@ -561,7 +583,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 561 |
elif output_value is None:
|
| 562 |
raise NotImplementedError
|
| 563 |
else:
|
| 564 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 565 |
embeddings = self.cls_pooling(
|
| 566 |
token_embs, encoded_input['attention_mask']
|
| 567 |
)
|
|
@@ -579,14 +606,28 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 579 |
|
| 580 |
all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
|
| 581 |
|
| 582 |
-
truncate_dim = truncate_dim or self.config.truncate_dim
|
| 583 |
if truncate_dim:
|
| 584 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 585 |
|
| 586 |
if convert_to_tensor:
|
| 587 |
-
|
|
|
|
|
|
|
|
|
|
| 588 |
elif convert_to_numpy:
|
| 589 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 590 |
|
| 591 |
if input_was_string:
|
| 592 |
all_embeddings = all_embeddings[0]
|
|
|
|
| 441 |
|
| 442 |
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
| 443 |
|
| 444 |
+
def chunking_pooling_inference(model_output, span_annotation):
|
| 445 |
+
token_embeddings = model_output[0]
|
| 446 |
+
outputs = []
|
| 447 |
+
|
| 448 |
+
for embeddings, annotations in zip(token_embeddings, span_annotation):
|
| 449 |
+
clamped_embeddings = torch.clamp(embeddings, min=-10, max=10)
|
| 450 |
+
pooled_embeddings = [
|
| 451 |
+
clamped_embeddings[start:end].sum(dim=0)
|
| 452 |
+
/ (end - start if end - start > 0 else 1)
|
| 453 |
+
for start, end in annotations
|
| 454 |
+
]
|
| 455 |
+
pooled_embeddings = [
|
| 456 |
+
embedding.detach().cpu().numpy() for embedding in pooled_embeddings
|
| 457 |
+
]
|
| 458 |
+
outputs.append(pooled_embeddings)
|
| 459 |
+
|
| 460 |
+
return outputs
|
| 461 |
|
| 462 |
@torch.inference_mode()
|
| 463 |
def encode(
|
|
|
|
| 471 |
device: Optional[torch.device] = None,
|
| 472 |
normalize_embeddings: bool = False,
|
| 473 |
truncate_dim: Optional[int] = None,
|
| 474 |
+
span_annotations: Optional[List[List[Tuple[int]]]] = None,
|
| 475 |
**tokenizer_kwargs,
|
| 476 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
| 477 |
"""
|
|
|
|
| 503 |
be used.
|
| 504 |
truncate_dim(`int`, *optional*, defaults to None):
|
| 505 |
The dimension to truncate sentence embeddings to. `None` does no truncation.
|
| 506 |
+
span_annotations(`List[List[Tuple[int]]]`, *optional*, defaults to None):
|
| 507 |
+
List of list of tuples. Each tuple represents the start and end index of a chunk.
|
| 508 |
+
If provided, the embeddings are pooled for each span, and an embedding for each
|
| 509 |
+
span is returned.
|
| 510 |
tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
|
| 511 |
Keyword arguments for the tokenizer
|
| 512 |
Returns:
|
|
|
|
| 583 |
elif output_value is None:
|
| 584 |
raise NotImplementedError
|
| 585 |
else:
|
| 586 |
+
if span_annotations:
|
| 587 |
+
embeddings = self.chunking_pooling_inference(
|
| 588 |
+
token_embs,
|
| 589 |
+
span_annotations[i : i + batch_size],
|
| 590 |
+
)
|
| 591 |
+
elif self.config.emb_pooler == 'cls':
|
| 592 |
embeddings = self.cls_pooling(
|
| 593 |
token_embs, encoded_input['attention_mask']
|
| 594 |
)
|
|
|
|
| 606 |
|
| 607 |
all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
|
| 608 |
|
|
|
|
| 609 |
if truncate_dim:
|
| 610 |
+
if isinstance(all_embeddings[0], list):
|
| 611 |
+
all_embeddings = [
|
| 612 |
+
[self.truncate_embeddings(chunk, truncate_dim) for chunk in emb_batch]
|
| 613 |
+
for emb_batch in all_embeddings
|
| 614 |
+
]
|
| 615 |
+
else:
|
| 616 |
+
all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
|
| 617 |
|
| 618 |
if convert_to_tensor:
|
| 619 |
+
if isinstance(all_embeddings[0], list):
|
| 620 |
+
all_embeddings = [torch.stack(emb_batch) for emb_batch in all_embeddings]
|
| 621 |
+
else:
|
| 622 |
+
all_embeddings = torch.stack(all_embeddings)
|
| 623 |
elif convert_to_numpy:
|
| 624 |
+
if isinstance(all_embeddings[0], list):
|
| 625 |
+
all_embeddings = [
|
| 626 |
+
np.asarray([chunk.numpy() for chunk in emb_batch])
|
| 627 |
+
for emb_batch in all_embeddings
|
| 628 |
+
]
|
| 629 |
+
else:
|
| 630 |
+
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
|
| 631 |
|
| 632 |
if input_was_string:
|
| 633 |
all_embeddings = all_embeddings[0]
|