File size: 10,410 Bytes
40e6b7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
import numpy as np
import faiss
import os
from openai import OpenAI
from loguru import logger
from app.services.embedding_service import EmbeddingService
from app.constants import POLICIES


class RAGService:
    """ Handles policy retrieval with FAISS 
    CRAG = corrective RAG, basically retries if results are bad """
    
    def __init__(self, embedding_service):
        try:
            self.embed_service = embedding_service
            self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
            
            logger.info("Initializing RAG Service...")
            
            self.policies = POLICIES.copy()
            self.build_index()
            
            logger.success(f"RAG Service initialized with {len(self.policies)} policies")
        except Exception as e:
            logger.error(f"Failed to initialize RAG Service: {str(e)}")
            raise
    
    def build_index(self):
        """Build FAISS index from policy docs"""
        try:
            if not self.policies:
                logger.warning("No policies to index")
                self.index = None
                self.policy_embeddings = None
                return
            
            logger.info(f"Embedding {len(self.policies)} policy documents...")
            
            # embed all policies
            self.policy_embeddings = self.embed_service.embed_batch(self.policies)
            
            # FAISS index - using inner product since vectors are normalized
            dimension = self.embed_service.get_dimension()
            self.index = faiss.IndexFlatIP(dimension)
            
            self.index.add(self.policy_embeddings.astype('float32'))
            
            logger.debug(f"FAISS index built with {self.index.ntotal} vectors")
        except Exception as e:
            logger.error(f"Index building failed: {str(e)}")
            self.index = None
            self.policy_embeddings = None
    
    def add_documents(self, new_docs):
        """Add new docs to index on the fly"""
        try:
            if not new_docs:
                return
            
            logger.info(f"Adding {len(new_docs)} temporary documents...")
            
            new_embeddings = self.embed_service.embed_batch(new_docs)
            self.index.add(new_embeddings.astype('float32'))
            self.policies.extend(new_docs)
            
            # stack new embeddings with old ones
            self.policy_embeddings = np.vstack([self.policy_embeddings, new_embeddings])
            
            logger.debug(f"Index now contains {self.index.ntotal} documents")
        except Exception as e:
            logger.error(f"Failed to add documents: {str(e)}")
    
    def retrieve(self, query, top_k=3):
        """Basic retrieval from FAISS"""
        try:
            if self.index is None or self.index.ntotal == 0:
                logger.warning("Index is empty, returning no results")
                return []
            
            # embed and search
            query_emb = self.embed_service.embed_single(query).reshape(1, -1)
            scores, indices = self.index.search(query_emb.astype('float32'), top_k)
            
            results = []
            for i, (score, idx) in enumerate(zip(scores[0], indices[0])):
                if idx < len(self.policies):
                    results.append({
                        'text': self.policies[idx],
                        'score': float(score),
                        'index': int(idx),
                        'rank': i + 1
                    })
            
            logger.debug(f"Retrieved {len(results)} documents")
            return results
        except Exception as e:
            logger.error(f"Retrieval failed: {str(e)}")
            return []
    
    def judge_relevance(self, query, documents):
        # use llm to score how relevant each doc is
        # helps filter out garbage results
        try:
            if not documents:
                return []
            
            doc_texts = "\n\n".join([
                f"DOCUMENT {i+1}:\n{doc['text']}" 
                for i, doc in enumerate(documents)
            ])
            
            judge_prompt = f"""You are an expert relevance evaluator for a loan application rule generation system.

QUERY: {query}

RETRIEVED DOCUMENTS:
{doc_texts}

Task: Rate the relevance of each document to the query on a scale of 0.0 to 1.0.
- 1.0 = Highly relevant, directly helps answer the query
- 0.5 = Somewhat relevant, provides context
- 0.0 = Not relevant at all

Respond ONLY with a JSON array of scores, one per document in order.
Example: [0.9, 0.6, 0.2]

Scores:"""

            response = self.client.chat.completions.create(
                model="gpt-4o-mini",
                messages=[
                    {"role": "system", "content": "You are a relevance scoring expert. Respond only with a JSON array of numbers."},
                    {"role": "user", "content": judge_prompt}
                ],
                temperature=0.1,
                max_tokens=100
            )
            
            content = response.choices[0].message.content.strip()
            import json
            scores = json.loads(content)
            
            # clamp to 0-1 range
            scores = [max(0.0, min(1.0, float(s))) for s in scores]
            
            # pad if llm didnt return enough scores
            while len(scores) < len(documents):
                scores.append(0.5)
            
            logger.debug(f"LLM judge scores: {scores}")
            return scores[:len(documents)]
            
        except Exception as e:
            logger.error(f"LLM judge failed: {str(e)}")
            # fallback - just use retrieval scores
            return [doc['score'] / (doc['score'] + 1.0) for doc in documents]
    
    def refine_query(self, original_query, low_relevance_docs):
        # if results are bad, ask llm to rewrite the query
        # usually helps by adding more specific terms
        try:
            refine_prompt = f"""Original query: "{original_query}"

The retrieved documents were not very relevant. Suggest a better search query that focuses on key loan application terms like:
- Bureau score, credit score, CIBIL
- Business vintage, age
- Overdue amounts, DPD
- Income, FOIR
- GST, banking metrics

Respond with ONLY the improved query, no explanation.

Improved query:"""

            response = self.client.chat.completions.create(
                model="gpt-4o-mini",
                messages=[
                    {"role": "system", "content": "You are a query refinement expert for loan application rules."},
                    {"role": "user", "content": refine_prompt}
                ],
                temperature=0.3,
                max_tokens=100
            )
            
            refined = response.choices[0].message.content.strip().strip('"')
            logger.info(f"Refined query: {refined}")
            return refined if refined else original_query
            
        except Exception as e:
            logger.error(f"Query refinement failed: {str(e)}")
            return original_query
    
    def retrieve_with_crag(self, query, top_k=2, relevance_threshold=0.7):
        """
        CRAG = Corrective RAG
        retrieves docs, checks if theyre good, retries if not
        """
        try:
            logger.info(f"CRAG: Retrieving for query: '{query[:50]}...'")
            
            docs = self.retrieve(query, top_k=top_k)
            
            if not docs:
                logger.warning("No documents retrieved")
                return [], 0.0
            
            # judge how relevant results are
            relevance_scores = self.judge_relevance(query, docs)
            
            for doc, score in zip(docs, relevance_scores):
                doc['relevance'] = score
            
            avg_relevance = np.mean(relevance_scores)
            logger.debug(f"CRAG: Initial relevance: {avg_relevance:.3f}")
            
            # if relevance sucks, refine and try again
            if avg_relevance < relevance_threshold:
                logger.info("CRAG: Low relevance detected, refining query...")
                
                refined_query = self.refine_query(query, [d['text'] for d in docs])
                logger.debug(f"CRAG: Refined query: '{refined_query[:50]}...'")
                
                # try again with better query
                refined_docs = self.retrieve(refined_query, top_k=top_k)
                
                if refined_docs:
                    refined_relevance = self.judge_relevance(refined_query, refined_docs)
                    
                    for doc, score in zip(refined_docs, refined_relevance):
                        doc['relevance'] = score
                    
                    refined_avg = np.mean(refined_relevance)
                    logger.debug(f"CRAG: Refined relevance: {refined_avg:.3f}")
                    
                    # use refined results only if theyre better
                    if refined_avg > avg_relevance:
                        docs = refined_docs
                        avg_relevance = refined_avg
                        logger.info("CRAG: Using refined results")
                    else:
                        logger.info("CRAG: Keeping original results")
            
            # sort by relevance score
            docs.sort(key=lambda x: x['relevance'], reverse=True)
            
            return docs, avg_relevance
        
        except Exception as e:
            logger.error(f"CRAG failed: {str(e)}")
            return [], 0.0
    
    def format_context(self, documents, max_length=500):
        """Format docs into a string for llm context"""
        try:
            if not documents:
                return "No relevant policies found."
            
            context_parts = []
            for i, doc in enumerate(documents):
                text = doc['text'][:max_length]
                relevance = doc.get('relevance', doc.get('score', 0))
                context_parts.append(
                    f"Policy {i+1} (relevance: {relevance:.2f}):\n{text}"
                )
            
            return "\n\n".join(context_parts)
        except Exception as e:
            logger.error(f"Context formatting failed: {str(e)}")
            return "Error formatting policy context."