import asyncio from typing import Optional from loguru import logger from pydantic_ai import Agent, ModelHTTPError from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings from src.configs import DatabaseConfig from src.schemas import SQLQueryExtractor, Message from src.prompts import SQL_QUERY_EXTRACTOR_PROMPT from sqlalchemy import text import os class PydanticAgent: def __init__( self ): self._system_prompt = SQL_QUERY_EXTRACTOR_PROMPT self._openai_model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-5.2") self._openai_model = OpenAIResponsesModel( model_name=self._openai_model_name, ) self._agent = Agent( system_prompt=self._system_prompt, model=self._openai_model, output_type=SQLQueryExtractor, model_settings=OpenAIResponsesModelSettings(temperature=0.0), tools=[self._verify_sql_query], retries=5, ) async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_value, traceback): pass async def _verify_sql_query(self, sqlite_query: str) -> bool | str: logger.info(f"Verifying SQL query: {sqlite_query}") try: words_shoould_not_present_in_sql_query = [ "DELETE", "DROP", "UPDATE", "TRUNCATE", "ALTER", "INSERT" ] sql_query = sqlite_query.lower().strip() if any( word.lower() in sql_query for word in words_shoould_not_present_in_sql_query ): raise Exception( f"SQL query contains a destructive operation: {sql_query}. Only SELECT queries are allowed." ) async with DatabaseConfig.async_session() as session: await session.execute(text(sql_query)) except Exception as e: logger.error(e) return str(e) + "\nPlease generate SQL Query again" return True async def _run_with_backoff(self, agent, *args, retries=5, **kwargs): delay = 5 max_delay = 60 for attempt in range(retries): try: logger.debug(f"Attempt {attempt + 1}/{retries}") return await agent.run(*args, **kwargs) except ModelHTTPError as e: logger.debug("Rate limit exceeded, backing off...") logger.debug(f"Backing off for {delay} seconds...") await asyncio.sleep(delay) delay = min(delay * 5, max_delay) continue raise RuntimeError("Exceeded retries due to rate limiting") async def run( self, user_input: str, message_history: Optional[list[Message]] = [] ) -> SQLQueryExtractor: output = await self._run_with_backoff( self._agent, user_input, message_history=message_history ) logger.debug(output.output) return output.output