tharu280's picture
Initial commit
bc620e9
import os
import pickle
import faiss
import numpy as np
import google.generativeai as genai
import traceback
from dotenv import load_dotenv
# --- 1. Force Load API Key ---
load_dotenv()
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
if not GEMINI_API_KEY:
print("⚠️ WARNING: GEMINI_API_KEY not found in rag.py environment.")
else:
genai.configure(api_key=GEMINI_API_KEY)
# Paths
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
VECTOR_STORE_DIR = os.path.join(BASE_DIR, "vector_store")
INDEX_PATH = os.path.join(VECTOR_STORE_DIR, "faiss_index.bin")
METADATA_PATH = os.path.join(VECTOR_STORE_DIR, "chunks_metadata.pkl")
# API Config
EMBEDDING_MODEL = "models/text-embedding-004"
# Global Components
faiss_index = None
chunks = []
def initialize_rag():
global faiss_index, chunks
print("--- RAG INITIALIZATION ---")
if not os.path.exists(INDEX_PATH) or not os.path.exists(METADATA_PATH):
print(f"CRITICAL: Index files not found at {VECTOR_STORE_DIR}")
return
try:
faiss_index = faiss.read_index(INDEX_PATH)
with open(METADATA_PATH, "rb") as f:
data = pickle.load(f)
chunks = data['chunks']
print(f"βœ… RAG Loaded. {len(chunks)} chunks indexed.")
except Exception as e:
print(f"❌ Error loading RAG files: {e}")
def retrieve_context(query: str, k: int = 2):
"""Retrieves text chunks using Gemini Embeddings."""
if not faiss_index:
print("⚠️ RAG Retrieval Skipped: Index not loaded.")
return []
try:
# 1. Get embedding from API
result = genai.embed_content(
model=EMBEDDING_MODEL,
content=query,
task_type="retrieval_query"
)
# 2. Convert to Numpy
query_vec = np.array([result['embedding']]).astype("float32")
# 3. Check Dimensions (Debug Step)
if faiss_index.d != query_vec.shape[1]:
print(
f"❌ DIMENSION MISMATCH: Index expects {faiss_index.d}, but Query is {query_vec.shape[1]}")
print(
"SOLUTION: Delete backend/vector_store and run create_vector_db.py again.")
return []
# 4. Search FAISS
distances, indices = faiss_index.search(query_vec, k)
retrieved_text = []
for i in indices[0]:
if i != -1 and i < len(chunks):
retrieved_text.append(chunks[i])
return retrieved_text
except Exception as e:
print(f"❌ RAG ERROR: {e}")
traceback.print_exc() # Prints the full error to the terminal
return []