Spaces:
Running
Running
Commit
Β·
42d0898
1
Parent(s):
07a5b4f
Production grade bot with re-index button
Browse files- src/config.py +9 -0
- src/ingest.py +20 -29
- src/rag.py +28 -36
- src/streamlit_app.py +15 -6
src/config.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# config.py
|
| 2 |
+
INDEX_DIR = "src/index"
|
| 3 |
+
FAISS_INDEX_PATH = f"{INDEX_DIR}/faiss.index"
|
| 4 |
+
DOC_STORE_PATH = f"{INDEX_DIR}/documents.pkl"
|
| 5 |
+
|
| 6 |
+
HF_DATASET = "OnlyTheTruth03/ott"
|
| 7 |
+
HF_SPLIT = "train"
|
| 8 |
+
|
| 9 |
+
TOP_K = 4
|
src/ingest.py
CHANGED
|
@@ -1,57 +1,48 @@
|
|
| 1 |
-
#
|
| 2 |
import os
|
| 3 |
import pickle
|
| 4 |
-
import faiss
|
| 5 |
from datasets import load_dataset
|
| 6 |
from pypdf import PdfReader
|
| 7 |
from sentence_transformers import SentenceTransformer
|
|
|
|
| 8 |
|
| 9 |
-
|
| 10 |
-
INDEX_DIR = "src/index"
|
| 11 |
|
| 12 |
-
os.makedirs(INDEX_DIR, exist_ok=True)
|
| 13 |
|
| 14 |
-
|
|
|
|
| 15 |
|
|
|
|
| 16 |
|
| 17 |
-
|
| 18 |
-
print("π₯ Loading HF dataset...")
|
| 19 |
-
dataset = load_dataset(DATASET_NAME, split="train")
|
| 20 |
|
| 21 |
documents = []
|
|
|
|
| 22 |
|
| 23 |
for row in dataset:
|
| 24 |
-
# HF auto
|
| 25 |
-
|
| 26 |
|
| 27 |
-
|
| 28 |
-
pdf_path = pdf_obj.path
|
| 29 |
-
|
| 30 |
-
print(f"π Reading PDF from: {pdf_path}")
|
| 31 |
-
reader = PdfReader(pdf_path)
|
| 32 |
-
|
| 33 |
-
for page_no, page in enumerate(reader.pages, start=1):
|
| 34 |
text = page.extract_text()
|
| 35 |
if not text:
|
| 36 |
continue
|
| 37 |
|
| 38 |
documents.append({
|
| 39 |
-
"text": text
|
| 40 |
-
"page": page_no
|
|
|
|
| 41 |
})
|
|
|
|
| 42 |
|
| 43 |
-
|
| 44 |
-
raise RuntimeError("β No text extracted from PDF")
|
| 45 |
-
|
| 46 |
-
texts = [d["text"] for d in documents]
|
| 47 |
-
embeddings = embedder.encode(texts).astype("float32")
|
| 48 |
|
| 49 |
-
index = faiss.IndexFlatL2(
|
| 50 |
-
index.add(
|
| 51 |
|
| 52 |
-
faiss.write_index(index,
|
| 53 |
|
| 54 |
-
with open(
|
| 55 |
pickle.dump(documents, f)
|
| 56 |
|
| 57 |
print("β
FAISS index built successfully")
|
|
|
|
| 1 |
+
# ingest.py
|
| 2 |
import os
|
| 3 |
import pickle
|
|
|
|
| 4 |
from datasets import load_dataset
|
| 5 |
from pypdf import PdfReader
|
| 6 |
from sentence_transformers import SentenceTransformer
|
| 7 |
+
import faiss
|
| 8 |
|
| 9 |
+
from config import INDEX_DIR, FAISS_INDEX_PATH, DOC_STORE_PATH, HF_DATASET, HF_SPLIT
|
|
|
|
| 10 |
|
|
|
|
| 11 |
|
| 12 |
+
def build_index():
|
| 13 |
+
os.makedirs(INDEX_DIR, exist_ok=True)
|
| 14 |
|
| 15 |
+
dataset = load_dataset(HF_DATASET, split=HF_SPLIT)
|
| 16 |
|
| 17 |
+
embedder = SentenceTransformer("all-MiniLM-L6-v2")
|
|
|
|
|
|
|
| 18 |
|
| 19 |
documents = []
|
| 20 |
+
embeddings = []
|
| 21 |
|
| 22 |
for row in dataset:
|
| 23 |
+
pdf_obj = row["pdf"] # HF auto column
|
| 24 |
+
reader = PdfReader(pdf_obj)
|
| 25 |
|
| 26 |
+
for page_no, page in enumerate(reader.pages):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
text = page.extract_text()
|
| 28 |
if not text:
|
| 29 |
continue
|
| 30 |
|
| 31 |
documents.append({
|
| 32 |
+
"text": text,
|
| 33 |
+
"page": page_no + 1,
|
| 34 |
+
"source": "dataset_pdf"
|
| 35 |
})
|
| 36 |
+
embeddings.append(text)
|
| 37 |
|
| 38 |
+
vectors = embedder.encode(embeddings, show_progress_bar=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
+
index = faiss.IndexFlatL2(vectors.shape[1])
|
| 41 |
+
index.add(vectors)
|
| 42 |
|
| 43 |
+
faiss.write_index(index, FAISS_INDEX_PATH)
|
| 44 |
|
| 45 |
+
with open(DOC_STORE_PATH, "wb") as f:
|
| 46 |
pickle.dump(documents, f)
|
| 47 |
|
| 48 |
print("β
FAISS index built successfully")
|
src/rag.py
CHANGED
|
@@ -1,69 +1,61 @@
|
|
| 1 |
-
#
|
| 2 |
import os
|
| 3 |
import pickle
|
| 4 |
import faiss
|
| 5 |
-
import numpy as np
|
| 6 |
-
from sentence_transformers import SentenceTransformer
|
| 7 |
from groq import Groq
|
| 8 |
-
from
|
| 9 |
|
| 10 |
-
|
| 11 |
-
TOP_K = 4
|
| 12 |
-
|
| 13 |
-
embedder = SentenceTransformer("all-MiniLM-L6-v2")
|
| 14 |
|
| 15 |
GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
|
| 16 |
if not GROQ_API_KEY:
|
| 17 |
-
raise RuntimeError("β GROQ_API_KEY not set
|
| 18 |
|
| 19 |
client = Groq(api_key=GROQ_API_KEY)
|
|
|
|
| 20 |
|
| 21 |
|
| 22 |
def load_index():
|
| 23 |
-
|
| 24 |
-
|
| 25 |
|
| 26 |
-
|
| 27 |
-
if not os.path.exists(index_path) or not os.path.exists(docs_path):
|
| 28 |
-
print("β οΈ FAISS index missing. Running ingestion...")
|
| 29 |
-
build_index()
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
with open(docs_path, "rb") as f:
|
| 34 |
documents = pickle.load(f)
|
| 35 |
|
| 36 |
return index, documents
|
| 37 |
|
| 38 |
|
| 39 |
-
def retrieve(query
|
| 40 |
index, documents = load_index()
|
| 41 |
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
results = []
|
| 46 |
-
for idx in indices[0]:
|
| 47 |
-
if idx == -1:
|
| 48 |
-
continue
|
| 49 |
-
results.append(documents[idx])
|
| 50 |
|
| 51 |
-
return
|
| 52 |
|
| 53 |
|
| 54 |
def ask_llm(query, contexts):
|
| 55 |
context_text = "\n\n".join(
|
| 56 |
-
f"
|
| 57 |
-
for c in contexts
|
| 58 |
)
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
response = client.chat.completions.create(
|
| 61 |
-
model="llama-3.1-
|
| 62 |
-
messages=[
|
| 63 |
-
|
| 64 |
-
{"role": "user", "content": f"{context_text}\n\nQuestion: {query}"}
|
| 65 |
-
],
|
| 66 |
-
temperature=0.2
|
| 67 |
)
|
| 68 |
|
| 69 |
return response.choices[0].message.content
|
|
|
|
| 1 |
+
# rag.py
|
| 2 |
import os
|
| 3 |
import pickle
|
| 4 |
import faiss
|
|
|
|
|
|
|
| 5 |
from groq import Groq
|
| 6 |
+
from sentence_transformers import SentenceTransformer
|
| 7 |
|
| 8 |
+
from config import FAISS_INDEX_PATH, DOC_STORE_PATH, TOP_K
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
|
| 11 |
if not GROQ_API_KEY:
|
| 12 |
+
raise RuntimeError("β GROQ_API_KEY not set")
|
| 13 |
|
| 14 |
client = Groq(api_key=GROQ_API_KEY)
|
| 15 |
+
embedder = SentenceTransformer("all-MiniLM-L6-v2")
|
| 16 |
|
| 17 |
|
| 18 |
def load_index():
|
| 19 |
+
if not os.path.exists(FAISS_INDEX_PATH):
|
| 20 |
+
raise RuntimeError("β FAISS index not found. Run ingestion first.")
|
| 21 |
|
| 22 |
+
index = faiss.read_index(FAISS_INDEX_PATH)
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
+
with open(DOC_STORE_PATH, "rb") as f:
|
|
|
|
|
|
|
| 25 |
documents = pickle.load(f)
|
| 26 |
|
| 27 |
return index, documents
|
| 28 |
|
| 29 |
|
| 30 |
+
def retrieve(query):
|
| 31 |
index, documents = load_index()
|
| 32 |
|
| 33 |
+
q_vec = embedder.encode([query])
|
| 34 |
+
_, indices = index.search(q_vec, TOP_K)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
+
return [documents[i] for i in indices[0]]
|
| 37 |
|
| 38 |
|
| 39 |
def ask_llm(query, contexts):
|
| 40 |
context_text = "\n\n".join(
|
| 41 |
+
f"(Page {c['page']}): {c['text']}" for c in contexts
|
|
|
|
| 42 |
)
|
| 43 |
|
| 44 |
+
prompt = f"""
|
| 45 |
+
Answer the question using only the context below.
|
| 46 |
+
If the answer is not found, say so.
|
| 47 |
+
|
| 48 |
+
Context:
|
| 49 |
+
{context_text}
|
| 50 |
+
|
| 51 |
+
Question:
|
| 52 |
+
{query}
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
response = client.chat.completions.create(
|
| 56 |
+
model="llama-3.1-70b-versatile",
|
| 57 |
+
messages=[{"role": "user", "content": prompt}],
|
| 58 |
+
temperature=0.2,
|
|
|
|
|
|
|
|
|
|
| 59 |
)
|
| 60 |
|
| 61 |
return response.choices[0].message.content
|
src/streamlit_app.py
CHANGED
|
@@ -1,16 +1,25 @@
|
|
| 1 |
-
#
|
| 2 |
import streamlit as st
|
|
|
|
| 3 |
from rag import retrieve, ask_llm
|
| 4 |
|
| 5 |
-
st.
|
| 6 |
-
st.title("πͺ OTT Astrology Assistant")
|
| 7 |
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
if query:
|
| 11 |
-
with st.spinner("
|
| 12 |
contexts = retrieve(query)
|
| 13 |
answer = ask_llm(query, contexts)
|
| 14 |
|
| 15 |
-
st.markdown("###
|
| 16 |
st.write(answer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# streamlit_app.py
|
| 2 |
import streamlit as st
|
| 3 |
+
from ingest import build_index
|
| 4 |
from rag import retrieve, ask_llm
|
| 5 |
|
| 6 |
+
st.title("π OTT Knowledge Bot")
|
|
|
|
| 7 |
|
| 8 |
+
if st.button("π Build / Rebuild Index"):
|
| 9 |
+
with st.spinner("Building index..."):
|
| 10 |
+
build_index()
|
| 11 |
+
st.success("Index built successfully")
|
| 12 |
+
|
| 13 |
+
query = st.text_input("Ask a question")
|
| 14 |
|
| 15 |
if query:
|
| 16 |
+
with st.spinner("Searching..."):
|
| 17 |
contexts = retrieve(query)
|
| 18 |
answer = ask_llm(query, contexts)
|
| 19 |
|
| 20 |
+
st.markdown("### β
Answer")
|
| 21 |
st.write(answer)
|
| 22 |
+
|
| 23 |
+
with st.expander("π Sources"):
|
| 24 |
+
for c in contexts:
|
| 25 |
+
st.markdown(f"- Page {c['page']}")
|