import os import pickle import faiss import numpy as np import google.generativeai as genai import traceback from dotenv import load_dotenv # --- 1. Force Load API Key --- load_dotenv() GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") if not GEMINI_API_KEY: print("⚠️ WARNING: GEMINI_API_KEY not found in rag.py environment.") else: genai.configure(api_key=GEMINI_API_KEY) # Paths BASE_DIR = os.path.dirname(os.path.abspath(__file__)) VECTOR_STORE_DIR = os.path.join(BASE_DIR, "vector_store") INDEX_PATH = os.path.join(VECTOR_STORE_DIR, "faiss_index.bin") METADATA_PATH = os.path.join(VECTOR_STORE_DIR, "chunks_metadata.pkl") # API Config EMBEDDING_MODEL = "models/text-embedding-004" # Global Components faiss_index = None chunks = [] def initialize_rag(): global faiss_index, chunks print("--- RAG INITIALIZATION ---") if not os.path.exists(INDEX_PATH) or not os.path.exists(METADATA_PATH): print(f"CRITICAL: Index files not found at {VECTOR_STORE_DIR}") return try: faiss_index = faiss.read_index(INDEX_PATH) with open(METADATA_PATH, "rb") as f: data = pickle.load(f) chunks = data['chunks'] print(f"✅ RAG Loaded. {len(chunks)} chunks indexed.") except Exception as e: print(f"❌ Error loading RAG files: {e}") def retrieve_context(query: str, k: int = 2): """Retrieves text chunks using Gemini Embeddings.""" if not faiss_index: print("⚠️ RAG Retrieval Skipped: Index not loaded.") return [] try: # 1. Get embedding from API result = genai.embed_content( model=EMBEDDING_MODEL, content=query, task_type="retrieval_query" ) # 2. Convert to Numpy query_vec = np.array([result['embedding']]).astype("float32") # 3. Check Dimensions (Debug Step) if faiss_index.d != query_vec.shape[1]: print( f"❌ DIMENSION MISMATCH: Index expects {faiss_index.d}, but Query is {query_vec.shape[1]}") print( "SOLUTION: Delete backend/vector_store and run create_vector_db.py again.") return [] # 4. Search FAISS distances, indices = faiss_index.search(query_vec, k) retrieved_text = [] for i in indices[0]: if i != -1 and i < len(chunks): retrieved_text.append(chunks[i]) return retrieved_text except Exception as e: print(f"❌ RAG ERROR: {e}") traceback.print_exc() # Prints the full error to the terminal return []