File size: 3,922 Bytes
e2eff86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from typing import List, Optional
import numpy as np
import json
import os
from qdrant_client import QdrantClient
from qdrant_client.http import models
from qdrant_client.http.models import Distance, VectorParams

class VectorDB:
    def __init__(self):
        # Get vector database URL from environment variables
        vector_db_url = os.getenv("VECTOR_DB_URL", "http://localhost:6333")
        self.client = QdrantClient(url=vector_db_url)
        
        # Initialize the collection for book content
        self.collection_name = "book_content"
        self._init_collection()
    
    def _init_collection(self):
        """
        Initialize the Qdrant collection for storing book content embeddings
        """
        try:
            # Try to get the collection info
            self.client.get_collection(self.collection_name)
        except:
            # If collection doesn't exist, create it
            # Note: This assumes embeddings of size 768 (typical for many embedding models)
            self.client.create_collection(
                collection_name=self.collection_name,
                vectors_config=VectorParams(size=768, distance=Distance.COSINE),
            )
    
    def store_embedding(self, content_id: str, embedding: List[float], content_metadata: dict):
        """
        Store an embedding with its metadata in the vector database
        """
        self.client.upsert(
            collection_name=self.collection_name,
            points=[
                models.PointStruct(
                    id=content_id,
                    vector=embedding,
                    payload={
                        "content_id": content_id,
                        "metadata": content_metadata
                    }
                )
            ]
        )
    
    def search_similar(self, query_embedding: List[float], limit: int = 5) -> List[dict]:
        """
        Search for similar content based on the query embedding
        """
        search_result = self.client.search(
            collection_name=self.collection_name,
            query_vector=query_embedding,
            limit=limit
        )
        
        results = []
        for hit in search_result:
            results.append({
                "content_id": hit.payload["content_id"],
                "metadata": hit.payload["metadata"],
                "score": hit.score
            })
        
        return results
    
    def delete_content(self, content_id: str):
        """
        Delete a content entry from the vector database
        """
        self.client.delete(
            collection_name=self.collection_name,
            points_selector=models.PointIdsList(
                points=[content_id]
            )
        )
    
    def update_content(self, content_id: str, embedding: List[float], content_metadata: dict):
        """
        Update a content entry in the vector database
        """
        self.client.upsert(
            collection_name=self.collection_name,
            points=[
                models.PointStruct(
                    id=content_id,
                    vector=embedding,
                    payload={
                        "content_id": content_id,
                        "metadata": content_metadata
                    }
                )
            ]
        )
    
    def get_content(self, content_id: str) -> Optional[dict]:
        """
        Get content by ID from the vector database
        """
        points = self.client.retrieve(
            collection_name=self.collection_name,
            ids=[content_id]
        )
        
        if points:
            point = points[0]
            return {
                "content_id": point.payload["content_id"],
                "metadata": point.payload["metadata"],
                "vector": point.vector
            }
        
        return None

# Singleton instance
vector_db = VectorDB()