Spaces:
Build error
Build error
File size: 3,337 Bytes
31086ae 0bd1b0f 31086ae 0bd1b0f 0a18089 31086ae 0bd1b0f 31086ae 0bd1b0f 31086ae 5445ab9 31086ae 0bd1b0f 2202d61 31086ae 5445ab9 0bd1b0f 31086ae 5445ab9 0bd1b0f 31086ae f5f93a3 5445ab9 0bd1b0f 31086ae 5445ab9 0bd1b0f 31086ae 5445ab9 0bd1b0f 31086ae 2202d61 31086ae 0bd1b0f 31086ae 0bd1b0f 31086ae 0bd1b0f 31086ae 0bd1b0f 31086ae 0bd1b0f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 | import os
from typing import Iterable, Tuple
from graphgen.bases import BaseGraphStorage, BaseOperator, BaseTokenizer
from graphgen.common.init_storage import init_storage
from graphgen.utils import logger
class PartitionService(BaseOperator):
def __init__(
self,
working_dir: str = "cache",
kv_backend: str = "rocksdb",
graph_backend: str = "kuzu",
**partition_kwargs,
):
super().__init__(
working_dir=working_dir, kv_backend=kv_backend, op_name="partition"
)
self.kg_instance: BaseGraphStorage = init_storage(
backend=graph_backend,
working_dir=working_dir,
namespace="graph",
)
tokenizer_model = os.getenv("TOKENIZER_MODEL", "cl100k_base")
from graphgen.models import Tokenizer
self.tokenizer_instance: BaseTokenizer = Tokenizer(model_name=tokenizer_model)
method = partition_kwargs["method"]
self.method_params = partition_kwargs.get("method_params", {})
if method == "bfs":
from graphgen.models import BFSPartitioner
self.partitioner = BFSPartitioner()
elif method == "dfs":
from graphgen.models import DFSPartitioner
self.partitioner = DFSPartitioner()
elif method == "ece":
# before ECE partitioning, we need to:
# 'quiz' and 'judge' to get the comprehension loss if unit_sampling is not random
from graphgen.models import ECEPartitioner
self.partitioner = ECEPartitioner()
elif method == "leiden":
from graphgen.models import LeidenPartitioner
self.partitioner = LeidenPartitioner()
elif method == "anchor_bfs":
from graphgen.models import AnchorBFSPartitioner
self.partitioner = AnchorBFSPartitioner(
anchor_type=self.method_params.get("anchor_type"),
anchor_ids=set(self.method_params.get("anchor_ids", []))
if self.method_params.get("anchor_ids")
else None,
)
elif method == "triple":
from graphgen.models import TriplePartitioner
self.partitioner = TriplePartitioner()
elif method == "quintuple":
from graphgen.models import QuintuplePartitioner
self.partitioner = QuintuplePartitioner()
else:
raise ValueError(f"Unsupported partition method: {method}")
def process(self, batch: list) -> Tuple[Iterable[dict], dict]:
# this operator does not consume any batch data
# but for compatibility we keep the interface
self.kg_instance.reload()
communities: Iterable = self.partitioner.partition(
g=self.kg_instance, **self.method_params
)
def generator():
count = 0
for community in communities:
count += 1
b = self.partitioner.community2batch(community, g=self.kg_instance)
result = {
"nodes": b[0],
"edges": b[1],
}
result["_trace_id"] = self.get_trace_id(result)
yield result
logger.info("Total communities partitioned: %d", count)
return generator(), {}
|