Spaces:
Sleeping
Sleeping
File size: 3,099 Bytes
5f22dc5 f9fd577 5f22dc5 83e6c59 5f22dc5 3bdece0 5f22dc5 f9fd577 5f22dc5 83e6c59 411c555 5f22dc5 71dcc32 5f22dc5 83e6c59 71dcc32 5f22dc5 cdcf836 5f22dc5 53d3e55 cdcf836 53d3e55 5f22dc5 3bdece0 5f22dc5 71dcc32 5f22dc5 f9fd577 5f22dc5 f9fd577 5f22dc5 411c555 5f22dc5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 | import asyncio
from typing import Optional
from loguru import logger
from pydantic_ai import Agent, ModelHTTPError
from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings
from src.configs import DatabaseConfig
from src.schemas import SQLQueryExtractor, Message
from src.prompts import SQL_QUERY_EXTRACTOR_PROMPT
from sqlalchemy import text
import os
class PydanticAgent:
def __init__(
self
):
self._system_prompt = SQL_QUERY_EXTRACTOR_PROMPT
self._openai_model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-5.2")
self._openai_model = OpenAIResponsesModel(
model_name=self._openai_model_name,
)
self._agent = Agent(
system_prompt=self._system_prompt,
model=self._openai_model,
output_type=SQLQueryExtractor,
model_settings=OpenAIResponsesModelSettings(temperature=0.0),
tools=[self._verify_sql_query],
retries=5,
)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_value, traceback):
pass
async def _verify_sql_query(self, sqlite_query: str) -> bool | str:
logger.info(f"Verifying SQL query: {sqlite_query}")
try:
words_shoould_not_present_in_sql_query = [
"DELETE",
"DROP",
"UPDATE",
"TRUNCATE",
"ALTER",
"INSERT"
]
sql_query = sqlite_query.lower().strip()
if any(
word.lower() in sql_query
for word in words_shoould_not_present_in_sql_query
):
raise Exception(
f"SQL query contains a destructive operation: {sql_query}. Only SELECT queries are allowed."
)
async with DatabaseConfig.async_session() as session:
await session.execute(text(sql_query))
except Exception as e:
logger.error(e)
return str(e) + "\nPlease generate SQL Query again"
return True
async def _run_with_backoff(self, agent, *args, retries=5, **kwargs):
delay = 5
max_delay = 60
for attempt in range(retries):
try:
logger.debug(f"Attempt {attempt + 1}/{retries}")
return await agent.run(*args, **kwargs)
except ModelHTTPError as e:
logger.debug("Rate limit exceeded, backing off...")
logger.debug(f"Backing off for {delay} seconds...")
await asyncio.sleep(delay)
delay = min(delay * 5, max_delay)
continue
raise RuntimeError("Exceeded retries due to rate limiting")
async def run(
self, user_input: str, message_history: Optional[list[Message]] = []
) -> SQLQueryExtractor:
output = await self._run_with_backoff(
self._agent, user_input, message_history=message_history
)
logger.debug(output.output)
return output.output
|