fguryel commited on
Commit
8fc10eb
·
1 Parent(s): b943512
Files changed (4) hide show
  1. .gitattributes +0 -4
  2. README.md +0 -19
  3. app.py +414 -232
  4. app_hf.py +0 -309
.gitattributes CHANGED
@@ -33,8 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
-
37
- # Large database and data files
38
- *.sqlite3 filter=lfs diff=lfs merge=lfs -text
39
- *.json filter=lfs diff=lfs merge=lfs -text
40
  chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
36
  chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -10,29 +10,10 @@ pinned: false
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
13
- ---
14
- title: Scikit-learn Documentation Q&A Bot
15
- emoji: 🤖
16
- colorFrom: blue
17
- colorTo: green
18
- sdk: streamlit
19
- sdk_version: 1.50.0
20
- app_file: app.py
21
- pinned: false
22
- license: mit
23
- ---
24
-
25
  # Scikit-learn Documentation Q&A Bot 🤖
26
 
27
  A Retrieval-Augmented Generation (RAG) chatbot that answers questions about Scikit-learn using the official documentation.
28
 
29
- ## How to Use on Hugging Face Spaces
30
-
31
- 1. **Enter OpenAI API Key**: In the sidebar, enter your OpenAI API key
32
- 2. **Ask Questions**: Type any question about Scikit-learn functionality
33
- 3. **Get Answers**: Receive detailed responses with source documentation links
34
- 4. **Explore**: Use the example questions or browse chat history
35
-
36
  ## Features
37
 
38
  - **🔍 Smart Retrieval**: Searches through 1,249+ documentation chunks using semantic similarity
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
13
  # Scikit-learn Documentation Q&A Bot 🤖
14
 
15
  A Retrieval-Augmented Generation (RAG) chatbot that answers questions about Scikit-learn using the official documentation.
16
 
 
 
 
 
 
 
 
17
  ## Features
18
 
19
  - **🔍 Smart Retrieval**: Searches through 1,249+ documentation chunks using semantic similarity
app.py CHANGED
@@ -1,309 +1,491 @@
 
1
  """
2
- Scikit-learn RAG Chatbot - Hugging Face Spaces Optimized Version
3
- A Retrieval-Augmented Generation chatbot for Scikit-learn documentation.
 
 
 
 
 
 
4
  """
5
 
6
- import streamlit as st
7
  import os
8
- import json
9
  import logging
10
- from typing import List, Dict, Optional, Tuple
11
- import warnings
 
 
 
 
12
 
13
- # Suppress warnings for cleaner output
14
- warnings.filterwarnings("ignore")
15
- logging.getLogger("chromadb").setLevel(logging.ERROR)
16
- logging.getLogger("sentence_transformers").setLevel(logging.ERROR)
17
 
18
- # Try imports with error handling
19
- try:
20
- import chromadb
21
- from sentence_transformers import SentenceTransformer
22
- import openai
23
- DEPENDENCIES_AVAILABLE = True
24
- except ImportError as e:
25
- DEPENDENCIES_AVAILABLE = False
26
- st.error(f"Missing dependencies: {e}")
27
 
28
- class SimpleRAGChatbot:
29
- """Simplified RAG chatbot for HF Spaces deployment"""
 
 
 
 
 
 
30
 
31
- def __init__(self):
32
- self.client = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  self.collection = None
34
- self.model = None
35
  self.openai_client = None
36
- self.initialized = False
37
 
38
- def initialize(self):
39
- """Initialize the RAG system with error handling"""
 
 
 
 
 
40
  try:
41
- if not DEPENDENCIES_AVAILABLE:
42
- return False
43
-
44
- # Initialize embedding model
45
- self.model = SentenceTransformer('all-MiniLM-L6-v2')
46
 
47
- # Try to load existing database
48
- if self._load_existing_database():
49
- self.initialized = True
50
- return True
51
-
52
- # If no database exists, try to rebuild
53
- if self._rebuild_from_chunks():
54
- self.initialized = True
55
- return True
56
-
57
- return False
58
 
59
- except Exception as e:
60
- st.error(f"Initialization error: {str(e)}")
61
- return False
62
-
63
- def _load_existing_database(self) -> bool:
64
- """Try to load existing ChromaDB"""
65
- try:
66
- # Check multiple possible paths
67
- db_paths = ['./chroma_db', './chroma', '.']
68
 
69
- for db_path in db_paths:
70
- try:
71
- if os.path.exists(os.path.join(db_path, 'chroma.sqlite3')) or os.path.exists(db_path):
72
- self.client = chromadb.PersistentClient(path=db_path)
73
- collections = self.client.list_collections()
74
-
75
- if collections:
76
- self.collection = collections[0] # Use first available collection
77
- st.success(f"✅ Loaded database from {db_path} with {self.collection.count()} documents")
78
- return True
79
-
80
- except Exception:
81
- continue
82
-
83
- return False
84
 
85
  except Exception as e:
86
- st.warning(f"Could not load existing database: {str(e)}")
87
- return False
88
 
