File size: 3,706 Bytes
8e073b5
6500c31
 
12ed030
 
 
 
 
 
 
 
 
 
411c555
3bdece0
6500c31
 
 
 
12ed030
 
 
 
 
 
 
 
 
411c555
6500c31
 
 
 
 
 
 
71dcc32
12ed030
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3bdece0
8770042
12ed030
 
 
8e073b5
 
 
12ed030
411c555
 
12ed030
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9fd577
12ed030
 
cdcf836
12ed030
103e7ee
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from typing import Optional
from loguru import logger

from src.schemas import Message, SQLQueryExtractor
from src.entities import (
    ApparatusModel,
    IncidentModel,
    PersonnelModel,
    IncidentApparatusModel,
    IncidentPersonnelModel,
    IncidentBaseModel,
)
from src.configs import DatabaseConfig
from src.utils import PydanticAgent, PineconeClient
from sqlalchemy import inspect, text


class DataExtractorService:
    def __init__(self):
        self._table_model_map = {
            "apparatus": ApparatusModel,
            "incident": IncidentModel,
            "personnel": PersonnelModel,
            "auv_incidentapparatus": IncidentApparatusModel,
            "auv_incidentpersonnel": IncidentPersonnelModel,
            "auv_incidentbase": IncidentBaseModel,
        }
        self._pydantic_agent = PydanticAgent
        self._pinecone_client = PineconeClient

    async def __aenter__(self):
        return self

    async def __aexit__(self, exc_type, exc_value, traceback):
        pass

    async def _model_schema_to_text(self, model_class) -> str:
        """Convert SQLAlchemy model schema to LLM-friendly text."""
        mapper = inspect(model_class)
        table_name = model_class.__tablename__

        lines = [f"Table: {table_name}", "Columns:"]

        for column in mapper.columns:
            col_type = str(column.type)
            nullable = "nullable" if column.nullable else "required"
            pk = " [PRIMARY KEY]" if column.primary_key else ""
            fk = (
                f" [FK -> {list(column.foreign_keys)[0].target_fullname}]"
                if column.foreign_keys
                else ""
            )
            lines.append(f"  - {column.name}: {col_type} ({nullable}){pk}{fk}")

        return "\n".join(lines)

    async def _get_table_schema(self, table_name: str) -> str:
        model_class = self._table_model_map[table_name]
        return await self._model_schema_to_text(model_class)

    async def _get_data(self, sql_query: str):
        try:
            async with DatabaseConfig.async_session() as session:
                result = await session.execute(text(sql_query))
                return result.mappings().all()
        except Exception as e:
            raise e

    async def extract(
        self, user_query: str, message_history: Optional[list[Message]] = []
    ):
        logger.info("Extracting data...")
        async with self._pinecone_client() as pinecone_client:
            vector_results = await pinecone_client.query(query=user_query, n_results=5)
            vector_document_metadata = [
                (doc, meta)
                for doc, meta in zip(
                    vector_results["documents"], vector_results["metadatas"]
                )
                if doc is not None and meta is not None
            ]
            table_schemas = []
            processed_tables = set()
            for met in vector_results["metadatas"]:
                if met["table_name"] in processed_tables:
                    continue
                processed_tables.add(met["table_name"])
                table_schemas.append(await self._get_table_schema(met["table_name"]))

        user_input = f"""
        # User Query: 
        {user_query}

        # Table Schemas: 
        {table_schemas}

        # Description
        {vector_document_metadata}
        """

        async with self._pydantic_agent() as pydantic_agent:
            output: SQLQueryExtractor = await pydantic_agent.run(
                user_input=user_input, message_history=message_history
            )

        sql_query = output.sqlite_query
        data = await self._get_data(sql_query)
        return data, sql_query, output