File size: 5,213 Bytes
925e68e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
# app.py
import os
import gradio as gr
from langchain_community.vectorstores.neo4j_vector import remove_lucene_chars
from langchain_community.graphs import Neo4jGraph
from langchain_experimental.graph_transformers import LLMGraphTransformer
from langchain.document_loaders import PyPDFLoader
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Neo4jVector
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableParallel, RunnableLambda
from langchain_core.pydantic_v1 import BaseModel, Field
# --- API & DB Setup ---
os.environ["GROQ_API_KEY"] = "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"
NEO4J_URI = "neo4j+s://491b8299.databases.neo4j.io"
NEO4J_USERNAME = "neo4j"
NEO4J_PASSWORD = "W3i8UiePw9QyaSJxK9l_apbzUnzh10YWxZQtnpSS02I"
graph = Neo4jGraph(url=NEO4J_URI, username=NEO4J_USERNAME, password=NEO4J_PASSWORD)
llm = ChatGroq(model="llama3-8b-8192")
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
llm_transformer = LLMGraphTransformer(llm=llm)
# --- Entity Extraction Schema ---
class Entities(BaseModel):
names: list[str] = Field(..., description="All person, org, or business names")
entity_prompt = ChatPromptTemplate.from_messages([
("system", "you are extracting organization and person entities from the text"),
("human", "Use the given format to extract entities:\ninput: {question}")
])
entity_chain = entity_prompt | llm.with_structured_output(Entities)
# --- Helpers ---
def generate_full_text_query(input: str) -> str:
words = [el for el in remove_lucene_chars(input).split() if el]
return " AND ".join([f"{word}~2" for word in words])
def structured_retriever(question: str) -> str:
entities = entity_chain.invoke({"question": question})
result = ""
for entity in entities.names:
cypher = """
CALL db.index.fulltext.queryNodes('entity', $query, {limit:2})
YIELD node,score
CALL {
WITH node
MATCH (node)-[r:!MENTIONS]->(neighbor)
RETURN node.id + '-' + type(r) + '->' + neighbor.id AS output
UNION ALL
WITH node
MATCH (node)<-[r:!MENTIONS]-(neighbor)
RETURN neighbor.id + '-' + type(r) + '->' + node.id AS output
}
RETURN output LIMIT 50
"""
response = graph.query(cypher, {"query": generate_full_text_query(entity)})
result += "\n".join([el['output'] for el in response])
return result
def retriever(question: str) -> str:
structured = structured_retriever(question)
unstructured = [el.page_content for el in vector_index.similarity_search(question)]
return f"Structured Data:\n{structured}\n\nUnstructured Data:\n" + "\n---\n".join(unstructured)
# --- RAG Chain ---
template = """Answer the question based only on the context:
{context}
Question: {question}
Use natural language and be concise.
Answer:"""
qa_prompt = ChatPromptTemplate.from_template(template)
chain = (
RunnableParallel({
"context": RunnableLambda(lambda x: retriever(x["question"])),
"question": RunnableLambda(lambda x: x["question"]),
})
| qa_prompt
| llm
| StrOutputParser()
)
# --- Gradio Pipeline ---
vector_index = None
def process_pdf(pdf_file):
global vector_index
loader = PyPDFLoader(pdf_file.name)
docs = loader.load()
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
docs_split = splitter.split_documents(docs)
graph_docs = []
for i in range(0, len(docs_split), 2):
try:
graph_docs.extend(llm_transformer.convert_to_graph_documents(docs_split[i:i+2]))
except Exception as e:
print(f"Error: {e}")
graph.add_graph_documents(graph_docs, baseEntityLabel=True, include_source=True)
graph.query("CREATE FULLTEXT INDEX entity IF NOT EXISTS FOR (e:__Entity__) ON EACH [e.id]")
vector_index = Neo4jVector.from_existing_graph(
embedding_model,
search_type="hybrid",
graph=graph,
node_label="Document",
embedding_node_property="embedding",
text_node_properties=["text"]
)
return "PDF uploaded and processed successfully!"
def chat_with_doc(question):
if vector_index is None:
return "Please upload and process a PDF first."
return chain.invoke({"question": question})
# --- Gradio UI ---
with gr.Blocks() as demo:
gr.Markdown("## 🧠 Graph RAG PDF Q&A")
with gr.Row():
pdf_input = gr.File(label="Upload PDF")
upload_btn = gr.Button("Process PDF")
output_info = gr.Textbox(label="Status", interactive=False)
with gr.Row():
question_input = gr.Textbox(label="Ask a Question")
ask_btn = gr.Button("Get Answer")
answer_output = gr.Textbox(label="Answer")
upload_btn.click(process_pdf, inputs=[pdf_input], outputs=[output_info])
ask_btn.click(chat_with_doc, inputs=[question_input], outputs=[answer_output])
demo.launch() |