ahmadsanafarooq commited on
Commit
d0c46e1
·
verified ·
1 Parent(s): b752b49

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +281 -128
app.py CHANGED
@@ -12,32 +12,282 @@ import logging
12
  import numpy as np
13
  from sklearn.feature_extraction.text import TfidfVectorizer
14
  from sklearn.metrics.pairwise import cosine_similarity
 
15
  from dotenv import load_dotenv
16
 
17
  # Configure logging
18
  logging.basicConfig(level=logging.INFO)
19
  logger = logging.getLogger(__name__)
20
 
 
21
  class SimpleEmbeddings:
 
 
22
  def __init__(self):
23
  self.vectorizer = TfidfVectorizer(max_features=384, stop_words='english')
24
  self.fitted = False
25
-
26
  def embed_documents(self, texts: List[str]) -> List[List[float]]:
 
27
  if not self.fitted:
28
  self.vectorizer.fit(texts)
29
  self.fitted = True
 
30
  embeddings = self.vectorizer.transform(texts)
31
  return embeddings.toarray().tolist()
32
-
33
  def embed_query(self, text: str) -> List[float]:
 
34
  if not self.fitted:
 
35
  return [0.0] * 384
 
36
  embedding = self.vectorizer.transform([text])
37
  return embedding.toarray()[0].tolist()
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  class RetrieverEvaluator:
40
- def __init__(self, retriever, ground_truth, k=3):
 
 
41
  self.retriever = retriever
42
  self.ground_truth = ground_truth
43
  self.k = k
@@ -66,138 +316,41 @@ class RetrieverEvaluator:
66
  print(f"MRR@{self.k}: {mrr:.2f}")
67
  return mrr
68
 
69
- class RAGAssistant:
70
- def __init__(self, groq_api_key: str):
71
- self.groq_api_key = groq_api_key
72
- self.embeddings = self._init_embeddings()
73
- self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
74
- self.learning_vectorstore = None
75
- self.llm = ChatGroq(groq_api_key=groq_api_key, model_name="llama3-70b-8192", temperature=0.1)
76
- self.learning_persist_dir = "./chroma_learning_db"
77
- self._init_vector_store()
78
-
79
- def _init_embeddings(self):
80
- try:
81
- from langchain_huggingface import HuggingFaceEmbeddings
82
- for model_name in ["all-MiniLM-L6-v2", "paraphrase-MiniLM-L3-v2", "all-mpnet-base-v2"]:
83
- try:
84
- return HuggingFaceEmbeddings(model_name=model_name, model_kwargs={'device': 'cpu'})
85
- except:
86
- continue
87
- except ImportError:
88
- pass
89
- return SimpleEmbeddings()
90
-
91
- def _init_vector_store(self):
92
- self.learning_vectorstore = Chroma(
93
- persist_directory=self.learning_persist_dir,
94
- embedding_function=self.embeddings,
95
- collection_name="learning_materials"
96
- )
97
-
98
- def load_documents(self, files: List[str]) -> str:
99
- documents = []
100
- for file_path in files:
101
- try:
102
- loader = PyPDFLoader(file_path) if file_path.endswith(".pdf") else TextLoader(file_path, encoding="utf-8")
103
- docs = loader.load()
104
- documents.extend(docs)
105
- except Exception as e:
106
- print(f"Error loading {file_path}: {e}")
107
- if not documents:
108
- return "No valid documents found."
109
- chunks = self.text_splitter.split_documents(documents)
110
- for chunk in chunks:
111
- chunk.metadata['source'] = chunk.metadata.get('source', 'unknown')
112
- self.learning_vectorstore.add_documents(chunks)
113
- self.learning_vectorstore.persist()
114
- return f"Loaded {len(chunks)} document chunks."
115
-
116
- def get_response(self, query: str) -> str:
117
- if not self.learning_vectorstore:
118
- return "Please upload learning materials first."
119
- qa_chain = RetrievalQA.from_chain_type(
120
- llm=self.llm,
121
- chain_type="stuff",
122
- retriever=self.learning_vectorstore.as_retriever(search_kwargs={"k": 3}),
123
- return_source_documents=True
124
- )
125
- prompt = f"""
126
- You are a helpful educational assistant.
127
- Answer the student's question clearly and provide references if applicable.
128
-
129
- Question: {query}
130
- """
131
- result = qa_chain({"query": prompt})
132
- response = result['result']
133
- if result.get("source_documents"):
134
- response += "\n\n**Sources:**\n"
135
- for doc in result["source_documents"]:
136
- response += f"- {Path(doc.metadata.get('source', 'Unknown')).name}\n"
137
- return response
138
-
139
- def evaluate_retriever(self, user_queries: List[str], file_names: List[str]):
140
- """Evaluate with user-provided queries and expected file names"""
141
- ground_truth = dict(zip(user_queries, file_names))
142
- retriever = self.learning_vectorstore.as_retriever(search_kwargs={"k": 3})
143
- evaluator = RetrieverEvaluator(retriever, ground_truth, k=3)
144
  recall = evaluator.recall_at_k()