89
- def _rebuild_from_chunks(self) -> bool:
90
- """Rebuild database from chunks.json if available"""
91
- try:
92
- chunks_file = 'chunks.json'
93
- if not os.path.exists(chunks_file):
94
- st.error("❌ No chunks.json file found. Please upload the required data files.")
95
- return False
96
-
97
- with open(chunks_file, 'r') as f:
98
- chunks = json.load(f)
99
-
100
- if not chunks:
101
- st.error("❌ Chunks file is empty")
102
- return False
103
-
104
- # Create new database
105
- db_path = './chroma_db'
106
- os.makedirs(db_path, exist_ok=True)
107
-
108
- self.client = chromadb.PersistentClient(path=db_path)
109
-
110
- # Create collection
111
- collection_name = "sklearn_docs"
112
- try:
113
- self.collection = self.client.get_collection(collection_name)
114
- except:
115
- self.collection = self.client.create_collection(collection_name)
116
-
117
- # Add chunks in batches
118
- batch_size = 100
119
- total_chunks = len(chunks)
120
-
121
- progress_bar = st.progress(0)
122
- status_text = st.empty()
123
 
124
- for i in range(0, total_chunks, batch_size):
125
- batch = chunks[i:i + batch_size]
126
-
127
- documents = [chunk['content'] for chunk in batch]
128
- metadatas = [{'source': chunk.get('source', 'unknown')} for chunk in batch]
129
- ids = [f"chunk_{i + j}" for j in range(len(batch))]
130
-
131
- self.collection.add(
132
- documents=documents,
133
- metadatas=metadatas,
134
- ids=ids
135
- )
136
-
137
- progress = (i + len(batch)) / total_chunks
138
- progress_bar.progress(progress)
139
- status_text.text(f"Processing chunks: {i + len(batch)}/{total_chunks}")
140
 
141
- progress_bar.empty()
142
- status_text.empty()
143
 
144
- st.success(f" Successfully rebuilt database with {total_chunks} chunks")
145
  return True
146
 
147
  except Exception as e:
148
- st.error(f"Failed to rebuild database: {str(e)}")
 
149
  return False
150
 
151
- def search_documents(self, query: str, n_results: int = 5) -> List[Dict]:
152
- """Search for relevant documents"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  try:
154
- if not self.initialized or not self.collection:
155
- return []
156
-
157
  results = self.collection.query(
158
  query_texts=[query],
159
  n_results=n_results
160
  )
161
 
162
- documents = []
 
 
163
  if results['documents'] and results['documents'][0]:
164
- for i, doc in enumerate(results['documents'][0]):
165
- documents.append({
166
- 'content': doc,
167
- 'source': results['metadatas'][0][i].get('source', 'unknown') if results['metadatas'] else 'unknown'
168
- })
 
 
 
 
 
169
 
170
- return documents
 
171
 
172
  except Exception as e:
173
- st.error(f"Search error: {str(e)}")
 
174
  return []
175
 
176
- def generate_response(self, query: str, context_docs: List[Dict]) -> str:
177
- """Generate response using OpenAI"""
178
- try:
179
- # Check for OpenAI API key
180
- api_key = st.session_state.get('openai_api_key') or os.getenv('OPENAI_API_KEY')
181
-
182
- if not api_key:
183
- return "⚠️ Please provide your OpenAI API key to generate responses."
184
-
185
- if not self.openai_client:
186
- self.openai_client = openai.OpenAI(api_key=api_key)
187
-
188
- # Prepare context
189
- context = "\n\n".join([f"Source: {doc['source']}\nContent: {doc['content']}"
190
- for doc in context_docs])
191
 
192
- if not context.strip():
193
- return "I couldn't find relevant information in the documentation. Please try rephrasing your question."
 
 
 
 
 
 
 
194
 
195
- # Create prompt
196
- prompt = f"""Based on the following Scikit-learn documentation, please answer the user's question accurately and helpfully.
 
 
 
 
 
197
 
198
- Documentation Context:
199
  {context}
200
 
201
- User Question: {query}
 
 
 
 
 
 
 
 
 
202
 
203
- Please provide a clear, accurate answer based on the documentation provided. If the documentation doesn't contain enough information to answer the question completely, please say so."""
204
 
205
- # Generate response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  response = self.openai_client.chat.completions.create(
207
- model="gpt-3.5-turbo",
208
  messages=[
209
- {"role": "system", "content": "You are a helpful assistant that answers questions about Scikit-learn based on provided documentation."},
210
- {"role": "user", "content": prompt}
 
 
 
 
 
 
211
  ],
212
- max_tokens=1000,
213
- temperature=0.3
 
214
  )
215
 
216
- return response.choices[0].message.content
 
 
217
 
218
  except Exception as e:
219
- return f"Error generating response: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
  def main():
222
- """Main Streamlit application"""
 
 
223
  st.set_page_config(
224
- page_title="Scikit-learn RAG Chatbot",
225
  page_icon="🤖",
226
- layout="wide"
 
227
  )
228
 
229
- st.title("🤖 Scikit-learn RAG Chatbot")
230
- st.markdown("Ask questions about Scikit-learn and get answers from the official documentation!")
231
-
232
  # Initialize session state
233
- if 'chatbot' not in st.session_state:
234
- st.session_state.chatbot = SimpleRAGChatbot()
235
- st.session_state.messages = []
236
- st.session_state.initialized = False
237
 
