github-actions[bot] commited on
Commit
0bd1b0f
·
1 Parent(s): b275e29

Auto-sync from demo at Thu Jan 29 12:51:48 UTC 2026

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. graphgen/bases/__init__.py +1 -1
  2. graphgen/bases/base_evaluator.py +21 -2
  3. graphgen/bases/base_generator.py +34 -49
  4. graphgen/bases/base_operator.py +110 -9
  5. graphgen/bases/base_storage.py +6 -0
  6. graphgen/bases/datatypes.py +7 -0
  7. graphgen/common/init_storage.py +16 -4
  8. graphgen/models/__init__.py +0 -8
  9. graphgen/models/evaluator/__init__.py +2 -1
  10. graphgen/models/evaluator/kg/__init__.py +0 -17
  11. graphgen/models/evaluator/kg/accuracy_evaluator.py +0 -350
  12. graphgen/models/evaluator/kg/consistency_evaluator.py +0 -388
  13. graphgen/models/evaluator/kg/structure_evaluator.py +15 -15
  14. graphgen/models/evaluator/qa/length_evaluator.py +8 -7
  15. graphgen/models/evaluator/qa/mtld_evaluator.py +7 -6
  16. graphgen/models/evaluator/qa/reward_evaluator.py +12 -8
  17. graphgen/models/evaluator/qa/uni_evaluator.py +17 -9
  18. graphgen/models/evaluator/triple/__init__.py +1 -0
  19. graphgen/models/evaluator/triple/accuracy_evaluator.py +94 -0
  20. graphgen/models/extractor/schema_guided_extractor.py +5 -33
  21. graphgen/models/generator/aggregated_generator.py +7 -11
  22. graphgen/models/generator/atomic_generator.py +4 -9
  23. graphgen/models/generator/cot_generator.py +6 -9
  24. graphgen/models/generator/fill_in_blank_generator.py +11 -11
  25. graphgen/models/generator/multi_answer_generator.py +14 -12
  26. graphgen/models/generator/multi_choice_generator.py +11 -11
  27. graphgen/models/generator/multi_hop_generator.py +4 -9
  28. graphgen/models/generator/quiz_generator.py +18 -14
  29. graphgen/models/generator/true_false_generator.py +10 -10
  30. graphgen/models/generator/vqa_generator.py +50 -65
  31. graphgen/models/kg_builder/light_rag_kg_builder.py +14 -3
  32. graphgen/models/kg_builder/mm_kg_builder.py +2 -0
  33. graphgen/models/reader/csv_reader.py +1 -1
  34. graphgen/models/reader/json_reader.py +4 -1
  35. graphgen/models/reader/parquet_reader.py +1 -1
  36. graphgen/models/reader/rdf_reader.py +1 -1
  37. graphgen/models/reader/txt_reader.py +2 -1
  38. graphgen/models/storage/__init__.py +0 -6
  39. graphgen/models/storage/rocksdb_cache.py +0 -43
  40. graphgen/models/vis/__init__.py +0 -0
  41. graphgen/models/vis/community_visualizer.py +0 -48
  42. graphgen/operators/build_kg/build_kg_service.py +46 -18
  43. graphgen/operators/build_kg/build_text_kg.py +1 -0
  44. graphgen/operators/chunk/chunk_service.py +39 -45
  45. graphgen/operators/evaluate/evaluate_kg.py +15 -0
  46. graphgen/operators/evaluate/evaluate_qa.py +107 -0
  47. graphgen/operators/evaluate/evaluate_service.py +120 -150
  48. graphgen/operators/evaluate/evaluate_triple.py +39 -0
  49. graphgen/operators/extract/extract_service.py +31 -20
  50. graphgen/operators/generate/generate_service.py +30 -28
graphgen/bases/__init__.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from .base_extractor import BaseExtractor
2
  from .base_generator import BaseGenerator
3
  from .base_kg_builder import BaseKGBuilder
@@ -9,5 +10,4 @@ from .base_searcher import BaseSearcher
9
  from .base_splitter import BaseSplitter
10
  from .base_storage import BaseGraphStorage, BaseKVStorage, StorageNameSpace
11
  from .base_tokenizer import BaseTokenizer
12
- from .base_evaluator import BaseEvaluator
13
  from .datatypes import Chunk, Config, Node, QAPair, Token
 
1
+ from .base_evaluator import BaseKGEvaluator, BaseQAEvaluator, BaseTripleEvaluator
2
  from .base_extractor import BaseExtractor
3
  from .base_generator import BaseGenerator
4
  from .base_kg_builder import BaseKGBuilder
 
10
  from .base_splitter import BaseSplitter
11
  from .base_storage import BaseGraphStorage, BaseKVStorage, StorageNameSpace
12
  from .base_tokenizer import BaseTokenizer
 
13
  from .datatypes import Chunk, Config, Node, QAPair, Token
graphgen/bases/base_evaluator.py CHANGED
@@ -1,10 +1,29 @@
1
  from abc import ABC, abstractmethod
 
 
 
2
  from .datatypes import QAPair
3
 
4
 
5
- class BaseEvaluator(ABC):
6
  @abstractmethod
7
- def evaluate(self, pair: QAPair) -> float:
8
  """
9
  Evaluate the text and return a score.
10
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from abc import ABC, abstractmethod
2
+ from typing import Any
3
+
4
+ from .base_storage import BaseGraphStorage
5
  from .datatypes import QAPair
6
 
7
 
8
+ class BaseQAEvaluator(ABC):
9
  @abstractmethod
10
+ async def evaluate(self, pair: QAPair) -> dict[str, float]:
11
  """
12
  Evaluate the text and return a score.
13
  """
14
+
15
+
16
+ class BaseKGEvaluator(ABC):
17
+ @abstractmethod
18
+ def evaluate(self, kg: BaseGraphStorage) -> dict[str, Any]:
19
+ """
20
+ Evaluate the whole graph and return a dict of scores.
21
+ """
22
+
23
+
24
+ class BaseTripleEvaluator(ABC):
25
+ @abstractmethod
26
+ async def evaluate(self, unit: dict) -> dict[str, float]:
27
+ """
28
+ Evaluate a node/edge and return a score.
29
+ """
graphgen/bases/base_generator.py CHANGED
@@ -21,7 +21,7 @@ class BaseGenerator(ABC):
21
 
22
  @staticmethod
23
  @abstractmethod
24
- def parse_response(response: str) -> Any:
25
  """Parse the LLM response and return the generated QAs"""
26
 
27
  async def generate(
@@ -29,64 +29,49 @@ class BaseGenerator(ABC):
29
  batch: tuple[
30
  list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
31
  ],
32
- ) -> dict[str, Any]:
33
  """
34
  Generate QAs based on a given batch.
35
  :param batch
36
  :return: QA pairs
37
  """
38
- result = {}
39
  prompt = self.build_prompt(batch)
40
  response = await self.llm_client.generate_answer(prompt)
41
  qa_pairs = self.parse_response(response) # generate one or more QA pairs
42
- result.update(qa_pairs)
43
- return result
44
 
45
  @staticmethod
46
  def format_generation_results(
47
- results: list[dict], output_data_format: str
48
- ) -> list[dict[str, Any]]:
 
 
 
 
 
 
 
 
49
 
50
- flat_results = []
51
- for item in results:
52
- for _, qa_data in item.items():
53
- question = qa_data.get("question", "")
54
- answer = qa_data.get("answer", "")
55
- if "options" in qa_data and qa_data["options"]:
56
- options = qa_data["options"]
57
- options_str = "\n".join(
58
- [f"{key}. {options[key]}" for key in sorted(options.keys())]
59
- )
60
- question += f"\nOptions:\n{options_str}"
61
 
62
- if output_data_format == "Alpaca":
63
- flat_results.append(
64
- {
65
- "instruction": question,
66
- "input": "",
67
- "output": answer,
68
- }
69
- )
70
- elif output_data_format == "Sharegpt":
71
- flat_results.append(
72
- {
73
- "conversations": [
74
- {"from": "human", "value": question},
75
- {"from": "gpt", "value": answer},
76
- ]
77
- }
78
- )
79
- elif output_data_format == "ChatML":
80
- flat_results.append(
81
- {
82
- "messages": [
83
- {"role": "user", "content": question},
84
- {"role": "assistant", "content": answer},
85
- ]
86
- }
87
- )
88
- else:
89
- raise ValueError(
90
- f"Unknown output data format: {output_data_format}"
91
- )
92
- return flat_results
 
21
 
22
  @staticmethod
23
  @abstractmethod
24
+ def parse_response(response: str) -> list[dict]:
25
  """Parse the LLM response and return the generated QAs"""
26
 
27
  async def generate(
 
29
  batch: tuple[
30
  list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
31
  ],
32
+ ) -> list[dict]:
33
  """
34
  Generate QAs based on a given batch.
35
  :param batch
36
  :return: QA pairs
