| import logging
|
| import torch
|
| import numpy as np
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from typing import Union, Optional, Tuple, List
|
| from pydantic import BaseModel
|
| from tqdm import tqdm
|
| from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| from transformers import ModernBertModel, ModernBertPreTrainedModel, ModernBertConfig
|
|
|
|
|
| class TextSpan(BaseModel):
|
| s: int
|
| e: int
|
| module_name: str
|
| text: Optional[str] = None
|
|
|
|
|
| class Instance(BaseModel):
|
| original_text: str
|
| text_spans: List[TextSpan]
|
|
|
|
|
| def recursive_split(text, chunk_size=256, chunk_overlap=32):
|
| """ recursive split a text by RecursiveCharacterTextSplitter in langchain_text_splitters """
|
| splitter = RecursiveCharacterTextSplitter(
|
| chunk_size=chunk_size,
|
| chunk_overlap=chunk_overlap,
|
| length_function=lambda x: len(x.split()),
|
| separators=["\n\n", "\n", ". ", "? ", "! ", "; "],
|
| )
|
| chunks = splitter.split_text(text)
|
| if not chunks:
|
| logging.error(f"Error, chunks is empty, text:{text}")
|
| return [text], [[0, len(text)]]
|
| chunk_span = [
|
|
|
| [text.find(chunk), text.find(chunk) + len(chunk)]
|
| for chunk in chunks
|
| ]
|
| assert chunk_span[0][0] == 0
|
| assert all((span[0] >= 0 for span in chunk_span))
|
| return chunks, chunk_span
|
|
|
|
|
| def make_batch_input_for_prediction(
|
| texts: List[str],
|
| tokenizer,
|
| max_seq_length: int,
|
| chunk_size=256,
|
| chunk_overlap=32,
|
| prompt: str = "",
|
| fast_chunk: bool = False,
|
| batch_text_spans: List[List[TextSpan]] = None,
|
| ):
|
| """ prepare input"""
|
| if batch_text_spans is not None:
|
| ipt = tokenizer(
|
| [prompt + i for i in texts],
|
| padding="longest",
|
| truncation=True,
|
| max_length=max_seq_length,
|
| return_tensors="pt"
|
| )
|
| for text_spans, data_len in zip(batch_text_spans, ipt["attention_mask"].sum(dim=1)):
|
| for text_span in text_spans:
|
| assert -1 < text_span.s < text_span.e <= data_len
|
| ipt["batch_text_spans"] = batch_text_spans
|
| return ipt
|
| prompt_len = len(tokenizer.tokenize(prompt))
|
| truncated_texts = [
|
| tokenizer.decode(
|
| tokenizer.encode(text)[:max_seq_length - prompt_len - 2],
|
| skip_special_tokens=True,
|
| clean_up_tokenization_spaces=True
|
| ).strip()
|
| for text in texts
|
| ]
|
| ipt = tokenizer(
|
| [prompt + i for i in truncated_texts],
|
| padding="longest",
|
| truncation=True,
|
| max_length=max_seq_length,
|
| return_tensors="pt"
|
| )
|
| batch_text_spans = []
|
| for text, data_len in zip(truncated_texts, ipt["attention_mask"].sum(dim=1)):
|
| text_spans = [
|
| TextSpan(
|
| s=0,
|
| e=1,
|
| module_name="cls_linear",
|
| ),
|
| TextSpan(
|
| s=1 + prompt_len,
|
| e=data_len - 1,
|
| module_name="chunk_linear",
|
| ),
|
| ]
|
|
|
| if chunk_size > 1 and chunk_overlap > -1:
|
|
|
| if fast_chunk:
|
| start_pos, end_pos = 1 + prompt_len, data_len - 1
|
| for s in range(start_pos, end_pos, chunk_size):
|
| s -= chunk_overlap
|
| s = max((s, start_pos))
|
| e = min((s + chunk_size, end_pos))
|
| if e - s > 0 and not (s == start_pos and e == end_pos):
|
| text_spans.append(
|
| TextSpan(
|
| s=s,
|
| e=e,
|
| module_name="chunk_linear",
|
| )
|
| )
|
|
|
| else:
|
| chunks, chunk_span = recursive_split(text, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
| if len(chunks) > 1:
|
| for (s, e), chunk in zip(chunk_span, chunks):
|
| s = len(tokenizer.tokenize(text[:s])) + 1 + prompt_len
|
| e = len(tokenizer.tokenize(text[:e])) + 1 + prompt_len
|
| if s >= e:
|
| continue
|
|
|
| text_spans.append(
|
| TextSpan(
|
| s=s,
|
| e=e,
|
| module_name="chunk_linear",
|
| text=chunk
|
| )
|
| )
|
|
|
| batch_text_spans.append(text_spans)
|
| ipt["batch_text_spans"] = batch_text_spans
|
| return ipt
|
|
|
|
|
| class DeweyV1(ModernBertPreTrainedModel):
|
| def __init__(self, config: ModernBertConfig):
|
| super().__init__(config)
|
| self.config = config
|
| self.model = ModernBertModel(config)
|
| hidden_size = config.hidden_size
|
| vector_size = config.vector_size
|
| self.linear_dict = nn.ModuleDict(
|
| {
|
| "cls_linear": nn.Linear(hidden_size, vector_size, bias=True),
|
| "chunk_linear": nn.Linear(hidden_size, vector_size, bias=True),
|
| }
|
| )
|
|
|
| self.post_init()
|
|
|
| def get_multi_vectors(
|
| self,
|
| batch_token_embeddings: torch.Tensor,
|
| batch_text_spans: List[List[TextSpan]],
|
| normalize_embeddings: bool = True
|
| ) -> List[torch.Tensor]:
|
| multi_vectors = []
|
| for token_embeddings, text_spans in zip(batch_token_embeddings, batch_text_spans):
|
| chunk_vectors = []
|
| for text_span in text_spans:
|
| s, e = text_span.s, text_span.e
|
| if s >= token_embeddings.shape[0] or s >= e:
|
| logging.warning(
|
| f"given span is wrong, s, e, token_embeddings.shape: {s, e, token_embeddings.shape}",
|
| )
|
| s, e = 0, 1
|
| mean_tokens_embs = token_embeddings[s:e, :].mean(dim=0, keepdim=True)
|
|
|
|
|
| chunk_vectors.append(
|
| self.linear_dict[text_span.module_name](mean_tokens_embs),
|
| )
|
| chunk_vectors = torch.cat(chunk_vectors, dim=0)
|
| if normalize_embeddings:
|
| multi_vectors.append(F.normalize(chunk_vectors, p=2, dim=-1))
|
| else:
|
| multi_vectors.append(chunk_vectors)
|
| return multi_vectors
|
|
|
| def forward(
|
| self,
|
| input_ids: torch.Tensor,
|
| attention_mask: torch.Tensor,
|
| batch_text_spans: List[List[TextSpan]],
|
| normalize_embeddings: bool = True,
|
| *args,
|
| **kwargs
|
| ) -> List[torch.Tensor]:
|
| batch_token_embeddings = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
|
| multi_vectors = self.get_multi_vectors(
|
| batch_token_embeddings=batch_token_embeddings,
|
| batch_text_spans=batch_text_spans,
|
| normalize_embeddings=normalize_embeddings
|
| )
|
| return multi_vectors
|
|
|
| @torch.no_grad()
|
| def encode(
|
| self,
|
| sentences: str | list[str],
|
| batch_size: int = 32,
|
| use_cuda: bool = True,
|
| show_progress_bar: bool = True,
|
| chunk_size: int = 256,
|
| chunk_overlap: int = 32,
|
| convert_to_tensor: bool = False,
|
| max_seq_length: int = 8192,
|
| normalize_embeddings: bool = True,
|
| prompt: str = "",
|
| fast_chunk: bool = False,
|
| batch_text_spans: List[List[TextSpan]] = None,
|
| *args,
|
| **kwargs
|
| ) -> Tuple[List[Union[np.ndarray, torch.Tensor]] | torch.Tensor | np.ndarray, List[List[TextSpan]]]:
|
| """
|
| encode sentences to multi vectors
|
| Args:
|
| sentences: str | list[str], The sentences to embed
|
| batch_size: int
|
| use_cuda: bool, Whether to use GPU for inference
|
| show_progress_bar: bool, Whether to display the progress bar
|
| chunk_size: int, the number tokens of chunk, The recommended size is between 64-1024. The larger the value,
|
| the faster the speed, but the effect may decrease. The smaller the value, the slower the speed,
|
| and when the value is very small, the effect may also decrease.
|
| chunk_overlap: int, Overlap in characters between chunks
|
| convert_to_tensor: bool, If true: convert to torch fp32 tensor, otherwise will return fp32 ndarray
|
| max_seq_length: int, max length of text
|
| normalize_embeddings: bool, whether to do a L2-normalize for vectors
|
| prompt: str, the prompt for text, the final text to be encoded is "[CLS]{prompt}{sentence}[SEP]",
|
| Note, you CANNOT manually add a prompt before the sentence yourself, as this will affect our length calculation!
|
| fast_chunk: bool, if true, directly chunk on input ids, else using RecursiveCharacterTextSplitter
|
| batch_text_spans: List[List[TextSpan]], default is None, if provided, the model will not chunk text anymore
|
| *args:
|
| **kwargs:
|
|
|
| Returns:
|
| List[tensor|ndarray], each text's multi vectors
|
| """
|
| self.eval()
|
|
|
| if isinstance(sentences, str):
|
| sentences = [sentences]
|
| deduplicate_sentences = list(set(sentences))
|
| deduplicate_sentences.sort(key=lambda x: len(x), reverse=True)
|
|
|
| vectors_list, text_spans = [], []
|
| for start in tqdm(
|
| range(0, len(deduplicate_sentences), batch_size),
|
| desc="encoding text...",
|
| disable=not show_progress_bar
|
| ):
|
| batch = deduplicate_sentences[start:start + batch_size]
|
| ipt = make_batch_input_for_prediction(
|
| batch,
|
| tokenizer=self.tokenizer,
|
| max_seq_length=max_seq_length,
|
| chunk_size=chunk_size,
|
| chunk_overlap=chunk_overlap,
|
| prompt=prompt,
|
| fast_chunk=fast_chunk,
|
| batch_text_spans=batch_text_spans
|
| )
|
| text_spans.extend(ipt["batch_text_spans"])
|
| ipt = {k: v.cuda() if use_cuda and isinstance(v, torch.Tensor) else v for k, v in ipt.items()}
|
| vectors_list.extend(self(**ipt, normalize_embeddings=normalize_embeddings))
|
|
|
| assert len(deduplicate_sentences) == len(vectors_list)
|
| sen2vecs = dict(zip(deduplicate_sentences, vectors_list))
|
| sen2spans = dict(zip(deduplicate_sentences, text_spans))
|
|
|
| text_spans = [sen2spans[sen] for sen in sentences]
|
| if convert_to_tensor:
|
| result = [sen2vecs[sen].cpu().float() for sen in sentences]
|
| else:
|
| result = [sen2vecs[sen].cpu().float().numpy() for sen in sentences]
|
| return result, text_spans
|
|
|