|
|
import os |
|
|
import sys |
|
|
import psutil |
|
|
from typing import Tuple, List, Optional, Any |
|
|
from loguru import logger |
|
|
from langchain_chroma import Chroma |
|
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
|
from langchain_groq import ChatGroq |
|
|
from langchain_core.prompts import ChatPromptTemplate |
|
|
from langchain_core.output_parsers import StrOutputParser |
|
|
from langchain_core.documents import Document |
|
|
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type |
|
|
|
|
|
from src.config import settings |
|
|
|
|
|
class SatelliteRAG: |
|
|
def __init__(self) -> None: |
|
|
"""Initialize the RAG engine with embeddings, vector store, and LLM.""" |
|
|
self._log_memory_usage() |
|
|
|
|
|
|
|
|
self.embeddings = self._load_embeddings() |
|
|
|
|
|
|
|
|
self.vector_store = self._init_vector_store() |
|
|
|
|
|
|
|
|
self.llm = self._init_llm() |
|
|
|
|
|
logger.info("RAG Engine successfully initialized.") |
|
|
|
|
|
def _log_memory_usage(self) -> None: |
|
|
"""Log current memory usage.""" |
|
|
process = psutil.Process(os.getpid()) |
|
|
mem_mb = process.memory_info().rss / 1024 / 1024 |
|
|
logger.info(f"RAG Engine initializing... Memory Usage: {mem_mb:.2f} MB") |
|
|
|
|
|
def _load_embeddings(self) -> HuggingFaceEmbeddings: |
|
|
"""Load HuggingFace embeddings.""" |
|
|
logger.info(f"Step 1/3: Loading HuggingFace Embeddings ({settings.EMBEDDING_MODEL})...") |
|
|
try: |
|
|
embeddings = HuggingFaceEmbeddings( |
|
|
model_name=settings.EMBEDDING_MODEL, |
|
|
model_kwargs={'device': 'cpu'}, |
|
|
encode_kwargs={'normalize_embeddings': True} |
|
|
) |
|
|
logger.info("Embeddings loaded successfully.") |
|
|
return embeddings |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load embeddings: {e}") |
|
|
raise e |
|
|
|
|
|
def _init_vector_store(self) -> Chroma: |
|
|
"""Initialize Chroma Vector Store.""" |
|
|
logger.info(f"Step 2/3: Connecting to ChromaDB at {settings.CHROMA_PATH}...") |
|
|
try: |
|
|
vector_store = Chroma( |
|
|
collection_name=settings.COLLECTION_NAME, |
|
|
embedding_function=self.embeddings, |
|
|
persist_directory=str(settings.CHROMA_PATH) |
|
|
) |
|
|
|
|
|
|
|
|
count = vector_store._collection.count() |
|
|
logger.info(f"Vector Store ready. Contains {count} documents.") |
|
|
return vector_store |
|
|
except Exception as e: |
|
|
logger.error(f"Vector Store initialization failed: {e}") |
|
|
raise e |
|
|
|
|
|
def _init_llm(self) -> ChatGroq: |
|
|
"""Initialize Groq LLM.""" |
|
|
if not settings.GROQ_API_KEY: |
|
|
raise ValueError("GROQ_API_KEY not found in environment variables.") |
|
|
|
|
|
return ChatGroq( |
|
|
temperature=0, |
|
|
model_name=settings.LLM_MODEL, |
|
|
api_key=settings.GROQ_API_KEY |
|
|
) |
|
|
|
|
|
def _rewrite_query(self, question: str, chat_history: List[Tuple[str, str]]) -> str: |
|
|
"""Rewrite question based on history to be standalone.""" |
|
|
if not chat_history: |
|
|
return question |
|
|
|
|
|
logger.info("Rewriting question with conversational context...") |
|
|
|
|
|
template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question. |
|
|
|
|
|
Chat History: |
|
|
{history} |
|
|
|
|
|
Follow Up Input: {question} |
|
|
Standalone Question:""" |
|
|
|
|
|
try: |
|
|
prompt = ChatPromptTemplate.from_template(template) |
|
|
chain = prompt | self.llm | StrOutputParser() |
|
|
|
|
|
|
|
|
history_str = "\n".join([f"User: {q}\nAssistant: {a}" for q, a in chat_history]) |
|
|
|
|
|
standalone_question = chain.invoke({"history": history_str, "question": question}) |
|
|
logger.info(f"Rephrased '{question}' -> '{standalone_question}'") |
|
|
return standalone_question |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to rewrite question: {e}") |
|
|
return question |
|
|
|
|
|
@retry( |
|
|
stop=stop_after_attempt(3), |
|
|
wait=wait_exponential(multiplier=1, min=2, max=10), |
|
|
reraise=True |
|
|
) |
|
|
def query(self, question: str, chat_history: List[Tuple[str, str]] = []) -> Tuple[str, List[Document]]: |
|
|
""" |
|
|
Query the RAG system. |
|
|
Retries up to 3 times on failure (e.g. API Rate Limits). |
|
|
""" |
|
|
|
|
|
standalone_question = self._rewrite_query(question, chat_history) |
|
|
|
|
|
|
|
|
logger.info(f"Starting query process for: {standalone_question}") |
|
|
try: |
|
|
|
|
|
import gc |
|
|
gc.collect() |
|
|
|
|
|
logger.info("Step 1: Initializing retriever...") |
|
|
|
|
|
retriever = self.vector_store.as_retriever(search_kwargs={"k": 4}) |
|
|
|
|
|
logger.info("Step 2: Invoking retriever (Embedding inference)...") |
|
|
docs = retriever.invoke(standalone_question) |
|
|
logger.info(f"Step 3: Retrieval successful. Found {len(docs)} chunks.") |
|
|
|
|
|
context_text = "\n\n".join([d.page_content for d in docs]) |
|
|
|
|
|
logger.info("Step 4: Constructing prompt and calling Groq LLM...") |
|
|
|
|
|
template = """ |
|
|
You are a Space Satellite Assistant, an expert in technical satellite data. |
|
|
Use the following context to answer the user's question. |
|
|
|
|
|
Guidelines: |
|
|
1. **Be Precise:** If the context mentions specific numbers (Mass, Date, Orbit), use them. |
|
|
2. **Synonyms:** If asked for "Instruments", look for "Payload" or "Cameras". |
|
|
3. **Honesty:** If the answer is truly not in the context, say "I don't have that specific information." |
|
|
|
|
|
Context: |
|
|
{context} |
|
|
|
|
|
Question: {question} |
|
|
Answer: |
|
|
""" |
|
|
|
|
|
prompt = ChatPromptTemplate.from_template(template) |
|
|
chain = prompt | self.llm | StrOutputParser() |
|
|
|
|
|
|
|
|
response = chain.invoke({"context": context_text, "question": question}) |
|
|
logger.info("Step 5: LLM generation successful.") |
|
|
return response, docs |
|
|
except Exception as e: |
|
|
logger.error(f"Error inside SatelliteRAG.query: {e}") |
|
|
raise e |
|
|
|