Spaces:
Running
Running
github-actions[bot]
commited on
Commit
·
5445ab9
1
Parent(s):
0a18089
Auto-sync from demo at Fri Jan 30 10:55:05 UTC 2026
Browse files- graphgen/bases/base_filter.py +4 -3
- graphgen/bases/base_operator.py +17 -7
- graphgen/bases/base_reader.py +8 -3
- graphgen/common/init_llm.py +7 -4
- graphgen/common/init_storage.py +21 -1
- graphgen/models/__init__.py +121 -48
- graphgen/models/evaluator/kg/structure_evaluator.py +3 -3
- graphgen/models/filter/range_filter.py +5 -4
- graphgen/models/llm/__init__.py +27 -4
- graphgen/models/partitioner/leiden_partitioner.py +7 -4
- graphgen/models/reader/csv_reader.py +7 -5
- graphgen/models/reader/json_reader.py +8 -5
- graphgen/models/reader/parquet_reader.py +8 -5
- graphgen/models/reader/pdf_reader.py +7 -5
- graphgen/models/reader/pickle_reader.py +10 -7
- graphgen/models/reader/rdf_reader.py +14 -8
- graphgen/models/reader/txt_reader.py +8 -5
- graphgen/models/tokenizer/__init__.py +6 -14
- graphgen/models/tokenizer/hf_tokenizer.py +2 -2
- graphgen/models/tokenizer/tiktoken_tokenizer.py +2 -2
- graphgen/operators/chunk/chunk_service.py +32 -20
- graphgen/operators/partition/partition_service.py +13 -8
- graphgen/operators/read/read.py +9 -4
- graphgen/operators/search/search_service.py +7 -4
- graphgen/storage/graph/networkx_storage.py +12 -6
graphgen/bases/base_filter.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
from abc import ABC, abstractmethod
|
| 2 |
-
from typing import Any, Union
|
| 3 |
|
| 4 |
-
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
class BaseFilter(ABC):
|
|
@@ -15,7 +16,7 @@ class BaseFilter(ABC):
|
|
| 15 |
|
| 16 |
class BaseValueFilter(BaseFilter, ABC):
|
| 17 |
@abstractmethod
|
| 18 |
-
def filter(self, data: Union[int, float, np.number]) -> bool:
|
| 19 |
"""
|
| 20 |
Filter the numeric value and return True if it passes the filter, False otherwise.
|
| 21 |
"""
|
|
|
|
| 1 |
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import TYPE_CHECKING, Any, Union
|
| 3 |
|
| 4 |
+
if TYPE_CHECKING:
|
| 5 |
+
import numpy as np
|
| 6 |
|
| 7 |
|
| 8 |
class BaseFilter(ABC):
|
|
|
|
| 16 |
|
| 17 |
class BaseValueFilter(BaseFilter, ABC):
|
| 18 |
@abstractmethod
|
| 19 |
+
def filter(self, data: Union[int, float, "np.number"]) -> bool:
|
| 20 |
"""
|
| 21 |
Filter the numeric value and return True if it passes the filter, False otherwise.
|
| 22 |
"""
|
graphgen/bases/base_operator.py
CHANGED
|
@@ -1,14 +1,18 @@
|
|
|
|
|
|
|
|
| 1 |
import inspect
|
| 2 |
import os
|
| 3 |
from abc import ABC, abstractmethod
|
| 4 |
-
from typing import Iterable, Tuple, Union
|
| 5 |
|
| 6 |
-
|
| 7 |
-
import
|
| 8 |
-
import
|
| 9 |
|
| 10 |
|
| 11 |
def convert_to_serializable(obj):
|
|
|
|
|
|
|
| 12 |
if isinstance(obj, np.ndarray):
|
| 13 |
return obj.tolist()
|
| 14 |
if isinstance(obj, np.generic):
|
|
@@ -40,6 +44,8 @@ class BaseOperator(ABC):
|
|
| 40 |
)
|
| 41 |
|
| 42 |
try:
|
|
|
|
|
|
|
| 43 |
ctx = ray.get_runtime_context()
|
| 44 |
worker_id = ctx.get_actor_id() or ctx.get_worker_id()
|
| 45 |
worker_id_short = worker_id[-6:] if worker_id else "driver"
|
|
@@ -62,9 +68,11 @@ class BaseOperator(ABC):
|
|
| 62 |
)
|
| 63 |
|
| 64 |
def __call__(
|
| 65 |
-
self, batch: pd.DataFrame
|
| 66 |
-
) -> Union[pd.DataFrame, Iterable[pd.DataFrame]]:
|
| 67 |
# lazy import to avoid circular import
|
|
|
|
|
|
|
| 68 |
from graphgen.utils import CURRENT_LOGGER_VAR
|
| 69 |
|
| 70 |
logger_token = CURRENT_LOGGER_VAR.set(self.logger)
|
|
@@ -106,7 +114,7 @@ class BaseOperator(ABC):
|
|
| 106 |
|
| 107 |
return compute_dict_hash(content, prefix=f"{self.op_name}-")
|
| 108 |
|
| 109 |
-
def split(self, batch: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
|
| 110 |
"""
|
| 111 |
Split the input batch into to_process & processed based on _meta data in KV_storage
|
| 112 |
:param batch
|
|
@@ -114,6 +122,8 @@ class BaseOperator(ABC):
|
|
| 114 |
to_process: DataFrame of documents to be chunked
|
| 115 |
recovered: Result DataFrame of already chunked documents
|
| 116 |
"""
|
|
|
|
|
|
|
| 117 |
meta_forward = self.get_meta_forward()
|
| 118 |
meta_ids = set(meta_forward.keys())
|
| 119 |
mask = batch["_trace_id"].isin(meta_ids)
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
import inspect
|
| 4 |
import os
|
| 5 |
from abc import ABC, abstractmethod
|
| 6 |
+
from typing import TYPE_CHECKING, Iterable, Tuple, Union
|
| 7 |
|
| 8 |
+
if TYPE_CHECKING:
|
| 9 |
+
import numpy as np
|
| 10 |
+
import pandas as pd
|
| 11 |
|
| 12 |
|
| 13 |
def convert_to_serializable(obj):
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
if isinstance(obj, np.ndarray):
|
| 17 |
return obj.tolist()
|
| 18 |
if isinstance(obj, np.generic):
|
|
|
|
| 44 |
)
|
| 45 |
|
| 46 |
try:
|
| 47 |
+
import ray
|
| 48 |
+
|
| 49 |
ctx = ray.get_runtime_context()
|
| 50 |
worker_id = ctx.get_actor_id() or ctx.get_worker_id()
|
| 51 |
worker_id_short = worker_id[-6:] if worker_id else "driver"
|
|
|
|
| 68 |
)
|
| 69 |
|
| 70 |
def __call__(
|
| 71 |
+
self, batch: "pd.DataFrame"
|
| 72 |
+
) -> Union["pd.DataFrame", Iterable["pd.DataFrame"]]:
|
| 73 |
# lazy import to avoid circular import
|
| 74 |
+
import pandas as pd
|
| 75 |
+
|
| 76 |
from graphgen.utils import CURRENT_LOGGER_VAR
|
| 77 |
|
| 78 |
logger_token = CURRENT_LOGGER_VAR.set(self.logger)
|
|
|
|
| 114 |
|
| 115 |
return compute_dict_hash(content, prefix=f"{self.op_name}-")
|
| 116 |
|
| 117 |
+
def split(self, batch: "pd.DataFrame") -> tuple["pd.DataFrame", "pd.DataFrame"]:
|
| 118 |
"""
|
| 119 |
Split the input batch into to_process & processed based on _meta data in KV_storage
|
| 120 |
:param batch
|
|
|
|
| 122 |
to_process: DataFrame of documents to be chunked
|
| 123 |
recovered: Result DataFrame of already chunked documents
|
| 124 |
"""
|
| 125 |
+
import pandas as pd
|
| 126 |
+
|
| 127 |
meta_forward = self.get_meta_forward()
|
| 128 |
meta_ids = set(meta_forward.keys())
|
| 129 |
mask = batch["_trace_id"].isin(meta_ids)
|
graphgen/bases/base_reader.py
CHANGED
|
@@ -1,10 +1,14 @@
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
from abc import ABC, abstractmethod
|
| 3 |
-
from typing import Any, Dict, List, Union
|
| 4 |
|
| 5 |
-
import pandas as pd
|
| 6 |
import requests
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
class BaseReader(ABC):
|
|
@@ -51,6 +55,7 @@ class BaseReader(ABC):
|
|
| 51 |
"""
|
| 52 |
Validate data format.
|
| 53 |
"""
|
|
|
|
| 54 |
if "type" not in batch.columns:
|
| 55 |
raise ValueError(f"Missing 'type' column. Found: {list(batch.columns)}")
|
| 56 |
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
import os
|
| 4 |
from abc import ABC, abstractmethod
|
| 5 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
| 6 |
|
|
|
|
| 7 |
import requests
|
| 8 |
+
|
| 9 |
+
if TYPE_CHECKING:
|
| 10 |
+
import pandas as pd
|
| 11 |
+
from ray.data import Dataset
|
| 12 |
|
| 13 |
|
| 14 |
class BaseReader(ABC):
|
|
|
|
| 55 |
"""
|
| 56 |
Validate data format.
|
| 57 |
"""
|
| 58 |
+
|
| 59 |
if "type" not in batch.columns:
|
| 60 |
raise ValueError(f"Missing 'type' column. Found: {list(batch.columns)}")
|
| 61 |
|
graphgen/common/init_llm.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
| 1 |
import os
|
| 2 |
-
from typing import Any, Dict, Optional
|
| 3 |
-
|
| 4 |
-
import ray
|
| 5 |
|
| 6 |
from graphgen.bases import BaseLLMWrapper
|
| 7 |
from graphgen.models import Tokenizer
|
| 8 |
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
class LLMServiceActor:
|
| 11 |
"""
|
|
@@ -73,7 +74,7 @@ class LLMServiceProxy(BaseLLMWrapper):
|
|
| 73 |
A proxy class to interact with the LLMServiceActor for distributed LLM operations.
|
| 74 |
"""
|
| 75 |
|
| 76 |
-
def __init__(self, actor_handle: ray.actor.ActorHandle):
|
| 77 |
super().__init__()
|
| 78 |
self.actor_handle = actor_handle
|
| 79 |
self._create_local_tokenizer()
|
|
@@ -120,6 +121,8 @@ class LLMFactory:
|
|
| 120 |
def create_llm(
|
| 121 |
model_type: str, backend: str, config: Dict[str, Any]
|
| 122 |
) -> BaseLLMWrapper:
|
|
|
|
|
|
|
| 123 |
if not config:
|
| 124 |
raise ValueError(
|
| 125 |
f"No configuration provided for LLM {model_type} with backend {backend}."
|
|
|
|
| 1 |
import os
|
| 2 |
+
from typing import TYPE_CHECKING, Any, Dict, Optional
|
|
|
|
|
|
|
| 3 |
|
| 4 |
from graphgen.bases import BaseLLMWrapper
|
| 5 |
from graphgen.models import Tokenizer
|
| 6 |
|
| 7 |
+
if TYPE_CHECKING:
|
| 8 |
+
import ray
|
| 9 |
+
|
| 10 |
|
| 11 |
class LLMServiceActor:
|
| 12 |
"""
|
|
|
|
| 74 |
A proxy class to interact with the LLMServiceActor for distributed LLM operations.
|
| 75 |
"""
|
| 76 |
|
| 77 |
+
def __init__(self, actor_handle: "ray.actor.ActorHandle"):
|
| 78 |
super().__init__()
|
| 79 |
self.actor_handle = actor_handle
|
| 80 |
self._create_local_tokenizer()
|
|
|
|
| 121 |
def create_llm(
|
| 122 |
model_type: str, backend: str, config: Dict[str, Any]
|
| 123 |
) -> BaseLLMWrapper:
|
| 124 |
+
import ray
|
| 125 |
+
|
| 126 |
if not config:
|
| 127 |
raise ValueError(
|
| 128 |
f"No configuration provided for LLM {model_type} with backend {backend}."
|
graphgen/common/init_storage.py
CHANGED
|
@@ -146,7 +146,7 @@ class GraphStorageActor:
|
|
| 146 |
|
| 147 |
|
| 148 |
class RemoteKVStorageProxy(BaseKVStorage):
|
| 149 |
-
def __init__(self, actor_handle: ray.actor.ActorHandle):
|
| 150 |
super().__init__()
|
| 151 |
self.actor = actor_handle
|
| 152 |
|
|
@@ -202,68 +202,87 @@ class RemoteGraphStorageProxy(BaseGraphStorage):
|
|
| 202 |
return ray.get(self.actor.get_all_node_degrees.remote())
|
| 203 |
|
| 204 |
def get_node_count(self) -> int:
|
|
|
|
| 205 |
return ray.get(self.actor.get_node_count.remote())
|
| 206 |
|
| 207 |
def get_edge_count(self) -> int:
|
|
|
|
| 208 |
return ray.get(self.actor.get_edge_count.remote())
|
| 209 |
|
| 210 |
def get_connected_components(self, undirected: bool = True) -> List[Set[str]]:
|
|
|
|
| 211 |
return ray.get(self.actor.get_connected_components.remote(undirected))
|
| 212 |
|
| 213 |
def has_node(self, node_id: str) -> bool:
|
|
|
|
| 214 |
return ray.get(self.actor.has_node.remote(node_id))
|
| 215 |
|
| 216 |
def has_edge(self, source_node_id: str, target_node_id: str):
|
|
|
|
| 217 |
return ray.get(self.actor.has_edge.remote(source_node_id, target_node_id))
|
| 218 |
|
| 219 |
def node_degree(self, node_id: str) -> int:
|
|
|
|
| 220 |
return ray.get(self.actor.node_degree.remote(node_id))
|
| 221 |
|
| 222 |
def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
|
|
|
| 223 |
return ray.get(self.actor.edge_degree.remote(src_id, tgt_id))
|
| 224 |
|
| 225 |
def get_node(self, node_id: str) -> Any:
|
|
|
|
| 226 |
return ray.get(self.actor.get_node.remote(node_id))
|
| 227 |
|
| 228 |
def update_node(self, node_id: str, node_data: dict[str, str]):
|
|
|
|
| 229 |
return ray.get(self.actor.update_node.remote(node_id, node_data))
|
| 230 |
|
| 231 |
def get_all_nodes(self) -> Any:
|
|
|
|
| 232 |
return ray.get(self.actor.get_all_nodes.remote())
|
| 233 |
|
| 234 |
def get_edge(self, source_node_id: str, target_node_id: str):
|
|
|
|
| 235 |
return ray.get(self.actor.get_edge.remote(source_node_id, target_node_id))
|
| 236 |
|
| 237 |
def update_edge(
|
| 238 |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
| 239 |
):
|
|
|
|
| 240 |
return ray.get(
|
| 241 |
self.actor.update_edge.remote(source_node_id, target_node_id, edge_data)
|
| 242 |
)
|
| 243 |
|
| 244 |
def get_all_edges(self) -> Any:
|
|
|
|
| 245 |
return ray.get(self.actor.get_all_edges.remote())
|
| 246 |
|
| 247 |
def get_node_edges(self, source_node_id: str) -> Any:
|
|
|
|
| 248 |
return ray.get(self.actor.get_node_edges.remote(source_node_id))
|
| 249 |
|
| 250 |
def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
|
|
|
| 251 |
return ray.get(self.actor.upsert_node.remote(node_id, node_data))
|
| 252 |
|
| 253 |
def upsert_edge(
|
| 254 |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
| 255 |
):
|
|
|
|
| 256 |
return ray.get(
|
| 257 |
self.actor.upsert_edge.remote(source_node_id, target_node_id, edge_data)
|
| 258 |
)
|
| 259 |
|
| 260 |
def delete_node(self, node_id: str):
|
|
|
|
| 261 |
return ray.get(self.actor.delete_node.remote(node_id))
|
| 262 |
|
| 263 |
def get_neighbors(self, node_id: str) -> List[str]:
|
|
|
|
| 264 |
return ray.get(self.actor.get_neighbors.remote(node_id))
|
| 265 |
|
| 266 |
def reload(self):
|
|
|
|
| 267 |
return ray.get(self.actor.reload.remote())
|
| 268 |
|
| 269 |
|
|
@@ -274,6 +293,7 @@ class StorageFactory:
|
|
| 274 |
|
| 275 |
@staticmethod
|
| 276 |
def create_storage(backend: str, working_dir: str, namespace: str):
|
|
|
|
| 277 |
if backend in ["json_kv", "rocksdb"]:
|
| 278 |
actor_name = f"Actor_KV_{namespace}"
|
| 279 |
actor_class = KVStorageActor
|
|
|
|
| 146 |
|
| 147 |
|
| 148 |
class RemoteKVStorageProxy(BaseKVStorage):
|
| 149 |
+
def __init__(self, actor_handle: "ray.actor.ActorHandle"):
|
| 150 |
super().__init__()
|
| 151 |
self.actor = actor_handle
|
| 152 |
|
|
|
|
| 202 |
return ray.get(self.actor.get_all_node_degrees.remote())
|
| 203 |
|
| 204 |
def get_node_count(self) -> int:
|
| 205 |
+
|
| 206 |
return ray.get(self.actor.get_node_count.remote())
|
| 207 |
|
| 208 |
def get_edge_count(self) -> int:
|
| 209 |
+
|
| 210 |
return ray.get(self.actor.get_edge_count.remote())
|
| 211 |
|
| 212 |
def get_connected_components(self, undirected: bool = True) -> List[Set[str]]:
|
| 213 |
+
|
| 214 |
return ray.get(self.actor.get_connected_components.remote(undirected))
|
| 215 |
|
| 216 |
def has_node(self, node_id: str) -> bool:
|
| 217 |
+
|
| 218 |
return ray.get(self.actor.has_node.remote(node_id))
|
| 219 |
|
| 220 |
def has_edge(self, source_node_id: str, target_node_id: str):
|
| 221 |
+
|
| 222 |
return ray.get(self.actor.has_edge.remote(source_node_id, target_node_id))
|
| 223 |
|
| 224 |
def node_degree(self, node_id: str) -> int:
|
| 225 |
+
|
| 226 |
return ray.get(self.actor.node_degree.remote(node_id))
|
| 227 |
|
| 228 |
def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
| 229 |
+
|
| 230 |
return ray.get(self.actor.edge_degree.remote(src_id, tgt_id))
|
| 231 |
|
| 232 |
def get_node(self, node_id: str) -> Any:
|
| 233 |
+
|
| 234 |
return ray.get(self.actor.get_node.remote(node_id))
|
| 235 |
|
| 236 |
def update_node(self, node_id: str, node_data: dict[str, str]):
|
| 237 |
+
|
| 238 |
return ray.get(self.actor.update_node.remote(node_id, node_data))
|
| 239 |
|
| 240 |
def get_all_nodes(self) -> Any:
|
| 241 |
+
|
| 242 |
return ray.get(self.actor.get_all_nodes.remote())
|
| 243 |
|
| 244 |
def get_edge(self, source_node_id: str, target_node_id: str):
|
| 245 |
+
|
| 246 |
return ray.get(self.actor.get_edge.remote(source_node_id, target_node_id))
|
| 247 |
|
| 248 |
def update_edge(
|
| 249 |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
| 250 |
):
|
| 251 |
+
|
| 252 |
return ray.get(
|
| 253 |
self.actor.update_edge.remote(source_node_id, target_node_id, edge_data)
|
| 254 |
)
|
| 255 |
|
| 256 |
def get_all_edges(self) -> Any:
|
| 257 |
+
|
| 258 |
return ray.get(self.actor.get_all_edges.remote())
|
| 259 |
|
| 260 |
def get_node_edges(self, source_node_id: str) -> Any:
|
| 261 |
+
|
| 262 |
return ray.get(self.actor.get_node_edges.remote(source_node_id))
|
| 263 |
|
| 264 |
def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
| 265 |
+
|
| 266 |
return ray.get(self.actor.upsert_node.remote(node_id, node_data))
|
| 267 |
|
| 268 |
def upsert_edge(
|
| 269 |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
| 270 |
):
|
| 271 |
+
|
| 272 |
return ray.get(
|
| 273 |
self.actor.upsert_edge.remote(source_node_id, target_node_id, edge_data)
|
| 274 |
)
|
| 275 |
|
| 276 |
def delete_node(self, node_id: str):
|
| 277 |
+
|
| 278 |
return ray.get(self.actor.delete_node.remote(node_id))
|
| 279 |
|
| 280 |
def get_neighbors(self, node_id: str) -> List[str]:
|
| 281 |
+
|
| 282 |
return ray.get(self.actor.get_neighbors.remote(node_id))
|
| 283 |
|
| 284 |
def reload(self):
|
| 285 |
+
|
| 286 |
return ray.get(self.actor.reload.remote())
|
| 287 |
|
| 288 |
|
|
|
|
| 293 |
|
| 294 |
@staticmethod
|
| 295 |
def create_storage(backend: str, working_dir: str, namespace: str):
|
| 296 |
+
|
| 297 |
if backend in ["json_kv", "rocksdb"]:
|
| 298 |
actor_name = f"Actor_KV_{namespace}"
|
| 299 |
actor_class = KVStorageActor
|
graphgen/models/__init__.py
CHANGED
|
@@ -1,48 +1,121 @@
|
|
| 1 |
-
from
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
from .
|
| 44 |
-
from .searcher.
|
| 45 |
-
from .searcher.
|
| 46 |
-
from .searcher.
|
| 47 |
-
from .
|
| 48 |
-
from .
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING
|
| 2 |
+
|
| 3 |
+
if TYPE_CHECKING:
|
| 4 |
+
from .evaluator import (
|
| 5 |
+
AccuracyEvaluator,
|
| 6 |
+
LengthEvaluator,
|
| 7 |
+
MTLDEvaluator,
|
| 8 |
+
RewardEvaluator,
|
| 9 |
+
StructureEvaluator,
|
| 10 |
+
UniEvaluator,
|
| 11 |
+
)
|
| 12 |
+
from .filter import RangeFilter
|
| 13 |
+
from .generator import (
|
| 14 |
+
AggregatedGenerator,
|
| 15 |
+
AtomicGenerator,
|
| 16 |
+
CoTGenerator,
|
| 17 |
+
FillInBlankGenerator,
|
| 18 |
+
MultiAnswerGenerator,
|
| 19 |
+
MultiChoiceGenerator,
|
| 20 |
+
MultiHopGenerator,
|
| 21 |
+
QuizGenerator,
|
| 22 |
+
TrueFalseGenerator,
|
| 23 |
+
VQAGenerator,
|
| 24 |
+
)
|
| 25 |
+
from .kg_builder import LightRAGKGBuilder, MMKGBuilder
|
| 26 |
+
from .llm import HTTPClient, OllamaClient, OpenAIClient
|
| 27 |
+
from .partitioner import (
|
| 28 |
+
AnchorBFSPartitioner,
|
| 29 |
+
BFSPartitioner,
|
| 30 |
+
DFSPartitioner,
|
| 31 |
+
ECEPartitioner,
|
| 32 |
+
LeidenPartitioner,
|
| 33 |
+
)
|
| 34 |
+
from .reader import (
|
| 35 |
+
CSVReader,
|
| 36 |
+
JSONReader,
|
| 37 |
+
ParquetReader,
|
| 38 |
+
PDFReader,
|
| 39 |
+
PickleReader,
|
| 40 |
+
RDFReader,
|
| 41 |
+
TXTReader,
|
| 42 |
+
)
|
| 43 |
+
from .rephraser import StyleControlledRephraser
|
| 44 |
+
from .searcher.db.ncbi_searcher import NCBISearch
|
| 45 |
+
from .searcher.db.rnacentral_searcher import RNACentralSearch
|
| 46 |
+
from .searcher.db.uniprot_searcher import UniProtSearch
|
| 47 |
+
from .searcher.kg.wiki_search import WikiSearch
|
| 48 |
+
from .searcher.web.bing_search import BingSearch
|
| 49 |
+
from .searcher.web.google_search import GoogleSearch
|
| 50 |
+
from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter
|
| 51 |
+
from .tokenizer import Tokenizer
|
| 52 |
+
|
| 53 |
+
_import_map = {
|
| 54 |
+
# Evaluator
|
| 55 |
+
"AccuracyEvaluator": ".evaluator",
|
| 56 |
+
"LengthEvaluator": ".evaluator",
|
| 57 |
+
"MTLDEvaluator": ".evaluator",
|
| 58 |
+
"RewardEvaluator": ".evaluator",
|
| 59 |
+
"StructureEvaluator": ".evaluator",
|
| 60 |
+
"UniEvaluator": ".evaluator",
|
| 61 |
+
# Filter
|
| 62 |
+
"RangeFilter": ".filter",
|
| 63 |
+
# Generator
|
| 64 |
+
"AggregatedGenerator": ".generator",
|
| 65 |
+
"AtomicGenerator": ".generator",
|
| 66 |
+
"CoTGenerator": ".generator",
|
| 67 |
+
"FillInBlankGenerator": ".generator",
|
| 68 |
+
"MultiAnswerGenerator": ".generator",
|
| 69 |
+
"MultiChoiceGenerator": ".generator",
|
| 70 |
+
"MultiHopGenerator": ".generator",
|
| 71 |
+
"QuizGenerator": ".generator",
|
| 72 |
+
"TrueFalseGenerator": ".generator",
|
| 73 |
+
"VQAGenerator": ".generator",
|
| 74 |
+
# KG Builder
|
| 75 |
+
"LightRAGKGBuilder": ".kg_builder",
|
| 76 |
+
"MMKGBuilder": ".kg_builder",
|
| 77 |
+
# LLM
|
| 78 |
+
"HTTPClient": ".llm",
|
| 79 |
+
"OllamaClient": ".llm",
|
| 80 |
+
"OpenAIClient": ".llm",
|
| 81 |
+
# Partitioner
|
| 82 |
+
"AnchorBFSPartitioner": ".partitioner",
|
| 83 |
+
"BFSPartitioner": ".partitioner",
|
| 84 |
+
"DFSPartitioner": ".partitioner",
|
| 85 |
+
"ECEPartitioner": ".partitioner",
|
| 86 |
+
"LeidenPartitioner": ".partitioner",
|
| 87 |
+
# Reader
|
| 88 |
+
"CSVReader": ".reader",
|
| 89 |
+
"JSONReader": ".reader",
|
| 90 |
+
"ParquetReader": ".reader",
|
| 91 |
+
"PDFReader": ".reader",
|
| 92 |
+
"PickleReader": ".reader",
|
| 93 |
+
"RDFReader": ".reader",
|
| 94 |
+
"TXTReader": ".reader",
|
| 95 |
+
# Searcher
|
| 96 |
+
"NCBISearch": ".searcher.db.ncbi_searcher",
|
| 97 |
+
"RNACentralSearch": ".searcher.db.rnacentral_searcher",
|
| 98 |
+
"UniProtSearch": ".searcher.db.uniprot_searcher",
|
| 99 |
+
"WikiSearch": ".searcher.kg.wiki_search",
|
| 100 |
+
"BingSearch": ".searcher.web.bing_search",
|
| 101 |
+
"GoogleSearch": ".searcher.web.google_search",
|
| 102 |
+
# Splitter
|
| 103 |
+
"ChineseRecursiveTextSplitter": ".splitter",
|
| 104 |
+
"RecursiveCharacterSplitter": ".splitter",
|
| 105 |
+
# Tokenizer
|
| 106 |
+
"Tokenizer": ".tokenizer",
|
| 107 |
+
# Rephraser
|
| 108 |
+
"StyleControlledRephraser": ".rephraser",
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def __getattr__(name):
|
| 113 |
+
if name in _import_map:
|
| 114 |
+
import importlib
|
| 115 |
+
|
| 116 |
+
module = importlib.import_module(_import_map[name], package=__name__)
|
| 117 |
+
return getattr(module, name)
|
| 118 |
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
__all__ = list(_import_map.keys())
|
graphgen/models/evaluator/kg/structure_evaluator.py
CHANGED
|
@@ -1,9 +1,6 @@
|
|
| 1 |
from collections import Counter
|
| 2 |
from typing import Any, Dict, Optional
|
| 3 |
|
| 4 |
-
import numpy as np
|
| 5 |
-
from scipy import stats
|
| 6 |
-
|
| 7 |
from graphgen.bases import BaseGraphStorage, BaseKGEvaluator
|
| 8 |
from graphgen.utils import logger
|
| 9 |
|
|
@@ -75,6 +72,9 @@ class StructureEvaluator(BaseKGEvaluator):
|
|
| 75 |
|
| 76 |
@staticmethod
|
| 77 |
def _calculate_powerlaw_r2(degree_map: Dict[str, int]) -> Optional[float]:
|
|
|
|
|
|
|
|
|
|
| 78 |
degrees = [deg for deg in degree_map.values() if deg > 0]
|
| 79 |
|
| 80 |
if len(degrees) < 10:
|
|
|
|
| 1 |
from collections import Counter
|
| 2 |
from typing import Any, Dict, Optional
|
| 3 |
|
|
|
|
|
|
|
|
|
|
| 4 |
from graphgen.bases import BaseGraphStorage, BaseKGEvaluator
|
| 5 |
from graphgen.utils import logger
|
| 6 |
|
|
|
|
| 72 |
|
| 73 |
@staticmethod
|
| 74 |
def _calculate_powerlaw_r2(degree_map: Dict[str, int]) -> Optional[float]:
|
| 75 |
+
import numpy as np
|
| 76 |
+
from scipy import stats
|
| 77 |
+
|
| 78 |
degrees = [deg for deg in degree_map.values() if deg > 0]
|
| 79 |
|
| 80 |
if len(degrees) < 10:
|
graphgen/models/filter/range_filter.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
| 1 |
-
from typing import Union
|
| 2 |
-
|
| 3 |
-
import numpy as np
|
| 4 |
|
| 5 |
from graphgen.bases import BaseValueFilter
|
| 6 |
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
class RangeFilter(BaseValueFilter):
|
| 9 |
"""
|
|
@@ -22,7 +23,7 @@ class RangeFilter(BaseValueFilter):
|
|
| 22 |
self.left_inclusive = left_inclusive
|
| 23 |
self.right_inclusive = right_inclusive
|
| 24 |
|
| 25 |
-
def filter(self, data: Union[int, float, np.number]) -> bool:
|
| 26 |
value = float(data)
|
| 27 |
if self.left_inclusive and self.right_inclusive:
|
| 28 |
return self.min_val <= value <= self.max_val
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING, Union
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from graphgen.bases import BaseValueFilter
|
| 4 |
|
| 5 |
+
if TYPE_CHECKING:
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
|
| 9 |
class RangeFilter(BaseValueFilter):
|
| 10 |
"""
|
|
|
|
| 23 |
self.left_inclusive = left_inclusive
|
| 24 |
self.right_inclusive = right_inclusive
|
| 25 |
|
| 26 |
+
def filter(self, data: Union[int, float, "np.number"]) -> bool:
|
| 27 |
value = float(data)
|
| 28 |
if self.left_inclusive and self.right_inclusive:
|
| 29 |
return self.min_val <= value <= self.max_val
|
graphgen/models/llm/__init__.py
CHANGED
|
@@ -1,4 +1,27 @@
|
|
| 1 |
-
from
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
from .
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING
|
| 2 |
+
|
| 3 |
+
if TYPE_CHECKING:
|
| 4 |
+
from .api.http_client import HTTPClient
|
| 5 |
+
from .api.ollama_client import OllamaClient
|
| 6 |
+
from .api.openai_client import OpenAIClient
|
| 7 |
+
from .local.hf_wrapper import HuggingFaceWrapper
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
_import_map = {
|
| 11 |
+
"HTTPClient": ".api.http_client",
|
| 12 |
+
"OllamaClient": ".api.ollama_client",
|
| 13 |
+
"OpenAIClient": ".api.openai_client",
|
| 14 |
+
"HuggingFaceWrapper": ".local.hf_wrapper",
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def __getattr__(name):
|
| 19 |
+
if name in _import_map:
|
| 20 |
+
import importlib
|
| 21 |
+
|
| 22 |
+
module = importlib.import_module(_import_map[name], package=__name__)
|
| 23 |
+
return getattr(module, name)
|
| 24 |
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
__all__ = list(_import_map.keys())
|
graphgen/models/partitioner/leiden_partitioner.py
CHANGED
|
@@ -1,12 +1,12 @@
|
|
| 1 |
from collections import defaultdict
|
| 2 |
-
from typing import Any, Dict, List, Set, Tuple
|
| 3 |
-
|
| 4 |
-
import igraph as ig
|
| 5 |
-
from leidenalg import ModularityVertexPartition, find_partition
|
| 6 |
|
| 7 |
from graphgen.bases import BaseGraphStorage, BasePartitioner
|
| 8 |
from graphgen.bases.datatypes import Community
|
| 9 |
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
class LeidenPartitioner(BasePartitioner):
|
| 12 |
"""
|
|
@@ -62,6 +62,9 @@ class LeidenPartitioner(BasePartitioner):
|
|
| 62 |
use_lcc: bool = False,
|
| 63 |
random_seed: int = 42,
|
| 64 |
) -> Dict[str, int]:
|
|
|
|
|
|
|
|
|
|
| 65 |
# build igraph
|
| 66 |
ig_graph = ig.Graph.TupleList(((u, v) for u, v, _ in edges), directed=False)
|
| 67 |
|
|
|
|
| 1 |
from collections import defaultdict
|
| 2 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Set, Tuple
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
from graphgen.bases import BaseGraphStorage, BasePartitioner
|
| 5 |
from graphgen.bases.datatypes import Community
|
| 6 |
|
| 7 |
+
if TYPE_CHECKING:
|
| 8 |
+
import igraph as ig
|
| 9 |
+
|
| 10 |
|
| 11 |
class LeidenPartitioner(BasePartitioner):
|
| 12 |
"""
|
|
|
|
| 62 |
use_lcc: bool = False,
|
| 63 |
random_seed: int = 42,
|
| 64 |
) -> Dict[str, int]:
|
| 65 |
+
import igraph as ig
|
| 66 |
+
from leidenalg import ModularityVertexPartition, find_partition
|
| 67 |
+
|
| 68 |
# build igraph
|
| 69 |
ig_graph = ig.Graph.TupleList(((u, v) for u, v, _ in edges), directed=False)
|
| 70 |
|
graphgen/models/reader/csv_reader.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
| 1 |
-
from typing import List, Union
|
| 2 |
-
|
| 3 |
-
import ray
|
| 4 |
-
from ray.data import Dataset
|
| 5 |
|
| 6 |
from graphgen.bases.base_reader import BaseReader
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
class CSVReader(BaseReader):
|
| 10 |
"""
|
|
@@ -14,13 +15,14 @@ class CSVReader(BaseReader):
|
|
| 14 |
- if type is "text", "content" column must be present.
|
| 15 |
"""
|
| 16 |
|
| 17 |
-
def read(self, input_path: Union[str, List[str]]) -> Dataset:
|
| 18 |
"""
|
| 19 |
Read CSV files and return Ray Dataset.
|
| 20 |
|
| 21 |
:param input_path: Path to CSV file or list of CSV files.
|
| 22 |
:return: Ray Dataset containing validated and filtered data.
|
| 23 |
"""
|
|
|
|
| 24 |
|
| 25 |
ds = ray.data.read_csv(input_path, include_paths=True)
|
| 26 |
ds = ds.map_batches(self._validate_batch, batch_format="pandas")
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING, List, Union
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from graphgen.bases.base_reader import BaseReader
|
| 4 |
|
| 5 |
+
if TYPE_CHECKING:
|
| 6 |
+
import ray
|
| 7 |
+
from ray.data import Dataset
|
| 8 |
+
|
| 9 |
|
| 10 |
class CSVReader(BaseReader):
|
| 11 |
"""
|
|
|
|
| 15 |
- if type is "text", "content" column must be present.
|
| 16 |
"""
|
| 17 |
|
| 18 |
+
def read(self, input_path: Union[str, List[str]]) -> "Dataset":
|
| 19 |
"""
|
| 20 |
Read CSV files and return Ray Dataset.
|
| 21 |
|
| 22 |
:param input_path: Path to CSV file or list of CSV files.
|
| 23 |
:return: Ray Dataset containing validated and filtered data.
|
| 24 |
"""
|
| 25 |
+
import ray
|
| 26 |
|
| 27 |
ds = ray.data.read_csv(input_path, include_paths=True)
|
| 28 |
ds = ds.map_batches(self._validate_batch, batch_format="pandas")
|
graphgen/models/reader/json_reader.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
| 1 |
import json
|
| 2 |
-
from typing import List, Union
|
| 3 |
-
|
| 4 |
-
import ray
|
| 5 |
-
import ray.data
|
| 6 |
|
| 7 |
from graphgen.bases.base_reader import BaseReader
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
class JSONReader(BaseReader):
|
| 11 |
"""
|
|
@@ -15,12 +16,14 @@ class JSONReader(BaseReader):
|
|
| 15 |
- if type is "text", "content" column must be present.
|
| 16 |
"""
|
| 17 |
|
| 18 |
-
def read(self, input_path: Union[str, List[str]]) -> ray.data.Dataset:
|
| 19 |
"""
|
| 20 |
Read JSON file and return Ray Dataset.
|
| 21 |
:param input_path: Path to JSON/JSONL file or list of JSON/JSONL files.
|
| 22 |
:return: Ray Dataset containing validated and filtered data.
|
| 23 |
"""
|
|
|
|
|
|
|
| 24 |
if self.modalities and len(self.modalities) >= 2:
|
| 25 |
ds: ray.data.Dataset = ray.data.from_items([])
|
| 26 |
for file in input_path if isinstance(input_path, list) else [input_path]:
|
|
|
|
| 1 |
import json
|
| 2 |
+
from typing import TYPE_CHECKING, List, Union
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
from graphgen.bases.base_reader import BaseReader
|
| 5 |
|
| 6 |
+
if TYPE_CHECKING:
|
| 7 |
+
import ray
|
| 8 |
+
import ray.data
|
| 9 |
+
|
| 10 |
|
| 11 |
class JSONReader(BaseReader):
|
| 12 |
"""
|
|
|
|
| 16 |
- if type is "text", "content" column must be present.
|
| 17 |
"""
|
| 18 |
|
| 19 |
+
def read(self, input_path: Union[str, List[str]]) -> "ray.data.Dataset":
|
| 20 |
"""
|
| 21 |
Read JSON file and return Ray Dataset.
|
| 22 |
:param input_path: Path to JSON/JSONL file or list of JSON/JSONL files.
|
| 23 |
:return: Ray Dataset containing validated and filtered data.
|
| 24 |
"""
|
| 25 |
+
import ray
|
| 26 |
+
|
| 27 |
if self.modalities and len(self.modalities) >= 2:
|
| 28 |
ds: ray.data.Dataset = ray.data.from_items([])
|
| 29 |
for file in input_path if isinstance(input_path, list) else [input_path]:
|
graphgen/models/reader/parquet_reader.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
| 1 |
-
from typing import List, Union
|
| 2 |
-
|
| 3 |
-
import ray
|
| 4 |
-
from ray.data import Dataset
|
| 5 |
|
| 6 |
from graphgen.bases.base_reader import BaseReader
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
class ParquetReader(BaseReader):
|
| 10 |
"""
|
|
@@ -14,13 +15,15 @@ class ParquetReader(BaseReader):
|
|
| 14 |
- if type is "text", "content" column must be present.
|
| 15 |
"""
|
| 16 |
|
| 17 |
-
def read(self, input_path: Union[str, List[str]]) -> Dataset:
|
| 18 |
"""
|
| 19 |
Read Parquet files using Ray Data.
|
| 20 |
|
| 21 |
:param input_path: Path to Parquet file or list of Parquet files.
|
| 22 |
:return: Ray Dataset containing validated documents.
|
| 23 |
"""
|
|
|
|
|
|
|
| 24 |
if not ray.is_initialized():
|
| 25 |
ray.init()
|
| 26 |
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING, List, Union
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from graphgen.bases.base_reader import BaseReader
|
| 4 |
|
| 5 |
+
if TYPE_CHECKING:
|
| 6 |
+
import ray
|
| 7 |
+
from ray.data import Dataset
|
| 8 |
+
|
| 9 |
|
| 10 |
class ParquetReader(BaseReader):
|
| 11 |
"""
|
|
|
|
| 15 |
- if type is "text", "content" column must be present.
|
| 16 |
"""
|
| 17 |
|
| 18 |
+
def read(self, input_path: Union[str, List[str]]) -> "Dataset":
|
| 19 |
"""
|
| 20 |
Read Parquet files using Ray Data.
|
| 21 |
|
| 22 |
:param input_path: Path to Parquet file or list of Parquet files.
|
| 23 |
:return: Ray Dataset containing validated documents.
|
| 24 |
"""
|
| 25 |
+
import ray
|
| 26 |
+
|
| 27 |
if not ray.is_initialized():
|
| 28 |
ray.init()
|
| 29 |
|
graphgen/models/reader/pdf_reader.py
CHANGED
|
@@ -3,15 +3,16 @@ import os
|
|
| 3 |
import subprocess
|
| 4 |
import tempfile
|
| 5 |
from pathlib import Path
|
| 6 |
-
from typing import Any, Dict, List, Optional, Union
|
| 7 |
-
|
| 8 |
-
import ray
|
| 9 |
-
from ray.data import Dataset
|
| 10 |
|
| 11 |
from graphgen.bases.base_reader import BaseReader
|
| 12 |
from graphgen.models.reader.txt_reader import TXTReader
|
| 13 |
from graphgen.utils import logger, pick_device
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
class PDFReader(BaseReader):
|
| 17 |
"""
|
|
@@ -69,7 +70,8 @@ class PDFReader(BaseReader):
|
|
| 69 |
self,
|
| 70 |
input_path: Union[str, List[str]],
|
| 71 |
**override,
|
| 72 |
-
) -> Dataset:
|
|
|
|
| 73 |
|
| 74 |
# Ensure input_path is a list
|
| 75 |
if isinstance(input_path, str):
|
|
|
|
| 3 |
import subprocess
|
| 4 |
import tempfile
|
| 5 |
from pathlib import Path
|
| 6 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
from graphgen.bases.base_reader import BaseReader
|
| 9 |
from graphgen.models.reader.txt_reader import TXTReader
|
| 10 |
from graphgen.utils import logger, pick_device
|
| 11 |
|
| 12 |
+
if TYPE_CHECKING:
|
| 13 |
+
import ray
|
| 14 |
+
from ray.data import Dataset
|
| 15 |
+
|
| 16 |
|
| 17 |
class PDFReader(BaseReader):
|
| 18 |
"""
|
|
|
|
| 70 |
self,
|
| 71 |
input_path: Union[str, List[str]],
|
| 72 |
**override,
|
| 73 |
+
) -> "Dataset":
|
| 74 |
+
import ray
|
| 75 |
|
| 76 |
# Ensure input_path is a list
|
| 77 |
if isinstance(input_path, str):
|
graphgen/models/reader/pickle_reader.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
| 1 |
import pickle
|
| 2 |
-
from typing import List, Union
|
| 3 |
-
|
| 4 |
-
import pandas as pd
|
| 5 |
-
import ray
|
| 6 |
-
from ray.data import Dataset
|
| 7 |
|
| 8 |
from graphgen.bases.base_reader import BaseReader
|
| 9 |
from graphgen.utils import logger
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
class PickleReader(BaseReader):
|
| 13 |
"""
|
|
@@ -23,13 +23,16 @@ class PickleReader(BaseReader):
|
|
| 23 |
def read(
|
| 24 |
self,
|
| 25 |
input_path: Union[str, List[str]],
|
| 26 |
-
) -> Dataset:
|
| 27 |
"""
|
| 28 |
Read Pickle files using Ray Data.
|
| 29 |
|
| 30 |
:param input_path: Path to pickle file or list of pickle files.
|
| 31 |
:return: Ray Dataset containing validated documents.
|
| 32 |
"""
|
|
|
|
|
|
|
|
|
|
| 33 |
if not ray.is_initialized():
|
| 34 |
ray.init()
|
| 35 |
|
|
@@ -37,7 +40,7 @@ class PickleReader(BaseReader):
|
|
| 37 |
ds = ray.data.read_binary_files(input_path, include_paths=True)
|
| 38 |
|
| 39 |
# Deserialize pickle files and flatten into individual records
|
| 40 |
-
def deserialize_batch(batch: pd.DataFrame) -> pd.DataFrame:
|
| 41 |
all_records = []
|
| 42 |
for _, row in batch.iterrows():
|
| 43 |
try:
|
|
|
|
| 1 |
import pickle
|
| 2 |
+
from typing import TYPE_CHECKING, List, Union
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
from graphgen.bases.base_reader import BaseReader
|
| 5 |
from graphgen.utils import logger
|
| 6 |
|
| 7 |
+
if TYPE_CHECKING:
|
| 8 |
+
import pandas as pd
|
| 9 |
+
from ray.data import Dataset
|
| 10 |
+
|
| 11 |
|
| 12 |
class PickleReader(BaseReader):
|
| 13 |
"""
|
|
|
|
| 23 |
def read(
|
| 24 |
self,
|
| 25 |
input_path: Union[str, List[str]],
|
| 26 |
+
) -> "Dataset":
|
| 27 |
"""
|
| 28 |
Read Pickle files using Ray Data.
|
| 29 |
|
| 30 |
:param input_path: Path to pickle file or list of pickle files.
|
| 31 |
:return: Ray Dataset containing validated documents.
|
| 32 |
"""
|
| 33 |
+
import pandas as pd
|
| 34 |
+
import ray
|
| 35 |
+
|
| 36 |
if not ray.is_initialized():
|
| 37 |
ray.init()
|
| 38 |
|
|
|
|
| 40 |
ds = ray.data.read_binary_files(input_path, include_paths=True)
|
| 41 |
|
| 42 |
# Deserialize pickle files and flatten into individual records
|
| 43 |
+
def deserialize_batch(batch: "pd.DataFrame") -> "pd.DataFrame":
|
| 44 |
all_records = []
|
| 45 |
for _, row in batch.iterrows():
|
| 46 |
try:
|
graphgen/models/reader/rdf_reader.py
CHANGED
|
@@ -1,15 +1,15 @@
|
|
| 1 |
from pathlib import Path
|
| 2 |
-
from typing import Any, Dict, List, Union
|
| 3 |
-
|
| 4 |
-
import ray
|
| 5 |
-
import rdflib
|
| 6 |
-
from ray.data import Dataset
|
| 7 |
-
from rdflib import Literal
|
| 8 |
-
from rdflib.util import guess_format
|
| 9 |
|
| 10 |
from graphgen.bases.base_reader import BaseReader
|
| 11 |
from graphgen.utils import logger
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
class RDFReader(BaseReader):
|
| 15 |
"""
|
|
@@ -30,13 +30,15 @@ class RDFReader(BaseReader):
|
|
| 30 |
def read(
|
| 31 |
self,
|
| 32 |
input_path: Union[str, List[str]],
|
| 33 |
-
) -> Dataset:
|
| 34 |
"""
|
| 35 |
Read RDF file(s) using Ray Data.
|
| 36 |
|
| 37 |
:param input_path: Path to RDF file or list of RDF files.
|
| 38 |
:return: Ray Dataset containing extracted documents.
|
| 39 |
"""
|
|
|
|
|
|
|
| 40 |
if not ray.is_initialized():
|
| 41 |
ray.init()
|
| 42 |
|
|
@@ -73,6 +75,10 @@ class RDFReader(BaseReader):
|
|
| 73 |
:param file_path: Path to RDF file.
|
| 74 |
:return: List of document dictionaries.
|
| 75 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
if not file_path.is_file():
|
| 77 |
raise FileNotFoundError(f"RDF file not found: {file_path}")
|
| 78 |
|
|
|
|
| 1 |
from pathlib import Path
|
| 2 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
from graphgen.bases.base_reader import BaseReader
|
| 5 |
from graphgen.utils import logger
|
| 6 |
|
| 7 |
+
if TYPE_CHECKING:
|
| 8 |
+
import ray
|
| 9 |
+
import rdflib
|
| 10 |
+
from ray.data import Dataset
|
| 11 |
+
from rdflib import Literal
|
| 12 |
+
|
| 13 |
|
| 14 |
class RDFReader(BaseReader):
|
| 15 |
"""
|
|
|
|
| 30 |
def read(
|
| 31 |
self,
|
| 32 |
input_path: Union[str, List[str]],
|
| 33 |
+
) -> "Dataset":
|
| 34 |
"""
|
| 35 |
Read RDF file(s) using Ray Data.
|
| 36 |
|
| 37 |
:param input_path: Path to RDF file or list of RDF files.
|
| 38 |
:return: Ray Dataset containing extracted documents.
|
| 39 |
"""
|
| 40 |
+
import ray
|
| 41 |
+
|
| 42 |
if not ray.is_initialized():
|
| 43 |
ray.init()
|
| 44 |
|
|
|
|
| 75 |
:param file_path: Path to RDF file.
|
| 76 |
:return: List of document dictionaries.
|
| 77 |
"""
|
| 78 |
+
import rdflib
|
| 79 |
+
from rdflib import Literal
|
| 80 |
+
from rdflib.util import guess_format
|
| 81 |
+
|
| 82 |
if not file_path.is_file():
|
| 83 |
raise FileNotFoundError(f"RDF file not found: {file_path}")
|
| 84 |
|
graphgen/models/reader/txt_reader.py
CHANGED
|
@@ -1,21 +1,24 @@
|
|
| 1 |
-
from typing import List, Union
|
| 2 |
-
|
| 3 |
-
import ray
|
| 4 |
-
from ray.data import Dataset
|
| 5 |
|
| 6 |
from graphgen.bases.base_reader import BaseReader
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
class TXTReader(BaseReader):
|
| 10 |
def read(
|
| 11 |
self,
|
| 12 |
input_path: Union[str, List[str]],
|
| 13 |
-
) -> Dataset:
|
| 14 |
"""
|
| 15 |
Read text files from the specified input path.
|
| 16 |
:param input_path: Path to the input text file or list of text files.
|
| 17 |
:return: Ray Dataset containing the read text data.
|
| 18 |
"""
|
|
|
|
|
|
|
| 19 |
docs_ds = ray.data.read_binary_files(
|
| 20 |
input_path,
|
| 21 |
include_paths=True,
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING, List, Union
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from graphgen.bases.base_reader import BaseReader
|
| 4 |
|
| 5 |
+
if TYPE_CHECKING:
|
| 6 |
+
import ray
|
| 7 |
+
from ray.data import Dataset
|
| 8 |
+
|
| 9 |
|
| 10 |
class TXTReader(BaseReader):
|
| 11 |
def read(
|
| 12 |
self,
|
| 13 |
input_path: Union[str, List[str]],
|
| 14 |
+
) -> "Dataset":
|
| 15 |
"""
|
| 16 |
Read text files from the specified input path.
|
| 17 |
:param input_path: Path to the input text file or list of text files.
|
| 18 |
:return: Ray Dataset containing the read text data.
|
| 19 |
"""
|
| 20 |
+
import ray
|
| 21 |
+
|
| 22 |
docs_ds = ray.data.read_binary_files(
|
| 23 |
input_path,
|
| 24 |
include_paths=True,
|
graphgen/models/tokenizer/__init__.py
CHANGED
|
@@ -4,29 +4,21 @@ from graphgen.bases import BaseTokenizer
|
|
| 4 |
|
| 5 |
from .tiktoken_tokenizer import TiktokenTokenizer
|
| 6 |
|
| 7 |
-
try:
|
| 8 |
-
from transformers import AutoTokenizer
|
| 9 |
-
|
| 10 |
-
_HF_AVAILABLE = True
|
| 11 |
-
except ImportError:
|
| 12 |
-
_HF_AVAILABLE = False
|
| 13 |
-
|
| 14 |
|
| 15 |
def get_tokenizer_impl(tokenizer_name: str = "cl100k_base") -> BaseTokenizer:
|
| 16 |
import tiktoken
|
| 17 |
|
| 18 |
if tokenizer_name in tiktoken.list_encoding_names():
|
| 19 |
return TiktokenTokenizer(model_name=tokenizer_name)
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
if _HF_AVAILABLE:
|
| 23 |
from .hf_tokenizer import HFTokenizer
|
| 24 |
|
| 25 |
return HFTokenizer(model_name=tokenizer_name)
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
|
| 31 |
|
| 32 |
class Tokenizer(BaseTokenizer):
|
|
|
|
| 4 |
|
| 5 |
from .tiktoken_tokenizer import TiktokenTokenizer
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
def get_tokenizer_impl(tokenizer_name: str = "cl100k_base") -> BaseTokenizer:
|
| 9 |
import tiktoken
|
| 10 |
|
| 11 |
if tokenizer_name in tiktoken.list_encoding_names():
|
| 12 |
return TiktokenTokenizer(model_name=tokenizer_name)
|
| 13 |
+
try:
|
| 14 |
+
# HuggingFace
|
|
|
|
| 15 |
from .hf_tokenizer import HFTokenizer
|
| 16 |
|
| 17 |
return HFTokenizer(model_name=tokenizer_name)
|
| 18 |
+
except ImportError as e:
|
| 19 |
+
raise ValueError(
|
| 20 |
+
f"Unknown tokenizer {tokenizer_name} and HuggingFace not available."
|
| 21 |
+
) from e
|
| 22 |
|
| 23 |
|
| 24 |
class Tokenizer(BaseTokenizer):
|
graphgen/models/tokenizer/hf_tokenizer.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
| 1 |
from typing import List
|
| 2 |
|
| 3 |
-
from transformers import AutoTokenizer
|
| 4 |
-
|
| 5 |
from graphgen.bases import BaseTokenizer
|
| 6 |
|
| 7 |
|
| 8 |
class HFTokenizer(BaseTokenizer):
|
| 9 |
def __init__(self, model_name: str = "cl100k_base"):
|
| 10 |
super().__init__(model_name)
|
|
|
|
|
|
|
| 11 |
self.enc = AutoTokenizer.from_pretrained(self.model_name)
|
| 12 |
|
| 13 |
def encode(self, text: str) -> List[int]:
|
|
|
|
| 1 |
from typing import List
|
| 2 |
|
|
|
|
|
|
|
| 3 |
from graphgen.bases import BaseTokenizer
|
| 4 |
|
| 5 |
|
| 6 |
class HFTokenizer(BaseTokenizer):
|
| 7 |
def __init__(self, model_name: str = "cl100k_base"):
|
| 8 |
super().__init__(model_name)
|
| 9 |
+
from transformers import AutoTokenizer
|
| 10 |
+
|
| 11 |
self.enc = AutoTokenizer.from_pretrained(self.model_name)
|
| 12 |
|
| 13 |
def encode(self, text: str) -> List[int]:
|
graphgen/models/tokenizer/tiktoken_tokenizer.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
| 1 |
from typing import List
|
| 2 |
|
| 3 |
-
import tiktoken
|
| 4 |
-
|
| 5 |
from graphgen.bases import BaseTokenizer
|
| 6 |
|
| 7 |
|
| 8 |
class TiktokenTokenizer(BaseTokenizer):
|
| 9 |
def __init__(self, model_name: str = "cl100k_base"):
|
| 10 |
super().__init__(model_name)
|
|
|
|
|
|
|
| 11 |
self.enc = tiktoken.get_encoding(self.model_name)
|
| 12 |
|
| 13 |
def encode(self, text: str) -> List[int]:
|
|
|
|
| 1 |
from typing import List
|
| 2 |
|
|
|
|
|
|
|
| 3 |
from graphgen.bases import BaseTokenizer
|
| 4 |
|
| 5 |
|
| 6 |
class TiktokenTokenizer(BaseTokenizer):
|
| 7 |
def __init__(self, model_name: str = "cl100k_base"):
|
| 8 |
super().__init__(model_name)
|
| 9 |
+
import tiktoken
|
| 10 |
+
|
| 11 |
self.enc = tiktoken.get_encoding(self.model_name)
|
| 12 |
|
| 13 |
def encode(self, text: str) -> List[int]:
|
graphgen/operators/chunk/chunk_service.py
CHANGED
|
@@ -1,36 +1,40 @@
|
|
| 1 |
import os
|
| 2 |
from functools import lru_cache
|
| 3 |
-
from typing import Tuple, Union
|
| 4 |
|
| 5 |
from graphgen.bases import BaseOperator
|
| 6 |
-
from graphgen.models import (
|
| 7 |
-
ChineseRecursiveTextSplitter,
|
| 8 |
-
RecursiveCharacterSplitter,
|
| 9 |
-
Tokenizer,
|
| 10 |
-
)
|
| 11 |
from graphgen.utils import detect_main_language
|
| 12 |
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
| 17 |
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
@lru_cache(maxsize=None)
|
| 22 |
def _get_splitter(language: str, frozen_kwargs: frozenset) -> SplitterT:
|
| 23 |
-
cls = _MAPPING[language]
|
| 24 |
kwargs = dict(frozen_kwargs)
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
|
| 28 |
def split_chunks(text: str, language: str = "en", **kwargs) -> list:
|
| 29 |
-
if language not in _MAPPING:
|
| 30 |
-
raise ValueError(
|
| 31 |
-
f"Unsupported language: {language}. "
|
| 32 |
-
f"Supported languages are: {list(_MAPPING.keys())}"
|
| 33 |
-
)
|
| 34 |
frozen_kwargs = frozenset(
|
| 35 |
(k, tuple(v) if isinstance(v, list) else v) for k, v in kwargs.items()
|
| 36 |
)
|
|
@@ -45,10 +49,18 @@ class ChunkService(BaseOperator):
|
|
| 45 |
super().__init__(
|
| 46 |
working_dir=working_dir, kv_backend=kv_backend, op_name="chunk"
|
| 47 |
)
|
| 48 |
-
tokenizer_model = os.getenv("TOKENIZER_MODEL", "cl100k_base")
|
| 49 |
-
self.
|
| 50 |
self.chunk_kwargs = chunk_kwargs
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
def process(self, batch: list) -> Tuple[list, dict]:
|
| 53 |
"""
|
| 54 |
Chunk the documents in the batch.
|
|
|
|
| 1 |
import os
|
| 2 |
from functools import lru_cache
|
| 3 |
+
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
|
| 4 |
|
| 5 |
from graphgen.bases import BaseOperator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
from graphgen.utils import detect_main_language
|
| 7 |
|
| 8 |
+
if TYPE_CHECKING:
|
| 9 |
+
from graphgen.models import (
|
| 10 |
+
ChineseRecursiveTextSplitter,
|
| 11 |
+
RecursiveCharacterSplitter,
|
| 12 |
+
Tokenizer,
|
| 13 |
+
)
|
| 14 |
|
| 15 |
+
if TYPE_CHECKING:
|
| 16 |
+
SplitterT = Union["RecursiveCharacterSplitter", "ChineseRecursiveTextSplitter"]
|
| 17 |
+
else:
|
| 18 |
+
SplitterT = Any
|
| 19 |
|
| 20 |
|
| 21 |
@lru_cache(maxsize=None)
|
| 22 |
def _get_splitter(language: str, frozen_kwargs: frozenset) -> SplitterT:
|
|
|
|
| 23 |
kwargs = dict(frozen_kwargs)
|
| 24 |
+
if language == "en":
|
| 25 |
+
from graphgen.models import RecursiveCharacterSplitter
|
| 26 |
+
|
| 27 |
+
return RecursiveCharacterSplitter(**kwargs)
|
| 28 |
+
if language == "zh":
|
| 29 |
+
from graphgen.models import ChineseRecursiveTextSplitter
|
| 30 |
+
|
| 31 |
+
return ChineseRecursiveTextSplitter(**kwargs)
|
| 32 |
+
raise ValueError(
|
| 33 |
+
f"Unsupported language: {language}. Supported languages are: en, zh"
|
| 34 |
+
)
|
| 35 |
|
| 36 |
|
| 37 |
def split_chunks(text: str, language: str = "en", **kwargs) -> list:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
frozen_kwargs = frozenset(
|
| 39 |
(k, tuple(v) if isinstance(v, list) else v) for k, v in kwargs.items()
|
| 40 |
)
|
|
|
|
| 49 |
super().__init__(
|
| 50 |
working_dir=working_dir, kv_backend=kv_backend, op_name="chunk"
|
| 51 |
)
|
| 52 |
+
self.tokenizer_model = os.getenv("TOKENIZER_MODEL", "cl100k_base")
|
| 53 |
+
self._tokenizer_instance: Optional["Tokenizer"] = None
|
| 54 |
self.chunk_kwargs = chunk_kwargs
|
| 55 |
|
| 56 |
+
@property
|
| 57 |
+
def tokenizer_instance(self) -> "Tokenizer":
|
| 58 |
+
if self._tokenizer_instance is None:
|
| 59 |
+
from graphgen.models import Tokenizer
|
| 60 |
+
|
| 61 |
+
self._tokenizer_instance = Tokenizer(model_name=self.tokenizer_model)
|
| 62 |
+
return self._tokenizer_instance
|
| 63 |
+
|
| 64 |
def process(self, batch: list) -> Tuple[list, dict]:
|
| 65 |
"""
|
| 66 |
Chunk the documents in the batch.
|
graphgen/operators/partition/partition_service.py
CHANGED
|
@@ -3,14 +3,6 @@ from typing import Iterable, Tuple
|
|
| 3 |
|
| 4 |
from graphgen.bases import BaseGraphStorage, BaseOperator, BaseTokenizer
|
| 5 |
from graphgen.common.init_storage import init_storage
|
| 6 |
-
from graphgen.models import (
|
| 7 |
-
AnchorBFSPartitioner,
|
| 8 |
-
BFSPartitioner,
|
| 9 |
-
DFSPartitioner,
|
| 10 |
-
ECEPartitioner,
|
| 11 |
-
LeidenPartitioner,
|
| 12 |
-
Tokenizer,
|
| 13 |
-
)
|
| 14 |
from graphgen.utils import logger
|
| 15 |
|
| 16 |
|
|
@@ -31,21 +23,34 @@ class PartitionService(BaseOperator):
|
|
| 31 |
namespace="graph",
|
| 32 |
)
|
| 33 |
tokenizer_model = os.getenv("TOKENIZER_MODEL", "cl100k_base")
|
|
|
|
|
|
|
|
|
|
| 34 |
self.tokenizer_instance: BaseTokenizer = Tokenizer(model_name=tokenizer_model)
|
| 35 |
method = partition_kwargs["method"]
|
| 36 |
self.method_params = partition_kwargs["method_params"]
|
| 37 |
|
| 38 |
if method == "bfs":
|
|
|
|
|
|
|
| 39 |
self.partitioner = BFSPartitioner()
|
| 40 |
elif method == "dfs":
|
|
|
|
|
|
|
| 41 |
self.partitioner = DFSPartitioner()
|
| 42 |
elif method == "ece":
|
| 43 |
# before ECE partitioning, we need to:
|
| 44 |
# 'quiz' and 'judge' to get the comprehension loss if unit_sampling is not random
|
|
|
|
|
|
|
| 45 |
self.partitioner = ECEPartitioner()
|
| 46 |
elif method == "leiden":
|
|
|
|
|
|
|
| 47 |
self.partitioner = LeidenPartitioner()
|
| 48 |
elif method == "anchor_bfs":
|
|
|
|
|
|
|
| 49 |
self.partitioner = AnchorBFSPartitioner(
|
| 50 |
anchor_type=self.method_params.get("anchor_type"),
|
| 51 |
anchor_ids=set(self.method_params.get("anchor_ids", []))
|
|
|
|
| 3 |
|
| 4 |
from graphgen.bases import BaseGraphStorage, BaseOperator, BaseTokenizer
|
| 5 |
from graphgen.common.init_storage import init_storage
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
from graphgen.utils import logger
|
| 7 |
|
| 8 |
|
|
|
|
| 23 |
namespace="graph",
|
| 24 |
)
|
| 25 |
tokenizer_model = os.getenv("TOKENIZER_MODEL", "cl100k_base")
|
| 26 |
+
|
| 27 |
+
from graphgen.models import Tokenizer
|
| 28 |
+
|
| 29 |
self.tokenizer_instance: BaseTokenizer = Tokenizer(model_name=tokenizer_model)
|
| 30 |
method = partition_kwargs["method"]
|
| 31 |
self.method_params = partition_kwargs["method_params"]
|
| 32 |
|
| 33 |
if method == "bfs":
|
| 34 |
+
from graphgen.models import BFSPartitioner
|
| 35 |
+
|
| 36 |
self.partitioner = BFSPartitioner()
|
| 37 |
elif method == "dfs":
|
| 38 |
+
from graphgen.models import DFSPartitioner
|
| 39 |
+
|
| 40 |
self.partitioner = DFSPartitioner()
|
| 41 |
elif method == "ece":
|
| 42 |
# before ECE partitioning, we need to:
|
| 43 |
# 'quiz' and 'judge' to get the comprehension loss if unit_sampling is not random
|
| 44 |
+
from graphgen.models import ECEPartitioner
|
| 45 |
+
|
| 46 |
self.partitioner = ECEPartitioner()
|
| 47 |
elif method == "leiden":
|
| 48 |
+
from graphgen.models import LeidenPartitioner
|
| 49 |
+
|
| 50 |
self.partitioner = LeidenPartitioner()
|
| 51 |
elif method == "anchor_bfs":
|
| 52 |
+
from graphgen.models import AnchorBFSPartitioner
|
| 53 |
+
|
| 54 |
self.partitioner = AnchorBFSPartitioner(
|
| 55 |
anchor_type=self.method_params.get("anchor_type"),
|
| 56 |
anchor_ids=set(self.method_params.get("anchor_ids", []))
|
graphgen/operators/read/read.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
| 1 |
from pathlib import Path
|
| 2 |
-
from typing import Any, List, Optional, Union
|
| 3 |
-
|
| 4 |
-
import ray
|
| 5 |
|
| 6 |
from graphgen.common.init_storage import init_storage
|
| 7 |
from graphgen.models import (
|
|
@@ -17,6 +15,11 @@ from graphgen.utils import compute_dict_hash, logger
|
|
| 17 |
|
| 18 |
from .parallel_file_scanner import ParallelFileScanner
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
_MAPPING = {
|
| 21 |
"jsonl": JSONReader,
|
| 22 |
"json": JSONReader,
|
|
@@ -57,7 +60,7 @@ def read(
|
|
| 57 |
recursive: bool = True,
|
| 58 |
read_nums: Optional[int] = None,
|
| 59 |
**reader_kwargs: Any,
|
| 60 |
-
) -> ray.data.Dataset:
|
| 61 |
"""
|
| 62 |
Unified entry point to read files of multiple types using Ray Data.
|
| 63 |
|
|
@@ -71,6 +74,8 @@ def read(
|
|
| 71 |
:param reader_kwargs: Additional kwargs passed to readers
|
| 72 |
:return: Ray Dataset containing all documents
|
| 73 |
"""
|
|
|
|
|
|
|
| 74 |
input_path_cache = init_storage(
|
| 75 |
backend=kv_backend, working_dir=working_dir, namespace="input_path"
|
| 76 |
)
|
|
|
|
| 1 |
from pathlib import Path
|
| 2 |
+
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
|
|
|
|
|
|
| 3 |
|
| 4 |
from graphgen.common.init_storage import init_storage
|
| 5 |
from graphgen.models import (
|
|
|
|
| 15 |
|
| 16 |
from .parallel_file_scanner import ParallelFileScanner
|
| 17 |
|
| 18 |
+
if TYPE_CHECKING:
|
| 19 |
+
import ray
|
| 20 |
+
import ray.data
|
| 21 |
+
|
| 22 |
+
|
| 23 |
_MAPPING = {
|
| 24 |
"jsonl": JSONReader,
|
| 25 |
"json": JSONReader,
|
|
|
|
| 60 |
recursive: bool = True,
|
| 61 |
read_nums: Optional[int] = None,
|
| 62 |
**reader_kwargs: Any,
|
| 63 |
+
) -> "ray.data.Dataset":
|
| 64 |
"""
|
| 65 |
Unified entry point to read files of multiple types using Ray Data.
|
| 66 |
|
|
|
|
| 74 |
:param reader_kwargs: Additional kwargs passed to readers
|
| 75 |
:return: Ray Dataset containing all documents
|
| 76 |
"""
|
| 77 |
+
import ray
|
| 78 |
+
|
| 79 |
input_path_cache = init_storage(
|
| 80 |
backend=kv_backend, working_dir=working_dir, namespace="input_path"
|
| 81 |
)
|
graphgen/operators/search/search_service.py
CHANGED
|
@@ -1,12 +1,13 @@
|
|
| 1 |
from functools import partial
|
| 2 |
-
from typing import Optional
|
| 3 |
-
|
| 4 |
-
import pandas as pd
|
| 5 |
|
| 6 |
from graphgen.bases import BaseOperator
|
| 7 |
from graphgen.common.init_storage import init_storage
|
| 8 |
from graphgen.utils import compute_content_hash, logger, run_concurrent
|
| 9 |
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
class SearchService(BaseOperator):
|
| 12 |
"""
|
|
@@ -136,7 +137,9 @@ class SearchService(BaseOperator):
|
|
| 136 |
|
| 137 |
return final_results
|
| 138 |
|
| 139 |
-
def process(self, batch: pd.DataFrame) -> pd.DataFrame:
|
|
|
|
|
|
|
| 140 |
docs = batch.to_dict(orient="records")
|
| 141 |
|
| 142 |
self._init_searchers()
|
|
|
|
| 1 |
from functools import partial
|
| 2 |
+
from typing import TYPE_CHECKING, Optional
|
|
|
|
|
|
|
| 3 |
|
| 4 |
from graphgen.bases import BaseOperator
|
| 5 |
from graphgen.common.init_storage import init_storage
|
| 6 |
from graphgen.utils import compute_content_hash, logger, run_concurrent
|
| 7 |
|
| 8 |
+
if TYPE_CHECKING:
|
| 9 |
+
import pandas as pd
|
| 10 |
+
|
| 11 |
|
| 12 |
class SearchService(BaseOperator):
|
| 13 |
"""
|
|
|
|
| 137 |
|
| 138 |
return final_results
|
| 139 |
|
| 140 |
+
def process(self, batch: "pd.DataFrame") -> "pd.DataFrame":
|
| 141 |
+
import pandas as pd
|
| 142 |
+
|
| 143 |
docs = batch.to_dict(orient="records")
|
| 144 |
|
| 145 |
self._init_searchers()
|
graphgen/storage/graph/networkx_storage.py
CHANGED
|
@@ -26,6 +26,7 @@ class NetworkXStorage(BaseGraphStorage):
|
|
| 26 |
return self._graph.number_of_edges()
|
| 27 |
|
| 28 |
def get_connected_components(self, undirected: bool = True) -> List[Set[str]]:
|
|
|
|
| 29 |
graph = self._graph
|
| 30 |
|
| 31 |
if undirected and graph.is_directed():
|
|
@@ -36,24 +37,27 @@ class NetworkXStorage(BaseGraphStorage):
|
|
| 36 |
]
|
| 37 |
|
| 38 |
@staticmethod
|
| 39 |
-
def load_nx_graph(file_name) -> Optional[nx.Graph]:
|
|
|
|
| 40 |
if os.path.exists(file_name):
|
| 41 |
return nx.read_graphml(file_name)
|
| 42 |
return None
|
| 43 |
|
| 44 |
@staticmethod
|
| 45 |
-
def write_nx_graph(graph: nx.Graph, file_name):
|
|
|
|
| 46 |
nx.write_graphml(graph, file_name)
|
| 47 |
|
| 48 |
@staticmethod
|
| 49 |
-
def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
|
| 50 |
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
|
| 51 |
Return the largest connected component of the graph, with nodes and edges sorted in a stable way.
|
| 52 |
"""
|
|
|
|
| 53 |
from graspologic.utils import largest_connected_component
|
| 54 |
|
| 55 |
graph = graph.copy()
|
| 56 |
-
graph = cast(nx.Graph, largest_connected_component(graph))
|
| 57 |
node_mapping = {
|
| 58 |
node: html.unescape(node.upper().strip()) for node in graph.nodes()
|
| 59 |
} # type: ignore
|
|
@@ -61,11 +65,12 @@ class NetworkXStorage(BaseGraphStorage):
|
|
| 61 |
return NetworkXStorage._stabilize_graph(graph)
|
| 62 |
|
| 63 |
@staticmethod
|
| 64 |
-
def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
|
| 65 |
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
|
| 66 |
Ensure an undirected graph with the same relationships will always be read the same way.
|
| 67 |
通过对节点和边进行排序来实现
|
| 68 |
"""
|
|
|
|
| 69 |
fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
|
| 70 |
|
| 71 |
sorted_nodes = graph.nodes(data=True)
|
|
@@ -97,6 +102,7 @@ class NetworkXStorage(BaseGraphStorage):
|
|
| 97 |
Initialize the NetworkX graph storage by loading an existing graph from a GraphML file,
|
| 98 |
if it exists, or creating a new empty graph otherwise.
|
| 99 |
"""
|
|
|
|
| 100 |
self._graphml_xml_file = os.path.join(
|
| 101 |
self.working_dir, f"{self.namespace}.graphml"
|
| 102 |
)
|
|
@@ -141,7 +147,7 @@ class NetworkXStorage(BaseGraphStorage):
|
|
| 141 |
return list(self._graph.edges(source_node_id, data=True))
|
| 142 |
return None
|
| 143 |
|
| 144 |
-
def get_graph(self) -> nx.Graph:
|
| 145 |
return self._graph
|
| 146 |
|
| 147 |
def upsert_node(self, node_id: str, node_data: dict[str, any]):
|
|
|
|
| 26 |
return self._graph.number_of_edges()
|
| 27 |
|
| 28 |
def get_connected_components(self, undirected: bool = True) -> List[Set[str]]:
|
| 29 |
+
|
| 30 |
graph = self._graph
|
| 31 |
|
| 32 |
if undirected and graph.is_directed():
|
|
|
|
| 37 |
]
|
| 38 |
|
| 39 |
@staticmethod
|
| 40 |
+
def load_nx_graph(file_name) -> Optional["nx.Graph"]:
|
| 41 |
+
|
| 42 |
if os.path.exists(file_name):
|
| 43 |
return nx.read_graphml(file_name)
|
| 44 |
return None
|
| 45 |
|
| 46 |
@staticmethod
|
| 47 |
+
def write_nx_graph(graph: "nx.Graph", file_name):
|
| 48 |
+
|
| 49 |
nx.write_graphml(graph, file_name)
|
| 50 |
|
| 51 |
@staticmethod
|
| 52 |
+
def stable_largest_connected_component(graph: "nx.Graph") -> "nx.Graph":
|
| 53 |
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
|
| 54 |
Return the largest connected component of the graph, with nodes and edges sorted in a stable way.
|
| 55 |
"""
|
| 56 |
+
|
| 57 |
from graspologic.utils import largest_connected_component
|
| 58 |
|
| 59 |
graph = graph.copy()
|
| 60 |
+
graph = cast("nx.Graph", largest_connected_component(graph))
|
| 61 |
node_mapping = {
|
| 62 |
node: html.unescape(node.upper().strip()) for node in graph.nodes()
|
| 63 |
} # type: ignore
|
|
|
|
| 65 |
return NetworkXStorage._stabilize_graph(graph)
|
| 66 |
|
| 67 |
@staticmethod
|
| 68 |
+
def _stabilize_graph(graph: "nx.Graph") -> "nx.Graph":
|
| 69 |
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
|
| 70 |
Ensure an undirected graph with the same relationships will always be read the same way.
|
| 71 |
通过对节点和边进行排序来实现
|
| 72 |
"""
|
| 73 |
+
|
| 74 |
fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
|
| 75 |
|
| 76 |
sorted_nodes = graph.nodes(data=True)
|
|
|
|
| 102 |
Initialize the NetworkX graph storage by loading an existing graph from a GraphML file,
|
| 103 |
if it exists, or creating a new empty graph otherwise.
|
| 104 |
"""
|
| 105 |
+
|
| 106 |
self._graphml_xml_file = os.path.join(
|
| 107 |
self.working_dir, f"{self.namespace}.graphml"
|
| 108 |
)
|
|
|
|
| 147 |
return list(self._graph.edges(source_node_id, data=True))
|
| 148 |
return None
|
| 149 |
|
| 150 |
+
def get_graph(self) -> "nx.Graph":
|
| 151 |
return self._graph
|
| 152 |
|
| 153 |
def upsert_node(self, node_id: str, node_data: dict[str, any]):
|