CodeMode / scripts /core /ingestion /ast_chunker.py
CodeMode Agent
Deploy CodeMode via Agent
463fc7e
"""
AST-based semantic code chunker - Primary source of truth for code structure.
This module implements the core AST-based chunking strategy that forms the
authority layer of our hybrid chunking pipeline. It uses Python's built-in
AST parser to extract semantic chunks (modules, classes, functions, methods)
while preserving hierarchical relationships.
ARCHITECTURE POSITION:
- Authority Layer: Source of truth for semantic structure
- Primary Chunker: Generates all primary chunks
- Hierarchy Builder: Establishes parent-child relationships
KEY FEATURES:
1. AST-first parsing for semantic accuracy
2. Hierarchical chunk generation with depth tracking
3. Byte-level span calculation for precise positioning
4. Import and decorator extraction per node
5. Deterministic chunk ID generation
FLOW:
File → Python AST → ASTChunker visitor → Semantic chunks with hierarchy
USAGE:
from ast_chunker import extract_ast_chunks
chunks = extract_ast_chunks(Path("file.py"))
"""
import ast
from pathlib import Path
from typing import List, Optional, Union, Dict, Tuple
import hashlib
from ..utils.id_utils import deterministic_chunk_id
from .chunk_schema import CodeChunk, ChunkAST, ChunkSpan, ChunkHierarchy, ASTSymbolType, ChunkType
DocNode = Union[
ast.Module,
ast.ClassDef,
ast.FunctionDef,
ast.AsyncFunctionDef,
]
class ASTChunker(ast.NodeVisitor):
def __init__(self, source: str, file_path: str):
self.source = source
self.file_path = file_path
self.source_bytes = source.encode('utf-8')
self.chunks: List[CodeChunk] = []
self.tree = ast.parse(source)
# Track hierarchy
self.current_class: Optional[str] = None
self.imports_list: List[str] = []
# For hierarchy tracking
self.parent_stack: List[CodeChunk] = []
self.sibling_counters: Dict[str, int] = {}
# Attach parents to nodes
for node in ast.walk(self.tree):
for child in ast.iter_child_nodes(node):
setattr(child, "parent", node)
# ---------------- utilities ----------------
def _get_code(self, node: ast.AST) -> str:
code = ast.get_source_segment(self.source, node)
return code.strip() if code else ""
def _get_byte_span(self, start_line: int, end_line: int) -> Tuple[int, int]:
"""Convert line numbers to byte positions"""
lines = self.source.split('\n')
# Calculate start byte
start_byte = sum(len(line.encode()) + 1 for line in lines[:start_line-1])
# Calculate end byte (up to end_line)
end_byte = sum(len(line.encode()) + 1 for line in lines[:end_line])
return start_byte, end_byte
def _extract_node_imports(self, node: ast.AST) -> List[str]:
"""Extract imports specific to this node (not all module imports)"""
imports: List[str] = []
# Walk through this node's body
for child in ast.walk(node):
if isinstance(child, (ast.Import, ast.ImportFrom)):
try:
imports.append(ast.unparse(child))
except Exception:
imports.append(str(child))
return imports
def _extract_decorators(self, node: ast.AST) -> List[str]:
decorators: List[str] = []
if hasattr(node, "decorator_list"):
for d in node.decorator_list: # type: ignore[attr-defined]
try:
decorators.append(ast.unparse(d))
except Exception:
decorators.append(str(d))
return decorators
# ---------------- chunk creation ----------------
def _create_chunk(
self,
node: DocNode,
chunk_type: ChunkType,
name: str,
parent: Optional[str] = None,
parent_chunk: Optional[CodeChunk] = None,
) -> CodeChunk:
code = self._get_code(node)
# Get line numbers
start_line = getattr(node, "lineno", None)
end_line = getattr(node, "end_lineno", None)
# Calculate byte span
start_byte, end_byte = None, None
if start_line and end_line:
start_byte, end_byte = self._get_byte_span(start_line, end_line)
# Determine parent if not provided
if parent is None and chunk_type == "method":
parent = self.current_class
decorators: List[str] = []
if isinstance(node, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)):
decorators = self._extract_decorators(node)
# Get imports specific to this node (not all module imports)
node_imports = self._extract_node_imports(node)
# Get docstring only for nodes that can have one
docstring: Optional[str] = None
if hasattr(node, 'body'):
docstring = ast.get_docstring(node)
# Determine hierarchy depth
depth = 0
lineage: List[str] = []
sibling_index = 0
if parent_chunk:
depth = parent_chunk.hierarchy.depth + 1
lineage = parent_chunk.hierarchy.lineage.copy()
lineage.append(parent_chunk.chunk_id)
# Update sibling counter
parent_key = parent_chunk.chunk_id
self.sibling_counters[parent_key] = self.sibling_counters.get(parent_key, 0) + 1
sibling_index = self.sibling_counters[parent_key] - 1
ast_info = ChunkAST(
symbol_type=chunk_type,
name=name,
parent=parent,
docstring=docstring,
decorators=decorators,
imports=node_imports,
)
span = ChunkSpan(
start_byte=start_byte,
end_byte=end_byte,
start_line=start_line,
end_line=end_line,
)
# Generate chunk ID
chunk_id = deterministic_chunk_id(
file_path=self.file_path,
chunk_type=chunk_type,
name=name,
parent=parent,
start_line=start_line,
end_line=end_line,
code=code,
)
chunk = CodeChunk(
chunk_id=chunk_id,
file_path=self.file_path,
language="python",
chunk_type=chunk_type,
code=code,
ast=ast_info,
span=span,
hierarchy=ChunkHierarchy(
parent_id=parent_chunk.chunk_id if parent_chunk else None,
children_ids=[],
depth=depth,
is_primary=True,
is_extracted=False,
lineage=lineage,
sibling_index=sibling_index,
),
)
# Add to parent's children if parent exists
if parent_chunk:
parent_chunk.hierarchy.children_ids.append(chunk_id)
self.chunks.append(chunk)
return chunk
def _create_module_chunk(self) -> CodeChunk:
"""Create module chunk with all imports"""
module_name = Path(self.file_path).stem
start_line = 1
end_line = len(self.source.split('\n'))
start_byte, end_byte = self._get_byte_span(start_line, end_line)
# Module code - entire file
module_code = self.source
# Extract ALL imports for module
module_imports: List[str] = []
for node in ast.walk(self.tree):
if isinstance(node, (ast.Import, ast.ImportFrom)):
try:
module_imports.append(ast.unparse(node))
except Exception:
pass
chunk_id = deterministic_chunk_id(
file_path=self.file_path,
chunk_type="module",
name=module_name,
parent=None,
start_line=start_line,
end_line=end_line,
code=module_code,
)
ast_info = ChunkAST(
symbol_type="module",
name=module_name,
parent=None,
docstring=ast.get_docstring(self.tree),
decorators=[],
imports=module_imports, # ALL imports in module
)
span = ChunkSpan(
start_byte=start_byte,
end_byte=end_byte,
start_line=start_line,
end_line=end_line,
)
chunk = CodeChunk(
chunk_id=chunk_id,
file_path=self.file_path,
language="python",
chunk_type="module",
code=module_code,
ast=ast_info,
span=span,
hierarchy=ChunkHierarchy(
parent_id=None,
children_ids=[],
depth=0,
is_primary=True,
is_extracted=False,
lineage=[],
sibling_index=0,
),
)
self.chunks.append(chunk)
return chunk
# ---------------- visitors ----------------
def visit_Import(self, node: ast.Import) -> None:
try:
self.imports_list.append(ast.unparse(node))
except Exception:
pass
self.generic_visit(node)
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
try:
self.imports_list.append(ast.unparse(node))
except Exception:
pass
self.generic_visit(node)
def visit_ClassDef(self, node: ast.ClassDef) -> None:
# Create class chunk
class_chunk = self._create_chunk(
node,
"class",
node.name,
parent="module",
parent_chunk=self.parent_stack[-1] if self.parent_stack else None,
)
# Save current class context
previous_class = self.current_class
self.current_class = node.name
# Push class to stack
self.parent_stack.append(class_chunk)
# Visit class body
self.generic_visit(node)
# Restore previous context
self.current_class = previous_class
self.parent_stack.pop()
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
parent = getattr(node, "parent", None)
if isinstance(parent, ast.Module):
# Top-level function
self._create_chunk(
node,
"function",
node.name,
parent="module",
parent_chunk=self.parent_stack[-1] if self.parent_stack else None,
)
elif isinstance(parent, ast.ClassDef):
# Method inside class
self._create_chunk(
node,
"method",
node.name,
parent=parent.name,
parent_chunk=self.parent_stack[-1] if self.parent_stack else None,
)
self.generic_visit(node)
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
parent = getattr(node, "parent", None)
if isinstance(parent, ast.Module):
# Top-level async function
self._create_chunk(
node,
"function",
node.name,
parent="module",
parent_chunk=self.parent_stack[-1] if self.parent_stack else None,
)
elif isinstance(parent, ast.ClassDef):
# Async method inside class
self._create_chunk(
node,
"method",
node.name,
parent=parent.name,
parent_chunk=self.parent_stack[-1] if self.parent_stack else None,
)
self.generic_visit(node)
def visit_Module(self, node: ast.Module) -> None:
# Create module chunk first (root)
module_chunk = self._create_module_chunk()
# Push module to stack
self.parent_stack.append(module_chunk)
# Visit children to create classes and functions
self.generic_visit(node)
# Pop module from stack
self.parent_stack.pop()
# ---------------- public API ----------------
def extract_ast_chunks(file_path: Path) -> List[CodeChunk]:
source = file_path.read_text(encoding="utf-8")
chunker = ASTChunker(source, str(file_path))
# Visit the tree (creates all chunks with relationships)
chunker.visit(chunker.tree)
return chunker.chunks