Spaces:
Running
Running
File size: 12,230 Bytes
7ddc81b 6a2abaa 7ddc81b 6a2abaa 7ddc81b 6a2abaa 7ddc81b 6a2abaa 7ddc81b 6a2abaa 7ddc81b 6a2abaa 7ddc81b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 | """AST-based code dependency graph for structured code analysis."""
import ast
import re
from dataclasses import dataclass, field
from pathlib import Path
@dataclass
class CodeNode:
"""A node in the code graph — file, function, class, or import."""
id: str
kind: str
name: str
file: str
line: int
metadata: dict = field(default_factory=dict)
@dataclass
class CodeEdge:
"""A directed edge — call, import, containment, or inheritance."""
source: str
target: str
kind: str
_SOURCE_EXTENSIONS = {".py", ".js", ".ts", ".jsx", ".tsx", ".java", ".go", ".rb"}
class CodeGraph:
"""Lightweight code dependency graph built from AST and regex analysis."""
def __init__(self) -> None:
self.nodes: dict[str, CodeNode] = {}
self.edges: list[CodeEdge] = []
self._file_contents: dict[str, list[str]] = {}
def build_from_directory(self, target_dir: str) -> None:
"""Scan all source files in *target_dir* and populate the graph."""
root = Path(target_dir)
for p in sorted(root.rglob("*")):
if p.suffix in _SOURCE_EXTENSIONS and p.is_file():
self.build_from_file(str(p), str(root))
def build_from_file(self, filepath: str, root: str = "") -> None:
"""Parse a single file and add its nodes/edges."""
path = Path(filepath)
rel = path.relative_to(root) if root else path.name
ext = path.suffix
try:
source = path.read_text(errors="replace")
except OSError:
return
self._file_contents[str(rel)] = source.splitlines()
file_id = str(rel)
self.nodes[file_id] = CodeNode(
id=file_id, kind="file", name=path.name, file=str(rel), line=0,
)
if ext == ".py":
self._parse_python(source, str(rel), file_id)
elif ext in {".js", ".ts", ".jsx", ".tsx"}:
self._parse_javascript(source, str(rel), file_id)
def get_file_summary(self, filepath: str) -> str:
"""Return a compact summary of what a file contains."""
lines: list[str] = []
for node in sorted(
(n for n in self.nodes.values() if n.file == filepath and n.kind != "file"),
key=lambda n: n.line,
):
lines.append(f"L{node.line}: [{node.kind}] {node.name}")
if not lines:
return f"No structured elements found in {filepath}"
return f"=== {filepath} ===\n" + "\n".join(lines)
def get_function_source(self, filepath: str, function_name: str) -> str:
"""Return the source lines of a function."""
raw_lines = self._file_contents.get(filepath, [])
if not raw_lines:
return f"File {filepath} not cached."
for node in self.nodes.values():
if node.file == filepath and node.name == function_name and node.kind == "function":
start = node.line - 1
collected: list[str] = []
for i in range(start, len(raw_lines)):
if i > start and raw_lines[i].strip() and not raw_lines[i].startswith((" ", "\t")):
break
collected.append(f"{i+1}: {raw_lines[i]}")
return "\n".join(collected) if collected else "Function body not available."
return f"Function '{function_name}' not found in {filepath}."
def trace_calls(self, function_name: str, depth: int = 3) -> str:
"""Trace the call chain from *function_name* up to *depth* hops."""
chains: list[list[str]] = []
self._trace_recursive(function_name, depth, [function_name], chains)
if not chains:
return f"No call chain found for '{function_name}'."
lines: list[str] = []
for chain in chains:
lines.append(" → ".join(chain))
return "\n".join(lines)
def get_callers(self, function_name: str) -> list[str]:
"""Return all functions that call *function_name*."""
callers: list[str] = []
for edge in self.edges:
if edge.kind == "calls" and edge.target.endswith(f":{function_name}"):
callers.append(edge.source)
return callers
def get_imports(self, filepath: str) -> list[str]:
"""Return all imports in a file."""
imports: list[str] = []
for edge in self.edges:
if edge.kind == "imports" and edge.source == filepath:
imports.append(edge.target)
return imports
def to_text(self) -> str:
"""Compact text representation of the graph for LLM context."""
lines: list[str] = ["=== CODE DEPENDENCY GRAPH ==="]
lines.append(f"Nodes: {len(self.nodes)} | Edges: {len(self.edges)}")
by_file: dict[str, list[CodeNode]] = {}
for node in self.nodes.values():
by_file.setdefault(node.file, []).append(node)
for filepath in sorted(by_file):
file_nodes = sorted(by_file[filepath], key=lambda n: n.line)
lines.append(f"\n--- {filepath} ---")
for node in file_nodes:
if node.kind == "file":
continue
decorators = node.metadata.get("decorators", [])
dec_str = f" @{','.join(decorators)}" if decorators else ""
lines.append(f" L{node.line} [{node.kind}] {node.name}{dec_str}")
file_edges = [
e for e in self.edges
if e.source == filepath or e.source.startswith(f"{filepath}:")
]
if file_edges:
for edge in file_edges[:20]:
short = edge.target.split(":")[-1] if ":" in edge.target else edge.target
lines.append(f" --{edge.kind}--> {short}")
return "\n".join(lines)
def get_statistics(self) -> dict:
"""Return counts by node kind and edge kind."""
node_kinds: dict[str, int] = {}
for node in self.nodes.values():
node_kinds[node.kind] = node_kinds.get(node.kind, 0) + 1
edge_kinds: dict[str, int] = {}
for edge in self.edges:
edge_kinds[edge.kind] = edge_kinds.get(edge.kind, 0) + 1
return {"nodes": node_kinds, "edges": edge_kinds}
def _trace_recursive(
self, name: str, depth: int, path: list[str], results: list[list[str]]
) -> None:
if depth <= 0:
return
callers = self.get_callers(name)
if not callers:
results.append(path[:])
return
for caller_id in callers:
caller_name = caller_id.split(":")[-1] if ":" in caller_id else caller_id
if caller_name in path:
continue
path.append(caller_name)
self._trace_recursive(caller_name, depth - 1, path, results)
path.pop()
def _add_node(self, node: CodeNode) -> None:
if node.id not in self.nodes:
self.nodes[node.id] = node
def _add_edge(self, edge: CodeEdge) -> None:
self.edges.append(edge)
def _parse_python(self, source: str, rel_path: str, file_id: str) -> None:
"""Parse Python source using the ast module."""
try:
tree = ast.parse(source, filename=rel_path)
except SyntaxError:
return
for node in ast.iter_child_nodes(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
func_id = f"{file_id}:{node.name}"
decorators = [
d.attr if isinstance(d, ast.Attribute) else (d.id if isinstance(d, ast.Name) else str(d))
for d in node.decorator_list
]
self._add_node(CodeNode(
id=func_id, kind="function", name=node.name,
file=rel_path, line=node.lineno,
metadata={"decorators": decorators},
))
self._add_edge(CodeEdge(source=file_id, target=func_id, kind="contains"))
self._collect_calls(node, func_id, file_id)
elif isinstance(node, ast.ClassDef):
cls_id = f"{file_id}:{node.name}"
decorators = [
d.attr if isinstance(d, ast.Attribute) else (d.id if isinstance(d, ast.Name) else str(d))
for d in node.decorator_list
]
self._add_node(CodeNode(
id=cls_id, kind="class", name=node.name,
file=rel_path, line=node.lineno,
metadata={"decorators": decorators},
))
self._add_edge(CodeEdge(source=file_id, target=cls_id, kind="contains"))
for base in node.bases:
if isinstance(base, ast.Name):
self._add_edge(CodeEdge(source=cls_id, target=base.id, kind="inherits"))
for item in ast.iter_child_nodes(node):
if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
method_id = f"{cls_id}.{item.name}"
self._add_node(CodeNode(
id=method_id, kind="function", name=item.name,
file=rel_path, line=item.lineno,
))
self._add_edge(CodeEdge(source=cls_id, target=method_id, kind="contains"))
self._collect_calls(item, method_id, file_id)
elif isinstance(node, (ast.Import, ast.ImportFrom)):
names = [alias.name for alias in node.names]
for name in names:
self._add_edge(CodeEdge(source=file_id, target=name, kind="imports"))
def _collect_calls(self, func_node: ast.AST, caller_id: str, file_id: str) -> None:
"""Walk *func_node* and record function calls as edges."""
for child in ast.walk(func_node):
if isinstance(child, ast.Call):
name = None
if isinstance(child.func, ast.Name):
name = child.func.id
elif isinstance(child.func, ast.Attribute):
name = child.func.attr
if name:
self._add_edge(CodeEdge(source=caller_id, target=name, kind="calls"))
_JS_FUNC_RE = re.compile(
r"(?:function\s+(\w+)|(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s+)?(?:function|\([^)]*\)\s*=>)|"
r"(?:async\s+)?(\w+)\s*\([^)]*\)\s*\{)",
re.MULTILINE,
)
_JS_IMPORT_RE = re.compile(
r"(?:import\s+.*?from\s+['\"]([^'\"]+)['\"]|require\s*\(\s*['\"]([^'\"]+)['\"]\s*\))",
re.MULTILINE,
)
_JS_CALL_RE = re.compile(r"(\w+)\s*\(", re.MULTILINE)
def _parse_javascript(self, source: str, rel_path: str, file_id: str) -> None:
"""Parse JavaScript/TypeScript source using regex-based heuristics."""
for m in self._JS_FUNC_RE.finditer(source):
name = m.group(1) or m.group(2) or m.group(3)
if not name or name in {"if", "for", "while", "switch", "catch", "return", "throw"}:
continue
line = source[: m.start()].count("\n") + 1
func_id = f"{file_id}:{name}"
self._add_node(CodeNode(
id=func_id, kind="function", name=name,
file=rel_path, line=line,
))
self._add_edge(CodeEdge(source=file_id, target=func_id, kind="contains"))
body_start = m.end()
next_match = self._JS_FUNC_RE.search(source, m.end())
body_end = next_match.start() if next_match else len(source)
body = source[body_start:body_end]
for cm in self._JS_CALL_RE.finditer(body):
callee = cm.group(1)
if callee not in {"if", "for", "while", "switch", "catch", "return", "throw", "function", "const", "let", "var"}:
self._add_edge(CodeEdge(source=func_id, target=callee, kind="calls"))
for m in self._JS_IMPORT_RE.finditer(source):
module = m.group(1) or m.group(2)
if module:
self._add_edge(CodeEdge(source=file_id, target=module, kind="imports"))
|