santanche's picture
feat (start): first setup
a30a065
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from langchain_community.llms import Ollama
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
import sqlparse
import json
import asyncio
from typing import AsyncGenerator
from pathlib import Path
import os
app = FastAPI(title="NL to SQL Multi-Agent System")
# Enable CORS for web interface
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Mount static files directory
static_dir = Path(__file__).parent / "static"
static_dir.mkdir(exist_ok=True)
app.mount("/static", StaticFiles(directory=static_dir), name="static")
# Request model
class ExecutionRequest(BaseModel):
schema: str
schema_prompt: str
query_prompt: str
syntax_prompt: str
semantic_prompt: str
question: str
model: str = "phi3"
max_iterations: int = 3
# Initialize LLM
def get_llm(model_name: str):
return Ollama(model=model_name, temperature=0.1)
# Syntax validator (deterministic)
def validate_syntax(sql_query: str):
try:
parsed = sqlparse.parse(sql_query)
if not parsed:
return False, "Empty or invalid SQL"
return True, "Syntax valid"
except Exception as e:
return False, f"Syntax error: {str(e)}"
# Stream event helper
def create_event(event_type: str, **kwargs):
data = {"type": event_type, **kwargs}
return f"data: {json.dumps(data)}\n\n"
# Main agent pipeline
async def execute_pipeline(request: ExecutionRequest) -> AsyncGenerator[str, None]:
try:
# Initialize LLM
llm = get_llm(request.model)
yield create_event("agent_start", agent="Schema Analyzer")
yield create_event("agent_input", content=f"Schema: {request.schema[:100]}... | Question: {request.question}")
# Agent 1: Schema Analyzer
schema_prompt = PromptTemplate(
input_variables=["schema", "question"],
template=request.schema_prompt
)
schema_chain = LLMChain(llm=llm, prompt=schema_prompt)
relevant_schema_result = schema_chain.invoke({
"schema": request.schema,
"question": request.question
})
relevant_schema = relevant_schema_result.get('text', relevant_schema_result) if isinstance(relevant_schema_result, dict) else relevant_schema_result
yield create_event("agent_output", content=relevant_schema.strip())
# Iteration loop
sql_query = None
for iteration in range(request.max_iterations):
yield create_event("iteration", iteration=iteration + 1)
# Agent 2: Query Generator
yield create_event("agent_start", agent="Query Generator")
yield create_event("agent_input", content=f"Relevant schema: {relevant_schema[:100]}...")
query_prompt = PromptTemplate(
input_variables=["question", "relevant_schema"],
template=request.query_prompt
)
sql_chain = LLMChain(llm=llm, prompt=query_prompt)
sql_result = sql_chain.invoke({
"question": request.question,
"relevant_schema": relevant_schema
})
sql_query = sql_result.get('text', sql_result) if isinstance(sql_result, dict) else sql_result
sql_query = sql_query.strip()
yield create_event("agent_output", content=sql_query)
# Agent 3: Syntax Validator
yield create_event("agent_start", agent="Syntax Validator")
is_valid, syntax_msg = validate_syntax(sql_query)
if is_valid:
yield create_event("validation", content=syntax_msg, status="pass")
else:
yield create_event("validation", content=syntax_msg, status="fail")
continue
# Agent 4: Semantic Verifier
yield create_event("agent_start", agent="Semantic Verifier")
yield create_event("agent_input", content=f"Checking if SQL answers: {request.question}")
verify_prompt = PromptTemplate(
input_variables=["question", "sql_query"],
template=request.semantic_prompt
)
verify_chain = LLMChain(llm=llm, prompt=verify_prompt)
verification_result = verify_chain.invoke({
"question": request.question,
"sql_query": sql_query
})
verification = verification_result.get('text', verification_result) if isinstance(verification_result, dict) else verification_result
yield create_event("agent_output", content=verification.strip())
if "YES" in verification.upper():
yield create_event("validation", content="Query is semantically correct", status="pass")
break
else:
yield create_event("validation", content="Query has semantic issues", status="fail")
# Final result
yield create_event("final_result", sql=sql_query if sql_query else "No valid SQL generated")
except Exception as e:
yield create_event("error", message=str(e))
@app.get("/")
async def root():
"""Serve the main web interface"""
index_file = static_dir / "index.html"
if index_file.exists():
return FileResponse(index_file)
return {"message": "Place index.html in the static/ directory"}
@app.post("/execute")
async def execute(request: ExecutionRequest):
"""Execute the multi-agent NL to SQL pipeline with streaming logs"""
return StreamingResponse(
execute_pipeline(request),
media_type="text/event-stream"
)
@app.get("/health")
async def health():
"""Health check endpoint"""
return {"status": "ok"}
@app.get("/models")
async def list_models():
"""List available Ollama models"""
# This would require calling ollama CLI or API
# For now, return common models
return {
"models": ["phi3", "llama3.2:3b", "gemma2:2b", "mistral"]
}
if __name__ == "__main__":
import uvicorn
# HuggingFace Spaces uses port 7860
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)