File size: 3,099 Bytes
5f22dc5
f9fd577
5f22dc5
 
 
 
 
83e6c59
 
5f22dc5
3bdece0
 
5f22dc5
 
 
 
 
f9fd577
5f22dc5
83e6c59
411c555
5f22dc5
71dcc32
5f22dc5
 
 
 
83e6c59
71dcc32
5f22dc5
 
 
 
 
 
 
 
 
 
cdcf836
 
5f22dc5
53d3e55
 
 
 
 
 
 
 
cdcf836
53d3e55
 
 
 
 
 
 
5f22dc5
3bdece0
5f22dc5
71dcc32
 
5f22dc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9fd577
 
 
5f22dc5
f9fd577
5f22dc5
411c555
5f22dc5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import asyncio
from typing import Optional
from loguru import logger
from pydantic_ai import Agent, ModelHTTPError
from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings
from src.configs import DatabaseConfig

from src.schemas import SQLQueryExtractor, Message
from src.prompts import SQL_QUERY_EXTRACTOR_PROMPT

from sqlalchemy import text

import os


class PydanticAgent:
    def __init__(
        self
    ):
        self._system_prompt = SQL_QUERY_EXTRACTOR_PROMPT
        self._openai_model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-5.2")
        self._openai_model = OpenAIResponsesModel(
            model_name=self._openai_model_name,
        )
        self._agent = Agent(
            system_prompt=self._system_prompt,
            model=self._openai_model,
            output_type=SQLQueryExtractor,
            model_settings=OpenAIResponsesModelSettings(temperature=0.0),
            tools=[self._verify_sql_query],
            retries=5,
        )

    async def __aenter__(self):
        return self

    async def __aexit__(self, exc_type, exc_value, traceback):
        pass

    async def _verify_sql_query(self, sqlite_query: str) -> bool | str:
        logger.info(f"Verifying SQL query: {sqlite_query}")
        try:
            words_shoould_not_present_in_sql_query = [
                "DELETE", 
                "DROP", 
                "UPDATE", 
                "TRUNCATE", 
                "ALTER", 
                "INSERT"
            ]
            sql_query = sqlite_query.lower().strip()
            if any(
                word.lower() in sql_query
                for word in words_shoould_not_present_in_sql_query
            ):
                raise Exception(
                    f"SQL query contains a destructive operation: {sql_query}. Only SELECT queries are allowed."
                )
            async with DatabaseConfig.async_session() as session:
                await session.execute(text(sql_query))
        except Exception as e:
            logger.error(e)
            return str(e) + "\nPlease generate SQL Query again"
        return True

    async def _run_with_backoff(self, agent, *args, retries=5, **kwargs):
        delay = 5
        max_delay = 60

        for attempt in range(retries):
            try:
                logger.debug(f"Attempt {attempt + 1}/{retries}")
                return await agent.run(*args, **kwargs)
            except ModelHTTPError as e:
                logger.debug("Rate limit exceeded, backing off...")
                logger.debug(f"Backing off for {delay} seconds...")
                await asyncio.sleep(delay)
                delay = min(delay * 5, max_delay)
                continue

        raise RuntimeError("Exceeded retries due to rate limiting")

    async def run(
        self, user_input: str, message_history: Optional[list[Message]] = []
    ) -> SQLQueryExtractor:
        output = await self._run_with_backoff(
            self._agent, user_input, message_history=message_history
        )
        logger.debug(output.output)
        return output.output