Spaces:
Running
Running
github-actions[bot]
commited on
Commit
·
ac15317
1
Parent(s):
d0887fa
Auto-sync from demo at Wed Dec 24 10:52:13 UTC 2025
Browse files- graphgen/bases/datatypes.py +5 -1
- graphgen/engine.py +7 -19
graphgen/bases/datatypes.py
CHANGED
|
@@ -63,7 +63,11 @@ class Node(BaseModel):
|
|
| 63 |
default_factory=list, description="list of dependent node ids"
|
| 64 |
)
|
| 65 |
execution_params: dict = Field(
|
| 66 |
-
default_factory=dict,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
)
|
| 68 |
|
| 69 |
@classmethod
|
|
|
|
| 63 |
default_factory=list, description="list of dependent node ids"
|
| 64 |
)
|
| 65 |
execution_params: dict = Field(
|
| 66 |
+
default_factory=dict,
|
| 67 |
+
description="execution parameters like replicas, batch_size",
|
| 68 |
+
)
|
| 69 |
+
save_output: bool = Field(
|
| 70 |
+
default=False, description="whether to save the output of this node"
|
| 71 |
)
|
| 72 |
|
| 73 |
@classmethod
|
graphgen/engine.py
CHANGED
|
@@ -1,21 +1,22 @@
|
|
| 1 |
-
import os
|
| 2 |
import inspect
|
| 3 |
import logging
|
|
|
|
| 4 |
from collections import defaultdict, deque
|
| 5 |
from functools import wraps
|
| 6 |
from typing import Any, Callable, Dict, List, Set
|
| 7 |
-
from dotenv import load_dotenv
|
| 8 |
|
| 9 |
import ray
|
| 10 |
import ray.data
|
|
|
|
| 11 |
from ray.data import DataContext
|
| 12 |
|
| 13 |
from graphgen.bases import Config, Node
|
| 14 |
-
from graphgen.utils import logger
|
| 15 |
from graphgen.common import init_llm, init_storage
|
|
|
|
| 16 |
|
| 17 |
load_dotenv()
|
| 18 |
|
|
|
|
| 19 |
class Engine:
|
| 20 |
def __init__(
|
| 21 |
self, config: Dict[str, Any], functions: Dict[str, Callable], **ray_init_kwargs
|
|
@@ -42,7 +43,7 @@ class Engine:
|
|
| 42 |
existing_env_vars = ray_init_kwargs["runtime_env"].get("env_vars", {})
|
| 43 |
ray_init_kwargs["runtime_env"]["env_vars"] = {
|
| 44 |
**all_env_vars,
|
| 45 |
-
**existing_env_vars
|
| 46 |
}
|
| 47 |
|
| 48 |
if not ray.is_initialized():
|
|
@@ -265,24 +266,11 @@ class Engine:
|
|
| 265 |
f"Unsupported node type {node.type} for node {node.id}"
|
| 266 |
)
|
| 267 |
|
| 268 |
-
@staticmethod
|
| 269 |
-
def _find_leaf_nodes(nodes: List[Node]) -> Set[str]:
|
| 270 |
-
all_ids = {n.id for n in nodes}
|
| 271 |
-
deps_set = set()
|
| 272 |
-
for n in nodes:
|
| 273 |
-
deps_set.update(n.dependencies)
|
| 274 |
-
return all_ids - deps_set
|
| 275 |
-
|
| 276 |
def execute(self, initial_ds: ray.data.Dataset) -> Dict[str, ray.data.Dataset]:
|
| 277 |
sorted_nodes = self._topo_sort(self.config.nodes)
|
| 278 |
|
| 279 |
for node in sorted_nodes:
|
| 280 |
self._execute_node(node, initial_ds)
|
| 281 |
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
@ray.remote
|
| 285 |
-
def _fetch_result(ds: ray.data.Dataset) -> List[Any]:
|
| 286 |
-
return ds.take_all()
|
| 287 |
-
|
| 288 |
-
return {node_id: self.datasets[node_id] for node_id in leaf_nodes}
|
|
|
|
|
|
|
| 1 |
import inspect
|
| 2 |
import logging
|
| 3 |
+
import os
|
| 4 |
from collections import defaultdict, deque
|
| 5 |
from functools import wraps
|
| 6 |
from typing import Any, Callable, Dict, List, Set
|
|
|
|
| 7 |
|
| 8 |
import ray
|
| 9 |
import ray.data
|
| 10 |
+
from dotenv import load_dotenv
|
| 11 |
from ray.data import DataContext
|
| 12 |
|
| 13 |
from graphgen.bases import Config, Node
|
|
|
|
| 14 |
from graphgen.common import init_llm, init_storage
|
| 15 |
+
from graphgen.utils import logger
|
| 16 |
|
| 17 |
load_dotenv()
|
| 18 |
|
| 19 |
+
|
| 20 |
class Engine:
|
| 21 |
def __init__(
|
| 22 |
self, config: Dict[str, Any], functions: Dict[str, Callable], **ray_init_kwargs
|
|
|
|
| 43 |
existing_env_vars = ray_init_kwargs["runtime_env"].get("env_vars", {})
|
| 44 |
ray_init_kwargs["runtime_env"]["env_vars"] = {
|
| 45 |
**all_env_vars,
|
| 46 |
+
**existing_env_vars,
|
| 47 |
}
|
| 48 |
|
| 49 |
if not ray.is_initialized():
|
|
|
|
| 266 |
f"Unsupported node type {node.type} for node {node.id}"
|
| 267 |
)
|
| 268 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
def execute(self, initial_ds: ray.data.Dataset) -> Dict[str, ray.data.Dataset]:
|
| 270 |
sorted_nodes = self._topo_sort(self.config.nodes)
|
| 271 |
|
| 272 |
for node in sorted_nodes:
|
| 273 |
self._execute_node(node, initial_ds)
|
| 274 |
|
| 275 |
+
output_nodes = [n for n in sorted_nodes if getattr(n, "save_output", False)]
|
| 276 |
+
return {node.id: self.datasets[node.id] for node in output_nodes}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|