github-actions[bot] commited on
Commit
5445ab9
·
1 Parent(s): 0a18089

Auto-sync from demo at Fri Jan 30 10:55:05 UTC 2026

Browse files
graphgen/bases/base_filter.py CHANGED
@@ -1,7 +1,8 @@
1
  from abc import ABC, abstractmethod
2
- from typing import Any, Union
3
 
4
- import numpy as np
 
5
 
6
 
7
  class BaseFilter(ABC):
@@ -15,7 +16,7 @@ class BaseFilter(ABC):
15
 
16
  class BaseValueFilter(BaseFilter, ABC):
17
  @abstractmethod
18
- def filter(self, data: Union[int, float, np.number]) -> bool:
19
  """
20
  Filter the numeric value and return True if it passes the filter, False otherwise.
21
  """
 
1
  from abc import ABC, abstractmethod
2
+ from typing import TYPE_CHECKING, Any, Union
3
 
4
+ if TYPE_CHECKING:
5
+ import numpy as np
6
 
7
 
8
  class BaseFilter(ABC):
 
16
 
17
  class BaseValueFilter(BaseFilter, ABC):
18
  @abstractmethod
19
+ def filter(self, data: Union[int, float, "np.number"]) -> bool:
20
  """
21
  Filter the numeric value and return True if it passes the filter, False otherwise.
22
  """
graphgen/bases/base_operator.py CHANGED
@@ -1,14 +1,18 @@
 
 
1
  import inspect
2
  import os
3
  from abc import ABC, abstractmethod
4
- from typing import Iterable, Tuple, Union
5
 
6
- import numpy as np
7
- import pandas as pd
8
- import ray
9
 
10
 
11
  def convert_to_serializable(obj):
 
 
12
  if isinstance(obj, np.ndarray):
13
  return obj.tolist()
14
  if isinstance(obj, np.generic):
@@ -40,6 +44,8 @@ class BaseOperator(ABC):
40
  )
41
 
42
  try:
 
 
43
  ctx = ray.get_runtime_context()
44
  worker_id = ctx.get_actor_id() or ctx.get_worker_id()
45
  worker_id_short = worker_id[-6:] if worker_id else "driver"
@@ -62,9 +68,11 @@ class BaseOperator(ABC):
62
  )
63
 
64
  def __call__(
65
- self, batch: pd.DataFrame
66
- ) -> Union[pd.DataFrame, Iterable[pd.DataFrame]]:
67
  # lazy import to avoid circular import
 
 
68
  from graphgen.utils import CURRENT_LOGGER_VAR
69
 
70
  logger_token = CURRENT_LOGGER_VAR.set(self.logger)
@@ -106,7 +114,7 @@ class BaseOperator(ABC):
106
 
107
  return compute_dict_hash(content, prefix=f"{self.op_name}-")
108
 
109
- def split(self, batch: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
110
  """
111
  Split the input batch into to_process & processed based on _meta data in KV_storage
112
  :param batch
@@ -114,6 +122,8 @@ class BaseOperator(ABC):
114
  to_process: DataFrame of documents to be chunked
115
  recovered: Result DataFrame of already chunked documents
116
  """
 
 
117
  meta_forward = self.get_meta_forward()
118
  meta_ids = set(meta_forward.keys())
119
  mask = batch["_trace_id"].isin(meta_ids)
 
1
+ from __future__ import annotations
2
+
3
  import inspect
4
  import os
5
  from abc import ABC, abstractmethod
6
+ from typing import TYPE_CHECKING, Iterable, Tuple, Union
7
 
8
+ if TYPE_CHECKING:
9
+ import numpy as np
10
+ import pandas as pd
11
 
12
 
13
  def convert_to_serializable(obj):
14
+ import numpy as np
15
+
16
  if isinstance(obj, np.ndarray):
17
  return obj.tolist()
18
  if isinstance(obj, np.generic):
 
44
  )
45
 
46
  try:
47
+ import ray
48
+
49
  ctx = ray.get_runtime_context()
50
  worker_id = ctx.get_actor_id() or ctx.get_worker_id()
51
  worker_id_short = worker_id[-6:] if worker_id else "driver"
 
68
  )
69
 
70
  def __call__(
71
+ self, batch: "pd.DataFrame"
72
+ ) -> Union["pd.DataFrame", Iterable["pd.DataFrame"]]:
73
  # lazy import to avoid circular import
74
+ import pandas as pd
75
+
76
  from graphgen.utils import CURRENT_LOGGER_VAR
77
 
78
  logger_token = CURRENT_LOGGER_VAR.set(self.logger)
 
114
 
115
  return compute_dict_hash(content, prefix=f"{self.op_name}-")
116
 
117
+ def split(self, batch: "pd.DataFrame") -> tuple["pd.DataFrame", "pd.DataFrame"]:
118
  """
119
  Split the input batch into to_process & processed based on _meta data in KV_storage
120
  :param batch
 
122
  to_process: DataFrame of documents to be chunked
123
  recovered: Result DataFrame of already chunked documents
124
  """
125
+ import pandas as pd
126
+
127
  meta_forward = self.get_meta_forward()
128
  meta_ids = set(meta_forward.keys())
129
  mask = batch["_trace_id"].isin(meta_ids)
graphgen/bases/base_reader.py CHANGED
@@ -1,10 +1,14 @@
 
 
1
  import os
2
  from abc import ABC, abstractmethod
3
- from typing import Any, Dict, List, Union
4
 
5
- import pandas as pd
6
  import requests
7
- from ray.data import Dataset
 
 
 
8
 
9
 
10
  class BaseReader(ABC):
@@ -51,6 +55,7 @@ class BaseReader(ABC):
51
  """
52
  Validate data format.
53
  """
 
54
  if "type" not in batch.columns:
55
  raise ValueError(f"Missing 'type' column. Found: {list(batch.columns)}")
56
 
 
1
+ from __future__ import annotations
2
+
3
  import os
4
  from abc import ABC, abstractmethod
5
+ from typing import TYPE_CHECKING, Any, Dict, List, Union
6
 
 
7
  import requests
8
+
9
+ if TYPE_CHECKING:
10
+ import pandas as pd
11
+ from ray.data import Dataset
12
 
13
 
14
  class BaseReader(ABC):
 
55
  """
56
  Validate data format.
57
  """
58
+
59
  if "type" not in batch.columns:
60
  raise ValueError(f"Missing 'type' column. Found: {list(batch.columns)}")
61
 
graphgen/common/init_llm.py CHANGED
@@ -1,11 +1,12 @@
1
  import os
2
- from typing import Any, Dict, Optional
3
-
4
- import ray
5
 
6
  from graphgen.bases import BaseLLMWrapper
7
  from graphgen.models import Tokenizer
8
 
 
 
 
9
 
10
  class LLMServiceActor:
11
  """
@@ -73,7 +74,7 @@ class LLMServiceProxy(BaseLLMWrapper):
73
  A proxy class to interact with the LLMServiceActor for distributed LLM operations.
74
  """
75
 
76
- def __init__(self, actor_handle: ray.actor.ActorHandle):
77
  super().__init__()
78
  self.actor_handle = actor_handle
79
  self._create_local_tokenizer()
@@ -120,6 +121,8 @@ class LLMFactory:
120
  def create_llm(
121
  model_type: str, backend: str, config: Dict[str, Any]
122
  ) -> BaseLLMWrapper:
 
 
123
  if not config:
