Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import copy | |
| import logging | |
| from abc import ABC, abstractmethod | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| from typing import ( | |
| AbstractSet, | |
| Any, | |
| Callable, | |
| Collection, | |
| Iterable, | |
| List, | |
| Literal, | |
| Optional, | |
| Sequence, | |
| Type, | |
| TypeVar, | |
| Union, | |
| ) | |
| from langchain_core.documents import BaseDocumentTransformer, Document | |
| logger = logging.getLogger(__name__) | |
| TS = TypeVar("TS", bound="TextSplitter") | |
| class TextSplitter(BaseDocumentTransformer, ABC): | |
| """Interface for splitting text into chunks.""" | |
| def __init__( | |
| self, | |
| chunk_size: int = 4000, | |
| chunk_overlap: int = 200, | |
| length_function: Callable[[str], int] = len, | |
| keep_separator: Union[bool, Literal["start", "end"]] = False, | |
| add_start_index: bool = False, | |
| strip_whitespace: bool = True, | |
| ) -> None: | |
| """Create a new TextSplitter. | |
| Args: | |
| chunk_size: Maximum size of chunks to return | |
| chunk_overlap: Overlap in characters between chunks | |
| length_function: Function that measures the length of given chunks | |
| keep_separator: Whether to keep the separator and where to place it | |
| in each corresponding chunk (True='start') | |
| add_start_index: If `True`, includes chunk's start index in metadata | |
| strip_whitespace: If `True`, strips whitespace from the start and end of | |
| every document | |
| """ | |
| if chunk_overlap > chunk_size: | |
| raise ValueError( | |
| f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " | |
| f"({chunk_size}), should be smaller." | |
| ) | |
| self._chunk_size = chunk_size | |
| self._chunk_overlap = chunk_overlap | |
| self._length_function = length_function | |
| self._keep_separator = keep_separator | |
| self._add_start_index = add_start_index | |
| self._strip_whitespace = strip_whitespace | |
| def split_text(self, text: str) -> List[str]: | |
| """Split text into multiple components.""" | |
| def create_documents( | |
| self, texts: List[str], metadatas: Optional[List[dict]] = None | |
| ) -> List[Document]: | |
| """Create documents from a list of texts.""" | |
| _metadatas = metadatas or [{}] * len(texts) | |
| documents = [] | |
| for i, text in enumerate(texts): | |
| index = 0 | |
| previous_chunk_len = 0 | |
| for chunk in self.split_text(text): | |
| metadata = copy.deepcopy(_metadatas[i]) | |
| if self._add_start_index: | |
| offset = index + previous_chunk_len - self._chunk_overlap | |
| index = text.find(chunk, max(0, offset)) | |
| metadata["start_index"] = index | |
| previous_chunk_len = len(chunk) | |
| new_doc = Document(page_content=chunk, metadata=metadata) | |
| documents.append(new_doc) | |
| return documents | |
| def split_documents(self, documents: Iterable[Document]) -> List[Document]: | |
| """Split documents.""" | |
| texts, metadatas = [], [] | |
| for doc in documents: | |
| texts.append(doc.page_content) | |
| metadatas.append(doc.metadata) | |
| return self.create_documents(texts, metadatas=metadatas) | |
| def _join_docs(self, docs: List[str], separator: str) -> Optional[str]: | |
| text = separator.join(docs) | |
| if self._strip_whitespace: | |
| text = text.strip() | |
| if text == "": | |
| return None | |
| else: | |
| return text | |
| def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]: | |
| # We now want to combine these smaller pieces into medium size | |
| # chunks to send to the LLM. | |
| separator_len = self._length_function(separator) | |
| docs = [] | |
| current_doc: List[str] = [] | |
| total = 0 | |
| for d in splits: | |
| _len = self._length_function(d) | |
| if ( | |
| total + _len + (separator_len if len(current_doc) > 0 else 0) | |
| > self._chunk_size | |
| ): | |
| if total > self._chunk_size: | |
| logger.warning( | |
| f"Created a chunk of size {total}, " | |
| f"which is longer than the specified {self._chunk_size}" | |
| ) | |
| if len(current_doc) > 0: | |
| doc = self._join_docs(current_doc, separator) | |
| if doc is not None: | |
| docs.append(doc) | |
| # Keep on popping if: | |
| # - we have a larger chunk than in the chunk overlap | |
| # - or if we still have any chunks and the length is long | |
| while total > self._chunk_overlap or ( | |
| total + _len + (separator_len if len(current_doc) > 0 else 0) | |
| > self._chunk_size | |
| and total > 0 | |
| ): | |
| total -= self._length_function(current_doc[0]) + ( | |
| separator_len if len(current_doc) > 1 else 0 | |
| ) | |
| current_doc = current_doc[1:] | |
| current_doc.append(d) | |
| total += _len + (separator_len if len(current_doc) > 1 else 0) | |
| doc = self._join_docs(current_doc, separator) | |
| if doc is not None: | |
| docs.append(doc) | |
| return docs | |
| def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter: | |
| """Text splitter that uses HuggingFace tokenizer to count length.""" | |
| try: | |
| from transformers import PreTrainedTokenizerBase | |
| if not isinstance(tokenizer, PreTrainedTokenizerBase): | |
| raise ValueError( | |
| "Tokenizer received was not an instance of PreTrainedTokenizerBase" | |
| ) | |
| def _huggingface_tokenizer_length(text: str) -> int: | |
| return len(tokenizer.encode(text)) | |
| except ImportError: | |
| raise ValueError( | |
| "Could not import transformers python package. " | |
| "Please install it with `pip install transformers`." | |
| ) | |
| return cls(length_function=_huggingface_tokenizer_length, **kwargs) | |
| def from_tiktoken_encoder( | |
| cls: Type[TS], | |
| encoding_name: str = "gpt2", | |
| model_name: Optional[str] = None, | |
| allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), | |
| disallowed_special: Union[Literal["all"], Collection[str]] = "all", | |
| **kwargs: Any, | |
| ) -> TS: | |
| """Text splitter that uses tiktoken encoder to count length.""" | |
| try: | |
| import tiktoken | |
| except ImportError: | |
| raise ImportError( | |
| "Could not import tiktoken python package. " | |
| "This is needed in order to calculate max_tokens_for_prompt. " | |
| "Please install it with `pip install tiktoken`." | |
| ) | |
| if model_name is not None: | |
| enc = tiktoken.encoding_for_model(model_name) | |
| else: | |
| enc = tiktoken.get_encoding(encoding_name) | |
| def _tiktoken_encoder(text: str) -> int: | |
| return len( | |
| enc.encode( | |
| text, | |
| allowed_special=allowed_special, | |
| disallowed_special=disallowed_special, | |
| ) | |
| ) | |
| if issubclass(cls, TokenTextSplitter): | |
| extra_kwargs = { | |
| "encoding_name": encoding_name, | |
| "model_name": model_name, | |
| "allowed_special": allowed_special, | |
| "disallowed_special": disallowed_special, | |
| } | |
| kwargs = {**kwargs, **extra_kwargs} | |
| return cls(length_function=_tiktoken_encoder, **kwargs) | |
| def transform_documents( | |
| self, documents: Sequence[Document], **kwargs: Any | |
| ) -> Sequence[Document]: | |
| """Transform sequence of documents by splitting them.""" | |
| return self.split_documents(list(documents)) | |
| class TokenTextSplitter(TextSplitter): | |
| """Splitting text to tokens using model tokenizer.""" | |
| def __init__( | |
| self, | |
| encoding_name: str = "gpt2", | |
| model_name: Optional[str] = None, | |
| allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), | |
| disallowed_special: Union[Literal["all"], Collection[str]] = "all", | |
| **kwargs: Any, | |
| ) -> None: | |
| """Create a new TextSplitter.""" | |
| super().__init__(**kwargs) | |
| try: | |
| import tiktoken | |
| except ImportError: | |
| raise ImportError( | |
| "Could not import tiktoken python package. " | |
| "This is needed in order to for TokenTextSplitter. " | |
| "Please install it with `pip install tiktoken`." | |
| ) | |
| if model_name is not None: | |
| enc = tiktoken.encoding_for_model(model_name) | |
| else: | |
| enc = tiktoken.get_encoding(encoding_name) | |
| self._tokenizer = enc | |
| self._allowed_special = allowed_special | |
| self._disallowed_special = disallowed_special | |
| def split_text(self, text: str) -> List[str]: | |
| def _encode(_text: str) -> List[int]: | |
| return self._tokenizer.encode( | |
| _text, | |
| allowed_special=self._allowed_special, | |
| disallowed_special=self._disallowed_special, | |
| ) | |
| tokenizer = Tokenizer( | |
| chunk_overlap=self._chunk_overlap, | |
| tokens_per_chunk=self._chunk_size, | |
| decode=self._tokenizer.decode, | |
| encode=_encode, | |
| ) | |
| return split_text_on_tokens(text=text, tokenizer=tokenizer) | |
| class Language(str, Enum): | |
| """Enum of the programming languages.""" | |
| CPP = "cpp" | |
| GO = "go" | |
| JAVA = "java" | |
| KOTLIN = "kotlin" | |
| JS = "js" | |
| TS = "ts" | |
| PHP = "php" | |
| PROTO = "proto" | |
| PYTHON = "python" | |
| RST = "rst" | |
| RUBY = "ruby" | |
| RUST = "rust" | |
| SCALA = "scala" | |
| SWIFT = "swift" | |
| MARKDOWN = "markdown" | |
| LATEX = "latex" | |
| HTML = "html" | |
| SOL = "sol" | |
| CSHARP = "csharp" | |
| COBOL = "cobol" | |
| C = "c" | |
| LUA = "lua" | |
| PERL = "perl" | |
| HASKELL = "haskell" | |
| class Tokenizer: | |
| """Tokenizer data class.""" | |
| chunk_overlap: int | |
| """Overlap in tokens between chunks""" | |
| tokens_per_chunk: int | |
| """Maximum number of tokens per chunk""" | |
| decode: Callable[[List[int]], str] | |
| """ Function to decode a list of token ids to a string""" | |
| encode: Callable[[str], List[int]] | |
| """ Function to encode a string to a list of token ids""" | |
| def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> List[str]: | |
| """Split incoming text and return chunks using tokenizer.""" | |
| splits: List[str] = [] | |
| input_ids = tokenizer.encode(text) | |
| start_idx = 0 | |
| cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) | |
| chunk_ids = input_ids[start_idx:cur_idx] | |
| while start_idx < len(input_ids): | |
| splits.append(tokenizer.decode(chunk_ids)) | |
| if cur_idx == len(input_ids): | |
| break | |
| start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap | |
| cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) | |
| chunk_ids = input_ids[start_idx:cur_idx] | |
| return splits | |