37
  """
 
38
  prompt = self.build_prompt(batch)
39
  response = await self.llm_client.generate_answer(prompt)
40
  qa_pairs = self.parse_response(response) # generate one or more QA pairs
41
+ return qa_pairs
 
42
 
43
  @staticmethod
44
  def format_generation_results(
45
+ result: dict, output_data_format: str
46
+ ) -> dict[str, Any]:
47
+ question = result.get("question", "")
48
+ answer = result.get("answer", "")
49
+ if "options" in result and result["options"]:
50
+ options = result["options"]
51
+ options_str = "\n".join(
52
+ [f"{key}. {options[key]}" for key in sorted(options.keys())]
53
+ )
54
+ question += f"\nOptions:\n{options_str}"
55
 
56
+ if output_data_format == "Alpaca":
57
+ return {
58
+ "instruction": question,
59
+ "input": "",
60
+ "output": answer,
61
+ }
 
 
 
 
 
62
 
63
+ if output_data_format == "Sharegpt":
64
+ return {
65
+ "conversations": [
66
+ {"from": "human", "value": question},
67
+ {"from": "gpt", "value": answer},
68
+ ]
69
+ }
70
+ if output_data_format == "ChatML":
71
+ return {
72
+ "messages": [
73
+ {"role": "user", "content": question},
74
+ {"role": "assistant", "content": answer},
75
+ ]
76
+ }
77
+ raise ValueError(f"Unknown output data format: {output_data_format}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/bases/base_operator.py CHANGED
@@ -1,19 +1,43 @@
1
  import inspect
2
  import os
3
  from abc import ABC, abstractmethod
4
- from typing import Iterable, Union
5
 
 
6
  import pandas as pd
7
  import ray
8
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  class BaseOperator(ABC):
11
- def __init__(self, working_dir: str = "cache", op_name: str = None):
 
 
 
 
 
12
  # lazy import to avoid circular import
 
13
  from graphgen.utils import set_logger
14
 
15
  log_dir = os.path.join(working_dir, "logs")
16
  self.op_name = op_name or self.__class__.__name__
 
 
 
 
 
17
 
18
  try:
19
  ctx = ray.get_runtime_context()
@@ -45,17 +69,94 @@ class BaseOperator(ABC):
45
 
46
  logger_token = CURRENT_LOGGER_VAR.set(self.logger)
47
  try:
48
- result = self.process(batch)
 
 
 
 
 
 
 
 
 
 
49
  if inspect.isgenerator(result):
50
- yield from result
 
 
 
 
51
  else:
52
- yield result
 
53
  finally:
54
  CURRENT_LOGGER_VAR.reset(logger_token)
55
 
56
- @abstractmethod
57
- def process(self, batch):
58
- raise NotImplementedError("Subclasses must implement the process method.")
59
-
60
  def get_logger(self):
61
  return self.logger
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import inspect
2
  import os
3
  from abc import ABC, abstractmethod
4
+ from typing import Iterable, Tuple, Union
5
 
6
+ import numpy as np
7
  import pandas as pd
8
  import ray
9
 
10
 
11
+ def convert_to_serializable(obj):
12
+ if isinstance(obj, np.ndarray):
13
+ return obj.tolist()
14
+ if isinstance(obj, np.generic):
15
+ return obj.item()
16
+ if isinstance(obj, dict):
17
+ return {k: convert_to_serializable(v) for k, v in obj.items()}
18
+ if isinstance(obj, list):
19
+ return [convert_to_serializable(v) for v in obj]
20
+ return obj
21
+
22
+
23
  class BaseOperator(ABC):
24
+ def __init__(
25
+ self,
26
+ working_dir: str = "cache",
27
+ kv_backend: str = "rocksdb",
28
+ op_name: str = None,
29
+ ):
30
  # lazy import to avoid circular import
31
+ from graphgen.common import init_storage
32
  from graphgen.utils import set_logger
33
 
34
  log_dir = os.path.join(working_dir, "logs")
35
  self.op_name = op_name or self.__class__.__name__
36
+ self.working_dir = working_dir
37
+ self.kv_backend = kv_backend
38
+ self.kv_storage = init_storage(
39
+ backend=kv_backend, working_dir=working_dir, namespace=self.op_name
40
+ )
41
 
42
  try:
43
  ctx = ray.get_runtime_context()
 
69
 
70
  logger_token = CURRENT_LOGGER_VAR.set(self.logger)
71
  try:
72
+ self.kv_storage.reload()
73
+ to_process, recovered = self.split(batch)
74
+ # yield recovered chunks first
75
+ if not recovered.empty:
76
+ yield recovered
77
+
78
+ if to_process.empty:
79
+ return
80
+
81
+ data = to_process.to_dict(orient="records")
82
+ result, meta_update = self.process(data)
83
  if inspect.isgenerator(result):
84
+ is_first = True
85
+ for res in result:
86
+ yield pd.DataFrame([res])
87
+ self.store([res], meta_update if is_first else {})
88
+ is_first = False
89
  else:
90
+ yield pd.DataFrame(result)
91
+ self.store(result, meta_update)
92
  finally:
93
  CURRENT_LOGGER_VAR.reset(logger_token)
94
 
 
 
 
 
95
  def get_logger(self):
96
  return self.logger
97
+
98
+ def get_meta_forward(self):
99
+ return self.kv_storage.get_by_id("_meta_forward") or {}
100
+
101
+ def get_meta_inverse(self):
102
+ return self.kv_storage.get_by_id("_meta_inverse") or {}
103
+
104
+ def get_trace_id(self, content: dict) -> str:
105
+ from graphgen.utils import compute_dict_hash
106
+
107
+ return compute_dict_hash(content, prefix=f"{self.op_name}-")
108
+
109
+ def split(self, batch: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
110
+ """
111
+ Split the input batch into to_process & processed based on _meta data in KV_storage
112
+ :param batch
113
+ :return:
114
+ to_process: DataFrame of documents to be chunked
115
+ recovered: Result DataFrame of already chunked documents
116
+ """
117
+ meta_forward = self.get_meta_forward()
118
+ meta_ids = set(meta_forward.keys())
119
+ mask = batch["_trace_id"].isin(meta_ids)
120
+ to_process = batch[~mask]
121
+ processed = batch[mask]
122
+
123
+ if processed.empty:
124
+ return to_process, pd.DataFrame()
125
+
126
+ all_ids = [
127
+ pid for tid in processed["_trace_id"] for pid in meta_forward.get(tid, [])
128
+ ]
129
+
130
+ recovered_chunks = self.kv_storage.get_by_ids(all_ids)
131
+ recovered_chunks = [c for c in recovered_chunks if c is not None]
132
+ return to_process, pd.DataFrame(recovered_chunks)
133
+
134
+ def store(self, results: list, meta_update: dict):
135
+ results = convert_to_serializable(results)
136
+ meta_update = convert_to_serializable(meta_update)
137
+
138
+ batch = {res["_trace_id"]: res for res in results}
139
+ self.kv_storage.upsert(batch)
140
+
141
+ # update forward meta
142
+ forward_meta = self.get_meta_forward()
143
+ forward_meta.update(meta_update)
144
+ self.kv_storage.update({"_meta_forward": forward_meta})
145
+
146
+ # update inverse meta
147
+ inverse_meta = self.get_meta_inverse()
148
+ for k, v_list in meta_update.items():
149
+ for v in v_list:
150
+ inverse_meta[v] = k
151
+ self.kv_storage.update({"_meta_inverse": inverse_meta})
152
+ self.kv_storage.index_done_callback()
153
+
154
+ @abstractmethod
155
+ def process(self, batch: list) -> Tuple[Union[list, Iterable[dict]], dict]:
156
+ """
157
+ Process the input batch and return the result.
158
+ :param batch
159
+ :return:
160
+ result: DataFrame of processed documents
161
+ meta_update: dict of meta data to be updated
162
+ """
graphgen/bases/base_storage.py CHANGED
@@ -39,6 +39,12 @@ class BaseKVStorage(Generic[T], StorageNameSpace):
39
  def upsert(self, data: dict[str, T]):
40
  raise NotImplementedError
41
 
 
 
 
 
 
 
42
  def drop(self):
43
  raise NotImplementedError
44
 
 
39
  def upsert(self, data: dict[str, T]):
40
  raise NotImplementedError
41
 
42
+ def update(self, data: dict[str, T]):
43
+ raise NotImplementedError
44
+
45
+ def delete(self, ids: list[str]):
46
+ raise NotImplementedError
47
+
48
  def drop(self):
49
  raise NotImplementedError
50
 
graphgen/bases/datatypes.py CHANGED
@@ -31,6 +31,13 @@ class QAPair:
31
  question: str
32
  answer: str
33
 
 
 
 
 
 
 
 
34
 
35
  @dataclass
36
  class Token:
 
31
  question: str
32
  answer: str
33
 
34
+ @staticmethod
35
+ def from_dict(data: dict) -> "QAPair":
36
+ return QAPair(
37
+ question=data.get("question", ""),
38
+ answer=data.get("answer", ""),
39
+ )
40
+
41
 
42
  @dataclass
43
  class Token:
graphgen/common/init_storage.py CHANGED
@@ -8,11 +8,11 @@ from graphgen.bases.base_storage import BaseGraphStorage, BaseKVStorage
8
  class KVStorageActor:
9
  def __init__(self, backend: str, working_dir: str, namespace: str):
10
  if backend == "json_kv":
11
- from graphgen.models import JsonKVStorage
12
 
13
  self.kv = JsonKVStorage(working_dir, namespace)
14
  elif backend == "rocksdb":
15
- from graphgen.models import RocksDBKVStorage
16
 
17
  self.kv = RocksDBKVStorage(working_dir, namespace)
18
  else:
@@ -42,6 +42,12 @@ class KVStorageActor:
42
  def upsert(self, data: dict) -> dict:
43
  return self.kv.upsert(data)
44
 
 
 
 
 
 
 
45
  def drop(self):
46
  return self.kv.drop()
47
 
@@ -55,11 +61,11 @@ class KVStorageActor:
55
  class GraphStorageActor:
56
  def __init__(self, backend: str, working_dir: str, namespace: str):
57
  if backend == "networkx":
58
- from graphgen.models import NetworkXStorage
59
 
60
  self.graph = NetworkXStorage(working_dir, namespace)
61
  elif backend == "kuzu":
62
- from graphgen.models import KuzuStorage
63
 
64
  self.graph = KuzuStorage(working_dir, namespace)
65
  else:
@@ -168,6 +174,12 @@ class RemoteKVStorageProxy(BaseKVStorage):
168
  def upsert(self, data: Dict[str, Any]):
169
  return ray.get(self.actor.upsert.remote(data))
170
 
 
 
 
 
 
 
171
  def drop(self):
172
  return ray.get(self.actor.drop.remote())
173
 
 
8
  class KVStorageActor:
9
  def __init__(self, backend: str, working_dir: str, namespace: str):
10
  if backend == "json_kv":
11
+ from graphgen.storage import JsonKVStorage
12
 
13
  self.kv = JsonKVStorage(working_dir, namespace)
14
  elif backend == "rocksdb":
15
+ from graphgen.storage import RocksDBKVStorage
16
 
17
  self.kv = RocksDBKVStorage(working_dir, namespace)
18
  else:
 
42
  def upsert(self, data: dict) -> dict:
43
  return self.kv.upsert(data)
44
 
45
+ def update(self, data: dict):
46
+ return self.kv.update(data)
47
+
48
+ def delete(self, ids: list[str]):
49
+ return self.kv.delete(ids)
50
+
51
  def drop(self):
52
  return self.kv.drop()
53
 
 
61
  class GraphStorageActor:
62
  def __init__(self, backend: str, working_dir: str, namespace: str):
63
  if backend == "networkx":
64
+ from graphgen.storage import NetworkXStorage
65
 
66
  self.graph = NetworkXStorage(working_dir, namespace)
67
  elif backend == "kuzu":
68
+ from graphgen.storage import KuzuStorage
69
 
70
  self.graph = KuzuStorage(working_dir, namespace)
71
  else:
 
174
  def upsert(self, data: Dict[str, Any]):
175
  return ray.get(self.actor.upsert.remote(data))
176
 
177
+ def update(self, data: Dict[str, Any]):
178
+ return ray.get(self.actor.update.remote(data))
179
+
180
+ def delete(self, ids: list[str]):
181
+ return ray.get(self.actor.delete.remote(ids))
182
+
183
  def drop(self):
184
  return ray.get(self.actor.drop.remote())
185
 
graphgen/models/__init__.py CHANGED
@@ -1,6 +1,5 @@
1
  from .evaluator import (
2
  AccuracyEvaluator,
3
- ConsistencyEvaluator,
4
  LengthEvaluator,
5
  MTLDEvaluator,
6
  RewardEvaluator,
@@ -44,11 +43,4 @@ from .searcher.kg.wiki_search import WikiSearch
44
  from .searcher.web.bing_search import BingSearch
45
  from .searcher.web.google_search import GoogleSearch
46
  from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter
47
- from .storage import (
48
- JsonKVStorage,
49
- KuzuStorage,
50
- NetworkXStorage,
51
- RocksDBCache,
52
- RocksDBKVStorage,
53
- )
54
  from .tokenizer import Tokenizer
 
1
  from .evaluator import (
2
  AccuracyEvaluator,
 
3
  LengthEvaluator,
4
  MTLDEvaluator,
5
  RewardEvaluator,
 
43
  from .searcher.web.bing_search import BingSearch
44
  from .searcher.web.google_search import GoogleSearch
45
  from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter
 
 
 
 
 
 
 
46
  from .tokenizer import Tokenizer
graphgen/models/evaluator/__init__.py CHANGED
@@ -1,2 +1,3 @@
1
- from .kg import AccuracyEvaluator, ConsistencyEvaluator, StructureEvaluator
2
  from .qa import LengthEvaluator, MTLDEvaluator, RewardEvaluator, UniEvaluator
 
 
1
+ from .kg import StructureEvaluator
2
  from .qa import LengthEvaluator, MTLDEvaluator, RewardEvaluator, UniEvaluator
3
+ from .triple import AccuracyEvaluator
graphgen/models/evaluator/kg/__init__.py CHANGED
@@ -1,18 +1 @@
1
- """
2
- Knowledge Graph Quality Evaluator
3
-
4
- This module provides comprehensive quality evaluation for knowledge graphs,
5
- 1. accuracy assessment (entity/relation/triple validation),
6
- 2. consistency assessment (attribute conflict detection), and structural
7
- 3. robustness assessment (noise ratio, connectivity, degree distribution).
8
- """
9
-
10
- from .accuracy_evaluator import AccuracyEvaluator
11
- from .consistency_evaluator import ConsistencyEvaluator
12
  from .structure_evaluator import StructureEvaluator
13
-
14
- __all__ = [
15
- "AccuracyEvaluator",
16
- "ConsistencyEvaluator",
17
- "StructureEvaluator",
18
- ]
 
 
 
 
 
 
 
 
 
 
 
 
1
  from .structure_evaluator import StructureEvaluator
 
 
 
 
 
 
graphgen/models/evaluator/kg/accuracy_evaluator.py DELETED
@@ -1,350 +0,0 @@
1
- import asyncio
2
- import json
3
- import re
4
- from typing import Any, Dict, List
5
-
6
- from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper
7
- from graphgen.bases.datatypes import Chunk
8
- from graphgen.templates import ACCURACY_EVALUATION_PROMPT
9
- from graphgen.utils import detect_main_language, logger
10
-
11
-
12
- class AccuracyEvaluator:
13
- """Evaluates accuracy of entity recognition and relation extraction using LLM-as-a-Judge.
14
-
15
- For each chunk, uses LLM to evaluate the quality of extracted entities and relations
16
- by comparing them with the original chunk content. Provides multi-dimensional quality
17
- scores (accuracy, completeness, precision).
18
- """
19
-
20
- def __init__(
21
- self,
22
- graph_storage: BaseGraphStorage,
23
- chunk_storage: BaseKVStorage,
24
- llm_client: BaseLLMWrapper,
25
- ):
26
- self.graph_storage = graph_storage
27
- self.chunk_storage = chunk_storage
28
- self.llm_client = llm_client
29
-
30
- def evaluate(self) -> Dict[str, Any]:
31
- """Evaluate entity and relation extraction quality using LLM-as-a-Judge.
32
-
33
- Returns:
34
- Dictionary containing entity_accuracy and relation_accuracy metrics.
35
- """
36
- # 1. Load all chunks from storage
37
- chunks = self._load_chunks_from_storage()
38
-
39
- if not chunks:
40
- logger.warning("No chunks found in storage")
41
- return {"error": "No chunks found in storage"}
42
-
43
- logger.info(f"Found {len(chunks)} chunks to evaluate")
44
-
45
- # 2. Evaluate each chunk
46
- entity_evaluations, relation_evaluations = self._evaluate_all_chunks(chunks)
47
-
48
- # 3. Aggregate results
49
- return self._aggregate_evaluation_results(
50
- entity_evaluations, relation_evaluations
51
- )
52
-
53
- def _load_chunks_from_storage(self) -> List[Chunk]:
54
- """Load all chunks from chunk storage."""
55
- chunks = []
56
- all_chunk_data = self.chunk_storage.get_all()
57
-
58
- for chunk_id, chunk_data in all_chunk_data.items():
59
- try:
60
- chunk = Chunk.from_dict(chunk_id, chunk_data)
61
- chunks.append(chunk)
62
- except Exception as e:
63
- logger.warning(f"Failed to load chunk {chunk_id}: {e}")
64
- continue
65
-
66
- return chunks
67
-
68
- def _get_extracted_entities_for_chunk(self, chunk_id: str) -> List[Dict]:
69
- """Get all entities extracted from the specified chunk."""
70
- entities = []
71
- all_nodes = self.graph_storage.get_all_nodes() or []
72
-
73
- for node_id, node_data in all_nodes:
74
- if not isinstance(node_data, dict):
75
- continue
76
- source_ids = node_data.get("source_id", "").split("<SEP>")
77
- # Check if this chunk_id is in the source_ids
78
- if chunk_id in [sid.strip() for sid in source_ids if sid.strip()]:
79
- entities.append(
80
- {
81
- "entity_name": node_data.get("entity_name", node_id),
82
- "entity_type": node_data.get("entity_type", ""),
83
- "description": node_data.get("description", ""),
84
- }
85
- )
86
-
87
- return entities
88
-
89
- def _get_extracted_relations_for_chunk(self, chunk_id: str) -> List[Dict]:
90
- """Get all relations extracted from the specified chunk."""
91
- relations = []
92
- all_edges = self.graph_storage.get_all_edges() or []
93
-
94
- for src_id, dst_id, edge_data in all_edges:
95
- if not isinstance(edge_data, dict):
96
- continue
97
- source_ids = edge_data.get("source_id", "").split("<SEP>")
98
- # Check if this chunk_id is in the source_ids
99
- if chunk_id in [sid.strip() for sid in source_ids if sid.strip()]:
100
- src_node = self.graph_storage.get_node(src_id) or {}
101
- dst_node = self.graph_storage.get_node(dst_id) or {}
102
- relations.append(
103
- {
104
- "source_entity": src_node.get("entity_name", src_id),
105
- "target_entity": dst_node.get("entity_name", dst_id),
106
- "relationship_summary": edge_data.get("description", ""),
107
- }
108
- )
109
-
110
- return relations
111
-
112
- def _evaluate_all_chunks(
113
- self, chunks: List[Chunk]
114
- ) -> tuple[List[Dict], List[Dict]]:
115
- """Evaluate all chunks sequentially."""
116
- entity_evaluations = []
117
- relation_evaluations = []
118
-
119
- for chunk in chunks:
120
- try:
121
- entities = self._get_extracted_entities_for_chunk(chunk.id)
122
- relations = self._get_extracted_relations_for_chunk(chunk.id)
123
-
124
- entity_eval = self._evaluate_entity_extraction(chunk, entities)
125
- relation_eval = self._evaluate_relation_extraction(chunk, relations)
126
-
127
- entity_evaluations.append(entity_eval)
128
- relation_evaluations.append(relation_eval)
129
- except Exception as e:
130
- logger.error(f"Failed to evaluate chunk {chunk.id}: {e}")
131
- continue
132
-
133
- return entity_evaluations, relation_evaluations
134
-
135
- def _evaluate_entity_extraction(
136
- self, chunk: Chunk, extracted_entities: List[Dict]
137
- ) -> Dict[str, Any]:
138
- """Use LLM to evaluate entity extraction quality."""
139
- try:
140
- lang = detect_main_language(chunk.content)
141
-
142
- prompt = ACCURACY_EVALUATION_PROMPT[lang]["ENTITY"].format(
143
- chunk_content=chunk.content,
144
- extracted_entities=json.dumps(
145
- extracted_entities, ensure_ascii=False, indent=2
146
- ),
147
- )
148
-
149
- response = asyncio.run(self.llm_client.generate_answer(prompt))
150
-
151
- # Try to parse JSON response
152
- try:
153
- evaluation_result = json.loads(response)
154
- except json.JSONDecodeError:
155
- # Try to extract JSON from markdown code blocks or other formats
156
- json_match = re.search(r"\{.*\}", response, re.DOTALL)
157
- if json_match:
158
- evaluation_result = json.loads(json_match.group(0))
159
- else:
160
- logger.warning(
161
- f"Failed to parse LLM response for chunk {chunk.id}: {response[:200]}"
162
- )
163
- # Return default evaluation
164
- evaluation_result = {
165
- "accuracy": 0.0,
166
- "completeness": 0.0,
167
- "precision": 0.0,
168
- "overall_score": 0.0,
169
- "accuracy_reasoning": "Failed to parse LLM response",
170
- "completeness_reasoning": "",
171
- "precision_reasoning": "",
172
- "issues": ["LLM response parsing failed"],
173
- }
174
-
175
- # Validate and calculate overall_score if not provided
176
- if "overall_score" not in evaluation_result:
177
- accuracy = float(evaluation_result.get("accuracy", 0.0))
178
- completeness = float(evaluation_result.get("completeness", 0.0))
179
- precision = float(evaluation_result.get("precision", 0.0))
180
- evaluation_result["overall_score"] = (
181
- 0.4 * accuracy + 0.4 * completeness + 0.2 * precision
182
- )
183
-
184
- return {
185
- "chunk_id": chunk.id,
186
- "chunk_content": chunk.content[:200]
187
- if chunk.content
188
- else "", # First 200 chars for debugging
189
- "extracted_entities_count": len(extracted_entities),
190
- **evaluation_result,
191
- }
192
- except Exception as e:
193
- logger.error(
194
- f"Error evaluating entity extraction for chunk {chunk.id}: {e}"
195
- )
196
- return {
197
- "chunk_id": chunk.id,
198
- "chunk_content": chunk.content[:200] if chunk.content else "",
199
- "extracted_entities_count": len(extracted_entities),
200
- "accuracy": 0.0,
201
- "completeness": 0.0,
202
- "precision": 0.0,
203
- "overall_score": 0.0,
204
- "accuracy_reasoning": f"Evaluation failed: {str(e)}",
205
- "completeness_reasoning": "",
206
- "precision_reasoning": "",
207
- "issues": [f"Evaluation error: {str(e)}"],
208
- }
209
-
210
- def _evaluate_relation_extraction(
211
- self, chunk: Chunk, extracted_relations: List[Dict]
212
- ) -> Dict[str, Any]:
213
- """Use LLM to evaluate relation extraction quality."""
214
- try:
215
- lang = detect_main_language(chunk.content)
216
- prompt = ACCURACY_EVALUATION_PROMPT[lang]["RELATION"].format(
217
- chunk_content=chunk.content,
218
- extracted_relations=json.dumps(
219
- extracted_relations, ensure_ascii=False, indent=2
220
- ),
221
- )
222
-
223
- response = asyncio.run(self.llm_client.generate_answer(prompt))
224
-
225
- # Try to parse JSON response
226
- try:
227
- evaluation_result = json.loads(response)
228
- except json.JSONDecodeError:
229
- # Try to extract JSON from markdown code blocks or other formats
230
- json_match = re.search(r"\{.*\}", response, re.DOTALL)
231
- if json_match:
232
- evaluation_result = json.loads(json_match.group(0))
233
- else:
234
- logger.warning(
235
- f"Failed to parse LLM response for chunk {chunk.id}: {response[:200]}"
236
- )
237
- # Return default evaluation
238
- evaluation_result = {
239
- "accuracy": 0.0,
240
- "completeness": 0.0,
241
- "precision": 0.0,
242
- "overall_score": 0.0,
243
- "accuracy_reasoning": "Failed to parse LLM response",
244
- "completeness_reasoning": "",
245
- "precision_reasoning": "",
246
- "issues": ["LLM response parsing failed"],
247
- }
248
-
249
- # Validate and calculate overall_score if not provided
250
- if "overall_score" not in evaluation_result:
251
- accuracy = float(evaluation_result.get("accuracy", 0.0))
252
- completeness = float(evaluation_result.get("completeness", 0.0))
253
- precision = float(evaluation_result.get("precision", 0.0))
254
- evaluation_result["overall_score"] = (
255
- 0.4 * accuracy + 0.4 * completeness + 0.2 * precision
256
- )
257
-
258
- return {
259
- "chunk_id": chunk.id,
260
- "chunk_content": chunk.content[:200] if chunk.content else "",
261
- "extracted_relations_count": len(extracted_relations),
262
- **evaluation_result,
263
- }
264
- except Exception as e:
265
- logger.error(
266
- f"Error evaluating relation extraction for chunk {chunk.id}: {e}"
267
- )
268
- return {
269
- "chunk_id": chunk.id,
270
- "chunk_content": chunk.content[:200] if chunk.content else "",
271
- "extracted_relations_count": len(extracted_relations),
272
- "accuracy": 0.0,
273
- "completeness": 0.0,
274
- "precision": 0.0,
275
- "overall_score": 0.0,
276
- "accuracy_reasoning": f"Evaluation failed: {str(e)}",
277
- "completeness_reasoning": "",
278
- "precision_reasoning": "",
279
- "issues": [f"Evaluation error: {str(e)}"],
280
- }
281
-
282
- @staticmethod
283
- def _aggregate_evaluation_results(
284
- entity_evaluations: List[Dict], relation_evaluations: List[Dict]
285
- ) -> Dict[str, Any]:
286
- """Aggregate evaluation results from all chunks."""
287
-
288
- def calculate_stats(scores: List[float]) -> Dict[str, float]:
289
- if not scores:
290
- return {"mean": 0.0, "median": 0.0, "min": 0.0, "max": 0.0, "std": 0.0}
291
- sorted_scores = sorted(scores)
292
- n = len(scores)
293
- mean = sum(scores) / n
294
- median = (
295
- sorted_scores[n // 2]
296
- if n % 2 == 1
297
- else (sorted_scores[n // 2 - 1] + sorted_scores[n // 2]) / 2
298
- )
299
- variance = sum((x - mean) ** 2 for x in scores) / n
300
- std = variance**0.5
301
-
302
- return {
303
- "mean": mean,
304
- "median": median,
305
- "min": min(scores),
306
- "max": max(scores),
307
- "std": std,
308
- }
309
-
310
- # Extract scores
311
- entity_overall_scores = [
312
- e.get("overall_score", 0.0) for e in entity_evaluations
313
- ]
314
- entity_accuracy_scores = [e.get("accuracy", 0.0) for e in entity_evaluations]
315
- entity_completeness_scores = [
316
- e.get("completeness", 0.0) for e in entity_evaluations
317
- ]
318
- entity_precision_scores = [e.get("precision", 0.0) for e in entity_evaluations]
319
-
320
- relation_overall_scores = [
321
- r.get("overall_score", 0.0) for r in relation_evaluations
322
- ]
323
- relation_accuracy_scores = [
324
- r.get("accuracy", 0.0) for r in relation_evaluations
325
- ]
326
- relation_completeness_scores = [
327
- r.get("completeness", 0.0) for r in relation_evaluations
328
- ]
329
- relation_precision_scores = [
330
- r.get("precision", 0.0) for r in relation_evaluations
331
- ]
332
-
333
- return {
334
- "entity_accuracy": {
335
- "overall_score": calculate_stats(entity_overall_scores),
336
- "accuracy": calculate_stats(entity_accuracy_scores),
337
- "completeness": calculate_stats(entity_completeness_scores),
338
- "precision": calculate_stats(entity_precision_scores),
339
- "total_chunks": len(entity_evaluations),
340
- "detailed_results": entity_evaluations,
341
- },
342
- "relation_accuracy": {
343
- "overall_score": calculate_stats(relation_overall_scores),
344
- "accuracy": calculate_stats(relation_accuracy_scores),
345
- "completeness": calculate_stats(relation_completeness_scores),
346
- "precision": calculate_stats(relation_precision_scores),
347
- "total_chunks": len(relation_evaluations),
348
- "detailed_results": relation_evaluations,
349
- },
350
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/models/evaluator/kg/consistency_evaluator.py DELETED
@@ -1,388 +0,0 @@
1
- import asyncio
2
- import json
3
- import re
4
- from typing import Any, Dict, List
5
-
6
- from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper
7
- from graphgen.bases.datatypes import Chunk
8
- from graphgen.templates.evaluation.kg.consistency_evaluation import (
9
- CONSISTENCY_EVALUATION_PROMPT,
10
- )
11
- from graphgen.utils import detect_main_language, logger
12
-
13
-
14
- class ConsistencyEvaluator:
15
- """Evaluates consistency by detecting semantic conflicts using LLM-as-a-Judge.
16
-
17
- For entities with multiple source chunks, compares entity_type and description
18
- extracted from different chunks to detect semantic conflicts.
19
- """
20
-
21
- def __init__(
22
- self,
23
- graph_storage: BaseGraphStorage,
24
- chunk_storage: BaseKVStorage,
25
- llm_client: BaseLLMWrapper,
26
- ):
27
- self.graph_storage = graph_storage
28
- self.chunk_storage = chunk_storage
29
- self.llm_client = llm_client
30
-
31
- def evaluate(self) -> Dict[str, Any]:
32
- """Evaluate consistency by detecting semantic conflicts."""
33
- all_nodes = self.graph_storage.get_all_nodes() or []
34
- if not all_nodes:
35
- return {"error": "Empty graph"}
36
-
37
- return self._evaluate_consistency(all_nodes)
38
-
39
- def _evaluate_consistency(self, all_nodes: List) -> Dict[str, Any]:
40
- """Evaluate consistency by detecting semantic conflicts."""
41
- # Filter entities with multiple source chunks
42
- entities_with_multiple_sources = []
43
- for node_id, node_data in all_nodes:
44
- if not isinstance(node_data, dict):
45
- continue
46
- source_ids = node_data.get("source_id", "").split("<SEP>")
47
- source_ids = [sid.strip() for sid in source_ids if sid.strip()]
48
- if len(source_ids) > 1: # Only check entities from multiple chunks
49
- entities_with_multiple_sources.append((node_id, node_data, source_ids))
50
-
51
- if not entities_with_multiple_sources:
52
- logger.info(
53
- "No entities with multiple sources found, skipping consistency check"
54
- )
55
- return {
56
- "conflict_rate": 0.0,
57
- "conflict_entities_count": 0,
58
- "total_entities": len(all_nodes),
59
- "conflicts": [],
60
- }
61
-
62
- logger.info(
63
- f"Checking consistency for {len(entities_with_multiple_sources)} entities with multiple sources"
64
- )
65
-
66
- # Evaluate entities sequentially
67
- conflicts = []
68
- conflict_entities = set()
69
-
70
- for entity_info in entities_with_multiple_sources:
71
- try:
72
- entity_id, entity_conflicts = self._evaluate_entity_consistency(entity_info)
73
- if entity_conflicts:
74
- conflicts.extend(entity_conflicts)
75
- conflict_entities.add(entity_id)
76
- except Exception as e:
77
- logger.error(
78
- f"Failed to evaluate entity {entity_info[0]}: {e}"
79
- )
80
- continue
81
-
82
- total_entities = len(all_nodes)
83
- conflict_rate = (
84
- len(conflict_entities) / total_entities if total_entities > 0 else 0
85
- )
86
-
87
- return {
88
- "conflict_rate": conflict_rate,
89
- "conflict_entities_count": len(conflict_entities),
90
- "total_entities": total_entities,
91
- "entities_checked": len(entities_with_multiple_sources),
92
- "conflicts": conflicts[:100], # Limit to first 100 conflicts
93
- }
94
-
95
- def _clean_entity_id(self, entity_id: str) -> str:
96
- """Clean entity ID by removing surrounding quotes."""
97
- clean_id = entity_id.strip()
98
- if (clean_id.startswith('"') and clean_id.endswith('"')) or (
99
- clean_id.startswith("'") and clean_id.endswith("'")
100
- ):
101
- clean_id = clean_id[1:-1].strip()
102
- return clean_id
103
-
104
- def _evaluate_entity_consistency(
105
- self, entity_info: tuple
106
- ) -> tuple[str, List[Dict]]:
107
- """Evaluate consistency for a single entity."""
108
- entity_id, _node_data, source_ids = entity_info
109
- # Clean entity_id for display
110
- clean_entity_id = self._clean_entity_id(entity_id)
111
- conflicts = []
112
-
113
- # Get chunks for this entity
114
- chunks = self._get_entity_chunks(source_ids)
115
- if len(chunks) < 2:
116
- return entity_id, []
117
-
118
- # Extract entity attributes from each chunk
119
- entity_extractions = {}
120
- for chunk in chunks:
121
- extraction = self._extract_entity_from_chunk(entity_id, chunk)
122
- if extraction:
123
- entity_extractions[chunk.id] = extraction
124
-
125
- if len(entity_extractions) < 2:
126
- return entity_id, []
127
-
128
- # Check entity type consistency
129
- type_extractions = {
130
- chunk_id: ext.get("entity_type", "")
131
- for chunk_id, ext in entity_extractions.items()
132
- }
133
- type_conflict = self._check_entity_type_consistency(
134
- entity_id, type_extractions
135
- )
136
- if type_conflict and type_conflict.get("has_conflict", False):
137
- conflicts.append(
138
- {
139
- "entity_id": clean_entity_id,
140
- "conflict_type": "entity_type",
141
- "conflict_severity": type_conflict.get("conflict_severity", 0.0),
142
- "conflict_reasoning": type_conflict.get("conflict_reasoning", ""),
143
- "conflicting_values": type_conflict.get("conflicting_types", []),
144
- "recommended_value": type_conflict.get("recommended_type", ""),
145
- }
146
- )
147
-
148
- # Check entity description consistency
149
- descriptions = {
150
- chunk_id: ext.get("description", "")
151
- for chunk_id, ext in entity_extractions.items()
152
- }
153
- desc_conflict = self._check_entity_description_consistency(
154
- entity_id, descriptions
155
- )
156
- if desc_conflict and desc_conflict.get("has_conflict", False):
157
- conflicts.append(
158
- {
159
- "entity_id": clean_entity_id,
160
- "conflict_type": "description",
161
- "conflict_severity": desc_conflict.get("conflict_severity", 0.0),
162
- "conflict_reasoning": desc_conflict.get("conflict_reasoning", ""),
163
- "conflicting_values": desc_conflict.get(
164
- "conflicting_descriptions", []
165
- ),
166
- "conflict_details": desc_conflict.get("conflict_details", ""),
167
- }
168
- )
169
-
170
- return entity_id, conflicts
171
-
172
- def _get_entity_chunks(self, source_ids: List[str]) -> List[Chunk]:
173
- """Get all chunks related to an entity."""
174
- chunks = []
175
- for chunk_id in source_ids:
176
- chunk_data = self.chunk_storage.get_by_id(chunk_id)
177
- if chunk_data:
178
- try:
179
- chunk = Chunk.from_dict(chunk_id, chunk_data)
180
- chunks.append(chunk)
181
- except Exception as e:
182
- logger.warning(f"Failed to load chunk {chunk_id}: {e}")
183
- continue
184
- return chunks
185
-
186
- def _extract_entity_from_chunk(
187
- self, entity_id: str, chunk: Chunk
188
- ) -> Dict[str, str]:
189
- """Extract entity attributes from a chunk using LLM."""
190
- try:
191
- # Clean entity_id: remove surrounding quotes if present
192
- clean_entity_id = self._clean_entity_id(entity_id)
193
-
194
- # Detect language and get appropriate prompt
195
- lang = detect_main_language(chunk.content)
196
- prompt = CONSISTENCY_EVALUATION_PROMPT[lang]["ENTITY_EXTRACTION"].format(
197
- entity_name=clean_entity_id,
198
- chunk_content=chunk.content[:2000]
199
- if chunk.content
200
- else "", # Limit content length
201
- )
202
-
203
- response = asyncio.run(self.llm_client.generate_answer(prompt))
204
-
205
- # Try to parse JSON response
206
- try:
207
- extraction = json.loads(response)
208
- except json.JSONDecodeError:
209
- # Try to extract JSON from markdown code blocks
210
- json_match = re.search(r"\{.*\}", response, re.DOTALL)
211
- if json_match:
212
- extraction = json.loads(json_match.group(0))
213
- else:
214
- logger.warning(
215
- f"Failed to parse extraction response for {entity_id} in chunk {chunk.id}"
216
- )
217
- return {}
218
-
219
- # Normalize entity_type to lowercase and validate
220
- entity_type = extraction.get("entity_type", "").lower().strip()
221
- # Valid preset types
222
- valid_types = {
223
- "concept",
224
- "date",
225
- "location",
226
- "keyword",
227
- "organization",
228
- "person",
229
- "event",
230
- "work",
231
- "nature",
232
- "artificial",
233
- "science",
234
- "technology",
235
- "mission",
236
- "gene",
237
- }
238
- # If entity_type is not in valid types, default to "concept"
239
- if entity_type not in valid_types:
240
- if entity_type: # If LLM provided a type but it's invalid
241
- logger.warning(
242
- f"Invalid entity_type '{entity_type}' for entity {clean_entity_id} in chunk {chunk.id}, "
243
- f"defaulting to 'concept'"
244
- )
245
- entity_type = "concept"
246
-
247
- return {
248
- "entity_type": entity_type,
249
- "description": extraction.get("description", ""),
250
- }
251
- except Exception as e:
252
- logger.error(
253
- f"Error extracting entity {entity_id} from chunk {chunk.id}: {e}"
254
- )
255
- return {}
256
-
257
- def _check_entity_type_consistency(
258
- self, entity_id: str, type_extractions: Dict[str, str]
259
- ) -> Dict[str, Any]:
260
- """Check entity type consistency using LLM."""
261
- if len(set(type_extractions.values())) <= 1:
262
- # All types are the same, no conflict
263
- return {"has_conflict": False}
264
-
265
- try:
266
- type_list = [
267
- f"Chunk {chunk_id}: {entity_type}"
268
- for chunk_id, entity_type in type_extractions.items()
269
- if entity_type
270
- ]
271
-
272
- # Detect language from type extraction text
273
- type_text = "\n".join(type_list)
274
- lang = detect_main_language(type_text)
275
- prompt = CONSISTENCY_EVALUATION_PROMPT[lang]["ENTITY_TYPE_CONFLICT"].format(
276
- entity_name=entity_id, type_extractions=type_text
277
- )
278
-
279
- response = asyncio.run(self.llm_client.generate_answer(prompt))
280
-
281
- # Parse JSON response
282
- try:
283
- result = json.loads(response)
284
- except json.JSONDecodeError:
285
- json_match = re.search(r"\{.*\}", response, re.DOTALL)
286
- if json_match:
287
- result = json.loads(json_match.group(0))
288
- else:
289
- logger.warning(
290
- f"Failed to parse conflict detection response for {entity_id}"
291
- )
292
- return {"has_conflict": False}
293
-
294
- return result
295
- except Exception as e:
296
- logger.error(f"Error checking type consistency for {entity_id}: {e}")
297
- return {"has_conflict": False}
298
-
299
- def _check_entity_description_consistency(
300
- self, entity_id: str, descriptions: Dict[str, str]
301
- ) -> Dict[str, Any]:
302
- """Check entity description consistency using LLM."""
303
- # Filter out empty descriptions
304
- valid_descriptions = {k: v for k, v in descriptions.items() if v}
305
- if len(valid_descriptions) < 2:
306
- return {"has_conflict": False}
307
-
308
- if len(set(valid_descriptions.values())) <= 1:
309
- # All descriptions are the same, no conflict
310
- return {"has_conflict": False}
311
-
312
- try:
313
- desc_list = [
314
- f"Chunk {chunk_id}: {description}"
315
- for chunk_id, description in valid_descriptions.items()
316
- ]
317
-
318
- # Detect language from description text
319
- desc_text = "\n".join(desc_list)
320
- lang = detect_main_language(desc_text)
321
- prompt = CONSISTENCY_EVALUATION_PROMPT[lang]["ENTITY_DESCRIPTION_CONFLICT"].format(
322
- entity_name=entity_id, descriptions=desc_text
323
- )
324
-
325
- response = asyncio.run(self.llm_client.generate_answer(prompt))
326
-
327
- # Parse JSON response
328
- try:
329
- result = json.loads(response)
330
- except json.JSONDecodeError:
331
- json_match = re.search(r"\{.*\}", response, re.DOTALL)
332
- if json_match:
333
- result = json.loads(json_match.group(0))
334
- else:
335
- logger.warning(
336
- f"Failed to parse conflict detection response for {entity_id}"
337
- )
338
- return {"has_conflict": False}
339
-
340
- return result
341
- except Exception as e:
342
- logger.error(f"Error checking description consistency for {entity_id}: {e}")
343
- return {"has_conflict": False}
344
-
345
- def _check_relation_consistency(
346
- self, src_id: str, dst_id: str, relation_extractions: Dict[str, str]
347
- ) -> Dict[str, Any]:
348
- """Check relation consistency using LLM."""
349
- if len(set(relation_extractions.values())) <= 1:
350
- return {"has_conflict": False}
351
-
352
- try:
353
- rel_list = [
354
- f"Chunk {chunk_id}: {relation}"
355
- for chunk_id, relation in relation_extractions.items()
356
- if relation
357
- ]
358
-
359
- # Detect language from relation description text
360
- rel_text = "\n".join(rel_list)
361
- lang = detect_main_language(rel_text)
362
- prompt = CONSISTENCY_EVALUATION_PROMPT[lang]["RELATION_CONFLICT"].format(
363
- source_entity=src_id,
364
- target_entity=dst_id,
365
- relation_descriptions=rel_text,
366
- )
367
-
368
- response = asyncio.run(self.llm_client.generate_answer(prompt))
369
-
370
- # Parse JSON response
371
- try:
372
- result = json.loads(response)
373
- except json.JSONDecodeError:
374
- json_match = re.search(r"\{.*\}", response, re.DOTALL)
375
- if json_match:
376
- result = json.loads(json_match.group(0))
377
- else:
378
- logger.warning(
379
- f"Failed to parse relation conflict response for {src_id}->{dst_id}"
380
- )
381
- return {"has_conflict": False}
382
-
383
- return result
384
- except Exception as e:
385
- logger.error(
386
- f"Error checking relation consistency for {src_id}->{dst_id}: {e}"
387
- )
388
- return {"has_conflict": False}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/models/evaluator/kg/structure_evaluator.py CHANGED
@@ -4,49 +4,49 @@ from typing import Any, Dict, Optional
4
  import numpy as np
5
  from scipy import stats
6
 
7
- from graphgen.bases import BaseGraphStorage
8
  from graphgen.utils import logger
9
 
10
 
11
- class StructureEvaluator:
12
  """Evaluates structural robustness of the graph."""
13
 
14
  def __init__(
15
  self,
16
- graph_storage: BaseGraphStorage,
17
  noise_ratio_threshold: float = 0.15,
18
  largest_cc_ratio_threshold: float = 0.90,
19
  avg_degree_min: float = 2.0,
20
  avg_degree_max: float = 5.0,
21
  powerlaw_r2_threshold: float = 0.75,
22
  ):
23
- self.graph_storage = graph_storage
24
  self.noise_ratio_threshold = noise_ratio_threshold
25
  self.largest_cc_ratio_threshold = largest_cc_ratio_threshold
26
  self.avg_degree_min = avg_degree_min
27
  self.avg_degree_max = avg_degree_max
28
  self.powerlaw_r2_threshold = powerlaw_r2_threshold
29
 
30
- def evaluate(self) -> Dict[str, Any]:
31
  """
