Spaces:
Running
Running
File size: 5,529 Bytes
570f7bd 260d5c1 570f7bd 260d5c1 570f7bd c1bc4eb 570f7bd c1bc4eb 570f7bd c1bc4eb 570f7bd c1bc4eb 570f7bd 260d5c1 570f7bd c1bc4eb 570f7bd c1bc4eb 570f7bd 260d5c1 570f7bd c1bc4eb 570f7bd c1bc4eb 570f7bd c1bc4eb 570f7bd c1bc4eb 570f7bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
from __future__ import annotations
import os
import json
from adapters.llm.base import LLMProvider
from openai import OpenAI
# NOTE:
# - Prefer proxy if PROXY_API_KEY and PROXY_BASE_URL are set.
# - Otherwise, fallback to OPENAI_API_KEY (+ OPENAI_BASE_URL defaulting to https://api.openai.com/v1).
# - Do NOT pass base_url/api_key in the constructor; rely on env vars.
def _resolve_api_config() -> tuple[str, str, str]:
"""
Returns (api_key, base_url, model_id) according to env.
Resolution order:
1) Proxy: PROXY_API_KEY + PROXY_BASE_URL [+ PROXY_MODEL_ID]
2) Direct: OPENAI_API_KEY [+ OPENAI_BASE_URL] [+ OPENAI_MODEL_ID]
Additionally, LLM_MODEL_ID (if set) overrides model choice.
"""
# Optional global override for model id
override_model = os.getenv("LLM_MODEL_ID")
proxy_key = os.getenv("PROXY_API_KEY")
proxy_url = os.getenv("PROXY_BASE_URL")
if proxy_key and proxy_url:
model = (
override_model
or os.getenv("PROXY_MODEL_ID")
or os.getenv("OPENAI_MODEL_ID")
or "gpt-4o-mini"
)
return proxy_key, proxy_url, model
openai_key = os.getenv("OPENAI_API_KEY")
if not openai_key:
raise RuntimeError(
"No API credentials found. Set either PROXY_API_KEY/PROXY_BASE_URL or OPENAI_API_KEY."
)
openai_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
model = override_model or os.getenv("OPENAI_MODEL_ID") or "gpt-4o-mini"
return openai_key, openai_url, model
class OpenAIProvider(LLMProvider):
provider_id = "openai"
def __init__(self) -> None:
# Resolve and export to env so we don't pass into constructor.
api_key, base_url, model = _resolve_api_config()
os.environ["OPENAI_API_KEY"] = api_key
os.environ["OPENAI_BASE_URL"] = base_url
# Create client using env only
self.client = OpenAI()
self.model = model
def plan(self, *, user_query, schema_preview):
completion = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You create SQL query plans."},
{
"role": "user",
"content": f"Query: {user_query}\nSchema:\n{schema_preview}",
},
],
temperature=0,
)
msg = completion.choices[0].message.content
usage = completion.usage
return (
msg,
usage.prompt_tokens,
usage.completion_tokens,
self._estimate_cost(usage),
)
def generate_sql(
self, *, user_query, schema_preview, plan_text, clarify_answers=None
):
prompt = f"""
You are a precise SQL generator.
Return ONLY valid JSON with two keys: "sql" and "rationale".
Do not include any markdown, backticks, or extra text.
Example:
{{
"sql": "SELECT * FROM singer;",
"rationale": "The user requested to list all singers."
}}
Now generate JSON for this input:
User query: {user_query}
Schema preview:
{schema_preview}
Plan: {plan_text}
Clarifications: {clarify_answers}
"""
completion = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You convert natural language to SQL."},
{"role": "user", "content": prompt},
],
temperature=0,
)
content = completion.choices[0].message.content.strip()
usage = completion.usage
t_in = usage.prompt_tokens if usage else None
t_out = usage.completion_tokens if usage else None
cost = self._estimate_cost(usage) if usage else None
try:
parsed = json.loads(content)
except json.JSONDecodeError:
start = content.find("{")
end = content.rfind("}")
if start != -1 and end != -1:
try:
parsed = json.loads(content[start : end + 1])
except Exception:
raise ValueError(f"Invalid LLM JSON output: {content[:200]}")
else:
raise ValueError(f"Invalid LLM JSON output: {content[:200]}")
sql = (parsed.get("sql") or "").strip()
rationale = parsed.get("rationale") or ""
if not sql:
raise ValueError("LLM returned empty 'sql'")
return sql, rationale, t_in, t_out, cost
def repair(self, *, sql, error_msg, schema_preview):
completion = self.client.chat.completions.create(
model=self.model,
messages=[
{
"role": "system",
"content": "You fix SQL queries keeping them SELECT-only.",
},
{
"role": "user",
"content": f"SQL:\n{sql}\nError:\n{error_msg}\nSchema:\n{schema_preview}",
},
],
temperature=0,
)
msg = completion.choices[0].message.content
usage = completion.usage
return (
msg,
usage.prompt_tokens,
usage.completion_tokens,
self._estimate_cost(usage),
)
def _estimate_cost(self, usage):
total = usage.prompt_tokens + usage.completion_tokens
return total * 0.000001
|