File size: 4,402 Bytes
04653e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fec7894
 
 
 
04653e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""Unified vector store interface: FAISS (default) + optional Qdrant Cloud."""
import faiss
import numpy as np
from typing import List, Tuple
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))

from sentence_transformers import SentenceTransformer
from app.config import Config


class VectorStore:
    """Vector storage with FAISS (default) and optional Qdrant Cloud upgrade."""
    
    def __init__(self, model_name: str = Config.EMBEDDING_MODEL):
        import os
        # Increase timeout for HF Spaces
        os.environ['HF_HUB_DOWNLOAD_TIMEOUT'] = '300'
        
        self.model = SentenceTransformer(model_name, device='cpu')
        self.use_qdrant = Config.USE_QDRANT
        
        if self.use_qdrant:
            try:
                from qdrant_client import QdrantClient
                from qdrant_client.models import Distance, VectorParams
                
                # Connect to Qdrant Cloud
                self.client = QdrantClient(
                    url=Config.QDRANT_URL,
                    api_key=Config.QDRANT_API_KEY
                )
                self._init_qdrant_collection()
                print(f"✅ Connected to Qdrant Cloud")
            except Exception as e:
                print(f"⚠️  Qdrant connection failed: {e}")
                print("   Falling back to FAISS...")
                self.use_qdrant = False
                self.index = None
                self.index_path = Config.FAISS_INDEX_PATH
        else:
            self.index = None
            self.index_path = Config.FAISS_INDEX_PATH
    
    def _init_qdrant_collection(self):
        """Initialize Qdrant collection if it doesn't exist."""
        from qdrant_client.models import Distance, VectorParams
        
        collections = self.client.get_collections().collections
        collection_names = [c.name for c in collections]
        
        if Config.QDRANT_COLLECTION not in collection_names:
            self.client.create_collection(
                collection_name=Config.QDRANT_COLLECTION,
                vectors_config=VectorParams(
                    size=self.model.get_sentence_embedding_dimension(),
                    distance=Distance.COSINE
                )
            )
    
    def add_documents(self, files: List[Tuple[str, str]]):
        """Add documents to vector store."""
        texts = [content for _, content in files]
        embeddings = self.model.encode(texts, show_progress_bar=True)
        
        if self.use_qdrant:
            from qdrant_client.models import PointStruct
            
            points = [
                PointStruct(
                    id=idx,
                    vector=embedding.tolist(),
                    payload={"filename": filename, "content": content}
                )
                for idx, (embedding, (filename, content)) in enumerate(zip(embeddings, files))
            ]
            self.client.upsert(
                collection_name=Config.QDRANT_COLLECTION,
                points=points
            )
            print(f"✅ Indexed {len(files)} files to Qdrant Cloud")
        else:
            # FAISS (default)
            if embeddings.ndim != 2:
                raise ValueError(f"Expected 2D embeddings, got shape {embeddings.shape}")
            self.index = faiss.IndexFlatL2(embeddings.shape[1])
            self.index.add(embeddings.astype('float32'))
            faiss.write_index(self.index, self.index_path)
            print(f"✅ Indexed {len(files)} files to FAISS")
        
        return embeddings, files
    
    def search(self, query: str, k: int = 5) -> List[Tuple[str, str]]:
        """Search for similar documents."""
        query_embedding = self.model.encode([query])
        
        if self.use_qdrant:
            results = self.client.search(
                collection_name=Config.QDRANT_COLLECTION,
                query_vector=query_embedding[0].tolist(),
                limit=k
            )
            return [(r.payload["filename"], r.payload["content"]) for r in results]
        else:
            # FAISS (default)
            if self.index is None:
                self.index = faiss.read_index(self.index_path)
            _, indices = self.index.search(query_embedding.astype('float32'), k)
            # Note: This requires files to be stored separately
            return indices[0].tolist()