code_graph / utility.py
tlarsson's picture
Upload 3 files
b82e7a7 verified
raw
history blame
3.72 kB
import ast
def parse_functions_from_files(file_dict):
functions = {}
defined_funcs = set()
def infer_type_from_value(value_node):
if isinstance(value_node, ast.Call) and isinstance(value_node.func, ast.Attribute):
if value_node.func.attr in ("read_csv", "DataFrame"): return "pd.DataFrame"
if value_node.func.attr == "array": return "np.ndarray"
elif isinstance(value_node, ast.List): return "list"
elif isinstance(value_node, ast.Dict): return "dict"
elif isinstance(value_node, ast.Set): return "set"
elif isinstance(value_node, ast.Constant):
if isinstance(value_node.value, str): return "str"
if isinstance(value_node.value, bool): return "bool"
if isinstance(value_node.value, int): return "int"
if isinstance(value_node.value, float): return "float"
return "?"
for fname, code in file_dict.items():
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef): defined_funcs.add(node.name)
for fname, code in file_dict.items():
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
func_name = node.name
args, returns, calls = [], [], []
local_assignments = {}
reads_state, writes_state = set(), set()
for arg in node.args.args:
arg_type = ast.unparse(arg.annotation) if arg.annotation else "?"
args.append(f"{arg.arg}: {arg_type}")
for sub in ast.walk(node):
if isinstance(sub, ast.Call) and isinstance(sub.func, ast.Name):
if sub.func.id in defined_funcs: calls.append(sub.func.id)
elif isinstance(sub, ast.Assign):
for target in sub.targets:
if isinstance(target, ast.Name):
local_assignments[target.id] = infer_type_from_value(sub.value)
elif isinstance(sub, ast.Return):
if sub.value is None: continue
if isinstance(sub.value, ast.Tuple):
for elt in sub.value.elts:
label = ast.unparse(elt)
returns.append(f"{label}: {local_assignments.get(label, infer_type_from_value(elt))}")
else:
label = ast.unparse(sub.value)
returns.append(f"{label}: {local_assignments.get(label, infer_type_from_value(sub.value))}")
functions[func_name] = {
"args": args,
"returns": returns,
"calls": calls,
"filename": fname,
"reads_state": sorted(reads_state),
"writes_state": sorted(writes_state)
}
return functions
def get_reachable_functions(start, graph):
visited, stack = set(), [start]
while stack:
node = stack.pop()
if node not in visited:
visited.add(node)
stack.extend(graph.get(node, []))
return visited
def get_backtrace_functions(target, graph):
reverse_graph = {}
for caller, callees in graph.items():
for callee in callees:
reverse_graph.setdefault(callee, []).append(caller)
visited, stack = set(), [target]
while stack:
node = stack.pop()
if node not in visited:
visited.add(node)
stack.extend(reverse_graph.get(node, []))
return visited