Spaces:
Sleeping
Sleeping
File size: 3,706 Bytes
8e073b5 6500c31 12ed030 411c555 3bdece0 6500c31 12ed030 411c555 6500c31 71dcc32 12ed030 3bdece0 8770042 12ed030 8e073b5 12ed030 411c555 12ed030 f9fd577 12ed030 cdcf836 12ed030 103e7ee | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 | from typing import Optional
from loguru import logger
from src.schemas import Message, SQLQueryExtractor
from src.entities import (
ApparatusModel,
IncidentModel,
PersonnelModel,
IncidentApparatusModel,
IncidentPersonnelModel,
IncidentBaseModel,
)
from src.configs import DatabaseConfig
from src.utils import PydanticAgent, PineconeClient
from sqlalchemy import inspect, text
class DataExtractorService:
def __init__(self):
self._table_model_map = {
"apparatus": ApparatusModel,
"incident": IncidentModel,
"personnel": PersonnelModel,
"auv_incidentapparatus": IncidentApparatusModel,
"auv_incidentpersonnel": IncidentPersonnelModel,
"auv_incidentbase": IncidentBaseModel,
}
self._pydantic_agent = PydanticAgent
self._pinecone_client = PineconeClient
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_value, traceback):
pass
async def _model_schema_to_text(self, model_class) -> str:
"""Convert SQLAlchemy model schema to LLM-friendly text."""
mapper = inspect(model_class)
table_name = model_class.__tablename__
lines = [f"Table: {table_name}", "Columns:"]
for column in mapper.columns:
col_type = str(column.type)
nullable = "nullable" if column.nullable else "required"
pk = " [PRIMARY KEY]" if column.primary_key else ""
fk = (
f" [FK -> {list(column.foreign_keys)[0].target_fullname}]"
if column.foreign_keys
else ""
)
lines.append(f" - {column.name}: {col_type} ({nullable}){pk}{fk}")
return "\n".join(lines)
async def _get_table_schema(self, table_name: str) -> str:
model_class = self._table_model_map[table_name]
return await self._model_schema_to_text(model_class)
async def _get_data(self, sql_query: str):
try:
async with DatabaseConfig.async_session() as session:
result = await session.execute(text(sql_query))
return result.mappings().all()
except Exception as e:
raise e
async def extract(
self, user_query: str, message_history: Optional[list[Message]] = []
):
logger.info("Extracting data...")
async with self._pinecone_client() as pinecone_client:
vector_results = await pinecone_client.query(query=user_query, n_results=5)
vector_document_metadata = [
(doc, meta)
for doc, meta in zip(
vector_results["documents"], vector_results["metadatas"]
)
if doc is not None and meta is not None
]
table_schemas = []
processed_tables = set()
for met in vector_results["metadatas"]:
if met["table_name"] in processed_tables:
continue
processed_tables.add(met["table_name"])
table_schemas.append(await self._get_table_schema(met["table_name"]))
user_input = f"""
# User Query:
{user_query}
# Table Schemas:
{table_schemas}
# Description
{vector_document_metadata}
"""
async with self._pydantic_agent() as pydantic_agent:
output: SQLQueryExtractor = await pydantic_agent.run(
user_input=user_input, message_history=message_history
)
sql_query = output.sqlite_query
data = await self._get_data(sql_query)
return data, sql_query, output
|