datasciencesage's picture
Upload 5 files
40e6b7a verified
import numpy as np
import faiss
import os
from openai import OpenAI
from loguru import logger
from app.services.embedding_service import EmbeddingService
from app.constants import POLICIES
class RAGService:
""" Handles policy retrieval with FAISS
CRAG = corrective RAG, basically retries if results are bad """
def __init__(self, embedding_service):
try:
self.embed_service = embedding_service
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
logger.info("Initializing RAG Service...")
self.policies = POLICIES.copy()
self.build_index()
logger.success(f"RAG Service initialized with {len(self.policies)} policies")
except Exception as e:
logger.error(f"Failed to initialize RAG Service: {str(e)}")
raise
def build_index(self):
"""Build FAISS index from policy docs"""
try:
if not self.policies:
logger.warning("No policies to index")
self.index = None
self.policy_embeddings = None
return
logger.info(f"Embedding {len(self.policies)} policy documents...")
# embed all policies
self.policy_embeddings = self.embed_service.embed_batch(self.policies)
# FAISS index - using inner product since vectors are normalized
dimension = self.embed_service.get_dimension()
self.index = faiss.IndexFlatIP(dimension)
self.index.add(self.policy_embeddings.astype('float32'))
logger.debug(f"FAISS index built with {self.index.ntotal} vectors")
except Exception as e:
logger.error(f"Index building failed: {str(e)}")
self.index = None
self.policy_embeddings = None
def add_documents(self, new_docs):
"""Add new docs to index on the fly"""
try:
if not new_docs:
return
logger.info(f"Adding {len(new_docs)} temporary documents...")
new_embeddings = self.embed_service.embed_batch(new_docs)
self.index.add(new_embeddings.astype('float32'))
self.policies.extend(new_docs)
# stack new embeddings with old ones
self.policy_embeddings = np.vstack([self.policy_embeddings, new_embeddings])
logger.debug(f"Index now contains {self.index.ntotal} documents")
except Exception as e:
logger.error(f"Failed to add documents: {str(e)}")
def retrieve(self, query, top_k=3):
"""Basic retrieval from FAISS"""
try:
if self.index is None or self.index.ntotal == 0:
logger.warning("Index is empty, returning no results")
return []
# embed and search
query_emb = self.embed_service.embed_single(query).reshape(1, -1)
scores, indices = self.index.search(query_emb.astype('float32'), top_k)
results = []
for i, (score, idx) in enumerate(zip(scores[0], indices[0])):
if idx < len(self.policies):
results.append({
'text': self.policies[idx],
'score': float(score),
'index': int(idx),
'rank': i + 1
})
logger.debug(f"Retrieved {len(results)} documents")
return results
except Exception as e:
logger.error(f"Retrieval failed: {str(e)}")
return []
def judge_relevance(self, query, documents):
# use llm to score how relevant each doc is
# helps filter out garbage results
try:
if not documents:
return []
doc_texts = "\n\n".join([
f"DOCUMENT {i+1}:\n{doc['text']}"
for i, doc in enumerate(documents)
])
judge_prompt = f"""You are an expert relevance evaluator for a loan application rule generation system.
QUERY: {query}
RETRIEVED DOCUMENTS:
{doc_texts}
Task: Rate the relevance of each document to the query on a scale of 0.0 to 1.0.
- 1.0 = Highly relevant, directly helps answer the query
- 0.5 = Somewhat relevant, provides context
- 0.0 = Not relevant at all
Respond ONLY with a JSON array of scores, one per document in order.
Example: [0.9, 0.6, 0.2]
Scores:"""
response = self.client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": "You are a relevance scoring expert. Respond only with a JSON array of numbers."},
{"role": "user", "content": judge_prompt}
],
temperature=0.1,
max_tokens=100
)
content = response.choices[0].message.content.strip()
import json
scores = json.loads(content)
# clamp to 0-1 range
scores = [max(0.0, min(1.0, float(s))) for s in scores]
# pad if llm didnt return enough scores
while len(scores) < len(documents):
scores.append(0.5)
logger.debug(f"LLM judge scores: {scores}")
return scores[:len(documents)]
except Exception as e:
logger.error(f"LLM judge failed: {str(e)}")
# fallback - just use retrieval scores
return [doc['score'] / (doc['score'] + 1.0) for doc in documents]
def refine_query(self, original_query, low_relevance_docs):
# if results are bad, ask llm to rewrite the query
# usually helps by adding more specific terms
try:
refine_prompt = f"""Original query: "{original_query}"
The retrieved documents were not very relevant. Suggest a better search query that focuses on key loan application terms like:
- Bureau score, credit score, CIBIL
- Business vintage, age
- Overdue amounts, DPD
- Income, FOIR
- GST, banking metrics
Respond with ONLY the improved query, no explanation.
Improved query:"""
response = self.client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": "You are a query refinement expert for loan application rules."},
{"role": "user", "content": refine_prompt}
],
temperature=0.3,
max_tokens=100
)
refined = response.choices[0].message.content.strip().strip('"')
logger.info(f"Refined query: {refined}")
return refined if refined else original_query
except Exception as e:
logger.error(f"Query refinement failed: {str(e)}")
return original_query
def retrieve_with_crag(self, query, top_k=2, relevance_threshold=0.7):
"""
CRAG = Corrective RAG
retrieves docs, checks if theyre good, retries if not
"""
try:
logger.info(f"CRAG: Retrieving for query: '{query[:50]}...'")
docs = self.retrieve(query, top_k=top_k)
if not docs:
logger.warning("No documents retrieved")
return [], 0.0
# judge how relevant results are
relevance_scores = self.judge_relevance(query, docs)
for doc, score in zip(docs, relevance_scores):
doc['relevance'] = score
avg_relevance = np.mean(relevance_scores)
logger.debug(f"CRAG: Initial relevance: {avg_relevance:.3f}")
# if relevance sucks, refine and try again
if avg_relevance < relevance_threshold:
logger.info("CRAG: Low relevance detected, refining query...")
refined_query = self.refine_query(query, [d['text'] for d in docs])
logger.debug(f"CRAG: Refined query: '{refined_query[:50]}...'")
# try again with better query
refined_docs = self.retrieve(refined_query, top_k=top_k)
if refined_docs:
refined_relevance = self.judge_relevance(refined_query, refined_docs)
for doc, score in zip(refined_docs, refined_relevance):
doc['relevance'] = score
refined_avg = np.mean(refined_relevance)
logger.debug(f"CRAG: Refined relevance: {refined_avg:.3f}")
# use refined results only if theyre better
if refined_avg > avg_relevance:
docs = refined_docs
avg_relevance = refined_avg
logger.info("CRAG: Using refined results")
else:
logger.info("CRAG: Keeping original results")
# sort by relevance score
docs.sort(key=lambda x: x['relevance'], reverse=True)
return docs, avg_relevance
except Exception as e:
logger.error(f"CRAG failed: {str(e)}")
return [], 0.0
def format_context(self, documents, max_length=500):
"""Format docs into a string for llm context"""
try:
if not documents:
return "No relevant policies found."
context_parts = []
for i, doc in enumerate(documents):
text = doc['text'][:max_length]
relevance = doc.get('relevance', doc.get('score', 0))
context_parts.append(
f"Policy {i+1} (relevance: {relevance:.2f}):\n{text}"
)
return "\n\n".join(context_parts)
except Exception as e:
logger.error(f"Context formatting failed: {str(e)}")
return "Error formatting policy context."