rag / rag_with_gemini.py
jessica45's picture
updated rag
5f04d6e verified
"""
RAG (Retrieval-Augmented Generation) system with Gemini
"""
import google.generativeai as genai
import time
import logging
from typing import List, Dict, Any, Optional
from embeddings_qdrant import EmbeddingManager, QdrantVectorStore
from index_docs import extract_text_from_path, chunk_text
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class RAGSystem:
"""Complete RAG system with Gemini AI"""
def __init__(self, gemini_api_key: str, qdrant_url: str, qdrant_api_key: str):
"""Initialize RAG system with Gemini and Qdrant"""
# Configure Gemini
genai.configure(api_key=gemini_api_key)
self.model = genai.GenerativeModel("models/gemini-2.5-flash")
# Initialize components
self.embedding_manager = EmbeddingManager(gemini_api_key)
# Try Qdrant Cloud first, fallback to simple vector store
try:
self.vector_store = QdrantVectorStore(url=qdrant_url, api_key=qdrant_api_key)
self.vector_store.create_collection(force_recreate=True)
logger.info("โœ… Connected to Qdrant Cloud")
self.using_qdrant = True
except Exception as e:
logger.warning(f"โŒ Qdrant Cloud connection failed: {e}")
logger.info("๐Ÿ”„ Falling back to simple vector store")
self.vector_store.create_collection()
self.using_qdrant = False
def add_documents(self, file_paths: List[str], session_id: Optional[str] = None) -> bool:
"""Add documents to the vector store"""
try:
all_chunks = []
for file_path in file_paths:
logger.info(f"Processing {file_path}")
# Extract text
text = extract_text_from_path(file_path)
if not text:
logger.warning(f"No text extracted from {file_path}")
continue
# Chunk text
chunks = chunk_text(text)
# Add metadata
for chunk in chunks:
all_chunks.append({
'text': chunk,
'source': file_path,
'chunk_id': len(all_chunks)
})
if not all_chunks:
logger.error("No chunks to process")
return False
# Generate embeddings
logger.info(f"Generating embeddings for {len(all_chunks)} chunks")
embeddings = []
texts = []
metadata_list = []
for i, chunk in enumerate(all_chunks):
try:
# Generate embedding
embedding = self.embedding_manager.generate_embedding(chunk['text'])
embeddings.append(embedding)
texts.append(chunk['text'])
metadata_list.append({
'source': chunk['source'],
'chunk_id': chunk['chunk_id']
})
logger.info(f"Generated embedding {i+1}/{len(all_chunks)}")
# Small delay to avoid rate limits
time.sleep(0.1)
except Exception as e:
logger.error(f"Error processing chunk {i}: {e}")
continue
# Store all embeddings in vector database
if embeddings and texts:
logger.info(f"Storing {len(embeddings)} embeddings in vector database (session={session_id})")
# Forward session_id so it is stored with each point
self.vector_store.add_documents(texts, embeddings, metadata_list, session_id=session_id)
logger.info("Document processing completed successfully!")
return True
except Exception as e:
logger.error(f"Error adding documents: {e}")
return False
def make_rag_prompt(self, query: str, context_passages: List[str]) -> str:
"""Create RAG prompt with the user's specified format"""
context = "\n\n".join([f"Context {i+1}: {passage}" for i, passage in enumerate(context_passages)])
prompt = f"""You are a helpful assistant. Answer the user's question based on the provided context. If the context doesn't contain enough information to answer the question, say so clearly.
Context:
{context}
Question: {query}
Answer:"""
return prompt
def generate_answer(self, prompt: str, max_retries: int = 3) -> str:
"""Generate answer using Gemini with retry logic"""
for attempt in range(max_retries):
try:
response = self.model.generate_content(prompt)
if response and response.text:
return response.text.strip()
else:
logger.warning(f"Empty response on attempt {attempt + 1}")
except Exception as e:
logger.error(f"Error generating answer (attempt {attempt + 1}): {e}")
if "429" in str(e) or "quota" in str(e).lower():
if attempt < max_retries - 1:
wait_time = (2 ** attempt) * 2 # Exponential backoff
logger.info(f"Rate limit hit, waiting {wait_time} seconds...")
time.sleep(wait_time)
else:
return "I'm sorry, I'm currently experiencing high demand. Please try again in a few minutes."
elif attempt < max_retries - 1:
time.sleep(1)
else:
return f"I encountered an error while generating the answer: {str(e)}"
return "I'm sorry, I couldn't generate an answer at this time. Please try again."
def query(self, question: str, top_k: int = 3) -> Dict[str, Any]:
"""Handle complete RAG query process"""
try:
logger.info(f"Processing query: {question}")
# Generate query embedding
query_embedding = self.embedding_manager.generate_embedding(question)
# Search for relevant passages
search_results = self.vector_store.similarity_search(
query_embedding=query_embedding,
top_k=top_k
)
if not search_results:
return {
'answer': "I couldn't find relevant information to answer your question.",
'sources': [],
'context_used': []
}
# Extract context passages and sources
context_passages = [result.get('chunk', '') for result in search_results]
sources = [result.get('metadata', {}).get('source', 'Unknown') for result in search_results]
# Create RAG prompt
rag_prompt = self.make_rag_prompt(question, context_passages)
# Generate answer
answer = self.generate_answer(rag_prompt)
return {
'answer': answer,
'sources': list(set(sources)), # Remove duplicates
'context_used': context_passages
}
except Exception as e:
logger.error(f"Error in query processing: {e}")
return {
'answer': f"I encountered an error while processing your question: {str(e)}",
'sources': [],
'context_used': []
}
def handle_query(rag_system: RAGSystem, query: str) -> Dict[str, Any]:
"""Handle a single query through the RAG system"""
return rag_system.query(query)