github-actions[bot] commited on
Commit
ac15317
·
1 Parent(s): d0887fa

Auto-sync from demo at Wed Dec 24 10:52:13 UTC 2025

Browse files
Files changed (2) hide show
  1. graphgen/bases/datatypes.py +5 -1
  2. graphgen/engine.py +7 -19
graphgen/bases/datatypes.py CHANGED
@@ -63,7 +63,11 @@ class Node(BaseModel):
63
  default_factory=list, description="list of dependent node ids"
64
  )
65
  execution_params: dict = Field(
66
- default_factory=dict, description="execution parameters like replicas, batch_size"
 
 
 
 
67
  )
68
 
69
  @classmethod
 
63
  default_factory=list, description="list of dependent node ids"
64
  )
65
  execution_params: dict = Field(
66
+ default_factory=dict,
67
+ description="execution parameters like replicas, batch_size",
68
+ )
69
+ save_output: bool = Field(
70
+ default=False, description="whether to save the output of this node"
71
  )
72
 
73
  @classmethod
graphgen/engine.py CHANGED
@@ -1,21 +1,22 @@
1
- import os
2
  import inspect
3
  import logging
 
4
  from collections import defaultdict, deque
5
  from functools import wraps
6
  from typing import Any, Callable, Dict, List, Set
7
- from dotenv import load_dotenv
8
 
9
  import ray
10
  import ray.data
 
11
  from ray.data import DataContext
12
 
13
  from graphgen.bases import Config, Node
14
- from graphgen.utils import logger
15
  from graphgen.common import init_llm, init_storage
 
16
 
17
  load_dotenv()
18
 
 
19
  class Engine:
20
  def __init__(
21
  self, config: Dict[str, Any], functions: Dict[str, Callable], **ray_init_kwargs
@@ -42,7 +43,7 @@ class Engine:
42
  existing_env_vars = ray_init_kwargs["runtime_env"].get("env_vars", {})
43
  ray_init_kwargs["runtime_env"]["env_vars"] = {
44
  **all_env_vars,
45
- **existing_env_vars
46
  }
47
 
48
  if not ray.is_initialized():
@@ -265,24 +266,11 @@ class Engine:
265
  f"Unsupported node type {node.type} for node {node.id}"
266
  )
267
 
268
- @staticmethod
269
- def _find_leaf_nodes(nodes: List[Node]) -> Set[str]:
270
- all_ids = {n.id for n in nodes}
271
- deps_set = set()
272
- for n in nodes:
273
- deps_set.update(n.dependencies)
274
- return all_ids - deps_set
275
-
276
  def execute(self, initial_ds: ray.data.Dataset) -> Dict[str, ray.data.Dataset]:
277
  sorted_nodes = self._topo_sort(self.config.nodes)
278
 
279
  for node in sorted_nodes:
280
  self._execute_node(node, initial_ds)
281
 
282
- leaf_nodes = self._find_leaf_nodes(sorted_nodes)
283
-
284
- @ray.remote
285
- def _fetch_result(ds: ray.data.Dataset) -> List[Any]:
286
- return ds.take_all()
287
-
288
- return {node_id: self.datasets[node_id] for node_id in leaf_nodes}
 
 
1
  import inspect
2
  import logging
3
+ import os
4
  from collections import defaultdict, deque
5
  from functools import wraps
6
  from typing import Any, Callable, Dict, List, Set
 
7
 
8
  import ray
9
  import ray.data
10
+ from dotenv import load_dotenv
11
  from ray.data import DataContext
12
 
13
  from graphgen.bases import Config, Node
 
14
  from graphgen.common import init_llm, init_storage
15
+ from graphgen.utils import logger
16
 
17
  load_dotenv()
18
 
19
+
20
  class Engine:
21
  def __init__(
22
  self, config: Dict[str, Any], functions: Dict[str, Callable], **ray_init_kwargs
 
43
  existing_env_vars = ray_init_kwargs["runtime_env"].get("env_vars", {})
44
  ray_init_kwargs["runtime_env"]["env_vars"] = {
45
  **all_env_vars,
46
+ **existing_env_vars,
47
  }
48
 
49
  if not ray.is_initialized():
 
266
  f"Unsupported node type {node.type} for node {node.id}"
267
  )
268
 
 
 
 
 
 
 
 
 
269
  def execute(self, initial_ds: ray.data.Dataset) -> Dict[str, ray.data.Dataset]:
270
  sorted_nodes = self._topo_sort(self.config.nodes)
271
 
272
  for node in sorted_nodes:
273
  self._execute_node(node, initial_ds)
274
 
275
+ output_nodes = [n for n in sorted_nodes if getattr(n, "save_output", False)]
276
+ return {node.id: self.datasets[node.id] for node in output_nodes}