"""Enhanced chunker with proper token counting and merging strategies, inspired by Sage.""" import logging import os from typing import List, Dict, Any, Optional from dataclasses import dataclass from functools import cached_property import pygments import tiktoken from langchain_core.documents import Document from tree_sitter import Language, Parser, Node import tree_sitter_python import tree_sitter_javascript logger = logging.getLogger(__name__) tokenizer = tiktoken.get_encoding("cl100k_base") @dataclass class FileChunk: """Represents a chunk of code with byte positions and rich metadata.""" file_content: str file_metadata: Dict start_byte: int end_byte: int # Enhanced metadata fields symbols_defined: Optional[List[str]] = None # Functions/classes defined in this chunk imports_used: Optional[List[str]] = None # Import statements relevant to chunk complexity_score: Optional[int] = None # Cyclomatic complexity parent_context: Optional[str] = None # Parent class/module name @cached_property def filename(self): if "file_path" not in self.file_metadata: raise ValueError("file_metadata must contain a 'file_path' key.") return self.file_metadata["file_path"] @cached_property def content(self) -> str: """The text content to be embedded. Includes filename for context.""" return self.filename + "\n\n" + self.file_content[self.start_byte : self.end_byte] @cached_property def num_tokens(self): """Number of tokens in this chunk.""" return len(tokenizer.encode(self.content, disallowed_special=())) def to_document(self) -> Document: """Convert to LangChain Document with enhanced metadata.""" chunk_type = self.file_metadata.get("chunk_type", "code") name = self.file_metadata.get("name", None) # Calculate line range from byte positions lines_before = self.file_content[:self.start_byte].count('\n') lines_in_chunk = self.file_content[self.start_byte:self.end_byte].count('\n') line_range = f"L{lines_before + 1}-L{lines_before + lines_in_chunk + 1}" # Get language from file extension ext = self.filename.split('.')[-1].lower() if '.' in self.filename else 'unknown' language_map = { 'py': 'python', 'js': 'javascript', 'ts': 'typescript', 'jsx': 'javascript', 'tsx': 'typescript', 'java': 'java', 'cpp': 'cpp', 'c': 'c', 'go': 'go', 'rs': 'rust' } language = language_map.get(ext, ext) metadata = { **self.file_metadata, "id": f"{self.filename}_{self.start_byte}_{self.end_byte}", "start_byte": self.start_byte, "end_byte": self.end_byte, "length": self.end_byte - self.start_byte, "line_range": line_range, "language": language, "chunk_type": chunk_type, "name": name, } # Add enhanced metadata if available if self.symbols_defined: metadata["symbols"] = self.symbols_defined if self.imports_used: metadata["imports"] = self.imports_used if self.complexity_score is not None: metadata["complexity"] = self.complexity_score if self.parent_context: metadata["parent_context"] = self.parent_context return Document(page_content=self.content, metadata=metadata) class StructuralChunker: """ Chunks code files based on their AST structure (Functions, Classes) using Tree-sitter. Uses proper token counting with tiktoken and implements merging strategies to avoid pathologically small chunks. """ def __init__(self, max_tokens: int = 800): self.max_tokens = max_tokens self.parsers = {} self._init_parsers() def _init_parsers(self): try: self.parsers['py'] = Parser(Language(tree_sitter_python.language())) self.parsers['python'] = self.parsers['py'] js_parser = Parser(Language(tree_sitter_javascript.language())) self.parsers['js'] = js_parser self.parsers['javascript'] = js_parser self.parsers['jsx'] = js_parser self.parsers['ts'] = js_parser self.parsers['tsx'] = js_parser except Exception as e: logger.error(f"Error initializing parsers in Chunker: {e}") @staticmethod def _get_language_from_filename(filename: str) -> Optional[str]: """Returns a canonical name for the language based on file extension.""" extension = os.path.splitext(filename)[1] if extension == ".tsx": return "tsx" try: lexer = pygments.lexers.get_lexer_for_filename(filename) return lexer.name.lower() except pygments.util.ClassNotFound: return None @staticmethod def is_code_file(filename: str) -> bool: """Checks whether the file can be parsed as code.""" language = StructuralChunker._get_language_from_filename(filename) return language and language not in ["text only", "none"] def chunk(self, content: str, file_path: str) -> List[Document]: """Main chunking entry point.""" ext = file_path.split('.')[-1].lower() parser = self.parsers.get(ext) if "\0" in content: logger.warning(f"Binary content detected in {file_path}, skipping chunking") return [] if not parser: logger.warning(f"No parser found for extension: {ext}, treating as text file") # Fallback to simple text chunking for non-code files return self._chunk_text_file(content, file_path) try: tree = parser.parse(bytes(content, "utf8")) if not tree.root_node.children or tree.root_node.children[0].type == "ERROR": logger.warning(f"Failed to parse code in {file_path}, falling back to text chunking") return self._chunk_text_file(content, file_path) file_metadata = {"file_path": file_path, "chunk_type": "code", "_full_content": content} file_chunks = self._chunk_node(tree.root_node, content, file_metadata) # Convert FileChunk objects to Documents return [chunk.to_document() for chunk in file_chunks] except Exception as e: logger.error(f"Failed to chunk {file_path}: {e}, falling back to text chunking") return self._chunk_text_file(content, file_path) def _chunk_text_file(self, content: str, file_path: str) -> List[Document]: """Fallback chunking for text files.""" from langchain_text_splitters import RecursiveCharacterTextSplitter splitter = RecursiveCharacterTextSplitter( chunk_size=self.max_tokens * 4, # Approximate char count chunk_overlap=200, separators=["\n\n", "\n", " ", ""] ) texts = splitter.split_text(content) return [ Document( page_content=f"{file_path}\n\n{text}", metadata={"file_path": file_path, "chunk_type": "text"} ) for text in texts ] def _chunk_node(self, node: Node, file_content: str, file_metadata: Dict) -> List[FileChunk]: """ Recursively splits a node into chunks. If a node is small enough, returns it as a single chunk. If too large, recursively chunks its children and merges neighboring chunks when possible. """ node_chunk = FileChunk(file_content, file_metadata, node.start_byte, node.end_byte) # If chunk is small enough and not a module/program node, return it if node_chunk.num_tokens <= self.max_tokens and node.type not in ["module", "program"]: # Add metadata about the node type and name chunk_metadata = {**file_metadata} chunk_metadata["chunk_type"] = node.type name = self._get_node_name(node, file_content) if name: chunk_metadata["name"] = name # Extract enhanced metadata node_chunk.file_metadata = chunk_metadata node_chunk.symbols_defined = self._extract_symbols(node, file_content) node_chunk.imports_used = self._extract_imports(node, file_content) node_chunk.complexity_score = self._calculate_complexity(node, file_content) node_chunk.parent_context = self._get_parent_context(node, file_content) return [node_chunk] # If leaf node is too large, split it as text if not node.children: return self._chunk_large_text( file_content[node.start_byte : node.end_byte], node.start_byte, file_metadata ) # Recursively chunk children chunks = [] for child in node.children: chunks.extend(self._chunk_node(child, file_content, file_metadata)) # Merge neighboring chunks if their combined size doesn't exceed max_tokens merged_chunks = [] for chunk in chunks: if not merged_chunks: merged_chunks.append(chunk) elif merged_chunks[-1].num_tokens + chunk.num_tokens < self.max_tokens - 50: # Try merging merged = FileChunk( file_content, file_metadata, merged_chunks[-1].start_byte, chunk.end_byte, ) if merged.num_tokens <= self.max_tokens: merged_chunks[-1] = merged else: merged_chunks.append(chunk) else: merged_chunks.append(chunk) # Verify all chunks are within token limit for chunk in merged_chunks: if chunk.num_tokens > self.max_tokens: logger.warning( f"Chunk size {chunk.num_tokens} exceeds max_tokens {self.max_tokens} " f"for {chunk.filename} at bytes {chunk.start_byte}-{chunk.end_byte}" ) return merged_chunks def _chunk_large_text(self, text: str, start_offset: int, file_metadata: Dict) -> List[FileChunk]: """Splits large text (e.g., long comments or strings) into smaller chunks.""" # Need full file content for FileChunk to work properly file_content = file_metadata.get("_full_content", "") if not file_content: logger.warning("Cannot chunk large text without full file content") return [] from langchain_text_splitters import RecursiveCharacterTextSplitter splitter = RecursiveCharacterTextSplitter( chunk_size=self.max_tokens * 4, chunk_overlap=200 ) texts = splitter.split_text(text) chunks = [] current_offset = start_offset for text_chunk in texts: end_offset = current_offset + len(text_chunk) chunk = FileChunk( file_content, {**file_metadata, "chunk_type": "large_text"}, current_offset, end_offset ) chunks.append(chunk) current_offset = end_offset return chunks def _get_node_name(self, node: Node, content: str) -> Optional[str]: """Extracts the name of a function or class node.""" name_node = node.child_by_field_name("name") if name_node: return content[name_node.start_byte:name_node.end_byte] return None def _extract_symbols(self, node: Node, content: str) -> List[str]: """ Extract function and class names defined in this node. Returns: List of symbol names (e.g., ['MyClass', 'MyClass.my_method']) """ symbols = [] def traverse(n: Node, parent_class: Optional[str] = None): # Check if this is a function or class definition if n.type in ['function_definition', 'class_definition', 'method_definition']: name = self._get_node_name(n, content) if name: if parent_class: symbols.append(f"{parent_class}.{name}") else: symbols.append(name) # If it's a class, traverse its children with this class as parent if n.type == 'class_definition': for child in n.children: traverse(child, name) return # Don't traverse children again # Traverse children for child in n.children: traverse(child, parent_class) traverse(node) return symbols def _extract_imports(self, node: Node, content: str) -> List[str]: """ Extract import statements from this node. Returns: List of import statements (e.g., ['import os', 'from typing import List']) """ imports = [] def traverse(n: Node): # Python imports if n.type in ['import_statement', 'import_from_statement']: import_text = content[n.start_byte:n.end_byte].strip() imports.append(import_text) # JavaScript/TypeScript imports elif n.type == 'import_statement': import_text = content[n.start_byte:n.end_byte].strip() imports.append(import_text) # Traverse children for child in n.children: traverse(child) traverse(node) return imports def _calculate_complexity(self, node: Node, content: str) -> int: """ Calculate cyclomatic complexity for a code chunk. Cyclomatic complexity = number of decision points + 1 Decision points: if, elif, for, while, except, and, or, case, etc. Returns: Complexity score (integer) """ complexity = 1 # Base complexity # Decision point node types decision_nodes = { 'if_statement', 'elif_clause', 'else_clause', 'for_statement', 'while_statement', 'except_clause', 'case_clause', 'conditional_expression', # ternary operator 'boolean_operator', # and, or } def traverse(n: Node): nonlocal complexity if n.type in decision_nodes: complexity += 1 for child in n.children: traverse(child) traverse(node) return complexity def _get_parent_context(self, node: Node, content: str) -> Optional[str]: """ Get the parent class or module context for this node. Returns: Parent class name or None """ current = node.parent while current: if current.type == 'class_definition': name = self._get_node_name(current, content) if name: return name current = current.parent return None