145
  mrr = evaluator.mean_reciprocal_rank()
146
- return f"Recall@3: {recall:.2f}, MRR@3: {mrr:.2f}"
147
-
148
- def create_interface(assistant: RAGAssistant):
149
- def upload_files(files):
150
- file_paths = [f.name for f in files]
151
- return assistant.load_documents(file_paths)
152
-
153
- def chat_fn(message, history):
154
- response = assistant.get_response(message)
155
- history.append((message, response))
156
- return history, ""
157
-
158
- def evaluate_fn(queries, file_names):
159
- query_list = [q.strip() for q in queries.split('\n') if q.strip()]
160
- file_list = [f.strip() for f in file_names.split('\n') if f.strip()]
161
- if len(query_list) != len(file_list):
162
- return "Number of queries and expected file names must match."
163
- return assistant.evaluate_retriever(query_list, file_list)
164
-
165
- with gr.Blocks(title="RAG Assistant") as demo:
166
- gr.Markdown("# 📘 RAG-Based Assistant")
167
- with gr.Tab("📄 Upload & Chat"):
168
- file_input = gr.File(label="Upload PDFs or Text Files", file_count="multiple", file_types=[".pdf", ".txt"])
169
- upload_btn = gr.Button("Load Documents")
170
- status = gr.Textbox(label="Status", interactive=False)
171
- chatbot = gr.Chatbot()
172
- user_input = gr.Textbox(label="Ask a question")
173
- send_btn = gr.Button("Send")
174
-
175
- upload_btn.click(fn=upload_files, inputs=[file_input], outputs=[status])
176
- send_btn.click(fn=chat_fn, inputs=[user_input, chatbot], outputs=[chatbot, user_input])
177
- user_input.submit(fn=chat_fn, inputs=[user_input, chatbot], outputs=[chatbot, user_input])
178
-
179
- with gr.Tab("📊 Evaluate Retriever"):
180
- gr.Markdown("Paste queries and expected file names (one per line).")
181
- queries = gr.Textbox(lines=5, label="Queries")
182
- filenames = gr.Textbox(lines=5, label="Expected File Names")
183
- eval_btn = gr.Button("Run Evaluation")
184
- eval_result = gr.Textbox(label="Evaluation Result")
185
- eval_btn.click(fn=evaluate_fn, inputs=[queries, filenames], outputs=[eval_result])
186
-
187
- gr.Markdown("---")
188
- gr.Markdown("*Powered by LangChain, ChromaDB, and Groq API*")
189
-
190
- return demo
191
 
 
192
  def main():
193
  load_dotenv()
194
  groq_api_key = os.getenv("GROQ_API_KEY")
195
  if not groq_api_key:
196
- print("Missing GROQ_API_KEY. Set it in your environment.")
197
  return
198
- assistant = RAGAssistant(groq_api_key)
199
- app = create_interface(assistant)
200
- app.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
 
 
 
 
 
 
 
 
 
 
201
 
202
  if __name__ == "__main__":
203
- main()
 
12
  import numpy as np
13
  from sklearn.feature_extraction.text import TfidfVectorizer
14
  from sklearn.metrics.pairwise import cosine_similarity
15
+ import pickle
16
  from dotenv import load_dotenv
17
 
18
  # Configure logging
19
  logging.basicConfig(level=logging.INFO)
20
  logger = logging.getLogger(__name__)
21
 
22
+ # ---------------------- TF-IDF Embedding Fallback ----------------------
23
  class SimpleEmbeddings:
24
+ """Simple TF-IDF based embeddings as fallback"""
25
+
26
  def __init__(self):
27
  self.vectorizer = TfidfVectorizer(max_features=384, stop_words='english')
28
  self.fitted = False
29
+
30
  def embed_documents(self, texts: List[str]) -> List[List[float]]:
31
+ """Embed a list of documents"""
32
  if not self.fitted:
33
  self.vectorizer.fit(texts)
34
  self.fitted = True
