File size: 12,582 Bytes
463fc7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
"""
AST-based semantic code chunker - Primary source of truth for code structure.

This module implements the core AST-based chunking strategy that forms the 
authority layer of our hybrid chunking pipeline. It uses Python's built-in 
AST parser to extract semantic chunks (modules, classes, functions, methods) 
while preserving hierarchical relationships.

ARCHITECTURE POSITION:
    - Authority Layer: Source of truth for semantic structure
    - Primary Chunker: Generates all primary chunks
    - Hierarchy Builder: Establishes parent-child relationships

KEY FEATURES:
    1. AST-first parsing for semantic accuracy
    2. Hierarchical chunk generation with depth tracking
    3. Byte-level span calculation for precise positioning
    4. Import and decorator extraction per node
    5. Deterministic chunk ID generation

FLOW:
    File → Python AST → ASTChunker visitor → Semantic chunks with hierarchy

USAGE:
    from ast_chunker import extract_ast_chunks
    chunks = extract_ast_chunks(Path("file.py"))
"""

import ast
from pathlib import Path
from typing import List, Optional, Union, Dict, Tuple
import hashlib

from ..utils.id_utils import deterministic_chunk_id
from .chunk_schema import CodeChunk, ChunkAST, ChunkSpan, ChunkHierarchy, ASTSymbolType, ChunkType

DocNode = Union[
    ast.Module,
    ast.ClassDef,
    ast.FunctionDef,
    ast.AsyncFunctionDef,
]


