Spaces:
Sleeping
Sleeping
| # ================================ | |
| # app.py - FastAPI Graph RAG API with Proper Prompt Templates | |
| # ================================ | |
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| from langchain.chains import GraphCypherQAChain, LLMChain | |
| from langchain_community.graphs import Neo4jGraph | |
| from langchain_community.llms import HuggingFaceHub | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.prompts import PromptTemplate | |
| from langchain.output_parsers import PydanticOutputParser | |
| from typing import List, Optional | |
| import os | |
| import json | |
| import re | |
| import uvicorn | |
| # ================================ | |
| # Configuration | |
| # ================================ | |
| app = FastAPI( | |
| title="Graph RAG API", | |
| description="Document QA using Graph RAG with Qwen model", | |
| version="1.0.0" | |
| ) | |
| # CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Environment variables | |
| NEO4J_URI = os.getenv("NEO4J_URI") | |
| NEO4J_USERNAME = os.getenv("NEO4J_USERNAME", "neo4j") | |
| NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| # Global variables | |
| graph = None | |
| llm = None | |
| qa_chain = None | |
| extraction_chain = None | |
| # ================================ | |
| # Pydantic Models for Extraction | |
| # ================================ | |
| class Entity(BaseModel): | |
| """Single entity extracted from text""" | |
| name: str = Field(description="The name of the entity") | |
| type: str = Field(description="Type: Person, Organization, Product, Technology, Concept, Location") | |
| description: Optional[str] = Field(default="", description="Brief description of the entity") | |
| class Relationship(BaseModel): | |
| """Relationship between two entities""" | |
| source: str = Field(description="Source entity name") | |
| target: str = Field(description="Target entity name") | |
| type: str = Field(description="Relationship type in UPPER_SNAKE_CASE (e.g., CREATED, FOUNDED, USES)") | |
| context: Optional[str] = Field(default="", description="Context of the relationship") | |
| class ExtractionResult(BaseModel): | |
| """Complete extraction result""" | |
| entities: List[Entity] = Field(description="List of extracted entities") | |
| relationships: List[Relationship] = Field(description="List of extracted relationships") | |
| # ================================ | |
| # Pydantic Models for API | |
| # ================================ | |
| class Query(BaseModel): | |
| question: str | |
| class UploadResponse(BaseModel): | |
| message: str | |
| entities_count: int | |
| relationships_count: int | |
| class QueryResponse(BaseModel): | |
| question: str | |
| answer: str | |
| cypher_query: str = None | |
| # ================================ | |
| # Prompt Templates with Pydantic Parser | |
| # ================================ | |
| # Create parser for structured output | |
| extraction_parser = PydanticOutputParser(pydantic_object=ExtractionResult) | |
| # 1. Entity Extraction Prompt with Pydantic | |
| ENTITY_EXTRACTION_TEMPLATE = """Extract entities and relationships from the text. | |
| {format_instructions} | |
| TEXT: | |
| {text} | |
| Important: | |
| - Extract people, organizations, products, technologies, concepts | |
| - For relationships use: CREATED, FOUNDED, USES, BUILT_ON, WORKS_AT, CEO_OF, INTEGRATES_WITH | |
| - Be specific and accurate | |
| Your response:""" | |
| entity_extraction_prompt = PromptTemplate( | |
| input_variables=["text"], | |
| template=ENTITY_EXTRACTION_TEMPLATE, | |
| partial_variables={"format_instructions": extraction_parser.get_format_instructions()} | |
| ) | |
| # 2. Cypher Generation Prompt Template | |
| CYPHER_GENERATION_TEMPLATE = """You are a Neo4j Cypher expert. Generate accurate Cypher queries based on the graph schema and user questions. | |
| GRAPH SCHEMA: | |
| {schema} | |
| USER QUESTION: | |
| {question} | |
| INSTRUCTIONS: | |
| 1. Generate a Cypher query that answers the question | |
| 2. Only use node labels and relationship types present in the schema | |
| 3. Use MATCH to find patterns | |
| 4. Use WHERE for filtering | |
| 5. Use RETURN to specify what to return | |
| 6. Keep queries simple and efficient | |
| 7. For "tell me about X" questions, get the entity and its immediate connections | |
| 8. For "who/what/where" questions, find specific information | |
| EXAMPLES: | |
| Question: "Who is the CEO of Tesla?" | |
| Cypher: MATCH (p:Person)-[:CEO_OF]->(c:Company {{name: 'Tesla'}}) RETURN p.name | |
| Question: "What products does Tesla make?" | |
| Cypher: MATCH (c:Company {{name: 'Tesla'}})-[:MANUFACTURES|CREATES]->(p:Product) RETURN p.name | |
| Question: "Tell me about Elon Musk" | |
| Cypher: MATCH (p:Person {{name: 'Elon Musk'}})-[r]-(connected) RETURN p, type(r), connected LIMIT 10 | |
| Return ONLY the Cypher query with no explanation or markdown:""" | |
| cypher_generation_prompt = PromptTemplate( | |
| input_variables=["schema", "question"], | |
| template=CYPHER_GENERATION_TEMPLATE | |
| ) | |
| # 3. Answer Generation Prompt Template | |
| ANSWER_GENERATION_TEMPLATE = """You are a helpful assistant that provides clear, concise answers based on graph database query results. | |
| USER QUESTION: | |
| {question} | |
| CYPHER QUERY USED: | |
| {query} | |
| QUERY RESULTS: | |
| {context} | |
| INSTRUCTIONS: | |
| 1. Answer the user's question based on the query results | |
| 2. Be clear and concise | |
| 3. If results are empty, say "I don't have information about that in the knowledge graph" | |
| 4. Don't mention technical details like Cypher or graph databases | |
| 5. Provide a natural, conversational answer | |
| ANSWER:""" | |
| answer_generation_prompt = PromptTemplate( | |
| input_variables=["question", "query", "context"], | |
| template=ANSWER_GENERATION_TEMPLATE | |
| ) | |
| # ================================ | |
| # Startup | |
| # ================================ | |
| async def startup_event(): | |
| global graph, llm, qa_chain, extraction_chain | |
| try: | |
| # Initialize Neo4j | |
| graph = Neo4jGraph( | |
| url=NEO4J_URI, | |
| username=NEO4J_USERNAME, | |
| password=NEO4J_PASSWORD | |
| ) | |
| print("✅ Connected to Neo4j") | |
| # Initialize Qwen LLM | |
| llm = HuggingFaceHub( | |
| repo_id="Qwen/Qwen2.5-3B-Instruct", | |
| huggingfacehub_api_token=HF_TOKEN, | |
| model_kwargs={"temperature": 0.1, "max_new_tokens": 512} | |
| ) | |
| print("✅ Qwen model loaded") | |
| # Create extraction chain with prompt template | |
| extraction_chain = LLMChain( | |
| llm=llm, | |
| prompt=entity_extraction_prompt, | |
| verbose=False | |
| ) | |
| print("✅ Extraction chain initialized") | |
| # Setup QA chain with custom prompt | |
| qa_chain = GraphCypherQAChain.from_llm( | |
| llm=llm, | |
| graph=graph, | |
| verbose=True, | |
| cypher_prompt=cypher_generation_prompt, | |
| return_intermediate_steps=True, | |
| top_k=10 | |
| ) | |
| print("✅ QA chain initialized") | |
| except Exception as e: | |
| print(f"❌ Startup error: {e}") | |
| raise | |
| # ================================ | |
| # Helper Functions with Prompt Templates | |
| # ================================ | |
| def extract_entities_relationships(text_chunk): | |
| """Extract entities and relationships using Pydantic structured output""" | |
| try: | |
| print(f"\n{'='*60}") | |
| print(f"Processing chunk: {text_chunk[:100]}...") | |
| # Use the extraction chain | |
| response = extraction_chain.run(text=text_chunk) | |
| print(f"RAW LLM RESPONSE:") | |
| print(response[:500]) | |
| print('='*60) | |
| # Try to parse with Pydantic parser | |
| try: | |
| result = extraction_parser.parse(response) | |
| entities = [e.dict() for e in result.entities] | |
| relationships = [r.dict() for r in result.relationships] | |
| print(f"✅ PARSED with Pydantic:") | |
| print(f" Entities: {len(entities)}") | |
| print(f" Relationships: {len(relationships)}") | |
| return {"entities": entities, "relationships": relationships} | |
| except Exception as parse_error: | |
| print(f"⚠️ Pydantic parsing failed: {parse_error}") | |
| print("Trying manual JSON extraction...") | |
| # Fallback: Try manual JSON extraction | |
| cleaned = response.strip() | |
| # Remove markdown | |
| if "```json" in cleaned: | |
| cleaned = cleaned.split("```json")[1].split("```")[0] | |
| elif "```" in cleaned: | |
| cleaned = cleaned.split("```")[1].split("```")[0] | |
| # Find JSON | |
| if "{" in cleaned and "}" in cleaned: | |
| start = cleaned.find("{") | |
| end = cleaned.rfind("}") + 1 | |
| cleaned = cleaned[start:end] | |
| data = json.loads(cleaned) | |
| print(f"✅ Manual JSON parse successful: {len(data.get('entities', []))} entities") | |
| return data | |
| except Exception as e: | |
| print(f"❌ All parsing failed: {e}") | |
| print("Using fallback extraction...") | |
| return fallback_extraction(text_chunk) | |
| def fallback_extraction(text): | |
| """Simple rule-based fallback extraction""" | |
| print("⚠️ Using fallback extraction...") | |
| entities = [] | |
| relationships = [] | |
| seen_entities = set() | |
| # Split into sentences | |
| sentences = [s.strip() for s in text.split('.') if s.strip()] | |
| for sentence in sentences: | |
| words = sentence.split() | |
| # Extract capitalized words/phrases as entities | |
| current_entity = [] | |
| for word in words: | |
| clean = re.sub(r'[^\w\s]', '', word) | |
| if clean and clean[0].isupper() and len(clean) > 2: | |
| current_entity.append(clean) | |
| elif current_entity: | |
| entity_name = ' '.join(current_entity) | |
| if entity_name not in seen_entities: | |
| entities.append({ | |
| "name": entity_name, | |
| "type": "Concept", | |
| "description": sentence[:100] | |
| }) | |
| seen_entities.add(entity_name) | |
| current_entity = [] | |
| # Check for common relationship patterns | |
| if ' created ' in sentence.lower() or ' developed ' in sentence.lower(): | |
| # Try to extract creator and creation | |
| parts = re.split(r' created | developed ', sentence, flags=re.IGNORECASE) | |
| if len(parts) == 2: | |
| creator = parts[0].strip().split()[-1] | |
| creation = parts[1].strip().split()[0] | |
| relationships.append({ | |
| "source": creator, | |
| "target": creation, | |
| "type": "CREATED", | |
| "context": sentence[:100] | |
| }) | |
| print(f"Fallback extracted: {len(entities)} entities, {len(relationships)} relationships") | |
| return {"entities": entities[:15], "relationships": relationships[:10]} | |
| def add_to_graph(entities, relationships, doc_name): | |
| """Add entities and relationships to Neo4j with proper sanitization""" | |
| entities_added = 0 | |
| relationships_added = 0 | |
| # Helper function to sanitize strings for Cypher | |
| def sanitize(text): | |
| if not text: | |
| return "" | |
| return text.replace("'", "\\'").replace('"', '\\"').replace("\n", " ") | |
| # Add entities | |
| for entity in entities: | |
| try: | |
| name = sanitize(entity.get('name', '')) | |
| if not name: | |
| continue | |
| entity_type = entity.get('type', 'Entity').replace(" ", "") | |
| desc = sanitize(entity.get('description', '')) | |
| query = f""" | |
| MERGE (e:{entity_type} {{name: '{name}'}}) | |
| ON CREATE SET | |
| e.description = '{desc}', | |
| e.source_document = '{doc_name}', | |
| e.created_at = datetime() | |
| ON MATCH SET | |
| e.description = CASE WHEN e.description = '' THEN '{desc}' ELSE e.description END | |
| """ | |
| graph.query(query) | |
| entities_added += 1 | |
| except Exception as e: | |
| print(f"Error adding entity {entity.get('name', 'unknown')}: {e}") | |
| # Add relationships | |
| for rel in relationships: | |
| try: | |
| source = sanitize(rel.get('source', '')) | |
| target = sanitize(rel.get('target', '')) | |
| rel_type = rel.get('type', 'RELATED_TO').replace(" ", "_").upper() | |
| context = sanitize(rel.get('context', '')) | |
| if not source or not target: | |
| continue | |
| query = f""" | |
| MATCH (a {{name: '{source}'}}) | |
| MATCH (b {{name: '{target}'}}) | |
| MERGE (a)-[r:{rel_type}]->(b) | |
| ON CREATE SET | |
| r.context = '{context}', | |
| r.source_document = '{doc_name}', | |
| r.created_at = datetime() | |
| """ | |
| graph.query(query) | |
| relationships_added += 1 | |
| except Exception as e: | |
| print(f"Error adding relationship {source} -> {target}: {e}") | |
| return entities_added, relationships_added | |
| # ================================ | |
| # API Endpoints | |
| # ================================ | |
| async def root(): | |
| return { | |
| "message": "Graph RAG API with Proper Prompt Templates", | |
| "model": "Qwen/Qwen2.5-3B-Instruct", | |
| "endpoints": { | |
| "POST /upload": "Upload document to build knowledge graph", | |
| "POST /query": "Ask questions about your documents", | |
| "GET /stats": "Get graph statistics", | |
| "GET /health": "Check system health", | |
| "DELETE /clear": "Clear the graph" | |
| }, | |
| "features": [ | |
| "Structured entity extraction with PromptTemplate", | |
| "Cypher generation with custom prompt", | |
| "Answer generation with context", | |
| "Natural language question answering" | |
| ] | |
| } | |
| async def health_check(): | |
| """Check if all systems are operational""" | |
| try: | |
| # Test Neo4j connection | |
| graph.query("MATCH (n) RETURN count(n) as count LIMIT 1") | |
| neo4j_status = "connected" | |
| except: | |
| neo4j_status = "disconnected" | |
| return { | |
| "status": "healthy" if neo4j_status == "connected" else "unhealthy", | |
| "neo4j": neo4j_status, | |
| "llm": "loaded" if llm else "not loaded", | |
| "extraction_chain": "ready" if extraction_chain else "not ready", | |
| "qa_chain": "ready" if qa_chain else "not ready" | |
| } | |
| async def upload_document(file: UploadFile = File(...)): | |
| """Upload document and build knowledge graph using prompt templates""" | |
| if not file.filename.endswith('.txt'): | |
| raise HTTPException(400, "Only .txt files supported") | |
| try: | |
| # Read file | |
| content = await file.read() | |
| text = content.decode('utf-8') | |
| print(f"Processing document: {file.filename} ({len(text)} characters)") | |
| # Split into chunks | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=800, | |
| chunk_overlap=150, | |
| length_function=len | |
| ) | |
| chunks = text_splitter.split_text(text) | |
| print(f"Split into {len(chunks)} chunks") | |
| total_entities = 0 | |
| total_relationships = 0 | |
| # Process chunks (limit to 10 for demo, remove limit in production) | |
| max_chunks = min(len(chunks), 10) | |
| for i, chunk in enumerate(chunks[:max_chunks]): | |
| print(f"Processing chunk {i+1}/{max_chunks}") | |
| # Use extraction chain with prompt template | |
| extracted = extract_entities_relationships(chunk) | |
| entities = extracted.get("entities", []) | |
| relationships = extracted.get("relationships", []) | |
| print(f" Extracted: {len(entities)} entities, {len(relationships)} relationships") | |
| # Add to graph | |
| e_count, r_count = add_to_graph(entities, relationships, file.filename) | |
| total_entities += e_count | |
| total_relationships += r_count | |
| return UploadResponse( | |
| message=f"Successfully processed {file.filename} ({max_chunks} chunks)", | |
| entities_count=total_entities, | |
| relationships_count=total_relationships | |
| ) | |
| except Exception as e: | |
| print(f"Upload error: {e}") | |
| raise HTTPException(500, f"Error processing document: {str(e)}") | |
| async def query_graph(query: Query): | |
| """Ask a question using GraphCypherQAChain with custom prompts""" | |
| try: | |
| print(f"Question: {query.question}") | |
| # Use QA chain with custom cypher generation prompt | |
| response = qa_chain.invoke({"query": query.question}) | |
| # Extract generated Cypher | |
| cypher = None | |
| if "intermediate_steps" in response and len(response["intermediate_steps"]) > 0: | |
| cypher = response["intermediate_steps"][0].get("query", "") | |
| print(f"Generated Cypher: {cypher}") | |
| print(f"Answer: {response['result']}") | |
| return QueryResponse( | |
| question=query.question, | |
| answer=response["result"], | |
| cypher_query=cypher | |
| ) | |
| except Exception as e: | |
| print(f"Query error: {e}") | |
| raise HTTPException(500, f"Error answering question: {str(e)}") | |
| async def get_stats(): | |
| """Get graph statistics""" | |
| try: | |
| # Count nodes by type | |
| nodes = graph.query(""" | |
| MATCH (n) | |
| RETURN labels(n)[0] as type, count(n) as count | |
| ORDER BY count DESC | |
| """) | |
| # Count relationships by type | |
| rels = graph.query(""" | |
| MATCH ()-[r]->() | |
| RETURN type(r) as type, count(r) as count | |
| ORDER BY count DESC | |
| """) | |
| # Get total counts | |
| total_nodes = sum([n['count'] for n in nodes]) if nodes else 0 | |
| total_rels = sum([r['count'] for r in rels]) if rels else 0 | |
| # Get sample entities | |
| samples = graph.query(""" | |
| MATCH (n) | |
| RETURN labels(n)[0] as type, n.name as name | |
| LIMIT 10 | |
| """) | |
| return { | |
| "total_nodes": total_nodes, | |
| "total_relationships": total_rels, | |
| "nodes_by_type": nodes, | |
| "relationships_by_type": rels, | |
| "sample_entities": samples | |
| } | |
| except Exception as e: | |
| raise HTTPException(500, f"Error getting stats: {str(e)}") | |
| async def clear_graph(): | |
| """Clear all data from graph""" | |
| try: | |
| result = graph.query("MATCH (n) RETURN count(n) as count") | |
| count = result[0]['count'] if result else 0 | |
| graph.query("MATCH (n) DETACH DELETE n") | |
| return { | |
| "message": f"Graph cleared successfully. Deleted {count} nodes.", | |
| "nodes_deleted": count | |
| } | |
| except Exception as e: | |
| raise HTTPException(500, f"Error clearing graph: {str(e)}") | |
| # ================================ | |
| # Run | |
| # ================================ | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |