File size: 5,529 Bytes
570f7bd
 
 
 
 
 
260d5c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
570f7bd
 
 
 
 
 
260d5c1
 
 
 
 
 
 
570f7bd
 
 
 
 
 
c1bc4eb
 
 
 
570f7bd
c1bc4eb
570f7bd
 
 
c1bc4eb
 
 
 
 
 
570f7bd
c1bc4eb
 
 
570f7bd
260d5c1
570f7bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1bc4eb
570f7bd
c1bc4eb
570f7bd
 
260d5c1
570f7bd
 
 
 
 
 
 
 
 
 
 
c1bc4eb
570f7bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1bc4eb
 
 
 
 
 
 
 
570f7bd
c1bc4eb
570f7bd
 
 
c1bc4eb
 
 
 
 
 
570f7bd
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from __future__ import annotations
import os
import json
from adapters.llm.base import LLMProvider
from openai import OpenAI

# NOTE:
# - Prefer proxy if PROXY_API_KEY and PROXY_BASE_URL are set.
# - Otherwise, fallback to OPENAI_API_KEY (+ OPENAI_BASE_URL defaulting to https://api.openai.com/v1).
# - Do NOT pass base_url/api_key in the constructor; rely on env vars.


def _resolve_api_config() -> tuple[str, str, str]:
    """
    Returns (api_key, base_url, model_id) according to env.
    Resolution order:
      1) Proxy: PROXY_API_KEY + PROXY_BASE_URL [+ PROXY_MODEL_ID]
      2) Direct: OPENAI_API_KEY [+ OPENAI_BASE_URL] [+ OPENAI_MODEL_ID]
    Additionally, LLM_MODEL_ID (if set) overrides model choice.
    """
    # Optional global override for model id
    override_model = os.getenv("LLM_MODEL_ID")

    proxy_key = os.getenv("PROXY_API_KEY")
    proxy_url = os.getenv("PROXY_BASE_URL")
    if proxy_key and proxy_url:
        model = (
            override_model
            or os.getenv("PROXY_MODEL_ID")
            or os.getenv("OPENAI_MODEL_ID")
            or "gpt-4o-mini"
        )
        return proxy_key, proxy_url, model

    openai_key = os.getenv("OPENAI_API_KEY")
    if not openai_key:
        raise RuntimeError(
            "No API credentials found. Set either PROXY_API_KEY/PROXY_BASE_URL or OPENAI_API_KEY."
        )
    openai_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
    model = override_model or os.getenv("OPENAI_MODEL_ID") or "gpt-4o-mini"
    return openai_key, openai_url, model


class OpenAIProvider(LLMProvider):
    provider_id = "openai"

    def __init__(self) -> None:
        # Resolve and export to env so we don't pass into constructor.
        api_key, base_url, model = _resolve_api_config()
        os.environ["OPENAI_API_KEY"] = api_key
        os.environ["OPENAI_BASE_URL"] = base_url
        # Create client using env only
        self.client = OpenAI()
        self.model = model

    def plan(self, *, user_query, schema_preview):
        completion = self.client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "You create SQL query plans."},
                {
                    "role": "user",
                    "content": f"Query: {user_query}\nSchema:\n{schema_preview}",
                },
            ],
            temperature=0,
        )
        msg = completion.choices[0].message.content
        usage = completion.usage
        return (
            msg,
            usage.prompt_tokens,
            usage.completion_tokens,
            self._estimate_cost(usage),
        )

    def generate_sql(
        self, *, user_query, schema_preview, plan_text, clarify_answers=None
    ):
        prompt = f"""
        You are a precise SQL generator.
        Return ONLY valid JSON with two keys: "sql" and "rationale".
        Do not include any markdown, backticks, or extra text.

        Example:
        {{
          "sql": "SELECT * FROM singer;",
          "rationale": "The user requested to list all singers."
        }}

        Now generate JSON for this input:

        User query: {user_query}
        Schema preview:
        {schema_preview}
        Plan: {plan_text}
        Clarifications: {clarify_answers}
        """
        completion = self.client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "You convert natural language to SQL."},
                {"role": "user", "content": prompt},
            ],
            temperature=0,
        )
        content = completion.choices[0].message.content.strip()
        usage = completion.usage
        t_in = usage.prompt_tokens if usage else None
        t_out = usage.completion_tokens if usage else None
        cost = self._estimate_cost(usage) if usage else None

        try:
            parsed = json.loads(content)
        except json.JSONDecodeError:
            start = content.find("{")
            end = content.rfind("}")
            if start != -1 and end != -1:
                try:
                    parsed = json.loads(content[start : end + 1])
                except Exception:
                    raise ValueError(f"Invalid LLM JSON output: {content[:200]}")
            else:
                raise ValueError(f"Invalid LLM JSON output: {content[:200]}")

        sql = (parsed.get("sql") or "").strip()
        rationale = parsed.get("rationale") or ""
        if not sql:
            raise ValueError("LLM returned empty 'sql'")

        return sql, rationale, t_in, t_out, cost

    def repair(self, *, sql, error_msg, schema_preview):
        completion = self.client.chat.completions.create(
            model=self.model,
            messages=[
                {
                    "role": "system",
                    "content": "You fix SQL queries keeping them SELECT-only.",
                },
                {
                    "role": "user",
                    "content": f"SQL:\n{sql}\nError:\n{error_msg}\nSchema:\n{schema_preview}",
                },
            ],
            temperature=0,
        )
        msg = completion.choices[0].message.content
        usage = completion.usage
        return (
            msg,
            usage.prompt_tokens,
            usage.completion_tokens,
            self._estimate_cost(usage),
        )

    def _estimate_cost(self, usage):
        total = usage.prompt_tokens + usage.completion_tokens
        return total * 0.000001