File size: 3,283 Bytes
36425a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
from qdrant_client import QdrantClient, models
from typing import List, Dict, Any

class QdrantService:
    def __init__(self):
        """
        Initializes the QdrantService, setting up the client and ensuring the collection exists.
        """
        self.qdrant_url = os.getenv("QDRANT_URL")
        self.qdrant_api_key = os.getenv("QDRANT_API_KEY")
        self.collection_name = "textbook_chunks"
        self.vector_size = 384  # Based on the all-MiniLM-L6-v2 model

        if not self.qdrant_url:
            raise ValueError("QDRANT_URL must be set in environment variables.")

        # The QdrantClient can be initialized with or without an API key.
        # If the key is None, it will connect to a local or unsecured Qdrant instance.
        self.client = QdrantClient(
            url=self.qdrant_url,
            api_key=self.qdrant_api_key,
        )
        self.ensure_collection()

    def ensure_collection(self):
        """
        Checks if the required collection exists in Qdrant and creates it if it doesn't.
        """
        try:
            self.client.get_collection(collection_name=self.collection_name)
            print(f"Collection '{self.collection_name}' already exists.")
        except Exception:
            print(f"Collection '{self.collection_name}' not found, creating it...")
            self.client.recreate_collection(
                collection_name=self.collection_name,
                vectors_config=models.VectorParams(size=self.vector_size, distance=models.Distance.COSINE),
            )
            print(f"Collection '{self.collection_name}' created successfully.")

    def upsert_chunks(self, ids: List[str], vectors: List[List[float]], payloads: List[Dict[str, Any]]):
        if not (len(ids) == len(vectors) == len(payloads)):
            raise ValueError("ids, vectors, and payloads must have the same length")

        points = [
            models.PointStruct(id=id_, vector=vector, payload=payload)
            for id_, vector, payload in zip(ids, vectors, payloads)
        ]

        self.client.upsert(
            collection_name=self.collection_name,
            points=points,
            wait=True
        )
        print(f"Upserted {len(points)} chunks successfully.")

    def search(self, query_vector: List[float], limit: int = 5) -> List[Dict[str, Any]]:
        """
        Performs a vector search in the Qdrant collection.
        
        Args:
            query_vector: The vector representation of the query.
            limit: The maximum number of results to return.
            
        Returns:
            A list of search results, each containing the payload and score.
        """
        print(f"Searching Qdrant with a vector...")
        search_results = self.client.search(
            collection_name=self.collection_name,
            query_vector=query_vector,
            limit=limit,
            with_payload=True,  # Ensure the payload is returned with the results
        )
        
        # The modern client returns ScoredPoint objects
        results = [
            {"payload": hit.payload, "score": hit.score}
            for hit in search_results
        ]
        
        print(f"Qdrant search completed. Found {len(results)} results.")
        return results