File size: 2,613 Bytes
bc620e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
import pickle
import faiss
import numpy as np
import google.generativeai as genai
import traceback
from dotenv import load_dotenv

# --- 1. Force Load API Key ---
load_dotenv()
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")

if not GEMINI_API_KEY:
    print("⚠️ WARNING: GEMINI_API_KEY not found in rag.py environment.")
else:
    genai.configure(api_key=GEMINI_API_KEY)

# Paths
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
VECTOR_STORE_DIR = os.path.join(BASE_DIR, "vector_store")
INDEX_PATH = os.path.join(VECTOR_STORE_DIR, "faiss_index.bin")
METADATA_PATH = os.path.join(VECTOR_STORE_DIR, "chunks_metadata.pkl")

# API Config
EMBEDDING_MODEL = "models/text-embedding-004"

# Global Components
faiss_index = None
chunks = []


def initialize_rag():
    global faiss_index, chunks

    print("--- RAG INITIALIZATION ---")
    if not os.path.exists(INDEX_PATH) or not os.path.exists(METADATA_PATH):
        print(f"CRITICAL: Index files not found at {VECTOR_STORE_DIR}")
        return

    try:
        faiss_index = faiss.read_index(INDEX_PATH)
        with open(METADATA_PATH, "rb") as f:
            data = pickle.load(f)
            chunks = data['chunks']
        print(f"✅ RAG Loaded. {len(chunks)} chunks indexed.")
    except Exception as e:
        print(f"❌ Error loading RAG files: {e}")


def retrieve_context(query: str, k: int = 2):
    """Retrieves text chunks using Gemini Embeddings."""
    if not faiss_index:
        print("⚠️ RAG Retrieval Skipped: Index not loaded.")
        return []

    try:
        # 1. Get embedding from API
        result = genai.embed_content(
            model=EMBEDDING_MODEL,
            content=query,
            task_type="retrieval_query"
        )

        # 2. Convert to Numpy
        query_vec = np.array([result['embedding']]).astype("float32")

        # 3. Check Dimensions (Debug Step)
        if faiss_index.d != query_vec.shape[1]:
            print(
                f"❌ DIMENSION MISMATCH: Index expects {faiss_index.d}, but Query is {query_vec.shape[1]}")
            print(
                "SOLUTION: Delete backend/vector_store and run create_vector_db.py again.")
            return []

        # 4. Search FAISS
        distances, indices = faiss_index.search(query_vec, k)

        retrieved_text = []
        for i in indices[0]:
            if i != -1 and i < len(chunks):
                retrieved_text.append(chunks[i])

        return retrieved_text

    except Exception as e:
        print(f"❌ RAG ERROR: {e}")
        traceback.print_exc()  # Prints the full error to the terminal
        return []