Spaces:
Sleeping
Sleeping
| import asyncio | |
| from typing import Optional | |
| from loguru import logger | |
| from pydantic_ai import Agent, ModelHTTPError | |
| from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings | |
| from src.configs import DatabaseConfig | |
| from src.schemas import SQLQueryExtractor, Message | |
| from src.prompts import SQL_QUERY_EXTRACTOR_PROMPT | |
| from sqlalchemy import text | |
| import os | |
| class PydanticAgent: | |
| def __init__( | |
| self | |
| ): | |
| self._system_prompt = SQL_QUERY_EXTRACTOR_PROMPT | |
| self._openai_model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-5.2") | |
| self._openai_model = OpenAIResponsesModel( | |
| model_name=self._openai_model_name, | |
| ) | |
| self._agent = Agent( | |
| system_prompt=self._system_prompt, | |
| model=self._openai_model, | |
| output_type=SQLQueryExtractor, | |
| model_settings=OpenAIResponsesModelSettings(temperature=0.0), | |
| tools=[self._verify_sql_query], | |
| retries=5, | |
| ) | |
| async def __aenter__(self): | |
| return self | |
| async def __aexit__(self, exc_type, exc_value, traceback): | |
| pass | |
| async def _verify_sql_query(self, sqlite_query: str) -> bool | str: | |
| logger.info(f"Verifying SQL query: {sqlite_query}") | |
| try: | |
| words_shoould_not_present_in_sql_query = [ | |
| "DELETE", | |
| "DROP", | |
| "UPDATE", | |
| "TRUNCATE", | |
| "ALTER", | |
| "INSERT" | |
| ] | |
| sql_query = sqlite_query.lower().strip() | |
| if any( | |
| word.lower() in sql_query | |
| for word in words_shoould_not_present_in_sql_query | |
| ): | |
| raise Exception( | |
| f"SQL query contains a destructive operation: {sql_query}. Only SELECT queries are allowed." | |
| ) | |
| async with DatabaseConfig.async_session() as session: | |
| await session.execute(text(sql_query)) | |
| except Exception as e: | |
| logger.error(e) | |
| return str(e) + "\nPlease generate SQL Query again" | |
| return True | |
| async def _run_with_backoff(self, agent, *args, retries=5, **kwargs): | |
| delay = 5 | |
| max_delay = 60 | |
| for attempt in range(retries): | |
| try: | |
| logger.debug(f"Attempt {attempt + 1}/{retries}") | |
| return await agent.run(*args, **kwargs) | |
| except ModelHTTPError as e: | |
| logger.debug("Rate limit exceeded, backing off...") | |
| logger.debug(f"Backing off for {delay} seconds...") | |
| await asyncio.sleep(delay) | |
| delay = min(delay * 5, max_delay) | |
| continue | |
| raise RuntimeError("Exceeded retries due to rate limiting") | |
| async def run( | |
| self, user_input: str, message_history: Optional[list[Message]] = [] | |
| ) -> SQLQueryExtractor: | |
| output = await self._run_with_backoff( | |
| self._agent, user_input, message_history=message_history | |
| ) | |
| logger.debug(output.output) | |
| return output.output | |