Spaces:
Sleeping
Sleeping
| """ | |
| GEPA (Goal-directed Evolutionary Prompt Adaptation) optimizer. | |
| Ported from gepa.ts. Key steps: | |
| 1. Reflection: LLM analyzes failure history, outputs diagnosis | |
| 2. Mutation: LLM rewrites system prompt based on diagnosis | |
| 3. Scoring: Run 3 golden queries with new prompt, compute score | |
| 4. Pareto front: Keep top 3 prompts by (score, diversity) | |
| State is persisted to data/gepa_prompt.json. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import time | |
| from pathlib import Path | |
| from typing import Optional | |
| from openai import AsyncOpenAI | |
| from pydantic import BaseModel | |
| _DATA_DIR = Path(os.environ.get("DATA_DIR", Path(__file__).parent.parent / "data")) | |
| GEPA_PATH = _DATA_DIR / "gepa_prompt.json" | |
| _API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1") | |
| _MODEL = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct") | |
| _HF_TOKEN = os.environ.get("HF_TOKEN") # no default β must be set explicitly | |
| # How many queries between each GEPA optimization cycle. | |
| # Override with the GEPA_OPTIMIZE_EVERY environment variable. | |
| GEPA_OPTIMIZE_EVERY: int = int(os.environ.get("GEPA_OPTIMIZE_EVERY", "4")) | |
| SEED_SYSTEM_PROMPT = """You are a SQL expert. Given a natural language question and a SQLite database schema, write a correct SQL query. | |
| Rules: | |
| - Output ONLY the SQL query, nothing else | |
| - No markdown, no code fences, no explanation | |
| - Use SQLite syntax""" | |
| # βββ Models ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class QueryResult(BaseModel): | |
| question: str | |
| final_sql: str | |
| attempts: int | |
| success: bool | |
| errors: list[str] | |
| timestamp: float | |
| class Candidate(BaseModel): | |
| system_prompt: str | |
| score: float | |
| avg_attempts: float | |
| success_rate: float | |
| generation: int | |
| feedback: list[str] | |
| # βββ LLM Helper ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _make_client() -> AsyncOpenAI: | |
| return AsyncOpenAI( | |
| api_key=_HF_TOKEN, | |
| base_url=_API_BASE_URL, | |
| ) | |
| async def _complete(system: str, user: str) -> str: | |
| client = _make_client() | |
| resp = await client.chat.completions.create( | |
| model=_MODEL, | |
| messages=[ | |
| {"role": "system", "content": system}, | |
| {"role": "user", "content": user}, | |
| ], | |
| temperature=0.7, | |
| ) | |
| return resp.choices[0].message.content or "" | |
| # βββ Golden Queries for Scoring ββββββββββββββββββββββββββββββββββ | |
| _GOLDEN_QUERIES = [ | |
| { | |
| "id": "gq-01", | |
| "question": "List all users from the USA.", | |
| "expected_min_rows": 10, | |
| }, | |
| { | |
| "id": "gq-02", | |
| "question": "Show all products in the 'Electronics' category.", | |
| "expected_min_rows": 8, | |
| }, | |
| { | |
| "id": "gq-03", | |
| "question": "Find the total number of orders per user.", | |
| "expected_min_rows": 10, | |
| }, | |
| { | |
| "id": "gq-04", | |
| "question": "Show the average rating for each product category.", | |
| "expected_min_rows": 5, | |
| }, | |
| { | |
| "id": "gq-05", | |
| "question": "List products along with their seller name.", | |
| "expected_min_rows": 20, | |
| }, | |
| ] | |
| # βββ Optimizer Class ββββββββββββββββββββββββββββββββββββββββββββββ | |
| class GEPAOptimizer: | |
| def __init__(self) -> None: | |
| self._history: list[QueryResult] = [] | |
| self._pareto_front: list[Candidate] = [ | |
| Candidate( | |
| system_prompt=SEED_SYSTEM_PROMPT, | |
| score=0.5, | |
| avg_attempts=3.0, | |
| success_rate=0.5, | |
| generation=0, | |
| feedback=[], | |
| ) | |
| ] | |
| self._load() | |
| # βββ Public Interface βββββββββββββββββββββββββββββββββββββββββ | |
| def record_result(self, result: QueryResult) -> None: | |
| self._history.append(result) | |
| self._save() | |
| def get_current_prompt(self) -> str: | |
| if not self._pareto_front: | |
| return SEED_SYSTEM_PROMPT | |
| return max(self._pareto_front, key=lambda c: c.score).system_prompt | |
| def get_history(self) -> list[QueryResult]: | |
| return list(self._history) | |
| def get_pareto_front(self) -> list[Candidate]: | |
| return list(self._pareto_front) | |
| def set_current_prompt(self, prompt: str) -> None: | |
| if self._pareto_front: | |
| best = max(self._pareto_front, key=lambda c: c.score) | |
| best.system_prompt = prompt | |
| else: | |
| self._pareto_front.append( | |
| Candidate( | |
| system_prompt=prompt, | |
| score=0.5, | |
| avg_attempts=3.0, | |
| success_rate=0.5, | |
| generation=0, | |
| feedback=[], | |
| ) | |
| ) | |
| self._save() | |
| def current_generation(self) -> int: | |
| if not self._pareto_front: | |
| return 0 | |
| return max(c.generation for c in self._pareto_front) | |
| def should_optimize(self) -> bool: | |
| return len(self._history) > 0 and len(self._history) % GEPA_OPTIMIZE_EVERY == 0 | |
| def reset(self) -> None: | |
| self._history.clear() | |
| self._pareto_front.clear() | |
| self._pareto_front.append( | |
| Candidate( | |
| system_prompt=SEED_SYSTEM_PROMPT, | |
| score=0.5, | |
| avg_attempts=3.0, | |
| success_rate=0.5, | |
| generation=0, | |
| feedback=[], | |
| ) | |
| ) | |
| self._save() | |
| async def run_optimization_cycle( | |
| self, | |
| user_feedback_context: Optional[str] = None, | |
| dialect: str = "SQLite", | |
| ) -> Optional[dict]: | |
| """ | |
| Run one GEPA cycle: reflect β mutate β score β update Pareto front. | |
| Returns {new_prompt, reflection} or None if not enough data. | |
| """ | |
| if len(self._history) < 2: | |
| return None | |
| recent_failures = [ | |
| h for h in self._history if h.attempts > 1 or not h.success | |
| ][-8:] | |
| if len(recent_failures) < 2: | |
| return None | |
| current_best = self.get_current_prompt() | |
| # ββ Step 1: Reflect ββββββββββββββββββββββββββββββββββββββ | |
| failure_summary = "\n\n---\n\n".join( | |
| f'Query {i+1}: "{f.question}"\n' | |
| f"Attempts: {f.attempts}\n" | |
| f"Errors:\n" + "\n".join(f" - {e}" for e in f.errors) + "\n" | |
| f"Final SQL: {f.final_sql}" | |
| for i, f in enumerate(recent_failures) | |
| ) | |
| user_ctx_block = ( | |
| f"\n\nUser conversation:\n{user_feedback_context}" | |
| if user_feedback_context | |
| else "" | |
| ) | |
| reflection = await _complete( | |
| f"You are an expert SQL prompt engineer analyzing why an LLM SQL agent is failing.\n" | |
| f"The target database is {dialect} β all rules must use {dialect} syntax.\n" | |
| "Your job: identify specific, recurring patterns in these failures and state EXACTLY " | |
| "what rules or knowledge the system prompt is missing.\n" | |
| "Be very specific β name the exact functions, syntax patterns, or schema reasoning gaps.\n" | |
| "Output a concise diagnosis (3-5 bullet points max).", | |
| f"Current system prompt:\n{current_best}\n\n" | |
| f"Recent failures:\n{failure_summary}{user_ctx_block}", | |
| ) | |
| # ββ Step 2: Mutate βββββββββββββββββββββββββββββββββββββββ | |
| current_generation = max(c.generation for c in self._pareto_front) if self._pareto_front else 0 | |
| new_prompt = await _complete( | |
| f"You are an expert prompt engineer. Improve a system prompt for a {dialect} SQL generation agent.\n" | |
| "Rules for the new prompt:\n" | |
| "- Keep it concise and actionable\n" | |
| f"- The target database is {dialect} β use ONLY {dialect} syntax and functions\n" | |
| "- Add specific rules that address the diagnosed failure patterns\n" | |
| "- Do NOT add generic fluff β every rule must be earned by a real failure\n" | |
| "- Output ONLY the improved system prompt text, nothing else", | |
| f"Current system prompt:\n{current_best}\n\n" | |
| f"Diagnosed failure patterns:\n{reflection}\n\n" | |
| "Write the improved system prompt:", | |
| ) | |
| # ββ Step 3: Score ββββββββββββββββββββββββββββββββββββββββ | |
| benchmark_score = await self._score_prompt(new_prompt) | |
| current_avg_attempts = ( | |
| sum(h.attempts for h in self._history) / len(self._history) | |
| if self._history | |
| else 3.0 | |
| ) | |
| new_candidate = Candidate( | |
| system_prompt=new_prompt, | |
| score=benchmark_score, | |
| avg_attempts=max(current_avg_attempts - 0.5, 1.0), | |
| success_rate=benchmark_score, | |
| generation=current_generation + 1, | |
| feedback=[reflection], | |
| ) | |
| # ββ Step 4: Update Pareto front ββββββββββββββββββββββββββ | |
| self._pareto_front.append(new_candidate) | |
| self._pareto_front.sort(key=lambda c: c.score, reverse=True) | |
| if len(self._pareto_front) > 3: | |
| self._pareto_front = self._pareto_front[:3] | |
| self._save() | |
| return {"new_prompt": new_prompt, "reflection": reflection} | |
| async def _score_prompt(self, prompt: str) -> float: | |
| """ | |
| Score a prompt by running 3 golden queries and measuring success rate. | |
| """ | |
| from env.database import execute_query, get_schema_info | |
| import re | |
| schema = get_schema_info() | |
| client = _make_client() | |
| scores = [] | |
| for gq in _GOLDEN_QUERIES[:3]: | |
| try: | |
| resp = await client.chat.completions.create( | |
| model=_MODEL, | |
| messages=[ | |
| {"role": "system", "content": prompt}, | |
| { | |
| "role": "user", | |
| "content": ( | |
| f"Schema:\n{schema}\n\n" | |
| f"Question: {gq['question']}\n\n" | |
| "Write a SQL query." | |
| ), | |
| }, | |
| ], | |
| temperature=0.1, | |
| ) | |
| sql = resp.choices[0].message.content or "" | |
| sql = re.sub(r"^```(?:sql)?\s*", "", sql.strip(), flags=re.IGNORECASE) | |
| sql = re.sub(r"\s*```$", "", sql).strip().rstrip(";") | |
| rows, error = execute_query(sql) | |
| if error is None and len(rows) >= gq["expected_min_rows"]: | |
| scores.append(1.0) | |
| elif error is None and rows: | |
| scores.append(0.5) | |
| else: | |
| scores.append(0.0) | |
| except Exception: | |
| scores.append(0.0) | |
| return sum(scores) / len(scores) if scores else 0.3 | |
| # βββ Persistence βββββββββββββββββββββββββββββββββββββββββββββ | |
| def _save(self) -> None: | |
| try: | |
| GEPA_PATH.parent.mkdir(parents=True, exist_ok=True) | |
| data = { | |
| "history": [r.model_dump() for r in self._history[-100:]], | |
| "pareto_front": [c.model_dump() for c in self._pareto_front], | |
| } | |
| GEPA_PATH.write_text(json.dumps(data, default=str)) | |
| except Exception: | |
| pass | |
| def _load(self) -> None: | |
| try: | |
| if not GEPA_PATH.exists(): | |
| return | |
| data = json.loads(GEPA_PATH.read_text()) | |
| self._history = [QueryResult(**r) for r in data.get("history", [])] | |
| loaded_front = [Candidate(**c) for c in data.get("pareto_front", [])] | |
| if loaded_front: | |
| self._pareto_front = loaded_front | |
| except Exception: | |
| pass | |
| # βββ Singleton ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _gepa_instance: Optional[GEPAOptimizer] = None | |
| def get_gepa() -> GEPAOptimizer: | |
| global _gepa_instance | |
| if _gepa_instance is None: | |
| _gepa_instance = GEPAOptimizer() | |
| return _gepa_instance | |