File size: 4,601 Bytes
41ac698
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# ─────────────────────────────────────────────────────────────
# app/rag_app.py
# Main RAG application β€” runs locally, calls HF for everything
# ─────────────────────────────────────────────────────────────

import os
import sys

# Load .env file
from dotenv import load_dotenv
load_dotenv()

# Add project root to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from utils.embedder  import HFEmbedder
from utils.retriever import FAISSRetriever
from utils.generator import HFGenerator


# ── Config ────────────────────────────────────────────────────
DOCS_PATH        = os.getenv("DOCS_PATH",        "data/sample_docs.txt")
FAISS_INDEX_PATH = os.getenv("FAISS_INDEX_PATH", "vector_store/index.faiss")
TOP_K            = 3


# ── Load documents ────────────────────────────────────────────
def load_documents(path: str) -> list:
    if not os.path.exists(path):
        raise FileNotFoundError(f"No documents found at {path}")
    with open(path) as f:
        docs = [line.strip() for line in f if line.strip()]
    print(f"Loaded {len(docs)} documents from {path}")
    return docs


# ── Build or load index ───────────────────────────────────────
def setup_retriever(embedder: HFEmbedder, force_rebuild: bool = False) -> FAISSRetriever:
    retriever = FAISSRetriever(FAISS_INDEX_PATH)

    if os.path.exists(FAISS_INDEX_PATH) and not force_rebuild:
        print("Loading existing FAISS index...")
        retriever.load()
    else:
        print("Building new FAISS index...")
        docs       = load_documents(DOCS_PATH)
        embeddings = embedder.embed_batch(docs)
        retriever.build(docs, embeddings)
        retriever.save()

    return retriever


# ── Main RAG function ─────────────────────────────────────────
class RAGPipeline:
    def __init__(self, force_rebuild: bool = False):
        print("\n" + "=" * 55)
        print("  RAG Pipeline β€” Your Own HF Model")
        print("=" * 55)

        # Initialize components
        self.embedder  = HFEmbedder()
        self.retriever = setup_retriever(self.embedder, force_rebuild)
        self.generator = HFGenerator()
        print("\nAll components ready!\n")

    def ask(self, question: str, verbose: bool = True) -> dict:
        """Ask a question and get an answer grounded in your documents."""

        if verbose:
            print(f"Question : {question}")

        # Step 1: Embed query
        query_vec = self.embedder.embed(question)

        # Step 2: Retrieve relevant chunks
        chunks = self.retriever.search(query_vec, top_k=TOP_K)

        if verbose:
            print(f"Retrieved : {[c['text'][:60] for c in chunks]}")

        # Step 3: Generate answer
        answer = self.generator.generate(question, chunks)

        if verbose:
            print(f"Answer   : {answer}\n")

        return {
            "question": question,
            "answer"  : answer,
            "sources" : [c["text"] for c in chunks]
        }


# ── Run interactively ─────────────────────────────────────────
if __name__ == "__main__":
    rag = RAGPipeline()

    # Demo questions
    demo_questions = [
        "What is the refund policy?",
        "How do I reset my password?",
        "When can I contact support?",
        "How long can I return a product?"
    ]

    print("=" * 55)
    print("  Demo Questions")
    print("=" * 55)

    for q in demo_questions:
        result = rag.ask(q)
        print(f"Q: {result['question']}")
        print(f"A: {result['answer']}")
        print("-" * 55)

    # Interactive mode
    print("\nInteractive mode β€” type your question (or 'quit' to exit)")
    while True:
        user_input = input("\nYou: ").strip()
        if user_input.lower() in ["quit", "exit", "q"]:
            print("Goodbye!")
            break
        if user_input:
            result = rag.ask(user_input)
            print(f"Bot: {result['answer']}")