firermsdata-agent / src /utils /_pydantic_agent.py
Aryan Jain
migrate to pinecone and show graph color
411c555
import asyncio
from typing import Optional
from loguru import logger
from pydantic_ai import Agent, ModelHTTPError
from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings
from src.configs import DatabaseConfig
from src.schemas import SQLQueryExtractor, Message
from src.prompts import SQL_QUERY_EXTRACTOR_PROMPT
from sqlalchemy import text
import os
class PydanticAgent:
def __init__(
self
):
self._system_prompt = SQL_QUERY_EXTRACTOR_PROMPT
self._openai_model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-5.2")
self._openai_model = OpenAIResponsesModel(
model_name=self._openai_model_name,
)
self._agent = Agent(
system_prompt=self._system_prompt,
model=self._openai_model,
output_type=SQLQueryExtractor,
model_settings=OpenAIResponsesModelSettings(temperature=0.0),
tools=[self._verify_sql_query],
retries=5,
)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_value, traceback):
pass
async def _verify_sql_query(self, sqlite_query: str) -> bool | str:
logger.info(f"Verifying SQL query: {sqlite_query}")
try:
words_shoould_not_present_in_sql_query = [
"DELETE",
"DROP",
"UPDATE",
"TRUNCATE",
"ALTER",
"INSERT"
]
sql_query = sqlite_query.lower().strip()
if any(
word.lower() in sql_query
for word in words_shoould_not_present_in_sql_query
):
raise Exception(
f"SQL query contains a destructive operation: {sql_query}. Only SELECT queries are allowed."
)
async with DatabaseConfig.async_session() as session:
await session.execute(text(sql_query))
except Exception as e:
logger.error(e)
return str(e) + "\nPlease generate SQL Query again"
return True
async def _run_with_backoff(self, agent, *args, retries=5, **kwargs):
delay = 5
max_delay = 60
for attempt in range(retries):
try:
logger.debug(f"Attempt {attempt + 1}/{retries}")
return await agent.run(*args, **kwargs)
except ModelHTTPError as e:
logger.debug("Rate limit exceeded, backing off...")
logger.debug(f"Backing off for {delay} seconds...")
await asyncio.sleep(delay)
delay = min(delay * 5, max_delay)
continue
raise RuntimeError("Exceeded retries due to rate limiting")
async def run(
self, user_input: str, message_history: Optional[list[Message]] = []
) -> SQLQueryExtractor:
output = await self._run_with_backoff(
self._agent, user_input, message_history=message_history
)
logger.debug(output.output)
return output.output