Rupesx007 commited on
Commit
73ad28e
Β·
verified Β·
1 Parent(s): 49000b3

Upload 8 files

Browse files

added necessary files.

Files changed (8) hide show
  1. .env +5 -0
  2. .gitignore +0 -0
  3. app.py +167 -0
  4. chain.py +111 -0
  5. ingest.py +89 -0
  6. kg_builder.py +145 -0
  7. requirements.txt +16 -0
  8. retriever.py +128 -0
.env ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ HUGGINGFACEHUB_API_TOKEN=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
2
+ NEO4J_URI=neo4j+s://b5a2cd23.databases.neo4j.io
3
+ NEO4J_USERNAME=neo4j
4
+ NEO4J_PASSWORD=6Cm5YT8pZUVxHdrXb3oCtyPNnAN8w7FHaZKqs-4ifLQ
5
+ GROQ_API=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
.gitignore ADDED
File without changes
app.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os, sys, tempfile
3
+ from ingest import load_document, chunk_documents, store_in_chromadb, get_vectorstore
4
+ from kg_builder import build_kg
5
+ from chain import RAGChain
6
+
7
+ # Page config
8
+ st.set_page_config(
9
+ page_title="KG-RAG Chatbot",
10
+ page_icon="🧠",
11
+ layout="wide"
12
+ )
13
+
14
+ st.title("🧠 KG-RAG Chatbot")
15
+ st.caption("Upload a PDF β†’ Build Knowledge Graph β†’ Chat with your document")
16
+
17
+ # Sidebar
18
+ with st.sidebar:
19
+ st.header("Upload Document")
20
+
21
+ uploaded_file = st.file_uploader("Upload a PDF", type=["pdf"])
22
+
23
+ if uploaded_file:
24
+ # Save uploaded file to /data directory
25
+ os.makedirs("data", exist_ok=True)
26
+ save_path = os.path.join("data", uploaded_file.name)
27
+
28
+ with open(save_path, "wb") as f:
29
+ f.write(uploaded_file.read())
30
+
31
+ st.success(f"Saved: {uploaded_file.name}")
32
+
33
+ # Process button
34
+ if st.button("Process PDF", type="primary", use_container_width=True):
35
+ # Clear old chat + chain
36
+ st.session_state.messages = []
37
+ st.session_state.chain_ready = False
38
+ st.session_state.pop("chain", None)
39
+
40
+ # Step 1: Ingest β†’ ChromaDB
41
+ with st.status("Ingesting document...", expanded=True) as status:
42
+ try:
43
+ st.write("Loading PDF pages...")
44
+ docs = load_document(uploaded_file.name)
45
+ st.write(f"{len(docs)} pages loaded")
46
+
47
+ st.write("Chunking text...")
48
+ chunks = chunk_documents(docs)
49
+ st.write(f"{len(chunks)} chunks created")
50
+
51
+ st.write("Embedding + saving to ChromaDB...")
52
+ store_in_chromadb(chunks)
53
+ st.write("ChromaDB ready")
54
+
55
+ # Step 2: Build KG β†’ Neo4j
56
+ st.write("Extracting triples β†’ Neo4j Knowledge Graph...")
57
+ build_kg(uploaded_file.name)
58
+ st.write("Knowledge Graph built")
59
+
60
+ status.update(label="PDF processed! You can now chat.", state="complete")
61
+ st.session_state.chain_ready = True
62
+ st.session_state.current_doc = uploaded_file.name
63
+
64
+ except Exception as e:
65
+ status.update(label="Processing failed", state="error")
66
+ st.error(str(e))
67
+
68
+ st.divider()
69
+
70
+ # Settings
71
+ st.header("Settings")
72
+ show_kg = st.toggle("Show KG facts", value=True)
73
+ show_chunks = st.toggle("Show source chunks", value=False)
74
+
75
+ st.divider()
76
+ st.markdown("""
77
+ **How it works:**
78
+ 1. Upload + Process your PDF
79
+ 2. **ChromaDB** stores semantic chunks
80
+ 3. **Neo4j** stores entity relationships
81
+ 4. **LLaMA via Groq** answers your questions
82
+ """)
83
+
84
+ if st.button("πŸ—‘οΈ Clear chat history", use_container_width=True):
85
+ st.session_state.messages = []
86
+ st.rerun()
87
+
88
+
89
+ # Load RAG chain (only after PDF is processed)
90
+ @st.cache_resource
91
+ def load_chain():
92
+ return RAGChain()
93
+
94
+
95
+ # Main area
96
+ if not st.session_state.get("chain_ready"):
97
+ # Show a friendly landing state
98
+ st.info("Upload a PDF from the sidebar and click **Process PDF** to get started.")
99
+
100
+ col1, col2, col3 = st.columns(3)
101
+ with col1:
102
+ st.markdown("### Step 1\nUpload any PDF document from the sidebar")
103
+ with col2:
104
+ st.markdown("### Step 2\nClick **Process PDF** to build the knowledge graph")
105
+ with col3:
106
+ st.markdown("### Step 3\nAsk questions β€” get answers with KG + RAG context")
107
+
108
+ else:
109
+ # Show which doc is loaded
110
+ st.success(f"Active document: **{st.session_state.get('current_doc', 'Unknown')}**")
111
+
112
+ # Load chain (cached)
113
+ try:
114
+ if "chain" not in st.session_state:
115
+ with st.spinner("Loading RAG chain..."):
116
+ st.session_state.chain = load_chain()
117
+ chain = st.session_state.chain
118
+ except Exception as e:
119
+ st.error(f"Failed to load chain: {e}")
120
+ st.stop()
121
+
122
+ # ── Chat history ──────────────────────────────────────────────────────────
123
+ if "messages" not in st.session_state:
124
+ st.session_state.messages = []
125
+
126
+ for msg in st.session_state.messages:
127
+ with st.chat_message(msg["role"]):
128
+ st.markdown(msg["content"])
129
+ if msg.get("kg_facts") and show_kg:
130
+ with st.expander("Knowledge Graph Facts"):
131
+ st.code(msg["kg_facts"])
132
+ if msg.get("sources") and show_chunks:
133
+ with st.expander("Source Chunks"):
134
+ for s in msg["sources"]:
135
+ st.markdown(f"**Page {s['page']}:** {s['snippet']}...")
136
+
137
+ # Chat input
138
+ if question := st.chat_input("Ask anything about your document..."):
139
+ st.session_state.messages.append({"role": "user", "content": question})
140
+ with st.chat_message("user"):
141
+ st.markdown(question)
142
+
143
+ with st.chat_message("assistant"):
144
+ with st.spinner("Retrieving from ChromaDB + Neo4j β†’ asking LLM..."):
145
+ try:
146
+ result = chain.ask(question)
147
+ answer = result["answer"]
148
+ st.markdown(answer)
149
+
150
+ if result["kg_facts"] and show_kg:
151
+ with st.expander("Knowledge Graph Facts used"):
152
+ st.code(result["kg_facts"])
153
+
154
+ if show_chunks:
155
+ with st.expander("Source Chunks"):
156
+ for s in result["sources"]:
157
+ st.markdown(f"**Page {s['page']}:** {s['snippet']}...")
158
+
159
+ st.session_state.messages.append({
160
+ "role": "assistant",
161
+ "content": answer,
162
+ "kg_facts": result["kg_facts"],
163
+ "sources": result["sources"]
164
+ })
165
+
166
+ except Exception as e:
167
+ st.error(f"Error: {e}")
chain.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from huggingface_hub import InferenceClient
4
+ from retriever import HybridRetriever
5
+ from groq import Groq
6
+
7
+ load_dotenv()
8
+
9
+ HF_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
10
+ HF_MODEL = "llama-3.1-8b-instant"
11
+
12
+
13
+ # Prompt builder
14
+ # Constructs the final prompt sent to Mistral LLM. Uses the [INST] instruction format
15
+ # specific to Mistral-Instruct models. Injects the retrieved context (vector + graph)
16
+ # and the user question, with strict instructions to only use provided context.
17
+ def build_prompt(context: str, question: str) -> str:
18
+ return f"""
19
+ You are a helpful research assistant. Answer the question using ONLY the context provided.
20
+ If the answer is not in the context, say "I don't have enough information to answer that."
21
+ Be concise and factual.
22
+
23
+ Context:
24
+ {context}
25
+
26
+ Question: {question}
27
+ """
28
+
29
+
30
+ # Main QA chain
31
+ # Main class that runs the complete RAG pipeline:
32
+ # 1. Retrieves hybrid context (ChromaDB vectors + Neo4j graph facts)
33
+ # 2. Sends to llama-3.1 LLM via HuggingFace API
34
+ # 3. Returns structured answer with source attribution
35
+ class RAGChain:
36
+ def __init__(self):
37
+ self.retriever = HybridRetriever()
38
+ self.client = Groq(api_key=os.getenv("GROQ_API"))
39
+ print("Mistral chain ready")
40
+
41
+ def ask(self, question: str, verbose: bool = False) -> dict:
42
+ """
43
+ Full RAG pipeline:
44
+ question β†’ hybrid retrieval β†’ Mistral β†’ answer
45
+
46
+ Returns dict with answer + sources for transparency.
47
+ """
48
+
49
+ # Step 1: Retrieve
50
+ retrieval = self.retriever.retrieve(question, k=4)
51
+ context = retrieval["combined_context"]
52
+
53
+ if verbose:
54
+ print("\n Retrieved Context ")
55
+ print(context[:800], "..." if len(context) > 800 else "")
56
+
57
+ # Step 2: Generate
58
+ prompt = build_prompt(context, question)
59
+
60
+ response = self.client.chat.completions.create(
61
+ messages=[{"role": "user", "content": prompt}],
62
+ model=HF_MODEL,
63
+ temperature=0.2,
64
+ max_tokens=300
65
+ )
66
+
67
+ # Extract text from chat response
68
+ response_text = response.choices[0].message.content
69
+
70
+ # Clean up response (remove any trailing artifacts)
71
+ answer = response_text.strip()
72
+
73
+ return {
74
+ "question": question,
75
+ "answer": answer,
76
+ "kg_facts": retrieval["kg_facts"],
77
+ "sources": [
78
+ {
79
+ "page": d.metadata.get("page", "?"),
80
+ "snippet": d.page_content[:150]
81
+ }
82
+ for d in retrieval["semantic_chunks"]
83
+ ]
84
+ }
85
+
86
+ def close(self):
87
+ self.retriever.close()
88
+
89
+
90
+ # CLI test
91
+ if __name__ == "__main__":
92
+ chain = RAGChain()
93
+ print("\n KG-RAG Chatbot (type 'exit' to quit)\n")
94
+
95
+ while True:
96
+ question = input("You: ").strip()
97
+ if question.lower() in ["exit", "quit", "q"]:
98
+ break
99
+ if not question:
100
+ continue
101
+
102
+ result = chain.ask(question, verbose=False)
103
+ print(f"\nπŸ€– Answer:\n{result['answer']}")
104
+
105
+ if result["kg_facts"]:
106
+ print(f"\nπŸ“Š KG Facts used:\n{result['kg_facts']}")
107
+
108
+ print(f"\nπŸ“„ Sources: pages {[s['page'] for s in result['sources']]}")
109
+ print("─" * 50)
110
+
111
+ chain.close()
ingest.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from langchain_community.document_loaders import PyMuPDFLoader
4
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
5
+ from langchain_huggingface import HuggingFaceEmbeddings
6
+ from langchain_community.vectorstores import Chroma
7
+
8
+ load_dotenv()
9
+
10
+ # Config
11
+ DATA_DIR = "data"
12
+ CHROMA_DB_DIR = "chroma_db"
13
+ EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
14
+
15
+ # Step 1: Load PDF
16
+ def load_document(filename: str):
17
+ path = os.path.join(DATA_DIR, filename)
18
+ if not os.path.exists(path):
19
+ raise FileNotFoundError(f"No file found at {path}. Drop your PDF inside the /data folder.")
20
+
21
+ print(f"Loading document: {filename}")
22
+ loader = PyMuPDFLoader(path)
23
+ docs = loader.load()
24
+ print(f"Loaded {len(docs)} pages")
25
+ return docs
26
+
27
+ # Step 2: Chunk the document
28
+ def chunk_documents(docs):
29
+ splitter = RecursiveCharacterTextSplitter(
30
+ chunk_size=500, # characters per chunks
31
+ chunk_overlap=50, # overlap so context isn't lost at boundaries
32
+ separators=["\n\n", "\n", ".", " "]
33
+ )
34
+ chunks = splitter.split_documents(docs)
35
+ print(f"Split into {len(chunks)} chunks")
36
+ return chunks
37
+
38
+
39
+ # Step 3: Embed + Store in ChromaDB
40
+ def store_in_chromadb(chunks):
41
+ print(f"Loading embedding model: {EMBED_MODEL}")
42
+ embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL)
43
+ print(f"Storing chunks in ChromaDB at ./{CHROMA_DB_DIR} ...")
44
+ vectorstore = Chroma.from_documents(
45
+ documents=chunks,
46
+ embedding=embeddings,
47
+ persist_directory=CHROMA_DB_DIR
48
+ )
49
+ print(f"ChromaDB ready β€” {len(chunks)} chunks stored")
50
+ return vectorstore
51
+
52
+
53
+ # Step 4: Load existing ChromaDB
54
+ def get_vectorstore():
55
+ """Load an already-persisted ChromaDB (used by retriever.py)."""
56
+ embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL)
57
+ return Chroma(
58
+ persist_directory=CHROMA_DB_DIR,
59
+ embedding_function=embeddings
60
+ )
61
+
62
+ # Step 4: Test retrieval
63
+ def test_retrieval(vectorstore, query: str = "What is this document about?"):
64
+ print(f"\nTest query: '{query}'")
65
+ results = vectorstore.similarity_search(query, k=3)
66
+ for i, r in enumerate(results):
67
+ print(f"\n--- Chunk {i+1} (page {r.metadata.get('page', '?')}) ---")
68
+ print(r.page_content[:300])
69
+
70
+ # Main
71
+ if __name__ == "__main__":
72
+ import sys
73
+
74
+ # Pass filename as argument: python ingest.py mypaper.pdf
75
+ # Or default to first PDF found in /data
76
+ if len(sys.argv) > 1:
77
+ filename = sys.argv[1]
78
+ else:
79
+ pdfs = [f for f in os.listdir(DATA_DIR) if f.endswith(".pdf")]
80
+ if not pdfs:
81
+ print("No PDF found in /data. Add one and retry.")
82
+ sys.exit(1)
83
+ filename = pdfs[0]
84
+ print(f"Auto-detected: {filename}")
85
+
86
+ docs = load_document(filename)
87
+ chunks = chunk_documents(docs)
88
+ vs = store_in_chromadb(chunks)
89
+ test_retrieval(vs)
kg_builder.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, re, json
2
+ from dotenv import load_dotenv
3
+ from neo4j import GraphDatabase
4
+ from ingest import load_document, chunk_documents
5
+ from groq import Groq
6
+
7
+ load_dotenv()
8
+
9
+ NEO4J_URI = os.getenv("NEO4J_URI")
10
+ NEO4J_USERNAME = os.getenv("NEO4J_USERNAME")
11
+ NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD")
12
+ HF_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
13
+ GROK_API = os.getenv("GROQ_API")
14
+
15
+ # use llama-3.1-8b-instant via GROQ API
16
+ HF_MODEL = "llama-3.1-8b-instant",
17
+
18
+
19
+ # knowldege graph builder
20
+ class KnowledgeGraph:
21
+ def __init__(self):
22
+
23
+ # establish a connection with neo4j database and create a driver object
24
+ # the driver manage all communication between python code and graph database
25
+ self.driver = GraphDatabase.driver(
26
+ NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)
27
+ )
28
+ print("Connected to Neo4j")
29
+
30
+
31
+ # utility to close the connection.
32
+ def close(self):
33
+ self.driver.close()
34
+
35
+ # utility to close/reset the session [clear graph entities and relations]
36
+ def clear(self):
37
+ with self.driver.session() as s:
38
+ s.run("MATCH (n) DETACH DELETE n")
39
+ print("Graph cleared")
40
+
41
+ # insert a triple to the graph Knowledge graph (KG)
42
+ # two entities and 1 relation.
43
+ # Uses MERGE to avoid duplicates (creates only if doesn't exist).
44
+ def insert_triple(self, subject: str, relation: str, obj: str):
45
+ query = """
46
+ MERGE (a:Entity {name: $subject})
47
+ MERGE (b:Entity {name: $obj})
48
+ MERGE (a)-[r:RELATION {type: $relation}]->(b)
49
+ """
50
+ with self.driver.session() as s:
51
+ s.run(query, subject=subject.strip(), relation=relation.strip(), obj=obj.strip())
52
+
53
+ # Searches the knowledge graph for any triples connected to a given entity.
54
+ # Returns up to 10 matching (subject, relation, object) triples.
55
+ def query_entity(self, entity: str) -> list[dict]:
56
+ query = """
57
+ MATCH (a:Entity)-[r]->(b:Entity)
58
+ WHERE toLower(a.name) CONTAINS toLower($entity)
59
+ OR toLower(b.name) CONTAINS toLower($entity)
60
+ RETURN a.name AS subject, r.type AS relation, b.name AS object
61
+ LIMIT 10
62
+ """
63
+ with self.driver.session() as s:
64
+ result = s.run(query, entity=entity)
65
+ return [dict(record) for record in result]
66
+
67
+
68
+ # # Extract triples from text using Mistral 7B LLM, returns JSON format, max 8 triples
69
+ def extract_triples(text: str, client: Groq) -> list[tuple]:
70
+ prompt = f"""Extract factual (subject, relation, object) triples from the text below.
71
+ Return ONLY a JSON array like: [{{"subject":"X","relation":"Y","object":"Z"}}]
72
+ Do not add explanation. Max 8 triples.
73
+
74
+ Text:
75
+ {text}
76
+
77
+ JSON:"""
78
+
79
+ chat_completion = client.chat.completions.create(
80
+ messages=[{"role": "user", "content": prompt}],
81
+ model="llama-3.1-8b-instant",
82
+ temperature=0.2,
83
+ max_tokens=300
84
+ )
85
+
86
+ # Extract text from chat response
87
+ response_text = chat_completion.choices[0].message.content
88
+
89
+
90
+ # Parse JSON from response
91
+ try:
92
+ # Find JSON array in the response
93
+ match = re.search(r'\[.*?\]', response_text, re.DOTALL)
94
+ if match:
95
+ triples_raw = json.loads(match.group())
96
+ return [(t["subject"], t["relation"], t["object"]) for t in triples_raw
97
+ if all(k in t for k in ["subject", "relation", "object"])]
98
+ except Exception as e:
99
+ print(f"Parse error: {e}")
100
+ return []
101
+
102
+
103
+ # ── Main ──────────────────────────────────────────────────────────────────────
104
+ def build_kg(filename: str):
105
+ import sys, os
106
+ DATA_DIR = "data"
107
+
108
+ docs = load_document(filename)
109
+ chunks = chunk_documents(docs)
110
+
111
+ # Only process first 20 chunks
112
+ chunks = chunks[:20]
113
+ print(f"\n Extracting triples from {len(chunks)} chunks via Mistral")
114
+
115
+ client = Groq(api_key=GROK_API)
116
+ kg = KnowledgeGraph()
117
+ kg.clear()
118
+
119
+ total_triples = 0
120
+ for i, chunk in enumerate(chunks):
121
+ print(f" Chunk {i+1}/{len(chunks)} ...", end=" ")
122
+ triples = extract_triples(chunk.page_content, client)
123
+ for s, r, o in triples:
124
+ kg.insert_triple(s, r, o)
125
+ total_triples += len(triples)
126
+ print(f"{len(triples)} triples")
127
+
128
+ print(f"\n Knowledge Graph built β€” {total_triples} triples stored in Neo4j")
129
+ kg.close()
130
+
131
+
132
+ if __name__ == "__main__":
133
+ import sys, os
134
+
135
+ DATA_DIR = "data"
136
+ if len(sys.argv) > 1:
137
+ filename = sys.argv[1]
138
+ else:
139
+ pdfs = [f for f in os.listdir(DATA_DIR) if f.endswith(".pdf")]
140
+ if not pdfs:
141
+ print(" No PDF in /data.")
142
+ sys.exit(1)
143
+ filename = pdfs[0]
144
+
145
+ build_kg(filename)
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ langchain
3
+ langchain-community
4
+ langchain-text-splitters
5
+ langchain-huggingface
6
+ chromadb
7
+ sentence-transformers
8
+ neo4j
9
+ python-dotenv
10
+ pymupdf
11
+ streamlit
12
+ huggingface_hub
13
+ transformers
14
+ accelerate
15
+ spacy
16
+
retriever.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from neo4j import GraphDatabase
4
+ from ingest import get_vectorstore
5
+
6
+ load_dotenv()
7
+
8
+ NEO4J_URI = os.getenv("NEO4J_URI")
9
+ NEO4J_USERNAME = os.getenv("NEO4J_USERNAME")
10
+ NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD")
11
+
12
+
13
+ # Neo4j retrieval
14
+ # Connects to Neo4j and retrieves structured knowledge (triples) based on
15
+ # entity keywords extracted from the user query. Returns formatted facts.
16
+ class Neo4jRetriever:
17
+ def __init__(self):
18
+
19
+ # Establishes connection to Neo4j
20
+ self.driver = GraphDatabase.driver(
21
+ NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)
22
+ )
23
+
24
+ # Cleanup utility: closes the database connection when done
25
+ def close(self):
26
+ self.driver.close()
27
+
28
+ # Core retrieval logic: searches Knowledge Graph for triples where either
29
+ # Subject or Object contains any of the provided keywords.
30
+ def query(self, entity_keywords: list[str]) -> str:
31
+ """
32
+ Given a list of keywords, find related triples in the graph.
33
+ Returns a formatted string of facts.
34
+ """
35
+ facts = []
36
+ query = """
37
+ MATCH (a:Entity)-[r]->(b:Entity)
38
+ WHERE ANY(kw IN $keywords WHERE
39
+ toLower(a.name) CONTAINS toLower(kw) OR
40
+ toLower(b.name) CONTAINS toLower(kw))
41
+ RETURN a.name AS subject, r.type AS relation, b.name AS object
42
+ LIMIT 15
43
+ """
44
+ with self.driver.session() as s:
45
+ results = s.run(query, keywords=entity_keywords)
46
+ for rec in results:
47
+ facts.append(f"{rec['subject']} β†’ {rec['relation']} β†’ {rec['object']}")
48
+
49
+ if facts:
50
+ return "Knowledge Graph Facts:\n" + "\n".join(facts)
51
+ return ""
52
+
53
+
54
+ # Keyword extractor (simple, no extra model needed)
55
+ def extract_keywords(query: str) -> list[str]:
56
+ """
57
+ Naive keyword extractor : filters stopwords, returns meaningful tokens.
58
+ Good enough for KG lookup without needing spaCy/NER.
59
+ """
60
+ stopwords = {
61
+ "what","is","are","the","a","an","of","in","on","at","to","for",
62
+ "how","why","who","when","where","does","do","was","were","has",
63
+ "have","had","be","been","being","and","or","but","with","from",
64
+ "this","that","these","those","it","its","their","there","about",
65
+ "can","could","would","should","will","tell","me","explain","give"
66
+ }
67
+ tokens = query.lower().split()
68
+ keywords = [t.strip("?.!,") for t in tokens if t not in stopwords and len(t) > 2]
69
+ return keywords
70
+
71
+
72
+ # Hybrid Retriever
73
+ # Combines both Vector Search (ChromaDB for semantic similarity) and
74
+ # Graph Search (Neo4j for logical connections) to provide comprehensive
75
+ # context to the LLM. This is the main retrieval engine of your RAG system
76
+ class HybridRetriever:
77
+ def __init__(self):
78
+ print("Loading ChromaDB vectorstore...")
79
+ self.vectorstore = get_vectorstore()
80
+ self.neo4j = Neo4jRetriever()
81
+ print("Hybrid retriever ready")
82
+
83
+ def retrieve(self, query: str, k: int = 4) -> dict:
84
+ """
85
+ Returns:
86
+ {
87
+ "semantic_chunks": [...], # from ChromaDB
88
+ "kg_facts": "...", # from Neo4j
89
+ "combined_context": "..." # merged string for LLM
90
+ }
91
+ """
92
+
93
+ # 1. Semantic retrieval from ChromaDB
94
+ semantic_docs = self.vectorstore.similarity_search(query, k=k)
95
+ semantic_text = "\n\n".join([d.page_content for d in semantic_docs])
96
+
97
+ # 2. KG retrieval from Neo4j
98
+ keywords = extract_keywords(query)
99
+ kg_facts = self.neo4j.query(keywords)
100
+
101
+ # 3. Combine
102
+ combined = ""
103
+ if kg_facts:
104
+ combined += f"{kg_facts}\n\n"
105
+ combined += f"Document Excerpts:\n{semantic_text}"
106
+
107
+ return {
108
+ "semantic_chunks": semantic_docs,
109
+ "kg_facts": kg_facts,
110
+ "combined_context": combined
111
+ }
112
+
113
+ def close(self):
114
+ self.neo4j.close()
115
+
116
+
117
+ # ── Quick test ────────────────────────────────────────────────────────────────
118
+ if __name__ == "__main__":
119
+ retriever = HybridRetriever()
120
+ query = input("Enter a test query: ")
121
+ result = retriever.retrieve(query)
122
+
123
+ print("\n── KG Facts ──────────────────────────────")
124
+ print(result["kg_facts"] or "(none found)")
125
+ print("\n── Semantic Chunks ───────────────────────")
126
+ for i, doc in enumerate(result["semantic_chunks"]):
127
+ print(f"\nChunk {i+1}: {doc.page_content[:200]}")
128
+ retriever.close()