sql-agent-openenv / backend /gepa /optimizer.py
ar9avg's picture
fix
92cc088
"""
GEPA (Goal-directed Evolutionary Prompt Adaptation) optimizer.
Ported from gepa.ts. Key steps:
1. Reflection: LLM analyzes failure history, outputs diagnosis
2. Mutation: LLM rewrites system prompt based on diagnosis
3. Scoring: Run 3 golden queries with new prompt, compute score
4. Pareto front: Keep top 3 prompts by (score, diversity)
State is persisted to data/gepa_prompt.json.
"""
from __future__ import annotations
import json
import os
import time
from pathlib import Path
from typing import Optional
from openai import AsyncOpenAI
from pydantic import BaseModel
_DATA_DIR = Path(os.environ.get("DATA_DIR", Path(__file__).parent.parent / "data"))
GEPA_PATH = _DATA_DIR / "gepa_prompt.json"
_API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1")
_MODEL = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
_HF_TOKEN = os.environ.get("HF_TOKEN") # no default β€” must be set explicitly
# How many queries between each GEPA optimization cycle.
# Override with the GEPA_OPTIMIZE_EVERY environment variable.
GEPA_OPTIMIZE_EVERY: int = int(os.environ.get("GEPA_OPTIMIZE_EVERY", "4"))
SEED_SYSTEM_PROMPT = """You are a SQL expert. Given a natural language question and a SQLite database schema, write a correct SQL query.
Rules:
- Output ONLY the SQL query, nothing else
- No markdown, no code fences, no explanation
- Use SQLite syntax"""
# ─── Models ──────────────────────────────────────────────────────
class QueryResult(BaseModel):
question: str
final_sql: str
attempts: int
success: bool
errors: list[str]
timestamp: float
class Candidate(BaseModel):
system_prompt: str
score: float
avg_attempts: float
success_rate: float
generation: int
feedback: list[str]
# ─── LLM Helper ──────────────────────────────────────────────────
def _make_client() -> AsyncOpenAI:
return AsyncOpenAI(
api_key=_HF_TOKEN,
base_url=_API_BASE_URL,
)
async def _complete(system: str, user: str) -> str:
client = _make_client()
resp = await client.chat.completions.create(
model=_MODEL,
messages=[
{"role": "system", "content": system},
{"role": "user", "content": user},
],
temperature=0.7,
)
return resp.choices[0].message.content or ""
# ─── Golden Queries for Scoring ──────────────────────────────────
_GOLDEN_QUERIES = [
{
"id": "gq-01",
"question": "List all users from the USA.",
"expected_min_rows": 10,
},
{
"id": "gq-02",
"question": "Show all products in the 'Electronics' category.",
"expected_min_rows": 8,
},
{
"id": "gq-03",
"question": "Find the total number of orders per user.",
"expected_min_rows": 10,
},
{
"id": "gq-04",
"question": "Show the average rating for each product category.",
"expected_min_rows": 5,
},
{
"id": "gq-05",
"question": "List products along with their seller name.",
"expected_min_rows": 20,
},
]
# ─── Optimizer Class ──────────────────────────────────────────────
class GEPAOptimizer:
def __init__(self) -> None:
self._history: list[QueryResult] = []
self._pareto_front: list[Candidate] = [
Candidate(
system_prompt=SEED_SYSTEM_PROMPT,
score=0.5,
avg_attempts=3.0,
success_rate=0.5,
generation=0,
feedback=[],
)
]
self._load()
# ─── Public Interface ─────────────────────────────────────────
def record_result(self, result: QueryResult) -> None:
self._history.append(result)
self._save()
def get_current_prompt(self) -> str:
if not self._pareto_front:
return SEED_SYSTEM_PROMPT
return max(self._pareto_front, key=lambda c: c.score).system_prompt
def get_history(self) -> list[QueryResult]:
return list(self._history)
def get_pareto_front(self) -> list[Candidate]:
return list(self._pareto_front)
def set_current_prompt(self, prompt: str) -> None:
if self._pareto_front:
best = max(self._pareto_front, key=lambda c: c.score)
best.system_prompt = prompt
else:
self._pareto_front.append(
Candidate(
system_prompt=prompt,
score=0.5,
avg_attempts=3.0,
success_rate=0.5,
generation=0,
feedback=[],
)
)
self._save()
@property
def current_generation(self) -> int:
if not self._pareto_front:
return 0
return max(c.generation for c in self._pareto_front)
def should_optimize(self) -> bool:
return len(self._history) > 0 and len(self._history) % GEPA_OPTIMIZE_EVERY == 0
def reset(self) -> None:
self._history.clear()
self._pareto_front.clear()
self._pareto_front.append(
Candidate(
system_prompt=SEED_SYSTEM_PROMPT,
score=0.5,
avg_attempts=3.0,
success_rate=0.5,
generation=0,
feedback=[],
)
)
self._save()
async def run_optimization_cycle(
self,
user_feedback_context: Optional[str] = None,
dialect: str = "SQLite",
) -> Optional[dict]:
"""
Run one GEPA cycle: reflect β†’ mutate β†’ score β†’ update Pareto front.
Returns {new_prompt, reflection} or None if not enough data.
"""
if len(self._history) < 2:
return None
recent_failures = [
h for h in self._history if h.attempts > 1 or not h.success
][-8:]
if len(recent_failures) < 2:
return None
current_best = self.get_current_prompt()
# ── Step 1: Reflect ──────────────────────────────────────
failure_summary = "\n\n---\n\n".join(
f'Query {i+1}: "{f.question}"\n'
f"Attempts: {f.attempts}\n"
f"Errors:\n" + "\n".join(f" - {e}" for e in f.errors) + "\n"
f"Final SQL: {f.final_sql}"
for i, f in enumerate(recent_failures)
)
user_ctx_block = (
f"\n\nUser conversation:\n{user_feedback_context}"
if user_feedback_context
else ""
)
reflection = await _complete(
f"You are an expert SQL prompt engineer analyzing why an LLM SQL agent is failing.\n"
f"The target database is {dialect} β€” all rules must use {dialect} syntax.\n"
"Your job: identify specific, recurring patterns in these failures and state EXACTLY "
"what rules or knowledge the system prompt is missing.\n"
"Be very specific β€” name the exact functions, syntax patterns, or schema reasoning gaps.\n"
"Output a concise diagnosis (3-5 bullet points max).",
f"Current system prompt:\n{current_best}\n\n"
f"Recent failures:\n{failure_summary}{user_ctx_block}",
)
# ── Step 2: Mutate ───────────────────────────────────────
current_generation = max(c.generation for c in self._pareto_front) if self._pareto_front else 0
new_prompt = await _complete(
f"You are an expert prompt engineer. Improve a system prompt for a {dialect} SQL generation agent.\n"
"Rules for the new prompt:\n"
"- Keep it concise and actionable\n"
f"- The target database is {dialect} β€” use ONLY {dialect} syntax and functions\n"
"- Add specific rules that address the diagnosed failure patterns\n"
"- Do NOT add generic fluff β€” every rule must be earned by a real failure\n"
"- Output ONLY the improved system prompt text, nothing else",
f"Current system prompt:\n{current_best}\n\n"
f"Diagnosed failure patterns:\n{reflection}\n\n"
"Write the improved system prompt:",
)
# ── Step 3: Score ────────────────────────────────────────
benchmark_score = await self._score_prompt(new_prompt)
current_avg_attempts = (
sum(h.attempts for h in self._history) / len(self._history)
if self._history
else 3.0
)
new_candidate = Candidate(
system_prompt=new_prompt,
score=benchmark_score,
avg_attempts=max(current_avg_attempts - 0.5, 1.0),
success_rate=benchmark_score,
generation=current_generation + 1,
feedback=[reflection],
)
# ── Step 4: Update Pareto front ──────────────────────────
self._pareto_front.append(new_candidate)
self._pareto_front.sort(key=lambda c: c.score, reverse=True)
if len(self._pareto_front) > 3:
self._pareto_front = self._pareto_front[:3]
self._save()
return {"new_prompt": new_prompt, "reflection": reflection}
async def _score_prompt(self, prompt: str) -> float:
"""
Score a prompt by running 3 golden queries and measuring success rate.
"""
from env.database import execute_query, get_schema_info
import re
schema = get_schema_info()
client = _make_client()
scores = []
for gq in _GOLDEN_QUERIES[:3]:
try:
resp = await client.chat.completions.create(
model=_MODEL,
messages=[
{"role": "system", "content": prompt},
{
"role": "user",
"content": (
f"Schema:\n{schema}\n\n"
f"Question: {gq['question']}\n\n"
"Write a SQL query."
),
},
],
temperature=0.1,
)
sql = resp.choices[0].message.content or ""
sql = re.sub(r"^```(?:sql)?\s*", "", sql.strip(), flags=re.IGNORECASE)
sql = re.sub(r"\s*```$", "", sql).strip().rstrip(";")
rows, error = execute_query(sql)
if error is None and len(rows) >= gq["expected_min_rows"]:
scores.append(1.0)
elif error is None and rows:
scores.append(0.5)
else:
scores.append(0.0)
except Exception:
scores.append(0.0)
return sum(scores) / len(scores) if scores else 0.3
# ─── Persistence ─────────────────────────────────────────────
def _save(self) -> None:
try:
GEPA_PATH.parent.mkdir(parents=True, exist_ok=True)
data = {
"history": [r.model_dump() for r in self._history[-100:]],
"pareto_front": [c.model_dump() for c in self._pareto_front],
}
GEPA_PATH.write_text(json.dumps(data, default=str))
except Exception:
pass
def _load(self) -> None:
try:
if not GEPA_PATH.exists():
return
data = json.loads(GEPA_PATH.read_text())
self._history = [QueryResult(**r) for r in data.get("history", [])]
loaded_front = [Candidate(**c) for c in data.get("pareto_front", [])]
if loaded_front:
self._pareto_front = loaded_front
except Exception:
pass
# ─── Singleton ────────────────────────────────────────────────────
_gepa_instance: Optional[GEPAOptimizer] = None
def get_gepa() -> GEPAOptimizer:
global _gepa_instance
if _gepa_instance is None:
_gepa_instance = GEPAOptimizer()
return _gepa_instance