firermsdata-agent / src /services /_data_extractor.py
Aryan Jain
migrate to pinecone and show graph color
411c555
from typing import Optional
from loguru import logger
from src.schemas import Message, SQLQueryExtractor
from src.entities import (
ApparatusModel,
IncidentModel,
PersonnelModel,
IncidentApparatusModel,
IncidentPersonnelModel,
IncidentBaseModel,
)
from src.configs import DatabaseConfig
from src.utils import PydanticAgent, PineconeClient
from sqlalchemy import inspect, text
class DataExtractorService:
def __init__(self):
self._table_model_map = {
"apparatus": ApparatusModel,
"incident": IncidentModel,
"personnel": PersonnelModel,
"auv_incidentapparatus": IncidentApparatusModel,
"auv_incidentpersonnel": IncidentPersonnelModel,
"auv_incidentbase": IncidentBaseModel,
}
self._pydantic_agent = PydanticAgent
self._pinecone_client = PineconeClient
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_value, traceback):
pass
async def _model_schema_to_text(self, model_class) -> str:
"""Convert SQLAlchemy model schema to LLM-friendly text."""
mapper = inspect(model_class)
table_name = model_class.__tablename__
lines = [f"Table: {table_name}", "Columns:"]
for column in mapper.columns:
col_type = str(column.type)
nullable = "nullable" if column.nullable else "required"
pk = " [PRIMARY KEY]" if column.primary_key else ""
fk = (
f" [FK -> {list(column.foreign_keys)[0].target_fullname}]"
if column.foreign_keys
else ""
)
lines.append(f" - {column.name}: {col_type} ({nullable}){pk}{fk}")
return "\n".join(lines)
async def _get_table_schema(self, table_name: str) -> str:
model_class = self._table_model_map[table_name]
return await self._model_schema_to_text(model_class)
async def _get_data(self, sql_query: str):
try:
async with DatabaseConfig.async_session() as session:
result = await session.execute(text(sql_query))
return result.mappings().all()
except Exception as e:
raise e
async def extract(
self, user_query: str, message_history: Optional[list[Message]] = []
):
logger.info("Extracting data...")
async with self._pinecone_client() as pinecone_client:
vector_results = await pinecone_client.query(query=user_query, n_results=5)
vector_document_metadata = [
(doc, meta)
for doc, meta in zip(
vector_results["documents"], vector_results["metadatas"]
)
if doc is not None and meta is not None
]
table_schemas = []
processed_tables = set()
for met in vector_results["metadatas"]:
if met["table_name"] in processed_tables:
continue
processed_tables.add(met["table_name"])
table_schemas.append(await self._get_table_schema(met["table_name"]))
user_input = f"""
# User Query:
{user_query}
# Table Schemas:
{table_schemas}
# Description
{vector_document_metadata}
"""
async with self._pydantic_agent() as pydantic_agent:
output: SQLQueryExtractor = await pydantic_agent.run(
user_input=user_input, message_history=message_history
)
sql_query = output.sqlite_query
data = await self._get_data(sql_query)
return data, sql_query, output