File size: 13,001 Bytes
3c665d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92cc088
 
 
3c665d2
44ef33f
 
 
 
3c665d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92cc088
 
3c665d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0b682f
 
 
 
 
 
3c665d2
44ef33f
3c665d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
"""
GEPA (Goal-directed Evolutionary Prompt Adaptation) optimizer.

Ported from gepa.ts. Key steps:
  1. Reflection: LLM analyzes failure history, outputs diagnosis
  2. Mutation: LLM rewrites system prompt based on diagnosis
  3. Scoring: Run 3 golden queries with new prompt, compute score
  4. Pareto front: Keep top 3 prompts by (score, diversity)

State is persisted to data/gepa_prompt.json.
"""

from __future__ import annotations

import json
import os
import time
from pathlib import Path
from typing import Optional

from openai import AsyncOpenAI
from pydantic import BaseModel

_DATA_DIR = Path(os.environ.get("DATA_DIR", Path(__file__).parent.parent / "data"))
GEPA_PATH = _DATA_DIR / "gepa_prompt.json"

_API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1")
_MODEL = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
_HF_TOKEN = os.environ.get("HF_TOKEN")  # no default β€” must be set explicitly

# How many queries between each GEPA optimization cycle.
# Override with the GEPA_OPTIMIZE_EVERY environment variable.
GEPA_OPTIMIZE_EVERY: int = int(os.environ.get("GEPA_OPTIMIZE_EVERY", "4"))

SEED_SYSTEM_PROMPT = """You are a SQL expert. Given a natural language question and a SQLite database schema, write a correct SQL query.

Rules:
- Output ONLY the SQL query, nothing else
- No markdown, no code fences, no explanation
- Use SQLite syntax"""


# ─── Models ──────────────────────────────────────────────────────

class QueryResult(BaseModel):
    question: str
    final_sql: str
    attempts: int
    success: bool
    errors: list[str]
    timestamp: float


class Candidate(BaseModel):
    system_prompt: str
    score: float
    avg_attempts: float
    success_rate: float
    generation: int
    feedback: list[str]


# ─── LLM Helper ──────────────────────────────────────────────────

def _make_client() -> AsyncOpenAI:
    return AsyncOpenAI(
        api_key=_HF_TOKEN,
        base_url=_API_BASE_URL,
    )


async def _complete(system: str, user: str) -> str:
    client = _make_client()
    resp = await client.chat.completions.create(
        model=_MODEL,
        messages=[
            {"role": "system", "content": system},
            {"role": "user", "content": user},
        ],
        temperature=0.7,
    )
    return resp.choices[0].message.content or ""


# ─── Golden Queries for Scoring ──────────────────────────────────

_GOLDEN_QUERIES = [
    {
        "id": "gq-01",
        "question": "List all users from the USA.",
        "expected_min_rows": 10,
    },
    {
        "id": "gq-02",
        "question": "Show all products in the 'Electronics' category.",
        "expected_min_rows": 8,
    },
    {
        "id": "gq-03",
        "question": "Find the total number of orders per user.",
        "expected_min_rows": 10,
    },
    {
        "id": "gq-04",
        "question": "Show the average rating for each product category.",
        "expected_min_rows": 5,
    },
    {
        "id": "gq-05",
        "question": "List products along with their seller name.",
        "expected_min_rows": 20,
    },
]


# ─── Optimizer Class ──────────────────────────────────────────────

