File size: 1,137 Bytes
570f7bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1bc4eb
570f7bd
 
c1bc4eb
570f7bd
 
 
c1bc4eb
570f7bd
c1bc4eb
 
 
 
 
 
 
 
 
 
 
 
 
570f7bd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import time


from nl2sql.types import StageTrace, StageResult
from adapters.llm.base import LLMProvider

GUIDELINES = """
When repairing:
1. Keep query SELECT-only.
2. Explicitly qualify ambiguous columns with table names.
3. Match GROUP BY fields with aggregations.
4. Use known foreign keys for JOIN.
5. Add a reasonable LIMIT if missing.
Return only the corrected SQL.
"""


class Repair:
    name = "repair"

    def __init__(self, llm: LLMProvider):
        self.llm = llm

    def run(self, sql: str, error_msg: str, schema_preview: str) -> StageResult:
        t0 = time.perf_counter()
        fixed_sql, t_in, t_out, cost = self.llm.repair(
            sql=sql,
            error_msg=f"{GUIDELINES}\n\n{error_msg}",
            schema_preview=schema_preview,
        )
        trace = StageTrace(
            stage=self.name,
            duration_ms=(time.perf_counter() - t0) * 1000,
            token_in=t_in,
            token_out=t_out,
            cost_usd=cost,
            notes={"old_sql_len": len(sql), "new_sql_len": len(fixed_sql)},
        )
        return StageResult(ok=True, data={"sql": fixed_sql}, trace=trace)