text2sql_backend / src /main.py
LightRT's picture
Project Completion Commit
52adb86
raw
history blame
2.74 kB
from fastapi import FastAPI , HTTPException , BackgroundTasks
from src.embedding import create_embeddings
from src.graph import workflow
from pydantic import BaseModel , Field
from langgraph.checkpoint.postgres import PostgresSaver
from langchain_core.messages import HumanMessage
import os
app = FastAPI(
title="Text2SQL Agent API",
description="A production-grade backend powering LangGraph agent.",
version="1.0.0"
)
class UploadRequest(BaseModel):
connection_url: str = Field(..., description="Database URL")
user_id: str = Field(..., description="The unique identifier for the tenant context.")
class ChatRequest(BaseModel) :
message : str = Field(...,description="Input message by the user.")
thread_id : str = Field(...,description="Unique session ID to maintain short term memory.")
user_id : str = Field(...,description="The unique identifier for the tenant context.")
connection_url : str = Field(...,description="Database URL")
@app.post("/upload_url" , summary="Recieve database URL and invoke ingestion pipeline.")
def upload(request : UploadRequest , background_tasks : BackgroundTasks) :
background_tasks.add_task(create_embeddings , request.connection_url , request.user_id)
return {
"status" : "success",
"message" : "Ingestion Pipeline started !"
}
@app.post("/chat" , summary="Return the response generated by the agent for the given user query.")
def chat_endpoint(request : ChatRequest) :
db_uri = os.getenv("DATABASE_URI")
with PostgresSaver.from_conn_string(db_uri) as checkpointer:
checkpointer.setup()
agent = workflow.compile(
checkpointer=checkpointer
)
config = {
"configurable" : {
'thread_id' : request.thread_id
}
}
initial_state = {
'connection_url' : request.connection_url ,
'user_id' : request.user_id ,
'messages' : [HumanMessage(content=request.message)],
'retry' : 0
}
try :
result = agent.invoke(initial_state , config=config)
final_result = result.get("final_result")
print("*"*50 , flush=True)
print(f"\n\n Scheme : {result['scheme']}\n\n" , flush=True)
print(f"\n\nSql Query : {result['sql_query']}\n\n" , flush=True)
print(f"\n\nQuery Result : {result['query_result']}\n\n" , flush=True)
return {
"status": "success",
"thread_id": request.thread_id,
"response": final_result
}
except Exception as e :
raise HTTPException(status_code=500 , detail=f"Error : {str(e)}")