32
  Evaluate the structural robustness of the graph.
33
- :return:
 
 
 
 
 
 
 
34
  """
35
- storage = self.graph_storage
36
-
37
- total_nodes = storage.get_node_count()
38
- if total_nodes == 0:
39
- return {"error": "Empty graph"}
40
-
41
- total_edges = storage.get_edge_count()
42
- degree_map = storage.get_all_node_degrees()
43
 
44
  # Noise ratio: isolated nodes / total nodes
45
  isolated_nodes = [nid for nid, deg in degree_map.items() if deg == 0]
46
  noise_ratio = len(isolated_nodes) / total_nodes
47
 
48
  # Largest connected component
49
- components = storage.get_connected_components(undirected=True)
50
  largest_cc_ratio = (
51
  len(max(components, key=len)) / total_nodes if components else 0
52
  )
 
4
  import numpy as np
5
  from scipy import stats
6
 
7
+ from graphgen.bases import BaseGraphStorage, BaseKGEvaluator
8
  from graphgen.utils import logger
9
 
10
 
11
+ class StructureEvaluator(BaseKGEvaluator):
12
  """Evaluates structural robustness of the graph."""
13
 
14
  def __init__(
15
  self,
 
16
  noise_ratio_threshold: float = 0.15,
17
  largest_cc_ratio_threshold: float = 0.90,
18
  avg_degree_min: float = 2.0,
19
  avg_degree_max: float = 5.0,
20
  powerlaw_r2_threshold: float = 0.75,
21
  ):
 
22
  self.noise_ratio_threshold = noise_ratio_threshold
23
  self.largest_cc_ratio_threshold = largest_cc_ratio_threshold
24
  self.avg_degree_min = avg_degree_min
25
  self.avg_degree_max = avg_degree_max
26
  self.powerlaw_r2_threshold = powerlaw_r2_threshold
27
 
28
+ def evaluate(self, kg: BaseGraphStorage) -> Dict[str, Any]:
29
  """
30
  Evaluate the structural robustness of the graph.
31
+ :return: Dictionary of structural metrics and robustness verdict. The keys include:
32
+ - total_nodes: Total number of nodes in the graph
33
+ - total_edges: Total number of edges in the graph
34
+ - noise_ratio: Ratio of isolated nodes to total nodes
35
+ - largest_cc_ratio: Ratio of largest connected component size to total nodes
36
+ - avg_degree: Average node degree
37
+ - powerlaw_r2: R² value of power law fit to degree distribution
38
+ - is_robust: Boolean indicating if the graph is structurally robust
39
  """
40
+ total_nodes = kg.get_node_count()
41
+ total_edges = kg.get_edge_count()
42
+ degree_map = kg.get_all_node_degrees()
 
 
 
 
 
43
 
44
  # Noise ratio: isolated nodes / total nodes
45
  isolated_nodes = [nid for nid, deg in degree_map.items() if deg == 0]
46
  noise_ratio = len(isolated_nodes) / total_nodes
47
 
48
  # Largest connected component
49
+ components = kg.get_connected_components(undirected=True)
50
  largest_cc_ratio = (
51
  len(max(components, key=len)) / total_nodes if components else 0
52
  )
graphgen/models/evaluator/qa/length_evaluator.py CHANGED
@@ -1,18 +1,19 @@
1
-
2
  import os
3
- from graphgen.bases import BaseEvaluator, QAPair
 
4
  from graphgen.models.tokenizer import Tokenizer
5
 
6
 
7
- class LengthEvaluator(BaseEvaluator):
8
  def __init__(self, tokenizer_name: str = None):
9
- tokenizer_model = tokenizer_name or os.environ.get("TOKENIZER_MODEL", "cl100k_base")
 
 
10
  self.tokenizer: Tokenizer = Tokenizer(tokenizer_model)
11
 
12
- def evaluate(self, pair: QAPair) -> float:
13
  """
14
  Evaluate the length of the qa pair.
15
  """
16
  content = pair.question + pair.answer
17
- tokens = self.tokenizer.encode(content)
18
- return len(tokens)
 
 
1
  import os
2
+
3
+ from graphgen.bases import BaseQAEvaluator, QAPair
4
  from graphgen.models.tokenizer import Tokenizer
5
 
6
 
7
+ class LengthEvaluator(BaseQAEvaluator):
8
  def __init__(self, tokenizer_name: str = None):
9
+ tokenizer_model = tokenizer_name or os.environ.get(
10
+ "TOKENIZER_MODEL", "cl100k_base"
11
+ )
12
  self.tokenizer: Tokenizer = Tokenizer(tokenizer_model)
13
 
14
+ async def evaluate(self, pair: QAPair) -> dict[str, float]:
15
  """
16
  Evaluate the length of the qa pair.
17
  """
18
  content = pair.question + pair.answer
19
+ return {"length": self.tokenizer.count_tokens(content)}
 
graphgen/models/evaluator/qa/mtld_evaluator.py CHANGED
@@ -1,10 +1,10 @@
1
  from typing import Set
2
 
3
- from graphgen.bases import BaseEvaluator, QAPair
4
  from graphgen.utils import NLTKHelper, detect_main_language
5
 
6
 
7
- class MTLDEvaluator(BaseEvaluator):
8
  """
9
  Metrics for measuring the lexical diversity of text.
10
  """
@@ -15,7 +15,7 @@ class MTLDEvaluator(BaseEvaluator):
15
  self.stopwords_zh: Set[str] = set(self.nltk_helper.get_stopwords("zh"))
16
  self.threshold = threshold
17
 
18
- def evaluate(self, pair: QAPair) -> float:
19
  """
