Hello / model /contextual_response /contextual_response.py
GaneshNaiknavare's picture
Update model/contextual_response/contextual_response.py
5b4c0a7 verified
import asyncio
import os
import logging
from typing import List, Optional
from contextlib import contextmanager
import numpy as np
from dotenv import load_dotenv
from langchain_groq import ChatGroq
from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_core.documents import Document
from sqlalchemy import select
from .constants import (
EMBEDDING_MODEL_NAME,
GROQ_MODEL_NAME,
CHUNK_SIZE,
CHUNK_OVERLAP,
SEPARATORS
)
from .system_prompt import RAG_SYSTEM_PROMPT
from model.db.db_setup import DocumentEmbedding, SessionLocal
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Load environment variables
load_dotenv()
# Configure environment
os.environ["TOKENIZERS_PARALLELISM"] = "false"
GROQ_API_KEY = os.environ['GROQ_API_KEY']
if not GROQ_API_KEY:
logger.error("GROQ_API_KEY environment variable is not set")
raise ValueError("GROQ_API_KEY environment variable is not set.")
os.environ["GROQ_API_KEY"] = GROQ_API_KEY
@contextmanager
def get_db_session():
"""Context manager for database sessions."""
db = SessionLocal()
try:
yield db
finally:
db.close()
class RAGManager:
def __init__(self):
self._embed_model: Optional[HuggingFaceEmbeddings] = None
self._rag_llm: Optional[ChatGroq] = None
self._initialize_embed_model()
self._rag_llm: Optional[ChatGroq] = None
logger.info("RAGManager initialized")
def _initialize_embed_model(self):
"""Initialize the embedding model."""
try:
logger.debug("Initializing embedding model")
self._embed_model = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
logger.info("Embedding model initialized successfully")
except Exception as e:
logger.error(f"Error initializing embedding model: {str(e)}")
raise
@property
def embed_model(self):
"""Get or create embedding model."""
if self._embed_model is None:
self._initialize_embed_model()
return self._embed_model
@property
def rag_llm(self):
"""Get or create RAG LLM model."""
if self._rag_llm is None:
logger.debug("Initializing RAG LLM model")
self._rag_llm = ChatGroq(model=GROQ_MODEL_NAME)
logger.info("RAG LLM model initialized successfully")
return self._rag_llm
async def get_relevant_documents(self, query: str, k: int = 4) -> List[Document]:
"""Get relevant documents using vector similarity search."""
with get_db_session() as db:
# Get query embedding
query_embedding = self.embed_model.embed_query(query)
# Get all documents and calculate similarities
stmt = select(DocumentEmbedding)
results = db.execute(stmt).scalars().all()
if not results:
logger.warning("No documents found in database")
return []
# Convert embeddings to numpy array for faster computation
embeddings = np.array([doc.embedding for doc in results])
query_embedding = np.array(query_embedding)
# Calculate similarities in batch
similarities = np.dot(embeddings, query_embedding) / (
np.linalg.norm(embeddings, axis=1) * np.linalg.norm(query_embedding)
)
# Get indices of top k results
top_k_indices = np.argsort(similarities)[-k:][::-1]
# Convert top k results to Documents
return [
Document(
page_content=results[idx].content,
metadata=results[idx].doc_metadata or {}
)
for idx in top_k_indices
]
def format_docs(self, docs: List[Document]) -> str:
"""Format documents with error handling."""
try:
if not docs:
logger.debug("No documents to format")
return ""
formatted = "\n\n".join(doc.page_content for doc in docs)
logger.debug(f"Formatted {len(docs)} documents")
return formatted
except Exception as e:
logger.error(f"Error formatting documents: {str(e)}")
return ""
async def process_documents(self, documents: List[Document]) -> bool:
"""Process documents and store embeddings in PostgreSQL."""
try:
if not documents:
logger.warning("No documents provided")
return False
# Split documents
text_splitter = RecursiveCharacterTextSplitter(
separators=SEPARATORS,
chunk_size=CHUNK_SIZE,
chunk_overlap=CHUNK_OVERLAP,
length_function=len,
is_separator_regex=False,
)
split_docs = []
for doc in documents:
try:
split_docs.extend(text_splitter.split_documents([doc]))
except Exception as e:
logger.warning(f"Error splitting document {doc.metadata.get('source', 'unknown')}: {str(e)}")
continue
if not split_docs:
logger.warning("No content could be split from documents")
return False
# Create embeddings and store in database
with get_db_session() as db:
try:
# Clear existing embeddings
db.query(DocumentEmbedding).delete()
# Create embeddings in batch
contents = [doc.page_content for doc in split_docs]
embeddings = self.embed_model.embed_documents(contents)
# Store embeddings
db_embeddings = [
DocumentEmbedding(
content=doc.page_content,
embedding=embedding,
doc_metadata=doc.metadata,
source=doc.metadata.get("source")
)
for doc, embedding in zip(split_docs, embeddings)
]
db.add_all(db_embeddings)
db.commit()
logger.info(f"Successfully stored {len(split_docs)} embeddings in database")
return True
except Exception as e:
db.rollback()
logger.error(f"Error storing embeddings in database: {str(e)}")
return False
except Exception as e:
logger.error(f"Error processing documents: {str(e)}")
return False
async def get_contextual_response(self, question: str) -> str:
"""Get contextual response using RAG."""
try:
# Get relevant documents
relevant_docs = await self.get_relevant_documents(question)
if not relevant_docs:
logger.warning("No relevant documents found")
return "I don't have enough context to answer that question."
# Format documents
context = self.format_docs(relevant_docs)
# Generate response
prompt = ChatPromptTemplate.from_messages([
("system", RAG_SYSTEM_PROMPT),
("human", "{input}")
])
chain = (
{"context": lambda x: context, "input": RunnablePassthrough()}
| prompt
| self.rag_llm
| StrOutputParser()
)
response = await chain.ainvoke(question)
logger.info("Successfully generated contextual response")
return response
except Exception as e:
logger.error(f"Error generating contextual response: {str(e)}")
return f"Error generating response: {str(e)}"
# Create a singleton instance
rag_manager = RAGManager()
logger.info("RAGManager singleton instance created")
# Export the get_contextual_response function for backward compatibility
async def get_contextual_response(question: str) -> str:
return await rag_manager.get_contextual_response(question)