File size: 1,942 Bytes
a98b8cc
 
 
 
 
 
 
 
 
 
 
f0381b3
a98b8cc
 
 
 
 
 
 
 
 
 
 
 
 
f0381b3
a98b8cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f708f6
a98b8cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f708f6
a98b8cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import re
from sentence_transformers import SentenceTransformer
import chromadb
import google.generativeai as genai
import re


def load_documents(file_path, chunk_size=30):
    with open(file_path, "r", encoding="utf-8") as f:
        text = f.read()

    words = re.findall(r"\w+|\S", text)

    chunks = [
        " ".join(words[i : i + chunk_size]) for i in range(0, len(words), chunk_size)
    ]

    return [chunk.strip() for chunk in chunks if chunk.strip()]


def embed_documents(docs, model):
    return model.encode(docs).tolist()


def build_chroma_db(docs, embeddings):
    client = chromadb.Client()
    collection = client.get_or_create_collection("rag_docs")
    for i, (doc, emb) in enumerate(zip(docs, embeddings)):
        collection.add(documents=[doc], embeddings=[emb], ids=[str(i)])
    return collection


def retrieve(query, collection, model, top_k=3):
    query_emb = model.encode([query]).tolist()[0]
    results = collection.query(
        query_embeddings=[query_emb], n_results=top_k, include=["documents"]
    )
    return results["documents"][0]


def call_gemini(query, context, api_key):
    genai.configure(api_key=api_key)
    model = genai.GenerativeModel("gemini-2.0-flash")
    prompt = (
        "Answer the following question based only on the provided context.\n\n"
        f"Context:\n{context}\n\n"
        f"Question: {query}\n"
        "Answer:"
    )
    response = model.generate_content(prompt)
    return response.text.strip()


def build_rag_chain(api_key):
    docs = load_documents("data/info.txt")
    model = SentenceTransformer("all-MiniLM-L6-v2")
    embeddings = embed_documents(docs, model)
    collection = build_chroma_db(docs, embeddings)

    def rag_qa(query):
        retrieved_docs = retrieve(query, collection, model)
        context = "\n\n".join(retrieved_docs)
        answer = call_gemini(query, context, api_key)
        return answer, retrieved_docs

    return rag_qa