Spaces:
Running
Running
| """Chunker abstraction and implementations.""" | |
| import logging | |
| from abc import ABC, abstractmethod | |
| from dataclasses import dataclass | |
| from functools import lru_cache | |
| from typing import List, Optional | |
| import nbformat | |
| import pygments | |
| import tiktoken | |
| from semchunk import chunk as chunk_via_semchunk | |
| from tree_sitter import Node | |
| from tree_sitter_language_pack import get_parser | |
| logger = logging.getLogger(__name__) | |
| class Chunk: | |
| """A chunk of code or text extracted from a file in the repository.""" | |
| filename: str | |
| start_byte: int | |
| end_byte: int | |
| _content: Optional[str] = None | |
| def content(self) -> Optional[str]: | |
| """The text content to be embedded. Might contain information beyond just the text snippet from the file.""" | |
| return self._content | |
| def to_metadata(self): | |
| """Converts the chunk to a dictionary that can be passed to a vector store.""" | |
| # Some vector stores require the IDs to be ASCII. | |
| filename_ascii = self.filename.encode("ascii", "ignore").decode("ascii") | |
| return { | |
| # Some vector stores require the IDs to be ASCII. | |
| "id": f"{filename_ascii}_{self.start_byte}_{self.end_byte}", | |
| "filename": self.filename, | |
| "start_byte": self.start_byte, | |
| "end_byte": self.end_byte, | |
| # Note to developer: When choosing a large chunk size, you might exceed the vector store's metadata | |
| # size limit. In that case, you can simply store the start/end bytes above, and fetch the content | |
| # directly from the repository when needed. | |
| "text": self.content, | |
| } | |
| def populate_content(self, file_content: str): | |
| """Populates the content of the chunk with the file path and file content.""" | |
| self._content = self.filename + "\n\n" + file_content[self.start_byte : self.end_byte] | |
| def num_tokens(self, tokenizer): | |
| """Counts the number of tokens in the chunk.""" | |
| if not self.content: | |
| raise ValueError("Content not populated.") | |
| return Chunk._cached_num_tokens(self.content, tokenizer) | |
| def _cached_num_tokens(content: str, tokenizer): | |
| """Static method to cache token counts.""" | |
| return len(tokenizer.encode(content, disallowed_special=())) | |
| def __eq__(self, other): | |
| if isinstance(other, Chunk): | |
| return ( | |
| self.filename == other.filename | |
| and self.start_byte == other.start_byte | |
| and self.end_byte == other.end_byte | |
| ) | |
| return False | |
| def __hash__(self): | |
| return hash((self.filename, self.start_byte, self.end_byte)) | |
| class Chunker(ABC): | |
| """Abstract class for chunking a file into smaller pieces.""" | |
| def chunk(self, file_path: str, file_content: str) -> List[Chunk]: | |
| """Chunks a file into smaller pieces.""" | |
| class CodeChunker(Chunker): | |
| """Splits a code file into chunks of at most `max_tokens` tokens each.""" | |
| def __init__(self, max_tokens: int): | |
| self.max_tokens = max_tokens | |
| self.tokenizer = tiktoken.get_encoding("cl100k_base") | |
| self.text_chunker = TextChunker(max_tokens) | |
| def _get_language_from_filename(filename: str): | |
| """Returns a canonical name for the language of the file, based on its extension. | |
| Returns None if the language is unknown to the pygments lexer. | |
| """ | |
| try: | |
| lexer = pygments.lexers.get_lexer_for_filename(filename) | |
| return lexer.name.lower() | |
| except pygments.util.ClassNotFound: | |
| return None | |
| def _chunk_node(self, node: Node, filename: str, file_content: str) -> List[Chunk]: | |
| """Splits a node in the parse tree into a flat list of chunks.""" | |
| node_chunk = Chunk(filename, node.start_byte, node.end_byte) | |
| node_chunk.populate_content(file_content) | |
| if node_chunk.num_tokens(self.tokenizer) <= self.max_tokens: | |
| return [node_chunk] | |
| if not node.children: | |
| # This is a leaf node, but it's too long. We'll have to split it with a text tokenizer. | |
| return self.text_chunker.chunk(filename, file_content[node.start_byte : node.end_byte]) | |
| chunks = [] | |
| for child in node.children: | |
| chunks.extend(self._chunk_node(child, filename, file_content)) | |
| for chunk in chunks: | |
| # This should always be true. Otherwise there must be a bug in the code. | |
| assert chunk.content and chunk.num_tokens(self.tokenizer) <= self.max_tokens | |
| # Merge neighboring chunks if their combined size doesn't exceed max_tokens. The goal is to avoid pathologically | |
| # small chunks that end up being undeservedly preferred by the retriever. | |
| merged_chunks = [] | |
| for chunk in chunks: | |
| if not merged_chunks: | |
| merged_chunks.append(chunk) | |
| elif merged_chunks[-1].num_tokens(self.tokenizer) + chunk.num_tokens(self.tokenizer) < self.max_tokens - 50: | |
| # There's a good chance that merging these two chunks will be under the token limit. We're not 100% sure | |
| # at this point, because tokenization is not necessarily additive. | |
| merged = Chunk( | |
| merged_chunks[-1].filename, | |
| merged_chunks[-1].start_byte, | |
| chunk.end_byte, | |
| ) | |
| merged.populate_content(file_content) | |
| if merged.num_tokens(self.tokenizer) <= self.max_tokens: | |
| merged_chunks[-1] = merged | |
| else: | |
| merged_chunks.append(chunk) | |
| else: | |
| merged_chunks.append(chunk) | |
| chunks = merged_chunks | |
| for chunk in merged_chunks: | |
| # This should always be true. Otherwise there's a bug worth investigating. | |
| assert chunk.content and chunk.num_tokens(self.tokenizer) <= self.max_tokens | |
| return merged_chunks | |
| def is_code_file(filename: str) -> bool: | |
| """Checks whether pygment & tree_sitter can parse the file as code.""" | |
| language = CodeChunker._get_language_from_filename(filename) | |
| return language and language not in ["text only", "None"] | |
| def parse_tree(filename: str, content: str) -> List[str]: | |
| """Parses the code in a file and returns the parse tree.""" | |
| language = CodeChunker._get_language_from_filename(filename) | |
| if not language or language in ["text only", "None"]: | |
| logging.debug("%s doesn't seem to be a code file.", filename) | |
| return None | |
| try: | |
| parser = get_parser(language) | |
| except LookupError: | |
| logging.debug("%s doesn't seem to be a code file.", filename) | |
| return None | |
| tree = parser.parse(bytes(content, "utf8")) | |
| if not tree.root_node.children or tree.root_node.children[0].type == "ERROR": | |
| logging.warning("Failed to parse code in %s.", filename) | |
| return None | |
| return tree | |
| def chunk(self, file_path: str, file_content: str) -> List[Chunk]: | |
| """Chunks a code file into smaller pieces.""" | |
| if not file_content.strip(): | |
| return [] | |
| tree = self.parse_tree(file_path, file_content) | |
| if tree is None: | |
| return [] | |
| chunks = self._chunk_node(tree.root_node, file_path, file_content) | |
| for chunk in chunks: | |
| # Make sure that the chunk has content and doesn't exceed the max_tokens limit. Otherwise there must be | |
| # a bug in the code. | |
| assert chunk.content | |
| size = chunk.num_tokens(self.tokenizer) | |
| assert size <= self.max_tokens, f"Chunk size {size} exceeds max_tokens {self.max_tokens}." | |
| return chunks | |
| class TextChunker(Chunker): | |
| """Wrapper around semchunk: https://github.com/umarbutler/semchunk.""" | |
| def __init__(self, max_tokens: int): | |
| self.max_tokens = max_tokens | |
| tokenizer = tiktoken.get_encoding("cl100k_base") | |
| self.count_tokens = lambda text: len(tokenizer.encode(text, disallowed_special=())) | |
| def chunk(self, file_path: str, file_content: str) -> List[Chunk]: | |
| """Chunks a text file into smaller pieces.""" | |
| # We need to allocate some tokens for the filename, which is part of the chunk content. | |
| extra_tokens = self.count_tokens(file_path + "\n\n") | |
| text_chunks = chunk_via_semchunk(file_content, self.max_tokens - extra_tokens, self.count_tokens) | |
| chunks = [] | |
| start = 0 | |
| for text_chunk in text_chunks: | |
| # This assertion should always be true. Otherwise there's a bug worth finding. | |
| assert self.count_tokens(text_chunk) <= self.max_tokens - extra_tokens | |
| # Find the start/end positions of the chunks. | |
| start = file_content.index(text_chunk, start) | |
| if start == -1: | |
| logging.warning("Couldn't find semchunk in content: %s", text_chunk) | |
| else: | |
| end = start + len(text_chunk) | |
| chunks.append(Chunk(file_path, start, end, text_chunk)) | |
| start = end | |
| return chunks | |
| class IPYNBChunker(Chunker): | |
| """Extracts the python code from a Jupyter notebook, removing all the boilerplate. | |
| Based on https://github.com/GoogleCloudPlatform/generative-ai/blob/main/language/code/code_retrieval_augmented_generation.ipynb | |
| """ | |
| def __init__(self, code_chunker: CodeChunker): | |
| self.code_chunker = code_chunker | |
| def chunk(self, filename: str, content: str) -> List[Chunk]: | |
| if not filename.lower().endswith(".ipynb"): | |
| logging.warn("IPYNBChunker is only for .ipynb files.") | |
| return [] | |
| notebook = nbformat.reads(content, as_version=nbformat.NO_CONVERT) | |
| python_code = "\n".join([cell.source for cell in notebook.cells if cell.cell_type == "code"]) | |
| chunks = self.code_chunker.chunk(filename.replace(".ipynb", ".py"), python_code) | |
| # Change back the filenames to .ipynb. | |
| for chunk in chunks: | |
| chunk.filename = chunk.filename.replace(".py", ".ipynb") | |
| return chunks | |
| class UniversalChunker(Chunker): | |
| """Chunks a file into smaller pieces, regardless of whether it's code or text.""" | |
| def __init__(self, max_tokens: int): | |
| self.code_chunker = CodeChunker(max_tokens) | |
| self.text_chunker = TextChunker(max_tokens) | |
| def chunk(self, file_path: str, file_content: str) -> List[Chunk]: | |
| if file_path.lower().endswith(".ipynb"): | |
| return IPYNBChunker(self.code_chunker).chunk(file_path, file_content) | |
| if CodeChunker.is_code_file(file_path): | |
| return self.code_chunker.chunk(file_path, file_content) | |
| return self.text_chunker.chunk(file_path, file_content) | |