20
  Calculate the MTLD (Mean Token Length Diversity) score for a given text.
21
 
@@ -24,7 +24,7 @@ class MTLDEvaluator(BaseEvaluator):
24
  """
25
  text = pair.answer
26
  if not text or not text.strip():
27
- return 0.0
28
 
29
  lang = detect_main_language(text)
30
  tokens = self.nltk_helper.word_tokenize(text, lang)
@@ -34,7 +34,7 @@ class MTLDEvaluator(BaseEvaluator):
34
  filtered_tokens = [word for word in filtered_tokens if word.isalnum()]
35
 
36
  if not filtered_tokens:
37
- return 0
38
 
39
  # Compute forward factors
40
  forward_factors = self._compute_factors(filtered_tokens, self.threshold)
@@ -43,7 +43,8 @@ class MTLDEvaluator(BaseEvaluator):
43
  backward_factors = self._compute_factors(filtered_tokens[::-1], self.threshold)
44
 
45
  # Compute average factors
46
- return (forward_factors + backward_factors) / 2
 
47
 
48
  @staticmethod
49
  def _compute_factors(tokens: list, threshold: float) -> float:
 
1
  from typing import Set
2
 
3
+ from graphgen.bases import BaseQAEvaluator, QAPair
4
  from graphgen.utils import NLTKHelper, detect_main_language
5
 
6
 
7
+ class MTLDEvaluator(BaseQAEvaluator):
8
  """
9
  Metrics for measuring the lexical diversity of text.
10
  """
 
15
  self.stopwords_zh: Set[str] = set(self.nltk_helper.get_stopwords("zh"))
16
  self.threshold = threshold
17
 
18
+ async def evaluate(self, pair: QAPair) -> dict[str, float]:
19
  """
20
  Calculate the MTLD (Mean Token Length Diversity) score for a given text.
21
 
 
24
  """
25
  text = pair.answer
26
  if not text or not text.strip():
27
+ return {"mtld": 0}
28
 
29
  lang = detect_main_language(text)
30
  tokens = self.nltk_helper.word_tokenize(text, lang)
 
34
  filtered_tokens = [word for word in filtered_tokens if word.isalnum()]
35
 
36
  if not filtered_tokens:
37
+ return {"mtld": 0}
38
 
39
  # Compute forward factors
40
  forward_factors = self._compute_factors(filtered_tokens, self.threshold)
 
43
  backward_factors = self._compute_factors(filtered_tokens[::-1], self.threshold)
44
 
45
  # Compute average factors
46
+ mtld_score = (forward_factors + backward_factors) / 2
47
+ return {"mtld": mtld_score}
48
 
49
  @staticmethod
50
  def _compute_factors(tokens: list, threshold: float) -> float:
graphgen/models/evaluator/qa/reward_evaluator.py CHANGED
@@ -1,8 +1,9 @@
1
  from typing import Optional
2
- from graphgen.bases import BaseEvaluator, QAPair
3
 
 
4
 
5
- class RewardEvaluator(BaseEvaluator):
 
6
  """
7
  Reward Model Evaluator for single QAPair evaluation.
8
  """
@@ -15,7 +16,7 @@ class RewardEvaluator(BaseEvaluator):
15
  ):
16
  """
17
  Initialize the reward evaluator.
18
-
19
  Args:
20
  reward_name: Model name or path on HuggingFace Hub
21
  max_length: Maximum token length for the model
@@ -26,6 +27,7 @@ class RewardEvaluator(BaseEvaluator):
26
 
27
  import torch
28
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
 
29
  self.torch = torch
30
 
31
  # Set device (auto-detect if not specified)
@@ -37,15 +39,17 @@ class RewardEvaluator(BaseEvaluator):
37
  self.model.to(self.device)
38
  self.model.eval()
39
  except Exception as e:
40
- raise RuntimeError(f"Failed to load reward model '{reward_name}': {e}") from e
 
 
41
 
42
- def evaluate(self, pair: QAPair) -> float:
43
  """
44
  Evaluate a single question-answer pair using the reward model.
45
-
46
  Args:
47
  pair: QAPair containing question and answer strings
48
-
49
  Returns:
50
  Score as a float
51
  """
@@ -63,4 +67,4 @@ class RewardEvaluator(BaseEvaluator):
63
  with self.torch.no_grad():
64
  score = self.model(**inputs).logits[0].item()
65
 
66
- return score
 
1
  from typing import Optional
 
2
 
3
+ from graphgen.bases import BaseQAEvaluator, QAPair
4
 
5
+
6
+ class RewardEvaluator(BaseQAEvaluator):
7
  """
8
  Reward Model Evaluator for single QAPair evaluation.
9
  """
 
16
  ):
17
  """
18
  Initialize the reward evaluator.
19
+
20
  Args:
21
  reward_name: Model name or path on HuggingFace Hub
22
  max_length: Maximum token length for the model
 
27
 
28
  import torch
29
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
30
+
31
  self.torch = torch
32
 
33
  # Set device (auto-detect if not specified)
 
39
  self.model.to(self.device)
40
  self.model.eval()
41
  except Exception as e:
42
+ raise RuntimeError(
43
+ f"Failed to load reward model '{reward_name}': {e}"
44
+ ) from e
45
 
46
+ async def evaluate(self, pair: QAPair) -> dict[str, float]:
47
  """
48
  Evaluate a single question-answer pair using the reward model.
49
+
50
  Args:
51
  pair: QAPair containing question and answer strings
52
+
53
  Returns:
54
  Score as a float
55
  """
 
67
  with self.torch.no_grad():
68
  score = self.model(**inputs).logits[0].item()
69
 
70
+ return {"reward_score": score}
graphgen/models/evaluator/qa/uni_evaluator.py CHANGED
@@ -1,14 +1,15 @@
1
  # https://github.com/maszhongming/UniEval/tree/main
2
- from typing import Optional, List
3
- from graphgen.bases import BaseEvaluator, QAPair
4
 
 
5
 
6
- class UniEvaluator(BaseEvaluator):
 
7
  """
8
  UniEvaluator for single QAPair evaluation across quality dimensions.
9
-
10
  Dimensions: naturalness, coherence, understandability
11
-
12
  Usage:
13
  evaluator = UniEvaluator()
14
  pair = QAPair(question="...", answer="...")
@@ -34,6 +35,7 @@ class UniEvaluator(BaseEvaluator):
34
  """
35
  import torch
36
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
 
37
  self.torch = torch
38
 
39
  self.model_name = model_name or self.DEFAULT_MODEL
@@ -58,10 +60,12 @@ class UniEvaluator(BaseEvaluator):
58
  if dimension == "coherence":
59
  return f"question: Is this a coherent response? </s> response: {answer} </s> history: {question}"
60
  if dimension == "understandability":
61
- return f"question: Is this an understandable response? </s> response: {answer}"
 
 
62
  raise NotImplementedError(f"Unsupported dimension '{dimension}'")
63
 
64
- def evaluate(
65
  self,
66
  pair: QAPair,
67
  dimensions: Optional[List[str]] = None,
@@ -72,7 +76,9 @@ class UniEvaluator(BaseEvaluator):
72
  # Validate dimensions
73
  invalid = set(dimensions) - set(self.DEFAULT_DIMS)
74
  if invalid:
75
- raise ValueError(f"Invalid dimensions: {invalid}. Available: {self.DEFAULT_DIMS}")
 
 
76
 
77
  results = {}
78
  no_token = self.torch.tensor([[self._no_id]], device=self.device)
@@ -95,7 +101,9 @@ class UniEvaluator(BaseEvaluator):
95
  attention_mask=src_mask,
96
  labels=no_token,
97
  use_cache=False,
98
- ).logits[:, 0, :] # [1, vocab_size]
 
 
99
 
100
  probs = self.torch.softmax(logits, dim=-1)[0]
101
  score = probs[self._yes_id] / (probs[self._yes_id] + probs[self._no_id])
 
1
  # https://github.com/maszhongming/UniEval/tree/main
2
+ from typing import List, Optional
 
3
 
4
+ from graphgen.bases import BaseQAEvaluator, QAPair
5
 
6
+
7
+ class UniEvaluator(BaseQAEvaluator):
8
  """
9
  UniEvaluator for single QAPair evaluation across quality dimensions.
10
+
11
  Dimensions: naturalness, coherence, understandability
12
+
13
  Usage:
14
  evaluator = UniEvaluator()
15
  pair = QAPair(question="...", answer="...")
 
35
  """
36
  import torch
37
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
38
+
39
  self.torch = torch
40
 
41
  self.model_name = model_name or self.DEFAULT_MODEL
 
60
  if dimension == "coherence":
61
  return f"question: Is this a coherent response? </s> response: {answer} </s> history: {question}"
62
  if dimension == "understandability":
63
+ return (
64
+ f"question: Is this an understandable response? </s> response: {answer}"
65
+ )
66
  raise NotImplementedError(f"Unsupported dimension '{dimension}'")
67
 
68
+ async def evaluate(
69
  self,
70
  pair: QAPair,
71
  dimensions: Optional[List[str]] = None,
 
76
  # Validate dimensions
77
  invalid = set(dimensions) - set(self.DEFAULT_DIMS)
78
  if invalid:
79
+ raise ValueError(
80
+ f"Invalid dimensions: {invalid}. Available: {self.DEFAULT_DIMS}"
81
+ )
82
 
83
  results = {}
84
  no_token = self.torch.tensor([[self._no_id]], device=self.device)
 
101
  attention_mask=src_mask,
102
  labels=no_token,
103
  use_cache=False,
104
+ ).logits[
105
+ :, 0, :
106
+ ] # [1, vocab_size]
107
 
108
  probs = self.torch.softmax(logits, dim=-1)[0]
109
  score = probs[self._yes_id] / (probs[self._yes_id] + probs[self._no_id])
graphgen/models/evaluator/triple/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .accuracy_evaluator import AccuracyEvaluator
graphgen/models/evaluator/triple/accuracy_evaluator.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ from typing import Any, Dict
4
+
5
+ from graphgen.bases import BaseLLMWrapper, BaseTripleEvaluator
6
+ from graphgen.templates import ACCURACY_EVALUATION_PROMPT
7
+ from graphgen.utils import detect_main_language, logger
8
+
9
+
10
+ class AccuracyEvaluator(BaseTripleEvaluator):
11
+ """Evaluates accuracy of entity recognition and relation extraction using LLM-as-a-Judge.
12
+
13
+ For each chunk, uses LLM to evaluate the quality of extracted entities and relations
14
+ by comparing them with the original chunk content. Provides multi-dimensional quality
15
+ scores (accuracy, completeness, precision).
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ llm_client: BaseLLMWrapper,
21
+ ):
22
+ self.llm_client = llm_client
23
+
24
+ async def evaluate(self, unit: tuple) -> Dict[str, Any]:
25
+ """Evaluate entity and relation extraction quality using LLM-as-a-Judge.
26
+
27
+ Returns:
28
+ Dictionary containing entity_accuracy and relation_accuracy metrics.
29
+ """
30
+ chunk_content, nodes, edges = unit
31
+ lang = detect_main_language(chunk_content)
32
+
33
+ # node
34
+ prompt = ACCURACY_EVALUATION_PROMPT[lang]["ENTITY"].format(
35
+ chunk_content=chunk_content,
36
+ extracted_entities=json.dumps(nodes, ensure_ascii=False, indent=2),
37
+ )
38
+
39
+ response = await self.llm_client.generate_answer(prompt)
40
+
41
+ # Try to parse JSON response
42
+ try:
43
+ node_evaluation_result = json.loads(response)
44
+ except json.JSONDecodeError:
45
+ # Try to extract JSON from markdown code blocks or other formats
46
+ json_match = re.search(r"\{.*\}", response, re.DOTALL)
47
+ if json_match:
48
+ node_evaluation_result = json.loads(json_match.group(0))
49
+ else:
50
+ logger.warning("Failed to parse LLM response.")
51
+ # default evaluation
52
+ node_evaluation_result = {
53
+ "accuracy": 0.0,
54
+ "completeness": 0.0,
55
+ "precision": 0.0,
56
+ "overall_score": 0.0,
57
+ "accuracy_reasoning": "Failed to parse LLM response",
58
+ "completeness_reasoning": "",
59
+ "precision_reasoning": "",
60
+ "issues": ["LLM response parsing failed"],
61
+ }
62
+
63
+ # edge
64
+ prompt = ACCURACY_EVALUATION_PROMPT[lang]["RELATION"].format(
65
+ chunk_content=chunk_content,
66
+ extracted_relations=json.dumps(edges, ensure_ascii=False, indent=2),
67
+ )
68
+ response = await self.llm_client.generate_answer(prompt)
69
+ # Try to parse JSON response
70
+ try:
71
+ edge_evaluation_result = json.loads(response)
72
+ except json.JSONDecodeError:
73
+ # Try to extract JSON from markdown code blocks or other formats
74
+ json_match = re.search(r"\{.*\}", response, re.DOTALL)
75
+ if json_match:
76
+ edge_evaluation_result = json.loads(json_match.group(0))
77
+ else:
78
+ logger.warning("Failed to parse LLM response.")
79
+ # default evaluation
80
+ edge_evaluation_result = {
81
+ "accuracy": 0.0,
82
+ "completeness": 0.0,
83
+ "precision": 0.0,
84
+ "overall_score": 0.0,
85
+ "accuracy_reasoning": "Failed to parse LLM response",
86
+ "completeness_reasoning": "",
87
+ "precision_reasoning": "",
88
+ "issues": ["LLM response parsing failed"],
89
+ }
90
+
91
+ return {
92
+ "entity_accuracy": node_evaluation_result,
93
+ "relation_accuracy": edge_evaluation_result,
94
+ }
graphgen/models/extractor/schema_guided_extractor.py CHANGED
@@ -1,9 +1,8 @@
1
  import json
2
- from typing import Dict, List
3
 
4
- from graphgen.bases import BaseExtractor, BaseLLMWrapper
5
  from graphgen.templates import SCHEMA_GUIDED_EXTRACTION_PROMPT
6
- from graphgen.utils import compute_dict_hash, detect_main_language, logger
7
 
8
 
9
  class SchemaGuidedExtractor(BaseExtractor):
@@ -59,9 +58,8 @@ class SchemaGuidedExtractor(BaseExtractor):
59
  )
60
  return prompt
61
 
62
- async def extract(self, chunk: dict) -> dict:
63
- _chunk_id = chunk.get("_chunk_id", "")
64
- text = chunk.get("content", "")
65
 
66
  prompt = self.build_prompt(text)
67
  response = await self.llm_client.generate_answer(prompt)
@@ -74,35 +72,9 @@ class SchemaGuidedExtractor(BaseExtractor):
74
  if any(extracted_info[key] == "" for key in self.required_keys):
75
  logger.debug("Missing required keys in extraction: %s", extracted_info)
76
  return {}
77
- main_keys_info = {key: extracted_info[key] for key in self.required_keys}
78
  logger.debug("Extracted info: %s", extracted_info)
 
79
 
80
- # add chunk metadata
81
- extracted_info["_chunk_id"] = _chunk_id
82
-
83
- return {
84
- compute_dict_hash(main_keys_info, prefix="extract-"): extracted_info
85
- }
86
  except json.JSONDecodeError:
87
  logger.error("Failed to parse extraction response: %s", response)
88
  return {}
89
-
90
- @staticmethod
91
- def merge_extractions(extraction_list: List[Dict[str, dict]]) -> Dict[str, dict]:
92
- """
93
- Merge multiple extraction results based on their hashes.
94
- :param extraction_list: List of extraction results, each is a dict with hash as key and record as value.
95
- :return: Merged extraction results.
96
- """
97
- merged: Dict[str, dict] = {}
98
- for ext in extraction_list:
99
- for h, rec in ext.items():
100
- if h not in merged:
101
- merged[h] = rec.copy()
102
- else:
103
- for k, v in rec.items():
104
- if k not in merged[h] or merged[h][k] == v:
105
- merged[h][k] = v
106
- else:
107
- merged[h][k] = f"{merged[h][k]}<SEP>{v}"
108
- return merged
 
1
  import json
 
2
 
3
+ from graphgen.bases import BaseExtractor, BaseLLMWrapper, Chunk
4
  from graphgen.templates import SCHEMA_GUIDED_EXTRACTION_PROMPT
5
+ from graphgen.utils import detect_main_language, logger
6
 
7
 
8
  class SchemaGuidedExtractor(BaseExtractor):
 
58
  )
59
  return prompt
60
 
61
+ async def extract(self, chunk: Chunk) -> dict:
62
+ text = chunk.content
 
63
 
64
  prompt = self.build_prompt(text)
65
  response = await self.llm_client.generate_answer(prompt)
 
72
  if any(extracted_info[key] == "" for key in self.required_keys):
73
  logger.debug("Missing required keys in extraction: %s", extracted_info)
74
  return {}
 
75
  logger.debug("Extracted info: %s", extracted_info)
76
+ return extracted_info
77
 
 
 
 
 
 
 
78
  except json.JSONDecodeError:
79
  logger.error("Failed to parse extraction response: %s", response)
80
  return {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/models/generator/aggregated_generator.py CHANGED
@@ -3,7 +3,7 @@ from typing import Any, Optional
3
 
4
  from graphgen.bases import BaseGenerator
5
  from graphgen.templates import AGGREGATED_GENERATION_PROMPT
6
- from graphgen.utils import compute_content_hash, detect_main_language, logger
7
 
8
 
9
  class AggregatedGenerator(BaseGenerator):
@@ -101,30 +101,26 @@ class AggregatedGenerator(BaseGenerator):
101
  batch: tuple[
102
  list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
103
  ],
104
- ) -> dict[str, Any]:
105
  """
106
  Generate QAs based on a given batch.
107
  :param batch
108
  :return: QA pairs
109
  """
110
- result = {}
111
  rephrasing_prompt = self.build_prompt(batch)
112
  response = await self.llm_client.generate_answer(rephrasing_prompt)
113
  context = self.parse_rephrased_text(response)
114
  if not context:
115
- return result
116
  question_generation_prompt = self._build_prompt_for_question_generation(context)
117
  response = await self.llm_client.generate_answer(question_generation_prompt)
118
  question = self.parse_response(response)["question"]
119
  if not question:
120
- return result
121
  logger.debug("Question: %s", question)
122
  logger.debug("Answer: %s", context)
123
  qa_pairs = {
124
- compute_content_hash(question): {
125
- "question": question,
126
- "answer": context,
127
- }
128
  }
129
- result.update(qa_pairs)
130
- return result
 
3
 
4
  from graphgen.bases import BaseGenerator
5
  from graphgen.templates import AGGREGATED_GENERATION_PROMPT
6
+ from graphgen.utils import detect_main_language, logger
7
 
8
 
9
  class AggregatedGenerator(BaseGenerator):
 
101
  batch: tuple[
102
  list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
103
  ],
104
+ ) -> list[dict]:
105
  """
106
  Generate QAs based on a given batch.
107
  :param batch
108
  :return: QA pairs
109
  """
 
110
  rephrasing_prompt = self.build_prompt(batch)
111
  response = await self.llm_client.generate_answer(rephrasing_prompt)
112
  context = self.parse_rephrased_text(response)
113
  if not context:
114
+ return []
115
  question_generation_prompt = self._build_prompt_for_question_generation(context)
116
  response = await self.llm_client.generate_answer(question_generation_prompt)
117
  question = self.parse_response(response)["question"]
118
  if not question:
119
+ return []
120
  logger.debug("Question: %s", question)
