| import torch |
| from typing import Optional |
| from pydantic import BaseModel |
| from sentence_transformers.models import Transformer as BaseTransformer |
|
|
|
|
| class TextSpan(BaseModel): |
| s: int |
| e: int |
| module_name: str |
| text: Optional[str] = None |
|
|
|
|
| class DeweyTransformer(BaseTransformer): |
| def __init__( |
| self, |
| model_name_or_path: str, |
| **kwargs, |
| ): |
| self.single_vector_type = kwargs.get("config_args", {}).get("single_vector_type", "mean") |
| super().__init__(model_name_or_path, **kwargs) |
|
|
| def forward( |
| self, features: dict[str, torch.Tensor], **kwargs |
| ) -> dict[str, torch.Tensor]: |
| prompt_length = features.get("prompt_length", 0) |
| if prompt_length > 0: |
| |
| prompt_length -= 1 |
| batch_text_spans = [] |
| for data_len in features["attention_mask"].sum(dim=1): |
| if self.single_vector_type == "cls": |
| batch_text_spans.append( |
| [ |
| TextSpan(s=0, e=1, module_name="cls_linear") |
| ] |
| ) |
| elif self.single_vector_type == "mean": |
| batch_text_spans.append( |
| [ |
| TextSpan(s=1 + prompt_length, e=data_len - 1, module_name="chunk_linear") |
| ] |
| ) |
| elif self.single_vector_type == "cls_add_mean": |
| batch_text_spans.append( |
| [ |
| TextSpan(s=0, e=1, module_name="cls_linear"), |
| TextSpan(s=1 + prompt_length, e=data_len - 1, module_name="chunk_linear") |
| ] |
| ) |
| else: |
| raise Exception("single_vector_type should be in {cls, mean or cls_add_mean}") |
|
|
| trans_features = { |
| "input_ids": features["input_ids"], |
| "attention_mask": features["attention_mask"], |
| "batch_text_spans": batch_text_spans, |
| "normalize_embeddings": self.single_vector_type == "cls_add_mean", |
| } |
| |
| vectors_list = self.auto_model(**trans_features, **kwargs) |
| sentence_embedding = torch.cat( |
| [vecs.mean(dim=0, keepdim=True) for vecs in vectors_list], |
| dim=0 |
| ) |
| features.update({"sentence_embedding": sentence_embedding}) |
| return features |
|
|