Spaces:
Sleeping
Sleeping
File size: 2,613 Bytes
bc620e9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 | import os
import pickle
import faiss
import numpy as np
import google.generativeai as genai
import traceback
from dotenv import load_dotenv
# --- 1. Force Load API Key ---
load_dotenv()
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
if not GEMINI_API_KEY:
print("⚠️ WARNING: GEMINI_API_KEY not found in rag.py environment.")
else:
genai.configure(api_key=GEMINI_API_KEY)
# Paths
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
VECTOR_STORE_DIR = os.path.join(BASE_DIR, "vector_store")
INDEX_PATH = os.path.join(VECTOR_STORE_DIR, "faiss_index.bin")
METADATA_PATH = os.path.join(VECTOR_STORE_DIR, "chunks_metadata.pkl")
# API Config
EMBEDDING_MODEL = "models/text-embedding-004"
# Global Components
faiss_index = None
chunks = []
def initialize_rag():
global faiss_index, chunks
print("--- RAG INITIALIZATION ---")
if not os.path.exists(INDEX_PATH) or not os.path.exists(METADATA_PATH):
print(f"CRITICAL: Index files not found at {VECTOR_STORE_DIR}")
return
try:
faiss_index = faiss.read_index(INDEX_PATH)
with open(METADATA_PATH, "rb") as f:
data = pickle.load(f)
chunks = data['chunks']
print(f"✅ RAG Loaded. {len(chunks)} chunks indexed.")
except Exception as e:
print(f"❌ Error loading RAG files: {e}")
def retrieve_context(query: str, k: int = 2):
"""Retrieves text chunks using Gemini Embeddings."""
if not faiss_index:
print("⚠️ RAG Retrieval Skipped: Index not loaded.")
return []
try:
# 1. Get embedding from API
result = genai.embed_content(
model=EMBEDDING_MODEL,
content=query,
task_type="retrieval_query"
)
# 2. Convert to Numpy
query_vec = np.array([result['embedding']]).astype("float32")
# 3. Check Dimensions (Debug Step)
if faiss_index.d != query_vec.shape[1]:
print(
f"❌ DIMENSION MISMATCH: Index expects {faiss_index.d}, but Query is {query_vec.shape[1]}")
print(
"SOLUTION: Delete backend/vector_store and run create_vector_db.py again.")
return []
# 4. Search FAISS
distances, indices = faiss_index.search(query_vec, k)
retrieved_text = []
for i in indices[0]:
if i != -1 and i < len(chunks):
retrieved_text.append(chunks[i])
return retrieved_text
except Exception as e:
print(f"❌ RAG ERROR: {e}")
traceback.print_exc() # Prints the full error to the terminal
return []
|