File size: 4,510 Bytes
1a436de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import ast
import logging
import os

import pandas as pd
from dotenv import load_dotenv
from duckdb import DuckDBPyConnection

from src.models import PanderaSchemaModel, SQLQueryModel

load_dotenv()

logger = logging.getLogger(__name__)


SQL_GENERATION_RETRIES = int(os.getenv("SQL_GENERATION_RETRIES", "5"))
PANDERA_PROMPT = os.getenv("PANDERA_PROMPT")
PANDERA_USER_PROMPT = os.getenv("PANDERA_USER_PROMPT")
SQL_PROMPT = os.getenv("SQL_PROMPT")
USER_PROMPT = os.getenv("USER_PROMPT")


class Query2Schema:
    def __init__(
        self,
        duckdb: DuckDBPyConnection,
        chain,
    ) -> None:
        self._duckdb = duckdb
        self.chain = chain

    def generate_sql(
        self, user_question: str, context: str, errors: str | None = None
    ) -> str | dict[str, str | int | float | None] | list[str] | None:
        """Generate SQL + description."""
        user_prompt_formatted = USER_PROMPT.format(
            question=user_question, context=context
        )
        if errors:
            user_prompt_formatted += f"Carefully review the previous error or\
            exception and rewrite the SQL so that the error does not occur again.\
            Try a different approach or rewrite SQL if needed. Last error: {errors}"

        sql = self.chain.run(
            system_prompt=SQL_PROMPT,
            user_prompt=user_prompt_formatted,
            format_name="sql_query",
            response_format=SQLQueryModel,
        )
        logger.info(f"SQL Generated Successfully: {sql}")
        return sql

    def run_query(self, sql_query: str) -> pd.DataFrame | None:
        """Execute SQL and return dataframe."""
        logger.info("Query Execution Started.")
        return self._duckdb.query(sql_query).df()

    def try_sql_with_retries(
        self,
        user_question: str,
        context: str,
        max_retries: int = SQL_GENERATION_RETRIES,
    ) -> tuple[
        str | dict[str, str | int | float | None] | list[str] | None,
        pd.DataFrame | None,
    ]:
        """Try SQL generation + execution with retries."""
        last_error = None
        all_errors = ""

        for attempt in range(
            1, max_retries + 2
        ):  # @ Since the first is normal and not consider in retries
            try:
                if attempt > 1 and last_error:
                    logger.info(f"Retrying: {attempt - 1}")
                    # Generate SQL
                    sql = self.generate_sql(user_question, context, errors=all_errors)
                    if not sql:
                        return None, None
                else:
                    # Generate SQL
                    sql = self.generate_sql(user_question, context)
                    if not sql:
                        return None, None

                # Try executing query
                sql_query_str = sql.get("sql_query") if isinstance(sql, dict) else sql
                if not isinstance(sql_query_str, str):
                    raise ValueError(
                        f"Expected SQL query to be a string, got {type(sql_query_str).__name__}"
                    )
                query_df = self.run_query(sql_query_str)

                # If execution succeeds, stop retrying or if df is not empty
                if query_df is not None and not query_df.empty:
                    return sql, query_df

            except Exception as e:
                last_error = f"\nAttempt {attempt - 1}] {type(e).__name__}: {e}"
                logger.error(f"Error during SQL generation or execution: {last_error}")
                all_errors += last_error

        logger.error(f"Failed after {max_retries} attempts. Last error: {all_errors}")
        return None, None

    def generate_pandera_schema(self, sql_query: str, user_instruction: str) -> str:
        """Generate Pandera schema."""
        class_lines = []

        schema_str = self.chain.run(
            system_prompt=PANDERA_PROMPT,
            user_prompt=PANDERA_USER_PROMPT.format(
                sql_query=sql_query, instructions=user_instruction
            ),
            format_name="pandera_schema",
            response_format=PanderaSchemaModel,
        )

        parsed = ast.parse(schema_str)

        original_lines = schema_str.splitlines()
        for node in parsed.body:
            if isinstance(node, ast.ClassDef):
                start, end = node.lineno - 1, node.end_lineno
                class_lines.extend(original_lines[start:end])

        return "\n".join(class_lines)