File size: 6,859 Bytes
ad06665 37be6ad ad06665 37be6ad ad06665 37be6ad ad06665 37be6ad ad06665 37be6ad ad06665 808e67d ad06665 37be6ad ad06665 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import os
import sys
import psutil
from typing import Tuple, List, Optional, Any
from loguru import logger
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.documents import Document
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from src.config import settings
class SatelliteRAG:
def __init__(self) -> None:
"""Initialize the RAG engine with embeddings, vector store, and LLM."""
self._log_memory_usage()
# 1. Load Embeddings
self.embeddings = self._load_embeddings()
# 2. Initialize Vector Store (Chroma)
self.vector_store = self._init_vector_store()
# 3. Initialize LLM
self.llm = self._init_llm()
logger.info("RAG Engine successfully initialized.")
def _log_memory_usage(self) -> None:
"""Log current memory usage."""
process = psutil.Process(os.getpid())
mem_mb = process.memory_info().rss / 1024 / 1024
logger.info(f"RAG Engine initializing... Memory Usage: {mem_mb:.2f} MB")
def _load_embeddings(self) -> HuggingFaceEmbeddings:
"""Load HuggingFace embeddings."""
logger.info(f"Step 1/3: Loading HuggingFace Embeddings ({settings.EMBEDDING_MODEL})...")
try:
embeddings = HuggingFaceEmbeddings(
model_name=settings.EMBEDDING_MODEL,
model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
logger.info("Embeddings loaded successfully.")
return embeddings
except Exception as e:
logger.error(f"Failed to load embeddings: {e}")
raise e
def _init_vector_store(self) -> Chroma:
"""Initialize Chroma Vector Store."""
logger.info(f"Step 2/3: Connecting to ChromaDB at {settings.CHROMA_PATH}...")
try:
vector_store = Chroma(
collection_name=settings.COLLECTION_NAME,
embedding_function=self.embeddings,
persist_directory=str(settings.CHROMA_PATH)
)
# Basic check to see if we can access the collection
# accessing ._collection is a bit internal but effective for quick check
count = vector_store._collection.count()
logger.info(f"Vector Store ready. Contains {count} documents.")
return vector_store
except Exception as e:
logger.error(f"Vector Store initialization failed: {e}")
raise e
def _init_llm(self) -> ChatGroq:
"""Initialize Groq LLM."""
if not settings.GROQ_API_KEY:
raise ValueError("GROQ_API_KEY not found in environment variables.")
return ChatGroq(
temperature=0,
model_name=settings.LLM_MODEL,
api_key=settings.GROQ_API_KEY
)
def _rewrite_query(self, question: str, chat_history: List[Tuple[str, str]]) -> str:
"""Rewrite question based on history to be standalone."""
if not chat_history:
return question
logger.info("Rewriting question with conversational context...")
template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
Chat History:
{history}
Follow Up Input: {question}
Standalone Question:"""
try:
prompt = ChatPromptTemplate.from_template(template)
chain = prompt | self.llm | StrOutputParser()
# Format history as a string
history_str = "\n".join([f"User: {q}\nAssistant: {a}" for q, a in chat_history])
standalone_question = chain.invoke({"history": history_str, "question": question})
logger.info(f"Rephrased '{question}' -> '{standalone_question}'")
return standalone_question
except Exception as e:
logger.error(f"Failed to rewrite question: {e}")
return question
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=2, max=10),
reraise=True
)
def query(self, question: str, chat_history: List[Tuple[str, str]] = []) -> Tuple[str, List[Document]]:
"""
Query the RAG system.
Retries up to 3 times on failure (e.g. API Rate Limits).
"""
# 0. Contextual Rewriting
standalone_question = self._rewrite_query(question, chat_history)
# Retrieval
logger.info(f"Starting query process for: {standalone_question}")
try:
# Force GC to clear any previous large objects
import gc
gc.collect()
logger.info("Step 1: Initializing retriever...")
# Reduced k from 10 to 4 to prevent Memory OOM on free tier spaces
retriever = self.vector_store.as_retriever(search_kwargs={"k": 4})
logger.info("Step 2: Invoking retriever (Embedding inference)...")
docs = retriever.invoke(standalone_question)
logger.info(f"Step 3: Retrieval successful. Found {len(docs)} chunks.")
context_text = "\n\n".join([d.page_content for d in docs])
logger.info("Step 4: Constructing prompt and calling Groq LLM...")
template = """
You are a Space Satellite Assistant, an expert in technical satellite data.
Use the following context to answer the user's question.
Guidelines:
1. **Be Precise:** If the context mentions specific numbers (Mass, Date, Orbit), use them.
2. **Synonyms:** If asked for "Instruments", look for "Payload" or "Cameras".
3. **Honesty:** If the answer is truly not in the context, say "I don't have that specific information."
Context:
{context}
Question: {question}
Answer:
"""
prompt = ChatPromptTemplate.from_template(template)
chain = prompt | self.llm | StrOutputParser()
# Use original question for answer generation to keep tone, but context is from standalone
response = chain.invoke({"context": context_text, "question": question})
logger.info("Step 5: LLM generation successful.")
return response, docs
except Exception as e:
logger.error(f"Error inside SatelliteRAG.query: {e}")
raise e
|