|
|
import numpy as np |
|
|
import faiss |
|
|
import os |
|
|
from openai import OpenAI |
|
|
from loguru import logger |
|
|
from app.services.embedding_service import EmbeddingService |
|
|
from app.constants import POLICIES |
|
|
|
|
|
|
|
|
class RAGService: |
|
|
""" Handles policy retrieval with FAISS |
|
|
CRAG = corrective RAG, basically retries if results are bad """ |
|
|
|
|
|
def __init__(self, embedding_service): |
|
|
try: |
|
|
self.embed_service = embedding_service |
|
|
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) |
|
|
|
|
|
logger.info("Initializing RAG Service...") |
|
|
|
|
|
self.policies = POLICIES.copy() |
|
|
self.build_index() |
|
|
|
|
|
logger.success(f"RAG Service initialized with {len(self.policies)} policies") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to initialize RAG Service: {str(e)}") |
|
|
raise |
|
|
|
|
|
def build_index(self): |
|
|
"""Build FAISS index from policy docs""" |
|
|
try: |
|
|
if not self.policies: |
|
|
logger.warning("No policies to index") |
|
|
self.index = None |
|
|
self.policy_embeddings = None |
|
|
return |
|
|
|
|
|
logger.info(f"Embedding {len(self.policies)} policy documents...") |
|
|
|
|
|
|
|
|
self.policy_embeddings = self.embed_service.embed_batch(self.policies) |
|
|
|
|
|
|
|
|
dimension = self.embed_service.get_dimension() |
|
|
self.index = faiss.IndexFlatIP(dimension) |
|
|
|
|
|
self.index.add(self.policy_embeddings.astype('float32')) |
|
|
|
|
|
logger.debug(f"FAISS index built with {self.index.ntotal} vectors") |
|
|
except Exception as e: |
|
|
logger.error(f"Index building failed: {str(e)}") |
|
|
self.index = None |
|
|
self.policy_embeddings = None |
|
|
|
|
|
def add_documents(self, new_docs): |
|
|
"""Add new docs to index on the fly""" |
|
|
try: |
|
|
if not new_docs: |
|
|
return |
|
|
|
|
|
logger.info(f"Adding {len(new_docs)} temporary documents...") |
|
|
|
|
|
new_embeddings = self.embed_service.embed_batch(new_docs) |
|
|
self.index.add(new_embeddings.astype('float32')) |
|
|
self.policies.extend(new_docs) |
|
|
|
|
|
|
|
|
self.policy_embeddings = np.vstack([self.policy_embeddings, new_embeddings]) |
|
|
|
|
|
logger.debug(f"Index now contains {self.index.ntotal} documents") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to add documents: {str(e)}") |
|
|
|
|
|
def retrieve(self, query, top_k=3): |
|
|
"""Basic retrieval from FAISS""" |
|
|
try: |
|
|
if self.index is None or self.index.ntotal == 0: |
|
|
logger.warning("Index is empty, returning no results") |
|
|
return [] |
|
|
|
|
|
|
|
|
query_emb = self.embed_service.embed_single(query).reshape(1, -1) |
|
|
scores, indices = self.index.search(query_emb.astype('float32'), top_k) |
|
|
|
|
|
results = [] |
|
|
for i, (score, idx) in enumerate(zip(scores[0], indices[0])): |
|
|
if idx < len(self.policies): |
|
|
results.append({ |
|
|
'text': self.policies[idx], |
|
|
'score': float(score), |
|
|
'index': int(idx), |
|
|
'rank': i + 1 |
|
|
}) |
|
|
|
|
|
logger.debug(f"Retrieved {len(results)} documents") |
|
|
return results |
|
|
except Exception as e: |
|
|
logger.error(f"Retrieval failed: {str(e)}") |
|
|
return [] |
|
|
|
|
|
def judge_relevance(self, query, documents): |
|
|
|
|
|
|
|
|
try: |
|
|
if not documents: |
|
|
return [] |
|
|
|
|
|
doc_texts = "\n\n".join([ |
|
|
f"DOCUMENT {i+1}:\n{doc['text']}" |
|
|
for i, doc in enumerate(documents) |
|
|
]) |
|
|
|
|
|
judge_prompt = f"""You are an expert relevance evaluator for a loan application rule generation system. |
|
|
|
|
|
QUERY: {query} |
|
|
|
|
|
RETRIEVED DOCUMENTS: |
|
|
{doc_texts} |
|
|
|
|
|
Task: Rate the relevance of each document to the query on a scale of 0.0 to 1.0. |
|
|
- 1.0 = Highly relevant, directly helps answer the query |
|
|
- 0.5 = Somewhat relevant, provides context |
|
|
- 0.0 = Not relevant at all |
|
|
|
|
|
Respond ONLY with a JSON array of scores, one per document in order. |
|
|
Example: [0.9, 0.6, 0.2] |
|
|
|
|
|
Scores:""" |
|
|
|
|
|
response = self.client.chat.completions.create( |
|
|
model="gpt-4o-mini", |
|
|
messages=[ |
|
|
{"role": "system", "content": "You are a relevance scoring expert. Respond only with a JSON array of numbers."}, |
|
|
{"role": "user", "content": judge_prompt} |
|
|
], |
|
|
temperature=0.1, |
|
|
max_tokens=100 |
|
|
) |
|
|
|
|
|
content = response.choices[0].message.content.strip() |
|
|
import json |
|
|
scores = json.loads(content) |
|
|
|
|
|
|
|
|
scores = [max(0.0, min(1.0, float(s))) for s in scores] |
|
|
|
|
|
|
|
|
while len(scores) < len(documents): |
|
|
scores.append(0.5) |
|
|
|
|
|
logger.debug(f"LLM judge scores: {scores}") |
|
|
return scores[:len(documents)] |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"LLM judge failed: {str(e)}") |
|
|
|
|
|
return [doc['score'] / (doc['score'] + 1.0) for doc in documents] |
|
|
|
|
|
def refine_query(self, original_query, low_relevance_docs): |
|
|
|
|
|
|
|
|
try: |
|
|
refine_prompt = f"""Original query: "{original_query}" |
|
|
|
|
|
The retrieved documents were not very relevant. Suggest a better search query that focuses on key loan application terms like: |
|
|
- Bureau score, credit score, CIBIL |
|
|
- Business vintage, age |
|
|
- Overdue amounts, DPD |
|
|
- Income, FOIR |
|
|
- GST, banking metrics |
|
|
|
|
|
Respond with ONLY the improved query, no explanation. |
|
|
|
|
|
Improved query:""" |
|
|
|
|
|
response = self.client.chat.completions.create( |
|
|
model="gpt-4o-mini", |
|
|
messages=[ |
|
|
{"role": "system", "content": "You are a query refinement expert for loan application rules."}, |
|
|
{"role": "user", "content": refine_prompt} |
|
|
], |
|
|
temperature=0.3, |
|
|
max_tokens=100 |
|
|
) |
|
|
|
|
|
refined = response.choices[0].message.content.strip().strip('"') |
|
|
logger.info(f"Refined query: {refined}") |
|
|
return refined if refined else original_query |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Query refinement failed: {str(e)}") |
|
|
return original_query |
|
|
|
|
|
def retrieve_with_crag(self, query, top_k=2, relevance_threshold=0.7): |
|
|
""" |
|
|
CRAG = Corrective RAG |
|
|
retrieves docs, checks if theyre good, retries if not |
|
|
""" |
|
|
try: |
|
|
logger.info(f"CRAG: Retrieving for query: '{query[:50]}...'") |
|
|
|
|
|
docs = self.retrieve(query, top_k=top_k) |
|
|
|
|
|
if not docs: |
|
|
logger.warning("No documents retrieved") |
|
|
return [], 0.0 |
|
|
|
|
|
|
|
|
relevance_scores = self.judge_relevance(query, docs) |
|
|
|
|
|
for doc, score in zip(docs, relevance_scores): |
|
|
doc['relevance'] = score |
|
|
|
|
|
avg_relevance = np.mean(relevance_scores) |
|
|
logger.debug(f"CRAG: Initial relevance: {avg_relevance:.3f}") |
|
|
|
|
|
|
|
|
if avg_relevance < relevance_threshold: |
|
|
logger.info("CRAG: Low relevance detected, refining query...") |
|
|
|
|
|
refined_query = self.refine_query(query, [d['text'] for d in docs]) |
|
|
logger.debug(f"CRAG: Refined query: '{refined_query[:50]}...'") |
|
|
|
|
|
|
|
|
refined_docs = self.retrieve(refined_query, top_k=top_k) |
|
|
|
|
|
if refined_docs: |
|
|
refined_relevance = self.judge_relevance(refined_query, refined_docs) |
|
|
|
|
|
for doc, score in zip(refined_docs, refined_relevance): |
|
|
doc['relevance'] = score |
|
|
|
|
|
refined_avg = np.mean(refined_relevance) |
|
|
logger.debug(f"CRAG: Refined relevance: {refined_avg:.3f}") |
|
|
|
|
|
|
|
|
if refined_avg > avg_relevance: |
|
|
docs = refined_docs |
|
|
avg_relevance = refined_avg |
|
|
logger.info("CRAG: Using refined results") |
|
|
else: |
|
|
logger.info("CRAG: Keeping original results") |
|
|
|
|
|
|
|
|
docs.sort(key=lambda x: x['relevance'], reverse=True) |
|
|
|
|
|
return docs, avg_relevance |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"CRAG failed: {str(e)}") |
|
|
return [], 0.0 |
|
|
|
|
|
def format_context(self, documents, max_length=500): |
|
|
"""Format docs into a string for llm context""" |
|
|
try: |
|
|
if not documents: |
|
|
return "No relevant policies found." |
|
|
|
|
|
context_parts = [] |
|
|
for i, doc in enumerate(documents): |
|
|
text = doc['text'][:max_length] |
|
|
relevance = doc.get('relevance', doc.get('score', 0)) |
|
|
context_parts.append( |
|
|
f"Policy {i+1} (relevance: {relevance:.2f}):\n{text}" |
|
|
) |
|
|
|
|
|
return "\n\n".join(context_parts) |
|
|
except Exception as e: |
|
|
logger.error(f"Context formatting failed: {str(e)}") |
|
|
return "Error formatting policy context." |
|
|
|