Spaces:
Running
Running
chenzihong-gavin
commited on
Commit
·
0682cc6
1
Parent(s):
4b2a9c2
delete
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- hf-repo/LICENSE +0 -201
- hf-repo/README.md +0 -43
- hf-repo/app.py +0 -587
- hf-repo/graphgen/__init__.py +0 -0
- hf-repo/graphgen/configs/README.md +0 -1
- hf-repo/graphgen/configs/aggregated_config.yaml +0 -21
- hf-repo/graphgen/configs/atomic_config.yaml +0 -21
- hf-repo/graphgen/configs/cot_config.yaml +0 -13
- hf-repo/graphgen/configs/multi_hop_config.yaml +0 -21
- hf-repo/graphgen/evaluate.py +0 -142
- hf-repo/graphgen/generate.py +0 -103
- hf-repo/graphgen/graphgen.py +0 -395
- hf-repo/graphgen/judge.py +0 -60
- hf-repo/graphgen/models/__init__.py +0 -45
- hf-repo/graphgen/models/community/__init__.py +0 -0
- hf-repo/graphgen/models/community/community_detector.py +0 -95
- hf-repo/graphgen/models/embed/__init__.py +0 -0
- hf-repo/graphgen/models/embed/embedding.py +0 -29
- hf-repo/graphgen/models/evaluate/__init__.py +0 -0
- hf-repo/graphgen/models/evaluate/base_evaluator.py +0 -51
- hf-repo/graphgen/models/evaluate/length_evaluator.py +0 -22
- hf-repo/graphgen/models/evaluate/mtld_evaluator.py +0 -76
- hf-repo/graphgen/models/evaluate/reward_evaluator.py +0 -101
- hf-repo/graphgen/models/evaluate/uni_evaluator.py +0 -159
- hf-repo/graphgen/models/llm/__init__.py +0 -0
- hf-repo/graphgen/models/llm/limitter.py +0 -88
- hf-repo/graphgen/models/llm/openai_model.py +0 -155
- hf-repo/graphgen/models/llm/tokenizer.py +0 -73
- hf-repo/graphgen/models/llm/topk_token_model.py +0 -48
- hf-repo/graphgen/models/search/__init__.py +0 -0
- hf-repo/graphgen/models/search/db/__init__.py +0 -0
- hf-repo/graphgen/models/search/db/uniprot_search.py +0 -64
- hf-repo/graphgen/models/search/kg/__init__.py +0 -0
- hf-repo/graphgen/models/search/kg/wiki_search.py +0 -37
- hf-repo/graphgen/models/search/web/__init__.py +0 -0
- hf-repo/graphgen/models/search/web/bing_search.py +0 -43
- hf-repo/graphgen/models/search/web/google_search.py +0 -45
- hf-repo/graphgen/models/storage/__init__.py +0 -0
- hf-repo/graphgen/models/storage/base_storage.py +0 -115
- hf-repo/graphgen/models/storage/json_storage.py +0 -87
- hf-repo/graphgen/models/storage/networkx_storage.py +0 -159
- hf-repo/graphgen/models/strategy/__init__.py +0 -0
- hf-repo/graphgen/models/strategy/base_strategy.py +0 -5
- hf-repo/graphgen/models/strategy/travserse_strategy.py +0 -30
- hf-repo/graphgen/models/text/__init__.py +0 -0
- hf-repo/graphgen/models/text/chunk.py +0 -7
- hf-repo/graphgen/models/text/text_pair.py +0 -9
- hf-repo/graphgen/models/vis/__init__.py +0 -0
- hf-repo/graphgen/models/vis/community_visualizer.py +0 -48
- hf-repo/graphgen/operators/__init__.py +0 -22
hf-repo/LICENSE
DELETED
|
@@ -1,201 +0,0 @@
|
|
| 1 |
-
Apache License
|
| 2 |
-
Version 2.0, January 2004
|
| 3 |
-
http://www.apache.org/licenses/
|
| 4 |
-
|
| 5 |
-
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
-
|
| 7 |
-
1. Definitions.
|
| 8 |
-
|
| 9 |
-
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
-
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
-
|
| 12 |
-
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
-
the copyright owner that is granting the License.
|
| 14 |
-
|
| 15 |
-
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
-
other entities that control, are controlled by, or are under common
|
| 17 |
-
control with that entity. For the purposes of this definition,
|
| 18 |
-
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
-
direction or management of such entity, whether by contract or
|
| 20 |
-
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
-
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
-
|
| 23 |
-
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
-
exercising permissions granted by this License.
|
| 25 |
-
|
| 26 |
-
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
-
including but not limited to software source code, documentation
|
| 28 |
-
source, and configuration files.
|
| 29 |
-
|
| 30 |
-
"Object" form shall mean any form resulting from mechanical
|
| 31 |
-
transformation or translation of a Source form, including but
|
| 32 |
-
not limited to compiled object code, generated documentation,
|
| 33 |
-
and conversions to other media types.
|
| 34 |
-
|
| 35 |
-
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
-
Object form, made available under the License, as indicated by a
|
| 37 |
-
copyright notice that is included in or attached to the work
|
| 38 |
-
(an example is provided in the Appendix below).
|
| 39 |
-
|
| 40 |
-
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
-
form, that is based on (or derived from) the Work and for which the
|
| 42 |
-
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
-
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
-
of this License, Derivative Works shall not include works that remain
|
| 45 |
-
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
-
the Work and Derivative Works thereof.
|
| 47 |
-
|
| 48 |
-
"Contribution" shall mean any work of authorship, including
|
| 49 |
-
the original version of the Work and any modifications or additions
|
| 50 |
-
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
-
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
-
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
-
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
-
means any form of electronic, verbal, or written communication sent
|
| 55 |
-
to the Licensor or its representatives, including but not limited to
|
| 56 |
-
communication on electronic mailing lists, source code control systems,
|
| 57 |
-
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
-
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
-
excluding communication that is conspicuously marked or otherwise
|
| 60 |
-
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
-
|
| 62 |
-
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
-
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
-
subsequently incorporated within the Work.
|
| 65 |
-
|
| 66 |
-
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
-
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
-
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
-
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
-
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
-
Work and such Derivative Works in Source or Object form.
|
| 72 |
-
|
| 73 |
-
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
-
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
-
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
-
(except as stated in this section) patent license to make, have made,
|
| 77 |
-
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
-
where such license applies only to those patent claims licensable
|
| 79 |
-
by such Contributor that are necessarily infringed by their
|
| 80 |
-
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
-
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
-
institute patent litigation against any entity (including a
|
| 83 |
-
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
-
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
-
or contributory patent infringement, then any patent licenses
|
| 86 |
-
granted to You under this License for that Work shall terminate
|
| 87 |
-
as of the date such litigation is filed.
|
| 88 |
-
|
| 89 |
-
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
-
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
-
modifications, and in Source or Object form, provided that You
|
| 92 |
-
meet the following conditions:
|
| 93 |
-
|
| 94 |
-
(a) You must give any other recipients of the Work or
|
| 95 |
-
Derivative Works a copy of this License; and
|
| 96 |
-
|
| 97 |
-
(b) You must cause any modified files to carry prominent notices
|
| 98 |
-
stating that You changed the files; and
|
| 99 |
-
|
| 100 |
-
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
-
that You distribute, all copyright, patent, trademark, and
|
| 102 |
-
attribution notices from the Source form of the Work,
|
| 103 |
-
excluding those notices that do not pertain to any part of
|
| 104 |
-
the Derivative Works; and
|
| 105 |
-
|
| 106 |
-
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
-
distribution, then any Derivative Works that You distribute must
|
| 108 |
-
include a readable copy of the attribution notices contained
|
| 109 |
-
within such NOTICE file, excluding those notices that do not
|
| 110 |
-
pertain to any part of the Derivative Works, in at least one
|
| 111 |
-
of the following places: within a NOTICE text file distributed
|
| 112 |
-
as part of the Derivative Works; within the Source form or
|
| 113 |
-
documentation, if provided along with the Derivative Works; or,
|
| 114 |
-
within a display generated by the Derivative Works, if and
|
| 115 |
-
wherever such third-party notices normally appear. The contents
|
| 116 |
-
of the NOTICE file are for informational purposes only and
|
| 117 |
-
do not modify the License. You may add Your own attribution
|
| 118 |
-
notices within Derivative Works that You distribute, alongside
|
| 119 |
-
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
-
that such additional attribution notices cannot be construed
|
| 121 |
-
as modifying the License.
|
| 122 |
-
|
| 123 |
-
You may add Your own copyright statement to Your modifications and
|
| 124 |
-
may provide additional or different license terms and conditions
|
| 125 |
-
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
-
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
-
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
-
the conditions stated in this License.
|
| 129 |
-
|
| 130 |
-
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
-
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
-
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
-
this License, without any additional terms or conditions.
|
| 134 |
-
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
-
the terms of any separate license agreement you may have executed
|
| 136 |
-
with Licensor regarding such Contributions.
|
| 137 |
-
|
| 138 |
-
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
-
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
-
except as required for reasonable and customary use in describing the
|
| 141 |
-
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
-
|
| 143 |
-
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
-
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
-
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
-
implied, including, without limitation, any warranties or conditions
|
| 148 |
-
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
-
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
-
appropriateness of using or redistributing the Work and assume any
|
| 151 |
-
risks associated with Your exercise of permissions under this License.
|
| 152 |
-
|
| 153 |
-
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
-
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
-
unless required by applicable law (such as deliberate and grossly
|
| 156 |
-
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
-
liable to You for damages, including any direct, indirect, special,
|
| 158 |
-
incidental, or consequential damages of any character arising as a
|
| 159 |
-
result of this License or out of the use or inability to use the
|
| 160 |
-
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
-
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
-
other commercial damages or losses), even if such Contributor
|
| 163 |
-
has been advised of the possibility of such damages.
|
| 164 |
-
|
| 165 |
-
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
-
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
-
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
-
or other liability obligations and/or rights consistent with this
|
| 169 |
-
License. However, in accepting such obligations, You may act only
|
| 170 |
-
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
-
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
-
defend, and hold each Contributor harmless for any liability
|
| 173 |
-
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
-
of your accepting any such warranty or additional liability.
|
| 175 |
-
|
| 176 |
-
END OF TERMS AND CONDITIONS
|
| 177 |
-
|
| 178 |
-
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
-
|
| 180 |
-
To apply the Apache License to your work, attach the following
|
| 181 |
-
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
-
replaced with your own identifying information. (Don't include
|
| 183 |
-
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
-
comment syntax for the file format. We also recommend that a
|
| 185 |
-
file or class name and description of purpose be included on the
|
| 186 |
-
same "printed page" as the copyright notice for easier
|
| 187 |
-
identification within third-party archives.
|
| 188 |
-
|
| 189 |
-
Copyright [yyyy] [name of copyright owner]
|
| 190 |
-
|
| 191 |
-
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
-
you may not use this file except in compliance with the License.
|
| 193 |
-
You may obtain a copy of the License at
|
| 194 |
-
|
| 195 |
-
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
-
|
| 197 |
-
Unless required by applicable law or agreed to in writing, software
|
| 198 |
-
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
-
See the License for the specific language governing permissions and
|
| 201 |
-
limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf-repo/README.md
DELETED
|
@@ -1,43 +0,0 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: GraphGen Demo
|
| 3 |
-
emoji: 📊
|
| 4 |
-
colorFrom: blue
|
| 5 |
-
colorTo: green
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: "5.44.0"
|
| 8 |
-
python_version: "3.10"
|
| 9 |
-
app_file: app.py
|
| 10 |
-
suggested_hardware: cpu-basic
|
| 11 |
-
pinned: false
|
| 12 |
-
short_description: "Knowledge-driven synthetic data generation demo"
|
| 13 |
-
tags:
|
| 14 |
-
- synthetic-data
|
| 15 |
-
- knowledge-graph
|
| 16 |
-
- gradio-demo
|
| 17 |
-
---
|
| 18 |
-
|
| 19 |
-
# GraphGen Space 🤖📊
|
| 20 |
-
|
| 21 |
-
This is the **official Hugging Face Space** for [GraphGen](https://github.com/open-sciencelab/GraphGen) – a framework that leverages knowledge graphs to generate high-quality synthetic question–answer pairs for supervised fine-tuning of LLMs.
|
| 22 |
-
|
| 23 |
-
🔗 Paper: [arXiv 2505.20416](https://arxiv.org/abs/2505.20416)
|
| 24 |
-
🐙 GitHub: [open-sciencelab/GraphGen](https://github.com/open-sciencelab/GraphGen)
|
| 25 |
-
|
| 26 |
-
---
|
| 27 |
-
|
| 28 |
-
## How to use (🖱️ 3 clicks)
|
| 29 |
-
|
| 30 |
-
1. Open the **Gradio app** above.
|
| 31 |
-
2. Upload or paste your source text → click **Generate KG**.
|
| 32 |
-
3. Download the generated QA pairs directly.
|
| 33 |
-
|
| 34 |
-
---
|
| 35 |
-
|
| 36 |
-
## Local quick start (optional)
|
| 37 |
-
|
| 38 |
-
```bash
|
| 39 |
-
git clone https://github.com/open-sciencelab/GraphGen
|
| 40 |
-
cd GraphGen
|
| 41 |
-
uv venv --python 3.10 && uv pip install -r requirements.txt
|
| 42 |
-
uv run webui/app.py # http://localhost:7860
|
| 43 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf-repo/app.py
DELETED
|
@@ -1,587 +0,0 @@
|
|
| 1 |
-
import json
|
| 2 |
-
import os
|
| 3 |
-
import sys
|
| 4 |
-
import tempfile
|
| 5 |
-
|
| 6 |
-
import gradio as gr
|
| 7 |
-
import pandas as pd
|
| 8 |
-
from gradio_i18n import Translate
|
| 9 |
-
from gradio_i18n import gettext as _
|
| 10 |
-
|
| 11 |
-
from webui.base import GraphGenParams
|
| 12 |
-
from webui.cache_utils import cleanup_workspace, setup_workspace
|
| 13 |
-
from webui.count_tokens import count_tokens
|
| 14 |
-
from webui.test_api import test_api_connection
|
| 15 |
-
|
| 16 |
-
# pylint: disable=wrong-import-position
|
| 17 |
-
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 18 |
-
sys.path.append(root_dir)
|
| 19 |
-
|
| 20 |
-
from graphgen.graphgen import GraphGen
|
| 21 |
-
from graphgen.models import OpenAIModel, Tokenizer, TraverseStrategy
|
| 22 |
-
from graphgen.models.llm.limitter import RPM, TPM
|
| 23 |
-
from graphgen.utils import set_logger
|
| 24 |
-
|
| 25 |
-
css = """
|
| 26 |
-
.center-row {
|
| 27 |
-
display: flex;
|
| 28 |
-
justify-content: center;
|
| 29 |
-
align-items: center;
|
| 30 |
-
}
|
| 31 |
-
"""
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def init_graph_gen(config: dict, env: dict) -> GraphGen:
|
| 35 |
-
# Set up working directory
|
| 36 |
-
log_file, working_dir = setup_workspace(os.path.join(root_dir, "cache"))
|
| 37 |
-
|
| 38 |
-
set_logger(log_file, if_stream=False)
|
| 39 |
-
graph_gen = GraphGen(working_dir=working_dir)
|
| 40 |
-
|
| 41 |
-
# Set up LLM clients
|
| 42 |
-
graph_gen.synthesizer_llm_client = OpenAIModel(
|
| 43 |
-
model_name=env.get("SYNTHESIZER_MODEL", ""),
|
| 44 |
-
base_url=env.get("SYNTHESIZER_BASE_URL", ""),
|
| 45 |
-
api_key=env.get("SYNTHESIZER_API_KEY", ""),
|
| 46 |
-
request_limit=True,
|
| 47 |
-
rpm=RPM(env.get("RPM", 1000)),
|
| 48 |
-
tpm=TPM(env.get("TPM", 50000)),
|
| 49 |
-
)
|
| 50 |
-
|
| 51 |
-
graph_gen.trainee_llm_client = OpenAIModel(
|
| 52 |
-
model_name=env.get("TRAINEE_MODEL", ""),
|
| 53 |
-
base_url=env.get("TRAINEE_BASE_URL", ""),
|
| 54 |
-
api_key=env.get("TRAINEE_API_KEY", ""),
|
| 55 |
-
request_limit=True,
|
| 56 |
-
rpm=RPM(env.get("RPM", 1000)),
|
| 57 |
-
tpm=TPM(env.get("TPM", 50000)),
|
| 58 |
-
)
|
| 59 |
-
|
| 60 |
-
graph_gen.tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base"))
|
| 61 |
-
|
| 62 |
-
strategy_config = config.get("traverse_strategy", {})
|
| 63 |
-
graph_gen.traverse_strategy = TraverseStrategy(
|
| 64 |
-
qa_form=strategy_config.get("qa_form"),
|
| 65 |
-
expand_method=strategy_config.get("expand_method"),
|
| 66 |
-
bidirectional=strategy_config.get("bidirectional"),
|
| 67 |
-
max_extra_edges=strategy_config.get("max_extra_edges"),
|
| 68 |
-
max_tokens=strategy_config.get("max_tokens"),
|
| 69 |
-
max_depth=strategy_config.get("max_depth"),
|
| 70 |
-
edge_sampling=strategy_config.get("edge_sampling"),
|
| 71 |
-
isolated_node_strategy=strategy_config.get("isolated_node_strategy"),
|
| 72 |
-
loss_strategy=str(strategy_config.get("loss_strategy")),
|
| 73 |
-
)
|
| 74 |
-
|
| 75 |
-
return graph_gen
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
# pylint: disable=too-many-statements
|
| 79 |
-
def run_graphgen(params, progress=gr.Progress()):
|
| 80 |
-
def sum_tokens(client):
|
| 81 |
-
return sum(u["total_tokens"] for u in client.token_usage)
|
| 82 |
-
|
| 83 |
-
config = {
|
| 84 |
-
"if_trainee_model": params.if_trainee_model,
|
| 85 |
-
"input_file": params.input_file,
|
| 86 |
-
"tokenizer": params.tokenizer,
|
| 87 |
-
"quiz_samples": params.quiz_samples,
|
| 88 |
-
"traverse_strategy": {
|
| 89 |
-
"qa_form": params.qa_form,
|
| 90 |
-
"bidirectional": params.bidirectional,
|
| 91 |
-
"expand_method": params.expand_method,
|
| 92 |
-
"max_extra_edges": params.max_extra_edges,
|
| 93 |
-
"max_tokens": params.max_tokens,
|
| 94 |
-
"max_depth": params.max_depth,
|
| 95 |
-
"edge_sampling": params.edge_sampling,
|
| 96 |
-
"isolated_node_strategy": params.isolated_node_strategy,
|
| 97 |
-
"loss_strategy": params.loss_strategy,
|
| 98 |
-
},
|
| 99 |
-
"chunk_size": params.chunk_size,
|
| 100 |
-
}
|
| 101 |
-
|
| 102 |
-
env = {
|
| 103 |
-
"SYNTHESIZER_BASE_URL": params.synthesizer_url,
|
| 104 |
-
"SYNTHESIZER_MODEL": params.synthesizer_model,
|
| 105 |
-
"TRAINEE_BASE_URL": params.trainee_url,
|
| 106 |
-
"TRAINEE_MODEL": params.trainee_model,
|
| 107 |
-
"SYNTHESIZER_API_KEY": params.api_key,
|
| 108 |
-
"TRAINEE_API_KEY": params.trainee_api_key,
|
| 109 |
-
"RPM": params.rpm,
|
| 110 |
-
"TPM": params.tpm,
|
| 111 |
-
}
|
| 112 |
-
|
| 113 |
-
# Test API connection
|
| 114 |
-
test_api_connection(
|
| 115 |
-
env["SYNTHESIZER_BASE_URL"],
|
| 116 |
-
env["SYNTHESIZER_API_KEY"],
|
| 117 |
-
env["SYNTHESIZER_MODEL"],
|
| 118 |
-
)
|
| 119 |
-
if config["if_trainee_model"]:
|
| 120 |
-
test_api_connection(
|
| 121 |
-
env["TRAINEE_BASE_URL"], env["TRAINEE_API_KEY"], env["TRAINEE_MODEL"]
|
| 122 |
-
)
|
| 123 |
-
|
| 124 |
-
# Initialize GraphGen
|
| 125 |
-
graph_gen = init_graph_gen(config, env)
|
| 126 |
-
graph_gen.clear()
|
| 127 |
-
|
| 128 |
-
graph_gen.progress_bar = progress
|
| 129 |
-
|
| 130 |
-
try:
|
| 131 |
-
# Load input data
|
| 132 |
-
file = config["input_file"]
|
| 133 |
-
if isinstance(file, list):
|
| 134 |
-
file = file[0]
|
| 135 |
-
|
| 136 |
-
data = []
|
| 137 |
-
|
| 138 |
-
if file.endswith(".jsonl"):
|
| 139 |
-
data_type = "raw"
|
| 140 |
-
with open(file, "r", encoding="utf-8") as f:
|
| 141 |
-
data.extend(json.loads(line) for line in f)
|
| 142 |
-
elif file.endswith(".json"):
|
| 143 |
-
data_type = "chunked"
|
| 144 |
-
with open(file, "r", encoding="utf-8") as f:
|
| 145 |
-
data.extend(json.load(f))
|
| 146 |
-
elif file.endswith(".txt"):
|
| 147 |
-
# 读取文件后根据chunk_size转成raw格式的数据
|
| 148 |
-
data_type = "raw"
|
| 149 |
-
content = ""
|
| 150 |
-
with open(file, "r", encoding="utf-8") as f:
|
| 151 |
-
lines = f.readlines()
|
| 152 |
-
for line in lines:
|
| 153 |
-
content += line.strip() + " "
|
| 154 |
-
size = int(config.get("chunk_size", 512))
|
| 155 |
-
chunks = [content[i : i + size] for i in range(0, len(content), size)]
|
| 156 |
-
data.extend([{"content": chunk} for chunk in chunks])
|
| 157 |
-
else:
|
| 158 |
-
raise ValueError(f"Unsupported file type: {file}")
|
| 159 |
-
|
| 160 |
-
# Process the data
|
| 161 |
-
graph_gen.insert(data, data_type)
|
| 162 |
-
|
| 163 |
-
if config["if_trainee_model"]:
|
| 164 |
-
# Generate quiz
|
| 165 |
-
graph_gen.quiz(max_samples=config["quiz_samples"])
|
| 166 |
-
|
| 167 |
-
# Judge statements
|
| 168 |
-
graph_gen.judge()
|
| 169 |
-
else:
|
| 170 |
-
graph_gen.traverse_strategy.edge_sampling = "random"
|
| 171 |
-
# Skip judge statements
|
| 172 |
-
graph_gen.judge(skip=True)
|
| 173 |
-
|
| 174 |
-
# Traverse graph
|
| 175 |
-
graph_gen.traverse(traverse_strategy=graph_gen.traverse_strategy)
|
| 176 |
-
|
| 177 |
-
# Save output
|
| 178 |
-
output_data = graph_gen.qa_storage.data
|
| 179 |
-
with tempfile.NamedTemporaryFile(
|
| 180 |
-
mode="w", suffix=".jsonl", delete=False, encoding="utf-8"
|
| 181 |
-
) as tmpfile:
|
| 182 |
-
json.dump(output_data, tmpfile, ensure_ascii=False)
|
| 183 |
-
output_file = tmpfile.name
|
| 184 |
-
|
| 185 |
-
synthesizer_tokens = sum_tokens(graph_gen.synthesizer_llm_client)
|
| 186 |
-
trainee_tokens = (
|
| 187 |
-
sum_tokens(graph_gen.trainee_llm_client)
|
| 188 |
-
if config["if_trainee_model"]
|
| 189 |
-
else 0
|
| 190 |
-
)
|
| 191 |
-
total_tokens = synthesizer_tokens + trainee_tokens
|
| 192 |
-
|
| 193 |
-
data_frame = params.token_counter
|
| 194 |
-
try:
|
| 195 |
-
_update_data = [
|
| 196 |
-
[data_frame.iloc[0, 0], data_frame.iloc[0, 1], str(total_tokens)]
|
| 197 |
-
]
|
| 198 |
-
new_df = pd.DataFrame(_update_data, columns=data_frame.columns)
|
| 199 |
-
data_frame = new_df
|
| 200 |
-
|
| 201 |
-
except Exception as e:
|
| 202 |
-
raise gr.Error(f"DataFrame operation error: {str(e)}")
|
| 203 |
-
|
| 204 |
-
return output_file, gr.DataFrame(
|
| 205 |
-
label="Token Stats",
|
| 206 |
-
headers=["Source Text Token Count", "Expected Token Usage", "Token Used"],
|
| 207 |
-
datatype="str",
|
| 208 |
-
interactive=False,
|
| 209 |
-
value=data_frame,
|
| 210 |
-
visible=True,
|
| 211 |
-
wrap=True,
|
| 212 |
-
)
|
| 213 |
-
|
| 214 |
-
except Exception as e: # pylint: disable=broad-except
|
| 215 |
-
raise gr.Error(f"Error occurred: {str(e)}")
|
| 216 |
-
|
| 217 |
-
finally:
|
| 218 |
-
# Clean up workspace
|
| 219 |
-
cleanup_workspace(graph_gen.working_dir)
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
|
| 223 |
-
# Header
|
| 224 |
-
gr.Image(
|
| 225 |
-
value=os.path.join(root_dir, "resources", "images", "logo.png"),
|
| 226 |
-
label="GraphGen Banner",
|
| 227 |
-
elem_id="banner",
|
| 228 |
-
interactive=False,
|
| 229 |
-
container=False,
|
| 230 |
-
show_download_button=False,
|
| 231 |
-
show_fullscreen_button=False,
|
| 232 |
-
)
|
| 233 |
-
lang_btn = gr.Radio(
|
| 234 |
-
choices=[
|
| 235 |
-
("English", "en"),
|
| 236 |
-
("简体中文", "zh"),
|
| 237 |
-
],
|
| 238 |
-
value="en",
|
| 239 |
-
# label=_("Language"),
|
| 240 |
-
render=False,
|
| 241 |
-
container=False,
|
| 242 |
-
elem_classes=["center-row"],
|
| 243 |
-
)
|
| 244 |
-
|
| 245 |
-
gr.HTML(
|
| 246 |
-
"""
|
| 247 |
-
<div style="display: flex; gap: 8px; margin-left: auto; align-items: center; justify-content: center;">
|
| 248 |
-
<a href="https://github.com/open-sciencelab/GraphGen/releases">
|
| 249 |
-
<img src="https://img.shields.io/badge/Version-v0.1.0-blue" alt="Version">
|
| 250 |
-
</a>
|
| 251 |
-
<a href="https://graphgen-docs.example.com">
|
| 252 |
-
<img src="https://img.shields.io/badge/Docs-Latest-brightgreen" alt="Documentation">
|
| 253 |
-
</a>
|
| 254 |
-
<a href="https://github.com/open-sciencelab/GraphGen/issues/10">
|
| 255 |
-
<img src="https://img.shields.io/github/stars/open-sciencelab/GraphGen?style=social" alt="GitHub Stars">
|
| 256 |
-
</a>
|
| 257 |
-
<a href="https://arxiv.org/abs/2505.20416">
|
| 258 |
-
<img src="https://img.shields.io/badge/arXiv-pdf-yellow" alt="arXiv">
|
| 259 |
-
</a>
|
| 260 |
-
</div>
|
| 261 |
-
"""
|
| 262 |
-
)
|
| 263 |
-
with Translate(
|
| 264 |
-
os.path.join(root_dir, "webui", "translation.json"),
|
| 265 |
-
lang_btn,
|
| 266 |
-
placeholder_langs=["en", "zh"],
|
| 267 |
-
persistant=False, # True to save the language setting in the browser. Requires gradio >= 5.6.0
|
| 268 |
-
):
|
| 269 |
-
lang_btn.render()
|
| 270 |
-
|
| 271 |
-
gr.Markdown(
|
| 272 |
-
value="# "
|
| 273 |
-
+ _("Title")
|
| 274 |
-
+ "\n\n"
|
| 275 |
-
+ "### [GraphGen](https://github.com/open-sciencelab/GraphGen) "
|
| 276 |
-
+ _("Intro")
|
| 277 |
-
)
|
| 278 |
-
|
| 279 |
-
if_trainee_model = gr.Checkbox(
|
| 280 |
-
label=_("Use Trainee Model"), value=False, interactive=True
|
| 281 |
-
)
|
| 282 |
-
|
| 283 |
-
with gr.Accordion(label=_("Model Config"), open=False):
|
| 284 |
-
synthesizer_url = gr.Textbox(
|
| 285 |
-
label="Synthesizer URL",
|
| 286 |
-
value="https://api.siliconflow.cn/v1",
|
| 287 |
-
info=_("Synthesizer URL Info"),
|
| 288 |
-
interactive=True,
|
| 289 |
-
)
|
| 290 |
-
synthesizer_model = gr.Textbox(
|
| 291 |
-
label="Synthesizer Model",
|
| 292 |
-
value="Qwen/Qwen2.5-7B-Instruct",
|
| 293 |
-
info=_("Synthesizer Model Info"),
|
| 294 |
-
interactive=True,
|
| 295 |
-
)
|
| 296 |
-
trainee_url = gr.Textbox(
|
| 297 |
-
label="Trainee URL",
|
| 298 |
-
value="https://api.siliconflow.cn/v1",
|
| 299 |
-
info=_("Trainee URL Info"),
|
| 300 |
-
interactive=True,
|
| 301 |
-
visible=if_trainee_model.value is True,
|
| 302 |
-
)
|
| 303 |
-
trainee_model = gr.Textbox(
|
| 304 |
-
label="Trainee Model",
|
| 305 |
-
value="Qwen/Qwen2.5-7B-Instruct",
|
| 306 |
-
info=_("Trainee Model Info"),
|
| 307 |
-
interactive=True,
|
| 308 |
-
visible=if_trainee_model.value is True,
|
| 309 |
-
)
|
| 310 |
-
trainee_api_key = gr.Textbox(
|
| 311 |
-
label=_("SiliconFlow Token for Trainee Model"),
|
| 312 |
-
type="password",
|
| 313 |
-
value="",
|
| 314 |
-
info="https://cloud.siliconflow.cn/account/ak",
|
| 315 |
-
visible=if_trainee_model.value is True,
|
| 316 |
-
)
|
| 317 |
-
|
| 318 |
-
with gr.Accordion(label=_("Generation Config"), open=False):
|
| 319 |
-
chunk_size = gr.Slider(
|
| 320 |
-
label="Chunk Size",
|
| 321 |
-
minimum=256,
|
| 322 |
-
maximum=4096,
|
| 323 |
-
value=512,
|
| 324 |
-
step=256,
|
| 325 |
-
interactive=True,
|
| 326 |
-
)
|
| 327 |
-
tokenizer = gr.Textbox(
|
| 328 |
-
label="Tokenizer", value="cl100k_base", interactive=True
|
| 329 |
-
)
|
| 330 |
-
qa_form = gr.Radio(
|
| 331 |
-
choices=["atomic", "multi_hop", "aggregated"],
|
| 332 |
-
label="QA Form",
|
| 333 |
-
value="aggregated",
|
| 334 |
-
interactive=True,
|
| 335 |
-
)
|
| 336 |
-
quiz_samples = gr.Number(
|
| 337 |
-
label="Quiz Samples",
|
| 338 |
-
value=2,
|
| 339 |
-
minimum=1,
|
| 340 |
-
interactive=True,
|
| 341 |
-
visible=if_trainee_model.value is True,
|
| 342 |
-
)
|
| 343 |
-
bidirectional = gr.Checkbox(
|
| 344 |
-
label="Bidirectional", value=True, interactive=True
|
| 345 |
-
)
|
| 346 |
-
|
| 347 |
-
expand_method = gr.Radio(
|
| 348 |
-
choices=["max_width", "max_tokens"],
|
| 349 |
-
label="Expand Method",
|
| 350 |
-
value="max_tokens",
|
| 351 |
-
interactive=True,
|
| 352 |
-
)
|
| 353 |
-
max_extra_edges = gr.Slider(
|
| 354 |
-
minimum=1,
|
| 355 |
-
maximum=10,
|
| 356 |
-
value=5,
|
| 357 |
-
label="Max Extra Edges",
|
| 358 |
-
step=1,
|
| 359 |
-
interactive=True,
|
| 360 |
-
visible=expand_method.value == "max_width",
|
| 361 |
-
)
|
| 362 |
-
max_tokens = gr.Slider(
|
| 363 |
-
minimum=64,
|
| 364 |
-
maximum=1024,
|
| 365 |
-
value=256,
|
| 366 |
-
label="Max Tokens",
|
| 367 |
-
step=64,
|
| 368 |
-
interactive=True,
|
| 369 |
-
visible=(expand_method.value != "max_width"),
|
| 370 |
-
)
|
| 371 |
-
|
| 372 |
-
max_depth = gr.Slider(
|
| 373 |
-
minimum=1,
|
| 374 |
-
maximum=5,
|
| 375 |
-
value=2,
|
| 376 |
-
label="Max Depth",
|
| 377 |
-
step=1,
|
| 378 |
-
interactive=True,
|
| 379 |
-
)
|
| 380 |
-
edge_sampling = gr.Radio(
|
| 381 |
-
choices=["max_loss", "min_loss", "random"],
|
| 382 |
-
label="Edge Sampling",
|
| 383 |
-
value="max_loss",
|
| 384 |
-
interactive=True,
|
| 385 |
-
visible=if_trainee_model.value is True,
|
| 386 |
-
)
|
| 387 |
-
isolated_node_strategy = gr.Radio(
|
| 388 |
-
choices=["add", "ignore"],
|
| 389 |
-
label="Isolated Node Strategy",
|
| 390 |
-
value="ignore",
|
| 391 |
-
interactive=True,
|
| 392 |
-
)
|
| 393 |
-
loss_strategy = gr.Radio(
|
| 394 |
-
choices=["only_edge", "both"],
|
| 395 |
-
label="Loss Strategy",
|
| 396 |
-
value="only_edge",
|
| 397 |
-
interactive=True,
|
| 398 |
-
)
|
| 399 |
-
|
| 400 |
-
with gr.Row(equal_height=True):
|
| 401 |
-
with gr.Column(scale=3):
|
| 402 |
-
api_key = gr.Textbox(
|
| 403 |
-
label=_("SiliconFlow Token"),
|
| 404 |
-
type="password",
|
| 405 |
-
value="",
|
| 406 |
-
info="https://cloud.siliconflow.cn/account/ak",
|
| 407 |
-
)
|
| 408 |
-
with gr.Column(scale=1):
|
| 409 |
-
test_connection_btn = gr.Button(_("Test Connection"))
|
| 410 |
-
|
| 411 |
-
with gr.Blocks():
|
| 412 |
-
with gr.Row(equal_height=True):
|
| 413 |
-
with gr.Column():
|
| 414 |
-
rpm = gr.Slider(
|
| 415 |
-
label="RPM",
|
| 416 |
-
minimum=10,
|
| 417 |
-
maximum=10000,
|
| 418 |
-
value=1000,
|
| 419 |
-
step=100,
|
| 420 |
-
interactive=True,
|
| 421 |
-
visible=True,
|
| 422 |
-
)
|
| 423 |
-
with gr.Column():
|
| 424 |
-
tpm = gr.Slider(
|
| 425 |
-
label="TPM",
|
| 426 |
-
minimum=5000,
|
| 427 |
-
maximum=5000000,
|
| 428 |
-
value=50000,
|
| 429 |
-
step=1000,
|
| 430 |
-
interactive=True,
|
| 431 |
-
visible=True,
|
| 432 |
-
)
|
| 433 |
-
|
| 434 |
-
with gr.Blocks():
|
| 435 |
-
with gr.Row(equal_height=True):
|
| 436 |
-
with gr.Column(scale=1):
|
| 437 |
-
upload_file = gr.File(
|
| 438 |
-
label=_("Upload File"),
|
| 439 |
-
file_count="single",
|
| 440 |
-
file_types=[".txt", ".json", ".jsonl"],
|
| 441 |
-
interactive=True,
|
| 442 |
-
)
|
| 443 |
-
examples_dir = os.path.join(root_dir, "webui", "examples")
|
| 444 |
-
gr.Examples(
|
| 445 |
-
examples=[
|
| 446 |
-
[os.path.join(examples_dir, "txt_demo.txt")],
|
| 447 |
-
[os.path.join(examples_dir, "raw_demo.jsonl")],
|
| 448 |
-
[os.path.join(examples_dir, "chunked_demo.json")],
|
| 449 |
-
],
|
| 450 |
-
inputs=upload_file,
|
| 451 |
-
label=_("Example Files"),
|
| 452 |
-
examples_per_page=3,
|
| 453 |
-
)
|
| 454 |
-
with gr.Column(scale=1):
|
| 455 |
-
output = gr.File(
|
| 456 |
-
label="Output(See Github FAQ)",
|
| 457 |
-
file_count="single",
|
| 458 |
-
interactive=False,
|
| 459 |
-
)
|
| 460 |
-
|
| 461 |
-
with gr.Blocks():
|
| 462 |
-
token_counter = gr.DataFrame(
|
| 463 |
-
label="Token Stats",
|
| 464 |
-
headers=[
|
| 465 |
-
"Source Text Token Count",
|
| 466 |
-
"Estimated Token Usage",
|
| 467 |
-
"Token Used",
|
| 468 |
-
],
|
| 469 |
-
datatype="str",
|
| 470 |
-
interactive=False,
|
| 471 |
-
visible=False,
|
| 472 |
-
wrap=True,
|
| 473 |
-
)
|
| 474 |
-
|
| 475 |
-
submit_btn = gr.Button(_("Run GraphGen"))
|
| 476 |
-
|
| 477 |
-
# Test Connection
|
| 478 |
-
test_connection_btn.click(
|
| 479 |
-
test_api_connection,
|
| 480 |
-
inputs=[synthesizer_url, api_key, synthesizer_model],
|
| 481 |
-
outputs=[],
|
| 482 |
-
)
|
| 483 |
-
|
| 484 |
-
if if_trainee_model.value:
|
| 485 |
-
test_connection_btn.click(
|
| 486 |
-
test_api_connection,
|
| 487 |
-
inputs=[trainee_url, api_key, trainee_model],
|
| 488 |
-
outputs=[],
|
| 489 |
-
)
|
| 490 |
-
|
| 491 |
-
expand_method.change(
|
| 492 |
-
lambda method: (
|
| 493 |
-
gr.update(visible=method == "max_width"),
|
| 494 |
-
gr.update(visible=method != "max_width"),
|
| 495 |
-
),
|
| 496 |
-
inputs=expand_method,
|
| 497 |
-
outputs=[max_extra_edges, max_tokens],
|
| 498 |
-
)
|
| 499 |
-
|
| 500 |
-
if_trainee_model.change(
|
| 501 |
-
lambda use_trainee: [gr.update(visible=use_trainee)] * 5,
|
| 502 |
-
inputs=if_trainee_model,
|
| 503 |
-
outputs=[
|
| 504 |
-
trainee_url,
|
| 505 |
-
trainee_model,
|
| 506 |
-
quiz_samples,
|
| 507 |
-
edge_sampling,
|
| 508 |
-
trainee_api_key,
|
| 509 |
-
],
|
| 510 |
-
)
|
| 511 |
-
|
| 512 |
-
upload_file.change(
|
| 513 |
-
lambda x: (gr.update(visible=True)),
|
| 514 |
-
inputs=[upload_file],
|
| 515 |
-
outputs=[token_counter],
|
| 516 |
-
).then(
|
| 517 |
-
count_tokens,
|
| 518 |
-
inputs=[upload_file, tokenizer, token_counter],
|
| 519 |
-
outputs=[token_counter],
|
| 520 |
-
)
|
| 521 |
-
|
| 522 |
-
# run GraphGen
|
| 523 |
-
submit_btn.click(
|
| 524 |
-
lambda x: (gr.update(visible=False)),
|
| 525 |
-
inputs=[token_counter],
|
| 526 |
-
outputs=[token_counter],
|
| 527 |
-
)
|
| 528 |
-
|
| 529 |
-
submit_btn.click(
|
| 530 |
-
lambda *args: run_graphgen(
|
| 531 |
-
GraphGenParams(
|
| 532 |
-
if_trainee_model=args[0],
|
| 533 |
-
input_file=args[1],
|
| 534 |
-
tokenizer=args[2],
|
| 535 |
-
qa_form=args[3],
|
| 536 |
-
bidirectional=args[4],
|
| 537 |
-
expand_method=args[5],
|
| 538 |
-
max_extra_edges=args[6],
|
| 539 |
-
max_tokens=args[7],
|
| 540 |
-
max_depth=args[8],
|
| 541 |
-
edge_sampling=args[9],
|
| 542 |
-
isolated_node_strategy=args[10],
|
| 543 |
-
loss_strategy=args[11],
|
| 544 |
-
synthesizer_url=args[12],
|
| 545 |
-
synthesizer_model=args[13],
|
| 546 |
-
trainee_model=args[14],
|
| 547 |
-
api_key=args[15],
|
| 548 |
-
chunk_size=args[16],
|
| 549 |
-
rpm=args[17],
|
| 550 |
-
tpm=args[18],
|
| 551 |
-
quiz_samples=args[19],
|
| 552 |
-
trainee_url=args[20],
|
| 553 |
-
trainee_api_key=args[21],
|
| 554 |
-
token_counter=args[22],
|
| 555 |
-
)
|
| 556 |
-
),
|
| 557 |
-
inputs=[
|
| 558 |
-
if_trainee_model,
|
| 559 |
-
upload_file,
|
| 560 |
-
tokenizer,
|
| 561 |
-
qa_form,
|
| 562 |
-
bidirectional,
|
| 563 |
-
expand_method,
|
| 564 |
-
max_extra_edges,
|
| 565 |
-
max_tokens,
|
| 566 |
-
max_depth,
|
| 567 |
-
edge_sampling,
|
| 568 |
-
isolated_node_strategy,
|
| 569 |
-
loss_strategy,
|
| 570 |
-
synthesizer_url,
|
| 571 |
-
synthesizer_model,
|
| 572 |
-
trainee_model,
|
| 573 |
-
api_key,
|
| 574 |
-
chunk_size,
|
| 575 |
-
rpm,
|
| 576 |
-
tpm,
|
| 577 |
-
quiz_samples,
|
| 578 |
-
trainee_url,
|
| 579 |
-
trainee_api_key,
|
| 580 |
-
token_counter,
|
| 581 |
-
],
|
| 582 |
-
outputs=[output, token_counter],
|
| 583 |
-
)
|
| 584 |
-
|
| 585 |
-
if __name__ == "__main__":
|
| 586 |
-
demo.queue(api_open=False, default_concurrency_limit=2)
|
| 587 |
-
demo.launch(server_name="0.0.0.0")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf-repo/graphgen/__init__.py
DELETED
|
File without changes
|
hf-repo/graphgen/configs/README.md
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
# Configs for GraphGen
|
|
|
|
|
|
hf-repo/graphgen/configs/aggregated_config.yaml
DELETED
|
@@ -1,21 +0,0 @@
|
|
| 1 |
-
input_data_type: raw # raw, chunked
|
| 2 |
-
input_file: resources/input_examples/raw_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples
|
| 3 |
-
output_data_type: aggregated # atomic, aggregated, multi_hop, cot
|
| 4 |
-
output_data_format: ChatML # Alpaca, Sharegpt, ChatML
|
| 5 |
-
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
|
| 6 |
-
search: # web search configuration
|
| 7 |
-
enabled: false # whether to enable web search
|
| 8 |
-
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
|
| 9 |
-
quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
|
| 10 |
-
enabled: true
|
| 11 |
-
quiz_samples: 2 # number of quiz samples to generate
|
| 12 |
-
re_judge: false # whether to re-judge the existing quiz samples
|
| 13 |
-
traverse_strategy: # strategy for clustering sub-graphs using comprehension loss
|
| 14 |
-
bidirectional: true # whether to traverse the graph in both directions
|
| 15 |
-
edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
|
| 16 |
-
expand_method: max_width # expand method, support: max_width, max_depth
|
| 17 |
-
isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
|
| 18 |
-
max_depth: 5 # maximum depth for graph traversal
|
| 19 |
-
max_extra_edges: 20 # max edges per direction (if expand_method="max_width")
|
| 20 |
-
max_tokens: 256 # restricts input length (if expand_method="max_tokens")
|
| 21 |
-
loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf-repo/graphgen/configs/atomic_config.yaml
DELETED
|
@@ -1,21 +0,0 @@
|
|
| 1 |
-
input_data_type: raw # raw, chunked
|
| 2 |
-
input_file: resources/input_examples/raw_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples
|
| 3 |
-
output_data_type: atomic # atomic, aggregated, multi_hop, cot
|
| 4 |
-
output_data_format: Alpaca # Alpaca, Sharegpt, ChatML
|
| 5 |
-
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
|
| 6 |
-
search: # web search configuration
|
| 7 |
-
enabled: false # whether to enable web search
|
| 8 |
-
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
|
| 9 |
-
quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
|
| 10 |
-
enabled: true
|
| 11 |
-
quiz_samples: 2 # number of quiz samples to generate
|
| 12 |
-
re_judge: false # whether to re-judge the existing quiz samples
|
| 13 |
-
traverse_strategy: # strategy for clustering sub-graphs using comprehension loss
|
| 14 |
-
bidirectional: true # whether to traverse the graph in both directions
|
| 15 |
-
edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
|
| 16 |
-
expand_method: max_width # expand method, support: max_width, max_depth
|
| 17 |
-
isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
|
| 18 |
-
max_depth: 3 # maximum depth for graph traversal
|
| 19 |
-
max_extra_edges: 5 # max edges per direction (if expand_method="max_width")
|
| 20 |
-
max_tokens: 256 # restricts input length (if expand_method="max_tokens")
|
| 21 |
-
loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf-repo/graphgen/configs/cot_config.yaml
DELETED
|
@@ -1,13 +0,0 @@
|
|
| 1 |
-
input_data_type: raw # raw, chunked
|
| 2 |
-
input_file: resources/input_examples/raw_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples
|
| 3 |
-
output_data_type: cot # atomic, aggregated, multi_hop, cot
|
| 4 |
-
output_data_format: Sharegpt # Alpaca, Sharegpt, ChatML
|
| 5 |
-
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
|
| 6 |
-
search: # web search configuration
|
| 7 |
-
enabled: false # whether to enable web search
|
| 8 |
-
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
|
| 9 |
-
method_params:
|
| 10 |
-
method: leiden
|
| 11 |
-
max_size: 20 # Maximum size of communities
|
| 12 |
-
use_lcc: false
|
| 13 |
-
random_seed: 42
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf-repo/graphgen/configs/multi_hop_config.yaml
DELETED
|
@@ -1,21 +0,0 @@
|
|
| 1 |
-
input_data_type: raw # raw, chunked
|
| 2 |
-
input_file: resources/input_examples/raw_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples
|
| 3 |
-
output_data_type: multi_hop # atomic, aggregated, multi_hop, cot
|
| 4 |
-
output_data_format: ChatML # Alpaca, Sharegpt, ChatML
|
| 5 |
-
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
|
| 6 |
-
search: # web search configuration
|
| 7 |
-
enabled: false # whether to enable web search
|
| 8 |
-
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
|
| 9 |
-
quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
|
| 10 |
-
enabled: true
|
| 11 |
-
quiz_samples: 2 # number of quiz samples to generate
|
| 12 |
-
re_judge: false # whether to re-judge the existing quiz samples
|
| 13 |
-
traverse_strategy: # strategy for clustering sub-graphs using comprehension loss
|
| 14 |
-
bidirectional: true # whether to traverse the graph in both directions
|
| 15 |
-
edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
|
| 16 |
-
expand_method: max_width # expand method, support: max_width, max_depth
|
| 17 |
-
isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
|
| 18 |
-
max_depth: 1 # maximum depth for graph traversal
|
| 19 |
-
max_extra_edges: 2 # max edges per direction (if expand_method="max_width")
|
| 20 |
-
max_tokens: 256 # restricts input length (if expand_method="max_tokens")
|
| 21 |
-
loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf-repo/graphgen/evaluate.py
DELETED
|
@@ -1,142 +0,0 @@
|
|
| 1 |
-
"""Evaluate the quality of the generated text using various metrics"""
|
| 2 |
-
|
| 3 |
-
import os
|
| 4 |
-
import json
|
| 5 |
-
import argparse
|
| 6 |
-
import pandas as pd
|
| 7 |
-
from dotenv import load_dotenv
|
| 8 |
-
from .models import LengthEvaluator, MTLDEvaluator, RewardEvaluator, TextPair, UniEvaluator
|
| 9 |
-
from .utils import logger, set_logger
|
| 10 |
-
|
| 11 |
-
sys_path = os.path.abspath(os.path.dirname(__file__))
|
| 12 |
-
set_logger(os.path.join(sys_path, "cache", "logs", "evaluate.log"))
|
| 13 |
-
|
| 14 |
-
load_dotenv()
|
| 15 |
-
|
| 16 |
-
def evaluate_length(corpus, tokenizer_name):
|
| 17 |
-
length_evaluator = LengthEvaluator(
|
| 18 |
-
tokenizer_name=tokenizer_name
|
| 19 |
-
)
|
| 20 |
-
logger.info("Length evaluator loaded")
|
| 21 |
-
scores = length_evaluator.get_average_score(corpus)
|
| 22 |
-
logger.info("Length scores: %s", scores)
|
| 23 |
-
return scores
|
| 24 |
-
|
| 25 |
-
def evaluate_mtld(corpus):
|
| 26 |
-
mtld_evaluator = MTLDEvaluator()
|
| 27 |
-
logger.info("MTLD evaluator loaded")
|
| 28 |
-
scores = mtld_evaluator.get_average_score(corpus)
|
| 29 |
-
logger.info("MTLD scores: %s", scores)
|
| 30 |
-
min_max_scores = mtld_evaluator.get_min_max_score(corpus)
|
| 31 |
-
logger.info("MTLD min max scores: %s", min_max_scores)
|
| 32 |
-
return scores, min_max_scores
|
| 33 |
-
|
| 34 |
-
def evaluate_reward(corpus, reward_model_names):
|
| 35 |
-
scores = []
|
| 36 |
-
for reward_name in reward_model_names:
|
| 37 |
-
reward_evaluator = RewardEvaluator(
|
| 38 |
-
reward_name=reward_name
|
| 39 |
-
)
|
| 40 |
-
logger.info("Loaded reward model: %s", reward_name)
|
| 41 |
-
average_score = reward_evaluator.get_average_score(corpus)
|
| 42 |
-
logger.info("%s scores: %s", reward_name, average_score)
|
| 43 |
-
min_max_scores = reward_evaluator.get_min_max_score(corpus)
|
| 44 |
-
logger.info("%s min max scores: %s", reward_name, min_max_scores)
|
| 45 |
-
scores.append({
|
| 46 |
-
'reward_name': reward_name.split('/')[-1],
|
| 47 |
-
'score': average_score,
|
| 48 |
-
'min_max_scores': min_max_scores
|
| 49 |
-
})
|
| 50 |
-
del reward_evaluator
|
| 51 |
-
clean_gpu_cache()
|
| 52 |
-
return scores
|
| 53 |
-
|
| 54 |
-
def evaluate_uni(corpus, uni_model_name):
|
| 55 |
-
uni_evaluator = UniEvaluator(
|
| 56 |
-
model_name=uni_model_name
|
| 57 |
-
)
|
| 58 |
-
logger.info("Uni evaluator loaded with model %s", uni_model_name)
|
| 59 |
-
uni_scores = uni_evaluator.get_average_score(corpus)
|
| 60 |
-
for key, value in uni_scores.items():
|
| 61 |
-
logger.info("Uni %s scores: %s", key, value)
|
| 62 |
-
min_max_scores = uni_evaluator.get_min_max_score(corpus)
|
| 63 |
-
for key, value in min_max_scores.items():
|
| 64 |
-
logger.info("Uni %s min max scores: %s", key, value)
|
| 65 |
-
del uni_evaluator
|
| 66 |
-
clean_gpu_cache()
|
| 67 |
-
return (uni_scores['naturalness'], uni_scores['coherence'], uni_scores['understandability'],
|
| 68 |
-
min_max_scores['naturalness'], min_max_scores['coherence'], min_max_scores['understandability'])
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
def clean_gpu_cache():
|
| 72 |
-
import torch
|
| 73 |
-
if torch.cuda.is_available():
|
| 74 |
-
torch.cuda.empty_cache()
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
if __name__ == '__main__':
|
| 78 |
-
import torch.multiprocessing as mp
|
| 79 |
-
parser = argparse.ArgumentParser()
|
| 80 |
-
|
| 81 |
-
parser.add_argument('--folder', type=str, default='cache/data', help='folder to load data')
|
| 82 |
-
parser.add_argument('--output', type=str, default='cache/output', help='path to save output')
|
| 83 |
-
|
| 84 |
-
parser.add_argument('--tokenizer', type=str, default='cl100k_base', help='tokenizer name')
|
| 85 |
-
parser.add_argument('--reward', type=str, default='OpenAssistant/reward-model-deberta-v3-large-v2',
|
| 86 |
-
help='Comma-separated list of reward models')
|
| 87 |
-
parser.add_argument('--uni', type=str, default='MingZhong/unieval-sum', help='uni model name')
|
| 88 |
-
|
| 89 |
-
args = parser.parse_args()
|
| 90 |
-
|
| 91 |
-
if not os.path.exists(args.folder):
|
| 92 |
-
raise ValueError(f"Folder {args.folder} does not exist")
|
| 93 |
-
|
| 94 |
-
if not os.path.exists(args.output):
|
| 95 |
-
os.makedirs(args.output)
|
| 96 |
-
|
| 97 |
-
reward_models = args.reward.split(',')
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
results = []
|
| 101 |
-
|
| 102 |
-
logger.info("Data loaded from %s", args.folder)
|
| 103 |
-
mp.set_start_method('spawn')
|
| 104 |
-
|
| 105 |
-
for file in os.listdir(args.folder):
|
| 106 |
-
if file.endswith('.json'):
|
| 107 |
-
logger.info("Processing %s", file)
|
| 108 |
-
with open(os.path.join(args.folder, file), 'r', encoding='utf-8') as f:
|
| 109 |
-
data = json.load(f)
|
| 110 |
-
data = [TextPair(
|
| 111 |
-
question=data[key]['question'],
|
| 112 |
-
answer=data[key]['answer']
|
| 113 |
-
) for key in data]
|
| 114 |
-
|
| 115 |
-
length_scores = evaluate_length(data, args.tokenizer)
|
| 116 |
-
mtld_scores, min_max_mtld_scores = evaluate_mtld(data)
|
| 117 |
-
reward_scores = evaluate_reward(data, reward_models)
|
| 118 |
-
uni_naturalness_scores, uni_coherence_scores, uni_understandability_scores, \
|
| 119 |
-
min_max_uni_naturalness_scores, min_max_uni_coherence_scores, min_max_uni_understandability_scores \
|
| 120 |
-
= evaluate_uni(data, args.uni)
|
| 121 |
-
|
| 122 |
-
result = {
|
| 123 |
-
'file': file,
|
| 124 |
-
'number': len(data),
|
| 125 |
-
'length': length_scores,
|
| 126 |
-
'mtld': mtld_scores,
|
| 127 |
-
'mtld_min_max': min_max_mtld_scores,
|
| 128 |
-
'uni_naturalness': uni_naturalness_scores,
|
| 129 |
-
'uni_coherence': uni_coherence_scores,
|
| 130 |
-
'uni_understandability': uni_understandability_scores,
|
| 131 |
-
'uni_naturalness_min_max': min_max_uni_naturalness_scores,
|
| 132 |
-
'uni_coherence_min_max': min_max_uni_coherence_scores,
|
| 133 |
-
'uni_understandability_min_max': min_max_uni_understandability_scores
|
| 134 |
-
}
|
| 135 |
-
for reward_score in reward_scores:
|
| 136 |
-
result[reward_score['reward_name']] = reward_score['score']
|
| 137 |
-
result[f"{reward_score['reward_name']}_min_max"] = reward_score['min_max_scores']
|
| 138 |
-
|
| 139 |
-
results.append(result)
|
| 140 |
-
|
| 141 |
-
results = pd.DataFrame(results)
|
| 142 |
-
results.to_csv(os.path.join(args.output, 'evaluation.csv'), index=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf-repo/graphgen/generate.py
DELETED
|
@@ -1,103 +0,0 @@
|
|
| 1 |
-
import argparse
|
| 2 |
-
import os
|
| 3 |
-
import time
|
| 4 |
-
from importlib.resources import files
|
| 5 |
-
|
| 6 |
-
import yaml
|
| 7 |
-
from dotenv import load_dotenv
|
| 8 |
-
|
| 9 |
-
from .graphgen import GraphGen
|
| 10 |
-
from .utils import logger, set_logger
|
| 11 |
-
|
| 12 |
-
sys_path = os.path.abspath(os.path.dirname(__file__))
|
| 13 |
-
|
| 14 |
-
load_dotenv()
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
def set_working_dir(folder):
|
| 18 |
-
os.makedirs(folder, exist_ok=True)
|
| 19 |
-
os.makedirs(os.path.join(folder, "data", "graphgen"), exist_ok=True)
|
| 20 |
-
os.makedirs(os.path.join(folder, "logs"), exist_ok=True)
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
def save_config(config_path, global_config):
|
| 24 |
-
if not os.path.exists(os.path.dirname(config_path)):
|
| 25 |
-
os.makedirs(os.path.dirname(config_path))
|
| 26 |
-
with open(config_path, "w", encoding="utf-8") as config_file:
|
| 27 |
-
yaml.dump(
|
| 28 |
-
global_config, config_file, default_flow_style=False, allow_unicode=True
|
| 29 |
-
)
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
def main():
|
| 33 |
-
parser = argparse.ArgumentParser()
|
| 34 |
-
parser.add_argument(
|
| 35 |
-
"--config_file",
|
| 36 |
-
help="Config parameters for GraphGen.",
|
| 37 |
-
default=files("graphgen").joinpath("configs", "aggregated_config.yaml"),
|
| 38 |
-
type=str,
|
| 39 |
-
)
|
| 40 |
-
parser.add_argument(
|
| 41 |
-
"--output_dir",
|
| 42 |
-
help="Output directory for GraphGen.",
|
| 43 |
-
default=sys_path,
|
| 44 |
-
required=True,
|
| 45 |
-
type=str,
|
| 46 |
-
)
|
| 47 |
-
|
| 48 |
-
args = parser.parse_args()
|
| 49 |
-
|
| 50 |
-
working_dir = args.output_dir
|
| 51 |
-
set_working_dir(working_dir)
|
| 52 |
-
|
| 53 |
-
with open(args.config_file, "r", encoding="utf-8") as f:
|
| 54 |
-
config = yaml.load(f, Loader=yaml.FullLoader)
|
| 55 |
-
|
| 56 |
-
output_data_type = config["output_data_type"]
|
| 57 |
-
unique_id = int(time.time())
|
| 58 |
-
set_logger(
|
| 59 |
-
os.path.join(
|
| 60 |
-
working_dir, "logs", f"graphgen_{output_data_type}_{unique_id}.log"
|
| 61 |
-
),
|
| 62 |
-
if_stream=True,
|
| 63 |
-
)
|
| 64 |
-
logger.info(
|
| 65 |
-
"GraphGen with unique ID %s logging to %s",
|
| 66 |
-
unique_id,
|
| 67 |
-
os.path.join(
|
| 68 |
-
working_dir, "logs", f"graphgen_{output_data_type}_{unique_id}.log"
|
| 69 |
-
),
|
| 70 |
-
)
|
| 71 |
-
|
| 72 |
-
graph_gen = GraphGen(working_dir=working_dir, unique_id=unique_id, config=config)
|
| 73 |
-
|
| 74 |
-
graph_gen.insert()
|
| 75 |
-
|
| 76 |
-
if config["search"]["enabled"]:
|
| 77 |
-
graph_gen.search()
|
| 78 |
-
|
| 79 |
-
# Use pipeline according to the output data type
|
| 80 |
-
if output_data_type in ["atomic", "aggregated", "multi_hop"]:
|
| 81 |
-
if "quiz_and_judge_strategy" in config and config[
|
| 82 |
-
"quiz_and_judge_strategy"
|
| 83 |
-
].get("enabled", False):
|
| 84 |
-
graph_gen.quiz()
|
| 85 |
-
graph_gen.judge()
|
| 86 |
-
else:
|
| 87 |
-
logger.warning(
|
| 88 |
-
"Quiz and Judge strategy is disabled. Edge sampling falls back to random."
|
| 89 |
-
)
|
| 90 |
-
graph_gen.traverse_strategy.edge_sampling = "random"
|
| 91 |
-
graph_gen.traverse()
|
| 92 |
-
elif output_data_type == "cot":
|
| 93 |
-
graph_gen.generate_reasoning(method_params=config["method_params"])
|
| 94 |
-
else:
|
| 95 |
-
raise ValueError(f"Unsupported output data type: {output_data_type}")
|
| 96 |
-
|
| 97 |
-
output_path = os.path.join(working_dir, "data", "graphgen", str(unique_id))
|
| 98 |
-
save_config(os.path.join(output_path, f"config-{unique_id}.yaml"), config)
|
| 99 |
-
logger.info("GraphGen completed successfully. Data saved to %s", output_path)
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
if __name__ == "__main__":
|
| 103 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf-repo/graphgen/graphgen.py
DELETED
|
@@ -1,395 +0,0 @@
|
|
| 1 |
-
import asyncio
|
| 2 |
-
import os
|
| 3 |
-
import time
|
| 4 |
-
from dataclasses import dataclass, field
|
| 5 |
-
from typing import Dict, List, Union, cast
|
| 6 |
-
|
| 7 |
-
import gradio as gr
|
| 8 |
-
from tqdm.asyncio import tqdm as tqdm_async
|
| 9 |
-
|
| 10 |
-
from .models import (
|
| 11 |
-
Chunk,
|
| 12 |
-
JsonKVStorage,
|
| 13 |
-
JsonListStorage,
|
| 14 |
-
NetworkXStorage,
|
| 15 |
-
OpenAIModel,
|
| 16 |
-
Tokenizer,
|
| 17 |
-
TraverseStrategy,
|
| 18 |
-
)
|
| 19 |
-
from .models.storage.base_storage import StorageNameSpace
|
| 20 |
-
from .operators import (
|
| 21 |
-
extract_kg,
|
| 22 |
-
generate_cot,
|
| 23 |
-
judge_statement,
|
| 24 |
-
quiz,
|
| 25 |
-
search_all,
|
| 26 |
-
traverse_graph_atomically,
|
| 27 |
-
traverse_graph_by_edge,
|
| 28 |
-
traverse_graph_for_multi_hop,
|
| 29 |
-
)
|
| 30 |
-
from .utils import (
|
| 31 |
-
compute_content_hash,
|
| 32 |
-
create_event_loop,
|
| 33 |
-
format_generation_results,
|
| 34 |
-
logger,
|
| 35 |
-
read_file,
|
| 36 |
-
)
|
| 37 |
-
|
| 38 |
-
sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
@dataclass
|
| 42 |
-
class GraphGen:
|
| 43 |
-
unique_id: int = int(time.time())
|
| 44 |
-
working_dir: str = os.path.join(sys_path, "cache")
|
| 45 |
-
config: Dict = field(default_factory=dict)
|
| 46 |
-
|
| 47 |
-
# llm
|
| 48 |
-
tokenizer_instance: Tokenizer = None
|
| 49 |
-
synthesizer_llm_client: OpenAIModel = None
|
| 50 |
-
trainee_llm_client: OpenAIModel = None
|
| 51 |
-
|
| 52 |
-
# text chunking
|
| 53 |
-
# TODO: make it configurable
|
| 54 |
-
chunk_size: int = 1024
|
| 55 |
-
chunk_overlap_size: int = 100
|
| 56 |
-
|
| 57 |
-
# search
|
| 58 |
-
search_config: dict = field(
|
| 59 |
-
default_factory=lambda: {"enabled": False, "search_types": ["wikipedia"]}
|
| 60 |
-
)
|
| 61 |
-
|
| 62 |
-
# traversal
|
| 63 |
-
traverse_strategy: TraverseStrategy = None
|
| 64 |
-
|
| 65 |
-
# webui
|
| 66 |
-
progress_bar: gr.Progress = None
|
| 67 |
-
|
| 68 |
-
def __post_init__(self):
|
| 69 |
-
self.tokenizer_instance: Tokenizer = Tokenizer(
|
| 70 |
-
model_name=self.config["tokenizer"]
|
| 71 |
-
)
|
| 72 |
-
self.synthesizer_llm_client: OpenAIModel = OpenAIModel(
|
| 73 |
-
model_name=os.getenv("SYNTHESIZER_MODEL"),
|
| 74 |
-
api_key=os.getenv("SYNTHESIZER_API_KEY"),
|
| 75 |
-
base_url=os.getenv("SYNTHESIZER_BASE_URL"),
|
| 76 |
-
tokenizer_instance=self.tokenizer_instance,
|
| 77 |
-
)
|
| 78 |
-
self.trainee_llm_client: OpenAIModel = OpenAIModel(
|
| 79 |
-
model_name=os.getenv("TRAINEE_MODEL"),
|
| 80 |
-
api_key=os.getenv("TRAINEE_API_KEY"),
|
| 81 |
-
base_url=os.getenv("TRAINEE_BASE_URL"),
|
| 82 |
-
tokenizer_instance=self.tokenizer_instance,
|
| 83 |
-
)
|
| 84 |
-
self.search_config = self.config["search"]
|
| 85 |
-
|
| 86 |
-
if "traverse_strategy" in self.config:
|
| 87 |
-
self.traverse_strategy = TraverseStrategy(
|
| 88 |
-
**self.config["traverse_strategy"]
|
| 89 |
-
)
|
| 90 |
-
|
| 91 |
-
self.full_docs_storage: JsonKVStorage = JsonKVStorage(
|
| 92 |
-
self.working_dir, namespace="full_docs"
|
| 93 |
-
)
|
| 94 |
-
self.text_chunks_storage: JsonKVStorage = JsonKVStorage(
|
| 95 |
-
self.working_dir, namespace="text_chunks"
|
| 96 |
-
)
|
| 97 |
-
self.graph_storage: NetworkXStorage = NetworkXStorage(
|
| 98 |
-
self.working_dir, namespace="graph"
|
| 99 |
-
)
|
| 100 |
-
self.search_storage: JsonKVStorage = JsonKVStorage(
|
| 101 |
-
self.working_dir, namespace="search"
|
| 102 |
-
)
|
| 103 |
-
self.rephrase_storage: JsonKVStorage = JsonKVStorage(
|
| 104 |
-
self.working_dir, namespace="rephrase"
|
| 105 |
-
)
|
| 106 |
-
self.qa_storage: JsonListStorage = JsonListStorage(
|
| 107 |
-
os.path.join(self.working_dir, "data", "graphgen", str(self.unique_id)),
|
| 108 |
-
namespace=f"qa-{self.unique_id}",
|
| 109 |
-
)
|
| 110 |
-
|
| 111 |
-
async def async_split_chunks(
|
| 112 |
-
self, data: List[Union[List, Dict]], data_type: str
|
| 113 |
-
) -> dict:
|
| 114 |
-
# TODO: configurable whether to use coreference resolution
|
| 115 |
-
if len(data) == 0:
|
| 116 |
-
return {}
|
| 117 |
-
|
| 118 |
-
inserting_chunks = {}
|
| 119 |
-
if data_type == "raw":
|
| 120 |
-
assert isinstance(data, list) and isinstance(data[0], dict)
|
| 121 |
-
# compute hash for each document
|
| 122 |
-
new_docs = {
|
| 123 |
-
compute_content_hash(doc["content"], prefix="doc-"): {
|
| 124 |
-
"content": doc["content"]
|
| 125 |
-
}
|
| 126 |
-
for doc in data
|
| 127 |
-
}
|
| 128 |
-
_add_doc_keys = await self.full_docs_storage.filter_keys(
|
| 129 |
-
list(new_docs.keys())
|
| 130 |
-
)
|
| 131 |
-
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
| 132 |
-
if len(new_docs) == 0:
|
| 133 |
-
logger.warning("All docs are already in the storage")
|
| 134 |
-
return {}
|
| 135 |
-
logger.info("[New Docs] inserting %d docs", len(new_docs))
|
| 136 |
-
|
| 137 |
-
cur_index = 1
|
| 138 |
-
doc_number = len(new_docs)
|
| 139 |
-
async for doc_key, doc in tqdm_async(
|
| 140 |
-
new_docs.items(), desc="[1/4]Chunking documents", unit="doc"
|
| 141 |
-
):
|
| 142 |
-
chunks = {
|
| 143 |
-
compute_content_hash(dp["content"], prefix="chunk-"): {
|
| 144 |
-
**dp,
|
| 145 |
-
"full_doc_id": doc_key,
|
| 146 |
-
}
|
| 147 |
-
for dp in self.tokenizer_instance.chunk_by_token_size(
|
| 148 |
-
doc["content"], self.chunk_overlap_size, self.chunk_size
|
| 149 |
-
)
|
| 150 |
-
}
|
| 151 |
-
inserting_chunks.update(chunks)
|
| 152 |
-
|
| 153 |
-
if self.progress_bar is not None:
|
| 154 |
-
self.progress_bar(cur_index / doc_number, f"Chunking {doc_key}")
|
| 155 |
-
cur_index += 1
|
| 156 |
-
|
| 157 |
-
_add_chunk_keys = await self.text_chunks_storage.filter_keys(
|
| 158 |
-
list(inserting_chunks.keys())
|
| 159 |
-
)
|
| 160 |
-
inserting_chunks = {
|
| 161 |
-
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
|
| 162 |
-
}
|
| 163 |
-
elif data_type == "chunked":
|
| 164 |
-
assert isinstance(data, list) and isinstance(data[0], list)
|
| 165 |
-
new_docs = {
|
| 166 |
-
compute_content_hash("".join(chunk["content"]), prefix="doc-"): {
|
| 167 |
-
"content": "".join(chunk["content"])
|
| 168 |
-
}
|
| 169 |
-
for doc in data
|
| 170 |
-
for chunk in doc
|
| 171 |
-
}
|
| 172 |
-
_add_doc_keys = await self.full_docs_storage.filter_keys(
|
| 173 |
-
list(new_docs.keys())
|
| 174 |
-
)
|
| 175 |
-
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
| 176 |
-
if len(new_docs) == 0:
|
| 177 |
-
logger.warning("All docs are already in the storage")
|
| 178 |
-
return {}
|
| 179 |
-
logger.info("[New Docs] inserting %d docs", len(new_docs))
|
| 180 |
-
async for doc in tqdm_async(
|
| 181 |
-
data, desc="[1/4]Chunking documents", unit="doc"
|
| 182 |
-
):
|
| 183 |
-
doc_str = "".join([chunk["content"] for chunk in doc])
|
| 184 |
-
for chunk in doc:
|
| 185 |
-
chunk_key = compute_content_hash(chunk["content"], prefix="chunk-")
|
| 186 |
-
inserting_chunks[chunk_key] = {
|
| 187 |
-
**chunk,
|
| 188 |
-
"full_doc_id": compute_content_hash(doc_str, prefix="doc-"),
|
| 189 |
-
}
|
| 190 |
-
_add_chunk_keys = await self.text_chunks_storage.filter_keys(
|
| 191 |
-
list(inserting_chunks.keys())
|
| 192 |
-
)
|
| 193 |
-
inserting_chunks = {
|
| 194 |
-
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
|
| 195 |
-
}
|
| 196 |
-
else:
|
| 197 |
-
raise ValueError(f"Unknown data type: {data_type}")
|
| 198 |
-
|
| 199 |
-
await self.full_docs_storage.upsert(new_docs)
|
| 200 |
-
await self.text_chunks_storage.upsert(inserting_chunks)
|
| 201 |
-
|
| 202 |
-
return inserting_chunks
|
| 203 |
-
|
| 204 |
-
def insert(self):
|
| 205 |
-
loop = create_event_loop()
|
| 206 |
-
loop.run_until_complete(self.async_insert())
|
| 207 |
-
|
| 208 |
-
async def async_insert(self):
|
| 209 |
-
"""
|
| 210 |
-
insert chunks into the graph
|
| 211 |
-
"""
|
| 212 |
-
|
| 213 |
-
input_file = self.config["input_file"]
|
| 214 |
-
data_type = self.config["input_data_type"]
|
| 215 |
-
data = read_file(input_file)
|
| 216 |
-
|
| 217 |
-
inserting_chunks = await self.async_split_chunks(data, data_type)
|
| 218 |
-
|
| 219 |
-
if len(inserting_chunks) == 0:
|
| 220 |
-
logger.warning("All chunks are already in the storage")
|
| 221 |
-
return
|
| 222 |
-
logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks))
|
| 223 |
-
|
| 224 |
-
logger.info("[Entity and Relation Extraction]...")
|
| 225 |
-
_add_entities_and_relations = await extract_kg(
|
| 226 |
-
llm_client=self.synthesizer_llm_client,
|
| 227 |
-
kg_instance=self.graph_storage,
|
| 228 |
-
tokenizer_instance=self.tokenizer_instance,
|
| 229 |
-
chunks=[
|
| 230 |
-
Chunk(id=k, content=v["content"]) for k, v in inserting_chunks.items()
|
| 231 |
-
],
|
| 232 |
-
progress_bar=self.progress_bar,
|
| 233 |
-
)
|
| 234 |
-
if not _add_entities_and_relations:
|
| 235 |
-
logger.warning("No entities or relations extracted")
|
| 236 |
-
return
|
| 237 |
-
|
| 238 |
-
await self._insert_done()
|
| 239 |
-
|
| 240 |
-
async def _insert_done(self):
|
| 241 |
-
tasks = []
|
| 242 |
-
for storage_instance in [
|
| 243 |
-
self.full_docs_storage,
|
| 244 |
-
self.text_chunks_storage,
|
| 245 |
-
self.graph_storage,
|
| 246 |
-
self.search_storage,
|
| 247 |
-
]:
|
| 248 |
-
if storage_instance is None:
|
| 249 |
-
continue
|
| 250 |
-
tasks.append(cast(StorageNameSpace, storage_instance).index_done_callback())
|
| 251 |
-
await asyncio.gather(*tasks)
|
| 252 |
-
|
| 253 |
-
def search(self):
|
| 254 |
-
loop = create_event_loop()
|
| 255 |
-
loop.run_until_complete(self.async_search())
|
| 256 |
-
|
| 257 |
-
async def async_search(self):
|
| 258 |
-
logger.info(
|
| 259 |
-
"Search is %s", "enabled" if self.search_config["enabled"] else "disabled"
|
| 260 |
-
)
|
| 261 |
-
if self.search_config["enabled"]:
|
| 262 |
-
logger.info(
|
| 263 |
-
"[Search] %s ...", ", ".join(self.search_config["search_types"])
|
| 264 |
-
)
|
| 265 |
-
all_nodes = await self.graph_storage.get_all_nodes()
|
| 266 |
-
all_nodes_names = [node[0] for node in all_nodes]
|
| 267 |
-
new_search_entities = await self.full_docs_storage.filter_keys(
|
| 268 |
-
all_nodes_names
|
| 269 |
-
)
|
| 270 |
-
logger.info(
|
| 271 |
-
"[Search] Found %d entities to search", len(new_search_entities)
|
| 272 |
-
)
|
| 273 |
-
_add_search_data = await search_all(
|
| 274 |
-
search_types=self.search_config["search_types"],
|
| 275 |
-
search_entities=new_search_entities,
|
| 276 |
-
)
|
| 277 |
-
if _add_search_data:
|
| 278 |
-
await self.search_storage.upsert(_add_search_data)
|
| 279 |
-
logger.info("[Search] %d entities searched", len(_add_search_data))
|
| 280 |
-
|
| 281 |
-
# Format search results for inserting
|
| 282 |
-
search_results = []
|
| 283 |
-
for _, search_data in _add_search_data.items():
|
| 284 |
-
search_results.extend(
|
| 285 |
-
[
|
| 286 |
-
{"content": search_data[key]}
|
| 287 |
-
for key in list(search_data.keys())
|
| 288 |
-
]
|
| 289 |
-
)
|
| 290 |
-
# TODO: fix insert after search
|
| 291 |
-
await self.async_insert()
|
| 292 |
-
|
| 293 |
-
def quiz(self):
|
| 294 |
-
loop = create_event_loop()
|
| 295 |
-
loop.run_until_complete(self.async_quiz())
|
| 296 |
-
|
| 297 |
-
async def async_quiz(self):
|
| 298 |
-
max_samples = self.config["quiz_and_judge_strategy"]["quiz_samples"]
|
| 299 |
-
await quiz(
|
| 300 |
-
self.synthesizer_llm_client,
|
| 301 |
-
self.graph_storage,
|
| 302 |
-
self.rephrase_storage,
|
| 303 |
-
max_samples,
|
| 304 |
-
)
|
| 305 |
-
await self.rephrase_storage.index_done_callback()
|
| 306 |
-
|
| 307 |
-
def judge(self):
|
| 308 |
-
loop = create_event_loop()
|
| 309 |
-
loop.run_until_complete(self.async_judge())
|
| 310 |
-
|
| 311 |
-
async def async_judge(self):
|
| 312 |
-
re_judge = self.config["quiz_and_judge_strategy"]["re_judge"]
|
| 313 |
-
_update_relations = await judge_statement(
|
| 314 |
-
self.trainee_llm_client,
|
| 315 |
-
self.graph_storage,
|
| 316 |
-
self.rephrase_storage,
|
| 317 |
-
re_judge,
|
| 318 |
-
)
|
| 319 |
-
await _update_relations.index_done_callback()
|
| 320 |
-
|
| 321 |
-
def traverse(self):
|
| 322 |
-
loop = create_event_loop()
|
| 323 |
-
loop.run_until_complete(self.async_traverse())
|
| 324 |
-
|
| 325 |
-
async def async_traverse(self):
|
| 326 |
-
output_data_type = self.config["output_data_type"]
|
| 327 |
-
|
| 328 |
-
if output_data_type == "atomic":
|
| 329 |
-
results = await traverse_graph_atomically(
|
| 330 |
-
self.synthesizer_llm_client,
|
| 331 |
-
self.tokenizer_instance,
|
| 332 |
-
self.graph_storage,
|
| 333 |
-
self.traverse_strategy,
|
| 334 |
-
self.text_chunks_storage,
|
| 335 |
-
self.progress_bar,
|
| 336 |
-
)
|
| 337 |
-
elif output_data_type == "multi_hop":
|
| 338 |
-
results = await traverse_graph_for_multi_hop(
|
| 339 |
-
self.synthesizer_llm_client,
|
| 340 |
-
self.tokenizer_instance,
|
| 341 |
-
self.graph_storage,
|
| 342 |
-
self.traverse_strategy,
|
| 343 |
-
self.text_chunks_storage,
|
| 344 |
-
self.progress_bar,
|
| 345 |
-
)
|
| 346 |
-
elif output_data_type == "aggregated":
|
| 347 |
-
results = await traverse_graph_by_edge(
|
| 348 |
-
self.synthesizer_llm_client,
|
| 349 |
-
self.tokenizer_instance,
|
| 350 |
-
self.graph_storage,
|
| 351 |
-
self.traverse_strategy,
|
| 352 |
-
self.text_chunks_storage,
|
| 353 |
-
self.progress_bar,
|
| 354 |
-
)
|
| 355 |
-
else:
|
| 356 |
-
raise ValueError(f"Unknown qa_form: {output_data_type}")
|
| 357 |
-
|
| 358 |
-
results = format_generation_results(
|
| 359 |
-
results, output_data_format=self.config["output_data_format"]
|
| 360 |
-
)
|
| 361 |
-
|
| 362 |
-
await self.qa_storage.upsert(results)
|
| 363 |
-
await self.qa_storage.index_done_callback()
|
| 364 |
-
|
| 365 |
-
def generate_reasoning(self, method_params):
|
| 366 |
-
loop = create_event_loop()
|
| 367 |
-
loop.run_until_complete(self.async_generate_reasoning(method_params))
|
| 368 |
-
|
| 369 |
-
async def async_generate_reasoning(self, method_params):
|
| 370 |
-
results = await generate_cot(
|
| 371 |
-
self.graph_storage,
|
| 372 |
-
self.synthesizer_llm_client,
|
| 373 |
-
method_params=method_params,
|
| 374 |
-
)
|
| 375 |
-
|
| 376 |
-
results = format_generation_results(
|
| 377 |
-
results, output_data_format=self.config["output_data_format"]
|
| 378 |
-
)
|
| 379 |
-
|
| 380 |
-
await self.qa_storage.upsert(results)
|
| 381 |
-
await self.qa_storage.index_done_callback()
|
| 382 |
-
|
| 383 |
-
def clear(self):
|
| 384 |
-
loop = create_event_loop()
|
| 385 |
-
loop.run_until_complete(self.async_clear())
|
| 386 |
-
|
| 387 |
-
async def async_clear(self):
|
| 388 |
-
await self.full_docs_storage.drop()
|
| 389 |
-
await self.text_chunks_storage.drop()
|
| 390 |
-
await self.search_storage.drop()
|
| 391 |
-
await self.graph_storage.clear()
|
| 392 |
-
await self.rephrase_storage.drop()
|
| 393 |
-
await self.qa_storage.drop()
|
| 394 |
-
|
| 395 |
-
logger.info("All caches are cleared")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf-repo/graphgen/judge.py
DELETED
|
@@ -1,60 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import argparse
|
| 3 |
-
import asyncio
|
| 4 |
-
from dotenv import load_dotenv
|
| 5 |
-
|
| 6 |
-
from .models import NetworkXStorage, JsonKVStorage, OpenAIModel
|
| 7 |
-
from .operators import judge_statement
|
| 8 |
-
|
| 9 |
-
sys_path = os.path.abspath(os.path.dirname(__file__))
|
| 10 |
-
|
| 11 |
-
load_dotenv()
|
| 12 |
-
|
| 13 |
-
def calculate_average_loss(graph: NetworkXStorage):
|
| 14 |
-
"""
|
| 15 |
-
Calculate the average loss of the graph.
|
| 16 |
-
|
| 17 |
-
:param graph: NetworkXStorage
|
| 18 |
-
:return: float
|
| 19 |
-
"""
|
| 20 |
-
edges = asyncio.run(graph.get_all_edges())
|
| 21 |
-
total_loss = 0
|
| 22 |
-
for edge in edges:
|
| 23 |
-
total_loss += edge[2]['loss']
|
| 24 |
-
return total_loss / len(edges)
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
if __name__ == '__main__':
|
| 29 |
-
parser = argparse.ArgumentParser()
|
| 30 |
-
parser.add_argument('--input', type=str, default=os.path.join(sys_path, "cache"), help='path to load input graph')
|
| 31 |
-
parser.add_argument('--output', type=str, default='cache/output/new_graph.graphml', help='path to save output')
|
| 32 |
-
|
| 33 |
-
args = parser.parse_args()
|
| 34 |
-
|
| 35 |
-
llm_client = OpenAIModel(
|
| 36 |
-
model_name=os.getenv("TRAINEE_MODEL"),
|
| 37 |
-
api_key=os.getenv("TRAINEE_API_KEY"),
|
| 38 |
-
base_url=os.getenv("TRAINEE_BASE_URL")
|
| 39 |
-
)
|
| 40 |
-
|
| 41 |
-
graph_storage = NetworkXStorage(
|
| 42 |
-
args.input,
|
| 43 |
-
namespace="graph"
|
| 44 |
-
)
|
| 45 |
-
average_loss = calculate_average_loss(graph_storage)
|
| 46 |
-
print(f"Average loss of the graph: {average_loss}")
|
| 47 |
-
|
| 48 |
-
rephrase_storage = JsonKVStorage(
|
| 49 |
-
os.path.join(sys_path, "cache"),
|
| 50 |
-
namespace="rephrase"
|
| 51 |
-
)
|
| 52 |
-
|
| 53 |
-
new_graph = asyncio.run(judge_statement(llm_client, graph_storage, rephrase_storage, re_judge=True))
|
| 54 |
-
|
| 55 |
-
graph_file = asyncio.run(graph_storage.get_graph())
|
| 56 |
-
|
| 57 |
-
new_graph.write_nx_graph(graph_file, args.output)
|
| 58 |
-
|
| 59 |
-
average_loss = calculate_average_loss(new_graph)
|
| 60 |
-
print(f"Average loss of the graph: {average_loss}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf-repo/graphgen/models/__init__.py
DELETED
|
@@ -1,45 +0,0 @@
|
|
| 1 |
-
from .community.community_detector import CommunityDetector
|
| 2 |
-
from .evaluate.length_evaluator import LengthEvaluator
|
| 3 |
-
from .evaluate.mtld_evaluator import MTLDEvaluator
|
| 4 |
-
from .evaluate.reward_evaluator import RewardEvaluator
|
| 5 |
-
from .evaluate.uni_evaluator import UniEvaluator
|
| 6 |
-
from .llm.openai_model import OpenAIModel
|
| 7 |
-
from .llm.tokenizer import Tokenizer
|
| 8 |
-
from .llm.topk_token_model import Token, TopkTokenModel
|
| 9 |
-
from .search.db.uniprot_search import UniProtSearch
|
| 10 |
-
from .search.kg.wiki_search import WikiSearch
|
| 11 |
-
from .search.web.bing_search import BingSearch
|
| 12 |
-
from .search.web.google_search import GoogleSearch
|
| 13 |
-
from .storage.json_storage import JsonKVStorage, JsonListStorage
|
| 14 |
-
from .storage.networkx_storage import NetworkXStorage
|
| 15 |
-
from .strategy.travserse_strategy import TraverseStrategy
|
| 16 |
-
from .text.chunk import Chunk
|
| 17 |
-
from .text.text_pair import TextPair
|
| 18 |
-
|
| 19 |
-
__all__ = [
|
| 20 |
-
# llm models
|
| 21 |
-
"OpenAIModel",
|
| 22 |
-
"TopkTokenModel",
|
| 23 |
-
"Token",
|
| 24 |
-
"Tokenizer",
|
| 25 |
-
# storage models
|
| 26 |
-
"Chunk",
|
| 27 |
-
"NetworkXStorage",
|
| 28 |
-
"JsonKVStorage",
|
| 29 |
-
"JsonListStorage",
|
| 30 |
-
# search models
|
| 31 |
-
"WikiSearch",
|
| 32 |
-
"GoogleSearch",
|
| 33 |
-
"BingSearch",
|
| 34 |
-
"UniProtSearch",
|
| 35 |
-
# evaluate models
|
| 36 |
-
"TextPair",
|
| 37 |
-
"LengthEvaluator",
|
| 38 |
-
"MTLDEvaluator",
|
| 39 |
-
"RewardEvaluator",
|
| 40 |
-
"UniEvaluator",
|
| 41 |
-
# strategy models
|
| 42 |
-
"TraverseStrategy",
|
| 43 |
-
# community models
|
| 44 |
-
"CommunityDetector",
|
| 45 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf-repo/graphgen/models/community/__init__.py
DELETED
|
File without changes
|
hf-repo/graphgen/models/community/community_detector.py
DELETED
|
@@ -1,95 +0,0 @@
|
|
| 1 |
-
from collections import defaultdict
|
| 2 |
-
from dataclasses import dataclass
|
| 3 |
-
from typing import Any, Dict, List
|
| 4 |
-
|
| 5 |
-
from graphgen.models.storage.networkx_storage import NetworkXStorage
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
@dataclass
|
| 9 |
-
class CommunityDetector:
|
| 10 |
-
"""Class for community detection algorithms."""
|
| 11 |
-
|
| 12 |
-
graph_storage: NetworkXStorage = None
|
| 13 |
-
method: str = "leiden"
|
| 14 |
-
method_params: Dict[str, Any] = None
|
| 15 |
-
|
| 16 |
-
async def detect_communities(self) -> Dict[str, int]:
|
| 17 |
-
if self.method == "leiden":
|
| 18 |
-
return await self._leiden_communities(**self.method_params or {})
|
| 19 |
-
raise ValueError(f"Unknown community detection method: {self.method}")
|
| 20 |
-
|
| 21 |
-
async def get_graph(self):
|
| 22 |
-
return await self.graph_storage.get_graph()
|
| 23 |
-
|
| 24 |
-
async def _leiden_communities(
|
| 25 |
-
self, max_size: int = None, **kwargs
|
| 26 |
-
) -> Dict[str, int]:
|
| 27 |
-
"""
|
| 28 |
-
Detect communities using the Leiden algorithm.
|
| 29 |
-
If max_size is given, any community larger than max_size will be split
|
| 30 |
-
into smaller sub-communities each having at most max_size nodes.
|
| 31 |
-
"""
|
| 32 |
-
import igraph as ig
|
| 33 |
-
import networkx as nx
|
| 34 |
-
from leidenalg import ModularityVertexPartition, find_partition
|
| 35 |
-
|
| 36 |
-
graph = await self.get_graph()
|
| 37 |
-
graph.remove_nodes_from(list(nx.isolates(graph)))
|
| 38 |
-
|
| 39 |
-
ig_graph = ig.Graph.TupleList(graph.edges(), directed=False)
|
| 40 |
-
|
| 41 |
-
random_seed = kwargs.get("random_seed", 42)
|
| 42 |
-
use_lcc = kwargs.get("use_lcc", False)
|
| 43 |
-
|
| 44 |
-
communities: Dict[str, int] = {}
|
| 45 |
-
if use_lcc:
|
| 46 |
-
lcc = ig_graph.components().giant()
|
| 47 |
-
partition = find_partition(lcc, ModularityVertexPartition, seed=random_seed)
|
| 48 |
-
for part, cluster in enumerate(partition):
|
| 49 |
-
for v in cluster:
|
| 50 |
-
communities[lcc.vs[v]["name"]] = part
|
| 51 |
-
else:
|
| 52 |
-
offset = 0
|
| 53 |
-
for component in ig_graph.components():
|
| 54 |
-
subgraph = ig_graph.induced_subgraph(component)
|
| 55 |
-
partition = find_partition(
|
| 56 |
-
subgraph, ModularityVertexPartition, seed=random_seed
|
| 57 |
-
)
|
| 58 |
-
for part, cluster in enumerate(partition):
|
| 59 |
-
for v in cluster:
|
| 60 |
-
original_node = subgraph.vs[v]["name"]
|
| 61 |
-
communities[original_node] = part + offset
|
| 62 |
-
offset += len(partition)
|
| 63 |
-
|
| 64 |
-
# split large communities if max_size is specified
|
| 65 |
-
if max_size is None or max_size <= 0:
|
| 66 |
-
return communities
|
| 67 |
-
|
| 68 |
-
return await self._split_communities(communities, max_size)
|
| 69 |
-
|
| 70 |
-
@staticmethod
|
| 71 |
-
async def _split_communities(
|
| 72 |
-
communities: Dict[str, int], max_size: int
|
| 73 |
-
) -> Dict[str, int]:
|
| 74 |
-
"""
|
| 75 |
-
Split communities larger than max_size into smaller sub-communities.
|
| 76 |
-
"""
|
| 77 |
-
cid2nodes: Dict[int, List[str]] = defaultdict(list)
|
| 78 |
-
for node, cid in communities.items():
|
| 79 |
-
cid2nodes[cid].append(node)
|
| 80 |
-
|
| 81 |
-
new_communities: Dict[str, int] = {}
|
| 82 |
-
new_cid = 0
|
| 83 |
-
for cid, nodes in cid2nodes.items():
|
| 84 |
-
if len(nodes) <= max_size:
|
| 85 |
-
for n in nodes:
|
| 86 |
-
new_communities[n] = new_cid
|
| 87 |
-
new_cid += 1
|
| 88 |
-
else:
|
| 89 |
-
for start in range(0, len(nodes), max_size):
|
| 90 |
-
sub = nodes[start : start + max_size]
|
| 91 |
-
for n in sub:
|
| 92 |
-
new_communities[n] = new_cid
|
| 93 |
-
new_cid += 1
|
| 94 |
-
|
| 95 |
-
return new_communities
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf-repo/graphgen/models/embed/__init__.py
DELETED
|
File without changes
|
hf-repo/graphgen/models/embed/embedding.py
DELETED
|
@@ -1,29 +0,0 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
-
import asyncio
|
| 3 |
-
import numpy as np
|
| 4 |
-
|
| 5 |
-
class UnlimitedSemaphore:
|
| 6 |
-
"""A context manager that allows unlimited access."""
|
| 7 |
-
|
| 8 |
-
async def __aenter__(self):
|
| 9 |
-
pass
|
| 10 |
-
|
| 11 |
-
async def __aexit__(self, exc_type, exc, tb):
|
| 12 |
-
pass
|
| 13 |
-
|
| 14 |
-
@dataclass
|
| 15 |
-
class EmbeddingFunc:
|
| 16 |
-
embedding_dim: int
|
| 17 |
-
max_token_size: int
|
| 18 |
-
func: callable
|
| 19 |
-
concurrent_limit: int = 16
|
| 20 |
-
|
| 21 |
-
def __post_init__(self):
|
| 22 |
-
if self.concurrent_limit != 0:
|
| 23 |
-
self._semaphore = asyncio.Semaphore(self.concurrent_limit)
|
| 24 |
-
else:
|
| 25 |
-
self._semaphore = UnlimitedSemaphore()
|
| 26 |
-
|
| 27 |
-
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
| 28 |
-
async with self._semaphore:
|
| 29 |
-
return await self.func(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf-repo/graphgen/models/evaluate/__init__.py
DELETED
|
File without changes
|
hf-repo/graphgen/models/evaluate/base_evaluator.py
DELETED
|
@@ -1,51 +0,0 @@
|
|
| 1 |
-
import asyncio
|
| 2 |
-
|
| 3 |
-
from dataclasses import dataclass
|
| 4 |
-
from tqdm.asyncio import tqdm as tqdm_async
|
| 5 |
-
from graphgen.utils import create_event_loop
|
| 6 |
-
from graphgen.models.text.text_pair import TextPair
|
| 7 |
-
|
| 8 |
-
@dataclass
|
| 9 |
-
class BaseEvaluator:
|
| 10 |
-
max_concurrent: int = 100
|
| 11 |
-
results: list[float] = None
|
| 12 |
-
|
| 13 |
-
def evaluate(self, pairs: list[TextPair]) -> list[float]:
|
| 14 |
-
"""
|
| 15 |
-
Evaluate the text and return a score.
|
| 16 |
-
"""
|
| 17 |
-
return create_event_loop().run_until_complete(self.async_evaluate(pairs))
|
| 18 |
-
|
| 19 |
-
async def async_evaluate(self, pairs: list[TextPair]) -> list[float]:
|
| 20 |
-
semaphore = asyncio.Semaphore(self.max_concurrent)
|
| 21 |
-
|
| 22 |
-
async def evaluate_with_semaphore(pair):
|
| 23 |
-
async with semaphore: # 获取Semaphore
|
| 24 |
-
return await self.evaluate_single(pair)
|
| 25 |
-
|
| 26 |
-
results = []
|
| 27 |
-
for result in tqdm_async(
|
| 28 |
-
asyncio.as_completed([evaluate_with_semaphore(pair) for pair in pairs]),
|
| 29 |
-
total=len(pairs),
|
| 30 |
-
):
|
| 31 |
-
results.append(await result)
|
| 32 |
-
return results
|
| 33 |
-
|
| 34 |
-
async def evaluate_single(self, pair: TextPair) -> float:
|
| 35 |
-
raise NotImplementedError()
|
| 36 |
-
|
| 37 |
-
def get_average_score(self, pairs: list[TextPair]) -> float:
|
| 38 |
-
"""
|
| 39 |
-
Get the average score of a batch of texts.
|
| 40 |
-
"""
|
| 41 |
-
results = self.evaluate(pairs)
|
| 42 |
-
self.results = results
|
| 43 |
-
return sum(self.results) / len(pairs)
|
| 44 |
-
|
| 45 |
-
def get_min_max_score(self, pairs: list[TextPair]) -> tuple[float, float]:
|
| 46 |
-
"""
|
| 47 |
-
Get the min and max score of a batch of texts.
|
| 48 |
-
"""
|
| 49 |
-
if self.results is None:
|
| 50 |
-
self.get_average_score(pairs)
|
| 51 |
-
return min(self.results), max(self.results)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf-repo/graphgen/models/evaluate/length_evaluator.py
DELETED
|
@@ -1,22 +0,0 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
-
from graphgen.models.evaluate.base_evaluator import BaseEvaluator
|
| 3 |
-
from graphgen.models.llm.tokenizer import Tokenizer
|
| 4 |
-
from graphgen.models.text.text_pair import TextPair
|
| 5 |
-
from graphgen.utils import create_event_loop
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
@dataclass
|
| 9 |
-
class LengthEvaluator(BaseEvaluator):
|
| 10 |
-
tokenizer_name: str = "cl100k_base"
|
| 11 |
-
def __post_init__(self):
|
| 12 |
-
self.tokenizer = Tokenizer(
|
| 13 |
-
model_name=self.tokenizer_name
|
| 14 |
-
)
|
| 15 |
-
|
| 16 |
-
async def evaluate_single(self, pair: TextPair) -> float:
|
| 17 |
-
loop = create_event_loop()
|
| 18 |
-
return await loop.run_in_executor(None, self._calculate_length, pair.answer)
|
| 19 |
-
|
| 20 |
-
def _calculate_length(self, text: str) -> float:
|
| 21 |
-
tokens = self.tokenizer.encode_string(text)
|
| 22 |
-
return len(tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf-repo/graphgen/models/evaluate/mtld_evaluator.py
DELETED
|
@@ -1,76 +0,0 @@
|
|
| 1 |
-
from dataclasses import dataclass, field
|
| 2 |
-
from typing import Set
|
| 3 |
-
|
| 4 |
-
from graphgen.models.evaluate.base_evaluator import BaseEvaluator
|
| 5 |
-
from graphgen.models.text.text_pair import TextPair
|
| 6 |
-
from graphgen.utils import detect_main_language, NLTKHelper, create_event_loop
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
nltk_helper = NLTKHelper()
|
| 10 |
-
|
| 11 |
-
@dataclass
|
| 12 |
-
class MTLDEvaluator(BaseEvaluator):
|
| 13 |
-
"""
|
| 14 |
-
衡量文本词汇多样性的指标
|
| 15 |
-
"""
|
| 16 |
-
stopwords_en: Set[str] = field(default_factory=lambda: set(nltk_helper.get_stopwords("english")))
|
| 17 |
-
stopwords_zh: Set[str] = field(default_factory=lambda: set(nltk_helper.get_stopwords("chinese")))
|
| 18 |
-
|
| 19 |
-
async def evaluate_single(self, pair: TextPair) -> float:
|
| 20 |
-
loop = create_event_loop()
|
| 21 |
-
return await loop.run_in_executor(None, self._calculate_mtld_score, pair.answer)
|
| 22 |
-
|
| 23 |
-
def _calculate_mtld_score(self, text: str, threshold=0.72) -> float:
|
| 24 |
-
"""
|
| 25 |
-
计算MTLD (向前和向后的平均值)
|
| 26 |
-
|
| 27 |
-
min is 1.0
|
| 28 |
-
higher is better
|
| 29 |
-
"""
|
| 30 |
-
if not text or not text.strip():
|
| 31 |
-
return 0.0
|
| 32 |
-
|
| 33 |
-
lang = detect_main_language(text)
|
| 34 |
-
tokens = nltk_helper.word_tokenize(text, lang)
|
| 35 |
-
|
| 36 |
-
stopwords = self.stopwords_zh if lang == "zh" else self.stopwords_en
|
| 37 |
-
filtered_tokens = [word for word in tokens if word not in stopwords]
|
| 38 |
-
filtered_tokens = [word for word in filtered_tokens if word.isalnum()]
|
| 39 |
-
|
| 40 |
-
if not filtered_tokens:
|
| 41 |
-
return 0
|
| 42 |
-
|
| 43 |
-
# 计算向前的MTLD
|
| 44 |
-
forward_factors = self._compute_factors(filtered_tokens, threshold)
|
| 45 |
-
|
| 46 |
-
# 计算向后的MTLD
|
| 47 |
-
backward_factors = self._compute_factors(filtered_tokens[::-1], threshold)
|
| 48 |
-
|
| 49 |
-
# 取平均值
|
| 50 |
-
return (forward_factors + backward_factors) / 2
|
| 51 |
-
|
| 52 |
-
@staticmethod
|
| 53 |
-
def _compute_factors(tokens: list, threshold: float) -> float:
|
| 54 |
-
factors = 0
|
| 55 |
-
current_segment = []
|
| 56 |
-
unique_words = set()
|
| 57 |
-
|
| 58 |
-
for token in tokens:
|
| 59 |
-
current_segment.append(token)
|
| 60 |
-
unique_words.add(token)
|
| 61 |
-
ttr = len(unique_words) / len(current_segment)
|
| 62 |
-
|
| 63 |
-
if ttr <= threshold:
|
| 64 |
-
factors += 1
|
| 65 |
-
current_segment = []
|
| 66 |
-
unique_words = set()
|
| 67 |
-
|
| 68 |
-
# 处理最后一个不完整片段
|
| 69 |
-
if current_segment:
|
| 70 |
-
ttr = len(unique_words) / len(current_segment)
|
| 71 |
-
if ttr <= threshold:
|
| 72 |
-
factors += 1
|
| 73 |
-
else:
|
| 74 |
-
factors += (1 - (ttr - threshold) / (1 - threshold))
|
| 75 |
-
|
| 76 |
-
return len(tokens) / factors if factors > 0 else len(tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf-repo/graphgen/models/evaluate/reward_evaluator.py
DELETED
|
@@ -1,101 +0,0 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
-
from tqdm import tqdm
|
| 3 |
-
from graphgen.models.text.text_pair import TextPair
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
@dataclass
|
| 7 |
-
class RewardEvaluator:
|
| 8 |
-
"""
|
| 9 |
-
Reward Model Evaluator.
|
| 10 |
-
OpenAssistant/reward-model-deberta-v3-large-v2: 分数范围为[-inf, inf],越高越好
|
| 11 |
-
"""
|
| 12 |
-
reward_name: str = "OpenAssistant/reward-model-deberta-v3-large-v2"
|
| 13 |
-
max_length: int = 2560
|
| 14 |
-
results: list[float] = None
|
| 15 |
-
|
| 16 |
-
def __post_init__(self):
|
| 17 |
-
import torch
|
| 18 |
-
self.num_gpus = torch.cuda.device_count()
|
| 19 |
-
|
| 20 |
-
@staticmethod
|
| 21 |
-
def process_chunk(rank, pairs, reward_name, max_length, return_dict):
|
| 22 |
-
import torch
|
| 23 |
-
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
| 24 |
-
device = f'cuda:{rank}'
|
| 25 |
-
torch.cuda.set_device(rank)
|
| 26 |
-
|
| 27 |
-
rank_model = AutoModelForSequenceClassification.from_pretrained(reward_name)
|
| 28 |
-
tokenizer = AutoTokenizer.from_pretrained(reward_name)
|
| 29 |
-
rank_model.to(device)
|
| 30 |
-
rank_model.eval()
|
| 31 |
-
|
| 32 |
-
results = []
|
| 33 |
-
with torch.no_grad():
|
| 34 |
-
for pair in tqdm(pairs):
|
| 35 |
-
inputs = tokenizer(
|
| 36 |
-
pair.question,
|
| 37 |
-
pair.answer,
|
| 38 |
-
return_tensors="pt",
|
| 39 |
-
max_length=max_length,
|
| 40 |
-
truncation=True
|
| 41 |
-
)
|
| 42 |
-
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 43 |
-
score = rank_model(**inputs).logits[0].item()
|
| 44 |
-
results.append(score)
|
| 45 |
-
|
| 46 |
-
return_dict[rank] = results
|
| 47 |
-
|
| 48 |
-
def evaluate(self, pairs: list[TextPair]) -> list[float]:
|
| 49 |
-
import torch.multiprocessing as mp
|
| 50 |
-
chunk_size = len(pairs) // self.num_gpus
|
| 51 |
-
chunks = []
|
| 52 |
-
for i in range(self.num_gpus):
|
| 53 |
-
start = i * chunk_size
|
| 54 |
-
end = start + chunk_size
|
| 55 |
-
if i == self.num_gpus - 1:
|
| 56 |
-
end = len(pairs)
|
| 57 |
-
chunks.append(pairs[start:end])
|
| 58 |
-
|
| 59 |
-
# multi-process
|
| 60 |
-
manager = mp.Manager()
|
| 61 |
-
return_dict = manager.dict()
|
| 62 |
-
processes = []
|
| 63 |
-
|
| 64 |
-
for rank, chunk in enumerate(chunks):
|
| 65 |
-
p = mp.Process(
|
| 66 |
-
target=self.process_chunk,
|
| 67 |
-
args=(rank, chunk, self.reward_name, self.max_length, return_dict)
|
| 68 |
-
)
|
| 69 |
-
p.start()
|
| 70 |
-
processes.append(p)
|
| 71 |
-
|
| 72 |
-
for p in processes:
|
| 73 |
-
p.join()
|
| 74 |
-
|
| 75 |
-
# 合并结果
|
| 76 |
-
results = []
|
| 77 |
-
for rank in range(len(chunks)):
|
| 78 |
-
results.extend(return_dict[rank])
|
| 79 |
-
|
| 80 |
-
for p in processes:
|
| 81 |
-
if p.is_alive():
|
| 82 |
-
p.terminate()
|
| 83 |
-
p.join()
|
| 84 |
-
|
| 85 |
-
return results
|
| 86 |
-
|
| 87 |
-
def get_average_score(self, pairs: list[TextPair]) -> float:
|
| 88 |
-
"""
|
| 89 |
-
Get the average score of a batch of texts.
|
| 90 |
-
"""
|
| 91 |
-
results = self.evaluate(pairs)
|
| 92 |
-
self.results = results
|
| 93 |
-
return sum(self.results) / len(pairs)
|
| 94 |
-
|
| 95 |
-
def get_min_max_score(self, pairs: list[TextPair]) -> tuple[float, float]:
|
| 96 |
-
"""
|
| 97 |
-
Get the min and max score of a batch of texts.
|
| 98 |
-
"""
|
| 99 |
-
if self.results is None:
|
| 100 |
-
self.get_average_score(pairs)
|
| 101 |
-
return min(self.results), max(self.results)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf-repo/graphgen/models/evaluate/uni_evaluator.py
DELETED
|
@@ -1,159 +0,0 @@
|
|
| 1 |
-
# https://github.com/maszhongming/UniEval/tree/main
|
| 2 |
-
|
| 3 |
-
from dataclasses import dataclass, field
|
| 4 |
-
from tqdm import tqdm
|
| 5 |
-
from graphgen.models.text.text_pair import TextPair
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
def _add_questions(dimension: str, question: str, answer: str):
|
| 9 |
-
if dimension == "naturalness":
|
| 10 |
-
cur_input = 'question: Is this a natural response in the dialogue? </s> response: ' + answer
|
| 11 |
-
elif dimension == "coherence":
|
| 12 |
-
cur_input = 'question: Is this a coherent response given the dialogue history? </s> response: ' \
|
| 13 |
-
+ answer + ' </s> dialogue history: ' + question
|
| 14 |
-
elif dimension == "understandability":
|
| 15 |
-
cur_input = 'question: Is this an understandable response in the dialogue? </s> response: ' + answer
|
| 16 |
-
else:
|
| 17 |
-
raise NotImplementedError(
|
| 18 |
-
'The input format for this dimension is still undefined. Please customize it first.')
|
| 19 |
-
return cur_input
|
| 20 |
-
|
| 21 |
-
@dataclass
|
| 22 |
-
class UniEvaluator:
|
| 23 |
-
model_name: str = "MingZhong/unieval-sum"
|
| 24 |
-
dimensions: list = field(default_factory=lambda: ['naturalness', 'coherence', 'understandability'])
|
| 25 |
-
max_length: int = 2560
|
| 26 |
-
results: dict = None
|
| 27 |
-
|
| 28 |
-
def __post_init__(self):
|
| 29 |
-
import torch
|
| 30 |
-
self.num_gpus = torch.cuda.device_count()
|
| 31 |
-
self.results = {}
|
| 32 |
-
|
| 33 |
-
@staticmethod
|
| 34 |
-
def process_chunk(rank, pairs, model_name, max_length, dimension, return_dict):
|
| 35 |
-
import torch
|
| 36 |
-
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 37 |
-
device = f'cuda:{rank}'
|
| 38 |
-
torch.cuda.set_device(rank)
|
| 39 |
-
|
| 40 |
-
rank_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
| 41 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 42 |
-
rank_model.to(device)
|
| 43 |
-
rank_model.eval()
|
| 44 |
-
|
| 45 |
-
softmax = torch.nn.Softmax(dim=1)
|
| 46 |
-
|
| 47 |
-
pos_id = tokenizer("Yes")["input_ids"][0]
|
| 48 |
-
neg_id = tokenizer("No")["input_ids"][0]
|
| 49 |
-
|
| 50 |
-
results = []
|
| 51 |
-
with torch.no_grad():
|
| 52 |
-
for pair in tqdm(pairs):
|
| 53 |
-
text = _add_questions(dimension, pair.question, pair.answer)
|
| 54 |
-
|
| 55 |
-
tgt = "No"
|
| 56 |
-
|
| 57 |
-
encoded_src = tokenizer(
|
| 58 |
-
text,
|
| 59 |
-
max_length=max_length,
|
| 60 |
-
truncation=True,
|
| 61 |
-
padding=True,
|
| 62 |
-
return_tensors='pt'
|
| 63 |
-
)
|
| 64 |
-
encoded_tgt = tokenizer(
|
| 65 |
-
tgt,
|
| 66 |
-
max_length=max_length,
|
| 67 |
-
truncation=True,
|
| 68 |
-
padding=True,
|
| 69 |
-
return_tensors='pt'
|
| 70 |
-
)
|
| 71 |
-
|
| 72 |
-
src_tokens = encoded_src['input_ids'].to(device)
|
| 73 |
-
src_mask = encoded_src['attention_mask'].to(device)
|
| 74 |
-
|
| 75 |
-
tgt_tokens = encoded_tgt['input_ids'].to(device)[:, 0].unsqueeze(-1)
|
| 76 |
-
|
| 77 |
-
output = rank_model(
|
| 78 |
-
input_ids=src_tokens,
|
| 79 |
-
attention_mask=src_mask,
|
| 80 |
-
labels=tgt_tokens,
|
| 81 |
-
use_cache = False
|
| 82 |
-
)
|
| 83 |
-
|
| 84 |
-
logits = output.logits.view(-1, rank_model.config.vocab_size)
|
| 85 |
-
|
| 86 |
-
pos_score = softmax(logits)[:, pos_id] # Yes
|
| 87 |
-
neg_score = softmax(logits)[:, neg_id]
|
| 88 |
-
score = pos_score / (pos_score + neg_score)
|
| 89 |
-
|
| 90 |
-
results.append(score.item())
|
| 91 |
-
|
| 92 |
-
return_dict[rank] = results
|
| 93 |
-
|
| 94 |
-
def evaluate(self, pairs: list[TextPair]) -> list[dict]:
|
| 95 |
-
import torch.multiprocessing as mp
|
| 96 |
-
final_results = []
|
| 97 |
-
for dimension in self.dimensions:
|
| 98 |
-
chunk_size = len(pairs) // self.num_gpus
|
| 99 |
-
chunks = []
|
| 100 |
-
for i in range(self.num_gpus):
|
| 101 |
-
start = i * chunk_size
|
| 102 |
-
end = start + chunk_size
|
| 103 |
-
if i == self.num_gpus - 1:
|
| 104 |
-
end = len(pairs)
|
| 105 |
-
chunks.append(pairs[start:end])
|
| 106 |
-
|
| 107 |
-
# multi-process
|
| 108 |
-
manager = mp.Manager()
|
| 109 |
-
return_dict = manager.dict()
|
| 110 |
-
processes = []
|
| 111 |
-
|
| 112 |
-
for rank, chunk in enumerate(chunks):
|
| 113 |
-
p = mp.Process(
|
| 114 |
-
target=self.process_chunk,
|
| 115 |
-
args=(rank, chunk, self.model_name, self.max_length, dimension, return_dict)
|
| 116 |
-
)
|
| 117 |
-
p.start()
|
| 118 |
-
processes.append(p)
|
| 119 |
-
|
| 120 |
-
for p in processes:
|
| 121 |
-
p.join()
|
| 122 |
-
|
| 123 |
-
# 合并结果
|
| 124 |
-
results = []
|
| 125 |
-
for rank in range(len(chunks)):
|
| 126 |
-
results.extend(return_dict[rank])
|
| 127 |
-
|
| 128 |
-
for p in processes:
|
| 129 |
-
if p.is_alive():
|
| 130 |
-
p.terminate()
|
| 131 |
-
p.join()
|
| 132 |
-
|
| 133 |
-
final_results.append({
|
| 134 |
-
dimension: results
|
| 135 |
-
})
|
| 136 |
-
return final_results
|
| 137 |
-
|
| 138 |
-
def get_average_score(self, pairs: list[TextPair]) -> dict:
|
| 139 |
-
"""
|
| 140 |
-
Get the average score of a batch of texts.
|
| 141 |
-
"""
|
| 142 |
-
results = self.evaluate(pairs)
|
| 143 |
-
final_results = {}
|
| 144 |
-
for result in results:
|
| 145 |
-
for key, value in result.items():
|
| 146 |
-
final_results[key] = sum(value) / len(value)
|
| 147 |
-
self.results[key] = value
|
| 148 |
-
return final_results
|
| 149 |
-
|
| 150 |
-
def get_min_max_score(self, pairs: list[TextPair]) -> dict:
|
| 151 |
-
"""
|
| 152 |
-
Get the min and max score of a batch of texts.
|
| 153 |
-
"""
|
| 154 |
-
if self.results is None:
|
| 155 |
-
self.get_average_score(pairs)
|
| 156 |
-
final_results = {}
|
| 157 |
-
for key, value in self.results.items():
|
| 158 |
-
final_results[key] = min(value), max(value)
|
| 159 |
-
return final_results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf-repo/graphgen/models/llm/__init__.py
DELETED
|
File without changes
|
hf-repo/graphgen/models/llm/limitter.py
DELETED
|
@@ -1,88 +0,0 @@
|
|
| 1 |
-
import time
|
| 2 |
-
from datetime import datetime, timedelta
|
| 3 |
-
import asyncio
|
| 4 |
-
|
| 5 |
-
from graphgen.utils import logger
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
class RPM:
|
| 9 |
-
|
| 10 |
-
def __init__(self, rpm: int = 1000):
|
| 11 |
-
self.rpm = rpm
|
| 12 |
-
self.record = {'rpm_slot': self.get_minute_slot(), 'counter': 0}
|
| 13 |
-
|
| 14 |
-
def get_minute_slot(self):
|
| 15 |
-
current_time = time.time()
|
| 16 |
-
dt_object = datetime.fromtimestamp(current_time)
|
| 17 |
-
total_minutes_since_midnight = dt_object.hour * 60 + dt_object.minute
|
| 18 |
-
return total_minutes_since_midnight
|
| 19 |
-
|
| 20 |
-
async def wait(self, silent=False):
|
| 21 |
-
current = time.time()
|
| 22 |
-
dt_object = datetime.fromtimestamp(current)
|
| 23 |
-
minute_slot = self.get_minute_slot()
|
| 24 |
-
|
| 25 |
-
if self.record['rpm_slot'] == minute_slot:
|
| 26 |
-
# check RPM exceed
|
| 27 |
-
if self.record['counter'] >= self.rpm:
|
| 28 |
-
# wait until next minute
|
| 29 |
-
next_minute = dt_object.replace(
|
| 30 |
-
second=0, microsecond=0) + timedelta(minutes=1)
|
| 31 |
-
_next = next_minute.timestamp()
|
| 32 |
-
sleep_time = abs(_next - current)
|
| 33 |
-
if not silent:
|
| 34 |
-
logger.info('RPM sleep %s', sleep_time)
|
| 35 |
-
await asyncio.sleep(sleep_time)
|
| 36 |
-
|
| 37 |
-
self.record = {
|
| 38 |
-
'rpm_slot': self.get_minute_slot(),
|
| 39 |
-
'counter': 0
|
| 40 |
-
}
|
| 41 |
-
else:
|
| 42 |
-
self.record = {'rpm_slot': self.get_minute_slot(), 'counter': 0}
|
| 43 |
-
self.record['counter'] += 1
|
| 44 |
-
|
| 45 |
-
if not silent:
|
| 46 |
-
logger.debug(self.record)
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
class TPM:
|
| 50 |
-
|
| 51 |
-
def __init__(self, tpm: int = 20000):
|
| 52 |
-
self.tpm = tpm
|
| 53 |
-
self.record = {'tpm_slot': self.get_minute_slot(), 'counter': 0}
|
| 54 |
-
|
| 55 |
-
def get_minute_slot(self):
|
| 56 |
-
current_time = time.time()
|
| 57 |
-
dt_object = datetime.fromtimestamp(current_time)
|
| 58 |
-
total_minutes_since_midnight = dt_object.hour * 60 + dt_object.minute
|
| 59 |
-
return total_minutes_since_midnight
|
| 60 |
-
|
| 61 |
-
async def wait(self, token_count, silent=False):
|
| 62 |
-
current = time.time()
|
| 63 |
-
dt_object = datetime.fromtimestamp(current)
|
| 64 |
-
minute_slot = self.get_minute_slot()
|
| 65 |
-
|
| 66 |
-
# get next slot, skip
|
| 67 |
-
if self.record['tpm_slot'] != minute_slot:
|
| 68 |
-
self.record = {'tpm_slot': minute_slot, 'counter': token_count}
|
| 69 |
-
return
|
| 70 |
-
|
| 71 |
-
# check RPM exceed
|
| 72 |
-
self.record['counter'] += token_count
|
| 73 |
-
if self.record['counter'] > self.tpm:
|
| 74 |
-
# wait until next minute
|
| 75 |
-
next_minute = dt_object.replace(
|
| 76 |
-
second=0, microsecond=0) + timedelta(minutes=1)
|
| 77 |
-
_next = next_minute.timestamp()
|
| 78 |
-
sleep_time = abs(_next - current)
|
| 79 |
-
logger.info('TPM sleep %s', sleep_time)
|
| 80 |
-
await asyncio.sleep(sleep_time)
|
| 81 |
-
|
| 82 |
-
self.record = {
|
| 83 |
-
'tpm_slot': self.get_minute_slot(),
|
| 84 |
-
'counter': token_count
|
| 85 |
-
}
|
| 86 |
-
|
| 87 |
-
if not silent:
|
| 88 |
-
logger.debug(self.record)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf-repo/graphgen/models/llm/openai_model.py
DELETED
|
@@ -1,155 +0,0 @@
|
|
| 1 |
-
import math
|
| 2 |
-
import re
|
| 3 |
-
from dataclasses import dataclass, field
|
| 4 |
-
from typing import Dict, List, Optional
|
| 5 |
-
|
| 6 |
-
import openai
|
| 7 |
-
from openai import APIConnectionError, APITimeoutError, AsyncOpenAI, RateLimitError
|
| 8 |
-
from tenacity import (
|
| 9 |
-
retry,
|
| 10 |
-
retry_if_exception_type,
|
| 11 |
-
stop_after_attempt,
|
| 12 |
-
wait_exponential,
|
| 13 |
-
)
|
| 14 |
-
|
| 15 |
-
from graphgen.models.llm.limitter import RPM, TPM
|
| 16 |
-
from graphgen.models.llm.tokenizer import Tokenizer
|
| 17 |
-
from graphgen.models.llm.topk_token_model import Token, TopkTokenModel
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def get_top_response_tokens(response: openai.ChatCompletion) -> List[Token]:
|
| 21 |
-
token_logprobs = response.choices[0].logprobs.content
|
| 22 |
-
tokens = []
|
| 23 |
-
for token_prob in token_logprobs:
|
| 24 |
-
prob = math.exp(token_prob.logprob)
|
| 25 |
-
candidate_tokens = [
|
| 26 |
-
Token(t.token, math.exp(t.logprob)) for t in token_prob.top_logprobs
|
| 27 |
-
]
|
| 28 |
-
token = Token(token_prob.token, prob, top_candidates=candidate_tokens)
|
| 29 |
-
tokens.append(token)
|
| 30 |
-
return tokens
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
def filter_think_tags(text: str) -> str:
|
| 34 |
-
"""
|
| 35 |
-
Remove <think> tags from the text.
|
| 36 |
-
If the text contains <think> and </think>, it removes everything between them and the tags themselves.
|
| 37 |
-
"""
|
| 38 |
-
think_pattern = re.compile(r"<think>.*?</think>", re.DOTALL)
|
| 39 |
-
filtered_text = think_pattern.sub("", text).strip()
|
| 40 |
-
return filtered_text if filtered_text else text.strip()
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
@dataclass
|
| 44 |
-
class OpenAIModel(TopkTokenModel):
|
| 45 |
-
model_name: str = "gpt-4o-mini"
|
| 46 |
-
api_key: str = None
|
| 47 |
-
base_url: str = None
|
| 48 |
-
|
| 49 |
-
system_prompt: str = ""
|
| 50 |
-
json_mode: bool = False
|
| 51 |
-
seed: int = None
|
| 52 |
-
|
| 53 |
-
token_usage: list = field(default_factory=list)
|
| 54 |
-
request_limit: bool = False
|
| 55 |
-
rpm: RPM = field(default_factory=lambda: RPM(rpm=1000))
|
| 56 |
-
tpm: TPM = field(default_factory=lambda: TPM(tpm=50000))
|
| 57 |
-
|
| 58 |
-
tokenizer_instance: Tokenizer = field(default_factory=Tokenizer)
|
| 59 |
-
|
| 60 |
-
def __post_init__(self):
|
| 61 |
-
assert self.api_key is not None, "Please provide api key to access openai api."
|
| 62 |
-
self.client = AsyncOpenAI(
|
| 63 |
-
api_key=self.api_key or "dummy", base_url=self.base_url
|
| 64 |
-
)
|
| 65 |
-
|
| 66 |
-
def _pre_generate(self, text: str, history: List[str]) -> Dict:
|
| 67 |
-
kwargs = {
|
| 68 |
-
"temperature": self.temperature,
|
| 69 |
-
"top_p": self.topp,
|
| 70 |
-
"max_tokens": self.max_tokens,
|
| 71 |
-
}
|
| 72 |
-
if self.seed:
|
| 73 |
-
kwargs["seed"] = self.seed
|
| 74 |
-
if self.json_mode:
|
| 75 |
-
kwargs["response_format"] = {"type": "json_object"}
|
| 76 |
-
|
| 77 |
-
messages = []
|
| 78 |
-
if self.system_prompt:
|
| 79 |
-
messages.append({"role": "system", "content": self.system_prompt})
|
| 80 |
-
messages.append({"role": "user", "content": text})
|
| 81 |
-
|
| 82 |
-
if history:
|
| 83 |
-
assert len(history) % 2 == 0, "History should have even number of elements."
|
| 84 |
-
messages = history + messages
|
| 85 |
-
|
| 86 |
-
kwargs["messages"] = messages
|
| 87 |
-
return kwargs
|
| 88 |
-
|
| 89 |
-
@retry(
|
| 90 |
-
stop=stop_after_attempt(5),
|
| 91 |
-
wait=wait_exponential(multiplier=1, min=4, max=10),
|
| 92 |
-
retry=retry_if_exception_type(
|
| 93 |
-
(RateLimitError, APIConnectionError, APITimeoutError)
|
| 94 |
-
),
|
| 95 |
-
)
|
| 96 |
-
async def generate_topk_per_token(
|
| 97 |
-
self, text: str, history: Optional[List[str]] = None
|
| 98 |
-
) -> List[Token]:
|
| 99 |
-
kwargs = self._pre_generate(text, history)
|
| 100 |
-
if self.topk_per_token > 0:
|
| 101 |
-
kwargs["logprobs"] = True
|
| 102 |
-
kwargs["top_logprobs"] = self.topk_per_token
|
| 103 |
-
|
| 104 |
-
# Limit max_tokens to 1 to avoid long completions
|
| 105 |
-
kwargs["max_tokens"] = 1
|
| 106 |
-
|
| 107 |
-
completion = await self.client.chat.completions.create( # pylint: disable=E1125
|
| 108 |
-
model=self.model_name, **kwargs
|
| 109 |
-
)
|
| 110 |
-
|
| 111 |
-
tokens = get_top_response_tokens(completion)
|
| 112 |
-
|
| 113 |
-
return tokens
|
| 114 |
-
|
| 115 |
-
@retry(
|
| 116 |
-
stop=stop_after_attempt(5),
|
| 117 |
-
wait=wait_exponential(multiplier=1, min=4, max=10),
|
| 118 |
-
retry=retry_if_exception_type(
|
| 119 |
-
(RateLimitError, APIConnectionError, APITimeoutError)
|
| 120 |
-
),
|
| 121 |
-
)
|
| 122 |
-
async def generate_answer(
|
| 123 |
-
self, text: str, history: Optional[List[str]] = None, temperature: int = 0
|
| 124 |
-
) -> str:
|
| 125 |
-
kwargs = self._pre_generate(text, history)
|
| 126 |
-
kwargs["temperature"] = temperature
|
| 127 |
-
|
| 128 |
-
prompt_tokens = 0
|
| 129 |
-
for message in kwargs["messages"]:
|
| 130 |
-
prompt_tokens += len(
|
| 131 |
-
self.tokenizer_instance.encode_string(message["content"])
|
| 132 |
-
)
|
| 133 |
-
estimated_tokens = prompt_tokens + kwargs["max_tokens"]
|
| 134 |
-
|
| 135 |
-
if self.request_limit:
|
| 136 |
-
await self.rpm.wait(silent=True)
|
| 137 |
-
await self.tpm.wait(estimated_tokens, silent=True)
|
| 138 |
-
|
| 139 |
-
completion = await self.client.chat.completions.create( # pylint: disable=E1125
|
| 140 |
-
model=self.model_name, **kwargs
|
| 141 |
-
)
|
| 142 |
-
if hasattr(completion, "usage"):
|
| 143 |
-
self.token_usage.append(
|
| 144 |
-
{
|
| 145 |
-
"prompt_tokens": completion.usage.prompt_tokens,
|
| 146 |
-
"completion_tokens": completion.usage.completion_tokens,
|
| 147 |
-
"total_tokens": completion.usage.total_tokens,
|
| 148 |
-
}
|
| 149 |
-
)
|
| 150 |
-
return filter_think_tags(completion.choices[0].message.content)
|
| 151 |
-
|
| 152 |
-
async def generate_inputs_prob(
|
| 153 |
-
self, text: str, history: Optional[List[str]] = None
|
| 154 |
-
) -> List[Token]:
|
| 155 |
-
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf-repo/graphgen/models/llm/tokenizer.py
DELETED
|
@@ -1,73 +0,0 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
-
from typing import List
|
| 3 |
-
import tiktoken
|
| 4 |
-
|
| 5 |
-
try:
|
| 6 |
-
from transformers import AutoTokenizer
|
| 7 |
-
TRANSFORMERS_AVAILABLE = True
|
| 8 |
-
except ImportError:
|
| 9 |
-
AutoTokenizer = None
|
| 10 |
-
TRANSFORMERS_AVAILABLE = False
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
def get_tokenizer(tokenizer_name: str = "cl100k_base"):
|
| 14 |
-
"""
|
| 15 |
-
Get a tokenizer instance by name.
|
| 16 |
-
|
| 17 |
-
:param tokenizer_name: tokenizer name, tiktoken encoding name or Hugging Face model name
|
| 18 |
-
:return: tokenizer instance
|
| 19 |
-
"""
|
| 20 |
-
if tokenizer_name in tiktoken.list_encoding_names():
|
| 21 |
-
return tiktoken.get_encoding(tokenizer_name)
|
| 22 |
-
if TRANSFORMERS_AVAILABLE:
|
| 23 |
-
try:
|
| 24 |
-
return AutoTokenizer.from_pretrained(tokenizer_name)
|
| 25 |
-
except Exception as e:
|
| 26 |
-
raise ValueError(f"Failed to load tokenizer from Hugging Face: {e}") from e
|
| 27 |
-
else:
|
| 28 |
-
raise ValueError("Hugging Face Transformers is not available, please install it first.")
|
| 29 |
-
|
| 30 |
-
@dataclass
|
| 31 |
-
class Tokenizer:
|
| 32 |
-
model_name: str = "cl100k_base"
|
| 33 |
-
|
| 34 |
-
def __post_init__(self):
|
| 35 |
-
self.tokenizer = get_tokenizer(self.model_name)
|
| 36 |
-
|
| 37 |
-
def encode_string(self, text: str) -> List[int]:
|
| 38 |
-
"""
|
| 39 |
-
Encode text to tokens
|
| 40 |
-
|
| 41 |
-
:param text
|
| 42 |
-
:return: tokens
|
| 43 |
-
"""
|
| 44 |
-
return self.tokenizer.encode(text)
|
| 45 |
-
|
| 46 |
-
def decode_tokens(self, tokens: List[int]) -> str:
|
| 47 |
-
"""
|
| 48 |
-
Decode tokens to text
|
| 49 |
-
|
| 50 |
-
:param tokens
|
| 51 |
-
:return: text
|
| 52 |
-
"""
|
| 53 |
-
return self.tokenizer.decode(tokens)
|
| 54 |
-
|
| 55 |
-
def chunk_by_token_size(
|
| 56 |
-
self, content: str, overlap_token_size=128, max_token_size=1024
|
| 57 |
-
):
|
| 58 |
-
tokens = self.encode_string(content)
|
| 59 |
-
results = []
|
| 60 |
-
for index, start in enumerate(
|
| 61 |
-
range(0, len(tokens), max_token_size - overlap_token_size)
|
| 62 |
-
):
|
| 63 |
-
chunk_content = self.decode_tokens(
|
| 64 |
-
tokens[start : start + max_token_size]
|
| 65 |
-
)
|
| 66 |
-
results.append(
|
| 67 |
-
{
|
| 68 |
-
"tokens": min(max_token_size, len(tokens) - start),
|
| 69 |
-
"content": chunk_content.strip(),
|
| 70 |
-
"chunk_order_index": index,
|
| 71 |
-
}
|
| 72 |
-
)
|
| 73 |
-
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf-repo/graphgen/models/llm/topk_token_model.py
DELETED
|
@@ -1,48 +0,0 @@
|
|
| 1 |
-
import math
|
| 2 |
-
from dataclasses import dataclass, field
|
| 3 |
-
from typing import List, Union, Optional
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
@dataclass
|
| 7 |
-
class Token:
|
| 8 |
-
text: str
|
| 9 |
-
prob: float
|
| 10 |
-
top_candidates: List = field(default_factory=list)
|
| 11 |
-
ppl: Union[float, None] = field(default=None)
|
| 12 |
-
|
| 13 |
-
@property
|
| 14 |
-
def logprob(self) -> float:
|
| 15 |
-
return math.log(self.prob)
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
@dataclass
|
| 19 |
-
class TopkTokenModel:
|
| 20 |
-
do_sample: bool = False
|
| 21 |
-
temperature: float = 0
|
| 22 |
-
max_tokens: int = 4096
|
| 23 |
-
repetition_penalty: float = 1.05
|
| 24 |
-
num_beams: int = 1
|
| 25 |
-
topk: int = 50
|
| 26 |
-
topp: float = 0.95
|
| 27 |
-
|
| 28 |
-
topk_per_token: int = 5 # number of topk tokens to generate for each token
|
| 29 |
-
|
| 30 |
-
async def generate_topk_per_token(self, text: str) -> List[Token]:
|
| 31 |
-
"""
|
| 32 |
-
Generate prob, text and candidates for each token of the model's output.
|
| 33 |
-
This function is used to visualize the inference process.
|
| 34 |
-
"""
|
| 35 |
-
raise NotImplementedError
|
| 36 |
-
|
| 37 |
-
async def generate_inputs_prob(self, text: str, history: Optional[List[str]] = None) -> List[Token]:
|
| 38 |
-
"""
|
| 39 |
-
Generate prob and text for each token of the input text.
|
| 40 |
-
This function is used to visualize the ppl.
|
| 41 |
-
"""
|
| 42 |
-
raise NotImplementedError
|
| 43 |
-
|
| 44 |
-
async def generate_answer(self, text: str, history: Optional[List[str]] = None) -> str:
|
| 45 |
-
"""
|
| 46 |
-
Generate answer from the model.
|
| 47 |
-
"""
|
| 48 |
-
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf-repo/graphgen/models/search/__init__.py
DELETED
|
File without changes
|
hf-repo/graphgen/models/search/db/__init__.py
DELETED
|
File without changes
|
hf-repo/graphgen/models/search/db/uniprot_search.py
DELETED
|
@@ -1,64 +0,0 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
-
|
| 3 |
-
import requests
|
| 4 |
-
from fastapi import HTTPException
|
| 5 |
-
|
| 6 |
-
from graphgen.utils import logger
|
| 7 |
-
|
| 8 |
-
UNIPROT_BASE = "https://rest.uniprot.org/uniprotkb/search"
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
@dataclass
|
| 12 |
-
class UniProtSearch:
|
| 13 |
-
"""
|
| 14 |
-
UniProt Search client to search with UniProt.
|
| 15 |
-
1) Get the protein by accession number.
|
| 16 |
-
2) Search with keywords or protein names.
|
| 17 |
-
"""
|
| 18 |
-
|
| 19 |
-
def get_entry(self, accession: str) -> dict:
|
| 20 |
-
"""
|
| 21 |
-
Get the UniProt entry by accession number(e.g., P04637).
|
| 22 |
-
"""
|
| 23 |
-
url = f"{UNIPROT_BASE}/{accession}.json"
|
| 24 |
-
return self._safe_get(url).json()
|
| 25 |
-
|
| 26 |
-
def search(
|
| 27 |
-
self,
|
| 28 |
-
query: str,
|
| 29 |
-
*,
|
| 30 |
-
size: int = 10,
|
| 31 |
-
cursor: str = None,
|
| 32 |
-
fields: list[str] = None,
|
| 33 |
-
) -> dict:
|
| 34 |
-
"""
|
| 35 |
-
Search UniProt with a query string.
|
| 36 |
-
:param query: The search query.
|
| 37 |
-
:param size: The number of results to return.
|
| 38 |
-
:param cursor: The cursor for pagination.
|
| 39 |
-
:param fields: The fields to return in the response.
|
| 40 |
-
:return: A dictionary containing the search results.
|
| 41 |
-
"""
|
| 42 |
-
params = {
|
| 43 |
-
"query": query,
|
| 44 |
-
"size": size,
|
| 45 |
-
}
|
| 46 |
-
if cursor:
|
| 47 |
-
params["cursor"] = cursor
|
| 48 |
-
if fields:
|
| 49 |
-
params["fields"] = ",".join(fields)
|
| 50 |
-
url = UNIPROT_BASE
|
| 51 |
-
return self._safe_get(url, params=params).json()
|
| 52 |
-
|
| 53 |
-
@staticmethod
|
| 54 |
-
def _safe_get(url: str, params: dict = None) -> requests.Response:
|
| 55 |
-
r = requests.get(
|
| 56 |
-
url,
|
| 57 |
-
params=params,
|
| 58 |
-
headers={"Accept": "application/json"},
|
| 59 |
-
timeout=10,
|
| 60 |
-
)
|
| 61 |
-
if not r.ok:
|
| 62 |
-
logger.error("Search engine error: %s", r.text)
|
| 63 |
-
raise HTTPException(r.status_code, "Search engine error.")
|
| 64 |
-
return r
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf-repo/graphgen/models/search/kg/__init__.py
DELETED
|
File without changes
|
hf-repo/graphgen/models/search/kg/wiki_search.py
DELETED
|
@@ -1,37 +0,0 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
-
from typing import List, Union
|
| 3 |
-
|
| 4 |
-
import wikipedia
|
| 5 |
-
from wikipedia import set_lang
|
| 6 |
-
|
| 7 |
-
from graphgen.utils import detect_main_language, logger
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
@dataclass
|
| 11 |
-
class WikiSearch:
|
| 12 |
-
@staticmethod
|
| 13 |
-
def set_language(language: str):
|
| 14 |
-
assert language in ["en", "zh"], "Only support English and Chinese"
|
| 15 |
-
set_lang(language)
|
| 16 |
-
|
| 17 |
-
async def search(self, query: str, num_results: int = 1) -> Union[List[str], None]:
|
| 18 |
-
self.set_language(detect_main_language(query))
|
| 19 |
-
return wikipedia.search(query, results=num_results, suggestion=False)
|
| 20 |
-
|
| 21 |
-
async def summary(self, query: str) -> Union[str, None]:
|
| 22 |
-
self.set_language(detect_main_language(query))
|
| 23 |
-
try:
|
| 24 |
-
result = wikipedia.summary(query, auto_suggest=False, redirect=False)
|
| 25 |
-
except wikipedia.exceptions.DisambiguationError as e:
|
| 26 |
-
logger.error("DisambiguationError: %s", e)
|
| 27 |
-
result = None
|
| 28 |
-
return result
|
| 29 |
-
|
| 30 |
-
async def page(self, query: str) -> Union[str, None]:
|
| 31 |
-
self.set_language(detect_main_language(query))
|
| 32 |
-
try:
|
| 33 |
-
result = wikipedia.page(query, auto_suggest=False, redirect=False).content
|
| 34 |
-
except wikipedia.exceptions.DisambiguationError as e:
|
| 35 |
-
logger.error("DisambiguationError: %s", e)
|
| 36 |
-
result = None
|
| 37 |
-
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf-repo/graphgen/models/search/web/__init__.py
DELETED
|
File without changes
|
hf-repo/graphgen/models/search/web/bing_search.py
DELETED
|
@@ -1,43 +0,0 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
-
|
| 3 |
-
import requests
|
| 4 |
-
from fastapi import HTTPException
|
| 5 |
-
|
| 6 |
-
from graphgen.utils import logger
|
| 7 |
-
|
| 8 |
-
BING_SEARCH_V7_ENDPOINT = "https://api.bing.microsoft.com/v7.0/search"
|
| 9 |
-
BING_MKT = "en-US"
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
@dataclass
|
| 13 |
-
class BingSearch:
|
| 14 |
-
"""
|
| 15 |
-
Bing Search client to search with Bing.
|
| 16 |
-
"""
|
| 17 |
-
|
| 18 |
-
subscription_key: str
|
| 19 |
-
|
| 20 |
-
def search(self, query: str, num_results: int = 1):
|
| 21 |
-
"""
|
| 22 |
-
Search with Bing and return the contexts.
|
| 23 |
-
:param query: The search query.
|
| 24 |
-
:param num_results: The number of results to return.
|
| 25 |
-
:return: A list of search results.
|
| 26 |
-
"""
|
| 27 |
-
params = {"q": query, "mkt": BING_MKT, "count": num_results}
|
| 28 |
-
response = requests.get(
|
| 29 |
-
BING_SEARCH_V7_ENDPOINT,
|
| 30 |
-
headers={"Ocp-Apim-Subscription-Key": self.subscription_key},
|
| 31 |
-
params=params,
|
| 32 |
-
timeout=10,
|
| 33 |
-
)
|
| 34 |
-
if not response.ok:
|
| 35 |
-
logger.error("Search engine error: %s", response.text)
|
| 36 |
-
raise HTTPException(response.status_code, "Search engine error.")
|
| 37 |
-
json_content = response.json()
|
| 38 |
-
try:
|
| 39 |
-
contexts = json_content["webPages"]["value"][:num_results]
|
| 40 |
-
except KeyError:
|
| 41 |
-
logger.error("Error encountered: %s", json_content)
|
| 42 |
-
return []
|
| 43 |
-
return contexts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf-repo/graphgen/models/search/web/google_search.py
DELETED
|
@@ -1,45 +0,0 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
-
|
| 3 |
-
import requests
|
| 4 |
-
from fastapi import HTTPException
|
| 5 |
-
|
| 6 |
-
from graphgen.utils import logger
|
| 7 |
-
|
| 8 |
-
GOOGLE_SEARCH_ENDPOINT = "https://customsearch.googleapis.com/customsearch/v1"
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
@dataclass
|
| 12 |
-
class GoogleSearch:
|
| 13 |
-
def __init__(self, subscription_key: str, cx: str):
|
| 14 |
-
"""
|
| 15 |
-
Initialize the Google Search client with the subscription key and custom search engine ID.
|
| 16 |
-
:param subscription_key: Your Google API subscription key.
|
| 17 |
-
:param cx: Your custom search engine ID.
|
| 18 |
-
"""
|
| 19 |
-
self.subscription_key = subscription_key
|
| 20 |
-
self.cx = cx
|
| 21 |
-
|
| 22 |
-
def search(self, query: str, num_results: int = 1):
|
| 23 |
-
"""
|
| 24 |
-
Search with Google and return the contexts.
|
| 25 |
-
:param query: The search query.
|
| 26 |
-
:param num_results: The number of results to return.
|
| 27 |
-
:return: A list of search results.
|
| 28 |
-
"""
|
| 29 |
-
params = {
|
| 30 |
-
"key": self.subscription_key,
|
| 31 |
-
"cx": self.cx,
|
| 32 |
-
"q": query,
|
| 33 |
-
"num": num_results,
|
| 34 |
-
}
|
| 35 |
-
response = requests.get(GOOGLE_SEARCH_ENDPOINT, params=params, timeout=10)
|
| 36 |
-
if not response.ok:
|
| 37 |
-
logger.error("Search engine error: %s", response.text)
|
| 38 |
-
raise HTTPException(response.status_code, "Search engine error.")
|
| 39 |
-
json_content = response.json()
|
| 40 |
-
try:
|
| 41 |
-
contexts = json_content["items"][:num_results]
|
| 42 |
-
except KeyError:
|
| 43 |
-
logger.error("Error encountered: %s", json_content)
|
| 44 |
-
return []
|
| 45 |
-
return contexts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf-repo/graphgen/models/storage/__init__.py
DELETED
|
File without changes
|
hf-repo/graphgen/models/storage/base_storage.py
DELETED
|
@@ -1,115 +0,0 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
-
from typing import Generic, TypeVar, Union
|
| 3 |
-
|
| 4 |
-
from graphgen.models.embed.embedding import EmbeddingFunc
|
| 5 |
-
|
| 6 |
-
T = TypeVar("T")
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
@dataclass
|
| 10 |
-
class StorageNameSpace:
|
| 11 |
-
working_dir: str = None
|
| 12 |
-
namespace: str = None
|
| 13 |
-
|
| 14 |
-
async def index_done_callback(self):
|
| 15 |
-
"""commit the storage operations after indexing"""
|
| 16 |
-
|
| 17 |
-
async def query_done_callback(self):
|
| 18 |
-
"""commit the storage operations after querying"""
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
@dataclass
|
| 22 |
-
class BaseListStorage(Generic[T], StorageNameSpace):
|
| 23 |
-
async def all_items(self) -> list[T]:
|
| 24 |
-
raise NotImplementedError
|
| 25 |
-
|
| 26 |
-
async def get_by_index(self, index: int) -> Union[T, None]:
|
| 27 |
-
raise NotImplementedError
|
| 28 |
-
|
| 29 |
-
async def append(self, data: T):
|
| 30 |
-
raise NotImplementedError
|
| 31 |
-
|
| 32 |
-
async def upsert(self, data: list[T]):
|
| 33 |
-
raise NotImplementedError
|
| 34 |
-
|
| 35 |
-
async def drop(self):
|
| 36 |
-
raise NotImplementedError
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
@dataclass
|
| 40 |
-
class BaseKVStorage(Generic[T], StorageNameSpace):
|
| 41 |
-
async def all_keys(self) -> list[str]:
|
| 42 |
-
raise NotImplementedError
|
| 43 |
-
|
| 44 |
-
async def get_by_id(self, id: str) -> Union[T, None]:
|
| 45 |
-
raise NotImplementedError
|
| 46 |
-
|
| 47 |
-
async def get_by_ids(
|
| 48 |
-
self, ids: list[str], fields: Union[set[str], None] = None
|
| 49 |
-
) -> list[Union[T, None]]:
|
| 50 |
-
raise NotImplementedError
|
| 51 |
-
|
| 52 |
-
async def filter_keys(self, data: list[str]) -> set[str]:
|
| 53 |
-
"""return un-exist keys"""
|
| 54 |
-
raise NotImplementedError
|
| 55 |
-
|
| 56 |
-
async def upsert(self, data: dict[str, T]):
|
| 57 |
-
raise NotImplementedError
|
| 58 |
-
|
| 59 |
-
async def drop(self):
|
| 60 |
-
raise NotImplementedError
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
@dataclass
|
| 64 |
-
class BaseGraphStorage(StorageNameSpace):
|
| 65 |
-
embedding_func: EmbeddingFunc = None
|
| 66 |
-
|
| 67 |
-
async def has_node(self, node_id: str) -> bool:
|
| 68 |
-
raise NotImplementedError
|
| 69 |
-
|
| 70 |
-
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
| 71 |
-
raise NotImplementedError
|
| 72 |
-
|
| 73 |
-
async def node_degree(self, node_id: str) -> int:
|
| 74 |
-
raise NotImplementedError
|
| 75 |
-
|
| 76 |
-
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
| 77 |
-
raise NotImplementedError
|
| 78 |
-
|
| 79 |
-
async def get_node(self, node_id: str) -> Union[dict, None]:
|
| 80 |
-
raise NotImplementedError
|
| 81 |
-
|
| 82 |
-
async def update_node(self, node_id: str, node_data: dict[str, str]):
|
| 83 |
-
raise NotImplementedError
|
| 84 |
-
|
| 85 |
-
async def get_all_nodes(self) -> Union[list[dict], None]:
|
| 86 |
-
raise NotImplementedError
|
| 87 |
-
|
| 88 |
-
async def get_edge(
|
| 89 |
-
self, source_node_id: str, target_node_id: str
|
| 90 |
-
) -> Union[dict, None]:
|
| 91 |
-
raise NotImplementedError
|
| 92 |
-
|
| 93 |
-
async def update_edge(
|
| 94 |
-
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
| 95 |
-
):
|
| 96 |
-
raise NotImplementedError
|
| 97 |
-
|
| 98 |
-
async def get_all_edges(self) -> Union[list[dict], None]:
|
| 99 |
-
raise NotImplementedError
|
| 100 |
-
|
| 101 |
-
async def get_node_edges(
|
| 102 |
-
self, source_node_id: str
|
| 103 |
-
) -> Union[list[tuple[str, str]], None]:
|
| 104 |
-
raise NotImplementedError
|
| 105 |
-
|
| 106 |
-
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
| 107 |
-
raise NotImplementedError
|
| 108 |
-
|
| 109 |
-
async def upsert_edge(
|
| 110 |
-
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
| 111 |
-
):
|
| 112 |
-
raise NotImplementedError
|
| 113 |
-
|
| 114 |
-
async def delete_node(self, node_id: str):
|
| 115 |
-
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf-repo/graphgen/models/storage/json_storage.py
DELETED
|
@@ -1,87 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
from dataclasses import dataclass
|
| 3 |
-
|
| 4 |
-
from graphgen.models.storage.base_storage import BaseKVStorage, BaseListStorage
|
| 5 |
-
from graphgen.utils import load_json, logger, write_json
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
@dataclass
|
| 9 |
-
class JsonKVStorage(BaseKVStorage):
|
| 10 |
-
_data: dict[str, str] = None
|
| 11 |
-
|
| 12 |
-
def __post_init__(self):
|
| 13 |
-
self._file_name = os.path.join(self.working_dir, f"{self.namespace}.json")
|
| 14 |
-
self._data = load_json(self._file_name) or {}
|
| 15 |
-
logger.info("Load KV %s with %d data", self.namespace, len(self._data))
|
| 16 |
-
|
| 17 |
-
@property
|
| 18 |
-
def data(self):
|
| 19 |
-
return self._data
|
| 20 |
-
|
| 21 |
-
async def all_keys(self) -> list[str]:
|
| 22 |
-
return list(self._data.keys())
|
| 23 |
-
|
| 24 |
-
async def index_done_callback(self):
|
| 25 |
-
write_json(self._data, self._file_name)
|
| 26 |
-
|
| 27 |
-
async def get_by_id(self, id):
|
| 28 |
-
return self._data.get(id, None)
|
| 29 |
-
|
| 30 |
-
async def get_by_ids(self, ids, fields=None) -> list:
|
| 31 |
-
if fields is None:
|
| 32 |
-
return [self._data.get(id, None) for id in ids]
|
| 33 |
-
return [
|
| 34 |
-
(
|
| 35 |
-
{k: v for k, v in self._data[id].items() if k in fields}
|
| 36 |
-
if self._data.get(id, None)
|
| 37 |
-
else None
|
| 38 |
-
)
|
| 39 |
-
for id in ids
|
| 40 |
-
]
|
| 41 |
-
|
| 42 |
-
async def filter_keys(self, data: list[str]) -> set[str]:
|
| 43 |
-
return {s for s in data if s not in self._data}
|
| 44 |
-
|
| 45 |
-
async def upsert(self, data: dict):
|
| 46 |
-
left_data = {k: v for k, v in data.items() if k not in self._data}
|
| 47 |
-
self._data.update(left_data)
|
| 48 |
-
return left_data
|
| 49 |
-
|
| 50 |
-
async def drop(self):
|
| 51 |
-
self._data = {}
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
@dataclass
|
| 55 |
-
class JsonListStorage(BaseListStorage):
|
| 56 |
-
_data: list = None
|
| 57 |
-
|
| 58 |
-
def __post_init__(self):
|
| 59 |
-
self._file_name = os.path.join(self.working_dir, f"{self.namespace}.json")
|
| 60 |
-
self._data = load_json(self._file_name) or []
|
| 61 |
-
logger.info("Load List %s with %d data", self.namespace, len(self._data))
|
| 62 |
-
|
| 63 |
-
@property
|
| 64 |
-
def data(self):
|
| 65 |
-
return self._data
|
| 66 |
-
|
| 67 |
-
async def all_items(self) -> list:
|
| 68 |
-
return self._data
|
| 69 |
-
|
| 70 |
-
async def index_done_callback(self):
|
| 71 |
-
write_json(self._data, self._file_name)
|
| 72 |
-
|
| 73 |
-
async def get_by_index(self, index: int):
|
| 74 |
-
if index < 0 or index >= len(self._data):
|
| 75 |
-
return None
|
| 76 |
-
return self._data[index]
|
| 77 |
-
|
| 78 |
-
async def append(self, data):
|
| 79 |
-
self._data.append(data)
|
| 80 |
-
|
| 81 |
-
async def upsert(self, data: list):
|
| 82 |
-
left_data = [d for d in data if d not in self._data]
|
| 83 |
-
self._data.extend(left_data)
|
| 84 |
-
return left_data
|
| 85 |
-
|
| 86 |
-
async def drop(self):
|
| 87 |
-
self._data = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf-repo/graphgen/models/storage/networkx_storage.py
DELETED
|
@@ -1,159 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import html
|
| 3 |
-
from typing import Any, Union, cast, Optional
|
| 4 |
-
from dataclasses import dataclass
|
| 5 |
-
import networkx as nx
|
| 6 |
-
|
| 7 |
-
from graphgen.utils import logger
|
| 8 |
-
from .base_storage import BaseGraphStorage
|
| 9 |
-
|
| 10 |
-
@dataclass
|
| 11 |
-
class NetworkXStorage(BaseGraphStorage):
|
| 12 |
-
@staticmethod
|
| 13 |
-
def load_nx_graph(file_name) -> Optional[nx.Graph]:
|
| 14 |
-
if os.path.exists(file_name):
|
| 15 |
-
return nx.read_graphml(file_name)
|
| 16 |
-
return None
|
| 17 |
-
|
| 18 |
-
@staticmethod
|
| 19 |
-
def write_nx_graph(graph: nx.Graph, file_name):
|
| 20 |
-
logger.info("Writing graph with %d nodes, %d edges", graph.number_of_nodes(), graph.number_of_edges())
|
| 21 |
-
nx.write_graphml(graph, file_name)
|
| 22 |
-
|
| 23 |
-
@staticmethod
|
| 24 |
-
def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
|
| 25 |
-
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
|
| 26 |
-
Return the largest connected component of the graph, with nodes and edges sorted in a stable way.
|
| 27 |
-
"""
|
| 28 |
-
from graspologic.utils import largest_connected_component
|
| 29 |
-
|
| 30 |
-
graph = graph.copy()
|
| 31 |
-
graph = cast(nx.Graph, largest_connected_component(graph))
|
| 32 |
-
node_mapping = {
|
| 33 |
-
node: html.unescape(node.upper().strip()) for node in graph.nodes()
|
| 34 |
-
} # type: ignore
|
| 35 |
-
graph = nx.relabel_nodes(graph, node_mapping)
|
| 36 |
-
return NetworkXStorage._stabilize_graph(graph)
|
| 37 |
-
|
| 38 |
-
@staticmethod
|
| 39 |
-
def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
|
| 40 |
-
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
|
| 41 |
-
Ensure an undirected graph with the same relationships will always be read the same way.
|
| 42 |
-
通过对节点和边进行排序来实现
|
| 43 |
-
"""
|
| 44 |
-
fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
|
| 45 |
-
|
| 46 |
-
sorted_nodes = graph.nodes(data=True)
|
| 47 |
-
sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
|
| 48 |
-
|
| 49 |
-
fixed_graph.add_nodes_from(sorted_nodes)
|
| 50 |
-
edges = list(graph.edges(data=True))
|
| 51 |
-
|
| 52 |
-
if not graph.is_directed():
|
| 53 |
-
|
| 54 |
-
def _sort_source_target(edge):
|
| 55 |
-
source, target, edge_data = edge
|
| 56 |
-
if source > target:
|
| 57 |
-
source, target = target, source
|
| 58 |
-
return source, target, edge_data
|
| 59 |
-
|
| 60 |
-
edges = [_sort_source_target(edge) for edge in edges]
|
| 61 |
-
|
| 62 |
-
def _get_edge_key(source: Any, target: Any) -> str:
|
| 63 |
-
return f"{source} -> {target}"
|
| 64 |
-
|
| 65 |
-
edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
|
| 66 |
-
|
| 67 |
-
fixed_graph.add_edges_from(edges)
|
| 68 |
-
return fixed_graph
|
| 69 |
-
|
| 70 |
-
def __post_init__(self):
|
| 71 |
-
"""
|
| 72 |
-
如果图文件存在,则加载图文件,否则创建一个新图
|
| 73 |
-
"""
|
| 74 |
-
self._graphml_xml_file = os.path.join(
|
| 75 |
-
self.working_dir, f"{self.namespace}.graphml"
|
| 76 |
-
)
|
| 77 |
-
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
|
| 78 |
-
if preloaded_graph is not None:
|
| 79 |
-
logger.info(
|
| 80 |
-
"Loaded graph from %s with %d nodes, %d edges", self._graphml_xml_file,
|
| 81 |
-
preloaded_graph.number_of_nodes(), preloaded_graph.number_of_edges()
|
| 82 |
-
)
|
| 83 |
-
self._graph = preloaded_graph or nx.Graph()
|
| 84 |
-
|
| 85 |
-
async def index_done_callback(self):
|
| 86 |
-
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
|
| 87 |
-
|
| 88 |
-
async def has_node(self, node_id: str) -> bool:
|
| 89 |
-
return self._graph.has_node(node_id)
|
| 90 |
-
|
| 91 |
-
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
| 92 |
-
return self._graph.has_edge(source_node_id, target_node_id)
|
| 93 |
-
|
| 94 |
-
async def get_node(self, node_id: str) -> Union[dict, None]:
|
| 95 |
-
return self._graph.nodes.get(node_id)
|
| 96 |
-
|
| 97 |
-
async def get_all_nodes(self) -> Union[list[dict], None]:
|
| 98 |
-
return self._graph.nodes(data=True)
|
| 99 |
-
|
| 100 |
-
async def node_degree(self, node_id: str) -> int:
|
| 101 |
-
return self._graph.degree(node_id)
|
| 102 |
-
|
| 103 |
-
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
| 104 |
-
return self._graph.degree(src_id) + self._graph.degree(tgt_id)
|
| 105 |
-
|
| 106 |
-
async def get_edge(
|
| 107 |
-
self, source_node_id: str, target_node_id: str
|
| 108 |
-
) -> Union[dict, None]:
|
| 109 |
-
return self._graph.edges.get((source_node_id, target_node_id))
|
| 110 |
-
|
| 111 |
-
async def get_all_edges(self) -> Union[list[dict], None]:
|
| 112 |
-
return self._graph.edges(data=True)
|
| 113 |
-
|
| 114 |
-
async def get_node_edges(self, source_node_id: str) -> Union[list[tuple[str, str]], None]:
|
| 115 |
-
if self._graph.has_node(source_node_id):
|
| 116 |
-
return list(self._graph.edges(source_node_id, data=True))
|
| 117 |
-
return None
|
| 118 |
-
|
| 119 |
-
async def get_graph(self) -> nx.Graph:
|
| 120 |
-
return self._graph
|
| 121 |
-
|
| 122 |
-
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
| 123 |
-
self._graph.add_node(node_id, **node_data)
|
| 124 |
-
|
| 125 |
-
async def update_node(self, node_id: str, node_data: dict[str, str]):
|
| 126 |
-
if self._graph.has_node(node_id):
|
| 127 |
-
self._graph.nodes[node_id].update(node_data)
|
| 128 |
-
else:
|
| 129 |
-
logger.warning("Node %s not found in the graph for update.", node_id)
|
| 130 |
-
|
| 131 |
-
async def upsert_edge(
|
| 132 |
-
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
| 133 |
-
):
|
| 134 |
-
self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
| 135 |
-
|
| 136 |
-
async def update_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]):
|
| 137 |
-
if self._graph.has_edge(source_node_id, target_node_id):
|
| 138 |
-
self._graph.edges[(source_node_id, target_node_id)].update(edge_data)
|
| 139 |
-
else:
|
| 140 |
-
logger.warning("Edge %s -> %s not found in the graph for update.", source_node_id, target_node_id)
|
| 141 |
-
|
| 142 |
-
async def delete_node(self, node_id: str):
|
| 143 |
-
"""
|
| 144 |
-
Delete a node from the graph based on the specified node_id.
|
| 145 |
-
|
| 146 |
-
:param node_id: The node_id to delete
|
| 147 |
-
"""
|
| 148 |
-
if self._graph.has_node(node_id):
|
| 149 |
-
self._graph.remove_node(node_id)
|
| 150 |
-
logger.info("Node %s deleted from the graph.", node_id)
|
| 151 |
-
else:
|
| 152 |
-
logger.warning("Node %s not found in the graph for deletion.", node_id)
|
| 153 |
-
|
| 154 |
-
async def clear(self):
|
| 155 |
-
"""
|
| 156 |
-
Clear the graph by removing all nodes and edges.
|
| 157 |
-
"""
|
| 158 |
-
self._graph.clear()
|
| 159 |
-
logger.info("Graph %s cleared.", self.namespace)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf-repo/graphgen/models/strategy/__init__.py
DELETED
|
File without changes
|
hf-repo/graphgen/models/strategy/base_strategy.py
DELETED
|
@@ -1,5 +0,0 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
-
|
| 3 |
-
@dataclass
|
| 4 |
-
class BaseStrategy:
|
| 5 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf-repo/graphgen/models/strategy/travserse_strategy.py
DELETED
|
@@ -1,30 +0,0 @@
|
|
| 1 |
-
from dataclasses import dataclass, fields
|
| 2 |
-
|
| 3 |
-
from graphgen.models.strategy.base_strategy import BaseStrategy
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
@dataclass
|
| 7 |
-
class TraverseStrategy(BaseStrategy):
|
| 8 |
-
# 生成的QA形式:原子、多跳、聚合型
|
| 9 |
-
qa_form: str = "atomic" # "atomic" or "multi_hop" or "aggregated"
|
| 10 |
-
# 最大边数和最大token数方法中选择一个生效
|
| 11 |
-
expand_method: str = "max_tokens" # "max_width" or "max_tokens"
|
| 12 |
-
# 单向拓展还是双向拓展
|
| 13 |
-
bidirectional: bool = True
|
| 14 |
-
# 每个方向拓展的最大边数
|
| 15 |
-
max_extra_edges: int = 5
|
| 16 |
-
# 最长token数
|
| 17 |
-
max_tokens: int = 256
|
| 18 |
-
# 每个方向拓展的最大深度
|
| 19 |
-
max_depth: int = 2
|
| 20 |
-
# 同一层中选边的策略(如果是双向拓展,同一层指的是两边连接的边的集合)
|
| 21 |
-
edge_sampling: str = "max_loss" # "max_loss" or "min_loss" or "random"
|
| 22 |
-
# 孤立节点的处理策略
|
| 23 |
-
isolated_node_strategy: str = "add" # "add" or "ignore"
|
| 24 |
-
loss_strategy: str = "only_edge" # only_edge, both
|
| 25 |
-
|
| 26 |
-
def to_yaml(self):
|
| 27 |
-
strategy_dict = {}
|
| 28 |
-
for f in fields(self):
|
| 29 |
-
strategy_dict[f.name] = getattr(self, f.name)
|
| 30 |
-
return {"traverse_strategy": strategy_dict}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf-repo/graphgen/models/text/__init__.py
DELETED
|
File without changes
|
hf-repo/graphgen/models/text/chunk.py
DELETED
|
@@ -1,7 +0,0 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
@dataclass
|
| 5 |
-
class Chunk:
|
| 6 |
-
id : str
|
| 7 |
-
content: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf-repo/graphgen/models/text/text_pair.py
DELETED
|
@@ -1,9 +0,0 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
-
|
| 3 |
-
@dataclass
|
| 4 |
-
class TextPair:
|
| 5 |
-
"""
|
| 6 |
-
A pair of input data.
|
| 7 |
-
"""
|
| 8 |
-
question: str
|
| 9 |
-
answer: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf-repo/graphgen/models/vis/__init__.py
DELETED
|
File without changes
|
hf-repo/graphgen/models/vis/community_visualizer.py
DELETED
|
@@ -1,48 +0,0 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
-
from typing import Dict
|
| 3 |
-
|
| 4 |
-
import matplotlib.pyplot as plt
|
| 5 |
-
import networkx as nx
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
@dataclass
|
| 9 |
-
class Visualizer:
|
| 10 |
-
"""
|
| 11 |
-
Class for visualizing graphs using NetworkX and Matplotlib.
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
graph: nx.Graph = None
|
| 15 |
-
communities: Dict[str, int] = None
|
| 16 |
-
layout: str = "spring"
|
| 17 |
-
max_nodes: int = 1000
|
| 18 |
-
node_size: int = 10
|
| 19 |
-
alpha: float = 0.6
|
| 20 |
-
|
| 21 |
-
def visualize(self, save_path: str = None):
|
| 22 |
-
n = self.graph.number_of_nodes()
|
| 23 |
-
if self.layout == "spring":
|
| 24 |
-
k = max(0.1, 1.0 / (n**0.5))
|
| 25 |
-
pos = nx.spring_layout(self.graph, k=k, seed=42)
|
| 26 |
-
else:
|
| 27 |
-
raise ValueError(f"Unknown layout: {self.layout}")
|
| 28 |
-
|
| 29 |
-
plt.figure(figsize=(10, 10))
|
| 30 |
-
|
| 31 |
-
node_colors = [self.communities.get(node, 0) for node in self.graph.nodes()]
|
| 32 |
-
|
| 33 |
-
nx.draw_networkx_nodes(
|
| 34 |
-
self.graph,
|
| 35 |
-
pos,
|
| 36 |
-
node_size=self.node_size,
|
| 37 |
-
node_color=node_colors,
|
| 38 |
-
cmap=plt.cm.tab20,
|
| 39 |
-
alpha=self.alpha,
|
| 40 |
-
)
|
| 41 |
-
nx.draw_networkx_edges(self.graph, pos, alpha=0.3, width=0.2)
|
| 42 |
-
plt.axis("off")
|
| 43 |
-
|
| 44 |
-
if save_path:
|
| 45 |
-
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
| 46 |
-
print("Saved to", save_path)
|
| 47 |
-
else:
|
| 48 |
-
plt.show()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf-repo/graphgen/operators/__init__.py
DELETED
|
@@ -1,22 +0,0 @@
|
|
| 1 |
-
from graphgen.operators.generate.generate_cot import generate_cot
|
| 2 |
-
from graphgen.operators.kg.extract_kg import extract_kg
|
| 3 |
-
from graphgen.operators.search.search_all import search_all
|
| 4 |
-
|
| 5 |
-
from .judge import judge_statement
|
| 6 |
-
from .quiz import quiz
|
| 7 |
-
from .traverse_graph import (
|
| 8 |
-
traverse_graph_atomically,
|
| 9 |
-
traverse_graph_by_edge,
|
| 10 |
-
traverse_graph_for_multi_hop,
|
| 11 |
-
)
|
| 12 |
-
|
| 13 |
-
__all__ = [
|
| 14 |
-
"extract_kg",
|
| 15 |
-
"quiz",
|
| 16 |
-
"judge_statement",
|
| 17 |
-
"search_all",
|
| 18 |
-
"traverse_graph_by_edge",
|
| 19 |
-
"traverse_graph_atomically",
|
| 20 |
-
"traverse_graph_for_multi_hop",
|
| 21 |
-
"generate_cot",
|
| 22 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|