35
+
36
  embeddings = self.vectorizer.transform(texts)
37
  return embeddings.toarray().tolist()
38
+
39
  def embed_query(self, text: str) -> List[float]:
40
+ """Embed a single query"""
41
  if not self.fitted:
42
+ # If not fitted, return zero vector
43
  return [0.0] * 384
44
+
45
  embedding = self.vectorizer.transform([text])
46
  return embedding.toarray()[0].tolist()
47
 
48
+ # ---------------------- RAG Assistant ----------------------
49
+ class RAGAssistant:
50
+ def __init__(self, groq_api_key: str):
51
+ """Initialize the RAG Assistant with Groq API key"""
52
+ self.groq_api_key = groq_api_key
53
+
54
+ # Initialize embeddings with fallback
55
+ self.embeddings = self._init_embeddings()
56
+
57
+ self.text_splitter = RecursiveCharacterTextSplitter(
58
+ chunk_size=1000,
59
+ chunk_overlap=200,
60
+ length_function=len
61
+ )
62
+
63
+ self.learning_vectorstore = None
64
+ self.code_vectorstore = None
65
+
66
+ self.llm = ChatGroq(
67
+ groq_api_key=groq_api_key,
68
+ model_name="llama3-70b-8192",
69
+ temperature=0.1
70
+ )
71
+
72
+ self.learning_persist_dir = "./chroma_learning_db"
73
+ self.code_persist_dir = "./chroma_code_db"
74
+
75
+ self._init_vector_stores()
76
+
77
+ def _init_embeddings(self):
78
+ try:
79
+ from langchain_huggingface import HuggingFaceEmbeddings
80
+ print("Trying HuggingFace embeddings...")
81
+ models_to_try = [
82
+ "all-MiniLM-L6-v2",
83
+ "paraphrase-MiniLM-L3-v2",
84
+ "all-mpnet-base-v2"
85
+ ]
86
+ for model_name in models_to_try:
87
+ try:
88
+ embeddings = HuggingFaceEmbeddings(
89
+ model_name=model_name,
90
+ model_kwargs={'device': 'cpu'},
91
+ encode_kwargs={'normalize_embeddings': False}
92
+ )
93
+ print(f"Successfully loaded HuggingFace model: {model_name}")
94
+ return embeddings
95
+ except Exception as e:
96
+ print(f"Failed to load {model_name}: {e}")
97
+ except ImportError:
98
+ print("HuggingFace embeddings not available")
99
+
100
+ print("Using TF-IDF embeddings as fallback...")
101
+ return SimpleEmbeddings()
102
+
103
+ def _init_vector_stores(self):
104
+ try:
105
+ self.learning_vectorstore = Chroma(
106
+ persist_directory=self.learning_persist_dir,
107
+ embedding_function=self.embeddings,
108
+ collection_name="learning_materials"
109
+ )
110
+ self.code_vectorstore = Chroma(
111
+ persist_directory=self.code_persist_dir,
112
+ embedding_function=self.embeddings,
113
+ collection_name="code_documentation"
114
+ )
115
+ except Exception as e:
116
+ logger.error(f"Error initializing vector stores: {str(e)}")
117
+ raise
118
+
119
+ def load_documents(self, files: List[str], assistant_type: str) -> str:
120
+ try:
121
+ documents = []
122
+ for file_path in files:
123
+ try:
124
+ if file_path.endswith('.pdf'):
125
+ loader = PyPDFLoader(file_path)
126
+ else:
127
+ loader = TextLoader(file_path, encoding='utf-8')
128
+ docs = loader.load()
129
+ documents.extend(docs)
130
+ except Exception as e:
131
+ print(f"Error loading {file_path}: {e}")
132
+ if not documents:
133
+ return "No documents could be loaded. Please check your files."
134
+ chunks = self.text_splitter.split_documents(documents)
135
+ for chunk in chunks:
136
+ chunk.metadata['assistant_type'] = assistant_type
137
+ if assistant_type == "learning":
138
+ self.learning_vectorstore.add_documents(chunks)
139
+ self.learning_vectorstore.persist()
140
+ elif assistant_type == "code":
141
+ self.code_vectorstore.add_documents(chunks)
142
+ self.code_vectorstore.persist()
143
+ return f"Successfully loaded {len(chunks)} chunks from {len(documents)} documents into {assistant_type} assistant."
144
+ except Exception as e:
145
+ logger.error(f"Error loading documents: {str(e)}")
146
+ return f"Error loading documents: {str(e)}"
147
+
148
+ def get_learning_tutor_response(self, question: str) -> str:
149
+ try:
150
+ if not self.learning_vectorstore:
151
+ return "Please upload some learning materials first."
152
+ qa_chain = RetrievalQA.from_chain_type(
153
+ llm=self.llm,
154
+ chain_type="stuff",
155
+ retriever=self.learning_vectorstore.as_retriever(search_kwargs={"k": 3}),
156
+ return_source_documents=True
157
+ )
158
+ learning_prompt = f"""
159
+ You are an AI learning assistant that helps students understand academic concepts.
160
+ Based on the provided course materials, answer the student's question clearly and educationally.
161
+
162
+ Guidelines:
163
+ - Provide clear, educational explanations
164
+ - Use examples when helpful
165
+ - Reference specific sources when possible
166
+ - Adapt to the student's level of understanding
167
+ - Offer additional practice questions or related concepts when relevant
168
+ - Maintain an encouraging, supportive tone
169
+
170
+ Student's question: {question}
171
+ """
172
+ result = qa_chain({"query": learning_prompt})
173
+ response = result['result']
174
+ if result.get('source_documents'):
175
+ response += "\n\n**Sources:**\n"
176
+ for doc in result['source_documents'][:3]:
177
+ source = doc.metadata.get('source', 'Unknown')
178
+ response += f"- {Path(source).name}\n"
179
+ return response
180
+ except Exception as e:
181
+ logger.error(f"Error in learning tutor: {str(e)}")
182
+ return f"Error generating response: {str(e)}"
183
+
184
+ def get_code_helper_response(self, question: str) -> str:
185
+ try:
186
+ if not self.code_vectorstore:
187
+ return "Please upload some code documentation first."
188
+ qa_chain = RetrievalQA.from_chain_type(
189
+ llm=self.llm,
190
+ chain_type="stuff",
191
+ retriever=self.code_vectorstore.as_retriever(search_kwargs={"k": 3}),
192
+ return_source_documents=True
193
+ )
194
+ code_prompt = f"""
195
+ You are a technical assistant that helps developers understand codebases and APIs.
196
+ Based on the provided documentation and code examples, answer the developer's question.
197
+
198
+ Guidelines:
199
+ - Provide practical, actionable guidance
200
+ - Include relevant code snippets with explanations
201
+ - Reference specific documentation sections when possible
202
+ - Highlight important considerations (security, performance, errors)
203
+ - Suggest related APIs or patterns that might be useful
204
+ - Use clear, technical language appropriate for developers
205
+
206
+ Developer's question: {question}
207
+ """
208
+ result = qa_chain({"query": code_prompt})
209
+ response = result['result']
210
+ if result.get('source_documents'):
211
+ response += "\n\n**Documentation Sources:**\n"
212
+ for doc in result['source_documents'][:3]:
213
+ source = doc.metadata.get('source', 'Unknown')
214
+ response += f"- {Path(source).name}\n"
215
+ return response
216
+ except Exception as e:
217
+ logger.error(f"Error in code helper: {str(e)}")
218
+ return f"Error generating response: {str(e)}"
219
+
220
+ # ---------------------- Gradio UI ----------------------
221
+ def create_gradio_interface(assistant: RAGAssistant):
222
+ def upload_learning_files(files):
223
+ if not files:
224
+ return "No files uploaded."
225
+ file_paths = [f.name for f in files]
226
+ return assistant.load_documents(file_paths, "learning")
227
+
228
+ def upload_code_files(files):
229
+ if not files:
230
+ return "No files uploaded."
231
+ file_paths = [f.name for f in files]
232
+ return assistant.load_documents(file_paths, "code")
233
+
234
+ def learning_chat(message, history):
235
+ if not message.strip():
236
+ return history, ""
237
+ response = assistant.get_learning_tutor_response(message)
238
+ history.append((message, response))
239
+ return history, ""
240
+
241
+ def code_chat(message, history):
242
+ if not message.strip():
243
+ return history, ""
244
+ response = assistant.get_code_helper_response(message)
245
+ history.append((message, response))
246
+ return history, ""
247
+
248
+ with gr.Blocks(title="RAG-Based Learning & Code Assistant", theme=gr.themes.Soft()) as demo:
249
+ gr.Markdown("# 🎓 RAG-Based Learning & Code Assistant")
250
+ gr.Markdown("Upload your documents and ask questions to get intelligent responses!")
251
+
252
+ with gr.Tabs():
253
+ with gr.TabItem("📚 Learning Tutor"):
254
+ with gr.Row():
255
+ with gr.Column(scale=1):
256
+ learning_files = gr.File(label="Upload Learning Materials", file_count="multiple", file_types=[".pdf", ".txt", ".md"])
257
+ learning_upload_btn = gr.Button("Upload Materials", variant="primary")
258
+ learning_status = gr.Textbox(label="Upload Status", interactive=False)
259
+ with gr.Column(scale=2):
260
+ learning_chatbot = gr.Chatbot(label="Learning Tutor Chat", height=400)
261
+ learning_input = gr.Textbox(label="Ask a question", placeholder="e.g., What is regression?")
262
+ learning_submit = gr.Button("Ask Question", variant="primary")
263
+ learning_upload_btn.click(upload_learning_files, inputs=[learning_files], outputs=[learning_status])
264
+ learning_submit.click(learning_chat, inputs=[learning_input, learning_chatbot], outputs=[learning_chatbot, learning_input])
265
+ learning_input.submit(learning_chat, inputs=[learning_input, learning_chatbot], outputs=[learning_chatbot, learning_input])
266
+
267
+ with gr.TabItem("💻 Code Documentation Helper"):
268
+ with gr.Row():
269
+ with gr.Column(scale=1):
270
+ code_files = gr.File(label="Upload Code Documentation", file_count="multiple", file_types=[".pdf", ".txt", ".md", ".py", ".js", ".json"])
271
+ code_upload_btn = gr.Button("Upload Documentation", variant="primary")
272
+ code_status = gr.Textbox(label="Upload Status", interactive=False)
273
+ with gr.Column(scale=2):
274
+ code_chatbot = gr.Chatbot(label="Code Helper Chat", height=400)
275
+ code_input = gr.Textbox(label="Ask about code or APIs", placeholder="e.g., How to use this function?")
276
+ code_submit = gr.Button("Ask Question", variant="primary")
277
+ code_upload_btn.click(upload_code_files, inputs=[code_files], outputs=[code_status])
278
+ code_submit.click(code_chat, inputs=[code_input, code_chatbot], outputs=[code_chatbot, code_input])
279
+ code_input.submit(code_chat, inputs=[code_input, code_chatbot], outputs=[code_chatbot, code_input])
280
+
281
+ gr.Markdown("---")
282
+ gr.Markdown("*Powered by LangChain, ChromaDB, and Groq API*")
283
+
284
+ return demo
285
+
286
+ # ---------------------- Evaluation Additions ----------------------
287
  class RetrieverEvaluator:
