import os import pickle import time from typing import List, Dict, Any from chainlit.types import AskFileResponse from aimakerspace.text_utils import CharacterTextSplitter, TextFileLoader, PDFLoader from aimakerspace.openai_utils.prompts import ( UserRolePrompt, SystemRolePrompt, AssistantRolePrompt, ) from aimakerspace.openai_utils.embedding import EmbeddingModel from aimakerspace.vectordatabase import VectorDatabase from aimakerspace.openai_utils.chatmodel import ChatOpenAI import chainlit as cl system_template = """\ Use the following context to answer a users question. If you cannot find the answer in the context, say you don't know the answer.""" system_role_prompt = SystemRolePrompt(system_template) user_prompt_template = """\ Context: {context} Question: {question} """ user_role_prompt = UserRolePrompt(user_prompt_template) def normalize_text(text): """Normalize text for better matching by removing extra whitespace and converting to lowercase""" return ' '.join(text.lower().split()) class RetrievalAugmentedQAPipeline: def __init__(self, llm: ChatOpenAI(), vector_db_retriever: VectorDatabase, metadata: List[Dict[str, Any]] = None, texts: List[str] = None) -> None: self.llm = llm self.vector_db_retriever = vector_db_retriever self.metadata = metadata or [] self.text_to_metadata = {} # Debug info about input data print(f"Init with metadata length: {len(metadata) if metadata else 0}, texts length: {len(texts) if texts else 0}") # Enhanced text-to-metadata mapping with normalization if metadata and texts and len(metadata) > 0: # Create normalized versions of texts for better matching normalized_texts = [normalize_text(t) for t in texts] # First, try exact mapping if lengths match if len(texts) == len(metadata): print(f"Creating direct mapping with {len(texts)} texts") for i, text in enumerate(texts): self.text_to_metadata[normalize_text(text)] = metadata[i] # Otherwise map by tracking which PDF and page each chunk is from else: print(f"WARN: Length mismatch between texts ({len(texts)}) and metadata ({len(metadata)})") current_file = None current_page = None for i, meta in enumerate(metadata): if i < len(normalized_texts): self.text_to_metadata[normalized_texts[i]] = meta # Track current file and page for debugging if current_file != meta['filename'] or current_page != meta['page']: current_file = meta['filename'] current_page = meta['page'] print(f"File: {current_file}, Page: {current_page}") print(f"Successfully mapped {len(self.text_to_metadata)} text chunks to metadata") # Sample a few mappings for verification sample_size = min(3, len(self.text_to_metadata)) sample_items = list(self.text_to_metadata.items())[:sample_size] for i, (text, meta) in enumerate(sample_items): print(f"Sample {i+1}: {text[:50]}... -> {meta}") else: print(f"WARNING: Metadata mapping not created. Metadata: {len(metadata) if metadata else 0}, Texts: {len(texts) if texts else 0}") async def arun_pipeline(self, user_query: str): context_list = self.vector_db_retriever.search_by_text(user_query, k=4) # Debug: print the first retrieved context if context_list: print(f"Retrieved context: {context_list[0][0][:100]}...") context_prompt = "" sources = [] for context in context_list: text = context[0] context_prompt += text + "\n" # Normalize the text for better matching normalized_text = normalize_text(text) # Get metadata for this text if available using normalized text if normalized_text in self.text_to_metadata: sources.append(self.text_to_metadata[normalized_text]) print(f"✓ Found exact metadata match for: {normalized_text[:50]}...") else: # If exact text not found, try finding most similar text print(f"× No exact match for: {normalized_text[:50]}...") found = False best_match = None best_score = 0 # Try fuzzy matching for orig_text, meta in self.text_to_metadata.items(): # Calculate overlap score text_words = set(normalized_text.split()) orig_words = set(orig_text.split()) if not text_words or not orig_words: continue overlap = len(text_words.intersection(orig_words)) score = overlap / max(len(text_words), len(orig_words)) if score > best_score and score > 0.5: # Minimum 50% word overlap best_score = score best_match = meta if best_match: sources.append(best_match) print(f"✓ Found fuzzy match with score {best_score:.2f}") found = True if not found: print("× No match found at all") sources.append({"filename": "unknown", "page": "unknown"}) formatted_system_prompt = system_role_prompt.create_message() formatted_user_prompt = user_role_prompt.create_message(question=user_query, context=context_prompt) async def generate_response(): async for chunk in self.llm.astream([formatted_system_prompt, formatted_user_prompt]): yield chunk return {"response": generate_response(), "sources": sources} text_splitter = CharacterTextSplitter() def load_preprocessed_data(): # Check if preprocessed data exists if not os.path.exists('data/preprocessed_data.pkl'): raise FileNotFoundError("Preprocessed data not found. Please run the preprocess.py script first.") # Load the pre-processed data with open('data/preprocessed_data.pkl', 'rb') as f: data = pickle.load(f) # Debug info about the file contents print(f"Loaded preprocessed data with keys: {list(data.keys())}") # Create a new vector database vector_db = VectorDatabase() # Check that vectors dictionary has data if 'vectors' in data and data['vectors']: print(f"Vectors dictionary has {len(data['vectors'])} entries") # Directly populate the vectors dictionary for key, vector in data['vectors'].items(): vector_db.insert(key, vector) else: print("WARNING: No vectors found in preprocessed data") # Get metadata and original texts if available metadata = data.get('metadata', []) texts = data.get('texts', []) print(f"Loaded {len(metadata)} metadata entries and {len(texts)} texts") # Verify a sample of metadata to debug page numbering if metadata and len(metadata) > 0: page_counts = {} for meta in metadata: filename = meta.get('filename', 'unknown') page = meta.get('page', 'unknown') if filename not in page_counts: page_counts[filename] = set() page_counts[filename].add(page) print(f"Found {len(page_counts)} unique files with pages:") for filename, pages in page_counts.items(): print(f" - {filename}: {len(pages)} unique pages (min: {min(pages)}, max: {max(pages)})") return vector_db, metadata, texts @cl.on_chat_start async def on_chat_start(): # Send welcome message msg = cl.Message(content="Loading knowledge base from pre-processed PDF documents...") await msg.send() try: # Check if preprocessed data exists if not os.path.exists('data/preprocessed_data.pkl'): msg.content = """ ## Error: Preprocessed Data Not Found The application requires preprocessing of PDF documents to build a knowledge base, but the preprocessed data was not found. **For administrators:** 1. Make sure you've set both OPENAI_API_KEY and HF_TOKEN as build secrets in your Hugging Face Space. 2. Check the build logs for any errors during the preprocessing step. 3. You may need to manually run preprocessing on your local machine and upload the data/preprocessed_data.pkl file. **Steps to build preprocessed data locally:** 1. Clone this repository 2. Install dependencies with `pip install -r requirements.txt` 3. Set your OpenAI API key: `export OPENAI_API_KEY=your_key_here` 4. Run: `python preprocess.py` 5. Upload the generated `data/preprocessed_data.pkl` file to your Hugging Face Space """ await msg.update() return # Load pre-processed data start_time = time.time() vector_db, metadata, texts = load_preprocessed_data() load_time = time.time() - start_time print(f"Loaded vector database in {load_time:.2f} seconds") chat_openai = ChatOpenAI() # Create chain retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline( vector_db_retriever=vector_db, llm=chat_openai, metadata=metadata, texts=texts ) # Let the user know that the system is ready msg.content = "Please ask questions about A/B Testing. We'll use material written by Ronny Kohavi to answer your questions!" await msg.update() cl.user_session.set("chain", retrieval_augmented_qa_pipeline) except Exception as e: msg.content = f"Error loading knowledge base: {str(e)}\n\nPlease make sure you've configured the OPENAI_API_KEY and HF_TOKEN as build secrets in your Hugging Face Space." await msg.update() print(f"Error details: {e}") @cl.on_message async def main(message): chain = cl.user_session.get("chain") # If chain is not initialized, inform the user if not chain: msg = cl.Message(content="Sorry, the knowledge base is not loaded. Please check the error message at startup.") await msg.send() return msg = cl.Message(content="") result = await chain.arun_pipeline(message.content) async for stream_resp in result["response"]: await msg.stream_token(stream_resp) # Add source information after the response sources_text = "\n\n**Sources:**" for i, source in enumerate(result["sources"]): sources_text += f"\n- {source['filename']} (Page {source['page']})" await msg.stream_token(sources_text) await msg.send()