124
  raise ValueError(
125
  f"No configuration provided for LLM {model_type} with backend {backend}."
 
1
  import os
2
+ from typing import TYPE_CHECKING, Any, Dict, Optional
 
 
3
 
4
  from graphgen.bases import BaseLLMWrapper
5
  from graphgen.models import Tokenizer
6
 
7
+ if TYPE_CHECKING:
8
+ import ray
9
+
10
 
11
  class LLMServiceActor:
12
  """
 
74
  A proxy class to interact with the LLMServiceActor for distributed LLM operations.
75
  """
76
 
77
+ def __init__(self, actor_handle: "ray.actor.ActorHandle"):
78
  super().__init__()
79
  self.actor_handle = actor_handle
80
  self._create_local_tokenizer()
 
121
  def create_llm(
122
  model_type: str, backend: str, config: Dict[str, Any]
123
  ) -> BaseLLMWrapper:
124
+ import ray
125
+
126
  if not config:
127
  raise ValueError(
128
  f"No configuration provided for LLM {model_type} with backend {backend}."
graphgen/common/init_storage.py CHANGED
@@ -146,7 +146,7 @@ class GraphStorageActor:
146
 
147
 
148
  class RemoteKVStorageProxy(BaseKVStorage):
149
- def __init__(self, actor_handle: ray.actor.ActorHandle):
150
  super().__init__()
151
  self.actor = actor_handle
152
 
@@ -202,68 +202,87 @@ class RemoteGraphStorageProxy(BaseGraphStorage):
202
  return ray.get(self.actor.get_all_node_degrees.remote())
203
 
204
  def get_node_count(self) -> int:
 
205
  return ray.get(self.actor.get_node_count.remote())
206
 
207
  def get_edge_count(self) -> int:
 
208
  return ray.get(self.actor.get_edge_count.remote())
209
 
210
  def get_connected_components(self, undirected: bool = True) -> List[Set[str]]:
 
211
  return ray.get(self.actor.get_connected_components.remote(undirected))
212
 
213
  def has_node(self, node_id: str) -> bool:
 
214
  return ray.get(self.actor.has_node.remote(node_id))
215
 
216
  def has_edge(self, source_node_id: str, target_node_id: str):
 
217
  return ray.get(self.actor.has_edge.remote(source_node_id, target_node_id))
218
 
219
  def node_degree(self, node_id: str) -> int:
 
220
  return ray.get(self.actor.node_degree.remote(node_id))
221
 
222
  def edge_degree(self, src_id: str, tgt_id: str) -> int:
 
223
  return ray.get(self.actor.edge_degree.remote(src_id, tgt_id))
224
 
225
  def get_node(self, node_id: str) -> Any:
 
226
  return ray.get(self.actor.get_node.remote(node_id))
227
 
228
  def update_node(self, node_id: str, node_data: dict[str, str]):
 
229
  return ray.get(self.actor.update_node.remote(node_id, node_data))
230
 
231
  def get_all_nodes(self) -> Any:
 
232
  return ray.get(self.actor.get_all_nodes.remote())
233
 
234
  def get_edge(self, source_node_id: str, target_node_id: str):
 
235
  return ray.get(self.actor.get_edge.remote(source_node_id, target_node_id))
236
 
237
  def update_edge(
238
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
239
  ):
 
240
  return ray.get(
241
  self.actor.update_edge.remote(source_node_id, target_node_id, edge_data)
242
  )
243
 
244
  def get_all_edges(self) -> Any:
 
245
  return ray.get(self.actor.get_all_edges.remote())
246
 
247
  def get_node_edges(self, source_node_id: str) -> Any:
 
248
  return ray.get(self.actor.get_node_edges.remote(source_node_id))
249
 
250
  def upsert_node(self, node_id: str, node_data: dict[str, str]):
 
251
  return ray.get(self.actor.upsert_node.remote(node_id, node_data))
252
 
253
  def upsert_edge(
254
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
255
  ):
 
256
  return ray.get(
257
  self.actor.upsert_edge.remote(source_node_id, target_node_id, edge_data)
258
  )
259
 
260
  def delete_node(self, node_id: str):
 
261
  return ray.get(self.actor.delete_node.remote(node_id))
262
 
263
  def get_neighbors(self, node_id: str) -> List[str]:
 
264
  return ray.get(self.actor.get_neighbors.remote(node_id))
265
 
266
  def reload(self):
 
267
  return ray.get(self.actor.reload.remote())
268
 
269
 
@@ -274,6 +293,7 @@ class StorageFactory:
274
 
275
  @staticmethod
276
  def create_storage(backend: str, working_dir: str, namespace: str):
 
277
  if backend in ["json_kv", "rocksdb"]:
278
  actor_name = f"Actor_KV_{namespace}"
279
  actor_class = KVStorageActor
 
146
 
147
 
148
  class RemoteKVStorageProxy(BaseKVStorage):
149
+ def __init__(self, actor_handle: "ray.actor.ActorHandle"):
150
  super().__init__()
151
  self.actor = actor_handle
152
 
 
202
  return ray.get(self.actor.get_all_node_degrees.remote())
203
 
204
  def get_node_count(self) -> int:
205
+
206
  return ray.get(self.actor.get_node_count.remote())
207
 
208
  def get_edge_count(self) -> int:
209
+
210
  return ray.get(self.actor.get_edge_count.remote())
211
 
212
  def get_connected_components(self, undirected: bool = True) -> List[Set[str]]:
213
+
214
  return ray.get(self.actor.get_connected_components.remote(undirected))
215
 
216
  def has_node(self, node_id: str) -> bool:
217
+
218
  return ray.get(self.actor.has_node.remote(node_id))
219
 
220
  def has_edge(self, source_node_id: str, target_node_id: str):
221
+
222
  return ray.get(self.actor.has_edge.remote(source_node_id, target_node_id))
223
 
224
  def node_degree(self, node_id: str) -> int:
225
+
226
  return ray.get(self.actor.node_degree.remote(node_id))
227
 
228
  def edge_degree(self, src_id: str, tgt_id: str) -> int:
229
+
230
  return ray.get(self.actor.edge_degree.remote(src_id, tgt_id))
231
 
232
  def get_node(self, node_id: str) -> Any:
233
+
234
  return ray.get(self.actor.get_node.remote(node_id))
235
 
236
  def update_node(self, node_id: str, node_data: dict[str, str]):
237
+
238
  return ray.get(self.actor.update_node.remote(node_id, node_data))
239
 
240
  def get_all_nodes(self) -> Any:
241
+
242
  return ray.get(self.actor.get_all_nodes.remote())
243
 
244
  def get_edge(self, source_node_id: str, target_node_id: str):
245
+
246
  return ray.get(self.actor.get_edge.remote(source_node_id, target_node_id))
247
 
248
  def update_edge(
249
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
250
  ):
251
+
252
  return ray.get(
253
  self.actor.update_edge.remote(source_node_id, target_node_id, edge_data)
254
  )
255
 
256
  def get_all_edges(self) -> Any:
257
+
258
  return ray.get(self.actor.get_all_edges.remote())
259
 
260
  def get_node_edges(self, source_node_id: str) -> Any:
261
+
262
  return ray.get(self.actor.get_node_edges.remote(source_node_id))
263
 
264
  def upsert_node(self, node_id: str, node_data: dict[str, str]):
265
+
266
  return ray.get(self.actor.upsert_node.remote(node_id, node_data))
267
 
268
  def upsert_edge(
269
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
270
  ):
271
+
272
  return ray.get(
273
  self.actor.upsert_edge.remote(source_node_id, target_node_id, edge_data)
274
  )
275
 
276
  def delete_node(self, node_id: str):
277
+
278
  return ray.get(self.actor.delete_node.remote(node_id))
279
 
280
  def get_neighbors(self, node_id: str) -> List[str]:
281
+
282
  return ray.get(self.actor.get_neighbors.remote(node_id))
283
 
284
  def reload(self):
285
+
286
  return ray.get(self.actor.reload.remote())
287
 
288
 
 
293
 
294
  @staticmethod
295
  def create_storage(backend: str, working_dir: str, namespace: str):
296
+
297
  if backend in ["json_kv", "rocksdb"]:
298
  actor_name = f"Actor_KV_{namespace}"
299
  actor_class = KVStorageActor
graphgen/models/__init__.py CHANGED
@@ -1,48 +1,121 @@
1
- from .evaluator import (
2
- AccuracyEvaluator,
3
- LengthEvaluator,
4
- MTLDEvaluator,
5
- RewardEvaluator,
6
- StructureEvaluator,
7
- UniEvaluator,
8
- )
9
- from .filter import RangeFilter
10
- from .generator import (
11
- AggregatedGenerator,
12
- AtomicGenerator,
13
- CoTGenerator,
14
- FillInBlankGenerator,
15
- MultiAnswerGenerator,
16
- MultiChoiceGenerator,
17
- MultiHopGenerator,
18
- QuizGenerator,
19
- TrueFalseGenerator,
20
- VQAGenerator,
21
- )
22
- from .kg_builder import LightRAGKGBuilder, MMKGBuilder
23
- from .llm import HTTPClient, OllamaClient, OpenAIClient
24
- from .partitioner import (
25
- AnchorBFSPartitioner,
26
- BFSPartitioner,
27
- DFSPartitioner,
28
- ECEPartitioner,
29
- LeidenPartitioner,
30
- )
31
- from .reader import (
32
- CSVReader,
33
- JSONReader,
34
- ParquetReader,
35
- PDFReader,
36
- PickleReader,
37
- RDFReader,
38
- TXTReader,
39
- )
40
- from .rephraser import StyleControlledRephraser
41
- from .searcher.db.ncbi_searcher import NCBISearch
42
- from .searcher.db.rnacentral_searcher import RNACentralSearch
43
- from .searcher.db.uniprot_searcher import UniProtSearch
44
- from .searcher.kg.wiki_search import WikiSearch
45
- from .searcher.web.bing_search import BingSearch
46
- from .searcher.web.google_search import GoogleSearch
47
- from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter
48
- from .tokenizer import Tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ if TYPE_CHECKING:
4
+ from .evaluator import (
5
+ AccuracyEvaluator,
6
+ LengthEvaluator,
7
+ MTLDEvaluator,
8
+ RewardEvaluator,
9
+ StructureEvaluator,
10
+ UniEvaluator,
11
+ )
12
+ from .filter import RangeFilter
13
+ from .generator import (
14
+ AggregatedGenerator,
15
+ AtomicGenerator,
16
+ CoTGenerator,
17
+ FillInBlankGenerator,
18
+ MultiAnswerGenerator,
19
+ MultiChoiceGenerator,
20
+ MultiHopGenerator,
21
+ QuizGenerator,
22
+ TrueFalseGenerator,
23
+ VQAGenerator,
24
+ )
25
+ from .kg_builder import LightRAGKGBuilder, MMKGBuilder
26
+ from .llm import HTTPClient, OllamaClient, OpenAIClient
27
+ from .partitioner import (
28
+ AnchorBFSPartitioner,
29
+ BFSPartitioner,
30
+ DFSPartitioner,
31
+ ECEPartitioner,
32
+ LeidenPartitioner,
33
+ )
34
+ from .reader import (
35
+ CSVReader,
36
+ JSONReader,
37
+ ParquetReader,
38
+ PDFReader,
39
+ PickleReader,
40
+ RDFReader,
41
+ TXTReader,
42
+ )
43
+ from .rephraser import StyleControlledRephraser
44
+ from .searcher.db.ncbi_searcher import NCBISearch
45
+ from .searcher.db.rnacentral_searcher import RNACentralSearch
46
+ from .searcher.db.uniprot_searcher import UniProtSearch
47
+ from .searcher.kg.wiki_search import WikiSearch
48
+ from .searcher.web.bing_search import BingSearch
49
+ from .searcher.web.google_search import GoogleSearch
50
+ from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter
51
+ from .tokenizer import Tokenizer
52
+
53
+ _import_map = {
54
+ # Evaluator
55
+ "AccuracyEvaluator": ".evaluator",
56
+ "LengthEvaluator": ".evaluator",
57
+ "MTLDEvaluator": ".evaluator",
58
+ "RewardEvaluator": ".evaluator",
59
+ "StructureEvaluator": ".evaluator",
60
+ "UniEvaluator": ".evaluator",
61
+ # Filter
62
+ "RangeFilter": ".filter",
63
+ # Generator
64
+ "AggregatedGenerator": ".generator",
65
+ "AtomicGenerator": ".generator",
66
+ "CoTGenerator": ".generator",
67
+ "FillInBlankGenerator": ".generator",
68
+ "MultiAnswerGenerator": ".generator",
69
+ "MultiChoiceGenerator": ".generator",
70
+ "MultiHopGenerator": ".generator",
71
+ "QuizGenerator": ".generator",
72
+ "TrueFalseGenerator": ".generator",
73
+ "VQAGenerator": ".generator",
74
+ # KG Builder
75
+ "LightRAGKGBuilder": ".kg_builder",
76
+ "MMKGBuilder": ".kg_builder",
77
+ # LLM
78
+ "HTTPClient": ".llm",
79
+ "OllamaClient": ".llm",
80
+ "OpenAIClient": ".llm",
81
+ # Partitioner
82
+ "AnchorBFSPartitioner": ".partitioner",
83
+ "BFSPartitioner": ".partitioner",
84
+ "DFSPartitioner": ".partitioner",
85
+ "ECEPartitioner": ".partitioner",
86
+ "LeidenPartitioner": ".partitioner",
87
+ # Reader
88
+ "CSVReader": ".reader",
89
+ "JSONReader": ".reader",
90
+ "ParquetReader": ".reader",
91
+ "PDFReader": ".reader",
92
+ "PickleReader": ".reader",
93
+ "RDFReader": ".reader",
94
+ "TXTReader": ".reader",
95
+ # Searcher
96
+ "NCBISearch": ".searcher.db.ncbi_searcher",
97
+ "RNACentralSearch": ".searcher.db.rnacentral_searcher",
98
+ "UniProtSearch": ".searcher.db.uniprot_searcher",
99
+ "WikiSearch": ".searcher.kg.wiki_search",
100
+ "BingSearch": ".searcher.web.bing_search",
101
+ "GoogleSearch": ".searcher.web.google_search",
102
+ # Splitter
103
+ "ChineseRecursiveTextSplitter": ".splitter",
104
+ "RecursiveCharacterSplitter": ".splitter",
105
+ # Tokenizer
106
+ "Tokenizer": ".tokenizer",
107
+ # Rephraser
108
+ "StyleControlledRephraser": ".rephraser",
109
+ }
110
+
111
+
112
+ def __getattr__(name):
113
+ if name in _import_map:
114
+ import importlib
115
+
116
+ module = importlib.import_module(_import_map[name], package=__name__)
117
+ return getattr(module, name)
118
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
119
+
120
+
121
+ __all__ = list(_import_map.keys())
graphgen/models/evaluator/kg/structure_evaluator.py CHANGED
@@ -1,9 +1,6 @@
1
  from collections import Counter
2
  from typing import Any, Dict, Optional
3
 
4
- import numpy as np
5
- from scipy import stats
6
-
7
  from graphgen.bases import BaseGraphStorage, BaseKGEvaluator
8
  from graphgen.utils import logger
9
 
@@ -75,6 +72,9 @@ class StructureEvaluator(BaseKGEvaluator):
75
 
76
  @staticmethod
77
  def _calculate_powerlaw_r2(degree_map: Dict[str, int]) -> Optional[float]:
 
 
 
78
  degrees = [deg for deg in degree_map.values() if deg > 0]
79
 
80
  if len(degrees) < 10:
 
1
  from collections import Counter
2
  from typing import Any, Dict, Optional
3
 
 
 
 
4
  from graphgen.bases import BaseGraphStorage, BaseKGEvaluator
5
  from graphgen.utils import logger
6
 
 
72
 
73
  @staticmethod
74
  def _calculate_powerlaw_r2(degree_map: Dict[str, int]) -> Optional[float]:
75
+ import numpy as np
76
+ from scipy import stats
77
+
78
  degrees = [deg for deg in degree_map.values() if deg > 0]
79
 
80
  if len(degrees) < 10:
graphgen/models/filter/range_filter.py CHANGED
@@ -1,9 +1,10 @@
1
- from typing import Union
2
-
3
- import numpy as np
4
 
5
  from graphgen.bases import BaseValueFilter
6
 
 
 
 
7
 
8
  class RangeFilter(BaseValueFilter):
9
  """
@@ -22,7 +23,7 @@ class RangeFilter(BaseValueFilter):
22
  self.left_inclusive = left_inclusive
23
  self.right_inclusive = right_inclusive
24
 
25
- def filter(self, data: Union[int, float, np.number]) -> bool:
26
  value = float(data)
27
  if self.left_inclusive and self.right_inclusive:
28
  return self.min_val <= value <= self.max_val
 
1
+ from typing import TYPE_CHECKING, Union
 
 
2
 
3
  from graphgen.bases import BaseValueFilter
4
 
5
+ if TYPE_CHECKING:
6
+ import numpy as np
7
+
8
 
9
  class RangeFilter(BaseValueFilter):
10
  """
 
23
  self.left_inclusive = left_inclusive
24
  self.right_inclusive = right_inclusive
25
 
26
+ def filter(self, data: Union[int, float, "np.number"]) -> bool:
27
  value = float(data)
28
  if self.left_inclusive and self.right_inclusive:
29
  return self.min_val <= value <= self.max_val
graphgen/models/llm/__init__.py CHANGED
@@ -1,4 +1,27 @@
1
- from .api.http_client import HTTPClient
2
- from .api.ollama_client import OllamaClient
3
- from .api.openai_client import OpenAIClient
4
- from .local.hf_wrapper import HuggingFaceWrapper
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ if TYPE_CHECKING:
4
+ from .api.http_client import HTTPClient
5
+ from .api.ollama_client import OllamaClient
6
+ from .api.openai_client import OpenAIClient
7
+ from .local.hf_wrapper import HuggingFaceWrapper
8
+
9
+
10
+ _import_map = {
11
+ "HTTPClient": ".api.http_client",
12
+ "OllamaClient": ".api.ollama_client",
13
+ "OpenAIClient": ".api.openai_client",
14
+ "HuggingFaceWrapper": ".local.hf_wrapper",
15
+ }
16
+
17
+
18
+ def __getattr__(name):
19
+ if name in _import_map:
20
+ import importlib
21
+
22
+ module = importlib.import_module(_import_map[name], package=__name__)
23
+ return getattr(module, name)
24
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
25
+
26
+
27
+ __all__ = list(_import_map.keys())
graphgen/models/partitioner/leiden_partitioner.py CHANGED
@@ -1,12 +1,12 @@
1
  from collections import defaultdict
2
- from typing import Any, Dict, List, Set, Tuple
3
-
4
- import igraph as ig
5
- from leidenalg import ModularityVertexPartition, find_partition
6
 
7
  from graphgen.bases import BaseGraphStorage, BasePartitioner
8
  from graphgen.bases.datatypes import Community
9
 
 
 
 
10
 
11
  class LeidenPartitioner(BasePartitioner):
12
  """
@@ -62,6 +62,9 @@ class LeidenPartitioner(BasePartitioner):
62
  use_lcc: bool = False,
63
  random_seed: int = 42,
64
  ) -> Dict[str, int]:
 
 
 
65
  # build igraph
66
  ig_graph = ig.Graph.TupleList(((u, v) for u, v, _ in edges), directed=False)
67
 
 
1
  from collections import defaultdict
2
+ from typing import TYPE_CHECKING, Any, Dict, List, Set, Tuple
 
 
 
3
 
4
  from graphgen.bases import BaseGraphStorage, BasePartitioner
5
  from graphgen.bases.datatypes import Community
6
 
7
+ if TYPE_CHECKING:
8
+ import igraph as ig
9
+
10
 
11
  class LeidenPartitioner(BasePartitioner):
12
  """
 
62
  use_lcc: bool = False,
63
  random_seed: int = 42,
64
  ) -> Dict[str, int]:
65
+ import igraph as ig
66
+ from leidenalg import ModularityVertexPartition, find_partition
67
+
68
  # build igraph
69
  ig_graph = ig.Graph.TupleList(((u, v) for u, v, _ in edges), directed=False)
70
 
graphgen/models/reader/csv_reader.py CHANGED
@@ -1,10 +1,11 @@
1
- from typing import List, Union
2
-
3
- import ray
4
- from ray.data import Dataset
5
 
6
  from graphgen.bases.base_reader import BaseReader
7
 
 
 
 
 
8
 
9
  class CSVReader(BaseReader):
10
  """
@@ -14,13 +15,14 @@ class CSVReader(BaseReader):
14
  - if type is "text", "content" column must be present.
15
  """
16
 
17
- def read(self, input_path: Union[str, List[str]]) -> Dataset:
18
  """
19
  Read CSV files and return Ray Dataset.
20
 
21
  :param input_path: Path to CSV file or list of CSV files.
22
  :return: Ray Dataset containing validated and filtered data.
23
  """
 
24
 
25
  ds = ray.data.read_csv(input_path, include_paths=True)
26
  ds = ds.map_batches(self._validate_batch, batch_format="pandas")
 
1
+ from typing import TYPE_CHECKING, List, Union
 
 
 
2
 
3
  from graphgen.bases.base_reader import BaseReader
4
 
5
+ if TYPE_CHECKING:
6
+ import ray
7
+ from ray.data import Dataset
8
+
9
 
10
  class CSVReader(BaseReader):
11
  """
 
15
  - if type is "text", "content" column must be present.
16
  """
17
 
18
+ def read(self, input_path: Union[str, List[str]]) -> "Dataset":
19
  """
20
  Read CSV files and return Ray Dataset.
21
 
22
  :param input_path: Path to CSV file or list of CSV files.
23
  :return: Ray Dataset containing validated and filtered data.
24
  """
25
+ import ray
26
 
27
  ds = ray.data.read_csv(input_path, include_paths=True)
28
  ds = ds.map_batches(self._validate_batch, batch_format="pandas")
graphgen/models/reader/json_reader.py CHANGED
@@ -1,11 +1,12 @@
1
  import json
2
- from typing import List, Union
3
-
4
- import ray
5
- import ray.data
6
 
7
  from graphgen.bases.base_reader import BaseReader
8
 
 
 
 
 
9
 
10
  class JSONReader(BaseReader):
11
  """
@@ -15,12 +16,14 @@ class JSONReader(BaseReader):
15
  - if type is "text", "content" column must be present.
16
  """
17
 
18
- def read(self, input_path: Union[str, List[str]]) -> ray.data.Dataset:
19
  """
20
  Read JSON file and return Ray Dataset.
21
  :param input_path: Path to JSON/JSONL file or list of JSON/JSONL files.
22
  :return: Ray Dataset containing validated and filtered data.
23
  """
 
 
24
  if self.modalities and len(self.modalities) >= 2:
25
  ds: ray.data.Dataset = ray.data.from_items([])
26
  for file in input_path if isinstance(input_path, list) else [input_path]:
 
1
  import json
2
+ from typing import TYPE_CHECKING, List, Union
 
 
 
3
 
4
  from graphgen.bases.base_reader import BaseReader
5
 
6
+ if TYPE_CHECKING:
7
+ import ray
8
+ import ray.data
9
+
10
 
11
  class JSONReader(BaseReader):
12
  """
 
16
  - if type is "text", "content" column must be present.
17
  """
18
 
19
+ def read(self, input_path: Union[str, List[str]]) -> "ray.data.Dataset":
20
  """
21
  Read JSON file and return Ray Dataset.
22
  :param input_path: Path to JSON/JSONL file or list of JSON/JSONL files.
23
  :return: Ray Dataset containing validated and filtered data.
24
  """
25
+ import ray
26
+
27
  if self.modalities and len(self.modalities) >= 2:
28
  ds: ray.data.Dataset = ray.data.from_items([])
29
  for file in input_path if isinstance(input_path, list) else [input_path]:
graphgen/models/reader/parquet_reader.py CHANGED
@@ -1,10 +1,11 @@
1
- from typing import List, Union
2
-
3
- import ray
4
- from ray.data import Dataset
5
 
6
  from graphgen.bases.base_reader import BaseReader
7
 
 
 
 
 
8
 
9
  class ParquetReader(BaseReader):
10
  """
@@ -14,13 +15,15 @@ class ParquetReader(BaseReader):
14
  - if type is "text", "content" column must be present.
15
  """
16
 
17
- def read(self, input_path: Union[str, List[str]]) -> Dataset:
18
  """
19
  Read Parquet files using Ray Data.
20
 
21
  :param input_path: Path to Parquet file or list of Parquet files.
22
  :return: Ray Dataset containing validated documents.
23
  """
 
 
24
  if not ray.is_initialized():
25
  ray.init()
26
 
 
1
+ from typing import TYPE_CHECKING, List, Union
 
 
 
2
 
3
  from graphgen.bases.base_reader import BaseReader
4
 
5
+ if TYPE_CHECKING:
6
+ import ray
7
+ from ray.data import Dataset
8
+
9
 
10
  class ParquetReader(BaseReader):
11
  """
 
15
  - if type is "text", "content" column must be present.
16
  """
17
 
18
+ def read(self, input_path: Union[str, List[str]]) -> "Dataset":
19
  """
20
  Read Parquet files using Ray Data.
21
 
22
  :param input_path: Path to Parquet file or list of Parquet files.
23
  :return: Ray Dataset containing validated documents.
24
  """
25
+ import ray
26
+
27
  if not ray.is_initialized():
28
  ray.init()
29
 
graphgen/models/reader/pdf_reader.py CHANGED
@@ -3,15 +3,16 @@ import os
3
  import subprocess
4
  import tempfile
5
  from pathlib import Path
6
- from typing import Any, Dict, List, Optional, Union
7
-
8
- import ray
9
- from ray.data import Dataset
10
 
11
  from graphgen.bases.base_reader import BaseReader
12
  from graphgen.models.reader.txt_reader import TXTReader
13
  from graphgen.utils import logger, pick_device
14
 
 
 
 
 
15
 
16
  class PDFReader(BaseReader):
17
  """
@@ -69,7 +70,8 @@ class PDFReader(BaseReader):
69
  self,
70
  input_path: Union[str, List[str]],
71
  **override,
72
- ) -> Dataset:
 
73
 
74
  # Ensure input_path is a list
75
  if isinstance(input_path, str):
 
3
  import subprocess
4
  import tempfile
5
  from pathlib import Path
6
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
 
 
 
7
 
8
  from graphgen.bases.base_reader import BaseReader
9
  from graphgen.models.reader.txt_reader import TXTReader
10
  from graphgen.utils import logger, pick_device
11
 
12
+ if TYPE_CHECKING:
13
+ import ray
14
+ from ray.data import Dataset
15
+
16
 
17
  class PDFReader(BaseReader):
18
  """
 
70
  self,
71
  input_path: Union[str, List[str]],
72
  **override,
73
+ ) -> "Dataset":
74
+ import ray
75
 
