Spaces:
Running
Running
Commit ·
5951bbe
1
Parent(s): 32cff2f
feat: Add full multi-turn conversation memory and context rewriting
Browse files- frontend-next/app/page.tsx +25 -1
- src/api/main.py +3 -1
- src/api/schemas.py +11 -1
- src/rag/llm_client.py +5 -4
- src/rag/pipeline.py +63 -3
- src/rag/prompt_templates.py +2 -0
frontend-next/app/page.tsx
CHANGED
|
@@ -430,6 +430,11 @@ export default function App() {
|
|
| 430 |
setSidebarOpen(false);
|
| 431 |
};
|
| 432 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 433 |
const handleSend = async () => {
|
| 434 |
if (!query.trim() || isStreaming) return;
|
| 435 |
|
|
@@ -473,12 +478,22 @@ export default function App() {
|
|
| 473 |
setQuery("");
|
| 474 |
setIsStreaming(true);
|
| 475 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 476 |
try {
|
| 477 |
const res = await fetch(`${API_URL}/query/stream`, {
|
| 478 |
method: "POST",
|
| 479 |
headers: { "Content-Type": "application/json" },
|
| 480 |
body: JSON.stringify({
|
| 481 |
question: originalQuery,
|
|
|
|
| 482 |
top_k: topK,
|
| 483 |
filter_category: category === "All" ? undefined : category,
|
| 484 |
filter_year_gte: filterYear === "All" ? undefined : parseInt(filterYear, 10)
|
|
@@ -674,11 +689,15 @@ export default function App() {
|
|
| 674 |
</motion.button>
|
| 675 |
)}
|
| 676 |
</AnimatePresence>
|
| 677 |
-
{/* Header API Status */}
|
| 678 |
<div className="top-api-status" style={{ display: 'flex', gap: '12px', alignItems: 'center' }}>
|
| 679 |
<button onClick={() => setShowInfo(true)} className="nav-icon-btn" aria-label="Project Info" style={{ background: 'rgba(255,255,255,0.05)', border: '1px solid rgba(255,255,255,0.1)', padding: '6px', borderRadius: '50%', color: 'var(--text-muted)', cursor: 'pointer', display: 'flex', alignItems: 'center', justifyContent: 'center' }}>
|
| 680 |
<Info size={16} />
|
| 681 |
</button>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 682 |
<div className="nav-status">
|
| 683 |
<div className={`status-dot ${apiStatus === 'online' ? 'status-online' : 'status-offline'}`} />
|
| 684 |
{apiStatus === 'online' ? 'API Online' : apiStatus === 'connecting' ? 'Connecting...' : 'API Offline'}
|
|
@@ -782,6 +801,11 @@ export default function App() {
|
|
| 782 |
<span className="model-badge" style={{ background: 'rgba(255,255,255,0.05)', padding: '2px 8px', borderRadius: '4px', border: '1px solid rgba(255,255,255,0.1)', color: 'var(--text-muted)', fontSize: '0.75rem' }}>
|
| 783 |
{msg.model_used || "Auto-Detecting..."}
|
| 784 |
</span>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 785 |
</div>
|
| 786 |
|
| 787 |
<>
|
|
|
|
| 430 |
setSidebarOpen(false);
|
| 431 |
};
|
| 432 |
|
| 433 |
+
const handleClearConversation = () => {
|
| 434 |
+
if (!activeSessionId) return;
|
| 435 |
+
setSessions(prev => prev.map(s => s.id === activeSessionId ? { ...s, messages: [] } : s));
|
| 436 |
+
};
|
| 437 |
+
|
| 438 |
const handleSend = async () => {
|
| 439 |
if (!query.trim() || isStreaming) return;
|
| 440 |
|
|
|
|
| 478 |
setQuery("");
|
| 479 |
setIsStreaming(true);
|
| 480 |
|
| 481 |
+
const history = currentMessages
|
| 482 |
+
.filter(m => m.role === "user" || m.role === "assistant")
|
| 483 |
+
.map(m => ({
|
| 484 |
+
role: m.role,
|
| 485 |
+
content: m.content,
|
| 486 |
+
citations: m.citations || []
|
| 487 |
+
}))
|
| 488 |
+
.slice(-20);
|
| 489 |
+
|
| 490 |
try {
|
| 491 |
const res = await fetch(`${API_URL}/query/stream`, {
|
| 492 |
method: "POST",
|
| 493 |
headers: { "Content-Type": "application/json" },
|
| 494 |
body: JSON.stringify({
|
| 495 |
question: originalQuery,
|
| 496 |
+
history: history,
|
| 497 |
top_k: topK,
|
| 498 |
filter_category: category === "All" ? undefined : category,
|
| 499 |
filter_year_gte: filterYear === "All" ? undefined : parseInt(filterYear, 10)
|
|
|
|
| 689 |
</motion.button>
|
| 690 |
)}
|
| 691 |
</AnimatePresence>
|
|
|
|
| 692 |
<div className="top-api-status" style={{ display: 'flex', gap: '12px', alignItems: 'center' }}>
|
| 693 |
<button onClick={() => setShowInfo(true)} className="nav-icon-btn" aria-label="Project Info" style={{ background: 'rgba(255,255,255,0.05)', border: '1px solid rgba(255,255,255,0.1)', padding: '6px', borderRadius: '50%', color: 'var(--text-muted)', cursor: 'pointer', display: 'flex', alignItems: 'center', justifyContent: 'center' }}>
|
| 694 |
<Info size={16} />
|
| 695 |
</button>
|
| 696 |
+
{activeSessionId && currentMessages.length > 0 && (
|
| 697 |
+
<button onClick={handleClearConversation} className="nav-icon-btn" aria-label="Clear Conversation" title="Clear current conversation context" style={{ background: 'rgba(255,255,255,0.05)', border: '1px solid rgba(255,255,255,0.1)', padding: '6px 12px', borderRadius: '16px', color: 'var(--text-muted)', cursor: 'pointer', display: 'flex', alignItems: 'center', justifyContent: 'center', gap: '6px', fontSize: '0.8rem' }}>
|
| 698 |
+
<Trash2 size={14} /> Clear context
|
| 699 |
+
</button>
|
| 700 |
+
)}
|
| 701 |
<div className="nav-status">
|
| 702 |
<div className={`status-dot ${apiStatus === 'online' ? 'status-online' : 'status-offline'}`} />
|
| 703 |
{apiStatus === 'online' ? 'API Online' : apiStatus === 'connecting' ? 'Connecting...' : 'API Offline'}
|
|
|
|
| 801 |
<span className="model-badge" style={{ background: 'rgba(255,255,255,0.05)', padding: '2px 8px', borderRadius: '4px', border: '1px solid rgba(255,255,255,0.1)', color: 'var(--text-muted)', fontSize: '0.75rem' }}>
|
| 802 |
{msg.model_used || "Auto-Detecting..."}
|
| 803 |
</span>
|
| 804 |
+
{i >= 2 && (
|
| 805 |
+
<span style={{ fontSize: '0.7rem', background: 'rgba(138, 43, 226, 0.15)', border: '1px solid rgba(138, 43, 226, 0.3)', padding: '2px 8px', borderRadius: '4px', color: 'var(--accent-2)', marginLeft: 'auto', display: 'flex', alignItems: 'center', gap: '4px' }}>
|
| 806 |
+
<Layers size={10} /> Using conversation context
|
| 807 |
+
</span>
|
| 808 |
+
)}
|
| 809 |
</div>
|
| 810 |
|
| 811 |
<>
|
src/api/main.py
CHANGED
|
@@ -47,7 +47,7 @@ class FeedbackRequest(BaseModel):
|
|
| 47 |
model_used: str
|
| 48 |
citations_count: int
|
| 49 |
total_time_ms: float
|
| 50 |
-
from src.rag.pipeline import RAGPipeline
|
| 51 |
from src.utils.logger import setup_logger, get_logger
|
| 52 |
|
| 53 |
|
|
@@ -187,6 +187,7 @@ async def stream_query_papers(
|
|
| 187 |
try:
|
| 188 |
for chunk in pipeline.stream_query(
|
| 189 |
question = query_input.question,
|
|
|
|
| 190 |
top_k = query_input.top_k,
|
| 191 |
filter_category = query_input.filter_category,
|
| 192 |
filter_year_gte = query_input.filter_year_gte,
|
|
@@ -265,6 +266,7 @@ async def query_papers(
|
|
| 265 |
response = await asyncio.to_thread(
|
| 266 |
pipeline.query,
|
| 267 |
query_input.question,
|
|
|
|
| 268 |
query_input.top_k,
|
| 269 |
query_input.filter_category,
|
| 270 |
query_input.filter_year_gte,
|
|
|
|
| 47 |
model_used: str
|
| 48 |
citations_count: int
|
| 49 |
total_time_ms: float
|
| 50 |
+
from src.rag.pipeline import RAGPipeline, ConversationTurn
|
| 51 |
from src.utils.logger import setup_logger, get_logger
|
| 52 |
|
| 53 |
|
|
|
|
| 187 |
try:
|
| 188 |
for chunk in pipeline.stream_query(
|
| 189 |
question = query_input.question,
|
| 190 |
+
history = [ConversationTurn(role=t.role, content=t.content, citations=t.citations) for t in query_input.history],
|
| 191 |
top_k = query_input.top_k,
|
| 192 |
filter_category = query_input.filter_category,
|
| 193 |
filter_year_gte = query_input.filter_year_gte,
|
|
|
|
| 266 |
response = await asyncio.to_thread(
|
| 267 |
pipeline.query,
|
| 268 |
query_input.question,
|
| 269 |
+
[ConversationTurn(role=t.role, content=t.content, citations=t.citations) for t in query_input.history],
|
| 270 |
query_input.top_k,
|
| 271 |
query_input.filter_category,
|
| 272 |
query_input.filter_year_gte,
|
src/api/schemas.py
CHANGED
|
@@ -11,7 +11,13 @@ WHY PYDANTIC SCHEMAS IN THE API LAYER:
|
|
| 11 |
"""
|
| 12 |
|
| 13 |
from pydantic import BaseModel, Field
|
| 14 |
-
from typing import Optional
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
|
|
@@ -28,6 +34,10 @@ class QueryRequest(BaseModel):
|
|
| 28 |
description = "Research question to answer",
|
| 29 |
examples = ["How does LoRA reduce trainable parameters?"]
|
| 30 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
top_k: int = Field(
|
| 32 |
default = 5,
|
| 33 |
ge = 1, # ge = greater than or equal
|
|
|
|
| 11 |
"""
|
| 12 |
|
| 13 |
from pydantic import BaseModel, Field
|
| 14 |
+
from typing import Optional, List
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ConversationTurnSchema(BaseModel):
|
| 18 |
+
role: str
|
| 19 |
+
content: str
|
| 20 |
+
citations: list = []
|
| 21 |
|
| 22 |
|
| 23 |
|
|
|
|
| 34 |
description = "Research question to answer",
|
| 35 |
examples = ["How does LoRA reduce trainable parameters?"]
|
| 36 |
)
|
| 37 |
+
history: List[ConversationTurnSchema] = Field(
|
| 38 |
+
default=[],
|
| 39 |
+
description="Conversation history for context"
|
| 40 |
+
)
|
| 41 |
top_k: int = Field(
|
| 42 |
default = 5,
|
| 43 |
ge = 1, # ge = greater than or equal
|
src/rag/llm_client.py
CHANGED
|
@@ -113,6 +113,7 @@ class MultiModelClient:
|
|
| 113 |
system_prompt: str,
|
| 114 |
user_prompt: str,
|
| 115 |
original_query: str = "",
|
|
|
|
| 116 |
temperature: float = LLM_TEMPERATURE,
|
| 117 |
max_tokens: int = LLM_MAX_TOKENS,
|
| 118 |
stream: bool = False
|
|
@@ -124,10 +125,10 @@ class MultiModelClient:
|
|
| 124 |
Otherwise, result is a string.
|
| 125 |
"""
|
| 126 |
models_to_try = self.get_model_for_query(original_query)
|
| 127 |
-
messages = [
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
|
| 132 |
for model in models_to_try:
|
| 133 |
try:
|
|
|
|
| 113 |
system_prompt: str,
|
| 114 |
user_prompt: str,
|
| 115 |
original_query: str = "",
|
| 116 |
+
history: list = None,
|
| 117 |
temperature: float = LLM_TEMPERATURE,
|
| 118 |
max_tokens: int = LLM_MAX_TOKENS,
|
| 119 |
stream: bool = False
|
|
|
|
| 125 |
Otherwise, result is a string.
|
| 126 |
"""
|
| 127 |
models_to_try = self.get_model_for_query(original_query)
|
| 128 |
+
messages = [{"role": "system", "content": system_prompt}]
|
| 129 |
+
if history:
|
| 130 |
+
messages.extend(history)
|
| 131 |
+
messages.append({"role": "user", "content": user_prompt})
|
| 132 |
|
| 133 |
for model in models_to_try:
|
| 134 |
try:
|
src/rag/pipeline.py
CHANGED
|
@@ -15,7 +15,13 @@ PIPELINE FLOW:
|
|
| 15 |
import time
|
| 16 |
import json
|
| 17 |
from dataclasses import dataclass, field
|
| 18 |
-
from typing import Optional
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
from src.retrieval.retrieval_pipeline import RetrievalPipeline
|
| 21 |
from src.rag.llm_client import MultiModelClient
|
|
@@ -63,22 +69,53 @@ class RAGPipeline:
|
|
| 63 |
self.llm = MultiModelClient()
|
| 64 |
logger.info("RAGPipeline ready")
|
| 65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
def query(
|
| 67 |
self,
|
| 68 |
question: str,
|
|
|
|
| 69 |
top_k: int = TOP_K_RERANK,
|
| 70 |
filter_category: Optional[str] = None,
|
| 71 |
filter_year_gte: Optional[int] = None,
|
| 72 |
) -> RAGResponse:
|
| 73 |
question = question.strip()
|
|
|
|
| 74 |
if not question:
|
| 75 |
raise ValueError("Question cannot be empty")
|
| 76 |
|
| 77 |
total_start = time.time()
|
| 78 |
retrieval_start = time.time()
|
| 79 |
|
|
|
|
|
|
|
| 80 |
chunks = self.retriever.retrieve(
|
| 81 |
-
query =
|
| 82 |
top_k_final = top_k,
|
| 83 |
filter_category = filter_category,
|
| 84 |
filter_year_gte = filter_year_gte,
|
|
@@ -96,11 +133,20 @@ class RAGPipeline:
|
|
| 96 |
f"or broadening their query."
|
| 97 |
)
|
| 98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
generation_start = time.time()
|
| 100 |
answer, model_used = self.llm.generate(
|
| 101 |
system_prompt = SYSTEM_PROMPT,
|
| 102 |
user_prompt = user_prompt,
|
| 103 |
original_query = question,
|
|
|
|
| 104 |
stream=False
|
| 105 |
)
|
| 106 |
|
|
@@ -123,18 +169,23 @@ class RAGPipeline:
|
|
| 123 |
def stream_query(
|
| 124 |
self,
|
| 125 |
question: str,
|
|
|
|
| 126 |
top_k: int = TOP_K_RERANK,
|
| 127 |
filter_category: Optional[str] = None,
|
| 128 |
filter_year_gte: Optional[int] = None,
|
| 129 |
):
|
| 130 |
question = question.strip()
|
|
|
|
| 131 |
if not question:
|
| 132 |
raise ValueError("Question cannot be empty")
|
| 133 |
|
| 134 |
total_start = time.time()
|
| 135 |
retrieval_start = time.time()
|
|
|
|
|
|
|
|
|
|
| 136 |
chunks = self.retriever.retrieve(
|
| 137 |
-
query =
|
| 138 |
top_k_final = top_k,
|
| 139 |
filter_category = filter_category,
|
| 140 |
filter_year_gte = filter_year_gte,
|
|
@@ -152,11 +203,20 @@ class RAGPipeline:
|
|
| 152 |
f"or broadening their query."
|
| 153 |
)
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
generation_start = time.time()
|
| 156 |
generator, model_used = self.llm.generate(
|
| 157 |
system_prompt = SYSTEM_PROMPT,
|
| 158 |
user_prompt = user_prompt,
|
| 159 |
original_query = question,
|
|
|
|
| 160 |
stream=True
|
| 161 |
)
|
| 162 |
|
|
|
|
| 15 |
import time
|
| 16 |
import json
|
| 17 |
from dataclasses import dataclass, field
|
| 18 |
+
from typing import Optional, List
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class ConversationTurn:
|
| 22 |
+
role: str
|
| 23 |
+
content: str
|
| 24 |
+
citations: list = field(default_factory=list)
|
| 25 |
|
| 26 |
from src.retrieval.retrieval_pipeline import RetrievalPipeline
|
| 27 |
from src.rag.llm_client import MultiModelClient
|
|
|
|
| 69 |
self.llm = MultiModelClient()
|
| 70 |
logger.info("RAGPipeline ready")
|
| 71 |
|
| 72 |
+
def _build_retrieval_query(
|
| 73 |
+
self,
|
| 74 |
+
question: str,
|
| 75 |
+
history: list[ConversationTurn]
|
| 76 |
+
) -> str:
|
| 77 |
+
followup_signals = [
|
| 78 |
+
"it", "that", "this", "they", "them",
|
| 79 |
+
"more", "example", "explain", "clarify",
|
| 80 |
+
"simpler", "detail", "elaborate", "again"
|
| 81 |
+
]
|
| 82 |
+
question_lower = question.lower()
|
| 83 |
+
is_followup = (
|
| 84 |
+
len(question.split()) < 12 and
|
| 85 |
+
any(word in question_lower for word in followup_signals)
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
if is_followup and history:
|
| 89 |
+
last_substantial = ""
|
| 90 |
+
for turn in reversed(history):
|
| 91 |
+
if turn.role == "user" and len(turn.content.split()) > 5:
|
| 92 |
+
last_substantial = turn.content
|
| 93 |
+
break
|
| 94 |
+
if last_substantial:
|
| 95 |
+
return f"{last_substantial} {question}"
|
| 96 |
+
|
| 97 |
+
return question
|
| 98 |
+
|
| 99 |
def query(
|
| 100 |
self,
|
| 101 |
question: str,
|
| 102 |
+
history: list[ConversationTurn] = None,
|
| 103 |
top_k: int = TOP_K_RERANK,
|
| 104 |
filter_category: Optional[str] = None,
|
| 105 |
filter_year_gte: Optional[int] = None,
|
| 106 |
) -> RAGResponse:
|
| 107 |
question = question.strip()
|
| 108 |
+
history = history or []
|
| 109 |
if not question:
|
| 110 |
raise ValueError("Question cannot be empty")
|
| 111 |
|
| 112 |
total_start = time.time()
|
| 113 |
retrieval_start = time.time()
|
| 114 |
|
| 115 |
+
retrieval_query = self._build_retrieval_query(question, history)
|
| 116 |
+
|
| 117 |
chunks = self.retriever.retrieve(
|
| 118 |
+
query = retrieval_query,
|
| 119 |
top_k_final = top_k,
|
| 120 |
filter_category = filter_category,
|
| 121 |
filter_year_gte = filter_year_gte,
|
|
|
|
| 133 |
f"or broadening their query."
|
| 134 |
)
|
| 135 |
|
| 136 |
+
history_messages = []
|
| 137 |
+
if history:
|
| 138 |
+
for turn in history[-10:]:
|
| 139 |
+
history_messages.append({
|
| 140 |
+
"role": turn.role,
|
| 141 |
+
"content": turn.content
|
| 142 |
+
})
|
| 143 |
+
|
| 144 |
generation_start = time.time()
|
| 145 |
answer, model_used = self.llm.generate(
|
| 146 |
system_prompt = SYSTEM_PROMPT,
|
| 147 |
user_prompt = user_prompt,
|
| 148 |
original_query = question,
|
| 149 |
+
history = history_messages,
|
| 150 |
stream=False
|
| 151 |
)
|
| 152 |
|
|
|
|
| 169 |
def stream_query(
|
| 170 |
self,
|
| 171 |
question: str,
|
| 172 |
+
history: list[ConversationTurn] = None,
|
| 173 |
top_k: int = TOP_K_RERANK,
|
| 174 |
filter_category: Optional[str] = None,
|
| 175 |
filter_year_gte: Optional[int] = None,
|
| 176 |
):
|
| 177 |
question = question.strip()
|
| 178 |
+
history = history or []
|
| 179 |
if not question:
|
| 180 |
raise ValueError("Question cannot be empty")
|
| 181 |
|
| 182 |
total_start = time.time()
|
| 183 |
retrieval_start = time.time()
|
| 184 |
+
|
| 185 |
+
retrieval_query = self._build_retrieval_query(question, history)
|
| 186 |
+
|
| 187 |
chunks = self.retriever.retrieve(
|
| 188 |
+
query = retrieval_query,
|
| 189 |
top_k_final = top_k,
|
| 190 |
filter_category = filter_category,
|
| 191 |
filter_year_gte = filter_year_gte,
|
|
|
|
| 203 |
f"or broadening their query."
|
| 204 |
)
|
| 205 |
|
| 206 |
+
history_messages = []
|
| 207 |
+
if history:
|
| 208 |
+
for turn in history[-10:]:
|
| 209 |
+
history_messages.append({
|
| 210 |
+
"role": turn.role,
|
| 211 |
+
"content": turn.content
|
| 212 |
+
})
|
| 213 |
+
|
| 214 |
generation_start = time.time()
|
| 215 |
generator, model_used = self.llm.generate(
|
| 216 |
system_prompt = SYSTEM_PROMPT,
|
| 217 |
user_prompt = user_prompt,
|
| 218 |
original_query = question,
|
| 219 |
+
history = history_messages,
|
| 220 |
stream=True
|
| 221 |
)
|
| 222 |
|
src/rag/prompt_templates.py
CHANGED
|
@@ -36,6 +36,8 @@ FORMATTING RULES:
|
|
| 36 |
7. Use markdown formatting: **bold** for key terms, numbered lists for steps
|
| 37 |
8. For algorithm explanations, structure as: Intuition -> Math -> Steps
|
| 38 |
9. Write comprehensive, detailed answers — do not truncate explanations
|
|
|
|
|
|
|
| 39 |
"""
|
| 40 |
|
| 41 |
|
|
|
|
| 36 |
7. Use markdown formatting: **bold** for key terms, numbered lists for steps
|
| 37 |
8. For algorithm explanations, structure as: Intuition -> Math -> Steps
|
| 38 |
9. Write comprehensive, detailed answers — do not truncate explanations
|
| 39 |
+
|
| 40 |
+
You have access to the conversation history above. Use it to understand follow-up questions, resolve pronouns (like 'it', 'that', 'this method'), and give answers that build on what was already discussed. If the user asks for clarification or says 'explain more' or 'give an example', refer back to your previous answer.
|
| 41 |
"""
|
| 42 |
|
| 43 |
|