Spaces:
Sleeping
Sleeping
Melika Kheirieh
feat(llm): proxy-first fallback, env-only OpenAI client; docs: update .env.example
260d5c1
| from __future__ import annotations | |
| import os | |
| import json | |
| from adapters.llm.base import LLMProvider | |
| from openai import OpenAI | |
| # NOTE: | |
| # - Prefer proxy if PROXY_API_KEY and PROXY_BASE_URL are set. | |
| # - Otherwise, fallback to OPENAI_API_KEY (+ OPENAI_BASE_URL defaulting to https://api.openai.com/v1). | |
| # - Do NOT pass base_url/api_key in the constructor; rely on env vars. | |
| def _resolve_api_config() -> tuple[str, str, str]: | |
| """ | |
| Returns (api_key, base_url, model_id) according to env. | |
| Resolution order: | |
| 1) Proxy: PROXY_API_KEY + PROXY_BASE_URL [+ PROXY_MODEL_ID] | |
| 2) Direct: OPENAI_API_KEY [+ OPENAI_BASE_URL] [+ OPENAI_MODEL_ID] | |
| Additionally, LLM_MODEL_ID (if set) overrides model choice. | |
| """ | |
| # Optional global override for model id | |
| override_model = os.getenv("LLM_MODEL_ID") | |
| proxy_key = os.getenv("PROXY_API_KEY") | |
| proxy_url = os.getenv("PROXY_BASE_URL") | |
| if proxy_key and proxy_url: | |
| model = ( | |
| override_model | |
| or os.getenv("PROXY_MODEL_ID") | |
| or os.getenv("OPENAI_MODEL_ID") | |
| or "gpt-4o-mini" | |
| ) | |
| return proxy_key, proxy_url, model | |
| openai_key = os.getenv("OPENAI_API_KEY") | |
| if not openai_key: | |
| raise RuntimeError( | |
| "No API credentials found. Set either PROXY_API_KEY/PROXY_BASE_URL or OPENAI_API_KEY." | |
| ) | |
| openai_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1") | |
| model = override_model or os.getenv("OPENAI_MODEL_ID") or "gpt-4o-mini" | |
| return openai_key, openai_url, model | |
| class OpenAIProvider(LLMProvider): | |
| provider_id = "openai" | |
| def __init__(self) -> None: | |
| # Resolve and export to env so we don't pass into constructor. | |
| api_key, base_url, model = _resolve_api_config() | |
| os.environ["OPENAI_API_KEY"] = api_key | |
| os.environ["OPENAI_BASE_URL"] = base_url | |
| # Create client using env only | |
| self.client = OpenAI() | |
| self.model = model | |
| def plan(self, *, user_query, schema_preview): | |
| completion = self.client.chat.completions.create( | |
| model=self.model, | |
| messages=[ | |
| {"role": "system", "content": "You create SQL query plans."}, | |
| { | |
| "role": "user", | |
| "content": f"Query: {user_query}\nSchema:\n{schema_preview}", | |
| }, | |
| ], | |
| temperature=0, | |
| ) | |
| msg = completion.choices[0].message.content | |
| usage = completion.usage | |
| return ( | |
| msg, | |
| usage.prompt_tokens, | |
| usage.completion_tokens, | |
| self._estimate_cost(usage), | |
| ) | |
| def generate_sql( | |
| self, *, user_query, schema_preview, plan_text, clarify_answers=None | |
| ): | |
| prompt = f""" | |
| You are a precise SQL generator. | |
| Return ONLY valid JSON with two keys: "sql" and "rationale". | |
| Do not include any markdown, backticks, or extra text. | |
| Example: | |
| {{ | |
| "sql": "SELECT * FROM singer;", | |
| "rationale": "The user requested to list all singers." | |
| }} | |
| Now generate JSON for this input: | |
| User query: {user_query} | |
| Schema preview: | |
| {schema_preview} | |
| Plan: {plan_text} | |
| Clarifications: {clarify_answers} | |
| """ | |
| completion = self.client.chat.completions.create( | |
| model=self.model, | |
| messages=[ | |
| {"role": "system", "content": "You convert natural language to SQL."}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| temperature=0, | |
| ) | |
| content = completion.choices[0].message.content.strip() | |
| usage = completion.usage | |
| t_in = usage.prompt_tokens if usage else None | |
| t_out = usage.completion_tokens if usage else None | |
| cost = self._estimate_cost(usage) if usage else None | |
| try: | |
| parsed = json.loads(content) | |
| except json.JSONDecodeError: | |
| start = content.find("{") | |
| end = content.rfind("}") | |
| if start != -1 and end != -1: | |
| try: | |
| parsed = json.loads(content[start : end + 1]) | |
| except Exception: | |
| raise ValueError(f"Invalid LLM JSON output: {content[:200]}") | |
| else: | |
| raise ValueError(f"Invalid LLM JSON output: {content[:200]}") | |
| sql = (parsed.get("sql") or "").strip() | |
| rationale = parsed.get("rationale") or "" | |
| if not sql: | |
| raise ValueError("LLM returned empty 'sql'") | |
| return sql, rationale, t_in, t_out, cost | |
| def repair(self, *, sql, error_msg, schema_preview): | |
| completion = self.client.chat.completions.create( | |
| model=self.model, | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": "You fix SQL queries keeping them SELECT-only.", | |
| }, | |
| { | |
| "role": "user", | |
| "content": f"SQL:\n{sql}\nError:\n{error_msg}\nSchema:\n{schema_preview}", | |
| }, | |
| ], | |
| temperature=0, | |
| ) | |
| msg = completion.choices[0].message.content | |
| usage = completion.usage | |
| return ( | |
| msg, | |
| usage.prompt_tokens, | |
| usage.completion_tokens, | |
| self._estimate_cost(usage), | |
| ) | |
| def _estimate_cost(self, usage): | |
| total = usage.prompt_tokens + usage.completion_tokens | |
| return total * 0.000001 | |