76
  # Ensure input_path is a list
77
  if isinstance(input_path, str):
graphgen/models/reader/pickle_reader.py CHANGED
@@ -1,13 +1,13 @@
1
  import pickle
2
- from typing import List, Union
3
-
4
- import pandas as pd
5
- import ray
6
- from ray.data import Dataset
7
 
8
  from graphgen.bases.base_reader import BaseReader
9
  from graphgen.utils import logger
10
 
 
 
 
 
11
 
12
  class PickleReader(BaseReader):
13
  """
@@ -23,13 +23,16 @@ class PickleReader(BaseReader):
23
  def read(
24
  self,
25
  input_path: Union[str, List[str]],
26
- ) -> Dataset:
27
  """
28
  Read Pickle files using Ray Data.
29
 
30
  :param input_path: Path to pickle file or list of pickle files.
31
  :return: Ray Dataset containing validated documents.
32
  """
 
 
 
33
  if not ray.is_initialized():
34
  ray.init()
35
 
@@ -37,7 +40,7 @@ class PickleReader(BaseReader):
37
  ds = ray.data.read_binary_files(input_path, include_paths=True)
38
 
39
  # Deserialize pickle files and flatten into individual records
40
- def deserialize_batch(batch: pd.DataFrame) -> pd.DataFrame:
41
  all_records = []
42
  for _, row in batch.iterrows():
43
  try:
 
1
  import pickle
2
+ from typing import TYPE_CHECKING, List, Union
 
 
 
 
3
 
4
  from graphgen.bases.base_reader import BaseReader
5
  from graphgen.utils import logger
6
 
7
+ if TYPE_CHECKING:
8
+ import pandas as pd
9
+ from ray.data import Dataset
10
+
11
 
12
  class PickleReader(BaseReader):
13
  """
 
