Spaces:
Running
Running
github-actions[bot]
commited on
Commit
·
ad9c5d9
1
Parent(s):
4331db7
Auto-sync from demo at Tue Jan 13 14:51:14 UTC 2026
Browse files- graphgen/models/kg_builder/light_rag_kg_builder.py +12 -2
- graphgen/models/partitioner/ece_partitioner.py +5 -3
- graphgen/operators/build_kg/build_kg_service.py +18 -5
- graphgen/operators/build_kg/build_mm_kg.py +5 -3
- graphgen/operators/build_kg/build_text_kg.py +5 -3
- graphgen/operators/evaluate/evaluate_service.py +6 -6
- graphgen/operators/partition/partition_service.py +6 -1
- graphgen/operators/quiz/quiz_service.py +33 -45
- requirements.txt +1 -1
graphgen/models/kg_builder/light_rag_kg_builder.py
CHANGED
|
@@ -99,7 +99,7 @@ class LightRAGKGBuilder(BaseKGBuilder):
|
|
| 99 |
self,
|
| 100 |
node_data: tuple[str, List[dict]],
|
| 101 |
kg_instance: BaseGraphStorage,
|
| 102 |
-
) ->
|
| 103 |
entity_name, node_data = node_data
|
| 104 |
entity_types = []
|
| 105 |
source_ids = []
|
|
@@ -131,16 +131,18 @@ class LightRAGKGBuilder(BaseKGBuilder):
|
|
| 131 |
|
| 132 |
node_data = {
|
| 133 |
"entity_type": entity_type,
|
|
|
|
| 134 |
"description": description,
|
| 135 |
"source_id": source_id,
|
| 136 |
}
|
| 137 |
kg_instance.upsert_node(entity_name, node_data=node_data)
|
|
|
|
| 138 |
|
| 139 |
async def merge_edges(
|
| 140 |
self,
|
| 141 |
edges_data: tuple[Tuple[str, str], List[dict]],
|
| 142 |
kg_instance: BaseGraphStorage,
|
| 143 |
-
) ->
|
| 144 |
(src_id, tgt_id), edge_data = edges_data
|
| 145 |
|
| 146 |
source_ids = []
|
|
@@ -175,11 +177,19 @@ class LightRAGKGBuilder(BaseKGBuilder):
|
|
| 175 |
f"({src_id}, {tgt_id})", description
|
| 176 |
)
|
| 177 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
kg_instance.upsert_edge(
|
| 179 |
src_id,
|
| 180 |
tgt_id,
|
| 181 |
edge_data={"source_id": source_id, "description": description},
|
| 182 |
)
|
|
|
|
| 183 |
|
| 184 |
async def _handle_kg_summary(
|
| 185 |
self,
|
|
|
|
| 99 |
self,
|
| 100 |
node_data: tuple[str, List[dict]],
|
| 101 |
kg_instance: BaseGraphStorage,
|
| 102 |
+
) -> dict:
|
| 103 |
entity_name, node_data = node_data
|
| 104 |
entity_types = []
|
| 105 |
source_ids = []
|
|
|
|
| 131 |
|
| 132 |
node_data = {
|
| 133 |
"entity_type": entity_type,
|
| 134 |
+
"entity_name": entity_name,
|
| 135 |
"description": description,
|
| 136 |
"source_id": source_id,
|
| 137 |
}
|
| 138 |
kg_instance.upsert_node(entity_name, node_data=node_data)
|
| 139 |
+
return node_data
|
| 140 |
|
| 141 |
async def merge_edges(
|
| 142 |
self,
|
| 143 |
edges_data: tuple[Tuple[str, str], List[dict]],
|
| 144 |
kg_instance: BaseGraphStorage,
|
| 145 |
+
) -> dict:
|
| 146 |
(src_id, tgt_id), edge_data = edges_data
|
| 147 |
|
| 148 |
source_ids = []
|
|
|
|
| 177 |
f"({src_id}, {tgt_id})", description
|
| 178 |
)
|
| 179 |
|
| 180 |
+
edge_data = {
|
| 181 |
+
"src_id": src_id,
|
| 182 |
+
"tgt_id": tgt_id,
|
| 183 |
+
"description": description,
|
| 184 |
+
"source_id": source_id, # for traceability
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
kg_instance.upsert_edge(
|
| 188 |
src_id,
|
| 189 |
tgt_id,
|
| 190 |
edge_data={"source_id": source_id, "description": description},
|
| 191 |
)
|
| 192 |
+
return edge_data
|
| 193 |
|
| 194 |
async def _handle_kg_summary(
|
| 195 |
self,
|
graphgen/models/partitioner/ece_partitioner.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import random
|
| 2 |
from collections import deque
|
| 3 |
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
|
|
@@ -34,17 +35,18 @@ class ECEPartitioner(BFSPartitioner):
|
|
| 34 |
:param edge_sampling: edge sampling strategy (random, min_loss, max_loss)
|
| 35 |
:return: sorted units
|
| 36 |
"""
|
|
|
|
| 37 |
if edge_sampling == "random":
|
| 38 |
random.shuffle(units)
|
| 39 |
elif edge_sampling == "min_loss":
|
| 40 |
units = sorted(
|
| 41 |
units,
|
| 42 |
-
key=lambda x: x[-1]
|
| 43 |
)
|
| 44 |
elif edge_sampling == "max_loss":
|
| 45 |
units = sorted(
|
| 46 |
units,
|
| 47 |
-
key=lambda x: x[-1]
|
| 48 |
reverse=True,
|
| 49 |
)
|
| 50 |
else:
|
|
@@ -142,7 +144,7 @@ class ECEPartitioner(BFSPartitioner):
|
|
| 142 |
return Community(
|
| 143 |
id=seed_unit[1],
|
| 144 |
nodes=list(community_nodes.keys()),
|
| 145 |
-
edges=[tuple(sorted(e)) for e in community_edges]
|
| 146 |
)
|
| 147 |
|
| 148 |
for unit in tqdm(all_units, desc="ECE partition"):
|
|
|
|
| 1 |
+
import math
|
| 2 |
import random
|
| 3 |
from collections import deque
|
| 4 |
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
|
|
|
|
| 35 |
:param edge_sampling: edge sampling strategy (random, min_loss, max_loss)
|
| 36 |
:return: sorted units
|
| 37 |
"""
|
| 38 |
+
default_loss = -math.log(0.1)
|
| 39 |
if edge_sampling == "random":
|
| 40 |
random.shuffle(units)
|
| 41 |
elif edge_sampling == "min_loss":
|
| 42 |
units = sorted(
|
| 43 |
units,
|
| 44 |
+
key=lambda x: x[-1].get("loss", default_loss),
|
| 45 |
)
|
| 46 |
elif edge_sampling == "max_loss":
|
| 47 |
units = sorted(
|
| 48 |
units,
|
| 49 |
+
key=lambda x: x[-1].get("loss", default_loss),
|
| 50 |
reverse=True,
|
| 51 |
)
|
| 52 |
else:
|
|
|
|
| 144 |
return Community(
|
| 145 |
id=seed_unit[1],
|
| 146 |
nodes=list(community_nodes.keys()),
|
| 147 |
+
edges=[tuple(sorted(e)) for e in community_edges],
|
| 148 |
)
|
| 149 |
|
| 150 |
for unit in tqdm(all_units, desc="ECE partition"):
|
graphgen/operators/build_kg/build_kg_service.py
CHANGED
|
@@ -28,10 +28,13 @@ class BuildKGService(BaseOperator):
|
|
| 28 |
docs = [Chunk.from_dict(doc["_chunk_id"], doc) for doc in docs]
|
| 29 |
|
| 30 |
# consume the chunks and build kg
|
| 31 |
-
self.build_kg(docs)
|
| 32 |
-
return pd.DataFrame(
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
-
def build_kg(self, chunks: List[Chunk]) ->
|
| 35 |
"""
|
| 36 |
Build knowledge graph (KG) and merge into kg_instance
|
| 37 |
"""
|
|
@@ -42,24 +45,34 @@ class BuildKGService(BaseOperator):
|
|
| 42 |
if chunk.type in ("image", "video", "table", "formula")
|
| 43 |
]
|
| 44 |
|
|
|
|
|
|
|
|
|
|
| 45 |
if len(text_chunks) == 0:
|
| 46 |
logger.info("All text chunks are already in the storage")
|
| 47 |
else:
|
| 48 |
logger.info("[Text Entity and Relation Extraction] processing ...")
|
| 49 |
-
build_text_kg(
|
| 50 |
llm_client=self.llm_client,
|
| 51 |
kg_instance=self.graph_storage,
|
| 52 |
chunks=text_chunks,
|
| 53 |
max_loop=self.max_loop,
|
| 54 |
)
|
|
|
|
|
|
|
| 55 |
if len(mm_chunks) == 0:
|
| 56 |
logger.info("All multi-modal chunks are already in the storage")
|
| 57 |
else:
|
| 58 |
logger.info("[Multi-modal Entity and Relation Extraction] processing ...")
|
| 59 |
-
build_mm_kg(
|
| 60 |
llm_client=self.llm_client,
|
| 61 |
kg_instance=self.graph_storage,
|
| 62 |
chunks=mm_chunks,
|
| 63 |
)
|
|
|
|
|
|
|
| 64 |
|
| 65 |
self.graph_storage.index_done_callback()
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
docs = [Chunk.from_dict(doc["_chunk_id"], doc) for doc in docs]
|
| 29 |
|
| 30 |
# consume the chunks and build kg
|
| 31 |
+
nodes, edges = self.build_kg(docs)
|
| 32 |
+
return pd.DataFrame(
|
| 33 |
+
[{"node": node, "edge": []} for node in nodes]
|
| 34 |
+
+ [{"node": [], "edge": edge} for edge in edges]
|
| 35 |
+
)
|
| 36 |
|
| 37 |
+
def build_kg(self, chunks: List[Chunk]) -> tuple:
|
| 38 |
"""
|
| 39 |
Build knowledge graph (KG) and merge into kg_instance
|
| 40 |
"""
|
|
|
|
| 45 |
if chunk.type in ("image", "video", "table", "formula")
|
| 46 |
]
|
| 47 |
|
| 48 |
+
nodes = []
|
| 49 |
+
edges = []
|
| 50 |
+
|
| 51 |
if len(text_chunks) == 0:
|
| 52 |
logger.info("All text chunks are already in the storage")
|
| 53 |
else:
|
| 54 |
logger.info("[Text Entity and Relation Extraction] processing ...")
|
| 55 |
+
text_nodes, text_edges = build_text_kg(
|
| 56 |
llm_client=self.llm_client,
|
| 57 |
kg_instance=self.graph_storage,
|
| 58 |
chunks=text_chunks,
|
| 59 |
max_loop=self.max_loop,
|
| 60 |
)
|
| 61 |
+
nodes += text_nodes
|
| 62 |
+
edges += text_edges
|
| 63 |
if len(mm_chunks) == 0:
|
| 64 |
logger.info("All multi-modal chunks are already in the storage")
|
| 65 |
else:
|
| 66 |
logger.info("[Multi-modal Entity and Relation Extraction] processing ...")
|
| 67 |
+
mm_nodes, mm_edges = build_mm_kg(
|
| 68 |
llm_client=self.llm_client,
|
| 69 |
kg_instance=self.graph_storage,
|
| 70 |
chunks=mm_chunks,
|
| 71 |
)
|
| 72 |
+
nodes += mm_nodes
|
| 73 |
+
edges += mm_edges
|
| 74 |
|
| 75 |
self.graph_storage.index_done_callback()
|
| 76 |
+
logger.info("Knowledge graph building completed.")
|
| 77 |
+
|
| 78 |
+
return nodes, edges
|
graphgen/operators/build_kg/build_mm_kg.py
CHANGED
|
@@ -12,7 +12,7 @@ def build_mm_kg(
|
|
| 12 |
llm_client: BaseLLMWrapper,
|
| 13 |
kg_instance: BaseGraphStorage,
|
| 14 |
chunks: List[Chunk],
|
| 15 |
-
):
|
| 16 |
"""
|
| 17 |
Build multi-modal KG and merge into kg_instance
|
| 18 |
:param llm_client: Synthesizer LLM model to extract entities and relationships
|
|
@@ -37,14 +37,16 @@ def build_mm_kg(
|
|
| 37 |
for k, v in e.items():
|
| 38 |
edges[tuple(sorted(k))].extend(v)
|
| 39 |
|
| 40 |
-
run_concurrent(
|
| 41 |
lambda kv: mm_builder.merge_nodes(kv, kg_instance=kg_instance),
|
| 42 |
list(nodes.items()),
|
| 43 |
desc="Inserting entities into storage",
|
| 44 |
)
|
| 45 |
|
| 46 |
-
run_concurrent(
|
| 47 |
lambda kv: mm_builder.merge_edges(kv, kg_instance=kg_instance),
|
| 48 |
list(edges.items()),
|
| 49 |
desc="Inserting relationships into storage",
|
| 50 |
)
|
|
|
|
|
|
|
|
|
| 12 |
llm_client: BaseLLMWrapper,
|
| 13 |
kg_instance: BaseGraphStorage,
|
| 14 |
chunks: List[Chunk],
|
| 15 |
+
) -> tuple:
|
| 16 |
"""
|
| 17 |
Build multi-modal KG and merge into kg_instance
|
| 18 |
:param llm_client: Synthesizer LLM model to extract entities and relationships
|
|
|
|
| 37 |
for k, v in e.items():
|
| 38 |
edges[tuple(sorted(k))].extend(v)
|
| 39 |
|
| 40 |
+
nodes = run_concurrent(
|
| 41 |
lambda kv: mm_builder.merge_nodes(kv, kg_instance=kg_instance),
|
| 42 |
list(nodes.items()),
|
| 43 |
desc="Inserting entities into storage",
|
| 44 |
)
|
| 45 |
|
| 46 |
+
edges = run_concurrent(
|
| 47 |
lambda kv: mm_builder.merge_edges(kv, kg_instance=kg_instance),
|
| 48 |
list(edges.items()),
|
| 49 |
desc="Inserting relationships into storage",
|
| 50 |
)
|
| 51 |
+
|
| 52 |
+
return nodes, edges
|
graphgen/operators/build_kg/build_text_kg.py
CHANGED
|
@@ -13,7 +13,7 @@ def build_text_kg(
|
|
| 13 |
kg_instance: BaseGraphStorage,
|
| 14 |
chunks: List[Chunk],
|
| 15 |
max_loop: int = 3,
|
| 16 |
-
):
|
| 17 |
"""
|
| 18 |
:param llm_client: Synthesizer LLM model to extract entities and relationships
|
| 19 |
:param kg_instance
|
|
@@ -39,14 +39,16 @@ def build_text_kg(
|
|
| 39 |
for k, v in e.items():
|
| 40 |
edges[tuple(sorted(k))].extend(v)
|
| 41 |
|
| 42 |
-
run_concurrent(
|
| 43 |
lambda kv: kg_builder.merge_nodes(kv, kg_instance=kg_instance),
|
| 44 |
list(nodes.items()),
|
| 45 |
desc="Inserting entities into storage",
|
| 46 |
)
|
| 47 |
|
| 48 |
-
run_concurrent(
|
| 49 |
lambda kv: kg_builder.merge_edges(kv, kg_instance=kg_instance),
|
| 50 |
list(edges.items()),
|
| 51 |
desc="Inserting relationships into storage",
|
| 52 |
)
|
|
|
|
|
|
|
|
|
| 13 |
kg_instance: BaseGraphStorage,
|
| 14 |
chunks: List[Chunk],
|
| 15 |
max_loop: int = 3,
|
| 16 |
+
) -> tuple:
|
| 17 |
"""
|
| 18 |
:param llm_client: Synthesizer LLM model to extract entities and relationships
|
| 19 |
:param kg_instance
|
|
|
|
| 39 |
for k, v in e.items():
|
| 40 |
edges[tuple(sorted(k))].extend(v)
|
| 41 |
|
| 42 |
+
nodes = run_concurrent(
|
| 43 |
lambda kv: kg_builder.merge_nodes(kv, kg_instance=kg_instance),
|
| 44 |
list(nodes.items()),
|
| 45 |
desc="Inserting entities into storage",
|
| 46 |
)
|
| 47 |
|
| 48 |
+
edges = run_concurrent(
|
| 49 |
lambda kv: kg_builder.merge_edges(kv, kg_instance=kg_instance),
|
| 50 |
list(edges.items()),
|
| 51 |
desc="Inserting relationships into storage",
|
| 52 |
)
|
| 53 |
+
|
| 54 |
+
return nodes, edges
|
graphgen/operators/evaluate/evaluate_service.py
CHANGED
|
@@ -95,10 +95,10 @@ class EvaluateService(BaseOperator):
|
|
| 95 |
answer=str(item.get("answer", "")),
|
| 96 |
)
|
| 97 |
if not qa_pair.question or not qa_pair.answer:
|
| 98 |
-
|
| 99 |
return {}
|
| 100 |
except Exception as e:
|
| 101 |
-
|
| 102 |
return {}
|
| 103 |
|
| 104 |
for metric, evaluator in self.qa_evaluators.items():
|
|
@@ -110,7 +110,7 @@ class EvaluateService(BaseOperator):
|
|
| 110 |
else:
|
| 111 |
item[metric] = float(score)
|
| 112 |
except Exception as e:
|
| 113 |
-
|
| 114 |
item[metric] = None
|
| 115 |
return item
|
| 116 |
|
|
@@ -136,7 +136,7 @@ class EvaluateService(BaseOperator):
|
|
| 136 |
return []
|
| 137 |
|
| 138 |
if not self.qa_evaluators:
|
| 139 |
-
|
| 140 |
return []
|
| 141 |
|
| 142 |
items = transform_messages_format(items)
|
|
@@ -155,11 +155,11 @@ class EvaluateService(BaseOperator):
|
|
| 155 |
|
| 156 |
for metric, evaluator in self.kg_evaluators.items():
|
| 157 |
try:
|
| 158 |
-
|
| 159 |
score = evaluator.evaluate()
|
| 160 |
results[metric] = score
|
| 161 |
except Exception as e:
|
| 162 |
-
|
| 163 |
results[metric] = {"error": str(e)}
|
| 164 |
return results
|
| 165 |
|
|
|
|
| 95 |
answer=str(item.get("answer", "")),
|
| 96 |
)
|
| 97 |
if not qa_pair.question or not qa_pair.answer:
|
| 98 |
+
logger.error("Empty question or answer, skipping.")
|
| 99 |
return {}
|
| 100 |
except Exception as e:
|
| 101 |
+
logger.error("Error in QAPair creation: %s", str(e))
|
| 102 |
return {}
|
| 103 |
|
| 104 |
for metric, evaluator in self.qa_evaluators.items():
|
|
|
|
| 110 |
else:
|
| 111 |
item[metric] = float(score)
|
| 112 |
except Exception as e:
|
| 113 |
+
logger.error("Error in %s evaluation: %s", metric, str(e))
|
| 114 |
item[metric] = None
|
| 115 |
return item
|
| 116 |
|
|
|
|
| 136 |
return []
|
| 137 |
|
| 138 |
if not self.qa_evaluators:
|
| 139 |
+
logger.warning("No QA evaluators initialized, skipping QA evaluation")
|
| 140 |
return []
|
| 141 |
|
| 142 |
items = transform_messages_format(items)
|
|
|
|
| 155 |
|
| 156 |
for metric, evaluator in self.kg_evaluators.items():
|
| 157 |
try:
|
| 158 |
+
logger.info("Running %s evaluation...", metric)
|
| 159 |
score = evaluator.evaluate()
|
| 160 |
results[metric] = score
|
| 161 |
except Exception as e:
|
| 162 |
+
logger.error("Error in %s evaluation: %s", metric, str(e))
|
| 163 |
results[metric] = {"error": str(e)}
|
| 164 |
return results
|
| 165 |
|
graphgen/operators/partition/partition_service.py
CHANGED
|
@@ -79,9 +79,13 @@ class PartitionService(BaseOperator):
|
|
| 79 |
else:
|
| 80 |
raise ValueError(f"Unsupported partition method: {method}")
|
| 81 |
|
| 82 |
-
communities = partitioner.partition(
|
|
|
|
|
|
|
| 83 |
|
|
|
|
| 84 |
for community in communities:
|
|
|
|
| 85 |
batch = partitioner.community2batch(community, g=self.kg_instance)
|
| 86 |
batch = self._attach_additional_data_to_node(batch)
|
| 87 |
|
|
@@ -91,6 +95,7 @@ class PartitionService(BaseOperator):
|
|
| 91 |
"edges": [batch[1]],
|
| 92 |
}
|
| 93 |
)
|
|
|
|
| 94 |
|
| 95 |
def _pre_tokenize(self) -> None:
|
| 96 |
"""Pre-tokenize all nodes and edges to add token length information."""
|
|
|
|
| 79 |
else:
|
| 80 |
raise ValueError(f"Unsupported partition method: {method}")
|
| 81 |
|
| 82 |
+
communities: Iterable = partitioner.partition(
|
| 83 |
+
g=self.kg_instance, **method_params
|
| 84 |
+
)
|
| 85 |
|
| 86 |
+
count = 0
|
| 87 |
for community in communities:
|
| 88 |
+
count += 1
|
| 89 |
batch = partitioner.community2batch(community, g=self.kg_instance)
|
| 90 |
batch = self._attach_additional_data_to_node(batch)
|
| 91 |
|
|
|
|
| 95 |
"edges": [batch[1]],
|
| 96 |
}
|
| 97 |
)
|
| 98 |
+
logger.info("Total communities partitioned: %d", count)
|
| 99 |
|
| 100 |
def _pre_tokenize(self) -> None:
|
| 101 |
"""Pre-tokenize all nodes and edges to add token length information."""
|
graphgen/operators/quiz/quiz_service.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
| 1 |
-
from collections.abc import Iterable
|
| 2 |
-
|
| 3 |
import pandas as pd
|
| 4 |
|
| 5 |
from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper, BaseOperator
|
|
@@ -15,7 +13,6 @@ class QuizService(BaseOperator):
|
|
| 15 |
graph_backend: str = "kuzu",
|
| 16 |
kv_backend: str = "rocksdb",
|
| 17 |
quiz_samples: int = 1,
|
| 18 |
-
concurrency_limit: int = 200,
|
| 19 |
):
|
| 20 |
super().__init__(working_dir=working_dir, op_name="quiz_service")
|
| 21 |
self.quiz_samples = quiz_samples
|
|
@@ -28,21 +25,16 @@ class QuizService(BaseOperator):
|
|
| 28 |
backend=kv_backend, working_dir=working_dir, namespace="quiz"
|
| 29 |
)
|
| 30 |
self.generator = QuizGenerator(self.llm_client)
|
| 31 |
-
self.concurrency_limit = concurrency_limit
|
| 32 |
|
| 33 |
-
def process(self, batch: pd.DataFrame) ->
|
| 34 |
-
|
| 35 |
-
# but for compatibility we keep the interface
|
| 36 |
-
_ = batch.to_dict(orient="records")
|
| 37 |
self.graph_storage.reload()
|
| 38 |
-
|
| 39 |
|
| 40 |
async def _process_single_quiz(self, item: tuple) -> dict | None:
|
| 41 |
# if quiz in quiz_storage exists already, directly get it
|
| 42 |
index, desc = item
|
| 43 |
_quiz_id = compute_dict_hash({"index": index, "description": desc})
|
| 44 |
-
if self.quiz_storage.get_by_id(_quiz_id):
|
| 45 |
-
return None
|
| 46 |
|
| 47 |
tasks = []
|
| 48 |
for i in range(self.quiz_samples):
|
|
@@ -68,47 +60,43 @@ class QuizService(BaseOperator):
|
|
| 68 |
logger.error("Error when quizzing description %s: %s", item, e)
|
| 69 |
return None
|
| 70 |
|
| 71 |
-
def quiz(self) ->
|
| 72 |
"""
|
| 73 |
Get all nodes and edges and quiz their descriptions using QuizGenerator.
|
| 74 |
"""
|
| 75 |
-
edges = self.graph_storage.get_all_edges()
|
| 76 |
-
nodes = self.graph_storage.get_all_nodes()
|
| 77 |
-
|
| 78 |
items = []
|
| 79 |
|
| 80 |
-
for
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
items.append(((edge[0], edge[1]), desc))
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
logger.info("Total descriptions to quiz: %d", len(items))
|
| 91 |
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
)
|
| 100 |
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
final_results.append(new_result)
|
| 113 |
-
self.quiz_storage.index_done_callback()
|
| 114 |
-
yield pd.DataFrame(final_results)
|
|
|
|
|
|
|
|
|
|
| 1 |
import pandas as pd
|
| 2 |
|
| 3 |
from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper, BaseOperator
|
|
|
|
| 13 |
graph_backend: str = "kuzu",
|
| 14 |
kv_backend: str = "rocksdb",
|
| 15 |
quiz_samples: int = 1,
|
|
|
|
| 16 |
):
|
| 17 |
super().__init__(working_dir=working_dir, op_name="quiz_service")
|
| 18 |
self.quiz_samples = quiz_samples
|
|
|
|
| 25 |
backend=kv_backend, working_dir=working_dir, namespace="quiz"
|
| 26 |
)
|
| 27 |
self.generator = QuizGenerator(self.llm_client)
|
|
|
|
| 28 |
|
| 29 |
+
def process(self, batch: pd.DataFrame) -> pd.DataFrame:
|
| 30 |
+
data = batch.to_dict(orient="records")
|
|
|
|
|
|
|
| 31 |
self.graph_storage.reload()
|
| 32 |
+
return self.quiz(data)
|
| 33 |
|
| 34 |
async def _process_single_quiz(self, item: tuple) -> dict | None:
|
| 35 |
# if quiz in quiz_storage exists already, directly get it
|
| 36 |
index, desc = item
|
| 37 |
_quiz_id = compute_dict_hash({"index": index, "description": desc})
|
|
|
|
|
|
|
| 38 |
|
| 39 |
tasks = []
|
| 40 |
for i in range(self.quiz_samples):
|
|
|
|
| 60 |
logger.error("Error when quizzing description %s: %s", item, e)
|
| 61 |
return None
|
| 62 |
|
| 63 |
+
def quiz(self, batch) -> pd.DataFrame:
|
| 64 |
"""
|
| 65 |
Get all nodes and edges and quiz their descriptions using QuizGenerator.
|
| 66 |
"""
|
|
|
|
|
|
|
|
|
|
| 67 |
items = []
|
| 68 |
|
| 69 |
+
for item in batch:
|
| 70 |
+
node_data = item.get("node", [])
|
| 71 |
+
edge_data = item.get("edge", [])
|
|
|
|
| 72 |
|
| 73 |
+
if node_data:
|
| 74 |
+
node_id = node_data["entity_name"]
|
| 75 |
+
desc = node_data["description"]
|
| 76 |
+
items.append((node_id, desc))
|
| 77 |
+
if edge_data:
|
| 78 |
+
edge_key = (edge_data["src_id"], edge_data["tgt_id"])
|
| 79 |
+
desc = edge_data["description"]
|
| 80 |
+
items.append((edge_key, desc))
|
| 81 |
|
| 82 |
logger.info("Total descriptions to quiz: %d", len(items))
|
| 83 |
|
| 84 |
+
results = run_concurrent(
|
| 85 |
+
self._process_single_quiz,
|
| 86 |
+
items,
|
| 87 |
+
desc=f"Quizzing batch of {len(items)} descriptions",
|
| 88 |
+
unit="description",
|
| 89 |
+
)
|
| 90 |
+
valid_results = [res for res in results if res]
|
|
|
|
| 91 |
|
| 92 |
+
for res in valid_results:
|
| 93 |
+
self.quiz_storage.upsert(
|
| 94 |
+
{
|
| 95 |
+
res["_quiz_id"]: {
|
| 96 |
+
"description": res["description"],
|
| 97 |
+
"quizzes": res["quizzes"],
|
| 98 |
+
}
|
| 99 |
+
}
|
| 100 |
+
)
|
| 101 |
+
self.quiz_storage.index_done_callback()
|
| 102 |
+
return pd.DataFrame(valid_results)
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -22,7 +22,7 @@ trafilatura
|
|
| 22 |
aiohttp
|
| 23 |
socksio
|
| 24 |
pydantic
|
| 25 |
-
ray==2.
|
| 26 |
pyarrow
|
| 27 |
|
| 28 |
leidenalg
|
|
|
|
| 22 |
aiohttp
|
| 23 |
socksio
|
| 24 |
pydantic
|
| 25 |
+
ray==2.53.0
|
| 26 |
pyarrow
|
| 27 |
|
| 28 |
leidenalg
|