Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Scikit-learn Documentation Q&A Bot | |
| A Retrieval-Augmented Generation (RAG) chatbot built with Streamlit | |
| that answers questions about Scikit-learn documentation using ChromaDB | |
| for retrieval and OpenAI for generation. | |
| Author: AI Assistant | |
| Date: September 2025 | |
| """ | |
| import os | |
| import logging | |
| from typing import List, Dict, Any, Optional, Tuple | |
| import streamlit as st | |
| import chromadb | |
| from chromadb.config import Settings | |
| from sentence_transformers import SentenceTransformer | |
| from openai import OpenAI | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class RAGChatbot: | |
| """ | |
| A Retrieval-Augmented Generation chatbot for Scikit-learn documentation. | |
| This class handles the complete RAG pipeline: retrieval from ChromaDB, | |
| augmentation with context, and generation using OpenAI's API. | |
| """ | |
| def __init__( | |
| self, | |
| db_path: str = './chroma_db', | |
| collection_name: str = 'sklearn_docs', | |
| embedding_model_name: str = 'all-MiniLM-L6-v2' | |
| ): | |
| """ | |
| Initialize the RAG chatbot. | |
| Args: | |
| db_path (str): Path to ChromaDB database | |
| collection_name (str): Name of the ChromaDB collection | |
| embedding_model_name (str): Name of the embedding model | |
| """ | |
| self.db_path = db_path | |
| self.collection_name = collection_name | |
| self.embedding_model_name = embedding_model_name | |
| # Initialize components | |
| self.chroma_client = None | |
| self.collection = None | |
| self.embedding_model = None | |
| self.openai_client = None | |
| # Initialize the retrieval system | |
| self._initialize_retrieval_system() | |
| def _initialize_retrieval_system(self) -> None: | |
| """ | |
| Initialize ChromaDB client and embedding model for retrieval. | |
| """ | |
| try: | |
| # Initialize ChromaDB client | |
| self.chroma_client = chromadb.PersistentClient( | |
| path=self.db_path, | |
| settings=Settings(anonymized_telemetry=False) | |
| ) | |
| # Get collection | |
| self.collection = self.chroma_client.get_collection( | |
| name=self.collection_name | |
| ) | |
| # Load embedding model (same as used for building the database) | |
| self.embedding_model = SentenceTransformer(self.embedding_model_name) | |
| logger.info("RAG retrieval system initialized successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize retrieval system: {e}") | |
| raise | |
| def set_openai_client(self, api_key: str) -> bool: | |
| """ | |
| Initialize OpenAI client with API key. | |
| Args: | |
| api_key (str): OpenAI API key | |
| Returns: | |
| bool: True if successful, False otherwise | |
| """ | |
| try: | |
| self.openai_client = OpenAI(api_key=api_key) | |
| # Test the API key with a simple request | |
| self.openai_client.models.list() | |
| logger.info("OpenAI client initialized successfully") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to initialize OpenAI client: {e}") | |
| st.error(f"Invalid API key or OpenAI connection error: {e}") | |
| return False | |
| def retrieve_relevant_chunks( | |
| self, | |
| query: str, | |
| n_results: int = 3, | |
| min_relevance_score: float = 0.1 | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Retrieve relevant text chunks from the vector database. | |
| Args: | |
| query (str): User question/query | |
| n_results (int): Number of chunks to retrieve | |
| min_relevance_score (float): Minimum relevance score threshold | |
| Returns: | |
| List[Dict[str, Any]]: Retrieved chunks with content and metadata | |
| """ | |
| try: | |
| # Query the collection | |
| results = self.collection.query( | |
| query_texts=[query], | |
| n_results=n_results | |
| ) | |
| retrieved_chunks = [] | |
| # Process results | |
| if results['documents'] and results['documents'][0]: | |
| for i in range(len(results['documents'][0])): | |
| chunk_data = { | |
| 'content': results['documents'][0][i], | |
| 'metadata': results['metadatas'][0][i], | |
| 'distance': results['distances'][0][i] if 'distances' in results else None | |
| } | |
| # Filter by relevance score if available | |
| if chunk_data['distance'] is None or chunk_data['distance'] >= min_relevance_score: | |
| retrieved_chunks.append(chunk_data) | |
| logger.info(f"Retrieved {len(retrieved_chunks)} relevant chunks for query: {query[:50]}...") | |
| return retrieved_chunks | |
| except Exception as e: | |
| logger.error(f"Error retrieving chunks: {e}") | |
| st.error(f"Error during retrieval: {e}") | |
| return [] | |
| def create_rag_prompt( | |
| self, | |
| user_question: str, | |
| retrieved_chunks: List[Dict[str, Any]] | |
| ) -> str: | |
| """ | |
| Create an augmented prompt for OpenAI with retrieved context. | |
| Args: | |
| user_question (str): Original user question | |
| retrieved_chunks (List[Dict[str, Any]]): Retrieved relevant chunks | |
| Returns: | |
| str: Augmented prompt for OpenAI | |
| """ | |
| # Build context from retrieved chunks | |
| context_parts = [] | |
| for i, chunk in enumerate(retrieved_chunks, 1): | |
| url = chunk['metadata'].get('url', 'Unknown source') | |
| content = chunk['content'].strip() | |
| context_part = f"--- Context {i} (Source: {url}) ---\n{content}\n" | |
| context_parts.append(context_part) | |
| context = "\n".join(context_parts) | |
| # Create the RAG prompt | |
| rag_prompt = f"""You are an expert AI assistant specializing in Scikit-learn, a popular Python machine learning library. Your task is to answer questions about Scikit-learn based ONLY on the provided context from the official documentation. | |
| CONTEXT: | |
| {context} | |
| USER QUESTION: | |
| {user_question} | |
| INSTRUCTIONS: | |
| 1. Answer the question using ONLY the information provided in the context above | |
| 2. Be accurate, helpful, and specific | |
| 3. If the context doesn't contain enough information to fully answer the question, say so clearly | |
| 4. Include relevant code examples if they appear in the context | |
| 5. Mention specific function names, class names, or parameter names when relevant | |
| 6. Structure your answer clearly with appropriate formatting | |
| ANSWER:""" | |
| return rag_prompt | |
| def generate_answer( | |
| self, | |
| prompt: str, | |
| model: str = "gpt-3.5-turbo", | |
| max_tokens: int = 1000, | |
| temperature: float = 0.1 | |
| ) -> Optional[str]: | |
| """ | |
| Generate answer using OpenAI API. | |
| Args: | |
| prompt (str): Augmented prompt with context | |
| model (str): OpenAI model to use | |
| max_tokens (int): Maximum tokens in response | |
| temperature (float): Temperature for generation | |
| Returns: | |
| Optional[str]: Generated answer or None if failed | |
| """ | |
| try: | |
| response = self.openai_client.chat.completions.create( | |
| model=model, | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": "You are a helpful AI assistant specializing in Scikit-learn documentation. Provide accurate, helpful answers based only on the provided context." | |
| }, | |
| { | |
| "role": "user", | |
| "content": prompt | |
| } | |
| ], | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=0.9 | |
| ) | |
| answer = response.choices[0].message.content.strip() | |
| logger.info(f"Generated answer of length: {len(answer)}") | |
| return answer | |
| except Exception as e: | |
| logger.error(f"Error generating answer: {e}") | |
| st.error(f"Error generating answer: {e}") | |
| return None | |
| def get_answer( | |
| self, | |
| user_question: str, | |
| n_chunks: int = 3, | |
| model: str = "gpt-3.5-turbo" | |
| ) -> Tuple[Optional[str], List[str]]: | |
| """ | |
| Complete RAG pipeline: retrieve, augment, generate. | |
| Args: | |
| user_question (str): User's question | |
| n_chunks (int): Number of chunks to retrieve | |
| model (str): OpenAI model to use | |
| Returns: | |
| Tuple[Optional[str], List[str]]: Generated answer and source URLs | |
| """ | |
| if not self.openai_client: | |
| st.error("OpenAI client not initialized. Please provide a valid API key.") | |
| return None, [] | |
| # Use a single spinner for the entire process to prevent flickering | |
| with st.spinner("π€ Generating answer..."): | |
| # Step 1: Retrieve relevant chunks | |
| retrieved_chunks = self.retrieve_relevant_chunks(user_question, n_chunks) | |
| if not retrieved_chunks: | |
| return "I couldn't find relevant information in the Scikit-learn documentation to answer your question. Please try rephrasing your question or ask about a different topic.", [] | |
| # Step 2: Create augmented prompt | |
| rag_prompt = self.create_rag_prompt(user_question, retrieved_chunks) | |
| # Step 3: Generate answer | |
| answer = self.generate_answer(rag_prompt, model) | |
| # Extract source URLs | |
| source_urls = [chunk['metadata'].get('url', 'Unknown') for chunk in retrieved_chunks] | |
| source_urls = list(dict.fromkeys(source_urls)) # Remove duplicates while preserving order | |
| return answer, source_urls | |
| def initialize_session_state(): | |
| """Initialize Streamlit session state variables.""" | |
| if 'chatbot' not in st.session_state: | |
| try: | |
| st.session_state.chatbot = RAGChatbot() | |
| except Exception as e: | |
| st.error(f"Failed to initialize chatbot: {e}") | |
| st.stop() | |
| if 'openai_initialized' not in st.session_state: | |
| st.session_state.openai_initialized = False | |
| if 'chat_history' not in st.session_state: | |
| st.session_state.chat_history = [] | |
| def main(): | |
| """Main Streamlit application.""" | |
| # Page configuration | |
| st.set_page_config( | |
| page_title="Scikit-learn Q&A Bot", | |
| page_icon="π€", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # Initialize session state | |
| initialize_session_state() | |
| # Main title and description | |
| st.title("π€ Scikit-learn Documentation Q&A Bot") | |
| st.markdown(""" | |
| Welcome to the **Scikit-learn Documentation Q&A Bot**! This intelligent assistant can answer your questions about Scikit-learn using the official documentation. | |
| **How it works:** | |
| 1. π **Retrieval**: Searches through 1,249+ documentation chunks | |
| 2. π **Augmentation**: Provides relevant context to the AI | |
| 3. π€ **Generation**: Uses OpenAI to generate accurate answers | |
| """) | |
| # Sidebar for API key and settings | |
| with st.sidebar: | |
| st.header("βοΈ Configuration") | |
| # OpenAI API Key input | |
| api_key = st.text_input( | |
| "π OpenAI API Key", | |
| type="password", | |
| placeholder="sk-...", | |
| help="Enter your OpenAI API key to enable the chatbot" | |
| ) | |
| if api_key and not st.session_state.openai_initialized: | |
| if st.session_state.chatbot.set_openai_client(api_key): | |
| st.session_state.openai_initialized = True | |
| st.success("β API key validated!") | |
| st.rerun() | |
| # Model selection | |
| model = st.selectbox( | |
| "π§ AI Model", | |
| ["gpt-3.5-turbo", "gpt-4", "gpt-4-turbo-preview"], | |
| index=0, | |
| help="Choose the OpenAI model for generating answers" | |
| ) | |
| # Number of context chunks | |
| n_chunks = st.slider( | |
| "π Context Chunks", | |
| min_value=1, | |
| max_value=5, | |
| value=3, | |
| help="Number of relevant documentation chunks to use for context" | |
| ) | |
| st.markdown("---") | |
| # Database info | |
| st.header("π Database Info") | |
| try: | |
| collection_count = st.session_state.chatbot.collection.count() | |
| st.metric("Total Documents", f"{collection_count:,}") | |
| st.metric("Embedding Model", "all-MiniLM-L6-v2") | |
| st.metric("Vector Dimensions", "384") | |
| except: | |
| st.error("Could not load database info") | |
| st.markdown("---") | |
| # Clear chat history | |
| if st.button("ποΈ Clear Chat History"): | |
| st.session_state.chat_history = [] | |
| st.rerun() | |
| # Main chat interface | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| st.header("π¬ Ask Your Question") | |
| # Question input with better state management | |
| # Initialize question input key if it doesn't exist | |
| if 'question_input' not in st.session_state: | |
| st.session_state.question_input = '' | |
| # Handle selected question from examples | |
| if 'selected_question' in st.session_state: | |
| st.session_state.question_input = st.session_state['selected_question'] | |
| # Clear after setting to prevent re-triggering | |
| del st.session_state['selected_question'] | |
| user_question = st.text_input( | |
| "Enter your question about Scikit-learn:", | |
| placeholder="e.g., How do I perform cross-validation in scikit-learn?", | |
| key="question_input" | |
| ) | |
| # Submit button | |
| submit_button = st.button("π Get Answer", type="primary") | |
| # Process question | |
| if submit_button and user_question: | |
| if not st.session_state.openai_initialized: | |
| st.error("β οΈ Please enter a valid OpenAI API key in the sidebar first.") | |
| else: | |
| # Get answer using RAG | |
| answer, sources = st.session_state.chatbot.get_answer( | |
| user_question, n_chunks, model | |
| ) | |
| if answer: | |
| # Add to chat history | |
| st.session_state.chat_history.append({ | |
| 'question': user_question, | |
| 'answer': answer, | |
| 'sources': sources | |
| }) | |
| # Success message | |
| st.success("β Answer generated successfully! Check the chat history below.") | |
| # Display chat history | |
| if st.session_state.chat_history: | |
| st.header("π Chat History") | |
| for i, chat in enumerate(reversed(st.session_state.chat_history)): | |
| with st.expander(f"Q: {chat['question'][:60]}...", expanded=(i == 0)): | |
| st.markdown(f"**Question:** {chat['question']}") | |
| st.markdown(f"**Answer:**\n{chat['answer']}") | |
| if chat['sources']: | |
| st.markdown("**Sources:**") | |
| for j, source in enumerate(chat['sources'], 1): | |
| source_name = source.split('/')[-1] if '/' in source else source | |
| st.markdown(f"{j}. [{source_name}]({source})") | |
| with col2: | |
| st.header("π‘ Example Questions") | |
| example_questions = [ | |
| "How do I perform cross-validation in scikit-learn?", | |
| "What is the difference between Ridge and Lasso regression?", | |
| "How do I use GridSearchCV for parameter tuning?", | |
| "What clustering algorithms are available in scikit-learn?", | |
| "How do I preprocess data using StandardScaler?", | |
| "What is the difference between classification and regression?", | |
| "How do I handle missing values in my dataset?", | |
| "What is feature selection and how do I use it?", | |
| "How do I visualize decision trees?", | |
| "What is ensemble learning in scikit-learn?" | |
| ] | |
| for question in example_questions: | |
| if st.button(question, key=f"example_{hash(question)}"): | |
| # Set the question and rerun only once | |
| st.session_state['selected_question'] = question | |
| st.rerun() | |
| st.markdown("---") | |
| st.header("βΉοΈ Tips") | |
| st.markdown(""" | |
| **For best results:** | |
| - Be specific in your questions | |
| - Ask about scikit-learn functionality | |
| - Include context when possible | |
| - Check the sources for verification | |
| **The bot can help with:** | |
| - API usage and parameters | |
| - Algorithm explanations | |
| - Code examples | |
| - Best practices | |
| - Troubleshooting | |
| """) | |
| if __name__ == "__main__": | |
| # Run the main application | |
| # Note: For deployment environments like HuggingFace Spaces, | |
| # Streamlit warnings about missing ScriptRunContext can be safely ignored | |
| main() |