chenzihong-gavin commited on
Commit
0682cc6
·
1 Parent(s): 4b2a9c2
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. hf-repo/LICENSE +0 -201
  2. hf-repo/README.md +0 -43
  3. hf-repo/app.py +0 -587
  4. hf-repo/graphgen/__init__.py +0 -0
  5. hf-repo/graphgen/configs/README.md +0 -1
  6. hf-repo/graphgen/configs/aggregated_config.yaml +0 -21
  7. hf-repo/graphgen/configs/atomic_config.yaml +0 -21
  8. hf-repo/graphgen/configs/cot_config.yaml +0 -13
  9. hf-repo/graphgen/configs/multi_hop_config.yaml +0 -21
  10. hf-repo/graphgen/evaluate.py +0 -142
  11. hf-repo/graphgen/generate.py +0 -103
  12. hf-repo/graphgen/graphgen.py +0 -395
  13. hf-repo/graphgen/judge.py +0 -60
  14. hf-repo/graphgen/models/__init__.py +0 -45
  15. hf-repo/graphgen/models/community/__init__.py +0 -0
  16. hf-repo/graphgen/models/community/community_detector.py +0 -95
  17. hf-repo/graphgen/models/embed/__init__.py +0 -0
  18. hf-repo/graphgen/models/embed/embedding.py +0 -29
  19. hf-repo/graphgen/models/evaluate/__init__.py +0 -0
  20. hf-repo/graphgen/models/evaluate/base_evaluator.py +0 -51
  21. hf-repo/graphgen/models/evaluate/length_evaluator.py +0 -22
  22. hf-repo/graphgen/models/evaluate/mtld_evaluator.py +0 -76
  23. hf-repo/graphgen/models/evaluate/reward_evaluator.py +0 -101
  24. hf-repo/graphgen/models/evaluate/uni_evaluator.py +0 -159
  25. hf-repo/graphgen/models/llm/__init__.py +0 -0
  26. hf-repo/graphgen/models/llm/limitter.py +0 -88
  27. hf-repo/graphgen/models/llm/openai_model.py +0 -155
  28. hf-repo/graphgen/models/llm/tokenizer.py +0 -73
  29. hf-repo/graphgen/models/llm/topk_token_model.py +0 -48
  30. hf-repo/graphgen/models/search/__init__.py +0 -0
  31. hf-repo/graphgen/models/search/db/__init__.py +0 -0
  32. hf-repo/graphgen/models/search/db/uniprot_search.py +0 -64
  33. hf-repo/graphgen/models/search/kg/__init__.py +0 -0
  34. hf-repo/graphgen/models/search/kg/wiki_search.py +0 -37
  35. hf-repo/graphgen/models/search/web/__init__.py +0 -0
  36. hf-repo/graphgen/models/search/web/bing_search.py +0 -43
  37. hf-repo/graphgen/models/search/web/google_search.py +0 -45
  38. hf-repo/graphgen/models/storage/__init__.py +0 -0
  39. hf-repo/graphgen/models/storage/base_storage.py +0 -115
  40. hf-repo/graphgen/models/storage/json_storage.py +0 -87
  41. hf-repo/graphgen/models/storage/networkx_storage.py +0 -159
  42. hf-repo/graphgen/models/strategy/__init__.py +0 -0
  43. hf-repo/graphgen/models/strategy/base_strategy.py +0 -5
  44. hf-repo/graphgen/models/strategy/travserse_strategy.py +0 -30
  45. hf-repo/graphgen/models/text/__init__.py +0 -0
  46. hf-repo/graphgen/models/text/chunk.py +0 -7
  47. hf-repo/graphgen/models/text/text_pair.py +0 -9
  48. hf-repo/graphgen/models/vis/__init__.py +0 -0
  49. hf-repo/graphgen/models/vis/community_visualizer.py +0 -48
  50. hf-repo/graphgen/operators/__init__.py +0 -22
