SPARKNET / examples /document_rag_end_to_end.py
MHamdan's picture
Initial commit: SPARKNET framework
d520909
raw
history blame
10.2 kB
#!/usr/bin/env python3
"""
Document Intelligence RAG End-to-End Example
Demonstrates the complete RAG workflow:
1. Parse documents into semantic chunks
2. Index chunks into vector store
3. Semantic retrieval with filters
4. Grounded question answering with evidence
5. Evidence visualization
Requirements:
- ChromaDB: pip install chromadb
- Ollama running with nomic-embed-text model: ollama pull nomic-embed-text
- PyMuPDF: pip install pymupdf
"""
import sys
from pathlib import Path
# Add project root to path
sys.path.insert(0, str(Path(__file__).parent.parent))
def check_dependencies():
"""Check that required dependencies are available."""
missing = []
try:
import chromadb
except ImportError:
missing.append("chromadb")
try:
import fitz # PyMuPDF
except ImportError:
missing.append("pymupdf")
if missing:
print("Missing dependencies:")
for dep in missing:
print(f" - {dep}")
print("\nInstall with: pip install " + " ".join(missing))
return False
# Check Ollama
try:
import requests
response = requests.get("http://localhost:11434/api/tags", timeout=2)
if response.status_code != 200:
print("Warning: Ollama server not responding")
print("Start Ollama with: ollama serve")
print("Then pull the embedding model: ollama pull nomic-embed-text")
except:
print("Warning: Could not connect to Ollama server")
print("The example will still work but with mock embeddings")
return True
def demo_parse_and_index(doc_paths: list):
"""
Demo: Parse documents and index into vector store.
Args:
doc_paths: List of document file paths
"""
print("\n" + "=" * 60)
print("STEP 1: PARSE AND INDEX DOCUMENTS")
print("=" * 60)
from src.document_intelligence import DocumentParser, ParserConfig
from src.document_intelligence.tools import get_rag_tool
# Get the index tool
index_tool = get_rag_tool("index_document")
results = []
for doc_path in doc_paths:
print(f"\nProcessing: {doc_path}")
# Parse document first (optional - tool can do this)
config = ParserConfig(render_dpi=200, max_pages=10)
parser = DocumentParser(config=config)
try:
parse_result = parser.parse(doc_path)
print(f" Parsed: {len(parse_result.chunks)} chunks, {parse_result.num_pages} pages")
# Index the parse result
result = index_tool.execute(parse_result=parse_result)
if result.success:
print(f" Indexed: {result.data['chunks_indexed']} chunks")
print(f" Document ID: {result.data['document_id']}")
results.append({
"path": doc_path,
"doc_id": result.data['document_id'],
"chunks": result.data['chunks_indexed'],
})
else:
print(f" Error: {result.error}")
except Exception as e:
print(f" Failed: {e}")
return results
def demo_semantic_retrieval(query: str, document_id: str = None):
"""
Demo: Semantic retrieval from vector store.
Args:
query: Search query
document_id: Optional document filter
"""
print("\n" + "=" * 60)
print("STEP 2: SEMANTIC RETRIEVAL")
print("=" * 60)
from src.document_intelligence.tools import get_rag_tool
retrieve_tool = get_rag_tool("retrieve_chunks")
print(f"\nQuery: \"{query}\"")
if document_id:
print(f"Document filter: {document_id}")
result = retrieve_tool.execute(
query=query,
top_k=5,
document_id=document_id,
include_evidence=True,
)
if result.success:
chunks = result.data.get("chunks", [])
print(f"\nFound {len(chunks)} relevant chunks:\n")
for i, chunk in enumerate(chunks, 1):
print(f"{i}. [similarity={chunk['similarity']:.3f}]")
print(f" Page {chunk.get('page', '?')}, Type: {chunk.get('chunk_type', 'unknown')}")
print(f" Text: {chunk['text'][:150]}...")
print()
# Show evidence
if result.evidence:
print("Evidence references:")
for ev in result.evidence[:3]:
print(f" - Chunk {ev['chunk_id'][:12]}... Page {ev.get('page', '?')}")
return chunks
else:
print(f"Error: {result.error}")
return []
def demo_grounded_qa(question: str, document_id: str = None):
"""
Demo: Grounded question answering with evidence.
Args:
question: Question to answer
document_id: Optional document filter
"""
print("\n" + "=" * 60)
print("STEP 3: GROUNDED QUESTION ANSWERING")
print("=" * 60)
from src.document_intelligence.tools import get_rag_tool
qa_tool = get_rag_tool("rag_answer")
print(f"\nQuestion: \"{question}\"")
result = qa_tool.execute(
question=question,
document_id=document_id,
top_k=5,
)
if result.success:
data = result.data
print(f"\nAnswer: {data.get('answer', 'No answer')}")
print(f"Confidence: {data.get('confidence', 0):.2f}")
if data.get('abstained'):
print("Note: System abstained due to low confidence")
# Show citations if any
citations = data.get('citations', [])
if citations:
print("\nCitations:")
for cit in citations:
print(f" [{cit['index']}] {cit.get('text', '')[:80]}...")
# Show evidence
if result.evidence:
print("\nEvidence locations:")
for ev in result.evidence:
print(f" - Page {ev.get('page', '?')}: {ev.get('snippet', '')[:60]}...")
return data
else:
print(f"Error: {result.error}")
return None
def demo_filtered_retrieval():
"""
Demo: Retrieval with various filters.
"""
print("\n" + "=" * 60)
print("STEP 4: FILTERED RETRIEVAL")
print("=" * 60)
from src.document_intelligence.tools import get_rag_tool
retrieve_tool = get_rag_tool("retrieve_chunks")
# Filter by chunk type
print("\n--- Retrieving only table chunks ---")
result = retrieve_tool.execute(
query="data values",
top_k=3,
chunk_types=["table"],
)
if result.success:
chunks = result.data.get("chunks", [])
print(f"Found {len(chunks)} table chunks")
for chunk in chunks:
print(f" - Page {chunk.get('page', '?')}: {chunk['text'][:80]}...")
# Filter by page range
print("\n--- Retrieving from pages 1-3 only ---")
result = retrieve_tool.execute(
query="content",
top_k=3,
page_range=(1, 3),
)
if result.success:
chunks = result.data.get("chunks", [])
print(f"Found {len(chunks)} chunks from pages 1-3")
for chunk in chunks:
print(f" - Page {chunk.get('page', '?')}: {chunk['text'][:80]}...")
def demo_index_stats():
"""
Demo: Show index statistics.
"""
print("\n" + "=" * 60)
print("INDEX STATISTICS")
print("=" * 60)
from src.document_intelligence.tools import get_rag_tool
stats_tool = get_rag_tool("get_index_stats")
result = stats_tool.execute()
if result.success:
data = result.data
print(f"\nTotal chunks indexed: {data.get('total_chunks', 0)}")
print(f"Embedding model: {data.get('embedding_model', 'unknown')}")
print(f"Embedding dimension: {data.get('embedding_dimension', 'unknown')}")
else:
print(f"Error: {result.error}")
def main():
"""Run the complete RAG demo."""
print("=" * 60)
print("SPARKNET Document Intelligence RAG Demo")
print("=" * 60)
# Check dependencies
if not check_dependencies():
print("\nPlease install missing dependencies and try again.")
return
# Find sample documents
sample_paths = [
Path("Dataset/Patent_1.pdf"),
Path("data/sample.pdf"),
Path("tests/fixtures/sample.pdf"),
]
doc_paths = []
for path in sample_paths:
if path.exists():
doc_paths.append(str(path))
break
if not doc_paths:
print("\nNo sample documents found.")
print("Please provide a PDF file path as argument.")
print("\nUsage: python document_rag_end_to_end.py [path/to/document.pdf]")
if len(sys.argv) > 1:
doc_paths = sys.argv[1:]
else:
return
print(f"\nUsing documents: {doc_paths}")
try:
# Step 1: Parse and index
indexed_docs = demo_parse_and_index(doc_paths)
if not indexed_docs:
print("\nNo documents were indexed. Exiting.")
return
# Get first document ID for filtering
first_doc_id = indexed_docs[0]["doc_id"]
# Step 2: Semantic retrieval
demo_semantic_retrieval(
query="main topic content",
document_id=first_doc_id,
)
# Step 3: Grounded Q&A
demo_grounded_qa(
question="What is this document about?",
document_id=first_doc_id,
)
# Step 4: Filtered retrieval
demo_filtered_retrieval()
# Show stats
demo_index_stats()
print("\n" + "=" * 60)
print("Demo complete!")
print("=" * 60)
print("\nNext steps:")
print(" 1. Try the CLI: sparknet docint index your_document.pdf")
print(" 2. Query the index: sparknet docint retrieve 'your query'")
print(" 3. Ask questions: sparknet docint ask doc.pdf 'question' --use-rag")
except ImportError as e:
print(f"\nImport error: {e}")
print("Make sure all dependencies are installed:")
print(" pip install pymupdf pillow numpy pydantic chromadb")
except Exception as e:
print(f"\nError: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()