121
  logger.debug("Answer: %s", context)
122
  qa_pairs = {
123
+ "question": question,
124
+ "answer": context,
 
 
125
  }
126
+ return [qa_pairs]
 
graphgen/models/generator/atomic_generator.py CHANGED
@@ -3,7 +3,7 @@ from typing import Any
3
 
4
  from graphgen.bases import BaseGenerator
5
  from graphgen.templates import ATOMIC_GENERATION_PROMPT
6
- from graphgen.utils import compute_content_hash, detect_main_language, logger
7
 
8
 
9
  class AtomicGenerator(BaseGenerator):
@@ -23,7 +23,7 @@ class AtomicGenerator(BaseGenerator):
23
  return prompt
24
 
25
  @staticmethod
26
- def parse_response(response: str) -> dict:
27
  """
28
  AtomicGenerator normally generates one QA pair per response.
29
  So we just need to parse one QA pair from the response.
@@ -38,15 +38,10 @@ class AtomicGenerator(BaseGenerator):
38
  answer = answer_match.group(1).strip()
39
  else:
40
  logger.warning("Failed to parse response: %s", response)
41
- return {}
42
 
43
  question = question.strip('"').strip("'")
44
  answer = answer.strip('"').strip("'")
45
  logger.debug("Question: %s", question)
46
  logger.debug("Answer: %s", answer)
47
- return {
48
- compute_content_hash(question): {
49
- "question": question,
50
- "answer": answer,
51
- }
52
- }
 
3
 
4
  from graphgen.bases import BaseGenerator
5
  from graphgen.templates import ATOMIC_GENERATION_PROMPT
6
+ from graphgen.utils import detect_main_language, logger
7
 
8
 
9
  class AtomicGenerator(BaseGenerator):
 
23
  return prompt
24
 
25
  @staticmethod
26
+ def parse_response(response: str) -> list[dict]:
27
  """
28
  AtomicGenerator normally generates one QA pair per response.
29
  So we just need to parse one QA pair from the response.
 
38
  answer = answer_match.group(1).strip()
39
  else:
40
  logger.warning("Failed to parse response: %s", response)
41
+ return []
42
 
43
  question = question.strip('"').strip("'")
44
  answer = answer.strip('"').strip("'")
45
  logger.debug("Question: %s", question)
46
  logger.debug("Answer: %s", answer)
47
+ return [{"question": question, "answer": answer}]
 
 
 
 
 
graphgen/models/generator/cot_generator.py CHANGED
@@ -3,7 +3,7 @@ from typing import Any
3
 
4
  from graphgen.bases import BaseGenerator
5
  from graphgen.templates import COT_GENERATION_PROMPT
6
- from graphgen.utils import compute_content_hash, detect_main_language, logger
7
 
8
 
9
  class CoTGenerator(BaseGenerator):
@@ -100,28 +100,25 @@ class CoTGenerator(BaseGenerator):
100
  batch: tuple[
101
  list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
102
  ],
103
- ) -> dict[str, Any]:
104
  """
105
  Generate QAs based on a given batch.
106
  :param batch
107
  :return: QA pairs
108
  """
109
- result = {}
110
  prompt = self.build_prompt(batch)
111
  response = await self.llm_client.generate_answer(prompt)
112
  response = self.parse_response(response)
113
  if not response:
114
- return result
115
  question, reasoning_path = response["question"], response["reasoning_path"]
116
  prompt = self.build_prompt_for_cot_generation(batch, question, reasoning_path)
117
  cot_answer = await self.llm_client.generate_answer(prompt)
118
  logger.debug("CoT Answer: %s", cot_answer)
119
- qa_pairs = {
120
- compute_content_hash(question): {
121
  "question": question,
122
  "answer": cot_answer,
123
  "reasoning_path": reasoning_path,
124
  }
125
- }
126
- result.update(qa_pairs)
127
- return result
 
3
 
4
  from graphgen.bases import BaseGenerator
5
  from graphgen.templates import COT_GENERATION_PROMPT
6
+ from graphgen.utils import detect_main_language, logger
7
 
8
 
9
  class CoTGenerator(BaseGenerator):
 
100
  batch: tuple[
101
  list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
102
  ],
103
+ ) -> list[dict]:
104
  """
105
  Generate QAs based on a given batch.
106
  :param batch
107
  :return: QA pairs
108
  """
 
109
  prompt = self.build_prompt(batch)
110
  response = await self.llm_client.generate_answer(prompt)
111
  response = self.parse_response(response)
112
  if not response:
113
+ return []
114
  question, reasoning_path = response["question"], response["reasoning_path"]
115
  prompt = self.build_prompt_for_cot_generation(batch, question, reasoning_path)
116
  cot_answer = await self.llm_client.generate_answer(prompt)
117
  logger.debug("CoT Answer: %s", cot_answer)
118
+ return [
119
+ {
120
  "question": question,
121
  "answer": cot_answer,
122
  "reasoning_path": reasoning_path,
123
  }
124
+ ]
 
 
graphgen/models/generator/fill_in_blank_generator.py CHANGED
@@ -3,7 +3,7 @@ from typing import Any
3
 
4
  from graphgen.bases import BaseGenerator
5
  from graphgen.templates import FILL_IN_BLANK_GENERATION_PROMPT
6
- from graphgen.utils import compute_content_hash, detect_main_language, logger
7
 
8
 
9
  class FillInBlankGenerator(BaseGenerator):
@@ -12,7 +12,7 @@ class FillInBlankGenerator(BaseGenerator):
12
  self.num_of_questions = num_of_questions
13
 
14
  @staticmethod
15
- def parse_response(response: str) -> Any:
16
  """
17
  Parse fill-in-the-blank QA pairs from the LLM response.
18
  Each QA pair contains question text with placeholders and the correct answer(s).
@@ -21,14 +21,14 @@ class FillInBlankGenerator(BaseGenerator):
21
  :return: Dictionary mapping question hash to question data, where each
22
  value is a dict with "question", "answer", and "answers" keys
23
  """
24
- qa_pairs = {}
25
 
26
  # Extract all QA pair blocks
27
  qa_blocks = re.findall(r"<qa_pair>(.*?)</qa_pair>", response, re.DOTALL)
28
 
29
  if not qa_blocks:
30
  logger.warning("No QA pairs found in response: %s", response)
31
- return {}
32
 
33
  for block in qa_blocks:
34
  # Extract and clean question text
@@ -55,13 +55,13 @@ class FillInBlankGenerator(BaseGenerator):
55
  logger.warning("No valid answers found in: %s", answer_text)
56
  continue
57
 
58
- # Build result entry with question hash as key
59
- question_hash = compute_content_hash(question)
60
- qa_pairs[question_hash] = {
61
- "question": question,
62
- "answer": answer_text, # Original answer text with commas
63
- "answers": answers, # List of individual answers: ["A8X"] or ["A8X", "八百万"]
64
- }
65
 
66
  logger.debug(
67
  "Successfully parsed fill-in-the-blank question: %s", question[:50]
 
3
 
4
  from graphgen.bases import BaseGenerator
5
  from graphgen.templates import FILL_IN_BLANK_GENERATION_PROMPT
6
+ from graphgen.utils import detect_main_language, logger
7
 
8
 
9
  class FillInBlankGenerator(BaseGenerator):
 
12
  self.num_of_questions = num_of_questions
13
 
14
  @staticmethod
15
+ def parse_response(response: str) -> list[dict]:
16
  """
17
  Parse fill-in-the-blank QA pairs from the LLM response.
18
  Each QA pair contains question text with placeholders and the correct answer(s).
 
21
  :return: Dictionary mapping question hash to question data, where each
22
  value is a dict with "question", "answer", and "answers" keys
23
  """
24
+ qa_pairs = []
25
 
26
  # Extract all QA pair blocks
27
  qa_blocks = re.findall(r"<qa_pair>(.*?)</qa_pair>", response, re.DOTALL)
28
 
29
  if not qa_blocks:
30
  logger.warning("No QA pairs found in response: %s", response)
31
+ return qa_pairs
32
 
33
  for block in qa_blocks:
34
  # Extract and clean question text
 
55
  logger.warning("No valid answers found in: %s", answer_text)
56
  continue
57
 
58
+ qa_pairs.append(
59
+ {
60
+ "question": question,
61
+ "answer": answer_text, # Original answer text with commas
62
+ "answers": answers, # List of individual answers: ["A8X"] or ["A8X", "八百万"]
63
+ }
64
+ )
65
 
66
  logger.debug(
67
  "Successfully parsed fill-in-the-blank question: %s", question[:50]
graphgen/models/generator/multi_answer_generator.py CHANGED
@@ -3,7 +3,7 @@ from typing import Any
3
 
4
  from graphgen.bases import BaseGenerator
5
  from graphgen.templates import MAQ_GENERATION_PROMPT
6
- from graphgen.utils import compute_content_hash, detect_main_language, logger
7
 
8
 
9
  class MultiAnswerGenerator(BaseGenerator):
@@ -12,7 +12,7 @@ class MultiAnswerGenerator(BaseGenerator):
12
  self.num_of_questions = num_of_questions
13
 
14
  @staticmethod
15
- def parse_response(response: str) -> Any:
16
  """
17
  Parse multiple-answer QA pairs from the LLM response.
18
  Each QA pair contains question text, four options, and the correct answers (one or more).
@@ -21,14 +21,14 @@ class MultiAnswerGenerator(BaseGenerator):
21
  :return: Dictionary mapping question hash to question data, where each
22
  value is a dict with "question", "options", and "answer" keys
23
  """
24
- qa_pairs = {}
25
 
26
  # Extract all QA pair blocks
27
  qa_blocks = re.findall(r"<qa_pair>(.*?)</qa_pair>", response, re.DOTALL)
28
 
29
  if not qa_blocks:
30
  logger.warning("No QA pairs found in response: %s", response)
31
- return {}
32
 
33
  for block in qa_blocks:
34
  # Extract and clean question text
@@ -61,7 +61,9 @@ class MultiAnswerGenerator(BaseGenerator):
61
  logger.warning("Failed to parse answer from block: %s", block)
62
  continue
63
  answer_text = ans_match.group(1).strip().strip('"').strip("'")
64
- answers = [ans.strip().upper() for ans in answer_text.split(",") if ans.strip()]
 
 
65
  invalid_answers = [ans for ans in answers if ans not in options]
66
  if invalid_answers:
67
  logger.warning(
@@ -76,13 +78,13 @@ class MultiAnswerGenerator(BaseGenerator):
76
  logger.warning("No valid answers found in: %s", answer_text)
77
  continue
78
 
79
- # Build result entry with question hash as key
80
- question_hash = compute_content_hash(question)
81
- qa_pairs[question_hash] = {
82
- "question": question,
83
- "options": options, # Dict like {"A": "text", "B": "text", ...}
84
- "answer": ", ".join(answers),
85
- }
86
 
87
  logger.debug("Successfully parsed MAQ: %s", question[:50])
88
 
 
3
 
4
  from graphgen.bases import BaseGenerator
5
  from graphgen.templates import MAQ_GENERATION_PROMPT
6
+ from graphgen.utils import detect_main_language, logger
7
 
8
 
9
  class MultiAnswerGenerator(BaseGenerator):
 
12
  self.num_of_questions = num_of_questions
13
 
14
  @staticmethod
15
+ def parse_response(response: str) -> list[dict]:
16
  """
17
  Parse multiple-answer QA pairs from the LLM response.
18
  Each QA pair contains question text, four options, and the correct answers (one or more).
 
21
  :return: Dictionary mapping question hash to question data, where each
22
  value is a dict with "question", "options", and "answer" keys
23
  """
24
+ qa_pairs = []
25
 
26
  # Extract all QA pair blocks
27
  qa_blocks = re.findall(r"<qa_pair>(.*?)</qa_pair>", response, re.DOTALL)
28
 
29
  if not qa_blocks:
30
  logger.warning("No QA pairs found in response: %s", response)
31
+ return qa_pairs
32
 
33
  for block in qa_blocks:
34
  # Extract and clean question text
 
61
  logger.warning("Failed to parse answer from block: %s", block)
62
  continue
63
  answer_text = ans_match.group(1).strip().strip('"').strip("'")
64
+ answers = [
65
+ ans.strip().upper() for ans in answer_text.split(",") if ans.strip()
66
+ ]
67
  invalid_answers = [ans for ans in answers if ans not in options]
68
  if invalid_answers:
69
  logger.warning(
 
78
  logger.warning("No valid answers found in: %s", answer_text)
79
  continue
80
 
81
+ qa_pairs.append(
82
+ {
83
+ "question": question,
84
+ "options": options, # Dict like {"A": "text", "B": "text", ...}
85
+ "answers": answers, # List of correct answers: ["A", "C"]
86
+ }
87
+ )
88
 
89
  logger.debug("Successfully parsed MAQ: %s", question[:50])
90
 
graphgen/models/generator/multi_choice_generator.py CHANGED
@@ -3,7 +3,7 @@ from typing import Any
3
 
4
  from graphgen.bases import BaseGenerator
5
  from graphgen.templates import MCQ_GENERATION_PROMPT
6
- from graphgen.utils import compute_content_hash, detect_main_language, logger
7
 
8
 
9
  class MultiChoiceGenerator(BaseGenerator):
@@ -12,7 +12,7 @@ class MultiChoiceGenerator(BaseGenerator):
12
  self.num_of_questions = num_of_questions
13
 
14
  @staticmethod
15
- def parse_response(response: str) -> Any:
16
  """
17
  Parse multiple choice QA pairs from the LLM response.
18
  Each QA pair contains question text, four options, and the correct answer.
@@ -21,14 +21,14 @@ class MultiChoiceGenerator(BaseGenerator):
21
  :return: Dictionary mapping question hash to question data, where each
22
  value is a dict with "question", "options", and "answer" keys
23
  """
24
- qa_pairs = {}
25
 
26
  # Extract all QA pair blocks
27
  qa_blocks = re.findall(r"<qa_pair>(.*?)</qa_pair>", response, re.DOTALL)
28
 
29
  if not qa_blocks:
30
  logger.warning("No QA pairs found in response: %s", response)
31
- return {}
32
 
33
  for block in qa_blocks:
34
  # Extract and clean question text
@@ -76,13 +76,13 @@ class MultiChoiceGenerator(BaseGenerator):
76
  )
77
  continue
78
 
79
- # Build result entry with question hash as key
80
- question_hash = compute_content_hash(question)
81
- qa_pairs[question_hash] = {
82
- "question": question,
83
- "options": options, # Dict like {"A": "text", "B": "text", ...}
84
- "answer": answer, # Single letter: "A", "B", "C", or "D"
85
- }
86
 
87
  logger.debug("Successfully parsed MCQ: %s", question[:50])
88
 
 
3
 
4
  from graphgen.bases import BaseGenerator
5
  from graphgen.templates import MCQ_GENERATION_PROMPT
6
+ from graphgen.utils import detect_main_language, logger
7
 
8
 
9
  class MultiChoiceGenerator(BaseGenerator):
 
12
  self.num_of_questions = num_of_questions
13
 
14
  @staticmethod
15
+ def parse_response(response: str) -> list[dict]:
16
  """
17
  Parse multiple choice QA pairs from the LLM response.
18
  Each QA pair contains question text, four options, and the correct answer.
 
21
  :return: Dictionary mapping question hash to question data, where each
22
  value is a dict with "question", "options", and "answer" keys
23
  """
24
+ qa_pairs = []
25
 
26
  # Extract all QA pair blocks
27
  qa_blocks = re.findall(r"<qa_pair>(.*?)</qa_pair>", response, re.DOTALL)
28
 
29
  if not qa_blocks:
30
  logger.warning("No QA pairs found in response: %s", response)
31
+ return qa_pairs
32
 
33
  for block in qa_blocks:
34
  # Extract and clean question text
 
76
  )
77
  continue
78
 
79
+ qa_pairs.append(
80
+ {
81
+ "question": question,
82
+ "options": options, # Dict like {"A": "text", "B": "text", ...}
83
+ "answer": answer, # Single letter: "A", "B", "C", or "D"
84
+ }
85
+ )
86
 
87
  logger.debug("Successfully parsed MCQ: %s", question[:50])
88
 
graphgen/models/generator/multi_hop_generator.py CHANGED
@@ -3,7 +3,7 @@ from typing import Any
3
 
4
  from graphgen.bases import BaseGenerator
5
  from graphgen.templates import MULTI_HOP_GENERATION_PROMPT
6
- from graphgen.utils import compute_content_hash, detect_main_language, logger
7
 
8
 
9
  class MultiHopGenerator(BaseGenerator):
@@ -32,7 +32,7 @@ class MultiHopGenerator(BaseGenerator):
32
  return prompt
33
 
34
  @staticmethod
35
- def parse_response(response: str) -> dict:
36
  question_match = re.search(r"<question>(.*?)</question>", response, re.DOTALL)
37
  answer_match = re.search(r"<answer>(.*?)</answer>", response, re.DOTALL)
38
 
@@ -41,15 +41,10 @@ class MultiHopGenerator(BaseGenerator):
41
  answer = answer_match.group(1).strip()
42
  else:
43
  logger.warning("Failed to parse response: %s", response)
44
- return {}
45
 
46
  question = question.strip('"').strip("'")
47
  answer = answer.strip('"').strip("'")
48
  logger.debug("Question: %s", question)
49
  logger.debug("Answer: %s", answer)
50
- return {
51
- compute_content_hash(question): {
52
- "question": question,
53
- "answer": answer,
54
- }
55
- }
 
3
 
4
  from graphgen.bases import BaseGenerator
5
  from graphgen.templates import MULTI_HOP_GENERATION_PROMPT
6
+ from graphgen.utils import detect_main_language, logger
7
 
8
 
9
  class MultiHopGenerator(BaseGenerator):
 
32
  return prompt
33
 
34
  @staticmethod
35
+ def parse_response(response: str) -> list[dict]:
36
  question_match = re.search(r"<question>(.*?)</question>", response, re.DOTALL)
37
  answer_match = re.search(r"<answer>(.*?)</answer>", response, re.DOTALL)
38
 
 
41
  answer = answer_match.group(1).strip()
42
  else:
43
  logger.warning("Failed to parse response: %s", response)
44
+ return []
45
 
46
  question = question.strip('"').strip("'")
47
  answer = answer.strip('"').strip("'")
48
  logger.debug("Question: %s", question)
49
  logger.debug("Answer: %s", answer)
50
+ return [{"question": question, "answer": answer}]
 
 
 
 
 
graphgen/models/generator/quiz_generator.py CHANGED
@@ -31,12 +31,16 @@ class QuizGenerator(BaseGenerator):
31
  description = edges[0][2].get("description", "")
32
  template_type = edges[0][2].get("template_type", "TEMPLATE")
33
  else:
34
- raise ValueError("Batch must contain at least one node or edge with description")
 
 
35
 
36
  return QuizGenerator.build_prompt_for_description(description, template_type)
37
 
38
  @staticmethod
39
- def build_prompt_for_description(description: str, template_type: str = "TEMPLATE") -> str:
 
 
40
  """
41
  Build prompt for rephrasing a single description.
42
  :param description: The description to rephrase
@@ -49,17 +53,6 @@ class QuizGenerator(BaseGenerator):
49
  )
50
  return prompt
51
 
52
- @staticmethod
53
- def parse_rephrased_text(response: str) -> str:
54
- """
55
- Parse the rephrased text from the response.
56
- :param response:
57
- :return:
58
- """
59
- rephrased_text = response.strip().strip('"')
60
- logger.debug("Rephrased Text: %s", rephrased_text)
61
- return rephrased_text
62
-
63
  @staticmethod
64
  def parse_response(response: str) -> Any:
65
  """
@@ -67,4 +60,15 @@ class QuizGenerator(BaseGenerator):
67
  :param response: LLM response
68
  :return: Rephrased text
69
  """
70
- return QuizGenerator.parse_rephrased_text(response)
 
 
 
 
 
 
 
 
 
 
 
 
31
  description = edges[0][2].get("description", "")
32
  template_type = edges[0][2].get("template_type", "TEMPLATE")
33
  else:
34
+ raise ValueError(
35
+ "Batch must contain at least one node or edge with description"
36
+ )
37
 
38
  return QuizGenerator.build_prompt_for_description(description, template_type)
39
 
40
  @staticmethod
41
+ def build_prompt_for_description(
42
+ description: str, template_type: str = "TEMPLATE"
43
+ ) -> str:
44
  """
45
  Build prompt for rephrasing a single description.
46
  :param description: The description to rephrase
 
53
  )
54
  return prompt
55
 
 
 
 
 
 
 
 
 
 
 
 
56
  @staticmethod
57
  def parse_response(response: str) -> Any:
58
  """
 
60
  :param response: LLM response
61
  :return: Rephrased text
62
  """
63
+
64
+ def parse_rephrased_text(content: str) -> str:
65
+ """
66
+ Parse the rephrased text from the response.
67
+ :param content: LLM response content
68
+ :return:
69
+ """
70
+ rephrased_text = content.strip().strip('"')
71
+ logger.debug("Rephrased Text: %s", rephrased_text)
72
+ return rephrased_text
73
+
74
+ return parse_rephrased_text(response)
graphgen/models/generator/true_false_generator.py CHANGED
@@ -3,7 +3,7 @@ from typing import Any
3
 
4
  from graphgen.bases import BaseGenerator
5
  from graphgen.templates import TF_GENERATION_PROMPT
6
- from graphgen.utils import compute_content_hash, detect_main_language, logger
7
 
8
 
9
  class TrueFalseGenerator(BaseGenerator):
@@ -12,7 +12,7 @@ class TrueFalseGenerator(BaseGenerator):
12
  self.num_of_questions = num_of_questions
13
 
14
  @staticmethod
15
- def parse_response(response: str) -> Any:
16
  """
17
  Parse true/false QA pairs from the LLM response.
18
  Each QA pair contains a statement question and True/False answer.
@@ -21,14 +21,14 @@ class TrueFalseGenerator(BaseGenerator):
21
  :return: Dictionary mapping question hash to question data, where each
22
  value is a dict with "question", "options", and "answer" keys
23
  """
24
- qa_pairs: dict[str, dict[str, Any]] = {}
25
 
26
  # Extract all QA pair blocks
27
  qa_blocks = re.findall(r"<qa_pair>(.*?)</qa_pair>", response, re.DOTALL)
28
 
29
  if not qa_blocks:
30
  logger.warning("No QA pairs found in response: %s", response)
31
- return {}
32
 
33
  for block in qa_blocks:
34
  # Extract and clean question text
@@ -50,12 +50,12 @@ class TrueFalseGenerator(BaseGenerator):
50
  logger.warning("Invalid answer '%s' in block: %s", answer, block)
51
  continue
52
 
53
- # Build result entry with question hash as key
54
- question_hash = compute_content_hash(question)
55
- qa_pairs[question_hash] = {
56
- "question": question,
57
- "answer": answer, # "True" or "False"
58
- }
59
 
60
  logger.debug("Successfully parsed TF question: %s", question[:50])
61
 
 
3
 
4
  from graphgen.bases import BaseGenerator
5
  from graphgen.templates import TF_GENERATION_PROMPT
6
+ from graphgen.utils import detect_main_language, logger
7
 
8
 
9
  class TrueFalseGenerator(BaseGenerator):
 
12
  self.num_of_questions = num_of_questions
13
 
14
  @staticmethod
15
+ def parse_response(response: str) -> list[dict]:
16
  """
17
  Parse true/false QA pairs from the LLM response.
18
  Each QA pair contains a statement question and True/False answer.
 
21
  :return: Dictionary mapping question hash to question data, where each
22
  value is a dict with "question", "options", and "answer" keys
23
  """
24
+ qa_pairs: list[dict[str, str]] = []
25
 
26
  # Extract all QA pair blocks
27
  qa_blocks = re.findall(r"<qa_pair>(.*?)</qa_pair>", response, re.DOTALL)
28
 
29
  if not qa_blocks:
30
  logger.warning("No QA pairs found in response: %s", response)
31
+ return qa_pairs
32
 
33
  for block in qa_blocks:
34
  # Extract and clean question text
 
50
  logger.warning("Invalid answer '%s' in block: %s", answer, block)
51
  continue
52
 
53
+ qa_pairs.append(
54
+ {
55
+ "question": question,
56
+ "answer": answer, # "True" or "False"
57
+ }
58
+ )
59
 
60
  logger.debug("Successfully parsed TF question: %s", question[:50])
61
 
graphgen/models/generator/vqa_generator.py CHANGED
@@ -1,9 +1,10 @@
 
1
  import re
2
  from typing import Any
3
 
4
  from graphgen.bases import BaseGenerator
5
  from graphgen.templates import VQA_GENERATION_PROMPT
6
- from graphgen.utils import compute_content_hash, detect_main_language, logger
7
 
8
 
9
  class VQAGenerator(BaseGenerator):
@@ -32,13 +33,13 @@ class VQAGenerator(BaseGenerator):
32
  return prompt
33
 
34
  @staticmethod
35
- def parse_response(response: str) -> Any:
36
  """
37
  Parse the LLM response and return the generated QAs
38
  :param response
39
  :return: QA pairs
40
  """
41
- qa_pairs = {}
42
  pattern = r"<question>(.*?)</question>\s*<answer>(.*?)</answer>"
43
  matches = re.findall(pattern, response, re.DOTALL)
44
 
@@ -48,10 +49,12 @@ class VQAGenerator(BaseGenerator):
48
  answer = answer.strip().strip('"').strip("'")
49
  logger.debug("Question: %s", question)
50
  logger.debug("Answer: %s", answer)
51
- qa_pairs[compute_content_hash(question)] = {
52
- "question": question,
53
- "answer": answer,
54
- }
 
 
55
  else:
56
  logger.warning("Error parsing the response %s", response)
57
  return qa_pairs
@@ -61,76 +64,58 @@ class VQAGenerator(BaseGenerator):
61
  batch: tuple[
62
  list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
63
  ],
64
- ) -> dict[str, Any]:
65
  """
66
  Generate QAs based on a given batch.
67
  :param batch
68
  :return: QA pairs
69
  """