class GEPAOptimizer:
    def __init__(self) -> None:
        self._history: list[QueryResult] = []
        self._pareto_front: list[Candidate] = [
            Candidate(
                system_prompt=SEED_SYSTEM_PROMPT,
                score=0.5,
                avg_attempts=3.0,
                success_rate=0.5,
                generation=0,
                feedback=[],
            )
        ]
        self._load()

    # ─── Public Interface ─────────────────────────────────────────

    def record_result(self, result: QueryResult) -> None:
        self._history.append(result)
        self._save()

    def get_current_prompt(self) -> str:
        if not self._pareto_front:
            return SEED_SYSTEM_PROMPT
        return max(self._pareto_front, key=lambda c: c.score).system_prompt

    def get_history(self) -> list[QueryResult]:
        return list(self._history)

    def get_pareto_front(self) -> list[Candidate]:
        return list(self._pareto_front)

    def set_current_prompt(self, prompt: str) -> None:
        if self._pareto_front:
            best = max(self._pareto_front, key=lambda c: c.score)
            best.system_prompt = prompt
        else:
            self._pareto_front.append(
                Candidate(
                    system_prompt=prompt,
                    score=0.5,
                    avg_attempts=3.0,
                    success_rate=0.5,
                    generation=0,
                    feedback=[],
                )
            )
        self._save()

    @property
    def current_generation(self) -> int:
        if not self._pareto_front:
            return 0
        return max(c.generation for c in self._pareto_front)

    def should_optimize(self) -> bool:
        return len(self._history) > 0 and len(self._history) % GEPA_OPTIMIZE_EVERY == 0

    def reset(self) -> None:
        self._history.clear()
        self._pareto_front.clear()
        self._pareto_front.append(
            Candidate(
                system_prompt=SEED_SYSTEM_PROMPT,
                score=0.5,
                avg_attempts=3.0,
                success_rate=0.5,
                generation=0,
                feedback=[],
            )
        )
        self._save()

    async def run_optimization_cycle(
        self,
        user_feedback_context: Optional[str] = None,
        dialect: str = "SQLite",
    ) -> Optional[dict]:
        """
        Run one GEPA cycle: reflect β†’ mutate β†’ score β†’ update Pareto front.
        Returns {new_prompt, reflection} or None if not enough data.
        """
        if len(self._history) < 2:
            return None

        recent_failures = [
            h for h in self._history if h.attempts > 1 or not h.success
        ][-8:]
        if len(recent_failures) < 2:
            return None

        current_best = self.get_current_prompt()

        # ── Step 1: Reflect ──────────────────────────────────────
        failure_summary = "\n\n---\n\n".join(
            f'Query {i+1}: "{f.question}"\n'
            f"Attempts: {f.attempts}\n"
            f"Errors:\n" + "\n".join(f"  - {e}" for e in f.errors) + "\n"
            f"Final SQL: {f.final_sql}"
            for i, f in enumerate(recent_failures)
        )

        user_ctx_block = (
            f"\n\nUser conversation:\n{user_feedback_context}"
            if user_feedback_context
            else ""
        )

        reflection = await _complete(
            f"You are an expert SQL prompt engineer analyzing why an LLM SQL agent is failing.\n"
            f"The target database is {dialect} β€” all rules must use {dialect} syntax.\n"
            "Your job: identify specific, recurring patterns in these failures and state EXACTLY "
            "what rules or knowledge the system prompt is missing.\n"
            "Be very specific β€” name the exact functions, syntax patterns, or schema reasoning gaps.\n"
            "Output a concise diagnosis (3-5 bullet points max).",
            f"Current system prompt:\n{current_best}\n\n"
            f"Recent failures:\n{failure_summary}{user_ctx_block}",
        )

        # ── Step 2: Mutate ───────────────────────────────────────
        current_generation = max(c.generation for c in self._pareto_front) if self._pareto_front else 0

        new_prompt = await _complete(
            f"You are an expert prompt engineer. Improve a system prompt for a {dialect} SQL generation agent.\n"
            "Rules for the new prompt:\n"
            "- Keep it concise and actionable\n"
            f"- The target database is {dialect} β€” use ONLY {dialect} syntax and functions\n"
            "- Add specific rules that address the diagnosed failure patterns\n"
            "- Do NOT add generic fluff β€” every rule must be earned by a real failure\n"
            "- Output ONLY the improved system prompt text, nothing else",
            f"Current system prompt:\n{current_best}\n\n"
            f"Diagnosed failure patterns:\n{reflection}\n\n"
            "Write the improved system prompt:",
        )

        # ── Step 3: Score ────────────────────────────────────────
        benchmark_score = await self._score_prompt(new_prompt)

        current_avg_attempts = (
            sum(h.attempts for h in self._history) / len(self._history)
            if self._history
            else 3.0
        )

        new_candidate = Candidate(
            system_prompt=new_prompt,
            score=benchmark_score,
            avg_attempts=max(current_avg_attempts - 0.5, 1.0),
            success_rate=benchmark_score,
            generation=current_generation + 1,
            feedback=[reflection],
        )

        # ── Step 4: Update Pareto front ──────────────────────────
        self._pareto_front.append(new_candidate)
        self._pareto_front.sort(key=lambda c: c.score, reverse=True)
        if len(self._pareto_front) > 3:
            self._pareto_front = self._pareto_front[:3]

        self._save()
        return {"new_prompt": new_prompt, "reflection": reflection}

    async def _score_prompt(self, prompt: str) -> float:
        """
        Score a prompt by running 3 golden queries and measuring success rate.
        """
        from env.database import execute_query, get_schema_info
        import re

        schema = get_schema_info()
        client = _make_client()

        scores = []
        for gq in _GOLDEN_QUERIES[:3]:
            try:
                resp = await client.chat.completions.create(
                    model=_MODEL,
                    messages=[
                        {"role": "system", "content": prompt},
                        {
                            "role": "user",
                            "content": (
                                f"Schema:\n{schema}\n\n"
                                f"Question: {gq['question']}\n\n"
                                "Write a SQL query."
                            ),
                        },
                    ],
                    temperature=0.1,
                )
                sql = resp.choices[0].message.content or ""
                sql = re.sub(r"^```(?:sql)?\s*", "", sql.strip(), flags=re.IGNORECASE)
                sql = re.sub(r"\s*```$", "", sql).strip().rstrip(";")

                rows, error = execute_query(sql)
                if error is None and len(rows) >= gq["expected_min_rows"]:
                    scores.append(1.0)
                elif error is None and rows:
                    scores.append(0.5)
                else:
                    scores.append(0.0)
            except Exception:
                scores.append(0.0)

        return sum(scores) / len(scores) if scores else 0.3

    # ─── Persistence ─────────────────────────────────────────────

    def _save(self) -> None:
        try:
            GEPA_PATH.parent.mkdir(parents=True, exist_ok=True)
            data = {
                "history": [r.model_dump() for r in self._history[-100:]],
                "pareto_front": [c.model_dump() for c in self._pareto_front],
            }
            GEPA_PATH.write_text(json.dumps(data, default=str))
        except Exception:
            pass

    def _load(self) -> None:
        try:
            if not GEPA_PATH.exists():
                return
            data = json.loads(GEPA_PATH.read_text())
            self._history = [QueryResult(**r) for r in data.get("history", [])]
            loaded_front = [Candidate(**c) for c in data.get("pareto_front", [])]
            if loaded_front:
                self._pareto_front = loaded_front
        except Exception:
            pass


# ─── Singleton ────────────────────────────────────────────────────

_gepa_instance: Optional[GEPAOptimizer] = None


def get_gepa() -> GEPAOptimizer:
    global _gepa_instance
    if _gepa_instance is None:
        _gepa_instance = GEPAOptimizer()
    return _gepa_instance