Spaces:
Runtime error
Runtime error
Upload 8 files
Browse filesadded necessary files.
- .env +5 -0
- .gitignore +0 -0
- app.py +167 -0
- chain.py +111 -0
- ingest.py +89 -0
- kg_builder.py +145 -0
- requirements.txt +16 -0
- retriever.py +128 -0
.env
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
HUGGINGFACEHUB_API_TOKEN=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
|
| 2 |
+
NEO4J_URI=neo4j+s://b5a2cd23.databases.neo4j.io
|
| 3 |
+
NEO4J_USERNAME=neo4j
|
| 4 |
+
NEO4J_PASSWORD=6Cm5YT8pZUVxHdrXb3oCtyPNnAN8w7FHaZKqs-4ifLQ
|
| 5 |
+
GROQ_API=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
|
.gitignore
ADDED
|
File without changes
|
app.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import os, sys, tempfile
|
| 3 |
+
from ingest import load_document, chunk_documents, store_in_chromadb, get_vectorstore
|
| 4 |
+
from kg_builder import build_kg
|
| 5 |
+
from chain import RAGChain
|
| 6 |
+
|
| 7 |
+
# Page config
|
| 8 |
+
st.set_page_config(
|
| 9 |
+
page_title="KG-RAG Chatbot",
|
| 10 |
+
page_icon="π§ ",
|
| 11 |
+
layout="wide"
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
st.title("π§ KG-RAG Chatbot")
|
| 15 |
+
st.caption("Upload a PDF β Build Knowledge Graph β Chat with your document")
|
| 16 |
+
|
| 17 |
+
# Sidebar
|
| 18 |
+
with st.sidebar:
|
| 19 |
+
st.header("Upload Document")
|
| 20 |
+
|
| 21 |
+
uploaded_file = st.file_uploader("Upload a PDF", type=["pdf"])
|
| 22 |
+
|
| 23 |
+
if uploaded_file:
|
| 24 |
+
# Save uploaded file to /data directory
|
| 25 |
+
os.makedirs("data", exist_ok=True)
|
| 26 |
+
save_path = os.path.join("data", uploaded_file.name)
|
| 27 |
+
|
| 28 |
+
with open(save_path, "wb") as f:
|
| 29 |
+
f.write(uploaded_file.read())
|
| 30 |
+
|
| 31 |
+
st.success(f"Saved: {uploaded_file.name}")
|
| 32 |
+
|
| 33 |
+
# Process button
|
| 34 |
+
if st.button("Process PDF", type="primary", use_container_width=True):
|
| 35 |
+
# Clear old chat + chain
|
| 36 |
+
st.session_state.messages = []
|
| 37 |
+
st.session_state.chain_ready = False
|
| 38 |
+
st.session_state.pop("chain", None)
|
| 39 |
+
|
| 40 |
+
# Step 1: Ingest β ChromaDB
|
| 41 |
+
with st.status("Ingesting document...", expanded=True) as status:
|
| 42 |
+
try:
|
| 43 |
+
st.write("Loading PDF pages...")
|
| 44 |
+
docs = load_document(uploaded_file.name)
|
| 45 |
+
st.write(f"{len(docs)} pages loaded")
|
| 46 |
+
|
| 47 |
+
st.write("Chunking text...")
|
| 48 |
+
chunks = chunk_documents(docs)
|
| 49 |
+
st.write(f"{len(chunks)} chunks created")
|
| 50 |
+
|
| 51 |
+
st.write("Embedding + saving to ChromaDB...")
|
| 52 |
+
store_in_chromadb(chunks)
|
| 53 |
+
st.write("ChromaDB ready")
|
| 54 |
+
|
| 55 |
+
# Step 2: Build KG β Neo4j
|
| 56 |
+
st.write("Extracting triples β Neo4j Knowledge Graph...")
|
| 57 |
+
build_kg(uploaded_file.name)
|
| 58 |
+
st.write("Knowledge Graph built")
|
| 59 |
+
|
| 60 |
+
status.update(label="PDF processed! You can now chat.", state="complete")
|
| 61 |
+
st.session_state.chain_ready = True
|
| 62 |
+
st.session_state.current_doc = uploaded_file.name
|
| 63 |
+
|
| 64 |
+
except Exception as e:
|
| 65 |
+
status.update(label="Processing failed", state="error")
|
| 66 |
+
st.error(str(e))
|
| 67 |
+
|
| 68 |
+
st.divider()
|
| 69 |
+
|
| 70 |
+
# Settings
|
| 71 |
+
st.header("Settings")
|
| 72 |
+
show_kg = st.toggle("Show KG facts", value=True)
|
| 73 |
+
show_chunks = st.toggle("Show source chunks", value=False)
|
| 74 |
+
|
| 75 |
+
st.divider()
|
| 76 |
+
st.markdown("""
|
| 77 |
+
**How it works:**
|
| 78 |
+
1. Upload + Process your PDF
|
| 79 |
+
2. **ChromaDB** stores semantic chunks
|
| 80 |
+
3. **Neo4j** stores entity relationships
|
| 81 |
+
4. **LLaMA via Groq** answers your questions
|
| 82 |
+
""")
|
| 83 |
+
|
| 84 |
+
if st.button("ποΈ Clear chat history", use_container_width=True):
|
| 85 |
+
st.session_state.messages = []
|
| 86 |
+
st.rerun()
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# Load RAG chain (only after PDF is processed)
|
| 90 |
+
@st.cache_resource
|
| 91 |
+
def load_chain():
|
| 92 |
+
return RAGChain()
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# Main area
|
| 96 |
+
if not st.session_state.get("chain_ready"):
|
| 97 |
+
# Show a friendly landing state
|
| 98 |
+
st.info("Upload a PDF from the sidebar and click **Process PDF** to get started.")
|
| 99 |
+
|
| 100 |
+
col1, col2, col3 = st.columns(3)
|
| 101 |
+
with col1:
|
| 102 |
+
st.markdown("### Step 1\nUpload any PDF document from the sidebar")
|
| 103 |
+
with col2:
|
| 104 |
+
st.markdown("### Step 2\nClick **Process PDF** to build the knowledge graph")
|
| 105 |
+
with col3:
|
| 106 |
+
st.markdown("### Step 3\nAsk questions β get answers with KG + RAG context")
|
| 107 |
+
|
| 108 |
+
else:
|
| 109 |
+
# Show which doc is loaded
|
| 110 |
+
st.success(f"Active document: **{st.session_state.get('current_doc', 'Unknown')}**")
|
| 111 |
+
|
| 112 |
+
# Load chain (cached)
|
| 113 |
+
try:
|
| 114 |
+
if "chain" not in st.session_state:
|
| 115 |
+
with st.spinner("Loading RAG chain..."):
|
| 116 |
+
st.session_state.chain = load_chain()
|
| 117 |
+
chain = st.session_state.chain
|
| 118 |
+
except Exception as e:
|
| 119 |
+
st.error(f"Failed to load chain: {e}")
|
| 120 |
+
st.stop()
|
| 121 |
+
|
| 122 |
+
# ββ Chat history ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 123 |
+
if "messages" not in st.session_state:
|
| 124 |
+
st.session_state.messages = []
|
| 125 |
+
|
| 126 |
+
for msg in st.session_state.messages:
|
| 127 |
+
with st.chat_message(msg["role"]):
|
| 128 |
+
st.markdown(msg["content"])
|
| 129 |
+
if msg.get("kg_facts") and show_kg:
|
| 130 |
+
with st.expander("Knowledge Graph Facts"):
|
| 131 |
+
st.code(msg["kg_facts"])
|
| 132 |
+
if msg.get("sources") and show_chunks:
|
| 133 |
+
with st.expander("Source Chunks"):
|
| 134 |
+
for s in msg["sources"]:
|
| 135 |
+
st.markdown(f"**Page {s['page']}:** {s['snippet']}...")
|
| 136 |
+
|
| 137 |
+
# Chat input
|
| 138 |
+
if question := st.chat_input("Ask anything about your document..."):
|
| 139 |
+
st.session_state.messages.append({"role": "user", "content": question})
|
| 140 |
+
with st.chat_message("user"):
|
| 141 |
+
st.markdown(question)
|
| 142 |
+
|
| 143 |
+
with st.chat_message("assistant"):
|
| 144 |
+
with st.spinner("Retrieving from ChromaDB + Neo4j β asking LLM..."):
|
| 145 |
+
try:
|
| 146 |
+
result = chain.ask(question)
|
| 147 |
+
answer = result["answer"]
|
| 148 |
+
st.markdown(answer)
|
| 149 |
+
|
| 150 |
+
if result["kg_facts"] and show_kg:
|
| 151 |
+
with st.expander("Knowledge Graph Facts used"):
|
| 152 |
+
st.code(result["kg_facts"])
|
| 153 |
+
|
| 154 |
+
if show_chunks:
|
| 155 |
+
with st.expander("Source Chunks"):
|
| 156 |
+
for s in result["sources"]:
|
| 157 |
+
st.markdown(f"**Page {s['page']}:** {s['snippet']}...")
|
| 158 |
+
|
| 159 |
+
st.session_state.messages.append({
|
| 160 |
+
"role": "assistant",
|
| 161 |
+
"content": answer,
|
| 162 |
+
"kg_facts": result["kg_facts"],
|
| 163 |
+
"sources": result["sources"]
|
| 164 |
+
})
|
| 165 |
+
|
| 166 |
+
except Exception as e:
|
| 167 |
+
st.error(f"Error: {e}")
|
chain.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
+
from huggingface_hub import InferenceClient
|
| 4 |
+
from retriever import HybridRetriever
|
| 5 |
+
from groq import Groq
|
| 6 |
+
|
| 7 |
+
load_dotenv()
|
| 8 |
+
|
| 9 |
+
HF_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
| 10 |
+
HF_MODEL = "llama-3.1-8b-instant"
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# Prompt builder
|
| 14 |
+
# Constructs the final prompt sent to Mistral LLM. Uses the [INST] instruction format
|
| 15 |
+
# specific to Mistral-Instruct models. Injects the retrieved context (vector + graph)
|
| 16 |
+
# and the user question, with strict instructions to only use provided context.
|
| 17 |
+
def build_prompt(context: str, question: str) -> str:
|
| 18 |
+
return f"""
|
| 19 |
+
You are a helpful research assistant. Answer the question using ONLY the context provided.
|
| 20 |
+
If the answer is not in the context, say "I don't have enough information to answer that."
|
| 21 |
+
Be concise and factual.
|
| 22 |
+
|
| 23 |
+
Context:
|
| 24 |
+
{context}
|
| 25 |
+
|
| 26 |
+
Question: {question}
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# Main QA chain
|
| 31 |
+
# Main class that runs the complete RAG pipeline:
|
| 32 |
+
# 1. Retrieves hybrid context (ChromaDB vectors + Neo4j graph facts)
|
| 33 |
+
# 2. Sends to llama-3.1 LLM via HuggingFace API
|
| 34 |
+
# 3. Returns structured answer with source attribution
|
| 35 |
+
class RAGChain:
|
| 36 |
+
def __init__(self):
|
| 37 |
+
self.retriever = HybridRetriever()
|
| 38 |
+
self.client = Groq(api_key=os.getenv("GROQ_API"))
|
| 39 |
+
print("Mistral chain ready")
|
| 40 |
+
|
| 41 |
+
def ask(self, question: str, verbose: bool = False) -> dict:
|
| 42 |
+
"""
|
| 43 |
+
Full RAG pipeline:
|
| 44 |
+
question β hybrid retrieval β Mistral β answer
|
| 45 |
+
|
| 46 |
+
Returns dict with answer + sources for transparency.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
# Step 1: Retrieve
|
| 50 |
+
retrieval = self.retriever.retrieve(question, k=4)
|
| 51 |
+
context = retrieval["combined_context"]
|
| 52 |
+
|
| 53 |
+
if verbose:
|
| 54 |
+
print("\n Retrieved Context ")
|
| 55 |
+
print(context[:800], "..." if len(context) > 800 else "")
|
| 56 |
+
|
| 57 |
+
# Step 2: Generate
|
| 58 |
+
prompt = build_prompt(context, question)
|
| 59 |
+
|
| 60 |
+
response = self.client.chat.completions.create(
|
| 61 |
+
messages=[{"role": "user", "content": prompt}],
|
| 62 |
+
model=HF_MODEL,
|
| 63 |
+
temperature=0.2,
|
| 64 |
+
max_tokens=300
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# Extract text from chat response
|
| 68 |
+
response_text = response.choices[0].message.content
|
| 69 |
+
|
| 70 |
+
# Clean up response (remove any trailing artifacts)
|
| 71 |
+
answer = response_text.strip()
|
| 72 |
+
|
| 73 |
+
return {
|
| 74 |
+
"question": question,
|
| 75 |
+
"answer": answer,
|
| 76 |
+
"kg_facts": retrieval["kg_facts"],
|
| 77 |
+
"sources": [
|
| 78 |
+
{
|
| 79 |
+
"page": d.metadata.get("page", "?"),
|
| 80 |
+
"snippet": d.page_content[:150]
|
| 81 |
+
}
|
| 82 |
+
for d in retrieval["semantic_chunks"]
|
| 83 |
+
]
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
def close(self):
|
| 87 |
+
self.retriever.close()
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# CLI test
|
| 91 |
+
if __name__ == "__main__":
|
| 92 |
+
chain = RAGChain()
|
| 93 |
+
print("\n KG-RAG Chatbot (type 'exit' to quit)\n")
|
| 94 |
+
|
| 95 |
+
while True:
|
| 96 |
+
question = input("You: ").strip()
|
| 97 |
+
if question.lower() in ["exit", "quit", "q"]:
|
| 98 |
+
break
|
| 99 |
+
if not question:
|
| 100 |
+
continue
|
| 101 |
+
|
| 102 |
+
result = chain.ask(question, verbose=False)
|
| 103 |
+
print(f"\nπ€ Answer:\n{result['answer']}")
|
| 104 |
+
|
| 105 |
+
if result["kg_facts"]:
|
| 106 |
+
print(f"\nπ KG Facts used:\n{result['kg_facts']}")
|
| 107 |
+
|
| 108 |
+
print(f"\nπ Sources: pages {[s['page'] for s in result['sources']]}")
|
| 109 |
+
print("β" * 50)
|
| 110 |
+
|
| 111 |
+
chain.close()
|
ingest.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
+
from langchain_community.document_loaders import PyMuPDFLoader
|
| 4 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 5 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 6 |
+
from langchain_community.vectorstores import Chroma
|
| 7 |
+
|
| 8 |
+
load_dotenv()
|
| 9 |
+
|
| 10 |
+
# Config
|
| 11 |
+
DATA_DIR = "data"
|
| 12 |
+
CHROMA_DB_DIR = "chroma_db"
|
| 13 |
+
EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
|
| 14 |
+
|
| 15 |
+
# Step 1: Load PDF
|
| 16 |
+
def load_document(filename: str):
|
| 17 |
+
path = os.path.join(DATA_DIR, filename)
|
| 18 |
+
if not os.path.exists(path):
|
| 19 |
+
raise FileNotFoundError(f"No file found at {path}. Drop your PDF inside the /data folder.")
|
| 20 |
+
|
| 21 |
+
print(f"Loading document: {filename}")
|
| 22 |
+
loader = PyMuPDFLoader(path)
|
| 23 |
+
docs = loader.load()
|
| 24 |
+
print(f"Loaded {len(docs)} pages")
|
| 25 |
+
return docs
|
| 26 |
+
|
| 27 |
+
# Step 2: Chunk the document
|
| 28 |
+
def chunk_documents(docs):
|
| 29 |
+
splitter = RecursiveCharacterTextSplitter(
|
| 30 |
+
chunk_size=500, # characters per chunks
|
| 31 |
+
chunk_overlap=50, # overlap so context isn't lost at boundaries
|
| 32 |
+
separators=["\n\n", "\n", ".", " "]
|
| 33 |
+
)
|
| 34 |
+
chunks = splitter.split_documents(docs)
|
| 35 |
+
print(f"Split into {len(chunks)} chunks")
|
| 36 |
+
return chunks
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# Step 3: Embed + Store in ChromaDB
|
| 40 |
+
def store_in_chromadb(chunks):
|
| 41 |
+
print(f"Loading embedding model: {EMBED_MODEL}")
|
| 42 |
+
embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL)
|
| 43 |
+
print(f"Storing chunks in ChromaDB at ./{CHROMA_DB_DIR} ...")
|
| 44 |
+
vectorstore = Chroma.from_documents(
|
| 45 |
+
documents=chunks,
|
| 46 |
+
embedding=embeddings,
|
| 47 |
+
persist_directory=CHROMA_DB_DIR
|
| 48 |
+
)
|
| 49 |
+
print(f"ChromaDB ready β {len(chunks)} chunks stored")
|
| 50 |
+
return vectorstore
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# Step 4: Load existing ChromaDB
|
| 54 |
+
def get_vectorstore():
|
| 55 |
+
"""Load an already-persisted ChromaDB (used by retriever.py)."""
|
| 56 |
+
embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL)
|
| 57 |
+
return Chroma(
|
| 58 |
+
persist_directory=CHROMA_DB_DIR,
|
| 59 |
+
embedding_function=embeddings
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# Step 4: Test retrieval
|
| 63 |
+
def test_retrieval(vectorstore, query: str = "What is this document about?"):
|
| 64 |
+
print(f"\nTest query: '{query}'")
|
| 65 |
+
results = vectorstore.similarity_search(query, k=3)
|
| 66 |
+
for i, r in enumerate(results):
|
| 67 |
+
print(f"\n--- Chunk {i+1} (page {r.metadata.get('page', '?')}) ---")
|
| 68 |
+
print(r.page_content[:300])
|
| 69 |
+
|
| 70 |
+
# Main
|
| 71 |
+
if __name__ == "__main__":
|
| 72 |
+
import sys
|
| 73 |
+
|
| 74 |
+
# Pass filename as argument: python ingest.py mypaper.pdf
|
| 75 |
+
# Or default to first PDF found in /data
|
| 76 |
+
if len(sys.argv) > 1:
|
| 77 |
+
filename = sys.argv[1]
|
| 78 |
+
else:
|
| 79 |
+
pdfs = [f for f in os.listdir(DATA_DIR) if f.endswith(".pdf")]
|
| 80 |
+
if not pdfs:
|
| 81 |
+
print("No PDF found in /data. Add one and retry.")
|
| 82 |
+
sys.exit(1)
|
| 83 |
+
filename = pdfs[0]
|
| 84 |
+
print(f"Auto-detected: {filename}")
|
| 85 |
+
|
| 86 |
+
docs = load_document(filename)
|
| 87 |
+
chunks = chunk_documents(docs)
|
| 88 |
+
vs = store_in_chromadb(chunks)
|
| 89 |
+
test_retrieval(vs)
|
kg_builder.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, re, json
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
+
from neo4j import GraphDatabase
|
| 4 |
+
from ingest import load_document, chunk_documents
|
| 5 |
+
from groq import Groq
|
| 6 |
+
|
| 7 |
+
load_dotenv()
|
| 8 |
+
|
| 9 |
+
NEO4J_URI = os.getenv("NEO4J_URI")
|
| 10 |
+
NEO4J_USERNAME = os.getenv("NEO4J_USERNAME")
|
| 11 |
+
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD")
|
| 12 |
+
HF_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
| 13 |
+
GROK_API = os.getenv("GROQ_API")
|
| 14 |
+
|
| 15 |
+
# use llama-3.1-8b-instant via GROQ API
|
| 16 |
+
HF_MODEL = "llama-3.1-8b-instant",
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# knowldege graph builder
|
| 20 |
+
class KnowledgeGraph:
|
| 21 |
+
def __init__(self):
|
| 22 |
+
|
| 23 |
+
# establish a connection with neo4j database and create a driver object
|
| 24 |
+
# the driver manage all communication between python code and graph database
|
| 25 |
+
self.driver = GraphDatabase.driver(
|
| 26 |
+
NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)
|
| 27 |
+
)
|
| 28 |
+
print("Connected to Neo4j")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# utility to close the connection.
|
| 32 |
+
def close(self):
|
| 33 |
+
self.driver.close()
|
| 34 |
+
|
| 35 |
+
# utility to close/reset the session [clear graph entities and relations]
|
| 36 |
+
def clear(self):
|
| 37 |
+
with self.driver.session() as s:
|
| 38 |
+
s.run("MATCH (n) DETACH DELETE n")
|
| 39 |
+
print("Graph cleared")
|
| 40 |
+
|
| 41 |
+
# insert a triple to the graph Knowledge graph (KG)
|
| 42 |
+
# two entities and 1 relation.
|
| 43 |
+
# Uses MERGE to avoid duplicates (creates only if doesn't exist).
|
| 44 |
+
def insert_triple(self, subject: str, relation: str, obj: str):
|
| 45 |
+
query = """
|
| 46 |
+
MERGE (a:Entity {name: $subject})
|
| 47 |
+
MERGE (b:Entity {name: $obj})
|
| 48 |
+
MERGE (a)-[r:RELATION {type: $relation}]->(b)
|
| 49 |
+
"""
|
| 50 |
+
with self.driver.session() as s:
|
| 51 |
+
s.run(query, subject=subject.strip(), relation=relation.strip(), obj=obj.strip())
|
| 52 |
+
|
| 53 |
+
# Searches the knowledge graph for any triples connected to a given entity.
|
| 54 |
+
# Returns up to 10 matching (subject, relation, object) triples.
|
| 55 |
+
def query_entity(self, entity: str) -> list[dict]:
|
| 56 |
+
query = """
|
| 57 |
+
MATCH (a:Entity)-[r]->(b:Entity)
|
| 58 |
+
WHERE toLower(a.name) CONTAINS toLower($entity)
|
| 59 |
+
OR toLower(b.name) CONTAINS toLower($entity)
|
| 60 |
+
RETURN a.name AS subject, r.type AS relation, b.name AS object
|
| 61 |
+
LIMIT 10
|
| 62 |
+
"""
|
| 63 |
+
with self.driver.session() as s:
|
| 64 |
+
result = s.run(query, entity=entity)
|
| 65 |
+
return [dict(record) for record in result]
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# # Extract triples from text using Mistral 7B LLM, returns JSON format, max 8 triples
|
| 69 |
+
def extract_triples(text: str, client: Groq) -> list[tuple]:
|
| 70 |
+
prompt = f"""Extract factual (subject, relation, object) triples from the text below.
|
| 71 |
+
Return ONLY a JSON array like: [{{"subject":"X","relation":"Y","object":"Z"}}]
|
| 72 |
+
Do not add explanation. Max 8 triples.
|
| 73 |
+
|
| 74 |
+
Text:
|
| 75 |
+
{text}
|
| 76 |
+
|
| 77 |
+
JSON:"""
|
| 78 |
+
|
| 79 |
+
chat_completion = client.chat.completions.create(
|
| 80 |
+
messages=[{"role": "user", "content": prompt}],
|
| 81 |
+
model="llama-3.1-8b-instant",
|
| 82 |
+
temperature=0.2,
|
| 83 |
+
max_tokens=300
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# Extract text from chat response
|
| 87 |
+
response_text = chat_completion.choices[0].message.content
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# Parse JSON from response
|
| 91 |
+
try:
|
| 92 |
+
# Find JSON array in the response
|
| 93 |
+
match = re.search(r'\[.*?\]', response_text, re.DOTALL)
|
| 94 |
+
if match:
|
| 95 |
+
triples_raw = json.loads(match.group())
|
| 96 |
+
return [(t["subject"], t["relation"], t["object"]) for t in triples_raw
|
| 97 |
+
if all(k in t for k in ["subject", "relation", "object"])]
|
| 98 |
+
except Exception as e:
|
| 99 |
+
print(f"Parse error: {e}")
|
| 100 |
+
return []
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# ββ Main ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 104 |
+
def build_kg(filename: str):
|
| 105 |
+
import sys, os
|
| 106 |
+
DATA_DIR = "data"
|
| 107 |
+
|
| 108 |
+
docs = load_document(filename)
|
| 109 |
+
chunks = chunk_documents(docs)
|
| 110 |
+
|
| 111 |
+
# Only process first 20 chunks
|
| 112 |
+
chunks = chunks[:20]
|
| 113 |
+
print(f"\n Extracting triples from {len(chunks)} chunks via Mistral")
|
| 114 |
+
|
| 115 |
+
client = Groq(api_key=GROK_API)
|
| 116 |
+
kg = KnowledgeGraph()
|
| 117 |
+
kg.clear()
|
| 118 |
+
|
| 119 |
+
total_triples = 0
|
| 120 |
+
for i, chunk in enumerate(chunks):
|
| 121 |
+
print(f" Chunk {i+1}/{len(chunks)} ...", end=" ")
|
| 122 |
+
triples = extract_triples(chunk.page_content, client)
|
| 123 |
+
for s, r, o in triples:
|
| 124 |
+
kg.insert_triple(s, r, o)
|
| 125 |
+
total_triples += len(triples)
|
| 126 |
+
print(f"{len(triples)} triples")
|
| 127 |
+
|
| 128 |
+
print(f"\n Knowledge Graph built β {total_triples} triples stored in Neo4j")
|
| 129 |
+
kg.close()
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
if __name__ == "__main__":
|
| 133 |
+
import sys, os
|
| 134 |
+
|
| 135 |
+
DATA_DIR = "data"
|
| 136 |
+
if len(sys.argv) > 1:
|
| 137 |
+
filename = sys.argv[1]
|
| 138 |
+
else:
|
| 139 |
+
pdfs = [f for f in os.listdir(DATA_DIR) if f.endswith(".pdf")]
|
| 140 |
+
if not pdfs:
|
| 141 |
+
print(" No PDF in /data.")
|
| 142 |
+
sys.exit(1)
|
| 143 |
+
filename = pdfs[0]
|
| 144 |
+
|
| 145 |
+
build_kg(filename)
|
requirements.txt
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
langchain
|
| 3 |
+
langchain-community
|
| 4 |
+
langchain-text-splitters
|
| 5 |
+
langchain-huggingface
|
| 6 |
+
chromadb
|
| 7 |
+
sentence-transformers
|
| 8 |
+
neo4j
|
| 9 |
+
python-dotenv
|
| 10 |
+
pymupdf
|
| 11 |
+
streamlit
|
| 12 |
+
huggingface_hub
|
| 13 |
+
transformers
|
| 14 |
+
accelerate
|
| 15 |
+
spacy
|
| 16 |
+
|
retriever.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
+
from neo4j import GraphDatabase
|
| 4 |
+
from ingest import get_vectorstore
|
| 5 |
+
|
| 6 |
+
load_dotenv()
|
| 7 |
+
|
| 8 |
+
NEO4J_URI = os.getenv("NEO4J_URI")
|
| 9 |
+
NEO4J_USERNAME = os.getenv("NEO4J_USERNAME")
|
| 10 |
+
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD")
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# Neo4j retrieval
|
| 14 |
+
# Connects to Neo4j and retrieves structured knowledge (triples) based on
|
| 15 |
+
# entity keywords extracted from the user query. Returns formatted facts.
|
| 16 |
+
class Neo4jRetriever:
|
| 17 |
+
def __init__(self):
|
| 18 |
+
|
| 19 |
+
# Establishes connection to Neo4j
|
| 20 |
+
self.driver = GraphDatabase.driver(
|
| 21 |
+
NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
# Cleanup utility: closes the database connection when done
|
| 25 |
+
def close(self):
|
| 26 |
+
self.driver.close()
|
| 27 |
+
|
| 28 |
+
# Core retrieval logic: searches Knowledge Graph for triples where either
|
| 29 |
+
# Subject or Object contains any of the provided keywords.
|
| 30 |
+
def query(self, entity_keywords: list[str]) -> str:
|
| 31 |
+
"""
|
| 32 |
+
Given a list of keywords, find related triples in the graph.
|
| 33 |
+
Returns a formatted string of facts.
|
| 34 |
+
"""
|
| 35 |
+
facts = []
|
| 36 |
+
query = """
|
| 37 |
+
MATCH (a:Entity)-[r]->(b:Entity)
|
| 38 |
+
WHERE ANY(kw IN $keywords WHERE
|
| 39 |
+
toLower(a.name) CONTAINS toLower(kw) OR
|
| 40 |
+
toLower(b.name) CONTAINS toLower(kw))
|
| 41 |
+
RETURN a.name AS subject, r.type AS relation, b.name AS object
|
| 42 |
+
LIMIT 15
|
| 43 |
+
"""
|
| 44 |
+
with self.driver.session() as s:
|
| 45 |
+
results = s.run(query, keywords=entity_keywords)
|
| 46 |
+
for rec in results:
|
| 47 |
+
facts.append(f"{rec['subject']} β {rec['relation']} β {rec['object']}")
|
| 48 |
+
|
| 49 |
+
if facts:
|
| 50 |
+
return "Knowledge Graph Facts:\n" + "\n".join(facts)
|
| 51 |
+
return ""
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# Keyword extractor (simple, no extra model needed)
|
| 55 |
+
def extract_keywords(query: str) -> list[str]:
|
| 56 |
+
"""
|
| 57 |
+
Naive keyword extractor : filters stopwords, returns meaningful tokens.
|
| 58 |
+
Good enough for KG lookup without needing spaCy/NER.
|
| 59 |
+
"""
|
| 60 |
+
stopwords = {
|
| 61 |
+
"what","is","are","the","a","an","of","in","on","at","to","for",
|
| 62 |
+
"how","why","who","when","where","does","do","was","were","has",
|
| 63 |
+
"have","had","be","been","being","and","or","but","with","from",
|
| 64 |
+
"this","that","these","those","it","its","their","there","about",
|
| 65 |
+
"can","could","would","should","will","tell","me","explain","give"
|
| 66 |
+
}
|
| 67 |
+
tokens = query.lower().split()
|
| 68 |
+
keywords = [t.strip("?.!,") for t in tokens if t not in stopwords and len(t) > 2]
|
| 69 |
+
return keywords
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# Hybrid Retriever
|
| 73 |
+
# Combines both Vector Search (ChromaDB for semantic similarity) and
|
| 74 |
+
# Graph Search (Neo4j for logical connections) to provide comprehensive
|
| 75 |
+
# context to the LLM. This is the main retrieval engine of your RAG system
|
| 76 |
+
class HybridRetriever:
|
| 77 |
+
def __init__(self):
|
| 78 |
+
print("Loading ChromaDB vectorstore...")
|
| 79 |
+
self.vectorstore = get_vectorstore()
|
| 80 |
+
self.neo4j = Neo4jRetriever()
|
| 81 |
+
print("Hybrid retriever ready")
|
| 82 |
+
|
| 83 |
+
def retrieve(self, query: str, k: int = 4) -> dict:
|
| 84 |
+
"""
|
| 85 |
+
Returns:
|
| 86 |
+
{
|
| 87 |
+
"semantic_chunks": [...], # from ChromaDB
|
| 88 |
+
"kg_facts": "...", # from Neo4j
|
| 89 |
+
"combined_context": "..." # merged string for LLM
|
| 90 |
+
}
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
# 1. Semantic retrieval from ChromaDB
|
| 94 |
+
semantic_docs = self.vectorstore.similarity_search(query, k=k)
|
| 95 |
+
semantic_text = "\n\n".join([d.page_content for d in semantic_docs])
|
| 96 |
+
|
| 97 |
+
# 2. KG retrieval from Neo4j
|
| 98 |
+
keywords = extract_keywords(query)
|
| 99 |
+
kg_facts = self.neo4j.query(keywords)
|
| 100 |
+
|
| 101 |
+
# 3. Combine
|
| 102 |
+
combined = ""
|
| 103 |
+
if kg_facts:
|
| 104 |
+
combined += f"{kg_facts}\n\n"
|
| 105 |
+
combined += f"Document Excerpts:\n{semantic_text}"
|
| 106 |
+
|
| 107 |
+
return {
|
| 108 |
+
"semantic_chunks": semantic_docs,
|
| 109 |
+
"kg_facts": kg_facts,
|
| 110 |
+
"combined_context": combined
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
def close(self):
|
| 114 |
+
self.neo4j.close()
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# ββ Quick test ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 118 |
+
if __name__ == "__main__":
|
| 119 |
+
retriever = HybridRetriever()
|
| 120 |
+
query = input("Enter a test query: ")
|
| 121 |
+
result = retriever.retrieve(query)
|
| 122 |
+
|
| 123 |
+
print("\nββ KG Facts ββββββββββββββββββββββββββββββ")
|
| 124 |
+
print(result["kg_facts"] or "(none found)")
|
| 125 |
+
print("\nββ Semantic Chunks βββββββββββββββββββββββ")
|
| 126 |
+
for i, doc in enumerate(result["semantic_chunks"]):
|
| 127 |
+
print(f"\nChunk {i+1}: {doc.page_content[:200]}")
|
| 128 |
+
retriever.close()
|