from __future__ import annotations import json import os import re from typing import Any, Dict, List, Tuple from adapters.llm.base import LLMProvider from openai import OpenAI def _resolve_api_config() -> tuple[str, str, str]: """Returns (api_key, base_url, model_id) according to env.""" override_model = os.getenv("LLM_MODEL_ID") proxy_key = os.getenv("PROXY_API_KEY") proxy_url = os.getenv("PROXY_BASE_URL") if proxy_key and proxy_url: model = ( override_model or os.getenv("PROXY_MODEL_ID") or os.getenv("OPENAI_MODEL_ID") or "gpt-4o-mini" ) return proxy_key, proxy_url, model openai_key = os.getenv("OPENAI_API_KEY") if not openai_key: raise RuntimeError( "No API credentials found. Set either PROXY_API_KEY/PROXY_BASE_URL or OPENAI_API_KEY." ) openai_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1") model = override_model or os.getenv("OPENAI_MODEL_ID") or "gpt-4o-mini" return openai_key, openai_url, model class OpenAIProvider(LLMProvider): """OpenAI LLM provider implementation. Goals for this implementation: - Keep prompts and behavior as close as possible to the current repo version. - Align method signatures + return shapes with the updated LLMProvider Protocol. - Provide a lightweight `used_tables` signal for observability/drift checks. """ PROVIDER_ID = "openai" def __init__(self) -> None: """Initialize OpenAI client with config from environment.""" api_key, base_url, model = _resolve_api_config() os.environ["OPENAI_API_KEY"] = api_key os.environ["OPENAI_BASE_URL"] = base_url self.client = OpenAI(timeout=120.0) self.model = model self._last_usage: dict[str, Any] = {} def get_last_usage(self) -> dict[str, Any]: """Return metadata of the last LLM call (tokens, cost, sql_length, kind).""" return dict(self._last_usage) def _create_chat_completion(self, **kwargs): """OpenAI SDK seam for stable unit testing.""" return self.client.chat.completions.create(**kwargs) # --------------------------------------------------------------------- # Table extraction helpers (best-effort; no heavy parsing). # --------------------------------------------------------------------- def _extract_schema_tables(self, schema_preview: str) -> List[str]: """Extract likely table names from the schema preview string.""" if not schema_preview: return [] tables: List[str] = [] for m in re.finditer( r"(?im)^\s*(?:-\s*)?table\s*[: ]\s*([A-Za-z_][A-Za-z0-9_]*)\b", schema_preview, ): tables.append(m.group(1)) for m in re.finditer( r"(?im)^\s*create\s+table\s+`?([A-Za-z_][A-Za-z0-9_]*)`?\b", schema_preview ): tables.append(m.group(1)) seen = set() uniq: List[str] = [] for t in tables: if t not in seen: uniq.append(t) seen.add(t) return uniq def _extract_tables_from_sql(self, sql: str) -> List[str]: """Very lightweight table extraction from FROM/JOIN clauses.""" if not sql: return [] pairs = re.findall( r"\bfrom\s+([A-Za-z_][A-Za-z0-9_]*)|\bjoin\s+([A-Za-z_][A-Za-z0-9_]*)", sql, flags=re.IGNORECASE, ) out: List[str] = [] for t1, t2 in pairs: if t1: out.append(t1) if t2: out.append(t2) seen = set() uniq: List[str] = [] for t in out: if t not in seen: uniq.append(t) seen.add(t) return uniq def _extract_used_tables_from_plan( self, plan_text: str, schema_preview: str ) -> List[str]: """Best-effort used table list from plan text by intersecting with schema table names.""" candidates = self._extract_schema_tables(schema_preview) if not candidates or not plan_text: return [] used: List[str] = [] for t in candidates: if re.search(rf"\b{re.escape(t)}\b", plan_text, flags=re.IGNORECASE): used.append(t) return used # --------------------------------------------------------------------- # Cost estimation # --------------------------------------------------------------------- def _estimate_cost(self, usage: Any) -> float: """Estimate cost based on token usage.""" if not usage: return 0.0 pricing = { "gpt-4": {"input": 0.03, "output": 0.06}, "gpt-4-turbo": {"input": 0.01, "output": 0.03}, "gpt-4o": {"input": 0.005, "output": 0.015}, "gpt-4o-mini": {"input": 0.00015, "output": 0.0006}, "gpt-3.5-turbo": {"input": 0.0005, "output": 0.0015}, } model_pricing = pricing.get(self.model, pricing["gpt-4o-mini"]) input_cost = (usage.prompt_tokens / 1000) * model_pricing["input"] output_cost = (usage.completion_tokens / 1000) * model_pricing["output"] return input_cost + output_cost # --------------------------------------------------------------------- # LLMProvider API # --------------------------------------------------------------------- def plan( self, *, user_query: str, schema_preview: str, constraints: List[str] | None = None, ) -> Tuple[str, List[str], int, int, float]: """Return (plan_text, used_tables, token_in, token_out, cost_usd).""" system_prompt = """You are a SQL query planning expert. Analyze the user's question and database schema to create a clear execution plan. Your plan should: 1. Identify the tables and columns needed 2. Determine any JOINs required 3. Specify filtering conditions (WHERE) 4. Identify aggregations (GROUP BY, COUNT, etc.) 5. Note sorting requirements (ORDER BY) 6. Check for special cases (DISTINCT, LIMIT, etc.) Be concise but thorough.""" user_prompt = f"""Question: {user_query} Database Schema: {schema_preview} Constraints: {constraints or []} Create a step-by-step plan to answer this question with SQL.""" completion = self._create_chat_completion( model=self.model, messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ], temperature=0.1, ) msg = completion.choices[0].message.content or "" usage = completion.usage plan_text = msg.strip() used_tables = self._extract_used_tables_from_plan(plan_text, schema_preview) if usage: prompt_tokens = usage.prompt_tokens completion_tokens = usage.completion_tokens cost = self._estimate_cost(usage) self._last_usage = { "kind": "plan", "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "cost_usd": cost, } return (plan_text, used_tables, prompt_tokens, completion_tokens, cost) self._last_usage = { "kind": "plan", "prompt_tokens": 0, "completion_tokens": 0, "cost_usd": 0.0, } return (plan_text, used_tables, 0, 0, 0.0) def generate_sql( self, *, user_query: str, schema_preview: str, plan_text: str, constraints: List[str] | None = None, clarify_answers: Dict[str, Any] | None = None, ) -> Tuple[str, str, int, int, float]: """Return (sql, rationale, token_in, token_out, cost_usd).""" system_prompt = """You are an expert SQL generator. CRITICAL RULES: 1. Write the SIMPLEST possible SQL that answers the question 2. NEVER use table prefixes unless absolutely necessary for disambiguation 3. NEVER add aliases (AS) unless specifically requested 4. NEVER add LIMIT unless the question asks for a specific number of results 5. NEVER use DISTINCT with COUNT(*) unless explicitly needed 6. Use lowercase for SQL keywords (select, from, where, etc.) 7. Do not add unnecessary parentheses or formatting 8. Match exact column and table names from the schema (case-sensitive) 9. NEVER return empty SQL. If unsure, return the simplest valid SQL that answers the question. 10. Use exact identifiers from `schema_preview` (case-insensitive match). 11. Do NOT invent or pluralize table names. E.g., use `Artist`, not `artists`. IMPORTANT: - For counting all rows: Use COUNT(*) not COUNT(column_name) - For ordering: Only add ORDER BY if the question asks for sorted results - Keep the SQL as close as possible to the minimal required syntax You must return ONLY valid JSON with exactly two keys: "sql" and "rationale". The SQL should be a single line without unnecessary spaces.""" user_prompt = f"""Based on this information, generate a simple SQL query: Question: {user_query} Database Schema: {schema_preview} Query Plan: {plan_text} Constraints: {constraints or []} Remember: Generate the SIMPLEST possible SQL. Avoid table prefixes, aliases, and unnecessary clauses. Example of what we want: Question: "How many singers are there?" Correct: {{"sql": "select count(*) from singer", "rationale": "Count all rows in singer table"}} Wrong: {{"sql": "SELECT COUNT(singer.singer_id) AS total_singers FROM singer", "rationale": "..."}} Now generate the SQL for the given question:""" if clarify_answers: user_prompt += f"\n\nAdditional context_engineering: {clarify_answers}" completion = self._create_chat_completion( model=self.model, messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ], temperature=0.1, max_tokens=500, ) text = completion.choices[0].message.content content = text.strip() if text else "" usage = completion.usage try: parsed = json.loads(content) except json.JSONDecodeError: start = content.find("{") end = content.rfind("}") if start != -1 and end != -1: try: parsed = json.loads(content[start : end + 1]) except Exception as e: raise ValueError(f"Invalid LLM JSON output: {content[:200]}") from e else: raise ValueError(f"Invalid LLM JSON output: {content[:200]}") sql = str(parsed.get("sql") or "").strip() rationale = str(parsed.get("rationale") or "") sql = self._simplify_sql(sql) if not sql: raise ValueError("LLM returned empty 'sql'") used_tables = self._extract_tables_from_sql(sql) sql_length = len(sql) if usage: prompt_tokens = usage.prompt_tokens completion_tokens = usage.completion_tokens cost = self._estimate_cost(usage) self._last_usage = { "kind": "generate", "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "cost_usd": cost, "sql_length": sql_length, "used_tables": used_tables, } return (sql, rationale, prompt_tokens, completion_tokens, cost) self._last_usage = { "kind": "generate", "prompt_tokens": 0, "completion_tokens": 0, "cost_usd": 0.0, "sql_length": sql_length, "used_tables": used_tables, } return (sql, rationale, 0, 0, 0.0) def _simplify_sql(self, sql: str) -> str: """Post-process SQL to remove common unnecessary additions.""" if not sql: return sql sql = sql.rstrip(";") if sql.lower().count(" from ") == 1 and " join " not in sql.lower(): match = re.search(r"\bfrom\s+(\w+)", sql, re.IGNORECASE) if match: table = match.group(1) sql = re.sub(rf"\b{table}\.(\w+)\b", r"\1", sql) sql = re.sub( r"count\s*\(\s*distinct\s+\*\s*\)", "count(*)", sql, flags=re.IGNORECASE, ) sql = re.sub( r"\s+limit\s+(100|1000|10000)\b", "", sql, flags=re.IGNORECASE, ) return sql def repair( self, *, sql: str, error_msg: str, schema_preview: str, ) -> Tuple[str, int, int, float]: """Return (patched_sql, token_in, token_out, cost_usd).""" system_prompt = """You are a SQL repair expert. Fix the given SQL query to resolve the error. IMPORTANT RULES: 1. Keep the fix as minimal as possible 2. Don't add complexity - keep it simple 3. Preserve the original intent of the query 4. Follow SQLite syntax rules 5. Don't add aliases or table prefixes unless necessary 6. Use exact identifiers from `schema_preview` (case-insensitive match). 7. Do NOT invent or pluralize table names. E.g., use `Artist`, not `artists`. Return ONLY the corrected SQL query, nothing else.""" user_prompt = f"""Fix this SQL query: Original SQL: {sql} Error: {error_msg} Database Schema: {schema_preview} Return the corrected SQL (keep it simple):""" completion = self._create_chat_completion( model=self.model, messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ], temperature=0.1, ) text = completion.choices[0].message.content fixed_sql = text.strip() if text else "" if fixed_sql.startswith("```sql"): fixed_sql = fixed_sql[6:] if fixed_sql.startswith("```"): fixed_sql = fixed_sql[3:] if fixed_sql.endswith("```"): fixed_sql = fixed_sql[:-3] fixed_sql = fixed_sql.strip() fixed_sql = self._simplify_sql(fixed_sql) usage = completion.usage if usage: prompt_tokens = usage.prompt_tokens completion_tokens = usage.completion_tokens cost = self._estimate_cost(usage) self._last_usage = { "kind": "repair", "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "cost_usd": cost, "sql_length": len(fixed_sql), } return (fixed_sql, prompt_tokens, completion_tokens, cost) self._last_usage = { "kind": "repair", "prompt_tokens": 0, "completion_tokens": 0, "cost_usd": 0.0, "sql_length": len(fixed_sql), } return (fixed_sql, 0, 0, 0.0)