scikit-rag / app.py
fguryel's picture
Fix: Remove session state modification after widget instantiation
6c1cd9e
#!/usr/bin/env python3
"""
Scikit-learn Documentation Q&A Bot
A Retrieval-Augmented Generation (RAG) chatbot built with Streamlit
that answers questions about Scikit-learn documentation using ChromaDB
for retrieval and OpenAI for generation.
Author: AI Assistant
Date: September 2025
"""
import os
import logging
from typing import List, Dict, Any, Optional, Tuple
import streamlit as st
import chromadb
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer
from openai import OpenAI
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class RAGChatbot:
"""
A Retrieval-Augmented Generation chatbot for Scikit-learn documentation.
This class handles the complete RAG pipeline: retrieval from ChromaDB,
augmentation with context, and generation using OpenAI's API.
"""
def __init__(
self,
db_path: str = './chroma_db',
collection_name: str = 'sklearn_docs',
embedding_model_name: str = 'all-MiniLM-L6-v2'
):
"""
Initialize the RAG chatbot.
Args:
db_path (str): Path to ChromaDB database
collection_name (str): Name of the ChromaDB collection
embedding_model_name (str): Name of the embedding model
"""
self.db_path = db_path
self.collection_name = collection_name
self.embedding_model_name = embedding_model_name
# Initialize components
self.chroma_client = None
self.collection = None
self.embedding_model = None
self.openai_client = None
# Initialize the retrieval system
self._initialize_retrieval_system()
def _initialize_retrieval_system(self) -> None:
"""
Initialize ChromaDB client and embedding model for retrieval.
"""
try:
# Initialize ChromaDB client
self.chroma_client = chromadb.PersistentClient(
path=self.db_path,
settings=Settings(anonymized_telemetry=False)
)
# Get collection
self.collection = self.chroma_client.get_collection(
name=self.collection_name
)
# Load embedding model (same as used for building the database)
self.embedding_model = SentenceTransformer(self.embedding_model_name)
logger.info("RAG retrieval system initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize retrieval system: {e}")
raise
def set_openai_client(self, api_key: str) -> bool:
"""
Initialize OpenAI client with API key.
Args:
api_key (str): OpenAI API key
Returns:
bool: True if successful, False otherwise
"""
try:
self.openai_client = OpenAI(api_key=api_key)
# Test the API key with a simple request
self.openai_client.models.list()
logger.info("OpenAI client initialized successfully")
return True
except Exception as e:
logger.error(f"Failed to initialize OpenAI client: {e}")
st.error(f"Invalid API key or OpenAI connection error: {e}")
return False
def retrieve_relevant_chunks(
self,
query: str,
n_results: int = 3,
min_relevance_score: float = 0.1
) -> List[Dict[str, Any]]:
"""
Retrieve relevant text chunks from the vector database.
Args:
query (str): User question/query
n_results (int): Number of chunks to retrieve
min_relevance_score (float): Minimum relevance score threshold
Returns:
List[Dict[str, Any]]: Retrieved chunks with content and metadata
"""
try:
# Query the collection
results = self.collection.query(
query_texts=[query],
n_results=n_results
)
retrieved_chunks = []
# Process results
if results['documents'] and results['documents'][0]:
for i in range(len(results['documents'][0])):
chunk_data = {
'content': results['documents'][0][i],
'metadata': results['metadatas'][0][i],
'distance': results['distances'][0][i] if 'distances' in results else None
}
# Filter by relevance score if available
if chunk_data['distance'] is None or chunk_data['distance'] >= min_relevance_score:
retrieved_chunks.append(chunk_data)
logger.info(f"Retrieved {len(retrieved_chunks)} relevant chunks for query: {query[:50]}...")
return retrieved_chunks
except Exception as e:
logger.error(f"Error retrieving chunks: {e}")
st.error(f"Error during retrieval: {e}")
return []
def create_rag_prompt(
self,
user_question: str,
retrieved_chunks: List[Dict[str, Any]]
) -> str:
"""
Create an augmented prompt for OpenAI with retrieved context.
Args:
user_question (str): Original user question
retrieved_chunks (List[Dict[str, Any]]): Retrieved relevant chunks
Returns:
str: Augmented prompt for OpenAI
"""
# Build context from retrieved chunks
context_parts = []
for i, chunk in enumerate(retrieved_chunks, 1):
url = chunk['metadata'].get('url', 'Unknown source')
content = chunk['content'].strip()
context_part = f"--- Context {i} (Source: {url}) ---\n{content}\n"
context_parts.append(context_part)
context = "\n".join(context_parts)
# Create the RAG prompt
rag_prompt = f"""You are an expert AI assistant specializing in Scikit-learn, a popular Python machine learning library. Your task is to answer questions about Scikit-learn based ONLY on the provided context from the official documentation.
CONTEXT:
{context}
USER QUESTION:
{user_question}
INSTRUCTIONS:
1. Answer the question using ONLY the information provided in the context above
2. Be accurate, helpful, and specific
3. If the context doesn't contain enough information to fully answer the question, say so clearly
4. Include relevant code examples if they appear in the context
5. Mention specific function names, class names, or parameter names when relevant
6. Structure your answer clearly with appropriate formatting
ANSWER:"""
return rag_prompt
def generate_answer(
self,
prompt: str,
model: str = "gpt-3.5-turbo",
max_tokens: int = 1000,
temperature: float = 0.1
) -> Optional[str]:
"""
Generate answer using OpenAI API.
Args:
prompt (str): Augmented prompt with context
model (str): OpenAI model to use
max_tokens (int): Maximum tokens in response
temperature (float): Temperature for generation
Returns:
Optional[str]: Generated answer or None if failed
"""
try:
response = self.openai_client.chat.completions.create(
model=model,
messages=[
{
"role": "system",
"content": "You are a helpful AI assistant specializing in Scikit-learn documentation. Provide accurate, helpful answers based only on the provided context."
},
{
"role": "user",
"content": prompt
}
],
max_tokens=max_tokens,
temperature=temperature,
top_p=0.9
)
answer = response.choices[0].message.content.strip()
logger.info(f"Generated answer of length: {len(answer)}")
return answer
except Exception as e:
logger.error(f"Error generating answer: {e}")
st.error(f"Error generating answer: {e}")
return None
def get_answer(
self,
user_question: str,
n_chunks: int = 3,
model: str = "gpt-3.5-turbo"
) -> Tuple[Optional[str], List[str]]:
"""
Complete RAG pipeline: retrieve, augment, generate.
Args:
user_question (str): User's question
n_chunks (int): Number of chunks to retrieve
model (str): OpenAI model to use
Returns:
Tuple[Optional[str], List[str]]: Generated answer and source URLs
"""
if not self.openai_client:
st.error("OpenAI client not initialized. Please provide a valid API key.")
return None, []
# Use a single spinner for the entire process to prevent flickering
with st.spinner("πŸ€– Generating answer..."):
# Step 1: Retrieve relevant chunks
retrieved_chunks = self.retrieve_relevant_chunks(user_question, n_chunks)
if not retrieved_chunks:
return "I couldn't find relevant information in the Scikit-learn documentation to answer your question. Please try rephrasing your question or ask about a different topic.", []
# Step 2: Create augmented prompt
rag_prompt = self.create_rag_prompt(user_question, retrieved_chunks)
# Step 3: Generate answer
answer = self.generate_answer(rag_prompt, model)
# Extract source URLs
source_urls = [chunk['metadata'].get('url', 'Unknown') for chunk in retrieved_chunks]
source_urls = list(dict.fromkeys(source_urls)) # Remove duplicates while preserving order
return answer, source_urls
def initialize_session_state():
"""Initialize Streamlit session state variables."""
if 'chatbot' not in st.session_state:
try:
st.session_state.chatbot = RAGChatbot()
except Exception as e:
st.error(f"Failed to initialize chatbot: {e}")
st.stop()
if 'openai_initialized' not in st.session_state:
st.session_state.openai_initialized = False
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
def main():
"""Main Streamlit application."""
# Page configuration
st.set_page_config(
page_title="Scikit-learn Q&A Bot",
page_icon="πŸ€–",
layout="wide",
initial_sidebar_state="expanded"
)
# Initialize session state
initialize_session_state()
# Main title and description
st.title("πŸ€– Scikit-learn Documentation Q&A Bot")
st.markdown("""
Welcome to the **Scikit-learn Documentation Q&A Bot**! This intelligent assistant can answer your questions about Scikit-learn using the official documentation.
**How it works:**
1. πŸ” **Retrieval**: Searches through 1,249+ documentation chunks
2. πŸ“ **Augmentation**: Provides relevant context to the AI
3. πŸ€– **Generation**: Uses OpenAI to generate accurate answers
""")
# Sidebar for API key and settings
with st.sidebar:
st.header("βš™οΈ Configuration")
# OpenAI API Key input
api_key = st.text_input(
"πŸ”‘ OpenAI API Key",
type="password",
placeholder="sk-...",
help="Enter your OpenAI API key to enable the chatbot"
)
if api_key and not st.session_state.openai_initialized:
if st.session_state.chatbot.set_openai_client(api_key):
st.session_state.openai_initialized = True
st.success("βœ… API key validated!")
st.rerun()
# Model selection
model = st.selectbox(
"🧠 AI Model",
["gpt-3.5-turbo", "gpt-4", "gpt-4-turbo-preview"],
index=0,
help="Choose the OpenAI model for generating answers"
)
# Number of context chunks
n_chunks = st.slider(
"πŸ“„ Context Chunks",
min_value=1,
max_value=5,
value=3,
help="Number of relevant documentation chunks to use for context"
)
st.markdown("---")
# Database info
st.header("πŸ“Š Database Info")
try:
collection_count = st.session_state.chatbot.collection.count()
st.metric("Total Documents", f"{collection_count:,}")
st.metric("Embedding Model", "all-MiniLM-L6-v2")
st.metric("Vector Dimensions", "384")
except:
st.error("Could not load database info")
st.markdown("---")
# Clear chat history
if st.button("πŸ—‘οΈ Clear Chat History"):
st.session_state.chat_history = []
st.rerun()
# Main chat interface
col1, col2 = st.columns([2, 1])
with col1:
st.header("πŸ’¬ Ask Your Question")
# Question input with better state management
# Initialize question input key if it doesn't exist
if 'question_input' not in st.session_state:
st.session_state.question_input = ''
# Handle selected question from examples
if 'selected_question' in st.session_state:
st.session_state.question_input = st.session_state['selected_question']
# Clear after setting to prevent re-triggering
del st.session_state['selected_question']
user_question = st.text_input(
"Enter your question about Scikit-learn:",
placeholder="e.g., How do I perform cross-validation in scikit-learn?",
key="question_input"
)
# Submit button
submit_button = st.button("πŸš€ Get Answer", type="primary")
# Process question
if submit_button and user_question:
if not st.session_state.openai_initialized:
st.error("⚠️ Please enter a valid OpenAI API key in the sidebar first.")
else:
# Get answer using RAG
answer, sources = st.session_state.chatbot.get_answer(
user_question, n_chunks, model
)
if answer:
# Add to chat history
st.session_state.chat_history.append({
'question': user_question,
'answer': answer,
'sources': sources
})
# Success message
st.success("βœ… Answer generated successfully! Check the chat history below.")
# Display chat history
if st.session_state.chat_history:
st.header("πŸ“ Chat History")
for i, chat in enumerate(reversed(st.session_state.chat_history)):
with st.expander(f"Q: {chat['question'][:60]}...", expanded=(i == 0)):
st.markdown(f"**Question:** {chat['question']}")
st.markdown(f"**Answer:**\n{chat['answer']}")
if chat['sources']:
st.markdown("**Sources:**")
for j, source in enumerate(chat['sources'], 1):
source_name = source.split('/')[-1] if '/' in source else source
st.markdown(f"{j}. [{source_name}]({source})")
with col2:
st.header("πŸ’‘ Example Questions")
example_questions = [
"How do I perform cross-validation in scikit-learn?",
"What is the difference between Ridge and Lasso regression?",
"How do I use GridSearchCV for parameter tuning?",
"What clustering algorithms are available in scikit-learn?",
"How do I preprocess data using StandardScaler?",
"What is the difference between classification and regression?",
"How do I handle missing values in my dataset?",
"What is feature selection and how do I use it?",
"How do I visualize decision trees?",
"What is ensemble learning in scikit-learn?"
]
for question in example_questions:
if st.button(question, key=f"example_{hash(question)}"):
# Set the question and rerun only once
st.session_state['selected_question'] = question
st.rerun()
st.markdown("---")
st.header("ℹ️ Tips")
st.markdown("""
**For best results:**
- Be specific in your questions
- Ask about scikit-learn functionality
- Include context when possible
- Check the sources for verification
**The bot can help with:**
- API usage and parameters
- Algorithm explanations
- Code examples
- Best practices
- Troubleshooting
""")
if __name__ == "__main__":
# Run the main application
# Note: For deployment environments like HuggingFace Spaces,
# Streamlit warnings about missing ScriptRunContext can be safely ignored
main()