File size: 4,175 Bytes
570f7bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from __future__ import annotations
import os
from typing import Tuple, Dict, Any, List
import json
from adapters.llm.base import LLMProvider
from openai import OpenAI

# NOTE: Read keys/base URL from env. Do NOT pass base_url in constructors.
#  - OPENAI_API_KEY   (required)
#  - OPENAI_BASE_URL  (optional; defaults to OpenAI public)
#  - OPENAI_MODEL_ID  (e.g., "gpt-4o-mini")



class OpenAIProvider(LLMProvider):
    provider_id = "openai"

    def __init__(self) -> None:
        self.client = OpenAI(
            api_key=os.environ["OPENAI_API_KEY"],
            base_url=os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
        )
        self.model = os.getenv("OPENAI_MODEL_ID", "gpt-4o-mini")

    def plan(self, *, user_query, schema_preview):
        completion = self.client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "You create SQL query plans."},
                {"role": "user", "content": f"Query: {user_query}\nSchema:\n{schema_preview}"}
            ],
            temperature=0
        )
        msg = completion.choices[0].message.content
        usage = completion.usage
        return msg, usage.prompt_tokens, usage.completion_tokens, self._estimate_cost(usage)


    def generate_sql(self, *, user_query, schema_preview, plan_text, clarify_answers=None):
        prompt = f"""
        You are a precise SQL generator. 
        Return ONLY valid JSON with two keys: "sql" and "rationale".
        Do not include any markdown, backticks, or extra text.

        Example:
        {{
          "sql": "SELECT * FROM singer;",
          "rationale": "The user requested to list all singers."
        }}

        Now generate JSON for this input:

        User query: {user_query}
        Schema preview:
        {schema_preview}
        Plan: {plan_text}
        Clarifications: {clarify_answers}
        """
        completion = self.client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "You convert natural language to SQL."},
                {"role": "user", "content": prompt}
            ],
            temperature=0
        )
        content = completion.choices[0].message.content.strip()
        usage = completion.usage  # ← لازم داریم
        t_in = usage.prompt_tokens if usage else None
        t_out = usage.completion_tokens if usage else None
        cost = self._estimate_cost(usage) if usage else None

        # Robust JSON parse (with fallback to substring)
        try:
            parsed = json.loads(content)
        except json.JSONDecodeError:
            start = content.find("{")
            end = content.rfind("}")
            if start != -1 and end != -1:
                try:
                    parsed = json.loads(content[start:end + 1])
                except Exception:
                    raise ValueError(f"Invalid LLM JSON output: {content[:200]}")
            else:
                raise ValueError(f"Invalid LLM JSON output: {content[:200]}")

        sql = (parsed.get("sql") or "").strip()
        rationale = parsed.get("rationale") or ""

        if not sql:
            raise ValueError("LLM returned empty 'sql'")

        # IMPORTANT: return the expected 5-tuple
        return sql, rationale, t_in, t_out, cost


    def repair(self, *, sql, error_msg, schema_preview):
        completion = self.client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "You fix SQL queries keeping them SELECT-only."},
                {"role": "user", "content": f"SQL:\n{sql}\nError:\n{error_msg}\nSchema:\n{schema_preview}"}
            ],
            temperature=0
        )
        msg = completion.choices[0].message.content
        usage = completion.usage
        return msg, usage.prompt_tokens, usage.completion_tokens, self._estimate_cost(usage)

    def _estimate_cost(self, usage):
        # Rough estimation example — can be refined with official token pricing
        total = usage.prompt_tokens + usage.completion_tokens
        return total * 0.000001