Feat: Implement Conversational Memory (Contextual Rewriting)
Browse files- src/app.py +10 -1
- src/rag_engine.py +36 -3
src/app.py
CHANGED
|
@@ -103,7 +103,16 @@ if prompt := st.chat_input("Ask about any satellite (e.g., 'What is Gaofen 1?').
|
|
| 103 |
if health["warning"]:
|
| 104 |
st.warning(f"⚠️ Low Memory Warning: Only {health['available_mb']:.0f}MB available. Query might be slow.")
|
| 105 |
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
st.markdown(response)
|
| 109 |
|
|
|
|
| 103 |
if health["warning"]:
|
| 104 |
st.warning(f"⚠️ Low Memory Warning: Only {health['available_mb']:.0f}MB available. Query might be slow.")
|
| 105 |
|
| 106 |
+
# Construct Chat History
|
| 107 |
+
# We need pairs of (User, AI) from session_state.messages
|
| 108 |
+
# Excluding the current new prompt which is already appended but not part of 'history' yet for this context
|
| 109 |
+
chat_history = []
|
| 110 |
+
msgs = st.session_state.messages[:-1]
|
| 111 |
+
for i in range(0, len(msgs) - 1, 2):
|
| 112 |
+
if msgs[i]["role"] == "user" and msgs[i+1]["role"] == "assistant":
|
| 113 |
+
chat_history.append((msgs[i]["content"], msgs[i+1]["content"]))
|
| 114 |
+
|
| 115 |
+
response, docs = engine.query(prompt, chat_history=chat_history)
|
| 116 |
|
| 117 |
st.markdown(response)
|
| 118 |
|
src/rag_engine.py
CHANGED
|
@@ -79,18 +79,50 @@ class SatelliteRAG:
|
|
| 79 |
api_key=settings.GROQ_API_KEY
|
| 80 |
)
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
@retry(
|
| 83 |
stop=stop_after_attempt(3),
|
| 84 |
wait=wait_exponential(multiplier=1, min=2, max=10),
|
| 85 |
reraise=True
|
| 86 |
)
|
| 87 |
-
def query(self, question: str) -> Tuple[str, List[Document]]:
|
| 88 |
"""
|
| 89 |
Query the RAG system.
|
| 90 |
Retries up to 3 times on failure (e.g. API Rate Limits).
|
| 91 |
"""
|
|
|
|
|
|
|
|
|
|
| 92 |
# Retrieval
|
| 93 |
-
logger.info(f"Starting query process for: {
|
| 94 |
try:
|
| 95 |
# Force GC to clear any previous large objects
|
| 96 |
import gc
|
|
@@ -101,7 +133,7 @@ class SatelliteRAG:
|
|
| 101 |
retriever = self.vector_store.as_retriever(search_kwargs={"k": 4})
|
| 102 |
|
| 103 |
logger.info("Step 2: Invoking retriever (Embedding inference)...")
|
| 104 |
-
docs = retriever.invoke(
|
| 105 |
logger.info(f"Step 3: Retrieval successful. Found {len(docs)} chunks.")
|
| 106 |
|
| 107 |
context_text = "\n\n".join([d.page_content for d in docs])
|
|
@@ -127,6 +159,7 @@ class SatelliteRAG:
|
|
| 127 |
prompt = ChatPromptTemplate.from_template(template)
|
| 128 |
chain = prompt | self.llm | StrOutputParser()
|
| 129 |
|
|
|
|
| 130 |
response = chain.invoke({"context": context_text, "question": question})
|
| 131 |
logger.info("Step 5: LLM generation successful.")
|
| 132 |
return response, docs
|
|
|
|
| 79 |
api_key=settings.GROQ_API_KEY
|
| 80 |
)
|
| 81 |
|
| 82 |
+
def _rewrite_query(self, question: str, chat_history: List[Tuple[str, str]]) -> str:
|
| 83 |
+
"""Rewrite question based on history to be standalone."""
|
| 84 |
+
if not chat_history:
|
| 85 |
+
return question
|
| 86 |
+
|
| 87 |
+
logger.info("Rewriting question with conversational context...")
|
| 88 |
+
|
| 89 |
+
template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
|
| 90 |
+
|
| 91 |
+
Chat History:
|
| 92 |
+
{history}
|
| 93 |
+
|
| 94 |
+
Follow Up Input: {question}
|
| 95 |
+
Standalone Question:"""
|
| 96 |
+
|
| 97 |
+
try:
|
| 98 |
+
prompt = ChatPromptTemplate.from_template(template)
|
| 99 |
+
chain = prompt | self.llm | StrOutputParser()
|
| 100 |
+
|
| 101 |
+
# Format history as a string
|
| 102 |
+
history_str = "\n".join([f"User: {q}\nAssistant: {a}" for q, a in chat_history])
|
| 103 |
+
|
| 104 |
+
standalone_question = chain.invoke({"history": history_str, "question": question})
|
| 105 |
+
logger.info(f"Rephrased '{question}' -> '{standalone_question}'")
|
| 106 |
+
return standalone_question
|
| 107 |
+
except Exception as e:
|
| 108 |
+
logger.error(f"Failed to rewrite question: {e}")
|
| 109 |
+
return question
|
| 110 |
+
|
| 111 |
@retry(
|
| 112 |
stop=stop_after_attempt(3),
|
| 113 |
wait=wait_exponential(multiplier=1, min=2, max=10),
|
| 114 |
reraise=True
|
| 115 |
)
|
| 116 |
+
def query(self, question: str, chat_history: List[Tuple[str, str]] = []) -> Tuple[str, List[Document]]:
|
| 117 |
"""
|
| 118 |
Query the RAG system.
|
| 119 |
Retries up to 3 times on failure (e.g. API Rate Limits).
|
| 120 |
"""
|
| 121 |
+
# 0. Contextual Rewriting
|
| 122 |
+
standalone_question = self._rewrite_query(question, chat_history)
|
| 123 |
+
|
| 124 |
# Retrieval
|
| 125 |
+
logger.info(f"Starting query process for: {standalone_question}")
|
| 126 |
try:
|
| 127 |
# Force GC to clear any previous large objects
|
| 128 |
import gc
|
|
|
|
| 133 |
retriever = self.vector_store.as_retriever(search_kwargs={"k": 4})
|
| 134 |
|
| 135 |
logger.info("Step 2: Invoking retriever (Embedding inference)...")
|
| 136 |
+
docs = retriever.invoke(standalone_question)
|
| 137 |
logger.info(f"Step 3: Retrieval successful. Found {len(docs)} chunks.")
|
| 138 |
|
| 139 |
context_text = "\n\n".join([d.page_content for d in docs])
|
|
|
|
| 159 |
prompt = ChatPromptTemplate.from_template(template)
|
| 160 |
chain = prompt | self.llm | StrOutputParser()
|
| 161 |
|
| 162 |
+
# Use original question for answer generation to keep tone, but context is from standalone
|
| 163 |
response = chain.invoke({"context": context_text, "question": question})
|
| 164 |
logger.info("Step 5: LLM generation successful.")
|
| 165 |
return response, docs
|