from typing import Optional from loguru import logger from src.schemas import Message, SQLQueryExtractor from src.entities import ( ApparatusModel, IncidentModel, PersonnelModel, IncidentApparatusModel, IncidentPersonnelModel, IncidentBaseModel, ) from src.configs import DatabaseConfig from src.utils import PydanticAgent, PineconeClient from sqlalchemy import inspect, text class DataExtractorService: def __init__(self): self._table_model_map = { "apparatus": ApparatusModel, "incident": IncidentModel, "personnel": PersonnelModel, "auv_incidentapparatus": IncidentApparatusModel, "auv_incidentpersonnel": IncidentPersonnelModel, "auv_incidentbase": IncidentBaseModel, } self._pydantic_agent = PydanticAgent self._pinecone_client = PineconeClient async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_value, traceback): pass async def _model_schema_to_text(self, model_class) -> str: """Convert SQLAlchemy model schema to LLM-friendly text.""" mapper = inspect(model_class) table_name = model_class.__tablename__ lines = [f"Table: {table_name}", "Columns:"] for column in mapper.columns: col_type = str(column.type) nullable = "nullable" if column.nullable else "required" pk = " [PRIMARY KEY]" if column.primary_key else "" fk = ( f" [FK -> {list(column.foreign_keys)[0].target_fullname}]" if column.foreign_keys else "" ) lines.append(f" - {column.name}: {col_type} ({nullable}){pk}{fk}") return "\n".join(lines) async def _get_table_schema(self, table_name: str) -> str: model_class = self._table_model_map[table_name] return await self._model_schema_to_text(model_class) async def _get_data(self, sql_query: str): try: async with DatabaseConfig.async_session() as session: result = await session.execute(text(sql_query)) return result.mappings().all() except Exception as e: raise e async def extract( self, user_query: str, message_history: Optional[list[Message]] = [] ): logger.info("Extracting data...") async with self._pinecone_client() as pinecone_client: vector_results = await pinecone_client.query(query=user_query, n_results=5) vector_document_metadata = [ (doc, meta) for doc, meta in zip( vector_results["documents"], vector_results["metadatas"] ) if doc is not None and meta is not None ] table_schemas = [] processed_tables = set() for met in vector_results["metadatas"]: if met["table_name"] in processed_tables: continue processed_tables.add(met["table_name"]) table_schemas.append(await self._get_table_schema(met["table_name"])) user_input = f""" # User Query: {user_query} # Table Schemas: {table_schemas} # Description {vector_document_metadata} """ async with self._pydantic_agent() as pydantic_agent: output: SQLQueryExtractor = await pydantic_agent.run( user_input=user_input, message_history=message_history ) sql_query = output.sqlite_query data = await self._get_data(sql_query) return data, sql_query, output