Spaces:
Running
Running
File size: 4,175 Bytes
570f7bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
from __future__ import annotations
import os
from typing import Tuple, Dict, Any, List
import json
from adapters.llm.base import LLMProvider
from openai import OpenAI
# NOTE: Read keys/base URL from env. Do NOT pass base_url in constructors.
# - OPENAI_API_KEY (required)
# - OPENAI_BASE_URL (optional; defaults to OpenAI public)
# - OPENAI_MODEL_ID (e.g., "gpt-4o-mini")
class OpenAIProvider(LLMProvider):
provider_id = "openai"
def __init__(self) -> None:
self.client = OpenAI(
api_key=os.environ["OPENAI_API_KEY"],
base_url=os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
)
self.model = os.getenv("OPENAI_MODEL_ID", "gpt-4o-mini")
def plan(self, *, user_query, schema_preview):
completion = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You create SQL query plans."},
{"role": "user", "content": f"Query: {user_query}\nSchema:\n{schema_preview}"}
],
temperature=0
)
msg = completion.choices[0].message.content
usage = completion.usage
return msg, usage.prompt_tokens, usage.completion_tokens, self._estimate_cost(usage)
def generate_sql(self, *, user_query, schema_preview, plan_text, clarify_answers=None):
prompt = f"""
You are a precise SQL generator.
Return ONLY valid JSON with two keys: "sql" and "rationale".
Do not include any markdown, backticks, or extra text.
Example:
{{
"sql": "SELECT * FROM singer;",
"rationale": "The user requested to list all singers."
}}
Now generate JSON for this input:
User query: {user_query}
Schema preview:
{schema_preview}
Plan: {plan_text}
Clarifications: {clarify_answers}
"""
completion = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You convert natural language to SQL."},
{"role": "user", "content": prompt}
],
temperature=0
)
content = completion.choices[0].message.content.strip()
usage = completion.usage # ← لازم داریم
t_in = usage.prompt_tokens if usage else None
t_out = usage.completion_tokens if usage else None
cost = self._estimate_cost(usage) if usage else None
# Robust JSON parse (with fallback to substring)
try:
parsed = json.loads(content)
except json.JSONDecodeError:
start = content.find("{")
end = content.rfind("}")
if start != -1 and end != -1:
try:
parsed = json.loads(content[start:end + 1])
except Exception:
raise ValueError(f"Invalid LLM JSON output: {content[:200]}")
else:
raise ValueError(f"Invalid LLM JSON output: {content[:200]}")
sql = (parsed.get("sql") or "").strip()
rationale = parsed.get("rationale") or ""
if not sql:
raise ValueError("LLM returned empty 'sql'")
# IMPORTANT: return the expected 5-tuple
return sql, rationale, t_in, t_out, cost
def repair(self, *, sql, error_msg, schema_preview):
completion = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You fix SQL queries keeping them SELECT-only."},
{"role": "user", "content": f"SQL:\n{sql}\nError:\n{error_msg}\nSchema:\n{schema_preview}"}
],
temperature=0
)
msg = completion.choices[0].message.content
usage = completion.usage
return msg, usage.prompt_tokens, usage.completion_tokens, self._estimate_cost(usage)
def _estimate_cost(self, usage):
# Rough estimation example — can be refined with official token pricing
total = usage.prompt_tokens + usage.completion_tokens
return total * 0.000001
|