class ASTChunker(ast.NodeVisitor):
    def __init__(self, source: str, file_path: str):
        self.source = source
        self.file_path = file_path
        self.source_bytes = source.encode('utf-8')
        self.chunks: List[CodeChunk] = []
        self.tree = ast.parse(source)
        
        # Track hierarchy
        self.current_class: Optional[str] = None
        self.imports_list: List[str] = []
        
        # For hierarchy tracking
        self.parent_stack: List[CodeChunk] = []
        self.sibling_counters: Dict[str, int] = {}
        
        # Attach parents to nodes
        for node in ast.walk(self.tree):
            for child in ast.iter_child_nodes(node):
                setattr(child, "parent", node)

    # ---------------- utilities ----------------

    def _get_code(self, node: ast.AST) -> str:
        code = ast.get_source_segment(self.source, node)
        return code.strip() if code else ""

    def _get_byte_span(self, start_line: int, end_line: int) -> Tuple[int, int]:
        """Convert line numbers to byte positions"""
        lines = self.source.split('\n')
        
        # Calculate start byte
        start_byte = sum(len(line.encode()) + 1 for line in lines[:start_line-1])
        
        # Calculate end byte (up to end_line)
        end_byte = sum(len(line.encode()) + 1 for line in lines[:end_line])
        
        return start_byte, end_byte

    def _extract_node_imports(self, node: ast.AST) -> List[str]:
        """Extract imports specific to this node (not all module imports)"""
        imports: List[str] = []
        
        # Walk through this node's body
        for child in ast.walk(node):
            if isinstance(child, (ast.Import, ast.ImportFrom)):
                try:
                    imports.append(ast.unparse(child))
                except Exception:
                    imports.append(str(child))
        return imports

    def _extract_decorators(self, node: ast.AST) -> List[str]:
        decorators: List[str] = []
        if hasattr(node, "decorator_list"):
            for d in node.decorator_list:  # type: ignore[attr-defined]
                try:
                    decorators.append(ast.unparse(d))
                except Exception:
                    decorators.append(str(d))
        return decorators

    # ---------------- chunk creation ----------------

    def _create_chunk(
        self,
        node: DocNode,
        chunk_type: ChunkType,
        name: str,
        parent: Optional[str] = None,
        parent_chunk: Optional[CodeChunk] = None,
    ) -> CodeChunk:
        code = self._get_code(node)
        
        # Get line numbers
        start_line = getattr(node, "lineno", None)
        end_line = getattr(node, "end_lineno", None)
        
        # Calculate byte span
        start_byte, end_byte = None, None
        if start_line and end_line:
            start_byte, end_byte = self._get_byte_span(start_line, end_line)

        # Determine parent if not provided
        if parent is None and chunk_type == "method":
            parent = self.current_class

        decorators: List[str] = []
        if isinstance(node, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)):
            decorators = self._extract_decorators(node)

        # Get imports specific to this node (not all module imports)
        node_imports = self._extract_node_imports(node)

        # Get docstring only for nodes that can have one
        docstring: Optional[str] = None
        if hasattr(node, 'body'):
            docstring = ast.get_docstring(node)

        # Determine hierarchy depth
        depth = 0
        lineage: List[str] = []
        sibling_index = 0
        
        if parent_chunk:
            depth = parent_chunk.hierarchy.depth + 1
            lineage = parent_chunk.hierarchy.lineage.copy()
            lineage.append(parent_chunk.chunk_id)
            
            # Update sibling counter
            parent_key = parent_chunk.chunk_id
            self.sibling_counters[parent_key] = self.sibling_counters.get(parent_key, 0) + 1
            sibling_index = self.sibling_counters[parent_key] - 1

        ast_info = ChunkAST(
            symbol_type=chunk_type,
            name=name,
            parent=parent,
            docstring=docstring,
            decorators=decorators,
            imports=node_imports,
        )

        span = ChunkSpan(
            start_byte=start_byte,
            end_byte=end_byte,
            start_line=start_line,
            end_line=end_line,
        )

        # Generate chunk ID
        chunk_id = deterministic_chunk_id(
            file_path=self.file_path,
            chunk_type=chunk_type,
            name=name,
            parent=parent,
            start_line=start_line,
            end_line=end_line,
            code=code,
        )

        chunk = CodeChunk(
            chunk_id=chunk_id,
            file_path=self.file_path,
            language="python",
            chunk_type=chunk_type,
            code=code,
            ast=ast_info,
            span=span,
            hierarchy=ChunkHierarchy(
                parent_id=parent_chunk.chunk_id if parent_chunk else None,
                children_ids=[],
                depth=depth,
                is_primary=True,
                is_extracted=False,
                lineage=lineage,
                sibling_index=sibling_index,
            ),
        )

        # Add to parent's children if parent exists
        if parent_chunk:
            parent_chunk.hierarchy.children_ids.append(chunk_id)

        self.chunks.append(chunk)
        return chunk

    def _create_module_chunk(self) -> CodeChunk:
        """Create module chunk with all imports"""
        module_name = Path(self.file_path).stem
        start_line = 1
        end_line = len(self.source.split('\n'))
        start_byte, end_byte = self._get_byte_span(start_line, end_line)
        
        # Module code - entire file
        module_code = self.source
        
        # Extract ALL imports for module
        module_imports: List[str] = []
        for node in ast.walk(self.tree):
            if isinstance(node, (ast.Import, ast.ImportFrom)):
                try:
                    module_imports.append(ast.unparse(node))
                except Exception:
                    pass
        
        chunk_id = deterministic_chunk_id(
            file_path=self.file_path,
            chunk_type="module",
            name=module_name,
            parent=None,
            start_line=start_line,
            end_line=end_line,
            code=module_code,
        )
        
        ast_info = ChunkAST(
            symbol_type="module",
            name=module_name,
            parent=None,
            docstring=ast.get_docstring(self.tree),
            decorators=[],
            imports=module_imports,  # ALL imports in module
        )
        
        span = ChunkSpan(
            start_byte=start_byte,
            end_byte=end_byte,
            start_line=start_line,
            end_line=end_line,
        )
        
        chunk = CodeChunk(
            chunk_id=chunk_id,
            file_path=self.file_path,
            language="python",
            chunk_type="module",
            code=module_code,
            ast=ast_info,
            span=span,
            hierarchy=ChunkHierarchy(
                parent_id=None,
                children_ids=[],
                depth=0,
                is_primary=True,
                is_extracted=False,
                lineage=[],
                sibling_index=0,
            ),
        )
        
        self.chunks.append(chunk)
        return chunk

    # ---------------- visitors ----------------

    def visit_Import(self, node: ast.Import) -> None:
        try:
            self.imports_list.append(ast.unparse(node))
        except Exception:
            pass
        self.generic_visit(node)

    def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
        try:
            self.imports_list.append(ast.unparse(node))
        except Exception:
            pass
        self.generic_visit(node)

    def visit_ClassDef(self, node: ast.ClassDef) -> None:
        # Create class chunk
        class_chunk = self._create_chunk(
            node, 
            "class", 
            node.name,
            parent="module",
            parent_chunk=self.parent_stack[-1] if self.parent_stack else None,
        )
        
        # Save current class context
        previous_class = self.current_class
        self.current_class = node.name
        
        # Push class to stack
        self.parent_stack.append(class_chunk)
        
        # Visit class body
        self.generic_visit(node)
        
        # Restore previous context
        self.current_class = previous_class
        self.parent_stack.pop()

    def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
        parent = getattr(node, "parent", None)
        
        if isinstance(parent, ast.Module):
            # Top-level function
            self._create_chunk(
                node, 
                "function", 
                node.name,
                parent="module",
                parent_chunk=self.parent_stack[-1] if self.parent_stack else None,
            )
        elif isinstance(parent, ast.ClassDef):
            # Method inside class
            self._create_chunk(
                node, 
                "method", 
                node.name,
                parent=parent.name,
                parent_chunk=self.parent_stack[-1] if self.parent_stack else None,
            )
        
        self.generic_visit(node)

    def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
        parent = getattr(node, "parent", None)
        
        if isinstance(parent, ast.Module):
            # Top-level async function
            self._create_chunk(
                node, 
                "function", 
                node.name,
                parent="module",
                parent_chunk=self.parent_stack[-1] if self.parent_stack else None,
            )
        elif isinstance(parent, ast.ClassDef):
            # Async method inside class
            self._create_chunk(
                node, 
                "method", 
                node.name,
                parent=parent.name,
                parent_chunk=self.parent_stack[-1] if self.parent_stack else None,
            )
        
        self.generic_visit(node)

    def visit_Module(self, node: ast.Module) -> None:
        # Create module chunk first (root)
        module_chunk = self._create_module_chunk()
        
        # Push module to stack
        self.parent_stack.append(module_chunk)
        
        # Visit children to create classes and functions
        self.generic_visit(node)
        
        # Pop module from stack
        self.parent_stack.pop()


# ---------------- public API ----------------

def extract_ast_chunks(file_path: Path) -> List[CodeChunk]:
    source = file_path.read_text(encoding="utf-8")
    chunker = ASTChunker(source, str(file_path))
    
    # Visit the tree (creates all chunks with relationships)
    chunker.visit(chunker.tree)
    
    return chunker.chunks