Spaces:
Sleeping
Sleeping
| from typing import Optional | |
| from loguru import logger | |
| from src.schemas import Message, SQLQueryExtractor | |
| from src.entities import ( | |
| ApparatusModel, | |
| IncidentModel, | |
| PersonnelModel, | |
| IncidentApparatusModel, | |
| IncidentPersonnelModel, | |
| IncidentBaseModel, | |
| ) | |
| from src.configs import DatabaseConfig | |
| from src.utils import PydanticAgent, PineconeClient | |
| from sqlalchemy import inspect, text | |
| class DataExtractorService: | |
| def __init__(self): | |
| self._table_model_map = { | |
| "apparatus": ApparatusModel, | |
| "incident": IncidentModel, | |
| "personnel": PersonnelModel, | |
| "auv_incidentapparatus": IncidentApparatusModel, | |
| "auv_incidentpersonnel": IncidentPersonnelModel, | |
| "auv_incidentbase": IncidentBaseModel, | |
| } | |
| self._pydantic_agent = PydanticAgent | |
| self._pinecone_client = PineconeClient | |
| async def __aenter__(self): | |
| return self | |
| async def __aexit__(self, exc_type, exc_value, traceback): | |
| pass | |
| async def _model_schema_to_text(self, model_class) -> str: | |
| """Convert SQLAlchemy model schema to LLM-friendly text.""" | |
| mapper = inspect(model_class) | |
| table_name = model_class.__tablename__ | |
| lines = [f"Table: {table_name}", "Columns:"] | |
| for column in mapper.columns: | |
| col_type = str(column.type) | |
| nullable = "nullable" if column.nullable else "required" | |
| pk = " [PRIMARY KEY]" if column.primary_key else "" | |
| fk = ( | |
| f" [FK -> {list(column.foreign_keys)[0].target_fullname}]" | |
| if column.foreign_keys | |
| else "" | |
| ) | |
| lines.append(f" - {column.name}: {col_type} ({nullable}){pk}{fk}") | |
| return "\n".join(lines) | |
| async def _get_table_schema(self, table_name: str) -> str: | |
| model_class = self._table_model_map[table_name] | |
| return await self._model_schema_to_text(model_class) | |
| async def _get_data(self, sql_query: str): | |
| try: | |
| async with DatabaseConfig.async_session() as session: | |
| result = await session.execute(text(sql_query)) | |
| return result.mappings().all() | |
| except Exception as e: | |
| raise e | |
| async def extract( | |
| self, user_query: str, message_history: Optional[list[Message]] = [] | |
| ): | |
| logger.info("Extracting data...") | |
| async with self._pinecone_client() as pinecone_client: | |
| vector_results = await pinecone_client.query(query=user_query, n_results=5) | |
| vector_document_metadata = [ | |
| (doc, meta) | |
| for doc, meta in zip( | |
| vector_results["documents"], vector_results["metadatas"] | |
| ) | |
| if doc is not None and meta is not None | |
| ] | |
| table_schemas = [] | |
| processed_tables = set() | |
| for met in vector_results["metadatas"]: | |
| if met["table_name"] in processed_tables: | |
| continue | |
| processed_tables.add(met["table_name"]) | |
| table_schemas.append(await self._get_table_schema(met["table_name"])) | |
| user_input = f""" | |
| # User Query: | |
| {user_query} | |
| # Table Schemas: | |
| {table_schemas} | |
| # Description | |
| {vector_document_metadata} | |
| """ | |
| async with self._pydantic_agent() as pydantic_agent: | |
| output: SQLQueryExtractor = await pydantic_agent.run( | |
| user_input=user_input, message_history=message_history | |
| ) | |
| sql_query = output.sqlite_query | |
| data = await self._get_data(sql_query) | |
| return data, sql_query, output | |