23
  def read(
24
  self,
25
  input_path: Union[str, List[str]],
26
+ ) -> "Dataset":
27
  """
28
  Read Pickle files using Ray Data.
29
 
30
  :param input_path: Path to pickle file or list of pickle files.
31
  :return: Ray Dataset containing validated documents.
32
  """
33
+ import pandas as pd
34
+ import ray
35
+
36
  if not ray.is_initialized():
37
  ray.init()
38
 
 
40
  ds = ray.data.read_binary_files(input_path, include_paths=True)
41
 
42
  # Deserialize pickle files and flatten into individual records
43
+ def deserialize_batch(batch: "pd.DataFrame") -> "pd.DataFrame":
44
  all_records = []
45
  for _, row in batch.iterrows():
46
  try:
graphgen/models/reader/rdf_reader.py CHANGED
@@ -1,15 +1,15 @@
1
  from pathlib import Path
2
- from typing import Any, Dict, List, Union
3
-
4
- import ray
5
- import rdflib
6
- from ray.data import Dataset
7
- from rdflib import Literal
8
- from rdflib.util import guess_format
9
 
10
  from graphgen.bases.base_reader import BaseReader
11
  from graphgen.utils import logger
12
 
 
 
 
 
 
 
13
 
14
  class RDFReader(BaseReader):
