File size: 12,724 Bytes
968e24d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7c31c1
 
 
968e24d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
"""RAG Query Engine with LLM"""

# pyrefly: ignore [missing-import]
import faiss
import json
import sqlite3
import re
# pyrefly: ignore [missing-import]
from sentence_transformers import SentenceTransformer, CrossEncoder
import os
from groq import Groq
from dotenv import load_dotenv

from src.summarization.ranker import ImportanceRanker
from src.summarization.utils import split_sentences

load_dotenv()

class QueryEngine:
    
    def __init__(self):
        print("Loading RAG components...")
        
        # FAISS + SQLite + Embeddings
        self.index = faiss.read_index("data/processed/faiss/faiss_index.bin")
        
        with open("data/processed/embeddings/paragraph_ids.json") as f:
            self.para_ids = json.load(f)
        
        self.model = SentenceTransformer("BAAI/bge-base-en-v1.5")
        self.reranker = CrossEncoder("BAAI/bge-reranker-base")
        self.importance_ranker = ImportanceRanker("outputs/summarization/final")
        
        # LLM Setup
        api_key = os.getenv('GROQ_API_KEY')
        if not api_key:
            raise ValueError("GROQ_API_KEY not set")
        
        self.llm = Groq(api_key=api_key)
        self.llm_model = 'llama-3.1-8b-instant'  
        
        print(f"βœ“ Ready with {self.index.ntotal:,} vectors")
        print(f"βœ“ LLM: Groq (Llama 3.1 8B)")
        
    def _get_db(self):
        return sqlite3.connect("data/processed/indexed/paragraphs.db")
    
    def search(self, query: str, top_k: int = 5):
        """Hybrid Search: FAISS (Dense) + SQLite FTS5 (BM25) with RRF"""
        
        # --- 1. Dense Search (FAISS) ---
        query_vec = self.model.encode([query], normalize_embeddings=True)
        dense_scores, dense_indices = self.index.search(query_vec, k=top_k * 2) # Fetch extra for fusion
        
        dense_results = []
        for rank, (score, idx) in enumerate(zip(dense_scores[0], dense_indices[0])):
            para_id = self.para_ids[idx]
            dense_results.append({'id': para_id, 'score': float(score), 'rank': rank + 1})
            
        # --- 2. Keyword Search (SQLite FTS5 BM25) ---
        db = self._get_db()
        cursor = db.cursor()
        
        # FTS5 requires a specific syntax. A raw string means AND.
        # We want BM25 behavior (OR), so we clean punctuation and join words with OR.
        import re
        clean_query = re.sub(r'[^\w\s]', '', query)
        fts_query = " OR ".join(clean_query.split())
        
        try:
            cursor.execute(f"""
                SELECT id, bm25(paragraphs_fts) as bm25_score
                FROM paragraphs_fts 
                WHERE paragraphs_fts MATCH ?
                ORDER BY bm25_score LIMIT ?
            """, (fts_query, top_k * 2))
            fts_rows = cursor.fetchall()
            
            keyword_results = []
            for rank, row in enumerate(fts_rows):
                keyword_results.append({'id': row[0], 'score': float(row[1]), 'rank': rank + 1})
        except sqlite3.OperationalError:
            # Fallback if query syntax is too complex for basic MATCH or FTS table missing
            keyword_results = []

        # --- 3. Reciprocal Rank Fusion (RRF) ---
        # RRF Score = 1 / (k + rank)  where k is usually 60
        k = 60
        rrf_scores = {}
        
        # Add dense scores
        for res in dense_results:
            pid = res['id']
            rrf_scores[pid] = rrf_scores.get(pid, 0.0) + (1.0 / (k + res['rank']))
            
        # Add keyword scores
        for res in keyword_results:
            pid = res['id']
            rrf_scores[pid] = rrf_scores.get(pid, 0.0) + (1.0 / (k + res['rank']))
            
        # Sort by RRF score descending
        # Fetch a larger pool of candidates for reranking
        candidate_pool_size = top_k * 3
        sorted_rrf = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True)[:candidate_pool_size]
        
        # --- 4. Fetch Details & Rerank (Cross-Encoder) ---
        candidates = []
        for pid, rrf_score in sorted_rrf:
            cursor.execute(
                "SELECT judgment_id, text, page_no FROM paragraphs WHERE id = ?",
                (pid,)
            )
            row = cursor.fetchone()
            
            if row:
                candidates.append({
                    'rrf_score': rrf_score,
                    'judgment_id': row[0],
                    'text': row[1],
                    'page_no': row[2],
                    'id': pid
                })
        
        if not candidates:
            db.close()
            return []
            
        db.close()

        # --- 5. Final Rerank (Cross-Encoder) ---
            
        # Prepare inputs for cross-encoder: list of [query, document_text]
        cross_inp = [[query, doc['text']] for doc in candidates]
        rerank_scores = self.reranker.predict(cross_inp)
        
        # Attach scores and sort
        for i, score in enumerate(rerank_scores):
            candidates[i]['score'] = float(score)  # Use cross-encoder score as final score
            
        candidates = sorted(candidates, key=lambda x: x['score'], reverse=True)
        
        return candidates[:top_k]
    
    def generate_answer(self, question: str, context: str, sources: list = [], chat_history: list = None):
        """Generate answer using Groq LLM with strict Legal Guardrails.
        Sources list is injected into the prompt so the LLM can ONLY cite
        what was actually retrieved β€” no hallucinated references.
        """
        chat_history = chat_history or []
        chat_history = chat_history[-6:]  # Cap to last 3 turns
        # Build a numbered source registry for the LLM
        source_registry = ""
        for i, s in enumerate(sources, 1):
            source_registry += f"[{i}] {s.get('judgment_id', 'Unknown')}\n"

        prompt = f"""You are a strict, brilliant legal research assistant specializing in Indian Supreme Court judgments.

GUARDRAIL: You MUST ONLY answer questions related to law, legal processes, or the provided context.
If the question is entirely unrelated to law (e.g., "how to bake a cake"), reply EXACTLY with:
"I am a legal AI assistant. I can only answer questions related to law."

CITATION RULES β€” THIS IS CRITICAL:
- You may ONLY cite sources from the APPROVED SOURCE LIST below.
- Do NOT cite any case from your training memory that is not in the APPROVED SOURCE LIST.
- If you cite a case not in this list, you are hallucinating and failing your task.
- Use [1], [2], [3] etc. to refer to sources from the list below.

APPROVED SOURCE LIST (cite ONLY these):
{source_registry}
CONTEXT (retrieved paragraphs):
{context}

QUESTION: {question}

INSTRUCTIONS:
- Provide a detailed, comprehensive legal answer in a professional conversational tone.
- Explain concepts clearly so a lawyer finds it extremely useful.
- Cite ONLY from the APPROVED SOURCE LIST above using [1], [2], [3] format.
- Use proper legal terminology.
- Do NOT invent case names, citations, or dates.
- TEMPORAL AWARENESS: Look at the years in the judgment titles (e.g. 2023_CaseName). Newer judgments (e.g. 2023) supersede older judgments (e.g. 2010). If the retrieved context contains conflicting rulings, you MUST prioritize the newer judgment and explicitly warn the user that the older precedent may have been superseded.

ANSWER:"""

        messages = chat_history.copy()
        messages.append({"role": "user", "content": prompt})

        response = self.llm.chat.completions.create(
            model=self.llm_model,
            messages=messages,
            temperature=0.2,
            max_tokens=1024
        )
        
        return response.choices[0].message.content
    
    def query(self, question: str, top_k: int = 5, chat_history: list = None):
        """Main query method"""
        print(f"\n{'='*70}")
        print(f"QUERY: {question}")
        print('='*70)
        
        # Search
        print("\nSearching FAISS index...")
        results = self.search(question, top_k)
        
        print(f"Found {len(results)} relevant paragraphs")
        
        # Format context
        context_parts = []
        for i, r in enumerate(results, 1):
            context_parts.append(
                f"[{i}] {r['judgment_id']}\n{r['text']}"
            )
        context = "\n\n".join(context_parts)
        
        # Generate answer β€” pass sources so LLM can only cite what was retrieved
        print("Generating answer with LLM...")
        answer = self.generate_answer(question, context, sources=results, chat_history=chat_history)
        
        return {
            'question': question,
            'answer': answer,
            'sources': results
        }
        
    def query_with_document(self, question: str, filepath: str, chat_history: list = None):
        """Queries a specific document. Falls back to global RAG if answer not found."""
        chat_history = chat_history or []
        try:
            with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
                doc_text = f.read()
        except Exception as e:
            return {"answer": f"Error reading document: {e}", "sources": []}
            
        # JUDGMENTS are usually under 30k chars. Let's take as much as possible.
        if len(doc_text) > 30000:
            print("Document exceeds 30k chars. Applying Semantic Truncation...")
            try:
                sentences = [s for s in split_sentences(doc_text) if len(s.strip()) > 20]
                scores = self.importance_ranker.score(sentences)
                indexed = list(enumerate(zip(sentences, scores)))
                sorted_by_score = sorted(indexed, key=lambda x: x[1][1], reverse=True)
                
                selected_indices = []
                current_chars = 0
                for idx, (sentence, score) in sorted_by_score:
                    if current_chars + len(sentence) > 30000:
                        continue
                    selected_indices.append(idx)
                    current_chars += len(sentence)
                    if current_chars > 29000:
                        break
                
                # Restore original chronological order
                top_in_order = sorted([indexed[i] for i in selected_indices], key=lambda x: x[0])
                doc_text = " ".join(s for _, (s, _) in top_in_order) + "\n\n... [TRUNCATED SEMANTICALLY FOR LLM] ..."
            except Exception as e:
                print(f"Semantic Truncation failed: {e}. Falling back to naive truncation.")
                doc_text = doc_text[:30000] + "\n\n... [TRUNCATED DUE TO SIZE] ..."
            
        print(f"--- DOCUMENT QA START ---")
        print(f"File: {os.path.basename(filepath)}")
        print(f"Size: {len(doc_text)} chars")
        print(f"Question: {question}")

        prompt = f"""You are a strict Legal Document Auditor.
Your ONLY source of information is the text provided below. 

STRICT RULES:
1. Answer the QUESTION using ONLY the DOCUMENT text.
2. If the answer is not in the text, say "I cannot find this in the uploaded document."
3. DO NOT cite external cases (like Venkata Reddy or V.C. Shukla) unless they are explicitly mentioned in the text below.
4. If you use your own internal knowledge instead of the document, you are failing your task.

DOCUMENT TEXT:
{doc_text}

QUESTION: {question}

DETAILED ANSWER (citing specific paragraphs if possible):"""

        messages = chat_history.copy()
        messages.append({"role": "user", "content": prompt})

        response = self.llm.chat.completions.create(
            model=self.llm_model,
            messages=messages,
            temperature=0.1,
            max_tokens=1024
        )
        
        answer = response.choices[0].message.content.strip()
            
        return {
            'question': question,
            'answer': answer,
            'sources': [{'judgment_id': os.path.basename(filepath), 'score': 1.0}]
        }
    
    def close(self):
        self.db.close()

# Test
if __name__ == "__main__":
    engine = QueryEngine()
    
    # Test queries
    queries = [
        "What are the conditions for granting anticipatory bail?",
        "Explain the doctrine of legitimate expectation",
        "What is the burden of proof in criminal cases?"
    ]
    
    for query in queries:
        response = engine.query(query, top_k=3)
        
        print(f"\nANSWER:\n{response['answer']}\n")
        
        print("SOURCES:")
        for i, src in enumerate(response['sources'], 1):
            print(f"  [{i}] {src['judgment_id']} (score: {src['score']:.3f})")
        
        print("\n" + "="*70 + "\n")
    
    engine.close()