borges-graph / nano-graphrag /examples /using_dspy_entity_extraction.py
ArthurSrz's picture
feat: Add complete nano-graphrag source code
70ab3b6
import os
from openai import AsyncOpenAI
from dotenv import load_dotenv
import logging
import numpy as np
import dspy
from sentence_transformers import SentenceTransformer
from nano_graphrag import GraphRAG, QueryParam
from nano_graphrag._llm import gpt_4o_mini_complete
from nano_graphrag._storage import HNSWVectorStorage
from nano_graphrag.base import BaseKVStorage
from nano_graphrag._utils import compute_args_hash, wrap_embedding_func_with_attrs
from nano_graphrag.entity_extraction.extract import extract_entities_dspy
logging.basicConfig(level=logging.WARNING)
logging.getLogger("nano-graphrag").setLevel(logging.DEBUG)
WORKING_DIR = "./nano_graphrag_cache_using_dspy_entity_extraction"
load_dotenv()
EMBED_MODEL = SentenceTransformer(
"sentence-transformers/all-MiniLM-L6-v2", cache_folder=WORKING_DIR, device="cpu"
)
@wrap_embedding_func_with_attrs(
embedding_dim=EMBED_MODEL.get_sentence_embedding_dimension(),
max_token_size=EMBED_MODEL.max_seq_length,
)
async def local_embedding(texts: list[str]) -> np.ndarray:
return EMBED_MODEL.encode(texts, normalize_embeddings=True)
async def deepseepk_model_if_cache(
prompt, model: str = "deepseek-chat", system_prompt=None, history_messages=[], **kwargs
) -> str:
openai_async_client = AsyncOpenAI(
api_key=os.environ.get("DEEPSEEK_API_KEY"), base_url="https://api.deepseek.com"
)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
# Get the cached response if having-------------------
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
if hashing_kv is not None:
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
# -----------------------------------------------------
response = await openai_async_client.chat.completions.create(
model=model, messages=messages, **kwargs
)
# Cache the response if having-------------------
if hashing_kv is not None:
await hashing_kv.upsert(
{args_hash: {"return": response.choices[0].message.content, "model": model}}
)
# -----------------------------------------------------
return response.choices[0].message.content
def remove_if_exist(file):
if os.path.exists(file):
os.remove(file)
def insert():
from time import time
with open("./tests/mock_data.txt", encoding="utf-8-sig") as f:
FAKE_TEXT = f.read()
remove_if_exist(f"{WORKING_DIR}/vdb_entities.json")
remove_if_exist(f"{WORKING_DIR}/kv_store_full_docs.json")
remove_if_exist(f"{WORKING_DIR}/kv_store_text_chunks.json")
remove_if_exist(f"{WORKING_DIR}/kv_store_community_reports.json")
remove_if_exist(f"{WORKING_DIR}/graph_chunk_entity_relation.graphml")
rag = GraphRAG(
working_dir=WORKING_DIR,
enable_llm_cache=True,
vector_db_storage_cls=HNSWVectorStorage,
vector_db_storage_cls_kwargs={"max_elements": 1000000, "ef_search": 200, "M": 50},
best_model_max_async=10,
cheap_model_max_async=10,
best_model_func=deepseepk_model_if_cache,
cheap_model_func=deepseepk_model_if_cache,
embedding_func=local_embedding,
entity_extraction_func=extract_entities_dspy
)
start = time()
rag.insert(FAKE_TEXT)
print("indexing time:", time() - start)
def query():
rag = GraphRAG(
working_dir=WORKING_DIR,
enable_llm_cache=True,
vector_db_storage_cls=HNSWVectorStorage,
vector_db_storage_cls_kwargs={"max_elements": 1000000, "ef_search": 200, "M": 50},
best_model_max_token_size=8196,
cheap_model_max_token_size=8196,
best_model_max_async=4,
cheap_model_max_async=4,
best_model_func=gpt_4o_mini_complete,
cheap_model_func=gpt_4o_mini_complete,
embedding_func=local_embedding,
entity_extraction_func=extract_entities_dspy
)
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="global")
)
)
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="local")
)
)
if __name__ == "__main__":
lm = dspy.LM(
model="deepseek/deepseek-chat",
model_type="chat",
api_provider="openai",
api_key=os.environ["DEEPSEEK_API_KEY"],
base_url=os.environ["DEEPSEEK_BASE_URL"],
temperature=1.0,
max_tokens=8192
)
dspy.settings.configure(lm=lm, experimental=True)
insert()
query()