test / src /rag_engine.py
Kirtan001's picture
UI: Rebranding and cleanup
808e67d
import os
import sys
import psutil
from typing import Tuple, List, Optional, Any
from loguru import logger
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.documents import Document
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from src.config import settings
class SatelliteRAG:
def __init__(self) -> None:
"""Initialize the RAG engine with embeddings, vector store, and LLM."""
self._log_memory_usage()
# 1. Load Embeddings
self.embeddings = self._load_embeddings()
# 2. Initialize Vector Store (Chroma)
self.vector_store = self._init_vector_store()
# 3. Initialize LLM
self.llm = self._init_llm()
logger.info("RAG Engine successfully initialized.")
def _log_memory_usage(self) -> None:
"""Log current memory usage."""
process = psutil.Process(os.getpid())
mem_mb = process.memory_info().rss / 1024 / 1024
logger.info(f"RAG Engine initializing... Memory Usage: {mem_mb:.2f} MB")
def _load_embeddings(self) -> HuggingFaceEmbeddings:
"""Load HuggingFace embeddings."""
logger.info(f"Step 1/3: Loading HuggingFace Embeddings ({settings.EMBEDDING_MODEL})...")
try:
embeddings = HuggingFaceEmbeddings(
model_name=settings.EMBEDDING_MODEL,
model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
logger.info("Embeddings loaded successfully.")
return embeddings
except Exception as e:
logger.error(f"Failed to load embeddings: {e}")
raise e
def _init_vector_store(self) -> Chroma:
"""Initialize Chroma Vector Store."""
logger.info(f"Step 2/3: Connecting to ChromaDB at {settings.CHROMA_PATH}...")
try:
vector_store = Chroma(
collection_name=settings.COLLECTION_NAME,
embedding_function=self.embeddings,
persist_directory=str(settings.CHROMA_PATH)
)
# Basic check to see if we can access the collection
# accessing ._collection is a bit internal but effective for quick check
count = vector_store._collection.count()
logger.info(f"Vector Store ready. Contains {count} documents.")
return vector_store
except Exception as e:
logger.error(f"Vector Store initialization failed: {e}")
raise e
def _init_llm(self) -> ChatGroq:
"""Initialize Groq LLM."""
if not settings.GROQ_API_KEY:
raise ValueError("GROQ_API_KEY not found in environment variables.")
return ChatGroq(
temperature=0,
model_name=settings.LLM_MODEL,
api_key=settings.GROQ_API_KEY
)
def _rewrite_query(self, question: str, chat_history: List[Tuple[str, str]]) -> str:
"""Rewrite question based on history to be standalone."""
if not chat_history:
return question
logger.info("Rewriting question with conversational context...")
template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
Chat History:
{history}
Follow Up Input: {question}
Standalone Question:"""
try:
prompt = ChatPromptTemplate.from_template(template)
chain = prompt | self.llm | StrOutputParser()
# Format history as a string
history_str = "\n".join([f"User: {q}\nAssistant: {a}" for q, a in chat_history])
standalone_question = chain.invoke({"history": history_str, "question": question})
logger.info(f"Rephrased '{question}' -> '{standalone_question}'")
return standalone_question
except Exception as e:
logger.error(f"Failed to rewrite question: {e}")
return question
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=2, max=10),
reraise=True
)
def query(self, question: str, chat_history: List[Tuple[str, str]] = []) -> Tuple[str, List[Document]]:
"""
Query the RAG system.
Retries up to 3 times on failure (e.g. API Rate Limits).
"""
# 0. Contextual Rewriting
standalone_question = self._rewrite_query(question, chat_history)
# Retrieval
logger.info(f"Starting query process for: {standalone_question}")
try:
# Force GC to clear any previous large objects
import gc
gc.collect()
logger.info("Step 1: Initializing retriever...")
# Reduced k from 10 to 4 to prevent Memory OOM on free tier spaces
retriever = self.vector_store.as_retriever(search_kwargs={"k": 4})
logger.info("Step 2: Invoking retriever (Embedding inference)...")
docs = retriever.invoke(standalone_question)
logger.info(f"Step 3: Retrieval successful. Found {len(docs)} chunks.")
context_text = "\n\n".join([d.page_content for d in docs])
logger.info("Step 4: Constructing prompt and calling Groq LLM...")
template = """
You are a Space Satellite Assistant, an expert in technical satellite data.
Use the following context to answer the user's question.
Guidelines:
1. **Be Precise:** If the context mentions specific numbers (Mass, Date, Orbit), use them.
2. **Synonyms:** If asked for "Instruments", look for "Payload" or "Cameras".
3. **Honesty:** If the answer is truly not in the context, say "I don't have that specific information."
Context:
{context}
Question: {question}
Answer:
"""
prompt = ChatPromptTemplate.from_template(template)
chain = prompt | self.llm | StrOutputParser()
# Use original question for answer generation to keep tone, but context is from standalone
response = chain.invoke({"context": context_text, "question": question})
logger.info("Step 5: LLM generation successful.")
return response, docs
except Exception as e:
logger.error(f"Error inside SatelliteRAG.query: {e}")
raise e