| from __future__ import annotations |
|
|
| import copy |
| import logging |
| from abc import ABC, abstractmethod |
| from dataclasses import dataclass |
| from enum import Enum |
| from typing import ( |
| AbstractSet, |
| Any, |
| Callable, |
| Collection, |
| Iterable, |
| List, |
| Literal, |
| Optional, |
| Sequence, |
| Type, |
| TypeVar, |
| Union, |
| ) |
|
|
| from langchain_core.documents import BaseDocumentTransformer, Document |
|
|
| logger = logging.getLogger(__name__) |
|
|
| TS = TypeVar("TS", bound="TextSplitter") |
|
|
|
|
| class TextSplitter(BaseDocumentTransformer, ABC): |
| """Interface for splitting text into chunks.""" |
|
|
| def __init__( |
| self, |
| chunk_size: int = 4000, |
| chunk_overlap: int = 200, |
| length_function: Callable[[str], int] = len, |
| keep_separator: Union[bool, Literal["start", "end"]] = False, |
| add_start_index: bool = False, |
| strip_whitespace: bool = True, |
| ) -> None: |
| """Create a new TextSplitter. |
| |
| Args: |
| chunk_size: Maximum size of chunks to return |
| chunk_overlap: Overlap in characters between chunks |
| length_function: Function that measures the length of given chunks |
| keep_separator: Whether to keep the separator and where to place it |
| in each corresponding chunk (True='start') |
| add_start_index: If `True`, includes chunk's start index in metadata |
| strip_whitespace: If `True`, strips whitespace from the start and end of |
| every document |
| """ |
| if chunk_overlap > chunk_size: |
| raise ValueError( |
| f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " |
| f"({chunk_size}), should be smaller." |
| ) |
| self._chunk_size = chunk_size |
| self._chunk_overlap = chunk_overlap |
| self._length_function = length_function |
| self._keep_separator = keep_separator |
| self._add_start_index = add_start_index |
| self._strip_whitespace = strip_whitespace |
|
|
| @abstractmethod |
| def split_text(self, text: str) -> List[str]: |
| """Split text into multiple components.""" |
|
|
| def create_documents( |
| self, texts: List[str], metadatas: Optional[List[dict]] = None |
| ) -> List[Document]: |
| """Create documents from a list of texts.""" |
| _metadatas = metadatas or [{}] * len(texts) |
| documents = [] |
| for i, text in enumerate(texts): |
| index = 0 |
| previous_chunk_len = 0 |
| for chunk in self.split_text(text): |
| metadata = copy.deepcopy(_metadatas[i]) |
| if self._add_start_index: |
| offset = index + previous_chunk_len - self._chunk_overlap |
| index = text.find(chunk, max(0, offset)) |
| metadata["start_index"] = index |
| previous_chunk_len = len(chunk) |
| new_doc = Document(page_content=chunk, metadata=metadata) |
| documents.append(new_doc) |
| return documents |
|
|
| def split_documents(self, documents: Iterable[Document]) -> List[Document]: |
| """Split documents.""" |
| texts, metadatas = [], [] |
| for doc in documents: |
| texts.append(doc.page_content) |
| metadatas.append(doc.metadata) |
| return self.create_documents(texts, metadatas=metadatas) |
|
|
| def _join_docs(self, docs: List[str], separator: str) -> Optional[str]: |
| text = separator.join(docs) |
| if self._strip_whitespace: |
| text = text.strip() |
| if text == "": |
| return None |
| else: |
| return text |
|
|
| def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]: |
| |
| |
| separator_len = self._length_function(separator) |
|
|
| docs = [] |
| current_doc: List[str] = [] |
| total = 0 |
| for d in splits: |
| _len = self._length_function(d) |
| if ( |
| total + _len + (separator_len if len(current_doc) > 0 else 0) |
| > self._chunk_size |
| ): |
| if total > self._chunk_size: |
| logger.warning( |
| f"Created a chunk of size {total}, " |
| f"which is longer than the specified {self._chunk_size}" |
| ) |
| if len(current_doc) > 0: |
| doc = self._join_docs(current_doc, separator) |
| if doc is not None: |
| docs.append(doc) |
| |
| |
| |
| while total > self._chunk_overlap or ( |
| total + _len + (separator_len if len(current_doc) > 0 else 0) |
| > self._chunk_size |
| and total > 0 |
| ): |
| total -= self._length_function(current_doc[0]) + ( |
| separator_len if len(current_doc) > 1 else 0 |
| ) |
| current_doc = current_doc[1:] |
| current_doc.append(d) |
| total += _len + (separator_len if len(current_doc) > 1 else 0) |
| doc = self._join_docs(current_doc, separator) |
| if doc is not None: |
| docs.append(doc) |
| return docs |
|
|
| @classmethod |
| def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter: |
| """Text splitter that uses HuggingFace tokenizer to count length.""" |
| try: |
| from transformers import PreTrainedTokenizerBase |
|
|
| if not isinstance(tokenizer, PreTrainedTokenizerBase): |
| raise ValueError( |
| "Tokenizer received was not an instance of PreTrainedTokenizerBase" |
| ) |
|
|
| def _huggingface_tokenizer_length(text: str) -> int: |
| return len(tokenizer.tokenize(text)) |
|
|
| except ImportError: |
| raise ValueError( |
| "Could not import transformers python package. " |
| "Please install it with `pip install transformers`." |
| ) |
| return cls(length_function=_huggingface_tokenizer_length, **kwargs) |
|
|
| @classmethod |
| def from_tiktoken_encoder( |
| cls: Type[TS], |
| encoding_name: str = "gpt2", |
| model_name: Optional[str] = None, |
| allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), |
| disallowed_special: Union[Literal["all"], Collection[str]] = "all", |
| **kwargs: Any, |
| ) -> TS: |
| """Text splitter that uses tiktoken encoder to count length.""" |
| try: |
| import tiktoken |
| except ImportError: |
| raise ImportError( |
| "Could not import tiktoken python package. " |
| "This is needed in order to calculate max_tokens_for_prompt. " |
| "Please install it with `pip install tiktoken`." |
| ) |
|
|
| if model_name is not None: |
| enc = tiktoken.encoding_for_model(model_name) |
| else: |
| enc = tiktoken.get_encoding(encoding_name) |
|
|
| def _tiktoken_encoder(text: str) -> int: |
| return len( |
| enc.encode( |
| text, |
| allowed_special=allowed_special, |
| disallowed_special=disallowed_special, |
| ) |
| ) |
|
|
| if issubclass(cls, TokenTextSplitter): |
| extra_kwargs = { |
| "encoding_name": encoding_name, |
| "model_name": model_name, |
| "allowed_special": allowed_special, |
| "disallowed_special": disallowed_special, |
| } |
| kwargs = {**kwargs, **extra_kwargs} |
|
|
| return cls(length_function=_tiktoken_encoder, **kwargs) |
|
|
| def transform_documents( |
| self, documents: Sequence[Document], **kwargs: Any |
| ) -> Sequence[Document]: |
| """Transform sequence of documents by splitting them.""" |
| return self.split_documents(list(documents)) |
|
|
|
|
| class TokenTextSplitter(TextSplitter): |
| """Splitting text to tokens using model tokenizer.""" |
|
|
| def __init__( |
| self, |
| encoding_name: str = "gpt2", |
| model_name: Optional[str] = None, |
| allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), |
| disallowed_special: Union[Literal["all"], Collection[str]] = "all", |
| **kwargs: Any, |
| ) -> None: |
| """Create a new TextSplitter.""" |
| super().__init__(**kwargs) |
| try: |
| import tiktoken |
| except ImportError: |
| raise ImportError( |
| "Could not import tiktoken python package. " |
| "This is needed in order to for TokenTextSplitter. " |
| "Please install it with `pip install tiktoken`." |
| ) |
|
|
| if model_name is not None: |
| enc = tiktoken.encoding_for_model(model_name) |
| else: |
| enc = tiktoken.get_encoding(encoding_name) |
| self._tokenizer = enc |
| self._allowed_special = allowed_special |
| self._disallowed_special = disallowed_special |
|
|
| def split_text(self, text: str) -> List[str]: |
| """Splits the input text into smaller chunks based on tokenization. |
| |
| This method uses a custom tokenizer configuration to encode the input text |
| into tokens, processes the tokens in chunks of a specified size with overlap, |
| and decodes them back into text chunks. The splitting is performed using the |
| `split_text_on_tokens` function. |
| |
| Args: |
| text (str): The input text to be split into smaller chunks. |
| |
| Returns: |
| List[str]: A list of text chunks, where each chunk is derived from a portion |
| of the input text based on the tokenization and chunking rules. |
| """ |
|
|
| def _encode(_text: str) -> List[int]: |
| return self._tokenizer.encode( |
| _text, |
| allowed_special=self._allowed_special, |
| disallowed_special=self._disallowed_special, |
| ) |
|
|
| tokenizer = Tokenizer( |
| chunk_overlap=self._chunk_overlap, |
| tokens_per_chunk=self._chunk_size, |
| decode=self._tokenizer.decode, |
| encode=_encode, |
| ) |
|
|
| return split_text_on_tokens(text=text, tokenizer=tokenizer) |
|
|
|
|
| class Language(str, Enum): |
| """Enum of the programming languages.""" |
|
|
| CPP = "cpp" |
| GO = "go" |
| JAVA = "java" |
| KOTLIN = "kotlin" |
| JS = "js" |
| TS = "ts" |
| PHP = "php" |
| PROTO = "proto" |
| PYTHON = "python" |
| RST = "rst" |
| RUBY = "ruby" |
| RUST = "rust" |
| SCALA = "scala" |
| SWIFT = "swift" |
| MARKDOWN = "markdown" |
| LATEX = "latex" |
| HTML = "html" |
| SOL = "sol" |
| CSHARP = "csharp" |
| COBOL = "cobol" |
| C = "c" |
| LUA = "lua" |
| PERL = "perl" |
| HASKELL = "haskell" |
| ELIXIR = "elixir" |
| POWERSHELL = "powershell" |
|
|
|
|
| @dataclass(frozen=True) |
| class Tokenizer: |
| """Tokenizer data class.""" |
|
|
| chunk_overlap: int |
| """Overlap in tokens between chunks""" |
| tokens_per_chunk: int |
| """Maximum number of tokens per chunk""" |
| decode: Callable[[List[int]], str] |
| """ Function to decode a list of token ids to a string""" |
| encode: Callable[[str], List[int]] |
| """ Function to encode a string to a list of token ids""" |
|
|
|
|
| def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> List[str]: |
| """Split incoming text and return chunks using tokenizer.""" |
| splits: List[str] = [] |
| input_ids = tokenizer.encode(text) |
| start_idx = 0 |
| cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) |
| chunk_ids = input_ids[start_idx:cur_idx] |
| while start_idx < len(input_ids): |
| splits.append(tokenizer.decode(chunk_ids)) |
| if cur_idx == len(input_ids): |
| break |
| start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap |
| cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) |
| chunk_ids = input_ids[start_idx:cur_idx] |
| return splits |
|
|