ArthurSrz Claude commited on
Commit
70ab3b6
·
1 Parent(s): d341c77

feat: Add complete nano-graphrag source code

Browse files

- Add all nano-graphrag source files to Space
- Remove submodule reference and add as regular files
- This ensures nano-graphrag can be installed locally with -e ./nano-graphrag

🤖 Generated with Claude Code

Co-Authored-By: Claude <noreply@anthropic.com>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. nano-graphrag +0 -1
  2. nano-graphrag/.coveragerc +11 -0
  3. nano-graphrag/.env.example.azure +7 -0
  4. nano-graphrag/.github/workflows/test.yml +58 -0
  5. nano-graphrag/.gitignore +183 -0
  6. nano-graphrag/LICENSE +21 -0
  7. nano-graphrag/MANIFEST.in +1 -0
  8. nano-graphrag/docs/CONTRIBUTING.md +19 -0
  9. nano-graphrag/docs/FAQ.md +41 -0
  10. nano-graphrag/docs/ROADMAP.md +25 -0
  11. nano-graphrag/docs/benchmark-dspy-entity-extraction.md +276 -0
  12. nano-graphrag/docs/benchmark-en.md +150 -0
  13. nano-graphrag/docs/benchmark-zh.md +91 -0
  14. nano-graphrag/docs/use_neo4j_for_graphrag.md +27 -0
  15. nano-graphrag/examples/benchmarks/dspy_entity.py +152 -0
  16. nano-graphrag/examples/benchmarks/eval_naive_graphrag_on_multi_hop.ipynb +432 -0
  17. nano-graphrag/examples/benchmarks/hnsw_vs_nano_vector_storage.py +78 -0
  18. nano-graphrag/examples/benchmarks/md5_vs_xxhash.py +54 -0
  19. nano-graphrag/examples/finetune_entity_relationship_dspy.ipynb +0 -0
  20. nano-graphrag/examples/generate_entity_relationship_dspy.ipynb +0 -0
  21. nano-graphrag/examples/graphml_visualize.py +282 -0
  22. nano-graphrag/examples/no_openai_key_at_all.py +111 -0
  23. nano-graphrag/examples/using_amazon_bedrock.py +19 -0
  24. nano-graphrag/examples/using_custom_chunking_method.py +43 -0
  25. nano-graphrag/examples/using_deepseek_api_as_llm+glm_api_as_embedding.py +136 -0
  26. nano-graphrag/examples/using_deepseek_as_llm.py +98 -0
  27. nano-graphrag/examples/using_dspy_entity_extraction.py +144 -0
  28. nano-graphrag/examples/using_faiss_as_vextorDB.py +97 -0
  29. nano-graphrag/examples/using_hnsw_as_vectorDB.py +129 -0
  30. nano-graphrag/examples/using_llm_api_as_llm+ollama_embedding.py +122 -0
  31. nano-graphrag/examples/using_local_embedding_model.py +38 -0
  32. nano-graphrag/examples/using_milvus_as_vectorDB.py +94 -0
  33. nano-graphrag/examples/using_ollama_as_llm.py +96 -0
  34. nano-graphrag/examples/using_ollama_as_llm_and_embedding.py +120 -0
  35. nano-graphrag/examples/using_qdrant_as_vectorDB.py +113 -0
  36. nano-graphrag/nano_graphrag/__init__.py +7 -0
  37. nano-graphrag/nano_graphrag/_llm.py +294 -0
  38. nano-graphrag/nano_graphrag/_op.py +1140 -0
  39. nano-graphrag/nano_graphrag/_splitter.py +94 -0
  40. nano-graphrag/nano_graphrag/_storage/__init__.py +5 -0
  41. nano-graphrag/nano_graphrag/_storage/gdb_neo4j.py +529 -0
  42. nano-graphrag/nano_graphrag/_storage/gdb_networkx.py +268 -0
  43. nano-graphrag/nano_graphrag/_storage/kv_json.py +46 -0
  44. nano-graphrag/nano_graphrag/_storage/vdb_hnswlib.py +141 -0
  45. nano-graphrag/nano_graphrag/_storage/vdb_nanovectordb.py +68 -0
  46. nano-graphrag/nano_graphrag/_utils.py +305 -0
  47. nano-graphrag/nano_graphrag/base.py +186 -0
  48. nano-graphrag/nano_graphrag/entity_extraction/__init__.py +0 -0
  49. nano-graphrag/nano_graphrag/entity_extraction/extract.py +171 -0
  50. nano-graphrag/nano_graphrag/entity_extraction/metric.py +62 -0
