Spaces:
Sleeping
Sleeping
| import os | |
| import asyncio | |
| import uuid | |
| import numpy as np | |
| from nano_graphrag import GraphRAG, QueryParam | |
| from nano_graphrag._utils import logger | |
| from nano_graphrag.base import BaseVectorStorage | |
| from dataclasses import dataclass | |
| try: | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.models import VectorParams, Distance, PointStruct, SearchParams | |
| except ImportError as original_error: | |
| raise ImportError( | |
| "Qdrant client is not installed. Install it using: pip install qdrant-client\n" | |
| ) from original_error | |
| class QdrantStorage(BaseVectorStorage): | |
| def __post_init__(self): | |
| # Use a local file-based Qdrant storage | |
| # Useful for prototyping and CI. | |
| # For production, refer to: | |
| # https://qdrant.tech/documentation/guides/installation/ | |
| self._client_file_path = os.path.join( | |
| self.global_config["working_dir"], "qdrant_storage" | |
| ) | |
| self._client = QdrantClient(path=self._client_file_path) | |
| self._max_batch_size = self.global_config["embedding_batch_num"] | |
| if not self._client.collection_exists(collection_name=self.namespace): | |
| self._client.create_collection( | |
| collection_name=self.namespace, | |
| vectors_config=VectorParams( | |
| size=self.embedding_func.embedding_dim, distance=Distance.COSINE | |
| ), | |
| ) | |
| async def upsert(self, data: dict[str, dict]): | |
| logger.info(f"Inserting {len(data)} vectors to {self.namespace}") | |
| list_data = [ | |
| { | |
| "id": k, | |
| **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields}, | |
| } | |
| for k, v in data.items() | |
| ] | |
| contents = [v["content"] for v in data.values()] | |
| batches = [ | |
| contents[i : i + self._max_batch_size] | |
| for i in range(0, len(contents), self._max_batch_size) | |
| ] | |
| embeddings_list = await asyncio.gather( | |
| *[self.embedding_func(batch) for batch in batches] | |
| ) | |
| embeddings = np.concatenate(embeddings_list) | |
| points = [ | |
| PointStruct( | |
| id=uuid.uuid4().hex, | |
| vector=embeddings[i].tolist(), | |
| payload=data, | |
| ) | |
| for i, data in enumerate(list_data) | |
| ] | |
| results = self._client.upsert(collection_name=self.namespace, points=points) | |
| return results | |
| async def query(self, query, top_k=5): | |
| embedding = await self.embedding_func([query]) | |
| results = self._client.query_points( | |
| collection_name=self.namespace, | |
| query=embedding[0].tolist(), | |
| limit=top_k, | |
| ).points | |
| return [ | |
| {**result.payload, "score": result.score} | |
| for result in results | |
| ] | |
| def insert(): | |
| data = ["YOUR TEXT DATA HERE", "YOUR TEXT DATA HERE"] | |
| rag = GraphRAG( | |
| working_dir="./nano_graphrag_cache_qdrant_TEST", | |
| enable_llm_cache=True, | |
| vector_db_storage_cls=QdrantStorage, | |
| ) | |
| rag.insert(data) | |
| def query(): | |
| rag = GraphRAG( | |
| working_dir="./nano_graphrag_cache_qdrant_TEST", | |
| enable_llm_cache=True, | |
| vector_db_storage_cls=QdrantStorage, | |
| ) | |
| print(rag.query("YOUR QUERY HERE", param=QueryParam(mode="local"))) | |
| if __name__ == "__main__": | |
| insert() | |
| query() | |