Spaces:
Sleeping
Sleeping
| """ | |
| AST-based semantic code chunker - Primary source of truth for code structure. | |
| This module implements the core AST-based chunking strategy that forms the | |
| authority layer of our hybrid chunking pipeline. It uses Python's built-in | |
| AST parser to extract semantic chunks (modules, classes, functions, methods) | |
| while preserving hierarchical relationships. | |
| ARCHITECTURE POSITION: | |
| - Authority Layer: Source of truth for semantic structure | |
| - Primary Chunker: Generates all primary chunks | |
| - Hierarchy Builder: Establishes parent-child relationships | |
| KEY FEATURES: | |
| 1. AST-first parsing for semantic accuracy | |
| 2. Hierarchical chunk generation with depth tracking | |
| 3. Byte-level span calculation for precise positioning | |
| 4. Import and decorator extraction per node | |
| 5. Deterministic chunk ID generation | |
| FLOW: | |
| File → Python AST → ASTChunker visitor → Semantic chunks with hierarchy | |
| USAGE: | |
| from ast_chunker import extract_ast_chunks | |
| chunks = extract_ast_chunks(Path("file.py")) | |
| """ | |
| import ast | |
| from pathlib import Path | |
| from typing import List, Optional, Union, Dict, Tuple | |
| import hashlib | |
| from ..utils.id_utils import deterministic_chunk_id | |
| from .chunk_schema import CodeChunk, ChunkAST, ChunkSpan, ChunkHierarchy, ASTSymbolType, ChunkType | |
| DocNode = Union[ | |
| ast.Module, | |
| ast.ClassDef, | |
| ast.FunctionDef, | |
| ast.AsyncFunctionDef, | |
| ] | |
| class ASTChunker(ast.NodeVisitor): | |
| def __init__(self, source: str, file_path: str): | |
| self.source = source | |
| self.file_path = file_path | |
| self.source_bytes = source.encode('utf-8') | |
| self.chunks: List[CodeChunk] = [] | |
| self.tree = ast.parse(source) | |
| # Track hierarchy | |
| self.current_class: Optional[str] = None | |
| self.imports_list: List[str] = [] | |
| # For hierarchy tracking | |
| self.parent_stack: List[CodeChunk] = [] | |
| self.sibling_counters: Dict[str, int] = {} | |
| # Attach parents to nodes | |
| for node in ast.walk(self.tree): | |
| for child in ast.iter_child_nodes(node): | |
| setattr(child, "parent", node) | |
| # ---------------- utilities ---------------- | |
| def _get_code(self, node: ast.AST) -> str: | |
| code = ast.get_source_segment(self.source, node) | |
| return code.strip() if code else "" | |
| def _get_byte_span(self, start_line: int, end_line: int) -> Tuple[int, int]: | |
| """Convert line numbers to byte positions""" | |
| lines = self.source.split('\n') | |
| # Calculate start byte | |
| start_byte = sum(len(line.encode()) + 1 for line in lines[:start_line-1]) | |
| # Calculate end byte (up to end_line) | |
| end_byte = sum(len(line.encode()) + 1 for line in lines[:end_line]) | |
| return start_byte, end_byte | |
| def _extract_node_imports(self, node: ast.AST) -> List[str]: | |
| """Extract imports specific to this node (not all module imports)""" | |
| imports: List[str] = [] | |
| # Walk through this node's body | |
| for child in ast.walk(node): | |
| if isinstance(child, (ast.Import, ast.ImportFrom)): | |
| try: | |
| imports.append(ast.unparse(child)) | |
| except Exception: | |
| imports.append(str(child)) | |
| return imports | |
| def _extract_decorators(self, node: ast.AST) -> List[str]: | |
| decorators: List[str] = [] | |
| if hasattr(node, "decorator_list"): | |
| for d in node.decorator_list: # type: ignore[attr-defined] | |
| try: | |
| decorators.append(ast.unparse(d)) | |
| except Exception: | |
| decorators.append(str(d)) | |
| return decorators | |
| # ---------------- chunk creation ---------------- | |
| def _create_chunk( | |
| self, | |
| node: DocNode, | |
| chunk_type: ChunkType, | |
| name: str, | |
| parent: Optional[str] = None, | |
| parent_chunk: Optional[CodeChunk] = None, | |
| ) -> CodeChunk: | |
| code = self._get_code(node) | |
| # Get line numbers | |
| start_line = getattr(node, "lineno", None) | |
| end_line = getattr(node, "end_lineno", None) | |
| # Calculate byte span | |
| start_byte, end_byte = None, None | |
| if start_line and end_line: | |
| start_byte, end_byte = self._get_byte_span(start_line, end_line) | |
| # Determine parent if not provided | |
| if parent is None and chunk_type == "method": | |
| parent = self.current_class | |
| decorators: List[str] = [] | |
| if isinstance(node, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)): | |
| decorators = self._extract_decorators(node) | |
| # Get imports specific to this node (not all module imports) | |
| node_imports = self._extract_node_imports(node) | |
| # Get docstring only for nodes that can have one | |
| docstring: Optional[str] = None | |
| if hasattr(node, 'body'): | |
| docstring = ast.get_docstring(node) | |
| # Determine hierarchy depth | |
| depth = 0 | |
| lineage: List[str] = [] | |
| sibling_index = 0 | |
| if parent_chunk: | |
| depth = parent_chunk.hierarchy.depth + 1 | |
| lineage = parent_chunk.hierarchy.lineage.copy() | |
| lineage.append(parent_chunk.chunk_id) | |
| # Update sibling counter | |
| parent_key = parent_chunk.chunk_id | |
| self.sibling_counters[parent_key] = self.sibling_counters.get(parent_key, 0) + 1 | |
| sibling_index = self.sibling_counters[parent_key] - 1 | |
| ast_info = ChunkAST( | |
| symbol_type=chunk_type, | |
| name=name, | |
| parent=parent, | |
| docstring=docstring, | |
| decorators=decorators, | |
| imports=node_imports, | |
| ) | |
| span = ChunkSpan( | |
| start_byte=start_byte, | |
| end_byte=end_byte, | |
| start_line=start_line, | |
| end_line=end_line, | |
| ) | |
| # Generate chunk ID | |
| chunk_id = deterministic_chunk_id( | |
| file_path=self.file_path, | |
| chunk_type=chunk_type, | |
| name=name, | |
| parent=parent, | |
| start_line=start_line, | |
| end_line=end_line, | |
| code=code, | |
| ) | |
| chunk = CodeChunk( | |
| chunk_id=chunk_id, | |
| file_path=self.file_path, | |
| language="python", | |
| chunk_type=chunk_type, | |
| code=code, | |
| ast=ast_info, | |
| span=span, | |
| hierarchy=ChunkHierarchy( | |
| parent_id=parent_chunk.chunk_id if parent_chunk else None, | |
| children_ids=[], | |
| depth=depth, | |
| is_primary=True, | |
| is_extracted=False, | |
| lineage=lineage, | |
| sibling_index=sibling_index, | |
| ), | |
| ) | |
| # Add to parent's children if parent exists | |
| if parent_chunk: | |
| parent_chunk.hierarchy.children_ids.append(chunk_id) | |
| self.chunks.append(chunk) | |
| return chunk | |
| def _create_module_chunk(self) -> CodeChunk: | |
| """Create module chunk with all imports""" | |
| module_name = Path(self.file_path).stem | |
| start_line = 1 | |
| end_line = len(self.source.split('\n')) | |
| start_byte, end_byte = self._get_byte_span(start_line, end_line) | |
| # Module code - entire file | |
| module_code = self.source | |
| # Extract ALL imports for module | |
| module_imports: List[str] = [] | |
| for node in ast.walk(self.tree): | |
| if isinstance(node, (ast.Import, ast.ImportFrom)): | |
| try: | |
| module_imports.append(ast.unparse(node)) | |
| except Exception: | |
| pass | |
| chunk_id = deterministic_chunk_id( | |
| file_path=self.file_path, | |
| chunk_type="module", | |
| name=module_name, | |
| parent=None, | |
| start_line=start_line, | |
| end_line=end_line, | |
| code=module_code, | |
| ) | |
| ast_info = ChunkAST( | |
| symbol_type="module", | |
| name=module_name, | |
| parent=None, | |
| docstring=ast.get_docstring(self.tree), | |
| decorators=[], | |
| imports=module_imports, # ALL imports in module | |
| ) | |
| span = ChunkSpan( | |
| start_byte=start_byte, | |
| end_byte=end_byte, | |
| start_line=start_line, | |
| end_line=end_line, | |
| ) | |
| chunk = CodeChunk( | |
| chunk_id=chunk_id, | |
| file_path=self.file_path, | |
| language="python", | |
| chunk_type="module", | |
| code=module_code, | |
| ast=ast_info, | |
| span=span, | |
| hierarchy=ChunkHierarchy( | |
| parent_id=None, | |
| children_ids=[], | |
| depth=0, | |
| is_primary=True, | |
| is_extracted=False, | |
| lineage=[], | |
| sibling_index=0, | |
| ), | |
| ) | |
| self.chunks.append(chunk) | |
| return chunk | |
| # ---------------- visitors ---------------- | |
| def visit_Import(self, node: ast.Import) -> None: | |
| try: | |
| self.imports_list.append(ast.unparse(node)) | |
| except Exception: | |
| pass | |
| self.generic_visit(node) | |
| def visit_ImportFrom(self, node: ast.ImportFrom) -> None: | |
| try: | |
| self.imports_list.append(ast.unparse(node)) | |
| except Exception: | |
| pass | |
| self.generic_visit(node) | |
| def visit_ClassDef(self, node: ast.ClassDef) -> None: | |
| # Create class chunk | |
| class_chunk = self._create_chunk( | |
| node, | |
| "class", | |
| node.name, | |
| parent="module", | |
| parent_chunk=self.parent_stack[-1] if self.parent_stack else None, | |
| ) | |
| # Save current class context | |
| previous_class = self.current_class | |
| self.current_class = node.name | |
| # Push class to stack | |
| self.parent_stack.append(class_chunk) | |
| # Visit class body | |
| self.generic_visit(node) | |
| # Restore previous context | |
| self.current_class = previous_class | |
| self.parent_stack.pop() | |
| def visit_FunctionDef(self, node: ast.FunctionDef) -> None: | |
| parent = getattr(node, "parent", None) | |
| if isinstance(parent, ast.Module): | |
| # Top-level function | |
| self._create_chunk( | |
| node, | |
| "function", | |
| node.name, | |
| parent="module", | |
| parent_chunk=self.parent_stack[-1] if self.parent_stack else None, | |
| ) | |
| elif isinstance(parent, ast.ClassDef): | |
| # Method inside class | |
| self._create_chunk( | |
| node, | |
| "method", | |
| node.name, | |
| parent=parent.name, | |
| parent_chunk=self.parent_stack[-1] if self.parent_stack else None, | |
| ) | |
| self.generic_visit(node) | |
| def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: | |
| parent = getattr(node, "parent", None) | |
| if isinstance(parent, ast.Module): | |
| # Top-level async function | |
| self._create_chunk( | |
| node, | |
| "function", | |
| node.name, | |
| parent="module", | |
| parent_chunk=self.parent_stack[-1] if self.parent_stack else None, | |
| ) | |
| elif isinstance(parent, ast.ClassDef): | |
| # Async method inside class | |
| self._create_chunk( | |
| node, | |
| "method", | |
| node.name, | |
| parent=parent.name, | |
| parent_chunk=self.parent_stack[-1] if self.parent_stack else None, | |
| ) | |
| self.generic_visit(node) | |
| def visit_Module(self, node: ast.Module) -> None: | |
| # Create module chunk first (root) | |
| module_chunk = self._create_module_chunk() | |
| # Push module to stack | |
| self.parent_stack.append(module_chunk) | |
| # Visit children to create classes and functions | |
| self.generic_visit(node) | |
| # Pop module from stack | |
| self.parent_stack.pop() | |
| # ---------------- public API ---------------- | |
| def extract_ast_chunks(file_path: Path) -> List[CodeChunk]: | |
| source = file_path.read_text(encoding="utf-8") | |
| chunker = ASTChunker(source, str(file_path)) | |
| # Visit the tree (creates all chunks with relationships) | |
| chunker.visit(chunker.tree) | |
| return chunker.chunks |