graph_rag / app.py
Sameer-Handsome173's picture
Update app.py
7641778 verified
# ================================
# app.py - FastAPI Graph RAG API with Proper Prompt Templates
# ================================
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from langchain.chains import GraphCypherQAChain, LLMChain
from langchain_community.graphs import Neo4jGraph
from langchain_community.llms import HuggingFaceHub
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.prompts import PromptTemplate
from langchain.output_parsers import PydanticOutputParser
from typing import List, Optional
import os
import json
import re
import uvicorn
# ================================
# Configuration
# ================================
app = FastAPI(
title="Graph RAG API",
description="Document QA using Graph RAG with Qwen model",
version="1.0.0"
)
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Environment variables
NEO4J_URI = os.getenv("NEO4J_URI")
NEO4J_USERNAME = os.getenv("NEO4J_USERNAME", "neo4j")
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD")
HF_TOKEN = os.getenv("HF_TOKEN")
# Global variables
graph = None
llm = None
qa_chain = None
extraction_chain = None
# ================================
# Pydantic Models for Extraction
# ================================
class Entity(BaseModel):
"""Single entity extracted from text"""
name: str = Field(description="The name of the entity")
type: str = Field(description="Type: Person, Organization, Product, Technology, Concept, Location")
description: Optional[str] = Field(default="", description="Brief description of the entity")
class Relationship(BaseModel):
"""Relationship between two entities"""
source: str = Field(description="Source entity name")
target: str = Field(description="Target entity name")
type: str = Field(description="Relationship type in UPPER_SNAKE_CASE (e.g., CREATED, FOUNDED, USES)")
context: Optional[str] = Field(default="", description="Context of the relationship")
class ExtractionResult(BaseModel):
"""Complete extraction result"""
entities: List[Entity] = Field(description="List of extracted entities")
relationships: List[Relationship] = Field(description="List of extracted relationships")
# ================================
# Pydantic Models for API
# ================================
class Query(BaseModel):
question: str
class UploadResponse(BaseModel):
message: str
entities_count: int
relationships_count: int
class QueryResponse(BaseModel):
question: str
answer: str
cypher_query: str = None
# ================================
# Prompt Templates with Pydantic Parser
# ================================
# Create parser for structured output
extraction_parser = PydanticOutputParser(pydantic_object=ExtractionResult)
# 1. Entity Extraction Prompt with Pydantic
ENTITY_EXTRACTION_TEMPLATE = """Extract entities and relationships from the text.
{format_instructions}
TEXT:
{text}
Important:
- Extract people, organizations, products, technologies, concepts
- For relationships use: CREATED, FOUNDED, USES, BUILT_ON, WORKS_AT, CEO_OF, INTEGRATES_WITH
- Be specific and accurate
Your response:"""
entity_extraction_prompt = PromptTemplate(
input_variables=["text"],
template=ENTITY_EXTRACTION_TEMPLATE,
partial_variables={"format_instructions": extraction_parser.get_format_instructions()}
)
# 2. Cypher Generation Prompt Template
CYPHER_GENERATION_TEMPLATE = """You are a Neo4j Cypher expert. Generate accurate Cypher queries based on the graph schema and user questions.
GRAPH SCHEMA:
{schema}
USER QUESTION:
{question}
INSTRUCTIONS:
1. Generate a Cypher query that answers the question
2. Only use node labels and relationship types present in the schema
3. Use MATCH to find patterns
4. Use WHERE for filtering
5. Use RETURN to specify what to return
6. Keep queries simple and efficient
7. For "tell me about X" questions, get the entity and its immediate connections
8. For "who/what/where" questions, find specific information
EXAMPLES:
Question: "Who is the CEO of Tesla?"
Cypher: MATCH (p:Person)-[:CEO_OF]->(c:Company {{name: 'Tesla'}}) RETURN p.name
Question: "What products does Tesla make?"
Cypher: MATCH (c:Company {{name: 'Tesla'}})-[:MANUFACTURES|CREATES]->(p:Product) RETURN p.name
Question: "Tell me about Elon Musk"
Cypher: MATCH (p:Person {{name: 'Elon Musk'}})-[r]-(connected) RETURN p, type(r), connected LIMIT 10
Return ONLY the Cypher query with no explanation or markdown:"""
cypher_generation_prompt = PromptTemplate(
input_variables=["schema", "question"],
template=CYPHER_GENERATION_TEMPLATE
)
# 3. Answer Generation Prompt Template
ANSWER_GENERATION_TEMPLATE = """You are a helpful assistant that provides clear, concise answers based on graph database query results.
USER QUESTION:
{question}
CYPHER QUERY USED:
{query}
QUERY RESULTS:
{context}
INSTRUCTIONS:
1. Answer the user's question based on the query results
2. Be clear and concise
3. If results are empty, say "I don't have information about that in the knowledge graph"
4. Don't mention technical details like Cypher or graph databases
5. Provide a natural, conversational answer
ANSWER:"""
answer_generation_prompt = PromptTemplate(
input_variables=["question", "query", "context"],
template=ANSWER_GENERATION_TEMPLATE
)
# ================================
# Startup
# ================================
@app.on_event("startup")
async def startup_event():
global graph, llm, qa_chain, extraction_chain
try:
# Initialize Neo4j
graph = Neo4jGraph(
url=NEO4J_URI,
username=NEO4J_USERNAME,
password=NEO4J_PASSWORD
)
print("✅ Connected to Neo4j")
# Initialize Qwen LLM
llm = HuggingFaceHub(
repo_id="Qwen/Qwen2.5-3B-Instruct",
huggingfacehub_api_token=HF_TOKEN,
model_kwargs={"temperature": 0.1, "max_new_tokens": 512}
)
print("✅ Qwen model loaded")
# Create extraction chain with prompt template
extraction_chain = LLMChain(
llm=llm,
prompt=entity_extraction_prompt,
verbose=False
)
print("✅ Extraction chain initialized")
# Setup QA chain with custom prompt
qa_chain = GraphCypherQAChain.from_llm(
llm=llm,
graph=graph,
verbose=True,
cypher_prompt=cypher_generation_prompt,
return_intermediate_steps=True,
top_k=10
)
print("✅ QA chain initialized")
except Exception as e:
print(f"❌ Startup error: {e}")
raise
# ================================
# Helper Functions with Prompt Templates
# ================================
def extract_entities_relationships(text_chunk):
"""Extract entities and relationships using Pydantic structured output"""
try:
print(f"\n{'='*60}")
print(f"Processing chunk: {text_chunk[:100]}...")
# Use the extraction chain
response = extraction_chain.run(text=text_chunk)
print(f"RAW LLM RESPONSE:")
print(response[:500])
print('='*60)
# Try to parse with Pydantic parser
try:
result = extraction_parser.parse(response)
entities = [e.dict() for e in result.entities]
relationships = [r.dict() for r in result.relationships]
print(f"✅ PARSED with Pydantic:")
print(f" Entities: {len(entities)}")
print(f" Relationships: {len(relationships)}")
return {"entities": entities, "relationships": relationships}
except Exception as parse_error:
print(f"⚠️ Pydantic parsing failed: {parse_error}")
print("Trying manual JSON extraction...")
# Fallback: Try manual JSON extraction
cleaned = response.strip()
# Remove markdown
if "```json" in cleaned:
cleaned = cleaned.split("```json")[1].split("```")[0]
elif "```" in cleaned:
cleaned = cleaned.split("```")[1].split("```")[0]
# Find JSON
if "{" in cleaned and "}" in cleaned:
start = cleaned.find("{")
end = cleaned.rfind("}") + 1
cleaned = cleaned[start:end]
data = json.loads(cleaned)
print(f"✅ Manual JSON parse successful: {len(data.get('entities', []))} entities")
return data
except Exception as e:
print(f"❌ All parsing failed: {e}")
print("Using fallback extraction...")
return fallback_extraction(text_chunk)
def fallback_extraction(text):
"""Simple rule-based fallback extraction"""
print("⚠️ Using fallback extraction...")
entities = []
relationships = []
seen_entities = set()
# Split into sentences
sentences = [s.strip() for s in text.split('.') if s.strip()]
for sentence in sentences:
words = sentence.split()
# Extract capitalized words/phrases as entities
current_entity = []
for word in words:
clean = re.sub(r'[^\w\s]', '', word)
if clean and clean[0].isupper() and len(clean) > 2:
current_entity.append(clean)
elif current_entity:
entity_name = ' '.join(current_entity)
if entity_name not in seen_entities:
entities.append({
"name": entity_name,
"type": "Concept",
"description": sentence[:100]
})
seen_entities.add(entity_name)
current_entity = []
# Check for common relationship patterns
if ' created ' in sentence.lower() or ' developed ' in sentence.lower():
# Try to extract creator and creation
parts = re.split(r' created | developed ', sentence, flags=re.IGNORECASE)
if len(parts) == 2:
creator = parts[0].strip().split()[-1]
creation = parts[1].strip().split()[0]
relationships.append({
"source": creator,
"target": creation,
"type": "CREATED",
"context": sentence[:100]
})
print(f"Fallback extracted: {len(entities)} entities, {len(relationships)} relationships")
return {"entities": entities[:15], "relationships": relationships[:10]}
def add_to_graph(entities, relationships, doc_name):
"""Add entities and relationships to Neo4j with proper sanitization"""
entities_added = 0
relationships_added = 0
# Helper function to sanitize strings for Cypher
def sanitize(text):
if not text:
return ""
return text.replace("'", "\\'").replace('"', '\\"').replace("\n", " ")
# Add entities
for entity in entities:
try:
name = sanitize(entity.get('name', ''))
if not name:
continue
entity_type = entity.get('type', 'Entity').replace(" ", "")
desc = sanitize(entity.get('description', ''))
query = f"""
MERGE (e:{entity_type} {{name: '{name}'}})
ON CREATE SET
e.description = '{desc}',
e.source_document = '{doc_name}',
e.created_at = datetime()
ON MATCH SET
e.description = CASE WHEN e.description = '' THEN '{desc}' ELSE e.description END
"""
graph.query(query)
entities_added += 1
except Exception as e:
print(f"Error adding entity {entity.get('name', 'unknown')}: {e}")
# Add relationships
for rel in relationships:
try:
source = sanitize(rel.get('source', ''))
target = sanitize(rel.get('target', ''))
rel_type = rel.get('type', 'RELATED_TO').replace(" ", "_").upper()
context = sanitize(rel.get('context', ''))
if not source or not target:
continue
query = f"""
MATCH (a {{name: '{source}'}})
MATCH (b {{name: '{target}'}})
MERGE (a)-[r:{rel_type}]->(b)
ON CREATE SET
r.context = '{context}',
r.source_document = '{doc_name}',
r.created_at = datetime()
"""
graph.query(query)
relationships_added += 1
except Exception as e:
print(f"Error adding relationship {source} -> {target}: {e}")
return entities_added, relationships_added
# ================================
# API Endpoints
# ================================
@app.get("/")
async def root():
return {
"message": "Graph RAG API with Proper Prompt Templates",
"model": "Qwen/Qwen2.5-3B-Instruct",
"endpoints": {
"POST /upload": "Upload document to build knowledge graph",
"POST /query": "Ask questions about your documents",
"GET /stats": "Get graph statistics",
"GET /health": "Check system health",
"DELETE /clear": "Clear the graph"
},
"features": [
"Structured entity extraction with PromptTemplate",
"Cypher generation with custom prompt",
"Answer generation with context",
"Natural language question answering"
]
}
@app.get("/health")
async def health_check():
"""Check if all systems are operational"""
try:
# Test Neo4j connection
graph.query("MATCH (n) RETURN count(n) as count LIMIT 1")
neo4j_status = "connected"
except:
neo4j_status = "disconnected"
return {
"status": "healthy" if neo4j_status == "connected" else "unhealthy",
"neo4j": neo4j_status,
"llm": "loaded" if llm else "not loaded",
"extraction_chain": "ready" if extraction_chain else "not ready",
"qa_chain": "ready" if qa_chain else "not ready"
}
@app.post("/upload", response_model=UploadResponse)
async def upload_document(file: UploadFile = File(...)):
"""Upload document and build knowledge graph using prompt templates"""
if not file.filename.endswith('.txt'):
raise HTTPException(400, "Only .txt files supported")
try:
# Read file
content = await file.read()
text = content.decode('utf-8')
print(f"Processing document: {file.filename} ({len(text)} characters)")
# Split into chunks
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=800,
chunk_overlap=150,
length_function=len
)
chunks = text_splitter.split_text(text)
print(f"Split into {len(chunks)} chunks")
total_entities = 0
total_relationships = 0
# Process chunks (limit to 10 for demo, remove limit in production)
max_chunks = min(len(chunks), 10)
for i, chunk in enumerate(chunks[:max_chunks]):
print(f"Processing chunk {i+1}/{max_chunks}")
# Use extraction chain with prompt template
extracted = extract_entities_relationships(chunk)
entities = extracted.get("entities", [])
relationships = extracted.get("relationships", [])
print(f" Extracted: {len(entities)} entities, {len(relationships)} relationships")
# Add to graph
e_count, r_count = add_to_graph(entities, relationships, file.filename)
total_entities += e_count
total_relationships += r_count
return UploadResponse(
message=f"Successfully processed {file.filename} ({max_chunks} chunks)",
entities_count=total_entities,
relationships_count=total_relationships
)
except Exception as e:
print(f"Upload error: {e}")
raise HTTPException(500, f"Error processing document: {str(e)}")
@app.post("/query", response_model=QueryResponse)
async def query_graph(query: Query):
"""Ask a question using GraphCypherQAChain with custom prompts"""
try:
print(f"Question: {query.question}")
# Use QA chain with custom cypher generation prompt
response = qa_chain.invoke({"query": query.question})
# Extract generated Cypher
cypher = None
if "intermediate_steps" in response and len(response["intermediate_steps"]) > 0:
cypher = response["intermediate_steps"][0].get("query", "")
print(f"Generated Cypher: {cypher}")
print(f"Answer: {response['result']}")
return QueryResponse(
question=query.question,
answer=response["result"],
cypher_query=cypher
)
except Exception as e:
print(f"Query error: {e}")
raise HTTPException(500, f"Error answering question: {str(e)}")
@app.get("/stats")
async def get_stats():
"""Get graph statistics"""
try:
# Count nodes by type
nodes = graph.query("""
MATCH (n)
RETURN labels(n)[0] as type, count(n) as count
ORDER BY count DESC
""")
# Count relationships by type
rels = graph.query("""
MATCH ()-[r]->()
RETURN type(r) as type, count(r) as count
ORDER BY count DESC
""")
# Get total counts
total_nodes = sum([n['count'] for n in nodes]) if nodes else 0
total_rels = sum([r['count'] for r in rels]) if rels else 0
# Get sample entities
samples = graph.query("""
MATCH (n)
RETURN labels(n)[0] as type, n.name as name
LIMIT 10
""")
return {
"total_nodes": total_nodes,
"total_relationships": total_rels,
"nodes_by_type": nodes,
"relationships_by_type": rels,
"sample_entities": samples
}
except Exception as e:
raise HTTPException(500, f"Error getting stats: {str(e)}")
@app.delete("/clear")
async def clear_graph():
"""Clear all data from graph"""
try:
result = graph.query("MATCH (n) RETURN count(n) as count")
count = result[0]['count'] if result else 0
graph.query("MATCH (n) DETACH DELETE n")
return {
"message": f"Graph cleared successfully. Deleted {count} nodes.",
"nodes_deleted": count
}
except Exception as e:
raise HTTPException(500, f"Error clearing graph: {str(e)}")
# ================================
# Run
# ================================
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)