from fastapi import FastAPI , HTTPException , BackgroundTasks from src.embedding import create_embeddings from src.graph import workflow from pydantic import BaseModel , Field from langgraph.checkpoint.postgres import PostgresSaver from langchain_core.messages import HumanMessage import os app = FastAPI( title="Text2SQL Agent API", description="A production-grade backend powering LangGraph agent.", version="1.0.0" ) class UploadRequest(BaseModel): connection_url: str = Field(..., description="Database URL") user_id: str = Field(..., description="The unique identifier for the tenant context.") class ChatRequest(BaseModel) : message : str = Field(...,description="Input message by the user.") thread_id : str = Field(...,description="Unique session ID to maintain short term memory.") user_id : str = Field(...,description="The unique identifier for the tenant context.") connection_url : str = Field(...,description="Database URL") @app.post("/upload_url" , summary="Recieve database URL and invoke ingestion pipeline.") def upload(request : UploadRequest , background_tasks : BackgroundTasks) : background_tasks.add_task(create_embeddings , request.connection_url , request.user_id) return { "status" : "success", "message" : "Ingestion Pipeline started !" } @app.post("/chat" , summary="Return the response generated by the agent for the given user query.") def chat_endpoint(request : ChatRequest) : db_uri = os.getenv("DATABASE_URI") print(f"DATABASE_URI = {repr(db_uri)}", flush=True) with PostgresSaver.from_conn_string(db_uri) as checkpointer: checkpointer.setup() agent = workflow.compile( checkpointer=checkpointer ) config = { "configurable" : { 'thread_id' : request.thread_id } } initial_state = { 'connection_url' : request.connection_url , 'user_id' : request.user_id , 'messages' : [HumanMessage(content=request.message)], 'retry' : 0 } try : result = agent.invoke(initial_state , config=config) final_result = result.get("final_result") print("*"*50 , flush=True) print(f"\n\n Scheme : {result['scheme']}\n\n" , flush=True) print(f"\n\nSql Query : {result['sql_query']}\n\n" , flush=True) print(f"\n\nQuery Result : {result['query_result']}\n\n" , flush=True) return { "status": "success", "thread_id": request.thread_id, "response": final_result } except Exception as e : raise HTTPException(status_code=500 , detail=f"Error : {str(e)}")