iamfaham commited on
Commit
8fed3de
·
verified ·
1 Parent(s): 6c457ef

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +178 -0
  2. rag_pipeline.py +258 -0
app.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from rag_pipeline import rag_chain # reuse from Step 3 in rag_pipeline.py
3
+
4
+
5
+ def chat_with_rag(message, history):
6
+ if not message.strip():
7
+ return history, ""
8
+
9
+ try:
10
+ response = rag_chain.invoke(message)
11
+
12
+ # Check if response is too long and truncate if necessary
13
+ max_display_length = 8000 # Reasonable limit for Gradio display
14
+ if len(response) > max_display_length:
15
+ truncated_response = (
16
+ response[:max_display_length]
17
+ + "\n\n... (response truncated due to length)"
18
+ )
19
+ response = truncated_response
20
+
21
+ # Add the exchange to history in the correct format for messages type
22
+ history.append({"role": "user", "content": message})
23
+ history.append({"role": "assistant", "content": response})
24
+ return history, ""
25
+
26
+ except Exception as e:
27
+ error_msg = f"Sorry, I encountered an error: {str(e)}. Please try again."
28
+ history.append({"role": "user", "content": message})
29
+ history.append({"role": "assistant", "content": error_msg})
30
+ return history, ""
31
+
32
+
33
+ def clear_chat():
34
+ """Clear the chat history"""
35
+ return [], ""
36
+
37
+
38
+ with gr.Blocks(
39
+ theme=gr.themes.Soft(),
40
+ css="""
41
+ .chatbot {
42
+ max-height: 600px !important;
43
+ overflow-y: auto !important;
44
+ }
45
+ .chatbot .message {
46
+ white-space: pre-wrap !important;
47
+ word-wrap: break-word !important;
48
+ max-width: 100% !important;
49
+ }
50
+ .chatbot .user-message, .chatbot .bot-message {
51
+ padding: 10px !important;
52
+ margin: 5px 0 !important;
53
+ border-radius: 8px !important;
54
+ }
55
+ .chatbot .bot-message {
56
+ background-color: #f0f8ff !important;
57
+ border-left: 4px solid #007acc !important;
58
+ }
59
+ .chatbot .user-message {
60
+ background-color: #e6f3ff !important;
61
+ border-left: 4px solid #28a745 !important;
62
+ }
63
+ .send-button {
64
+ background-color: #007acc !important;
65
+ color: white !important;
66
+ border: none !important;
67
+ border-radius: 8px !important;
68
+ padding: 10px 20px !important;
69
+ font-weight: bold !important;
70
+ transition: background-color 0.3s !important;
71
+ }
72
+ .send-button:hover {
73
+ background-color: #005a9e !important;
74
+ }
75
+ .clear-button {
76
+ background-color: #dc3545 !important;
77
+ color: white !important;
78
+ border: none !important;
79
+ border-radius: 8px !important;
80
+ padding: 8px 16px !important;
81
+ font-weight: bold !important;
82
+ transition: background-color 0.3s !important;
83
+ }
84
+ .clear-button:hover {
85
+ background-color: #c82333 !important;
86
+ }
87
+ .input-container {
88
+ display: flex !important;
89
+ gap: 10px !important;
90
+ align-items: flex-end !important;
91
+ }
92
+ .textbox-container {
93
+ flex: 1 !important;
94
+ }
95
+ """,
96
+ ) as demo:
97
+ gr.Markdown("# 🤖 React Docs Assistant")
98
+ gr.Markdown(
99
+ "Ask questions about React documentation and get comprehensive answers."
100
+ )
101
+
102
+ # Chat history
103
+ chatbot = gr.Chatbot(
104
+ label="Chat History",
105
+ height=500, # Slightly reduced to make room for input area
106
+ show_label=True,
107
+ type="messages", # Use the new messages format
108
+ )
109
+
110
+ # Input area with send button
111
+ with gr.Row():
112
+ with gr.Column(scale=4):
113
+ textbox = gr.Textbox(
114
+ placeholder="Ask a question about React... (Press Enter or click Send)",
115
+ lines=2, # Allow multiple lines for longer questions
116
+ max_lines=5,
117
+ label="Your Question",
118
+ show_label=True,
119
+ )
120
+ with gr.Column(scale=1):
121
+ send_button = gr.Button(
122
+ "🚀 Send", variant="primary", size="lg", elem_classes=["send-button"]
123
+ )
124
+
125
+ # Control buttons
126
+ with gr.Row():
127
+ clear_button = gr.Button(
128
+ "🗑️ Clear Chat", variant="secondary", elem_classes=["clear-button"]
129
+ )
130
+
131
+ # Example questions
132
+ with gr.Accordion("Example Questions", open=False):
133
+ gr.Markdown(
134
+ """
135
+ Try these example questions:
136
+ - **What is React?**
137
+ - **How do I use useState hook?**
138
+ - **Explain React components**
139
+ - **What are props in React?**
140
+ - **How does React rendering work?**
141
+ - **What are React Hooks?**
142
+ - **How to handle events in React?**
143
+ """
144
+ )
145
+
146
+ # Event handlers
147
+ def send_message(message, history):
148
+ return chat_with_rag(message, history)
149
+
150
+ # Connect the send button
151
+ send_button.click(
152
+ fn=send_message,
153
+ inputs=[textbox, chatbot],
154
+ outputs=[chatbot, textbox],
155
+ api_name="send",
156
+ )
157
+
158
+ # Connect Enter key in textbox
159
+ textbox.submit(
160
+ fn=send_message,
161
+ inputs=[textbox, chatbot],
162
+ outputs=[chatbot, textbox],
163
+ api_name="send_enter",
164
+ )
165
+
166
+ # Connect clear button
167
+ clear_button.click(
168
+ fn=clear_chat, inputs=[], outputs=[chatbot, textbox], api_name="clear"
169
+ )
170
+
171
+ if __name__ == "__main__":
172
+ demo.launch(
173
+ server_name="127.0.0.1", # Allow external access
174
+ server_port=7860,
175
+ share=False, # Set to True if you want a public link
176
+ debug=True, # Enable debug mode for better error messages
177
+ show_error=True,
178
+ )
rag_pipeline.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from langchain_pinecone import Pinecone as LangchainPinecone
4
+ from langchain_huggingface import HuggingFaceEmbeddings
5
+ from langchain_core.prompts import PromptTemplate
6
+ from langchain_core.runnables import RunnableLambda
7
+ from langchain_openai import ChatOpenAI
8
+ import json
9
+ from rank_bm25 import BM25Okapi
10
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
11
+ import torch
12
+ import logging
13
+ import re
14
+
15
+ load_dotenv()
16
+
17
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
18
+
19
+ # Initialize Pinecone vectorstore
20
+ embedder = HuggingFaceEmbeddings(
21
+ model_name="intfloat/e5-large-v2",
22
+ model_kwargs={"device": "cpu"},
23
+ encode_kwargs={"normalize_embeddings": True},
24
+ )
25
+
26
+ index_name = os.getenv("PINECONE_INDEX")
27
+ vectorstore = LangchainPinecone.from_existing_index(
28
+ index_name=index_name,
29
+ embedding=embedder,
30
+ )
31
+
32
+ # Retriever
33
+ retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
34
+
35
+ # LLM setup
36
+ llm = ChatOpenAI(
37
+ model=os.getenv("OPENROUTER_MODEL"),
38
+ api_key=os.getenv("OPENROUTER_API_KEY"),
39
+ base_url="https://openrouter.ai/api/v1",
40
+ max_tokens=2000, # Limit response length to prevent extremely long outputs
41
+ temperature=0.7, # Add some creativity while keeping responses focused
42
+ )
43
+
44
+ # Question decomposition prompt template
45
+ decomposition_template = """Break down the following question into exactly 4 sub-questions that would help provide a comprehensive answer.
46
+ Each sub-question should focus on a different aspect of the main question.
47
+
48
+ Original Question: {question}
49
+
50
+ Please provide exactly 4 sub-questions, one per line, starting with numbers 1-4:
51
+
52
+ 1. [First sub-question]
53
+ 2. [Second sub-question]
54
+ 3. [Third sub-question]
55
+ 4. [Fourth sub-question]
56
+
57
+ Make sure each sub-question is specific and focused on a different aspect of the original question."""
58
+
59
+ decomposition_prompt = PromptTemplate(
60
+ input_variables=["question"],
61
+ template=decomposition_template,
62
+ )
63
+
64
+ # Answer synthesis prompt template
65
+ synthesis_template = """You are a helpful assistant. Based on the answers to the sub-questions below, provide a comprehensive but concise answer to the original question.
66
+
67
+ Original Question: {original_question}
68
+
69
+ Sub-questions and their answers:
70
+ {sub_answers}
71
+
72
+ Please synthesize these answers into a clear, well-structured response that directly addresses the original question.
73
+ Keep the response focused and avoid unnecessary repetition. If any sub-question couldn't be answered with the available context, mention that briefly.
74
+ Include relevant code examples where applicable, but keep them concise."""
75
+
76
+ synthesis_prompt = PromptTemplate(
77
+ input_variables=["original_question", "sub_answers"],
78
+ template=synthesis_template,
79
+ )
80
+
81
+ # Individual answer prompt template
82
+ template = """You are a helpful assistant. Answer the question using ONLY the context below. Also add a code example if applicable.
83
+ If the answer is not in the context, say "I don't know."
84
+
85
+ Context:
86
+ {context}
87
+
88
+ Question:
89
+ {question}
90
+
91
+ Helpful Answer:"""
92
+
93
+ prompt = PromptTemplate(
94
+ input_variables=["context", "question"],
95
+ template=template,
96
+ )
97
+
98
+ # Load docs for BM25
99
+ with open("react_docs_chunks.json", "r", encoding="utf-8") as f:
100
+ docs_json = json.load(f)
101
+
102
+ bm25_corpus = [doc["content"] for doc in docs_json]
103
+ bm25_titles = [doc.get("title", "") for doc in docs_json]
104
+ bm25 = BM25Okapi([doc.split() for doc in bm25_corpus])
105
+
106
+ # Cross-encoder for re-ranking
107
+ cross_encoder_model = "cross-encoder/ms-marco-MiniLM-L-6-v2"
108
+ cross_tokenizer = AutoTokenizer.from_pretrained(cross_encoder_model)
109
+ cross_model = AutoModelForSequenceClassification.from_pretrained(cross_encoder_model)
110
+
111
+
112
+ # Hybrid retrieval function
113
+ def hybrid_retrieve(query, dense_k=5, bm25_k=5, rerank_k=5):
114
+ logging.info(f"Hybrid retrieval for query: {query}")
115
+ # Dense retrieval
116
+ dense_docs = retriever.get_relevant_documents(query)
117
+ logging.info(f"Dense docs retrieved: {len(dense_docs)}")
118
+ dense_set = set((d.metadata["title"], d.page_content) for d in dense_docs)
119
+
120
+ # BM25 retrieval
121
+ bm25_scores = bm25.get_scores(query.split())
122
+ bm25_indices = sorted(
123
+ range(len(bm25_scores)), key=lambda i: bm25_scores[i], reverse=True
124
+ )[:bm25_k]
125
+ bm25_docs = [
126
+ type(
127
+ "Doc",
128
+ (),
129
+ {"metadata": {"title": bm25_titles[i]}, "page_content": bm25_corpus[i]},
130
+ )
131
+ for i in bm25_indices
132
+ ]
133
+ logging.info(f"BM25 docs retrieved: {len(bm25_docs)}")
134
+ bm25_set = set((d.metadata["title"], d.page_content) for d in bm25_docs)
135
+
136
+ # Merge and deduplicate
137
+ all_docs = list(
138
+ {(d[0], d[1]): d for d in list(dense_set) + list(bm25_set)}.values()
139
+ )
140
+ all_doc_objs = [
141
+ type("Doc", (), {"metadata": {"title": t}, "page_content": c})
142
+ for t, c in all_docs
143
+ ]
144
+ logging.info(f"Total unique docs before re-ranking: {len(all_doc_objs)}")
145
+
146
+ # Re-rank with cross-encoder
147
+ pairs = [(query, doc.page_content) for doc in all_doc_objs]
148
+ inputs = cross_tokenizer.batch_encode_plus(
149
+ pairs, padding=True, truncation=True, return_tensors="pt", max_length=512
150
+ )
151
+ with torch.no_grad():
152
+ scores = cross_model(**inputs).logits.squeeze().cpu().numpy()
153
+ ranked = sorted(zip(all_doc_objs, scores), key=lambda x: x[1], reverse=True)[
154
+ :rerank_k
155
+ ]
156
+ logging.info(f"Docs after re-ranking: {len(ranked)}")
157
+ return [doc for doc, _ in ranked]
158
+
159
+
160
+ # Question decomposition function
161
+ def decompose_question(question):
162
+ try:
163
+ logging.info(f"Decomposing question: {question}")
164
+ decomposition_response = llm.invoke(
165
+ decomposition_prompt.format(question=question)
166
+ )
167
+ logging.info(
168
+ f"Decomposition response: {decomposition_response.content[:200]}..."
169
+ )
170
+
171
+ # Extract sub-questions from the response
172
+ content = decomposition_response.content
173
+ sub_questions = []
174
+
175
+ # Use regex to extract numbered questions
176
+ pattern = r"\d+\.\s*(.+)"
177
+ matches = re.findall(pattern, content, re.MULTILINE)
178
+ logging.info(f"Regex matches: {matches}")
179
+
180
+ for match in matches[:4]: # Take first 4 matches
181
+ sub_question = match.strip()
182
+ if sub_question:
183
+ sub_questions.append(sub_question)
184
+
185
+ # If we don't get exactly 4 questions, create variations
186
+ while len(sub_questions) < 4:
187
+ sub_questions.append(f"Additional aspect of: {question}")
188
+
189
+ logging.info(f"Decomposed into {len(sub_questions)} sub-questions")
190
+ return sub_questions[:4]
191
+ except Exception as e:
192
+ logging.error(f"Error in decompose_question: {str(e)}")
193
+ # Fallback to simple variations
194
+ return [
195
+ f"What is {question}?",
196
+ f"How does {question} work?",
197
+ f"When to use {question}?",
198
+ f"Examples of {question}",
199
+ ]
200
+
201
+
202
+ # RAG chain
203
+ def format_docs(docs):
204
+ logging.info(f"Formatting {len(docs)} docs for LLM context.")
205
+ return "\n\n".join(f"{doc.metadata['title']}:\n{doc.page_content}" for doc in docs)
206
+
207
+
208
+ def process_question_with_decomposition(original_question):
209
+ try:
210
+ logging.info(f"Processing question with decomposition: {original_question}")
211
+
212
+ # Step 1: Decompose the question
213
+ sub_questions = decompose_question(original_question)
214
+ logging.info(f"Sub-questions: {sub_questions}")
215
+
216
+ # Step 2: Get answers for each sub-question
217
+ sub_answers = []
218
+ for i, sub_q in enumerate(sub_questions, 1):
219
+ logging.info(f"Processing sub-question {i}: {sub_q}")
220
+
221
+ # Retrieve context for this sub-question
222
+ context = format_docs(hybrid_retrieve(sub_q))
223
+ logging.info(f"Context length for sub-question {i}: {len(context)}")
224
+
225
+ # Get answer for this sub-question
226
+ sub_answer = llm.invoke(prompt.format(context=context, question=sub_q))
227
+ logging.info(f"Sub-answer {i}: {sub_answer.content[:100]}...")
228
+ sub_answers.append(f"{i}. {sub_q}\nAnswer: {sub_answer.content}")
229
+
230
+ # Step 3: Synthesize the final answer
231
+ sub_answers_text = "\n\n".join(sub_answers)
232
+ logging.info(f"Sub-answers text length: {len(sub_answers_text)}")
233
+
234
+ final_answer = llm.invoke(
235
+ synthesis_prompt.format(
236
+ original_question=original_question, sub_answers=sub_answers_text
237
+ )
238
+ )
239
+
240
+ logging.info(f"Final answer: {final_answer.content[:100]}...")
241
+ return final_answer.content
242
+
243
+ except Exception as e:
244
+ logging.error(f"Error in process_question_with_decomposition: {str(e)}")
245
+ return f"Error processing question: {str(e)}"
246
+
247
+
248
+ # Enhanced RAG chain with decomposition
249
+ rag_chain = RunnableLambda(process_question_with_decomposition)
250
+
251
+ # Run it for local testing
252
+ if __name__ == "__main__":
253
+ while True:
254
+ query = input("\n Ask a question about React: ")
255
+ if query.lower() in ["exit", "quit"]:
256
+ break
257
+ response = rag_chain.invoke(query)
258
+ print("\n🤖 Answer:\n", response)