15
  """
@@ -30,13 +30,15 @@ class RDFReader(BaseReader):
30
  def read(
31
  self,
32
  input_path: Union[str, List[str]],
33
- ) -> Dataset:
34
  """
35
  Read RDF file(s) using Ray Data.
36
 
37
  :param input_path: Path to RDF file or list of RDF files.
38
  :return: Ray Dataset containing extracted documents.
39
  """
 
 
40
  if not ray.is_initialized():
41
  ray.init()
42
 
@@ -73,6 +75,10 @@ class RDFReader(BaseReader):
73
  :param file_path: Path to RDF file.
74
  :return: List of document dictionaries.
75
  """
 
 
 
 
76
  if not file_path.is_file():
77
  raise FileNotFoundError(f"RDF file not found: {file_path}")
78
 
 
1
  from pathlib import Path
2
+ from typing import TYPE_CHECKING, Any, Dict, List, Union
 
 
 
 
 
 
3
 
4
  from graphgen.bases.base_reader import BaseReader
5
  from graphgen.utils import logger
6
 
7
+ if TYPE_CHECKING:
8
+ import ray
9
+ import rdflib
10
+ from ray.data import Dataset
11
+ from rdflib import Literal
12
+
13
 
14
  class RDFReader(BaseReader):
15
  """
 
30
  def read(
31
  self,
32
  input_path: Union[str, List[str]],
33
+ ) -> "Dataset":
34
  """