hf-repo/LICENSE DELETED
@@ -1,201 +0,0 @@
1
- Apache License
2
- Version 2.0, January 2004
3
- http://www.apache.org/licenses/
4
-
5
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
-
7
- 1. Definitions.
8
-
9
- "License" shall mean the terms and conditions for use, reproduction,
10
- and distribution as defined by Sections 1 through 9 of this document.
11
-
12
- "Licensor" shall mean the copyright owner or entity authorized by
13
- the copyright owner that is granting the License.
14
-
15
- "Legal Entity" shall mean the union of the acting entity and all
16
- other entities that control, are controlled by, or are under common
17
- control with that entity. For the purposes of this definition,
18
- "control" means (i) the power, direct or indirect, to cause the
19
- direction or management of such entity, whether by contract or
20
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
- outstanding shares, or (iii) beneficial ownership of such entity.
22
-
23
- "You" (or "Your") shall mean an individual or Legal Entity
24
- exercising permissions granted by this License.
25
-
26
- "Source" form shall mean the preferred form for making modifications,
27
- including but not limited to software source code, documentation
28
- source, and configuration files.
29
-
30
- "Object" form shall mean any form resulting from mechanical
31
- transformation or translation of a Source form, including but
32
- not limited to compiled object code, generated documentation,
33
- and conversions to other media types.
34
-
35
- "Work" shall mean the work of authorship, whether in Source or
36
- Object form, made available under the License, as indicated by a
37
- copyright notice that is included in or attached to the work
38
- (an example is provided in the Appendix below).
39
-
40
- "Derivative Works" shall mean any work, whether in Source or Object
41
- form, that is based on (or derived from) the Work and for which the
42
- editorial revisions, annotations, elaborations, or other modifications
43
- represent, as a whole, an original work of authorship. For the purposes
44
- of this License, Derivative Works shall not include works that remain
45
- separable from, or merely link (or bind by name) to the interfaces of,
46
- the Work and Derivative Works thereof.
47
-
48
- "Contribution" shall mean any work of authorship, including
49
- the original version of the Work and any modifications or additions
50
- to that Work or Derivative Works thereof, that is intentionally
51
- submitted to Licensor for inclusion in the Work by the copyright owner
52
- or by an individual or Legal Entity authorized to submit on behalf of
53
- the copyright owner. For the purposes of this definition, "submitted"
54
- means any form of electronic, verbal, or written communication sent
55
- to the Licensor or its representatives, including but not limited to
56
- communication on electronic mailing lists, source code control systems,
57
- and issue tracking systems that are managed by, or on behalf of, the
58
- Licensor for the purpose of discussing and improving the Work, but
59
- excluding communication that is conspicuously marked or otherwise
60
- designated in writing by the copyright owner as "Not a Contribution."
61
-
62
- "Contributor" shall mean Licensor and any individual or Legal Entity
63
- on behalf of whom a Contribution has been received by Licensor and
64
- subsequently incorporated within the Work.
65
-
66
- 2. Grant of Copyright License. Subject to the terms and conditions of
67
- this License, each Contributor hereby grants to You a perpetual,
68
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
- copyright license to reproduce, prepare Derivative Works of,
70
- publicly display, publicly perform, sublicense, and distribute the
71
- Work and such Derivative Works in Source or Object form.
72
-
73
- 3. Grant of Patent License. Subject to the terms and conditions of
74
- this License, each Contributor hereby grants to You a perpetual,
75
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
- (except as stated in this section) patent license to make, have made,
77
- use, offer to sell, sell, import, and otherwise transfer the Work,
78
- where such license applies only to those patent claims licensable
79
- by such Contributor that are necessarily infringed by their
80
- Contribution(s) alone or by combination of their Contribution(s)
81
- with the Work to which such Contribution(s) was submitted. If You
82
- institute patent litigation against any entity (including a
83
- cross-claim or counterclaim in a lawsuit) alleging that the Work
84
- or a Contribution incorporated within the Work constitutes direct
85
- or contributory patent infringement, then any patent licenses
86
- granted to You under this License for that Work shall terminate
87
- as of the date such litigation is filed.
88
-
89
- 4. Redistribution. You may reproduce and distribute copies of the
90
- Work or Derivative Works thereof in any medium, with or without
91
- modifications, and in Source or Object form, provided that You
92
- meet the following conditions:
93
-
94
- (a) You must give any other recipients of the Work or
95
- Derivative Works a copy of this License; and
96
-
97
- (b) You must cause any modified files to carry prominent notices
98
- stating that You changed the files; and
99
-
100
- (c) You must retain, in the Source form of any Derivative Works
101
- that You distribute, all copyright, patent, trademark, and
102
- attribution notices from the Source form of the Work,
103
- excluding those notices that do not pertain to any part of
104
- the Derivative Works; and
105
-
106
- (d) If the Work includes a "NOTICE" text file as part of its
107
- distribution, then any Derivative Works that You distribute must
108
- include a readable copy of the attribution notices contained
109
- within such NOTICE file, excluding those notices that do not
110
- pertain to any part of the Derivative Works, in at least one
111
- of the following places: within a NOTICE text file distributed
112
- as part of the Derivative Works; within the Source form or
113
- documentation, if provided along with the Derivative Works; or,
114
- within a display generated by the Derivative Works, if and
115
- wherever such third-party notices normally appear. The contents
116
- of the NOTICE file are for informational purposes only and
117
- do not modify the License. You may add Your own attribution
118
- notices within Derivative Works that You distribute, alongside
119
- or as an addendum to the NOTICE text from the Work, provided
120
- that such additional attribution notices cannot be construed
121
- as modifying the License.
122
-
123
- You may add Your own copyright statement to Your modifications and
124
- may provide additional or different license terms and conditions
125
- for use, reproduction, or distribution of Your modifications, or
126
- for any such Derivative Works as a whole, provided Your use,
127
- reproduction, and distribution of the Work otherwise complies with
128
- the conditions stated in this License.
129
-
130
- 5. Submission of Contributions. Unless You explicitly state otherwise,
131
- any Contribution intentionally submitted for inclusion in the Work
132
- by You to the Licensor shall be under the terms and conditions of
133
- this License, without any additional terms or conditions.
134
- Notwithstanding the above, nothing herein shall supersede or modify
135
- the terms of any separate license agreement you may have executed
136
- with Licensor regarding such Contributions.
137
-
138
- 6. Trademarks. This License does not grant permission to use the trade
139
- names, trademarks, service marks, or product names of the Licensor,
140
- except as required for reasonable and customary use in describing the
141
- origin of the Work and reproducing the content of the NOTICE file.
142
-
143
- 7. Disclaimer of Warranty. Unless required by applicable law or
144
- agreed to in writing, Licensor provides the Work (and each
145
- Contributor provides its Contributions) on an "AS IS" BASIS,
146
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
- implied, including, without limitation, any warranties or conditions
148
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
- PARTICULAR PURPOSE. You are solely responsible for determining the
150
- appropriateness of using or redistributing the Work and assume any
151
- risks associated with Your exercise of permissions under this License.
152
-
153
- 8. Limitation of Liability. In no event and under no legal theory,
154
- whether in tort (including negligence), contract, or otherwise,
155
- unless required by applicable law (such as deliberate and grossly
156
- negligent acts) or agreed to in writing, shall any Contributor be
157
- liable to You for damages, including any direct, indirect, special,
158
- incidental, or consequential damages of any character arising as a
159
- result of this License or out of the use or inability to use the
160
- Work (including but not limited to damages for loss of goodwill,
161
- work stoppage, computer failure or malfunction, or any and all
162
- other commercial damages or losses), even if such Contributor
163
- has been advised of the possibility of such damages.
164
-
165
- 9. Accepting Warranty or Additional Liability. While redistributing
166
- the Work or Derivative Works thereof, You may choose to offer,
167
- and charge a fee for, acceptance of support, warranty, indemnity,
168
- or other liability obligations and/or rights consistent with this
169
- License. However, in accepting such obligations, You may act only
170
- on Your own behalf and on Your sole responsibility, not on behalf
171
- of any other Contributor, and only if You agree to indemnify,
172
- defend, and hold each Contributor harmless for any liability
173
- incurred by, or claims asserted against, such Contributor by reason
174
- of your accepting any such warranty or additional liability.
175
-
176
- END OF TERMS AND CONDITIONS
177
-
178
- APPENDIX: How to apply the Apache License to your work.
179
-
180
- To apply the Apache License to your work, attach the following
181
- boilerplate notice, with the fields enclosed by brackets "[]"
182
- replaced with your own identifying information. (Don't include
183
- the brackets!) The text should be enclosed in the appropriate
184
- comment syntax for the file format. We also recommend that a
185
- file or class name and description of purpose be included on the
186
- same "printed page" as the copyright notice for easier
187
- identification within third-party archives.
188
-
189
- Copyright [yyyy] [name of copyright owner]
190
-
191
- Licensed under the Apache License, Version 2.0 (the "License");
192
- you may not use this file except in compliance with the License.
193
- You may obtain a copy of the License at
194
-
195
- http://www.apache.org/licenses/LICENSE-2.0
196
-
197
- Unless required by applicable law or agreed to in writing, software
198
- distributed under the License is distributed on an "AS IS" BASIS,
199
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
- See the License for the specific language governing permissions and
201
- limitations under the License.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf-repo/README.md DELETED
@@ -1,43 +0,0 @@
1
- ---
2
- title: GraphGen Demo
3
- emoji: 📊
4
- colorFrom: blue
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: "5.44.0"
8
- python_version: "3.10"
9
- app_file: app.py
10
- suggested_hardware: cpu-basic
11
- pinned: false
12
- short_description: "Knowledge-driven synthetic data generation demo"
13
- tags:
14
- - synthetic-data
15
- - knowledge-graph
16
- - gradio-demo
17
- ---
18
-
19
- # GraphGen Space 🤖📊
20
-
21
- This is the **official Hugging Face Space** for [GraphGen](https://github.com/open-sciencelab/GraphGen) – a framework that leverages knowledge graphs to generate high-quality synthetic question–answer pairs for supervised fine-tuning of LLMs.
22
-
23
- 🔗 Paper: [arXiv 2505.20416](https://arxiv.org/abs/2505.20416)
24
- 🐙 GitHub: [open-sciencelab/GraphGen](https://github.com/open-sciencelab/GraphGen)
25
-
26
- ---
27
-
28
- ## How to use (🖱️ 3 clicks)
29
-
30
- 1. Open the **Gradio app** above.
31
- 2. Upload or paste your source text → click **Generate KG**.
32
- 3. Download the generated QA pairs directly.
33
-
34
- ---
35
-
36
- ## Local quick start (optional)
37
-
38
- ```bash
39
- git clone https://github.com/open-sciencelab/GraphGen
40
- cd GraphGen
41
- uv venv --python 3.10 && uv pip install -r requirements.txt
42
- uv run webui/app.py # http://localhost:7860
43
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf-repo/app.py DELETED
@@ -1,587 +0,0 @@
1
- import json
2
- import os
3
- import sys
4
- import tempfile
5
-
6
- import gradio as gr
7
- import pandas as pd
8
- from gradio_i18n import Translate
9
- from gradio_i18n import gettext as _
10
-
11
- from webui.base import GraphGenParams
12
- from webui.cache_utils import cleanup_workspace, setup_workspace
13
- from webui.count_tokens import count_tokens
14
- from webui.test_api import test_api_connection
15
-
16
- # pylint: disable=wrong-import-position
17
- root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
18
- sys.path.append(root_dir)
19
-
20
- from graphgen.graphgen import GraphGen
21
- from graphgen.models import OpenAIModel, Tokenizer, TraverseStrategy
22
- from graphgen.models.llm.limitter import RPM, TPM
23
- from graphgen.utils import set_logger
24
-
25
- css = """
26
- .center-row {
27
- display: flex;
28
- justify-content: center;
29
- align-items: center;
30
- }
31
- """
32
-
33
-
34
- def init_graph_gen(config: dict, env: dict) -> GraphGen:
35
- # Set up working directory
36
- log_file, working_dir = setup_workspace(os.path.join(root_dir, "cache"))
37
-
38
- set_logger(log_file, if_stream=False)
39
- graph_gen = GraphGen(working_dir=working_dir)
40
-
41
- # Set up LLM clients
42
- graph_gen.synthesizer_llm_client = OpenAIModel(
43
- model_name=env.get("SYNTHESIZER_MODEL", ""),
44
- base_url=env.get("SYNTHESIZER_BASE_URL", ""),
45
- api_key=env.get("SYNTHESIZER_API_KEY", ""),
46
- request_limit=True,
47
- rpm=RPM(env.get("RPM", 1000)),
48
- tpm=TPM(env.get("TPM", 50000)),
49
- )
50
-
51
- graph_gen.trainee_llm_client = OpenAIModel(
52
- model_name=env.get("TRAINEE_MODEL", ""),
53
- base_url=env.get("TRAINEE_BASE_URL", ""),
54
- api_key=env.get("TRAINEE_API_KEY", ""),
55
- request_limit=True,
56
- rpm=RPM(env.get("RPM", 1000)),
57
- tpm=TPM(env.get("TPM", 50000)),
58
- )
59
-
60
- graph_gen.tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base"))
61
-
62
- strategy_config = config.get("traverse_strategy", {})
63
- graph_gen.traverse_strategy = TraverseStrategy(
64
- qa_form=strategy_config.get("qa_form"),
65
- expand_method=strategy_config.get("expand_method"),
66
- bidirectional=strategy_config.get("bidirectional"),
67
- max_extra_edges=strategy_config.get("max_extra_edges"),
68
- max_tokens=strategy_config.get("max_tokens"),
69
- max_depth=strategy_config.get("max_depth"),
70
- edge_sampling=strategy_config.get("edge_sampling"),
71
- isolated_node_strategy=strategy_config.get("isolated_node_strategy"),
72
- loss_strategy=str(strategy_config.get("loss_strategy")),
73
- )
74
-
75
- return graph_gen
76
-
77
-
78
- # pylint: disable=too-many-statements
79
- def run_graphgen(params, progress=gr.Progress()):
80
- def sum_tokens(client):
81
- return sum(u["total_tokens"] for u in client.token_usage)
82
-
83
- config = {
84
- "if_trainee_model": params.if_trainee_model,
85
- "input_file": params.input_file,
86
- "tokenizer": params.tokenizer,
87
- "quiz_samples": params.quiz_samples,
88
- "traverse_strategy": {
89
- "qa_form": params.qa_form,
90
- "bidirectional": params.bidirectional,
91
- "expand_method": params.expand_method,
92
- "max_extra_edges": params.max_extra_edges,
93
- "max_tokens": params.max_tokens,
94
- "max_depth": params.max_depth,
95
- "edge_sampling": params.edge_sampling,
96
- "isolated_node_strategy": params.isolated_node_strategy,
97
- "loss_strategy": params.loss_strategy,
98
- },
99
- "chunk_size": params.chunk_size,
100
- }
101
-
102
- env = {
103
- "SYNTHESIZER_BASE_URL": params.synthesizer_url,
104
- "SYNTHESIZER_MODEL": params.synthesizer_model,
105
- "TRAINEE_BASE_URL": params.trainee_url,
106
- "TRAINEE_MODEL": params.trainee_model,
107
- "SYNTHESIZER_API_KEY": params.api_key,
108
- "TRAINEE_API_KEY": params.trainee_api_key,
109
- "RPM": params.rpm,
110
- "TPM": params.tpm,
111
- }
112
-
113
- # Test API connection
114
- test_api_connection(
115
- env["SYNTHESIZER_BASE_URL"],
116
- env["SYNTHESIZER_API_KEY"],
117
- env["SYNTHESIZER_MODEL"],
118
- )
119
- if config["if_trainee_model"]:
120
- test_api_connection(
121
- env["TRAINEE_BASE_URL"], env["TRAINEE_API_KEY"], env["TRAINEE_MODEL"]
122
- )
123
-
124
- # Initialize GraphGen
125
- graph_gen = init_graph_gen(config, env)
126
- graph_gen.clear()
127
-
128
- graph_gen.progress_bar = progress
129
-
130
- try:
131
- # Load input data
132
- file = config["input_file"]
133
- if isinstance(file, list):
134
- file = file[0]
135
-
136
- data = []
137
-
138
- if file.endswith(".jsonl"):
139
- data_type = "raw"
140
- with open(file, "r", encoding="utf-8") as f:
141
- data.extend(json.loads(line) for line in f)
142
- elif file.endswith(".json"):
143
- data_type = "chunked"
144
- with open(file, "r", encoding="utf-8") as f:
145
- data.extend(json.load(f))
146
- elif file.endswith(".txt"):
147
- # 读取文件后根据chunk_size转成raw格式的数据
148
- data_type = "raw"
149
- content = ""
150
- with open(file, "r", encoding="utf-8") as f:
151
- lines = f.readlines()
152
- for line in lines:
153
- content += line.strip() + " "
154
- size = int(config.get("chunk_size", 512))
155
- chunks = [content[i : i + size] for i in range(0, len(content), size)]
156
- data.extend([{"content": chunk} for chunk in chunks])
157
- else:
158
- raise ValueError(f"Unsupported file type: {file}")
159
-
160
- # Process the data
161
- graph_gen.insert(data, data_type)
162
-
163
- if config["if_trainee_model"]:
164
- # Generate quiz
165
- graph_gen.quiz(max_samples=config["quiz_samples"])
166
-
167
- # Judge statements
168
- graph_gen.judge()
169
- else:
170
- graph_gen.traverse_strategy.edge_sampling = "random"
171
- # Skip judge statements
172
- graph_gen.judge(skip=True)
173
-
174
- # Traverse graph
175
- graph_gen.traverse(traverse_strategy=graph_gen.traverse_strategy)
176
-
177
- # Save output
178
- output_data = graph_gen.qa_storage.data
179
- with tempfile.NamedTemporaryFile(
180
- mode="w", suffix=".jsonl", delete=False, encoding="utf-8"
181
- ) as tmpfile:
182
- json.dump(output_data, tmpfile, ensure_ascii=False)
183
- output_file = tmpfile.name
184
-
185
- synthesizer_tokens = sum_tokens(graph_gen.synthesizer_llm_client)
186
- trainee_tokens = (
187
- sum_tokens(graph_gen.trainee_llm_client)
188
- if config["if_trainee_model"]
189
- else 0
190
- )
191
- total_tokens = synthesizer_tokens + trainee_tokens
192
-
193
- data_frame = params.token_counter
194
- try:
195
- _update_data = [
196
- [data_frame.iloc[0, 0], data_frame.iloc[0, 1], str(total_tokens)]
197
- ]
198
- new_df = pd.DataFrame(_update_data, columns=data_frame.columns)
199
- data_frame = new_df
200
-
201
- except Exception as e:
202
- raise gr.Error(f"DataFrame operation error: {str(e)}")
203
-
204
- return output_file, gr.DataFrame(
205
- label="Token Stats",
206
- headers=["Source Text Token Count", "Expected Token Usage", "Token Used"],
207
- datatype="str",
208
- interactive=False,
209
- value=data_frame,
210
- visible=True,
211
- wrap=True,
212
- )
213
-
214
- except Exception as e: # pylint: disable=broad-except
215
- raise gr.Error(f"Error occurred: {str(e)}")
216
-
217
- finally:
218
- # Clean up workspace
219
- cleanup_workspace(graph_gen.working_dir)
220
-
221
-
222
- with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
223
- # Header
224
- gr.Image(
225
- value=os.path.join(root_dir, "resources", "images", "logo.png"),
226
- label="GraphGen Banner",
227
- elem_id="banner",
228
- interactive=False,
229
- container=False,
230
- show_download_button=False,
231
- show_fullscreen_button=False,
232
- )
233
- lang_btn = gr.Radio(
234
- choices=[
235
- ("English", "en"),
236
- ("简体中文", "zh"),
237
- ],
238
- value="en",
239
- # label=_("Language"),
240
- render=False,
241
- container=False,
242
- elem_classes=["center-row"],
243
- )
244
-
245
- gr.HTML(
246
- """
247
- <div style="display: flex; gap: 8px; margin-left: auto; align-items: center; justify-content: center;">
248
- <a href="https://github.com/open-sciencelab/GraphGen/releases">
249
- <img src="https://img.shields.io/badge/Version-v0.1.0-blue" alt="Version">
250
- </a>
251
- <a href="https://graphgen-docs.example.com">
252
- <img src="https://img.shields.io/badge/Docs-Latest-brightgreen" alt="Documentation">
253
- </a>
254
- <a href="https://github.com/open-sciencelab/GraphGen/issues/10">
255
- <img src="https://img.shields.io/github/stars/open-sciencelab/GraphGen?style=social" alt="GitHub Stars">
256
- </a>
257
- <a href="https://arxiv.org/abs/2505.20416">
258
- <img src="https://img.shields.io/badge/arXiv-pdf-yellow" alt="arXiv">
259
- </a>
260
- </div>
261
- """
262
- )
263
- with Translate(
264
- os.path.join(root_dir, "webui", "translation.json"),
265
- lang_btn,
266
- placeholder_langs=["en", "zh"],
267
- persistant=False, # True to save the language setting in the browser. Requires gradio >= 5.6.0
268
- ):
269
- lang_btn.render()
270
-
271
- gr.Markdown(
272
- value="# "
273
- + _("Title")
274
- + "\n\n"
275
- + "### [GraphGen](https://github.com/open-sciencelab/GraphGen) "
276
- + _("Intro")
277
- )
278
-
279
- if_trainee_model = gr.Checkbox(
280
- label=_("Use Trainee Model"), value=False, interactive=True
281
- )
282
-
283
- with gr.Accordion(label=_("Model Config"), open=False):
284
- synthesizer_url = gr.Textbox(
285
- label="Synthesizer URL",
286
- value="https://api.siliconflow.cn/v1",
287
- info=_("Synthesizer URL Info"),
288
- interactive=True,
289
- )
290
- synthesizer_model = gr.Textbox(
291
- label="Synthesizer Model",
292
- value="Qwen/Qwen2.5-7B-Instruct",
293
- info=_("Synthesizer Model Info"),
294
- interactive=True,
295
- )
296
- trainee_url = gr.Textbox(
297
- label="Trainee URL",
298
- value="https://api.siliconflow.cn/v1",
299
- info=_("Trainee URL Info"),
300
- interactive=True,
301
- visible=if_trainee_model.value is True,
302
- )
303
- trainee_model = gr.Textbox(
304
- label="Trainee Model",
305
- value="Qwen/Qwen2.5-7B-Instruct",
306
- info=_("Trainee Model Info"),
307
- interactive=True,
308
- visible=if_trainee_model.value is True,
309
- )
310
- trainee_api_key = gr.Textbox(
311
- label=_("SiliconFlow Token for Trainee Model"),
312
- type="password",
313
- value="",
314
- info="https://cloud.siliconflow.cn/account/ak",
315
- visible=if_trainee_model.value is True,
316
- )
317
-
318
- with gr.Accordion(label=_("Generation Config"), open=False):
319
- chunk_size = gr.Slider(
320
- label="Chunk Size",
321
- minimum=256,
322
- maximum=4096,
323
- value=512,
324
- step=256,
325
- interactive=True,
326
- )
327
- tokenizer = gr.Textbox(
328
- label="Tokenizer", value="cl100k_base", interactive=True
329
- )
330
- qa_form = gr.Radio(
331
- choices=["atomic", "multi_hop", "aggregated"],
332
- label="QA Form",
333
- value="aggregated",
334
- interactive=True,
335
- )
336
- quiz_samples = gr.Number(
337
- label="Quiz Samples",
338
- value=2,
339
- minimum=1,
340
- interactive=True,
341
- visible=if_trainee_model.value is True,
342
- )
343
- bidirectional = gr.Checkbox(
344
- label="Bidirectional", value=True, interactive=True
345
- )
346
-
347
- expand_method = gr.Radio(
348
- choices=["max_width", "max_tokens"],
349
- label="Expand Method",
350
- value="max_tokens",
351
- interactive=True,
352
- )
353
- max_extra_edges = gr.Slider(
354
- minimum=1,
355
- maximum=10,
356
- value=5,
357
- label="Max Extra Edges",
358
- step=1,
359
- interactive=True,
360
- visible=expand_method.value == "max_width",
361
- )
362
- max_tokens = gr.Slider(
363
- minimum=64,
364
- maximum=1024,
365
- value=256,
366
- label="Max Tokens",
367
- step=64,
368
- interactive=True,
369
- visible=(expand_method.value != "max_width"),
370
- )
371
-
372
- max_depth = gr.Slider(
373
- minimum=1,
374
- maximum=5,
375
- value=2,
376
- label="Max Depth",
377
- step=1,
378
- interactive=True,
379
- )
380
- edge_sampling = gr.Radio(
381
- choices=["max_loss", "min_loss", "random"],
382
- label="Edge Sampling",
383
- value="max_loss",
384
- interactive=True,
385
- visible=if_trainee_model.value is True,
386
- )
387
- isolated_node_strategy = gr.Radio(
388
- choices=["add", "ignore"],
389
- label="Isolated Node Strategy",
390
- value="ignore",
391
- interactive=True,
392
- )
393
- loss_strategy = gr.Radio(
394
- choices=["only_edge", "both"],
395
- label="Loss Strategy",
396
- value="only_edge",
397
- interactive=True,
398
- )
399
-
400
- with gr.Row(equal_height=True):
401
- with gr.Column(scale=3):
402
- api_key = gr.Textbox(
403
- label=_("SiliconFlow Token"),
404
- type="password",
405
- value="",
406
- info="https://cloud.siliconflow.cn/account/ak",
407
- )
408
- with gr.Column(scale=1):
409
- test_connection_btn = gr.Button(_("Test Connection"))
410
-
411
- with gr.Blocks():
412
- with gr.Row(equal_height=True):
413
- with gr.Column():
414
- rpm = gr.Slider(
415
- label="RPM",
416
- minimum=10,
417
- maximum=10000,
418
- value=1000,
419
- step=100,
420
- interactive=True,
421
- visible=True,
422
- )
423
- with gr.Column():
424
- tpm = gr.Slider(
425
- label="TPM",
426
- minimum=5000,
427
- maximum=5000000,
428
- value=50000,
429
- step=1000,
430
- interactive=True,
431
- visible=True,
432
- )
433
-
434
- with gr.Blocks():
435
- with gr.Row(equal_height=True):
436
- with gr.Column(scale=1):
437
- upload_file = gr.File(
438
- label=_("Upload File"),
439
- file_count="single",
440
- file_types=[".txt", ".json", ".jsonl"],
441
- interactive=True,
442
- )
443
- examples_dir = os.path.join(root_dir, "webui", "examples")
444
- gr.Examples(
445
- examples=[
446
- [os.path.join(examples_dir, "txt_demo.txt")],
447
- [os.path.join(examples_dir, "raw_demo.jsonl")],
448
- [os.path.join(examples_dir, "chunked_demo.json")],
449
- ],
450
- inputs=upload_file,
451
- label=_("Example Files"),
452
- examples_per_page=3,
453
- )
454
- with gr.Column(scale=1):
455
- output = gr.File(
456
- label="Output(See Github FAQ)",
457
- file_count="single",
458
- interactive=False,
459
- )
460
-
461
- with gr.Blocks():
462
- token_counter = gr.DataFrame(
463
- label="Token Stats",
464
- headers=[
465
- "Source Text Token Count",
466
- "Estimated Token Usage",
467
- "Token Used",
468
- ],
469
- datatype="str",
470
- interactive=False,
471
- visible=False,
472
- wrap=True,
473
- )
474
-
475
- submit_btn = gr.Button(_("Run GraphGen"))
476
-
477
- # Test Connection
478
- test_connection_btn.click(
479
- test_api_connection,
480
- inputs=[synthesizer_url, api_key, synthesizer_model],
481
- outputs=[],
482
- )
483
-
484
- if if_trainee_model.value:
485
- test_connection_btn.click(
486
- test_api_connection,
487
- inputs=[trainee_url, api_key, trainee_model],
488
- outputs=[],
489
- )
490
-
491
- expand_method.change(
492
- lambda method: (
493
- gr.update(visible=method == "max_width"),
494
- gr.update(visible=method != "max_width"),
495
- ),
496
- inputs=expand_method,
497
- outputs=[max_extra_edges, max_tokens],
498
- )
499
-
500
- if_trainee_model.change(
501
- lambda use_trainee: [gr.update(visible=use_trainee)] * 5,
502
- inputs=if_trainee_model,
503
- outputs=[
504
- trainee_url,
505
- trainee_model,
506
- quiz_samples,
507
- edge_sampling,
508
- trainee_api_key,
509
- ],
510
- )
511
-
512
- upload_file.change(
513
- lambda x: (gr.update(visible=True)),
514
- inputs=[upload_file],
515
- outputs=[token_counter],
516
- ).then(
517
- count_tokens,
518
- inputs=[upload_file, tokenizer, token_counter],
519
- outputs=[token_counter],
520
- )
521
-
522
- # run GraphGen
523
- submit_btn.click(
524
- lambda x: (gr.update(visible=False)),
525
- inputs=[token_counter],
526
- outputs=[token_counter],
527
- )
528
-
529
- submit_btn.click(
530
- lambda *args: run_graphgen(
531
- GraphGenParams(
532
- if_trainee_model=args[0],
533
- input_file=args[1],
534
- tokenizer=args[2],
535
- qa_form=args[3],
536
- bidirectional=args[4],
537
- expand_method=args[5],
538
- max_extra_edges=args[6],
539
- max_tokens=args[7],
540
- max_depth=args[8],
541
- edge_sampling=args[9],
542
- isolated_node_strategy=args[10],
543
- loss_strategy=args[11],
544
- synthesizer_url=args[12],
545
- synthesizer_model=args[13],
546
- trainee_model=args[14],
547
- api_key=args[15],
548
- chunk_size=args[16],
549
- rpm=args[17],
550
- tpm=args[18],
551
- quiz_samples=args[19],
552
- trainee_url=args[20],
553
- trainee_api_key=args[21],
554
- token_counter=args[22],
555
- )
556
- ),
557
- inputs=[
558
- if_trainee_model,
559
- upload_file,
560
- tokenizer,
561
- qa_form,
562
- bidirectional,
563
- expand_method,
564
- max_extra_edges,
565
- max_tokens,
566
- max_depth,
567
- edge_sampling,
568
- isolated_node_strategy,
569
- loss_strategy,
570
- synthesizer_url,
571
- synthesizer_model,
572
- trainee_model,
573
- api_key,
574
- chunk_size,
575
- rpm,
576
- tpm,
577
- quiz_samples,
578
- trainee_url,
579
- trainee_api_key,
580
- token_counter,
581
- ],
582
- outputs=[output, token_counter],
583
- )
584
-
585
- if __name__ == "__main__":
586
- demo.queue(api_open=False, default_concurrency_limit=2)
587
- demo.launch(server_name="0.0.0.0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf-repo/graphgen/__init__.py DELETED
File without changes
hf-repo/graphgen/configs/README.md DELETED
@@ -1 +0,0 @@
1
- # Configs for GraphGen
 
 
hf-repo/graphgen/configs/aggregated_config.yaml DELETED
@@ -1,21 +0,0 @@
1
- input_data_type: raw # raw, chunked
2
- input_file: resources/input_examples/raw_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples
3
- output_data_type: aggregated # atomic, aggregated, multi_hop, cot
4
- output_data_format: ChatML # Alpaca, Sharegpt, ChatML
5
- tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
6
- search: # web search configuration
7
- enabled: false # whether to enable web search
8
- search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
9
- quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
10
- enabled: true
11
- quiz_samples: 2 # number of quiz samples to generate
12
- re_judge: false # whether to re-judge the existing quiz samples
13
- traverse_strategy: # strategy for clustering sub-graphs using comprehension loss
14
- bidirectional: true # whether to traverse the graph in both directions
15
- edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
16
- expand_method: max_width # expand method, support: max_width, max_depth
17
- isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
18
- max_depth: 5 # maximum depth for graph traversal
19
- max_extra_edges: 20 # max edges per direction (if expand_method="max_width")
20
- max_tokens: 256 # restricts input length (if expand_method="max_tokens")
21
- loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf-repo/graphgen/configs/atomic_config.yaml DELETED
@@ -1,21 +0,0 @@
1
- input_data_type: raw # raw, chunked
2
- input_file: resources/input_examples/raw_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples
3
- output_data_type: atomic # atomic, aggregated, multi_hop, cot
4
- output_data_format: Alpaca # Alpaca, Sharegpt, ChatML
5
- tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
6
- search: # web search configuration
7
- enabled: false # whether to enable web search
8
- search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
9
- quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
10
- enabled: true
11
- quiz_samples: 2 # number of quiz samples to generate
12
- re_judge: false # whether to re-judge the existing quiz samples
13
- traverse_strategy: # strategy for clustering sub-graphs using comprehension loss
14
- bidirectional: true # whether to traverse the graph in both directions
15
- edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
16
- expand_method: max_width # expand method, support: max_width, max_depth
17
- isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
18
- max_depth: 3 # maximum depth for graph traversal
19
- max_extra_edges: 5 # max edges per direction (if expand_method="max_width")
20
- max_tokens: 256 # restricts input length (if expand_method="max_tokens")
21
- loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf-repo/graphgen/configs/cot_config.yaml DELETED
@@ -1,13 +0,0 @@
1
- input_data_type: raw # raw, chunked
2
- input_file: resources/input_examples/raw_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples
3
- output_data_type: cot # atomic, aggregated, multi_hop, cot
4
- output_data_format: Sharegpt # Alpaca, Sharegpt, ChatML
5
- tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
6
- search: # web search configuration
7
- enabled: false # whether to enable web search
8
- search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
9
- method_params:
10
- method: leiden
11
- max_size: 20 # Maximum size of communities
12
- use_lcc: false
13
- random_seed: 42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf-repo/graphgen/configs/multi_hop_config.yaml DELETED
@@ -1,21 +0,0 @@
1
- input_data_type: raw # raw, chunked
2
- input_file: resources/input_examples/raw_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples
3
- output_data_type: multi_hop # atomic, aggregated, multi_hop, cot
4
- output_data_format: ChatML # Alpaca, Sharegpt, ChatML
5
- tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
6
- search: # web search configuration
7
- enabled: false # whether to enable web search
8
- search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
9
- quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
10
- enabled: true
11
- quiz_samples: 2 # number of quiz samples to generate
12
- re_judge: false # whether to re-judge the existing quiz samples
13
- traverse_strategy: # strategy for clustering sub-graphs using comprehension loss
14
- bidirectional: true # whether to traverse the graph in both directions
15
- edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
16
- expand_method: max_width # expand method, support: max_width, max_depth
17
- isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
18
- max_depth: 1 # maximum depth for graph traversal
19
- max_extra_edges: 2 # max edges per direction (if expand_method="max_width")
20
- max_tokens: 256 # restricts input length (if expand_method="max_tokens")
21
- loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf-repo/graphgen/evaluate.py DELETED
@@ -1,142 +0,0 @@
1
- """Evaluate the quality of the generated text using various metrics"""
2
-
3
- import os
4
- import json
5
- import argparse
6
- import pandas as pd
7
- from dotenv import load_dotenv
8
- from .models import LengthEvaluator, MTLDEvaluator, RewardEvaluator, TextPair, UniEvaluator
9
- from .utils import logger, set_logger
10
-
11
- sys_path = os.path.abspath(os.path.dirname(__file__))
12
- set_logger(os.path.join(sys_path, "cache", "logs", "evaluate.log"))
13
-
14
- load_dotenv()
15
-
16
- def evaluate_length(corpus, tokenizer_name):
17
- length_evaluator = LengthEvaluator(
18
- tokenizer_name=tokenizer_name
19
- )
20
- logger.info("Length evaluator loaded")
21
- scores = length_evaluator.get_average_score(corpus)
22
- logger.info("Length scores: %s", scores)
23
- return scores
24
-
25
- def evaluate_mtld(corpus):
26
- mtld_evaluator = MTLDEvaluator()
27
- logger.info("MTLD evaluator loaded")
28
- scores = mtld_evaluator.get_average_score(corpus)
29
- logger.info("MTLD scores: %s", scores)
30
- min_max_scores = mtld_evaluator.get_min_max_score(corpus)
31
- logger.info("MTLD min max scores: %s", min_max_scores)
32
- return scores, min_max_scores
33
-
34
- def evaluate_reward(corpus, reward_model_names):
35
- scores = []
36
- for reward_name in reward_model_names:
37
- reward_evaluator = RewardEvaluator(
38
- reward_name=reward_name
39
- )
40
- logger.info("Loaded reward model: %s", reward_name)
41
- average_score = reward_evaluator.get_average_score(corpus)
42
- logger.info("%s scores: %s", reward_name, average_score)
43
- min_max_scores = reward_evaluator.get_min_max_score(corpus)
44
- logger.info("%s min max scores: %s", reward_name, min_max_scores)
45
- scores.append({
46
- 'reward_name': reward_name.split('/')[-1],
47
- 'score': average_score,
48
- 'min_max_scores': min_max_scores
49
- })
50
- del reward_evaluator
51
- clean_gpu_cache()
52
- return scores
53
-
54
- def evaluate_uni(corpus, uni_model_name):
55
- uni_evaluator = UniEvaluator(
56
- model_name=uni_model_name
57
- )
58
- logger.info("Uni evaluator loaded with model %s", uni_model_name)
59
- uni_scores = uni_evaluator.get_average_score(corpus)
60
- for key, value in uni_scores.items():
61
- logger.info("Uni %s scores: %s", key, value)
62
- min_max_scores = uni_evaluator.get_min_max_score(corpus)
63
- for key, value in min_max_scores.items():
64
- logger.info("Uni %s min max scores: %s", key, value)
65
- del uni_evaluator
66
- clean_gpu_cache()
67
- return (uni_scores['naturalness'], uni_scores['coherence'], uni_scores['understandability'],
68
- min_max_scores['naturalness'], min_max_scores['coherence'], min_max_scores['understandability'])
69
-
70
-
71
- def clean_gpu_cache():
72
- import torch
73
- if torch.cuda.is_available():
74
- torch.cuda.empty_cache()
75
-
76
-
77
- if __name__ == '__main__':
78
- import torch.multiprocessing as mp
79
- parser = argparse.ArgumentParser()
80
-
81
- parser.add_argument('--folder', type=str, default='cache/data', help='folder to load data')
82
- parser.add_argument('--output', type=str, default='cache/output', help='path to save output')
83
-
84
- parser.add_argument('--tokenizer', type=str, default='cl100k_base', help='tokenizer name')
85
- parser.add_argument('--reward', type=str, default='OpenAssistant/reward-model-deberta-v3-large-v2',
86
- help='Comma-separated list of reward models')
87
- parser.add_argument('--uni', type=str, default='MingZhong/unieval-sum', help='uni model name')
88
-
89
- args = parser.parse_args()
90
-
91
- if not os.path.exists(args.folder):
92
- raise ValueError(f"Folder {args.folder} does not exist")
93
-
94
- if not os.path.exists(args.output):
95
- os.makedirs(args.output)
96
-
97
- reward_models = args.reward.split(',')
98
-
99
-
100
- results = []
101
-
102
- logger.info("Data loaded from %s", args.folder)
103
- mp.set_start_method('spawn')
104
-
105
- for file in os.listdir(args.folder):
106
- if file.endswith('.json'):
107
- logger.info("Processing %s", file)
108
- with open(os.path.join(args.folder, file), 'r', encoding='utf-8') as f:
109
- data = json.load(f)
110
- data = [TextPair(
111
- question=data[key]['question'],
112
- answer=data[key]['answer']
113
- ) for key in data]
114
-
115
- length_scores = evaluate_length(data, args.tokenizer)
116
- mtld_scores, min_max_mtld_scores = evaluate_mtld(data)
117
- reward_scores = evaluate_reward(data, reward_models)
118
- uni_naturalness_scores, uni_coherence_scores, uni_understandability_scores, \
119
- min_max_uni_naturalness_scores, min_max_uni_coherence_scores, min_max_uni_understandability_scores \
120
- = evaluate_uni(data, args.uni)
121
-
122
- result = {
123
- 'file': file,
124
- 'number': len(data),
125
- 'length': length_scores,
126
- 'mtld': mtld_scores,
127
- 'mtld_min_max': min_max_mtld_scores,
128
- 'uni_naturalness': uni_naturalness_scores,
129
- 'uni_coherence': uni_coherence_scores,
130
- 'uni_understandability': uni_understandability_scores,
131
- 'uni_naturalness_min_max': min_max_uni_naturalness_scores,
132
- 'uni_coherence_min_max': min_max_uni_coherence_scores,
133
- 'uni_understandability_min_max': min_max_uni_understandability_scores
134
- }
135
- for reward_score in reward_scores:
136
- result[reward_score['reward_name']] = reward_score['score']
137
- result[f"{reward_score['reward_name']}_min_max"] = reward_score['min_max_scores']
138
-
139
- results.append(result)
140
-
141
- results = pd.DataFrame(results)
142
- results.to_csv(os.path.join(args.output, 'evaluation.csv'), index=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf-repo/graphgen/generate.py DELETED
@@ -1,103 +0,0 @@
1
- import argparse
2
- import os
3
- import time
4
- from importlib.resources import files
5
-
6
- import yaml
7
- from dotenv import load_dotenv
8
-
9
- from .graphgen import GraphGen
10
- from .utils import logger, set_logger
11
-
12
- sys_path = os.path.abspath(os.path.dirname(__file__))
13
-
14
- load_dotenv()
15
-
16
-
17
- def set_working_dir(folder):
18
- os.makedirs(folder, exist_ok=True)
19
- os.makedirs(os.path.join(folder, "data", "graphgen"), exist_ok=True)
20
- os.makedirs(os.path.join(folder, "logs"), exist_ok=True)
21
-
22
-
23
- def save_config(config_path, global_config):
24
- if not os.path.exists(os.path.dirname(config_path)):
25
- os.makedirs(os.path.dirname(config_path))
26
- with open(config_path, "w", encoding="utf-8") as config_file:
27
- yaml.dump(
28
- global_config, config_file, default_flow_style=False, allow_unicode=True
29
- )
30
-
31
-
32
- def main():
33
- parser = argparse.ArgumentParser()
34
- parser.add_argument(
35
- "--config_file",
36
- help="Config parameters for GraphGen.",
37
- default=files("graphgen").joinpath("configs", "aggregated_config.yaml"),
38
- type=str,
39
- )
40
- parser.add_argument(
41
- "--output_dir",
42
- help="Output directory for GraphGen.",
43
- default=sys_path,
44
- required=True,
45
- type=str,
46
- )
47
-
48
- args = parser.parse_args()
49
-
50
- working_dir = args.output_dir
51
- set_working_dir(working_dir)
52
-
53
- with open(args.config_file, "r", encoding="utf-8") as f:
54
- config = yaml.load(f, Loader=yaml.FullLoader)
55
-
56
- output_data_type = config["output_data_type"]
57
- unique_id = int(time.time())
58
- set_logger(
59
- os.path.join(
60
- working_dir, "logs", f"graphgen_{output_data_type}_{unique_id}.log"
61
- ),
62
- if_stream=True,
63
- )
64
- logger.info(
65
- "GraphGen with unique ID %s logging to %s",
66
- unique_id,
67
- os.path.join(
68
- working_dir, "logs", f"graphgen_{output_data_type}_{unique_id}.log"
69
- ),
70
- )
71
-
72
- graph_gen = GraphGen(working_dir=working_dir, unique_id=unique_id, config=config)
73
-
74
- graph_gen.insert()
75
-
76
- if config["search"]["enabled"]:
77
- graph_gen.search()
78
-
79
- # Use pipeline according to the output data type
80
- if output_data_type in ["atomic", "aggregated", "multi_hop"]:
81
- if "quiz_and_judge_strategy" in config and config[
82
- "quiz_and_judge_strategy"
83
- ].get("enabled", False):
84
- graph_gen.quiz()
85
- graph_gen.judge()
86
- else:
87
- logger.warning(
88
- "Quiz and Judge strategy is disabled. Edge sampling falls back to random."
89
- )
90
- graph_gen.traverse_strategy.edge_sampling = "random"
91
- graph_gen.traverse()
92
- elif output_data_type == "cot":
93
- graph_gen.generate_reasoning(method_params=config["method_params"])
94
- else:
95
- raise ValueError(f"Unsupported output data type: {output_data_type}")
96
-
97
- output_path = os.path.join(working_dir, "data", "graphgen", str(unique_id))
98
- save_config(os.path.join(output_path, f"config-{unique_id}.yaml"), config)
99
- logger.info("GraphGen completed successfully. Data saved to %s", output_path)
100
-
101
-
102
- if __name__ == "__main__":
103
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf-repo/graphgen/graphgen.py DELETED
@@ -1,395 +0,0 @@
1
- import asyncio
2
- import os
3
- import time
4
- from dataclasses import dataclass, field
5
- from typing import Dict, List, Union, cast
6
-
7
- import gradio as gr
8
- from tqdm.asyncio import tqdm as tqdm_async
9
-
10
- from .models import (
11
- Chunk,
12
- JsonKVStorage,
13
- JsonListStorage,
14
- NetworkXStorage,
15
- OpenAIModel,
16
- Tokenizer,
17
- TraverseStrategy,
18
- )
19
- from .models.storage.base_storage import StorageNameSpace
20
- from .operators import (
21
- extract_kg,
22
- generate_cot,
23
- judge_statement,
24
- quiz,
25
- search_all,
26
- traverse_graph_atomically,
27
- traverse_graph_by_edge,
28
- traverse_graph_for_multi_hop,
29
- )
30
- from .utils import (
31
- compute_content_hash,
32
- create_event_loop,
33
- format_generation_results,
34
- logger,
35
- read_file,
36
- )
37
-
38
- sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
39
-
40
-
41
- @dataclass
42
- class GraphGen:
43
- unique_id: int = int(time.time())
44
- working_dir: str = os.path.join(sys_path, "cache")
45
- config: Dict = field(default_factory=dict)
46
-
47
- # llm
48
- tokenizer_instance: Tokenizer = None
49
- synthesizer_llm_client: OpenAIModel = None
50
- trainee_llm_client: OpenAIModel = None
51
-
52
- # text chunking
53
- # TODO: make it configurable
54
- chunk_size: int = 1024
55
- chunk_overlap_size: int = 100
56
-
57
- # search
58
- search_config: dict = field(
59
- default_factory=lambda: {"enabled": False, "search_types": ["wikipedia"]}
60
- )
61
-
62
- # traversal
63
- traverse_strategy: TraverseStrategy = None
64
-
65
- # webui
66
- progress_bar: gr.Progress = None
67
-
68
- def __post_init__(self):
69
- self.tokenizer_instance: Tokenizer = Tokenizer(
70
- model_name=self.config["tokenizer"]
71
- )
72
- self.synthesizer_llm_client: OpenAIModel = OpenAIModel(
73
- model_name=os.getenv("SYNTHESIZER_MODEL"),
74
- api_key=os.getenv("SYNTHESIZER_API_KEY"),
75
- base_url=os.getenv("SYNTHESIZER_BASE_URL"),
76
- tokenizer_instance=self.tokenizer_instance,
77
- )
78
- self.trainee_llm_client: OpenAIModel = OpenAIModel(
79
- model_name=os.getenv("TRAINEE_MODEL"),
80
- api_key=os.getenv("TRAINEE_API_KEY"),
81
- base_url=os.getenv("TRAINEE_BASE_URL"),
82
- tokenizer_instance=self.tokenizer_instance,
83
- )
84
- self.search_config = self.config["search"]
85
-
86
- if "traverse_strategy" in self.config:
87
- self.traverse_strategy = TraverseStrategy(
88
- **self.config["traverse_strategy"]
89
- )
90
-
91
- self.full_docs_storage: JsonKVStorage = JsonKVStorage(
92
- self.working_dir, namespace="full_docs"
93
- )
94
- self.text_chunks_storage: JsonKVStorage = JsonKVStorage(
95
- self.working_dir, namespace="text_chunks"
96
- )
97
- self.graph_storage: NetworkXStorage = NetworkXStorage(
98
- self.working_dir, namespace="graph"
99
- )
100
- self.search_storage: JsonKVStorage = JsonKVStorage(
101
- self.working_dir, namespace="search"
102
- )
103
- self.rephrase_storage: JsonKVStorage = JsonKVStorage(
104
- self.working_dir, namespace="rephrase"
105
- )
106
- self.qa_storage: JsonListStorage = JsonListStorage(
107
- os.path.join(self.working_dir, "data", "graphgen", str(self.unique_id)),
108
- namespace=f"qa-{self.unique_id}",
109
- )
110
-
111
- async def async_split_chunks(
112
- self, data: List[Union[List, Dict]], data_type: str
113
- ) -> dict:
114
- # TODO: configurable whether to use coreference resolution
115
- if len(data) == 0:
116
- return {}
117
-
118
- inserting_chunks = {}
119
- if data_type == "raw":
120
- assert isinstance(data, list) and isinstance(data[0], dict)
121
- # compute hash for each document
122
- new_docs = {
123
- compute_content_hash(doc["content"], prefix="doc-"): {
124
- "content": doc["content"]
125
- }
126
- for doc in data
127
- }
128
- _add_doc_keys = await self.full_docs_storage.filter_keys(
129
- list(new_docs.keys())
130
- )
131
- new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
132
- if len(new_docs) == 0:
133
- logger.warning("All docs are already in the storage")
134
- return {}
135
- logger.info("[New Docs] inserting %d docs", len(new_docs))
136
-
137
- cur_index = 1
138
- doc_number = len(new_docs)
139
- async for doc_key, doc in tqdm_async(
140
- new_docs.items(), desc="[1/4]Chunking documents", unit="doc"
141
- ):
142
- chunks = {
143
- compute_content_hash(dp["content"], prefix="chunk-"): {
144
- **dp,
145
- "full_doc_id": doc_key,
146
- }
147
- for dp in self.tokenizer_instance.chunk_by_token_size(
148
- doc["content"], self.chunk_overlap_size, self.chunk_size
149
- )
150
- }
151
- inserting_chunks.update(chunks)
152
-
153
- if self.progress_bar is not None:
154
- self.progress_bar(cur_index / doc_number, f"Chunking {doc_key}")
155
- cur_index += 1
156
-
157
- _add_chunk_keys = await self.text_chunks_storage.filter_keys(
158
- list(inserting_chunks.keys())
159
- )
160
- inserting_chunks = {
161
- k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
162
- }
163
- elif data_type == "chunked":
164
- assert isinstance(data, list) and isinstance(data[0], list)
165
- new_docs = {
166
- compute_content_hash("".join(chunk["content"]), prefix="doc-"): {
167
- "content": "".join(chunk["content"])
168
- }
169
- for doc in data
170
- for chunk in doc
171
- }
172
- _add_doc_keys = await self.full_docs_storage.filter_keys(
173
- list(new_docs.keys())
174
- )
175
- new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
176
- if len(new_docs) == 0:
177
- logger.warning("All docs are already in the storage")
178
- return {}
179
- logger.info("[New Docs] inserting %d docs", len(new_docs))
180
- async for doc in tqdm_async(
181
- data, desc="[1/4]Chunking documents", unit="doc"
182
- ):
183
- doc_str = "".join([chunk["content"] for chunk in doc])
184
- for chunk in doc:
185
- chunk_key = compute_content_hash(chunk["content"], prefix="chunk-")
186
- inserting_chunks[chunk_key] = {
187
- **chunk,
188
- "full_doc_id": compute_content_hash(doc_str, prefix="doc-"),
189
- }
190
- _add_chunk_keys = await self.text_chunks_storage.filter_keys(
191
- list(inserting_chunks.keys())
192
- )
193
- inserting_chunks = {
194
- k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
195
- }
196
- else:
197
- raise ValueError(f"Unknown data type: {data_type}")
198
-
199
- await self.full_docs_storage.upsert(new_docs)
200
- await self.text_chunks_storage.upsert(inserting_chunks)
201
-
202
- return inserting_chunks
203
-
204
- def insert(self):
205
- loop = create_event_loop()
206
- loop.run_until_complete(self.async_insert())
207
-
208
- async def async_insert(self):
209
- """
210
- insert chunks into the graph
211
- """
212
-
213
- input_file = self.config["input_file"]
214
- data_type = self.config["input_data_type"]
215
- data = read_file(input_file)
216
-
217
- inserting_chunks = await self.async_split_chunks(data, data_type)
218
-
219
- if len(inserting_chunks) == 0:
220
- logger.warning("All chunks are already in the storage")
221
- return
222
- logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks))
223
-
224
- logger.info("[Entity and Relation Extraction]...")
225
- _add_entities_and_relations = await extract_kg(
226
- llm_client=self.synthesizer_llm_client,
227
- kg_instance=self.graph_storage,
228
- tokenizer_instance=self.tokenizer_instance,
229
- chunks=[
230
- Chunk(id=k, content=v["content"]) for k, v in inserting_chunks.items()
231
- ],
232
- progress_bar=self.progress_bar,
233
- )
234
- if not _add_entities_and_relations:
235
- logger.warning("No entities or relations extracted")
236
- return
237
-
238
- await self._insert_done()
239
-
240
- async def _insert_done(self):
241
- tasks = []
242
- for storage_instance in [
243
- self.full_docs_storage,
244
- self.text_chunks_storage,
245
- self.graph_storage,
246
- self.search_storage,
247
- ]:
248
- if storage_instance is None:
249
- continue
250
- tasks.append(cast(StorageNameSpace, storage_instance).index_done_callback())
251
- await asyncio.gather(*tasks)
252
-
253
- def search(self):
254
- loop = create_event_loop()
255
- loop.run_until_complete(self.async_search())
256
-
257
- async def async_search(self):
258
- logger.info(
259
- "Search is %s", "enabled" if self.search_config["enabled"] else "disabled"
260
- )
261
- if self.search_config["enabled"]:
262
- logger.info(
263
- "[Search] %s ...", ", ".join(self.search_config["search_types"])
264
- )
265
- all_nodes = await self.graph_storage.get_all_nodes()
266
- all_nodes_names = [node[0] for node in all_nodes]
267
- new_search_entities = await self.full_docs_storage.filter_keys(
268
- all_nodes_names
269
- )
270
- logger.info(
271
- "[Search] Found %d entities to search", len(new_search_entities)
272
- )
273
- _add_search_data = await search_all(
274
- search_types=self.search_config["search_types"],
275
- search_entities=new_search_entities,
276
- )
277
- if _add_search_data:
278
- await self.search_storage.upsert(_add_search_data)
279
- logger.info("[Search] %d entities searched", len(_add_search_data))
280
-
281
- # Format search results for inserting
282
- search_results = []
283
- for _, search_data in _add_search_data.items():
284
- search_results.extend(
285
- [
286
- {"content": search_data[key]}
287
- for key in list(search_data.keys())
288
- ]
289
- )
290
- # TODO: fix insert after search
291
- await self.async_insert()
292
-
293
- def quiz(self):
294
- loop = create_event_loop()
295
- loop.run_until_complete(self.async_quiz())
296
-
297
- async def async_quiz(self):
298
- max_samples = self.config["quiz_and_judge_strategy"]["quiz_samples"]
299
- await quiz(
300
- self.synthesizer_llm_client,
301
- self.graph_storage,
302
- self.rephrase_storage,
303
- max_samples,
304
- )
305
- await self.rephrase_storage.index_done_callback()
306
-
307
- def judge(self):
308
- loop = create_event_loop()
309
- loop.run_until_complete(self.async_judge())
310
-
311
- async def async_judge(self):
312
- re_judge = self.config["quiz_and_judge_strategy"]["re_judge"]
313
- _update_relations = await judge_statement(
314
- self.trainee_llm_client,
315
- self.graph_storage,
316
- self.rephrase_storage,
317
- re_judge,
318
- )
319
- await _update_relations.index_done_callback()
320
-
321
- def traverse(self):
322
- loop = create_event_loop()
323
- loop.run_until_complete(self.async_traverse())
324
-
325
- async def async_traverse(self):
326
- output_data_type = self.config["output_data_type"]
327
-
328
- if output_data_type == "atomic":
329
- results = await traverse_graph_atomically(
330
- self.synthesizer_llm_client,
331
- self.tokenizer_instance,
332
- self.graph_storage,
333
- self.traverse_strategy,
334
- self.text_chunks_storage,
335
- self.progress_bar,
336
- )
337
- elif output_data_type == "multi_hop":
338
- results = await traverse_graph_for_multi_hop(
339
- self.synthesizer_llm_client,
340
- self.tokenizer_instance,
341
- self.graph_storage,
342
- self.traverse_strategy,
343
- self.text_chunks_storage,
344
- self.progress_bar,
345
- )
346
- elif output_data_type == "aggregated":
347
- results = await traverse_graph_by_edge(
348
- self.synthesizer_llm_client,
349
- self.tokenizer_instance,
350
- self.graph_storage,
351
- self.traverse_strategy,
352
- self.text_chunks_storage,
353
- self.progress_bar,
354
- )
355
- else:
356
- raise ValueError(f"Unknown qa_form: {output_data_type}")
357
-
358
- results = format_generation_results(
359
- results, output_data_format=self.config["output_data_format"]
360
- )
361
-
362
- await self.qa_storage.upsert(results)
363
- await self.qa_storage.index_done_callback()
364
-
365
- def generate_reasoning(self, method_params):
366
- loop = create_event_loop()
367
- loop.run_until_complete(self.async_generate_reasoning(method_params))
368
-
369
- async def async_generate_reasoning(self, method_params):
370
- results = await generate_cot(
371
- self.graph_storage,
372
- self.synthesizer_llm_client,
373
- method_params=method_params,
374
- )
375
-
376
- results = format_generation_results(
377
- results, output_data_format=self.config["output_data_format"]
378
- )
379
-
380
- await self.qa_storage.upsert(results)
381
- await self.qa_storage.index_done_callback()
382
-
383
- def clear(self):
384
- loop = create_event_loop()
385
- loop.run_until_complete(self.async_clear())
386
-
387
- async def async_clear(self):
388
- await self.full_docs_storage.drop()
389
- await self.text_chunks_storage.drop()
390
- await self.search_storage.drop()
391
- await self.graph_storage.clear()
392
- await self.rephrase_storage.drop()
393
- await self.qa_storage.drop()
394
-
395
- logger.info("All caches are cleared")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf-repo/graphgen/judge.py DELETED
@@ -1,60 +0,0 @@
1
- import os
2
- import argparse
3
- import asyncio
4
- from dotenv import load_dotenv
5
-
6
- from .models import NetworkXStorage, JsonKVStorage, OpenAIModel
7
- from .operators import judge_statement
8
-
9
- sys_path = os.path.abspath(os.path.dirname(__file__))
10
-
11
- load_dotenv()
12
-
13
- def calculate_average_loss(graph: NetworkXStorage):
14
- """
15
- Calculate the average loss of the graph.
16
-
17
- :param graph: NetworkXStorage
18
- :return: float
19
- """
20
- edges = asyncio.run(graph.get_all_edges())
21
- total_loss = 0
22
- for edge in edges:
23
- total_loss += edge[2]['loss']
24
- return total_loss / len(edges)
25
-
26
-
27
-
28
- if __name__ == '__main__':
29
- parser = argparse.ArgumentParser()
30
- parser.add_argument('--input', type=str, default=os.path.join(sys_path, "cache"), help='path to load input graph')
31
- parser.add_argument('--output', type=str, default='cache/output/new_graph.graphml', help='path to save output')
32
-
33
- args = parser.parse_args()
34
-
35
- llm_client = OpenAIModel(
36
- model_name=os.getenv("TRAINEE_MODEL"),
37
- api_key=os.getenv("TRAINEE_API_KEY"),
38
- base_url=os.getenv("TRAINEE_BASE_URL")
39
- )
40
-
41
- graph_storage = NetworkXStorage(
42
- args.input,
43
- namespace="graph"
44
- )
45
- average_loss = calculate_average_loss(graph_storage)
46
- print(f"Average loss of the graph: {average_loss}")
47
-
48
- rephrase_storage = JsonKVStorage(
49
- os.path.join(sys_path, "cache"),
50
- namespace="rephrase"
51
- )
52
-
53
- new_graph = asyncio.run(judge_statement(llm_client, graph_storage, rephrase_storage, re_judge=True))
54
-
55
- graph_file = asyncio.run(graph_storage.get_graph())
56
-
57
- new_graph.write_nx_graph(graph_file, args.output)
58
-
59
- average_loss = calculate_average_loss(new_graph)
60
- print(f"Average loss of the graph: {average_loss}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf-repo/graphgen/models/__init__.py DELETED
@@ -1,45 +0,0 @@
1
- from .community.community_detector import CommunityDetector
2
- from .evaluate.length_evaluator import LengthEvaluator
3
- from .evaluate.mtld_evaluator import MTLDEvaluator
4
- from .evaluate.reward_evaluator import RewardEvaluator
5
- from .evaluate.uni_evaluator import UniEvaluator
6
- from .llm.openai_model import OpenAIModel
7
- from .llm.tokenizer import Tokenizer
8
- from .llm.topk_token_model import Token, TopkTokenModel
9
- from .search.db.uniprot_search import UniProtSearch
10
- from .search.kg.wiki_search import WikiSearch
11
- from .search.web.bing_search import BingSearch
12
- from .search.web.google_search import GoogleSearch
13
- from .storage.json_storage import JsonKVStorage, JsonListStorage
14
- from .storage.networkx_storage import NetworkXStorage
15
- from .strategy.travserse_strategy import TraverseStrategy
16
- from .text.chunk import Chunk
17
- from .text.text_pair import TextPair
18
-
19
- __all__ = [
20
- # llm models
21
- "OpenAIModel",
22
- "TopkTokenModel",
23
- "Token",
24
- "Tokenizer",
25
- # storage models
26
- "Chunk",
27
- "NetworkXStorage",
28
- "JsonKVStorage",
29
- "JsonListStorage",
30
- # search models
31
- "WikiSearch",
32
- "GoogleSearch",
33
- "BingSearch",
34
- "UniProtSearch",
35
- # evaluate models
36
- "TextPair",
37
- "LengthEvaluator",
38
- "MTLDEvaluator",
39
- "RewardEvaluator",
40
- "UniEvaluator",
41
- # strategy models
42
- "TraverseStrategy",
43
- # community models
44
- "CommunityDetector",
45
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf-repo/graphgen/models/community/__init__.py DELETED
File without changes
hf-repo/graphgen/models/community/community_detector.py DELETED
@@ -1,95 +0,0 @@
1
- from collections import defaultdict
2
- from dataclasses import dataclass
3
- from typing import Any, Dict, List
4
-
5
- from graphgen.models.storage.networkx_storage import NetworkXStorage
6
-
7
-
8
- @dataclass
9
- class CommunityDetector:
10
- """Class for community detection algorithms."""
11
-
12
- graph_storage: NetworkXStorage = None
13
- method: str = "leiden"
14
- method_params: Dict[str, Any] = None
15
-
16
- async def detect_communities(self) -> Dict[str, int]:
17
- if self.method == "leiden":
18
- return await self._leiden_communities(**self.method_params or {})
19
- raise ValueError(f"Unknown community detection method: {self.method}")
20
-
21
- async def get_graph(self):
22
- return await self.graph_storage.get_graph()
23
-
24
- async def _leiden_communities(
25
- self, max_size: int = None, **kwargs
26
- ) -> Dict[str, int]:
27
- """
28
- Detect communities using the Leiden algorithm.
29
- If max_size is given, any community larger than max_size will be split
30
- into smaller sub-communities each having at most max_size nodes.
31
- """
32
- import igraph as ig
33
- import networkx as nx
34
- from leidenalg import ModularityVertexPartition, find_partition
35
-
36
- graph = await self.get_graph()
37
- graph.remove_nodes_from(list(nx.isolates(graph)))
38
-
39
- ig_graph = ig.Graph.TupleList(graph.edges(), directed=False)
40
-
41
- random_seed = kwargs.get("random_seed", 42)
42
- use_lcc = kwargs.get("use_lcc", False)
43
-
44
- communities: Dict[str, int] = {}
45
- if use_lcc:
46
- lcc = ig_graph.components().giant()
47
- partition = find_partition(lcc, ModularityVertexPartition, seed=random_seed)
48
- for part, cluster in enumerate(partition):
49
- for v in cluster:
50
- communities[lcc.vs[v]["name"]] = part
51
- else:
52
- offset = 0
53
- for component in ig_graph.components():
54
- subgraph = ig_graph.induced_subgraph(component)
55
- partition = find_partition(
56
- subgraph, ModularityVertexPartition, seed=random_seed
57
- )
58
- for part, cluster in enumerate(partition):
59
- for v in cluster:
60
- original_node = subgraph.vs[v]["name"]
61
- communities[original_node] = part + offset
62
- offset += len(partition)
63
-
64
- # split large communities if max_size is specified
65
- if max_size is None or max_size <= 0:
66
- return communities
67
-
68
- return await self._split_communities(communities, max_size)
69
-
70
- @staticmethod
71
- async def _split_communities(
72
- communities: Dict[str, int], max_size: int
73
- ) -> Dict[str, int]:
74
- """
75
- Split communities larger than max_size into smaller sub-communities.
76
- """
77
- cid2nodes: Dict[int, List[str]] = defaultdict(list)
78
- for node, cid in communities.items():
79
- cid2nodes[cid].append(node)
80
-
81
- new_communities: Dict[str, int] = {}
82
- new_cid = 0
83
- for cid, nodes in cid2nodes.items():
84
- if len(nodes) <= max_size:
85
- for n in nodes:
86
- new_communities[n] = new_cid
87
- new_cid += 1
88
- else:
89
- for start in range(0, len(nodes), max_size):
90
- sub = nodes[start : start + max_size]
91
- for n in sub:
92
- new_communities[n] = new_cid
93
- new_cid += 1
94
-
95
- return new_communities
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf-repo/graphgen/models/embed/__init__.py DELETED
File without changes
hf-repo/graphgen/models/embed/embedding.py DELETED
@@ -1,29 +0,0 @@
1
- from dataclasses import dataclass
2
- import asyncio
3
- import numpy as np
4
-
5
- class UnlimitedSemaphore:
6
- """A context manager that allows unlimited access."""
7
-
8
- async def __aenter__(self):
9
- pass
10
-
11
- async def __aexit__(self, exc_type, exc, tb):
12
- pass
13
-
14
- @dataclass
15
- class EmbeddingFunc:
16
- embedding_dim: int
17
- max_token_size: int
18
- func: callable
19
- concurrent_limit: int = 16
20
-
21
- def __post_init__(self):
22
- if self.concurrent_limit != 0:
23
- self._semaphore = asyncio.Semaphore(self.concurrent_limit)
24
- else:
25
- self._semaphore = UnlimitedSemaphore()
26
-
27
- async def __call__(self, *args, **kwargs) -> np.ndarray:
28
- async with self._semaphore:
29
- return await self.func(*args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf-repo/graphgen/models/evaluate/__init__.py DELETED
File without changes
hf-repo/graphgen/models/evaluate/base_evaluator.py DELETED
@@ -1,51 +0,0 @@
1
- import asyncio
2
-
3
- from dataclasses import dataclass
4
- from tqdm.asyncio import tqdm as tqdm_async
5
- from graphgen.utils import create_event_loop
6
- from graphgen.models.text.text_pair import TextPair
7
-
8
- @dataclass
9
- class BaseEvaluator:
10
- max_concurrent: int = 100
11
- results: list[float] = None
12
-
13
- def evaluate(self, pairs: list[TextPair]) -> list[float]:
14
- """
15
- Evaluate the text and return a score.
16
- """
17
- return create_event_loop().run_until_complete(self.async_evaluate(pairs))
18
-
19
- async def async_evaluate(self, pairs: list[TextPair]) -> list[float]:
20
- semaphore = asyncio.Semaphore(self.max_concurrent)
21
-
22
- async def evaluate_with_semaphore(pair):
23
- async with semaphore: # 获取Semaphore
24
- return await self.evaluate_single(pair)
25
-
26
- results = []
27
- for result in tqdm_async(
28
- asyncio.as_completed([evaluate_with_semaphore(pair) for pair in pairs]),
29
- total=len(pairs),
30
- ):
31
- results.append(await result)
32
- return results
33
-
34
- async def evaluate_single(self, pair: TextPair) -> float:
35
- raise NotImplementedError()
36
-
37
- def get_average_score(self, pairs: list[TextPair]) -> float:
38
- """
39
- Get the average score of a batch of texts.
40
- """
41
- results = self.evaluate(pairs)
42
- self.results = results
43
- return sum(self.results) / len(pairs)
44
-
45
- def get_min_max_score(self, pairs: list[TextPair]) -> tuple[float, float]:
46
- """
47
- Get the min and max score of a batch of texts.
48
- """
49
- if self.results is None:
50
- self.get_average_score(pairs)
51
- return min(self.results), max(self.results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf-repo/graphgen/models/evaluate/length_evaluator.py DELETED
@@ -1,22 +0,0 @@
1
- from dataclasses import dataclass
2
- from graphgen.models.evaluate.base_evaluator import BaseEvaluator
3
- from graphgen.models.llm.tokenizer import Tokenizer
4
- from graphgen.models.text.text_pair import TextPair
5
- from graphgen.utils import create_event_loop
6
-
7
-
8
- @dataclass
9
- class LengthEvaluator(BaseEvaluator):
10
- tokenizer_name: str = "cl100k_base"
11
- def __post_init__(self):
12
- self.tokenizer = Tokenizer(
13
- model_name=self.tokenizer_name
14
- )
15
-
16
- async def evaluate_single(self, pair: TextPair) -> float:
17
- loop = create_event_loop()
18
- return await loop.run_in_executor(None, self._calculate_length, pair.answer)
19
-
20
- def _calculate_length(self, text: str) -> float:
21
- tokens = self.tokenizer.encode_string(text)
22
- return len(tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf-repo/graphgen/models/evaluate/mtld_evaluator.py DELETED
@@ -1,76 +0,0 @@
1
- from dataclasses import dataclass, field
2
- from typing import Set
3
-
4
- from graphgen.models.evaluate.base_evaluator import BaseEvaluator
5
- from graphgen.models.text.text_pair import TextPair
6
- from graphgen.utils import detect_main_language, NLTKHelper, create_event_loop
7
-
8
-
9
- nltk_helper = NLTKHelper()
10
-
11
- @dataclass
12
- class MTLDEvaluator(BaseEvaluator):
13
- """
14
- 衡量文本词汇多样性的指标
15
- """
16
- stopwords_en: Set[str] = field(default_factory=lambda: set(nltk_helper.get_stopwords("english")))
17
- stopwords_zh: Set[str] = field(default_factory=lambda: set(nltk_helper.get_stopwords("chinese")))
18
-
19
- async def evaluate_single(self, pair: TextPair) -> float:
20
- loop = create_event_loop()
21
- return await loop.run_in_executor(None, self._calculate_mtld_score, pair.answer)
22
-
23
- def _calculate_mtld_score(self, text: str, threshold=0.72) -> float:
24
- """
25
- 计算MTLD (向前和向后的平均值)
26
-
27
- min is 1.0
28
- higher is better
29
- """
30
- if not text or not text.strip():
31
- return 0.0
32
-
33
- lang = detect_main_language(text)
34
- tokens = nltk_helper.word_tokenize(text, lang)
35
-
36
- stopwords = self.stopwords_zh if lang == "zh" else self.stopwords_en
37
- filtered_tokens = [word for word in tokens if word not in stopwords]
38
- filtered_tokens = [word for word in filtered_tokens if word.isalnum()]
39
-
40
- if not filtered_tokens:
41
- return 0
42
-
43
- # 计算向前的MTLD
44
- forward_factors = self._compute_factors(filtered_tokens, threshold)
45
-
46
- # 计算向后的MTLD
47
- backward_factors = self._compute_factors(filtered_tokens[::-1], threshold)
48
-
49
- # 取平均值
50
- return (forward_factors + backward_factors) / 2
51
-
52
- @staticmethod
53
- def _compute_factors(tokens: list, threshold: float) -> float:
54
- factors = 0
55
- current_segment = []
56
- unique_words = set()
57
-
58
- for token in tokens:
59
- current_segment.append(token)
60
- unique_words.add(token)
61
- ttr = len(unique_words) / len(current_segment)
62
-
63
- if ttr <= threshold:
64
- factors += 1
65
- current_segment = []
66
- unique_words = set()
67
-
68
- # 处理最后一个不完整片段
69
- if current_segment:
70
- ttr = len(unique_words) / len(current_segment)
71
- if ttr <= threshold:
72
- factors += 1
73
- else:
74
- factors += (1 - (ttr - threshold) / (1 - threshold))
75
-
76
- return len(tokens) / factors if factors > 0 else len(tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf-repo/graphgen/models/evaluate/reward_evaluator.py DELETED
@@ -1,101 +0,0 @@
1
- from dataclasses import dataclass
2
- from tqdm import tqdm
3
- from graphgen.models.text.text_pair import TextPair
4
-
5
-
6
- @dataclass
7
- class RewardEvaluator:
8
- """
9
- Reward Model Evaluator.
10
- OpenAssistant/reward-model-deberta-v3-large-v2: 分数范围为[-inf, inf],越高越好
11
- """
12
- reward_name: str = "OpenAssistant/reward-model-deberta-v3-large-v2"
13
- max_length: int = 2560
14
- results: list[float] = None
15
-
16
- def __post_init__(self):
17
- import torch
18
- self.num_gpus = torch.cuda.device_count()
19
-
20
- @staticmethod
21
- def process_chunk(rank, pairs, reward_name, max_length, return_dict):
22
- import torch
23
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
24
- device = f'cuda:{rank}'
25
- torch.cuda.set_device(rank)
26
-
27
- rank_model = AutoModelForSequenceClassification.from_pretrained(reward_name)
28
- tokenizer = AutoTokenizer.from_pretrained(reward_name)
29
- rank_model.to(device)
30
- rank_model.eval()
31
-
32
- results = []
33
- with torch.no_grad():
34
- for pair in tqdm(pairs):
35
- inputs = tokenizer(
36
- pair.question,
37
- pair.answer,
38
- return_tensors="pt",
39
- max_length=max_length,
40
- truncation=True
41
- )
42
- inputs = {k: v.to(device) for k, v in inputs.items()}
43
- score = rank_model(**inputs).logits[0].item()
44
- results.append(score)
45
-
46
- return_dict[rank] = results
47
-
48
- def evaluate(self, pairs: list[TextPair]) -> list[float]:
49
- import torch.multiprocessing as mp
50
- chunk_size = len(pairs) // self.num_gpus
51
- chunks = []
52
- for i in range(self.num_gpus):
53
- start = i * chunk_size
54
- end = start + chunk_size
55
- if i == self.num_gpus - 1:
56
- end = len(pairs)
57
- chunks.append(pairs[start:end])
58
-
59
- # multi-process
60
- manager = mp.Manager()
61
- return_dict = manager.dict()
62
- processes = []
63
-
64
- for rank, chunk in enumerate(chunks):
65
- p = mp.Process(
66
- target=self.process_chunk,
67
- args=(rank, chunk, self.reward_name, self.max_length, return_dict)
68
- )
69
- p.start()
70
- processes.append(p)
71
-
72
- for p in processes:
73
- p.join()
74
-
75
- # 合并结果
76
- results = []
77
- for rank in range(len(chunks)):
78
- results.extend(return_dict[rank])
79
-
80
- for p in processes:
81
- if p.is_alive():
82
- p.terminate()
83
- p.join()
84
-
85
- return results
86
-
87
- def get_average_score(self, pairs: list[TextPair]) -> float:
88
- """
89
- Get the average score of a batch of texts.
90
- """
91
- results = self.evaluate(pairs)
92
- self.results = results
93
- return sum(self.results) / len(pairs)
94
-
95
- def get_min_max_score(self, pairs: list[TextPair]) -> tuple[float, float]:
96
- """
97
- Get the min and max score of a batch of texts.
98
- """
99
- if self.results is None:
100
- self.get_average_score(pairs)
101
- return min(self.results), max(self.results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf-repo/graphgen/models/evaluate/uni_evaluator.py DELETED
@@ -1,159 +0,0 @@
1
- # https://github.com/maszhongming/UniEval/tree/main
2
-
3
- from dataclasses import dataclass, field
4
- from tqdm import tqdm
5
- from graphgen.models.text.text_pair import TextPair
6
-
7
-
8
- def _add_questions(dimension: str, question: str, answer: str):
9
- if dimension == "naturalness":
10
- cur_input = 'question: Is this a natural response in the dialogue? </s> response: ' + answer
11
- elif dimension == "coherence":
12
- cur_input = 'question: Is this a coherent response given the dialogue history? </s> response: ' \
13
- + answer + ' </s> dialogue history: ' + question
14
- elif dimension == "understandability":
15
- cur_input = 'question: Is this an understandable response in the dialogue? </s> response: ' + answer
16
- else:
17
- raise NotImplementedError(
18
- 'The input format for this dimension is still undefined. Please customize it first.')
19
- return cur_input
20
-
21
- @dataclass
22
- class UniEvaluator:
23
- model_name: str = "MingZhong/unieval-sum"
24
- dimensions: list = field(default_factory=lambda: ['naturalness', 'coherence', 'understandability'])
25
- max_length: int = 2560
26
- results: dict = None
27
-
28
- def __post_init__(self):
29
- import torch
30
- self.num_gpus = torch.cuda.device_count()
31
- self.results = {}
32
-
33
- @staticmethod
34
- def process_chunk(rank, pairs, model_name, max_length, dimension, return_dict):
35
- import torch
36
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
37
- device = f'cuda:{rank}'
38
- torch.cuda.set_device(rank)
39
-
40
- rank_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
41
- tokenizer = AutoTokenizer.from_pretrained(model_name)
42
- rank_model.to(device)
43
- rank_model.eval()
44
-
45
- softmax = torch.nn.Softmax(dim=1)
46
-
47
- pos_id = tokenizer("Yes")["input_ids"][0]
48
- neg_id = tokenizer("No")["input_ids"][0]
49
-
50
- results = []
51
- with torch.no_grad():
52
- for pair in tqdm(pairs):
53
- text = _add_questions(dimension, pair.question, pair.answer)
54
-
55
- tgt = "No"
56
-
57
- encoded_src = tokenizer(
58
- text,
59
- max_length=max_length,
60
- truncation=True,
61
- padding=True,
62
- return_tensors='pt'
63
- )
64
- encoded_tgt = tokenizer(
65
- tgt,
66
- max_length=max_length,
67
- truncation=True,
68
- padding=True,
69
- return_tensors='pt'
70
- )
71
-
72
- src_tokens = encoded_src['input_ids'].to(device)
73
- src_mask = encoded_src['attention_mask'].to(device)
74
-
75
- tgt_tokens = encoded_tgt['input_ids'].to(device)[:, 0].unsqueeze(-1)
76
-
77
- output = rank_model(
78
- input_ids=src_tokens,
79
- attention_mask=src_mask,
80
- labels=tgt_tokens,
81
- use_cache = False
82
- )
83
-
84
- logits = output.logits.view(-1, rank_model.config.vocab_size)
85
-
86
- pos_score = softmax(logits)[:, pos_id] # Yes
87
- neg_score = softmax(logits)[:, neg_id]
88
- score = pos_score / (pos_score + neg_score)
89
-
90
- results.append(score.item())
91
-
92
- return_dict[rank] = results
93
-
94
- def evaluate(self, pairs: list[TextPair]) -> list[dict]:
95
- import torch.multiprocessing as mp
96
- final_results = []
97
- for dimension in self.dimensions:
98
- chunk_size = len(pairs) // self.num_gpus
99
- chunks = []
100
- for i in range(self.num_gpus):
101
- start = i * chunk_size
102
- end = start + chunk_size
103
- if i == self.num_gpus - 1:
104
- end = len(pairs)
105
- chunks.append(pairs[start:end])
106
-
107
- # multi-process
108
- manager = mp.Manager()
109
- return_dict = manager.dict()
110
- processes = []
111
-
112
- for rank, chunk in enumerate(chunks):
113
- p = mp.Process(
114
- target=self.process_chunk,
115
- args=(rank, chunk, self.model_name, self.max_length, dimension, return_dict)
116
- )
117
- p.start()
118
- processes.append(p)
119
-
120
- for p in processes:
121
- p.join()
122
-
123
- # 合并结果
124
- results = []
125
- for rank in range(len(chunks)):
126
- results.extend(return_dict[rank])
127
-
128
- for p in processes:
129
- if p.is_alive():
130
- p.terminate()
131
- p.join()
132
-
133
- final_results.append({
134
- dimension: results
135
- })
136
- return final_results
137
-
138
- def get_average_score(self, pairs: list[TextPair]) -> dict:
139
- """
140
- Get the average score of a batch of texts.
141
- """
142
- results = self.evaluate(pairs)
143
- final_results = {}
144
- for result in results:
145
- for key, value in result.items():
146
- final_results[key] = sum(value) / len(value)
147
- self.results[key] = value
148
- return final_results
149
-
150
- def get_min_max_score(self, pairs: list[TextPair]) -> dict:
151
- """
152
- Get the min and max score of a batch of texts.
153
- """
154
- if self.results is None:
155
- self.get_average_score(pairs)
156
- final_results = {}
157
- for key, value in self.results.items():
158
- final_results[key] = min(value), max(value)
159
- return final_results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf-repo/graphgen/models/llm/__init__.py DELETED
File without changes
hf-repo/graphgen/models/llm/limitter.py DELETED
@@ -1,88 +0,0 @@
1
- import time
2
- from datetime import datetime, timedelta
3
- import asyncio
4
-
5
- from graphgen.utils import logger
6
-
7
-
8
- class RPM:
9
-
10
- def __init__(self, rpm: int = 1000):
11
- self.rpm = rpm
12
- self.record = {'rpm_slot': self.get_minute_slot(), 'counter': 0}
13
-
14
- def get_minute_slot(self):
15
- current_time = time.time()
16
- dt_object = datetime.fromtimestamp(current_time)
17
- total_minutes_since_midnight = dt_object.hour * 60 + dt_object.minute
18
- return total_minutes_since_midnight
19
-
20
- async def wait(self, silent=False):
21
- current = time.time()
22
- dt_object = datetime.fromtimestamp(current)
23
- minute_slot = self.get_minute_slot()
24
-
25
- if self.record['rpm_slot'] == minute_slot:
26
- # check RPM exceed
27
- if self.record['counter'] >= self.rpm:
28
- # wait until next minute
29
- next_minute = dt_object.replace(
30
- second=0, microsecond=0) + timedelta(minutes=1)
31
- _next = next_minute.timestamp()
32
- sleep_time = abs(_next - current)
33
- if not silent:
34
- logger.info('RPM sleep %s', sleep_time)
35
- await asyncio.sleep(sleep_time)
36
-
37
- self.record = {
38
- 'rpm_slot': self.get_minute_slot(),
39
- 'counter': 0
40
- }
41
- else:
42
- self.record = {'rpm_slot': self.get_minute_slot(), 'counter': 0}
43
- self.record['counter'] += 1
44
-
45
- if not silent:
46
- logger.debug(self.record)
47
-
48
-
49
- class TPM:
50
-
51
- def __init__(self, tpm: int = 20000):
52
- self.tpm = tpm
53
- self.record = {'tpm_slot': self.get_minute_slot(), 'counter': 0}
54
-
55
- def get_minute_slot(self):
56
- current_time = time.time()
57
- dt_object = datetime.fromtimestamp(current_time)
58
- total_minutes_since_midnight = dt_object.hour * 60 + dt_object.minute
59
- return total_minutes_since_midnight
60
-
61
- async def wait(self, token_count, silent=False):
62
- current = time.time()
63
- dt_object = datetime.fromtimestamp(current)
64
- minute_slot = self.get_minute_slot()
65
-
66
- # get next slot, skip
67
- if self.record['tpm_slot'] != minute_slot:
68
- self.record = {'tpm_slot': minute_slot, 'counter': token_count}
69
- return
70
-
71
- # check RPM exceed
72
- self.record['counter'] += token_count
73
- if self.record['counter'] > self.tpm:
74
- # wait until next minute
75
- next_minute = dt_object.replace(
76
- second=0, microsecond=0) + timedelta(minutes=1)
77
- _next = next_minute.timestamp()
78
- sleep_time = abs(_next - current)
79
- logger.info('TPM sleep %s', sleep_time)
80
- await asyncio.sleep(sleep_time)
81
-
82
- self.record = {
83
- 'tpm_slot': self.get_minute_slot(),
84
- 'counter': token_count
85
- }
86
-
87
- if not silent:
88
- logger.debug(self.record)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf-repo/graphgen/models/llm/openai_model.py DELETED
@@ -1,155 +0,0 @@
1
- import math
2
- import re
3
- from dataclasses import dataclass, field
4
- from typing import Dict, List, Optional
5
-
6
- import openai
7
- from openai import APIConnectionError, APITimeoutError, AsyncOpenAI, RateLimitError
8
- from tenacity import (
9
- retry,
10
- retry_if_exception_type,
11
- stop_after_attempt,
12
- wait_exponential,
13
- )
14
-
15
- from graphgen.models.llm.limitter import RPM, TPM
16
- from graphgen.models.llm.tokenizer import Tokenizer
17
- from graphgen.models.llm.topk_token_model import Token, TopkTokenModel
18
-
19
-
20
- def get_top_response_tokens(response: openai.ChatCompletion) -> List[Token]:
21
- token_logprobs = response.choices[0].logprobs.content
22
- tokens = []
23
- for token_prob in token_logprobs:
24
- prob = math.exp(token_prob.logprob)
25
- candidate_tokens = [
26
- Token(t.token, math.exp(t.logprob)) for t in token_prob.top_logprobs
27
- ]
28
- token = Token(token_prob.token, prob, top_candidates=candidate_tokens)
29
- tokens.append(token)
30
- return tokens
31
-
32
-
33
- def filter_think_tags(text: str) -> str:
34
- """
35
- Remove <think> tags from the text.
36
- If the text contains <think> and </think>, it removes everything between them and the tags themselves.
37
- """
38
- think_pattern = re.compile(r"<think>.*?</think>", re.DOTALL)
39
- filtered_text = think_pattern.sub("", text).strip()
40
- return filtered_text if filtered_text else text.strip()
41
-
42
-
43
- @dataclass
44
- class OpenAIModel(TopkTokenModel):
45
- model_name: str = "gpt-4o-mini"
46
- api_key: str = None
47
- base_url: str = None
48
-
49
- system_prompt: str = ""
50
- json_mode: bool = False
51
- seed: int = None
52
-
53
- token_usage: list = field(default_factory=list)
54
- request_limit: bool = False
55
- rpm: RPM = field(default_factory=lambda: RPM(rpm=1000))
56
- tpm: TPM = field(default_factory=lambda: TPM(tpm=50000))
57
-
58
- tokenizer_instance: Tokenizer = field(default_factory=Tokenizer)
59
-
60
- def __post_init__(self):
61
- assert self.api_key is not None, "Please provide api key to access openai api."
62
- self.client = AsyncOpenAI(
63
- api_key=self.api_key or "dummy", base_url=self.base_url
64
- )
65
-
66
- def _pre_generate(self, text: str, history: List[str]) -> Dict:
67
- kwargs = {
68
- "temperature": self.temperature,
69
- "top_p": self.topp,
70
- "max_tokens": self.max_tokens,
71
- }
72
- if self.seed:
73
- kwargs["seed"] = self.seed
74
- if self.json_mode:
75
- kwargs["response_format"] = {"type": "json_object"}
76
-
77
- messages = []
78
- if self.system_prompt:
79
- messages.append({"role": "system", "content": self.system_prompt})
80
- messages.append({"role": "user", "content": text})
81
-
82
- if history:
83
- assert len(history) % 2 == 0, "History should have even number of elements."
84
- messages = history + messages
85
-
86
- kwargs["messages"] = messages
87
- return kwargs
88
-
89
- @retry(
90
- stop=stop_after_attempt(5),
91
- wait=wait_exponential(multiplier=1, min=4, max=10),
92
- retry=retry_if_exception_type(
93
- (RateLimitError, APIConnectionError, APITimeoutError)
94
- ),
95
- )
96
- async def generate_topk_per_token(
97
- self, text: str, history: Optional[List[str]] = None
98
- ) -> List[Token]:
99
- kwargs = self._pre_generate(text, history)
100
- if self.topk_per_token > 0:
101
- kwargs["logprobs"] = True
102
- kwargs["top_logprobs"] = self.topk_per_token
103
-
104
- # Limit max_tokens to 1 to avoid long completions
105
- kwargs["max_tokens"] = 1
106
-
107
- completion = await self.client.chat.completions.create( # pylint: disable=E1125
108
- model=self.model_name, **kwargs
109
- )
110
-
111
- tokens = get_top_response_tokens(completion)
112
-
113
- return tokens
114
-
115
- @retry(
116
- stop=stop_after_attempt(5),
117
- wait=wait_exponential(multiplier=1, min=4, max=10),
118
- retry=retry_if_exception_type(
119
- (RateLimitError, APIConnectionError, APITimeoutError)
120
- ),
121
- )
122
- async def generate_answer(
123
- self, text: str, history: Optional[List[str]] = None, temperature: int = 0
124
- ) -> str:
125
- kwargs = self._pre_generate(text, history)
126
- kwargs["temperature"] = temperature
127
-
128
- prompt_tokens = 0
129
- for message in kwargs["messages"]:
130
- prompt_tokens += len(
131
- self.tokenizer_instance.encode_string(message["content"])
132
- )
133
- estimated_tokens = prompt_tokens + kwargs["max_tokens"]
134
-
135
- if self.request_limit:
136
- await self.rpm.wait(silent=True)
137
- await self.tpm.wait(estimated_tokens, silent=True)
138
-
139
- completion = await self.client.chat.completions.create( # pylint: disable=E1125
140
- model=self.model_name, **kwargs
141
- )
142
- if hasattr(completion, "usage"):
143
- self.token_usage.append(
144
- {
145
- "prompt_tokens": completion.usage.prompt_tokens,
146
- "completion_tokens": completion.usage.completion_tokens,
147
- "total_tokens": completion.usage.total_tokens,
148
- }
149
- )
150
- return filter_think_tags(completion.choices[0].message.content)
151
-
152
- async def generate_inputs_prob(
153
- self, text: str, history: Optional[List[str]] = None
154
- ) -> List[Token]:
155
- raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf-repo/graphgen/models/llm/tokenizer.py DELETED
@@ -1,73 +0,0 @@
1
- from dataclasses import dataclass
2
- from typing import List
3
- import tiktoken
4
-
5
- try:
6
- from transformers import AutoTokenizer
7
- TRANSFORMERS_AVAILABLE = True
8
- except ImportError:
9
- AutoTokenizer = None
10
- TRANSFORMERS_AVAILABLE = False
11
-
12
-
13
- def get_tokenizer(tokenizer_name: str = "cl100k_base"):
14
- """
15
- Get a tokenizer instance by name.
16
-
17
- :param tokenizer_name: tokenizer name, tiktoken encoding name or Hugging Face model name
18
- :return: tokenizer instance
19
- """
20
- if tokenizer_name in tiktoken.list_encoding_names():
21
- return tiktoken.get_encoding(tokenizer_name)
22
- if TRANSFORMERS_AVAILABLE:
23
- try:
24
- return AutoTokenizer.from_pretrained(tokenizer_name)
25
- except Exception as e:
26
- raise ValueError(f"Failed to load tokenizer from Hugging Face: {e}") from e
27
- else:
28
- raise ValueError("Hugging Face Transformers is not available, please install it first.")
29
-
30
- @dataclass
31
- class Tokenizer:
32
- model_name: str = "cl100k_base"
33
-
34
- def __post_init__(self):
35
- self.tokenizer = get_tokenizer(self.model_name)
36
-
37
- def encode_string(self, text: str) -> List[int]:
38
- """
39
- Encode text to tokens
40
-
41
- :param text
42
- :return: tokens
43
- """
44
- return self.tokenizer.encode(text)
45
-
46
- def decode_tokens(self, tokens: List[int]) -> str:
47
- """
48
- Decode tokens to text
49
-
50
- :param tokens
51
- :return: text
52
- """
53
- return self.tokenizer.decode(tokens)
54
-
55
- def chunk_by_token_size(
56
- self, content: str, overlap_token_size=128, max_token_size=1024
57
- ):
58
- tokens = self.encode_string(content)
59
- results = []
60
- for index, start in enumerate(
61
- range(0, len(tokens), max_token_size - overlap_token_size)
62
- ):
63
- chunk_content = self.decode_tokens(
64
- tokens[start : start + max_token_size]
65
- )
66
- results.append(
67
- {
68
- "tokens": min(max_token_size, len(tokens) - start),
69
- "content": chunk_content.strip(),
70
- "chunk_order_index": index,
71
- }
72
- )
73
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf-repo/graphgen/models/llm/topk_token_model.py DELETED
@@ -1,48 +0,0 @@
1
- import math
2
- from dataclasses import dataclass, field
3
- from typing import List, Union, Optional
4
-
5
-
6
- @dataclass
7
- class Token:
8
- text: str
9
- prob: float
10
- top_candidates: List = field(default_factory=list)
11
- ppl: Union[float, None] = field(default=None)
12
-
13
- @property
14
- def logprob(self) -> float:
15
- return math.log(self.prob)
16
-
17
-
18
- @dataclass
19
- class TopkTokenModel:
20
- do_sample: bool = False
21
- temperature: float = 0
22
- max_tokens: int = 4096
23
- repetition_penalty: float = 1.05
24
- num_beams: int = 1
25
- topk: int = 50
26
- topp: float = 0.95
27
-
28
- topk_per_token: int = 5 # number of topk tokens to generate for each token
29
-
30
- async def generate_topk_per_token(self, text: str) -> List[Token]:
31
- """
32
- Generate prob, text and candidates for each token of the model's output.
33
- This function is used to visualize the inference process.
34
- """
35
- raise NotImplementedError
36
-
37
- async def generate_inputs_prob(self, text: str, history: Optional[List[str]] = None) -> List[Token]:
38
- """
39
- Generate prob and text for each token of the input text.
40
- This function is used to visualize the ppl.
41
- """
42
- raise NotImplementedError
43
-
44
- async def generate_answer(self, text: str, history: Optional[List[str]] = None) -> str:
45
- """
46
- Generate answer from the model.
47
- """
48
- raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf-repo/graphgen/models/search/__init__.py DELETED
File without changes
hf-repo/graphgen/models/search/db/__init__.py DELETED
File without changes
hf-repo/graphgen/models/search/db/uniprot_search.py DELETED
@@ -1,64 +0,0 @@
1
- from dataclasses import dataclass
2
-
3
- import requests
4
- from fastapi import HTTPException
5
-
6
- from graphgen.utils import logger
7
-
8
- UNIPROT_BASE = "https://rest.uniprot.org/uniprotkb/search"
9
-
10
-
11
- @dataclass
12
- class UniProtSearch:
13
- """
14
- UniProt Search client to search with UniProt.
15
- 1) Get the protein by accession number.
16
- 2) Search with keywords or protein names.
17
- """
18
-
19
- def get_entry(self, accession: str) -> dict:
20
- """
21
- Get the UniProt entry by accession number(e.g., P04637).
22
- """
23
- url = f"{UNIPROT_BASE}/{accession}.json"
24
- return self._safe_get(url).json()
25
-
26
- def search(
27
- self,
28
- query: str,
29
- *,
30
- size: int = 10,
31
- cursor: str = None,
32
- fields: list[str] = None,
33
- ) -> dict:
34
- """
35
- Search UniProt with a query string.
36
- :param query: The search query.
37
- :param size: The number of results to return.
38
- :param cursor: The cursor for pagination.
39
- :param fields: The fields to return in the response.
40
- :return: A dictionary containing the search results.
41
- """
42
- params = {
43
- "query": query,
44
- "size": size,
45
- }
46
- if cursor:
47
- params["cursor"] = cursor
48
- if fields:
49
- params["fields"] = ",".join(fields)
50
- url = UNIPROT_BASE
51
- return self._safe_get(url, params=params).json()
52
-
53
- @staticmethod
54
- def _safe_get(url: str, params: dict = None) -> requests.Response:
55
- r = requests.get(
56
- url,
57
- params=params,
58
- headers={"Accept": "application/json"},
59
- timeout=10,
60
- )
61
- if not r.ok:
62
- logger.error("Search engine error: %s", r.text)
63
- raise HTTPException(r.status_code, "Search engine error.")
64
- return r
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf-repo/graphgen/models/search/kg/__init__.py DELETED
File without changes
hf-repo/graphgen/models/search/kg/wiki_search.py DELETED
@@ -1,37 +0,0 @@
1
- from dataclasses import dataclass
2
- from typing import List, Union
3
-
4
- import wikipedia
5
- from wikipedia import set_lang
6
-
7
- from graphgen.utils import detect_main_language, logger
8
-
9
-
10
- @dataclass
11
- class WikiSearch:
12
- @staticmethod
13
- def set_language(language: str):
14
- assert language in ["en", "zh"], "Only support English and Chinese"
15
- set_lang(language)
16
-
17
- async def search(self, query: str, num_results: int = 1) -> Union[List[str], None]:
18
- self.set_language(detect_main_language(query))
19
- return wikipedia.search(query, results=num_results, suggestion=False)
20
-
21
- async def summary(self, query: str) -> Union[str, None]:
22
- self.set_language(detect_main_language(query))
23
- try:
24
- result = wikipedia.summary(query, auto_suggest=False, redirect=False)
25
- except wikipedia.exceptions.DisambiguationError as e:
26
- logger.error("DisambiguationError: %s", e)
27
- result = None
28
- return result
29
-
30
- async def page(self, query: str) -> Union[str, None]:
31
- self.set_language(detect_main_language(query))
32
- try:
33
- result = wikipedia.page(query, auto_suggest=False, redirect=False).content
34
- except wikipedia.exceptions.DisambiguationError as e:
35
- logger.error("DisambiguationError: %s", e)
36
- result = None
37
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf-repo/graphgen/models/search/web/__init__.py DELETED
File without changes
hf-repo/graphgen/models/search/web/bing_search.py DELETED
@@ -1,43 +0,0 @@
1
- from dataclasses import dataclass
2
-
3
- import requests
4
- from fastapi import HTTPException
5
-
6
- from graphgen.utils import logger
7
-
8
- BING_SEARCH_V7_ENDPOINT = "https://api.bing.microsoft.com/v7.0/search"
9
- BING_MKT = "en-US"
10
-
11
-
12
- @dataclass
13
- class BingSearch:
14
- """
15
- Bing Search client to search with Bing.
16
- """
17
-
18
- subscription_key: str
19
-
20
- def search(self, query: str, num_results: int = 1):
21
- """
22
- Search with Bing and return the contexts.
23
- :param query: The search query.
24
- :param num_results: The number of results to return.
25
- :return: A list of search results.
26
- """
27
- params = {"q": query, "mkt": BING_MKT, "count": num_results}
28
- response = requests.get(
29
- BING_SEARCH_V7_ENDPOINT,
30
- headers={"Ocp-Apim-Subscription-Key": self.subscription_key},
31
- params=params,
32
- timeout=10,
33
- )
34
- if not response.ok:
35
- logger.error("Search engine error: %s", response.text)
36
- raise HTTPException(response.status_code, "Search engine error.")
37
- json_content = response.json()
38
- try:
39
- contexts = json_content["webPages"]["value"][:num_results]
40
- except KeyError:
41
- logger.error("Error encountered: %s", json_content)
42
- return []
43
- return contexts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf-repo/graphgen/models/search/web/google_search.py DELETED
@@ -1,45 +0,0 @@
1
- from dataclasses import dataclass
2
-
3
- import requests
4
- from fastapi import HTTPException
5
-
6
- from graphgen.utils import logger
7
-
8
- GOOGLE_SEARCH_ENDPOINT = "https://customsearch.googleapis.com/customsearch/v1"
9
-
10
-
11
- @dataclass
12
- class GoogleSearch:
13
- def __init__(self, subscription_key: str, cx: str):
14
- """
15
- Initialize the Google Search client with the subscription key and custom search engine ID.
16
- :param subscription_key: Your Google API subscription key.
17
- :param cx: Your custom search engine ID.
18
- """
19
- self.subscription_key = subscription_key
20
- self.cx = cx
21
-
22
- def search(self, query: str, num_results: int = 1):
23
- """
24
- Search with Google and return the contexts.
25
- :param query: The search query.
26
- :param num_results: The number of results to return.
27
- :return: A list of search results.
28
- """
29
- params = {
30
- "key": self.subscription_key,
31
- "cx": self.cx,
32
- "q": query,
33
- "num": num_results,
34
- }
35
- response = requests.get(GOOGLE_SEARCH_ENDPOINT, params=params, timeout=10)
36
- if not response.ok:
37
- logger.error("Search engine error: %s", response.text)
38
- raise HTTPException(response.status_code, "Search engine error.")
39
- json_content = response.json()
40
- try:
41
- contexts = json_content["items"][:num_results]
42
- except KeyError:
43
- logger.error("Error encountered: %s", json_content)
44
- return []
45
- return contexts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf-repo/graphgen/models/storage/__init__.py DELETED
File without changes
hf-repo/graphgen/models/storage/base_storage.py DELETED
@@ -1,115 +0,0 @@
1
- from dataclasses import dataclass
2
- from typing import Generic, TypeVar, Union
3
-
4
- from graphgen.models.embed.embedding import EmbeddingFunc
5
-
6
- T = TypeVar("T")
7
-
8
-
9
- @dataclass
10
- class StorageNameSpace:
11
- working_dir: str = None
12
- namespace: str = None
13
-
14
- async def index_done_callback(self):
15
- """commit the storage operations after indexing"""
16
-
17
- async def query_done_callback(self):
18
- """commit the storage operations after querying"""
19
-
20
-
21
- @dataclass
22
- class BaseListStorage(Generic[T], StorageNameSpace):
23
- async def all_items(self) -> list[T]:
24
- raise NotImplementedError
25
-
26
- async def get_by_index(self, index: int) -> Union[T, None]:
27
- raise NotImplementedError
28
-
29
- async def append(self, data: T):
30
- raise NotImplementedError
31
-
32
- async def upsert(self, data: list[T]):
33
- raise NotImplementedError
34
-
35
- async def drop(self):
36
- raise NotImplementedError
37
-
38
-
39
- @dataclass
40
- class BaseKVStorage(Generic[T], StorageNameSpace):
41
- async def all_keys(self) -> list[str]:
42
- raise NotImplementedError
43
-
44
- async def get_by_id(self, id: str) -> Union[T, None]:
45
- raise NotImplementedError
46
-
47
- async def get_by_ids(
48
- self, ids: list[str], fields: Union[set[str], None] = None
49
- ) -> list[Union[T, None]]:
50
- raise NotImplementedError
51
-
52
- async def filter_keys(self, data: list[str]) -> set[str]:
53
- """return un-exist keys"""
54
- raise NotImplementedError
55
-
56
- async def upsert(self, data: dict[str, T]):
57
- raise NotImplementedError
58
-
59
- async def drop(self):
60
- raise NotImplementedError
61
-
62
-
63
- @dataclass
64
- class BaseGraphStorage(StorageNameSpace):
65
- embedding_func: EmbeddingFunc = None
66
-
67
- async def has_node(self, node_id: str) -> bool:
68
- raise NotImplementedError
69
-
70
- async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
71
- raise NotImplementedError
72
-
73
- async def node_degree(self, node_id: str) -> int:
74
- raise NotImplementedError
75
-
76
- async def edge_degree(self, src_id: str, tgt_id: str) -> int:
77
- raise NotImplementedError
78
-
79
- async def get_node(self, node_id: str) -> Union[dict, None]:
80
- raise NotImplementedError
81
-
82
- async def update_node(self, node_id: str, node_data: dict[str, str]):
83
- raise NotImplementedError
84
-
85
- async def get_all_nodes(self) -> Union[list[dict], None]:
86
- raise NotImplementedError
87
-
88
- async def get_edge(
89
- self, source_node_id: str, target_node_id: str
90
- ) -> Union[dict, None]:
91
- raise NotImplementedError
92
-
93
- async def update_edge(
94
- self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
95
- ):
96
- raise NotImplementedError
97
-
98
- async def get_all_edges(self) -> Union[list[dict], None]:
99
- raise NotImplementedError
100
-
101
- async def get_node_edges(
102
- self, source_node_id: str
103
- ) -> Union[list[tuple[str, str]], None]:
104
- raise NotImplementedError
105
-
106
- async def upsert_node(self, node_id: str, node_data: dict[str, str]):
107
- raise NotImplementedError
108
-
109
- async def upsert_edge(
110
- self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
111
- ):
112
- raise NotImplementedError
113
-
114
- async def delete_node(self, node_id: str):
115
- raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf-repo/graphgen/models/storage/json_storage.py DELETED
@@ -1,87 +0,0 @@
1
- import os
2
- from dataclasses import dataclass
3
-
4
- from graphgen.models.storage.base_storage import BaseKVStorage, BaseListStorage
5
- from graphgen.utils import load_json, logger, write_json
6
-
7
-
8
- @dataclass
9
- class JsonKVStorage(BaseKVStorage):
10
- _data: dict[str, str] = None
11
-
12
- def __post_init__(self):
13
- self._file_name = os.path.join(self.working_dir, f"{self.namespace}.json")
14
- self._data = load_json(self._file_name) or {}
15
- logger.info("Load KV %s with %d data", self.namespace, len(self._data))
16
-
17
- @property
18
- def data(self):
19
- return self._data
20
-
21
- async def all_keys(self) -> list[str]:
22
- return list(self._data.keys())
23
-
24
- async def index_done_callback(self):
25
- write_json(self._data, self._file_name)
26
-
27
- async def get_by_id(self, id):
28
- return self._data.get(id, None)
29
-
30
- async def get_by_ids(self, ids, fields=None) -> list:
31
- if fields is None:
32
- return [self._data.get(id, None) for id in ids]
33
- return [
34
- (
35
- {k: v for k, v in self._data[id].items() if k in fields}
36
- if self._data.get(id, None)
37
- else None
38
- )
39
- for id in ids
40
- ]
41
-
42
- async def filter_keys(self, data: list[str]) -> set[str]:
43
- return {s for s in data if s not in self._data}
44
-
45
- async def upsert(self, data: dict):
46
- left_data = {k: v for k, v in data.items() if k not in self._data}
47
- self._data.update(left_data)
48
- return left_data
49
-
50
- async def drop(self):
51
- self._data = {}
52
-
53
-
54
- @dataclass
55
- class JsonListStorage(BaseListStorage):
56
- _data: list = None
57
-
58
- def __post_init__(self):
59
- self._file_name = os.path.join(self.working_dir, f"{self.namespace}.json")
60
- self._data = load_json(self._file_name) or []
61
- logger.info("Load List %s with %d data", self.namespace, len(self._data))
62
-
63
- @property
64
- def data(self):
65
- return self._data
66
-
67
- async def all_items(self) -> list:
68
- return self._data
69
-
70
- async def index_done_callback(self):
71
- write_json(self._data, self._file_name)
72
-
73
- async def get_by_index(self, index: int):
74
- if index < 0 or index >= len(self._data):
75
- return None
76
- return self._data[index]
77
-
78
- async def append(self, data):
79
- self._data.append(data)
80
-
81
- async def upsert(self, data: list):
82
- left_data = [d for d in data if d not in self._data]
83
- self._data.extend(left_data)
84
- return left_data
85
-
86
- async def drop(self):
87
- self._data = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf-repo/graphgen/models/storage/networkx_storage.py DELETED
@@ -1,159 +0,0 @@
1
- import os
2
- import html
3
- from typing import Any, Union, cast, Optional
4
- from dataclasses import dataclass
5
- import networkx as nx
6
-
7
- from graphgen.utils import logger
8
- from .base_storage import BaseGraphStorage
9
-
10
- @dataclass
11
- class NetworkXStorage(BaseGraphStorage):
12
- @staticmethod
13
- def load_nx_graph(file_name) -> Optional[nx.Graph]:
14
- if os.path.exists(file_name):
15
- return nx.read_graphml(file_name)
16
- return None
17
-
18
- @staticmethod
19
- def write_nx_graph(graph: nx.Graph, file_name):
20
- logger.info("Writing graph with %d nodes, %d edges", graph.number_of_nodes(), graph.number_of_edges())
21
- nx.write_graphml(graph, file_name)
22
-
23
- @staticmethod
24
- def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
25
- """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
26
- Return the largest connected component of the graph, with nodes and edges sorted in a stable way.
27
- """
28
- from graspologic.utils import largest_connected_component
29
-
30
- graph = graph.copy()
31
- graph = cast(nx.Graph, largest_connected_component(graph))
32
- node_mapping = {
33
- node: html.unescape(node.upper().strip()) for node in graph.nodes()
34
- } # type: ignore
35
- graph = nx.relabel_nodes(graph, node_mapping)
36
- return NetworkXStorage._stabilize_graph(graph)
37
-
38
- @staticmethod
39
- def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
40
- """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
41
- Ensure an undirected graph with the same relationships will always be read the same way.
42
- 通过对节点和边进行排序来实现
43
- """
44
- fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
45
-
46
- sorted_nodes = graph.nodes(data=True)
47
- sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
48
-
49
- fixed_graph.add_nodes_from(sorted_nodes)
50
- edges = list(graph.edges(data=True))
51
-
52
- if not graph.is_directed():
53
-
54
- def _sort_source_target(edge):
55
- source, target, edge_data = edge
56
- if source > target:
57
- source, target = target, source
58
- return source, target, edge_data
59
-
60
- edges = [_sort_source_target(edge) for edge in edges]
61
-
62
- def _get_edge_key(source: Any, target: Any) -> str:
63
- return f"{source} -> {target}"
64
-
65
- edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
66
-
67
- fixed_graph.add_edges_from(edges)
68
- return fixed_graph
69
-
70
- def __post_init__(self):
71
- """
72
- 如果图文件存在,则加载图文件,否则创建一个新图
73
- """
74
- self._graphml_xml_file = os.path.join(
75
- self.working_dir, f"{self.namespace}.graphml"
76
- )
77
- preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
78
- if preloaded_graph is not None:
79
- logger.info(
80
- "Loaded graph from %s with %d nodes, %d edges", self._graphml_xml_file,
81
- preloaded_graph.number_of_nodes(), preloaded_graph.number_of_edges()
82
- )
83
- self._graph = preloaded_graph or nx.Graph()
84
-
85
- async def index_done_callback(self):
86
- NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
87
-
88
- async def has_node(self, node_id: str) -> bool:
89
- return self._graph.has_node(node_id)
90
-
91
- async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
92
- return self._graph.has_edge(source_node_id, target_node_id)
93
-
94
- async def get_node(self, node_id: str) -> Union[dict, None]:
95
- return self._graph.nodes.get(node_id)
96
-
97
- async def get_all_nodes(self) -> Union[list[dict], None]:
98
- return self._graph.nodes(data=True)
99
-
100
- async def node_degree(self, node_id: str) -> int:
101
- return self._graph.degree(node_id)
102
-
103
- async def edge_degree(self, src_id: str, tgt_id: str) -> int:
104
- return self._graph.degree(src_id) + self._graph.degree(tgt_id)
105
-
106
- async def get_edge(
107
- self, source_node_id: str, target_node_id: str
108
- ) -> Union[dict, None]:
109
- return self._graph.edges.get((source_node_id, target_node_id))
110
-
111
- async def get_all_edges(self) -> Union[list[dict], None]:
112
- return self._graph.edges(data=True)
113
-
114
- async def get_node_edges(self, source_node_id: str) -> Union[list[tuple[str, str]], None]:
115
- if self._graph.has_node(source_node_id):
116
- return list(self._graph.edges(source_node_id, data=True))
117
- return None
118
-
119
- async def get_graph(self) -> nx.Graph:
120
- return self._graph
121
-
122
- async def upsert_node(self, node_id: str, node_data: dict[str, str]):
123
- self._graph.add_node(node_id, **node_data)
124
-
125
- async def update_node(self, node_id: str, node_data: dict[str, str]):
126
- if self._graph.has_node(node_id):
127
- self._graph.nodes[node_id].update(node_data)
128
- else:
129
- logger.warning("Node %s not found in the graph for update.", node_id)
130
-
131
- async def upsert_edge(
132
- self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
133
- ):
134
- self._graph.add_edge(source_node_id, target_node_id, **edge_data)
135
-
136
- async def update_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]):
137
- if self._graph.has_edge(source_node_id, target_node_id):
138
- self._graph.edges[(source_node_id, target_node_id)].update(edge_data)
139
- else:
140
- logger.warning("Edge %s -> %s not found in the graph for update.", source_node_id, target_node_id)
141
-
142
- async def delete_node(self, node_id: str):
143
- """
144
- Delete a node from the graph based on the specified node_id.
145
-
146
- :param node_id: The node_id to delete
147
- """
148
- if self._graph.has_node(node_id):
149
- self._graph.remove_node(node_id)
150
- logger.info("Node %s deleted from the graph.", node_id)
151
- else:
152
- logger.warning("Node %s not found in the graph for deletion.", node_id)
153
-
154
- async def clear(self):
155
- """
156
- Clear the graph by removing all nodes and edges.
157
- """
158
- self._graph.clear()
159
- logger.info("Graph %s cleared.", self.namespace)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf-repo/graphgen/models/strategy/__init__.py DELETED
File without changes
hf-repo/graphgen/models/strategy/base_strategy.py DELETED
@@ -1,5 +0,0 @@
1
- from dataclasses import dataclass
2
-
3
- @dataclass
4
- class BaseStrategy:
5
- pass
 
 
 
 
 
 
hf-repo/graphgen/models/strategy/travserse_strategy.py DELETED
@@ -1,30 +0,0 @@
1
- from dataclasses import dataclass, fields
2
-
3
- from graphgen.models.strategy.base_strategy import BaseStrategy
4
-
5
-
6
- @dataclass
7
- class TraverseStrategy(BaseStrategy):
8
- # 生成的QA形式:原子、多跳、聚合型
9
- qa_form: str = "atomic" # "atomic" or "multi_hop" or "aggregated"
10
- # 最大边数和最大token数方法中选择一个生效
11
- expand_method: str = "max_tokens" # "max_width" or "max_tokens"
12
- # 单向拓展还是双向拓展
13
- bidirectional: bool = True
14
- # 每个方向拓展的最大边数
15
- max_extra_edges: int = 5
16
- # 最长token数
17
- max_tokens: int = 256
18
- # 每个方向拓展的最大深度
19
- max_depth: int = 2
20
- # 同一层中选边的策略(如果是双向拓展,同一层指的是两边连接的边的集合)
21
- edge_sampling: str = "max_loss" # "max_loss" or "min_loss" or "random"
22
- # 孤立节点的处理策略
23
- isolated_node_strategy: str = "add" # "add" or "ignore"
24
- loss_strategy: str = "only_edge" # only_edge, both
25
-
26
- def to_yaml(self):
27
- strategy_dict = {}
28
- for f in fields(self):
29
- strategy_dict[f.name] = getattr(self, f.name)
30
- return {"traverse_strategy": strategy_dict}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf-repo/graphgen/models/text/__init__.py DELETED
File without changes
hf-repo/graphgen/models/text/chunk.py DELETED
@@ -1,7 +0,0 @@
1
- from dataclasses import dataclass
2
-
3
-
4
- @dataclass
5
- class Chunk:
6
- id : str
7
- content: str
 
 
 
 
 
 
 
 
hf-repo/graphgen/models/text/text_pair.py DELETED
@@ -1,9 +0,0 @@
1
- from dataclasses import dataclass
2
-
3
- @dataclass
4
- class TextPair:
5
- """
6
- A pair of input data.
7
- """
8
- question: str
9
- answer: str
 
 
 
 
 
 
 
 
 
 
hf-repo/graphgen/models/vis/__init__.py DELETED
File without changes
hf-repo/graphgen/models/vis/community_visualizer.py DELETED
@@ -1,48 +0,0 @@
1
- from dataclasses import dataclass
2
- from typing import Dict
3
-
4
- import matplotlib.pyplot as plt
5
- import networkx as nx
6
-
7
-
8
- @dataclass
9
- class Visualizer:
10
- """
11
- Class for visualizing graphs using NetworkX and Matplotlib.
12
- """
13
-
14
- graph: nx.Graph = None
15
- communities: Dict[str, int] = None
16
- layout: str = "spring"
17
- max_nodes: int = 1000
18
- node_size: int = 10
19
- alpha: float = 0.6
20
-
21
- def visualize(self, save_path: str = None):
22
- n = self.graph.number_of_nodes()
23
- if self.layout == "spring":
24
- k = max(0.1, 1.0 / (n**0.5))
25
- pos = nx.spring_layout(self.graph, k=k, seed=42)
26
- else:
27
- raise ValueError(f"Unknown layout: {self.layout}")
28
-
29
- plt.figure(figsize=(10, 10))
30
-
31
- node_colors = [self.communities.get(node, 0) for node in self.graph.nodes()]
32
-
33
- nx.draw_networkx_nodes(
34
- self.graph,
35
- pos,
36
- node_size=self.node_size,
37
- node_color=node_colors,
38
- cmap=plt.cm.tab20,
39
- alpha=self.alpha,
40
- )
41
- nx.draw_networkx_edges(self.graph, pos, alpha=0.3, width=0.2)
42
- plt.axis("off")
43
-
44
- if save_path:
45
- plt.savefig(save_path, dpi=300, bbox_inches="tight")
46
- print("Saved to", save_path)
47
- else:
48
- plt.show()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf-repo/graphgen/operators/__init__.py DELETED
@@ -1,22 +0,0 @@
1
- from graphgen.operators.generate.generate_cot import generate_cot
2
- from graphgen.operators.kg.extract_kg import extract_kg
3
- from graphgen.operators.search.search_all import search_all
4
-
5
- from .judge import judge_statement
6
- from .quiz import quiz
7
- from .traverse_graph import (
8
- traverse_graph_atomically,
9
- traverse_graph_by_edge,
10
- traverse_graph_for_multi_hop,
11
- )
12
-
13
- __all__ = [
14
- "extract_kg",
15
- "quiz",
16
- "judge_statement",
17
- "search_all",
18
- "traverse_graph_by_edge",
19
- "traverse_graph_atomically",
20
- "traverse_graph_for_multi_hop",
21
- "generate_cot",
22
- ]