"""AST-based code dependency graph for structured code analysis.""" import ast import re from dataclasses import dataclass, field from pathlib import Path @dataclass class CodeNode: """A node in the code graph — file, function, class, or import.""" id: str kind: str name: str file: str line: int metadata: dict = field(default_factory=dict) @dataclass class CodeEdge: """A directed edge — call, import, containment, or inheritance.""" source: str target: str kind: str _SOURCE_EXTENSIONS = {".py", ".js", ".ts", ".jsx", ".tsx", ".java", ".go", ".rb"} class CodeGraph: """Lightweight code dependency graph built from AST and regex analysis.""" def __init__(self) -> None: self.nodes: dict[str, CodeNode] = {} self.edges: list[CodeEdge] = [] self._file_contents: dict[str, list[str]] = {} def build_from_directory(self, target_dir: str) -> None: """Scan all source files in *target_dir* and populate the graph.""" root = Path(target_dir) for p in sorted(root.rglob("*")): if p.suffix in _SOURCE_EXTENSIONS and p.is_file(): self.build_from_file(str(p), str(root)) def build_from_file(self, filepath: str, root: str = "") -> None: """Parse a single file and add its nodes/edges.""" path = Path(filepath) rel = path.relative_to(root) if root else path.name ext = path.suffix try: source = path.read_text(errors="replace") except OSError: return self._file_contents[str(rel)] = source.splitlines() file_id = str(rel) self.nodes[file_id] = CodeNode( id=file_id, kind="file", name=path.name, file=str(rel), line=0, ) if ext == ".py": self._parse_python(source, str(rel), file_id) elif ext in {".js", ".ts", ".jsx", ".tsx"}: self._parse_javascript(source, str(rel), file_id) def get_file_summary(self, filepath: str) -> str: """Return a compact summary of what a file contains.""" lines: list[str] = [] for node in sorted( (n for n in self.nodes.values() if n.file == filepath and n.kind != "file"), key=lambda n: n.line, ): lines.append(f"L{node.line}: [{node.kind}] {node.name}") if not lines: return f"No structured elements found in {filepath}" return f"=== {filepath} ===\n" + "\n".join(lines) def get_function_source(self, filepath: str, function_name: str) -> str: """Return the source lines of a function.""" raw_lines = self._file_contents.get(filepath, []) if not raw_lines: return f"File {filepath} not cached." for node in self.nodes.values(): if node.file == filepath and node.name == function_name and node.kind == "function": start = node.line - 1 collected: list[str] = [] for i in range(start, len(raw_lines)): if i > start and raw_lines[i].strip() and not raw_lines[i].startswith((" ", "\t")): break collected.append(f"{i+1}: {raw_lines[i]}") return "\n".join(collected) if collected else "Function body not available." return f"Function '{function_name}' not found in {filepath}." def trace_calls(self, function_name: str, depth: int = 3) -> str: """Trace the call chain from *function_name* up to *depth* hops.""" chains: list[list[str]] = [] self._trace_recursive(function_name, depth, [function_name], chains) if not chains: return f"No call chain found for '{function_name}'." lines: list[str] = [] for chain in chains: lines.append(" → ".join(chain)) return "\n".join(lines) def get_callers(self, function_name: str) -> list[str]: """Return all functions that call *function_name*.""" callers: list[str] = [] for edge in self.edges: if edge.kind == "calls" and edge.target.endswith(f":{function_name}"): callers.append(edge.source) return callers def get_imports(self, filepath: str) -> list[str]: """Return all imports in a file.""" imports: list[str] = [] for edge in self.edges: if edge.kind == "imports" and edge.source == filepath: imports.append(edge.target) return imports def to_text(self) -> str: """Compact text representation of the graph for LLM context.""" lines: list[str] = ["=== CODE DEPENDENCY GRAPH ==="] lines.append(f"Nodes: {len(self.nodes)} | Edges: {len(self.edges)}") by_file: dict[str, list[CodeNode]] = {} for node in self.nodes.values(): by_file.setdefault(node.file, []).append(node) for filepath in sorted(by_file): file_nodes = sorted(by_file[filepath], key=lambda n: n.line) lines.append(f"\n--- {filepath} ---") for node in file_nodes: if node.kind == "file": continue decorators = node.metadata.get("decorators", []) dec_str = f" @{','.join(decorators)}" if decorators else "" lines.append(f" L{node.line} [{node.kind}] {node.name}{dec_str}") file_edges = [ e for e in self.edges if e.source == filepath or e.source.startswith(f"{filepath}:") ] if file_edges: for edge in file_edges[:20]: short = edge.target.split(":")[-1] if ":" in edge.target else edge.target lines.append(f" --{edge.kind}--> {short}") return "\n".join(lines) def get_statistics(self) -> dict: """Return counts by node kind and edge kind.""" node_kinds: dict[str, int] = {} for node in self.nodes.values(): node_kinds[node.kind] = node_kinds.get(node.kind, 0) + 1 edge_kinds: dict[str, int] = {} for edge in self.edges: edge_kinds[edge.kind] = edge_kinds.get(edge.kind, 0) + 1 return {"nodes": node_kinds, "edges": edge_kinds} def _trace_recursive( self, name: str, depth: int, path: list[str], results: list[list[str]] ) -> None: if depth <= 0: return callers = self.get_callers(name) if not callers: results.append(path[:]) return for caller_id in callers: caller_name = caller_id.split(":")[-1] if ":" in caller_id else caller_id if caller_name in path: continue path.append(caller_name) self._trace_recursive(caller_name, depth - 1, path, results) path.pop() def _add_node(self, node: CodeNode) -> None: if node.id not in self.nodes: self.nodes[node.id] = node def _add_edge(self, edge: CodeEdge) -> None: self.edges.append(edge) def _parse_python(self, source: str, rel_path: str, file_id: str) -> None: """Parse Python source using the ast module.""" try: tree = ast.parse(source, filename=rel_path) except SyntaxError: return for node in ast.iter_child_nodes(tree): if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): func_id = f"{file_id}:{node.name}" decorators = [ d.attr if isinstance(d, ast.Attribute) else (d.id if isinstance(d, ast.Name) else str(d)) for d in node.decorator_list ] self._add_node(CodeNode( id=func_id, kind="function", name=node.name, file=rel_path, line=node.lineno, metadata={"decorators": decorators}, )) self._add_edge(CodeEdge(source=file_id, target=func_id, kind="contains")) self._collect_calls(node, func_id, file_id) elif isinstance(node, ast.ClassDef): cls_id = f"{file_id}:{node.name}" decorators = [ d.attr if isinstance(d, ast.Attribute) else (d.id if isinstance(d, ast.Name) else str(d)) for d in node.decorator_list ] self._add_node(CodeNode( id=cls_id, kind="class", name=node.name, file=rel_path, line=node.lineno, metadata={"decorators": decorators}, )) self._add_edge(CodeEdge(source=file_id, target=cls_id, kind="contains")) for base in node.bases: if isinstance(base, ast.Name): self._add_edge(CodeEdge(source=cls_id, target=base.id, kind="inherits")) for item in ast.iter_child_nodes(node): if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)): method_id = f"{cls_id}.{item.name}" self._add_node(CodeNode( id=method_id, kind="function", name=item.name, file=rel_path, line=item.lineno, )) self._add_edge(CodeEdge(source=cls_id, target=method_id, kind="contains")) self._collect_calls(item, method_id, file_id) elif isinstance(node, (ast.Import, ast.ImportFrom)): names = [alias.name for alias in node.names] for name in names: self._add_edge(CodeEdge(source=file_id, target=name, kind="imports")) def _collect_calls(self, func_node: ast.AST, caller_id: str, file_id: str) -> None: """Walk *func_node* and record function calls as edges.""" for child in ast.walk(func_node): if isinstance(child, ast.Call): name = None if isinstance(child.func, ast.Name): name = child.func.id elif isinstance(child.func, ast.Attribute): name = child.func.attr if name: self._add_edge(CodeEdge(source=caller_id, target=name, kind="calls")) _JS_FUNC_RE = re.compile( r"(?:function\s+(\w+)|(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s+)?(?:function|\([^)]*\)\s*=>)|" r"(?:async\s+)?(\w+)\s*\([^)]*\)\s*\{)", re.MULTILINE, ) _JS_IMPORT_RE = re.compile( r"(?:import\s+.*?from\s+['\"]([^'\"]+)['\"]|require\s*\(\s*['\"]([^'\"]+)['\"]\s*\))", re.MULTILINE, ) _JS_CALL_RE = re.compile(r"(\w+)\s*\(", re.MULTILINE) def _parse_javascript(self, source: str, rel_path: str, file_id: str) -> None: """Parse JavaScript/TypeScript source using regex-based heuristics.""" for m in self._JS_FUNC_RE.finditer(source): name = m.group(1) or m.group(2) or m.group(3) if not name or name in {"if", "for", "while", "switch", "catch", "return", "throw"}: continue line = source[: m.start()].count("\n") + 1 func_id = f"{file_id}:{name}" self._add_node(CodeNode( id=func_id, kind="function", name=name, file=rel_path, line=line, )) self._add_edge(CodeEdge(source=file_id, target=func_id, kind="contains")) body_start = m.end() next_match = self._JS_FUNC_RE.search(source, m.end()) body_end = next_match.start() if next_match else len(source) body = source[body_start:body_end] for cm in self._JS_CALL_RE.finditer(body): callee = cm.group(1) if callee not in {"if", "for", "while", "switch", "catch", "return", "throw", "function", "const", "let", "var"}: self._add_edge(CodeEdge(source=func_id, target=callee, kind="calls")) for m in self._JS_IMPORT_RE.finditer(source): module = m.group(1) or m.group(2) if module: self._add_edge(CodeEdge(source=file_id, target=module, kind="imports"))