code_graph / utility.py
tlarsson's picture
Upload 2 files
3145636 verified
import ast
import ast
class CallCollector(ast.NodeVisitor):
def __init__(self, defined_funcs):
self.calls = []
self.defined_funcs = defined_funcs
def visit_Call(self, node):
if isinstance(node.func, ast.Name) and node.func.id in self.defined_funcs:
self.calls.append(node.func.id)
self.generic_visit(node)
def parse_functions_from_files(file_dict):
functions = {}
defined_funcs = set()
def infer_type_from_value(value_node):
if isinstance(value_node, ast.Call) and isinstance(value_node.func, ast.Attribute):
if value_node.func.attr in ("read_csv", "DataFrame"): return "pd.DataFrame"
if value_node.func.attr == "array": return "np.ndarray"
elif isinstance(value_node, ast.List): return "list"
elif isinstance(value_node, ast.Dict): return "dict"
elif isinstance(value_node, ast.Set): return "set"
elif isinstance(value_node, ast.Constant):
if isinstance(value_node.value, str): return "str"
if isinstance(value_node.value, bool): return "bool"
if isinstance(value_node.value, int): return "int"
if isinstance(value_node.value, float): return "float"
return "?"
for fname, code in file_dict.items():
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef): defined_funcs.add(node.name)
for fname, code in file_dict.items():
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
func_name = node.name
args, returns = [], []
local_assignments = {}
reads_state, writes_state = set(), set()
for arg in node.args.args:
arg_type = ast.unparse(arg.annotation) if arg.annotation else "?"
args.append(f"{arg.arg}: {arg_type}")
collector = CallCollector(defined_funcs)
collector.visit(node)
calls = collector.calls
for sub in ast.walk(node):
if isinstance(sub, ast.Assign):
for target in sub.targets:
if isinstance(target, ast.Name):
local_assignments[target.id] = infer_type_from_value(sub.value)
elif isinstance(sub, ast.Return):
if sub.value is None: continue
if isinstance(sub.value, ast.Tuple):
for elt in sub.value.elts:
label = ast.unparse(elt)
returns.append(f"{label}: {local_assignments.get(label, infer_type_from_value(elt))}")
else:
label = ast.unparse(sub.value)
returns.append(f"{label}: {local_assignments.get(label, infer_type_from_value(sub.value))}")
functions[func_name] = {
"args": args,
"returns": returns,
"calls": calls,
"filename": fname,
"reads_state": sorted(reads_state),
"writes_state": sorted(writes_state)
}
return functions
def get_reachable_functions(start, graph):
visited, stack = set(), [start]
while stack:
node = stack.pop()
if node not in visited:
visited.add(node)
stack.extend(graph.get(node, []))
return visited
def get_backtrace_functions(target, graph):
reverse_graph = {}
for caller, callees in graph.items():
for callee in callees:
reverse_graph.setdefault(callee, []).append(caller)
visited, stack = set(), [target]
while stack:
node = stack.pop()
if node not in visited:
visited.add(node)
stack.extend(reverse_graph.get(node, []))
return visited
from collections import deque, defaultdict
def build_nodes_and_edges(parsed, root_func, reachable, reverse=False, max_depth=10):
depth_map = {}
x_offset_map = defaultdict(int)
positions = {}
visited = set()
queue = deque([(root_func, 0)])
while queue:
current, depth = queue.popleft()
if current in visited or depth > max_depth:
continue
visited.add(current)
adjusted_depth = -depth if reverse else depth
x = x_offset_map[adjusted_depth] * 300
y = adjusted_depth * 150
positions[current] = {"x": x, "y": y}
x_offset_map[adjusted_depth] += 1
if reverse:
# Find callers of this function
next_funcs = [
caller for caller, meta in parsed.items()
if current in meta["calls"] and caller in reachable
]
else:
# Find callees
next_funcs = [
callee for callee in parsed[current]["calls"]
if callee in reachable
]
for nxt in next_funcs:
queue.append((nxt, depth + 1))
nodes = [{
"data": {"id": name, "label": name},
"position": positions.get(name, {"x": 0, "y": 0}),
"classes": "main" if name == root_func else ""
} for name in visited]
edges = []
for src in visited:
call_sequence = parsed[src]["calls"]
call_index = 1
for tgt in call_sequence:
if tgt in visited:
edges.append({
"data": {
"source": src,
"target": tgt,
"label": str(call_index)
},
"style": {
"line-width": 4 if call_sequence.count(tgt) > 1 else 2
}
})
call_index += 1
return nodes, edges