File size: 3,640 Bytes
97f9138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct
from typing import List, Dict, Optional
import uuid
from src.config import config
from src.embeddings import embedding_service

class VectorStore:
    def __init__(self, host: str = None, port: int = None, collection_name: str = None):
        self.host = host or config.QDRANT_HOST
        self.port = port or config.QDRANT_PORT
        self.collection_name = collection_name or config.COLLECTION_NAME
        self.client = None
    
    def connect(self):
        if self.client is None:
            if config.QDRANT_URL:
                self.client = QdrantClient(
                    url=config.QDRANT_URL,
                    api_key=config.QDRANT_API_KEY
                )
            else:
                self.client = QdrantClient(host=self.host, port=self.port)
            self._ensure_collection()
    
    def _ensure_collection(self):
        collections = self.client.get_collections().collections
        collection_names = [col.name for col in collections]
        
        if self.collection_name not in collection_names:
            embedding_dim = embedding_service.get_embedding_dimension()
            self.client.create_collection(
                collection_name=self.collection_name,
                vectors_config=VectorParams(
                    size=embedding_dim,
                    distance=Distance.COSINE
                )
            )
            print(f"Created collection: {self.collection_name}")
    
    def store_chunks(self, url_id: str, url: str, chunks: List[Dict], embeddings: List[List[float]]):
        self.connect()
        
        points = []
        for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
            point_id = str(uuid.uuid4())
            point = PointStruct(
                id=point_id,
                vector=embedding,
                payload={
                    "url_id": url_id,
                    "url": url,
                    "chunk_id": chunk["id"],
                    "text": chunk["text"],
                    "start_word": chunk["start_word"],
                    "end_word": chunk["end_word"]
                }
            )
            points.append(point)
        
        self.client.upsert(
            collection_name=self.collection_name,
            points=points
        )
        
        return len(points)
    
    def search(self, query_embedding: List[float], top_k: int = None) -> List[Dict]:
        self.connect()
        
        k = top_k or config.TOP_K_RESULTS
        
        results = self.client.search(
            collection_name=self.collection_name,
            query_vector=query_embedding,
            limit=k
        )
        
        return [
            {
                "id": result.id,
                "score": result.score,
                "url": result.payload.get("url"),
                "url_id": result.payload.get("url_id"),
                "text": result.payload.get("text"),
                "chunk_id": result.payload.get("chunk_id")
            }
            for result in results
        ]
    
    def delete_by_url_id(self, url_id: str):
        self.connect()
        
        self.client.delete(
            collection_name=self.collection_name,
            points_selector={
                "filter": {
                    "must": [
                        {
                            "key": "url_id",
                            "match": {"value": url_id}
                        }
                    ]
                }
            }
        )

vector_store = VectorStore()