File size: 3,396 Bytes
70ab3b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import asyncio
import uuid
import numpy as np
from nano_graphrag import GraphRAG, QueryParam
from nano_graphrag._utils import logger
from nano_graphrag.base import BaseVectorStorage
from dataclasses import dataclass

try:
    from qdrant_client import QdrantClient
    from qdrant_client.models import VectorParams, Distance, PointStruct, SearchParams
except ImportError as original_error:
    raise ImportError(
        "Qdrant client is not installed. Install it using: pip install qdrant-client\n"
    ) from original_error


@dataclass
class QdrantStorage(BaseVectorStorage):
    def __post_init__(self):

        # Use a local file-based Qdrant storage
        # Useful for prototyping and CI.
        # For production, refer to:
        # https://qdrant.tech/documentation/guides/installation/
        self._client_file_path = os.path.join(
            self.global_config["working_dir"], "qdrant_storage"
        )

        self._client = QdrantClient(path=self._client_file_path)

        self._max_batch_size = self.global_config["embedding_batch_num"]

        if not self._client.collection_exists(collection_name=self.namespace):
            self._client.create_collection(
                collection_name=self.namespace,
                vectors_config=VectorParams(
                    size=self.embedding_func.embedding_dim, distance=Distance.COSINE
                ),
            )

    async def upsert(self, data: dict[str, dict]):
        logger.info(f"Inserting {len(data)} vectors to {self.namespace}")

        list_data = [
            {
                "id": k,
                **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
            }
            for k, v in data.items()
        ]

        contents = [v["content"] for v in data.values()]
        batches = [
            contents[i : i + self._max_batch_size]
            for i in range(0, len(contents), self._max_batch_size)
        ]

        embeddings_list = await asyncio.gather(
            *[self.embedding_func(batch) for batch in batches]
        )
        embeddings = np.concatenate(embeddings_list)

        points = [
            PointStruct(
                id=uuid.uuid4().hex,
                vector=embeddings[i].tolist(),
                payload=data,
            )
            for i, data in enumerate(list_data)
        ]

        results = self._client.upsert(collection_name=self.namespace, points=points)
        return results

    async def query(self, query, top_k=5):
        embedding = await self.embedding_func([query])

        results = self._client.query_points(
            collection_name=self.namespace,
            query=embedding[0].tolist(),
            limit=top_k,
        ).points

        return [
            {**result.payload, "score": result.score}
            for result in results
        ]


def insert():
    data = ["YOUR TEXT DATA HERE", "YOUR TEXT DATA HERE"]
    rag = GraphRAG(
        working_dir="./nano_graphrag_cache_qdrant_TEST",
        enable_llm_cache=True,
        vector_db_storage_cls=QdrantStorage,
    )
    rag.insert(data)


def query():
    rag = GraphRAG(
        working_dir="./nano_graphrag_cache_qdrant_TEST",
        enable_llm_cache=True,
        vector_db_storage_cls=QdrantStorage,
    )
    print(rag.query("YOUR QUERY HERE", param=QueryParam(mode="local")))


if __name__ == "__main__":
    insert()
    query()