Spaces:
Sleeping
Sleeping
File size: 17,914 Bytes
8fc10eb 9222df3 8fc10eb 9222df3 8fc10eb 9222df3 8fc10eb 9222df3 8fc10eb 9222df3 8fc10eb 9222df3 8fc10eb 9222df3 8fc10eb 9222df3 8fc10eb 9222df3 8fc10eb 9222df3 8fc10eb b943512 8fc10eb 95d9c92 b943512 8fc10eb b943512 8fc10eb 95d9c92 8fc10eb 95d9c92 8fc10eb 9222df3 8fc10eb 9222df3 8fc10eb 9222df3 8fc10eb 9222df3 8fc10eb 9222df3 8fc10eb 9222df3 8fc10eb 9222df3 8fc10eb 9222df3 8fc10eb 9222df3 8fc10eb b943512 8fc10eb b943512 8fc10eb 9222df3 8fc10eb 9222df3 8fc10eb 9222df3 8fc10eb 9222df3 8fc10eb 9222df3 8fc10eb 9222df3 8fc10eb 9222df3 8fc10eb 9222df3 8fc10eb 9222df3 8fc10eb fc24b73 8fc10eb fc24b73 8fc10eb fc24b73 8fc10eb 9222df3 8fc10eb 9222df3 8fc10eb 9222df3 8fc10eb 9222df3 b943512 8fc10eb 95d9c92 8fc10eb 9222df3 8fc10eb 9222df3 8fc10eb 9222df3 8fc10eb 9222df3 8fc10eb 9222df3 8fc10eb 9222df3 8fc10eb 9222df3 8fc10eb 9222df3 8fc10eb 9222df3 8fc10eb fc24b73 e748e8d fc24b73 e748e8d fc24b73 8fc10eb b943512 8fc10eb e748e8d fc24b73 8fc10eb fc24b73 8fc10eb 9222df3 74bb35b ac20173 74bb35b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 | #!/usr/bin/env python3
"""
Scikit-learn Documentation Q&A Bot
A Retrieval-Augmented Generation (RAG) chatbot built with Streamlit
that answers questions about Scikit-learn documentation using ChromaDB
for retrieval and OpenAI for generation.
Author: AI Assistant
Date: September 2025
"""
import os
import logging
from typing import List, Dict, Any, Optional, Tuple
import streamlit as st
import chromadb
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer
from openai import OpenAI
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class RAGChatbot:
"""
A Retrieval-Augmented Generation chatbot for Scikit-learn documentation.
This class handles the complete RAG pipeline: retrieval from ChromaDB,
augmentation with context, and generation using OpenAI's API.
"""
def __init__(
self,
db_path: str = './chroma_db',
collection_name: str = 'sklearn_docs',
embedding_model_name: str = 'all-MiniLM-L6-v2'
):
"""
Initialize the RAG chatbot.
Args:
db_path (str): Path to ChromaDB database
collection_name (str): Name of the ChromaDB collection
embedding_model_name (str): Name of the embedding model
"""
self.db_path = db_path
self.collection_name = collection_name
self.embedding_model_name = embedding_model_name
# Initialize components
self.chroma_client = None
self.collection = None
self.embedding_model = None
self.openai_client = None
# Initialize the retrieval system
self._initialize_retrieval_system()
def _initialize_retrieval_system(self) -> None:
"""
Initialize ChromaDB client and embedding model for retrieval.
"""
try:
# Initialize ChromaDB client
self.chroma_client = chromadb.PersistentClient(
path=self.db_path,
settings=Settings(anonymized_telemetry=False)
)
# Get collection
self.collection = self.chroma_client.get_collection(
name=self.collection_name
)
# Load embedding model (same as used for building the database)
self.embedding_model = SentenceTransformer(self.embedding_model_name)
logger.info("RAG retrieval system initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize retrieval system: {e}")
raise
def set_openai_client(self, api_key: str) -> bool:
"""
Initialize OpenAI client with API key.
Args:
api_key (str): OpenAI API key
Returns:
bool: True if successful, False otherwise
"""
try:
self.openai_client = OpenAI(api_key=api_key)
# Test the API key with a simple request
self.openai_client.models.list()
logger.info("OpenAI client initialized successfully")
return True
except Exception as e:
logger.error(f"Failed to initialize OpenAI client: {e}")
st.error(f"Invalid API key or OpenAI connection error: {e}")
return False
def retrieve_relevant_chunks(
self,
query: str,
n_results: int = 3,
min_relevance_score: float = 0.1
) -> List[Dict[str, Any]]:
"""
Retrieve relevant text chunks from the vector database.
Args:
query (str): User question/query
n_results (int): Number of chunks to retrieve
min_relevance_score (float): Minimum relevance score threshold
Returns:
List[Dict[str, Any]]: Retrieved chunks with content and metadata
"""
try:
# Query the collection
results = self.collection.query(
query_texts=[query],
n_results=n_results
)
retrieved_chunks = []
# Process results
if results['documents'] and results['documents'][0]:
for i in range(len(results['documents'][0])):
chunk_data = {
'content': results['documents'][0][i],
'metadata': results['metadatas'][0][i],
'distance': results['distances'][0][i] if 'distances' in results else None
}
# Filter by relevance score if available
if chunk_data['distance'] is None or chunk_data['distance'] >= min_relevance_score:
retrieved_chunks.append(chunk_data)
logger.info(f"Retrieved {len(retrieved_chunks)} relevant chunks for query: {query[:50]}...")
return retrieved_chunks
except Exception as e:
logger.error(f"Error retrieving chunks: {e}")
st.error(f"Error during retrieval: {e}")
return []
def create_rag_prompt(
self,
user_question: str,
retrieved_chunks: List[Dict[str, Any]]
) -> str:
"""
Create an augmented prompt for OpenAI with retrieved context.
Args:
user_question (str): Original user question
retrieved_chunks (List[Dict[str, Any]]): Retrieved relevant chunks
Returns:
str: Augmented prompt for OpenAI
"""
# Build context from retrieved chunks
context_parts = []
for i, chunk in enumerate(retrieved_chunks, 1):
url = chunk['metadata'].get('url', 'Unknown source')
content = chunk['content'].strip()
context_part = f"--- Context {i} (Source: {url}) ---\n{content}\n"
context_parts.append(context_part)
context = "\n".join(context_parts)
# Create the RAG prompt
rag_prompt = f"""You are an expert AI assistant specializing in Scikit-learn, a popular Python machine learning library. Your task is to answer questions about Scikit-learn based ONLY on the provided context from the official documentation.
CONTEXT:
{context}
USER QUESTION:
{user_question}
INSTRUCTIONS:
1. Answer the question using ONLY the information provided in the context above
2. Be accurate, helpful, and specific
3. If the context doesn't contain enough information to fully answer the question, say so clearly
4. Include relevant code examples if they appear in the context
5. Mention specific function names, class names, or parameter names when relevant
6. Structure your answer clearly with appropriate formatting
ANSWER:"""
return rag_prompt
def generate_answer(
self,
prompt: str,
model: str = "gpt-3.5-turbo",
max_tokens: int = 1000,
temperature: float = 0.1
) -> Optional[str]:
"""
Generate answer using OpenAI API.
Args:
prompt (str): Augmented prompt with context
model (str): OpenAI model to use
max_tokens (int): Maximum tokens in response
temperature (float): Temperature for generation
Returns:
Optional[str]: Generated answer or None if failed
"""
try:
response = self.openai_client.chat.completions.create(
model=model,
messages=[
{
"role": "system",
"content": "You are a helpful AI assistant specializing in Scikit-learn documentation. Provide accurate, helpful answers based only on the provided context."
},
{
"role": "user",
"content": prompt
}
],
max_tokens=max_tokens,
temperature=temperature,
top_p=0.9
)
answer = response.choices[0].message.content.strip()
logger.info(f"Generated answer of length: {len(answer)}")
return answer
except Exception as e:
logger.error(f"Error generating answer: {e}")
st.error(f"Error generating answer: {e}")
return None
def get_answer(
self,
user_question: str,
n_chunks: int = 3,
model: str = "gpt-3.5-turbo"
) -> Tuple[Optional[str], List[str]]:
"""
Complete RAG pipeline: retrieve, augment, generate.
Args:
user_question (str): User's question
n_chunks (int): Number of chunks to retrieve
model (str): OpenAI model to use
Returns:
Tuple[Optional[str], List[str]]: Generated answer and source URLs
"""
if not self.openai_client:
st.error("OpenAI client not initialized. Please provide a valid API key.")
return None, []
# Use a single spinner for the entire process to prevent flickering
with st.spinner("π€ Generating answer..."):
# Step 1: Retrieve relevant chunks
retrieved_chunks = self.retrieve_relevant_chunks(user_question, n_chunks)
if not retrieved_chunks:
return "I couldn't find relevant information in the Scikit-learn documentation to answer your question. Please try rephrasing your question or ask about a different topic.", []
# Step 2: Create augmented prompt
rag_prompt = self.create_rag_prompt(user_question, retrieved_chunks)
# Step 3: Generate answer
answer = self.generate_answer(rag_prompt, model)
# Extract source URLs
source_urls = [chunk['metadata'].get('url', 'Unknown') for chunk in retrieved_chunks]
source_urls = list(dict.fromkeys(source_urls)) # Remove duplicates while preserving order
return answer, source_urls
def initialize_session_state():
"""Initialize Streamlit session state variables."""
if 'chatbot' not in st.session_state:
try:
st.session_state.chatbot = RAGChatbot()
except Exception as e:
st.error(f"Failed to initialize chatbot: {e}")
st.stop()
if 'openai_initialized' not in st.session_state:
st.session_state.openai_initialized = False
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
def main():
"""Main Streamlit application."""
# Page configuration
st.set_page_config(
page_title="Scikit-learn Q&A Bot",
page_icon="π€",
layout="wide",
initial_sidebar_state="expanded"
)
# Initialize session state
initialize_session_state()
# Main title and description
st.title("π€ Scikit-learn Documentation Q&A Bot")
st.markdown("""
Welcome to the **Scikit-learn Documentation Q&A Bot**! This intelligent assistant can answer your questions about Scikit-learn using the official documentation.
**How it works:**
1. π **Retrieval**: Searches through 1,249+ documentation chunks
2. π **Augmentation**: Provides relevant context to the AI
3. π€ **Generation**: Uses OpenAI to generate accurate answers
""")
# Sidebar for API key and settings
with st.sidebar:
st.header("βοΈ Configuration")
# OpenAI API Key input
api_key = st.text_input(
"π OpenAI API Key",
type="password",
placeholder="sk-...",
help="Enter your OpenAI API key to enable the chatbot"
)
if api_key and not st.session_state.openai_initialized:
if st.session_state.chatbot.set_openai_client(api_key):
st.session_state.openai_initialized = True
st.success("β
API key validated!")
st.rerun()
# Model selection
model = st.selectbox(
"π§ AI Model",
["gpt-3.5-turbo", "gpt-4", "gpt-4-turbo-preview"],
index=0,
help="Choose the OpenAI model for generating answers"
)
# Number of context chunks
n_chunks = st.slider(
"π Context Chunks",
min_value=1,
max_value=5,
value=3,
help="Number of relevant documentation chunks to use for context"
)
st.markdown("---")
# Database info
st.header("π Database Info")
try:
collection_count = st.session_state.chatbot.collection.count()
st.metric("Total Documents", f"{collection_count:,}")
st.metric("Embedding Model", "all-MiniLM-L6-v2")
st.metric("Vector Dimensions", "384")
except:
st.error("Could not load database info")
st.markdown("---")
# Clear chat history
if st.button("ποΈ Clear Chat History"):
st.session_state.chat_history = []
st.rerun()
# Main chat interface
col1, col2 = st.columns([2, 1])
with col1:
st.header("π¬ Ask Your Question")
# Question input with better state management
# Initialize question input key if it doesn't exist
if 'question_input' not in st.session_state:
st.session_state.question_input = ''
# Handle selected question from examples
if 'selected_question' in st.session_state:
st.session_state.question_input = st.session_state['selected_question']
# Clear after setting to prevent re-triggering
del st.session_state['selected_question']
user_question = st.text_input(
"Enter your question about Scikit-learn:",
placeholder="e.g., How do I perform cross-validation in scikit-learn?",
key="question_input"
)
# Submit button
submit_button = st.button("π Get Answer", type="primary")
# Process question
if submit_button and user_question:
if not st.session_state.openai_initialized:
st.error("β οΈ Please enter a valid OpenAI API key in the sidebar first.")
else:
# Get answer using RAG
answer, sources = st.session_state.chatbot.get_answer(
user_question, n_chunks, model
)
if answer:
# Add to chat history
st.session_state.chat_history.append({
'question': user_question,
'answer': answer,
'sources': sources
})
# Success message
st.success("β
Answer generated successfully! Check the chat history below.")
# Display chat history
if st.session_state.chat_history:
st.header("π Chat History")
for i, chat in enumerate(reversed(st.session_state.chat_history)):
with st.expander(f"Q: {chat['question'][:60]}...", expanded=(i == 0)):
st.markdown(f"**Question:** {chat['question']}")
st.markdown(f"**Answer:**\n{chat['answer']}")
if chat['sources']:
st.markdown("**Sources:**")
for j, source in enumerate(chat['sources'], 1):
source_name = source.split('/')[-1] if '/' in source else source
st.markdown(f"{j}. [{source_name}]({source})")
with col2:
st.header("π‘ Example Questions")
example_questions = [
"How do I perform cross-validation in scikit-learn?",
"What is the difference between Ridge and Lasso regression?",
"How do I use GridSearchCV for parameter tuning?",
"What clustering algorithms are available in scikit-learn?",
"How do I preprocess data using StandardScaler?",
"What is the difference between classification and regression?",
"How do I handle missing values in my dataset?",
"What is feature selection and how do I use it?",
"How do I visualize decision trees?",
"What is ensemble learning in scikit-learn?"
]
for question in example_questions:
if st.button(question, key=f"example_{hash(question)}"):
# Set the question and rerun only once
st.session_state['selected_question'] = question
st.rerun()
st.markdown("---")
st.header("βΉοΈ Tips")
st.markdown("""
**For best results:**
- Be specific in your questions
- Ask about scikit-learn functionality
- Include context when possible
- Check the sources for verification
**The bot can help with:**
- API usage and parameters
- Algorithm explanations
- Code examples
- Best practices
- Troubleshooting
""")
if __name__ == "__main__":
# Run the main application
# Note: For deployment environments like HuggingFace Spaces,
# Streamlit warnings about missing ScriptRunContext can be safely ignored
main() |