Spaces:
Runtime error
Runtime error
| """**Text Splitters** are classes for splitting text. | |
| **Class hierarchy:** | |
| .. code-block:: | |
| BaseDocumentTransformer --> TextSplitter --> <name>TextSplitter # Example: CharacterTextSplitter | |
| RecursiveCharacterTextSplitter --> <name>TextSplitter | |
| Note: **MarkdownHeaderTextSplitter** and **HTMLHeaderTextSplitter do not derive from TextSplitter. | |
| **Main helpers:** | |
| .. code-block:: | |
| Document, Tokenizer, Language, LineType, HeaderType | |
| """ # noqa: E501 | |
| from __future__ import annotations | |
| import asyncio | |
| import copy | |
| import logging | |
| import pathlib | |
| import re | |
| from abc import ABC, abstractmethod | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| from functools import partial | |
| from io import BytesIO, StringIO | |
| from typing import ( | |
| AbstractSet, | |
| Any, | |
| Callable, | |
| Collection, | |
| Dict, | |
| Iterable, | |
| List, | |
| Literal, | |
| Optional, | |
| Sequence, | |
| Tuple, | |
| Type, | |
| TypedDict, | |
| TypeVar, | |
| Union, | |
| cast, | |
| ) | |
| import requests | |
| from langchain_core.documents import BaseDocumentTransformer | |
| from langchain.docstore.document import Document | |
| logger = logging.getLogger(__name__) | |
| TS = TypeVar("TS", bound="TextSplitter") | |
| def _make_spacy_pipeline_for_splitting(pipeline: str) -> Any: # avoid importing spacy | |
| try: | |
| import spacy | |
| except ImportError: | |
| raise ImportError( | |
| "Spacy is not installed, please install it with `pip install spacy`." | |
| ) | |
| if pipeline == "sentencizer": | |
| from spacy.lang.en import English | |
| sentencizer = English() | |
| sentencizer.add_pipe("sentencizer") | |
| else: | |
| sentencizer = spacy.load(pipeline, exclude=["ner", "tagger"]) | |
| return sentencizer | |
| def _split_text_with_regex( | |
| text: str, separator: str, keep_separator: bool | |
| ) -> List[str]: | |
| # Now that we have the separator, split the text | |
| if separator: | |
| if keep_separator: | |
| # The parentheses in the pattern keep the delimiters in the result. | |
| _splits = re.split(f"({separator})", text) | |
| splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)] | |
| if len(_splits) % 2 == 0: | |
| splits += _splits[-1:] | |
| splits = [_splits[0]] + splits | |
| else: | |
| splits = re.split(separator, text) | |
| else: | |
| splits = list(text) | |
| return [s for s in splits if s != ""] | |
| class TextSplitter(BaseDocumentTransformer, ABC): | |
| """Interface for splitting text into chunks.""" | |
| def __init__( | |
| self, | |
| chunk_size: int = 4000, | |
| chunk_overlap: int = 200, | |
| length_function: Callable[[str], int] = len, | |
| keep_separator: bool = False, | |
| add_start_index: bool = False, | |
| strip_whitespace: bool = True, | |
| ) -> None: | |
| """Create a new TextSplitter. | |
| Args: | |
| chunk_size: Maximum size of chunks to return | |
| chunk_overlap: Overlap in characters between chunks | |
| length_function: Function that measures the length of given chunks | |
| keep_separator: Whether to keep the separator in the chunks | |
| add_start_index: If `True`, includes chunk's start index in metadata | |
| strip_whitespace: If `True`, strips whitespace from the start and end of | |
| every document | |
| """ | |
| if chunk_overlap > chunk_size: | |
| raise ValueError( | |
| f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " | |
| f"({chunk_size}), should be smaller." | |
| ) | |
| self._chunk_size = chunk_size | |
| self._chunk_overlap = chunk_overlap | |
| self._length_function = length_function | |
| self._keep_separator = keep_separator | |
| self._add_start_index = add_start_index | |
| self._strip_whitespace = strip_whitespace | |
| def split_text(self, text: str) -> List[str]: | |
| """Split text into multiple components.""" | |
| def create_documents( | |
| self, texts: List[str], metadatas: Optional[List[dict]] = None | |
| ) -> List[Document]: | |
| """Create documents from a list of texts.""" | |
| _metadatas = metadatas or [{}] * len(texts) | |
| documents = [] | |
| for i, text in enumerate(texts): | |
| index = -1 | |
| for chunk in self.split_text(text): | |
| metadata = copy.deepcopy(_metadatas[i]) | |
| if self._add_start_index: | |
| index = text.find(chunk, index + 1) | |
| metadata["start_index"] = index | |
| new_doc = Document(page_content=chunk, metadata=metadata) | |
| documents.append(new_doc) | |
| return documents | |
| def split_documents(self, documents: Iterable[Document]) -> List[Document]: | |
| """Split documents.""" | |
| texts, metadatas = [], [] | |
| for doc in documents: | |
| texts.append(doc.page_content) | |
| metadatas.append(doc.metadata) | |
| return self.create_documents(texts, metadatas=metadatas) | |
| def _join_docs(self, docs: List[str], separator: str) -> Optional[str]: | |
| text = separator.join(docs) | |
| if self._strip_whitespace: | |
| text = text.strip() | |
| if text == "": | |
| return None | |
| else: | |
| return text | |
| def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]: | |
| # We now want to combine these smaller pieces into medium size | |
| # chunks to send to the LLM. | |
| separator_len = self._length_function(separator) | |
| docs = [] | |
| current_doc: List[str] = [] | |
| total = 0 | |
| for d in splits: | |
| _len = self._length_function(d) | |
| if ( | |
| total + _len + (separator_len if len(current_doc) > 0 else 0) | |
| > self._chunk_size | |
| ): | |
| if total > self._chunk_size: | |
| logger.warning( | |
| f"Created a chunk of size {total}, " | |
| f"which is longer than the specified {self._chunk_size}" | |
| ) | |
| if len(current_doc) > 0: | |
| doc = self._join_docs(current_doc, separator) | |
| if doc is not None: | |
| docs.append(doc) | |
| # Keep on popping if: | |
| # - we have a larger chunk than in the chunk overlap | |
| # - or if we still have any chunks and the length is long | |
| while total > self._chunk_overlap or ( | |
| total + _len + (separator_len if len(current_doc) > 0 else 0) | |
| > self._chunk_size | |
| and total > 0 | |
| ): | |
| total -= self._length_function(current_doc[0]) + ( | |
| separator_len if len(current_doc) > 1 else 0 | |
| ) | |
| current_doc = current_doc[1:] | |
| current_doc.append(d) | |
| total += _len + (separator_len if len(current_doc) > 1 else 0) | |
| doc = self._join_docs(current_doc, separator) | |
| if doc is not None: | |
| docs.append(doc) | |
| return docs | |
| def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter: | |
| """Text splitter that uses HuggingFace tokenizer to count length.""" | |
| try: | |
| from transformers import PreTrainedTokenizerBase | |
| if not isinstance(tokenizer, PreTrainedTokenizerBase): | |
| raise ValueError( | |
| "Tokenizer received was not an instance of PreTrainedTokenizerBase" | |
| ) | |
| def _huggingface_tokenizer_length(text: str) -> int: | |
| return len(tokenizer.encode(text)) | |
| except ImportError: | |
| raise ValueError( | |
| "Could not import transformers python package. " | |
| "Please install it with `pip install transformers`." | |
| ) | |
| return cls(length_function=_huggingface_tokenizer_length, **kwargs) | |
| def from_tiktoken_encoder( | |
| cls: Type[TS], | |
| encoding_name: str = "gpt2", | |
| model_name: Optional[str] = None, | |
| allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), | |
| disallowed_special: Union[Literal["all"], Collection[str]] = "all", | |
| **kwargs: Any, | |
| ) -> TS: | |
| """Text splitter that uses tiktoken encoder to count length.""" | |
| try: | |
| import tiktoken | |
| except ImportError: | |
| raise ImportError( | |
| "Could not import tiktoken python package. " | |
| "This is needed in order to calculate max_tokens_for_prompt. " | |
| "Please install it with `pip install tiktoken`." | |
| ) | |
| if model_name is not None: | |
| enc = tiktoken.encoding_for_model(model_name) | |
| else: | |
| enc = tiktoken.get_encoding(encoding_name) | |
| def _tiktoken_encoder(text: str) -> int: | |
| return len( | |
| enc.encode( | |
| text, | |
| allowed_special=allowed_special, | |
| disallowed_special=disallowed_special, | |
| ) | |
| ) | |
| if issubclass(cls, TokenTextSplitter): | |
| extra_kwargs = { | |
| "encoding_name": encoding_name, | |
| "model_name": model_name, | |
| "allowed_special": allowed_special, | |
| "disallowed_special": disallowed_special, | |
| } | |
| kwargs = {**kwargs, **extra_kwargs} | |
| return cls(length_function=_tiktoken_encoder, **kwargs) | |
| def transform_documents( | |
| self, documents: Sequence[Document], **kwargs: Any | |
| ) -> Sequence[Document]: | |
| """Transform sequence of documents by splitting them.""" | |
| return self.split_documents(list(documents)) | |
| async def atransform_documents( | |
| self, documents: Sequence[Document], **kwargs: Any | |
| ) -> Sequence[Document]: | |
| """Asynchronously transform a sequence of documents by splitting them.""" | |
| return await asyncio.get_running_loop().run_in_executor( | |
| None, partial(self.transform_documents, **kwargs), documents | |
| ) | |
| class CharacterTextSplitter(TextSplitter): | |
| """Splitting text that looks at characters.""" | |
| def __init__( | |
| self, separator: str = "\n\n", is_separator_regex: bool = False, **kwargs: Any | |
| ) -> None: | |
| """Create a new TextSplitter.""" | |
| super().__init__(**kwargs) | |
| self._separator = separator | |
| self._is_separator_regex = is_separator_regex | |
| def split_text(self, text: str) -> List[str]: | |
| """Split incoming text and return chunks.""" | |
| # First we naively split the large input into a bunch of smaller ones. | |
| separator = ( | |
| self._separator if self._is_separator_regex else re.escape(self._separator) | |
| ) | |
| splits = _split_text_with_regex(text, separator, self._keep_separator) | |
| _separator = "" if self._keep_separator else self._separator | |
| return self._merge_splits(splits, _separator) | |
| class LineType(TypedDict): | |
| """Line type as typed dict.""" | |
| metadata: Dict[str, str] | |
| content: str | |
| class HeaderType(TypedDict): | |
| """Header type as typed dict.""" | |
| level: int | |
| name: str | |
| data: str | |
| class MarkdownHeaderTextSplitter: | |
| """Splitting markdown files based on specified headers.""" | |
| def __init__( | |
| self, headers_to_split_on: List[Tuple[str, str]], return_each_line: bool = False | |
| ): | |
| """Create a new MarkdownHeaderTextSplitter. | |
| Args: | |
| headers_to_split_on: Headers we want to track | |
| return_each_line: Return each line w/ associated headers | |
| """ | |
| # Output line-by-line or aggregated into chunks w/ common headers | |
| self.return_each_line = return_each_line | |
| # Given the headers we want to split on, | |
| # (e.g., "#, ##, etc") order by length | |
| self.headers_to_split_on = sorted( | |
| headers_to_split_on, key=lambda split: len(split[0]), reverse=True | |
| ) | |
| def aggregate_lines_to_chunks(self, lines: List[LineType]) -> List[Document]: | |
| """Combine lines with common metadata into chunks | |
| Args: | |
| lines: Line of text / associated header metadata | |
| """ | |
| aggregated_chunks: List[LineType] = [] | |
| for line in lines: | |
| if ( | |
| aggregated_chunks | |
| and aggregated_chunks[-1]["metadata"] == line["metadata"] | |
| ): | |
| # If the last line in the aggregated list | |
| # has the same metadata as the current line, | |
| # append the current content to the last lines's content | |
| aggregated_chunks[-1]["content"] += " \n" + line["content"] | |
| else: | |
| # Otherwise, append the current line to the aggregated list | |
| aggregated_chunks.append(line) | |
| return [ | |
| Document(page_content=chunk["content"], metadata=chunk["metadata"]) | |
| for chunk in aggregated_chunks | |
| ] | |
| def split_text(self, text: str) -> List[Document]: | |
| """Split markdown file | |
| Args: | |
| text: Markdown file""" | |
| # Split the input text by newline character ("\n"). | |
| lines = text.split("\n") | |
| # Final output | |
| lines_with_metadata: List[LineType] = [] | |
| # Content and metadata of the chunk currently being processed | |
| current_content: List[str] = [] | |
| current_metadata: Dict[str, str] = {} | |
| # Keep track of the nested header structure | |
| # header_stack: List[Dict[str, Union[int, str]]] = [] | |
| header_stack: List[HeaderType] = [] | |
| initial_metadata: Dict[str, str] = {} | |
| in_code_block = False | |
| for line in lines: | |
| stripped_line = line.strip() | |
| if stripped_line.startswith("```"): | |
| # code block in one row | |
| if stripped_line.count("```") >= 2: | |
| in_code_block = False | |
| else: | |
| in_code_block = not in_code_block | |
| if in_code_block: | |
| current_content.append(stripped_line) | |
| continue | |
| # Check each line against each of the header types (e.g., #, ##) | |
| for sep, name in self.headers_to_split_on: | |
| # Check if line starts with a header that we intend to split on | |
| if stripped_line.startswith(sep) and ( | |
| # Header with no text OR header is followed by space | |
| # Both are valid conditions that sep is being used a header | |
| len(stripped_line) == len(sep) or stripped_line[len(sep)] == " " | |
| ): | |
| # Ensure we are tracking the header as metadata | |
| if name is not None: | |
| # Get the current header level | |
| current_header_level = sep.count("#") | |
| # Pop out headers of lower or same level from the stack | |
| while ( | |
| header_stack | |
| and header_stack[-1]["level"] >= current_header_level | |
| ): | |
| # We have encountered a new header | |
| # at the same or higher level | |
| popped_header = header_stack.pop() | |
| # Clear the metadata for the | |
| # popped header in initial_metadata | |
| if popped_header["name"] in initial_metadata: | |
| initial_metadata.pop(popped_header["name"]) | |
| # Push the current header to the stack | |
| header: HeaderType = { | |
| "level": current_header_level, | |
| "name": name, | |
| "data": stripped_line[len(sep) :].strip(), | |
| } | |
| header_stack.append(header) | |
| # Update initial_metadata with the current header | |
| initial_metadata[name] = header["data"] | |
| # Add the previous line to the lines_with_metadata | |
| # only if current_content is not empty | |
| if current_content: | |
| lines_with_metadata.append( | |
| { | |
| "content": "\n".join(current_content), | |
| "metadata": current_metadata.copy(), | |
| } | |
| ) | |
| current_content.clear() | |
| break | |
| else: | |
| if stripped_line: | |
| current_content.append(stripped_line) | |
| elif current_content: | |
| lines_with_metadata.append( | |
| { | |
| "content": "\n".join(current_content), | |
| "metadata": current_metadata.copy(), | |
| } | |
| ) | |
| current_content.clear() | |
| current_metadata = initial_metadata.copy() | |
| if current_content: | |
| lines_with_metadata.append( | |
| {"content": "\n".join(current_content), "metadata": current_metadata} | |
| ) | |
| # lines_with_metadata has each line with associated header metadata | |
| # aggregate these into chunks based on common metadata | |
| if not self.return_each_line: | |
| return self.aggregate_lines_to_chunks(lines_with_metadata) | |
| else: | |
| return [ | |
| Document(page_content=chunk["content"], metadata=chunk["metadata"]) | |
| for chunk in lines_with_metadata | |
| ] | |
| class ElementType(TypedDict): | |
| """Element type as typed dict.""" | |
| url: str | |
| xpath: str | |
| content: str | |
| metadata: Dict[str, str] | |
| class HTMLHeaderTextSplitter: | |
| """ | |
| Splitting HTML files based on specified headers. | |
| Requires lxml package. | |
| """ | |
| def __init__( | |
| self, | |
| headers_to_split_on: List[Tuple[str, str]], | |
| return_each_element: bool = False, | |
| ): | |
| """Create a new HTMLHeaderTextSplitter. | |
| Args: | |
| headers_to_split_on: list of tuples of headers we want to track mapped to | |
| (arbitrary) keys for metadata. Allowed header values: h1, h2, h3, h4, | |
| h5, h6 e.g. [("h1", "Header 1"), ("h2", "Header 2)]. | |
| return_each_element: Return each element w/ associated headers. | |
| """ | |
| # Output element-by-element or aggregated into chunks w/ common headers | |
| self.return_each_element = return_each_element | |
| self.headers_to_split_on = sorted(headers_to_split_on) | |
| def aggregate_elements_to_chunks( | |
| self, elements: List[ElementType] | |
| ) -> List[Document]: | |
| """Combine elements with common metadata into chunks | |
| Args: | |
| elements: HTML element content with associated identifying info and metadata | |
| """ | |
| aggregated_chunks: List[ElementType] = [] | |
| for element in elements: | |
| if ( | |
| aggregated_chunks | |
| and aggregated_chunks[-1]["metadata"] == element["metadata"] | |
| ): | |
| # If the last element in the aggregated list | |
| # has the same metadata as the current element, | |
| # append the current content to the last element's content | |
| aggregated_chunks[-1]["content"] += " \n" + element["content"] | |
| else: | |
| # Otherwise, append the current element to the aggregated list | |
| aggregated_chunks.append(element) | |
| return [ | |
| Document(page_content=chunk["content"], metadata=chunk["metadata"]) | |
| for chunk in aggregated_chunks | |
| ] | |
| def split_text_from_url(self, url: str) -> List[Document]: | |
| """Split HTML from web URL | |
| Args: | |
| url: web URL | |
| """ | |
| r = requests.get(url) | |
| return self.split_text_from_file(BytesIO(r.content)) | |
| def split_text(self, text: str) -> List[Document]: | |
| """Split HTML text string | |
| Args: | |
| text: HTML text | |
| """ | |
| return self.split_text_from_file(StringIO(text)) | |
| def split_text_from_file(self, file: Any) -> List[Document]: | |
| """Split HTML file | |
| Args: | |
| file: HTML file | |
| """ | |
| try: | |
| from lxml import etree | |
| except ImportError as e: | |
| raise ImportError( | |
| "Unable to import lxml, please install with `pip install lxml`." | |
| ) from e | |
| # use lxml library to parse html document and return xml ElementTree | |
| parser = etree.HTMLParser() | |
| tree = etree.parse(file, parser) | |
| # document transformation for "structure-aware" chunking is handled with xsl. | |
| # see comments in html_chunks_with_headers.xslt for more detailed information. | |
| xslt_path = ( | |
| pathlib.Path(__file__).parent | |
| / "document_transformers/xsl/html_chunks_with_headers.xslt" | |
| ) | |
| xslt_tree = etree.parse(xslt_path) | |
| transform = etree.XSLT(xslt_tree) | |
| result = transform(tree) | |
| result_dom = etree.fromstring(str(result)) | |
| # create filter and mapping for header metadata | |
| header_filter = [header[0] for header in self.headers_to_split_on] | |
| header_mapping = dict(self.headers_to_split_on) | |
| # map xhtml namespace prefix | |
| ns_map = {"h": "http://www.w3.org/1999/xhtml"} | |
| # build list of elements from DOM | |
| elements = [] | |
| for element in result_dom.findall("*//*", ns_map): | |
| if element.findall("*[@class='headers']") or element.findall( | |
| "*[@class='chunk']" | |
| ): | |
| elements.append( | |
| ElementType( | |
| url=file, | |
| xpath="".join( | |
| [ | |
| node.text | |
| for node in element.findall("*[@class='xpath']", ns_map) | |
| ] | |
| ), | |
| content="".join( | |
| [ | |
| node.text | |
| for node in element.findall("*[@class='chunk']", ns_map) | |
| ] | |
| ), | |
| metadata={ | |
| # Add text of specified headers to metadata using header | |
| # mapping. | |
| header_mapping[node.tag]: node.text | |
| for node in filter( | |
| lambda x: x.tag in header_filter, | |
| element.findall("*[@class='headers']/*", ns_map), | |
| ) | |
| }, | |
| ) | |
| ) | |
| if not self.return_each_element: | |
| return self.aggregate_elements_to_chunks(elements) | |
| else: | |
| return [ | |
| Document(page_content=chunk["content"], metadata=chunk["metadata"]) | |
| for chunk in elements | |
| ] | |
| # should be in newer Python versions (3.10+) | |
| # @dataclass(frozen=True, kw_only=True, slots=True) | |
| class Tokenizer: | |
| """Tokenizer data class.""" | |
| chunk_overlap: int | |
| """Overlap in tokens between chunks""" | |
| tokens_per_chunk: int | |
| """Maximum number of tokens per chunk""" | |
| decode: Callable[[List[int]], str] | |
| """ Function to decode a list of token ids to a string""" | |
| encode: Callable[[str], List[int]] | |
| """ Function to encode a string to a list of token ids""" | |
| def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> List[str]: | |
| """Split incoming text and return chunks using tokenizer.""" | |
| splits: List[str] = [] | |
| input_ids = tokenizer.encode(text) | |
| start_idx = 0 | |
| cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) | |
| chunk_ids = input_ids[start_idx:cur_idx] | |
| while start_idx < len(input_ids): | |
| splits.append(tokenizer.decode(chunk_ids)) | |
| start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap | |
| cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) | |
| chunk_ids = input_ids[start_idx:cur_idx] | |
| return splits | |
| class TokenTextSplitter(TextSplitter): | |
| """Splitting text to tokens using model tokenizer.""" | |
| def __init__( | |
| self, | |
| encoding_name: str = "gpt2", | |
| model_name: Optional[str] = None, | |
| allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), | |
| disallowed_special: Union[Literal["all"], Collection[str]] = "all", | |
| **kwargs: Any, | |
| ) -> None: | |
| """Create a new TextSplitter.""" | |
| super().__init__(**kwargs) | |
| try: | |
| import tiktoken | |
| except ImportError: | |
| raise ImportError( | |
| "Could not import tiktoken python package. " | |
| "This is needed in order to for TokenTextSplitter. " | |
| "Please install it with `pip install tiktoken`." | |
| ) | |
| if model_name is not None: | |
| enc = tiktoken.encoding_for_model(model_name) | |
| else: | |
| enc = tiktoken.get_encoding(encoding_name) | |
| self._tokenizer = enc | |
| self._allowed_special = allowed_special | |
| self._disallowed_special = disallowed_special | |
| def split_text(self, text: str) -> List[str]: | |
| def _encode(_text: str) -> List[int]: | |
| return self._tokenizer.encode( | |
| _text, | |
| allowed_special=self._allowed_special, | |
| disallowed_special=self._disallowed_special, | |
| ) | |
| tokenizer = Tokenizer( | |
| chunk_overlap=self._chunk_overlap, | |
| tokens_per_chunk=self._chunk_size, | |
| decode=self._tokenizer.decode, | |
| encode=_encode, | |
| ) | |
| return split_text_on_tokens(text=text, tokenizer=tokenizer) | |
| class SentenceTransformersTokenTextSplitter(TextSplitter): | |
| """Splitting text to tokens using sentence model tokenizer.""" | |
| def __init__( | |
| self, | |
| chunk_overlap: int = 50, | |
| model_name: str = "sentence-transformers/all-mpnet-base-v2", | |
| tokens_per_chunk: Optional[int] = None, | |
| **kwargs: Any, | |
| ) -> None: | |
| """Create a new TextSplitter.""" | |
| super().__init__(**kwargs, chunk_overlap=chunk_overlap) | |
| try: | |
| from sentence_transformers import SentenceTransformer | |
| except ImportError: | |
| raise ImportError( | |
| "Could not import sentence_transformer python package. " | |
| "This is needed in order to for SentenceTransformersTokenTextSplitter. " | |
| "Please install it with `pip install sentence-transformers`." | |
| ) | |
| self.model_name = model_name | |
| self._model = SentenceTransformer(self.model_name) | |
| self.tokenizer = self._model.tokenizer | |
| self._initialize_chunk_configuration(tokens_per_chunk=tokens_per_chunk) | |
| def _initialize_chunk_configuration( | |
| self, *, tokens_per_chunk: Optional[int] | |
| ) -> None: | |
| self.maximum_tokens_per_chunk = cast(int, self._model.max_seq_length) | |
| if tokens_per_chunk is None: | |
| self.tokens_per_chunk = self.maximum_tokens_per_chunk | |
| else: | |
| self.tokens_per_chunk = tokens_per_chunk | |
| if self.tokens_per_chunk > self.maximum_tokens_per_chunk: | |
| raise ValueError( | |
| f"The token limit of the models '{self.model_name}'" | |
| f" is: {self.maximum_tokens_per_chunk}." | |
| f" Argument tokens_per_chunk={self.tokens_per_chunk}" | |
| f" > maximum token limit." | |
| ) | |
| def split_text(self, text: str) -> List[str]: | |
| def encode_strip_start_and_stop_token_ids(text: str) -> List[int]: | |
| return self._encode(text)[1:-1] | |
| tokenizer = Tokenizer( | |
| chunk_overlap=self._chunk_overlap, | |
| tokens_per_chunk=self.tokens_per_chunk, | |
| decode=self.tokenizer.decode, | |
| encode=encode_strip_start_and_stop_token_ids, | |
| ) | |
| return split_text_on_tokens(text=text, tokenizer=tokenizer) | |
| def count_tokens(self, *, text: str) -> int: | |
| return len(self._encode(text)) | |
| _max_length_equal_32_bit_integer: int = 2**32 | |
| def _encode(self, text: str) -> List[int]: | |
| token_ids_with_start_and_end_token_ids = self.tokenizer.encode( | |
| text, | |
| max_length=self._max_length_equal_32_bit_integer, | |
| truncation="do_not_truncate", | |
| ) | |
| return token_ids_with_start_and_end_token_ids | |
| class Language(str, Enum): | |
| """Enum of the programming languages.""" | |
| CPP = "cpp" | |
| GO = "go" | |
| JAVA = "java" | |
| KOTLIN = "kotlin" | |
| JS = "js" | |
| TS = "ts" | |
| PHP = "php" | |
| PROTO = "proto" | |
| PYTHON = "python" | |
| RST = "rst" | |
| RUBY = "ruby" | |
| RUST = "rust" | |
| SCALA = "scala" | |
| SWIFT = "swift" | |
| MARKDOWN = "markdown" | |
| LATEX = "latex" | |
| HTML = "html" | |
| SOL = "sol" | |
| CSHARP = "csharp" | |
| COBOL = "cobol" | |
| class RecursiveCharacterTextSplitter(TextSplitter): | |
| """Splitting text by recursively look at characters. | |
| Recursively tries to split by different characters to find one | |
| that works. | |
| """ | |
| def __init__( | |
| self, | |
| separators: Optional[List[str]] = None, | |
| keep_separator: bool = True, | |
| is_separator_regex: bool = False, | |
| **kwargs: Any, | |
| ) -> None: | |
| """Create a new TextSplitter.""" | |
| super().__init__(keep_separator=keep_separator, **kwargs) | |
| self._separators = separators or ["\n\n", "\n", " ", ""] | |
| self._is_separator_regex = is_separator_regex | |
| def _split_text(self, text: str, separators: List[str]) -> List[str]: | |
| """Split incoming text and return chunks.""" | |
| final_chunks = [] | |
| # Get appropriate separator to use | |
| separator = separators[-1] | |
| new_separators = [] | |
| for i, _s in enumerate(separators): | |
| _separator = _s if self._is_separator_regex else re.escape(_s) | |
| if _s == "": | |
| separator = _s | |
| break | |
| if re.search(_separator, text): | |
| separator = _s | |
| new_separators = separators[i + 1 :] | |
| break | |
| _separator = separator if self._is_separator_regex else re.escape(separator) | |
| splits = _split_text_with_regex(text, _separator, self._keep_separator) | |
| # Now go merging things, recursively splitting longer texts. | |
| _good_splits = [] | |
| _separator = "" if self._keep_separator else separator | |
| for s in splits: | |
| if self._length_function(s) < self._chunk_size: | |
| _good_splits.append(s) | |
| else: | |
| if _good_splits: | |
| merged_text = self._merge_splits(_good_splits, _separator) | |
| final_chunks.extend(merged_text) | |
| _good_splits = [] | |
| if not new_separators: | |
| final_chunks.append(s) | |
| else: | |
| other_info = self._split_text(s, new_separators) | |
| final_chunks.extend(other_info) | |
| if _good_splits: | |
| merged_text = self._merge_splits(_good_splits, _separator) | |
| final_chunks.extend(merged_text) | |
| return final_chunks | |
| def split_text(self, text: str) -> List[str]: | |
| return self._split_text(text, self._separators) | |
| def from_language( | |
| cls, language: Language, **kwargs: Any | |
| ) -> RecursiveCharacterTextSplitter: | |
| separators = cls.get_separators_for_language(language) | |
| return cls(separators=separators, is_separator_regex=True, **kwargs) | |
| def get_separators_for_language(language: Language) -> List[str]: | |
| if language == Language.CPP: | |
| return [ | |
| # Split along class definitions | |
| "\nclass ", | |
| # Split along function definitions | |
| "\nvoid ", | |
| "\nint ", | |
| "\nfloat ", | |
| "\ndouble ", | |
| # Split along control flow statements | |
| "\nif ", | |
| "\nfor ", | |
| "\nwhile ", | |
| "\nswitch ", | |
| "\ncase ", | |
| # Split by the normal type of lines | |
| "\n\n", | |
| "\n", | |
| " ", | |
| "", | |
| ] | |
| elif language == Language.GO: | |
| return [ | |
| # Split along function definitions | |
| "\nfunc ", | |
| "\nvar ", | |
| "\nconst ", | |
| "\ntype ", | |
| # Split along control flow statements | |
| "\nif ", | |
| "\nfor ", | |
| "\nswitch ", | |
| "\ncase ", | |
| # Split by the normal type of lines | |
| "\n\n", | |
| "\n", | |
| " ", | |
| "", | |
| ] | |
| elif language == Language.JAVA: | |
| return [ | |
| # Split along class definitions | |
| "\nclass ", | |
| # Split along method definitions | |
| "\npublic ", | |
| "\nprotected ", | |
| "\nprivate ", | |
| "\nstatic ", | |
| # Split along control flow statements | |
| "\nif ", | |
| "\nfor ", | |
| "\nwhile ", | |
| "\nswitch ", | |
| "\ncase ", | |
| # Split by the normal type of lines | |
| "\n\n", | |
| "\n", | |
| " ", | |
| "", | |
| ] | |
| elif language == Language.KOTLIN: | |
| return [ | |
| # Split along class definitions | |
| "\nclass ", | |
| # Split along method definitions | |
| "\npublic ", | |
| "\nprotected ", | |
| "\nprivate ", | |
| "\ninternal ", | |
| "\ncompanion ", | |
| "\nfun ", | |
| "\nval ", | |
| "\nvar ", | |
| # Split along control flow statements | |
| "\nif ", | |
| "\nfor ", | |
| "\nwhile ", | |
| "\nwhen ", | |
| "\ncase ", | |
| "\nelse ", | |
| # Split by the normal type of lines | |
| "\n\n", | |
| "\n", | |
| " ", | |
| "", | |
| ] | |
| elif language == Language.JS: | |
| return [ | |
| # Split along function definitions | |
| "\nfunction ", | |
| "\nconst ", | |
| "\nlet ", | |
| "\nvar ", | |
| "\nclass ", | |
| # Split along control flow statements | |
| "\nif ", | |
| "\nfor ", | |
| "\nwhile ", | |
| "\nswitch ", | |
| "\ncase ", | |
| "\ndefault ", | |
| # Split by the normal type of lines | |
| "\n\n", | |
| "\n", | |
| " ", | |
| "", | |
| ] | |
| elif language == Language.TS: | |
| return [ | |
| "\nenum ", | |
| "\ninterface ", | |
| "\nnamespace ", | |
| "\ntype ", | |
| # Split along class definitions | |
| "\nclass ", | |
| # Split along function definitions | |
| "\nfunction ", | |
| "\nconst ", | |
| "\nlet ", | |
| "\nvar ", | |
| # Split along control flow statements | |
| "\nif ", | |
| "\nfor ", | |
| "\nwhile ", | |
| "\nswitch ", | |
| "\ncase ", | |
| "\ndefault ", | |
| # Split by the normal type of lines | |
| "\n\n", | |
| "\n", | |
| " ", | |
| "", | |
| ] | |
| elif language == Language.PHP: | |
| return [ | |
| # Split along function definitions | |
| "\nfunction ", | |
| # Split along class definitions | |
| "\nclass ", | |
| # Split along control flow statements | |
| "\nif ", | |
| "\nforeach ", | |
| "\nwhile ", | |
| "\ndo ", | |
| "\nswitch ", | |
| "\ncase ", | |
| # Split by the normal type of lines | |
| "\n\n", | |
| "\n", | |
| " ", | |
| "", | |
| ] | |
| elif language == Language.PROTO: | |
| return [ | |
| # Split along message definitions | |
| "\nmessage ", | |
| # Split along service definitions | |
| "\nservice ", | |
| # Split along enum definitions | |
| "\nenum ", | |
| # Split along option definitions | |
| "\noption ", | |
| # Split along import statements | |
| "\nimport ", | |
| # Split along syntax declarations | |
| "\nsyntax ", | |
| # Split by the normal type of lines | |
| "\n\n", | |
| "\n", | |
| " ", | |
| "", | |
| ] | |
| elif language == Language.PYTHON: | |
| return [ | |
| # First, try to split along class definitions | |
| "\nclass ", | |
| "\ndef ", | |
| "\n\tdef ", | |
| # Now split by the normal type of lines | |
| "\n\n", | |
| "\n", | |
| " ", | |
| "", | |
| ] | |
| elif language == Language.RST: | |
| return [ | |
| # Split along section titles | |
| "\n=+\n", | |
| "\n-+\n", | |
| "\n\\*+\n", | |
| # Split along directive markers | |
| "\n\n.. *\n\n", | |
| # Split by the normal type of lines | |
| "\n\n", | |
| "\n", | |
| " ", | |
| "", | |
| ] | |
| elif language == Language.RUBY: | |
| return [ | |
| # Split along method definitions | |
| "\ndef ", | |
| "\nclass ", | |
| # Split along control flow statements | |
| "\nif ", | |
| "\nunless ", | |
| "\nwhile ", | |
| "\nfor ", | |
| "\ndo ", | |
| "\nbegin ", | |
| "\nrescue ", | |
| # Split by the normal type of lines | |
| "\n\n", | |
| "\n", | |
| " ", | |
| "", | |
| ] | |
| elif language == Language.RUST: | |
| return [ | |
| # Split along function definitions | |
| "\nfn ", | |
| "\nconst ", | |
| "\nlet ", | |
| # Split along control flow statements | |
| "\nif ", | |
| "\nwhile ", | |
| "\nfor ", | |
| "\nloop ", | |
| "\nmatch ", | |
| "\nconst ", | |
| # Split by the normal type of lines | |
| "\n\n", | |
| "\n", | |
| " ", | |
| "", | |
| ] | |
| elif language == Language.SCALA: | |
| return [ | |
| # Split along class definitions | |
| "\nclass ", | |
| "\nobject ", | |
| # Split along method definitions | |
| "\ndef ", | |
| "\nval ", | |
| "\nvar ", | |
| # Split along control flow statements | |
| "\nif ", | |
| "\nfor ", | |
| "\nwhile ", | |
| "\nmatch ", | |
| "\ncase ", | |
| # Split by the normal type of lines | |
| "\n\n", | |
| "\n", | |
| " ", | |
| "", | |
| ] | |
| elif language == Language.SWIFT: | |
| return [ | |
| # Split along function definitions | |
| "\nfunc ", | |
| # Split along class definitions | |
| "\nclass ", | |
| "\nstruct ", | |
| "\nenum ", | |
| # Split along control flow statements | |
| "\nif ", | |
| "\nfor ", | |
| "\nwhile ", | |
| "\ndo ", | |
| "\nswitch ", | |
| "\ncase ", | |
| # Split by the normal type of lines | |
| "\n\n", | |
| "\n", | |
| " ", | |
| "", | |
| ] | |
| elif language == Language.MARKDOWN: | |
| return [ | |
| # First, try to split along Markdown headings (starting with level 2) | |
| "\n#{1,6} ", | |
| # Note the alternative syntax for headings (below) is not handled here | |
| # Heading level 2 | |
| # --------------- | |
| # End of code block | |
| "```\n", | |
| # Horizontal lines | |
| "\n\\*\\*\\*+\n", | |
| "\n---+\n", | |
| "\n___+\n", | |
| # Note that this splitter doesn't handle horizontal lines defined | |
| # by *three or more* of ***, ---, or ___, but this is not handled | |
| "\n\n", | |
| "\n", | |
| " ", | |
| "", | |
| ] | |
| elif language == Language.LATEX: | |
| return [ | |
| # First, try to split along Latex sections | |
| "\n\\\\chapter{", | |
| "\n\\\\section{", | |
| "\n\\\\subsection{", | |
| "\n\\\\subsubsection{", | |
| # Now split by environments | |
| "\n\\\\begin{enumerate}", | |
| "\n\\\\begin{itemize}", | |
| "\n\\\\begin{description}", | |
| "\n\\\\begin{list}", | |
| "\n\\\\begin{quote}", | |
| "\n\\\\begin{quotation}", | |
| "\n\\\\begin{verse}", | |
| "\n\\\\begin{verbatim}", | |
| # Now split by math environments | |
| "\n\\\begin{align}", | |
| "$$", | |
| "$", | |
| # Now split by the normal type of lines | |
| " ", | |
| "", | |
| ] | |
| elif language == Language.HTML: | |
| return [ | |
| # First, try to split along HTML tags | |
| "<body", | |
| "<div", | |
| "<p", | |
| "<br", | |
| "<li", | |
| "<h1", | |
| "<h2", | |
| "<h3", | |
| "<h4", | |
| "<h5", | |
| "<h6", | |
| "<span", | |
| "<table", | |
| "<tr", | |
| "<td", | |
| "<th", | |
| "<ul", | |
| "<ol", | |
| "<header", | |
| "<footer", | |
| "<nav", | |
| # Head | |
| "<head", | |
| "<style", | |
| "<script", | |
| "<meta", | |
| "<title", | |
| "", | |
| ] | |
| elif language == Language.CSHARP: | |
| return [ | |
| "\ninterface ", | |
| "\nenum ", | |
| "\nimplements ", | |
| "\ndelegate ", | |
| "\nevent ", | |
| # Split along class definitions | |
| "\nclass ", | |
| "\nabstract ", | |
| # Split along method definitions | |
| "\npublic ", | |
| "\nprotected ", | |
| "\nprivate ", | |
| "\nstatic ", | |
| "\nreturn ", | |
| # Split along control flow statements | |
| "\nif ", | |
| "\ncontinue ", | |
| "\nfor ", | |
| "\nforeach ", | |
| "\nwhile ", | |
| "\nswitch ", | |
| "\nbreak ", | |
| "\ncase ", | |
| "\nelse ", | |
| # Split by exceptions | |
| "\ntry ", | |
| "\nthrow ", | |
| "\nfinally ", | |
| "\ncatch ", | |
| # Split by the normal type of lines | |
| "\n\n", | |
| "\n", | |
| " ", | |
| "", | |
| ] | |
| elif language == Language.SOL: | |
| return [ | |
| # Split along compiler information definitions | |
| "\npragma ", | |
| "\nusing ", | |
| # Split along contract definitions | |
| "\ncontract ", | |
| "\ninterface ", | |
| "\nlibrary ", | |
| # Split along method definitions | |
| "\nconstructor ", | |
| "\ntype ", | |
| "\nfunction ", | |
| "\nevent ", | |
| "\nmodifier ", | |
| "\nerror ", | |
| "\nstruct ", | |
| "\nenum ", | |
| # Split along control flow statements | |
| "\nif ", | |
| "\nfor ", | |
| "\nwhile ", | |
| "\ndo while ", | |
| "\nassembly ", | |
| # Split by the normal type of lines | |
| "\n\n", | |
| "\n", | |
| " ", | |
| "", | |
| ] | |
| elif language == Language.COBOL: | |
| return [ | |
| # Split along divisions | |
| "\nIDENTIFICATION DIVISION.", | |
| "\nENVIRONMENT DIVISION.", | |
| "\nDATA DIVISION.", | |
| "\nPROCEDURE DIVISION.", | |
| # Split along sections within DATA DIVISION | |
| "\nWORKING-STORAGE SECTION.", | |
| "\nLINKAGE SECTION.", | |
| "\nFILE SECTION.", | |
| # Split along sections within PROCEDURE DIVISION | |
| "\nINPUT-OUTPUT SECTION.", | |
| # Split along paragraphs and common statements | |
| "\nOPEN ", | |
| "\nCLOSE ", | |
| "\nREAD ", | |
| "\nWRITE ", | |
| "\nIF ", | |
| "\nELSE ", | |
| "\nMOVE ", | |
| "\nPERFORM ", | |
| "\nUNTIL ", | |
| "\nVARYING ", | |
| "\nACCEPT ", | |
| "\nDISPLAY ", | |
| "\nSTOP RUN.", | |
| # Split by the normal type of lines | |
| "\n", | |
| " ", | |
| "", | |
| ] | |
| else: | |
| raise ValueError( | |
| f"Language {language} is not supported! " | |
| f"Please choose from {list(Language)}" | |
| ) | |
| class NLTKTextSplitter(TextSplitter): | |
| """Splitting text using NLTK package.""" | |
| def __init__( | |
| self, separator: str = "\n\n", language: str = "english", **kwargs: Any | |
| ) -> None: | |
| """Initialize the NLTK splitter.""" | |
| super().__init__(**kwargs) | |
| try: | |
| from nltk.tokenize import sent_tokenize | |
| self._tokenizer = sent_tokenize | |
| except ImportError: | |
| raise ImportError( | |
| "NLTK is not installed, please install it with `pip install nltk`." | |
| ) | |
| self._separator = separator | |
| self._language = language | |
| def split_text(self, text: str) -> List[str]: | |
| """Split incoming text and return chunks.""" | |
| # First we naively split the large input into a bunch of smaller ones. | |
| splits = self._tokenizer(text, language=self._language) | |
| return self._merge_splits(splits, self._separator) | |
| class SpacyTextSplitter(TextSplitter): | |
| """Splitting text using Spacy package. | |
| Per default, Spacy's `en_core_web_sm` model is used. For a faster, but | |
| potentially less accurate splitting, you can use `pipeline='sentencizer'`. | |
| """ | |
| def __init__( | |
| self, separator: str = "\n\n", pipeline: str = "en_core_web_sm", **kwargs: Any | |
| ) -> None: | |
| """Initialize the spacy text splitter.""" | |
| super().__init__(**kwargs) | |
| self._tokenizer = _make_spacy_pipeline_for_splitting(pipeline) | |
| self._separator = separator | |
| def split_text(self, text: str) -> List[str]: | |
| """Split incoming text and return chunks.""" | |
| splits = (s.text for s in self._tokenizer(text).sents) | |
| return self._merge_splits(splits, self._separator) | |
| # For backwards compatibility | |
| class PythonCodeTextSplitter(RecursiveCharacterTextSplitter): | |
| """Attempts to split the text along Python syntax.""" | |
| def __init__(self, **kwargs: Any) -> None: | |
| """Initialize a PythonCodeTextSplitter.""" | |
| separators = self.get_separators_for_language(Language.PYTHON) | |
| super().__init__(separators=separators, **kwargs) | |
| class MarkdownTextSplitter(RecursiveCharacterTextSplitter): | |
| """Attempts to split the text along Markdown-formatted headings.""" | |
| def __init__(self, **kwargs: Any) -> None: | |
| """Initialize a MarkdownTextSplitter.""" | |
| separators = self.get_separators_for_language(Language.MARKDOWN) | |
| super().__init__(separators=separators, **kwargs) | |
| class LatexTextSplitter(RecursiveCharacterTextSplitter): | |
| """Attempts to split the text along Latex-formatted layout elements.""" | |
| def __init__(self, **kwargs: Any) -> None: | |
| """Initialize a LatexTextSplitter.""" | |
| separators = self.get_separators_for_language(Language.LATEX) | |
| super().__init__(separators=separators, **kwargs) | |