70
- result = {}
71
  prompt = self.build_prompt(batch)
72
  response = await self.llm_client.generate_answer(prompt)
73
  qa_pairs = self.parse_response(response) # generate one or more QA pairs
74
  nodes, _ = batch
75
  for node in nodes:
76
  node_data = node[1]
77
- if "image_data" in node_data and node_data["image_data"]:
78
- img_path = node_data["image_data"]["img_path"]
79
- for qa in qa_pairs.values():
 
80
  qa["img_path"] = img_path
81
- result.update(qa_pairs)
82
- return result
83
 
84
  @staticmethod
85
- def format_generation_results(
86
- results: list[dict], output_data_format: str
87
- ) -> list[dict[str, Any]]:
 
88
  if output_data_format == "Alpaca":
89
- results = [
90
- {
91
- "instruction": v["question"],
92
- "input": "",
93
- "output": v["answer"],
94
- "image": v.get("img_path", ""),
95
- }
96
- for item in results
97
- for k, v in item.items()
98
- ]
99
- elif output_data_format == "Sharegpt":
100
- results = [
101
- {
102
- "conversations": [
103
- {
104
- "from": "human",
105
- "value": [
106
- {"text": v["question"], "image": v.get("img_path", "")}
107
- ],
108
- },
109
- {"from": "gpt", "value": [{"text": v["answer"]}]},
110
- ]
111
- }
112
- for item in results
113
- for k, v in item.items()
114
- ]
115
- elif output_data_format == "ChatML":
116
- results = [
117
- {
118
- "messages": [
119
- {
120
- "role": "user",
121
- "content": [
122
- {"text": v["question"], "image": v.get("img_path", "")}
123
- ],
124
- },
125
- {
126
- "role": "assistant",
127
- "content": [{"type": "text", "text": v["answer"]}],
128
- },
129
- ]
130
- }
131
- for item in results
132
- for k, v in item.items()
133
- ]
134
- else:
135
- raise ValueError(f"Unknown output data format: {output_data_format}")
136
- return results
 
1
+ import json
2
  import re
3
  from typing import Any
4
 
5
  from graphgen.bases import BaseGenerator
6
  from graphgen.templates import VQA_GENERATION_PROMPT
7
+ from graphgen.utils import detect_main_language, logger
8
 
9
 
10
  class VQAGenerator(BaseGenerator):
 
33
  return prompt
34
 
35
  @staticmethod
36
+ def parse_response(response: str) -> list[dict]:
37
  """
38
  Parse the LLM response and return the generated QAs
39
  :param response
40
  :return: QA pairs
41
  """
42
+ qa_pairs = []
43
  pattern = r"<question>(.*?)</question>\s*<answer>(.*?)</answer>"
44
  matches = re.findall(pattern, response, re.DOTALL)
45
 
 
49
  answer = answer.strip().strip('"').strip("'")
50
  logger.debug("Question: %s", question)
51
  logger.debug("Answer: %s", answer)
52
+ qa_pairs.append(
53
+ {
54
+ "question": question,
55
+ "answer": answer,
56
+ }
57
+ )
58
  else:
59
  logger.warning("Error parsing the response %s", response)
60
  return qa_pairs
 
64
  batch: tuple[
65
  list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
66
  ],
67
+ ) -> list[dict]:
68
  """
69
  Generate QAs based on a given batch.
70
  :param batch
71
  :return: QA pairs
72
  """
 
73
  prompt = self.build_prompt(batch)
74
  response = await self.llm_client.generate_answer(prompt)
75
  qa_pairs = self.parse_response(response) # generate one or more QA pairs
76
  nodes, _ = batch
77
  for node in nodes:
78
  node_data = node[1]
79
+ if "metadata" in node_data and node_data["metadata"]:
80
+ metadata = json.loads(node_data["metadata"])["metadata"]
81
+ img_path = metadata.get("path", "")
82
+ for qa in qa_pairs:
83
  qa["img_path"] = img_path
84
+ return qa_pairs
 
85
 
86
  @staticmethod
87
+ def format_generation_results(result: dict, output_data_format: str) -> dict:
88
+ question = result.get("question", "")
89
+ answer = result.get("answer", "")
90
+ img_path = result.get("img_path", "")
91
  if output_data_format == "Alpaca":
92
+ return {
93
+ "instruction": question,
94
+ "input": "",
95
+ "output": answer,
96
+ "image": img_path,
97
+ }
98
+ if output_data_format == "Sharegpt":
99
+ return {
100
+ "conversations": [
101
+ {
102
+ "from": "human",
103
+ "value": [{"text": question, "image": img_path}],
104
+ },
105
+ {"from": "gpt", "value": [{"text": answer}]},
106
+ ]
107
+ }
108
+ if output_data_format == "ChatML":
109
+ return {
110
+ "messages": [
111
+ {
112
+ "role": "user",
113
+ "content": [{"text": question, "image": img_path}],
114
+ },
115
+ {
116
+ "role": "assistant",
117
+ "content": [{"type": "text", "text": answer}],
118
+ },
119
+ ]
120
+ }
121
+ raise ValueError(f"Unknown output data format: {output_data_format}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/models/kg_builder/light_rag_kg_builder.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import re
2
  from collections import Counter, defaultdict
3
  from typing import Dict, List, Tuple
@@ -130,15 +131,25 @@ class LightRAGKGBuilder(BaseKGBuilder):
130
  set([dp["source_id"] for dp in node_data] + source_ids)
131
  )
132
 
133
- node_data = {
134
  "entity_type": entity_type,
135
  "entity_name": entity_name,
136
  "description": description,
137
  "source_id": source_id,
138
  "length": self.tokenizer.count_tokens(description),
139
  }
140
- kg_instance.upsert_node(entity_name, node_data=node_data)
141
- return node_data
 
 
 
 
 
 
 
 
 
 
142
 
143
  async def merge_edges(
144
  self,
 
1
+ import json
2
  import re
3
  from collections import Counter, defaultdict
4
  from typing import Dict, List, Tuple
 
131
  set([dp["source_id"] for dp in node_data] + source_ids)
132
  )
133
 
134
+ node_data_dict = {
135
  "entity_type": entity_type,
136
  "entity_name": entity_name,
137
  "description": description,
138
  "source_id": source_id,
139
  "length": self.tokenizer.count_tokens(description),
140
  }
141
+
142
+ if entity_type in ("IMAGE", "TABLE", "FORMULA"):
143
+ metadata = next(
144
+ (dp["metadata"] for dp in node_data if dp.get("metadata")), None
145
+ )
146
+ if metadata:
147
+ node_data_dict["metadata"] = json.dumps(
148
+ metadata, ensure_ascii=False, default=str
149
+ )
150
+
151
+ kg_instance.upsert_node(entity_name, node_data=node_data_dict)
152
+ return node_data_dict
153
 
154
  async def merge_edges(
155
  self,
graphgen/models/kg_builder/mm_kg_builder.py CHANGED
@@ -70,6 +70,8 @@ class MMKGBuilder(LightRAGKGBuilder):
70
 
71
  entity = await handle_single_entity_extraction(attributes, chunk_id)
72
  if entity is not None:
 
 
73
  nodes[entity["entity_name"]].append(entity)
74
  continue
75
 
 
70
 
71
  entity = await handle_single_entity_extraction(attributes, chunk_id)
72
  if entity is not None:
73
+ if entity["entity_type"] == "IMAGE":
74
+ entity["metadata"] = chunk.metadata
75
  nodes[entity["entity_name"]].append(entity)
76
  continue
77
 
graphgen/models/reader/csv_reader.py CHANGED
@@ -22,7 +22,7 @@ class CSVReader(BaseReader):
22
  :return: Ray Dataset containing validated and filtered data.
23
  """
24
 
25
- ds = ray.data.read_csv(input_path)
26
  ds = ds.map_batches(self._validate_batch, batch_format="pandas")
27
  ds = ds.filter(self._should_keep_item)
28
  return ds
 
22
  :return: Ray Dataset containing validated and filtered data.
23
  """
24
 
25
+ ds = ray.data.read_csv(input_path, include_paths=True)
26
  ds = ds.map_batches(self._validate_batch, batch_format="pandas")
27
  ds = ds.filter(self._should_keep_item)
28
  return ds
graphgen/models/reader/json_reader.py CHANGED
@@ -34,10 +34,13 @@ class JSONReader(BaseReader):
34
  with open(file, "r", encoding="utf-8") as f:
35
  data = json.load(f)
36
  data = self._unify_schema(data)
 
 
 
37
  file_ds: ray.data.Dataset = ray.data.from_items(data)
38
  ds = ds.union(file_ds) # type: ignore
39
  else:
40
- ds = ray.data.read_json(input_path)
41
  ds = ds.map_batches(self._validate_batch, batch_format="pandas")
42
  ds = ds.filter(self._should_keep_item)
43
  return ds
 
34
  with open(file, "r", encoding="utf-8") as f:
35
  data = json.load(f)
36
  data = self._unify_schema(data)
37
+ # add path
38
+ for item in data:
39
+ item["path"] = file
40
  file_ds: ray.data.Dataset = ray.data.from_items(data)
41
  ds = ds.union(file_ds) # type: ignore
42
  else:
43
+ ds = ray.data.read_json(input_path, include_paths=True)
44
  ds = ds.map_batches(self._validate_batch, batch_format="pandas")
45
  ds = ds.filter(self._should_keep_item)
46
  return ds
graphgen/models/reader/parquet_reader.py CHANGED
@@ -24,7 +24,7 @@ class ParquetReader(BaseReader):
24
  if not ray.is_initialized():
25
  ray.init()
26
 
27
- ds = ray.data.read_parquet(input_path)
28
  ds = ds.map_batches(self._validate_batch, batch_format="pandas")
29
  ds = ds.filter(self._should_keep_item)
30
  return ds
 
24
  if not ray.is_initialized():
25
  ray.init()
26
 
27
+ ds = ray.data.read_parquet(input_path, include_paths=True)
28
  ds = ds.map_batches(self._validate_batch, batch_format="pandas")
29
  ds = ds.filter(self._should_keep_item)
30
  return ds
graphgen/models/reader/rdf_reader.py CHANGED
@@ -118,7 +118,7 @@ class RDFReader(BaseReader):
118
  "id": str(subj),
119
  self.text_column: text,
120
  "properties": props,
121
- "source_file": str(file_path),
122
  }
123
  docs.append(doc)
124
 
 
118
  "id": str(subj),
119
  self.text_column: text,
120
  "properties": props,
121
+ "path": str(file_path),
122
  }
123
  docs.append(doc)
124
 
graphgen/models/reader/txt_reader.py CHANGED
@@ -18,13 +18,14 @@ class TXTReader(BaseReader):
18
  """
19
  docs_ds = ray.data.read_binary_files(
20
  input_path,
21
- include_paths=False,
22
  )
23
 
24
  docs_ds = docs_ds.map(
25
  lambda row: {
26
  "type": "text",
27
  self.text_column: row["bytes"].decode("utf-8"),
 
28
  }
29
  )
30
 
 
18
  """
19
  docs_ds = ray.data.read_binary_files(
20
  input_path,
21
+ include_paths=True,
22
  )
23
 
24
  docs_ds = docs_ds.map(
25
  lambda row: {
26
  "type": "text",
27
  self.text_column: row["bytes"].decode("utf-8"),
28
+ "path": row["path"],
29
  }
30
  )
31
 
graphgen/models/storage/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- from graphgen.models.storage.graph.kuzu_storage import KuzuStorage
2
- from graphgen.models.storage.graph.networkx_storage import NetworkXStorage
3
- from graphgen.models.storage.kv.json_storage import JsonKVStorage
4
- from graphgen.models.storage.kv.rocksdb_storage import RocksDBKVStorage
5
-
6
- from .rocksdb_cache import RocksDBCache
 
 
 
 
 
 
 
