Spaces:
Sleeping
Sleeping
File size: 13,001 Bytes
3c665d2 92cc088 3c665d2 44ef33f 3c665d2 92cc088 3c665d2 f0b682f 3c665d2 44ef33f 3c665d2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 | """
GEPA (Goal-directed Evolutionary Prompt Adaptation) optimizer.
Ported from gepa.ts. Key steps:
1. Reflection: LLM analyzes failure history, outputs diagnosis
2. Mutation: LLM rewrites system prompt based on diagnosis
3. Scoring: Run 3 golden queries with new prompt, compute score
4. Pareto front: Keep top 3 prompts by (score, diversity)
State is persisted to data/gepa_prompt.json.
"""
from __future__ import annotations
import json
import os
import time
from pathlib import Path
from typing import Optional
from openai import AsyncOpenAI
from pydantic import BaseModel
_DATA_DIR = Path(os.environ.get("DATA_DIR", Path(__file__).parent.parent / "data"))
GEPA_PATH = _DATA_DIR / "gepa_prompt.json"
_API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1")
_MODEL = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
_HF_TOKEN = os.environ.get("HF_TOKEN") # no default β must be set explicitly
# How many queries between each GEPA optimization cycle.
# Override with the GEPA_OPTIMIZE_EVERY environment variable.
GEPA_OPTIMIZE_EVERY: int = int(os.environ.get("GEPA_OPTIMIZE_EVERY", "4"))
SEED_SYSTEM_PROMPT = """You are a SQL expert. Given a natural language question and a SQLite database schema, write a correct SQL query.
Rules:
- Output ONLY the SQL query, nothing else
- No markdown, no code fences, no explanation
- Use SQLite syntax"""
# βββ Models ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class QueryResult(BaseModel):
question: str
final_sql: str
attempts: int
success: bool
errors: list[str]
timestamp: float
class Candidate(BaseModel):
system_prompt: str
score: float
avg_attempts: float
success_rate: float
generation: int
feedback: list[str]
# βββ LLM Helper ββββββββββββββββββββββββββββββββββββββββββββββββββ
def _make_client() -> AsyncOpenAI:
return AsyncOpenAI(
api_key=_HF_TOKEN,
base_url=_API_BASE_URL,
)
async def _complete(system: str, user: str) -> str:
client = _make_client()
resp = await client.chat.completions.create(
model=_MODEL,
messages=[
{"role": "system", "content": system},
{"role": "user", "content": user},
],
temperature=0.7,
)
return resp.choices[0].message.content or ""
# βββ Golden Queries for Scoring ββββββββββββββββββββββββββββββββββ
_GOLDEN_QUERIES = [
{
"id": "gq-01",
"question": "List all users from the USA.",
"expected_min_rows": 10,
},
{
"id": "gq-02",
"question": "Show all products in the 'Electronics' category.",
"expected_min_rows": 8,
},
{
"id": "gq-03",
"question": "Find the total number of orders per user.",
"expected_min_rows": 10,
},
{
"id": "gq-04",
"question": "Show the average rating for each product category.",
"expected_min_rows": 5,
},
{
"id": "gq-05",
"question": "List products along with their seller name.",
"expected_min_rows": 20,
},
]
# βββ Optimizer Class ββββββββββββββββββββββββββββββββββββββββββββββ
class GEPAOptimizer:
def __init__(self) -> None:
self._history: list[QueryResult] = []
self._pareto_front: list[Candidate] = [
Candidate(
system_prompt=SEED_SYSTEM_PROMPT,
score=0.5,
avg_attempts=3.0,
success_rate=0.5,
generation=0,
feedback=[],
)
]
self._load()
# βββ Public Interface βββββββββββββββββββββββββββββββββββββββββ
def record_result(self, result: QueryResult) -> None:
self._history.append(result)
self._save()
def get_current_prompt(self) -> str:
if not self._pareto_front:
return SEED_SYSTEM_PROMPT
return max(self._pareto_front, key=lambda c: c.score).system_prompt
def get_history(self) -> list[QueryResult]:
return list(self._history)
def get_pareto_front(self) -> list[Candidate]:
return list(self._pareto_front)
def set_current_prompt(self, prompt: str) -> None:
if self._pareto_front:
best = max(self._pareto_front, key=lambda c: c.score)
best.system_prompt = prompt
else:
self._pareto_front.append(
Candidate(
system_prompt=prompt,
score=0.5,
avg_attempts=3.0,
success_rate=0.5,
generation=0,
feedback=[],
)
)
self._save()
@property
def current_generation(self) -> int:
if not self._pareto_front:
return 0
return max(c.generation for c in self._pareto_front)
def should_optimize(self) -> bool:
return len(self._history) > 0 and len(self._history) % GEPA_OPTIMIZE_EVERY == 0
def reset(self) -> None:
self._history.clear()
self._pareto_front.clear()
self._pareto_front.append(
Candidate(
system_prompt=SEED_SYSTEM_PROMPT,
score=0.5,
avg_attempts=3.0,
success_rate=0.5,
generation=0,
feedback=[],
)
)
self._save()
async def run_optimization_cycle(
self,
user_feedback_context: Optional[str] = None,
dialect: str = "SQLite",
) -> Optional[dict]:
"""
Run one GEPA cycle: reflect β mutate β score β update Pareto front.
Returns {new_prompt, reflection} or None if not enough data.
"""
if len(self._history) < 2:
return None
recent_failures = [
h for h in self._history if h.attempts > 1 or not h.success
][-8:]
if len(recent_failures) < 2:
return None
current_best = self.get_current_prompt()
# ββ Step 1: Reflect ββββββββββββββββββββββββββββββββββββββ
failure_summary = "\n\n---\n\n".join(
f'Query {i+1}: "{f.question}"\n'
f"Attempts: {f.attempts}\n"
f"Errors:\n" + "\n".join(f" - {e}" for e in f.errors) + "\n"
f"Final SQL: {f.final_sql}"
for i, f in enumerate(recent_failures)
)
user_ctx_block = (
f"\n\nUser conversation:\n{user_feedback_context}"
if user_feedback_context
else ""
)
reflection = await _complete(
f"You are an expert SQL prompt engineer analyzing why an LLM SQL agent is failing.\n"
f"The target database is {dialect} β all rules must use {dialect} syntax.\n"
"Your job: identify specific, recurring patterns in these failures and state EXACTLY "
"what rules or knowledge the system prompt is missing.\n"
"Be very specific β name the exact functions, syntax patterns, or schema reasoning gaps.\n"
"Output a concise diagnosis (3-5 bullet points max).",
f"Current system prompt:\n{current_best}\n\n"
f"Recent failures:\n{failure_summary}{user_ctx_block}",
)
# ββ Step 2: Mutate βββββββββββββββββββββββββββββββββββββββ
current_generation = max(c.generation for c in self._pareto_front) if self._pareto_front else 0
new_prompt = await _complete(
f"You are an expert prompt engineer. Improve a system prompt for a {dialect} SQL generation agent.\n"
"Rules for the new prompt:\n"
"- Keep it concise and actionable\n"
f"- The target database is {dialect} β use ONLY {dialect} syntax and functions\n"
"- Add specific rules that address the diagnosed failure patterns\n"
"- Do NOT add generic fluff β every rule must be earned by a real failure\n"
"- Output ONLY the improved system prompt text, nothing else",
f"Current system prompt:\n{current_best}\n\n"
f"Diagnosed failure patterns:\n{reflection}\n\n"
"Write the improved system prompt:",
)
# ββ Step 3: Score ββββββββββββββββββββββββββββββββββββββββ
benchmark_score = await self._score_prompt(new_prompt)
current_avg_attempts = (
sum(h.attempts for h in self._history) / len(self._history)
if self._history
else 3.0
)
new_candidate = Candidate(
system_prompt=new_prompt,
score=benchmark_score,
avg_attempts=max(current_avg_attempts - 0.5, 1.0),
success_rate=benchmark_score,
generation=current_generation + 1,
feedback=[reflection],
)
# ββ Step 4: Update Pareto front ββββββββββββββββββββββββββ
self._pareto_front.append(new_candidate)
self._pareto_front.sort(key=lambda c: c.score, reverse=True)
if len(self._pareto_front) > 3:
self._pareto_front = self._pareto_front[:3]
self._save()
return {"new_prompt": new_prompt, "reflection": reflection}
async def _score_prompt(self, prompt: str) -> float:
"""
Score a prompt by running 3 golden queries and measuring success rate.
"""
from env.database import execute_query, get_schema_info
import re
schema = get_schema_info()
client = _make_client()
scores = []
for gq in _GOLDEN_QUERIES[:3]:
try:
resp = await client.chat.completions.create(
model=_MODEL,
messages=[
{"role": "system", "content": prompt},
{
"role": "user",
"content": (
f"Schema:\n{schema}\n\n"
f"Question: {gq['question']}\n\n"
"Write a SQL query."
),
},
],
temperature=0.1,
)
sql = resp.choices[0].message.content or ""
sql = re.sub(r"^```(?:sql)?\s*", "", sql.strip(), flags=re.IGNORECASE)
sql = re.sub(r"\s*```$", "", sql).strip().rstrip(";")
rows, error = execute_query(sql)
if error is None and len(rows) >= gq["expected_min_rows"]:
scores.append(1.0)
elif error is None and rows:
scores.append(0.5)
else:
scores.append(0.0)
except Exception:
scores.append(0.0)
return sum(scores) / len(scores) if scores else 0.3
# βββ Persistence βββββββββββββββββββββββββββββββββββββββββββββ
def _save(self) -> None:
try:
GEPA_PATH.parent.mkdir(parents=True, exist_ok=True)
data = {
"history": [r.model_dump() for r in self._history[-100:]],
"pareto_front": [c.model_dump() for c in self._pareto_front],
}
GEPA_PATH.write_text(json.dumps(data, default=str))
except Exception:
pass
def _load(self) -> None:
try:
if not GEPA_PATH.exists():
return
data = json.loads(GEPA_PATH.read_text())
self._history = [QueryResult(**r) for r in data.get("history", [])]
loaded_front = [Candidate(**c) for c in data.get("pareto_front", [])]
if loaded_front:
self._pareto_front = loaded_front
except Exception:
pass
# βββ Singleton ββββββββββββββββββββββββββββββββββββββββββββββββββββ
_gepa_instance: Optional[GEPAOptimizer] = None
def get_gepa() -> GEPAOptimizer:
global _gepa_instance
if _gepa_instance is None:
_gepa_instance = GEPAOptimizer()
return _gepa_instance
|