File size: 2,792 Bytes
52adb86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c87f5a
 
52adb86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from fastapi import FastAPI , HTTPException , BackgroundTasks
from src.embedding import create_embeddings
from src.graph import workflow
from pydantic import BaseModel , Field
from langgraph.checkpoint.postgres import PostgresSaver
from langchain_core.messages import HumanMessage
import os 

app = FastAPI(
    title="Text2SQL Agent API",
    description="A production-grade backend powering LangGraph agent.",
    version="1.0.0"
)

class UploadRequest(BaseModel):
    connection_url: str = Field(..., description="Database URL")
    user_id: str = Field(..., description="The unique identifier for the tenant context.")

class ChatRequest(BaseModel) :
    message : str = Field(...,description="Input message by the user.")
    thread_id : str = Field(...,description="Unique session ID to maintain short term memory.")
    user_id : str = Field(...,description="The unique identifier for the tenant context.")
    connection_url : str = Field(...,description="Database URL")

@app.post("/upload_url" , summary="Recieve database URL and invoke ingestion pipeline.")
def upload(request : UploadRequest , background_tasks : BackgroundTasks) :
    background_tasks.add_task(create_embeddings , request.connection_url , request.user_id)

    return {
        "status" : "success",
        "message" : "Ingestion Pipeline started !"
    }

@app.post("/chat" , summary="Return the response generated by the agent for the given user query.")
def chat_endpoint(request : ChatRequest) :
    db_uri = os.getenv("DATABASE_URI")

    print(f"DATABASE_URI = {repr(db_uri)}", flush=True)

    with PostgresSaver.from_conn_string(db_uri) as checkpointer:
        checkpointer.setup()

        agent = workflow.compile(
        checkpointer=checkpointer
        )
        config = {
            "configurable" : {
                'thread_id' : request.thread_id
            }
        }

        initial_state = {
            'connection_url' : request.connection_url ,
            'user_id' : request.user_id ,
            'messages' : [HumanMessage(content=request.message)],
            'retry' : 0
        }
        try :
            result = agent.invoke(initial_state , config=config)

            final_result = result.get("final_result")

            print("*"*50 , flush=True)
            print(f"\n\n Scheme : {result['scheme']}\n\n" , flush=True)
            print(f"\n\nSql Query : {result['sql_query']}\n\n" , flush=True)
            print(f"\n\nQuery Result : {result['query_result']}\n\n" , flush=True)

            return {
                    "status": "success",
                    "thread_id": request.thread_id,
                    "response": final_result
                }

        except Exception as e :
            raise HTTPException(status_code=500 , detail=f"Error : {str(e)}")