graphgen/models/storage/rocksdb_cache.py DELETED
@@ -1,43 +0,0 @@
1
- from pathlib import Path
2
- from typing import Any, Iterator, Optional
3
-
4
- # rocksdict is a lightweight C wrapper around RocksDB for Python, pylint may not recognize it
5
- # pylint: disable=no-name-in-module
6
- from rocksdict import Rdict
7
-
8
-
9
- class RocksDBCache:
10
- def __init__(self, cache_dir: str):
11
- self.db_path = Path(cache_dir)
12
- self.db = Rdict(str(self.db_path))
13
-
14
- def get(self, key: str) -> Optional[Any]:
15
- return self.db.get(key)
16
-
17
- def set(self, key: str, value: Any):
18
- self.db[key] = value
19
-
20
- def delete(self, key: str):
21
- try:
22
- del self.db[key]
23
- except KeyError:
24
- # If the key does not exist, do nothing (deletion is idempotent for caches)
25
- pass
26
-
27
- def close(self):
28
- if hasattr(self, "db") and self.db is not None:
29
- self.db.close()
30
- self.db = None
31
-
32
- def __del__(self):
33
- # Ensure the database is closed when the object is destroyed
34
- self.close()
35
-
36
- def __enter__(self):
37
- return self
38
-
39
- def __exit__(self, exc_type, exc_val, exc_tb):
40
- self.close()
41
-
42
- def __iter__(self) -> Iterator[str]:
43
- return iter(self.db.keys())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/models/vis/__init__.py DELETED
File without changes
graphgen/models/vis/community_visualizer.py DELETED
@@ -1,48 +0,0 @@
1
- from dataclasses import dataclass
2
- from typing import Dict
3
-
4
- import matplotlib.pyplot as plt
5
- import networkx as nx
6
-
7
-
8
- @dataclass
9
- class Visualizer:
10
- """
11
- Class for visualizing graphs using NetworkX and Matplotlib.
12
- """
13
-
14
- graph: nx.Graph = None
15
- communities: Dict[str, int] = None
16
- layout: str = "spring"
17
- max_nodes: int = 1000
18
- node_size: int = 10
19
- alpha: float = 0.6
20
-
21
- def visualize(self, save_path: str = None):
22
- n = self.graph.number_of_nodes()
23
- if self.layout == "spring":
24
- k = max(0.1, 1.0 / (n**0.5))
25
- pos = nx.spring_layout(self.graph, k=k, seed=42)
26
- else:
27
- raise ValueError(f"Unknown layout: {self.layout}")
28
-
29
- plt.figure(figsize=(10, 10))
30
-
31
- node_colors = [self.communities.get(node, 0) for node in self.graph.nodes()]
32
-
33
- nx.draw_networkx_nodes(
34
- self.graph,
35
- pos,
36
- node_size=self.node_size,
37
- node_color=node_colors,
38
- cmap=plt.cm.tab20,
39
- alpha=self.alpha,
40
- )
41
- nx.draw_networkx_edges(self.graph, pos, alpha=0.3, width=0.2)
42
- plt.axis("off")
43
-
44
- if save_path:
45
- plt.savefig(save_path, dpi=300, bbox_inches="tight")
46
- print("Saved to", save_path)
47
- else:
48
- plt.show()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/operators/build_kg/build_kg_service.py CHANGED
@@ -1,6 +1,4 @@
1
- from typing import List
2
-
3
- import pandas as pd
4
 
5
  from graphgen.bases import BaseGraphStorage, BaseLLMWrapper, BaseOperator
6
  from graphgen.bases.datatypes import Chunk
@@ -13,9 +11,15 @@ from .build_text_kg import build_text_kg
13
 
14
  class BuildKGService(BaseOperator):
15
  def __init__(
16
- self, working_dir: str = "cache", graph_backend: str = "kuzu", **build_kwargs
 
 
 
 
17
  ):
18
- super().__init__(working_dir=working_dir, op_name="build_kg_service")
 
 
19
  self.llm_client: BaseLLMWrapper = init_llm("synthesizer")
