Subhadip007 commited on
Commit
5951bbe
·
1 Parent(s): 32cff2f

feat: Add full multi-turn conversation memory and context rewriting

Browse files
frontend-next/app/page.tsx CHANGED
@@ -430,6 +430,11 @@ export default function App() {
430
  setSidebarOpen(false);
431
  };
432
 
 
 
 
 
 
433
  const handleSend = async () => {
434
  if (!query.trim() || isStreaming) return;
435
 
@@ -473,12 +478,22 @@ export default function App() {
473
  setQuery("");
474
  setIsStreaming(true);
475
 
 
 
 
 
 
 
 
 
 
476
  try {
477
  const res = await fetch(`${API_URL}/query/stream`, {
478
  method: "POST",
479
  headers: { "Content-Type": "application/json" },
480
  body: JSON.stringify({
481
  question: originalQuery,
 
482
  top_k: topK,
483
  filter_category: category === "All" ? undefined : category,
484
  filter_year_gte: filterYear === "All" ? undefined : parseInt(filterYear, 10)
@@ -674,11 +689,15 @@ export default function App() {
674
  </motion.button>
675
  )}
676
  </AnimatePresence>
677
- {/* Header API Status */}
678
  <div className="top-api-status" style={{ display: 'flex', gap: '12px', alignItems: 'center' }}>
679
  <button onClick={() => setShowInfo(true)} className="nav-icon-btn" aria-label="Project Info" style={{ background: 'rgba(255,255,255,0.05)', border: '1px solid rgba(255,255,255,0.1)', padding: '6px', borderRadius: '50%', color: 'var(--text-muted)', cursor: 'pointer', display: 'flex', alignItems: 'center', justifyContent: 'center' }}>
680
  <Info size={16} />
681
  </button>
 
 
 
 
 
682
  <div className="nav-status">
683
  <div className={`status-dot ${apiStatus === 'online' ? 'status-online' : 'status-offline'}`} />
684
  {apiStatus === 'online' ? 'API Online' : apiStatus === 'connecting' ? 'Connecting...' : 'API Offline'}
@@ -782,6 +801,11 @@ export default function App() {
782
  <span className="model-badge" style={{ background: 'rgba(255,255,255,0.05)', padding: '2px 8px', borderRadius: '4px', border: '1px solid rgba(255,255,255,0.1)', color: 'var(--text-muted)', fontSize: '0.75rem' }}>
783
  {msg.model_used || "Auto-Detecting..."}
784
  </span>
 
 
 
 
 
785
  </div>
786
 
787
  <>
 
430
  setSidebarOpen(false);
431
  };
432
 
433
+ const handleClearConversation = () => {
434
+ if (!activeSessionId) return;
435
+ setSessions(prev => prev.map(s => s.id === activeSessionId ? { ...s, messages: [] } : s));
436
+ };
437
+
438
  const handleSend = async () => {
439
  if (!query.trim() || isStreaming) return;
440
 
 
478
  setQuery("");
479
  setIsStreaming(true);
480
 
481
+ const history = currentMessages
482
+ .filter(m => m.role === "user" || m.role === "assistant")
483
+ .map(m => ({
484
+ role: m.role,
485
+ content: m.content,
486
+ citations: m.citations || []
487
+ }))
488
+ .slice(-20);
489
+
490
  try {
491
  const res = await fetch(`${API_URL}/query/stream`, {
492
  method: "POST",
493
  headers: { "Content-Type": "application/json" },
494
  body: JSON.stringify({
495
  question: originalQuery,
496
+ history: history,
497
  top_k: topK,
498
  filter_category: category === "All" ? undefined : category,
499
  filter_year_gte: filterYear === "All" ? undefined : parseInt(filterYear, 10)
 
689
  </motion.button>
690
  )}
691
  </AnimatePresence>
 
692
  <div className="top-api-status" style={{ display: 'flex', gap: '12px', alignItems: 'center' }}>
693
  <button onClick={() => setShowInfo(true)} className="nav-icon-btn" aria-label="Project Info" style={{ background: 'rgba(255,255,255,0.05)', border: '1px solid rgba(255,255,255,0.1)', padding: '6px', borderRadius: '50%', color: 'var(--text-muted)', cursor: 'pointer', display: 'flex', alignItems: 'center', justifyContent: 'center' }}>
694
  <Info size={16} />
695
  </button>
696
+ {activeSessionId && currentMessages.length > 0 && (
697
+ <button onClick={handleClearConversation} className="nav-icon-btn" aria-label="Clear Conversation" title="Clear current conversation context" style={{ background: 'rgba(255,255,255,0.05)', border: '1px solid rgba(255,255,255,0.1)', padding: '6px 12px', borderRadius: '16px', color: 'var(--text-muted)', cursor: 'pointer', display: 'flex', alignItems: 'center', justifyContent: 'center', gap: '6px', fontSize: '0.8rem' }}>
698
+ <Trash2 size={14} /> Clear context
699
+ </button>
700
+ )}
701
  <div className="nav-status">
702
  <div className={`status-dot ${apiStatus === 'online' ? 'status-online' : 'status-offline'}`} />
703
  {apiStatus === 'online' ? 'API Online' : apiStatus === 'connecting' ? 'Connecting...' : 'API Offline'}
 
801
  <span className="model-badge" style={{ background: 'rgba(255,255,255,0.05)', padding: '2px 8px', borderRadius: '4px', border: '1px solid rgba(255,255,255,0.1)', color: 'var(--text-muted)', fontSize: '0.75rem' }}>
802
  {msg.model_used || "Auto-Detecting..."}
803
  </span>
804
+ {i >= 2 && (
805
+ <span style={{ fontSize: '0.7rem', background: 'rgba(138, 43, 226, 0.15)', border: '1px solid rgba(138, 43, 226, 0.3)', padding: '2px 8px', borderRadius: '4px', color: 'var(--accent-2)', marginLeft: 'auto', display: 'flex', alignItems: 'center', gap: '4px' }}>
806
+ <Layers size={10} /> Using conversation context
807
+ </span>
808
+ )}
809
  </div>
810
 
811
  <>
src/api/main.py CHANGED
@@ -47,7 +47,7 @@ class FeedbackRequest(BaseModel):
47
  model_used: str
48
  citations_count: int
49
  total_time_ms: float
50
- from src.rag.pipeline import RAGPipeline
51
  from src.utils.logger import setup_logger, get_logger
52
 
53
 
@@ -187,6 +187,7 @@ async def stream_query_papers(
187
  try:
188
  for chunk in pipeline.stream_query(
189
  question = query_input.question,
 
190
  top_k = query_input.top_k,
191
  filter_category = query_input.filter_category,
192
  filter_year_gte = query_input.filter_year_gte,
@@ -265,6 +266,7 @@ async def query_papers(
265
  response = await asyncio.to_thread(
266
  pipeline.query,
267
  query_input.question,
 
268
  query_input.top_k,
269
  query_input.filter_category,
270
  query_input.filter_year_gte,
 
47
  model_used: str
48
  citations_count: int
49
  total_time_ms: float
50
+ from src.rag.pipeline import RAGPipeline, ConversationTurn
51
  from src.utils.logger import setup_logger, get_logger
52
 
53
 
 
187
  try:
188
  for chunk in pipeline.stream_query(
189
  question = query_input.question,
190
+ history = [ConversationTurn(role=t.role, content=t.content, citations=t.citations) for t in query_input.history],
191
  top_k = query_input.top_k,
192
  filter_category = query_input.filter_category,
193
  filter_year_gte = query_input.filter_year_gte,
 
266
  response = await asyncio.to_thread(
267
  pipeline.query,
268
  query_input.question,
269
+ [ConversationTurn(role=t.role, content=t.content, citations=t.citations) for t in query_input.history],
270
  query_input.top_k,
271
  query_input.filter_category,
272
  query_input.filter_year_gte,
src/api/schemas.py CHANGED
@@ -11,7 +11,13 @@ WHY PYDANTIC SCHEMAS IN THE API LAYER:
11
  """
12
 
13
  from pydantic import BaseModel, Field
14
- from typing import Optional
 
 
 
 
 
 
15
 
16
 
17
 
@@ -28,6 +34,10 @@ class QueryRequest(BaseModel):
28
  description = "Research question to answer",
29
  examples = ["How does LoRA reduce trainable parameters?"]
30
  )
 
 
 
 
31
  top_k: int = Field(
32
  default = 5,
33
  ge = 1, # ge = greater than or equal
 
11
  """
12
 
13
  from pydantic import BaseModel, Field
14
+ from typing import Optional, List
15
+
16
+
17
+ class ConversationTurnSchema(BaseModel):
18
+ role: str
19
+ content: str
20
+ citations: list = []
21
 
22
 
23
 
 
34
  description = "Research question to answer",
35
  examples = ["How does LoRA reduce trainable parameters?"]
36
  )
37
+ history: List[ConversationTurnSchema] = Field(
38
+ default=[],
39
+ description="Conversation history for context"
40
+ )
41
  top_k: int = Field(
42
  default = 5,
43
  ge = 1, # ge = greater than or equal
src/rag/llm_client.py CHANGED
@@ -113,6 +113,7 @@ class MultiModelClient:
113
  system_prompt: str,
114
  user_prompt: str,
115
  original_query: str = "",
 
116
  temperature: float = LLM_TEMPERATURE,
117
  max_tokens: int = LLM_MAX_TOKENS,
118
  stream: bool = False
@@ -124,10 +125,10 @@ class MultiModelClient:
124
  Otherwise, result is a string.
125
  """
126
  models_to_try = self.get_model_for_query(original_query)
127
- messages = [
128
- {"role": "system", "content": system_prompt},
129
- {"role": "user", "content": user_prompt}
130
- ]
131
 
132
  for model in models_to_try:
133
  try:
 
113
  system_prompt: str,
114
  user_prompt: str,
115
  original_query: str = "",
116
+ history: list = None,
117
  temperature: float = LLM_TEMPERATURE,
118
  max_tokens: int = LLM_MAX_TOKENS,
119
  stream: bool = False
 
125
  Otherwise, result is a string.
126
  """
127
  models_to_try = self.get_model_for_query(original_query)
128
+ messages = [{"role": "system", "content": system_prompt}]
129
+ if history:
130
+ messages.extend(history)
131
+ messages.append({"role": "user", "content": user_prompt})
132
 
133
  for model in models_to_try:
134
  try:
src/rag/pipeline.py CHANGED
@@ -15,7 +15,13 @@ PIPELINE FLOW:
15
  import time
16
  import json
17
  from dataclasses import dataclass, field
18
- from typing import Optional
 
 
 
 
 
 
19
 
20
  from src.retrieval.retrieval_pipeline import RetrievalPipeline
21
  from src.rag.llm_client import MultiModelClient
@@ -63,22 +69,53 @@ class RAGPipeline:
63
  self.llm = MultiModelClient()
64
  logger.info("RAGPipeline ready")
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  def query(
67
  self,
68
  question: str,
 
69
  top_k: int = TOP_K_RERANK,
70
  filter_category: Optional[str] = None,
71
  filter_year_gte: Optional[int] = None,
72
  ) -> RAGResponse:
73
  question = question.strip()
 
74
  if not question:
75
  raise ValueError("Question cannot be empty")
76
 
77
  total_start = time.time()
78
  retrieval_start = time.time()
79
 
 
 
80
  chunks = self.retriever.retrieve(
81
- query = question,
82
  top_k_final = top_k,
83
  filter_category = filter_category,
84
  filter_year_gte = filter_year_gte,
@@ -96,11 +133,20 @@ class RAGPipeline:
96
  f"or broadening their query."
97
  )
98
 
 
 
 
 
 
 
 
 
99
  generation_start = time.time()
100
  answer, model_used = self.llm.generate(
101
  system_prompt = SYSTEM_PROMPT,
102
  user_prompt = user_prompt,
103
  original_query = question,
 
104
  stream=False
105
  )
106
 
@@ -123,18 +169,23 @@ class RAGPipeline:
123
  def stream_query(
124
  self,
125
  question: str,
 
126
  top_k: int = TOP_K_RERANK,
127
  filter_category: Optional[str] = None,
128
  filter_year_gte: Optional[int] = None,
129
  ):
130
  question = question.strip()
 
131
  if not question:
132
  raise ValueError("Question cannot be empty")
133
 
134
  total_start = time.time()
135
  retrieval_start = time.time()
 
 
 
136
  chunks = self.retriever.retrieve(
137
- query = question,
138
  top_k_final = top_k,
139
  filter_category = filter_category,
140
  filter_year_gte = filter_year_gte,
@@ -152,11 +203,20 @@ class RAGPipeline:
152
  f"or broadening their query."
153
  )
154
 
 
 
 
 
 
 
 
 
155
  generation_start = time.time()
156
  generator, model_used = self.llm.generate(
157
  system_prompt = SYSTEM_PROMPT,
158
  user_prompt = user_prompt,
159
  original_query = question,
 
160
  stream=True
161
  )
162
 
 
15
  import time
16
  import json
17
  from dataclasses import dataclass, field
18
+ from typing import Optional, List
19
+
20
+ @dataclass
21
+ class ConversationTurn:
22
+ role: str
23
+ content: str
24
+ citations: list = field(default_factory=list)
25
 
26
  from src.retrieval.retrieval_pipeline import RetrievalPipeline
27
  from src.rag.llm_client import MultiModelClient
 
69
  self.llm = MultiModelClient()
70
  logger.info("RAGPipeline ready")
71
 
72
+ def _build_retrieval_query(
73
+ self,
74
+ question: str,
75
+ history: list[ConversationTurn]
76
+ ) -> str:
77
+ followup_signals = [
78
+ "it", "that", "this", "they", "them",
79
+ "more", "example", "explain", "clarify",
80
+ "simpler", "detail", "elaborate", "again"
81
+ ]
82
+ question_lower = question.lower()
83
+ is_followup = (
84
+ len(question.split()) < 12 and
85
+ any(word in question_lower for word in followup_signals)
86
+ )
87
+
88
+ if is_followup and history:
89
+ last_substantial = ""
90
+ for turn in reversed(history):
91
+ if turn.role == "user" and len(turn.content.split()) > 5:
92
+ last_substantial = turn.content
93
+ break
94
+ if last_substantial:
95
+ return f"{last_substantial} {question}"
96
+
97
+ return question
98
+
99
  def query(
100
  self,
101
  question: str,
102
+ history: list[ConversationTurn] = None,
103
  top_k: int = TOP_K_RERANK,
104
  filter_category: Optional[str] = None,
105
  filter_year_gte: Optional[int] = None,
106
  ) -> RAGResponse:
107
  question = question.strip()
108
+ history = history or []
109
  if not question:
110
  raise ValueError("Question cannot be empty")
111
 
112
  total_start = time.time()
113
  retrieval_start = time.time()
114
 
115
+ retrieval_query = self._build_retrieval_query(question, history)
116
+
117
  chunks = self.retriever.retrieve(
118
+ query = retrieval_query,
119
  top_k_final = top_k,
120
  filter_category = filter_category,
121
  filter_year_gte = filter_year_gte,
 
133
  f"or broadening their query."
134
  )
135
 
136
+ history_messages = []
137
+ if history:
138
+ for turn in history[-10:]:
139
+ history_messages.append({
140
+ "role": turn.role,
141
+ "content": turn.content
142
+ })
143
+
144
  generation_start = time.time()
145
  answer, model_used = self.llm.generate(
146
  system_prompt = SYSTEM_PROMPT,
147
  user_prompt = user_prompt,
148
  original_query = question,
149
+ history = history_messages,
150
  stream=False
151
  )
152
 
 
169
  def stream_query(
170
  self,
171
  question: str,
172
+ history: list[ConversationTurn] = None,
173
  top_k: int = TOP_K_RERANK,
174
  filter_category: Optional[str] = None,
175
  filter_year_gte: Optional[int] = None,
176
  ):
177
  question = question.strip()
178
+ history = history or []
179
  if not question:
180
  raise ValueError("Question cannot be empty")
181
 
182
  total_start = time.time()
183
  retrieval_start = time.time()
184
+
185
+ retrieval_query = self._build_retrieval_query(question, history)
186
+
187
  chunks = self.retriever.retrieve(
188
+ query = retrieval_query,
189
  top_k_final = top_k,
190
  filter_category = filter_category,
191
  filter_year_gte = filter_year_gte,
 
203
  f"or broadening their query."
204
  )
205
 
206
+ history_messages = []
207
+ if history:
208
+ for turn in history[-10:]:
209
+ history_messages.append({
210
+ "role": turn.role,
211
+ "content": turn.content
212
+ })
213
+
214
  generation_start = time.time()
215
  generator, model_used = self.llm.generate(
216
  system_prompt = SYSTEM_PROMPT,
217
  user_prompt = user_prompt,
218
  original_query = question,
219
+ history = history_messages,
220
  stream=True
221
  )
222
 
src/rag/prompt_templates.py CHANGED
@@ -36,6 +36,8 @@ FORMATTING RULES:
36
  7. Use markdown formatting: **bold** for key terms, numbered lists for steps
37
  8. For algorithm explanations, structure as: Intuition -> Math -> Steps
38
  9. Write comprehensive, detailed answers — do not truncate explanations
 
 
39
  """
40
 
41
 
 
36
  7. Use markdown formatting: **bold** for key terms, numbered lists for steps
37
  8. For algorithm explanations, structure as: Intuition -> Math -> Steps
38
  9. Write comprehensive, detailed answers — do not truncate explanations
39
+
40
+ You have access to the conversation history above. Use it to understand follow-up questions, resolve pronouns (like 'it', 'that', 'this method'), and give answers that build on what was already discussed. If the user asks for clarification or says 'explain more' or 'give an example', refer back to your previous answer.
41
  """
42
 
43