adi-123's picture
Upload builder.py
0727871 verified
"""GraphRAG builder for PDF ingestion."""
from __future__ import annotations
import os
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Dict, Generator, List, Optional, Tuple
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.prompts import PromptTemplate
from src.config import get_logger, trace_flow, log_step
# LangChain imports with compatibility handling
try:
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores import Neo4jVector
except ImportError:
from langchain.document_loaders import PyPDFLoader
from langchain.vectorstores import Neo4jVector
from langchain_experimental.graph_transformers import LLMGraphTransformer
from langchain_community.chains.graph_qa.cypher import GraphCypherQAChain
from langchain_together import ChatTogether, TogetherEmbeddings
from src.config.schema import SchemaPolicy
from src.config.settings import Neo4jConfig, TogetherAIConfig
from src.models.state import AppState
from src.parsers.project_parser import ProjectReportParser
from src.parsers.smart_chunker import SemanticChunker
from src.services.neo4j_service import Neo4jService, Neo4jConnectionError
# Module logger
logger = get_logger(__name__)
class GraphRAGBuilder:
"""Builds and populates Neo4j-backed GraphRAG resources from uploaded PDFs.
Responsibilities:
- Configure Together AI chat + embeddings models.
- Parse PDFs into pages and chunks with provenance metadata.
- Upsert deterministic structured graph nodes for stable Q/A.
- Run LLMGraphTransformer for broader entity/relationship extraction.
- Create/refresh Neo4jVector hybrid indexes.
- Create GraphCypherQAChain for graph-native Q/A.
This class is intentionally stateless across runs; it returns AppState
for query-time usage.
Attributes:
llm: Chat model instance.
embeddings: Embeddings model instance.
Example:
>>> builder = GraphRAGBuilder(
... together_config=TogetherAIConfig(api_key="key")
... )
>>> message, state = builder.ingest(pdf_files, neo4j_config)
"""
# Chunk configuration
DEFAULT_CHUNK_SIZE = 900
DEFAULT_CHUNK_OVERLAP = 150
# Parallel extraction configuration (optimized for speed)
EXTRACTION_BATCH_SIZE = 8 # Increased from 5
MAX_EXTRACTION_WORKERS = 5 # Increased from 3
# Vector index configuration
INDEX_NAME = "project_chunks_vector"
KEYWORD_INDEX_NAME = "project_chunks_keyword"
NODE_LABEL = "Chunk"
# Enhanced Cypher QA prompt with examples
CYPHER_PROMPT_TEMPLATE = """You are a Neo4j Cypher expert. Generate a Cypher query to answer the question.
## Schema
{schema}
## Key Patterns
1. **Project with Budget and Location:**
```cypher
MATCH (p:Project)
OPTIONAL MATCH (p)-[:HAS_BUDGET]->(b:Budget)
OPTIONAL MATCH (p)-[:LOCATED_IN]->(l:Location)
RETURN p.name, b.amount, b.currency, l.city, l.country
```
2. **Project Milestones/Timeline:**
```cypher
MATCH (p:Project)-[:HAS_MILESTONE]->(m:Milestone)
RETURN p.name, m.name AS milestone, m.dateText
ORDER BY p.name, m.dateText
```
3. **Challenges and Risks:**
```cypher
MATCH (p:Project)-[:HAS_CHALLENGE]->(c:Challenge)
RETURN p.name, collect(c.text) AS challenges
```
4. **Cross-Project Comparison:**
```cypher
MATCH (p:Project)
OPTIONAL MATCH (p)-[:HAS_BUDGET]->(b:Budget)
OPTIONAL MATCH (p)-[:HAS_MILESTONE]->(m:Milestone)
WITH p, b, collect(m) AS milestones
RETURN p.name, b.amount, size(milestones) AS milestone_count
ORDER BY b.amount DESC
```
5. **Entity Relationships:**
```cypher
MATCH (p:Project)-[r]->(related)
WHERE NOT related:Chunk
RETURN p.name, type(r) AS relationship, labels(related)[0] AS entity_type,
coalesce(related.name, related.text, related.amount) AS value
LIMIT 50
```
## Rules
- Use OPTIONAL MATCH when relationships may not exist
- Always include ORDER BY for consistent results
- Use collect() to aggregate multiple related nodes
- Limit results if the query could return many rows
- Return human-readable names, not IDs
- For comparisons across projects, ensure all projects are included
## Question
{question}
Return ONLY the Cypher query, no explanation.""".strip()
def __init__(
self,
together_config: Optional[TogetherAIConfig] = None,
together_api_key: Optional[str] = None,
chat_model: str = "deepseek-ai/DeepSeek-V3",
embedding_model: str = "togethercomputer/m2-bert-80M-8k-retrieval",
) -> None:
"""Initialize GraphRAG builder.
Args:
together_config: Together AI configuration object.
together_api_key: API key (alternative to config object).
chat_model: Chat model identifier.
embedding_model: Embedding model identifier.
Raises:
ValueError: If no API key is provided.
"""
# Handle configuration
if together_config:
api_key = together_config.api_key
chat_model = together_config.chat_model or chat_model
embedding_model = together_config.embedding_model or embedding_model
else:
api_key = together_api_key
if not api_key:
raise ValueError("Together API key is required.")
# Set environment variable for SDK
os.environ["TOGETHER_API_KEY"] = api_key
# Initialize models
self.llm = ChatTogether(model=chat_model, temperature=0)
self.embeddings = TogetherEmbeddings(model=embedding_model)
# Initialize parsers and chunkers
self._parser = ProjectReportParser()
self._chunker = SemanticChunker(
max_chunk_size=self.DEFAULT_CHUNK_SIZE + 300, # Slightly larger for semantic chunks
min_chunk_size=200,
overlap_sentences=2,
)
def _load_pdf_pages(
self,
pdf_files: List[Any]
) -> Tuple[List[Document], List[Tuple[str, str]]]:
"""Load PDF files and extract pages with metadata.
Args:
pdf_files: List of gradio-uploaded file handles.
Returns:
Tuple of (all pages as Documents, list of (source_name, full_text)).
"""
all_pages: List[Document] = []
raw_texts: List[Tuple[str, str]] = []
with log_step(logger, "Load PDF files", f"{len(pdf_files)} file(s)"):
for f in pdf_files:
# Handle both file objects (from uploads) and string paths (from sample files)
if isinstance(f, str):
file_path = f
src_name = f
else:
file_path = f.name
src_name = (
getattr(f, "name", None) or
getattr(f, "orig_name", None) or
"uploaded.pdf"
)
logger.substep(f"Loading: {os.path.basename(src_name)}")
loader = PyPDFLoader(file_path)
pages = loader.load()
all_pages.extend(pages)
logger.substep(f"Extracted {len(pages)} pages")
joined = "\n".join([p.page_content for p in pages])
raw_texts.append((os.path.basename(src_name), joined))
logger.info(f"Total pages loaded: {len(all_pages)}")
return all_pages, raw_texts
def _create_chunks(
self,
pages: List[Document],
use_semantic_chunking: bool = True,
) -> List[Document]:
"""Split pages into chunks with normalized metadata.
Args:
pages: List of page Documents.
use_semantic_chunking: If True, uses section-aware chunking.
Returns:
List of chunk Documents with metadata.
"""
chunking_type = "semantic" if use_semantic_chunking else "character-based"
with log_step(logger, "Create document chunks", chunking_type):
if use_semantic_chunking:
# Use semantic chunker that respects document structure
logger.substep("Using section-aware semantic chunking")
chunks = self._chunker.chunk_pages(pages, adaptive=True)
else:
# Fallback to simple character-based splitting
logger.substep("Using RecursiveCharacterTextSplitter")
splitter = RecursiveCharacterTextSplitter(
chunk_size=self.DEFAULT_CHUNK_SIZE,
chunk_overlap=self.DEFAULT_CHUNK_OVERLAP,
)
chunks = splitter.split_documents(pages)
logger.substep(f"Raw chunks created: {len(chunks)}")
processed_chunks: List[Document] = []
for chunk in chunks:
meta = dict(chunk.metadata or {})
meta["source"] = os.path.basename(meta.get("source", "")) or "uploaded.pdf"
# Normalize page numbers (PyPDFLoader uses 0-index)
if "page" in meta and isinstance(meta["page"], int):
if meta["page"] == 0 or (not use_semantic_chunking):
meta["page"] = int(meta["page"]) + 1
processed_chunks.append(Document(
page_content=chunk.page_content.replace("\n", " "),
metadata=meta,
))
logger.info(f"Final chunks: {len(processed_chunks)}")
return processed_chunks
def _extract_structured_data(
self,
neo4j: Neo4jService,
raw_texts: List[Tuple[str, str]],
) -> List[Dict[str, Any]]:
"""Extract and upsert structured project data.
Args:
neo4j: Neo4j service instance.
raw_texts: List of (source_name, full_text) tuples.
Returns:
List of project dictionaries with results/warnings.
"""
projects_created: List[Dict[str, Any]] = []
with log_step(logger, "Extract structured data", f"{len(raw_texts)} document(s)"):
for source, full_text in raw_texts:
logger.substep(f"Parsing: {source}")
record = self._parser.parse(full_text, source)
try:
proj = neo4j.upsert_structured_project(record)
projects_created.append(proj)
logger.substep(f"Created project: {proj.get('name', source)}")
except Exception as e:
logger.warning(f"Failed to create project {source}: {e}")
projects_created.append({
"projectId": record.project_id or source,
"name": record.project_name or source,
"warning": str(e),
})
logger.info(f"Structured extraction complete: {len(projects_created)} project(s)")
return projects_created
def _extract_llm_graph(
self,
neo4j: Neo4jService,
chunks: List[Document],
parallel: bool = True,
) -> None:
"""Extract entities/relationships using LLM and add to graph.
Args:
neo4j: Neo4j service instance.
chunks: Document chunks for extraction.
parallel: If True, uses parallel batch processing.
"""
mode = "parallel" if parallel else "sequential"
with log_step(logger, "LLM graph extraction", f"{len(chunks)} chunks, {mode}"):
logger.substep("Initializing LLMGraphTransformer")
transformer = LLMGraphTransformer(
llm=self.llm,
allowed_nodes=SchemaPolicy.ALLOWED_NODES,
allowed_relationships=SchemaPolicy.ALLOWED_RELATIONSHIPS,
node_properties=True, # Enable property extraction for richer graph
)
if not parallel or len(chunks) <= self.EXTRACTION_BATCH_SIZE:
# Sequential extraction for small chunk sets
logger.substep("Using sequential extraction (small chunk set)")
graph_documents = transformer.convert_to_graph_documents(chunks)
neo4j.graph.add_graph_documents(graph_documents, include_source=True)
logger.info(f"Added {len(graph_documents)} graph documents")
return
# Parallel extraction for larger chunk sets
def process_batch(batch: List[Document]) -> List:
"""Process a batch of chunks."""
try:
return transformer.convert_to_graph_documents(batch)
except Exception:
return []
# Split into batches
batches = [
chunks[i:i + self.EXTRACTION_BATCH_SIZE]
for i in range(0, len(chunks), self.EXTRACTION_BATCH_SIZE)
]
logger.substep(f"Split into {len(batches)} batches ({self.EXTRACTION_BATCH_SIZE} chunks each)")
all_graph_docs = []
failed_batches = 0
# Process batches with thread pool for IO-bound LLM calls
logger.substep(f"Starting parallel extraction with {self.MAX_EXTRACTION_WORKERS} workers")
with ThreadPoolExecutor(max_workers=self.MAX_EXTRACTION_WORKERS) as executor:
futures = {
executor.submit(process_batch, batch): i
for i, batch in enumerate(batches)
}
for future in as_completed(futures):
batch_idx = futures[future]
try:
result = future.result(timeout=120)
all_graph_docs.extend(result)
logger.substep(f"Batch {batch_idx + 1}/{len(batches)} complete")
except Exception as e:
failed_batches += 1
logger.warning(f"Batch {batch_idx + 1} failed: {e}")
# Bulk add to graph
if all_graph_docs:
logger.substep(f"Adding {len(all_graph_docs)} graph documents to Neo4j")
neo4j.graph.add_graph_documents(all_graph_docs, include_source=True)
if failed_batches > 0:
logger.warning(f"{failed_batches} batch(es) failed during extraction")
logger.info(f"LLM extraction complete: {len(all_graph_docs)} graph documents")
def _create_vector_index(
self,
chunks: List[Document],
neo4j_config: Neo4jConfig,
) -> Neo4jVector:
"""Create or refresh vector index for chunks.
Args:
chunks: Document chunks to index.
neo4j_config: Neo4j connection configuration.
Returns:
Neo4jVector index instance.
"""
with log_step(logger, "Create vector index", f"{len(chunks)} chunks"):
logger.substep(f"Index name: {self.INDEX_NAME}")
logger.substep(f"Keyword index: {self.KEYWORD_INDEX_NAME}")
logger.substep("Creating hybrid search index (dense + BM25)")
vector = Neo4jVector.from_documents(
documents=chunks,
embedding=self.embeddings,
url=neo4j_config.uri,
username=neo4j_config.username,
password=neo4j_config.password,
database=neo4j_config.database or "neo4j",
index_name=self.INDEX_NAME,
keyword_index_name=self.KEYWORD_INDEX_NAME,
node_label=self.NODE_LABEL,
embedding_node_property="embedding",
search_type="hybrid",
)
logger.info("Vector index created successfully")
return vector
def _create_qa_chain(self, neo4j: Neo4jService) -> GraphCypherQAChain:
"""Create Cypher QA chain for graph querying.
Args:
neo4j: Neo4j service instance.
Returns:
GraphCypherQAChain instance.
"""
with log_step(logger, "Create Cypher QA chain"):
logger.substep("Configuring enhanced Cypher prompt template")
cypher_prompt = PromptTemplate(
template=self.CYPHER_PROMPT_TEMPLATE,
input_variables=["schema", "question"],
)
logger.substep("Initializing GraphCypherQAChain")
chain = GraphCypherQAChain.from_llm(
llm=self.llm,
graph=neo4j.graph,
cypher_prompt=cypher_prompt,
verbose=False,
allow_dangerous_requests=True,
)
logger.info("Cypher QA chain ready")
return chain
@trace_flow("PDF Ingestion Pipeline")
def ingest(
self,
pdf_files: List[Any],
neo4j_config: Optional[Neo4jConfig] = None,
neo4j_uri: Optional[str] = None,
neo4j_user: Optional[str] = None,
neo4j_password: Optional[str] = None,
neo4j_database: str = "neo4j",
clear_db: bool = True,
) -> Tuple[str, AppState]:
"""Ingest one or more PDF reports into Neo4j and build GraphRAG indices.
Args:
pdf_files: List of gradio-uploaded file handles.
neo4j_config: Neo4j configuration object (preferred).
neo4j_uri: Neo4j connection URI (alternative).
neo4j_user: Username (alternative).
neo4j_password: Password (alternative).
neo4j_database: Database name.
clear_db: If True, deletes all existing nodes prior to ingestion.
Returns:
Tuple of (human-readable status message, AppState).
Notes:
- The ingestion process can be compute-heavy due to LLM graph extraction.
- Even if the deterministic parser yields partial results, chunk retrieval
still works.
"""
# Validate inputs
if not pdf_files:
logger.warning("No PDF files provided")
return "Please upload at least one PDF.", AppState()
logger.info(f"Starting ingestion of {len(pdf_files)} PDF file(s)")
# Build config from parameters if not provided
if neo4j_config is None:
neo4j_config = Neo4jConfig(
uri=neo4j_uri or "",
username=neo4j_user or "neo4j",
password=neo4j_password or "",
database=neo4j_database,
)
if not neo4j_config.is_valid():
logger.error("Invalid Neo4j configuration")
return "Please provide Neo4j connection details.", AppState()
# Connect to Neo4j
with log_step(logger, "Connect to Neo4j"):
try:
neo4j = Neo4jService(
uri=neo4j_config.uri,
user=neo4j_config.username,
password=neo4j_config.password,
database=neo4j_config.database,
)
logger.substep(f"Connected to {neo4j_config.uri}")
except Neo4jConnectionError as e:
logger.error(f"Neo4j connection failed: {e}")
return (
f"Neo4j connection failed. For Aura, use the exact URI shown in the "
f"console (typically starts with neo4j+s://...). Error: {e}",
AppState(),
)
# Ensure constraints
with log_step(logger, "Ensure database constraints"):
neo4j.ensure_constraints()
# Clear database if requested
if clear_db:
with log_step(logger, "Clear existing data"):
neo4j.clear()
# 1) Load PDF pages
all_pages, raw_texts = self._load_pdf_pages(pdf_files)
# 2) Structured extraction (high precision)
projects_created = self._extract_structured_data(neo4j, raw_texts)
# 3) Create chunks
chunks = self._create_chunks(all_pages)
# 4) LLM-based KG extraction (high recall)
self._extract_llm_graph(neo4j, chunks)
# 5) Vector index
vector = self._create_vector_index(chunks, neo4j_config)
# 6) Cypher QA chain
qa_chain = self._create_qa_chain(neo4j)
# Build status message
proj_lines = []
for p in projects_created:
warn = f" (warning: {p.get('warning')})" if "warning" in p else ""
proj_lines.append(f"- {p.get('name')} [{p.get('projectId')}]{warn}")
msg = (
"Ingestion complete.\n\n"
f"Neo4j database: `{neo4j_config.database}`\n\n"
"Projects found:\n" + "\n".join(proj_lines)
)
logger.info(f"Ingestion complete: {len(projects_created)} project(s), {len(chunks)} chunks")
return msg, AppState(
neo4j=neo4j,
vector=vector,
qa_chain=qa_chain,
llm=self.llm,
)
def ingest_with_progress(
self,
pdf_files: List[Any],
neo4j_config: Optional[Neo4jConfig] = None,
neo4j_uri: Optional[str] = None,
neo4j_user: Optional[str] = None,
neo4j_password: Optional[str] = None,
neo4j_database: str = "neo4j",
clear_db: bool = True,
skip_llm_extraction: bool = True, # Skip LLM extraction for faster ingestion
) -> Generator[Tuple[str, float, Optional[AppState]], None, None]:
"""Ingest PDFs with progress updates for UI.
This generator yields progress updates during ingestion, allowing
the UI to display a progress bar with status messages.
Args:
pdf_files: List of gradio-uploaded file handles.
neo4j_config: Neo4j configuration object (preferred).
neo4j_uri: Neo4j connection URI (alternative).
neo4j_user: Username (alternative).
neo4j_password: Password (alternative).
neo4j_database: Database name.
clear_db: If True, deletes all existing nodes prior to ingestion.
skip_llm_extraction: If True, skips LLM graph extraction for faster ingestion.
Yields:
Tuple of (status_message, progress_fraction, optional_state)
- progress_fraction is 0.0 to 1.0
- optional_state is None until final yield, then contains AppState
Example:
>>> for status, progress, state in builder.ingest_with_progress(files, config):
... print(f"{progress*100:.0f}%: {status}")
... if state:
... print("Done!")
"""
start_time = time.time()
# Validate inputs
if not pdf_files:
yield "❌ Please upload at least one PDF file.", 0.0, None
return
# Build config from parameters if not provided
if neo4j_config is None:
neo4j_config = Neo4jConfig(
uri=neo4j_uri or "",
username=neo4j_user or "neo4j",
password=neo4j_password or "",
database=neo4j_database,
)
if not neo4j_config.is_valid():
yield "❌ Please provide Neo4j connection details.", 0.0, None
return
# Step 1: Connect to Neo4j (5%)
yield "πŸ”Œ Connecting to Neo4j...", 0.05, None
try:
neo4j = Neo4jService(
uri=neo4j_config.uri,
user=neo4j_config.username,
password=neo4j_config.password,
database=neo4j_config.database,
)
except Neo4jConnectionError as e:
yield f"❌ Neo4j connection failed: {e}", 0.05, None
return
# Step 2: Ensure constraints (10%)
yield "πŸ“‹ Setting up database constraints...", 0.10, None
neo4j.ensure_constraints()
# Step 3: Clear database if requested (15%)
if clear_db:
yield "πŸ—‘οΈ Clearing existing data...", 0.15, None
neo4j.clear()
# Step 4: Load PDF pages (25%)
yield f"πŸ“„ Loading {len(pdf_files)} PDF file(s)...", 0.20, None
all_pages, raw_texts = self._load_pdf_pages(pdf_files)
yield f"πŸ“„ Loaded {len(all_pages)} pages from PDFs", 0.25, None
# Step 5: Structured extraction (35%)
yield "πŸ” Extracting structured project data...", 0.30, None
projects_created = self._extract_structured_data(neo4j, raw_texts)
project_names = [p.get('name', 'Unknown') for p in projects_created]
yield f"βœ… Found {len(projects_created)} project(s): {', '.join(project_names)}", 0.35, None
# Step 6: Create chunks (45%)
yield "βœ‚οΈ Creating document chunks...", 0.40, None
chunks = self._create_chunks(all_pages)
yield f"βœ… Created {len(chunks)} chunks", 0.45, None
# Step 7: LLM Graph Extraction (optional) (45-70%)
if not skip_llm_extraction:
yield f"🧠 Extracting entities with LLM ({len(chunks)} chunks)...", 0.50, None
# This is the slowest step - show batch progress
total_batches = (len(chunks) + self.EXTRACTION_BATCH_SIZE - 1) // self.EXTRACTION_BATCH_SIZE
for batch_num in range(total_batches):
progress = 0.50 + (0.20 * (batch_num + 1) / total_batches)
yield f"🧠 LLM extraction: batch {batch_num + 1}/{total_batches}...", progress, None
self._extract_llm_graph(neo4j, chunks)
yield "βœ… LLM graph extraction complete", 0.70, None
else:
yield "⏩ Skipping LLM extraction (using fast mode)", 0.70, None
# Step 8: Create vector index (90%)
yield f"πŸ“Š Creating vector index ({len(chunks)} chunks)...", 0.75, None
vector = self._create_vector_index(chunks, neo4j_config)
yield "βœ… Vector index created", 0.90, None
# Step 9: Create QA chain (95%)
yield "βš™οΈ Initializing QA chain...", 0.95, None
qa_chain = self._create_qa_chain(neo4j)
# Final step: Complete (100%)
elapsed = time.time() - start_time
proj_lines = []
for p in projects_created:
warn = f" ⚠️ {p.get('warning')}" if "warning" in p else ""
proj_lines.append(f"- **{p.get('name')}** [{p.get('projectId')}]{warn}")
final_msg = (
f"## βœ… Ingestion Complete ({elapsed:.1f}s)\n\n"
f"**Database:** `{neo4j_config.database}`\n\n"
f"**Projects found:**\n" + "\n".join(proj_lines) + "\n\n"
f"**Stats:** {len(chunks)} chunks indexed"
)
yield final_msg, 1.0, AppState(
neo4j=neo4j,
vector=vector,
qa_chain=qa_chain,
llm=self.llm,
)