github-actions[bot] commited on
Commit
9a57b42
·
1 Parent(s): 76b2991

Auto-sync from demo at Fri Jan 30 05:51:20 UTC 2026

Browse files
graphgen/bases/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
  from .base_evaluator import BaseKGEvaluator, BaseQAEvaluator, BaseTripleEvaluator
2
  from .base_extractor import BaseExtractor
 
3
  from .base_generator import BaseGenerator
4
  from .base_kg_builder import BaseKGBuilder
5
  from .base_llm_wrapper import BaseLLMWrapper
 
1
  from .base_evaluator import BaseKGEvaluator, BaseQAEvaluator, BaseTripleEvaluator
2
  from .base_extractor import BaseExtractor
3
+ from .base_filter import BaseValueFilter
4
  from .base_generator import BaseGenerator
5
  from .base_kg_builder import BaseKGBuilder
6
  from .base_llm_wrapper import BaseLLMWrapper
graphgen/bases/base_filter.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, Union
3
+
4
+ import numpy as np
5
+
6
+
7
+ class BaseFilter(ABC):
8
+ @abstractmethod
9
+ def filter(self, data: Any) -> bool:
10
+ """
11
+ Filter the data and return True if it passes the filter, False otherwise.
12
+ """
13
+ raise NotImplementedError
14
+
15
+
16
+ class BaseValueFilter(BaseFilter, ABC):
17
+ @abstractmethod
18
+ def filter(self, data: Union[int, float, np.number]) -> bool:
19
+ """
20
+ Filter the numeric value and return True if it passes the filter, False otherwise.
21
+ """
22
+ raise NotImplementedError
23
+
24
+ @property
25
+ @abstractmethod
26
+ def filter_type(self) -> str:
27
+ """
28
+ Return the type of filter (e.g., "greater_than", "less_than", etc.)
29
+ """
30
+ raise NotImplementedError
graphgen/engine.py CHANGED
@@ -2,7 +2,6 @@ import inspect
2
  import logging
3
  import os
4
  from collections import defaultdict, deque
5
- from functools import wraps
6
  from typing import Any, Callable, Dict, List, Set
7
 
8
  import ray
@@ -103,7 +102,6 @@ class Engine:
103
  kv_namespaces = set()
104
  graph_namespaces = set()
105
 
106
- # TODO: Temporarily hard-coded; node storage will be centrally managed later.
107
  for node in self.config.nodes:
108
  op_name = node.op_name
109
  if self._function_needs_param(op_name, "kv_backend"):
@@ -232,62 +230,38 @@ class Engine:
232
 
233
  input_ds = self._get_input_dataset(node, initial_ds)
234
 
235
- if inspect.isclass(op_handler):
236
- execution_params = node.execution_params or {}
237
- replicas = execution_params.get("replicas", 1)
238
- batch_size = (
239
- int(execution_params.get("batch_size"))
240
- if "batch_size" in execution_params
241
- else "default"
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  )
243
- compute_resources = execution_params.get("compute_resources", {})
244
-
245
- if node.type == "aggregate":
246
- self.datasets[node.id] = input_ds.repartition(1).map_batches(
247
- op_handler,
248
- compute=ray.data.ActorPoolStrategy(min_size=1, max_size=1),
249
- batch_size=None, # aggregate processes the whole dataset at once
250
- num_gpus=compute_resources.get("num_gpus", 0)
251
- if compute_resources
252
- else 0,
253
- fn_constructor_kwargs=node_params,
254
- batch_format="pandas",
255
- )
256
- else:
257
- # others like map, filter, flatmap, map_batch let actors process data inside batches
258
- self.datasets[node.id] = input_ds.map_batches(
259
- op_handler,
260
- compute=ray.data.ActorPoolStrategy(min_size=1, max_size=replicas),
261
- batch_size=batch_size,
262
- num_gpus=compute_resources.get("num_gpus", 0)
263
- if compute_resources
264
- else 0,
265
- fn_constructor_kwargs=node_params,
266
- batch_format="pandas",
267
- )
268
-
269
  else:
