| """ |
| AST-based semantic code chunker - Primary source of truth for code structure. |
| |
| This module implements the core AST-based chunking strategy that forms the |
| authority layer of our hybrid chunking pipeline. It uses Python's built-in |
| AST parser to extract semantic chunks (modules, classes, functions, methods) |
| while preserving hierarchical relationships. |
| |
| ARCHITECTURE POSITION: |
| - Authority Layer: Source of truth for semantic structure |
| - Primary Chunker: Generates all primary chunks |
| - Hierarchy Builder: Establishes parent-child relationships |
| |
| KEY FEATURES: |
| 1. AST-first parsing for semantic accuracy |
| 2. Hierarchical chunk generation with depth tracking |
| 3. Byte-level span calculation for precise positioning |
| 4. Import and decorator extraction per node |
| 5. Deterministic chunk ID generation |
| |
| FLOW: |
| File → Python AST → ASTChunker visitor → Semantic chunks with hierarchy |
| |
| USAGE: |
| from ast_chunker import extract_ast_chunks |
| chunks = extract_ast_chunks(Path("file.py")) |
| """ |
|
|
| import ast |
| from pathlib import Path |
| from typing import List, Optional, Union, Dict, Tuple |
| import hashlib |
|
|
| from ..utils.id_utils import deterministic_chunk_id |
| from .chunk_schema import CodeChunk, ChunkAST, ChunkSpan, ChunkHierarchy, ASTSymbolType, ChunkType |
|
|
| DocNode = Union[ |
| ast.Module, |
| ast.ClassDef, |
| ast.FunctionDef, |
| ast.AsyncFunctionDef, |
| ] |
|
|
|
|
| class ASTChunker(ast.NodeVisitor): |
| def __init__(self, source: str, file_path: str): |
| self.source = source |
| self.file_path = file_path |
| self.source_bytes = source.encode('utf-8') |
| self.chunks: List[CodeChunk] = [] |
| self.tree = ast.parse(source) |
| |
| |
| self.current_class: Optional[str] = None |
| self.imports_list: List[str] = [] |
| |
| |
| self.parent_stack: List[CodeChunk] = [] |
| self.sibling_counters: Dict[str, int] = {} |
| |
| |
| for node in ast.walk(self.tree): |
| for child in ast.iter_child_nodes(node): |
| setattr(child, "parent", node) |
|
|
| |
|
|
| def _get_code(self, node: ast.AST) -> str: |
| code = ast.get_source_segment(self.source, node) |
| return code.strip() if code else "" |
|
|
| def _get_byte_span(self, start_line: int, end_line: int) -> Tuple[int, int]: |
| """Convert line numbers to byte positions""" |
| lines = self.source.split('\n') |
| |
| |
| start_byte = sum(len(line.encode()) + 1 for line in lines[:start_line-1]) |
| |
| |
| end_byte = sum(len(line.encode()) + 1 for line in lines[:end_line]) |
| |
| return start_byte, end_byte |
|
|
| def _extract_node_imports(self, node: ast.AST) -> List[str]: |
| """Extract imports specific to this node (not all module imports)""" |
| imports: List[str] = [] |
| |
| |
| for child in ast.walk(node): |
| if isinstance(child, (ast.Import, ast.ImportFrom)): |
| try: |
| imports.append(ast.unparse(child)) |
| except Exception: |
| imports.append(str(child)) |
| return imports |
|
|
| def _extract_decorators(self, node: ast.AST) -> List[str]: |
| decorators: List[str] = [] |
| if hasattr(node, "decorator_list"): |
| for d in node.decorator_list: |
| try: |
| decorators.append(ast.unparse(d)) |
| except Exception: |
| decorators.append(str(d)) |
| return decorators |
|
|
| |
|
|
| def _create_chunk( |
| self, |
| node: DocNode, |
| chunk_type: ChunkType, |
| name: str, |
| parent: Optional[str] = None, |
| parent_chunk: Optional[CodeChunk] = None, |
| ) -> CodeChunk: |
| code = self._get_code(node) |
| |
| |
| start_line = getattr(node, "lineno", None) |
| end_line = getattr(node, "end_lineno", None) |
| |
| |
| start_byte, end_byte = None, None |
| if start_line and end_line: |
| start_byte, end_byte = self._get_byte_span(start_line, end_line) |
|
|
| |
| if parent is None and chunk_type == "method": |
| parent = self.current_class |
|
|
| decorators: List[str] = [] |
| if isinstance(node, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)): |
| decorators = self._extract_decorators(node) |
|
|
| |
| node_imports = self._extract_node_imports(node) |
|
|
| |
| docstring: Optional[str] = None |
| if hasattr(node, 'body'): |
| docstring = ast.get_docstring(node) |
|
|
| |
| depth = 0 |
| lineage: List[str] = [] |
| sibling_index = 0 |
| |
| if parent_chunk: |
| depth = parent_chunk.hierarchy.depth + 1 |
| lineage = parent_chunk.hierarchy.lineage.copy() |
| lineage.append(parent_chunk.chunk_id) |
| |
| |
| parent_key = parent_chunk.chunk_id |
| self.sibling_counters[parent_key] = self.sibling_counters.get(parent_key, 0) + 1 |
| sibling_index = self.sibling_counters[parent_key] - 1 |
|
|
| ast_info = ChunkAST( |
| symbol_type=chunk_type, |
| name=name, |
| parent=parent, |
| docstring=docstring, |
| decorators=decorators, |
| imports=node_imports, |
| ) |
|
|
| span = ChunkSpan( |
| start_byte=start_byte, |
| end_byte=end_byte, |
| start_line=start_line, |
| end_line=end_line, |
| ) |
|
|
| |
| chunk_id = deterministic_chunk_id( |
| file_path=self.file_path, |
| chunk_type=chunk_type, |
| name=name, |
| parent=parent, |
| start_line=start_line, |
| end_line=end_line, |
| code=code, |
| ) |
|
|
| chunk = CodeChunk( |
| chunk_id=chunk_id, |
| file_path=self.file_path, |
| language="python", |
| chunk_type=chunk_type, |
| code=code, |
| ast=ast_info, |
| span=span, |
| hierarchy=ChunkHierarchy( |
| parent_id=parent_chunk.chunk_id if parent_chunk else None, |
| children_ids=[], |
| depth=depth, |
| is_primary=True, |
| is_extracted=False, |
| lineage=lineage, |
| sibling_index=sibling_index, |
| ), |
| ) |
|
|
| |
| if parent_chunk: |
| parent_chunk.hierarchy.children_ids.append(chunk_id) |
|
|
| self.chunks.append(chunk) |
| return chunk |
|
|
| def _create_module_chunk(self) -> CodeChunk: |
| """Create module chunk with all imports""" |
| module_name = Path(self.file_path).stem |
| start_line = 1 |
| end_line = len(self.source.split('\n')) |
| start_byte, end_byte = self._get_byte_span(start_line, end_line) |
| |
| |
| module_code = self.source |
| |
| |
| module_imports: List[str] = [] |
| for node in ast.walk(self.tree): |
| if isinstance(node, (ast.Import, ast.ImportFrom)): |
| try: |
| module_imports.append(ast.unparse(node)) |
| except Exception: |
| pass |
| |
| chunk_id = deterministic_chunk_id( |
| file_path=self.file_path, |
| chunk_type="module", |
| name=module_name, |
| parent=None, |
| start_line=start_line, |
| end_line=end_line, |
| code=module_code, |
| ) |
| |
| ast_info = ChunkAST( |
| symbol_type="module", |
| name=module_name, |
| parent=None, |
| docstring=ast.get_docstring(self.tree), |
| decorators=[], |
| imports=module_imports, |
| ) |
| |
| span = ChunkSpan( |
| start_byte=start_byte, |
| end_byte=end_byte, |
| start_line=start_line, |
| end_line=end_line, |
| ) |
| |
| chunk = CodeChunk( |
| chunk_id=chunk_id, |
| file_path=self.file_path, |
| language="python", |
| chunk_type="module", |
| code=module_code, |
| ast=ast_info, |
| span=span, |
| hierarchy=ChunkHierarchy( |
| parent_id=None, |
| children_ids=[], |
| depth=0, |
| is_primary=True, |
| is_extracted=False, |
| lineage=[], |
| sibling_index=0, |
| ), |
| ) |
| |
| self.chunks.append(chunk) |
| return chunk |
|
|
| |
|
|
| def visit_Import(self, node: ast.Import) -> None: |
| try: |
| self.imports_list.append(ast.unparse(node)) |
| except Exception: |
| pass |
| self.generic_visit(node) |
|
|
| def visit_ImportFrom(self, node: ast.ImportFrom) -> None: |
| try: |
| self.imports_list.append(ast.unparse(node)) |
| except Exception: |
| pass |
| self.generic_visit(node) |
|
|
| def visit_ClassDef(self, node: ast.ClassDef) -> None: |
| |
| class_chunk = self._create_chunk( |
| node, |
| "class", |
| node.name, |
| parent="module", |
| parent_chunk=self.parent_stack[-1] if self.parent_stack else None, |
| ) |
| |
| |
| previous_class = self.current_class |
| self.current_class = node.name |
| |
| |
| self.parent_stack.append(class_chunk) |
| |
| |
| self.generic_visit(node) |
| |
| |
| self.current_class = previous_class |
| self.parent_stack.pop() |
|
|
| def visit_FunctionDef(self, node: ast.FunctionDef) -> None: |
| parent = getattr(node, "parent", None) |
| |
| if isinstance(parent, ast.Module): |
| |
| self._create_chunk( |
| node, |
| "function", |
| node.name, |
| parent="module", |
| parent_chunk=self.parent_stack[-1] if self.parent_stack else None, |
| ) |
| elif isinstance(parent, ast.ClassDef): |
| |
| self._create_chunk( |
| node, |
| "method", |
| node.name, |
| parent=parent.name, |
| parent_chunk=self.parent_stack[-1] if self.parent_stack else None, |
| ) |
| |
| self.generic_visit(node) |
|
|
| def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: |
| parent = getattr(node, "parent", None) |
| |
| if isinstance(parent, ast.Module): |
| |
| self._create_chunk( |
| node, |
| "function", |
| node.name, |
| parent="module", |
| parent_chunk=self.parent_stack[-1] if self.parent_stack else None, |
| ) |
| elif isinstance(parent, ast.ClassDef): |
| |
| self._create_chunk( |
| node, |
| "method", |
| node.name, |
| parent=parent.name, |
| parent_chunk=self.parent_stack[-1] if self.parent_stack else None, |
| ) |
| |
| self.generic_visit(node) |
|
|
| def visit_Module(self, node: ast.Module) -> None: |
| |
| module_chunk = self._create_module_chunk() |
| |
| |
| self.parent_stack.append(module_chunk) |
| |
| |
| self.generic_visit(node) |
| |
| |
| self.parent_stack.pop() |
|
|
|
|
| |
|
|
| def extract_ast_chunks(file_path: Path) -> List[CodeChunk]: |
| source = file_path.read_text(encoding="utf-8") |
| chunker = ASTChunker(source, str(file_path)) |
| |
| |
| chunker.visit(chunker.tree) |
| |
| return chunker.chunks |