Spaces:
Sleeping
Sleeping
feat: Add complete nano-graphrag source code
Browse files- Add all nano-graphrag source files to Space
- Remove submodule reference and add as regular files
- This ensures nano-graphrag can be installed locally with -e ./nano-graphrag
🤖 Generated with Claude Code
Co-Authored-By: Claude <noreply@anthropic.com>
This view is limited to 50 files because it contains too many changes.
See raw diff
- nano-graphrag +0 -1
- nano-graphrag/.coveragerc +11 -0
- nano-graphrag/.env.example.azure +7 -0
- nano-graphrag/.github/workflows/test.yml +58 -0
- nano-graphrag/.gitignore +183 -0
- nano-graphrag/LICENSE +21 -0
- nano-graphrag/MANIFEST.in +1 -0
- nano-graphrag/docs/CONTRIBUTING.md +19 -0
- nano-graphrag/docs/FAQ.md +41 -0
- nano-graphrag/docs/ROADMAP.md +25 -0
- nano-graphrag/docs/benchmark-dspy-entity-extraction.md +276 -0
- nano-graphrag/docs/benchmark-en.md +150 -0
- nano-graphrag/docs/benchmark-zh.md +91 -0
- nano-graphrag/docs/use_neo4j_for_graphrag.md +27 -0
- nano-graphrag/examples/benchmarks/dspy_entity.py +152 -0
- nano-graphrag/examples/benchmarks/eval_naive_graphrag_on_multi_hop.ipynb +432 -0
- nano-graphrag/examples/benchmarks/hnsw_vs_nano_vector_storage.py +78 -0
- nano-graphrag/examples/benchmarks/md5_vs_xxhash.py +54 -0
- nano-graphrag/examples/finetune_entity_relationship_dspy.ipynb +0 -0
- nano-graphrag/examples/generate_entity_relationship_dspy.ipynb +0 -0
- nano-graphrag/examples/graphml_visualize.py +282 -0
- nano-graphrag/examples/no_openai_key_at_all.py +111 -0
- nano-graphrag/examples/using_amazon_bedrock.py +19 -0
- nano-graphrag/examples/using_custom_chunking_method.py +43 -0
- nano-graphrag/examples/using_deepseek_api_as_llm+glm_api_as_embedding.py +136 -0
- nano-graphrag/examples/using_deepseek_as_llm.py +98 -0
- nano-graphrag/examples/using_dspy_entity_extraction.py +144 -0
- nano-graphrag/examples/using_faiss_as_vextorDB.py +97 -0
- nano-graphrag/examples/using_hnsw_as_vectorDB.py +129 -0
- nano-graphrag/examples/using_llm_api_as_llm+ollama_embedding.py +122 -0
- nano-graphrag/examples/using_local_embedding_model.py +38 -0
- nano-graphrag/examples/using_milvus_as_vectorDB.py +94 -0
- nano-graphrag/examples/using_ollama_as_llm.py +96 -0
- nano-graphrag/examples/using_ollama_as_llm_and_embedding.py +120 -0
- nano-graphrag/examples/using_qdrant_as_vectorDB.py +113 -0
- nano-graphrag/nano_graphrag/__init__.py +7 -0
- nano-graphrag/nano_graphrag/_llm.py +294 -0
- nano-graphrag/nano_graphrag/_op.py +1140 -0
- nano-graphrag/nano_graphrag/_splitter.py +94 -0
- nano-graphrag/nano_graphrag/_storage/__init__.py +5 -0
- nano-graphrag/nano_graphrag/_storage/gdb_neo4j.py +529 -0
- nano-graphrag/nano_graphrag/_storage/gdb_networkx.py +268 -0
- nano-graphrag/nano_graphrag/_storage/kv_json.py +46 -0
- nano-graphrag/nano_graphrag/_storage/vdb_hnswlib.py +141 -0
- nano-graphrag/nano_graphrag/_storage/vdb_nanovectordb.py +68 -0
- nano-graphrag/nano_graphrag/_utils.py +305 -0
- nano-graphrag/nano_graphrag/base.py +186 -0
- nano-graphrag/nano_graphrag/entity_extraction/__init__.py +0 -0
- nano-graphrag/nano_graphrag/entity_extraction/extract.py +171 -0
- nano-graphrag/nano_graphrag/entity_extraction/metric.py +62 -0
nano-graphrag
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
Subproject commit 01f429e8c562e8f19b2449f90cec9a4a67d4f6ee
|
|
|
|
|
|
nano-graphrag/.coveragerc
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[report]
|
| 2 |
+
exclude_lines =
|
| 3 |
+
# Have to re-enable the standard pragma
|
| 4 |
+
pragma: no cover
|
| 5 |
+
|
| 6 |
+
# Don't complain if tests don't hit defensive assertion code:
|
| 7 |
+
raise NotImplementedError
|
| 8 |
+
logger.
|
| 9 |
+
omit =
|
| 10 |
+
# Don't have a nice github action for neo4j now, so skip this file:
|
| 11 |
+
nano_graphrag/_storage/gdb_neo4j.py
|
nano-graphrag/.env.example.azure
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
API_KEY_EMB="<your azure openai key for embedding>"
|
| 2 |
+
AZURE_ENDPOINT_EMB="<your azure openai endpoint for embedding>"
|
| 3 |
+
API_VERSION_EMB="<api version>"
|
| 4 |
+
|
| 5 |
+
AZURE_OPENAI_API_KEY="<your azure openai key for embedding>"
|
| 6 |
+
AZURE_OPENAI_ENDPOINT="<AZURE_OPENAI_ENDPOINT>"
|
| 7 |
+
OPENAI_API_VERSION="<OPENAI_API_VERSION>"
|
nano-graphrag/.github/workflows/test.yml
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: test
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches:
|
| 6 |
+
- main
|
| 7 |
+
- dev
|
| 8 |
+
paths-ignore:
|
| 9 |
+
- '**/*.md'
|
| 10 |
+
- '**/*.ipynb'
|
| 11 |
+
- 'examples/**'
|
| 12 |
+
pull_request:
|
| 13 |
+
branches:
|
| 14 |
+
- main
|
| 15 |
+
- dev
|
| 16 |
+
paths-ignore:
|
| 17 |
+
- '**/*.md'
|
| 18 |
+
- '**/*.ipynb'
|
| 19 |
+
- 'examples/**'
|
| 20 |
+
|
| 21 |
+
jobs:
|
| 22 |
+
test:
|
| 23 |
+
name: Tests on ${{ matrix.os }} for ${{ matrix.python-version }}
|
| 24 |
+
strategy:
|
| 25 |
+
matrix:
|
| 26 |
+
python-version: [3.9]
|
| 27 |
+
os: [ubuntu-latest]
|
| 28 |
+
runs-on: ${{ matrix.os }}
|
| 29 |
+
timeout-minutes: 10
|
| 30 |
+
steps:
|
| 31 |
+
- uses: actions/checkout@v4
|
| 32 |
+
- name: Set up Python ${{ matrix.python-version }}
|
| 33 |
+
uses: actions/setup-python@v3
|
| 34 |
+
with:
|
| 35 |
+
python-version: ${{ matrix.python-version }}
|
| 36 |
+
- name: Install dependencies
|
| 37 |
+
run: |
|
| 38 |
+
python -m pip install --upgrade pip
|
| 39 |
+
pip install -r requirements.txt
|
| 40 |
+
pip install -r requirements-dev.txt
|
| 41 |
+
- name: Lint with flake8
|
| 42 |
+
run: |
|
| 43 |
+
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
|
| 44 |
+
- name: Build and Test
|
| 45 |
+
env:
|
| 46 |
+
NANO_GRAPHRAG_TEST_IGNORE_NEO4J: true
|
| 47 |
+
run: |
|
| 48 |
+
python -m pytest -o log_cli=true -o log_cli_level="INFO" --cov=nano_graphrag --cov-report=xml -v ./
|
| 49 |
+
- name: Check codecov file
|
| 50 |
+
id: check_files
|
| 51 |
+
uses: andstor/file-existence-action@v1
|
| 52 |
+
with:
|
| 53 |
+
files: './coverage.xml'
|
| 54 |
+
- name: Upload coverage from test to Codecov
|
| 55 |
+
uses: codecov/codecov-action@v2
|
| 56 |
+
with:
|
| 57 |
+
file: ./coverage.xml
|
| 58 |
+
token: ${{ secrets.CODECOV_TOKEN }}
|
nano-graphrag/.gitignore
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Created by https://www.toptal.com/developers/gitignore/api/python
|
| 2 |
+
# Edit at https://www.toptal.com/developers/gitignore?templates=python
|
| 3 |
+
test_cache.json
|
| 4 |
+
run_test*.py
|
| 5 |
+
nano_graphrag_cache*/
|
| 6 |
+
*.txt
|
| 7 |
+
examples/benchmarks/fixtures/
|
| 8 |
+
tests/original_workflow.txt
|
| 9 |
+
### Python ###
|
| 10 |
+
# Byte-compiled / optimized / DLL files
|
| 11 |
+
__pycache__/
|
| 12 |
+
*.py[cod]
|
| 13 |
+
*$py.class
|
| 14 |
+
.vscode
|
| 15 |
+
.DS_Store
|
| 16 |
+
# C extensions
|
| 17 |
+
*.so
|
| 18 |
+
|
| 19 |
+
# Distribution / packaging
|
| 20 |
+
.Python
|
| 21 |
+
build/
|
| 22 |
+
develop-eggs/
|
| 23 |
+
dist/
|
| 24 |
+
downloads/
|
| 25 |
+
eggs/
|
| 26 |
+
.eggs/
|
| 27 |
+
lib/
|
| 28 |
+
lib64/
|
| 29 |
+
parts/
|
| 30 |
+
sdist/
|
| 31 |
+
var/
|
| 32 |
+
wheels/
|
| 33 |
+
share/python-wheels/
|
| 34 |
+
*.egg-info/
|
| 35 |
+
.installed.cfg
|
| 36 |
+
*.egg
|
| 37 |
+
MANIFEST
|
| 38 |
+
|
| 39 |
+
# PyInstaller
|
| 40 |
+
# Usually these files are written by a python script from a template
|
| 41 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 42 |
+
*.manifest
|
| 43 |
+
*.spec
|
| 44 |
+
|
| 45 |
+
# Installer logs
|
| 46 |
+
pip-log.txt
|
| 47 |
+
pip-delete-this-directory.txt
|
| 48 |
+
|
| 49 |
+
# Unit test / coverage reports
|
| 50 |
+
htmlcov/
|
| 51 |
+
.tox/
|
| 52 |
+
.nox/
|
| 53 |
+
.coverage
|
| 54 |
+
.coverage.*
|
| 55 |
+
.cache
|
| 56 |
+
nosetests.xml
|
| 57 |
+
coverage.xml
|
| 58 |
+
*.cover
|
| 59 |
+
*.py,cover
|
| 60 |
+
.hypothesis/
|
| 61 |
+
.pytest_cache/
|
| 62 |
+
cover/
|
| 63 |
+
|
| 64 |
+
# Translations
|
| 65 |
+
*.mo
|
| 66 |
+
*.pot
|
| 67 |
+
|
| 68 |
+
# Django stuff:
|
| 69 |
+
*.log
|
| 70 |
+
local_settings.py
|
| 71 |
+
db.sqlite3
|
| 72 |
+
db.sqlite3-journal
|
| 73 |
+
|
| 74 |
+
# Flask stuff:
|
| 75 |
+
instance/
|
| 76 |
+
.webassets-cache
|
| 77 |
+
|
| 78 |
+
# Scrapy stuff:
|
| 79 |
+
.scrapy
|
| 80 |
+
|
| 81 |
+
# Sphinx documentation
|
| 82 |
+
docs/_build/
|
| 83 |
+
|
| 84 |
+
# PyBuilder
|
| 85 |
+
.pybuilder/
|
| 86 |
+
target/
|
| 87 |
+
|
| 88 |
+
# Jupyter Notebook
|
| 89 |
+
.ipynb_checkpoints
|
| 90 |
+
|
| 91 |
+
# IPython
|
| 92 |
+
profile_default/
|
| 93 |
+
ipython_config.py
|
| 94 |
+
|
| 95 |
+
# pyenv
|
| 96 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 97 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 98 |
+
# .python-version
|
| 99 |
+
|
| 100 |
+
# pipenv
|
| 101 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 102 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 103 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 104 |
+
# install all needed dependencies.
|
| 105 |
+
#Pipfile.lock
|
| 106 |
+
|
| 107 |
+
# poetry
|
| 108 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 109 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 110 |
+
# commonly ignored for libraries.
|
| 111 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 112 |
+
#poetry.lock
|
| 113 |
+
|
| 114 |
+
# pdm
|
| 115 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 116 |
+
#pdm.lock
|
| 117 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 118 |
+
# in version control.
|
| 119 |
+
# https://pdm.fming.dev/#use-with-ide
|
| 120 |
+
.pdm.toml
|
| 121 |
+
|
| 122 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 123 |
+
__pypackages__/
|
| 124 |
+
|
| 125 |
+
# Celery stuff
|
| 126 |
+
celerybeat-schedule
|
| 127 |
+
celerybeat.pid
|
| 128 |
+
|
| 129 |
+
# SageMath parsed files
|
| 130 |
+
*.sage.py
|
| 131 |
+
|
| 132 |
+
# Environments
|
| 133 |
+
.env
|
| 134 |
+
.venv
|
| 135 |
+
env/
|
| 136 |
+
venv/
|
| 137 |
+
ENV/
|
| 138 |
+
env.bak/
|
| 139 |
+
venv.bak/
|
| 140 |
+
|
| 141 |
+
# Spyder project settings
|
| 142 |
+
.spyderproject
|
| 143 |
+
.spyproject
|
| 144 |
+
|
| 145 |
+
# Rope project settings
|
| 146 |
+
.ropeproject
|
| 147 |
+
|
| 148 |
+
# mkdocs documentation
|
| 149 |
+
/site
|
| 150 |
+
|
| 151 |
+
# mypy
|
| 152 |
+
.mypy_cache/
|
| 153 |
+
.dmypy.json
|
| 154 |
+
dmypy.json
|
| 155 |
+
|
| 156 |
+
# Pyre type checker
|
| 157 |
+
.pyre/
|
| 158 |
+
|
| 159 |
+
# pytype static type analyzer
|
| 160 |
+
.pytype/
|
| 161 |
+
|
| 162 |
+
# Cython debug symbols
|
| 163 |
+
cython_debug/
|
| 164 |
+
|
| 165 |
+
# PyCharm
|
| 166 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 167 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 168 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 169 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 170 |
+
#.idea/
|
| 171 |
+
|
| 172 |
+
### Python Patch ###
|
| 173 |
+
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
|
| 174 |
+
poetry.toml
|
| 175 |
+
|
| 176 |
+
# ruff
|
| 177 |
+
.ruff_cache/
|
| 178 |
+
|
| 179 |
+
# LSP config files
|
| 180 |
+
pyrightconfig.json
|
| 181 |
+
|
| 182 |
+
# End of https://www.toptal.com/developers/gitignore/api/python
|
| 183 |
+
|
nano-graphrag/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2024 Gustavo Ye
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
nano-graphrag/MANIFEST.in
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
include readme.md
|
nano-graphrag/docs/CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributing to nano-graphrag
|
| 2 |
+
|
| 3 |
+
### Submit your Contribution through PR
|
| 4 |
+
|
| 5 |
+
To make a contribution, follow these steps:
|
| 6 |
+
|
| 7 |
+
1. Fork and clone this repository
|
| 8 |
+
3. If you modified the core code (`./nano_graphrag`), please add tests for it
|
| 9 |
+
4. **Include proper documentation / docstring or examples**
|
| 10 |
+
5. Ensure that all tests pass by running `pytest`
|
| 11 |
+
6. Submit a pull request
|
| 12 |
+
|
| 13 |
+
For more details about pull requests, please read [GitHub's guides](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request).
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
### Only add a dependency when we have to
|
| 18 |
+
|
| 19 |
+
`nano-graphrag` needs to be `nano` and `light`. If we want to add more features, we add them smartly. Don't introduce a huge dependency just for a simple function.
|
nano-graphrag/docs/FAQ.md
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
### `Leiden.EmptyNetworkError:EmptyNetworkError`
|
| 2 |
+
|
| 3 |
+
This error is caused by `nano-graphrag` tries to compute communities on an empty network. In most cases, this is caused by the LLM model you're using, it fails to extract any entities or relations, so the graph is empty.
|
| 4 |
+
|
| 5 |
+
Try to use another bigger LLM, or here are some ideas to fix it:
|
| 6 |
+
|
| 7 |
+
- Check the response from the LLM, make sure the result fits the desired response format of the extracting entities prompt.
|
| 8 |
+
|
| 9 |
+
The desired response format is something like that:
|
| 10 |
+
|
| 11 |
+
```text
|
| 12 |
+
("entity"<|>"Cruz"<|>"person"<|>"Cruz is associated with a vision of control and order, influencing the dynamics among other characters.")
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
- Some LLMs may not return the format like above, so one possible solution is to add a system instruction to the input prompt, such like:
|
| 16 |
+
```json
|
| 17 |
+
{
|
| 18 |
+
"role": "system",
|
| 19 |
+
"content": "You are an intelligent assistant and will follow the instructions given to you to fulfill the goal. The answer should be in the format as in the given example."
|
| 20 |
+
}
|
| 21 |
+
```
|
| 22 |
+
You can use this system_prompt as default for your LLM calling funcation
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
### One possible reason of 'Processed 42 chunks,0 entities(duplicated),0 relations(duplicated)WARNING:nano-graphrag:Didn't extract any entities, maybe your LLM is not working WARNING:nano-graphrag:No new entities found'
|
| 26 |
+
|
| 27 |
+
The default num_ctx of ollama is 2048 which is too small for the input prompt of entity extraction. This causes the model to fail to respond correctly.
|
| 28 |
+
|
| 29 |
+
Solution:
|
| 30 |
+
Each model in Ollama has a configuration file. Here, you need to generate a new configuration file based on the original one, and then use this configuration file to generate a new model.
|
| 31 |
+
For example the qwen2, run the following command:
|
| 32 |
+
|
| 33 |
+
`ollama show --modelfile qwen2 > Modelfile`
|
| 34 |
+
|
| 35 |
+
Add a new line into this file below the 'FROM':
|
| 36 |
+
|
| 37 |
+
`PARAMETER num_ctx 32000`
|
| 38 |
+
|
| 39 |
+
`ollama create -f Modelfile qwen2:ctx32k`
|
| 40 |
+
|
| 41 |
+
Afterwards, you can use qwen2:ctx32k to replace qwen2.
|
nano-graphrag/docs/ROADMAP.md
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Next Version
|
| 2 |
+
|
| 3 |
+
- [ ] Add DSpy for prompt-tuning to make small models(Qwen2 7B, Llama 3.1 8B...) can extract entities. @NumberChiffre @gusye1234
|
| 4 |
+
- [ ] Optimize Algorithm: add `global_local` query method, globally rewrite query then perform local search.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
## In next few versions
|
| 9 |
+
|
| 10 |
+
- [ ] Add rate limiter: support token limit (tokens per second, per minute)
|
| 11 |
+
|
| 12 |
+
- [ ] Add other advanced RAG algorithms, candidates:
|
| 13 |
+
|
| 14 |
+
- [ ] [HybridRAG](https://arxiv.org/abs/2408.04948)
|
| 15 |
+
- [ ] [HippoRAG](https://arxiv.org/abs/2405.14831)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
## Interesting directions
|
| 23 |
+
|
| 24 |
+
- [ ] Add [Sciphi Triplex](https://huggingface.co/SciPhi/Triplex) as the entity extraction model.
|
| 25 |
+
- [ ] Add new components, see [issue](https://github.com/gusye1234/nano-graphrag/issues/2)
|
nano-graphrag/docs/benchmark-dspy-entity-extraction.md
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Chain Of Thought Prompting with DSPy-AI (v2.4.16)
|
| 2 |
+
## Main Takeaways
|
| 3 |
+
- Time difference: 156.99 seconds
|
| 4 |
+
- Execution time with DSPy-AI: 304.38 seconds
|
| 5 |
+
- Execution time without DSPy-AI: 147.39 seconds
|
| 6 |
+
- Entities extracted: 22 (without DSPy-AI) vs 37 (with DSPy-AI)
|
| 7 |
+
- Relationships extracted: 21 (without DSPy-AI) vs 36 (with DSPy-AI)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
## Results
|
| 11 |
+
```markdown
|
| 12 |
+
> python examples/benchmarks/dspy_entity.py
|
| 13 |
+
|
| 14 |
+
Running benchmark with DSPy-AI:
|
| 15 |
+
INFO:httpx:HTTP Request: POST https://api.deepseek.com/chat/completions "HTTP/1.1 200 OK"
|
| 16 |
+
INFO:httpx:HTTP Request: POST https://api.deepseek.com/chat/completions "HTTP/1.1 200 OK"
|
| 17 |
+
INFO:httpx:HTTP Request: POST https://api.deepseek.com/chat/completions "HTTP/1.1 200 OK"
|
| 18 |
+
DEBUG:nano-graphrag:Entities: 14 | Missed Entities: 23 | Total Entities: 37
|
| 19 |
+
DEBUG:nano-graphrag:Relationships: 13 | Missed Relationships: 23 | Total Relationships: 36
|
| 20 |
+
DEBUG:nano-graphrag:Direct Relationships: 31 | Second-order: 5 | Third-order: 0 | Total Relationships: 36
|
| 21 |
+
⠙ Processed 1 chunks, 37 entities(duplicated), 36 relations(duplicated)
|
| 22 |
+
Execution time with DSPy-AI: 304.38 seconds
|
| 23 |
+
|
| 24 |
+
Entities:
|
| 25 |
+
- 朱元璋 (PERSON):
|
| 26 |
+
明朝开国皇帝,原名朱重八,后改名朱元璋。他出身贫农,经历了从放牛娃到皇帝的传奇人生。在元朝末年,他参加了红巾军起义,最终推翻元朝,建立了明朝。
|
| 27 |
+
- 朱五四 (PERSON):
|
| 28 |
+
朱元璋的父亲,农民出身,家境贫寒。他在朱元璋幼年时去世,对朱元璋的成长和人生选择产生了深远影响。
|
| 29 |
+
- 陈氏 (PERSON):
|
| 30 |
+
朱元璋的母亲,农民出身,家境贫寒。她在朱元璋幼年时去世,对朱元璋的成长和人生选择产生了深远影响。
|
| 31 |
+
- 汤和 (PERSON):
|
| 32 |
+
朱元璋的幼年朋友,后来成为朱元璋起义军中的重要将领。他在朱元璋早期的发展中起到了关键作用。
|
| 33 |
+
- 郭子兴 (PERSON):
|
| 34 |
+
红巾军起义的领导人之一,朱元璋的岳父。他在朱元璋早期的发展中起到了重要作用,但后来与朱元璋产生了矛盾。
|
| 35 |
+
- 马姑娘 (PERSON):
|
| 36 |
+
郭子兴的义女,朱元璋的妻子。她在朱元璋最困难的时候给予了极大的支持,是朱元璋成功的重要因素之一。
|
| 37 |
+
- 元朝 (ORGANIZATION):
|
| 38 |
+
中国历史上的一个朝代,由蒙古族建立。元朝末年,社会矛盾激化,最终导致了红巾军起义和明朝的建立。
|
| 39 |
+
- 红巾军 (ORGANIZATION):
|
| 40 |
+
元朝末年起义军的一支,主要由农民组成。朱元璋最初加入的就是红巾军,并在其中逐渐崭露头角。
|
| 41 |
+
- 皇觉寺 (LOCATION):
|
| 42 |
+
朱元璋早年出家的地方,位于安徽凤阳。他在寺庙中度过了几年的时光,这段经历对他的人生观和价值观产生了深远影响。
|
| 43 |
+
- 濠州 (LOCATION):
|
| 44 |
+
朱元璋早期活动的重要地点,也是红巾军的重要据点之一。朱元璋在这里经历了许多重要事件,包括与郭子兴的矛盾和最终的离开。
|
| 45 |
+
- 1328年 (DATE):
|
| 46 |
+
朱元璋出生的年份。这一年标志着明朝开国皇帝传奇人生的开始。
|
| 47 |
+
- 1344年 (DATE):
|
| 48 |
+
朱元璋家庭遭遇重大变故的年份,他的父母在这一年相继去世。这一事件对朱元璋的人生选择产生了深远影响。
|
| 49 |
+
- 1352年 (DATE):
|
| 50 |
+
朱元璋正式加入红巾军起义的年份。这一年标志着朱元璋从农民到起义军领袖的转变。
|
| 51 |
+
- 1368年 (DATE):
|
| 52 |
+
朱元璋推翻元朝,建立明朝的年份。这一年标志着朱元璋从起义军领袖到皇帝的转变。
|
| 53 |
+
- 朱百六 (PERSON):
|
| 54 |
+
朱元璋的高祖,名字具有元朝时期老百姓命名的特点,即以数字命名。
|
| 55 |
+
- 朱四九 (PERSON):
|
| 56 |
+
朱元璋的曾祖,名字同样具有元朝时期老百姓命名的特点,即以数字命名。
|
| 57 |
+
- 朱初一 (PERSON):
|
| 58 |
+
朱元璋的祖父,名字具有元朝时期老百姓命名的特点,即以数字命名。
|
| 59 |
+
- 刘德 (PERSON):
|
| 60 |
+
朱元璋早年为其放牛的地主,对朱元璋的童年生活有重要影响。
|
| 61 |
+
- 韩山童 (PERSON):
|
| 62 |
+
红巾军起义的早期领导人之一,与刘福通共同起义,对朱元璋的起义选择有间接影响。
|
| 63 |
+
- 刘福通 (PERSON):
|
| 64 |
+
红巾军起义的早期领导人之一,与韩山童共同起义,对朱元璋的起义选择有间接影响。
|
| 65 |
+
- 脱脱 (PERSON):
|
| 66 |
+
元朝末年的著名宰相,主张治理黄河,但他的政策间接导致了红巾军起义的爆发。
|
| 67 |
+
- 元顺帝 (PERSON):
|
| 68 |
+
元朝末代皇帝,他在位期间元朝社会矛盾激化,最终导致了红巾军起义和明朝的建立。
|
| 69 |
+
- 孙德崖 (PERSON):
|
| 70 |
+
红巾军起义的领导人之一,与郭子兴有矛盾,曾绑架郭子兴,对朱元璋的早期发展有重要影响。
|
| 71 |
+
- 周德兴 (PERSON):
|
| 72 |
+
朱元璋的早期朋友,曾为朱元璋算卦,对朱元璋的人生选择有一定影响。
|
| 73 |
+
- 徐达 (PERSON):
|
| 74 |
+
朱元璋早期的重要将领,后来成为明朝的开国功臣之一。
|
| 75 |
+
- 明教 (RELIGION):
|
| 76 |
+
朱元璋在起义过程中接触到的宗教信仰,对他的思想和行动有一定影响。
|
| 77 |
+
- 弥勒佛 (RELIGION):
|
| 78 |
+
明教中的重要神祇,朱元璋相信弥勒佛会降世,对他的信仰和行动有一定影响。
|
| 79 |
+
- 颖州 (LOCATION):
|
| 80 |
+
朱元璋早年讨饭的地方,也是红巾军起义的重要地点之一。
|
| 81 |
+
- 定远 (LOCATION):
|
| 82 |
+
朱元璋早期攻打的地点之一,是他军事生涯的起点。
|
| 83 |
+
- 怀远 (LOCATION):
|
| 84 |
+
朱元璋早期攻打的地点之一,是他军事生涯的起点。
|
| 85 |
+
- 安奉 (LOCATION):
|
| 86 |
+
朱元璋早期攻打的地点之一,是他军事生涯的起点。
|
| 87 |
+
- 含山 (LOCATION):
|
| 88 |
+
朱元璋早期攻打的地点之一,是他军事生涯的起点。
|
| 89 |
+
- 虹县 (LOCATION):
|
| 90 |
+
朱元璋早期攻打的地点之一,是他军事生涯的起点。
|
| 91 |
+
- 钟离 (LOCATION):
|
| 92 |
+
朱元璋的家乡,他在此地召集了二十四位重要将领。
|
| 93 |
+
- 黄河 (LOCATION):
|
| 94 |
+
元朝末年黄河泛滥,导致了严重的社会问题,间接引发了红巾军起义。
|
| 95 |
+
- 淮河 (LOCATION):
|
| 96 |
+
元朝末年淮河沿岸遭遇严重瘟疫和旱灾,加剧了社会矛盾。
|
| 97 |
+
- 1351年 (DATE):
|
| 98 |
+
红巾军起义爆发的年份,对朱元璋的人生选择产生了重要影响。
|
| 99 |
+
|
| 100 |
+
Relationships:
|
| 101 |
+
- 朱元璋 -> 朱五四:
|
| 102 |
+
朱元璋是朱五四的儿子,朱五四的去世对朱元璋的成长和人生选择产生了深远影响。
|
| 103 |
+
- 朱元璋 -> 陈氏:
|
| 104 |
+
朱元璋是陈氏的儿子,陈氏的去世对朱元璋的成长和人生选择产生了深远影响。
|
| 105 |
+
- 朱元璋 -> 汤和:
|
| 106 |
+
汤和是朱元璋的幼年朋友,后来成为朱元璋起义军中的重要将领,对朱元璋早期的发展起到了关键作用。
|
| 107 |
+
- 朱元璋 -> 郭子兴:
|
| 108 |
+
郭子兴是朱元璋的岳父,也是红巾军起义的领导人之一。他在朱元璋早期的发展中起到了重要作用,但后来与朱元璋产生了矛盾。
|
| 109 |
+
- 朱元璋 -> 马姑娘:
|
| 110 |
+
马姑娘是朱元璋的妻子,她在朱元璋最困难的时候给予了极大的支持,是朱元璋成功的重要因素之一。
|
| 111 |
+
- 朱元璋 -> 元朝:
|
| 112 |
+
朱元璋在元朝末年参加了红巾军起义,最终推翻了元朝,建立了明朝。
|
| 113 |
+
- 朱元璋 -> 红巾军:
|
| 114 |
+
朱元璋最初加入的是红巾军,并在其中逐渐崭露头角,最终成为起义军的重要领导人。
|
| 115 |
+
- 朱元璋 -> 皇觉寺:
|
| 116 |
+
朱元璋早年出家的地方是皇觉寺,这段经历对他的人生观和价值观产生了深远影响。
|
| 117 |
+
- 朱元璋 -> 濠州:
|
| 118 |
+
濠州是朱元璋早期活动的重要地点,也是红巾军的重要据点之一。朱元璋在这里经历了许多重要事件,包括与郭子兴的矛盾和最终的离开。
|
| 119 |
+
- 朱元璋 -> 1328年:
|
| 120 |
+
1328年是朱元璋出生的年份,这一年标志着明朝开国皇帝传奇人生的开始。
|
| 121 |
+
- 朱元璋 -> 1344年:
|
| 122 |
+
1344年是朱元璋家庭遭遇重大变故的年份,他的父母在这一年相继去世,这一事件对朱元璋的人生选择产生了深远影响。
|
| 123 |
+
- 朱元璋 -> 1352年:
|
| 124 |
+
1352年是朱元璋正式加入红巾军起义的年份,这一年标志着朱元璋从农民到起义军领袖的转变。
|
| 125 |
+
- 朱元璋 -> 1368年:
|
| 126 |
+
1368年是朱元璋推翻元朝,建立明朝的年份,这一年标志着朱元璋从起义军领袖到皇帝的转变。
|
| 127 |
+
- 朱元璋 -> 朱百六:
|
| 128 |
+
朱百六是朱元璋的高祖,对朱元璋的家族背景有重要影响。
|
| 129 |
+
- 朱元璋 -> 朱四九:
|
| 130 |
+
朱四九是朱元璋的曾祖,对朱元璋的家族背景有重要影响。
|
| 131 |
+
- 朱元璋 -> 朱初一:
|
| 132 |
+
朱初一是朱元璋的祖父,对朱元璋的家族背景有重要影响。
|
| 133 |
+
- 朱元璋 -> 刘德:
|
| 134 |
+
刘德是朱元璋早年为其放牛的地主,对朱元璋的童年生活有重要影响。
|
| 135 |
+
- 朱元璋 -> 韩山童:
|
| 136 |
+
韩山童是红巾军起义的早期领导人之一,对朱元璋的起义选择有间接影响。
|
| 137 |
+
- 朱元璋 -> 刘福通:
|
| 138 |
+
刘福通是红巾军起义的早期领导人之一,对朱元璋的起义选择有间接影响。
|
| 139 |
+
- 朱元璋 -> 脱脱:
|
| 140 |
+
脱脱是元朝末年的著名宰相,他的政策间接导致了红巾军起义的爆发,对朱元璋的起义选择有间接影响。
|
| 141 |
+
- 朱元璋 -> 元顺帝:
|
| 142 |
+
元顺帝是元朝末代皇帝,他在位期间社会矛盾激化,最终导致了红巾军起义和明朝的建立,对朱元璋的起义选择有重要影响。
|
| 143 |
+
- 朱元璋 -> 孙德崖:
|
| 144 |
+
孙德崖是红巾军起义的领导人之一,与郭子兴有矛盾,曾绑架郭子兴,对朱元璋的早期发展有重要影响。
|
| 145 |
+
- 朱元璋 -> 周德兴:
|
| 146 |
+
周德兴是朱元璋的早期朋友,曾为朱元璋算卦,对朱元璋的人生选择有一定影响。
|
| 147 |
+
- 朱元璋 -> 徐达:
|
| 148 |
+
徐达是朱元璋早期的重要将领,后来成为明朝的开国功臣之一,对朱元璋的军事生涯有重要影响。
|
| 149 |
+
- 朱元璋 -> 明教:
|
| 150 |
+
朱元璋在起义过程中接触到的宗教信仰,对他的思想和行动有一定影响。
|
| 151 |
+
- 朱元璋 -> 弥勒佛:
|
| 152 |
+
朱元璋相信弥勒佛会降世,对他的信仰和行动有一定影响。
|
| 153 |
+
- 朱元璋 -> 颖州:
|
| 154 |
+
颖州是朱元璋早年讨饭的地方,也是红巾军起义的重要地点之一,对朱元璋的早期生活有重要影响。
|
| 155 |
+
- 朱元璋 -> 定远:
|
| 156 |
+
定远是朱元璋早期攻打的地点之一,是他军事生涯的起点,对朱元璋的军事发展有重要影响。
|
| 157 |
+
- 朱元璋 -> 怀远:
|
| 158 |
+
怀远是朱元璋早期攻打的地点之一,是他军事生涯的起点,对朱元璋的军事发展有重要影响。
|
| 159 |
+
- 朱元璋 -> 安奉:
|
| 160 |
+
安奉是朱元璋早期攻打的地点之一,是他军事生涯的起点,对朱元璋的军事发展有重要影响。
|
| 161 |
+
- 朱元璋 -> 含山:
|
| 162 |
+
含山是朱元璋早期攻打的地点之一,是他军事生涯的起点,对朱元璋的军事发展有重要影响。
|
| 163 |
+
- 朱元璋 -> 虹县:
|
| 164 |
+
虹县是朱元璋早期攻打的地点之一,是他军事生涯的起点,对朱元璋的军事发展有重要影响。
|
| 165 |
+
- 朱元璋 -> 钟离:
|
| 166 |
+
钟离是朱元璋的家乡,他在此地召集了二十四位重要将领,对朱元璋的军事发展有重要影响。
|
| 167 |
+
- 朱元璋 -> 黄河:
|
| 168 |
+
元朝末年黄河泛滥,导致了严重的社会问题,间接引发了红巾军起义,对朱元璋的起义选择有重要影响。
|
| 169 |
+
- 朱元璋 -> 淮河:
|
| 170 |
+
元朝末年淮河沿岸遭遇严重瘟疫和旱灾,加剧了社会矛盾,对朱元璋的起义选择有重要影响。
|
| 171 |
+
- 朱元璋 -> 1351年:
|
| 172 |
+
1351年是红巾军起义爆发的年份,对朱元璋的人生选择产生了重要影响。
|
| 173 |
+
Running benchmark without DSPy-AI:
|
| 174 |
+
INFO:httpx:HTTP Request: POST https://api.deepseek.com/chat/completions "HTTP/1.1 200 OK"
|
| 175 |
+
INFO:httpx:HTTP Request: POST https://api.deepseek.com/chat/completions "HTTP/1.1 200 OK"
|
| 176 |
+
⠙ Processed 1 chunks, 22 entities(duplicated), 21 relations(duplicated)
|
| 177 |
+
Execution time without DSPy-AI: 147.39 seconds
|
| 178 |
+
|
| 179 |
+
Entities:
|
| 180 |
+
- "朱元璋" ("PERSON"):
|
| 181 |
+
"朱元璋,原名朱重八,后改名朱元璋,是明朝的开国皇帝。他出身贫农,经历了从放牛娃到和尚,再到起义军领袖,最终成为皇帝的传奇人生。"
|
| 182 |
+
- "朱五四" ("PERSON"):
|
| 183 |
+
"朱五四,朱元璋的父亲,是一个农民,为地主种地,家境贫寒。"
|
| 184 |
+
- "陈氏" ("PERSON"):
|
| 185 |
+
"陈氏,朱元璋的母亲,是一个农民,与丈夫朱五四一起辛勤劳作,家境贫寒。"
|
| 186 |
+
- "汤和" ("PERSON"):
|
| 187 |
+
"汤和,朱元璋的幼年朋友,后来成为朱元璋的战友,在朱元璋的崛起过程中起到了重要作用。"
|
| 188 |
+
- "郭子兴" ("PERSON"):
|
| 189 |
+
"郭子兴,濠州城的守卫者,是朱元璋的岳父,也是朱元璋早期的重要支持者。"
|
| 190 |
+
- "韩山童" ("PERSON"):
|
| 191 |
+
"韩山童,与刘福通一起起义反抗元朝统治,是元末农民起义的重要领袖之一。"<SEP>"韩山童,元末农民起义的领袖之一,自称宋朝皇室后裔,与刘福通一起起义。"
|
| 192 |
+
- "刘福通" ("PERSON"):
|
| 193 |
+
"刘福通,与韩山童一起起义反抗元朝统治,是元末农民起义的重要领袖之一。"<SEP>"刘福通,元末农民起义的领袖之一,自称刘光世大将的后人,与韩山童一起起义。"
|
| 194 |
+
- "元朝" ("ORGANIZATION"):
|
| 195 |
+
"元朝,由蒙古族建立的王朝,统治中国时期实行了严格的等级制度,导致社会矛盾激化,最终被朱元璋领导的起义军推翻。"
|
| 196 |
+
- "皇觉寺" ("ORGANIZATION"):
|
| 197 |
+
"皇觉寺,朱元璋曾经在此当和尚,从事杂役工作,后来因饥荒严重,和尚们都被派出去化缘。"
|
| 198 |
+
- "白莲教" ("ORGANIZATION"):
|
| 199 |
+
"白莲教,元末农民起义中的一种宗教组织,韩山童和刘福通起义时利用了这一宗教信仰。"
|
| 200 |
+
- "濠州城" ("GEO"):
|
| 201 |
+
"濠州城,位于今安徽省,是朱元璋早期活动的重要地点,也是郭子兴的驻地。"
|
| 202 |
+
- "定远" ("GEO"):
|
| 203 |
+
"定远,朱元璋奉命攻击的地方,成功攻克后在元军回援前撤出,显示了其军事才能。"
|
| 204 |
+
- "钟离" ("GEO"):
|
| 205 |
+
"钟离,朱元璋的家乡,他在此招收了二十四名壮丁,这些人后来成为明朝的高级干部。"
|
| 206 |
+
- "元末农民起义" ("EVENT"):
|
| 207 |
+
"元末农民起义,是元朝末年由韩山童、刘福通等人领导的反抗元朝统治的大规模起义,最终导致了元朝的灭亡。"
|
| 208 |
+
- "马姑娘" ("PERSON"):
|
| 209 |
+
"马姑娘,郭子兴的义女,后来成为朱元璋的妻子,在朱元璋被关押时,她冒着危险送饭给朱元璋,表现出深厚的感情。"
|
| 210 |
+
- "孙德崖" ("PERSON"):
|
| 211 |
+
"孙德崖,与郭子兴有矛盾的起义军领袖之一,曾参与绑架郭子兴。"
|
| 212 |
+
- "徐达" ("PERSON"):
|
| 213 |
+
"徐达,朱元璋的二十四名亲信之一,后来成为明朝的重要将领。"
|
| 214 |
+
- "周德兴" ("PERSON"):
|
| 215 |
+
"周德兴,朱元璋的二十四名亲信之一,曾为朱元璋算过命。"
|
| 216 |
+
- "脱脱" ("PERSON"):
|
| 217 |
+
"脱脱,元朝的著名宰相,主张治理黄河,但他的政策间接导致了元朝的灭亡。"
|
| 218 |
+
- "元顺帝" ("PERSON"):
|
| 219 |
+
"元顺帝,元朝的最后一位皇帝,统治时期元朝社会矛盾激化,最终导致了元朝的灭亡。"
|
| 220 |
+
- "刘德" ("PERSON"):
|
| 221 |
+
"刘德,地主,朱元璋早年为其放牛。"
|
| 222 |
+
- "吴老太" ("PERSON"):
|
| 223 |
+
"吴老太,村口的媒人,朱元璋曾希望托她找一个媳妇。"
|
| 224 |
+
|
| 225 |
+
Relationships:
|
| 226 |
+
- "朱元璋" -> "朱五四":
|
| 227 |
+
"朱元璋的父亲,对他的成长和早期生活有重要影响。"
|
| 228 |
+
- "朱元璋" -> "陈氏":
|
| 229 |
+
"朱元璋的母亲,对他的成长和早期生活有重要影响。"
|
| 230 |
+
- "朱元璋" -> "汤和":
|
| 231 |
+
"朱元璋的幼年朋友,后来成为他的战友,在朱元璋的崛起过程中起到了重要作用。"
|
| 232 |
+
- "朱元璋" -> "郭子兴":
|
| 233 |
+
"朱元璋的岳父,是他在起义军中的重要支持者。"
|
| 234 |
+
- "朱元璋" -> "韩山童":
|
| 235 |
+
"朱元璋在起义过程中与韩山童有间接联系,韩山童的起义对朱元璋的崛起有重要影响。"
|
| 236 |
+
- "朱元璋" -> "刘福通":
|
| 237 |
+
"朱元璋在起义过程中与刘福通有间接联系,刘福通的起义对朱元璋的崛起有重要影响。"
|
| 238 |
+
- "朱元璋" -> "元朝":
|
| 239 |
+
"朱元璋最终推翻了元朝的统治,建立了明朝。"
|
| 240 |
+
- "朱元璋" -> "皇觉寺":
|
| 241 |
+
"朱元璋曾经在此当和尚,这段经历对他的成长有重要影响。"
|
| 242 |
+
- "朱元璋" -> "白莲教":
|
| 243 |
+
"朱元璋在起义过程中接触到了白莲教,虽然他本人可能并不信仰,但白莲教的起义对他有重要影响。"
|
| 244 |
+
- "朱元璋" -> "濠州城":
|
| 245 |
+
"朱元璋在濠州城的活动对其早期军事和政治生涯有重要影响。"
|
| 246 |
+
- "朱元璋" -> "定远":
|
| 247 |
+
"朱元璋成功攻克定远,显示了其军事才能。"
|
| 248 |
+
- "朱元璋" -> "钟离":
|
| 249 |
+
"朱元璋的家乡,他在此招收了二十四名壮丁,这些人后来成为明朝的高级干部。"
|
| 250 |
+
- "朱元璋" -> "元末农民起义":
|
| 251 |
+
"朱元璋参与并最终领导了元末农民起义,推翻了元朝的统治。"
|
| 252 |
+
- "朱元璋" -> "马姑娘":
|
| 253 |
+
"朱元璋的妻子,在朱元璋被关押时,她冒着危险送饭给朱元璋,表现出深厚的感情。"
|
| 254 |
+
- "朱元璋" -> "孙德崖":
|
| 255 |
+
"朱元璋在孙德崖与郭子兴的矛盾中起到了调解作用,显示了其政治智慧。"
|
| 256 |
+
- "朱元璋" -> "徐达":
|
| 257 |
+
"朱元璋的二十四名亲信之一,后来成为明朝的重要将领。"
|
| 258 |
+
- "朱元璋" -> "周德兴":
|
| 259 |
+
"朱元璋的二十四名亲信之一,曾为朱元璋算过命。"
|
| 260 |
+
- "朱元璋" -> "脱脱":
|
| 261 |
+
"朱元璋在起义过程中间接受到脱脱政策的影响,脱脱的政策间接导致了元朝的灭亡。"
|
| 262 |
+
- "朱元璋" -> "元顺帝":
|
| 263 |
+
"朱元璋最终推翻了元顺帝的统治,建立了明朝。"
|
| 264 |
+
- "朱元璋" -> "刘德":
|
| 265 |
+
"朱元璋早年为刘德放牛,这段经历对他的成长有重要影响。"
|
| 266 |
+
- "朱元璋" -> "吴老太":
|
| 267 |
+
"朱元璋曾希望托吴老太找一个媳妇,显示了他对家庭的渴望。"
|
| 268 |
+
```
|
| 269 |
+
|
| 270 |
+
# Self-Refine with DSPy-AI (v2.5.6)
|
| 271 |
+
## Main Takeaways
|
| 272 |
+
- Time difference: 66.24 seconds
|
| 273 |
+
- Execution time with DSPy-AI: 211.04 seconds
|
| 274 |
+
- Execution time without DSPy-AI: 144.80 seconds
|
| 275 |
+
- Entities extracted: 38 (without DSPy-AI) vs 16 (with DSPy-AI)
|
| 276 |
+
- Relationships extracted: 38 (without DSPy-AI) vs 16 (with DSPy-AI)
|
nano-graphrag/docs/benchmark-en.md
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
- We use [A Christmas Carol](https://raw.githubusercontent.com/gusye1234/nano-graphrag/main/tests/mock_data.txt) by Dickens as the benchmark corpus.
|
| 2 |
+
- We use `61b5eea34783c58074b3c53f1689ad8a5ba6b6ee` commit of [Official GraphRAG implementation](https://github.com/microsoft/graphrag/tree/main)
|
| 3 |
+
- Both GraphRAG and `nano-graphrag` use OpenAI Embedding and `gpt-4o`.
|
| 4 |
+
- Not Cache for both. On the same device and network connection.
|
| 5 |
+
- GrapgRAG Max Async API request: 25
|
| 6 |
+
- `nano-graphrag` Max Async API request: 16
|
| 7 |
+
|
| 8 |
+
## Index Benchmark
|
| 9 |
+
|
| 10 |
+
**GraphRAG index time**: more than 5 minutes
|
| 11 |
+
|
| 12 |
+
**`nano-graphrag` index time**: less than 4 minutes
|
| 13 |
+
|
| 14 |
+
## Local Search Results
|
| 15 |
+
|
| 16 |
+
#### GraphRAG
|
| 17 |
+
|
| 18 |
+
"What are the top themes in this story?"
|
| 19 |
+
|
| 20 |
+
```markdown
|
| 21 |
+
# Top Themes in the Story
|
| 22 |
+
|
| 23 |
+
The story revolves around several central themes that are intricately woven into the narrative, each contributing to the overall message and character development. Here are the top themes:
|
| 24 |
+
|
| 25 |
+
## Family and Relationships
|
| 26 |
+
|
| 27 |
+
Family and relationships play a crucial role in the story, particularly through the interactions between Scrooge and his nephew and niece. Scrooge's nephew is depicted as a warm-hearted and persistent character who continually tries to engage Scrooge, especially during the Christmas season. Despite Scrooge's initial disdain for Christmas, his nephew's efforts highlight themes of compassion and familial bonds [Data: Scrooge's Family and Christmas (20); Relationships (149, 152, 154)]. Scrooge's niece also contributes to this theme with her strong opinions and musical talent, which help soften Scrooge's attitude and remind him of his past [Data: Scrooge's Family and Christmas (20); Relationships (154)].
|
| 28 |
+
|
| 29 |
+
## Transformation and Redemption
|
| 30 |
+
|
| 31 |
+
The theme of transformation and redemption is central to the story, primarily illustrated through Scrooge's character arc. Initially, Scrooge is cynical and dismissive of Christmas and its associated values. However, the recurring event of Christmas serves as a catalyst for his transformation. Through various emotional experiences and reflections prompted by the holiday, Scrooge undergoes a significant change, ultimately embracing generosity and kindness [Data: Scrooge's Family and Christmas (20); Relationships (96, 156, 153, 154)].
|
| 32 |
+
|
| 33 |
+
## The Spirit of Christmas
|
| 34 |
+
|
| 35 |
+
Christmas itself is a recurring and significant event in the story, symbolizing themes of goodwill, generosity, and reflection. The holiday serves as the backdrop for many of the story's pivotal moments and emotional responses. It is during Christmas that Scrooge's transformation occurs, highlighting the holiday's themes of kindness and community. The discussions and celebrations around Christmas emphasize its importance in fostering a sense of togetherness and compassion [Data: Scrooge's Family and Christmas (20); Relationships (96, 149, 153, 154, 156)].
|
| 36 |
+
|
| 37 |
+
## Social Commentary
|
| 38 |
+
|
| 39 |
+
The story also provides a social commentary on the issues of poverty and social inequality. Characters like Ignorance and Want symbolize the social issues of the time, revealed by the Ghost of Christmas Present. Scrooge's initial attitude towards the poor and his interactions with the gentlemen seeking charity highlight the harsh realities of social neglect and the need for compassion and support for the less fortunate [Data: Entities (40); Sources (4)].
|
| 40 |
+
|
| 41 |
+
## The Role of Memory and Reflection
|
| 42 |
+
|
| 43 |
+
Memory and reflection are significant themes that contribute to Scrooge's transformation. The story often delves into Scrooge's past, evoking memories that help him reflect on his life choices and their impact on his current state. These reflections are crucial in prompting Scrooge to change his ways and embrace a more generous and compassionate outlook [Data: Scrooge's Family and Christmas (20); Relationships (96, 156)].
|
| 44 |
+
|
| 45 |
+
## Conclusion
|
| 46 |
+
|
| 47 |
+
In summary, the top themes in the story include family and relationships, transformation and redemption, the spirit of Christmas, social commentary, and the role of memory and reflection. Each of these themes is intricately connected to the characters and events, contributing to the story's rich narrative and moral lessons.
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
### `nano-graphrag`
|
| 51 |
+
|
| 52 |
+
"What are the top themes in this story? "
|
| 53 |
+
|
| 54 |
+
```markdown
|
| 55 |
+
Sure, let’s explore the top themes in "A Christmas Carol" based on the data provided.
|
| 56 |
+
|
| 57 |
+
### Transformation and Redemption
|
| 58 |
+
|
| 59 |
+
The central theme of "A Christmas Carol" revolves around the profound transformation and redemption of Ebenezer Scrooge. Initially depicted as a miserly and solitary man who disdains Christmas and human connection, Scrooge undergoes a radical change after a series of supernatural visits. These visitations force him to reflect on his past, recognize the suffering of others in the present, and confront the grim future that awaits him if he remains unchanged [Data: Reports (0); Entities (1); Relationships (0, 3, 5, 14, 45, +more)].
|
| 60 |
+
|
| 61 |
+
### Influence of Supernatural Entities
|
| 62 |
+
|
| 63 |
+
Supernatural elements are crucial in facilitating Scrooge's transformation. The Ghosts of Christmas Past, Present, and Yet to Come each play a significant role by exposing Scrooge to various scenes, prompting deep introspection and emotional responses. Jacob Marley's ghost, in particular, acts as the initial catalyst, warning Scrooge of the consequences of his miserly ways and setting the stage for the more impactful supernatural encounters that follow [Data: Reports (0); Entities (17, 18, 46); Relationships (29, 30, 46, 55, 114, +more)].
|
| 64 |
+
|
| 65 |
+
### Social Critique
|
| 66 |
+
|
| 67 |
+
Dickens uses the story to critique the social conditions of his time, emphasizing themes of poverty and social neglect. Characters like Bob Cratchit and his family highlight the struggles of the working class, while the conceptual children, Ignorance and Want, presented by the Ghost of Christmas Present, illustrate the dire consequences of societal indifference. This critique encourages readers to reflect on their own social responsibilities and the broader societal need for compassion and generosity [Data: Reports (0); Entities (57, 107, 141); Relationships (45, 102, 159, +more)].
|
| 68 |
+
|
| 69 |
+
### Family and Community
|
| 70 |
+
|
| 71 |
+
Another prominent theme is the importance of family and community. Scrooge's interactions with Fred, his cheerful and persistent nephew, highlight the warmth and joy that family connections can bring. Fred's continued efforts to include Scrooge in family celebrations play a pivotal role in rekindling their relationship. Likewise, the Cratchit family, despite their poverty, exemplifies love, togetherness, and the true spirit of Christmas, all of which contribute to Scrooge's awakening and ultimate redemption [Data: Reports (0); Entities (14, 44); Relationships (28, 39, 40, 86, 153, +more)].
|
| 72 |
+
|
| 73 |
+
### Generosity and Charity
|
| 74 |
+
|
| 75 |
+
The theme of generosity and charity is woven throughout the narrative. Scrooge's eventual acts of kindness, such as his improved treatment of Bob Cratchit and his charitable gestures towards the end of the story, symbolize the transformative power of giving. This newly embraced generosity contrasts sharply with his former frugality and serves as a testament to the positive changes that have taken place within him [Data: Reports (0); Entities (6, 43); Relationships (7, 8, 84, 127, +more)].
|
| 76 |
+
|
| 77 |
+
### Conclusion
|
| 78 |
+
|
| 79 |
+
In summary, "A Christmas Carol" is rich with themes that explore personal transformation, social responsibility, family, and the spirit of generosity. These themes are intricately linked through the narrative structure and character interactions, making the story a timeless reflection on the human condition and the potential for redemption and change [Data: Reports (0); Entities (1, 17, 18); Relationships (0, 14, 31, 45, 48, +more)].
|
| 80 |
+
|
| 81 |
+
Feel free to ask if you need more detailed information on any specific aspect of these themes!
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
## Global Search Results
|
| 87 |
+
|
| 88 |
+
### GraphRAG
|
| 89 |
+
|
| 90 |
+
"What are the top themes in this story? "
|
| 91 |
+
|
| 92 |
+
```markdown
|
| 93 |
+
#### Transformation
|
| 94 |
+
The central theme of the story is the transformation of Ebenezer Scrooge. Initially depicted as a miserly and unsympathetic individual, Scrooge undergoes a profound change to become a generous and caring person. This transformation is driven by his interactions with various supernatural entities, including Jacob Marley and the Ghosts of Christmas Past, Present, and Yet to Come [Data: Reports (18, 17, 27, 21, 6, 26)].
|
| 95 |
+
|
| 96 |
+
#### Familial Love and Support
|
| 97 |
+
The theme of familial love and support is prominently depicted through the Cratchit Family. Despite their financial struggles, the Cratchit family remains loving and supportive of each other. Key relationships include Bob Cratchit's deep affection for his son Tiny Tim and the family's collective grief over Tiny Tim's potential death [Data: Reports (20, 19)].
|
| 98 |
+
|
| 99 |
+
#### Generosity and Compassion
|
| 100 |
+
Generosity and compassion are illustrated through characters like Fezziwig and the transformed Scrooge. Fezziwig's joyful Christmas parties and Scrooge's eventual acts of kindness towards the Cratchit family and others highlight the positive impact of generosity on the community [Data: Reports (23, 19, 26)].
|
| 101 |
+
|
| 102 |
+
#### Social Responsibility
|
| 103 |
+
The theme of social responsibility is underscored by the allegorical figures of Ignorance and Want, as well as the broader societal issues highlighted by the Ghost of Christmas Present. These elements emphasize the importance of addressing poverty and the consequences of neglecting the less fortunate [Data: Reports (17, 21)].
|
| 104 |
+
|
| 105 |
+
#### The Christmas Spirit
|
| 106 |
+
The Christmas spirit, encompassing joy, celebration, and togetherness, is a recurring motif throughout the story. Various scenes, such as the Cratchit family's Christmas dinner, Fred's Christmas party, and the celebrations shown by the Ghost of Christmas Present, reinforce the importance of embracing the festive spirit and its associated values [Data: Reports (20, 18, 23)].
|
| 107 |
+
|
| 108 |
+
#### Impact of Past Relationships
|
| 109 |
+
Scrooge's past relationships, particularly with 'The Girl' (later 'The Matron') and Belle, highlight themes of lost opportunities and the contrast between his lonely existence and the warmth of family life [Data: Reports (2)].
|
| 110 |
+
|
| 111 |
+
#### Role of Supernatural Elements
|
| 112 |
+
Supernatural elements, including Marley's Ghost and the three spirits, play a crucial role in Scrooge's transformation. These entities guide Scrooge through various significant scenes, influencing his reflections and realizations [Data: Reports (26)].
|
| 113 |
+
|
| 114 |
+
#### Family and Love
|
| 115 |
+
The relationships within the community, such as those between Mr. and Mrs. Fezziwig and Belle's family, underscore the importance of family and love. These relationships serve as a reflection on what Scrooge has missed out on due to his choices [Data: Reports (2, 25)].
|
| 116 |
+
|
| 117 |
+
These themes collectively highlight the story's emphasis on personal growth, the importance of community and family, and the enduring values associated with the Christmas season.
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
### `nano-graghrag`
|
| 123 |
+
|
| 124 |
+
"What are the top themes in this story? "
|
| 125 |
+
|
| 126 |
+
```markdown
|
| 127 |
+
### Key Themes in the Story
|
| 128 |
+
|
| 129 |
+
#### 1. Transformation and Redemption
|
| 130 |
+
The central theme of the story revolves around Ebenezer Scrooge's profound transformation from a miserly, cold-hearted man to a generous and kind individual. This journey of change emphasizes the possibility of redemption and the capacity for personal change. The transformation is predominantly driven by Scrooge's interactions with various supernatural entities such as the Ghosts of Christmas Past, Present, and Yet to Come [Data: Reports (0, 1, 2, 3, 4, 5, 6, 12, 23, 14, +more)].
|
| 131 |
+
|
| 132 |
+
#### 2. The Role of Supernatural Entities
|
| 133 |
+
Supernatural spirits play critical roles in guiding Scrooge toward his transformation. The Ghosts of Christmas Past, Present, and Yet to Come evoke reflection, regret, and fear in Scrooge, highlighting the impact of spiritual guidance on personal redemption. These encounters force him to confront and reflect on his past actions and their consequences [Data: Reports (1, 2, 3, 12, 0)].
|
| 134 |
+
|
| 135 |
+
#### 3. Generosity and Community Support
|
| 136 |
+
Another significant theme is the spirit of generosity and community support. Initially depicted as a miser, Scrooge's eventual acts of generosity and the festive spirit among various characters underscore the importance of kindness and community. The Cratchit family's modest yet joyous Christmas celebration poignantly highlights this theme [Data: Reports (26, 6, 2, 14)].
|
| 137 |
+
|
| 138 |
+
#### 4. Family and Emotional Bonds
|
| 139 |
+
The story places significant emphasis on family and emotional connections. This theme is illustrated through the relationships between Scrooge and his family, such as the close sibling bond with Fan and the dynamics within the Cratchit family. These relationships and their development underscore the importance of familial love and support [Data: Reports (0, 2, 4, 17, 5, 15)].
|
| 140 |
+
|
| 141 |
+
#### 5. Social Justice and Economic Disparity
|
| 142 |
+
The narrative also addresses themes of social justice and economic disparity. Scrooge's initial indifference to others' struggles and the contrasting lifestyles of characters like Bob Cratchit highlight the societal inequalities of the time. This critique is further exemplified through the characters’ interactions and the broader social critique in the story [Data: Reports (0, 1, 24, 9, 25, +more)].
|
| 143 |
+
|
| 144 |
+
#### 6. Role of Memory and the Past
|
| 145 |
+
Closely tied to Scrooge’s transformation is the theme of memory and the past. The Ghosts of Christmas Past and Marley’s Ghost prompt Scrooge to reflect on his own life, evoking memories that significantly impact his journey toward becoming a better man. These reflections play a crucial role in his redemption [Data: Reports (3, 8, 13, 17, 6, +more)].
|
| 146 |
+
|
| 147 |
+
### Conclusion
|
| 148 |
+
These themes collectively create a rich and multifaceted narrative that conveys powerful messages about personal growth, the impact of community, the value of generosity, and the critique of social inequities. The journey of Ebenezer Scrooge serves as a timeless reminder of the potential for change and the importance of compassion and empathy in society.
|
| 149 |
+
```
|
| 150 |
+
|
nano-graphrag/docs/benchmark-zh.md
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
- We use [三国演义](https://github.com/tennessine/corpus/blob/master/%E4%B8%89%E5%9B%BD%E6%BC%94%E4%B9%89.txt) by 罗贯中 as the benchmark corpus.
|
| 2 |
+
- We use `61b5eea34783c58074b3c53f1689ad8a5ba6b6ee` commit of [Official GraphRAG implementation](https://github.com/microsoft/graphrag/tree/main)
|
| 3 |
+
- Both GraphRAG and `nano-graphrag` use OpenAI Embedding and `gpt-4o`.
|
| 4 |
+
- Not Cache for both.
|
| 5 |
+
- On the same device and network connection.
|
| 6 |
+
- GrapgRAG Max Async API request: 25
|
| 7 |
+
- `nano-graphrag` Max Async API request: 16
|
| 8 |
+
|
| 9 |
+
## Index Benchmark
|
| 10 |
+
|
| 11 |
+
Both are suck, more than 10 minutes.
|
| 12 |
+
|
| 13 |
+
## Local Search Results
|
| 14 |
+
|
| 15 |
+
### GraphRAG
|
| 16 |
+
|
| 17 |
+
"大乔和曹操的关系"
|
| 18 |
+
|
| 19 |
+
```markdown
|
| 20 |
+
大乔与曹操之间的关系在历史和文学中被描绘为复杂且多维的,尽管直接的互动可能并不频繁。大乔,作为东吴孙策的妻子,以及其美貌和与三国时期重要人物的关联而闻名。曹操,作为三国时期的关键政治和军事领导者,其与大乔的关系主要通过他与东吴的关系以及对大乔及其家族的兴趣来体现。
|
| 21 |
+
|
| 22 |
+
首先,大乔的美貌在三国时期广为人知,甚至引起了曹操的注意。据记载,曹操曾有意俘获大乔以供己用,这一点体现了曹操对大乔的个人兴趣,尽管这种兴趣可能更多地是基于她的美貌和象征意义,而非两人之间有实质性的个人关系 [Data: Entities (266)]。
|
| 23 |
+
|
| 24 |
+
其次,大乔作为孙策的妻子,自然成为了孙家与东吴势力中的一员。孙策和孙权是曹操在政治和军事上的对手,这使得大乔与曹操之间存在着间接的敌对关系。曹操对东吴的军事行动,包括对孙权的攻击,间接地影响到了大乔及其家族的安全和利益 [Data: Entities (266)]。
|
| 25 |
+
|
| 26 |
+
最后,尽管直接的文献记录可能不多,大乔的地位和她所处的时代背景使她与曹操之间的关系充满了历史和文化上的意义。她的故事和曹操的野心相交织,展现了三国时期女性角色与政治军事活动之间的复杂联系。
|
| 27 |
+
|
| 28 |
+
综上所述,大乔与曹操之间的关系主要是通过她的家庭背景、曹操的政治军事活动,以及文化记忆中对她美貌的追捧来定义的。这种关系更多地体现在历史和文学的叙述中,而不是基于两人之间有直接的个人互动。
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
### `nano-graphrag`
|
| 32 |
+
|
| 33 |
+
"大乔和曹操的关系"
|
| 34 |
+
|
| 35 |
+
```markdown
|
| 36 |
+
# 大乔和曹操的关系
|
| 37 |
+
## 大乔的背景
|
| 38 |
+
大乔(Da Qiao)是江东乔公的长女,以其极美的容貌闻名于世。她最为著名的身份是孙策(Sun Ce)的妻子,这一身份在三国时期的政治和军事格局中也占有重要地位。周瑜特别提到她,显示她在周瑜与曹操局势中的重要性 [Data: Entities (1)]。
|
| 39 |
+
## 曹操的欲望
|
| 40 |
+
曹操(Cao Cao)对大乔的渴望主要源自其好色本性。他希望得到大乔和她的妹妹小乔,以满足个人欲望,并且认为获得她们可以帮助他提高政治地位和享乐。曹操曾发誓希望能得到江东二乔(大乔和小乔),并将她们置于铜雀台上,以乐晚年 [Data: Relationships (6, 7); Entities (1, 2)]。
|
| 41 |
+
## 谋略与政治
|
| 42 |
+
在三国时期的权谋斗争中,诸葛亮曾提议使用美人计,将大乔和小乔送给曹操,以此来使曹操班师回朝,减少对东吴的威胁。这一提议显示出大乔在权谋中的重要角色以及她可能对大局产生的影响 [Data: Sources (0); Relationships (12)]。
|
| 43 |
+
## 曹操的策略和行动
|
| 44 |
+
尽管曹操对大乔有很大的兴趣,但实际行动中的证据有限。周瑜得知曹操对大乔的意图后,表达了强烈的反对,并与诸葛亮商议对策,以确保她和小乔的安全 [Data: Sources (2)]。周瑜明确表示大乔是孙策的正妻,而小乔是他的妻子,拒绝将她们送给曹操。
|
| 45 |
+
## 总结
|
| 46 |
+
总体而言,曹操和大乔没有直接的互动记录,但曹操对大乔的欲望及其可能带来的政治影响却是不容忽视的。大乔作为江东乔公的长女,其美貌和身份使她成为多方势力争夺的对象,但她与曹操的关系主要体现在曹操的欲望和诸葛亮、周瑜等人的权谋计划中。
|
| 47 |
+
曹操对大乔的欲望和权谋体现了三国时期复杂的政治动态,同时也揭示了美人计在当时战略中的重要性 [Data: Relationships (6, 12, 13); Entities (1); Sources (0, 2)].
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
## Global Search Results
|
| 53 |
+
|
| 54 |
+
### GraphRAG
|
| 55 |
+
|
| 56 |
+
"大乔与曹操的关系"
|
| 57 |
+
|
| 58 |
+
```markdown
|
| 59 |
+
### 大乔与曹操的关系
|
| 60 |
+
|
| 61 |
+
大乔(Da Qiao)是乔公的女儿,以其非凡的美貌而闻名。这一特质使她在社区中成为一个重要人物,尤其是在曹操(Cao Cao)对她的渴望背景下 [Data: Reports (41)]。
|
| 62 |
+
|
| 63 |
+
#### 曹操的计划
|
| 64 |
+
|
| 65 |
+
曹操计划在统一江南后,将大乔安置在铜雀台(Tongque Tai)。铜雀台是曹操在叶县张河上建造的一座宏伟平台,象征着好运和庆祝。曹操打算将包括大乔在内的美丽女子安置在铜雀台中 [Data: Reports (269)]。这一计划不仅反映了曹操的野心和愿景,也��示了他对大乔的特别关注。
|
| 66 |
+
|
| 67 |
+
#### 乔公与曹操的关系
|
| 68 |
+
|
| 69 |
+
乔公是大乔的父亲,同时也是曹操的朋友。这层关系不仅突显了乔公的影响力,还暗示了由于曹操对大乔的渴望,可能会产生潜在的杠杆作用或冲突 [Data: Reports (41)]。
|
| 70 |
+
|
| 71 |
+
### 结论
|
| 72 |
+
|
| 73 |
+
综上所述,大乔与曹操的关系主要体现在曹操对她的渴望以及他计划将她安置在铜雀台的意图上。此外,乔公与曹操的友谊也为这一关系增添了复杂性,可能会在未来引发更多的动态变化 [Data: Reports (41, 269)]。
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
### `nano-graph`
|
| 79 |
+
|
| 80 |
+
"大乔与曹操的关系"
|
| 81 |
+
|
| 82 |
+
```markdown
|
| 83 |
+
## 曹操与大乔的关系
|
| 84 |
+
|
| 85 |
+
曹操对大乔展现了显著的政治和战略兴趣,这部分是由于大乔的美貌以及她在战略上的重要性。在那个时期,美丽并非单纯的外貌特征,还往往与政治资源和战略权力相关联。曹操认为拥有大乔会对他的政治野心有所帮助,并为其权势扩大增添筹码 [Data: Reports (6)]。
|
| 86 |
+
|
| 87 |
+
大乔作为孙策的妻子,再加上她的地位和美貌,使得她在权力动态中成为一个关键的目标。曹操对大乔的重视,不仅因为她个人的吸引力,更因为她所代表的权力象征。如果能将大乔纳入自己的势力范围,曹操将可能通过这种方式削弱敌对势力,同时增强自己的影响力 [Data: Reports (6)]。
|
| 88 |
+
|
| 89 |
+
总之,曹操和大乔的关系不仅仅停留在个人层面,更深层次的是政治策略和权力博弈的体现。曹操对她的兴趣,反映了他在权力斗争中的精明与野心。
|
| 90 |
+
```
|
| 91 |
+
|
nano-graphrag/docs/use_neo4j_for_graphrag.md
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
1. Install [Neo4j](https://neo4j.com/docs/operations-manual/current/installation/) (please use 5.x version)
|
| 2 |
+
2. Install Neo4j GDS (graph data science) [plugin](https://neo4j.com/docs/graph-data-science/current/installation/neo4j-server/)
|
| 3 |
+
3. Start neo4j server
|
| 4 |
+
4. Get the `NEO4J_URL`, `NEO4J_USER` and `NEO4J_PASSWORD`
|
| 5 |
+
- By default, `NEO4J_URL` is `neo4j://localhost:7687` , `NEO4J_USER` is `neo4j` and `NEO4J_PASSWORD` is `neo4j`
|
| 6 |
+
|
| 7 |
+
Pass your neo4j instance to `GraphRAG`:
|
| 8 |
+
|
| 9 |
+
```python
|
| 10 |
+
from nano_graphrag import GraphRAG
|
| 11 |
+
from nano_graphrag._storage import Neo4jStorage
|
| 12 |
+
|
| 13 |
+
neo4j_config = {
|
| 14 |
+
"neo4j_url": os.environ.get("NEO4J_URL", "neo4j://localhost:7687"),
|
| 15 |
+
"neo4j_auth": (
|
| 16 |
+
os.environ.get("NEO4J_USER", "neo4j"),
|
| 17 |
+
os.environ.get("NEO4J_PASSWORD", "neo4j"),
|
| 18 |
+
)
|
| 19 |
+
}
|
| 20 |
+
GraphRAG(
|
| 21 |
+
graph_storage_cls=Neo4jStorage,
|
| 22 |
+
addon_params=neo4j_config,
|
| 23 |
+
)
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
nano-graphrag/examples/benchmarks/dspy_entity.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dspy
|
| 2 |
+
import os
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
from openai import AsyncOpenAI
|
| 5 |
+
import logging
|
| 6 |
+
import asyncio
|
| 7 |
+
import time
|
| 8 |
+
import shutil
|
| 9 |
+
from nano_graphrag.entity_extraction.extract import extract_entities_dspy
|
| 10 |
+
from nano_graphrag.base import BaseKVStorage
|
| 11 |
+
from nano_graphrag._storage import NetworkXStorage
|
| 12 |
+
from nano_graphrag._utils import compute_mdhash_id, compute_args_hash
|
| 13 |
+
from nano_graphrag._op import extract_entities
|
| 14 |
+
|
| 15 |
+
WORKING_DIR = "./nano_graphrag_cache_dspy_entity"
|
| 16 |
+
|
| 17 |
+
load_dotenv()
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger("nano-graphrag")
|
| 20 |
+
logger.setLevel(logging.DEBUG)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
async def deepseepk_model_if_cache(
|
| 24 |
+
prompt: str, model: str = "deepseek-chat", system_prompt : str = None, history_messages: list = [], **kwargs
|
| 25 |
+
) -> str:
|
| 26 |
+
openai_async_client = AsyncOpenAI(
|
| 27 |
+
api_key=os.environ.get("DEEPSEEK_API_KEY"), base_url="https://api.deepseek.com"
|
| 28 |
+
)
|
| 29 |
+
messages = []
|
| 30 |
+
if system_prompt:
|
| 31 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 32 |
+
|
| 33 |
+
# Get the cached response if having-------------------
|
| 34 |
+
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
| 35 |
+
messages.extend(history_messages)
|
| 36 |
+
messages.append({"role": "user", "content": prompt})
|
| 37 |
+
if hashing_kv is not None:
|
| 38 |
+
args_hash = compute_args_hash(model, messages)
|
| 39 |
+
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
| 40 |
+
if if_cache_return is not None:
|
| 41 |
+
return if_cache_return["return"]
|
| 42 |
+
# -----------------------------------------------------
|
| 43 |
+
|
| 44 |
+
response = await openai_async_client.chat.completions.create(
|
| 45 |
+
model=model, messages=messages, **kwargs
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# Cache the response if having-------------------
|
| 49 |
+
if hashing_kv is not None:
|
| 50 |
+
await hashing_kv.upsert(
|
| 51 |
+
{args_hash: {"return": response.choices[0].message.content, "model": model}}
|
| 52 |
+
)
|
| 53 |
+
# -----------------------------------------------------
|
| 54 |
+
return response.choices[0].message.content
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
async def benchmark_entity_extraction(text: str, system_prompt: str, use_dspy: bool = False):
|
| 58 |
+
working_dir = os.path.join(WORKING_DIR, f"use_dspy={use_dspy}")
|
| 59 |
+
if os.path.exists(working_dir):
|
| 60 |
+
shutil.rmtree(working_dir)
|
| 61 |
+
|
| 62 |
+
start_time = time.time()
|
| 63 |
+
graph_storage = NetworkXStorage(namespace="test", global_config={
|
| 64 |
+
"working_dir": working_dir,
|
| 65 |
+
"entity_summary_to_max_tokens": 500,
|
| 66 |
+
"cheap_model_func": lambda *args, **kwargs: deepseepk_model_if_cache(*args, system_prompt=system_prompt, **kwargs),
|
| 67 |
+
"best_model_func": lambda *args, **kwargs: deepseepk_model_if_cache(*args, system_prompt=system_prompt, **kwargs),
|
| 68 |
+
"cheap_model_max_token_size": 4096,
|
| 69 |
+
"best_model_max_token_size": 4096,
|
| 70 |
+
"tiktoken_model_name": "gpt-4o",
|
| 71 |
+
"hashing_kv": BaseKVStorage(namespace="test", global_config={"working_dir": working_dir}),
|
| 72 |
+
"entity_extract_max_gleaning": 1,
|
| 73 |
+
"entity_extract_max_tokens": 4096,
|
| 74 |
+
"entity_extract_max_entities": 100,
|
| 75 |
+
"entity_extract_max_relationships": 100,
|
| 76 |
+
})
|
| 77 |
+
chunks = {compute_mdhash_id(text, prefix="chunk-"): {"content": text}}
|
| 78 |
+
|
| 79 |
+
if use_dspy:
|
| 80 |
+
graph_storage = await extract_entities_dspy(chunks, graph_storage, None, graph_storage.global_config)
|
| 81 |
+
else:
|
| 82 |
+
graph_storage = await extract_entities(chunks, graph_storage, None, graph_storage.global_config)
|
| 83 |
+
|
| 84 |
+
end_time = time.time()
|
| 85 |
+
execution_time = end_time - start_time
|
| 86 |
+
|
| 87 |
+
return graph_storage, execution_time
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def print_extraction_results(graph_storage: NetworkXStorage):
|
| 91 |
+
print("\nEntities:")
|
| 92 |
+
entities = []
|
| 93 |
+
for node, data in graph_storage._graph.nodes(data=True):
|
| 94 |
+
entity_type = data.get('entity_type', 'Unknown')
|
| 95 |
+
description = data.get('description', 'No description')
|
| 96 |
+
entities.append(f"- {node} ({entity_type}):\n {description}")
|
| 97 |
+
print("\n".join(entities))
|
| 98 |
+
|
| 99 |
+
print("\nRelationships:")
|
| 100 |
+
relationships = []
|
| 101 |
+
for source, target, data in graph_storage._graph.edges(data=True):
|
| 102 |
+
description = data.get('description', 'No description')
|
| 103 |
+
relationships.append(f"- {source} -> {target}:\n {description}")
|
| 104 |
+
print("\n".join(relationships))
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
async def run_benchmark(text: str):
|
| 108 |
+
print("\nRunning benchmark with DSPy-AI:")
|
| 109 |
+
system_prompt = """
|
| 110 |
+
You are an expert system specialized in entity and relationship extraction from complex texts.
|
| 111 |
+
Your task is to thoroughly analyze the given text and extract all relevant entities and their relationships with utmost precision and completeness.
|
| 112 |
+
"""
|
| 113 |
+
system_prompt_dspy = f"{system_prompt} Time: {time.time()}."
|
| 114 |
+
lm = dspy.LM(
|
| 115 |
+
model="deepseek/deepseek-chat",
|
| 116 |
+
model_type="chat",
|
| 117 |
+
api_provider="openai",
|
| 118 |
+
api_key=os.environ["DEEPSEEK_API_KEY"],
|
| 119 |
+
base_url=os.environ["DEEPSEEK_BASE_URL"],
|
| 120 |
+
system_prompt=system_prompt,
|
| 121 |
+
temperature=1.0,
|
| 122 |
+
max_tokens=8192
|
| 123 |
+
)
|
| 124 |
+
dspy.settings.configure(lm=lm, experimental=True)
|
| 125 |
+
graph_storage_with_dspy, time_with_dspy = await benchmark_entity_extraction(text, system_prompt_dspy, use_dspy=True)
|
| 126 |
+
print(f"Execution time with DSPy-AI: {time_with_dspy:.2f} seconds")
|
| 127 |
+
print_extraction_results(graph_storage_with_dspy)
|
| 128 |
+
|
| 129 |
+
print("Running benchmark without DSPy-AI:")
|
| 130 |
+
system_prompt_no_dspy = f"{system_prompt} Time: {time.time()}."
|
| 131 |
+
graph_storage_without_dspy, time_without_dspy = await benchmark_entity_extraction(text, system_prompt_no_dspy, use_dspy=False)
|
| 132 |
+
print(f"Execution time without DSPy-AI: {time_without_dspy:.2f} seconds")
|
| 133 |
+
print_extraction_results(graph_storage_without_dspy)
|
| 134 |
+
|
| 135 |
+
print("\nComparison:")
|
| 136 |
+
print(f"Time difference: {abs(time_with_dspy - time_without_dspy):.2f} seconds")
|
| 137 |
+
print(f"DSPy-AI is {'faster' if time_with_dspy < time_without_dspy else 'slower'}")
|
| 138 |
+
|
| 139 |
+
entities_without_dspy = len(graph_storage_without_dspy._graph.nodes())
|
| 140 |
+
entities_with_dspy = len(graph_storage_with_dspy._graph.nodes())
|
| 141 |
+
relationships_without_dspy = len(graph_storage_without_dspy._graph.edges())
|
| 142 |
+
relationships_with_dspy = len(graph_storage_with_dspy._graph.edges())
|
| 143 |
+
|
| 144 |
+
print(f"Entities extracted: {entities_without_dspy} (without DSPy-AI) vs {entities_with_dspy} (with DSPy-AI)")
|
| 145 |
+
print(f"Relationships extracted: {relationships_without_dspy} (without DSPy-AI) vs {relationships_with_dspy} (with DSPy-AI)")
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
if __name__ == "__main__":
|
| 149 |
+
with open("./tests/zhuyuanzhang.txt", encoding="utf-8-sig") as f:
|
| 150 |
+
text = f.read()
|
| 151 |
+
|
| 152 |
+
asyncio.run(run_benchmark(text=text))
|
nano-graphrag/examples/benchmarks/eval_naive_graphrag_on_multi_hop.ipynb
ADDED
|
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"In this tutorial, we are going to evaluate the performance of the naive RAG and the GraphRAG algorithm on a [multi-hop RAG task](https://github.com/yixuantt/MultiHop-RAG)."
|
| 8 |
+
]
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"cell_type": "markdown",
|
| 12 |
+
"metadata": {},
|
| 13 |
+
"source": [
|
| 14 |
+
"## Setup\n",
|
| 15 |
+
"Make sure you install the necessary dependencies by running the following commands:"
|
| 16 |
+
]
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"cell_type": "code",
|
| 20 |
+
"execution_count": null,
|
| 21 |
+
"metadata": {},
|
| 22 |
+
"outputs": [],
|
| 23 |
+
"source": [
|
| 24 |
+
"!pip install ragas nest_asyncio datasets"
|
| 25 |
+
]
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"cell_type": "markdown",
|
| 29 |
+
"metadata": {},
|
| 30 |
+
"source": [
|
| 31 |
+
"Import the necessary libraries, and set up your openai api key if needed:"
|
| 32 |
+
]
|
| 33 |
+
},
|
| 34 |
+
{
|
| 35 |
+
"cell_type": "code",
|
| 36 |
+
"execution_count": 21,
|
| 37 |
+
"metadata": {},
|
| 38 |
+
"outputs": [],
|
| 39 |
+
"source": [
|
| 40 |
+
"import os\n",
|
| 41 |
+
"# os.environ[\"OPENAI_API_KEY\"] = \"YOUR_API_KEY\"\n",
|
| 42 |
+
"import json\n",
|
| 43 |
+
"import sys\n",
|
| 44 |
+
"sys.path.append(\"../..\")\n",
|
| 45 |
+
"\n",
|
| 46 |
+
"import nest_asyncio\n",
|
| 47 |
+
"nest_asyncio.apply()\n",
|
| 48 |
+
"import logging\n",
|
| 49 |
+
"\n",
|
| 50 |
+
"logging.basicConfig(level=logging.WARNING)\n",
|
| 51 |
+
"logging.getLogger(\"nano-graphrag\").setLevel(logging.INFO)\n",
|
| 52 |
+
"from nano_graphrag import GraphRAG, QueryParam\n",
|
| 53 |
+
"from datasets import Dataset \n",
|
| 54 |
+
"from ragas import evaluate\n",
|
| 55 |
+
"from ragas.metrics import (\n",
|
| 56 |
+
" answer_correctness,\n",
|
| 57 |
+
" answer_similarity,\n",
|
| 58 |
+
")"
|
| 59 |
+
]
|
| 60 |
+
},
|
| 61 |
+
{
|
| 62 |
+
"cell_type": "markdown",
|
| 63 |
+
"metadata": {},
|
| 64 |
+
"source": [
|
| 65 |
+
"Download the dataset from [Github Repo](https://github.com/yixuantt/MultiHop-RAG/tree/main/dataset). \n",
|
| 66 |
+
"If should contain two files:\n",
|
| 67 |
+
"- `MultiHopRAG.json`\n",
|
| 68 |
+
"- `corpus.json`\n",
|
| 69 |
+
"\n",
|
| 70 |
+
"After downloading the dataset, replace the below paths to the paths on your machine."
|
| 71 |
+
]
|
| 72 |
+
},
|
| 73 |
+
{
|
| 74 |
+
"cell_type": "code",
|
| 75 |
+
"execution_count": 3,
|
| 76 |
+
"metadata": {},
|
| 77 |
+
"outputs": [],
|
| 78 |
+
"source": [
|
| 79 |
+
"\n",
|
| 80 |
+
"multi_hop_rag_file = \"./fixtures/MultiHopRAG.json\"\n",
|
| 81 |
+
"multi_hop_corpus_file = \"./fixtures/corpus.json\""
|
| 82 |
+
]
|
| 83 |
+
},
|
| 84 |
+
{
|
| 85 |
+
"cell_type": "markdown",
|
| 86 |
+
"metadata": {},
|
| 87 |
+
"source": [
|
| 88 |
+
"## Preprocess"
|
| 89 |
+
]
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
"cell_type": "code",
|
| 93 |
+
"execution_count": 4,
|
| 94 |
+
"metadata": {},
|
| 95 |
+
"outputs": [],
|
| 96 |
+
"source": [
|
| 97 |
+
"\n",
|
| 98 |
+
"with open(multi_hop_rag_file) as f:\n",
|
| 99 |
+
" multi_hop_rag_dataset = json.load(f)\n",
|
| 100 |
+
"with open(multi_hop_corpus_file) as f:\n",
|
| 101 |
+
" multi_hop_corpus = json.load(f)\n",
|
| 102 |
+
"\n",
|
| 103 |
+
"corups_url_refernces = {}\n",
|
| 104 |
+
"for cor in multi_hop_corpus:\n",
|
| 105 |
+
" corups_url_refernces[cor['url']] = cor"
|
| 106 |
+
]
|
| 107 |
+
},
|
| 108 |
+
{
|
| 109 |
+
"cell_type": "markdown",
|
| 110 |
+
"metadata": {},
|
| 111 |
+
"source": [
|
| 112 |
+
"We only use the top-100 queries for evaluation."
|
| 113 |
+
]
|
| 114 |
+
},
|
| 115 |
+
{
|
| 116 |
+
"cell_type": "code",
|
| 117 |
+
"execution_count": 5,
|
| 118 |
+
"metadata": {},
|
| 119 |
+
"outputs": [
|
| 120 |
+
{
|
| 121 |
+
"name": "stdout",
|
| 122 |
+
"output_type": "stream",
|
| 123 |
+
"text": [
|
| 124 |
+
"Queries have types: {'inference_query', 'comparison_query', 'null_query', 'temporal_query'}\n",
|
| 125 |
+
"We will need 139 articles:\n",
|
| 126 |
+
"## ASX set to drop as Wall Street’s September slump deepens\n",
|
| 127 |
+
"Author: Stan Choe, The Sydney Morning Herald\n",
|
| 128 |
+
"Category: business\n",
|
| 129 |
+
"Publised: 2023-09-26T19:11:30+00:00\n",
|
| 130 |
+
"ETF provider Betashares, which manages $ ...\n"
|
| 131 |
+
]
|
| 132 |
+
}
|
| 133 |
+
],
|
| 134 |
+
"source": [
|
| 135 |
+
"multi_hop_rag_dataset = multi_hop_rag_dataset[:100]\n",
|
| 136 |
+
"print(\"Queries have types:\", set([q['question_type'] for q in multi_hop_rag_dataset]))\n",
|
| 137 |
+
"total_urls = set()\n",
|
| 138 |
+
"for q in multi_hop_rag_dataset:\n",
|
| 139 |
+
" total_urls.update([up['url'] for up in q['evidence_list']])\n",
|
| 140 |
+
"corups_url_refernces = {k:v for k, v in corups_url_refernces.items() if k in total_urls}\n",
|
| 141 |
+
"\n",
|
| 142 |
+
"total_corups = [f\"## {cor['title']}\\nAuthor: {cor['author']}, {cor['source']}\\nCategory: {cor['category']}\\nPublised: {cor['published_at']}\\n{cor['body']}\" for cor in corups_url_refernces.values()]\n",
|
| 143 |
+
"\n",
|
| 144 |
+
"print(f\"We will need {len(total_corups)} articles:\")\n",
|
| 145 |
+
"print(total_corups[0][:200], \"...\")"
|
| 146 |
+
]
|
| 147 |
+
},
|
| 148 |
+
{
|
| 149 |
+
"cell_type": "markdown",
|
| 150 |
+
"metadata": {},
|
| 151 |
+
"source": [
|
| 152 |
+
"Add index for the `total_corups` using naive RAG and GraphRAG"
|
| 153 |
+
]
|
| 154 |
+
},
|
| 155 |
+
{
|
| 156 |
+
"cell_type": "code",
|
| 157 |
+
"execution_count": 6,
|
| 158 |
+
"metadata": {},
|
| 159 |
+
"outputs": [
|
| 160 |
+
{
|
| 161 |
+
"name": "stderr",
|
| 162 |
+
"output_type": "stream",
|
| 163 |
+
"text": [
|
| 164 |
+
"INFO:nano-graphrag:Load KV full_docs with 139 data\n",
|
| 165 |
+
"INFO:nano-graphrag:Load KV text_chunks with 408 data\n",
|
| 166 |
+
"INFO:nano-graphrag:Load KV llm_response_cache with 1634 data\n",
|
| 167 |
+
"INFO:nano-graphrag:Load KV community_reports with 794 data\n",
|
| 168 |
+
"INFO:nano-graphrag:Loaded graph from nano_graphrag_cache_multi_hop_rag_test/graph_chunk_entity_relation.graphml with 6181 nodes, 5423 edges\n",
|
| 169 |
+
"WARNING:nano-graphrag:All docs are already in the storage\n",
|
| 170 |
+
"INFO:nano-graphrag:Writing graph with 6181 nodes, 5423 edges\n"
|
| 171 |
+
]
|
| 172 |
+
}
|
| 173 |
+
],
|
| 174 |
+
"source": [
|
| 175 |
+
"# First time indexing will cost many time, roughly 15~20 minutes\n",
|
| 176 |
+
"graphrag_func = GraphRAG(working_dir=\"nano_graphrag_cache_multi_hop_rag_test\", enable_naive_rag=True,\n",
|
| 177 |
+
" embedding_func_max_async=4)\n",
|
| 178 |
+
"graphrag_func.insert(total_corups)"
|
| 179 |
+
]
|
| 180 |
+
},
|
| 181 |
+
{
|
| 182 |
+
"cell_type": "markdown",
|
| 183 |
+
"metadata": {},
|
| 184 |
+
"source": [
|
| 185 |
+
"Look at the response of different RAG methods on the first query:"
|
| 186 |
+
]
|
| 187 |
+
},
|
| 188 |
+
{
|
| 189 |
+
"cell_type": "code",
|
| 190 |
+
"execution_count": 24,
|
| 191 |
+
"metadata": {},
|
| 192 |
+
"outputs": [],
|
| 193 |
+
"source": [
|
| 194 |
+
"response_formate = \"Single phrase or sentence, concise and no redundant explanation needed. If you don't have the answer in context, Just response 'Insufficient information'\"\n",
|
| 195 |
+
"naive_rag_query_param = QueryParam(mode='naive', response_type=response_formate)\n",
|
| 196 |
+
"naive_rag_query_only_context_param = QueryParam(mode='naive', only_need_context=True)\n",
|
| 197 |
+
"local_graphrag_query_param = QueryParam(mode='local', response_type=response_formate)\n",
|
| 198 |
+
"local_graphrag_only_context__param = QueryParam(mode='local', only_need_context=True)"
|
| 199 |
+
]
|
| 200 |
+
},
|
| 201 |
+
{
|
| 202 |
+
"cell_type": "code",
|
| 203 |
+
"execution_count": 8,
|
| 204 |
+
"metadata": {},
|
| 205 |
+
"outputs": [
|
| 206 |
+
{
|
| 207 |
+
"name": "stdout",
|
| 208 |
+
"output_type": "stream",
|
| 209 |
+
"text": [
|
| 210 |
+
"Question: Who is the individual associated with the cryptocurrency industry facing a criminal trial on fraud and conspiracy charges, as reported by both The Verge and TechCrunch, and is accused by prosecutors of committing fraud for personal gain?\n",
|
| 211 |
+
"GroundTruth Answer: Sam Bankman-Fried\n"
|
| 212 |
+
]
|
| 213 |
+
}
|
| 214 |
+
],
|
| 215 |
+
"source": [
|
| 216 |
+
"query = multi_hop_rag_dataset[0]\n",
|
| 217 |
+
"print(\"Question:\", query['query'])\n",
|
| 218 |
+
"print(\"GroundTruth Answer:\", query['answer'])"
|
| 219 |
+
]
|
| 220 |
+
},
|
| 221 |
+
{
|
| 222 |
+
"cell_type": "code",
|
| 223 |
+
"execution_count": 9,
|
| 224 |
+
"metadata": {},
|
| 225 |
+
"outputs": [
|
| 226 |
+
{
|
| 227 |
+
"name": "stderr",
|
| 228 |
+
"output_type": "stream",
|
| 229 |
+
"text": [
|
| 230 |
+
"INFO:nano-graphrag:Truncate 20 to 12 chunks\n"
|
| 231 |
+
]
|
| 232 |
+
},
|
| 233 |
+
{
|
| 234 |
+
"name": "stdout",
|
| 235 |
+
"output_type": "stream",
|
| 236 |
+
"text": [
|
| 237 |
+
"NaiveRAG Answer: Sam Bankman-Fried\n"
|
| 238 |
+
]
|
| 239 |
+
}
|
| 240 |
+
],
|
| 241 |
+
"source": [
|
| 242 |
+
"print(\"NaiveRAG Answer:\", graphrag_func.query(query['query'], param=naive_rag_query_param))"
|
| 243 |
+
]
|
| 244 |
+
},
|
| 245 |
+
{
|
| 246 |
+
"cell_type": "code",
|
| 247 |
+
"execution_count": 10,
|
| 248 |
+
"metadata": {},
|
| 249 |
+
"outputs": [
|
| 250 |
+
{
|
| 251 |
+
"name": "stderr",
|
| 252 |
+
"output_type": "stream",
|
| 253 |
+
"text": [
|
| 254 |
+
"INFO:nano-graphrag:Using 20 entites, 3 communities, 124 relations, 3 text units\n"
|
| 255 |
+
]
|
| 256 |
+
},
|
| 257 |
+
{
|
| 258 |
+
"name": "stdout",
|
| 259 |
+
"output_type": "stream",
|
| 260 |
+
"text": [
|
| 261 |
+
"Local GraphRAG Answer: Sam Bankman-Fried\n"
|
| 262 |
+
]
|
| 263 |
+
}
|
| 264 |
+
],
|
| 265 |
+
"source": [
|
| 266 |
+
"print(\"Local GraphRAG Answer:\", graphrag_func.query(query['query'], param=local_graphrag_query_param))"
|
| 267 |
+
]
|
| 268 |
+
},
|
| 269 |
+
{
|
| 270 |
+
"cell_type": "markdown",
|
| 271 |
+
"metadata": {},
|
| 272 |
+
"source": [
|
| 273 |
+
"Great! Now we're ready to evaluate more detailed metrics. We will use [ragas](https://docs.ragas.io/en/stable/) to evalue the answers' quality."
|
| 274 |
+
]
|
| 275 |
+
},
|
| 276 |
+
{
|
| 277 |
+
"cell_type": "code",
|
| 278 |
+
"execution_count": 11,
|
| 279 |
+
"metadata": {},
|
| 280 |
+
"outputs": [],
|
| 281 |
+
"source": [
|
| 282 |
+
"questions = [q['query'] for q in multi_hop_rag_dataset]\n",
|
| 283 |
+
"labels = [q['answer'] for q in multi_hop_rag_dataset]"
|
| 284 |
+
]
|
| 285 |
+
},
|
| 286 |
+
{
|
| 287 |
+
"cell_type": "code",
|
| 288 |
+
"execution_count": 12,
|
| 289 |
+
"metadata": {},
|
| 290 |
+
"outputs": [
|
| 291 |
+
{
|
| 292 |
+
"name": "stderr",
|
| 293 |
+
"output_type": "stream",
|
| 294 |
+
"text": [
|
| 295 |
+
" 0%| | 0/100 [00:00<?, ?it/s]"
|
| 296 |
+
]
|
| 297 |
+
},
|
| 298 |
+
{
|
| 299 |
+
"name": "stderr",
|
| 300 |
+
"output_type": "stream",
|
| 301 |
+
"text": [
|
| 302 |
+
"100%|██████████| 100/100 [03:53<00:00, 2.33s/it]\n"
|
| 303 |
+
]
|
| 304 |
+
}
|
| 305 |
+
],
|
| 306 |
+
"source": [
|
| 307 |
+
"from tqdm import tqdm\n",
|
| 308 |
+
"logging.getLogger(\"nano-graphrag\").setLevel(logging.WARNING)\n",
|
| 309 |
+
"\n",
|
| 310 |
+
"naive_rag_answers = [\n",
|
| 311 |
+
" graphrag_func.query(q, param=naive_rag_query_param) for q in tqdm(questions)\n",
|
| 312 |
+
"]"
|
| 313 |
+
]
|
| 314 |
+
},
|
| 315 |
+
{
|
| 316 |
+
"cell_type": "code",
|
| 317 |
+
"execution_count": 14,
|
| 318 |
+
"metadata": {},
|
| 319 |
+
"outputs": [
|
| 320 |
+
{
|
| 321 |
+
"name": "stderr",
|
| 322 |
+
"output_type": "stream",
|
| 323 |
+
"text": [
|
| 324 |
+
"100%|██████████| 100/100 [09:10<00:00, 5.50s/it]\n"
|
| 325 |
+
]
|
| 326 |
+
}
|
| 327 |
+
],
|
| 328 |
+
"source": [
|
| 329 |
+
"local_graphrag_answers = [\n",
|
| 330 |
+
" graphrag_func.query(q, param=local_graphrag_query_param) for q in tqdm(questions)\n",
|
| 331 |
+
"]"
|
| 332 |
+
]
|
| 333 |
+
},
|
| 334 |
+
{
|
| 335 |
+
"cell_type": "code",
|
| 336 |
+
"execution_count": 34,
|
| 337 |
+
"metadata": {},
|
| 338 |
+
"outputs": [
|
| 339 |
+
{
|
| 340 |
+
"name": "stderr",
|
| 341 |
+
"output_type": "stream",
|
| 342 |
+
"text": [
|
| 343 |
+
" 70%|███████ | 70/100 [04:25<01:53, 3.79s/it]8, 6.38it/s]\n",
|
| 344 |
+
"Evaluating: 100%|██████████| 200/200 [00:32<00:00, 6.19it/s]\n"
|
| 345 |
+
]
|
| 346 |
+
}
|
| 347 |
+
],
|
| 348 |
+
"source": [
|
| 349 |
+
"naive_results = evaluate(\n",
|
| 350 |
+
" Dataset.from_dict({\n",
|
| 351 |
+
" \"question\": questions,\n",
|
| 352 |
+
" \"ground_truth\": labels,\n",
|
| 353 |
+
" \"answer\": naive_rag_answers,\n",
|
| 354 |
+
" }),\n",
|
| 355 |
+
" metrics=[\n",
|
| 356 |
+
" # answer_relevancy,\n",
|
| 357 |
+
" answer_correctness,\n",
|
| 358 |
+
" answer_similarity,\n",
|
| 359 |
+
" ],\n",
|
| 360 |
+
")"
|
| 361 |
+
]
|
| 362 |
+
},
|
| 363 |
+
{
|
| 364 |
+
"cell_type": "code",
|
| 365 |
+
"execution_count": 36,
|
| 366 |
+
"metadata": {},
|
| 367 |
+
"outputs": [
|
| 368 |
+
{
|
| 369 |
+
"name": "stderr",
|
| 370 |
+
"output_type": "stream",
|
| 371 |
+
"text": [
|
| 372 |
+
"Evaluating: 100%|██████████| 200/200 [00:23<00:00, 8.59it/s]\n"
|
| 373 |
+
]
|
| 374 |
+
}
|
| 375 |
+
],
|
| 376 |
+
"source": [
|
| 377 |
+
"local_graphrag_results = evaluate(\n",
|
| 378 |
+
" Dataset.from_dict({\n",
|
| 379 |
+
" \"question\": questions,\n",
|
| 380 |
+
" \"ground_truth\": labels,\n",
|
| 381 |
+
" \"answer\": local_graphrag_answers,\n",
|
| 382 |
+
" }),\n",
|
| 383 |
+
" metrics=[\n",
|
| 384 |
+
" # answer_relevancy,\n",
|
| 385 |
+
" answer_correctness,\n",
|
| 386 |
+
" answer_similarity,\n",
|
| 387 |
+
" ],\n",
|
| 388 |
+
")"
|
| 389 |
+
]
|
| 390 |
+
},
|
| 391 |
+
{
|
| 392 |
+
"cell_type": "code",
|
| 393 |
+
"execution_count": 39,
|
| 394 |
+
"metadata": {},
|
| 395 |
+
"outputs": [
|
| 396 |
+
{
|
| 397 |
+
"name": "stdout",
|
| 398 |
+
"output_type": "stream",
|
| 399 |
+
"text": [
|
| 400 |
+
"Naive RAG results {'answer_correctness': 0.5896, 'answer_similarity': 0.8935}\n",
|
| 401 |
+
"Local GraphRAG results {'answer_correctness': 0.7380, 'answer_similarity': 0.8619}\n"
|
| 402 |
+
]
|
| 403 |
+
}
|
| 404 |
+
],
|
| 405 |
+
"source": [
|
| 406 |
+
"print(\"Naive RAG results\", naive_results)\n",
|
| 407 |
+
"print(\"Local GraphRAG results\", local_graphrag_results)"
|
| 408 |
+
]
|
| 409 |
+
}
|
| 410 |
+
],
|
| 411 |
+
"metadata": {
|
| 412 |
+
"kernelspec": {
|
| 413 |
+
"display_name": "baai",
|
| 414 |
+
"language": "python",
|
| 415 |
+
"name": "python3"
|
| 416 |
+
},
|
| 417 |
+
"language_info": {
|
| 418 |
+
"codemirror_mode": {
|
| 419 |
+
"name": "ipython",
|
| 420 |
+
"version": 3
|
| 421 |
+
},
|
| 422 |
+
"file_extension": ".py",
|
| 423 |
+
"mimetype": "text/x-python",
|
| 424 |
+
"name": "python",
|
| 425 |
+
"nbconvert_exporter": "python",
|
| 426 |
+
"pygments_lexer": "ipython3",
|
| 427 |
+
"version": "3.9.19"
|
| 428 |
+
}
|
| 429 |
+
},
|
| 430 |
+
"nbformat": 4,
|
| 431 |
+
"nbformat_minor": 2
|
| 432 |
+
}
|
nano-graphrag/examples/benchmarks/hnsw_vs_nano_vector_storage.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import time
|
| 3 |
+
import numpy as np
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
from nano_graphrag import GraphRAG
|
| 6 |
+
from nano_graphrag._storage import NanoVectorDBStorage, HNSWVectorStorage
|
| 7 |
+
from nano_graphrag._utils import wrap_embedding_func_with_attrs
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
WORKING_DIR = "./nano_graphrag_cache_benchmark_hnsw_vs_nano_vector_storage"
|
| 11 |
+
DATA_LEN = 100_000
|
| 12 |
+
FAKE_DIM = 1024
|
| 13 |
+
BATCH_SIZE = 100000
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@wrap_embedding_func_with_attrs(embedding_dim=FAKE_DIM, max_token_size=8192)
|
| 17 |
+
async def sample_embedding(texts: list[str]) -> np.ndarray:
|
| 18 |
+
return np.float32(np.random.rand(len(texts), FAKE_DIM))
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def generate_test_data():
|
| 22 |
+
return {str(i): {"content": f"Test content {i}"} for i in range(DATA_LEN)}
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
async def benchmark_storage(storage_class, name):
|
| 26 |
+
rag = GraphRAG(working_dir=WORKING_DIR, embedding_func=sample_embedding)
|
| 27 |
+
storage = storage_class(
|
| 28 |
+
namespace=f"benchmark_{name}",
|
| 29 |
+
global_config=rag.__dict__,
|
| 30 |
+
embedding_func=sample_embedding,
|
| 31 |
+
meta_fields={"content"},
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
test_data = generate_test_data()
|
| 35 |
+
|
| 36 |
+
print(f"Benchmarking {name}...")
|
| 37 |
+
with tqdm(total=DATA_LEN, desc=f"{name} Benchmark") as pbar:
|
| 38 |
+
start_time = time.time()
|
| 39 |
+
for i in range(0, len(test_data), BATCH_SIZE):
|
| 40 |
+
batch = {k: test_data[k] for k in list(test_data.keys())[i:i+BATCH_SIZE]}
|
| 41 |
+
await storage.upsert(batch)
|
| 42 |
+
pbar.update(min(BATCH_SIZE, DATA_LEN - i))
|
| 43 |
+
|
| 44 |
+
insert_time = time.time() - start_time
|
| 45 |
+
|
| 46 |
+
save_start_time = time.time()
|
| 47 |
+
await storage.index_done_callback()
|
| 48 |
+
save_time = time.time() - save_start_time
|
| 49 |
+
pbar.update(1)
|
| 50 |
+
|
| 51 |
+
query_vector = np.random.rand(FAKE_DIM)
|
| 52 |
+
query_times = []
|
| 53 |
+
for _ in range(100):
|
| 54 |
+
query_start = time.time()
|
| 55 |
+
await storage.query(query_vector, top_k=10)
|
| 56 |
+
query_times.append(time.time() - query_start)
|
| 57 |
+
pbar.update(1)
|
| 58 |
+
|
| 59 |
+
avg_query_time = sum(query_times) / len(query_times)
|
| 60 |
+
|
| 61 |
+
print(f"{name} - Insert: {insert_time:.2f}s, Save: {save_time:.2f}s, Avg Query: {avg_query_time:.4f}s")
|
| 62 |
+
return insert_time, save_time, avg_query_time
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
async def run_benchmarks():
|
| 66 |
+
print("Running NanoVectorDB benchmark...")
|
| 67 |
+
nano_insert_time, nano_save_time, nano_query_time = await benchmark_storage(NanoVectorDBStorage, "nano")
|
| 68 |
+
|
| 69 |
+
print("\nRunning HNSWVectorStorage benchmark...")
|
| 70 |
+
hnsw_insert_time, hnsw_save_time, hnsw_query_time = await benchmark_storage(HNSWVectorStorage, "hnsw")
|
| 71 |
+
|
| 72 |
+
print("\nBenchmark Results:")
|
| 73 |
+
print(f"NanoVectorDB - Insert: {nano_insert_time:.2f}s, Save: {nano_save_time:.2f}s, Avg Query: {nano_query_time:.4f}s")
|
| 74 |
+
print(f"HNSWVectorStorage - Insert: {hnsw_insert_time:.2f}s, Save: {hnsw_save_time:.2f}s, Avg Query: {hnsw_query_time:.4f}s")
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
if __name__ == "__main__":
|
| 78 |
+
asyncio.run(run_benchmarks())
|
nano-graphrag/examples/benchmarks/md5_vs_xxhash.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import xxhash
|
| 3 |
+
from hashlib import md5
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def xxhash_ids(data: list[str]) -> np.ndarray:
|
| 9 |
+
return np.fromiter(
|
| 10 |
+
(xxhash.xxh32_intdigest(d.encode()) for d in data),
|
| 11 |
+
dtype=np.uint32,
|
| 12 |
+
count=len(data)
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def md5_ids(data: list[str]) -> np.ndarray:
|
| 17 |
+
return np.fromiter(
|
| 18 |
+
(int(md5(d.encode()).hexdigest(), 16) & 0xFFFFFFFF for d in data),
|
| 19 |
+
dtype=np.uint32,
|
| 20 |
+
count=len(data)
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
if __name__ == "__main__":
|
| 25 |
+
num_ids = 1000000
|
| 26 |
+
num_iterations = 100
|
| 27 |
+
xxhash_times = []
|
| 28 |
+
md5_times = []
|
| 29 |
+
|
| 30 |
+
for i in tqdm(range(num_iterations)):
|
| 31 |
+
test_data = [f"{i}_{j}" for j in range(num_ids)]
|
| 32 |
+
|
| 33 |
+
start_time = time.time()
|
| 34 |
+
xxhash_result = xxhash_ids(test_data)
|
| 35 |
+
xxhash_times.append(time.time() - start_time)
|
| 36 |
+
|
| 37 |
+
start_time = time.time()
|
| 38 |
+
md5_result = md5_ids(test_data)
|
| 39 |
+
md5_times.append(time.time() - start_time)
|
| 40 |
+
|
| 41 |
+
assert len(xxhash_result) == len(md5_result) == num_ids
|
| 42 |
+
assert not np.array_equal(xxhash_result, md5_result)
|
| 43 |
+
|
| 44 |
+
avg_xxhash_time = np.mean(xxhash_times)
|
| 45 |
+
avg_md5_time = np.mean(md5_times)
|
| 46 |
+
std_xxhash_time = np.std(xxhash_times)
|
| 47 |
+
std_md5_time = np.std(md5_times)
|
| 48 |
+
|
| 49 |
+
print(f"num_ids: {num_ids} | num_iterations: {num_iterations}")
|
| 50 |
+
print(f"\nAverage xxhash time: {avg_xxhash_time:.4f} seconds")
|
| 51 |
+
print(f"Average MD5 time: {avg_md5_time:.4f} seconds")
|
| 52 |
+
print(f"xxhash is {avg_md5_time / avg_xxhash_time:.2f}x faster than MD5 on average")
|
| 53 |
+
print(f"\nxxhash time standard deviation: {std_xxhash_time:.4f} seconds")
|
| 54 |
+
print(f"MD5 time standard deviation: {std_md5_time:.4f} seconds")
|
nano-graphrag/examples/finetune_entity_relationship_dspy.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
nano-graphrag/examples/generate_entity_relationship_dspy.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
nano-graphrag/examples/graphml_visualize.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import networkx as nx
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import webbrowser
|
| 5 |
+
import http.server
|
| 6 |
+
import socketserver
|
| 7 |
+
import threading
|
| 8 |
+
|
| 9 |
+
# load GraphML file and transfer to JSON
|
| 10 |
+
def graphml_to_json(graphml_file):
|
| 11 |
+
G = nx.read_graphml(graphml_file)
|
| 12 |
+
data = nx.node_link_data(G)
|
| 13 |
+
return json.dumps(data)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# create HTML file
|
| 17 |
+
def create_html(html_path):
|
| 18 |
+
html_content = '''
|
| 19 |
+
<!DOCTYPE html>
|
| 20 |
+
<html lang="en">
|
| 21 |
+
<head>
|
| 22 |
+
<meta charset="UTF-8">
|
| 23 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 24 |
+
<title>Graph Visualization</title>
|
| 25 |
+
<script src="https://d3js.org/d3.v7.min.js"></script>
|
| 26 |
+
<style>
|
| 27 |
+
body, html {
|
| 28 |
+
margin: 0;
|
| 29 |
+
padding: 0;
|
| 30 |
+
width: 100%;
|
| 31 |
+
height: 100%;
|
| 32 |
+
overflow: hidden;
|
| 33 |
+
}
|
| 34 |
+
svg {
|
| 35 |
+
width: 100%;
|
| 36 |
+
height: 100%;
|
| 37 |
+
}
|
| 38 |
+
.links line {
|
| 39 |
+
stroke: #999;
|
| 40 |
+
stroke-opacity: 0.6;
|
| 41 |
+
}
|
| 42 |
+
.nodes circle {
|
| 43 |
+
stroke: #fff;
|
| 44 |
+
stroke-width: 1.5px;
|
| 45 |
+
}
|
| 46 |
+
.node-label {
|
| 47 |
+
font-size: 12px;
|
| 48 |
+
pointer-events: none;
|
| 49 |
+
}
|
| 50 |
+
.link-label {
|
| 51 |
+
font-size: 10px;
|
| 52 |
+
fill: #666;
|
| 53 |
+
pointer-events: none;
|
| 54 |
+
opacity: 0;
|
| 55 |
+
transition: opacity 0.3s;
|
| 56 |
+
}
|
| 57 |
+
.link:hover .link-label {
|
| 58 |
+
opacity: 1;
|
| 59 |
+
}
|
| 60 |
+
.tooltip {
|
| 61 |
+
position: absolute;
|
| 62 |
+
text-align: left;
|
| 63 |
+
padding: 10px;
|
| 64 |
+
font: 12px sans-serif;
|
| 65 |
+
background: lightsteelblue;
|
| 66 |
+
border: 0px;
|
| 67 |
+
border-radius: 8px;
|
| 68 |
+
pointer-events: none;
|
| 69 |
+
opacity: 0;
|
| 70 |
+
transition: opacity 0.3s;
|
| 71 |
+
max-width: 300px;
|
| 72 |
+
}
|
| 73 |
+
.legend {
|
| 74 |
+
position: absolute;
|
| 75 |
+
top: 10px;
|
| 76 |
+
right: 10px;
|
| 77 |
+
background-color: rgba(255, 255, 255, 0.8);
|
| 78 |
+
padding: 10px;
|
| 79 |
+
border-radius: 5px;
|
| 80 |
+
}
|
| 81 |
+
.legend-item {
|
| 82 |
+
margin: 5px 0;
|
| 83 |
+
}
|
| 84 |
+
.legend-color {
|
| 85 |
+
display: inline-block;
|
| 86 |
+
width: 20px;
|
| 87 |
+
height: 20px;
|
| 88 |
+
margin-right: 5px;
|
| 89 |
+
vertical-align: middle;
|
| 90 |
+
}
|
| 91 |
+
</style>
|
| 92 |
+
</head>
|
| 93 |
+
<body>
|
| 94 |
+
<svg></svg>
|
| 95 |
+
<div class="tooltip"></div>
|
| 96 |
+
<div class="legend"></div>
|
| 97 |
+
<script type="text/javascript" src="./graph_json.js"></script>
|
| 98 |
+
<script>
|
| 99 |
+
const graphData = graphJson;
|
| 100 |
+
|
| 101 |
+
const svg = d3.select("svg"),
|
| 102 |
+
width = window.innerWidth,
|
| 103 |
+
height = window.innerHeight;
|
| 104 |
+
|
| 105 |
+
svg.attr("viewBox", [0, 0, width, height]);
|
| 106 |
+
|
| 107 |
+
const g = svg.append("g");
|
| 108 |
+
|
| 109 |
+
const entityTypes = [...new Set(graphData.nodes.map(d => d.entity_type))];
|
| 110 |
+
const color = d3.scaleOrdinal(d3.schemeCategory10).domain(entityTypes);
|
| 111 |
+
|
| 112 |
+
const simulation = d3.forceSimulation(graphData.nodes)
|
| 113 |
+
.force("link", d3.forceLink(graphData.links).id(d => d.id).distance(150))
|
| 114 |
+
.force("charge", d3.forceManyBody().strength(-300))
|
| 115 |
+
.force("center", d3.forceCenter(width / 2, height / 2))
|
| 116 |
+
.force("collide", d3.forceCollide().radius(30));
|
| 117 |
+
|
| 118 |
+
const linkGroup = g.append("g")
|
| 119 |
+
.attr("class", "links")
|
| 120 |
+
.selectAll("g")
|
| 121 |
+
.data(graphData.links)
|
| 122 |
+
.enter().append("g")
|
| 123 |
+
.attr("class", "link");
|
| 124 |
+
|
| 125 |
+
const link = linkGroup.append("line")
|
| 126 |
+
.attr("stroke-width", d => Math.sqrt(d.value));
|
| 127 |
+
|
| 128 |
+
const linkLabel = linkGroup.append("text")
|
| 129 |
+
.attr("class", "link-label")
|
| 130 |
+
.text(d => d.description || "");
|
| 131 |
+
|
| 132 |
+
const node = g.append("g")
|
| 133 |
+
.attr("class", "nodes")
|
| 134 |
+
.selectAll("circle")
|
| 135 |
+
.data(graphData.nodes)
|
| 136 |
+
.enter().append("circle")
|
| 137 |
+
.attr("r", 5)
|
| 138 |
+
.attr("fill", d => color(d.entity_type))
|
| 139 |
+
.call(d3.drag()
|
| 140 |
+
.on("start", dragstarted)
|
| 141 |
+
.on("drag", dragged)
|
| 142 |
+
.on("end", dragended));
|
| 143 |
+
|
| 144 |
+
const nodeLabel = g.append("g")
|
| 145 |
+
.attr("class", "node-labels")
|
| 146 |
+
.selectAll("text")
|
| 147 |
+
.data(graphData.nodes)
|
| 148 |
+
.enter().append("text")
|
| 149 |
+
.attr("class", "node-label")
|
| 150 |
+
.text(d => d.id);
|
| 151 |
+
|
| 152 |
+
const tooltip = d3.select(".tooltip");
|
| 153 |
+
|
| 154 |
+
node.on("mouseover", function(event, d) {
|
| 155 |
+
tooltip.transition()
|
| 156 |
+
.duration(200)
|
| 157 |
+
.style("opacity", .9);
|
| 158 |
+
tooltip.html(`<strong>${d.id}</strong><br>Entity Type: ${d.entity_type}<br>Description: ${d.description || "N/A"}`)
|
| 159 |
+
.style("left", (event.pageX + 10) + "px")
|
| 160 |
+
.style("top", (event.pageY - 28) + "px");
|
| 161 |
+
})
|
| 162 |
+
.on("mouseout", function(d) {
|
| 163 |
+
tooltip.transition()
|
| 164 |
+
.duration(500)
|
| 165 |
+
.style("opacity", 0);
|
| 166 |
+
});
|
| 167 |
+
|
| 168 |
+
const legend = d3.select(".legend");
|
| 169 |
+
entityTypes.forEach(type => {
|
| 170 |
+
legend.append("div")
|
| 171 |
+
.attr("class", "legend-item")
|
| 172 |
+
.html(`<span class="legend-color" style="background-color: ${color(type)}"></span>${type}`);
|
| 173 |
+
});
|
| 174 |
+
|
| 175 |
+
simulation
|
| 176 |
+
.nodes(graphData.nodes)
|
| 177 |
+
.on("tick", ticked);
|
| 178 |
+
|
| 179 |
+
simulation.force("link")
|
| 180 |
+
.links(graphData.links);
|
| 181 |
+
|
| 182 |
+
function ticked() {
|
| 183 |
+
link
|
| 184 |
+
.attr("x1", d => d.source.x)
|
| 185 |
+
.attr("y1", d => d.source.y)
|
| 186 |
+
.attr("x2", d => d.target.x)
|
| 187 |
+
.attr("y2", d => d.target.y);
|
| 188 |
+
|
| 189 |
+
linkLabel
|
| 190 |
+
.attr("x", d => (d.source.x + d.target.x) / 2)
|
| 191 |
+
.attr("y", d => (d.source.y + d.target.y) / 2)
|
| 192 |
+
.attr("text-anchor", "middle")
|
| 193 |
+
.attr("dominant-baseline", "middle");
|
| 194 |
+
|
| 195 |
+
node
|
| 196 |
+
.attr("cx", d => d.x)
|
| 197 |
+
.attr("cy", d => d.y);
|
| 198 |
+
|
| 199 |
+
nodeLabel
|
| 200 |
+
.attr("x", d => d.x + 8)
|
| 201 |
+
.attr("y", d => d.y + 3);
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
function dragstarted(event) {
|
| 205 |
+
if (!event.active) simulation.alphaTarget(0.3).restart();
|
| 206 |
+
event.subject.fx = event.subject.x;
|
| 207 |
+
event.subject.fy = event.subject.y;
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
function dragged(event) {
|
| 211 |
+
event.subject.fx = event.x;
|
| 212 |
+
event.subject.fy = event.y;
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
function dragended(event) {
|
| 216 |
+
if (!event.active) simulation.alphaTarget(0);
|
| 217 |
+
event.subject.fx = null;
|
| 218 |
+
event.subject.fy = null;
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
const zoom = d3.zoom()
|
| 222 |
+
.scaleExtent([0.1, 10])
|
| 223 |
+
.on("zoom", zoomed);
|
| 224 |
+
|
| 225 |
+
svg.call(zoom);
|
| 226 |
+
|
| 227 |
+
function zoomed(event) {
|
| 228 |
+
g.attr("transform", event.transform);
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
</script>
|
| 232 |
+
</body>
|
| 233 |
+
</html>
|
| 234 |
+
'''
|
| 235 |
+
|
| 236 |
+
with open(html_path, 'w', encoding='utf-8') as f:
|
| 237 |
+
f.write(html_content)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def create_json(json_data, json_path):
|
| 241 |
+
json_data = "var graphJson = " + json_data.replace('\\"', '').replace("'", "\\'").replace("\n", "")
|
| 242 |
+
with open(json_path, 'w', encoding='utf-8') as f:
|
| 243 |
+
f.write(json_data)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
# start simple HTTP server
|
| 247 |
+
def start_server(port):
|
| 248 |
+
handler = http.server.SimpleHTTPRequestHandler
|
| 249 |
+
with socketserver.TCPServer(("", port), handler) as httpd:
|
| 250 |
+
print(f"Server started at http://localhost:{port}")
|
| 251 |
+
httpd.serve_forever()
|
| 252 |
+
|
| 253 |
+
# main function
|
| 254 |
+
def visualize_graphml(graphml_file, html_path, port=8000):
|
| 255 |
+
json_data = graphml_to_json(graphml_file)
|
| 256 |
+
html_dir = os.path.dirname(html_path)
|
| 257 |
+
if not os.path.exists(html_dir):
|
| 258 |
+
os.makedirs(html_dir)
|
| 259 |
+
json_path = os.path.join(html_dir, 'graph_json.js')
|
| 260 |
+
create_json(json_data, json_path)
|
| 261 |
+
create_html(html_path)
|
| 262 |
+
# start server in background
|
| 263 |
+
server_thread = threading.Thread(target=start_server(port))
|
| 264 |
+
server_thread.daemon = True
|
| 265 |
+
server_thread.start()
|
| 266 |
+
|
| 267 |
+
# open default browser
|
| 268 |
+
webbrowser.open(f'http://localhost:{port}/{html_path}')
|
| 269 |
+
|
| 270 |
+
print("Visualization is ready. Press Ctrl+C to exit.")
|
| 271 |
+
try:
|
| 272 |
+
# keep main thread running
|
| 273 |
+
while True:
|
| 274 |
+
pass
|
| 275 |
+
except KeyboardInterrupt:
|
| 276 |
+
print("Shutting down...")
|
| 277 |
+
|
| 278 |
+
# usage
|
| 279 |
+
if __name__ == "__main__":
|
| 280 |
+
graphml_file = r"nano_graphrag_cache_azure_openai_TEST\graph_chunk_entity_relation.graphml" # replace with your GraphML file path
|
| 281 |
+
html_path = "graph_visualization.html"
|
| 282 |
+
visualize_graphml(graphml_file, html_path, 11236)
|
nano-graphrag/examples/no_openai_key_at_all.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
import ollama
|
| 4 |
+
import numpy as np
|
| 5 |
+
from nano_graphrag import GraphRAG, QueryParam
|
| 6 |
+
from nano_graphrag import GraphRAG, QueryParam
|
| 7 |
+
from nano_graphrag.base import BaseKVStorage
|
| 8 |
+
from nano_graphrag._utils import compute_args_hash, wrap_embedding_func_with_attrs
|
| 9 |
+
from sentence_transformers import SentenceTransformer
|
| 10 |
+
|
| 11 |
+
logging.basicConfig(level=logging.WARNING)
|
| 12 |
+
logging.getLogger("nano-graphrag").setLevel(logging.INFO)
|
| 13 |
+
|
| 14 |
+
# !!! qwen2-7B maybe produce unparsable results and cause the extraction of graph to fail.
|
| 15 |
+
WORKING_DIR = "./nano_graphrag_cache_ollama_TEST"
|
| 16 |
+
MODEL = "qwen2"
|
| 17 |
+
|
| 18 |
+
EMBED_MODEL = SentenceTransformer(
|
| 19 |
+
"sentence-transformers/all-MiniLM-L6-v2", cache_folder=WORKING_DIR, device="cpu"
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# We're using Sentence Transformers to generate embeddings for the BGE model
|
| 24 |
+
@wrap_embedding_func_with_attrs(
|
| 25 |
+
embedding_dim=EMBED_MODEL.get_sentence_embedding_dimension(),
|
| 26 |
+
max_token_size=EMBED_MODEL.max_seq_length,
|
| 27 |
+
)
|
| 28 |
+
async def local_embedding(texts: list[str]) -> np.ndarray:
|
| 29 |
+
return EMBED_MODEL.encode(texts, normalize_embeddings=True)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
async def ollama_model_if_cache(
|
| 33 |
+
prompt, system_prompt=None, history_messages=[], **kwargs
|
| 34 |
+
) -> str:
|
| 35 |
+
# remove kwargs that are not supported by ollama
|
| 36 |
+
kwargs.pop("max_tokens", None)
|
| 37 |
+
kwargs.pop("response_format", None)
|
| 38 |
+
|
| 39 |
+
ollama_client = ollama.AsyncClient()
|
| 40 |
+
messages = []
|
| 41 |
+
if system_prompt:
|
| 42 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 43 |
+
|
| 44 |
+
# Get the cached response if having-------------------
|
| 45 |
+
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
| 46 |
+
messages.extend(history_messages)
|
| 47 |
+
messages.append({"role": "user", "content": prompt})
|
| 48 |
+
if hashing_kv is not None:
|
| 49 |
+
args_hash = compute_args_hash(MODEL, messages)
|
| 50 |
+
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
| 51 |
+
if if_cache_return is not None:
|
| 52 |
+
return if_cache_return["return"]
|
| 53 |
+
# -----------------------------------------------------
|
| 54 |
+
response = await ollama_client.chat(model=MODEL, messages=messages, **kwargs)
|
| 55 |
+
|
| 56 |
+
result = response["message"]["content"]
|
| 57 |
+
# Cache the response if having-------------------
|
| 58 |
+
if hashing_kv is not None:
|
| 59 |
+
await hashing_kv.upsert({args_hash: {"return": result, "model": MODEL}})
|
| 60 |
+
# -----------------------------------------------------
|
| 61 |
+
return result
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def remove_if_exist(file):
|
| 65 |
+
if os.path.exists(file):
|
| 66 |
+
os.remove(file)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def query():
|
| 70 |
+
rag = GraphRAG(
|
| 71 |
+
working_dir=WORKING_DIR,
|
| 72 |
+
best_model_func=ollama_model_if_cache,
|
| 73 |
+
cheap_model_func=ollama_model_if_cache,
|
| 74 |
+
embedding_func=local_embedding,
|
| 75 |
+
)
|
| 76 |
+
print(
|
| 77 |
+
rag.query(
|
| 78 |
+
"What are the top themes in this story?", param=QueryParam(mode="global")
|
| 79 |
+
)
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def insert():
|
| 84 |
+
from time import time
|
| 85 |
+
|
| 86 |
+
with open("./tests/mock_data.txt", encoding="utf-8-sig") as f:
|
| 87 |
+
FAKE_TEXT = f.read()
|
| 88 |
+
|
| 89 |
+
remove_if_exist(f"{WORKING_DIR}/vdb_entities.json")
|
| 90 |
+
remove_if_exist(f"{WORKING_DIR}/kv_store_full_docs.json")
|
| 91 |
+
remove_if_exist(f"{WORKING_DIR}/kv_store_text_chunks.json")
|
| 92 |
+
remove_if_exist(f"{WORKING_DIR}/kv_store_community_reports.json")
|
| 93 |
+
remove_if_exist(f"{WORKING_DIR}/graph_chunk_entity_relation.graphml")
|
| 94 |
+
|
| 95 |
+
rag = GraphRAG(
|
| 96 |
+
working_dir=WORKING_DIR,
|
| 97 |
+
enable_llm_cache=True,
|
| 98 |
+
best_model_func=ollama_model_if_cache,
|
| 99 |
+
cheap_model_func=ollama_model_if_cache,
|
| 100 |
+
embedding_func=local_embedding,
|
| 101 |
+
)
|
| 102 |
+
start = time()
|
| 103 |
+
rag.insert(FAKE_TEXT)
|
| 104 |
+
print("indexing time:", time() - start)
|
| 105 |
+
# rag = GraphRAG(working_dir=WORKING_DIR, enable_llm_cache=True)
|
| 106 |
+
# rag.insert(FAKE_TEXT[half_len:])
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
if __name__ == "__main__":
|
| 110 |
+
insert()
|
| 111 |
+
query()
|
nano-graphrag/examples/using_amazon_bedrock.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from nano_graphrag import GraphRAG, QueryParam
|
| 2 |
+
|
| 3 |
+
graph_func = GraphRAG(
|
| 4 |
+
working_dir="../bedrock_example",
|
| 5 |
+
using_amazon_bedrock=True,
|
| 6 |
+
best_model_id="us.anthropic.claude-3-sonnet-20240229-v1:0",
|
| 7 |
+
cheap_model_id="us.anthropic.claude-3-haiku-20240307-v1:0",
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
with open("../tests/mock_data.txt") as f:
|
| 11 |
+
graph_func.insert(f.read())
|
| 12 |
+
|
| 13 |
+
prompt = "What are the top themes in this story?"
|
| 14 |
+
|
| 15 |
+
# Perform global graphrag search
|
| 16 |
+
print(graph_func.query(prompt, param=QueryParam(mode="global")))
|
| 17 |
+
|
| 18 |
+
# Perform local graphrag search (I think is better and more scalable one)
|
| 19 |
+
print(graph_func.query(prompt, param=QueryParam(mode="local")))
|
nano-graphrag/examples/using_custom_chunking_method.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from nano_graphrag._utils import encode_string_by_tiktoken
|
| 2 |
+
from nano_graphrag.base import QueryParam
|
| 3 |
+
from nano_graphrag.graphrag import GraphRAG
|
| 4 |
+
from nano_graphrag._op import chunking_by_seperators
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def chunking_by_token_size(
|
| 8 |
+
tokens_list: list[list[int]], # nano-graphrag may pass a batch of docs' tokens
|
| 9 |
+
doc_keys: list[str], # nano-graphrag may pass a batch of docs' key ids
|
| 10 |
+
tiktoken_model, # a titoken model
|
| 11 |
+
overlap_token_size=128,
|
| 12 |
+
max_token_size=1024,
|
| 13 |
+
):
|
| 14 |
+
|
| 15 |
+
results = []
|
| 16 |
+
for index, tokens in enumerate(tokens_list):
|
| 17 |
+
chunk_token = []
|
| 18 |
+
lengths = []
|
| 19 |
+
for start in range(0, len(tokens), max_token_size - overlap_token_size):
|
| 20 |
+
|
| 21 |
+
chunk_token.append(tokens[start : start + max_token_size])
|
| 22 |
+
lengths.append(min(max_token_size, len(tokens) - start))
|
| 23 |
+
|
| 24 |
+
chunk_token = tiktoken_model.decode_batch(chunk_token)
|
| 25 |
+
for i, chunk in enumerate(chunk_token):
|
| 26 |
+
|
| 27 |
+
results.append(
|
| 28 |
+
{
|
| 29 |
+
"tokens": lengths[i],
|
| 30 |
+
"content": chunk.strip(),
|
| 31 |
+
"chunk_order_index": i,
|
| 32 |
+
"full_doc_id": doc_keys[index],
|
| 33 |
+
}
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
return results
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
WORKING_DIR = "./nano_graphrag_cache_local_embedding_TEST"
|
| 40 |
+
rag = GraphRAG(
|
| 41 |
+
working_dir=WORKING_DIR,
|
| 42 |
+
chunk_func=chunking_by_seperators,
|
| 43 |
+
)
|
nano-graphrag/examples/using_deepseek_api_as_llm+glm_api_as_embedding.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
import numpy as np
|
| 4 |
+
from openai import AsyncOpenAI, OpenAI
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from nano_graphrag import GraphRAG, QueryParam
|
| 7 |
+
from nano_graphrag.base import BaseKVStorage
|
| 8 |
+
from nano_graphrag._utils import compute_args_hash
|
| 9 |
+
|
| 10 |
+
logging.basicConfig(level=logging.WARNING)
|
| 11 |
+
logging.getLogger("nano-graphrag").setLevel(logging.INFO)
|
| 12 |
+
|
| 13 |
+
GLM_API_KEY = "XXXX"
|
| 14 |
+
DEEPSEEK_API_KEY = "sk-XXXX"
|
| 15 |
+
|
| 16 |
+
MODEL = "deepseek-chat"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
async def deepseepk_model_if_cache(
|
| 20 |
+
prompt, system_prompt=None, history_messages=[], **kwargs
|
| 21 |
+
) -> str:
|
| 22 |
+
openai_async_client = AsyncOpenAI(
|
| 23 |
+
api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com"
|
| 24 |
+
)
|
| 25 |
+
messages = []
|
| 26 |
+
if system_prompt:
|
| 27 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 28 |
+
|
| 29 |
+
# Get the cached response if having-------------------
|
| 30 |
+
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
| 31 |
+
messages.extend(history_messages)
|
| 32 |
+
messages.append({"role": "user", "content": prompt})
|
| 33 |
+
if hashing_kv is not None:
|
| 34 |
+
args_hash = compute_args_hash(MODEL, messages)
|
| 35 |
+
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
| 36 |
+
if if_cache_return is not None:
|
| 37 |
+
return if_cache_return["return"]
|
| 38 |
+
# -----------------------------------------------------
|
| 39 |
+
|
| 40 |
+
response = await openai_async_client.chat.completions.create(
|
| 41 |
+
model=MODEL, messages=messages, **kwargs
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# Cache the response if having-------------------
|
| 45 |
+
if hashing_kv is not None:
|
| 46 |
+
await hashing_kv.upsert(
|
| 47 |
+
{args_hash: {"return": response.choices[0].message.content, "model": MODEL}}
|
| 48 |
+
)
|
| 49 |
+
# -----------------------------------------------------
|
| 50 |
+
return response.choices[0].message.content
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def remove_if_exist(file):
|
| 54 |
+
if os.path.exists(file):
|
| 55 |
+
os.remove(file)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@dataclass
|
| 59 |
+
class EmbeddingFunc:
|
| 60 |
+
embedding_dim: int
|
| 61 |
+
max_token_size: int
|
| 62 |
+
func: callable
|
| 63 |
+
|
| 64 |
+
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
| 65 |
+
return await self.func(*args, **kwargs)
|
| 66 |
+
|
| 67 |
+
def wrap_embedding_func_with_attrs(**kwargs):
|
| 68 |
+
"""Wrap a function with attributes"""
|
| 69 |
+
|
| 70 |
+
def final_decro(func) -> EmbeddingFunc:
|
| 71 |
+
new_func = EmbeddingFunc(**kwargs, func=func)
|
| 72 |
+
return new_func
|
| 73 |
+
|
| 74 |
+
return final_decro
|
| 75 |
+
|
| 76 |
+
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
| 77 |
+
async def GLM_embedding(texts: list[str]) -> np.ndarray:
|
| 78 |
+
model_name = "embedding-2"
|
| 79 |
+
client = OpenAI(
|
| 80 |
+
api_key=GLM_API_KEY,
|
| 81 |
+
base_url="https://open.bigmodel.cn/api/paas/v4/"
|
| 82 |
+
)
|
| 83 |
+
embedding = client.embeddings.create(
|
| 84 |
+
input=texts,
|
| 85 |
+
model=model_name,
|
| 86 |
+
)
|
| 87 |
+
final_embedding = [d.embedding for d in embedding.data]
|
| 88 |
+
return np.array(final_embedding)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
WORKING_DIR = "./nano_graphrag_cache_deepseek_TEST"
|
| 93 |
+
|
| 94 |
+
def query():
|
| 95 |
+
rag = GraphRAG(
|
| 96 |
+
working_dir=WORKING_DIR,
|
| 97 |
+
best_model_func=deepseepk_model_if_cache,
|
| 98 |
+
cheap_model_func=deepseepk_model_if_cache,
|
| 99 |
+
embedding_func=GLM_embedding,
|
| 100 |
+
)
|
| 101 |
+
print(
|
| 102 |
+
rag.query(
|
| 103 |
+
"What are the top themes in this story?", param=QueryParam(mode="global")
|
| 104 |
+
)
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def insert():
|
| 109 |
+
from time import time
|
| 110 |
+
|
| 111 |
+
with open("./tests/mock_data.txt", encoding="utf-8-sig") as f:
|
| 112 |
+
FAKE_TEXT = f.read()
|
| 113 |
+
|
| 114 |
+
remove_if_exist(f"{WORKING_DIR}/vdb_entities.json")
|
| 115 |
+
remove_if_exist(f"{WORKING_DIR}/kv_store_full_docs.json")
|
| 116 |
+
remove_if_exist(f"{WORKING_DIR}/kv_store_text_chunks.json")
|
| 117 |
+
remove_if_exist(f"{WORKING_DIR}/kv_store_community_reports.json")
|
| 118 |
+
remove_if_exist(f"{WORKING_DIR}/graph_chunk_entity_relation.graphml")
|
| 119 |
+
|
| 120 |
+
rag = GraphRAG(
|
| 121 |
+
working_dir=WORKING_DIR,
|
| 122 |
+
enable_llm_cache=True,
|
| 123 |
+
best_model_func=deepseepk_model_if_cache,
|
| 124 |
+
cheap_model_func=deepseepk_model_if_cache,
|
| 125 |
+
embedding_func=GLM_embedding,
|
| 126 |
+
)
|
| 127 |
+
start = time()
|
| 128 |
+
rag.insert(FAKE_TEXT)
|
| 129 |
+
print("indexing time:", time() - start)
|
| 130 |
+
# rag = GraphRAG(working_dir=WORKING_DIR, enable_llm_cache=True)
|
| 131 |
+
# rag.insert(FAKE_TEXT[half_len:])
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
if __name__ == "__main__":
|
| 135 |
+
insert()
|
| 136 |
+
# query()
|
nano-graphrag/examples/using_deepseek_as_llm.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
from openai import AsyncOpenAI
|
| 4 |
+
from nano_graphrag import GraphRAG, QueryParam
|
| 5 |
+
from nano_graphrag import GraphRAG, QueryParam
|
| 6 |
+
from nano_graphrag.base import BaseKVStorage
|
| 7 |
+
from nano_graphrag._utils import compute_args_hash
|
| 8 |
+
|
| 9 |
+
logging.basicConfig(level=logging.WARNING)
|
| 10 |
+
logging.getLogger("nano-graphrag").setLevel(logging.INFO)
|
| 11 |
+
|
| 12 |
+
DEEPSEEK_API_KEY = "sk-XXXX"
|
| 13 |
+
MODEL = "deepseek-chat"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
async def deepseepk_model_if_cache(
|
| 17 |
+
prompt, system_prompt=None, history_messages=[], **kwargs
|
| 18 |
+
) -> str:
|
| 19 |
+
openai_async_client = AsyncOpenAI(
|
| 20 |
+
api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com"
|
| 21 |
+
)
|
| 22 |
+
messages = []
|
| 23 |
+
if system_prompt:
|
| 24 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 25 |
+
|
| 26 |
+
# Get the cached response if having-------------------
|
| 27 |
+
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
| 28 |
+
messages.extend(history_messages)
|
| 29 |
+
messages.append({"role": "user", "content": prompt})
|
| 30 |
+
if hashing_kv is not None:
|
| 31 |
+
args_hash = compute_args_hash(MODEL, messages)
|
| 32 |
+
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
| 33 |
+
if if_cache_return is not None:
|
| 34 |
+
return if_cache_return["return"]
|
| 35 |
+
# -----------------------------------------------------
|
| 36 |
+
|
| 37 |
+
response = await openai_async_client.chat.completions.create(
|
| 38 |
+
model=MODEL, messages=messages, **kwargs
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# Cache the response if having-------------------
|
| 42 |
+
if hashing_kv is not None:
|
| 43 |
+
await hashing_kv.upsert(
|
| 44 |
+
{args_hash: {"return": response.choices[0].message.content, "model": MODEL}}
|
| 45 |
+
)
|
| 46 |
+
# -----------------------------------------------------
|
| 47 |
+
return response.choices[0].message.content
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def remove_if_exist(file):
|
| 51 |
+
if os.path.exists(file):
|
| 52 |
+
os.remove(file)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
WORKING_DIR = "./nano_graphrag_cache_deepseek_TEST"
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def query():
|
| 59 |
+
rag = GraphRAG(
|
| 60 |
+
working_dir=WORKING_DIR,
|
| 61 |
+
best_model_func=deepseepk_model_if_cache,
|
| 62 |
+
cheap_model_func=deepseepk_model_if_cache,
|
| 63 |
+
)
|
| 64 |
+
print(
|
| 65 |
+
rag.query(
|
| 66 |
+
"What are the top themes in this story?", param=QueryParam(mode="global")
|
| 67 |
+
)
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def insert():
|
| 72 |
+
from time import time
|
| 73 |
+
|
| 74 |
+
with open("./tests/mock_data.txt", encoding="utf-8-sig") as f:
|
| 75 |
+
FAKE_TEXT = f.read()
|
| 76 |
+
|
| 77 |
+
remove_if_exist(f"{WORKING_DIR}/vdb_entities.json")
|
| 78 |
+
remove_if_exist(f"{WORKING_DIR}/kv_store_full_docs.json")
|
| 79 |
+
remove_if_exist(f"{WORKING_DIR}/kv_store_text_chunks.json")
|
| 80 |
+
remove_if_exist(f"{WORKING_DIR}/kv_store_community_reports.json")
|
| 81 |
+
remove_if_exist(f"{WORKING_DIR}/graph_chunk_entity_relation.graphml")
|
| 82 |
+
|
| 83 |
+
rag = GraphRAG(
|
| 84 |
+
working_dir=WORKING_DIR,
|
| 85 |
+
enable_llm_cache=True,
|
| 86 |
+
best_model_func=deepseepk_model_if_cache,
|
| 87 |
+
cheap_model_func=deepseepk_model_if_cache,
|
| 88 |
+
)
|
| 89 |
+
start = time()
|
| 90 |
+
rag.insert(FAKE_TEXT)
|
| 91 |
+
print("indexing time:", time() - start)
|
| 92 |
+
# rag = GraphRAG(working_dir=WORKING_DIR, enable_llm_cache=True)
|
| 93 |
+
# rag.insert(FAKE_TEXT[half_len:])
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
if __name__ == "__main__":
|
| 97 |
+
insert()
|
| 98 |
+
# query()
|
nano-graphrag/examples/using_dspy_entity_extraction.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from openai import AsyncOpenAI
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
import logging
|
| 5 |
+
import numpy as np
|
| 6 |
+
import dspy
|
| 7 |
+
from sentence_transformers import SentenceTransformer
|
| 8 |
+
from nano_graphrag import GraphRAG, QueryParam
|
| 9 |
+
from nano_graphrag._llm import gpt_4o_mini_complete
|
| 10 |
+
from nano_graphrag._storage import HNSWVectorStorage
|
| 11 |
+
from nano_graphrag.base import BaseKVStorage
|
| 12 |
+
from nano_graphrag._utils import compute_args_hash, wrap_embedding_func_with_attrs
|
| 13 |
+
from nano_graphrag.entity_extraction.extract import extract_entities_dspy
|
| 14 |
+
|
| 15 |
+
logging.basicConfig(level=logging.WARNING)
|
| 16 |
+
logging.getLogger("nano-graphrag").setLevel(logging.DEBUG)
|
| 17 |
+
|
| 18 |
+
WORKING_DIR = "./nano_graphrag_cache_using_dspy_entity_extraction"
|
| 19 |
+
|
| 20 |
+
load_dotenv()
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
EMBED_MODEL = SentenceTransformer(
|
| 24 |
+
"sentence-transformers/all-MiniLM-L6-v2", cache_folder=WORKING_DIR, device="cpu"
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@wrap_embedding_func_with_attrs(
|
| 29 |
+
embedding_dim=EMBED_MODEL.get_sentence_embedding_dimension(),
|
| 30 |
+
max_token_size=EMBED_MODEL.max_seq_length,
|
| 31 |
+
)
|
| 32 |
+
async def local_embedding(texts: list[str]) -> np.ndarray:
|
| 33 |
+
return EMBED_MODEL.encode(texts, normalize_embeddings=True)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
async def deepseepk_model_if_cache(
|
| 37 |
+
prompt, model: str = "deepseek-chat", system_prompt=None, history_messages=[], **kwargs
|
| 38 |
+
) -> str:
|
| 39 |
+
openai_async_client = AsyncOpenAI(
|
| 40 |
+
api_key=os.environ.get("DEEPSEEK_API_KEY"), base_url="https://api.deepseek.com"
|
| 41 |
+
)
|
| 42 |
+
messages = []
|
| 43 |
+
if system_prompt:
|
| 44 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 45 |
+
|
| 46 |
+
# Get the cached response if having-------------------
|
| 47 |
+
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
| 48 |
+
messages.extend(history_messages)
|
| 49 |
+
messages.append({"role": "user", "content": prompt})
|
| 50 |
+
if hashing_kv is not None:
|
| 51 |
+
args_hash = compute_args_hash(model, messages)
|
| 52 |
+
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
| 53 |
+
if if_cache_return is not None:
|
| 54 |
+
return if_cache_return["return"]
|
| 55 |
+
# -----------------------------------------------------
|
| 56 |
+
|
| 57 |
+
response = await openai_async_client.chat.completions.create(
|
| 58 |
+
model=model, messages=messages, **kwargs
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# Cache the response if having-------------------
|
| 62 |
+
if hashing_kv is not None:
|
| 63 |
+
await hashing_kv.upsert(
|
| 64 |
+
{args_hash: {"return": response.choices[0].message.content, "model": model}}
|
| 65 |
+
)
|
| 66 |
+
# -----------------------------------------------------
|
| 67 |
+
return response.choices[0].message.content
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def remove_if_exist(file):
|
| 72 |
+
if os.path.exists(file):
|
| 73 |
+
os.remove(file)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def insert():
|
| 77 |
+
from time import time
|
| 78 |
+
|
| 79 |
+
with open("./tests/mock_data.txt", encoding="utf-8-sig") as f:
|
| 80 |
+
FAKE_TEXT = f.read()
|
| 81 |
+
|
| 82 |
+
remove_if_exist(f"{WORKING_DIR}/vdb_entities.json")
|
| 83 |
+
remove_if_exist(f"{WORKING_DIR}/kv_store_full_docs.json")
|
| 84 |
+
remove_if_exist(f"{WORKING_DIR}/kv_store_text_chunks.json")
|
| 85 |
+
remove_if_exist(f"{WORKING_DIR}/kv_store_community_reports.json")
|
| 86 |
+
remove_if_exist(f"{WORKING_DIR}/graph_chunk_entity_relation.graphml")
|
| 87 |
+
rag = GraphRAG(
|
| 88 |
+
working_dir=WORKING_DIR,
|
| 89 |
+
enable_llm_cache=True,
|
| 90 |
+
vector_db_storage_cls=HNSWVectorStorage,
|
| 91 |
+
vector_db_storage_cls_kwargs={"max_elements": 1000000, "ef_search": 200, "M": 50},
|
| 92 |
+
best_model_max_async=10,
|
| 93 |
+
cheap_model_max_async=10,
|
| 94 |
+
best_model_func=deepseepk_model_if_cache,
|
| 95 |
+
cheap_model_func=deepseepk_model_if_cache,
|
| 96 |
+
embedding_func=local_embedding,
|
| 97 |
+
entity_extraction_func=extract_entities_dspy
|
| 98 |
+
)
|
| 99 |
+
start = time()
|
| 100 |
+
rag.insert(FAKE_TEXT)
|
| 101 |
+
print("indexing time:", time() - start)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def query():
|
| 105 |
+
rag = GraphRAG(
|
| 106 |
+
working_dir=WORKING_DIR,
|
| 107 |
+
enable_llm_cache=True,
|
| 108 |
+
vector_db_storage_cls=HNSWVectorStorage,
|
| 109 |
+
vector_db_storage_cls_kwargs={"max_elements": 1000000, "ef_search": 200, "M": 50},
|
| 110 |
+
best_model_max_token_size=8196,
|
| 111 |
+
cheap_model_max_token_size=8196,
|
| 112 |
+
best_model_max_async=4,
|
| 113 |
+
cheap_model_max_async=4,
|
| 114 |
+
best_model_func=gpt_4o_mini_complete,
|
| 115 |
+
cheap_model_func=gpt_4o_mini_complete,
|
| 116 |
+
embedding_func=local_embedding,
|
| 117 |
+
entity_extraction_func=extract_entities_dspy
|
| 118 |
+
|
| 119 |
+
)
|
| 120 |
+
print(
|
| 121 |
+
rag.query(
|
| 122 |
+
"What are the top themes in this story?", param=QueryParam(mode="global")
|
| 123 |
+
)
|
| 124 |
+
)
|
| 125 |
+
print(
|
| 126 |
+
rag.query(
|
| 127 |
+
"What are the top themes in this story?", param=QueryParam(mode="local")
|
| 128 |
+
)
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
if __name__ == "__main__":
|
| 133 |
+
lm = dspy.LM(
|
| 134 |
+
model="deepseek/deepseek-chat",
|
| 135 |
+
model_type="chat",
|
| 136 |
+
api_provider="openai",
|
| 137 |
+
api_key=os.environ["DEEPSEEK_API_KEY"],
|
| 138 |
+
base_url=os.environ["DEEPSEEK_BASE_URL"],
|
| 139 |
+
temperature=1.0,
|
| 140 |
+
max_tokens=8192
|
| 141 |
+
)
|
| 142 |
+
dspy.settings.configure(lm=lm, experimental=True)
|
| 143 |
+
insert()
|
| 144 |
+
query()
|
nano-graphrag/examples/using_faiss_as_vextorDB.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import asyncio
|
| 3 |
+
import numpy as np
|
| 4 |
+
from nano_graphrag.graphrag import GraphRAG, QueryParam
|
| 5 |
+
from nano_graphrag._utils import logger
|
| 6 |
+
from nano_graphrag.base import BaseVectorStorage
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
import faiss
|
| 9 |
+
import pickle
|
| 10 |
+
import logging
|
| 11 |
+
import xxhash
|
| 12 |
+
logging.getLogger('msal').setLevel(logging.WARNING)
|
| 13 |
+
logging.getLogger('azure').setLevel(logging.WARNING)
|
| 14 |
+
logging.getLogger("httpx").setLevel(logging.WARNING)
|
| 15 |
+
|
| 16 |
+
WORKING_DIR = "./nano_graphrag_cache_faiss_TEST"
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class FAISSStorage(BaseVectorStorage):
|
| 20 |
+
|
| 21 |
+
def __post_init__(self):
|
| 22 |
+
self._index_file_name = os.path.join(
|
| 23 |
+
self.global_config["working_dir"], f"{self.namespace}_faiss.index"
|
| 24 |
+
)
|
| 25 |
+
self._metadata_file_name = os.path.join(
|
| 26 |
+
self.global_config["working_dir"], f"{self.namespace}_metadata.pkl"
|
| 27 |
+
)
|
| 28 |
+
self._max_batch_size = self.global_config["embedding_batch_num"]
|
| 29 |
+
|
| 30 |
+
if os.path.exists(self._index_file_name) and os.path.exists(self._metadata_file_name):
|
| 31 |
+
self._index = faiss.read_index(self._index_file_name)
|
| 32 |
+
with open(self._metadata_file_name, 'rb') as f:
|
| 33 |
+
self._metadata = pickle.load(f)
|
| 34 |
+
else:
|
| 35 |
+
self._index = faiss.IndexIDMap(faiss.IndexFlatIP(self.embedding_func.embedding_dim))
|
| 36 |
+
self._metadata = {}
|
| 37 |
+
|
| 38 |
+
async def upsert(self, data: dict[str, dict]):
|
| 39 |
+
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
| 40 |
+
|
| 41 |
+
contents = [v["content"] for v in data.values()]
|
| 42 |
+
batches = [
|
| 43 |
+
contents[i : i + self._max_batch_size]
|
| 44 |
+
for i in range(0, len(contents), self._max_batch_size)
|
| 45 |
+
]
|
| 46 |
+
embeddings_list = await asyncio.gather(
|
| 47 |
+
*[self.embedding_func(batch) for batch in batches]
|
| 48 |
+
)
|
| 49 |
+
embeddings = np.concatenate(embeddings_list)
|
| 50 |
+
|
| 51 |
+
ids = []
|
| 52 |
+
for k, v in data.items():
|
| 53 |
+
id = xxhash.xxh32_intdigest(k.encode())
|
| 54 |
+
metadata = {k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields}
|
| 55 |
+
metadata['id'] = k
|
| 56 |
+
self._metadata[id] = metadata
|
| 57 |
+
ids.append(id)
|
| 58 |
+
|
| 59 |
+
ids = np.array(ids, dtype=np.int64)
|
| 60 |
+
self._index.add_with_ids(embeddings, ids)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
return len(data)
|
| 64 |
+
|
| 65 |
+
async def query(self, query, top_k=5):
|
| 66 |
+
embedding = await self.embedding_func([query])
|
| 67 |
+
distances, indices = self._index.search(embedding, top_k)
|
| 68 |
+
|
| 69 |
+
results = []
|
| 70 |
+
for _, (distance, id) in enumerate(zip(distances[0], indices[0])):
|
| 71 |
+
if id != -1: # FAISS returns -1 for empty slots
|
| 72 |
+
if id in self._metadata:
|
| 73 |
+
metadata = self._metadata[id]
|
| 74 |
+
results.append({**metadata, "distance": 1 - distance}) # Convert to cosine distance
|
| 75 |
+
|
| 76 |
+
return results
|
| 77 |
+
|
| 78 |
+
async def index_done_callback(self):
|
| 79 |
+
faiss.write_index(self._index, self._index_file_name)
|
| 80 |
+
with open(self._metadata_file_name, 'wb') as f:
|
| 81 |
+
pickle.dump(self._metadata, f)
|
| 82 |
+
|
| 83 |
+
if __name__ == "__main__":
|
| 84 |
+
|
| 85 |
+
graph_func = GraphRAG(
|
| 86 |
+
working_dir=WORKING_DIR,
|
| 87 |
+
enable_llm_cache=True,
|
| 88 |
+
vector_db_storage_cls=FAISSStorage,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
with open(r"tests/mock_data.txt", encoding='utf-8') as f:
|
| 92 |
+
graph_func.insert(f.read()[:30000])
|
| 93 |
+
|
| 94 |
+
# Perform global graphrag search
|
| 95 |
+
print(graph_func.query("What are the top themes in this story?"))
|
| 96 |
+
|
| 97 |
+
|
nano-graphrag/examples/using_hnsw_as_vectorDB.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from openai import AsyncOpenAI
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
import logging
|
| 5 |
+
import numpy as np
|
| 6 |
+
from sentence_transformers import SentenceTransformer
|
| 7 |
+
from nano_graphrag import GraphRAG, QueryParam
|
| 8 |
+
from nano_graphrag._llm import gpt_4o_mini_complete
|
| 9 |
+
from nano_graphrag._storage import HNSWVectorStorage
|
| 10 |
+
from nano_graphrag.base import BaseKVStorage
|
| 11 |
+
from nano_graphrag._utils import compute_args_hash, wrap_embedding_func_with_attrs
|
| 12 |
+
|
| 13 |
+
logging.basicConfig(level=logging.WARNING)
|
| 14 |
+
logging.getLogger("nano-graphrag").setLevel(logging.DEBUG)
|
| 15 |
+
|
| 16 |
+
WORKING_DIR = "./nano_graphrag_cache_using_hnsw_as_vectorDB"
|
| 17 |
+
|
| 18 |
+
load_dotenv()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
EMBED_MODEL = SentenceTransformer(
|
| 22 |
+
"sentence-transformers/all-MiniLM-L6-v2", cache_folder=WORKING_DIR, device="cpu"
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@wrap_embedding_func_with_attrs(
|
| 27 |
+
embedding_dim=EMBED_MODEL.get_sentence_embedding_dimension(),
|
| 28 |
+
max_token_size=EMBED_MODEL.max_seq_length,
|
| 29 |
+
)
|
| 30 |
+
async def local_embedding(texts: list[str]) -> np.ndarray:
|
| 31 |
+
return EMBED_MODEL.encode(texts, normalize_embeddings=True)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
async def deepseepk_model_if_cache(
|
| 35 |
+
prompt, model: str = "deepseek-chat", system_prompt=None, history_messages=[], **kwargs
|
| 36 |
+
) -> str:
|
| 37 |
+
openai_async_client = AsyncOpenAI(
|
| 38 |
+
api_key=os.environ.get("DEEPSEEK_API_KEY"), base_url="https://api.deepseek.com"
|
| 39 |
+
)
|
| 40 |
+
messages = []
|
| 41 |
+
if system_prompt:
|
| 42 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 43 |
+
|
| 44 |
+
# Get the cached response if having-------------------
|
| 45 |
+
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
| 46 |
+
messages.extend(history_messages)
|
| 47 |
+
messages.append({"role": "user", "content": prompt})
|
| 48 |
+
if hashing_kv is not None:
|
| 49 |
+
args_hash = compute_args_hash(model, messages)
|
| 50 |
+
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
| 51 |
+
if if_cache_return is not None:
|
| 52 |
+
return if_cache_return["return"]
|
| 53 |
+
# -----------------------------------------------------
|
| 54 |
+
|
| 55 |
+
response = await openai_async_client.chat.completions.create(
|
| 56 |
+
model=model, messages=messages, **kwargs
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# Cache the response if having-------------------
|
| 60 |
+
if hashing_kv is not None:
|
| 61 |
+
await hashing_kv.upsert(
|
| 62 |
+
{args_hash: {"return": response.choices[0].message.content, "model": model}}
|
| 63 |
+
)
|
| 64 |
+
# -----------------------------------------------------
|
| 65 |
+
return response.choices[0].message.content
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def remove_if_exist(file):
|
| 70 |
+
if os.path.exists(file):
|
| 71 |
+
os.remove(file)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def insert():
|
| 75 |
+
from time import time
|
| 76 |
+
|
| 77 |
+
with open("./tests/mock_data.txt", encoding="utf-8-sig") as f:
|
| 78 |
+
FAKE_TEXT = f.read()
|
| 79 |
+
|
| 80 |
+
remove_if_exist(f"{WORKING_DIR}/vdb_entities.json")
|
| 81 |
+
remove_if_exist(f"{WORKING_DIR}/kv_store_full_docs.json")
|
| 82 |
+
remove_if_exist(f"{WORKING_DIR}/kv_store_text_chunks.json")
|
| 83 |
+
remove_if_exist(f"{WORKING_DIR}/kv_store_community_reports.json")
|
| 84 |
+
remove_if_exist(f"{WORKING_DIR}/graph_chunk_entity_relation.graphml")
|
| 85 |
+
rag = GraphRAG(
|
| 86 |
+
working_dir=WORKING_DIR,
|
| 87 |
+
enable_llm_cache=True,
|
| 88 |
+
vector_db_storage_cls=HNSWVectorStorage,
|
| 89 |
+
vector_db_storage_cls_kwargs={"max_elements": 1000000, "ef_search": 200, "M": 50},
|
| 90 |
+
best_model_max_async=10,
|
| 91 |
+
cheap_model_max_async=10,
|
| 92 |
+
best_model_func=deepseepk_model_if_cache,
|
| 93 |
+
cheap_model_func=deepseepk_model_if_cache,
|
| 94 |
+
embedding_func=local_embedding
|
| 95 |
+
)
|
| 96 |
+
start = time()
|
| 97 |
+
rag.insert(FAKE_TEXT)
|
| 98 |
+
print("indexing time:", time() - start)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def query():
|
| 102 |
+
rag = GraphRAG(
|
| 103 |
+
working_dir=WORKING_DIR,
|
| 104 |
+
enable_llm_cache=True,
|
| 105 |
+
vector_db_storage_cls=HNSWVectorStorage,
|
| 106 |
+
vector_db_storage_cls_kwargs={"max_elements": 1000000, "ef_search": 200, "M": 50},
|
| 107 |
+
best_model_max_token_size=8196,
|
| 108 |
+
cheap_model_max_token_size=8196,
|
| 109 |
+
best_model_max_async=4,
|
| 110 |
+
cheap_model_max_async=4,
|
| 111 |
+
best_model_func=gpt_4o_mini_complete,
|
| 112 |
+
cheap_model_func=gpt_4o_mini_complete,
|
| 113 |
+
embedding_func=local_embedding
|
| 114 |
+
)
|
| 115 |
+
print(
|
| 116 |
+
rag.query(
|
| 117 |
+
"What are the top themes in this story?", param=QueryParam(mode="global")
|
| 118 |
+
)
|
| 119 |
+
)
|
| 120 |
+
print(
|
| 121 |
+
rag.query(
|
| 122 |
+
"What are the top themes in this story?", param=QueryParam(mode="local")
|
| 123 |
+
)
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
if __name__ == "__main__":
|
| 128 |
+
insert()
|
| 129 |
+
query()
|
nano-graphrag/examples/using_llm_api_as_llm+ollama_embedding.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
import ollama
|
| 4 |
+
import numpy as np
|
| 5 |
+
from openai import AsyncOpenAI
|
| 6 |
+
from nano_graphrag import GraphRAG, QueryParam
|
| 7 |
+
from nano_graphrag import GraphRAG, QueryParam
|
| 8 |
+
from nano_graphrag.base import BaseKVStorage
|
| 9 |
+
from nano_graphrag._utils import compute_args_hash, wrap_embedding_func_with_attrs
|
| 10 |
+
|
| 11 |
+
logging.basicConfig(level=logging.WARNING)
|
| 12 |
+
logging.getLogger("nano-graphrag").setLevel(logging.INFO)
|
| 13 |
+
|
| 14 |
+
# Assumed llm model settings
|
| 15 |
+
LLM_BASE_URL = "https://your.api.url"
|
| 16 |
+
LLM_API_KEY = "your_api_key"
|
| 17 |
+
MODEL = "your_model_name"
|
| 18 |
+
|
| 19 |
+
# Assumed embedding model settings
|
| 20 |
+
EMBEDDING_MODEL = "nomic-embed-text"
|
| 21 |
+
EMBEDDING_MODEL_DIM = 768
|
| 22 |
+
EMBEDDING_MODEL_MAX_TOKENS = 8192
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
async def llm_model_if_cache(
|
| 26 |
+
prompt, system_prompt=None, history_messages=[], **kwargs
|
| 27 |
+
) -> str:
|
| 28 |
+
openai_async_client = AsyncOpenAI(
|
| 29 |
+
api_key=LLM_API_KEY, base_url=LLM_BASE_URL
|
| 30 |
+
)
|
| 31 |
+
messages = []
|
| 32 |
+
if system_prompt:
|
| 33 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 34 |
+
|
| 35 |
+
# Get the cached response if having-------------------
|
| 36 |
+
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
| 37 |
+
messages.extend(history_messages)
|
| 38 |
+
messages.append({"role": "user", "content": prompt})
|
| 39 |
+
if hashing_kv is not None:
|
| 40 |
+
args_hash = compute_args_hash(MODEL, messages)
|
| 41 |
+
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
| 42 |
+
if if_cache_return is not None:
|
| 43 |
+
return if_cache_return["return"]
|
| 44 |
+
# -----------------------------------------------------
|
| 45 |
+
|
| 46 |
+
response = await openai_async_client.chat.completions.create(
|
| 47 |
+
model=MODEL, messages=messages, **kwargs
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# Cache the response if having-------------------
|
| 51 |
+
if hashing_kv is not None:
|
| 52 |
+
await hashing_kv.upsert(
|
| 53 |
+
{args_hash: {"return": response.choices[0].message.content, "model": MODEL}}
|
| 54 |
+
)
|
| 55 |
+
# -----------------------------------------------------
|
| 56 |
+
return response.choices[0].message.content
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def remove_if_exist(file):
|
| 60 |
+
if os.path.exists(file):
|
| 61 |
+
os.remove(file)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
WORKING_DIR = "./nano_graphrag_cache_llm_TEST"
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def query():
|
| 68 |
+
rag = GraphRAG(
|
| 69 |
+
working_dir=WORKING_DIR,
|
| 70 |
+
best_model_func=llm_model_if_cache,
|
| 71 |
+
cheap_model_func=llm_model_if_cache,
|
| 72 |
+
embedding_func=ollama_embedding,
|
| 73 |
+
)
|
| 74 |
+
print(
|
| 75 |
+
rag.query(
|
| 76 |
+
"What are the top themes in this story?", param=QueryParam(mode="global")
|
| 77 |
+
)
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def insert():
|
| 82 |
+
from time import time
|
| 83 |
+
|
| 84 |
+
with open("./tests/mock_data.txt", encoding="utf-8-sig") as f:
|
| 85 |
+
FAKE_TEXT = f.read()
|
| 86 |
+
|
| 87 |
+
remove_if_exist(f"{WORKING_DIR}/vdb_entities.json")
|
| 88 |
+
remove_if_exist(f"{WORKING_DIR}/kv_store_full_docs.json")
|
| 89 |
+
remove_if_exist(f"{WORKING_DIR}/kv_store_text_chunks.json")
|
| 90 |
+
remove_if_exist(f"{WORKING_DIR}/kv_store_community_reports.json")
|
| 91 |
+
remove_if_exist(f"{WORKING_DIR}/graph_chunk_entity_relation.graphml")
|
| 92 |
+
|
| 93 |
+
rag = GraphRAG(
|
| 94 |
+
working_dir=WORKING_DIR,
|
| 95 |
+
enable_llm_cache=True,
|
| 96 |
+
best_model_func=llm_model_if_cache,
|
| 97 |
+
cheap_model_func=llm_model_if_cache,
|
| 98 |
+
embedding_func=ollama_embedding,
|
| 99 |
+
)
|
| 100 |
+
start = time()
|
| 101 |
+
rag.insert(FAKE_TEXT)
|
| 102 |
+
print("indexing time:", time() - start)
|
| 103 |
+
# rag = GraphRAG(working_dir=WORKING_DIR, enable_llm_cache=True)
|
| 104 |
+
# rag.insert(FAKE_TEXT[half_len:])
|
| 105 |
+
|
| 106 |
+
# We're using Ollama to generate embeddings for the BGE model
|
| 107 |
+
@wrap_embedding_func_with_attrs(
|
| 108 |
+
embedding_dim= EMBEDDING_MODEL_DIM,
|
| 109 |
+
max_token_size= EMBEDDING_MODEL_MAX_TOKENS,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
async def ollama_embedding(texts :list[str]) -> np.ndarray:
|
| 113 |
+
embed_text = []
|
| 114 |
+
for text in texts:
|
| 115 |
+
data = ollama.embeddings(model=EMBEDDING_MODEL, prompt=text)
|
| 116 |
+
embed_text.append(data["embedding"])
|
| 117 |
+
|
| 118 |
+
return embed_text
|
| 119 |
+
|
| 120 |
+
if __name__ == "__main__":
|
| 121 |
+
insert()
|
| 122 |
+
query()
|
nano-graphrag/examples/using_local_embedding_model.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
|
| 3 |
+
sys.path.append("..")
|
| 4 |
+
import logging
|
| 5 |
+
import numpy as np
|
| 6 |
+
from nano_graphrag import GraphRAG, QueryParam
|
| 7 |
+
from nano_graphrag._utils import wrap_embedding_func_with_attrs
|
| 8 |
+
from sentence_transformers import SentenceTransformer
|
| 9 |
+
|
| 10 |
+
logging.basicConfig(level=logging.WARNING)
|
| 11 |
+
logging.getLogger("nano-graphrag").setLevel(logging.INFO)
|
| 12 |
+
|
| 13 |
+
WORKING_DIR = "./nano_graphrag_cache_local_embedding_TEST"
|
| 14 |
+
|
| 15 |
+
EMBED_MODEL = SentenceTransformer(
|
| 16 |
+
"sentence-transformers/all-MiniLM-L6-v2", cache_folder=WORKING_DIR, device="cpu"
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# We're using Sentence Transformers to generate embeddings for the BGE model
|
| 21 |
+
@wrap_embedding_func_with_attrs(
|
| 22 |
+
embedding_dim=EMBED_MODEL.get_sentence_embedding_dimension(),
|
| 23 |
+
max_token_size=EMBED_MODEL.max_seq_length,
|
| 24 |
+
)
|
| 25 |
+
async def local_embedding(texts: list[str]) -> np.ndarray:
|
| 26 |
+
return EMBED_MODEL.encode(texts, normalize_embeddings=True)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
rag = GraphRAG(
|
| 30 |
+
working_dir=WORKING_DIR,
|
| 31 |
+
embedding_func=local_embedding,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
with open("../tests/mock_data.txt", encoding="utf-8-sig") as f:
|
| 35 |
+
FAKE_TEXT = f.read()
|
| 36 |
+
|
| 37 |
+
# rag.insert(FAKE_TEXT)
|
| 38 |
+
print(rag.query("What the main theme of this story?", param=QueryParam(mode="local")))
|
nano-graphrag/examples/using_milvus_as_vectorDB.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import asyncio
|
| 3 |
+
import numpy as np
|
| 4 |
+
from nano_graphrag import GraphRAG, QueryParam
|
| 5 |
+
from nano_graphrag._utils import logger
|
| 6 |
+
from nano_graphrag.base import BaseVectorStorage
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class MilvusLiteStorge(BaseVectorStorage):
|
| 12 |
+
|
| 13 |
+
@staticmethod
|
| 14 |
+
def create_collection_if_not_exist(client, collection_name: str, **kwargs):
|
| 15 |
+
if client.has_collection(collection_name):
|
| 16 |
+
return
|
| 17 |
+
# TODO add constants for ID max length to 32
|
| 18 |
+
client.create_collection(
|
| 19 |
+
collection_name, max_length=32, id_type="string", **kwargs
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
def __post_init__(self):
|
| 23 |
+
from pymilvus import MilvusClient
|
| 24 |
+
|
| 25 |
+
self._client_file_name = os.path.join(
|
| 26 |
+
self.global_config["working_dir"], "milvus_lite.db"
|
| 27 |
+
)
|
| 28 |
+
self._client = MilvusClient(self._client_file_name)
|
| 29 |
+
self._max_batch_size = self.global_config["embedding_batch_num"]
|
| 30 |
+
MilvusLiteStorge.create_collection_if_not_exist(
|
| 31 |
+
self._client,
|
| 32 |
+
self.namespace,
|
| 33 |
+
dimension=self.embedding_func.embedding_dim,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
async def upsert(self, data: dict[str, dict]):
|
| 37 |
+
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
| 38 |
+
list_data = [
|
| 39 |
+
{
|
| 40 |
+
"id": k,
|
| 41 |
+
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
|
| 42 |
+
}
|
| 43 |
+
for k, v in data.items()
|
| 44 |
+
]
|
| 45 |
+
contents = [v["content"] for v in data.values()]
|
| 46 |
+
batches = [
|
| 47 |
+
contents[i : i + self._max_batch_size]
|
| 48 |
+
for i in range(0, len(contents), self._max_batch_size)
|
| 49 |
+
]
|
| 50 |
+
embeddings_list = await asyncio.gather(
|
| 51 |
+
*[self.embedding_func(batch) for batch in batches]
|
| 52 |
+
)
|
| 53 |
+
embeddings = np.concatenate(embeddings_list)
|
| 54 |
+
for i, d in enumerate(list_data):
|
| 55 |
+
d["vector"] = embeddings[i]
|
| 56 |
+
results = self._client.upsert(collection_name=self.namespace, data=list_data)
|
| 57 |
+
return results
|
| 58 |
+
|
| 59 |
+
async def query(self, query, top_k=5):
|
| 60 |
+
embedding = await self.embedding_func([query])
|
| 61 |
+
results = self._client.search(
|
| 62 |
+
collection_name=self.namespace,
|
| 63 |
+
data=embedding,
|
| 64 |
+
limit=top_k,
|
| 65 |
+
output_fields=list(self.meta_fields),
|
| 66 |
+
search_params={"metric_type": "COSINE", "params": {"radius": 0.2}},
|
| 67 |
+
)
|
| 68 |
+
return [
|
| 69 |
+
{**dp["entity"], "id": dp["id"], "distance": dp["distance"]}
|
| 70 |
+
for dp in results[0]
|
| 71 |
+
]
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def insert():
|
| 75 |
+
data = ["YOUR TEXT DATA HERE", "YOUR TEXT DATA HERE"]
|
| 76 |
+
rag = GraphRAG(
|
| 77 |
+
working_dir="./nano_graphrag_cache_milvus_TEST",
|
| 78 |
+
enable_llm_cache=True,
|
| 79 |
+
vector_db_storage_cls=MilvusLiteStorge,
|
| 80 |
+
)
|
| 81 |
+
rag.insert(data)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def query():
|
| 85 |
+
rag = GraphRAG(
|
| 86 |
+
working_dir="./nano_graphrag_cache_milvus_TEST",
|
| 87 |
+
enable_llm_cache=True,
|
| 88 |
+
vector_db_storage_cls=MilvusLiteStorge,
|
| 89 |
+
)
|
| 90 |
+
print(rag.query("YOUR QUERY HERE", param=QueryParam(mode="local")))
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
insert()
|
| 94 |
+
query()
|
nano-graphrag/examples/using_ollama_as_llm.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
import ollama
|
| 4 |
+
from nano_graphrag import GraphRAG, QueryParam
|
| 5 |
+
from nano_graphrag import GraphRAG, QueryParam
|
| 6 |
+
from nano_graphrag.base import BaseKVStorage
|
| 7 |
+
from nano_graphrag._utils import compute_args_hash
|
| 8 |
+
|
| 9 |
+
logging.basicConfig(level=logging.WARNING)
|
| 10 |
+
logging.getLogger("nano-graphrag").setLevel(logging.INFO)
|
| 11 |
+
|
| 12 |
+
# !!! qwen2-7B maybe produce unparsable results and cause the extraction of graph to fail.
|
| 13 |
+
MODEL = "qwen2"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
async def ollama_model_if_cache(
|
| 17 |
+
prompt, system_prompt=None, history_messages=[], **kwargs
|
| 18 |
+
) -> str:
|
| 19 |
+
# remove kwargs that are not supported by ollama
|
| 20 |
+
kwargs.pop("max_tokens", None)
|
| 21 |
+
kwargs.pop("response_format", None)
|
| 22 |
+
|
| 23 |
+
ollama_client = ollama.AsyncClient()
|
| 24 |
+
messages = []
|
| 25 |
+
if system_prompt:
|
| 26 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 27 |
+
|
| 28 |
+
# Get the cached response if having-------------------
|
| 29 |
+
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
| 30 |
+
messages.extend(history_messages)
|
| 31 |
+
messages.append({"role": "user", "content": prompt})
|
| 32 |
+
if hashing_kv is not None:
|
| 33 |
+
args_hash = compute_args_hash(MODEL, messages)
|
| 34 |
+
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
| 35 |
+
if if_cache_return is not None:
|
| 36 |
+
return if_cache_return["return"]
|
| 37 |
+
# -----------------------------------------------------
|
| 38 |
+
response = await ollama_client.chat(model=MODEL, messages=messages, **kwargs)
|
| 39 |
+
|
| 40 |
+
result = response["message"]["content"]
|
| 41 |
+
# Cache the response if having-------------------
|
| 42 |
+
if hashing_kv is not None:
|
| 43 |
+
await hashing_kv.upsert({args_hash: {"return": result, "model": MODEL}})
|
| 44 |
+
# -----------------------------------------------------
|
| 45 |
+
return result
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def remove_if_exist(file):
|
| 49 |
+
if os.path.exists(file):
|
| 50 |
+
os.remove(file)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
WORKING_DIR = "./nano_graphrag_cache_ollama_TEST"
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def query():
|
| 57 |
+
rag = GraphRAG(
|
| 58 |
+
working_dir=WORKING_DIR,
|
| 59 |
+
best_model_func=ollama_model_if_cache,
|
| 60 |
+
cheap_model_func=ollama_model_if_cache,
|
| 61 |
+
)
|
| 62 |
+
print(
|
| 63 |
+
rag.query(
|
| 64 |
+
"What are the top themes in this story?", param=QueryParam(mode="global")
|
| 65 |
+
)
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def insert():
|
| 70 |
+
from time import time
|
| 71 |
+
|
| 72 |
+
with open("./tests/mock_data.txt", encoding="utf-8-sig") as f:
|
| 73 |
+
FAKE_TEXT = f.read()
|
| 74 |
+
|
| 75 |
+
remove_if_exist(f"{WORKING_DIR}/vdb_entities.json")
|
| 76 |
+
remove_if_exist(f"{WORKING_DIR}/kv_store_full_docs.json")
|
| 77 |
+
remove_if_exist(f"{WORKING_DIR}/kv_store_text_chunks.json")
|
| 78 |
+
remove_if_exist(f"{WORKING_DIR}/kv_store_community_reports.json")
|
| 79 |
+
remove_if_exist(f"{WORKING_DIR}/graph_chunk_entity_relation.graphml")
|
| 80 |
+
|
| 81 |
+
rag = GraphRAG(
|
| 82 |
+
working_dir=WORKING_DIR,
|
| 83 |
+
enable_llm_cache=True,
|
| 84 |
+
best_model_func=ollama_model_if_cache,
|
| 85 |
+
cheap_model_func=ollama_model_if_cache,
|
| 86 |
+
)
|
| 87 |
+
start = time()
|
| 88 |
+
rag.insert(FAKE_TEXT)
|
| 89 |
+
print("indexing time:", time() - start)
|
| 90 |
+
# rag = GraphRAG(working_dir=WORKING_DIR, enable_llm_cache=True)
|
| 91 |
+
# rag.insert(FAKE_TEXT[half_len:])
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
if __name__ == "__main__":
|
| 95 |
+
insert()
|
| 96 |
+
query()
|
nano-graphrag/examples/using_ollama_as_llm_and_embedding.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
sys.path.append("..")
|
| 5 |
+
import logging
|
| 6 |
+
import ollama
|
| 7 |
+
import numpy as np
|
| 8 |
+
from nano_graphrag import GraphRAG, QueryParam
|
| 9 |
+
from nano_graphrag.base import BaseKVStorage
|
| 10 |
+
from nano_graphrag._utils import compute_args_hash, wrap_embedding_func_with_attrs
|
| 11 |
+
|
| 12 |
+
logging.basicConfig(level=logging.WARNING)
|
| 13 |
+
logging.getLogger("nano-graphrag").setLevel(logging.INFO)
|
| 14 |
+
|
| 15 |
+
# Assumed llm model settings
|
| 16 |
+
MODEL = "your_model_name"
|
| 17 |
+
|
| 18 |
+
# Assumed embedding model settings
|
| 19 |
+
EMBEDDING_MODEL = "nomic-embed-text"
|
| 20 |
+
EMBEDDING_MODEL_DIM = 768
|
| 21 |
+
EMBEDDING_MODEL_MAX_TOKENS = 8192
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
async def ollama_model_if_cache(
|
| 25 |
+
prompt, system_prompt=None, history_messages=[], **kwargs
|
| 26 |
+
) -> str:
|
| 27 |
+
# remove kwargs that are not supported by ollama
|
| 28 |
+
kwargs.pop("max_tokens", None)
|
| 29 |
+
kwargs.pop("response_format", None)
|
| 30 |
+
|
| 31 |
+
ollama_client = ollama.AsyncClient()
|
| 32 |
+
messages = []
|
| 33 |
+
if system_prompt:
|
| 34 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 35 |
+
|
| 36 |
+
# Get the cached response if having-------------------
|
| 37 |
+
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
| 38 |
+
messages.extend(history_messages)
|
| 39 |
+
messages.append({"role": "user", "content": prompt})
|
| 40 |
+
if hashing_kv is not None:
|
| 41 |
+
args_hash = compute_args_hash(MODEL, messages)
|
| 42 |
+
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
| 43 |
+
if if_cache_return is not None:
|
| 44 |
+
return if_cache_return["return"]
|
| 45 |
+
# -----------------------------------------------------
|
| 46 |
+
response = await ollama_client.chat(model=MODEL, messages=messages, **kwargs)
|
| 47 |
+
|
| 48 |
+
result = response["message"]["content"]
|
| 49 |
+
# Cache the response if having-------------------
|
| 50 |
+
if hashing_kv is not None:
|
| 51 |
+
await hashing_kv.upsert({args_hash: {"return": result, "model": MODEL}})
|
| 52 |
+
# -----------------------------------------------------
|
| 53 |
+
return result
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def remove_if_exist(file):
|
| 57 |
+
if os.path.exists(file):
|
| 58 |
+
os.remove(file)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
WORKING_DIR = "./nano_graphrag_cache_ollama_TEST"
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def query():
|
| 65 |
+
rag = GraphRAG(
|
| 66 |
+
working_dir=WORKING_DIR,
|
| 67 |
+
best_model_func=ollama_model_if_cache,
|
| 68 |
+
cheap_model_func=ollama_model_if_cache,
|
| 69 |
+
embedding_func=ollama_embedding,
|
| 70 |
+
)
|
| 71 |
+
print(
|
| 72 |
+
rag.query(
|
| 73 |
+
"What are the top themes in this story?", param=QueryParam(mode="global")
|
| 74 |
+
)
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def insert():
|
| 79 |
+
from time import time
|
| 80 |
+
|
| 81 |
+
with open("./tests/mock_data.txt", encoding="utf-8-sig") as f:
|
| 82 |
+
FAKE_TEXT = f.read()
|
| 83 |
+
|
| 84 |
+
remove_if_exist(f"{WORKING_DIR}/vdb_entities.json")
|
| 85 |
+
remove_if_exist(f"{WORKING_DIR}/kv_store_full_docs.json")
|
| 86 |
+
remove_if_exist(f"{WORKING_DIR}/kv_store_text_chunks.json")
|
| 87 |
+
remove_if_exist(f"{WORKING_DIR}/kv_store_community_reports.json")
|
| 88 |
+
remove_if_exist(f"{WORKING_DIR}/graph_chunk_entity_relation.graphml")
|
| 89 |
+
|
| 90 |
+
rag = GraphRAG(
|
| 91 |
+
working_dir=WORKING_DIR,
|
| 92 |
+
enable_llm_cache=True,
|
| 93 |
+
best_model_func=ollama_model_if_cache,
|
| 94 |
+
cheap_model_func=ollama_model_if_cache,
|
| 95 |
+
embedding_func=ollama_embedding,
|
| 96 |
+
)
|
| 97 |
+
start = time()
|
| 98 |
+
rag.insert(FAKE_TEXT)
|
| 99 |
+
print("indexing time:", time() - start)
|
| 100 |
+
# rag = GraphRAG(working_dir=WORKING_DIR, enable_llm_cache=True)
|
| 101 |
+
# rag.insert(FAKE_TEXT[half_len:])
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# We're using Ollama to generate embeddings for the BGE model
|
| 105 |
+
@wrap_embedding_func_with_attrs(
|
| 106 |
+
embedding_dim=EMBEDDING_MODEL_DIM,
|
| 107 |
+
max_token_size=EMBEDDING_MODEL_MAX_TOKENS,
|
| 108 |
+
)
|
| 109 |
+
async def ollama_embedding(texts: list[str]) -> np.ndarray:
|
| 110 |
+
embed_text = []
|
| 111 |
+
for text in texts:
|
| 112 |
+
data = ollama.embeddings(model=EMBEDDING_MODEL, prompt=text)
|
| 113 |
+
embed_text.append(data["embedding"])
|
| 114 |
+
|
| 115 |
+
return embed_text
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
if __name__ == "__main__":
|
| 119 |
+
insert()
|
| 120 |
+
query()
|
nano-graphrag/examples/using_qdrant_as_vectorDB.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import asyncio
|
| 3 |
+
import uuid
|
| 4 |
+
import numpy as np
|
| 5 |
+
from nano_graphrag import GraphRAG, QueryParam
|
| 6 |
+
from nano_graphrag._utils import logger
|
| 7 |
+
from nano_graphrag.base import BaseVectorStorage
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
from qdrant_client import QdrantClient
|
| 12 |
+
from qdrant_client.models import VectorParams, Distance, PointStruct, SearchParams
|
| 13 |
+
except ImportError as original_error:
|
| 14 |
+
raise ImportError(
|
| 15 |
+
"Qdrant client is not installed. Install it using: pip install qdrant-client\n"
|
| 16 |
+
) from original_error
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class QdrantStorage(BaseVectorStorage):
|
| 21 |
+
def __post_init__(self):
|
| 22 |
+
|
| 23 |
+
# Use a local file-based Qdrant storage
|
| 24 |
+
# Useful for prototyping and CI.
|
| 25 |
+
# For production, refer to:
|
| 26 |
+
# https://qdrant.tech/documentation/guides/installation/
|
| 27 |
+
self._client_file_path = os.path.join(
|
| 28 |
+
self.global_config["working_dir"], "qdrant_storage"
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
self._client = QdrantClient(path=self._client_file_path)
|
| 32 |
+
|
| 33 |
+
self._max_batch_size = self.global_config["embedding_batch_num"]
|
| 34 |
+
|
| 35 |
+
if not self._client.collection_exists(collection_name=self.namespace):
|
| 36 |
+
self._client.create_collection(
|
| 37 |
+
collection_name=self.namespace,
|
| 38 |
+
vectors_config=VectorParams(
|
| 39 |
+
size=self.embedding_func.embedding_dim, distance=Distance.COSINE
|
| 40 |
+
),
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
async def upsert(self, data: dict[str, dict]):
|
| 44 |
+
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
| 45 |
+
|
| 46 |
+
list_data = [
|
| 47 |
+
{
|
| 48 |
+
"id": k,
|
| 49 |
+
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
|
| 50 |
+
}
|
| 51 |
+
for k, v in data.items()
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
contents = [v["content"] for v in data.values()]
|
| 55 |
+
batches = [
|
| 56 |
+
contents[i : i + self._max_batch_size]
|
| 57 |
+
for i in range(0, len(contents), self._max_batch_size)
|
| 58 |
+
]
|
| 59 |
+
|
| 60 |
+
embeddings_list = await asyncio.gather(
|
| 61 |
+
*[self.embedding_func(batch) for batch in batches]
|
| 62 |
+
)
|
| 63 |
+
embeddings = np.concatenate(embeddings_list)
|
| 64 |
+
|
| 65 |
+
points = [
|
| 66 |
+
PointStruct(
|
| 67 |
+
id=uuid.uuid4().hex,
|
| 68 |
+
vector=embeddings[i].tolist(),
|
| 69 |
+
payload=data,
|
| 70 |
+
)
|
| 71 |
+
for i, data in enumerate(list_data)
|
| 72 |
+
]
|
| 73 |
+
|
| 74 |
+
results = self._client.upsert(collection_name=self.namespace, points=points)
|
| 75 |
+
return results
|
| 76 |
+
|
| 77 |
+
async def query(self, query, top_k=5):
|
| 78 |
+
embedding = await self.embedding_func([query])
|
| 79 |
+
|
| 80 |
+
results = self._client.query_points(
|
| 81 |
+
collection_name=self.namespace,
|
| 82 |
+
query=embedding[0].tolist(),
|
| 83 |
+
limit=top_k,
|
| 84 |
+
).points
|
| 85 |
+
|
| 86 |
+
return [
|
| 87 |
+
{**result.payload, "score": result.score}
|
| 88 |
+
for result in results
|
| 89 |
+
]
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def insert():
|
| 93 |
+
data = ["YOUR TEXT DATA HERE", "YOUR TEXT DATA HERE"]
|
| 94 |
+
rag = GraphRAG(
|
| 95 |
+
working_dir="./nano_graphrag_cache_qdrant_TEST",
|
| 96 |
+
enable_llm_cache=True,
|
| 97 |
+
vector_db_storage_cls=QdrantStorage,
|
| 98 |
+
)
|
| 99 |
+
rag.insert(data)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def query():
|
| 103 |
+
rag = GraphRAG(
|
| 104 |
+
working_dir="./nano_graphrag_cache_qdrant_TEST",
|
| 105 |
+
enable_llm_cache=True,
|
| 106 |
+
vector_db_storage_cls=QdrantStorage,
|
| 107 |
+
)
|
| 108 |
+
print(rag.query("YOUR QUERY HERE", param=QueryParam(mode="local")))
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
if __name__ == "__main__":
|
| 112 |
+
insert()
|
| 113 |
+
query()
|
nano-graphrag/nano_graphrag/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .graphrag import GraphRAG, QueryParam
|
| 2 |
+
|
| 3 |
+
__version__ = "0.0.8.2"
|
| 4 |
+
__author__ = "Jianbai Ye"
|
| 5 |
+
__url__ = "https://github.com/gusye1234/nano-graphrag"
|
| 6 |
+
|
| 7 |
+
# dp stands for data pack
|
nano-graphrag/nano_graphrag/_llm.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import numpy as np
|
| 3 |
+
from typing import Optional, List, Any, Callable
|
| 4 |
+
|
| 5 |
+
import aioboto3
|
| 6 |
+
from openai import AsyncOpenAI, AsyncAzureOpenAI, APIConnectionError, RateLimitError
|
| 7 |
+
|
| 8 |
+
from tenacity import (
|
| 9 |
+
retry,
|
| 10 |
+
stop_after_attempt,
|
| 11 |
+
wait_exponential,
|
| 12 |
+
retry_if_exception_type,
|
| 13 |
+
)
|
| 14 |
+
import os
|
| 15 |
+
|
| 16 |
+
from ._utils import compute_args_hash, wrap_embedding_func_with_attrs
|
| 17 |
+
from .base import BaseKVStorage
|
| 18 |
+
|
| 19 |
+
global_openai_async_client = None
|
| 20 |
+
global_azure_openai_async_client = None
|
| 21 |
+
global_amazon_bedrock_async_client = None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def get_openai_async_client_instance():
|
| 25 |
+
global global_openai_async_client
|
| 26 |
+
if global_openai_async_client is None:
|
| 27 |
+
global_openai_async_client = AsyncOpenAI()
|
| 28 |
+
return global_openai_async_client
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_azure_openai_async_client_instance():
|
| 32 |
+
global global_azure_openai_async_client
|
| 33 |
+
if global_azure_openai_async_client is None:
|
| 34 |
+
global_azure_openai_async_client = AsyncAzureOpenAI()
|
| 35 |
+
return global_azure_openai_async_client
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_amazon_bedrock_async_client_instance():
|
| 39 |
+
global global_amazon_bedrock_async_client
|
| 40 |
+
if global_amazon_bedrock_async_client is None:
|
| 41 |
+
global_amazon_bedrock_async_client = aioboto3.Session()
|
| 42 |
+
return global_amazon_bedrock_async_client
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@retry(
|
| 46 |
+
stop=stop_after_attempt(5),
|
| 47 |
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
| 48 |
+
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
|
| 49 |
+
)
|
| 50 |
+
async def openai_complete_if_cache(
|
| 51 |
+
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
| 52 |
+
) -> str:
|
| 53 |
+
openai_async_client = get_openai_async_client_instance()
|
| 54 |
+
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
| 55 |
+
messages = []
|
| 56 |
+
if system_prompt:
|
| 57 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 58 |
+
messages.extend(history_messages)
|
| 59 |
+
messages.append({"role": "user", "content": prompt})
|
| 60 |
+
if hashing_kv is not None:
|
| 61 |
+
args_hash = compute_args_hash(model, messages)
|
| 62 |
+
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
| 63 |
+
if if_cache_return is not None:
|
| 64 |
+
return if_cache_return["return"]
|
| 65 |
+
|
| 66 |
+
response = await openai_async_client.chat.completions.create(
|
| 67 |
+
model=model, messages=messages, **kwargs
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
if hashing_kv is not None:
|
| 71 |
+
await hashing_kv.upsert(
|
| 72 |
+
{args_hash: {"return": response.choices[0].message.content, "model": model}}
|
| 73 |
+
)
|
| 74 |
+
await hashing_kv.index_done_callback()
|
| 75 |
+
return response.choices[0].message.content
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@retry(
|
| 79 |
+
stop=stop_after_attempt(5),
|
| 80 |
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
| 81 |
+
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
|
| 82 |
+
)
|
| 83 |
+
async def amazon_bedrock_complete_if_cache(
|
| 84 |
+
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
| 85 |
+
) -> str:
|
| 86 |
+
amazon_bedrock_async_client = get_amazon_bedrock_async_client_instance()
|
| 87 |
+
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
| 88 |
+
messages = []
|
| 89 |
+
messages.extend(history_messages)
|
| 90 |
+
messages.append({"role": "user", "content": [{"text": prompt}]})
|
| 91 |
+
if hashing_kv is not None:
|
| 92 |
+
args_hash = compute_args_hash(model, messages)
|
| 93 |
+
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
| 94 |
+
if if_cache_return is not None:
|
| 95 |
+
return if_cache_return["return"]
|
| 96 |
+
|
| 97 |
+
inference_config = {
|
| 98 |
+
"temperature": 0,
|
| 99 |
+
"maxTokens": 4096 if "max_tokens" not in kwargs else kwargs["max_tokens"],
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
async with amazon_bedrock_async_client.client(
|
| 103 |
+
"bedrock-runtime",
|
| 104 |
+
region_name=os.getenv("AWS_REGION", "us-east-1")
|
| 105 |
+
) as bedrock_runtime:
|
| 106 |
+
if system_prompt:
|
| 107 |
+
response = await bedrock_runtime.converse(
|
| 108 |
+
modelId=model, messages=messages, inferenceConfig=inference_config,
|
| 109 |
+
system=[{"text": system_prompt}]
|
| 110 |
+
)
|
| 111 |
+
else:
|
| 112 |
+
response = await bedrock_runtime.converse(
|
| 113 |
+
modelId=model, messages=messages, inferenceConfig=inference_config,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
if hashing_kv is not None:
|
| 117 |
+
await hashing_kv.upsert(
|
| 118 |
+
{args_hash: {"return": response["output"]["message"]["content"][0]["text"], "model": model}}
|
| 119 |
+
)
|
| 120 |
+
await hashing_kv.index_done_callback()
|
| 121 |
+
return response["output"]["message"]["content"][0]["text"]
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def create_amazon_bedrock_complete_function(model_id: str) -> Callable:
|
| 125 |
+
"""
|
| 126 |
+
Factory function to dynamically create completion functions for Amazon Bedrock
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
model_id (str): Amazon Bedrock model identifier (e.g., "us.anthropic.claude-3-sonnet-20240229-v1:0")
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
Callable: Generated completion function
|
| 133 |
+
"""
|
| 134 |
+
async def bedrock_complete(
|
| 135 |
+
prompt: str,
|
| 136 |
+
system_prompt: Optional[str] = None,
|
| 137 |
+
history_messages: List[Any] = [],
|
| 138 |
+
**kwargs
|
| 139 |
+
) -> str:
|
| 140 |
+
return await amazon_bedrock_complete_if_cache(
|
| 141 |
+
model_id,
|
| 142 |
+
prompt,
|
| 143 |
+
system_prompt=system_prompt,
|
| 144 |
+
history_messages=history_messages,
|
| 145 |
+
**kwargs
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# Set function name for easier debugging
|
| 149 |
+
bedrock_complete.__name__ = f"{model_id}_complete"
|
| 150 |
+
|
| 151 |
+
return bedrock_complete
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
async def gpt_4o_complete(
|
| 155 |
+
prompt, system_prompt=None, history_messages=[], **kwargs
|
| 156 |
+
) -> str:
|
| 157 |
+
return await openai_complete_if_cache(
|
| 158 |
+
"gpt-4o",
|
| 159 |
+
prompt,
|
| 160 |
+
system_prompt=system_prompt,
|
| 161 |
+
history_messages=history_messages,
|
| 162 |
+
**kwargs,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
async def gpt_4o_mini_complete(
|
| 167 |
+
prompt, system_prompt=None, history_messages=[], **kwargs
|
| 168 |
+
) -> str:
|
| 169 |
+
return await openai_complete_if_cache(
|
| 170 |
+
"gpt-4o-mini",
|
| 171 |
+
prompt,
|
| 172 |
+
system_prompt=system_prompt,
|
| 173 |
+
history_messages=history_messages,
|
| 174 |
+
**kwargs,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
| 179 |
+
@retry(
|
| 180 |
+
stop=stop_after_attempt(5),
|
| 181 |
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
| 182 |
+
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
|
| 183 |
+
)
|
| 184 |
+
async def amazon_bedrock_embedding(texts: list[str]) -> np.ndarray:
|
| 185 |
+
amazon_bedrock_async_client = get_amazon_bedrock_async_client_instance()
|
| 186 |
+
|
| 187 |
+
async with amazon_bedrock_async_client.client(
|
| 188 |
+
"bedrock-runtime",
|
| 189 |
+
region_name=os.getenv("AWS_REGION", "us-east-1")
|
| 190 |
+
) as bedrock_runtime:
|
| 191 |
+
embeddings = []
|
| 192 |
+
for text in texts:
|
| 193 |
+
body = json.dumps(
|
| 194 |
+
{
|
| 195 |
+
"inputText": text,
|
| 196 |
+
"dimensions": 1024,
|
| 197 |
+
}
|
| 198 |
+
)
|
| 199 |
+
response = await bedrock_runtime.invoke_model(
|
| 200 |
+
modelId="amazon.titan-embed-text-v2:0", body=body,
|
| 201 |
+
)
|
| 202 |
+
response_body = await response.get("body").read()
|
| 203 |
+
embeddings.append(json.loads(response_body))
|
| 204 |
+
return np.array([dp["embedding"] for dp in embeddings])
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
| 208 |
+
@retry(
|
| 209 |
+
stop=stop_after_attempt(5),
|
| 210 |
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
| 211 |
+
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
|
| 212 |
+
)
|
| 213 |
+
async def openai_embedding(texts: list[str]) -> np.ndarray:
|
| 214 |
+
openai_async_client = get_openai_async_client_instance()
|
| 215 |
+
response = await openai_async_client.embeddings.create(
|
| 216 |
+
model="text-embedding-3-small", input=texts, encoding_format="float"
|
| 217 |
+
)
|
| 218 |
+
return np.array([dp.embedding for dp in response.data])
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
@retry(
|
| 222 |
+
stop=stop_after_attempt(3),
|
| 223 |
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
| 224 |
+
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
|
| 225 |
+
)
|
| 226 |
+
async def azure_openai_complete_if_cache(
|
| 227 |
+
deployment_name, prompt, system_prompt=None, history_messages=[], **kwargs
|
| 228 |
+
) -> str:
|
| 229 |
+
azure_openai_client = get_azure_openai_async_client_instance()
|
| 230 |
+
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
| 231 |
+
messages = []
|
| 232 |
+
if system_prompt:
|
| 233 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 234 |
+
messages.extend(history_messages)
|
| 235 |
+
messages.append({"role": "user", "content": prompt})
|
| 236 |
+
if hashing_kv is not None:
|
| 237 |
+
args_hash = compute_args_hash(deployment_name, messages)
|
| 238 |
+
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
| 239 |
+
if if_cache_return is not None:
|
| 240 |
+
return if_cache_return["return"]
|
| 241 |
+
|
| 242 |
+
response = await azure_openai_client.chat.completions.create(
|
| 243 |
+
model=deployment_name, messages=messages, **kwargs
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
if hashing_kv is not None:
|
| 247 |
+
await hashing_kv.upsert(
|
| 248 |
+
{
|
| 249 |
+
args_hash: {
|
| 250 |
+
"return": response.choices[0].message.content,
|
| 251 |
+
"model": deployment_name,
|
| 252 |
+
}
|
| 253 |
+
}
|
| 254 |
+
)
|
| 255 |
+
await hashing_kv.index_done_callback()
|
| 256 |
+
return response.choices[0].message.content
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
async def azure_gpt_4o_complete(
|
| 260 |
+
prompt, system_prompt=None, history_messages=[], **kwargs
|
| 261 |
+
) -> str:
|
| 262 |
+
return await azure_openai_complete_if_cache(
|
| 263 |
+
"gpt-4o",
|
| 264 |
+
prompt,
|
| 265 |
+
system_prompt=system_prompt,
|
| 266 |
+
history_messages=history_messages,
|
| 267 |
+
**kwargs,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
async def azure_gpt_4o_mini_complete(
|
| 272 |
+
prompt, system_prompt=None, history_messages=[], **kwargs
|
| 273 |
+
) -> str:
|
| 274 |
+
return await azure_openai_complete_if_cache(
|
| 275 |
+
"gpt-4o-mini",
|
| 276 |
+
prompt,
|
| 277 |
+
system_prompt=system_prompt,
|
| 278 |
+
history_messages=history_messages,
|
| 279 |
+
**kwargs,
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
| 284 |
+
@retry(
|
| 285 |
+
stop=stop_after_attempt(3),
|
| 286 |
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
| 287 |
+
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
|
| 288 |
+
)
|
| 289 |
+
async def azure_openai_embedding(texts: list[str]) -> np.ndarray:
|
| 290 |
+
azure_openai_client = get_azure_openai_async_client_instance()
|
| 291 |
+
response = await azure_openai_client.embeddings.create(
|
| 292 |
+
model="text-embedding-3-small", input=texts, encoding_format="float"
|
| 293 |
+
)
|
| 294 |
+
return np.array([dp.embedding for dp in response.data])
|
nano-graphrag/nano_graphrag/_op.py
ADDED
|
@@ -0,0 +1,1140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import json
|
| 3 |
+
import asyncio
|
| 4 |
+
from typing import Union
|
| 5 |
+
from collections import Counter, defaultdict
|
| 6 |
+
from ._splitter import SeparatorSplitter
|
| 7 |
+
from ._utils import (
|
| 8 |
+
logger,
|
| 9 |
+
clean_str,
|
| 10 |
+
compute_mdhash_id,
|
| 11 |
+
is_float_regex,
|
| 12 |
+
list_of_list_to_csv,
|
| 13 |
+
pack_user_ass_to_openai_messages,
|
| 14 |
+
split_string_by_multi_markers,
|
| 15 |
+
truncate_list_by_token_size,
|
| 16 |
+
|
| 17 |
+
TokenizerWrapper
|
| 18 |
+
)
|
| 19 |
+
from .base import (
|
| 20 |
+
BaseGraphStorage,
|
| 21 |
+
BaseKVStorage,
|
| 22 |
+
BaseVectorStorage,
|
| 23 |
+
SingleCommunitySchema,
|
| 24 |
+
CommunitySchema,
|
| 25 |
+
TextChunkSchema,
|
| 26 |
+
QueryParam,
|
| 27 |
+
)
|
| 28 |
+
from .prompt import GRAPH_FIELD_SEP, PROMPTS
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def chunking_by_token_size(
|
| 32 |
+
tokens_list: list[list[int]],
|
| 33 |
+
doc_keys,
|
| 34 |
+
tokenizer_wrapper: TokenizerWrapper,
|
| 35 |
+
overlap_token_size=128,
|
| 36 |
+
max_token_size=1024,
|
| 37 |
+
):
|
| 38 |
+
results = []
|
| 39 |
+
for index, tokens in enumerate(tokens_list):
|
| 40 |
+
chunk_token = []
|
| 41 |
+
lengths = []
|
| 42 |
+
for start in range(0, len(tokens), max_token_size - overlap_token_size):
|
| 43 |
+
chunk_token.append(tokens[start : start + max_token_size])
|
| 44 |
+
lengths.append(min(max_token_size, len(tokens) - start))
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
chunk_texts = tokenizer_wrapper.decode_batch(chunk_token)
|
| 48 |
+
|
| 49 |
+
for i, chunk in enumerate(chunk_texts):
|
| 50 |
+
results.append(
|
| 51 |
+
{
|
| 52 |
+
"tokens": lengths[i],
|
| 53 |
+
"content": chunk.strip(),
|
| 54 |
+
"chunk_order_index": i,
|
| 55 |
+
"full_doc_id": doc_keys[index],
|
| 56 |
+
}
|
| 57 |
+
)
|
| 58 |
+
return results
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def chunking_by_seperators(
|
| 62 |
+
tokens_list: list[list[int]],
|
| 63 |
+
doc_keys,
|
| 64 |
+
tokenizer_wrapper: TokenizerWrapper,
|
| 65 |
+
overlap_token_size=128,
|
| 66 |
+
max_token_size=1024,
|
| 67 |
+
):
|
| 68 |
+
from .prompt import PROMPTS
|
| 69 |
+
# *** 修改 ***: 直接使用 wrapper 编码,而不是获取底层 tokenizer
|
| 70 |
+
separators = [tokenizer_wrapper.encode(s) for s in PROMPTS["default_text_separator"]]
|
| 71 |
+
splitter = SeparatorSplitter(
|
| 72 |
+
separators=separators,
|
| 73 |
+
chunk_size=max_token_size,
|
| 74 |
+
chunk_overlap=overlap_token_size,
|
| 75 |
+
)
|
| 76 |
+
results = []
|
| 77 |
+
for index, tokens in enumerate(tokens_list):
|
| 78 |
+
chunk_tokens = splitter.split_tokens(tokens)
|
| 79 |
+
lengths = [len(c) for c in chunk_tokens]
|
| 80 |
+
|
| 81 |
+
decoded_chunks = tokenizer_wrapper.decode_batch(chunk_tokens)
|
| 82 |
+
for i, chunk in enumerate(decoded_chunks):
|
| 83 |
+
results.append(
|
| 84 |
+
{
|
| 85 |
+
"tokens": lengths[i],
|
| 86 |
+
"content": chunk.strip(),
|
| 87 |
+
"chunk_order_index": i,
|
| 88 |
+
"full_doc_id": doc_keys[index],
|
| 89 |
+
}
|
| 90 |
+
)
|
| 91 |
+
return results
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def get_chunks(new_docs, chunk_func=chunking_by_token_size, tokenizer_wrapper: TokenizerWrapper = None, **chunk_func_params):
|
| 95 |
+
inserting_chunks = {}
|
| 96 |
+
new_docs_list = list(new_docs.items())
|
| 97 |
+
docs = [new_doc[1]["content"] for new_doc in new_docs_list]
|
| 98 |
+
doc_keys = [new_doc[0] for new_doc in new_docs_list]
|
| 99 |
+
|
| 100 |
+
tokens = [tokenizer_wrapper.encode(doc) for doc in docs]
|
| 101 |
+
chunks = chunk_func(
|
| 102 |
+
tokens, doc_keys=doc_keys, tokenizer_wrapper=tokenizer_wrapper, overlap_token_size=chunk_func_params.get("overlap_token_size", 128), max_token_size=chunk_func_params.get("max_token_size", 1024)
|
| 103 |
+
)
|
| 104 |
+
for chunk in chunks:
|
| 105 |
+
inserting_chunks.update(
|
| 106 |
+
{compute_mdhash_id(chunk["content"], prefix="chunk-"): chunk}
|
| 107 |
+
)
|
| 108 |
+
return inserting_chunks
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
async def _handle_entity_relation_summary(
|
| 112 |
+
entity_or_relation_name: str,
|
| 113 |
+
description: str,
|
| 114 |
+
global_config: dict,
|
| 115 |
+
tokenizer_wrapper: TokenizerWrapper,
|
| 116 |
+
) -> str:
|
| 117 |
+
use_llm_func: callable = global_config["cheap_model_func"]
|
| 118 |
+
llm_max_tokens = global_config["cheap_model_max_token_size"]
|
| 119 |
+
summary_max_tokens = global_config["entity_summary_to_max_tokens"]
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
tokens = tokenizer_wrapper.encode(description)
|
| 123 |
+
if len(tokens) < summary_max_tokens:
|
| 124 |
+
return description
|
| 125 |
+
prompt_template = PROMPTS["summarize_entity_descriptions"]
|
| 126 |
+
|
| 127 |
+
use_description = tokenizer_wrapper.decode(tokens[:llm_max_tokens])
|
| 128 |
+
context_base = dict(
|
| 129 |
+
entity_name=entity_or_relation_name,
|
| 130 |
+
description_list=use_description.split(GRAPH_FIELD_SEP),
|
| 131 |
+
)
|
| 132 |
+
use_prompt = prompt_template.format(**context_base)
|
| 133 |
+
logger.debug(f"Trigger summary: {entity_or_relation_name}")
|
| 134 |
+
summary = await use_llm_func(use_prompt, max_tokens=summary_max_tokens)
|
| 135 |
+
return summary
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
async def _handle_single_entity_extraction(
|
| 139 |
+
record_attributes: list[str],
|
| 140 |
+
chunk_key: str,
|
| 141 |
+
):
|
| 142 |
+
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
|
| 143 |
+
return None
|
| 144 |
+
# add this record as a node in the G
|
| 145 |
+
entity_name = clean_str(record_attributes[1].upper())
|
| 146 |
+
if not entity_name.strip():
|
| 147 |
+
return None
|
| 148 |
+
entity_type = clean_str(record_attributes[2].upper())
|
| 149 |
+
entity_description = clean_str(record_attributes[3])
|
| 150 |
+
entity_source_id = chunk_key
|
| 151 |
+
return dict(
|
| 152 |
+
entity_name=entity_name,
|
| 153 |
+
entity_type=entity_type,
|
| 154 |
+
description=entity_description,
|
| 155 |
+
source_id=entity_source_id,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
async def _handle_single_relationship_extraction(
|
| 160 |
+
record_attributes: list[str],
|
| 161 |
+
chunk_key: str,
|
| 162 |
+
):
|
| 163 |
+
if len(record_attributes) < 5 or record_attributes[0] != '"relationship"':
|
| 164 |
+
return None
|
| 165 |
+
# add this record as edge
|
| 166 |
+
source = clean_str(record_attributes[1].upper())
|
| 167 |
+
target = clean_str(record_attributes[2].upper())
|
| 168 |
+
edge_description = clean_str(record_attributes[3])
|
| 169 |
+
edge_source_id = chunk_key
|
| 170 |
+
weight = (
|
| 171 |
+
float(record_attributes[-1]) if is_float_regex(record_attributes[-1]) else 1.0
|
| 172 |
+
)
|
| 173 |
+
return dict(
|
| 174 |
+
src_id=source,
|
| 175 |
+
tgt_id=target,
|
| 176 |
+
weight=weight,
|
| 177 |
+
description=edge_description,
|
| 178 |
+
source_id=edge_source_id,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
async def _merge_nodes_then_upsert(
|
| 183 |
+
entity_name: str,
|
| 184 |
+
nodes_data: list[dict],
|
| 185 |
+
knwoledge_graph_inst: BaseGraphStorage,
|
| 186 |
+
global_config: dict,
|
| 187 |
+
tokenizer_wrapper,
|
| 188 |
+
):
|
| 189 |
+
already_entitiy_types = []
|
| 190 |
+
already_source_ids = []
|
| 191 |
+
already_description = []
|
| 192 |
+
|
| 193 |
+
already_node = await knwoledge_graph_inst.get_node(entity_name)
|
| 194 |
+
if already_node is not None:
|
| 195 |
+
already_entitiy_types.append(already_node["entity_type"])
|
| 196 |
+
already_source_ids.extend(
|
| 197 |
+
split_string_by_multi_markers(already_node["source_id"], [GRAPH_FIELD_SEP])
|
| 198 |
+
)
|
| 199 |
+
already_description.append(already_node["description"])
|
| 200 |
+
|
| 201 |
+
entity_type = sorted(
|
| 202 |
+
Counter(
|
| 203 |
+
[dp["entity_type"] for dp in nodes_data] + already_entitiy_types
|
| 204 |
+
).items(),
|
| 205 |
+
key=lambda x: x[1],
|
| 206 |
+
reverse=True,
|
| 207 |
+
)[0][0]
|
| 208 |
+
description = GRAPH_FIELD_SEP.join(
|
| 209 |
+
sorted(set([dp["description"] for dp in nodes_data] + already_description))
|
| 210 |
+
)
|
| 211 |
+
source_id = GRAPH_FIELD_SEP.join(
|
| 212 |
+
set([dp["source_id"] for dp in nodes_data] + already_source_ids)
|
| 213 |
+
)
|
| 214 |
+
description = await _handle_entity_relation_summary(
|
| 215 |
+
entity_name, description, global_config, tokenizer_wrapper
|
| 216 |
+
)
|
| 217 |
+
node_data = dict(
|
| 218 |
+
entity_type=entity_type,
|
| 219 |
+
description=description,
|
| 220 |
+
source_id=source_id,
|
| 221 |
+
)
|
| 222 |
+
await knwoledge_graph_inst.upsert_node(
|
| 223 |
+
entity_name,
|
| 224 |
+
node_data=node_data,
|
| 225 |
+
)
|
| 226 |
+
node_data["entity_name"] = entity_name
|
| 227 |
+
return node_data
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
async def _merge_edges_then_upsert(
|
| 231 |
+
src_id: str,
|
| 232 |
+
tgt_id: str,
|
| 233 |
+
edges_data: list[dict],
|
| 234 |
+
knwoledge_graph_inst: BaseGraphStorage,
|
| 235 |
+
global_config: dict,
|
| 236 |
+
tokenizer_wrapper,
|
| 237 |
+
):
|
| 238 |
+
already_weights = []
|
| 239 |
+
already_source_ids = []
|
| 240 |
+
already_description = []
|
| 241 |
+
already_order = []
|
| 242 |
+
if await knwoledge_graph_inst.has_edge(src_id, tgt_id):
|
| 243 |
+
already_edge = await knwoledge_graph_inst.get_edge(src_id, tgt_id)
|
| 244 |
+
already_weights.append(already_edge["weight"])
|
| 245 |
+
already_source_ids.extend(
|
| 246 |
+
split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP])
|
| 247 |
+
)
|
| 248 |
+
already_description.append(already_edge["description"])
|
| 249 |
+
already_order.append(already_edge.get("order", 1))
|
| 250 |
+
|
| 251 |
+
# [numberchiffre]: `Relationship.order` is only returned from DSPy's predictions
|
| 252 |
+
order = min([dp.get("order", 1) for dp in edges_data] + already_order)
|
| 253 |
+
weight = sum([dp["weight"] for dp in edges_data] + already_weights)
|
| 254 |
+
description = GRAPH_FIELD_SEP.join(
|
| 255 |
+
sorted(set([dp["description"] for dp in edges_data] + already_description))
|
| 256 |
+
)
|
| 257 |
+
source_id = GRAPH_FIELD_SEP.join(
|
| 258 |
+
set([dp["source_id"] for dp in edges_data] + already_source_ids)
|
| 259 |
+
)
|
| 260 |
+
for need_insert_id in [src_id, tgt_id]:
|
| 261 |
+
if not (await knwoledge_graph_inst.has_node(need_insert_id)):
|
| 262 |
+
await knwoledge_graph_inst.upsert_node(
|
| 263 |
+
need_insert_id,
|
| 264 |
+
node_data={
|
| 265 |
+
"source_id": source_id,
|
| 266 |
+
"description": description,
|
| 267 |
+
"entity_type": '"UNKNOWN"',
|
| 268 |
+
},
|
| 269 |
+
)
|
| 270 |
+
description = await _handle_entity_relation_summary(
|
| 271 |
+
(src_id, tgt_id), description, global_config, tokenizer_wrapper
|
| 272 |
+
)
|
| 273 |
+
await knwoledge_graph_inst.upsert_edge(
|
| 274 |
+
src_id,
|
| 275 |
+
tgt_id,
|
| 276 |
+
edge_data=dict(
|
| 277 |
+
weight=weight, description=description, source_id=source_id, order=order
|
| 278 |
+
),
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
async def extract_entities(
|
| 283 |
+
chunks: dict[str, TextChunkSchema],
|
| 284 |
+
knwoledge_graph_inst: BaseGraphStorage,
|
| 285 |
+
entity_vdb: BaseVectorStorage,
|
| 286 |
+
tokenizer_wrapper,
|
| 287 |
+
global_config: dict,
|
| 288 |
+
using_amazon_bedrock: bool=False,
|
| 289 |
+
) -> Union[BaseGraphStorage, None]:
|
| 290 |
+
use_llm_func: callable = global_config["best_model_func"]
|
| 291 |
+
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
| 292 |
+
|
| 293 |
+
ordered_chunks = list(chunks.items())
|
| 294 |
+
|
| 295 |
+
entity_extract_prompt = PROMPTS["entity_extraction"]
|
| 296 |
+
context_base = dict(
|
| 297 |
+
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
|
| 298 |
+
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
|
| 299 |
+
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
|
| 300 |
+
entity_types=",".join(PROMPTS["DEFAULT_ENTITY_TYPES"]),
|
| 301 |
+
)
|
| 302 |
+
continue_prompt = PROMPTS["entiti_continue_extraction"]
|
| 303 |
+
if_loop_prompt = PROMPTS["entiti_if_loop_extraction"]
|
| 304 |
+
|
| 305 |
+
already_processed = 0
|
| 306 |
+
already_entities = 0
|
| 307 |
+
already_relations = 0
|
| 308 |
+
|
| 309 |
+
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
|
| 310 |
+
nonlocal already_processed, already_entities, already_relations
|
| 311 |
+
chunk_key = chunk_key_dp[0]
|
| 312 |
+
chunk_dp = chunk_key_dp[1]
|
| 313 |
+
content = chunk_dp["content"]
|
| 314 |
+
hint_prompt = entity_extract_prompt.format(**context_base, input_text=content)
|
| 315 |
+
final_result = await use_llm_func(hint_prompt)
|
| 316 |
+
if isinstance(final_result, list):
|
| 317 |
+
final_result = final_result[0]["text"]
|
| 318 |
+
|
| 319 |
+
history = pack_user_ass_to_openai_messages(hint_prompt, final_result, using_amazon_bedrock)
|
| 320 |
+
for now_glean_index in range(entity_extract_max_gleaning):
|
| 321 |
+
glean_result = await use_llm_func(continue_prompt, history_messages=history)
|
| 322 |
+
|
| 323 |
+
history += pack_user_ass_to_openai_messages(continue_prompt, glean_result, using_amazon_bedrock)
|
| 324 |
+
final_result += glean_result
|
| 325 |
+
if now_glean_index == entity_extract_max_gleaning - 1:
|
| 326 |
+
break
|
| 327 |
+
|
| 328 |
+
if_loop_result: str = await use_llm_func(
|
| 329 |
+
if_loop_prompt, history_messages=history
|
| 330 |
+
)
|
| 331 |
+
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
|
| 332 |
+
if if_loop_result != "yes":
|
| 333 |
+
break
|
| 334 |
+
|
| 335 |
+
records = split_string_by_multi_markers(
|
| 336 |
+
final_result,
|
| 337 |
+
[context_base["record_delimiter"], context_base["completion_delimiter"]],
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
maybe_nodes = defaultdict(list)
|
| 341 |
+
maybe_edges = defaultdict(list)
|
| 342 |
+
for record in records:
|
| 343 |
+
record = re.search(r"\((.*)\)", record)
|
| 344 |
+
if record is None:
|
| 345 |
+
continue
|
| 346 |
+
record = record.group(1)
|
| 347 |
+
record_attributes = split_string_by_multi_markers(
|
| 348 |
+
record, [context_base["tuple_delimiter"]]
|
| 349 |
+
)
|
| 350 |
+
if_entities = await _handle_single_entity_extraction(
|
| 351 |
+
record_attributes, chunk_key
|
| 352 |
+
)
|
| 353 |
+
if if_entities is not None:
|
| 354 |
+
maybe_nodes[if_entities["entity_name"]].append(if_entities)
|
| 355 |
+
continue
|
| 356 |
+
|
| 357 |
+
if_relation = await _handle_single_relationship_extraction(
|
| 358 |
+
record_attributes, chunk_key
|
| 359 |
+
)
|
| 360 |
+
if if_relation is not None:
|
| 361 |
+
maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(
|
| 362 |
+
if_relation
|
| 363 |
+
)
|
| 364 |
+
already_processed += 1
|
| 365 |
+
already_entities += len(maybe_nodes)
|
| 366 |
+
already_relations += len(maybe_edges)
|
| 367 |
+
now_ticks = PROMPTS["process_tickers"][
|
| 368 |
+
already_processed % len(PROMPTS["process_tickers"])
|
| 369 |
+
]
|
| 370 |
+
print(
|
| 371 |
+
f"{now_ticks} Processed {already_processed}({already_processed*100//len(ordered_chunks)}%) chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
|
| 372 |
+
end="",
|
| 373 |
+
flush=True,
|
| 374 |
+
)
|
| 375 |
+
return dict(maybe_nodes), dict(maybe_edges)
|
| 376 |
+
|
| 377 |
+
# use_llm_func is wrapped in ascynio.Semaphore, limiting max_async callings
|
| 378 |
+
results = await asyncio.gather(
|
| 379 |
+
*[_process_single_content(c) for c in ordered_chunks]
|
| 380 |
+
)
|
| 381 |
+
print() # clear the progress bar
|
| 382 |
+
maybe_nodes = defaultdict(list)
|
| 383 |
+
maybe_edges = defaultdict(list)
|
| 384 |
+
for m_nodes, m_edges in results:
|
| 385 |
+
for k, v in m_nodes.items():
|
| 386 |
+
maybe_nodes[k].extend(v)
|
| 387 |
+
for k, v in m_edges.items():
|
| 388 |
+
# it's undirected graph
|
| 389 |
+
maybe_edges[tuple(sorted(k))].extend(v)
|
| 390 |
+
all_entities_data = await asyncio.gather(
|
| 391 |
+
*[
|
| 392 |
+
_merge_nodes_then_upsert(k, v, knwoledge_graph_inst, global_config, tokenizer_wrapper)
|
| 393 |
+
for k, v in maybe_nodes.items()
|
| 394 |
+
]
|
| 395 |
+
)
|
| 396 |
+
await asyncio.gather(
|
| 397 |
+
*[
|
| 398 |
+
_merge_edges_then_upsert(k[0], k[1], v, knwoledge_graph_inst, global_config, tokenizer_wrapper)
|
| 399 |
+
for k, v in maybe_edges.items()
|
| 400 |
+
]
|
| 401 |
+
)
|
| 402 |
+
if not len(all_entities_data):
|
| 403 |
+
logger.warning("Didn't extract any entities, maybe your LLM is not working")
|
| 404 |
+
return None
|
| 405 |
+
if entity_vdb is not None:
|
| 406 |
+
data_for_vdb = {
|
| 407 |
+
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
| 408 |
+
"content": dp["entity_name"] + dp["description"],
|
| 409 |
+
"entity_name": dp["entity_name"],
|
| 410 |
+
}
|
| 411 |
+
for dp in all_entities_data
|
| 412 |
+
}
|
| 413 |
+
await entity_vdb.upsert(data_for_vdb)
|
| 414 |
+
return knwoledge_graph_inst
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
def _pack_single_community_by_sub_communities(
|
| 418 |
+
community: SingleCommunitySchema,
|
| 419 |
+
max_token_size: int,
|
| 420 |
+
already_reports: dict[str, CommunitySchema],
|
| 421 |
+
tokenizer_wrapper: TokenizerWrapper,
|
| 422 |
+
) -> tuple[str, int, set, set]:
|
| 423 |
+
all_sub_communities = [
|
| 424 |
+
already_reports[k] for k in community["sub_communities"] if k in already_reports
|
| 425 |
+
]
|
| 426 |
+
all_sub_communities = sorted(
|
| 427 |
+
all_sub_communities, key=lambda x: x["occurrence"], reverse=True
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
may_trun_all_sub_communities = truncate_list_by_token_size(
|
| 431 |
+
all_sub_communities,
|
| 432 |
+
key=lambda x: x["report_string"],
|
| 433 |
+
max_token_size=max_token_size,
|
| 434 |
+
tokenizer_wrapper=tokenizer_wrapper,
|
| 435 |
+
)
|
| 436 |
+
sub_fields = ["id", "report", "rating", "importance"]
|
| 437 |
+
sub_communities_describe = list_of_list_to_csv(
|
| 438 |
+
[sub_fields]
|
| 439 |
+
+ [
|
| 440 |
+
[
|
| 441 |
+
i,
|
| 442 |
+
c["report_string"],
|
| 443 |
+
c["report_json"].get("rating", -1),
|
| 444 |
+
c["occurrence"],
|
| 445 |
+
]
|
| 446 |
+
for i, c in enumerate(may_trun_all_sub_communities)
|
| 447 |
+
]
|
| 448 |
+
)
|
| 449 |
+
already_nodes = []
|
| 450 |
+
already_edges = []
|
| 451 |
+
for c in may_trun_all_sub_communities:
|
| 452 |
+
already_nodes.extend(c["nodes"])
|
| 453 |
+
already_edges.extend([tuple(e) for e in c["edges"]])
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
return (
|
| 457 |
+
sub_communities_describe,
|
| 458 |
+
len(tokenizer_wrapper.encode(sub_communities_describe)),
|
| 459 |
+
set(already_nodes),
|
| 460 |
+
set(already_edges),
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
async def _pack_single_community_describe(
|
| 465 |
+
knwoledge_graph_inst: BaseGraphStorage,
|
| 466 |
+
community: SingleCommunitySchema,
|
| 467 |
+
tokenizer_wrapper: "TokenizerWrapper",
|
| 468 |
+
max_token_size: int = 12000,
|
| 469 |
+
already_reports: dict[str, CommunitySchema] = {},
|
| 470 |
+
global_config: dict = {},
|
| 471 |
+
) -> str:
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
# 1. 准备原始数据
|
| 476 |
+
nodes_in_order = sorted(community["nodes"])
|
| 477 |
+
edges_in_order = sorted(community["edges"], key=lambda x: x[0] + x[1])
|
| 478 |
+
|
| 479 |
+
nodes_data = await asyncio.gather(
|
| 480 |
+
*[knwoledge_graph_inst.get_node(n) for n in nodes_in_order]
|
| 481 |
+
)
|
| 482 |
+
edges_data = await asyncio.gather(
|
| 483 |
+
*[knwoledge_graph_inst.get_edge(src, tgt) for src, tgt in edges_in_order]
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
# 2. 定义模板和固定开销
|
| 488 |
+
final_template = """-----Reports-----
|
| 489 |
+
```csv
|
| 490 |
+
{reports}
|
| 491 |
+
```
|
| 492 |
+
-----Entities-----
|
| 493 |
+
```csv
|
| 494 |
+
{entities}
|
| 495 |
+
```
|
| 496 |
+
-----Relationships-----
|
| 497 |
+
```csv
|
| 498 |
+
{relationships}
|
| 499 |
+
```"""
|
| 500 |
+
base_template_tokens = len(tokenizer_wrapper.encode(
|
| 501 |
+
final_template.format(reports="", entities="", relationships="")
|
| 502 |
+
))
|
| 503 |
+
remaining_budget = max_token_size - base_template_tokens
|
| 504 |
+
|
| 505 |
+
# 3. 处理子社区报告
|
| 506 |
+
report_describe = ""
|
| 507 |
+
contain_nodes = set()
|
| 508 |
+
contain_edges = set()
|
| 509 |
+
|
| 510 |
+
# 启发式截断检测
|
| 511 |
+
truncated = len(nodes_in_order) > 100 or len(edges_in_order) > 100
|
| 512 |
+
|
| 513 |
+
need_to_use_sub_communities = (
|
| 514 |
+
truncated and
|
| 515 |
+
community["sub_communities"] and
|
| 516 |
+
already_reports
|
| 517 |
+
)
|
| 518 |
+
force_to_use_sub_communities = global_config["addon_params"].get(
|
| 519 |
+
"force_to_use_sub_communities", False
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
if need_to_use_sub_communities or force_to_use_sub_communities:
|
| 523 |
+
logger.debug(f"Community {community['title']} using sub-communities")
|
| 524 |
+
# 获取子社区报告及包含的节点/边
|
| 525 |
+
result = _pack_single_community_by_sub_communities(
|
| 526 |
+
community, remaining_budget, already_reports, tokenizer_wrapper
|
| 527 |
+
)
|
| 528 |
+
report_describe, report_size, contain_nodes, contain_edges = result
|
| 529 |
+
remaining_budget = max(0, remaining_budget - report_size)
|
| 530 |
+
|
| 531 |
+
# 4. 准备节点和边数据(过滤子社区已包含的)
|
| 532 |
+
def format_row(row: list) -> str:
|
| 533 |
+
return ','.join('"{}"'.format(str(item).replace('"', '""')) for item in row)
|
| 534 |
+
|
| 535 |
+
node_fields = ["id", "entity", "type", "description", "degree"]
|
| 536 |
+
edge_fields = ["id", "source", "target", "description", "rank"]
|
| 537 |
+
|
| 538 |
+
# 获取度数并创建数据结构
|
| 539 |
+
node_degrees = await knwoledge_graph_inst.node_degrees_batch(nodes_in_order)
|
| 540 |
+
edge_degrees = await knwoledge_graph_inst.edge_degrees_batch(edges_in_order)
|
| 541 |
+
|
| 542 |
+
# 过滤已存在于子社区的节点/边
|
| 543 |
+
nodes_list_data = [
|
| 544 |
+
[i, name, data.get("entity_type", "UNKNOWN"),
|
| 545 |
+
data.get("description", "UNKNOWN"), node_degrees[i]]
|
| 546 |
+
for i, (name, data) in enumerate(zip(nodes_in_order, nodes_data))
|
| 547 |
+
if name not in contain_nodes # 关键过滤
|
| 548 |
+
]
|
| 549 |
+
|
| 550 |
+
edges_list_data = [
|
| 551 |
+
[i, edge[0], edge[1], data.get("description", "UNKNOWN"), edge_degrees[i]]
|
| 552 |
+
for i, (edge, data) in enumerate(zip(edges_in_order, edges_data))
|
| 553 |
+
if (edge[0], edge[1]) not in contain_edges # 关键过滤
|
| 554 |
+
]
|
| 555 |
+
|
| 556 |
+
# 按重要性排序
|
| 557 |
+
nodes_list_data.sort(key=lambda x: x[-1], reverse=True)
|
| 558 |
+
edges_list_data.sort(key=lambda x: x[-1], reverse=True)
|
| 559 |
+
|
| 560 |
+
# 5. 动态分配预算
|
| 561 |
+
# 计算表头开销
|
| 562 |
+
header_tokens = len(tokenizer_wrapper.encode(
|
| 563 |
+
list_of_list_to_csv([node_fields]) + "\n" + list_of_list_to_csv([edge_fields])
|
| 564 |
+
))
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
data_budget = max(0, remaining_budget - header_tokens)
|
| 569 |
+
total_items = len(nodes_list_data) + len(edges_list_data)
|
| 570 |
+
node_ratio = len(nodes_list_data) / max(1, total_items)
|
| 571 |
+
edge_ratio = 1 - node_ratio
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
# 执行截断
|
| 577 |
+
nodes_final = truncate_list_by_token_size(
|
| 578 |
+
nodes_list_data, key=format_row,
|
| 579 |
+
max_token_size=int(data_budget * node_ratio),
|
| 580 |
+
tokenizer_wrapper=tokenizer_wrapper
|
| 581 |
+
)
|
| 582 |
+
edges_final = truncate_list_by_token_size(
|
| 583 |
+
edges_list_data, key=format_row,
|
| 584 |
+
max_token_size= int(data_budget * edge_ratio),
|
| 585 |
+
tokenizer_wrapper=tokenizer_wrapper
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
# 6. 组装最终输出
|
| 589 |
+
nodes_describe = list_of_list_to_csv([node_fields] + nodes_final)
|
| 590 |
+
edges_describe = list_of_list_to_csv([edge_fields] + edges_final)
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
final_output = final_template.format(
|
| 595 |
+
reports=report_describe,
|
| 596 |
+
entities=nodes_describe,
|
| 597 |
+
relationships=edges_describe
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
return final_output
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
def _community_report_json_to_str(parsed_output: dict) -> str:
|
| 604 |
+
"""refer official graphrag: index/graph/extractors/community_reports"""
|
| 605 |
+
title = parsed_output.get("title", "Report")
|
| 606 |
+
summary = parsed_output.get("summary", "")
|
| 607 |
+
findings = parsed_output.get("findings", [])
|
| 608 |
+
|
| 609 |
+
def finding_summary(finding: dict):
|
| 610 |
+
if isinstance(finding, str):
|
| 611 |
+
return finding
|
| 612 |
+
return finding.get("summary")
|
| 613 |
+
|
| 614 |
+
def finding_explanation(finding: dict):
|
| 615 |
+
if isinstance(finding, str):
|
| 616 |
+
return ""
|
| 617 |
+
return finding.get("explanation")
|
| 618 |
+
|
| 619 |
+
report_sections = "\n\n".join(
|
| 620 |
+
f"## {finding_summary(f)}\n\n{finding_explanation(f)}" for f in findings
|
| 621 |
+
)
|
| 622 |
+
return f"# {title}\n\n{summary}\n\n{report_sections}"
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
async def generate_community_report(
|
| 626 |
+
community_report_kv: BaseKVStorage[CommunitySchema],
|
| 627 |
+
knwoledge_graph_inst: BaseGraphStorage,
|
| 628 |
+
tokenizer_wrapper: TokenizerWrapper,
|
| 629 |
+
global_config: dict,
|
| 630 |
+
):
|
| 631 |
+
llm_extra_kwargs = global_config["special_community_report_llm_kwargs"]
|
| 632 |
+
use_llm_func: callable = global_config["best_model_func"]
|
| 633 |
+
use_string_json_convert_func: callable = global_config["convert_response_to_json_func"]
|
| 634 |
+
|
| 635 |
+
communities_schema = await knwoledge_graph_inst.community_schema()
|
| 636 |
+
community_keys, community_values = list(communities_schema.keys()), list(communities_schema.values())
|
| 637 |
+
already_processed = 0
|
| 638 |
+
|
| 639 |
+
prompt_template = PROMPTS["community_report"]
|
| 640 |
+
|
| 641 |
+
prompt_overhead = len(tokenizer_wrapper.encode(prompt_template.format(input_text="")))
|
| 642 |
+
|
| 643 |
+
async def _form_single_community_report(
|
| 644 |
+
community: SingleCommunitySchema, already_reports: dict[str, CommunitySchema]
|
| 645 |
+
):
|
| 646 |
+
nonlocal already_processed
|
| 647 |
+
describe = await _pack_single_community_describe(
|
| 648 |
+
knwoledge_graph_inst,
|
| 649 |
+
community,
|
| 650 |
+
tokenizer_wrapper=tokenizer_wrapper,
|
| 651 |
+
max_token_size=global_config["best_model_max_token_size"] - prompt_overhead -200, # extra token for chat template and prompt template
|
| 652 |
+
already_reports=already_reports,
|
| 653 |
+
global_config=global_config,
|
| 654 |
+
)
|
| 655 |
+
prompt = prompt_template.format(input_text=describe)
|
| 656 |
+
|
| 657 |
+
|
| 658 |
+
response = await use_llm_func(prompt, **llm_extra_kwargs)
|
| 659 |
+
data = use_string_json_convert_func(response)
|
| 660 |
+
already_processed += 1
|
| 661 |
+
now_ticks = PROMPTS["process_tickers"][already_processed % len(PROMPTS["process_tickers"])]
|
| 662 |
+
print(f"{now_ticks} Processed {already_processed} communities\r", end="", flush=True)
|
| 663 |
+
return data
|
| 664 |
+
|
| 665 |
+
levels = sorted(set([c["level"] for c in community_values]), reverse=True)
|
| 666 |
+
logger.info(f"Generating by levels: {levels}")
|
| 667 |
+
community_datas = {}
|
| 668 |
+
for level in levels:
|
| 669 |
+
this_level_community_keys, this_level_community_values = zip(
|
| 670 |
+
*[
|
| 671 |
+
(k, v)
|
| 672 |
+
for k, v in zip(community_keys, community_values)
|
| 673 |
+
if v["level"] == level
|
| 674 |
+
]
|
| 675 |
+
)
|
| 676 |
+
this_level_communities_reports = await asyncio.gather(
|
| 677 |
+
*[
|
| 678 |
+
_form_single_community_report(c, community_datas)
|
| 679 |
+
for c in this_level_community_values
|
| 680 |
+
]
|
| 681 |
+
)
|
| 682 |
+
community_datas.update(
|
| 683 |
+
{
|
| 684 |
+
k: {
|
| 685 |
+
"report_string": _community_report_json_to_str(r),
|
| 686 |
+
"report_json": r,
|
| 687 |
+
**v,
|
| 688 |
+
}
|
| 689 |
+
for k, r, v in zip(
|
| 690 |
+
this_level_community_keys,
|
| 691 |
+
this_level_communities_reports,
|
| 692 |
+
this_level_community_values,
|
| 693 |
+
)
|
| 694 |
+
}
|
| 695 |
+
)
|
| 696 |
+
print() # clear the progress bar
|
| 697 |
+
await community_report_kv.upsert(community_datas)
|
| 698 |
+
|
| 699 |
+
|
| 700 |
+
async def _find_most_related_community_from_entities(
|
| 701 |
+
node_datas: list[dict],
|
| 702 |
+
query_param: QueryParam,
|
| 703 |
+
community_reports: BaseKVStorage[CommunitySchema],
|
| 704 |
+
tokenizer_wrapper,
|
| 705 |
+
):
|
| 706 |
+
related_communities = []
|
| 707 |
+
for node_d in node_datas:
|
| 708 |
+
if "clusters" not in node_d:
|
| 709 |
+
continue
|
| 710 |
+
related_communities.extend(json.loads(node_d["clusters"]))
|
| 711 |
+
related_community_dup_keys = [
|
| 712 |
+
str(dp["cluster"])
|
| 713 |
+
for dp in related_communities
|
| 714 |
+
if dp["level"] <= query_param.level
|
| 715 |
+
]
|
| 716 |
+
related_community_keys_counts = dict(Counter(related_community_dup_keys))
|
| 717 |
+
_related_community_datas = await asyncio.gather(
|
| 718 |
+
*[community_reports.get_by_id(k) for k in related_community_keys_counts.keys()]
|
| 719 |
+
)
|
| 720 |
+
related_community_datas = {
|
| 721 |
+
k: v
|
| 722 |
+
for k, v in zip(related_community_keys_counts.keys(), _related_community_datas)
|
| 723 |
+
if v is not None
|
| 724 |
+
}
|
| 725 |
+
related_community_keys = sorted(
|
| 726 |
+
related_community_keys_counts.keys(),
|
| 727 |
+
key=lambda k: (
|
| 728 |
+
related_community_keys_counts[k],
|
| 729 |
+
related_community_datas[k]["report_json"].get("rating", -1),
|
| 730 |
+
),
|
| 731 |
+
reverse=True,
|
| 732 |
+
)
|
| 733 |
+
sorted_community_datas = [
|
| 734 |
+
related_community_datas[k] for k in related_community_keys
|
| 735 |
+
]
|
| 736 |
+
|
| 737 |
+
use_community_reports = truncate_list_by_token_size(
|
| 738 |
+
sorted_community_datas,
|
| 739 |
+
key=lambda x: x["report_string"],
|
| 740 |
+
max_token_size=query_param.local_max_token_for_community_report,
|
| 741 |
+
tokenizer_wrapper=tokenizer_wrapper,
|
| 742 |
+
)
|
| 743 |
+
if query_param.local_community_single_one:
|
| 744 |
+
use_community_reports = use_community_reports[:1]
|
| 745 |
+
return use_community_reports
|
| 746 |
+
|
| 747 |
+
|
| 748 |
+
async def _find_most_related_text_unit_from_entities(
|
| 749 |
+
node_datas: list[dict],
|
| 750 |
+
query_param: QueryParam,
|
| 751 |
+
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
| 752 |
+
knowledge_graph_inst: BaseGraphStorage,
|
| 753 |
+
tokenizer_wrapper,
|
| 754 |
+
):
|
| 755 |
+
text_units = [
|
| 756 |
+
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
| 757 |
+
for dp in node_datas
|
| 758 |
+
]
|
| 759 |
+
edges = await knowledge_graph_inst.get_nodes_edges_batch([dp["entity_name"] for dp in node_datas])
|
| 760 |
+
all_one_hop_nodes = set()
|
| 761 |
+
for this_edges in edges:
|
| 762 |
+
if not this_edges:
|
| 763 |
+
continue
|
| 764 |
+
all_one_hop_nodes.update([e[1] for e in this_edges])
|
| 765 |
+
all_one_hop_nodes = list(all_one_hop_nodes)
|
| 766 |
+
all_one_hop_nodes_data = await knowledge_graph_inst.get_nodes_batch(all_one_hop_nodes)
|
| 767 |
+
all_one_hop_text_units_lookup = {
|
| 768 |
+
k: set(split_string_by_multi_markers(v["source_id"], [GRAPH_FIELD_SEP]))
|
| 769 |
+
for k, v in zip(all_one_hop_nodes, all_one_hop_nodes_data)
|
| 770 |
+
if v is not None
|
| 771 |
+
}
|
| 772 |
+
all_text_units_lookup = {}
|
| 773 |
+
for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)):
|
| 774 |
+
for c_id in this_text_units:
|
| 775 |
+
if c_id in all_text_units_lookup:
|
| 776 |
+
continue
|
| 777 |
+
relation_counts = 0
|
| 778 |
+
for e in this_edges:
|
| 779 |
+
if (
|
| 780 |
+
e[1] in all_one_hop_text_units_lookup
|
| 781 |
+
and c_id in all_one_hop_text_units_lookup[e[1]]
|
| 782 |
+
):
|
| 783 |
+
relation_counts += 1
|
| 784 |
+
all_text_units_lookup[c_id] = {
|
| 785 |
+
"data": await text_chunks_db.get_by_id(c_id),
|
| 786 |
+
"order": index,
|
| 787 |
+
"relation_counts": relation_counts,
|
| 788 |
+
}
|
| 789 |
+
if any([v is None for v in all_text_units_lookup.values()]):
|
| 790 |
+
logger.warning("Text chunks are missing, maybe the storage is damaged")
|
| 791 |
+
all_text_units = [
|
| 792 |
+
{"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None
|
| 793 |
+
]
|
| 794 |
+
all_text_units = sorted(
|
| 795 |
+
all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
|
| 796 |
+
)
|
| 797 |
+
all_text_units = truncate_list_by_token_size(
|
| 798 |
+
all_text_units,
|
| 799 |
+
key=lambda x: x["data"]["content"],
|
| 800 |
+
max_token_size=query_param.local_max_token_for_text_unit,
|
| 801 |
+
tokenizer_wrapper=tokenizer_wrapper, # 传入 wrapper
|
| 802 |
+
)
|
| 803 |
+
all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units]
|
| 804 |
+
return all_text_units
|
| 805 |
+
|
| 806 |
+
|
| 807 |
+
async def _find_most_related_edges_from_entities(
|
| 808 |
+
node_datas: list[dict],
|
| 809 |
+
query_param: QueryParam,
|
| 810 |
+
knowledge_graph_inst: BaseGraphStorage,
|
| 811 |
+
tokenizer_wrapper,
|
| 812 |
+
):
|
| 813 |
+
all_related_edges = await knowledge_graph_inst.get_nodes_edges_batch([dp["entity_name"] for dp in node_datas])
|
| 814 |
+
|
| 815 |
+
all_edges = []
|
| 816 |
+
seen = set()
|
| 817 |
+
|
| 818 |
+
for this_edges in all_related_edges:
|
| 819 |
+
for e in this_edges:
|
| 820 |
+
sorted_edge = tuple(sorted(e))
|
| 821 |
+
if sorted_edge not in seen:
|
| 822 |
+
seen.add(sorted_edge)
|
| 823 |
+
all_edges.append(sorted_edge)
|
| 824 |
+
|
| 825 |
+
all_edges_pack = await knowledge_graph_inst.get_edges_batch(all_edges)
|
| 826 |
+
all_edges_degree = await knowledge_graph_inst.edge_degrees_batch(all_edges)
|
| 827 |
+
all_edges_data = [
|
| 828 |
+
{"src_tgt": k, "rank": d, **v}
|
| 829 |
+
for k, v, d in zip(all_edges, all_edges_pack, all_edges_degree)
|
| 830 |
+
if v is not None
|
| 831 |
+
]
|
| 832 |
+
all_edges_data = sorted(
|
| 833 |
+
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
| 834 |
+
)
|
| 835 |
+
all_edges_data = truncate_list_by_token_size(
|
| 836 |
+
all_edges_data,
|
| 837 |
+
key=lambda x: x["description"],
|
| 838 |
+
max_token_size=query_param.local_max_token_for_local_context,
|
| 839 |
+
tokenizer_wrapper=tokenizer_wrapper,
|
| 840 |
+
)
|
| 841 |
+
return all_edges_data
|
| 842 |
+
|
| 843 |
+
|
| 844 |
+
async def _build_local_query_context(
|
| 845 |
+
query,
|
| 846 |
+
knowledge_graph_inst: BaseGraphStorage,
|
| 847 |
+
entities_vdb: BaseVectorStorage,
|
| 848 |
+
community_reports: BaseKVStorage[CommunitySchema],
|
| 849 |
+
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
| 850 |
+
query_param: QueryParam,
|
| 851 |
+
tokenizer_wrapper,
|
| 852 |
+
):
|
| 853 |
+
results = await entities_vdb.query(query, top_k=query_param.top_k)
|
| 854 |
+
if not len(results):
|
| 855 |
+
return None
|
| 856 |
+
node_datas = await knowledge_graph_inst.get_nodes_batch([r["entity_name"] for r in results])
|
| 857 |
+
if not all([n is not None for n in node_datas]):
|
| 858 |
+
logger.warning("Some nodes are missing, maybe the storage is damaged")
|
| 859 |
+
node_degrees = await knowledge_graph_inst.node_degrees_batch([r["entity_name"] for r in results])
|
| 860 |
+
node_datas = [
|
| 861 |
+
{**n, "entity_name": k["entity_name"], "rank": d}
|
| 862 |
+
for k, n, d in zip(results, node_datas, node_degrees)
|
| 863 |
+
if n is not None
|
| 864 |
+
]
|
| 865 |
+
use_communities = await _find_most_related_community_from_entities(
|
| 866 |
+
node_datas, query_param, community_reports, tokenizer_wrapper
|
| 867 |
+
)
|
| 868 |
+
use_text_units = await _find_most_related_text_unit_from_entities(
|
| 869 |
+
node_datas, query_param, text_chunks_db, knowledge_graph_inst, tokenizer_wrapper
|
| 870 |
+
)
|
| 871 |
+
use_relations = await _find_most_related_edges_from_entities(
|
| 872 |
+
node_datas, query_param, knowledge_graph_inst, tokenizer_wrapper
|
| 873 |
+
)
|
| 874 |
+
logger.info(
|
| 875 |
+
f"Using {len(node_datas)} entites, {len(use_communities)} communities, {len(use_relations)} relations, {len(use_text_units)} text units"
|
| 876 |
+
)
|
| 877 |
+
entites_section_list = [["id", "entity", "type", "description", "rank"]]
|
| 878 |
+
for i, n in enumerate(node_datas):
|
| 879 |
+
entites_section_list.append(
|
| 880 |
+
[
|
| 881 |
+
i,
|
| 882 |
+
n["entity_name"],
|
| 883 |
+
n.get("entity_type", "UNKNOWN"),
|
| 884 |
+
n.get("description", "UNKNOWN"),
|
| 885 |
+
n["rank"],
|
| 886 |
+
]
|
| 887 |
+
)
|
| 888 |
+
entities_context = list_of_list_to_csv(entites_section_list)
|
| 889 |
+
|
| 890 |
+
relations_section_list = [
|
| 891 |
+
["id", "source", "target", "description", "weight", "rank"]
|
| 892 |
+
]
|
| 893 |
+
for i, e in enumerate(use_relations):
|
| 894 |
+
relations_section_list.append(
|
| 895 |
+
[
|
| 896 |
+
i,
|
| 897 |
+
e["src_tgt"][0],
|
| 898 |
+
e["src_tgt"][1],
|
| 899 |
+
e["description"],
|
| 900 |
+
e["weight"],
|
| 901 |
+
e["rank"],
|
| 902 |
+
]
|
| 903 |
+
)
|
| 904 |
+
relations_context = list_of_list_to_csv(relations_section_list)
|
| 905 |
+
|
| 906 |
+
communities_section_list = [["id", "content"]]
|
| 907 |
+
for i, c in enumerate(use_communities):
|
| 908 |
+
communities_section_list.append([i, c["report_string"]])
|
| 909 |
+
communities_context = list_of_list_to_csv(communities_section_list)
|
| 910 |
+
|
| 911 |
+
text_units_section_list = [["id", "content"]]
|
| 912 |
+
for i, t in enumerate(use_text_units):
|
| 913 |
+
text_units_section_list.append([i, t["content"]])
|
| 914 |
+
text_units_context = list_of_list_to_csv(text_units_section_list)
|
| 915 |
+
return f"""
|
| 916 |
+
-----Reports-----
|
| 917 |
+
```csv
|
| 918 |
+
{communities_context}
|
| 919 |
+
```
|
| 920 |
+
-----Entities-----
|
| 921 |
+
```csv
|
| 922 |
+
{entities_context}
|
| 923 |
+
```
|
| 924 |
+
-----Relationships-----
|
| 925 |
+
```csv
|
| 926 |
+
{relations_context}
|
| 927 |
+
```
|
| 928 |
+
-----Sources-----
|
| 929 |
+
```csv
|
| 930 |
+
{text_units_context}
|
| 931 |
+
```
|
| 932 |
+
"""
|
| 933 |
+
|
| 934 |
+
|
| 935 |
+
async def local_query(
|
| 936 |
+
query,
|
| 937 |
+
knowledge_graph_inst: BaseGraphStorage,
|
| 938 |
+
entities_vdb: BaseVectorStorage,
|
| 939 |
+
community_reports: BaseKVStorage[CommunitySchema],
|
| 940 |
+
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
| 941 |
+
query_param: QueryParam,
|
| 942 |
+
tokenizer_wrapper,
|
| 943 |
+
global_config: dict,
|
| 944 |
+
) -> str:
|
| 945 |
+
use_model_func = global_config["best_model_func"]
|
| 946 |
+
context = await _build_local_query_context(
|
| 947 |
+
query,
|
| 948 |
+
knowledge_graph_inst,
|
| 949 |
+
entities_vdb,
|
| 950 |
+
community_reports,
|
| 951 |
+
text_chunks_db,
|
| 952 |
+
query_param,
|
| 953 |
+
tokenizer_wrapper,
|
| 954 |
+
)
|
| 955 |
+
if query_param.only_need_context:
|
| 956 |
+
return context
|
| 957 |
+
if context is None:
|
| 958 |
+
return PROMPTS["fail_response"]
|
| 959 |
+
sys_prompt_temp = PROMPTS["local_rag_response"]
|
| 960 |
+
sys_prompt = sys_prompt_temp.format(
|
| 961 |
+
context_data=context, response_type=query_param.response_type
|
| 962 |
+
)
|
| 963 |
+
response = await use_model_func(
|
| 964 |
+
query,
|
| 965 |
+
system_prompt=sys_prompt,
|
| 966 |
+
)
|
| 967 |
+
return response
|
| 968 |
+
|
| 969 |
+
|
| 970 |
+
async def _map_global_communities(
|
| 971 |
+
query: str,
|
| 972 |
+
communities_data: list[CommunitySchema],
|
| 973 |
+
query_param: QueryParam,
|
| 974 |
+
global_config: dict,
|
| 975 |
+
tokenizer_wrapper,
|
| 976 |
+
):
|
| 977 |
+
use_string_json_convert_func = global_config["convert_response_to_json_func"]
|
| 978 |
+
use_model_func = global_config["best_model_func"]
|
| 979 |
+
community_groups = []
|
| 980 |
+
while len(communities_data):
|
| 981 |
+
this_group = truncate_list_by_token_size(
|
| 982 |
+
communities_data,
|
| 983 |
+
key=lambda x: x["report_string"],
|
| 984 |
+
max_token_size=query_param.global_max_token_for_community_report,
|
| 985 |
+
tokenizer_wrapper=tokenizer_wrapper, # 传入 wrapper
|
| 986 |
+
)
|
| 987 |
+
community_groups.append(this_group)
|
| 988 |
+
communities_data = communities_data[len(this_group) :]
|
| 989 |
+
|
| 990 |
+
async def _process(community_truncated_datas: list[CommunitySchema]) -> dict:
|
| 991 |
+
communities_section_list = [["id", "content", "rating", "importance"]]
|
| 992 |
+
for i, c in enumerate(community_truncated_datas):
|
| 993 |
+
communities_section_list.append(
|
| 994 |
+
[
|
| 995 |
+
i,
|
| 996 |
+
c["report_string"],
|
| 997 |
+
c["report_json"].get("rating", 0),
|
| 998 |
+
c["occurrence"],
|
| 999 |
+
]
|
| 1000 |
+
)
|
| 1001 |
+
community_context = list_of_list_to_csv(communities_section_list)
|
| 1002 |
+
sys_prompt_temp = PROMPTS["global_map_rag_points"]
|
| 1003 |
+
sys_prompt = sys_prompt_temp.format(context_data=community_context)
|
| 1004 |
+
response = await use_model_func(
|
| 1005 |
+
query,
|
| 1006 |
+
system_prompt=sys_prompt,
|
| 1007 |
+
**query_param.global_special_community_map_llm_kwargs,
|
| 1008 |
+
)
|
| 1009 |
+
data = use_string_json_convert_func(response)
|
| 1010 |
+
return data.get("points", [])
|
| 1011 |
+
|
| 1012 |
+
logger.info(f"Grouping to {len(community_groups)} groups for global search")
|
| 1013 |
+
responses = await asyncio.gather(*[_process(c) for c in community_groups])
|
| 1014 |
+
return responses
|
| 1015 |
+
|
| 1016 |
+
|
| 1017 |
+
async def global_query(
|
| 1018 |
+
query,
|
| 1019 |
+
knowledge_graph_inst: BaseGraphStorage,
|
| 1020 |
+
entities_vdb: BaseVectorStorage,
|
| 1021 |
+
community_reports: BaseKVStorage[CommunitySchema],
|
| 1022 |
+
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
| 1023 |
+
query_param: QueryParam,
|
| 1024 |
+
tokenizer_wrapper,
|
| 1025 |
+
global_config: dict,
|
| 1026 |
+
) -> str:
|
| 1027 |
+
community_schema = await knowledge_graph_inst.community_schema()
|
| 1028 |
+
community_schema = {
|
| 1029 |
+
k: v for k, v in community_schema.items() if v["level"] <= query_param.level
|
| 1030 |
+
}
|
| 1031 |
+
if not len(community_schema):
|
| 1032 |
+
return PROMPTS["fail_response"]
|
| 1033 |
+
use_model_func = global_config["best_model_func"]
|
| 1034 |
+
|
| 1035 |
+
sorted_community_schemas = sorted(
|
| 1036 |
+
community_schema.items(),
|
| 1037 |
+
key=lambda x: x[1]["occurrence"],
|
| 1038 |
+
reverse=True,
|
| 1039 |
+
)
|
| 1040 |
+
sorted_community_schemas = sorted_community_schemas[
|
| 1041 |
+
: query_param.global_max_consider_community
|
| 1042 |
+
]
|
| 1043 |
+
community_datas = await community_reports.get_by_ids(
|
| 1044 |
+
[k[0] for k in sorted_community_schemas]
|
| 1045 |
+
)
|
| 1046 |
+
community_datas = [c for c in community_datas if c is not None]
|
| 1047 |
+
community_datas = [
|
| 1048 |
+
c
|
| 1049 |
+
for c in community_datas
|
| 1050 |
+
if c["report_json"].get("rating", 0) >= query_param.global_min_community_rating
|
| 1051 |
+
]
|
| 1052 |
+
community_datas = sorted(
|
| 1053 |
+
community_datas,
|
| 1054 |
+
key=lambda x: (x["occurrence"], x["report_json"].get("rating", 0)),
|
| 1055 |
+
reverse=True,
|
| 1056 |
+
)
|
| 1057 |
+
logger.info(f"Revtrieved {len(community_datas)} communities")
|
| 1058 |
+
|
| 1059 |
+
map_communities_points = await _map_global_communities(
|
| 1060 |
+
query, community_datas, query_param, global_config, tokenizer_wrapper
|
| 1061 |
+
)
|
| 1062 |
+
final_support_points = []
|
| 1063 |
+
for i, mc in enumerate(map_communities_points):
|
| 1064 |
+
for point in mc:
|
| 1065 |
+
if "description" not in point:
|
| 1066 |
+
continue
|
| 1067 |
+
final_support_points.append(
|
| 1068 |
+
{
|
| 1069 |
+
"analyst": i,
|
| 1070 |
+
"answer": point["description"],
|
| 1071 |
+
"score": point.get("score", 1),
|
| 1072 |
+
}
|
| 1073 |
+
)
|
| 1074 |
+
final_support_points = [p for p in final_support_points if p["score"] > 0]
|
| 1075 |
+
if not len(final_support_points):
|
| 1076 |
+
return PROMPTS["fail_response"]
|
| 1077 |
+
final_support_points = sorted(
|
| 1078 |
+
final_support_points, key=lambda x: x["score"], reverse=True
|
| 1079 |
+
)
|
| 1080 |
+
final_support_points = truncate_list_by_token_size(
|
| 1081 |
+
final_support_points,
|
| 1082 |
+
key=lambda x: x["answer"],
|
| 1083 |
+
max_token_size=query_param.global_max_token_for_community_report,
|
| 1084 |
+
tokenizer_wrapper=tokenizer_wrapper, # 传入 wrapper
|
| 1085 |
+
)
|
| 1086 |
+
points_context = []
|
| 1087 |
+
for dp in final_support_points:
|
| 1088 |
+
points_context.append(
|
| 1089 |
+
f"""----Analyst {dp['analyst']}----
|
| 1090 |
+
Importance Score: {dp['score']}
|
| 1091 |
+
{dp['answer']}
|
| 1092 |
+
"""
|
| 1093 |
+
)
|
| 1094 |
+
points_context = "\n".join(points_context)
|
| 1095 |
+
if query_param.only_need_context:
|
| 1096 |
+
return points_context
|
| 1097 |
+
sys_prompt_temp = PROMPTS["global_reduce_rag_response"]
|
| 1098 |
+
response = await use_model_func(
|
| 1099 |
+
query,
|
| 1100 |
+
sys_prompt_temp.format(
|
| 1101 |
+
report_data=points_context, response_type=query_param.response_type
|
| 1102 |
+
),
|
| 1103 |
+
)
|
| 1104 |
+
return response
|
| 1105 |
+
|
| 1106 |
+
|
| 1107 |
+
async def naive_query(
|
| 1108 |
+
query,
|
| 1109 |
+
chunks_vdb: BaseVectorStorage,
|
| 1110 |
+
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
| 1111 |
+
query_param: QueryParam,
|
| 1112 |
+
tokenizer_wrapper,
|
| 1113 |
+
global_config: dict,
|
| 1114 |
+
):
|
| 1115 |
+
use_model_func = global_config["best_model_func"]
|
| 1116 |
+
results = await chunks_vdb.query(query, top_k=query_param.top_k)
|
| 1117 |
+
if not len(results):
|
| 1118 |
+
return PROMPTS["fail_response"]
|
| 1119 |
+
chunks_ids = [r["id"] for r in results]
|
| 1120 |
+
chunks = await text_chunks_db.get_by_ids(chunks_ids)
|
| 1121 |
+
|
| 1122 |
+
maybe_trun_chunks = truncate_list_by_token_size(
|
| 1123 |
+
chunks,
|
| 1124 |
+
key=lambda x: x["content"],
|
| 1125 |
+
max_token_size=query_param.naive_max_token_for_text_unit,
|
| 1126 |
+
tokenizer_wrapper=tokenizer_wrapper, # 传入 wrapper
|
| 1127 |
+
)
|
| 1128 |
+
logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks")
|
| 1129 |
+
section = "--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks])
|
| 1130 |
+
if query_param.only_need_context:
|
| 1131 |
+
return section
|
| 1132 |
+
sys_prompt_temp = PROMPTS["naive_rag_response"]
|
| 1133 |
+
sys_prompt = sys_prompt_temp.format(
|
| 1134 |
+
content_data=section, response_type=query_param.response_type
|
| 1135 |
+
)
|
| 1136 |
+
response = await use_model_func(
|
| 1137 |
+
query,
|
| 1138 |
+
system_prompt=sys_prompt,
|
| 1139 |
+
)
|
| 1140 |
+
return response
|
nano-graphrag/nano_graphrag/_splitter.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Union, Literal
|
| 2 |
+
|
| 3 |
+
class SeparatorSplitter:
|
| 4 |
+
def __init__(
|
| 5 |
+
self,
|
| 6 |
+
separators: Optional[List[List[int]]] = None,
|
| 7 |
+
keep_separator: Union[bool, Literal["start", "end"]] = "end",
|
| 8 |
+
chunk_size: int = 4000,
|
| 9 |
+
chunk_overlap: int = 200,
|
| 10 |
+
length_function: callable = len,
|
| 11 |
+
):
|
| 12 |
+
self._separators = separators or []
|
| 13 |
+
self._keep_separator = keep_separator
|
| 14 |
+
self._chunk_size = chunk_size
|
| 15 |
+
self._chunk_overlap = chunk_overlap
|
| 16 |
+
self._length_function = length_function
|
| 17 |
+
|
| 18 |
+
def split_tokens(self, tokens: List[int]) -> List[List[int]]:
|
| 19 |
+
splits = self._split_tokens_with_separators(tokens)
|
| 20 |
+
return self._merge_splits(splits)
|
| 21 |
+
|
| 22 |
+
def _split_tokens_with_separators(self, tokens: List[int]) -> List[List[int]]:
|
| 23 |
+
splits = []
|
| 24 |
+
current_split = []
|
| 25 |
+
i = 0
|
| 26 |
+
while i < len(tokens):
|
| 27 |
+
separator_found = False
|
| 28 |
+
for separator in self._separators:
|
| 29 |
+
if tokens[i:i+len(separator)] == separator:
|
| 30 |
+
if self._keep_separator in [True, "end"]:
|
| 31 |
+
current_split.extend(separator)
|
| 32 |
+
if current_split:
|
| 33 |
+
splits.append(current_split)
|
| 34 |
+
current_split = []
|
| 35 |
+
if self._keep_separator == "start":
|
| 36 |
+
current_split.extend(separator)
|
| 37 |
+
i += len(separator)
|
| 38 |
+
separator_found = True
|
| 39 |
+
break
|
| 40 |
+
if not separator_found:
|
| 41 |
+
current_split.append(tokens[i])
|
| 42 |
+
i += 1
|
| 43 |
+
if current_split:
|
| 44 |
+
splits.append(current_split)
|
| 45 |
+
return [s for s in splits if s]
|
| 46 |
+
|
| 47 |
+
def _merge_splits(self, splits: List[List[int]]) -> List[List[int]]:
|
| 48 |
+
if not splits:
|
| 49 |
+
return []
|
| 50 |
+
|
| 51 |
+
merged_splits = []
|
| 52 |
+
current_chunk = []
|
| 53 |
+
|
| 54 |
+
for split in splits:
|
| 55 |
+
if not current_chunk:
|
| 56 |
+
current_chunk = split
|
| 57 |
+
elif self._length_function(current_chunk) + self._length_function(split) <= self._chunk_size:
|
| 58 |
+
current_chunk.extend(split)
|
| 59 |
+
else:
|
| 60 |
+
merged_splits.append(current_chunk)
|
| 61 |
+
current_chunk = split
|
| 62 |
+
|
| 63 |
+
if current_chunk:
|
| 64 |
+
merged_splits.append(current_chunk)
|
| 65 |
+
|
| 66 |
+
if len(merged_splits) == 1 and self._length_function(merged_splits[0]) > self._chunk_size:
|
| 67 |
+
return self._split_chunk(merged_splits[0])
|
| 68 |
+
|
| 69 |
+
if self._chunk_overlap > 0:
|
| 70 |
+
return self._enforce_overlap(merged_splits)
|
| 71 |
+
|
| 72 |
+
return merged_splits
|
| 73 |
+
|
| 74 |
+
def _split_chunk(self, chunk: List[int]) -> List[List[int]]:
|
| 75 |
+
result = []
|
| 76 |
+
for i in range(0, len(chunk), self._chunk_size - self._chunk_overlap):
|
| 77 |
+
new_chunk = chunk[i:i + self._chunk_size]
|
| 78 |
+
if len(new_chunk) > self._chunk_overlap: # 只有当 chunk 长度大于 overlap 时才添加
|
| 79 |
+
result.append(new_chunk)
|
| 80 |
+
return result
|
| 81 |
+
|
| 82 |
+
def _enforce_overlap(self, chunks: List[List[int]]) -> List[List[int]]:
|
| 83 |
+
result = []
|
| 84 |
+
for i, chunk in enumerate(chunks):
|
| 85 |
+
if i == 0:
|
| 86 |
+
result.append(chunk)
|
| 87 |
+
else:
|
| 88 |
+
overlap = chunks[i-1][-self._chunk_overlap:]
|
| 89 |
+
new_chunk = overlap + chunk
|
| 90 |
+
if self._length_function(new_chunk) > self._chunk_size:
|
| 91 |
+
new_chunk = new_chunk[:self._chunk_size]
|
| 92 |
+
result.append(new_chunk)
|
| 93 |
+
return result
|
| 94 |
+
|
nano-graphrag/nano_graphrag/_storage/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .gdb_networkx import NetworkXStorage
|
| 2 |
+
from .gdb_neo4j import Neo4jStorage
|
| 3 |
+
from .vdb_hnswlib import HNSWVectorStorage
|
| 4 |
+
from .vdb_nanovectordb import NanoVectorDBStorage
|
| 5 |
+
from .kv_json import JsonKVStorage
|
nano-graphrag/nano_graphrag/_storage/gdb_neo4j.py
ADDED
|
@@ -0,0 +1,529 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import asyncio
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
from typing import List
|
| 5 |
+
from neo4j import AsyncGraphDatabase
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Union
|
| 8 |
+
from ..base import BaseGraphStorage, SingleCommunitySchema
|
| 9 |
+
from .._utils import logger
|
| 10 |
+
from ..prompt import GRAPH_FIELD_SEP
|
| 11 |
+
|
| 12 |
+
neo4j_lock = asyncio.Lock()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def make_path_idable(path):
|
| 16 |
+
return path.replace(".", "_").replace("/", "__").replace("-", "_").replace(":", "_").replace("\\", "__")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class Neo4jStorage(BaseGraphStorage):
|
| 21 |
+
def __post_init__(self):
|
| 22 |
+
self.neo4j_url = self.global_config["addon_params"].get("neo4j_url", None)
|
| 23 |
+
self.neo4j_auth = self.global_config["addon_params"].get("neo4j_auth", None)
|
| 24 |
+
self.namespace = (
|
| 25 |
+
f"{make_path_idable(self.global_config['working_dir'])}__{self.namespace}"
|
| 26 |
+
)
|
| 27 |
+
logger.info(f"Using the label {self.namespace} for Neo4j as identifier")
|
| 28 |
+
if self.neo4j_url is None or self.neo4j_auth is None:
|
| 29 |
+
raise ValueError("Missing neo4j_url or neo4j_auth in addon_params")
|
| 30 |
+
self.async_driver = AsyncGraphDatabase.driver(
|
| 31 |
+
self.neo4j_url, auth=self.neo4j_auth, max_connection_pool_size=50,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# async def create_database(self):
|
| 35 |
+
# async with self.async_driver.session() as session:
|
| 36 |
+
# try:
|
| 37 |
+
# constraints = await session.run("SHOW CONSTRAINTS")
|
| 38 |
+
# # TODO I don't know why CREATE CONSTRAINT IF NOT EXISTS still trigger error
|
| 39 |
+
# # so have to check if the constrain exists
|
| 40 |
+
# constrain_exists = False
|
| 41 |
+
|
| 42 |
+
# async for record in constraints:
|
| 43 |
+
# if (
|
| 44 |
+
# self.namespace in record["labelsOrTypes"]
|
| 45 |
+
# and "id" in record["properties"]
|
| 46 |
+
# and record["type"] == "UNIQUENESS"
|
| 47 |
+
# ):
|
| 48 |
+
# constrain_exists = True
|
| 49 |
+
# break
|
| 50 |
+
# if not constrain_exists:
|
| 51 |
+
# await session.run(
|
| 52 |
+
# f"CREATE CONSTRAINT FOR (n:{self.namespace}) REQUIRE n.id IS UNIQUE"
|
| 53 |
+
# )
|
| 54 |
+
# logger.info(f"Add constraint for namespace: {self.namespace}")
|
| 55 |
+
|
| 56 |
+
# except Exception as e:
|
| 57 |
+
# logger.error(f"Error accessing or setting up the database: {str(e)}")
|
| 58 |
+
# raise
|
| 59 |
+
|
| 60 |
+
async def _init_workspace(self):
|
| 61 |
+
await self.async_driver.verify_authentication()
|
| 62 |
+
await self.async_driver.verify_connectivity()
|
| 63 |
+
# TODOLater: create database if not exists always cause an error when async
|
| 64 |
+
# await self.create_database()
|
| 65 |
+
|
| 66 |
+
async def index_start_callback(self):
|
| 67 |
+
logger.info("Init Neo4j workspace")
|
| 68 |
+
await self._init_workspace()
|
| 69 |
+
|
| 70 |
+
# create index for faster searching
|
| 71 |
+
try:
|
| 72 |
+
async with self.async_driver.session() as session:
|
| 73 |
+
await session.run(
|
| 74 |
+
f"CREATE INDEX IF NOT EXISTS FOR (n:`{self.namespace}`) ON (n.id)"
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
await session.run(
|
| 78 |
+
f"CREATE INDEX IF NOT EXISTS FOR (n:`{self.namespace}`) ON (n.entity_type)"
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
await session.run(
|
| 82 |
+
f"CREATE INDEX IF NOT EXISTS FOR (n:`{self.namespace}`) ON (n.communityIds)"
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
await session.run(
|
| 86 |
+
f"CREATE INDEX IF NOT EXISTS FOR (n:`{self.namespace}`) ON (n.source_id)"
|
| 87 |
+
)
|
| 88 |
+
logger.info("Neo4j indexes created successfully")
|
| 89 |
+
except Exception as e:
|
| 90 |
+
logger.error(f"Failed to create indexes: {e}")
|
| 91 |
+
raise e
|
| 92 |
+
|
| 93 |
+
async def has_node(self, node_id: str) -> bool:
|
| 94 |
+
async with self.async_driver.session() as session:
|
| 95 |
+
result = await session.run(
|
| 96 |
+
f"MATCH (n:`{self.namespace}`) WHERE n.id = $node_id RETURN COUNT(n) > 0 AS exists",
|
| 97 |
+
node_id=node_id,
|
| 98 |
+
)
|
| 99 |
+
record = await result.single()
|
| 100 |
+
return record["exists"] if record else False
|
| 101 |
+
|
| 102 |
+
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
| 103 |
+
async with self.async_driver.session() as session:
|
| 104 |
+
result = await session.run(
|
| 105 |
+
f"""
|
| 106 |
+
MATCH (s:`{self.namespace}`)
|
| 107 |
+
WHERE s.id = $source_id
|
| 108 |
+
MATCH (t:`{self.namespace}`)
|
| 109 |
+
WHERE t.id = $target_id
|
| 110 |
+
RETURN EXISTS((s)-[]->(t)) AS exists
|
| 111 |
+
""",
|
| 112 |
+
source_id=source_node_id,
|
| 113 |
+
target_id=target_node_id,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
record = await result.single()
|
| 117 |
+
return record["exists"] if record else False
|
| 118 |
+
|
| 119 |
+
async def node_degree(self, node_id: str) -> int:
|
| 120 |
+
results = await self.node_degrees_batch([node_id])
|
| 121 |
+
return results[0] if results else 0
|
| 122 |
+
|
| 123 |
+
async def node_degrees_batch(self, node_ids: List[str]) -> List[str]:
|
| 124 |
+
if not node_ids:
|
| 125 |
+
return {}
|
| 126 |
+
|
| 127 |
+
result_dict = {node_id: 0 for node_id in node_ids}
|
| 128 |
+
async with self.async_driver.session() as session:
|
| 129 |
+
result = await session.run(
|
| 130 |
+
f"""
|
| 131 |
+
UNWIND $node_ids AS node_id
|
| 132 |
+
MATCH (n:`{self.namespace}`)
|
| 133 |
+
WHERE n.id = node_id
|
| 134 |
+
OPTIONAL MATCH (n)-[]-(m:`{self.namespace}`)
|
| 135 |
+
RETURN node_id, COUNT(m) AS degree
|
| 136 |
+
""",
|
| 137 |
+
node_ids=node_ids
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
async for record in result:
|
| 141 |
+
result_dict[record["node_id"]] = record["degree"]
|
| 142 |
+
|
| 143 |
+
return [result_dict[node_id] for node_id in node_ids]
|
| 144 |
+
|
| 145 |
+
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
| 146 |
+
results = await self.edge_degrees_batch([(src_id, tgt_id)])
|
| 147 |
+
return results[0] if results else 0
|
| 148 |
+
|
| 149 |
+
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> list[int]:
|
| 150 |
+
if not edge_pairs:
|
| 151 |
+
return []
|
| 152 |
+
|
| 153 |
+
result_dict = {tuple(edge_pair): 0 for edge_pair in edge_pairs}
|
| 154 |
+
|
| 155 |
+
edges_params = [{"src_id": src, "tgt_id": tgt} for src, tgt in edge_pairs]
|
| 156 |
+
|
| 157 |
+
try:
|
| 158 |
+
async with self.async_driver.session() as session:
|
| 159 |
+
result = await session.run(
|
| 160 |
+
f"""
|
| 161 |
+
UNWIND $edges AS edge
|
| 162 |
+
|
| 163 |
+
MATCH (s:`{self.namespace}`)
|
| 164 |
+
WHERE s.id = edge.src_id
|
| 165 |
+
WITH edge, s
|
| 166 |
+
OPTIONAL MATCH (s)-[]-(n1:`{self.namespace}`)
|
| 167 |
+
WITH edge, COUNT(n1) AS src_degree
|
| 168 |
+
|
| 169 |
+
MATCH (t:`{self.namespace}`)
|
| 170 |
+
WHERE t.id = edge.tgt_id
|
| 171 |
+
WITH edge, src_degree, t
|
| 172 |
+
OPTIONAL MATCH (t)-[]-(n2:`{self.namespace}`)
|
| 173 |
+
WITH edge.src_id AS src_id, edge.tgt_id AS tgt_id, src_degree, COUNT(n2) AS tgt_degree
|
| 174 |
+
|
| 175 |
+
RETURN src_id, tgt_id, src_degree + tgt_degree AS degree
|
| 176 |
+
""",
|
| 177 |
+
edges=edges_params
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
async for record in result:
|
| 181 |
+
src_id = record["src_id"]
|
| 182 |
+
tgt_id = record["tgt_id"]
|
| 183 |
+
degree = record["degree"]
|
| 184 |
+
|
| 185 |
+
# 更新结果字典
|
| 186 |
+
edge_pair = (src_id, tgt_id)
|
| 187 |
+
result_dict[edge_pair] = degree
|
| 188 |
+
|
| 189 |
+
return [result_dict[tuple(edge_pair)] for edge_pair in edge_pairs]
|
| 190 |
+
except Exception as e:
|
| 191 |
+
logger.error(f"Error in batch edge degree calculation: {e}")
|
| 192 |
+
return [0] * len(edge_pairs)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
async def get_node(self, node_id: str) -> Union[dict, None]:
|
| 197 |
+
result = await self.get_nodes_batch([node_id])
|
| 198 |
+
return result[0] if result else None
|
| 199 |
+
|
| 200 |
+
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, Union[dict, None]]:
|
| 201 |
+
if not node_ids:
|
| 202 |
+
return {}
|
| 203 |
+
|
| 204 |
+
result_dict = {node_id: None for node_id in node_ids}
|
| 205 |
+
|
| 206 |
+
try:
|
| 207 |
+
async with self.async_driver.session() as session:
|
| 208 |
+
result = await session.run(
|
| 209 |
+
f"""
|
| 210 |
+
UNWIND $node_ids AS node_id
|
| 211 |
+
MATCH (n:`{self.namespace}`)
|
| 212 |
+
WHERE n.id = node_id
|
| 213 |
+
RETURN node_id, properties(n) AS node_data
|
| 214 |
+
""",
|
| 215 |
+
node_ids=node_ids
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
async for record in result:
|
| 219 |
+
node_id = record["node_id"]
|
| 220 |
+
raw_node_data = record["node_data"]
|
| 221 |
+
|
| 222 |
+
if raw_node_data:
|
| 223 |
+
raw_node_data["clusters"] = json.dumps(
|
| 224 |
+
[
|
| 225 |
+
{
|
| 226 |
+
"level": index,
|
| 227 |
+
"cluster": cluster_id,
|
| 228 |
+
}
|
| 229 |
+
for index, cluster_id in enumerate(
|
| 230 |
+
raw_node_data.get("communityIds", [])
|
| 231 |
+
)
|
| 232 |
+
]
|
| 233 |
+
)
|
| 234 |
+
result_dict[node_id] = raw_node_data
|
| 235 |
+
return [result_dict[node_id] for node_id in node_ids]
|
| 236 |
+
except Exception as e:
|
| 237 |
+
logger.error(f"Error in batch node retrieval: {e}")
|
| 238 |
+
raise e
|
| 239 |
+
|
| 240 |
+
async def get_edge(
|
| 241 |
+
self, source_node_id: str, target_node_id: str
|
| 242 |
+
) -> Union[dict, None]:
|
| 243 |
+
results = await self.get_edges_batch([(source_node_id, target_node_id)])
|
| 244 |
+
return results[0] if results else None
|
| 245 |
+
|
| 246 |
+
async def get_edges_batch(
|
| 247 |
+
self, edge_pairs: list[tuple[str, str]]
|
| 248 |
+
) -> list[Union[dict, None]]:
|
| 249 |
+
if not edge_pairs:
|
| 250 |
+
return []
|
| 251 |
+
|
| 252 |
+
result_dict = {tuple(edge_pair): None for edge_pair in edge_pairs}
|
| 253 |
+
|
| 254 |
+
edges_params = [{"source_id": src, "target_id": tgt} for src, tgt in edge_pairs]
|
| 255 |
+
|
| 256 |
+
try:
|
| 257 |
+
async with self.async_driver.session() as session:
|
| 258 |
+
result = await session.run(
|
| 259 |
+
f"""
|
| 260 |
+
UNWIND $edges AS edge
|
| 261 |
+
MATCH (s:`{self.namespace}`)-[r]->(t:`{self.namespace}`)
|
| 262 |
+
WHERE s.id = edge.source_id AND t.id = edge.target_id
|
| 263 |
+
RETURN edge.source_id AS source_id, edge.target_id AS target_id, properties(r) AS edge_data
|
| 264 |
+
""",
|
| 265 |
+
edges=edges_params
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
async for record in result:
|
| 269 |
+
source_id = record["source_id"]
|
| 270 |
+
target_id = record["target_id"]
|
| 271 |
+
edge_data = record["edge_data"]
|
| 272 |
+
|
| 273 |
+
edge_pair = (source_id, target_id)
|
| 274 |
+
result_dict[edge_pair] = edge_data
|
| 275 |
+
|
| 276 |
+
return [result_dict[tuple(edge_pair)] for edge_pair in edge_pairs]
|
| 277 |
+
except Exception as e:
|
| 278 |
+
logger.error(f"Error in batch edge retrieval: {e}")
|
| 279 |
+
return [None] * len(edge_pairs)
|
| 280 |
+
|
| 281 |
+
async def get_node_edges(
|
| 282 |
+
self, source_node_id: str
|
| 283 |
+
) -> list[tuple[str, str]]:
|
| 284 |
+
results = await self.get_nodes_edges_batch([source_node_id])
|
| 285 |
+
return results[0] if results else []
|
| 286 |
+
|
| 287 |
+
async def get_nodes_edges_batch(
|
| 288 |
+
self, node_ids: list[str]
|
| 289 |
+
) -> list[list[tuple[str, str]]]:
|
| 290 |
+
if not node_ids:
|
| 291 |
+
return []
|
| 292 |
+
|
| 293 |
+
result_dict = {node_id: [] for node_id in node_ids}
|
| 294 |
+
|
| 295 |
+
try:
|
| 296 |
+
async with self.async_driver.session() as session:
|
| 297 |
+
result = await session.run(
|
| 298 |
+
f"""
|
| 299 |
+
UNWIND $node_ids AS node_id
|
| 300 |
+
MATCH (s:`{self.namespace}`)-[r]->(t:`{self.namespace}`)
|
| 301 |
+
WHERE s.id = node_id
|
| 302 |
+
RETURN s.id AS source_id, t.id AS target_id
|
| 303 |
+
""",
|
| 304 |
+
node_ids=node_ids
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
async for record in result:
|
| 308 |
+
source_id = record["source_id"]
|
| 309 |
+
target_id = record["target_id"]
|
| 310 |
+
|
| 311 |
+
if source_id in result_dict:
|
| 312 |
+
result_dict[source_id].append((source_id, target_id))
|
| 313 |
+
|
| 314 |
+
return [result_dict[node_id] for node_id in node_ids]
|
| 315 |
+
except Exception as e:
|
| 316 |
+
logger.error(f"Error in batch node edges retrieval: {e}")
|
| 317 |
+
return [[] for _ in node_ids]
|
| 318 |
+
|
| 319 |
+
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
| 320 |
+
await self.upsert_nodes_batch([(node_id, node_data)])
|
| 321 |
+
|
| 322 |
+
async def upsert_nodes_batch(self, nodes_data: list[tuple[str, dict[str, str]]]):
|
| 323 |
+
if not nodes_data:
|
| 324 |
+
return []
|
| 325 |
+
|
| 326 |
+
nodes_by_type = {}
|
| 327 |
+
for node_id, node_data in nodes_data:
|
| 328 |
+
node_type = node_data.get("entity_type", "UNKNOWN").strip('"')
|
| 329 |
+
if node_type not in nodes_by_type:
|
| 330 |
+
nodes_by_type[node_type] = []
|
| 331 |
+
nodes_by_type[node_type].append((node_id, node_data))
|
| 332 |
+
|
| 333 |
+
async with self.async_driver.session() as session:
|
| 334 |
+
for node_type, type_nodes in nodes_by_type.items():
|
| 335 |
+
params = [{"id": node_id, "data": node_data} for node_id, node_data in type_nodes]
|
| 336 |
+
|
| 337 |
+
await session.run(
|
| 338 |
+
f"""
|
| 339 |
+
UNWIND $nodes AS node
|
| 340 |
+
MERGE (n:`{self.namespace}`:`{node_type}` {{id: node.id}})
|
| 341 |
+
SET n += node.data
|
| 342 |
+
""",
|
| 343 |
+
nodes=params
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
async def upsert_edge(
|
| 347 |
+
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
| 348 |
+
):
|
| 349 |
+
await self.upsert_edges_batch([(source_node_id, target_node_id, edge_data)])
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
async def upsert_edges_batch(
|
| 353 |
+
self, edges_data: list[tuple[str, str, dict[str, str]]]
|
| 354 |
+
):
|
| 355 |
+
if not edges_data:
|
| 356 |
+
return
|
| 357 |
+
|
| 358 |
+
edges_params = []
|
| 359 |
+
for source_id, target_id, edge_data in edges_data:
|
| 360 |
+
edge_data_copy = edge_data.copy()
|
| 361 |
+
edge_data_copy.setdefault("weight", 0.0)
|
| 362 |
+
|
| 363 |
+
edges_params.append({
|
| 364 |
+
"source_id": source_id,
|
| 365 |
+
"target_id": target_id,
|
| 366 |
+
"edge_data": edge_data_copy
|
| 367 |
+
})
|
| 368 |
+
|
| 369 |
+
async with self.async_driver.session() as session:
|
| 370 |
+
await session.run(
|
| 371 |
+
f"""
|
| 372 |
+
UNWIND $edges AS edge
|
| 373 |
+
MATCH (s:`{self.namespace}`)
|
| 374 |
+
WHERE s.id = edge.source_id
|
| 375 |
+
WITH edge, s
|
| 376 |
+
MATCH (t:`{self.namespace}`)
|
| 377 |
+
WHERE t.id = edge.target_id
|
| 378 |
+
MERGE (s)-[r:RELATED]->(t)
|
| 379 |
+
SET r += edge.edge_data
|
| 380 |
+
""",
|
| 381 |
+
edges=edges_params
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
async def clustering(self, algorithm: str):
|
| 388 |
+
if algorithm != "leiden":
|
| 389 |
+
raise ValueError(
|
| 390 |
+
f"Clustering algorithm {algorithm} not supported in Neo4j implementation"
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
random_seed = self.global_config["graph_cluster_seed"]
|
| 394 |
+
max_level = self.global_config["max_graph_cluster_size"]
|
| 395 |
+
async with self.async_driver.session() as session:
|
| 396 |
+
try:
|
| 397 |
+
# Project the graph with undirected relationships
|
| 398 |
+
await session.run(
|
| 399 |
+
f"""
|
| 400 |
+
CALL gds.graph.project(
|
| 401 |
+
'graph_{self.namespace}',
|
| 402 |
+
['{self.namespace}'],
|
| 403 |
+
{{
|
| 404 |
+
RELATED: {{
|
| 405 |
+
orientation: 'UNDIRECTED',
|
| 406 |
+
properties: ['weight']
|
| 407 |
+
}}
|
| 408 |
+
}}
|
| 409 |
+
)
|
| 410 |
+
"""
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
# Run Leiden algorithm
|
| 414 |
+
result = await session.run(
|
| 415 |
+
f"""
|
| 416 |
+
CALL gds.leiden.write(
|
| 417 |
+
'graph_{self.namespace}',
|
| 418 |
+
{{
|
| 419 |
+
writeProperty: 'communityIds',
|
| 420 |
+
includeIntermediateCommunities: True,
|
| 421 |
+
relationshipWeightProperty: "weight",
|
| 422 |
+
maxLevels: {max_level},
|
| 423 |
+
tolerance: 0.0001,
|
| 424 |
+
gamma: 1.0,
|
| 425 |
+
theta: 0.01,
|
| 426 |
+
randomSeed: {random_seed}
|
| 427 |
+
}}
|
| 428 |
+
)
|
| 429 |
+
YIELD communityCount, modularities;
|
| 430 |
+
"""
|
| 431 |
+
)
|
| 432 |
+
result = await result.single()
|
| 433 |
+
community_count: int = result["communityCount"]
|
| 434 |
+
modularities = result["modularities"]
|
| 435 |
+
logger.info(
|
| 436 |
+
f"Performed graph clustering with {community_count} communities and modularities {modularities}"
|
| 437 |
+
)
|
| 438 |
+
finally:
|
| 439 |
+
# Drop the projected graph
|
| 440 |
+
await session.run(f"CALL gds.graph.drop('graph_{self.namespace}')")
|
| 441 |
+
|
| 442 |
+
async def community_schema(self) -> dict[str, SingleCommunitySchema]:
|
| 443 |
+
results = defaultdict(
|
| 444 |
+
lambda: dict(
|
| 445 |
+
level=None,
|
| 446 |
+
title=None,
|
| 447 |
+
edges=set(),
|
| 448 |
+
nodes=set(),
|
| 449 |
+
chunk_ids=set(),
|
| 450 |
+
occurrence=0.0,
|
| 451 |
+
sub_communities=[],
|
| 452 |
+
)
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
async with self.async_driver.session() as session:
|
| 456 |
+
# Fetch community data
|
| 457 |
+
result = await session.run(
|
| 458 |
+
f"""
|
| 459 |
+
MATCH (n:`{self.namespace}`)
|
| 460 |
+
WITH n, n.communityIds AS communityIds, [(n)-[]-(m:`{self.namespace}`) | m.id] AS connected_nodes
|
| 461 |
+
RETURN n.id AS node_id, n.source_id AS source_id,
|
| 462 |
+
communityIds AS cluster_key,
|
| 463 |
+
connected_nodes
|
| 464 |
+
"""
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
# records = await result.fetch()
|
| 468 |
+
|
| 469 |
+
max_num_ids = 0
|
| 470 |
+
async for record in result:
|
| 471 |
+
for index, c_id in enumerate(record["cluster_key"]):
|
| 472 |
+
node_id = str(record["node_id"])
|
| 473 |
+
source_id = record["source_id"]
|
| 474 |
+
level = index
|
| 475 |
+
cluster_key = str(c_id)
|
| 476 |
+
connected_nodes = record["connected_nodes"]
|
| 477 |
+
|
| 478 |
+
results[cluster_key]["level"] = level
|
| 479 |
+
results[cluster_key]["title"] = f"Cluster {cluster_key}"
|
| 480 |
+
results[cluster_key]["nodes"].add(node_id)
|
| 481 |
+
results[cluster_key]["edges"].update(
|
| 482 |
+
[
|
| 483 |
+
tuple(sorted([node_id, str(connected)]))
|
| 484 |
+
for connected in connected_nodes
|
| 485 |
+
if connected != node_id
|
| 486 |
+
]
|
| 487 |
+
)
|
| 488 |
+
chunk_ids = source_id.split(GRAPH_FIELD_SEP)
|
| 489 |
+
results[cluster_key]["chunk_ids"].update(chunk_ids)
|
| 490 |
+
max_num_ids = max(
|
| 491 |
+
max_num_ids, len(results[cluster_key]["chunk_ids"])
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
# Process results
|
| 495 |
+
for k, v in results.items():
|
| 496 |
+
v["edges"] = [list(e) for e in v["edges"]]
|
| 497 |
+
v["nodes"] = list(v["nodes"])
|
| 498 |
+
v["chunk_ids"] = list(v["chunk_ids"])
|
| 499 |
+
v["occurrence"] = len(v["chunk_ids"]) / max_num_ids
|
| 500 |
+
|
| 501 |
+
# Compute sub-communities (this is a simplified approach)
|
| 502 |
+
for cluster in results.values():
|
| 503 |
+
cluster["sub_communities"] = [
|
| 504 |
+
sub_key
|
| 505 |
+
for sub_key, sub_cluster in results.items()
|
| 506 |
+
if sub_cluster["level"] > cluster["level"]
|
| 507 |
+
and set(sub_cluster["nodes"]).issubset(set(cluster["nodes"]))
|
| 508 |
+
]
|
| 509 |
+
|
| 510 |
+
return dict(results)
|
| 511 |
+
|
| 512 |
+
async def index_done_callback(self):
|
| 513 |
+
await self.async_driver.close()
|
| 514 |
+
|
| 515 |
+
async def _debug_delete_all_node_edges(self):
|
| 516 |
+
async with self.async_driver.session() as session:
|
| 517 |
+
try:
|
| 518 |
+
# Delete all relationships in the namespace
|
| 519 |
+
await session.run(f"MATCH (n:`{self.namespace}`)-[r]-() DELETE r")
|
| 520 |
+
|
| 521 |
+
# Delete all nodes in the namespace
|
| 522 |
+
await session.run(f"MATCH (n:`{self.namespace}`) DELETE n")
|
| 523 |
+
|
| 524 |
+
logger.info(
|
| 525 |
+
f"All nodes and edges in namespace '{self.namespace}' have been deleted."
|
| 526 |
+
)
|
| 527 |
+
except Exception as e:
|
| 528 |
+
logger.error(f"Error deleting nodes and edges: {str(e)}")
|
| 529 |
+
raise
|
nano-graphrag/nano_graphrag/_storage/gdb_networkx.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import html
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Any, Union, cast, List
|
| 7 |
+
import networkx as nx
|
| 8 |
+
import numpy as np
|
| 9 |
+
import asyncio
|
| 10 |
+
|
| 11 |
+
from .._utils import logger
|
| 12 |
+
from ..base import (
|
| 13 |
+
BaseGraphStorage,
|
| 14 |
+
SingleCommunitySchema,
|
| 15 |
+
)
|
| 16 |
+
from ..prompt import GRAPH_FIELD_SEP
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class NetworkXStorage(BaseGraphStorage):
|
| 21 |
+
@staticmethod
|
| 22 |
+
def load_nx_graph(file_name) -> nx.Graph:
|
| 23 |
+
if os.path.exists(file_name):
|
| 24 |
+
return nx.read_graphml(file_name)
|
| 25 |
+
return None
|
| 26 |
+
|
| 27 |
+
@staticmethod
|
| 28 |
+
def write_nx_graph(graph: nx.Graph, file_name):
|
| 29 |
+
logger.info(
|
| 30 |
+
f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges"
|
| 31 |
+
)
|
| 32 |
+
nx.write_graphml(graph, file_name)
|
| 33 |
+
|
| 34 |
+
@staticmethod
|
| 35 |
+
def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
|
| 36 |
+
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
|
| 37 |
+
Return the largest connected component of the graph, with nodes and edges sorted in a stable way.
|
| 38 |
+
"""
|
| 39 |
+
from graspologic.utils import largest_connected_component
|
| 40 |
+
|
| 41 |
+
graph = graph.copy()
|
| 42 |
+
graph = cast(nx.Graph, largest_connected_component(graph))
|
| 43 |
+
node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()} # type: ignore
|
| 44 |
+
graph = nx.relabel_nodes(graph, node_mapping)
|
| 45 |
+
return NetworkXStorage._stabilize_graph(graph)
|
| 46 |
+
|
| 47 |
+
@staticmethod
|
| 48 |
+
def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
|
| 49 |
+
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
|
| 50 |
+
Ensure an undirected graph with the same relationships will always be read the same way.
|
| 51 |
+
"""
|
| 52 |
+
fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
|
| 53 |
+
|
| 54 |
+
sorted_nodes = graph.nodes(data=True)
|
| 55 |
+
sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
|
| 56 |
+
|
| 57 |
+
fixed_graph.add_nodes_from(sorted_nodes)
|
| 58 |
+
edges = list(graph.edges(data=True))
|
| 59 |
+
|
| 60 |
+
if not graph.is_directed():
|
| 61 |
+
|
| 62 |
+
def _sort_source_target(edge):
|
| 63 |
+
source, target, edge_data = edge
|
| 64 |
+
if source > target:
|
| 65 |
+
temp = source
|
| 66 |
+
source = target
|
| 67 |
+
target = temp
|
| 68 |
+
return source, target, edge_data
|
| 69 |
+
|
| 70 |
+
edges = [_sort_source_target(edge) for edge in edges]
|
| 71 |
+
|
| 72 |
+
def _get_edge_key(source: Any, target: Any) -> str:
|
| 73 |
+
return f"{source} -> {target}"
|
| 74 |
+
|
| 75 |
+
edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
|
| 76 |
+
|
| 77 |
+
fixed_graph.add_edges_from(edges)
|
| 78 |
+
return fixed_graph
|
| 79 |
+
|
| 80 |
+
def __post_init__(self):
|
| 81 |
+
self._graphml_xml_file = os.path.join(
|
| 82 |
+
self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
|
| 83 |
+
)
|
| 84 |
+
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
|
| 85 |
+
if preloaded_graph is not None:
|
| 86 |
+
logger.info(
|
| 87 |
+
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
|
| 88 |
+
)
|
| 89 |
+
self._graph = preloaded_graph or nx.Graph()
|
| 90 |
+
self._clustering_algorithms = {
|
| 91 |
+
"leiden": self._leiden_clustering,
|
| 92 |
+
}
|
| 93 |
+
self._node_embed_algorithms = {
|
| 94 |
+
"node2vec": self._node2vec_embed,
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
async def index_done_callback(self):
|
| 98 |
+
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
|
| 99 |
+
|
| 100 |
+
async def has_node(self, node_id: str) -> bool:
|
| 101 |
+
return self._graph.has_node(node_id)
|
| 102 |
+
|
| 103 |
+
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
| 104 |
+
return self._graph.has_edge(source_node_id, target_node_id)
|
| 105 |
+
|
| 106 |
+
async def get_node(self, node_id: str) -> Union[dict, None]:
|
| 107 |
+
return self._graph.nodes.get(node_id)
|
| 108 |
+
|
| 109 |
+
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, Union[dict, None]]:
|
| 110 |
+
return await asyncio.gather(*[self.get_node(node_id) for node_id in node_ids])
|
| 111 |
+
|
| 112 |
+
async def node_degree(self, node_id: str) -> int:
|
| 113 |
+
# [numberchiffre]: node_id not part of graph returns `DegreeView({})` instead of 0
|
| 114 |
+
return self._graph.degree(node_id) if self._graph.has_node(node_id) else 0
|
| 115 |
+
|
| 116 |
+
async def node_degrees_batch(self, node_ids: List[str]) -> List[str]:
|
| 117 |
+
return await asyncio.gather(*[self.node_degree(node_id) for node_id in node_ids])
|
| 118 |
+
|
| 119 |
+
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
| 120 |
+
return (self._graph.degree(src_id) if self._graph.has_node(src_id) else 0) + (
|
| 121 |
+
self._graph.degree(tgt_id) if self._graph.has_node(tgt_id) else 0
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> list[int]:
|
| 125 |
+
return await asyncio.gather(*[self.edge_degree(src_id, tgt_id) for src_id, tgt_id in edge_pairs])
|
| 126 |
+
|
| 127 |
+
async def get_edge(
|
| 128 |
+
self, source_node_id: str, target_node_id: str
|
| 129 |
+
) -> Union[dict, None]:
|
| 130 |
+
return self._graph.edges.get((source_node_id, target_node_id))
|
| 131 |
+
|
| 132 |
+
async def get_edges_batch(
|
| 133 |
+
self, edge_pairs: list[tuple[str, str]]
|
| 134 |
+
) -> list[Union[dict, None]]:
|
| 135 |
+
return await asyncio.gather(*[self.get_edge(source_node_id, target_node_id) for source_node_id, target_node_id in edge_pairs])
|
| 136 |
+
|
| 137 |
+
async def get_node_edges(self, source_node_id: str):
|
| 138 |
+
if self._graph.has_node(source_node_id):
|
| 139 |
+
return list(self._graph.edges(source_node_id))
|
| 140 |
+
return None
|
| 141 |
+
|
| 142 |
+
async def get_nodes_edges_batch(
|
| 143 |
+
self, node_ids: list[str]
|
| 144 |
+
) -> list[list[tuple[str, str]]]:
|
| 145 |
+
return await asyncio.gather(*[self.get_node_edges(node_id) for node_id
|
| 146 |
+
in node_ids])
|
| 147 |
+
|
| 148 |
+
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
| 149 |
+
self._graph.add_node(node_id, **node_data)
|
| 150 |
+
|
| 151 |
+
async def upsert_nodes_batch(self, nodes_data: list[tuple[str, dict[str, str]]]):
|
| 152 |
+
await asyncio.gather(*[self.upsert_node(node_id, node_data) for node_id, node_data in nodes_data])
|
| 153 |
+
|
| 154 |
+
async def upsert_edge(
|
| 155 |
+
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
| 156 |
+
):
|
| 157 |
+
self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
| 158 |
+
|
| 159 |
+
async def upsert_edges_batch(
|
| 160 |
+
self, edges_data: list[tuple[str, str, dict[str, str]]]
|
| 161 |
+
):
|
| 162 |
+
await asyncio.gather(*[self.upsert_edge(source_node_id, target_node_id, edge_data)
|
| 163 |
+
for source_node_id, target_node_id, edge_data in edges_data])
|
| 164 |
+
|
| 165 |
+
async def clustering(self, algorithm: str):
|
| 166 |
+
if algorithm not in self._clustering_algorithms:
|
| 167 |
+
raise ValueError(f"Clustering algorithm {algorithm} not supported")
|
| 168 |
+
await self._clustering_algorithms[algorithm]()
|
| 169 |
+
|
| 170 |
+
async def community_schema(self) -> dict[str, SingleCommunitySchema]:
|
| 171 |
+
results = defaultdict(
|
| 172 |
+
lambda: dict(
|
| 173 |
+
level=None,
|
| 174 |
+
title=None,
|
| 175 |
+
edges=set(),
|
| 176 |
+
nodes=set(),
|
| 177 |
+
chunk_ids=set(),
|
| 178 |
+
occurrence=0.0,
|
| 179 |
+
sub_communities=[],
|
| 180 |
+
)
|
| 181 |
+
)
|
| 182 |
+
max_num_ids = 0
|
| 183 |
+
levels = defaultdict(set)
|
| 184 |
+
for node_id, node_data in self._graph.nodes(data=True):
|
| 185 |
+
if "clusters" not in node_data:
|
| 186 |
+
continue
|
| 187 |
+
clusters = json.loads(node_data["clusters"])
|
| 188 |
+
this_node_edges = self._graph.edges(node_id)
|
| 189 |
+
|
| 190 |
+
for cluster in clusters:
|
| 191 |
+
level = cluster["level"]
|
| 192 |
+
cluster_key = str(cluster["cluster"])
|
| 193 |
+
levels[level].add(cluster_key)
|
| 194 |
+
results[cluster_key]["level"] = level
|
| 195 |
+
results[cluster_key]["title"] = f"Cluster {cluster_key}"
|
| 196 |
+
results[cluster_key]["nodes"].add(node_id)
|
| 197 |
+
results[cluster_key]["edges"].update(
|
| 198 |
+
[tuple(sorted(e)) for e in this_node_edges]
|
| 199 |
+
)
|
| 200 |
+
results[cluster_key]["chunk_ids"].update(
|
| 201 |
+
node_data["source_id"].split(GRAPH_FIELD_SEP)
|
| 202 |
+
)
|
| 203 |
+
max_num_ids = max(max_num_ids, len(results[cluster_key]["chunk_ids"]))
|
| 204 |
+
|
| 205 |
+
ordered_levels = sorted(levels.keys())
|
| 206 |
+
for i, curr_level in enumerate(ordered_levels[:-1]):
|
| 207 |
+
next_level = ordered_levels[i + 1]
|
| 208 |
+
this_level_comms = levels[curr_level]
|
| 209 |
+
next_level_comms = levels[next_level]
|
| 210 |
+
# compute the sub-communities by nodes intersection
|
| 211 |
+
for comm in this_level_comms:
|
| 212 |
+
results[comm]["sub_communities"] = [
|
| 213 |
+
c
|
| 214 |
+
for c in next_level_comms
|
| 215 |
+
if results[c]["nodes"].issubset(results[comm]["nodes"])
|
| 216 |
+
]
|
| 217 |
+
|
| 218 |
+
for k, v in results.items():
|
| 219 |
+
v["edges"] = list(v["edges"])
|
| 220 |
+
v["edges"] = [list(e) for e in v["edges"]]
|
| 221 |
+
v["nodes"] = list(v["nodes"])
|
| 222 |
+
v["chunk_ids"] = list(v["chunk_ids"])
|
| 223 |
+
v["occurrence"] = len(v["chunk_ids"]) / max_num_ids
|
| 224 |
+
return dict(results)
|
| 225 |
+
|
| 226 |
+
def _cluster_data_to_subgraphs(self, cluster_data: dict[str, list[dict[str, str]]]):
|
| 227 |
+
for node_id, clusters in cluster_data.items():
|
| 228 |
+
self._graph.nodes[node_id]["clusters"] = json.dumps(clusters)
|
| 229 |
+
|
| 230 |
+
async def _leiden_clustering(self):
|
| 231 |
+
from graspologic.partition import hierarchical_leiden
|
| 232 |
+
|
| 233 |
+
graph = NetworkXStorage.stable_largest_connected_component(self._graph)
|
| 234 |
+
community_mapping = hierarchical_leiden(
|
| 235 |
+
graph,
|
| 236 |
+
max_cluster_size=self.global_config["max_graph_cluster_size"],
|
| 237 |
+
random_seed=self.global_config["graph_cluster_seed"],
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
node_communities: dict[str, list[dict[str, str]]] = defaultdict(list)
|
| 241 |
+
__levels = defaultdict(set)
|
| 242 |
+
for partition in community_mapping:
|
| 243 |
+
level_key = partition.level
|
| 244 |
+
cluster_id = partition.cluster
|
| 245 |
+
node_communities[partition.node].append(
|
| 246 |
+
{"level": level_key, "cluster": cluster_id}
|
| 247 |
+
)
|
| 248 |
+
__levels[level_key].add(cluster_id)
|
| 249 |
+
node_communities = dict(node_communities)
|
| 250 |
+
__levels = {k: len(v) for k, v in __levels.items()}
|
| 251 |
+
logger.info(f"Each level has communities: {dict(__levels)}")
|
| 252 |
+
self._cluster_data_to_subgraphs(node_communities)
|
| 253 |
+
|
| 254 |
+
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
| 255 |
+
if algorithm not in self._node_embed_algorithms:
|
| 256 |
+
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
| 257 |
+
return await self._node_embed_algorithms[algorithm]()
|
| 258 |
+
|
| 259 |
+
async def _node2vec_embed(self):
|
| 260 |
+
from graspologic import embed
|
| 261 |
+
|
| 262 |
+
embeddings, nodes = embed.node2vec_embed(
|
| 263 |
+
self._graph,
|
| 264 |
+
**self.global_config["node2vec_params"],
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
|
| 268 |
+
return embeddings, nodes_ids
|
nano-graphrag/nano_graphrag/_storage/kv_json.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
|
| 4 |
+
from .._utils import load_json, logger, write_json
|
| 5 |
+
from ..base import (
|
| 6 |
+
BaseKVStorage,
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class JsonKVStorage(BaseKVStorage):
|
| 12 |
+
def __post_init__(self):
|
| 13 |
+
working_dir = self.global_config["working_dir"]
|
| 14 |
+
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
| 15 |
+
self._data = load_json(self._file_name) or {}
|
| 16 |
+
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
|
| 17 |
+
|
| 18 |
+
async def all_keys(self) -> list[str]:
|
| 19 |
+
return list(self._data.keys())
|
| 20 |
+
|
| 21 |
+
async def index_done_callback(self):
|
| 22 |
+
write_json(self._data, self._file_name)
|
| 23 |
+
|
| 24 |
+
async def get_by_id(self, id):
|
| 25 |
+
return self._data.get(id, None)
|
| 26 |
+
|
| 27 |
+
async def get_by_ids(self, ids, fields=None):
|
| 28 |
+
if fields is None:
|
| 29 |
+
return [self._data.get(id, None) for id in ids]
|
| 30 |
+
return [
|
| 31 |
+
(
|
| 32 |
+
{k: v for k, v in self._data[id].items() if k in fields}
|
| 33 |
+
if self._data.get(id, None)
|
| 34 |
+
else None
|
| 35 |
+
)
|
| 36 |
+
for id in ids
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
async def filter_keys(self, data: list[str]) -> set[str]:
|
| 40 |
+
return set([s for s in data if s not in self._data])
|
| 41 |
+
|
| 42 |
+
async def upsert(self, data: dict[str, dict]):
|
| 43 |
+
self._data.update(data)
|
| 44 |
+
|
| 45 |
+
async def drop(self):
|
| 46 |
+
self._data = {}
|
nano-graphrag/nano_graphrag/_storage/vdb_hnswlib.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import os
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from typing import Any
|
| 5 |
+
import pickle
|
| 6 |
+
import hnswlib
|
| 7 |
+
import numpy as np
|
| 8 |
+
import xxhash
|
| 9 |
+
|
| 10 |
+
from .._utils import logger
|
| 11 |
+
from ..base import BaseVectorStorage
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class HNSWVectorStorage(BaseVectorStorage):
|
| 16 |
+
ef_construction: int = 100
|
| 17 |
+
M: int = 16
|
| 18 |
+
max_elements: int = 1000000
|
| 19 |
+
ef_search: int = 50
|
| 20 |
+
num_threads: int = -1
|
| 21 |
+
_index: Any = field(init=False)
|
| 22 |
+
_metadata: dict[str, dict] = field(default_factory=dict)
|
| 23 |
+
_current_elements: int = 0
|
| 24 |
+
|
| 25 |
+
def __post_init__(self):
|
| 26 |
+
self._index_file_name = os.path.join(
|
| 27 |
+
self.global_config["working_dir"], f"{self.namespace}_hnsw.index"
|
| 28 |
+
)
|
| 29 |
+
self._metadata_file_name = os.path.join(
|
| 30 |
+
self.global_config["working_dir"], f"{self.namespace}_hnsw_metadata.pkl"
|
| 31 |
+
)
|
| 32 |
+
self._embedding_batch_num = self.global_config.get("embedding_batch_num", 100)
|
| 33 |
+
|
| 34 |
+
hnsw_params = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
| 35 |
+
self.ef_construction = hnsw_params.get("ef_construction", self.ef_construction)
|
| 36 |
+
self.M = hnsw_params.get("M", self.M)
|
| 37 |
+
self.max_elements = hnsw_params.get("max_elements", self.max_elements)
|
| 38 |
+
self.ef_search = hnsw_params.get("ef_search", self.ef_search)
|
| 39 |
+
self.num_threads = hnsw_params.get("num_threads", self.num_threads)
|
| 40 |
+
self._index = hnswlib.Index(
|
| 41 |
+
space="cosine", dim=self.embedding_func.embedding_dim
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
if os.path.exists(self._index_file_name) and os.path.exists(
|
| 45 |
+
self._metadata_file_name
|
| 46 |
+
):
|
| 47 |
+
self._index.load_index(
|
| 48 |
+
self._index_file_name, max_elements=self.max_elements
|
| 49 |
+
)
|
| 50 |
+
with open(self._metadata_file_name, "rb") as f:
|
| 51 |
+
self._metadata, self._current_elements = pickle.load(f)
|
| 52 |
+
logger.info(
|
| 53 |
+
f"Loaded existing index for {self.namespace} with {self._current_elements} elements"
|
| 54 |
+
)
|
| 55 |
+
else:
|
| 56 |
+
self._index.init_index(
|
| 57 |
+
max_elements=self.max_elements,
|
| 58 |
+
ef_construction=self.ef_construction,
|
| 59 |
+
M=self.M,
|
| 60 |
+
)
|
| 61 |
+
self._index.set_ef(self.ef_search)
|
| 62 |
+
self._metadata = {}
|
| 63 |
+
self._current_elements = 0
|
| 64 |
+
logger.info(f"Created new index for {self.namespace}")
|
| 65 |
+
|
| 66 |
+
async def upsert(self, data: dict[str, dict]) -> np.ndarray:
|
| 67 |
+
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
| 68 |
+
if not data:
|
| 69 |
+
logger.warning("You insert an empty data to vector DB")
|
| 70 |
+
return []
|
| 71 |
+
|
| 72 |
+
if self._current_elements + len(data) > self.max_elements:
|
| 73 |
+
raise ValueError(
|
| 74 |
+
f"Cannot insert {len(data)} elements. Current: {self._current_elements}, Max: {self.max_elements}"
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
list_data = [
|
| 78 |
+
{
|
| 79 |
+
"id": k,
|
| 80 |
+
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
|
| 81 |
+
}
|
| 82 |
+
for k, v in data.items()
|
| 83 |
+
]
|
| 84 |
+
contents = [v["content"] for v in data.values()]
|
| 85 |
+
batch_size = min(self._embedding_batch_num, len(contents))
|
| 86 |
+
embeddings = np.concatenate(
|
| 87 |
+
await asyncio.gather(
|
| 88 |
+
*[
|
| 89 |
+
self.embedding_func(contents[i : i + batch_size])
|
| 90 |
+
for i in range(0, len(contents), batch_size)
|
| 91 |
+
]
|
| 92 |
+
)
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
ids = np.fromiter(
|
| 96 |
+
(xxhash.xxh32_intdigest(d["id"].encode()) for d in list_data),
|
| 97 |
+
dtype=np.uint32,
|
| 98 |
+
count=len(list_data),
|
| 99 |
+
)
|
| 100 |
+
self._metadata.update(
|
| 101 |
+
{
|
| 102 |
+
id_int: {
|
| 103 |
+
k: v for k, v in d.items() if k in self.meta_fields or k == "id"
|
| 104 |
+
}
|
| 105 |
+
for id_int, d in zip(ids, list_data)
|
| 106 |
+
}
|
| 107 |
+
)
|
| 108 |
+
self._index.add_items(data=embeddings, ids=ids, num_threads=self.num_threads)
|
| 109 |
+
self._current_elements = self._index.get_current_count()
|
| 110 |
+
return ids
|
| 111 |
+
|
| 112 |
+
async def query(self, query: str, top_k: int = 5) -> list[dict]:
|
| 113 |
+
if self._current_elements == 0:
|
| 114 |
+
return []
|
| 115 |
+
|
| 116 |
+
top_k = min(top_k, self._current_elements)
|
| 117 |
+
|
| 118 |
+
if top_k > self.ef_search:
|
| 119 |
+
logger.warning(
|
| 120 |
+
f"Setting ef_search to {top_k} because top_k is larger than ef_search"
|
| 121 |
+
)
|
| 122 |
+
self._index.set_ef(top_k)
|
| 123 |
+
|
| 124 |
+
embedding = await self.embedding_func([query])
|
| 125 |
+
labels, distances = self._index.knn_query(
|
| 126 |
+
data=embedding[0], k=top_k, num_threads=self.num_threads
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
return [
|
| 130 |
+
{
|
| 131 |
+
**self._metadata.get(label, {}),
|
| 132 |
+
"distance": distance,
|
| 133 |
+
"similarity": 1 - distance,
|
| 134 |
+
}
|
| 135 |
+
for label, distance in zip(labels[0], distances[0])
|
| 136 |
+
]
|
| 137 |
+
|
| 138 |
+
async def index_done_callback(self):
|
| 139 |
+
self._index.save_index(self._index_file_name)
|
| 140 |
+
with open(self._metadata_file_name, "wb") as f:
|
| 141 |
+
pickle.dump((self._metadata, self._current_elements), f)
|
nano-graphrag/nano_graphrag/_storage/vdb_nanovectordb.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import os
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
import numpy as np
|
| 5 |
+
from nano_vectordb import NanoVectorDB
|
| 6 |
+
|
| 7 |
+
from .._utils import logger
|
| 8 |
+
from ..base import BaseVectorStorage
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class NanoVectorDBStorage(BaseVectorStorage):
|
| 13 |
+
cosine_better_than_threshold: float = 0.2
|
| 14 |
+
|
| 15 |
+
def __post_init__(self):
|
| 16 |
+
|
| 17 |
+
self._client_file_name = os.path.join(
|
| 18 |
+
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
| 19 |
+
)
|
| 20 |
+
self._max_batch_size = self.global_config["embedding_batch_num"]
|
| 21 |
+
self._client = NanoVectorDB(
|
| 22 |
+
self.embedding_func.embedding_dim, storage_file=self._client_file_name
|
| 23 |
+
)
|
| 24 |
+
self.cosine_better_than_threshold = self.global_config.get(
|
| 25 |
+
"query_better_than_threshold", self.cosine_better_than_threshold
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
async def upsert(self, data: dict[str, dict]):
|
| 29 |
+
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
| 30 |
+
if not len(data):
|
| 31 |
+
logger.warning("You insert an empty data to vector DB")
|
| 32 |
+
return []
|
| 33 |
+
list_data = [
|
| 34 |
+
{
|
| 35 |
+
"__id__": k,
|
| 36 |
+
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
|
| 37 |
+
}
|
| 38 |
+
for k, v in data.items()
|
| 39 |
+
]
|
| 40 |
+
contents = [v["content"] for v in data.values()]
|
| 41 |
+
batches = [
|
| 42 |
+
contents[i : i + self._max_batch_size]
|
| 43 |
+
for i in range(0, len(contents), self._max_batch_size)
|
| 44 |
+
]
|
| 45 |
+
embeddings_list = await asyncio.gather(
|
| 46 |
+
*[self.embedding_func(batch) for batch in batches]
|
| 47 |
+
)
|
| 48 |
+
embeddings = np.concatenate(embeddings_list)
|
| 49 |
+
for i, d in enumerate(list_data):
|
| 50 |
+
d["__vector__"] = embeddings[i]
|
| 51 |
+
results = self._client.upsert(datas=list_data)
|
| 52 |
+
return results
|
| 53 |
+
|
| 54 |
+
async def query(self, query: str, top_k=5):
|
| 55 |
+
embedding = await self.embedding_func([query])
|
| 56 |
+
embedding = embedding[0]
|
| 57 |
+
results = self._client.query(
|
| 58 |
+
query=embedding,
|
| 59 |
+
top_k=top_k,
|
| 60 |
+
better_than_threshold=self.cosine_better_than_threshold,
|
| 61 |
+
)
|
| 62 |
+
results = [
|
| 63 |
+
{**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results
|
| 64 |
+
]
|
| 65 |
+
return results
|
| 66 |
+
|
| 67 |
+
async def index_done_callback(self):
|
| 68 |
+
self._client.save()
|
nano-graphrag/nano_graphrag/_utils.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import html
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
import re
|
| 7 |
+
import numbers
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from functools import wraps
|
| 10 |
+
from hashlib import md5
|
| 11 |
+
from typing import Any, Union, Literal
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import tiktoken
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
from transformers import AutoTokenizer
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger("nano-graphrag")
|
| 20 |
+
logging.getLogger("neo4j").setLevel(logging.ERROR)
|
| 21 |
+
|
| 22 |
+
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
| 23 |
+
try:
|
| 24 |
+
# If there is already an event loop, use it.
|
| 25 |
+
loop = asyncio.get_event_loop()
|
| 26 |
+
except RuntimeError:
|
| 27 |
+
# If in a sub-thread, create a new event loop.
|
| 28 |
+
logger.info("Creating a new event loop in a sub-thread.")
|
| 29 |
+
loop = asyncio.new_event_loop()
|
| 30 |
+
asyncio.set_event_loop(loop)
|
| 31 |
+
return loop
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def extract_first_complete_json(s: str):
|
| 35 |
+
"""Extract the first complete JSON object from the string using a stack to track braces."""
|
| 36 |
+
stack = []
|
| 37 |
+
first_json_start = None
|
| 38 |
+
|
| 39 |
+
for i, char in enumerate(s):
|
| 40 |
+
if char == '{':
|
| 41 |
+
stack.append(i)
|
| 42 |
+
if first_json_start is None:
|
| 43 |
+
first_json_start = i
|
| 44 |
+
elif char == '}':
|
| 45 |
+
if stack:
|
| 46 |
+
start = stack.pop()
|
| 47 |
+
if not stack:
|
| 48 |
+
first_json_str = s[first_json_start:i+1]
|
| 49 |
+
try:
|
| 50 |
+
# Attempt to parse the JSON string
|
| 51 |
+
return json.loads(first_json_str.replace("\n", ""))
|
| 52 |
+
except json.JSONDecodeError as e:
|
| 53 |
+
logger.error(f"JSON decoding failed: {e}. Attempted string: {first_json_str[:50]}...")
|
| 54 |
+
return None
|
| 55 |
+
finally:
|
| 56 |
+
first_json_start = None
|
| 57 |
+
logger.warning("No complete JSON object found in the input string.")
|
| 58 |
+
return None
|
| 59 |
+
|
| 60 |
+
def parse_value(value: str):
|
| 61 |
+
"""Convert a string value to its appropriate type (int, float, bool, None, or keep as string). Work as a more broad 'eval()'"""
|
| 62 |
+
value = value.strip()
|
| 63 |
+
|
| 64 |
+
if value == "null":
|
| 65 |
+
return None
|
| 66 |
+
elif value == "true":
|
| 67 |
+
return True
|
| 68 |
+
elif value == "false":
|
| 69 |
+
return False
|
| 70 |
+
else:
|
| 71 |
+
# Try to convert to int or float
|
| 72 |
+
try:
|
| 73 |
+
if '.' in value: # If there's a dot, it might be a float
|
| 74 |
+
return float(value)
|
| 75 |
+
else:
|
| 76 |
+
return int(value)
|
| 77 |
+
except ValueError:
|
| 78 |
+
# If conversion fails, return the value as-is (likely a string)
|
| 79 |
+
return value.strip('"') # Remove surrounding quotes if they exist
|
| 80 |
+
|
| 81 |
+
def extract_values_from_json(json_string, keys=["reasoning", "answer", "data"], allow_no_quotes=False):
|
| 82 |
+
"""Extract key values from a non-standard or malformed JSON string, handling nested objects."""
|
| 83 |
+
extracted_values = {}
|
| 84 |
+
|
| 85 |
+
# Enhanced pattern to match both quoted and unquoted values, as well as nested objects
|
| 86 |
+
regex_pattern = r'(?P<key>"?\w+"?)\s*:\s*(?P<value>{[^}]*}|".*?"|[^,}]+)'
|
| 87 |
+
|
| 88 |
+
for match in re.finditer(regex_pattern, json_string, re.DOTALL):
|
| 89 |
+
key = match.group('key').strip('"') # Strip quotes from key
|
| 90 |
+
value = match.group('value').strip()
|
| 91 |
+
|
| 92 |
+
# If the value is another nested JSON (starts with '{' and ends with '}'), recursively parse it
|
| 93 |
+
if value.startswith('{') and value.endswith('}'):
|
| 94 |
+
extracted_values[key] = extract_values_from_json(value)
|
| 95 |
+
else:
|
| 96 |
+
# Parse the value into the appropriate type (int, float, bool, etc.)
|
| 97 |
+
extracted_values[key] = parse_value(value)
|
| 98 |
+
|
| 99 |
+
if not extracted_values:
|
| 100 |
+
logger.warning("No values could be extracted from the string.")
|
| 101 |
+
|
| 102 |
+
return extracted_values
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def convert_response_to_json(response: str) -> dict:
|
| 106 |
+
"""Convert response string to JSON, with error handling and fallback to non-standard JSON extraction."""
|
| 107 |
+
prediction_json = extract_first_complete_json(response)
|
| 108 |
+
|
| 109 |
+
if prediction_json is None:
|
| 110 |
+
logger.info("Attempting to extract values from a non-standard JSON string...")
|
| 111 |
+
prediction_json = extract_values_from_json(response, allow_no_quotes=True)
|
| 112 |
+
|
| 113 |
+
if not prediction_json:
|
| 114 |
+
logger.error("Unable to extract meaningful data from the response.")
|
| 115 |
+
else:
|
| 116 |
+
logger.info("JSON data successfully extracted.")
|
| 117 |
+
|
| 118 |
+
return prediction_json
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class TokenizerWrapper:
|
| 124 |
+
def __init__(self, tokenizer_type: Literal["tiktoken", "huggingface"] = "tiktoken", model_name: str = "gpt-4o"):
|
| 125 |
+
self.tokenizer_type = tokenizer_type
|
| 126 |
+
self.model_name = model_name
|
| 127 |
+
self._tokenizer = None
|
| 128 |
+
self._lazy_load_tokenizer()
|
| 129 |
+
|
| 130 |
+
def _lazy_load_tokenizer(self):
|
| 131 |
+
if self._tokenizer is not None:
|
| 132 |
+
return
|
| 133 |
+
logger.info(f"Loading tokenizer: type='{self.tokenizer_type}', name='{self.model_name}'")
|
| 134 |
+
if self.tokenizer_type == "tiktoken":
|
| 135 |
+
self._tokenizer = tiktoken.encoding_for_model(self.model_name)
|
| 136 |
+
elif self.tokenizer_type == "huggingface":
|
| 137 |
+
if AutoTokenizer is None:
|
| 138 |
+
raise ImportError("`transformers` is not installed. Please install it via `pip install transformers` to use HuggingFace tokenizers.")
|
| 139 |
+
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)
|
| 140 |
+
else:
|
| 141 |
+
raise ValueError(f"Unknown tokenizer_type: {self.tokenizer_type}")
|
| 142 |
+
|
| 143 |
+
def get_tokenizer(self):
|
| 144 |
+
"""提供对底层 tokenizer 对象的访问,用于特殊情况(如 decode_batch)。"""
|
| 145 |
+
self._lazy_load_tokenizer()
|
| 146 |
+
return self._tokenizer
|
| 147 |
+
|
| 148 |
+
def encode(self, text: str) -> list[int]:
|
| 149 |
+
self._lazy_load_tokenizer()
|
| 150 |
+
return self._tokenizer.encode(text)
|
| 151 |
+
|
| 152 |
+
def decode(self, tokens: list[int]) -> str:
|
| 153 |
+
self._lazy_load_tokenizer()
|
| 154 |
+
return self._tokenizer.decode(tokens)
|
| 155 |
+
|
| 156 |
+
# +++ 新增 +++: 增加一个批量解码的方法以提高效率,并保持接口一致性
|
| 157 |
+
def decode_batch(self, tokens_list: list[list[int]]) -> list[str]:
|
| 158 |
+
self._lazy_load_tokenizer()
|
| 159 |
+
# HuggingFace tokenizer 有 decode_batch,但 tiktoken 没有,我们用列表推导来模拟
|
| 160 |
+
if self.tokenizer_type == "tiktoken":
|
| 161 |
+
return [self._tokenizer.decode(tokens) for tokens in tokens_list]
|
| 162 |
+
elif self.tokenizer_type == "huggingface":
|
| 163 |
+
return self._tokenizer.batch_decode(tokens_list, skip_special_tokens=True)
|
| 164 |
+
else:
|
| 165 |
+
raise ValueError(f"Unknown tokenizer_type: {self.tokenizer_type}")
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def truncate_list_by_token_size(
|
| 170 |
+
list_data: list,
|
| 171 |
+
key: callable,
|
| 172 |
+
max_token_size: int,
|
| 173 |
+
tokenizer_wrapper: TokenizerWrapper
|
| 174 |
+
):
|
| 175 |
+
"""Truncate a list of data by token size using a provided tokenizer wrapper."""
|
| 176 |
+
if max_token_size <= 0:
|
| 177 |
+
return []
|
| 178 |
+
tokens = 0
|
| 179 |
+
for i, data in enumerate(list_data):
|
| 180 |
+
tokens += len(tokenizer_wrapper.encode(key(data))) + 1 # 防御性,模拟通过\n拼接列表的情况
|
| 181 |
+
if tokens > max_token_size:
|
| 182 |
+
return list_data[:i]
|
| 183 |
+
return list_data
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def compute_mdhash_id(content, prefix: str = ""):
|
| 187 |
+
return prefix + md5(content.encode()).hexdigest()
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def write_json(json_obj, file_name):
|
| 191 |
+
with open(file_name, "w", encoding="utf-8") as f:
|
| 192 |
+
json.dump(json_obj, f, indent=2, ensure_ascii=False)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def load_json(file_name):
|
| 196 |
+
if not os.path.exists(file_name):
|
| 197 |
+
return None
|
| 198 |
+
with open(file_name, encoding="utf-8") as f:
|
| 199 |
+
return json.load(f)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
# it's dirty to type, so it's a good way to have fun
|
| 203 |
+
def pack_user_ass_to_openai_messages(prompt: str, generated_content: str, using_amazon_bedrock: bool):
|
| 204 |
+
if using_amazon_bedrock:
|
| 205 |
+
return [
|
| 206 |
+
{"role": "user", "content": [{"text": prompt}]},
|
| 207 |
+
{"role": "assistant", "content": [{"text": generated_content}]},
|
| 208 |
+
]
|
| 209 |
+
else:
|
| 210 |
+
return [
|
| 211 |
+
{"role": "user", "content": prompt},
|
| 212 |
+
{"role": "assistant", "content": generated_content},
|
| 213 |
+
]
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def is_float_regex(value):
|
| 217 |
+
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def compute_args_hash(*args):
|
| 221 |
+
return md5(str(args).encode()).hexdigest()
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
|
| 225 |
+
"""Split a string by multiple markers"""
|
| 226 |
+
if not markers:
|
| 227 |
+
return [content]
|
| 228 |
+
results = re.split("|".join(re.escape(marker) for marker in markers), content)
|
| 229 |
+
return [r.strip() for r in results if r.strip()]
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def enclose_string_with_quotes(content: Any) -> str:
|
| 233 |
+
"""Enclose a string with quotes"""
|
| 234 |
+
if isinstance(content, numbers.Number):
|
| 235 |
+
return str(content)
|
| 236 |
+
content = str(content)
|
| 237 |
+
content = content.strip().strip("'").strip('"')
|
| 238 |
+
return f'"{content}"'
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def list_of_list_to_csv(data: list[list]):
|
| 242 |
+
return "\n".join(
|
| 243 |
+
[
|
| 244 |
+
",\t".join([f"{enclose_string_with_quotes(data_dd)}" for data_dd in data_d])
|
| 245 |
+
for data_d in data
|
| 246 |
+
]
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
# -----------------------------------------------------------------------------------
|
| 251 |
+
# Refer the utils functions of the official GraphRAG implementation:
|
| 252 |
+
# https://github.com/microsoft/graphrag
|
| 253 |
+
def clean_str(input: Any) -> str:
|
| 254 |
+
"""Clean an input string by removing HTML escapes, control characters, and other unwanted characters."""
|
| 255 |
+
# If we get non-string input, just give it back
|
| 256 |
+
if not isinstance(input, str):
|
| 257 |
+
return input
|
| 258 |
+
|
| 259 |
+
result = html.unescape(input.strip())
|
| 260 |
+
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
|
| 261 |
+
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
# Utils types -----------------------------------------------------------------------
|
| 265 |
+
@dataclass
|
| 266 |
+
class EmbeddingFunc:
|
| 267 |
+
embedding_dim: int
|
| 268 |
+
max_token_size: int
|
| 269 |
+
func: callable
|
| 270 |
+
|
| 271 |
+
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
| 272 |
+
return await self.func(*args, **kwargs)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
# Decorators ------------------------------------------------------------------------
|
| 276 |
+
def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
|
| 277 |
+
"""Add restriction of maximum async calling times for a async func"""
|
| 278 |
+
|
| 279 |
+
def final_decro(func):
|
| 280 |
+
"""Not using async.Semaphore to aovid use nest-asyncio"""
|
| 281 |
+
__current_size = 0
|
| 282 |
+
|
| 283 |
+
@wraps(func)
|
| 284 |
+
async def wait_func(*args, **kwargs):
|
| 285 |
+
nonlocal __current_size
|
| 286 |
+
while __current_size >= max_size:
|
| 287 |
+
await asyncio.sleep(waitting_time)
|
| 288 |
+
__current_size += 1
|
| 289 |
+
result = await func(*args, **kwargs)
|
| 290 |
+
__current_size -= 1
|
| 291 |
+
return result
|
| 292 |
+
|
| 293 |
+
return wait_func
|
| 294 |
+
|
| 295 |
+
return final_decro
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def wrap_embedding_func_with_attrs(**kwargs):
|
| 299 |
+
"""Wrap a function with attributes"""
|
| 300 |
+
|
| 301 |
+
def final_decro(func) -> EmbeddingFunc:
|
| 302 |
+
new_func = EmbeddingFunc(**kwargs, func=func)
|
| 303 |
+
return new_func
|
| 304 |
+
|
| 305 |
+
return final_decro
|
nano-graphrag/nano_graphrag/base.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
from typing import TypedDict, Union, Literal, Generic, TypeVar, List
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from ._utils import EmbeddingFunc
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class QueryParam:
|
| 11 |
+
mode: Literal["local", "global", "naive"] = "global"
|
| 12 |
+
only_need_context: bool = False
|
| 13 |
+
response_type: str = "Multiple Paragraphs"
|
| 14 |
+
level: int = 2
|
| 15 |
+
top_k: int = 20
|
| 16 |
+
# naive search
|
| 17 |
+
naive_max_token_for_text_unit = 12000
|
| 18 |
+
# local search
|
| 19 |
+
local_max_token_for_text_unit: int = 4000 # 12000 * 0.33
|
| 20 |
+
local_max_token_for_local_context: int = 4800 # 12000 * 0.4
|
| 21 |
+
local_max_token_for_community_report: int = 3200 # 12000 * 0.27
|
| 22 |
+
local_community_single_one: bool = False
|
| 23 |
+
# global search
|
| 24 |
+
global_min_community_rating: float = 0
|
| 25 |
+
global_max_consider_community: float = 512
|
| 26 |
+
global_max_token_for_community_report: int = 16384
|
| 27 |
+
global_special_community_map_llm_kwargs: dict = field(
|
| 28 |
+
default_factory=lambda: {"response_format": {"type": "json_object"}}
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
TextChunkSchema = TypedDict(
|
| 33 |
+
"TextChunkSchema",
|
| 34 |
+
{"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int},
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
SingleCommunitySchema = TypedDict(
|
| 38 |
+
"SingleCommunitySchema",
|
| 39 |
+
{
|
| 40 |
+
"level": int,
|
| 41 |
+
"title": str,
|
| 42 |
+
"edges": list[list[str, str]],
|
| 43 |
+
"nodes": list[str],
|
| 44 |
+
"chunk_ids": list[str],
|
| 45 |
+
"occurrence": float,
|
| 46 |
+
"sub_communities": list[str],
|
| 47 |
+
},
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class CommunitySchema(SingleCommunitySchema):
|
| 52 |
+
report_string: str
|
| 53 |
+
report_json: dict
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
T = TypeVar("T")
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@dataclass
|
| 60 |
+
class StorageNameSpace:
|
| 61 |
+
namespace: str
|
| 62 |
+
global_config: dict
|
| 63 |
+
|
| 64 |
+
async def index_start_callback(self):
|
| 65 |
+
"""commit the storage operations after indexing"""
|
| 66 |
+
pass
|
| 67 |
+
|
| 68 |
+
async def index_done_callback(self):
|
| 69 |
+
"""commit the storage operations after indexing"""
|
| 70 |
+
pass
|
| 71 |
+
|
| 72 |
+
async def query_done_callback(self):
|
| 73 |
+
"""commit the storage operations after querying"""
|
| 74 |
+
pass
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@dataclass
|
| 78 |
+
class BaseVectorStorage(StorageNameSpace):
|
| 79 |
+
embedding_func: EmbeddingFunc
|
| 80 |
+
meta_fields: set = field(default_factory=set)
|
| 81 |
+
|
| 82 |
+
async def query(self, query: str, top_k: int) -> list[dict]:
|
| 83 |
+
raise NotImplementedError
|
| 84 |
+
|
| 85 |
+
async def upsert(self, data: dict[str, dict]):
|
| 86 |
+
"""Use 'content' field from value for embedding, use key as id.
|
| 87 |
+
If embedding_func is None, use 'embedding' field from value
|
| 88 |
+
"""
|
| 89 |
+
raise NotImplementedError
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@dataclass
|
| 93 |
+
class BaseKVStorage(Generic[T], StorageNameSpace):
|
| 94 |
+
async def all_keys(self) -> list[str]:
|
| 95 |
+
raise NotImplementedError
|
| 96 |
+
|
| 97 |
+
async def get_by_id(self, id: str) -> Union[T, None]:
|
| 98 |
+
raise NotImplementedError
|
| 99 |
+
|
| 100 |
+
async def get_by_ids(
|
| 101 |
+
self, ids: list[str], fields: Union[set[str], None] = None
|
| 102 |
+
) -> list[Union[T, None]]:
|
| 103 |
+
raise NotImplementedError
|
| 104 |
+
|
| 105 |
+
async def filter_keys(self, data: list[str]) -> set[str]:
|
| 106 |
+
"""return un-exist keys"""
|
| 107 |
+
raise NotImplementedError
|
| 108 |
+
|
| 109 |
+
async def upsert(self, data: dict[str, T]):
|
| 110 |
+
raise NotImplementedError
|
| 111 |
+
|
| 112 |
+
async def drop(self):
|
| 113 |
+
raise NotImplementedError
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
@dataclass
|
| 117 |
+
class BaseGraphStorage(StorageNameSpace):
|
| 118 |
+
async def has_node(self, node_id: str) -> bool:
|
| 119 |
+
raise NotImplementedError
|
| 120 |
+
|
| 121 |
+
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
| 122 |
+
raise NotImplementedError
|
| 123 |
+
|
| 124 |
+
async def node_degree(self, node_id: str) -> int:
|
| 125 |
+
raise NotImplementedError
|
| 126 |
+
|
| 127 |
+
async def node_degrees_batch(self, node_ids: List[str]) -> List[str]:
|
| 128 |
+
raise NotImplementedError
|
| 129 |
+
|
| 130 |
+
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
| 131 |
+
raise NotImplementedError
|
| 132 |
+
|
| 133 |
+
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> list[int]:
|
| 134 |
+
raise NotImplementedError
|
| 135 |
+
|
| 136 |
+
async def get_node(self, node_id: str) -> Union[dict, None]:
|
| 137 |
+
raise NotImplementedError
|
| 138 |
+
|
| 139 |
+
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, Union[dict, None]]:
|
| 140 |
+
raise NotImplementedError
|
| 141 |
+
|
| 142 |
+
async def get_edge(
|
| 143 |
+
self, source_node_id: str, target_node_id: str
|
| 144 |
+
) -> Union[dict, None]:
|
| 145 |
+
raise NotImplementedError
|
| 146 |
+
|
| 147 |
+
async def get_edges_batch(
|
| 148 |
+
self, edge_pairs: list[tuple[str, str]]
|
| 149 |
+
) -> list[Union[dict, None]]:
|
| 150 |
+
raise NotImplementedError
|
| 151 |
+
|
| 152 |
+
async def get_node_edges(
|
| 153 |
+
self, source_node_id: str
|
| 154 |
+
) -> Union[list[tuple[str, str]], None]:
|
| 155 |
+
raise NotImplementedError
|
| 156 |
+
|
| 157 |
+
async def get_nodes_edges_batch(
|
| 158 |
+
self, node_ids: list[str]
|
| 159 |
+
) -> list[list[tuple[str, str]]]:
|
| 160 |
+
raise NotImplementedError
|
| 161 |
+
|
| 162 |
+
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
| 163 |
+
raise NotImplementedError
|
| 164 |
+
|
| 165 |
+
async def upsert_nodes_batch(self, nodes_data: list[tuple[str, dict[str, str]]]):
|
| 166 |
+
raise NotImplementedError
|
| 167 |
+
|
| 168 |
+
async def upsert_edge(
|
| 169 |
+
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
| 170 |
+
):
|
| 171 |
+
raise NotImplementedError
|
| 172 |
+
|
| 173 |
+
async def upsert_edges_batch(
|
| 174 |
+
self, edges_data: list[tuple[str, str, dict[str, str]]]
|
| 175 |
+
):
|
| 176 |
+
raise NotImplementedError
|
| 177 |
+
|
| 178 |
+
async def clustering(self, algorithm: str):
|
| 179 |
+
raise NotImplementedError
|
| 180 |
+
|
| 181 |
+
async def community_schema(self) -> dict[str, SingleCommunitySchema]:
|
| 182 |
+
"""Return the community representation with report and nodes"""
|
| 183 |
+
raise NotImplementedError
|
| 184 |
+
|
| 185 |
+
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
| 186 |
+
raise NotImplementedError("Node embedding is not used in nano-graphrag.")
|
nano-graphrag/nano_graphrag/entity_extraction/__init__.py
ADDED
|
File without changes
|
nano-graphrag/nano_graphrag/entity_extraction/extract.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Union
|
| 2 |
+
import pickle
|
| 3 |
+
import asyncio
|
| 4 |
+
from openai import BadRequestError
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
import dspy
|
| 7 |
+
from nano_graphrag.base import (
|
| 8 |
+
BaseGraphStorage,
|
| 9 |
+
BaseVectorStorage,
|
| 10 |
+
TextChunkSchema,
|
| 11 |
+
)
|
| 12 |
+
from nano_graphrag.prompt import PROMPTS
|
| 13 |
+
from nano_graphrag._utils import logger, compute_mdhash_id
|
| 14 |
+
from nano_graphrag.entity_extraction.module import TypedEntityRelationshipExtractor
|
| 15 |
+
from nano_graphrag._op import _merge_edges_then_upsert, _merge_nodes_then_upsert
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
async def generate_dataset(
|
| 19 |
+
chunks: dict[str, TextChunkSchema],
|
| 20 |
+
filepath: str,
|
| 21 |
+
save_dataset: bool = True,
|
| 22 |
+
global_config: dict = {},
|
| 23 |
+
) -> list[dspy.Example]:
|
| 24 |
+
entity_extractor = TypedEntityRelationshipExtractor(num_refine_turns=1, self_refine=True)
|
| 25 |
+
|
| 26 |
+
if global_config.get("use_compiled_dspy_entity_relationship", False):
|
| 27 |
+
entity_extractor.load(global_config["entity_relationship_module_path"])
|
| 28 |
+
|
| 29 |
+
ordered_chunks = list(chunks.items())
|
| 30 |
+
already_processed = 0
|
| 31 |
+
already_entities = 0
|
| 32 |
+
already_relations = 0
|
| 33 |
+
|
| 34 |
+
async def _process_single_content(
|
| 35 |
+
chunk_key_dp: tuple[str, TextChunkSchema]
|
| 36 |
+
) -> dspy.Example:
|
| 37 |
+
nonlocal already_processed, already_entities, already_relations
|
| 38 |
+
chunk_dp = chunk_key_dp[1]
|
| 39 |
+
content = chunk_dp["content"]
|
| 40 |
+
try:
|
| 41 |
+
prediction = await asyncio.to_thread(entity_extractor, input_text=content)
|
| 42 |
+
entities, relationships = prediction.entities, prediction.relationships
|
| 43 |
+
except BadRequestError as e:
|
| 44 |
+
logger.error(f"Error in TypedEntityRelationshipExtractor: {e}")
|
| 45 |
+
entities, relationships = [], []
|
| 46 |
+
example = dspy.Example(
|
| 47 |
+
input_text=content, entities=entities, relationships=relationships
|
| 48 |
+
).with_inputs("input_text")
|
| 49 |
+
already_entities += len(entities)
|
| 50 |
+
already_relations += len(relationships)
|
| 51 |
+
already_processed += 1
|
| 52 |
+
now_ticks = PROMPTS["process_tickers"][
|
| 53 |
+
already_processed % len(PROMPTS["process_tickers"])
|
| 54 |
+
]
|
| 55 |
+
print(
|
| 56 |
+
f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
|
| 57 |
+
end="",
|
| 58 |
+
flush=True,
|
| 59 |
+
)
|
| 60 |
+
return example
|
| 61 |
+
|
| 62 |
+
examples = await asyncio.gather(
|
| 63 |
+
*[_process_single_content(c) for c in ordered_chunks]
|
| 64 |
+
)
|
| 65 |
+
filtered_examples = [
|
| 66 |
+
example
|
| 67 |
+
for example in examples
|
| 68 |
+
if len(example.entities) > 0 and len(example.relationships) > 0
|
| 69 |
+
]
|
| 70 |
+
num_filtered_examples = len(examples) - len(filtered_examples)
|
| 71 |
+
if save_dataset:
|
| 72 |
+
with open(filepath, "wb") as f:
|
| 73 |
+
pickle.dump(filtered_examples, f)
|
| 74 |
+
logger.info(
|
| 75 |
+
f"Saved {len(filtered_examples)} examples with keys: {filtered_examples[0].keys()}, filtered {num_filtered_examples} examples"
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
return filtered_examples
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
async def extract_entities_dspy(
|
| 82 |
+
chunks: dict[str, TextChunkSchema],
|
| 83 |
+
knwoledge_graph_inst: BaseGraphStorage,
|
| 84 |
+
entity_vdb: BaseVectorStorage,
|
| 85 |
+
global_config: dict,
|
| 86 |
+
) -> Union[BaseGraphStorage, None]:
|
| 87 |
+
entity_extractor = TypedEntityRelationshipExtractor(num_refine_turns=1, self_refine=True)
|
| 88 |
+
|
| 89 |
+
if global_config.get("use_compiled_dspy_entity_relationship", False):
|
| 90 |
+
entity_extractor.load(global_config["entity_relationship_module_path"])
|
| 91 |
+
|
| 92 |
+
ordered_chunks = list(chunks.items())
|
| 93 |
+
already_processed = 0
|
| 94 |
+
already_entities = 0
|
| 95 |
+
already_relations = 0
|
| 96 |
+
|
| 97 |
+
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
|
| 98 |
+
nonlocal already_processed, already_entities, already_relations
|
| 99 |
+
chunk_key = chunk_key_dp[0]
|
| 100 |
+
chunk_dp = chunk_key_dp[1]
|
| 101 |
+
content = chunk_dp["content"]
|
| 102 |
+
try:
|
| 103 |
+
prediction = await asyncio.to_thread(entity_extractor, input_text=content)
|
| 104 |
+
entities, relationships = prediction.entities, prediction.relationships
|
| 105 |
+
except BadRequestError as e:
|
| 106 |
+
logger.error(f"Error in TypedEntityRelationshipExtractor: {e}")
|
| 107 |
+
entities, relationships = [], []
|
| 108 |
+
|
| 109 |
+
maybe_nodes = defaultdict(list)
|
| 110 |
+
maybe_edges = defaultdict(list)
|
| 111 |
+
|
| 112 |
+
for entity in entities:
|
| 113 |
+
entity["source_id"] = chunk_key
|
| 114 |
+
maybe_nodes[entity["entity_name"]].append(entity)
|
| 115 |
+
already_entities += 1
|
| 116 |
+
|
| 117 |
+
for relationship in relationships:
|
| 118 |
+
relationship["source_id"] = chunk_key
|
| 119 |
+
maybe_edges[(relationship["src_id"], relationship["tgt_id"])].append(
|
| 120 |
+
relationship
|
| 121 |
+
)
|
| 122 |
+
already_relations += 1
|
| 123 |
+
|
| 124 |
+
already_processed += 1
|
| 125 |
+
now_ticks = PROMPTS["process_tickers"][
|
| 126 |
+
already_processed % len(PROMPTS["process_tickers"])
|
| 127 |
+
]
|
| 128 |
+
print(
|
| 129 |
+
f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
|
| 130 |
+
end="",
|
| 131 |
+
flush=True,
|
| 132 |
+
)
|
| 133 |
+
return dict(maybe_nodes), dict(maybe_edges)
|
| 134 |
+
|
| 135 |
+
results = await asyncio.gather(
|
| 136 |
+
*[_process_single_content(c) for c in ordered_chunks]
|
| 137 |
+
)
|
| 138 |
+
print()
|
| 139 |
+
maybe_nodes = defaultdict(list)
|
| 140 |
+
maybe_edges = defaultdict(list)
|
| 141 |
+
for m_nodes, m_edges in results:
|
| 142 |
+
for k, v in m_nodes.items():
|
| 143 |
+
maybe_nodes[k].extend(v)
|
| 144 |
+
for k, v in m_edges.items():
|
| 145 |
+
maybe_edges[k].extend(v)
|
| 146 |
+
all_entities_data = await asyncio.gather(
|
| 147 |
+
*[
|
| 148 |
+
_merge_nodes_then_upsert(k, v, knwoledge_graph_inst, global_config)
|
| 149 |
+
for k, v in maybe_nodes.items()
|
| 150 |
+
]
|
| 151 |
+
)
|
| 152 |
+
await asyncio.gather(
|
| 153 |
+
*[
|
| 154 |
+
_merge_edges_then_upsert(k[0], k[1], v, knwoledge_graph_inst, global_config)
|
| 155 |
+
for k, v in maybe_edges.items()
|
| 156 |
+
]
|
| 157 |
+
)
|
| 158 |
+
if not len(all_entities_data):
|
| 159 |
+
logger.warning("Didn't extract any entities, maybe your LLM is not working")
|
| 160 |
+
return None
|
| 161 |
+
if entity_vdb is not None:
|
| 162 |
+
data_for_vdb = {
|
| 163 |
+
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
| 164 |
+
"content": dp["entity_name"] + dp["description"],
|
| 165 |
+
"entity_name": dp["entity_name"],
|
| 166 |
+
}
|
| 167 |
+
for dp in all_entities_data
|
| 168 |
+
}
|
| 169 |
+
await entity_vdb.upsert(data_for_vdb)
|
| 170 |
+
|
| 171 |
+
return knwoledge_graph_inst
|
nano-graphrag/nano_graphrag/entity_extraction/metric.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dspy
|
| 2 |
+
from nano_graphrag.entity_extraction.module import Relationship
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class AssessRelationships(dspy.Signature):
|
| 6 |
+
"""
|
| 7 |
+
Assess the similarity between gold and predicted relationships:
|
| 8 |
+
1. Match relationships based on src_id and tgt_id pairs, allowing for slight variations in entity names.
|
| 9 |
+
2. For matched pairs, compare:
|
| 10 |
+
a) Description similarity (semantic meaning)
|
| 11 |
+
b) Weight similarity
|
| 12 |
+
c) Order similarity
|
| 13 |
+
3. Consider unmatched relationships as penalties.
|
| 14 |
+
4. Aggregate scores, accounting for precision and recall.
|
| 15 |
+
5. Return a final similarity score between 0 (no similarity) and 1 (perfect match).
|
| 16 |
+
|
| 17 |
+
Key considerations:
|
| 18 |
+
- Prioritize matching based on entity pairs over exact string matches.
|
| 19 |
+
- Use semantic similarity for descriptions rather than exact matches.
|
| 20 |
+
- Weight the importance of different aspects (e.g., entity matching, description, weight, order).
|
| 21 |
+
- Balance the impact of matched and unmatched relationships in the final score.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
gold_relationships: list[Relationship] = dspy.InputField(
|
| 25 |
+
desc="The gold-standard relationships to compare against."
|
| 26 |
+
)
|
| 27 |
+
predicted_relationships: list[Relationship] = dspy.InputField(
|
| 28 |
+
desc="The predicted relationships to compare against the gold-standard relationships."
|
| 29 |
+
)
|
| 30 |
+
similarity_score: float = dspy.OutputField(
|
| 31 |
+
desc="Similarity score between 0 and 1, with 1 being the highest similarity."
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def relationships_similarity_metric(
|
| 36 |
+
gold: dspy.Example, pred: dspy.Prediction, trace=None
|
| 37 |
+
) -> float:
|
| 38 |
+
model = dspy.ChainOfThought(AssessRelationships)
|
| 39 |
+
gold_relationships = [Relationship(**item) for item in gold["relationships"]]
|
| 40 |
+
predicted_relationships = [Relationship(**item) for item in pred["relationships"]]
|
| 41 |
+
similarity_score = float(
|
| 42 |
+
model(
|
| 43 |
+
gold_relationships=gold_relationships,
|
| 44 |
+
predicted_relationships=predicted_relationships,
|
| 45 |
+
).similarity_score
|
| 46 |
+
)
|
| 47 |
+
return similarity_score
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def entity_recall_metric(
|
| 51 |
+
gold: dspy.Example, pred: dspy.Prediction, trace=None
|
| 52 |
+
) -> float:
|
| 53 |
+
true_set = set(item["entity_name"] for item in gold["entities"])
|
| 54 |
+
pred_set = set(item["entity_name"] for item in pred["entities"])
|
| 55 |
+
true_positives = len(pred_set.intersection(true_set))
|
| 56 |
+
false_negatives = len(true_set - pred_set)
|
| 57 |
+
recall = (
|
| 58 |
+
true_positives / (true_positives + false_negatives)
|
| 59 |
+
if (true_positives + false_negatives) > 0
|
| 60 |
+
else 0
|
| 61 |
+
)
|
| 62 |
+
return recall
|