238
- # Initialize the chatbot if not already done
239
- if not st.session_state.initialized:
240
- with st.spinner("Initializing RAG system..."):
241
- success = st.session_state.chatbot.initialize()
242
- st.session_state.initialized = success
243
-
244
- if not success:
245
- st.error("❌ Failed to initialize the system. Please check the data files.")
246
- st.stop()
 
247
 
248
- # Sidebar for API key
249
  with st.sidebar:
250
- st.header("Configuration")
251
 
 
252
  api_key = st.text_input(
253
- "OpenAI API Key",
254
  type="password",
255
- value=st.session_state.get('openai_api_key', ''),
256
- help="Enter your OpenAI API key to enable response generation"
257
  )
258
 
259
- if api_key:
260
- st.session_state.openai_api_key = api_key
261
- st.success("✅ API key configured")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
 
263
  st.markdown("---")
264
- st.markdown("### About")
265
- st.markdown("""
266
- This chatbot uses RAG (Retrieval-Augmented Generation) to answer questions about Scikit-learn.
267
 
268
- - **Data**: Official Scikit-learn documentation
269
- - **Embeddings**: all-MiniLM-L6-v2
270
- - **Vector DB**: ChromaDB
271
- - **LLM**: GPT-3.5-turbo
272
- """)
273
-
274
- # Chat interface
275
- st.header("💬 Chat")
 
 
 
 
 
 
 
 
276
 
277
- # Display chat messages
278
- for message in st.session_state.messages:
279
- with st.chat_message(message["role"]):
280
- st.markdown(message["content"])
281
 
282
- # Chat input
283
- if prompt := st.chat_input("Ask a question about Scikit-learn..."):
284
- # Add user message
285
- st.session_state.messages.append({"role": "user", "content": prompt})
286
- with st.chat_message("user"):
287
- st.markdown(prompt)
288
-
289
- # Generate response
290
- with st.chat_message("assistant"):
291
- with st.spinner("Searching documentation and generating response..."):
292
- # Search for relevant documents
293
- docs = st.session_state.chatbot.search_documents(prompt)
294
-
295
- if docs:
296
- st.markdown("**Found relevant documentation:**")
297
- for i, doc in enumerate(docs[:3], 1):
298
- with st.expander(f"📄 Source {i}: {doc['source']}", expanded=False):
299
- st.markdown(doc['content'][:500] + "..." if len(doc['content']) > 500 else doc['content'])
300
-
301
- # Generate response
302
- response = st.session_state.chatbot.generate_response(prompt, docs)
303
- st.markdown(response)
 
 
 
 
 
 
304
 
305
- # Add assistant message
306
- st.session_state.messages.append({"role": "assistant", "content": response})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
 
308
  if __name__ == "__main__":
309
  main()
 
1
+ #!/usr/bin/env python3
2
  """
3
+ Scikit-learn Documentation Q&A Bot
4
+
5
+ A Retrieval-Augmented Generation (RAG) chatbot built with Streamlit
6
+ that answers questions about Scikit-learn documentation using ChromaDB
7
+ for retrieval and OpenAI for generation.
8
+
9
+ Author: AI Assistant
10
+ Date: September 2025
11
  """
12
 
 
13
  import os
 
14
  import logging
15
+ from typing import List, Dict, Any, Optional, Tuple
16
+ import streamlit as st
17
+ import chromadb
18
+ from chromadb.config import Settings
19
+ from sentence_transformers import SentenceTransformer
20
+ from openai import OpenAI
21
 
 
 
 
 
22
 
23
+ # Configure logging
24
+ logging.basicConfig(level=logging.INFO)
25
+ logger = logging.getLogger(__name__)
 
 
 
 
 
 
26
 
27
+
28
+ class RAGChatbot:
29
+ """
30
+ A Retrieval-Augmented Generation chatbot for Scikit-learn documentation.
31
+
32
+ This class handles the complete RAG pipeline: retrieval from ChromaDB,
33
+ augmentation with context, and generation using OpenAI's API.
34
+ """
35
 
36
+ def __init__(
37
+ self,
38
+ db_path: str = './chroma_db',
39
+ collection_name: str = 'sklearn_docs',
40
+ embedding_model_name: str = 'all-MiniLM-L6-v2'
41
+ ):
42
+ """
43
+ Initialize the RAG chatbot.
44
+
45
+ Args:
46
+ db_path (str): Path to ChromaDB database
47
+ collection_name (str): Name of the ChromaDB collection
48
+ embedding_model_name (str): Name of the embedding model
49
+ """
50
+ self.db_path = db_path
51
+ self.collection_name = collection_name
52
+ self.embedding_model_name = embedding_model_name
53
+
54
+ # Initialize components
55
+ self.chroma_client = None
56
  self.collection = None
57
+ self.embedding_model = None
58
  self.openai_client = None
 
59
 
60
+ # Initialize the retrieval system
61
+ self._initialize_retrieval_system()
62
+
63
+ def _initialize_retrieval_system(self) -> None:
64
+ """
65
+ Initialize ChromaDB client and embedding model for retrieval.
66
+ """
67
  try:
68
+ # Initialize ChromaDB client
69
+ self.chroma_client = chromadb.PersistentClient(
70
+ path=self.db_path,
71
+ settings=Settings(anonymized_telemetry=False)
72
+ )
73
 
