File size: 6,859 Bytes
ad06665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37be6ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad06665
 
 
 
 
37be6ad
ad06665
 
 
 
37be6ad
 
 
ad06665
37be6ad
ad06665
 
 
 
 
 
 
 
 
 
37be6ad
ad06665
 
 
 
 
 
 
808e67d
ad06665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37be6ad
ad06665
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import os
import sys
import psutil
from typing import Tuple, List, Optional, Any
from loguru import logger
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.documents import Document
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type

from src.config import settings

class SatelliteRAG:
    def __init__(self) -> None:
        """Initialize the RAG engine with embeddings, vector store, and LLM."""
        self._log_memory_usage()

        # 1. Load Embeddings
        self.embeddings = self._load_embeddings()

        # 2. Initialize Vector Store (Chroma)
        self.vector_store = self._init_vector_store()
        
        # 3. Initialize LLM
        self.llm = self._init_llm()

        logger.info("RAG Engine successfully initialized.")

    def _log_memory_usage(self) -> None:
        """Log current memory usage."""
        process = psutil.Process(os.getpid())
        mem_mb = process.memory_info().rss / 1024 / 1024
        logger.info(f"RAG Engine initializing... Memory Usage: {mem_mb:.2f} MB")

    def _load_embeddings(self) -> HuggingFaceEmbeddings:
        """Load HuggingFace embeddings."""
        logger.info(f"Step 1/3: Loading HuggingFace Embeddings ({settings.EMBEDDING_MODEL})...")
        try:
            embeddings = HuggingFaceEmbeddings(
                model_name=settings.EMBEDDING_MODEL,
                model_kwargs={'device': 'cpu'},
                encode_kwargs={'normalize_embeddings': True}
            )
            logger.info("Embeddings loaded successfully.")
            return embeddings
        except Exception as e:
            logger.error(f"Failed to load embeddings: {e}")
            raise e

    def _init_vector_store(self) -> Chroma:
        """Initialize Chroma Vector Store."""
        logger.info(f"Step 2/3: Connecting to ChromaDB at {settings.CHROMA_PATH}...")
        try:
            vector_store = Chroma(
                collection_name=settings.COLLECTION_NAME,
                embedding_function=self.embeddings,
                persist_directory=str(settings.CHROMA_PATH)
            )
            # Basic check to see if we can access the collection
            # accessing ._collection is a bit internal but effective for quick check
            count = vector_store._collection.count()
            logger.info(f"Vector Store ready. Contains {count} documents.")
            return vector_store
        except Exception as e:
            logger.error(f"Vector Store initialization failed: {e}")
            raise e

    def _init_llm(self) -> ChatGroq:
        """Initialize Groq LLM."""
        if not settings.GROQ_API_KEY:
            raise ValueError("GROQ_API_KEY not found in environment variables.")
            
        return ChatGroq(
            temperature=0,
            model_name=settings.LLM_MODEL,
            api_key=settings.GROQ_API_KEY
        )

    def _rewrite_query(self, question: str, chat_history: List[Tuple[str, str]]) -> str:
        """Rewrite question based on history to be standalone."""
        if not chat_history:
            return question
            
        logger.info("Rewriting question with conversational context...")
        
        template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
        
        Chat History:
        {history}
        
        Follow Up Input: {question}
        Standalone Question:"""
        
        try:
            prompt = ChatPromptTemplate.from_template(template)
            chain = prompt | self.llm | StrOutputParser()
            
            # Format history as a string
            history_str = "\n".join([f"User: {q}\nAssistant: {a}" for q, a in chat_history])
            
            standalone_question = chain.invoke({"history": history_str, "question": question})
            logger.info(f"Rephrased '{question}' -> '{standalone_question}'")
            return standalone_question
        except Exception as e:
            logger.error(f"Failed to rewrite question: {e}")
            return question

    @retry(
        stop=stop_after_attempt(3),
        wait=wait_exponential(multiplier=1, min=2, max=10),
        reraise=True
    )
    def query(self, question: str, chat_history: List[Tuple[str, str]] = []) -> Tuple[str, List[Document]]:
        """
        Query the RAG system.
        Retries up to 3 times on failure (e.g. API Rate Limits).
        """
        # 0. Contextual Rewriting
        standalone_question = self._rewrite_query(question, chat_history)

        # Retrieval
        logger.info(f"Starting query process for: {standalone_question}")
        try:
            # Force GC to clear any previous large objects
            import gc
            gc.collect()
            
            logger.info("Step 1: Initializing retriever...")
            # Reduced k from 10 to 4 to prevent Memory OOM on free tier spaces
            retriever = self.vector_store.as_retriever(search_kwargs={"k": 4})
            
            logger.info("Step 2: Invoking retriever (Embedding inference)...")
            docs = retriever.invoke(standalone_question)
            logger.info(f"Step 3: Retrieval successful. Found {len(docs)} chunks.")
            
            context_text = "\n\n".join([d.page_content for d in docs])
            
            logger.info("Step 4: Constructing prompt and calling Groq LLM...")
            
            template = """
            You are a Space Satellite Assistant, an expert in technical satellite data.
            Use the following context to answer the user's question.
            
            Guidelines:
            1. **Be Precise:** If the context mentions specific numbers (Mass, Date, Orbit), use them.
            2. **Synonyms:** If asked for "Instruments", look for "Payload" or "Cameras".
            3. **Honesty:** If the answer is truly not in the context, say "I don't have that specific information."
            
            Context:
            {context}
            
            Question: {question}
            Answer:
            """
            
            prompt = ChatPromptTemplate.from_template(template)
            chain = prompt | self.llm | StrOutputParser()
            
            # Use original question for answer generation to keep tone, but context is from standalone
            response = chain.invoke({"context": context_text, "question": question})
            logger.info("Step 5: LLM generation successful.")
            return response, docs
        except Exception as e:
            logger.error(f"Error inside SatelliteRAG.query: {e}")
            raise e