270
-
271
- @wraps(op_handler)
272
- def func_wrapper(row_or_batch: Dict[str, Any]) -> Dict[str, Any]:
273
- return op_handler(row_or_batch, **node_params)
274
-
275
- if node.type == "map":
276
- self.datasets[node.id] = input_ds.map(func_wrapper)
277
- elif node.type == "filter":
278
- self.datasets[node.id] = input_ds.filter(func_wrapper)
279
- elif node.type == "flatmap":
280
- self.datasets[node.id] = input_ds.flat_map(func_wrapper)
281
- elif node.type == "aggregate":
282
- self.datasets[node.id] = input_ds.repartition(1).map_batches(
283
- func_wrapper, batch_format="default"
284
- )
285
- elif node.type == "map_batch":
286
- self.datasets[node.id] = input_ds.map_batches(func_wrapper)
287
- else:
288
- raise ValueError(
289
- f"Unsupported node type {node.type} for node {node.id}"
290
- )
291
 
292
  def execute(
293
  self, initial_ds: ray.data.Dataset, output_dir: str
@@ -315,6 +289,14 @@ class Engine:
315
  logger.info("Node %s output saved to %s", node.id, node_output_path)
316
 
317
  # ray will lazy read the dataset
318
- self.datasets[node.id] = ray.data.read_json(node_output_path)
 
 
 
 
 
 
 
 
319
 
320
  return self.datasets
 
2
  import logging
3
  import os
4
  from collections import defaultdict, deque
 
5
  from typing import Any, Callable, Dict, List, Set
6
 
7
  import ray
 
102
  kv_namespaces = set()
103
  graph_namespaces = set()
104
 
 
105
  for node in self.config.nodes:
106
  op_name = node.op_name
107
  if self._function_needs_param(op_name, "kv_backend"):
 
230
 
231
  input_ds = self._get_input_dataset(node, initial_ds)
232
 
233
+ # if inspect.isclass(op_handler):
234
+ execution_params = node.execution_params or {}
235
+ replicas = execution_params.get("replicas", 1)
236
+ batch_size = (
237
+ int(execution_params.get("batch_size"))
238
+ if "batch_size" in execution_params
239
+ else "default"
240
+ )
241
+ compute_resources = execution_params.get("compute_resources", {})
242
+
243
+ if node.type == "aggregate":
244
+ self.datasets[node.id] = input_ds.repartition(1).map_batches(
245
+ op_handler,
246
+ compute=ray.data.ActorPoolStrategy(min_size=1, max_size=1),
247
+ batch_size=None, # aggregate processes the whole dataset at once
248
+ num_gpus=compute_resources.get("num_gpus", 0)
249
+ if compute_resources
250
+ else 0,
251
+ fn_constructor_kwargs=node_params,
252
+ batch_format="pandas",
253
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  else:
255
+ self.datasets[node.id] = input_ds.map_batches(
256
+ op_handler,
257
+ compute=ray.data.ActorPoolStrategy(min_size=1, max_size=replicas),
258
+ batch_size=batch_size,
259
+ num_gpus=compute_resources.get("num_gpus", 0)
260
+ if compute_resources
261
+ else 0,
262
+ fn_constructor_kwargs=node_params,
263
+ batch_format="pandas",
264
+ )
 
 
 
 
 
 
 
 
 
 
 
265
 
266
  def execute(
267
  self, initial_ds: ray.data.Dataset, output_dir: str
 
289
  logger.info("Node %s output saved to %s", node.id, node_output_path)
290
 
291
  # ray will lazy read the dataset
292
+ if os.path.exists(node_output_path) and os.listdir(node_output_path):
293
+ self.datasets[node.id] = ray.data.read_json(node_output_path)
294
+ else:
295
+ self.datasets[node.id] = ray.data.from_items([])
296
+ logger.warning(
297
+ "Node %s output path %s is empty. Created an empty dataset.",
298
+ node.id,
299
+ node_output_path,
300
+ )
301
 
302
  return self.datasets
graphgen/models/__init__.py CHANGED
@@ -6,6 +6,7 @@ from .evaluator import (
6
  StructureEvaluator,
7
  UniEvaluator,
8
  )
 
9
  from .generator import (
10
  AggregatedGenerator,
11
  AtomicGenerator,
 
6
  StructureEvaluator,
7
  UniEvaluator,
8
  )
9
+ from .filter import RangeFilter
10
  from .generator import (
11
  AggregatedGenerator,
12
  AtomicGenerator,
graphgen/models/filter/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .range_filter import RangeFilter
graphgen/models/filter/range_filter.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import numpy as np
4
+
5
+ from graphgen.bases import BaseValueFilter
6
+
7
+
8
+ class RangeFilter(BaseValueFilter):
9
+ """
10
+ keeps values within a specified range [min_val, max_val] (inclusive or exclusive)
11
+ """
12
+
13
+ def __init__(
14
+ self,
15
+ min_val: float,
16
+ max_val: float,
17
+ left_inclusive: bool = True,
18
+ right_inclusive: bool = True,
19
+ ):
20
+ self.min_val = min_val
21
+ self.max_val = max_val
22
+ self.left_inclusive = left_inclusive
23
+ self.right_inclusive = right_inclusive
24
+
25
+ def filter(self, data: Union[int, float, np.number]) -> bool:
26
+ value = float(data)
27
+ if self.left_inclusive and self.right_inclusive:
28
+ return self.min_val <= value <= self.max_val
29
+ if self.left_inclusive and not self.right_inclusive:
30
+ return self.min_val <= value < self.max_val
31
+ if not self.left_inclusive and self.right_inclusive:
32
+ return self.min_val < value <= self.max_val
33
+ return self.min_val < value < self.max_val
34
+
35
+ @property
36
+ def filter_type(self) -> str:
37
+ return "range"
38
+
39
+ def __repr__(self) -> str:
40
+ return f"RangeFilter({self.min_val}, {self.max_val})"
graphgen/operators/__init__.py CHANGED
@@ -2,6 +2,7 @@ from .build_kg import BuildKGService
2
  from .chunk import ChunkService
3
  from .evaluate import EvaluateService
4
  from .extract import ExtractService
 
5
  from .generate import GenerateService
6
  from .judge import JudgeService
7
  from .partition import PartitionService
@@ -9,7 +10,6 @@ from .quiz import QuizService
9
  from .read import read
10
  from .search import SearchService
11
 
12
-
13
  operators = {
14
  "read": read,
15
  "chunk": ChunkService,
@@ -21,4 +21,5 @@ operators = {
21
  "partition": PartitionService,
22
  "generate": GenerateService,
23
  "evaluate": EvaluateService,
 
24
  }
 
2
  from .chunk import ChunkService
3
  from .evaluate import EvaluateService
4
  from .extract import ExtractService
5
+ from .filter import FilterService
6
  from .generate import GenerateService
7
  from .judge import JudgeService
8
  from .partition import PartitionService
 
10
  from .read import read
11
  from .search import SearchService
12
 
 
13
  operators = {
14
  "read": read,
15
  "chunk": ChunkService,
 
21
  "partition": PartitionService,
22
  "generate": GenerateService,
23
  "evaluate": EvaluateService,
24
+ "filter": FilterService,
25
  }
graphgen/operators/filter/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .filter_service import FilterService
graphgen/operators/filter/filter_service.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ from graphgen.bases import BaseOperator
4
+ from graphgen.utils import logger
5
+
6
+
7
+ class FilterService(BaseOperator):
8
+ def __init__(
9
+ self, working_dir: str = "cache", kv_backend: str = "rocksdb", **filter_kwargs
10
+ ):
11
+ super().__init__(
12
+ working_dir=working_dir, kv_backend=kv_backend, op_name="filter"
13
+ )
14
+ method = filter_kwargs["method"]
15
+ method_params = filter_kwargs["method_params"]
16
+ self.metric = method_params["metric"]
17
+ if method == "range":
18
+ from graphgen.models import RangeFilter
19
+
20
+ self.filter_instance = RangeFilter(
21
+ min_val=method_params["min_val"],
22
+ max_val=method_params["max_val"],
23
+ left_inclusive=method_params.get("left_inclusive", True),
24
+ right_inclusive=method_params.get("right_inclusive", True),
25
+ )
26
+ else:
27
+ raise ValueError(f"Unsupported filter method: {method}")
28
+
29
+ def process(self, batch: list) -> Tuple[list, dict]:
30
+ """
31
+ Filter the items in the batch.
32
+ :return: A tuple of (results, meta_updates)
33
+ results: A list of filtered items.
34
+ meta_updates: empty as filtering does not create new items.
35
+ """
36
+ results = []
37
+ meta_updates = {}
38
+
39
+ for item in batch:
40
+ value = item["metrics"].get(self.metric)
41
+ if value is None:
42
+ logger.warning(
43
+ f"Item {item} does not have metric {self.metric}. Skipping."
44
+ )
45
+ continue
46
+ if self.filter_instance.filter(value):
47
+ results.append(item)
48
+
49
+ return results, meta_updates