20
  self.graph_storage: BaseGraphStorage = init_storage(
21
  backend=graph_backend, working_dir=working_dir, namespace="graph"
@@ -23,21 +27,15 @@ class BuildKGService(BaseOperator):
23
  self.build_kwargs = build_kwargs
24
  self.max_loop: int = int(self.build_kwargs.get("max_loop", 3))
25
 
26
- def process(self, batch: pd.DataFrame) -> pd.DataFrame:
27
- docs = batch.to_dict(orient="records")
28
- docs = [Chunk.from_dict(doc["_chunk_id"], doc) for doc in docs]
29
-
30
- # consume the chunks and build kg
31
- nodes, edges = self.build_kg(docs)
32
- return pd.DataFrame(
33
- [{"node": node, "edge": []} for node in nodes]
34
- + [{"node": [], "edge": edge} for edge in edges]
35
- )
36
-
37
- def build_kg(self, chunks: List[Chunk]) -> tuple:
38
  """
39
  Build knowledge graph (KG) and merge into kg_instance
 
 
 
 
40
  """
 
41
  text_chunks = [chunk for chunk in chunks if chunk.type == "text"]
42
  mm_chunks = [
43
  chunk
@@ -75,4 +73,34 @@ class BuildKGService(BaseOperator):
75
  self.graph_storage.index_done_callback()
76
  logger.info("Knowledge graph building completed.")
77
 
78
- return nodes, edges
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
 
 
2
 
3
  from graphgen.bases import BaseGraphStorage, BaseLLMWrapper, BaseOperator
4
  from graphgen.bases.datatypes import Chunk
 
11
 
12
  class BuildKGService(BaseOperator):
13
  def __init__(
14
+ self,
15
+ working_dir: str = "cache",
16
+ kv_backend: str = "rocksdb",
17
+ graph_backend: str = "kuzu",
18
+ **build_kwargs
19
  ):
20
+ super().__init__(
21
+ working_dir=working_dir, kv_backend=kv_backend, op_name="build_kg"
22
+ )
23
  self.llm_client: BaseLLMWrapper = init_llm("synthesizer")
24
  self.graph_storage: BaseGraphStorage = init_storage(
25
  backend=graph_backend, working_dir=working_dir, namespace="graph"
 
27
  self.build_kwargs = build_kwargs
28
  self.max_loop: int = int(self.build_kwargs.get("max_loop", 3))
29
 
30
+ def process(self, batch: list) -> Tuple[list, dict]:
 
 
 
 
 
 
 
 
 
 
 
31
  """
32
  Build knowledge graph (KG) and merge into kg_instance
33
+ :return: A tuple of (results, meta_updates)
34
+ results: A list of dicts containing nodes and edges added to the KG. Each dict has the structure:
35
+ {"_trace_id": str, "node": dict, "edge": dict}
36
+ meta_updates: A dict mapping source IDs to lists of trace IDs for nodes and edges added.
37
  """
38
+ chunks = [Chunk.from_dict(doc["_trace_id"], doc) for doc in batch]
39
  text_chunks = [chunk for chunk in chunks if chunk.type == "text"]
40
  mm_chunks = [
41
  chunk
 
73
  self.graph_storage.index_done_callback()
74
  logger.info("Knowledge graph building completed.")
75
 
76
+ meta_updates = {}
77
+ results = []
78
+ for node in nodes:
79
+ if not node:
80
+ continue
81
+ trace_id = node["entity_name"]
82
+ results.append(
83
+ {
84
+ "_trace_id": trace_id,
85
+ "node": node,
86
+ "edge": {},
87
+ }
88
+ )
89
+ source_ids = node.get("source_id", "").split("<SEP>")
90
+ for source_id in source_ids:
91
+ meta_updates.setdefault(source_id, []).append(trace_id)
92
+ for edge in edges:
93
+ if not edge:
94
+ continue
95
+ trace_id = frozenset((edge["src_id"], edge["tgt_id"]))
96
+ results.append(
97
+ {
98
+ "_trace_id": str(trace_id),
99
+ "node": {},
100
+ "edge": edge,
101
+ }
102
+ )
103
+ source_ids = edge.get("source_id", "").split("<SEP>")
104
+ for source_id in source_ids:
105
+ meta_updates.setdefault(source_id, []).append(str(trace_id))
106
+ return results, meta_updates
graphgen/operators/build_kg/build_text_kg.py CHANGED
@@ -30,6 +30,7 @@ def build_text_kg(
30
  desc="[2/4]Extracting entities and relationships from chunks",
31
  unit="chunk",
32
  )
 
33
 
34
  nodes = defaultdict(list)
35
  edges = defaultdict(list)
 
30
  desc="[2/4]Extracting entities and relationships from chunks",
31
  unit="chunk",
32
  )
33
+ results = [res for res in results if res]
34
 
35
  nodes = defaultdict(list)
36
  edges = defaultdict(list)
graphgen/operators/chunk/chunk_service.py CHANGED
@@ -1,17 +1,14 @@
1
  import os
2
  from functools import lru_cache
3
- from typing import Union
4
-
5
- import pandas as pd
6
 
7
  from graphgen.bases import BaseOperator
8
- from graphgen.common import init_storage
9
  from graphgen.models import (
10
  ChineseRecursiveTextSplitter,
11
  RecursiveCharacterSplitter,
12
  Tokenizer,
13
  )
14
- from graphgen.utils import compute_content_hash, detect_main_language
15
 
16
  _MAPPING = {
17
  "en": RecursiveCharacterSplitter,
@@ -45,26 +42,25 @@ class ChunkService(BaseOperator):
45
  def __init__(
46
  self, working_dir: str = "cache", kv_backend: str = "rocksdb", **chunk_kwargs
47
  ):
48
- super().__init__(working_dir=working_dir, op_name="chunk_service")
 
 
49
  tokenizer_model = os.getenv("TOKENIZER_MODEL", "cl100k_base")
50
  self.tokenizer_instance: Tokenizer = Tokenizer(model_name=tokenizer_model)
51
- self.chunk_storage = init_storage(
52
- backend=kv_backend,
53
- working_dir=working_dir,
54
- namespace="chunk",
55
- )
56
  self.chunk_kwargs = chunk_kwargs
57
 
58
- def process(self, batch: pd.DataFrame) -> pd.DataFrame:
59
- docs = batch.to_dict(orient="records")
60
- return pd.DataFrame(self.chunk_documents(docs))
61
-
62
- def chunk_documents(self, new_docs: list) -> list:
63
- chunks = []
64
- for doc in new_docs:
65
- doc_id = doc.get("_doc_id")
 
 
 
66
  doc_type = doc.get("type")
67
-
68
  if doc_type == "text":
69
  doc_language = detect_main_language(doc["content"])
70
  text_chunks = split_chunks(
@@ -72,32 +68,30 @@ class ChunkService(BaseOperator):
72
  language=doc_language,
73
  **self.chunk_kwargs,
74
  )
75
-
76
- chunks.extend(
77
- [
78
- {
79
- "_chunk_id": compute_content_hash(
80
- chunk_text, prefix="chunk-"
81
- ),
82
- "content": chunk_text,
83
- "type": "text",
84
- "_doc_id": doc_id,
85
- "length": len(self.tokenizer_instance.encode(chunk_text))
86
  if self.tokenizer_instance
87
- else len(chunk_text),
88
  "language": doc_language,
89
- }
90
- for chunk_text in text_chunks
91
- ]
92
- )
 
 
 
93
  else:
94
  # other types of documents(images, sequences) are not chunked
95
- chunks.append(
96
- {
97
- "_chunk_id": doc_id.replace("doc-", f"{doc_type}-"),
98
- **doc,
99
- }
100
- )
101
- self.chunk_storage.upsert({chunk["_chunk_id"]: chunk for chunk in chunks})
102
- self.chunk_storage.index_done_callback()
103
- return chunks
 
1
  import os
2
  from functools import lru_cache
3
+ from typing import Tuple, Union
 
 
4
 
5
  from graphgen.bases import BaseOperator
 
6
  from graphgen.models import (
7
  ChineseRecursiveTextSplitter,
8
  RecursiveCharacterSplitter,
9
  Tokenizer,
10
  )
11
+ from graphgen.utils import detect_main_language
12
 
13
  _MAPPING = {
14
  "en": RecursiveCharacterSplitter,
 
42
  def __init__(
43
  self, working_dir: str = "cache", kv_backend: str = "rocksdb", **chunk_kwargs
44
  ):
45
+ super().__init__(
46
+ working_dir=working_dir, kv_backend=kv_backend, op_name="chunk"
47
+ )
48
  tokenizer_model = os.getenv("TOKENIZER_MODEL", "cl100k_base")
49
  self.tokenizer_instance: Tokenizer = Tokenizer(model_name=tokenizer_model)
 
 
 
 
 
50
  self.chunk_kwargs = chunk_kwargs
51
 
52
+ def process(self, batch: list) -> Tuple[list, dict]:
53
+ """
54
+ Chunk the documents in the batch.
55
+ :return: A tuple of (results, meta_updates)
56
+ results: A list of chunked documents. Each chunked document is a dict with the structure:
57
+ {"_trace_id": str, "content": str, "type": str, "metadata": {"length": int, "language": str, ...}
58
+ meta_updates: A dict mapping source document IDs to lists of trace IDs for the chunked documents.
59
+ """
60
+ results = []
61
+ meta_updates = {}
62
+ for doc in batch:
63
  doc_type = doc.get("type")
 
64
  if doc_type == "text":
65
  doc_language = detect_main_language(doc["content"])
66
  text_chunks = split_chunks(
 
68
  language=doc_language,
69
  **self.chunk_kwargs,
70
  )
71
+ for text_chunk in text_chunks:
72
+ chunk = {
73
+ "content": text_chunk,
74
+ "type": "text",
75
+ "metadata": {
76
+ "length": len(self.tokenizer_instance.encode(text_chunk))
 
 
 
 
 
77
  if self.tokenizer_instance
78
+ else len(text_chunk),
79
  "language": doc_language,
80
+ },
81
+ }
82
+ chunk["_trace_id"] = self.get_trace_id(chunk)
83
+ results.append(chunk)
84
+ meta_updates.setdefault(doc["_trace_id"], []).append(
85
+ chunk["_trace_id"]
86
+ )
87
  else:
88
  # other types of documents(images, sequences) are not chunked
89
+ data = doc.copy()
90
+ input_trace_id = data.pop("_trace_id")
91
+ content = data.pop("content") if "content" in data else ""
92
+ doc_type = data.pop("type")
93
+ chunk = {"content": content, "type": doc_type, "metadata": data}
94
+ chunk["_trace_id"] = self.get_trace_id(chunk)
95
+ results.append(chunk)
96
+ meta_updates.setdefault(input_trace_id, []).append(chunk["_trace_id"])
97
+ return results, meta_updates
graphgen/operators/evaluate/evaluate_kg.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+
3
+ from graphgen.bases import BaseGraphStorage
4
+ from graphgen.utils import logger
5
+
6
+
7
+ def evaluate_kg(
8
+ kg_evaluators: Dict[str, Any],
9
+ kg_instance: BaseGraphStorage,
10
+ ) -> Dict[str, Any]:
11
+ results = {}
12
+ for key, kg_evaluator in kg_evaluators.items():
13
+ results[key] = kg_evaluator.evaluate(kg_instance)
14
+ logger.info(f"KG Evaluation result for {key}: {results[key]}")
15
+ return results
graphgen/operators/evaluate/evaluate_qa.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ from graphgen.bases import QAPair
4
+ from graphgen.utils import run_concurrent
5
+
6
+
7
+ def transform_to_qa_format(
8
+ items: list[dict], format_hint: str = "auto"
9
+ ) -> list[dict[str, str]]:
10
+ extractors = {
11
+ "ChatML": lambda x: (
12
+ next(
13
+ (
14
+ m["content"]
15
+ for m in x.get("messages", [])
16
+ if m.get("role") == "user"
17
+ ),
18
+ "",
19
+ ),
20
+ next(
21
+ (
22
+ m["content"]
23
+ for m in x.get("messages", [])
24
+ if m.get("role") == "assistant"
25
+ ),
26
+ "",
27
+ ),
28
+ ),
29
+ "Alpaca": lambda x: (
30
+ f"{x.get('instruction', '')}\n\n{x['input']}".strip()
31
+ if x.get("input")
32
+ else x.get("instruction", ""),
33
+ x.get("output", ""),
34
+ ),
35
+ "Sharegpt": lambda x: (
36
+ next(
37
+ (
38
+ c["value"]
39
+ for c in x.get("conversations", [])
40
+ if c.get("from") == "human"
41
+ ),
42
+ "",
43
+ ),
44
+ next(
45
+ (
46
+ c["value"]
47
+ for c in x.get("conversations", [])
48
+ if c.get("from") in ("gpt", "assistant")
49
+ ),
50
+ "",
51
+ ),
52
+ ),
53
+ }
54
+
55
+ auto_detect = {
56
+ "messages": "ChatML",
57
+ "conversations": "Sharegpt",
58
+ "instruction": "Alpaca",
59
+ }
60
+
61
+ transformed = []
62
+ for item in items:
63
+ fmt = format_hint
64
+ if fmt == "auto":
65
+ fmt = next(
66
+ (fmt_name for key, fmt_name in auto_detect.items() if key in item), None
67
+ )
68
+ if not fmt:
69
+ raise ValueError(
70
+ "Could not auto-detect format. Please specify format_hint."
71
+ )
72
+
73
+ question, answer = extractors[fmt](item)
74
+ options = None
75
+ if "\nOptions:\n" in question:
76
+ q_part, opt_part = question.split("\nOptions:\n", 1)
77
+ question = q_part
78
+ options = {
79
+ k.strip(): v.strip()
80
+ for line in opt_part.strip().split("\n")
81
+ if "." in line
82
+ for k, v in [line.split(".", 1)]
83
+ }
84
+
85
+ result = {"question": question.strip(), "answer": answer.strip()}
86
+ if options:
87
+ result["options"] = options
88
+ transformed.append(result)
89
+
90
+ return transformed
91
+
92
+
93
+ def evaluate_qa(
94
+ qa_evaluators: dict[str, Any], items: list[dict[str, Any]]
95
+ ) -> dict[str, Any]:
96
+ items = transform_to_qa_format(items)
97
+ items = [QAPair.from_dict(item) for item in items]
98
+
99
+ results = {}
100
+ for key, qa_evaluator in qa_evaluators.items():
101
+ result = run_concurrent(
102
+ qa_evaluator.evaluate,
103
+ items,
104
+ desc=f"Evaluating QA with {key}",
105
+ )
106
+ results[key] = result
107
+ return results
graphgen/operators/evaluate/evaluate_service.py CHANGED
@@ -1,10 +1,12 @@
1
- from typing import Any, Dict
2
 
3
- import pandas as pd
4
-
5
- from graphgen.bases import BaseLLMWrapper, BaseOperator, QAPair
6
  from graphgen.common import init_llm, init_storage
7
- from graphgen.utils import logger, run_concurrent
 
 
 
 
8
 
9
 
10
  class EvaluateService(BaseOperator):
@@ -15,167 +17,135 @@ class EvaluateService(BaseOperator):
15
 
16
  def __init__(
17
  self,
 
 
18
  working_dir: str = "cache",
19
- metrics: list[str] = None,
20
  graph_backend: str = "kuzu",
21
  kv_backend: str = "rocksdb",
22
  **kwargs,
23
  ):
24
- super().__init__(working_dir=working_dir, op_name="evaluate_service")
 
 
25
  self.llm_client: BaseLLMWrapper = init_llm("synthesizer")
26
  self.metrics = metrics or []
27
  self.kwargs = kwargs
28
  self.graph_storage = init_storage(
29
  backend=graph_backend, working_dir=working_dir, namespace="graph"
30
  )
31
- self.chunk_storage = init_storage(
32
- backend=kv_backend, working_dir=working_dir, namespace="chunk"
33
- )
34
 
35
  # Initialize evaluators
36
- self.qa_evaluators = {}
37
- self.kg_evaluators = {}
38
- self._init_evaluators()
39
-
40
- def _init_evaluators(self):
41
- """Initialize QA and KG evaluators based on metrics."""
42
- for metric in self.metrics:
43
- if metric == "qa_length":
44
- from graphgen.models import LengthEvaluator
45
-
46
- self.qa_evaluators[metric] = LengthEvaluator()
47
- elif metric == "qa_mtld":
48
- from graphgen.models import MTLDEvaluator
49
-
50
- self.qa_evaluators[metric] = MTLDEvaluator(
51
- **self.kwargs.get("mtld_params", {})
52
- )
53
- elif metric == "qa_reward_score":
54
- from graphgen.models import RewardEvaluator
55
-
56
- self.qa_evaluators[metric] = RewardEvaluator(
57
- **self.kwargs.get("reward_params", {})
58
- )
59
- elif metric == "qa_uni_score":
60
- from graphgen.models import UniEvaluator
61
-
62
- self.qa_evaluators[metric] = UniEvaluator(
63
- **self.kwargs.get("uni_params", {})
64
- )
65
- elif metric == "kg_accuracy":
66
- from graphgen.models import AccuracyEvaluator
67
-
68
- self.kg_evaluators[metric] = AccuracyEvaluator(
69
- graph_storage=self.graph_storage,
70
- chunk_storage=self.chunk_storage,
71
- llm_client=self.llm_client,
72
- )
73
- elif metric == "kg_consistency":
74
- from graphgen.models import ConsistencyEvaluator
75
-
76
- self.kg_evaluators[metric] = ConsistencyEvaluator(
77
- graph_storage=self.graph_storage,
78
- chunk_storage=self.chunk_storage,
79
- llm_client=self.llm_client,
80
- )
81
- elif metric == "kg_structure":
82
- from graphgen.models import StructureEvaluator
83
-
84
- self.kg_evaluators[metric] = StructureEvaluator(
85
- graph_storage=self.graph_storage,
86
- **self.kwargs.get("structure_params", {}),
87
- )
88
- else:
89
- raise ValueError(f"Unknown QA metric: {metric}")
90
-
91
- async def _process_single_qa(self, item: dict[str, Any]) -> dict[str, Any]:
92
- try:
93
- qa_pair = QAPair(
94
- question=str(item.get("question", "")),
95
- answer=str(item.get("answer", "")),
96
  )
97
- if not qa_pair.question or not qa_pair.answer:
98
- logger.error("Empty question or answer, skipping.")
99
- return {}
100
- except Exception as e:
101
- logger.error("Error in QAPair creation: %s", str(e))
102
- return {}
103
-
104
- for metric, evaluator in self.qa_evaluators.items():
105
- try:
106
- score = evaluator.evaluate(qa_pair)
107
- if isinstance(score, dict):
108
- for sub_metric, sub_score in score.items():
109
- item[f"{metric}_{sub_metric}"] = float(sub_score)
110
- else:
111
- item[metric] = float(score)
112
- except Exception as e:
113
- logger.error("Error in %s evaluation: %s", metric, str(e))
114
- item[metric] = None
115
- return item
116
-
117
- def _evaluate_qa(self, items: list[dict[str, Any]]) -> list[dict[str, Any]]:
118
- def transform_messages_format(items: list[dict]) -> list[dict]:
119
- """
120
- Transform from [{'messages': [...]}, ...] to [{'question': '...', 'answer': '...'}, ...]
121
- """
122
- transformed = []
123
- for item in items:
124
- messages = item.get("messages", [])
125
- question = next(
126
- (m["content"] for m in messages if m.get("role") == "user"), ""
127
- )
128
- answer = next(
129
- (m["content"] for m in messages if m.get("role") == "assistant"), ""
130
- )
131
-
132
- transformed.append({"question": question, "answer": answer})
133
- return transformed
134
-
135
- if not items:
136
- return []
137
-
138
- if not self.qa_evaluators:
139
- logger.warning("No QA evaluators initialized, skipping QA evaluation")
140
- return []
141
-
142
- items = transform_messages_format(items)
143
- results = run_concurrent(
144
- self._process_single_qa,
145
- items,
146
- desc="Evaluating QA items",
147
- unit="item",
148
  )
149
 
150
- results = [item for item in results if item]
151
- return results
152
-
153
- def _evaluate_kg(self) -> Dict[str, Any]:
154
- results = {}
155
-
156
- for metric, evaluator in self.kg_evaluators.items():
157
- try:
158
- logger.info("Running %s evaluation...", metric)
159
- score = evaluator.evaluate()
160
- results[metric] = score
161
- except Exception as e:
162
- logger.error("Error in %s evaluation: %s", metric, str(e))
163
- results[metric] = {"error": str(e)}
164
- return results
165
-
166
- def process(self, batch: pd.DataFrame) -> pd.DataFrame:
167
- # QA evaluation
168
- if len(self.qa_evaluators) > 0:
169
- items = batch.to_dict(orient="records")
170
- results = self._evaluate_qa(items)
171
- return pd.DataFrame(results)
172
-
173
- # KG evaluation
174
- if len(self.kg_evaluators) > 0:
175
- results = self._evaluate_kg()
176
- # Convert dict to DataFrame (single row)
177
- return pd.DataFrame([results])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
  # No metrics specified
180
  logger.warning("No metrics specified, returning empty DataFrame")
181
- return pd.DataFrame()
 
1
+ from typing import Tuple
2
 
3
+ from graphgen.bases import BaseLLMWrapper, BaseOperator
 
 
4
  from graphgen.common import init_llm, init_storage
5
+ from graphgen.utils import logger
6
+
7
+ from .evaluate_kg import evaluate_kg
8
+ from .evaluate_qa import evaluate_qa
9
+ from .evaluate_triple import evaluate_triple
10
 
11
 
12
  class EvaluateService(BaseOperator):
 
17
 
18
  def __init__(
19
  self,
20
+ target: str,
21
+ metrics: list[str],
22
  working_dir: str = "cache",
 
23
  graph_backend: str = "kuzu",
24
  kv_backend: str = "rocksdb",
25
  **kwargs,
26
  ):
27
+ super().__init__(
28
+ working_dir=working_dir, kv_backend=kv_backend, op_name="evaluate"
29
+ )
30
  self.llm_client: BaseLLMWrapper = init_llm("synthesizer")
31
  self.metrics = metrics or []
32
  self.kwargs = kwargs
33
  self.graph_storage = init_storage(
34
  backend=graph_backend, working_dir=working_dir, namespace="graph"
35
  )
 
 
 
36
 
37
  # Initialize evaluators
38
+ self.target = target
39
+ self.src_storage = None
40
+ self.tgt_storage = None
41
+ self.evaluators = {}
42
+ self._init_evaluators(self.target, metrics)
43
+
44
+ def _init_evaluators(self, target: str, metrics: list[str]):
45
+ """Initialize evaluators based on target and metrics."""
46
+ if target not in {"qa", "kg", "triple"}:
47
+ raise ValueError(f"Unknown evaluation target: {target}")
48
+
49
+ # Delegate to target-specific initializer
50
+ getattr(self, f"_init_{target}_evaluators")(metrics)
51
+
52
+ def _init_qa_evaluators(self, metrics: list[str]):
53
+ """Initialize QA evaluators."""
54
+ for metric in metrics:
55
+ self.evaluators[metric] = self._create_qa_evaluator(metric)
56
+
57
+ def _create_qa_evaluator(self, metric: str):
58
+ """Factory method for QA evaluator instances."""
59
+ if metric == "length":
60
+ from graphgen.models import LengthEvaluator
61
+
62
+ return LengthEvaluator()
63
+ if metric == "mtld":
64
+ from graphgen.models import MTLDEvaluator
65
+
66
+ return MTLDEvaluator(**self.kwargs.get("mtld_params", {}))
67
+ if metric == "reward_score":
68
+ from graphgen.models import RewardEvaluator
69
+
70
+ return RewardEvaluator(**self.kwargs.get("reward_params", {}))
71
+ if metric == "uni_score":
72
+ from graphgen.models import UniEvaluator
73
+
74
+ return UniEvaluator(**self.kwargs.get("uni_params", {}))
75
+ raise ValueError(f"Unknown QA metric: {metric}")
76
+
77
+ def _init_kg_evaluators(self, metrics: list[str]):
78
+ """Initialize KG evaluators."""
79
+ for metric in metrics:
80
+ if metric != "structure":
81
+ raise ValueError(f"Unknown KG metric: {metric}")
82
+ from graphgen.models import StructureEvaluator
83
+
84
+ self.evaluators[metric] = StructureEvaluator(
85
+ **self.kwargs.get("structure_params", {})
 
 
 
 
 
 
 
 
 
 
 
 
86
  )
87
+
88
+ def _init_triple_evaluators(self, metrics: list[str]):
89
+ """Initialize Triple evaluators."""
90
+ self.src_storage = init_storage(
91
+ backend=self.kv_backend,
92
+ working_dir=self.working_dir,
93
+ namespace=self.kwargs["src_namespace"],
94
+ )
95
+ self.tgt_storage = init_storage(
96
+ backend=self.kv_backend,
97
+ working_dir=self.working_dir,
98
+ namespace=self.kwargs["tgt_namespace"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  )
100
 
101
+ for metric in metrics:
102
+ if metric != "accuracy":
103
+ raise ValueError(f"Unknown Triple metric: {metric}")
104
+ from graphgen.models import AccuracyEvaluator
105
+
106
+ self.evaluators[metric] = AccuracyEvaluator(llm_client=self.llm_client)
107
+
108
+ def process(self, batch: list) -> Tuple[list, dict]:
109
+ final_results = []
110
+ meta_updates = {}
111
+
112
+ # 1. QA Evaluation (per item)
113
+ if self.target == "qa" and self.evaluators:
114
+ results: dict = evaluate_qa(self.evaluators, batch)
115
+ for i, item in enumerate(batch):
116
+ metrics = {}
117
+ for _, scores in results.items():
118
+ metrics.update(scores[i])
119
+ item.update({"metrics": metrics})
120
+ input_trace_id = item.pop("_trace_id")
121
+ item["_trace_id"] = self.get_trace_id(item)
122
+ final_results.append(item)
123
+ meta_updates.setdefault(input_trace_id, []).append(item["_trace_id"])
124
+
125
+ return final_results, meta_updates
126
+
127
+ # 2. KG evaluation
128
+ if self.target == "kg" and self.evaluators:
129
+ results = evaluate_kg(
130
+ self.evaluators,
131
+ self.graph_storage,
132
+ )
133
+ if not results:
134
+ logger.warning("No KG evaluation results, returning empty DataFrame")
135
+ return [], {}
136
+ results["_trace_id"] = self.get_trace_id(results)
137
+ final_results.append(results)
138
+ return final_results, {}
139
+
140
+ # 3. Triple evaluation
141
+ if self.target == "triple" and self.evaluators:
142
+ results = evaluate_triple(
143
+ self.evaluators, self.src_storage, self.tgt_storage
144
+ )
145
+ results["_trace_id"] = "evaluate-triple-result"
146
+ final_results.append(results)
147
+ return final_results, {}
148
 
149
  # No metrics specified
150
  logger.warning("No metrics specified, returning empty DataFrame")
151
+ return [], {}
graphgen/operators/evaluate/evaluate_triple.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ from graphgen.bases import BaseKVStorage
4
+ from graphgen.utils import logger, run_concurrent
5
+
6
+
7
+ def evaluate_triple(
8
+ triple_evaluators: dict[str, Any],
9
+ src_storage: BaseKVStorage,
10
+ tgt_storage: BaseKVStorage,
11
+ ) -> dict[str, Any]:
12
+ forward_meta = tgt_storage.get_by_id("_meta_forward")
13
+
14
+ tasks = []
15
+ for chunk_id, unit_ids in forward_meta.items():
16
+ chunk_content = str(src_storage.get_by_id(chunk_id))
17
+
18
+ nodes = []
19
+ edges = []
20
+
21
+ for unit_id in unit_ids:
22
+ unit_data = tgt_storage.get_by_id(unit_id)
23
+ if "node" in unit_data and unit_data["node"]:
24
+ nodes.append(unit_data["node"])
25
+ if "edge" in unit_data and unit_data["edge"]:
26
+ edges.append(unit_data["edge"])
27
+
28
+ tasks.append((chunk_content, nodes, edges))
29
+
30
+ results = {}
31
+ for key, triple_evaluator in triple_evaluators.items():
32
+ logger.info(f"Evaluating Triples with metric: {key}...")
33
+ result = run_concurrent(
34
+ triple_evaluator.evaluate,
35
+ tasks,
36
+ desc=f"Evaluating Triples with {key}",
37
+ )
38
+ results[key] = result
39
+ return results
graphgen/operators/extract/extract_service.py CHANGED
@@ -1,16 +1,19 @@
1
  import json
 
2
 
3
- import pandas as pd
4
-
5
- from graphgen.bases import BaseLLMWrapper, BaseOperator
6
  from graphgen.common import init_llm
7
  from graphgen.models.extractor import SchemaGuidedExtractor
8
  from graphgen.utils import logger, run_concurrent
9
 
10
 
11
  class ExtractService(BaseOperator):
12
- def __init__(self, working_dir: str = "cache", **extract_kwargs):
13
- super().__init__(working_dir=working_dir, op_name="extract_service")
 
 
 
 
14
  self.llm_client: BaseLLMWrapper = init_llm("synthesizer")
15
  self.extract_kwargs = extract_kwargs
16
  self.method = self.extract_kwargs.get("method")
@@ -22,24 +25,32 @@ class ExtractService(BaseOperator):
22
  else:
23
  raise ValueError(f"Unsupported extraction method: {self.method}")
24
 
25
- def process(self, batch: pd.DataFrame) -> pd.DataFrame:
26
- items = batch.to_dict(orient="records")
27
- return pd.DataFrame(self.extract(items))
28
-
29
- def extract(self, items: list[dict]) -> list[dict]:
30
-
31
- logger.info("Start extracting information from %d items", len(items))
32
-
 
 
33
  results = run_concurrent(
34
  self.extractor.extract,
35
- items,
36
  desc="Extracting information",
37
  unit="item",
38
  )
39
- results = self.extractor.merge_extractions(results)
40
 
41
- results = [
42
- {"_extract_id": key, "extracted_data": value}
43
- for key, value in results.items()
44
- ]
45
- return results
 
 
 
 
 
 
 
 
1
  import json
2
+ from typing import Tuple
3
 
4
+ from graphgen.bases import BaseLLMWrapper, BaseOperator, Chunk
 
 
5
  from graphgen.common import init_llm
6
  from graphgen.models.extractor import SchemaGuidedExtractor
7
  from graphgen.utils import logger, run_concurrent
8
 
9
 
10
  class ExtractService(BaseOperator):
11
+ def __init__(
12
+ self, working_dir: str = "cache", kv_backend: str = "rocksdb", **extract_kwargs
13
+ ):
14
+ super().__init__(
15
+ working_dir=working_dir, kv_backend=kv_backend, op_name="extract"
16
+ )
17
  self.llm_client: BaseLLMWrapper = init_llm("synthesizer")
18
  self.extract_kwargs = extract_kwargs
19
  self.method = self.extract_kwargs.get("method")
 
25
  else:
26
  raise ValueError(f"Unsupported extraction method: {self.method}")
27
 
28
+ def process(self, batch: list) -> Tuple[list, dict]:
29
+ """
30
+ Extract information from the batch of chunks.
31
+ :return: A tuple of (results, meta_updates)
32
+ results: A list of dicts containing extracted information. Each dict has the structure:
33
+ {"_trace_id": str, "content": dict}
34
+ meta_updates: A dict mapping source IDs to lists of trace IDs for the extracted information.
35
+ """
36
+ logger.info("Start extracting information from %d items", len(batch))
37
+ chunks = [Chunk.from_dict(item["_trace_id"], item) for item in batch]
38
  results = run_concurrent(
39
  self.extractor.extract,
40
+ chunks,
41
  desc="Extracting information",
42
  unit="item",
43
  )
 
44
 
45
+ meta_updates = {}
46
+ final_results = []
47
+ # chunk -> extracted info
48
+ for input_trace_id, result in zip(
49
+ [item["_trace_id"] for item in batch], results
50
+ ):
51
+ if not result:
52
+ continue
53
+ result = {"_trace_id": self.get_trace_id(result), "content": result}
54
+ meta_updates.setdefault(input_trace_id, []).append(result["_trace_id"])
55
+ final_results.append(result)
56
+ return final_results, meta_updates
graphgen/operators/generate/generate_service.py CHANGED
@@ -1,9 +1,6 @@
1
- import json
2
-
3
- import pandas as pd
4
-
5
- from graphgen.bases import BaseLLMWrapper, BaseOperator
6
- from graphgen.common import init_llm
7
  from graphgen.utils import logger, run_concurrent
8
 
9
 
@@ -15,12 +12,18 @@ class GenerateService(BaseOperator):
15
  def __init__(
16
  self,
17
  working_dir: str = "cache",
 
18
  method: str = "aggregated",
19
  data_format: str = "ChatML",
20
  **generate_kwargs,
21
  ):
22
- super().__init__(working_dir=working_dir, op_name="generate_service")
 
 
23
  self.llm_client: BaseLLMWrapper = init_llm("synthesizer")
 
 
 
24
 
25
  self.method = method
26
  self.data_format = data_format
@@ -76,32 +79,31 @@ class GenerateService(BaseOperator):
76
  else:
77
  raise ValueError(f"Unsupported generation mode: {method}")
78
 
79
- def process(self, batch: pd.DataFrame) -> pd.DataFrame:
80
- items = batch.to_dict(orient="records")
81
- return pd.DataFrame(self.generate(items))
82
-
83
- def generate(self, items: list[dict]) -> list[dict]:
84
  """
85
  Generate question-answer pairs based on nodes and edges.
86
- :param items
87
- :return: QA pairs
88
  """
89
- logger.info("[Generation] mode: %s, batches: %d", self.method, len(items))
90
- items = [
91
- (json.loads(item["nodes"]), json.loads(item["edges"])) for item in items
92
- ]
93
  results = run_concurrent(
94
  self.generator.generate,
95
- items,
96
- desc="[4/4]Generating QAs",
97
  unit="batch",
98
  )
99
 
100
- # Filter out empty results
101
- results = [res for res in results if res]
102
-
103
- results = self.generator.format_generation_results(
104
- results, output_data_format=self.data_format
105
- )
106
-
107
- return results
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ from graphgen.bases import BaseKVStorage, BaseLLMWrapper, BaseOperator
3
+ from graphgen.common import init_llm, init_storage
 
 
 
4
  from graphgen.utils import logger, run_concurrent
5
 
6
 
 
12
  def __init__(
13
  self,
14
  working_dir: str = "cache",
15
+ kv_backend: str = "rocksdb",
16
  method: str = "aggregated",
17
  data_format: str = "ChatML",
18
  **generate_kwargs,
19
  ):
20
+ super().__init__(
21
+ working_dir=working_dir, kv_backend=kv_backend, op_name="generate"
22
+ )
23
  self.llm_client: BaseLLMWrapper = init_llm("synthesizer")
24
+ self.generate_storage: BaseKVStorage = init_storage(
25
+ backend=kv_backend, working_dir=working_dir, namespace="generate"
26
+ )
27
 
28
  self.method = method
29
  self.data_format = data_format
 
79
  else:
80
  raise ValueError(f"Unsupported generation mode: {method}")
81
 
82
+ def process(self, batch: list) -> Tuple[list, dict]:
 
 
 
 
83
  """
84
  Generate question-answer pairs based on nodes and edges.
 
 
85
  """
86
+ logger.info("[Generation] mode: %s, batches: %d", self.method, len(batch))
87
+ triples = [(item["nodes"], item["edges"]) for item in batch]
 
 
88
  results = run_concurrent(
89
  self.generator.generate,
90
+ triples,
91
+ desc="Generating QAs",
92
  unit="batch",
93
  )
94
 
95
+ meta_updates = {}
96
+ final_results = []
97
+ for input_trace_id, qa_pairs in zip(
98
+ [item["_trace_id"] for item in batch], results
99
+ ):
100
+ if not qa_pairs:
101
+ continue
102
+ for qa_pair in qa_pairs:
103
+ res = self.generator.format_generation_results(
104
+ qa_pair, output_data_format=self.data_format
105
+ )
106
+ res["_trace_id"] = self.get_trace_id(res)
107
+ final_results.append(res)
108
+ meta_updates.setdefault(input_trace_id, []).append(res["_trace_id"])
109
+ return final_results, meta_updates