nano-graphrag DELETED
@@ -1 +0,0 @@
1
- Subproject commit 01f429e8c562e8f19b2449f90cec9a4a67d4f6ee
 
 
nano-graphrag/.coveragerc ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [report]
2
+ exclude_lines =
3
+ # Have to re-enable the standard pragma
4
+ pragma: no cover
5
+
6
+ # Don't complain if tests don't hit defensive assertion code:
7
+ raise NotImplementedError
8
+ logger.
9
+ omit =
10
+ # Don't have a nice github action for neo4j now, so skip this file:
11
+ nano_graphrag/_storage/gdb_neo4j.py
nano-graphrag/.env.example.azure ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ API_KEY_EMB="<your azure openai key for embedding>"
2
+ AZURE_ENDPOINT_EMB="<your azure openai endpoint for embedding>"
3
+ API_VERSION_EMB="<api version>"
4
+
5
+ AZURE_OPENAI_API_KEY="<your azure openai key for embedding>"
6
+ AZURE_OPENAI_ENDPOINT="<AZURE_OPENAI_ENDPOINT>"
7
+ OPENAI_API_VERSION="<OPENAI_API_VERSION>"
nano-graphrag/.github/workflows/test.yml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: test
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ - dev
8
+ paths-ignore:
9
+ - '**/*.md'
10
+ - '**/*.ipynb'
11
+ - 'examples/**'
12
+ pull_request:
13
+ branches:
14
+ - main
15
+ - dev
16
+ paths-ignore:
17
+ - '**/*.md'
18
+ - '**/*.ipynb'
19
+ - 'examples/**'
20
+
21
+ jobs:
22
+ test:
23
+ name: Tests on ${{ matrix.os }} for ${{ matrix.python-version }}
24
+ strategy:
25
+ matrix:
26
+ python-version: [3.9]
27
+ os: [ubuntu-latest]
28
+ runs-on: ${{ matrix.os }}
29
+ timeout-minutes: 10
30
+ steps:
31
+ - uses: actions/checkout@v4
32
+ - name: Set up Python ${{ matrix.python-version }}
33
+ uses: actions/setup-python@v3
34
+ with:
35
+ python-version: ${{ matrix.python-version }}
36
+ - name: Install dependencies
37
+ run: |
38
+ python -m pip install --upgrade pip
39
+ pip install -r requirements.txt
40
+ pip install -r requirements-dev.txt
41
+ - name: Lint with flake8
42
+ run: |
43
+ flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
44
+ - name: Build and Test
45
+ env:
46
+ NANO_GRAPHRAG_TEST_IGNORE_NEO4J: true
47
+ run: |
48
+ python -m pytest -o log_cli=true -o log_cli_level="INFO" --cov=nano_graphrag --cov-report=xml -v ./
49
+ - name: Check codecov file
50
+ id: check_files
51
+ uses: andstor/file-existence-action@v1
52
+ with:
53
+ files: './coverage.xml'
54
+ - name: Upload coverage from test to Codecov
55
+ uses: codecov/codecov-action@v2
56
+ with:
57
+ file: ./coverage.xml
58
+ token: ${{ secrets.CODECOV_TOKEN }}
nano-graphrag/.gitignore ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by https://www.toptal.com/developers/gitignore/api/python
2
+ # Edit at https://www.toptal.com/developers/gitignore?templates=python
3
+ test_cache.json
4
+ run_test*.py
5
+ nano_graphrag_cache*/
6
+ *.txt
7
+ examples/benchmarks/fixtures/
8
+ tests/original_workflow.txt
9
+ ### Python ###
10
+ # Byte-compiled / optimized / DLL files
11
+ __pycache__/
12
+ *.py[cod]
13
+ *$py.class
14
+ .vscode
15
+ .DS_Store
16
+ # C extensions
17
+ *.so
18
+
19
+ # Distribution / packaging
20
+ .Python
21
+ build/
22
+ develop-eggs/
23
+ dist/
24
+ downloads/
25
+ eggs/
26
+ .eggs/
27
+ lib/
28
+ lib64/
29
+ parts/
30
+ sdist/
31
+ var/
32
+ wheels/
33
+ share/python-wheels/
34
+ *.egg-info/
35
+ .installed.cfg
36
+ *.egg
37
+ MANIFEST
38
+
39
+ # PyInstaller
40
+ # Usually these files are written by a python script from a template
41
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
42
+ *.manifest
43
+ *.spec
44
+
45
+ # Installer logs
46
+ pip-log.txt
47
+ pip-delete-this-directory.txt
48
+
49
+ # Unit test / coverage reports
50
+ htmlcov/
51
+ .tox/
52
+ .nox/
53
+ .coverage
54
+ .coverage.*
55
+ .cache
56
+ nosetests.xml
57
+ coverage.xml
58
+ *.cover
59
+ *.py,cover
60
+ .hypothesis/
61
+ .pytest_cache/
62
+ cover/
63
+
64
+ # Translations
65
+ *.mo
66
+ *.pot
67
+
68
+ # Django stuff:
69
+ *.log
70
+ local_settings.py
71
+ db.sqlite3
72
+ db.sqlite3-journal
73
+
74
+ # Flask stuff:
75
+ instance/
76
+ .webassets-cache
77
+
78
+ # Scrapy stuff:
79
+ .scrapy
80
+
81
+ # Sphinx documentation
82
+ docs/_build/
83
+
84
+ # PyBuilder
85
+ .pybuilder/
86
+ target/
87
+
88
+ # Jupyter Notebook
89
+ .ipynb_checkpoints
90
+
91
+ # IPython
92
+ profile_default/
93
+ ipython_config.py
94
+
95
+ # pyenv
96
+ # For a library or package, you might want to ignore these files since the code is
97
+ # intended to run in multiple environments; otherwise, check them in:
98
+ # .python-version
99
+
100
+ # pipenv
101
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
102
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
103
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
104
+ # install all needed dependencies.
105
+ #Pipfile.lock
106
+
107
+ # poetry
108
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
109
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
110
+ # commonly ignored for libraries.
111
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
112
+ #poetry.lock
113
+
114
+ # pdm
115
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
116
+ #pdm.lock
117
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
118
+ # in version control.
119
+ # https://pdm.fming.dev/#use-with-ide
120
+ .pdm.toml
121
+
122
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
123
+ __pypackages__/
124
+
125
+ # Celery stuff
126
+ celerybeat-schedule
127
+ celerybeat.pid
128
+
129
+ # SageMath parsed files
130
+ *.sage.py
131
+
132
+ # Environments
133
+ .env
134
+ .venv
135
+ env/
136
+ venv/
137
+ ENV/
138
+ env.bak/
139
+ venv.bak/
140
+
141
+ # Spyder project settings
142
+ .spyderproject
143
+ .spyproject
144
+
145
+ # Rope project settings
146
+ .ropeproject
147
+
148
+ # mkdocs documentation
149
+ /site
150
+
151
+ # mypy
152
+ .mypy_cache/
153
+ .dmypy.json
154
+ dmypy.json
155
+
156
+ # Pyre type checker
157
+ .pyre/
158
+
159
+ # pytype static type analyzer
160
+ .pytype/
161
+
162
+ # Cython debug symbols
163
+ cython_debug/
164
+
165
+ # PyCharm
166
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
167
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
168
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
169
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
170
+ #.idea/
171
+
172
+ ### Python Patch ###
173
+ # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
174
+ poetry.toml
175
+
176
+ # ruff
177
+ .ruff_cache/
178
+
179
+ # LSP config files
180
+ pyrightconfig.json
181
+
182
+ # End of https://www.toptal.com/developers/gitignore/api/python
183
+
nano-graphrag/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Gustavo Ye
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
nano-graphrag/MANIFEST.in ADDED
@@ -0,0 +1 @@
 
 
1
+ include readme.md
nano-graphrag/docs/CONTRIBUTING.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to nano-graphrag
2
+
3
+ ### Submit your Contribution through PR
4
+
5
+ To make a contribution, follow these steps:
6
+
7
+ 1. Fork and clone this repository
8
+ 3. If you modified the core code (`./nano_graphrag`), please add tests for it
9
+ 4. **Include proper documentation / docstring or examples**
10
+ 5. Ensure that all tests pass by running `pytest`
11
+ 6. Submit a pull request
12
+
13
+ For more details about pull requests, please read [GitHub's guides](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request).
14
+
15
+
16
+
17
+ ### Only add a dependency when we have to
18
+
19
+ `nano-graphrag` needs to be `nano` and `light`. If we want to add more features, we add them smartly. Don't introduce a huge dependency just for a simple function.
nano-graphrag/docs/FAQ.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### `Leiden.EmptyNetworkError:EmptyNetworkError`
2
+
3
+ This error is caused by `nano-graphrag` tries to compute communities on an empty network. In most cases, this is caused by the LLM model you're using, it fails to extract any entities or relations, so the graph is empty.
4
+
5
+ Try to use another bigger LLM, or here are some ideas to fix it:
6
+
7
+ - Check the response from the LLM, make sure the result fits the desired response format of the extracting entities prompt.
8
+
9
+ The desired response format is something like that:
10
+
11
+ ```text
12
+ ("entity"<|>"Cruz"<|>"person"<|>"Cruz is associated with a vision of control and order, influencing the dynamics among other characters.")
13
+ ```
14
+
15
+ - Some LLMs may not return the format like above, so one possible solution is to add a system instruction to the input prompt, such like:
16
+ ```json
17
+ {
18
+ "role": "system",
19
+ "content": "You are an intelligent assistant and will follow the instructions given to you to fulfill the goal. The answer should be in the format as in the given example."
20
+ }
21
+ ```
22
+ You can use this system_prompt as default for your LLM calling funcation
23
+
24
+
25
+ ### One possible reason of 'Processed 42 chunks,0 entities(duplicated),0 relations(duplicated)WARNING:nano-graphrag:Didn't extract any entities, maybe your LLM is not working WARNING:nano-graphrag:No new entities found'
26
+
27
+ The default num_ctx of ollama is 2048 which is too small for the input prompt of entity extraction. This causes the model to fail to respond correctly.
28
+
29
+ Solution:
30
+ Each model in Ollama has a configuration file. Here, you need to generate a new configuration file based on the original one, and then use this configuration file to generate a new model.
31
+ For example the qwen2, run the following command:
32
+
33
+ `ollama show --modelfile qwen2 > Modelfile`
34
+
35
+ Add a new line into this file below the 'FROM':
36
+
37
+ `PARAMETER num_ctx 32000`
38
+
39
+ `ollama create -f Modelfile qwen2:ctx32k`
40
+
41
+ Afterwards, you can use qwen2:ctx32k to replace qwen2.
nano-graphrag/docs/ROADMAP.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Next Version
2
+
3
+ - [ ] Add DSpy for prompt-tuning to make small models(Qwen2 7B, Llama 3.1 8B...) can extract entities. @NumberChiffre @gusye1234
4
+ - [ ] Optimize Algorithm: add `global_local` query method, globally rewrite query then perform local search.
5
+
6
+
7
+
8
+ ## In next few versions
9
+
10
+ - [ ] Add rate limiter: support token limit (tokens per second, per minute)
11
+
12
+ - [ ] Add other advanced RAG algorithms, candidates:
13
+
14
+ - [ ] [HybridRAG](https://arxiv.org/abs/2408.04948)
15
+ - [ ] [HippoRAG](https://arxiv.org/abs/2405.14831)
16
+
17
+
18
+
19
+
20
+
21
+
22
+ ## Interesting directions
23
+
24
+ - [ ] Add [Sciphi Triplex](https://huggingface.co/SciPhi/Triplex) as the entity extraction model.
25
+ - [ ] Add new components, see [issue](https://github.com/gusye1234/nano-graphrag/issues/2)
nano-graphrag/docs/benchmark-dspy-entity-extraction.md ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Chain Of Thought Prompting with DSPy-AI (v2.4.16)
2
+ ## Main Takeaways
3
+ - Time difference: 156.99 seconds
4
+ - Execution time with DSPy-AI: 304.38 seconds
5
+ - Execution time without DSPy-AI: 147.39 seconds
6
+ - Entities extracted: 22 (without DSPy-AI) vs 37 (with DSPy-AI)
7
+ - Relationships extracted: 21 (without DSPy-AI) vs 36 (with DSPy-AI)
8
+
9
+
10
+ ## Results
11
+ ```markdown
12
+ > python examples/benchmarks/dspy_entity.py
13
+
14
+ Running benchmark with DSPy-AI:
15
+ INFO:httpx:HTTP Request: POST https://api.deepseek.com/chat/completions "HTTP/1.1 200 OK"
16
+ INFO:httpx:HTTP Request: POST https://api.deepseek.com/chat/completions "HTTP/1.1 200 OK"
17
+ INFO:httpx:HTTP Request: POST https://api.deepseek.com/chat/completions "HTTP/1.1 200 OK"
18
+ DEBUG:nano-graphrag:Entities: 14 | Missed Entities: 23 | Total Entities: 37
19
+ DEBUG:nano-graphrag:Relationships: 13 | Missed Relationships: 23 | Total Relationships: 36
20
+ DEBUG:nano-graphrag:Direct Relationships: 31 | Second-order: 5 | Third-order: 0 | Total Relationships: 36
21
+ ⠙ Processed 1 chunks, 37 entities(duplicated), 36 relations(duplicated)
22
+ Execution time with DSPy-AI: 304.38 seconds
23
+
24
+ Entities:
25
+ - 朱元璋 (PERSON):
26
+ 明朝开国皇帝,原名朱重八,后改名朱元璋。他出身贫农,经历了从放牛娃到皇帝的传奇人生。在元朝末年,他参加了红巾军起义,最终推翻元朝,建立了明朝。
27
+ - 朱五四 (PERSON):
28
+ 朱元璋的父亲,农民出身,家境贫寒。他在朱元璋幼年时去世,对朱元璋的成长和人生选择产生了深远影响。
29
+ - 陈氏 (PERSON):
30
+ 朱元璋的母亲,农民出身,家境贫寒。她在朱元璋幼年时去世,对朱元璋的成长和人生选择产生了深远影响。
31
+ - 汤和 (PERSON):
32
+ 朱元璋的幼年朋友,后来成为朱元璋起义军中的重要将领。他在朱元璋早期的发展中起到了关键作用。
33
+ - 郭子兴 (PERSON):
34
+ 红巾军起义的领导人之一,朱元璋的岳父。他在朱元璋早期的发展中起到了重要作用,但后来与朱元璋产生了矛盾。
35
+ - 马姑娘 (PERSON):
36
+ 郭子兴的义女,朱元璋的妻子。她在朱元璋最困难的时候给予了极大的支持,是朱元璋成功的重要因素之一。
37
+ - 元朝 (ORGANIZATION):
38
+ 中国历史上的一个朝代,由蒙古族建立。元朝末年,社会矛盾激化,最终导致了红巾军起义和明朝的建立。
39
+ - 红巾军 (ORGANIZATION):
40
+ 元朝末年起义军的一支,主要由农民组成。朱元璋最初加入的就是红巾军,并在其中逐渐崭露头角。
41
+ - 皇觉寺 (LOCATION):
42
+ 朱元璋早年出家的地方,位于安徽凤阳。他在寺庙中度过了几年的时光,这段经历对他的人生观和价值观产生了深远影响。
43
+ - 濠州 (LOCATION):
44
+ 朱元璋早期活动的重要地点,也是红巾军的重要据点之一。朱元璋在这里经历了许多重要事件,包括与郭子兴的矛盾和最终的离开。
45
+ - 1328年 (DATE):
46
+ 朱元璋出生的年份。这一年标志着明朝开国皇帝传奇人生的开始。
47
+ - 1344年 (DATE):
48
+ 朱元璋家庭遭遇重大变故的年份,他的父母在这一年相继去世。这一事件对朱元璋的人生选择产生了深远影响。
49
+ - 1352年 (DATE):
50
+ 朱元璋正式加入红巾军起义的年份。这一年标志着朱元璋从农民到起义军领袖的转变。
51
+ - 1368年 (DATE):
52
+ 朱元璋推翻元朝,建立明朝的年份。这一年标志着朱元璋从起义军领袖到皇帝的转变。
53
+ - 朱百六 (PERSON):
54
+ 朱元璋的高祖,名字具有元朝时期老百姓命名的特点,即以数字命名。
55
+ - 朱四九 (PERSON):
56
+ 朱元璋的曾祖,名字同样具有元朝时期老百姓命名的特点,即以数字命名。
57
+ - 朱初一 (PERSON):
58
+ 朱元璋的祖父,名字具有元朝时期老百姓命名的特点,即以数字命名。
59
+ - 刘德 (PERSON):
60
+ 朱元璋早年为其放牛的地主,对朱元璋的童年生活有重要影响。
61
+ - 韩山童 (PERSON):
62
+ 红巾军起义的早期领导人之一,与刘福通共同起义,对朱元璋的起义选择有间接影响。
63
+ - 刘福通 (PERSON):
64
+ 红巾军起义的早期领导人之一,与韩山童共同起义,对朱元璋的起义选择有间接影响。
65
+ - 脱脱 (PERSON):
66
+ 元朝末年的著名宰相,主张治理黄河,但他的政策间接导致了红巾军起义的爆发。
67
+ - 元顺帝 (PERSON):
68
+ 元朝末代皇帝,他在位期间元朝社会矛盾激化,最终导致了红巾军起义和明朝的建立。
69
+ - 孙德崖 (PERSON):
70
+ 红巾军起义的领导人之一,与郭子兴有矛盾,曾绑架郭子兴,对朱元璋的早期发展有重要影响。
71
+ - 周德兴 (PERSON):
72
+ 朱元璋的早期朋友,曾为朱元璋算卦,对朱元璋的人生选择有一定影响。
73
+ - 徐达 (PERSON):
74
+ 朱元璋早期的重要将领,后来成为明朝的开国功臣之一。
75
+ - 明教 (RELIGION):
76
+ 朱元璋在起义过程中接触到的宗教信仰,对他的思想和行动有一定影响。
77
+ - 弥勒佛 (RELIGION):
78
+ 明教中的重要神祇,朱元璋相信弥勒佛会降世,对他的信仰和行动有一定影响。
79
+ - 颖州 (LOCATION):
80
+ 朱元璋早年讨饭的地方,也是红巾军起义的重要地点之一。
81
+ - 定远 (LOCATION):
82
+ 朱元璋早期攻打的地点之一,是他军事生涯的起点。
83
+ - 怀远 (LOCATION):
84
+ 朱元璋早期攻打的地点之一,是他军事生涯的起点。
85
+ - 安奉 (LOCATION):
86
+ 朱元璋早期攻打的地点之一,是他军事生涯的起点。
87
+ - 含山 (LOCATION):
88
+ 朱元璋早期攻打的地点之一,是他军事生涯的起点。
89
+ - 虹县 (LOCATION):
90
+ 朱元璋早期攻打的地点之一,是他军事生涯的起点。
91
+ - 钟离 (LOCATION):
92
+ 朱元璋的家乡,他在此地召集了二十四位重要将领。
93
+ - 黄河 (LOCATION):
94
+ 元朝末年黄河泛滥,导致了严重的社会问题,间接引发了红巾军起义。
95
+ - 淮河 (LOCATION):
96
+ 元朝末年淮河沿岸遭遇严重瘟疫和旱灾,加剧了社会矛盾。
97
+ - 1351年 (DATE):
98
+ 红巾军起义爆发的年份,对朱元璋的人生选择产生了重要影响。
99
+
100
+ Relationships:
101
+ - 朱元璋 -> 朱五四:
102
+ 朱元璋是朱五四的儿子,朱五四的去世对朱元璋的成长和人生选择产生了深远影响。
103
+ - 朱元璋 -> 陈氏:
104
+ 朱元璋是陈氏的儿子,陈氏的去世对朱元璋的成长和人生选择产生了深远影响。
105
+ - 朱元璋 -> 汤和:
106
+ 汤和是朱元璋的幼年朋友,后来成为朱元璋起义军中的重要将领,对朱元璋早期的发展起到了关键作用。
107
+ - 朱元璋 -> 郭子兴:
108
+ 郭子兴是朱元璋的岳父,也是红巾军起义的领导人之一。他在朱元璋早期的发展中起到了重要作用,但后来与朱元璋产生了矛盾。
109
+ - 朱元璋 -> 马姑娘:
110
+ 马姑娘是朱元璋的妻子,她在朱元璋最困难的时候给予了极大的支持,是朱元璋成功的重要因素之一。
111
+ - 朱元璋 -> 元朝:
112
+ 朱元璋在元朝末年参加了红巾军起义,最终推翻了元朝,建立了明朝。
113
+ - 朱元璋 -> 红巾军:
114
+ 朱元璋最初加入的是红巾军,并在其中逐渐崭露头角,最终成为起义军的重要领导人。
115
+ - 朱元璋 -> 皇觉寺:
116
+ 朱元璋早年出家的地方是皇觉寺,这段经历对他的人生观和价值观产生了深远影响。
117
+ - 朱元璋 -> 濠州:
118
+ 濠州是朱元璋早期活动的重要地点,也是红巾军的重要据点之一。朱元璋在这里经历了许多重要事件,包括与郭子兴的矛盾和最终的离开。
119
+ - 朱元璋 -> 1328年:
120
+ 1328年是朱元璋出生的年份,这一年标志着明朝开国皇帝传奇人生的开始。
121
+ - 朱元璋 -> 1344年:
122
+ 1344年是朱元璋家庭遭遇重大变故的年份,他的父母在这一年相继去世,这一事件对朱元璋的人生选择产生了深远影响。
123
+ - 朱元璋 -> 1352年:
124
+ 1352年是朱元璋正式加入红巾军起义的年份,这一年标志着朱元璋从农民到起义军领袖的转变。
125
+ - 朱元璋 -> 1368年:
126
+ 1368年是朱元璋推翻元朝,建立明朝的年份,这一年标志着朱元璋从起义军领袖到皇帝的转变。
127
+ - 朱元璋 -> 朱百六:
128
+ 朱百六是朱元璋的高祖,对朱元璋的家族背景有重要影响。
129
+ - 朱元璋 -> 朱四九:
130
+ 朱四九是朱元璋的曾祖,对朱元璋的家族背景有重要影响。
131
+ - 朱元璋 -> 朱初一:
132
+ 朱初一是朱元璋的祖父,对朱元璋的家族背景有重要影响。
133
+ - 朱元璋 -> 刘德:
134
+ 刘德是朱元璋早年为其放牛的地主,对朱元璋的童年生活有重要影响。
135
+ - 朱元璋 -> 韩山童:
136
+ 韩山童是红巾军起义的早期领导人之一,对朱元璋的起义选择有间接影响。
137
+ - 朱元璋 -> 刘福通:
138
+ 刘福通是红巾军起义的早期领导人之一,对朱元璋的起义选择有间接影响。
139
+ - 朱元璋 -> 脱脱:
140
+ 脱脱是元朝末年的著名宰相,他的政策间接导致了红巾军起义的爆发,对朱元璋的起义选择有间接影响。
141
+ - 朱元璋 -> 元顺帝:
142
+ 元顺帝是元朝末代皇帝,他在位期间社会矛盾激化,最终导致了红巾军起义和明朝的建立,对朱元璋的起义选择有重要影响。
143
+ - 朱元璋 -> 孙德崖:
144
+ 孙德崖是红巾军起义的领导人之一,与郭子兴有矛盾,曾绑架郭子兴,对朱元璋的早期发展有重要影响。
145
+ - 朱元璋 -> 周德兴:
146
+ 周德兴是朱元璋的早期朋友,曾为朱元璋算卦,对朱元璋的人生选择有一定影响。
147
+ - 朱元璋 -> 徐达:
148
+ 徐达是朱元璋早期的重要将领,后来成为明朝的开国功臣之一,对朱元璋的军事生涯有重要影响。
149
+ - 朱元璋 -> 明教:
150
+ 朱元璋在起义过程中接触到的宗教信仰,对他的思想和行动有一定影响。
151
+ - 朱元璋 -> 弥勒佛:
152
+ 朱元璋相信弥勒佛会降世,对他的信仰和行动有一定影响。
153
+ - 朱元璋 -> 颖州:
154
+ 颖州是朱元璋早年讨饭的地方,也是红巾军起义的重要地点之一,对朱元璋的早期生活有重要影响。
155
+ - 朱元璋 -> 定远:
156
+ 定远是朱元璋早期攻打的地点之一,是他军事生涯的起点,对朱元璋的军事发展有重要影响。
157
+ - 朱元璋 -> 怀远:
158
+ 怀远是朱元璋早期攻打的地点之一,是他军事生涯的起点,对朱元璋的军事发展有重要影响。
159
+ - 朱元璋 -> 安奉:
160
+ 安奉是朱元璋早期攻打的地点之一,是他军事生涯的起点,对朱元璋的军事发展有重要影响。
161
+ - 朱元璋 -> 含山:
162
+ 含山是朱元璋早期攻打的地点之一,是他军事生涯的起点,对朱元璋的军事发展有重要影响。
163
+ - 朱元璋 -> 虹县:
164
+ 虹县是朱元璋早期攻打的地点之一,是他军事生涯的起点,对朱元璋的军事发展有重要影响。
165
+ - 朱元璋 -> 钟离:
166
+ 钟离是朱元璋的家乡,他在此地召集了二十四位重要将领,对朱元璋的军事发展有重要影响。
167
+ - 朱元璋 -> 黄河:
168
+ 元朝末年黄河泛滥,导致了严重的社会问题,间接引发了红巾军起义,对朱元璋的起义选择有重要影响。
169
+ - 朱元璋 -> 淮河:
170
+ 元朝末年淮河沿岸遭遇严重瘟疫和旱灾,加剧了社会矛盾,对朱元璋的起义选择有重要影响。
171
+ - 朱元璋 -> 1351年:
172
+ 1351年是红巾军起义爆发的年份,对朱元璋的人生选择产生了重要影响。
173
+ Running benchmark without DSPy-AI:
174
+ INFO:httpx:HTTP Request: POST https://api.deepseek.com/chat/completions "HTTP/1.1 200 OK"
175
+ INFO:httpx:HTTP Request: POST https://api.deepseek.com/chat/completions "HTTP/1.1 200 OK"
176
+ ⠙ Processed 1 chunks, 22 entities(duplicated), 21 relations(duplicated)
177
+ Execution time without DSPy-AI: 147.39 seconds
178
+
179
+ Entities:
180
+ - "朱元璋" ("PERSON"):
181
+ "朱元璋,原名朱重八,后改名朱元璋,是明朝的开国皇帝。他出身贫农,经历了从放牛娃到和尚,再到起义军领袖,最终成为皇帝的传奇人生。"
182
+ - "朱五四" ("PERSON"):
183
+ "朱五四,朱元璋的父亲,是一个农民,为地主种地,家境贫寒。"
184
+ - "陈氏" ("PERSON"):
185
+ "陈氏,朱元璋的母亲,是一个农民,与丈夫朱五四一起辛勤劳作,家境贫寒。"
186
+ - "汤和" ("PERSON"):
187
+ "汤和,朱元璋的幼年朋友,后来成为朱元璋的战友,在朱元璋的崛起过程中起到了重要作用。"
188
+ - "郭子兴" ("PERSON"):
189
+ "郭子兴,濠州城的守卫者,是朱元璋的岳父,也是朱元璋早期的重要支持者。"
190
+ - "韩山童" ("PERSON"):
191
+ "韩山童,与刘福通一起起义反抗元朝统治,是元末农民起义的重要领袖之一。"<SEP>"韩山童,元末农民起义的领袖之一,自称宋朝皇室后裔,与刘福通一起起义。"
192
+ - "刘福通" ("PERSON"):
193
+ "刘福通,与韩山童一起起义反抗元朝统治,是元末农民起义的重要领袖之一。"<SEP>"刘福通,元末农民起义的领袖之一,自称刘光世大将的后人,与韩山童一起起义。"
194
+ - "元朝" ("ORGANIZATION"):
195
+ "元朝,由蒙古族建立的王朝,统治中国时期实行了严格的等级制度,导致社会矛盾激化,最终被朱元璋领导的起义军推翻。"
196
+ - "皇觉寺" ("ORGANIZATION"):
197
+ "皇觉寺,朱元璋曾经在此当和尚,从事杂役工作,后来因饥荒严重,和尚们都被派出去化缘。"
198
+ - "白莲教" ("ORGANIZATION"):
199
+ "白莲教,元末农民起义中的一种宗教组织,韩山童和刘福通起义时利用了这一宗教信仰。"
200
+ - "濠州城" ("GEO"):
201
+ "濠州城,位于今安徽省,是朱元璋早期活动的重要地点,也是郭子兴的驻地。"
202
+ - "定远" ("GEO"):
203
+ "定远,朱元璋奉命攻击的地方,成功攻克后在元军回援前撤出,显示了其军事才能。"
204
+ - "钟离" ("GEO"):
205
+ "钟离,朱元璋的家乡,他在此招收了二十四名壮丁,这些人后来成为明朝的高级干部。"
206
+ - "元末农民起义" ("EVENT"):
207
+ "元末农民起义,是元朝末年由韩山童、刘福通等人领导的反抗元朝统治的大规模起义,最终导致了元朝的灭亡。"
208
+ - "马姑娘" ("PERSON"):
209
+ "马姑娘,郭子兴的义女,后来成为朱元璋的妻子,在朱元璋被关押时,她冒着危险送饭给朱元璋,表现出深厚的感情。"
210
+ - "孙德崖" ("PERSON"):
211
+ "孙德崖,与郭子兴有矛盾的起义军领袖之一,曾参与绑架郭子兴。"
212
+ - "徐达" ("PERSON"):
213
+ "徐达,朱元璋的二十四名亲信之一,后来成为明朝的重要将领。"
214
+ - "周德兴" ("PERSON"):
215
+ "周德兴,朱元璋的二十四名亲信之一,曾为朱元璋算过命。"
216
+ - "脱脱" ("PERSON"):
217
+ "脱脱,元朝的著名宰相,主张治理黄河,但他的政策间接导致了元朝的灭亡。"
218
+ - "元顺帝" ("PERSON"):
219
+ "元顺帝,元朝的最后一位皇帝,统治时期元朝社会矛盾激化,最终导致了元朝的灭亡。"
220
+ - "刘德" ("PERSON"):
221
+ "刘德,地主,朱元璋早年为其放牛。"
222
+ - "吴老太" ("PERSON"):
223
+ "吴老太,村口的媒人,朱元璋曾希望托她找一个媳妇。"
224
+
225
+ Relationships:
226
+ - "朱元璋" -> "朱五四":
227
+ "朱元璋的父亲,对他的成长和早期生活有重要影响。"
228
+ - "朱元璋" -> "陈氏":
229
+ "朱元璋的母亲,对他的成长和早期生活有重要影响。"
230
+ - "朱元璋" -> "汤和":
231
+ "朱元璋的幼年朋友,后来成为他的战友,在朱元璋的崛起过程中起到了重要作用。"
232
+ - "朱元璋" -> "郭子兴":
233
+ "朱元璋的岳父,是他在起义军中的重要支持者。"
234
+ - "朱元璋" -> "韩山童":
235
+ "朱元璋在起义过程中与韩山童有间接联系,韩山童的起义对朱元璋的崛起有重要影响。"
236
+ - "朱元璋" -> "刘福通":
237
+ "朱元璋在起义过程中与刘福通有间接联系,刘福通的起义对朱元璋的崛起有重要影响。"
238
+ - "朱元璋" -> "元朝":
239
+ "朱元璋最终推翻了元朝的统治,建立了明朝。"
240
+ - "朱元璋" -> "皇觉寺":
241
+ "朱元璋曾经在此当和尚,这段经历对他的成长有重要影响。"
242
+ - "朱元璋" -> "白莲教":
243
+ "朱元璋在起义过程中接触到了白莲教,虽然他本人可能并不信仰,但白莲教的起义对他有重要影响。"
244
+ - "朱元璋" -> "濠州城":
245
+ "朱元璋在濠州城的活动对其早期军事和政治生涯有重要影响。"
246
+ - "朱元璋" -> "定远":
247
+ "朱元璋成功攻克定远,显示了其军事才能。"
248
+ - "朱元璋" -> "钟离":
249
+ "朱元璋的家乡,他在此招收了二十四名壮丁,这些人后来成为明朝的高级干部。"
250
+ - "朱元璋" -> "元末农民起义":
251
+ "朱元璋参与并最终领导了元末农民起义,推翻了元朝的统治。"
252
+ - "朱元璋" -> "马姑娘":
253
+ "朱元璋的妻子,在朱元璋被关押时,她冒着危险送饭给朱元璋,表现出深厚的感情。"
254
+ - "朱元璋" -> "孙德崖":
255
+ "朱元璋在孙德崖与郭子兴的矛盾中起到了调解作用,显示了其政治智慧。"
256
+ - "朱元璋" -> "徐达":
257
+ "朱元璋的二十四名亲信之一,后来成为明朝的重要将领。"
258
+ - "朱元璋" -> "周德兴":
259
+ "朱元璋的二十四名亲信之一,曾为朱元璋算过命。"
260
+ - "朱元璋" -> "脱脱":
261
+ "朱元璋在起义过程中间接受到脱脱政策的影响,脱脱的政策间接导致了元朝的灭亡。"
262
+ - "朱元璋" -> "元顺帝":
263
+ "朱元璋最终推翻了元顺帝的统治,建立了明朝。"
264
+ - "朱元璋" -> "刘德":
265
+ "朱元璋早年为刘德放牛,这段经历对他的成长有重要影响。"
266
+ - "朱元璋" -> "吴老太":
267
+ "朱元璋曾希望托吴老太找一个媳妇,显示了他对家庭的渴望。"
268
+ ```
269
+
270
+ # Self-Refine with DSPy-AI (v2.5.6)
271
+ ## Main Takeaways
272
+ - Time difference: 66.24 seconds
273
+ - Execution time with DSPy-AI: 211.04 seconds
274
+ - Execution time without DSPy-AI: 144.80 seconds
275
+ - Entities extracted: 38 (without DSPy-AI) vs 16 (with DSPy-AI)
276
+ - Relationships extracted: 38 (without DSPy-AI) vs 16 (with DSPy-AI)
nano-graphrag/docs/benchmark-en.md ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ - We use [A Christmas Carol](https://raw.githubusercontent.com/gusye1234/nano-graphrag/main/tests/mock_data.txt) by Dickens as the benchmark corpus.
2
+ - We use `61b5eea34783c58074b3c53f1689ad8a5ba6b6ee` commit of [Official GraphRAG implementation](https://github.com/microsoft/graphrag/tree/main)
3
+ - Both GraphRAG and `nano-graphrag` use OpenAI Embedding and `gpt-4o`.
4
+ - Not Cache for both. On the same device and network connection.
5
+ - GrapgRAG Max Async API request: 25
6
+ - `nano-graphrag` Max Async API request: 16
7
+
8
+ ## Index Benchmark
9
+
10
+ **GraphRAG index time**: more than 5 minutes
11
+
12
+ **`nano-graphrag` index time**: less than 4 minutes
13
+
14
+ ## Local Search Results
15
+
16
+ #### GraphRAG
17
+
18
+ "What are the top themes in this story?"
19
+
20
+ ```markdown
21
+ # Top Themes in the Story
22
+
23
+ The story revolves around several central themes that are intricately woven into the narrative, each contributing to the overall message and character development. Here are the top themes:
24
+
25
+ ## Family and Relationships
26
+
27
+ Family and relationships play a crucial role in the story, particularly through the interactions between Scrooge and his nephew and niece. Scrooge's nephew is depicted as a warm-hearted and persistent character who continually tries to engage Scrooge, especially during the Christmas season. Despite Scrooge's initial disdain for Christmas, his nephew's efforts highlight themes of compassion and familial bonds [Data: Scrooge's Family and Christmas (20); Relationships (149, 152, 154)]. Scrooge's niece also contributes to this theme with her strong opinions and musical talent, which help soften Scrooge's attitude and remind him of his past [Data: Scrooge's Family and Christmas (20); Relationships (154)].
28
+
29
+ ## Transformation and Redemption
30
+
31
+ The theme of transformation and redemption is central to the story, primarily illustrated through Scrooge's character arc. Initially, Scrooge is cynical and dismissive of Christmas and its associated values. However, the recurring event of Christmas serves as a catalyst for his transformation. Through various emotional experiences and reflections prompted by the holiday, Scrooge undergoes a significant change, ultimately embracing generosity and kindness [Data: Scrooge's Family and Christmas (20); Relationships (96, 156, 153, 154)].
32
+
33
+ ## The Spirit of Christmas
34
+
35
+ Christmas itself is a recurring and significant event in the story, symbolizing themes of goodwill, generosity, and reflection. The holiday serves as the backdrop for many of the story's pivotal moments and emotional responses. It is during Christmas that Scrooge's transformation occurs, highlighting the holiday's themes of kindness and community. The discussions and celebrations around Christmas emphasize its importance in fostering a sense of togetherness and compassion [Data: Scrooge's Family and Christmas (20); Relationships (96, 149, 153, 154, 156)].
36
+
37
+ ## Social Commentary
38
+
39
+ The story also provides a social commentary on the issues of poverty and social inequality. Characters like Ignorance and Want symbolize the social issues of the time, revealed by the Ghost of Christmas Present. Scrooge's initial attitude towards the poor and his interactions with the gentlemen seeking charity highlight the harsh realities of social neglect and the need for compassion and support for the less fortunate [Data: Entities (40); Sources (4)].
40
+
41
+ ## The Role of Memory and Reflection
42
+
43
+ Memory and reflection are significant themes that contribute to Scrooge's transformation. The story often delves into Scrooge's past, evoking memories that help him reflect on his life choices and their impact on his current state. These reflections are crucial in prompting Scrooge to change his ways and embrace a more generous and compassionate outlook [Data: Scrooge's Family and Christmas (20); Relationships (96, 156)].
44
+
45
+ ## Conclusion
46
+
47
+ In summary, the top themes in the story include family and relationships, transformation and redemption, the spirit of Christmas, social commentary, and the role of memory and reflection. Each of these themes is intricately connected to the characters and events, contributing to the story's rich narrative and moral lessons.
48
+ ```
49
+
50
+ ### `nano-graphrag`
51
+
52
+ "What are the top themes in this story? "
53
+
54
+ ```markdown
55
+ Sure, let’s explore the top themes in "A Christmas Carol" based on the data provided.
56
+
57
+ ### Transformation and Redemption
58
+
59
+ The central theme of "A Christmas Carol" revolves around the profound transformation and redemption of Ebenezer Scrooge. Initially depicted as a miserly and solitary man who disdains Christmas and human connection, Scrooge undergoes a radical change after a series of supernatural visits. These visitations force him to reflect on his past, recognize the suffering of others in the present, and confront the grim future that awaits him if he remains unchanged [Data: Reports (0); Entities (1); Relationships (0, 3, 5, 14, 45, +more)].
60
+
61
+ ### Influence of Supernatural Entities
62
+
63
+ Supernatural elements are crucial in facilitating Scrooge's transformation. The Ghosts of Christmas Past, Present, and Yet to Come each play a significant role by exposing Scrooge to various scenes, prompting deep introspection and emotional responses. Jacob Marley's ghost, in particular, acts as the initial catalyst, warning Scrooge of the consequences of his miserly ways and setting the stage for the more impactful supernatural encounters that follow [Data: Reports (0); Entities (17, 18, 46); Relationships (29, 30, 46, 55, 114, +more)].
64
+
65
+ ### Social Critique
66
+
67
+ Dickens uses the story to critique the social conditions of his time, emphasizing themes of poverty and social neglect. Characters like Bob Cratchit and his family highlight the struggles of the working class, while the conceptual children, Ignorance and Want, presented by the Ghost of Christmas Present, illustrate the dire consequences of societal indifference. This critique encourages readers to reflect on their own social responsibilities and the broader societal need for compassion and generosity [Data: Reports (0); Entities (57, 107, 141); Relationships (45, 102, 159, +more)].
68
+
69
+ ### Family and Community
70
+
71
+ Another prominent theme is the importance of family and community. Scrooge's interactions with Fred, his cheerful and persistent nephew, highlight the warmth and joy that family connections can bring. Fred's continued efforts to include Scrooge in family celebrations play a pivotal role in rekindling their relationship. Likewise, the Cratchit family, despite their poverty, exemplifies love, togetherness, and the true spirit of Christmas, all of which contribute to Scrooge's awakening and ultimate redemption [Data: Reports (0); Entities (14, 44); Relationships (28, 39, 40, 86, 153, +more)].
72
+
73
+ ### Generosity and Charity
74
+
75
+ The theme of generosity and charity is woven throughout the narrative. Scrooge's eventual acts of kindness, such as his improved treatment of Bob Cratchit and his charitable gestures towards the end of the story, symbolize the transformative power of giving. This newly embraced generosity contrasts sharply with his former frugality and serves as a testament to the positive changes that have taken place within him [Data: Reports (0); Entities (6, 43); Relationships (7, 8, 84, 127, +more)].
76
+
77
+ ### Conclusion
78
+
79
+ In summary, "A Christmas Carol" is rich with themes that explore personal transformation, social responsibility, family, and the spirit of generosity. These themes are intricately linked through the narrative structure and character interactions, making the story a timeless reflection on the human condition and the potential for redemption and change [Data: Reports (0); Entities (1, 17, 18); Relationships (0, 14, 31, 45, 48, +more)].
80
+
81
+ Feel free to ask if you need more detailed information on any specific aspect of these themes!
82
+ ```
83
+
84
+
85
+
86
+ ## Global Search Results
87
+
88
+ ### GraphRAG
89
+
90
+ "What are the top themes in this story? "
91
+
92
+ ```markdown
93
+ #### Transformation
94
+ The central theme of the story is the transformation of Ebenezer Scrooge. Initially depicted as a miserly and unsympathetic individual, Scrooge undergoes a profound change to become a generous and caring person. This transformation is driven by his interactions with various supernatural entities, including Jacob Marley and the Ghosts of Christmas Past, Present, and Yet to Come [Data: Reports (18, 17, 27, 21, 6, 26)].
95
+
96
+ #### Familial Love and Support
97
+ The theme of familial love and support is prominently depicted through the Cratchit Family. Despite their financial struggles, the Cratchit family remains loving and supportive of each other. Key relationships include Bob Cratchit's deep affection for his son Tiny Tim and the family's collective grief over Tiny Tim's potential death [Data: Reports (20, 19)].
98
+
99
+ #### Generosity and Compassion
100
+ Generosity and compassion are illustrated through characters like Fezziwig and the transformed Scrooge. Fezziwig's joyful Christmas parties and Scrooge's eventual acts of kindness towards the Cratchit family and others highlight the positive impact of generosity on the community [Data: Reports (23, 19, 26)].
101
+
102
+ #### Social Responsibility
103
+ The theme of social responsibility is underscored by the allegorical figures of Ignorance and Want, as well as the broader societal issues highlighted by the Ghost of Christmas Present. These elements emphasize the importance of addressing poverty and the consequences of neglecting the less fortunate [Data: Reports (17, 21)].
104
+
105
+ #### The Christmas Spirit
106
+ The Christmas spirit, encompassing joy, celebration, and togetherness, is a recurring motif throughout the story. Various scenes, such as the Cratchit family's Christmas dinner, Fred's Christmas party, and the celebrations shown by the Ghost of Christmas Present, reinforce the importance of embracing the festive spirit and its associated values [Data: Reports (20, 18, 23)].
107
+
108
+ #### Impact of Past Relationships
109
+ Scrooge's past relationships, particularly with 'The Girl' (later 'The Matron') and Belle, highlight themes of lost opportunities and the contrast between his lonely existence and the warmth of family life [Data: Reports (2)].
110
+
111
+ #### Role of Supernatural Elements
112
+ Supernatural elements, including Marley's Ghost and the three spirits, play a crucial role in Scrooge's transformation. These entities guide Scrooge through various significant scenes, influencing his reflections and realizations [Data: Reports (26)].
113
+
114
+ #### Family and Love
115
+ The relationships within the community, such as those between Mr. and Mrs. Fezziwig and Belle's family, underscore the importance of family and love. These relationships serve as a reflection on what Scrooge has missed out on due to his choices [Data: Reports (2, 25)].
116
+
117
+ These themes collectively highlight the story's emphasis on personal growth, the importance of community and family, and the enduring values associated with the Christmas season.
118
+ ```
119
+
120
+
121
+
122
+ ### `nano-graghrag`
123
+
124
+ "What are the top themes in this story? "
125
+
126
+ ```markdown
127
+ ### Key Themes in the Story
128
+
129
+ #### 1. Transformation and Redemption
130
+ The central theme of the story revolves around Ebenezer Scrooge's profound transformation from a miserly, cold-hearted man to a generous and kind individual. This journey of change emphasizes the possibility of redemption and the capacity for personal change. The transformation is predominantly driven by Scrooge's interactions with various supernatural entities such as the Ghosts of Christmas Past, Present, and Yet to Come [Data: Reports (0, 1, 2, 3, 4, 5, 6, 12, 23, 14, +more)].
131
+
132
+ #### 2. The Role of Supernatural Entities
133
+ Supernatural spirits play critical roles in guiding Scrooge toward his transformation. The Ghosts of Christmas Past, Present, and Yet to Come evoke reflection, regret, and fear in Scrooge, highlighting the impact of spiritual guidance on personal redemption. These encounters force him to confront and reflect on his past actions and their consequences [Data: Reports (1, 2, 3, 12, 0)].
134
+
135
+ #### 3. Generosity and Community Support
136
+ Another significant theme is the spirit of generosity and community support. Initially depicted as a miser, Scrooge's eventual acts of generosity and the festive spirit among various characters underscore the importance of kindness and community. The Cratchit family's modest yet joyous Christmas celebration poignantly highlights this theme [Data: Reports (26, 6, 2, 14)].
137
+
138
+ #### 4. Family and Emotional Bonds
139
+ The story places significant emphasis on family and emotional connections. This theme is illustrated through the relationships between Scrooge and his family, such as the close sibling bond with Fan and the dynamics within the Cratchit family. These relationships and their development underscore the importance of familial love and support [Data: Reports (0, 2, 4, 17, 5, 15)].
140
+
141
+ #### 5. Social Justice and Economic Disparity
142
+ The narrative also addresses themes of social justice and economic disparity. Scrooge's initial indifference to others' struggles and the contrasting lifestyles of characters like Bob Cratchit highlight the societal inequalities of the time. This critique is further exemplified through the characters’ interactions and the broader social critique in the story [Data: Reports (0, 1, 24, 9, 25, +more)].
143
+
144
+ #### 6. Role of Memory and the Past
145
+ Closely tied to Scrooge’s transformation is the theme of memory and the past. The Ghosts of Christmas Past and Marley’s Ghost prompt Scrooge to reflect on his own life, evoking memories that significantly impact his journey toward becoming a better man. These reflections play a crucial role in his redemption [Data: Reports (3, 8, 13, 17, 6, +more)].
146
+
147
+ ### Conclusion
148
+ These themes collectively create a rich and multifaceted narrative that conveys powerful messages about personal growth, the impact of community, the value of generosity, and the critique of social inequities. The journey of Ebenezer Scrooge serves as a timeless reminder of the potential for change and the importance of compassion and empathy in society.
149
+ ```
150
+
nano-graphrag/docs/benchmark-zh.md ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ - We use [三国演义](https://github.com/tennessine/corpus/blob/master/%E4%B8%89%E5%9B%BD%E6%BC%94%E4%B9%89.txt) by 罗贯中 as the benchmark corpus.
2
+ - We use `61b5eea34783c58074b3c53f1689ad8a5ba6b6ee` commit of [Official GraphRAG implementation](https://github.com/microsoft/graphrag/tree/main)
3
+ - Both GraphRAG and `nano-graphrag` use OpenAI Embedding and `gpt-4o`.
4
+ - Not Cache for both.
5
+ - On the same device and network connection.
6
+ - GrapgRAG Max Async API request: 25
7
+ - `nano-graphrag` Max Async API request: 16
8
+
9
+ ## Index Benchmark
10
+
11
+ Both are suck, more than 10 minutes.
12
+
13
+ ## Local Search Results
14
+
15
+ ### GraphRAG
16
+
17
+ "大乔和曹操的关系"
18
+
19
+ ```markdown
20
+ 大乔与曹操之间的关系在历史和文学中被描绘为复杂且多维的,尽管直接的互动可能并不频繁。大乔,作为东吴孙策的妻子,以及其美貌和与三国时期重要人物的关联而闻名。曹操,作为三国时期的关键政治和军事领导者,其与大乔的关系主要通过他与东吴的关系以及对大乔及其家族的兴趣来体现。
21
+
22
+ 首先,大乔的美貌在三国时期广为人知,甚至引起了曹操的注意。据记载,曹操曾有意俘获大乔以供己用,这一点体现了曹操对大乔的个人兴趣,尽管这种兴趣可能更多地是基于她的美貌和象征意义,而非两人之间有实质性的个人关系 [Data: Entities (266)]。
23
+
24
+ 其次,大乔作为孙策的妻子,自然成为了孙家与东吴势力中的一员。孙策和孙权是曹操在政治和军事上的对手,这使得大乔与曹操之间存在着间接的敌对关系。曹操对东吴的军事行动,包括对孙权的攻击,间接地影响到了大乔及其家族的安全和利益 [Data: Entities (266)]。
25
+
26
+ 最后,尽管直接的文献记录可能不多,大乔的地位和她所处的时代背景使她与曹操之间的关系充满了历史和文化上的意义。她的故事和曹操的野心相交织,展现了三国时期女性角色与政治军事活动之间的复杂联系。
27
+
28
+ 综上所述,大乔与曹操之间的关系主要是通过她的家庭背景、曹操的政治军事活动,以及文化记忆中对她美貌的追捧来定义的。这种关系更多地体现在历史和文学的叙述中,而不是基于两人之间有直接的个人互动。
29
+ ```
30
+
31
+ ### `nano-graphrag`
32
+
33
+ "大乔和曹操的关系"
34
+
35
+ ```markdown
36
+ # 大乔和曹操的关系
37
+ ## 大乔的背景
38
+ 大乔(Da Qiao)是江东乔公的长女,以其极美的容貌闻名于世。她最为著名的身份是孙策(Sun Ce)的妻子,这一身份在三国时期的政治和军事格局中也占有重要地位。周瑜特别提到她,显示她在周瑜与曹操局势中的重要性 [Data: Entities (1)]。
39
+ ## 曹操的欲望
40
+ 曹操(Cao Cao)对大乔的渴望主要源自其好色本性。他希望得到大乔和她的妹妹小乔,以满足个人欲望,并且认为获得她们可以帮助他提高政治地位和享乐。曹操曾发誓希望能得到江东二乔(大乔和小乔),并将她们置于铜雀台上,以乐晚年 [Data: Relationships (6, 7); Entities (1, 2)]。
41
+ ## 谋略与政治
42
+ 在三国时期的权谋斗争中,诸葛亮曾提议使用美人计,将大乔和小乔送给曹操,以此来使曹操班师回朝,减少对东吴的威胁。这一提议显示出大乔在权谋中的重要角色以及她可能对大局产生的影响 [Data: Sources (0); Relationships (12)]。
43
+ ## 曹操的策略和行动
44
+ 尽管曹操对大乔有很大的兴趣,但实际行动中的证据有限。周瑜得知曹操对大乔的意图后,表达了强烈的反对,并与诸葛亮商议对策,以确保她和小乔的安全 [Data: Sources (2)]。周瑜明确表示大乔是孙策的正妻,而小乔是他的妻子,拒绝将她们送给曹操。
45
+ ## 总结
46
+ 总体而言,曹操和大乔没有直接的互动记录,但曹操对大乔的欲望及其可能带来的政治影响却是不容忽视的。大乔作为江东乔公的长女,其美貌和身份使她成为多方势力争夺的对象,但她与曹操的关系主要体现在曹操的欲望和诸葛亮、周瑜等人的权谋计划中。
47
+ 曹操对大乔的欲望和权谋体现了三国时期复杂的政治动态,同时也揭示了美人计在当时战略中的重要性 [Data: Relationships (6, 12, 13); Entities (1); Sources (0, 2)].
48
+ ```
49
+
50
+
51
+
52
+ ## Global Search Results
53
+
54
+ ### GraphRAG
55
+
56
+ "大乔与曹操的关系"
57
+
58
+ ```markdown
59
+ ### 大乔与曹操的关系
60
+
61
+ 大乔(Da Qiao)是乔公的女儿,以其非凡的美貌而闻名。这一特质使她在社区中成为一个重要人物,尤其是在曹操(Cao Cao)对她的渴望背景下 [Data: Reports (41)]。
62
+
63
+ #### 曹操的计划
64
+
65
+ 曹操计划在统一江南后,将大乔安置在铜雀台(Tongque Tai)。铜雀台是曹操在叶县张河上建造的一座宏伟平台,象征着好运和庆祝。曹操打算将包括大乔在内的美丽女子安置在铜雀台中 [Data: Reports (269)]。这一计划不仅反映了曹操的野心和愿景,也��示了他对大乔的特别关注。
66
+
67
+ #### 乔公与曹操的关系
68
+
69
+ 乔公是大乔的父亲,同时也是曹操的朋友。这层关系不仅突显了乔公的影响力,还暗示了由于曹操对大乔的渴望,可能会产生潜在的杠杆作用或冲突 [Data: Reports (41)]。
70
+
71
+ ### 结论
72
+
73
+ 综上所述,大乔与曹操的关系主要体现在曹操对她的渴望以及他计划将她安置在铜雀台的意图上。此外,乔公与曹操的友谊也为这一关系增添了复杂性,可能会在未来引发更多的动态变化 [Data: Reports (41, 269)]。
74
+ ```
75
+
76
+
77
+
78
+ ### `nano-graph`
79
+
80
+ "大乔与曹操的关系"
81
+
82
+ ```markdown
83
+ ## 曹操与大乔的关系
84
+
85
+ 曹操对大乔展现了显著的政治和战略兴趣,这部分是由于大乔的美貌以及她在战略上的重要性。在那个时期,美丽并非单纯的外貌特征,还往往与政治资源和战略权力相关联。曹操认为拥有大乔会对他的政治野心有所帮助,并为其权势扩大增添筹码 [Data: Reports (6)]。
86
+
87
+ 大乔作为孙策的妻子,再加上她的地位和美貌,使得她在权力动态中成为一个关键的目标。曹操对大乔的重视,不仅因为她个人的吸引力,更因为她所代表的权力象征。如果能将大乔纳入自己的势力范围,曹操将可能通过这种方式削弱敌对势力,同时增强自己的影响力 [Data: Reports (6)]。
88
+
89
+ 总之,曹操和大乔的关系不仅仅停留在个人层面,更深层次的是政治策略和权力博弈的体现。曹操对她的兴趣,反映了他在权力斗争中的精明与野心。
90
+ ```
91
+
nano-graphrag/docs/use_neo4j_for_graphrag.md ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 1. Install [Neo4j](https://neo4j.com/docs/operations-manual/current/installation/) (please use 5.x version)
2
+ 2. Install Neo4j GDS (graph data science) [plugin](https://neo4j.com/docs/graph-data-science/current/installation/neo4j-server/)
3
+ 3. Start neo4j server
4
+ 4. Get the `NEO4J_URL`, `NEO4J_USER` and `NEO4J_PASSWORD`
5
+ - By default, `NEO4J_URL` is `neo4j://localhost:7687` , `NEO4J_USER` is `neo4j` and `NEO4J_PASSWORD` is `neo4j`
6
+
7
+ Pass your neo4j instance to `GraphRAG`:
8
+
9
+ ```python
10
+ from nano_graphrag import GraphRAG
11
+ from nano_graphrag._storage import Neo4jStorage
12
+
13
+ neo4j_config = {
14
+ "neo4j_url": os.environ.get("NEO4J_URL", "neo4j://localhost:7687"),
15
+ "neo4j_auth": (
16
+ os.environ.get("NEO4J_USER", "neo4j"),
17
+ os.environ.get("NEO4J_PASSWORD", "neo4j"),
18
+ )
19
+ }
20
+ GraphRAG(
21
+ graph_storage_cls=Neo4jStorage,
22
+ addon_params=neo4j_config,
23
+ )
24
+ ```
25
+
26
+
27
+
nano-graphrag/examples/benchmarks/dspy_entity.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dspy
2
+ import os
3
+ from dotenv import load_dotenv
4
+ from openai import AsyncOpenAI
5
+ import logging
6
+ import asyncio
7
+ import time
8
+ import shutil
9
+ from nano_graphrag.entity_extraction.extract import extract_entities_dspy
10
+ from nano_graphrag.base import BaseKVStorage
11
+ from nano_graphrag._storage import NetworkXStorage
12
+ from nano_graphrag._utils import compute_mdhash_id, compute_args_hash
13
+ from nano_graphrag._op import extract_entities
14
+
15
+ WORKING_DIR = "./nano_graphrag_cache_dspy_entity"
16
+
17
+ load_dotenv()
18
+
19
+ logger = logging.getLogger("nano-graphrag")
20
+ logger.setLevel(logging.DEBUG)
21
+
22
+
23
+ async def deepseepk_model_if_cache(
24
+ prompt: str, model: str = "deepseek-chat", system_prompt : str = None, history_messages: list = [], **kwargs
25
+ ) -> str:
26
+ openai_async_client = AsyncOpenAI(
27
+ api_key=os.environ.get("DEEPSEEK_API_KEY"), base_url="https://api.deepseek.com"
28
+ )
29
+ messages = []
30
+ if system_prompt:
31
+ messages.append({"role": "system", "content": system_prompt})
32
+
33
+ # Get the cached response if having-------------------
34
+ hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
35
+ messages.extend(history_messages)
36
+ messages.append({"role": "user", "content": prompt})
37
+ if hashing_kv is not None:
38
+ args_hash = compute_args_hash(model, messages)
39
+ if_cache_return = await hashing_kv.get_by_id(args_hash)
40
+ if if_cache_return is not None:
41
+ return if_cache_return["return"]
42
+ # -----------------------------------------------------
43
+
44
+ response = await openai_async_client.chat.completions.create(
45
+ model=model, messages=messages, **kwargs
46
+ )
47
+
48
+ # Cache the response if having-------------------
49
+ if hashing_kv is not None:
50
+ await hashing_kv.upsert(
51
+ {args_hash: {"return": response.choices[0].message.content, "model": model}}
52
+ )
53
+ # -----------------------------------------------------
54
+ return response.choices[0].message.content
55
+
56
+
57
+ async def benchmark_entity_extraction(text: str, system_prompt: str, use_dspy: bool = False):
58
+ working_dir = os.path.join(WORKING_DIR, f"use_dspy={use_dspy}")
59
+ if os.path.exists(working_dir):
60
+ shutil.rmtree(working_dir)
61
+
62
+ start_time = time.time()
63
+ graph_storage = NetworkXStorage(namespace="test", global_config={
64
+ "working_dir": working_dir,
65
+ "entity_summary_to_max_tokens": 500,
66
+ "cheap_model_func": lambda *args, **kwargs: deepseepk_model_if_cache(*args, system_prompt=system_prompt, **kwargs),
67
+ "best_model_func": lambda *args, **kwargs: deepseepk_model_if_cache(*args, system_prompt=system_prompt, **kwargs),
68
+ "cheap_model_max_token_size": 4096,
69
+ "best_model_max_token_size": 4096,
70
+ "tiktoken_model_name": "gpt-4o",
71
+ "hashing_kv": BaseKVStorage(namespace="test", global_config={"working_dir": working_dir}),
72
+ "entity_extract_max_gleaning": 1,
73
+ "entity_extract_max_tokens": 4096,
74
+ "entity_extract_max_entities": 100,
75
+ "entity_extract_max_relationships": 100,
76
+ })
77
+ chunks = {compute_mdhash_id(text, prefix="chunk-"): {"content": text}}
78
+
79
+ if use_dspy:
80
+ graph_storage = await extract_entities_dspy(chunks, graph_storage, None, graph_storage.global_config)
81
+ else:
82
+ graph_storage = await extract_entities(chunks, graph_storage, None, graph_storage.global_config)
83
+
84
+ end_time = time.time()
85
+ execution_time = end_time - start_time
86
+
87
+ return graph_storage, execution_time
88
+
89
+
90
+ def print_extraction_results(graph_storage: NetworkXStorage):
91
+ print("\nEntities:")
92
+ entities = []
93
+ for node, data in graph_storage._graph.nodes(data=True):
94
+ entity_type = data.get('entity_type', 'Unknown')
95
+ description = data.get('description', 'No description')
96
+ entities.append(f"- {node} ({entity_type}):\n {description}")
97
+ print("\n".join(entities))
98
+
99
+ print("\nRelationships:")
100
+ relationships = []
101
+ for source, target, data in graph_storage._graph.edges(data=True):
102
+ description = data.get('description', 'No description')
103
+ relationships.append(f"- {source} -> {target}:\n {description}")
104
+ print("\n".join(relationships))
105
+
106
+
107
+ async def run_benchmark(text: str):
108
+ print("\nRunning benchmark with DSPy-AI:")
109
+ system_prompt = """
110
+ You are an expert system specialized in entity and relationship extraction from complex texts.
111
+ Your task is to thoroughly analyze the given text and extract all relevant entities and their relationships with utmost precision and completeness.
112
+ """
113
+ system_prompt_dspy = f"{system_prompt} Time: {time.time()}."
114
+ lm = dspy.LM(
115
+ model="deepseek/deepseek-chat",
116
+ model_type="chat",
117
+ api_provider="openai",
118
+ api_key=os.environ["DEEPSEEK_API_KEY"],
119
+ base_url=os.environ["DEEPSEEK_BASE_URL"],
120
+ system_prompt=system_prompt,
121
+ temperature=1.0,
122
+ max_tokens=8192
123
+ )
124
+ dspy.settings.configure(lm=lm, experimental=True)
125
+ graph_storage_with_dspy, time_with_dspy = await benchmark_entity_extraction(text, system_prompt_dspy, use_dspy=True)
126
+ print(f"Execution time with DSPy-AI: {time_with_dspy:.2f} seconds")
127
+ print_extraction_results(graph_storage_with_dspy)
128
+
129
+ print("Running benchmark without DSPy-AI:")
130
+ system_prompt_no_dspy = f"{system_prompt} Time: {time.time()}."
131
+ graph_storage_without_dspy, time_without_dspy = await benchmark_entity_extraction(text, system_prompt_no_dspy, use_dspy=False)
132
+ print(f"Execution time without DSPy-AI: {time_without_dspy:.2f} seconds")
133
+ print_extraction_results(graph_storage_without_dspy)
134
+
135
+ print("\nComparison:")
136
+ print(f"Time difference: {abs(time_with_dspy - time_without_dspy):.2f} seconds")
137
+ print(f"DSPy-AI is {'faster' if time_with_dspy < time_without_dspy else 'slower'}")
138
+
139
+ entities_without_dspy = len(graph_storage_without_dspy._graph.nodes())
140
+ entities_with_dspy = len(graph_storage_with_dspy._graph.nodes())
141
+ relationships_without_dspy = len(graph_storage_without_dspy._graph.edges())
142
+ relationships_with_dspy = len(graph_storage_with_dspy._graph.edges())
143
+
144
+ print(f"Entities extracted: {entities_without_dspy} (without DSPy-AI) vs {entities_with_dspy} (with DSPy-AI)")
145
+ print(f"Relationships extracted: {relationships_without_dspy} (without DSPy-AI) vs {relationships_with_dspy} (with DSPy-AI)")
146
+
147
+
148
+ if __name__ == "__main__":
149
+ with open("./tests/zhuyuanzhang.txt", encoding="utf-8-sig") as f:
150
+ text = f.read()
151
+
152
+ asyncio.run(run_benchmark(text=text))
nano-graphrag/examples/benchmarks/eval_naive_graphrag_on_multi_hop.ipynb ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "In this tutorial, we are going to evaluate the performance of the naive RAG and the GraphRAG algorithm on a [multi-hop RAG task](https://github.com/yixuantt/MultiHop-RAG)."
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "markdown",
12
+ "metadata": {},
13
+ "source": [
14
+ "## Setup\n",
15
+ "Make sure you install the necessary dependencies by running the following commands:"
16
+ ]
17
+ },
18
+ {
19
+ "cell_type": "code",
20
+ "execution_count": null,
21
+ "metadata": {},
22
+ "outputs": [],
23
+ "source": [
24
+ "!pip install ragas nest_asyncio datasets"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "markdown",
29
+ "metadata": {},
30
+ "source": [
31
+ "Import the necessary libraries, and set up your openai api key if needed:"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": 21,
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "import os\n",
41
+ "# os.environ[\"OPENAI_API_KEY\"] = \"YOUR_API_KEY\"\n",
42
+ "import json\n",
43
+ "import sys\n",
44
+ "sys.path.append(\"../..\")\n",
45
+ "\n",
46
+ "import nest_asyncio\n",
47
+ "nest_asyncio.apply()\n",
48
+ "import logging\n",
49
+ "\n",
50
+ "logging.basicConfig(level=logging.WARNING)\n",
51
+ "logging.getLogger(\"nano-graphrag\").setLevel(logging.INFO)\n",
52
+ "from nano_graphrag import GraphRAG, QueryParam\n",
53
+ "from datasets import Dataset \n",
54
+ "from ragas import evaluate\n",
55
+ "from ragas.metrics import (\n",
56
+ " answer_correctness,\n",
57
+ " answer_similarity,\n",
58
+ ")"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "markdown",
63
+ "metadata": {},
64
+ "source": [
65
+ "Download the dataset from [Github Repo](https://github.com/yixuantt/MultiHop-RAG/tree/main/dataset). \n",
66
+ "If should contain two files:\n",
67
+ "- `MultiHopRAG.json`\n",
68
+ "- `corpus.json`\n",
69
+ "\n",
70
+ "After downloading the dataset, replace the below paths to the paths on your machine."
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": 3,
76
+ "metadata": {},
77
+ "outputs": [],
78
+ "source": [
79
+ "\n",
80
+ "multi_hop_rag_file = \"./fixtures/MultiHopRAG.json\"\n",
81
+ "multi_hop_corpus_file = \"./fixtures/corpus.json\""
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "markdown",
86
+ "metadata": {},
87
+ "source": [
88
+ "## Preprocess"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": 4,
94
+ "metadata": {},
95
+ "outputs": [],
96
+ "source": [
97
+ "\n",
98
+ "with open(multi_hop_rag_file) as f:\n",
99
+ " multi_hop_rag_dataset = json.load(f)\n",
100
+ "with open(multi_hop_corpus_file) as f:\n",
101
+ " multi_hop_corpus = json.load(f)\n",
102
+ "\n",
103
+ "corups_url_refernces = {}\n",
104
+ "for cor in multi_hop_corpus:\n",
105
+ " corups_url_refernces[cor['url']] = cor"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "markdown",
110
+ "metadata": {},
111
+ "source": [
112
+ "We only use the top-100 queries for evaluation."
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": 5,
118
+ "metadata": {},
119
+ "outputs": [
120
+ {
121
+ "name": "stdout",
122
+ "output_type": "stream",
123
+ "text": [
124
+ "Queries have types: {'inference_query', 'comparison_query', 'null_query', 'temporal_query'}\n",
125
+ "We will need 139 articles:\n",
126
+ "## ASX set to drop as Wall Street’s September slump deepens\n",
127
+ "Author: Stan Choe, The Sydney Morning Herald\n",
128
+ "Category: business\n",
129
+ "Publised: 2023-09-26T19:11:30+00:00\n",
130
+ "ETF provider Betashares, which manages $ ...\n"
131
+ ]
132
+ }
133
+ ],
134
+ "source": [
135
+ "multi_hop_rag_dataset = multi_hop_rag_dataset[:100]\n",
136
+ "print(\"Queries have types:\", set([q['question_type'] for q in multi_hop_rag_dataset]))\n",
137
+ "total_urls = set()\n",
138
+ "for q in multi_hop_rag_dataset:\n",
139
+ " total_urls.update([up['url'] for up in q['evidence_list']])\n",
140
+ "corups_url_refernces = {k:v for k, v in corups_url_refernces.items() if k in total_urls}\n",
141
+ "\n",
142
+ "total_corups = [f\"## {cor['title']}\\nAuthor: {cor['author']}, {cor['source']}\\nCategory: {cor['category']}\\nPublised: {cor['published_at']}\\n{cor['body']}\" for cor in corups_url_refernces.values()]\n",
143
+ "\n",
144
+ "print(f\"We will need {len(total_corups)} articles:\")\n",
145
+ "print(total_corups[0][:200], \"...\")"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "markdown",
150
+ "metadata": {},
151
+ "source": [
152
+ "Add index for the `total_corups` using naive RAG and GraphRAG"
153
+ ]
154
+ },
155
+ {
156
+ "cell_type": "code",
157
+ "execution_count": 6,
158
+ "metadata": {},
159
+ "outputs": [
160
+ {
161
+ "name": "stderr",
162
+ "output_type": "stream",
163
+ "text": [
164
+ "INFO:nano-graphrag:Load KV full_docs with 139 data\n",
165
+ "INFO:nano-graphrag:Load KV text_chunks with 408 data\n",
166
+ "INFO:nano-graphrag:Load KV llm_response_cache with 1634 data\n",
167
+ "INFO:nano-graphrag:Load KV community_reports with 794 data\n",
168
+ "INFO:nano-graphrag:Loaded graph from nano_graphrag_cache_multi_hop_rag_test/graph_chunk_entity_relation.graphml with 6181 nodes, 5423 edges\n",
169
+ "WARNING:nano-graphrag:All docs are already in the storage\n",
170
+ "INFO:nano-graphrag:Writing graph with 6181 nodes, 5423 edges\n"
171
+ ]
172
+ }
173
+ ],
174
+ "source": [
175
+ "# First time indexing will cost many time, roughly 15~20 minutes\n",
176
+ "graphrag_func = GraphRAG(working_dir=\"nano_graphrag_cache_multi_hop_rag_test\", enable_naive_rag=True,\n",
177
+ " embedding_func_max_async=4)\n",
178
+ "graphrag_func.insert(total_corups)"
179
+ ]
180
+ },
181
+ {
182
+ "cell_type": "markdown",
183
+ "metadata": {},
184
+ "source": [
185
+ "Look at the response of different RAG methods on the first query:"
186
+ ]
187
+ },
188
+ {
189
+ "cell_type": "code",
190
+ "execution_count": 24,
191
+ "metadata": {},
192
+ "outputs": [],
193
+ "source": [
194
+ "response_formate = \"Single phrase or sentence, concise and no redundant explanation needed. If you don't have the answer in context, Just response 'Insufficient information'\"\n",
195
+ "naive_rag_query_param = QueryParam(mode='naive', response_type=response_formate)\n",
196
+ "naive_rag_query_only_context_param = QueryParam(mode='naive', only_need_context=True)\n",
197
+ "local_graphrag_query_param = QueryParam(mode='local', response_type=response_formate)\n",
198
+ "local_graphrag_only_context__param = QueryParam(mode='local', only_need_context=True)"
199
+ ]
200
+ },
201
+ {
202
+ "cell_type": "code",
203
+ "execution_count": 8,
204
+ "metadata": {},
205
+ "outputs": [
206
+ {
207
+ "name": "stdout",
208
+ "output_type": "stream",
209
+ "text": [
210
+ "Question: Who is the individual associated with the cryptocurrency industry facing a criminal trial on fraud and conspiracy charges, as reported by both The Verge and TechCrunch, and is accused by prosecutors of committing fraud for personal gain?\n",
211
+ "GroundTruth Answer: Sam Bankman-Fried\n"
212
+ ]
213
+ }
214
+ ],
215
+ "source": [
216
+ "query = multi_hop_rag_dataset[0]\n",
217
+ "print(\"Question:\", query['query'])\n",
218
+ "print(\"GroundTruth Answer:\", query['answer'])"
219
+ ]
220
+ },
221
+ {
222
+ "cell_type": "code",
223
+ "execution_count": 9,
224
+ "metadata": {},
225
+ "outputs": [
226
+ {
227
+ "name": "stderr",
228
+ "output_type": "stream",
229
+ "text": [
230
+ "INFO:nano-graphrag:Truncate 20 to 12 chunks\n"
231
+ ]
232
+ },
233
+ {
234
+ "name": "stdout",
235
+ "output_type": "stream",
236
+ "text": [
237
+ "NaiveRAG Answer: Sam Bankman-Fried\n"
238
+ ]
239
+ }
240
+ ],
241
+ "source": [
242
+ "print(\"NaiveRAG Answer:\", graphrag_func.query(query['query'], param=naive_rag_query_param))"
243
+ ]
244
+ },
245
+ {
246
+ "cell_type": "code",
247
+ "execution_count": 10,
248
+ "metadata": {},
249
+ "outputs": [
250
+ {
251
+ "name": "stderr",
252
+ "output_type": "stream",
253
+ "text": [
254
+ "INFO:nano-graphrag:Using 20 entites, 3 communities, 124 relations, 3 text units\n"
255
+ ]
256
+ },
257
+ {
258
+ "name": "stdout",
259
+ "output_type": "stream",
260
+ "text": [
261
+ "Local GraphRAG Answer: Sam Bankman-Fried\n"
262
+ ]
263
+ }
264
+ ],
265
+ "source": [
266
+ "print(\"Local GraphRAG Answer:\", graphrag_func.query(query['query'], param=local_graphrag_query_param))"
267
+ ]
268
+ },
269
+ {
270
+ "cell_type": "markdown",
271
+ "metadata": {},
272
+ "source": [
273
+ "Great! Now we're ready to evaluate more detailed metrics. We will use [ragas](https://docs.ragas.io/en/stable/) to evalue the answers' quality."
274
+ ]
275
+ },
276
+ {
277
+ "cell_type": "code",
278
+ "execution_count": 11,
279
+ "metadata": {},
280
+ "outputs": [],
281
+ "source": [
282
+ "questions = [q['query'] for q in multi_hop_rag_dataset]\n",
283
+ "labels = [q['answer'] for q in multi_hop_rag_dataset]"
284
+ ]
285
+ },
286
+ {
287
+ "cell_type": "code",
288
+ "execution_count": 12,
289
+ "metadata": {},
290
+ "outputs": [
291
+ {
292
+ "name": "stderr",
293
+ "output_type": "stream",
294
+ "text": [
295
+ " 0%| | 0/100 [00:00<?, ?it/s]"
296
+ ]
297
+ },
298
+ {
299
+ "name": "stderr",
300
+ "output_type": "stream",
301
+ "text": [
302
+ "100%|██████████| 100/100 [03:53<00:00, 2.33s/it]\n"
303
+ ]
304
+ }
305
+ ],
306
+ "source": [
307
+ "from tqdm import tqdm\n",
308
+ "logging.getLogger(\"nano-graphrag\").setLevel(logging.WARNING)\n",
309
+ "\n",
310
+ "naive_rag_answers = [\n",
311
+ " graphrag_func.query(q, param=naive_rag_query_param) for q in tqdm(questions)\n",
312
+ "]"
313
+ ]
314
+ },
315
+ {
316
+ "cell_type": "code",
317
+ "execution_count": 14,
318
+ "metadata": {},
319
+ "outputs": [
320
+ {
321
+ "name": "stderr",
322
+ "output_type": "stream",
323
+ "text": [
324
+ "100%|██████████| 100/100 [09:10<00:00, 5.50s/it]\n"
325
+ ]
326
+ }
327
+ ],
328
+ "source": [
329
+ "local_graphrag_answers = [\n",
330
+ " graphrag_func.query(q, param=local_graphrag_query_param) for q in tqdm(questions)\n",
331
+ "]"
332
+ ]
333
+ },
334
+ {
335
+ "cell_type": "code",
336
+ "execution_count": 34,
337
+ "metadata": {},
338
+ "outputs": [
339
+ {
340
+ "name": "stderr",
341
+ "output_type": "stream",
342
+ "text": [
343
+ " 70%|███████ | 70/100 [04:25<01:53, 3.79s/it]8, 6.38it/s]\n",
344
+ "Evaluating: 100%|██████████| 200/200 [00:32<00:00, 6.19it/s]\n"
345
+ ]
346
+ }
347
+ ],
348
+ "source": [
349
+ "naive_results = evaluate(\n",
350
+ " Dataset.from_dict({\n",
351
+ " \"question\": questions,\n",
352
+ " \"ground_truth\": labels,\n",
353
+ " \"answer\": naive_rag_answers,\n",
354
+ " }),\n",
355
+ " metrics=[\n",
356
+ " # answer_relevancy,\n",
357
+ " answer_correctness,\n",
358
+ " answer_similarity,\n",
359
+ " ],\n",
360
+ ")"
361
+ ]
362
+ },
363
+ {
364
+ "cell_type": "code",
365
+ "execution_count": 36,
366
+ "metadata": {},
367
+ "outputs": [
368
+ {
369
+ "name": "stderr",
370
+ "output_type": "stream",
371
+ "text": [
372
+ "Evaluating: 100%|██████████| 200/200 [00:23<00:00, 8.59it/s]\n"
373
+ ]
374
+ }
375
+ ],
376
+ "source": [
377
+ "local_graphrag_results = evaluate(\n",
378
+ " Dataset.from_dict({\n",
379
+ " \"question\": questions,\n",
380
+ " \"ground_truth\": labels,\n",
381
+ " \"answer\": local_graphrag_answers,\n",
382
+ " }),\n",
383
+ " metrics=[\n",
384
+ " # answer_relevancy,\n",
385
+ " answer_correctness,\n",
386
+ " answer_similarity,\n",
387
+ " ],\n",
388
+ ")"
389
+ ]
390
+ },
391
+ {
392
+ "cell_type": "code",
393
+ "execution_count": 39,
394
+ "metadata": {},
395
+ "outputs": [
396
+ {
397
+ "name": "stdout",
398
+ "output_type": "stream",
399
+ "text": [
400
+ "Naive RAG results {'answer_correctness': 0.5896, 'answer_similarity': 0.8935}\n",
401
+ "Local GraphRAG results {'answer_correctness': 0.7380, 'answer_similarity': 0.8619}\n"
402
+ ]
403
+ }
404
+ ],
405
+ "source": [
406
+ "print(\"Naive RAG results\", naive_results)\n",
407
+ "print(\"Local GraphRAG results\", local_graphrag_results)"
408
+ ]
409
+ }
410
+ ],
411
+ "metadata": {
412
+ "kernelspec": {
413
+ "display_name": "baai",
414
+ "language": "python",
415
+ "name": "python3"
416
+ },
417
+ "language_info": {
418
+ "codemirror_mode": {
419
+ "name": "ipython",
420
+ "version": 3
421
+ },
422
+ "file_extension": ".py",
423
+ "mimetype": "text/x-python",
424
+ "name": "python",
425
+ "nbconvert_exporter": "python",
426
+ "pygments_lexer": "ipython3",
427
+ "version": "3.9.19"
428
+ }
429
+ },
430
+ "nbformat": 4,
431
+ "nbformat_minor": 2
432
+ }
nano-graphrag/examples/benchmarks/hnsw_vs_nano_vector_storage.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import time
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ from nano_graphrag import GraphRAG
6
+ from nano_graphrag._storage import NanoVectorDBStorage, HNSWVectorStorage
7
+ from nano_graphrag._utils import wrap_embedding_func_with_attrs
8
+
9
+
10
+ WORKING_DIR = "./nano_graphrag_cache_benchmark_hnsw_vs_nano_vector_storage"
11
+ DATA_LEN = 100_000
12
+ FAKE_DIM = 1024
13
+ BATCH_SIZE = 100000
14
+
15
+
16
+ @wrap_embedding_func_with_attrs(embedding_dim=FAKE_DIM, max_token_size=8192)
17
+ async def sample_embedding(texts: list[str]) -> np.ndarray:
18
+ return np.float32(np.random.rand(len(texts), FAKE_DIM))
19
+
20
+
21
+ def generate_test_data():
22
+ return {str(i): {"content": f"Test content {i}"} for i in range(DATA_LEN)}
23
+
24
+
25
+ async def benchmark_storage(storage_class, name):
26
+ rag = GraphRAG(working_dir=WORKING_DIR, embedding_func=sample_embedding)
27
+ storage = storage_class(
28
+ namespace=f"benchmark_{name}",
29
+ global_config=rag.__dict__,
30
+ embedding_func=sample_embedding,
31
+ meta_fields={"content"},
32
+ )
33
+
34
+ test_data = generate_test_data()
35
+
36
+ print(f"Benchmarking {name}...")
37
+ with tqdm(total=DATA_LEN, desc=f"{name} Benchmark") as pbar:
38
+ start_time = time.time()
39
+ for i in range(0, len(test_data), BATCH_SIZE):
40
+ batch = {k: test_data[k] for k in list(test_data.keys())[i:i+BATCH_SIZE]}
41
+ await storage.upsert(batch)
42
+ pbar.update(min(BATCH_SIZE, DATA_LEN - i))
43
+
44
+ insert_time = time.time() - start_time
45
+
46
+ save_start_time = time.time()
47
+ await storage.index_done_callback()
48
+ save_time = time.time() - save_start_time
49
+ pbar.update(1)
50
+
51
+ query_vector = np.random.rand(FAKE_DIM)
52
+ query_times = []
53
+ for _ in range(100):
54
+ query_start = time.time()
55
+ await storage.query(query_vector, top_k=10)
56
+ query_times.append(time.time() - query_start)
57
+ pbar.update(1)
58
+
59
+ avg_query_time = sum(query_times) / len(query_times)
60
+
61
+ print(f"{name} - Insert: {insert_time:.2f}s, Save: {save_time:.2f}s, Avg Query: {avg_query_time:.4f}s")
62
+ return insert_time, save_time, avg_query_time
63
+
64
+
65
+ async def run_benchmarks():
66
+ print("Running NanoVectorDB benchmark...")
67
+ nano_insert_time, nano_save_time, nano_query_time = await benchmark_storage(NanoVectorDBStorage, "nano")
68
+
69
+ print("\nRunning HNSWVectorStorage benchmark...")
70
+ hnsw_insert_time, hnsw_save_time, hnsw_query_time = await benchmark_storage(HNSWVectorStorage, "hnsw")
71
+
72
+ print("\nBenchmark Results:")
73
+ print(f"NanoVectorDB - Insert: {nano_insert_time:.2f}s, Save: {nano_save_time:.2f}s, Avg Query: {nano_query_time:.4f}s")
74
+ print(f"HNSWVectorStorage - Insert: {hnsw_insert_time:.2f}s, Save: {hnsw_save_time:.2f}s, Avg Query: {hnsw_query_time:.4f}s")
75
+
76
+
77
+ if __name__ == "__main__":
78
+ asyncio.run(run_benchmarks())
nano-graphrag/examples/benchmarks/md5_vs_xxhash.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import xxhash
3
+ from hashlib import md5
4
+ from tqdm import tqdm
5
+ import numpy as np
6
+
7
+
8
+ def xxhash_ids(data: list[str]) -> np.ndarray:
9
+ return np.fromiter(
10
+ (xxhash.xxh32_intdigest(d.encode()) for d in data),
11
+ dtype=np.uint32,
12
+ count=len(data)
13
+ )
14
+
15
+
16
+ def md5_ids(data: list[str]) -> np.ndarray:
17
+ return np.fromiter(
18
+ (int(md5(d.encode()).hexdigest(), 16) & 0xFFFFFFFF for d in data),
19
+ dtype=np.uint32,
20
+ count=len(data)
21
+ )
22
+
23
+
24
+ if __name__ == "__main__":
25
+ num_ids = 1000000
26
+ num_iterations = 100
27
+ xxhash_times = []
28
+ md5_times = []
29
+
30
+ for i in tqdm(range(num_iterations)):
31
+ test_data = [f"{i}_{j}" for j in range(num_ids)]
32
+
33
+ start_time = time.time()
34
+ xxhash_result = xxhash_ids(test_data)
35
+ xxhash_times.append(time.time() - start_time)
36
+
37
+ start_time = time.time()
38
+ md5_result = md5_ids(test_data)
39
+ md5_times.append(time.time() - start_time)
40
+
41
+ assert len(xxhash_result) == len(md5_result) == num_ids
42
+ assert not np.array_equal(xxhash_result, md5_result)
43
+
44
+ avg_xxhash_time = np.mean(xxhash_times)
45
+ avg_md5_time = np.mean(md5_times)
46
+ std_xxhash_time = np.std(xxhash_times)
47
+ std_md5_time = np.std(md5_times)
48
+
49
+ print(f"num_ids: {num_ids} | num_iterations: {num_iterations}")
50
+ print(f"\nAverage xxhash time: {avg_xxhash_time:.4f} seconds")
51
+ print(f"Average MD5 time: {avg_md5_time:.4f} seconds")
52
+ print(f"xxhash is {avg_md5_time / avg_xxhash_time:.2f}x faster than MD5 on average")
53
+ print(f"\nxxhash time standard deviation: {std_xxhash_time:.4f} seconds")
54
+ print(f"MD5 time standard deviation: {std_md5_time:.4f} seconds")
nano-graphrag/examples/finetune_entity_relationship_dspy.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
nano-graphrag/examples/generate_entity_relationship_dspy.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
nano-graphrag/examples/graphml_visualize.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import networkx as nx
2
+ import json
3
+ import os
4
+ import webbrowser
5
+ import http.server
6
+ import socketserver
7
+ import threading
8
+
9
+ # load GraphML file and transfer to JSON
10
+ def graphml_to_json(graphml_file):
11
+ G = nx.read_graphml(graphml_file)
12
+ data = nx.node_link_data(G)
13
+ return json.dumps(data)
14
+
15
+
16
+ # create HTML file
17
+ def create_html(html_path):
18
+ html_content = '''
19
+ <!DOCTYPE html>
20
+ <html lang="en">
21
+ <head>
22
+ <meta charset="UTF-8">
23
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
24
+ <title>Graph Visualization</title>
25
+ <script src="https://d3js.org/d3.v7.min.js"></script>
26
+ <style>
27
+ body, html {
28
+ margin: 0;
29
+ padding: 0;
30
+ width: 100%;
31
+ height: 100%;
32
+ overflow: hidden;
33
+ }
34
+ svg {
35
+ width: 100%;
36
+ height: 100%;
37
+ }
38
+ .links line {
39
+ stroke: #999;
40
+ stroke-opacity: 0.6;
41
+ }
42
+ .nodes circle {
43
+ stroke: #fff;
44
+ stroke-width: 1.5px;
45
+ }
46
+ .node-label {
47
+ font-size: 12px;
48
+ pointer-events: none;
49
+ }
50
+ .link-label {
51
+ font-size: 10px;
52
+ fill: #666;
53
+ pointer-events: none;
54
+ opacity: 0;
55
+ transition: opacity 0.3s;
56
+ }
57
+ .link:hover .link-label {
58
+ opacity: 1;
59
+ }
60
+ .tooltip {
61
+ position: absolute;
62
+ text-align: left;
63
+ padding: 10px;
64
+ font: 12px sans-serif;
65
+ background: lightsteelblue;
66
+ border: 0px;
67
+ border-radius: 8px;
68
+ pointer-events: none;
69
+ opacity: 0;
70
+ transition: opacity 0.3s;
71
+ max-width: 300px;
72
+ }
73
+ .legend {
74
+ position: absolute;
75
+ top: 10px;
76
+ right: 10px;
77
+ background-color: rgba(255, 255, 255, 0.8);
78
+ padding: 10px;
79
+ border-radius: 5px;
80
+ }
81
+ .legend-item {
82
+ margin: 5px 0;
83
+ }
84
+ .legend-color {
85
+ display: inline-block;
86
+ width: 20px;
87
+ height: 20px;
88
+ margin-right: 5px;
89
+ vertical-align: middle;
90
+ }
91
+ </style>
92
+ </head>
93
+ <body>
94
+ <svg></svg>
95
+ <div class="tooltip"></div>
96
+ <div class="legend"></div>
97
+ <script type="text/javascript" src="./graph_json.js"></script>
98
+ <script>
99
+ const graphData = graphJson;
100
+
101
+ const svg = d3.select("svg"),
102
+ width = window.innerWidth,
103
+ height = window.innerHeight;
104
+
105
+ svg.attr("viewBox", [0, 0, width, height]);
106
+
107
+ const g = svg.append("g");
108
+
109
+ const entityTypes = [...new Set(graphData.nodes.map(d => d.entity_type))];
110
+ const color = d3.scaleOrdinal(d3.schemeCategory10).domain(entityTypes);
111
+
112
+ const simulation = d3.forceSimulation(graphData.nodes)
113
+ .force("link", d3.forceLink(graphData.links).id(d => d.id).distance(150))
114
+ .force("charge", d3.forceManyBody().strength(-300))
115
+ .force("center", d3.forceCenter(width / 2, height / 2))
116
+ .force("collide", d3.forceCollide().radius(30));
117
+
118
+ const linkGroup = g.append("g")
119
+ .attr("class", "links")
120
+ .selectAll("g")
121
+ .data(graphData.links)
122
+ .enter().append("g")
123
+ .attr("class", "link");
124
+
125
+ const link = linkGroup.append("line")
126
+ .attr("stroke-width", d => Math.sqrt(d.value));
127
+
128
+ const linkLabel = linkGroup.append("text")
129
+ .attr("class", "link-label")
130
+ .text(d => d.description || "");
131
+
132
+ const node = g.append("g")
133
+ .attr("class", "nodes")
134
+ .selectAll("circle")
135
+ .data(graphData.nodes)
136
+ .enter().append("circle")
137
+ .attr("r", 5)
138
+ .attr("fill", d => color(d.entity_type))
139
+ .call(d3.drag()
140
+ .on("start", dragstarted)
141
+ .on("drag", dragged)
142
+ .on("end", dragended));
143
+
144
+ const nodeLabel = g.append("g")
145
+ .attr("class", "node-labels")
146
+ .selectAll("text")
147
+ .data(graphData.nodes)
148
+ .enter().append("text")
149
+ .attr("class", "node-label")
150
+ .text(d => d.id);
151
+
152
+ const tooltip = d3.select(".tooltip");
153
+
154
+ node.on("mouseover", function(event, d) {
155
+ tooltip.transition()
156
+ .duration(200)
157
+ .style("opacity", .9);
158
+ tooltip.html(`<strong>${d.id}</strong><br>Entity Type: ${d.entity_type}<br>Description: ${d.description || "N/A"}`)
159
+ .style("left", (event.pageX + 10) + "px")
160
+ .style("top", (event.pageY - 28) + "px");
161
+ })
162
+ .on("mouseout", function(d) {
163
+ tooltip.transition()
164
+ .duration(500)
165
+ .style("opacity", 0);
166
+ });
167
+
168
+ const legend = d3.select(".legend");
169
+ entityTypes.forEach(type => {
170
+ legend.append("div")
171
+ .attr("class", "legend-item")
172
+ .html(`<span class="legend-color" style="background-color: ${color(type)}"></span>${type}`);
173
+ });
174
+
175
+ simulation
176
+ .nodes(graphData.nodes)
177
+ .on("tick", ticked);
178
+
179
+ simulation.force("link")
180
+ .links(graphData.links);
181
+
182
+ function ticked() {
183
+ link
184
+ .attr("x1", d => d.source.x)
185
+ .attr("y1", d => d.source.y)
186
+ .attr("x2", d => d.target.x)
187
+ .attr("y2", d => d.target.y);
188
+
189
+ linkLabel
190
+ .attr("x", d => (d.source.x + d.target.x) / 2)
191
+ .attr("y", d => (d.source.y + d.target.y) / 2)
192
+ .attr("text-anchor", "middle")
193
+ .attr("dominant-baseline", "middle");
194
+
195
+ node
196
+ .attr("cx", d => d.x)
197
+ .attr("cy", d => d.y);
198
+
199
+ nodeLabel
200
+ .attr("x", d => d.x + 8)
201
+ .attr("y", d => d.y + 3);
202
+ }
203
+
204
+ function dragstarted(event) {
205
+ if (!event.active) simulation.alphaTarget(0.3).restart();
206
+ event.subject.fx = event.subject.x;
207
+ event.subject.fy = event.subject.y;
208
+ }
209
+
210
+ function dragged(event) {
211
+ event.subject.fx = event.x;
212
+ event.subject.fy = event.y;
213
+ }
214
+
215
+ function dragended(event) {
216
+ if (!event.active) simulation.alphaTarget(0);
217
+ event.subject.fx = null;
218
+ event.subject.fy = null;
219
+ }
220
+
221
+ const zoom = d3.zoom()
222
+ .scaleExtent([0.1, 10])
223
+ .on("zoom", zoomed);
224
+
225
+ svg.call(zoom);
226
+
227
+ function zoomed(event) {
228
+ g.attr("transform", event.transform);
229
+ }
230
+
231
+ </script>
232
+ </body>
233
+ </html>
234
+ '''
235
+
236
+ with open(html_path, 'w', encoding='utf-8') as f:
237
+ f.write(html_content)
238
+
239
+
240
+ def create_json(json_data, json_path):
241
+ json_data = "var graphJson = " + json_data.replace('\\"', '').replace("'", "\\'").replace("\n", "")
242
+ with open(json_path, 'w', encoding='utf-8') as f:
243
+ f.write(json_data)
244
+
245
+
246
+ # start simple HTTP server
247
+ def start_server(port):
248
+ handler = http.server.SimpleHTTPRequestHandler
249
+ with socketserver.TCPServer(("", port), handler) as httpd:
250
+ print(f"Server started at http://localhost:{port}")
251
+ httpd.serve_forever()
252
+
253
+ # main function
254
+ def visualize_graphml(graphml_file, html_path, port=8000):
255
+ json_data = graphml_to_json(graphml_file)
256
+ html_dir = os.path.dirname(html_path)
257
+ if not os.path.exists(html_dir):
258
+ os.makedirs(html_dir)
259
+ json_path = os.path.join(html_dir, 'graph_json.js')
260
+ create_json(json_data, json_path)
261
+ create_html(html_path)
262
+ # start server in background
263
+ server_thread = threading.Thread(target=start_server(port))
264
+ server_thread.daemon = True
265
+ server_thread.start()
266
+
267
+ # open default browser
268
+ webbrowser.open(f'http://localhost:{port}/{html_path}')
269
+
270
+ print("Visualization is ready. Press Ctrl+C to exit.")
271
+ try:
272
+ # keep main thread running
273
+ while True:
274
+ pass
275
+ except KeyboardInterrupt:
276
+ print("Shutting down...")
277
+
278
+ # usage
279
+ if __name__ == "__main__":
280
+ graphml_file = r"nano_graphrag_cache_azure_openai_TEST\graph_chunk_entity_relation.graphml" # replace with your GraphML file path
281
+ html_path = "graph_visualization.html"
282
+ visualize_graphml(graphml_file, html_path, 11236)
nano-graphrag/examples/no_openai_key_at_all.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import ollama
4
+ import numpy as np
5
+ from nano_graphrag import GraphRAG, QueryParam
6
+ from nano_graphrag import GraphRAG, QueryParam
7
+ from nano_graphrag.base import BaseKVStorage
8
+ from nano_graphrag._utils import compute_args_hash, wrap_embedding_func_with_attrs
9
+ from sentence_transformers import SentenceTransformer
10
+
11
+ logging.basicConfig(level=logging.WARNING)
12
+ logging.getLogger("nano-graphrag").setLevel(logging.INFO)
13
+
14
+ # !!! qwen2-7B maybe produce unparsable results and cause the extraction of graph to fail.
15
+ WORKING_DIR = "./nano_graphrag_cache_ollama_TEST"
16
+ MODEL = "qwen2"
17
+
18
+ EMBED_MODEL = SentenceTransformer(
19
+ "sentence-transformers/all-MiniLM-L6-v2", cache_folder=WORKING_DIR, device="cpu"
20
+ )
21
+
22
+
23
+ # We're using Sentence Transformers to generate embeddings for the BGE model
24
+ @wrap_embedding_func_with_attrs(
25
+ embedding_dim=EMBED_MODEL.get_sentence_embedding_dimension(),
26
+ max_token_size=EMBED_MODEL.max_seq_length,
27
+ )
28
+ async def local_embedding(texts: list[str]) -> np.ndarray:
29
+ return EMBED_MODEL.encode(texts, normalize_embeddings=True)
30
+
31
+
32
+ async def ollama_model_if_cache(
33
+ prompt, system_prompt=None, history_messages=[], **kwargs
34
+ ) -> str:
35
+ # remove kwargs that are not supported by ollama
36
+ kwargs.pop("max_tokens", None)
37
+ kwargs.pop("response_format", None)
38
+
39
+ ollama_client = ollama.AsyncClient()
40
+ messages = []
41
+ if system_prompt:
42
+ messages.append({"role": "system", "content": system_prompt})
43
+
44
+ # Get the cached response if having-------------------
45
+ hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
46
+ messages.extend(history_messages)
47
+ messages.append({"role": "user", "content": prompt})
48
+ if hashing_kv is not None:
49
+ args_hash = compute_args_hash(MODEL, messages)
50
+ if_cache_return = await hashing_kv.get_by_id(args_hash)
51
+ if if_cache_return is not None:
52
+ return if_cache_return["return"]
53
+ # -----------------------------------------------------
54
+ response = await ollama_client.chat(model=MODEL, messages=messages, **kwargs)
55
+
56
+ result = response["message"]["content"]
57
+ # Cache the response if having-------------------
58
+ if hashing_kv is not None:
59
+ await hashing_kv.upsert({args_hash: {"return": result, "model": MODEL}})
60
+ # -----------------------------------------------------
61
+ return result
62
+
63
+
64
+ def remove_if_exist(file):
65
+ if os.path.exists(file):
66
+ os.remove(file)
67
+
68
+
69
+ def query():
70
+ rag = GraphRAG(
71
+ working_dir=WORKING_DIR,
72
+ best_model_func=ollama_model_if_cache,
73
+ cheap_model_func=ollama_model_if_cache,
74
+ embedding_func=local_embedding,
75
+ )
76
+ print(
77
+ rag.query(
78
+ "What are the top themes in this story?", param=QueryParam(mode="global")
79
+ )
80
+ )
81
+
82
+
83
+ def insert():
84
+ from time import time
85
+
86
+ with open("./tests/mock_data.txt", encoding="utf-8-sig") as f:
87
+ FAKE_TEXT = f.read()
88
+
89
+ remove_if_exist(f"{WORKING_DIR}/vdb_entities.json")
90
+ remove_if_exist(f"{WORKING_DIR}/kv_store_full_docs.json")
91
+ remove_if_exist(f"{WORKING_DIR}/kv_store_text_chunks.json")
92
+ remove_if_exist(f"{WORKING_DIR}/kv_store_community_reports.json")
93
+ remove_if_exist(f"{WORKING_DIR}/graph_chunk_entity_relation.graphml")
94
+
95
+ rag = GraphRAG(
96
+ working_dir=WORKING_DIR,
97
+ enable_llm_cache=True,
98
+ best_model_func=ollama_model_if_cache,
99
+ cheap_model_func=ollama_model_if_cache,
100
+ embedding_func=local_embedding,
101
+ )
102
+ start = time()
103
+ rag.insert(FAKE_TEXT)
104
+ print("indexing time:", time() - start)
105
+ # rag = GraphRAG(working_dir=WORKING_DIR, enable_llm_cache=True)
106
+ # rag.insert(FAKE_TEXT[half_len:])
107
+
108
+
109
+ if __name__ == "__main__":
110
+ insert()
111
+ query()
nano-graphrag/examples/using_amazon_bedrock.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from nano_graphrag import GraphRAG, QueryParam
2
+
3
+ graph_func = GraphRAG(
4
+ working_dir="../bedrock_example",
5
+ using_amazon_bedrock=True,
6
+ best_model_id="us.anthropic.claude-3-sonnet-20240229-v1:0",
7
+ cheap_model_id="us.anthropic.claude-3-haiku-20240307-v1:0",
8
+ )
9
+
10
+ with open("../tests/mock_data.txt") as f:
11
+ graph_func.insert(f.read())
12
+
13
+ prompt = "What are the top themes in this story?"
14
+
15
+ # Perform global graphrag search
16
+ print(graph_func.query(prompt, param=QueryParam(mode="global")))
17
+
18
+ # Perform local graphrag search (I think is better and more scalable one)
19
+ print(graph_func.query(prompt, param=QueryParam(mode="local")))
nano-graphrag/examples/using_custom_chunking_method.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from nano_graphrag._utils import encode_string_by_tiktoken
2
+ from nano_graphrag.base import QueryParam
3
+ from nano_graphrag.graphrag import GraphRAG
4
+ from nano_graphrag._op import chunking_by_seperators
5
+
6
+
7
+ def chunking_by_token_size(
8
+ tokens_list: list[list[int]], # nano-graphrag may pass a batch of docs' tokens
9
+ doc_keys: list[str], # nano-graphrag may pass a batch of docs' key ids
10
+ tiktoken_model, # a titoken model
11
+ overlap_token_size=128,
12
+ max_token_size=1024,
13
+ ):
14
+
15
+ results = []
16
+ for index, tokens in enumerate(tokens_list):
17
+ chunk_token = []
18
+ lengths = []
19
+ for start in range(0, len(tokens), max_token_size - overlap_token_size):
20
+
21
+ chunk_token.append(tokens[start : start + max_token_size])
22
+ lengths.append(min(max_token_size, len(tokens) - start))
23
+
24
+ chunk_token = tiktoken_model.decode_batch(chunk_token)
25
+ for i, chunk in enumerate(chunk_token):
26
+
27
+ results.append(
28
+ {
29
+ "tokens": lengths[i],
30
+ "content": chunk.strip(),
31
+ "chunk_order_index": i,
32
+ "full_doc_id": doc_keys[index],
33
+ }
34
+ )
35
+
36
+ return results
37
+
38
+
39
+ WORKING_DIR = "./nano_graphrag_cache_local_embedding_TEST"
40
+ rag = GraphRAG(
41
+ working_dir=WORKING_DIR,
42
+ chunk_func=chunking_by_seperators,
43
+ )
nano-graphrag/examples/using_deepseek_api_as_llm+glm_api_as_embedding.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import numpy as np
4
+ from openai import AsyncOpenAI, OpenAI
5
+ from dataclasses import dataclass
6
+ from nano_graphrag import GraphRAG, QueryParam
7
+ from nano_graphrag.base import BaseKVStorage
8
+ from nano_graphrag._utils import compute_args_hash
9
+
10
+ logging.basicConfig(level=logging.WARNING)
11
+ logging.getLogger("nano-graphrag").setLevel(logging.INFO)
12
+
13
+ GLM_API_KEY = "XXXX"
14
+ DEEPSEEK_API_KEY = "sk-XXXX"
15
+
16
+ MODEL = "deepseek-chat"
17
+
18
+
19
+ async def deepseepk_model_if_cache(
20
+ prompt, system_prompt=None, history_messages=[], **kwargs
21
+ ) -> str:
22
+ openai_async_client = AsyncOpenAI(
23
+ api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com"
24
+ )
25
+ messages = []
26
+ if system_prompt:
27
+ messages.append({"role": "system", "content": system_prompt})
28
+
29
+ # Get the cached response if having-------------------
30
+ hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
31
+ messages.extend(history_messages)
32
+ messages.append({"role": "user", "content": prompt})
33
+ if hashing_kv is not None:
34
+ args_hash = compute_args_hash(MODEL, messages)
35
+ if_cache_return = await hashing_kv.get_by_id(args_hash)
36
+ if if_cache_return is not None:
37
+ return if_cache_return["return"]
38
+ # -----------------------------------------------------
39
+
40
+ response = await openai_async_client.chat.completions.create(
41
+ model=MODEL, messages=messages, **kwargs
42
+ )
43
+
44
+ # Cache the response if having-------------------
45
+ if hashing_kv is not None:
46
+ await hashing_kv.upsert(
47
+ {args_hash: {"return": response.choices[0].message.content, "model": MODEL}}
48
+ )
49
+ # -----------------------------------------------------
50
+ return response.choices[0].message.content
51
+
52
+
53
+ def remove_if_exist(file):
54
+ if os.path.exists(file):
55
+ os.remove(file)
56
+
57
+
58
+ @dataclass
59
+ class EmbeddingFunc:
60
+ embedding_dim: int
61
+ max_token_size: int
62
+ func: callable
63
+
64
+ async def __call__(self, *args, **kwargs) -> np.ndarray:
65
+ return await self.func(*args, **kwargs)
66
+
67
+ def wrap_embedding_func_with_attrs(**kwargs):
68
+ """Wrap a function with attributes"""
69
+
70
+ def final_decro(func) -> EmbeddingFunc:
71
+ new_func = EmbeddingFunc(**kwargs, func=func)
72
+ return new_func
73
+
74
+ return final_decro
75
+
76
+ @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
77
+ async def GLM_embedding(texts: list[str]) -> np.ndarray:
78
+ model_name = "embedding-2"
79
+ client = OpenAI(
80
+ api_key=GLM_API_KEY,
81
+ base_url="https://open.bigmodel.cn/api/paas/v4/"
82
+ )
83
+ embedding = client.embeddings.create(
84
+ input=texts,
85
+ model=model_name,
86
+ )
87
+ final_embedding = [d.embedding for d in embedding.data]
88
+ return np.array(final_embedding)
89
+
90
+
91
+
92
+ WORKING_DIR = "./nano_graphrag_cache_deepseek_TEST"
93
+
94
+ def query():
95
+ rag = GraphRAG(
96
+ working_dir=WORKING_DIR,
97
+ best_model_func=deepseepk_model_if_cache,
98
+ cheap_model_func=deepseepk_model_if_cache,
99
+ embedding_func=GLM_embedding,
100
+ )
101
+ print(
102
+ rag.query(
103
+ "What are the top themes in this story?", param=QueryParam(mode="global")
104
+ )
105
+ )
106
+
107
+
108
+ def insert():
109
+ from time import time
110
+
111
+ with open("./tests/mock_data.txt", encoding="utf-8-sig") as f:
112
+ FAKE_TEXT = f.read()
113
+
114
+ remove_if_exist(f"{WORKING_DIR}/vdb_entities.json")
115
+ remove_if_exist(f"{WORKING_DIR}/kv_store_full_docs.json")
116
+ remove_if_exist(f"{WORKING_DIR}/kv_store_text_chunks.json")
117
+ remove_if_exist(f"{WORKING_DIR}/kv_store_community_reports.json")
118
+ remove_if_exist(f"{WORKING_DIR}/graph_chunk_entity_relation.graphml")
119
+
120
+ rag = GraphRAG(
121
+ working_dir=WORKING_DIR,
122
+ enable_llm_cache=True,
123
+ best_model_func=deepseepk_model_if_cache,
124
+ cheap_model_func=deepseepk_model_if_cache,
125
+ embedding_func=GLM_embedding,
126
+ )
127
+ start = time()
128
+ rag.insert(FAKE_TEXT)
129
+ print("indexing time:", time() - start)
130
+ # rag = GraphRAG(working_dir=WORKING_DIR, enable_llm_cache=True)
131
+ # rag.insert(FAKE_TEXT[half_len:])
132
+
133
+
134
+ if __name__ == "__main__":
135
+ insert()
136
+ # query()
nano-graphrag/examples/using_deepseek_as_llm.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from openai import AsyncOpenAI
4
+ from nano_graphrag import GraphRAG, QueryParam
5
+ from nano_graphrag import GraphRAG, QueryParam
6
+ from nano_graphrag.base import BaseKVStorage
7
+ from nano_graphrag._utils import compute_args_hash
8
+
9
+ logging.basicConfig(level=logging.WARNING)
10
+ logging.getLogger("nano-graphrag").setLevel(logging.INFO)
11
+
12
+ DEEPSEEK_API_KEY = "sk-XXXX"
13
+ MODEL = "deepseek-chat"
14
+
15
+
16
+ async def deepseepk_model_if_cache(
17
+ prompt, system_prompt=None, history_messages=[], **kwargs
18
+ ) -> str:
19
+ openai_async_client = AsyncOpenAI(
20
+ api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com"
21
+ )
22
+ messages = []
23
+ if system_prompt:
24
+ messages.append({"role": "system", "content": system_prompt})
25
+
26
+ # Get the cached response if having-------------------
27
+ hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
28
+ messages.extend(history_messages)
29
+ messages.append({"role": "user", "content": prompt})
30
+ if hashing_kv is not None:
31
+ args_hash = compute_args_hash(MODEL, messages)
32
+ if_cache_return = await hashing_kv.get_by_id(args_hash)
33
+ if if_cache_return is not None:
34
+ return if_cache_return["return"]
35
+ # -----------------------------------------------------
36
+
37
+ response = await openai_async_client.chat.completions.create(
38
+ model=MODEL, messages=messages, **kwargs
39
+ )
40
+
41
+ # Cache the response if having-------------------
42
+ if hashing_kv is not None:
43
+ await hashing_kv.upsert(
44
+ {args_hash: {"return": response.choices[0].message.content, "model": MODEL}}
45
+ )
46
+ # -----------------------------------------------------
47
+ return response.choices[0].message.content
48
+
49
+
50
+ def remove_if_exist(file):
51
+ if os.path.exists(file):
52
+ os.remove(file)
53
+
54
+
55
+ WORKING_DIR = "./nano_graphrag_cache_deepseek_TEST"
56
+
57
+
58
+ def query():
59
+ rag = GraphRAG(
60
+ working_dir=WORKING_DIR,
61
+ best_model_func=deepseepk_model_if_cache,
62
+ cheap_model_func=deepseepk_model_if_cache,
63
+ )
64
+ print(
65
+ rag.query(
66
+ "What are the top themes in this story?", param=QueryParam(mode="global")
67
+ )
68
+ )
69
+
70
+
71
+ def insert():
72
+ from time import time
73
+
74
+ with open("./tests/mock_data.txt", encoding="utf-8-sig") as f:
75
+ FAKE_TEXT = f.read()
76
+
77
+ remove_if_exist(f"{WORKING_DIR}/vdb_entities.json")
78
+ remove_if_exist(f"{WORKING_DIR}/kv_store_full_docs.json")
79
+ remove_if_exist(f"{WORKING_DIR}/kv_store_text_chunks.json")
80
+ remove_if_exist(f"{WORKING_DIR}/kv_store_community_reports.json")
81
+ remove_if_exist(f"{WORKING_DIR}/graph_chunk_entity_relation.graphml")
82
+
83
+ rag = GraphRAG(
84
+ working_dir=WORKING_DIR,
85
+ enable_llm_cache=True,
86
+ best_model_func=deepseepk_model_if_cache,
87
+ cheap_model_func=deepseepk_model_if_cache,
88
+ )
89
+ start = time()
90
+ rag.insert(FAKE_TEXT)
91
+ print("indexing time:", time() - start)
92
+ # rag = GraphRAG(working_dir=WORKING_DIR, enable_llm_cache=True)
93
+ # rag.insert(FAKE_TEXT[half_len:])
94
+
95
+
96
+ if __name__ == "__main__":
97
+ insert()
98
+ # query()
nano-graphrag/examples/using_dspy_entity_extraction.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from openai import AsyncOpenAI
3
+ from dotenv import load_dotenv
4
+ import logging
5
+ import numpy as np
6
+ import dspy
7
+ from sentence_transformers import SentenceTransformer
8
+ from nano_graphrag import GraphRAG, QueryParam
9
+ from nano_graphrag._llm import gpt_4o_mini_complete
10
+ from nano_graphrag._storage import HNSWVectorStorage
11
+ from nano_graphrag.base import BaseKVStorage
12
+ from nano_graphrag._utils import compute_args_hash, wrap_embedding_func_with_attrs
13
+ from nano_graphrag.entity_extraction.extract import extract_entities_dspy
14
+
15
+ logging.basicConfig(level=logging.WARNING)
16
+ logging.getLogger("nano-graphrag").setLevel(logging.DEBUG)
17
+
18
+ WORKING_DIR = "./nano_graphrag_cache_using_dspy_entity_extraction"
19
+
20
+ load_dotenv()
21
+
22
+
23
+ EMBED_MODEL = SentenceTransformer(
24
+ "sentence-transformers/all-MiniLM-L6-v2", cache_folder=WORKING_DIR, device="cpu"
25
+ )
26
+
27
+
28
+ @wrap_embedding_func_with_attrs(
29
+ embedding_dim=EMBED_MODEL.get_sentence_embedding_dimension(),
30
+ max_token_size=EMBED_MODEL.max_seq_length,
31
+ )
32
+ async def local_embedding(texts: list[str]) -> np.ndarray:
33
+ return EMBED_MODEL.encode(texts, normalize_embeddings=True)
34
+
35
+
36
+ async def deepseepk_model_if_cache(
37
+ prompt, model: str = "deepseek-chat", system_prompt=None, history_messages=[], **kwargs
38
+ ) -> str:
39
+ openai_async_client = AsyncOpenAI(
40
+ api_key=os.environ.get("DEEPSEEK_API_KEY"), base_url="https://api.deepseek.com"
41
+ )
42
+ messages = []
43
+ if system_prompt:
44
+ messages.append({"role": "system", "content": system_prompt})
45
+
46
+ # Get the cached response if having-------------------
47
+ hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
48
+ messages.extend(history_messages)
49
+ messages.append({"role": "user", "content": prompt})
50
+ if hashing_kv is not None:
51
+ args_hash = compute_args_hash(model, messages)
52
+ if_cache_return = await hashing_kv.get_by_id(args_hash)
53
+ if if_cache_return is not None:
54
+ return if_cache_return["return"]
55
+ # -----------------------------------------------------
56
+
57
+ response = await openai_async_client.chat.completions.create(
58
+ model=model, messages=messages, **kwargs
59
+ )
60
+
61
+ # Cache the response if having-------------------
62
+ if hashing_kv is not None:
63
+ await hashing_kv.upsert(
64
+ {args_hash: {"return": response.choices[0].message.content, "model": model}}
65
+ )
66
+ # -----------------------------------------------------
67
+ return response.choices[0].message.content
68
+
69
+
70
+
71
+ def remove_if_exist(file):
72
+ if os.path.exists(file):
73
+ os.remove(file)
74
+
75
+
76
+ def insert():
77
+ from time import time
78
+
79
+ with open("./tests/mock_data.txt", encoding="utf-8-sig") as f:
80
+ FAKE_TEXT = f.read()
81
+
82
+ remove_if_exist(f"{WORKING_DIR}/vdb_entities.json")
83
+ remove_if_exist(f"{WORKING_DIR}/kv_store_full_docs.json")
84
+ remove_if_exist(f"{WORKING_DIR}/kv_store_text_chunks.json")
85
+ remove_if_exist(f"{WORKING_DIR}/kv_store_community_reports.json")
86
+ remove_if_exist(f"{WORKING_DIR}/graph_chunk_entity_relation.graphml")
87
+ rag = GraphRAG(
88
+ working_dir=WORKING_DIR,
89
+ enable_llm_cache=True,
90
+ vector_db_storage_cls=HNSWVectorStorage,
91
+ vector_db_storage_cls_kwargs={"max_elements": 1000000, "ef_search": 200, "M": 50},
92
+ best_model_max_async=10,
93
+ cheap_model_max_async=10,
94
+ best_model_func=deepseepk_model_if_cache,
95
+ cheap_model_func=deepseepk_model_if_cache,
96
+ embedding_func=local_embedding,
97
+ entity_extraction_func=extract_entities_dspy
98
+ )
99
+ start = time()
100
+ rag.insert(FAKE_TEXT)
101
+ print("indexing time:", time() - start)
102
+
103
+
104
+ def query():
105
+ rag = GraphRAG(
106
+ working_dir=WORKING_DIR,
107
+ enable_llm_cache=True,
108
+ vector_db_storage_cls=HNSWVectorStorage,
109
+ vector_db_storage_cls_kwargs={"max_elements": 1000000, "ef_search": 200, "M": 50},
110
+ best_model_max_token_size=8196,
111
+ cheap_model_max_token_size=8196,
112
+ best_model_max_async=4,
113
+ cheap_model_max_async=4,
114
+ best_model_func=gpt_4o_mini_complete,
115
+ cheap_model_func=gpt_4o_mini_complete,
116
+ embedding_func=local_embedding,
117
+ entity_extraction_func=extract_entities_dspy
118
+
119
+ )
120
+ print(
121
+ rag.query(
122
+ "What are the top themes in this story?", param=QueryParam(mode="global")
123
+ )
124
+ )
125
+ print(
126
+ rag.query(
127
+ "What are the top themes in this story?", param=QueryParam(mode="local")
128
+ )
129
+ )
130
+
131
+
132
+ if __name__ == "__main__":
133
+ lm = dspy.LM(
134
+ model="deepseek/deepseek-chat",
135
+ model_type="chat",
136
+ api_provider="openai",
137
+ api_key=os.environ["DEEPSEEK_API_KEY"],
138
+ base_url=os.environ["DEEPSEEK_BASE_URL"],
139
+ temperature=1.0,
140
+ max_tokens=8192
141
+ )
142
+ dspy.settings.configure(lm=lm, experimental=True)
143
+ insert()
144
+ query()
nano-graphrag/examples/using_faiss_as_vextorDB.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import asyncio
3
+ import numpy as np
4
+ from nano_graphrag.graphrag import GraphRAG, QueryParam
5
+ from nano_graphrag._utils import logger
6
+ from nano_graphrag.base import BaseVectorStorage
7
+ from dataclasses import dataclass
8
+ import faiss
9
+ import pickle
10
+ import logging
11
+ import xxhash
12
+ logging.getLogger('msal').setLevel(logging.WARNING)
13
+ logging.getLogger('azure').setLevel(logging.WARNING)
14
+ logging.getLogger("httpx").setLevel(logging.WARNING)
15
+
16
+ WORKING_DIR = "./nano_graphrag_cache_faiss_TEST"
17
+
18
+ @dataclass
19
+ class FAISSStorage(BaseVectorStorage):
20
+
21
+ def __post_init__(self):
22
+ self._index_file_name = os.path.join(
23
+ self.global_config["working_dir"], f"{self.namespace}_faiss.index"
24
+ )
25
+ self._metadata_file_name = os.path.join(
26
+ self.global_config["working_dir"], f"{self.namespace}_metadata.pkl"
27
+ )
28
+ self._max_batch_size = self.global_config["embedding_batch_num"]
29
+
30
+ if os.path.exists(self._index_file_name) and os.path.exists(self._metadata_file_name):
31
+ self._index = faiss.read_index(self._index_file_name)
32
+ with open(self._metadata_file_name, 'rb') as f:
33
+ self._metadata = pickle.load(f)
34
+ else:
35
+ self._index = faiss.IndexIDMap(faiss.IndexFlatIP(self.embedding_func.embedding_dim))
36
+ self._metadata = {}
37
+
38
+ async def upsert(self, data: dict[str, dict]):
39
+ logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
40
+
41
+ contents = [v["content"] for v in data.values()]
42
+ batches = [
43
+ contents[i : i + self._max_batch_size]
44
+ for i in range(0, len(contents), self._max_batch_size)
45
+ ]
46
+ embeddings_list = await asyncio.gather(
47
+ *[self.embedding_func(batch) for batch in batches]
48
+ )
49
+ embeddings = np.concatenate(embeddings_list)
50
+
51
+ ids = []
52
+ for k, v in data.items():
53
+ id = xxhash.xxh32_intdigest(k.encode())
54
+ metadata = {k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields}
55
+ metadata['id'] = k
56
+ self._metadata[id] = metadata
57
+ ids.append(id)
58
+
59
+ ids = np.array(ids, dtype=np.int64)
60
+ self._index.add_with_ids(embeddings, ids)
61
+
62
+
63
+ return len(data)
64
+
65
+ async def query(self, query, top_k=5):
66
+ embedding = await self.embedding_func([query])
67
+ distances, indices = self._index.search(embedding, top_k)
68
+
69
+ results = []
70
+ for _, (distance, id) in enumerate(zip(distances[0], indices[0])):
71
+ if id != -1: # FAISS returns -1 for empty slots
72
+ if id in self._metadata:
73
+ metadata = self._metadata[id]
74
+ results.append({**metadata, "distance": 1 - distance}) # Convert to cosine distance
75
+
76
+ return results
77
+
78
+ async def index_done_callback(self):
79
+ faiss.write_index(self._index, self._index_file_name)
80
+ with open(self._metadata_file_name, 'wb') as f:
81
+ pickle.dump(self._metadata, f)
82
+
83
+ if __name__ == "__main__":
84
+
85
+ graph_func = GraphRAG(
86
+ working_dir=WORKING_DIR,
87
+ enable_llm_cache=True,
88
+ vector_db_storage_cls=FAISSStorage,
89
+ )
90
+
91
+ with open(r"tests/mock_data.txt", encoding='utf-8') as f:
92
+ graph_func.insert(f.read()[:30000])
93
+
94
+ # Perform global graphrag search
95
+ print(graph_func.query("What are the top themes in this story?"))
96
+
97
+
nano-graphrag/examples/using_hnsw_as_vectorDB.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from openai import AsyncOpenAI
3
+ from dotenv import load_dotenv
4
+ import logging
5
+ import numpy as np
6
+ from sentence_transformers import SentenceTransformer
7
+ from nano_graphrag import GraphRAG, QueryParam
8
+ from nano_graphrag._llm import gpt_4o_mini_complete
9
+ from nano_graphrag._storage import HNSWVectorStorage
10
+ from nano_graphrag.base import BaseKVStorage
11
+ from nano_graphrag._utils import compute_args_hash, wrap_embedding_func_with_attrs
12
+
13
+ logging.basicConfig(level=logging.WARNING)
14
+ logging.getLogger("nano-graphrag").setLevel(logging.DEBUG)
15
+
16
+ WORKING_DIR = "./nano_graphrag_cache_using_hnsw_as_vectorDB"
17
+
18
+ load_dotenv()
19
+
20
+
21
+ EMBED_MODEL = SentenceTransformer(
22
+ "sentence-transformers/all-MiniLM-L6-v2", cache_folder=WORKING_DIR, device="cpu"
23
+ )
24
+
25
+
26
+ @wrap_embedding_func_with_attrs(
27
+ embedding_dim=EMBED_MODEL.get_sentence_embedding_dimension(),
28
+ max_token_size=EMBED_MODEL.max_seq_length,
29
+ )
30
+ async def local_embedding(texts: list[str]) -> np.ndarray:
31
+ return EMBED_MODEL.encode(texts, normalize_embeddings=True)
32
+
33
+
34
+ async def deepseepk_model_if_cache(
35
+ prompt, model: str = "deepseek-chat", system_prompt=None, history_messages=[], **kwargs
36
+ ) -> str:
37
+ openai_async_client = AsyncOpenAI(
38
+ api_key=os.environ.get("DEEPSEEK_API_KEY"), base_url="https://api.deepseek.com"
39
+ )
40
+ messages = []
41
+ if system_prompt:
42
+ messages.append({"role": "system", "content": system_prompt})
43
+
44
+ # Get the cached response if having-------------------
45
+ hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
46
+ messages.extend(history_messages)
47
+ messages.append({"role": "user", "content": prompt})
48
+ if hashing_kv is not None:
49
+ args_hash = compute_args_hash(model, messages)
50
+ if_cache_return = await hashing_kv.get_by_id(args_hash)
51
+ if if_cache_return is not None:
52
+ return if_cache_return["return"]
53
+ # -----------------------------------------------------
54
+
55
+ response = await openai_async_client.chat.completions.create(
56
+ model=model, messages=messages, **kwargs
57
+ )
58
+
59
+ # Cache the response if having-------------------
60
+ if hashing_kv is not None:
61
+ await hashing_kv.upsert(
62
+ {args_hash: {"return": response.choices[0].message.content, "model": model}}
63
+ )
64
+ # -----------------------------------------------------
65
+ return response.choices[0].message.content
66
+
67
+
68
+
69
+ def remove_if_exist(file):
70
+ if os.path.exists(file):
71
+ os.remove(file)
72
+
73
+
74
+ def insert():
75
+ from time import time
76
+
77
+ with open("./tests/mock_data.txt", encoding="utf-8-sig") as f:
78
+ FAKE_TEXT = f.read()
79
+
80
+ remove_if_exist(f"{WORKING_DIR}/vdb_entities.json")
81
+ remove_if_exist(f"{WORKING_DIR}/kv_store_full_docs.json")
82
+ remove_if_exist(f"{WORKING_DIR}/kv_store_text_chunks.json")
83
+ remove_if_exist(f"{WORKING_DIR}/kv_store_community_reports.json")
84
+ remove_if_exist(f"{WORKING_DIR}/graph_chunk_entity_relation.graphml")
85
+ rag = GraphRAG(
86
+ working_dir=WORKING_DIR,
87
+ enable_llm_cache=True,
88
+ vector_db_storage_cls=HNSWVectorStorage,
89
+ vector_db_storage_cls_kwargs={"max_elements": 1000000, "ef_search": 200, "M": 50},
90
+ best_model_max_async=10,
91
+ cheap_model_max_async=10,
92
+ best_model_func=deepseepk_model_if_cache,
93
+ cheap_model_func=deepseepk_model_if_cache,
94
+ embedding_func=local_embedding
95
+ )
96
+ start = time()
97
+ rag.insert(FAKE_TEXT)
98
+ print("indexing time:", time() - start)
99
+
100
+
101
+ def query():
102
+ rag = GraphRAG(
103
+ working_dir=WORKING_DIR,
104
+ enable_llm_cache=True,
105
+ vector_db_storage_cls=HNSWVectorStorage,
106
+ vector_db_storage_cls_kwargs={"max_elements": 1000000, "ef_search": 200, "M": 50},
107
+ best_model_max_token_size=8196,
108
+ cheap_model_max_token_size=8196,
109
+ best_model_max_async=4,
110
+ cheap_model_max_async=4,
111
+ best_model_func=gpt_4o_mini_complete,
112
+ cheap_model_func=gpt_4o_mini_complete,
113
+ embedding_func=local_embedding
114
+ )
115
+ print(
116
+ rag.query(
117
+ "What are the top themes in this story?", param=QueryParam(mode="global")
118
+ )
119
+ )
120
+ print(
121
+ rag.query(
122
+ "What are the top themes in this story?", param=QueryParam(mode="local")
123
+ )
124
+ )
125
+
126
+
127
+ if __name__ == "__main__":
128
+ insert()
129
+ query()
nano-graphrag/examples/using_llm_api_as_llm+ollama_embedding.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import ollama
4
+ import numpy as np
5
+ from openai import AsyncOpenAI
6
+ from nano_graphrag import GraphRAG, QueryParam
7
+ from nano_graphrag import GraphRAG, QueryParam
8
+ from nano_graphrag.base import BaseKVStorage
9
+ from nano_graphrag._utils import compute_args_hash, wrap_embedding_func_with_attrs
10
+
11
+ logging.basicConfig(level=logging.WARNING)
12
+ logging.getLogger("nano-graphrag").setLevel(logging.INFO)
13
+
14
+ # Assumed llm model settings
15
+ LLM_BASE_URL = "https://your.api.url"
16
+ LLM_API_KEY = "your_api_key"
17
+ MODEL = "your_model_name"
18
+
19
+ # Assumed embedding model settings
20
+ EMBEDDING_MODEL = "nomic-embed-text"
21
+ EMBEDDING_MODEL_DIM = 768
22
+ EMBEDDING_MODEL_MAX_TOKENS = 8192
23
+
24
+
25
+ async def llm_model_if_cache(
26
+ prompt, system_prompt=None, history_messages=[], **kwargs
27
+ ) -> str:
28
+ openai_async_client = AsyncOpenAI(
29
+ api_key=LLM_API_KEY, base_url=LLM_BASE_URL
30
+ )
31
+ messages = []
32
+ if system_prompt:
33
+ messages.append({"role": "system", "content": system_prompt})
34
+
35
+ # Get the cached response if having-------------------
36
+ hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
37
+ messages.extend(history_messages)
38
+ messages.append({"role": "user", "content": prompt})
39
+ if hashing_kv is not None:
40
+ args_hash = compute_args_hash(MODEL, messages)
41
+ if_cache_return = await hashing_kv.get_by_id(args_hash)
42
+ if if_cache_return is not None:
43
+ return if_cache_return["return"]
44
+ # -----------------------------------------------------
45
+
46
+ response = await openai_async_client.chat.completions.create(
47
+ model=MODEL, messages=messages, **kwargs
48
+ )
49
+
50
+ # Cache the response if having-------------------
51
+ if hashing_kv is not None:
52
+ await hashing_kv.upsert(
53
+ {args_hash: {"return": response.choices[0].message.content, "model": MODEL}}
54
+ )
55
+ # -----------------------------------------------------
56
+ return response.choices[0].message.content
57
+
58
+
59
+ def remove_if_exist(file):
60
+ if os.path.exists(file):
61
+ os.remove(file)
62
+
63
+
64
+ WORKING_DIR = "./nano_graphrag_cache_llm_TEST"
65
+
66
+
67
+ def query():
68
+ rag = GraphRAG(
69
+ working_dir=WORKING_DIR,
70
+ best_model_func=llm_model_if_cache,
71
+ cheap_model_func=llm_model_if_cache,
72
+ embedding_func=ollama_embedding,
73
+ )
74
+ print(
75
+ rag.query(
76
+ "What are the top themes in this story?", param=QueryParam(mode="global")
77
+ )
78
+ )
79
+
80
+
81
+ def insert():
82
+ from time import time
83
+
84
+ with open("./tests/mock_data.txt", encoding="utf-8-sig") as f:
85
+ FAKE_TEXT = f.read()
86
+
87
+ remove_if_exist(f"{WORKING_DIR}/vdb_entities.json")
88
+ remove_if_exist(f"{WORKING_DIR}/kv_store_full_docs.json")
89
+ remove_if_exist(f"{WORKING_DIR}/kv_store_text_chunks.json")
90
+ remove_if_exist(f"{WORKING_DIR}/kv_store_community_reports.json")
91
+ remove_if_exist(f"{WORKING_DIR}/graph_chunk_entity_relation.graphml")
92
+
93
+ rag = GraphRAG(
94
+ working_dir=WORKING_DIR,
95
+ enable_llm_cache=True,
96
+ best_model_func=llm_model_if_cache,
97
+ cheap_model_func=llm_model_if_cache,
98
+ embedding_func=ollama_embedding,
99
+ )
100
+ start = time()
101
+ rag.insert(FAKE_TEXT)
102
+ print("indexing time:", time() - start)
103
+ # rag = GraphRAG(working_dir=WORKING_DIR, enable_llm_cache=True)
104
+ # rag.insert(FAKE_TEXT[half_len:])
105
+
106
+ # We're using Ollama to generate embeddings for the BGE model
107
+ @wrap_embedding_func_with_attrs(
108
+ embedding_dim= EMBEDDING_MODEL_DIM,
109
+ max_token_size= EMBEDDING_MODEL_MAX_TOKENS,
110
+ )
111
+
112
+ async def ollama_embedding(texts :list[str]) -> np.ndarray:
113
+ embed_text = []
114
+ for text in texts:
115
+ data = ollama.embeddings(model=EMBEDDING_MODEL, prompt=text)
116
+ embed_text.append(data["embedding"])
117
+
118
+ return embed_text
119
+
120
+ if __name__ == "__main__":
121
+ insert()
122
+ query()
nano-graphrag/examples/using_local_embedding_model.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ sys.path.append("..")
4
+ import logging
5
+ import numpy as np
6
+ from nano_graphrag import GraphRAG, QueryParam
7
+ from nano_graphrag._utils import wrap_embedding_func_with_attrs
8
+ from sentence_transformers import SentenceTransformer
9
+
10
+ logging.basicConfig(level=logging.WARNING)
11
+ logging.getLogger("nano-graphrag").setLevel(logging.INFO)
12
+
13
+ WORKING_DIR = "./nano_graphrag_cache_local_embedding_TEST"
14
+
15
+ EMBED_MODEL = SentenceTransformer(
16
+ "sentence-transformers/all-MiniLM-L6-v2", cache_folder=WORKING_DIR, device="cpu"
17
+ )
18
+
19
+
20
+ # We're using Sentence Transformers to generate embeddings for the BGE model
21
+ @wrap_embedding_func_with_attrs(
22
+ embedding_dim=EMBED_MODEL.get_sentence_embedding_dimension(),
23
+ max_token_size=EMBED_MODEL.max_seq_length,
24
+ )
25
+ async def local_embedding(texts: list[str]) -> np.ndarray:
26
+ return EMBED_MODEL.encode(texts, normalize_embeddings=True)
27
+
28
+
29
+ rag = GraphRAG(
30
+ working_dir=WORKING_DIR,
31
+ embedding_func=local_embedding,
32
+ )
33
+
34
+ with open("../tests/mock_data.txt", encoding="utf-8-sig") as f:
35
+ FAKE_TEXT = f.read()
36
+
37
+ # rag.insert(FAKE_TEXT)
38
+ print(rag.query("What the main theme of this story?", param=QueryParam(mode="local")))
nano-graphrag/examples/using_milvus_as_vectorDB.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import asyncio
3
+ import numpy as np
4
+ from nano_graphrag import GraphRAG, QueryParam
5
+ from nano_graphrag._utils import logger
6
+ from nano_graphrag.base import BaseVectorStorage
7
+ from dataclasses import dataclass
8
+
9
+
10
+ @dataclass
11
+ class MilvusLiteStorge(BaseVectorStorage):
12
+
13
+ @staticmethod
14
+ def create_collection_if_not_exist(client, collection_name: str, **kwargs):
15
+ if client.has_collection(collection_name):
16
+ return
17
+ # TODO add constants for ID max length to 32
18
+ client.create_collection(
19
+ collection_name, max_length=32, id_type="string", **kwargs
20
+ )
21
+
22
+ def __post_init__(self):
23
+ from pymilvus import MilvusClient
24
+
25
+ self._client_file_name = os.path.join(
26
+ self.global_config["working_dir"], "milvus_lite.db"
27
+ )
28
+ self._client = MilvusClient(self._client_file_name)
29
+ self._max_batch_size = self.global_config["embedding_batch_num"]
30
+ MilvusLiteStorge.create_collection_if_not_exist(
31
+ self._client,
32
+ self.namespace,
33
+ dimension=self.embedding_func.embedding_dim,
34
+ )
35
+
36
+ async def upsert(self, data: dict[str, dict]):
37
+ logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
38
+ list_data = [
39
+ {
40
+ "id": k,
41
+ **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
42
+ }
43
+ for k, v in data.items()
44
+ ]
45
+ contents = [v["content"] for v in data.values()]
46
+ batches = [
47
+ contents[i : i + self._max_batch_size]
48
+ for i in range(0, len(contents), self._max_batch_size)
49
+ ]
50
+ embeddings_list = await asyncio.gather(
51
+ *[self.embedding_func(batch) for batch in batches]
52
+ )
53
+ embeddings = np.concatenate(embeddings_list)
54
+ for i, d in enumerate(list_data):
55
+ d["vector"] = embeddings[i]
56
+ results = self._client.upsert(collection_name=self.namespace, data=list_data)
57
+ return results
58
+
59
+ async def query(self, query, top_k=5):
60
+ embedding = await self.embedding_func([query])
61
+ results = self._client.search(
62
+ collection_name=self.namespace,
63
+ data=embedding,
64
+ limit=top_k,
65
+ output_fields=list(self.meta_fields),
66
+ search_params={"metric_type": "COSINE", "params": {"radius": 0.2}},
67
+ )
68
+ return [
69
+ {**dp["entity"], "id": dp["id"], "distance": dp["distance"]}
70
+ for dp in results[0]
71
+ ]
72
+
73
+
74
+ def insert():
75
+ data = ["YOUR TEXT DATA HERE", "YOUR TEXT DATA HERE"]
76
+ rag = GraphRAG(
77
+ working_dir="./nano_graphrag_cache_milvus_TEST",
78
+ enable_llm_cache=True,
79
+ vector_db_storage_cls=MilvusLiteStorge,
80
+ )
81
+ rag.insert(data)
82
+
83
+
84
+ def query():
85
+ rag = GraphRAG(
86
+ working_dir="./nano_graphrag_cache_milvus_TEST",
87
+ enable_llm_cache=True,
88
+ vector_db_storage_cls=MilvusLiteStorge,
89
+ )
90
+ print(rag.query("YOUR QUERY HERE", param=QueryParam(mode="local")))
91
+
92
+
93
+ insert()
94
+ query()
nano-graphrag/examples/using_ollama_as_llm.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import ollama
4
+ from nano_graphrag import GraphRAG, QueryParam
5
+ from nano_graphrag import GraphRAG, QueryParam
6
+ from nano_graphrag.base import BaseKVStorage
7
+ from nano_graphrag._utils import compute_args_hash
8
+
9
+ logging.basicConfig(level=logging.WARNING)
10
+ logging.getLogger("nano-graphrag").setLevel(logging.INFO)
11
+
12
+ # !!! qwen2-7B maybe produce unparsable results and cause the extraction of graph to fail.
13
+ MODEL = "qwen2"
14
+
15
+
16
+ async def ollama_model_if_cache(
17
+ prompt, system_prompt=None, history_messages=[], **kwargs
18
+ ) -> str:
19
+ # remove kwargs that are not supported by ollama
20
+ kwargs.pop("max_tokens", None)
21
+ kwargs.pop("response_format", None)
22
+
23
+ ollama_client = ollama.AsyncClient()
24
+ messages = []
25
+ if system_prompt:
26
+ messages.append({"role": "system", "content": system_prompt})
27
+
28
+ # Get the cached response if having-------------------
29
+ hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
30
+ messages.extend(history_messages)
31
+ messages.append({"role": "user", "content": prompt})
32
+ if hashing_kv is not None:
33
+ args_hash = compute_args_hash(MODEL, messages)
34
+ if_cache_return = await hashing_kv.get_by_id(args_hash)
35
+ if if_cache_return is not None:
36
+ return if_cache_return["return"]
37
+ # -----------------------------------------------------
38
+ response = await ollama_client.chat(model=MODEL, messages=messages, **kwargs)
39
+
40
+ result = response["message"]["content"]
41
+ # Cache the response if having-------------------
42
+ if hashing_kv is not None:
43
+ await hashing_kv.upsert({args_hash: {"return": result, "model": MODEL}})
44
+ # -----------------------------------------------------
45
+ return result
46
+
47
+
48
+ def remove_if_exist(file):
49
+ if os.path.exists(file):
50
+ os.remove(file)
51
+
52
+
53
+ WORKING_DIR = "./nano_graphrag_cache_ollama_TEST"
54
+
55
+
56
+ def query():
57
+ rag = GraphRAG(
58
+ working_dir=WORKING_DIR,
59
+ best_model_func=ollama_model_if_cache,
60
+ cheap_model_func=ollama_model_if_cache,
61
+ )
62
+ print(
63
+ rag.query(
64
+ "What are the top themes in this story?", param=QueryParam(mode="global")
65
+ )
66
+ )
67
+
68
+
69
+ def insert():
70
+ from time import time
71
+
72
+ with open("./tests/mock_data.txt", encoding="utf-8-sig") as f:
73
+ FAKE_TEXT = f.read()
74
+
75
+ remove_if_exist(f"{WORKING_DIR}/vdb_entities.json")
76
+ remove_if_exist(f"{WORKING_DIR}/kv_store_full_docs.json")
77
+ remove_if_exist(f"{WORKING_DIR}/kv_store_text_chunks.json")
78
+ remove_if_exist(f"{WORKING_DIR}/kv_store_community_reports.json")
79
+ remove_if_exist(f"{WORKING_DIR}/graph_chunk_entity_relation.graphml")
80
+
81
+ rag = GraphRAG(
82
+ working_dir=WORKING_DIR,
83
+ enable_llm_cache=True,
84
+ best_model_func=ollama_model_if_cache,
85
+ cheap_model_func=ollama_model_if_cache,
86
+ )
87
+ start = time()
88
+ rag.insert(FAKE_TEXT)
89
+ print("indexing time:", time() - start)
90
+ # rag = GraphRAG(working_dir=WORKING_DIR, enable_llm_cache=True)
91
+ # rag.insert(FAKE_TEXT[half_len:])
92
+
93
+
94
+ if __name__ == "__main__":
95
+ insert()
96
+ query()
nano-graphrag/examples/using_ollama_as_llm_and_embedding.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ sys.path.append("..")
5
+ import logging
6
+ import ollama
7
+ import numpy as np
8
+ from nano_graphrag import GraphRAG, QueryParam
9
+ from nano_graphrag.base import BaseKVStorage
10
+ from nano_graphrag._utils import compute_args_hash, wrap_embedding_func_with_attrs
11
+
12
+ logging.basicConfig(level=logging.WARNING)
13
+ logging.getLogger("nano-graphrag").setLevel(logging.INFO)
14
+
15
+ # Assumed llm model settings
16
+ MODEL = "your_model_name"
17
+
18
+ # Assumed embedding model settings
19
+ EMBEDDING_MODEL = "nomic-embed-text"
20
+ EMBEDDING_MODEL_DIM = 768
21
+ EMBEDDING_MODEL_MAX_TOKENS = 8192
22
+
23
+
24
+ async def ollama_model_if_cache(
25
+ prompt, system_prompt=None, history_messages=[], **kwargs
26
+ ) -> str:
27
+ # remove kwargs that are not supported by ollama
28
+ kwargs.pop("max_tokens", None)
29
+ kwargs.pop("response_format", None)
30
+
31
+ ollama_client = ollama.AsyncClient()
32
+ messages = []
33
+ if system_prompt:
34
+ messages.append({"role": "system", "content": system_prompt})
35
+
36
+ # Get the cached response if having-------------------
37
+ hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
38
+ messages.extend(history_messages)
39
+ messages.append({"role": "user", "content": prompt})
40
+ if hashing_kv is not None:
41
+ args_hash = compute_args_hash(MODEL, messages)
42
+ if_cache_return = await hashing_kv.get_by_id(args_hash)
43
+ if if_cache_return is not None:
44
+ return if_cache_return["return"]
45
+ # -----------------------------------------------------
46
+ response = await ollama_client.chat(model=MODEL, messages=messages, **kwargs)
47
+
48
+ result = response["message"]["content"]
49
+ # Cache the response if having-------------------
50
+ if hashing_kv is not None:
51
+ await hashing_kv.upsert({args_hash: {"return": result, "model": MODEL}})
52
+ # -----------------------------------------------------
53
+ return result
54
+
55
+
56
+ def remove_if_exist(file):
57
+ if os.path.exists(file):
58
+ os.remove(file)
59
+
60
+
61
+ WORKING_DIR = "./nano_graphrag_cache_ollama_TEST"
62
+
63
+
64
+ def query():
65
+ rag = GraphRAG(
66
+ working_dir=WORKING_DIR,
67
+ best_model_func=ollama_model_if_cache,
68
+ cheap_model_func=ollama_model_if_cache,
69
+ embedding_func=ollama_embedding,
70
+ )
71
+ print(
72
+ rag.query(
73
+ "What are the top themes in this story?", param=QueryParam(mode="global")
74
+ )
75
+ )
76
+
77
+
78
+ def insert():
79
+ from time import time
80
+
81
+ with open("./tests/mock_data.txt", encoding="utf-8-sig") as f:
82
+ FAKE_TEXT = f.read()
83
+
84
+ remove_if_exist(f"{WORKING_DIR}/vdb_entities.json")
85
+ remove_if_exist(f"{WORKING_DIR}/kv_store_full_docs.json")
86
+ remove_if_exist(f"{WORKING_DIR}/kv_store_text_chunks.json")
87
+ remove_if_exist(f"{WORKING_DIR}/kv_store_community_reports.json")
88
+ remove_if_exist(f"{WORKING_DIR}/graph_chunk_entity_relation.graphml")
89
+
90
+ rag = GraphRAG(
91
+ working_dir=WORKING_DIR,
92
+ enable_llm_cache=True,
93
+ best_model_func=ollama_model_if_cache,
94
+ cheap_model_func=ollama_model_if_cache,
95
+ embedding_func=ollama_embedding,
96
+ )
97
+ start = time()
98
+ rag.insert(FAKE_TEXT)
99
+ print("indexing time:", time() - start)
100
+ # rag = GraphRAG(working_dir=WORKING_DIR, enable_llm_cache=True)
101
+ # rag.insert(FAKE_TEXT[half_len:])
102
+
103
+
104
+ # We're using Ollama to generate embeddings for the BGE model
105
+ @wrap_embedding_func_with_attrs(
106
+ embedding_dim=EMBEDDING_MODEL_DIM,
107
+ max_token_size=EMBEDDING_MODEL_MAX_TOKENS,
108
+ )
109
+ async def ollama_embedding(texts: list[str]) -> np.ndarray:
110
+ embed_text = []
111
+ for text in texts:
112
+ data = ollama.embeddings(model=EMBEDDING_MODEL, prompt=text)
113
+ embed_text.append(data["embedding"])
114
+
115
+ return embed_text
116
+
117
+
118
+ if __name__ == "__main__":
119
+ insert()
120
+ query()
nano-graphrag/examples/using_qdrant_as_vectorDB.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import asyncio
3
+ import uuid
4
+ import numpy as np
5
+ from nano_graphrag import GraphRAG, QueryParam
6
+ from nano_graphrag._utils import logger
7
+ from nano_graphrag.base import BaseVectorStorage
8
+ from dataclasses import dataclass
9
+
10
+ try:
11
+ from qdrant_client import QdrantClient
12
+ from qdrant_client.models import VectorParams, Distance, PointStruct, SearchParams
13
+ except ImportError as original_error:
14
+ raise ImportError(
15
+ "Qdrant client is not installed. Install it using: pip install qdrant-client\n"
16
+ ) from original_error
17
+
18
+
19
+ @dataclass
20
+ class QdrantStorage(BaseVectorStorage):
21
+ def __post_init__(self):
22
+
23
+ # Use a local file-based Qdrant storage
24
+ # Useful for prototyping and CI.
25
+ # For production, refer to:
26
+ # https://qdrant.tech/documentation/guides/installation/
27
+ self._client_file_path = os.path.join(
28
+ self.global_config["working_dir"], "qdrant_storage"
29
+ )
30
+
31
+ self._client = QdrantClient(path=self._client_file_path)
32
+
33
+ self._max_batch_size = self.global_config["embedding_batch_num"]
34
+
35
+ if not self._client.collection_exists(collection_name=self.namespace):
36
+ self._client.create_collection(
37
+ collection_name=self.namespace,
38
+ vectors_config=VectorParams(
39
+ size=self.embedding_func.embedding_dim, distance=Distance.COSINE
40
+ ),
41
+ )
42
+
43
+ async def upsert(self, data: dict[str, dict]):
44
+ logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
45
+
46
+ list_data = [
47
+ {
48
+ "id": k,
49
+ **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
50
+ }
51
+ for k, v in data.items()
52
+ ]
53
+
54
+ contents = [v["content"] for v in data.values()]
55
+ batches = [
56
+ contents[i : i + self._max_batch_size]
57
+ for i in range(0, len(contents), self._max_batch_size)
58
+ ]
59
+
60
+ embeddings_list = await asyncio.gather(
61
+ *[self.embedding_func(batch) for batch in batches]
62
+ )
63
+ embeddings = np.concatenate(embeddings_list)
64
+
65
+ points = [
66
+ PointStruct(
67
+ id=uuid.uuid4().hex,
68
+ vector=embeddings[i].tolist(),
69
+ payload=data,
70
+ )
71
+ for i, data in enumerate(list_data)
72
+ ]
73
+
74
+ results = self._client.upsert(collection_name=self.namespace, points=points)
75
+ return results
76
+
77
+ async def query(self, query, top_k=5):
78
+ embedding = await self.embedding_func([query])
79
+
80
+ results = self._client.query_points(
81
+ collection_name=self.namespace,
82
+ query=embedding[0].tolist(),
83
+ limit=top_k,
84
+ ).points
85
+
86
+ return [
87
+ {**result.payload, "score": result.score}
88
+ for result in results
89
+ ]
90
+
91
+
92
+ def insert():
93
+ data = ["YOUR TEXT DATA HERE", "YOUR TEXT DATA HERE"]
94
+ rag = GraphRAG(
95
+ working_dir="./nano_graphrag_cache_qdrant_TEST",
96
+ enable_llm_cache=True,
97
+ vector_db_storage_cls=QdrantStorage,
98
+ )
99
+ rag.insert(data)
100
+
101
+
102
+ def query():
103
+ rag = GraphRAG(
104
+ working_dir="./nano_graphrag_cache_qdrant_TEST",
105
+ enable_llm_cache=True,
106
+ vector_db_storage_cls=QdrantStorage,
107
+ )
108
+ print(rag.query("YOUR QUERY HERE", param=QueryParam(mode="local")))
109
+
110
+
111
+ if __name__ == "__main__":
112
+ insert()
113
+ query()
nano-graphrag/nano_graphrag/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .graphrag import GraphRAG, QueryParam
2
+
3
+ __version__ = "0.0.8.2"
4
+ __author__ = "Jianbai Ye"
5
+ __url__ = "https://github.com/gusye1234/nano-graphrag"
6
+
7
+ # dp stands for data pack
nano-graphrag/nano_graphrag/_llm.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ from typing import Optional, List, Any, Callable
4
+
5
+ import aioboto3
6
+ from openai import AsyncOpenAI, AsyncAzureOpenAI, APIConnectionError, RateLimitError
7
+
8
+ from tenacity import (
9
+ retry,
10
+ stop_after_attempt,
11
+ wait_exponential,
12
+ retry_if_exception_type,
13
+ )
14
+ import os
15
+
16
+ from ._utils import compute_args_hash, wrap_embedding_func_with_attrs
17
+ from .base import BaseKVStorage
18
+
19
+ global_openai_async_client = None
20
+ global_azure_openai_async_client = None
21
+ global_amazon_bedrock_async_client = None
22
+
23
+
24
+ def get_openai_async_client_instance():
25
+ global global_openai_async_client
26
+ if global_openai_async_client is None:
27
+ global_openai_async_client = AsyncOpenAI()
28
+ return global_openai_async_client
29
+
30
+
31
+ def get_azure_openai_async_client_instance():
32
+ global global_azure_openai_async_client
33
+ if global_azure_openai_async_client is None:
34
+ global_azure_openai_async_client = AsyncAzureOpenAI()
35
+ return global_azure_openai_async_client
36
+
37
+
38
+ def get_amazon_bedrock_async_client_instance():
39
+ global global_amazon_bedrock_async_client
40
+ if global_amazon_bedrock_async_client is None:
41
+ global_amazon_bedrock_async_client = aioboto3.Session()
42
+ return global_amazon_bedrock_async_client
43
+
44
+
45
+ @retry(
46
+ stop=stop_after_attempt(5),
47
+ wait=wait_exponential(multiplier=1, min=4, max=10),
48
+ retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
49
+ )
50
+ async def openai_complete_if_cache(
51
+ model, prompt, system_prompt=None, history_messages=[], **kwargs
52
+ ) -> str:
53
+ openai_async_client = get_openai_async_client_instance()
54
+ hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
55
+ messages = []
56
+ if system_prompt:
57
+ messages.append({"role": "system", "content": system_prompt})
58
+ messages.extend(history_messages)
59
+ messages.append({"role": "user", "content": prompt})
60
+ if hashing_kv is not None:
61
+ args_hash = compute_args_hash(model, messages)
62
+ if_cache_return = await hashing_kv.get_by_id(args_hash)
63
+ if if_cache_return is not None:
64
+ return if_cache_return["return"]
65
+
66
+ response = await openai_async_client.chat.completions.create(
67
+ model=model, messages=messages, **kwargs
68
+ )
69
+
70
+ if hashing_kv is not None:
71
+ await hashing_kv.upsert(
72
+ {args_hash: {"return": response.choices[0].message.content, "model": model}}
73
+ )
74
+ await hashing_kv.index_done_callback()
75
+ return response.choices[0].message.content
76
+
77
+
78
+ @retry(
79
+ stop=stop_after_attempt(5),
80
+ wait=wait_exponential(multiplier=1, min=4, max=10),
81
+ retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
82
+ )
83
+ async def amazon_bedrock_complete_if_cache(
84
+ model, prompt, system_prompt=None, history_messages=[], **kwargs
85
+ ) -> str:
86
+ amazon_bedrock_async_client = get_amazon_bedrock_async_client_instance()
87
+ hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
88
+ messages = []
89
+ messages.extend(history_messages)
90
+ messages.append({"role": "user", "content": [{"text": prompt}]})
91
+ if hashing_kv is not None:
92
+ args_hash = compute_args_hash(model, messages)
93
+ if_cache_return = await hashing_kv.get_by_id(args_hash)
94
+ if if_cache_return is not None:
95
+ return if_cache_return["return"]
96
+
97
+ inference_config = {
98
+ "temperature": 0,
99
+ "maxTokens": 4096 if "max_tokens" not in kwargs else kwargs["max_tokens"],
100
+ }
101
+
102
+ async with amazon_bedrock_async_client.client(
103
+ "bedrock-runtime",
104
+ region_name=os.getenv("AWS_REGION", "us-east-1")
105
+ ) as bedrock_runtime:
106
+ if system_prompt:
107
+ response = await bedrock_runtime.converse(
108
+ modelId=model, messages=messages, inferenceConfig=inference_config,
109
+ system=[{"text": system_prompt}]
110
+ )
111
+ else:
112
+ response = await bedrock_runtime.converse(
113
+ modelId=model, messages=messages, inferenceConfig=inference_config,
114
+ )
115
+
116
+ if hashing_kv is not None:
117
+ await hashing_kv.upsert(
118
+ {args_hash: {"return": response["output"]["message"]["content"][0]["text"], "model": model}}
119
+ )
120
+ await hashing_kv.index_done_callback()
121
+ return response["output"]["message"]["content"][0]["text"]
122
+
123
+
124
+ def create_amazon_bedrock_complete_function(model_id: str) -> Callable:
125
+ """
126
+ Factory function to dynamically create completion functions for Amazon Bedrock
127
+
128
+ Args:
129
+ model_id (str): Amazon Bedrock model identifier (e.g., "us.anthropic.claude-3-sonnet-20240229-v1:0")
130
+
131
+ Returns:
132
+ Callable: Generated completion function
133
+ """
134
+ async def bedrock_complete(
135
+ prompt: str,
136
+ system_prompt: Optional[str] = None,
137
+ history_messages: List[Any] = [],
138
+ **kwargs
139
+ ) -> str:
140
+ return await amazon_bedrock_complete_if_cache(
141
+ model_id,
142
+ prompt,
143
+ system_prompt=system_prompt,
144
+ history_messages=history_messages,
145
+ **kwargs
146
+ )
147
+
148
+ # Set function name for easier debugging
149
+ bedrock_complete.__name__ = f"{model_id}_complete"
150
+
151
+ return bedrock_complete
152
+
153
+
154
+ async def gpt_4o_complete(
155
+ prompt, system_prompt=None, history_messages=[], **kwargs
156
+ ) -> str:
157
+ return await openai_complete_if_cache(
158
+ "gpt-4o",
159
+ prompt,
160
+ system_prompt=system_prompt,
161
+ history_messages=history_messages,
162
+ **kwargs,
163
+ )
164
+
165
+
166
+ async def gpt_4o_mini_complete(
167
+ prompt, system_prompt=None, history_messages=[], **kwargs
168
+ ) -> str:
169
+ return await openai_complete_if_cache(
170
+ "gpt-4o-mini",
171
+ prompt,
172
+ system_prompt=system_prompt,
173
+ history_messages=history_messages,
174
+ **kwargs,
175
+ )
176
+
177
+
178
+ @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
179
+ @retry(
180
+ stop=stop_after_attempt(5),
181
+ wait=wait_exponential(multiplier=1, min=4, max=10),
182
+ retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
183
+ )
184
+ async def amazon_bedrock_embedding(texts: list[str]) -> np.ndarray:
185
+ amazon_bedrock_async_client = get_amazon_bedrock_async_client_instance()
186
+
187
+ async with amazon_bedrock_async_client.client(
188
+ "bedrock-runtime",
189
+ region_name=os.getenv("AWS_REGION", "us-east-1")
190
+ ) as bedrock_runtime:
191
+ embeddings = []
192
+ for text in texts:
193
+ body = json.dumps(
194
+ {
195
+ "inputText": text,
196
+ "dimensions": 1024,
197
+ }
198
+ )
199
+ response = await bedrock_runtime.invoke_model(
200
+ modelId="amazon.titan-embed-text-v2:0", body=body,
201
+ )
202
+ response_body = await response.get("body").read()
203
+ embeddings.append(json.loads(response_body))
204
+ return np.array([dp["embedding"] for dp in embeddings])
205
+
206
+
207
+ @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
208
+ @retry(
209
+ stop=stop_after_attempt(5),
210
+ wait=wait_exponential(multiplier=1, min=4, max=10),
211
+ retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
212
+ )
213
+ async def openai_embedding(texts: list[str]) -> np.ndarray:
214
+ openai_async_client = get_openai_async_client_instance()
215
+ response = await openai_async_client.embeddings.create(
216
+ model="text-embedding-3-small", input=texts, encoding_format="float"
217
+ )
218
+ return np.array([dp.embedding for dp in response.data])
219
+
220
+
221
+ @retry(
222
+ stop=stop_after_attempt(3),
223
+ wait=wait_exponential(multiplier=1, min=4, max=10),
224
+ retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
225
+ )
226
+ async def azure_openai_complete_if_cache(
227
+ deployment_name, prompt, system_prompt=None, history_messages=[], **kwargs
228
+ ) -> str:
229
+ azure_openai_client = get_azure_openai_async_client_instance()
230
+ hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
231
+ messages = []
232
+ if system_prompt:
233
+ messages.append({"role": "system", "content": system_prompt})
234
+ messages.extend(history_messages)
235
+ messages.append({"role": "user", "content": prompt})
236
+ if hashing_kv is not None:
237
+ args_hash = compute_args_hash(deployment_name, messages)
238
+ if_cache_return = await hashing_kv.get_by_id(args_hash)
239
+ if if_cache_return is not None:
240
+ return if_cache_return["return"]
241
+
242
+ response = await azure_openai_client.chat.completions.create(
243
+ model=deployment_name, messages=messages, **kwargs
244
+ )
245
+
246
+ if hashing_kv is not None:
247
+ await hashing_kv.upsert(
248
+ {
249
+ args_hash: {
250
+ "return": response.choices[0].message.content,
251
+ "model": deployment_name,
252
+ }
253
+ }
254
+ )
255
+ await hashing_kv.index_done_callback()
256
+ return response.choices[0].message.content
257
+
258
+
259
+ async def azure_gpt_4o_complete(
260
+ prompt, system_prompt=None, history_messages=[], **kwargs
261
+ ) -> str:
262
+ return await azure_openai_complete_if_cache(
263
+ "gpt-4o",
264
+ prompt,
265
+ system_prompt=system_prompt,
266
+ history_messages=history_messages,
267
+ **kwargs,
268
+ )
269
+
270
+
271
+ async def azure_gpt_4o_mini_complete(
272
+ prompt, system_prompt=None, history_messages=[], **kwargs
273
+ ) -> str:
274
+ return await azure_openai_complete_if_cache(
275
+ "gpt-4o-mini",
276
+ prompt,
277
+ system_prompt=system_prompt,
278
+ history_messages=history_messages,
279
+ **kwargs,
280
+ )
281
+
282
+
283
+ @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
284
+ @retry(
285
+ stop=stop_after_attempt(3),
286
+ wait=wait_exponential(multiplier=1, min=4, max=10),
287
+ retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
288
+ )
289
+ async def azure_openai_embedding(texts: list[str]) -> np.ndarray:
290
+ azure_openai_client = get_azure_openai_async_client_instance()
291
+ response = await azure_openai_client.embeddings.create(
292
+ model="text-embedding-3-small", input=texts, encoding_format="float"
293
+ )
294
+ return np.array([dp.embedding for dp in response.data])
nano-graphrag/nano_graphrag/_op.py ADDED
@@ -0,0 +1,1140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import json
3
+ import asyncio
4
+ from typing import Union
5
+ from collections import Counter, defaultdict
6
+ from ._splitter import SeparatorSplitter
7
+ from ._utils import (
8
+ logger,
9
+ clean_str,
10
+ compute_mdhash_id,
11
+ is_float_regex,
12
+ list_of_list_to_csv,
13
+ pack_user_ass_to_openai_messages,
14
+ split_string_by_multi_markers,
15
+ truncate_list_by_token_size,
16
+
17
+ TokenizerWrapper
18
+ )
19
+ from .base import (
20
+ BaseGraphStorage,
21
+ BaseKVStorage,
22
+ BaseVectorStorage,
23
+ SingleCommunitySchema,
24
+ CommunitySchema,
25
+ TextChunkSchema,
26
+ QueryParam,
27
+ )
28
+ from .prompt import GRAPH_FIELD_SEP, PROMPTS
29
+
30
+
31
+ def chunking_by_token_size(
32
+ tokens_list: list[list[int]],
33
+ doc_keys,
34
+ tokenizer_wrapper: TokenizerWrapper,
35
+ overlap_token_size=128,
36
+ max_token_size=1024,
37
+ ):
38
+ results = []
39
+ for index, tokens in enumerate(tokens_list):
40
+ chunk_token = []
41
+ lengths = []
42
+ for start in range(0, len(tokens), max_token_size - overlap_token_size):
43
+ chunk_token.append(tokens[start : start + max_token_size])
44
+ lengths.append(min(max_token_size, len(tokens) - start))
45
+
46
+
47
+ chunk_texts = tokenizer_wrapper.decode_batch(chunk_token)
48
+
49
+ for i, chunk in enumerate(chunk_texts):
50
+ results.append(
51
+ {
52
+ "tokens": lengths[i],
53
+ "content": chunk.strip(),
54
+ "chunk_order_index": i,
55
+ "full_doc_id": doc_keys[index],
56
+ }
57
+ )
58
+ return results
59
+
60
+
61
+ def chunking_by_seperators(
62
+ tokens_list: list[list[int]],
63
+ doc_keys,
64
+ tokenizer_wrapper: TokenizerWrapper,
65
+ overlap_token_size=128,
66
+ max_token_size=1024,
67
+ ):
68
+ from .prompt import PROMPTS
69
+ # *** 修改 ***: 直接使用 wrapper 编码,而不是获取底层 tokenizer
70
+ separators = [tokenizer_wrapper.encode(s) for s in PROMPTS["default_text_separator"]]
71
+ splitter = SeparatorSplitter(
72
+ separators=separators,
73
+ chunk_size=max_token_size,
74
+ chunk_overlap=overlap_token_size,
75
+ )
76
+ results = []
77
+ for index, tokens in enumerate(tokens_list):
78
+ chunk_tokens = splitter.split_tokens(tokens)
79
+ lengths = [len(c) for c in chunk_tokens]
80
+
81
+ decoded_chunks = tokenizer_wrapper.decode_batch(chunk_tokens)
82
+ for i, chunk in enumerate(decoded_chunks):
83
+ results.append(
84
+ {
85
+ "tokens": lengths[i],
86
+ "content": chunk.strip(),
87
+ "chunk_order_index": i,
88
+ "full_doc_id": doc_keys[index],
89
+ }
90
+ )
91
+ return results
92
+
93
+
94
+ def get_chunks(new_docs, chunk_func=chunking_by_token_size, tokenizer_wrapper: TokenizerWrapper = None, **chunk_func_params):
95
+ inserting_chunks = {}
96
+ new_docs_list = list(new_docs.items())
97
+ docs = [new_doc[1]["content"] for new_doc in new_docs_list]
98
+ doc_keys = [new_doc[0] for new_doc in new_docs_list]
99
+
100
+ tokens = [tokenizer_wrapper.encode(doc) for doc in docs]
101
+ chunks = chunk_func(
102
+ tokens, doc_keys=doc_keys, tokenizer_wrapper=tokenizer_wrapper, overlap_token_size=chunk_func_params.get("overlap_token_size", 128), max_token_size=chunk_func_params.get("max_token_size", 1024)
103
+ )
104
+ for chunk in chunks:
105
+ inserting_chunks.update(
106
+ {compute_mdhash_id(chunk["content"], prefix="chunk-"): chunk}
107
+ )
108
+ return inserting_chunks
109
+
110
+
111
+ async def _handle_entity_relation_summary(
112
+ entity_or_relation_name: str,
113
+ description: str,
114
+ global_config: dict,
115
+ tokenizer_wrapper: TokenizerWrapper,
116
+ ) -> str:
117
+ use_llm_func: callable = global_config["cheap_model_func"]
118
+ llm_max_tokens = global_config["cheap_model_max_token_size"]
119
+ summary_max_tokens = global_config["entity_summary_to_max_tokens"]
120
+
121
+
122
+ tokens = tokenizer_wrapper.encode(description)
123
+ if len(tokens) < summary_max_tokens:
124
+ return description
125
+ prompt_template = PROMPTS["summarize_entity_descriptions"]
126
+
127
+ use_description = tokenizer_wrapper.decode(tokens[:llm_max_tokens])
128
+ context_base = dict(
129
+ entity_name=entity_or_relation_name,
130
+ description_list=use_description.split(GRAPH_FIELD_SEP),
131
+ )
132
+ use_prompt = prompt_template.format(**context_base)
133
+ logger.debug(f"Trigger summary: {entity_or_relation_name}")
134
+ summary = await use_llm_func(use_prompt, max_tokens=summary_max_tokens)
135
+ return summary
136
+
137
+
138
+ async def _handle_single_entity_extraction(
139
+ record_attributes: list[str],
140
+ chunk_key: str,
141
+ ):
142
+ if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
143
+ return None
144
+ # add this record as a node in the G
145
+ entity_name = clean_str(record_attributes[1].upper())
146
+ if not entity_name.strip():
147
+ return None
148
+ entity_type = clean_str(record_attributes[2].upper())
149
+ entity_description = clean_str(record_attributes[3])
150
+ entity_source_id = chunk_key
151
+ return dict(
152
+ entity_name=entity_name,
153
+ entity_type=entity_type,
154
+ description=entity_description,
155
+ source_id=entity_source_id,
156
+ )
157
+
158
+
159
+ async def _handle_single_relationship_extraction(
160
+ record_attributes: list[str],
161
+ chunk_key: str,
162
+ ):
163
+ if len(record_attributes) < 5 or record_attributes[0] != '"relationship"':
164
+ return None
165
+ # add this record as edge
166
+ source = clean_str(record_attributes[1].upper())
167
+ target = clean_str(record_attributes[2].upper())
168
+ edge_description = clean_str(record_attributes[3])
169
+ edge_source_id = chunk_key
170
+ weight = (
171
+ float(record_attributes[-1]) if is_float_regex(record_attributes[-1]) else 1.0
172
+ )
173
+ return dict(
174
+ src_id=source,
175
+ tgt_id=target,
176
+ weight=weight,
177
+ description=edge_description,
178
+ source_id=edge_source_id,
179
+ )
180
+
181
+
182
+ async def _merge_nodes_then_upsert(
183
+ entity_name: str,
184
+ nodes_data: list[dict],
185
+ knwoledge_graph_inst: BaseGraphStorage,
186
+ global_config: dict,
187
+ tokenizer_wrapper,
188
+ ):
189
+ already_entitiy_types = []
190
+ already_source_ids = []
191
+ already_description = []
192
+
193
+ already_node = await knwoledge_graph_inst.get_node(entity_name)
194
+ if already_node is not None:
195
+ already_entitiy_types.append(already_node["entity_type"])
196
+ already_source_ids.extend(
197
+ split_string_by_multi_markers(already_node["source_id"], [GRAPH_FIELD_SEP])
198
+ )
199
+ already_description.append(already_node["description"])
200
+
201
+ entity_type = sorted(
202
+ Counter(
203
+ [dp["entity_type"] for dp in nodes_data] + already_entitiy_types
204
+ ).items(),
205
+ key=lambda x: x[1],
206
+ reverse=True,
207
+ )[0][0]
208
+ description = GRAPH_FIELD_SEP.join(
209
+ sorted(set([dp["description"] for dp in nodes_data] + already_description))
210
+ )
211
+ source_id = GRAPH_FIELD_SEP.join(
212
+ set([dp["source_id"] for dp in nodes_data] + already_source_ids)
213
+ )
214
+ description = await _handle_entity_relation_summary(
215
+ entity_name, description, global_config, tokenizer_wrapper
216
+ )
217
+ node_data = dict(
218
+ entity_type=entity_type,
219
+ description=description,
220
+ source_id=source_id,
221
+ )
222
+ await knwoledge_graph_inst.upsert_node(
223
+ entity_name,
224
+ node_data=node_data,
225
+ )
226
+ node_data["entity_name"] = entity_name
227
+ return node_data
228
+
229
+
230
+ async def _merge_edges_then_upsert(
231
+ src_id: str,
232
+ tgt_id: str,
233
+ edges_data: list[dict],
234
+ knwoledge_graph_inst: BaseGraphStorage,
235
+ global_config: dict,
236
+ tokenizer_wrapper,
237
+ ):
238
+ already_weights = []
239
+ already_source_ids = []
240
+ already_description = []
241
+ already_order = []
242
+ if await knwoledge_graph_inst.has_edge(src_id, tgt_id):
243
+ already_edge = await knwoledge_graph_inst.get_edge(src_id, tgt_id)
244
+ already_weights.append(already_edge["weight"])
245
+ already_source_ids.extend(
246
+ split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP])
247
+ )
248
+ already_description.append(already_edge["description"])
249
+ already_order.append(already_edge.get("order", 1))
250
+
251
+ # [numberchiffre]: `Relationship.order` is only returned from DSPy's predictions
252
+ order = min([dp.get("order", 1) for dp in edges_data] + already_order)
253
+ weight = sum([dp["weight"] for dp in edges_data] + already_weights)
254
+ description = GRAPH_FIELD_SEP.join(
255
+ sorted(set([dp["description"] for dp in edges_data] + already_description))
256
+ )
257
+ source_id = GRAPH_FIELD_SEP.join(
258
+ set([dp["source_id"] for dp in edges_data] + already_source_ids)
259
+ )
260
+ for need_insert_id in [src_id, tgt_id]:
261
+ if not (await knwoledge_graph_inst.has_node(need_insert_id)):
262
+ await knwoledge_graph_inst.upsert_node(
263
+ need_insert_id,
264
+ node_data={
265
+ "source_id": source_id,
266
+ "description": description,
267
+ "entity_type": '"UNKNOWN"',
268
+ },
269
+ )
270
+ description = await _handle_entity_relation_summary(
271
+ (src_id, tgt_id), description, global_config, tokenizer_wrapper
272
+ )
273
+ await knwoledge_graph_inst.upsert_edge(
274
+ src_id,
275
+ tgt_id,
276
+ edge_data=dict(
277
+ weight=weight, description=description, source_id=source_id, order=order
278
+ ),
279
+ )
280
+
281
+
282
+ async def extract_entities(
283
+ chunks: dict[str, TextChunkSchema],
284
+ knwoledge_graph_inst: BaseGraphStorage,
285
+ entity_vdb: BaseVectorStorage,
286
+ tokenizer_wrapper,
287
+ global_config: dict,
288
+ using_amazon_bedrock: bool=False,
289
+ ) -> Union[BaseGraphStorage, None]:
290
+ use_llm_func: callable = global_config["best_model_func"]
291
+ entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
292
+
293
+ ordered_chunks = list(chunks.items())
294
+
295
+ entity_extract_prompt = PROMPTS["entity_extraction"]
296
+ context_base = dict(
297
+ tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
298
+ record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
299
+ completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
300
+ entity_types=",".join(PROMPTS["DEFAULT_ENTITY_TYPES"]),
301
+ )
302
+ continue_prompt = PROMPTS["entiti_continue_extraction"]
303
+ if_loop_prompt = PROMPTS["entiti_if_loop_extraction"]
304
+
305
+ already_processed = 0
306
+ already_entities = 0
307
+ already_relations = 0
308
+
309
+ async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
310
+ nonlocal already_processed, already_entities, already_relations
311
+ chunk_key = chunk_key_dp[0]
312
+ chunk_dp = chunk_key_dp[1]
313
+ content = chunk_dp["content"]
314
+ hint_prompt = entity_extract_prompt.format(**context_base, input_text=content)
315
+ final_result = await use_llm_func(hint_prompt)
316
+ if isinstance(final_result, list):
317
+ final_result = final_result[0]["text"]
318
+
319
+ history = pack_user_ass_to_openai_messages(hint_prompt, final_result, using_amazon_bedrock)
320
+ for now_glean_index in range(entity_extract_max_gleaning):
321
+ glean_result = await use_llm_func(continue_prompt, history_messages=history)
322
+
323
+ history += pack_user_ass_to_openai_messages(continue_prompt, glean_result, using_amazon_bedrock)
324
+ final_result += glean_result
325
+ if now_glean_index == entity_extract_max_gleaning - 1:
326
+ break
327
+
328
+ if_loop_result: str = await use_llm_func(
329
+ if_loop_prompt, history_messages=history
330
+ )
331
+ if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
332
+ if if_loop_result != "yes":
333
+ break
334
+
335
+ records = split_string_by_multi_markers(
336
+ final_result,
337
+ [context_base["record_delimiter"], context_base["completion_delimiter"]],
338
+ )
339
+
340
+ maybe_nodes = defaultdict(list)
341
+ maybe_edges = defaultdict(list)
342
+ for record in records:
343
+ record = re.search(r"\((.*)\)", record)
344
+ if record is None:
345
+ continue
346
+ record = record.group(1)
347
+ record_attributes = split_string_by_multi_markers(
348
+ record, [context_base["tuple_delimiter"]]
349
+ )
350
+ if_entities = await _handle_single_entity_extraction(
351
+ record_attributes, chunk_key
352
+ )
353
+ if if_entities is not None:
354
+ maybe_nodes[if_entities["entity_name"]].append(if_entities)
355
+ continue
356
+
357
+ if_relation = await _handle_single_relationship_extraction(
358
+ record_attributes, chunk_key
359
+ )
360
+ if if_relation is not None:
361
+ maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(
362
+ if_relation
363
+ )
364
+ already_processed += 1
365
+ already_entities += len(maybe_nodes)
366
+ already_relations += len(maybe_edges)
367
+ now_ticks = PROMPTS["process_tickers"][
368
+ already_processed % len(PROMPTS["process_tickers"])
369
+ ]
370
+ print(
371
+ f"{now_ticks} Processed {already_processed}({already_processed*100//len(ordered_chunks)}%) chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
372
+ end="",
373
+ flush=True,
374
+ )
375
+ return dict(maybe_nodes), dict(maybe_edges)
376
+
377
+ # use_llm_func is wrapped in ascynio.Semaphore, limiting max_async callings
378
+ results = await asyncio.gather(
379
+ *[_process_single_content(c) for c in ordered_chunks]
380
+ )
381
+ print() # clear the progress bar
382
+ maybe_nodes = defaultdict(list)
383
+ maybe_edges = defaultdict(list)
384
+ for m_nodes, m_edges in results:
385
+ for k, v in m_nodes.items():
386
+ maybe_nodes[k].extend(v)
387
+ for k, v in m_edges.items():
388
+ # it's undirected graph
389
+ maybe_edges[tuple(sorted(k))].extend(v)
390
+ all_entities_data = await asyncio.gather(
391
+ *[
392
+ _merge_nodes_then_upsert(k, v, knwoledge_graph_inst, global_config, tokenizer_wrapper)
393
+ for k, v in maybe_nodes.items()
394
+ ]
395
+ )
396
+ await asyncio.gather(
397
+ *[
398
+ _merge_edges_then_upsert(k[0], k[1], v, knwoledge_graph_inst, global_config, tokenizer_wrapper)
399
+ for k, v in maybe_edges.items()
400
+ ]
401
+ )
402
+ if not len(all_entities_data):
403
+ logger.warning("Didn't extract any entities, maybe your LLM is not working")
404
+ return None
405
+ if entity_vdb is not None:
406
+ data_for_vdb = {
407
+ compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
408
+ "content": dp["entity_name"] + dp["description"],
409
+ "entity_name": dp["entity_name"],
410
+ }
411
+ for dp in all_entities_data
412
+ }
413
+ await entity_vdb.upsert(data_for_vdb)
414
+ return knwoledge_graph_inst
415
+
416
+
417
+ def _pack_single_community_by_sub_communities(
418
+ community: SingleCommunitySchema,
419
+ max_token_size: int,
420
+ already_reports: dict[str, CommunitySchema],
421
+ tokenizer_wrapper: TokenizerWrapper,
422
+ ) -> tuple[str, int, set, set]:
423
+ all_sub_communities = [
424
+ already_reports[k] for k in community["sub_communities"] if k in already_reports
425
+ ]
426
+ all_sub_communities = sorted(
427
+ all_sub_communities, key=lambda x: x["occurrence"], reverse=True
428
+ )
429
+
430
+ may_trun_all_sub_communities = truncate_list_by_token_size(
431
+ all_sub_communities,
432
+ key=lambda x: x["report_string"],
433
+ max_token_size=max_token_size,
434
+ tokenizer_wrapper=tokenizer_wrapper,
435
+ )
436
+ sub_fields = ["id", "report", "rating", "importance"]
437
+ sub_communities_describe = list_of_list_to_csv(
438
+ [sub_fields]
439
+ + [
440
+ [
441
+ i,
442
+ c["report_string"],
443
+ c["report_json"].get("rating", -1),
444
+ c["occurrence"],
445
+ ]
446
+ for i, c in enumerate(may_trun_all_sub_communities)
447
+ ]
448
+ )
449
+ already_nodes = []
450
+ already_edges = []
451
+ for c in may_trun_all_sub_communities:
452
+ already_nodes.extend(c["nodes"])
453
+ already_edges.extend([tuple(e) for e in c["edges"]])
454
+
455
+
456
+ return (
457
+ sub_communities_describe,
458
+ len(tokenizer_wrapper.encode(sub_communities_describe)),
459
+ set(already_nodes),
460
+ set(already_edges),
461
+ )
462
+
463
+
464
+ async def _pack_single_community_describe(
465
+ knwoledge_graph_inst: BaseGraphStorage,
466
+ community: SingleCommunitySchema,
467
+ tokenizer_wrapper: "TokenizerWrapper",
468
+ max_token_size: int = 12000,
469
+ already_reports: dict[str, CommunitySchema] = {},
470
+ global_config: dict = {},
471
+ ) -> str:
472
+
473
+
474
+
475
+ # 1. 准备原始数据
476
+ nodes_in_order = sorted(community["nodes"])
477
+ edges_in_order = sorted(community["edges"], key=lambda x: x[0] + x[1])
478
+
479
+ nodes_data = await asyncio.gather(
480
+ *[knwoledge_graph_inst.get_node(n) for n in nodes_in_order]
481
+ )
482
+ edges_data = await asyncio.gather(
483
+ *[knwoledge_graph_inst.get_edge(src, tgt) for src, tgt in edges_in_order]
484
+ )
485
+
486
+
487
+ # 2. 定义模板和固定开销
488
+ final_template = """-----Reports-----
489
+ ```csv
490
+ {reports}
491
+ ```
492
+ -----Entities-----
493
+ ```csv
494
+ {entities}
495
+ ```
496
+ -----Relationships-----
497
+ ```csv
498
+ {relationships}
499
+ ```"""
500
+ base_template_tokens = len(tokenizer_wrapper.encode(
501
+ final_template.format(reports="", entities="", relationships="")
502
+ ))
503
+ remaining_budget = max_token_size - base_template_tokens
504
+
505
+ # 3. 处理子社区报告
506
+ report_describe = ""
507
+ contain_nodes = set()
508
+ contain_edges = set()
509
+
510
+ # 启发式截断检测
511
+ truncated = len(nodes_in_order) > 100 or len(edges_in_order) > 100
512
+
513
+ need_to_use_sub_communities = (
514
+ truncated and
515
+ community["sub_communities"] and
516
+ already_reports
517
+ )
518
+ force_to_use_sub_communities = global_config["addon_params"].get(
519
+ "force_to_use_sub_communities", False
520
+ )
521
+
522
+ if need_to_use_sub_communities or force_to_use_sub_communities:
523
+ logger.debug(f"Community {community['title']} using sub-communities")
524
+ # 获取子社区报告及包含的节点/边
525
+ result = _pack_single_community_by_sub_communities(
526
+ community, remaining_budget, already_reports, tokenizer_wrapper
527
+ )
528
+ report_describe, report_size, contain_nodes, contain_edges = result
529
+ remaining_budget = max(0, remaining_budget - report_size)
530
+
531
+ # 4. 准备节点和边数据(过滤子社区已包含的)
532
+ def format_row(row: list) -> str:
533
+ return ','.join('"{}"'.format(str(item).replace('"', '""')) for item in row)
534
+
535
+ node_fields = ["id", "entity", "type", "description", "degree"]
536
+ edge_fields = ["id", "source", "target", "description", "rank"]
537
+
538
+ # 获取度数并创建数据结构
539
+ node_degrees = await knwoledge_graph_inst.node_degrees_batch(nodes_in_order)
540
+ edge_degrees = await knwoledge_graph_inst.edge_degrees_batch(edges_in_order)
541
+
542
+ # 过滤已存在于子社区的节点/边
543
+ nodes_list_data = [
544
+ [i, name, data.get("entity_type", "UNKNOWN"),
545
+ data.get("description", "UNKNOWN"), node_degrees[i]]
546
+ for i, (name, data) in enumerate(zip(nodes_in_order, nodes_data))
547
+ if name not in contain_nodes # 关键过滤
548
+ ]
549
+
550
+ edges_list_data = [
551
+ [i, edge[0], edge[1], data.get("description", "UNKNOWN"), edge_degrees[i]]
552
+ for i, (edge, data) in enumerate(zip(edges_in_order, edges_data))
553
+ if (edge[0], edge[1]) not in contain_edges # 关键过滤
554
+ ]
555
+
556
+ # 按重要性排序
557
+ nodes_list_data.sort(key=lambda x: x[-1], reverse=True)
558
+ edges_list_data.sort(key=lambda x: x[-1], reverse=True)
559
+
560
+ # 5. 动态分配预算
561
+ # 计算表头开销
562
+ header_tokens = len(tokenizer_wrapper.encode(
563
+ list_of_list_to_csv([node_fields]) + "\n" + list_of_list_to_csv([edge_fields])
564
+ ))
565
+
566
+
567
+
568
+ data_budget = max(0, remaining_budget - header_tokens)
569
+ total_items = len(nodes_list_data) + len(edges_list_data)
570
+ node_ratio = len(nodes_list_data) / max(1, total_items)
571
+ edge_ratio = 1 - node_ratio
572
+
573
+
574
+
575
+
576
+ # 执行截断
577
+ nodes_final = truncate_list_by_token_size(
578
+ nodes_list_data, key=format_row,
579
+ max_token_size=int(data_budget * node_ratio),
580
+ tokenizer_wrapper=tokenizer_wrapper
581
+ )
582
+ edges_final = truncate_list_by_token_size(
583
+ edges_list_data, key=format_row,
584
+ max_token_size= int(data_budget * edge_ratio),
585
+ tokenizer_wrapper=tokenizer_wrapper
586
+ )
587
+
588
+ # 6. 组装最终输出
589
+ nodes_describe = list_of_list_to_csv([node_fields] + nodes_final)
590
+ edges_describe = list_of_list_to_csv([edge_fields] + edges_final)
591
+
592
+
593
+
594
+ final_output = final_template.format(
595
+ reports=report_describe,
596
+ entities=nodes_describe,
597
+ relationships=edges_describe
598
+ )
599
+
600
+ return final_output
601
+
602
+
603
+ def _community_report_json_to_str(parsed_output: dict) -> str:
604
+ """refer official graphrag: index/graph/extractors/community_reports"""
605
+ title = parsed_output.get("title", "Report")
606
+ summary = parsed_output.get("summary", "")
607
+ findings = parsed_output.get("findings", [])
608
+
609
+ def finding_summary(finding: dict):
610
+ if isinstance(finding, str):
611
+ return finding
612
+ return finding.get("summary")
613
+
614
+ def finding_explanation(finding: dict):
615
+ if isinstance(finding, str):
616
+ return ""
617
+ return finding.get("explanation")
618
+
619
+ report_sections = "\n\n".join(
620
+ f"## {finding_summary(f)}\n\n{finding_explanation(f)}" for f in findings
621
+ )
622
+ return f"# {title}\n\n{summary}\n\n{report_sections}"
623
+
624
+
625
+ async def generate_community_report(
626
+ community_report_kv: BaseKVStorage[CommunitySchema],
627
+ knwoledge_graph_inst: BaseGraphStorage,
628
+ tokenizer_wrapper: TokenizerWrapper,
629
+ global_config: dict,
630
+ ):
631
+ llm_extra_kwargs = global_config["special_community_report_llm_kwargs"]
632
+ use_llm_func: callable = global_config["best_model_func"]
633
+ use_string_json_convert_func: callable = global_config["convert_response_to_json_func"]
634
+
635
+ communities_schema = await knwoledge_graph_inst.community_schema()
636
+ community_keys, community_values = list(communities_schema.keys()), list(communities_schema.values())
637
+ already_processed = 0
638
+
639
+ prompt_template = PROMPTS["community_report"]
640
+
641
+ prompt_overhead = len(tokenizer_wrapper.encode(prompt_template.format(input_text="")))
642
+
643
+ async def _form_single_community_report(
644
+ community: SingleCommunitySchema, already_reports: dict[str, CommunitySchema]
645
+ ):
646
+ nonlocal already_processed
647
+ describe = await _pack_single_community_describe(
648
+ knwoledge_graph_inst,
649
+ community,
650
+ tokenizer_wrapper=tokenizer_wrapper,
651
+ max_token_size=global_config["best_model_max_token_size"] - prompt_overhead -200, # extra token for chat template and prompt template
652
+ already_reports=already_reports,
653
+ global_config=global_config,
654
+ )
655
+ prompt = prompt_template.format(input_text=describe)
656
+
657
+
658
+ response = await use_llm_func(prompt, **llm_extra_kwargs)
659
+ data = use_string_json_convert_func(response)
660
+ already_processed += 1
661
+ now_ticks = PROMPTS["process_tickers"][already_processed % len(PROMPTS["process_tickers"])]
662
+ print(f"{now_ticks} Processed {already_processed} communities\r", end="", flush=True)
663
+ return data
664
+
665
+ levels = sorted(set([c["level"] for c in community_values]), reverse=True)
666
+ logger.info(f"Generating by levels: {levels}")
667
+ community_datas = {}
668
+ for level in levels:
669
+ this_level_community_keys, this_level_community_values = zip(
670
+ *[
671
+ (k, v)
672
+ for k, v in zip(community_keys, community_values)
673
+ if v["level"] == level
674
+ ]
675
+ )
676
+ this_level_communities_reports = await asyncio.gather(
677
+ *[
678
+ _form_single_community_report(c, community_datas)
679
+ for c in this_level_community_values
680
+ ]
681
+ )
682
+ community_datas.update(
683
+ {
684
+ k: {
685
+ "report_string": _community_report_json_to_str(r),
686
+ "report_json": r,
687
+ **v,
688
+ }
689
+ for k, r, v in zip(
690
+ this_level_community_keys,
691
+ this_level_communities_reports,
692
+ this_level_community_values,
693
+ )
694
+ }
695
+ )
696
+ print() # clear the progress bar
697
+ await community_report_kv.upsert(community_datas)
698
+
699
+
700
+ async def _find_most_related_community_from_entities(
701
+ node_datas: list[dict],
702
+ query_param: QueryParam,
703
+ community_reports: BaseKVStorage[CommunitySchema],
704
+ tokenizer_wrapper,
705
+ ):
706
+ related_communities = []
707
+ for node_d in node_datas:
708
+ if "clusters" not in node_d:
709
+ continue
710
+ related_communities.extend(json.loads(node_d["clusters"]))
711
+ related_community_dup_keys = [
712
+ str(dp["cluster"])
713
+ for dp in related_communities
714
+ if dp["level"] <= query_param.level
715
+ ]
716
+ related_community_keys_counts = dict(Counter(related_community_dup_keys))
717
+ _related_community_datas = await asyncio.gather(
718
+ *[community_reports.get_by_id(k) for k in related_community_keys_counts.keys()]
719
+ )
720
+ related_community_datas = {
721
+ k: v
722
+ for k, v in zip(related_community_keys_counts.keys(), _related_community_datas)
723
+ if v is not None
724
+ }
725
+ related_community_keys = sorted(
726
+ related_community_keys_counts.keys(),
727
+ key=lambda k: (
728
+ related_community_keys_counts[k],
729
+ related_community_datas[k]["report_json"].get("rating", -1),
730
+ ),
731
+ reverse=True,
732
+ )
733
+ sorted_community_datas = [
734
+ related_community_datas[k] for k in related_community_keys
735
+ ]
736
+
737
+ use_community_reports = truncate_list_by_token_size(
738
+ sorted_community_datas,
739
+ key=lambda x: x["report_string"],
740
+ max_token_size=query_param.local_max_token_for_community_report,
741
+ tokenizer_wrapper=tokenizer_wrapper,
742
+ )
743
+ if query_param.local_community_single_one:
744
+ use_community_reports = use_community_reports[:1]
745
+ return use_community_reports
746
+
747
+
748
+ async def _find_most_related_text_unit_from_entities(
749
+ node_datas: list[dict],
750
+ query_param: QueryParam,
751
+ text_chunks_db: BaseKVStorage[TextChunkSchema],
752
+ knowledge_graph_inst: BaseGraphStorage,
753
+ tokenizer_wrapper,
754
+ ):
755
+ text_units = [
756
+ split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
757
+ for dp in node_datas
758
+ ]
759
+ edges = await knowledge_graph_inst.get_nodes_edges_batch([dp["entity_name"] for dp in node_datas])
760
+ all_one_hop_nodes = set()
761
+ for this_edges in edges:
762
+ if not this_edges:
763
+ continue
764
+ all_one_hop_nodes.update([e[1] for e in this_edges])
765
+ all_one_hop_nodes = list(all_one_hop_nodes)
766
+ all_one_hop_nodes_data = await knowledge_graph_inst.get_nodes_batch(all_one_hop_nodes)
767
+ all_one_hop_text_units_lookup = {
768
+ k: set(split_string_by_multi_markers(v["source_id"], [GRAPH_FIELD_SEP]))
769
+ for k, v in zip(all_one_hop_nodes, all_one_hop_nodes_data)
770
+ if v is not None
771
+ }
772
+ all_text_units_lookup = {}
773
+ for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)):
774
+ for c_id in this_text_units:
775
+ if c_id in all_text_units_lookup:
776
+ continue
777
+ relation_counts = 0
778
+ for e in this_edges:
779
+ if (
780
+ e[1] in all_one_hop_text_units_lookup
781
+ and c_id in all_one_hop_text_units_lookup[e[1]]
782
+ ):
783
+ relation_counts += 1
784
+ all_text_units_lookup[c_id] = {
785
+ "data": await text_chunks_db.get_by_id(c_id),
786
+ "order": index,
787
+ "relation_counts": relation_counts,
788
+ }
789
+ if any([v is None for v in all_text_units_lookup.values()]):
790
+ logger.warning("Text chunks are missing, maybe the storage is damaged")
791
+ all_text_units = [
792
+ {"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None
793
+ ]
794
+ all_text_units = sorted(
795
+ all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
796
+ )
797
+ all_text_units = truncate_list_by_token_size(
798
+ all_text_units,
799
+ key=lambda x: x["data"]["content"],
800
+ max_token_size=query_param.local_max_token_for_text_unit,
801
+ tokenizer_wrapper=tokenizer_wrapper, # 传入 wrapper
802
+ )
803
+ all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units]
804
+ return all_text_units
805
+
806
+
807
+ async def _find_most_related_edges_from_entities(
808
+ node_datas: list[dict],
809
+ query_param: QueryParam,
810
+ knowledge_graph_inst: BaseGraphStorage,
811
+ tokenizer_wrapper,
812
+ ):
813
+ all_related_edges = await knowledge_graph_inst.get_nodes_edges_batch([dp["entity_name"] for dp in node_datas])
814
+
815
+ all_edges = []
816
+ seen = set()
817
+
818
+ for this_edges in all_related_edges:
819
+ for e in this_edges:
820
+ sorted_edge = tuple(sorted(e))
821
+ if sorted_edge not in seen:
822
+ seen.add(sorted_edge)
823
+ all_edges.append(sorted_edge)
824
+
825
+ all_edges_pack = await knowledge_graph_inst.get_edges_batch(all_edges)
826
+ all_edges_degree = await knowledge_graph_inst.edge_degrees_batch(all_edges)
827
+ all_edges_data = [
828
+ {"src_tgt": k, "rank": d, **v}
829
+ for k, v, d in zip(all_edges, all_edges_pack, all_edges_degree)
830
+ if v is not None
831
+ ]
832
+ all_edges_data = sorted(
833
+ all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
834
+ )
835
+ all_edges_data = truncate_list_by_token_size(
836
+ all_edges_data,
837
+ key=lambda x: x["description"],
838
+ max_token_size=query_param.local_max_token_for_local_context,
839
+ tokenizer_wrapper=tokenizer_wrapper,
840
+ )
841
+ return all_edges_data
842
+
843
+
844
+ async def _build_local_query_context(
845
+ query,
846
+ knowledge_graph_inst: BaseGraphStorage,
847
+ entities_vdb: BaseVectorStorage,
848
+ community_reports: BaseKVStorage[CommunitySchema],
849
+ text_chunks_db: BaseKVStorage[TextChunkSchema],
850
+ query_param: QueryParam,
851
+ tokenizer_wrapper,
852
+ ):
853
+ results = await entities_vdb.query(query, top_k=query_param.top_k)
854
+ if not len(results):
855
+ return None
856
+ node_datas = await knowledge_graph_inst.get_nodes_batch([r["entity_name"] for r in results])
857
+ if not all([n is not None for n in node_datas]):
858
+ logger.warning("Some nodes are missing, maybe the storage is damaged")
859
+ node_degrees = await knowledge_graph_inst.node_degrees_batch([r["entity_name"] for r in results])
860
+ node_datas = [
861
+ {**n, "entity_name": k["entity_name"], "rank": d}
862
+ for k, n, d in zip(results, node_datas, node_degrees)
863
+ if n is not None
864
+ ]
865
+ use_communities = await _find_most_related_community_from_entities(
866
+ node_datas, query_param, community_reports, tokenizer_wrapper
867
+ )
868
+ use_text_units = await _find_most_related_text_unit_from_entities(
869
+ node_datas, query_param, text_chunks_db, knowledge_graph_inst, tokenizer_wrapper
870
+ )
871
+ use_relations = await _find_most_related_edges_from_entities(
872
+ node_datas, query_param, knowledge_graph_inst, tokenizer_wrapper
873
+ )
874
+ logger.info(
875
+ f"Using {len(node_datas)} entites, {len(use_communities)} communities, {len(use_relations)} relations, {len(use_text_units)} text units"
876
+ )
877
+ entites_section_list = [["id", "entity", "type", "description", "rank"]]
878
+ for i, n in enumerate(node_datas):
879
+ entites_section_list.append(
880
+ [
881
+ i,
882
+ n["entity_name"],
883
+ n.get("entity_type", "UNKNOWN"),
884
+ n.get("description", "UNKNOWN"),
885
+ n["rank"],
886
+ ]
887
+ )
888
+ entities_context = list_of_list_to_csv(entites_section_list)
889
+
890
+ relations_section_list = [
891
+ ["id", "source", "target", "description", "weight", "rank"]
892
+ ]
893
+ for i, e in enumerate(use_relations):
894
+ relations_section_list.append(
895
+ [
896
+ i,
897
+ e["src_tgt"][0],
898
+ e["src_tgt"][1],
899
+ e["description"],
900
+ e["weight"],
901
+ e["rank"],
902
+ ]
903
+ )
904
+ relations_context = list_of_list_to_csv(relations_section_list)
905
+
906
+ communities_section_list = [["id", "content"]]
907
+ for i, c in enumerate(use_communities):
908
+ communities_section_list.append([i, c["report_string"]])
909
+ communities_context = list_of_list_to_csv(communities_section_list)
910
+
911
+ text_units_section_list = [["id", "content"]]
912
+ for i, t in enumerate(use_text_units):
913
+ text_units_section_list.append([i, t["content"]])
914
+ text_units_context = list_of_list_to_csv(text_units_section_list)
915
+ return f"""
916
+ -----Reports-----
917
+ ```csv
918
+ {communities_context}
919
+ ```
920
+ -----Entities-----
921
+ ```csv
922
+ {entities_context}
923
+ ```
924
+ -----Relationships-----
925
+ ```csv
926
+ {relations_context}
927
+ ```
928
+ -----Sources-----
929
+ ```csv
930
+ {text_units_context}
931
+ ```
932
+ """
933
+
934
+
935
+ async def local_query(
936
+ query,
937
+ knowledge_graph_inst: BaseGraphStorage,
938
+ entities_vdb: BaseVectorStorage,
939
+ community_reports: BaseKVStorage[CommunitySchema],
940
+ text_chunks_db: BaseKVStorage[TextChunkSchema],
941
+ query_param: QueryParam,
942
+ tokenizer_wrapper,
943
+ global_config: dict,
944
+ ) -> str:
945
+ use_model_func = global_config["best_model_func"]
946
+ context = await _build_local_query_context(
947
+ query,
948
+ knowledge_graph_inst,
949
+ entities_vdb,
950
+ community_reports,
951
+ text_chunks_db,
952
+ query_param,
953
+ tokenizer_wrapper,
954
+ )
955
+ if query_param.only_need_context:
956
+ return context
957
+ if context is None:
958
+ return PROMPTS["fail_response"]
959
+ sys_prompt_temp = PROMPTS["local_rag_response"]
960
+ sys_prompt = sys_prompt_temp.format(
961
+ context_data=context, response_type=query_param.response_type
962
+ )
963
+ response = await use_model_func(
964
+ query,
965
+ system_prompt=sys_prompt,
966
+ )
967
+ return response
968
+
969
+
970
+ async def _map_global_communities(
971
+ query: str,
972
+ communities_data: list[CommunitySchema],
973
+ query_param: QueryParam,
974
+ global_config: dict,
975
+ tokenizer_wrapper,
976
+ ):
977
+ use_string_json_convert_func = global_config["convert_response_to_json_func"]
978
+ use_model_func = global_config["best_model_func"]
979
+ community_groups = []
980
+ while len(communities_data):
981
+ this_group = truncate_list_by_token_size(
982
+ communities_data,
983
+ key=lambda x: x["report_string"],
984
+ max_token_size=query_param.global_max_token_for_community_report,
985
+ tokenizer_wrapper=tokenizer_wrapper, # 传入 wrapper
986
+ )
987
+ community_groups.append(this_group)
988
+ communities_data = communities_data[len(this_group) :]
989
+
990
+ async def _process(community_truncated_datas: list[CommunitySchema]) -> dict:
991
+ communities_section_list = [["id", "content", "rating", "importance"]]
992
+ for i, c in enumerate(community_truncated_datas):
993
+ communities_section_list.append(
994
+ [
995
+ i,
996
+ c["report_string"],
997
+ c["report_json"].get("rating", 0),
998
+ c["occurrence"],
999
+ ]
1000
+ )
1001
+ community_context = list_of_list_to_csv(communities_section_list)
1002
+ sys_prompt_temp = PROMPTS["global_map_rag_points"]
1003
+ sys_prompt = sys_prompt_temp.format(context_data=community_context)
1004
+ response = await use_model_func(
1005
+ query,
1006
+ system_prompt=sys_prompt,
1007
+ **query_param.global_special_community_map_llm_kwargs,
1008
+ )
1009
+ data = use_string_json_convert_func(response)
1010
+ return data.get("points", [])
1011
+
1012
+ logger.info(f"Grouping to {len(community_groups)} groups for global search")
1013
+ responses = await asyncio.gather(*[_process(c) for c in community_groups])
1014
+ return responses
1015
+
1016
+
1017
+ async def global_query(
1018
+ query,
1019
+ knowledge_graph_inst: BaseGraphStorage,
1020
+ entities_vdb: BaseVectorStorage,
1021
+ community_reports: BaseKVStorage[CommunitySchema],
1022
+ text_chunks_db: BaseKVStorage[TextChunkSchema],
1023
+ query_param: QueryParam,
1024
+ tokenizer_wrapper,
1025
+ global_config: dict,
1026
+ ) -> str:
1027
+ community_schema = await knowledge_graph_inst.community_schema()
1028
+ community_schema = {
1029
+ k: v for k, v in community_schema.items() if v["level"] <= query_param.level
1030
+ }
1031
+ if not len(community_schema):
1032
+ return PROMPTS["fail_response"]
1033
+ use_model_func = global_config["best_model_func"]
1034
+
1035
+ sorted_community_schemas = sorted(
1036
+ community_schema.items(),
1037
+ key=lambda x: x[1]["occurrence"],
1038
+ reverse=True,
1039
+ )
1040
+ sorted_community_schemas = sorted_community_schemas[
1041
+ : query_param.global_max_consider_community
1042
+ ]
1043
+ community_datas = await community_reports.get_by_ids(
1044
+ [k[0] for k in sorted_community_schemas]
1045
+ )
1046
+ community_datas = [c for c in community_datas if c is not None]
1047
+ community_datas = [
1048
+ c
1049
+ for c in community_datas
1050
+ if c["report_json"].get("rating", 0) >= query_param.global_min_community_rating
1051
+ ]
1052
+ community_datas = sorted(
1053
+ community_datas,
1054
+ key=lambda x: (x["occurrence"], x["report_json"].get("rating", 0)),
1055
+ reverse=True,
1056
+ )
1057
+ logger.info(f"Revtrieved {len(community_datas)} communities")
1058
+
1059
+ map_communities_points = await _map_global_communities(
1060
+ query, community_datas, query_param, global_config, tokenizer_wrapper
1061
+ )
1062
+ final_support_points = []
1063
+ for i, mc in enumerate(map_communities_points):
1064
+ for point in mc:
1065
+ if "description" not in point:
1066
+ continue
1067
+ final_support_points.append(
1068
+ {
1069
+ "analyst": i,
1070
+ "answer": point["description"],
1071
+ "score": point.get("score", 1),
1072
+ }
1073
+ )
1074
+ final_support_points = [p for p in final_support_points if p["score"] > 0]
1075
+ if not len(final_support_points):
1076
+ return PROMPTS["fail_response"]
1077
+ final_support_points = sorted(
1078
+ final_support_points, key=lambda x: x["score"], reverse=True
1079
+ )
1080
+ final_support_points = truncate_list_by_token_size(
1081
+ final_support_points,
1082
+ key=lambda x: x["answer"],
1083
+ max_token_size=query_param.global_max_token_for_community_report,
1084
+ tokenizer_wrapper=tokenizer_wrapper, # 传入 wrapper
1085
+ )
1086
+ points_context = []
1087
+ for dp in final_support_points:
1088
+ points_context.append(
1089
+ f"""----Analyst {dp['analyst']}----
1090
+ Importance Score: {dp['score']}
1091
+ {dp['answer']}
1092
+ """
1093
+ )
1094
+ points_context = "\n".join(points_context)
1095
+ if query_param.only_need_context:
1096
+ return points_context
1097
+ sys_prompt_temp = PROMPTS["global_reduce_rag_response"]
1098
+ response = await use_model_func(
1099
+ query,
1100
+ sys_prompt_temp.format(
1101
+ report_data=points_context, response_type=query_param.response_type
1102
+ ),
1103
+ )
1104
+ return response
1105
+
1106
+
1107
+ async def naive_query(
1108
+ query,
1109
+ chunks_vdb: BaseVectorStorage,
1110
+ text_chunks_db: BaseKVStorage[TextChunkSchema],
1111
+ query_param: QueryParam,
1112
+ tokenizer_wrapper,
1113
+ global_config: dict,
1114
+ ):
1115
+ use_model_func = global_config["best_model_func"]
1116
+ results = await chunks_vdb.query(query, top_k=query_param.top_k)
1117
+ if not len(results):
1118
+ return PROMPTS["fail_response"]
1119
+ chunks_ids = [r["id"] for r in results]
1120
+ chunks = await text_chunks_db.get_by_ids(chunks_ids)
1121
+
1122
+ maybe_trun_chunks = truncate_list_by_token_size(
1123
+ chunks,
1124
+ key=lambda x: x["content"],
1125
+ max_token_size=query_param.naive_max_token_for_text_unit,
1126
+ tokenizer_wrapper=tokenizer_wrapper, # 传入 wrapper
1127
+ )
1128
+ logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks")
1129
+ section = "--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks])
1130
+ if query_param.only_need_context:
1131
+ return section
1132
+ sys_prompt_temp = PROMPTS["naive_rag_response"]
1133
+ sys_prompt = sys_prompt_temp.format(
1134
+ content_data=section, response_type=query_param.response_type
1135
+ )
1136
+ response = await use_model_func(
1137
+ query,
1138
+ system_prompt=sys_prompt,
1139
+ )
1140
+ return response
nano-graphrag/nano_graphrag/_splitter.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union, Literal
2
+
3
+ class SeparatorSplitter:
4
+ def __init__(
5
+ self,
6
+ separators: Optional[List[List[int]]] = None,
7
+ keep_separator: Union[bool, Literal["start", "end"]] = "end",
8
+ chunk_size: int = 4000,
9
+ chunk_overlap: int = 200,
10
+ length_function: callable = len,
11
+ ):
12
+ self._separators = separators or []
13
+ self._keep_separator = keep_separator
14
+ self._chunk_size = chunk_size
15
+ self._chunk_overlap = chunk_overlap
16
+ self._length_function = length_function
17
+
18
+ def split_tokens(self, tokens: List[int]) -> List[List[int]]:
19
+ splits = self._split_tokens_with_separators(tokens)
20
+ return self._merge_splits(splits)
21
+
22
+ def _split_tokens_with_separators(self, tokens: List[int]) -> List[List[int]]:
23
+ splits = []
24
+ current_split = []
25
+ i = 0
26
+ while i < len(tokens):
27
+ separator_found = False
28
+ for separator in self._separators:
29
+ if tokens[i:i+len(separator)] == separator:
30
+ if self._keep_separator in [True, "end"]:
31
+ current_split.extend(separator)
32
+ if current_split:
33
+ splits.append(current_split)
34
+ current_split = []
35
+ if self._keep_separator == "start":
36
+ current_split.extend(separator)
37
+ i += len(separator)
38
+ separator_found = True
39
+ break
40
+ if not separator_found:
41
+ current_split.append(tokens[i])
42
+ i += 1
43
+ if current_split:
44
+ splits.append(current_split)
45
+ return [s for s in splits if s]
46
+
47
+ def _merge_splits(self, splits: List[List[int]]) -> List[List[int]]:
48
+ if not splits:
49
+ return []
50
+
51
+ merged_splits = []
52
+ current_chunk = []
53
+
54
+ for split in splits:
55
+ if not current_chunk:
56
+ current_chunk = split
57
+ elif self._length_function(current_chunk) + self._length_function(split) <= self._chunk_size:
58
+ current_chunk.extend(split)
59
+ else:
60
+ merged_splits.append(current_chunk)
61
+ current_chunk = split
62
+
63
+ if current_chunk:
64
+ merged_splits.append(current_chunk)
65
+
66
+ if len(merged_splits) == 1 and self._length_function(merged_splits[0]) > self._chunk_size:
67
+ return self._split_chunk(merged_splits[0])
68
+
69
+ if self._chunk_overlap > 0:
70
+ return self._enforce_overlap(merged_splits)
71
+
72
+ return merged_splits
73
+
74
+ def _split_chunk(self, chunk: List[int]) -> List[List[int]]:
75
+ result = []
76
+ for i in range(0, len(chunk), self._chunk_size - self._chunk_overlap):
77
+ new_chunk = chunk[i:i + self._chunk_size]
78
+ if len(new_chunk) > self._chunk_overlap: # 只有当 chunk 长度大于 overlap 时才添加
79
+ result.append(new_chunk)
80
+ return result
81
+
82
+ def _enforce_overlap(self, chunks: List[List[int]]) -> List[List[int]]:
83
+ result = []
84
+ for i, chunk in enumerate(chunks):
85
+ if i == 0:
86
+ result.append(chunk)
87
+ else:
88
+ overlap = chunks[i-1][-self._chunk_overlap:]
89
+ new_chunk = overlap + chunk
90
+ if self._length_function(new_chunk) > self._chunk_size:
91
+ new_chunk = new_chunk[:self._chunk_size]
92
+ result.append(new_chunk)
93
+ return result
94
+
nano-graphrag/nano_graphrag/_storage/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .gdb_networkx import NetworkXStorage
2
+ from .gdb_neo4j import Neo4jStorage
3
+ from .vdb_hnswlib import HNSWVectorStorage
4
+ from .vdb_nanovectordb import NanoVectorDBStorage
5
+ from .kv_json import JsonKVStorage
nano-graphrag/nano_graphrag/_storage/gdb_neo4j.py ADDED
@@ -0,0 +1,529 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import asyncio
3
+ from collections import defaultdict
4
+ from typing import List
5
+ from neo4j import AsyncGraphDatabase
6
+ from dataclasses import dataclass
7
+ from typing import Union
8
+ from ..base import BaseGraphStorage, SingleCommunitySchema
9
+ from .._utils import logger
10
+ from ..prompt import GRAPH_FIELD_SEP
11
+
12
+ neo4j_lock = asyncio.Lock()
13
+
14
+
15
+ def make_path_idable(path):
16
+ return path.replace(".", "_").replace("/", "__").replace("-", "_").replace(":", "_").replace("\\", "__")
17
+
18
+
19
+ @dataclass
20
+ class Neo4jStorage(BaseGraphStorage):
21
+ def __post_init__(self):
22
+ self.neo4j_url = self.global_config["addon_params"].get("neo4j_url", None)
23
+ self.neo4j_auth = self.global_config["addon_params"].get("neo4j_auth", None)
24
+ self.namespace = (
25
+ f"{make_path_idable(self.global_config['working_dir'])}__{self.namespace}"
26
+ )
27
+ logger.info(f"Using the label {self.namespace} for Neo4j as identifier")
28
+ if self.neo4j_url is None or self.neo4j_auth is None:
29
+ raise ValueError("Missing neo4j_url or neo4j_auth in addon_params")
30
+ self.async_driver = AsyncGraphDatabase.driver(
31
+ self.neo4j_url, auth=self.neo4j_auth, max_connection_pool_size=50,
32
+ )
33
+
34
+ # async def create_database(self):
35
+ # async with self.async_driver.session() as session:
36
+ # try:
37
+ # constraints = await session.run("SHOW CONSTRAINTS")
38
+ # # TODO I don't know why CREATE CONSTRAINT IF NOT EXISTS still trigger error
39
+ # # so have to check if the constrain exists
40
+ # constrain_exists = False
41
+
42
+ # async for record in constraints:
43
+ # if (
44
+ # self.namespace in record["labelsOrTypes"]
45
+ # and "id" in record["properties"]
46
+ # and record["type"] == "UNIQUENESS"
47
+ # ):
48
+ # constrain_exists = True
49
+ # break
50
+ # if not constrain_exists:
51
+ # await session.run(
52
+ # f"CREATE CONSTRAINT FOR (n:{self.namespace}) REQUIRE n.id IS UNIQUE"
53
+ # )
54
+ # logger.info(f"Add constraint for namespace: {self.namespace}")
55
+
56
+ # except Exception as e:
57
+ # logger.error(f"Error accessing or setting up the database: {str(e)}")
58
+ # raise
59
+
60
+ async def _init_workspace(self):
61
+ await self.async_driver.verify_authentication()
62
+ await self.async_driver.verify_connectivity()
63
+ # TODOLater: create database if not exists always cause an error when async
64
+ # await self.create_database()
65
+
66
+ async def index_start_callback(self):
67
+ logger.info("Init Neo4j workspace")
68
+ await self._init_workspace()
69
+
70
+ # create index for faster searching
71
+ try:
72
+ async with self.async_driver.session() as session:
73
+ await session.run(
74
+ f"CREATE INDEX IF NOT EXISTS FOR (n:`{self.namespace}`) ON (n.id)"
75
+ )
76
+
77
+ await session.run(
78
+ f"CREATE INDEX IF NOT EXISTS FOR (n:`{self.namespace}`) ON (n.entity_type)"
79
+ )
80
+
81
+ await session.run(
82
+ f"CREATE INDEX IF NOT EXISTS FOR (n:`{self.namespace}`) ON (n.communityIds)"
83
+ )
84
+
85
+ await session.run(
86
+ f"CREATE INDEX IF NOT EXISTS FOR (n:`{self.namespace}`) ON (n.source_id)"
87
+ )
88
+ logger.info("Neo4j indexes created successfully")
89
+ except Exception as e:
90
+ logger.error(f"Failed to create indexes: {e}")
91
+ raise e
92
+
93
+ async def has_node(self, node_id: str) -> bool:
94
+ async with self.async_driver.session() as session:
95
+ result = await session.run(
96
+ f"MATCH (n:`{self.namespace}`) WHERE n.id = $node_id RETURN COUNT(n) > 0 AS exists",
97
+ node_id=node_id,
98
+ )
99
+ record = await result.single()
100
+ return record["exists"] if record else False
101
+
102
+ async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
103
+ async with self.async_driver.session() as session:
104
+ result = await session.run(
105
+ f"""
106
+ MATCH (s:`{self.namespace}`)
107
+ WHERE s.id = $source_id
108
+ MATCH (t:`{self.namespace}`)
109
+ WHERE t.id = $target_id
110
+ RETURN EXISTS((s)-[]->(t)) AS exists
111
+ """,
112
+ source_id=source_node_id,
113
+ target_id=target_node_id,
114
+ )
115
+
116
+ record = await result.single()
117
+ return record["exists"] if record else False
118
+
119
+ async def node_degree(self, node_id: str) -> int:
120
+ results = await self.node_degrees_batch([node_id])
121
+ return results[0] if results else 0
122
+
123
+ async def node_degrees_batch(self, node_ids: List[str]) -> List[str]:
124
+ if not node_ids:
125
+ return {}
126
+
127
+ result_dict = {node_id: 0 for node_id in node_ids}
128
+ async with self.async_driver.session() as session:
129
+ result = await session.run(
130
+ f"""
131
+ UNWIND $node_ids AS node_id
132
+ MATCH (n:`{self.namespace}`)
133
+ WHERE n.id = node_id
134
+ OPTIONAL MATCH (n)-[]-(m:`{self.namespace}`)
135
+ RETURN node_id, COUNT(m) AS degree
136
+ """,
137
+ node_ids=node_ids
138
+ )
139
+
140
+ async for record in result:
141
+ result_dict[record["node_id"]] = record["degree"]
142
+
143
+ return [result_dict[node_id] for node_id in node_ids]
144
+
145
+ async def edge_degree(self, src_id: str, tgt_id: str) -> int:
146
+ results = await self.edge_degrees_batch([(src_id, tgt_id)])
147
+ return results[0] if results else 0
148
+
149
+ async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> list[int]:
150
+ if not edge_pairs:
151
+ return []
152
+
153
+ result_dict = {tuple(edge_pair): 0 for edge_pair in edge_pairs}
154
+
155
+ edges_params = [{"src_id": src, "tgt_id": tgt} for src, tgt in edge_pairs]
156
+
157
+ try:
158
+ async with self.async_driver.session() as session:
159
+ result = await session.run(
160
+ f"""
161
+ UNWIND $edges AS edge
162
+
163
+ MATCH (s:`{self.namespace}`)
164
+ WHERE s.id = edge.src_id
165
+ WITH edge, s
166
+ OPTIONAL MATCH (s)-[]-(n1:`{self.namespace}`)
167
+ WITH edge, COUNT(n1) AS src_degree
168
+
169
+ MATCH (t:`{self.namespace}`)
170
+ WHERE t.id = edge.tgt_id
171
+ WITH edge, src_degree, t
172
+ OPTIONAL MATCH (t)-[]-(n2:`{self.namespace}`)
173
+ WITH edge.src_id AS src_id, edge.tgt_id AS tgt_id, src_degree, COUNT(n2) AS tgt_degree
174
+
175
+ RETURN src_id, tgt_id, src_degree + tgt_degree AS degree
176
+ """,
177
+ edges=edges_params
178
+ )
179
+
180
+ async for record in result:
181
+ src_id = record["src_id"]
182
+ tgt_id = record["tgt_id"]
183
+ degree = record["degree"]
184
+
185
+ # 更新结果字典
186
+ edge_pair = (src_id, tgt_id)
187
+ result_dict[edge_pair] = degree
188
+
189
+ return [result_dict[tuple(edge_pair)] for edge_pair in edge_pairs]
190
+ except Exception as e:
191
+ logger.error(f"Error in batch edge degree calculation: {e}")
192
+ return [0] * len(edge_pairs)
193
+
194
+
195
+
196
+ async def get_node(self, node_id: str) -> Union[dict, None]:
197
+ result = await self.get_nodes_batch([node_id])
198
+ return result[0] if result else None
199
+
200
+ async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, Union[dict, None]]:
201
+ if not node_ids:
202
+ return {}
203
+
204
+ result_dict = {node_id: None for node_id in node_ids}
205
+
206
+ try:
207
+ async with self.async_driver.session() as session:
208
+ result = await session.run(
209
+ f"""
210
+ UNWIND $node_ids AS node_id
211
+ MATCH (n:`{self.namespace}`)
212
+ WHERE n.id = node_id
213
+ RETURN node_id, properties(n) AS node_data
214
+ """,
215
+ node_ids=node_ids
216
+ )
217
+
218
+ async for record in result:
219
+ node_id = record["node_id"]
220
+ raw_node_data = record["node_data"]
221
+
222
+ if raw_node_data:
223
+ raw_node_data["clusters"] = json.dumps(
224
+ [
225
+ {
226
+ "level": index,
227
+ "cluster": cluster_id,
228
+ }
229
+ for index, cluster_id in enumerate(
230
+ raw_node_data.get("communityIds", [])
231
+ )
232
+ ]
233
+ )
234
+ result_dict[node_id] = raw_node_data
235
+ return [result_dict[node_id] for node_id in node_ids]
236
+ except Exception as e:
237
+ logger.error(f"Error in batch node retrieval: {e}")
238
+ raise e
239
+
240
+ async def get_edge(
241
+ self, source_node_id: str, target_node_id: str
242
+ ) -> Union[dict, None]:
243
+ results = await self.get_edges_batch([(source_node_id, target_node_id)])
244
+ return results[0] if results else None
245
+
246
+ async def get_edges_batch(
247
+ self, edge_pairs: list[tuple[str, str]]
248
+ ) -> list[Union[dict, None]]:
249
+ if not edge_pairs:
250
+ return []
251
+
252
+ result_dict = {tuple(edge_pair): None for edge_pair in edge_pairs}
253
+
254
+ edges_params = [{"source_id": src, "target_id": tgt} for src, tgt in edge_pairs]
255
+
256
+ try:
257
+ async with self.async_driver.session() as session:
258
+ result = await session.run(
259
+ f"""
260
+ UNWIND $edges AS edge
261
+ MATCH (s:`{self.namespace}`)-[r]->(t:`{self.namespace}`)
262
+ WHERE s.id = edge.source_id AND t.id = edge.target_id
263
+ RETURN edge.source_id AS source_id, edge.target_id AS target_id, properties(r) AS edge_data
264
+ """,
265
+ edges=edges_params
266
+ )
267
+
268
+ async for record in result:
269
+ source_id = record["source_id"]
270
+ target_id = record["target_id"]
271
+ edge_data = record["edge_data"]
272
+
273
+ edge_pair = (source_id, target_id)
274
+ result_dict[edge_pair] = edge_data
275
+
276
+ return [result_dict[tuple(edge_pair)] for edge_pair in edge_pairs]
277
+ except Exception as e:
278
+ logger.error(f"Error in batch edge retrieval: {e}")
279
+ return [None] * len(edge_pairs)
280
+
281
+ async def get_node_edges(
282
+ self, source_node_id: str
283
+ ) -> list[tuple[str, str]]:
284
+ results = await self.get_nodes_edges_batch([source_node_id])
285
+ return results[0] if results else []
286
+
287
+ async def get_nodes_edges_batch(
288
+ self, node_ids: list[str]
289
+ ) -> list[list[tuple[str, str]]]:
290
+ if not node_ids:
291
+ return []
292
+
293
+ result_dict = {node_id: [] for node_id in node_ids}
294
+
295
+ try:
296
+ async with self.async_driver.session() as session:
297
+ result = await session.run(
298
+ f"""
299
+ UNWIND $node_ids AS node_id
300
+ MATCH (s:`{self.namespace}`)-[r]->(t:`{self.namespace}`)
301
+ WHERE s.id = node_id
302
+ RETURN s.id AS source_id, t.id AS target_id
303
+ """,
304
+ node_ids=node_ids
305
+ )
306
+
307
+ async for record in result:
308
+ source_id = record["source_id"]
309
+ target_id = record["target_id"]
310
+
311
+ if source_id in result_dict:
312
+ result_dict[source_id].append((source_id, target_id))
313
+
314
+ return [result_dict[node_id] for node_id in node_ids]
315
+ except Exception as e:
316
+ logger.error(f"Error in batch node edges retrieval: {e}")
317
+ return [[] for _ in node_ids]
318
+
319
+ async def upsert_node(self, node_id: str, node_data: dict[str, str]):
320
+ await self.upsert_nodes_batch([(node_id, node_data)])
321
+
322
+ async def upsert_nodes_batch(self, nodes_data: list[tuple[str, dict[str, str]]]):
323
+ if not nodes_data:
324
+ return []
325
+
326
+ nodes_by_type = {}
327
+ for node_id, node_data in nodes_data:
328
+ node_type = node_data.get("entity_type", "UNKNOWN").strip('"')
329
+ if node_type not in nodes_by_type:
330
+ nodes_by_type[node_type] = []
331
+ nodes_by_type[node_type].append((node_id, node_data))
332
+
333
+ async with self.async_driver.session() as session:
334
+ for node_type, type_nodes in nodes_by_type.items():
335
+ params = [{"id": node_id, "data": node_data} for node_id, node_data in type_nodes]
336
+
337
+ await session.run(
338
+ f"""
339
+ UNWIND $nodes AS node
340
+ MERGE (n:`{self.namespace}`:`{node_type}` {{id: node.id}})
341
+ SET n += node.data
342
+ """,
343
+ nodes=params
344
+ )
345
+
346
+ async def upsert_edge(
347
+ self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
348
+ ):
349
+ await self.upsert_edges_batch([(source_node_id, target_node_id, edge_data)])
350
+
351
+
352
+ async def upsert_edges_batch(
353
+ self, edges_data: list[tuple[str, str, dict[str, str]]]
354
+ ):
355
+ if not edges_data:
356
+ return
357
+
358
+ edges_params = []
359
+ for source_id, target_id, edge_data in edges_data:
360
+ edge_data_copy = edge_data.copy()
361
+ edge_data_copy.setdefault("weight", 0.0)
362
+
363
+ edges_params.append({
364
+ "source_id": source_id,
365
+ "target_id": target_id,
366
+ "edge_data": edge_data_copy
367
+ })
368
+
369
+ async with self.async_driver.session() as session:
370
+ await session.run(
371
+ f"""
372
+ UNWIND $edges AS edge
373
+ MATCH (s:`{self.namespace}`)
374
+ WHERE s.id = edge.source_id
375
+ WITH edge, s
376
+ MATCH (t:`{self.namespace}`)
377
+ WHERE t.id = edge.target_id
378
+ MERGE (s)-[r:RELATED]->(t)
379
+ SET r += edge.edge_data
380
+ """,
381
+ edges=edges_params
382
+ )
383
+
384
+
385
+
386
+
387
+ async def clustering(self, algorithm: str):
388
+ if algorithm != "leiden":
389
+ raise ValueError(
390
+ f"Clustering algorithm {algorithm} not supported in Neo4j implementation"
391
+ )
392
+
393
+ random_seed = self.global_config["graph_cluster_seed"]
394
+ max_level = self.global_config["max_graph_cluster_size"]
395
+ async with self.async_driver.session() as session:
396
+ try:
397
+ # Project the graph with undirected relationships
398
+ await session.run(
399
+ f"""
400
+ CALL gds.graph.project(
401
+ 'graph_{self.namespace}',
402
+ ['{self.namespace}'],
403
+ {{
404
+ RELATED: {{
405
+ orientation: 'UNDIRECTED',
406
+ properties: ['weight']
407
+ }}
408
+ }}
409
+ )
410
+ """
411
+ )
412
+
413
+ # Run Leiden algorithm
414
+ result = await session.run(
415
+ f"""
416
+ CALL gds.leiden.write(
417
+ 'graph_{self.namespace}',
418
+ {{
419
+ writeProperty: 'communityIds',
420
+ includeIntermediateCommunities: True,
421
+ relationshipWeightProperty: "weight",
422
+ maxLevels: {max_level},
423
+ tolerance: 0.0001,
424
+ gamma: 1.0,
425
+ theta: 0.01,
426
+ randomSeed: {random_seed}
427
+ }}
428
+ )
429
+ YIELD communityCount, modularities;
430
+ """
431
+ )
432
+ result = await result.single()
433
+ community_count: int = result["communityCount"]
434
+ modularities = result["modularities"]
435
+ logger.info(
436
+ f"Performed graph clustering with {community_count} communities and modularities {modularities}"
437
+ )
438
+ finally:
439
+ # Drop the projected graph
440
+ await session.run(f"CALL gds.graph.drop('graph_{self.namespace}')")
441
+
442
+ async def community_schema(self) -> dict[str, SingleCommunitySchema]:
443
+ results = defaultdict(
444
+ lambda: dict(
445
+ level=None,
446
+ title=None,
447
+ edges=set(),
448
+ nodes=set(),
449
+ chunk_ids=set(),
450
+ occurrence=0.0,
451
+ sub_communities=[],
452
+ )
453
+ )
454
+
455
+ async with self.async_driver.session() as session:
456
+ # Fetch community data
457
+ result = await session.run(
458
+ f"""
459
+ MATCH (n:`{self.namespace}`)
460
+ WITH n, n.communityIds AS communityIds, [(n)-[]-(m:`{self.namespace}`) | m.id] AS connected_nodes
461
+ RETURN n.id AS node_id, n.source_id AS source_id,
462
+ communityIds AS cluster_key,
463
+ connected_nodes
464
+ """
465
+ )
466
+
467
+ # records = await result.fetch()
468
+
469
+ max_num_ids = 0
470
+ async for record in result:
471
+ for index, c_id in enumerate(record["cluster_key"]):
472
+ node_id = str(record["node_id"])
473
+ source_id = record["source_id"]
474
+ level = index
475
+ cluster_key = str(c_id)
476
+ connected_nodes = record["connected_nodes"]
477
+
478
+ results[cluster_key]["level"] = level
479
+ results[cluster_key]["title"] = f"Cluster {cluster_key}"
480
+ results[cluster_key]["nodes"].add(node_id)
481
+ results[cluster_key]["edges"].update(
482
+ [
483
+ tuple(sorted([node_id, str(connected)]))
484
+ for connected in connected_nodes
485
+ if connected != node_id
486
+ ]
487
+ )
488
+ chunk_ids = source_id.split(GRAPH_FIELD_SEP)
489
+ results[cluster_key]["chunk_ids"].update(chunk_ids)
490
+ max_num_ids = max(
491
+ max_num_ids, len(results[cluster_key]["chunk_ids"])
492
+ )
493
+
494
+ # Process results
495
+ for k, v in results.items():
496
+ v["edges"] = [list(e) for e in v["edges"]]
497
+ v["nodes"] = list(v["nodes"])
498
+ v["chunk_ids"] = list(v["chunk_ids"])
499
+ v["occurrence"] = len(v["chunk_ids"]) / max_num_ids
500
+
501
+ # Compute sub-communities (this is a simplified approach)
502
+ for cluster in results.values():
503
+ cluster["sub_communities"] = [
504
+ sub_key
505
+ for sub_key, sub_cluster in results.items()
506
+ if sub_cluster["level"] > cluster["level"]
507
+ and set(sub_cluster["nodes"]).issubset(set(cluster["nodes"]))
508
+ ]
509
+
510
+ return dict(results)
511
+
512
+ async def index_done_callback(self):
513
+ await self.async_driver.close()
514
+
515
+ async def _debug_delete_all_node_edges(self):
516
+ async with self.async_driver.session() as session:
517
+ try:
518
+ # Delete all relationships in the namespace
519
+ await session.run(f"MATCH (n:`{self.namespace}`)-[r]-() DELETE r")
520
+
521
+ # Delete all nodes in the namespace
522
+ await session.run(f"MATCH (n:`{self.namespace}`) DELETE n")
523
+
524
+ logger.info(
525
+ f"All nodes and edges in namespace '{self.namespace}' have been deleted."
526
+ )
527
+ except Exception as e:
528
+ logger.error(f"Error deleting nodes and edges: {str(e)}")
529
+ raise
nano-graphrag/nano_graphrag/_storage/gdb_networkx.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import html
2
+ import json
3
+ import os
4
+ from collections import defaultdict
5
+ from dataclasses import dataclass
6
+ from typing import Any, Union, cast, List
7
+ import networkx as nx
8
+ import numpy as np
9
+ import asyncio
10
+
11
+ from .._utils import logger
12
+ from ..base import (
13
+ BaseGraphStorage,
14
+ SingleCommunitySchema,
15
+ )
16
+ from ..prompt import GRAPH_FIELD_SEP
17
+
18
+
19
+ @dataclass
20
+ class NetworkXStorage(BaseGraphStorage):
21
+ @staticmethod
22
+ def load_nx_graph(file_name) -> nx.Graph:
23
+ if os.path.exists(file_name):
24
+ return nx.read_graphml(file_name)
25
+ return None
26
+
27
+ @staticmethod
28
+ def write_nx_graph(graph: nx.Graph, file_name):
29
+ logger.info(
30
+ f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges"
31
+ )
32
+ nx.write_graphml(graph, file_name)
33
+
34
+ @staticmethod
35
+ def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
36
+ """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
37
+ Return the largest connected component of the graph, with nodes and edges sorted in a stable way.
38
+ """
39
+ from graspologic.utils import largest_connected_component
40
+
41
+ graph = graph.copy()
42
+ graph = cast(nx.Graph, largest_connected_component(graph))
43
+ node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()} # type: ignore
44
+ graph = nx.relabel_nodes(graph, node_mapping)
45
+ return NetworkXStorage._stabilize_graph(graph)
46
+
47
+ @staticmethod
48
+ def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
49
+ """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
50
+ Ensure an undirected graph with the same relationships will always be read the same way.
51
+ """
52
+ fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
53
+
54
+ sorted_nodes = graph.nodes(data=True)
55
+ sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
56
+
57
+ fixed_graph.add_nodes_from(sorted_nodes)
58
+ edges = list(graph.edges(data=True))
59
+
60
+ if not graph.is_directed():
61
+
62
+ def _sort_source_target(edge):
63
+ source, target, edge_data = edge
64
+ if source > target:
65
+ temp = source
66
+ source = target
67
+ target = temp
68
+ return source, target, edge_data
69
+
70
+ edges = [_sort_source_target(edge) for edge in edges]
71
+
72
+ def _get_edge_key(source: Any, target: Any) -> str:
73
+ return f"{source} -> {target}"
74
+
75
+ edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
76
+
77
+ fixed_graph.add_edges_from(edges)
78
+ return fixed_graph
79
+
80
+ def __post_init__(self):
81
+ self._graphml_xml_file = os.path.join(
82
+ self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
83
+ )
84
+ preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
85
+ if preloaded_graph is not None:
86
+ logger.info(
87
+ f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
88
+ )
89
+ self._graph = preloaded_graph or nx.Graph()
90
+ self._clustering_algorithms = {
91
+ "leiden": self._leiden_clustering,
92
+ }
93
+ self._node_embed_algorithms = {
94
+ "node2vec": self._node2vec_embed,
95
+ }
96
+
97
+ async def index_done_callback(self):
98
+ NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
99
+
100
+ async def has_node(self, node_id: str) -> bool:
101
+ return self._graph.has_node(node_id)
102
+
103
+ async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
104
+ return self._graph.has_edge(source_node_id, target_node_id)
105
+
106
+ async def get_node(self, node_id: str) -> Union[dict, None]:
107
+ return self._graph.nodes.get(node_id)
108
+
109
+ async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, Union[dict, None]]:
110
+ return await asyncio.gather(*[self.get_node(node_id) for node_id in node_ids])
111
+
112
+ async def node_degree(self, node_id: str) -> int:
113
+ # [numberchiffre]: node_id not part of graph returns `DegreeView({})` instead of 0
114
+ return self._graph.degree(node_id) if self._graph.has_node(node_id) else 0
115
+
116
+ async def node_degrees_batch(self, node_ids: List[str]) -> List[str]:
117
+ return await asyncio.gather(*[self.node_degree(node_id) for node_id in node_ids])
118
+
119
+ async def edge_degree(self, src_id: str, tgt_id: str) -> int:
120
+ return (self._graph.degree(src_id) if self._graph.has_node(src_id) else 0) + (
121
+ self._graph.degree(tgt_id) if self._graph.has_node(tgt_id) else 0
122
+ )
123
+
124
+ async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> list[int]:
125
+ return await asyncio.gather(*[self.edge_degree(src_id, tgt_id) for src_id, tgt_id in edge_pairs])
126
+
127
+ async def get_edge(
128
+ self, source_node_id: str, target_node_id: str
129
+ ) -> Union[dict, None]:
130
+ return self._graph.edges.get((source_node_id, target_node_id))
131
+
132
+ async def get_edges_batch(
133
+ self, edge_pairs: list[tuple[str, str]]
134
+ ) -> list[Union[dict, None]]:
135
+ return await asyncio.gather(*[self.get_edge(source_node_id, target_node_id) for source_node_id, target_node_id in edge_pairs])
136
+
137
+ async def get_node_edges(self, source_node_id: str):
138
+ if self._graph.has_node(source_node_id):
139
+ return list(self._graph.edges(source_node_id))
140
+ return None
141
+
142
+ async def get_nodes_edges_batch(
143
+ self, node_ids: list[str]
144
+ ) -> list[list[tuple[str, str]]]:
145
+ return await asyncio.gather(*[self.get_node_edges(node_id) for node_id
146
+ in node_ids])
147
+
148
+ async def upsert_node(self, node_id: str, node_data: dict[str, str]):
149
+ self._graph.add_node(node_id, **node_data)
150
+
151
+ async def upsert_nodes_batch(self, nodes_data: list[tuple[str, dict[str, str]]]):
152
+ await asyncio.gather(*[self.upsert_node(node_id, node_data) for node_id, node_data in nodes_data])
153
+
154
+ async def upsert_edge(
155
+ self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
156
+ ):
157
+ self._graph.add_edge(source_node_id, target_node_id, **edge_data)
158
+
159
+ async def upsert_edges_batch(
160
+ self, edges_data: list[tuple[str, str, dict[str, str]]]
161
+ ):
162
+ await asyncio.gather(*[self.upsert_edge(source_node_id, target_node_id, edge_data)
163
+ for source_node_id, target_node_id, edge_data in edges_data])
164
+
165
+ async def clustering(self, algorithm: str):
166
+ if algorithm not in self._clustering_algorithms:
167
+ raise ValueError(f"Clustering algorithm {algorithm} not supported")
168
+ await self._clustering_algorithms[algorithm]()
169
+
170
+ async def community_schema(self) -> dict[str, SingleCommunitySchema]:
171
+ results = defaultdict(
172
+ lambda: dict(
173
+ level=None,
174
+ title=None,
175
+ edges=set(),
176
+ nodes=set(),
177
+ chunk_ids=set(),
178
+ occurrence=0.0,
179
+ sub_communities=[],
180
+ )
181
+ )
182
+ max_num_ids = 0
183
+ levels = defaultdict(set)
184
+ for node_id, node_data in self._graph.nodes(data=True):
185
+ if "clusters" not in node_data:
186
+ continue
187
+ clusters = json.loads(node_data["clusters"])
188
+ this_node_edges = self._graph.edges(node_id)
189
+
190
+ for cluster in clusters:
191
+ level = cluster["level"]
192
+ cluster_key = str(cluster["cluster"])
193
+ levels[level].add(cluster_key)
194
+ results[cluster_key]["level"] = level
195
+ results[cluster_key]["title"] = f"Cluster {cluster_key}"
196
+ results[cluster_key]["nodes"].add(node_id)
197
+ results[cluster_key]["edges"].update(
198
+ [tuple(sorted(e)) for e in this_node_edges]
199
+ )
200
+ results[cluster_key]["chunk_ids"].update(
201
+ node_data["source_id"].split(GRAPH_FIELD_SEP)
202
+ )
203
+ max_num_ids = max(max_num_ids, len(results[cluster_key]["chunk_ids"]))
204
+
205
+ ordered_levels = sorted(levels.keys())
206
+ for i, curr_level in enumerate(ordered_levels[:-1]):
207
+ next_level = ordered_levels[i + 1]
208
+ this_level_comms = levels[curr_level]
209
+ next_level_comms = levels[next_level]
210
+ # compute the sub-communities by nodes intersection
211
+ for comm in this_level_comms:
212
+ results[comm]["sub_communities"] = [
213
+ c
214
+ for c in next_level_comms
215
+ if results[c]["nodes"].issubset(results[comm]["nodes"])
216
+ ]
217
+
218
+ for k, v in results.items():
219
+ v["edges"] = list(v["edges"])
220
+ v["edges"] = [list(e) for e in v["edges"]]
221
+ v["nodes"] = list(v["nodes"])
222
+ v["chunk_ids"] = list(v["chunk_ids"])
223
+ v["occurrence"] = len(v["chunk_ids"]) / max_num_ids
224
+ return dict(results)
225
+
226
+ def _cluster_data_to_subgraphs(self, cluster_data: dict[str, list[dict[str, str]]]):
227
+ for node_id, clusters in cluster_data.items():
228
+ self._graph.nodes[node_id]["clusters"] = json.dumps(clusters)
229
+
230
+ async def _leiden_clustering(self):
231
+ from graspologic.partition import hierarchical_leiden
232
+
233
+ graph = NetworkXStorage.stable_largest_connected_component(self._graph)
234
+ community_mapping = hierarchical_leiden(
235
+ graph,
236
+ max_cluster_size=self.global_config["max_graph_cluster_size"],
237
+ random_seed=self.global_config["graph_cluster_seed"],
238
+ )
239
+
240
+ node_communities: dict[str, list[dict[str, str]]] = defaultdict(list)
241
+ __levels = defaultdict(set)
242
+ for partition in community_mapping:
243
+ level_key = partition.level
244
+ cluster_id = partition.cluster
245
+ node_communities[partition.node].append(
246
+ {"level": level_key, "cluster": cluster_id}
247
+ )
248
+ __levels[level_key].add(cluster_id)
249
+ node_communities = dict(node_communities)
250
+ __levels = {k: len(v) for k, v in __levels.items()}
251
+ logger.info(f"Each level has communities: {dict(__levels)}")
252
+ self._cluster_data_to_subgraphs(node_communities)
253
+
254
+ async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
255
+ if algorithm not in self._node_embed_algorithms:
256
+ raise ValueError(f"Node embedding algorithm {algorithm} not supported")
257
+ return await self._node_embed_algorithms[algorithm]()
258
+
259
+ async def _node2vec_embed(self):
260
+ from graspologic import embed
261
+
262
+ embeddings, nodes = embed.node2vec_embed(
263
+ self._graph,
264
+ **self.global_config["node2vec_params"],
265
+ )
266
+
267
+ nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
268
+ return embeddings, nodes_ids
nano-graphrag/nano_graphrag/_storage/kv_json.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass
3
+
4
+ from .._utils import load_json, logger, write_json
5
+ from ..base import (
6
+ BaseKVStorage,
7
+ )
8
+
9
+
10
+ @dataclass
11
+ class JsonKVStorage(BaseKVStorage):
12
+ def __post_init__(self):
13
+ working_dir = self.global_config["working_dir"]
14
+ self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
15
+ self._data = load_json(self._file_name) or {}
16
+ logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
17
+
18
+ async def all_keys(self) -> list[str]:
19
+ return list(self._data.keys())
20
+
21
+ async def index_done_callback(self):
22
+ write_json(self._data, self._file_name)
23
+
24
+ async def get_by_id(self, id):
25
+ return self._data.get(id, None)
26
+
27
+ async def get_by_ids(self, ids, fields=None):
28
+ if fields is None:
29
+ return [self._data.get(id, None) for id in ids]
30
+ return [
31
+ (
32
+ {k: v for k, v in self._data[id].items() if k in fields}
33
+ if self._data.get(id, None)
34
+ else None
35
+ )
36
+ for id in ids
37
+ ]
38
+
39
+ async def filter_keys(self, data: list[str]) -> set[str]:
40
+ return set([s for s in data if s not in self._data])
41
+
42
+ async def upsert(self, data: dict[str, dict]):
43
+ self._data.update(data)
44
+
45
+ async def drop(self):
46
+ self._data = {}
nano-graphrag/nano_graphrag/_storage/vdb_hnswlib.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ from dataclasses import dataclass, field
4
+ from typing import Any
5
+ import pickle
6
+ import hnswlib
7
+ import numpy as np
8
+ import xxhash
9
+
10
+ from .._utils import logger
11
+ from ..base import BaseVectorStorage
12
+
13
+
14
+ @dataclass
15
+ class HNSWVectorStorage(BaseVectorStorage):
16
+ ef_construction: int = 100
17
+ M: int = 16
18
+ max_elements: int = 1000000
19
+ ef_search: int = 50
20
+ num_threads: int = -1
21
+ _index: Any = field(init=False)
22
+ _metadata: dict[str, dict] = field(default_factory=dict)
23
+ _current_elements: int = 0
24
+
25
+ def __post_init__(self):
26
+ self._index_file_name = os.path.join(
27
+ self.global_config["working_dir"], f"{self.namespace}_hnsw.index"
28
+ )
29
+ self._metadata_file_name = os.path.join(
30
+ self.global_config["working_dir"], f"{self.namespace}_hnsw_metadata.pkl"
31
+ )
32
+ self._embedding_batch_num = self.global_config.get("embedding_batch_num", 100)
33
+
34
+ hnsw_params = self.global_config.get("vector_db_storage_cls_kwargs", {})
35
+ self.ef_construction = hnsw_params.get("ef_construction", self.ef_construction)
36
+ self.M = hnsw_params.get("M", self.M)
37
+ self.max_elements = hnsw_params.get("max_elements", self.max_elements)
38
+ self.ef_search = hnsw_params.get("ef_search", self.ef_search)
39
+ self.num_threads = hnsw_params.get("num_threads", self.num_threads)
40
+ self._index = hnswlib.Index(
41
+ space="cosine", dim=self.embedding_func.embedding_dim
42
+ )
43
+
44
+ if os.path.exists(self._index_file_name) and os.path.exists(
45
+ self._metadata_file_name
46
+ ):
47
+ self._index.load_index(
48
+ self._index_file_name, max_elements=self.max_elements
49
+ )
50
+ with open(self._metadata_file_name, "rb") as f:
51
+ self._metadata, self._current_elements = pickle.load(f)
52
+ logger.info(
53
+ f"Loaded existing index for {self.namespace} with {self._current_elements} elements"
54
+ )
55
+ else:
56
+ self._index.init_index(
57
+ max_elements=self.max_elements,
58
+ ef_construction=self.ef_construction,
59
+ M=self.M,
60
+ )
61
+ self._index.set_ef(self.ef_search)
62
+ self._metadata = {}
63
+ self._current_elements = 0
64
+ logger.info(f"Created new index for {self.namespace}")
65
+
66
+ async def upsert(self, data: dict[str, dict]) -> np.ndarray:
67
+ logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
68
+ if not data:
69
+ logger.warning("You insert an empty data to vector DB")
70
+ return []
71
+
72
+ if self._current_elements + len(data) > self.max_elements:
73
+ raise ValueError(
74
+ f"Cannot insert {len(data)} elements. Current: {self._current_elements}, Max: {self.max_elements}"
75
+ )
76
+
77
+ list_data = [
78
+ {
79
+ "id": k,
80
+ **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
81
+ }
82
+ for k, v in data.items()
83
+ ]
84
+ contents = [v["content"] for v in data.values()]
85
+ batch_size = min(self._embedding_batch_num, len(contents))
86
+ embeddings = np.concatenate(
87
+ await asyncio.gather(
88
+ *[
89
+ self.embedding_func(contents[i : i + batch_size])
90
+ for i in range(0, len(contents), batch_size)
91
+ ]
92
+ )
93
+ )
94
+
95
+ ids = np.fromiter(
96
+ (xxhash.xxh32_intdigest(d["id"].encode()) for d in list_data),
97
+ dtype=np.uint32,
98
+ count=len(list_data),
99
+ )
100
+ self._metadata.update(
101
+ {
102
+ id_int: {
103
+ k: v for k, v in d.items() if k in self.meta_fields or k == "id"
104
+ }
105
+ for id_int, d in zip(ids, list_data)
106
+ }
107
+ )
108
+ self._index.add_items(data=embeddings, ids=ids, num_threads=self.num_threads)
109
+ self._current_elements = self._index.get_current_count()
110
+ return ids
111
+
112
+ async def query(self, query: str, top_k: int = 5) -> list[dict]:
113
+ if self._current_elements == 0:
114
+ return []
115
+
116
+ top_k = min(top_k, self._current_elements)
117
+
118
+ if top_k > self.ef_search:
119
+ logger.warning(
120
+ f"Setting ef_search to {top_k} because top_k is larger than ef_search"
121
+ )
122
+ self._index.set_ef(top_k)
123
+
124
+ embedding = await self.embedding_func([query])
125
+ labels, distances = self._index.knn_query(
126
+ data=embedding[0], k=top_k, num_threads=self.num_threads
127
+ )
128
+
129
+ return [
130
+ {
131
+ **self._metadata.get(label, {}),
132
+ "distance": distance,
133
+ "similarity": 1 - distance,
134
+ }
135
+ for label, distance in zip(labels[0], distances[0])
136
+ ]
137
+
138
+ async def index_done_callback(self):
139
+ self._index.save_index(self._index_file_name)
140
+ with open(self._metadata_file_name, "wb") as f:
141
+ pickle.dump((self._metadata, self._current_elements), f)
nano-graphrag/nano_graphrag/_storage/vdb_nanovectordb.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ from dataclasses import dataclass
4
+ import numpy as np
5
+ from nano_vectordb import NanoVectorDB
6
+
7
+ from .._utils import logger
8
+ from ..base import BaseVectorStorage
9
+
10
+
11
+ @dataclass
12
+ class NanoVectorDBStorage(BaseVectorStorage):
13
+ cosine_better_than_threshold: float = 0.2
14
+
15
+ def __post_init__(self):
16
+
17
+ self._client_file_name = os.path.join(
18
+ self.global_config["working_dir"], f"vdb_{self.namespace}.json"
19
+ )
20
+ self._max_batch_size = self.global_config["embedding_batch_num"]
21
+ self._client = NanoVectorDB(
22
+ self.embedding_func.embedding_dim, storage_file=self._client_file_name
23
+ )
24
+ self.cosine_better_than_threshold = self.global_config.get(
25
+ "query_better_than_threshold", self.cosine_better_than_threshold
26
+ )
27
+
28
+ async def upsert(self, data: dict[str, dict]):
29
+ logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
30
+ if not len(data):
31
+ logger.warning("You insert an empty data to vector DB")
32
+ return []
33
+ list_data = [
34
+ {
35
+ "__id__": k,
36
+ **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
37
+ }
38
+ for k, v in data.items()
39
+ ]
40
+ contents = [v["content"] for v in data.values()]
41
+ batches = [
42
+ contents[i : i + self._max_batch_size]
43
+ for i in range(0, len(contents), self._max_batch_size)
44
+ ]
45
+ embeddings_list = await asyncio.gather(
46
+ *[self.embedding_func(batch) for batch in batches]
47
+ )
48
+ embeddings = np.concatenate(embeddings_list)
49
+ for i, d in enumerate(list_data):
50
+ d["__vector__"] = embeddings[i]
51
+ results = self._client.upsert(datas=list_data)
52
+ return results
53
+
54
+ async def query(self, query: str, top_k=5):
55
+ embedding = await self.embedding_func([query])
56
+ embedding = embedding[0]
57
+ results = self._client.query(
58
+ query=embedding,
59
+ top_k=top_k,
60
+ better_than_threshold=self.cosine_better_than_threshold,
61
+ )
62
+ results = [
63
+ {**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results
64
+ ]
65
+ return results
66
+
67
+ async def index_done_callback(self):
68
+ self._client.save()
nano-graphrag/nano_graphrag/_utils.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import html
3
+ import json
4
+ import logging
5
+ import os
6
+ import re
7
+ import numbers
8
+ from dataclasses import dataclass
9
+ from functools import wraps
10
+ from hashlib import md5
11
+ from typing import Any, Union, Literal
12
+
13
+ import numpy as np
14
+ import tiktoken
15
+
16
+
17
+ from transformers import AutoTokenizer
18
+
19
+ logger = logging.getLogger("nano-graphrag")
20
+ logging.getLogger("neo4j").setLevel(logging.ERROR)
21
+
22
+ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
23
+ try:
24
+ # If there is already an event loop, use it.
25
+ loop = asyncio.get_event_loop()
26
+ except RuntimeError:
27
+ # If in a sub-thread, create a new event loop.
28
+ logger.info("Creating a new event loop in a sub-thread.")
29
+ loop = asyncio.new_event_loop()
30
+ asyncio.set_event_loop(loop)
31
+ return loop
32
+
33
+
34
+ def extract_first_complete_json(s: str):
35
+ """Extract the first complete JSON object from the string using a stack to track braces."""
36
+ stack = []
37
+ first_json_start = None
38
+
39
+ for i, char in enumerate(s):
40
+ if char == '{':
41
+ stack.append(i)
42
+ if first_json_start is None:
43
+ first_json_start = i
44
+ elif char == '}':
45
+ if stack:
46
+ start = stack.pop()
47
+ if not stack:
48
+ first_json_str = s[first_json_start:i+1]
49
+ try:
50
+ # Attempt to parse the JSON string
51
+ return json.loads(first_json_str.replace("\n", ""))
52
+ except json.JSONDecodeError as e:
53
+ logger.error(f"JSON decoding failed: {e}. Attempted string: {first_json_str[:50]}...")
54
+ return None
55
+ finally:
56
+ first_json_start = None
57
+ logger.warning("No complete JSON object found in the input string.")
58
+ return None
59
+
60
+ def parse_value(value: str):
61
+ """Convert a string value to its appropriate type (int, float, bool, None, or keep as string). Work as a more broad 'eval()'"""
62
+ value = value.strip()
63
+
64
+ if value == "null":
65
+ return None
66
+ elif value == "true":
67
+ return True
68
+ elif value == "false":
69
+ return False
70
+ else:
71
+ # Try to convert to int or float
72
+ try:
73
+ if '.' in value: # If there's a dot, it might be a float
74
+ return float(value)
75
+ else:
76
+ return int(value)
77
+ except ValueError:
78
+ # If conversion fails, return the value as-is (likely a string)
79
+ return value.strip('"') # Remove surrounding quotes if they exist
80
+
81
+ def extract_values_from_json(json_string, keys=["reasoning", "answer", "data"], allow_no_quotes=False):
82
+ """Extract key values from a non-standard or malformed JSON string, handling nested objects."""
83
+ extracted_values = {}
84
+
85
+ # Enhanced pattern to match both quoted and unquoted values, as well as nested objects
86
+ regex_pattern = r'(?P<key>"?\w+"?)\s*:\s*(?P<value>{[^}]*}|".*?"|[^,}]+)'
87
+
88
+ for match in re.finditer(regex_pattern, json_string, re.DOTALL):
89
+ key = match.group('key').strip('"') # Strip quotes from key
90
+ value = match.group('value').strip()
91
+
92
+ # If the value is another nested JSON (starts with '{' and ends with '}'), recursively parse it
93
+ if value.startswith('{') and value.endswith('}'):
94
+ extracted_values[key] = extract_values_from_json(value)
95
+ else:
96
+ # Parse the value into the appropriate type (int, float, bool, etc.)
97
+ extracted_values[key] = parse_value(value)
98
+
99
+ if not extracted_values:
100
+ logger.warning("No values could be extracted from the string.")
101
+
102
+ return extracted_values
103
+
104
+
105
+ def convert_response_to_json(response: str) -> dict:
106
+ """Convert response string to JSON, with error handling and fallback to non-standard JSON extraction."""
107
+ prediction_json = extract_first_complete_json(response)
108
+
109
+ if prediction_json is None:
110
+ logger.info("Attempting to extract values from a non-standard JSON string...")
111
+ prediction_json = extract_values_from_json(response, allow_no_quotes=True)
112
+
113
+ if not prediction_json:
114
+ logger.error("Unable to extract meaningful data from the response.")
115
+ else:
116
+ logger.info("JSON data successfully extracted.")
117
+
118
+ return prediction_json
119
+
120
+
121
+
122
+
123
+ class TokenizerWrapper:
124
+ def __init__(self, tokenizer_type: Literal["tiktoken", "huggingface"] = "tiktoken", model_name: str = "gpt-4o"):
125
+ self.tokenizer_type = tokenizer_type
126
+ self.model_name = model_name
127
+ self._tokenizer = None
128
+ self._lazy_load_tokenizer()
129
+
130
+ def _lazy_load_tokenizer(self):
131
+ if self._tokenizer is not None:
132
+ return
133
+ logger.info(f"Loading tokenizer: type='{self.tokenizer_type}', name='{self.model_name}'")
134
+ if self.tokenizer_type == "tiktoken":
135
+ self._tokenizer = tiktoken.encoding_for_model(self.model_name)
136
+ elif self.tokenizer_type == "huggingface":
137
+ if AutoTokenizer is None:
138
+ raise ImportError("`transformers` is not installed. Please install it via `pip install transformers` to use HuggingFace tokenizers.")
139
+ self._tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)
140
+ else:
141
+ raise ValueError(f"Unknown tokenizer_type: {self.tokenizer_type}")
142
+
143
+ def get_tokenizer(self):
144
+ """提供对底层 tokenizer 对象的访问,用于特殊情况(如 decode_batch)。"""
145
+ self._lazy_load_tokenizer()
146
+ return self._tokenizer
147
+
148
+ def encode(self, text: str) -> list[int]:
149
+ self._lazy_load_tokenizer()
150
+ return self._tokenizer.encode(text)
151
+
152
+ def decode(self, tokens: list[int]) -> str:
153
+ self._lazy_load_tokenizer()
154
+ return self._tokenizer.decode(tokens)
155
+
156
+ # +++ 新增 +++: 增加一个批量解码的方法以提高效率,并保持接口一致性
157
+ def decode_batch(self, tokens_list: list[list[int]]) -> list[str]:
158
+ self._lazy_load_tokenizer()
159
+ # HuggingFace tokenizer 有 decode_batch,但 tiktoken 没有,我们用列表推导来模拟
160
+ if self.tokenizer_type == "tiktoken":
161
+ return [self._tokenizer.decode(tokens) for tokens in tokens_list]
162
+ elif self.tokenizer_type == "huggingface":
163
+ return self._tokenizer.batch_decode(tokens_list, skip_special_tokens=True)
164
+ else:
165
+ raise ValueError(f"Unknown tokenizer_type: {self.tokenizer_type}")
166
+
167
+
168
+
169
+ def truncate_list_by_token_size(
170
+ list_data: list,
171
+ key: callable,
172
+ max_token_size: int,
173
+ tokenizer_wrapper: TokenizerWrapper
174
+ ):
175
+ """Truncate a list of data by token size using a provided tokenizer wrapper."""
176
+ if max_token_size <= 0:
177
+ return []
178
+ tokens = 0
179
+ for i, data in enumerate(list_data):
180
+ tokens += len(tokenizer_wrapper.encode(key(data))) + 1 # 防御性,模拟通过\n拼接列表的情况
181
+ if tokens > max_token_size:
182
+ return list_data[:i]
183
+ return list_data
184
+
185
+
186
+ def compute_mdhash_id(content, prefix: str = ""):
187
+ return prefix + md5(content.encode()).hexdigest()
188
+
189
+
190
+ def write_json(json_obj, file_name):
191
+ with open(file_name, "w", encoding="utf-8") as f:
192
+ json.dump(json_obj, f, indent=2, ensure_ascii=False)
193
+
194
+
195
+ def load_json(file_name):
196
+ if not os.path.exists(file_name):
197
+ return None
198
+ with open(file_name, encoding="utf-8") as f:
199
+ return json.load(f)
200
+
201
+
202
+ # it's dirty to type, so it's a good way to have fun
203
+ def pack_user_ass_to_openai_messages(prompt: str, generated_content: str, using_amazon_bedrock: bool):
204
+ if using_amazon_bedrock:
205
+ return [
206
+ {"role": "user", "content": [{"text": prompt}]},
207
+ {"role": "assistant", "content": [{"text": generated_content}]},
208
+ ]
209
+ else:
210
+ return [
211
+ {"role": "user", "content": prompt},
212
+ {"role": "assistant", "content": generated_content},
213
+ ]
214
+
215
+
216
+ def is_float_regex(value):
217
+ return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
218
+
219
+
220
+ def compute_args_hash(*args):
221
+ return md5(str(args).encode()).hexdigest()
222
+
223
+
224
+ def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
225
+ """Split a string by multiple markers"""
226
+ if not markers:
227
+ return [content]
228
+ results = re.split("|".join(re.escape(marker) for marker in markers), content)
229
+ return [r.strip() for r in results if r.strip()]
230
+
231
+
232
+ def enclose_string_with_quotes(content: Any) -> str:
233
+ """Enclose a string with quotes"""
234
+ if isinstance(content, numbers.Number):
235
+ return str(content)
236
+ content = str(content)
237
+ content = content.strip().strip("'").strip('"')
238
+ return f'"{content}"'
239
+
240
+
241
+ def list_of_list_to_csv(data: list[list]):
242
+ return "\n".join(
243
+ [
244
+ ",\t".join([f"{enclose_string_with_quotes(data_dd)}" for data_dd in data_d])
245
+ for data_d in data
246
+ ]
247
+ )
248
+
249
+
250
+ # -----------------------------------------------------------------------------------
251
+ # Refer the utils functions of the official GraphRAG implementation:
252
+ # https://github.com/microsoft/graphrag
253
+ def clean_str(input: Any) -> str:
254
+ """Clean an input string by removing HTML escapes, control characters, and other unwanted characters."""
255
+ # If we get non-string input, just give it back
256
+ if not isinstance(input, str):
257
+ return input
258
+
259
+ result = html.unescape(input.strip())
260
+ # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
261
+ return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
262
+
263
+
264
+ # Utils types -----------------------------------------------------------------------
265
+ @dataclass
266
+ class EmbeddingFunc:
267
+ embedding_dim: int
268
+ max_token_size: int
269
+ func: callable
270
+
271
+ async def __call__(self, *args, **kwargs) -> np.ndarray:
272
+ return await self.func(*args, **kwargs)
273
+
274
+
275
+ # Decorators ------------------------------------------------------------------------
276
+ def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
277
+ """Add restriction of maximum async calling times for a async func"""
278
+
279
+ def final_decro(func):
280
+ """Not using async.Semaphore to aovid use nest-asyncio"""
281
+ __current_size = 0
282
+
283
+ @wraps(func)
284
+ async def wait_func(*args, **kwargs):
285
+ nonlocal __current_size
286
+ while __current_size >= max_size:
287
+ await asyncio.sleep(waitting_time)
288
+ __current_size += 1
289
+ result = await func(*args, **kwargs)
290
+ __current_size -= 1
291
+ return result
292
+
293
+ return wait_func
294
+
295
+ return final_decro
296
+
297
+
298
+ def wrap_embedding_func_with_attrs(**kwargs):
299
+ """Wrap a function with attributes"""
300
+
301
+ def final_decro(func) -> EmbeddingFunc:
302
+ new_func = EmbeddingFunc(**kwargs, func=func)
303
+ return new_func
304
+
305
+ return final_decro
nano-graphrag/nano_graphrag/base.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import TypedDict, Union, Literal, Generic, TypeVar, List
3
+
4
+ import numpy as np
5
+
6
+ from ._utils import EmbeddingFunc
7
+
8
+
9
+ @dataclass
10
+ class QueryParam:
11
+ mode: Literal["local", "global", "naive"] = "global"
12
+ only_need_context: bool = False
13
+ response_type: str = "Multiple Paragraphs"
14
+ level: int = 2
15
+ top_k: int = 20
16
+ # naive search
17
+ naive_max_token_for_text_unit = 12000
18
+ # local search
19
+ local_max_token_for_text_unit: int = 4000 # 12000 * 0.33
20
+ local_max_token_for_local_context: int = 4800 # 12000 * 0.4
21
+ local_max_token_for_community_report: int = 3200 # 12000 * 0.27
22
+ local_community_single_one: bool = False
23
+ # global search
24
+ global_min_community_rating: float = 0
25
+ global_max_consider_community: float = 512
26
+ global_max_token_for_community_report: int = 16384
27
+ global_special_community_map_llm_kwargs: dict = field(
28
+ default_factory=lambda: {"response_format": {"type": "json_object"}}
29
+ )
30
+
31
+
32
+ TextChunkSchema = TypedDict(
33
+ "TextChunkSchema",
34
+ {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int},
35
+ )
36
+
37
+ SingleCommunitySchema = TypedDict(
38
+ "SingleCommunitySchema",
39
+ {
40
+ "level": int,
41
+ "title": str,
42
+ "edges": list[list[str, str]],
43
+ "nodes": list[str],
44
+ "chunk_ids": list[str],
45
+ "occurrence": float,
46
+ "sub_communities": list[str],
47
+ },
48
+ )
49
+
50
+
51
+ class CommunitySchema(SingleCommunitySchema):
52
+ report_string: str
53
+ report_json: dict
54
+
55
+
56
+ T = TypeVar("T")
57
+
58
+
59
+ @dataclass
60
+ class StorageNameSpace:
61
+ namespace: str
62
+ global_config: dict
63
+
64
+ async def index_start_callback(self):
65
+ """commit the storage operations after indexing"""
66
+ pass
67
+
68
+ async def index_done_callback(self):
69
+ """commit the storage operations after indexing"""
70
+ pass
71
+
72
+ async def query_done_callback(self):
73
+ """commit the storage operations after querying"""
74
+ pass
75
+
76
+
77
+ @dataclass
78
+ class BaseVectorStorage(StorageNameSpace):
79
+ embedding_func: EmbeddingFunc
80
+ meta_fields: set = field(default_factory=set)
81
+
82
+ async def query(self, query: str, top_k: int) -> list[dict]:
83
+ raise NotImplementedError
84
+
85
+ async def upsert(self, data: dict[str, dict]):
86
+ """Use 'content' field from value for embedding, use key as id.
87
+ If embedding_func is None, use 'embedding' field from value
88
+ """
89
+ raise NotImplementedError
90
+
91
+
92
+ @dataclass
93
+ class BaseKVStorage(Generic[T], StorageNameSpace):
94
+ async def all_keys(self) -> list[str]:
95
+ raise NotImplementedError
96
+
97
+ async def get_by_id(self, id: str) -> Union[T, None]:
98
+ raise NotImplementedError
99
+
100
+ async def get_by_ids(
101
+ self, ids: list[str], fields: Union[set[str], None] = None
102
+ ) -> list[Union[T, None]]:
103
+ raise NotImplementedError
104
+
105
+ async def filter_keys(self, data: list[str]) -> set[str]:
106
+ """return un-exist keys"""
107
+ raise NotImplementedError
108
+
109
+ async def upsert(self, data: dict[str, T]):
110
+ raise NotImplementedError
111
+
112
+ async def drop(self):
113
+ raise NotImplementedError
114
+
115
+
116
+ @dataclass
117
+ class BaseGraphStorage(StorageNameSpace):
118
+ async def has_node(self, node_id: str) -> bool:
119
+ raise NotImplementedError
120
+
121
+ async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
122
+ raise NotImplementedError
123
+
124
+ async def node_degree(self, node_id: str) -> int:
125
+ raise NotImplementedError
126
+
127
+ async def node_degrees_batch(self, node_ids: List[str]) -> List[str]:
128
+ raise NotImplementedError
129
+
130
+ async def edge_degree(self, src_id: str, tgt_id: str) -> int:
131
+ raise NotImplementedError
132
+
133
+ async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> list[int]:
134
+ raise NotImplementedError
135
+
136
+ async def get_node(self, node_id: str) -> Union[dict, None]:
137
+ raise NotImplementedError
138
+
139
+ async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, Union[dict, None]]:
140
+ raise NotImplementedError
141
+
142
+ async def get_edge(
143
+ self, source_node_id: str, target_node_id: str
144
+ ) -> Union[dict, None]:
145
+ raise NotImplementedError
146
+
147
+ async def get_edges_batch(
148
+ self, edge_pairs: list[tuple[str, str]]
149
+ ) -> list[Union[dict, None]]:
150
+ raise NotImplementedError
151
+
152
+ async def get_node_edges(
153
+ self, source_node_id: str
154
+ ) -> Union[list[tuple[str, str]], None]:
155
+ raise NotImplementedError
156
+
157
+ async def get_nodes_edges_batch(
158
+ self, node_ids: list[str]
159
+ ) -> list[list[tuple[str, str]]]:
160
+ raise NotImplementedError
161
+
162
+ async def upsert_node(self, node_id: str, node_data: dict[str, str]):
163
+ raise NotImplementedError
164
+
165
+ async def upsert_nodes_batch(self, nodes_data: list[tuple[str, dict[str, str]]]):
166
+ raise NotImplementedError
167
+
168
+ async def upsert_edge(
169
+ self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
170
+ ):
171
+ raise NotImplementedError
172
+
173
+ async def upsert_edges_batch(
174
+ self, edges_data: list[tuple[str, str, dict[str, str]]]
175
+ ):
176
+ raise NotImplementedError
177
+
178
+ async def clustering(self, algorithm: str):
179
+ raise NotImplementedError
180
+
181
+ async def community_schema(self) -> dict[str, SingleCommunitySchema]:
182
+ """Return the community representation with report and nodes"""
183
+ raise NotImplementedError
184
+
185
+ async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
186
+ raise NotImplementedError("Node embedding is not used in nano-graphrag.")
nano-graphrag/nano_graphrag/entity_extraction/__init__.py ADDED
File without changes
nano-graphrag/nano_graphrag/entity_extraction/extract.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+ import pickle
3
+ import asyncio
4
+ from openai import BadRequestError
5
+ from collections import defaultdict
6
+ import dspy
7
+ from nano_graphrag.base import (
8
+ BaseGraphStorage,
9
+ BaseVectorStorage,
10
+ TextChunkSchema,
11
+ )
12
+ from nano_graphrag.prompt import PROMPTS
13
+ from nano_graphrag._utils import logger, compute_mdhash_id
14
+ from nano_graphrag.entity_extraction.module import TypedEntityRelationshipExtractor
15
+ from nano_graphrag._op import _merge_edges_then_upsert, _merge_nodes_then_upsert
16
+
17
+
18
+ async def generate_dataset(
19
+ chunks: dict[str, TextChunkSchema],
20
+ filepath: str,
21
+ save_dataset: bool = True,
22
+ global_config: dict = {},
23
+ ) -> list[dspy.Example]:
24
+ entity_extractor = TypedEntityRelationshipExtractor(num_refine_turns=1, self_refine=True)
25
+
26
+ if global_config.get("use_compiled_dspy_entity_relationship", False):
27
+ entity_extractor.load(global_config["entity_relationship_module_path"])
28
+
29
+ ordered_chunks = list(chunks.items())
30
+ already_processed = 0
31
+ already_entities = 0
32
+ already_relations = 0
33
+
34
+ async def _process_single_content(
35
+ chunk_key_dp: tuple[str, TextChunkSchema]
36
+ ) -> dspy.Example:
37
+ nonlocal already_processed, already_entities, already_relations
38
+ chunk_dp = chunk_key_dp[1]
39
+ content = chunk_dp["content"]
40
+ try:
41
+ prediction = await asyncio.to_thread(entity_extractor, input_text=content)
42
+ entities, relationships = prediction.entities, prediction.relationships
43
+ except BadRequestError as e:
44
+ logger.error(f"Error in TypedEntityRelationshipExtractor: {e}")
45
+ entities, relationships = [], []
46
+ example = dspy.Example(
47
+ input_text=content, entities=entities, relationships=relationships
48
+ ).with_inputs("input_text")
49
+ already_entities += len(entities)
50
+ already_relations += len(relationships)
51
+ already_processed += 1
52
+ now_ticks = PROMPTS["process_tickers"][
53
+ already_processed % len(PROMPTS["process_tickers"])
54
+ ]
55
+ print(
56
+ f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
57
+ end="",
58
+ flush=True,
59
+ )
60
+ return example
61
+
62
+ examples = await asyncio.gather(
63
+ *[_process_single_content(c) for c in ordered_chunks]
64
+ )
65
+ filtered_examples = [
66
+ example
67
+ for example in examples
68
+ if len(example.entities) > 0 and len(example.relationships) > 0
69
+ ]
70
+ num_filtered_examples = len(examples) - len(filtered_examples)
71
+ if save_dataset:
72
+ with open(filepath, "wb") as f:
73
+ pickle.dump(filtered_examples, f)
74
+ logger.info(
75
+ f"Saved {len(filtered_examples)} examples with keys: {filtered_examples[0].keys()}, filtered {num_filtered_examples} examples"
76
+ )
77
+
78
+ return filtered_examples
79
+
80
+
81
+ async def extract_entities_dspy(
82
+ chunks: dict[str, TextChunkSchema],
83
+ knwoledge_graph_inst: BaseGraphStorage,
84
+ entity_vdb: BaseVectorStorage,
85
+ global_config: dict,
86
+ ) -> Union[BaseGraphStorage, None]:
87
+ entity_extractor = TypedEntityRelationshipExtractor(num_refine_turns=1, self_refine=True)
88
+
89
+ if global_config.get("use_compiled_dspy_entity_relationship", False):
90
+ entity_extractor.load(global_config["entity_relationship_module_path"])
91
+
92
+ ordered_chunks = list(chunks.items())
93
+ already_processed = 0
94
+ already_entities = 0
95
+ already_relations = 0
96
+
97
+ async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
98
+ nonlocal already_processed, already_entities, already_relations
99
+ chunk_key = chunk_key_dp[0]
100
+ chunk_dp = chunk_key_dp[1]
101
+ content = chunk_dp["content"]
102
+ try:
103
+ prediction = await asyncio.to_thread(entity_extractor, input_text=content)
104
+ entities, relationships = prediction.entities, prediction.relationships
105
+ except BadRequestError as e:
106
+ logger.error(f"Error in TypedEntityRelationshipExtractor: {e}")
107
+ entities, relationships = [], []
108
+
109
+ maybe_nodes = defaultdict(list)
110
+ maybe_edges = defaultdict(list)
111
+
112
+ for entity in entities:
113
+ entity["source_id"] = chunk_key
114
+ maybe_nodes[entity["entity_name"]].append(entity)
115
+ already_entities += 1
116
+
117
+ for relationship in relationships:
118
+ relationship["source_id"] = chunk_key
119
+ maybe_edges[(relationship["src_id"], relationship["tgt_id"])].append(
120
+ relationship
121
+ )
122
+ already_relations += 1
123
+
124
+ already_processed += 1
125
+ now_ticks = PROMPTS["process_tickers"][
126
+ already_processed % len(PROMPTS["process_tickers"])
127
+ ]
128
+ print(
129
+ f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
130
+ end="",
131
+ flush=True,
132
+ )
133
+ return dict(maybe_nodes), dict(maybe_edges)
134
+
135
+ results = await asyncio.gather(
136
+ *[_process_single_content(c) for c in ordered_chunks]
137
+ )
138
+ print()
139
+ maybe_nodes = defaultdict(list)
140
+ maybe_edges = defaultdict(list)
141
+ for m_nodes, m_edges in results:
142
+ for k, v in m_nodes.items():
143
+ maybe_nodes[k].extend(v)
144
+ for k, v in m_edges.items():
145
+ maybe_edges[k].extend(v)
146
+ all_entities_data = await asyncio.gather(
147
+ *[
148
+ _merge_nodes_then_upsert(k, v, knwoledge_graph_inst, global_config)
149
+ for k, v in maybe_nodes.items()
150
+ ]
151
+ )
152
+ await asyncio.gather(
153
+ *[
154
+ _merge_edges_then_upsert(k[0], k[1], v, knwoledge_graph_inst, global_config)
155
+ for k, v in maybe_edges.items()
156
+ ]
157
+ )
158
+ if not len(all_entities_data):
159
+ logger.warning("Didn't extract any entities, maybe your LLM is not working")
160
+ return None
161
+ if entity_vdb is not None:
162
+ data_for_vdb = {
163
+ compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
164
+ "content": dp["entity_name"] + dp["description"],
165
+ "entity_name": dp["entity_name"],
166
+ }
167
+ for dp in all_entities_data
168
+ }
169
+ await entity_vdb.upsert(data_for_vdb)
170
+
171
+ return knwoledge_graph_inst
nano-graphrag/nano_graphrag/entity_extraction/metric.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dspy
2
+ from nano_graphrag.entity_extraction.module import Relationship
3
+
4
+
5
+ class AssessRelationships(dspy.Signature):
6
+ """
7
+ Assess the similarity between gold and predicted relationships:
8
+ 1. Match relationships based on src_id and tgt_id pairs, allowing for slight variations in entity names.
9
+ 2. For matched pairs, compare:
10
+ a) Description similarity (semantic meaning)
11
+ b) Weight similarity
12
+ c) Order similarity
13
+ 3. Consider unmatched relationships as penalties.
14
+ 4. Aggregate scores, accounting for precision and recall.
15
+ 5. Return a final similarity score between 0 (no similarity) and 1 (perfect match).
16
+
17
+ Key considerations:
18
+ - Prioritize matching based on entity pairs over exact string matches.
19
+ - Use semantic similarity for descriptions rather than exact matches.
20
+ - Weight the importance of different aspects (e.g., entity matching, description, weight, order).
21
+ - Balance the impact of matched and unmatched relationships in the final score.
22
+ """
23
+
24
+ gold_relationships: list[Relationship] = dspy.InputField(
25
+ desc="The gold-standard relationships to compare against."
26
+ )
27
+ predicted_relationships: list[Relationship] = dspy.InputField(
28
+ desc="The predicted relationships to compare against the gold-standard relationships."
29
+ )
30
+ similarity_score: float = dspy.OutputField(
31
+ desc="Similarity score between 0 and 1, with 1 being the highest similarity."
32
+ )
33
+
34
+
35
+ def relationships_similarity_metric(
36
+ gold: dspy.Example, pred: dspy.Prediction, trace=None
37
+ ) -> float:
38
+ model = dspy.ChainOfThought(AssessRelationships)
39
+ gold_relationships = [Relationship(**item) for item in gold["relationships"]]
40
+ predicted_relationships = [Relationship(**item) for item in pred["relationships"]]
41
+ similarity_score = float(
42
+ model(
43
+ gold_relationships=gold_relationships,
44
+ predicted_relationships=predicted_relationships,
45
+ ).similarity_score
46
+ )
47
+ return similarity_score
48
+
49
+
50
+ def entity_recall_metric(
51
+ gold: dspy.Example, pred: dspy.Prediction, trace=None
52
+ ) -> float:
53
+ true_set = set(item["entity_name"] for item in gold["entities"])
54
+ pred_set = set(item["entity_name"] for item in pred["entities"])
55
+ true_positives = len(pred_set.intersection(true_set))
56
+ false_negatives = len(true_set - pred_set)
57
+ recall = (
58
+ true_positives / (true_positives + false_negatives)
59
+ if (true_positives + false_negatives) > 0
60
+ else 0
61
+ )
62
+ return recall