CodeTribunal / src /code_tribunal /code_graph.py
amine-yagoub's picture
refactor: clean up core modules by removing comment headers and unused code
6a2abaa
"""AST-based code dependency graph for structured code analysis."""
import ast
import re
from dataclasses import dataclass, field
from pathlib import Path
@dataclass
class CodeNode:
"""A node in the code graph — file, function, class, or import."""
id: str
kind: str
name: str
file: str
line: int
metadata: dict = field(default_factory=dict)
@dataclass
class CodeEdge:
"""A directed edge — call, import, containment, or inheritance."""
source: str
target: str
kind: str
_SOURCE_EXTENSIONS = {".py", ".js", ".ts", ".jsx", ".tsx", ".java", ".go", ".rb"}
class CodeGraph:
"""Lightweight code dependency graph built from AST and regex analysis."""
def __init__(self) -> None:
self.nodes: dict[str, CodeNode] = {}
self.edges: list[CodeEdge] = []
self._file_contents: dict[str, list[str]] = {}
def build_from_directory(self, target_dir: str) -> None:
"""Scan all source files in *target_dir* and populate the graph."""
root = Path(target_dir)
for p in sorted(root.rglob("*")):
if p.suffix in _SOURCE_EXTENSIONS and p.is_file():
self.build_from_file(str(p), str(root))
def build_from_file(self, filepath: str, root: str = "") -> None:
"""Parse a single file and add its nodes/edges."""
path = Path(filepath)
rel = path.relative_to(root) if root else path.name
ext = path.suffix
try:
source = path.read_text(errors="replace")
except OSError:
return
self._file_contents[str(rel)] = source.splitlines()
file_id = str(rel)
self.nodes[file_id] = CodeNode(
id=file_id, kind="file", name=path.name, file=str(rel), line=0,
)
if ext == ".py":
self._parse_python(source, str(rel), file_id)
elif ext in {".js", ".ts", ".jsx", ".tsx"}:
self._parse_javascript(source, str(rel), file_id)
def get_file_summary(self, filepath: str) -> str:
"""Return a compact summary of what a file contains."""
lines: list[str] = []
for node in sorted(
(n for n in self.nodes.values() if n.file == filepath and n.kind != "file"),
key=lambda n: n.line,
):
lines.append(f"L{node.line}: [{node.kind}] {node.name}")
if not lines:
return f"No structured elements found in {filepath}"
return f"=== {filepath} ===\n" + "\n".join(lines)
def get_function_source(self, filepath: str, function_name: str) -> str:
"""Return the source lines of a function."""
raw_lines = self._file_contents.get(filepath, [])
if not raw_lines:
return f"File {filepath} not cached."
for node in self.nodes.values():
if node.file == filepath and node.name == function_name and node.kind == "function":
start = node.line - 1
collected: list[str] = []
for i in range(start, len(raw_lines)):
if i > start and raw_lines[i].strip() and not raw_lines[i].startswith((" ", "\t")):
break
collected.append(f"{i+1}: {raw_lines[i]}")
return "\n".join(collected) if collected else "Function body not available."
return f"Function '{function_name}' not found in {filepath}."
def trace_calls(self, function_name: str, depth: int = 3) -> str:
"""Trace the call chain from *function_name* up to *depth* hops."""
chains: list[list[str]] = []
self._trace_recursive(function_name, depth, [function_name], chains)
if not chains:
return f"No call chain found for '{function_name}'."
lines: list[str] = []
for chain in chains:
lines.append(" → ".join(chain))
return "\n".join(lines)
def get_callers(self, function_name: str) -> list[str]:
"""Return all functions that call *function_name*."""
callers: list[str] = []
for edge in self.edges:
if edge.kind == "calls" and edge.target.endswith(f":{function_name}"):
callers.append(edge.source)
return callers
def get_imports(self, filepath: str) -> list[str]:
"""Return all imports in a file."""
imports: list[str] = []
for edge in self.edges:
if edge.kind == "imports" and edge.source == filepath:
imports.append(edge.target)
return imports
def to_text(self) -> str:
"""Compact text representation of the graph for LLM context."""
lines: list[str] = ["=== CODE DEPENDENCY GRAPH ==="]
lines.append(f"Nodes: {len(self.nodes)} | Edges: {len(self.edges)}")
by_file: dict[str, list[CodeNode]] = {}
for node in self.nodes.values():
by_file.setdefault(node.file, []).append(node)
for filepath in sorted(by_file):
file_nodes = sorted(by_file[filepath], key=lambda n: n.line)
lines.append(f"\n--- {filepath} ---")
for node in file_nodes:
if node.kind == "file":
continue
decorators = node.metadata.get("decorators", [])
dec_str = f" @{','.join(decorators)}" if decorators else ""
lines.append(f" L{node.line} [{node.kind}] {node.name}{dec_str}")
file_edges = [
e for e in self.edges
if e.source == filepath or e.source.startswith(f"{filepath}:")
]
if file_edges:
for edge in file_edges[:20]:
short = edge.target.split(":")[-1] if ":" in edge.target else edge.target
lines.append(f" --{edge.kind}--> {short}")
return "\n".join(lines)
def get_statistics(self) -> dict:
"""Return counts by node kind and edge kind."""
node_kinds: dict[str, int] = {}
for node in self.nodes.values():
node_kinds[node.kind] = node_kinds.get(node.kind, 0) + 1
edge_kinds: dict[str, int] = {}
for edge in self.edges:
edge_kinds[edge.kind] = edge_kinds.get(edge.kind, 0) + 1
return {"nodes": node_kinds, "edges": edge_kinds}
def _trace_recursive(
self, name: str, depth: int, path: list[str], results: list[list[str]]
) -> None:
if depth <= 0:
return
callers = self.get_callers(name)
if not callers:
results.append(path[:])
return
for caller_id in callers:
caller_name = caller_id.split(":")[-1] if ":" in caller_id else caller_id
if caller_name in path:
continue
path.append(caller_name)
self._trace_recursive(caller_name, depth - 1, path, results)
path.pop()
def _add_node(self, node: CodeNode) -> None:
if node.id not in self.nodes:
self.nodes[node.id] = node
def _add_edge(self, edge: CodeEdge) -> None:
self.edges.append(edge)
def _parse_python(self, source: str, rel_path: str, file_id: str) -> None:
"""Parse Python source using the ast module."""
try:
tree = ast.parse(source, filename=rel_path)
except SyntaxError:
return
for node in ast.iter_child_nodes(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
func_id = f"{file_id}:{node.name}"
decorators = [
d.attr if isinstance(d, ast.Attribute) else (d.id if isinstance(d, ast.Name) else str(d))
for d in node.decorator_list
]
self._add_node(CodeNode(
id=func_id, kind="function", name=node.name,
file=rel_path, line=node.lineno,
metadata={"decorators": decorators},
))
self._add_edge(CodeEdge(source=file_id, target=func_id, kind="contains"))
self._collect_calls(node, func_id, file_id)
elif isinstance(node, ast.ClassDef):
cls_id = f"{file_id}:{node.name}"
decorators = [
d.attr if isinstance(d, ast.Attribute) else (d.id if isinstance(d, ast.Name) else str(d))
for d in node.decorator_list
]
self._add_node(CodeNode(
id=cls_id, kind="class", name=node.name,
file=rel_path, line=node.lineno,
metadata={"decorators": decorators},
))
self._add_edge(CodeEdge(source=file_id, target=cls_id, kind="contains"))
for base in node.bases:
if isinstance(base, ast.Name):
self._add_edge(CodeEdge(source=cls_id, target=base.id, kind="inherits"))
for item in ast.iter_child_nodes(node):
if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
method_id = f"{cls_id}.{item.name}"
self._add_node(CodeNode(
id=method_id, kind="function", name=item.name,
file=rel_path, line=item.lineno,
))
self._add_edge(CodeEdge(source=cls_id, target=method_id, kind="contains"))
self._collect_calls(item, method_id, file_id)
elif isinstance(node, (ast.Import, ast.ImportFrom)):
names = [alias.name for alias in node.names]
for name in names:
self._add_edge(CodeEdge(source=file_id, target=name, kind="imports"))
def _collect_calls(self, func_node: ast.AST, caller_id: str, file_id: str) -> None:
"""Walk *func_node* and record function calls as edges."""
for child in ast.walk(func_node):
if isinstance(child, ast.Call):
name = None
if isinstance(child.func, ast.Name):
name = child.func.id
elif isinstance(child.func, ast.Attribute):
name = child.func.attr
if name:
self._add_edge(CodeEdge(source=caller_id, target=name, kind="calls"))
_JS_FUNC_RE = re.compile(
r"(?:function\s+(\w+)|(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s+)?(?:function|\([^)]*\)\s*=>)|"
r"(?:async\s+)?(\w+)\s*\([^)]*\)\s*\{)",
re.MULTILINE,
)
_JS_IMPORT_RE = re.compile(
r"(?:import\s+.*?from\s+['\"]([^'\"]+)['\"]|require\s*\(\s*['\"]([^'\"]+)['\"]\s*\))",
re.MULTILINE,
)
_JS_CALL_RE = re.compile(r"(\w+)\s*\(", re.MULTILINE)
def _parse_javascript(self, source: str, rel_path: str, file_id: str) -> None:
"""Parse JavaScript/TypeScript source using regex-based heuristics."""
for m in self._JS_FUNC_RE.finditer(source):
name = m.group(1) or m.group(2) or m.group(3)
if not name or name in {"if", "for", "while", "switch", "catch", "return", "throw"}:
continue
line = source[: m.start()].count("\n") + 1
func_id = f"{file_id}:{name}"
self._add_node(CodeNode(
id=func_id, kind="function", name=name,
file=rel_path, line=line,
))
self._add_edge(CodeEdge(source=file_id, target=func_id, kind="contains"))
body_start = m.end()
next_match = self._JS_FUNC_RE.search(source, m.end())
body_end = next_match.start() if next_match else len(source)
body = source[body_start:body_end]
for cm in self._JS_CALL_RE.finditer(body):
callee = cm.group(1)
if callee not in {"if", "for", "while", "switch", "catch", "return", "throw", "function", "const", "let", "var"}:
self._add_edge(CodeEdge(source=func_id, target=callee, kind="calls"))
for m in self._JS_IMPORT_RE.finditer(source):
module = m.group(1) or m.group(2)
if module:
self._add_edge(CodeEdge(source=file_id, target=module, kind="imports"))