35
  Read RDF file(s) using Ray Data.
36
 
37
  :param input_path: Path to RDF file or list of RDF files.
38
  :return: Ray Dataset containing extracted documents.
39
  """
40
+ import ray
41
+
42
  if not ray.is_initialized():
43
  ray.init()
44
 
 
75
  :param file_path: Path to RDF file.
76
  :return: List of document dictionaries.
77
  """
78
+ import rdflib
79
+ from rdflib import Literal
80
+ from rdflib.util import guess_format
81
+
82
  if not file_path.is_file():
83
  raise FileNotFoundError(f"RDF file not found: {file_path}")
84
 
graphgen/models/reader/txt_reader.py CHANGED
@@ -1,21 +1,24 @@
1
- from typing import List, Union
2
-
3
- import ray
4
- from ray.data import Dataset
5
 
6
  from graphgen.bases.base_reader import BaseReader
7
 
 
 
 
 
8
 
9
  class TXTReader(BaseReader):
10
  def read(
11
  self,
12
  input_path: Union[str, List[str]],
13
- ) -> Dataset:
14
  """
15
  Read text files from the specified input path.
16
  :param input_path: Path to the input text file or list of text files.
17
  :return: Ray Dataset containing the read text data.
18
  """
 
 
19
  docs_ds = ray.data.read_binary_files(
20
  input_path,
21
  include_paths=True,
 
1
+ from typing import TYPE_CHECKING, List, Union
 
 
 
2
 
3
  from graphgen.bases.base_reader import BaseReader
4
 
5
+ if TYPE_CHECKING:
6
+ import ray
7
+ from ray.data import Dataset
8
+
9
 
10
  class TXTReader(BaseReader):
11
  def read(
12
  self,
13
  input_path: Union[str, List[str]],
14
+ ) -> "Dataset":
15
  """
16
  Read text files from the specified input path.
17
  :param input_path: Path to the input text file or list of text files.
18
  :return: Ray Dataset containing the read text data.
19
  """
20
+ import ray
21
+
22
  docs_ds = ray.data.read_binary_files(
23
  input_path,
24
  include_paths=True,
graphgen/models/tokenizer/__init__.py CHANGED
@@ -4,29 +4,21 @@ from graphgen.bases import BaseTokenizer
4
 
5
  from .tiktoken_tokenizer import TiktokenTokenizer
6
 
7
- try:
8
- from transformers import AutoTokenizer
9
-
10
- _HF_AVAILABLE = True
11
- except ImportError:
12
- _HF_AVAILABLE = False
13
-
14
 
15
  def get_tokenizer_impl(tokenizer_name: str = "cl100k_base") -> BaseTokenizer:
16
  import tiktoken
17
 
18
  if tokenizer_name in tiktoken.list_encoding_names():
19
  return TiktokenTokenizer(model_name=tokenizer_name)
20
-
21
- # 2. HuggingFace
22
- if _HF_AVAILABLE:
23
  from .hf_tokenizer import HFTokenizer
24
 
25
  return HFTokenizer(model_name=tokenizer_name)
26
-
27
- raise ValueError(
28
- f"Unknown tokenizer {tokenizer_name} and HuggingFace not available."
29
- )
30
 
31
 
32
  class Tokenizer(BaseTokenizer):
 
4
 
5
  from .tiktoken_tokenizer import TiktokenTokenizer
6
 
 
 
 
 
 
 
 
7
 
8
  def get_tokenizer_impl(tokenizer_name: str = "cl100k_base") -> BaseTokenizer:
9
  import tiktoken
10
 
11
  if tokenizer_name in tiktoken.list_encoding_names():
12
  return TiktokenTokenizer(model_name=tokenizer_name)
13
+ try:
14
+ # HuggingFace
 
15
  from .hf_tokenizer import HFTokenizer
16
 
17
  return HFTokenizer(model_name=tokenizer_name)
18
+ except ImportError as e:
19
+ raise ValueError(
20
+ f"Unknown tokenizer {tokenizer_name} and HuggingFace not available."
21
+ ) from e
22
 
23
 
24
  class Tokenizer(BaseTokenizer):
graphgen/models/tokenizer/hf_tokenizer.py CHANGED
@@ -1,13 +1,13 @@
1
  from typing import List
2
 
3
- from transformers import AutoTokenizer
4
-
5
  from graphgen.bases import BaseTokenizer
6
 
7
 
8
  class HFTokenizer(BaseTokenizer):
9
  def __init__(self, model_name: str = "cl100k_base"):
10
  super().__init__(model_name)
 
 
11
  self.enc = AutoTokenizer.from_pretrained(self.model_name)
12
 
13
  def encode(self, text: str) -> List[int]:
 
1
  from typing import List
2
 
 
 
3
  from graphgen.bases import BaseTokenizer
4
 
5
 
6
  class HFTokenizer(BaseTokenizer):
7
  def __init__(self, model_name: str = "cl100k_base"):
8
  super().__init__(model_name)
9
+ from transformers import AutoTokenizer
10
+
11
  self.enc = AutoTokenizer.from_pretrained(self.model_name)
12
 
13
  def encode(self, text: str) -> List[int]:
graphgen/models/tokenizer/tiktoken_tokenizer.py CHANGED
@@ -1,13 +1,13 @@
1
  from typing import List
2
 
3
- import tiktoken
4
-
5
  from graphgen.bases import BaseTokenizer
6
 
7
 
8
  class TiktokenTokenizer(BaseTokenizer):
9
  def __init__(self, model_name: str = "cl100k_base"):
10
  super().__init__(model_name)
 
 
11
  self.enc = tiktoken.get_encoding(self.model_name)
12
 
13
  def encode(self, text: str) -> List[int]:
 
1
  from typing import List
2
 
 
 
3
  from graphgen.bases import BaseTokenizer
4
 
5
 
6
  class TiktokenTokenizer(BaseTokenizer):
7
  def __init__(self, model_name: str = "cl100k_base"):
8
  super().__init__(model_name)
9
+ import tiktoken
10
+
11
  self.enc = tiktoken.get_encoding(self.model_name)
12
 
13
  def encode(self, text: str) -> List[int]:
graphgen/operators/chunk/chunk_service.py CHANGED
@@ -1,36 +1,40 @@
1
  import os
2
  from functools import lru_cache
3
- from typing import Tuple, Union
4
 
5
  from graphgen.bases import BaseOperator
6
- from graphgen.models import (
7
- ChineseRecursiveTextSplitter,
8
- RecursiveCharacterSplitter,
9
- Tokenizer,
10
- )
11
  from graphgen.utils import detect_main_language
12
 
13
- _MAPPING = {
14
- "en": RecursiveCharacterSplitter,
15
- "zh": ChineseRecursiveTextSplitter,
16
- }
 
 
17
 
18
- SplitterT = Union[RecursiveCharacterSplitter, ChineseRecursiveTextSplitter]
 
 
 
19
 
20
 
21
  @lru_cache(maxsize=None)
22
  def _get_splitter(language: str, frozen_kwargs: frozenset) -> SplitterT:
23
- cls = _MAPPING[language]
24
  kwargs = dict(frozen_kwargs)
25
- return cls(**kwargs)
 
 
 
 
 
 
 
 
 
 
26
 
27
 
28
  def split_chunks(text: str, language: str = "en", **kwargs) -> list:
29
- if language not in _MAPPING:
30
- raise ValueError(
31
- f"Unsupported language: {language}. "
32
- f"Supported languages are: {list(_MAPPING.keys())}"
33
- )
34
  frozen_kwargs = frozenset(
35
  (k, tuple(v) if isinstance(v, list) else v) for k, v in kwargs.items()
36
  )
@@ -45,10 +49,18 @@ class ChunkService(BaseOperator):
45
  super().__init__(
46
  working_dir=working_dir, kv_backend=kv_backend, op_name="chunk"
47
  )
48
- tokenizer_model = os.getenv("TOKENIZER_MODEL", "cl100k_base")
49
- self.tokenizer_instance: Tokenizer = Tokenizer(model_name=tokenizer_model)
50
  self.chunk_kwargs = chunk_kwargs
51
 
 
 
 
 
 
 
 
 
52
  def process(self, batch: list) -> Tuple[list, dict]:
53
  """
54
  Chunk the documents in the batch.
 
1
  import os
2
  from functools import lru_cache
3
+ from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
4
 
5
  from graphgen.bases import BaseOperator
 
 
 
 
 
6
  from graphgen.utils import detect_main_language
7
 
8
+ if TYPE_CHECKING:
9
+ from graphgen.models import (
10
+ ChineseRecursiveTextSplitter,
11
+ RecursiveCharacterSplitter,
12
+ Tokenizer,
13
+ )
14
 
15
+ if TYPE_CHECKING:
16
+ SplitterT = Union["RecursiveCharacterSplitter", "ChineseRecursiveTextSplitter"]
17
+ else:
18
+ SplitterT = Any
19
 
20
 
21
  @lru_cache(maxsize=None)
22
  def _get_splitter(language: str, frozen_kwargs: frozenset) -> SplitterT:
 
23
  kwargs = dict(frozen_kwargs)
24
+ if language == "en":
25
+ from graphgen.models import RecursiveCharacterSplitter
26
+
27
+ return RecursiveCharacterSplitter(**kwargs)
28
+ if language == "zh":
29
+ from graphgen.models import ChineseRecursiveTextSplitter
30
+
31
+ return ChineseRecursiveTextSplitter(**kwargs)
32
+ raise ValueError(
33
+ f"Unsupported language: {language}. Supported languages are: en, zh"
34
+ )
35
 
36
 
37
  def split_chunks(text: str, language: str = "en", **kwargs) -> list:
 
 
 
 
 
38
  frozen_kwargs = frozenset(
39
  (k, tuple(v) if isinstance(v, list) else v) for k, v in kwargs.items()
40
  )
 
49
  super().__init__(
50
  working_dir=working_dir, kv_backend=kv_backend, op_name="chunk"
51
  )
52
+ self.tokenizer_model = os.getenv("TOKENIZER_MODEL", "cl100k_base")
53
+ self._tokenizer_instance: Optional["Tokenizer"] = None
54
  self.chunk_kwargs = chunk_kwargs
55
 
56
+ @property
57
+ def tokenizer_instance(self) -> "Tokenizer":
58
+ if self._tokenizer_instance is None:
59
+ from graphgen.models import Tokenizer
60
+
61
+ self._tokenizer_instance = Tokenizer(model_name=self.tokenizer_model)
62
+ return self._tokenizer_instance
63
+
64
  def process(self, batch: list) -> Tuple[list, dict]:
65
  """
66
  Chunk the documents in the batch.
graphgen/operators/partition/partition_service.py CHANGED
@@ -3,14 +3,6 @@ from typing import Iterable, Tuple
3
 
4
  from graphgen.bases import BaseGraphStorage, BaseOperator, BaseTokenizer
5
  from graphgen.common.init_storage import init_storage
6
- from graphgen.models import (
7
- AnchorBFSPartitioner,
8
- BFSPartitioner,
9
- DFSPartitioner,
10
- ECEPartitioner,
11
- LeidenPartitioner,
12
- Tokenizer,
13
- )
14
  from graphgen.utils import logger
15
 
16
 
@@ -31,21 +23,34 @@ class PartitionService(BaseOperator):
31
  namespace="graph",
32
  )
33
  tokenizer_model = os.getenv("TOKENIZER_MODEL", "cl100k_base")
 
 
 
34
  self.tokenizer_instance: BaseTokenizer = Tokenizer(model_name=tokenizer_model)
35
  method = partition_kwargs["method"]
36
  self.method_params = partition_kwargs["method_params"]
37
 
38
  if method == "bfs":
 
 
39
  self.partitioner = BFSPartitioner()
40
  elif method == "dfs":
 
 
41
  self.partitioner = DFSPartitioner()
42
  elif method == "ece":
43
  # before ECE partitioning, we need to:
44
  # 'quiz' and 'judge' to get the comprehension loss if unit_sampling is not random
 
 
45
  self.partitioner = ECEPartitioner()
46
  elif method == "leiden":
 
 
47
  self.partitioner = LeidenPartitioner()
48
  elif method == "anchor_bfs":
 
 
49
  self.partitioner = AnchorBFSPartitioner(
50
  anchor_type=self.method_params.get("anchor_type"),
51
  anchor_ids=set(self.method_params.get("anchor_ids", []))
 
3
 
4
  from graphgen.bases import BaseGraphStorage, BaseOperator, BaseTokenizer
5
  from graphgen.common.init_storage import init_storage
 
 
 
 
 
 
 
 
6
  from graphgen.utils import logger
7
 
8
 
 
23
  namespace="graph",
24
  )
25
  tokenizer_model = os.getenv("TOKENIZER_MODEL", "cl100k_base")
26
+
27
+ from graphgen.models import Tokenizer
28
+
29
  self.tokenizer_instance: BaseTokenizer = Tokenizer(model_name=tokenizer_model)
30
  method = partition_kwargs["method"]
31
  self.method_params = partition_kwargs["method_params"]
32
 
33
  if method == "bfs":
34
+ from graphgen.models import BFSPartitioner
35
+
36
  self.partitioner = BFSPartitioner()
37
  elif method == "dfs":
38
+ from graphgen.models import DFSPartitioner
39
+
40
  self.partitioner = DFSPartitioner()
41
  elif method == "ece":
42
  # before ECE partitioning, we need to:
43
  # 'quiz' and 'judge' to get the comprehension loss if unit_sampling is not random
44
+ from graphgen.models import ECEPartitioner
45
+
46
  self.partitioner = ECEPartitioner()
47
  elif method == "leiden":
48
+ from graphgen.models import LeidenPartitioner
49
+
50
  self.partitioner = LeidenPartitioner()
51
  elif method == "anchor_bfs":
52
+ from graphgen.models import AnchorBFSPartitioner
53
+
54
  self.partitioner = AnchorBFSPartitioner(
55
  anchor_type=self.method_params.get("anchor_type"),
56
  anchor_ids=set(self.method_params.get("anchor_ids", []))
graphgen/operators/read/read.py CHANGED
@@ -1,7 +1,5 @@
1
  from pathlib import Path
2
- from typing import Any, List, Optional, Union
3
-
4
- import ray
5
 
6
  from graphgen.common.init_storage import init_storage
7
  from graphgen.models import (
@@ -17,6 +15,11 @@ from graphgen.utils import compute_dict_hash, logger
17
 
18
  from .parallel_file_scanner import ParallelFileScanner
19
 
 
 
 
 
 
20
  _MAPPING = {
21
  "jsonl": JSONReader,
22
  "json": JSONReader,
@@ -57,7 +60,7 @@ def read(
57
  recursive: bool = True,
58
  read_nums: Optional[int] = None,
59
  **reader_kwargs: Any,
60
- ) -> ray.data.Dataset:
61
  """
62
  Unified entry point to read files of multiple types using Ray Data.
63
 
@@ -71,6 +74,8 @@ def read(
71
  :param reader_kwargs: Additional kwargs passed to readers
72
  :return: Ray Dataset containing all documents
73
  """
 
 
74
  input_path_cache = init_storage(
75
  backend=kv_backend, working_dir=working_dir, namespace="input_path"
76
  )
 
1
  from pathlib import Path
2
+ from typing import TYPE_CHECKING, Any, List, Optional, Union
 
 
3
 
4
  from graphgen.common.init_storage import init_storage
5
  from graphgen.models import (
 
15
 
16
  from .parallel_file_scanner import ParallelFileScanner
17
 
18
+ if TYPE_CHECKING:
19
+ import ray
20
+ import ray.data
21
+
22
+
23
  _MAPPING = {
24
  "jsonl": JSONReader,
25
  "json": JSONReader,
 
60
  recursive: bool = True,
61
  read_nums: Optional[int] = None,
62
  **reader_kwargs: Any,
63
+ ) -> "ray.data.Dataset":
64
  """
65
  Unified entry point to read files of multiple types using Ray Data.
66
 
 
74
  :param reader_kwargs: Additional kwargs passed to readers
75
  :return: Ray Dataset containing all documents
76
  """
77
+ import ray
78
+
79
  input_path_cache = init_storage(
80
  backend=kv_backend, working_dir=working_dir, namespace="input_path"
81
  )
graphgen/operators/search/search_service.py CHANGED
@@ -1,12 +1,13 @@
1
  from functools import partial
2
- from typing import Optional
3
-
4
- import pandas as pd
5
 
6
  from graphgen.bases import BaseOperator
7
  from graphgen.common.init_storage import init_storage
8
  from graphgen.utils import compute_content_hash, logger, run_concurrent
9
 
 
 
 
10
 
11
  class SearchService(BaseOperator):
12
  """
@@ -136,7 +137,9 @@ class SearchService(BaseOperator):
136
 
137
  return final_results
138
 
139
- def process(self, batch: pd.DataFrame) -> pd.DataFrame:
 
 
140
  docs = batch.to_dict(orient="records")
141
 
142
  self._init_searchers()
 
1
  from functools import partial
2
+ from typing import TYPE_CHECKING, Optional
 
 
3
 
4
  from graphgen.bases import BaseOperator
5
  from graphgen.common.init_storage import init_storage
6
  from graphgen.utils import compute_content_hash, logger, run_concurrent
7
 
8
+ if TYPE_CHECKING:
9
+ import pandas as pd
10
+
11
 
12
  class SearchService(BaseOperator):
13
  """
 
137
 
138
  return final_results
139
 
140
+ def process(self, batch: "pd.DataFrame") -> "pd.DataFrame":
141
+ import pandas as pd
142
+
143
  docs = batch.to_dict(orient="records")
144
 
145
  self._init_searchers()
graphgen/storage/graph/networkx_storage.py CHANGED
@@ -26,6 +26,7 @@ class NetworkXStorage(BaseGraphStorage):
26
  return self._graph.number_of_edges()
27
 
28
  def get_connected_components(self, undirected: bool = True) -> List[Set[str]]:
 
29
  graph = self._graph
30
 
31
  if undirected and graph.is_directed():
@@ -36,24 +37,27 @@ class NetworkXStorage(BaseGraphStorage):
36
  ]
37
 
38
  @staticmethod
39
- def load_nx_graph(file_name) -> Optional[nx.Graph]:
 
40
  if os.path.exists(file_name):
41
  return nx.read_graphml(file_name)
42
  return None
43
 
44
  @staticmethod
45
- def write_nx_graph(graph: nx.Graph, file_name):
 
46
  nx.write_graphml(graph, file_name)
47
 
48
  @staticmethod
49
- def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
50
  """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
51
  Return the largest connected component of the graph, with nodes and edges sorted in a stable way.
52
  """
 
53
  from graspologic.utils import largest_connected_component
54
 
55
  graph = graph.copy()
56
- graph = cast(nx.Graph, largest_connected_component(graph))
57
  node_mapping = {
58
  node: html.unescape(node.upper().strip()) for node in graph.nodes()
59
  } # type: ignore
@@ -61,11 +65,12 @@ class NetworkXStorage(BaseGraphStorage):
61
  return NetworkXStorage._stabilize_graph(graph)
62
 
63
  @staticmethod
64
- def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
65
  """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
66
  Ensure an undirected graph with the same relationships will always be read the same way.
67
  通过对节点和边进行排序来实现
68
  """
 
69
  fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
70
 
71
  sorted_nodes = graph.nodes(data=True)
@@ -97,6 +102,7 @@ class NetworkXStorage(BaseGraphStorage):
97
  Initialize the NetworkX graph storage by loading an existing graph from a GraphML file,
98
  if it exists, or creating a new empty graph otherwise.
99
  """
 
100
  self._graphml_xml_file = os.path.join(
101
  self.working_dir, f"{self.namespace}.graphml"
102
  )
@@ -141,7 +147,7 @@ class NetworkXStorage(BaseGraphStorage):
141
  return list(self._graph.edges(source_node_id, data=True))
142
  return None
143
 
144
- def get_graph(self) -> nx.Graph:
145
  return self._graph
146
 
147
  def upsert_node(self, node_id: str, node_data: dict[str, any]):
 
26
  return self._graph.number_of_edges()
27
 
28
  def get_connected_components(self, undirected: bool = True) -> List[Set[str]]:
29
+
30
  graph = self._graph
31
 
32
  if undirected and graph.is_directed():
 
37
  ]
38
 
39
  @staticmethod
40
+ def load_nx_graph(file_name) -> Optional["nx.Graph"]:
41
+
42
  if os.path.exists(file_name):
43
  return nx.read_graphml(file_name)
44
  return None
45
 
46
  @staticmethod
47
+ def write_nx_graph(graph: "nx.Graph", file_name):
48
+
49
  nx.write_graphml(graph, file_name)
50
 
51
  @staticmethod
52
+ def stable_largest_connected_component(graph: "nx.Graph") -> "nx.Graph":
53
  """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
54
  Return the largest connected component of the graph, with nodes and edges sorted in a stable way.
55
  """
56
+
57
  from graspologic.utils import largest_connected_component
58
 
59
  graph = graph.copy()
60
+ graph = cast("nx.Graph", largest_connected_component(graph))
61
  node_mapping = {
62
  node: html.unescape(node.upper().strip()) for node in graph.nodes()
63
  } # type: ignore
 
65
  return NetworkXStorage._stabilize_graph(graph)
66
 
67
  @staticmethod
68
+ def _stabilize_graph(graph: "nx.Graph") -> "nx.Graph":
69
  """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
70
  Ensure an undirected graph with the same relationships will always be read the same way.
71
  通过对节点和边进行排序来实现
72
  """
73
+
74
  fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
75
 
76
  sorted_nodes = graph.nodes(data=True)
 
102
  Initialize the NetworkX graph storage by loading an existing graph from a GraphML file,
103
  if it exists, or creating a new empty graph otherwise.
104
  """
105
+
106
  self._graphml_xml_file = os.path.join(
107
  self.working_dir, f"{self.namespace}.graphml"
108
  )
 
147
  return list(self._graph.edges(source_node_id, data=True))
148
  return None
149
 
150
+ def get_graph(self) -> "nx.Graph":
151
  return self._graph
152
 
153
  def upsert_node(self, node_id: str, node_data: dict[str, any]):