Spaces:
Sleeping
Sleeping
| import os | |
| import pickle | |
| import time | |
| from typing import List, Dict, Any | |
| from chainlit.types import AskFileResponse | |
| from aimakerspace.text_utils import CharacterTextSplitter, TextFileLoader, PDFLoader | |
| from aimakerspace.openai_utils.prompts import ( | |
| UserRolePrompt, | |
| SystemRolePrompt, | |
| AssistantRolePrompt, | |
| ) | |
| from aimakerspace.openai_utils.embedding import EmbeddingModel | |
| from aimakerspace.vectordatabase import VectorDatabase | |
| from aimakerspace.openai_utils.chatmodel import ChatOpenAI | |
| import chainlit as cl | |
| system_template = """\ | |
| Use the following context to answer a users question. If you cannot find the answer in the context, say you don't know the answer.""" | |
| system_role_prompt = SystemRolePrompt(system_template) | |
| user_prompt_template = """\ | |
| Context: | |
| {context} | |
| Question: | |
| {question} | |
| """ | |
| user_role_prompt = UserRolePrompt(user_prompt_template) | |
| def normalize_text(text): | |
| """Normalize text for better matching by removing extra whitespace and converting to lowercase""" | |
| return ' '.join(text.lower().split()) | |
| class RetrievalAugmentedQAPipeline: | |
| def __init__(self, llm: ChatOpenAI(), vector_db_retriever: VectorDatabase, metadata: List[Dict[str, Any]] = None, texts: List[str] = None) -> None: | |
| self.llm = llm | |
| self.vector_db_retriever = vector_db_retriever | |
| self.metadata = metadata or [] | |
| self.text_to_metadata = {} | |
| # Debug info about input data | |
| print(f"Init with metadata length: {len(metadata) if metadata else 0}, texts length: {len(texts) if texts else 0}") | |
| # Enhanced text-to-metadata mapping with normalization | |
| if metadata and texts and len(metadata) > 0: | |
| # Create normalized versions of texts for better matching | |
| normalized_texts = [normalize_text(t) for t in texts] | |
| # First, try exact mapping if lengths match | |
| if len(texts) == len(metadata): | |
| print(f"Creating direct mapping with {len(texts)} texts") | |
| for i, text in enumerate(texts): | |
| self.text_to_metadata[normalize_text(text)] = metadata[i] | |
| # Otherwise map by tracking which PDF and page each chunk is from | |
| else: | |
| print(f"WARN: Length mismatch between texts ({len(texts)}) and metadata ({len(metadata)})") | |
| current_file = None | |
| current_page = None | |
| for i, meta in enumerate(metadata): | |
| if i < len(normalized_texts): | |
| self.text_to_metadata[normalized_texts[i]] = meta | |
| # Track current file and page for debugging | |
| if current_file != meta['filename'] or current_page != meta['page']: | |
| current_file = meta['filename'] | |
| current_page = meta['page'] | |
| print(f"File: {current_file}, Page: {current_page}") | |
| print(f"Successfully mapped {len(self.text_to_metadata)} text chunks to metadata") | |
| # Sample a few mappings for verification | |
| sample_size = min(3, len(self.text_to_metadata)) | |
| sample_items = list(self.text_to_metadata.items())[:sample_size] | |
| for i, (text, meta) in enumerate(sample_items): | |
| print(f"Sample {i+1}: {text[:50]}... -> {meta}") | |
| else: | |
| print(f"WARNING: Metadata mapping not created. Metadata: {len(metadata) if metadata else 0}, Texts: {len(texts) if texts else 0}") | |
| async def arun_pipeline(self, user_query: str): | |
| context_list = self.vector_db_retriever.search_by_text(user_query, k=4) | |
| # Debug: print the first retrieved context | |
| if context_list: | |
| print(f"Retrieved context: {context_list[0][0][:100]}...") | |
| context_prompt = "" | |
| sources = [] | |
| for context in context_list: | |
| text = context[0] | |
| context_prompt += text + "\n" | |
| # Normalize the text for better matching | |
| normalized_text = normalize_text(text) | |
| # Get metadata for this text if available using normalized text | |
| if normalized_text in self.text_to_metadata: | |
| sources.append(self.text_to_metadata[normalized_text]) | |
| print(f"✓ Found exact metadata match for: {normalized_text[:50]}...") | |
| else: | |
| # If exact text not found, try finding most similar text | |
| print(f"× No exact match for: {normalized_text[:50]}...") | |
| found = False | |
| best_match = None | |
| best_score = 0 | |
| # Try fuzzy matching | |
| for orig_text, meta in self.text_to_metadata.items(): | |
| # Calculate overlap score | |
| text_words = set(normalized_text.split()) | |
| orig_words = set(orig_text.split()) | |
| if not text_words or not orig_words: | |
| continue | |
| overlap = len(text_words.intersection(orig_words)) | |
| score = overlap / max(len(text_words), len(orig_words)) | |
| if score > best_score and score > 0.5: # Minimum 50% word overlap | |
| best_score = score | |
| best_match = meta | |
| if best_match: | |
| sources.append(best_match) | |
| print(f"✓ Found fuzzy match with score {best_score:.2f}") | |
| found = True | |
| if not found: | |
| print("× No match found at all") | |
| sources.append({"filename": "unknown", "page": "unknown"}) | |
| formatted_system_prompt = system_role_prompt.create_message() | |
| formatted_user_prompt = user_role_prompt.create_message(question=user_query, context=context_prompt) | |
| async def generate_response(): | |
| async for chunk in self.llm.astream([formatted_system_prompt, formatted_user_prompt]): | |
| yield chunk | |
| return {"response": generate_response(), "sources": sources} | |
| text_splitter = CharacterTextSplitter() | |
| def load_preprocessed_data(): | |
| # Check if preprocessed data exists | |
| if not os.path.exists('data/preprocessed_data.pkl'): | |
| raise FileNotFoundError("Preprocessed data not found. Please run the preprocess.py script first.") | |
| # Load the pre-processed data | |
| with open('data/preprocessed_data.pkl', 'rb') as f: | |
| data = pickle.load(f) | |
| # Debug info about the file contents | |
| print(f"Loaded preprocessed data with keys: {list(data.keys())}") | |
| # Create a new vector database | |
| vector_db = VectorDatabase() | |
| # Check that vectors dictionary has data | |
| if 'vectors' in data and data['vectors']: | |
| print(f"Vectors dictionary has {len(data['vectors'])} entries") | |
| # Directly populate the vectors dictionary | |
| for key, vector in data['vectors'].items(): | |
| vector_db.insert(key, vector) | |
| else: | |
| print("WARNING: No vectors found in preprocessed data") | |
| # Get metadata and original texts if available | |
| metadata = data.get('metadata', []) | |
| texts = data.get('texts', []) | |
| print(f"Loaded {len(metadata)} metadata entries and {len(texts)} texts") | |
| # Verify a sample of metadata to debug page numbering | |
| if metadata and len(metadata) > 0: | |
| page_counts = {} | |
| for meta in metadata: | |
| filename = meta.get('filename', 'unknown') | |
| page = meta.get('page', 'unknown') | |
| if filename not in page_counts: | |
| page_counts[filename] = set() | |
| page_counts[filename].add(page) | |
| print(f"Found {len(page_counts)} unique files with pages:") | |
| for filename, pages in page_counts.items(): | |
| print(f" - {filename}: {len(pages)} unique pages (min: {min(pages)}, max: {max(pages)})") | |
| return vector_db, metadata, texts | |
| async def on_chat_start(): | |
| # Send welcome message | |
| msg = cl.Message(content="Loading knowledge base from pre-processed PDF documents...") | |
| await msg.send() | |
| try: | |
| # Check if preprocessed data exists | |
| if not os.path.exists('data/preprocessed_data.pkl'): | |
| msg.content = """ | |
| ## Error: Preprocessed Data Not Found | |
| The application requires preprocessing of PDF documents to build a knowledge base, but the preprocessed data was not found. | |
| **For administrators:** | |
| 1. Make sure you've set both OPENAI_API_KEY and HF_TOKEN as build secrets in your Hugging Face Space. | |
| 2. Check the build logs for any errors during the preprocessing step. | |
| 3. You may need to manually run preprocessing on your local machine and upload the data/preprocessed_data.pkl file. | |
| **Steps to build preprocessed data locally:** | |
| 1. Clone this repository | |
| 2. Install dependencies with `pip install -r requirements.txt` | |
| 3. Set your OpenAI API key: `export OPENAI_API_KEY=your_key_here` | |
| 4. Run: `python preprocess.py` | |
| 5. Upload the generated `data/preprocessed_data.pkl` file to your Hugging Face Space | |
| """ | |
| await msg.update() | |
| return | |
| # Load pre-processed data | |
| start_time = time.time() | |
| vector_db, metadata, texts = load_preprocessed_data() | |
| load_time = time.time() - start_time | |
| print(f"Loaded vector database in {load_time:.2f} seconds") | |
| chat_openai = ChatOpenAI() | |
| # Create chain | |
| retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline( | |
| vector_db_retriever=vector_db, | |
| llm=chat_openai, | |
| metadata=metadata, | |
| texts=texts | |
| ) | |
| # Let the user know that the system is ready | |
| msg.content = "Please ask questions about A/B Testing. We'll use material written by Ronny Kohavi to answer your questions!" | |
| await msg.update() | |
| cl.user_session.set("chain", retrieval_augmented_qa_pipeline) | |
| except Exception as e: | |
| msg.content = f"Error loading knowledge base: {str(e)}\n\nPlease make sure you've configured the OPENAI_API_KEY and HF_TOKEN as build secrets in your Hugging Face Space." | |
| await msg.update() | |
| print(f"Error details: {e}") | |
| async def main(message): | |
| chain = cl.user_session.get("chain") | |
| # If chain is not initialized, inform the user | |
| if not chain: | |
| msg = cl.Message(content="Sorry, the knowledge base is not loaded. Please check the error message at startup.") | |
| await msg.send() | |
| return | |
| msg = cl.Message(content="") | |
| result = await chain.arun_pipeline(message.content) | |
| async for stream_resp in result["response"]: | |
| await msg.stream_token(stream_resp) | |
| # Add source information after the response | |
| sources_text = "\n\n**Sources:**" | |
| for i, source in enumerate(result["sources"]): | |
| sources_text += f"\n- {source['filename']} (Page {source['page']})" | |
| await msg.stream_token(sources_text) | |
| await msg.send() |