Spaces:
Sleeping
Sleeping
File size: 6,630 Bytes
a30a065 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from langchain_community.llms import Ollama
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
import sqlparse
import json
import asyncio
from typing import AsyncGenerator
from pathlib import Path
import os
app = FastAPI(title="NL to SQL Multi-Agent System")
# Enable CORS for web interface
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Mount static files directory
static_dir = Path(__file__).parent / "static"
static_dir.mkdir(exist_ok=True)
app.mount("/static", StaticFiles(directory=static_dir), name="static")
# Request model
class ExecutionRequest(BaseModel):
schema: str
schema_prompt: str
query_prompt: str
syntax_prompt: str
semantic_prompt: str
question: str
model: str = "phi3"
max_iterations: int = 3
# Initialize LLM
def get_llm(model_name: str):
return Ollama(model=model_name, temperature=0.1)
# Syntax validator (deterministic)
def validate_syntax(sql_query: str):
try:
parsed = sqlparse.parse(sql_query)
if not parsed:
return False, "Empty or invalid SQL"
return True, "Syntax valid"
except Exception as e:
return False, f"Syntax error: {str(e)}"
# Stream event helper
def create_event(event_type: str, **kwargs):
data = {"type": event_type, **kwargs}
return f"data: {json.dumps(data)}\n\n"
# Main agent pipeline
async def execute_pipeline(request: ExecutionRequest) -> AsyncGenerator[str, None]:
try:
# Initialize LLM
llm = get_llm(request.model)
yield create_event("agent_start", agent="Schema Analyzer")
yield create_event("agent_input", content=f"Schema: {request.schema[:100]}... | Question: {request.question}")
# Agent 1: Schema Analyzer
schema_prompt = PromptTemplate(
input_variables=["schema", "question"],
template=request.schema_prompt
)
schema_chain = LLMChain(llm=llm, prompt=schema_prompt)
relevant_schema_result = schema_chain.invoke({
"schema": request.schema,
"question": request.question
})
relevant_schema = relevant_schema_result.get('text', relevant_schema_result) if isinstance(relevant_schema_result, dict) else relevant_schema_result
yield create_event("agent_output", content=relevant_schema.strip())
# Iteration loop
sql_query = None
for iteration in range(request.max_iterations):
yield create_event("iteration", iteration=iteration + 1)
# Agent 2: Query Generator
yield create_event("agent_start", agent="Query Generator")
yield create_event("agent_input", content=f"Relevant schema: {relevant_schema[:100]}...")
query_prompt = PromptTemplate(
input_variables=["question", "relevant_schema"],
template=request.query_prompt
)
sql_chain = LLMChain(llm=llm, prompt=query_prompt)
sql_result = sql_chain.invoke({
"question": request.question,
"relevant_schema": relevant_schema
})
sql_query = sql_result.get('text', sql_result) if isinstance(sql_result, dict) else sql_result
sql_query = sql_query.strip()
yield create_event("agent_output", content=sql_query)
# Agent 3: Syntax Validator
yield create_event("agent_start", agent="Syntax Validator")
is_valid, syntax_msg = validate_syntax(sql_query)
if is_valid:
yield create_event("validation", content=syntax_msg, status="pass")
else:
yield create_event("validation", content=syntax_msg, status="fail")
continue
# Agent 4: Semantic Verifier
yield create_event("agent_start", agent="Semantic Verifier")
yield create_event("agent_input", content=f"Checking if SQL answers: {request.question}")
verify_prompt = PromptTemplate(
input_variables=["question", "sql_query"],
template=request.semantic_prompt
)
verify_chain = LLMChain(llm=llm, prompt=verify_prompt)
verification_result = verify_chain.invoke({
"question": request.question,
"sql_query": sql_query
})
verification = verification_result.get('text', verification_result) if isinstance(verification_result, dict) else verification_result
yield create_event("agent_output", content=verification.strip())
if "YES" in verification.upper():
yield create_event("validation", content="Query is semantically correct", status="pass")
break
else:
yield create_event("validation", content="Query has semantic issues", status="fail")
# Final result
yield create_event("final_result", sql=sql_query if sql_query else "No valid SQL generated")
except Exception as e:
yield create_event("error", message=str(e))
@app.get("/")
async def root():
"""Serve the main web interface"""
index_file = static_dir / "index.html"
if index_file.exists():
return FileResponse(index_file)
return {"message": "Place index.html in the static/ directory"}
@app.post("/execute")
async def execute(request: ExecutionRequest):
"""Execute the multi-agent NL to SQL pipeline with streaming logs"""
return StreamingResponse(
execute_pipeline(request),
media_type="text/event-stream"
)
@app.get("/health")
async def health():
"""Health check endpoint"""
return {"status": "ok"}
@app.get("/models")
async def list_models():
"""List available Ollama models"""
# This would require calling ollama CLI or API
# For now, return common models
return {
"models": ["phi3", "llama3.2:3b", "gemma2:2b", "mistral"]
}
if __name__ == "__main__":
import uvicorn
# HuggingFace Spaces uses port 7860
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)
|