Subhadip007 commited on
Commit
2671aea
·
1 Parent(s): f780124

feat: RAG pipeline complete - ResearchPilot end-to-end working

Browse files

- Full RAG pipeline: retrieval → prompt building → LLM generation
- Groq API integration: llama-3.3-70b-versatile
- Structured RAGResponse: answer + citations + timing metadata
- Prompt engineering: grounding rules + citation requirements
- Hallucination resistance: correctly refuses out-of-scope questions
- Citations working: paper IDs cited inline + source list populated
- Fixed hybrid retriever: spread all metadata fields from dense results

Test results:
LoRA question: grounded answer with 4 cited sources
MARL question: synthesized from multiple papers with citations
Python history: correctly refused (no relevant context)
Warm query latency: ~8s (cross-encoder on CPU, optimizing)

config/settings.py CHANGED
@@ -73,16 +73,16 @@ EMBEDDING_DIMENSION = 768 # BGE-base output dimension
73
  # ------------------------------------------
74
  QDRANT_COLLECTION_NAME = 'research_papers'
75
  QDRANT_PATH = str(ROOT_DIR / 'data' / 'qdrant_db') # Local Storage path
76
- TOP_K_RETRIEVAL = 20 # Retieve top 20 candidates
77
  TOP_K_RERANK = 5 # Keep top 5 after reranking
78
 
79
  # ------------------------------------------
80
  # LLM SETTINGS
81
  # ------------------------------------------
82
- GROQ_API_KEY = os.getenv('GROQ_API_KEY') # Loaded from .env
83
- LLM_MODEL_NAME = 'llama3-8b-8192' # Groq model ID
84
- LLM_TEMPERATURE = 0.1 # Low = More factual/consistent
85
- LLM_MAX_TOKENS = 1024 # Max response tokens
86
 
87
  # ------------------------------------------
88
  # API SETTINGS
 
73
  # ------------------------------------------
74
  QDRANT_COLLECTION_NAME = 'research_papers'
75
  QDRANT_PATH = str(ROOT_DIR / 'data' / 'qdrant_db') # Local Storage path
76
+ TOP_K_RETRIEVAL = 20 # Retrieve top 20 candidates
77
  TOP_K_RERANK = 5 # Keep top 5 after reranking
78
 
79
  # ------------------------------------------
80
  # LLM SETTINGS
81
  # ------------------------------------------
82
+ GROQ_API_KEY = os.getenv('GROQ_API_KEY') # Loaded from .env
83
+ LLM_MODEL_NAME = 'llama-3.3-70b-versatile' # Groq model ID
84
+ LLM_TEMPERATURE = 0.1 # Low = More factual/consistent
85
+ LLM_MAX_TOKENS = 1024 # Max response tokens
86
 
87
  # ------------------------------------------
88
  # API SETTINGS
src/rag/__init__.py ADDED
File without changes
src/rag/llm_client.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Groq API client for LLM inference.
3
+
4
+ WHY GROQ:
5
+ - Free tier: 14,400 requests/day with Llama3
6
+ - Speed: ~500 tokens/second (vs 10 tokens/second local CPU)
7
+ - No GPU needed on our machine
8
+ - Production-quality latency for demos
9
+
10
+ WHY LLAMA3-8B:
11
+ - Free on Groq
12
+ - 8B parameters: strong reasoning for research QA
13
+ - 8192 token context window: fits our 5 retrieved chunks
14
+ - Fast: ~1-2 seconds for a full response
15
+ """
16
+
17
+ import os
18
+ from groq import Groq
19
+ from src.utils.logger import get_logger
20
+ from config.settings import (
21
+ GROQ_API_KEY,
22
+ LLM_MODEL_NAME,
23
+ LLM_TEMPERATURE,
24
+ LLM_MAX_TOKENS,
25
+ )
26
+
27
+ logger = get_logger(__name__)
28
+
29
+
30
+ class LLMClient:
31
+ """
32
+ Wrapper around Groq API for LLM inference.
33
+
34
+ Designed as a simple interface so we can swap
35
+ to any other LLM provider (OpenAI, Anthropic, local)
36
+ by changing only this file.
37
+ """
38
+
39
+
40
+ def __init__(self):
41
+ if not GROQ_API_KEY:
42
+ raise ValueError(
43
+ "GROQ_API_KEY not found. "
44
+ "Add it to your .env file: GROQ_API_KEY=gsk_..."
45
+ )
46
+ self.client = Groq(api_key = GROQ_API_KEY)
47
+ self.model = LLM_MODEL_NAME
48
+ logger.info(f"LLMClient initialized with model: {self.model}")
49
+
50
+
51
+ def generate(
52
+ self,
53
+ system_prompt: str,
54
+ user_prompt: str,
55
+ temperature: float = LLM_TEMPERATURE,
56
+ max_tokens: int = LLM_MAX_TOKENS,
57
+ ) -> str:
58
+ """
59
+ Generate a response from the LLM.
60
+
61
+ Args:
62
+ system_prompt: Instructions for the LLM's behavior
63
+ user_prompt: The actual question + context
64
+ temperature: 0.0 = deterministic, 1.0 = creative
65
+ We use 0.1 for factual research QA
66
+ max_tokens: Maximum response length
67
+
68
+ Returns:
69
+ Generated text string
70
+
71
+ GROQ API STRUCTURE:
72
+ Uses OpenAI-compatible chat format:
73
+ [{"role": "system", "content": "..."},
74
+ {"role": "user", "content": "..."}]
75
+ """
76
+
77
+ try:
78
+ response = self.client.chat.completions.create(
79
+ model = self.model,
80
+ messages = [
81
+ {"role": "system", "content": system_prompt},
82
+ {"role": "user", "content": user_prompt}
83
+ ],
84
+ temperature = temperature,
85
+ max_tokens = max_tokens,
86
+ )
87
+
88
+ answer = response.choices[0].message.content
89
+
90
+ # Log token usage for monitoring
91
+ usage = response.usage
92
+ logger.debug(
93
+ f"LLM usage - "
94
+ f"prompt: {usage.prompt_tokens} tokens, "
95
+ f"completion: {usage.completion_tokens} tokens, "
96
+ f"total: {usage.total_tokens} tokens"
97
+ )
98
+
99
+ return answer
100
+
101
+ except Exception as e:
102
+ logger.error(f"LLM generation failed: {e}")
103
+ raise
src/rag/pipeline.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The complete RAG pipeline - orchestrates retrieval + generation.
3
+
4
+ This is the core of ResearchPilot. Every user query goes through this.
5
+
6
+ PIPELINE FLOW:
7
+ 1. Validate and clean the query
8
+ 2. Retrieve top-5 relevant chunks (Phase 8 pipeline)
9
+ 3. Build prompt with context
10
+ 4. Generate answer via Groq LLM
11
+ 5. Parse and structure the response
12
+ 6. Return answer + citations + metadata
13
+ """
14
+
15
+ import time
16
+ from dataclasses import dataclass, field
17
+ from typing import Optional
18
+
19
+ from src.retrieval.retrieval_pipeline import RetrievalPipeline
20
+ from src.rag.llm_client import LLMClient
21
+ from src.rag.prompt_templates import (
22
+ SYSTEM_PROMPT,
23
+ build_rag_prompt,
24
+ build_citation_list,
25
+ )
26
+ from src.utils.logger import get_logger
27
+ from config.settings import TOP_K_RERANK
28
+
29
+ logger = get_logger(__name__)
30
+
31
+
32
+ @dataclass
33
+ class RAGResponse:
34
+ """
35
+ Structured response from the RAG pipeline.
36
+
37
+ WHY A DATACLASS INSTEAD OF A DICT:
38
+ Dicts can have any keys - you never know what's in them.
39
+ A dataclass defines the exact contract. The FastAPI layer
40
+ (Phase 11) and frontend (Phase 12) can rely on these
41
+ fields always being present.
42
+ """
43
+ # The generated answer
44
+ answer: str
45
+
46
+ # Source papers used to generate the answer
47
+ citations: list[dict]
48
+
49
+ # Raw retrieved chunks (for debugging / evaluation)
50
+ retrieved_chunks: list[dict]
51
+
52
+ # Performance metadata
53
+ query: str
54
+ retrieval_time_ms: float
55
+ generation_time_ms: float
56
+ total_time_ms: float
57
+
58
+ # Whether retrieval found retrieval content
59
+ has_context: bool
60
+
61
+
62
+ def to_dict(self) -> dict:
63
+ return {
64
+ "answer": self.answer,
65
+ "citations": self.citations,
66
+ "query": self.query,
67
+ "retrieval_time_ms": round(self.retrieval_time_ms, 1),
68
+ "generation_time_ms": round(self.generation_time_ms, 1),
69
+ "total_time_ms": round(self.total_time_ms, 1),
70
+ "has_context": self.has_context,
71
+ "chunks_used": len(self.retrieved_chunks),
72
+ }
73
+
74
+
75
+
76
+
77
+ class RAGPipeline:
78
+ """
79
+ End-to-end RAG pipeline: query -> retrieve -> generate -> respond.
80
+
81
+ Usage:
82
+ pipeline = RAGPipeline()
83
+ response = pipeline.query("How does LoRA reduce training parameters?")
84
+ print(response.answer)
85
+ for cite in response.citations:
86
+ print(cite["title"], cite["arxiv_url"])
87
+ """
88
+
89
+ def __init__(self):
90
+ logger.info("Initializing RAGPipeline...")
91
+
92
+ self.retriever = RetrievalPipeline()
93
+ self.llm = LLMClient()
94
+
95
+ logger.info("RAGPipeline ready")
96
+
97
+ def query(
98
+ self,
99
+ question: str,
100
+ top_k: int = TOP_K_RERANK,
101
+ filter_category: Optional[str] = None,
102
+ filter_year_gte: Optional[int] = None,
103
+ ) -> RAGResponse:
104
+ """
105
+ Process a user question through the full RAG pipeline.
106
+
107
+ Args:
108
+ question: User's natural language question
109
+ top_k: Number of chunks to retrieve
110
+ filter_category: Optional ArXiv category filter
111
+ filter_year_gte: Optional year filter
112
+
113
+ Returns:
114
+ RAGResponse with answer, citations, and timing metadata
115
+ """
116
+ question = question.strip()
117
+
118
+ if not question:
119
+ raise ValueError("Question cannot be empty")
120
+
121
+ total_start = time.time()
122
+
123
+ # ------------ Stage 1: Retrieval ------------
124
+ retrieval_start = time.time()
125
+
126
+ chunks = self.retriever.retrieve(
127
+ query = question,
128
+ top_k_final = top_k,
129
+ filter_category = filter_category,
130
+ filter_year_gte = filter_year_gte,
131
+ )
132
+
133
+ retrieval_ms = (time.time() - retrieval_start) * 1000
134
+
135
+ logger.info(
136
+ f"Retrieved: {len(chunks)} chunks in {retrieval_ms:.0f}ms"
137
+ )
138
+
139
+ has_context = len(chunks) > 0
140
+
141
+ # ------------ Stage 2: Prompt Construction ------------
142
+ if has_context:
143
+ user_prompt = build_rag_prompt(question, chunks)
144
+ else:
145
+ # Fallback prompt when no relevant context found
146
+ user_prompt = (
147
+ f"The user asked: {question}\n\n"
148
+ f"No relevant research papers were found in the database. "
149
+ f"Politely inform the user and suggest they try rephrasing "
150
+ f"or broadening their query."
151
+ )
152
+
153
+ # ------------ Stage 3: LLM Generation ------------
154
+ generation_start = time.time()
155
+
156
+ answer = self.llm.generate(
157
+ system_prompt = SYSTEM_PROMPT,
158
+ user_prompt = user_prompt,
159
+ )
160
+
161
+ generation_ms = (time.time() - generation_start) * 1000
162
+ total_ms = (time.time() - total_start) * 1000
163
+
164
+ logger.info(
165
+ f"Generated answer in {generation_ms:.0f}ms | "
166
+ f"Total: {total_ms:.0f}ms"
167
+ )
168
+
169
+ # ------------ Stage 4: Build Citations ------------
170
+ citations = build_citation_list(chunks)
171
+
172
+ return RAGResponse(
173
+ answer = answer,
174
+ citations = citations,
175
+ retrieved_chunks = chunks,
176
+ query = question,
177
+ retrieval_time_ms = retrieval_ms,
178
+ generation_time_ms = generation_ms,
179
+ total_time_ms = total_ms,
180
+ has_context = has_context,
181
+ )
src/rag/prompt_templates.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Prompt templates for the ResearchPilot RAG system.
3
+
4
+ PROMPT ENGINEERING IS NOT OPTIONAL.
5
+ The difference between a good RAG system and a bad one
6
+ is often entirely in the prompt design.
7
+
8
+ Key principles we apply:
9
+ 1. EXPLICIT GROUNDING: Tell the LLM to ONLY use provided context
10
+ 2. CITATION REQUIREMENT: Force the LLM to cite which paper it used
11
+ 3. UNCERTAINTY ACKNOWLEDGMENT: If context doesn't answer, say so
12
+ 4. STRUCTURED OUTPUT: Consistent format makes parsing reliable
13
+ """
14
+
15
+ SYSTEM_PROMPT = """You are ResearchPilot, an expert AI research assistant
16
+ specialized in machine learning and AI research papers.
17
+
18
+ Your job is to answer questions based EXCLUSIVELY on the research paper
19
+ excerpts provided in the context below.
20
+
21
+ STRICT RULES:
22
+ 1. Only use information from the provided context excerpts
23
+ 2. Always cite the paper title and ID when using information from it
24
+ 3. If the context does not contain enough information to answer,
25
+ say "The provided papers do not contain sufficient information
26
+ to answer this question" - do NOT make up information
27
+ 4. Be precise and technical - your users are ML researchers and engineers
28
+ 5. When multiple papers discuss the same topic, synthesize their findings
29
+ 6. Keep answers focused and well-structured
30
+ """
31
+
32
+
33
+ def build_rag_prompt(query: str, context_chunks: list[dict]) -> str:
34
+ """
35
+ Build the full prompt for the LLM with retrieved context.
36
+
37
+ Args:
38
+ query: User's question
39
+ context_chunks: List of retrieved chunk dicts from RetrievalPipeline
40
+
41
+ Returns:
42
+ Formatted prompt string ready to send to the LLM
43
+
44
+ PROMPT STRUCTURE:
45
+ [System prompt]
46
+ [Context block - all retrieved chunks with citations]
47
+ [User question]
48
+
49
+ WHY WE FORMAT CONTEXT THIS WAY:
50
+ Each chunk is labeled with its paper title and ID.
51
+ This enables the LLM to produce citations like:
52
+ "According to [2603.12248], LoRA constrains..."
53
+
54
+ Without this labeling, the LLM cannot cite sources
55
+ even if it wanted to.
56
+ """
57
+
58
+ # Build context block from retrieved chunks
59
+ context_parts = []
60
+
61
+
62
+ for i, chunk in enumerate(context_chunks, 1):
63
+ paper_id = chunk.get("paper_id", "unknown")
64
+ title = chunk.get("title", "Unknown Paper")
65
+ date = chunk.get("published_date", "")
66
+ text = chunk.get("text", "")
67
+
68
+ context_parts.append(
69
+ f"[SOURCE {i}]\n"
70
+ f"Paper ID: {paper_id}\n"
71
+ f"Title: {title}\n"
72
+ f"Published: {date}\n"
73
+ f"Excerpt:\n{text}\n"
74
+ )
75
+
76
+
77
+ context_block = "\n---\n".join(context_parts)
78
+
79
+
80
+ prompt = f"""
81
+ CONTEXT - Research Paper Excerpts:
82
+ {context_block}
83
+
84
+ ---
85
+
86
+ QUESTION: {query}
87
+
88
+ INSTRUCTIONS: Answer the question using ONLY the context above.
89
+ Cite sources using their Paper ID in brackets, e.g. [2603.12248].
90
+ If the context is insufficient, say so clearly.
91
+ """
92
+
93
+ return prompt
94
+
95
+
96
+
97
+ def build_citation_list(context_chunks: list[dict]) -> list[dict]:
98
+ """
99
+ Build a structured list of cited sources from retrieved chunks.
100
+
101
+ Returns deduplicated list of papers used as sources.
102
+ """
103
+ seen_papers = set()
104
+ citations = []
105
+
106
+
107
+ for chunk in context_chunks:
108
+ paper_id = chunk.get("paper_id", "")
109
+ if paper_id and paper_id not in seen_papers:
110
+ seen_papers.add(paper_id)
111
+ citations.append(
112
+ {
113
+ "paper_id": paper_id,
114
+ "title": chunk.get("title", ""),
115
+ "authors": chunk.get("authors", []),
116
+ "published_date": chunk.get("published_date", ""),
117
+ "arxiv_url": chunk.get("arxiv_url", ""),
118
+ }
119
+ )
120
+
121
+ return citations
src/retrieval/hybrid_retriever.py CHANGED
@@ -116,14 +116,23 @@ class HybridRetriever:
116
  # -------------- Step 3: Build chunk_id -> full data lookup --------------
117
  # Dense results have full payload (text, metadata)
118
  # Sparse results only have chunk_id and text
 
119
  chunk_data = {}
 
 
 
 
 
 
 
 
 
 
 
120
  for r in dense_results:
121
  if r["chunk_id"] not in chunk_data:
122
- chunk_data[r["chunk_id"]] = {
123
- "chunk_id": r["chunk_id"],
124
- "text": r["text"],
125
- "score": 0.0,
126
- }
127
 
128
  # -------------- Step 4: Compute RRF score --------------
129
  RRF_scores = {}
 
116
  # -------------- Step 3: Build chunk_id -> full data lookup --------------
117
  # Dense results have full payload (text, metadata)
118
  # Sparse results only have chunk_id and text
119
+
120
  chunk_data = {}
121
+
122
+ # -------------------------------------------------------
123
+ # for r in dense_results:
124
+ # if r["chunk_id"] not in chunk_data:
125
+ # chunk_data[r["chunk_id"]] = {
126
+ # "chunk_id": r["chunk_id"],
127
+ # "text": r["text"],
128
+ # "score": 0.0,
129
+ # }
130
+ # -------------------------------------------------------
131
+
132
  for r in dense_results:
133
  if r["chunk_id"] not in chunk_data:
134
+ chunk_data[r["chunk_id"]] = {**r}
135
+
 
 
 
136
 
137
  # -------------- Step 4: Compute RRF score --------------
138
  RRF_scores = {}
src/retrieval/retrieval_pipeline.py CHANGED
@@ -98,8 +98,8 @@ class RetrievalPipeline:
98
  # Stage 2: Cross-encoder re-ranking -> top-5
99
  reranked = self.reranker.rerank(
100
  query = query,
101
- results = candidates,
102
- top_k = top_k_final * 2, # Keep extra before diversity filter
103
  )
104
 
105
  # Stage 3: Diversity filter -> max 2 chunks per paper
 
98
  # Stage 2: Cross-encoder re-ranking -> top-5
99
  reranked = self.reranker.rerank(
100
  query = query,
101
+ results = candidates[:10],
102
+ top_k = TOP_K_RETRIEVAL * 2, # Keep extra before diversity filter
103
  )
104
 
105
  # Stage 3: Diversity filter -> max 2 chunks per paper
test_rag.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ End-to-end test of the RAG pipeline.
3
+ This is the most important test in the project.
4
+ """
5
+
6
+ from src.utils.logger import setup_logger, get_logger
7
+ from src.rag.pipeline import RAGPipeline
8
+
9
+
10
+ setup_logger()
11
+ logger = get_logger(__name__)
12
+
13
+
14
+
15
+ def ask(pipeline: RAGPipeline, question: str, **kwargs):
16
+ print(f"\n{'='*65}")
17
+ print(f"Q: {question}")
18
+ print(f"{'='*65}")
19
+
20
+ response = pipeline.query(question, **kwargs)
21
+
22
+ print(f"\nANSWER:\n{response.answer}")
23
+
24
+ print(f"\nSOURCES ({len(response.citations)}):")
25
+ for i, cite in enumerate(response.citations, 1):
26
+ print(f" [{i}] {cite['paper_id']} — {cite['title'][:60]}...")
27
+ print(f" {cite['arxiv_url']}")
28
+
29
+ print(f"\nTIMING:")
30
+ print(f" Retrieval: {response.retrieval_time_ms:.0f}ms")
31
+ print(f" Generation: {response.generation_time_ms:.0f}ms")
32
+ print(f" Total: {response.total_time_ms:.0f}ms")
33
+ print(f" Chunks used: {len(response.retrieved_chunks)}")
34
+
35
+
36
+ def main():
37
+ logger.info("Initializing RAG pipeline...")
38
+ pipeline = RAGPipeline()
39
+
40
+ # Test 1: Specific technical question
41
+ ask(
42
+ pipeline,
43
+ "What is LoRA and how does it reduce the number of trainable parameters?"
44
+ )
45
+
46
+ # Test 2: Comparison question
47
+ ask(
48
+ pipeline,
49
+ "What are the main challenges in multi-agent reinforcement learning?"
50
+ )
51
+
52
+ # Test 3: Question that may not be in corpus
53
+ ask(
54
+ pipeline,
55
+ "What is the history of the Python programming language?"
56
+ )
57
+
58
+
59
+ if __name__ == "__main__":
60
+ main()