| """ |
| RAG (Retrieval-Augmented Generation) system with Gemini |
| """ |
|
|
| import google.generativeai as genai |
| import time |
| import logging |
| from typing import List, Dict, Any, Optional |
| from embeddings_qdrant import EmbeddingManager, QdrantVectorStore |
| from index_docs import extract_text_from_path, chunk_text |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| class RAGSystem: |
| """Complete RAG system with Gemini AI""" |
| |
| def __init__(self, gemini_api_key: str, qdrant_url: str, qdrant_api_key: str): |
| """Initialize RAG system with Gemini and Qdrant""" |
| |
| genai.configure(api_key=gemini_api_key) |
| self.model = genai.GenerativeModel("models/gemini-2.5-flash") |
|
|
| |
| |
| self.embedding_manager = EmbeddingManager(gemini_api_key) |
| |
| |
| try: |
| self.vector_store = QdrantVectorStore(url=qdrant_url, api_key=qdrant_api_key) |
| self.vector_store.create_collection(force_recreate=True) |
| logger.info("โ
Connected to Qdrant Cloud") |
| self.using_qdrant = True |
| except Exception as e: |
| logger.warning(f"โ Qdrant Cloud connection failed: {e}") |
| logger.info("๐ Falling back to simple vector store") |
| self.vector_store.create_collection() |
| self.using_qdrant = False |
| |
| def add_documents(self, file_paths: List[str], session_id: Optional[str] = None) -> bool: |
| """Add documents to the vector store""" |
| try: |
| all_chunks = [] |
| |
| for file_path in file_paths: |
| logger.info(f"Processing {file_path}") |
| |
| |
| text = extract_text_from_path(file_path) |
| if not text: |
| logger.warning(f"No text extracted from {file_path}") |
| continue |
| |
| |
| chunks = chunk_text(text) |
| |
| |
| for chunk in chunks: |
| all_chunks.append({ |
| 'text': chunk, |
| 'source': file_path, |
| 'chunk_id': len(all_chunks) |
| }) |
| |
| if not all_chunks: |
| logger.error("No chunks to process") |
| return False |
| |
| |
| logger.info(f"Generating embeddings for {len(all_chunks)} chunks") |
| |
| embeddings = [] |
| texts = [] |
| metadata_list = [] |
| |
| for i, chunk in enumerate(all_chunks): |
| try: |
| |
| embedding = self.embedding_manager.generate_embedding(chunk['text']) |
| |
| embeddings.append(embedding) |
| texts.append(chunk['text']) |
| metadata_list.append({ |
| 'source': chunk['source'], |
| 'chunk_id': chunk['chunk_id'] |
| }) |
| |
| logger.info(f"Generated embedding {i+1}/{len(all_chunks)}") |
| |
| |
| time.sleep(0.1) |
| |
| except Exception as e: |
| logger.error(f"Error processing chunk {i}: {e}") |
| continue |
| |
| |
| if embeddings and texts: |
| logger.info(f"Storing {len(embeddings)} embeddings in vector database (session={session_id})") |
| |
| self.vector_store.add_documents(texts, embeddings, metadata_list, session_id=session_id) |
| |
| logger.info("Document processing completed successfully!") |
| return True |
| |
| except Exception as e: |
| logger.error(f"Error adding documents: {e}") |
| return False |
| |
| def make_rag_prompt(self, query: str, context_passages: List[str]) -> str: |
| """Create RAG prompt with the user's specified format""" |
| context = "\n\n".join([f"Context {i+1}: {passage}" for i, passage in enumerate(context_passages)]) |
| |
| prompt = f"""You are a helpful assistant. Answer the user's question based on the provided context. If the context doesn't contain enough information to answer the question, say so clearly. |
| |
| Context: |
| {context} |
| |
| Question: {query} |
| |
| Answer:""" |
| |
| return prompt |
| |
| def generate_answer(self, prompt: str, max_retries: int = 3) -> str: |
| """Generate answer using Gemini with retry logic""" |
| for attempt in range(max_retries): |
| try: |
| response = self.model.generate_content(prompt) |
| |
| if response and response.text: |
| return response.text.strip() |
| else: |
| logger.warning(f"Empty response on attempt {attempt + 1}") |
| |
| except Exception as e: |
| logger.error(f"Error generating answer (attempt {attempt + 1}): {e}") |
| |
| if "429" in str(e) or "quota" in str(e).lower(): |
| if attempt < max_retries - 1: |
| wait_time = (2 ** attempt) * 2 |
| logger.info(f"Rate limit hit, waiting {wait_time} seconds...") |
| time.sleep(wait_time) |
| else: |
| return "I'm sorry, I'm currently experiencing high demand. Please try again in a few minutes." |
| elif attempt < max_retries - 1: |
| time.sleep(1) |
| else: |
| return f"I encountered an error while generating the answer: {str(e)}" |
| |
| return "I'm sorry, I couldn't generate an answer at this time. Please try again." |
| |
| def query(self, question: str, top_k: int = 3) -> Dict[str, Any]: |
| """Handle complete RAG query process""" |
| try: |
| logger.info(f"Processing query: {question}") |
| |
| |
| query_embedding = self.embedding_manager.generate_embedding(question) |
| |
| |
| search_results = self.vector_store.similarity_search( |
| query_embedding=query_embedding, |
| top_k=top_k |
| ) |
| |
| if not search_results: |
| return { |
| 'answer': "I couldn't find relevant information to answer your question.", |
| 'sources': [], |
| 'context_used': [] |
| } |
| |
| |
| context_passages = [result.get('chunk', '') for result in search_results] |
| sources = [result.get('metadata', {}).get('source', 'Unknown') for result in search_results] |
| |
| |
| rag_prompt = self.make_rag_prompt(question, context_passages) |
| |
| |
| answer = self.generate_answer(rag_prompt) |
| |
| return { |
| 'answer': answer, |
| 'sources': list(set(sources)), |
| 'context_used': context_passages |
| } |
| |
| except Exception as e: |
| logger.error(f"Error in query processing: {e}") |
| return { |
| 'answer': f"I encountered an error while processing your question: {str(e)}", |
| 'sources': [], |
| 'context_used': [] |
| } |
|
|
| def handle_query(rag_system: RAGSystem, query: str) -> Dict[str, Any]: |
| """Handle a single query through the RAG system""" |
| return rag_system.query(query) |