Spaces:
Sleeping
Sleeping
| import os | |
| from openai import AsyncOpenAI | |
| from dotenv import load_dotenv | |
| import logging | |
| import numpy as np | |
| import dspy | |
| from sentence_transformers import SentenceTransformer | |
| from nano_graphrag import GraphRAG, QueryParam | |
| from nano_graphrag._llm import gpt_4o_mini_complete | |
| from nano_graphrag._storage import HNSWVectorStorage | |
| from nano_graphrag.base import BaseKVStorage | |
| from nano_graphrag._utils import compute_args_hash, wrap_embedding_func_with_attrs | |
| from nano_graphrag.entity_extraction.extract import extract_entities_dspy | |
| logging.basicConfig(level=logging.WARNING) | |
| logging.getLogger("nano-graphrag").setLevel(logging.DEBUG) | |
| WORKING_DIR = "./nano_graphrag_cache_using_dspy_entity_extraction" | |
| load_dotenv() | |
| EMBED_MODEL = SentenceTransformer( | |
| "sentence-transformers/all-MiniLM-L6-v2", cache_folder=WORKING_DIR, device="cpu" | |
| ) | |
| async def local_embedding(texts: list[str]) -> np.ndarray: | |
| return EMBED_MODEL.encode(texts, normalize_embeddings=True) | |
| async def deepseepk_model_if_cache( | |
| prompt, model: str = "deepseek-chat", system_prompt=None, history_messages=[], **kwargs | |
| ) -> str: | |
| openai_async_client = AsyncOpenAI( | |
| api_key=os.environ.get("DEEPSEEK_API_KEY"), base_url="https://api.deepseek.com" | |
| ) | |
| messages = [] | |
| if system_prompt: | |
| messages.append({"role": "system", "content": system_prompt}) | |
| # Get the cached response if having------------------- | |
| hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) | |
| messages.extend(history_messages) | |
| messages.append({"role": "user", "content": prompt}) | |
| if hashing_kv is not None: | |
| args_hash = compute_args_hash(model, messages) | |
| if_cache_return = await hashing_kv.get_by_id(args_hash) | |
| if if_cache_return is not None: | |
| return if_cache_return["return"] | |
| # ----------------------------------------------------- | |
| response = await openai_async_client.chat.completions.create( | |
| model=model, messages=messages, **kwargs | |
| ) | |
| # Cache the response if having------------------- | |
| if hashing_kv is not None: | |
| await hashing_kv.upsert( | |
| {args_hash: {"return": response.choices[0].message.content, "model": model}} | |
| ) | |
| # ----------------------------------------------------- | |
| return response.choices[0].message.content | |
| def remove_if_exist(file): | |
| if os.path.exists(file): | |
| os.remove(file) | |
| def insert(): | |
| from time import time | |
| with open("./tests/mock_data.txt", encoding="utf-8-sig") as f: | |
| FAKE_TEXT = f.read() | |
| remove_if_exist(f"{WORKING_DIR}/vdb_entities.json") | |
| remove_if_exist(f"{WORKING_DIR}/kv_store_full_docs.json") | |
| remove_if_exist(f"{WORKING_DIR}/kv_store_text_chunks.json") | |
| remove_if_exist(f"{WORKING_DIR}/kv_store_community_reports.json") | |
| remove_if_exist(f"{WORKING_DIR}/graph_chunk_entity_relation.graphml") | |
| rag = GraphRAG( | |
| working_dir=WORKING_DIR, | |
| enable_llm_cache=True, | |
| vector_db_storage_cls=HNSWVectorStorage, | |
| vector_db_storage_cls_kwargs={"max_elements": 1000000, "ef_search": 200, "M": 50}, | |
| best_model_max_async=10, | |
| cheap_model_max_async=10, | |
| best_model_func=deepseepk_model_if_cache, | |
| cheap_model_func=deepseepk_model_if_cache, | |
| embedding_func=local_embedding, | |
| entity_extraction_func=extract_entities_dspy | |
| ) | |
| start = time() | |
| rag.insert(FAKE_TEXT) | |
| print("indexing time:", time() - start) | |
| def query(): | |
| rag = GraphRAG( | |
| working_dir=WORKING_DIR, | |
| enable_llm_cache=True, | |
| vector_db_storage_cls=HNSWVectorStorage, | |
| vector_db_storage_cls_kwargs={"max_elements": 1000000, "ef_search": 200, "M": 50}, | |
| best_model_max_token_size=8196, | |
| cheap_model_max_token_size=8196, | |
| best_model_max_async=4, | |
| cheap_model_max_async=4, | |
| best_model_func=gpt_4o_mini_complete, | |
| cheap_model_func=gpt_4o_mini_complete, | |
| embedding_func=local_embedding, | |
| entity_extraction_func=extract_entities_dspy | |
| ) | |
| print( | |
| rag.query( | |
| "What are the top themes in this story?", param=QueryParam(mode="global") | |
| ) | |
| ) | |
| print( | |
| rag.query( | |
| "What are the top themes in this story?", param=QueryParam(mode="local") | |
| ) | |
| ) | |
| if __name__ == "__main__": | |
| lm = dspy.LM( | |
| model="deepseek/deepseek-chat", | |
| model_type="chat", | |
| api_provider="openai", | |
| api_key=os.environ["DEEPSEEK_API_KEY"], | |
| base_url=os.environ["DEEPSEEK_BASE_URL"], | |
| temperature=1.0, | |
| max_tokens=8192 | |
| ) | |
| dspy.settings.configure(lm=lm, experimental=True) | |
| insert() | |
| query() | |