Spaces:
Sleeping
Sleeping
| import asyncio | |
| import os | |
| import logging | |
| from typing import List, Optional | |
| from contextlib import contextmanager | |
| import numpy as np | |
| from dotenv import load_dotenv | |
| from langchain_groq import ChatGroq | |
| from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.prompts import ChatPromptTemplate | |
| from langchain_core.runnables import RunnablePassthrough | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.documents import Document | |
| from sqlalchemy import select | |
| from .constants import ( | |
| EMBEDDING_MODEL_NAME, | |
| GROQ_MODEL_NAME, | |
| CHUNK_SIZE, | |
| CHUNK_OVERLAP, | |
| SEPARATORS | |
| ) | |
| from .system_prompt import RAG_SYSTEM_PROMPT | |
| from model.db.db_setup import DocumentEmbedding, SessionLocal | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Load environment variables | |
| load_dotenv() | |
| # Configure environment | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| GROQ_API_KEY = os.environ['GROQ_API_KEY'] | |
| if not GROQ_API_KEY: | |
| logger.error("GROQ_API_KEY environment variable is not set") | |
| raise ValueError("GROQ_API_KEY environment variable is not set.") | |
| os.environ["GROQ_API_KEY"] = GROQ_API_KEY | |
| def get_db_session(): | |
| """Context manager for database sessions.""" | |
| db = SessionLocal() | |
| try: | |
| yield db | |
| finally: | |
| db.close() | |
| class RAGManager: | |
| def __init__(self): | |
| self._embed_model: Optional[HuggingFaceEmbeddings] = None | |
| self._rag_llm: Optional[ChatGroq] = None | |
| self._initialize_embed_model() | |
| self._rag_llm: Optional[ChatGroq] = None | |
| logger.info("RAGManager initialized") | |
| def _initialize_embed_model(self): | |
| """Initialize the embedding model.""" | |
| try: | |
| logger.debug("Initializing embedding model") | |
| self._embed_model = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME) | |
| logger.info("Embedding model initialized successfully") | |
| except Exception as e: | |
| logger.error(f"Error initializing embedding model: {str(e)}") | |
| raise | |
| def embed_model(self): | |
| """Get or create embedding model.""" | |
| if self._embed_model is None: | |
| self._initialize_embed_model() | |
| return self._embed_model | |
| def rag_llm(self): | |
| """Get or create RAG LLM model.""" | |
| if self._rag_llm is None: | |
| logger.debug("Initializing RAG LLM model") | |
| self._rag_llm = ChatGroq(model=GROQ_MODEL_NAME) | |
| logger.info("RAG LLM model initialized successfully") | |
| return self._rag_llm | |
| async def get_relevant_documents(self, query: str, k: int = 4) -> List[Document]: | |
| """Get relevant documents using vector similarity search.""" | |
| with get_db_session() as db: | |
| # Get query embedding | |
| query_embedding = self.embed_model.embed_query(query) | |
| # Get all documents and calculate similarities | |
| stmt = select(DocumentEmbedding) | |
| results = db.execute(stmt).scalars().all() | |
| if not results: | |
| logger.warning("No documents found in database") | |
| return [] | |
| # Convert embeddings to numpy array for faster computation | |
| embeddings = np.array([doc.embedding for doc in results]) | |
| query_embedding = np.array(query_embedding) | |
| # Calculate similarities in batch | |
| similarities = np.dot(embeddings, query_embedding) / ( | |
| np.linalg.norm(embeddings, axis=1) * np.linalg.norm(query_embedding) | |
| ) | |
| # Get indices of top k results | |
| top_k_indices = np.argsort(similarities)[-k:][::-1] | |
| # Convert top k results to Documents | |
| return [ | |
| Document( | |
| page_content=results[idx].content, | |
| metadata=results[idx].doc_metadata or {} | |
| ) | |
| for idx in top_k_indices | |
| ] | |
| def format_docs(self, docs: List[Document]) -> str: | |
| """Format documents with error handling.""" | |
| try: | |
| if not docs: | |
| logger.debug("No documents to format") | |
| return "" | |
| formatted = "\n\n".join(doc.page_content for doc in docs) | |
| logger.debug(f"Formatted {len(docs)} documents") | |
| return formatted | |
| except Exception as e: | |
| logger.error(f"Error formatting documents: {str(e)}") | |
| return "" | |
| async def process_documents(self, documents: List[Document]) -> bool: | |
| """Process documents and store embeddings in PostgreSQL.""" | |
| try: | |
| if not documents: | |
| logger.warning("No documents provided") | |
| return False | |
| # Split documents | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| separators=SEPARATORS, | |
| chunk_size=CHUNK_SIZE, | |
| chunk_overlap=CHUNK_OVERLAP, | |
| length_function=len, | |
| is_separator_regex=False, | |
| ) | |
| split_docs = [] | |
| for doc in documents: | |
| try: | |
| split_docs.extend(text_splitter.split_documents([doc])) | |
| except Exception as e: | |
| logger.warning(f"Error splitting document {doc.metadata.get('source', 'unknown')}: {str(e)}") | |
| continue | |
| if not split_docs: | |
| logger.warning("No content could be split from documents") | |
| return False | |
| # Create embeddings and store in database | |
| with get_db_session() as db: | |
| try: | |
| # Clear existing embeddings | |
| db.query(DocumentEmbedding).delete() | |
| # Create embeddings in batch | |
| contents = [doc.page_content for doc in split_docs] | |
| embeddings = self.embed_model.embed_documents(contents) | |
| # Store embeddings | |
| db_embeddings = [ | |
| DocumentEmbedding( | |
| content=doc.page_content, | |
| embedding=embedding, | |
| doc_metadata=doc.metadata, | |
| source=doc.metadata.get("source") | |
| ) | |
| for doc, embedding in zip(split_docs, embeddings) | |
| ] | |
| db.add_all(db_embeddings) | |
| db.commit() | |
| logger.info(f"Successfully stored {len(split_docs)} embeddings in database") | |
| return True | |
| except Exception as e: | |
| db.rollback() | |
| logger.error(f"Error storing embeddings in database: {str(e)}") | |
| return False | |
| except Exception as e: | |
| logger.error(f"Error processing documents: {str(e)}") | |
| return False | |
| async def get_contextual_response(self, question: str) -> str: | |
| """Get contextual response using RAG.""" | |
| try: | |
| # Get relevant documents | |
| relevant_docs = await self.get_relevant_documents(question) | |
| if not relevant_docs: | |
| logger.warning("No relevant documents found") | |
| return "I don't have enough context to answer that question." | |
| # Format documents | |
| context = self.format_docs(relevant_docs) | |
| # Generate response | |
| prompt = ChatPromptTemplate.from_messages([ | |
| ("system", RAG_SYSTEM_PROMPT), | |
| ("human", "{input}") | |
| ]) | |
| chain = ( | |
| {"context": lambda x: context, "input": RunnablePassthrough()} | |
| | prompt | |
| | self.rag_llm | |
| | StrOutputParser() | |
| ) | |
| response = await chain.ainvoke(question) | |
| logger.info("Successfully generated contextual response") | |
| return response | |
| except Exception as e: | |
| logger.error(f"Error generating contextual response: {str(e)}") | |
| return f"Error generating response: {str(e)}" | |
| # Create a singleton instance | |
| rag_manager = RAGManager() | |
| logger.info("RAGManager singleton instance created") | |
| # Export the get_contextual_response function for backward compatibility | |
| async def get_contextual_response(question: str) -> str: | |
| return await rag_manager.get_contextual_response(question) | |