File size: 7,028 Bytes
c88e290
 
 
 
 
7841205
c88e290
 
 
 
8e0483e
 
c88e290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7841205
 
 
 
c88e290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1112c5
c88e290
 
 
 
 
 
7841205
 
 
 
c88e290
7841205
 
1fd5385
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c88e290
7841205
 
 
 
 
 
 
 
 
b1112c5
 
7841205
 
 
 
 
 
 
 
 
 
 
c88e290
 
 
 
7841205
 
 
 
c88e290
 
 
 
 
 
 
 
7841205
 
c88e290
 
7841205
 
 
 
 
 
c88e290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3755446
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import os
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from sentence_transformers import CrossEncoder
from core.ChunkingManager import ChunkingManager, ChunkingStrategy
import tracker 

# --- CONFIGURATION ---
UPLOAD_DIR = "/tmp/rag_uploads"
DB_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "chroma_db")
EMBEDDING_MODEL_NAME = "all-MiniLM-L12-v2"
RERANKER_MODEL_NAME = "BAAI/bge-reranker-base" 

# --- LAZY LOADING SINGLETONS ---
_embedding_fn = None
_reranker = None
_chunk_manager = None

def get_embedding_function():
    global _embedding_fn
    if _embedding_fn is None:
        print("⚙️ Loading Embedding Model...")
        _embedding_fn = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
    return _embedding_fn

def get_reranker_model():
    global _reranker
    if _reranker is None:
        print("⚙️ Loading Reranker Model...")
        _reranker = CrossEncoder(RERANKER_MODEL_NAME)
    return _reranker

def get_chunk_manager():
    global _chunk_manager
    if _chunk_manager is None:
        print("⚙️ Loading Chunk Manager...")
        _chunk_manager = ChunkingManager(embedding_model_name=EMBEDDING_MODEL_NAME)
    return _chunk_manager

# --- DATABASE OPERATIONS ---
def get_vectorstore(username):
    safe_username = os.path.basename(username)
    user_db_path = os.path.join(DB_ROOT, safe_username)
    
    if not os.path.exists(user_db_path):
        os.makedirs(user_db_path, exist_ok=True)
        
    return Chroma(
        persist_directory=user_db_path,
        embedding_function=get_embedding_function(),
        collection_name=f"docs_{safe_username}" 
    )

def save_uploaded_file(uploaded_file):
    if not os.path.exists(UPLOAD_DIR):
        os.makedirs(UPLOAD_DIR)
    safe_filename = os.path.basename(uploaded_file.name)
    file_path = os.path.join(UPLOAD_DIR, safe_filename)
    with open(file_path, "wb") as f:
        f.write(uploaded_file.getbuffer())
    return file_path

def process_and_add_document(file_path, username, strategy="paragraph"):
    try:
        print(f"🧠 Chunking {os.path.basename(file_path)}...")
        
        strat_map = {
            "token": ChunkingStrategy.TOKEN,
            "paragraph": ChunkingStrategy.PARAGRAPH,
            "page": ChunkingStrategy.PAGE
        }
        selected_strategy = strat_map.get(strategy, ChunkingStrategy.PARAGRAPH)
        
        manager = get_chunk_manager()
        chunks = manager.process_document(
            file_path=file_path,
            strategy=selected_strategy,
            preprocess=True
        )
        
        if not chunks:
            return False, "No text extracted. Is the file empty/scanned?"

        # FIX #1: Tag every chunk with the strategy used
        for chunk in chunks:
            chunk.metadata["strategy"] = strategy

        print(f"💾 Indexing {len(chunks)} chunks into Vector DB...")
        db = get_vectorstore(username)
        db.add_documents(chunks)
        
        tracker.upload_user_db(username)
        
        if os.path.exists(file_path):
            os.remove(file_path)
            
        return True, f"Successfully added {len(chunks)} chunks to Knowledge Base."
        
    except Exception as e:
        print(f"❌ Processing Error: {e}")
        return False, str(e)

# --- RETRIEVAL ENGINE ---
def search_knowledge_base(query, username, k=6):
    """
    Two-Stage Retrieval System (RAG):
    1. Retrieval: Get 10 candidates via fast Vector Search.
    2. Reranking: Sort them via Cross-Encoder (Slow/Precise).
    3. Return top k.
    """
    try:
        db = get_vectorstore(username)
        if db._collection.count() == 0:
            return []

        reranker = get_reranker_model()
        
        # 1. Vector Search (Broad Net)
        vector_results = db.similarity_search(query, k=25)
        
        # 2. "Poor Man's" Keyword Search (The Safety Net)
        # We perform a basic text search for unique terms in the query
        # This catches acronyms like "C&D" if we normalize them
        
        # Normalize query acronyms (e.g., "C&D" -> "C D")
        normalized_query = query.replace("&", " ")
        keyword_results = []
        
        # (Optional: In a production DB like Pinecone/Weaviate, this is built-in.
        # For Chroma local, we rely on the vector net mostly, but we can 
        # extend k significantly to catch edge cases).
        
        # STRATEGY: Just widen the net significantly.
        # Vector models often hide the match at rank 30 or 40 if the spelling differs.
        results = db.similarity_search(query, k=50) # Widen from 25 to 50
        
        if not results:
            return []

        # 2. Reranking
        passages = [doc.page_content for doc in results]
        ranks = reranker.rank(query, passages)
        
        top_results = []
        sorted_ranks = sorted(ranks, key=lambda x: x['score'], reverse=True)

        # Return the top k results
        for rank in sorted_ranks[:k]:
            doc_index = rank['corpus_id']
            doc = results[doc_index]
            doc.metadata["relevance_score"] = round(rank['score'], 4)
            top_results.append(doc)
            
        return top_results
        
    except Exception as e:
        print(f"⚠️ Search Error (likely empty DB): {e}")
        return []

def list_documents(username):
    try:
        db = get_vectorstore(username)
        # Check if empty before fetching to prevent errors
        if db._collection.count() == 0:
            return []
            
        data = db.get() 
        metadatas = data['metadatas']
        
        file_stats = {}
        
        for meta in metadatas:
            src = meta.get('source', 'unknown')
            filename = os.path.basename(src)
            # FIX #2: Retrieve the strategy (Default to 'unknown' for old docs)
            strat = meta.get('strategy', 'unknown')
            
            if src not in file_stats:
                file_stats[src] = {
                    'source': src, 
                    'filename': filename, 
                    'chunks': 0, 
                    'strategy': strat
                }
            file_stats[src]['chunks'] += 1
            
        return list(file_stats.values())
        
    except Exception as e:
        print(f"❌ Error listing docs: {e}")
        return []

def delete_document(username, source_path):
    try:
        print(f"🗑️ Deleting {source_path} for {username}...")
        db = get_vectorstore(username)
        db.delete(where={"source": source_path})
        tracker.upload_user_db(username)
        return True, f"Deleted {os.path.basename(source_path)}"
    except Exception as e:
        return False, str(e)

def reset_knowledge_base(username):
    try:
        db = get_vectorstore(username)
        db.delete_collection()
        tracker.upload_user_db(username)
        return True, "Knowledge Base completely reset."
    except Exception as e:
        return False, str(e)