Spaces:
Sleeping
Sleeping
github-actions[bot]
commited on
Commit
·
0bd1b0f
1
Parent(s):
b275e29
Auto-sync from demo at Thu Jan 29 12:51:48 UTC 2026
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- graphgen/bases/__init__.py +1 -1
- graphgen/bases/base_evaluator.py +21 -2
- graphgen/bases/base_generator.py +34 -49
- graphgen/bases/base_operator.py +110 -9
- graphgen/bases/base_storage.py +6 -0
- graphgen/bases/datatypes.py +7 -0
- graphgen/common/init_storage.py +16 -4
- graphgen/models/__init__.py +0 -8
- graphgen/models/evaluator/__init__.py +2 -1
- graphgen/models/evaluator/kg/__init__.py +0 -17
- graphgen/models/evaluator/kg/accuracy_evaluator.py +0 -350
- graphgen/models/evaluator/kg/consistency_evaluator.py +0 -388
- graphgen/models/evaluator/kg/structure_evaluator.py +15 -15
- graphgen/models/evaluator/qa/length_evaluator.py +8 -7
- graphgen/models/evaluator/qa/mtld_evaluator.py +7 -6
- graphgen/models/evaluator/qa/reward_evaluator.py +12 -8
- graphgen/models/evaluator/qa/uni_evaluator.py +17 -9
- graphgen/models/evaluator/triple/__init__.py +1 -0
- graphgen/models/evaluator/triple/accuracy_evaluator.py +94 -0
- graphgen/models/extractor/schema_guided_extractor.py +5 -33
- graphgen/models/generator/aggregated_generator.py +7 -11
- graphgen/models/generator/atomic_generator.py +4 -9
- graphgen/models/generator/cot_generator.py +6 -9
- graphgen/models/generator/fill_in_blank_generator.py +11 -11
- graphgen/models/generator/multi_answer_generator.py +14 -12
- graphgen/models/generator/multi_choice_generator.py +11 -11
- graphgen/models/generator/multi_hop_generator.py +4 -9
- graphgen/models/generator/quiz_generator.py +18 -14
- graphgen/models/generator/true_false_generator.py +10 -10
- graphgen/models/generator/vqa_generator.py +50 -65
- graphgen/models/kg_builder/light_rag_kg_builder.py +14 -3
- graphgen/models/kg_builder/mm_kg_builder.py +2 -0
- graphgen/models/reader/csv_reader.py +1 -1
- graphgen/models/reader/json_reader.py +4 -1
- graphgen/models/reader/parquet_reader.py +1 -1
- graphgen/models/reader/rdf_reader.py +1 -1
- graphgen/models/reader/txt_reader.py +2 -1
- graphgen/models/storage/__init__.py +0 -6
- graphgen/models/storage/rocksdb_cache.py +0 -43
- graphgen/models/vis/__init__.py +0 -0
- graphgen/models/vis/community_visualizer.py +0 -48
- graphgen/operators/build_kg/build_kg_service.py +46 -18
- graphgen/operators/build_kg/build_text_kg.py +1 -0
- graphgen/operators/chunk/chunk_service.py +39 -45
- graphgen/operators/evaluate/evaluate_kg.py +15 -0
- graphgen/operators/evaluate/evaluate_qa.py +107 -0
- graphgen/operators/evaluate/evaluate_service.py +120 -150
- graphgen/operators/evaluate/evaluate_triple.py +39 -0
- graphgen/operators/extract/extract_service.py +31 -20
- graphgen/operators/generate/generate_service.py +30 -28
graphgen/bases/__init__.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
from .base_extractor import BaseExtractor
|
| 2 |
from .base_generator import BaseGenerator
|
| 3 |
from .base_kg_builder import BaseKGBuilder
|
|
@@ -9,5 +10,4 @@ from .base_searcher import BaseSearcher
|
|
| 9 |
from .base_splitter import BaseSplitter
|
| 10 |
from .base_storage import BaseGraphStorage, BaseKVStorage, StorageNameSpace
|
| 11 |
from .base_tokenizer import BaseTokenizer
|
| 12 |
-
from .base_evaluator import BaseEvaluator
|
| 13 |
from .datatypes import Chunk, Config, Node, QAPair, Token
|
|
|
|
| 1 |
+
from .base_evaluator import BaseKGEvaluator, BaseQAEvaluator, BaseTripleEvaluator
|
| 2 |
from .base_extractor import BaseExtractor
|
| 3 |
from .base_generator import BaseGenerator
|
| 4 |
from .base_kg_builder import BaseKGBuilder
|
|
|
|
| 10 |
from .base_splitter import BaseSplitter
|
| 11 |
from .base_storage import BaseGraphStorage, BaseKVStorage, StorageNameSpace
|
| 12 |
from .base_tokenizer import BaseTokenizer
|
|
|
|
| 13 |
from .datatypes import Chunk, Config, Node, QAPair, Token
|
graphgen/bases/base_evaluator.py
CHANGED
|
@@ -1,10 +1,29 @@
|
|
| 1 |
from abc import ABC, abstractmethod
|
|
|
|
|
|
|
|
|
|
| 2 |
from .datatypes import QAPair
|
| 3 |
|
| 4 |
|
| 5 |
-
class
|
| 6 |
@abstractmethod
|
| 7 |
-
def evaluate(self, pair: QAPair) -> float:
|
| 8 |
"""
|
| 9 |
Evaluate the text and return a score.
|
| 10 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import Any
|
| 3 |
+
|
| 4 |
+
from .base_storage import BaseGraphStorage
|
| 5 |
from .datatypes import QAPair
|
| 6 |
|
| 7 |
|
| 8 |
+
class BaseQAEvaluator(ABC):
|
| 9 |
@abstractmethod
|
| 10 |
+
async def evaluate(self, pair: QAPair) -> dict[str, float]:
|
| 11 |
"""
|
| 12 |
Evaluate the text and return a score.
|
| 13 |
"""
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class BaseKGEvaluator(ABC):
|
| 17 |
+
@abstractmethod
|
| 18 |
+
def evaluate(self, kg: BaseGraphStorage) -> dict[str, Any]:
|
| 19 |
+
"""
|
| 20 |
+
Evaluate the whole graph and return a dict of scores.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class BaseTripleEvaluator(ABC):
|
| 25 |
+
@abstractmethod
|
| 26 |
+
async def evaluate(self, unit: dict) -> dict[str, float]:
|
| 27 |
+
"""
|
| 28 |
+
Evaluate a node/edge and return a score.
|
| 29 |
+
"""
|
graphgen/bases/base_generator.py
CHANGED
|
@@ -21,7 +21,7 @@ class BaseGenerator(ABC):
|
|
| 21 |
|
| 22 |
@staticmethod
|
| 23 |
@abstractmethod
|
| 24 |
-
def parse_response(response: str) ->
|
| 25 |
"""Parse the LLM response and return the generated QAs"""
|
| 26 |
|
| 27 |
async def generate(
|
|
@@ -29,64 +29,49 @@ class BaseGenerator(ABC):
|
|
| 29 |
batch: tuple[
|
| 30 |
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
|
| 31 |
],
|
| 32 |
-
) -> dict
|
| 33 |
"""
|
| 34 |
Generate QAs based on a given batch.
|
| 35 |
:param batch
|
| 36 |
:return: QA pairs
|
| 37 |
"""
|
| 38 |
-
result = {}
|
| 39 |
prompt = self.build_prompt(batch)
|
| 40 |
response = await self.llm_client.generate_answer(prompt)
|
| 41 |
qa_pairs = self.parse_response(response) # generate one or more QA pairs
|
| 42 |
-
|
| 43 |
-
return result
|
| 44 |
|
| 45 |
@staticmethod
|
| 46 |
def format_generation_results(
|
| 47 |
-
|
| 48 |
-
) ->
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
options = qa_data["options"]
|
| 57 |
-
options_str = "\n".join(
|
| 58 |
-
[f"{key}. {options[key]}" for key in sorted(options.keys())]
|
| 59 |
-
)
|
| 60 |
-
question += f"\nOptions:\n{options_str}"
|
| 61 |
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
}
|
| 78 |
-
)
|
| 79 |
-
elif output_data_format == "ChatML":
|
| 80 |
-
flat_results.append(
|
| 81 |
-
{
|
| 82 |
-
"messages": [
|
| 83 |
-
{"role": "user", "content": question},
|
| 84 |
-
{"role": "assistant", "content": answer},
|
| 85 |
-
]
|
| 86 |
-
}
|
| 87 |
-
)
|
| 88 |
-
else:
|
| 89 |
-
raise ValueError(
|
| 90 |
-
f"Unknown output data format: {output_data_format}"
|
| 91 |
-
)
|
| 92 |
-
return flat_results
|
|
|
|
| 21 |
|
| 22 |
@staticmethod
|
| 23 |
@abstractmethod
|
| 24 |
+
def parse_response(response: str) -> list[dict]:
|
| 25 |
"""Parse the LLM response and return the generated QAs"""
|
| 26 |
|
| 27 |
async def generate(
|
|
|
|
| 29 |
batch: tuple[
|
| 30 |
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
|
| 31 |
],
|
| 32 |
+
) -> list[dict]:
|
| 33 |
"""
|
| 34 |
Generate QAs based on a given batch.
|
| 35 |
:param batch
|
| 36 |
:return: QA pairs
|
| 37 |
"""
|
|
|
|
| 38 |
prompt = self.build_prompt(batch)
|
| 39 |
response = await self.llm_client.generate_answer(prompt)
|
| 40 |
qa_pairs = self.parse_response(response) # generate one or more QA pairs
|
| 41 |
+
return qa_pairs
|
|
|
|
| 42 |
|
| 43 |
@staticmethod
|
| 44 |
def format_generation_results(
|
| 45 |
+
result: dict, output_data_format: str
|
| 46 |
+
) -> dict[str, Any]:
|
| 47 |
+
question = result.get("question", "")
|
| 48 |
+
answer = result.get("answer", "")
|
| 49 |
+
if "options" in result and result["options"]:
|
| 50 |
+
options = result["options"]
|
| 51 |
+
options_str = "\n".join(
|
| 52 |
+
[f"{key}. {options[key]}" for key in sorted(options.keys())]
|
| 53 |
+
)
|
| 54 |
+
question += f"\nOptions:\n{options_str}"
|
| 55 |
|
| 56 |
+
if output_data_format == "Alpaca":
|
| 57 |
+
return {
|
| 58 |
+
"instruction": question,
|
| 59 |
+
"input": "",
|
| 60 |
+
"output": answer,
|
| 61 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
+
if output_data_format == "Sharegpt":
|
| 64 |
+
return {
|
| 65 |
+
"conversations": [
|
| 66 |
+
{"from": "human", "value": question},
|
| 67 |
+
{"from": "gpt", "value": answer},
|
| 68 |
+
]
|
| 69 |
+
}
|
| 70 |
+
if output_data_format == "ChatML":
|
| 71 |
+
return {
|
| 72 |
+
"messages": [
|
| 73 |
+
{"role": "user", "content": question},
|
| 74 |
+
{"role": "assistant", "content": answer},
|
| 75 |
+
]
|
| 76 |
+
}
|
| 77 |
+
raise ValueError(f"Unknown output data format: {output_data_format}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/bases/base_operator.py
CHANGED
|
@@ -1,19 +1,43 @@
|
|
| 1 |
import inspect
|
| 2 |
import os
|
| 3 |
from abc import ABC, abstractmethod
|
| 4 |
-
from typing import Iterable, Union
|
| 5 |
|
|
|
|
| 6 |
import pandas as pd
|
| 7 |
import ray
|
| 8 |
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
class BaseOperator(ABC):
|
| 11 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
# lazy import to avoid circular import
|
|
|
|
| 13 |
from graphgen.utils import set_logger
|
| 14 |
|
| 15 |
log_dir = os.path.join(working_dir, "logs")
|
| 16 |
self.op_name = op_name or self.__class__.__name__
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
try:
|
| 19 |
ctx = ray.get_runtime_context()
|
|
@@ -45,17 +69,94 @@ class BaseOperator(ABC):
|
|
| 45 |
|
| 46 |
logger_token = CURRENT_LOGGER_VAR.set(self.logger)
|
| 47 |
try:
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
if inspect.isgenerator(result):
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
else:
|
| 52 |
-
yield result
|
|
|
|
| 53 |
finally:
|
| 54 |
CURRENT_LOGGER_VAR.reset(logger_token)
|
| 55 |
|
| 56 |
-
@abstractmethod
|
| 57 |
-
def process(self, batch):
|
| 58 |
-
raise NotImplementedError("Subclasses must implement the process method.")
|
| 59 |
-
|
| 60 |
def get_logger(self):
|
| 61 |
return self.logger
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import inspect
|
| 2 |
import os
|
| 3 |
from abc import ABC, abstractmethod
|
| 4 |
+
from typing import Iterable, Tuple, Union
|
| 5 |
|
| 6 |
+
import numpy as np
|
| 7 |
import pandas as pd
|
| 8 |
import ray
|
| 9 |
|
| 10 |
|
| 11 |
+
def convert_to_serializable(obj):
|
| 12 |
+
if isinstance(obj, np.ndarray):
|
| 13 |
+
return obj.tolist()
|
| 14 |
+
if isinstance(obj, np.generic):
|
| 15 |
+
return obj.item()
|
| 16 |
+
if isinstance(obj, dict):
|
| 17 |
+
return {k: convert_to_serializable(v) for k, v in obj.items()}
|
| 18 |
+
if isinstance(obj, list):
|
| 19 |
+
return [convert_to_serializable(v) for v in obj]
|
| 20 |
+
return obj
|
| 21 |
+
|
| 22 |
+
|
| 23 |
class BaseOperator(ABC):
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
working_dir: str = "cache",
|
| 27 |
+
kv_backend: str = "rocksdb",
|
| 28 |
+
op_name: str = None,
|
| 29 |
+
):
|
| 30 |
# lazy import to avoid circular import
|
| 31 |
+
from graphgen.common import init_storage
|
| 32 |
from graphgen.utils import set_logger
|
| 33 |
|
| 34 |
log_dir = os.path.join(working_dir, "logs")
|
| 35 |
self.op_name = op_name or self.__class__.__name__
|
| 36 |
+
self.working_dir = working_dir
|
| 37 |
+
self.kv_backend = kv_backend
|
| 38 |
+
self.kv_storage = init_storage(
|
| 39 |
+
backend=kv_backend, working_dir=working_dir, namespace=self.op_name
|
| 40 |
+
)
|
| 41 |
|
| 42 |
try:
|
| 43 |
ctx = ray.get_runtime_context()
|
|
|
|
| 69 |
|
| 70 |
logger_token = CURRENT_LOGGER_VAR.set(self.logger)
|
| 71 |
try:
|
| 72 |
+
self.kv_storage.reload()
|
| 73 |
+
to_process, recovered = self.split(batch)
|
| 74 |
+
# yield recovered chunks first
|
| 75 |
+
if not recovered.empty:
|
| 76 |
+
yield recovered
|
| 77 |
+
|
| 78 |
+
if to_process.empty:
|
| 79 |
+
return
|
| 80 |
+
|
| 81 |
+
data = to_process.to_dict(orient="records")
|
| 82 |
+
result, meta_update = self.process(data)
|
| 83 |
if inspect.isgenerator(result):
|
| 84 |
+
is_first = True
|
| 85 |
+
for res in result:
|
| 86 |
+
yield pd.DataFrame([res])
|
| 87 |
+
self.store([res], meta_update if is_first else {})
|
| 88 |
+
is_first = False
|
| 89 |
else:
|
| 90 |
+
yield pd.DataFrame(result)
|
| 91 |
+
self.store(result, meta_update)
|
| 92 |
finally:
|
| 93 |
CURRENT_LOGGER_VAR.reset(logger_token)
|
| 94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
def get_logger(self):
|
| 96 |
return self.logger
|
| 97 |
+
|
| 98 |
+
def get_meta_forward(self):
|
| 99 |
+
return self.kv_storage.get_by_id("_meta_forward") or {}
|
| 100 |
+
|
| 101 |
+
def get_meta_inverse(self):
|
| 102 |
+
return self.kv_storage.get_by_id("_meta_inverse") or {}
|
| 103 |
+
|
| 104 |
+
def get_trace_id(self, content: dict) -> str:
|
| 105 |
+
from graphgen.utils import compute_dict_hash
|
| 106 |
+
|
| 107 |
+
return compute_dict_hash(content, prefix=f"{self.op_name}-")
|
| 108 |
+
|
| 109 |
+
def split(self, batch: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
|
| 110 |
+
"""
|
| 111 |
+
Split the input batch into to_process & processed based on _meta data in KV_storage
|
| 112 |
+
:param batch
|
| 113 |
+
:return:
|
| 114 |
+
to_process: DataFrame of documents to be chunked
|
| 115 |
+
recovered: Result DataFrame of already chunked documents
|
| 116 |
+
"""
|
| 117 |
+
meta_forward = self.get_meta_forward()
|
| 118 |
+
meta_ids = set(meta_forward.keys())
|
| 119 |
+
mask = batch["_trace_id"].isin(meta_ids)
|
| 120 |
+
to_process = batch[~mask]
|
| 121 |
+
processed = batch[mask]
|
| 122 |
+
|
| 123 |
+
if processed.empty:
|
| 124 |
+
return to_process, pd.DataFrame()
|
| 125 |
+
|
| 126 |
+
all_ids = [
|
| 127 |
+
pid for tid in processed["_trace_id"] for pid in meta_forward.get(tid, [])
|
| 128 |
+
]
|
| 129 |
+
|
| 130 |
+
recovered_chunks = self.kv_storage.get_by_ids(all_ids)
|
| 131 |
+
recovered_chunks = [c for c in recovered_chunks if c is not None]
|
| 132 |
+
return to_process, pd.DataFrame(recovered_chunks)
|
| 133 |
+
|
| 134 |
+
def store(self, results: list, meta_update: dict):
|
| 135 |
+
results = convert_to_serializable(results)
|
| 136 |
+
meta_update = convert_to_serializable(meta_update)
|
| 137 |
+
|
| 138 |
+
batch = {res["_trace_id"]: res for res in results}
|
| 139 |
+
self.kv_storage.upsert(batch)
|
| 140 |
+
|
| 141 |
+
# update forward meta
|
| 142 |
+
forward_meta = self.get_meta_forward()
|
| 143 |
+
forward_meta.update(meta_update)
|
| 144 |
+
self.kv_storage.update({"_meta_forward": forward_meta})
|
| 145 |
+
|
| 146 |
+
# update inverse meta
|
| 147 |
+
inverse_meta = self.get_meta_inverse()
|
| 148 |
+
for k, v_list in meta_update.items():
|
| 149 |
+
for v in v_list:
|
| 150 |
+
inverse_meta[v] = k
|
| 151 |
+
self.kv_storage.update({"_meta_inverse": inverse_meta})
|
| 152 |
+
self.kv_storage.index_done_callback()
|
| 153 |
+
|
| 154 |
+
@abstractmethod
|
| 155 |
+
def process(self, batch: list) -> Tuple[Union[list, Iterable[dict]], dict]:
|
| 156 |
+
"""
|
| 157 |
+
Process the input batch and return the result.
|
| 158 |
+
:param batch
|
| 159 |
+
:return:
|
| 160 |
+
result: DataFrame of processed documents
|
| 161 |
+
meta_update: dict of meta data to be updated
|
| 162 |
+
"""
|
graphgen/bases/base_storage.py
CHANGED
|
@@ -39,6 +39,12 @@ class BaseKVStorage(Generic[T], StorageNameSpace):
|
|
| 39 |
def upsert(self, data: dict[str, T]):
|
| 40 |
raise NotImplementedError
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
def drop(self):
|
| 43 |
raise NotImplementedError
|
| 44 |
|
|
|
|
| 39 |
def upsert(self, data: dict[str, T]):
|
| 40 |
raise NotImplementedError
|
| 41 |
|
| 42 |
+
def update(self, data: dict[str, T]):
|
| 43 |
+
raise NotImplementedError
|
| 44 |
+
|
| 45 |
+
def delete(self, ids: list[str]):
|
| 46 |
+
raise NotImplementedError
|
| 47 |
+
|
| 48 |
def drop(self):
|
| 49 |
raise NotImplementedError
|
| 50 |
|
graphgen/bases/datatypes.py
CHANGED
|
@@ -31,6 +31,13 @@ class QAPair:
|
|
| 31 |
question: str
|
| 32 |
answer: str
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
@dataclass
|
| 36 |
class Token:
|
|
|
|
| 31 |
question: str
|
| 32 |
answer: str
|
| 33 |
|
| 34 |
+
@staticmethod
|
| 35 |
+
def from_dict(data: dict) -> "QAPair":
|
| 36 |
+
return QAPair(
|
| 37 |
+
question=data.get("question", ""),
|
| 38 |
+
answer=data.get("answer", ""),
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
|
| 42 |
@dataclass
|
| 43 |
class Token:
|
graphgen/common/init_storage.py
CHANGED
|
@@ -8,11 +8,11 @@ from graphgen.bases.base_storage import BaseGraphStorage, BaseKVStorage
|
|
| 8 |
class KVStorageActor:
|
| 9 |
def __init__(self, backend: str, working_dir: str, namespace: str):
|
| 10 |
if backend == "json_kv":
|
| 11 |
-
from graphgen.
|
| 12 |
|
| 13 |
self.kv = JsonKVStorage(working_dir, namespace)
|
| 14 |
elif backend == "rocksdb":
|
| 15 |
-
from graphgen.
|
| 16 |
|
| 17 |
self.kv = RocksDBKVStorage(working_dir, namespace)
|
| 18 |
else:
|
|
@@ -42,6 +42,12 @@ class KVStorageActor:
|
|
| 42 |
def upsert(self, data: dict) -> dict:
|
| 43 |
return self.kv.upsert(data)
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
def drop(self):
|
| 46 |
return self.kv.drop()
|
| 47 |
|
|
@@ -55,11 +61,11 @@ class KVStorageActor:
|
|
| 55 |
class GraphStorageActor:
|
| 56 |
def __init__(self, backend: str, working_dir: str, namespace: str):
|
| 57 |
if backend == "networkx":
|
| 58 |
-
from graphgen.
|
| 59 |
|
| 60 |
self.graph = NetworkXStorage(working_dir, namespace)
|
| 61 |
elif backend == "kuzu":
|
| 62 |
-
from graphgen.
|
| 63 |
|
| 64 |
self.graph = KuzuStorage(working_dir, namespace)
|
| 65 |
else:
|
|
@@ -168,6 +174,12 @@ class RemoteKVStorageProxy(BaseKVStorage):
|
|
| 168 |
def upsert(self, data: Dict[str, Any]):
|
| 169 |
return ray.get(self.actor.upsert.remote(data))
|
| 170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
def drop(self):
|
| 172 |
return ray.get(self.actor.drop.remote())
|
| 173 |
|
|
|
|
| 8 |
class KVStorageActor:
|
| 9 |
def __init__(self, backend: str, working_dir: str, namespace: str):
|
| 10 |
if backend == "json_kv":
|
| 11 |
+
from graphgen.storage import JsonKVStorage
|
| 12 |
|
| 13 |
self.kv = JsonKVStorage(working_dir, namespace)
|
| 14 |
elif backend == "rocksdb":
|
| 15 |
+
from graphgen.storage import RocksDBKVStorage
|
| 16 |
|
| 17 |
self.kv = RocksDBKVStorage(working_dir, namespace)
|
| 18 |
else:
|
|
|
|
| 42 |
def upsert(self, data: dict) -> dict:
|
| 43 |
return self.kv.upsert(data)
|
| 44 |
|
| 45 |
+
def update(self, data: dict):
|
| 46 |
+
return self.kv.update(data)
|
| 47 |
+
|
| 48 |
+
def delete(self, ids: list[str]):
|
| 49 |
+
return self.kv.delete(ids)
|
| 50 |
+
|
| 51 |
def drop(self):
|
| 52 |
return self.kv.drop()
|
| 53 |
|
|
|
|
| 61 |
class GraphStorageActor:
|
| 62 |
def __init__(self, backend: str, working_dir: str, namespace: str):
|
| 63 |
if backend == "networkx":
|
| 64 |
+
from graphgen.storage import NetworkXStorage
|
| 65 |
|
| 66 |
self.graph = NetworkXStorage(working_dir, namespace)
|
| 67 |
elif backend == "kuzu":
|
| 68 |
+
from graphgen.storage import KuzuStorage
|
| 69 |
|
| 70 |
self.graph = KuzuStorage(working_dir, namespace)
|
| 71 |
else:
|
|
|
|
| 174 |
def upsert(self, data: Dict[str, Any]):
|
| 175 |
return ray.get(self.actor.upsert.remote(data))
|
| 176 |
|
| 177 |
+
def update(self, data: Dict[str, Any]):
|
| 178 |
+
return ray.get(self.actor.update.remote(data))
|
| 179 |
+
|
| 180 |
+
def delete(self, ids: list[str]):
|
| 181 |
+
return ray.get(self.actor.delete.remote(ids))
|
| 182 |
+
|
| 183 |
def drop(self):
|
| 184 |
return ray.get(self.actor.drop.remote())
|
| 185 |
|
graphgen/models/__init__.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
from .evaluator import (
|
| 2 |
AccuracyEvaluator,
|
| 3 |
-
ConsistencyEvaluator,
|
| 4 |
LengthEvaluator,
|
| 5 |
MTLDEvaluator,
|
| 6 |
RewardEvaluator,
|
|
@@ -44,11 +43,4 @@ from .searcher.kg.wiki_search import WikiSearch
|
|
| 44 |
from .searcher.web.bing_search import BingSearch
|
| 45 |
from .searcher.web.google_search import GoogleSearch
|
| 46 |
from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter
|
| 47 |
-
from .storage import (
|
| 48 |
-
JsonKVStorage,
|
| 49 |
-
KuzuStorage,
|
| 50 |
-
NetworkXStorage,
|
| 51 |
-
RocksDBCache,
|
| 52 |
-
RocksDBKVStorage,
|
| 53 |
-
)
|
| 54 |
from .tokenizer import Tokenizer
|
|
|
|
| 1 |
from .evaluator import (
|
| 2 |
AccuracyEvaluator,
|
|
|
|
| 3 |
LengthEvaluator,
|
| 4 |
MTLDEvaluator,
|
| 5 |
RewardEvaluator,
|
|
|
|
| 43 |
from .searcher.web.bing_search import BingSearch
|
| 44 |
from .searcher.web.google_search import GoogleSearch
|
| 45 |
from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
from .tokenizer import Tokenizer
|
graphgen/models/evaluator/__init__.py
CHANGED
|
@@ -1,2 +1,3 @@
|
|
| 1 |
-
from .kg import
|
| 2 |
from .qa import LengthEvaluator, MTLDEvaluator, RewardEvaluator, UniEvaluator
|
|
|
|
|
|
| 1 |
+
from .kg import StructureEvaluator
|
| 2 |
from .qa import LengthEvaluator, MTLDEvaluator, RewardEvaluator, UniEvaluator
|
| 3 |
+
from .triple import AccuracyEvaluator
|
graphgen/models/evaluator/kg/__init__.py
CHANGED
|
@@ -1,18 +1 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Knowledge Graph Quality Evaluator
|
| 3 |
-
|
| 4 |
-
This module provides comprehensive quality evaluation for knowledge graphs,
|
| 5 |
-
1. accuracy assessment (entity/relation/triple validation),
|
| 6 |
-
2. consistency assessment (attribute conflict detection), and structural
|
| 7 |
-
3. robustness assessment (noise ratio, connectivity, degree distribution).
|
| 8 |
-
"""
|
| 9 |
-
|
| 10 |
-
from .accuracy_evaluator import AccuracyEvaluator
|
| 11 |
-
from .consistency_evaluator import ConsistencyEvaluator
|
| 12 |
from .structure_evaluator import StructureEvaluator
|
| 13 |
-
|
| 14 |
-
__all__ = [
|
| 15 |
-
"AccuracyEvaluator",
|
| 16 |
-
"ConsistencyEvaluator",
|
| 17 |
-
"StructureEvaluator",
|
| 18 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from .structure_evaluator import StructureEvaluator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/models/evaluator/kg/accuracy_evaluator.py
DELETED
|
@@ -1,350 +0,0 @@
|
|
| 1 |
-
import asyncio
|
| 2 |
-
import json
|
| 3 |
-
import re
|
| 4 |
-
from typing import Any, Dict, List
|
| 5 |
-
|
| 6 |
-
from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper
|
| 7 |
-
from graphgen.bases.datatypes import Chunk
|
| 8 |
-
from graphgen.templates import ACCURACY_EVALUATION_PROMPT
|
| 9 |
-
from graphgen.utils import detect_main_language, logger
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class AccuracyEvaluator:
|
| 13 |
-
"""Evaluates accuracy of entity recognition and relation extraction using LLM-as-a-Judge.
|
| 14 |
-
|
| 15 |
-
For each chunk, uses LLM to evaluate the quality of extracted entities and relations
|
| 16 |
-
by comparing them with the original chunk content. Provides multi-dimensional quality
|
| 17 |
-
scores (accuracy, completeness, precision).
|
| 18 |
-
"""
|
| 19 |
-
|
| 20 |
-
def __init__(
|
| 21 |
-
self,
|
| 22 |
-
graph_storage: BaseGraphStorage,
|
| 23 |
-
chunk_storage: BaseKVStorage,
|
| 24 |
-
llm_client: BaseLLMWrapper,
|
| 25 |
-
):
|
| 26 |
-
self.graph_storage = graph_storage
|
| 27 |
-
self.chunk_storage = chunk_storage
|
| 28 |
-
self.llm_client = llm_client
|
| 29 |
-
|
| 30 |
-
def evaluate(self) -> Dict[str, Any]:
|
| 31 |
-
"""Evaluate entity and relation extraction quality using LLM-as-a-Judge.
|
| 32 |
-
|
| 33 |
-
Returns:
|
| 34 |
-
Dictionary containing entity_accuracy and relation_accuracy metrics.
|
| 35 |
-
"""
|
| 36 |
-
# 1. Load all chunks from storage
|
| 37 |
-
chunks = self._load_chunks_from_storage()
|
| 38 |
-
|
| 39 |
-
if not chunks:
|
| 40 |
-
logger.warning("No chunks found in storage")
|
| 41 |
-
return {"error": "No chunks found in storage"}
|
| 42 |
-
|
| 43 |
-
logger.info(f"Found {len(chunks)} chunks to evaluate")
|
| 44 |
-
|
| 45 |
-
# 2. Evaluate each chunk
|
| 46 |
-
entity_evaluations, relation_evaluations = self._evaluate_all_chunks(chunks)
|
| 47 |
-
|
| 48 |
-
# 3. Aggregate results
|
| 49 |
-
return self._aggregate_evaluation_results(
|
| 50 |
-
entity_evaluations, relation_evaluations
|
| 51 |
-
)
|
| 52 |
-
|
| 53 |
-
def _load_chunks_from_storage(self) -> List[Chunk]:
|
| 54 |
-
"""Load all chunks from chunk storage."""
|
| 55 |
-
chunks = []
|
| 56 |
-
all_chunk_data = self.chunk_storage.get_all()
|
| 57 |
-
|
| 58 |
-
for chunk_id, chunk_data in all_chunk_data.items():
|
| 59 |
-
try:
|
| 60 |
-
chunk = Chunk.from_dict(chunk_id, chunk_data)
|
| 61 |
-
chunks.append(chunk)
|
| 62 |
-
except Exception as e:
|
| 63 |
-
logger.warning(f"Failed to load chunk {chunk_id}: {e}")
|
| 64 |
-
continue
|
| 65 |
-
|
| 66 |
-
return chunks
|
| 67 |
-
|
| 68 |
-
def _get_extracted_entities_for_chunk(self, chunk_id: str) -> List[Dict]:
|
| 69 |
-
"""Get all entities extracted from the specified chunk."""
|
| 70 |
-
entities = []
|
| 71 |
-
all_nodes = self.graph_storage.get_all_nodes() or []
|
| 72 |
-
|
| 73 |
-
for node_id, node_data in all_nodes:
|
| 74 |
-
if not isinstance(node_data, dict):
|
| 75 |
-
continue
|
| 76 |
-
source_ids = node_data.get("source_id", "").split("<SEP>")
|
| 77 |
-
# Check if this chunk_id is in the source_ids
|
| 78 |
-
if chunk_id in [sid.strip() for sid in source_ids if sid.strip()]:
|
| 79 |
-
entities.append(
|
| 80 |
-
{
|
| 81 |
-
"entity_name": node_data.get("entity_name", node_id),
|
| 82 |
-
"entity_type": node_data.get("entity_type", ""),
|
| 83 |
-
"description": node_data.get("description", ""),
|
| 84 |
-
}
|
| 85 |
-
)
|
| 86 |
-
|
| 87 |
-
return entities
|
| 88 |
-
|
| 89 |
-
def _get_extracted_relations_for_chunk(self, chunk_id: str) -> List[Dict]:
|
| 90 |
-
"""Get all relations extracted from the specified chunk."""
|
| 91 |
-
relations = []
|
| 92 |
-
all_edges = self.graph_storage.get_all_edges() or []
|
| 93 |
-
|
| 94 |
-
for src_id, dst_id, edge_data in all_edges:
|
| 95 |
-
if not isinstance(edge_data, dict):
|
| 96 |
-
continue
|
| 97 |
-
source_ids = edge_data.get("source_id", "").split("<SEP>")
|
| 98 |
-
# Check if this chunk_id is in the source_ids
|
| 99 |
-
if chunk_id in [sid.strip() for sid in source_ids if sid.strip()]:
|
| 100 |
-
src_node = self.graph_storage.get_node(src_id) or {}
|
| 101 |
-
dst_node = self.graph_storage.get_node(dst_id) or {}
|
| 102 |
-
relations.append(
|
| 103 |
-
{
|
| 104 |
-
"source_entity": src_node.get("entity_name", src_id),
|
| 105 |
-
"target_entity": dst_node.get("entity_name", dst_id),
|
| 106 |
-
"relationship_summary": edge_data.get("description", ""),
|
| 107 |
-
}
|
| 108 |
-
)
|
| 109 |
-
|
| 110 |
-
return relations
|
| 111 |
-
|
| 112 |
-
def _evaluate_all_chunks(
|
| 113 |
-
self, chunks: List[Chunk]
|
| 114 |
-
) -> tuple[List[Dict], List[Dict]]:
|
| 115 |
-
"""Evaluate all chunks sequentially."""
|
| 116 |
-
entity_evaluations = []
|
| 117 |
-
relation_evaluations = []
|
| 118 |
-
|
| 119 |
-
for chunk in chunks:
|
| 120 |
-
try:
|
| 121 |
-
entities = self._get_extracted_entities_for_chunk(chunk.id)
|
| 122 |
-
relations = self._get_extracted_relations_for_chunk(chunk.id)
|
| 123 |
-
|
| 124 |
-
entity_eval = self._evaluate_entity_extraction(chunk, entities)
|
| 125 |
-
relation_eval = self._evaluate_relation_extraction(chunk, relations)
|
| 126 |
-
|
| 127 |
-
entity_evaluations.append(entity_eval)
|
| 128 |
-
relation_evaluations.append(relation_eval)
|
| 129 |
-
except Exception as e:
|
| 130 |
-
logger.error(f"Failed to evaluate chunk {chunk.id}: {e}")
|
| 131 |
-
continue
|
| 132 |
-
|
| 133 |
-
return entity_evaluations, relation_evaluations
|
| 134 |
-
|
| 135 |
-
def _evaluate_entity_extraction(
|
| 136 |
-
self, chunk: Chunk, extracted_entities: List[Dict]
|
| 137 |
-
) -> Dict[str, Any]:
|
| 138 |
-
"""Use LLM to evaluate entity extraction quality."""
|
| 139 |
-
try:
|
| 140 |
-
lang = detect_main_language(chunk.content)
|
| 141 |
-
|
| 142 |
-
prompt = ACCURACY_EVALUATION_PROMPT[lang]["ENTITY"].format(
|
| 143 |
-
chunk_content=chunk.content,
|
| 144 |
-
extracted_entities=json.dumps(
|
| 145 |
-
extracted_entities, ensure_ascii=False, indent=2
|
| 146 |
-
),
|
| 147 |
-
)
|
| 148 |
-
|
| 149 |
-
response = asyncio.run(self.llm_client.generate_answer(prompt))
|
| 150 |
-
|
| 151 |
-
# Try to parse JSON response
|
| 152 |
-
try:
|
| 153 |
-
evaluation_result = json.loads(response)
|
| 154 |
-
except json.JSONDecodeError:
|
| 155 |
-
# Try to extract JSON from markdown code blocks or other formats
|
| 156 |
-
json_match = re.search(r"\{.*\}", response, re.DOTALL)
|
| 157 |
-
if json_match:
|
| 158 |
-
evaluation_result = json.loads(json_match.group(0))
|
| 159 |
-
else:
|
| 160 |
-
logger.warning(
|
| 161 |
-
f"Failed to parse LLM response for chunk {chunk.id}: {response[:200]}"
|
| 162 |
-
)
|
| 163 |
-
# Return default evaluation
|
| 164 |
-
evaluation_result = {
|
| 165 |
-
"accuracy": 0.0,
|
| 166 |
-
"completeness": 0.0,
|
| 167 |
-
"precision": 0.0,
|
| 168 |
-
"overall_score": 0.0,
|
| 169 |
-
"accuracy_reasoning": "Failed to parse LLM response",
|
| 170 |
-
"completeness_reasoning": "",
|
| 171 |
-
"precision_reasoning": "",
|
| 172 |
-
"issues": ["LLM response parsing failed"],
|
| 173 |
-
}
|
| 174 |
-
|
| 175 |
-
# Validate and calculate overall_score if not provided
|
| 176 |
-
if "overall_score" not in evaluation_result:
|
| 177 |
-
accuracy = float(evaluation_result.get("accuracy", 0.0))
|
| 178 |
-
completeness = float(evaluation_result.get("completeness", 0.0))
|
| 179 |
-
precision = float(evaluation_result.get("precision", 0.0))
|
| 180 |
-
evaluation_result["overall_score"] = (
|
| 181 |
-
0.4 * accuracy + 0.4 * completeness + 0.2 * precision
|
| 182 |
-
)
|
| 183 |
-
|
| 184 |
-
return {
|
| 185 |
-
"chunk_id": chunk.id,
|
| 186 |
-
"chunk_content": chunk.content[:200]
|
| 187 |
-
if chunk.content
|
| 188 |
-
else "", # First 200 chars for debugging
|
| 189 |
-
"extracted_entities_count": len(extracted_entities),
|
| 190 |
-
**evaluation_result,
|
| 191 |
-
}
|
| 192 |
-
except Exception as e:
|
| 193 |
-
logger.error(
|
| 194 |
-
f"Error evaluating entity extraction for chunk {chunk.id}: {e}"
|
| 195 |
-
)
|
| 196 |
-
return {
|
| 197 |
-
"chunk_id": chunk.id,
|
| 198 |
-
"chunk_content": chunk.content[:200] if chunk.content else "",
|
| 199 |
-
"extracted_entities_count": len(extracted_entities),
|
| 200 |
-
"accuracy": 0.0,
|
| 201 |
-
"completeness": 0.0,
|
| 202 |
-
"precision": 0.0,
|
| 203 |
-
"overall_score": 0.0,
|
| 204 |
-
"accuracy_reasoning": f"Evaluation failed: {str(e)}",
|
| 205 |
-
"completeness_reasoning": "",
|
| 206 |
-
"precision_reasoning": "",
|
| 207 |
-
"issues": [f"Evaluation error: {str(e)}"],
|
| 208 |
-
}
|
| 209 |
-
|
| 210 |
-
def _evaluate_relation_extraction(
|
| 211 |
-
self, chunk: Chunk, extracted_relations: List[Dict]
|
| 212 |
-
) -> Dict[str, Any]:
|
| 213 |
-
"""Use LLM to evaluate relation extraction quality."""
|
| 214 |
-
try:
|
| 215 |
-
lang = detect_main_language(chunk.content)
|
| 216 |
-
prompt = ACCURACY_EVALUATION_PROMPT[lang]["RELATION"].format(
|
| 217 |
-
chunk_content=chunk.content,
|
| 218 |
-
extracted_relations=json.dumps(
|
| 219 |
-
extracted_relations, ensure_ascii=False, indent=2
|
| 220 |
-
),
|
| 221 |
-
)
|
| 222 |
-
|
| 223 |
-
response = asyncio.run(self.llm_client.generate_answer(prompt))
|
| 224 |
-
|
| 225 |
-
# Try to parse JSON response
|
| 226 |
-
try:
|
| 227 |
-
evaluation_result = json.loads(response)
|
| 228 |
-
except json.JSONDecodeError:
|
| 229 |
-
# Try to extract JSON from markdown code blocks or other formats
|
| 230 |
-
json_match = re.search(r"\{.*\}", response, re.DOTALL)
|
| 231 |
-
if json_match:
|
| 232 |
-
evaluation_result = json.loads(json_match.group(0))
|
| 233 |
-
else:
|
| 234 |
-
logger.warning(
|
| 235 |
-
f"Failed to parse LLM response for chunk {chunk.id}: {response[:200]}"
|
| 236 |
-
)
|
| 237 |
-
# Return default evaluation
|
| 238 |
-
evaluation_result = {
|
| 239 |
-
"accuracy": 0.0,
|
| 240 |
-
"completeness": 0.0,
|
| 241 |
-
"precision": 0.0,
|
| 242 |
-
"overall_score": 0.0,
|
| 243 |
-
"accuracy_reasoning": "Failed to parse LLM response",
|
| 244 |
-
"completeness_reasoning": "",
|
| 245 |
-
"precision_reasoning": "",
|
| 246 |
-
"issues": ["LLM response parsing failed"],
|
| 247 |
-
}
|
| 248 |
-
|
| 249 |
-
# Validate and calculate overall_score if not provided
|
| 250 |
-
if "overall_score" not in evaluation_result:
|
| 251 |
-
accuracy = float(evaluation_result.get("accuracy", 0.0))
|
| 252 |
-
completeness = float(evaluation_result.get("completeness", 0.0))
|
| 253 |
-
precision = float(evaluation_result.get("precision", 0.0))
|
| 254 |
-
evaluation_result["overall_score"] = (
|
| 255 |
-
0.4 * accuracy + 0.4 * completeness + 0.2 * precision
|
| 256 |
-
)
|
| 257 |
-
|
| 258 |
-
return {
|
| 259 |
-
"chunk_id": chunk.id,
|
| 260 |
-
"chunk_content": chunk.content[:200] if chunk.content else "",
|
| 261 |
-
"extracted_relations_count": len(extracted_relations),
|
| 262 |
-
**evaluation_result,
|
| 263 |
-
}
|
| 264 |
-
except Exception as e:
|
| 265 |
-
logger.error(
|
| 266 |
-
f"Error evaluating relation extraction for chunk {chunk.id}: {e}"
|
| 267 |
-
)
|
| 268 |
-
return {
|
| 269 |
-
"chunk_id": chunk.id,
|
| 270 |
-
"chunk_content": chunk.content[:200] if chunk.content else "",
|
| 271 |
-
"extracted_relations_count": len(extracted_relations),
|
| 272 |
-
"accuracy": 0.0,
|
| 273 |
-
"completeness": 0.0,
|
| 274 |
-
"precision": 0.0,
|
| 275 |
-
"overall_score": 0.0,
|
| 276 |
-
"accuracy_reasoning": f"Evaluation failed: {str(e)}",
|
| 277 |
-
"completeness_reasoning": "",
|
| 278 |
-
"precision_reasoning": "",
|
| 279 |
-
"issues": [f"Evaluation error: {str(e)}"],
|
| 280 |
-
}
|
| 281 |
-
|
| 282 |
-
@staticmethod
|
| 283 |
-
def _aggregate_evaluation_results(
|
| 284 |
-
entity_evaluations: List[Dict], relation_evaluations: List[Dict]
|
| 285 |
-
) -> Dict[str, Any]:
|
| 286 |
-
"""Aggregate evaluation results from all chunks."""
|
| 287 |
-
|
| 288 |
-
def calculate_stats(scores: List[float]) -> Dict[str, float]:
|
| 289 |
-
if not scores:
|
| 290 |
-
return {"mean": 0.0, "median": 0.0, "min": 0.0, "max": 0.0, "std": 0.0}
|
| 291 |
-
sorted_scores = sorted(scores)
|
| 292 |
-
n = len(scores)
|
| 293 |
-
mean = sum(scores) / n
|
| 294 |
-
median = (
|
| 295 |
-
sorted_scores[n // 2]
|
| 296 |
-
if n % 2 == 1
|
| 297 |
-
else (sorted_scores[n // 2 - 1] + sorted_scores[n // 2]) / 2
|
| 298 |
-
)
|
| 299 |
-
variance = sum((x - mean) ** 2 for x in scores) / n
|
| 300 |
-
std = variance**0.5
|
| 301 |
-
|
| 302 |
-
return {
|
| 303 |
-
"mean": mean,
|
| 304 |
-
"median": median,
|
| 305 |
-
"min": min(scores),
|
| 306 |
-
"max": max(scores),
|
| 307 |
-
"std": std,
|
| 308 |
-
}
|
| 309 |
-
|
| 310 |
-
# Extract scores
|
| 311 |
-
entity_overall_scores = [
|
| 312 |
-
e.get("overall_score", 0.0) for e in entity_evaluations
|
| 313 |
-
]
|
| 314 |
-
entity_accuracy_scores = [e.get("accuracy", 0.0) for e in entity_evaluations]
|
| 315 |
-
entity_completeness_scores = [
|
| 316 |
-
e.get("completeness", 0.0) for e in entity_evaluations
|
| 317 |
-
]
|
| 318 |
-
entity_precision_scores = [e.get("precision", 0.0) for e in entity_evaluations]
|
| 319 |
-
|
| 320 |
-
relation_overall_scores = [
|
| 321 |
-
r.get("overall_score", 0.0) for r in relation_evaluations
|
| 322 |
-
]
|
| 323 |
-
relation_accuracy_scores = [
|
| 324 |
-
r.get("accuracy", 0.0) for r in relation_evaluations
|
| 325 |
-
]
|
| 326 |
-
relation_completeness_scores = [
|
| 327 |
-
r.get("completeness", 0.0) for r in relation_evaluations
|
| 328 |
-
]
|
| 329 |
-
relation_precision_scores = [
|
| 330 |
-
r.get("precision", 0.0) for r in relation_evaluations
|
| 331 |
-
]
|
| 332 |
-
|
| 333 |
-
return {
|
| 334 |
-
"entity_accuracy": {
|
| 335 |
-
"overall_score": calculate_stats(entity_overall_scores),
|
| 336 |
-
"accuracy": calculate_stats(entity_accuracy_scores),
|
| 337 |
-
"completeness": calculate_stats(entity_completeness_scores),
|
| 338 |
-
"precision": calculate_stats(entity_precision_scores),
|
| 339 |
-
"total_chunks": len(entity_evaluations),
|
| 340 |
-
"detailed_results": entity_evaluations,
|
| 341 |
-
},
|
| 342 |
-
"relation_accuracy": {
|
| 343 |
-
"overall_score": calculate_stats(relation_overall_scores),
|
| 344 |
-
"accuracy": calculate_stats(relation_accuracy_scores),
|
| 345 |
-
"completeness": calculate_stats(relation_completeness_scores),
|
| 346 |
-
"precision": calculate_stats(relation_precision_scores),
|
| 347 |
-
"total_chunks": len(relation_evaluations),
|
| 348 |
-
"detailed_results": relation_evaluations,
|
| 349 |
-
},
|
| 350 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/models/evaluator/kg/consistency_evaluator.py
DELETED
|
@@ -1,388 +0,0 @@
|
|
| 1 |
-
import asyncio
|
| 2 |
-
import json
|
| 3 |
-
import re
|
| 4 |
-
from typing import Any, Dict, List
|
| 5 |
-
|
| 6 |
-
from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper
|
| 7 |
-
from graphgen.bases.datatypes import Chunk
|
| 8 |
-
from graphgen.templates.evaluation.kg.consistency_evaluation import (
|
| 9 |
-
CONSISTENCY_EVALUATION_PROMPT,
|
| 10 |
-
)
|
| 11 |
-
from graphgen.utils import detect_main_language, logger
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
class ConsistencyEvaluator:
|
| 15 |
-
"""Evaluates consistency by detecting semantic conflicts using LLM-as-a-Judge.
|
| 16 |
-
|
| 17 |
-
For entities with multiple source chunks, compares entity_type and description
|
| 18 |
-
extracted from different chunks to detect semantic conflicts.
|
| 19 |
-
"""
|
| 20 |
-
|
| 21 |
-
def __init__(
|
| 22 |
-
self,
|
| 23 |
-
graph_storage: BaseGraphStorage,
|
| 24 |
-
chunk_storage: BaseKVStorage,
|
| 25 |
-
llm_client: BaseLLMWrapper,
|
| 26 |
-
):
|
| 27 |
-
self.graph_storage = graph_storage
|
| 28 |
-
self.chunk_storage = chunk_storage
|
| 29 |
-
self.llm_client = llm_client
|
| 30 |
-
|
| 31 |
-
def evaluate(self) -> Dict[str, Any]:
|
| 32 |
-
"""Evaluate consistency by detecting semantic conflicts."""
|
| 33 |
-
all_nodes = self.graph_storage.get_all_nodes() or []
|
| 34 |
-
if not all_nodes:
|
| 35 |
-
return {"error": "Empty graph"}
|
| 36 |
-
|
| 37 |
-
return self._evaluate_consistency(all_nodes)
|
| 38 |
-
|
| 39 |
-
def _evaluate_consistency(self, all_nodes: List) -> Dict[str, Any]:
|
| 40 |
-
"""Evaluate consistency by detecting semantic conflicts."""
|
| 41 |
-
# Filter entities with multiple source chunks
|
| 42 |
-
entities_with_multiple_sources = []
|
| 43 |
-
for node_id, node_data in all_nodes:
|
| 44 |
-
if not isinstance(node_data, dict):
|
| 45 |
-
continue
|
| 46 |
-
source_ids = node_data.get("source_id", "").split("<SEP>")
|
| 47 |
-
source_ids = [sid.strip() for sid in source_ids if sid.strip()]
|
| 48 |
-
if len(source_ids) > 1: # Only check entities from multiple chunks
|
| 49 |
-
entities_with_multiple_sources.append((node_id, node_data, source_ids))
|
| 50 |
-
|
| 51 |
-
if not entities_with_multiple_sources:
|
| 52 |
-
logger.info(
|
| 53 |
-
"No entities with multiple sources found, skipping consistency check"
|
| 54 |
-
)
|
| 55 |
-
return {
|
| 56 |
-
"conflict_rate": 0.0,
|
| 57 |
-
"conflict_entities_count": 0,
|
| 58 |
-
"total_entities": len(all_nodes),
|
| 59 |
-
"conflicts": [],
|
| 60 |
-
}
|
| 61 |
-
|
| 62 |
-
logger.info(
|
| 63 |
-
f"Checking consistency for {len(entities_with_multiple_sources)} entities with multiple sources"
|
| 64 |
-
)
|
| 65 |
-
|
| 66 |
-
# Evaluate entities sequentially
|
| 67 |
-
conflicts = []
|
| 68 |
-
conflict_entities = set()
|
| 69 |
-
|
| 70 |
-
for entity_info in entities_with_multiple_sources:
|
| 71 |
-
try:
|
| 72 |
-
entity_id, entity_conflicts = self._evaluate_entity_consistency(entity_info)
|
| 73 |
-
if entity_conflicts:
|
| 74 |
-
conflicts.extend(entity_conflicts)
|
| 75 |
-
conflict_entities.add(entity_id)
|
| 76 |
-
except Exception as e:
|
| 77 |
-
logger.error(
|
| 78 |
-
f"Failed to evaluate entity {entity_info[0]}: {e}"
|
| 79 |
-
)
|
| 80 |
-
continue
|
| 81 |
-
|
| 82 |
-
total_entities = len(all_nodes)
|
| 83 |
-
conflict_rate = (
|
| 84 |
-
len(conflict_entities) / total_entities if total_entities > 0 else 0
|
| 85 |
-
)
|
| 86 |
-
|
| 87 |
-
return {
|
| 88 |
-
"conflict_rate": conflict_rate,
|
| 89 |
-
"conflict_entities_count": len(conflict_entities),
|
| 90 |
-
"total_entities": total_entities,
|
| 91 |
-
"entities_checked": len(entities_with_multiple_sources),
|
| 92 |
-
"conflicts": conflicts[:100], # Limit to first 100 conflicts
|
| 93 |
-
}
|
| 94 |
-
|
| 95 |
-
def _clean_entity_id(self, entity_id: str) -> str:
|
| 96 |
-
"""Clean entity ID by removing surrounding quotes."""
|
| 97 |
-
clean_id = entity_id.strip()
|
| 98 |
-
if (clean_id.startswith('"') and clean_id.endswith('"')) or (
|
| 99 |
-
clean_id.startswith("'") and clean_id.endswith("'")
|
| 100 |
-
):
|
| 101 |
-
clean_id = clean_id[1:-1].strip()
|
| 102 |
-
return clean_id
|
| 103 |
-
|
| 104 |
-
def _evaluate_entity_consistency(
|
| 105 |
-
self, entity_info: tuple
|
| 106 |
-
) -> tuple[str, List[Dict]]:
|
| 107 |
-
"""Evaluate consistency for a single entity."""
|
| 108 |
-
entity_id, _node_data, source_ids = entity_info
|
| 109 |
-
# Clean entity_id for display
|
| 110 |
-
clean_entity_id = self._clean_entity_id(entity_id)
|
| 111 |
-
conflicts = []
|
| 112 |
-
|
| 113 |
-
# Get chunks for this entity
|
| 114 |
-
chunks = self._get_entity_chunks(source_ids)
|
| 115 |
-
if len(chunks) < 2:
|
| 116 |
-
return entity_id, []
|
| 117 |
-
|
| 118 |
-
# Extract entity attributes from each chunk
|
| 119 |
-
entity_extractions = {}
|
| 120 |
-
for chunk in chunks:
|
| 121 |
-
extraction = self._extract_entity_from_chunk(entity_id, chunk)
|
| 122 |
-
if extraction:
|
| 123 |
-
entity_extractions[chunk.id] = extraction
|
| 124 |
-
|
| 125 |
-
if len(entity_extractions) < 2:
|
| 126 |
-
return entity_id, []
|
| 127 |
-
|
| 128 |
-
# Check entity type consistency
|
| 129 |
-
type_extractions = {
|
| 130 |
-
chunk_id: ext.get("entity_type", "")
|
| 131 |
-
for chunk_id, ext in entity_extractions.items()
|
| 132 |
-
}
|
| 133 |
-
type_conflict = self._check_entity_type_consistency(
|
| 134 |
-
entity_id, type_extractions
|
| 135 |
-
)
|
| 136 |
-
if type_conflict and type_conflict.get("has_conflict", False):
|
| 137 |
-
conflicts.append(
|
| 138 |
-
{
|
| 139 |
-
"entity_id": clean_entity_id,
|
| 140 |
-
"conflict_type": "entity_type",
|
| 141 |
-
"conflict_severity": type_conflict.get("conflict_severity", 0.0),
|
| 142 |
-
"conflict_reasoning": type_conflict.get("conflict_reasoning", ""),
|
| 143 |
-
"conflicting_values": type_conflict.get("conflicting_types", []),
|
| 144 |
-
"recommended_value": type_conflict.get("recommended_type", ""),
|
| 145 |
-
}
|
| 146 |
-
)
|
| 147 |
-
|
| 148 |
-
# Check entity description consistency
|
| 149 |
-
descriptions = {
|
| 150 |
-
chunk_id: ext.get("description", "")
|
| 151 |
-
for chunk_id, ext in entity_extractions.items()
|
| 152 |
-
}
|
| 153 |
-
desc_conflict = self._check_entity_description_consistency(
|
| 154 |
-
entity_id, descriptions
|
| 155 |
-
)
|
| 156 |
-
if desc_conflict and desc_conflict.get("has_conflict", False):
|
| 157 |
-
conflicts.append(
|
| 158 |
-
{
|
| 159 |
-
"entity_id": clean_entity_id,
|
| 160 |
-
"conflict_type": "description",
|
| 161 |
-
"conflict_severity": desc_conflict.get("conflict_severity", 0.0),
|
| 162 |
-
"conflict_reasoning": desc_conflict.get("conflict_reasoning", ""),
|
| 163 |
-
"conflicting_values": desc_conflict.get(
|
| 164 |
-
"conflicting_descriptions", []
|
| 165 |
-
),
|
| 166 |
-
"conflict_details": desc_conflict.get("conflict_details", ""),
|
| 167 |
-
}
|
| 168 |
-
)
|
| 169 |
-
|
| 170 |
-
return entity_id, conflicts
|
| 171 |
-
|
| 172 |
-
def _get_entity_chunks(self, source_ids: List[str]) -> List[Chunk]:
|
| 173 |
-
"""Get all chunks related to an entity."""
|
| 174 |
-
chunks = []
|
| 175 |
-
for chunk_id in source_ids:
|
| 176 |
-
chunk_data = self.chunk_storage.get_by_id(chunk_id)
|
| 177 |
-
if chunk_data:
|
| 178 |
-
try:
|
| 179 |
-
chunk = Chunk.from_dict(chunk_id, chunk_data)
|
| 180 |
-
chunks.append(chunk)
|
| 181 |
-
except Exception as e:
|
| 182 |
-
logger.warning(f"Failed to load chunk {chunk_id}: {e}")
|
| 183 |
-
continue
|
| 184 |
-
return chunks
|
| 185 |
-
|
| 186 |
-
def _extract_entity_from_chunk(
|
| 187 |
-
self, entity_id: str, chunk: Chunk
|
| 188 |
-
) -> Dict[str, str]:
|
| 189 |
-
"""Extract entity attributes from a chunk using LLM."""
|
| 190 |
-
try:
|
| 191 |
-
# Clean entity_id: remove surrounding quotes if present
|
| 192 |
-
clean_entity_id = self._clean_entity_id(entity_id)
|
| 193 |
-
|
| 194 |
-
# Detect language and get appropriate prompt
|
| 195 |
-
lang = detect_main_language(chunk.content)
|
| 196 |
-
prompt = CONSISTENCY_EVALUATION_PROMPT[lang]["ENTITY_EXTRACTION"].format(
|
| 197 |
-
entity_name=clean_entity_id,
|
| 198 |
-
chunk_content=chunk.content[:2000]
|
| 199 |
-
if chunk.content
|
| 200 |
-
else "", # Limit content length
|
| 201 |
-
)
|
| 202 |
-
|
| 203 |
-
response = asyncio.run(self.llm_client.generate_answer(prompt))
|
| 204 |
-
|
| 205 |
-
# Try to parse JSON response
|
| 206 |
-
try:
|
| 207 |
-
extraction = json.loads(response)
|
| 208 |
-
except json.JSONDecodeError:
|
| 209 |
-
# Try to extract JSON from markdown code blocks
|
| 210 |
-
json_match = re.search(r"\{.*\}", response, re.DOTALL)
|
| 211 |
-
if json_match:
|
| 212 |
-
extraction = json.loads(json_match.group(0))
|
| 213 |
-
else:
|
| 214 |
-
logger.warning(
|
| 215 |
-
f"Failed to parse extraction response for {entity_id} in chunk {chunk.id}"
|
| 216 |
-
)
|
| 217 |
-
return {}
|
| 218 |
-
|
| 219 |
-
# Normalize entity_type to lowercase and validate
|
| 220 |
-
entity_type = extraction.get("entity_type", "").lower().strip()
|
| 221 |
-
# Valid preset types
|
| 222 |
-
valid_types = {
|
| 223 |
-
"concept",
|
| 224 |
-
"date",
|
| 225 |
-
"location",
|
| 226 |
-
"keyword",
|
| 227 |
-
"organization",
|
| 228 |
-
"person",
|
| 229 |
-
"event",
|
| 230 |
-
"work",
|
| 231 |
-
"nature",
|
| 232 |
-
"artificial",
|
| 233 |
-
"science",
|
| 234 |
-
"technology",
|
| 235 |
-
"mission",
|
| 236 |
-
"gene",
|
| 237 |
-
}
|
| 238 |
-
# If entity_type is not in valid types, default to "concept"
|
| 239 |
-
if entity_type not in valid_types:
|
| 240 |
-
if entity_type: # If LLM provided a type but it's invalid
|
| 241 |
-
logger.warning(
|
| 242 |
-
f"Invalid entity_type '{entity_type}' for entity {clean_entity_id} in chunk {chunk.id}, "
|
| 243 |
-
f"defaulting to 'concept'"
|
| 244 |
-
)
|
| 245 |
-
entity_type = "concept"
|
| 246 |
-
|
| 247 |
-
return {
|
| 248 |
-
"entity_type": entity_type,
|
| 249 |
-
"description": extraction.get("description", ""),
|
| 250 |
-
}
|
| 251 |
-
except Exception as e:
|
| 252 |
-
logger.error(
|
| 253 |
-
f"Error extracting entity {entity_id} from chunk {chunk.id}: {e}"
|
| 254 |
-
)
|
| 255 |
-
return {}
|
| 256 |
-
|
| 257 |
-
def _check_entity_type_consistency(
|
| 258 |
-
self, entity_id: str, type_extractions: Dict[str, str]
|
| 259 |
-
) -> Dict[str, Any]:
|
| 260 |
-
"""Check entity type consistency using LLM."""
|
| 261 |
-
if len(set(type_extractions.values())) <= 1:
|
| 262 |
-
# All types are the same, no conflict
|
| 263 |
-
return {"has_conflict": False}
|
| 264 |
-
|
| 265 |
-
try:
|
| 266 |
-
type_list = [
|
| 267 |
-
f"Chunk {chunk_id}: {entity_type}"
|
| 268 |
-
for chunk_id, entity_type in type_extractions.items()
|
| 269 |
-
if entity_type
|
| 270 |
-
]
|
| 271 |
-
|
| 272 |
-
# Detect language from type extraction text
|
| 273 |
-
type_text = "\n".join(type_list)
|
| 274 |
-
lang = detect_main_language(type_text)
|
| 275 |
-
prompt = CONSISTENCY_EVALUATION_PROMPT[lang]["ENTITY_TYPE_CONFLICT"].format(
|
| 276 |
-
entity_name=entity_id, type_extractions=type_text
|
| 277 |
-
)
|
| 278 |
-
|
| 279 |
-
response = asyncio.run(self.llm_client.generate_answer(prompt))
|
| 280 |
-
|
| 281 |
-
# Parse JSON response
|
| 282 |
-
try:
|
| 283 |
-
result = json.loads(response)
|
| 284 |
-
except json.JSONDecodeError:
|
| 285 |
-
json_match = re.search(r"\{.*\}", response, re.DOTALL)
|
| 286 |
-
if json_match:
|
| 287 |
-
result = json.loads(json_match.group(0))
|
| 288 |
-
else:
|
| 289 |
-
logger.warning(
|
| 290 |
-
f"Failed to parse conflict detection response for {entity_id}"
|
| 291 |
-
)
|
| 292 |
-
return {"has_conflict": False}
|
| 293 |
-
|
| 294 |
-
return result
|
| 295 |
-
except Exception as e:
|
| 296 |
-
logger.error(f"Error checking type consistency for {entity_id}: {e}")
|
| 297 |
-
return {"has_conflict": False}
|
| 298 |
-
|
| 299 |
-
def _check_entity_description_consistency(
|
| 300 |
-
self, entity_id: str, descriptions: Dict[str, str]
|
| 301 |
-
) -> Dict[str, Any]:
|
| 302 |
-
"""Check entity description consistency using LLM."""
|
| 303 |
-
# Filter out empty descriptions
|
| 304 |
-
valid_descriptions = {k: v for k, v in descriptions.items() if v}
|
| 305 |
-
if len(valid_descriptions) < 2:
|
| 306 |
-
return {"has_conflict": False}
|
| 307 |
-
|
| 308 |
-
if len(set(valid_descriptions.values())) <= 1:
|
| 309 |
-
# All descriptions are the same, no conflict
|
| 310 |
-
return {"has_conflict": False}
|
| 311 |
-
|
| 312 |
-
try:
|
| 313 |
-
desc_list = [
|
| 314 |
-
f"Chunk {chunk_id}: {description}"
|
| 315 |
-
for chunk_id, description in valid_descriptions.items()
|
| 316 |
-
]
|
| 317 |
-
|
| 318 |
-
# Detect language from description text
|
| 319 |
-
desc_text = "\n".join(desc_list)
|
| 320 |
-
lang = detect_main_language(desc_text)
|
| 321 |
-
prompt = CONSISTENCY_EVALUATION_PROMPT[lang]["ENTITY_DESCRIPTION_CONFLICT"].format(
|
| 322 |
-
entity_name=entity_id, descriptions=desc_text
|
| 323 |
-
)
|
| 324 |
-
|
| 325 |
-
response = asyncio.run(self.llm_client.generate_answer(prompt))
|
| 326 |
-
|
| 327 |
-
# Parse JSON response
|
| 328 |
-
try:
|
| 329 |
-
result = json.loads(response)
|
| 330 |
-
except json.JSONDecodeError:
|
| 331 |
-
json_match = re.search(r"\{.*\}", response, re.DOTALL)
|
| 332 |
-
if json_match:
|
| 333 |
-
result = json.loads(json_match.group(0))
|
| 334 |
-
else:
|
| 335 |
-
logger.warning(
|
| 336 |
-
f"Failed to parse conflict detection response for {entity_id}"
|
| 337 |
-
)
|
| 338 |
-
return {"has_conflict": False}
|
| 339 |
-
|
| 340 |
-
return result
|
| 341 |
-
except Exception as e:
|
| 342 |
-
logger.error(f"Error checking description consistency for {entity_id}: {e}")
|
| 343 |
-
return {"has_conflict": False}
|
| 344 |
-
|
| 345 |
-
def _check_relation_consistency(
|
| 346 |
-
self, src_id: str, dst_id: str, relation_extractions: Dict[str, str]
|
| 347 |
-
) -> Dict[str, Any]:
|
| 348 |
-
"""Check relation consistency using LLM."""
|
| 349 |
-
if len(set(relation_extractions.values())) <= 1:
|
| 350 |
-
return {"has_conflict": False}
|
| 351 |
-
|
| 352 |
-
try:
|
| 353 |
-
rel_list = [
|
| 354 |
-
f"Chunk {chunk_id}: {relation}"
|
| 355 |
-
for chunk_id, relation in relation_extractions.items()
|
| 356 |
-
if relation
|
| 357 |
-
]
|
| 358 |
-
|
| 359 |
-
# Detect language from relation description text
|
| 360 |
-
rel_text = "\n".join(rel_list)
|
| 361 |
-
lang = detect_main_language(rel_text)
|
| 362 |
-
prompt = CONSISTENCY_EVALUATION_PROMPT[lang]["RELATION_CONFLICT"].format(
|
| 363 |
-
source_entity=src_id,
|
| 364 |
-
target_entity=dst_id,
|
| 365 |
-
relation_descriptions=rel_text,
|
| 366 |
-
)
|
| 367 |
-
|
| 368 |
-
response = asyncio.run(self.llm_client.generate_answer(prompt))
|
| 369 |
-
|
| 370 |
-
# Parse JSON response
|
| 371 |
-
try:
|
| 372 |
-
result = json.loads(response)
|
| 373 |
-
except json.JSONDecodeError:
|
| 374 |
-
json_match = re.search(r"\{.*\}", response, re.DOTALL)
|
| 375 |
-
if json_match:
|
| 376 |
-
result = json.loads(json_match.group(0))
|
| 377 |
-
else:
|
| 378 |
-
logger.warning(
|
| 379 |
-
f"Failed to parse relation conflict response for {src_id}->{dst_id}"
|
| 380 |
-
)
|
| 381 |
-
return {"has_conflict": False}
|
| 382 |
-
|
| 383 |
-
return result
|
| 384 |
-
except Exception as e:
|
| 385 |
-
logger.error(
|
| 386 |
-
f"Error checking relation consistency for {src_id}->{dst_id}: {e}"
|
| 387 |
-
)
|
| 388 |
-
return {"has_conflict": False}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/models/evaluator/kg/structure_evaluator.py
CHANGED
|
@@ -4,49 +4,49 @@ from typing import Any, Dict, Optional
|
|
| 4 |
import numpy as np
|
| 5 |
from scipy import stats
|
| 6 |
|
| 7 |
-
from graphgen.bases import BaseGraphStorage
|
| 8 |
from graphgen.utils import logger
|
| 9 |
|
| 10 |
|
| 11 |
-
class StructureEvaluator:
|
| 12 |
"""Evaluates structural robustness of the graph."""
|
| 13 |
|
| 14 |
def __init__(
|
| 15 |
self,
|
| 16 |
-
graph_storage: BaseGraphStorage,
|
| 17 |
noise_ratio_threshold: float = 0.15,
|
| 18 |
largest_cc_ratio_threshold: float = 0.90,
|
| 19 |
avg_degree_min: float = 2.0,
|
| 20 |
avg_degree_max: float = 5.0,
|
| 21 |
powerlaw_r2_threshold: float = 0.75,
|
| 22 |
):
|
| 23 |
-
self.graph_storage = graph_storage
|
| 24 |
self.noise_ratio_threshold = noise_ratio_threshold
|
| 25 |
self.largest_cc_ratio_threshold = largest_cc_ratio_threshold
|
| 26 |
self.avg_degree_min = avg_degree_min
|
| 27 |
self.avg_degree_max = avg_degree_max
|
| 28 |
self.powerlaw_r2_threshold = powerlaw_r2_threshold
|
| 29 |
|
| 30 |
-
def evaluate(self) -> Dict[str, Any]:
|
| 31 |
"""
|
| 32 |
Evaluate the structural robustness of the graph.
|
| 33 |
-
:return:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
"""
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
if total_nodes == 0:
|
| 39 |
-
return {"error": "Empty graph"}
|
| 40 |
-
|
| 41 |
-
total_edges = storage.get_edge_count()
|
| 42 |
-
degree_map = storage.get_all_node_degrees()
|
| 43 |
|
| 44 |
# Noise ratio: isolated nodes / total nodes
|
| 45 |
isolated_nodes = [nid for nid, deg in degree_map.items() if deg == 0]
|
| 46 |
noise_ratio = len(isolated_nodes) / total_nodes
|
| 47 |
|
| 48 |
# Largest connected component
|
| 49 |
-
components =
|
| 50 |
largest_cc_ratio = (
|
| 51 |
len(max(components, key=len)) / total_nodes if components else 0
|
| 52 |
)
|
|
|
|
| 4 |
import numpy as np
|
| 5 |
from scipy import stats
|
| 6 |
|
| 7 |
+
from graphgen.bases import BaseGraphStorage, BaseKGEvaluator
|
| 8 |
from graphgen.utils import logger
|
| 9 |
|
| 10 |
|
| 11 |
+
class StructureEvaluator(BaseKGEvaluator):
|
| 12 |
"""Evaluates structural robustness of the graph."""
|
| 13 |
|
| 14 |
def __init__(
|
| 15 |
self,
|
|
|
|
| 16 |
noise_ratio_threshold: float = 0.15,
|
| 17 |
largest_cc_ratio_threshold: float = 0.90,
|
| 18 |
avg_degree_min: float = 2.0,
|
| 19 |
avg_degree_max: float = 5.0,
|
| 20 |
powerlaw_r2_threshold: float = 0.75,
|
| 21 |
):
|
|
|
|
| 22 |
self.noise_ratio_threshold = noise_ratio_threshold
|
| 23 |
self.largest_cc_ratio_threshold = largest_cc_ratio_threshold
|
| 24 |
self.avg_degree_min = avg_degree_min
|
| 25 |
self.avg_degree_max = avg_degree_max
|
| 26 |
self.powerlaw_r2_threshold = powerlaw_r2_threshold
|
| 27 |
|
| 28 |
+
def evaluate(self, kg: BaseGraphStorage) -> Dict[str, Any]:
|
| 29 |
"""
|
| 30 |
Evaluate the structural robustness of the graph.
|
| 31 |
+
:return: Dictionary of structural metrics and robustness verdict. The keys include:
|
| 32 |
+
- total_nodes: Total number of nodes in the graph
|
| 33 |
+
- total_edges: Total number of edges in the graph
|
| 34 |
+
- noise_ratio: Ratio of isolated nodes to total nodes
|
| 35 |
+
- largest_cc_ratio: Ratio of largest connected component size to total nodes
|
| 36 |
+
- avg_degree: Average node degree
|
| 37 |
+
- powerlaw_r2: R² value of power law fit to degree distribution
|
| 38 |
+
- is_robust: Boolean indicating if the graph is structurally robust
|
| 39 |
"""
|
| 40 |
+
total_nodes = kg.get_node_count()
|
| 41 |
+
total_edges = kg.get_edge_count()
|
| 42 |
+
degree_map = kg.get_all_node_degrees()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
# Noise ratio: isolated nodes / total nodes
|
| 45 |
isolated_nodes = [nid for nid, deg in degree_map.items() if deg == 0]
|
| 46 |
noise_ratio = len(isolated_nodes) / total_nodes
|
| 47 |
|
| 48 |
# Largest connected component
|
| 49 |
+
components = kg.get_connected_components(undirected=True)
|
| 50 |
largest_cc_ratio = (
|
| 51 |
len(max(components, key=len)) / total_nodes if components else 0
|
| 52 |
)
|
graphgen/models/evaluator/qa/length_evaluator.py
CHANGED
|
@@ -1,18 +1,19 @@
|
|
| 1 |
-
|
| 2 |
import os
|
| 3 |
-
|
|
|
|
| 4 |
from graphgen.models.tokenizer import Tokenizer
|
| 5 |
|
| 6 |
|
| 7 |
-
class LengthEvaluator(
|
| 8 |
def __init__(self, tokenizer_name: str = None):
|
| 9 |
-
tokenizer_model = tokenizer_name or os.environ.get(
|
|
|
|
|
|
|
| 10 |
self.tokenizer: Tokenizer = Tokenizer(tokenizer_model)
|
| 11 |
|
| 12 |
-
def evaluate(self, pair: QAPair) -> float:
|
| 13 |
"""
|
| 14 |
Evaluate the length of the qa pair.
|
| 15 |
"""
|
| 16 |
content = pair.question + pair.answer
|
| 17 |
-
|
| 18 |
-
return len(tokens)
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
+
|
| 3 |
+
from graphgen.bases import BaseQAEvaluator, QAPair
|
| 4 |
from graphgen.models.tokenizer import Tokenizer
|
| 5 |
|
| 6 |
|
| 7 |
+
class LengthEvaluator(BaseQAEvaluator):
|
| 8 |
def __init__(self, tokenizer_name: str = None):
|
| 9 |
+
tokenizer_model = tokenizer_name or os.environ.get(
|
| 10 |
+
"TOKENIZER_MODEL", "cl100k_base"
|
| 11 |
+
)
|
| 12 |
self.tokenizer: Tokenizer = Tokenizer(tokenizer_model)
|
| 13 |
|
| 14 |
+
async def evaluate(self, pair: QAPair) -> dict[str, float]:
|
| 15 |
"""
|
| 16 |
Evaluate the length of the qa pair.
|
| 17 |
"""
|
| 18 |
content = pair.question + pair.answer
|
| 19 |
+
return {"length": self.tokenizer.count_tokens(content)}
|
|
|
graphgen/models/evaluator/qa/mtld_evaluator.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
from typing import Set
|
| 2 |
|
| 3 |
-
from graphgen.bases import
|
| 4 |
from graphgen.utils import NLTKHelper, detect_main_language
|
| 5 |
|
| 6 |
|
| 7 |
-
class MTLDEvaluator(
|
| 8 |
"""
|
| 9 |
Metrics for measuring the lexical diversity of text.
|
| 10 |
"""
|
|
@@ -15,7 +15,7 @@ class MTLDEvaluator(BaseEvaluator):
|
|
| 15 |
self.stopwords_zh: Set[str] = set(self.nltk_helper.get_stopwords("zh"))
|
| 16 |
self.threshold = threshold
|
| 17 |
|
| 18 |
-
def evaluate(self, pair: QAPair) -> float:
|
| 19 |
"""
|
| 20 |
Calculate the MTLD (Mean Token Length Diversity) score for a given text.
|
| 21 |
|
|
@@ -24,7 +24,7 @@ class MTLDEvaluator(BaseEvaluator):
|
|
| 24 |
"""
|
| 25 |
text = pair.answer
|
| 26 |
if not text or not text.strip():
|
| 27 |
-
return 0
|
| 28 |
|
| 29 |
lang = detect_main_language(text)
|
| 30 |
tokens = self.nltk_helper.word_tokenize(text, lang)
|
|
@@ -34,7 +34,7 @@ class MTLDEvaluator(BaseEvaluator):
|
|
| 34 |
filtered_tokens = [word for word in filtered_tokens if word.isalnum()]
|
| 35 |
|
| 36 |
if not filtered_tokens:
|
| 37 |
-
return 0
|
| 38 |
|
| 39 |
# Compute forward factors
|
| 40 |
forward_factors = self._compute_factors(filtered_tokens, self.threshold)
|
|
@@ -43,7 +43,8 @@ class MTLDEvaluator(BaseEvaluator):
|
|
| 43 |
backward_factors = self._compute_factors(filtered_tokens[::-1], self.threshold)
|
| 44 |
|
| 45 |
# Compute average factors
|
| 46 |
-
|
|
|
|
| 47 |
|
| 48 |
@staticmethod
|
| 49 |
def _compute_factors(tokens: list, threshold: float) -> float:
|
|
|
|
| 1 |
from typing import Set
|
| 2 |
|
| 3 |
+
from graphgen.bases import BaseQAEvaluator, QAPair
|
| 4 |
from graphgen.utils import NLTKHelper, detect_main_language
|
| 5 |
|
| 6 |
|
| 7 |
+
class MTLDEvaluator(BaseQAEvaluator):
|
| 8 |
"""
|
| 9 |
Metrics for measuring the lexical diversity of text.
|
| 10 |
"""
|
|
|
|
| 15 |
self.stopwords_zh: Set[str] = set(self.nltk_helper.get_stopwords("zh"))
|
| 16 |
self.threshold = threshold
|
| 17 |
|
| 18 |
+
async def evaluate(self, pair: QAPair) -> dict[str, float]:
|
| 19 |
"""
|
| 20 |
Calculate the MTLD (Mean Token Length Diversity) score for a given text.
|
| 21 |
|
|
|
|
| 24 |
"""
|
| 25 |
text = pair.answer
|
| 26 |
if not text or not text.strip():
|
| 27 |
+
return {"mtld": 0}
|
| 28 |
|
| 29 |
lang = detect_main_language(text)
|
| 30 |
tokens = self.nltk_helper.word_tokenize(text, lang)
|
|
|
|
| 34 |
filtered_tokens = [word for word in filtered_tokens if word.isalnum()]
|
| 35 |
|
| 36 |
if not filtered_tokens:
|
| 37 |
+
return {"mtld": 0}
|
| 38 |
|
| 39 |
# Compute forward factors
|
| 40 |
forward_factors = self._compute_factors(filtered_tokens, self.threshold)
|
|
|
|
| 43 |
backward_factors = self._compute_factors(filtered_tokens[::-1], self.threshold)
|
| 44 |
|
| 45 |
# Compute average factors
|
| 46 |
+
mtld_score = (forward_factors + backward_factors) / 2
|
| 47 |
+
return {"mtld": mtld_score}
|
| 48 |
|
| 49 |
@staticmethod
|
| 50 |
def _compute_factors(tokens: list, threshold: float) -> float:
|
graphgen/models/evaluator/qa/reward_evaluator.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
| 1 |
from typing import Optional
|
| 2 |
-
from graphgen.bases import BaseEvaluator, QAPair
|
| 3 |
|
|
|
|
| 4 |
|
| 5 |
-
|
|
|
|
| 6 |
"""
|
| 7 |
Reward Model Evaluator for single QAPair evaluation.
|
| 8 |
"""
|
|
@@ -15,7 +16,7 @@ class RewardEvaluator(BaseEvaluator):
|
|
| 15 |
):
|
| 16 |
"""
|
| 17 |
Initialize the reward evaluator.
|
| 18 |
-
|
| 19 |
Args:
|
| 20 |
reward_name: Model name or path on HuggingFace Hub
|
| 21 |
max_length: Maximum token length for the model
|
|
@@ -26,6 +27,7 @@ class RewardEvaluator(BaseEvaluator):
|
|
| 26 |
|
| 27 |
import torch
|
| 28 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
|
|
|
| 29 |
self.torch = torch
|
| 30 |
|
| 31 |
# Set device (auto-detect if not specified)
|
|
@@ -37,15 +39,17 @@ class RewardEvaluator(BaseEvaluator):
|
|
| 37 |
self.model.to(self.device)
|
| 38 |
self.model.eval()
|
| 39 |
except Exception as e:
|
| 40 |
-
raise RuntimeError(
|
|
|
|
|
|
|
| 41 |
|
| 42 |
-
def evaluate(self, pair: QAPair) -> float:
|
| 43 |
"""
|
| 44 |
Evaluate a single question-answer pair using the reward model.
|
| 45 |
-
|
| 46 |
Args:
|
| 47 |
pair: QAPair containing question and answer strings
|
| 48 |
-
|
| 49 |
Returns:
|
| 50 |
Score as a float
|
| 51 |
"""
|
|
@@ -63,4 +67,4 @@ class RewardEvaluator(BaseEvaluator):
|
|
| 63 |
with self.torch.no_grad():
|
| 64 |
score = self.model(**inputs).logits[0].item()
|
| 65 |
|
| 66 |
-
return score
|
|
|
|
| 1 |
from typing import Optional
|
|
|
|
| 2 |
|
| 3 |
+
from graphgen.bases import BaseQAEvaluator, QAPair
|
| 4 |
|
| 5 |
+
|
| 6 |
+
class RewardEvaluator(BaseQAEvaluator):
|
| 7 |
"""
|
| 8 |
Reward Model Evaluator for single QAPair evaluation.
|
| 9 |
"""
|
|
|
|
| 16 |
):
|
| 17 |
"""
|
| 18 |
Initialize the reward evaluator.
|
| 19 |
+
|
| 20 |
Args:
|
| 21 |
reward_name: Model name or path on HuggingFace Hub
|
| 22 |
max_length: Maximum token length for the model
|
|
|
|
| 27 |
|
| 28 |
import torch
|
| 29 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
| 30 |
+
|
| 31 |
self.torch = torch
|
| 32 |
|
| 33 |
# Set device (auto-detect if not specified)
|
|
|
|
| 39 |
self.model.to(self.device)
|
| 40 |
self.model.eval()
|
| 41 |
except Exception as e:
|
| 42 |
+
raise RuntimeError(
|
| 43 |
+
f"Failed to load reward model '{reward_name}': {e}"
|
| 44 |
+
) from e
|
| 45 |
|
| 46 |
+
async def evaluate(self, pair: QAPair) -> dict[str, float]:
|
| 47 |
"""
|
| 48 |
Evaluate a single question-answer pair using the reward model.
|
| 49 |
+
|
| 50 |
Args:
|
| 51 |
pair: QAPair containing question and answer strings
|
| 52 |
+
|
| 53 |
Returns:
|
| 54 |
Score as a float
|
| 55 |
"""
|
|
|
|
| 67 |
with self.torch.no_grad():
|
| 68 |
score = self.model(**inputs).logits[0].item()
|
| 69 |
|
| 70 |
+
return {"reward_score": score}
|
graphgen/models/evaluator/qa/uni_evaluator.py
CHANGED
|
@@ -1,14 +1,15 @@
|
|
| 1 |
# https://github.com/maszhongming/UniEval/tree/main
|
| 2 |
-
from typing import
|
| 3 |
-
from graphgen.bases import BaseEvaluator, QAPair
|
| 4 |
|
|
|
|
| 5 |
|
| 6 |
-
|
|
|
|
| 7 |
"""
|
| 8 |
UniEvaluator for single QAPair evaluation across quality dimensions.
|
| 9 |
-
|
| 10 |
Dimensions: naturalness, coherence, understandability
|
| 11 |
-
|
| 12 |
Usage:
|
| 13 |
evaluator = UniEvaluator()
|
| 14 |
pair = QAPair(question="...", answer="...")
|
|
@@ -34,6 +35,7 @@ class UniEvaluator(BaseEvaluator):
|
|
| 34 |
"""
|
| 35 |
import torch
|
| 36 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
|
|
|
| 37 |
self.torch = torch
|
| 38 |
|
| 39 |
self.model_name = model_name or self.DEFAULT_MODEL
|
|
@@ -58,10 +60,12 @@ class UniEvaluator(BaseEvaluator):
|
|
| 58 |
if dimension == "coherence":
|
| 59 |
return f"question: Is this a coherent response? </s> response: {answer} </s> history: {question}"
|
| 60 |
if dimension == "understandability":
|
| 61 |
-
return
|
|
|
|
|
|
|
| 62 |
raise NotImplementedError(f"Unsupported dimension '{dimension}'")
|
| 63 |
|
| 64 |
-
def evaluate(
|
| 65 |
self,
|
| 66 |
pair: QAPair,
|
| 67 |
dimensions: Optional[List[str]] = None,
|
|
@@ -72,7 +76,9 @@ class UniEvaluator(BaseEvaluator):
|
|
| 72 |
# Validate dimensions
|
| 73 |
invalid = set(dimensions) - set(self.DEFAULT_DIMS)
|
| 74 |
if invalid:
|
| 75 |
-
raise ValueError(
|
|
|
|
|
|
|
| 76 |
|
| 77 |
results = {}
|
| 78 |
no_token = self.torch.tensor([[self._no_id]], device=self.device)
|
|
@@ -95,7 +101,9 @@ class UniEvaluator(BaseEvaluator):
|
|
| 95 |
attention_mask=src_mask,
|
| 96 |
labels=no_token,
|
| 97 |
use_cache=False,
|
| 98 |
-
).logits[
|
|
|
|
|
|
|
| 99 |
|
| 100 |
probs = self.torch.softmax(logits, dim=-1)[0]
|
| 101 |
score = probs[self._yes_id] / (probs[self._yes_id] + probs[self._no_id])
|
|
|
|
| 1 |
# https://github.com/maszhongming/UniEval/tree/main
|
| 2 |
+
from typing import List, Optional
|
|
|
|
| 3 |
|
| 4 |
+
from graphgen.bases import BaseQAEvaluator, QAPair
|
| 5 |
|
| 6 |
+
|
| 7 |
+
class UniEvaluator(BaseQAEvaluator):
|
| 8 |
"""
|
| 9 |
UniEvaluator for single QAPair evaluation across quality dimensions.
|
| 10 |
+
|
| 11 |
Dimensions: naturalness, coherence, understandability
|
| 12 |
+
|
| 13 |
Usage:
|
| 14 |
evaluator = UniEvaluator()
|
| 15 |
pair = QAPair(question="...", answer="...")
|
|
|
|
| 35 |
"""
|
| 36 |
import torch
|
| 37 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
| 38 |
+
|
| 39 |
self.torch = torch
|
| 40 |
|
| 41 |
self.model_name = model_name or self.DEFAULT_MODEL
|
|
|
|
| 60 |
if dimension == "coherence":
|
| 61 |
return f"question: Is this a coherent response? </s> response: {answer} </s> history: {question}"
|
| 62 |
if dimension == "understandability":
|
| 63 |
+
return (
|
| 64 |
+
f"question: Is this an understandable response? </s> response: {answer}"
|
| 65 |
+
)
|
| 66 |
raise NotImplementedError(f"Unsupported dimension '{dimension}'")
|
| 67 |
|
| 68 |
+
async def evaluate(
|
| 69 |
self,
|
| 70 |
pair: QAPair,
|
| 71 |
dimensions: Optional[List[str]] = None,
|
|
|
|
| 76 |
# Validate dimensions
|
| 77 |
invalid = set(dimensions) - set(self.DEFAULT_DIMS)
|
| 78 |
if invalid:
|
| 79 |
+
raise ValueError(
|
| 80 |
+
f"Invalid dimensions: {invalid}. Available: {self.DEFAULT_DIMS}"
|
| 81 |
+
)
|
| 82 |
|
| 83 |
results = {}
|
| 84 |
no_token = self.torch.tensor([[self._no_id]], device=self.device)
|
|
|
|
| 101 |
attention_mask=src_mask,
|
| 102 |
labels=no_token,
|
| 103 |
use_cache=False,
|
| 104 |
+
).logits[
|
| 105 |
+
:, 0, :
|
| 106 |
+
] # [1, vocab_size]
|
| 107 |
|
| 108 |
probs = self.torch.softmax(logits, dim=-1)[0]
|
| 109 |
score = probs[self._yes_id] / (probs[self._yes_id] + probs[self._no_id])
|
graphgen/models/evaluator/triple/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .accuracy_evaluator import AccuracyEvaluator
|
graphgen/models/evaluator/triple/accuracy_evaluator.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import re
|
| 3 |
+
from typing import Any, Dict
|
| 4 |
+
|
| 5 |
+
from graphgen.bases import BaseLLMWrapper, BaseTripleEvaluator
|
| 6 |
+
from graphgen.templates import ACCURACY_EVALUATION_PROMPT
|
| 7 |
+
from graphgen.utils import detect_main_language, logger
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class AccuracyEvaluator(BaseTripleEvaluator):
|
| 11 |
+
"""Evaluates accuracy of entity recognition and relation extraction using LLM-as-a-Judge.
|
| 12 |
+
|
| 13 |
+
For each chunk, uses LLM to evaluate the quality of extracted entities and relations
|
| 14 |
+
by comparing them with the original chunk content. Provides multi-dimensional quality
|
| 15 |
+
scores (accuracy, completeness, precision).
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
llm_client: BaseLLMWrapper,
|
| 21 |
+
):
|
| 22 |
+
self.llm_client = llm_client
|
| 23 |
+
|
| 24 |
+
async def evaluate(self, unit: tuple) -> Dict[str, Any]:
|
| 25 |
+
"""Evaluate entity and relation extraction quality using LLM-as-a-Judge.
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
Dictionary containing entity_accuracy and relation_accuracy metrics.
|
| 29 |
+
"""
|
| 30 |
+
chunk_content, nodes, edges = unit
|
| 31 |
+
lang = detect_main_language(chunk_content)
|
| 32 |
+
|
| 33 |
+
# node
|
| 34 |
+
prompt = ACCURACY_EVALUATION_PROMPT[lang]["ENTITY"].format(
|
| 35 |
+
chunk_content=chunk_content,
|
| 36 |
+
extracted_entities=json.dumps(nodes, ensure_ascii=False, indent=2),
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
response = await self.llm_client.generate_answer(prompt)
|
| 40 |
+
|
| 41 |
+
# Try to parse JSON response
|
| 42 |
+
try:
|
| 43 |
+
node_evaluation_result = json.loads(response)
|
| 44 |
+
except json.JSONDecodeError:
|
| 45 |
+
# Try to extract JSON from markdown code blocks or other formats
|
| 46 |
+
json_match = re.search(r"\{.*\}", response, re.DOTALL)
|
| 47 |
+
if json_match:
|
| 48 |
+
node_evaluation_result = json.loads(json_match.group(0))
|
| 49 |
+
else:
|
| 50 |
+
logger.warning("Failed to parse LLM response.")
|
| 51 |
+
# default evaluation
|
| 52 |
+
node_evaluation_result = {
|
| 53 |
+
"accuracy": 0.0,
|
| 54 |
+
"completeness": 0.0,
|
| 55 |
+
"precision": 0.0,
|
| 56 |
+
"overall_score": 0.0,
|
| 57 |
+
"accuracy_reasoning": "Failed to parse LLM response",
|
| 58 |
+
"completeness_reasoning": "",
|
| 59 |
+
"precision_reasoning": "",
|
| 60 |
+
"issues": ["LLM response parsing failed"],
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
# edge
|
| 64 |
+
prompt = ACCURACY_EVALUATION_PROMPT[lang]["RELATION"].format(
|
| 65 |
+
chunk_content=chunk_content,
|
| 66 |
+
extracted_relations=json.dumps(edges, ensure_ascii=False, indent=2),
|
| 67 |
+
)
|
| 68 |
+
response = await self.llm_client.generate_answer(prompt)
|
| 69 |
+
# Try to parse JSON response
|
| 70 |
+
try:
|
| 71 |
+
edge_evaluation_result = json.loads(response)
|
| 72 |
+
except json.JSONDecodeError:
|
| 73 |
+
# Try to extract JSON from markdown code blocks or other formats
|
| 74 |
+
json_match = re.search(r"\{.*\}", response, re.DOTALL)
|
| 75 |
+
if json_match:
|
| 76 |
+
edge_evaluation_result = json.loads(json_match.group(0))
|
| 77 |
+
else:
|
| 78 |
+
logger.warning("Failed to parse LLM response.")
|
| 79 |
+
# default evaluation
|
| 80 |
+
edge_evaluation_result = {
|
| 81 |
+
"accuracy": 0.0,
|
| 82 |
+
"completeness": 0.0,
|
| 83 |
+
"precision": 0.0,
|
| 84 |
+
"overall_score": 0.0,
|
| 85 |
+
"accuracy_reasoning": "Failed to parse LLM response",
|
| 86 |
+
"completeness_reasoning": "",
|
| 87 |
+
"precision_reasoning": "",
|
| 88 |
+
"issues": ["LLM response parsing failed"],
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
return {
|
| 92 |
+
"entity_accuracy": node_evaluation_result,
|
| 93 |
+
"relation_accuracy": edge_evaluation_result,
|
| 94 |
+
}
|
graphgen/models/extractor/schema_guided_extractor.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
| 1 |
import json
|
| 2 |
-
from typing import Dict, List
|
| 3 |
|
| 4 |
-
from graphgen.bases import BaseExtractor, BaseLLMWrapper
|
| 5 |
from graphgen.templates import SCHEMA_GUIDED_EXTRACTION_PROMPT
|
| 6 |
-
from graphgen.utils import
|
| 7 |
|
| 8 |
|
| 9 |
class SchemaGuidedExtractor(BaseExtractor):
|
|
@@ -59,9 +58,8 @@ class SchemaGuidedExtractor(BaseExtractor):
|
|
| 59 |
)
|
| 60 |
return prompt
|
| 61 |
|
| 62 |
-
async def extract(self, chunk:
|
| 63 |
-
|
| 64 |
-
text = chunk.get("content", "")
|
| 65 |
|
| 66 |
prompt = self.build_prompt(text)
|
| 67 |
response = await self.llm_client.generate_answer(prompt)
|
|
@@ -74,35 +72,9 @@ class SchemaGuidedExtractor(BaseExtractor):
|
|
| 74 |
if any(extracted_info[key] == "" for key in self.required_keys):
|
| 75 |
logger.debug("Missing required keys in extraction: %s", extracted_info)
|
| 76 |
return {}
|
| 77 |
-
main_keys_info = {key: extracted_info[key] for key in self.required_keys}
|
| 78 |
logger.debug("Extracted info: %s", extracted_info)
|
|
|
|
| 79 |
|
| 80 |
-
# add chunk metadata
|
| 81 |
-
extracted_info["_chunk_id"] = _chunk_id
|
| 82 |
-
|
| 83 |
-
return {
|
| 84 |
-
compute_dict_hash(main_keys_info, prefix="extract-"): extracted_info
|
| 85 |
-
}
|
| 86 |
except json.JSONDecodeError:
|
| 87 |
logger.error("Failed to parse extraction response: %s", response)
|
| 88 |
return {}
|
| 89 |
-
|
| 90 |
-
@staticmethod
|
| 91 |
-
def merge_extractions(extraction_list: List[Dict[str, dict]]) -> Dict[str, dict]:
|
| 92 |
-
"""
|
| 93 |
-
Merge multiple extraction results based on their hashes.
|
| 94 |
-
:param extraction_list: List of extraction results, each is a dict with hash as key and record as value.
|
| 95 |
-
:return: Merged extraction results.
|
| 96 |
-
"""
|
| 97 |
-
merged: Dict[str, dict] = {}
|
| 98 |
-
for ext in extraction_list:
|
| 99 |
-
for h, rec in ext.items():
|
| 100 |
-
if h not in merged:
|
| 101 |
-
merged[h] = rec.copy()
|
| 102 |
-
else:
|
| 103 |
-
for k, v in rec.items():
|
| 104 |
-
if k not in merged[h] or merged[h][k] == v:
|
| 105 |
-
merged[h][k] = v
|
| 106 |
-
else:
|
| 107 |
-
merged[h][k] = f"{merged[h][k]}<SEP>{v}"
|
| 108 |
-
return merged
|
|
|
|
| 1 |
import json
|
|
|
|
| 2 |
|
| 3 |
+
from graphgen.bases import BaseExtractor, BaseLLMWrapper, Chunk
|
| 4 |
from graphgen.templates import SCHEMA_GUIDED_EXTRACTION_PROMPT
|
| 5 |
+
from graphgen.utils import detect_main_language, logger
|
| 6 |
|
| 7 |
|
| 8 |
class SchemaGuidedExtractor(BaseExtractor):
|
|
|
|
| 58 |
)
|
| 59 |
return prompt
|
| 60 |
|
| 61 |
+
async def extract(self, chunk: Chunk) -> dict:
|
| 62 |
+
text = chunk.content
|
|
|
|
| 63 |
|
| 64 |
prompt = self.build_prompt(text)
|
| 65 |
response = await self.llm_client.generate_answer(prompt)
|
|
|
|
| 72 |
if any(extracted_info[key] == "" for key in self.required_keys):
|
| 73 |
logger.debug("Missing required keys in extraction: %s", extracted_info)
|
| 74 |
return {}
|
|
|
|
| 75 |
logger.debug("Extracted info: %s", extracted_info)
|
| 76 |
+
return extracted_info
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
except json.JSONDecodeError:
|
| 79 |
logger.error("Failed to parse extraction response: %s", response)
|
| 80 |
return {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/models/generator/aggregated_generator.py
CHANGED
|
@@ -3,7 +3,7 @@ from typing import Any, Optional
|
|
| 3 |
|
| 4 |
from graphgen.bases import BaseGenerator
|
| 5 |
from graphgen.templates import AGGREGATED_GENERATION_PROMPT
|
| 6 |
-
from graphgen.utils import
|
| 7 |
|
| 8 |
|
| 9 |
class AggregatedGenerator(BaseGenerator):
|
|
@@ -101,30 +101,26 @@ class AggregatedGenerator(BaseGenerator):
|
|
| 101 |
batch: tuple[
|
| 102 |
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
|
| 103 |
],
|
| 104 |
-
) -> dict
|
| 105 |
"""
|
| 106 |
Generate QAs based on a given batch.
|
| 107 |
:param batch
|
| 108 |
:return: QA pairs
|
| 109 |
"""
|
| 110 |
-
result = {}
|
| 111 |
rephrasing_prompt = self.build_prompt(batch)
|
| 112 |
response = await self.llm_client.generate_answer(rephrasing_prompt)
|
| 113 |
context = self.parse_rephrased_text(response)
|
| 114 |
if not context:
|
| 115 |
-
return
|
| 116 |
question_generation_prompt = self._build_prompt_for_question_generation(context)
|
| 117 |
response = await self.llm_client.generate_answer(question_generation_prompt)
|
| 118 |
question = self.parse_response(response)["question"]
|
| 119 |
if not question:
|
| 120 |
-
return
|
| 121 |
logger.debug("Question: %s", question)
|
| 122 |
logger.debug("Answer: %s", context)
|
| 123 |
qa_pairs = {
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
"answer": context,
|
| 127 |
-
}
|
| 128 |
}
|
| 129 |
-
|
| 130 |
-
return result
|
|
|
|
| 3 |
|
| 4 |
from graphgen.bases import BaseGenerator
|
| 5 |
from graphgen.templates import AGGREGATED_GENERATION_PROMPT
|
| 6 |
+
from graphgen.utils import detect_main_language, logger
|
| 7 |
|
| 8 |
|
| 9 |
class AggregatedGenerator(BaseGenerator):
|
|
|
|
| 101 |
batch: tuple[
|
| 102 |
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
|
| 103 |
],
|
| 104 |
+
) -> list[dict]:
|
| 105 |
"""
|
| 106 |
Generate QAs based on a given batch.
|
| 107 |
:param batch
|
| 108 |
:return: QA pairs
|
| 109 |
"""
|
|
|
|
| 110 |
rephrasing_prompt = self.build_prompt(batch)
|
| 111 |
response = await self.llm_client.generate_answer(rephrasing_prompt)
|
| 112 |
context = self.parse_rephrased_text(response)
|
| 113 |
if not context:
|
| 114 |
+
return []
|
| 115 |
question_generation_prompt = self._build_prompt_for_question_generation(context)
|
| 116 |
response = await self.llm_client.generate_answer(question_generation_prompt)
|
| 117 |
question = self.parse_response(response)["question"]
|
| 118 |
if not question:
|
| 119 |
+
return []
|
| 120 |
logger.debug("Question: %s", question)
|
| 121 |
logger.debug("Answer: %s", context)
|
| 122 |
qa_pairs = {
|
| 123 |
+
"question": question,
|
| 124 |
+
"answer": context,
|
|
|
|
|
|
|
| 125 |
}
|
| 126 |
+
return [qa_pairs]
|
|
|
graphgen/models/generator/atomic_generator.py
CHANGED
|
@@ -3,7 +3,7 @@ from typing import Any
|
|
| 3 |
|
| 4 |
from graphgen.bases import BaseGenerator
|
| 5 |
from graphgen.templates import ATOMIC_GENERATION_PROMPT
|
| 6 |
-
from graphgen.utils import
|
| 7 |
|
| 8 |
|
| 9 |
class AtomicGenerator(BaseGenerator):
|
|
@@ -23,7 +23,7 @@ class AtomicGenerator(BaseGenerator):
|
|
| 23 |
return prompt
|
| 24 |
|
| 25 |
@staticmethod
|
| 26 |
-
def parse_response(response: str) -> dict:
|
| 27 |
"""
|
| 28 |
AtomicGenerator normally generates one QA pair per response.
|
| 29 |
So we just need to parse one QA pair from the response.
|
|
@@ -38,15 +38,10 @@ class AtomicGenerator(BaseGenerator):
|
|
| 38 |
answer = answer_match.group(1).strip()
|
| 39 |
else:
|
| 40 |
logger.warning("Failed to parse response: %s", response)
|
| 41 |
-
return
|
| 42 |
|
| 43 |
question = question.strip('"').strip("'")
|
| 44 |
answer = answer.strip('"').strip("'")
|
| 45 |
logger.debug("Question: %s", question)
|
| 46 |
logger.debug("Answer: %s", answer)
|
| 47 |
-
return {
|
| 48 |
-
compute_content_hash(question): {
|
| 49 |
-
"question": question,
|
| 50 |
-
"answer": answer,
|
| 51 |
-
}
|
| 52 |
-
}
|
|
|
|
| 3 |
|
| 4 |
from graphgen.bases import BaseGenerator
|
| 5 |
from graphgen.templates import ATOMIC_GENERATION_PROMPT
|
| 6 |
+
from graphgen.utils import detect_main_language, logger
|
| 7 |
|
| 8 |
|
| 9 |
class AtomicGenerator(BaseGenerator):
|
|
|
|
| 23 |
return prompt
|
| 24 |
|
| 25 |
@staticmethod
|
| 26 |
+
def parse_response(response: str) -> list[dict]:
|
| 27 |
"""
|
| 28 |
AtomicGenerator normally generates one QA pair per response.
|
| 29 |
So we just need to parse one QA pair from the response.
|
|
|
|
| 38 |
answer = answer_match.group(1).strip()
|
| 39 |
else:
|
| 40 |
logger.warning("Failed to parse response: %s", response)
|
| 41 |
+
return []
|
| 42 |
|
| 43 |
question = question.strip('"').strip("'")
|
| 44 |
answer = answer.strip('"').strip("'")
|
| 45 |
logger.debug("Question: %s", question)
|
| 46 |
logger.debug("Answer: %s", answer)
|
| 47 |
+
return [{"question": question, "answer": answer}]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/models/generator/cot_generator.py
CHANGED
|
@@ -3,7 +3,7 @@ from typing import Any
|
|
| 3 |
|
| 4 |
from graphgen.bases import BaseGenerator
|
| 5 |
from graphgen.templates import COT_GENERATION_PROMPT
|
| 6 |
-
from graphgen.utils import
|
| 7 |
|
| 8 |
|
| 9 |
class CoTGenerator(BaseGenerator):
|
|
@@ -100,28 +100,25 @@ class CoTGenerator(BaseGenerator):
|
|
| 100 |
batch: tuple[
|
| 101 |
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
|
| 102 |
],
|
| 103 |
-
) -> dict
|
| 104 |
"""
|
| 105 |
Generate QAs based on a given batch.
|
| 106 |
:param batch
|
| 107 |
:return: QA pairs
|
| 108 |
"""
|
| 109 |
-
result = {}
|
| 110 |
prompt = self.build_prompt(batch)
|
| 111 |
response = await self.llm_client.generate_answer(prompt)
|
| 112 |
response = self.parse_response(response)
|
| 113 |
if not response:
|
| 114 |
-
return
|
| 115 |
question, reasoning_path = response["question"], response["reasoning_path"]
|
| 116 |
prompt = self.build_prompt_for_cot_generation(batch, question, reasoning_path)
|
| 117 |
cot_answer = await self.llm_client.generate_answer(prompt)
|
| 118 |
logger.debug("CoT Answer: %s", cot_answer)
|
| 119 |
-
|
| 120 |
-
|
| 121 |
"question": question,
|
| 122 |
"answer": cot_answer,
|
| 123 |
"reasoning_path": reasoning_path,
|
| 124 |
}
|
| 125 |
-
|
| 126 |
-
result.update(qa_pairs)
|
| 127 |
-
return result
|
|
|
|
| 3 |
|
| 4 |
from graphgen.bases import BaseGenerator
|
| 5 |
from graphgen.templates import COT_GENERATION_PROMPT
|
| 6 |
+
from graphgen.utils import detect_main_language, logger
|
| 7 |
|
| 8 |
|
| 9 |
class CoTGenerator(BaseGenerator):
|
|
|
|
| 100 |
batch: tuple[
|
| 101 |
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
|
| 102 |
],
|
| 103 |
+
) -> list[dict]:
|
| 104 |
"""
|
| 105 |
Generate QAs based on a given batch.
|
| 106 |
:param batch
|
| 107 |
:return: QA pairs
|
| 108 |
"""
|
|
|
|
| 109 |
prompt = self.build_prompt(batch)
|
| 110 |
response = await self.llm_client.generate_answer(prompt)
|
| 111 |
response = self.parse_response(response)
|
| 112 |
if not response:
|
| 113 |
+
return []
|
| 114 |
question, reasoning_path = response["question"], response["reasoning_path"]
|
| 115 |
prompt = self.build_prompt_for_cot_generation(batch, question, reasoning_path)
|
| 116 |
cot_answer = await self.llm_client.generate_answer(prompt)
|
| 117 |
logger.debug("CoT Answer: %s", cot_answer)
|
| 118 |
+
return [
|
| 119 |
+
{
|
| 120 |
"question": question,
|
| 121 |
"answer": cot_answer,
|
| 122 |
"reasoning_path": reasoning_path,
|
| 123 |
}
|
| 124 |
+
]
|
|
|
|
|
|
graphgen/models/generator/fill_in_blank_generator.py
CHANGED
|
@@ -3,7 +3,7 @@ from typing import Any
|
|
| 3 |
|
| 4 |
from graphgen.bases import BaseGenerator
|
| 5 |
from graphgen.templates import FILL_IN_BLANK_GENERATION_PROMPT
|
| 6 |
-
from graphgen.utils import
|
| 7 |
|
| 8 |
|
| 9 |
class FillInBlankGenerator(BaseGenerator):
|
|
@@ -12,7 +12,7 @@ class FillInBlankGenerator(BaseGenerator):
|
|
| 12 |
self.num_of_questions = num_of_questions
|
| 13 |
|
| 14 |
@staticmethod
|
| 15 |
-
def parse_response(response: str) ->
|
| 16 |
"""
|
| 17 |
Parse fill-in-the-blank QA pairs from the LLM response.
|
| 18 |
Each QA pair contains question text with placeholders and the correct answer(s).
|
|
@@ -21,14 +21,14 @@ class FillInBlankGenerator(BaseGenerator):
|
|
| 21 |
:return: Dictionary mapping question hash to question data, where each
|
| 22 |
value is a dict with "question", "answer", and "answers" keys
|
| 23 |
"""
|
| 24 |
-
qa_pairs =
|
| 25 |
|
| 26 |
# Extract all QA pair blocks
|
| 27 |
qa_blocks = re.findall(r"<qa_pair>(.*?)</qa_pair>", response, re.DOTALL)
|
| 28 |
|
| 29 |
if not qa_blocks:
|
| 30 |
logger.warning("No QA pairs found in response: %s", response)
|
| 31 |
-
return
|
| 32 |
|
| 33 |
for block in qa_blocks:
|
| 34 |
# Extract and clean question text
|
|
@@ -55,13 +55,13 @@ class FillInBlankGenerator(BaseGenerator):
|
|
| 55 |
logger.warning("No valid answers found in: %s", answer_text)
|
| 56 |
continue
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
|
| 66 |
logger.debug(
|
| 67 |
"Successfully parsed fill-in-the-blank question: %s", question[:50]
|
|
|
|
| 3 |
|
| 4 |
from graphgen.bases import BaseGenerator
|
| 5 |
from graphgen.templates import FILL_IN_BLANK_GENERATION_PROMPT
|
| 6 |
+
from graphgen.utils import detect_main_language, logger
|
| 7 |
|
| 8 |
|
| 9 |
class FillInBlankGenerator(BaseGenerator):
|
|
|
|
| 12 |
self.num_of_questions = num_of_questions
|
| 13 |
|
| 14 |
@staticmethod
|
| 15 |
+
def parse_response(response: str) -> list[dict]:
|
| 16 |
"""
|
| 17 |
Parse fill-in-the-blank QA pairs from the LLM response.
|
| 18 |
Each QA pair contains question text with placeholders and the correct answer(s).
|
|
|
|
| 21 |
:return: Dictionary mapping question hash to question data, where each
|
| 22 |
value is a dict with "question", "answer", and "answers" keys
|
| 23 |
"""
|
| 24 |
+
qa_pairs = []
|
| 25 |
|
| 26 |
# Extract all QA pair blocks
|
| 27 |
qa_blocks = re.findall(r"<qa_pair>(.*?)</qa_pair>", response, re.DOTALL)
|
| 28 |
|
| 29 |
if not qa_blocks:
|
| 30 |
logger.warning("No QA pairs found in response: %s", response)
|
| 31 |
+
return qa_pairs
|
| 32 |
|
| 33 |
for block in qa_blocks:
|
| 34 |
# Extract and clean question text
|
|
|
|
| 55 |
logger.warning("No valid answers found in: %s", answer_text)
|
| 56 |
continue
|
| 57 |
|
| 58 |
+
qa_pairs.append(
|
| 59 |
+
{
|
| 60 |
+
"question": question,
|
| 61 |
+
"answer": answer_text, # Original answer text with commas
|
| 62 |
+
"answers": answers, # List of individual answers: ["A8X"] or ["A8X", "八百万"]
|
| 63 |
+
}
|
| 64 |
+
)
|
| 65 |
|
| 66 |
logger.debug(
|
| 67 |
"Successfully parsed fill-in-the-blank question: %s", question[:50]
|
graphgen/models/generator/multi_answer_generator.py
CHANGED
|
@@ -3,7 +3,7 @@ from typing import Any
|
|
| 3 |
|
| 4 |
from graphgen.bases import BaseGenerator
|
| 5 |
from graphgen.templates import MAQ_GENERATION_PROMPT
|
| 6 |
-
from graphgen.utils import
|
| 7 |
|
| 8 |
|
| 9 |
class MultiAnswerGenerator(BaseGenerator):
|
|
@@ -12,7 +12,7 @@ class MultiAnswerGenerator(BaseGenerator):
|
|
| 12 |
self.num_of_questions = num_of_questions
|
| 13 |
|
| 14 |
@staticmethod
|
| 15 |
-
def parse_response(response: str) ->
|
| 16 |
"""
|
| 17 |
Parse multiple-answer QA pairs from the LLM response.
|
| 18 |
Each QA pair contains question text, four options, and the correct answers (one or more).
|
|
@@ -21,14 +21,14 @@ class MultiAnswerGenerator(BaseGenerator):
|
|
| 21 |
:return: Dictionary mapping question hash to question data, where each
|
| 22 |
value is a dict with "question", "options", and "answer" keys
|
| 23 |
"""
|
| 24 |
-
qa_pairs =
|
| 25 |
|
| 26 |
# Extract all QA pair blocks
|
| 27 |
qa_blocks = re.findall(r"<qa_pair>(.*?)</qa_pair>", response, re.DOTALL)
|
| 28 |
|
| 29 |
if not qa_blocks:
|
| 30 |
logger.warning("No QA pairs found in response: %s", response)
|
| 31 |
-
return
|
| 32 |
|
| 33 |
for block in qa_blocks:
|
| 34 |
# Extract and clean question text
|
|
@@ -61,7 +61,9 @@ class MultiAnswerGenerator(BaseGenerator):
|
|
| 61 |
logger.warning("Failed to parse answer from block: %s", block)
|
| 62 |
continue
|
| 63 |
answer_text = ans_match.group(1).strip().strip('"').strip("'")
|
| 64 |
-
answers = [
|
|
|
|
|
|
|
| 65 |
invalid_answers = [ans for ans in answers if ans not in options]
|
| 66 |
if invalid_answers:
|
| 67 |
logger.warning(
|
|
@@ -76,13 +78,13 @@ class MultiAnswerGenerator(BaseGenerator):
|
|
| 76 |
logger.warning("No valid answers found in: %s", answer_text)
|
| 77 |
continue
|
| 78 |
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
|
| 87 |
logger.debug("Successfully parsed MAQ: %s", question[:50])
|
| 88 |
|
|
|
|
| 3 |
|
| 4 |
from graphgen.bases import BaseGenerator
|
| 5 |
from graphgen.templates import MAQ_GENERATION_PROMPT
|
| 6 |
+
from graphgen.utils import detect_main_language, logger
|
| 7 |
|
| 8 |
|
| 9 |
class MultiAnswerGenerator(BaseGenerator):
|
|
|
|
| 12 |
self.num_of_questions = num_of_questions
|
| 13 |
|
| 14 |
@staticmethod
|
| 15 |
+
def parse_response(response: str) -> list[dict]:
|
| 16 |
"""
|
| 17 |
Parse multiple-answer QA pairs from the LLM response.
|
| 18 |
Each QA pair contains question text, four options, and the correct answers (one or more).
|
|
|
|
| 21 |
:return: Dictionary mapping question hash to question data, where each
|
| 22 |
value is a dict with "question", "options", and "answer" keys
|
| 23 |
"""
|
| 24 |
+
qa_pairs = []
|
| 25 |
|
| 26 |
# Extract all QA pair blocks
|
| 27 |
qa_blocks = re.findall(r"<qa_pair>(.*?)</qa_pair>", response, re.DOTALL)
|
| 28 |
|
| 29 |
if not qa_blocks:
|
| 30 |
logger.warning("No QA pairs found in response: %s", response)
|
| 31 |
+
return qa_pairs
|
| 32 |
|
| 33 |
for block in qa_blocks:
|
| 34 |
# Extract and clean question text
|
|
|
|
| 61 |
logger.warning("Failed to parse answer from block: %s", block)
|
| 62 |
continue
|
| 63 |
answer_text = ans_match.group(1).strip().strip('"').strip("'")
|
| 64 |
+
answers = [
|
| 65 |
+
ans.strip().upper() for ans in answer_text.split(",") if ans.strip()
|
| 66 |
+
]
|
| 67 |
invalid_answers = [ans for ans in answers if ans not in options]
|
| 68 |
if invalid_answers:
|
| 69 |
logger.warning(
|
|
|
|
| 78 |
logger.warning("No valid answers found in: %s", answer_text)
|
| 79 |
continue
|
| 80 |
|
| 81 |
+
qa_pairs.append(
|
| 82 |
+
{
|
| 83 |
+
"question": question,
|
| 84 |
+
"options": options, # Dict like {"A": "text", "B": "text", ...}
|
| 85 |
+
"answers": answers, # List of correct answers: ["A", "C"]
|
| 86 |
+
}
|
| 87 |
+
)
|
| 88 |
|
| 89 |
logger.debug("Successfully parsed MAQ: %s", question[:50])
|
| 90 |
|
graphgen/models/generator/multi_choice_generator.py
CHANGED
|
@@ -3,7 +3,7 @@ from typing import Any
|
|
| 3 |
|
| 4 |
from graphgen.bases import BaseGenerator
|
| 5 |
from graphgen.templates import MCQ_GENERATION_PROMPT
|
| 6 |
-
from graphgen.utils import
|
| 7 |
|
| 8 |
|
| 9 |
class MultiChoiceGenerator(BaseGenerator):
|
|
@@ -12,7 +12,7 @@ class MultiChoiceGenerator(BaseGenerator):
|
|
| 12 |
self.num_of_questions = num_of_questions
|
| 13 |
|
| 14 |
@staticmethod
|
| 15 |
-
def parse_response(response: str) ->
|
| 16 |
"""
|
| 17 |
Parse multiple choice QA pairs from the LLM response.
|
| 18 |
Each QA pair contains question text, four options, and the correct answer.
|
|
@@ -21,14 +21,14 @@ class MultiChoiceGenerator(BaseGenerator):
|
|
| 21 |
:return: Dictionary mapping question hash to question data, where each
|
| 22 |
value is a dict with "question", "options", and "answer" keys
|
| 23 |
"""
|
| 24 |
-
qa_pairs =
|
| 25 |
|
| 26 |
# Extract all QA pair blocks
|
| 27 |
qa_blocks = re.findall(r"<qa_pair>(.*?)</qa_pair>", response, re.DOTALL)
|
| 28 |
|
| 29 |
if not qa_blocks:
|
| 30 |
logger.warning("No QA pairs found in response: %s", response)
|
| 31 |
-
return
|
| 32 |
|
| 33 |
for block in qa_blocks:
|
| 34 |
# Extract and clean question text
|
|
@@ -76,13 +76,13 @@ class MultiChoiceGenerator(BaseGenerator):
|
|
| 76 |
)
|
| 77 |
continue
|
| 78 |
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
|
| 87 |
logger.debug("Successfully parsed MCQ: %s", question[:50])
|
| 88 |
|
|
|
|
| 3 |
|
| 4 |
from graphgen.bases import BaseGenerator
|
| 5 |
from graphgen.templates import MCQ_GENERATION_PROMPT
|
| 6 |
+
from graphgen.utils import detect_main_language, logger
|
| 7 |
|
| 8 |
|
| 9 |
class MultiChoiceGenerator(BaseGenerator):
|
|
|
|
| 12 |
self.num_of_questions = num_of_questions
|
| 13 |
|
| 14 |
@staticmethod
|
| 15 |
+
def parse_response(response: str) -> list[dict]:
|
| 16 |
"""
|
| 17 |
Parse multiple choice QA pairs from the LLM response.
|
| 18 |
Each QA pair contains question text, four options, and the correct answer.
|
|
|
|
| 21 |
:return: Dictionary mapping question hash to question data, where each
|
| 22 |
value is a dict with "question", "options", and "answer" keys
|
| 23 |
"""
|
| 24 |
+
qa_pairs = []
|
| 25 |
|
| 26 |
# Extract all QA pair blocks
|
| 27 |
qa_blocks = re.findall(r"<qa_pair>(.*?)</qa_pair>", response, re.DOTALL)
|
| 28 |
|
| 29 |
if not qa_blocks:
|
| 30 |
logger.warning("No QA pairs found in response: %s", response)
|
| 31 |
+
return qa_pairs
|
| 32 |
|
| 33 |
for block in qa_blocks:
|
| 34 |
# Extract and clean question text
|
|
|
|
| 76 |
)
|
| 77 |
continue
|
| 78 |
|
| 79 |
+
qa_pairs.append(
|
| 80 |
+
{
|
| 81 |
+
"question": question,
|
| 82 |
+
"options": options, # Dict like {"A": "text", "B": "text", ...}
|
| 83 |
+
"answer": answer, # Single letter: "A", "B", "C", or "D"
|
| 84 |
+
}
|
| 85 |
+
)
|
| 86 |
|
| 87 |
logger.debug("Successfully parsed MCQ: %s", question[:50])
|
| 88 |
|
graphgen/models/generator/multi_hop_generator.py
CHANGED
|
@@ -3,7 +3,7 @@ from typing import Any
|
|
| 3 |
|
| 4 |
from graphgen.bases import BaseGenerator
|
| 5 |
from graphgen.templates import MULTI_HOP_GENERATION_PROMPT
|
| 6 |
-
from graphgen.utils import
|
| 7 |
|
| 8 |
|
| 9 |
class MultiHopGenerator(BaseGenerator):
|
|
@@ -32,7 +32,7 @@ class MultiHopGenerator(BaseGenerator):
|
|
| 32 |
return prompt
|
| 33 |
|
| 34 |
@staticmethod
|
| 35 |
-
def parse_response(response: str) -> dict:
|
| 36 |
question_match = re.search(r"<question>(.*?)</question>", response, re.DOTALL)
|
| 37 |
answer_match = re.search(r"<answer>(.*?)</answer>", response, re.DOTALL)
|
| 38 |
|
|
@@ -41,15 +41,10 @@ class MultiHopGenerator(BaseGenerator):
|
|
| 41 |
answer = answer_match.group(1).strip()
|
| 42 |
else:
|
| 43 |
logger.warning("Failed to parse response: %s", response)
|
| 44 |
-
return
|
| 45 |
|
| 46 |
question = question.strip('"').strip("'")
|
| 47 |
answer = answer.strip('"').strip("'")
|
| 48 |
logger.debug("Question: %s", question)
|
| 49 |
logger.debug("Answer: %s", answer)
|
| 50 |
-
return {
|
| 51 |
-
compute_content_hash(question): {
|
| 52 |
-
"question": question,
|
| 53 |
-
"answer": answer,
|
| 54 |
-
}
|
| 55 |
-
}
|
|
|
|
| 3 |
|
| 4 |
from graphgen.bases import BaseGenerator
|
| 5 |
from graphgen.templates import MULTI_HOP_GENERATION_PROMPT
|
| 6 |
+
from graphgen.utils import detect_main_language, logger
|
| 7 |
|
| 8 |
|
| 9 |
class MultiHopGenerator(BaseGenerator):
|
|
|
|
| 32 |
return prompt
|
| 33 |
|
| 34 |
@staticmethod
|
| 35 |
+
def parse_response(response: str) -> list[dict]:
|
| 36 |
question_match = re.search(r"<question>(.*?)</question>", response, re.DOTALL)
|
| 37 |
answer_match = re.search(r"<answer>(.*?)</answer>", response, re.DOTALL)
|
| 38 |
|
|
|
|
| 41 |
answer = answer_match.group(1).strip()
|
| 42 |
else:
|
| 43 |
logger.warning("Failed to parse response: %s", response)
|
| 44 |
+
return []
|
| 45 |
|
| 46 |
question = question.strip('"').strip("'")
|
| 47 |
answer = answer.strip('"').strip("'")
|
| 48 |
logger.debug("Question: %s", question)
|
| 49 |
logger.debug("Answer: %s", answer)
|
| 50 |
+
return [{"question": question, "answer": answer}]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/models/generator/quiz_generator.py
CHANGED
|
@@ -31,12 +31,16 @@ class QuizGenerator(BaseGenerator):
|
|
| 31 |
description = edges[0][2].get("description", "")
|
| 32 |
template_type = edges[0][2].get("template_type", "TEMPLATE")
|
| 33 |
else:
|
| 34 |
-
raise ValueError(
|
|
|
|
|
|
|
| 35 |
|
| 36 |
return QuizGenerator.build_prompt_for_description(description, template_type)
|
| 37 |
|
| 38 |
@staticmethod
|
| 39 |
-
def build_prompt_for_description(
|
|
|
|
|
|
|
| 40 |
"""
|
| 41 |
Build prompt for rephrasing a single description.
|
| 42 |
:param description: The description to rephrase
|
|
@@ -49,17 +53,6 @@ class QuizGenerator(BaseGenerator):
|
|
| 49 |
)
|
| 50 |
return prompt
|
| 51 |
|
| 52 |
-
@staticmethod
|
| 53 |
-
def parse_rephrased_text(response: str) -> str:
|
| 54 |
-
"""
|
| 55 |
-
Parse the rephrased text from the response.
|
| 56 |
-
:param response:
|
| 57 |
-
:return:
|
| 58 |
-
"""
|
| 59 |
-
rephrased_text = response.strip().strip('"')
|
| 60 |
-
logger.debug("Rephrased Text: %s", rephrased_text)
|
| 61 |
-
return rephrased_text
|
| 62 |
-
|
| 63 |
@staticmethod
|
| 64 |
def parse_response(response: str) -> Any:
|
| 65 |
"""
|
|
@@ -67,4 +60,15 @@ class QuizGenerator(BaseGenerator):
|
|
| 67 |
:param response: LLM response
|
| 68 |
:return: Rephrased text
|
| 69 |
"""
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
description = edges[0][2].get("description", "")
|
| 32 |
template_type = edges[0][2].get("template_type", "TEMPLATE")
|
| 33 |
else:
|
| 34 |
+
raise ValueError(
|
| 35 |
+
"Batch must contain at least one node or edge with description"
|
| 36 |
+
)
|
| 37 |
|
| 38 |
return QuizGenerator.build_prompt_for_description(description, template_type)
|
| 39 |
|
| 40 |
@staticmethod
|
| 41 |
+
def build_prompt_for_description(
|
| 42 |
+
description: str, template_type: str = "TEMPLATE"
|
| 43 |
+
) -> str:
|
| 44 |
"""
|
| 45 |
Build prompt for rephrasing a single description.
|
| 46 |
:param description: The description to rephrase
|
|
|
|
| 53 |
)
|
| 54 |
return prompt
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
@staticmethod
|
| 57 |
def parse_response(response: str) -> Any:
|
| 58 |
"""
|
|
|
|
| 60 |
:param response: LLM response
|
| 61 |
:return: Rephrased text
|
| 62 |
"""
|
| 63 |
+
|
| 64 |
+
def parse_rephrased_text(content: str) -> str:
|
| 65 |
+
"""
|
| 66 |
+
Parse the rephrased text from the response.
|
| 67 |
+
:param content: LLM response content
|
| 68 |
+
:return:
|
| 69 |
+
"""
|
| 70 |
+
rephrased_text = content.strip().strip('"')
|
| 71 |
+
logger.debug("Rephrased Text: %s", rephrased_text)
|
| 72 |
+
return rephrased_text
|
| 73 |
+
|
| 74 |
+
return parse_rephrased_text(response)
|
graphgen/models/generator/true_false_generator.py
CHANGED
|
@@ -3,7 +3,7 @@ from typing import Any
|
|
| 3 |
|
| 4 |
from graphgen.bases import BaseGenerator
|
| 5 |
from graphgen.templates import TF_GENERATION_PROMPT
|
| 6 |
-
from graphgen.utils import
|
| 7 |
|
| 8 |
|
| 9 |
class TrueFalseGenerator(BaseGenerator):
|
|
@@ -12,7 +12,7 @@ class TrueFalseGenerator(BaseGenerator):
|
|
| 12 |
self.num_of_questions = num_of_questions
|
| 13 |
|
| 14 |
@staticmethod
|
| 15 |
-
def parse_response(response: str) ->
|
| 16 |
"""
|
| 17 |
Parse true/false QA pairs from the LLM response.
|
| 18 |
Each QA pair contains a statement question and True/False answer.
|
|
@@ -21,14 +21,14 @@ class TrueFalseGenerator(BaseGenerator):
|
|
| 21 |
:return: Dictionary mapping question hash to question data, where each
|
| 22 |
value is a dict with "question", "options", and "answer" keys
|
| 23 |
"""
|
| 24 |
-
qa_pairs: dict[str,
|
| 25 |
|
| 26 |
# Extract all QA pair blocks
|
| 27 |
qa_blocks = re.findall(r"<qa_pair>(.*?)</qa_pair>", response, re.DOTALL)
|
| 28 |
|
| 29 |
if not qa_blocks:
|
| 30 |
logger.warning("No QA pairs found in response: %s", response)
|
| 31 |
-
return
|
| 32 |
|
| 33 |
for block in qa_blocks:
|
| 34 |
# Extract and clean question text
|
|
@@ -50,12 +50,12 @@ class TrueFalseGenerator(BaseGenerator):
|
|
| 50 |
logger.warning("Invalid answer '%s' in block: %s", answer, block)
|
| 51 |
continue
|
| 52 |
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
|
| 60 |
logger.debug("Successfully parsed TF question: %s", question[:50])
|
| 61 |
|
|
|
|
| 3 |
|
| 4 |
from graphgen.bases import BaseGenerator
|
| 5 |
from graphgen.templates import TF_GENERATION_PROMPT
|
| 6 |
+
from graphgen.utils import detect_main_language, logger
|
| 7 |
|
| 8 |
|
| 9 |
class TrueFalseGenerator(BaseGenerator):
|
|
|
|
| 12 |
self.num_of_questions = num_of_questions
|
| 13 |
|
| 14 |
@staticmethod
|
| 15 |
+
def parse_response(response: str) -> list[dict]:
|
| 16 |
"""
|
| 17 |
Parse true/false QA pairs from the LLM response.
|
| 18 |
Each QA pair contains a statement question and True/False answer.
|
|
|
|
| 21 |
:return: Dictionary mapping question hash to question data, where each
|
| 22 |
value is a dict with "question", "options", and "answer" keys
|
| 23 |
"""
|
| 24 |
+
qa_pairs: list[dict[str, str]] = []
|
| 25 |
|
| 26 |
# Extract all QA pair blocks
|
| 27 |
qa_blocks = re.findall(r"<qa_pair>(.*?)</qa_pair>", response, re.DOTALL)
|
| 28 |
|
| 29 |
if not qa_blocks:
|
| 30 |
logger.warning("No QA pairs found in response: %s", response)
|
| 31 |
+
return qa_pairs
|
| 32 |
|
| 33 |
for block in qa_blocks:
|
| 34 |
# Extract and clean question text
|
|
|
|
| 50 |
logger.warning("Invalid answer '%s' in block: %s", answer, block)
|
| 51 |
continue
|
| 52 |
|
| 53 |
+
qa_pairs.append(
|
| 54 |
+
{
|
| 55 |
+
"question": question,
|
| 56 |
+
"answer": answer, # "True" or "False"
|
| 57 |
+
}
|
| 58 |
+
)
|
| 59 |
|
| 60 |
logger.debug("Successfully parsed TF question: %s", question[:50])
|
| 61 |
|
graphgen/models/generator/vqa_generator.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
|
|
| 1 |
import re
|
| 2 |
from typing import Any
|
| 3 |
|
| 4 |
from graphgen.bases import BaseGenerator
|
| 5 |
from graphgen.templates import VQA_GENERATION_PROMPT
|
| 6 |
-
from graphgen.utils import
|
| 7 |
|
| 8 |
|
| 9 |
class VQAGenerator(BaseGenerator):
|
|
@@ -32,13 +33,13 @@ class VQAGenerator(BaseGenerator):
|
|
| 32 |
return prompt
|
| 33 |
|
| 34 |
@staticmethod
|
| 35 |
-
def parse_response(response: str) ->
|
| 36 |
"""
|
| 37 |
Parse the LLM response and return the generated QAs
|
| 38 |
:param response
|
| 39 |
:return: QA pairs
|
| 40 |
"""
|
| 41 |
-
qa_pairs =
|
| 42 |
pattern = r"<question>(.*?)</question>\s*<answer>(.*?)</answer>"
|
| 43 |
matches = re.findall(pattern, response, re.DOTALL)
|
| 44 |
|
|
@@ -48,10 +49,12 @@ class VQAGenerator(BaseGenerator):
|
|
| 48 |
answer = answer.strip().strip('"').strip("'")
|
| 49 |
logger.debug("Question: %s", question)
|
| 50 |
logger.debug("Answer: %s", answer)
|
| 51 |
-
qa_pairs
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
|
|
|
|
|
|
| 55 |
else:
|
| 56 |
logger.warning("Error parsing the response %s", response)
|
| 57 |
return qa_pairs
|
|
@@ -61,76 +64,58 @@ class VQAGenerator(BaseGenerator):
|
|
| 61 |
batch: tuple[
|
| 62 |
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
|
| 63 |
],
|
| 64 |
-
) -> dict
|
| 65 |
"""
|
| 66 |
Generate QAs based on a given batch.
|
| 67 |
:param batch
|
| 68 |
:return: QA pairs
|
| 69 |
"""
|
| 70 |
-
result = {}
|
| 71 |
prompt = self.build_prompt(batch)
|
| 72 |
response = await self.llm_client.generate_answer(prompt)
|
| 73 |
qa_pairs = self.parse_response(response) # generate one or more QA pairs
|
| 74 |
nodes, _ = batch
|
| 75 |
for node in nodes:
|
| 76 |
node_data = node[1]
|
| 77 |
-
if "
|
| 78 |
-
|
| 79 |
-
|
|
|
|
| 80 |
qa["img_path"] = img_path
|
| 81 |
-
|
| 82 |
-
return result
|
| 83 |
|
| 84 |
@staticmethod
|
| 85 |
-
def format_generation_results(
|
| 86 |
-
|
| 87 |
-
|
|
|
|
| 88 |
if output_data_format == "Alpaca":
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
"
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
{
|
| 120 |
-
"role": "user",
|
| 121 |
-
"content": [
|
| 122 |
-
{"text": v["question"], "image": v.get("img_path", "")}
|
| 123 |
-
],
|
| 124 |
-
},
|
| 125 |
-
{
|
| 126 |
-
"role": "assistant",
|
| 127 |
-
"content": [{"type": "text", "text": v["answer"]}],
|
| 128 |
-
},
|
| 129 |
-
]
|
| 130 |
-
}
|
| 131 |
-
for item in results
|
| 132 |
-
for k, v in item.items()
|
| 133 |
-
]
|
| 134 |
-
else:
|
| 135 |
-
raise ValueError(f"Unknown output data format: {output_data_format}")
|
| 136 |
-
return results
|
|
|
|
| 1 |
+
import json
|
| 2 |
import re
|
| 3 |
from typing import Any
|
| 4 |
|
| 5 |
from graphgen.bases import BaseGenerator
|
| 6 |
from graphgen.templates import VQA_GENERATION_PROMPT
|
| 7 |
+
from graphgen.utils import detect_main_language, logger
|
| 8 |
|
| 9 |
|
| 10 |
class VQAGenerator(BaseGenerator):
|
|
|
|
| 33 |
return prompt
|
| 34 |
|
| 35 |
@staticmethod
|
| 36 |
+
def parse_response(response: str) -> list[dict]:
|
| 37 |
"""
|
| 38 |
Parse the LLM response and return the generated QAs
|
| 39 |
:param response
|
| 40 |
:return: QA pairs
|
| 41 |
"""
|
| 42 |
+
qa_pairs = []
|
| 43 |
pattern = r"<question>(.*?)</question>\s*<answer>(.*?)</answer>"
|
| 44 |
matches = re.findall(pattern, response, re.DOTALL)
|
| 45 |
|
|
|
|
| 49 |
answer = answer.strip().strip('"').strip("'")
|
| 50 |
logger.debug("Question: %s", question)
|
| 51 |
logger.debug("Answer: %s", answer)
|
| 52 |
+
qa_pairs.append(
|
| 53 |
+
{
|
| 54 |
+
"question": question,
|
| 55 |
+
"answer": answer,
|
| 56 |
+
}
|
| 57 |
+
)
|
| 58 |
else:
|
| 59 |
logger.warning("Error parsing the response %s", response)
|
| 60 |
return qa_pairs
|
|
|
|
| 64 |
batch: tuple[
|
| 65 |
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
|
| 66 |
],
|
| 67 |
+
) -> list[dict]:
|
| 68 |
"""
|
| 69 |
Generate QAs based on a given batch.
|
| 70 |
:param batch
|
| 71 |
:return: QA pairs
|
| 72 |
"""
|
|
|
|
| 73 |
prompt = self.build_prompt(batch)
|
| 74 |
response = await self.llm_client.generate_answer(prompt)
|
| 75 |
qa_pairs = self.parse_response(response) # generate one or more QA pairs
|
| 76 |
nodes, _ = batch
|
| 77 |
for node in nodes:
|
| 78 |
node_data = node[1]
|
| 79 |
+
if "metadata" in node_data and node_data["metadata"]:
|
| 80 |
+
metadata = json.loads(node_data["metadata"])["metadata"]
|
| 81 |
+
img_path = metadata.get("path", "")
|
| 82 |
+
for qa in qa_pairs:
|
| 83 |
qa["img_path"] = img_path
|
| 84 |
+
return qa_pairs
|
|
|
|
| 85 |
|
| 86 |
@staticmethod
|
| 87 |
+
def format_generation_results(result: dict, output_data_format: str) -> dict:
|
| 88 |
+
question = result.get("question", "")
|
| 89 |
+
answer = result.get("answer", "")
|
| 90 |
+
img_path = result.get("img_path", "")
|
| 91 |
if output_data_format == "Alpaca":
|
| 92 |
+
return {
|
| 93 |
+
"instruction": question,
|
| 94 |
+
"input": "",
|
| 95 |
+
"output": answer,
|
| 96 |
+
"image": img_path,
|
| 97 |
+
}
|
| 98 |
+
if output_data_format == "Sharegpt":
|
| 99 |
+
return {
|
| 100 |
+
"conversations": [
|
| 101 |
+
{
|
| 102 |
+
"from": "human",
|
| 103 |
+
"value": [{"text": question, "image": img_path}],
|
| 104 |
+
},
|
| 105 |
+
{"from": "gpt", "value": [{"text": answer}]},
|
| 106 |
+
]
|
| 107 |
+
}
|
| 108 |
+
if output_data_format == "ChatML":
|
| 109 |
+
return {
|
| 110 |
+
"messages": [
|
| 111 |
+
{
|
| 112 |
+
"role": "user",
|
| 113 |
+
"content": [{"text": question, "image": img_path}],
|
| 114 |
+
},
|
| 115 |
+
{
|
| 116 |
+
"role": "assistant",
|
| 117 |
+
"content": [{"type": "text", "text": answer}],
|
| 118 |
+
},
|
| 119 |
+
]
|
| 120 |
+
}
|
| 121 |
+
raise ValueError(f"Unknown output data format: {output_data_format}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/models/kg_builder/light_rag_kg_builder.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import re
|
| 2 |
from collections import Counter, defaultdict
|
| 3 |
from typing import Dict, List, Tuple
|
|
@@ -130,15 +131,25 @@ class LightRAGKGBuilder(BaseKGBuilder):
|
|
| 130 |
set([dp["source_id"] for dp in node_data] + source_ids)
|
| 131 |
)
|
| 132 |
|
| 133 |
-
|
| 134 |
"entity_type": entity_type,
|
| 135 |
"entity_name": entity_name,
|
| 136 |
"description": description,
|
| 137 |
"source_id": source_id,
|
| 138 |
"length": self.tokenizer.count_tokens(description),
|
| 139 |
}
|
| 140 |
-
|
| 141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
async def merge_edges(
|
| 144 |
self,
|
|
|
|
| 1 |
+
import json
|
| 2 |
import re
|
| 3 |
from collections import Counter, defaultdict
|
| 4 |
from typing import Dict, List, Tuple
|
|
|
|
| 131 |
set([dp["source_id"] for dp in node_data] + source_ids)
|
| 132 |
)
|
| 133 |
|
| 134 |
+
node_data_dict = {
|
| 135 |
"entity_type": entity_type,
|
| 136 |
"entity_name": entity_name,
|
| 137 |
"description": description,
|
| 138 |
"source_id": source_id,
|
| 139 |
"length": self.tokenizer.count_tokens(description),
|
| 140 |
}
|
| 141 |
+
|
| 142 |
+
if entity_type in ("IMAGE", "TABLE", "FORMULA"):
|
| 143 |
+
metadata = next(
|
| 144 |
+
(dp["metadata"] for dp in node_data if dp.get("metadata")), None
|
| 145 |
+
)
|
| 146 |
+
if metadata:
|
| 147 |
+
node_data_dict["metadata"] = json.dumps(
|
| 148 |
+
metadata, ensure_ascii=False, default=str
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
kg_instance.upsert_node(entity_name, node_data=node_data_dict)
|
| 152 |
+
return node_data_dict
|
| 153 |
|
| 154 |
async def merge_edges(
|
| 155 |
self,
|
graphgen/models/kg_builder/mm_kg_builder.py
CHANGED
|
@@ -70,6 +70,8 @@ class MMKGBuilder(LightRAGKGBuilder):
|
|
| 70 |
|
| 71 |
entity = await handle_single_entity_extraction(attributes, chunk_id)
|
| 72 |
if entity is not None:
|
|
|
|
|
|
|
| 73 |
nodes[entity["entity_name"]].append(entity)
|
| 74 |
continue
|
| 75 |
|
|
|
|
| 70 |
|
| 71 |
entity = await handle_single_entity_extraction(attributes, chunk_id)
|
| 72 |
if entity is not None:
|
| 73 |
+
if entity["entity_type"] == "IMAGE":
|
| 74 |
+
entity["metadata"] = chunk.metadata
|
| 75 |
nodes[entity["entity_name"]].append(entity)
|
| 76 |
continue
|
| 77 |
|
graphgen/models/reader/csv_reader.py
CHANGED
|
@@ -22,7 +22,7 @@ class CSVReader(BaseReader):
|
|
| 22 |
:return: Ray Dataset containing validated and filtered data.
|
| 23 |
"""
|
| 24 |
|
| 25 |
-
ds = ray.data.read_csv(input_path)
|
| 26 |
ds = ds.map_batches(self._validate_batch, batch_format="pandas")
|
| 27 |
ds = ds.filter(self._should_keep_item)
|
| 28 |
return ds
|
|
|
|
| 22 |
:return: Ray Dataset containing validated and filtered data.
|
| 23 |
"""
|
| 24 |
|
| 25 |
+
ds = ray.data.read_csv(input_path, include_paths=True)
|
| 26 |
ds = ds.map_batches(self._validate_batch, batch_format="pandas")
|
| 27 |
ds = ds.filter(self._should_keep_item)
|
| 28 |
return ds
|
graphgen/models/reader/json_reader.py
CHANGED
|
@@ -34,10 +34,13 @@ class JSONReader(BaseReader):
|
|
| 34 |
with open(file, "r", encoding="utf-8") as f:
|
| 35 |
data = json.load(f)
|
| 36 |
data = self._unify_schema(data)
|
|
|
|
|
|
|
|
|
|
| 37 |
file_ds: ray.data.Dataset = ray.data.from_items(data)
|
| 38 |
ds = ds.union(file_ds) # type: ignore
|
| 39 |
else:
|
| 40 |
-
ds = ray.data.read_json(input_path)
|
| 41 |
ds = ds.map_batches(self._validate_batch, batch_format="pandas")
|
| 42 |
ds = ds.filter(self._should_keep_item)
|
| 43 |
return ds
|
|
|
|
| 34 |
with open(file, "r", encoding="utf-8") as f:
|
| 35 |
data = json.load(f)
|
| 36 |
data = self._unify_schema(data)
|
| 37 |
+
# add path
|
| 38 |
+
for item in data:
|
| 39 |
+
item["path"] = file
|
| 40 |
file_ds: ray.data.Dataset = ray.data.from_items(data)
|
| 41 |
ds = ds.union(file_ds) # type: ignore
|
| 42 |
else:
|
| 43 |
+
ds = ray.data.read_json(input_path, include_paths=True)
|
| 44 |
ds = ds.map_batches(self._validate_batch, batch_format="pandas")
|
| 45 |
ds = ds.filter(self._should_keep_item)
|
| 46 |
return ds
|
graphgen/models/reader/parquet_reader.py
CHANGED
|
@@ -24,7 +24,7 @@ class ParquetReader(BaseReader):
|
|
| 24 |
if not ray.is_initialized():
|
| 25 |
ray.init()
|
| 26 |
|
| 27 |
-
ds = ray.data.read_parquet(input_path)
|
| 28 |
ds = ds.map_batches(self._validate_batch, batch_format="pandas")
|
| 29 |
ds = ds.filter(self._should_keep_item)
|
| 30 |
return ds
|
|
|
|
| 24 |
if not ray.is_initialized():
|
| 25 |
ray.init()
|
| 26 |
|
| 27 |
+
ds = ray.data.read_parquet(input_path, include_paths=True)
|
| 28 |
ds = ds.map_batches(self._validate_batch, batch_format="pandas")
|
| 29 |
ds = ds.filter(self._should_keep_item)
|
| 30 |
return ds
|
graphgen/models/reader/rdf_reader.py
CHANGED
|
@@ -118,7 +118,7 @@ class RDFReader(BaseReader):
|
|
| 118 |
"id": str(subj),
|
| 119 |
self.text_column: text,
|
| 120 |
"properties": props,
|
| 121 |
-
"
|
| 122 |
}
|
| 123 |
docs.append(doc)
|
| 124 |
|
|
|
|
| 118 |
"id": str(subj),
|
| 119 |
self.text_column: text,
|
| 120 |
"properties": props,
|
| 121 |
+
"path": str(file_path),
|
| 122 |
}
|
| 123 |
docs.append(doc)
|
| 124 |
|
graphgen/models/reader/txt_reader.py
CHANGED
|
@@ -18,13 +18,14 @@ class TXTReader(BaseReader):
|
|
| 18 |
"""
|
| 19 |
docs_ds = ray.data.read_binary_files(
|
| 20 |
input_path,
|
| 21 |
-
include_paths=
|
| 22 |
)
|
| 23 |
|
| 24 |
docs_ds = docs_ds.map(
|
| 25 |
lambda row: {
|
| 26 |
"type": "text",
|
| 27 |
self.text_column: row["bytes"].decode("utf-8"),
|
|
|
|
| 28 |
}
|
| 29 |
)
|
| 30 |
|
|
|
|
| 18 |
"""
|
| 19 |
docs_ds = ray.data.read_binary_files(
|
| 20 |
input_path,
|
| 21 |
+
include_paths=True,
|
| 22 |
)
|
| 23 |
|
| 24 |
docs_ds = docs_ds.map(
|
| 25 |
lambda row: {
|
| 26 |
"type": "text",
|
| 27 |
self.text_column: row["bytes"].decode("utf-8"),
|
| 28 |
+
"path": row["path"],
|
| 29 |
}
|
| 30 |
)
|
| 31 |
|
graphgen/models/storage/__init__.py
DELETED
|
@@ -1,6 +0,0 @@
|
|
| 1 |
-
from graphgen.models.storage.graph.kuzu_storage import KuzuStorage
|
| 2 |
-
from graphgen.models.storage.graph.networkx_storage import NetworkXStorage
|
| 3 |
-
from graphgen.models.storage.kv.json_storage import JsonKVStorage
|
| 4 |
-
from graphgen.models.storage.kv.rocksdb_storage import RocksDBKVStorage
|
| 5 |
-
|
| 6 |
-
from .rocksdb_cache import RocksDBCache
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/models/storage/rocksdb_cache.py
DELETED
|
@@ -1,43 +0,0 @@
|
|
| 1 |
-
from pathlib import Path
|
| 2 |
-
from typing import Any, Iterator, Optional
|
| 3 |
-
|
| 4 |
-
# rocksdict is a lightweight C wrapper around RocksDB for Python, pylint may not recognize it
|
| 5 |
-
# pylint: disable=no-name-in-module
|
| 6 |
-
from rocksdict import Rdict
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
class RocksDBCache:
|
| 10 |
-
def __init__(self, cache_dir: str):
|
| 11 |
-
self.db_path = Path(cache_dir)
|
| 12 |
-
self.db = Rdict(str(self.db_path))
|
| 13 |
-
|
| 14 |
-
def get(self, key: str) -> Optional[Any]:
|
| 15 |
-
return self.db.get(key)
|
| 16 |
-
|
| 17 |
-
def set(self, key: str, value: Any):
|
| 18 |
-
self.db[key] = value
|
| 19 |
-
|
| 20 |
-
def delete(self, key: str):
|
| 21 |
-
try:
|
| 22 |
-
del self.db[key]
|
| 23 |
-
except KeyError:
|
| 24 |
-
# If the key does not exist, do nothing (deletion is idempotent for caches)
|
| 25 |
-
pass
|
| 26 |
-
|
| 27 |
-
def close(self):
|
| 28 |
-
if hasattr(self, "db") and self.db is not None:
|
| 29 |
-
self.db.close()
|
| 30 |
-
self.db = None
|
| 31 |
-
|
| 32 |
-
def __del__(self):
|
| 33 |
-
# Ensure the database is closed when the object is destroyed
|
| 34 |
-
self.close()
|
| 35 |
-
|
| 36 |
-
def __enter__(self):
|
| 37 |
-
return self
|
| 38 |
-
|
| 39 |
-
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 40 |
-
self.close()
|
| 41 |
-
|
| 42 |
-
def __iter__(self) -> Iterator[str]:
|
| 43 |
-
return iter(self.db.keys())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/models/vis/__init__.py
DELETED
|
File without changes
|
graphgen/models/vis/community_visualizer.py
DELETED
|
@@ -1,48 +0,0 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
-
from typing import Dict
|
| 3 |
-
|
| 4 |
-
import matplotlib.pyplot as plt
|
| 5 |
-
import networkx as nx
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
@dataclass
|
| 9 |
-
class Visualizer:
|
| 10 |
-
"""
|
| 11 |
-
Class for visualizing graphs using NetworkX and Matplotlib.
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
graph: nx.Graph = None
|
| 15 |
-
communities: Dict[str, int] = None
|
| 16 |
-
layout: str = "spring"
|
| 17 |
-
max_nodes: int = 1000
|
| 18 |
-
node_size: int = 10
|
| 19 |
-
alpha: float = 0.6
|
| 20 |
-
|
| 21 |
-
def visualize(self, save_path: str = None):
|
| 22 |
-
n = self.graph.number_of_nodes()
|
| 23 |
-
if self.layout == "spring":
|
| 24 |
-
k = max(0.1, 1.0 / (n**0.5))
|
| 25 |
-
pos = nx.spring_layout(self.graph, k=k, seed=42)
|
| 26 |
-
else:
|
| 27 |
-
raise ValueError(f"Unknown layout: {self.layout}")
|
| 28 |
-
|
| 29 |
-
plt.figure(figsize=(10, 10))
|
| 30 |
-
|
| 31 |
-
node_colors = [self.communities.get(node, 0) for node in self.graph.nodes()]
|
| 32 |
-
|
| 33 |
-
nx.draw_networkx_nodes(
|
| 34 |
-
self.graph,
|
| 35 |
-
pos,
|
| 36 |
-
node_size=self.node_size,
|
| 37 |
-
node_color=node_colors,
|
| 38 |
-
cmap=plt.cm.tab20,
|
| 39 |
-
alpha=self.alpha,
|
| 40 |
-
)
|
| 41 |
-
nx.draw_networkx_edges(self.graph, pos, alpha=0.3, width=0.2)
|
| 42 |
-
plt.axis("off")
|
| 43 |
-
|
| 44 |
-
if save_path:
|
| 45 |
-
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
| 46 |
-
print("Saved to", save_path)
|
| 47 |
-
else:
|
| 48 |
-
plt.show()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/operators/build_kg/build_kg_service.py
CHANGED
|
@@ -1,6 +1,4 @@
|
|
| 1 |
-
from typing import
|
| 2 |
-
|
| 3 |
-
import pandas as pd
|
| 4 |
|
| 5 |
from graphgen.bases import BaseGraphStorage, BaseLLMWrapper, BaseOperator
|
| 6 |
from graphgen.bases.datatypes import Chunk
|
|
@@ -13,9 +11,15 @@ from .build_text_kg import build_text_kg
|
|
| 13 |
|
| 14 |
class BuildKGService(BaseOperator):
|
| 15 |
def __init__(
|
| 16 |
-
self,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
):
|
| 18 |
-
super().__init__(
|
|
|
|
|
|
|
| 19 |
self.llm_client: BaseLLMWrapper = init_llm("synthesizer")
|
| 20 |
self.graph_storage: BaseGraphStorage = init_storage(
|
| 21 |
backend=graph_backend, working_dir=working_dir, namespace="graph"
|
|
@@ -23,21 +27,15 @@ class BuildKGService(BaseOperator):
|
|
| 23 |
self.build_kwargs = build_kwargs
|
| 24 |
self.max_loop: int = int(self.build_kwargs.get("max_loop", 3))
|
| 25 |
|
| 26 |
-
def process(self, batch:
|
| 27 |
-
docs = batch.to_dict(orient="records")
|
| 28 |
-
docs = [Chunk.from_dict(doc["_chunk_id"], doc) for doc in docs]
|
| 29 |
-
|
| 30 |
-
# consume the chunks and build kg
|
| 31 |
-
nodes, edges = self.build_kg(docs)
|
| 32 |
-
return pd.DataFrame(
|
| 33 |
-
[{"node": node, "edge": []} for node in nodes]
|
| 34 |
-
+ [{"node": [], "edge": edge} for edge in edges]
|
| 35 |
-
)
|
| 36 |
-
|
| 37 |
-
def build_kg(self, chunks: List[Chunk]) -> tuple:
|
| 38 |
"""
|
| 39 |
Build knowledge graph (KG) and merge into kg_instance
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
"""
|
|
|
|
| 41 |
text_chunks = [chunk for chunk in chunks if chunk.type == "text"]
|
| 42 |
mm_chunks = [
|
| 43 |
chunk
|
|
@@ -75,4 +73,34 @@ class BuildKGService(BaseOperator):
|
|
| 75 |
self.graph_storage.index_done_callback()
|
| 76 |
logger.info("Knowledge graph building completed.")
|
| 77 |
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from graphgen.bases import BaseGraphStorage, BaseLLMWrapper, BaseOperator
|
| 4 |
from graphgen.bases.datatypes import Chunk
|
|
|
|
| 11 |
|
| 12 |
class BuildKGService(BaseOperator):
|
| 13 |
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
working_dir: str = "cache",
|
| 16 |
+
kv_backend: str = "rocksdb",
|
| 17 |
+
graph_backend: str = "kuzu",
|
| 18 |
+
**build_kwargs
|
| 19 |
):
|
| 20 |
+
super().__init__(
|
| 21 |
+
working_dir=working_dir, kv_backend=kv_backend, op_name="build_kg"
|
| 22 |
+
)
|
| 23 |
self.llm_client: BaseLLMWrapper = init_llm("synthesizer")
|
| 24 |
self.graph_storage: BaseGraphStorage = init_storage(
|
| 25 |
backend=graph_backend, working_dir=working_dir, namespace="graph"
|
|
|
|
| 27 |
self.build_kwargs = build_kwargs
|
| 28 |
self.max_loop: int = int(self.build_kwargs.get("max_loop", 3))
|
| 29 |
|
| 30 |
+
def process(self, batch: list) -> Tuple[list, dict]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
"""
|
| 32 |
Build knowledge graph (KG) and merge into kg_instance
|
| 33 |
+
:return: A tuple of (results, meta_updates)
|
| 34 |
+
results: A list of dicts containing nodes and edges added to the KG. Each dict has the structure:
|
| 35 |
+
{"_trace_id": str, "node": dict, "edge": dict}
|
| 36 |
+
meta_updates: A dict mapping source IDs to lists of trace IDs for nodes and edges added.
|
| 37 |
"""
|
| 38 |
+
chunks = [Chunk.from_dict(doc["_trace_id"], doc) for doc in batch]
|
| 39 |
text_chunks = [chunk for chunk in chunks if chunk.type == "text"]
|
| 40 |
mm_chunks = [
|
| 41 |
chunk
|
|
|
|
| 73 |
self.graph_storage.index_done_callback()
|
| 74 |
logger.info("Knowledge graph building completed.")
|
| 75 |
|
| 76 |
+
meta_updates = {}
|
| 77 |
+
results = []
|
| 78 |
+
for node in nodes:
|
| 79 |
+
if not node:
|
| 80 |
+
continue
|
| 81 |
+
trace_id = node["entity_name"]
|
| 82 |
+
results.append(
|
| 83 |
+
{
|
| 84 |
+
"_trace_id": trace_id,
|
| 85 |
+
"node": node,
|
| 86 |
+
"edge": {},
|
| 87 |
+
}
|
| 88 |
+
)
|
| 89 |
+
source_ids = node.get("source_id", "").split("<SEP>")
|
| 90 |
+
for source_id in source_ids:
|
| 91 |
+
meta_updates.setdefault(source_id, []).append(trace_id)
|
| 92 |
+
for edge in edges:
|
| 93 |
+
if not edge:
|
| 94 |
+
continue
|
| 95 |
+
trace_id = frozenset((edge["src_id"], edge["tgt_id"]))
|
| 96 |
+
results.append(
|
| 97 |
+
{
|
| 98 |
+
"_trace_id": str(trace_id),
|
| 99 |
+
"node": {},
|
| 100 |
+
"edge": edge,
|
| 101 |
+
}
|
| 102 |
+
)
|
| 103 |
+
source_ids = edge.get("source_id", "").split("<SEP>")
|
| 104 |
+
for source_id in source_ids:
|
| 105 |
+
meta_updates.setdefault(source_id, []).append(str(trace_id))
|
| 106 |
+
return results, meta_updates
|
graphgen/operators/build_kg/build_text_kg.py
CHANGED
|
@@ -30,6 +30,7 @@ def build_text_kg(
|
|
| 30 |
desc="[2/4]Extracting entities and relationships from chunks",
|
| 31 |
unit="chunk",
|
| 32 |
)
|
|
|
|
| 33 |
|
| 34 |
nodes = defaultdict(list)
|
| 35 |
edges = defaultdict(list)
|
|
|
|
| 30 |
desc="[2/4]Extracting entities and relationships from chunks",
|
| 31 |
unit="chunk",
|
| 32 |
)
|
| 33 |
+
results = [res for res in results if res]
|
| 34 |
|
| 35 |
nodes = defaultdict(list)
|
| 36 |
edges = defaultdict(list)
|
graphgen/operators/chunk/chunk_service.py
CHANGED
|
@@ -1,17 +1,14 @@
|
|
| 1 |
import os
|
| 2 |
from functools import lru_cache
|
| 3 |
-
from typing import Union
|
| 4 |
-
|
| 5 |
-
import pandas as pd
|
| 6 |
|
| 7 |
from graphgen.bases import BaseOperator
|
| 8 |
-
from graphgen.common import init_storage
|
| 9 |
from graphgen.models import (
|
| 10 |
ChineseRecursiveTextSplitter,
|
| 11 |
RecursiveCharacterSplitter,
|
| 12 |
Tokenizer,
|
| 13 |
)
|
| 14 |
-
from graphgen.utils import
|
| 15 |
|
| 16 |
_MAPPING = {
|
| 17 |
"en": RecursiveCharacterSplitter,
|
|
@@ -45,26 +42,25 @@ class ChunkService(BaseOperator):
|
|
| 45 |
def __init__(
|
| 46 |
self, working_dir: str = "cache", kv_backend: str = "rocksdb", **chunk_kwargs
|
| 47 |
):
|
| 48 |
-
super().__init__(
|
|
|
|
|
|
|
| 49 |
tokenizer_model = os.getenv("TOKENIZER_MODEL", "cl100k_base")
|
| 50 |
self.tokenizer_instance: Tokenizer = Tokenizer(model_name=tokenizer_model)
|
| 51 |
-
self.chunk_storage = init_storage(
|
| 52 |
-
backend=kv_backend,
|
| 53 |
-
working_dir=working_dir,
|
| 54 |
-
namespace="chunk",
|
| 55 |
-
)
|
| 56 |
self.chunk_kwargs = chunk_kwargs
|
| 57 |
|
| 58 |
-
def process(self, batch:
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
| 66 |
doc_type = doc.get("type")
|
| 67 |
-
|
| 68 |
if doc_type == "text":
|
| 69 |
doc_language = detect_main_language(doc["content"])
|
| 70 |
text_chunks = split_chunks(
|
|
@@ -72,32 +68,30 @@ class ChunkService(BaseOperator):
|
|
| 72 |
language=doc_language,
|
| 73 |
**self.chunk_kwargs,
|
| 74 |
)
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
),
|
| 82 |
-
"content": chunk_text,
|
| 83 |
-
"type": "text",
|
| 84 |
-
"_doc_id": doc_id,
|
| 85 |
-
"length": len(self.tokenizer_instance.encode(chunk_text))
|
| 86 |
if self.tokenizer_instance
|
| 87 |
-
else len(
|
| 88 |
"language": doc_language,
|
| 89 |
-
}
|
| 90 |
-
|
| 91 |
-
]
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
| 93 |
else:
|
| 94 |
# other types of documents(images, sequences) are not chunked
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
)
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
return
|
|
|
|
| 1 |
import os
|
| 2 |
from functools import lru_cache
|
| 3 |
+
from typing import Tuple, Union
|
|
|
|
|
|
|
| 4 |
|
| 5 |
from graphgen.bases import BaseOperator
|
|
|
|
| 6 |
from graphgen.models import (
|
| 7 |
ChineseRecursiveTextSplitter,
|
| 8 |
RecursiveCharacterSplitter,
|
| 9 |
Tokenizer,
|
| 10 |
)
|
| 11 |
+
from graphgen.utils import detect_main_language
|
| 12 |
|
| 13 |
_MAPPING = {
|
| 14 |
"en": RecursiveCharacterSplitter,
|
|
|
|
| 42 |
def __init__(
|
| 43 |
self, working_dir: str = "cache", kv_backend: str = "rocksdb", **chunk_kwargs
|
| 44 |
):
|
| 45 |
+
super().__init__(
|
| 46 |
+
working_dir=working_dir, kv_backend=kv_backend, op_name="chunk"
|
| 47 |
+
)
|
| 48 |
tokenizer_model = os.getenv("TOKENIZER_MODEL", "cl100k_base")
|
| 49 |
self.tokenizer_instance: Tokenizer = Tokenizer(model_name=tokenizer_model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
self.chunk_kwargs = chunk_kwargs
|
| 51 |
|
| 52 |
+
def process(self, batch: list) -> Tuple[list, dict]:
|
| 53 |
+
"""
|
| 54 |
+
Chunk the documents in the batch.
|
| 55 |
+
:return: A tuple of (results, meta_updates)
|
| 56 |
+
results: A list of chunked documents. Each chunked document is a dict with the structure:
|
| 57 |
+
{"_trace_id": str, "content": str, "type": str, "metadata": {"length": int, "language": str, ...}
|
| 58 |
+
meta_updates: A dict mapping source document IDs to lists of trace IDs for the chunked documents.
|
| 59 |
+
"""
|
| 60 |
+
results = []
|
| 61 |
+
meta_updates = {}
|
| 62 |
+
for doc in batch:
|
| 63 |
doc_type = doc.get("type")
|
|
|
|
| 64 |
if doc_type == "text":
|
| 65 |
doc_language = detect_main_language(doc["content"])
|
| 66 |
text_chunks = split_chunks(
|
|
|
|
| 68 |
language=doc_language,
|
| 69 |
**self.chunk_kwargs,
|
| 70 |
)
|
| 71 |
+
for text_chunk in text_chunks:
|
| 72 |
+
chunk = {
|
| 73 |
+
"content": text_chunk,
|
| 74 |
+
"type": "text",
|
| 75 |
+
"metadata": {
|
| 76 |
+
"length": len(self.tokenizer_instance.encode(text_chunk))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
if self.tokenizer_instance
|
| 78 |
+
else len(text_chunk),
|
| 79 |
"language": doc_language,
|
| 80 |
+
},
|
| 81 |
+
}
|
| 82 |
+
chunk["_trace_id"] = self.get_trace_id(chunk)
|
| 83 |
+
results.append(chunk)
|
| 84 |
+
meta_updates.setdefault(doc["_trace_id"], []).append(
|
| 85 |
+
chunk["_trace_id"]
|
| 86 |
+
)
|
| 87 |
else:
|
| 88 |
# other types of documents(images, sequences) are not chunked
|
| 89 |
+
data = doc.copy()
|
| 90 |
+
input_trace_id = data.pop("_trace_id")
|
| 91 |
+
content = data.pop("content") if "content" in data else ""
|
| 92 |
+
doc_type = data.pop("type")
|
| 93 |
+
chunk = {"content": content, "type": doc_type, "metadata": data}
|
| 94 |
+
chunk["_trace_id"] = self.get_trace_id(chunk)
|
| 95 |
+
results.append(chunk)
|
| 96 |
+
meta_updates.setdefault(input_trace_id, []).append(chunk["_trace_id"])
|
| 97 |
+
return results, meta_updates
|
graphgen/operators/evaluate/evaluate_kg.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict
|
| 2 |
+
|
| 3 |
+
from graphgen.bases import BaseGraphStorage
|
| 4 |
+
from graphgen.utils import logger
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def evaluate_kg(
|
| 8 |
+
kg_evaluators: Dict[str, Any],
|
| 9 |
+
kg_instance: BaseGraphStorage,
|
| 10 |
+
) -> Dict[str, Any]:
|
| 11 |
+
results = {}
|
| 12 |
+
for key, kg_evaluator in kg_evaluators.items():
|
| 13 |
+
results[key] = kg_evaluator.evaluate(kg_instance)
|
| 14 |
+
logger.info(f"KG Evaluation result for {key}: {results[key]}")
|
| 15 |
+
return results
|
graphgen/operators/evaluate/evaluate_qa.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
|
| 3 |
+
from graphgen.bases import QAPair
|
| 4 |
+
from graphgen.utils import run_concurrent
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def transform_to_qa_format(
|
| 8 |
+
items: list[dict], format_hint: str = "auto"
|
| 9 |
+
) -> list[dict[str, str]]:
|
| 10 |
+
extractors = {
|
| 11 |
+
"ChatML": lambda x: (
|
| 12 |
+
next(
|
| 13 |
+
(
|
| 14 |
+
m["content"]
|
| 15 |
+
for m in x.get("messages", [])
|
| 16 |
+
if m.get("role") == "user"
|
| 17 |
+
),
|
| 18 |
+
"",
|
| 19 |
+
),
|
| 20 |
+
next(
|
| 21 |
+
(
|
| 22 |
+
m["content"]
|
| 23 |
+
for m in x.get("messages", [])
|
| 24 |
+
if m.get("role") == "assistant"
|
| 25 |
+
),
|
| 26 |
+
"",
|
| 27 |
+
),
|
| 28 |
+
),
|
| 29 |
+
"Alpaca": lambda x: (
|
| 30 |
+
f"{x.get('instruction', '')}\n\n{x['input']}".strip()
|
| 31 |
+
if x.get("input")
|
| 32 |
+
else x.get("instruction", ""),
|
| 33 |
+
x.get("output", ""),
|
| 34 |
+
),
|
| 35 |
+
"Sharegpt": lambda x: (
|
| 36 |
+
next(
|
| 37 |
+
(
|
| 38 |
+
c["value"]
|
| 39 |
+
for c in x.get("conversations", [])
|
| 40 |
+
if c.get("from") == "human"
|
| 41 |
+
),
|
| 42 |
+
"",
|
| 43 |
+
),
|
| 44 |
+
next(
|
| 45 |
+
(
|
| 46 |
+
c["value"]
|
| 47 |
+
for c in x.get("conversations", [])
|
| 48 |
+
if c.get("from") in ("gpt", "assistant")
|
| 49 |
+
),
|
| 50 |
+
"",
|
| 51 |
+
),
|
| 52 |
+
),
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
auto_detect = {
|
| 56 |
+
"messages": "ChatML",
|
| 57 |
+
"conversations": "Sharegpt",
|
| 58 |
+
"instruction": "Alpaca",
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
transformed = []
|
| 62 |
+
for item in items:
|
| 63 |
+
fmt = format_hint
|
| 64 |
+
if fmt == "auto":
|
| 65 |
+
fmt = next(
|
| 66 |
+
(fmt_name for key, fmt_name in auto_detect.items() if key in item), None
|
| 67 |
+
)
|
| 68 |
+
if not fmt:
|
| 69 |
+
raise ValueError(
|
| 70 |
+
"Could not auto-detect format. Please specify format_hint."
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
question, answer = extractors[fmt](item)
|
| 74 |
+
options = None
|
| 75 |
+
if "\nOptions:\n" in question:
|
| 76 |
+
q_part, opt_part = question.split("\nOptions:\n", 1)
|
| 77 |
+
question = q_part
|
| 78 |
+
options = {
|
| 79 |
+
k.strip(): v.strip()
|
| 80 |
+
for line in opt_part.strip().split("\n")
|
| 81 |
+
if "." in line
|
| 82 |
+
for k, v in [line.split(".", 1)]
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
result = {"question": question.strip(), "answer": answer.strip()}
|
| 86 |
+
if options:
|
| 87 |
+
result["options"] = options
|
| 88 |
+
transformed.append(result)
|
| 89 |
+
|
| 90 |
+
return transformed
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def evaluate_qa(
|
| 94 |
+
qa_evaluators: dict[str, Any], items: list[dict[str, Any]]
|
| 95 |
+
) -> dict[str, Any]:
|
| 96 |
+
items = transform_to_qa_format(items)
|
| 97 |
+
items = [QAPair.from_dict(item) for item in items]
|
| 98 |
+
|
| 99 |
+
results = {}
|
| 100 |
+
for key, qa_evaluator in qa_evaluators.items():
|
| 101 |
+
result = run_concurrent(
|
| 102 |
+
qa_evaluator.evaluate,
|
| 103 |
+
items,
|
| 104 |
+
desc=f"Evaluating QA with {key}",
|
| 105 |
+
)
|
| 106 |
+
results[key] = result
|
| 107 |
+
return results
|
graphgen/operators/evaluate/evaluate_service.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
| 1 |
-
from typing import
|
| 2 |
|
| 3 |
-
import
|
| 4 |
-
|
| 5 |
-
from graphgen.bases import BaseLLMWrapper, BaseOperator, QAPair
|
| 6 |
from graphgen.common import init_llm, init_storage
|
| 7 |
-
from graphgen.utils import logger
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
class EvaluateService(BaseOperator):
|
|
@@ -15,167 +17,135 @@ class EvaluateService(BaseOperator):
|
|
| 15 |
|
| 16 |
def __init__(
|
| 17 |
self,
|
|
|
|
|
|
|
| 18 |
working_dir: str = "cache",
|
| 19 |
-
metrics: list[str] = None,
|
| 20 |
graph_backend: str = "kuzu",
|
| 21 |
kv_backend: str = "rocksdb",
|
| 22 |
**kwargs,
|
| 23 |
):
|
| 24 |
-
super().__init__(
|
|
|
|
|
|
|
| 25 |
self.llm_client: BaseLLMWrapper = init_llm("synthesizer")
|
| 26 |
self.metrics = metrics or []
|
| 27 |
self.kwargs = kwargs
|
| 28 |
self.graph_storage = init_storage(
|
| 29 |
backend=graph_backend, working_dir=working_dir, namespace="graph"
|
| 30 |
)
|
| 31 |
-
self.chunk_storage = init_storage(
|
| 32 |
-
backend=kv_backend, working_dir=working_dir, namespace="chunk"
|
| 33 |
-
)
|
| 34 |
|
| 35 |
# Initialize evaluators
|
| 36 |
-
self.
|
| 37 |
-
self.
|
| 38 |
-
self.
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
self.kg_evaluators[metric] = StructureEvaluator(
|
| 85 |
-
graph_storage=self.graph_storage,
|
| 86 |
-
**self.kwargs.get("structure_params", {}),
|
| 87 |
-
)
|
| 88 |
-
else:
|
| 89 |
-
raise ValueError(f"Unknown QA metric: {metric}")
|
| 90 |
-
|
| 91 |
-
async def _process_single_qa(self, item: dict[str, Any]) -> dict[str, Any]:
|
| 92 |
-
try:
|
| 93 |
-
qa_pair = QAPair(
|
| 94 |
-
question=str(item.get("question", "")),
|
| 95 |
-
answer=str(item.get("answer", "")),
|
| 96 |
)
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
item[f"{metric}_{sub_metric}"] = float(sub_score)
|
| 110 |
-
else:
|
| 111 |
-
item[metric] = float(score)
|
| 112 |
-
except Exception as e:
|
| 113 |
-
logger.error("Error in %s evaluation: %s", metric, str(e))
|
| 114 |
-
item[metric] = None
|
| 115 |
-
return item
|
| 116 |
-
|
| 117 |
-
def _evaluate_qa(self, items: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
| 118 |
-
def transform_messages_format(items: list[dict]) -> list[dict]:
|
| 119 |
-
"""
|
| 120 |
-
Transform from [{'messages': [...]}, ...] to [{'question': '...', 'answer': '...'}, ...]
|
| 121 |
-
"""
|
| 122 |
-
transformed = []
|
| 123 |
-
for item in items:
|
| 124 |
-
messages = item.get("messages", [])
|
| 125 |
-
question = next(
|
| 126 |
-
(m["content"] for m in messages if m.get("role") == "user"), ""
|
| 127 |
-
)
|
| 128 |
-
answer = next(
|
| 129 |
-
(m["content"] for m in messages if m.get("role") == "assistant"), ""
|
| 130 |
-
)
|
| 131 |
-
|
| 132 |
-
transformed.append({"question": question, "answer": answer})
|
| 133 |
-
return transformed
|
| 134 |
-
|
| 135 |
-
if not items:
|
| 136 |
-
return []
|
| 137 |
-
|
| 138 |
-
if not self.qa_evaluators:
|
| 139 |
-
logger.warning("No QA evaluators initialized, skipping QA evaluation")
|
| 140 |
-
return []
|
| 141 |
-
|
| 142 |
-
items = transform_messages_format(items)
|
| 143 |
-
results = run_concurrent(
|
| 144 |
-
self._process_single_qa,
|
| 145 |
-
items,
|
| 146 |
-
desc="Evaluating QA items",
|
| 147 |
-
unit="item",
|
| 148 |
)
|
| 149 |
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
|
| 179 |
# No metrics specified
|
| 180 |
logger.warning("No metrics specified, returning empty DataFrame")
|
| 181 |
-
return
|
|
|
|
| 1 |
+
from typing import Tuple
|
| 2 |
|
| 3 |
+
from graphgen.bases import BaseLLMWrapper, BaseOperator
|
|
|
|
|
|
|
| 4 |
from graphgen.common import init_llm, init_storage
|
| 5 |
+
from graphgen.utils import logger
|
| 6 |
+
|
| 7 |
+
from .evaluate_kg import evaluate_kg
|
| 8 |
+
from .evaluate_qa import evaluate_qa
|
| 9 |
+
from .evaluate_triple import evaluate_triple
|
| 10 |
|
| 11 |
|
| 12 |
class EvaluateService(BaseOperator):
|
|
|
|
| 17 |
|
| 18 |
def __init__(
|
| 19 |
self,
|
| 20 |
+
target: str,
|
| 21 |
+
metrics: list[str],
|
| 22 |
working_dir: str = "cache",
|
|
|
|
| 23 |
graph_backend: str = "kuzu",
|
| 24 |
kv_backend: str = "rocksdb",
|
| 25 |
**kwargs,
|
| 26 |
):
|
| 27 |
+
super().__init__(
|
| 28 |
+
working_dir=working_dir, kv_backend=kv_backend, op_name="evaluate"
|
| 29 |
+
)
|
| 30 |
self.llm_client: BaseLLMWrapper = init_llm("synthesizer")
|
| 31 |
self.metrics = metrics or []
|
| 32 |
self.kwargs = kwargs
|
| 33 |
self.graph_storage = init_storage(
|
| 34 |
backend=graph_backend, working_dir=working_dir, namespace="graph"
|
| 35 |
)
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
# Initialize evaluators
|
| 38 |
+
self.target = target
|
| 39 |
+
self.src_storage = None
|
| 40 |
+
self.tgt_storage = None
|
| 41 |
+
self.evaluators = {}
|
| 42 |
+
self._init_evaluators(self.target, metrics)
|
| 43 |
+
|
| 44 |
+
def _init_evaluators(self, target: str, metrics: list[str]):
|
| 45 |
+
"""Initialize evaluators based on target and metrics."""
|
| 46 |
+
if target not in {"qa", "kg", "triple"}:
|
| 47 |
+
raise ValueError(f"Unknown evaluation target: {target}")
|
| 48 |
+
|
| 49 |
+
# Delegate to target-specific initializer
|
| 50 |
+
getattr(self, f"_init_{target}_evaluators")(metrics)
|
| 51 |
+
|
| 52 |
+
def _init_qa_evaluators(self, metrics: list[str]):
|
| 53 |
+
"""Initialize QA evaluators."""
|
| 54 |
+
for metric in metrics:
|
| 55 |
+
self.evaluators[metric] = self._create_qa_evaluator(metric)
|
| 56 |
+
|
| 57 |
+
def _create_qa_evaluator(self, metric: str):
|
| 58 |
+
"""Factory method for QA evaluator instances."""
|
| 59 |
+
if metric == "length":
|
| 60 |
+
from graphgen.models import LengthEvaluator
|
| 61 |
+
|
| 62 |
+
return LengthEvaluator()
|
| 63 |
+
if metric == "mtld":
|
| 64 |
+
from graphgen.models import MTLDEvaluator
|
| 65 |
+
|
| 66 |
+
return MTLDEvaluator(**self.kwargs.get("mtld_params", {}))
|
| 67 |
+
if metric == "reward_score":
|
| 68 |
+
from graphgen.models import RewardEvaluator
|
| 69 |
+
|
| 70 |
+
return RewardEvaluator(**self.kwargs.get("reward_params", {}))
|
| 71 |
+
if metric == "uni_score":
|
| 72 |
+
from graphgen.models import UniEvaluator
|
| 73 |
+
|
| 74 |
+
return UniEvaluator(**self.kwargs.get("uni_params", {}))
|
| 75 |
+
raise ValueError(f"Unknown QA metric: {metric}")
|
| 76 |
+
|
| 77 |
+
def _init_kg_evaluators(self, metrics: list[str]):
|
| 78 |
+
"""Initialize KG evaluators."""
|
| 79 |
+
for metric in metrics:
|
| 80 |
+
if metric != "structure":
|
| 81 |
+
raise ValueError(f"Unknown KG metric: {metric}")
|
| 82 |
+
from graphgen.models import StructureEvaluator
|
| 83 |
+
|
| 84 |
+
self.evaluators[metric] = StructureEvaluator(
|
| 85 |
+
**self.kwargs.get("structure_params", {})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
)
|
| 87 |
+
|
| 88 |
+
def _init_triple_evaluators(self, metrics: list[str]):
|
| 89 |
+
"""Initialize Triple evaluators."""
|
| 90 |
+
self.src_storage = init_storage(
|
| 91 |
+
backend=self.kv_backend,
|
| 92 |
+
working_dir=self.working_dir,
|
| 93 |
+
namespace=self.kwargs["src_namespace"],
|
| 94 |
+
)
|
| 95 |
+
self.tgt_storage = init_storage(
|
| 96 |
+
backend=self.kv_backend,
|
| 97 |
+
working_dir=self.working_dir,
|
| 98 |
+
namespace=self.kwargs["tgt_namespace"],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
)
|
| 100 |
|
| 101 |
+
for metric in metrics:
|
| 102 |
+
if metric != "accuracy":
|
| 103 |
+
raise ValueError(f"Unknown Triple metric: {metric}")
|
| 104 |
+
from graphgen.models import AccuracyEvaluator
|
| 105 |
+
|
| 106 |
+
self.evaluators[metric] = AccuracyEvaluator(llm_client=self.llm_client)
|
| 107 |
+
|
| 108 |
+
def process(self, batch: list) -> Tuple[list, dict]:
|
| 109 |
+
final_results = []
|
| 110 |
+
meta_updates = {}
|
| 111 |
+
|
| 112 |
+
# 1. QA Evaluation (per item)
|
| 113 |
+
if self.target == "qa" and self.evaluators:
|
| 114 |
+
results: dict = evaluate_qa(self.evaluators, batch)
|
| 115 |
+
for i, item in enumerate(batch):
|
| 116 |
+
metrics = {}
|
| 117 |
+
for _, scores in results.items():
|
| 118 |
+
metrics.update(scores[i])
|
| 119 |
+
item.update({"metrics": metrics})
|
| 120 |
+
input_trace_id = item.pop("_trace_id")
|
| 121 |
+
item["_trace_id"] = self.get_trace_id(item)
|
| 122 |
+
final_results.append(item)
|
| 123 |
+
meta_updates.setdefault(input_trace_id, []).append(item["_trace_id"])
|
| 124 |
+
|
| 125 |
+
return final_results, meta_updates
|
| 126 |
+
|
| 127 |
+
# 2. KG evaluation
|
| 128 |
+
if self.target == "kg" and self.evaluators:
|
| 129 |
+
results = evaluate_kg(
|
| 130 |
+
self.evaluators,
|
| 131 |
+
self.graph_storage,
|
| 132 |
+
)
|
| 133 |
+
if not results:
|
| 134 |
+
logger.warning("No KG evaluation results, returning empty DataFrame")
|
| 135 |
+
return [], {}
|
| 136 |
+
results["_trace_id"] = self.get_trace_id(results)
|
| 137 |
+
final_results.append(results)
|
| 138 |
+
return final_results, {}
|
| 139 |
+
|
| 140 |
+
# 3. Triple evaluation
|
| 141 |
+
if self.target == "triple" and self.evaluators:
|
| 142 |
+
results = evaluate_triple(
|
| 143 |
+
self.evaluators, self.src_storage, self.tgt_storage
|
| 144 |
+
)
|
| 145 |
+
results["_trace_id"] = "evaluate-triple-result"
|
| 146 |
+
final_results.append(results)
|
| 147 |
+
return final_results, {}
|
| 148 |
|
| 149 |
# No metrics specified
|
| 150 |
logger.warning("No metrics specified, returning empty DataFrame")
|
| 151 |
+
return [], {}
|
graphgen/operators/evaluate/evaluate_triple.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
|
| 3 |
+
from graphgen.bases import BaseKVStorage
|
| 4 |
+
from graphgen.utils import logger, run_concurrent
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def evaluate_triple(
|
| 8 |
+
triple_evaluators: dict[str, Any],
|
| 9 |
+
src_storage: BaseKVStorage,
|
| 10 |
+
tgt_storage: BaseKVStorage,
|
| 11 |
+
) -> dict[str, Any]:
|
| 12 |
+
forward_meta = tgt_storage.get_by_id("_meta_forward")
|
| 13 |
+
|
| 14 |
+
tasks = []
|
| 15 |
+
for chunk_id, unit_ids in forward_meta.items():
|
| 16 |
+
chunk_content = str(src_storage.get_by_id(chunk_id))
|
| 17 |
+
|
| 18 |
+
nodes = []
|
| 19 |
+
edges = []
|
| 20 |
+
|
| 21 |
+
for unit_id in unit_ids:
|
| 22 |
+
unit_data = tgt_storage.get_by_id(unit_id)
|
| 23 |
+
if "node" in unit_data and unit_data["node"]:
|
| 24 |
+
nodes.append(unit_data["node"])
|
| 25 |
+
if "edge" in unit_data and unit_data["edge"]:
|
| 26 |
+
edges.append(unit_data["edge"])
|
| 27 |
+
|
| 28 |
+
tasks.append((chunk_content, nodes, edges))
|
| 29 |
+
|
| 30 |
+
results = {}
|
| 31 |
+
for key, triple_evaluator in triple_evaluators.items():
|
| 32 |
+
logger.info(f"Evaluating Triples with metric: {key}...")
|
| 33 |
+
result = run_concurrent(
|
| 34 |
+
triple_evaluator.evaluate,
|
| 35 |
+
tasks,
|
| 36 |
+
desc=f"Evaluating Triples with {key}",
|
| 37 |
+
)
|
| 38 |
+
results[key] = result
|
| 39 |
+
return results
|
graphgen/operators/extract/extract_service.py
CHANGED
|
@@ -1,16 +1,19 @@
|
|
| 1 |
import json
|
|
|
|
| 2 |
|
| 3 |
-
import
|
| 4 |
-
|
| 5 |
-
from graphgen.bases import BaseLLMWrapper, BaseOperator
|
| 6 |
from graphgen.common import init_llm
|
| 7 |
from graphgen.models.extractor import SchemaGuidedExtractor
|
| 8 |
from graphgen.utils import logger, run_concurrent
|
| 9 |
|
| 10 |
|
| 11 |
class ExtractService(BaseOperator):
|
| 12 |
-
def __init__(
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
self.llm_client: BaseLLMWrapper = init_llm("synthesizer")
|
| 15 |
self.extract_kwargs = extract_kwargs
|
| 16 |
self.method = self.extract_kwargs.get("method")
|
|
@@ -22,24 +25,32 @@ class ExtractService(BaseOperator):
|
|
| 22 |
else:
|
| 23 |
raise ValueError(f"Unsupported extraction method: {self.method}")
|
| 24 |
|
| 25 |
-
def process(self, batch:
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
| 33 |
results = run_concurrent(
|
| 34 |
self.extractor.extract,
|
| 35 |
-
|
| 36 |
desc="Extracting information",
|
| 37 |
unit="item",
|
| 38 |
)
|
| 39 |
-
results = self.extractor.merge_extractions(results)
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import json
|
| 2 |
+
from typing import Tuple
|
| 3 |
|
| 4 |
+
from graphgen.bases import BaseLLMWrapper, BaseOperator, Chunk
|
|
|
|
|
|
|
| 5 |
from graphgen.common import init_llm
|
| 6 |
from graphgen.models.extractor import SchemaGuidedExtractor
|
| 7 |
from graphgen.utils import logger, run_concurrent
|
| 8 |
|
| 9 |
|
| 10 |
class ExtractService(BaseOperator):
|
| 11 |
+
def __init__(
|
| 12 |
+
self, working_dir: str = "cache", kv_backend: str = "rocksdb", **extract_kwargs
|
| 13 |
+
):
|
| 14 |
+
super().__init__(
|
| 15 |
+
working_dir=working_dir, kv_backend=kv_backend, op_name="extract"
|
| 16 |
+
)
|
| 17 |
self.llm_client: BaseLLMWrapper = init_llm("synthesizer")
|
| 18 |
self.extract_kwargs = extract_kwargs
|
| 19 |
self.method = self.extract_kwargs.get("method")
|
|
|
|
| 25 |
else:
|
| 26 |
raise ValueError(f"Unsupported extraction method: {self.method}")
|
| 27 |
|
| 28 |
+
def process(self, batch: list) -> Tuple[list, dict]:
|
| 29 |
+
"""
|
| 30 |
+
Extract information from the batch of chunks.
|
| 31 |
+
:return: A tuple of (results, meta_updates)
|
| 32 |
+
results: A list of dicts containing extracted information. Each dict has the structure:
|
| 33 |
+
{"_trace_id": str, "content": dict}
|
| 34 |
+
meta_updates: A dict mapping source IDs to lists of trace IDs for the extracted information.
|
| 35 |
+
"""
|
| 36 |
+
logger.info("Start extracting information from %d items", len(batch))
|
| 37 |
+
chunks = [Chunk.from_dict(item["_trace_id"], item) for item in batch]
|
| 38 |
results = run_concurrent(
|
| 39 |
self.extractor.extract,
|
| 40 |
+
chunks,
|
| 41 |
desc="Extracting information",
|
| 42 |
unit="item",
|
| 43 |
)
|
|
|
|
| 44 |
|
| 45 |
+
meta_updates = {}
|
| 46 |
+
final_results = []
|
| 47 |
+
# chunk -> extracted info
|
| 48 |
+
for input_trace_id, result in zip(
|
| 49 |
+
[item["_trace_id"] for item in batch], results
|
| 50 |
+
):
|
| 51 |
+
if not result:
|
| 52 |
+
continue
|
| 53 |
+
result = {"_trace_id": self.get_trace_id(result), "content": result}
|
| 54 |
+
meta_updates.setdefault(input_trace_id, []).append(result["_trace_id"])
|
| 55 |
+
final_results.append(result)
|
| 56 |
+
return final_results, meta_updates
|
graphgen/operators/generate/generate_service.py
CHANGED
|
@@ -1,9 +1,6 @@
|
|
| 1 |
-
import
|
| 2 |
-
|
| 3 |
-
import
|
| 4 |
-
|
| 5 |
-
from graphgen.bases import BaseLLMWrapper, BaseOperator
|
| 6 |
-
from graphgen.common import init_llm
|
| 7 |
from graphgen.utils import logger, run_concurrent
|
| 8 |
|
| 9 |
|
|
@@ -15,12 +12,18 @@ class GenerateService(BaseOperator):
|
|
| 15 |
def __init__(
|
| 16 |
self,
|
| 17 |
working_dir: str = "cache",
|
|
|
|
| 18 |
method: str = "aggregated",
|
| 19 |
data_format: str = "ChatML",
|
| 20 |
**generate_kwargs,
|
| 21 |
):
|
| 22 |
-
super().__init__(
|
|
|
|
|
|
|
| 23 |
self.llm_client: BaseLLMWrapper = init_llm("synthesizer")
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
self.method = method
|
| 26 |
self.data_format = data_format
|
|
@@ -76,32 +79,31 @@ class GenerateService(BaseOperator):
|
|
| 76 |
else:
|
| 77 |
raise ValueError(f"Unsupported generation mode: {method}")
|
| 78 |
|
| 79 |
-
def process(self, batch:
|
| 80 |
-
items = batch.to_dict(orient="records")
|
| 81 |
-
return pd.DataFrame(self.generate(items))
|
| 82 |
-
|
| 83 |
-
def generate(self, items: list[dict]) -> list[dict]:
|
| 84 |
"""
|
| 85 |
Generate question-answer pairs based on nodes and edges.
|
| 86 |
-
:param items
|
| 87 |
-
:return: QA pairs
|
| 88 |
"""
|
| 89 |
-
logger.info("[Generation] mode: %s, batches: %d", self.method, len(
|
| 90 |
-
|
| 91 |
-
(json.loads(item["nodes"]), json.loads(item["edges"])) for item in items
|
| 92 |
-
]
|
| 93 |
results = run_concurrent(
|
| 94 |
self.generator.generate,
|
| 95 |
-
|
| 96 |
-
desc="
|
| 97 |
unit="batch",
|
| 98 |
)
|
| 99 |
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple
|
| 2 |
+
from graphgen.bases import BaseKVStorage, BaseLLMWrapper, BaseOperator
|
| 3 |
+
from graphgen.common import init_llm, init_storage
|
|
|
|
|
|
|
|
|
|
| 4 |
from graphgen.utils import logger, run_concurrent
|
| 5 |
|
| 6 |
|
|
|
|
| 12 |
def __init__(
|
| 13 |
self,
|
| 14 |
working_dir: str = "cache",
|
| 15 |
+
kv_backend: str = "rocksdb",
|
| 16 |
method: str = "aggregated",
|
| 17 |
data_format: str = "ChatML",
|
| 18 |
**generate_kwargs,
|
| 19 |
):
|
| 20 |
+
super().__init__(
|
| 21 |
+
working_dir=working_dir, kv_backend=kv_backend, op_name="generate"
|
| 22 |
+
)
|
| 23 |
self.llm_client: BaseLLMWrapper = init_llm("synthesizer")
|
| 24 |
+
self.generate_storage: BaseKVStorage = init_storage(
|
| 25 |
+
backend=kv_backend, working_dir=working_dir, namespace="generate"
|
| 26 |
+
)
|
| 27 |
|
| 28 |
self.method = method
|
| 29 |
self.data_format = data_format
|
|
|
|
| 79 |
else:
|
| 80 |
raise ValueError(f"Unsupported generation mode: {method}")
|
| 81 |
|
| 82 |
+
def process(self, batch: list) -> Tuple[list, dict]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
"""
|
| 84 |
Generate question-answer pairs based on nodes and edges.
|
|
|
|
|
|
|
| 85 |
"""
|
| 86 |
+
logger.info("[Generation] mode: %s, batches: %d", self.method, len(batch))
|
| 87 |
+
triples = [(item["nodes"], item["edges"]) for item in batch]
|
|
|
|
|
|
|
| 88 |
results = run_concurrent(
|
| 89 |
self.generator.generate,
|
| 90 |
+
triples,
|
| 91 |
+
desc="Generating QAs",
|
| 92 |
unit="batch",
|
| 93 |
)
|
| 94 |
|
| 95 |
+
meta_updates = {}
|
| 96 |
+
final_results = []
|
| 97 |
+
for input_trace_id, qa_pairs in zip(
|
| 98 |
+
[item["_trace_id"] for item in batch], results
|
| 99 |
+
):
|
| 100 |
+
if not qa_pairs:
|
| 101 |
+
continue
|
| 102 |
+
for qa_pair in qa_pairs:
|
| 103 |
+
res = self.generator.format_generation_results(
|
| 104 |
+
qa_pair, output_data_format=self.data_format
|
| 105 |
+
)
|
| 106 |
+
res["_trace_id"] = self.get_trace_id(res)
|
| 107 |
+
final_results.append(res)
|
| 108 |
+
meta_updates.setdefault(input_trace_id, []).append(res["_trace_id"])
|
| 109 |
+
return final_results, meta_updates
|