|
|
import contextlib |
|
|
import re |
|
|
|
|
|
|
|
|
def extract_input_variables(nodes): |
|
|
"""Extracts input variables from the template and adds them to the input_variables field.""" |
|
|
for node in nodes: |
|
|
with contextlib.suppress(Exception): |
|
|
if "input_variables" in node["data"]["node"]["template"]: |
|
|
if node["data"]["node"]["template"]["_type"] == "prompt": |
|
|
variables = re.findall( |
|
|
r"\{(.*?)\}", |
|
|
node["data"]["node"]["template"]["template"]["value"], |
|
|
) |
|
|
elif node["data"]["node"]["template"]["_type"] == "few_shot": |
|
|
variables = re.findall( |
|
|
r"\{(.*?)\}", |
|
|
node["data"]["node"]["template"]["prefix"]["value"] |
|
|
+ node["data"]["node"]["template"]["suffix"]["value"], |
|
|
) |
|
|
else: |
|
|
variables = [] |
|
|
node["data"]["node"]["template"]["input_variables"]["value"] = variables |
|
|
return nodes |
|
|
|
|
|
|
|
|
def get_root_vertex(graph): |
|
|
"""Returns the root node of the template.""" |
|
|
incoming_edges = {edge.source_id for edge in graph.edges} |
|
|
|
|
|
if not incoming_edges and len(graph.vertices) == 1: |
|
|
return graph.vertices[0] |
|
|
|
|
|
return next((node for node in graph.vertices if node.id not in incoming_edges), None) |
|
|
|
|
|
|
|
|
def build_json(root, graph) -> dict: |
|
|
if "node" not in root.data: |
|
|
|
|
|
|
|
|
edge = root.edges[0] |
|
|
local_nodes = [edge.target] |
|
|
else: |
|
|
|
|
|
|
|
|
node_type = root.node_type |
|
|
local_nodes = graph.get_nodes_with_target(root) |
|
|
|
|
|
if len(local_nodes) == 1: |
|
|
return build_json(local_nodes[0], graph) |
|
|
|
|
|
template = root.data["node"]["template"] |
|
|
final_dict = template.copy() |
|
|
|
|
|
for key in final_dict: |
|
|
if key == "_type": |
|
|
continue |
|
|
|
|
|
value = final_dict[key] |
|
|
node_type = value["type"] |
|
|
|
|
|
if "value" in value and value["value"] is not None: |
|
|
|
|
|
value = value["value"] |
|
|
elif "dict" in node_type: |
|
|
|
|
|
value = {} |
|
|
else: |
|
|
|
|
|
children = [] |
|
|
for local_node in local_nodes: |
|
|
node_children = graph.get_children_by_node_type(local_node, node_type) |
|
|
children.extend(node_children) |
|
|
|
|
|
if value["required"] and not children: |
|
|
msg = f"No child with type {node_type} found" |
|
|
raise ValueError(msg) |
|
|
values = [build_json(child, graph) for child in children] |
|
|
value = ( |
|
|
list(values) if value["list"] else next(iter(values), None) |
|
|
) |
|
|
final_dict[key] = value |
|
|
|
|
|
return final_dict |
|
|
|