Spaces:
Sleeping
Sleeping
| import os | |
| from typing import List, Dict, Any | |
| import tempfile | |
| import shutil | |
| import logging | |
| import time | |
| import traceback | |
| import asyncio | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # Make sure aimakerspace is in the path | |
| import sys | |
| sys.path.append(os.path.join(os.path.dirname(os.path.dirname(__file__)), "")) | |
| # Import from local aimakerspace module | |
| from aimakerspace.text_utils import CharacterTextSplitter, TextFileLoader, PDFLoader | |
| from aimakerspace.vectordatabase import VectorDatabase | |
| from aimakerspace.openai_utils.embedding import EmbeddingModel | |
| from openai import OpenAI | |
| # Initialize OpenAI client | |
| client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
| logger.info(f"Initialized OpenAI client with API key: {'valid key' if os.getenv('OPENAI_API_KEY') else 'API KEY MISSING!'}") | |
| class RetrievalAugmentedQAPipeline: | |
| def __init__(self, vector_db_retriever: VectorDatabase) -> None: | |
| self.vector_db_retriever = vector_db_retriever | |
| async def arun_pipeline(self, user_query: str): | |
| """ | |
| Run the RAG pipeline with the given user query. | |
| Returns a stream of response chunks. | |
| """ | |
| try: | |
| # 1. Retrieve relevant documents | |
| logger.info(f"RAG Pipeline: Retrieving documents for query: '{user_query}'") | |
| relevant_docs = self.vector_db_retriever.search_by_text(user_query, k=4) | |
| if not relevant_docs: | |
| logger.warning("No relevant documents found in vector database") | |
| documents_context = "No relevant information found in the document." | |
| else: | |
| logger.info(f"Found {len(relevant_docs)} relevant document chunks") | |
| # Format documents | |
| documents_context = "\n\n".join([doc[0] for doc in relevant_docs]) | |
| # Debug similarity scores | |
| doc_scores = [f"{i+1}. Score: {doc[1]:.4f}" for i, doc in enumerate(relevant_docs)] | |
| logger.info(f"Document similarity scores: {', '.join(doc_scores) if doc_scores else 'No documents'}") | |
| # 2. Create messaging payload | |
| messages = [ | |
| {"role": "system", "content": f"""You are a helpful AI assistant that answers questions based on the provided document context. | |
| If the answer is not in the context, say that you don't know based on the available information. | |
| Use the following document extracts to answer the user's question: | |
| {documents_context}"""}, | |
| {"role": "user", "content": user_query} | |
| ] | |
| # 3. Call LLM and stream the output | |
| async def generate_response(): | |
| try: | |
| logger.info("Initiating streaming completion from OpenAI") | |
| stream = client.chat.completions.create( | |
| model="gpt-3.5-turbo", | |
| messages=messages, | |
| temperature=0.2, | |
| stream=True | |
| ) | |
| for chunk in stream: | |
| if chunk.choices[0].delta.content: | |
| yield chunk.choices[0].delta.content | |
| except Exception as e: | |
| logger.error(f"Error generating stream: {str(e)}") | |
| yield f"\n\nI apologize, but I encountered an error while generating a response: {str(e)}" | |
| return { | |
| "response": generate_response() | |
| } | |
| except Exception as e: | |
| logger.error(f"Error in RAG pipeline: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| return { | |
| "response": (chunk for chunk in [f"I apologize, but an error occurred: {str(e)}"]) | |
| } | |
| def process_file(file_path: str, file_name: str) -> List[str]: | |
| """Process an uploaded file and convert it to text chunks""" | |
| logger.info(f"Processing file: {file_name} at path: {file_path}") | |
| try: | |
| # Determine loader based on file extension | |
| if file_name.lower().endswith('.txt'): | |
| logger.info(f"Using TextFileLoader for {file_name}") | |
| loader = TextFileLoader(file_path) | |
| loader.load() | |
| elif file_name.lower().endswith('.pdf'): | |
| logger.info(f"Using PDFLoader for {file_name}") | |
| loader = PDFLoader(file_path) | |
| loader.load() | |
| else: | |
| logger.warning(f"Unsupported file type: {file_name}") | |
| return ["Unsupported file format. Please upload a .txt or .pdf file."] | |
| # Get documents from loader | |
| documents = loader.documents | |
| if documents and len(documents) > 0: | |
| logger.info(f"Loaded document with {len(documents[0])} characters") | |
| else: | |
| logger.warning("No document content loaded") | |
| return ["No content found in the document"] | |
| # Split text into chunks | |
| text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
| text_chunks = text_splitter.split_texts(documents) | |
| logger.info(f"Split document into {len(text_chunks)} chunks") | |
| return text_chunks | |
| except Exception as e: | |
| logger.error(f"Error processing file: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| return [f"Error processing file: {str(e)}"] | |
| async def setup_vector_db(texts: List[str]) -> VectorDatabase: | |
| """Create vector database from text chunks""" | |
| logger.info(f"Setting up vector database with {len(texts)} text chunks") | |
| embedding_model = EmbeddingModel() | |
| vector_db = VectorDatabase(embedding_model=embedding_model) | |
| try: | |
| await vector_db.abuild_from_list(texts) | |
| vector_db.documents = texts | |
| logger.info(f"Vector database built with {len(texts)} documents") | |
| return vector_db | |
| except Exception as e: | |
| logger.error(f"Error setting up vector database: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| fallback_db = VectorDatabase(embedding_model=embedding_model) | |
| error_text = "I'm sorry, but there was an error processing the document." | |
| fallback_db.insert(error_text, [0.0] * 1536) | |
| fallback_db.documents = [error_text] | |
| return fallback_db |