#!/usr/bin/env python3 """ Scikit-learn Documentation Q&A Bot A Retrieval-Augmented Generation (RAG) chatbot built with Streamlit that answers questions about Scikit-learn documentation using ChromaDB for retrieval and OpenAI for generation. Author: AI Assistant Date: September 2025 """ import os import logging from typing import List, Dict, Any, Optional, Tuple import streamlit as st import chromadb from chromadb.config import Settings from sentence_transformers import SentenceTransformer from openai import OpenAI # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class RAGChatbot: """ A Retrieval-Augmented Generation chatbot for Scikit-learn documentation. This class handles the complete RAG pipeline: retrieval from ChromaDB, augmentation with context, and generation using OpenAI's API. """ def __init__( self, db_path: str = './chroma_db', collection_name: str = 'sklearn_docs', embedding_model_name: str = 'all-MiniLM-L6-v2' ): """ Initialize the RAG chatbot. Args: db_path (str): Path to ChromaDB database collection_name (str): Name of the ChromaDB collection embedding_model_name (str): Name of the embedding model """ self.db_path = db_path self.collection_name = collection_name self.embedding_model_name = embedding_model_name # Initialize components self.chroma_client = None self.collection = None self.embedding_model = None self.openai_client = None # Initialize the retrieval system self._initialize_retrieval_system() def _initialize_retrieval_system(self) -> None: """ Initialize ChromaDB client and embedding model for retrieval. """ try: # Initialize ChromaDB client self.chroma_client = chromadb.PersistentClient( path=self.db_path, settings=Settings(anonymized_telemetry=False) ) # Get collection self.collection = self.chroma_client.get_collection( name=self.collection_name ) # Load embedding model (same as used for building the database) self.embedding_model = SentenceTransformer(self.embedding_model_name) logger.info("RAG retrieval system initialized successfully") except Exception as e: logger.error(f"Failed to initialize retrieval system: {e}") raise def set_openai_client(self, api_key: str) -> bool: """ Initialize OpenAI client with API key. Args: api_key (str): OpenAI API key Returns: bool: True if successful, False otherwise """ try: self.openai_client = OpenAI(api_key=api_key) # Test the API key with a simple request self.openai_client.models.list() logger.info("OpenAI client initialized successfully") return True except Exception as e: logger.error(f"Failed to initialize OpenAI client: {e}") st.error(f"Invalid API key or OpenAI connection error: {e}") return False def retrieve_relevant_chunks( self, query: str, n_results: int = 3, min_relevance_score: float = 0.1 ) -> List[Dict[str, Any]]: """ Retrieve relevant text chunks from the vector database. Args: query (str): User question/query n_results (int): Number of chunks to retrieve min_relevance_score (float): Minimum relevance score threshold Returns: List[Dict[str, Any]]: Retrieved chunks with content and metadata """ try: # Query the collection results = self.collection.query( query_texts=[query], n_results=n_results ) retrieved_chunks = [] # Process results if results['documents'] and results['documents'][0]: for i in range(len(results['documents'][0])): chunk_data = { 'content': results['documents'][0][i], 'metadata': results['metadatas'][0][i], 'distance': results['distances'][0][i] if 'distances' in results else None } # Filter by relevance score if available if chunk_data['distance'] is None or chunk_data['distance'] >= min_relevance_score: retrieved_chunks.append(chunk_data) logger.info(f"Retrieved {len(retrieved_chunks)} relevant chunks for query: {query[:50]}...") return retrieved_chunks except Exception as e: logger.error(f"Error retrieving chunks: {e}") st.error(f"Error during retrieval: {e}") return [] def create_rag_prompt( self, user_question: str, retrieved_chunks: List[Dict[str, Any]] ) -> str: """ Create an augmented prompt for OpenAI with retrieved context. Args: user_question (str): Original user question retrieved_chunks (List[Dict[str, Any]]): Retrieved relevant chunks Returns: str: Augmented prompt for OpenAI """ # Build context from retrieved chunks context_parts = [] for i, chunk in enumerate(retrieved_chunks, 1): url = chunk['metadata'].get('url', 'Unknown source') content = chunk['content'].strip() context_part = f"--- Context {i} (Source: {url}) ---\n{content}\n" context_parts.append(context_part) context = "\n".join(context_parts) # Create the RAG prompt rag_prompt = f"""You are an expert AI assistant specializing in Scikit-learn, a popular Python machine learning library. Your task is to answer questions about Scikit-learn based ONLY on the provided context from the official documentation. CONTEXT: {context} USER QUESTION: {user_question} INSTRUCTIONS: 1. Answer the question using ONLY the information provided in the context above 2. Be accurate, helpful, and specific 3. If the context doesn't contain enough information to fully answer the question, say so clearly 4. Include relevant code examples if they appear in the context 5. Mention specific function names, class names, or parameter names when relevant 6. Structure your answer clearly with appropriate formatting ANSWER:""" return rag_prompt def generate_answer( self, prompt: str, model: str = "gpt-3.5-turbo", max_tokens: int = 1000, temperature: float = 0.1 ) -> Optional[str]: """ Generate answer using OpenAI API. Args: prompt (str): Augmented prompt with context model (str): OpenAI model to use max_tokens (int): Maximum tokens in response temperature (float): Temperature for generation Returns: Optional[str]: Generated answer or None if failed """ try: response = self.openai_client.chat.completions.create( model=model, messages=[ { "role": "system", "content": "You are a helpful AI assistant specializing in Scikit-learn documentation. Provide accurate, helpful answers based only on the provided context." }, { "role": "user", "content": prompt } ], max_tokens=max_tokens, temperature=temperature, top_p=0.9 ) answer = response.choices[0].message.content.strip() logger.info(f"Generated answer of length: {len(answer)}") return answer except Exception as e: logger.error(f"Error generating answer: {e}") st.error(f"Error generating answer: {e}") return None def get_answer( self, user_question: str, n_chunks: int = 3, model: str = "gpt-3.5-turbo" ) -> Tuple[Optional[str], List[str]]: """ Complete RAG pipeline: retrieve, augment, generate. Args: user_question (str): User's question n_chunks (int): Number of chunks to retrieve model (str): OpenAI model to use Returns: Tuple[Optional[str], List[str]]: Generated answer and source URLs """ if not self.openai_client: st.error("OpenAI client not initialized. Please provide a valid API key.") return None, [] # Use a single spinner for the entire process to prevent flickering with st.spinner("🤖 Generating answer..."): # Step 1: Retrieve relevant chunks retrieved_chunks = self.retrieve_relevant_chunks(user_question, n_chunks) if not retrieved_chunks: return "I couldn't find relevant information in the Scikit-learn documentation to answer your question. Please try rephrasing your question or ask about a different topic.", [] # Step 2: Create augmented prompt rag_prompt = self.create_rag_prompt(user_question, retrieved_chunks) # Step 3: Generate answer answer = self.generate_answer(rag_prompt, model) # Extract source URLs source_urls = [chunk['metadata'].get('url', 'Unknown') for chunk in retrieved_chunks] source_urls = list(dict.fromkeys(source_urls)) # Remove duplicates while preserving order return answer, source_urls def initialize_session_state(): """Initialize Streamlit session state variables.""" if 'chatbot' not in st.session_state: try: st.session_state.chatbot = RAGChatbot() except Exception as e: st.error(f"Failed to initialize chatbot: {e}") st.stop() if 'openai_initialized' not in st.session_state: st.session_state.openai_initialized = False if 'chat_history' not in st.session_state: st.session_state.chat_history = [] def main(): """Main Streamlit application.""" # Page configuration st.set_page_config( page_title="Scikit-learn Q&A Bot", page_icon="🤖", layout="wide", initial_sidebar_state="expanded" ) # Initialize session state initialize_session_state() # Main title and description st.title("🤖 Scikit-learn Documentation Q&A Bot") st.markdown(""" Welcome to the **Scikit-learn Documentation Q&A Bot**! This intelligent assistant can answer your questions about Scikit-learn using the official documentation. **How it works:** 1. 🔍 **Retrieval**: Searches through 1,249+ documentation chunks 2. 📝 **Augmentation**: Provides relevant context to the AI 3. 🤖 **Generation**: Uses OpenAI to generate accurate answers """) # Sidebar for API key and settings with st.sidebar: st.header("âš™ī¸ Configuration") # OpenAI API Key input api_key = st.text_input( "🔑 OpenAI API Key", type="password", placeholder="sk-...", help="Enter your OpenAI API key to enable the chatbot" ) if api_key and not st.session_state.openai_initialized: if st.session_state.chatbot.set_openai_client(api_key): st.session_state.openai_initialized = True st.success("✅ API key validated!") st.rerun() # Model selection model = st.selectbox( "🧠 AI Model", ["gpt-3.5-turbo", "gpt-4", "gpt-4-turbo-preview"], index=0, help="Choose the OpenAI model for generating answers" ) # Number of context chunks n_chunks = st.slider( "📄 Context Chunks", min_value=1, max_value=5, value=3, help="Number of relevant documentation chunks to use for context" ) st.markdown("---") # Database info st.header("📊 Database Info") try: collection_count = st.session_state.chatbot.collection.count() st.metric("Total Documents", f"{collection_count:,}") st.metric("Embedding Model", "all-MiniLM-L6-v2") st.metric("Vector Dimensions", "384") except: st.error("Could not load database info") st.markdown("---") # Clear chat history if st.button("đŸ—‘ī¸ Clear Chat History"): st.session_state.chat_history = [] st.rerun() # Main chat interface col1, col2 = st.columns([2, 1]) with col1: st.header("đŸ’Ŧ Ask Your Question") # Question input with better state management # Initialize question input key if it doesn't exist if 'question_input' not in st.session_state: st.session_state.question_input = '' # Handle selected question from examples if 'selected_question' in st.session_state: st.session_state.question_input = st.session_state['selected_question'] # Clear after setting to prevent re-triggering del st.session_state['selected_question'] user_question = st.text_input( "Enter your question about Scikit-learn:", placeholder="e.g., How do I perform cross-validation in scikit-learn?", key="question_input" ) # Submit button submit_button = st.button("🚀 Get Answer", type="primary") # Process question if submit_button and user_question: if not st.session_state.openai_initialized: st.error("âš ī¸ Please enter a valid OpenAI API key in the sidebar first.") else: # Get answer using RAG answer, sources = st.session_state.chatbot.get_answer( user_question, n_chunks, model ) if answer: # Add to chat history st.session_state.chat_history.append({ 'question': user_question, 'answer': answer, 'sources': sources }) # Success message st.success("✅ Answer generated successfully! Check the chat history below.") # Display chat history if st.session_state.chat_history: st.header("📝 Chat History") for i, chat in enumerate(reversed(st.session_state.chat_history)): with st.expander(f"Q: {chat['question'][:60]}...", expanded=(i == 0)): st.markdown(f"**Question:** {chat['question']}") st.markdown(f"**Answer:**\n{chat['answer']}") if chat['sources']: st.markdown("**Sources:**") for j, source in enumerate(chat['sources'], 1): source_name = source.split('/')[-1] if '/' in source else source st.markdown(f"{j}. [{source_name}]({source})") with col2: st.header("💡 Example Questions") example_questions = [ "How do I perform cross-validation in scikit-learn?", "What is the difference between Ridge and Lasso regression?", "How do I use GridSearchCV for parameter tuning?", "What clustering algorithms are available in scikit-learn?", "How do I preprocess data using StandardScaler?", "What is the difference between classification and regression?", "How do I handle missing values in my dataset?", "What is feature selection and how do I use it?", "How do I visualize decision trees?", "What is ensemble learning in scikit-learn?" ] for question in example_questions: if st.button(question, key=f"example_{hash(question)}"): # Set the question and rerun only once st.session_state['selected_question'] = question st.rerun() st.markdown("---") st.header("â„šī¸ Tips") st.markdown(""" **For best results:** - Be specific in your questions - Ask about scikit-learn functionality - Include context when possible - Check the sources for verification **The bot can help with:** - API usage and parameters - Algorithm explanations - Code examples - Best practices - Troubleshooting """) if __name__ == "__main__": # Run the main application # Note: For deployment environments like HuggingFace Spaces, # Streamlit warnings about missing ScriptRunContext can be safely ignored main()