Kirtan001 commited on
Commit
37be6ad
·
1 Parent(s): 3966fc8

Feat: Implement Conversational Memory (Contextual Rewriting)

Browse files
Files changed (2) hide show
  1. src/app.py +10 -1
  2. src/rag_engine.py +36 -3
src/app.py CHANGED
@@ -103,7 +103,16 @@ if prompt := st.chat_input("Ask about any satellite (e.g., 'What is Gaofen 1?').
103
  if health["warning"]:
104
  st.warning(f"⚠️ Low Memory Warning: Only {health['available_mb']:.0f}MB available. Query might be slow.")
105
 
106
- response, docs = engine.query(prompt)
 
 
 
 
 
 
 
 
 
107
 
108
  st.markdown(response)
109
 
 
103
  if health["warning"]:
104
  st.warning(f"⚠️ Low Memory Warning: Only {health['available_mb']:.0f}MB available. Query might be slow.")
105
 
106
+ # Construct Chat History
107
+ # We need pairs of (User, AI) from session_state.messages
108
+ # Excluding the current new prompt which is already appended but not part of 'history' yet for this context
109
+ chat_history = []
110
+ msgs = st.session_state.messages[:-1]
111
+ for i in range(0, len(msgs) - 1, 2):
112
+ if msgs[i]["role"] == "user" and msgs[i+1]["role"] == "assistant":
113
+ chat_history.append((msgs[i]["content"], msgs[i+1]["content"]))
114
+
115
+ response, docs = engine.query(prompt, chat_history=chat_history)
116
 
117
  st.markdown(response)
118
 
src/rag_engine.py CHANGED
@@ -79,18 +79,50 @@ class SatelliteRAG:
79
  api_key=settings.GROQ_API_KEY
80
  )
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  @retry(
83
  stop=stop_after_attempt(3),
84
  wait=wait_exponential(multiplier=1, min=2, max=10),
85
  reraise=True
86
  )
87
- def query(self, question: str) -> Tuple[str, List[Document]]:
88
  """
89
  Query the RAG system.
90
  Retries up to 3 times on failure (e.g. API Rate Limits).
91
  """
 
 
 
92
  # Retrieval
93
- logger.info(f"Starting query process for: {question}")
94
  try:
95
  # Force GC to clear any previous large objects
96
  import gc
@@ -101,7 +133,7 @@ class SatelliteRAG:
101
  retriever = self.vector_store.as_retriever(search_kwargs={"k": 4})
102
 
103
  logger.info("Step 2: Invoking retriever (Embedding inference)...")
104
- docs = retriever.invoke(question)
105
  logger.info(f"Step 3: Retrieval successful. Found {len(docs)} chunks.")
106
 
107
  context_text = "\n\n".join([d.page_content for d in docs])
@@ -127,6 +159,7 @@ class SatelliteRAG:
127
  prompt = ChatPromptTemplate.from_template(template)
128
  chain = prompt | self.llm | StrOutputParser()
129
 
 
130
  response = chain.invoke({"context": context_text, "question": question})
131
  logger.info("Step 5: LLM generation successful.")
132
  return response, docs
 
79
  api_key=settings.GROQ_API_KEY
80
  )
81
 
82
+ def _rewrite_query(self, question: str, chat_history: List[Tuple[str, str]]) -> str:
83
+ """Rewrite question based on history to be standalone."""
84
+ if not chat_history:
85
+ return question
86
+
87
+ logger.info("Rewriting question with conversational context...")
88
+
89
+ template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
90
+
91
+ Chat History:
92
+ {history}
93
+
94
+ Follow Up Input: {question}
95
+ Standalone Question:"""
96
+
97
+ try:
98
+ prompt = ChatPromptTemplate.from_template(template)
99
+ chain = prompt | self.llm | StrOutputParser()
100
+
101
+ # Format history as a string
102
+ history_str = "\n".join([f"User: {q}\nAssistant: {a}" for q, a in chat_history])
103
+
104
+ standalone_question = chain.invoke({"history": history_str, "question": question})
105
+ logger.info(f"Rephrased '{question}' -> '{standalone_question}'")
106
+ return standalone_question
107
+ except Exception as e:
108
+ logger.error(f"Failed to rewrite question: {e}")
109
+ return question
110
+
111
  @retry(
112
  stop=stop_after_attempt(3),
113
  wait=wait_exponential(multiplier=1, min=2, max=10),
114
  reraise=True
115
  )
116
+ def query(self, question: str, chat_history: List[Tuple[str, str]] = []) -> Tuple[str, List[Document]]:
117
  """
118
  Query the RAG system.
119
  Retries up to 3 times on failure (e.g. API Rate Limits).
120
  """
121
+ # 0. Contextual Rewriting
122
+ standalone_question = self._rewrite_query(question, chat_history)
123
+
124
  # Retrieval
125
+ logger.info(f"Starting query process for: {standalone_question}")
126
  try:
127
  # Force GC to clear any previous large objects
128
  import gc
 
133
  retriever = self.vector_store.as_retriever(search_kwargs={"k": 4})
134
 
135
  logger.info("Step 2: Invoking retriever (Embedding inference)...")
136
+ docs = retriever.invoke(standalone_question)
137
  logger.info(f"Step 3: Retrieval successful. Found {len(docs)} chunks.")
138
 
139
  context_text = "\n\n".join([d.page_content for d in docs])
 
159
  prompt = ChatPromptTemplate.from_template(template)
160
  chain = prompt | self.llm | StrOutputParser()
161
 
162
+ # Use original question for answer generation to keep tone, but context is from standalone
163
  response = chain.invoke({"context": context_text, "question": question})
164
  logger.info("Step 5: LLM generation successful.")
165
  return response, docs