# ================================ # app.py - FastAPI Graph RAG API with Proper Prompt Templates # ================================ from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from langchain.chains import GraphCypherQAChain, LLMChain from langchain_community.graphs import Neo4jGraph from langchain_community.llms import HuggingFaceHub from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.prompts import PromptTemplate from langchain.output_parsers import PydanticOutputParser from typing import List, Optional import os import json import re import uvicorn # ================================ # Configuration # ================================ app = FastAPI( title="Graph RAG API", description="Document QA using Graph RAG with Qwen model", version="1.0.0" ) # CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Environment variables NEO4J_URI = os.getenv("NEO4J_URI") NEO4J_USERNAME = os.getenv("NEO4J_USERNAME", "neo4j") NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD") HF_TOKEN = os.getenv("HF_TOKEN") # Global variables graph = None llm = None qa_chain = None extraction_chain = None # ================================ # Pydantic Models for Extraction # ================================ class Entity(BaseModel): """Single entity extracted from text""" name: str = Field(description="The name of the entity") type: str = Field(description="Type: Person, Organization, Product, Technology, Concept, Location") description: Optional[str] = Field(default="", description="Brief description of the entity") class Relationship(BaseModel): """Relationship between two entities""" source: str = Field(description="Source entity name") target: str = Field(description="Target entity name") type: str = Field(description="Relationship type in UPPER_SNAKE_CASE (e.g., CREATED, FOUNDED, USES)") context: Optional[str] = Field(default="", description="Context of the relationship") class ExtractionResult(BaseModel): """Complete extraction result""" entities: List[Entity] = Field(description="List of extracted entities") relationships: List[Relationship] = Field(description="List of extracted relationships") # ================================ # Pydantic Models for API # ================================ class Query(BaseModel): question: str class UploadResponse(BaseModel): message: str entities_count: int relationships_count: int class QueryResponse(BaseModel): question: str answer: str cypher_query: str = None # ================================ # Prompt Templates with Pydantic Parser # ================================ # Create parser for structured output extraction_parser = PydanticOutputParser(pydantic_object=ExtractionResult) # 1. Entity Extraction Prompt with Pydantic ENTITY_EXTRACTION_TEMPLATE = """Extract entities and relationships from the text. {format_instructions} TEXT: {text} Important: - Extract people, organizations, products, technologies, concepts - For relationships use: CREATED, FOUNDED, USES, BUILT_ON, WORKS_AT, CEO_OF, INTEGRATES_WITH - Be specific and accurate Your response:""" entity_extraction_prompt = PromptTemplate( input_variables=["text"], template=ENTITY_EXTRACTION_TEMPLATE, partial_variables={"format_instructions": extraction_parser.get_format_instructions()} ) # 2. Cypher Generation Prompt Template CYPHER_GENERATION_TEMPLATE = """You are a Neo4j Cypher expert. Generate accurate Cypher queries based on the graph schema and user questions. GRAPH SCHEMA: {schema} USER QUESTION: {question} INSTRUCTIONS: 1. Generate a Cypher query that answers the question 2. Only use node labels and relationship types present in the schema 3. Use MATCH to find patterns 4. Use WHERE for filtering 5. Use RETURN to specify what to return 6. Keep queries simple and efficient 7. For "tell me about X" questions, get the entity and its immediate connections 8. For "who/what/where" questions, find specific information EXAMPLES: Question: "Who is the CEO of Tesla?" Cypher: MATCH (p:Person)-[:CEO_OF]->(c:Company {{name: 'Tesla'}}) RETURN p.name Question: "What products does Tesla make?" Cypher: MATCH (c:Company {{name: 'Tesla'}})-[:MANUFACTURES|CREATES]->(p:Product) RETURN p.name Question: "Tell me about Elon Musk" Cypher: MATCH (p:Person {{name: 'Elon Musk'}})-[r]-(connected) RETURN p, type(r), connected LIMIT 10 Return ONLY the Cypher query with no explanation or markdown:""" cypher_generation_prompt = PromptTemplate( input_variables=["schema", "question"], template=CYPHER_GENERATION_TEMPLATE ) # 3. Answer Generation Prompt Template ANSWER_GENERATION_TEMPLATE = """You are a helpful assistant that provides clear, concise answers based on graph database query results. USER QUESTION: {question} CYPHER QUERY USED: {query} QUERY RESULTS: {context} INSTRUCTIONS: 1. Answer the user's question based on the query results 2. Be clear and concise 3. If results are empty, say "I don't have information about that in the knowledge graph" 4. Don't mention technical details like Cypher or graph databases 5. Provide a natural, conversational answer ANSWER:""" answer_generation_prompt = PromptTemplate( input_variables=["question", "query", "context"], template=ANSWER_GENERATION_TEMPLATE ) # ================================ # Startup # ================================ @app.on_event("startup") async def startup_event(): global graph, llm, qa_chain, extraction_chain try: # Initialize Neo4j graph = Neo4jGraph( url=NEO4J_URI, username=NEO4J_USERNAME, password=NEO4J_PASSWORD ) print("✅ Connected to Neo4j") # Initialize Qwen LLM llm = HuggingFaceHub( repo_id="Qwen/Qwen2.5-3B-Instruct", huggingfacehub_api_token=HF_TOKEN, model_kwargs={"temperature": 0.1, "max_new_tokens": 512} ) print("✅ Qwen model loaded") # Create extraction chain with prompt template extraction_chain = LLMChain( llm=llm, prompt=entity_extraction_prompt, verbose=False ) print("✅ Extraction chain initialized") # Setup QA chain with custom prompt qa_chain = GraphCypherQAChain.from_llm( llm=llm, graph=graph, verbose=True, cypher_prompt=cypher_generation_prompt, return_intermediate_steps=True, top_k=10 ) print("✅ QA chain initialized") except Exception as e: print(f"❌ Startup error: {e}") raise # ================================ # Helper Functions with Prompt Templates # ================================ def extract_entities_relationships(text_chunk): """Extract entities and relationships using Pydantic structured output""" try: print(f"\n{'='*60}") print(f"Processing chunk: {text_chunk[:100]}...") # Use the extraction chain response = extraction_chain.run(text=text_chunk) print(f"RAW LLM RESPONSE:") print(response[:500]) print('='*60) # Try to parse with Pydantic parser try: result = extraction_parser.parse(response) entities = [e.dict() for e in result.entities] relationships = [r.dict() for r in result.relationships] print(f"✅ PARSED with Pydantic:") print(f" Entities: {len(entities)}") print(f" Relationships: {len(relationships)}") return {"entities": entities, "relationships": relationships} except Exception as parse_error: print(f"⚠️ Pydantic parsing failed: {parse_error}") print("Trying manual JSON extraction...") # Fallback: Try manual JSON extraction cleaned = response.strip() # Remove markdown if "```json" in cleaned: cleaned = cleaned.split("```json")[1].split("```")[0] elif "```" in cleaned: cleaned = cleaned.split("```")[1].split("```")[0] # Find JSON if "{" in cleaned and "}" in cleaned: start = cleaned.find("{") end = cleaned.rfind("}") + 1 cleaned = cleaned[start:end] data = json.loads(cleaned) print(f"✅ Manual JSON parse successful: {len(data.get('entities', []))} entities") return data except Exception as e: print(f"❌ All parsing failed: {e}") print("Using fallback extraction...") return fallback_extraction(text_chunk) def fallback_extraction(text): """Simple rule-based fallback extraction""" print("⚠️ Using fallback extraction...") entities = [] relationships = [] seen_entities = set() # Split into sentences sentences = [s.strip() for s in text.split('.') if s.strip()] for sentence in sentences: words = sentence.split() # Extract capitalized words/phrases as entities current_entity = [] for word in words: clean = re.sub(r'[^\w\s]', '', word) if clean and clean[0].isupper() and len(clean) > 2: current_entity.append(clean) elif current_entity: entity_name = ' '.join(current_entity) if entity_name not in seen_entities: entities.append({ "name": entity_name, "type": "Concept", "description": sentence[:100] }) seen_entities.add(entity_name) current_entity = [] # Check for common relationship patterns if ' created ' in sentence.lower() or ' developed ' in sentence.lower(): # Try to extract creator and creation parts = re.split(r' created | developed ', sentence, flags=re.IGNORECASE) if len(parts) == 2: creator = parts[0].strip().split()[-1] creation = parts[1].strip().split()[0] relationships.append({ "source": creator, "target": creation, "type": "CREATED", "context": sentence[:100] }) print(f"Fallback extracted: {len(entities)} entities, {len(relationships)} relationships") return {"entities": entities[:15], "relationships": relationships[:10]} def add_to_graph(entities, relationships, doc_name): """Add entities and relationships to Neo4j with proper sanitization""" entities_added = 0 relationships_added = 0 # Helper function to sanitize strings for Cypher def sanitize(text): if not text: return "" return text.replace("'", "\\'").replace('"', '\\"').replace("\n", " ") # Add entities for entity in entities: try: name = sanitize(entity.get('name', '')) if not name: continue entity_type = entity.get('type', 'Entity').replace(" ", "") desc = sanitize(entity.get('description', '')) query = f""" MERGE (e:{entity_type} {{name: '{name}'}}) ON CREATE SET e.description = '{desc}', e.source_document = '{doc_name}', e.created_at = datetime() ON MATCH SET e.description = CASE WHEN e.description = '' THEN '{desc}' ELSE e.description END """ graph.query(query) entities_added += 1 except Exception as e: print(f"Error adding entity {entity.get('name', 'unknown')}: {e}") # Add relationships for rel in relationships: try: source = sanitize(rel.get('source', '')) target = sanitize(rel.get('target', '')) rel_type = rel.get('type', 'RELATED_TO').replace(" ", "_").upper() context = sanitize(rel.get('context', '')) if not source or not target: continue query = f""" MATCH (a {{name: '{source}'}}) MATCH (b {{name: '{target}'}}) MERGE (a)-[r:{rel_type}]->(b) ON CREATE SET r.context = '{context}', r.source_document = '{doc_name}', r.created_at = datetime() """ graph.query(query) relationships_added += 1 except Exception as e: print(f"Error adding relationship {source} -> {target}: {e}") return entities_added, relationships_added # ================================ # API Endpoints # ================================ @app.get("/") async def root(): return { "message": "Graph RAG API with Proper Prompt Templates", "model": "Qwen/Qwen2.5-3B-Instruct", "endpoints": { "POST /upload": "Upload document to build knowledge graph", "POST /query": "Ask questions about your documents", "GET /stats": "Get graph statistics", "GET /health": "Check system health", "DELETE /clear": "Clear the graph" }, "features": [ "Structured entity extraction with PromptTemplate", "Cypher generation with custom prompt", "Answer generation with context", "Natural language question answering" ] } @app.get("/health") async def health_check(): """Check if all systems are operational""" try: # Test Neo4j connection graph.query("MATCH (n) RETURN count(n) as count LIMIT 1") neo4j_status = "connected" except: neo4j_status = "disconnected" return { "status": "healthy" if neo4j_status == "connected" else "unhealthy", "neo4j": neo4j_status, "llm": "loaded" if llm else "not loaded", "extraction_chain": "ready" if extraction_chain else "not ready", "qa_chain": "ready" if qa_chain else "not ready" } @app.post("/upload", response_model=UploadResponse) async def upload_document(file: UploadFile = File(...)): """Upload document and build knowledge graph using prompt templates""" if not file.filename.endswith('.txt'): raise HTTPException(400, "Only .txt files supported") try: # Read file content = await file.read() text = content.decode('utf-8') print(f"Processing document: {file.filename} ({len(text)} characters)") # Split into chunks text_splitter = RecursiveCharacterTextSplitter( chunk_size=800, chunk_overlap=150, length_function=len ) chunks = text_splitter.split_text(text) print(f"Split into {len(chunks)} chunks") total_entities = 0 total_relationships = 0 # Process chunks (limit to 10 for demo, remove limit in production) max_chunks = min(len(chunks), 10) for i, chunk in enumerate(chunks[:max_chunks]): print(f"Processing chunk {i+1}/{max_chunks}") # Use extraction chain with prompt template extracted = extract_entities_relationships(chunk) entities = extracted.get("entities", []) relationships = extracted.get("relationships", []) print(f" Extracted: {len(entities)} entities, {len(relationships)} relationships") # Add to graph e_count, r_count = add_to_graph(entities, relationships, file.filename) total_entities += e_count total_relationships += r_count return UploadResponse( message=f"Successfully processed {file.filename} ({max_chunks} chunks)", entities_count=total_entities, relationships_count=total_relationships ) except Exception as e: print(f"Upload error: {e}") raise HTTPException(500, f"Error processing document: {str(e)}") @app.post("/query", response_model=QueryResponse) async def query_graph(query: Query): """Ask a question using GraphCypherQAChain with custom prompts""" try: print(f"Question: {query.question}") # Use QA chain with custom cypher generation prompt response = qa_chain.invoke({"query": query.question}) # Extract generated Cypher cypher = None if "intermediate_steps" in response and len(response["intermediate_steps"]) > 0: cypher = response["intermediate_steps"][0].get("query", "") print(f"Generated Cypher: {cypher}") print(f"Answer: {response['result']}") return QueryResponse( question=query.question, answer=response["result"], cypher_query=cypher ) except Exception as e: print(f"Query error: {e}") raise HTTPException(500, f"Error answering question: {str(e)}") @app.get("/stats") async def get_stats(): """Get graph statistics""" try: # Count nodes by type nodes = graph.query(""" MATCH (n) RETURN labels(n)[0] as type, count(n) as count ORDER BY count DESC """) # Count relationships by type rels = graph.query(""" MATCH ()-[r]->() RETURN type(r) as type, count(r) as count ORDER BY count DESC """) # Get total counts total_nodes = sum([n['count'] for n in nodes]) if nodes else 0 total_rels = sum([r['count'] for r in rels]) if rels else 0 # Get sample entities samples = graph.query(""" MATCH (n) RETURN labels(n)[0] as type, n.name as name LIMIT 10 """) return { "total_nodes": total_nodes, "total_relationships": total_rels, "nodes_by_type": nodes, "relationships_by_type": rels, "sample_entities": samples } except Exception as e: raise HTTPException(500, f"Error getting stats: {str(e)}") @app.delete("/clear") async def clear_graph(): """Clear all data from graph""" try: result = graph.query("MATCH (n) RETURN count(n) as count") count = result[0]['count'] if result else 0 graph.query("MATCH (n) DETACH DELETE n") return { "message": f"Graph cleared successfully. Deleted {count} nodes.", "nodes_deleted": count } except Exception as e: raise HTTPException(500, f"Error clearing graph: {str(e)}") # ================================ # Run # ================================ if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)