| | from __future__ import annotations |
| |
|
| | import copy |
| | import logging |
| | import re |
| | from abc import ABC, abstractmethod |
| | from collections.abc import Callable, Collection, Iterable, Sequence, Set |
| | from dataclasses import dataclass |
| | from typing import ( |
| | Any, |
| | Literal, |
| | Optional, |
| | TypedDict, |
| | TypeVar, |
| | Union, |
| | ) |
| |
|
| | from core.rag.models.document import BaseDocumentTransformer, Document |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | TS = TypeVar("TS", bound="TextSplitter") |
| |
|
| |
|
| | def _split_text_with_regex(text: str, separator: str, keep_separator: bool) -> list[str]: |
| | |
| | if separator: |
| | if keep_separator: |
| | |
| | _splits = re.split(f"({re.escape(separator)})", text) |
| | splits = [_splits[i - 1] + _splits[i] for i in range(1, len(_splits), 2)] |
| | if len(_splits) % 2 != 0: |
| | splits += _splits[-1:] |
| | else: |
| | splits = re.split(separator, text) |
| | else: |
| | splits = list(text) |
| | return [s for s in splits if (s not in {"", "\n"})] |
| |
|
| |
|
| | class TextSplitter(BaseDocumentTransformer, ABC): |
| | """Interface for splitting text into chunks.""" |
| |
|
| | def __init__( |
| | self, |
| | chunk_size: int = 4000, |
| | chunk_overlap: int = 200, |
| | length_function: Callable[[str], int] = len, |
| | keep_separator: bool = False, |
| | add_start_index: bool = False, |
| | ) -> None: |
| | """Create a new TextSplitter. |
| | |
| | Args: |
| | chunk_size: Maximum size of chunks to return |
| | chunk_overlap: Overlap in characters between chunks |
| | length_function: Function that measures the length of given chunks |
| | keep_separator: Whether to keep the separator in the chunks |
| | add_start_index: If `True`, includes chunk's start index in metadata |
| | """ |
| | if chunk_overlap > chunk_size: |
| | raise ValueError( |
| | f"Got a larger chunk overlap ({chunk_overlap}) than chunk size ({chunk_size}), should be smaller." |
| | ) |
| | self._chunk_size = chunk_size |
| | self._chunk_overlap = chunk_overlap |
| | self._length_function = length_function |
| | self._keep_separator = keep_separator |
| | self._add_start_index = add_start_index |
| |
|
| | @abstractmethod |
| | def split_text(self, text: str) -> list[str]: |
| | """Split text into multiple components.""" |
| |
|
| | def create_documents(self, texts: list[str], metadatas: Optional[list[dict]] = None) -> list[Document]: |
| | """Create documents from a list of texts.""" |
| | _metadatas = metadatas or [{}] * len(texts) |
| | documents = [] |
| | for i, text in enumerate(texts): |
| | index = -1 |
| | for chunk in self.split_text(text): |
| | metadata = copy.deepcopy(_metadatas[i]) |
| | if self._add_start_index: |
| | index = text.find(chunk, index + 1) |
| | metadata["start_index"] = index |
| | new_doc = Document(page_content=chunk, metadata=metadata) |
| | documents.append(new_doc) |
| | return documents |
| |
|
| | def split_documents(self, documents: Iterable[Document]) -> list[Document]: |
| | """Split documents.""" |
| | texts, metadatas = [], [] |
| | for doc in documents: |
| | texts.append(doc.page_content) |
| | metadatas.append(doc.metadata) |
| | return self.create_documents(texts, metadatas=metadatas) |
| |
|
| | def _join_docs(self, docs: list[str], separator: str) -> Optional[str]: |
| | text = separator.join(docs) |
| | text = text.strip() |
| | if text == "": |
| | return None |
| | else: |
| | return text |
| |
|
| | def _merge_splits(self, splits: Iterable[str], separator: str, lengths: list[int]) -> list[str]: |
| | |
| | |
| | separator_len = self._length_function(separator) |
| |
|
| | docs = [] |
| | current_doc: list[str] = [] |
| | total = 0 |
| | index = 0 |
| | for d in splits: |
| | _len = lengths[index] |
| | if total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size: |
| | if total > self._chunk_size: |
| | logger.warning( |
| | f"Created a chunk of size {total}, which is longer than the specified {self._chunk_size}" |
| | ) |
| | if len(current_doc) > 0: |
| | doc = self._join_docs(current_doc, separator) |
| | if doc is not None: |
| | docs.append(doc) |
| | |
| | |
| | |
| | while total > self._chunk_overlap or ( |
| | total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size and total > 0 |
| | ): |
| | total -= self._length_function(current_doc[0]) + (separator_len if len(current_doc) > 1 else 0) |
| | current_doc = current_doc[1:] |
| | current_doc.append(d) |
| | total += _len + (separator_len if len(current_doc) > 1 else 0) |
| | index += 1 |
| | doc = self._join_docs(current_doc, separator) |
| | if doc is not None: |
| | docs.append(doc) |
| | return docs |
| |
|
| | @classmethod |
| | def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter: |
| | """Text splitter that uses HuggingFace tokenizer to count length.""" |
| | try: |
| | from transformers import PreTrainedTokenizerBase |
| |
|
| | if not isinstance(tokenizer, PreTrainedTokenizerBase): |
| | raise ValueError("Tokenizer received was not an instance of PreTrainedTokenizerBase") |
| |
|
| | def _huggingface_tokenizer_length(text: str) -> int: |
| | return len(tokenizer.encode(text)) |
| |
|
| | except ImportError: |
| | raise ValueError( |
| | "Could not import transformers python package. Please install it with `pip install transformers`." |
| | ) |
| | return cls(length_function=_huggingface_tokenizer_length, **kwargs) |
| |
|
| | @classmethod |
| | def from_tiktoken_encoder( |
| | cls: type[TS], |
| | encoding_name: str = "gpt2", |
| | model_name: Optional[str] = None, |
| | allowed_special: Union[Literal["all"], Set[str]] = set(), |
| | disallowed_special: Union[Literal["all"], Collection[str]] = "all", |
| | **kwargs: Any, |
| | ) -> TS: |
| | """Text splitter that uses tiktoken encoder to count length.""" |
| | try: |
| | import tiktoken |
| | except ImportError: |
| | raise ImportError( |
| | "Could not import tiktoken python package. " |
| | "This is needed in order to calculate max_tokens_for_prompt. " |
| | "Please install it with `pip install tiktoken`." |
| | ) |
| |
|
| | if model_name is not None: |
| | enc = tiktoken.encoding_for_model(model_name) |
| | else: |
| | enc = tiktoken.get_encoding(encoding_name) |
| |
|
| | def _tiktoken_encoder(text: str) -> int: |
| | return len( |
| | enc.encode( |
| | text, |
| | allowed_special=allowed_special, |
| | disallowed_special=disallowed_special, |
| | ) |
| | ) |
| |
|
| | if issubclass(cls, TokenTextSplitter): |
| | extra_kwargs = { |
| | "encoding_name": encoding_name, |
| | "model_name": model_name, |
| | "allowed_special": allowed_special, |
| | "disallowed_special": disallowed_special, |
| | } |
| | kwargs = {**kwargs, **extra_kwargs} |
| |
|
| | return cls(length_function=_tiktoken_encoder, **kwargs) |
| |
|
| | def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]: |
| | """Transform sequence of documents by splitting them.""" |
| | return self.split_documents(list(documents)) |
| |
|
| | async def atransform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]: |
| | """Asynchronously transform a sequence of documents by splitting them.""" |
| | raise NotImplementedError |
| |
|
| |
|
| | class CharacterTextSplitter(TextSplitter): |
| | """Splitting text that looks at characters.""" |
| |
|
| | def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None: |
| | """Create a new TextSplitter.""" |
| | super().__init__(**kwargs) |
| | self._separator = separator |
| |
|
| | def split_text(self, text: str) -> list[str]: |
| | """Split incoming text and return chunks.""" |
| | |
| | splits = _split_text_with_regex(text, self._separator, self._keep_separator) |
| | _separator = "" if self._keep_separator else self._separator |
| | _good_splits_lengths = [] |
| | for split in splits: |
| | _good_splits_lengths.append(self._length_function(split)) |
| | return self._merge_splits(splits, _separator, _good_splits_lengths) |
| |
|
| |
|
| | class LineType(TypedDict): |
| | """Line type as typed dict.""" |
| |
|
| | metadata: dict[str, str] |
| | content: str |
| |
|
| |
|
| | class HeaderType(TypedDict): |
| | """Header type as typed dict.""" |
| |
|
| | level: int |
| | name: str |
| | data: str |
| |
|
| |
|
| | class MarkdownHeaderTextSplitter: |
| | """Splitting markdown files based on specified headers.""" |
| |
|
| | def __init__(self, headers_to_split_on: list[tuple[str, str]], return_each_line: bool = False): |
| | """Create a new MarkdownHeaderTextSplitter. |
| | |
| | Args: |
| | headers_to_split_on: Headers we want to track |
| | return_each_line: Return each line w/ associated headers |
| | """ |
| | |
| | self.return_each_line = return_each_line |
| | |
| | |
| | self.headers_to_split_on = sorted(headers_to_split_on, key=lambda split: len(split[0]), reverse=True) |
| |
|
| | def aggregate_lines_to_chunks(self, lines: list[LineType]) -> list[Document]: |
| | """Combine lines with common metadata into chunks |
| | Args: |
| | lines: Line of text / associated header metadata |
| | """ |
| | aggregated_chunks: list[LineType] = [] |
| |
|
| | for line in lines: |
| | if aggregated_chunks and aggregated_chunks[-1]["metadata"] == line["metadata"]: |
| | |
| | |
| | |
| | aggregated_chunks[-1]["content"] += " \n" + line["content"] |
| | else: |
| | |
| | aggregated_chunks.append(line) |
| |
|
| | return [Document(page_content=chunk["content"], metadata=chunk["metadata"]) for chunk in aggregated_chunks] |
| |
|
| | def split_text(self, text: str) -> list[Document]: |
| | """Split markdown file |
| | Args: |
| | text: Markdown file""" |
| |
|
| | |
| | lines = text.split("\n") |
| | |
| | lines_with_metadata: list[LineType] = [] |
| | |
| | current_content: list[str] = [] |
| | current_metadata: dict[str, str] = {} |
| | |
| | |
| | header_stack: list[HeaderType] = [] |
| | initial_metadata: dict[str, str] = {} |
| |
|
| | for line in lines: |
| | stripped_line = line.strip() |
| | |
| | for sep, name in self.headers_to_split_on: |
| | |
| | if stripped_line.startswith(sep) and ( |
| | |
| | |
| | len(stripped_line) == len(sep) or stripped_line[len(sep)] == " " |
| | ): |
| | |
| | if name is not None: |
| | |
| | current_header_level = sep.count("#") |
| |
|
| | |
| | while header_stack and header_stack[-1]["level"] >= current_header_level: |
| | |
| | |
| | popped_header = header_stack.pop() |
| | |
| | |
| | if popped_header["name"] in initial_metadata: |
| | initial_metadata.pop(popped_header["name"]) |
| |
|
| | |
| | header: HeaderType = { |
| | "level": current_header_level, |
| | "name": name, |
| | "data": stripped_line[len(sep) :].strip(), |
| | } |
| | header_stack.append(header) |
| | |
| | initial_metadata[name] = header["data"] |
| |
|
| | |
| | |
| | if current_content: |
| | lines_with_metadata.append( |
| | { |
| | "content": "\n".join(current_content), |
| | "metadata": current_metadata.copy(), |
| | } |
| | ) |
| | current_content.clear() |
| |
|
| | break |
| | else: |
| | if stripped_line: |
| | current_content.append(stripped_line) |
| | elif current_content: |
| | lines_with_metadata.append( |
| | { |
| | "content": "\n".join(current_content), |
| | "metadata": current_metadata.copy(), |
| | } |
| | ) |
| | current_content.clear() |
| |
|
| | current_metadata = initial_metadata.copy() |
| |
|
| | if current_content: |
| | lines_with_metadata.append({"content": "\n".join(current_content), "metadata": current_metadata}) |
| |
|
| | |
| | |
| | if not self.return_each_line: |
| | return self.aggregate_lines_to_chunks(lines_with_metadata) |
| | else: |
| | return [ |
| | Document(page_content=chunk["content"], metadata=chunk["metadata"]) for chunk in lines_with_metadata |
| | ] |
| |
|
| |
|
| | |
| | |
| | @dataclass(frozen=True) |
| | class Tokenizer: |
| | chunk_overlap: int |
| | tokens_per_chunk: int |
| | decode: Callable[[list[int]], str] |
| | encode: Callable[[str], list[int]] |
| |
|
| |
|
| | def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]: |
| | """Split incoming text and return chunks using tokenizer.""" |
| | splits: list[str] = [] |
| | input_ids = tokenizer.encode(text) |
| | start_idx = 0 |
| | cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) |
| | chunk_ids = input_ids[start_idx:cur_idx] |
| | while start_idx < len(input_ids): |
| | splits.append(tokenizer.decode(chunk_ids)) |
| | start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap |
| | cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) |
| | chunk_ids = input_ids[start_idx:cur_idx] |
| | return splits |
| |
|
| |
|
| | class TokenTextSplitter(TextSplitter): |
| | """Splitting text to tokens using model tokenizer.""" |
| |
|
| | def __init__( |
| | self, |
| | encoding_name: str = "gpt2", |
| | model_name: Optional[str] = None, |
| | allowed_special: Union[Literal["all"], Set[str]] = set(), |
| | disallowed_special: Union[Literal["all"], Collection[str]] = "all", |
| | **kwargs: Any, |
| | ) -> None: |
| | """Create a new TextSplitter.""" |
| | super().__init__(**kwargs) |
| | try: |
| | import tiktoken |
| | except ImportError: |
| | raise ImportError( |
| | "Could not import tiktoken python package. " |
| | "This is needed in order to for TokenTextSplitter. " |
| | "Please install it with `pip install tiktoken`." |
| | ) |
| |
|
| | if model_name is not None: |
| | enc = tiktoken.encoding_for_model(model_name) |
| | else: |
| | enc = tiktoken.get_encoding(encoding_name) |
| | self._tokenizer = enc |
| | self._allowed_special = allowed_special |
| | self._disallowed_special = disallowed_special |
| |
|
| | def split_text(self, text: str) -> list[str]: |
| | def _encode(_text: str) -> list[int]: |
| | return self._tokenizer.encode( |
| | _text, |
| | allowed_special=self._allowed_special, |
| | disallowed_special=self._disallowed_special, |
| | ) |
| |
|
| | tokenizer = Tokenizer( |
| | chunk_overlap=self._chunk_overlap, |
| | tokens_per_chunk=self._chunk_size, |
| | decode=self._tokenizer.decode, |
| | encode=_encode, |
| | ) |
| |
|
| | return split_text_on_tokens(text=text, tokenizer=tokenizer) |
| |
|
| |
|
| | class RecursiveCharacterTextSplitter(TextSplitter): |
| | """Splitting text by recursively look at characters. |
| | |
| | Recursively tries to split by different characters to find one |
| | that works. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | separators: Optional[list[str]] = None, |
| | keep_separator: bool = True, |
| | **kwargs: Any, |
| | ) -> None: |
| | """Create a new TextSplitter.""" |
| | super().__init__(keep_separator=keep_separator, **kwargs) |
| | self._separators = separators or ["\n\n", "\n", " ", ""] |
| |
|
| | def _split_text(self, text: str, separators: list[str]) -> list[str]: |
| | final_chunks = [] |
| | separator = separators[-1] |
| | new_separators = [] |
| |
|
| | for i, _s in enumerate(separators): |
| | if _s == "": |
| | separator = _s |
| | break |
| | if re.search(_s, text): |
| | separator = _s |
| | new_separators = separators[i + 1 :] |
| | break |
| |
|
| | splits = _split_text_with_regex(text, separator, self._keep_separator) |
| | _good_splits = [] |
| | _good_splits_lengths = [] |
| | _separator = "" if self._keep_separator else separator |
| |
|
| | for s in splits: |
| | s_len = self._length_function(s) |
| | if s_len < self._chunk_size: |
| | _good_splits.append(s) |
| | _good_splits_lengths.append(s_len) |
| | else: |
| | if _good_splits: |
| | merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths) |
| | final_chunks.extend(merged_text) |
| | _good_splits = [] |
| | _good_splits_lengths = [] |
| | if not new_separators: |
| | final_chunks.append(s) |
| | else: |
| | other_info = self._split_text(s, new_separators) |
| | final_chunks.extend(other_info) |
| |
|
| | if _good_splits: |
| | merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths) |
| | final_chunks.extend(merged_text) |
| |
|
| | return final_chunks |
| |
|
| | def split_text(self, text: str) -> list[str]: |
| | return self._split_text(text, self._separators) |
| |
|