Spaces:
Sleeping
Sleeping
| import time | |
| from nl2sql.types import StageTrace, StageResult | |
| from adapters.llm.base import LLMProvider | |
| GUIDELINES = """ | |
| When repairing: | |
| 1. Keep query SELECT-only. | |
| 2. Explicitly qualify ambiguous columns with table names. | |
| 3. Match GROUP BY fields with aggregations. | |
| 4. Use known foreign keys for JOIN. | |
| 5. Add a reasonable LIMIT if missing. | |
| Return only the corrected SQL. | |
| """ | |
| class Repair: | |
| name = "repair" | |
| def __init__(self, llm: LLMProvider): | |
| self.llm = llm | |
| def run(self, sql: str, error_msg: str, schema_preview: str) -> StageResult: | |
| t0 = time.perf_counter() | |
| fixed_sql, t_in, t_out, cost = self.llm.repair( | |
| sql=sql, | |
| error_msg=f"{GUIDELINES}\n\n{error_msg}", | |
| schema_preview=schema_preview, | |
| ) | |
| trace = StageTrace( | |
| stage=self.name, | |
| duration_ms=(time.perf_counter() - t0) * 1000, | |
| token_in=t_in, | |
| token_out=t_out, | |
| cost_usd=cost, | |
| notes={"old_sql_len": len(sql), "new_sql_len": len(fixed_sql)}, | |
| ) | |
| return StageResult(ok=True, data={"sql": fixed_sql}, trace=trace) | |