288
+ """Evaluation class for computing Recall@k and MRR@k"""
289
+
290
+ def __init__(self, retriever, ground_truth: dict, k=3):
291
  self.retriever = retriever
292
  self.ground_truth = ground_truth
293
  self.k = k
 
316
  print(f"MRR@{self.k}: {mrr:.2f}")
317
  return mrr
318
 
319
+ def evaluate_retriever_example(assistant):
320
+ """Run example evaluation with mock ground truth"""
321
+ sample_ground_truth = {
322
+ "What is machine learning?": ["ml_intro.txt"],
323
+ "What is API authentication?": ["api_guide.pdf"]
324
+ }
325
+ if assistant.learning_vectorstore:
326
+ retriever = assistant.learning_vectorstore.as_retriever(search_kwargs={"k": 3})
327
+ evaluator = RetrieverEvaluator(retriever, sample_ground_truth, k=3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
  recall = evaluator.recall_at_k()
329
  mrr = evaluator.mean_reciprocal_rank()
330
+ return f"Evaluation Results:\nRecall@3: {recall:.2f}\nMRR@3: {mrr:.2f}"
331
+ return "No documents uploaded for evaluation."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
 
333
+ # ---------------------- Entry Point ----------------------
334
  def main():
335
  load_dotenv()
336
  groq_api_key = os.getenv("GROQ_API_KEY")
337
  if not groq_api_key:
338
+ print("Please set your GROQ_API_KEY in the environment.")
339
  return
340
+
341
+ try:
342
+ print("Initializing RAG Assistant...")
343
+ assistant = RAGAssistant(groq_api_key)
344
+
345
+ # Optional: Run evaluation after docs are uploaded
346
+ # print(evaluate_retriever_example(assistant))
347
+
348
+ demo = create_gradio_interface(assistant)
349
+ print("Launching app...")
350
+ demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)
351
+ except Exception as e:
352
+ logger.error(f"Error starting application: {str(e)}")
353
+ print(f"Error: {str(e)}")
354
 
355
  if __name__ == "__main__":
356
+ main()