Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse, FileResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel | |
| from langchain_community.llms import Ollama | |
| from langchain.prompts import PromptTemplate | |
| from langchain.chains import LLMChain | |
| import sqlparse | |
| import json | |
| import asyncio | |
| from typing import AsyncGenerator | |
| from pathlib import Path | |
| import os | |
| app = FastAPI(title="NL to SQL Multi-Agent System") | |
| # Enable CORS for web interface | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Mount static files directory | |
| static_dir = Path(__file__).parent / "static" | |
| static_dir.mkdir(exist_ok=True) | |
| app.mount("/static", StaticFiles(directory=static_dir), name="static") | |
| # Request model | |
| class ExecutionRequest(BaseModel): | |
| schema: str | |
| schema_prompt: str | |
| query_prompt: str | |
| syntax_prompt: str | |
| semantic_prompt: str | |
| question: str | |
| model: str = "phi3" | |
| max_iterations: int = 3 | |
| # Initialize LLM | |
| def get_llm(model_name: str): | |
| return Ollama(model=model_name, temperature=0.1) | |
| # Syntax validator (deterministic) | |
| def validate_syntax(sql_query: str): | |
| try: | |
| parsed = sqlparse.parse(sql_query) | |
| if not parsed: | |
| return False, "Empty or invalid SQL" | |
| return True, "Syntax valid" | |
| except Exception as e: | |
| return False, f"Syntax error: {str(e)}" | |
| # Stream event helper | |
| def create_event(event_type: str, **kwargs): | |
| data = {"type": event_type, **kwargs} | |
| return f"data: {json.dumps(data)}\n\n" | |
| # Main agent pipeline | |
| async def execute_pipeline(request: ExecutionRequest) -> AsyncGenerator[str, None]: | |
| try: | |
| # Initialize LLM | |
| llm = get_llm(request.model) | |
| yield create_event("agent_start", agent="Schema Analyzer") | |
| yield create_event("agent_input", content=f"Schema: {request.schema[:100]}... | Question: {request.question}") | |
| # Agent 1: Schema Analyzer | |
| schema_prompt = PromptTemplate( | |
| input_variables=["schema", "question"], | |
| template=request.schema_prompt | |
| ) | |
| schema_chain = LLMChain(llm=llm, prompt=schema_prompt) | |
| relevant_schema_result = schema_chain.invoke({ | |
| "schema": request.schema, | |
| "question": request.question | |
| }) | |
| relevant_schema = relevant_schema_result.get('text', relevant_schema_result) if isinstance(relevant_schema_result, dict) else relevant_schema_result | |
| yield create_event("agent_output", content=relevant_schema.strip()) | |
| # Iteration loop | |
| sql_query = None | |
| for iteration in range(request.max_iterations): | |
| yield create_event("iteration", iteration=iteration + 1) | |
| # Agent 2: Query Generator | |
| yield create_event("agent_start", agent="Query Generator") | |
| yield create_event("agent_input", content=f"Relevant schema: {relevant_schema[:100]}...") | |
| query_prompt = PromptTemplate( | |
| input_variables=["question", "relevant_schema"], | |
| template=request.query_prompt | |
| ) | |
| sql_chain = LLMChain(llm=llm, prompt=query_prompt) | |
| sql_result = sql_chain.invoke({ | |
| "question": request.question, | |
| "relevant_schema": relevant_schema | |
| }) | |
| sql_query = sql_result.get('text', sql_result) if isinstance(sql_result, dict) else sql_result | |
| sql_query = sql_query.strip() | |
| yield create_event("agent_output", content=sql_query) | |
| # Agent 3: Syntax Validator | |
| yield create_event("agent_start", agent="Syntax Validator") | |
| is_valid, syntax_msg = validate_syntax(sql_query) | |
| if is_valid: | |
| yield create_event("validation", content=syntax_msg, status="pass") | |
| else: | |
| yield create_event("validation", content=syntax_msg, status="fail") | |
| continue | |
| # Agent 4: Semantic Verifier | |
| yield create_event("agent_start", agent="Semantic Verifier") | |
| yield create_event("agent_input", content=f"Checking if SQL answers: {request.question}") | |
| verify_prompt = PromptTemplate( | |
| input_variables=["question", "sql_query"], | |
| template=request.semantic_prompt | |
| ) | |
| verify_chain = LLMChain(llm=llm, prompt=verify_prompt) | |
| verification_result = verify_chain.invoke({ | |
| "question": request.question, | |
| "sql_query": sql_query | |
| }) | |
| verification = verification_result.get('text', verification_result) if isinstance(verification_result, dict) else verification_result | |
| yield create_event("agent_output", content=verification.strip()) | |
| if "YES" in verification.upper(): | |
| yield create_event("validation", content="Query is semantically correct", status="pass") | |
| break | |
| else: | |
| yield create_event("validation", content="Query has semantic issues", status="fail") | |
| # Final result | |
| yield create_event("final_result", sql=sql_query if sql_query else "No valid SQL generated") | |
| except Exception as e: | |
| yield create_event("error", message=str(e)) | |
| async def root(): | |
| """Serve the main web interface""" | |
| index_file = static_dir / "index.html" | |
| if index_file.exists(): | |
| return FileResponse(index_file) | |
| return {"message": "Place index.html in the static/ directory"} | |
| async def execute(request: ExecutionRequest): | |
| """Execute the multi-agent NL to SQL pipeline with streaming logs""" | |
| return StreamingResponse( | |
| execute_pipeline(request), | |
| media_type="text/event-stream" | |
| ) | |
| async def health(): | |
| """Health check endpoint""" | |
| return {"status": "ok"} | |
| async def list_models(): | |
| """List available Ollama models""" | |
| # This would require calling ollama CLI or API | |
| # For now, return common models | |
| return { | |
| "models": ["phi3", "llama3.2:3b", "gemma2:2b", "mistral"] | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| # HuggingFace Spaces uses port 7860 | |
| port = int(os.environ.get("PORT", 7860)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |