Spaces:
Running
Running
| """Enhanced chunker with proper token counting and merging strategies, inspired by Sage.""" | |
| import logging | |
| import os | |
| from typing import List, Dict, Any, Optional | |
| from dataclasses import dataclass | |
| from functools import cached_property | |
| import pygments | |
| import tiktoken | |
| from langchain_core.documents import Document | |
| from tree_sitter import Language, Parser, Node | |
| import tree_sitter_python | |
| import tree_sitter_javascript | |
| logger = logging.getLogger(__name__) | |
| tokenizer = tiktoken.get_encoding("cl100k_base") | |
| class FileChunk: | |
| """Represents a chunk of code with byte positions and rich metadata.""" | |
| file_content: str | |
| file_metadata: Dict | |
| start_byte: int | |
| end_byte: int | |
| # Enhanced metadata fields | |
| symbols_defined: Optional[List[str]] = None # Functions/classes defined in this chunk | |
| imports_used: Optional[List[str]] = None # Import statements relevant to chunk | |
| complexity_score: Optional[int] = None # Cyclomatic complexity | |
| parent_context: Optional[str] = None # Parent class/module name | |
| def filename(self): | |
| if "file_path" not in self.file_metadata: | |
| raise ValueError("file_metadata must contain a 'file_path' key.") | |
| return self.file_metadata["file_path"] | |
| def content(self) -> str: | |
| """The text content to be embedded. Includes filename for context.""" | |
| return self.filename + "\n\n" + self.file_content[self.start_byte : self.end_byte] | |
| def num_tokens(self): | |
| """Number of tokens in this chunk.""" | |
| return len(tokenizer.encode(self.content, disallowed_special=())) | |
| def to_document(self) -> Document: | |
| """Convert to LangChain Document with enhanced metadata.""" | |
| chunk_type = self.file_metadata.get("chunk_type", "code") | |
| name = self.file_metadata.get("name", None) | |
| # Calculate line range from byte positions | |
| lines_before = self.file_content[:self.start_byte].count('\n') | |
| lines_in_chunk = self.file_content[self.start_byte:self.end_byte].count('\n') | |
| line_range = f"L{lines_before + 1}-L{lines_before + lines_in_chunk + 1}" | |
| # Get language from file extension | |
| ext = self.filename.split('.')[-1].lower() if '.' in self.filename else 'unknown' | |
| language_map = { | |
| 'py': 'python', 'js': 'javascript', 'ts': 'typescript', | |
| 'jsx': 'javascript', 'tsx': 'typescript', 'java': 'java', | |
| 'cpp': 'cpp', 'c': 'c', 'go': 'go', 'rs': 'rust' | |
| } | |
| language = language_map.get(ext, ext) | |
| metadata = { | |
| **self.file_metadata, | |
| "id": f"{self.filename}_{self.start_byte}_{self.end_byte}", | |
| "start_byte": self.start_byte, | |
| "end_byte": self.end_byte, | |
| "length": self.end_byte - self.start_byte, | |
| "line_range": line_range, | |
| "language": language, | |
| "chunk_type": chunk_type, | |
| "name": name, | |
| } | |
| # Add enhanced metadata if available | |
| if self.symbols_defined: | |
| metadata["symbols"] = self.symbols_defined | |
| if self.imports_used: | |
| metadata["imports"] = self.imports_used | |
| if self.complexity_score is not None: | |
| metadata["complexity"] = self.complexity_score | |
| if self.parent_context: | |
| metadata["parent_context"] = self.parent_context | |
| return Document(page_content=self.content, metadata=metadata) | |
| class StructuralChunker: | |
| """ | |
| Chunks code files based on their AST structure (Functions, Classes) using Tree-sitter. | |
| Uses proper token counting with tiktoken and implements merging strategies to avoid | |
| pathologically small chunks. | |
| """ | |
| def __init__(self, max_tokens: int = 800): | |
| self.max_tokens = max_tokens | |
| self.parsers = {} | |
| self._init_parsers() | |
| def _init_parsers(self): | |
| try: | |
| self.parsers['py'] = Parser(Language(tree_sitter_python.language())) | |
| self.parsers['python'] = self.parsers['py'] | |
| js_parser = Parser(Language(tree_sitter_javascript.language())) | |
| self.parsers['js'] = js_parser | |
| self.parsers['javascript'] = js_parser | |
| self.parsers['jsx'] = js_parser | |
| self.parsers['ts'] = js_parser | |
| self.parsers['tsx'] = js_parser | |
| except Exception as e: | |
| logger.error(f"Error initializing parsers in Chunker: {e}") | |
| def _get_language_from_filename(filename: str) -> Optional[str]: | |
| """Returns a canonical name for the language based on file extension.""" | |
| extension = os.path.splitext(filename)[1] | |
| if extension == ".tsx": | |
| return "tsx" | |
| try: | |
| lexer = pygments.lexers.get_lexer_for_filename(filename) | |
| return lexer.name.lower() | |
| except pygments.util.ClassNotFound: | |
| return None | |
| def is_code_file(filename: str) -> bool: | |
| """Checks whether the file can be parsed as code.""" | |
| language = StructuralChunker._get_language_from_filename(filename) | |
| return language and language not in ["text only", "none"] | |
| def chunk(self, content: str, file_path: str) -> List[Document]: | |
| """Main chunking entry point.""" | |
| ext = file_path.split('.')[-1].lower() | |
| parser = self.parsers.get(ext) | |
| if "\0" in content: | |
| logger.warning(f"Binary content detected in {file_path}, skipping chunking") | |
| return [] | |
| if not parser: | |
| logger.warning(f"No parser found for extension: {ext}, treating as text file") | |
| # Fallback to simple text chunking for non-code files | |
| return self._chunk_text_file(content, file_path) | |
| try: | |
| tree = parser.parse(bytes(content, "utf8")) | |
| if not tree.root_node.children or tree.root_node.children[0].type == "ERROR": | |
| logger.warning(f"Failed to parse code in {file_path}, falling back to text chunking") | |
| return self._chunk_text_file(content, file_path) | |
| file_metadata = {"file_path": file_path, "chunk_type": "code", "_full_content": content} | |
| file_chunks = self._chunk_node(tree.root_node, content, file_metadata) | |
| # Convert FileChunk objects to Documents | |
| return [chunk.to_document() for chunk in file_chunks] | |
| except Exception as e: | |
| logger.error(f"Failed to chunk {file_path}: {e}, falling back to text chunking") | |
| return self._chunk_text_file(content, file_path) | |
| def _chunk_text_file(self, content: str, file_path: str) -> List[Document]: | |
| """Fallback chunking for text files.""" | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=self.max_tokens * 4, # Approximate char count | |
| chunk_overlap=200, | |
| separators=["\n\n", "\n", " ", ""] | |
| ) | |
| texts = splitter.split_text(content) | |
| return [ | |
| Document( | |
| page_content=f"{file_path}\n\n{text}", | |
| metadata={"file_path": file_path, "chunk_type": "text"} | |
| ) | |
| for text in texts | |
| ] | |
| def _chunk_node(self, node: Node, file_content: str, file_metadata: Dict) -> List[FileChunk]: | |
| """ | |
| Recursively splits a node into chunks. | |
| If a node is small enough, returns it as a single chunk. | |
| If too large, recursively chunks its children and merges neighboring chunks when possible. | |
| """ | |
| node_chunk = FileChunk(file_content, file_metadata, node.start_byte, node.end_byte) | |
| # If chunk is small enough and not a module/program node, return it | |
| if node_chunk.num_tokens <= self.max_tokens and node.type not in ["module", "program"]: | |
| # Add metadata about the node type and name | |
| chunk_metadata = {**file_metadata} | |
| chunk_metadata["chunk_type"] = node.type | |
| name = self._get_node_name(node, file_content) | |
| if name: | |
| chunk_metadata["name"] = name | |
| # Extract enhanced metadata | |
| node_chunk.file_metadata = chunk_metadata | |
| node_chunk.symbols_defined = self._extract_symbols(node, file_content) | |
| node_chunk.imports_used = self._extract_imports(node, file_content) | |
| node_chunk.complexity_score = self._calculate_complexity(node, file_content) | |
| node_chunk.parent_context = self._get_parent_context(node, file_content) | |
| return [node_chunk] | |
| # If leaf node is too large, split it as text | |
| if not node.children: | |
| return self._chunk_large_text( | |
| file_content[node.start_byte : node.end_byte], | |
| node.start_byte, | |
| file_metadata | |
| ) | |
| # Recursively chunk children | |
| chunks = [] | |
| for child in node.children: | |
| chunks.extend(self._chunk_node(child, file_content, file_metadata)) | |
| # Merge neighboring chunks if their combined size doesn't exceed max_tokens | |
| merged_chunks = [] | |
| for chunk in chunks: | |
| if not merged_chunks: | |
| merged_chunks.append(chunk) | |
| elif merged_chunks[-1].num_tokens + chunk.num_tokens < self.max_tokens - 50: | |
| # Try merging | |
| merged = FileChunk( | |
| file_content, | |
| file_metadata, | |
| merged_chunks[-1].start_byte, | |
| chunk.end_byte, | |
| ) | |
| if merged.num_tokens <= self.max_tokens: | |
| merged_chunks[-1] = merged | |
| else: | |
| merged_chunks.append(chunk) | |
| else: | |
| merged_chunks.append(chunk) | |
| # Verify all chunks are within token limit | |
| for chunk in merged_chunks: | |
| if chunk.num_tokens > self.max_tokens: | |
| logger.warning( | |
| f"Chunk size {chunk.num_tokens} exceeds max_tokens {self.max_tokens} " | |
| f"for {chunk.filename} at bytes {chunk.start_byte}-{chunk.end_byte}" | |
| ) | |
| return merged_chunks | |
| def _chunk_large_text(self, text: str, start_offset: int, file_metadata: Dict) -> List[FileChunk]: | |
| """Splits large text (e.g., long comments or strings) into smaller chunks.""" | |
| # Need full file content for FileChunk to work properly | |
| file_content = file_metadata.get("_full_content", "") | |
| if not file_content: | |
| logger.warning("Cannot chunk large text without full file content") | |
| return [] | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=self.max_tokens * 4, | |
| chunk_overlap=200 | |
| ) | |
| texts = splitter.split_text(text) | |
| chunks = [] | |
| current_offset = start_offset | |
| for text_chunk in texts: | |
| end_offset = current_offset + len(text_chunk) | |
| chunk = FileChunk( | |
| file_content, | |
| {**file_metadata, "chunk_type": "large_text"}, | |
| current_offset, | |
| end_offset | |
| ) | |
| chunks.append(chunk) | |
| current_offset = end_offset | |
| return chunks | |
| def _get_node_name(self, node: Node, content: str) -> Optional[str]: | |
| """Extracts the name of a function or class node.""" | |
| name_node = node.child_by_field_name("name") | |
| if name_node: | |
| return content[name_node.start_byte:name_node.end_byte] | |
| return None | |
| def _extract_symbols(self, node: Node, content: str) -> List[str]: | |
| """ | |
| Extract function and class names defined in this node. | |
| Returns: | |
| List of symbol names (e.g., ['MyClass', 'MyClass.my_method']) | |
| """ | |
| symbols = [] | |
| def traverse(n: Node, parent_class: Optional[str] = None): | |
| # Check if this is a function or class definition | |
| if n.type in ['function_definition', 'class_definition', 'method_definition']: | |
| name = self._get_node_name(n, content) | |
| if name: | |
| if parent_class: | |
| symbols.append(f"{parent_class}.{name}") | |
| else: | |
| symbols.append(name) | |
| # If it's a class, traverse its children with this class as parent | |
| if n.type == 'class_definition': | |
| for child in n.children: | |
| traverse(child, name) | |
| return # Don't traverse children again | |
| # Traverse children | |
| for child in n.children: | |
| traverse(child, parent_class) | |
| traverse(node) | |
| return symbols | |
| def _extract_imports(self, node: Node, content: str) -> List[str]: | |
| """ | |
| Extract import statements from this node. | |
| Returns: | |
| List of import statements (e.g., ['import os', 'from typing import List']) | |
| """ | |
| imports = [] | |
| def traverse(n: Node): | |
| # Python imports | |
| if n.type in ['import_statement', 'import_from_statement']: | |
| import_text = content[n.start_byte:n.end_byte].strip() | |
| imports.append(import_text) | |
| # JavaScript/TypeScript imports | |
| elif n.type == 'import_statement': | |
| import_text = content[n.start_byte:n.end_byte].strip() | |
| imports.append(import_text) | |
| # Traverse children | |
| for child in n.children: | |
| traverse(child) | |
| traverse(node) | |
| return imports | |
| def _calculate_complexity(self, node: Node, content: str) -> int: | |
| """ | |
| Calculate cyclomatic complexity for a code chunk. | |
| Cyclomatic complexity = number of decision points + 1 | |
| Decision points: if, elif, for, while, except, and, or, case, etc. | |
| Returns: | |
| Complexity score (integer) | |
| """ | |
| complexity = 1 # Base complexity | |
| # Decision point node types | |
| decision_nodes = { | |
| 'if_statement', 'elif_clause', 'else_clause', | |
| 'for_statement', 'while_statement', | |
| 'except_clause', 'case_clause', | |
| 'conditional_expression', # ternary operator | |
| 'boolean_operator', # and, or | |
| } | |
| def traverse(n: Node): | |
| nonlocal complexity | |
| if n.type in decision_nodes: | |
| complexity += 1 | |
| for child in n.children: | |
| traverse(child) | |
| traverse(node) | |
| return complexity | |
| def _get_parent_context(self, node: Node, content: str) -> Optional[str]: | |
| """ | |
| Get the parent class or module context for this node. | |
| Returns: | |
| Parent class name or None | |
| """ | |
| current = node.parent | |
| while current: | |
| if current.type == 'class_definition': | |
| name = self._get_node_name(current, content) | |
| if name: | |
| return name | |
| current = current.parent | |
| return None | |