File size: 3,337 Bytes
31086ae
0bd1b0f
31086ae
0bd1b0f
0a18089
31086ae
 
 
 
 
 
 
 
0bd1b0f
31086ae
 
0bd1b0f
 
 
31086ae
 
 
 
 
 
5445ab9
 
 
31086ae
0bd1b0f
2202d61
31086ae
 
5445ab9
 
0bd1b0f
31086ae
5445ab9
 
0bd1b0f
31086ae
f5f93a3
 
5445ab9
 
0bd1b0f
31086ae
5445ab9
 
0bd1b0f
31086ae
5445ab9
 
0bd1b0f
 
 
 
31086ae
 
2202d61
 
 
 
 
 
 
 
31086ae
 
 
0bd1b0f
 
 
 
31086ae
0bd1b0f
 
 
31086ae
0bd1b0f
 
 
 
 
31086ae
0bd1b0f
 
 
 
 
 
 
31086ae
0bd1b0f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
from typing import Iterable, Tuple

from graphgen.bases import BaseGraphStorage, BaseOperator, BaseTokenizer
from graphgen.common.init_storage import init_storage
from graphgen.utils import logger


class PartitionService(BaseOperator):
    def __init__(
        self,
        working_dir: str = "cache",
        kv_backend: str = "rocksdb",
        graph_backend: str = "kuzu",
        **partition_kwargs,
    ):
        super().__init__(
            working_dir=working_dir, kv_backend=kv_backend, op_name="partition"
        )
        self.kg_instance: BaseGraphStorage = init_storage(
            backend=graph_backend,
            working_dir=working_dir,
            namespace="graph",
        )
        tokenizer_model = os.getenv("TOKENIZER_MODEL", "cl100k_base")

        from graphgen.models import Tokenizer

        self.tokenizer_instance: BaseTokenizer = Tokenizer(model_name=tokenizer_model)
        method = partition_kwargs["method"]
        self.method_params = partition_kwargs.get("method_params", {})

        if method == "bfs":
            from graphgen.models import BFSPartitioner

            self.partitioner = BFSPartitioner()
        elif method == "dfs":
            from graphgen.models import DFSPartitioner

            self.partitioner = DFSPartitioner()
        elif method == "ece":
            # before ECE partitioning, we need to:
            # 'quiz' and 'judge' to get the comprehension loss if unit_sampling is not random
            from graphgen.models import ECEPartitioner

            self.partitioner = ECEPartitioner()
        elif method == "leiden":
            from graphgen.models import LeidenPartitioner

            self.partitioner = LeidenPartitioner()
        elif method == "anchor_bfs":
            from graphgen.models import AnchorBFSPartitioner

            self.partitioner = AnchorBFSPartitioner(
                anchor_type=self.method_params.get("anchor_type"),
                anchor_ids=set(self.method_params.get("anchor_ids", []))
                if self.method_params.get("anchor_ids")
                else None,
            )
        elif method == "triple":
            from graphgen.models import TriplePartitioner

            self.partitioner = TriplePartitioner()
        elif method == "quintuple":
            from graphgen.models import QuintuplePartitioner

            self.partitioner = QuintuplePartitioner()
        else:
            raise ValueError(f"Unsupported partition method: {method}")

    def process(self, batch: list) -> Tuple[Iterable[dict], dict]:
        # this operator does not consume any batch data
        # but for compatibility we keep the interface
        self.kg_instance.reload()

        communities: Iterable = self.partitioner.partition(
            g=self.kg_instance, **self.method_params
        )

        def generator():
            count = 0
            for community in communities:
                count += 1
                b = self.partitioner.community2batch(community, g=self.kg_instance)

                result = {
                    "nodes": b[0],
                    "edges": b[1],
                }
                result["_trace_id"] = self.get_trace_id(result)
                yield result
            logger.info("Total communities partitioned: %d", count)

        return generator(), {}