74
+ # Get collection
75
+ self.collection = self.chroma_client.get_collection(
76
+ name=self.collection_name
77
+ )
 
 
 
 
 
 
 
78
 
79
+ # Load embedding model (same as used for building the database)
80
+ self.embedding_model = SentenceTransformer(self.embedding_model_name)
 
 
 
 
 
 
 
81
 
82
+ logger.info("RAG retrieval system initialized successfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  except Exception as e:
85
+ logger.error(f"Failed to initialize retrieval system: {e}")
86
+ raise
87
 
88
+ def set_openai_client(self, api_key: str) -> bool:
89
+ """
90
+ Initialize OpenAI client with API key.
91
+
92
+ Args:
93
+ api_key (str): OpenAI API key
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
+ Returns:
96
+ bool: True if successful, False otherwise
97
+ """
98
+ try:
99
+ self.openai_client = OpenAI(api_key=api_key)
 
 
 
 
 
 
 
 
 
 
 
100
 
101
+ # Test the API key with a simple request
102
+ self.openai_client.models.list()
103
 
104
+ logger.info("OpenAI client initialized successfully")
105
  return True
106
 
107
  except Exception as e:
108
+ logger.error(f"Failed to initialize OpenAI client: {e}")
109
+ st.error(f"Invalid API key or OpenAI connection error: {e}")
110
  return False
111
 
112
+ def retrieve_relevant_chunks(
113
+ self,
114
+ query: str,
115
+ n_results: int = 3,
116
+ min_relevance_score: float = 0.1
117
+ ) -> List[Dict[str, Any]]:
118
+ """
119
+ Retrieve relevant text chunks from the vector database.
120
+
121
+ Args:
122
+ query (str): User question/query
123
+ n_results (int): Number of chunks to retrieve
124
+ min_relevance_score (float): Minimum relevance score threshold
125
+
126
+ Returns:
127
+ List[Dict[str, Any]]: Retrieved chunks with content and metadata
128
+ """
129
  try:
130
+ # Query the collection
 
 
131
  results = self.collection.query(
132
  query_texts=[query],
133
  n_results=n_results
134
  )
135
 
136
+ retrieved_chunks = []
137
+
138
+ # Process results
139
  if results['documents'] and results['documents'][0]:
140
+ for i in range(len(results['documents'][0])):
141
+ chunk_data = {
142
+ 'content': results['documents'][0][i],
143
+ 'metadata': results['metadatas'][0][i],
144
+ 'distance': results['distances'][0][i] if 'distances' in results else None
145
+ }
146
+
147
+ # Filter by relevance score if available
148
+ if chunk_data['distance'] is None or chunk_data['distance'] >= min_relevance_score:
149
+ retrieved_chunks.append(chunk_data)
150
 
151
+ logger.info(f"Retrieved {len(retrieved_chunks)} relevant chunks for query: {query[:50]}...")
152
+ return retrieved_chunks
153
 
154
  except Exception as e:
155
+ logger.error(f"Error retrieving chunks: {e}")
156
+ st.error(f"Error during retrieval: {e}")
157
  return []
158
 
159
+ def create_rag_prompt(
160
+ self,
161
+ user_question: str,
162
+ retrieved_chunks: List[Dict[str, Any]]
163
+ ) -> str:
164
+ """
165
+ Create an augmented prompt for OpenAI with retrieved context.
166
+
167
+ Args:
168
+ user_question (str): Original user question
169
+ retrieved_chunks (List[Dict[str, Any]]): Retrieved relevant chunks
 
 
 
 
170
 
171
+ Returns:
172
+ str: Augmented prompt for OpenAI
173
+ """
174
+ # Build context from retrieved chunks
175
+ context_parts = []
176
+
177
+ for i, chunk in enumerate(retrieved_chunks, 1):
178
+ url = chunk['metadata'].get('url', 'Unknown source')
179
+ content = chunk['content'].strip()
180
 
181
+ context_part = f"--- Context {i} (Source: {url}) ---\n{content}\n"
182
+ context_parts.append(context_part)
183
+
184
+ context = "\n".join(context_parts)
185
+
186
+ # Create the RAG prompt
187
+ rag_prompt = f"""You are an expert AI assistant specializing in Scikit-learn, a popular Python machine learning library. Your task is to answer questions about Scikit-learn based ONLY on the provided context from the official documentation.
188
 
189
+ CONTEXT:
190
  {context}
191
 
192
+ USER QUESTION:
193
+ {user_question}
194
+
195
+ INSTRUCTIONS:
196
+ 1. Answer the question using ONLY the information provided in the context above
197
+ 2. Be accurate, helpful, and specific
198
+ 3. If the context doesn't contain enough information to fully answer the question, say so clearly
199
+ 4. Include relevant code examples if they appear in the context
200
+ 5. Mention specific function names, class names, or parameter names when relevant
201
+ 6. Structure your answer clearly with appropriate formatting
202
 
203
+ ANSWER:"""
204
 
205
+ return rag_prompt
206
+
207
+ def generate_answer(
208
+ self,
209
+ prompt: str,
210
+ model: str = "gpt-3.5-turbo",
211
+ max_tokens: int = 1000,
212
+ temperature: float = 0.1
213
+ ) -> Optional[str]:
214
+ """
215
+ Generate answer using OpenAI API.
216
+
217
+ Args:
218
+ prompt (str): Augmented prompt with context
219
+ model (str): OpenAI model to use
220
+ max_tokens (int): Maximum tokens in response
221
+ temperature (float): Temperature for generation
222
+
223
+ Returns:
224
+ Optional[str]: Generated answer or None if failed
225
+ """
226
+ try:
227
  response = self.openai_client.chat.completions.create(
228
+ model=model,
229
  messages=[
230
+ {
231
+ "role": "system",
232
+ "content": "You are a helpful AI assistant specializing in Scikit-learn documentation. Provide accurate, helpful answers based only on the provided context."
233
+ },
234
+ {
235
+ "role": "user",
236
+ "content": prompt
237
+ }
238
  ],
239
+ max_tokens=max_tokens,
240
+ temperature=temperature,
241
+ top_p=0.9
242
  )
243
 
244
+ answer = response.choices[0].message.content.strip()
245
+ logger.info(f"Generated answer of length: {len(answer)}")
246
+ return answer
247
 
248
  except Exception as e:
249
+ logger.error(f"Error generating answer: {e}")
250
+ st.error(f"Error generating answer: {e}")
251
+ return None
252
+
253
+ def get_answer(
254
+ self,
255
+ user_question: str,
256
+ n_chunks: int = 3,
257
+ model: str = "gpt-3.5-turbo"
258
+ ) -> Tuple[Optional[str], List[str]]:
259
+ """
260
+ Complete RAG pipeline: retrieve, augment, generate.
261
+
262
+ Args:
263
+ user_question (str): User's question
264
+ n_chunks (int): Number of chunks to retrieve
265
+ model (str): OpenAI model to use
266
+
267
+ Returns:
268
+ Tuple[Optional[str], List[str]]: Generated answer and source URLs
269
+ """
270
+ if not self.openai_client:
271
+ st.error("OpenAI client not initialized. Please provide a valid API key.")
272
+ return None, []
273
+
274
+ # Step 1: Retrieve relevant chunks
275
+ with st.spinner("🔍 Searching relevant documentation..."):
276
+ retrieved_chunks = self.retrieve_relevant_chunks(user_question, n_chunks)
277
+
278
+ if not retrieved_chunks:
279
+ return "I couldn't find relevant information in the Scikit-learn documentation to answer your question. Please try rephrasing your question or ask about a different topic.", []
280
+
281
+ # Step 2: Create augmented prompt
282
+ with st.spinner("📝 Preparing context..."):
283
+ rag_prompt = self.create_rag_prompt(user_question, retrieved_chunks)
284
+
285
+ # Step 3: Generate answer
286
+ with st.spinner("🤖 Generating answer..."):
287
+ answer = self.generate_answer(rag_prompt, model)
288
+
289
+ # Extract source URLs
290
+ source_urls = [chunk['metadata'].get('url', 'Unknown') for chunk in retrieved_chunks]
291
+ source_urls = list(dict.fromkeys(source_urls)) # Remove duplicates while preserving order
292
+
293
+ return answer, source_urls
294
+
295
+
296
+ def initialize_session_state():
297
+ """Initialize Streamlit session state variables."""
298
+ if 'chatbot' not in st.session_state:
299
+ try:
300
+ st.session_state.chatbot = RAGChatbot()
301
+ except Exception as e:
302
+ st.error(f"Failed to initialize chatbot: {e}")
303
+ st.stop()
304
+
305
+ if 'openai_initialized' not in st.session_state:
306
+ st.session_state.openai_initialized = False
307
+
308
+ if 'chat_history' not in st.session_state:
309
+ st.session_state.chat_history = []
310
+
311
 
312
  def main():
313
+ """Main Streamlit application."""
314
+
315
+ # Page configuration
316
  st.set_page_config(
317
+ page_title="Scikit-learn Q&A Bot",
318
  page_icon="🤖",
319
+ layout="wide",
320
+ initial_sidebar_state="expanded"
321
  )
322
 
 
 
 
323
  # Initialize session state
324
+ initialize_session_state()
 
 
 
325
 
326
+ # Main title and description
327
+ st.title("🤖 Scikit-learn Documentation Q&A Bot")
328
+ st.markdown("""
329
+ Welcome to the **Scikit-learn Documentation Q&A Bot**! This intelligent assistant can answer your questions about Scikit-learn using the official documentation.
330
+
331
+ **How it works:**
332
+ 1. 🔍 **Retrieval**: Searches through 1,249+ documentation chunks
333
+ 2. 📝 **Augmentation**: Provides relevant context to the AI
334
+ 3. 🤖 **Generation**: Uses OpenAI to generate accurate answers
335
+ """)
336
 
337
+ # Sidebar for API key and settings
338
  with st.sidebar:
339
+ st.header("⚙️ Configuration")
340
 
341
+ # OpenAI API Key input
342
  api_key = st.text_input(
343
+ "🔑 OpenAI API Key",
344
  type="password",
345
+ placeholder="sk-...",
346
+ help="Enter your OpenAI API key to enable the chatbot"
347
  )
348
 
349
+ if api_key and not st.session_state.openai_initialized:
350
+ if st.session_state.chatbot.set_openai_client(api_key):
351
+ st.session_state.openai_initialized = True
352
+ st.success("✅ API key validated!")
353
+ st.rerun()
354
+
355
+ # Model selection
356
+ model = st.selectbox(
357
+ "🧠 AI Model",
358
+ ["gpt-3.5-turbo", "gpt-4", "gpt-4-turbo-preview"],
359
+ index=0,
360
+ help="Choose the OpenAI model for generating answers"
361
+ )
362
+
363
+ # Number of context chunks
364
+ n_chunks = st.slider(
365
+ "📄 Context Chunks",
366
+ min_value=1,
367
+ max_value=5,
368
+ value=3,
369
+ help="Number of relevant documentation chunks to use for context"
370
+ )
371
 
372
  st.markdown("---")
 
 
 
373
 
374
+ # Database info
375
+ st.header("📊 Database Info")
376
+ try:
377
+ collection_count = st.session_state.chatbot.collection.count()
378
+ st.metric("Total Documents", f"{collection_count:,}")
379
+ st.metric("Embedding Model", "all-MiniLM-L6-v2")
380
+ st.metric("Vector Dimensions", "384")
381
+ except:
382
+ st.error("Could not load database info")
383
+
384
+ st.markdown("---")
385
+
386
+ # Clear chat history
387
+ if st.button("🗑️ Clear Chat History"):
388
+ st.session_state.chat_history = []
389
+ st.rerun()
390
 
391
+ # Main chat interface
392
+ col1, col2 = st.columns([2, 1])
 
 
393
 
394
+ with col1:
395
+ st.header("💬 Ask Your Question")
396
+
397
+ # Question input
398
+ default_question = st.session_state.get('selected_question', '')
399
+ user_question = st.text_input(
400
+ "Enter your question about Scikit-learn:",
401
+ value=default_question,
402
+ placeholder="e.g., How do I perform cross-validation in scikit-learn?",
403
+ key="question_input"
404
+ )
405
+
406
+ # Clear selected question after using it
407
+ if 'selected_question' in st.session_state:
408
+ del st.session_state['selected_question']
409
+
410
+ # Submit button
411
+ submit_button = st.button("🚀 Get Answer", type="primary")
412
+
413
+ # Process question
414
+ if submit_button and user_question:
415
+ if not st.session_state.openai_initialized:
416
+ st.error("⚠️ Please enter a valid OpenAI API key in the sidebar first.")
417
+ else:
418
+ # Get answer using RAG
419
+ answer, sources = st.session_state.chatbot.get_answer(
420
+ user_question, n_chunks, model
421
+ )
422
 
423
+ if answer:
424
+ # Add to chat history
425
+ st.session_state.chat_history.append({
426
+ 'question': user_question,
427
+ 'answer': answer,
428
+ 'sources': sources
429
+ })
430
+
431
+ # Clear input
432
+ st.rerun()
433
+
434
+ # Display chat history
435
+ if st.session_state.chat_history:
436
+ st.header("📝 Chat History")
437
+
438
+ for i, chat in enumerate(reversed(st.session_state.chat_history)):
439
+ with st.expander(f"Q: {chat['question'][:60]}...", expanded=(i == 0)):
440
+ st.markdown(f"**Question:** {chat['question']}")
441
+ st.markdown(f"**Answer:**\n{chat['answer']}")
442
+
443
+ if chat['sources']:
444
+ st.markdown("**Sources:**")
445
+ for j, source in enumerate(chat['sources'], 1):
446
+ source_name = source.split('/')[-1] if '/' in source else source
447
+ st.markdown(f"{j}. [{source_name}]({source})")
448
+
449
+ with col2:
450
+ st.header("💡 Example Questions")
451
+
452
+ example_questions = [
453
+ "How do I perform cross-validation in scikit-learn?",
454
+ "What is the difference between Ridge and Lasso regression?",
455
+ "How do I use GridSearchCV for parameter tuning?",
456
+ "What clustering algorithms are available in scikit-learn?",
457
+ "How do I preprocess data using StandardScaler?",
458
+ "What is the difference between classification and regression?",
459
+ "How do I handle missing values in my dataset?",
460
+ "What is feature selection and how do I use it?",
461
+ "How do I visualize decision trees?",
462
+ "What is ensemble learning in scikit-learn?"
463
+ ]
464
+
465
+ for question in example_questions:
466
+ if st.button(question, key=f"example_{hash(question)}"):
467
+ # Use a different approach to set the question
468
+ st.session_state['selected_question'] = question
469
+ st.rerun()
470
+
471
+ st.markdown("---")
472
+
473
+ st.header("ℹ️ Tips")
474
+ st.markdown("""
475
+ **For best results:**
476
+ - Be specific in your questions
477
+ - Ask about scikit-learn functionality
478
+ - Include context when possible
479
+ - Check the sources for verification
480
+
481
+ **The bot can help with:**
482
+ - API usage and parameters
483
+ - Algorithm explanations
484
+ - Code examples
485
+ - Best practices
486
+ - Troubleshooting
487
+ """)
488
+
489
 
490
  if __name__ == "__main__":
491
  main()
app_hf.py CHANGED
@@ -1,309 +0,0 @@
1
- """
2
- Scikit-learn RAG Chatbot - Hugging Face Spaces Optimized Version
3
- A Retrieval-Augmented Generation chatbot for Scikit-learn documentation.
4
- """
5
-
6
- import streamlit as st
7
- import os
8
- import json
9
- import logging
10
- from typing import List, Dict, Optional, Tuple
11
- import warnings
12
-
13
- # Suppress warnings for cleaner output
14
- warnings.filterwarnings("ignore")
15
- logging.getLogger("chromadb").setLevel(logging.ERROR)
16
- logging.getLogger("sentence_transformers").setLevel(logging.ERROR)
17
-
18
- # Try imports with error handling
19
- try:
20
- import chromadb
21
- from sentence_transformers import SentenceTransformer
22
- import openai
23
- DEPENDENCIES_AVAILABLE = True
24
- except ImportError as e:
25
- DEPENDENCIES_AVAILABLE = False
26
- st.error(f"Missing dependencies: {e}")
27
-
28
- class SimpleRAGChatbot:
29
- """Simplified RAG chatbot for HF Spaces deployment"""
30
-
31
- def __init__(self):
32
- self.client = None
33
- self.collection = None
34
- self.model = None
35
- self.openai_client = None
36
- self.initialized = False
37
-
38
- def initialize(self):
39
- """Initialize the RAG system with error handling"""
40
- try:
41
- if not DEPENDENCIES_AVAILABLE:
42
- return False
43
-
44
- # Initialize embedding model
45
- self.model = SentenceTransformer('all-MiniLM-L6-v2')
46
-
47
- # Try to load existing database
48
- if self._load_existing_database():
49
- self.initialized = True
50
- return True
51
-
52
- # If no database exists, try to rebuild
53
- if self._rebuild_from_chunks():
54
- self.initialized = True
55
- return True
56
-
57
- return False
58
-
59
- except Exception as e:
60
- st.error(f"Initialization error: {str(e)}")
61
- return False
62
-
63
- def _load_existing_database(self) -> bool:
64
- """Try to load existing ChromaDB"""
65
- try:
66
- # Check multiple possible paths
67
- db_paths = ['./chroma_db', './chroma', '.']
68
-
69
- for db_path in db_paths:
70
- try:
71
- if os.path.exists(os.path.join(db_path, 'chroma.sqlite3')) or os.path.exists(db_path):
72
- self.client = chromadb.PersistentClient(path=db_path)
73
- collections = self.client.list_collections()
74
-
75
- if collections:
76
- self.collection = collections[0] # Use first available collection
77
- st.success(f"✅ Loaded database from {db_path} with {self.collection.count()} documents")
78
- return True
79
-
80
- except Exception:
81
- continue
82
-
83
- return False
84
-
85
- except Exception as e:
86
- st.warning(f"Could not load existing database: {str(e)}")
87
- return False
88
-
89
- def _rebuild_from_chunks(self) -> bool:
90
- """Rebuild database from chunks.json if available"""
91
- try:
92
- chunks_file = 'chunks.json'
93
- if not os.path.exists(chunks_file):
94
- st.error("❌ No chunks.json file found. Please upload the required data files.")
95
- return False
96
-
97
- with open(chunks_file, 'r') as f:
98
- chunks = json.load(f)
99
-
100
- if not chunks:
101
- st.error("❌ Chunks file is empty")
102
- return False
103
-
104
- # Create new database
105
- db_path = './chroma_db'
106
- os.makedirs(db_path, exist_ok=True)
107
-
108
- self.client = chromadb.PersistentClient(path=db_path)
109
-
110
- # Create collection
111
- collection_name = "sklearn_docs"
112
- try:
113
- self.collection = self.client.get_collection(collection_name)
114
- except:
115
- self.collection = self.client.create_collection(collection_name)
116
-
117
- # Add chunks in batches
118
- batch_size = 100
119
- total_chunks = len(chunks)
120
-
121
- progress_bar = st.progress(0)
122
- status_text = st.empty()
123
-
124
- for i in range(0, total_chunks, batch_size):
125
- batch = chunks[i:i + batch_size]
126
-
127
- documents = [chunk['content'] for chunk in batch]
128
- metadatas = [{'source': chunk.get('source', 'unknown')} for chunk in batch]
129
- ids = [f"chunk_{i + j}" for j in range(len(batch))]
130
-
131
- self.collection.add(
132
- documents=documents,
133
- metadatas=metadatas,
134
- ids=ids
135
- )
136
-
137
- progress = (i + len(batch)) / total_chunks
138
- progress_bar.progress(progress)
139
- status_text.text(f"Processing chunks: {i + len(batch)}/{total_chunks}")
140
-
141
- progress_bar.empty()
142
- status_text.empty()
143
-
144
- st.success(f"✅ Successfully rebuilt database with {total_chunks} chunks")
145
- return True
146
-
147
- except Exception as e:
148
- st.error(f"Failed to rebuild database: {str(e)}")
149
- return False
150
-
151
- def search_documents(self, query: str, n_results: int = 5) -> List[Dict]:
152
- """Search for relevant documents"""
153
- try:
154
- if not self.initialized or not self.collection:
155
- return []
156
-
157
- results = self.collection.query(
158
- query_texts=[query],
159
- n_results=n_results
160
- )
161
-
162
- documents = []
163
- if results['documents'] and results['documents'][0]:
164
- for i, doc in enumerate(results['documents'][0]):
165
- documents.append({
166
- 'content': doc,
167
- 'source': results['metadatas'][0][i].get('source', 'unknown') if results['metadatas'] else 'unknown'
168
- })
169
-
170
- return documents
171
-
172
- except Exception as e:
173
- st.error(f"Search error: {str(e)}")
174
- return []
175
-
176
- def generate_response(self, query: str, context_docs: List[Dict]) -> str:
177
- """Generate response using OpenAI"""
178
- try:
179
- # Check for OpenAI API key
180
- api_key = st.session_state.get('openai_api_key') or os.getenv('OPENAI_API_KEY')
181
-
182
- if not api_key:
183
- return "⚠️ Please provide your OpenAI API key to generate responses."
184
-
185
- if not self.openai_client:
186
- self.openai_client = openai.OpenAI(api_key=api_key)
187
-
188
- # Prepare context
189
- context = "\n\n".join([f"Source: {doc['source']}\nContent: {doc['content']}"
190
- for doc in context_docs])
191
-
192
- if not context.strip():
193
- return "I couldn't find relevant information in the documentation. Please try rephrasing your question."
194
-
195
- # Create prompt
196
- prompt = f"""Based on the following Scikit-learn documentation, please answer the user's question accurately and helpfully.
197
-
198
- Documentation Context:
199
- {context}
200
-
201
- User Question: {query}
202
-
203
- Please provide a clear, accurate answer based on the documentation provided. If the documentation doesn't contain enough information to answer the question completely, please say so."""
204
-
205
- # Generate response
206
- response = self.openai_client.chat.completions.create(
207
- model="gpt-3.5-turbo",
208
- messages=[
209
- {"role": "system", "content": "You are a helpful assistant that answers questions about Scikit-learn based on provided documentation."},
210
- {"role": "user", "content": prompt}
211
- ],
212
- max_tokens=1000,
213
- temperature=0.3
214
- )
215
-
216
- return response.choices[0].message.content
217
-
218
- except Exception as e:
219
- return f"Error generating response: {str(e)}"
220
-
221
- def main():
222
- """Main Streamlit application"""
223
- st.set_page_config(
224
- page_title="Scikit-learn RAG Chatbot",
225
- page_icon="🤖",
226
- layout="wide"
227
- )
228
-
229
- st.title("🤖 Scikit-learn RAG Chatbot")
230
- st.markdown("Ask questions about Scikit-learn and get answers from the official documentation!")
231
-
232
- # Initialize session state
233
- if 'chatbot' not in st.session_state:
234
- st.session_state.chatbot = SimpleRAGChatbot()
235
- st.session_state.messages = []
236
- st.session_state.initialized = False
237
-
238
- # Initialize the chatbot if not already done
239
- if not st.session_state.initialized:
240
- with st.spinner("Initializing RAG system..."):
241
- success = st.session_state.chatbot.initialize()
242
- st.session_state.initialized = success
243
-
244
- if not success:
245
- st.error("❌ Failed to initialize the system. Please check the data files.")
246
- st.stop()
247
-
248
- # Sidebar for API key
249
- with st.sidebar:
250
- st.header("Configuration")
251
-
252
- api_key = st.text_input(
253
- "OpenAI API Key",
254
- type="password",
255
- value=st.session_state.get('openai_api_key', ''),
256
- help="Enter your OpenAI API key to enable response generation"
257
- )
258
-
259
- if api_key:
260
- st.session_state.openai_api_key = api_key
261
- st.success("✅ API key configured")
262
-
263
- st.markdown("---")
264
- st.markdown("### About")
265
- st.markdown("""
266
- This chatbot uses RAG (Retrieval-Augmented Generation) to answer questions about Scikit-learn.
267
-
268
- - **Data**: Official Scikit-learn documentation
269
- - **Embeddings**: all-MiniLM-L6-v2
270
- - **Vector DB**: ChromaDB
271
- - **LLM**: GPT-3.5-turbo
272
- """)
273
-
274
- # Chat interface
275
- st.header("💬 Chat")
276
-
277
- # Display chat messages
278
- for message in st.session_state.messages:
279
- with st.chat_message(message["role"]):
280
- st.markdown(message["content"])
281
-
282
- # Chat input
283
- if prompt := st.chat_input("Ask a question about Scikit-learn..."):
284
- # Add user message
285
- st.session_state.messages.append({"role": "user", "content": prompt})
286
- with st.chat_message("user"):
287
- st.markdown(prompt)
288
-
289
- # Generate response
290
- with st.chat_message("assistant"):
291
- with st.spinner("Searching documentation and generating response..."):
292
- # Search for relevant documents
293
- docs = st.session_state.chatbot.search_documents(prompt)
294
-
295
- if docs:
296
- st.markdown("**Found relevant documentation:**")
297
- for i, doc in enumerate(docs[:3], 1):
298
- with st.expander(f"📄 Source {i}: {doc['source']}", expanded=False):
299
- st.markdown(doc['content'][:500] + "..." if len(doc['content']) > 500 else doc['content'])
300
-
301
- # Generate response
302
- response = st.session_state.chatbot.generate_response(prompt, docs)
303
- st.markdown(response)
304
-
305
- # Add assistant message
306
- st.session_state.messages.append({"role": "assistant", "content": response})
307
-
308
- if __name__ == "__main__":
309
- main()